Add optional support for unicode keys

memcached's ASCII protocol supports unicode keys, so lets support them
as well. Since using unicode keys for memcache is uncommon and to
preserve the previous behavior disable support by default.
This commit is contained in:
Joe Gordon 2016-10-31 14:04:00 -07:00
parent f1c939be4d
commit 3613587536
9 changed files with 121 additions and 42 deletions

View File

@ -123,3 +123,4 @@ Credits
* `Ernest W. Durbin III <https://github.com/ewdurbin>`_
* `Remco van Oosterhout <https://github.com/Vhab>`_
* `Nicholas Charriere <https://github.com/nichochar>`_
* `Joe Gordon <https://github.com/jogo>`_

View File

@ -63,8 +63,11 @@ This client implements the ASCII protocol of memcached. This means keys should n
contain any of the following illegal characters:
> Keys cannot have spaces, new lines, carriage returns, or null characters.
We suggest that if you have unicode characters, or long keys, you use an effective
hashing mechanism before calling this client. At Pinterest, we have found that murmur3 hash is a
great candidate for this.
hashing mechanism before calling this client. At Pinterest, we have found that
murmur3 hash is a great candidate for this. Alternatively you can
set `allow_unicode_keys` to support unicode keys, but beware of
what unicode encoding you use to make sure multiple clients can find the
same key.
Best Practices

View File

@ -81,9 +81,12 @@ STAT_TYPES = {
# Common helper functions.
def _check_key(key, key_prefix=b''):
def _check_key(key, allow_unicode_keys, key_prefix=b''):
"""Checks key and add key_prefix."""
if isinstance(key, VALID_STRING_TYPES):
if allow_unicode_keys:
if isinstance(key, six.text_type):
key = key.encode('utf8')
elif isinstance(key, VALID_STRING_TYPES):
try:
key = key.encode('ascii')
except (UnicodeEncodeError, UnicodeDecodeError):
@ -177,7 +180,8 @@ class Client(object):
ignore_exc=False,
socket_module=socket,
key_prefix=b'',
default_noreply=True):
default_noreply=True,
allow_unicode_keys=False):
"""
Constructor.
@ -203,6 +207,7 @@ class Client(object):
default_noreply: bool, the default value for 'noreply' as passed to
store commands (except from cas, incr, and decr, which default to
False).
allow_unicode_keys: bool, support unicode (utf8) keys
Notes:
The constructor does not make a connection to memcached. The first
@ -223,10 +228,11 @@ class Client(object):
raise TypeError("key_prefix should be bytes.")
self.key_prefix = key_prefix
self.default_noreply = default_noreply
self.allow_unicode_keys = allow_unicode_keys
def check_key(self, key):
"""Checks key and add key_prefix."""
return _check_key(key, key_prefix=self.key_prefix)
return _check_key(key, allow_unicode_keys=self.allow_unicode_keys, key_prefix=self.key_prefix)
def _connect(self):
sock = self.socket_module.socket(self.socket_module.AF_INET,
@ -841,7 +847,8 @@ class PooledClient(object):
key_prefix=b'',
max_pool_size=None,
lock_generator=None,
default_noreply=True):
default_noreply=True,
allow_unicode_keys=False):
self.server = server
self.serializer = serializer
self.deserializer = deserializer
@ -851,6 +858,7 @@ class PooledClient(object):
self.ignore_exc = ignore_exc
self.socket_module = socket_module
self.default_noreply = default_noreply
self.allow_unicode_keys = allow_unicode_keys
if isinstance(key_prefix, six.text_type):
key_prefix = key_prefix.encode('ascii')
if not isinstance(key_prefix, bytes):
@ -864,7 +872,7 @@ class PooledClient(object):
def check_key(self, key):
"""Checks key and add key_prefix."""
return _check_key(key, key_prefix=self.key_prefix)
return _check_key(key, allow_unicode_keys=self.allow_unicode_keys, key_prefix=self.key_prefix)
def _create_client(self):
client = Client(self.server,
@ -878,7 +886,8 @@ class PooledClient(object):
ignore_exc=False,
socket_module=self.socket_module,
key_prefix=self.key_prefix,
default_noreply=self.default_noreply)
default_noreply=self.default_noreply,
allow_unicode_keys=self.allow_unicode_keys)
return client
def close(self):

View File

@ -30,6 +30,7 @@ class HashClient(object):
dead_timeout=60,
use_pooling=False,
ignore_exc=False,
allow_unicode_keys=False
):
"""
Constructor.
@ -66,6 +67,7 @@ class HashClient(object):
self.use_pooling = use_pooling
self.key_prefix = key_prefix
self.ignore_exc = ignore_exc
self.allow_unicode_keys = allow_unicode_keys
self._failed_clients = {}
self._dead_clients = {}
self._last_dead_check_time = time.time()
@ -80,6 +82,7 @@ class HashClient(object):
'key_prefix': key_prefix,
'serializer': serializer,
'deserializer': deserializer,
'allow_unicode_keys': allow_unicode_keys,
}
if use_pooling is True:
@ -113,7 +116,7 @@ class HashClient(object):
self.hasher.remove_node(key)
def _get_client(self, key):
_check_key(key, self.key_prefix)
_check_key(key, self.allow_unicode_keys, self.key_prefix)
if len(self._dead_clients) > 0:
current_time = time.time()
ldc = self._last_dead_check_time

View File

@ -36,7 +36,7 @@ class RendezvousHash(object):
for node in self.nodes:
score = self.hash_function(
"%s-%s" % (str(node), str(key)))
"%s-%s" % (node, key))
if score > high_score:
(high_score, winner) = (score, node)

View File

@ -1,4 +1,5 @@
# Copyright 2012 Pinterest.com
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -92,6 +93,18 @@ class ClientTestMixin(object):
with pytest.raises(MemcacheIllegalInputError):
_set()
def test_set_unicode_key_ok(self):
client = self.make_client([b'STORED\r\n'], allow_unicode_keys=True)
result = client.set(u'\u0FFF', b'value', noreply=False)
assert result is True
def test_set_unicode_key_ok_snowman(self):
client = self.make_client([b'STORED\r\n'], allow_unicode_keys=True)
result = client.set('my☃', b'value', noreply=False)
assert result is True
def test_set_unicode_char_in_middle_of_key(self):
client = self.make_client([b'STORED\r\n'])
@ -101,6 +114,15 @@ class ClientTestMixin(object):
with pytest.raises(MemcacheIllegalInputError):
_set()
def test_set_unicode_char_in_middle_of_key_snowman(self):
client = self.make_client([b'STORED\r\n'])
def _set():
client.set('my☃', b'value', noreply=False)
with pytest.raises(MemcacheIllegalInputError):
_set()
def test_set_unicode_value(self):
client = self.make_client([b''])
@ -110,6 +132,12 @@ class ClientTestMixin(object):
with pytest.raises(MemcacheIllegalInputError):
_set()
def test_set_unicode_char_in_middle_of_key_ok(self):
client = self.make_client([b'STORED\r\n'], allow_unicode_keys=True)
result = client.set('helloworld_\xb1901520_%c3', b'value', noreply=False)
assert result is True
def test_set_noreply(self):
client = self.make_client([])
result = client.set(b'key', b'value', noreply=True)
@ -626,6 +654,15 @@ class TestClient(ClientTestMixin, unittest.TestCase):
with pytest.raises(MemcacheClientError):
client.get(b'x' * 251)
def test_too_long_unicode_key(self):
client = self.make_client([b'STORED\r\n'], allow_unicode_keys=True)
with pytest.raises(MemcacheClientError):
client.get('my☃'*150)
with pytest.raises(MemcacheClientError):
client.get(u'\u0FFF'*150)
def test_key_contains_spae(self):
client = self.make_client([b'END\r\n'])
with pytest.raises(MemcacheClientError):

View File

@ -12,8 +12,8 @@ import socket
class TestHashClient(ClientTestMixin, unittest.TestCase):
def make_client_pool(self, hostname, mock_socket_values, serializer=None):
mock_client = Client(hostname, serializer=serializer)
def make_client_pool(self, hostname, mock_socket_values, serializer=None, **kwargs):
mock_client = Client(hostname, serializer=serializer, **kwargs)
mock_client.sock = MockSocket(mock_socket_values)
client = PooledClient(hostname, serializer=serializer)
client.client_pool = pool.ObjectPool(lambda: mock_client)
@ -28,7 +28,8 @@ class TestHashClient(ClientTestMixin, unittest.TestCase):
s = '%s:%s' % (ip, current_port)
c = self.make_client_pool(
(ip, current_port),
vals
vals,
**kwargs
)
client.clients[s] = c
client.hasher.add_node(s)

View File

@ -1,4 +1,5 @@
# Copyright 2012 Pinterest.com
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -23,27 +24,47 @@ from pymemcache.exceptions import (
)
def get_set_helper(client, key, value, key2, value2):
result = client.get(key)
assert result is None
client.set(key, value, noreply=False)
result = client.get(key)
assert result == value
client.set(key2, value2, noreply=True)
result = client.get(key2)
assert result == value2
result = client.get_many([key, key2])
assert result == {key: value, key2: value2}
result = client.get_many([])
assert result == {}
@pytest.mark.integration()
def test_get_set(client_class, host, port, socket_module):
client = client_class((host, port), socket_module=socket_module)
client.flush_all()
result = client.get('key')
assert result is None
key = b'key'
value = b'value'
key2 = b'key2'
value2 = b'value2'
get_set_helper(client, key, value, key2, value2)
client.set(b'key', b'value', noreply=False)
result = client.get(b'key')
assert result == b'value'
client.set(b'key2', b'value2', noreply=True)
result = client.get(b'key2')
assert result == b'value2'
@pytest.mark.integration()
def test_get_set_unicode_key(client_class, host, port, socket_module):
client = client_class((host, port), socket_module=socket_module, allow_unicode_keys=True)
client.flush_all()
result = client.get_many([b'key', b'key2'])
assert result == {b'key': b'value', b'key2': b'value2'}
result = client.get_many([])
assert result == {}
key = u"こんにちは"
value = b'hello'
key2 = 'my☃'
value2 = b'value2'
get_set_helper(client, key, value, key2, value2)
@pytest.mark.integration()

View File

@ -26,12 +26,14 @@ class MockMemcacheClient(object):
timeout=None,
no_delay=False,
ignore_exc=False,
default_noreply=True):
default_noreply=True,
allow_unicode_keys=False):
self._contents = {}
self.serializer = serializer
self.deserializer = deserializer
self.allow_unicode_keys = allow_unicode_keys
# Unused, but present for interface compatibility
self.server = server
@ -41,13 +43,14 @@ class MockMemcacheClient(object):
self.ignore_exc = ignore_exc
def get(self, key, default=None):
if isinstance(key, six.text_type):
raise MemcacheIllegalInputError(key)
if isinstance(key, six.string_types):
try:
key = key.encode('ascii')
except (UnicodeEncodeError, UnicodeDecodeError):
raise MemcacheIllegalInputError
if not self.allow_unicode_keys:
if isinstance(key, six.text_type):
raise MemcacheIllegalInputError(key)
if isinstance(key, six.string_types):
try:
key = key.encode('ascii')
except (UnicodeEncodeError, UnicodeDecodeError):
raise MemcacheIllegalInputError
if key not in self._contents:
return default
@ -72,15 +75,16 @@ class MockMemcacheClient(object):
get_multi = get_many
def set(self, key, value, expire=0, noreply=True):
if isinstance(key, six.text_type):
raise MemcacheIllegalInputError(key)
if not self.allow_unicode_keys:
if isinstance(key, six.text_type):
raise MemcacheIllegalInputError(key)
if isinstance(key, six.string_types):
try:
key = key.encode('ascii')
except (UnicodeEncodeError, UnicodeDecodeError):
raise MemcacheIllegalInputError
if isinstance(value, six.text_type):
raise MemcacheIllegalInputError(value)
if isinstance(key, six.string_types):
try:
key = key.encode('ascii')
except (UnicodeEncodeError, UnicodeDecodeError):
raise MemcacheIllegalInputError
if isinstance(value, six.string_types):
try:
value = value.encode('ascii')