diff --git a/cassandra/__init__.py b/cassandra/__init__.py index 3052969d..957d24e9 100644 --- a/cassandra/__init__.py +++ b/cassandra/__init__.py @@ -165,7 +165,7 @@ class AlreadyExists(Exception): table = None """ The name of the table that already exists, or, if an attempt was - make to create a keyspace, ``None``. + make to create a keyspace, :const:`None`. """ def __init__(self, keyspace=None, table=None): diff --git a/cassandra/cluster.py b/cassandra/cluster.py index b9b77cd7..460542fa 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -90,7 +90,7 @@ class Cluster(object): compression = True """ Whether or not compression should be enabled when possible. Defaults to - ``True`` and attempts to use snappy compression. + :const:`True` and attempts to use snappy compression. """ auth_provider = None @@ -99,11 +99,10 @@ class Cluster(object): and returns a dict of credentials for that node. """ - load_balancing_policy_factory = RoundRobinPolicy + load_balancing_policy = RoundRobinPolicy() """ - A factory function which creates instances of subclasses of - :class:`policies.LoadBalancingPolicy`. Defaults to - :class:`policies.RoundRobinPolicy`. + An instance of :class:`.policies.LoadBalancingPolicy` or + one of its subclasses. Defaults to :class:`~.RoundRobinPolicy`. """ reconnection_policy = ExponentialReconnectionPolicy(1.0, 600.0) @@ -168,6 +167,7 @@ class Cluster(object): scheduler = None executor = None _is_shutdown = False + _is_setup = False _prepared_statements = None def __init__(self, @@ -175,7 +175,7 @@ class Cluster(object): port=9042, compression=True, auth_provider=None, - load_balancing_policy_factory=None, + load_balancing_policy=None, reconnection_policy=None, retry_policy_factory=None, conviction_policy_factory=None, @@ -198,10 +198,8 @@ class Cluster(object): raise ValueError("auth_provider must be callable") self.auth_provider = auth_provider - if load_balancing_policy_factory is not None: - if not callable(load_balancing_policy_factory): - raise ValueError("load_balancing_policy_factory must be callable") - self.load_balancing_policy_factory = load_balancing_policy_factory + if load_balancing_policy is not None: + self.load_balancing_policy = load_balancing_policy if reconnection_policy is not None: self.reconnection_policy = reconnection_policy @@ -319,6 +317,11 @@ class Cluster(object): if self._is_shutdown: raise Exception("Cluster is already shut down") + if not self._is_setup: + self.load_balancing_policy.populate( + weakref.proxy(self), self.metadata.getAllHosts()) + self._is_setup = True + if self.control_connection: try: self.control_connection.connect() @@ -550,8 +553,7 @@ class Session(object): self._lock = RLock() self._pools = {} - self._load_balancer = cluster.load_balancing_policy_factory() - self._load_balancer.populate(weakref.proxy(cluster), hosts) + self._load_balancer = cluster.load_balancing_policy for host in hosts: self.add_host(host) @@ -832,7 +834,7 @@ class ControlConnection(object): # use a weak reference to allow the Cluster instance to be GC'ed (and # shutdown) since implementing __del__ disables the cycle detector self._cluster = weakref.proxy(cluster) - self._balancing_policy = cluster.load_balancing_policy_factory() + self._balancing_policy = cluster.load_balancing_policy self._balancing_policy.populate(cluster, []) self._reconnection_policy = cluster.reconnection_policy self._connection = None diff --git a/cassandra/metadata.py b/cassandra/metadata.py index f279355c..bf9ecf33 100644 --- a/cassandra/metadata.py +++ b/cassandra/metadata.py @@ -424,7 +424,7 @@ class TableMetadata(object): def as_cql_query(self, formatted=False): """ Returns a CQL query that can be used to recreate this table (index - creations are not included). If `formatted` is set to ``True``, + creations are not included). If `formatted` is set to :const:`True`, extra whitespace will be added to make the query human readable. """ ret = "CREATE TABLE %s.%s (%s" % (self.keyspace.name, self.name, "\n" if formatted else "") @@ -553,7 +553,7 @@ class ColumnMetadata(object): index = None """ If an index exists on this column, this is an instance of - :class:`.IndexMetadata`, otherwise ``None``. + :class:`.IndexMetadata`, otherwise :const:`None`. """ def __init__(self, table_metadata, column_name, data_type, index_metadata=None): diff --git a/cassandra/policies.py b/cassandra/policies.py index 2cd78a98..4593843f 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -75,7 +75,7 @@ class LoadBalancingPolicy(object): order. A generator may work well for custom implementations of this method. - Note that the `query` argument may be ``None`` when preparing + Note that the `query` argument may be :const:`None` when preparing statements. """ raise NotImplementedError() @@ -175,8 +175,8 @@ class DCAwareRoundRobinPolicy(LoadBalancingPolicy): `used_hosts_per_remote_dc` controls how many nodes in each remote datacenter will have connections opened against them. In other words, `used_hosts_per_remote_dc` hosts - will be considered :data:`.HostDistance.REMOTE` and the - rest will be considered :data:`.HostDistance.IGNORED`. + will be considered :attr:`~.HostDistance.REMOTE` and the + rest will be considered :attr:`~.HostDistance.IGNORED`. By default, all remote hosts are ignored. """ self.local_dc = local_dc @@ -241,6 +241,70 @@ class DCAwareRoundRobinPolicy(LoadBalancingPolicy): self._dc_live_hosts.setdefault(host.datacenter, set()).discard(host) +class TokenAwarePolicy(LoadBalancingPolicy): + """ + A :class:`.LoadBalancingPolicy` wrapper that adds token awareness to + a child policy. + + This alters the child policy's behavior so that it first attempts to + send queries to :attr:`~.HostDistance.LOCAL` replicas (as determined + by the child policy) based on the :class:`.Query`'s + :attr:`~.Query.routing_key`. Once those hosts are exhausted, the + remaining hosts in the child policy's query plan will be used. + + If no :attr:`~.Query.routing_key` is set on the query, the child + policy's query plan will be used as is. + """ + + _child_policy = None + _cluster_metadata = None + + def __init__(self, child_policy): + self.child_policy = child_policy + + def populate(self, cluster, hosts): + self._cluster_metadata = cluster.metadata + self.child_policy.populate(cluster, hosts) + + def distance(self, *args, **kwargs): + return self.child_policy.distance(*args, **kwargs) + + def make_query_plan(self, query=None): + child = self.child_policy + if query is None: + for host in child.make_query_plan(query): + yield host + else: + routing_key = query.routing_key + if routing_key is None: + for host in child.make_query_plan(query): + yield host + else: + replicas = self.metadata.get_replicas(routing_key) + for replica in replicas: + if replica.monitor.is_up and \ + child.distance(replica) == HostDistance.LOCAL: + yield replica + + for host in child.make_query_plan(query): + # skip if we've already listed this host + if host not in replicas or \ + child.distance(replica) == HostDistance.REMOTE: + yield host + + def on_up(self, *args, **kwargs): + return self.child_policy.on_up(*args, **kwargs) + + def on_down(self, *args, **kwargs): + return self.child_policy.on_down(*args, **kwargs) + + def on_add(self, *args, **kwargs): + return self.child_policy.on_add(*args, **kwargs) + + def on_remove(self, *args, **kwargs): + return self.child_policy.on_remove(*args, **kwargs) + + class ConvictionPolicy(object): """ A policy which decides when hosts should be considered down @@ -257,8 +321,8 @@ class ConvictionPolicy(object): def add_failure(self, connection_exc): """ - Implementations should return ``True`` if the host should be - convicted, ``False`` otherwise. + Implementations should return :const:`True` if the host should be + convicted, :const:`False` otherwise. """ raise NotImplementedError() @@ -314,7 +378,7 @@ class ConstantReconnectionPolicy(ReconnectionPolicy): each attempt. `max_attempts` should be a total number of attempts to be made before - giving up, or ``None`` to continue reconnection attempts forever. + giving up, or :const:`None` to continue reconnection attempts forever. The default is 64. """ if delay < 0: @@ -431,7 +495,7 @@ class RetryPolicy(object): perspective (i.e. a replica did not respond to the coordinator in time). It should return a tuple with two items: one of the class enums (such as :attr:`.RETRY`) and a :class:`.ConsistencyLevel` to retry the - operation at or ``None`` to keep the same consistency level. + operation at or :const:`None` to keep the same consistency level. `query` is the :class:`.Query` that timed out. @@ -541,7 +605,7 @@ class DowngradingConsistencyRetryPolicy(RetryPolicy): level than the one initially requested. By doing so, it may break consistency guarantees. In other words, if you use this retry policy, there is cases (documented below) where a read at :attr:`~.QUORUM` - *may not* see a preceding write at :attr`~.QUORUM`. Do not use this + *may not* see a preceding write at :attr:`~.QUORUM`. Do not use this policy unless you have understood the cases where this can happen and are ok with that. It is also recommended to subclass this class so that queries that required a consistency level downgrade can be diff --git a/cassandra/pool.py b/cassandra/pool.py index 6a53dcad..dc8cbe86 100644 --- a/cassandra/pool.py +++ b/cassandra/pool.py @@ -158,9 +158,10 @@ class _ReconnectionHandler(object): number of seconds (as a float) that the handler will wait before attempting to connect again. - Subclasses should return ``False`` if no more attempts to connection - should be made, ``True`` otherwise. The default behavior is to - always retry unless the error is an :exc:`.AuthenticationFailed`. + Subclasses should return :const:`False` if no more attempts to + connection should be made, :const:`True` otherwise. The default + behavior is to always retry unless the error is an + :exc:`.AuthenticationFailed` instance. """ if isinstance(exc, AuthenticationFailed): return False diff --git a/cassandra/query.py b/cassandra/query.py index 845c8a18..bd912f5b 100644 --- a/cassandra/query.py +++ b/cassandra/query.py @@ -26,7 +26,7 @@ class Query(object): tracing_enabled = False """ - A boolean flag that may be set to ``True`` to enable tracing on this + A boolean flag that may be set to :const:`True` to enable tracing on this query only. **Note**: query tracing is not yet supported by this driver @@ -45,18 +45,30 @@ class Query(object): self.consistency_level = consistency_level self._routing_key = routing_key - @property - def routing_key(self): + def _get_routing_key(self): return self._routing_key - @routing_key.setter - def set_routing_key(self, *key_components): + def _set_routing_key(self, key_components): if len(key_components) == 1: self._routing_key = key_components[0] else: self._routing_key = "".join(struct.pack("HsB", len(component), component, 0) for component in key_components) + def _del_routing_key(self): + self._routing_key = None + + routing_key = property( + _get_routing_key, + _set_routing_key, + _del_routing_key, + """ + The :attr:`~.TableMetadata.partition_key` portion of the primary key, + which can be used to determine which nodes are replicas for the query. + + When setting this attribute, a list or tuple *must* be used. + """) + class SimpleStatement(Query): """ A simple, un-prepared query. All attributes of :class:`Query` apply diff --git a/tests/integration/test_cluster.py b/tests/integration/test_cluster.py index e3477c83..991bc849 100644 --- a/tests/integration/test_cluster.py +++ b/tests/integration/test_cluster.py @@ -45,8 +45,8 @@ class ClusterTests(unittest.TestCase): def foo(*args, **kwargs): return Mock() - for kw in ('auth_provider', 'load_balancing_policy_factory', - 'retry_policy_factory', 'conviction_policy_factory'): + for kw in ('auth_provider', 'retry_policy_factory', + 'conviction_policy_factory'): kwargs = {kw: 123} self.assertRaises(ValueError, Cluster, **kwargs) @@ -55,7 +55,8 @@ class ClusterTests(unittest.TestCase): self.assertEquals(getattr(c, kw), foo) for kw in ('contact_points', 'port', 'compression', 'metrics_enabled', - 'reconnection_policy', 'sockopts', 'max_schema_agreement_wait'): + 'load_balancing_policy', 'reconnection_policy', 'sockopts', + 'max_schema_agreement_wait'): kwargs = {kw: (1, 2, 3)} c = Cluster(**kwargs) self.assertEquals(getattr(c, kw), (1, 2, 3)) diff --git a/tests/unit/test_control_connection.py b/tests/unit/test_control_connection.py index 034d69a4..9314c8b3 100644 --- a/tests/unit/test_control_connection.py +++ b/tests/unit/test_control_connection.py @@ -39,7 +39,7 @@ class MockMetadata(object): class MockCluster(object): max_schema_agreement_wait = Cluster.max_schema_agreement_wait - load_balancing_policy_factory = RoundRobinPolicy + load_balancing_policy = RoundRobinPolicy() reconnection_policy = ConstantReconnectionPolicy(2) def __init__(self):