diff --git a/cassandra/cluster.py b/cassandra/cluster.py index db08c75c..985e3a9b 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -1177,6 +1177,10 @@ class Session(object): remaining_callbacks = set(self._pools.values()) errors = {} + if not remaining_callbacks: + callback(errors) + return + def pool_finished_setting_keyspace(pool, host_errors): remaining_callbacks.remove(pool) if host_errors: diff --git a/cassandra/connection.py b/cassandra/connection.py index 047ecf97..a736a4da 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -322,6 +322,7 @@ class Connection(object): occurred, otherwise :const:`None`. """ if not keyspace or keyspace == self.keyspace: + callback(self, None) return query = QueryMessage(query='USE "%s"' % (keyspace,), diff --git a/cassandra/pool.py b/cassandra/pool.py index d3a4aa91..ffae74af 100644 --- a/cassandra/pool.py +++ b/cassandra/pool.py @@ -532,6 +532,10 @@ class HostConnectionPool(object): remaining_callbacks = set(self._connections) errors = [] + if not remaining_callbacks: + callback(self, errors) + return + def connection_finished_setting_keyspace(conn, error): remaining_callbacks.remove(conn) if error: diff --git a/tests/integration/test_cluster.py b/tests/integration/test_cluster.py index 085a17bd..e7676b3f 100644 --- a/tests/integration/test_cluster.py +++ b/tests/integration/test_cluster.py @@ -69,6 +69,12 @@ class ClusterTests(unittest.TestCase): result2 = session2.execute("SELECT * FROM test") self.assertEquals(result, result2) + def test_set_keyspace_twice(self): + cluster = Cluster() + session = cluster.connect() + session.execute("USE system") + session.execute("USE system") + def test_default_connections(self): """ Ensure errors are not thrown when using non-default policies