diff --git a/cassandra/metadata.py b/cassandra/metadata.py index 51d97685..cf5a26b1 100644 --- a/cassandra/metadata.py +++ b/cassandra/metadata.py @@ -473,30 +473,29 @@ class NetworkTopologyStrategy(ReplicationStrategy): def make_token_replica_map(self, token_to_host_owner, ring): # note: this does not account for hosts having different racks replica_map = defaultdict(list) - ring_len = len(ring) - ring_len_range = range(ring_len) dc_rf_map = dict((dc, int(rf)) for dc, rf in self.dc_replication_factors.items() if rf > 0) - dcs = dict((h, h.datacenter) for h in set(token_to_host_owner.values())) # build a map of DCs to lists of indexes into `ring` for tokens that # belong to that DC dc_to_token_offset = defaultdict(list) + dc_racks = defaultdict(set) for i, token in enumerate(ring): host = token_to_host_owner[token] - dc_to_token_offset[dcs[host]].append(i) + dc_to_token_offset[host.datacenter].append(i) + if host.datacenter and host.rack: + dc_racks[host.datacenter].add(host.rack) # A map of DCs to an index into the dc_to_token_offset value for that dc. # This is how we keep track of advancing around the ring for each DC. dc_to_current_index = defaultdict(int) - for i in ring_len_range: - remaining = dc_rf_map.copy() + for i in range(len(ring)): replicas = replica_map[ring[i]] # go through each DC and find the replicas in that DC for dc in dc_to_token_offset.keys(): - if dc not in remaining: + if dc not in dc_rf_map: continue # advance our per-DC index until we're up to at least the @@ -508,20 +507,33 @@ class NetworkTopologyStrategy(ReplicationStrategy): index += 1 dc_to_current_index[dc] = index - # now add the next RF distinct token owners to the set of - # replicas for this DC + replicas_remaining = dc_rf_map[dc] + skipped_hosts = [] + racks_placed = set() + racks_this_dc = dc_racks[dc] for token_offset in islice(cycle(token_offsets), index, index + num_tokens): host = token_to_host_owner[ring[token_offset]] + if replicas_remaining == 0: + break + if host in replicas: continue + if host.rack in racks_placed and len(racks_placed) < len(racks_this_dc): + skipped_hosts.append(host) + continue + replicas.append(host) - dc_remaining = remaining[dc] - 1 - if dc_remaining == 0: - del remaining[dc] - break - else: - remaining[dc] = dc_remaining + replicas_remaining -= 1 + racks_placed.add(host.rack) + + if len(racks_placed) == len(racks_this_dc): + for host in skipped_hosts: + if replicas_remaining == 0: + break + replicas.append(host) + replicas_remaining -= 1 + del skipped_hosts[:] return replica_map diff --git a/tests/unit/test_metadata.py b/tests/unit/test_metadata.py index 51f4f613..fc927a02 100644 --- a/tests/unit/test_metadata.py +++ b/tests/unit/test_metadata.py @@ -111,6 +111,48 @@ class StrategiesTest(unittest.TestCase): self.assertItemsEqual(replica_map[MD5Token(0)], (dc1_1, dc1_2, dc2_1, dc2_2, dc3_1)) + def test_nts_make_token_replica_map_multi_rack(self): + token_to_host_owner = {} + + # (A) not enough distinct racks, first skipped is used + dc1_1 = Host('dc1.1', SimpleConvictionPolicy) + dc1_2 = Host('dc1.2', SimpleConvictionPolicy) + dc1_3 = Host('dc1.3', SimpleConvictionPolicy) + dc1_4 = Host('dc1.4', SimpleConvictionPolicy) + dc1_1.set_location_info('dc1', 'rack1') + dc1_2.set_location_info('dc1', 'rack1') + dc1_3.set_location_info('dc1', 'rack2') + dc1_4.set_location_info('dc1', 'rack2') + token_to_host_owner[MD5Token(0)] = dc1_1 + token_to_host_owner[MD5Token(100)] = dc1_2 + token_to_host_owner[MD5Token(200)] = dc1_3 + token_to_host_owner[MD5Token(300)] = dc1_4 + + # (B) distinct racks, but not contiguous + dc2_1 = Host('dc2.1', SimpleConvictionPolicy) + dc2_2 = Host('dc2.2', SimpleConvictionPolicy) + dc2_3 = Host('dc2.3', SimpleConvictionPolicy) + dc2_1.set_location_info('dc2', 'rack1') + dc2_2.set_location_info('dc2', 'rack1') + dc2_3.set_location_info('dc2', 'rack2') + token_to_host_owner[MD5Token(1)] = dc2_1 + token_to_host_owner[MD5Token(101)] = dc2_2 + token_to_host_owner[MD5Token(201)] = dc2_3 + + ring = [MD5Token(0), + MD5Token(1), + MD5Token(100), + MD5Token(101), + MD5Token(200), + MD5Token(201), + MD5Token(300)] + + nts = NetworkTopologyStrategy({'dc1': 3, 'dc2': 2}) + replica_map = nts.make_token_replica_map(token_to_host_owner, ring) + + token_replicas = replica_map[MD5Token(0)] + self.assertItemsEqual(token_replicas, (dc1_1, dc1_2, dc1_3, dc2_1, dc2_3)) + def test_nts_make_token_replica_map_empty_dc(self): host = Host('1', SimpleConvictionPolicy) host.set_location_info('dc1', 'rack1')