From 1553c77dca4d6682b311155d7724bab83590d4f9 Mon Sep 17 00:00:00 2001 From: Adam Holmberg Date: Mon, 17 Aug 2015 11:55:49 -0500 Subject: [PATCH] Use connect_timeout for socket connect in addition to negotiation PYTHON-381 --- cassandra/connection.py | 12 ++++++++---- cassandra/io/twistedreactor.py | 3 ++- tests/integration/standard/test_connection.py | 12 ++++++++++-- 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/cassandra/connection.py b/cassandra/connection.py index a09d5a87..e3e0c7dd 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -242,7 +242,7 @@ class Connection(object): def __init__(self, host='127.0.0.1', port=9042, authenticator=None, ssl_options=None, sockopts=None, compression=True, cql_version=None, protocol_version=MAX_SUPPORTED_VERSION, is_control_connection=False, - user_type_map=None): + user_type_map=None, connect_timeout=None): self.host = host self.port = port self.authenticator = authenticator @@ -253,6 +253,7 @@ class Connection(object): self.protocol_version = protocol_version self.is_control_connection = is_control_connection self.user_type_map = user_type_map + self.connect_timeout = connect_timeout self._push_watchers = defaultdict(set) self._requests = {} self._iobuf = io.BytesIO() @@ -298,8 +299,11 @@ class Connection(object): succeeded in connecting and are ready for service (or raises an exception otherwise). """ + start = time.time() + kwargs['connect_timeout'] = timeout conn = cls(host, *args, **kwargs) - conn.connected_event.wait(timeout) + elapsed = time.time() - start + conn.connected_event.wait(timeout - elapsed) if conn.last_error: if conn.is_unsupported_proto_version: raise ProtocolVersionUnsupported(host, conn.protocol_version) @@ -320,7 +324,7 @@ class Connection(object): if not self._ssl_impl: raise Exception("This version of Python was not compiled with SSL support") self._socket = self._ssl_impl.wrap_socket(self._socket, **self.ssl_options) - self._socket.settimeout(1.0) + self._socket.settimeout(self.connect_timeout) self._socket.connect(sockaddr) sockerr = None break @@ -331,7 +335,7 @@ class Connection(object): sockerr = err if sockerr: - raise socket.error(sockerr.errno, "Tried connecting to %s. Last error: %s" % ([a[4] for a in addresses], sockerr.strerror)) + raise socket.error(sockerr.errno, "Tried connecting to %s. Last error: %s" % ([a[4] for a in addresses], sockerr.strerror or sockerr)) if self.sockopts: for args in self.sockopts: diff --git a/cassandra/io/twistedreactor.py b/cassandra/io/twistedreactor.py index 967a968f..cc74bc6c 100644 --- a/cassandra/io/twistedreactor.py +++ b/cassandra/io/twistedreactor.py @@ -200,7 +200,8 @@ class TwistedConnection(Connection): """ self.connector = reactor.connectTCP( host=self.host, port=self.port, - factory=TwistedConnectionClientFactory(self)) + factory=TwistedConnectionClientFactory(self), + timeout=self.connect_timeout) def client_connection_made(self): """ diff --git a/tests/integration/standard/test_connection.py b/tests/integration/standard/test_connection.py index 3261adc3..eabf2a4a 100644 --- a/tests/integration/standard/test_connection.py +++ b/tests/integration/standard/test_connection.py @@ -19,7 +19,9 @@ except ImportError: from functools import partial from six.moves import range +import sys from threading import Thread, Event +import time from cassandra import ConsistencyLevel, OperationTimedOut from cassandra.cluster import NoHostAvailable @@ -46,7 +48,7 @@ class ConnectionTests(object): def setUp(self): self.klass.initialize_reactor() - def get_connection(self): + def get_connection(self, timeout=5): """ Helper method to solve automated testing issues within Jenkins. Officially patched under the 2.0 branch through @@ -58,7 +60,7 @@ class ConnectionTests(object): e = None for i in range(5): try: - conn = self.klass.factory(host='127.0.0.1', timeout=5, protocol_version=PROTOCOL_VERSION) + conn = self.klass.factory(host='127.0.0.1', timeout=timeout, protocol_version=PROTOCOL_VERSION) break except (OperationTimedOut, NoHostAvailable) as e: continue @@ -224,6 +226,12 @@ class ConnectionTests(object): for t in threads: t.join() + def test_connect_timeout(self): + start = time.time() + self.assertRaises(Exception, self.get_connection, timeout=sys.float_info.min) + end = time.time() + self.assertAlmostEqual(start, end, 1) + class AsyncoreConnectionTests(ConnectionTests, unittest.TestCase):