diff --git a/octavia/controller/worker/v2/tasks/database_tasks.py b/octavia/controller/worker/v2/tasks/database_tasks.py index 876e125e8d..7488b01477 100644 --- a/octavia/controller/worker/v2/tasks/database_tasks.py +++ b/octavia/controller/worker/v2/tasks/database_tasks.py @@ -2872,27 +2872,27 @@ class DecrementL7policyQuota(BaseDatabaseTask): :param l7policy: The l7policy to decrement the quota on. :returns: None """ - LOG.debug("Decrementing l7policy quota for " - "project: %s ", l7policy.project_id) - + "project: %s ", l7policy[constants.PROJECT_ID]) lock_session = db_apis.get_session(autocommit=False) try: self.repos.decrement_quota(lock_session, data_models.L7Policy, - l7policy.project_id) + l7policy[constants.PROJECT_ID]) + db_l7policy = self.l7policy_repo.get( + db_apis.get_session(), id=l7policy[constants.L7POLICY_ID]) - if l7policy.l7rules: + if db_l7policy and db_l7policy.l7rules: self.repos.decrement_quota(lock_session, data_models.L7Rule, - l7policy.project_id, - quantity=len(l7policy.l7rules)) + l7policy[constants.PROJECT_ID], + quantity=len(db_l7policy.l7rules)) lock_session.commit() except Exception: with excutils.save_and_reraise_exception(): LOG.error('Failed to decrement l7policy quota for project: ' '%(proj)s the project may have excess quota in use.', - {'proj': l7policy.project_id}) + {'proj': l7policy[constants.L7POLICY_ID]}) lock_session.rollback() def revert(self, l7policy, result, *args, **kwargs): @@ -2901,14 +2901,11 @@ class DecrementL7policyQuota(BaseDatabaseTask): :param l7policy: The l7policy to decrement the quota on. :returns: None """ - LOG.warning('Reverting decrement quota for l7policy on project' ' %(proj)s Project quota counts may be incorrect.', - {'proj': l7policy.project_id}) - + {'proj': l7policy[constants.PROJECT_ID]}) # Increment the quota back if this task wasn't the failure if not isinstance(result, failure.Failure): - try: session = db_apis.get_session() lock_session = db_apis.get_session(autocommit=False) @@ -2916,22 +2913,23 @@ class DecrementL7policyQuota(BaseDatabaseTask): self.repos.check_quota_met(session, lock_session, data_models.L7Policy, - l7policy.project_id) + l7policy[constants.PROJECT_ID]) lock_session.commit() except Exception: lock_session.rollback() - - # Attempt to increment back the L7Rule quota - for i in range(len(l7policy.l7rules)): - lock_session = db_apis.get_session(autocommit=False) - try: - self.repos.check_quota_met(session, - lock_session, - data_models.L7Rule, - l7policy.project_id) - lock_session.commit() - except Exception: - lock_session.rollback() + db_l7policy = self.l7policy_repo.get( + session, id=l7policy[constants.L7POLICY_ID]) + if db_l7policy: + # Attempt to increment back the L7Rule quota + for i in range(len(db_l7policy.l7rules)): + lock_session = db_apis.get_session(autocommit=False) + try: + self.repos.check_quota_met( + session, lock_session, data_models.L7Rule, + db_l7policy.project_id) + lock_session.commit() + except Exception: + lock_session.rollback() except Exception: # Don't fail the revert flow pass @@ -2951,19 +2949,19 @@ class DecrementL7ruleQuota(BaseDatabaseTask): """ LOG.debug("Decrementing l7rule quota for " - "project: %s ", l7rule.project_id) + "project: %s ", l7rule[constants.PROJECT_ID]) lock_session = db_apis.get_session(autocommit=False) try: self.repos.decrement_quota(lock_session, data_models.L7Rule, - l7rule.project_id) + l7rule[constants.PROJECT_ID]) lock_session.commit() except Exception: with excutils.save_and_reraise_exception(): LOG.error('Failed to decrement l7rule quota for project: ' '%(proj)s the project may have excess quota in use.', - {'proj': l7rule.project_id}) + {'proj': l7rule[constants.PROJECT_ID]}) lock_session.rollback() def revert(self, l7rule, result, *args, **kwargs): @@ -2975,7 +2973,7 @@ class DecrementL7ruleQuota(BaseDatabaseTask): LOG.warning('Reverting decrement quota for l7rule on project %(proj)s ' 'Project quota counts may be incorrect.', - {'proj': l7rule.project_id}) + {'proj': l7rule[constants.PROJECT_ID]}) # Increment the quota back if this task wasn't the failure if not isinstance(result, failure.Failure): @@ -2987,7 +2985,7 @@ class DecrementL7ruleQuota(BaseDatabaseTask): self.repos.check_quota_met(session, lock_session, data_models.L7Rule, - l7rule.project_id) + l7rule[constants.PROJECT_ID]) lock_session.commit() except Exception: lock_session.rollback() diff --git a/octavia/tests/unit/controller/worker/v2/tasks/test_database_tasks_quota.py b/octavia/tests/unit/controller/worker/v2/tasks/test_database_tasks_quota.py index 798c23fab7..d202b86586 100644 --- a/octavia/tests/unit/controller/worker/v2/tasks/test_database_tasks_quota.py +++ b/octavia/tests/unit/controller/worker/v2/tasks/test_database_tasks_quota.py @@ -17,6 +17,7 @@ from unittest import mock from oslo_utils import uuidutils from taskflow.types import failure +from octavia.common import constants from octavia.common import data_models from octavia.common import exceptions from octavia.controller.worker.v2.tasks import database_tasks @@ -32,19 +33,25 @@ class TestDatabaseTasksQuota(base.TestCase): super(TestDatabaseTasksQuota, self).setUp() + @mock.patch('octavia.db.repositories.L7PolicyRepository.get') @mock.patch('octavia.db.api.get_session', return_value='TEST') @mock.patch('octavia.db.repositories.Repositories.decrement_quota') @mock.patch('octavia.db.repositories.Repositories.check_quota_met') def _test_decrement_quota(self, task, data_model, mock_check_quota_met, mock_decrement_quota, - mock_get_session, project_id=None): + mock_get_session, mock_l7policy_get, + project_id=None): test_object = None - if project_id: + if project_id and data_model == data_models.L7Rule: + test_object = {constants.PROJECT_ID: project_id} + elif project_id: test_object = project_id else: project_id = uuidutils.generate_uuid() test_object = mock.MagicMock() test_object.project_id = project_id + l7policy_dict = {constants.PROJECT_ID: project_id, + constants.L7POLICY_ID: uuidutils.generate_uuid()} # execute without exception mock_decrement_quota.reset_mock() @@ -55,9 +62,12 @@ class TestDatabaseTasksQuota(base.TestCase): if data_model == data_models.L7Policy: test_object.l7rules = [] + mock_l7policy_get.return_value = test_object if data_model == data_models.Pool: task.execute(test_object, self.zero_pool_child_count) + elif data_model == data_models.L7Policy: + task.execute(l7policy_dict) else: task.execute(test_object) @@ -80,6 +90,10 @@ class TestDatabaseTasksQuota(base.TestCase): task.execute, test_object, self.zero_pool_child_count) + elif data_model == data_models.L7Policy: + self.assertRaises(exceptions.OctaviaException, + task.execute, + l7policy_dict) else: self.assertRaises(exceptions.OctaviaException, task.execute, @@ -97,6 +111,8 @@ class TestDatabaseTasksQuota(base.TestCase): task.revert(test_object, self.zero_pool_child_count, self._tf_failure_mock) + elif data_model == data_models.L7Policy: + task.revert(l7policy_dict, self._tf_failure_mock) else: task.revert(test_object, self._tf_failure_mock) self.assertFalse(mock_get_session.called) @@ -114,6 +130,8 @@ class TestDatabaseTasksQuota(base.TestCase): if data_model == data_models.Pool: task.revert(test_object, self.zero_pool_child_count, None) + elif data_model == data_models.L7Policy: + task.revert(l7policy_dict, None) else: task.revert(test_object, None) @@ -137,6 +155,8 @@ class TestDatabaseTasksQuota(base.TestCase): if data_model == data_models.Pool: task.revert(test_object, self.zero_pool_child_count, None) + elif data_model == data_models.L7Policy: + task.revert(l7policy_dict, None) else: task.revert(test_object, None) @@ -150,6 +170,8 @@ class TestDatabaseTasksQuota(base.TestCase): if data_model == data_models.Pool: task.revert(test_object, self.zero_pool_child_count, None) + elif data_model == data_models.L7Policy: + task.revert(l7policy_dict, None) else: task.revert(test_object, None) @@ -349,17 +371,23 @@ class TestDatabaseTasksQuota(base.TestCase): @mock.patch('octavia.db.repositories.Repositories.decrement_quota') @mock.patch('octavia.db.repositories.Repositories.check_quota_met') + @mock.patch('octavia.db.repositories.L7PolicyRepository.get') def test_decrement_l7policy_quota_with_children(self, + mock_l7policy_get, mock_check_quota_met, mock_decrement_quota): project_id = uuidutils.generate_uuid() + l7_policy_id = uuidutils.generate_uuid() test_l7rule1 = mock.MagicMock() test_l7rule1.project_id = project_id test_l7rule2 = mock.MagicMock() test_l7rule2.project_id = project_id - test_object = mock.MagicMock() - test_object.project_id = project_id - test_object.l7rules = [test_l7rule1, test_l7rule2] + test_object = {constants.PROJECT_ID: project_id, + constants.L7POLICY_ID: l7_policy_id} + db_test_object = mock.MagicMock() + db_test_object.project_id = project_id + db_test_object.l7rules = [test_l7rule1, test_l7rule2] + mock_l7policy_get.return_value = db_test_object task = database_tasks.DecrementL7policyQuota() mock_session = mock.MagicMock() @@ -427,6 +455,8 @@ class TestDatabaseTasksQuota(base.TestCase): self.assertEqual(1, mock_lock_session.rollback.call_count) def test_decrement_l7rule_quota(self): + project_id = uuidutils.generate_uuid() task = database_tasks.DecrementL7ruleQuota() data_model = data_models.L7Rule - self._test_decrement_quota(task, data_model) + self._test_decrement_quota(task, data_model, + project_id=project_id)