From bff3d5f6e31fe595a77143ec4ac779c187bf72a8 Mon Sep 17 00:00:00 2001 From: Erik Olof Gunnar Andersson Date: Thu, 23 Dec 2021 23:14:46 -0800 Subject: [PATCH] Fix designate-manage pool update bugs This patch addresses a few problems with the manage command for pools and NS records. - Fixed an issue where having multiple NS records would break the pool command. - Fixed a scenario where manually (non-managed) NS records could break the pool command. - Fixed a potential edge case that could break the pool command. The biggest change is that we now only manage the NS record for the zone itself. This was always the case, but because we didn't check for this specifically, other NS records would conflict with the command. Change-Id: I4e6ea0b6b717d2a1b5cc420874d1bb8fb290e04b --- designate/central/service.py | 108 ++++++++++-------- .../tests/unit/test_central/test_basic.py | 40 ++----- 2 files changed, 66 insertions(+), 82 deletions(-) diff --git a/designate/central/service.py b/designate/central/service.py index 93b880832..861728d9c 100644 --- a/designate/central/service.py +++ b/designate/central/service.py @@ -561,51 +561,55 @@ class Service(service.RPCService): objects.Record(data=r, managed=True) for r in ns_records]) values = { 'name': zone['name'], - 'type': "NS", + 'type': 'NS', 'records': recordlist } ns, zone = self._create_recordset_in_storage( context, zone, objects.RecordSet(**values), - increment_serial=False) + increment_serial=False + ) return ns def _add_ns(self, context, zone, ns_record): # Get NS recordset # If the zone doesn't have an NS recordset yet, create one - recordsets = self.find_recordsets( - context, criterion={'zone_id': zone['id'], 'type': "NS"} - ) - - managed = [] - for rs in recordsets: - if rs.managed: - managed.append(rs) - - if len(managed) == 0: + try: + recordset = self.find_recordset( + context, + criterion={ + 'zone_id': zone['id'], + 'name': zone['name'], + 'type': 'NS' + } + ) + except exceptions.RecordSetNotFound: self._create_ns(context, zone, [ns_record]) return - elif len(managed) != 1: - raise exceptions.RecordSetNotFound("No valid recordset found") - - ns_recordset = managed[0] # Add new record to recordset based on the new nameserver - ns_recordset.records.append( - objects.Record(data=ns_record, managed=True)) + recordset.records.append( + objects.Record(data=ns_record, managed=True) + ) - self._update_recordset_in_storage(context, zone, ns_recordset, + self._update_recordset_in_storage(context, zone, recordset, set_delayed_notify=True) def _delete_ns(self, context, zone, ns_record): - ns_recordset = self.find_recordset( - context, criterion={'zone_id': zone['id'], 'type': "NS"}) + recordset = self.find_recordset( + context, + criterion={ + 'zone_id': zone['id'], + 'name': zone['name'], + 'type': 'NS' + } + ) - for record in copy.deepcopy(ns_recordset.records): + for record in list(recordset.records): if record.data == ns_record: - ns_recordset.records.remove(record) + recordset.records.remove(record) - self._update_recordset_in_storage(context, zone, ns_recordset, + self._update_recordset_in_storage(context, zone, recordset, set_delayed_notify=True) # Quota Enforcement Methods @@ -2538,46 +2542,49 @@ class Service(service.RPCService): @notification('dns.pool.update') @transaction def update_pool(self, context, pool): - policy.check('update_pool', context) # If there is a nameserver, then additional steps need to be done # Since these are treated as mutable objects, we're only going to # be comparing the nameserver.value which is the FQDN - if pool.obj_attr_is_set('ns_records'): - elevated_context = context.elevated(all_tenants=True) + elevated_context = context.elevated(all_tenants=True) - # TODO(kiall): ListObjects should be able to give you their - # original set of values. - original_pool_ns_records = self._get_pool_ns_records(context, - pool.id) - # Find the current NS hostnames - existing_ns = set([n.hostname for n in original_pool_ns_records]) - - # Find the desired NS hostnames - request_ns = set([n.hostname for n in pool.ns_records]) - - # Get the NS's to be created and deleted, ignoring the ones that - # are in both sets, as those haven't changed. - # TODO(kiall): Factor in priority - create_ns = request_ns.difference(existing_ns) - delete_ns = existing_ns.difference(request_ns) + # TODO(kiall): ListObjects should be able to give you their + # original set of values. + original_pool_ns_records = self._get_pool_ns_records( + context, pool.id + ) updated_pool = self.storage.update_pool(context, pool) + if not pool.obj_attr_is_set('ns_records'): + return updated_pool + + # Find the current NS hostnames + existing_ns = set([n.hostname for n in original_pool_ns_records]) + + # Find the desired NS hostnames + request_ns = set([n.hostname for n in pool.ns_records]) + + # Get the NS's to be created and deleted, ignoring the ones that + # are in both sets, as those haven't changed. + # TODO(kiall): Factor in priority + create_ns = request_ns.difference(existing_ns) + delete_ns = existing_ns.difference(request_ns) + # After the update, handle new ns_records - for ns in create_ns: + for ns_record in create_ns: # Create new NS recordsets for every zone zones = self.find_zones( context=elevated_context, criterion={'pool_id': pool.id, 'action': '!DELETE'}) - for z in zones: - self._add_ns(elevated_context, z, ns) + for zone in zones: + self._add_ns(elevated_context, zone, ns_record) # Then handle the ns_records to delete - for ns in delete_ns: + for ns_record in delete_ns: # Cannot delete the last nameserver, so verify that first. - if len(pool.ns_records) == 0: + if not pool.ns_records: raise exceptions.LastServerDeleteNotAllowed( "Not allowed to delete last of servers" ) @@ -2585,9 +2592,10 @@ class Service(service.RPCService): # Delete the NS record for every zone zones = self.find_zones( context=elevated_context, - criterion={'pool_id': pool.id}) - for z in zones: - self._delete_ns(elevated_context, z, ns) + criterion={'pool_id': pool.id} + ) + for zone in zones: + self._delete_ns(elevated_context, zone, ns_record) return updated_pool diff --git a/designate/tests/unit/test_central/test_basic.py b/designate/tests/unit/test_central/test_basic.py index 492a66641..8d11b59df 100644 --- a/designate/tests/unit/test_central/test_basic.py +++ b/designate/tests/unit/test_central/test_basic.py @@ -789,13 +789,13 @@ class CentralZoneTestCase(CentralBasic): def test_add_ns_creation(self): self.service._create_ns = mock.Mock() - self.service.find_recordsets = mock.Mock( - return_value=[] + self.service.find_recordset = mock.Mock( + side_effect=exceptions.RecordSetNotFound() ) self.service._add_ns( self.context, - RoObject(id=CentralZoneTestCase.zone__id), + RoObject(name='foo', id=CentralZoneTestCase.zone__id), RoObject(name='bar') ) ctx, zone, records = self.service._create_ns.call_args[0] @@ -804,16 +804,15 @@ class CentralZoneTestCase(CentralBasic): def test_add_ns(self): self.service._update_recordset_in_storage = mock.Mock() - recordsets = [ - RoObject(records=objects.RecordList.from_list([]), managed=True) - ] - self.service.find_recordsets = mock.Mock( - return_value=recordsets + self.service.find_recordset = mock.Mock( + return_value=RoObject( + records=objects.RecordList.from_list([]), managed=True + ) ) self.service._add_ns( self.context, - RoObject(id=CentralZoneTestCase.zone__id), + RoObject(name='foo', id=CentralZoneTestCase.zone__id), RoObject(name='bar') ) ctx, zone, rset = \ @@ -822,29 +821,6 @@ class CentralZoneTestCase(CentralBasic): self.assertTrue(rset.records[0].managed) self.assertEqual('bar', rset.records[0].data.name) - def test_add_ns_with_other_ns_rs(self): - self.service._update_recordset_in_storage = mock.Mock() - - recordsets = [ - RoObject(records=objects.RecordList.from_list([]), managed=True), - RoObject(records=objects.RecordList.from_list([]), managed=False) - ] - - self.service.find_recordsets = mock.Mock( - return_value=recordsets - ) - - self.service._add_ns( - self.context, - RoObject(id=CentralZoneTestCase.zone__id), - RoObject(name='bar') - ) - ctx, zone, rset = \ - self.service._update_recordset_in_storage.call_args[0] - self.assertEqual(1, len(rset.records)) - self.assertTrue(rset.records[0].managed) - self.assertEqual('bar', rset.records[0].data.name) - def test_create_zone_no_servers(self): self.service._enforce_zone_quota = mock.Mock() self.service._is_valid_zone_name = mock.Mock()