# Copyright (c) 2010-2012 OpenStack, LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import sys
from requests import RequestException
from requests.structures import CaseInsensitiveDict
from time import sleep
import unittest
import testtools
import mock
import six
from six.moves import reload_module
from six.moves.urllib.parse import urlparse, ParseResult
from swiftclient import client as c
from swiftclient import shell as s


def fake_get_auth_keystone(expected_os_options=None, exc=None,
                           storage_url='http://url/', token='token',
                           **kwargs):
    def fake_get_auth_keystone(auth_url,
                               user,
                               key,
                               actual_os_options, **actual_kwargs):
        if exc:
            raise exc('test')
        # TODO: some way to require auth_url, user and key?
        if expected_os_options and actual_os_options != expected_os_options:
            return "", None
        if 'required_kwargs' in kwargs:
            for k, v in kwargs['required_kwargs'].items():
                if v != actual_kwargs.get(k):
                    return "", None

        if auth_url.startswith("https") and \
           auth_url.endswith("invalid-certificate") and \
           not actual_kwargs['insecure']:
            from swiftclient import client as c
            raise c.ClientException("invalid-certificate")
        if auth_url.startswith("https") and \
           auth_url.endswith("self-signed-certificate") and \
           not actual_kwargs['insecure'] and \
           actual_kwargs['cacert'] is None:
            from swiftclient import client as c
            raise c.ClientException("unverified-certificate")

        return storage_url, token
    return fake_get_auth_keystone


class StubResponse(object):
    """
    Placeholder structure for use with fake_http_connect's code_iter to modify
    response attributes (status, body, headers) on a per-request basis.
    """

    def __init__(self, status=200, body='', headers=None):
        self.status = status
        self.body = body
        self.headers = headers or {}


def fake_http_connect(*code_iter, **kwargs):
    """
    Generate a callable which yields a series of stubbed responses.  Because
    swiftclient will reuse an HTTP connection across pipelined requests it is
    not always the case that this fake is used strictly for mocking an HTTP
    connection, but rather each HTTP response (i.e. each call to requests
    get_response).
    """

    class FakeConn(object):

        def __init__(self, status, etag=None, body='', timestamp='1',
                     headers=None):
            self.status = status
            self.reason = 'Fake'
            self.host = '1.2.3.4'
            self.port = '1234'
            self.sent = 0
            self.received = 0
            self.etag = etag
            self.body = body
            self.timestamp = timestamp
            self._is_closed = True
            self.headers = headers or {}

        def connect(self):
            self._is_closed = False

        def close(self):
            self._is_closed = True

        def isclosed(self):
            return self._is_closed

        def getresponse(self):
            if kwargs.get('raise_exc'):
                raise Exception('test')
            return self

        def getexpect(self):
            if self.status == -2:
                raise RequestException()
            if self.status == -3:
                return FakeConn(507)
            return FakeConn(100)

        def getheaders(self):
            if self.headers:
                return self.headers.items()
            headers = {'content-length': len(self.body),
                       'content-type': 'x-application/test',
                       'x-timestamp': self.timestamp,
                       'last-modified': self.timestamp,
                       'x-object-meta-test': 'testing',
                       'etag':
                       self.etag or '"68b329da9893e34099c7d8ad5cb9c940"',
                       'x-works': 'yes',
                       'x-account-container-count': 12345}
            if not self.timestamp:
                del headers['x-timestamp']
            try:
                if next(container_ts_iter) is False:
                    headers['x-container-timestamp'] = '1'
            except StopIteration:
                pass
            if 'slow' in kwargs:
                headers['content-length'] = '4'
            if 'headers' in kwargs:
                headers.update(kwargs['headers'])
            if 'auth_v1' in kwargs:
                headers.update(
                    {'x-storage-url': 'storageURL',
                     'x-auth-token': 'someauthtoken'})
            return headers.items()

        def read(self, amt=None):
            if 'slow' in kwargs:
                if self.sent < 4:
                    self.sent += 1
                    sleep(0.1)
                    return ' '
            rv = self.body[:amt]
            self.body = self.body[amt:]
            return rv

        def send(self, amt=None):
            if 'slow' in kwargs:
                if self.received < 4:
                    self.received += 1
                    sleep(0.1)

        def getheader(self, name, default=None):
            return dict(self.getheaders()).get(name.lower(), default)

    timestamps_iter = iter(kwargs.get('timestamps') or ['1'] * len(code_iter))
    etag_iter = iter(kwargs.get('etags') or [None] * len(code_iter))
    x = kwargs.get('missing_container', [False] * len(code_iter))
    if not isinstance(x, (tuple, list)):
        x = [x] * len(code_iter)
    container_ts_iter = iter(x)
    code_iter = iter(code_iter)

    def connect(*args, **ckwargs):
        if 'give_content_type' in kwargs:
            if len(args) >= 7 and 'Content-Type' in args[6]:
                kwargs['give_content_type'](args[6]['Content-Type'])
            else:
                kwargs['give_content_type']('')
        if 'give_connect' in kwargs:
            kwargs['give_connect'](*args, **ckwargs)
        status = next(code_iter)
        if isinstance(status, StubResponse):
            fake_conn = FakeConn(status.status, body=status.body,
                                 headers=status.headers)
        else:
            etag = next(etag_iter)
            timestamp = next(timestamps_iter)
            fake_conn = FakeConn(status, etag, body=kwargs.get('body', ''),
                                 timestamp=timestamp)
        if fake_conn.status <= 0:
            raise RequestException()
        fake_conn.connect()
        return fake_conn

    connect.code_iter = code_iter
    return connect


class MockHttpTest(testtools.TestCase):

    def setUp(self):
        super(MockHttpTest, self).setUp()
        self.fake_connect = None
        self.request_log = []

        def fake_http_connection(*args, **kwargs):
            self.validateMockedRequestsConsumed()
            self.request_log = []
            self.fake_connect = fake_http_connect(*args, **kwargs)
            _orig_http_connection = c.http_connection
            query_string = kwargs.get('query_string')
            storage_url = kwargs.get('storage_url')
            auth_token = kwargs.get('auth_token')
            exc = kwargs.get('exc')

            def wrapper(url, proxy=None, cacert=None, insecure=False,
                        ssl_compression=True):
                if storage_url:
                    self.assertEqual(storage_url, url)

                parsed, _conn = _orig_http_connection(url, proxy=proxy)

                class RequestsWrapper(object):
                    pass
                conn = RequestsWrapper()

                def request(method, url, *args, **kwargs):
                    try:
                        conn.resp = self.fake_connect()
                    except StopIteration:
                        self.fail('Unexpected %s request for %s' % (
                            method, url))
                    self.request_log.append((parsed, method, url, args,
                                             kwargs, conn.resp))
                    conn.host = conn.resp.host
                    conn.isclosed = conn.resp.isclosed
                    conn.resp.has_been_read = False
                    _orig_read = conn.resp.read

                    def read(*args, **kwargs):
                        conn.resp.has_been_read = True
                        return _orig_read(*args, **kwargs)
                    conn.resp.read = read
                    if auth_token:
                        headers = args[1]
                        self.assertTrue('X-Auth-Token' in headers)
                        actual_token = headers.get('X-Auth-Token')
                        self.assertEqual(auth_token, actual_token)
                    if query_string:
                        self.assertTrue(url.endswith('?' + query_string))
                    if url.endswith('invalid_cert') and not insecure:
                        from swiftclient import client as c
                        raise c.ClientException("invalid_certificate")
                    if exc:
                        raise exc
                    return conn.resp
                conn.request = request

                def getresponse():
                    return conn.resp
                conn.getresponse = getresponse

                return parsed, conn
            return wrapper
        self.fake_http_connection = fake_http_connection

    def iter_request_log(self):
        for parsed, method, path, args, kwargs, resp in self.request_log:
            parts = parsed._asdict()
            parts['path'] = path
            full_path = ParseResult(**parts).geturl()
            args = list(args)
            log = dict(zip(('body', 'headers'), args))
            log.update({
                'method': method,
                'full_path': full_path,
                'parsed_path': urlparse(full_path),
                'path': path,
                'headers': CaseInsensitiveDict(log.get('headers')),
                'resp': resp,
                'status': resp.status,
            })
            yield log

    orig_assertEqual = unittest.TestCase.assertEqual

    def assertRequests(self, expected_requests):
        """
        Make sure some requests were made like you expected, provide a list of
        expected requests, typically in the form of [(method, path), ...]
        """
        real_requests = self.iter_request_log()
        for expected in expected_requests:
            method, path = expected[:2]
            real_request = next(real_requests)
            if urlparse(path).scheme:
                match_path = real_request['full_path']
            else:
                match_path = real_request['path']
            self.assertEqual((method, path), (real_request['method'],
                                              match_path))
            if len(expected) > 2:
                body = expected[2]
                real_request['expected'] = body
                err_msg = 'Body mismatch for %(method)s %(path)s, ' \
                    'expected %(expected)r, and got %(body)r' % real_request
                self.orig_assertEqual(body, real_request['body'], err_msg)

            if len(expected) > 3:
                headers = expected[3]
                for key, value in headers.items():
                    real_request['key'] = key
                    real_request['expected_value'] = value
                    real_request['value'] = real_request['headers'].get(key)
                    err_msg = (
                        'Header mismatch on %(key)r, '
                        'expected %(expected_value)r and got %(value)r '
                        'for %(method)s %(path)s %(headers)r' % real_request)
                    self.orig_assertEqual(value, real_request['value'],
                                          err_msg)

    def validateMockedRequestsConsumed(self):
        if not self.fake_connect:
            return
        unused_responses = list(self.fake_connect.code_iter)
        if unused_responses:
            self.fail('Unused responses %r' % (unused_responses,))

    def tearDown(self):
        self.validateMockedRequestsConsumed()
        super(MockHttpTest, self).tearDown()
        # TODO: this nuke from orbit clean up seems to be encouraging
        # un-hygienic mocking on the swiftclient.client module; which may lead
        # to some unfortunate test order dependency bugs by way of the broken
        # window theory if any other modules are similarly patched
        reload_module(c)


class CaptureStream(object):

    def __init__(self, stream):
        self.stream = stream
        self._capture = six.StringIO()
        self.streams = [self.stream, self._capture]

    def write(self, *args, **kwargs):
        for stream in self.streams:
            stream.write(*args, **kwargs)

    def writelines(self, *args, **kwargs):
        for stream in self.streams:
            stream.writelines(*args, **kwargs)

    def getvalue(self):
        return self._capture.getvalue()

    def clear(self):
        self._capture.truncate(0)
        self._capture.seek(0)


class CaptureOutput(object):

    def __init__(self, suppress_systemexit=False):
        self._out = CaptureStream(sys.stdout)
        self._err = CaptureStream(sys.stderr)
        self.patchers = []

        WrappedOutputManager = functools.partial(s.OutputManager,
                                                 print_stream=self._out,
                                                 error_stream=self._err)

        if suppress_systemexit:
            self.patchers += [
                mock.patch('swiftclient.shell.OutputManager.get_error_count',
                           return_value=0)
            ]

        self.patchers += [
            mock.patch('swiftclient.shell.OutputManager',
                       WrappedOutputManager),
            mock.patch('sys.stdout', self._out),
            mock.patch('sys.stderr', self._err),
        ]

    def __enter__(self):
        for patcher in self.patchers:
            patcher.start()
        return self

    def __exit__(self, *args, **kwargs):
        for patcher in self.patchers:
            patcher.stop()

    @property
    def out(self):
        return self._out.getvalue()

    @property
    def err(self):
        return self._err.getvalue()

    def clear(self):
        self._out.clear()
        self._err.clear()

    # act like the string captured by stdout

    def __str__(self):
        return self.out

    def __len__(self):
        return len(self.out)

    def __eq__(self, other):
        return self.out == other

    def __getattr__(self, name):
        return getattr(self.out, name)