Added insert/delete-rule to top-level interface

Passing all tests.

Also stubbed in high-level interface for explanations and hypotheticals

Issue: #
Change-Id: Ibe536dfa15236457212eb97351b56ed9e28163cc
This commit is contained in:
Tim Hinrichs 2013-08-20 13:12:30 -07:00
parent ad78238c79
commit fa085d5421
6 changed files with 450 additions and 266 deletions

View File

@ -0,0 +1,25 @@
error :- nova:virtual_machine(vm), nova:network(vm, network),
not neutron:public_network(network),
neutron:owner(network, netowner), nova:owner(vm, vmowner), not same_group(netowner, vmowner)
same_group(user1, user2) :- cms:group(user1, group), cms:group(user2, group)
nova:virtual_machine("vm1")
nova:virtual_machine("vm2")
nova:network("vm1", "net_private")
nova:network("vm2", "net_public")
neutron:public_network("net_public")
nova:owner("vm1", "tim")
nova:owner("vm2", "pete")
neutron:owner("net_private", "martin")
cms:group("pete", "congress")
cms:group("tim", "congress")
cms:group("martin", "congress")
cms:group("pierre", "congress")

View File

@ -14,6 +14,7 @@ tokens {
RPAREN=')';
// Structure
THEORY;
STRUCTURED_NAME;
// Kinds of Formulas
RULE;
@ -45,7 +46,7 @@ formula_terminator
bare_formula
: rule
| atom
| formula
;
rule
@ -92,7 +93,7 @@ variable
;
relation_constant
: ID
: ID (':' ID)* -> ^(STRUCTURED_NAME ID+)
;
propositional_constant

View File

@ -21,6 +21,10 @@ class CongressException (Exception):
s = " at" + s
return Exception.__str__(self) + s
##############################################################################
## Internal representation of policy language
##############################################################################
class Location (object):
""" A location in the program source code. """
def __init__(self, line=None, col=None, obj=None):
@ -191,12 +195,15 @@ class Rule (object):
def is_rule(self):
return True
##############################################################################
## Compiler
##############################################################################
class Compiler (object):
""" Process Congress policy file. """
def __init__(self):
self.raw_syntax_tree = None
self.theory = None
self.theory = []
self.errors = []
self.warnings = []
@ -235,66 +242,78 @@ class Compiler (object):
errors = [str(err) for err in self.errors]
raise CongressException('Compiler found errors:' + '\n'.join(errors))
def eliminate_self_joins(self):
def new_table_name(name, arity, index):
return "___{}_{}_{}".format(name, arity, index)
def n_variables(n):
vars = []
for i in xrange(0, n):
vars.append("x" + str(i))
return vars
# dict from (table name, arity) tuple to
# max num of occurrences of self-joins in any rule
global_self_joins = {}
# dict from (table name, arity) to # of args for
arities = {}
# remove self-joins from rules
for rule in self.theory:
if rule.is_atom():
continue
logging.debug("eliminating self joins from {}".format(rule))
occurrences = {} # for just this rule
for atom in rule.body:
table = atom.table
arity = len(atom.arguments)
tablearity = (table, arity)
if tablearity not in occurrences:
occurrences[tablearity] = 1
else:
# change name of atom
atom.table = new_table_name(table, arity,
occurrences[tablearity])
# update our counters
occurrences[tablearity] += 1
if tablearity not in global_self_joins:
global_self_joins[tablearity] = 1
else:
global_self_joins[tablearity] = \
max(occurrences[tablearity] - 1,
global_self_joins[tablearity])
logging.debug("final rule: {}".format(str(rule)))
# add definitions for new tables
for tablearity in global_self_joins:
table = tablearity[0]
arity = tablearity[1]
for i in xrange(1, global_self_joins[tablearity] + 1):
newtable = new_table_name(table, arity, i)
args = [Variable(var) for var in n_variables(arity)]
head = Atom(newtable, args)
body = [Atom(table, args)]
self.theory.append(Rule(head, body))
logging.debug("Adding rule {}".format(str(self.theory[-1])))
def compute_delta_rules(self):
""" Assumes no self-joins. """
self.delta_rules = []
for rule in self.theory:
if rule.is_atom():
continue
for literal in rule.body:
newbody = [lit for lit in rule.body if lit is not literal]
self.delta_rules.append(
runtime.DeltaRule(literal, rule.head, newbody, rule))
# logging.debug("self.theory: {}".format([str(x) for x in self.theory]))
self.delta_rules = compute_delta_rules(self.theory)
def eliminate_self_joins(theory):
""" Modify THEORY so that all self-joins have been eliminated. """
def new_table_name(name, arity, index):
return "___{}_{}_{}".format(name, arity, index)
def n_variables(n):
vars = []
for i in xrange(0, n):
vars.append("x" + str(i))
return vars
# dict from (table name, arity) tuple to
# max num of occurrences of self-joins in any rule
global_self_joins = {}
# dict from (table name, arity) to # of args for
arities = {}
# remove self-joins from rules
for rule in theory:
if rule.is_atom():
continue
logging.debug("eliminating self joins from {}".format(rule))
occurrences = {} # for just this rule
for atom in rule.body:
table = atom.table
arity = len(atom.arguments)
tablearity = (table, arity)
if tablearity not in occurrences:
occurrences[tablearity] = 1
else:
# change name of atom
atom.table = new_table_name(table, arity,
occurrences[tablearity])
# update our counters
occurrences[tablearity] += 1
if tablearity not in global_self_joins:
global_self_joins[tablearity] = 1
else:
global_self_joins[tablearity] = \
max(occurrences[tablearity] - 1,
global_self_joins[tablearity])
logging.debug("final rule: {}".format(str(rule)))
# add definitions for new tables
for tablearity in global_self_joins:
table = tablearity[0]
arity = tablearity[1]
for i in xrange(1, global_self_joins[tablearity] + 1):
newtable = new_table_name(table, arity, i)
args = [Variable(var) for var in n_variables(arity)]
head = Atom(newtable, args)
body = [Atom(table, args)]
theory.append(Rule(head, body))
logging.debug("Adding rule {}".format(str(theory[-1])))
return theory
def compute_delta_rules(theory):
eliminate_self_joins(theory)
delta_rules = []
for rule in theory:
if rule.is_atom():
continue
for literal in rule.body:
newbody = [lit for lit in rule.body if lit is not literal]
delta_rules.append(
runtime.DeltaRule(literal, rule.head, newbody, rule))
return delta_rules
##############################################################################
## External syntax: datalog
##############################################################################
class CongressSyntax (object):
""" External syntax and converting it into internal representation. """
@ -393,13 +412,19 @@ class CongressSyntax (object):
@classmethod
def create_atom_aux(cls, antlr):
# (ATOM (TABLE ARG1 ... ARGN))
# (ATOM (TABLENAME ARG1 ... ARGN))
table = cls.create_structured_name(antlr.children[0])
args = []
for i in xrange(1, len(antlr.children)):
args.append(cls.create_term(antlr.children[i]))
loc = Location(line=antlr.children[0].token.line,
col=antlr.children[0].token.charPositionInLine)
return (antlr.children[0].getText(), args, loc)
return (table, args, loc)
@classmethod
def create_structured_name(cls, antlr):
# (STRUCTURED_NAME (ARG1 ... ARGN))
return ":".join([x.getText() for x in antlr.children])
@classmethod
def create_term(cls, antlr):
@ -437,7 +462,12 @@ def print_tree(tree, text, kids, ind=0):
for child in children:
print_tree(child, text, kids, ind + 1)
##############################################################################
## Mains
##############################################################################
def get_compiled(args):
""" Run compiler as per ARGS and return the resulting Compiler instance. """
parser = optparse.OptionParser()
parser.add_option("--input_string", dest="input_string", default=False,
action="store_true",
@ -449,11 +479,12 @@ def get_compiled(args):
compiler = Compiler()
for i in inputs:
compiler.read_source(i, input_string=options.input_string)
compiler.eliminate_self_joins()
compiler.compute_delta_rules()
return compiler
def get_runtime(args):
""" Create runtime by running compiler as per ARGS and initializing runtime
with result of compilation. """
comp = get_compiled(args)
run = runtime.Runtime(comp.delta_rules)
tracer = runtime.Tracer()

View File

@ -2,6 +2,7 @@
import collections
import logging
import compile
class Tracer(object):
def __init__(self):
@ -10,10 +11,18 @@ class Tracer(object):
self.expressions.append(table)
def is_traced(self, table):
return table in self.expressions or '*' in self.expressions
def log(self, table, msg, depth=0):
if self.is_traced(table):
logging.debug("{}{}".format(("| " * depth), msg))
class CongressRuntime (Exception):
pass
##############################################################################
## Delta Rules
##############################################################################
class DeltaRule(object):
def __init__(self, trigger, head, body, original):
self.trigger = trigger # atom
@ -32,6 +41,48 @@ class DeltaRule(object):
all(self.body[i] == other.body[i]
for i in xrange(0, len(self.body))))
class DeltaRuleTheory (object):
""" A collection of DeltaRules. """
def __init__(self, rules=None):
self.contents = {}
if rules is not None:
for rule in rules:
self.insert(rule)
def modify(self, delta, is_insert):
if is_insert is True:
return self.insert(delta)
else:
return self.delete(delta)
def insert(self, delta):
if delta.trigger.table not in self.contents:
self.contents[delta.trigger.table] = [delta]
else:
self.contents[delta.trigger.table].append(delta)
def delete(self, delta):
if delta.trigger.table not in self.contents:
return
self.contents[delta.trigger.table].remove(delta)
def __str__(self):
return str(self.contents)
# for table in self.contents:
# print "{}:".format(table)
# for rule in self.delta_rules[table]:
# print " {}".format(rule)
def rules_with_trigger(self, table):
if table not in self.contents:
return []
else:
return self.contents[table]
##############################################################################
## Events
##############################################################################
class EventQueue(object):
def __init__(self):
self.queue = collections.deque()
@ -54,7 +105,7 @@ class Event(object):
self.table = table
self.tuple = Database.DBTuple(tuple, proofs=proofs)
self.insert = insert
logging.debug("EVENT: created event {}".format(str(self)))
logging.debug("EV: created event {}".format(str(self)))
def is_insert(self):
return self.insert
@ -66,6 +117,10 @@ class Event(object):
sign = '-'
return "{}{}({})".format(self.table, sign, str(self.tuple))
##############################################################################
## Database
##############################################################################
class Database(object):
class Proof(object):
def __init__(self, binding, rule):
@ -211,15 +266,60 @@ class Database(object):
# KEY must be a tablename
return self.data[key]
def __iter__(self):
self.__
def table_names(self):
return self.data.keys()
def log(self, table, msg):
if self.tracer.is_traced(table):
logging.debug("DB: " + msg)
def log(self, table, msg, depth=0):
self.tracer.log(table, "DB: " + msg, depth)
def select(self, atom):
bindings = self.top_down_eval([atom], 0, {})
result = []
for binding in bindings:
new_atom = [atom.table]
new_atom.extend(plug(atom, binding))
if new_atom not in result:
result.append(new_atom)
return result
def top_down_eval(self, literals, literal_index, binding):
""" Compute all instances of LITERALS (from LITERAL_INDEX and above) that
are true in the Database (after applying the dictionary binding
BINDING to LITERALS). Returns a list of dictionary bindings. """
if literal_index > len(literals) - 1:
return [binding]
lit = literals[literal_index]
self.log(lit.table, ("Top_down_eval(literals={}, literal_index={}, "
"bindings={})").format(
"[" + ",".join(str(x) for x in literals) + "]",
literal_index,
str(binding)),
depth=literal_index)
# assume that for negative literals, all vars are bound at this point
# if there is a match, data_bindings will contain at least one binding
# (possibly including the empty binding)
data_bindings = self.get_matches(lit, binding)
self.log(lit.table, "data_bindings: " + str(data_bindings), depth=literal_index)
# if not negated, empty data_bindings means failure
if len(data_bindings) == 0 :
return []
results = []
for data_binding in data_bindings:
# add new binding to current binding
binding.update(data_binding)
if literal_index == len(literals) - 1: # last element
results.append(dict(binding)) # need to copy
else:
results.extend(self.top_down_eval(literals, literal_index + 1,
binding))
# remove new binding from current bindings
for var in data_binding:
del binding[var]
self.log(lit.table, "Top_down_eval return value: {}".format(
'[' + ", ".join([str(x) for x in results]) + ']'), depth=literal_index)
return results
def get_matches(self, literal, binding):
""" Returns a list of binding lists for the variables in LITERAL
@ -290,14 +390,18 @@ class Database(object):
del self.data[table][i]
return
##############################################################################
## Runtime classes
##############################################################################
class Runtime (object):
""" Runtime for the Congress policy language. Only have one instantiation
in practice, but using a class is natural and useful for testing. """
def __init__(self, rules):
# rules dictating how an insert/delete to one table
# effects other tables
self.delta_rules = index_delta_rules(rules)
# affects other tables
self.delta_rules = DeltaRuleTheory(rules)
# queue of events left to process
self.queue = EventQueue()
# collection of all tables
@ -306,33 +410,81 @@ class Runtime (object):
self.tracer = Tracer()
def log(self, table, msg, depth=0):
if self.tracer.is_traced(table):
logging.debug("{}{}".format(("| " * depth), msg))
self.tracer.log(table, "RT: " + msg, depth)
def insert(self, table, tuple):
""" Event handler for an insertion.
TABLE is the name of a table (a string).
TUPLE is a Python tuple. """
if not isinstance(tuple, Database.DBTuple):
tuple = Database.DBTuple(tuple)
self.log(table, "Inserting into queue: {} with {}".format(
table, str(tuple)))
self.queue.enqueue(Event(table, tuple, insert=True))
self.process_queue() # should be running in separate daemon
############### External interface ###############
def select(self, query):
""" Event handler for arbitrary queries. Returns the set of
all instantiated QUERY that are true. """
# should generalize to at least a (conjunction of atoms)
# Need to change compiler a bit, but runtime should be fine.
assert isinstance(query, compile.Atom), "Only have support for atomic queries"
return self.database.select(query)
def delete(self, table, tuple):
""" Event handler for a deletion. TUPLE is a Python tuple.
def select_if(self, query, temporary_data):
""" Event handler for hypothetical queries. Returns the set of
all instantiated QUERYs that would be true IF
TEMPORARY_DATA were true. """
assert False, "Not yet implemented"
def explain(self, query):
""" Event handler for explanations. Given a ground query, return
all explanations for it. """
assert False, "Not yet implemented"
def insert(self, formula):
""" Event handler for arbitrary insertion (rules and facts). """
return self.modify(formula, is_insert=True)
def delete(self, formula):
""" Event handler for arbitrary deletion (rules and facts). """
return self.modify(formula, is_insert=False)
############### Interface implementation ###############
def modify(self, formula, is_insert=True):
""" Event handler for arbitrary insertion/deletion (rules and facts). """
if formula.is_atom():
args = tuple([arg.name for arg in formula.arguments])
self.modify_tuple(formula.table, args, is_insert=is_insert)
else:
self.modify_rule(formula, is_insert=is_insert)
for delta_rule in compile.compute_delta_rules([formula]):
self.delta_rules.modify(delta_rule, is_insert=is_insert)
def modify_rule(self, rule, is_insert):
""" Add rule (not a DeltaRule) to collection and update
tables as appropriate. """
# don't have separate queue since inserting/deleting a rule doesn't generate any
# new rule insertion/deletion events
bindings = self.database.top_down_eval(rule.body, 0, {})
self.log(None, "new bindings after top-down: {}".format(
",".join([str(x) for x in bindings])))
self.process_new_bindings(bindings, rule.head, is_insert, rule)
self.process_queue()
def modify_tuple(self, table, row, is_insert):
""" Event handler for a tuple insertion/deletion.
TABLE is the name of a table (a string).
TUPLE is a Python tuple. """
if not isinstance(tuple, Database.DBTuple):
tuple = Database.DBTuple(tuple)
self.log(table, "Deleting from queue: {} with {}".format(
table, str(tuple)))
self.queue.enqueue(Event(table, tuple, insert=False))
self.process_queue() # should be running in separate daemon
TUPLE is a Python tuple.
IS_INSERT is True or False."""
if is_insert:
text = "Inserting into queue"
else:
text = "Deleting from queue"
self.log(table, "{}: table {} with tuple {}".format(
text, table, str(row)))
if not isinstance(row, Database.DBTuple):
row = Database.DBTuple(row)
self.log(table, "{}: table {} with tuple {}".format(
text, table, str(row)))
self.queue.enqueue(Event(table, row, insert=is_insert))
self.process_queue()
############### Data manipulation ###############
def process_queue(self):
""" Toplevel evaluation routine. """
""" Toplevel data evaluation routine. """
while len(self.queue) > 0:
event = self.queue.dequeue()
if event.is_insert():
@ -346,10 +498,10 @@ class Runtime (object):
""" Computes events generated by EVENT and the DELTA_RULES,
and enqueues them. """
self.log(event.table, "Processing event: {}".format(str(event)))
if event.table not in self.delta_rules.keys():
applicable_rules = self.delta_rules.rules_with_trigger(event.table)
if len(applicable_rules) == 0:
self.log(event.table, "No applicable delta rule")
return
for delta_rule in self.delta_rules[event.table]:
for delta_rule in applicable_rules:
self.propagate_rule(event, delta_rule)
def propagate_rule(self, event, delta_rule):
@ -366,95 +518,90 @@ class Runtime (object):
self.log(event.table,
"binding_list for event-tuple and delta_rule trigger: {}".format(
str(binding_list)))
new_bindings = self.top_down_eval(delta_rule.body, 0, binding_list)
new_bindings = self.database.top_down_eval(delta_rule.body, 0, binding_list)
self.log(event.table, "new bindings after top-down: {}".format(
",".join([str(x) for x in new_bindings])))
# for each binding, compute generated tuple and group bindings
# by the tuple they generated
new_tuples = {}
for new_binding in new_bindings:
new_tuple = tuple(plug(delta_rule.head, new_binding))
if new_tuple not in new_tuples:
new_tuples[new_tuple] = []
new_tuples[new_tuple].append(Database.Proof(
new_binding, delta_rule.original))
self.log(event.table, "new tuples generated: {}".format(
", ".join([str(x) for x in new_tuples])))
# enqueue each distinct generated tuple, recording appropriate bindings
head_table = delta_rule.head.table
if delta_rule.trigger.is_negated():
insert_delete = not event.insert
else:
insert_delete = event.insert
self.process_new_bindings(new_bindings, delta_rule.head, insert_delete,
delta_rule.original)
def process_new_bindings(self, bindings, atom, insert, original_rule):
""" For each of BINDINGS, apply to ATOM, and enqueue it as an insert if
INSERT is True and as a delete otherwise. """
# for each binding, compute generated tuple and group bindings
# by the tuple they generated
new_tuples = {}
for binding in bindings:
new_tuple = tuple(plug(atom, binding))
if new_tuple not in new_tuples:
new_tuples[new_tuple] = []
new_tuples[new_tuple].append(Database.Proof(
binding, original_rule))
self.log(atom.table, "new tuples generated: {}".format(
", ".join([str(x) for x in new_tuples])))
# enqueue each distinct generated tuple, recording appropriate bindings
for new_tuple in new_tuples:
# self.log(event.table,
# "new_tuple {}: {}".format(str(new_tuple), str(new_tuples[new_tuple])))
self.queue.enqueue(Event(table=head_table,
self.queue.enqueue(Event(table=atom.table,
tuple=new_tuple,
proofs=new_tuples[new_tuple],
insert=insert_delete))
insert=insert))
def top_down_eval(self, literals, literal_index, binding):
""" Compute all instances of LITERALS (from LITERAL_INDEX and above) that
are true in the Database (after applying the dictionary binding
BINDING to LITERALS). Returns a list of dictionary bindings. """
if literal_index > len(literals) - 1:
return [binding]
lit = literals[literal_index]
self.log(lit.table, ("Top_down_eval(literals={}, literal_index={}, "
"bindings={})").format(
"[" + ",".join(str(x) for x in literals) + "]",
literal_index,
str(binding)),
depth=literal_index)
# assume that for negative literals, all vars are bound at this point
# if there is a match, data_bindings will contain at least one binding
# (possibly including the empty binding)
data_bindings = self.database.get_matches(lit, binding)
self.log(lit.table, "data_bindings: " + str(data_bindings), depth=literal_index)
# if not negated, empty data_bindings means failure
if len(data_bindings) == 0 :
return []
class StringRuntime(Runtime):
""" Version of Runtime that communicates via strings. """
def select(self, policy_string):
""" Event handler for arbitrary queries. Returns the set of
all instantiated POLICY_STRING that are true. """
def str_tuple_atom (atom):
s = atom[0]
s += '('
s += ', '.join([str(x) for x in atom[1:]])
s += ')'
return s
c = compile.get_compiled([policy_string, '--input_string'])
assert len(c.theory) == 1, "Queries can have only 1 statement"
assert c.theory[0].is_atom(), "Queries must be atomic"
results = super(StringRuntime, self).select(c.theory[0])
return " ".join([str_tuple_atom(x) for x in results])
results = []
for data_binding in data_bindings:
# add new binding to current binding
binding.update(data_binding)
if literal_index == len(literals) - 1: # last element
results.append(dict(binding)) # need to copy
else:
results.extend(self.top_down_eval(literals, literal_index + 1,
binding))
# remove new binding from current bindings
for var in data_binding:
del binding[var]
self.log(lit.table, "Top_down_eval return value: {}".format(
'[' + ", ".join([str(x) for x in results]) + ']'), depth=literal_index)
def select_if(self, query_string, temporary_data):
""" Event handler for hypothetical queries. Returns the set of
all instantiated QUERYs that would be true IF
TEMPORARY_DATA were true. """
assert False, "Not yet implemented"
return results
def explain(self, query_string):
""" Event handler for explanations. Given a ground query, return
all explanations for it. """
assert False, "Not yet implemented"
def print_delta_rules(self):
for table in self.delta_rules:
print "{}:".format(table)
for rule in self.delta_rules[table]:
print " {}".format(rule)
def insert(self, policy_string):
""" Event handler for arbitrary insertion (rules and/or facts). """
c = compile.get_compiled([policy_string, '--input_string'])
for formula in c.theory:
logging.debug("Parsed {}".format(str(formula)))
super(StringRuntime, self).insert(formula)
def delete(self, policy_string):
""" Event handler for arbitrary deletion (rules and/or facts). """
c = compile.get_compiled([policy_string, '--input_string'])
for formula in c.theory:
super(StringRuntime, self).delete(formula)
def index_delta_rules(delta_rules):
indexed_delta_rules = {}
for delta in delta_rules:
if delta.trigger.table not in indexed_delta_rules:
indexed_delta_rules[delta.trigger.table] = [delta]
else:
indexed_delta_rules[delta.trigger.table].append(delta)
return indexed_delta_rules
def plug(atom, binding):
def plug(atom, binding, withtable=False):
""" Returns a tuple representing the arguments to ATOM after having
applied BINDING to the variables in ATOM. """
result = []
if withtable is True:
result = [atom.table]
else:
result = []
for i in xrange(0, len(atom.arguments)):
if atom.arguments[i].is_variable() and atom.arguments[i].name in binding:
result.append(binding[atom.arguments[i].name])
@ -479,50 +626,4 @@ def match(tuple, atom):
binding[arg.name] = tuple[i]
return binding
def eliminate_dups_with_ref_counts(tuples):
refcounts = {}
for tuple in tuples:
if tuple in refcounts:
refcounts[tuple] += 1
else:
refcounts[tuple] = 0
return refcounts
# def atom_arg_names(atom):
# if atom.table not in database.schemas:
# raise CongressRuntime("Table {} has no schema".format(atom.table))
# schema = database.schemas[atom.table]
# if len(atom.arguments) != len(schema.arguments):
# raise CongressRuntime("Atom {} has wrong number of arguments for "
# " schema: {}".format(atom, str(schema)))
# mapping = {}
# for i in xrange(0, len(atom.arguments)):
# mapping[schema.arguments[i]] = atom.arguments[i]
# return mapping
def all_variables(atoms, atom_index):
vars = set()
for i in xrange(atom_index, len(atoms)):
vars |= atoms[i].variable_names()
return vars
# def var_bindings_to_named_bindings(atom, var_bindings):
# new_bindings = {}
# unbound_names = set()
# schema = database.schemas[atom.table]
# print "schema: " + str(schema.arguments)
# assert(len(schema.arguments) == len(atom.arguments))
# for i in xrange(0, len(atom.arguments)):
# term = atom.arguments[i]
# if term.is_object():
# new_bindings[schema.arguments[i]] = term.name
# elif term.name in var_bindings:
# new_bindings[schema.arguments[i]] = var_bindings[term.name]
# else:
# unbound_names.add(schema.arguments[i])
# print "new_bindings: {}, unbound_names: {}".format(new_bindings, unbound_names)
# return (new_bindings, unbound_names)

View File

@ -17,7 +17,7 @@ class TestRuntime(unittest.TestCase):
if msg is not None:
logging.debug(msg)
c = compile.get_compiled([code, '--input_string'])
run = runtime.Runtime(c.delta_rules)
run = runtime.StringRuntime(c.delta_rules)
tracer = runtime.Tracer()
tracer.trace('*')
run.tracer = tracer
@ -25,24 +25,23 @@ class TestRuntime(unittest.TestCase):
return run
def insert(self, run, list):
run.insert(list[0], tuple(list[1:]))
run.modify_tuple(list[0], tuple(list[1:]), is_insert=True)
def delete(self, run, list):
run.delete(list[0], tuple(list[1:]))
run.modify_tuple(list[0], tuple(list[1:]), is_insert=False)
def check(self, run, correct_database_code, msg=None):
# extract correct answer from correct_database_code
logging.debug("** Checking {} **".format(msg))
c = compile.get_compiled([correct_database_code, '--input_string'])
correct = c.theory
correct_database = runtime.Database()
for atom in correct:
correct_database.insert(atom.table,
[x.name for x in atom.arguments])
def string_to_database(self, string):
c = compile.get_compiled([string, '--input_string'])
database = runtime.Database()
for atom in c.theory:
if atom.is_atom():
database.insert(atom.table,
[x.name for x in atom.arguments])
return database
# compute diffs; should be empty
extra = run.database - correct_database
missing = correct_database - run.database
def check_db_diffs(self, actual, correct, msg):
extra = actual - correct
missing = correct - actual
extra = [e for e in extra if not e[0].startswith("___")]
missing = [m for m in missing if not m[0].startswith("___")]
errmsg = ""
@ -53,9 +52,21 @@ class TestRuntime(unittest.TestCase):
logging.debug("Missing tuples")
logging.debug(", ".join([str(x) for x in missing]))
self.assertTrue(len(extra) == 0 and len(missing) == 0, msg)
logging.debug(str(run.database))
def check(self, run, correct_database_code, msg=None):
# extract correct answer from correct_database_code
logging.debug("** Checking {} **".format(msg))
correct_database = self.string_to_database(correct_database_code)
self.check_db_diffs(run.database, correct_database, msg)
logging.debug("** Finished {} **".format(msg))
def check_equal(self, actual_database_code, correct_database_code, msg=None):
logging.debug("** Checking equality for {} **".format(msg))
actual = self.string_to_database(actual_database_code)
correct = self.string_to_database(correct_database_code)
self.check_db_diffs(actual, correct, msg)
logging.debug("** Finished for {} **".format(msg))
def check_proofs(self, run, correct, msg=None):
""" Check that the proofs stored in runtime RUN are exactly
those in CORRECT. """
@ -291,20 +302,20 @@ class TestRuntime(unittest.TestCase):
self.delete(run, ['r', 1.2])
self.check(run, 'p(1.2)', "String delete with 1 propagation")
def test_proofs(self):
""" Test if the proof computation is performed correctly. """
def check_table_proofs(run, table, tuple_proof_dict, msg):
for tuple in tuple_proof_dict:
tuple_proof_dict[tuple] = \
Database.ProofCollection(tuple_proof_dict[tuple])
self.check_proofs(run, {table : tuple_proof_dict}, msg)
# def test_proofs(self):
# """ Test if the proof computation is performed correctly. """
# def check_table_proofs(run, table, tuple_proof_dict, msg):
# for tuple in tuple_proof_dict:
# tuple_proof_dict[tuple] = \
# Database.ProofCollection(tuple_proof_dict[tuple])
# self.check_proofs(run, {table : tuple_proof_dict}, msg)
code = ("q(x) :- p(x,y)")
run = self.prep_runtime(code, "**** Proof tests ****")
# code = ("q(x) :- p(x,y)")
# run = self.prep_runtime(code, "**** Proof tests ****")
self.insert(run, ['p', 1, 2])
check_table_proofs(run, 'q', {(1,): [{u'x': 1, u'y': 2}]},
'Simplest proof test')
# self.insert(run, ['p', 1, 2])
# check_table_proofs(run, 'q', {(1,): [{u'x': 1, u'y': 2}]},
# 'Simplest proof test')
def test_negation(self):
""" Test negation """
@ -378,5 +389,44 @@ class TestRuntime(unittest.TestCase):
self.check(run, 'q(1,2) r(2,3) r(2,4) u(3,5) u(4,6) s(1,3) s(1,4)',
'Insert into non-unary with different propagation')
def test_select(self):
""" Test the SELECT event handler. """
code = ("p(x, y) :- q(x), r(y)")
run = self.prep_runtime(code, "Select")
self.insert(run, ['q', 1])
self.insert(run, ['q', 2])
self.insert(run, ['r', 1])
self.insert(run, ['r', 2])
self.check(run, 'q(1) q(2) r(1) r(2) p(1,1) p(1,2) p(2,1) p(2,2)',
'Prepare for select')
logging.debug(run.select('p(x,y)'))
self.check_equal(run.select('p(x,y)'), 'p(1,1) p(1,2) p(2,1) p(2,2)',
'Select: bound no args')
self.check_equal(run.select('p(1,y)'), 'p(1,1) p(1,2)',
'Select: bound 1st arg')
self.check_equal(run.select('p(x,2)'), 'p(1,2) p(2,2)',
'Select: bound 2nd arg')
self.check_equal(run.select('p(1,2)'), 'p(1,2)',
'Select: bound 1st and 2nd arg')
def test_modify_rules(self):
""" Test the functionality for adding and deleting rules *after* data
has already been entered. """
run = self.prep_runtime("", "Rule modification")
run.insert("q(1) r(1) q(2) r(2)")
self.showdb(run)
self.check(run, 'q(1) r(1) q(2) r(2)', "Installation")
run.insert("p(x) :- q(x), r(x)")
self.check(run, 'q(1) r(1) q(2) r(2) p(1) p(2)', 'Rule insert after data insert')
run.delete("q(1)")
self.check(run, 'r(1) q(2) r(2) p(2)', 'Delete after Rule insert with propagation')
run.insert("q(1)")
run.delete("p(x) :- q(x), r(x)")
self.check(run, 'q(1) r(1) q(2) r(2)', "Delete rule")
def test_explanations(self):
""" Test the explanation event handler. """
pass
if __name__ == '__main__':
unittest.main()

View File

@ -1,24 +0,0 @@
error(vm) :- virtual_machine(vm), network(vm, net), not public(net), owner(vm, owner), not owned_by_some_group(owner, net)
owned_by_some_group(owner, object) :- owner(object, user), member(user, group), member(owner, group)
virtual_machine('vm1')
virtual_machine('vm2')
network('vm1', 'net_private')
network('vm2', 'net_public')
public('net_public')
owner('vm1', 'tim')
owner('vm2', 'pete')
owner('net_private', 'martin')
member('pete', 'congress')
member('tim', 'congress')
member('martin', 'congress')
member('pierre', 'congress')