diff --git a/keystone/limit/backends/base.py b/keystone/limit/backends/base.py index f20b974fe8..2c48b14cc5 100644 --- a/keystone/limit/backends/base.py +++ b/keystone/limit/backends/base.py @@ -39,7 +39,7 @@ class UnifiedLimitDriverBase(object): :param registered_limits: a list of dictionaries representing limits to create. - :returns: all the registered limits. + :returns: all the newly created registered limits. :raises keystone.exception.Conflict: If a duplicate registered limit exists. @@ -106,7 +106,7 @@ class UnifiedLimitDriverBase(object): :param limits: a list of dictionaries representing limits to create. - :returns: all the limits. + :returns: all the newly created limits. :raises keystone.exception.Conflict: If a duplicate limit exists. :raises keystone.exception.NoLimitReference: If no reference registered limit exists. diff --git a/keystone/limit/backends/sql.py b/keystone/limit/backends/sql.py index b35be840c3..5d3ef958b3 100644 --- a/keystone/limit/backends/sql.py +++ b/keystone/limit/backends/sql.py @@ -135,11 +135,14 @@ class UnifiedLimit(base.UnifiedLimitDriverBase): @sql.handle_conflicts(conflict_type='registered_limit') def create_registered_limits(self, registered_limits): with sql.session_for_write() as session: + new_registered_limits = [] for registered_limit in registered_limits: if registered_limit.get('region_id') is None: self._check_unified_limit_without_region(registered_limit) ref = RegisteredLimitModel.from_dict(registered_limit) session.add(ref) + new_registered_limits.append(ref.to_dict()) + return new_registered_limits @sql.handle_conflicts(conflict_type='registered_limit') def update_registered_limits(self, registered_limits): @@ -206,12 +209,15 @@ class UnifiedLimit(base.UnifiedLimitDriverBase): def create_limits(self, limits): try: with sql.session_for_write() as session: + new_limits = [] for limit in limits: if limit.get('region_id') is None: self._check_unified_limit_without_region( limit, is_registered_limit=False) ref = LimitModel.from_dict(limit) session.add(ref) + new_limits.append(ref.to_dict()) + return new_limits except db_exception.DBReferenceError: raise exception.NoLimitReference() diff --git a/keystone/limit/core.py b/keystone/limit/core.py index c8d8a8496a..bd8c03ac5d 100644 --- a/keystone/limit/core.py +++ b/keystone/limit/core.py @@ -59,8 +59,7 @@ class Manager(manager.Manager): def create_registered_limits(self, registered_limits): for registered_limit in registered_limits: self._assert_resource_exist(registered_limit, 'registered_limit') - self.driver.create_registered_limits(registered_limits) - return self.list_registered_limits() + return self.driver.create_registered_limits(registered_limits) def update_registered_limits(self, registered_limits): for registered_limit in registered_limits: @@ -86,8 +85,7 @@ class Manager(manager.Manager): def create_limits(self, limits): for limit in limits: self._assert_resource_exist(limit, 'limit') - self.driver.create_limits(limits) - return self.list_limits() + return self.driver.create_limits(limits) def update_limits(self, limits): for limit in limits: diff --git a/keystone/tests/unit/limit/test_backends.py b/keystone/tests/unit/limit/test_backends.py index 315771ee0d..491b15a1af 100644 --- a/keystone/tests/unit/limit/test_backends.py +++ b/keystone/tests/unit/limit/test_backends.py @@ -24,28 +24,32 @@ PROVIDERS = provider_api.ProviderAPIs class RegisteredLimitTests(object): def test_create_registered_limit_crud(self): - # create one, return all registered_limits + # create one, return it. registered_limit_1 = unit.new_registered_limit_ref( service_id=self.service_one['id'], region_id=self.region_one['id'], resource_name='volume', default_limit=10, id=uuid.uuid4().hex) - res1 = PROVIDERS.unified_limit_api.create_registered_limits( + reg_limits = PROVIDERS.unified_limit_api.create_registered_limits( [registered_limit_1]) - self.assertDictEqual(registered_limit_1, res1[0]) + self.assertDictEqual(registered_limit_1, reg_limits[0]) - # create another, return all registered_limits + # create another two, return them. registered_limit_2 = unit.new_registered_limit_ref( service_id=self.service_one['id'], region_id=self.region_one['id'], resource_name='snapshot', default_limit=5, id=uuid.uuid4().hex) - res2 = PROVIDERS.unified_limit_api.create_registered_limits( - [registered_limit_2]) - self.assertEqual(2, len(res2)) - for re in res2: - if re['id'] == registered_limit_1['id']: - self.assertDictEqual(registered_limit_1, re) - if re['id'] == registered_limit_2['id']: - self.assertDictEqual(registered_limit_2, re) + registered_limit_3 = unit.new_registered_limit_ref( + service_id=self.service_one['id'], + region_id=self.region_one['id'], + resource_name='backup', default_limit=5, id=uuid.uuid4().hex) + reg_limits = PROVIDERS.unified_limit_api.create_registered_limits( + [registered_limit_2, registered_limit_3]) + self.assertEqual(2, len(reg_limits)) + for reg_limit in reg_limits: + if reg_limit['id'] == registered_limit_2['id']: + self.assertDictEqual(registered_limit_2, reg_limit) + if reg_limit['id'] == registered_limit_3['id']: + self.assertDictEqual(registered_limit_3, reg_limit) def test_create_registered_limit_duplicate(self): registered_limit_1 = unit.new_registered_limit_ref( @@ -371,27 +375,33 @@ class RegisteredLimitTests(object): class LimitTests(object): def test_create_limit(self): - # create one, return all limits + # create one, return it. limit_1 = unit.new_limit_ref( project_id=self.tenant_bar['id'], service_id=self.service_one['id'], region_id=self.region_one['id'], resource_name='volume', resource_limit=10, id=uuid.uuid4().hex) - res1 = PROVIDERS.unified_limit_api.create_limits([limit_1]) - self.assertDictEqual(limit_1, res1[0]) + limits = PROVIDERS.unified_limit_api.create_limits([limit_1]) + self.assertDictEqual(limit_1, limits[0]) - # create another, return all limits + # create another two, return them. limit_2 = unit.new_limit_ref( project_id=self.tenant_bar['id'], service_id=self.service_one['id'], region_id=self.region_two['id'], resource_name='snapshot', resource_limit=5, id=uuid.uuid4().hex) - res2 = PROVIDERS.unified_limit_api.create_limits([limit_2]) - for re in res2: - if re['id'] == limit_1['id']: - self.assertDictEqual(limit_1, re) - if re['id'] == limit_2['id']: - self.assertDictEqual(limit_2, re) + limit_3 = unit.new_limit_ref( + project_id=self.tenant_bar['id'], + service_id=self.service_one['id'], + region_id=self.region_two['id'], + resource_name='backup', resource_limit=5, id=uuid.uuid4().hex) + + limits = PROVIDERS.unified_limit_api.create_limits([limit_2, limit_3]) + for limit in limits: + if limit['id'] == limit_2['id']: + self.assertDictEqual(limit_2, limit) + if limit['id'] == limit_3['id']: + self.assertDictEqual(limit_3, limit) def test_create_limit_duplicate(self): limit_1 = unit.new_limit_ref( diff --git a/keystone/tests/unit/test_limits.py b/keystone/tests/unit/test_limits.py index 7ebfdbb787..1e6fccea35 100644 --- a/keystone/tests/unit/test_limits.py +++ b/keystone/tests/unit/test_limits.py @@ -84,6 +84,26 @@ class RegisteredLimitsTestCase(test_v3.RestfulTestCase): self.assertEqual(registered_limits[0]['region_id'], ref1['region_id']) self.assertIsNone(registered_limits[1].get('region_id')) + def test_create_registered_limit_return_count(self): + ref1 = unit.new_registered_limit_ref(service_id=self.service_id, + region_id=self.region_id) + r = self.post( + '/registered_limits', + body={'registered_limits': [ref1]}, + expected_status=http_client.CREATED) + registered_limits = r.result['registered_limits'] + self.assertEqual(1, len(registered_limits)) + + ref2 = unit.new_registered_limit_ref(service_id=self.service_id2, + region_id=self.region_id2) + ref3 = unit.new_registered_limit_ref(service_id=self.service_id2) + r = self.post( + '/registered_limits', + body={'registered_limits': [ref2, ref3]}, + expected_status=http_client.CREATED) + registered_limits = r.result['registered_limits'] + self.assertEqual(2, len(registered_limits)) + def test_create_registered_limit_with_invalid_input(self): ref1 = unit.new_registered_limit_ref() ref2 = unit.new_registered_limit_ref(default_limit='not_int') @@ -438,6 +458,33 @@ class LimitsTestCase(test_v3.RestfulTestCase): self.assertEqual(limits[0]['region_id'], ref1['region_id']) self.assertIsNone(limits[1].get('region_id')) + def test_create_limit_return_count(self): + ref1 = unit.new_limit_ref(project_id=self.project_id, + service_id=self.service_id, + region_id=self.region_id, + resource_name='volume') + r = self.post( + '/limits', + body={'limits': [ref1]}, + expected_status=http_client.CREATED) + limits = r.result['limits'] + self.assertEqual(1, len(limits)) + + ref2 = unit.new_limit_ref(project_id=self.project_id, + service_id=self.service_id, + region_id=self.region_id, + resource_name='snapshot') + ref3 = unit.new_limit_ref(project_id=self.project_id, + service_id=self.service_id, + region_id=self.region_id, + resource_name='backup') + r = self.post( + '/limits', + body={'limits': [ref2, ref3]}, + expected_status=http_client.CREATED) + limits = r.result['limits'] + self.assertEqual(2, len(limits)) + def test_create_limit_with_invalid_input(self): ref1 = unit.new_limit_ref(project_id=self.project_id, resource_limit='not_int')