Added negation, proofs to policy engine runtime.
Passes all tests, except I haven't yet figured out how to write reasonable tests for the proofs. Issue: # Change-Id: I713c10481afc59f9872fc98866145a2b594efd64
This commit is contained in:
parent
0b613444df
commit
ad78238c79
|
@ -8,6 +8,7 @@ import CongressLexer
|
|||
import CongressParser
|
||||
import antlr3
|
||||
import runtime
|
||||
import logging
|
||||
|
||||
class CongressException (Exception):
|
||||
def __init__(self, msg, obj=None, line=None, col=None):
|
||||
|
@ -68,6 +69,9 @@ class Variable (Term):
|
|||
def __str__(self):
|
||||
return str(self.name)
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, Variable) and self.name == other.name
|
||||
|
||||
def is_variable(self):
|
||||
return True
|
||||
|
||||
|
@ -89,6 +93,11 @@ class ObjectConstant (Term):
|
|||
def __str__(self):
|
||||
return str(self.name)
|
||||
|
||||
def __eq__(self, other):
|
||||
return (isinstance(other, ObjectConstant) and
|
||||
self.name == other.name and
|
||||
self.type == other.type)
|
||||
|
||||
def is_variable(self):
|
||||
return False
|
||||
|
||||
|
@ -115,6 +124,13 @@ class Atom (object):
|
|||
return "{}({})".format(self.table,
|
||||
", ".join([str(x) for x in self.arguments]))
|
||||
|
||||
def __eq__(self, other):
|
||||
return (isinstance(other, Atom) and
|
||||
self.table == other.table and
|
||||
len(self.arguments) == len(other.arguments) and
|
||||
all(self.arguments[i] == other.arguments[i]
|
||||
for i in xrange(0, len(self.arguments))))
|
||||
|
||||
def is_atom(self):
|
||||
return True
|
||||
|
||||
|
@ -139,6 +155,9 @@ class Literal(Atom):
|
|||
else:
|
||||
return Atom.__str__(self)
|
||||
|
||||
def __eq__(self, other):
|
||||
return (self.negated == other.negated and Atom.__eq__(self, other))
|
||||
|
||||
def is_negated(self):
|
||||
return self.negated
|
||||
|
||||
|
@ -160,6 +179,12 @@ class Rule (object):
|
|||
str(self.head),
|
||||
", ".join([str(atom) for atom in self.body]))
|
||||
|
||||
def __eq__(self, other):
|
||||
return (self.head == other.head and
|
||||
len(self.body) == len(other.body) and
|
||||
all(self.body[i] == other.body[i]
|
||||
for i in xrange(0, len(self.body))))
|
||||
|
||||
def is_atom(self):
|
||||
return False
|
||||
|
||||
|
@ -184,10 +209,9 @@ class Compiler (object):
|
|||
s += 'None'
|
||||
return s
|
||||
|
||||
def read_source(self, input_file=None, input_string=None):
|
||||
assert(input_file is not None or input_string is not None)
|
||||
def read_source(self, input, input_string=False):
|
||||
# parse input file and convert to internal representation
|
||||
self.raw_syntax_tree = CongressSyntax.parse_file(input_file=input_file,
|
||||
self.raw_syntax_tree = CongressSyntax.parse_file(input,
|
||||
input_string=input_string)
|
||||
#self.print_parse_result()
|
||||
self.theory = CongressSyntax.create(self.raw_syntax_tree)
|
||||
|
@ -211,13 +235,66 @@ class Compiler (object):
|
|||
errors = [str(err) for err in self.errors]
|
||||
raise CongressException('Compiler found errors:' + '\n'.join(errors))
|
||||
|
||||
def eliminate_self_joins(self):
|
||||
def new_table_name(name, arity, index):
|
||||
return "___{}_{}_{}".format(name, arity, index)
|
||||
def n_variables(n):
|
||||
vars = []
|
||||
for i in xrange(0, n):
|
||||
vars.append("x" + str(i))
|
||||
return vars
|
||||
# dict from (table name, arity) tuple to
|
||||
# max num of occurrences of self-joins in any rule
|
||||
global_self_joins = {}
|
||||
# dict from (table name, arity) to # of args for
|
||||
arities = {}
|
||||
# remove self-joins from rules
|
||||
for rule in self.theory:
|
||||
if rule.is_atom():
|
||||
continue
|
||||
logging.debug("eliminating self joins from {}".format(rule))
|
||||
occurrences = {} # for just this rule
|
||||
for atom in rule.body:
|
||||
table = atom.table
|
||||
arity = len(atom.arguments)
|
||||
tablearity = (table, arity)
|
||||
if tablearity not in occurrences:
|
||||
occurrences[tablearity] = 1
|
||||
else:
|
||||
# change name of atom
|
||||
atom.table = new_table_name(table, arity,
|
||||
occurrences[tablearity])
|
||||
# update our counters
|
||||
occurrences[tablearity] += 1
|
||||
if tablearity not in global_self_joins:
|
||||
global_self_joins[tablearity] = 1
|
||||
else:
|
||||
global_self_joins[tablearity] = \
|
||||
max(occurrences[tablearity] - 1,
|
||||
global_self_joins[tablearity])
|
||||
logging.debug("final rule: {}".format(str(rule)))
|
||||
# add definitions for new tables
|
||||
for tablearity in global_self_joins:
|
||||
table = tablearity[0]
|
||||
arity = tablearity[1]
|
||||
for i in xrange(1, global_self_joins[tablearity] + 1):
|
||||
newtable = new_table_name(table, arity, i)
|
||||
args = [Variable(var) for var in n_variables(arity)]
|
||||
head = Atom(newtable, args)
|
||||
body = [Atom(table, args)]
|
||||
self.theory.append(Rule(head, body))
|
||||
logging.debug("Adding rule {}".format(str(self.theory[-1])))
|
||||
|
||||
def compute_delta_rules(self):
|
||||
""" Assumes no self-joins. """
|
||||
self.delta_rules = []
|
||||
for rule in self.theory:
|
||||
if rule.is_atom():
|
||||
continue
|
||||
for literal in rule.body:
|
||||
newbody = [lit for lit in rule.body if lit is not literal]
|
||||
self.delta_rules.append(
|
||||
runtime.DeltaRule(literal, rule.head, newbody))
|
||||
runtime.DeltaRule(literal, rule.head, newbody, rule))
|
||||
|
||||
class CongressSyntax (object):
|
||||
""" External syntax and converting it into internal representation. """
|
||||
|
@ -251,12 +328,11 @@ class CongressSyntax (object):
|
|||
e.line, e.charPositionInLine)
|
||||
|
||||
@classmethod
|
||||
def parse_file(cls, input_file=None, input_string=None):
|
||||
assert(input_file is not None or input_string is not None)
|
||||
if input_file is not None:
|
||||
char_stream = antlr3.ANTLRFileStream(input_file)
|
||||
def parse_file(cls, input, input_string=False):
|
||||
if not input_string:
|
||||
char_stream = antlr3.ANTLRFileStream(input)
|
||||
else:
|
||||
char_stream = antlr3.ANTLRStringStream(input_string)
|
||||
char_stream = antlr3.ANTLRStringStream(input)
|
||||
lexer = cls.Lexer(char_stream)
|
||||
tokens = antlr3.CommonTokenStream(lexer)
|
||||
parser = cls.Parser(tokens)
|
||||
|
@ -361,16 +437,32 @@ def print_tree(tree, text, kids, ind=0):
|
|||
for child in children:
|
||||
print_tree(child, text, kids, ind + 1)
|
||||
|
||||
def main():
|
||||
def get_compiled(args):
|
||||
parser = optparse.OptionParser()
|
||||
(options, inputs) = parser.parse_args(sys.argv[1:])
|
||||
parser.add_option("--input_string", dest="input_string", default=False,
|
||||
action="store_true",
|
||||
help="Indicates that inputs should be treated not as file names but "
|
||||
"as the contents to compile")
|
||||
(options, inputs) = parser.parse_args(args)
|
||||
if len(inputs) != 1:
|
||||
parser.error("Usage: %prog [options] policy-file")
|
||||
compiler = Compiler()
|
||||
compiler.read_source(inputs[0])
|
||||
for i in inputs:
|
||||
compiler.read_source(i, input_string=options.input_string)
|
||||
compiler.eliminate_self_joins()
|
||||
compiler.compute_delta_rules()
|
||||
return compiler
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main())
|
||||
def get_runtime(args):
|
||||
comp = get_compiled(args)
|
||||
run = runtime.Runtime(comp.delta_rules)
|
||||
tracer = runtime.Tracer()
|
||||
tracer.trace('*')
|
||||
run.tracer = tracer
|
||||
run.database.tracer = tracer
|
||||
return run
|
||||
|
||||
# if __name__ == '__main__':
|
||||
# main()
|
||||
|
||||
|
||||
|
|
|
@ -15,15 +15,23 @@ class CongressRuntime (Exception):
|
|||
pass
|
||||
|
||||
class DeltaRule(object):
|
||||
def __init__(self, trigger, head, body):
|
||||
def __init__(self, trigger, head, body, original):
|
||||
self.trigger = trigger # atom
|
||||
self.head = head # atom
|
||||
self.body = body # list of atoms with is_negated()
|
||||
self.original = original # Rule from which derived
|
||||
|
||||
def __str__(self):
|
||||
return "<trigger: {}, head: {}, body: {}>".format(
|
||||
str(self.trigger), str(self.head), [str(lit) for lit in self.body])
|
||||
|
||||
def __eq__(self, other):
|
||||
return (self.trigger == other.trigger and
|
||||
self.head == other.head and
|
||||
len(self.body) == len(other.body) and
|
||||
all(self.body[i] == other.body[i]
|
||||
for i in xrange(0, len(self.body))))
|
||||
|
||||
class EventQueue(object):
|
||||
def __init__(self):
|
||||
self.queue = collections.deque()
|
||||
|
@ -46,6 +54,7 @@ class Event(object):
|
|||
self.table = table
|
||||
self.tuple = Database.DBTuple(tuple, proofs=proofs)
|
||||
self.insert = insert
|
||||
logging.debug("EVENT: created event {}".format(str(self)))
|
||||
|
||||
def is_insert(self):
|
||||
return self.insert
|
||||
|
@ -57,105 +66,26 @@ class Event(object):
|
|||
sign = '-'
|
||||
return "{}{}({})".format(self.table, sign, str(self.tuple))
|
||||
|
||||
# class Database(object):
|
||||
# class DictTuple(object):
|
||||
# def __init__(self, binding, refcount=1):
|
||||
# self.binding = binding
|
||||
# self.refcount = refcount
|
||||
|
||||
# def __eq__(self, other):
|
||||
# return self.binding == other.binding
|
||||
|
||||
# def __str__(self):
|
||||
# return "<binding: {}, refcount: {}>".format(
|
||||
# str(self.binding), self.refcount)
|
||||
|
||||
# def matches(self, binding):
|
||||
# print "Checking if tuple {} matches binding {}".format(str(self), str(binding))
|
||||
# for column_name in self.binding.keys():
|
||||
# if column_name not in binding:
|
||||
# return False
|
||||
# if self.binding[column_name] != binding[column_name]:
|
||||
# return False
|
||||
# print "Check succeeded with binding {}".format(str(binding))
|
||||
# return True
|
||||
|
||||
# class Schema (object):
|
||||
# def __init__(self, column_names):
|
||||
# self.arguments = column_names
|
||||
# def __str__(self):
|
||||
# return str(self.arguments)
|
||||
|
||||
# def __init__(self):
|
||||
# self.data = {'p': [], 'q': [], 'r': [self.DictTuple({1: 1})]}
|
||||
# # self.data = {'p': [self.DictTuple({1: 'a'}),
|
||||
# # self.DictTuple({1: 'b'}),
|
||||
# # self.DictTuple({1, 'c'})],
|
||||
# # 'q': [self.DictTuple({1: 'b'}),
|
||||
# # self.DictTuple({1: 'c'}),
|
||||
# # self.DictTuple({1, 'd'})],
|
||||
# # 'r': [self.DictTuple({1: 'c'}),
|
||||
# # self.DictTuple({1: 'd'}),
|
||||
# # self.DictTuple({1, 'e'})]
|
||||
# # }
|
||||
# self.schemas = {'p': Database.Schema([1]),
|
||||
# 'q': Database.Schema([1]),
|
||||
# 'r': Database.Schema([1])}
|
||||
|
||||
# def __str__(self):
|
||||
# def hash2str (h):
|
||||
# s = "{"
|
||||
# s += ", ".join(["{} : {}".format(str(key), str(h[key]))
|
||||
# for key in h])
|
||||
# return s
|
||||
|
||||
# def hashlist2str (h):
|
||||
# strings = []
|
||||
# for key in h:
|
||||
# s = "{} : ".format(key)
|
||||
# s += '['
|
||||
# s += ', '.join([str(val) for val in h[key]])
|
||||
# s += ']'
|
||||
# strings.append(s)
|
||||
# return '{' + ", ".join(strings) + '}'
|
||||
|
||||
# return "<data: {}, \nschemas: {}>".format(
|
||||
# hashlist2str(self.data), hash2str(self.schemas))
|
||||
|
||||
# def get_matches(self, table, binding, columns=None):
|
||||
# print "Getting matches for table {} with binding {}".format(
|
||||
# str(table), str(binding))
|
||||
# if table not in self.data:
|
||||
# raise CongressRuntime("Table not found ".format(table))
|
||||
# result = []
|
||||
# for dicttuple in self.data[table]:
|
||||
# print "Matching database tuple {}".format(str(dicttuple))
|
||||
# if dicttuple.matches(binding):
|
||||
# result.append(dicttuple)
|
||||
# return result
|
||||
|
||||
# def insert(self, table, binding, refcount=1):
|
||||
# if table not in self.data:
|
||||
# raise CongressRuntime("Table not found ".format(table))
|
||||
# for dicttuple in self.data[table]:
|
||||
# if dicttuple.binding == binding:
|
||||
# dicttuple.refcount += refcount
|
||||
# return
|
||||
# self.data[table].append(self.DictTuple(binding, refcount))
|
||||
|
||||
# def delete(self, table, binding, refcount=1):
|
||||
# if table not in self.data:
|
||||
# raise CongressRuntime("Table not found ".format(table))
|
||||
# for dicttuple in self.data[table]:
|
||||
# if dicttuple.binding == binding:
|
||||
# dicttuple.refcount -= refcount
|
||||
# if dicttuple.refcount < 0:
|
||||
# raise CongressRuntime("Deleted more tuples than existed")
|
||||
# return
|
||||
# raise CongressRuntime("Deleted tuple that didn't exist")
|
||||
|
||||
|
||||
class Database(object):
|
||||
class Proof(object):
|
||||
def __init__(self, binding, rule):
|
||||
self.binding = binding
|
||||
self.rule = rule
|
||||
|
||||
def __str__(self):
|
||||
return "apply({}, {})".format(str(self.binding), str(self.rule))
|
||||
|
||||
def __eq__(self, other):
|
||||
result = (self.binding == other.binding and
|
||||
self.rule == other.rule)
|
||||
# logging.debug("Pf: Comparing {} and {}: {}".format(
|
||||
# str(self), str(other), result))
|
||||
# logging.debug("Pf: {} == {} is {}".format(
|
||||
# str(self.binding), str(other.binding), self.binding == other.binding))
|
||||
# logging.debug("Pf: {} == {} is {}".format(
|
||||
# str(self.rule), str(other.rule), self.rule == other.rule))
|
||||
return result
|
||||
|
||||
class ProofCollection(object):
|
||||
def __init__(self, proofs):
|
||||
self.contents = list(proofs)
|
||||
|
@ -166,6 +96,7 @@ class Database(object):
|
|||
def __isub__(self, other):
|
||||
if other is None:
|
||||
return
|
||||
# logging.debug("PC: Subtracting {} and {}".format(str(self), str(other)))
|
||||
remaining = []
|
||||
for proof in self.contents:
|
||||
if proof not in other.contents:
|
||||
|
@ -176,7 +107,9 @@ class Database(object):
|
|||
def __ior__(self, other):
|
||||
if other is None:
|
||||
return
|
||||
# logging.debug("PC: Unioning {} and {}".format(str(self), str(other)))
|
||||
for proof in other.contents:
|
||||
# logging.debug("PC: Considering {}".format(str(proof)))
|
||||
if proof not in self.contents:
|
||||
self.contents.append(proof)
|
||||
return self
|
||||
|
@ -272,14 +205,39 @@ class Database(object):
|
|||
for dbtuple in self.data[table]:
|
||||
if dbtuple not in other.data[table]:
|
||||
add_tuple(table, dbtuple)
|
||||
|
||||
return results
|
||||
|
||||
def __getitem__(self, key):
|
||||
# KEY must be a tablename
|
||||
return self.data[key]
|
||||
|
||||
def __iter__(self):
|
||||
self.__
|
||||
|
||||
def table_names(self):
|
||||
return self.data.keys()
|
||||
|
||||
def log(self, table, msg):
|
||||
if self.tracer.is_traced(table):
|
||||
logging.debug(msg)
|
||||
logging.debug("DB: " + msg)
|
||||
|
||||
def get_matches(self, atom, binding):
|
||||
def get_matches(self, literal, binding):
|
||||
""" Returns a list of binding lists for the variables in LITERAL
|
||||
not bound in BINDING. If LITERAL is negative, returns
|
||||
either [] meaning the lookup failed or [{}] meaning the lookup
|
||||
succeeded; otherwise, returns one binding list for each tuple in
|
||||
the database matching LITERAL under BINDING. """
|
||||
# slow--should stop at first match, not find all of them
|
||||
matches = self.get_matches_atom(literal, binding)
|
||||
if literal.is_negated():
|
||||
if len(matches) > 0:
|
||||
return []
|
||||
else:
|
||||
return [{}]
|
||||
else:
|
||||
return matches
|
||||
|
||||
def get_matches_atom(self, atom, binding):
|
||||
""" Returns a list of binding lists for the variables in ATOM
|
||||
not bound in BINDING: one binding list for each tuple in
|
||||
the database matching ATOM under BINDING. """
|
||||
|
@ -296,16 +254,21 @@ class Database(object):
|
|||
def insert(self, table, dbtuple):
|
||||
if not isinstance(dbtuple, Database.DBTuple):
|
||||
dbtuple = Database.DBTuple(dbtuple)
|
||||
self.log(table, "Inserting table {} tuple {} into DB".format(
|
||||
self.log(table, "Inserting table {} tuple {}".format(
|
||||
table, str(dbtuple)))
|
||||
if table not in self.data:
|
||||
self.data[table] = [dbtuple]
|
||||
# self.log(table, "First tuple in table {}".format(table))
|
||||
else:
|
||||
# self.log(table, "Not first tuple in table {}".format(table))
|
||||
for existingtuple in self.data[table]:
|
||||
assert(existingtuple.proofs is not None)
|
||||
if existingtuple.tuple == dbtuple.tuple:
|
||||
# self.log(table, "Found existing tuple: {}".format(
|
||||
# str(existingtuple)))
|
||||
assert(existingtuple.proofs is not None)
|
||||
existingtuple.proofs |= dbtuple.proofs
|
||||
# self.log(table, "Updated tuple: {}".format(str(existingtuple)))
|
||||
assert(existingtuple.proofs is not None)
|
||||
return
|
||||
self.data[table].append(dbtuple)
|
||||
|
@ -342,9 +305,9 @@ class Runtime (object):
|
|||
# tracer object
|
||||
self.tracer = Tracer()
|
||||
|
||||
def log(self, table, msg):
|
||||
def log(self, table, msg, depth=0):
|
||||
if self.tracer.is_traced(table):
|
||||
logging.debug(msg)
|
||||
logging.debug("{}{}".format(("| " * depth), msg))
|
||||
|
||||
def insert(self, table, tuple):
|
||||
""" Event handler for an insertion.
|
||||
|
@ -372,12 +335,9 @@ class Runtime (object):
|
|||
""" Toplevel evaluation routine. """
|
||||
while len(self.queue) > 0:
|
||||
event = self.queue.dequeue()
|
||||
# Note differing order of insert/delete into database.
|
||||
# Insert happens before propagation; Delete happens after propagation.
|
||||
# Necessary for correctness on self-joins.
|
||||
if event.is_insert():
|
||||
self.database.insert(event.table, event.tuple)
|
||||
self.propagate(event)
|
||||
self.database.insert(event.table, event.tuple)
|
||||
else:
|
||||
self.propagate(event)
|
||||
self.database.delete(event.table, event.tuple)
|
||||
|
@ -394,7 +354,6 @@ class Runtime (object):
|
|||
|
||||
def propagate_rule(self, event, delta_rule):
|
||||
""" Compute and enqueue new events generated by EVENT and DELTA_RULE. """
|
||||
assert(not delta_rule.trigger.is_negated())
|
||||
self.log(event.table, "Processing event {} with rule {}".format(
|
||||
str(event), str(delta_rule)))
|
||||
|
||||
|
@ -418,49 +377,61 @@ class Runtime (object):
|
|||
new_tuple = tuple(plug(delta_rule.head, new_binding))
|
||||
if new_tuple not in new_tuples:
|
||||
new_tuples[new_tuple] = []
|
||||
new_tuples[new_tuple].append(new_binding)
|
||||
new_tuples[new_tuple].append(Database.Proof(
|
||||
new_binding, delta_rule.original))
|
||||
self.log(event.table, "new tuples generated: {}".format(
|
||||
str(new_tuples)))
|
||||
", ".join([str(x) for x in new_tuples])))
|
||||
|
||||
# enqueue each distinct generated tuple, recording appropriate bindings
|
||||
head_table = delta_rule.head.table
|
||||
if delta_rule.trigger.is_negated():
|
||||
insert_delete = not event.insert
|
||||
else:
|
||||
insert_delete = event.insert
|
||||
for new_tuple in new_tuples:
|
||||
# self.log(event.table,
|
||||
# "new_tuple {}: {}".format(str(new_tuple), str(new_tuples[new_tuple])))
|
||||
self.queue.enqueue(Event(table=head_table,
|
||||
tuple=new_tuple,
|
||||
proofs=new_tuples[new_tuple],
|
||||
insert=event.insert))
|
||||
insert=insert_delete))
|
||||
|
||||
def top_down_eval(self, atoms, atom_index, binding):
|
||||
""" Compute all instances of ATOMS (from ATOM_INDEX and above) that
|
||||
def top_down_eval(self, literals, literal_index, binding):
|
||||
""" Compute all instances of LITERALS (from LITERAL_INDEX and above) that
|
||||
are true in the Database (after applying the dictionary binding
|
||||
BINDING to ATOMs). Returns a list of dictionary bindings. """
|
||||
if atom_index > len(atoms) - 1:
|
||||
BINDING to LITERALS). Returns a list of dictionary bindings. """
|
||||
if literal_index > len(literals) - 1:
|
||||
return [binding]
|
||||
atom = atoms[atom_index]
|
||||
self.log(atom.table, ("Top_down_eval(atoms={}, atom_index={}, "
|
||||
lit = literals[literal_index]
|
||||
self.log(lit.table, ("Top_down_eval(literals={}, literal_index={}, "
|
||||
"bindings={})").format(
|
||||
"[" + ",".join(str(x) for x in atoms) + "]",
|
||||
atom_index,
|
||||
str(binding)))
|
||||
data_bindings = self.database.get_matches(atom, binding)
|
||||
self.log(atom.table, "data_bindings: " + str(data_bindings))
|
||||
if len(data_bindings) == 0:
|
||||
"[" + ",".join(str(x) for x in literals) + "]",
|
||||
literal_index,
|
||||
str(binding)),
|
||||
depth=literal_index)
|
||||
# assume that for negative literals, all vars are bound at this point
|
||||
# if there is a match, data_bindings will contain at least one binding
|
||||
# (possibly including the empty binding)
|
||||
data_bindings = self.database.get_matches(lit, binding)
|
||||
self.log(lit.table, "data_bindings: " + str(data_bindings), depth=literal_index)
|
||||
# if not negated, empty data_bindings means failure
|
||||
if len(data_bindings) == 0 :
|
||||
return []
|
||||
|
||||
results = []
|
||||
for data_binding in data_bindings:
|
||||
# add new binding to current binding
|
||||
binding.update(data_binding)
|
||||
if atom_index == len(atoms) - 1: # last element in atoms
|
||||
if literal_index == len(literals) - 1: # last element
|
||||
results.append(dict(binding)) # need to copy
|
||||
else:
|
||||
results.extend(self.top_down_eval(atoms, atom_index + 1, binding))
|
||||
results.extend(self.top_down_eval(literals, literal_index + 1,
|
||||
binding))
|
||||
# remove new binding from current bindings
|
||||
for var in data_binding:
|
||||
del binding[var]
|
||||
# self.log(atom.table, "Top_down_eval return value: {}".format(
|
||||
# '[' + ", ".join([str(x) for x in results]) + ']'))
|
||||
self.log(lit.table, "Top_down_eval return value: {}".format(
|
||||
'[' + ", ".join([str(x) for x in results]) + ']'), depth=literal_index)
|
||||
|
||||
return results
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
import unittest
|
||||
from policy import compile
|
||||
from policy import runtime
|
||||
from policy.runtime import Database
|
||||
import logging
|
||||
|
||||
class TestRuntime(unittest.TestCase):
|
||||
|
@ -11,244 +12,371 @@ class TestRuntime(unittest.TestCase):
|
|||
def setUp(self):
|
||||
pass
|
||||
|
||||
def test_runtime(self):
|
||||
def prep_runtime(code, msg=None):
|
||||
# compile source
|
||||
if msg is not None:
|
||||
logging.debug(msg)
|
||||
c = compile.Compiler()
|
||||
c.read_source(input_string=code)
|
||||
c.compute_delta_rules()
|
||||
run = runtime.Runtime(c.delta_rules)
|
||||
tracer = runtime.Tracer()
|
||||
tracer.trace('*')
|
||||
run.tracer = tracer
|
||||
run.database.tracer = tracer
|
||||
return run
|
||||
def prep_runtime(self, code, msg=None):
|
||||
# compile source
|
||||
if msg is not None:
|
||||
logging.debug(msg)
|
||||
c = compile.get_compiled([code, '--input_string'])
|
||||
run = runtime.Runtime(c.delta_rules)
|
||||
tracer = runtime.Tracer()
|
||||
tracer.trace('*')
|
||||
run.tracer = tracer
|
||||
run.database.tracer = tracer
|
||||
return run
|
||||
|
||||
def insert(run, list):
|
||||
run.insert(list[0], tuple(list[1:]))
|
||||
def insert(self, run, list):
|
||||
run.insert(list[0], tuple(list[1:]))
|
||||
|
||||
def delete(run, list):
|
||||
run.delete(list[0], tuple(list[1:]))
|
||||
def delete(self, run, list):
|
||||
run.delete(list[0], tuple(list[1:]))
|
||||
|
||||
def check(run, correct_database_code, msg=None):
|
||||
# extract correct answer from correct_database_code
|
||||
logging.debug("** Checking {} **".format(msg))
|
||||
c = compile.Compiler()
|
||||
c.read_source(input_string=correct_database_code)
|
||||
correct = c.theory
|
||||
correct_database = runtime.Database()
|
||||
for atom in correct:
|
||||
correct_database.insert(atom.table,
|
||||
[x.name for x in atom.arguments])
|
||||
def check(self, run, correct_database_code, msg=None):
|
||||
# extract correct answer from correct_database_code
|
||||
logging.debug("** Checking {} **".format(msg))
|
||||
c = compile.get_compiled([correct_database_code, '--input_string'])
|
||||
correct = c.theory
|
||||
correct_database = runtime.Database()
|
||||
for atom in correct:
|
||||
correct_database.insert(atom.table,
|
||||
[x.name for x in atom.arguments])
|
||||
|
||||
# compute diffs; should be empty
|
||||
extra = run.database - correct_database
|
||||
missing = correct_database - run.database
|
||||
errmsg = ""
|
||||
if len(extra) > 0:
|
||||
logging.debug("Extra tuples")
|
||||
logging.debug(", ".join([str(x) for x in extra]))
|
||||
if len(missing) > 0:
|
||||
logging.debug("Missing tuples")
|
||||
logging.debug(", ".join([str(x) for x in missing]))
|
||||
self.assertTrue(len(extra) == 0 and len(missing) == 0, msg)
|
||||
logging.debug(str(run.database))
|
||||
logging.debug("** Finished {} **".format(msg))
|
||||
# compute diffs; should be empty
|
||||
extra = run.database - correct_database
|
||||
missing = correct_database - run.database
|
||||
extra = [e for e in extra if not e[0].startswith("___")]
|
||||
missing = [m for m in missing if not m[0].startswith("___")]
|
||||
errmsg = ""
|
||||
if len(extra) > 0:
|
||||
logging.debug("Extra tuples")
|
||||
logging.debug(", ".join([str(x) for x in extra]))
|
||||
if len(missing) > 0:
|
||||
logging.debug("Missing tuples")
|
||||
logging.debug(", ".join([str(x) for x in missing]))
|
||||
self.assertTrue(len(extra) == 0 and len(missing) == 0, msg)
|
||||
logging.debug(str(run.database))
|
||||
logging.debug("** Finished {} **".format(msg))
|
||||
|
||||
def showdb(run):
|
||||
logging.debug("Resulting DB: " + str(run.database))
|
||||
def check_proofs(self, run, correct, msg=None):
|
||||
""" Check that the proofs stored in runtime RUN are exactly
|
||||
those in CORRECT. """
|
||||
# example
|
||||
# check_proofs(run, {'q': {(1,):
|
||||
# Database.ProofCollection([{'x': 1, 'y': 2}])}})
|
||||
|
||||
errs = []
|
||||
checked_tables = set()
|
||||
for table in run.database.table_names():
|
||||
if table in correct:
|
||||
checked_tables.add(table)
|
||||
for dbtuple in run.database[table]:
|
||||
if dbtuple.tuple in correct[table]:
|
||||
if dbtuple.proofs != correct[table][dbtuple.tuple]:
|
||||
errs.append("For table {} tuple {}\n "
|
||||
"Computed: {}\n "
|
||||
"Correct: {}".format(table, str(dbtuple),
|
||||
str(dbtuple.proofs),
|
||||
str(correct[table][dbtuple.tuple])))
|
||||
for table in set(correct.keys()) - checked_tables:
|
||||
errs.append("Table {} had a correct answer but did not exist "
|
||||
"in the database".format(table))
|
||||
if len(errs) > 0:
|
||||
# logging.debug("Check_proof errors:\n{}".format("\n".join(errs)))
|
||||
self.fail("\n".join(errs))
|
||||
|
||||
|
||||
# basic tests
|
||||
|
||||
def showdb(self, run):
|
||||
logging.debug("Resulting DB: " + str(run.database))
|
||||
|
||||
def test_database(self):
|
||||
code = ("")
|
||||
run = self.prep_runtime(code, "**** Database tests ****")
|
||||
|
||||
self.check(run, "", "Empty database on init")
|
||||
|
||||
self.insert(run, ['r', 1])
|
||||
self.check(run, "r(1)", "Basic insert with no propagations")
|
||||
self.insert(run, ['r', 1])
|
||||
self.check(run, "r(1)", "Duplicate insert with no propagations")
|
||||
|
||||
self.delete(run, ['r', 1])
|
||||
self.check(run, "", "Delete with no propagations")
|
||||
self.delete(run, ['r', 1])
|
||||
self.check(run, "", "Delete from empty table")
|
||||
|
||||
def test_unary_tables(self):
|
||||
""" Test rules for tables with one argument """
|
||||
code = ("q(x) :- p(x), r(x)")
|
||||
run = prep_runtime(code, "**** Basic tests ****")
|
||||
run = self.prep_runtime(code, "**** Basic propagation tests ****")
|
||||
self.insert(run, ['r', 1])
|
||||
self.insert(run, ['p', 1])
|
||||
self.check(run, "r(1) p(1) q(1)", "Insert into base table with 1 propagation")
|
||||
|
||||
check(run, "", "Empty database on init")
|
||||
|
||||
insert(run, ['r', 1])
|
||||
check(run, "r(1)", "Basic insert with no propagations")
|
||||
insert(run, ['r', 1])
|
||||
check(run, "r(1)", "Duplicate insert with no propagations")
|
||||
|
||||
delete(run, ['r', 1])
|
||||
check(run, "", "Delete with no propagations")
|
||||
delete(run, ['r', 1])
|
||||
check(run, "", "Delete from empty table")
|
||||
|
||||
insert(run, ['r', 1])
|
||||
insert(run, ['p', 1])
|
||||
check(run, "r(1) p(1) q(1)", "Insert into base table with 1 propagation")
|
||||
showdb(run)
|
||||
|
||||
delete(run, ['r', 1])
|
||||
check(run, "p(1)", "Delete from base table with 1 propagation")
|
||||
showdb(run)
|
||||
self.delete(run, ['r', 1])
|
||||
self.check(run, "p(1)", "Delete from base table with 1 propagation")
|
||||
|
||||
# multiple rules
|
||||
code = ("q(x) :- p(x), r(x)"
|
||||
"q(x) :- s(x)")
|
||||
insert(run, ['p', 1])
|
||||
insert(run, ['r', 1])
|
||||
showdb(run)
|
||||
check(run, "p(1) r(1) q(1)", "Insert: multiple rules")
|
||||
insert(run, ['s', 1])
|
||||
showdb(run)
|
||||
check(run, "p(1) r(1) s(1) q(1)", "Insert: duplicate conclusions")
|
||||
self.insert(run, ['p', 1])
|
||||
self.insert(run, ['r', 1])
|
||||
self.check(run, "p(1) r(1) q(1)", "Insert: multiple rules")
|
||||
self.insert(run, ['s', 1])
|
||||
self.check(run, "p(1) r(1) s(1) q(1)", "Insert: duplicate conclusions")
|
||||
|
||||
# body of length 1
|
||||
code = ("q(x) :- p(x)")
|
||||
run = prep_runtime(code, "**** Body length 1 tests ****")
|
||||
run = self.prep_runtime(code, "**** Body length 1 tests ****")
|
||||
|
||||
insert(run, ['p', 1])
|
||||
check(run, "p(1) q(1)", "Insert with body of size 1")
|
||||
|
||||
delete(run, ['p', 1])
|
||||
check(run, "", "Delete with body of size 1")
|
||||
self.insert(run, ['p', 1])
|
||||
self.check(run, "p(1) q(1)", "Insert with body of size 1")
|
||||
self.showdb(run)
|
||||
self.delete(run, ['p', 1])
|
||||
self.showdb(run)
|
||||
self.check(run, "", "Delete with body of size 1")
|
||||
|
||||
# existential variables
|
||||
code = ("q(x) :- p(x,y)")
|
||||
run = prep_runtime(code, "**** Existential variable tests ****")
|
||||
code = ("q(x) :- p(x), r(y)")
|
||||
run = self.prep_runtime(code, "**** Unary tables with existential ****")
|
||||
self.insert(run, ['p', 1])
|
||||
self.insert(run, ['r', 2])
|
||||
self.insert(run, ['r', 3])
|
||||
self.showdb(run)
|
||||
self.check(run, "p(1) r(2) r(3) q(1)",
|
||||
"Insert with unary table and existential")
|
||||
self.delete(run, ['r', 2])
|
||||
self.check(run, "p(1) r(3) q(1)",
|
||||
"Delete 1 with unary table and existential")
|
||||
self.delete(run, ['r', 3])
|
||||
self.check(run, "p(1)",
|
||||
"Delete all with unary table and existential")
|
||||
|
||||
insert(run, ['p', 1, 2])
|
||||
check(run, "p(1, 2) q(1)", "Insert: existential variable in body of size 1")
|
||||
delete(run, ['p', 1, 2])
|
||||
check(run, "", "Delete: existential variable in body of size 1")
|
||||
|
||||
def test_multi_arity_tables(self):
|
||||
""" Test rules whose tables have more than 1 argument """
|
||||
code = ("q(x) :- p(x,y)")
|
||||
run = self.prep_runtime(code, "**** Multiple-arity table tests ****")
|
||||
|
||||
self.insert(run, ['p', 1, 2])
|
||||
self.check(run, "p(1, 2) q(1)", "Insert: existential variable in body of size 1")
|
||||
self.delete(run, ['p', 1, 2])
|
||||
self.check(run, "", "Delete: existential variable in body of size 1")
|
||||
|
||||
code = ("q(x) :- p(x,y), r(y,x)")
|
||||
run = prep_runtime(code)
|
||||
insert(run, ['p', 1, 2])
|
||||
showdb(run)
|
||||
insert(run, ['r', 2, 1])
|
||||
showdb(run)
|
||||
check(run, "p(1, 2) r(2, 1) q(1)", "Insert: join in body of size 2")
|
||||
delete(run, ['p', 1, 2])
|
||||
showdb(run)
|
||||
check(run, "r(2, 1)", "Delete: join in body of size 2")
|
||||
insert(run, ['p', 1, 2])
|
||||
showdb(run)
|
||||
insert(run, ['p', 1, 3])
|
||||
showdb(run)
|
||||
insert(run, ['r', 3, 1])
|
||||
showdb(run)
|
||||
check(run, "r(2, 1) r(3,1) p(1, 2) p(1, 3) q(1)",
|
||||
run = self.prep_runtime(code)
|
||||
self.insert(run, ['p', 1, 2])
|
||||
self.insert(run, ['r', 2, 1])
|
||||
self.check(run, "p(1, 2) r(2, 1) q(1)", "Insert: join in body of size 2")
|
||||
self.delete(run, ['p', 1, 2])
|
||||
self.check(run, "r(2, 1)", "Delete: join in body of size 2")
|
||||
self.insert(run, ['p', 1, 2])
|
||||
self.insert(run, ['p', 1, 3])
|
||||
self.insert(run, ['r', 3, 1])
|
||||
self.check(run, "r(2, 1) r(3,1) p(1, 2) p(1, 3) q(1)",
|
||||
"Insert: multiple existential bindings for same head")
|
||||
|
||||
delete(run, ['p', 1, 2])
|
||||
check(run, "r(2, 1) r(3,1) p(1, 3) q(1)",
|
||||
self.delete(run, ['p', 1, 2])
|
||||
self.check(run, "r(2, 1) r(3,1) p(1, 3) q(1)",
|
||||
"Delete: multiple existential bindings for same head")
|
||||
|
||||
code = ("q(x,v) :- p(x,y), r(y,z), s(z,w), t(w,v)")
|
||||
run = prep_runtime(code)
|
||||
insert(run, ['p', 1, 10])
|
||||
insert(run, ['p', 1, 20])
|
||||
insert(run, ['r', 10, 100])
|
||||
insert(run, ['r', 20, 200])
|
||||
insert(run, ['s', 100, 1000])
|
||||
insert(run, ['s', 200, 2000])
|
||||
insert(run, ['t', 1000, 10000])
|
||||
insert(run, ['t', 2000, 20000])
|
||||
run = self.prep_runtime(code)
|
||||
self.insert(run, ['p', 1, 10])
|
||||
self.insert(run, ['p', 1, 20])
|
||||
self.insert(run, ['r', 10, 100])
|
||||
self.insert(run, ['r', 20, 200])
|
||||
self.insert(run, ['s', 100, 1000])
|
||||
self.insert(run, ['s', 200, 2000])
|
||||
self.insert(run, ['t', 1000, 10000])
|
||||
self.insert(run, ['t', 2000, 20000])
|
||||
code = ("p(1,10) p(1,20) r(10,100) r(20,200) s(100,1000) s(200,2000)"
|
||||
"t(1000, 10000) t(2000,20000) "
|
||||
"q(1,10000) q(1,20000)")
|
||||
check(run, code, "Insert: larger join")
|
||||
delete(run, ['t', 1000, 10000])
|
||||
self.check(run, code, "Insert: larger join")
|
||||
self.delete(run, ['t', 1000, 10000])
|
||||
code = ("p(1,10) p(1,20) r(10,100) r(20,200) s(100,1000) s(200,2000)"
|
||||
"t(2000,20000) "
|
||||
"q(1,20000)")
|
||||
check(run, code, "Delete: larger join")
|
||||
self.check(run, code, "Delete: larger join")
|
||||
|
||||
code = ("q(x,y) :- p(x,z), p(z,y)")
|
||||
run = prep_runtime(code)
|
||||
insert(run, ['p', 1, 2])
|
||||
insert(run, ['p', 1, 3])
|
||||
insert(run, ['p', 2, 4])
|
||||
insert(run, ['p', 2, 5])
|
||||
check(run, 'p(1,2) p(1,3) p(2,4) p(2,5) q(1,4) q(1,5)',
|
||||
run = self.prep_runtime(code)
|
||||
self.insert(run, ['p', 1, 2])
|
||||
self.insert(run, ['p', 1, 3])
|
||||
self.insert(run, ['p', 2, 4])
|
||||
self.insert(run, ['p', 2, 5])
|
||||
self.check(run, 'p(1,2) p(1,3) p(2,4) p(2,5) q(1,4) q(1,5)',
|
||||
"Insert: self-join")
|
||||
delete(run, ['p', 2, 4])
|
||||
check(run, 'p(1,2) p(1,3) p(2,5) q(1,5)')
|
||||
self.delete(run, ['p', 2, 4])
|
||||
self.check(run, 'p(1,2) p(1,3) p(2,5) q(1,5)')
|
||||
|
||||
code = ("q(x,z) :- p(x,y), p(y,z)")
|
||||
run = self.prep_runtime(code)
|
||||
self.insert(run, ['p', 1, 1])
|
||||
self.check(run, 'p(1,1) q(1,1)', "Insert: self-join on same data")
|
||||
|
||||
code = ("q(x,w) :- p(x,y), p(y,z), p(z,w)")
|
||||
run = prep_runtime(code)
|
||||
insert(run, ['p', 1, 1])
|
||||
insert(run, ['p', 1, 2])
|
||||
insert(run, ['p', 2, 2])
|
||||
insert(run, ['p', 2, 3])
|
||||
insert(run, ['p', 2, 4])
|
||||
insert(run, ['p', 2, 5])
|
||||
insert(run, ['p', 3, 3])
|
||||
insert(run, ['p', 3, 4])
|
||||
insert(run, ['p', 3, 5])
|
||||
insert(run, ['p', 3, 6])
|
||||
insert(run, ['p', 3, 7])
|
||||
run = self.prep_runtime(code)
|
||||
self.insert(run, ['p', 1, 1])
|
||||
self.insert(run, ['p', 1, 2])
|
||||
self.insert(run, ['p', 2, 2])
|
||||
self.insert(run, ['p', 2, 3])
|
||||
self.insert(run, ['p', 2, 4])
|
||||
self.insert(run, ['p', 2, 5])
|
||||
self.insert(run, ['p', 3, 3])
|
||||
self.insert(run, ['p', 3, 4])
|
||||
self.insert(run, ['p', 3, 5])
|
||||
self.insert(run, ['p', 3, 6])
|
||||
self.insert(run, ['p', 3, 7])
|
||||
code = ('p(1,1) p(1,2) p(2,2) p(2,3) p(2,4) p(2,5)'
|
||||
'p(3,3) p(3,4) p(3,5) p(3,6) p(3,7)'
|
||||
'q(1,1) q(1,2) q(2,2) q(2,3) q(2,4) q(2,5)'
|
||||
'q(3,3) q(3,4) q(3,5) q(3,6) q(3,7)'
|
||||
'q(1,3) q(1,4) q(1,5) q(1,6) q(1,7)'
|
||||
'q(2,6) q(2,7)')
|
||||
check(run, code, "Insert: larger self join")
|
||||
delete(run, ['p', 1, 1])
|
||||
delete(run, ['p', 2, 2])
|
||||
self.check(run, code, "Insert: larger self join")
|
||||
self.delete(run, ['p', 1, 1])
|
||||
self.delete(run, ['p', 2, 2])
|
||||
code = (' p(1,2) p(2,3) p(2,4) p(2,5)'
|
||||
'p(3,3) p(3,4) p(3,5) p(3,6) p(3,7)'
|
||||
' q(2,3) q(2,4) q(2,5)'
|
||||
'q(3,3) q(3,4) q(3,5) q(3,6) q(3,7)'
|
||||
'q(1,3) q(1,4) q(1,5) q(1,6) q(1,7)'
|
||||
'q(2,6) q(2,7)')
|
||||
check(run, code, "Delete: larger self join")
|
||||
self.check(run, code, "Delete: larger self join")
|
||||
|
||||
# Value types: string
|
||||
def test_value_types(self):
|
||||
""" Test the different value types """
|
||||
# string
|
||||
code = ("q(x) :- p(x), r(x)")
|
||||
run = prep_runtime(code, "String data type")
|
||||
run = self.prep_runtime(code, "String data type")
|
||||
|
||||
insert(run, ['r', 'apple'])
|
||||
check(run, 'r("apple")', "String insert with no propagations")
|
||||
insert(run, ['r', 'apple'])
|
||||
check(run, 'r("apple")', "Duplicate string insert with no propagations")
|
||||
self.insert(run, ['r', 'apple'])
|
||||
self.check(run, 'r("apple")', "String insert with no propagations")
|
||||
self.insert(run, ['r', 'apple'])
|
||||
self.check(run, 'r("apple")', "Duplicate string insert with no propagations")
|
||||
|
||||
delete(run, ['r', 'apple'])
|
||||
check(run, "", "Delete with no propagations")
|
||||
delete(run, ['r', 'apple'])
|
||||
check(run, "", "Delete from empty table")
|
||||
self.delete(run, ['r', 'apple'])
|
||||
self.check(run, "", "Delete with no propagations")
|
||||
self.delete(run, ['r', 'apple'])
|
||||
self.check(run, "", "Delete from empty table")
|
||||
|
||||
insert(run, ['r', 'apple'])
|
||||
insert(run, ['p', 'apple'])
|
||||
check(run, 'r("apple") p("apple") q("apple")',
|
||||
self.insert(run, ['r', 'apple'])
|
||||
self.insert(run, ['p', 'apple'])
|
||||
self.check(run, 'r("apple") p("apple") q("apple")',
|
||||
"String insert with 1 propagation")
|
||||
showdb(run)
|
||||
|
||||
delete(run, ['r', 'apple'])
|
||||
check(run, 'p("apple")', "String delete with 1 propagation")
|
||||
showdb(run)
|
||||
self.delete(run, ['r', 'apple'])
|
||||
self.check(run, 'p("apple")', "String delete with 1 propagation")
|
||||
|
||||
# Value types: floats
|
||||
# float
|
||||
code = ("q(x) :- p(x), r(x)")
|
||||
run = prep_runtime(code, "Float data type")
|
||||
run = self.prep_runtime(code, "Float data type")
|
||||
|
||||
insert(run, ['r', 1.2])
|
||||
check(run, 'r(1.2)', "String insert with no propagations")
|
||||
insert(run, ['r', 1.2])
|
||||
check(run, 'r(1.2)', "Duplicate string insert with no propagations")
|
||||
self.insert(run, ['r', 1.2])
|
||||
self.check(run, 'r(1.2)', "String insert with no propagations")
|
||||
self.insert(run, ['r', 1.2])
|
||||
self.check(run, 'r(1.2)', "Duplicate string insert with no propagations")
|
||||
|
||||
delete(run, ['r', 1.2])
|
||||
check(run, "", "Delete with no propagations")
|
||||
delete(run, ['r', 1.2])
|
||||
check(run, "", "Delete from empty table")
|
||||
self.delete(run, ['r', 1.2])
|
||||
self.check(run, "", "Delete with no propagations")
|
||||
self.delete(run, ['r', 1.2])
|
||||
self.check(run, "", "Delete from empty table")
|
||||
|
||||
insert(run, ['r', 1.2])
|
||||
insert(run, ['p', 1.2])
|
||||
check(run, 'r(1.2) p(1.2) q(1.2)',
|
||||
"String insert with 1 propagation")
|
||||
showdb(run)
|
||||
self.insert(run, ['r', 1.2])
|
||||
self.insert(run, ['p', 1.2])
|
||||
self.check(run, 'r(1.2) p(1.2) q(1.2)',
|
||||
"String self.insert with 1 propagation")
|
||||
|
||||
delete(run, ['r', 1.2])
|
||||
check(run, 'p(1.2)', "String delete with 1 propagation")
|
||||
showdb(run)
|
||||
self.delete(run, ['r', 1.2])
|
||||
self.check(run, 'p(1.2)', "String delete with 1 propagation")
|
||||
|
||||
# negation
|
||||
def test_proofs(self):
|
||||
""" Test if the proof computation is performed correctly. """
|
||||
def check_table_proofs(run, table, tuple_proof_dict, msg):
|
||||
for tuple in tuple_proof_dict:
|
||||
tuple_proof_dict[tuple] = \
|
||||
Database.ProofCollection(tuple_proof_dict[tuple])
|
||||
self.check_proofs(run, {table : tuple_proof_dict}, msg)
|
||||
|
||||
code = ("q(x) :- p(x,y)")
|
||||
run = self.prep_runtime(code, "**** Proof tests ****")
|
||||
|
||||
self.insert(run, ['p', 1, 2])
|
||||
check_table_proofs(run, 'q', {(1,): [{u'x': 1, u'y': 2}]},
|
||||
'Simplest proof test')
|
||||
|
||||
def test_negation(self):
|
||||
""" Test negation """
|
||||
# Unary, single join
|
||||
code = ("q(x) :- p(x), not r(x)")
|
||||
run = self.prep_runtime(code, "Unary, single join")
|
||||
|
||||
self.insert(run, ['p', 2])
|
||||
self.check(run, 'p(2) q(2)',
|
||||
"Insert into positive literal with propagation")
|
||||
self.delete(run, ['p', 2])
|
||||
self.check(run, '',
|
||||
"Delete from positive literal with propagation")
|
||||
|
||||
self.insert(run, ['r', 2])
|
||||
self.check(run, 'r(2)',
|
||||
"Insert into negative literal without propagation")
|
||||
self.delete(run, ['r', 2])
|
||||
self.check(run, '',
|
||||
"Delete from negative literal without propagation")
|
||||
|
||||
self.insert(run, ['p', 2])
|
||||
self.insert(run, ['r', 2])
|
||||
self.check(run, 'p(2) r(2)',
|
||||
"Insert into negative literal with propagation")
|
||||
|
||||
self.delete(run, ['r', 2])
|
||||
self.check(run, 'q(2) p(2)',
|
||||
"Delete from negative literal with propagation")
|
||||
|
||||
# Unary, multiple joins
|
||||
code = ("s(x) :- p(x), not r(x), q(y), not t(y)")
|
||||
run = self.prep_runtime(code, "Unary, multiple join")
|
||||
self.insert(run, ['p', 1])
|
||||
self.insert(run, ['q', 2])
|
||||
self.check(run, 'p(1) q(2) s(1)',
|
||||
'Insert with two negative literals')
|
||||
|
||||
self.insert(run, ['r', 3])
|
||||
self.check(run, 'p(1) q(2) s(1) r(3)',
|
||||
'Ineffectual insert with 2 negative literals')
|
||||
self.insert(run, ['r', 1])
|
||||
self.check(run, 'p(1) q(2) r(3) r(1)',
|
||||
'Insert into existentially quantified negative literal with propagation. ')
|
||||
self.insert(run, ['t', 2])
|
||||
self.check(run, 'p(1) q(2) r(3) r(1) t(2)',
|
||||
'Insert into negative literal producing extra blocker for proof.')
|
||||
self.delete(run, ['t', 2])
|
||||
self.check(run, 'p(1) q(2) r(3) r(1)',
|
||||
'Delete first blocker from proof')
|
||||
self.delete(run, ['r', 1])
|
||||
self.check(run, 'p(1) q(2) r(3) s(1)',
|
||||
'Delete second blocker from proof')
|
||||
|
||||
# Non-unary
|
||||
code = ("p(x, v) :- q(x,z), r(z, w), not s(x, w), u(w,v)")
|
||||
run = self.prep_runtime(code, "Non-unary")
|
||||
self.insert(run, ['q', 1, 2])
|
||||
self.insert(run, ['r', 2, 3])
|
||||
self.insert(run, ['r', 2, 4])
|
||||
self.insert(run, ['u', 3, 5])
|
||||
self.insert(run, ['u', 4, 6])
|
||||
self.check(run, 'q(1,2) r(2,3) r(2,4) u(3,5) u(4,6) p(1,5) p(1,6)',
|
||||
'Insert with non-unary negative literal')
|
||||
|
||||
self.insert(run, ['s', 1, 3])
|
||||
self.check(run, 'q(1,2) r(2,3) r(2,4) u(3,5) u(4,6) s(1,3) p(1,6)',
|
||||
'Insert into non-unary negative with propagation')
|
||||
|
||||
self.insert(run, ['s', 1, 4])
|
||||
self.check(run, 'q(1,2) r(2,3) r(2,4) u(3,5) u(4,6) s(1,3) s(1,4)',
|
||||
'Insert into non-unary with different propagation')
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -1,3 +1,24 @@
|
|||
q(x) :- p(x), r(x)
|
||||
|
||||
error(vm) :- virtual_machine(vm), network(vm, net), not public(net), owner(vm, owner), not owned_by_some_group(owner, net)
|
||||
|
||||
owned_by_some_group(owner, object) :- owner(object, user), member(user, group), member(owner, group)
|
||||
|
||||
virtual_machine('vm1')
|
||||
virtual_machine('vm2')
|
||||
network('vm1', 'net_private')
|
||||
network('vm2', 'net_public')
|
||||
|
||||
public('net_public')
|
||||
|
||||
owner('vm1', 'tim')
|
||||
owner('vm2', 'pete')
|
||||
owner('net_private', 'martin')
|
||||
|
||||
member('pete', 'congress')
|
||||
member('tim', 'congress')
|
||||
member('martin', 'congress')
|
||||
member('pierre', 'congress')
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue