From 3613587536673154b45dbb8fe482e736f13a3a36 Mon Sep 17 00:00:00 2001 From: Joe Gordon Date: Mon, 31 Oct 2016 14:04:00 -0700 Subject: [PATCH] Add optional support for unicode keys memcached's ASCII protocol supports unicode keys, so lets support them as well. Since using unicode keys for memcache is uncommon and to preserve the previous behavior disable support by default. --- README.rst | 1 + docs/getting_started.rst | 7 +++-- pymemcache/client/base.py | 23 +++++++++----- pymemcache/client/hash.py | 5 ++- pymemcache/client/rendezvous.py | 2 +- pymemcache/test/test_client.py | 37 +++++++++++++++++++++++ pymemcache/test/test_client_hash.py | 7 +++-- pymemcache/test/test_integration.py | 47 +++++++++++++++++++++-------- pymemcache/test/utils.py | 34 ++++++++++++--------- 9 files changed, 121 insertions(+), 42 deletions(-) diff --git a/README.rst b/README.rst index 02e7b18..813695a 100644 --- a/README.rst +++ b/README.rst @@ -123,3 +123,4 @@ Credits * `Ernest W. Durbin III `_ * `Remco van Oosterhout `_ * `Nicholas Charriere `_ +* `Joe Gordon `_ diff --git a/docs/getting_started.rst b/docs/getting_started.rst index 2a93d09..c5f5b2e 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -63,8 +63,11 @@ This client implements the ASCII protocol of memcached. This means keys should n contain any of the following illegal characters: > Keys cannot have spaces, new lines, carriage returns, or null characters. We suggest that if you have unicode characters, or long keys, you use an effective -hashing mechanism before calling this client. At Pinterest, we have found that murmur3 hash is a -great candidate for this. +hashing mechanism before calling this client. At Pinterest, we have found that +murmur3 hash is a great candidate for this. Alternatively you can +set `allow_unicode_keys` to support unicode keys, but beware of +what unicode encoding you use to make sure multiple clients can find the +same key. Best Practices diff --git a/pymemcache/client/base.py b/pymemcache/client/base.py index 384223f..1d4beb6 100644 --- a/pymemcache/client/base.py +++ b/pymemcache/client/base.py @@ -81,9 +81,12 @@ STAT_TYPES = { # Common helper functions. -def _check_key(key, key_prefix=b''): +def _check_key(key, allow_unicode_keys, key_prefix=b''): """Checks key and add key_prefix.""" - if isinstance(key, VALID_STRING_TYPES): + if allow_unicode_keys: + if isinstance(key, six.text_type): + key = key.encode('utf8') + elif isinstance(key, VALID_STRING_TYPES): try: key = key.encode('ascii') except (UnicodeEncodeError, UnicodeDecodeError): @@ -177,7 +180,8 @@ class Client(object): ignore_exc=False, socket_module=socket, key_prefix=b'', - default_noreply=True): + default_noreply=True, + allow_unicode_keys=False): """ Constructor. @@ -203,6 +207,7 @@ class Client(object): default_noreply: bool, the default value for 'noreply' as passed to store commands (except from cas, incr, and decr, which default to False). + allow_unicode_keys: bool, support unicode (utf8) keys Notes: The constructor does not make a connection to memcached. The first @@ -223,10 +228,11 @@ class Client(object): raise TypeError("key_prefix should be bytes.") self.key_prefix = key_prefix self.default_noreply = default_noreply + self.allow_unicode_keys = allow_unicode_keys def check_key(self, key): """Checks key and add key_prefix.""" - return _check_key(key, key_prefix=self.key_prefix) + return _check_key(key, allow_unicode_keys=self.allow_unicode_keys, key_prefix=self.key_prefix) def _connect(self): sock = self.socket_module.socket(self.socket_module.AF_INET, @@ -841,7 +847,8 @@ class PooledClient(object): key_prefix=b'', max_pool_size=None, lock_generator=None, - default_noreply=True): + default_noreply=True, + allow_unicode_keys=False): self.server = server self.serializer = serializer self.deserializer = deserializer @@ -851,6 +858,7 @@ class PooledClient(object): self.ignore_exc = ignore_exc self.socket_module = socket_module self.default_noreply = default_noreply + self.allow_unicode_keys = allow_unicode_keys if isinstance(key_prefix, six.text_type): key_prefix = key_prefix.encode('ascii') if not isinstance(key_prefix, bytes): @@ -864,7 +872,7 @@ class PooledClient(object): def check_key(self, key): """Checks key and add key_prefix.""" - return _check_key(key, key_prefix=self.key_prefix) + return _check_key(key, allow_unicode_keys=self.allow_unicode_keys, key_prefix=self.key_prefix) def _create_client(self): client = Client(self.server, @@ -878,7 +886,8 @@ class PooledClient(object): ignore_exc=False, socket_module=self.socket_module, key_prefix=self.key_prefix, - default_noreply=self.default_noreply) + default_noreply=self.default_noreply, + allow_unicode_keys=self.allow_unicode_keys) return client def close(self): diff --git a/pymemcache/client/hash.py b/pymemcache/client/hash.py index 55b6287..946608d 100644 --- a/pymemcache/client/hash.py +++ b/pymemcache/client/hash.py @@ -30,6 +30,7 @@ class HashClient(object): dead_timeout=60, use_pooling=False, ignore_exc=False, + allow_unicode_keys=False ): """ Constructor. @@ -66,6 +67,7 @@ class HashClient(object): self.use_pooling = use_pooling self.key_prefix = key_prefix self.ignore_exc = ignore_exc + self.allow_unicode_keys = allow_unicode_keys self._failed_clients = {} self._dead_clients = {} self._last_dead_check_time = time.time() @@ -80,6 +82,7 @@ class HashClient(object): 'key_prefix': key_prefix, 'serializer': serializer, 'deserializer': deserializer, + 'allow_unicode_keys': allow_unicode_keys, } if use_pooling is True: @@ -113,7 +116,7 @@ class HashClient(object): self.hasher.remove_node(key) def _get_client(self, key): - _check_key(key, self.key_prefix) + _check_key(key, self.allow_unicode_keys, self.key_prefix) if len(self._dead_clients) > 0: current_time = time.time() ldc = self._last_dead_check_time diff --git a/pymemcache/client/rendezvous.py b/pymemcache/client/rendezvous.py index 32ecc2b..46542ef 100644 --- a/pymemcache/client/rendezvous.py +++ b/pymemcache/client/rendezvous.py @@ -36,7 +36,7 @@ class RendezvousHash(object): for node in self.nodes: score = self.hash_function( - "%s-%s" % (str(node), str(key))) + "%s-%s" % (node, key)) if score > high_score: (high_score, winner) = (score, node) diff --git a/pymemcache/test/test_client.py b/pymemcache/test/test_client.py index ab72652..ae18c28 100644 --- a/pymemcache/test/test_client.py +++ b/pymemcache/test/test_client.py @@ -1,4 +1,5 @@ # Copyright 2012 Pinterest.com +# -*- coding: utf-8 -*- # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -92,6 +93,18 @@ class ClientTestMixin(object): with pytest.raises(MemcacheIllegalInputError): _set() + def test_set_unicode_key_ok(self): + client = self.make_client([b'STORED\r\n'], allow_unicode_keys=True) + + result = client.set(u'\u0FFF', b'value', noreply=False) + assert result is True + + def test_set_unicode_key_ok_snowman(self): + client = self.make_client([b'STORED\r\n'], allow_unicode_keys=True) + + result = client.set('my☃', b'value', noreply=False) + assert result is True + def test_set_unicode_char_in_middle_of_key(self): client = self.make_client([b'STORED\r\n']) @@ -101,6 +114,15 @@ class ClientTestMixin(object): with pytest.raises(MemcacheIllegalInputError): _set() + def test_set_unicode_char_in_middle_of_key_snowman(self): + client = self.make_client([b'STORED\r\n']) + + def _set(): + client.set('my☃', b'value', noreply=False) + + with pytest.raises(MemcacheIllegalInputError): + _set() + def test_set_unicode_value(self): client = self.make_client([b'']) @@ -110,6 +132,12 @@ class ClientTestMixin(object): with pytest.raises(MemcacheIllegalInputError): _set() + def test_set_unicode_char_in_middle_of_key_ok(self): + client = self.make_client([b'STORED\r\n'], allow_unicode_keys=True) + + result = client.set('helloworld_\xb1901520_%c3', b'value', noreply=False) + assert result is True + def test_set_noreply(self): client = self.make_client([]) result = client.set(b'key', b'value', noreply=True) @@ -626,6 +654,15 @@ class TestClient(ClientTestMixin, unittest.TestCase): with pytest.raises(MemcacheClientError): client.get(b'x' * 251) + def test_too_long_unicode_key(self): + client = self.make_client([b'STORED\r\n'], allow_unicode_keys=True) + + with pytest.raises(MemcacheClientError): + client.get('my☃'*150) + + with pytest.raises(MemcacheClientError): + client.get(u'\u0FFF'*150) + def test_key_contains_spae(self): client = self.make_client([b'END\r\n']) with pytest.raises(MemcacheClientError): diff --git a/pymemcache/test/test_client_hash.py b/pymemcache/test/test_client_hash.py index 393a191..f523f4a 100644 --- a/pymemcache/test/test_client_hash.py +++ b/pymemcache/test/test_client_hash.py @@ -12,8 +12,8 @@ import socket class TestHashClient(ClientTestMixin, unittest.TestCase): - def make_client_pool(self, hostname, mock_socket_values, serializer=None): - mock_client = Client(hostname, serializer=serializer) + def make_client_pool(self, hostname, mock_socket_values, serializer=None, **kwargs): + mock_client = Client(hostname, serializer=serializer, **kwargs) mock_client.sock = MockSocket(mock_socket_values) client = PooledClient(hostname, serializer=serializer) client.client_pool = pool.ObjectPool(lambda: mock_client) @@ -28,7 +28,8 @@ class TestHashClient(ClientTestMixin, unittest.TestCase): s = '%s:%s' % (ip, current_port) c = self.make_client_pool( (ip, current_port), - vals + vals, + **kwargs ) client.clients[s] = c client.hasher.add_node(s) diff --git a/pymemcache/test/test_integration.py b/pymemcache/test/test_integration.py index 17b26fd..32b6b6f 100644 --- a/pymemcache/test/test_integration.py +++ b/pymemcache/test/test_integration.py @@ -1,4 +1,5 @@ # Copyright 2012 Pinterest.com +# -*- coding: utf-8 -*- # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,27 +24,47 @@ from pymemcache.exceptions import ( ) +def get_set_helper(client, key, value, key2, value2): + result = client.get(key) + assert result is None + + client.set(key, value, noreply=False) + result = client.get(key) + assert result == value + + client.set(key2, value2, noreply=True) + result = client.get(key2) + assert result == value2 + + result = client.get_many([key, key2]) + assert result == {key: value, key2: value2} + + result = client.get_many([]) + assert result == {} + + @pytest.mark.integration() def test_get_set(client_class, host, port, socket_module): client = client_class((host, port), socket_module=socket_module) client.flush_all() - result = client.get('key') - assert result is None + key = b'key' + value = b'value' + key2 = b'key2' + value2 = b'value2' + get_set_helper(client, key, value, key2, value2) - client.set(b'key', b'value', noreply=False) - result = client.get(b'key') - assert result == b'value' - client.set(b'key2', b'value2', noreply=True) - result = client.get(b'key2') - assert result == b'value2' +@pytest.mark.integration() +def test_get_set_unicode_key(client_class, host, port, socket_module): + client = client_class((host, port), socket_module=socket_module, allow_unicode_keys=True) + client.flush_all() - result = client.get_many([b'key', b'key2']) - assert result == {b'key': b'value', b'key2': b'value2'} - - result = client.get_many([]) - assert result == {} + key = u"こんにちは" + value = b'hello' + key2 = 'my☃' + value2 = b'value2' + get_set_helper(client, key, value, key2, value2) @pytest.mark.integration() diff --git a/pymemcache/test/utils.py b/pymemcache/test/utils.py index 4414031..a4a539e 100644 --- a/pymemcache/test/utils.py +++ b/pymemcache/test/utils.py @@ -26,12 +26,14 @@ class MockMemcacheClient(object): timeout=None, no_delay=False, ignore_exc=False, - default_noreply=True): + default_noreply=True, + allow_unicode_keys=False): self._contents = {} self.serializer = serializer self.deserializer = deserializer + self.allow_unicode_keys = allow_unicode_keys # Unused, but present for interface compatibility self.server = server @@ -41,13 +43,14 @@ class MockMemcacheClient(object): self.ignore_exc = ignore_exc def get(self, key, default=None): - if isinstance(key, six.text_type): - raise MemcacheIllegalInputError(key) - if isinstance(key, six.string_types): - try: - key = key.encode('ascii') - except (UnicodeEncodeError, UnicodeDecodeError): - raise MemcacheIllegalInputError + if not self.allow_unicode_keys: + if isinstance(key, six.text_type): + raise MemcacheIllegalInputError(key) + if isinstance(key, six.string_types): + try: + key = key.encode('ascii') + except (UnicodeEncodeError, UnicodeDecodeError): + raise MemcacheIllegalInputError if key not in self._contents: return default @@ -72,15 +75,16 @@ class MockMemcacheClient(object): get_multi = get_many def set(self, key, value, expire=0, noreply=True): - if isinstance(key, six.text_type): - raise MemcacheIllegalInputError(key) + if not self.allow_unicode_keys: + if isinstance(key, six.text_type): + raise MemcacheIllegalInputError(key) + if isinstance(key, six.string_types): + try: + key = key.encode('ascii') + except (UnicodeEncodeError, UnicodeDecodeError): + raise MemcacheIllegalInputError if isinstance(value, six.text_type): raise MemcacheIllegalInputError(value) - if isinstance(key, six.string_types): - try: - key = key.encode('ascii') - except (UnicodeEncodeError, UnicodeDecodeError): - raise MemcacheIllegalInputError if isinstance(value, six.string_types): try: value = value.encode('ascii')