Added negation, proofs to policy engine runtime.

Passes all tests, except I haven't yet figured out how to write reasonable tests for the
proofs.

Issue: #
Change-Id: I713c10481afc59f9872fc98866145a2b594efd64
This commit is contained in:
Tim Hinrichs 2013-08-19 13:36:19 -07:00
parent 0b613444df
commit ad78238c79
4 changed files with 530 additions and 318 deletions

View File

@ -8,6 +8,7 @@ import CongressLexer
import CongressParser
import antlr3
import runtime
import logging
class CongressException (Exception):
def __init__(self, msg, obj=None, line=None, col=None):
@ -68,6 +69,9 @@ class Variable (Term):
def __str__(self):
return str(self.name)
def __eq__(self, other):
return isinstance(other, Variable) and self.name == other.name
def is_variable(self):
return True
@ -89,6 +93,11 @@ class ObjectConstant (Term):
def __str__(self):
return str(self.name)
def __eq__(self, other):
return (isinstance(other, ObjectConstant) and
self.name == other.name and
self.type == other.type)
def is_variable(self):
return False
@ -115,6 +124,13 @@ class Atom (object):
return "{}({})".format(self.table,
", ".join([str(x) for x in self.arguments]))
def __eq__(self, other):
return (isinstance(other, Atom) and
self.table == other.table and
len(self.arguments) == len(other.arguments) and
all(self.arguments[i] == other.arguments[i]
for i in xrange(0, len(self.arguments))))
def is_atom(self):
return True
@ -139,6 +155,9 @@ class Literal(Atom):
else:
return Atom.__str__(self)
def __eq__(self, other):
return (self.negated == other.negated and Atom.__eq__(self, other))
def is_negated(self):
return self.negated
@ -160,6 +179,12 @@ class Rule (object):
str(self.head),
", ".join([str(atom) for atom in self.body]))
def __eq__(self, other):
return (self.head == other.head and
len(self.body) == len(other.body) and
all(self.body[i] == other.body[i]
for i in xrange(0, len(self.body))))
def is_atom(self):
return False
@ -184,10 +209,9 @@ class Compiler (object):
s += 'None'
return s
def read_source(self, input_file=None, input_string=None):
assert(input_file is not None or input_string is not None)
def read_source(self, input, input_string=False):
# parse input file and convert to internal representation
self.raw_syntax_tree = CongressSyntax.parse_file(input_file=input_file,
self.raw_syntax_tree = CongressSyntax.parse_file(input,
input_string=input_string)
#self.print_parse_result()
self.theory = CongressSyntax.create(self.raw_syntax_tree)
@ -211,13 +235,66 @@ class Compiler (object):
errors = [str(err) for err in self.errors]
raise CongressException('Compiler found errors:' + '\n'.join(errors))
def eliminate_self_joins(self):
def new_table_name(name, arity, index):
return "___{}_{}_{}".format(name, arity, index)
def n_variables(n):
vars = []
for i in xrange(0, n):
vars.append("x" + str(i))
return vars
# dict from (table name, arity) tuple to
# max num of occurrences of self-joins in any rule
global_self_joins = {}
# dict from (table name, arity) to # of args for
arities = {}
# remove self-joins from rules
for rule in self.theory:
if rule.is_atom():
continue
logging.debug("eliminating self joins from {}".format(rule))
occurrences = {} # for just this rule
for atom in rule.body:
table = atom.table
arity = len(atom.arguments)
tablearity = (table, arity)
if tablearity not in occurrences:
occurrences[tablearity] = 1
else:
# change name of atom
atom.table = new_table_name(table, arity,
occurrences[tablearity])
# update our counters
occurrences[tablearity] += 1
if tablearity not in global_self_joins:
global_self_joins[tablearity] = 1
else:
global_self_joins[tablearity] = \
max(occurrences[tablearity] - 1,
global_self_joins[tablearity])
logging.debug("final rule: {}".format(str(rule)))
# add definitions for new tables
for tablearity in global_self_joins:
table = tablearity[0]
arity = tablearity[1]
for i in xrange(1, global_self_joins[tablearity] + 1):
newtable = new_table_name(table, arity, i)
args = [Variable(var) for var in n_variables(arity)]
head = Atom(newtable, args)
body = [Atom(table, args)]
self.theory.append(Rule(head, body))
logging.debug("Adding rule {}".format(str(self.theory[-1])))
def compute_delta_rules(self):
""" Assumes no self-joins. """
self.delta_rules = []
for rule in self.theory:
if rule.is_atom():
continue
for literal in rule.body:
newbody = [lit for lit in rule.body if lit is not literal]
self.delta_rules.append(
runtime.DeltaRule(literal, rule.head, newbody))
runtime.DeltaRule(literal, rule.head, newbody, rule))
class CongressSyntax (object):
""" External syntax and converting it into internal representation. """
@ -251,12 +328,11 @@ class CongressSyntax (object):
e.line, e.charPositionInLine)
@classmethod
def parse_file(cls, input_file=None, input_string=None):
assert(input_file is not None or input_string is not None)
if input_file is not None:
char_stream = antlr3.ANTLRFileStream(input_file)
def parse_file(cls, input, input_string=False):
if not input_string:
char_stream = antlr3.ANTLRFileStream(input)
else:
char_stream = antlr3.ANTLRStringStream(input_string)
char_stream = antlr3.ANTLRStringStream(input)
lexer = cls.Lexer(char_stream)
tokens = antlr3.CommonTokenStream(lexer)
parser = cls.Parser(tokens)
@ -361,16 +437,32 @@ def print_tree(tree, text, kids, ind=0):
for child in children:
print_tree(child, text, kids, ind + 1)
def main():
def get_compiled(args):
parser = optparse.OptionParser()
(options, inputs) = parser.parse_args(sys.argv[1:])
parser.add_option("--input_string", dest="input_string", default=False,
action="store_true",
help="Indicates that inputs should be treated not as file names but "
"as the contents to compile")
(options, inputs) = parser.parse_args(args)
if len(inputs) != 1:
parser.error("Usage: %prog [options] policy-file")
compiler = Compiler()
compiler.read_source(inputs[0])
for i in inputs:
compiler.read_source(i, input_string=options.input_string)
compiler.eliminate_self_joins()
compiler.compute_delta_rules()
return compiler
if __name__ == '__main__':
sys.exit(main())
def get_runtime(args):
comp = get_compiled(args)
run = runtime.Runtime(comp.delta_rules)
tracer = runtime.Tracer()
tracer.trace('*')
run.tracer = tracer
run.database.tracer = tracer
return run
# if __name__ == '__main__':
# main()

View File

@ -15,15 +15,23 @@ class CongressRuntime (Exception):
pass
class DeltaRule(object):
def __init__(self, trigger, head, body):
def __init__(self, trigger, head, body, original):
self.trigger = trigger # atom
self.head = head # atom
self.body = body # list of atoms with is_negated()
self.original = original # Rule from which derived
def __str__(self):
return "<trigger: {}, head: {}, body: {}>".format(
str(self.trigger), str(self.head), [str(lit) for lit in self.body])
def __eq__(self, other):
return (self.trigger == other.trigger and
self.head == other.head and
len(self.body) == len(other.body) and
all(self.body[i] == other.body[i]
for i in xrange(0, len(self.body))))
class EventQueue(object):
def __init__(self):
self.queue = collections.deque()
@ -46,6 +54,7 @@ class Event(object):
self.table = table
self.tuple = Database.DBTuple(tuple, proofs=proofs)
self.insert = insert
logging.debug("EVENT: created event {}".format(str(self)))
def is_insert(self):
return self.insert
@ -57,105 +66,26 @@ class Event(object):
sign = '-'
return "{}{}({})".format(self.table, sign, str(self.tuple))
# class Database(object):
# class DictTuple(object):
# def __init__(self, binding, refcount=1):
# self.binding = binding
# self.refcount = refcount
# def __eq__(self, other):
# return self.binding == other.binding
# def __str__(self):
# return "<binding: {}, refcount: {}>".format(
# str(self.binding), self.refcount)
# def matches(self, binding):
# print "Checking if tuple {} matches binding {}".format(str(self), str(binding))
# for column_name in self.binding.keys():
# if column_name not in binding:
# return False
# if self.binding[column_name] != binding[column_name]:
# return False
# print "Check succeeded with binding {}".format(str(binding))
# return True
# class Schema (object):
# def __init__(self, column_names):
# self.arguments = column_names
# def __str__(self):
# return str(self.arguments)
# def __init__(self):
# self.data = {'p': [], 'q': [], 'r': [self.DictTuple({1: 1})]}
# # self.data = {'p': [self.DictTuple({1: 'a'}),
# # self.DictTuple({1: 'b'}),
# # self.DictTuple({1, 'c'})],
# # 'q': [self.DictTuple({1: 'b'}),
# # self.DictTuple({1: 'c'}),
# # self.DictTuple({1, 'd'})],
# # 'r': [self.DictTuple({1: 'c'}),
# # self.DictTuple({1: 'd'}),
# # self.DictTuple({1, 'e'})]
# # }
# self.schemas = {'p': Database.Schema([1]),
# 'q': Database.Schema([1]),
# 'r': Database.Schema([1])}
# def __str__(self):
# def hash2str (h):
# s = "{"
# s += ", ".join(["{} : {}".format(str(key), str(h[key]))
# for key in h])
# return s
# def hashlist2str (h):
# strings = []
# for key in h:
# s = "{} : ".format(key)
# s += '['
# s += ', '.join([str(val) for val in h[key]])
# s += ']'
# strings.append(s)
# return '{' + ", ".join(strings) + '}'
# return "<data: {}, \nschemas: {}>".format(
# hashlist2str(self.data), hash2str(self.schemas))
# def get_matches(self, table, binding, columns=None):
# print "Getting matches for table {} with binding {}".format(
# str(table), str(binding))
# if table not in self.data:
# raise CongressRuntime("Table not found ".format(table))
# result = []
# for dicttuple in self.data[table]:
# print "Matching database tuple {}".format(str(dicttuple))
# if dicttuple.matches(binding):
# result.append(dicttuple)
# return result
# def insert(self, table, binding, refcount=1):
# if table not in self.data:
# raise CongressRuntime("Table not found ".format(table))
# for dicttuple in self.data[table]:
# if dicttuple.binding == binding:
# dicttuple.refcount += refcount
# return
# self.data[table].append(self.DictTuple(binding, refcount))
# def delete(self, table, binding, refcount=1):
# if table not in self.data:
# raise CongressRuntime("Table not found ".format(table))
# for dicttuple in self.data[table]:
# if dicttuple.binding == binding:
# dicttuple.refcount -= refcount
# if dicttuple.refcount < 0:
# raise CongressRuntime("Deleted more tuples than existed")
# return
# raise CongressRuntime("Deleted tuple that didn't exist")
class Database(object):
class Proof(object):
def __init__(self, binding, rule):
self.binding = binding
self.rule = rule
def __str__(self):
return "apply({}, {})".format(str(self.binding), str(self.rule))
def __eq__(self, other):
result = (self.binding == other.binding and
self.rule == other.rule)
# logging.debug("Pf: Comparing {} and {}: {}".format(
# str(self), str(other), result))
# logging.debug("Pf: {} == {} is {}".format(
# str(self.binding), str(other.binding), self.binding == other.binding))
# logging.debug("Pf: {} == {} is {}".format(
# str(self.rule), str(other.rule), self.rule == other.rule))
return result
class ProofCollection(object):
def __init__(self, proofs):
self.contents = list(proofs)
@ -166,6 +96,7 @@ class Database(object):
def __isub__(self, other):
if other is None:
return
# logging.debug("PC: Subtracting {} and {}".format(str(self), str(other)))
remaining = []
for proof in self.contents:
if proof not in other.contents:
@ -176,7 +107,9 @@ class Database(object):
def __ior__(self, other):
if other is None:
return
# logging.debug("PC: Unioning {} and {}".format(str(self), str(other)))
for proof in other.contents:
# logging.debug("PC: Considering {}".format(str(proof)))
if proof not in self.contents:
self.contents.append(proof)
return self
@ -272,14 +205,39 @@ class Database(object):
for dbtuple in self.data[table]:
if dbtuple not in other.data[table]:
add_tuple(table, dbtuple)
return results
def __getitem__(self, key):
# KEY must be a tablename
return self.data[key]
def __iter__(self):
self.__
def table_names(self):
return self.data.keys()
def log(self, table, msg):
if self.tracer.is_traced(table):
logging.debug(msg)
logging.debug("DB: " + msg)
def get_matches(self, atom, binding):
def get_matches(self, literal, binding):
""" Returns a list of binding lists for the variables in LITERAL
not bound in BINDING. If LITERAL is negative, returns
either [] meaning the lookup failed or [{}] meaning the lookup
succeeded; otherwise, returns one binding list for each tuple in
the database matching LITERAL under BINDING. """
# slow--should stop at first match, not find all of them
matches = self.get_matches_atom(literal, binding)
if literal.is_negated():
if len(matches) > 0:
return []
else:
return [{}]
else:
return matches
def get_matches_atom(self, atom, binding):
""" Returns a list of binding lists for the variables in ATOM
not bound in BINDING: one binding list for each tuple in
the database matching ATOM under BINDING. """
@ -296,16 +254,21 @@ class Database(object):
def insert(self, table, dbtuple):
if not isinstance(dbtuple, Database.DBTuple):
dbtuple = Database.DBTuple(dbtuple)
self.log(table, "Inserting table {} tuple {} into DB".format(
self.log(table, "Inserting table {} tuple {}".format(
table, str(dbtuple)))
if table not in self.data:
self.data[table] = [dbtuple]
# self.log(table, "First tuple in table {}".format(table))
else:
# self.log(table, "Not first tuple in table {}".format(table))
for existingtuple in self.data[table]:
assert(existingtuple.proofs is not None)
if existingtuple.tuple == dbtuple.tuple:
# self.log(table, "Found existing tuple: {}".format(
# str(existingtuple)))
assert(existingtuple.proofs is not None)
existingtuple.proofs |= dbtuple.proofs
# self.log(table, "Updated tuple: {}".format(str(existingtuple)))
assert(existingtuple.proofs is not None)
return
self.data[table].append(dbtuple)
@ -342,9 +305,9 @@ class Runtime (object):
# tracer object
self.tracer = Tracer()
def log(self, table, msg):
def log(self, table, msg, depth=0):
if self.tracer.is_traced(table):
logging.debug(msg)
logging.debug("{}{}".format(("| " * depth), msg))
def insert(self, table, tuple):
""" Event handler for an insertion.
@ -372,12 +335,9 @@ class Runtime (object):
""" Toplevel evaluation routine. """
while len(self.queue) > 0:
event = self.queue.dequeue()
# Note differing order of insert/delete into database.
# Insert happens before propagation; Delete happens after propagation.
# Necessary for correctness on self-joins.
if event.is_insert():
self.database.insert(event.table, event.tuple)
self.propagate(event)
self.database.insert(event.table, event.tuple)
else:
self.propagate(event)
self.database.delete(event.table, event.tuple)
@ -394,7 +354,6 @@ class Runtime (object):
def propagate_rule(self, event, delta_rule):
""" Compute and enqueue new events generated by EVENT and DELTA_RULE. """
assert(not delta_rule.trigger.is_negated())
self.log(event.table, "Processing event {} with rule {}".format(
str(event), str(delta_rule)))
@ -418,49 +377,61 @@ class Runtime (object):
new_tuple = tuple(plug(delta_rule.head, new_binding))
if new_tuple not in new_tuples:
new_tuples[new_tuple] = []
new_tuples[new_tuple].append(new_binding)
new_tuples[new_tuple].append(Database.Proof(
new_binding, delta_rule.original))
self.log(event.table, "new tuples generated: {}".format(
str(new_tuples)))
", ".join([str(x) for x in new_tuples])))
# enqueue each distinct generated tuple, recording appropriate bindings
head_table = delta_rule.head.table
if delta_rule.trigger.is_negated():
insert_delete = not event.insert
else:
insert_delete = event.insert
for new_tuple in new_tuples:
# self.log(event.table,
# "new_tuple {}: {}".format(str(new_tuple), str(new_tuples[new_tuple])))
self.queue.enqueue(Event(table=head_table,
tuple=new_tuple,
proofs=new_tuples[new_tuple],
insert=event.insert))
insert=insert_delete))
def top_down_eval(self, atoms, atom_index, binding):
""" Compute all instances of ATOMS (from ATOM_INDEX and above) that
def top_down_eval(self, literals, literal_index, binding):
""" Compute all instances of LITERALS (from LITERAL_INDEX and above) that
are true in the Database (after applying the dictionary binding
BINDING to ATOMs). Returns a list of dictionary bindings. """
if atom_index > len(atoms) - 1:
BINDING to LITERALS). Returns a list of dictionary bindings. """
if literal_index > len(literals) - 1:
return [binding]
atom = atoms[atom_index]
self.log(atom.table, ("Top_down_eval(atoms={}, atom_index={}, "
lit = literals[literal_index]
self.log(lit.table, ("Top_down_eval(literals={}, literal_index={}, "
"bindings={})").format(
"[" + ",".join(str(x) for x in atoms) + "]",
atom_index,
str(binding)))
data_bindings = self.database.get_matches(atom, binding)
self.log(atom.table, "data_bindings: " + str(data_bindings))
if len(data_bindings) == 0:
"[" + ",".join(str(x) for x in literals) + "]",
literal_index,
str(binding)),
depth=literal_index)
# assume that for negative literals, all vars are bound at this point
# if there is a match, data_bindings will contain at least one binding
# (possibly including the empty binding)
data_bindings = self.database.get_matches(lit, binding)
self.log(lit.table, "data_bindings: " + str(data_bindings), depth=literal_index)
# if not negated, empty data_bindings means failure
if len(data_bindings) == 0 :
return []
results = []
for data_binding in data_bindings:
# add new binding to current binding
binding.update(data_binding)
if atom_index == len(atoms) - 1: # last element in atoms
if literal_index == len(literals) - 1: # last element
results.append(dict(binding)) # need to copy
else:
results.extend(self.top_down_eval(atoms, atom_index + 1, binding))
results.extend(self.top_down_eval(literals, literal_index + 1,
binding))
# remove new binding from current bindings
for var in data_binding:
del binding[var]
# self.log(atom.table, "Top_down_eval return value: {}".format(
# '[' + ", ".join([str(x) for x in results]) + ']'))
self.log(lit.table, "Top_down_eval return value: {}".format(
'[' + ", ".join([str(x) for x in results]) + ']'), depth=literal_index)
return results

View File

@ -4,6 +4,7 @@
import unittest
from policy import compile
from policy import runtime
from policy.runtime import Database
import logging
class TestRuntime(unittest.TestCase):
@ -11,244 +12,371 @@ class TestRuntime(unittest.TestCase):
def setUp(self):
pass
def test_runtime(self):
def prep_runtime(code, msg=None):
# compile source
if msg is not None:
logging.debug(msg)
c = compile.Compiler()
c.read_source(input_string=code)
c.compute_delta_rules()
run = runtime.Runtime(c.delta_rules)
tracer = runtime.Tracer()
tracer.trace('*')
run.tracer = tracer
run.database.tracer = tracer
return run
def prep_runtime(self, code, msg=None):
# compile source
if msg is not None:
logging.debug(msg)
c = compile.get_compiled([code, '--input_string'])
run = runtime.Runtime(c.delta_rules)
tracer = runtime.Tracer()
tracer.trace('*')
run.tracer = tracer
run.database.tracer = tracer
return run
def insert(run, list):
run.insert(list[0], tuple(list[1:]))
def insert(self, run, list):
run.insert(list[0], tuple(list[1:]))
def delete(run, list):
run.delete(list[0], tuple(list[1:]))
def delete(self, run, list):
run.delete(list[0], tuple(list[1:]))
def check(run, correct_database_code, msg=None):
# extract correct answer from correct_database_code
logging.debug("** Checking {} **".format(msg))
c = compile.Compiler()
c.read_source(input_string=correct_database_code)
correct = c.theory
correct_database = runtime.Database()
for atom in correct:
correct_database.insert(atom.table,
[x.name for x in atom.arguments])
def check(self, run, correct_database_code, msg=None):
# extract correct answer from correct_database_code
logging.debug("** Checking {} **".format(msg))
c = compile.get_compiled([correct_database_code, '--input_string'])
correct = c.theory
correct_database = runtime.Database()
for atom in correct:
correct_database.insert(atom.table,
[x.name for x in atom.arguments])
# compute diffs; should be empty
extra = run.database - correct_database
missing = correct_database - run.database
errmsg = ""
if len(extra) > 0:
logging.debug("Extra tuples")
logging.debug(", ".join([str(x) for x in extra]))
if len(missing) > 0:
logging.debug("Missing tuples")
logging.debug(", ".join([str(x) for x in missing]))
self.assertTrue(len(extra) == 0 and len(missing) == 0, msg)
logging.debug(str(run.database))
logging.debug("** Finished {} **".format(msg))
# compute diffs; should be empty
extra = run.database - correct_database
missing = correct_database - run.database
extra = [e for e in extra if not e[0].startswith("___")]
missing = [m for m in missing if not m[0].startswith("___")]
errmsg = ""
if len(extra) > 0:
logging.debug("Extra tuples")
logging.debug(", ".join([str(x) for x in extra]))
if len(missing) > 0:
logging.debug("Missing tuples")
logging.debug(", ".join([str(x) for x in missing]))
self.assertTrue(len(extra) == 0 and len(missing) == 0, msg)
logging.debug(str(run.database))
logging.debug("** Finished {} **".format(msg))
def showdb(run):
logging.debug("Resulting DB: " + str(run.database))
def check_proofs(self, run, correct, msg=None):
""" Check that the proofs stored in runtime RUN are exactly
those in CORRECT. """
# example
# check_proofs(run, {'q': {(1,):
# Database.ProofCollection([{'x': 1, 'y': 2}])}})
errs = []
checked_tables = set()
for table in run.database.table_names():
if table in correct:
checked_tables.add(table)
for dbtuple in run.database[table]:
if dbtuple.tuple in correct[table]:
if dbtuple.proofs != correct[table][dbtuple.tuple]:
errs.append("For table {} tuple {}\n "
"Computed: {}\n "
"Correct: {}".format(table, str(dbtuple),
str(dbtuple.proofs),
str(correct[table][dbtuple.tuple])))
for table in set(correct.keys()) - checked_tables:
errs.append("Table {} had a correct answer but did not exist "
"in the database".format(table))
if len(errs) > 0:
# logging.debug("Check_proof errors:\n{}".format("\n".join(errs)))
self.fail("\n".join(errs))
# basic tests
def showdb(self, run):
logging.debug("Resulting DB: " + str(run.database))
def test_database(self):
code = ("")
run = self.prep_runtime(code, "**** Database tests ****")
self.check(run, "", "Empty database on init")
self.insert(run, ['r', 1])
self.check(run, "r(1)", "Basic insert with no propagations")
self.insert(run, ['r', 1])
self.check(run, "r(1)", "Duplicate insert with no propagations")
self.delete(run, ['r', 1])
self.check(run, "", "Delete with no propagations")
self.delete(run, ['r', 1])
self.check(run, "", "Delete from empty table")
def test_unary_tables(self):
""" Test rules for tables with one argument """
code = ("q(x) :- p(x), r(x)")
run = prep_runtime(code, "**** Basic tests ****")
run = self.prep_runtime(code, "**** Basic propagation tests ****")
self.insert(run, ['r', 1])
self.insert(run, ['p', 1])
self.check(run, "r(1) p(1) q(1)", "Insert into base table with 1 propagation")
check(run, "", "Empty database on init")
insert(run, ['r', 1])
check(run, "r(1)", "Basic insert with no propagations")
insert(run, ['r', 1])
check(run, "r(1)", "Duplicate insert with no propagations")
delete(run, ['r', 1])
check(run, "", "Delete with no propagations")
delete(run, ['r', 1])
check(run, "", "Delete from empty table")
insert(run, ['r', 1])
insert(run, ['p', 1])
check(run, "r(1) p(1) q(1)", "Insert into base table with 1 propagation")
showdb(run)
delete(run, ['r', 1])
check(run, "p(1)", "Delete from base table with 1 propagation")
showdb(run)
self.delete(run, ['r', 1])
self.check(run, "p(1)", "Delete from base table with 1 propagation")
# multiple rules
code = ("q(x) :- p(x), r(x)"
"q(x) :- s(x)")
insert(run, ['p', 1])
insert(run, ['r', 1])
showdb(run)
check(run, "p(1) r(1) q(1)", "Insert: multiple rules")
insert(run, ['s', 1])
showdb(run)
check(run, "p(1) r(1) s(1) q(1)", "Insert: duplicate conclusions")
self.insert(run, ['p', 1])
self.insert(run, ['r', 1])
self.check(run, "p(1) r(1) q(1)", "Insert: multiple rules")
self.insert(run, ['s', 1])
self.check(run, "p(1) r(1) s(1) q(1)", "Insert: duplicate conclusions")
# body of length 1
code = ("q(x) :- p(x)")
run = prep_runtime(code, "**** Body length 1 tests ****")
run = self.prep_runtime(code, "**** Body length 1 tests ****")
insert(run, ['p', 1])
check(run, "p(1) q(1)", "Insert with body of size 1")
delete(run, ['p', 1])
check(run, "", "Delete with body of size 1")
self.insert(run, ['p', 1])
self.check(run, "p(1) q(1)", "Insert with body of size 1")
self.showdb(run)
self.delete(run, ['p', 1])
self.showdb(run)
self.check(run, "", "Delete with body of size 1")
# existential variables
code = ("q(x) :- p(x,y)")
run = prep_runtime(code, "**** Existential variable tests ****")
code = ("q(x) :- p(x), r(y)")
run = self.prep_runtime(code, "**** Unary tables with existential ****")
self.insert(run, ['p', 1])
self.insert(run, ['r', 2])
self.insert(run, ['r', 3])
self.showdb(run)
self.check(run, "p(1) r(2) r(3) q(1)",
"Insert with unary table and existential")
self.delete(run, ['r', 2])
self.check(run, "p(1) r(3) q(1)",
"Delete 1 with unary table and existential")
self.delete(run, ['r', 3])
self.check(run, "p(1)",
"Delete all with unary table and existential")
insert(run, ['p', 1, 2])
check(run, "p(1, 2) q(1)", "Insert: existential variable in body of size 1")
delete(run, ['p', 1, 2])
check(run, "", "Delete: existential variable in body of size 1")
def test_multi_arity_tables(self):
""" Test rules whose tables have more than 1 argument """
code = ("q(x) :- p(x,y)")
run = self.prep_runtime(code, "**** Multiple-arity table tests ****")
self.insert(run, ['p', 1, 2])
self.check(run, "p(1, 2) q(1)", "Insert: existential variable in body of size 1")
self.delete(run, ['p', 1, 2])
self.check(run, "", "Delete: existential variable in body of size 1")
code = ("q(x) :- p(x,y), r(y,x)")
run = prep_runtime(code)
insert(run, ['p', 1, 2])
showdb(run)
insert(run, ['r', 2, 1])
showdb(run)
check(run, "p(1, 2) r(2, 1) q(1)", "Insert: join in body of size 2")
delete(run, ['p', 1, 2])
showdb(run)
check(run, "r(2, 1)", "Delete: join in body of size 2")
insert(run, ['p', 1, 2])
showdb(run)
insert(run, ['p', 1, 3])
showdb(run)
insert(run, ['r', 3, 1])
showdb(run)
check(run, "r(2, 1) r(3,1) p(1, 2) p(1, 3) q(1)",
run = self.prep_runtime(code)
self.insert(run, ['p', 1, 2])
self.insert(run, ['r', 2, 1])
self.check(run, "p(1, 2) r(2, 1) q(1)", "Insert: join in body of size 2")
self.delete(run, ['p', 1, 2])
self.check(run, "r(2, 1)", "Delete: join in body of size 2")
self.insert(run, ['p', 1, 2])
self.insert(run, ['p', 1, 3])
self.insert(run, ['r', 3, 1])
self.check(run, "r(2, 1) r(3,1) p(1, 2) p(1, 3) q(1)",
"Insert: multiple existential bindings for same head")
delete(run, ['p', 1, 2])
check(run, "r(2, 1) r(3,1) p(1, 3) q(1)",
self.delete(run, ['p', 1, 2])
self.check(run, "r(2, 1) r(3,1) p(1, 3) q(1)",
"Delete: multiple existential bindings for same head")
code = ("q(x,v) :- p(x,y), r(y,z), s(z,w), t(w,v)")
run = prep_runtime(code)
insert(run, ['p', 1, 10])
insert(run, ['p', 1, 20])
insert(run, ['r', 10, 100])
insert(run, ['r', 20, 200])
insert(run, ['s', 100, 1000])
insert(run, ['s', 200, 2000])
insert(run, ['t', 1000, 10000])
insert(run, ['t', 2000, 20000])
run = self.prep_runtime(code)
self.insert(run, ['p', 1, 10])
self.insert(run, ['p', 1, 20])
self.insert(run, ['r', 10, 100])
self.insert(run, ['r', 20, 200])
self.insert(run, ['s', 100, 1000])
self.insert(run, ['s', 200, 2000])
self.insert(run, ['t', 1000, 10000])
self.insert(run, ['t', 2000, 20000])
code = ("p(1,10) p(1,20) r(10,100) r(20,200) s(100,1000) s(200,2000)"
"t(1000, 10000) t(2000,20000) "
"q(1,10000) q(1,20000)")
check(run, code, "Insert: larger join")
delete(run, ['t', 1000, 10000])
self.check(run, code, "Insert: larger join")
self.delete(run, ['t', 1000, 10000])
code = ("p(1,10) p(1,20) r(10,100) r(20,200) s(100,1000) s(200,2000)"
"t(2000,20000) "
"q(1,20000)")
check(run, code, "Delete: larger join")
self.check(run, code, "Delete: larger join")
code = ("q(x,y) :- p(x,z), p(z,y)")
run = prep_runtime(code)
insert(run, ['p', 1, 2])
insert(run, ['p', 1, 3])
insert(run, ['p', 2, 4])
insert(run, ['p', 2, 5])
check(run, 'p(1,2) p(1,3) p(2,4) p(2,5) q(1,4) q(1,5)',
run = self.prep_runtime(code)
self.insert(run, ['p', 1, 2])
self.insert(run, ['p', 1, 3])
self.insert(run, ['p', 2, 4])
self.insert(run, ['p', 2, 5])
self.check(run, 'p(1,2) p(1,3) p(2,4) p(2,5) q(1,4) q(1,5)',
"Insert: self-join")
delete(run, ['p', 2, 4])
check(run, 'p(1,2) p(1,3) p(2,5) q(1,5)')
self.delete(run, ['p', 2, 4])
self.check(run, 'p(1,2) p(1,3) p(2,5) q(1,5)')
code = ("q(x,z) :- p(x,y), p(y,z)")
run = self.prep_runtime(code)
self.insert(run, ['p', 1, 1])
self.check(run, 'p(1,1) q(1,1)', "Insert: self-join on same data")
code = ("q(x,w) :- p(x,y), p(y,z), p(z,w)")
run = prep_runtime(code)
insert(run, ['p', 1, 1])
insert(run, ['p', 1, 2])
insert(run, ['p', 2, 2])
insert(run, ['p', 2, 3])
insert(run, ['p', 2, 4])
insert(run, ['p', 2, 5])
insert(run, ['p', 3, 3])
insert(run, ['p', 3, 4])
insert(run, ['p', 3, 5])
insert(run, ['p', 3, 6])
insert(run, ['p', 3, 7])
run = self.prep_runtime(code)
self.insert(run, ['p', 1, 1])
self.insert(run, ['p', 1, 2])
self.insert(run, ['p', 2, 2])
self.insert(run, ['p', 2, 3])
self.insert(run, ['p', 2, 4])
self.insert(run, ['p', 2, 5])
self.insert(run, ['p', 3, 3])
self.insert(run, ['p', 3, 4])
self.insert(run, ['p', 3, 5])
self.insert(run, ['p', 3, 6])
self.insert(run, ['p', 3, 7])
code = ('p(1,1) p(1,2) p(2,2) p(2,3) p(2,4) p(2,5)'
'p(3,3) p(3,4) p(3,5) p(3,6) p(3,7)'
'q(1,1) q(1,2) q(2,2) q(2,3) q(2,4) q(2,5)'
'q(3,3) q(3,4) q(3,5) q(3,6) q(3,7)'
'q(1,3) q(1,4) q(1,5) q(1,6) q(1,7)'
'q(2,6) q(2,7)')
check(run, code, "Insert: larger self join")
delete(run, ['p', 1, 1])
delete(run, ['p', 2, 2])
self.check(run, code, "Insert: larger self join")
self.delete(run, ['p', 1, 1])
self.delete(run, ['p', 2, 2])
code = (' p(1,2) p(2,3) p(2,4) p(2,5)'
'p(3,3) p(3,4) p(3,5) p(3,6) p(3,7)'
' q(2,3) q(2,4) q(2,5)'
'q(3,3) q(3,4) q(3,5) q(3,6) q(3,7)'
'q(1,3) q(1,4) q(1,5) q(1,6) q(1,7)'
'q(2,6) q(2,7)')
check(run, code, "Delete: larger self join")
self.check(run, code, "Delete: larger self join")
# Value types: string
def test_value_types(self):
""" Test the different value types """
# string
code = ("q(x) :- p(x), r(x)")
run = prep_runtime(code, "String data type")
run = self.prep_runtime(code, "String data type")
insert(run, ['r', 'apple'])
check(run, 'r("apple")', "String insert with no propagations")
insert(run, ['r', 'apple'])
check(run, 'r("apple")', "Duplicate string insert with no propagations")
self.insert(run, ['r', 'apple'])
self.check(run, 'r("apple")', "String insert with no propagations")
self.insert(run, ['r', 'apple'])
self.check(run, 'r("apple")', "Duplicate string insert with no propagations")
delete(run, ['r', 'apple'])
check(run, "", "Delete with no propagations")
delete(run, ['r', 'apple'])
check(run, "", "Delete from empty table")
self.delete(run, ['r', 'apple'])
self.check(run, "", "Delete with no propagations")
self.delete(run, ['r', 'apple'])
self.check(run, "", "Delete from empty table")
insert(run, ['r', 'apple'])
insert(run, ['p', 'apple'])
check(run, 'r("apple") p("apple") q("apple")',
self.insert(run, ['r', 'apple'])
self.insert(run, ['p', 'apple'])
self.check(run, 'r("apple") p("apple") q("apple")',
"String insert with 1 propagation")
showdb(run)
delete(run, ['r', 'apple'])
check(run, 'p("apple")', "String delete with 1 propagation")
showdb(run)
self.delete(run, ['r', 'apple'])
self.check(run, 'p("apple")', "String delete with 1 propagation")
# Value types: floats
# float
code = ("q(x) :- p(x), r(x)")
run = prep_runtime(code, "Float data type")
run = self.prep_runtime(code, "Float data type")
insert(run, ['r', 1.2])
check(run, 'r(1.2)', "String insert with no propagations")
insert(run, ['r', 1.2])
check(run, 'r(1.2)', "Duplicate string insert with no propagations")
self.insert(run, ['r', 1.2])
self.check(run, 'r(1.2)', "String insert with no propagations")
self.insert(run, ['r', 1.2])
self.check(run, 'r(1.2)', "Duplicate string insert with no propagations")
delete(run, ['r', 1.2])
check(run, "", "Delete with no propagations")
delete(run, ['r', 1.2])
check(run, "", "Delete from empty table")
self.delete(run, ['r', 1.2])
self.check(run, "", "Delete with no propagations")
self.delete(run, ['r', 1.2])
self.check(run, "", "Delete from empty table")
insert(run, ['r', 1.2])
insert(run, ['p', 1.2])
check(run, 'r(1.2) p(1.2) q(1.2)',
"String insert with 1 propagation")
showdb(run)
self.insert(run, ['r', 1.2])
self.insert(run, ['p', 1.2])
self.check(run, 'r(1.2) p(1.2) q(1.2)',
"String self.insert with 1 propagation")
delete(run, ['r', 1.2])
check(run, 'p(1.2)', "String delete with 1 propagation")
showdb(run)
self.delete(run, ['r', 1.2])
self.check(run, 'p(1.2)', "String delete with 1 propagation")
# negation
def test_proofs(self):
""" Test if the proof computation is performed correctly. """
def check_table_proofs(run, table, tuple_proof_dict, msg):
for tuple in tuple_proof_dict:
tuple_proof_dict[tuple] = \
Database.ProofCollection(tuple_proof_dict[tuple])
self.check_proofs(run, {table : tuple_proof_dict}, msg)
code = ("q(x) :- p(x,y)")
run = self.prep_runtime(code, "**** Proof tests ****")
self.insert(run, ['p', 1, 2])
check_table_proofs(run, 'q', {(1,): [{u'x': 1, u'y': 2}]},
'Simplest proof test')
def test_negation(self):
""" Test negation """
# Unary, single join
code = ("q(x) :- p(x), not r(x)")
run = self.prep_runtime(code, "Unary, single join")
self.insert(run, ['p', 2])
self.check(run, 'p(2) q(2)',
"Insert into positive literal with propagation")
self.delete(run, ['p', 2])
self.check(run, '',
"Delete from positive literal with propagation")
self.insert(run, ['r', 2])
self.check(run, 'r(2)',
"Insert into negative literal without propagation")
self.delete(run, ['r', 2])
self.check(run, '',
"Delete from negative literal without propagation")
self.insert(run, ['p', 2])
self.insert(run, ['r', 2])
self.check(run, 'p(2) r(2)',
"Insert into negative literal with propagation")
self.delete(run, ['r', 2])
self.check(run, 'q(2) p(2)',
"Delete from negative literal with propagation")
# Unary, multiple joins
code = ("s(x) :- p(x), not r(x), q(y), not t(y)")
run = self.prep_runtime(code, "Unary, multiple join")
self.insert(run, ['p', 1])
self.insert(run, ['q', 2])
self.check(run, 'p(1) q(2) s(1)',
'Insert with two negative literals')
self.insert(run, ['r', 3])
self.check(run, 'p(1) q(2) s(1) r(3)',
'Ineffectual insert with 2 negative literals')
self.insert(run, ['r', 1])
self.check(run, 'p(1) q(2) r(3) r(1)',
'Insert into existentially quantified negative literal with propagation. ')
self.insert(run, ['t', 2])
self.check(run, 'p(1) q(2) r(3) r(1) t(2)',
'Insert into negative literal producing extra blocker for proof.')
self.delete(run, ['t', 2])
self.check(run, 'p(1) q(2) r(3) r(1)',
'Delete first blocker from proof')
self.delete(run, ['r', 1])
self.check(run, 'p(1) q(2) r(3) s(1)',
'Delete second blocker from proof')
# Non-unary
code = ("p(x, v) :- q(x,z), r(z, w), not s(x, w), u(w,v)")
run = self.prep_runtime(code, "Non-unary")
self.insert(run, ['q', 1, 2])
self.insert(run, ['r', 2, 3])
self.insert(run, ['r', 2, 4])
self.insert(run, ['u', 3, 5])
self.insert(run, ['u', 4, 6])
self.check(run, 'q(1,2) r(2,3) r(2,4) u(3,5) u(4,6) p(1,5) p(1,6)',
'Insert with non-unary negative literal')
self.insert(run, ['s', 1, 3])
self.check(run, 'q(1,2) r(2,3) r(2,4) u(3,5) u(4,6) s(1,3) p(1,6)',
'Insert into non-unary negative with propagation')
self.insert(run, ['s', 1, 4])
self.check(run, 'q(1,2) r(2,3) r(2,4) u(3,5) u(4,6) s(1,3) s(1,4)',
'Insert into non-unary with different propagation')
if __name__ == '__main__':
unittest.main()

View File

@ -1,3 +1,24 @@
q(x) :- p(x), r(x)
error(vm) :- virtual_machine(vm), network(vm, net), not public(net), owner(vm, owner), not owned_by_some_group(owner, net)
owned_by_some_group(owner, object) :- owner(object, user), member(user, group), member(owner, group)
virtual_machine('vm1')
virtual_machine('vm2')
network('vm1', 'net_private')
network('vm2', 'net_public')
public('net_public')
owner('vm1', 'tim')
owner('vm2', 'pete')
owner('net_private', 'martin')
member('pete', 'congress')
member('tim', 'congress')
member('martin', 'congress')
member('pierre', 'congress')