From 45fd89592616d7278152c3cc3997310929eece80 Mon Sep 17 00:00:00 2001 From: Tim Hinrichs Date: Tue, 20 Aug 2013 15:33:40 -0700 Subject: [PATCH] Added explanation extraction Added ability to extract a full proof object. Issue: # Change-Id: I197c271b289d8ef19027fbdf2b4fda52ba100f3c --- src/policy/compile.py | 27 ++++++++++++ src/policy/runtime.py | 76 +++++++++++++++++++++++++------- src/policy/tests/test_runtime.py | 11 ++++- 3 files changed, 97 insertions(+), 17 deletions(-) diff --git a/src/policy/compile.py b/src/policy/compile.py index f8d3d8ec5..6f547edd4 100755 --- a/src/policy/compile.py +++ b/src/policy/compile.py @@ -147,6 +147,18 @@ class Atom (object): def variable_names(self): return set([x.name for x in self.arguments if x.is_variable()]) + def is_ground(self): + return all(not arg.is_variable() for arg in self.arguments) + + def plug(self, binding): + args = [] + for arg in self.arguments: + if arg.name in binding: + args.append(Term.create_from_python(binding[arg.name])) + else: + args.append(arg) + return Literal(self.table, args) + class Literal(Atom): """ Represents either a negated atom or an atom. """ def __init__(self, table, arguments, negated=False, location=None): @@ -171,6 +183,15 @@ class Literal(Atom): def is_rule(self): return False + def plug(self, binding): + args = [] + for arg in self.arguments: + if arg.name in binding: + args.append(Term.create_from_python(binding[arg.name])) + else: + args.append(arg) + return Literal(self.table, args, negated=self.negated) + class Rule (object): """ Represents a rule, e.g. p(x) :- q(x). """ def __init__(self, head, body, location=None): @@ -195,6 +216,12 @@ class Rule (object): def is_rule(self): return True + def plug(self, binding): + newhead = self.head.plug(binding) + newbody = [lit.plug(binding) for lit in self.body] + return Rule(newhead, newbody) + + ############################################################################## ## Compiler ############################################################################## diff --git a/src/policy/runtime.py b/src/policy/runtime.py index 588b3e6a0..a484683e7 100644 --- a/src/policy/runtime.py +++ b/src/policy/runtime.py @@ -169,6 +169,9 @@ class Database(object): self.contents.append(proof) return self + def __getitem__(self, key): + return self.contents[key] + def __len__(self): return len(self.contents) @@ -210,15 +213,8 @@ class Database(object): logging.debug("Check succeeded with binding {}".format(str(new_binding))) return new_binding - class Schema (object): - def __init__(self, column_names): - self.arguments = column_names - def __str__(self): - return str(self.arguments) - def __init__(self): self.data = {} - self.schemas = {} # not currently used self.tracer = Tracer() def __str__(self): @@ -239,8 +235,6 @@ class Database(object): return '{' + ", ".join(strings) + '}' return hashlist2str(self.data) - # return "".format( - # hashlist2str(self.data), hash2str(self.schemas)) def __eq__(self, other): return self.data == other.data @@ -282,6 +276,14 @@ class Database(object): result.append(new_atom) return result + def explain(self, atom): + if atom.table not in self.data or not atom.is_ground(): + return self.ProofCollection() + args = tuple([x.name for x in atom.arguments]) + for dbtuple in self.data[atom.table]: + if dbtuple.tuple == args: + return dbtuple.proofs + def top_down_eval(self, literals, literal_index, binding): """ Compute all instances of LITERALS (from LITERAL_INDEX and above) that are true in the Database (after applying the dictionary binding @@ -298,7 +300,7 @@ class Database(object): # assume that for negative literals, all vars are bound at this point # if there is a match, data_bindings will contain at least one binding # (possibly including the empty binding) - data_bindings = self.get_matches(lit, binding) + data_bindings = self.matches(lit, binding) self.log(lit.table, "data_bindings: " + str(data_bindings), depth=literal_index) # if not negated, empty data_bindings means failure if len(data_bindings) == 0 : @@ -321,14 +323,14 @@ class Database(object): return results - def get_matches(self, literal, binding): + def matches(self, literal, binding): """ Returns a list of binding lists for the variables in LITERAL not bound in BINDING. If LITERAL is negative, returns either [] meaning the lookup failed or [{}] meaning the lookup succeeded; otherwise, returns one binding list for each tuple in the database matching LITERAL under BINDING. """ # slow--should stop at first match, not find all of them - matches = self.get_matches_atom(literal, binding) + matches = self.matches_atom(literal, binding) if literal.is_negated(): if len(matches) > 0: return [] @@ -337,7 +339,7 @@ class Database(object): else: return matches - def get_matches_atom(self, atom, binding): + def matches_atom(self, atom, binding): """ Returns a list of binding lists for the variables in ATOM not bound in BINDING: one binding list for each tuple in the database matching ATOM under BINDING. """ @@ -398,6 +400,26 @@ class Runtime (object): """ Runtime for the Congress policy language. Only have one instantiation in practice, but using a class is natural and useful for testing. """ + class Proof(object): + """ A single proof. Differs semantically from Database's + Proof in that this verison represents a proof that spans rules, + instead of just a proof for a single rule. """ + def __init__(self, root, children): + self.root = root + self.children = children + + def __str__(self): + return self.str_tree(0) + + def str_tree(self, depth): + s = " " * depth + s += str(self.root) + s += "\n" + for child in self.children: + s += child.str_tree(depth + 1) + return s + + def __init__(self, rules): # rules dictating how an insert/delete to one table # affects other tables @@ -429,8 +451,9 @@ class Runtime (object): def explain(self, query): """ Event handler for explanations. Given a ground query, return - all explanations for it. """ - assert False, "Not yet implemented" + a single proof that it belongs in the database. """ + assert isinstance(query, compile.Atom), "Only have support for literals" + return self.explain_aux(query, 0) def insert(self, formula): """ Event handler for arbitrary insertion (rules and facts). """ @@ -441,6 +464,23 @@ class Runtime (object): return self.modify(formula, is_insert=False) ############### Interface implementation ############### + def explain_aux(self, query, depth): + self.log(query.table, "Explaining {}".format(str(query)), depth) + if query.is_negated(): + return self.Proof(query, []) + # grab first local proof, since they're all equally good + localproofs = self.database.explain(query) + if len(localproofs) == 0: # base fact + return self.Proof(query, []) + localproof = localproofs[0] + rule_instance = localproof.rule.plug(localproof.binding) + subproofs = [] + for lit in rule_instance.body: + subproof = self.explain_aux(lit, depth + 1) + if subproof is None: + return None + subproofs.append(subproof) + return self.Proof(query, subproofs) def modify(self, formula, is_insert=True): """ Event handler for arbitrary insertion/deletion (rules and facts). """ @@ -579,7 +619,11 @@ class StringRuntime(Runtime): def explain(self, query_string): """ Event handler for explanations. Given a ground query, return all explanations for it. """ - assert False, "Not yet implemented" + c = compile.get_compiled([query_string, '--input_string']) + assert len(c.theory) == 1, "Queries can have only 1 statement" + assert c.theory[0].is_atom(), "Queries must be atomic" + results = super(StringRuntime, self).explain(c.theory[0]) + return str(results) def insert(self, policy_string): """ Event handler for arbitrary insertion (rules and/or facts). """ diff --git a/src/policy/tests/test_runtime.py b/src/policy/tests/test_runtime.py index b420c7dbf..d379139da 100644 --- a/src/policy/tests/test_runtime.py +++ b/src/policy/tests/test_runtime.py @@ -426,7 +426,16 @@ class TestRuntime(unittest.TestCase): def test_explanations(self): """ Test the explanation event handler. """ - pass + run = self.prep_runtime("p(x) :- q(x), r(x)", "Explanations") + run.insert("q(1) r(1)") + self.showdb(run) + logging.debug(run.explain("p(1)")) + + run = self.prep_runtime("p(x) :- q(x), r(x) q(x) :- s(x), t(x)", "Explanations") + run.insert("s(1) r(1) t(1)") + self.showdb(run) + logging.debug(run.explain("p(1)")) + self.fail() if __name__ == '__main__': unittest.main()