diff --git a/congress/api/policy_model.py b/congress/api/policy_model.py index 9e7674df9..f1da870b5 100644 --- a/congress/api/policy_model.py +++ b/congress/api/policy_model.py @@ -26,6 +26,7 @@ from congress.api import base from congress.api import error_codes from congress.api import webservice from congress import exception +from congress.library_service import library_service class PolicyModel(base.APIModel): @@ -90,7 +91,44 @@ class PolicyModel(base.APIModel): Raises: KeyError: ID already exists. DataModelException: Addition cannot be performed. + BadRequest: library_policy parameter and request body both present """ + # case 1: parameter gives library policy UUID + if 'library_policy' in params: + if item is not None: + raise exception.BadRequest( + 'Policy creation reqest with `library_policy` parameter ' + 'must not have body.') + try: + # Note(thread-safety): blocking call + library_policy_object = self.invoke_rpc( + base.LIBRARY_SERVICE_ID, + 'get_policy', {'id_': params['library_policy']}) + + policy_metadata = self.invoke_rpc( + base.ENGINE_SERVICE_ID, + 'persistent_create_policy_with_rules', + {'policy_rules_obj': library_policy_object}) + except exception.CongressException as e: + raise webservice.DataModelException.create(e) + + return (policy_metadata['id'], policy_metadata) + + # case 2: item contains rules + if 'rules' in item: + try: + library_service.validate_policy_item(item) + # Note(thread-safety): blocking call + policy_metadata = self.invoke_rpc( + base.ENGINE_SERVICE_ID, + 'persistent_create_policy_with_rules', + {'policy_rules_obj': item}) + except exception.CongressException as e: + raise webservice.DataModelException.create(e) + + return (policy_metadata['id'], policy_metadata) + + # case 3: item does not contain rules self._check_create_policy(id_, item) name = item['name'] try: diff --git a/congress/db/api.py b/congress/db/api.py index c5568bcf6..31d20b340 100644 --- a/congress/db/api.py +++ b/congress/db/api.py @@ -73,8 +73,8 @@ def lock_tables(session, tables): session.begin(subtransactions=True) if is_mysql(): # Explicitly LOCK TABLES for MySQL session.execute('SET autocommit=0') - for table in tables: - session.execute('LOCK TABLES {} WRITE'.format(table)) + session.execute('LOCK TABLES {}'.format( + ','.join([table + ' WRITE' for table in tables]))) elif is_postgres(): # Explicitly LOCK TABLE for Postgres session.execute('BEGIN TRANSACTION') for table in tables: @@ -83,8 +83,8 @@ def lock_tables(session, tables): def commit_unlock_tables(session): """Commit and unlock tables for supported backends: MySQL and PostgreSQL""" - session.commit() session.execute('COMMIT') # execute COMMIT on DB backend + session.commit() # because sqlalchemy session does not guarantee # exact boundary correspondence to DB backend transactions # We must guarantee DB commits transaction before UNLOCK @@ -100,13 +100,14 @@ def rollback_unlock_tables(session): supported backends: MySQL and PostgreSQL """ - session.rollback() - # unlock if is_mysql(): session.execute('UNLOCK TABLES') + # postgres automatically releases lock at transaction end + session.rollback() + def is_mysql(): """Return true if and only if database backend is mysql""" diff --git a/congress/db/db_policy_rules.py b/congress/db/db_policy_rules.py index 442561e00..3da2ff0db 100644 --- a/congress/db/db_policy_rules.py +++ b/congress/db/db_policy_rules.py @@ -91,7 +91,20 @@ class PolicyDeleted(model_base.BASE, model_base.HasId, model_base.HasAudit): def add_policy(id_, name, abbreviation, description, owner, kind, deleted=False, session=None): - session = session or db.get_session() + if session: + # IMPORTANT: if session provided, do not interrupt existing transaction + # with BEGIN which can drop db locks and change desired transaction + # boundaries for proper commit and rollback + try: + policy = Policy(id_, name, abbreviation, description, owner, + kind, deleted) + session.add(policy) + return policy + except oslo_db_exc.DBDuplicateEntry: + raise KeyError("Policy with name %s already exists" % name) + + # else + session = db.get_session() try: with session.begin(subtransactions=True): policy = Policy(id_, name, abbreviation, description, owner, @@ -206,7 +219,17 @@ class PolicyRule(model_base.BASE, model_base.HasId, model_base.HasAudit): def add_policy_rule(id, policy_name, rule, comment, deleted=False, rule_name="", session=None): - session = session or db.get_session() + if session: + # IMPORTANT: if session provided, do not interrupt existing transaction + # with BEGIN which can drop db locks and change desired transaction + # boundaries for proper commit and rollback + policy_rule = PolicyRule(id, policy_name, rule, comment, + deleted, rule_name=rule_name) + session.add(policy_rule) + return policy_rule + + # else + session = db.get_session() with session.begin(subtransactions=True): policy_rule = PolicyRule(id, policy_name, rule, comment, deleted, rule_name=rule_name) diff --git a/congress/library_service/library_service.py b/congress/library_service/library_service.py index 93778805f..e4504c9f0 100644 --- a/congress/library_service/library_service.py +++ b/congress/library_service/library_service.py @@ -35,6 +35,73 @@ from congress import exception LOG = logging.getLogger(__name__) +def validate_policy_item(item): + schema_json = ''' + { + "id": "PolicyProperties", + "title": "Policy Properties", + "type": "object", + "required": ["name", "rules"], + "properties": { + "name": { + "title": "Policy unique name", + "type": "string", + "minLength": 1, + "maxLength": 255 + }, + "description": { + "title": "Policy description", + "type": "string" + }, + "kind": { + "title": "Policy kind", + "type": "string", + "enum": ["database", "nonrecursive", "action", "materialized", + "delta", "datasource"] + }, + "abbreviation": { + "title": "Policy name abbreviation", + "type": "string", + "minLength": 1, + "maxLength": 5 + }, + "rules": { + "title": "collection of rules", + "type": "array", + "items": { + "type": "object", + "properties": { + "PolicyRule": { + "title": "Policy rule", + "type": "object", + "required": ["rule"], + "properties": { + "rule": { + "title": "Rule definition following policy grammar", + "type": "string" + }, + "name": { + "title": "User-friendly name", + "type": "string" + }, + "comment": { + "title": "User-friendly comment", + "type": "string" + } + } + } + } + } + } + } + } + ''' + try: + jsonschema.validate(item, json.loads(schema_json)) + except jsonschema.exceptions.ValidationError as ve: + raise exception.InvalidPolicyInput(data=str(ve)) + + class LibraryService (data_service.DataService): def __init__(self, name): data_service.DataService.__init__(self, name) @@ -43,7 +110,7 @@ class LibraryService (data_service.DataService): def create_policy(self, policy_dict): policy_dict = copy.deepcopy(policy_dict) - self._validate_policy_item(policy_dict) + validate_policy_item(policy_dict) policy_name = policy_dict['name'] # check name is valid @@ -87,7 +154,7 @@ class LibraryService (data_service.DataService): return db_object.to_dict(include_rules=True) def replace_policy(self, id_, policy_dict): - self._validate_policy_item(policy_dict) + validate_policy_item(policy_dict) policy_name = policy_dict['name'] # check name is valid @@ -108,72 +175,6 @@ class LibraryService (data_service.DataService): id_, policy_dict=policy_dict) return policy.to_dict() - def _validate_policy_item(self, item): - schema_json = ''' - { - "id": "PolicyProperties", - "title": "Policy Properties", - "type": "object", - "required": ["name", "rules"], - "properties": { - "name": { - "title": "Policy unique name", - "type": "string", - "minLength": 1, - "maxLength": 255 - }, - "description": { - "title": "Policy description", - "type": "string" - }, - "kind": { - "title": "Policy kind", - "type": "string", - "enum": ["database", "nonrecursive", "action", "materialized", - "delta", "datasource"] - }, - "abbreviation": { - "title": "Policy name abbreviation", - "type": "string", - "minLength": 1, - "maxLength": 5 - }, - "rules": { - "title": "collection of rules", - "type": "array", - "items": { - "type": "object", - "properties": { - "PolicyRule": { - "title": "Policy rule", - "type": "object", - "required": ["rule"], - "properties": { - "rule": { - "title": "Rule definition following policy grammar", - "type": "string" - }, - "name": { - "title": "User-friendly name", - "type": "string" - }, - "comment": { - "title": "User-friendly comment", - "type": "string" - } - } - } - } - } - } - } - } - ''' - try: - jsonschema.validate(item, json.loads(schema_json)) - except jsonschema.exceptions.ValidationError as ve: - raise exception.InvalidPolicyInput(data=str(ve)) - def load_policies_from_files(self): def _load_library_policy_file(full_path): with open(full_path, "r") as stream: diff --git a/congress/policy_engines/agnostic.py b/congress/policy_engines/agnostic.py index a1023e593..7982cd4a9 100644 --- a/congress/policy_engines/agnostic.py +++ b/congress/policy_engines/agnostic.py @@ -244,9 +244,26 @@ class Runtime (object): ############################################### # Persistence layer ############################################### + # Note(thread-safety): blocking function + def persistent_create_policy_with_rules(self, policy_rules_obj): + rules, policy_metadata = self.persistent_insert_rules( + policy_name=policy_rules_obj['name'], + rules=policy_rules_obj['rules'], + create_policy=True, + abbr=policy_rules_obj.get('abbreviation'), + kind=policy_rules_obj.get('kind'), + desc=policy_rules_obj.get('description')) + + # remove the rule IDs + for rule in rules: + del rule['id'] + + policy_metadata['rules'] = rules + return policy_metadata + # Note(thread-safety): blocking function def persistent_create_policy(self, name, id_=None, abbr=None, kind=None, - desc=None): + desc=None, db_session=None): # validation for name if not compile.string_is_servicename(name): raise exception.PolicyException( @@ -274,7 +291,8 @@ class Runtime (object): obj['abbreviation'], obj['description'], obj['owner_id'], - obj['kind']) + obj['kind'], + session=db_session) except KeyError: raise except Exception: @@ -282,7 +300,13 @@ class Runtime (object): msg = "Error thrown while adding policy %s into DB." % policy_name LOG.exception(msg) raise exception.PolicyException(msg) - self.synchronizer.sync_one_policy(obj['name']) + if db_session: + # stay in current transaction, previous write may not be + # readable by synchronizer + self.add_policy_obj_to_runtime(policy_obj) + else: + self.synchronizer.sync_one_policy(obj['name'], + db_session=db_session) return obj # Note(thread-safety): blocking function @@ -333,15 +357,21 @@ class Runtime (object): def persistent_insert_rule(self, policy_name, str_rule, rule_name, comment): - rule_data = {'str_rule': str_rule, 'rule_name': rule_name, + rule_data = {'rule': str_rule, 'name': rule_name, 'comment': comment} - return_data = self.persistent_insert_rules(policy_name, [rule_data]) + return_data, _ = self.persistent_insert_rules(policy_name, [rule_data]) return (return_data[0]['id'], return_data[0]) # Note(thread-safety): blocking function # acquire lock to avoid periodic sync from undoing insert before persisted + # IMPORTANT: Be very careful to avoid deadlock when + # acquiring locks sequentially. In this case, we will acquire lock A + # then attempt to acquire lock B. We have to make sure no thread will hold + # lock B and attempt to acquire lock A, causing a deadlock + @lockutils.synchronized('congress_synchronize_policies') @lockutils.synchronized('congress_synchronize_rules') - def persistent_insert_rules(self, policy_name, rules): + def persistent_insert_rules(self, policy_name, rules, create_policy=False, + id_=None, abbr=None, kind=None, desc=None): """Insert and persists rule into policy_name.""" def uninsert_rules(rules_inserted): @@ -353,46 +383,59 @@ class Runtime (object): try: rules_to_persist = [] return_data = [] + # get session + db_session = db_api.get_locking_session() + + # lock policy_rules table to prevent conflicting rules + # insertion (say causing unsupported recursion) + # policies and datasources tables locked because + # it's a requirement of MySQL backend to lock all accessed tables + db_api.lock_tables(session=db_session, + tables=['policy_rules', 'policies', + 'datasources']) + if cfg.CONF.replicated_policy_engine: - # get session - db_session = db_api.get_locking_session() - db_session.begin(subtransactions=True) - - # lock policy_rules table to prevent conflicting rules - # insertion (say causing unsupported recursion) - db_api.lock_tables(session=db_session, tables=['policy_rules']) - # synchronize policy rules to get latest state, locked state # non-locking version because lock already acquired, # avoid deadlock self.synchronizer.synchronize_rules_nonlocking( db_session=db_session) - else: - db_session = None - # Reject rules inserted into non-persisted policies - # (i.e. datasource policies) - # Note(thread-safety): blocking call - policy_name = db_policy_rules.policy_name( - policy_name) - # call synchronizer to make sure policy is synchronized in memory - self.synchronizer.sync_one_policy(policy_name) - # Note(thread-safety): blocking call - policies = db_policy_rules.get_policies() - persisted_policies = set([p.name for p in policies]) - if policy_name not in persisted_policies: - if policy_name in self.theory: - LOG.debug( - "insert_persisted_rule error: rule not permitted for " - "policy %s", policy_name) - raise exception.PolicyRuntimeException( - name='rule_not_permitted') + # Note: it's important that this create policy is run after locking + # the policy_rules table, so as to prevent other nodes from + # inserting rules into this policy, which may be removed by an + # undo (delete the policy) later in this method + policy_metadata = None + if create_policy: + policy_metadata = self.persistent_create_policy( + id_=id_, name=policy_name, abbr=abbr, kind=kind, + desc=desc, db_session=db_session) + else: + # Reject rules inserted into non-persisted policies + # (i.e. datasource policies) + + # Note(thread-safety): blocking call + policy_name = db_policy_rules.policy_name( + policy_name, session=db_session) + # call synchronizer to make sure policy is sync'ed in memory + self.synchronizer.sync_one_policy_nonlocking( + policy_name, db_session=db_session) + # Note(thread-safety): blocking call + policies = db_policy_rules.get_policies(session=db_session) + persisted_policies = set([p.name for p in policies]) + if policy_name not in persisted_policies: + if policy_name in self.theory: + LOG.debug( + "insert_persisted_rule error: rule not permitted " + "for policy %s", policy_name) + raise exception.PolicyRuntimeException( + name='rule_not_permitted') rules_to_insert = [] for rule_data in rules: - str_rule = rule_data['str_rule'] - rule_name = rule_data['rule_name'] - comment = rule_data['comment'] + str_rule = rule_data['rule'] + rule_name = rule_data.get('name') + comment = rule_data.get('comment') id_ = uuidutils.generate_uuid() try: @@ -456,7 +499,7 @@ class Runtime (object): # do not begin to avoid implicitly releasing table # lock due to starting new transaction success = True - return return_data + return return_data, policy_metadata except Exception as db_exception: try: # un-insert all rules from engine unless all db inserts @@ -481,6 +524,10 @@ class Runtime (object): db_api.commit_unlock_tables(session=db_session) else: db_api.rollback_unlock_tables(session=db_session) + if create_policy: + # sync the potentially rolled back policy creation + self.synchronizer.sync_one_policy_nonlocking( + policy_name) db_session.close() # Note(thread-safety): blocking function @@ -2224,6 +2271,11 @@ class DseRuntimeEndpoints(object): # Note(thread-safety): blocking call return self.dse.persistent_create_policy(name, id_, abbr, kind, desc) + # Note(thread-safety): blocking function + def persistent_create_policy_with_rules(self, context, policy_rules_obj): + # Note(thread-safety): blocking call + return self.dse.persistent_create_policy_with_rules(policy_rules_obj) + # Note(thread-safety): blocking function def persistent_delete_policy(self, context, name_or_id): # Note(thread-safety): blocking call diff --git a/congress/synchronizer/policy_rule_synchronizer.py b/congress/synchronizer/policy_rule_synchronizer.py index 467685564..a50029d25 100644 --- a/congress/synchronizer/policy_rule_synchronizer.py +++ b/congress/synchronizer/policy_rule_synchronizer.py @@ -108,7 +108,12 @@ class PolicyRuleSynchronizer(object): return active_policies @lockutils.synchronized('congress_synchronize_policies') - def sync_one_policy(self, name, datasource=True): + def sync_one_policy(self, name, datasource=True, db_session=None): + return self.sync_one_policy_nonlocking( + name, datasource=datasource, db_session=db_session) + + def sync_one_policy_nonlocking( + self, name, datasource=True, db_session=None): """Synchronize single policy with DB. :param name: policy name to be synchronized @@ -118,13 +123,15 @@ class PolicyRuleSynchronizer(object): LOG.info("sync %s policy with DB", name) if datasource: - policy_object = datasources.get_datasource_by_name(name) + policy_object = datasources.get_datasource_by_name( + name, session=db_session) if policy_object is not None: if name not in self.engine.policy_names(): self._register_datasource_with_pe(name) return - policy_object = db_policy_rules.get_policy_by_name(name) + policy_object = db_policy_rules.get_policy_by_name( + name, session=db_session) if policy_object is None: if name in self.engine.policy_names(): self.engine.delete_policy(name) diff --git a/congress/tests/api/test_policy_model.py b/congress/tests/api/test_policy_model.py index 5ca8018d4..35047cf37 100644 --- a/congress/tests/api/test_policy_model.py +++ b/congress/tests/api/test_policy_model.py @@ -17,11 +17,13 @@ from __future__ import print_function from __future__ import division from __future__ import absolute_import +import copy import mock from oslo_utils import uuidutils from congress.api import error_codes from congress.api import webservice +from congress.datalog import compile from congress.tests.api import base as api_base from congress.tests import base from congress.tests import helper @@ -115,6 +117,84 @@ class TestPolicyModel(base.SqlTestCase): self.assertEqual(expected_ret1, policy_id) self.assertEqual(expected_ret2, policy_obj) + def test_add_item_with_rules(self): + res = self.policy_model.get_items({})['results'] + self.assertEqual(len(res), 4) # built-in-and-setup + + def adjust_for_comparison(rules): + # compile rule string into rule object + # replace dict with tuple for sorting + # 'id' field implicitly dropped if present + rules = [(compile.parse1(rule['rule']), rule['name'], + rule['comment']) for rule in rules] + + # sort lists for comparison + return sorted(rules) + + test_policy = { + "name": "test_rules_policy", + "description": "test policy description", + "kind": "nonrecursive", + "abbreviation": "abbr", + "rules": [{"rule": "p(x) :- q(x)", "comment": "test comment", + "name": "test name"}, + {"rule": "p(x) :- q2(x)", "comment": "test comment2", + "name": "test name2"}] + } + + add_policy_id, add_policy_obj = self.policy_model.add_item( + test_policy, {}) + + test_policy['id'] = add_policy_id + + # adjust for comparison + test_policy['owner_id'] = 'user' + test_policy['rules'] = adjust_for_comparison(test_policy['rules']) + + add_policy_obj['rules'] = adjust_for_comparison( + add_policy_obj['rules']) + + self.assertEqual(add_policy_obj, test_policy) + + res = self.policy_model.get_items({})['results'] + del test_policy['rules'] + self.assertIn(test_policy, res) + + res = self.policy_model.get_items({})['results'] + self.assertEqual(len(res), 5) + + # failure by duplicate policy name + duplicate_name_policy = copy.deepcopy(test_policy) + duplicate_name_policy['description'] = 'diff description' + duplicate_name_policy['abbreviation'] = 'diff' + duplicate_name_policy['rules'] = [] + + self.assertRaises( + KeyError, self.policy_model.add_item, duplicate_name_policy, {}) + + res = self.policy_model.get_items({})['results'] + self.assertEqual(len(res), 5) + + def test_add_item_with_bad_rules(self): + res = self.policy_model.get_items({})['results'] + self.assertEqual(len(res), 4) # two built-in and two setup policies + + test_policy = { + "name": "test_rules_policy", + "description": "test policy description", + "kind": "nonrecursive", + "abbreviation": "abbr", + "rules": [{"rule": "p(x) :- q(x)", "comment": "test comment", + "name": "test name"}, + {"rule": "p(x) ====:- q2(x)", "comment": "test comment2", + "name": "test name2"}] + } + self.assertRaises(webservice.DataModelException, + self.policy_model.add_item, test_policy, {}) + + res = self.policy_model.get_items({})['results'] + self.assertEqual(len(res), 4) # unchanged + def test_add_item_with_id(self): test = { "name": "test", diff --git a/congress/tests/policy_engines/test_agnostic.py b/congress/tests/policy_engines/test_agnostic.py index 2cc079bab..bfacb8c6f 100644 --- a/congress/tests/policy_engines/test_agnostic.py +++ b/congress/tests/policy_engines/test_agnostic.py @@ -259,7 +259,8 @@ class TestRuntime(base.TestCase): policy_name[:5], mock.ANY, 'user', - 'nonrecursive') + 'nonrecursive', + session=mock.ANY) # mock_delete.assert_called_once_with(policy_name) self.assertFalse(mock_delete.called) self.assertFalse(run.synchronizer.sync_one_policy.called) @@ -269,7 +270,8 @@ class TestRuntime(base.TestCase): setattr(mock_db_policy_obj, 'name', 'test_policy') @mock.patch.object(db_policy_rules, 'add_policy_rule') - @mock.patch.object(db_policy_rules, 'policy_name', side_effect=lambda x: x) + @mock.patch.object(db_policy_rules, 'policy_name', + side_effect=lambda x, session: x) @mock.patch.object( db_policy_rules, 'get_policies', return_value=[mock_db_policy_obj]) def test_persistent_insert_rules( @@ -279,17 +281,17 @@ class TestRuntime(base.TestCase): run.create_policy('test_policy') # test empty insert - result = run.persistent_insert_rules('test_policy', []) + result, _ = run.persistent_insert_rules('test_policy', []) self.assertEqual(len(result), 0) self.assertTrue(helper.datalog_equal( run.select('p(x)'), '')) # test duplicated insert, 3 rules, 2 unique - result = run.persistent_insert_rules( + result, _ = run.persistent_insert_rules( 'test_policy', - [{'str_rule': 'p(1)', 'rule_name': '', 'comment': ''}, - {'str_rule': 'p(2)', 'rule_name': '', 'comment': ''}, - {'str_rule': 'p(1)', 'rule_name': '', 'comment': ''}]) + [{'rule': 'p(1)', 'name': '', 'comment': ''}, + {'rule': 'p(2)', 'name': '', 'comment': ''}, + {'rule': 'p(1)', 'name': '', 'comment': ''}]) self.assertEqual(len(result), 2) self.assertTrue(helper.datalog_equal( run.select('p(x)'), 'p(1) p(2)')) diff --git a/congress/tests/test_congress.py b/congress/tests/test_congress.py index 932a49410..ecec8ba57 100644 --- a/congress/tests/test_congress.py +++ b/congress/tests/test_congress.py @@ -29,6 +29,7 @@ from oslo_log import log as logging from congress.api import base as api_base from congress.common import config +from congress.datalog import compile from congress.datasources import neutronv2_driver from congress.datasources import nova_driver from congress.db import db_library_policies @@ -99,6 +100,7 @@ class TestCongress(BaseTestPolicyCongress): def test_startup(self): self.assertIsNotNone(self.services['api']) self.assertIsNotNone(self.services['engine']) + self.assertIsNotNone(self.services['library']) self.assertIsNotNone(self.services['engine'].node) def test_policy(self): @@ -130,8 +132,103 @@ class TestCongress(BaseTestPolicyCongress): # asking for a snapshot and return []. # self.insert_rule('p(x) :- fake:fake_table(x)', 'alpha') + def test_policy_create_from_library(self): + def adjust_for_comparison(rules): + # compile rule string into rule object + # replace dict with tuple for sorting + # 'id' field implicitly dropped if present + rules = [(compile.parse1(rule['rule']), rule['name'], + rule['comment']) for rule in rules] + + # sort lists for comparison + return sorted(rules) + + test_policy = { + "name": "test_policy", + "description": "test policy description", + "kind": "nonrecursive", + "abbreviation": "abbr", + "rules": [{"rule": "p(x) :- q(x)", "comment": "test comment", + "name": "test name"}, + {"rule": "p(x) :- q2(x)", "comment": "test comment2", + "name": "test name2"}] + } + test_policy_id, test_policy_obj = self.api[ + 'api-library-policy'].add_item(test_policy, {}) + + add_policy_id, add_policy_obj = self.api['api-policy'].add_item( + None, {'library_policy': test_policy_id}) + + test_policy['id'] = add_policy_id + + # adjust for comparison + test_policy['owner_id'] = 'user' + test_policy['rules'] = adjust_for_comparison(test_policy['rules']) + + add_policy_obj['rules'] = adjust_for_comparison( + add_policy_obj['rules']) + + self.assertEqual(add_policy_obj, test_policy) + + context = {'policy_id': test_policy['name']} + rules = self.api['api-rule'].get_items({}, context)['results'] + rules = adjust_for_comparison(rules) + self.assertEqual(rules, test_policy['rules']) + + res = self.api['api-policy'].get_items({})['results'] + del test_policy['rules'] + self.assertIn(test_policy, res) + + def test_policy_create_with_rules(self): + def adjust_for_comparison(rules): + # compile rule string into rule object + # replace dict with tuple for sorting + # 'id' field implicitly dropped if present + rules = [(compile.parse1(rule['rule']), rule['name'], + rule['comment']) for rule in rules] + + # sort lists for comparison + return sorted(rules) + + test_policy = { + "name": "test_policy", + "description": "test policy description", + "kind": "nonrecursive", + "abbreviation": "abbr", + "rules": [{"rule": "p(x) :- q(x)", "comment": "test comment", + "name": "test name"}, + {"rule": "p(x) :- q2(x)", "comment": "test comment2", + "name": "test name2"}] + } + + add_policy_id, add_policy_obj = self.api['api-policy'].add_item( + test_policy, {}) + + test_policy['id'] = add_policy_id + + # adjust for comparison + test_policy['owner_id'] = 'user' + test_policy['rules'] = adjust_for_comparison(test_policy['rules']) + + add_policy_obj['rules'] = adjust_for_comparison( + add_policy_obj['rules']) + + self.assertEqual(add_policy_obj, test_policy) + + context = {'policy_id': test_policy['name']} + rules = self.api['api-rule'].get_items({}, context)['results'] + rules = adjust_for_comparison(rules) + self.assertEqual(rules, test_policy['rules']) + + res = self.api['api-policy'].get_items({})['results'] + del test_policy['rules'] + self.assertIn(test_policy, res) + def create_policy(self, name): - self.api['api-policy'].add_item({'name': name}, {}) + return self.api['api-policy'].add_item({'name': name}, {}) + + def create_policy_from_obj(self, policy_obj): + return self.api['api-policy'].add_item(policy_obj, {}) def insert_rule(self, rule, policy): context = {'policy_id': policy}