Modifications to organization of runtime, to enable testing

Turned runtime element into a class

Issue: #
Change-Id: I9431814475d0dc58fbe7c73fd8b59e24c381fb15
This commit is contained in:
Tim Hinrichs 2013-08-13 13:32:30 -07:00
parent 1c796ccec4
commit 705bef6d8b
3 changed files with 253 additions and 158 deletions

View File

@ -85,7 +85,6 @@ object_constant
: INT -> ^(INTEGER_OBJ INT)
| FLOAT -> ^(FLOAT_OBJ FLOAT)
| STRING -> ^(STRING_OBJ STRING)
| SYMBOL -> ^(SYMBOL_OBJ SYMBOL)
;
variable

View File

@ -43,8 +43,25 @@ class Location (object):
s += " col: {}".format(self.col)
return s
class Term(object):
""" Represents the union of Variable and ObjectConstant. Should
only be instantiated via factory method. """
def __init__(self):
assert False, "Cannot instantiate Term directly--use factory method"
class Variable (object):
@classmethod
def create_from_python(cls, value):
if isinstance(value, basestring):
return ObjectConstant(value, ObjectConstant.STRING)
elif isinstance(value, (int, long)):
return ObjectConstant(value, ObjectConstant.INTEGER)
elif isinstance(value, float):
return ObjectConstant(value, ObjectConstant.FLOAT)
else:
return Variable(value)
class Variable (Term):
""" Represents a term without a fixed value. """
def __init__(self, name, location=None):
self.name = name
@ -59,18 +76,17 @@ class Variable (object):
def is_object(self):
return False
class ObjectConstant (object):
class ObjectConstant (Term):
""" Represents a term with a fixed value. """
STRING = 'STRING'
FLOAT = 'FLOAT'
INTEGER = 'INTEGER'
SYMBOL = 'SYMBOL'
def __init__(self, name, type, location=None):
assert(type in [self.STRING, self.FLOAT, self.INTEGER])
self.name = name
self.type = type
self.location = location
assert(self.type in [self.STRING, self.FLOAT, self.INTEGER, self.SYMBOL])
def __str__(self):
return str(self.name)
@ -88,6 +104,15 @@ class Atom (object):
self.arguments = arguments
self.location = location
@classmethod
def create_from_list(cls, list):
""" LIST is a python list representing an atom, e.g.
['p', 17, "string", 3.14]. Returns the corresponding Atom. """
arguments = []
for i in xrange(1, len(list)):
arguments.append(Term.create_from_python(list[i]))
return cls(list[0], arguments)
def __str__(self):
return "{}({})".format(self.table,
", ".join([str(x) for x in self.arguments]))
@ -161,12 +186,14 @@ class Compiler (object):
s += 'None'
return s
def read_source(self, file):
def read_source(self, input_file=None, input_string=None):
assert(input_file is not None or input_string is not None)
# parse input file and convert to internal representation
self.raw_syntax_tree = CongressSyntax.parse_file(file)
self.print_parse_result()
self.raw_syntax_tree = CongressSyntax.parse_file(input_file=input_file,
input_string=input_string)
#self.print_parse_result()
self.theory = CongressSyntax.create(self.raw_syntax_tree)
print str(self)
#print str(self)
def print_parse_result(self):
print_tree(
@ -194,26 +221,57 @@ class Compiler (object):
self.delta_rules.append(
runtime.DeltaRule(literal, rule.head, newbody))
def test_runtime(self):
# init runtime's delta rules.
runtime.delta_rules = {}
for delta in self.delta_rules:
if delta.trigger.table not in runtime.delta_rules:
runtime.delta_rules[delta.trigger.table] = [delta]
else:
runtime.delta_rules[delta.trigger.table].append(delta)
runtime.print_delta_rules()
# insert stuff
# BUG: handle case where only 1 element in body
# BUG: self-joins require inserting data into database before computing updates
runtime.tracer.trace('?')
runtime.handle_insert('p', tuple([1]))
print "**Final State**"
print str(runtime.database)
# print "p's contents: {}".format(runtime.database.data['p'])
# print "q's contents: {}".format(runtime.database.data['q'])
# print "q's contents: {}".format(runtime.database.data['q'])
class Tester (object):
def test_runtime(self):
def prep_runtime(code):
# compile source
c = Compiler()
c.read_source(input_string=code)
c.compute_delta_rules()
run = runtime.Runtime(c.delta_rules)
return run
def insert(run, list):
run.insert(list[0], tuple(list[1:]))
def delete(run, list):
run.delete(list[0], tuple(list[1:]))
def check(run, correct_database_code, msg=None):
# extract correct answer from code, represented as a Database
print "** Reading correct Database **"
c = Compiler()
c.read_source(input_string=correct_database_code)
correct = c.theory
correct_database = runtime.Database()
for atom in correct:
correct_database.insert(atom.table,
tuple([x.name for x in atom.arguments]))
print "** Correct Database **"
print str(correct_database)
# ensure correct answers is a subset of run.database
extra = run.database - correct_database
missing = correct_database - run.database
if len(extra) > 0 or len(missing) > 0:
print "Test {} failed".format(msg)
if len(extra) > 0:
print "Extra tuples: {}".format(str(extra))
if len(missing) > 0:
print "Missing tuples: {}".format(str(extra))
code = ("q(x) :- p(x), r(x)")
run = prep_runtime(code)
print "Finished prep_runtime"
insert(run, ['r', 1])
check(run, "r(1)")
print "Finished first insert"
insert(run, ['p', 1])
check(run, "r(1) p(1) q(1)")
print "Finished second insert"
print "**Final State**"
print str(run.database)
class CongressSyntax (object):
""" External syntax and converting it into internal representation. """
@ -247,8 +305,12 @@ class CongressSyntax (object):
e.line, e.charPositionInLine)
@classmethod
def parse_file(cls, filename):
char_stream = antlr3.ANTLRFileStream(filename)
def parse_file(cls, input_file=None, input_string=None):
assert(input_file is not None or input_string is not None)
if input_file is not None:
char_stream = antlr3.ANTLRFileStream(input_file)
else:
char_stream = antlr3.ANTLRStringStream(input_string)
lexer = cls.Lexer(char_stream)
tokens = antlr3.CommonTokenStream(lexer)
parser = cls.Parser(tokens)
@ -322,21 +384,18 @@ class CongressSyntax (object):
loc = Location(line=antlr.children[0].token.line,
col=antlr.children[0].token.charPositionInLine)
if op == 'STRING_OBJ':
return ObjectConstant(antlr.children[0].getText(),
value = antlr.children[0].getText()
return ObjectConstant(value[1:len(value) - 1], # prune quotes
ObjectConstant.STRING,
location=loc)
elif op == 'INTEGER_OBJ':
return ObjectConstant(antlr.children[0].getText(),
return ObjectConstant(int(antlr.children[0].getText()),
ObjectConstant.INTEGER,
location=loc)
elif op == 'FLOAT_OBJ':
return ObjectConstant(antlr.children[0].getText(),
return ObjectConstant(float(antlr.children[0].getText()),
ObjectConstant.FLOAT,
location=loc)
elif op == 'SYMBOL_OBJ':
return ObjectConstant(antlr.children[0].getText(),
ObjectConstant.SYMBOL,
location=loc)
elif op == 'VARIABLE':
return Variable(antlr.children[0].getText(), location=loc)
else:
@ -362,7 +421,8 @@ def main():
compiler = Compiler()
compiler.read_source(inputs[0])
compiler.compute_delta_rules()
compiler.test_runtime()
test = Tester()
test.test_runtime()
if __name__ == '__main__':
sys.exit(main())

View File

@ -185,10 +185,8 @@ class Database(object):
return str(self.arguments)
def __init__(self):
self.data = {'p': [], 'q': [], 'r': [self.DBTuple((1,))]}
self.schemas = {'p': Database.Schema([1]),
'q': Database.Schema([1]),
'r': Database.Schema([1])}
self.data = {}
self.schemas = {} # not currently used
def __str__(self):
def hash2str (h):
@ -207,15 +205,30 @@ class Database(object):
strings.append(s)
return '{' + ", ".join(strings) + '}'
return "<data: {}, \nschemas: {}>".format(
hashlist2str(self.data), hash2str(self.schemas))
return hashlist2str(self.data)
# return "<data: {}, \nschemas: {}>".format(
# hashlist2str(self.data), hash2str(self.schemas))
def __eq__(self, other):
return self.data == other.data
def __sub__(self, other):
results = []
for table in self.data:
if table not in other.data:
results.extend(self.data[table])
else:
for dbtuple in self.data[table]:
if dbtuple not in other.data[table]:
results.append(dbtuple)
return results
def get_matches(self, atom, binding):
""" Returns a list of binding lists for the variables in ATOM
not bound in BINDING: one binding list for each tuple in
the database matching ATOM under BINDING. """
if atom.table not in self.data:
raise CongressRuntime("Table not found ".format(table))
return []
result = []
for tuple in self.data[atom.table]:
print "Matching database tuple {}".format(str(tuple))
@ -227,8 +240,7 @@ class Database(object):
def insert(self, table, tuple):
print "Inserting table {} tuple {} into DB".format(table, str(tuple))
if table not in self.data:
raise CongressRuntime("Table not found ".format(table))
# if already present, ignore
self.data[table] = []
if any([dbtuple.tuple == tuple for dbtuple in self.data[table]]):
return
self.data[table].append(self.DBTuple(tuple))
@ -236,85 +248,151 @@ class Database(object):
def delete(self, table, binding):
print "Deleting table {} tuple {} from DB".format(table, str(tuple))
if table not in self.data:
raise CongressRuntime("Table not found ".format(table))
return
locs = [i for i in xrange(0,len(self.data[table]))
if self.data[table][i].tuple == tuple]
for loc in locs:
del self.data[loc]
class Runtime (object):
""" Runtime for the Congress policy language. Only have one instantiation
in practice, but using a class is natural and useful for testing. """
# queue of events left to process
queue = EventQueue()
# collection of all tables
database = Database()
# update rules, indexed by trigger table name
delta_rules = {}
# tracing construct
tracer = Tracer()
def __init__(self, rules):
# rules dictating how an insert/delete to one table
# effects other tables
self.delta_rules = index_delta_rules(rules)
# queue of events left to process
self.queue = EventQueue()
# collection of all tables
self.database = Database()
# tracer object
self.tracer = Tracer()
def handle_insert(table, tuple):
""" Event handler for an insertion. """
if tracer.is_traced(table):
print "Inserting into queue: {} with {}".format(table, str(tuple))
# insert tuple into actual table before propagating or else self-join bug.
# Self-joins harder to fix when multi-threaded.
queue.enqueue(Event(table, tuple, insert=True))
process_queue()
def insert(self, table, tuple, atomlist=None):
""" Event handler for an insertion. """
# atomlist is used to make tests easier to read/write
if atomlist is not None:
table = atomlist[0]
tuple = atomlist[1:]
if self.tracer.is_traced(table):
print "Inserting into queue: {} with {}".format(table, str(tuple))
# insert tuple into actual table before propagating or else self-join bug.
# Self-joins harder to fix when multi-threaded.
self.queue.enqueue(Event(table, tuple, insert=True))
self.process_queue()
def handle_delete(table, tuple):
""" Event handler for a deletion. """
if tracer.is_traced(table):
print "Inserting into queue: {} with {}".format(table, str(tuple))
queue.enqueue(Event(table, tuple, insert=False))
process_queue()
def delete(self, table, tuple, atomlist=None):
""" Event handler for a deletion. """
# atomlist is used to make tests easier to read/write
if atomlist is not None:
table = atomlist[0]
tuple = atomlist[1:]
if self.tracer.is_traced(table):
print "Inserting into queue: {} with {}".format(table, str(tuple))
self.queue.enqueue(Event(table, tuple, insert=False))
self.process_queue()
def process_queue():
""" Toplevel evaluation routine. """
while len(queue) > 0:
event = queue.dequeue()
if event.is_insert():
database.insert(event.table, event.tuple)
def process_queue(self):
""" Toplevel evaluation routine. """
while len(self.queue) > 0:
event = self.queue.dequeue()
if event.is_insert():
self.database.insert(event.table, event.tuple)
else:
self.database.delete(event.table, event.tuple)
self.propagate(event)
def propagate(self, event):
""" Computes events generated by EVENT and the DELTA_RULES,
and enqueues them. """
if self.tracer.is_traced(event.table):
print "Processing event: {}".format(str(event))
if event.table not in self.delta_rules.keys():
print "event.table: {}".format(event.table)
self.print_delta_rules()
print "No applicable delta rule"
return
for delta_rule in self.delta_rules[event.table]:
self.propagate_rule(event, delta_rule)
def propagate_rule(self, event, delta_rule):
""" Compute and enqueue new events generated by EVENT and DELTA_RULE. """
assert(not delta_rule.trigger.is_negated())
if self.tracer.is_traced(event.table):
print "Processing event {} with rule {}".format(str(event), str(delta_rule))
# compute tuples generated by event (either for insert or delete)
binding_list = match(event.tuple, delta_rule.trigger)
if binding_list is None:
return
print "binding_list for event-tuple and delta_rule trigger: " + str(binding_list)
# vars_in_head = delta_rule.head.variable_names()
# print "vars_in_head: " + str(vars_in_head)
# needed_vars = set(vars_in_head)
# print "needed_vars: " + str(needed_vars)
new_bindings = self.top_down_eval(delta_rule.body, 0, binding_list)
print "new bindings after top-down: " + ",".join([str(x) for x in new_bindings])
# enqueue effects of Event
head_table = delta_rule.head.table
for new_binding in new_bindings:
self.queue.enqueue(Event(table=head_table,
tuple=plug(delta_rule.head, new_binding),
insert=event.insert))
def top_down_eval(self, atoms, atom_index, binding):
""" Compute all instances of ATOMS (from ATOM_INDEX and above) that
are true in the Database (after applying the dictionary binding
BINDING to ATOMs). Returns a list of dictionary bindings. """
atom = atoms[atom_index]
if self.tracer.is_traced(atom.table):
print ("Top-down eval(atoms={}, atom_index={}, "
"bindings={})").format(
"[" + ",".join(str(x) for x in atoms) + "]",
atom_index,
str(binding))
data_bindings = self.database.get_matches(atom, binding)
print "data_bindings: " + str(data_bindings)
if len(data_bindings) == 0:
return []
results = []
for data_binding in data_bindings:
# add this binding to var_bindings
binding.update(data_binding)
if atom_index == len(atoms) - 1: # last element in atoms
# construct result
# output_binding = {}
# for var in projection:
# output_binding[var] = binding[var]
# results.append(output_binding)
results.append(dict(binding)) # need to copy
else:
# recurse
results.extend(self.top_down_eval(atoms, atom_index + 1, binding))
# remove this binding from var_bindings
for var in data_binding:
del binding[var]
if self.tracer.is_traced(atom.table):
print "Return value: {}".format([str(x) for x in results])
return results
def print_delta_rules(self):
for table in self.delta_rules:
print "{}:".format(table)
for rule in self.delta_rules[table]:
print " {}".format(rule)
def index_delta_rules(delta_rules):
indexed_delta_rules = {}
for delta in delta_rules:
if delta.trigger.table not in indexed_delta_rules:
indexed_delta_rules[delta.trigger.table] = [delta]
else:
database.delete(event.table, event.tuple)
propagate(event)
def propagate(event):
""" Computes events generated by EVENT and the DELTA_RULES,
and enqueues them. """
if tracer.is_traced(event.table):
print "Processing event: {}".format(str(event))
if event.table not in delta_rules.keys():
print "event.table: {}".format(event.table)
print_delta_rules()
print "No applicable delta rule"
return
for delta_rule in delta_rules[event.table]:
propagate_rule(event, delta_rule)
def propagate_rule(event, delta_rule):
""" Compute and enqueue new events generated by EVENT and DELTA_RULE. """
assert(not delta_rule.trigger.is_negated())
if tracer.is_traced(event.table):
print "Processing event {} with rule {}".format(str(event), str(delta_rule))
# compute tuples generated by event (either for insert or delete)
binding_list = match(event.tuple, delta_rule.trigger)
if binding_list is None:
return
print "binding_list for event-tuple and delta_rule trigger: " + str(binding_list)
# vars_in_head = delta_rule.head.variable_names()
# print "vars_in_head: " + str(vars_in_head)
# needed_vars = set(vars_in_head)
# print "needed_vars: " + str(needed_vars)
new_bindings = top_down_eval(delta_rule.body, 0, binding_list)
print "new bindings after top-down: " + ",".join([str(x) for x in new_bindings])
# enqueue effects of Event
head_table = delta_rule.head.table
for new_binding in new_bindings:
queue.enqueue(Event(table=head_table,
tuple=plug(delta_rule.head, new_binding),
insert=event.insert))
indexed_delta_rules[delta.trigger.table].append(delta)
return indexed_delta_rules
def plug(atom, binding):
""" Returns a tuple representing the arguments to ATOM after having
@ -353,42 +431,6 @@ def eliminate_dups_with_ref_counts(tuples):
refcounts[tuple] = 0
return refcounts
def top_down_eval(atoms, atom_index, binding):
""" Compute all instances of ATOMS (from ATOM_INDEX and above) that
are true in the Database (after applying the dictionary binding
BINDING to ATOMs). Returns a list of dictionary bindings. """
atom = atoms[atom_index]
if tracer.is_traced(atom.table):
print ("Top-down eval(atoms={}, atom_index={}, "
"bindings={})").format(
"[" + ",".join(str(x) for x in atoms) + "]",
atom_index,
str(binding))
data_bindings = database.get_matches(atom, binding)
print "data_bindings: " + str(data_bindings)
if len(data_bindings) == 0:
return []
results = []
for data_binding in data_bindings:
# add this binding to var_bindings
binding.update(data_binding)
if atom_index == len(atoms) - 1: # last element in atoms
# construct result
# output_binding = {}
# for var in projection:
# output_binding[var] = binding[var]
# results.append(output_binding)
results.append(dict(binding)) # need to copy
else:
# recurse
results.extend(top_down_eval(atoms, atom_index + 1, binding))
# remove this binding from var_bindings
for var in data_binding:
del binding[var]
if tracer.is_traced(atom.table):
print "Return value: {}".format([str(x) for x in results])
return results
# def atom_arg_names(atom):
# if atom.table not in database.schemas:
@ -425,11 +467,5 @@ def all_variables(atoms, atom_index):
# print "new_bindings: {}, unbound_names: {}".format(new_bindings, unbound_names)
# return (new_bindings, unbound_names)
def print_delta_rules():
print "runtime's delta rules"
for table in delta_rules:
print "{}:".format(table)
for rule in delta_rules[table]:
print " {}".format(rule)