From e44f27eaacf6c1189a063ea7e3be875c28c8cbd9 Mon Sep 17 00:00:00 2001 From: Kiall Mac Innes Date: Mon, 30 Jun 2014 17:38:39 +0100 Subject: [PATCH] Update's should use objects We have updated the methods for these resources: * Server * TLD * TsigKey * Domain * RecordSet * Record * Blacklist * Quota (Storage layer only) Change-Id: I73b4c2ca8b08b85e1b7432e3f798164c703a4164 --- designate/api/v1/domains.py | 15 +- designate/api/v1/records.py | 50 ++-- designate/api/v1/servers.py | 15 +- designate/api/v1/tsigkeys.py | 16 +- designate/api/v2/controllers/blacklists.py | 17 +- designate/api/v2/controllers/records.py | 14 +- designate/api/v2/controllers/recordsets.py | 14 +- designate/api/v2/controllers/tlds.py | 13 +- designate/api/v2/controllers/zones.py | 15 +- designate/central/rpcapi.py | 42 +-- designate/central/service.py | 158 ++++++----- designate/quota/impl_storage.py | 18 +- designate/storage/base.py | 54 ++-- designate/storage/impl_sqlalchemy/__init__.py | 88 +++--- designate/tests/test_central/test_service.py | 266 +++++++++--------- designate/tests/test_storage/__init__.py | 191 ++++++------- 16 files changed, 497 insertions(+), 489 deletions(-) diff --git a/designate/api/v1/domains.py b/designate/api/v1/domains.py index 3bb987a1..72570953 100644 --- a/designate/api/v1/domains.py +++ b/designate/api/v1/domains.py @@ -76,12 +76,19 @@ def update_domain(domain_id): context = flask.request.environ.get('context') values = flask.request.json + # Fetch the existing resource domain = get_central_api().get_domain(context, domain_id) - domain = domain_schema.filter(domain) - domain.update(values) - domain_schema.validate(domain) - domain = get_central_api().update_domain(context, domain_id, values) + # Prepare a dict of fields for validation + domain_data = domain_schema.filter(domain) + domain_data.update(values) + + # Validate the new set of data + domain_schema.validate(domain_data) + + # Update and persist the resource + domain.update(values) + domain = get_central_api().update_domain(context, domain) return flask.jsonify(domain_schema.filter(domain)) diff --git a/designate/api/v1/records.py b/designate/api/v1/records.py index 5437d117..2c15ac43 100644 --- a/designate/api/v1/records.py +++ b/designate/api/v1/records.py @@ -170,49 +170,31 @@ def update_record(domain_id, record_id): # return an record not found instead of a domain not found get_central_api().get_domain(context, domain_id) - # Find the record + # Fetch the existing resource + # NOTE(kiall): We use "find_record" rather than "get_record" as we do not + # have the recordset_id. criterion = {'domain_id': domain_id, 'id': record_id} record = get_central_api().find_record(context, criterion) # Find the associated recordset recordset = get_central_api().get_recordset( - context, domain_id, record['recordset_id']) + context, domain_id, record.recordset_id) - # Ensure all the API V1 fields are in place - record = _format_record_v1(record, recordset) + # Prepare a dict of fields for validation + record_data = record_schema.filter(_format_record_v1(record, recordset)) + record_data.update(values) - # Filter out any extra fields from the fetched record - record = record_schema.filter(record) + # Validate the new set of data + record_schema.validate(record_data) - # Name and Type can't be updated on existing records - if 'name' in values and record['name'] != values['name']: - raise exceptions.InvalidOperation('The name field is immutable') + # Update and persist the resource + record.update(_extract_record_values(values)) + record = get_central_api().update_record(context, record) - if 'type' in values and record['type'] != values['type']: - raise exceptions.InvalidOperation('The type field is immutable') - - # TTL Updates should be applied to the RecordSet - update_recordset = False - - if 'ttl' in values and record['ttl'] != values['ttl']: - update_recordset = True - - # Apply the updated fields to the record - record.update(values) - - # Validate the record - record_schema.validate(record) - - # Update the record - record = get_central_api().update_record( - context, domain_id, recordset['id'], record_id, - _extract_record_values(values)) - - # Update the recordset (if necessary) - if update_recordset: - recordset = get_central_api().update_recordset( - context, domain_id, recordset['id'], - _extract_recordset_values(values)) + # Update the recordset resource (if necessary) + recordset.update(_extract_recordset_values(values)) + if len(recordset.obj_what_changed()) > 0: + recordset = get_central_api().update_recordset(context, recordset) # Format and return the response record = _format_record_v1(record, recordset) diff --git a/designate/api/v1/servers.py b/designate/api/v1/servers.py index 19b05328..f57e08bf 100644 --- a/designate/api/v1/servers.py +++ b/designate/api/v1/servers.py @@ -76,12 +76,19 @@ def update_server(server_id): context = flask.request.environ.get('context') values = flask.request.json + # Fetch the existing resource server = get_central_api().get_server(context, server_id) - server = server_schema.filter(server) - server.update(values) - server_schema.validate(server) - server = get_central_api().update_server(context, server_id, values=values) + # Prepare a dict of fields for validation + server_data = server_schema.filter(server) + server_data.update(values) + + # Validate the new set of data + server_schema.validate(server_data) + + # Update and persist the resource + server.update(values) + server = get_central_api().update_server(context, server) return flask.jsonify(server_schema.filter(server)) diff --git a/designate/api/v1/tsigkeys.py b/designate/api/v1/tsigkeys.py index 42900c90..8686e770 100644 --- a/designate/api/v1/tsigkeys.py +++ b/designate/api/v1/tsigkeys.py @@ -76,13 +76,19 @@ def update_tsigkey(tsigkey_id): context = flask.request.environ.get('context') values = flask.request.json + # Fetch the existing resource tsigkey = get_central_api().get_tsigkey(context, tsigkey_id) - tsigkey = tsigkey_schema.filter(tsigkey) - tsigkey.update(values) - tsigkey_schema.validate(tsigkey) - tsigkey = get_central_api().update_tsigkey(context, tsigkey_id, - values=values) + # Prepare a dict of fields for validation + tsigkey_data = tsigkey_schema.filter(tsigkey) + tsigkey_data.update(values) + + # Validate the new set of data + tsigkey_schema.validate(tsigkey_data) + + # Update and persist the resource + tsigkey.update(values) + tsigkey = get_central_api().update_tsigkey(context, tsigkey) return flask.jsonify(tsigkey_schema.filter(tsigkey)) diff --git a/designate/api/v2/controllers/blacklists.py b/designate/api/v2/controllers/blacklists.py index 37ede72f..e092c185 100644 --- a/designate/api/v2/controllers/blacklists.py +++ b/designate/api/v2/controllers/blacklists.py @@ -101,24 +101,23 @@ class BlacklistsController(rest.RestController): body = request.body_dict response = pecan.response - # Fetch the existing blacklisted zone + # Fetch the existing blacklist entry blacklist = self.central_api.get_blacklist(context, blacklist_id) # Convert to APIv2 Format - blacklist = self._view.show(context, request, blacklist) + blacklist_data = self._view.show(context, request, blacklist) if request.content_type == 'application/json-patch+json': raise NotImplemented('json-patch not implemented') else: - blacklist = utils.deep_dict_merge(blacklist, body) + blacklist_data = utils.deep_dict_merge(blacklist_data, body) - # Validate the request conforms to the schema - self._resource_schema.validate(blacklist) + # Validate the new set of data + self._resource_schema.validate(blacklist_data) - values = self._view.load(context, request, body) - - blacklist = self.central_api.update_blacklist(context, - blacklist_id, values) + # Update and persist the resource + blacklist.update(self._view.load(context, request, body)) + blacklist = self.central_api.update_blacklist(context, blacklist) response.status_int = 200 diff --git a/designate/api/v2/controllers/records.py b/designate/api/v2/controllers/records.py index 4ad030d7..2fd9de49 100644 --- a/designate/api/v2/controllers/records.py +++ b/designate/api/v2/controllers/records.py @@ -116,19 +116,19 @@ class RecordsController(rest.RestController): record_id) # Convert to APIv2 Format - record = self._view.show(context, request, record) + record_data = self._view.show(context, request, record) if request.content_type == 'application/json-patch+json': raise NotImplemented('json-patch not implemented') else: - record = utils.deep_dict_merge(record, body) + record_data = utils.deep_dict_merge(record_data, body) - # Validate the request conforms to the schema - self._resource_schema.validate(record) + # Validate the new set of data + self._resource_schema.validate(record_data) - values = self._view.load(context, request, body) - record = self.central_api.update_record( - context, zone_id, recordset_id, record_id, values) + # Update and persist the resource + record.update(self._view.load(context, request, body)) + record = self.central_api.update_record(context, record) if record['status'] == 'PENDING': response.status_int = 202 diff --git a/designate/api/v2/controllers/recordsets.py b/designate/api/v2/controllers/recordsets.py index 8608a19a..b72bd34b 100644 --- a/designate/api/v2/controllers/recordsets.py +++ b/designate/api/v2/controllers/recordsets.py @@ -113,19 +113,19 @@ class RecordSetsController(rest.RestController): recordset_id) # Convert to APIv2 Format - recordset = self._view.show(context, request, recordset) + recordset_data = self._view.show(context, request, recordset) if request.content_type == 'application/json-patch+json': raise NotImplemented('json-patch not implemented') else: - recordset = utils.deep_dict_merge(recordset, body) + recordset_data = utils.deep_dict_merge(recordset_data, body) - # Validate the request conforms to the schema - self._resource_schema.validate(recordset) + # Validate the new set of data + self._resource_schema.validate(recordset_data) - values = self._view.load(context, request, body) - recordset = self.central_api.update_recordset( - context, zone_id, recordset_id, values) + # Update and persist the resource + recordset.update(self._view.load(context, request, body)) + recordset = self.central_api.update_recordset(context, recordset) response.status_int = 200 diff --git a/designate/api/v2/controllers/tlds.py b/designate/api/v2/controllers/tlds.py index 2167b11e..c5702380 100644 --- a/designate/api/v2/controllers/tlds.py +++ b/designate/api/v2/controllers/tlds.py @@ -98,18 +98,19 @@ class TldsController(rest.RestController): tld = self.central_api.get_tld(context, tld_id) # Convert to APIv2 Format - tld = self._view.show(context, request, tld) + tld_data = self._view.show(context, request, tld) if request.content_type == 'application/json-patch+json': raise NotImplemented('json-patch not implemented') else: - tld = utils.deep_dict_merge(tld, body) + tld_data = utils.deep_dict_merge(tld_data, body) - # Validate the request conforms to the schema - self._resource_schema.validate(tld) + # Validate the new set of data + self._resource_schema.validate(tld_data) - values = self._view.load(context, request, body) - tld = self.central_api.update_tld(context, tld_id, values) + # Update and persist the resource + tld.update(self._view.load(context, request, body)) + tld = self.central_api.update_tld(context, tld) response.status_int = 200 diff --git a/designate/api/v2/controllers/zones.py b/designate/api/v2/controllers/zones.py index cf3e1ce1..1fe8199d 100644 --- a/designate/api/v2/controllers/zones.py +++ b/designate/api/v2/controllers/zones.py @@ -200,7 +200,7 @@ class ZonesController(rest.RestController): zone = self.central_api.get_domain(context, zone_id) # Convert to APIv2 Format - zone = self._view.show(context, request, zone) + zone_data = self._view.show(context, request, zone) if request.content_type == 'application/json-patch+json': # Possible pattern: @@ -217,15 +217,16 @@ class ZonesController(rest.RestController): # 3) ...? raise NotImplemented('json-patch not implemented') else: - zone = utils.deep_dict_merge(zone, body) + zone_data = utils.deep_dict_merge(zone_data, body) - # Validate the request conforms to the schema - self._resource_schema.validate(zone) + # Validate the new set of data + self._resource_schema.validate(zone_data) - values = self._view.load(context, request, body) - zone = self.central_api.update_domain(context, zone_id, values) + # Update and persist the resource + zone.update(self._view.load(context, request, body)) + zone = self.central_api.update_domain(context, zone) - if zone['status'] == 'PENDING': + if zone.status == 'PENDING': response.status_int = 202 else: response.status_int = 200 diff --git a/designate/central/rpcapi.py b/designate/central/rpcapi.py index 4f561ecb..8c8de0ac 100644 --- a/designate/central/rpcapi.py +++ b/designate/central/rpcapi.py @@ -99,11 +99,10 @@ class CentralAPI(object): return self.client.call(context, 'get_server', server_id=server_id) - def update_server(self, context, server_id, values): + def update_server(self, context, server): LOG.info(_LI("update_server: Calling central's update_server.")) - return self.client.call(context, 'update_server', server_id=server_id, - values=values) + return self.client.call(context, 'update_server', server=server) def delete_server(self, context, server_id): LOG.info(_LI("delete_server: Calling central's delete_server.")) @@ -126,11 +125,9 @@ class CentralAPI(object): LOG.info(_LI("get_tsigkey: Calling central's get_tsigkey.")) return self.client.call(context, 'get_tsigkey', tsigkey_id=tsigkey_id) - def update_tsigkey(self, context, tsigkey_id, values): + def update_tsigkey(self, context, tsigkey): LOG.info(_LI("update_tsigkey: Calling central's update_tsigkey.")) - return self.client.call(context, 'update_tsigkey', - tsigkey_id=tsigkey_id, - values=values) + return self.client.call(context, 'update_tsigkey', tsigkey=tsigkey) def delete_tsigkey(self, context, tsigkey_id): LOG.info(_LI("delete_tsigkey: Calling central's delete_tsigkey.")) @@ -176,11 +173,10 @@ class CentralAPI(object): LOG.info(_LI("find_domain: Calling central's find_domain.")) return self.client.call(context, 'find_domain', criterion=criterion) - def update_domain(self, context, domain_id, values, increment_serial=True): + def update_domain(self, context, domain, increment_serial=True): LOG.info(_LI("update_domain: Calling central's update_domain.")) - return self.client.call( - context, 'update_domain', domain_id=domain_id, - values=values, increment_serial=increment_serial) + return self.client.call(context, 'update_domain', domain=domain, + increment_serial=increment_serial) def delete_domain(self, context, domain_id): LOG.info(_LI("delete_domain: Calling central's delete_domain.")) @@ -210,10 +206,9 @@ class CentralAPI(object): LOG.info(_LI("get_tld: Calling central's get_tld.")) return self.client.call(context, 'get_tld', tld_id=tld_id) - def update_tld(self, context, tld_id, values): + def update_tld(self, context, tld): LOG.info(_LI("update_tld: Calling central's update_tld.")) - return self.client.call(context, 'update_tld', tld_id=tld_id, - values=values) + return self.client.call(context, 'update_tld', tld=tld) def delete_tld(self, context, tld_id): LOG.info(_LI("delete_tld: Calling central's delete_tld.")) @@ -242,13 +237,10 @@ class CentralAPI(object): LOG.info(_LI("find_recordset: Calling central's find_recordset.")) return self.client.call(context, 'find_recordset', criterion=criterion) - def update_recordset(self, context, domain_id, recordset_id, values, - increment_serial=True): + def update_recordset(self, context, recordset, increment_serial=True): LOG.info(_LI("update_recordset: Calling central's update_recordset.")) return self.client.call(context, 'update_recordset', - domain_id=domain_id, - recordset_id=recordset_id, - values=values, + recordset=recordset, increment_serial=increment_serial) def delete_recordset(self, context, domain_id, recordset_id, @@ -292,14 +284,10 @@ class CentralAPI(object): LOG.info(_LI("find_record: Calling central's find_record.")) return self.client.call(context, 'find_record', criterion=criterion) - def update_record(self, context, domain_id, recordset_id, record_id, - values, increment_serial=True): + def update_record(self, context, record, increment_serial=True): LOG.info(_LI("update_record: Calling central's update_record.")) return self.client.call(context, 'update_record', - domain_id=domain_id, - recordset_id=recordset_id, - record_id=record_id, - values=values, + record=record, increment_serial=increment_serial) def delete_record(self, context, domain_id, recordset_id, record_id, @@ -368,10 +356,10 @@ class CentralAPI(object): LOG.info(_LI("find_blacklist: Calling central's find_blacklist.")) return self.client.call(context, 'find_blacklist', criterion=criterion) - def update_blacklist(self, context, blacklist_id, values): + def update_blacklist(self, context, blacklist): LOG.info(_LI("update_blacklist: Calling central's update_blacklist.")) return self.client.call(context, 'update_blacklist', - blacklist_id=blacklist_id, values=values) + blacklist=blacklist) def delete_blacklist(self, context, blacklist_id): LOG.info(_LI("delete_blacklist: Calling central's delete blacklist.")) diff --git a/designate/central/service.py b/designate/central/service.py index d36b132f..36c3c962 100644 --- a/designate/central/service.py +++ b/designate/central/service.py @@ -273,9 +273,8 @@ class Service(service.Service): domain = self.storage.get_domain(context, domain_id) # Increment the serial number - values = {'serial': utils.increment_serial(domain['serial'])} - - domain = self.storage.update_domain(context, domain_id, values) + domain.serial = utils.increment_serial(domain.serial) + domain = self.storage.update_domain(context, domain) with wrap_backend_call(): self.backend.update_domain(context, domain) @@ -371,10 +370,13 @@ class Service(service.Service): return self.storage.get_server(context, server_id) @transaction - def update_server(self, context, server_id, values): - policy.check('update_server', context, {'server_id': server_id}) + def update_server(self, context, server): + target = { + 'server_id': server.obj_get_original_value('id'), + } + policy.check('update_server', context, target) - server = self.storage.update_server(context, server_id, values) + server = self.storage.update_server(context, server) # Update backend with the new details.. with wrap_backend_call(): @@ -429,10 +431,13 @@ class Service(service.Service): return self.storage.get_tld(context, tld_id) @transaction - def update_tld(self, context, tld_id, values): - policy.check('update_tld', context, {'tld_id': tld_id}) + def update_tld(self, context, tld): + target = { + 'tld_id': tld.obj_get_original_value('id'), + } + policy.check('update_tld', context, target) - tld = self.storage.update_tld(context, tld_id, values) + tld = self.storage.update_tld(context, tld) self.notifier.info(context, 'dns.tld.update', tld) @@ -478,10 +483,13 @@ class Service(service.Service): return self.storage.get_tsigkey(context, tsigkey_id) @transaction - def update_tsigkey(self, context, tsigkey_id, values): - policy.check('update_tsigkey', context, {'tsigkey_id': tsigkey_id}) + def update_tsigkey(self, context, tsigkey): + target = { + 'tsigkey_id': tsigkey.obj_get_original_value('id'), + } + policy.check('update_tsigkey', context, target) - tsigkey = self.storage.update_tsigkey(context, tsigkey_id, values) + tsigkey = self.storage.update_tsigkey(context, tsigkey) with wrap_backend_call(): self.backend.update_tsigkey(context, tsigkey) @@ -622,41 +630,39 @@ class Service(service.Service): return self.storage.find_domain(context, criterion) @transaction - def update_domain(self, context, domain_id, values, increment_serial=True): + def update_domain(self, context, domain, increment_serial=True): # TODO(kiall): Refactor this method into *MUCH* smaller chunks. - domain = self.storage.get_domain(context, domain_id) - target = { - 'domain_id': domain_id, - 'domain_name': domain.name, - 'tenant_id': domain.tenant_id + 'domain_id': domain.obj_get_original_value('id'), + 'domain_name': domain.obj_get_original_value('name'), + 'tenant_id': domain.obj_get_original_value('tenant_id'), } policy.check('update_domain', context, target) - if 'tenant_id' in values: - # NOTE(kiall): Ensure the user is allowed to delete a domain from - # the original tenant. - policy.check('delete_domain', context, target) + changes = domain.obj_get_changes() - # NOTE(kiall): Ensure the user is allowed to create a domain in - # the new tenant. - target = {'domain_id': domain_id, 'tenant_id': values['tenant_id']} - policy.check('create_domain', context, target) + # Ensure immutable fields are not changed + if 'tenant_id' in changes: + # TODO(kiall): Moving between tenants should be allowed, but the + # current code will not take into account that + # RecordSets and Records must also be moved. + raise exceptions.BadRequest('Moving a domain between tenants is ' + 'not allowed') - if 'name' in values and values['name'] != domain.name: + if 'name' in changes: raise exceptions.BadRequest('Renaming a domain is not allowed') # Ensure TTL is above the minimum - ttl = values.get('ttl', None) + ttl = changes.get('ttl', None) if ttl is not None: self._is_valid_ttl(context, ttl) if increment_serial: # Increment the serial number - values['serial'] = utils.increment_serial(domain.serial) + domain.serial = utils.increment_serial(domain.serial) - domain = self.storage.update_domain(context, domain_id, values) + domain = self.storage.update_domain(context, domain) with wrap_backend_call(): self.backend.update_domain(context, domain) @@ -798,50 +804,54 @@ class Service(service.Service): return self.storage.find_recordset(context, criterion) @transaction - def update_recordset(self, context, domain_id, recordset_id, values, - increment_serial=True): + def update_recordset(self, context, recordset, increment_serial=True): + domain_id = recordset.obj_get_original_value('domain_id') domain = self.storage.get_domain(context, domain_id) - recordset = self.storage.get_recordset(context, recordset_id) - # Ensure the domain_id matches the recordset's domain_id - if domain.id != recordset.domain_id: - raise exceptions.RecordSetNotFound() + changes = recordset.obj_get_changes() + + # Ensure immutable fields are not changed + if 'tenant_id' in changes: + raise exceptions.BadRequest('Moving a recordset between tenants ' + 'is not allowed') + + if 'domain_id' in changes: + raise exceptions.BadRequest('Moving a recordset between domains ' + 'is not allowed') + + if 'type' in changes: + raise exceptions.BadRequest('Changing a recordsets type is not ' + 'allowed') target = { - 'domain_id': domain_id, + 'domain_id': recordset.obj_get_original_value('domain_id'), + 'recordset_id': recordset.obj_get_original_value('id'), 'domain_name': domain.name, - 'recordset_id': recordset.id, 'tenant_id': domain.tenant_id } policy.check('update_recordset', context, target) # Ensure the record name is valid - recordset_name = values['name'] if 'name' in values \ - else recordset.name - recordset_type = values['type'] if 'type' in values \ - else recordset.type - - self._is_valid_recordset_name(context, domain, recordset_name) - self._is_valid_recordset_placement(context, domain, recordset_name, - recordset_type, recordset_id) + self._is_valid_recordset_name(context, domain, recordset.name) + self._is_valid_recordset_placement(context, domain, recordset.name, + recordset.type, recordset.id) self._is_valid_recordset_placement_subdomain( - context, domain, recordset_name) + context, domain, recordset.name) # Ensure TTL is above the minimum - ttl = values.get('ttl', None) + ttl = changes.get('ttl', None) if ttl is not None: self._is_valid_ttl(context, ttl) # Update the recordset - recordset = self.storage.update_recordset(context, recordset_id, - values) + recordset = self.storage.update_recordset(context, recordset) with wrap_backend_call(): self.backend.update_recordset(context, domain, recordset) if increment_serial: - self._increment_domain_serial(context, domain_id) + self._increment_domain_serial(context, domain.id) # Send RecordSet update notification self.notifier.info(context, 'dns.recordset.update', recordset) @@ -969,39 +979,47 @@ class Service(service.Service): return self.storage.find_record(context, criterion) @transaction - def update_record(self, context, domain_id, recordset_id, record_id, - values, increment_serial=True): + def update_record(self, context, record, increment_serial=True): + domain_id = record.obj_get_original_value('domain_id') domain = self.storage.get_domain(context, domain_id) + + recordset_id = record.obj_get_original_value('recordset_id') recordset = self.storage.get_recordset(context, recordset_id) - record = self.storage.get_record(context, record_id) - # Ensure the domain_id matches the record's domain_id - if domain.id != record.domain_id: - raise exceptions.RecordNotFound() + changes = record.obj_get_changes() - # Ensure the recordset_id matches the record's recordset_id - if recordset.id != record.recordset_id: - raise exceptions.RecordNotFound() + # Ensure immutable fields are not changed + if 'tenant_id' in changes: + raise exceptions.BadRequest('Moving a recordset between tenants ' + 'is not allowed') + + if 'domain_id' in changes: + raise exceptions.BadRequest('Moving a recordset between domains ' + 'is not allowed') + + if 'recordset_id' in changes: + raise exceptions.BadRequest('Moving a recordset between ' + 'recordsets is not allowed') target = { - 'domain_id': domain_id, + 'domain_id': record.obj_get_original_value('domain_id'), 'domain_name': domain.name, - 'recordset_id': recordset_id, + 'recordset_id': record.obj_get_original_value('recordset_id'), 'recordset_name': recordset.name, - 'record_id': record.id, + 'record_id': record.obj_get_original_value('id'), 'tenant_id': domain.tenant_id } policy.check('update_record', context, target) # Update the record - record = self.storage.update_record(context, record_id, values) + record = self.storage.update_record(context, record) with wrap_backend_call(): self.backend.update_record(context, domain, recordset, record) if increment_serial: - self._increment_domain_serial(context, domain_id) + self._increment_domain_serial(context, domain.id) # Send Record update notification self.notifier.info(context, 'dns.record.update', record) @@ -1482,11 +1500,13 @@ class Service(service.Service): return blacklist @transaction - def update_blacklist(self, context, blacklist_id, values): - policy.check('update_blacklist', context) + def update_blacklist(self, context, blacklist): + target = { + 'blacklist_id': blacklist.id, + } + policy.check('update_blacklist', context, target) - blacklist = self.storage.update_blacklist(context, blacklist_id, - values) + blacklist = self.storage.update_blacklist(context, blacklist) self.notifier.info(context, 'dns.blacklist.update', blacklist) diff --git a/designate/quota/impl_storage.py b/designate/quota/impl_storage.py index 5eb1121a..6e8e1834 100644 --- a/designate/quota/impl_storage.py +++ b/designate/quota/impl_storage.py @@ -17,6 +17,7 @@ from oslo.config import cfg from designate import exceptions from designate import storage +from designate import objects from designate.openstack.common import log as logging from designate.quota.base import Quota @@ -57,18 +58,15 @@ class StorageQuota(Quota): context.all_tenants = True def create_quota(): - values = { - 'tenant_id': tenant_id, - 'resource': resource, - 'hard_limit': hard_limit, - } + quota = objects.Quota( + tenant_id=tenant_id, resource=resource, hard_limit=hard_limit) - self.storage.create_quota(context, values) + self.storage.create_quota(context, quota) - def update_quota(): - values = {'hard_limit': hard_limit} + def update_quota(quota): + quota.hard_limit = hard_limit - self.storage.update_quota(context, quota['id'], values) + self.storage.update_quota(context, quota) if resource not in self.get_default_quotas(context).keys(): raise exceptions.QuotaResourceUnknown("%s is not a valid quota " @@ -82,7 +80,7 @@ class StorageQuota(Quota): except exceptions.NotFound: create_quota() else: - update_quota() + update_quota(quota) return {resource: hard_limit} diff --git a/designate/storage/base.py b/designate/storage/base.py index 18d819aa..0037b916 100644 --- a/designate/storage/base.py +++ b/designate/storage/base.py @@ -71,13 +71,12 @@ class Storage(DriverPlugin): """ @abc.abstractmethod - def update_quota(self, context, quota_id, values): + def update_quota(self, context, quota): """ - Update a Quota via ID + Update a Quota :param context: RPC Context. - :param quota_id: Quota ID to update. - :param values: Values to update the Quota from + :param quota: Quota to update. """ @abc.abstractmethod @@ -124,13 +123,12 @@ class Storage(DriverPlugin): """ @abc.abstractmethod - def update_server(self, context, server_id, values): + def update_server(self, context, server): """ - Update a Server via ID + Update a Server :param context: RPC Context. - :param server_id: Server ID to update. - :param values: Values to update the Server from + :param server: Server object """ @abc.abstractmethod @@ -192,13 +190,12 @@ class Storage(DriverPlugin): """ @abc.abstractmethod - def update_tld(self, context, tld_id, values): + def update_tld(self, context, tld): """ - Update a TLD via ID + Update a TLD :param context: RPC Context. - :param tld_id: TLD ID to update. - :param values: Values to update the TLD from + :param tld: TLD to update. """ @abc.abstractmethod @@ -245,13 +242,12 @@ class Storage(DriverPlugin): """ @abc.abstractmethod - def update_tsigkey(self, context, tsigkey_id, values): + def update_tsigkey(self, context, tsigkey): """ - Update a TSIG Key via ID + Update a TSIG Key :param context: RPC Context. - :param tsigkey_id: TSIG Key ID to update. - :param values: Values to update the TSIG Key from + :param tsigkey: TSIG Keyto update. """ @abc.abstractmethod @@ -332,13 +328,12 @@ class Storage(DriverPlugin): """ @abc.abstractmethod - def update_domain(self, context, domain_id, values): + def update_domain(self, context, domain): """ - Update a Domain via ID. + Update a Domain :param context: RPC Context. - :param domain_id: Values to update the Domain with - :param values: Values to update the Domain from. + :param domain: Domain object. """ @abc.abstractmethod @@ -405,12 +400,12 @@ class Storage(DriverPlugin): """ @abc.abstractmethod - def update_recordset(self, context, recordset_id, values): + def update_recordset(self, context, recordset): """ - Update a recordset via ID + Update a recordset :param context: RPC Context - :param recordset_id: RecordSet ID to update + :param recordset: RecordSet to update """ @abc.abstractmethod @@ -477,12 +472,12 @@ class Storage(DriverPlugin): """ @abc.abstractmethod - def update_record(self, context, record_id, values): + def update_record(self, context, record): """ - Update a record via ID + Update a record :param context: RPC Context - :param record_id: Record ID to update + :param record: Record to update """ @abc.abstractmethod @@ -547,13 +542,12 @@ class Storage(DriverPlugin): """ @abc.abstractmethod - def update_blacklist(self, context, blacklist_id, values): + def update_blacklist(self, context, blacklist): """ - Update a Blacklist via ID + Update a Blacklist :param context: RPC Context. - :param blacklist_id: Blacklist ID to update. - :param values: Values to update the Blacklist from + :param blacklist: Blacklist to update. """ @abc.abstractmethod diff --git a/designate/storage/impl_sqlalchemy/__init__.py b/designate/storage/impl_sqlalchemy/__init__.py index 0d1a0bd2..b34a2ae2 100644 --- a/designate/storage/impl_sqlalchemy/__init__.py +++ b/designate/storage/impl_sqlalchemy/__init__.py @@ -213,17 +213,17 @@ class SQLAlchemyStorage(base.Storage): return _set_object_from_model(objects.Quota(), quota) - def update_quota(self, context, quota_id, values): - quota = self._find_quotas(context, {'id': quota_id}, one=True) + def update_quota(self, context, quota): + storage_quota = self._find_quotas(context, {'id': quota.id}, one=True) - quota.update(values) + storage_quota.update(quota.obj_get_changes()) try: - quota.save(self.session) + storage_quota.save(self.session) except exceptions.Duplicate: raise exceptions.DuplicateQuota() - return _set_object_from_model(objects.Quota(), quota) + return _set_object_from_model(quota, storage_quota) def delete_quota(self, context, quota_id): quota = self._find_quotas(context, {'id': quota_id}, one=True) @@ -265,17 +265,18 @@ class SQLAlchemyStorage(base.Storage): server = self._find_servers(context, {'id': server_id}, one=True) return _set_object_from_model(objects.Server(), server) - def update_server(self, context, server_id, values): - server = self._find_servers(context, {'id': server_id}, one=True) + def update_server(self, context, server): + storage_server = self._find_servers(context, {'id': server.id}, + one=True) - server.update(values) + storage_server.update(server.obj_get_changes()) try: - server.save(self.session) + storage_server.save(self.session) except exceptions.Duplicate: raise exceptions.DuplicateServer() - return _set_object_from_model(objects.Server(), server) + return _set_object_from_model(server, storage_server) def delete_server(self, context, server_id): server = self._find_servers(context, {'id': server_id}, one=True) @@ -319,16 +320,16 @@ class SQLAlchemyStorage(base.Storage): tld = self._find_tlds(context, {'id': tld_id}, one=True) return _set_object_from_model(objects.Tld(), tld) - def update_tld(self, context, tld_id, values): - tld = self._find_tlds(context, {'id': tld_id}, one=True) - tld.update(values) + def update_tld(self, context, tld): + storage_tld = self._find_tlds(context, {'id': tld.id}, one=True) + storage_tld.update(tld.obj_get_changes()) try: - tld.save(self.session) + storage_tld.save(self.session) except exceptions.Duplicate: raise exceptions.DuplicateTLD() - return _set_object_from_model(objects.Tld(), tld) + return _set_object_from_model(tld, storage_tld) def delete_tld(self, context, tld_id): tld = self._find_tlds(context, {'id': tld_id}, one=True) @@ -371,17 +372,18 @@ class SQLAlchemyStorage(base.Storage): return _set_object_from_model(objects.TsigKey(), tsigkey) - def update_tsigkey(self, context, tsigkey_id, values): - tsigkey = self._find_tsigkeys(context, {'id': tsigkey_id}, one=True) + def update_tsigkey(self, context, tsigkey): + storage_tsigkey = self._find_tsigkeys(context, {'id': tsigkey.id}, + one=True) - tsigkey.update(values) + storage_tsigkey.update(tsigkey.obj_get_changes()) try: - tsigkey.save(self.session) + storage_tsigkey.save(self.session) except exceptions.Duplicate: raise exceptions.DuplicateTsigKey() - return _set_object_from_model(objects.TsigKey(), tsigkey) + return _set_object_from_model(tsigkey, storage_tsigkey) def delete_tsigkey(self, context, tsigkey_id): tsigkey = self._find_tsigkeys(context, {'id': tsigkey_id}, one=True) @@ -468,17 +470,18 @@ class SQLAlchemyStorage(base.Storage): domain = self._find_domains(context, criterion, one=True) return _set_object_from_model(objects.Domain(), domain) - def update_domain(self, context, domain_id, values): - domain = self._find_domains(context, {'id': domain_id}, one=True) + def update_domain(self, context, domain): + storage_domain = self._find_domains(context, {'id': domain.id}, + one=True) - domain.update(values) + storage_domain.update(domain.obj_get_changes()) try: - domain.save(self.session) + storage_domain.save(self.session) except exceptions.Duplicate: raise exceptions.DuplicateDomain() - return _set_object_from_model(objects.Domain(), domain) + return _set_object_from_model(domain, storage_domain) def delete_domain(self, context, domain_id): domain = self._find_domains(context, {'id': domain_id}, one=True) @@ -543,18 +546,18 @@ class SQLAlchemyStorage(base.Storage): return _set_object_from_model(objects.RecordSet(), recordset) - def update_recordset(self, context, recordset_id, values): - recordset = self._find_recordsets(context, {'id': recordset_id}, - one=True) + def update_recordset(self, context, recordset): + storage_recordset = self._find_recordsets( + context, {'id': recordset.id}, one=True) - recordset.update(values) + storage_recordset.update(recordset.obj_get_changes()) try: - recordset.save(self.session) + storage_recordset.save(self.session) except exceptions.Duplicate: raise exceptions.DuplicateRecordSet() - return _set_object_from_model(objects.RecordSet(), recordset) + return _set_object_from_model(recordset, storage_recordset) def delete_recordset(self, context, recordset_id): recordset = self._find_recordsets(context, {'id': recordset_id}, @@ -617,17 +620,18 @@ class SQLAlchemyStorage(base.Storage): return _set_object_from_model(objects.Record(), record) - def update_record(self, context, record_id, values): - record = self._find_records(context, {'id': record_id}, one=True) + def update_record(self, context, record): + storage_record = self._find_records(context, {'id': record.id}, + one=True) - record.update(values) + storage_record.update(record.obj_get_changes()) try: - record.save(self.session) + storage_record.save(self.session) except exceptions.Duplicate: raise exceptions.DuplicateRecord() - return _set_object_from_model(objects.Record(), record) + return _set_object_from_model(record, storage_record) def delete_record(self, context, record_id): record = self._find_records(context, {'id': record_id}, one=True) @@ -686,18 +690,18 @@ class SQLAlchemyStorage(base.Storage): return _set_object_from_model(objects.Blacklist(), blacklist) - def update_blacklist(self, context, blacklist_id, values): - blacklist = self._find_blacklist(context, {'id': blacklist_id}, - one=True) + def update_blacklist(self, context, blacklist): + storage_blacklist = self._find_blacklist(context, {'id': blacklist.id}, + one=True) - blacklist.update(values) + storage_blacklist.update(blacklist.obj_get_changes()) try: - blacklist.save(self.session) + storage_blacklist.save(self.session) except exceptions.Duplicate: raise exceptions.DuplicateBlacklist() - return _set_object_from_model(objects.Blacklist(), blacklist) + return _set_object_from_model(blacklist, storage_blacklist) def delete_blacklist(self, context, blacklist_id): diff --git a/designate/tests/test_central/test_service.py b/designate/tests/test_central/test_service.py index d006092f..91250d04 100644 --- a/designate/tests/test_central/test_service.py +++ b/designate/tests/test_central/test_service.py @@ -210,19 +210,19 @@ class CentralServiceTest(CentralTestCase): def test_update_server(self): # Create a server - expected_server = self.create_server() + server = self.create_server(name='ns1.example.org.') - # Update the server - values = dict(name='prefix.%s' % expected_server['name']) - self.central_service.update_server( - self.admin_context, expected_server['id'], values=values) + # Update the Object + server.name = 'ns2.example.org.' + + # Perform the update + self.central_service.update_server(self.admin_context, server) # Fetch the server again - server = self.central_service.get_server( - self.admin_context, expected_server['id']) + server = self.central_service.get_server(self.admin_context, server.id) - # Ensure the server was updated correctly - self.assertEqual(server['name'], 'prefix.%s' % expected_server['name']) + # Ensure the new value took + self.assertEqual('ns2.example.org.', server.name) def test_delete_server(self): # Create a server @@ -300,19 +300,19 @@ class CentralServiceTest(CentralTestCase): def test_update_tld(self): # Create a tld - expected_tld = self.create_tld(fixture=0) + tld = self.create_tld(name='org.') - # Update the tld - values = dict(name='prefix.%s' % expected_tld['name']) - self.central_service.update_tld( - self.admin_context, expected_tld['id'], values=values) + # Update the Object + tld.name = 'net.' + + # Perform the update + self.central_service.update_tld(self.admin_context, tld) # Fetch the tld again - tld = self.central_service.get_tld( - self.admin_context, expected_tld['id']) + tld = self.central_service.get_tld(self.admin_context, tld.id) # Ensure the tld was updated correctly - self.assertEqual(tld['name'], 'prefix.%s' % expected_tld['name']) + self.assertEqual('net.', tld.name) def test_delete_tld(self): # Create a tld @@ -376,22 +376,21 @@ class CentralServiceTest(CentralTestCase): self.assertEqual(tsigkey['secret'], expected['secret']) def test_update_tsigkey(self): - # Create a tsigkey using default values - expected = self.create_tsigkey() + # Create a tsigkey + tsigkey = self.create_tsigkey(name='test-key') - # Update the tsigkey - fixture = self.get_tsigkey_fixture(fixture=1) - values = dict(name=fixture['name']) + # Update the Object + tsigkey.name = 'test-key-updated' - self.central_service.update_tsigkey( - self.admin_context, expected['id'], values=values) + # Perform the update + self.central_service.update_tsigkey(self.admin_context, tsigkey) # Fetch the tsigkey again tsigkey = self.central_service.get_tsigkey( - self.admin_context, expected['id']) + self.admin_context, tsigkey.id) - # Ensure the tsigkey was updated correctly - self.assertEqual(tsigkey['name'], fixture['name']) + # Ensure the new value took + self.assertEqual('test-key-updated', tsigkey.name) def test_delete_tsigkey(self): # Create a tsigkey @@ -748,24 +747,24 @@ class CentralServiceTest(CentralTestCase): def test_update_domain(self): # Create a domain - expected_domain = self.create_domain() + domain = self.create_domain(email='info@example.org') + original_serial = domain.serial # Reset the list of notifications self.reset_notifications() - # Update the domain - values = dict(email='new@example.com') + # Update the object + domain.email = 'info@example.net' - self.central_service.update_domain( - self.admin_context, expected_domain['id'], values=values) + # Perform the update + self.central_service.update_domain(self.admin_context, domain) # Fetch the domain again - domain = self.central_service.get_domain( - self.admin_context, expected_domain['id']) + domain = self.central_service.get_domain(self.admin_context, domain.id) # Ensure the domain was updated correctly - self.assertTrue(domain['serial'] > expected_domain['serial']) - self.assertEqual(domain['email'], 'new@example.com') + self.assertTrue(domain.serial > original_serial) + self.assertEqual('info@example.net', domain.email) # Ensure we sent exactly 1 notification notifications = self.get_notifications() @@ -780,39 +779,42 @@ class CentralServiceTest(CentralTestCase): # Ensure the notification payload contains the correct info payload = message['payload'] - self.assertEqual(payload['id'], domain['id']) - self.assertEqual(payload['name'], domain['name']) - self.assertEqual(payload['tenant_id'], domain['tenant_id']) + self.assertEqual(domain.id, payload['id']) + self.assertEqual(domain.name, payload['name']) + self.assertEqual(domain.tenant_id, payload['tenant_id']) def test_update_domain_without_incrementing_serial(self): # Create a domain - expected_domain = self.create_domain() + domain = self.create_domain(email='info@example.org') + original_serial = domain.serial - # Update the domain - values = dict(email='new@example.com') + # Reset the list of notifications + self.reset_notifications() + # Update the object + domain.email = 'info@example.net' + + # Perform the update self.central_service.update_domain( - self.admin_context, expected_domain['id'], values=values, - increment_serial=False) + self.admin_context, domain, increment_serial=False) # Fetch the domain again - domain = self.central_service.get_domain( - self.admin_context, expected_domain['id']) + domain = self.central_service.get_domain(self.admin_context, domain.id) # Ensure the domain was updated correctly - self.assertEqual(domain['serial'], expected_domain['serial']) - self.assertEqual(domain['email'], 'new@example.com') + self.assertEqual(original_serial, domain.serial) + self.assertEqual('info@example.net', domain.email) def test_update_domain_name_fail(self): # Create a domain - expected_domain = self.create_domain() + domain = self.create_domain(name='example.org.') - # Update the domain + # Update the Object + domain.name = 'example.net.' + + # Perform the update with testtools.ExpectedException(exceptions.BadRequest): - values = dict(name='renamed-domain.com.') - - self.central_service.update_domain( - self.admin_context, expected_domain['id'], values=values) + self.central_service.update_domain(self.admin_context, domain) def test_delete_domain(self): # Create a domain @@ -1081,67 +1083,70 @@ class CentralServiceTest(CentralTestCase): self.assertEqual(recordset['name'], expected['name']) def test_update_recordset(self): + # Create a domain domain = self.create_domain() # Create a recordset - expected = self.create_recordset(domain) + recordset = self.create_recordset(domain) # Update the recordset - values = dict(ttl=1800) - self.central_service.update_recordset( - self.admin_context, domain['id'], expected['id'], values=values) + recordset.ttl = 1800 - # Fetch the recordset again + # Perform the update + self.central_service.update_recordset(self.admin_context, recordset) + + # Fetch the resource again recordset = self.central_service.get_recordset( - self.admin_context, domain['id'], expected['id']) + self.admin_context, recordset.domain_id, recordset.id) - # Ensure the record was updated correctly - self.assertEqual(recordset['ttl'], 1800) + # Ensure the new value took + self.assertEqual(recordset.ttl, 1800) def test_update_recordset_without_incrementing_serial(self): domain = self.create_domain() # Create a recordset - expected = self.create_recordset(domain) + recordset = self.create_recordset(domain) # Fetch the domain so we have the latest serial number domain_before = self.central_service.get_domain( - self.admin_context, domain['id']) + self.admin_context, domain.id) # Update the recordset - values = dict(ttl=1800) - self.central_service.update_recordset( - self.admin_context, domain['id'], expected['id'], values, - increment_serial=False) + recordset.ttl = 1800 - # Fetch the recordset again + # Perform the update + self.central_service.update_recordset( + self.admin_context, recordset, increment_serial=False) + + # Fetch the resource again recordset = self.central_service.get_recordset( - self.admin_context, domain['id'], expected['id']) + self.admin_context, recordset.domain_id, recordset.id) # Ensure the recordset was updated correctly - self.assertEqual(recordset['ttl'], 1800) + self.assertEqual(recordset.ttl, 1800) # Ensure the domains serial number was not updated domain_after = self.central_service.get_domain( - self.admin_context, domain['id']) + self.admin_context, domain.id) - self.assertEqual(domain_before['serial'], domain_after['serial']) + self.assertEqual(domain_before.serial, domain_after.serial) - def test_update_recordset_incorrect_domain_id(self): + def test_update_recordset_immutable_domain_id(self): domain = self.create_domain() other_domain = self.create_domain(fixture=1) # Create a recordset - expected = self.create_recordset(domain) + recordset = self.create_recordset(domain) # Update the recordset - values = dict(ttl=1800) + recordset.ttl = 1800 + recordset.domain_id = other_domain.id - # Ensure we get a 404 if we use the incorrect domain_id - with testtools.ExpectedException(exceptions.RecordSetNotFound): + # Ensure we get a BadRequest if we change the domain_id + with testtools.ExpectedException(exceptions.BadRequest): self.central_service.update_recordset( - self.admin_context, other_domain['id'], expected['id'], - values=values) + self.admin_context, recordset) def test_delete_recordset(self): domain = self.create_domain() @@ -1377,85 +1382,83 @@ class CentralServiceTest(CentralTestCase): recordset = self.create_recordset(domain, 'A') # Create a record - expected = self.create_record(domain, recordset) + record = self.create_record(domain, recordset) - # Update the record - values = dict(data='127.0.0.2') - self.central_service.update_record( - self.admin_context, domain['id'], recordset['id'], expected['id'], - values=values) + # Update the Object + record.data = '192.0.2.255' - # Fetch the record again + # Perform the update + self.central_service.update_record(self.admin_context, record) + + # Fetch the resource again record = self.central_service.get_record( - self.admin_context, domain['id'], recordset['id'], expected['id']) + self.admin_context, record.domain_id, record.recordset_id, + record.id) - # Ensure the record was updated correctly - self.assertEqual(record['data'], '127.0.0.2') + # Ensure the new value took + self.assertEqual('192.0.2.255', record.data) def test_update_record_without_incrementing_serial(self): domain = self.create_domain() recordset = self.create_recordset(domain, 'A') # Create a record - expected = self.create_record(domain, recordset) + record = self.create_record(domain, recordset) # Fetch the domain so we have the latest serial number domain_before = self.central_service.get_domain( - self.admin_context, domain['id']) + self.admin_context, domain.id) - # Update the record - values = dict(data='127.0.0.2') + # Update the Object + record.data = '192.0.2.255' + # Perform the update self.central_service.update_record( - self.admin_context, domain['id'], recordset['id'], expected['id'], - values, increment_serial=False) + self.admin_context, record, increment_serial=False) - # Fetch the record again + # Fetch the resource again record = self.central_service.get_record( - self.admin_context, domain['id'], recordset['id'], expected['id']) + self.admin_context, record.domain_id, record.recordset_id, + record.id) - # Ensure the record was updated correctly - self.assertEqual(record['data'], '127.0.0.2') + # Ensure the new value took + self.assertEqual('192.0.2.255', record.data) # Ensure the domains serial number was not updated domain_after = self.central_service.get_domain( - self.admin_context, domain['id']) + self.admin_context, domain.id) - self.assertEqual(domain_before['serial'], domain_after['serial']) + self.assertEqual(domain_before.serial, domain_after.serial) - def test_update_record_incorrect_domain_id(self): + def test_update_record_immutable_domain_id(self): domain = self.create_domain() - recordset = self.create_recordset(domain, 'A') + recordset = self.create_recordset(domain) other_domain = self.create_domain(fixture=1) # Create a record - expected = self.create_record(domain, recordset) + record = self.create_record(domain, recordset) # Update the record - values = dict(data='127.0.0.2') + record.domain_id = other_domain.id - # Ensure we get a 404 if we use the incorrect domain_id - with testtools.ExpectedException(exceptions.RecordNotFound): - self.central_service.update_record( - self.admin_context, other_domain['id'], recordset['id'], - expected['id'], values=values) + # Ensure we get a BadRequest if we change the domain_id + with testtools.ExpectedException(exceptions.BadRequest): + self.central_service.update_record(self.admin_context, record) - def test_update_record_incorrect_recordset_id(self): + def test_update_record_immutable_recordset_id(self): domain = self.create_domain() - recordset = self.create_recordset(domain, 'A') - other_recordset = self.create_recordset(domain, 'A', fixture=1) + recordset = self.create_recordset(domain) + other_recordset = self.create_recordset(domain, fixture=1) # Create a record - expected = self.create_record(domain, recordset) + record = self.create_record(domain, recordset) # Update the record - values = dict(data='127.0.0.2') + record.recordset_id = other_recordset.id - # Ensure we get a 404 if we use the incorrect domain_id - with testtools.ExpectedException(exceptions.RecordNotFound): - self.central_service.update_record( - self.admin_context, domain['id'], other_recordset['id'], - expected['id'], values=values) + # Ensure we get a BadRequest if we change the recordset_id + with testtools.ExpectedException(exceptions.BadRequest): + self.central_service.update_record(self.admin_context, record) def test_delete_record(self): domain = self.create_domain() @@ -1885,23 +1888,20 @@ class CentralServiceTest(CentralTestCase): def test_update_blacklist(self): # Create a blacklisted zone - expected = self.create_blacklist(fixture=0) - new_comment = "This is a different comment." + blacklist = self.create_blacklist(fixture=0) - # Update the blacklist - updated_values = dict( - description=new_comment - ) - self.central_service.update_blacklist(self.admin_context, - expected['id'], - updated_values) + # Update the Object + blacklist.description = "New Comment" - # Fetch the blacklist + # Perform the update + self.central_service.update_blacklist(self.admin_context, blacklist) + + # Fetch the resource again blacklist = self.central_service.get_blacklist(self.admin_context, - expected['id']) + blacklist.id) # Verify that the record was updated correctly - self.assertEqual(blacklist['description'], new_comment) + self.assertEqual("New Comment", blacklist.description) def test_delete_blacklist(self): # Create a blacklisted zone diff --git a/designate/tests/test_storage/__init__.py b/designate/tests/test_storage/__init__.py index a7c087ad..a3b17555 100644 --- a/designate/tests/test_storage/__init__.py +++ b/designate/tests/test_storage/__init__.py @@ -216,31 +216,33 @@ class StorageTestCase(object): def test_update_quota(self): # Create a quota - fixture = self.get_quota_fixture() quota = self.create_quota(fixture=1) - updated = self.storage.update_quota(self.admin_context, quota['id'], - fixture) + # Update the Object + quota.hard_limit = 5000 - self.assertEqual(updated['resource'], fixture['resource']) - self.assertEqual(updated['hard_limit'], fixture['hard_limit']) + # Perform the update + quota = self.storage.update_quota(self.admin_context, quota) + + # Ensure the new value took + self.assertEqual(5000, quota.hard_limit) def test_update_quota_duplicate(self): - context = self.get_admin_context() - context.all_tenants = True - # Create two quotas - self.create_quota(fixture=0, tenant_id='1') - quota = self.create_quota(fixture=0, tenant_id='2') + quota_one = self.create_quota(fixture=0) + quota_two = self.create_quota(fixture=1) + + # Update the Q2 object to be a duplicate of Q1 + quota_two.resource = quota_one.resource with testtools.ExpectedException(exceptions.DuplicateQuota): - self.storage.update_quota(context, quota['id'], - values={'tenant_id': '1'}) + self.storage.update_quota(self.admin_context, quota_two) def test_update_quota_missing(self): + quota = objects.Quota(id='caf771fc-6b05-4891-bee1-c2a48621f57b') + with testtools.ExpectedException(exceptions.QuotaNotFound): - uuid = 'caf771fc-6b05-4891-bee1-c2a48621f57b' - self.storage.update_quota(self.admin_context, uuid, {}) + self.storage.update_quota(self.admin_context, quota) def test_delete_quota(self): quota = self.create_quota() @@ -334,29 +336,32 @@ class StorageTestCase(object): def test_update_server(self): # Create a server - fixture = self.get_server_fixture() - server = self.create_server(**fixture) + server = self.create_server(name='ns1.example.org.') - updated = self.storage.update_server(self.admin_context, server['id'], - fixture) + # Update the Object + server.name = 'ns2.example.org.' - self.assertEqual(str(updated['name']), str(fixture['name'])) + # Perform the update + server = self.storage.update_server(self.admin_context, server) + + # Ensure the new value took + self.assertEqual('ns2.example.org.', server.name) def test_update_server_duplicate(self): # Create two servers - self.create_server(fixture=0) - server = self.create_server(fixture=1) + server_one = self.create_server(fixture=0) + server_two = self.create_server(fixture=1) - values = self.server_fixtures[0] + # Update the S2 object to be a duplicate of S1 + server_two.name = server_one.name with testtools.ExpectedException(exceptions.DuplicateServer): - self.storage.update_server(self.admin_context, server['id'], - values) + self.storage.update_server(self.admin_context, server_two) def test_update_server_missing(self): + server = objects.Server(id='caf771fc-6b05-4891-bee1-c2a48621f57b') with testtools.ExpectedException(exceptions.ServerNotFound): - uuid = 'caf771fc-6b05-4891-bee1-c2a48621f57b' - self.storage.update_server(self.admin_context, uuid, {}) + self.storage.update_server(self.admin_context, server) def test_delete_server(self): server = self.create_server() @@ -459,32 +464,33 @@ class StorageTestCase(object): def test_update_tsigkey(self): # Create a tsigkey - fixture = self.get_tsigkey_fixture() - tsigkey = self.create_tsigkey(**fixture) + tsigkey = self.create_tsigkey(name='test-key') - updated = self.storage.update_tsigkey(self.admin_context, - tsigkey['id'], - fixture) + # Update the Object + tsigkey.name = 'test-key-updated' - self.assertEqual(updated['name'], fixture['name']) - self.assertEqual(updated['algorithm'], fixture['algorithm']) - self.assertEqual(updated['secret'], fixture['secret']) + # Perform the update + tsigkey = self.storage.update_tsigkey(self.admin_context, tsigkey) + + # Ensure the new value took + self.assertEqual('test-key-updated', tsigkey.name) def test_update_tsigkey_duplicate(self): # Create two tsigkeys - self.create_tsigkey(fixture=0) - tsigkey = self.create_tsigkey(fixture=1) + tsigkey_one = self.create_tsigkey(fixture=0) + tsigkey_two = self.create_tsigkey(fixture=1) - values = self.tsigkey_fixtures[0] + # Update the T2 object to be a duplicate of T1 + tsigkey_two.name = tsigkey_one.name with testtools.ExpectedException(exceptions.DuplicateTsigKey): - self.storage.update_tsigkey(self.admin_context, tsigkey['id'], - values) + self.storage.update_tsigkey(self.admin_context, tsigkey_two) def test_update_tsigkey_missing(self): + tsigkey = objects.TsigKey(id='caf771fc-6b05-4891-bee1-c2a48621f57b') + with testtools.ExpectedException(exceptions.TsigKeyNotFound): - uuid = 'caf771fc-6b05-4891-bee1-c2a48621f57b' - self.storage.update_tsigkey(self.admin_context, uuid, {}) + self.storage.update_tsigkey(self.admin_context, tsigkey) def test_delete_tsigkey(self): tsigkey = self.create_tsigkey() @@ -736,30 +742,32 @@ class StorageTestCase(object): def test_update_domain(self): # Create a domain - fixture = self.get_domain_fixture() - domain = self.create_domain(**fixture) + domain = self.create_domain(name='example.org.') - updated = self.storage.update_domain(self.admin_context, domain['id'], - fixture) + # Update the Object + domain.name = 'example.net.' - self.assertEqual(updated['name'], fixture['name']) - self.assertEqual(updated['email'], fixture['email']) - self.assertIn('status', updated) + # Perform the update + domain = self.storage.update_domain(self.admin_context, domain) + + # Ensure the new valie took + self.assertEqual('example.net.', domain.name) def test_update_domain_duplicate(self): # Create two domains - fixture = self.get_domain_fixture(fixture=0) - self.create_domain(**fixture) + domain_one = self.create_domain(fixture=0) domain_two = self.create_domain(fixture=1) + # Update the D2 object to be a duplicate of D1 + domain_two.name = domain_one.name + with testtools.ExpectedException(exceptions.DuplicateDomain): - self.storage.update_domain(self.admin_context, domain_two['id'], - fixture) + self.storage.update_domain(self.admin_context, domain_two) def test_update_domain_missing(self): + domain = objects.Domain(id='caf771fc-6b05-4891-bee1-c2a48621f57b') with testtools.ExpectedException(exceptions.DomainNotFound): - uuid = 'caf771fc-6b05-4891-bee1-c2a48621f57b' - self.storage.update_domain(self.admin_context, uuid, {}) + self.storage.update_domain(self.admin_context, domain) def test_delete_domain(self): domain = self.create_domain() @@ -930,39 +938,35 @@ class StorageTestCase(object): # Create a recordset recordset = self.create_recordset(domain) - # Get some different values to test the update with - recordset_fixture = self.get_recordset_fixture(domain['name'], - fixture=1) + # Update the Object + recordset.ttl = 1800 - # Update the recordset with the new values... - updated = self.storage.update_recordset(self.admin_context, - recordset['id'], - recordset_fixture) + # Perform the update + recordset = self.storage.update_recordset(self.admin_context, + recordset) - # Ensure the update succeeded - self.assertEqual(updated['id'], recordset['id']) - self.assertEqual(updated['name'], recordset_fixture['name']) - self.assertEqual(updated['type'], recordset_fixture['type']) + # Ensure the new value took + self.assertEqual(1800, recordset.ttl) def test_update_recordset_duplicate(self): domain = self.create_domain() - # Create the first two recordsets - recordset_one_fixture = self.get_recordset_fixture(domain['name']) - self.create_recordset(domain, **recordset_one_fixture) - recordset_two = self.create_recordset(domain, fixture=1) + # Create two recordsets + recordset_one = self.create_recordset(domain, type='A') + recordset_two = self.create_recordset(domain, type='A', fixture=1) + + # Update the R2 object to be a duplicate of R1 + recordset_two.name = recordset_one.name with testtools.ExpectedException(exceptions.DuplicateRecordSet): - # Attempt to update the second recordset, making it a duplicate - # recordset - self.storage.update_recordset(self.admin_context, - recordset_two['id'], - recordset_one_fixture) + self.storage.update_recordset(self.admin_context, recordset_two) def test_update_recordset_missing(self): + recordset = objects.RecordSet( + id='caf771fc-6b05-4891-bee1-c2a48621f57b') + with testtools.ExpectedException(exceptions.RecordSetNotFound): - uuid = 'caf771fc-6b05-4891-bee1-c2a48621f57b' - self.storage.update_recordset(self.admin_context, uuid, {}) + self.storage.update_recordset(self.admin_context, recordset) def test_delete_recordset(self): domain = self.create_domain() @@ -1188,42 +1192,39 @@ class StorageTestCase(object): def test_update_record(self): domain = self.create_domain() - recordset = self.create_recordset(domain) + recordset = self.create_recordset(domain, type='A') # Create a record record = self.create_record(domain, recordset) - # Get some different values to test the update with - record_fixture = self.get_record_fixture(recordset['type'], fixture=1) + # Update the Object + record.data = '192.0.2.255' - # Update the record with the new values... - updated = self.storage.update_record(self.admin_context, record['id'], - record_fixture) + # Perform the update + record = self.storage.update_record(self.admin_context, record) - # Ensure the update succeeded - self.assertEqual(updated['id'], record['id']) - self.assertEqual(updated['data'], record_fixture['data']) - self.assertNotEqual(updated['hash'], record['hash']) - self.assertIn('status', updated) + # Ensure the new value took + self.assertEqual('192.0.2.255', record.data) def test_update_record_duplicate(self): domain = self.create_domain() recordset = self.create_recordset(domain) - # Create the first two records - record_one_fixture = self.get_record_fixture(recordset['type']) - self.create_record(domain, recordset, **record_one_fixture) + # Create two records + record_one = self.create_record(domain, recordset) record_two = self.create_record(domain, recordset, fixture=1) + # Update the R2 object to be a duplicate of R1 + record_two.data = record_one.data + with testtools.ExpectedException(exceptions.DuplicateRecord): - # Attempt to update the second record, making it a duplicate record - self.storage.update_record(self.admin_context, record_two['id'], - record_one_fixture) + self.storage.update_record(self.admin_context, record_two) def test_update_record_missing(self): + record = objects.Record(id='caf771fc-6b05-4891-bee1-c2a48621f57b') + with testtools.ExpectedException(exceptions.RecordNotFound): - uuid = 'caf771fc-6b05-4891-bee1-c2a48621f57b' - self.storage.update_record(self.admin_context, uuid, {}) + self.storage.update_record(self.admin_context, record) def test_delete_record(self): domain = self.create_domain()