From 8c60f34688360049c8cc4d6920d9262343001b09 Mon Sep 17 00:00:00 2001 From: Brant Knudson Date: Sun, 1 Jun 2014 10:45:50 -0500 Subject: [PATCH] Refactor auth_token token cache members to class The token cache members are moved from AuthToken to their own class. Change-Id: Ibf00d39435fa7a6d9a92a9bdfacc3f1b07f890ef --- keystoneclient/middleware/auth_token.py | 398 ++++++++++-------- .../tests/test_auth_token_middleware.py | 81 ++-- 2 files changed, 263 insertions(+), 216 deletions(-) diff --git a/keystoneclient/middleware/auth_token.py b/keystoneclient/middleware/auth_token.py index 5fbbe7f6d..d6fe3dd8b 100644 --- a/keystoneclient/middleware/auth_token.py +++ b/keystoneclient/middleware/auth_token.py @@ -545,20 +545,18 @@ class AuthProtocol(object): self.admin_password = self._conf_get('admin_password') self.admin_tenant_name = self._conf_get('admin_tenant_name') - # Token caching - self._cache_pool = None - self._cache_initialized = False - # memcache value treatment, ENCRYPT or MAC - self._memcache_security_strategy = ( + memcache_security_strategy = ( self._conf_get('memcache_security_strategy')) - if self._memcache_security_strategy is not None: - self._memcache_security_strategy = ( - self._memcache_security_strategy.upper()) - self._memcache_secret_key = ( - self._conf_get('memcache_secret_key')) - self._assert_valid_memcache_protection_config() - # By default the token will be cached for 5 minutes - self.token_cache_time = int(self._conf_get('token_cache_time')) + + self._token_cache = TokenCache( + self.LOG, + cache_time=int(self._conf_get('token_cache_time')), + hash_algorithms=self._conf_get('hash_algorithms'), + env_cache_name=self._conf_get('cache'), + memcached_servers=self._conf_get('memcached_servers'), + memcache_security_strategy=memcache_security_strategy, + memcache_secret_key=self._conf_get('memcache_secret_key')) + self._token_revocation_list = None self._token_revocation_list_fetched_time = None self.token_revocation_list_cache_timeout = datetime.timedelta( @@ -576,22 +574,6 @@ class AuthProtocol(object): self.check_revocations_for_cached = self._conf_get( 'check_revocations_for_cached') - def _assert_valid_memcache_protection_config(self): - if self._memcache_security_strategy: - if self._memcache_security_strategy not in ('MAC', 'ENCRYPT'): - raise ConfigurationError('memcache_security_strategy must be ' - 'ENCRYPT or MAC') - if not self._memcache_secret_key: - raise ConfigurationError('memcache_secret_key must be defined ' - 'when a memcache_security_strategy ' - 'is defined') - - def _init_cache(self, env): - self._cache_pool = CachePool( - env.get(self._conf_get('cache')), - self._conf_get('memcached_servers')) - self._cache_initialized = True - def _conf_get(self, name): # try config from paste-deploy first if name in self.conf: @@ -665,9 +647,7 @@ class AuthProtocol(object): """ self.LOG.debug('Authenticating user token') - # initialize memcache if we haven't done so - if not self._cache_initialized: - self._init_cache(env) + self._token_cache.initialize(env) try: self._remove_auth_headers(env) @@ -907,7 +887,7 @@ class AuthProtocol(object): token_id = None try: - token_ids, cached = self._check_user_token_cached(user_token) + token_ids, cached = self._token_cache.get(user_token) token_id = token_ids[0] if cached: data = cached @@ -933,7 +913,7 @@ class AuthProtocol(object): data = self.verify_uuid_token(user_token, retry) expires = confirm_token_not_expired(data) self._confirm_token_bind(data, env) - self._cache_put(token_id, data, expires) + self._token_cache.store(token_id, data, expires) return data except NetworkError: self.LOG.debug('Token validation failure.', exc_info=True) @@ -942,43 +922,10 @@ class AuthProtocol(object): except Exception: self.LOG.debug('Token validation failure.', exc_info=True) if token_id: - self._cache_store_invalid(token_id) + self._token_cache.store_invalid(token_id) self.LOG.warn('Authorization failed for token') raise InvalidUserToken('Token authorization failed') - def _check_user_token_cached(self, user_token): - """Check if the token is cached already. - - Returns a tuple. The first element is a list of token IDs, where the - first one is the preferred hash. - - The second element is the token data from the cache if the token was - cached, otherwise ``None``. - - :raises InvalidUserToken: if the token is invalid - - """ - - if cms.is_asn1_token(user_token): - # user_token is a PKI token that's not hashed. - - algos = self._conf_get('hash_algorithms') - token_hashes = list(cms.cms_hash_token(user_token, mode=algo) - for algo in algos) - - for token_hash in token_hashes: - cached = self._cache_get(token_hash) - if cached: - return (token_hashes, cached) - - # The token wasn't found using any hash algorithm. - return (token_hashes, None) - - # user_token is either a UUID token or a hashed PKI token. - token_id = user_token - cached = self._cache_get(token_id) - return ([token_id], cached) - def _build_user_headers(self, token_info): """Convert token object into headers. @@ -1057,102 +1004,6 @@ class AuthProtocol(object): env_key = self._header_to_env_var(key) return env.get(env_key, default) - def _cache_get(self, token_id): - """Return token information from cache. - - If token is invalid raise InvalidUserToken - return token only if fresh (not expired). - """ - - if token_id: - if self._memcache_security_strategy is None: - key = CACHE_KEY_TEMPLATE % token_id - with self._cache_pool.reserve() as cache: - serialized = cache.get(key) - else: - secret_key = self._memcache_secret_key - if isinstance(secret_key, six.string_types): - secret_key = secret_key.encode('utf-8') - security_strategy = self._memcache_security_strategy - if isinstance(security_strategy, six.string_types): - security_strategy = security_strategy.encode('utf-8') - keys = memcache_crypt.derive_keys( - token_id, - secret_key, - security_strategy) - cache_key = CACHE_KEY_TEMPLATE % ( - memcache_crypt.get_cache_key(keys)) - with self._cache_pool.reserve() as cache: - raw_cached = cache.get(cache_key) - try: - # unprotect_data will return None if raw_cached is None - serialized = memcache_crypt.unprotect_data(keys, - raw_cached) - except Exception: - msg = 'Failed to decrypt/verify cache data' - self.LOG.exception(msg) - # this should have the same effect as data not - # found in cache - serialized = None - - if serialized is None: - return None - - # Note that 'invalid' and (data, expires) are the only - # valid types of serialized cache entries, so there is not - # a collision with jsonutils.loads(serialized) == None. - if not isinstance(serialized, six.string_types): - serialized = serialized.decode('utf-8') - cached = jsonutils.loads(serialized) - if cached == 'invalid': - self.LOG.debug('Cached Token is marked unauthorized') - raise InvalidUserToken('Token authorization failed') - - data, expires = cached - - try: - expires = timeutils.parse_isotime(expires) - except ValueError: - # Gracefully handle upgrade of expiration times from *nix - # timestamps to ISO 8601 formatted dates by ignoring old cached - # values. - return - - expires = timeutils.normalize_time(expires) - utcnow = timeutils.utcnow() - if utcnow < expires: - self.LOG.debug('Returning cached token') - return data - else: - self.LOG.debug('Cached Token seems expired') - - def _cache_store(self, token_id, data): - """Store value into memcache. - - data may be the string 'invalid' or a tuple like (data, expires) - - """ - serialized_data = jsonutils.dumps(data) - if isinstance(serialized_data, six.text_type): - serialized_data = serialized_data.encode('utf-8') - if self._memcache_security_strategy is None: - cache_key = CACHE_KEY_TEMPLATE % token_id - data_to_store = serialized_data - else: - secret_key = self._memcache_secret_key - if isinstance(secret_key, six.string_types): - secret_key = secret_key.encode('utf-8') - security_strategy = self._memcache_security_strategy - if isinstance(security_strategy, six.string_types): - security_strategy = security_strategy.encode('utf-8') - keys = memcache_crypt.derive_keys( - token_id, secret_key, security_strategy) - cache_key = CACHE_KEY_TEMPLATE % memcache_crypt.get_cache_key(keys) - data_to_store = memcache_crypt.protect_data(keys, serialized_data) - - with self._cache_pool.reserve() as cache: - cache.set(cache_key, data_to_store, time=self.token_cache_time) - def _invalid_user_token(self, msg=False): # NOTE(jamielennox): use False as the default so that None is valid if msg is False: @@ -1224,21 +1075,6 @@ class AuthProtocol(object): 'identifier': identifier}) self._invalid_user_token() - def _cache_put(self, token_id, data, expires): - """Put token data into the cache. - - Stores the parsed expire date in cache allowing - quick check of token freshness on retrieval. - - """ - self.LOG.debug('Storing token in cache') - self._cache_store(token_id, (data, expires)) - - def _cache_store_invalid(self, token_id): - """Store invalid token in cache.""" - self.LOG.debug('Marking token as unauthorized in cache') - self._cache_store(token_id, 'invalid') - def verify_uuid_token(self, user_token, retry=True): """Authenticate user token with keystone. @@ -1507,6 +1343,210 @@ class CachePool(list): self.append(c) +class TokenCache(object): + """Encapsulates the auth_token token cache functionality. + + auth_token caches tokens that it's seen so that when a token is re-used the + middleware doesn't have to do a more expensive operation (like going to the + identity server) to validate the token. + + initialize() must be called before calling the other methods. + + Store a valid token in the cache using store(); mark a token as invalid in + the cache using store_invalid(). + + Check if a token is in the cache and retrieve it using get(). + + """ + + _INVALID_INDICATOR = 'invalid' + + def __init__(self, log, cache_time=None, hash_algorithms=None, + env_cache_name=None, memcached_servers=None, + memcache_security_strategy=None, memcache_secret_key=None): + self.LOG = log + self._cache_time = cache_time + self._hash_algorithms = hash_algorithms + self._env_cache_name = env_cache_name + self._memcached_servers = memcached_servers + + # memcache value treatment, ENCRYPT or MAC + self._memcache_security_strategy = memcache_security_strategy + if self._memcache_security_strategy is not None: + self._memcache_security_strategy = ( + self._memcache_security_strategy.upper()) + self._memcache_secret_key = memcache_secret_key + + self._cache_pool = None + self._initialized = False + + self._assert_valid_memcache_protection_config() + + def initialize(self, env): + if self._initialized: + return + + self._cache_pool = CachePool(env.get(self._env_cache_name), + self._memcached_servers) + self._initialized = True + + def get(self, user_token): + """Check if the token is cached already. + + Returns a tuple. The first element is a list of token IDs, where the + first one is the preferred hash. + + The second element is the token data from the cache if the token was + cached, otherwise ``None``. + + :raises InvalidUserToken: if the token is invalid + + """ + + if cms.is_asn1_token(user_token): + # user_token is a PKI token that's not hashed. + + token_hashes = list(cms.cms_hash_token(user_token, mode=algo) + for algo in self._hash_algorithms) + + for token_hash in token_hashes: + cached = self._cache_get(token_hash) + if cached: + return (token_hashes, cached) + + # The token wasn't found using any hash algorithm. + return (token_hashes, None) + + # user_token is either a UUID token or a hashed PKI token. + token_id = user_token + cached = self._cache_get(token_id) + return ([token_id], cached) + + def store(self, token_id, data, expires): + """Put token data into the cache. + + Stores the parsed expire date in cache allowing + quick check of token freshness on retrieval. + + """ + self.LOG.debug('Storing token in cache') + self._cache_store(token_id, (data, expires)) + + def store_invalid(self, token_id): + """Store invalid token in cache.""" + self.LOG.debug('Marking token as unauthorized in cache') + self._cache_store(token_id, self._INVALID_INDICATOR) + + def _assert_valid_memcache_protection_config(self): + if self._memcache_security_strategy: + if self._memcache_security_strategy not in ('MAC', 'ENCRYPT'): + raise ConfigurationError('memcache_security_strategy must be ' + 'ENCRYPT or MAC') + if not self._memcache_secret_key: + raise ConfigurationError('memcache_secret_key must be defined ' + 'when a memcache_security_strategy ' + 'is defined') + + def _cache_get(self, token_id): + """Return token information from cache. + + If token is invalid raise InvalidUserToken + return token only if fresh (not expired). + """ + + if not token_id: + # Nothing to do + return + + if self._memcache_security_strategy is None: + key = CACHE_KEY_TEMPLATE % token_id + with self._cache_pool.reserve() as cache: + serialized = cache.get(key) + else: + secret_key = self._memcache_secret_key + if isinstance(secret_key, six.string_types): + secret_key = secret_key.encode('utf-8') + security_strategy = self._memcache_security_strategy + if isinstance(security_strategy, six.string_types): + security_strategy = security_strategy.encode('utf-8') + keys = memcache_crypt.derive_keys( + token_id, + secret_key, + security_strategy) + cache_key = CACHE_KEY_TEMPLATE % ( + memcache_crypt.get_cache_key(keys)) + with self._cache_pool.reserve() as cache: + raw_cached = cache.get(cache_key) + try: + # unprotect_data will return None if raw_cached is None + serialized = memcache_crypt.unprotect_data(keys, + raw_cached) + except Exception: + msg = 'Failed to decrypt/verify cache data' + self.LOG.exception(msg) + # this should have the same effect as data not + # found in cache + serialized = None + + if serialized is None: + return None + + # Note that _INVALID_INDICATOR and (data, expires) are the only + # valid types of serialized cache entries, so there is not + # a collision with jsonutils.loads(serialized) == None. + if not isinstance(serialized, six.string_types): + serialized = serialized.decode('utf-8') + cached = jsonutils.loads(serialized) + if cached == self._INVALID_INDICATOR: + self.LOG.debug('Cached Token is marked unauthorized') + raise InvalidUserToken('Token authorization failed') + + data, expires = cached + + try: + expires = timeutils.parse_isotime(expires) + except ValueError: + # Gracefully handle upgrade of expiration times from *nix + # timestamps to ISO 8601 formatted dates by ignoring old cached + # values. + return + + expires = timeutils.normalize_time(expires) + utcnow = timeutils.utcnow() + if utcnow < expires: + self.LOG.debug('Returning cached token') + return data + else: + self.LOG.debug('Cached Token seems expired') + + def _cache_store(self, token_id, data): + """Store value into memcache. + + data may be _INVALID_INDICATOR or a tuple like (data, expires) + + """ + serialized_data = jsonutils.dumps(data) + if isinstance(serialized_data, six.text_type): + serialized_data = serialized_data.encode('utf-8') + if self._memcache_security_strategy is None: + cache_key = CACHE_KEY_TEMPLATE % token_id + data_to_store = serialized_data + else: + secret_key = self._memcache_secret_key + if isinstance(secret_key, six.string_types): + secret_key = secret_key.encode('utf-8') + security_strategy = self._memcache_security_strategy + if isinstance(security_strategy, six.string_types): + security_strategy = security_strategy.encode('utf-8') + keys = memcache_crypt.derive_keys( + token_id, secret_key, security_strategy) + cache_key = CACHE_KEY_TEMPLATE % memcache_crypt.get_cache_key(keys) + data_to_store = memcache_crypt.protect_data(keys, serialized_data) + + with self._cache_pool.reserve() as cache: + cache.set(cache_key, data_to_store, time=self._cache_time) + + def filter_factory(global_conf, **local_conf): """Returns a WSGI filter app for use with paste.deploy.""" conf = global_conf.copy() diff --git a/keystoneclient/tests/test_auth_token_middleware.py b/keystoneclient/tests/test_auth_token_middleware.py index 0e7731c22..a38bab68a 100644 --- a/keystoneclient/tests/test_auth_token_middleware.py +++ b/keystoneclient/tests/test_auth_token_middleware.py @@ -378,8 +378,8 @@ class CachePoolTest(BaseAuthTokenMiddlewareTest): 'cache': 'swift.cache' } self.set_middleware(conf=conf) - self.middleware._init_cache(env) - with self.middleware._cache_pool.reserve() as cache: + self.middleware._token_cache.initialize(env) + with self.middleware._token_cache._cache_pool.reserve() as cache: self.assertEqual(cache, 'CACHE_TEST') def test_not_use_cache_from_env(self): @@ -388,37 +388,40 @@ class CachePoolTest(BaseAuthTokenMiddlewareTest): """ self.set_middleware() env = {'swift.cache': 'CACHE_TEST'} - self.middleware._init_cache(env) - with self.middleware._cache_pool.reserve() as cache: + self.middleware._token_cache.initialize(env) + with self.middleware._token_cache._cache_pool.reserve() as cache: self.assertNotEqual(cache, 'CACHE_TEST') def test_multiple_context_managers_share_single_client(self): self.set_middleware() + token_cache = self.middleware._token_cache env = {} - self.middleware._init_cache(env) + token_cache.initialize(env) caches = [] - with self.middleware._cache_pool.reserve() as cache: + with token_cache._cache_pool.reserve() as cache: caches.append(cache) - with self.middleware._cache_pool.reserve() as cache: + with token_cache._cache_pool.reserve() as cache: caches.append(cache) self.assertIs(caches[0], caches[1]) - self.assertEqual(set(caches), set(self.middleware._cache_pool)) + self.assertEqual(set(caches), set(token_cache._cache_pool)) def test_nested_context_managers_create_multiple_clients(self): self.set_middleware() env = {} - self.middleware._init_cache(env) + self.middleware._token_cache.initialize(env) + token_cache = self.middleware._token_cache - with self.middleware._cache_pool.reserve() as outer_cache: - with self.middleware._cache_pool.reserve() as inner_cache: + with token_cache._cache_pool.reserve() as outer_cache: + with token_cache._cache_pool.reserve() as inner_cache: self.assertNotEqual(outer_cache, inner_cache) self.assertEqual( - set([inner_cache, outer_cache]), set(self.middleware._cache_pool)) + set([inner_cache, outer_cache]), + set(token_cache._cache_pool)) class GeneralAuthTokenMiddlewareTest(BaseAuthTokenMiddlewareTest, @@ -470,9 +473,10 @@ class GeneralAuthTokenMiddlewareTest(BaseAuthTokenMiddlewareTest, some_time_later = timeutils.utcnow() + datetime.timedelta(hours=4) expires = timeutils.strtime(some_time_later) data = ('this_data', expires) - self.middleware._init_cache({}) - self.middleware._cache_store(token, data) - self.assertEqual(self.middleware._cache_get(token), data[0]) + token_cache = self.middleware._token_cache + token_cache.initialize({}) + token_cache._cache_store(token, data) + self.assertEqual(token_cache._cache_get(token), data[0]) @testtools.skipUnless(memcached_available(), 'memcached not available') def test_sign_cache_data(self): @@ -487,9 +491,10 @@ class GeneralAuthTokenMiddlewareTest(BaseAuthTokenMiddlewareTest, some_time_later = timeutils.utcnow() + datetime.timedelta(hours=4) expires = timeutils.strtime(some_time_later) data = ('this_data', expires) - self.middleware._init_cache({}) - self.middleware._cache_store(token, data) - self.assertEqual(self.middleware._cache_get(token), data[0]) + token_cache = self.middleware._token_cache + token_cache.initialize({}) + token_cache._cache_store(token, data) + self.assertEqual(token_cache._cache_get(token), data[0]) @testtools.skipUnless(memcached_available(), 'memcached not available') def test_no_memcache_protection(self): @@ -503,9 +508,10 @@ class GeneralAuthTokenMiddlewareTest(BaseAuthTokenMiddlewareTest, some_time_later = timeutils.utcnow() + datetime.timedelta(hours=4) expires = timeutils.strtime(some_time_later) data = ('this_data', expires) - self.middleware._init_cache({}) - self.middleware._cache_store(token, data) - self.assertEqual(self.middleware._cache_get(token), data[0]) + token_cache = self.middleware._token_cache + token_cache.initialize({}) + token_cache._cache_store(token, data) + self.assertEqual(token_cache._cache_get(token), data[0]) def test_assert_valid_memcache_protection_config(self): # test missing memcache_secret_key @@ -942,7 +948,7 @@ class CommonAuthTokenMiddlewareTest(object): def _get_cached_token(self, token, mode='md5'): token_id = cms.cms_hash_token(token, mode=mode) - return self.middleware._cache_get(token_id) + return self.middleware._token_cache._cache_get(token_id) def test_memcache(self): # NOTE(jamielennox): it appears that httpretty can mess with the @@ -1866,11 +1872,11 @@ class TokenExpirationTest(BaseAuthTokenMiddlewareTest): token = 'mytoken' data = 'this_data' self.set_middleware() - self.middleware._init_cache({}) + self.middleware._token_cache.initialize({}) some_time_later = timeutils.strtime(at=(self.now + self.delta)) expires = some_time_later - self.middleware._cache_put(token, data, expires) - self.assertEqual(self.middleware._cache_get(token), data) + self.middleware._token_cache.store(token, data, expires) + self.assertEqual(self.middleware._token_cache._cache_get(token), data) def test_cached_token_not_expired_with_old_style_nix_timestamp(self): """Ensure we cannot retrieve a token from the cache. @@ -1882,44 +1888,45 @@ class TokenExpirationTest(BaseAuthTokenMiddlewareTest): token = 'mytoken' data = 'this_data' self.set_middleware() - self.middleware._init_cache({}) + token_cache = self.middleware._token_cache + token_cache.initialize({}) some_time_later = self.now + self.delta # Store a unix timestamp in the cache. expires = calendar.timegm(some_time_later.timetuple()) - self.middleware._cache_put(token, data, expires) - self.assertIsNone(self.middleware._cache_get(token)) + token_cache.store(token, data, expires) + self.assertIsNone(token_cache._cache_get(token)) def test_cached_token_expired(self): token = 'mytoken' data = 'this_data' self.set_middleware() - self.middleware._init_cache({}) + self.middleware._token_cache.initialize({}) some_time_earlier = timeutils.strtime(at=(self.now - self.delta)) expires = some_time_earlier - self.middleware._cache_put(token, data, expires) - self.assertIsNone(self.middleware._cache_get(token)) + self.middleware._token_cache.store(token, data, expires) + self.assertIsNone(self.middleware._token_cache._cache_get(token)) def test_cached_token_with_timezone_offset_not_expired(self): token = 'mytoken' data = 'this_data' self.set_middleware() - self.middleware._init_cache({}) + self.middleware._token_cache.initialize({}) timezone_offset = datetime.timedelta(hours=2) some_time_later = self.now - timezone_offset + self.delta expires = timeutils.strtime(some_time_later) + '-02:00' - self.middleware._cache_put(token, data, expires) - self.assertEqual(self.middleware._cache_get(token), data) + self.middleware._token_cache.store(token, data, expires) + self.assertEqual(self.middleware._token_cache._cache_get(token), data) def test_cached_token_with_timezone_offset_expired(self): token = 'mytoken' data = 'this_data' self.set_middleware() - self.middleware._init_cache({}) + self.middleware._token_cache.initialize({}) timezone_offset = datetime.timedelta(hours=2) some_time_earlier = self.now - timezone_offset - self.delta expires = timeutils.strtime(some_time_earlier) + '-02:00' - self.middleware._cache_put(token, data, expires) - self.assertIsNone(self.middleware._cache_get(token)) + self.middleware._token_cache.store(token, data, expires) + self.assertIsNone(self.middleware._token_cache._cache_get(token)) class CatalogConversionTests(BaseAuthTokenMiddlewareTest):