diff --git a/cassandra/policies.py b/cassandra/policies.py index fcffc41b..6ffe34d1 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -186,25 +186,31 @@ class DCAwareRoundRobinPolicy(LoadBalancingPolicy): self.used_hosts_per_remote_dc = used_hosts_per_remote_dc self._dc_live_hosts = {} + def _dc(self, host): + return host.datacenter or self.local_dc + def populate(self, cluster, hosts): - for dc, dc_hosts in groupby(hosts, lambda h: h.datacenter): + for dc, dc_hosts in groupby(hosts, lambda h: self._dc(h)): self._dc_live_hosts[dc] = set(dc_hosts) # position is currently only used for local hosts local_live = self._dc_live_hosts.get(self.local_dc) - if len(local_live) == 1: + if not local_live: + self._position = 0 + elif len(local_live) == 1: self._position = 0 else: self._position = randint(0, len(local_live) - 1) def distance(self, host): - if host.datacenter == self.local_dc: + dc = self._dc(host) + if dc == self.local_dc: return HostDistance.LOCAL if not self.used_hosts_per_remote_dc: return HostDistance.IGNORED else: - dc_hosts = self._dc_live_hosts.get(host.datacenter) + dc_hosts = self._dc_live_hosts.get(dc) if not dc_hosts: return HostDistance.IGNORED @@ -219,8 +225,8 @@ class DCAwareRoundRobinPolicy(LoadBalancingPolicy): pos = self._position self._position += 1 - local_live = list(self._dc_live_hosts.get(self.local_dc)) - pos %= len(local_live) + local_live = list(self._dc_live_hosts.get(self.local_dc, ())) + pos = (pos % len(local_live)) if local_live else 0 for host in islice(cycle(local_live), pos, pos + len(local_live)): yield host @@ -232,16 +238,16 @@ class DCAwareRoundRobinPolicy(LoadBalancingPolicy): yield host def on_up(self, host): - self._dc_live_hosts.setdefault(host.datacenter, set()).add(host) + self._dc_live_hosts.setdefault(self._dc(host), set()).add(host) def on_down(self, host): - self._dc_live_hosts.setdefault(host.datacenter, set()).discard(host) + self._dc_live_hosts.setdefault(self._dc(host), set()).discard(host) def on_add(self, host): - self._dc_live_hosts.setdefault(host.datacenter, set()).add(host) + self._dc_live_hosts.setdefault(self._dc(host), set()).add(host) def on_remove(self, host): - self._dc_live_hosts.setdefault(host.datacenter, set()).discard(host) + self._dc_live_hosts.setdefault(self._dc(host), set()).discard(host) class TokenAwarePolicy(LoadBalancingPolicy):