Adding runtime tests
Issue: # Change-Id: Iddf142d48b7b9c9ac2a78130ea46374fc92bb3d5
This commit is contained in:
parent
705bef6d8b
commit
80eb6a7f5e
|
@ -35,8 +35,6 @@ class Location (object):
|
|||
|
||||
def __str__(self):
|
||||
s = ""
|
||||
if self.file is not None:
|
||||
s += " file: {}".format(self.file)
|
||||
if self.line is not None:
|
||||
s += " line: {}".format(self.line)
|
||||
if self.col is not None:
|
||||
|
@ -221,58 +219,6 @@ class Compiler (object):
|
|||
self.delta_rules.append(
|
||||
runtime.DeltaRule(literal, rule.head, newbody))
|
||||
|
||||
class Tester (object):
|
||||
|
||||
def test_runtime(self):
|
||||
def prep_runtime(code):
|
||||
# compile source
|
||||
c = Compiler()
|
||||
c.read_source(input_string=code)
|
||||
c.compute_delta_rules()
|
||||
run = runtime.Runtime(c.delta_rules)
|
||||
return run
|
||||
|
||||
def insert(run, list):
|
||||
run.insert(list[0], tuple(list[1:]))
|
||||
|
||||
def delete(run, list):
|
||||
run.delete(list[0], tuple(list[1:]))
|
||||
|
||||
def check(run, correct_database_code, msg=None):
|
||||
# extract correct answer from code, represented as a Database
|
||||
print "** Reading correct Database **"
|
||||
c = Compiler()
|
||||
c.read_source(input_string=correct_database_code)
|
||||
correct = c.theory
|
||||
correct_database = runtime.Database()
|
||||
for atom in correct:
|
||||
correct_database.insert(atom.table,
|
||||
tuple([x.name for x in atom.arguments]))
|
||||
print "** Correct Database **"
|
||||
print str(correct_database)
|
||||
|
||||
# ensure correct answers is a subset of run.database
|
||||
extra = run.database - correct_database
|
||||
missing = correct_database - run.database
|
||||
if len(extra) > 0 or len(missing) > 0:
|
||||
print "Test {} failed".format(msg)
|
||||
if len(extra) > 0:
|
||||
print "Extra tuples: {}".format(str(extra))
|
||||
if len(missing) > 0:
|
||||
print "Missing tuples: {}".format(str(extra))
|
||||
|
||||
code = ("q(x) :- p(x), r(x)")
|
||||
run = prep_runtime(code)
|
||||
print "Finished prep_runtime"
|
||||
insert(run, ['r', 1])
|
||||
check(run, "r(1)")
|
||||
print "Finished first insert"
|
||||
insert(run, ['p', 1])
|
||||
check(run, "r(1) p(1) q(1)")
|
||||
print "Finished second insert"
|
||||
print "**Final State**"
|
||||
print str(run.database)
|
||||
|
||||
class CongressSyntax (object):
|
||||
""" External syntax and converting it into internal representation. """
|
||||
|
||||
|
@ -334,6 +280,8 @@ class CongressSyntax (object):
|
|||
return cls.create_atom(antlr)
|
||||
elif obj == 'THEORY':
|
||||
return [cls.create(x) for x in antlr.children]
|
||||
elif obj == '<EOF>':
|
||||
return []
|
||||
else:
|
||||
raise CongressException(
|
||||
"Antlr tree with unknown root: {}".format(obj))
|
||||
|
@ -421,8 +369,6 @@ def main():
|
|||
compiler = Compiler()
|
||||
compiler.read_source(inputs[0])
|
||||
compiler.compute_delta_rules()
|
||||
test = Tester()
|
||||
test.test_runtime()
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main())
|
|
@ -1,6 +1,7 @@
|
|||
#! /usr/bin/python
|
||||
|
||||
import collections
|
||||
import logging
|
||||
|
||||
class Tracer(object):
|
||||
def __init__(self):
|
||||
|
@ -8,7 +9,7 @@ class Tracer(object):
|
|||
def trace(self, table):
|
||||
self.expressions.append(table)
|
||||
def is_traced(self, table):
|
||||
return table in self.expressions or '?' in self.expressions
|
||||
return table in self.expressions or '*' in self.expressions
|
||||
|
||||
class CongressRuntime (Exception):
|
||||
pass
|
||||
|
@ -50,7 +51,11 @@ class Event(object):
|
|||
return self.insert
|
||||
|
||||
def __str__(self):
|
||||
return "{}({})".format(self.table,
|
||||
if self.is_insert():
|
||||
sign = '+'
|
||||
else:
|
||||
sign = '-'
|
||||
return "{}{}({})".format(self.table, sign,
|
||||
",".join([str(x) for x in self.tuple]))
|
||||
|
||||
# class Database(object):
|
||||
|
@ -159,12 +164,13 @@ class Database(object):
|
|||
def __eq__(self, other):
|
||||
return self.tuple == other.tuple
|
||||
|
||||
|
||||
def __str__(self):
|
||||
return str(self.tuple)
|
||||
|
||||
def match(self, atom, binding):
|
||||
print "Checking if tuple {} matches atom {} with binding {}".format(
|
||||
str(self), str(atom), str(binding))
|
||||
logging.debug("Checking if tuple {} matches atom {} with binding {}".format(
|
||||
str(self), str(atom), str(binding)))
|
||||
if len(self.tuple) != len(atom.arguments):
|
||||
return None
|
||||
new_binding = {}
|
||||
|
@ -175,7 +181,7 @@ class Database(object):
|
|||
return None
|
||||
else:
|
||||
new_binding[atom.arguments[i].name] = self.tuple[i]
|
||||
print "Check succeeded with binding {}".format(str(new_binding))
|
||||
logging.debug("Check succeeded with binding {}".format(str(new_binding)))
|
||||
return new_binding
|
||||
|
||||
class Schema (object):
|
||||
|
@ -187,6 +193,7 @@ class Database(object):
|
|||
def __init__(self):
|
||||
self.data = {}
|
||||
self.schemas = {} # not currently used
|
||||
self.tracer = Tracer()
|
||||
|
||||
def __str__(self):
|
||||
def hash2str (h):
|
||||
|
@ -213,16 +220,27 @@ class Database(object):
|
|||
return self.data == other.data
|
||||
|
||||
def __sub__(self, other):
|
||||
def add_tuple(table, dbtuple):
|
||||
new = [table]
|
||||
new.extend(dbtuple.tuple)
|
||||
results.append(new)
|
||||
|
||||
results = []
|
||||
for table in self.data:
|
||||
if table not in other.data:
|
||||
results.extend(self.data[table])
|
||||
for dbtuple in self.data[table]:
|
||||
add_tuple(table, dbtuple)
|
||||
else:
|
||||
for dbtuple in self.data[table]:
|
||||
if dbtuple not in other.data[table]:
|
||||
results.append(dbtuple)
|
||||
add_tuple(table, dbtuple)
|
||||
|
||||
return results
|
||||
|
||||
def log(self, table, msg):
|
||||
if self.tracer.is_traced(table):
|
||||
logging.debug(msg)
|
||||
|
||||
def get_matches(self, atom, binding):
|
||||
""" Returns a list of binding lists for the variables in ATOM
|
||||
not bound in BINDING: one binding list for each tuple in
|
||||
|
@ -231,28 +249,29 @@ class Database(object):
|
|||
return []
|
||||
result = []
|
||||
for tuple in self.data[atom.table]:
|
||||
print "Matching database tuple {}".format(str(tuple))
|
||||
logging.debug("Matching database tuple {}".format(str(tuple)))
|
||||
new_binding = tuple.match(atom, binding)
|
||||
if new_binding is not None:
|
||||
result.append(new_binding)
|
||||
return result
|
||||
|
||||
def insert(self, table, tuple):
|
||||
print "Inserting table {} tuple {} into DB".format(table, str(tuple))
|
||||
self.log(table, "Inserting table {} tuple {} into DB".format(table, str(tuple)))
|
||||
if table not in self.data:
|
||||
self.data[table] = []
|
||||
if any([dbtuple.tuple == tuple for dbtuple in self.data[table]]):
|
||||
return
|
||||
self.data[table].append(self.DBTuple(tuple))
|
||||
|
||||
def delete(self, table, binding):
|
||||
print "Deleting table {} tuple {} from DB".format(table, str(tuple))
|
||||
def delete(self, table, tuple):
|
||||
self.log(table, "Deleting table {} tuple {} from DB".format(table, str(tuple)))
|
||||
if table not in self.data:
|
||||
return
|
||||
self.data[table]
|
||||
locs = [i for i in xrange(0,len(self.data[table]))
|
||||
if self.data[table][i].tuple == tuple]
|
||||
for loc in locs:
|
||||
del self.data[loc]
|
||||
del self.data[table][loc]
|
||||
|
||||
class Runtime (object):
|
||||
""" Runtime for the Congress policy language. Only have one instantiation
|
||||
|
@ -269,14 +288,17 @@ class Runtime (object):
|
|||
# tracer object
|
||||
self.tracer = Tracer()
|
||||
|
||||
def log(self, table, msg):
|
||||
if self.tracer.is_traced(table):
|
||||
logging.debug(msg)
|
||||
|
||||
def insert(self, table, tuple, atomlist=None):
|
||||
""" Event handler for an insertion. """
|
||||
# atomlist is used to make tests easier to read/write
|
||||
if atomlist is not None:
|
||||
table = atomlist[0]
|
||||
tuple = atomlist[1:]
|
||||
if self.tracer.is_traced(table):
|
||||
print "Inserting into queue: {} with {}".format(table, str(tuple))
|
||||
self.log(table, "Inserting into queue: {} with {}".format(table, str(tuple)))
|
||||
# insert tuple into actual table before propagating or else self-join bug.
|
||||
# Self-joins harder to fix when multi-threaded.
|
||||
self.queue.enqueue(Event(table, tuple, insert=True))
|
||||
|
@ -288,8 +310,7 @@ class Runtime (object):
|
|||
if atomlist is not None:
|
||||
table = atomlist[0]
|
||||
tuple = atomlist[1:]
|
||||
if self.tracer.is_traced(table):
|
||||
print "Inserting into queue: {} with {}".format(table, str(tuple))
|
||||
self.log(table, "Deleting from queue: {} with {}".format(table, str(tuple)))
|
||||
self.queue.enqueue(Event(table, tuple, insert=False))
|
||||
self.process_queue()
|
||||
|
||||
|
@ -306,12 +327,9 @@ class Runtime (object):
|
|||
def propagate(self, event):
|
||||
""" Computes events generated by EVENT and the DELTA_RULES,
|
||||
and enqueues them. """
|
||||
if self.tracer.is_traced(event.table):
|
||||
print "Processing event: {}".format(str(event))
|
||||
self.log(event.table, "Processing event: {}".format(str(event)))
|
||||
if event.table not in self.delta_rules.keys():
|
||||
print "event.table: {}".format(event.table)
|
||||
self.print_delta_rules()
|
||||
print "No applicable delta rule"
|
||||
self.log(event.table, "No applicable delta rule")
|
||||
return
|
||||
for delta_rule in self.delta_rules[event.table]:
|
||||
self.propagate_rule(event, delta_rule)
|
||||
|
@ -319,20 +337,23 @@ class Runtime (object):
|
|||
def propagate_rule(self, event, delta_rule):
|
||||
""" Compute and enqueue new events generated by EVENT and DELTA_RULE. """
|
||||
assert(not delta_rule.trigger.is_negated())
|
||||
if self.tracer.is_traced(event.table):
|
||||
print "Processing event {} with rule {}".format(str(event), str(delta_rule))
|
||||
self.log(event.table, "Processing event {} with rule {}".format(
|
||||
str(event), str(delta_rule)))
|
||||
|
||||
# compute tuples generated by event (either for insert or delete)
|
||||
binding_list = match(event.tuple, delta_rule.trigger)
|
||||
if binding_list is None:
|
||||
return
|
||||
print "binding_list for event-tuple and delta_rule trigger: " + str(binding_list)
|
||||
self.log(event.table,
|
||||
"binding_list for event-tuple and delta_rule trigger: {}".format(
|
||||
str(binding_list)))
|
||||
# vars_in_head = delta_rule.head.variable_names()
|
||||
# print "vars_in_head: " + str(vars_in_head)
|
||||
# needed_vars = set(vars_in_head)
|
||||
# print "needed_vars: " + str(needed_vars)
|
||||
new_bindings = self.top_down_eval(delta_rule.body, 0, binding_list)
|
||||
print "new bindings after top-down: " + ",".join([str(x) for x in new_bindings])
|
||||
self.log(event.table, "new bindings after top-down: {}".format(
|
||||
",".join([str(x) for x in new_bindings])))
|
||||
|
||||
# enqueue effects of Event
|
||||
head_table = delta_rule.head.table
|
||||
|
@ -346,14 +367,13 @@ class Runtime (object):
|
|||
are true in the Database (after applying the dictionary binding
|
||||
BINDING to ATOMs). Returns a list of dictionary bindings. """
|
||||
atom = atoms[atom_index]
|
||||
if self.tracer.is_traced(atom.table):
|
||||
print ("Top-down eval(atoms={}, atom_index={}, "
|
||||
self.log(atom.table, ("Top_down_eval(atoms={}, atom_index={}, "
|
||||
"bindings={})").format(
|
||||
"[" + ",".join(str(x) for x in atoms) + "]",
|
||||
atom_index,
|
||||
str(binding))
|
||||
str(binding)))
|
||||
data_bindings = self.database.get_matches(atom, binding)
|
||||
print "data_bindings: " + str(data_bindings)
|
||||
self.log(atom.table, "data_bindings: " + str(data_bindings))
|
||||
if len(data_bindings) == 0:
|
||||
return []
|
||||
results = []
|
||||
|
@ -373,8 +393,8 @@ class Runtime (object):
|
|||
# remove this binding from var_bindings
|
||||
for var in data_binding:
|
||||
del binding[var]
|
||||
if self.tracer.is_traced(atom.table):
|
||||
print "Return value: {}".format([str(x) for x in results])
|
||||
self.log(atom.table, "Top_down_eval return value: {}".format(
|
||||
'[' + ", ".join([str(x) for x in results]) + ']'))
|
||||
|
||||
return results
|
||||
|
||||
|
|
|
@ -0,0 +1,85 @@
|
|||
# Copyright (c) 2013 VMware, Inc. All rights reserved.
|
||||
#
|
||||
|
||||
import unittest
|
||||
from policy import compile
|
||||
from policy import runtime
|
||||
import logging
|
||||
|
||||
class TestRuntime(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
def test_runtime(self):
|
||||
def prep_runtime(code):
|
||||
# compile source
|
||||
c = compile.Compiler()
|
||||
c.read_source(input_string=code)
|
||||
c.compute_delta_rules()
|
||||
run = runtime.Runtime(c.delta_rules)
|
||||
tracer = runtime.Tracer()
|
||||
tracer.trace('*')
|
||||
run.tracer = tracer
|
||||
run.database.tracer = tracer
|
||||
return run
|
||||
|
||||
def insert(run, list):
|
||||
run.insert(list[0], tuple(list[1:]))
|
||||
|
||||
def delete(run, list):
|
||||
run.delete(list[0], tuple(list[1:]))
|
||||
|
||||
def check(run, correct_database_code, msg=None):
|
||||
# extract correct answer from code, represented as a Database
|
||||
print "** Creating correct Database **"
|
||||
c = compile.Compiler()
|
||||
c.read_source(input_string=correct_database_code)
|
||||
correct = c.theory
|
||||
correct_database = runtime.Database()
|
||||
for atom in correct:
|
||||
correct_database.insert(atom.table,
|
||||
tuple([x.name for x in atom.arguments]))
|
||||
print str(correct_database)
|
||||
|
||||
# compute diffs; should be empty
|
||||
extra = run.database - correct_database
|
||||
missing = correct_database - run.database
|
||||
errmsg = ""
|
||||
if len(extra) > 0:
|
||||
print "Extra tuples"
|
||||
print ", ".join([str(x) for x in extra])
|
||||
if len(missing) > 0:
|
||||
print "Missing tuples"
|
||||
print ", ".join([str(x) for x in missing])
|
||||
self.assertTrue(len(extra) == 0 and len(missing) == 0, msg)
|
||||
|
||||
|
||||
code = ("q(x) :- p(x), r(x)")
|
||||
run = prep_runtime(code)
|
||||
check(run, "", "Empty database on init")
|
||||
logging.debug("** Next test phase **")
|
||||
|
||||
insert(run, ['r', 1])
|
||||
check(run, "r(1)", "Basic insert should Insert into base table with no propagations")
|
||||
logging.debug("** Next test phase **")
|
||||
|
||||
delete(run, ['r', 1])
|
||||
check(run, "", "Delete from base table after insert")
|
||||
logging.debug("** Next test phase **")
|
||||
|
||||
insert(run, ['r', 1])
|
||||
insert(run, ['p', 1])
|
||||
check(run, "r(1) p(1) q(1)", "Insert into base table with 1 propagation")
|
||||
logging.debug("** Next test phase **")
|
||||
|
||||
delete(run, ['r', 1])
|
||||
check(run, "p(1)", "Delete from base table with 1 propagation")
|
||||
|
||||
# body of length 1
|
||||
# existential variables
|
||||
# multiple rules
|
||||
# negation
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in New Issue