From ae8470131ead095e3bf1c290bac866a5e6e29e79 Mon Sep 17 00:00:00 2001 From: Chuck Thier Date: Wed, 4 Sep 2013 22:20:44 +0000 Subject: [PATCH] Pool memcache connections This creates a pool to each memcache server so that connections will not grow without bound. This also adds a proxy config "max_memcache_connections" which can control how many connections are available in the pool. A side effect of the change is that we had to change the memcache calls that used noreply, and instead wait for the result of the request. Leaving with noreply could cause a race condition (specifically in account auto create), due to one request calling `memcache.del(key)` and then `memcache.get(key)` with a different pooled connection. If the delete didn't complete fast enough, the get would return the old value before it was deleted, and thus believe that the account was not autocreated. ClaysMindExploded DocImpact Change-Id: I350720b7bba29e1453894d3d4105ac1ea232595b --- doc/source/deployment_guide.rst | 3 + swift/common/memcached.py | 230 ++++++++++++++++------------ swift/common/middleware/memcache.py | 4 +- test/unit/common/test_memcached.py | 80 ++++++++-- 4 files changed, 212 insertions(+), 105 deletions(-) diff --git a/doc/source/deployment_guide.rst b/doc/source/deployment_guide.rst index ed744f7f23..c10143349a 100644 --- a/doc/source/deployment_guide.rst +++ b/doc/source/deployment_guide.rst @@ -795,6 +795,9 @@ client_chunk_size 65536 Chunk size to read from clients memcache_servers 127.0.0.1:11211 Comma separated list of memcached servers ip:port +memcache_max_connections 2 Max number of connections to + each memcached server per + worker node_timeout 10 Request timeout to external services client_timeout 60 Timeout to read one chunk diff --git a/swift/common/memcached.py b/swift/common/memcached.py index 678e66a668..d1fe193b66 100644 --- a/swift/common/memcached.py +++ b/swift/common/memcached.py @@ -46,12 +46,15 @@ http://github.com/memcached/memcached/blob/1.4.2/doc/protocol.txt import cPickle as pickle import logging -import socket import time from bisect import bisect from swift import gettext_ as _ from hashlib import md5 +from eventlet.green import socket +from eventlet.pools import Pool +from eventlet import Timeout + from swift.common.utils import json DEFAULT_MEMCACHED_PORT = 11211 @@ -91,6 +94,34 @@ class MemcacheConnectionError(Exception): pass +class MemcacheConnPool(Pool): + """Connection pool for Memcache Connections""" + + def __init__(self, server, size, connect_timeout): + Pool.__init__(self, max_size=size) + self.server = server + self._connect_timeout = connect_timeout + + def create(self): + if ':' in self.server: + host, port = self.server.split(':') + else: + host = self.server + port = DEFAULT_MEMCACHED_PORT + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + with Timeout(self._connect_timeout): + sock.connect((host, int(port))) + return (sock.makefile(), sock) + + def get(self): + fp, sock = Pool.get(self) + if fp is None: + # An error happened previously, so we need a new connection + fp, sock = self.create() + return fp, sock + + class MemcacheRing(object): """ Simple, consistent-hashed memcache client. @@ -98,7 +129,8 @@ class MemcacheRing(object): def __init__(self, servers, connect_timeout=CONN_TIMEOUT, io_timeout=IO_TIMEOUT, tries=TRY_COUNT, - allow_pickle=False, allow_unpickle=False): + allow_pickle=False, allow_unpickle=False, + max_conns=2): self._ring = {} self._errors = dict(((serv, []) for serv in servers)) self._error_limited = dict(((serv, 0) for serv in servers)) @@ -107,7 +139,10 @@ class MemcacheRing(object): self._ring[md5hash('%s-%s' % (server, i))] = server self._tries = tries if tries <= len(servers) else len(servers) self._sorted = sorted(self._ring) - self._client_cache = dict(((server, []) for server in servers)) + self._client_cache = dict(((server, + MemcacheConnPool(server, max_conns, + connect_timeout)) + for server in servers)) self._connect_timeout = connect_timeout self._io_timeout = io_timeout self._allow_pickle = allow_pickle @@ -115,7 +150,7 @@ class MemcacheRing(object): def _exception_occurred(self, server, e, action='talking', sock=None, fp=None): - if isinstance(e, socket.timeout): + if isinstance(e, Timeout): logging.error(_("Timeout %(action)s to memcached: %(server)s"), {'action': action, 'server': server}) else: @@ -133,6 +168,9 @@ class MemcacheRing(object): del sock except Exception: pass + # We need to return something to the pool + # A new connection will be created the next time it is retreived + self._return_conn(server, None, None) now = time.time() self._errors[server].append(time.time()) if len(self._errors[server]) > ERROR_LIMIT_COUNT: @@ -159,28 +197,16 @@ class MemcacheRing(object): continue sock = None try: - fp, sock = self._client_cache[server].pop() + with Timeout(self._connect_timeout): + fp, sock = self._client_cache[server].get() yield server, fp, sock - except IndexError: - try: - if ':' in server: - host, port = server.split(':') - else: - host = server - port = DEFAULT_MEMCACHED_PORT - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - sock.settimeout(self._connect_timeout) - sock.connect((host, int(port))) - sock.settimeout(self._io_timeout) - yield server, sock.makefile(), sock - except Exception as e: - self._exception_occurred( - server, e, action='connecting', sock=sock) + except (Exception, Timeout) as e: + self._exception_occurred( + server, e, action='connecting', sock=sock) def _return_conn(self, server, fp, sock): """Returns a server connection to the pool.""" - self._client_cache[server].append((fp, sock)) + self._client_cache[server].put((fp, sock)) def set(self, key, value, serialize=True, timeout=0, time=0, min_compress_len=0): @@ -217,11 +243,14 @@ class MemcacheRing(object): flags |= JSON_FLAG for (server, fp, sock) in self._get_conns(key): try: - sock.sendall('set %s %d %d %s noreply\r\n%s\r\n' % - (key, flags, timeout, len(value), value)) - self._return_conn(server, fp, sock) - return - except Exception as e: + with Timeout(self._io_timeout): + sock.sendall('set %s %d %d %s\r\n%s\r\n' % + (key, flags, timeout, len(value), value)) + # Wait for the set to complete + fp.readline() + self._return_conn(server, fp, sock) + return + except (Exception, Timeout) as e: self._exception_occurred(server, e, sock=sock, fp=fp) def get(self, key): @@ -237,24 +266,25 @@ class MemcacheRing(object): value = None for (server, fp, sock) in self._get_conns(key): try: - sock.sendall('get %s\r\n' % key) - line = fp.readline().strip().split() - while line[0].upper() != 'END': - if line[0].upper() == 'VALUE' and line[1] == key: - size = int(line[3]) - value = fp.read(size) - if int(line[2]) & PICKLE_FLAG: - if self._allow_unpickle: - value = pickle.loads(value) - else: - value = None - elif int(line[2]) & JSON_FLAG: - value = json.loads(value) - fp.readline() + with Timeout(self._io_timeout): + sock.sendall('get %s\r\n' % key) line = fp.readline().strip().split() - self._return_conn(server, fp, sock) - return value - except Exception as e: + while line[0].upper() != 'END': + if line[0].upper() == 'VALUE' and line[1] == key: + size = int(line[3]) + value = fp.read(size) + if int(line[2]) & PICKLE_FLAG: + if self._allow_unpickle: + value = pickle.loads(value) + else: + value = None + elif int(line[2]) & JSON_FLAG: + value = json.loads(value) + fp.readline() + line = fp.readline().strip().split() + self._return_conn(server, fp, sock) + return value + except (Exception, Timeout) as e: self._exception_occurred(server, e, sock=sock, fp=fp) def incr(self, key, delta=1, time=0, timeout=0): @@ -287,26 +317,28 @@ class MemcacheRing(object): timeout = sanitize_timeout(time or timeout) for (server, fp, sock) in self._get_conns(key): try: - sock.sendall('%s %s %s\r\n' % (command, key, delta)) - line = fp.readline().strip().split() - if line[0].upper() == 'NOT_FOUND': - add_val = delta - if command == 'decr': - add_val = '0' - sock.sendall('add %s %d %d %s\r\n%s\r\n' % - (key, 0, timeout, len(add_val), add_val)) + with Timeout(self._io_timeout): + sock.sendall('%s %s %s\r\n' % (command, key, delta)) line = fp.readline().strip().split() - if line[0].upper() == 'NOT_STORED': - sock.sendall('%s %s %s\r\n' % (command, key, delta)) + if line[0].upper() == 'NOT_FOUND': + add_val = delta + if command == 'decr': + add_val = '0' + sock.sendall('add %s %d %d %s\r\n%s\r\n' % + (key, 0, timeout, len(add_val), add_val)) line = fp.readline().strip().split() - ret = int(line[0].strip()) + if line[0].upper() == 'NOT_STORED': + sock.sendall('%s %s %s\r\n' % (command, key, + delta)) + line = fp.readline().strip().split() + ret = int(line[0].strip()) + else: + ret = int(add_val) else: - ret = int(add_val) - else: - ret = int(line[0].strip()) - self._return_conn(server, fp, sock) - return ret - except Exception as e: + ret = int(line[0].strip()) + self._return_conn(server, fp, sock) + return ret + except (Exception, Timeout) as e: self._exception_occurred(server, e, sock=sock, fp=fp) raise MemcacheConnectionError("No Memcached connections succeeded.") @@ -340,10 +372,13 @@ class MemcacheRing(object): key = md5hash(key) for (server, fp, sock) in self._get_conns(key): try: - sock.sendall('delete %s noreply\r\n' % key) - self._return_conn(server, fp, sock) - return - except Exception as e: + with Timeout(self._io_timeout): + sock.sendall('delete %s\r\n' % key) + # Wait for the delete to complete + fp.readline() + self._return_conn(server, fp, sock) + return + except (Exception, Timeout) as e: self._exception_occurred(server, e, sock=sock, fp=fp) def set_multi(self, mapping, server_key, serialize=True, timeout=0, @@ -384,14 +419,18 @@ class MemcacheRing(object): elif serialize: value = json.dumps(value) flags |= JSON_FLAG - msg += ('set %s %d %d %s noreply\r\n%s\r\n' % + msg += ('set %s %d %d %s\r\n%s\r\n' % (key, flags, timeout, len(value), value)) for (server, fp, sock) in self._get_conns(server_key): try: - sock.sendall(msg) - self._return_conn(server, fp, sock) - return - except Exception as e: + with Timeout(self._io_timeout): + sock.sendall(msg) + # Wait for the set to complete + for _ in range(len(mapping)): + fp.readline() + self._return_conn(server, fp, sock) + return + except (Exception, Timeout) as e: self._exception_occurred(server, e, sock=sock, fp=fp) def get_multi(self, keys, server_key): @@ -407,30 +446,31 @@ class MemcacheRing(object): keys = [md5hash(key) for key in keys] for (server, fp, sock) in self._get_conns(server_key): try: - sock.sendall('get %s\r\n' % ' '.join(keys)) - line = fp.readline().strip().split() - responses = {} - while line[0].upper() != 'END': - if line[0].upper() == 'VALUE': - size = int(line[3]) - value = fp.read(size) - if int(line[2]) & PICKLE_FLAG: - if self._allow_unpickle: - value = pickle.loads(value) - else: - value = None - elif int(line[2]) & JSON_FLAG: - value = json.loads(value) - responses[line[1]] = value - fp.readline() + with Timeout(self._io_timeout): + sock.sendall('get %s\r\n' % ' '.join(keys)) line = fp.readline().strip().split() - values = [] - for key in keys: - if key in responses: - values.append(responses[key]) - else: - values.append(None) - self._return_conn(server, fp, sock) - return values - except Exception as e: + responses = {} + while line[0].upper() != 'END': + if line[0].upper() == 'VALUE': + size = int(line[3]) + value = fp.read(size) + if int(line[2]) & PICKLE_FLAG: + if self._allow_unpickle: + value = pickle.loads(value) + else: + value = None + elif int(line[2]) & JSON_FLAG: + value = json.loads(value) + responses[line[1]] = value + fp.readline() + line = fp.readline().strip().split() + values = [] + for key in keys: + if key in responses: + values.append(responses[key]) + else: + values.append(None) + self._return_conn(server, fp, sock) + return values + except (Exception, Timeout) as e: self._exception_occurred(server, e, sock=sock, fp=fp) diff --git a/swift/common/middleware/memcache.py b/swift/common/middleware/memcache.py index 13e16d4c68..ae67c4ac44 100644 --- a/swift/common/middleware/memcache.py +++ b/swift/common/middleware/memcache.py @@ -28,6 +28,7 @@ class MemcacheMiddleware(object): self.app = app self.memcache_servers = conf.get('memcache_servers') serialization_format = conf.get('memcache_serialization_support') + max_conns = int(conf.get('max_connections', 2)) if not self.memcache_servers or serialization_format is None: path = os.path.join(conf.get('swift_dir', '/etc/swift'), @@ -58,7 +59,8 @@ class MemcacheMiddleware(object): self.memcache = MemcacheRing( [s.strip() for s in self.memcache_servers.split(',') if s.strip()], allow_pickle=(serialization_format == 0), - allow_unpickle=(serialization_format <= 1)) + allow_unpickle=(serialization_format <= 1), + max_conns=max_conns) def __call__(self, env, start_response): env['swift.cache'] = self.memcache diff --git a/test/unit/common/test_memcached.py b/test/unit/common/test_memcached.py index c717e6b870..46363e7907 100644 --- a/test/unit/common/test_memcached.py +++ b/test/unit/common/test_memcached.py @@ -23,10 +23,23 @@ import time import unittest from uuid import uuid4 +from eventlet import GreenPool, sleep, Queue +from eventlet.pools import Pool + from swift.common import memcached +from mock import patch from test.unit import NullLoggingHandler +class MockedMemcachePool(memcached.MemcacheConnPool): + def __init__(self, mocks): + Pool.__init__(self, max_size=2) + self.mocks = mocks + + def create(self): + return self.mocks.pop(0) + + class ExplodingMockMemcached(object): exploded = False @@ -173,7 +186,8 @@ class TestMemcached(unittest.TestCase): def test_set_get(self): memcache_client = memcached.MemcacheRing(['1.2.3.4:11211']) mock = MockMemcached() - memcache_client._client_cache['1.2.3.4:11211'] = [(mock, mock)] * 2 + memcache_client._client_cache['1.2.3.4:11211'] = MockedMemcachePool( + [(mock, mock)] * 2) memcache_client.set('some_key', [1, 2, 3]) self.assertEquals(memcache_client.get('some_key'), [1, 2, 3]) self.assertEquals(mock.cache.values()[0][1], '0') @@ -200,7 +214,8 @@ class TestMemcached(unittest.TestCase): def test_incr(self): memcache_client = memcached.MemcacheRing(['1.2.3.4:11211']) mock = MockMemcached() - memcache_client._client_cache['1.2.3.4:11211'] = [(mock, mock)] * 2 + memcache_client._client_cache['1.2.3.4:11211'] = MockedMemcachePool( + [(mock, mock)] * 2) memcache_client.incr('some_key', delta=5) self.assertEquals(memcache_client.get('some_key'), '5') memcache_client.incr('some_key', delta=5) @@ -219,7 +234,8 @@ class TestMemcached(unittest.TestCase): def test_incr_w_timeout(self): memcache_client = memcached.MemcacheRing(['1.2.3.4:11211']) mock = MockMemcached() - memcache_client._client_cache['1.2.3.4:11211'] = [(mock, mock)] * 2 + memcache_client._client_cache['1.2.3.4:11211'] = MockedMemcachePool( + [(mock, mock)] * 2) memcache_client.incr('some_key', delta=5, time=55) self.assertEquals(memcache_client.get('some_key'), '5') self.assertEquals(mock.cache.values()[0][1], '55') @@ -242,7 +258,8 @@ class TestMemcached(unittest.TestCase): def test_decr(self): memcache_client = memcached.MemcacheRing(['1.2.3.4:11211']) mock = MockMemcached() - memcache_client._client_cache['1.2.3.4:11211'] = [(mock, mock)] * 2 + memcache_client._client_cache['1.2.3.4:11211'] = MockedMemcachePool( + [(mock, mock)] * 2) memcache_client.decr('some_key', delta=5) self.assertEquals(memcache_client.get('some_key'), '0') memcache_client.incr('some_key', delta=15) @@ -261,8 +278,10 @@ class TestMemcached(unittest.TestCase): ['1.2.3.4:11211', '1.2.3.5:11211']) mock1 = ExplodingMockMemcached() mock2 = MockMemcached() - memcache_client._client_cache['1.2.3.4:11211'] = [(mock2, mock2)] - memcache_client._client_cache['1.2.3.5:11211'] = [(mock1, mock1)] + memcache_client._client_cache['1.2.3.4:11211'] = MockedMemcachePool( + [(mock2, mock2)]) + memcache_client._client_cache['1.2.3.5:11211'] = MockedMemcachePool( + [(mock1, mock1)]) memcache_client.set('some_key', [1, 2, 3]) self.assertEquals(memcache_client.get('some_key'), [1, 2, 3]) self.assertEquals(mock1.exploded, True) @@ -270,7 +289,8 @@ class TestMemcached(unittest.TestCase): def test_delete(self): memcache_client = memcached.MemcacheRing(['1.2.3.4:11211']) mock = MockMemcached() - memcache_client._client_cache['1.2.3.4:11211'] = [(mock, mock)] * 2 + memcache_client._client_cache['1.2.3.4:11211'] = MockedMemcachePool( + [(mock, mock)] * 2) memcache_client.set('some_key', [1, 2, 3]) self.assertEquals(memcache_client.get('some_key'), [1, 2, 3]) memcache_client.delete('some_key') @@ -279,7 +299,8 @@ class TestMemcached(unittest.TestCase): def test_multi(self): memcache_client = memcached.MemcacheRing(['1.2.3.4:11211']) mock = MockMemcached() - memcache_client._client_cache['1.2.3.4:11211'] = [(mock, mock)] * 2 + memcache_client._client_cache['1.2.3.4:11211'] = MockedMemcachePool( + [(mock, mock)] * 2) memcache_client.set_multi( {'some_key1': [1, 2, 3], 'some_key2': [4, 5, 6]}, 'multi_key') self.assertEquals( @@ -313,7 +334,8 @@ class TestMemcached(unittest.TestCase): memcache_client = memcached.MemcacheRing(['1.2.3.4:11211'], allow_pickle=True) mock = MockMemcached() - memcache_client._client_cache['1.2.3.4:11211'] = [(mock, mock)] * 2 + memcache_client._client_cache['1.2.3.4:11211'] = MockedMemcachePool( + [(mock, mock)] * 2) memcache_client.set('some_key', [1, 2, 3]) self.assertEquals(memcache_client.get('some_key'), [1, 2, 3]) memcache_client._allow_pickle = False @@ -328,6 +350,46 @@ class TestMemcached(unittest.TestCase): memcache_client._allow_pickle = True self.assertEquals(memcache_client.get('some_key'), [1, 2, 3]) + def test_connection_pooling(self): + with patch('swift.common.memcached.socket') as mock_module: + # patch socket, stub socket.socket, mock sock + mock_sock = mock_module.socket.return_value + + # track clients waiting for connections + connected = [] + connections = Queue() + + def wait_connect(addr): + connected.append(addr) + connections.get() + mock_sock.connect = wait_connect + + memcache_client = memcached.MemcacheRing(['1.2.3.4:11211'], + connect_timeout=10) + # sanity + self.assertEquals(1, len(memcache_client._client_cache)) + for server, pool in memcache_client._client_cache.items(): + self.assertEquals(2, pool.max_size) + + # make 10 requests "at the same time" + p = GreenPool() + for i in range(10): + p.spawn(memcache_client.set, 'key', 'value') + for i in range(3): + sleep(0.1) + self.assertEquals(2, len(connected)) + # give out a connection + connections.put(None) + for i in range(3): + sleep(0.1) + self.assertEquals(2, len(connected)) + # finish up + for i in range(8): + connections.put(None) + self.assertEquals(2, len(connected)) + p.waitall() + self.assertEquals(2, len(connected)) + if __name__ == '__main__': unittest.main()