Merge "Added builtin safety and reordering to policy engine"

This commit is contained in:
Jenkins 2014-10-23 17:44:03 +00:00 committed by Gerrit Code Review
commit cc4fe52eed
3 changed files with 213 additions and 55 deletions

View File

@ -22,6 +22,7 @@ from congress.policy.builtin.congressbuiltin \
import CongressBuiltinCategoryMap as builtins
from congress.policy.builtin.congressbuiltin import CongressBuiltinPred
from congress.policy.builtin.congressbuiltin import start_builtin_map
from congress.policy import compile
from congress.policy import runtime
from congress.tests import helper
@ -94,6 +95,148 @@ class TestBuiltins(unittest.TestCase):
self.assertEqual(result, False)
class TestReorder(unittest.TestCase):
def check(self, input_string, correct_string, msg):
rule = compile.parse1(input_string)
actual = compile.reorder_for_safety(rule)
correct = compile.parse1(correct_string)
if correct != actual:
emsg = "Correct: " + str(correct)
emsg += "; Actual: " + str(actual)
self.fail(msg + " :: " + emsg)
def check_err(self, input_string, unsafe_lit_strings, msg):
rule = compile.parse1(input_string)
try:
compile.reorder_for_safety(rule)
self.fail("Failed to raise exception for " + input_string)
except compile.CongressException as e:
errmsg = str(e)
# parse then print to string so string rep same in err msg
unsafe_lits = [str(compile.parse1(x)) for x in unsafe_lit_strings]
missing_lits = [m for m in unsafe_lits
if m + " (vars" not in errmsg]
if len(missing_lits) > 0:
self.fail(
"Unsafe literals {} not reported in error: {}".format(
";".join(missing_lits), errmsg))
def test_reorder_builtins(self):
self.check("p(x, z) :- q(x, y), plus(x, y, z)",
"p(x, z) :- q(x, y), plus(x, y, z)",
"No reorder")
self.check("p(x, z) :- plus(x, y, z), q(x, y)",
"p(x, z) :- q(x, y), plus(x, y, z)",
"Basic reorder")
self.check("p(x, z) :- q(x, y), r(w), plus(x, y, z), plus(z, w, y)",
"p(x, z) :- q(x, y), r(w), plus(x, y, z), plus(z, w, y)",
"Chaining: no reorder")
self.check("p(x, z) :- q(x, y), plus(x, y, z), plus(z, w, y), r(w)",
"p(x, z) :- q(x, y), plus(x, y, z), r(w), plus(z, w, y)",
"Chaining: reorder")
self.check("p(x) :- lt(t, v), plus(z, w, t), plus(z, u, v), "
" plus(x, y, z), q(y), r(x), s(u), t(w) ",
"p(x) :- q(y), r(x), plus(x, y, z), s(u), plus(z, u, v), "
" t(w), plus(z, w, t), lt(t, v)",
"Partial-order chaining")
def test_unsafe_builtins(self):
# an output
self.check_err("p(x) :- q(x), plus(x, y, z)",
["plus(x,y,z)"],
"Basic Unsafe input")
self.check_err("p(x) :- q(x), r(z), plus(x, y, z)",
["plus(x,y,z)"],
"Basic Unsafe input 2")
self.check_err("p(x, z) :- plus(x, y, z), plus(z, y, x), "
" plus(x, z, y)",
["plus(x, y, z)", "plus(z, y, x)", "plus(x, z, y)"],
"Unsafe with cycle")
# no outputs
self.check_err("p(x) :- q(x), lt(x, y)",
["lt(x,y)"],
"Basic Unsafe input, no outputs")
self.check_err("p(x) :- q(y), lt(x, y)",
["lt(x,y)"],
"Basic Unsafe input, no outputs 2")
self.check_err("p(x, z) :- lt(x, y), lt(y, x)",
["lt(x,y)", "lt(y, x)"],
"Unsafe with cycle, no outputs")
# chaining
self.check_err("p(x) :- q(x, y), plus(x, y, z), plus(z, 3, w), "
" plus(w, t, u)",
["plus(w, t, u)"],
"Unsafe chaining")
self.check_err("p(x) :- q(x, y), plus(x, y, z), plus(z, 3, w), "
" lt(w, t)",
["lt(w, t)"],
"Unsafe chaining 2")
def test_reorder_negation(self):
self.check("p(x) :- q(x), not u(x), r(y), not s(x, y)",
"p(x) :- q(x), not u(x), r(y), not s(x, y)",
"No reordering")
self.check("p(x) :- not q(x), r(x)",
"p(x) :- r(x), not q(x)",
"Basic")
self.check("p(x) :- r(x), not q(x, y), s(y)",
"p(x) :- r(x), s(y), not q(x,y)",
"Partially safe")
self.check("p(x) :- not q(x, y), not r(x), not r(x, z), "
" t(x, y), u(x), s(z)",
"p(x) :- t(x,y), not q(x,y), not r(x), u(x), s(z), "
" not r(x, z)",
"Complex")
def test_unsafe_negation(self):
self.check_err("p(x) :- not q(x)",
["q(x)"],
"Basic")
self.check_err("p(x) :- not q(x), not r(x)",
["q(x)", "r(x)"],
"Cycle")
self.check_err("p(x) :- not q(x, y), r(y)",
["q(x, y)"],
"Partially safe")
def test_reorder_builtins_negation(self):
self.check("p(x) :- not q(z), plus(x, y, z), s(x), s(y)",
"p(x) :- s(x), s(y), plus(x, y, z), not q(z)",
"Basic")
self.check("p(x) :- not q(z, w), plus(x, y, z), lt(z, w), "
" plus(x, 3, w), s(x, y)",
"p(x) :- s(x,y), plus(x, y, z), plus(x, 3, w), "
" not q(z, w), lt(z, w)",
"Partial order")
def test_unsafe_builtins_negation(self):
self.check_err("p(x) :- plus(x, y, z), not q(x, y)",
['plus(x,y,z)', 'q(x,y)'],
'Unsafe cycle')
self.check_err("p(x) :- plus(x, y, z), plus(z, w, t), not q(z, t),"
" s(x), t(y)",
['plus(z, w, t)', 'q(z, t)'],
'Unsafety propagates')
class TestNonrecursive(unittest.TestCase):
def prep_runtime(self, code=None, msg=None, target=None):
# compile source
@ -173,25 +316,6 @@ class TestNonrecursive(unittest.TestCase):
self.check_equal(run.select('p(x)', target=th),
'p(4)', "Bound output")
def test_builtins_safety(self):
"""Test that the builtins mechanism catches invalid syntax"""
def check_err(code, emsg, title):
th = NREC_THEORY
run = self.prep_runtime()
(permitted, errors) = run.insert(code, th)
self.assertFalse(permitted, title)
self.assertTrue(any(emsg in str(e) for e in errors),
"Error msg should include '{}' but received: {}".format(
emsg, ";".join(str(e) for e in errors)))
code = "p(x) :- plus(x,y,z)"
emsg = 'y found in builtin input but not in positive literal'
check_err(code, emsg, 'Unsafe input variable')
code = "p(x) :- plus(x,y,z), not q(y)"
emsg = 'y found in builtin input but not in positive literal'
check_err(code, emsg, 'Unsafe input variable in neg literal')
def test_builtins_content(self):
"""Test the content of the builtins, not the mechanism"""
def check_true(code, msg):

View File

@ -552,6 +552,9 @@ class Rule (object):
all(self.body[i] == other.body[i]
for i in xrange(0, len(self.body))))
def __ne__(self, other):
return not self.__eq__(other)
def __repr__(self):
return "Rule(head={}, body={}, location={})".format(
"[" + ",".join(repr(arg) for arg in self.heads) + "]",
@ -706,6 +709,65 @@ def head_to_body_dependency_graph(formulas):
return g
def reorder_for_safety(rule):
"""Moves builtins/negative literals so that when left-to-right evaluation
is performed all of a builtin's inputs are bound by the time that builtin
is evaluated. Reordering is stable, meaning that if the rule is
properly ordered, no changes are made.
"""
cbcmapinst = cbcmap(initbuiltin)
safe_vars = set()
unsafe_literals = []
unsafe_variables = {} # dictionary from literal to its unsafe vars
new_body = []
def make_safe(lit):
safe_vars.update(lit.variable_names())
new_body.append(lit)
def make_safe_plus(lit):
make_safe(lit)
found_safe = True
while found_safe:
found_safe = False
for unsafe_lit in unsafe_literals:
if unsafe_variables[unsafe_lit] <= safe_vars:
unsafe_literals.remove(unsafe_lit)
make_safe(unsafe_lit)
found_safe = True
break # so that we reorder as little as possible
for lit in rule.body:
target_vars = None
if lit.is_negated():
target_vars = lit.variable_names()
elif cbcmapinst.check_if_builtin_by_name(
lit.table, len(lit.arguments)):
builtin = cbcmapinst.return_builtin_pred(lit.table)
target_vars = lit.arguments[0:builtin.num_inputs]
target_vars = set([x.name for x in target_vars if x.is_variable()])
else:
# neither a builtin nor negated
make_safe_plus(lit)
continue
new_unsafe_vars = target_vars - safe_vars
if new_unsafe_vars:
unsafe_literals.append(lit)
unsafe_variables[lit] = new_unsafe_vars
else:
make_safe_plus(lit)
if len(unsafe_literals) > 0:
lit_msgs = [str(lit) + " (vars " + str(unsafe_variables[lit]) + ")"
for lit in unsafe_literals]
raise CongressException(
"Could not reorder rule {}. Unsafe lits: {}".format(
str(rule), "; ".join(lit_msgs)))
rule.body = new_body
return rule
def fact_errors(atom, module_schemas):
"""Checks if ATOM is ground."""
assert atom.is_atom(), "fact_errors expects an atom"
@ -744,39 +806,11 @@ def rule_body_safety(rule):
in a builtin input appears in the body. Returns list of exceptions.
"""
assert not rule.is_atom(), "rule_body_safety expects a rule"
errors = []
cbcmapinst = cbcmap(initbuiltin)
# Variables in negative literals must appear in positive literals
# Variables as inputs to builtins must appear in positive literals
# TODO(thinrichs): relax the builtin restriction so that an output
# of a safe builtin is also safe
neg_vars = set()
pos_vars = set()
builtin_vars = set()
for lit in rule.body:
if lit.is_negated():
neg_vars |= lit.variables()
elif cbcmapinst.check_if_builtin_by_name(
lit.table, len(lit.arguments)):
cbc = cbcmapinst.return_builtin_pred(lit.tablename())
for i in xrange(0, cbc.num_inputs):
if lit.arguments[i].is_variable():
builtin_vars.add(lit.arguments[i])
else:
pos_vars |= lit.variables()
for var in neg_vars - pos_vars:
errors.append(CongressException(
"Variable {} found in negative literal but not in "
"positive literal, rule {}".format(str(var), str(rule)),
obj=var))
for var in builtin_vars - pos_vars:
errors.append(CongressException(
"Variable {} found in builtin input but not in "
"positive literal, rule {}".format(
str(var), str(rule)),
obj=var))
return errors
try:
reorder_for_safety(rule)
return []
except CongressException as e:
return [e]
def rule_schema_consistency(rule, module_schemas=None):

View File

@ -18,8 +18,8 @@ import cStringIO
import os
from unify import bi_unify_lists
from builtin.congressbuiltin import CongressBuiltinCategoryMap as cbcmap
from builtin.congressbuiltin import start_builtin_map as initbuiltin
from builtin.congressbuiltin import CongressBuiltinCategoryMap
from builtin.congressbuiltin import start_builtin_map
# FIXME there is a circular import here because compile.py imports runtime.py
import compile
@ -261,7 +261,7 @@ class Theory(object):
self.trace_prefix = self.abbr[0:maxlength]
else:
self.trace_prefix = self.abbr + " " * (maxlength - len(self.abbr))
self.cbcmap = cbcmap(initbuiltin)
self.cbcmap = CongressBuiltinCategoryMap(start_builtin_map)
def set_tracer(self, tracer):
self.tracer = tracer