diff --git a/designate/central/service.py b/designate/central/service.py index 93b880832..861728d9c 100644 --- a/designate/central/service.py +++ b/designate/central/service.py @@ -561,51 +561,55 @@ class Service(service.RPCService): objects.Record(data=r, managed=True) for r in ns_records]) values = { 'name': zone['name'], - 'type': "NS", + 'type': 'NS', 'records': recordlist } ns, zone = self._create_recordset_in_storage( context, zone, objects.RecordSet(**values), - increment_serial=False) + increment_serial=False + ) return ns def _add_ns(self, context, zone, ns_record): # Get NS recordset # If the zone doesn't have an NS recordset yet, create one - recordsets = self.find_recordsets( - context, criterion={'zone_id': zone['id'], 'type': "NS"} - ) - - managed = [] - for rs in recordsets: - if rs.managed: - managed.append(rs) - - if len(managed) == 0: + try: + recordset = self.find_recordset( + context, + criterion={ + 'zone_id': zone['id'], + 'name': zone['name'], + 'type': 'NS' + } + ) + except exceptions.RecordSetNotFound: self._create_ns(context, zone, [ns_record]) return - elif len(managed) != 1: - raise exceptions.RecordSetNotFound("No valid recordset found") - - ns_recordset = managed[0] # Add new record to recordset based on the new nameserver - ns_recordset.records.append( - objects.Record(data=ns_record, managed=True)) + recordset.records.append( + objects.Record(data=ns_record, managed=True) + ) - self._update_recordset_in_storage(context, zone, ns_recordset, + self._update_recordset_in_storage(context, zone, recordset, set_delayed_notify=True) def _delete_ns(self, context, zone, ns_record): - ns_recordset = self.find_recordset( - context, criterion={'zone_id': zone['id'], 'type': "NS"}) + recordset = self.find_recordset( + context, + criterion={ + 'zone_id': zone['id'], + 'name': zone['name'], + 'type': 'NS' + } + ) - for record in copy.deepcopy(ns_recordset.records): + for record in list(recordset.records): if record.data == ns_record: - ns_recordset.records.remove(record) + recordset.records.remove(record) - self._update_recordset_in_storage(context, zone, ns_recordset, + self._update_recordset_in_storage(context, zone, recordset, set_delayed_notify=True) # Quota Enforcement Methods @@ -2538,46 +2542,49 @@ class Service(service.RPCService): @notification('dns.pool.update') @transaction def update_pool(self, context, pool): - policy.check('update_pool', context) # If there is a nameserver, then additional steps need to be done # Since these are treated as mutable objects, we're only going to # be comparing the nameserver.value which is the FQDN - if pool.obj_attr_is_set('ns_records'): - elevated_context = context.elevated(all_tenants=True) + elevated_context = context.elevated(all_tenants=True) - # TODO(kiall): ListObjects should be able to give you their - # original set of values. - original_pool_ns_records = self._get_pool_ns_records(context, - pool.id) - # Find the current NS hostnames - existing_ns = set([n.hostname for n in original_pool_ns_records]) - - # Find the desired NS hostnames - request_ns = set([n.hostname for n in pool.ns_records]) - - # Get the NS's to be created and deleted, ignoring the ones that - # are in both sets, as those haven't changed. - # TODO(kiall): Factor in priority - create_ns = request_ns.difference(existing_ns) - delete_ns = existing_ns.difference(request_ns) + # TODO(kiall): ListObjects should be able to give you their + # original set of values. + original_pool_ns_records = self._get_pool_ns_records( + context, pool.id + ) updated_pool = self.storage.update_pool(context, pool) + if not pool.obj_attr_is_set('ns_records'): + return updated_pool + + # Find the current NS hostnames + existing_ns = set([n.hostname for n in original_pool_ns_records]) + + # Find the desired NS hostnames + request_ns = set([n.hostname for n in pool.ns_records]) + + # Get the NS's to be created and deleted, ignoring the ones that + # are in both sets, as those haven't changed. + # TODO(kiall): Factor in priority + create_ns = request_ns.difference(existing_ns) + delete_ns = existing_ns.difference(request_ns) + # After the update, handle new ns_records - for ns in create_ns: + for ns_record in create_ns: # Create new NS recordsets for every zone zones = self.find_zones( context=elevated_context, criterion={'pool_id': pool.id, 'action': '!DELETE'}) - for z in zones: - self._add_ns(elevated_context, z, ns) + for zone in zones: + self._add_ns(elevated_context, zone, ns_record) # Then handle the ns_records to delete - for ns in delete_ns: + for ns_record in delete_ns: # Cannot delete the last nameserver, so verify that first. - if len(pool.ns_records) == 0: + if not pool.ns_records: raise exceptions.LastServerDeleteNotAllowed( "Not allowed to delete last of servers" ) @@ -2585,9 +2592,10 @@ class Service(service.RPCService): # Delete the NS record for every zone zones = self.find_zones( context=elevated_context, - criterion={'pool_id': pool.id}) - for z in zones: - self._delete_ns(elevated_context, z, ns) + criterion={'pool_id': pool.id} + ) + for zone in zones: + self._delete_ns(elevated_context, zone, ns_record) return updated_pool diff --git a/designate/tests/unit/test_central/test_basic.py b/designate/tests/unit/test_central/test_basic.py index 492a66641..8d11b59df 100644 --- a/designate/tests/unit/test_central/test_basic.py +++ b/designate/tests/unit/test_central/test_basic.py @@ -789,13 +789,13 @@ class CentralZoneTestCase(CentralBasic): def test_add_ns_creation(self): self.service._create_ns = mock.Mock() - self.service.find_recordsets = mock.Mock( - return_value=[] + self.service.find_recordset = mock.Mock( + side_effect=exceptions.RecordSetNotFound() ) self.service._add_ns( self.context, - RoObject(id=CentralZoneTestCase.zone__id), + RoObject(name='foo', id=CentralZoneTestCase.zone__id), RoObject(name='bar') ) ctx, zone, records = self.service._create_ns.call_args[0] @@ -804,16 +804,15 @@ class CentralZoneTestCase(CentralBasic): def test_add_ns(self): self.service._update_recordset_in_storage = mock.Mock() - recordsets = [ - RoObject(records=objects.RecordList.from_list([]), managed=True) - ] - self.service.find_recordsets = mock.Mock( - return_value=recordsets + self.service.find_recordset = mock.Mock( + return_value=RoObject( + records=objects.RecordList.from_list([]), managed=True + ) ) self.service._add_ns( self.context, - RoObject(id=CentralZoneTestCase.zone__id), + RoObject(name='foo', id=CentralZoneTestCase.zone__id), RoObject(name='bar') ) ctx, zone, rset = \ @@ -822,29 +821,6 @@ class CentralZoneTestCase(CentralBasic): self.assertTrue(rset.records[0].managed) self.assertEqual('bar', rset.records[0].data.name) - def test_add_ns_with_other_ns_rs(self): - self.service._update_recordset_in_storage = mock.Mock() - - recordsets = [ - RoObject(records=objects.RecordList.from_list([]), managed=True), - RoObject(records=objects.RecordList.from_list([]), managed=False) - ] - - self.service.find_recordsets = mock.Mock( - return_value=recordsets - ) - - self.service._add_ns( - self.context, - RoObject(id=CentralZoneTestCase.zone__id), - RoObject(name='bar') - ) - ctx, zone, rset = \ - self.service._update_recordset_in_storage.call_args[0] - self.assertEqual(1, len(rset.records)) - self.assertTrue(rset.records[0].managed) - self.assertEqual('bar', rset.records[0].data.name) - def test_create_zone_no_servers(self): self.service._enforce_zone_quota = mock.Mock() self.service._is_valid_zone_name = mock.Mock()