diff --git a/README.rst b/README.rst index 02e7b18..813695a 100644 --- a/README.rst +++ b/README.rst @@ -123,3 +123,4 @@ Credits * `Ernest W. Durbin III `_ * `Remco van Oosterhout `_ * `Nicholas Charriere `_ +* `Joe Gordon `_ diff --git a/docs/getting_started.rst b/docs/getting_started.rst index 2a93d09..c5f5b2e 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -63,8 +63,11 @@ This client implements the ASCII protocol of memcached. This means keys should n contain any of the following illegal characters: > Keys cannot have spaces, new lines, carriage returns, or null characters. We suggest that if you have unicode characters, or long keys, you use an effective -hashing mechanism before calling this client. At Pinterest, we have found that murmur3 hash is a -great candidate for this. +hashing mechanism before calling this client. At Pinterest, we have found that +murmur3 hash is a great candidate for this. Alternatively you can +set `allow_unicode_keys` to support unicode keys, but beware of +what unicode encoding you use to make sure multiple clients can find the +same key. Best Practices diff --git a/pymemcache/client/base.py b/pymemcache/client/base.py index 384223f..1d4beb6 100644 --- a/pymemcache/client/base.py +++ b/pymemcache/client/base.py @@ -81,9 +81,12 @@ STAT_TYPES = { # Common helper functions. -def _check_key(key, key_prefix=b''): +def _check_key(key, allow_unicode_keys, key_prefix=b''): """Checks key and add key_prefix.""" - if isinstance(key, VALID_STRING_TYPES): + if allow_unicode_keys: + if isinstance(key, six.text_type): + key = key.encode('utf8') + elif isinstance(key, VALID_STRING_TYPES): try: key = key.encode('ascii') except (UnicodeEncodeError, UnicodeDecodeError): @@ -177,7 +180,8 @@ class Client(object): ignore_exc=False, socket_module=socket, key_prefix=b'', - default_noreply=True): + default_noreply=True, + allow_unicode_keys=False): """ Constructor. @@ -203,6 +207,7 @@ class Client(object): default_noreply: bool, the default value for 'noreply' as passed to store commands (except from cas, incr, and decr, which default to False). + allow_unicode_keys: bool, support unicode (utf8) keys Notes: The constructor does not make a connection to memcached. The first @@ -223,10 +228,11 @@ class Client(object): raise TypeError("key_prefix should be bytes.") self.key_prefix = key_prefix self.default_noreply = default_noreply + self.allow_unicode_keys = allow_unicode_keys def check_key(self, key): """Checks key and add key_prefix.""" - return _check_key(key, key_prefix=self.key_prefix) + return _check_key(key, allow_unicode_keys=self.allow_unicode_keys, key_prefix=self.key_prefix) def _connect(self): sock = self.socket_module.socket(self.socket_module.AF_INET, @@ -841,7 +847,8 @@ class PooledClient(object): key_prefix=b'', max_pool_size=None, lock_generator=None, - default_noreply=True): + default_noreply=True, + allow_unicode_keys=False): self.server = server self.serializer = serializer self.deserializer = deserializer @@ -851,6 +858,7 @@ class PooledClient(object): self.ignore_exc = ignore_exc self.socket_module = socket_module self.default_noreply = default_noreply + self.allow_unicode_keys = allow_unicode_keys if isinstance(key_prefix, six.text_type): key_prefix = key_prefix.encode('ascii') if not isinstance(key_prefix, bytes): @@ -864,7 +872,7 @@ class PooledClient(object): def check_key(self, key): """Checks key and add key_prefix.""" - return _check_key(key, key_prefix=self.key_prefix) + return _check_key(key, allow_unicode_keys=self.allow_unicode_keys, key_prefix=self.key_prefix) def _create_client(self): client = Client(self.server, @@ -878,7 +886,8 @@ class PooledClient(object): ignore_exc=False, socket_module=self.socket_module, key_prefix=self.key_prefix, - default_noreply=self.default_noreply) + default_noreply=self.default_noreply, + allow_unicode_keys=self.allow_unicode_keys) return client def close(self): diff --git a/pymemcache/client/hash.py b/pymemcache/client/hash.py index 55b6287..946608d 100644 --- a/pymemcache/client/hash.py +++ b/pymemcache/client/hash.py @@ -30,6 +30,7 @@ class HashClient(object): dead_timeout=60, use_pooling=False, ignore_exc=False, + allow_unicode_keys=False ): """ Constructor. @@ -66,6 +67,7 @@ class HashClient(object): self.use_pooling = use_pooling self.key_prefix = key_prefix self.ignore_exc = ignore_exc + self.allow_unicode_keys = allow_unicode_keys self._failed_clients = {} self._dead_clients = {} self._last_dead_check_time = time.time() @@ -80,6 +82,7 @@ class HashClient(object): 'key_prefix': key_prefix, 'serializer': serializer, 'deserializer': deserializer, + 'allow_unicode_keys': allow_unicode_keys, } if use_pooling is True: @@ -113,7 +116,7 @@ class HashClient(object): self.hasher.remove_node(key) def _get_client(self, key): - _check_key(key, self.key_prefix) + _check_key(key, self.allow_unicode_keys, self.key_prefix) if len(self._dead_clients) > 0: current_time = time.time() ldc = self._last_dead_check_time diff --git a/pymemcache/client/rendezvous.py b/pymemcache/client/rendezvous.py index 32ecc2b..46542ef 100644 --- a/pymemcache/client/rendezvous.py +++ b/pymemcache/client/rendezvous.py @@ -36,7 +36,7 @@ class RendezvousHash(object): for node in self.nodes: score = self.hash_function( - "%s-%s" % (str(node), str(key))) + "%s-%s" % (node, key)) if score > high_score: (high_score, winner) = (score, node) diff --git a/pymemcache/test/test_client.py b/pymemcache/test/test_client.py index ab72652..ae18c28 100644 --- a/pymemcache/test/test_client.py +++ b/pymemcache/test/test_client.py @@ -1,4 +1,5 @@ # Copyright 2012 Pinterest.com +# -*- coding: utf-8 -*- # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -92,6 +93,18 @@ class ClientTestMixin(object): with pytest.raises(MemcacheIllegalInputError): _set() + def test_set_unicode_key_ok(self): + client = self.make_client([b'STORED\r\n'], allow_unicode_keys=True) + + result = client.set(u'\u0FFF', b'value', noreply=False) + assert result is True + + def test_set_unicode_key_ok_snowman(self): + client = self.make_client([b'STORED\r\n'], allow_unicode_keys=True) + + result = client.set('my☃', b'value', noreply=False) + assert result is True + def test_set_unicode_char_in_middle_of_key(self): client = self.make_client([b'STORED\r\n']) @@ -101,6 +114,15 @@ class ClientTestMixin(object): with pytest.raises(MemcacheIllegalInputError): _set() + def test_set_unicode_char_in_middle_of_key_snowman(self): + client = self.make_client([b'STORED\r\n']) + + def _set(): + client.set('my☃', b'value', noreply=False) + + with pytest.raises(MemcacheIllegalInputError): + _set() + def test_set_unicode_value(self): client = self.make_client([b'']) @@ -110,6 +132,12 @@ class ClientTestMixin(object): with pytest.raises(MemcacheIllegalInputError): _set() + def test_set_unicode_char_in_middle_of_key_ok(self): + client = self.make_client([b'STORED\r\n'], allow_unicode_keys=True) + + result = client.set('helloworld_\xb1901520_%c3', b'value', noreply=False) + assert result is True + def test_set_noreply(self): client = self.make_client([]) result = client.set(b'key', b'value', noreply=True) @@ -626,6 +654,15 @@ class TestClient(ClientTestMixin, unittest.TestCase): with pytest.raises(MemcacheClientError): client.get(b'x' * 251) + def test_too_long_unicode_key(self): + client = self.make_client([b'STORED\r\n'], allow_unicode_keys=True) + + with pytest.raises(MemcacheClientError): + client.get('my☃'*150) + + with pytest.raises(MemcacheClientError): + client.get(u'\u0FFF'*150) + def test_key_contains_spae(self): client = self.make_client([b'END\r\n']) with pytest.raises(MemcacheClientError): diff --git a/pymemcache/test/test_client_hash.py b/pymemcache/test/test_client_hash.py index 393a191..f523f4a 100644 --- a/pymemcache/test/test_client_hash.py +++ b/pymemcache/test/test_client_hash.py @@ -12,8 +12,8 @@ import socket class TestHashClient(ClientTestMixin, unittest.TestCase): - def make_client_pool(self, hostname, mock_socket_values, serializer=None): - mock_client = Client(hostname, serializer=serializer) + def make_client_pool(self, hostname, mock_socket_values, serializer=None, **kwargs): + mock_client = Client(hostname, serializer=serializer, **kwargs) mock_client.sock = MockSocket(mock_socket_values) client = PooledClient(hostname, serializer=serializer) client.client_pool = pool.ObjectPool(lambda: mock_client) @@ -28,7 +28,8 @@ class TestHashClient(ClientTestMixin, unittest.TestCase): s = '%s:%s' % (ip, current_port) c = self.make_client_pool( (ip, current_port), - vals + vals, + **kwargs ) client.clients[s] = c client.hasher.add_node(s) diff --git a/pymemcache/test/test_integration.py b/pymemcache/test/test_integration.py index 17b26fd..32b6b6f 100644 --- a/pymemcache/test/test_integration.py +++ b/pymemcache/test/test_integration.py @@ -1,4 +1,5 @@ # Copyright 2012 Pinterest.com +# -*- coding: utf-8 -*- # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,27 +24,47 @@ from pymemcache.exceptions import ( ) +def get_set_helper(client, key, value, key2, value2): + result = client.get(key) + assert result is None + + client.set(key, value, noreply=False) + result = client.get(key) + assert result == value + + client.set(key2, value2, noreply=True) + result = client.get(key2) + assert result == value2 + + result = client.get_many([key, key2]) + assert result == {key: value, key2: value2} + + result = client.get_many([]) + assert result == {} + + @pytest.mark.integration() def test_get_set(client_class, host, port, socket_module): client = client_class((host, port), socket_module=socket_module) client.flush_all() - result = client.get('key') - assert result is None + key = b'key' + value = b'value' + key2 = b'key2' + value2 = b'value2' + get_set_helper(client, key, value, key2, value2) - client.set(b'key', b'value', noreply=False) - result = client.get(b'key') - assert result == b'value' - client.set(b'key2', b'value2', noreply=True) - result = client.get(b'key2') - assert result == b'value2' +@pytest.mark.integration() +def test_get_set_unicode_key(client_class, host, port, socket_module): + client = client_class((host, port), socket_module=socket_module, allow_unicode_keys=True) + client.flush_all() - result = client.get_many([b'key', b'key2']) - assert result == {b'key': b'value', b'key2': b'value2'} - - result = client.get_many([]) - assert result == {} + key = u"こんにちは" + value = b'hello' + key2 = 'my☃' + value2 = b'value2' + get_set_helper(client, key, value, key2, value2) @pytest.mark.integration() diff --git a/pymemcache/test/utils.py b/pymemcache/test/utils.py index 4414031..a4a539e 100644 --- a/pymemcache/test/utils.py +++ b/pymemcache/test/utils.py @@ -26,12 +26,14 @@ class MockMemcacheClient(object): timeout=None, no_delay=False, ignore_exc=False, - default_noreply=True): + default_noreply=True, + allow_unicode_keys=False): self._contents = {} self.serializer = serializer self.deserializer = deserializer + self.allow_unicode_keys = allow_unicode_keys # Unused, but present for interface compatibility self.server = server @@ -41,13 +43,14 @@ class MockMemcacheClient(object): self.ignore_exc = ignore_exc def get(self, key, default=None): - if isinstance(key, six.text_type): - raise MemcacheIllegalInputError(key) - if isinstance(key, six.string_types): - try: - key = key.encode('ascii') - except (UnicodeEncodeError, UnicodeDecodeError): - raise MemcacheIllegalInputError + if not self.allow_unicode_keys: + if isinstance(key, six.text_type): + raise MemcacheIllegalInputError(key) + if isinstance(key, six.string_types): + try: + key = key.encode('ascii') + except (UnicodeEncodeError, UnicodeDecodeError): + raise MemcacheIllegalInputError if key not in self._contents: return default @@ -72,15 +75,16 @@ class MockMemcacheClient(object): get_multi = get_many def set(self, key, value, expire=0, noreply=True): - if isinstance(key, six.text_type): - raise MemcacheIllegalInputError(key) + if not self.allow_unicode_keys: + if isinstance(key, six.text_type): + raise MemcacheIllegalInputError(key) + if isinstance(key, six.string_types): + try: + key = key.encode('ascii') + except (UnicodeEncodeError, UnicodeDecodeError): + raise MemcacheIllegalInputError if isinstance(value, six.text_type): raise MemcacheIllegalInputError(value) - if isinstance(key, six.string_types): - try: - key = key.encode('ascii') - except (UnicodeEncodeError, UnicodeDecodeError): - raise MemcacheIllegalInputError if isinstance(value, six.string_types): try: value = value.encode('ascii')