diff --git a/swift/common/memcached.py b/swift/common/memcached.py index 315ea9e205..9110f1682c 100644 --- a/swift/common/memcached.py +++ b/swift/common/memcached.py @@ -119,13 +119,6 @@ def set_msg(key, flags, timeout, value): ]) + (b'\r\n' + value + b'\r\n') -# get the prefix of a user provided memcache key by removing the content after -# the last '/', all current usages within swift are using prefix, such as -# "shard-updating-v2", "nvratelimit" and etc. -def get_key_prefix(key): - return key.rsplit('/', 1)[0] - - class MemcacheConnectionError(Exception): pass @@ -183,6 +176,29 @@ class MemcacheConnPool(Pool): raise +class MemcacheCommand(object): + """ + Helper class that encapsulates common parameters of a command. + + :param method: the name of the MemcacheRing method that was called. + :param key: the memcached key. + """ + __slots__ = ('method', 'key', 'command', 'hash_key') + + def __init__(self, method, key): + self.method = method + self.key = key + self.command = method.encode() + self.hash_key = md5hash(key) + + @property + def key_prefix(self): + # get the prefix of a user provided memcache key by removing the + # content after the last '/', all current usages within swift are using + # prefix, such as "shard-updating-v2", "nvratelimit" and etc. + return self.key.rsplit('/', 1)[0] + + class MemcacheRing(object): """ Simple, consistent-hashed memcache client. @@ -225,8 +241,21 @@ class MemcacheRing(object): def memcache_servers(self): return list(self._client_cache.keys()) - def _exception_occurred(self, server, e, key_prefix, method, - conn_start_time, action='talking', sock=None, + """ + Handles exceptions. + + :param server: a server. + :param e: an exception. + :param cmd: an instance of MemcacheCommand. + :param conn_start_time: the time at which the failed operation started. + :param action: a verb describing the operation. + :param sock: an optional socket that needs to be closed by this method. + :param fp: an optional file pointer that needs to be closed by this method. + :param got_connection: if ``True``, the server's connection will be reset + in the cached connection pool. + """ + def _exception_occurred(self, server, e, cmd, conn_start_time, + action='talking', sock=None, fp=None, got_connection=True): if isinstance(e, Timeout): self.logger.error( @@ -234,28 +263,30 @@ class MemcacheRing(object): ": with key_prefix %(key_prefix)s, method %(method)s, " "config_timeout %(config_timeout)s, time_spent %(time_spent)s", {'action': action, 'server': server, - 'key_prefix': key_prefix, 'method': method, + 'key_prefix': cmd.key_prefix, 'method': cmd.method, 'config_timeout': e.seconds, 'time_spent': tm.time() - conn_start_time}) self.logger.timing_since( - 'memcached.' + method + '.timeout.timing', conn_start_time) + 'memcached.' + cmd.method + '.timeout.timing', + conn_start_time) elif isinstance(e, (socket.error, MemcacheConnectionError)): self.logger.error( "Error %(action)s to memcached: %(server)s: " "with key_prefix %(key_prefix)s, method %(method)s, " "time_spent %(time_spent)s, %(err)s", {'action': action, 'server': server, - 'key_prefix': key_prefix, 'method': method, + 'key_prefix': cmd.key_prefix, 'method': cmd.method, 'time_spent': tm.time() - conn_start_time, 'err': e}) self.logger.timing_since( - 'memcached.' + method + '.conn_err.timing', conn_start_time) + 'memcached.' + cmd.method + '.conn_err.timing', + conn_start_time) else: self.logger.exception("Error %(action)s to memcached: %(server)s" ": with key_prefix %(key_prefix)s", {'action': action, 'server': server, - 'key_prefix': key_prefix}) + 'key_prefix': cmd.key_prefix}) self.logger.timing_since( - 'memcached.' + method + '.errors.timing', conn_start_time) + 'memcached.' + cmd.method + '.errors.timing', conn_start_time) try: if fp: @@ -286,18 +317,15 @@ class MemcacheRing(object): self._error_limited[server] = now + self._error_limit_duration self.logger.error('Error limiting server %s', server) - def _get_conns(self, method, key_prefix, hash_key): + def _get_conns(self, cmd): """ Retrieves a server conn from the pool, or connects a new one. Chooses the server based on a consistent hash of "key". - :param method: the name of memcache method. - :param key_prefix: the prefix of user provided key. - :param hash_key: the consistent hash of user key, or server key for - set_multi and get_multi. + :param cmd: an instance of MemcacheCommand. :return: generator to serve memcached connection """ - pos = bisect(self._sorted, hash_key) + pos = bisect(self._sorted, cmd.hash_key) served = [] any_yielded = False while len(served) < self._tries: @@ -316,16 +344,15 @@ class MemcacheRing(object): any_yielded = True yield server, fp, sock except MemcachePoolTimeout as e: - self._exception_occurred( - server, e, key_prefix, method, pool_start_time, - action='getting a connection', got_connection=False) + self._exception_occurred(server, e, cmd, pool_start_time, + action='getting a connection', + got_connection=False) except (Exception, Timeout) as e: # Typically a Timeout exception caught here is the one raised # by the create() method of this server's MemcacheConnPool # object. - self._exception_occurred( - server, e, key_prefix, method, pool_start_time, - action='connecting', sock=sock) + self._exception_occurred(server, e, cmd, pool_start_time, + action='connecting', sock=sock) if not any_yielded: self.logger.error('All memcached servers error-limited') @@ -353,8 +380,7 @@ class MemcacheRing(object): :param raise_on_error: if True, propagate Timeouts and other errors. By default, errors are ignored. """ - key_prefix = get_key_prefix(key) - hash_key = md5hash(key) + cmd = MemcacheCommand('set', key) timeout = sanitize_timeout(time) flags = 0 if serialize: @@ -365,11 +391,11 @@ class MemcacheRing(object): elif not isinstance(value, bytes): value = str(value).encode('utf-8') - for (server, fp, sock) in self._get_conns('set', key_prefix, hash_key): + for (server, fp, sock) in self._get_conns(cmd): conn_start_time = tm.time() try: with Timeout(self._io_timeout): - sock.sendall(set_msg(hash_key, flags, timeout, value)) + sock.sendall(set_msg(cmd.hash_key, flags, timeout, value)) # Wait for the set to complete msg = fp.readline().strip() if msg != b'STORED': @@ -389,9 +415,8 @@ class MemcacheRing(object): self._return_conn(server, fp, sock) return except (Exception, Timeout) as e: - self._exception_occurred( - server, e, key_prefix, 'set', conn_start_time, - sock=sock, fp=fp) + self._exception_occurred(server, e, cmd, conn_start_time, + sock=sock, fp=fp) if raise_on_error: raise MemcacheConnectionError( "No memcached connections succeeded.") @@ -407,21 +432,21 @@ class MemcacheRing(object): By default, errors are treated as cache misses. :returns: value of the key in memcache """ - key_prefix = get_key_prefix(key) - hash_key = md5hash(key) + cmd = MemcacheCommand('get', key) value = None - for (server, fp, sock) in self._get_conns('get', key_prefix, hash_key): + for (server, fp, sock) in self._get_conns(cmd): conn_start_time = tm.time() try: with Timeout(self._io_timeout): - sock.sendall(b'get ' + hash_key + b'\r\n') + sock.sendall(b'get ' + cmd.hash_key + b'\r\n') line = fp.readline().strip().split() while True: if not line: raise MemcacheConnectionError('incomplete read') if line[0].upper() == b'END': break - if line[0].upper() == b'VALUE' and line[1] == hash_key: + if (line[0].upper() == b'VALUE' and + line[1] == cmd.hash_key): size = int(line[3]) value = fp.read(size) if int(line[2]) & PICKLE_FLAG: @@ -433,9 +458,8 @@ class MemcacheRing(object): self._return_conn(server, fp, sock) return value except (Exception, Timeout) as e: - self._exception_occurred( - server, e, key_prefix, 'get', conn_start_time, - sock=sock, fp=fp) + self._exception_occurred(server, e, cmd, conn_start_time, + sock=sock, fp=fp) if raise_on_error: raise MemcacheConnectionError( "No memcached connections succeeded.") @@ -458,38 +482,32 @@ class MemcacheRing(object): :returns: result of incrementing :raises MemcacheConnectionError: """ - key_prefix = get_key_prefix(key) - hash_key = md5hash(key) - command = b'incr' - if delta < 0: - command = b'decr' + cmd = MemcacheCommand('incr' if delta >= 0 else 'decr', key) delta = str(abs(int(delta))).encode('ascii') timeout = sanitize_timeout(time) - method = command.decode() - for (server, fp, sock) in self._get_conns(method, key_prefix, - hash_key): + for (server, fp, sock) in self._get_conns(cmd): conn_start_time = tm.time() try: with Timeout(self._io_timeout): sock.sendall(b' '.join([ - command, hash_key, delta]) + b'\r\n') + cmd.command, cmd.hash_key, delta]) + b'\r\n') line = fp.readline().strip().split() if not line: raise MemcacheConnectionError('incomplete read') if line[0].upper() == b'NOT_FOUND': add_val = delta - if command == b'decr': + if cmd.command == b'decr': add_val = b'0' sock.sendall( b' '.join( - [b'add', hash_key, b'0', str(timeout).encode( - 'ascii'), + [b'add', cmd.hash_key, b'0', + str(timeout).encode('ascii'), str(len(add_val)).encode('ascii') ]) + b'\r\n' + add_val + b'\r\n') line = fp.readline().strip().split() if line[0].upper() == b'NOT_STORED': sock.sendall(b' '.join([ - command, hash_key, delta]) + b'\r\n') + cmd.command, cmd.hash_key, delta]) + b'\r\n') line = fp.readline().strip().split() ret = int(line[0].strip()) else: @@ -499,9 +517,8 @@ class MemcacheRing(object): self._return_conn(server, fp, sock) return ret except (Exception, Timeout) as e: - self._exception_occurred( - server, e, key_prefix, method, conn_start_time, - sock=sock, fp=fp) + self._exception_occurred(server, e, cmd, conn_start_time, + sock=sock, fp=fp) raise MemcacheConnectionError("No memcached connections succeeded.") @memcached_timing_stats(sample_rate=TIMING_SAMPLE_RATE_LOW) @@ -529,23 +546,21 @@ class MemcacheRing(object): :param server_key: key to use in determining which server in the ring is used """ - key_prefix = get_key_prefix(key) - hash_key = md5hash(key) - server_key = md5hash(server_key) if server_key else hash_key - for (server, fp, sock) in self._get_conns('delete', key_prefix, - server_key): + cmd = server_cmd = MemcacheCommand('delete', key) + if server_key: + server_cmd = MemcacheCommand('delete', server_key) + for (server, fp, sock) in self._get_conns(server_cmd): conn_start_time = tm.time() try: with Timeout(self._io_timeout): - sock.sendall(b'delete ' + hash_key + b'\r\n') + sock.sendall(b'delete ' + cmd.hash_key + b'\r\n') # Wait for the delete to complete fp.readline() self._return_conn(server, fp, sock) return except (Exception, Timeout) as e: - self._exception_occurred( - server, e, key_prefix, 'delete', conn_start_time, - sock=sock, fp=fp) + self._exception_occurred(server, e, cmd, conn_start_time, + sock=sock, fp=fp) @memcached_timing_stats(sample_rate=TIMING_SAMPLE_RATE_HIGH) def set_multi(self, mapping, server_key, serialize=True, time=0, @@ -564,8 +579,7 @@ class MemcacheRing(object): python-memcached interface. This implementation ignores it """ - key_prefix = get_key_prefix(server_key) - hash_key = md5hash(server_key) + cmd = MemcacheCommand('set_multi', server_key) timeout = sanitize_timeout(time) msg = [] for key, value in mapping.items(): @@ -577,8 +591,7 @@ class MemcacheRing(object): value = json.dumps(value).encode('ascii') flags |= JSON_FLAG msg.append(set_msg(key, flags, timeout, value)) - for (server, fp, sock) in self._get_conns('set_multi', key_prefix, - hash_key): + for (server, fp, sock) in self._get_conns(cmd): conn_start_time = tm.time() try: with Timeout(self._io_timeout): @@ -589,9 +602,8 @@ class MemcacheRing(object): self._return_conn(server, fp, sock) return except (Exception, Timeout) as e: - self._exception_occurred( - server, e, key_prefix, 'set_multi', conn_start_time, - sock=sock, fp=fp) + self._exception_occurred(server, e, cmd, conn_start_time, + sock=sock, fp=fp) @memcached_timing_stats(sample_rate=TIMING_SAMPLE_RATE_HIGH) def get_multi(self, keys, server_key): @@ -603,11 +615,9 @@ class MemcacheRing(object): is used :returns: list of values """ - key_prefix = get_key_prefix(server_key) - server_key = md5hash(server_key) + cmd = MemcacheCommand('get_multi', server_key) hash_keys = [md5hash(key) for key in keys] - for (server, fp, sock) in self._get_conns('get_multi', key_prefix, - server_key): + for (server, fp, sock) in self._get_conns(cmd): conn_start_time = tm.time() try: with Timeout(self._io_timeout): @@ -638,9 +648,8 @@ class MemcacheRing(object): self._return_conn(server, fp, sock) return values except (Exception, Timeout) as e: - self._exception_occurred( - server, e, key_prefix, 'get_multi', conn_start_time, - sock=sock, fp=fp) + self._exception_occurred(server, e, cmd, conn_start_time, + sock=sock, fp=fp) def load_memcache(conf, logger): diff --git a/test/unit/common/test_memcached.py b/test/unit/common/test_memcached.py index 34968cc091..81b7fd97e6 100644 --- a/test/unit/common/test_memcached.py +++ b/test/unit/common/test_memcached.py @@ -24,7 +24,6 @@ import six import socket import time import unittest -from uuid import uuid4 import os import mock @@ -36,7 +35,7 @@ from eventlet.green import ssl from swift.common import memcached from swift.common.memcached import MemcacheConnectionError, md5hash, \ - get_key_prefix + MemcacheCommand from swift.common.utils import md5, human_readable from mock import patch, MagicMock from test.debug_logger import debug_logger @@ -196,26 +195,35 @@ class MockMemcached(object): pass +class TestMemcacheCommand(unittest.TestCase): + def test_init(self): + cmd = MemcacheCommand("set", "shard-updating-v2/a/c") + self.assertEqual(cmd.method, "set") + self.assertEqual(cmd.command, b"set") + self.assertEqual(cmd.key, "shard-updating-v2/a/c") + self.assertEqual(cmd.key_prefix, "shard-updating-v2/a") + self.assertEqual(cmd.hash_key, md5hash("shard-updating-v2/a/c")) + + def test_get_key_prefix(self): + cmd = MemcacheCommand("set", "shard-updating-v2/a/c") + self.assertEqual(cmd.key_prefix, "shard-updating-v2/a") + cmd = MemcacheCommand("set", "shard-listing-v2/accout/container3") + self.assertEqual(cmd.key_prefix, "shard-listing-v2/accout") + cmd = MemcacheCommand( + "set", "auth_reseller_name/token/X58E34EL2SDFLEY3") + self.assertEqual(cmd.key_prefix, "auth_reseller_name/token") + cmd = MemcacheCommand("set", "nvratelimit/v2/wf/2345392374") + self.assertEqual(cmd.key_prefix, "nvratelimit/v2/wf") + cmd = MemcacheCommand("set", "some_key") + self.assertEqual(cmd.key_prefix, "some_key") + + class TestMemcached(unittest.TestCase): """Tests for swift.common.memcached""" def setUp(self): self.logger = debug_logger() - - def test_get_key_prefix(self): - self.assertEqual( - get_key_prefix("shard-updating-v2/a/c"), - "shard-updating-v2/a") - self.assertEqual( - get_key_prefix("shard-listing-v2/accout/container3"), - "shard-listing-v2/accout") - self.assertEqual( - get_key_prefix("auth_reseller_name/token/X58E34EL2SDFLEY3"), - "auth_reseller_name/token") - self.assertEqual( - get_key_prefix("nvratelimit/v2/wf/2345392374"), - "nvratelimit/v2/wf") - self.assertEqual(get_key_prefix("some_key"), "some_key") + self.set_cmd = MemcacheCommand('set', 'key') def test_logger_kwarg(self): server_socket = '%s:%s' % ('[::1]', 11211) @@ -235,8 +243,7 @@ class TestMemcached(unittest.TestCase): client = memcached.MemcacheRing([server], tls_context=context) self.assertIs(client._client_cache[server]._tls_context, context) - key = uuid4().hex.encode('ascii') - list(client._get_conns('set', 'test', key)) + list(client._get_conns(self.set_cmd)) context.wrap_socket.assert_called_once() def test_get_conns(self): @@ -257,8 +264,7 @@ class TestMemcached(unittest.TestCase): logger=self.logger) one = two = True while one or two: # Run until we match hosts one and two - key = uuid4().hex.encode('ascii') - for conn in memcache_client._get_conns('set', 'test', key): + for conn in memcache_client._get_conns(self.set_cmd): if 'b' not in getattr(conn[1], 'mode', ''): self.assertIsInstance(conn[1], ( io.RawIOBase, io.BufferedIOBase)) @@ -284,8 +290,7 @@ class TestMemcached(unittest.TestCase): server_socket = '[%s]:%s' % (sock_addr[0], sock_addr[1]) memcache_client = memcached.MemcacheRing([server_socket], logger=self.logger) - key = uuid4().hex.encode('ascii') - for conn in memcache_client._get_conns('set', 'test', key): + for conn in memcache_client._get_conns(self.set_cmd): peer_sockaddr = conn[2].getpeername() peer_socket = '[%s]:%s' % (peer_sockaddr[0], peer_sockaddr[1]) self.assertEqual(peer_socket, server_socket) @@ -306,8 +311,7 @@ class TestMemcached(unittest.TestCase): memcached.DEFAULT_MEMCACHED_PORT = sock_addr[1] memcache_client = memcached.MemcacheRing([server_host], logger=self.logger) - key = uuid4().hex.encode('ascii') - for conn in memcache_client._get_conns('set', 'test', key): + for conn in memcache_client._get_conns(self.set_cmd): peer_sockaddr = conn[2].getpeername() peer_socket = '[%s]:%s' % (peer_sockaddr[0], peer_sockaddr[1]) self.assertEqual(peer_socket, server_socket) @@ -335,8 +339,7 @@ class TestMemcached(unittest.TestCase): ('127.0.0.1', sock_addr[1]))] memcache_client = memcached.MemcacheRing([server_socket], logger=self.logger) - key = uuid4().hex.encode('ascii') - for conn in memcache_client._get_conns('set', 'test', key): + for conn in memcache_client._get_conns(self.set_cmd): peer_sockaddr = conn[2].getpeername() peer_socket = '%s:%s' % (peer_sockaddr[0], peer_sockaddr[1]) @@ -361,8 +364,7 @@ class TestMemcached(unittest.TestCase): ('::1', sock_addr[1]))] memcache_client = memcached.MemcacheRing([server_socket], logger=self.logger) - key = uuid4().hex.encode('ascii') - for conn in memcache_client._get_conns('set', 'test', key): + for conn in memcache_client._get_conns(self.set_cmd): peer_sockaddr = conn[2].getpeername() peer_socket = '[%s]:%s' % (peer_sockaddr[0], peer_sockaddr[1]) @@ -1066,8 +1068,7 @@ class TestMemcached(unittest.TestCase): # try to get connect and no connection found # so it will result in StopIteration - conn_generator = memcache_client._get_conns( - 'set', 'key', md5hash(b'key')) + conn_generator = memcache_client._get_conns(self.set_cmd) with self.assertRaises(StopIteration): next(conn_generator)