Merge "Added builtin safety and reordering to policy engine"
This commit is contained in:
commit
cc4fe52eed
@ -22,6 +22,7 @@ from congress.policy.builtin.congressbuiltin \
|
||||
import CongressBuiltinCategoryMap as builtins
|
||||
from congress.policy.builtin.congressbuiltin import CongressBuiltinPred
|
||||
from congress.policy.builtin.congressbuiltin import start_builtin_map
|
||||
from congress.policy import compile
|
||||
from congress.policy import runtime
|
||||
from congress.tests import helper
|
||||
|
||||
@ -94,6 +95,148 @@ class TestBuiltins(unittest.TestCase):
|
||||
self.assertEqual(result, False)
|
||||
|
||||
|
||||
class TestReorder(unittest.TestCase):
|
||||
def check(self, input_string, correct_string, msg):
|
||||
rule = compile.parse1(input_string)
|
||||
actual = compile.reorder_for_safety(rule)
|
||||
correct = compile.parse1(correct_string)
|
||||
if correct != actual:
|
||||
emsg = "Correct: " + str(correct)
|
||||
emsg += "; Actual: " + str(actual)
|
||||
self.fail(msg + " :: " + emsg)
|
||||
|
||||
def check_err(self, input_string, unsafe_lit_strings, msg):
|
||||
rule = compile.parse1(input_string)
|
||||
try:
|
||||
compile.reorder_for_safety(rule)
|
||||
self.fail("Failed to raise exception for " + input_string)
|
||||
except compile.CongressException as e:
|
||||
errmsg = str(e)
|
||||
# parse then print to string so string rep same in err msg
|
||||
unsafe_lits = [str(compile.parse1(x)) for x in unsafe_lit_strings]
|
||||
missing_lits = [m for m in unsafe_lits
|
||||
if m + " (vars" not in errmsg]
|
||||
if len(missing_lits) > 0:
|
||||
self.fail(
|
||||
"Unsafe literals {} not reported in error: {}".format(
|
||||
";".join(missing_lits), errmsg))
|
||||
|
||||
def test_reorder_builtins(self):
|
||||
self.check("p(x, z) :- q(x, y), plus(x, y, z)",
|
||||
"p(x, z) :- q(x, y), plus(x, y, z)",
|
||||
"No reorder")
|
||||
|
||||
self.check("p(x, z) :- plus(x, y, z), q(x, y)",
|
||||
"p(x, z) :- q(x, y), plus(x, y, z)",
|
||||
"Basic reorder")
|
||||
|
||||
self.check("p(x, z) :- q(x, y), r(w), plus(x, y, z), plus(z, w, y)",
|
||||
"p(x, z) :- q(x, y), r(w), plus(x, y, z), plus(z, w, y)",
|
||||
"Chaining: no reorder")
|
||||
|
||||
self.check("p(x, z) :- q(x, y), plus(x, y, z), plus(z, w, y), r(w)",
|
||||
"p(x, z) :- q(x, y), plus(x, y, z), r(w), plus(z, w, y)",
|
||||
"Chaining: reorder")
|
||||
|
||||
self.check("p(x) :- lt(t, v), plus(z, w, t), plus(z, u, v), "
|
||||
" plus(x, y, z), q(y), r(x), s(u), t(w) ",
|
||||
"p(x) :- q(y), r(x), plus(x, y, z), s(u), plus(z, u, v), "
|
||||
" t(w), plus(z, w, t), lt(t, v)",
|
||||
"Partial-order chaining")
|
||||
|
||||
def test_unsafe_builtins(self):
|
||||
# an output
|
||||
self.check_err("p(x) :- q(x), plus(x, y, z)",
|
||||
["plus(x,y,z)"],
|
||||
"Basic Unsafe input")
|
||||
|
||||
self.check_err("p(x) :- q(x), r(z), plus(x, y, z)",
|
||||
["plus(x,y,z)"],
|
||||
"Basic Unsafe input 2")
|
||||
|
||||
self.check_err("p(x, z) :- plus(x, y, z), plus(z, y, x), "
|
||||
" plus(x, z, y)",
|
||||
["plus(x, y, z)", "plus(z, y, x)", "plus(x, z, y)"],
|
||||
"Unsafe with cycle")
|
||||
|
||||
# no outputs
|
||||
self.check_err("p(x) :- q(x), lt(x, y)",
|
||||
["lt(x,y)"],
|
||||
"Basic Unsafe input, no outputs")
|
||||
|
||||
self.check_err("p(x) :- q(y), lt(x, y)",
|
||||
["lt(x,y)"],
|
||||
"Basic Unsafe input, no outputs 2")
|
||||
|
||||
self.check_err("p(x, z) :- lt(x, y), lt(y, x)",
|
||||
["lt(x,y)", "lt(y, x)"],
|
||||
"Unsafe with cycle, no outputs")
|
||||
|
||||
# chaining
|
||||
self.check_err("p(x) :- q(x, y), plus(x, y, z), plus(z, 3, w), "
|
||||
" plus(w, t, u)",
|
||||
["plus(w, t, u)"],
|
||||
"Unsafe chaining")
|
||||
|
||||
self.check_err("p(x) :- q(x, y), plus(x, y, z), plus(z, 3, w), "
|
||||
" lt(w, t)",
|
||||
["lt(w, t)"],
|
||||
"Unsafe chaining 2")
|
||||
|
||||
def test_reorder_negation(self):
|
||||
self.check("p(x) :- q(x), not u(x), r(y), not s(x, y)",
|
||||
"p(x) :- q(x), not u(x), r(y), not s(x, y)",
|
||||
"No reordering")
|
||||
|
||||
self.check("p(x) :- not q(x), r(x)",
|
||||
"p(x) :- r(x), not q(x)",
|
||||
"Basic")
|
||||
|
||||
self.check("p(x) :- r(x), not q(x, y), s(y)",
|
||||
"p(x) :- r(x), s(y), not q(x,y)",
|
||||
"Partially safe")
|
||||
|
||||
self.check("p(x) :- not q(x, y), not r(x), not r(x, z), "
|
||||
" t(x, y), u(x), s(z)",
|
||||
"p(x) :- t(x,y), not q(x,y), not r(x), u(x), s(z), "
|
||||
" not r(x, z)",
|
||||
"Complex")
|
||||
|
||||
def test_unsafe_negation(self):
|
||||
self.check_err("p(x) :- not q(x)",
|
||||
["q(x)"],
|
||||
"Basic")
|
||||
|
||||
self.check_err("p(x) :- not q(x), not r(x)",
|
||||
["q(x)", "r(x)"],
|
||||
"Cycle")
|
||||
|
||||
self.check_err("p(x) :- not q(x, y), r(y)",
|
||||
["q(x, y)"],
|
||||
"Partially safe")
|
||||
|
||||
def test_reorder_builtins_negation(self):
|
||||
self.check("p(x) :- not q(z), plus(x, y, z), s(x), s(y)",
|
||||
"p(x) :- s(x), s(y), plus(x, y, z), not q(z)",
|
||||
"Basic")
|
||||
|
||||
self.check("p(x) :- not q(z, w), plus(x, y, z), lt(z, w), "
|
||||
" plus(x, 3, w), s(x, y)",
|
||||
"p(x) :- s(x,y), plus(x, y, z), plus(x, 3, w), "
|
||||
" not q(z, w), lt(z, w)",
|
||||
"Partial order")
|
||||
|
||||
def test_unsafe_builtins_negation(self):
|
||||
self.check_err("p(x) :- plus(x, y, z), not q(x, y)",
|
||||
['plus(x,y,z)', 'q(x,y)'],
|
||||
'Unsafe cycle')
|
||||
|
||||
self.check_err("p(x) :- plus(x, y, z), plus(z, w, t), not q(z, t),"
|
||||
" s(x), t(y)",
|
||||
['plus(z, w, t)', 'q(z, t)'],
|
||||
'Unsafety propagates')
|
||||
|
||||
|
||||
class TestNonrecursive(unittest.TestCase):
|
||||
def prep_runtime(self, code=None, msg=None, target=None):
|
||||
# compile source
|
||||
@ -173,25 +316,6 @@ class TestNonrecursive(unittest.TestCase):
|
||||
self.check_equal(run.select('p(x)', target=th),
|
||||
'p(4)', "Bound output")
|
||||
|
||||
def test_builtins_safety(self):
|
||||
"""Test that the builtins mechanism catches invalid syntax"""
|
||||
def check_err(code, emsg, title):
|
||||
th = NREC_THEORY
|
||||
run = self.prep_runtime()
|
||||
(permitted, errors) = run.insert(code, th)
|
||||
self.assertFalse(permitted, title)
|
||||
self.assertTrue(any(emsg in str(e) for e in errors),
|
||||
"Error msg should include '{}' but received: {}".format(
|
||||
emsg, ";".join(str(e) for e in errors)))
|
||||
|
||||
code = "p(x) :- plus(x,y,z)"
|
||||
emsg = 'y found in builtin input but not in positive literal'
|
||||
check_err(code, emsg, 'Unsafe input variable')
|
||||
|
||||
code = "p(x) :- plus(x,y,z), not q(y)"
|
||||
emsg = 'y found in builtin input but not in positive literal'
|
||||
check_err(code, emsg, 'Unsafe input variable in neg literal')
|
||||
|
||||
def test_builtins_content(self):
|
||||
"""Test the content of the builtins, not the mechanism"""
|
||||
def check_true(code, msg):
|
||||
|
@ -552,6 +552,9 @@ class Rule (object):
|
||||
all(self.body[i] == other.body[i]
|
||||
for i in xrange(0, len(self.body))))
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self.__eq__(other)
|
||||
|
||||
def __repr__(self):
|
||||
return "Rule(head={}, body={}, location={})".format(
|
||||
"[" + ",".join(repr(arg) for arg in self.heads) + "]",
|
||||
@ -706,6 +709,65 @@ def head_to_body_dependency_graph(formulas):
|
||||
return g
|
||||
|
||||
|
||||
def reorder_for_safety(rule):
|
||||
"""Moves builtins/negative literals so that when left-to-right evaluation
|
||||
is performed all of a builtin's inputs are bound by the time that builtin
|
||||
is evaluated. Reordering is stable, meaning that if the rule is
|
||||
properly ordered, no changes are made.
|
||||
"""
|
||||
cbcmapinst = cbcmap(initbuiltin)
|
||||
safe_vars = set()
|
||||
unsafe_literals = []
|
||||
unsafe_variables = {} # dictionary from literal to its unsafe vars
|
||||
new_body = []
|
||||
|
||||
def make_safe(lit):
|
||||
safe_vars.update(lit.variable_names())
|
||||
new_body.append(lit)
|
||||
|
||||
def make_safe_plus(lit):
|
||||
make_safe(lit)
|
||||
found_safe = True
|
||||
while found_safe:
|
||||
found_safe = False
|
||||
for unsafe_lit in unsafe_literals:
|
||||
if unsafe_variables[unsafe_lit] <= safe_vars:
|
||||
unsafe_literals.remove(unsafe_lit)
|
||||
make_safe(unsafe_lit)
|
||||
found_safe = True
|
||||
break # so that we reorder as little as possible
|
||||
|
||||
for lit in rule.body:
|
||||
target_vars = None
|
||||
if lit.is_negated():
|
||||
target_vars = lit.variable_names()
|
||||
elif cbcmapinst.check_if_builtin_by_name(
|
||||
lit.table, len(lit.arguments)):
|
||||
builtin = cbcmapinst.return_builtin_pred(lit.table)
|
||||
target_vars = lit.arguments[0:builtin.num_inputs]
|
||||
target_vars = set([x.name for x in target_vars if x.is_variable()])
|
||||
else:
|
||||
# neither a builtin nor negated
|
||||
make_safe_plus(lit)
|
||||
continue
|
||||
|
||||
new_unsafe_vars = target_vars - safe_vars
|
||||
if new_unsafe_vars:
|
||||
unsafe_literals.append(lit)
|
||||
unsafe_variables[lit] = new_unsafe_vars
|
||||
else:
|
||||
make_safe_plus(lit)
|
||||
|
||||
if len(unsafe_literals) > 0:
|
||||
lit_msgs = [str(lit) + " (vars " + str(unsafe_variables[lit]) + ")"
|
||||
for lit in unsafe_literals]
|
||||
raise CongressException(
|
||||
"Could not reorder rule {}. Unsafe lits: {}".format(
|
||||
str(rule), "; ".join(lit_msgs)))
|
||||
rule.body = new_body
|
||||
return rule
|
||||
|
||||
|
||||
def fact_errors(atom, module_schemas):
|
||||
"""Checks if ATOM is ground."""
|
||||
assert atom.is_atom(), "fact_errors expects an atom"
|
||||
@ -744,39 +806,11 @@ def rule_body_safety(rule):
|
||||
in a builtin input appears in the body. Returns list of exceptions.
|
||||
"""
|
||||
assert not rule.is_atom(), "rule_body_safety expects a rule"
|
||||
errors = []
|
||||
cbcmapinst = cbcmap(initbuiltin)
|
||||
|
||||
# Variables in negative literals must appear in positive literals
|
||||
# Variables as inputs to builtins must appear in positive literals
|
||||
# TODO(thinrichs): relax the builtin restriction so that an output
|
||||
# of a safe builtin is also safe
|
||||
neg_vars = set()
|
||||
pos_vars = set()
|
||||
builtin_vars = set()
|
||||
for lit in rule.body:
|
||||
if lit.is_negated():
|
||||
neg_vars |= lit.variables()
|
||||
elif cbcmapinst.check_if_builtin_by_name(
|
||||
lit.table, len(lit.arguments)):
|
||||
cbc = cbcmapinst.return_builtin_pred(lit.tablename())
|
||||
for i in xrange(0, cbc.num_inputs):
|
||||
if lit.arguments[i].is_variable():
|
||||
builtin_vars.add(lit.arguments[i])
|
||||
else:
|
||||
pos_vars |= lit.variables()
|
||||
for var in neg_vars - pos_vars:
|
||||
errors.append(CongressException(
|
||||
"Variable {} found in negative literal but not in "
|
||||
"positive literal, rule {}".format(str(var), str(rule)),
|
||||
obj=var))
|
||||
for var in builtin_vars - pos_vars:
|
||||
errors.append(CongressException(
|
||||
"Variable {} found in builtin input but not in "
|
||||
"positive literal, rule {}".format(
|
||||
str(var), str(rule)),
|
||||
obj=var))
|
||||
return errors
|
||||
try:
|
||||
reorder_for_safety(rule)
|
||||
return []
|
||||
except CongressException as e:
|
||||
return [e]
|
||||
|
||||
|
||||
def rule_schema_consistency(rule, module_schemas=None):
|
||||
|
@ -18,8 +18,8 @@ import cStringIO
|
||||
import os
|
||||
from unify import bi_unify_lists
|
||||
|
||||
from builtin.congressbuiltin import CongressBuiltinCategoryMap as cbcmap
|
||||
from builtin.congressbuiltin import start_builtin_map as initbuiltin
|
||||
from builtin.congressbuiltin import CongressBuiltinCategoryMap
|
||||
from builtin.congressbuiltin import start_builtin_map
|
||||
|
||||
# FIXME there is a circular import here because compile.py imports runtime.py
|
||||
import compile
|
||||
@ -261,7 +261,7 @@ class Theory(object):
|
||||
self.trace_prefix = self.abbr[0:maxlength]
|
||||
else:
|
||||
self.trace_prefix = self.abbr + " " * (maxlength - len(self.abbr))
|
||||
self.cbcmap = cbcmap(initbuiltin)
|
||||
self.cbcmap = CongressBuiltinCategoryMap(start_builtin_map)
|
||||
|
||||
def set_tracer(self, tracer):
|
||||
self.tracer = tracer
|
||||
|
Loading…
Reference in New Issue
Block a user