From 9fbe21cb37082a02e5ba841a10709a510518b305 Mon Sep 17 00:00:00 2001 From: Jarret Raim Date: Thu, 9 May 2013 15:06:02 -0500 Subject: [PATCH] Copied auth.py and tests from python-reddwarf client --- barbicanclient/common/__init__.py | 0 barbicanclient/common/auth.py | 255 +++++++++++++++++++ barbicanclient/common/http.py | 229 +++++++++++++++++ barbicanclient/exceptions.py | 164 ++++++++++++ tests/__init__.py | 0 tests/test_auth.py | 407 ++++++++++++++++++++++++++++++ 6 files changed, 1055 insertions(+) create mode 100644 barbicanclient/common/__init__.py create mode 100644 barbicanclient/common/auth.py create mode 100644 barbicanclient/common/http.py create mode 100644 barbicanclient/exceptions.py create mode 100644 tests/__init__.py create mode 100644 tests/test_auth.py diff --git a/barbicanclient/common/__init__.py b/barbicanclient/common/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/barbicanclient/common/auth.py b/barbicanclient/common/auth.py new file mode 100644 index 00000000..857e61b6 --- /dev/null +++ b/barbicanclient/common/auth.py @@ -0,0 +1,255 @@ +from barbicanclient import exceptions + + +def get_authenticator_cls(cls_or_name): + """Factory method to retrieve Authenticator class.""" + if isinstance(cls_or_name, type): + return cls_or_name + elif isinstance(cls_or_name, basestring): + if cls_or_name == "keystone": + return KeyStoneV2Authenticator + elif cls_or_name == "rax": + return RaxAuthenticator + elif cls_or_name == "auth1.1": + return Auth1_1 + elif cls_or_name == "fake": + return FakeAuth + + raise ValueError("Could not determine authenticator class from the given " + "value %r." % cls_or_name) + + +class Authenticator(object): + """ + Helper class to perform Keystone or other miscellaneous authentication. + + The "authenticate" method returns a ServiceCatalog, which can be used + to obtain a token. + + """ + + URL_REQUIRED = True + + def __init__(self, client, type, url, username, password, tenant, + region=None, service_type=None, service_name=None, + service_url=None): + self.client = client + self.type = type + self.url = url + self.username = username + self.password = password + self.tenant = tenant + self.region = region + self.service_type = service_type + self.service_name = service_name + self.service_url = service_url + + def _authenticate(self, url, body, root_key='access'): + """Authenticate and extract the service catalog.""" + # Make sure we follow redirects when trying to reach Keystone + tmp_follow_all_redirects = self.client.follow_all_redirects + self.client.follow_all_redirects = True + + try: + resp, body = self.client._time_request(url, "POST", body=body) + finally: + self.client.follow_all_redirects = tmp_follow_all_redirects + + if resp.status == 200: # content must always present + try: + return ServiceCatalog(body, region=self.region, + service_type=self.service_type, + service_name=self.service_name, + service_url=self.service_url, + root_key=root_key) + except exceptions.AmbiguousEndpoints: + print "Found more than one valid endpoint. Use a more "\ + "restrictive filter" + raise + except KeyError: + raise exceptions.AuthorizationFailure() + except exceptions.EndpointNotFound: + print "Could not find any suitable endpoint. Correct region?" + raise + + elif resp.status == 305: + return resp['location'] + else: + raise exceptions.from_response(resp, body) + + def authenticate(self): + raise NotImplementedError("Missing authenticate method.") + + +class KeyStoneV2Authenticator(Authenticator): + + def authenticate(self): + if self.url is None: + raise exceptions.AuthUrlNotGiven() + return self._v2_auth(self.url) + + def _v2_auth(self, url): + """Authenticate against a v2.0 auth service.""" + body = {"auth": { + "passwordCredentials": { + "username": self.username, + "password": self.password} + } + } + + if self.tenant: + body['auth']['tenantName'] = self.tenant + + return self._authenticate(url, body) + + +class Auth1_1(Authenticator): + + def authenticate(self): + """Authenticate against a v2.0 auth service.""" + if self.url is None: + raise exceptions.AuthUrlNotGiven() + auth_url = self.url + body = {"credentials": {"username": self.username, + "key": self.password}} + return self._authenticate(auth_url, body, root_key='auth') + + try: + print(resp_body) + self.auth_token = resp_body['auth']['token']['id'] + except KeyError: + raise nova_exceptions.AuthorizationFailure() + + catalog = resp_body['auth']['serviceCatalog'] + if 'cloudDatabases' not in catalog: + raise nova_exceptions.EndpointNotFound() + endpoints = catalog['cloudDatabases'] + for endpoint in endpoints: + if self.region_name is None or \ + endpoint['region'] == self.region_name: + self.management_url = endpoint['publicURL'] + return + raise nova_exceptions.EndpointNotFound() + + +class RaxAuthenticator(Authenticator): + + def authenticate(self): + if self.url is None: + raise exceptions.AuthUrlNotGiven() + return self._rax_auth(self.url) + + def _rax_auth(self, url): + """Authenticate against the Rackspace auth service.""" + body = {'auth': { + 'RAX-KSKEY:apiKeyCredentials': { + 'username': self.username, + 'apiKey': self.password, + 'tenantName': self.tenant} + } + } + + return self._authenticate(self.url, body) + + +class FakeAuth(Authenticator): + """Useful for faking auth.""" + + def authenticate(self): + class FakeCatalog(object): + def __init__(self, auth): + self.auth = auth + + def get_public_url(self): + return "%s/%s" % ('http://localhost:8779/v1.0', + self.auth.tenant) + + def get_token(self): + return self.auth.tenant + + return FakeCatalog(self) + + +class ServiceCatalog(object): + """Represents a Keystone Service Catalog which describes a service. + + This class has methods to obtain a valid token as well as a public service + url and a management url. + + """ + + def __init__(self, resource_dict, region=None, service_type=None, + service_name=None, service_url=None, root_key='access'): + self.catalog = resource_dict + self.region = region + self.service_type = service_type + self.service_name = service_name + self.service_url = service_url + self.management_url = None + self.public_url = None + self.root_key = root_key + self._load() + + def _load(self): + if not self.service_url: + self.public_url = self._url_for(attr='region', + filter_value=self.region, + endpoint_type="publicURL") + self.management_url = self._url_for(attr='region', + filter_value=self.region, + endpoint_type="adminURL") + else: + self.public_url = self.service_url + self.management_url = self.service_url + + def get_token(self): + return self.catalog[self.root_key]['token']['id'] + + def get_management_url(self): + return self.management_url + + def get_public_url(self): + return self.public_url + + def _url_for(self, attr=None, filter_value=None, + endpoint_type='publicURL'): + """ + Fetch the public URL from the Reddwarf service for a particular + endpoint attribute. If none given, return the first. + """ + matching_endpoints = [] + if 'endpoints' in self.catalog: + # We have a bastardized service catalog. Treat it special. :/ + for endpoint in self.catalog['endpoints']: + if not filter_value or endpoint[attr] == filter_value: + matching_endpoints.append(endpoint) + if not matching_endpoints: + raise exceptions.EndpointNotFound() + + # We don't always get a service catalog back ... + if not 'serviceCatalog' in self.catalog[self.root_key]: + raise exceptions.EndpointNotFound() + + # Full catalog ... + catalog = self.catalog[self.root_key]['serviceCatalog'] + + for service in catalog: + if service.get("type") != self.service_type: + continue + + if (self.service_name and self.service_type == 'database' and + service.get('name') != self.service_name): + continue + + endpoints = service['endpoints'] + for endpoint in endpoints: + if not filter_value or endpoint.get(attr) == filter_value: + endpoint["serviceName"] = service.get("name") + matching_endpoints.append(endpoint) + + if not matching_endpoints: + raise exceptions.EndpointNotFound() + elif len(matching_endpoints) > 1: + raise exceptions.AmbiguousEndpoints(endpoints=matching_endpoints) + else: + return matching_endpoints[0].get(endpoint_type, None) \ No newline at end of file diff --git a/barbicanclient/common/http.py b/barbicanclient/common/http.py new file mode 100644 index 00000000..16fdf418 --- /dev/null +++ b/barbicanclient/common/http.py @@ -0,0 +1,229 @@ +import httplib2 +import logging + +from barbicanclient.common import auth + + +class BarbicanHTTPClient(httplib2.Http): + + USER_AGENT = 'python-barbicanclient' + + def __init__(self, user, password, tenant, auth_url, service_name, + service_url=None, + auth_strategy=None, insecure=False, + timeout=None, proxy_tenant_id=None, + proxy_token=None, region_name=None, + endpoint_type='publicURL', service_type=None, + timings=False): + + super(BarbicanHTTPClient, self).__init__(timeout=timeout) + + self.username = user + self.password = password + self.tenant = tenant + if auth_url: + self.auth_url = auth_url.rstrip('/') + else: + self.auth_url = None + self.region_name = region_name + self.endpoint_type = endpoint_type + self.service_url = service_url + self.service_type = service_type + self.service_name = service_name + self.timings = timings + + self.times = [] # [("item", starttime, endtime), ...] + + self.auth_token = None + self.proxy_token = proxy_token + self.proxy_tenant_id = proxy_tenant_id + + # httplib2 overrides + self.force_exception_to_status_code = True + self.disable_ssl_certificate_validation = insecure + + auth_cls = auth.get_authenticator_cls(auth_strategy) + + self.authenticator = auth_cls(self, auth_strategy, + self.auth_url, self.username, + self.password, self.tenant, + region=region_name, + service_type=service_type, + service_name=service_name, + service_url=service_url) + + def get_timings(self): + return self.times + + def http_log(self, args, kwargs, resp, body): + if not RDC_PP: + self.simple_log(args, kwargs, resp, body) + else: + self.pretty_log(args, kwargs, resp, body) + + def simple_log(self, args, kwargs, resp, body): + if not _logger.isEnabledFor(logging.DEBUG): + return + + string_parts = ['curl -i'] + for element in args: + if element in ('GET', 'POST'): + string_parts.append(' -X %s' % element) + else: + string_parts.append(' %s' % element) + + for element in kwargs['headers']: + header = ' -H "%s: %s"' % (element, kwargs['headers'][element]) + string_parts.append(header) + + _logger.debug("REQ: %s\n" % "".join(string_parts)) + if 'body' in kwargs: + _logger.debug("REQ BODY: %s\n" % (kwargs['body'])) + _logger.debug("RESP:%s %s\n", resp, body) + + def pretty_log(self, args, kwargs, resp, body): + from reddwarfclient import common + if not _logger.isEnabledFor(logging.DEBUG): + return + + string_parts = ['curl -i'] + for element in args: + if element in ('GET', 'POST'): + string_parts.append(' -X %s' % element) + else: + string_parts.append(' %s' % element) + + for element in kwargs['headers']: + header = ' -H "%s: %s"' % (element, kwargs['headers'][element]) + string_parts.append(header) + + curl_cmd = "".join(string_parts) + _logger.debug("REQUEST:") + if 'body' in kwargs: + _logger.debug("%s -d '%s'" % (curl_cmd, kwargs['body'])) + try: + req_body = json.dumps(json.loads(kwargs['body']), + sort_keys=True, indent=4) + except: + req_body = kwargs['body'] + _logger.debug("BODY: %s\n" % (req_body)) + else: + _logger.debug(curl_cmd) + + try: + resp_body = json.dumps(json.loads(body), sort_keys=True, indent=4) + except: + resp_body = body + _logger.debug("RESPONSE HEADERS: %s" % resp) + _logger.debug("RESPONSE BODY : %s" % resp_body) + + def request(self, *args, **kwargs): + kwargs.setdefault('headers', kwargs.get('headers', {})) + kwargs['headers']['User-Agent'] = self.USER_AGENT + self.morph_request(kwargs) + + resp, body = super(ReddwarfHTTPClient, self).request(*args, **kwargs) + + # Save this in case anyone wants it. + self.last_response = (resp, body) + self.http_log(args, kwargs, resp, body) + + if body: + try: + body = self.morph_response_body(body) + except exceptions.ResponseFormatError: + # Acceptable only if the response status is an error code. + # Otherwise its the API or client misbehaving. + self.raise_error_from_status(resp, None) + raise # Not accepted! + else: + body = None + + if resp.status in expected_errors: + raise exceptions.from_response(resp, body) + + return resp, body + + def raise_error_from_status(self, resp, body): + if resp.status in expected_errors: + raise exceptions.from_response(resp, body) + + def morph_request(self, kwargs): + kwargs['headers']['Accept'] = 'application/json' + kwargs['headers']['Content-Type'] = 'application/json' + if 'body' in kwargs: + kwargs['body'] = json.dumps(kwargs['body']) + + def morph_response_body(self, body_string): + try: + return json.loads(body_string) + except ValueError: + raise exceptions.ResponseFormatError() + + def _time_request(self, url, method, **kwargs): + start_time = time.time() + resp, body = self.request(url, method, **kwargs) + self.times.append(("%s %s" % (method, url), + start_time, time.time())) + return resp, body + + def _cs_request(self, url, method, **kwargs): + def request(): + kwargs.setdefault('headers', {})['X-Auth-Token'] = self.auth_token + if self.tenant: + kwargs['headers']['X-Auth-Project-Id'] = self.tenant + + resp, body = self._time_request(self.service_url + url, method, + **kwargs) + return resp, body + + if not self.auth_token or not self.service_url: + self.authenticate() + + # Perform the request once. If we get a 401 back then it + # might be because the auth token expired, so try to + # re-authenticate and try again. If it still fails, bail. + try: + return request() + except exceptions.Unauthorized, ex: + self.authenticate() + return request() + + def get(self, url, **kwargs): + return self._cs_request(url, 'GET', **kwargs) + + def post(self, url, **kwargs): + return self._cs_request(url, 'POST', **kwargs) + + def put(self, url, **kwargs): + return self._cs_request(url, 'PUT', **kwargs) + + def delete(self, url, **kwargs): + return self._cs_request(url, 'DELETE', **kwargs) + + def authenticate(self): + """Auths the client and gets a token. May optionally set a service url. + + The client will get auth errors until the authentication step + occurs. Additionally, if a service_url was not explicitly given in + the clients __init__ method, one will be obtained from the auth + service. + + """ + catalog = self.authenticator.authenticate() + if self.service_url: + possible_service_url = None + else: + if self.endpoint_type == "publicURL": + possible_service_url = catalog.get_public_url() + elif self.endpoint_type == "adminURL": + possible_service_url = catalog.get_management_url() + self.authenticate_with_token(catalog.get_token(), possible_service_url) + + def authenticate_with_token(self, token, service_url=None): + self.auth_token = token + if not self.service_url: + if not service_url: + raise exceptions.ServiceUrlNotGiven() + else: + self.service_url = service_url \ No newline at end of file diff --git a/barbicanclient/exceptions.py b/barbicanclient/exceptions.py new file mode 100644 index 00000000..b6a8c36b --- /dev/null +++ b/barbicanclient/exceptions.py @@ -0,0 +1,164 @@ +class UnsupportedVersion(Exception): + """Indicates that the user is trying to use an unsupported + version of the API""" + pass + + +class CommandError(Exception): + pass + + +class AuthorizationFailure(Exception): + pass + + +class NoUniqueMatch(Exception): + pass + + +class NoTokenLookupException(Exception): + """This form of authentication does not support looking up + endpoints from an existing token.""" + pass + + +class EndpointNotFound(Exception): + """Could not find Service or Region in Service Catalog.""" + pass + + +class AuthUrlNotGiven(EndpointNotFound): + """The auth url was not given.""" + pass + + +class ServiceUrlNotGiven(EndpointNotFound): + """The service url was not given.""" + pass + + +class ResponseFormatError(Exception): + """Could not parse the response format.""" + pass + + +class AmbiguousEndpoints(Exception): + """Found more than one matching endpoint in Service Catalog.""" + def __init__(self, endpoints=None): + self.endpoints = endpoints + + def __str__(self): + return "AmbiguousEndpoints: %s" % repr(self.endpoints) + + +class ClientException(Exception): + """ + The base exception class for all exceptions this library raises. + """ + def __init__(self, code, message=None, details=None, request_id=None): + self.code = code + self.message = message or self.__class__.message + self.details = details + self.request_id = request_id + + def __str__(self): + formatted_string = "%s (HTTP %s)" % (self.message, self.code) + if self.request_id: + formatted_string += " (Request-ID: %s)" % self.request_id + + return formatted_string + + +class BadRequest(ClientException): + """ + HTTP 400 - Bad request: you sent some malformed data. + """ + http_status = 400 + message = "Bad request" + + +class Unauthorized(ClientException): + """ + HTTP 401 - Unauthorized: bad credentials. + """ + http_status = 401 + message = "Unauthorized" + + +class Forbidden(ClientException): + """ + HTTP 403 - Forbidden: your credentials don't give you access to this + resource. + """ + http_status = 403 + message = "Forbidden" + + +class NotFound(ClientException): + """ + HTTP 404 - Not found + """ + http_status = 404 + message = "Not found" + + +class OverLimit(ClientException): + """ + HTTP 413 - Over limit: you're over the API limits for this time period. + """ + http_status = 413 + message = "Over limit" + + +# NotImplemented is a python keyword. +class HTTPNotImplemented(ClientException): + """ + HTTP 501 - Not Implemented: the server does not support this operation. + """ + http_status = 501 + message = "Not Implemented" + + +class UnprocessableEntity(ClientException): + """ + HTTP 422 - Unprocessable Entity: The request cannot be processed. + """ + http_status = 422 + message = "Unprocessable Entity" + + +# In Python 2.4 Exception is old-style and thus doesn't have a __subclasses__() +# so we can do this: +# _code_map = dict((c.http_status, c) +# for c in ClientException.__subclasses__()) +# +# Instead, we have to hardcode it: +_code_map = dict((c.http_status, c) for c in [BadRequest, Unauthorized, + Forbidden, NotFound, OverLimit, + HTTPNotImplemented, + UnprocessableEntity]) + + +def from_response(response, body): + """ + Return an instance of an ClientException or subclass + based on an httplib2 response. + + Usage:: + + resp, body = http.request(...) + if resp.status != 200: + raise exception_from_response(resp, body) + """ + cls = _code_map.get(response.status, ClientException) + if body: + message = "n/a" + details = "n/a" + if hasattr(body, 'keys'): + error = body[body.keys()[0]] + message = error.get('message', None) + details = error.get('details', None) + return cls(code=response.status, message=message, details=details) + else: + request_id = response.get('x-compute-request-id') + return cls(code=response.status, request_id=request_id) \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 00000000..6de1191d --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,407 @@ +import contextlib + +from testtools import TestCase +from barbicanclient.common import auth +from mock import Mock + +from barbicanclient import exceptions + +#Unit tests for the classes and functions in auth.py. + + +def check_url_none(test_case, auth_class): + # url is None, it must throw exception + authObj = auth_class(url=None, type=auth_class, client=None, + username=None, password=None, tenant=None) + try: + authObj.authenticate() + test_case.fail("AuthUrlNotGiven exception expected") + except exceptions.AuthUrlNotGiven: + pass + + +class AuthenticatorTest(TestCase): + def setUp(self): + super(AuthenticatorTest, self).setUp() + self.orig_load = auth.ServiceCatalog._load + self.orig__init = auth.ServiceCatalog.__init__ + + def tearDown(self): + super(AuthenticatorTest, self).tearDown() + auth.ServiceCatalog._load = self.orig_load + auth.ServiceCatalog.__init__ = self.orig__init + + def test_get_authenticator_cls(self): + class_list = (auth.KeyStoneV2Authenticator, + auth.RaxAuthenticator, + auth.Auth1_1, + auth.FakeAuth) + + for c in class_list: + self.assertEqual(c, auth.get_authenticator_cls(c)) + + class_names = {"keystone": auth.KeyStoneV2Authenticator, + "rax": auth.RaxAuthenticator, + "auth1.1": auth.Auth1_1, + "fake": auth.FakeAuth} + + for cn in class_names.keys(): + self.assertEqual(class_names[cn], auth.get_authenticator_cls(cn)) + + cls_or_name = "_unknown_" + self.assertRaises(ValueError, auth.get_authenticator_cls, cls_or_name) + + def test__authenticate(self): + authObj = auth.Authenticator(Mock(), auth.KeyStoneV2Authenticator, + Mock(), Mock(), Mock(), Mock()) + # test response code 200 + resp = Mock() + resp.status = 200 + body = "test_body" + + auth.ServiceCatalog._load = Mock(return_value=1) + authObj.client._time_request = Mock(return_value=(resp, body)) + + sc = authObj._authenticate(Mock(), Mock()) + self.assertEqual(body, sc.catalog) + + # test AmbiguousEndpoints exception + auth.ServiceCatalog.__init__ = \ + Mock(side_effect=exceptions.AmbiguousEndpoints) + self.assertRaises(exceptions.AmbiguousEndpoints, + authObj._authenticate, Mock(), Mock()) + + # test handling KeyError and raising AuthorizationFailure exception + auth.ServiceCatalog.__init__ = Mock(side_effect=KeyError) + self.assertRaises(exceptions.AuthorizationFailure, + authObj._authenticate, Mock(), Mock()) + + # test EndpointNotFound exception + mock = Mock(side_effect=exceptions.EndpointNotFound) + auth.ServiceCatalog.__init__ = mock + self.assertRaises(exceptions.EndpointNotFound, + authObj._authenticate, Mock(), Mock()) + mock.side_effect = None + + # test response code 305 + resp.__getitem__ = Mock(return_value='loc') + resp.status = 305 + body = "test_body" + authObj.client._time_request = Mock(return_value=(resp, body)) + + l = authObj._authenticate(Mock(), Mock()) + self.assertEqual('loc', l) + + # test any response code other than 200 and 305 + resp.status = 404 + exceptions.from_response = Mock(side_effect=ValueError) + self.assertRaises(ValueError, authObj._authenticate, Mock(), Mock()) + + def test_authenticate(self): + authObj = auth.Authenticator(Mock(), auth.KeyStoneV2Authenticator, + Mock(), Mock(), Mock(), Mock()) + self.assertRaises(NotImplementedError, authObj.authenticate) + + +class KeyStoneV2AuthenticatorTest(TestCase): + def test_authenticate(self): + # url is None + check_url_none(self, auth.KeyStoneV2Authenticator) + + # url is not None, so it must not throw exception + url = "test_url" + cls_type = auth.KeyStoneV2Authenticator + authObj = auth.KeyStoneV2Authenticator(url=url, type=cls_type, + client=None, username=None, + password=None, tenant=None) + + def side_effect_func(url): + return url + + mock = Mock() + mock.side_effect = side_effect_func + authObj._v2_auth = mock + r = authObj.authenticate() + self.assertEqual(url, r) + + def test__v2_auth(self): + username = "reddwarf_user" + password = "reddwarf_password" + tenant = "tenant" + cls_type = auth.KeyStoneV2Authenticator + authObj = auth.KeyStoneV2Authenticator(url=None, type=cls_type, + client=None, + username=username, + password=password, + tenant=tenant) + + def side_effect_func(url, body): + return body + + mock = Mock() + mock.side_effect = side_effect_func + authObj._authenticate = mock + body = authObj._v2_auth(Mock()) + self.assertEqual(username, + body['auth']['passwordCredentials']['username']) + self.assertEqual(password, + body['auth']['passwordCredentials']['password']) + self.assertEqual(tenant, body['auth']['tenantName']) + + +class Auth1_1Test(TestCase): + def test_authenticate(self): + # handle when url is None + check_url_none(self, auth.Auth1_1) + + # url is not none + username = "reddwarf_user" + password = "reddwarf_password" + url = "test_url" + authObj = auth.Auth1_1(url=url, + type=auth.Auth1_1, + client=None, username=username, + password=password, tenant=None) + + def side_effect_func(auth_url, body, root_key): + return auth_url, body, root_key + + mock = Mock() + mock.side_effect = side_effect_func + authObj._authenticate = mock + auth_url, body, root_key = authObj.authenticate() + + self.assertEqual(username, body['credentials']['username']) + self.assertEqual(password, body['credentials']['key']) + self.assertEqual(auth_url, url) + self.assertEqual('auth', root_key) + + +class RaxAuthenticatorTest(TestCase): + def test_authenticate(self): + # url is None + check_url_none(self, auth.RaxAuthenticator) + + # url is not None, so it must not throw exception + url = "test_url" + authObj = auth.RaxAuthenticator(url=url, + type=auth.RaxAuthenticator, + client=None, username=None, + password=None, tenant=None) + + def side_effect_func(url): + return url + + mock = Mock() + mock.side_effect = side_effect_func + authObj._rax_auth = mock + r = authObj.authenticate() + self.assertEqual(url, r) + + def test__rax_auth(self): + username = "reddwarf_user" + password = "reddwarf_password" + tenant = "tenant" + authObj = auth.RaxAuthenticator(url=None, + type=auth.RaxAuthenticator, + client=None, username=username, + password=password, tenant=tenant) + + def side_effect_func(url, body): + return body + + mock = Mock() + mock.side_effect = side_effect_func + authObj._authenticate = mock + body = authObj._rax_auth(Mock()) + + v = body['auth']['RAX-KSKEY:apiKeyCredentials']['username'] + self.assertEqual(username, v) + + v = body['auth']['RAX-KSKEY:apiKeyCredentials']['apiKey'] + self.assertEqual(password, v) + + v = body['auth']['RAX-KSKEY:apiKeyCredentials']['tenantName'] + self.assertEqual(tenant, v) + + +class FakeAuthTest(TestCase): + def test_authenticate(self): + tenant = "tenant" + authObj = auth.FakeAuth(url=None, + type=auth.FakeAuth, + client=None, username=None, + password=None, tenant=tenant) + + fc = authObj.authenticate() + public_url = "%s/%s" % ('http://localhost:8779/v1.0', tenant) + self.assertEqual(public_url, fc.get_public_url()) + self.assertEqual(tenant, fc.get_token()) + + +class ServiceCatalogTest(TestCase): + def setUp(self): + super(ServiceCatalogTest, self).setUp() + self.orig_url_for = auth.ServiceCatalog._url_for + self.orig__init__ = auth.ServiceCatalog.__init__ + auth.ServiceCatalog.__init__ = Mock(return_value=None) + self.test_url = "http://localhost:1234/test" + + def tearDown(self): + super(ServiceCatalogTest, self).tearDown() + auth.ServiceCatalog._url_for = self.orig_url_for + auth.ServiceCatalog.__init__ = self.orig__init__ + + def test__load(self): + url = "random_url" + auth.ServiceCatalog._url_for = Mock(return_value=url) + + # when service_url is None + scObj = auth.ServiceCatalog() + scObj.region = None + scObj.service_url = None + scObj._load() + self.assertEqual(url, scObj.public_url) + self.assertEqual(url, scObj.management_url) + + # service url is not None + service_url = "service_url" + scObj = auth.ServiceCatalog() + scObj.region = None + scObj.service_url = service_url + scObj._load() + self.assertEqual(service_url, scObj.public_url) + self.assertEqual(service_url, scObj.management_url) + + def test_get_token(self): + test_id = "test_id" + scObj = auth.ServiceCatalog() + scObj.root_key = "root_key" + scObj.catalog = dict() + scObj.catalog[scObj.root_key] = dict() + scObj.catalog[scObj.root_key]['token'] = dict() + scObj.catalog[scObj.root_key]['token']['id'] = test_id + self.assertEqual(test_id, scObj.get_token()) + + def test_get_management_url(self): + test_mng_url = "test_management_url" + scObj = auth.ServiceCatalog() + scObj.management_url = test_mng_url + self.assertEqual(test_mng_url, scObj.get_management_url()) + + def test_get_public_url(self): + test_public_url = "test_public_url" + scObj = auth.ServiceCatalog() + scObj.public_url = test_public_url + self.assertEqual(test_public_url, scObj.get_public_url()) + + def test__url_for(self): + scObj = auth.ServiceCatalog() + + # case for no endpoint found + self.case_no_endpoint_match(scObj) + + # case for empty service catalog + self.case_endpoing_with_empty_catalog(scObj) + + # more than one matching endpoints + self.case_ambiguous_endpoint(scObj) + + # happy case + self.case_unique_endpoint(scObj) + + # testing if-statements in for-loop to iterate services in catalog + self.case_iterating_services_in_catalog(scObj) + + def case_no_endpoint_match(self, scObj): + # empty endpoint list + scObj.catalog = dict() + scObj.catalog['endpoints'] = list() + self.assertRaises(exceptions.EndpointNotFound, scObj._url_for) + + def side_effect_func_ep(attr): + return "test_attr_value" + + # simulating dict + endpoint = Mock() + mock = Mock() + mock.side_effect = side_effect_func_ep + endpoint.__getitem__ = mock + scObj.catalog['endpoints'].append(endpoint) + + # not-empty list but not matching endpoint + filter_value = "not_matching_value" + self.assertRaises(exceptions.EndpointNotFound, scObj._url_for, + attr="test_attr", filter_value=filter_value) + + filter_value = "test_attr_value" # so that we have an endpoint match + scObj.root_key = "access" + scObj.catalog[scObj.root_key] = dict() + self.assertRaises(exceptions.EndpointNotFound, scObj._url_for, + attr="test_attr", filter_value=filter_value) + + def case_endpoing_with_empty_catalog(self, scObj): + # first, test with empty catalog, this should pass since + # there is already enpoint added + scObj.catalog[scObj.root_key]['serviceCatalog'] = list() + + endpoint = scObj.catalog['endpoints'][0] + endpoint.get = Mock(return_value=self.test_url) + r_url = scObj._url_for(attr="test_attr", + filter_value="test_attr_value") + self.assertEqual(self.test_url, r_url) + + def case_ambiguous_endpoint(self, scObj): + scObj.service_type = "reddwarf" + scObj.service_name = "test_service_name" + + def side_effect_func_service(key): + if key == "type": + return "reddwarf" + elif key == "name": + return "test_service_name" + return None + + mock1 = Mock() + mock1.side_effect = side_effect_func_service + service1 = Mock() + service1.get = mock1 + + endpoint2 = {"test_attr": "test_attr_value"} + service1.__getitem__ = Mock(return_value=[endpoint2]) + scObj.catalog[scObj.root_key]['serviceCatalog'] = [service1] + self.assertRaises(exceptions.AmbiguousEndpoints, scObj._url_for, + attr="test_attr", filter_value="test_attr_value") + + def case_unique_endpoint(self, scObj): + # changing the endpoint2 attribute to pass the filter + service1 = scObj.catalog[scObj.root_key]['serviceCatalog'][0] + endpoint2 = service1[0][0] + endpoint2["test_attr"] = "new value not matching filter" + r_url = scObj._url_for(attr="test_attr", + filter_value="test_attr_value") + self.assertEqual(self.test_url, r_url) + + def case_iterating_services_in_catalog(self, scObj): + service1 = scObj.catalog[scObj.root_key]['serviceCatalog'][0] + + scObj.catalog = dict() + scObj.root_key = "access" + scObj.catalog[scObj.root_key] = dict() + scObj.service_type = "no_match" + + scObj.catalog[scObj.root_key]['serviceCatalog'] = [service1] + self.assertRaises(exceptions.EndpointNotFound, scObj._url_for) + + scObj.service_type = "database" + scObj.service_name = "no_match" + self.assertRaises(exceptions.EndpointNotFound, scObj._url_for) + + # no endpoints and no 'serviceCatalog' in catalog => raise exception + scObj = auth.ServiceCatalog() + scObj.catalog = dict() + scObj.root_key = "access" + scObj.catalog[scObj.root_key] = dict() + scObj.catalog[scObj.root_key]['serviceCatalog'] = [] + self.assertRaises(exceptions.EndpointNotFound, scObj._url_for, + attr="test_attr", filter_value="test_attr_value") \ No newline at end of file