diff --git a/designate/central/service.py b/designate/central/service.py index 08b29d208..dadd37061 100644 --- a/designate/central/service.py +++ b/designate/central/service.py @@ -2775,6 +2775,9 @@ class Service(service.RPCService): except exceptions.InvalidTTL as e: zone_import.status = 'ERROR' zone_import.message = str(e) + except exceptions.OverQuota: + zone_import.status = 'ERROR' + zone_import.message = 'Quota exceeded during zone import.' except Exception as e: LOG.exception( 'An undefined error occurred during zone import creation' diff --git a/designate/tests/__init__.py b/designate/tests/__init__.py index eb3a02bbc..86ca13810 100644 --- a/designate/tests/__init__.py +++ b/designate/tests/__init__.py @@ -807,34 +807,36 @@ class TestCase(base.BaseTestCase): return self.storage.create_zone_export( context, objects.ZoneExport.from_dict(zone_export)) - def wait_for_import(self, zone_import_id, errorok=False): + def wait_for_import(self, zone_import_id, error_is_ok=False, max_wait=10): """ Zone imports spawn a thread to parse the zone file and insert the data. This waits for this process before continuing """ - attempts = 0 - while attempts < 20: - # Give the import a half second to complete - time.sleep(.5) - + start_time = time.time() + while True: # Retrieve it, and ensure it's the same zone_import = self.central_service.get_zone_import( - self.admin_context_all_tenants, zone_import_id) + self.admin_context_all_tenants, zone_import_id + ) # If the import is done, we're done if zone_import.status == 'COMPLETE': break # If errors are allowed, just make sure that something completed - if errorok: - if zone_import.status != 'PENDING': - break + if error_is_ok and zone_import.status != 'PENDING': + break - attempts += 1 + if (time.time() - start_time) > max_wait: + break - if not errorok: + time.sleep(0.5) + + if not error_is_ok: self.assertEqual('COMPLETE', zone_import.status) + return zone_import + def _ensure_interface(self, interface, implementation): for name in interface.__abstractmethods__: in_arginfo = inspect.getfullargspec(getattr(interface, name)) diff --git a/designate/tests/test_api/test_v2/test_import_export.py b/designate/tests/test_api/test_v2/test_import_export.py index 7855f88b3..8204da5d1 100644 --- a/designate/tests/test_api/test_v2/test_import_export.py +++ b/designate/tests/test_api/test_v2/test_import_export.py @@ -53,7 +53,7 @@ class APIV2ZoneImportExportTest(ApiV2TestCase): headers={'Content-type': 'text/dns'}) import_id = response.json_body['id'] - self.wait_for_import(import_id, errorok=True) + self.wait_for_import(import_id, error_is_ok=True) url = '/zones/tasks/imports/%s' % import_id @@ -70,7 +70,7 @@ class APIV2ZoneImportExportTest(ApiV2TestCase): headers={'Content-type': 'text/dns'}) import_id = response.json_body['id'] - self.wait_for_import(import_id, errorok=True) + self.wait_for_import(import_id, error_is_ok=True) url = '/zones/tasks/imports/%s' % import_id @@ -86,7 +86,7 @@ class APIV2ZoneImportExportTest(ApiV2TestCase): headers={'Content-type': 'text/dns'}) import_id = response.json_body['id'] - self.wait_for_import(import_id, errorok=True) + self.wait_for_import(import_id, error_is_ok=True) url = '/zones/tasks/imports/%s' % import_id diff --git a/designate/tests/test_central/test_service.py b/designate/tests/test_central/test_service.py index b0332a120..60d9e15b8 100644 --- a/designate/tests/test_central/test_service.py +++ b/designate/tests/test_central/test_service.py @@ -3395,6 +3395,30 @@ class CentralServiceTest(CentralTestCase): self.wait_for_import(zone_import.id) + def test_create_zone_import_overquota(self): + self.config( + quota_zone_records=5, + quota_zone_recordsets=5, + ) + + # Create a Zone Import + context = self.get_context(project_id=utils.generate_uuid()) + request_body = self.get_zonefile_fixture() + zone_import = self.central_service.create_zone_import(context, + request_body) + + # Ensure all values have been set correctly + self.assertIsNotNone(zone_import['id']) + self.assertEqual('PENDING', zone_import.status) + self.assertIsNone(zone_import.message) + self.assertIsNone(zone_import.zone_id) + + zone_import = self.wait_for_import(zone_import.id, error_is_ok=True) + + self.assertEqual('Quota exceeded during zone import.', + zone_import.message) + self.assertEqual('ERROR', zone_import.status) + def test_find_zone_imports(self): context = self.get_context(project_id=utils.generate_uuid())