Merge "Standardize AccessInfo token setting"
This commit is contained in:
@@ -33,35 +33,43 @@ class AccessInfo(dict):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def factory(cls, resp=None, body=None, region_name=None, **kwargs):
|
def factory(cls, resp=None, body=None, region_name=None, auth_token=None,
|
||||||
|
**kwargs):
|
||||||
"""Create AccessInfo object given a successful auth response & body
|
"""Create AccessInfo object given a successful auth response & body
|
||||||
or a user-provided dict.
|
or a user-provided dict.
|
||||||
"""
|
"""
|
||||||
# FIXME(jamielennox): Passing region_name is deprecated. Provide an
|
# FIXME(jamielennox): Passing region_name is deprecated. Provide an
|
||||||
# appropriate warning.
|
# appropriate warning.
|
||||||
|
auth_ref = None
|
||||||
|
|
||||||
if body is not None or len(kwargs):
|
if body is not None or len(kwargs):
|
||||||
if AccessInfoV3.is_valid(body, **kwargs):
|
if AccessInfoV3.is_valid(body, **kwargs):
|
||||||
token = None
|
if resp and not auth_token:
|
||||||
if resp:
|
auth_token = resp.headers['X-Subject-Token']
|
||||||
token = resp.headers['X-Subject-Token']
|
# NOTE(jamielennox): these return AccessInfo because they
|
||||||
|
# already have auth_token installed on them.
|
||||||
if body:
|
if body:
|
||||||
if region_name:
|
if region_name:
|
||||||
body['token']['region_name'] = region_name
|
body['token']['region_name'] = region_name
|
||||||
return AccessInfoV3(token, **body['token'])
|
return AccessInfoV3(auth_token, **body['token'])
|
||||||
else:
|
else:
|
||||||
return AccessInfoV3(token, **kwargs)
|
return AccessInfoV3(auth_token, **kwargs)
|
||||||
elif AccessInfoV2.is_valid(body, **kwargs):
|
elif AccessInfoV2.is_valid(body, **kwargs):
|
||||||
if body:
|
if body:
|
||||||
if region_name:
|
if region_name:
|
||||||
body['access']['region_name'] = region_name
|
body['access']['region_name'] = region_name
|
||||||
return AccessInfoV2(**body['access'])
|
auth_ref = AccessInfoV2(**body['access'])
|
||||||
else:
|
else:
|
||||||
return AccessInfoV2(**kwargs)
|
auth_ref = AccessInfoV2(**kwargs)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Unrecognized auth response')
|
raise NotImplementedError('Unrecognized auth response')
|
||||||
else:
|
else:
|
||||||
return AccessInfoV2(**kwargs)
|
auth_ref = AccessInfoV2(**kwargs)
|
||||||
|
|
||||||
|
if auth_token:
|
||||||
|
auth_ref.auth_token = auth_token
|
||||||
|
|
||||||
|
return auth_ref
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super(AccessInfo, self).__init__(*args, **kwargs)
|
super(AccessInfo, self).__init__(*args, **kwargs)
|
||||||
@@ -110,7 +118,18 @@ class AccessInfo(dict):
|
|||||||
|
|
||||||
:returns: str
|
:returns: str
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
return self['auth_token']
|
||||||
|
|
||||||
|
@auth_token.setter
|
||||||
|
def auth_token(self, value):
|
||||||
|
self['auth_token'] = value
|
||||||
|
|
||||||
|
@auth_token.deleter
|
||||||
|
def auth_token(self):
|
||||||
|
try:
|
||||||
|
del self['auth_token']
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def expires(self):
|
def expires(self):
|
||||||
@@ -395,9 +414,12 @@ class AccessInfoV2(AccessInfo):
|
|||||||
def has_service_catalog(self):
|
def has_service_catalog(self):
|
||||||
return 'serviceCatalog' in self
|
return 'serviceCatalog' in self
|
||||||
|
|
||||||
@property
|
@AccessInfo.auth_token.getter
|
||||||
def auth_token(self):
|
def auth_token(self):
|
||||||
return self['token']['id']
|
try:
|
||||||
|
return super(AccessInfoV2, self).auth_token
|
||||||
|
except KeyError:
|
||||||
|
return self['token']['id']
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def expires(self):
|
def expires(self):
|
||||||
@@ -568,7 +590,7 @@ class AccessInfoV3(AccessInfo):
|
|||||||
token=token,
|
token=token,
|
||||||
region_name=self._region_name)
|
region_name=self._region_name)
|
||||||
if token:
|
if token:
|
||||||
self.update(auth_token=token)
|
self.auth_token = token
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_valid(cls, body, **kwargs):
|
def is_valid(cls, body, **kwargs):
|
||||||
@@ -582,10 +604,6 @@ class AccessInfoV3(AccessInfo):
|
|||||||
def has_service_catalog(self):
|
def has_service_catalog(self):
|
||||||
return 'catalog' in self
|
return 'catalog' in self
|
||||||
|
|
||||||
@property
|
|
||||||
def auth_token(self):
|
|
||||||
return self['auth_token']
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def expires(self):
|
def expires(self):
|
||||||
return timeutils.parse_isotime(self['expires_at'])
|
return timeutils.parse_isotime(self['expires_at'])
|
||||||
|
@@ -165,6 +165,37 @@ class AccessInfoTest(utils.TestCase, testresources.ResourcedTestCase):
|
|||||||
|
|
||||||
self.assertEqual(trust_id, token['access']['trust']['id'])
|
self.assertEqual(trust_id, token['access']['trust']['id'])
|
||||||
|
|
||||||
|
def test_override_auth_token(self):
|
||||||
|
token = fixture.V2Token()
|
||||||
|
token.set_scope()
|
||||||
|
token.add_role()
|
||||||
|
|
||||||
|
new_auth_token = uuid.uuid4().hex
|
||||||
|
|
||||||
|
auth_ref = access.AccessInfo.factory(body=token)
|
||||||
|
|
||||||
|
self.assertEqual(token.token_id, auth_ref.auth_token)
|
||||||
|
|
||||||
|
auth_ref.auth_token = new_auth_token
|
||||||
|
self.assertEqual(new_auth_token, auth_ref.auth_token)
|
||||||
|
|
||||||
|
del auth_ref.auth_token
|
||||||
|
self.assertEqual(token.token_id, auth_ref.auth_token)
|
||||||
|
|
||||||
|
def test_override_auth_token_in_factory(self):
|
||||||
|
token = fixture.V2Token()
|
||||||
|
token.set_scope()
|
||||||
|
token.add_role()
|
||||||
|
|
||||||
|
new_auth_token = uuid.uuid4().hex
|
||||||
|
|
||||||
|
auth_ref = access.AccessInfo.factory(body=token,
|
||||||
|
auth_token=new_auth_token)
|
||||||
|
|
||||||
|
self.assertEqual(new_auth_token, auth_ref.auth_token)
|
||||||
|
del auth_ref.auth_token
|
||||||
|
self.assertEqual(token.token_id, auth_ref.auth_token)
|
||||||
|
|
||||||
|
|
||||||
def load_tests(loader, tests, pattern):
|
def load_tests(loader, tests, pattern):
|
||||||
return testresources.OptimisingTestSuite(tests)
|
return testresources.OptimisingTestSuite(tests)
|
||||||
|
@@ -172,3 +172,12 @@ class AccessInfoTest(utils.TestCase):
|
|||||||
self.assertEqual(consumer_id, auth_ref['OS-OAUTH1']['consumer_id'])
|
self.assertEqual(consumer_id, auth_ref['OS-OAUTH1']['consumer_id'])
|
||||||
self.assertEqual(access_token_id,
|
self.assertEqual(access_token_id,
|
||||||
auth_ref['OS-OAUTH1']['access_token_id'])
|
auth_ref['OS-OAUTH1']['access_token_id'])
|
||||||
|
|
||||||
|
def test_override_auth_token(self):
|
||||||
|
token = fixture.V3Token()
|
||||||
|
token.set_project_scope()
|
||||||
|
|
||||||
|
new_auth_token = uuid.uuid4().hex
|
||||||
|
auth_ref = access.AccessInfo.factory(body=token,
|
||||||
|
auth_token=new_auth_token)
|
||||||
|
self.assertEqual(new_auth_token, auth_ref.auth_token)
|
||||||
|
Reference in New Issue
Block a user