Add optional support for unicode keys
memcached's ASCII protocol supports unicode keys, so lets support them as well. Since using unicode keys for memcache is uncommon and to preserve the previous behavior disable support by default.
This commit is contained in:
		| @@ -123,3 +123,4 @@ Credits | ||||
| * `Ernest W. Durbin III <https://github.com/ewdurbin>`_ | ||||
| * `Remco van Oosterhout <https://github.com/Vhab>`_ | ||||
| * `Nicholas Charriere <https://github.com/nichochar>`_ | ||||
| * `Joe Gordon <https://github.com/jogo>`_ | ||||
|   | ||||
| @@ -63,8 +63,11 @@ This client implements the ASCII protocol of memcached. This means keys should n | ||||
| contain any of the following illegal characters: | ||||
| > Keys cannot have spaces, new lines, carriage returns, or null characters. | ||||
| We suggest that if you have unicode characters, or long keys, you use an effective | ||||
| hashing mechanism before calling this client. At Pinterest, we have found that murmur3 hash is a | ||||
| great candidate for this. | ||||
| hashing mechanism before calling this client. At Pinterest, we have found that | ||||
| murmur3 hash is a great candidate for this. Alternatively you can | ||||
| set `allow_unicode_keys` to support unicode keys, but beware of | ||||
| what unicode encoding you use to make sure multiple clients can find the | ||||
| same key. | ||||
|  | ||||
|  | ||||
| Best Practices | ||||
|   | ||||
| @@ -81,9 +81,12 @@ STAT_TYPES = { | ||||
| # Common helper functions. | ||||
|  | ||||
|  | ||||
| def _check_key(key, key_prefix=b''): | ||||
| def _check_key(key, allow_unicode_keys, key_prefix=b''): | ||||
|     """Checks key and add key_prefix.""" | ||||
|     if isinstance(key, VALID_STRING_TYPES): | ||||
|     if allow_unicode_keys: | ||||
|         if isinstance(key, six.text_type): | ||||
|             key = key.encode('utf8') | ||||
|     elif isinstance(key, VALID_STRING_TYPES): | ||||
|         try: | ||||
|             key = key.encode('ascii') | ||||
|         except (UnicodeEncodeError, UnicodeDecodeError): | ||||
| @@ -177,7 +180,8 @@ class Client(object): | ||||
|                  ignore_exc=False, | ||||
|                  socket_module=socket, | ||||
|                  key_prefix=b'', | ||||
|                  default_noreply=True): | ||||
|                  default_noreply=True, | ||||
|                  allow_unicode_keys=False): | ||||
|         """ | ||||
|         Constructor. | ||||
|  | ||||
| @@ -203,6 +207,7 @@ class Client(object): | ||||
|           default_noreply: bool, the default value for 'noreply' as passed to | ||||
|             store commands (except from cas, incr, and decr, which default to | ||||
|             False). | ||||
|           allow_unicode_keys: bool, support unicode (utf8) keys | ||||
|  | ||||
|         Notes: | ||||
|           The constructor does not make a connection to memcached. The first | ||||
| @@ -223,10 +228,11 @@ class Client(object): | ||||
|             raise TypeError("key_prefix should be bytes.") | ||||
|         self.key_prefix = key_prefix | ||||
|         self.default_noreply = default_noreply | ||||
|         self.allow_unicode_keys = allow_unicode_keys | ||||
|  | ||||
|     def check_key(self, key): | ||||
|         """Checks key and add key_prefix.""" | ||||
|         return _check_key(key, key_prefix=self.key_prefix) | ||||
|         return _check_key(key, allow_unicode_keys=self.allow_unicode_keys, key_prefix=self.key_prefix) | ||||
|  | ||||
|     def _connect(self): | ||||
|         sock = self.socket_module.socket(self.socket_module.AF_INET, | ||||
| @@ -841,7 +847,8 @@ class PooledClient(object): | ||||
|                  key_prefix=b'', | ||||
|                  max_pool_size=None, | ||||
|                  lock_generator=None, | ||||
|                  default_noreply=True): | ||||
|                  default_noreply=True, | ||||
|                  allow_unicode_keys=False): | ||||
|         self.server = server | ||||
|         self.serializer = serializer | ||||
|         self.deserializer = deserializer | ||||
| @@ -851,6 +858,7 @@ class PooledClient(object): | ||||
|         self.ignore_exc = ignore_exc | ||||
|         self.socket_module = socket_module | ||||
|         self.default_noreply = default_noreply | ||||
|         self.allow_unicode_keys = allow_unicode_keys | ||||
|         if isinstance(key_prefix, six.text_type): | ||||
|             key_prefix = key_prefix.encode('ascii') | ||||
|         if not isinstance(key_prefix, bytes): | ||||
| @@ -864,7 +872,7 @@ class PooledClient(object): | ||||
|  | ||||
|     def check_key(self, key): | ||||
|         """Checks key and add key_prefix.""" | ||||
|         return _check_key(key, key_prefix=self.key_prefix) | ||||
|         return _check_key(key, allow_unicode_keys=self.allow_unicode_keys, key_prefix=self.key_prefix) | ||||
|  | ||||
|     def _create_client(self): | ||||
|         client = Client(self.server, | ||||
| @@ -878,7 +886,8 @@ class PooledClient(object): | ||||
|                         ignore_exc=False, | ||||
|                         socket_module=self.socket_module, | ||||
|                         key_prefix=self.key_prefix, | ||||
|                         default_noreply=self.default_noreply) | ||||
|                         default_noreply=self.default_noreply, | ||||
|                         allow_unicode_keys=self.allow_unicode_keys) | ||||
|         return client | ||||
|  | ||||
|     def close(self): | ||||
|   | ||||
| @@ -30,6 +30,7 @@ class HashClient(object): | ||||
|         dead_timeout=60, | ||||
|         use_pooling=False, | ||||
|         ignore_exc=False, | ||||
|         allow_unicode_keys=False | ||||
|     ): | ||||
|         """ | ||||
|         Constructor. | ||||
| @@ -66,6 +67,7 @@ class HashClient(object): | ||||
|         self.use_pooling = use_pooling | ||||
|         self.key_prefix = key_prefix | ||||
|         self.ignore_exc = ignore_exc | ||||
|         self.allow_unicode_keys = allow_unicode_keys | ||||
|         self._failed_clients = {} | ||||
|         self._dead_clients = {} | ||||
|         self._last_dead_check_time = time.time() | ||||
| @@ -80,6 +82,7 @@ class HashClient(object): | ||||
|             'key_prefix': key_prefix, | ||||
|             'serializer': serializer, | ||||
|             'deserializer': deserializer, | ||||
|             'allow_unicode_keys': allow_unicode_keys, | ||||
|         } | ||||
|  | ||||
|         if use_pooling is True: | ||||
| @@ -113,7 +116,7 @@ class HashClient(object): | ||||
|         self.hasher.remove_node(key) | ||||
|  | ||||
|     def _get_client(self, key): | ||||
|         _check_key(key, self.key_prefix) | ||||
|         _check_key(key, self.allow_unicode_keys, self.key_prefix) | ||||
|         if len(self._dead_clients) > 0: | ||||
|             current_time = time.time() | ||||
|             ldc = self._last_dead_check_time | ||||
|   | ||||
| @@ -36,7 +36,7 @@ class RendezvousHash(object): | ||||
|  | ||||
|         for node in self.nodes: | ||||
|             score = self.hash_function( | ||||
|                 "%s-%s" % (str(node), str(key))) | ||||
|                 "%s-%s" % (node, key)) | ||||
|  | ||||
|             if score > high_score: | ||||
|                 (high_score, winner) = (score, node) | ||||
|   | ||||
| @@ -1,4 +1,5 @@ | ||||
| # Copyright 2012 Pinterest.com | ||||
| # -*- coding: utf-8 -*- | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| @@ -92,6 +93,18 @@ class ClientTestMixin(object): | ||||
|         with pytest.raises(MemcacheIllegalInputError): | ||||
|             _set() | ||||
|  | ||||
|     def test_set_unicode_key_ok(self): | ||||
|         client = self.make_client([b'STORED\r\n'], allow_unicode_keys=True) | ||||
|  | ||||
|         result = client.set(u'\u0FFF', b'value', noreply=False) | ||||
|         assert result is True | ||||
|  | ||||
|     def test_set_unicode_key_ok_snowman(self): | ||||
|         client = self.make_client([b'STORED\r\n'], allow_unicode_keys=True) | ||||
|  | ||||
|         result = client.set('my☃', b'value', noreply=False) | ||||
|         assert result is True | ||||
|  | ||||
|     def test_set_unicode_char_in_middle_of_key(self): | ||||
|         client = self.make_client([b'STORED\r\n']) | ||||
|  | ||||
| @@ -101,6 +114,15 @@ class ClientTestMixin(object): | ||||
|         with pytest.raises(MemcacheIllegalInputError): | ||||
|             _set() | ||||
|  | ||||
|     def test_set_unicode_char_in_middle_of_key_snowman(self): | ||||
|         client = self.make_client([b'STORED\r\n']) | ||||
|  | ||||
|         def _set(): | ||||
|             client.set('my☃', b'value', noreply=False) | ||||
|  | ||||
|         with pytest.raises(MemcacheIllegalInputError): | ||||
|             _set() | ||||
|  | ||||
|     def test_set_unicode_value(self): | ||||
|         client = self.make_client([b'']) | ||||
|  | ||||
| @@ -110,6 +132,12 @@ class ClientTestMixin(object): | ||||
|         with pytest.raises(MemcacheIllegalInputError): | ||||
|             _set() | ||||
|  | ||||
|     def test_set_unicode_char_in_middle_of_key_ok(self): | ||||
|         client = self.make_client([b'STORED\r\n'], allow_unicode_keys=True) | ||||
|  | ||||
|         result = client.set('helloworld_\xb1901520_%c3', b'value', noreply=False) | ||||
|         assert result is True | ||||
|  | ||||
|     def test_set_noreply(self): | ||||
|         client = self.make_client([]) | ||||
|         result = client.set(b'key', b'value', noreply=True) | ||||
| @@ -626,6 +654,15 @@ class TestClient(ClientTestMixin, unittest.TestCase): | ||||
|         with pytest.raises(MemcacheClientError): | ||||
|             client.get(b'x' * 251) | ||||
|  | ||||
|     def test_too_long_unicode_key(self): | ||||
|         client = self.make_client([b'STORED\r\n'], allow_unicode_keys=True) | ||||
|  | ||||
|         with pytest.raises(MemcacheClientError): | ||||
|             client.get('my☃'*150) | ||||
|  | ||||
|         with pytest.raises(MemcacheClientError): | ||||
|             client.get(u'\u0FFF'*150) | ||||
|  | ||||
|     def test_key_contains_spae(self): | ||||
|         client = self.make_client([b'END\r\n']) | ||||
|         with pytest.raises(MemcacheClientError): | ||||
|   | ||||
| @@ -12,8 +12,8 @@ import socket | ||||
|  | ||||
| class TestHashClient(ClientTestMixin, unittest.TestCase): | ||||
|  | ||||
|     def make_client_pool(self, hostname, mock_socket_values, serializer=None): | ||||
|         mock_client = Client(hostname, serializer=serializer) | ||||
|     def make_client_pool(self, hostname, mock_socket_values, serializer=None, **kwargs): | ||||
|         mock_client = Client(hostname, serializer=serializer, **kwargs) | ||||
|         mock_client.sock = MockSocket(mock_socket_values) | ||||
|         client = PooledClient(hostname, serializer=serializer) | ||||
|         client.client_pool = pool.ObjectPool(lambda: mock_client) | ||||
| @@ -28,7 +28,8 @@ class TestHashClient(ClientTestMixin, unittest.TestCase): | ||||
|             s = '%s:%s' % (ip, current_port) | ||||
|             c = self.make_client_pool( | ||||
|                 (ip, current_port), | ||||
|                 vals | ||||
|                 vals, | ||||
|                 **kwargs | ||||
|             ) | ||||
|             client.clients[s] = c | ||||
|             client.hasher.add_node(s) | ||||
|   | ||||
| @@ -1,4 +1,5 @@ | ||||
| # Copyright 2012 Pinterest.com | ||||
| # -*- coding: utf-8 -*- | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| @@ -23,27 +24,47 @@ from pymemcache.exceptions import ( | ||||
| ) | ||||
|  | ||||
|  | ||||
| def get_set_helper(client, key, value, key2, value2): | ||||
|     result = client.get(key) | ||||
|     assert result is None | ||||
|  | ||||
|     client.set(key, value, noreply=False) | ||||
|     result = client.get(key) | ||||
|     assert result == value | ||||
|  | ||||
|     client.set(key2, value2, noreply=True) | ||||
|     result = client.get(key2) | ||||
|     assert result == value2 | ||||
|  | ||||
|     result = client.get_many([key, key2]) | ||||
|     assert result == {key: value, key2: value2} | ||||
|  | ||||
|     result = client.get_many([]) | ||||
|     assert result == {} | ||||
|  | ||||
|  | ||||
| @pytest.mark.integration() | ||||
| def test_get_set(client_class, host, port, socket_module): | ||||
|     client = client_class((host, port), socket_module=socket_module) | ||||
|     client.flush_all() | ||||
|  | ||||
|     result = client.get('key') | ||||
|     assert result is None | ||||
|     key = b'key' | ||||
|     value = b'value' | ||||
|     key2 = b'key2' | ||||
|     value2 = b'value2' | ||||
|     get_set_helper(client, key, value, key2, value2) | ||||
|  | ||||
|     client.set(b'key', b'value', noreply=False) | ||||
|     result = client.get(b'key') | ||||
|     assert result == b'value' | ||||
|  | ||||
|     client.set(b'key2', b'value2', noreply=True) | ||||
|     result = client.get(b'key2') | ||||
|     assert result == b'value2' | ||||
| @pytest.mark.integration() | ||||
| def test_get_set_unicode_key(client_class, host, port, socket_module): | ||||
|     client = client_class((host, port), socket_module=socket_module, allow_unicode_keys=True) | ||||
|     client.flush_all() | ||||
|  | ||||
|     result = client.get_many([b'key', b'key2']) | ||||
|     assert result == {b'key': b'value', b'key2': b'value2'} | ||||
|  | ||||
|     result = client.get_many([]) | ||||
|     assert result == {} | ||||
|     key = u"こんにちは" | ||||
|     value = b'hello' | ||||
|     key2 = 'my☃' | ||||
|     value2 = b'value2' | ||||
|     get_set_helper(client, key, value, key2, value2) | ||||
|  | ||||
|  | ||||
| @pytest.mark.integration() | ||||
|   | ||||
| @@ -26,12 +26,14 @@ class MockMemcacheClient(object): | ||||
|                  timeout=None, | ||||
|                  no_delay=False, | ||||
|                  ignore_exc=False, | ||||
|                  default_noreply=True): | ||||
|                  default_noreply=True, | ||||
|                  allow_unicode_keys=False): | ||||
|  | ||||
|         self._contents = {} | ||||
|  | ||||
|         self.serializer = serializer | ||||
|         self.deserializer = deserializer | ||||
|         self.allow_unicode_keys = allow_unicode_keys | ||||
|  | ||||
|         # Unused, but present for interface compatibility | ||||
|         self.server = server | ||||
| @@ -41,6 +43,7 @@ class MockMemcacheClient(object): | ||||
|         self.ignore_exc = ignore_exc | ||||
|  | ||||
|     def get(self, key, default=None): | ||||
|         if not self.allow_unicode_keys: | ||||
|             if isinstance(key, six.text_type): | ||||
|                 raise MemcacheIllegalInputError(key) | ||||
|             if isinstance(key, six.string_types): | ||||
| @@ -72,15 +75,16 @@ class MockMemcacheClient(object): | ||||
|     get_multi = get_many | ||||
|  | ||||
|     def set(self, key, value, expire=0, noreply=True): | ||||
|         if not self.allow_unicode_keys: | ||||
|             if isinstance(key, six.text_type): | ||||
|                 raise MemcacheIllegalInputError(key) | ||||
|         if isinstance(value, six.text_type): | ||||
|             raise MemcacheIllegalInputError(value) | ||||
|             if isinstance(key, six.string_types): | ||||
|                 try: | ||||
|                     key = key.encode('ascii') | ||||
|                 except (UnicodeEncodeError, UnicodeDecodeError): | ||||
|                     raise MemcacheIllegalInputError | ||||
|         if isinstance(value, six.text_type): | ||||
|             raise MemcacheIllegalInputError(value) | ||||
|         if isinstance(value, six.string_types): | ||||
|             try: | ||||
|                 value = value.encode('ascii') | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Joe Gordon
					Joe Gordon