# Copyright (c) 2010-2012 OpenStack, LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import sys
from requests import RequestException
from requests.structures import CaseInsensitiveDict
from time import sleep
import unittest
import mock
import six
from six.moves import reload_module
from six.moves.urllib.parse import urlparse, ParseResult
from swiftclient import client as c
from swiftclient import shell as s
from swiftclient.utils import EMPTY_ETAG


def fake_get_auth_keystone(expected_os_options=None, exc=None,
                           storage_url='http://url/', token='token',
                           **kwargs):
    def fake_get_auth_keystone(auth_url,
                               user,
                               key,
                               actual_os_options, **actual_kwargs):
        if exc:
            raise exc('test')
        # TODO: some way to require auth_url, user and key?
        if expected_os_options:
            for key, value in actual_os_options.items():
                if value and value != expected_os_options.get(key):
                    return "", None
        if 'required_kwargs' in kwargs:
            for k, v in kwargs['required_kwargs'].items():
                if v != actual_kwargs.get(k):
                    return "", None

        if auth_url.startswith("https") and \
           auth_url.endswith("invalid-certificate") and \
           not actual_kwargs['insecure']:
            from swiftclient import client as c
            raise c.ClientException("invalid-certificate")
        if auth_url.startswith("https") and \
           auth_url.endswith("self-signed-certificate") and \
           not actual_kwargs['insecure'] and \
           actual_kwargs['cacert'] is None:
            from swiftclient import client as c
            raise c.ClientException("unverified-certificate")
        if auth_url.startswith("https") and \
           auth_url.endswith("client-certificate") and \
           not (actual_kwargs['cert'] and actual_kwargs['cert_key']):
            from swiftclient import client as c
            raise c.ClientException("noclient-certificate")

        return storage_url, token
    return fake_get_auth_keystone


class StubResponse(object):
    """
    Placeholder structure for use with fake_http_connect's code_iter to modify
    response attributes (status, body, headers) on a per-request basis.
    """

    def __init__(self, status=200, body='', headers=None):
        self.status = status
        self.body = body
        self.headers = headers or {}

    def __repr__(self):
        return '%s(%r, %r, %r)' % (self.__class__.__name__, self.status,
                                   self.body, self.headers)


def fake_http_connect(*code_iter, **kwargs):
    """
    Generate a callable which yields a series of stubbed responses.  Because
    swiftclient will reuse an HTTP connection across pipelined requests it is
    not always the case that this fake is used strictly for mocking an HTTP
    connection, but rather each HTTP response (i.e. each call to requests
    get_response).
    """

    class FakeConn(object):

        def __init__(self, status, etag=None, body='', timestamp='1',
                     headers=None):
            self.status_code = self.status = status
            self.reason = 'Fake'
            self.scheme = 'http'
            self.host = '1.2.3.4'
            self.port = '1234'
            self.sent = 0
            self.received = 0
            self.etag = etag
            self.content = self.body = body
            self.timestamp = timestamp
            self.headers = headers or {}
            self.request = None

        def getresponse(self):
            if kwargs.get('raise_exc'):
                raise Exception('test')
            return self

        def getheaders(self):
            if self.headers:
                return self.headers.items()
            headers = {'content-length': str(len(self.body)),
                       'content-type': 'x-application/test',
                       'x-timestamp': self.timestamp,
                       'last-modified': self.timestamp,
                       'x-object-meta-test': 'testing',
                       'etag':
                       self.etag or '"%s"' % EMPTY_ETAG,
                       'x-works': 'yes',
                       'x-account-container-count': '12345'}
            if not self.timestamp:
                del headers['x-timestamp']
            try:
                if next(container_ts_iter) is False:
                    headers['x-container-timestamp'] = '1'
            except StopIteration:
                pass
            if 'slow' in kwargs:
                headers['content-length'] = '4'
            if 'headers' in kwargs:
                headers.update(kwargs['headers'])
            if 'auth_v1' in kwargs:
                headers.update(
                    {'x-storage-url': 'storageURL',
                     'x-auth-token': 'someauthtoken'})
            return headers.items()

        def read(self, amt=None):
            if 'slow' in kwargs:
                if self.sent < 4:
                    self.sent += 1
                    sleep(0.1)
                    return ' '
            rv = self.body[:amt]
            if amt is not None:
                self.body = self.body[amt:]
            else:
                self.body = ''
            return rv

        def send(self, amt=None):
            if 'slow' in kwargs:
                if self.received < 4:
                    self.received += 1
                    sleep(0.1)

        def getheader(self, name, default=None):
            return dict(self.getheaders()).get(name.lower(), default)

        def close(self):
            pass

    timestamps_iter = iter(kwargs.get('timestamps') or ['1'] * len(code_iter))
    etag_iter = iter(kwargs.get('etags') or [None] * len(code_iter))
    x = kwargs.get('missing_container', [False] * len(code_iter))
    if not isinstance(x, (tuple, list)):
        x = [x] * len(code_iter)
    container_ts_iter = iter(x)
    code_iter = iter(code_iter)

    def connect(*args, **ckwargs):
        if 'give_content_type' in kwargs:
            if len(args) >= 7 and 'Content-Type' in args[6]:
                kwargs['give_content_type'](args[6]['Content-Type'])
            else:
                kwargs['give_content_type']('')
        if 'give_connect' in kwargs:
            kwargs['give_connect'](*args, **ckwargs)
        status = next(code_iter)
        if isinstance(status, StubResponse):
            fake_conn = FakeConn(status.status, body=status.body,
                                 headers=status.headers)
        else:
            etag = next(etag_iter)
            timestamp = next(timestamps_iter)
            fake_conn = FakeConn(status, etag, body=kwargs.get('body', ''),
                                 timestamp=timestamp)
        if fake_conn.status <= 0:
            raise RequestException()
        return fake_conn

    connect.code_iter = code_iter
    return connect


class MockHttpTest(unittest.TestCase):

    def setUp(self):
        super(MockHttpTest, self).setUp()
        self.fake_connect = None
        self.request_log = []

        # Capture output, since the test-runner stdout/stderr monkey-patching
        # won't cover the references to sys.stdout/sys.stderr in
        # swiftclient.multithreading
        self.capture_output = CaptureOutput()
        self.capture_output.__enter__()

        def fake_http_connection(*args, **kwargs):
            self.validateMockedRequestsConsumed()
            self.request_log = []
            self.fake_connect = fake_http_connect(*args, **kwargs)
            _orig_http_connection = c.http_connection
            query_string = kwargs.get('query_string')
            storage_url = kwargs.get('storage_url')
            auth_token = kwargs.get('auth_token')
            exc = kwargs.get('exc')
            on_request = kwargs.get('on_request')

            def wrapper(url, proxy=None, cacert=None, insecure=False,
                        cert=None, cert_key=None,
                        ssl_compression=True, timeout=None):
                if storage_url:
                    self.assertEqual(storage_url, url)

                parsed, _conn = _orig_http_connection(url, proxy=proxy)

                class RequestsWrapper(object):
                    def close(self):
                        pass
                conn = RequestsWrapper()

                def request(method, path, *args, **kwargs):
                    try:
                        conn.resp = self.fake_connect()
                    except StopIteration:
                        self.fail('Unexpected %s request for %s' % (
                            method, path))
                    self.request_log.append((parsed, method, path, args,
                                             kwargs, conn.resp))
                    conn.host = conn.resp.host
                    conn.resp.request = RequestsWrapper()
                    conn.resp.request.url = '%s://%s%s' % (
                        conn.resp.scheme, conn.resp.host, path)
                    conn.resp.has_been_read = False
                    _orig_read = conn.resp.read

                    def read(*args, **kwargs):
                        conn.resp.has_been_read = True
                        return _orig_read(*args, **kwargs)
                    conn.resp.read = read
                    if on_request:
                        status = on_request(method, path, *args, **kwargs)
                        conn.resp.status = status
                    if auth_token:
                        headers = args[1]
                        self.assertEqual(auth_token,
                                         headers.get('X-Auth-Token'))
                    if query_string:
                        self.assertTrue(path.endswith('?' + query_string))
                    if path.endswith('invalid_cert') and not insecure:
                        from swiftclient import client as c
                        raise c.ClientException("invalid_certificate")
                    if exc:
                        raise exc
                    return conn.resp

                def putrequest(path, data=None, headers=None, **kwargs):
                    request('PUT', path, data, headers, **kwargs)

                conn.request = request
                conn.putrequest = putrequest

                def getresponse():
                    return conn.resp
                conn.getresponse = getresponse

                return parsed, conn
            return wrapper
        self.fake_http_connection = fake_http_connection

    def iter_request_log(self):
        for parsed, method, path, args, kwargs, resp in self.request_log:
            parts = parsed._asdict()
            parts['path'] = path
            full_path = ParseResult(**parts).geturl()
            args = list(args)
            log = dict(zip(('body', 'headers'), args))
            log.update({
                'method': method,
                'full_path': full_path,
                'parsed_path': urlparse(full_path),
                'path': path,
                'headers': CaseInsensitiveDict(log.get('headers')),
                'resp': resp,
                'status': resp.status,
            })
            yield log

    orig_assertEqual = unittest.TestCase.assertEqual

    def assert_request_equal(self, expected, real_request):
        method, path = expected[:2]
        if urlparse(path).scheme:
            match_path = real_request['full_path']
        else:
            match_path = real_request['path']
        self.assertEqual((method, path), (real_request['method'],
                                          match_path))
        if len(expected) > 2:
            body = expected[2]
            real_request['expected'] = body
            err_msg = 'Body mismatch for %(method)s %(path)s, ' \
                'expected %(expected)r, and got %(body)r' % real_request
            self.orig_assertEqual(body, real_request['body'], err_msg)

        if len(expected) > 3:
            headers = CaseInsensitiveDict(expected[3])
            for key, value in headers.items():
                real_request['key'] = key
                real_request['expected_value'] = value
                real_request['value'] = real_request['headers'].get(key)
                err_msg = (
                    'Header mismatch on %(key)r, '
                    'expected %(expected_value)r and got %(value)r '
                    'for %(method)s %(path)s %(headers)r' % real_request)
                self.orig_assertEqual(value, real_request['value'],
                                      err_msg)
            real_request['extra_headers'] = dict(
                (key, value) for key, value in real_request['headers'].items()
                if key not in headers)
            if real_request['extra_headers']:
                self.fail('Received unexpected headers for %(method)s '
                          '%(path)s, got %(extra_headers)r' % real_request)

    def assertRequests(self, expected_requests):
        """
        Make sure some requests were made like you expected, provide a list of
        expected requests, typically in the form of [(method, path), ...]
        or [(method, path, body, headers), ...]
        """
        real_requests = self.iter_request_log()
        for expected in expected_requests:
            real_request = next(real_requests)
            self.assert_request_equal(expected, real_request)
        try:
            real_request = next(real_requests)
        except StopIteration:
            pass
        else:
            self.fail('At least one extra request received: %r' %
                      real_request)

    def assert_request(self, expected_request):
        """
        Make sure a request was made as expected. Provide the
        expected request in the form of [(method, path), ...]
        """
        real_requests = self.iter_request_log()
        for real_request in real_requests:
            try:
                self.assert_request_equal(expected_request, real_request)
                break
            except AssertionError:
                pass
        else:
            raise AssertionError(
                "Expected request %s not found in actual requests %s"
                % (expected_request, self.request_log)
            )

    def validateMockedRequestsConsumed(self):
        if not self.fake_connect:
            return
        unused_responses = list(self.fake_connect.code_iter)
        if unused_responses:
            self.fail('Unused responses %r' % (unused_responses,))

    def tearDown(self):
        self.validateMockedRequestsConsumed()
        super(MockHttpTest, self).tearDown()
        # TODO: this nuke from orbit clean up seems to be encouraging
        # un-hygienic mocking on the swiftclient.client module; which may lead
        # to some unfortunate test order dependency bugs by way of the broken
        # window theory if any other modules are similarly patched
        self.capture_output.__exit__()
        reload_module(c)


class CaptureStreamPrinter(object):
    """
    CaptureStreamPrinter is used for testing unicode writing for PY3. Anything
    written here is encoded as utf-8 and written to the parent CaptureStream
    """
    def __init__(self, captured_stream):
        self._captured_stream = captured_stream

    def write(self, data):
        # No encoding, just convert the raw bytes into a str for testing
        # The below call also validates that we have a byte string.
        self._captured_stream.write(
            data if isinstance(data, six.binary_type) else data.encode('utf8'))


class CaptureStream(object):

    def __init__(self, stream):
        self.stream = stream
        self._buffer = six.BytesIO()
        self._capture = CaptureStreamPrinter(self._buffer)
        self.streams = [self._capture]

    @property
    def buffer(self):
        if six.PY3:
            return self._buffer
        else:
            raise AttributeError(
                'Output stream has no attribute "buffer" in Python2')

    def flush(self):
        pass

    def write(self, *args, **kwargs):
        for stream in self.streams:
            stream.write(*args, **kwargs)

    def writelines(self, *args, **kwargs):
        for stream in self.streams:
            stream.writelines(*args, **kwargs)

    def getvalue(self):
        return self._buffer.getvalue()

    def clear(self):
        self._buffer.truncate(0)
        self._buffer.seek(0)


class CaptureOutput(object):

    def __init__(self, suppress_systemexit=False):
        self._out = CaptureStream(sys.stdout)
        self._err = CaptureStream(sys.stderr)
        self.patchers = []

        WrappedOutputManager = functools.partial(s.OutputManager,
                                                 print_stream=self._out,
                                                 error_stream=self._err)

        if suppress_systemexit:
            self.patchers += [
                mock.patch('swiftclient.shell.OutputManager.get_error_count',
                           return_value=0)
            ]

        self.patchers += [
            mock.patch('swiftclient.shell.OutputManager',
                       WrappedOutputManager),
            mock.patch('sys.stdout', self._out),
            mock.patch('sys.stderr', self._err),
        ]

    def __enter__(self):
        for patcher in self.patchers:
            patcher.start()
        return self

    def __exit__(self, *args, **kwargs):
        for patcher in self.patchers:
            patcher.stop()

    @property
    def out(self):
        return self._out.getvalue().decode('utf8')

    @property
    def err(self):
        return self._err.getvalue().decode('utf8')

    def clear(self):
        self._out.clear()
        self._err.clear()

    # act like the string captured by stdout

    def __str__(self):
        return self.out

    def __len__(self):
        return len(self.out)

    def __eq__(self, other):
        return self.out == other

    def __ne__(self, other):
        return not self.__eq__(other)

    def __getattr__(self, name):
        return getattr(self.out, name)


class FakeKeystone(object):
    '''
    Fake keystone client module. Returns given endpoint url and auth token.
    '''
    def __init__(self, endpoint, token):
        self.calls = []
        self.auth_version = None
        self.endpoint = endpoint
        self.token = token

    class _Client(object):
        def __init__(self, endpoint, auth_token, **kwargs):
            self.auth_token = auth_token
            self.endpoint = endpoint
            self.service_catalog = self.ServiceCatalog(endpoint)

        class ServiceCatalog(object):
            def __init__(self, endpoint):
                self.calls = []
                self.endpoint_url = endpoint

            def url_for(self, **kwargs):
                self.calls.append(kwargs)
                return self.endpoint_url

    def Client(self, **kwargs):
        self.calls.append(kwargs)
        self.client = self._Client(
            endpoint=self.endpoint, auth_token=self.token, **kwargs)
        return self.client

    class Unauthorized(Exception):
        pass

    class AuthorizationFailure(Exception):
        pass

    class EndpointNotFound(Exception):
        pass


class FakeStream(object):
    def __init__(self, size):
        self.bytes_read = 0
        self.size = size

    def read(self, size=-1):
        if self.bytes_read == self.size:
            return b''

        if size == -1 or size + self.bytes_read > self.size:
            remaining = self.size - self.bytes_read
            self.bytes_read = self.size
            return b'A' * remaining

        self.bytes_read += size
        return b'A' * size

    def __len__(self):
        return self.size