From 12d3dd2376e87a548e68b3775bc67670f9be7bc8 Mon Sep 17 00:00:00 2001 From: Tim Hinrichs Date: Tue, 14 Oct 2014 16:40:48 -0700 Subject: [PATCH] Added builtin safety and reordering to policy engine Previously, policy writers needed to ensure that left-to-right evaluation of the rule body guaranteed that all the inputs for a builtin were bound before that builtin was evaluated. Similarly for negation. This change automates the reordering process, but does so in a way that leaves the rule unchanged if it is properly ordered. This ensures that people can hand-tweak the ordering if they want, but they do not need to. It constitutes a first step toward query optimization. Change-Id: Ia8e522dcec655df6c49a71f748f6280b9df64eeb --- congress/policy/builtin/tests/test_builtin.py | 162 ++++++++++++++++-- congress/policy/compile.py | 100 +++++++---- congress/policy/runtime.py | 6 +- 3 files changed, 213 insertions(+), 55 deletions(-) diff --git a/congress/policy/builtin/tests/test_builtin.py b/congress/policy/builtin/tests/test_builtin.py index d293b40e7..c7c70bd9e 100755 --- a/congress/policy/builtin/tests/test_builtin.py +++ b/congress/policy/builtin/tests/test_builtin.py @@ -22,6 +22,7 @@ from congress.policy.builtin.congressbuiltin \ import CongressBuiltinCategoryMap as builtins from congress.policy.builtin.congressbuiltin import CongressBuiltinPred from congress.policy.builtin.congressbuiltin import start_builtin_map +from congress.policy import compile from congress.policy import runtime from congress.tests import helper @@ -94,6 +95,148 @@ class TestBuiltins(unittest.TestCase): self.assertEqual(result, False) +class TestReorder(unittest.TestCase): + def check(self, input_string, correct_string, msg): + rule = compile.parse1(input_string) + actual = compile.reorder_for_safety(rule) + correct = compile.parse1(correct_string) + if correct != actual: + emsg = "Correct: " + str(correct) + emsg += "; Actual: " + str(actual) + self.fail(msg + " :: " + emsg) + + def check_err(self, input_string, unsafe_lit_strings, msg): + rule = compile.parse1(input_string) + try: + compile.reorder_for_safety(rule) + self.fail("Failed to raise exception for " + input_string) + except compile.CongressException as e: + errmsg = str(e) + # parse then print to string so string rep same in err msg + unsafe_lits = [str(compile.parse1(x)) for x in unsafe_lit_strings] + missing_lits = [m for m in unsafe_lits + if m + " (vars" not in errmsg] + if len(missing_lits) > 0: + self.fail( + "Unsafe literals {} not reported in error: {}".format( + ";".join(missing_lits), errmsg)) + + def test_reorder_builtins(self): + self.check("p(x, z) :- q(x, y), plus(x, y, z)", + "p(x, z) :- q(x, y), plus(x, y, z)", + "No reorder") + + self.check("p(x, z) :- plus(x, y, z), q(x, y)", + "p(x, z) :- q(x, y), plus(x, y, z)", + "Basic reorder") + + self.check("p(x, z) :- q(x, y), r(w), plus(x, y, z), plus(z, w, y)", + "p(x, z) :- q(x, y), r(w), plus(x, y, z), plus(z, w, y)", + "Chaining: no reorder") + + self.check("p(x, z) :- q(x, y), plus(x, y, z), plus(z, w, y), r(w)", + "p(x, z) :- q(x, y), plus(x, y, z), r(w), plus(z, w, y)", + "Chaining: reorder") + + self.check("p(x) :- lt(t, v), plus(z, w, t), plus(z, u, v), " + " plus(x, y, z), q(y), r(x), s(u), t(w) ", + "p(x) :- q(y), r(x), plus(x, y, z), s(u), plus(z, u, v), " + " t(w), plus(z, w, t), lt(t, v)", + "Partial-order chaining") + + def test_unsafe_builtins(self): + # an output + self.check_err("p(x) :- q(x), plus(x, y, z)", + ["plus(x,y,z)"], + "Basic Unsafe input") + + self.check_err("p(x) :- q(x), r(z), plus(x, y, z)", + ["plus(x,y,z)"], + "Basic Unsafe input 2") + + self.check_err("p(x, z) :- plus(x, y, z), plus(z, y, x), " + " plus(x, z, y)", + ["plus(x, y, z)", "plus(z, y, x)", "plus(x, z, y)"], + "Unsafe with cycle") + + # no outputs + self.check_err("p(x) :- q(x), lt(x, y)", + ["lt(x,y)"], + "Basic Unsafe input, no outputs") + + self.check_err("p(x) :- q(y), lt(x, y)", + ["lt(x,y)"], + "Basic Unsafe input, no outputs 2") + + self.check_err("p(x, z) :- lt(x, y), lt(y, x)", + ["lt(x,y)", "lt(y, x)"], + "Unsafe with cycle, no outputs") + + # chaining + self.check_err("p(x) :- q(x, y), plus(x, y, z), plus(z, 3, w), " + " plus(w, t, u)", + ["plus(w, t, u)"], + "Unsafe chaining") + + self.check_err("p(x) :- q(x, y), plus(x, y, z), plus(z, 3, w), " + " lt(w, t)", + ["lt(w, t)"], + "Unsafe chaining 2") + + def test_reorder_negation(self): + self.check("p(x) :- q(x), not u(x), r(y), not s(x, y)", + "p(x) :- q(x), not u(x), r(y), not s(x, y)", + "No reordering") + + self.check("p(x) :- not q(x), r(x)", + "p(x) :- r(x), not q(x)", + "Basic") + + self.check("p(x) :- r(x), not q(x, y), s(y)", + "p(x) :- r(x), s(y), not q(x,y)", + "Partially safe") + + self.check("p(x) :- not q(x, y), not r(x), not r(x, z), " + " t(x, y), u(x), s(z)", + "p(x) :- t(x,y), not q(x,y), not r(x), u(x), s(z), " + " not r(x, z)", + "Complex") + + def test_unsafe_negation(self): + self.check_err("p(x) :- not q(x)", + ["q(x)"], + "Basic") + + self.check_err("p(x) :- not q(x), not r(x)", + ["q(x)", "r(x)"], + "Cycle") + + self.check_err("p(x) :- not q(x, y), r(y)", + ["q(x, y)"], + "Partially safe") + + def test_reorder_builtins_negation(self): + self.check("p(x) :- not q(z), plus(x, y, z), s(x), s(y)", + "p(x) :- s(x), s(y), plus(x, y, z), not q(z)", + "Basic") + + self.check("p(x) :- not q(z, w), plus(x, y, z), lt(z, w), " + " plus(x, 3, w), s(x, y)", + "p(x) :- s(x,y), plus(x, y, z), plus(x, 3, w), " + " not q(z, w), lt(z, w)", + "Partial order") + + def test_unsafe_builtins_negation(self): + self.check_err("p(x) :- plus(x, y, z), not q(x, y)", + ['plus(x,y,z)', 'q(x,y)'], + 'Unsafe cycle') + + self.check_err("p(x) :- plus(x, y, z), plus(z, w, t), not q(z, t)," + " s(x), t(y)", + ['plus(z, w, t)', 'q(z, t)'], + 'Unsafety propagates') + + class TestNonrecursive(unittest.TestCase): def prep_runtime(self, code=None, msg=None, target=None): # compile source @@ -173,25 +316,6 @@ class TestNonrecursive(unittest.TestCase): self.check_equal(run.select('p(x)', target=th), 'p(4)', "Bound output") - def test_builtins_safety(self): - """Test that the builtins mechanism catches invalid syntax""" - def check_err(code, emsg, title): - th = NREC_THEORY - run = self.prep_runtime() - (permitted, errors) = run.insert(code, th) - self.assertFalse(permitted, title) - self.assertTrue(any(emsg in str(e) for e in errors), - "Error msg should include '{}' but received: {}".format( - emsg, ";".join(str(e) for e in errors))) - - code = "p(x) :- plus(x,y,z)" - emsg = 'y found in builtin input but not in positive literal' - check_err(code, emsg, 'Unsafe input variable') - - code = "p(x) :- plus(x,y,z), not q(y)" - emsg = 'y found in builtin input but not in positive literal' - check_err(code, emsg, 'Unsafe input variable in neg literal') - def test_builtins_content(self): """Test the content of the builtins, not the mechanism""" def check_true(code, msg): diff --git a/congress/policy/compile.py b/congress/policy/compile.py index 87bd3945c..9ac269f40 100755 --- a/congress/policy/compile.py +++ b/congress/policy/compile.py @@ -552,6 +552,9 @@ class Rule (object): all(self.body[i] == other.body[i] for i in xrange(0, len(self.body)))) + def __ne__(self, other): + return not self.__eq__(other) + def __repr__(self): return "Rule(head={}, body={}, location={})".format( "[" + ",".join(repr(arg) for arg in self.heads) + "]", @@ -706,6 +709,65 @@ def head_to_body_dependency_graph(formulas): return g +def reorder_for_safety(rule): + """Moves builtins/negative literals so that when left-to-right evaluation + is performed all of a builtin's inputs are bound by the time that builtin + is evaluated. Reordering is stable, meaning that if the rule is + properly ordered, no changes are made. + """ + cbcmapinst = cbcmap(initbuiltin) + safe_vars = set() + unsafe_literals = [] + unsafe_variables = {} # dictionary from literal to its unsafe vars + new_body = [] + + def make_safe(lit): + safe_vars.update(lit.variable_names()) + new_body.append(lit) + + def make_safe_plus(lit): + make_safe(lit) + found_safe = True + while found_safe: + found_safe = False + for unsafe_lit in unsafe_literals: + if unsafe_variables[unsafe_lit] <= safe_vars: + unsafe_literals.remove(unsafe_lit) + make_safe(unsafe_lit) + found_safe = True + break # so that we reorder as little as possible + + for lit in rule.body: + target_vars = None + if lit.is_negated(): + target_vars = lit.variable_names() + elif cbcmapinst.check_if_builtin_by_name( + lit.table, len(lit.arguments)): + builtin = cbcmapinst.return_builtin_pred(lit.table) + target_vars = lit.arguments[0:builtin.num_inputs] + target_vars = set([x.name for x in target_vars if x.is_variable()]) + else: + # neither a builtin nor negated + make_safe_plus(lit) + continue + + new_unsafe_vars = target_vars - safe_vars + if new_unsafe_vars: + unsafe_literals.append(lit) + unsafe_variables[lit] = new_unsafe_vars + else: + make_safe_plus(lit) + + if len(unsafe_literals) > 0: + lit_msgs = [str(lit) + " (vars " + str(unsafe_variables[lit]) + ")" + for lit in unsafe_literals] + raise CongressException( + "Could not reorder rule {}. Unsafe lits: {}".format( + str(rule), "; ".join(lit_msgs))) + rule.body = new_body + return rule + + def fact_errors(atom, module_schemas): """Checks if ATOM is ground.""" assert atom.is_atom(), "fact_errors expects an atom" @@ -744,39 +806,11 @@ def rule_body_safety(rule): in a builtin input appears in the body. Returns list of exceptions. """ assert not rule.is_atom(), "rule_body_safety expects a rule" - errors = [] - cbcmapinst = cbcmap(initbuiltin) - - # Variables in negative literals must appear in positive literals - # Variables as inputs to builtins must appear in positive literals - # TODO(thinrichs): relax the builtin restriction so that an output - # of a safe builtin is also safe - neg_vars = set() - pos_vars = set() - builtin_vars = set() - for lit in rule.body: - if lit.is_negated(): - neg_vars |= lit.variables() - elif cbcmapinst.check_if_builtin_by_name( - lit.table, len(lit.arguments)): - cbc = cbcmapinst.return_builtin_pred(lit.tablename()) - for i in xrange(0, cbc.num_inputs): - if lit.arguments[i].is_variable(): - builtin_vars.add(lit.arguments[i]) - else: - pos_vars |= lit.variables() - for var in neg_vars - pos_vars: - errors.append(CongressException( - "Variable {} found in negative literal but not in " - "positive literal, rule {}".format(str(var), str(rule)), - obj=var)) - for var in builtin_vars - pos_vars: - errors.append(CongressException( - "Variable {} found in builtin input but not in " - "positive literal, rule {}".format( - str(var), str(rule)), - obj=var)) - return errors + try: + reorder_for_safety(rule) + return [] + except CongressException as e: + return [e] def rule_schema_consistency(rule, module_schemas=None): diff --git a/congress/policy/runtime.py b/congress/policy/runtime.py index 5d541fc90..338ea3acc 100644 --- a/congress/policy/runtime.py +++ b/congress/policy/runtime.py @@ -18,8 +18,8 @@ import cStringIO import os from unify import bi_unify_lists -from builtin.congressbuiltin import CongressBuiltinCategoryMap as cbcmap -from builtin.congressbuiltin import start_builtin_map as initbuiltin +from builtin.congressbuiltin import CongressBuiltinCategoryMap +from builtin.congressbuiltin import start_builtin_map # FIXME there is a circular import here because compile.py imports runtime.py import compile @@ -261,7 +261,7 @@ class Theory(object): self.trace_prefix = self.abbr[0:maxlength] else: self.trace_prefix = self.abbr + " " * (maxlength - len(self.abbr)) - self.cbcmap = cbcmap(initbuiltin) + self.cbcmap = CongressBuiltinCategoryMap(start_builtin_map) def set_tracer(self, tracer): self.tracer = tracer