Add optional support for unicode keys

memcached's ASCII protocol supports unicode keys, so lets support them
as well. Since using unicode keys for memcache is uncommon and to
preserve the previous behavior disable support by default.
This commit is contained in:
Joe Gordon 2016-10-31 14:04:00 -07:00
parent f1c939be4d
commit 3613587536
9 changed files with 121 additions and 42 deletions

View File

@ -123,3 +123,4 @@ Credits
* `Ernest W. Durbin III <https://github.com/ewdurbin>`_ * `Ernest W. Durbin III <https://github.com/ewdurbin>`_
* `Remco van Oosterhout <https://github.com/Vhab>`_ * `Remco van Oosterhout <https://github.com/Vhab>`_
* `Nicholas Charriere <https://github.com/nichochar>`_ * `Nicholas Charriere <https://github.com/nichochar>`_
* `Joe Gordon <https://github.com/jogo>`_

View File

@ -63,8 +63,11 @@ This client implements the ASCII protocol of memcached. This means keys should n
contain any of the following illegal characters: contain any of the following illegal characters:
> Keys cannot have spaces, new lines, carriage returns, or null characters. > Keys cannot have spaces, new lines, carriage returns, or null characters.
We suggest that if you have unicode characters, or long keys, you use an effective We suggest that if you have unicode characters, or long keys, you use an effective
hashing mechanism before calling this client. At Pinterest, we have found that murmur3 hash is a hashing mechanism before calling this client. At Pinterest, we have found that
great candidate for this. murmur3 hash is a great candidate for this. Alternatively you can
set `allow_unicode_keys` to support unicode keys, but beware of
what unicode encoding you use to make sure multiple clients can find the
same key.
Best Practices Best Practices

View File

@ -81,9 +81,12 @@ STAT_TYPES = {
# Common helper functions. # Common helper functions.
def _check_key(key, key_prefix=b''): def _check_key(key, allow_unicode_keys, key_prefix=b''):
"""Checks key and add key_prefix.""" """Checks key and add key_prefix."""
if isinstance(key, VALID_STRING_TYPES): if allow_unicode_keys:
if isinstance(key, six.text_type):
key = key.encode('utf8')
elif isinstance(key, VALID_STRING_TYPES):
try: try:
key = key.encode('ascii') key = key.encode('ascii')
except (UnicodeEncodeError, UnicodeDecodeError): except (UnicodeEncodeError, UnicodeDecodeError):
@ -177,7 +180,8 @@ class Client(object):
ignore_exc=False, ignore_exc=False,
socket_module=socket, socket_module=socket,
key_prefix=b'', key_prefix=b'',
default_noreply=True): default_noreply=True,
allow_unicode_keys=False):
""" """
Constructor. Constructor.
@ -203,6 +207,7 @@ class Client(object):
default_noreply: bool, the default value for 'noreply' as passed to default_noreply: bool, the default value for 'noreply' as passed to
store commands (except from cas, incr, and decr, which default to store commands (except from cas, incr, and decr, which default to
False). False).
allow_unicode_keys: bool, support unicode (utf8) keys
Notes: Notes:
The constructor does not make a connection to memcached. The first The constructor does not make a connection to memcached. The first
@ -223,10 +228,11 @@ class Client(object):
raise TypeError("key_prefix should be bytes.") raise TypeError("key_prefix should be bytes.")
self.key_prefix = key_prefix self.key_prefix = key_prefix
self.default_noreply = default_noreply self.default_noreply = default_noreply
self.allow_unicode_keys = allow_unicode_keys
def check_key(self, key): def check_key(self, key):
"""Checks key and add key_prefix.""" """Checks key and add key_prefix."""
return _check_key(key, key_prefix=self.key_prefix) return _check_key(key, allow_unicode_keys=self.allow_unicode_keys, key_prefix=self.key_prefix)
def _connect(self): def _connect(self):
sock = self.socket_module.socket(self.socket_module.AF_INET, sock = self.socket_module.socket(self.socket_module.AF_INET,
@ -841,7 +847,8 @@ class PooledClient(object):
key_prefix=b'', key_prefix=b'',
max_pool_size=None, max_pool_size=None,
lock_generator=None, lock_generator=None,
default_noreply=True): default_noreply=True,
allow_unicode_keys=False):
self.server = server self.server = server
self.serializer = serializer self.serializer = serializer
self.deserializer = deserializer self.deserializer = deserializer
@ -851,6 +858,7 @@ class PooledClient(object):
self.ignore_exc = ignore_exc self.ignore_exc = ignore_exc
self.socket_module = socket_module self.socket_module = socket_module
self.default_noreply = default_noreply self.default_noreply = default_noreply
self.allow_unicode_keys = allow_unicode_keys
if isinstance(key_prefix, six.text_type): if isinstance(key_prefix, six.text_type):
key_prefix = key_prefix.encode('ascii') key_prefix = key_prefix.encode('ascii')
if not isinstance(key_prefix, bytes): if not isinstance(key_prefix, bytes):
@ -864,7 +872,7 @@ class PooledClient(object):
def check_key(self, key): def check_key(self, key):
"""Checks key and add key_prefix.""" """Checks key and add key_prefix."""
return _check_key(key, key_prefix=self.key_prefix) return _check_key(key, allow_unicode_keys=self.allow_unicode_keys, key_prefix=self.key_prefix)
def _create_client(self): def _create_client(self):
client = Client(self.server, client = Client(self.server,
@ -878,7 +886,8 @@ class PooledClient(object):
ignore_exc=False, ignore_exc=False,
socket_module=self.socket_module, socket_module=self.socket_module,
key_prefix=self.key_prefix, key_prefix=self.key_prefix,
default_noreply=self.default_noreply) default_noreply=self.default_noreply,
allow_unicode_keys=self.allow_unicode_keys)
return client return client
def close(self): def close(self):

View File

@ -30,6 +30,7 @@ class HashClient(object):
dead_timeout=60, dead_timeout=60,
use_pooling=False, use_pooling=False,
ignore_exc=False, ignore_exc=False,
allow_unicode_keys=False
): ):
""" """
Constructor. Constructor.
@ -66,6 +67,7 @@ class HashClient(object):
self.use_pooling = use_pooling self.use_pooling = use_pooling
self.key_prefix = key_prefix self.key_prefix = key_prefix
self.ignore_exc = ignore_exc self.ignore_exc = ignore_exc
self.allow_unicode_keys = allow_unicode_keys
self._failed_clients = {} self._failed_clients = {}
self._dead_clients = {} self._dead_clients = {}
self._last_dead_check_time = time.time() self._last_dead_check_time = time.time()
@ -80,6 +82,7 @@ class HashClient(object):
'key_prefix': key_prefix, 'key_prefix': key_prefix,
'serializer': serializer, 'serializer': serializer,
'deserializer': deserializer, 'deserializer': deserializer,
'allow_unicode_keys': allow_unicode_keys,
} }
if use_pooling is True: if use_pooling is True:
@ -113,7 +116,7 @@ class HashClient(object):
self.hasher.remove_node(key) self.hasher.remove_node(key)
def _get_client(self, key): def _get_client(self, key):
_check_key(key, self.key_prefix) _check_key(key, self.allow_unicode_keys, self.key_prefix)
if len(self._dead_clients) > 0: if len(self._dead_clients) > 0:
current_time = time.time() current_time = time.time()
ldc = self._last_dead_check_time ldc = self._last_dead_check_time

View File

@ -36,7 +36,7 @@ class RendezvousHash(object):
for node in self.nodes: for node in self.nodes:
score = self.hash_function( score = self.hash_function(
"%s-%s" % (str(node), str(key))) "%s-%s" % (node, key))
if score > high_score: if score > high_score:
(high_score, winner) = (score, node) (high_score, winner) = (score, node)

View File

@ -1,4 +1,5 @@
# Copyright 2012 Pinterest.com # Copyright 2012 Pinterest.com
# -*- coding: utf-8 -*-
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -92,6 +93,18 @@ class ClientTestMixin(object):
with pytest.raises(MemcacheIllegalInputError): with pytest.raises(MemcacheIllegalInputError):
_set() _set()
def test_set_unicode_key_ok(self):
client = self.make_client([b'STORED\r\n'], allow_unicode_keys=True)
result = client.set(u'\u0FFF', b'value', noreply=False)
assert result is True
def test_set_unicode_key_ok_snowman(self):
client = self.make_client([b'STORED\r\n'], allow_unicode_keys=True)
result = client.set('my☃', b'value', noreply=False)
assert result is True
def test_set_unicode_char_in_middle_of_key(self): def test_set_unicode_char_in_middle_of_key(self):
client = self.make_client([b'STORED\r\n']) client = self.make_client([b'STORED\r\n'])
@ -101,6 +114,15 @@ class ClientTestMixin(object):
with pytest.raises(MemcacheIllegalInputError): with pytest.raises(MemcacheIllegalInputError):
_set() _set()
def test_set_unicode_char_in_middle_of_key_snowman(self):
client = self.make_client([b'STORED\r\n'])
def _set():
client.set('my☃', b'value', noreply=False)
with pytest.raises(MemcacheIllegalInputError):
_set()
def test_set_unicode_value(self): def test_set_unicode_value(self):
client = self.make_client([b'']) client = self.make_client([b''])
@ -110,6 +132,12 @@ class ClientTestMixin(object):
with pytest.raises(MemcacheIllegalInputError): with pytest.raises(MemcacheIllegalInputError):
_set() _set()
def test_set_unicode_char_in_middle_of_key_ok(self):
client = self.make_client([b'STORED\r\n'], allow_unicode_keys=True)
result = client.set('helloworld_\xb1901520_%c3', b'value', noreply=False)
assert result is True
def test_set_noreply(self): def test_set_noreply(self):
client = self.make_client([]) client = self.make_client([])
result = client.set(b'key', b'value', noreply=True) result = client.set(b'key', b'value', noreply=True)
@ -626,6 +654,15 @@ class TestClient(ClientTestMixin, unittest.TestCase):
with pytest.raises(MemcacheClientError): with pytest.raises(MemcacheClientError):
client.get(b'x' * 251) client.get(b'x' * 251)
def test_too_long_unicode_key(self):
client = self.make_client([b'STORED\r\n'], allow_unicode_keys=True)
with pytest.raises(MemcacheClientError):
client.get('my☃'*150)
with pytest.raises(MemcacheClientError):
client.get(u'\u0FFF'*150)
def test_key_contains_spae(self): def test_key_contains_spae(self):
client = self.make_client([b'END\r\n']) client = self.make_client([b'END\r\n'])
with pytest.raises(MemcacheClientError): with pytest.raises(MemcacheClientError):

View File

@ -12,8 +12,8 @@ import socket
class TestHashClient(ClientTestMixin, unittest.TestCase): class TestHashClient(ClientTestMixin, unittest.TestCase):
def make_client_pool(self, hostname, mock_socket_values, serializer=None): def make_client_pool(self, hostname, mock_socket_values, serializer=None, **kwargs):
mock_client = Client(hostname, serializer=serializer) mock_client = Client(hostname, serializer=serializer, **kwargs)
mock_client.sock = MockSocket(mock_socket_values) mock_client.sock = MockSocket(mock_socket_values)
client = PooledClient(hostname, serializer=serializer) client = PooledClient(hostname, serializer=serializer)
client.client_pool = pool.ObjectPool(lambda: mock_client) client.client_pool = pool.ObjectPool(lambda: mock_client)
@ -28,7 +28,8 @@ class TestHashClient(ClientTestMixin, unittest.TestCase):
s = '%s:%s' % (ip, current_port) s = '%s:%s' % (ip, current_port)
c = self.make_client_pool( c = self.make_client_pool(
(ip, current_port), (ip, current_port),
vals vals,
**kwargs
) )
client.clients[s] = c client.clients[s] = c
client.hasher.add_node(s) client.hasher.add_node(s)

View File

@ -1,4 +1,5 @@
# Copyright 2012 Pinterest.com # Copyright 2012 Pinterest.com
# -*- coding: utf-8 -*-
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -23,27 +24,47 @@ from pymemcache.exceptions import (
) )
def get_set_helper(client, key, value, key2, value2):
result = client.get(key)
assert result is None
client.set(key, value, noreply=False)
result = client.get(key)
assert result == value
client.set(key2, value2, noreply=True)
result = client.get(key2)
assert result == value2
result = client.get_many([key, key2])
assert result == {key: value, key2: value2}
result = client.get_many([])
assert result == {}
@pytest.mark.integration() @pytest.mark.integration()
def test_get_set(client_class, host, port, socket_module): def test_get_set(client_class, host, port, socket_module):
client = client_class((host, port), socket_module=socket_module) client = client_class((host, port), socket_module=socket_module)
client.flush_all() client.flush_all()
result = client.get('key') key = b'key'
assert result is None value = b'value'
key2 = b'key2'
value2 = b'value2'
get_set_helper(client, key, value, key2, value2)
client.set(b'key', b'value', noreply=False)
result = client.get(b'key')
assert result == b'value'
client.set(b'key2', b'value2', noreply=True) @pytest.mark.integration()
result = client.get(b'key2') def test_get_set_unicode_key(client_class, host, port, socket_module):
assert result == b'value2' client = client_class((host, port), socket_module=socket_module, allow_unicode_keys=True)
client.flush_all()
result = client.get_many([b'key', b'key2']) key = u"こんにちは"
assert result == {b'key': b'value', b'key2': b'value2'} value = b'hello'
key2 = 'my☃'
result = client.get_many([]) value2 = b'value2'
assert result == {} get_set_helper(client, key, value, key2, value2)
@pytest.mark.integration() @pytest.mark.integration()

View File

@ -26,12 +26,14 @@ class MockMemcacheClient(object):
timeout=None, timeout=None,
no_delay=False, no_delay=False,
ignore_exc=False, ignore_exc=False,
default_noreply=True): default_noreply=True,
allow_unicode_keys=False):
self._contents = {} self._contents = {}
self.serializer = serializer self.serializer = serializer
self.deserializer = deserializer self.deserializer = deserializer
self.allow_unicode_keys = allow_unicode_keys
# Unused, but present for interface compatibility # Unused, but present for interface compatibility
self.server = server self.server = server
@ -41,13 +43,14 @@ class MockMemcacheClient(object):
self.ignore_exc = ignore_exc self.ignore_exc = ignore_exc
def get(self, key, default=None): def get(self, key, default=None):
if isinstance(key, six.text_type): if not self.allow_unicode_keys:
raise MemcacheIllegalInputError(key) if isinstance(key, six.text_type):
if isinstance(key, six.string_types): raise MemcacheIllegalInputError(key)
try: if isinstance(key, six.string_types):
key = key.encode('ascii') try:
except (UnicodeEncodeError, UnicodeDecodeError): key = key.encode('ascii')
raise MemcacheIllegalInputError except (UnicodeEncodeError, UnicodeDecodeError):
raise MemcacheIllegalInputError
if key not in self._contents: if key not in self._contents:
return default return default
@ -72,15 +75,16 @@ class MockMemcacheClient(object):
get_multi = get_many get_multi = get_many
def set(self, key, value, expire=0, noreply=True): def set(self, key, value, expire=0, noreply=True):
if isinstance(key, six.text_type): if not self.allow_unicode_keys:
raise MemcacheIllegalInputError(key) if isinstance(key, six.text_type):
raise MemcacheIllegalInputError(key)
if isinstance(key, six.string_types):
try:
key = key.encode('ascii')
except (UnicodeEncodeError, UnicodeDecodeError):
raise MemcacheIllegalInputError
if isinstance(value, six.text_type): if isinstance(value, six.text_type):
raise MemcacheIllegalInputError(value) raise MemcacheIllegalInputError(value)
if isinstance(key, six.string_types):
try:
key = key.encode('ascii')
except (UnicodeEncodeError, UnicodeDecodeError):
raise MemcacheIllegalInputError
if isinstance(value, six.string_types): if isinstance(value, six.string_types):
try: try:
value = value.encode('ascii') value = value.encode('ascii')