Merge pull request #10 from krallin/custom-socket-module
Let users provide Client with a Custom socket module
This commit is contained in:
@@ -226,7 +226,8 @@ class Client(object):
|
||||
connect_timeout=None,
|
||||
timeout=None,
|
||||
no_delay=False,
|
||||
ignore_exc=False):
|
||||
ignore_exc=False,
|
||||
socket_module=socket):
|
||||
"""
|
||||
Constructor.
|
||||
|
||||
@@ -245,6 +246,8 @@ class Client(object):
|
||||
ignore_exc: optional bool, True to cause the "get", "gets",
|
||||
"get_many" and "gets_many" calls to treat any errors as cache
|
||||
misses. Defaults to False.
|
||||
socket_module: socket module to use, e.g. gevent.socket. Defaults to
|
||||
the standard library's socket module.
|
||||
|
||||
Notes:
|
||||
The constructor does not make a connection to memcached. The first
|
||||
@@ -257,20 +260,23 @@ class Client(object):
|
||||
self.timeout = timeout
|
||||
self.no_delay = no_delay
|
||||
self.ignore_exc = ignore_exc
|
||||
self.socket_module = socket_module
|
||||
self.sock = None
|
||||
self.buf = ''
|
||||
|
||||
def _connect(self):
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock = self.socket_module.socket(self.socket_module.AF_INET,
|
||||
self.socket_module.SOCK_STREAM)
|
||||
sock.settimeout(self.connect_timeout)
|
||||
sock.connect(self.server)
|
||||
sock.settimeout(self.timeout)
|
||||
if self.no_delay:
|
||||
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||||
sock.setsockopt(self.socket_module.IPPROTO_TCP,
|
||||
self.socket_module.TCP_NODELAY, 1)
|
||||
self.sock = sock
|
||||
|
||||
def close(self):
|
||||
"""Close the connetion to memcached, if it is open. The next call to a
|
||||
"""Close the connection to memcached, if it is open. The next call to a
|
||||
method that requires a connection will re-open it."""
|
||||
if self.sock is not None:
|
||||
try:
|
||||
|
||||
@@ -14,14 +14,16 @@
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import socket
|
||||
|
||||
from pymemcache.client import Client, MemcacheClientError, MemcacheUnknownCommandError
|
||||
from pymemcache.client import (Client, MemcacheClientError,
|
||||
MemcacheUnknownCommandError)
|
||||
from pymemcache.client import MemcacheIllegalInputError
|
||||
from nose import tools
|
||||
|
||||
|
||||
def get_set_test(host, port):
|
||||
client = Client((host, port))
|
||||
def get_set_test(host, port, socket_module):
|
||||
client = Client((host, port), socket_module=socket_module)
|
||||
client.flush_all()
|
||||
|
||||
result = client.get('key')
|
||||
@@ -42,8 +44,8 @@ def get_set_test(host, port):
|
||||
tools.assert_equal(result, {})
|
||||
|
||||
|
||||
def add_replace_test(host, port):
|
||||
client = Client((host, port))
|
||||
def add_replace_test(host, port, socket_module):
|
||||
client = Client((host, port), socket_module=socket_module)
|
||||
client.flush_all()
|
||||
|
||||
result = client.add('key', 'value', noreply=False)
|
||||
@@ -67,8 +69,8 @@ def add_replace_test(host, port):
|
||||
tools.assert_equal(result, 'value2')
|
||||
|
||||
|
||||
def append_prepend_test(host, port):
|
||||
client = Client((host, port))
|
||||
def append_prepend_test(host, port, socket_module):
|
||||
client = Client((host, port), socket_module=socket_module)
|
||||
client.flush_all()
|
||||
|
||||
result = client.append('key', 'value', noreply=False)
|
||||
@@ -94,8 +96,8 @@ def append_prepend_test(host, port):
|
||||
tools.assert_equal(result, 'beforevalueafter')
|
||||
|
||||
|
||||
def cas_test(host, port):
|
||||
client = Client((host, port))
|
||||
def cas_test(host, port, socket_module):
|
||||
client = Client((host, port), socket_module=socket_module)
|
||||
client.flush_all()
|
||||
|
||||
result = client.cas('key', 'value', '1', noreply=False)
|
||||
@@ -117,8 +119,8 @@ def cas_test(host, port):
|
||||
tools.assert_equal(result, False)
|
||||
|
||||
|
||||
def gets_test(host, port):
|
||||
client = Client((host, port))
|
||||
def gets_test(host, port, socket_module):
|
||||
client = Client((host, port), socket_module=socket_module)
|
||||
client.flush_all()
|
||||
|
||||
result = client.gets('key')
|
||||
@@ -130,8 +132,8 @@ def gets_test(host, port):
|
||||
tools.assert_equal(result[0], 'value')
|
||||
|
||||
|
||||
def delete_test(host, port):
|
||||
client = Client((host, port))
|
||||
def delete_test(host, port, socket_module):
|
||||
client = Client((host, port), socket_module=socket_module)
|
||||
client.flush_all()
|
||||
|
||||
result = client.delete('key', noreply=False)
|
||||
@@ -147,8 +149,8 @@ def delete_test(host, port):
|
||||
tools.assert_equal(result, None)
|
||||
|
||||
|
||||
def incr_decr_test(host, port):
|
||||
client = Client((host, port))
|
||||
def incr_decr_test(host, port, socket_module):
|
||||
client = Client((host, port), socket_module=socket_module)
|
||||
client.flush_all()
|
||||
|
||||
result = client.incr('key', 1, noreply=False)
|
||||
@@ -173,12 +175,12 @@ def incr_decr_test(host, port):
|
||||
tools.assert_equal(result, '0')
|
||||
|
||||
|
||||
def misc_test(host, port):
|
||||
client = Client((host, port))
|
||||
def misc_test(host, port, socket_module):
|
||||
client = Client((host, port), socket_module=socket_module)
|
||||
client.flush_all()
|
||||
|
||||
|
||||
def test_serialization_deserialization(host, port):
|
||||
def test_serialization_deserialization(host, port, socket_module):
|
||||
def _ser(key, value):
|
||||
return json.dumps(value), 1
|
||||
|
||||
@@ -187,7 +189,8 @@ def test_serialization_deserialization(host, port):
|
||||
return json.loads(value)
|
||||
return value
|
||||
|
||||
client = Client((host, port), serializer=_ser, deserializer=_des)
|
||||
client = Client((host, port), serializer=_ser, deserializer=_des,
|
||||
socket_module=socket_module)
|
||||
client.flush_all()
|
||||
|
||||
value = {'a': 'b', 'c': ['d']}
|
||||
@@ -196,8 +199,8 @@ def test_serialization_deserialization(host, port):
|
||||
tools.assert_equal(result, value)
|
||||
|
||||
|
||||
def test_errors(host, port):
|
||||
client = Client((host, port))
|
||||
def test_errors(host, port, socket_module):
|
||||
client = Client((host, port), socket_module=socket_module)
|
||||
client.flush_all()
|
||||
|
||||
def _key_with_ws():
|
||||
@@ -238,26 +241,38 @@ def main():
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print "Testing get and set..."
|
||||
get_set_test(args.server, args.port)
|
||||
print "Testing add and replace..."
|
||||
add_replace_test(args.server, args.port)
|
||||
print "Testing append and prepend..."
|
||||
append_prepend_test(args.server, args.port)
|
||||
print "Testing cas..."
|
||||
cas_test(args.server, args.port)
|
||||
print "Testing gets..."
|
||||
gets_test(args.server, args.port)
|
||||
print "Testing delete..."
|
||||
delete_test(args.server, args.port)
|
||||
print "Testing incr and decr..."
|
||||
incr_decr_test(args.server, args.port)
|
||||
print "Testing flush_all..."
|
||||
misc_test(args.server, args.port)
|
||||
print "Testing serialization and deserialization..."
|
||||
test_serialization_deserialization(args.server, args.port)
|
||||
print "Testing error cases..."
|
||||
test_errors(args.server, args.port)
|
||||
socket_modules = [socket]
|
||||
try:
|
||||
from gevent import socket as gevent_socket
|
||||
except ImportError:
|
||||
print "Skipping gevent (not installed)"
|
||||
else:
|
||||
socket_modules.append(gevent_socket)
|
||||
|
||||
for socket_module in socket_modules:
|
||||
print "Testing with socket module:", socket_module.__name__
|
||||
|
||||
print "Testing get and set..."
|
||||
get_set_test(args.server, args.port, socket_module)
|
||||
print "Testing add and replace..."
|
||||
add_replace_test(args.server, args.port, socket_module)
|
||||
print "Testing append and prepend..."
|
||||
append_prepend_test(args.server, args.port, socket_module)
|
||||
print "Testing cas..."
|
||||
cas_test(args.server, args.port, socket_module)
|
||||
print "Testing gets..."
|
||||
gets_test(args.server, args.port, socket_module)
|
||||
print "Testing delete..."
|
||||
delete_test(args.server, args.port, socket_module)
|
||||
print "Testing incr and decr..."
|
||||
incr_decr_test(args.server, args.port, socket_module)
|
||||
print "Testing flush_all..."
|
||||
misc_test(args.server, args.port, socket_module)
|
||||
print "Testing serialization and deserialization..."
|
||||
test_serialization_deserialization(args.server, args.port,
|
||||
socket_module)
|
||||
print "Testing error cases..."
|
||||
test_errors(args.server, args.port, socket_module)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
import collections
|
||||
import json
|
||||
import socket
|
||||
|
||||
from nose import tools
|
||||
from pymemcache.client import Client, MemcacheUnknownCommandError
|
||||
@@ -26,6 +27,9 @@ class MockSocket(object):
|
||||
self.recv_bufs = collections.deque(recv_bufs)
|
||||
self.send_bufs = []
|
||||
self.closed = False
|
||||
self.timeouts = []
|
||||
self.connections = []
|
||||
self.socket_options = []
|
||||
|
||||
def sendall(self, value):
|
||||
self.send_bufs.append(value)
|
||||
@@ -39,6 +43,23 @@ class MockSocket(object):
|
||||
raise value
|
||||
return value
|
||||
|
||||
def settimeout(self, timeout):
|
||||
self.timeouts.append(timeout)
|
||||
|
||||
def connect(self, server):
|
||||
self.connections.append(server)
|
||||
|
||||
def setsockopt(self, level, option, value):
|
||||
self.socket_options.append((level, option, value))
|
||||
|
||||
|
||||
class MockSocketModule(object):
|
||||
def socket(self, family, type):
|
||||
return MockSocket([])
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(socket, name)
|
||||
|
||||
|
||||
def test_set_success():
|
||||
client = Client(None)
|
||||
@@ -131,7 +152,7 @@ def test_set_exception():
|
||||
def test_set_many_success():
|
||||
client = Client(None)
|
||||
client.sock = MockSocket(['STORED\r\n'])
|
||||
result = client.set_many({'key' : 'value'}, noreply=False)
|
||||
result = client.set_many({'key': 'value'}, noreply=False)
|
||||
tools.assert_equal(result, True)
|
||||
tools.assert_equal(client.sock.closed, False)
|
||||
tools.assert_equal(len(client.sock.send_bufs), 1)
|
||||
@@ -142,7 +163,7 @@ def test_set_many_exception():
|
||||
client.sock = MockSocket(['STORED\r\n', Exception('fail')])
|
||||
|
||||
def _set():
|
||||
client.set_many({'key' : 'value', 'other' : 'value'}, noreply=False)
|
||||
client.set_many({'key': 'value', 'other': 'value'}, noreply=False)
|
||||
|
||||
tools.assert_raises(Exception, _set)
|
||||
tools.assert_equal(client.sock, None)
|
||||
@@ -284,7 +305,6 @@ def test_cr_nl_boundaries():
|
||||
result = client.get_many(['key1', 'key2'])
|
||||
tools.assert_equals(result, {'key1': 'value1', 'key2': 'value2'})
|
||||
|
||||
|
||||
client.sock = MockSocket(['VALUE key1 0 6\r\n',
|
||||
'value1\r\n',
|
||||
'VALUE key2 0 6\r\n',
|
||||
@@ -489,6 +509,7 @@ def test_serialization():
|
||||
'set key 0 0 20 noreply\r\n{"a": "b", "c": "d"}\r\n'
|
||||
])
|
||||
|
||||
|
||||
def test_stats():
|
||||
client = Client(None)
|
||||
client.sock = MockSocket(['STAT fake_stats 1\r\n', 'END\r\n'])
|
||||
@@ -498,6 +519,7 @@ def test_stats():
|
||||
])
|
||||
tools.assert_equal(result, {'fake_stats': 1})
|
||||
|
||||
|
||||
def test_stats_with_args():
|
||||
client = Client(None)
|
||||
client.sock = MockSocket(['STAT fake_stats 1\r\n', 'END\r\n'])
|
||||
@@ -507,6 +529,7 @@ def test_stats_with_args():
|
||||
])
|
||||
tools.assert_equal(result, {'fake_stats': 1})
|
||||
|
||||
|
||||
def test_stats_conversions():
|
||||
client = Client(None)
|
||||
client.sock = MockSocket([
|
||||
@@ -540,3 +563,27 @@ def test_stats_conversions():
|
||||
'version': '1.4.14',
|
||||
}
|
||||
tools.assert_equal(result, expected)
|
||||
|
||||
|
||||
def test_socket_connect():
|
||||
server = ("example.com", 11211)
|
||||
|
||||
client = Client(server, socket_module=MockSocketModule())
|
||||
client._connect()
|
||||
tools.assert_equal(client.sock.connections, [server])
|
||||
|
||||
timeout = 2
|
||||
connect_timeout = 3
|
||||
client = Client(server, connect_timeout=connect_timeout, timeout=timeout,
|
||||
socket_module=MockSocketModule())
|
||||
client._connect()
|
||||
tools.assert_equal(client.sock.timeouts, [connect_timeout, timeout])
|
||||
|
||||
client = Client(server, socket_module=MockSocketModule())
|
||||
client._connect()
|
||||
tools.assert_equal(client.sock.socket_options, [])
|
||||
|
||||
client = Client(server, socket_module=MockSocketModule(), no_delay=True)
|
||||
client._connect()
|
||||
tools.assert_equal(client.sock.socket_options, [(socket.IPPROTO_TCP,
|
||||
socket.TCP_NODELAY, 1)])
|
||||
|
||||
Reference in New Issue
Block a user