Added proof-tracking support

Also added many tests and fixed bugs.

Issue: #
Change-Id: I4da520469c502c166c28b381c082fc2927df4f45
This commit is contained in:
Tim Hinrichs 2013-08-14 16:02:59 -07:00
parent 80eb6a7f5e
commit 0b613444df
2 changed files with 310 additions and 75 deletions

View File

@ -42,9 +42,9 @@ class EventQueue(object):
return "[" + ",".join([str(x) for x in self.queue]) + "]"
class Event(object):
def __init__(self, table=None, tuple=None, insert=True):
def __init__(self, table=None, tuple=None, insert=True, proofs=None):
self.table = table
self.tuple = tuple
self.tuple = Database.DBTuple(tuple, proofs=proofs)
self.insert = insert
def is_insert(self):
@ -55,8 +55,7 @@ class Event(object):
sign = '+'
else:
sign = '-'
return "{}{}({})".format(self.table, sign,
",".join([str(x) for x in self.tuple]))
return "{}{}({})".format(self.table, sign, str(self.tuple))
# class Database(object):
# class DictTuple(object):
@ -157,16 +156,55 @@ class Event(object):
class Database(object):
class ProofCollection(object):
def __init__(self, proofs):
self.contents = list(proofs)
def __str__(self):
return '{' + ",".join(str(x) for x in self.contents) + '}'
def __isub__(self, other):
if other is None:
return
remaining = []
for proof in self.contents:
if proof not in other.contents:
remaining.append(proof)
self.contents = remaining
return self
def __ior__(self, other):
if other is None:
return
for proof in other.contents:
if proof not in self.contents:
self.contents.append(proof)
return self
def __len__(self):
return len(self.contents)
class DBTuple(object):
def __init__(self, tuple):
self.tuple = tuple
def __init__(self, iterable, proofs=None):
self.tuple = tuple(iterable)
if proofs is None:
proofs = []
self.proofs = Database.ProofCollection(proofs)
def __eq__(self, other):
return self.tuple == other.tuple
def __str__(self):
return str(self.tuple)
return str(self.tuple) + str(self.proofs)
def __len__(self):
return len(self.tuple)
def __getitem__(self, index):
return self.tuple[index]
def __setitem__(self, index, value):
self.tuple[index] = value
def match(self, atom, binding):
logging.debug("Checking if tuple {} matches atom {} with binding {}".format(
@ -255,23 +293,39 @@ class Database(object):
result.append(new_binding)
return result
def insert(self, table, tuple):
self.log(table, "Inserting table {} tuple {} into DB".format(table, str(tuple)))
def insert(self, table, dbtuple):
if not isinstance(dbtuple, Database.DBTuple):
dbtuple = Database.DBTuple(dbtuple)
self.log(table, "Inserting table {} tuple {} into DB".format(
table, str(dbtuple)))
if table not in self.data:
self.data[table] = []
if any([dbtuple.tuple == tuple for dbtuple in self.data[table]]):
return
self.data[table].append(self.DBTuple(tuple))
self.data[table] = [dbtuple]
else:
for existingtuple in self.data[table]:
assert(existingtuple.proofs is not None)
if existingtuple.tuple == dbtuple.tuple:
assert(existingtuple.proofs is not None)
existingtuple.proofs |= dbtuple.proofs
assert(existingtuple.proofs is not None)
return
self.data[table].append(dbtuple)
def delete(self, table, tuple):
self.log(table, "Deleting table {} tuple {} from DB".format(table, str(tuple)))
def delete(self, table, dbtuple):
if not isinstance(dbtuple, Database.DBTuple):
dbtuple = Database.DBTuple(dbtuple)
self.log(table, "Deleting table {} tuple {} from DB".format(
table, str(dbtuple)))
if table not in self.data:
return
self.data[table]
locs = [i for i in xrange(0,len(self.data[table]))
if self.data[table][i].tuple == tuple]
for loc in locs:
del self.data[table][loc]
for i in xrange(0, len(self.data[table])):
existingtuple = self.data[table][i]
self.log(table, "Checking tuple {}".format(str(existingtuple)))
if existingtuple.tuple == dbtuple.tuple:
existingtuple.proofs -= dbtuple.proofs
if len(existingtuple.proofs) == 0:
del self.data[table][i]
return
class Runtime (object):
""" Runtime for the Congress policy language. Only have one instantiation
@ -292,37 +346,41 @@ class Runtime (object):
if self.tracer.is_traced(table):
logging.debug(msg)
def insert(self, table, tuple, atomlist=None):
""" Event handler for an insertion. """
# atomlist is used to make tests easier to read/write
if atomlist is not None:
table = atomlist[0]
tuple = atomlist[1:]
self.log(table, "Inserting into queue: {} with {}".format(table, str(tuple)))
# insert tuple into actual table before propagating or else self-join bug.
# Self-joins harder to fix when multi-threaded.
def insert(self, table, tuple):
""" Event handler for an insertion.
TABLE is the name of a table (a string).
TUPLE is a Python tuple. """
if not isinstance(tuple, Database.DBTuple):
tuple = Database.DBTuple(tuple)
self.log(table, "Inserting into queue: {} with {}".format(
table, str(tuple)))
self.queue.enqueue(Event(table, tuple, insert=True))
self.process_queue()
self.process_queue() # should be running in separate daemon
def delete(self, table, tuple, atomlist=None):
""" Event handler for a deletion. """
# atomlist is used to make tests easier to read/write
if atomlist is not None:
table = atomlist[0]
tuple = atomlist[1:]
self.log(table, "Deleting from queue: {} with {}".format(table, str(tuple)))
def delete(self, table, tuple):
""" Event handler for a deletion. TUPLE is a Python tuple.
TABLE is the name of a table (a string).
TUPLE is a Python tuple. """
if not isinstance(tuple, Database.DBTuple):
tuple = Database.DBTuple(tuple)
self.log(table, "Deleting from queue: {} with {}".format(
table, str(tuple)))
self.queue.enqueue(Event(table, tuple, insert=False))
self.process_queue()
self.process_queue() # should be running in separate daemon
def process_queue(self):
""" Toplevel evaluation routine. """
while len(self.queue) > 0:
event = self.queue.dequeue()
# Note differing order of insert/delete into database.
# Insert happens before propagation; Delete happens after propagation.
# Necessary for correctness on self-joins.
if event.is_insert():
self.database.insert(event.table, event.tuple)
self.propagate(event)
else:
self.propagate(event)
self.database.delete(event.table, event.tuple)
self.propagate(event)
def propagate(self, event):
""" Computes events generated by EVENT and the DELTA_RULES,
@ -341,31 +399,45 @@ class Runtime (object):
str(event), str(delta_rule)))
# compute tuples generated by event (either for insert or delete)
# print "event: {}, event.tuple: {}, event.tuple.rawtuple(): {}".format(
# str(event), str(event.tuple), str(event.tuple.raw_tuple()))
binding_list = match(event.tuple, delta_rule.trigger)
if binding_list is None:
return
self.log(event.table,
"binding_list for event-tuple and delta_rule trigger: {}".format(
str(binding_list)))
# vars_in_head = delta_rule.head.variable_names()
# print "vars_in_head: " + str(vars_in_head)
# needed_vars = set(vars_in_head)
# print "needed_vars: " + str(needed_vars)
new_bindings = self.top_down_eval(delta_rule.body, 0, binding_list)
self.log(event.table, "new bindings after top-down: {}".format(
",".join([str(x) for x in new_bindings])))
# enqueue effects of Event
head_table = delta_rule.head.table
# for each binding, compute generated tuple and group bindings
# by the tuple they generated
new_tuples = {}
for new_binding in new_bindings:
new_tuple = tuple(plug(delta_rule.head, new_binding))
if new_tuple not in new_tuples:
new_tuples[new_tuple] = []
new_tuples[new_tuple].append(new_binding)
self.log(event.table, "new tuples generated: {}".format(
str(new_tuples)))
# enqueue each distinct generated tuple, recording appropriate bindings
head_table = delta_rule.head.table
for new_tuple in new_tuples:
# self.log(event.table,
# "new_tuple {}: {}".format(str(new_tuple), str(new_tuples[new_tuple])))
self.queue.enqueue(Event(table=head_table,
tuple=plug(delta_rule.head, new_binding),
insert=event.insert))
tuple=new_tuple,
proofs=new_tuples[new_tuple],
insert=event.insert))
def top_down_eval(self, atoms, atom_index, binding):
""" Compute all instances of ATOMS (from ATOM_INDEX and above) that
are true in the Database (after applying the dictionary binding
BINDING to ATOMs). Returns a list of dictionary bindings. """
if atom_index > len(atoms) - 1:
return [binding]
atom = atoms[atom_index]
self.log(atom.table, ("Top_down_eval(atoms={}, atom_index={}, "
"bindings={})").format(
@ -378,23 +450,17 @@ class Runtime (object):
return []
results = []
for data_binding in data_bindings:
# add this binding to var_bindings
# add new binding to current binding
binding.update(data_binding)
if atom_index == len(atoms) - 1: # last element in atoms
# construct result
# output_binding = {}
# for var in projection:
# output_binding[var] = binding[var]
# results.append(output_binding)
results.append(dict(binding)) # need to copy
else:
# recurse
results.extend(self.top_down_eval(atoms, atom_index + 1, binding))
# remove this binding from var_bindings
# remove new binding from current bindings
for var in data_binding:
del binding[var]
self.log(atom.table, "Top_down_eval return value: {}".format(
'[' + ", ".join([str(x) for x in results]) + ']'))
# self.log(atom.table, "Top_down_eval return value: {}".format(
# '[' + ", ".join([str(x) for x in results]) + ']'))
return results

View File

@ -12,8 +12,10 @@ class TestRuntime(unittest.TestCase):
pass
def test_runtime(self):
def prep_runtime(code):
def prep_runtime(code, msg=None):
# compile source
if msg is not None:
logging.debug(msg)
c = compile.Compiler()
c.read_source(input_string=code)
c.compute_delta_rules()
@ -31,54 +33,221 @@ class TestRuntime(unittest.TestCase):
run.delete(list[0], tuple(list[1:]))
def check(run, correct_database_code, msg=None):
# extract correct answer from code, represented as a Database
print "** Creating correct Database **"
# extract correct answer from correct_database_code
logging.debug("** Checking {} **".format(msg))
c = compile.Compiler()
c.read_source(input_string=correct_database_code)
correct = c.theory
correct_database = runtime.Database()
for atom in correct:
correct_database.insert(atom.table,
tuple([x.name for x in atom.arguments]))
print str(correct_database)
[x.name for x in atom.arguments])
# compute diffs; should be empty
extra = run.database - correct_database
missing = correct_database - run.database
errmsg = ""
if len(extra) > 0:
print "Extra tuples"
print ", ".join([str(x) for x in extra])
logging.debug("Extra tuples")
logging.debug(", ".join([str(x) for x in extra]))
if len(missing) > 0:
print "Missing tuples"
print ", ".join([str(x) for x in missing])
logging.debug("Missing tuples")
logging.debug(", ".join([str(x) for x in missing]))
self.assertTrue(len(extra) == 0 and len(missing) == 0, msg)
logging.debug(str(run.database))
logging.debug("** Finished {} **".format(msg))
def showdb(run):
logging.debug("Resulting DB: " + str(run.database))
# basic tests
code = ("q(x) :- p(x), r(x)")
run = prep_runtime(code)
run = prep_runtime(code, "**** Basic tests ****")
check(run, "", "Empty database on init")
logging.debug("** Next test phase **")
insert(run, ['r', 1])
check(run, "r(1)", "Basic insert should Insert into base table with no propagations")
logging.debug("** Next test phase **")
check(run, "r(1)", "Basic insert with no propagations")
insert(run, ['r', 1])
check(run, "r(1)", "Duplicate insert with no propagations")
delete(run, ['r', 1])
check(run, "", "Delete from base table after insert")
logging.debug("** Next test phase **")
check(run, "", "Delete with no propagations")
delete(run, ['r', 1])
check(run, "", "Delete from empty table")
insert(run, ['r', 1])
insert(run, ['p', 1])
check(run, "r(1) p(1) q(1)", "Insert into base table with 1 propagation")
logging.debug("** Next test phase **")
showdb(run)
delete(run, ['r', 1])
check(run, "p(1)", "Delete from base table with 1 propagation")
showdb(run)
# multiple rules
code = ("q(x) :- p(x), r(x)"
"q(x) :- s(x)")
insert(run, ['p', 1])
insert(run, ['r', 1])
showdb(run)
check(run, "p(1) r(1) q(1)", "Insert: multiple rules")
insert(run, ['s', 1])
showdb(run)
check(run, "p(1) r(1) s(1) q(1)", "Insert: duplicate conclusions")
# body of length 1
code = ("q(x) :- p(x)")
run = prep_runtime(code, "**** Body length 1 tests ****")
insert(run, ['p', 1])
check(run, "p(1) q(1)", "Insert with body of size 1")
delete(run, ['p', 1])
check(run, "", "Delete with body of size 1")
# existential variables
# multiple rules
code = ("q(x) :- p(x,y)")
run = prep_runtime(code, "**** Existential variable tests ****")
insert(run, ['p', 1, 2])
check(run, "p(1, 2) q(1)", "Insert: existential variable in body of size 1")
delete(run, ['p', 1, 2])
check(run, "", "Delete: existential variable in body of size 1")
code = ("q(x) :- p(x,y), r(y,x)")
run = prep_runtime(code)
insert(run, ['p', 1, 2])
showdb(run)
insert(run, ['r', 2, 1])
showdb(run)
check(run, "p(1, 2) r(2, 1) q(1)", "Insert: join in body of size 2")
delete(run, ['p', 1, 2])
showdb(run)
check(run, "r(2, 1)", "Delete: join in body of size 2")
insert(run, ['p', 1, 2])
showdb(run)
insert(run, ['p', 1, 3])
showdb(run)
insert(run, ['r', 3, 1])
showdb(run)
check(run, "r(2, 1) r(3,1) p(1, 2) p(1, 3) q(1)",
"Insert: multiple existential bindings for same head")
delete(run, ['p', 1, 2])
check(run, "r(2, 1) r(3,1) p(1, 3) q(1)",
"Delete: multiple existential bindings for same head")
code = ("q(x,v) :- p(x,y), r(y,z), s(z,w), t(w,v)")
run = prep_runtime(code)
insert(run, ['p', 1, 10])
insert(run, ['p', 1, 20])
insert(run, ['r', 10, 100])
insert(run, ['r', 20, 200])
insert(run, ['s', 100, 1000])
insert(run, ['s', 200, 2000])
insert(run, ['t', 1000, 10000])
insert(run, ['t', 2000, 20000])
code = ("p(1,10) p(1,20) r(10,100) r(20,200) s(100,1000) s(200,2000)"
"t(1000, 10000) t(2000,20000) "
"q(1,10000) q(1,20000)")
check(run, code, "Insert: larger join")
delete(run, ['t', 1000, 10000])
code = ("p(1,10) p(1,20) r(10,100) r(20,200) s(100,1000) s(200,2000)"
"t(2000,20000) "
"q(1,20000)")
check(run, code, "Delete: larger join")
code = ("q(x,y) :- p(x,z), p(z,y)")
run = prep_runtime(code)
insert(run, ['p', 1, 2])
insert(run, ['p', 1, 3])
insert(run, ['p', 2, 4])
insert(run, ['p', 2, 5])
check(run, 'p(1,2) p(1,3) p(2,4) p(2,5) q(1,4) q(1,5)',
"Insert: self-join")
delete(run, ['p', 2, 4])
check(run, 'p(1,2) p(1,3) p(2,5) q(1,5)')
code = ("q(x,w) :- p(x,y), p(y,z), p(z,w)")
run = prep_runtime(code)
insert(run, ['p', 1, 1])
insert(run, ['p', 1, 2])
insert(run, ['p', 2, 2])
insert(run, ['p', 2, 3])
insert(run, ['p', 2, 4])
insert(run, ['p', 2, 5])
insert(run, ['p', 3, 3])
insert(run, ['p', 3, 4])
insert(run, ['p', 3, 5])
insert(run, ['p', 3, 6])
insert(run, ['p', 3, 7])
code = ('p(1,1) p(1,2) p(2,2) p(2,3) p(2,4) p(2,5)'
'p(3,3) p(3,4) p(3,5) p(3,6) p(3,7)'
'q(1,1) q(1,2) q(2,2) q(2,3) q(2,4) q(2,5)'
'q(3,3) q(3,4) q(3,5) q(3,6) q(3,7)'
'q(1,3) q(1,4) q(1,5) q(1,6) q(1,7)'
'q(2,6) q(2,7)')
check(run, code, "Insert: larger self join")
delete(run, ['p', 1, 1])
delete(run, ['p', 2, 2])
code = (' p(1,2) p(2,3) p(2,4) p(2,5)'
'p(3,3) p(3,4) p(3,5) p(3,6) p(3,7)'
' q(2,3) q(2,4) q(2,5)'
'q(3,3) q(3,4) q(3,5) q(3,6) q(3,7)'
'q(1,3) q(1,4) q(1,5) q(1,6) q(1,7)'
'q(2,6) q(2,7)')
check(run, code, "Delete: larger self join")
# Value types: string
code = ("q(x) :- p(x), r(x)")
run = prep_runtime(code, "String data type")
insert(run, ['r', 'apple'])
check(run, 'r("apple")', "String insert with no propagations")
insert(run, ['r', 'apple'])
check(run, 'r("apple")', "Duplicate string insert with no propagations")
delete(run, ['r', 'apple'])
check(run, "", "Delete with no propagations")
delete(run, ['r', 'apple'])
check(run, "", "Delete from empty table")
insert(run, ['r', 'apple'])
insert(run, ['p', 'apple'])
check(run, 'r("apple") p("apple") q("apple")',
"String insert with 1 propagation")
showdb(run)
delete(run, ['r', 'apple'])
check(run, 'p("apple")', "String delete with 1 propagation")
showdb(run)
# Value types: floats
code = ("q(x) :- p(x), r(x)")
run = prep_runtime(code, "Float data type")
insert(run, ['r', 1.2])
check(run, 'r(1.2)', "String insert with no propagations")
insert(run, ['r', 1.2])
check(run, 'r(1.2)', "Duplicate string insert with no propagations")
delete(run, ['r', 1.2])
check(run, "", "Delete with no propagations")
delete(run, ['r', 1.2])
check(run, "", "Delete from empty table")
insert(run, ['r', 1.2])
insert(run, ['p', 1.2])
check(run, 'r(1.2) p(1.2) q(1.2)',
"String insert with 1 propagation")
showdb(run)
delete(run, ['r', 1.2])
check(run, 'p(1.2)', "String delete with 1 propagation")
showdb(run)
# negation
if __name__ == '__main__':