# Copyright (c) 2010-2012 OpenStack, LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import sys
from requests import RequestException
from time import sleep
import testtools
import mock
import six
from six.moves import reload_module
from swiftclient import client as c
from swiftclient import shell as s


def fake_get_auth_keystone(os_options, exc=None, **kwargs):
    def fake_get_auth_keystone(auth_url,
                               user,
                               key,
                               actual_os_options, **actual_kwargs):
        if exc:
            raise exc('test')
        if actual_os_options != os_options:
            return "", None

        if auth_url.startswith("https") and \
           auth_url.endswith("invalid-certificate") and \
           not actual_kwargs['insecure']:
            from swiftclient import client as c
            raise c.ClientException("invalid-certificate")
        if auth_url.startswith("https") and \
           auth_url.endswith("self-signed-certificate") and \
           not actual_kwargs['insecure'] and \
           actual_kwargs['cacert'] is None:
            from swiftclient import client as c
            raise c.ClientException("unverified-certificate")
        if 'required_kwargs' in kwargs:
            for k, v in kwargs['required_kwargs'].items():
                if v != actual_kwargs.get(k):
                    return "", None

        return "http://url/", "token"
    return fake_get_auth_keystone


def fake_http_connect(*code_iter, **kwargs):

    class FakeConn(object):

        def __init__(self, status, etag=None, body='', timestamp='1'):
            self.status = status
            self.reason = 'Fake'
            self.host = '1.2.3.4'
            self.port = '1234'
            self.sent = 0
            self.received = 0
            self.etag = etag
            self.body = body
            self.timestamp = timestamp
            self._is_closed = True

        def connect(self):
            self._is_closed = False

        def close(self):
            self._is_closed = True

        def isclosed(self):
            return self._is_closed

        def getresponse(self):
            if kwargs.get('raise_exc'):
                raise Exception('test')
            return self

        def getexpect(self):
            if self.status == -2:
                raise RequestException()
            if self.status == -3:
                return FakeConn(507)
            return FakeConn(100)

        def getheaders(self):
            headers = {'content-length': len(self.body),
                       'content-type': 'x-application/test',
                       'x-timestamp': self.timestamp,
                       'last-modified': self.timestamp,
                       'x-object-meta-test': 'testing',
                       'etag':
                       self.etag or '"68b329da9893e34099c7d8ad5cb9c940"',
                       'x-works': 'yes',
                       'x-account-container-count': 12345}
            if not self.timestamp:
                del headers['x-timestamp']
            try:
                if next(container_ts_iter) is False:
                    headers['x-container-timestamp'] = '1'
            except StopIteration:
                pass
            if 'slow' in kwargs:
                headers['content-length'] = '4'
            if 'headers' in kwargs:
                headers.update(kwargs['headers'])
            if 'auth_v1' in kwargs:
                headers.update(
                    {'x-storage-url': 'storageURL',
                     'x-auth-token': 'someauthtoken'})
            return headers.items()

        def read(self, amt=None):
            if 'slow' in kwargs:
                if self.sent < 4:
                    self.sent += 1
                    sleep(0.1)
                    return ' '
            rv = self.body[:amt]
            self.body = self.body[amt:]
            return rv

        def send(self, amt=None):
            if 'slow' in kwargs:
                if self.received < 4:
                    self.received += 1
                    sleep(0.1)

        def getheader(self, name, default=None):
            return dict(self.getheaders()).get(name.lower(), default)

    timestamps_iter = iter(kwargs.get('timestamps') or ['1'] * len(code_iter))
    etag_iter = iter(kwargs.get('etags') or [None] * len(code_iter))
    x = kwargs.get('missing_container', [False] * len(code_iter))
    if not isinstance(x, (tuple, list)):
        x = [x] * len(code_iter)
    container_ts_iter = iter(x)
    code_iter = iter(code_iter)

    def connect(*args, **ckwargs):
        if 'give_content_type' in kwargs:
            if len(args) >= 7 and 'Content-Type' in args[6]:
                kwargs['give_content_type'](args[6]['Content-Type'])
            else:
                kwargs['give_content_type']('')
        if 'give_connect' in kwargs:
            kwargs['give_connect'](*args, **ckwargs)
        status = next(code_iter)
        etag = next(etag_iter)
        timestamp = next(timestamps_iter)
        if status <= 0:
            raise RequestException()
        fake_conn = FakeConn(status, etag, body=kwargs.get('body', ''),
                             timestamp=timestamp)
        fake_conn.connect()
        return fake_conn

    return connect


class MockHttpTest(testtools.TestCase):

    def setUp(self):
        super(MockHttpTest, self).setUp()

        def fake_http_connection(*args, **kwargs):
            _orig_http_connection = c.http_connection
            return_read = kwargs.get('return_read')
            query_string = kwargs.get('query_string')
            storage_url = kwargs.get('storage_url')
            auth_token = kwargs.get('auth_token')
            exc = kwargs.get('exc')

            def wrapper(url, proxy=None, cacert=None, insecure=False,
                        ssl_compression=True):
                if storage_url:
                    self.assertEqual(storage_url, url)

                parsed, _conn = _orig_http_connection(url, proxy=proxy)
                conn = fake_http_connect(*args, **kwargs)()

                def request(method, url, *args, **kwargs):
                    if auth_token:
                        headers = args[1]
                        self.assertTrue('X-Auth-Token' in headers)
                        actual_token = headers.get('X-Auth-Token')
                        self.assertEqual(auth_token, actual_token)
                    if query_string:
                        self.assertTrue(url.endswith('?' + query_string))
                    if url.endswith('invalid_cert') and not insecure:
                        from swiftclient import client as c
                        raise c.ClientException("invalid_certificate")
                    elif exc:
                        raise exc
                    return
                conn.request = request

                conn.has_been_read = False
                _orig_read = conn.read

                def read(*args, **kwargs):
                    conn.has_been_read = True
                    return _orig_read(*args, **kwargs)
                conn.read = return_read or read

                return parsed, conn
            return wrapper
        self.fake_http_connection = fake_http_connection

    def tearDown(self):
        super(MockHttpTest, self).tearDown()
        reload_module(c)


class CaptureStream(object):

    def __init__(self, stream):
        self.stream = stream
        self._capture = six.StringIO()
        self.streams = [self.stream, self._capture]

    def write(self, *args, **kwargs):
        for stream in self.streams:
            stream.write(*args, **kwargs)

    def writelines(self, *args, **kwargs):
        for stream in self.streams:
            stream.writelines(*args, **kwargs)

    def getvalue(self):
        return self._capture.getvalue()


class CaptureOutput(object):

    def __init__(self):
        self._out = CaptureStream(sys.stdout)
        self._err = CaptureStream(sys.stderr)

        WrappedOutputManager = functools.partial(s.OutputManager,
                                                 print_stream=self._out,
                                                 error_stream=self._err)
        self.patchers = [
            mock.patch('swiftclient.shell.OutputManager',
                       WrappedOutputManager),
            mock.patch('sys.stdout', self._out),
            mock.patch('sys.stderr', self._err),
        ]

    def __enter__(self):
        for patcher in self.patchers:
            patcher.start()
        return self

    def __exit__(self, *args, **kwargs):
        for patcher in self.patchers:
            patcher.stop()

    @property
    def out(self):
        return self._out.getvalue()

    @property
    def err(self):
        return self._err.getvalue()

    # act like the string captured by stdout

    def __str__(self):
        return self.out

    def __len__(self):
        return len(self.out)

    def __eq__(self, other):
        return self.out == other

    def __getattr__(self, name):
        return getattr(self.out, name)