Pool memcache connections

This creates a pool to each memcache server so that connections will not
grow without bound.  This also adds a proxy config
"max_memcache_connections" which can control how many connections are
available in the pool.

A side effect of the change is that we had to change the memcache calls
that used noreply, and instead wait for the result of the request.
Leaving with noreply could cause a race condition (specifically in
account auto create), due to one request calling `memcache.del(key)` and
then `memcache.get(key)` with a different pooled connection.  If the
delete didn't complete fast enough, the get would return the old value
before it was deleted, and thus believe that the account was not
autocreated.

ClaysMindExploded
DocImpact
Change-Id: I350720b7bba29e1453894d3d4105ac1ea232595b
This commit is contained in:
Chuck Thier 2013-09-04 22:20:44 +00:00
parent 58efcb3b3e
commit ae8470131e
4 changed files with 212 additions and 105 deletions

View File

@ -795,6 +795,9 @@ client_chunk_size 65536 Chunk size to read from
clients
memcache_servers 127.0.0.1:11211 Comma separated list of
memcached servers ip:port
memcache_max_connections 2 Max number of connections to
each memcached server per
worker
node_timeout 10 Request timeout to external
services
client_timeout 60 Timeout to read one chunk

View File

@ -46,12 +46,15 @@ http://github.com/memcached/memcached/blob/1.4.2/doc/protocol.txt
import cPickle as pickle
import logging
import socket
import time
from bisect import bisect
from swift import gettext_ as _
from hashlib import md5
from eventlet.green import socket
from eventlet.pools import Pool
from eventlet import Timeout
from swift.common.utils import json
DEFAULT_MEMCACHED_PORT = 11211
@ -91,6 +94,34 @@ class MemcacheConnectionError(Exception):
pass
class MemcacheConnPool(Pool):
"""Connection pool for Memcache Connections"""
def __init__(self, server, size, connect_timeout):
Pool.__init__(self, max_size=size)
self.server = server
self._connect_timeout = connect_timeout
def create(self):
if ':' in self.server:
host, port = self.server.split(':')
else:
host = self.server
port = DEFAULT_MEMCACHED_PORT
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
with Timeout(self._connect_timeout):
sock.connect((host, int(port)))
return (sock.makefile(), sock)
def get(self):
fp, sock = Pool.get(self)
if fp is None:
# An error happened previously, so we need a new connection
fp, sock = self.create()
return fp, sock
class MemcacheRing(object):
"""
Simple, consistent-hashed memcache client.
@ -98,7 +129,8 @@ class MemcacheRing(object):
def __init__(self, servers, connect_timeout=CONN_TIMEOUT,
io_timeout=IO_TIMEOUT, tries=TRY_COUNT,
allow_pickle=False, allow_unpickle=False):
allow_pickle=False, allow_unpickle=False,
max_conns=2):
self._ring = {}
self._errors = dict(((serv, []) for serv in servers))
self._error_limited = dict(((serv, 0) for serv in servers))
@ -107,7 +139,10 @@ class MemcacheRing(object):
self._ring[md5hash('%s-%s' % (server, i))] = server
self._tries = tries if tries <= len(servers) else len(servers)
self._sorted = sorted(self._ring)
self._client_cache = dict(((server, []) for server in servers))
self._client_cache = dict(((server,
MemcacheConnPool(server, max_conns,
connect_timeout))
for server in servers))
self._connect_timeout = connect_timeout
self._io_timeout = io_timeout
self._allow_pickle = allow_pickle
@ -115,7 +150,7 @@ class MemcacheRing(object):
def _exception_occurred(self, server, e, action='talking',
sock=None, fp=None):
if isinstance(e, socket.timeout):
if isinstance(e, Timeout):
logging.error(_("Timeout %(action)s to memcached: %(server)s"),
{'action': action, 'server': server})
else:
@ -133,6 +168,9 @@ class MemcacheRing(object):
del sock
except Exception:
pass
# We need to return something to the pool
# A new connection will be created the next time it is retreived
self._return_conn(server, None, None)
now = time.time()
self._errors[server].append(time.time())
if len(self._errors[server]) > ERROR_LIMIT_COUNT:
@ -159,28 +197,16 @@ class MemcacheRing(object):
continue
sock = None
try:
fp, sock = self._client_cache[server].pop()
with Timeout(self._connect_timeout):
fp, sock = self._client_cache[server].get()
yield server, fp, sock
except IndexError:
try:
if ':' in server:
host, port = server.split(':')
else:
host = server
port = DEFAULT_MEMCACHED_PORT
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
sock.settimeout(self._connect_timeout)
sock.connect((host, int(port)))
sock.settimeout(self._io_timeout)
yield server, sock.makefile(), sock
except Exception as e:
self._exception_occurred(
server, e, action='connecting', sock=sock)
except (Exception, Timeout) as e:
self._exception_occurred(
server, e, action='connecting', sock=sock)
def _return_conn(self, server, fp, sock):
"""Returns a server connection to the pool."""
self._client_cache[server].append((fp, sock))
self._client_cache[server].put((fp, sock))
def set(self, key, value, serialize=True, timeout=0, time=0,
min_compress_len=0):
@ -217,11 +243,14 @@ class MemcacheRing(object):
flags |= JSON_FLAG
for (server, fp, sock) in self._get_conns(key):
try:
sock.sendall('set %s %d %d %s noreply\r\n%s\r\n' %
(key, flags, timeout, len(value), value))
self._return_conn(server, fp, sock)
return
except Exception as e:
with Timeout(self._io_timeout):
sock.sendall('set %s %d %d %s\r\n%s\r\n' %
(key, flags, timeout, len(value), value))
# Wait for the set to complete
fp.readline()
self._return_conn(server, fp, sock)
return
except (Exception, Timeout) as e:
self._exception_occurred(server, e, sock=sock, fp=fp)
def get(self, key):
@ -237,24 +266,25 @@ class MemcacheRing(object):
value = None
for (server, fp, sock) in self._get_conns(key):
try:
sock.sendall('get %s\r\n' % key)
line = fp.readline().strip().split()
while line[0].upper() != 'END':
if line[0].upper() == 'VALUE' and line[1] == key:
size = int(line[3])
value = fp.read(size)
if int(line[2]) & PICKLE_FLAG:
if self._allow_unpickle:
value = pickle.loads(value)
else:
value = None
elif int(line[2]) & JSON_FLAG:
value = json.loads(value)
fp.readline()
with Timeout(self._io_timeout):
sock.sendall('get %s\r\n' % key)
line = fp.readline().strip().split()
self._return_conn(server, fp, sock)
return value
except Exception as e:
while line[0].upper() != 'END':
if line[0].upper() == 'VALUE' and line[1] == key:
size = int(line[3])
value = fp.read(size)
if int(line[2]) & PICKLE_FLAG:
if self._allow_unpickle:
value = pickle.loads(value)
else:
value = None
elif int(line[2]) & JSON_FLAG:
value = json.loads(value)
fp.readline()
line = fp.readline().strip().split()
self._return_conn(server, fp, sock)
return value
except (Exception, Timeout) as e:
self._exception_occurred(server, e, sock=sock, fp=fp)
def incr(self, key, delta=1, time=0, timeout=0):
@ -287,26 +317,28 @@ class MemcacheRing(object):
timeout = sanitize_timeout(time or timeout)
for (server, fp, sock) in self._get_conns(key):
try:
sock.sendall('%s %s %s\r\n' % (command, key, delta))
line = fp.readline().strip().split()
if line[0].upper() == 'NOT_FOUND':
add_val = delta
if command == 'decr':
add_val = '0'
sock.sendall('add %s %d %d %s\r\n%s\r\n' %
(key, 0, timeout, len(add_val), add_val))
with Timeout(self._io_timeout):
sock.sendall('%s %s %s\r\n' % (command, key, delta))
line = fp.readline().strip().split()
if line[0].upper() == 'NOT_STORED':
sock.sendall('%s %s %s\r\n' % (command, key, delta))
if line[0].upper() == 'NOT_FOUND':
add_val = delta
if command == 'decr':
add_val = '0'
sock.sendall('add %s %d %d %s\r\n%s\r\n' %
(key, 0, timeout, len(add_val), add_val))
line = fp.readline().strip().split()
ret = int(line[0].strip())
if line[0].upper() == 'NOT_STORED':
sock.sendall('%s %s %s\r\n' % (command, key,
delta))
line = fp.readline().strip().split()
ret = int(line[0].strip())
else:
ret = int(add_val)
else:
ret = int(add_val)
else:
ret = int(line[0].strip())
self._return_conn(server, fp, sock)
return ret
except Exception as e:
ret = int(line[0].strip())
self._return_conn(server, fp, sock)
return ret
except (Exception, Timeout) as e:
self._exception_occurred(server, e, sock=sock, fp=fp)
raise MemcacheConnectionError("No Memcached connections succeeded.")
@ -340,10 +372,13 @@ class MemcacheRing(object):
key = md5hash(key)
for (server, fp, sock) in self._get_conns(key):
try:
sock.sendall('delete %s noreply\r\n' % key)
self._return_conn(server, fp, sock)
return
except Exception as e:
with Timeout(self._io_timeout):
sock.sendall('delete %s\r\n' % key)
# Wait for the delete to complete
fp.readline()
self._return_conn(server, fp, sock)
return
except (Exception, Timeout) as e:
self._exception_occurred(server, e, sock=sock, fp=fp)
def set_multi(self, mapping, server_key, serialize=True, timeout=0,
@ -384,14 +419,18 @@ class MemcacheRing(object):
elif serialize:
value = json.dumps(value)
flags |= JSON_FLAG
msg += ('set %s %d %d %s noreply\r\n%s\r\n' %
msg += ('set %s %d %d %s\r\n%s\r\n' %
(key, flags, timeout, len(value), value))
for (server, fp, sock) in self._get_conns(server_key):
try:
sock.sendall(msg)
self._return_conn(server, fp, sock)
return
except Exception as e:
with Timeout(self._io_timeout):
sock.sendall(msg)
# Wait for the set to complete
for _ in range(len(mapping)):
fp.readline()
self._return_conn(server, fp, sock)
return
except (Exception, Timeout) as e:
self._exception_occurred(server, e, sock=sock, fp=fp)
def get_multi(self, keys, server_key):
@ -407,30 +446,31 @@ class MemcacheRing(object):
keys = [md5hash(key) for key in keys]
for (server, fp, sock) in self._get_conns(server_key):
try:
sock.sendall('get %s\r\n' % ' '.join(keys))
line = fp.readline().strip().split()
responses = {}
while line[0].upper() != 'END':
if line[0].upper() == 'VALUE':
size = int(line[3])
value = fp.read(size)
if int(line[2]) & PICKLE_FLAG:
if self._allow_unpickle:
value = pickle.loads(value)
else:
value = None
elif int(line[2]) & JSON_FLAG:
value = json.loads(value)
responses[line[1]] = value
fp.readline()
with Timeout(self._io_timeout):
sock.sendall('get %s\r\n' % ' '.join(keys))
line = fp.readline().strip().split()
values = []
for key in keys:
if key in responses:
values.append(responses[key])
else:
values.append(None)
self._return_conn(server, fp, sock)
return values
except Exception as e:
responses = {}
while line[0].upper() != 'END':
if line[0].upper() == 'VALUE':
size = int(line[3])
value = fp.read(size)
if int(line[2]) & PICKLE_FLAG:
if self._allow_unpickle:
value = pickle.loads(value)
else:
value = None
elif int(line[2]) & JSON_FLAG:
value = json.loads(value)
responses[line[1]] = value
fp.readline()
line = fp.readline().strip().split()
values = []
for key in keys:
if key in responses:
values.append(responses[key])
else:
values.append(None)
self._return_conn(server, fp, sock)
return values
except (Exception, Timeout) as e:
self._exception_occurred(server, e, sock=sock, fp=fp)

View File

@ -28,6 +28,7 @@ class MemcacheMiddleware(object):
self.app = app
self.memcache_servers = conf.get('memcache_servers')
serialization_format = conf.get('memcache_serialization_support')
max_conns = int(conf.get('max_connections', 2))
if not self.memcache_servers or serialization_format is None:
path = os.path.join(conf.get('swift_dir', '/etc/swift'),
@ -58,7 +59,8 @@ class MemcacheMiddleware(object):
self.memcache = MemcacheRing(
[s.strip() for s in self.memcache_servers.split(',') if s.strip()],
allow_pickle=(serialization_format == 0),
allow_unpickle=(serialization_format <= 1))
allow_unpickle=(serialization_format <= 1),
max_conns=max_conns)
def __call__(self, env, start_response):
env['swift.cache'] = self.memcache

View File

@ -23,10 +23,23 @@ import time
import unittest
from uuid import uuid4
from eventlet import GreenPool, sleep, Queue
from eventlet.pools import Pool
from swift.common import memcached
from mock import patch
from test.unit import NullLoggingHandler
class MockedMemcachePool(memcached.MemcacheConnPool):
def __init__(self, mocks):
Pool.__init__(self, max_size=2)
self.mocks = mocks
def create(self):
return self.mocks.pop(0)
class ExplodingMockMemcached(object):
exploded = False
@ -173,7 +186,8 @@ class TestMemcached(unittest.TestCase):
def test_set_get(self):
memcache_client = memcached.MemcacheRing(['1.2.3.4:11211'])
mock = MockMemcached()
memcache_client._client_cache['1.2.3.4:11211'] = [(mock, mock)] * 2
memcache_client._client_cache['1.2.3.4:11211'] = MockedMemcachePool(
[(mock, mock)] * 2)
memcache_client.set('some_key', [1, 2, 3])
self.assertEquals(memcache_client.get('some_key'), [1, 2, 3])
self.assertEquals(mock.cache.values()[0][1], '0')
@ -200,7 +214,8 @@ class TestMemcached(unittest.TestCase):
def test_incr(self):
memcache_client = memcached.MemcacheRing(['1.2.3.4:11211'])
mock = MockMemcached()
memcache_client._client_cache['1.2.3.4:11211'] = [(mock, mock)] * 2
memcache_client._client_cache['1.2.3.4:11211'] = MockedMemcachePool(
[(mock, mock)] * 2)
memcache_client.incr('some_key', delta=5)
self.assertEquals(memcache_client.get('some_key'), '5')
memcache_client.incr('some_key', delta=5)
@ -219,7 +234,8 @@ class TestMemcached(unittest.TestCase):
def test_incr_w_timeout(self):
memcache_client = memcached.MemcacheRing(['1.2.3.4:11211'])
mock = MockMemcached()
memcache_client._client_cache['1.2.3.4:11211'] = [(mock, mock)] * 2
memcache_client._client_cache['1.2.3.4:11211'] = MockedMemcachePool(
[(mock, mock)] * 2)
memcache_client.incr('some_key', delta=5, time=55)
self.assertEquals(memcache_client.get('some_key'), '5')
self.assertEquals(mock.cache.values()[0][1], '55')
@ -242,7 +258,8 @@ class TestMemcached(unittest.TestCase):
def test_decr(self):
memcache_client = memcached.MemcacheRing(['1.2.3.4:11211'])
mock = MockMemcached()
memcache_client._client_cache['1.2.3.4:11211'] = [(mock, mock)] * 2
memcache_client._client_cache['1.2.3.4:11211'] = MockedMemcachePool(
[(mock, mock)] * 2)
memcache_client.decr('some_key', delta=5)
self.assertEquals(memcache_client.get('some_key'), '0')
memcache_client.incr('some_key', delta=15)
@ -261,8 +278,10 @@ class TestMemcached(unittest.TestCase):
['1.2.3.4:11211', '1.2.3.5:11211'])
mock1 = ExplodingMockMemcached()
mock2 = MockMemcached()
memcache_client._client_cache['1.2.3.4:11211'] = [(mock2, mock2)]
memcache_client._client_cache['1.2.3.5:11211'] = [(mock1, mock1)]
memcache_client._client_cache['1.2.3.4:11211'] = MockedMemcachePool(
[(mock2, mock2)])
memcache_client._client_cache['1.2.3.5:11211'] = MockedMemcachePool(
[(mock1, mock1)])
memcache_client.set('some_key', [1, 2, 3])
self.assertEquals(memcache_client.get('some_key'), [1, 2, 3])
self.assertEquals(mock1.exploded, True)
@ -270,7 +289,8 @@ class TestMemcached(unittest.TestCase):
def test_delete(self):
memcache_client = memcached.MemcacheRing(['1.2.3.4:11211'])
mock = MockMemcached()
memcache_client._client_cache['1.2.3.4:11211'] = [(mock, mock)] * 2
memcache_client._client_cache['1.2.3.4:11211'] = MockedMemcachePool(
[(mock, mock)] * 2)
memcache_client.set('some_key', [1, 2, 3])
self.assertEquals(memcache_client.get('some_key'), [1, 2, 3])
memcache_client.delete('some_key')
@ -279,7 +299,8 @@ class TestMemcached(unittest.TestCase):
def test_multi(self):
memcache_client = memcached.MemcacheRing(['1.2.3.4:11211'])
mock = MockMemcached()
memcache_client._client_cache['1.2.3.4:11211'] = [(mock, mock)] * 2
memcache_client._client_cache['1.2.3.4:11211'] = MockedMemcachePool(
[(mock, mock)] * 2)
memcache_client.set_multi(
{'some_key1': [1, 2, 3], 'some_key2': [4, 5, 6]}, 'multi_key')
self.assertEquals(
@ -313,7 +334,8 @@ class TestMemcached(unittest.TestCase):
memcache_client = memcached.MemcacheRing(['1.2.3.4:11211'],
allow_pickle=True)
mock = MockMemcached()
memcache_client._client_cache['1.2.3.4:11211'] = [(mock, mock)] * 2
memcache_client._client_cache['1.2.3.4:11211'] = MockedMemcachePool(
[(mock, mock)] * 2)
memcache_client.set('some_key', [1, 2, 3])
self.assertEquals(memcache_client.get('some_key'), [1, 2, 3])
memcache_client._allow_pickle = False
@ -328,6 +350,46 @@ class TestMemcached(unittest.TestCase):
memcache_client._allow_pickle = True
self.assertEquals(memcache_client.get('some_key'), [1, 2, 3])
def test_connection_pooling(self):
with patch('swift.common.memcached.socket') as mock_module:
# patch socket, stub socket.socket, mock sock
mock_sock = mock_module.socket.return_value
# track clients waiting for connections
connected = []
connections = Queue()
def wait_connect(addr):
connected.append(addr)
connections.get()
mock_sock.connect = wait_connect
memcache_client = memcached.MemcacheRing(['1.2.3.4:11211'],
connect_timeout=10)
# sanity
self.assertEquals(1, len(memcache_client._client_cache))
for server, pool in memcache_client._client_cache.items():
self.assertEquals(2, pool.max_size)
# make 10 requests "at the same time"
p = GreenPool()
for i in range(10):
p.spawn(memcache_client.set, 'key', 'value')
for i in range(3):
sleep(0.1)
self.assertEquals(2, len(connected))
# give out a connection
connections.put(None)
for i in range(3):
sleep(0.1)
self.assertEquals(2, len(connected))
# finish up
for i in range(8):
connections.put(None)
self.assertEquals(2, len(connected))
p.waitall()
self.assertEquals(2, len(connected))
if __name__ == '__main__':
unittest.main()