Adding runtime tests

Issue: #

Change-Id: Iddf142d48b7b9c9ac2a78130ea46374fc92bb3d5
This commit is contained in:
Tim Hinrichs 2013-08-13 13:43:45 -07:00
parent 705bef6d8b
commit 80eb6a7f5e
3 changed files with 138 additions and 87 deletions

View File

@ -35,8 +35,6 @@ class Location (object):
def __str__(self):
s = ""
if self.file is not None:
s += " file: {}".format(self.file)
if self.line is not None:
s += " line: {}".format(self.line)
if self.col is not None:
@ -221,58 +219,6 @@ class Compiler (object):
self.delta_rules.append(
runtime.DeltaRule(literal, rule.head, newbody))
class Tester (object):
def test_runtime(self):
def prep_runtime(code):
# compile source
c = Compiler()
c.read_source(input_string=code)
c.compute_delta_rules()
run = runtime.Runtime(c.delta_rules)
return run
def insert(run, list):
run.insert(list[0], tuple(list[1:]))
def delete(run, list):
run.delete(list[0], tuple(list[1:]))
def check(run, correct_database_code, msg=None):
# extract correct answer from code, represented as a Database
print "** Reading correct Database **"
c = Compiler()
c.read_source(input_string=correct_database_code)
correct = c.theory
correct_database = runtime.Database()
for atom in correct:
correct_database.insert(atom.table,
tuple([x.name for x in atom.arguments]))
print "** Correct Database **"
print str(correct_database)
# ensure correct answers is a subset of run.database
extra = run.database - correct_database
missing = correct_database - run.database
if len(extra) > 0 or len(missing) > 0:
print "Test {} failed".format(msg)
if len(extra) > 0:
print "Extra tuples: {}".format(str(extra))
if len(missing) > 0:
print "Missing tuples: {}".format(str(extra))
code = ("q(x) :- p(x), r(x)")
run = prep_runtime(code)
print "Finished prep_runtime"
insert(run, ['r', 1])
check(run, "r(1)")
print "Finished first insert"
insert(run, ['p', 1])
check(run, "r(1) p(1) q(1)")
print "Finished second insert"
print "**Final State**"
print str(run.database)
class CongressSyntax (object):
""" External syntax and converting it into internal representation. """
@ -334,6 +280,8 @@ class CongressSyntax (object):
return cls.create_atom(antlr)
elif obj == 'THEORY':
return [cls.create(x) for x in antlr.children]
elif obj == '<EOF>':
return []
else:
raise CongressException(
"Antlr tree with unknown root: {}".format(obj))
@ -421,8 +369,6 @@ def main():
compiler = Compiler()
compiler.read_source(inputs[0])
compiler.compute_delta_rules()
test = Tester()
test.test_runtime()
if __name__ == '__main__':
sys.exit(main())

View File

@ -1,6 +1,7 @@
#! /usr/bin/python
import collections
import logging
class Tracer(object):
def __init__(self):
@ -8,7 +9,7 @@ class Tracer(object):
def trace(self, table):
self.expressions.append(table)
def is_traced(self, table):
return table in self.expressions or '?' in self.expressions
return table in self.expressions or '*' in self.expressions
class CongressRuntime (Exception):
pass
@ -50,7 +51,11 @@ class Event(object):
return self.insert
def __str__(self):
return "{}({})".format(self.table,
if self.is_insert():
sign = '+'
else:
sign = '-'
return "{}{}({})".format(self.table, sign,
",".join([str(x) for x in self.tuple]))
# class Database(object):
@ -159,12 +164,13 @@ class Database(object):
def __eq__(self, other):
return self.tuple == other.tuple
def __str__(self):
return str(self.tuple)
def match(self, atom, binding):
print "Checking if tuple {} matches atom {} with binding {}".format(
str(self), str(atom), str(binding))
logging.debug("Checking if tuple {} matches atom {} with binding {}".format(
str(self), str(atom), str(binding)))
if len(self.tuple) != len(atom.arguments):
return None
new_binding = {}
@ -175,7 +181,7 @@ class Database(object):
return None
else:
new_binding[atom.arguments[i].name] = self.tuple[i]
print "Check succeeded with binding {}".format(str(new_binding))
logging.debug("Check succeeded with binding {}".format(str(new_binding)))
return new_binding
class Schema (object):
@ -187,6 +193,7 @@ class Database(object):
def __init__(self):
self.data = {}
self.schemas = {} # not currently used
self.tracer = Tracer()
def __str__(self):
def hash2str (h):
@ -213,16 +220,27 @@ class Database(object):
return self.data == other.data
def __sub__(self, other):
def add_tuple(table, dbtuple):
new = [table]
new.extend(dbtuple.tuple)
results.append(new)
results = []
for table in self.data:
if table not in other.data:
results.extend(self.data[table])
for dbtuple in self.data[table]:
add_tuple(table, dbtuple)
else:
for dbtuple in self.data[table]:
if dbtuple not in other.data[table]:
results.append(dbtuple)
add_tuple(table, dbtuple)
return results
def log(self, table, msg):
if self.tracer.is_traced(table):
logging.debug(msg)
def get_matches(self, atom, binding):
""" Returns a list of binding lists for the variables in ATOM
not bound in BINDING: one binding list for each tuple in
@ -231,28 +249,29 @@ class Database(object):
return []
result = []
for tuple in self.data[atom.table]:
print "Matching database tuple {}".format(str(tuple))
logging.debug("Matching database tuple {}".format(str(tuple)))
new_binding = tuple.match(atom, binding)
if new_binding is not None:
result.append(new_binding)
return result
def insert(self, table, tuple):
print "Inserting table {} tuple {} into DB".format(table, str(tuple))
self.log(table, "Inserting table {} tuple {} into DB".format(table, str(tuple)))
if table not in self.data:
self.data[table] = []
if any([dbtuple.tuple == tuple for dbtuple in self.data[table]]):
return
self.data[table].append(self.DBTuple(tuple))
def delete(self, table, binding):
print "Deleting table {} tuple {} from DB".format(table, str(tuple))
def delete(self, table, tuple):
self.log(table, "Deleting table {} tuple {} from DB".format(table, str(tuple)))
if table not in self.data:
return
self.data[table]
locs = [i for i in xrange(0,len(self.data[table]))
if self.data[table][i].tuple == tuple]
for loc in locs:
del self.data[loc]
del self.data[table][loc]
class Runtime (object):
""" Runtime for the Congress policy language. Only have one instantiation
@ -269,14 +288,17 @@ class Runtime (object):
# tracer object
self.tracer = Tracer()
def log(self, table, msg):
if self.tracer.is_traced(table):
logging.debug(msg)
def insert(self, table, tuple, atomlist=None):
""" Event handler for an insertion. """
# atomlist is used to make tests easier to read/write
if atomlist is not None:
table = atomlist[0]
tuple = atomlist[1:]
if self.tracer.is_traced(table):
print "Inserting into queue: {} with {}".format(table, str(tuple))
self.log(table, "Inserting into queue: {} with {}".format(table, str(tuple)))
# insert tuple into actual table before propagating or else self-join bug.
# Self-joins harder to fix when multi-threaded.
self.queue.enqueue(Event(table, tuple, insert=True))
@ -288,8 +310,7 @@ class Runtime (object):
if atomlist is not None:
table = atomlist[0]
tuple = atomlist[1:]
if self.tracer.is_traced(table):
print "Inserting into queue: {} with {}".format(table, str(tuple))
self.log(table, "Deleting from queue: {} with {}".format(table, str(tuple)))
self.queue.enqueue(Event(table, tuple, insert=False))
self.process_queue()
@ -306,12 +327,9 @@ class Runtime (object):
def propagate(self, event):
""" Computes events generated by EVENT and the DELTA_RULES,
and enqueues them. """
if self.tracer.is_traced(event.table):
print "Processing event: {}".format(str(event))
self.log(event.table, "Processing event: {}".format(str(event)))
if event.table not in self.delta_rules.keys():
print "event.table: {}".format(event.table)
self.print_delta_rules()
print "No applicable delta rule"
self.log(event.table, "No applicable delta rule")
return
for delta_rule in self.delta_rules[event.table]:
self.propagate_rule(event, delta_rule)
@ -319,20 +337,23 @@ class Runtime (object):
def propagate_rule(self, event, delta_rule):
""" Compute and enqueue new events generated by EVENT and DELTA_RULE. """
assert(not delta_rule.trigger.is_negated())
if self.tracer.is_traced(event.table):
print "Processing event {} with rule {}".format(str(event), str(delta_rule))
self.log(event.table, "Processing event {} with rule {}".format(
str(event), str(delta_rule)))
# compute tuples generated by event (either for insert or delete)
binding_list = match(event.tuple, delta_rule.trigger)
if binding_list is None:
return
print "binding_list for event-tuple and delta_rule trigger: " + str(binding_list)
self.log(event.table,
"binding_list for event-tuple and delta_rule trigger: {}".format(
str(binding_list)))
# vars_in_head = delta_rule.head.variable_names()
# print "vars_in_head: " + str(vars_in_head)
# needed_vars = set(vars_in_head)
# print "needed_vars: " + str(needed_vars)
new_bindings = self.top_down_eval(delta_rule.body, 0, binding_list)
print "new bindings after top-down: " + ",".join([str(x) for x in new_bindings])
self.log(event.table, "new bindings after top-down: {}".format(
",".join([str(x) for x in new_bindings])))
# enqueue effects of Event
head_table = delta_rule.head.table
@ -346,14 +367,13 @@ class Runtime (object):
are true in the Database (after applying the dictionary binding
BINDING to ATOMs). Returns a list of dictionary bindings. """
atom = atoms[atom_index]
if self.tracer.is_traced(atom.table):
print ("Top-down eval(atoms={}, atom_index={}, "
self.log(atom.table, ("Top_down_eval(atoms={}, atom_index={}, "
"bindings={})").format(
"[" + ",".join(str(x) for x in atoms) + "]",
atom_index,
str(binding))
str(binding)))
data_bindings = self.database.get_matches(atom, binding)
print "data_bindings: " + str(data_bindings)
self.log(atom.table, "data_bindings: " + str(data_bindings))
if len(data_bindings) == 0:
return []
results = []
@ -373,8 +393,8 @@ class Runtime (object):
# remove this binding from var_bindings
for var in data_binding:
del binding[var]
if self.tracer.is_traced(atom.table):
print "Return value: {}".format([str(x) for x in results])
self.log(atom.table, "Top_down_eval return value: {}".format(
'[' + ", ".join([str(x) for x in results]) + ']'))
return results

View File

@ -0,0 +1,85 @@
# Copyright (c) 2013 VMware, Inc. All rights reserved.
#
import unittest
from policy import compile
from policy import runtime
import logging
class TestRuntime(unittest.TestCase):
def setUp(self):
pass
def test_runtime(self):
def prep_runtime(code):
# compile source
c = compile.Compiler()
c.read_source(input_string=code)
c.compute_delta_rules()
run = runtime.Runtime(c.delta_rules)
tracer = runtime.Tracer()
tracer.trace('*')
run.tracer = tracer
run.database.tracer = tracer
return run
def insert(run, list):
run.insert(list[0], tuple(list[1:]))
def delete(run, list):
run.delete(list[0], tuple(list[1:]))
def check(run, correct_database_code, msg=None):
# extract correct answer from code, represented as a Database
print "** Creating correct Database **"
c = compile.Compiler()
c.read_source(input_string=correct_database_code)
correct = c.theory
correct_database = runtime.Database()
for atom in correct:
correct_database.insert(atom.table,
tuple([x.name for x in atom.arguments]))
print str(correct_database)
# compute diffs; should be empty
extra = run.database - correct_database
missing = correct_database - run.database
errmsg = ""
if len(extra) > 0:
print "Extra tuples"
print ", ".join([str(x) for x in extra])
if len(missing) > 0:
print "Missing tuples"
print ", ".join([str(x) for x in missing])
self.assertTrue(len(extra) == 0 and len(missing) == 0, msg)
code = ("q(x) :- p(x), r(x)")
run = prep_runtime(code)
check(run, "", "Empty database on init")
logging.debug("** Next test phase **")
insert(run, ['r', 1])
check(run, "r(1)", "Basic insert should Insert into base table with no propagations")
logging.debug("** Next test phase **")
delete(run, ['r', 1])
check(run, "", "Delete from base table after insert")
logging.debug("** Next test phase **")
insert(run, ['r', 1])
insert(run, ['p', 1])
check(run, "r(1) p(1) q(1)", "Insert into base table with 1 propagation")
logging.debug("** Next test phase **")
delete(run, ['r', 1])
check(run, "p(1)", "Delete from base table with 1 propagation")
# body of length 1
# existential variables
# multiple rules
# negation
if __name__ == '__main__':
unittest.main()