Merge pull request #10 from krallin/custom-socket-module

Let users provide Client with a Custom socket module
This commit is contained in:
Charles Gordon
2013-09-02 16:49:06 -07:00
3 changed files with 116 additions and 48 deletions

View File

@@ -226,7 +226,8 @@ class Client(object):
connect_timeout=None,
timeout=None,
no_delay=False,
ignore_exc=False):
ignore_exc=False,
socket_module=socket):
"""
Constructor.
@@ -245,6 +246,8 @@ class Client(object):
ignore_exc: optional bool, True to cause the "get", "gets",
"get_many" and "gets_many" calls to treat any errors as cache
misses. Defaults to False.
socket_module: socket module to use, e.g. gevent.socket. Defaults to
the standard library's socket module.
Notes:
The constructor does not make a connection to memcached. The first
@@ -257,20 +260,23 @@ class Client(object):
self.timeout = timeout
self.no_delay = no_delay
self.ignore_exc = ignore_exc
self.socket_module = socket_module
self.sock = None
self.buf = ''
def _connect(self):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock = self.socket_module.socket(self.socket_module.AF_INET,
self.socket_module.SOCK_STREAM)
sock.settimeout(self.connect_timeout)
sock.connect(self.server)
sock.settimeout(self.timeout)
if self.no_delay:
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
sock.setsockopt(self.socket_module.IPPROTO_TCP,
self.socket_module.TCP_NODELAY, 1)
self.sock = sock
def close(self):
"""Close the connetion to memcached, if it is open. The next call to a
"""Close the connection to memcached, if it is open. The next call to a
method that requires a connection will re-open it."""
if self.sock is not None:
try:

View File

@@ -14,14 +14,16 @@
import argparse
import json
import socket
from pymemcache.client import Client, MemcacheClientError, MemcacheUnknownCommandError
from pymemcache.client import (Client, MemcacheClientError,
MemcacheUnknownCommandError)
from pymemcache.client import MemcacheIllegalInputError
from nose import tools
def get_set_test(host, port):
client = Client((host, port))
def get_set_test(host, port, socket_module):
client = Client((host, port), socket_module=socket_module)
client.flush_all()
result = client.get('key')
@@ -42,8 +44,8 @@ def get_set_test(host, port):
tools.assert_equal(result, {})
def add_replace_test(host, port):
client = Client((host, port))
def add_replace_test(host, port, socket_module):
client = Client((host, port), socket_module=socket_module)
client.flush_all()
result = client.add('key', 'value', noreply=False)
@@ -67,8 +69,8 @@ def add_replace_test(host, port):
tools.assert_equal(result, 'value2')
def append_prepend_test(host, port):
client = Client((host, port))
def append_prepend_test(host, port, socket_module):
client = Client((host, port), socket_module=socket_module)
client.flush_all()
result = client.append('key', 'value', noreply=False)
@@ -94,8 +96,8 @@ def append_prepend_test(host, port):
tools.assert_equal(result, 'beforevalueafter')
def cas_test(host, port):
client = Client((host, port))
def cas_test(host, port, socket_module):
client = Client((host, port), socket_module=socket_module)
client.flush_all()
result = client.cas('key', 'value', '1', noreply=False)
@@ -117,8 +119,8 @@ def cas_test(host, port):
tools.assert_equal(result, False)
def gets_test(host, port):
client = Client((host, port))
def gets_test(host, port, socket_module):
client = Client((host, port), socket_module=socket_module)
client.flush_all()
result = client.gets('key')
@@ -130,8 +132,8 @@ def gets_test(host, port):
tools.assert_equal(result[0], 'value')
def delete_test(host, port):
client = Client((host, port))
def delete_test(host, port, socket_module):
client = Client((host, port), socket_module=socket_module)
client.flush_all()
result = client.delete('key', noreply=False)
@@ -147,8 +149,8 @@ def delete_test(host, port):
tools.assert_equal(result, None)
def incr_decr_test(host, port):
client = Client((host, port))
def incr_decr_test(host, port, socket_module):
client = Client((host, port), socket_module=socket_module)
client.flush_all()
result = client.incr('key', 1, noreply=False)
@@ -173,12 +175,12 @@ def incr_decr_test(host, port):
tools.assert_equal(result, '0')
def misc_test(host, port):
client = Client((host, port))
def misc_test(host, port, socket_module):
client = Client((host, port), socket_module=socket_module)
client.flush_all()
def test_serialization_deserialization(host, port):
def test_serialization_deserialization(host, port, socket_module):
def _ser(key, value):
return json.dumps(value), 1
@@ -187,7 +189,8 @@ def test_serialization_deserialization(host, port):
return json.loads(value)
return value
client = Client((host, port), serializer=_ser, deserializer=_des)
client = Client((host, port), serializer=_ser, deserializer=_des,
socket_module=socket_module)
client.flush_all()
value = {'a': 'b', 'c': ['d']}
@@ -196,8 +199,8 @@ def test_serialization_deserialization(host, port):
tools.assert_equal(result, value)
def test_errors(host, port):
client = Client((host, port))
def test_errors(host, port, socket_module):
client = Client((host, port), socket_module=socket_module)
client.flush_all()
def _key_with_ws():
@@ -238,26 +241,38 @@ def main():
args = parser.parse_args()
print "Testing get and set..."
get_set_test(args.server, args.port)
print "Testing add and replace..."
add_replace_test(args.server, args.port)
print "Testing append and prepend..."
append_prepend_test(args.server, args.port)
print "Testing cas..."
cas_test(args.server, args.port)
print "Testing gets..."
gets_test(args.server, args.port)
print "Testing delete..."
delete_test(args.server, args.port)
print "Testing incr and decr..."
incr_decr_test(args.server, args.port)
print "Testing flush_all..."
misc_test(args.server, args.port)
print "Testing serialization and deserialization..."
test_serialization_deserialization(args.server, args.port)
print "Testing error cases..."
test_errors(args.server, args.port)
socket_modules = [socket]
try:
from gevent import socket as gevent_socket
except ImportError:
print "Skipping gevent (not installed)"
else:
socket_modules.append(gevent_socket)
for socket_module in socket_modules:
print "Testing with socket module:", socket_module.__name__
print "Testing get and set..."
get_set_test(args.server, args.port, socket_module)
print "Testing add and replace..."
add_replace_test(args.server, args.port, socket_module)
print "Testing append and prepend..."
append_prepend_test(args.server, args.port, socket_module)
print "Testing cas..."
cas_test(args.server, args.port, socket_module)
print "Testing gets..."
gets_test(args.server, args.port, socket_module)
print "Testing delete..."
delete_test(args.server, args.port, socket_module)
print "Testing incr and decr..."
incr_decr_test(args.server, args.port, socket_module)
print "Testing flush_all..."
misc_test(args.server, args.port, socket_module)
print "Testing serialization and deserialization..."
test_serialization_deserialization(args.server, args.port,
socket_module)
print "Testing error cases..."
test_errors(args.server, args.port, socket_module)
if __name__ == '__main__':

View File

@@ -14,6 +14,7 @@
import collections
import json
import socket
from nose import tools
from pymemcache.client import Client, MemcacheUnknownCommandError
@@ -26,6 +27,9 @@ class MockSocket(object):
self.recv_bufs = collections.deque(recv_bufs)
self.send_bufs = []
self.closed = False
self.timeouts = []
self.connections = []
self.socket_options = []
def sendall(self, value):
self.send_bufs.append(value)
@@ -39,6 +43,23 @@ class MockSocket(object):
raise value
return value
def settimeout(self, timeout):
self.timeouts.append(timeout)
def connect(self, server):
self.connections.append(server)
def setsockopt(self, level, option, value):
self.socket_options.append((level, option, value))
class MockSocketModule(object):
def socket(self, family, type):
return MockSocket([])
def __getattr__(self, name):
return getattr(socket, name)
def test_set_success():
client = Client(None)
@@ -131,7 +152,7 @@ def test_set_exception():
def test_set_many_success():
client = Client(None)
client.sock = MockSocket(['STORED\r\n'])
result = client.set_many({'key' : 'value'}, noreply=False)
result = client.set_many({'key': 'value'}, noreply=False)
tools.assert_equal(result, True)
tools.assert_equal(client.sock.closed, False)
tools.assert_equal(len(client.sock.send_bufs), 1)
@@ -142,7 +163,7 @@ def test_set_many_exception():
client.sock = MockSocket(['STORED\r\n', Exception('fail')])
def _set():
client.set_many({'key' : 'value', 'other' : 'value'}, noreply=False)
client.set_many({'key': 'value', 'other': 'value'}, noreply=False)
tools.assert_raises(Exception, _set)
tools.assert_equal(client.sock, None)
@@ -284,7 +305,6 @@ def test_cr_nl_boundaries():
result = client.get_many(['key1', 'key2'])
tools.assert_equals(result, {'key1': 'value1', 'key2': 'value2'})
client.sock = MockSocket(['VALUE key1 0 6\r\n',
'value1\r\n',
'VALUE key2 0 6\r\n',
@@ -489,6 +509,7 @@ def test_serialization():
'set key 0 0 20 noreply\r\n{"a": "b", "c": "d"}\r\n'
])
def test_stats():
client = Client(None)
client.sock = MockSocket(['STAT fake_stats 1\r\n', 'END\r\n'])
@@ -498,6 +519,7 @@ def test_stats():
])
tools.assert_equal(result, {'fake_stats': 1})
def test_stats_with_args():
client = Client(None)
client.sock = MockSocket(['STAT fake_stats 1\r\n', 'END\r\n'])
@@ -507,6 +529,7 @@ def test_stats_with_args():
])
tools.assert_equal(result, {'fake_stats': 1})
def test_stats_conversions():
client = Client(None)
client.sock = MockSocket([
@@ -540,3 +563,27 @@ def test_stats_conversions():
'version': '1.4.14',
}
tools.assert_equal(result, expected)
def test_socket_connect():
server = ("example.com", 11211)
client = Client(server, socket_module=MockSocketModule())
client._connect()
tools.assert_equal(client.sock.connections, [server])
timeout = 2
connect_timeout = 3
client = Client(server, connect_timeout=connect_timeout, timeout=timeout,
socket_module=MockSocketModule())
client._connect()
tools.assert_equal(client.sock.timeouts, [connect_timeout, timeout])
client = Client(server, socket_module=MockSocketModule())
client._connect()
tools.assert_equal(client.sock.socket_options, [])
client = Client(server, socket_module=MockSocketModule(), no_delay=True)
client._connect()
tools.assert_equal(client.sock.socket_options, [(socket.IPPROTO_TCP,
socket.TCP_NODELAY, 1)])