diff --git a/octavia/db/repositories.py b/octavia/db/repositories.py index f697f13f02..45532d94c0 100644 --- a/octavia/db/repositories.py +++ b/octavia/db/repositories.py @@ -1533,6 +1533,10 @@ class L7RuleRepository(BaseRepository): with session.begin(subtransactions=True): if not model_kwargs.get('id'): model_kwargs.update(id=uuidutils.generate_uuid()) + if model_kwargs.get('l7policy_id'): + l7policy_db = session.query(models.L7Policy).filter_by( + id=model_kwargs.get('l7policy_id')).first() + model_kwargs.update(l7policy=l7policy_db) l7rule = self.model_class(**model_kwargs) validate.l7rule_data(l7rule) session.add(l7rule) @@ -1681,6 +1685,10 @@ class L7PolicyRepository(BaseRepository): pool_db = session.query(models.Pool).filter_by( id=model_kwargs.get('redirect_pool_id')).first() model_kwargs.update(redirect_pool=pool_db) + if model_kwargs.get('listener_id'): + listener_db = session.query(models.Listener).filter_by( + id=model_kwargs.get('listener_id')).first() + model_kwargs.update(listener=listener_db) l7policy = self.model_class( **validate.sanitize_l7policy_api_args(model_kwargs, create=True)) diff --git a/octavia/tests/functional/db/test_repositories.py b/octavia/tests/functional/db/test_repositories.py index 6071818e20..6a896d6392 100644 --- a/octavia/tests/functional/db/test_repositories.py +++ b/octavia/tests/functional/db/test_repositories.py @@ -3583,6 +3583,14 @@ class L7PolicyRepositoryTest(BaseRepositoryTest): new_l7policy.action) self.assertEqual(1, new_l7policy.position) + def test_l7policy_create_no_listener_id(self): + self.assertRaises( + db_exception.DBError, self.l7policy_repo.create, + self.session, action=constants.L7POLICY_ACTION_REJECT, + operating_status=constants.ONLINE, + provisioning_status=constants.ACTIVE, + enabled=True) + def test_update(self): new_url = 'http://www.example.com/' listener = self.create_listener(uuidutils.generate_uuid(), 80) @@ -3952,6 +3960,16 @@ class L7RuleRepositoryTest(BaseRepositoryTest): self.assertEqual('something', new_l7rule.value) self.assertFalse(new_l7rule.invert) + def test_l7rule_create_wihout_l7policy_id(self): + self.assertRaises( + db_exception.DBError, self.l7rule_repo.create, + self.session, id=None, type=constants.L7RULE_TYPE_PATH, + compare_type=constants.L7RULE_COMPARE_TYPE_CONTAINS, + provisioning_status=constants.ACTIVE, + operating_status=constants.ONLINE, + value='something', + enabled=True) + def test_update(self): l7rule = self.create_l7rule(uuidutils.generate_uuid(), self.l7policy.id,