diff --git a/src/policy/Congress.g b/src/policy/Congress.g index 5c40f49db..d883efea6 100644 --- a/src/policy/Congress.g +++ b/src/policy/Congress.g @@ -99,7 +99,7 @@ SIGN : '+' | '-' ; -ID : ('a'..'z'|'A'..'Z'|'_') ('a'..'z'|'A'..'Z'|'0'..'9'|'_')* +ID : ('a'..'z'|'A'..'Z'|'_'|'.') ('a'..'z'|'A'..'Z'|'0'..'9'|'_'|'.')* ; INT : '0'..'9'+ diff --git a/src/policy/compile.py b/src/policy/compile.py index 24f895207..a6017a65b 100755 --- a/src/policy/compile.py +++ b/src/policy/compile.py @@ -120,7 +120,10 @@ class ObjectConstant (Term): self.location = location def __str__(self): - return str(self.name) + if self.type == ObjectConstant.STRING: + return '"' + str(self.name) + '"' + else: + return str(self.name) def __repr__(self): # Use repr to hash rule--can't include location @@ -233,6 +236,36 @@ class Atom (object): """ Does NOT make copy """ return self + def invert_update(self): + """ If end of table name is + or -, return a copy after switching + the copy's sign. + Does not make a copy if table name does not end in + or -. """ + if self.table.endswith('+'): + suffix = '-' + elif self.table.endswith('-'): + suffix = '+' + else: + suffix = None + + if suffix is None: + return self + else: + new = copy.copy(self) + new.table = new.table[:-1] + suffix + return new + + def drop_update(self): + """ If end of table name is + or -, return a copy without the sign. + If table name does not end in + or -, make no copy. """ + if self.table.endswith('+') or self.table.endswith('-'): + new = copy.copy(self) + new.table = new.table[:-1] + return new + else: + return self + + def tablename(self): + return self.table class Literal(Atom): """ Represents either a negated atom or an atom. """ @@ -319,10 +352,8 @@ class Rule (object): def is_rule(self): return True - def plug(self, binding, caller=None): - newhead = self.head.plug(binding, caller=caller) - newbody = [lit.plug(binding, caller=caller) for lit in self.body] - return Rule(newhead, newbody) + def tablename(self): + return self.head.table def variables(self): vs = self.head.variables() @@ -336,6 +367,21 @@ class Rule (object): vs |= lit.variable_names() return vs + def plug(self, binding, caller=None): + newhead = self.head.plug(binding, caller=caller) + newbody = [lit.plug(binding, caller=caller) for lit in self.body] + return Rule(newhead, newbody) + + def invert_update(self): + new = copy.copy(self) + new.head = self.head.invert_update() + return new + + def drop_update(self): + new = copy.copy(self) + new.head = self.head.drop_update() + return new + def formulas_to_string(formulas): """ Takes an iterable of compiler sentence objects and returns a @@ -345,6 +391,17 @@ def formulas_to_string(formulas): return "None" return " ".join([str(formula) for formula in formulas]) +def is_update(x): + """ Returns T iff x is a formula or tablename representing an update. """ + if isinstance(x, basestring): + return x.endswith('+') or x.endswith('-') + elif isinstance(x, Atom): + return is_update(x.table) + elif isinstance(x, Rule): + return is_update(x.head.table) + else: + return False + ############################################################################## ## Compiler ############################################################################## @@ -392,75 +449,6 @@ class Compiler (object): errors = [str(err) for err in self.errors] raise CongressException('Compiler found errors:' + '\n'.join(errors)) - def compute_delta_rules(self): - # logging.debug("self.theory: {}".format([str(x) for x in self.theory])) - self.delta_rules = compute_delta_rules(self.theory) - - -def eliminate_self_joins(theory): - """ Modify THEORY so that all self-joins have been eliminated. """ - def new_table_name(name, arity, index): - return "___{}_{}_{}".format(name, arity, index) - def n_variables(n): - vars = [] - for i in xrange(0, n): - vars.append("x" + str(i)) - return vars - # dict from (table name, arity) tuple to - # max num of occurrences of self-joins in any rule - global_self_joins = {} - # dict from (table name, arity) to # of args for - arities = {} - # remove self-joins from rules - for rule in theory: - if rule.is_atom(): - continue - logging.debug("eliminating self joins from {}".format(rule)) - occurrences = {} # for just this rule - for atom in rule.body: - table = atom.table - arity = len(atom.arguments) - tablearity = (table, arity) - if tablearity not in occurrences: - occurrences[tablearity] = 1 - else: - # change name of atom - atom.table = new_table_name(table, arity, - occurrences[tablearity]) - # update our counters - occurrences[tablearity] += 1 - if tablearity not in global_self_joins: - global_self_joins[tablearity] = 1 - else: - global_self_joins[tablearity] = \ - max(occurrences[tablearity] - 1, - global_self_joins[tablearity]) - logging.debug("final rule: {}".format(str(rule))) - # add definitions for new tables - for tablearity in global_self_joins: - table = tablearity[0] - arity = tablearity[1] - for i in xrange(1, global_self_joins[tablearity] + 1): - newtable = new_table_name(table, arity, i) - args = [Variable(var) for var in n_variables(arity)] - head = Atom(newtable, args) - body = [Atom(table, args)] - theory.append(Rule(head, body)) - logging.debug("Adding rule {}".format(str(theory[-1]))) - return theory - -def compute_delta_rules(theory): - eliminate_self_joins(theory) - delta_rules = [] - for rule in theory: - if rule.is_atom(): - continue - for literal in rule.body: - newbody = [lit for lit in rule.body if lit is not literal] - delta_rules.append( - runtime.DeltaRule(literal, rule.head, newbody, rule)) - return delta_rules - ############################################################################## ## External syntax: datalog ############################################################################## diff --git a/src/policy/runtime.py b/src/policy/runtime.py index c9534bcce..64f73ef80 100644 --- a/src/policy/runtime.py +++ b/src/policy/runtime.py @@ -203,8 +203,8 @@ class TopDownTheory(object): # in QUERY. bindings = self.top_down_evaluation(query.variables(), literals, find_all=find_all) - logging.debug("Top_down_evaluation returned: {}".format( - str(bindings))) + # logging.debug("Top_down_evaluation returned: {}".format( + # str(bindings))) if len(bindings) > 0: logging.debug("Found answer {}".format( "[" + ",".join([str(query.plug(x)) @@ -251,7 +251,24 @@ class TopDownTheory(object): find_all=find_all, save=lambda lit,binding: lit.table in tablenames) results = [compile.Rule(output.plug(abd.binding), abd.support) for abd in abductions] - logging.debug("abduction result: " + iterstr(results)) + logging.debug("abduction result:") + logging.debug("\n".join([str(x) for x in results])) + return results + + def consequences(self, filter=None): + """ Return all the true instances of any table that is defined + in this theory (according to DEFINED_TABLE_NAMES). """ + results = set() + # create queries: need table names and arities + for table in self.defined_table_names(): + if filter is None or filter(table): + arity = self.arity(table) + vs = [] + for i in xrange(0, arity): + vs.append("x" + str(i)) + vs = [compile.Variable(var) for var in vs] + query = compile.Atom(table, vs) + results |= set(self.select(query)) return results def top_down_evaluation(self, variables, literals, @@ -421,25 +438,25 @@ class TopDownTheory(object): return finished def print_call(self, literal, binding, depth): - self.log(literal.table, "{}Call: {} with {}".format("| "*depth, - literal.plug(binding), str(binding))) + self.log(literal.table, "{}Call: {}".format("| "*depth, + literal.plug(binding))) def print_exit(self, literal, binding, depth): - self.log(literal.table, "{}Exit: {} with {}".format("| "*depth, - literal.plug(binding), str(binding))) + self.log(literal.table, "{}Exit: {}".format("| "*depth, + literal.plug(binding))) def print_save(self, literal, binding, depth): - self.log(literal.table, "{}Save: {} with {}".format("| "*depth, - literal.plug(binding), str(binding))) + self.log(literal.table, "{}Save: {}".format("| "*depth, + literal.plug(binding))) def print_fail(self, literal, binding, depth): - self.log(literal.table, "{}Fail: {} with {}".format("| "*depth, - literal.plug(binding), str(binding))) + self.log(literal.table, "{}Fail: {}".format("| "*depth, + literal.plug(binding))) return False def print_redo(self, literal, binding, depth): - self.log(literal.table, "{}Redo: {} with {}".format("| "*depth, - literal.plug(binding), str(binding))) + self.log(literal.table, "{}Redo: {}".format("| "*depth, + literal.plug(binding))) return False def log(self, table, msg, depth=0): @@ -452,6 +469,23 @@ class TopDownTheory(object): # lambda (index): # compile.Variable("x" + str(index)), dictionary=dictionary) + def arity(self, tablename): + """ Return the number of arguments TABLENAME takes or None if + unknown because TABLENAME is not defined here. """ + # assuming a fixed arity for all tables + formulas = self.head_index(tablename) + if len(formulas) == 0: + return None + first = formulas[0] + # should probably have an overridable function for computing + # the arguments of a head. Instead we assume heads have .arguments + return len(self.head(first).arguments) + + def defined_table_names(self): + """ This routine returns the list of all table names that are + defined/written to in this theory. """ + return self.contents.keys() + def head_index(self, table): """ This routine must return all the formulas pertinent for top-down evaluation when a literal with TABLE is at the top @@ -680,6 +714,9 @@ class Database(TopDownTheory): # overloads for TopDownTheory so we can properly use the # top_down_evaluation routines + def defined_table_names(self): + return self.table_names() + def head_index(self, table): if table not in self.data: return [] @@ -705,6 +742,7 @@ class Database(TopDownTheory): table, str(dbtuple))) if table not in self.data: self.data[table] = [dbtuple] + return # self.log(table, "First tuple in table {}".format(table)) else: # self.log(table, "Not first tuple in table {}".format(table)) @@ -759,6 +797,7 @@ class NonrecursiveRuleTheory(TopDownTheory): return str(self.contents) def insert(self, rule): + """ Insert RULE and return True iff the theory changed. """ if isinstance(rule, compile.Atom): rule = compile.Rule(rule, [], rule.location) self.log(rule.head.table, "Insert: {}".format(str(rule))) @@ -766,34 +805,65 @@ class NonrecursiveRuleTheory(TopDownTheory): if table in self.contents: if rule not in self.contents[table]: # eliminate dups self.contents[table].append(rule) + return True + return False else: self.contents[table] = [rule] + return True def delete(self, rule): + """ Delete RULE and return True iff the theory changed. """ if isinstance(rule, compile.Atom): rule = compile.Rule(rule, [], rule.location) self.log(rule.head.table, "Delete: {}".format(str(rule))) table = rule.head.table if table in self.contents: - self.contents[table].remove(rule) + try: + self.contents[table].remove(rule) + return True + except ValueError: + return False + return False + + def define(self, rules): + """ Empties and then inserts RULES. """ + self.empty() + for rule in rules: + self.insert(rule) + + def empty(self): + """ Deletes contents of theory. """ + self.contents = {} def log(self, table, msg, depth=0): self.tracer.log(table, "NRT: " + msg, depth) class DeltaRuleTheory (object): """ A collection of DeltaRules. """ - def __init__(self, rules=None): + def __init__(self): # dictionary from table name to list of rules with that table as trigger self.contents = {} + # dictionary from delta_rule to the rule from which it was derived + self.originals = set() # list of theories implicitly included in this one self.includes = [] # dictionary from table name to number of rules with that table in head self.views = {} - if rules is not None: - for rule in rules: - self.insert(rule) - def insert(self, delta): + def insert(self, rule): + """ Insert a compile.Rule into the theory. + Return True iff the theory changed. """ + assert isinstance(rule, compile.Rule), \ + "DeltaRuleTheory only takes rules" + if rule in self.originals: + return False + for delta in self.compute_delta_rules([rule]): + self.insert_delta(delta) + self.originals.add(rule) + return True + + def insert_delta(self, delta): + """ Insert a delta rule. """ if delta.head.table in self.views: self.views[delta.head.table] += 1 else: @@ -804,7 +874,18 @@ class DeltaRuleTheory (object): else: self.contents[delta.trigger.table].append(delta) - def delete(self, delta): + def delete(self, rule): + """ Delete a compile.Rule from theory. + Assumes that COMPUTE_DELTA_RULES is deterministic. + Returns True iff the theory changed. """ + if rule not in self.originals: + return False + for delta in self.compute_delta_rules([rule]): + self.delete_delta(delta) + self.originals.remove(rule) + return True + + def delete_delta(self, delta): if delta.head.table in self.views: self.views[delta.head.table] -= 1 if self.views[delta.head.table] == 0: @@ -831,6 +912,72 @@ class DeltaRuleTheory (object): def is_view(self, x): return x in self.views + @classmethod + def eliminate_self_joins(cls, theory): + """ Modify THEORY so that all self-joins have been eliminated. """ + def new_table_name(name, arity, index): + return "___{}_{}_{}".format(name, arity, index) + def n_variables(n): + vars = [] + for i in xrange(0, n): + vars.append("x" + str(i)) + return vars + # dict from (table name, arity) tuple to + # max num of occurrences of self-joins in any rule + global_self_joins = {} + # dict from (table name, arity) to # of args for + arities = {} + # remove self-joins from rules + for rule in theory: + if rule.is_atom(): + continue + logging.debug("eliminating self joins from {}".format(rule)) + occurrences = {} # for just this rule + for atom in rule.body: + table = atom.table + arity = len(atom.arguments) + tablearity = (table, arity) + if tablearity not in occurrences: + occurrences[tablearity] = 1 + else: + # change name of atom + atom.table = new_table_name(table, arity, + occurrences[tablearity]) + # update our counters + occurrences[tablearity] += 1 + if tablearity not in global_self_joins: + global_self_joins[tablearity] = 1 + else: + global_self_joins[tablearity] = \ + max(occurrences[tablearity] - 1, + global_self_joins[tablearity]) + logging.debug("final rule: {}".format(str(rule))) + # add definitions for new tables + for tablearity in global_self_joins: + table = tablearity[0] + arity = tablearity[1] + for i in xrange(1, global_self_joins[tablearity] + 1): + newtable = new_table_name(table, arity, i) + args = [compile.Variable(var) for var in n_variables(arity)] + head = compile.Atom(newtable, args) + body = [compile.Atom(table, args)] + theory.append(compile.Rule(head, body)) + logging.debug("Adding rule {}".format(str(theory[-1]))) + return theory + + @classmethod + def compute_delta_rules(cls, theory): + theory = cls.eliminate_self_joins(theory) + delta_rules = [] + for rule in theory: + if rule.is_atom(): + continue + for literal in rule.body: + newbody = [lit for lit in rule.body if lit is not literal] + delta_rules.append( + DeltaRule(literal, rule.head, newbody, rule)) + return delta_rules + class MaterializedRuleTheory(TopDownTheory): """ A theory that stores the table contents explicitly. Recursive rules are allowed. """ @@ -848,24 +995,29 @@ class MaterializedRuleTheory(TopDownTheory): ############### External Interface ############### def select(self, query): + """ Returns list of instances of QUERY true in the theory. """ assert (isinstance(query, compile.Atom) or isinstance(query, compile.Rule)), \ "Select requires a formula" return self.database.select(query) def insert(self, formula): + """ Insert FORMULA. Returns True iff the theory changed. """ assert (isinstance(formula, compile.Atom) or isinstance(formula, compile.Rule)), \ "Insert requires a formula" return self.modify(formula, is_insert=True) def delete(self, formula): + """ Delete FORMULA. Returns True iff the theory changed. """ assert (isinstance(formula, compile.Atom) or isinstance(formula, compile.Rule)), \ "Delete requires a formula" return self.modify(formula, is_insert=False) def explain(self, query, tablenames, find_all): + """ Returns None if QUERY is False in theory. Otherwise returns + a list of proofs that QUERY is true. """ assert isinstance(query, compile.Atom), \ "Explain requires an atom" # ignoring TABLENAMES and FIND_ALL @@ -904,7 +1056,8 @@ class MaterializedRuleTheory(TopDownTheory): return Proof(query, subproofs) def modify(self, formula, is_insert=True): - """ Event handler for arbitrary insertion/deletion (rules and facts). """ + """ Insert or delete a rule or fact. + Returns True iff the theory changed.""" if is_insert: text = "Insert" else: @@ -913,19 +1066,20 @@ class MaterializedRuleTheory(TopDownTheory): assert not self.is_view(formula.table), \ "Cannot directly modify tables computed from other tables" self.log(formula.table, "{}: {}".format(text, str(formula))) - self.modify_tables_with_atom(formula, is_insert=is_insert) - return None + return self.modify_tables_with_atom(formula, is_insert=is_insert) else: self.modify_tables_with_rule( - formula, is_insert=is_insert) + formula, is_insert=is_insert) self.log(formula.head.table, "{}: {}".format(text, str(formula))) - for delta_rule in compile.compute_delta_rules([formula]): - self.delta_rules.modify(delta_rule, is_insert=is_insert) - return None + if is_insert: + return self.delta_rules.insert(formula) + else: + return self.delta_rules.delete(formula) def modify_tables_with_rule(self, rule, is_insert): """ Add rule (not a DeltaRule) to collection and update - tables as appropriate. """ + tables as appropriate. + Returns True iff atoms were added to the tables. """ # don't have separate queue since inserting/deleting a rule doesn't generate any # new rule insertion/deletion events bindings = self.database.top_down_evaluation( @@ -934,35 +1088,23 @@ class MaterializedRuleTheory(TopDownTheory): ",".join([str(x) for x in bindings]))) self.process_new_bindings(bindings, rule.head, is_insert, rule) self.process_queue() - - # def modify_tables_with_tuple(self, table, row, is_insert): - # """ Event handler for a tuple insertion/deletion. - # TABLE is the name of a table (a string). - # TUPLE is a Python tuple. - # IS_INSERT is True or False.""" - # if is_insert: - # text = "Inserting into queue" - # else: - # text = "Deleting from queue" - # self.log(table, "{}: table {} with tuple {}".format( - # text, table, str(row))) - # if not isinstance(row, Database.DBTuple): - # row = Database.DBTuple(row) - # self.log(table, "{}: table {} with tuple {}".format( - # text, table, str(row))) - # self.queue.enqueue(Event(table, row, insert=is_insert)) - # self.process_queue() + return len(bindings) > 0 def modify_tables_with_atom(self, atom, is_insert): """ Event handler for atom insertion/deletion. - IS_INSERT is True or False.""" + IS_INSERT is True or False. + Returns True iff the tables changed.""" if is_insert: text = "Inserting into queue" else: text = "Deleting from queue" self.log(atom.table, "{}: {}".format(text, str(atom))) - self.queue.enqueue(Event(atom=atom, insert=is_insert)) + event = Event(atom=atom, insert=is_insert) + if self.database.is_noop(event): + return False + self.queue.enqueue(event) self.process_queue() + return True ############### Data manipulation ############### @@ -1109,6 +1251,12 @@ class Runtime (object): assert name in self.theory, "Unknown target {}".format(name) return self.theory[name] + def get_actions(self): + """ Return a list of the names of action tables. """ + actionth = self.theory[self.ACTION_THEORY] + actions = actionth.select(compile.parse1('action(x)')) + return [action.arguments[0].name for action in actions] + def log(self, table, msg, depth=0): self.tracer.log(table, "RT: " + msg, depth) @@ -1179,14 +1327,25 @@ class Runtime (object): else: return self.delete_obj(formula, self.get_target(target)) - def remediate(self, formula, target=None): + def remediate(self, formula): """ Event handler for remediation. """ if isinstance(formula, basestring): - return self.remediate_string(formula, self.get_target(target)) + return self.remediate_string(formula) elif isinstance(formula, tuple): - return self.remediate_tuple(formula, self.get_target(target)) + return self.remediate_tuple(formula) else: - return self.remediate_obj(formula, self.get_target(target)) + return self.remediate_obj(formula) + + def project(self, query, sequence): + """ Event handler for projection: the computation of a query given an + action sequence. [Eventually, we will want to support a sequence of + {action, rule insert/delete, atom insert/delete}. Holding + off only because no syntactic way of differentiating those + within the language. May need modals/function constants.] """ + if isinstance(query, basestring) and isinstance(sequence, basestring): + return self.project_string(query, sequence) + else: + return self.project_obj(query, sequence) # Maybe implement one day # def select_if(self, query, temporary_data): @@ -1205,9 +1364,7 @@ class Runtime (object): ## Arguments that are strings are suffixed with _string. ## All other arguments are instances of Theory, Atom, etc. - def select_obj(self, query, theory): - return theory.select(query) - + # select def select_string(self, policy_string, theory): policy = compile.parse(policy_string) assert len(policy) == 1, \ @@ -1219,9 +1376,10 @@ class Runtime (object): def select_tuple(self, tuple, theory): return self.select_obj(compile.Atom.create_from_iter(tuple), theory) - def explain_obj(self, query, tablenames, find_all, theory): - return theory.explain(query, tablenames, find_all) + def select_obj(self, query, theory): + return theory.select(query) + # explain def explain_string(self, query_string, tablenames, find_all, theory): policy = compile.parse(query_string) assert len(policy) == 1, "Queries can have only 1 statement" @@ -1232,9 +1390,10 @@ class Runtime (object): self.explain_obj(compile.Atom.create_from_iter(tuple), tablenames, find_all, theory) - def insert_obj(self, formula, theory): - return theory.insert(formula) + def explain_obj(self, query, tablenames, find_all, theory): + return theory.explain(query, tablenames, find_all) + # insert def insert_string(self, policy_string, theory): policy = compile.parse(policy_string) # TODO: send entire parsed theory so that e.g. self-join elim @@ -1246,9 +1405,10 @@ class Runtime (object): def insert_tuple(self, tuple, theory): self.insert_obj(compile.Atom.create_from_iter(tuple), theory) - def delete_obj(self, formula, theory): - theory.delete(formula) + def insert_obj(self, formula, theory): + return theory.insert(formula) + # delete def delete_string(self, policy_string, theory): policy = compile.parse(policy_string) for formula in policy: @@ -1257,7 +1417,19 @@ class Runtime (object): def delete_tuple(self, tuple, theory): self.delete_obj(compile.Atom.create_from_iter(tuple), theory) - def remediate_obj(self, formula, theory): + def delete_obj(self, formula, theory): + theory.delete(formula) + + # remediate + def remediate_string(self, policy_string): + policy = compile.parse(policy_string) + assert len(policy) == 1, "Queries can have only 1 statement" + return compile.formulas_to_string(self.remediate_obj(policy[0])) + + def remediate_tuple(self, tuple, theory): + self.remediate_obj(compile.Atom.create_from_iter(tuple)) + + def remediate_obj(self, formula): """ Find a collection of action invocations that if executed result in FORMULA becoming false. """ actionth = self.theory[self.ACTION_THEORY] @@ -1284,8 +1456,7 @@ class Runtime (object): leaf.table in base_tables)] logging.debug("Leaves: {}".format(iterstr(leaves))) # Query action theory for abductions of negated base tables - actions = actionth.select(compile.parse1('action(x)')) - actions = [action.arguments[0].name for action in actions] + actions = self.get_actions() results = [] for lit in leaves: goal = lit.make_positive() @@ -1299,11 +1470,101 @@ class Runtime (object): results.append(abduction) return results - def remediate_string(self, policy_string, theory): - policy = compile.parse(policy_string) - assert len(policy) == 1, "Queries can have only 1 statement" - return compile.formulas_to_string(self.remediate_obj(policy[0], theory)) + # project + def project_string(self, query, sequence): + query = compile.parse1(query) + sequence = compile.parse(sequence) + result = self.project_obj(query, sequence) + return compile.formulas_to_string(result) - def remediate_tuple(self, tuple, theory): - self.remediate_obj(compile.Atom.create_from_iter(tuple), theory) + def project_obj(self, query, sequence): + assert (isinstance(query, compile.Rule) or + isinstance(query, compile.Atom)), "Query must be formula" + # Each action is represented as a rule with the actual action + # in the head and its supporting data (e.g. options) in the body + assert all(isinstance(x, compile.Rule) or isinstance(x, compile.Atom) + for x in sequence), "Sequence must be an iterable of Rules" + actth = self.theory[self.ACTION_THEORY] + clsth = self.theory[self.CLASSIFY_THEORY] + # apply changes to the state + newth = NonrecursiveRuleTheory() + newth.tracer.trace('*') + actth.includes.append(newth) + actions = self.get_actions() + self.log(query.tablename(), "Actions: " + str(actions)) + change_sequence = [] # a list of lists of updates + for formula in sequence: + if formula.is_atom(): + tablename = formula.table + else: + tablename = formula.head.table + if tablename in actions: + self.log(tablename, "* Projecting " + str(formula)) + # add action to theory + if formula.is_atom(): + newth.define([formula]) + else: + newth.define([formula.head] + formula.body) + # compute updates caused by action + updates = actth.consequences(compile.is_update) + updates = self.resolve_conflicts(updates) + else: + updates = [formula] + # apply and remember each update-set + changes = [] + for update in updates: + undo = self.update_classifier(update) + if undo is not None: + changes.append(undo) + change_sequence.append(changes) + + # query the resulting state + result = clsth.select(query) + self.log(query.tablename(), "Result of {} is {}".format( + str(query), iterstr(result))) + # rollback the changes: in the reverse order we applied them in + self.log(query.tablename(), "* Rolling back") + actth.includes.remove(newth) + for changes in reversed(change_sequence): + for undo in reversed(changes): + self.update_classifier(undo) + return result + + def update_classifier(self, delta): + """ Takes an atom/rule DELTA with update head table + (i.e. ending in + or -) and inserts/deletes, respectively, + that atom/rule into CLASSIFY_THEORY after stripping + the +/-. Returns None if DELTA had no effect on the + current state or an atom/rule that when given to + MODIFY_STATE will produce the original state. """ + self.log(None, "Applying update {}".format(str(delta))) + clsth = self.theory[self.CLASSIFY_THEORY] + isinsert = delta.tablename().endswith('+') + newdelta = delta.drop_update() + if isinsert: + changed = clsth.insert(newdelta) + else: + changed = clsth.delete(newdelta) + if changed: + return delta.invert_update() + else: + return None + + def resolve_conflicts(self, atoms): + """ If p+(args) and p-(args) are present, removes the p-(args). """ + neg = set() + result = set() + # split atoms into NEG and RESULT + for atom in atoms: + if atom.table.endswith('+'): + result.add(atom) + elif atom.table.endswith('-'): + neg.add(atom) + else: + result.add(atom) + # add elems from NEG only if their inverted version not in RESULT + for atom in neg: + if atom.invert_update() not in result: # slow: copying ATOM here + result.add(atom) + return result diff --git a/src/policy/tests/test_runtime.py b/src/policy/tests/test_runtime.py index 28affbe99..34dc79c19 100644 --- a/src/policy/tests/test_runtime.py +++ b/src/policy/tests/test_runtime.py @@ -7,6 +7,7 @@ from policy import runtime from policy import unify from policy.runtime import Database import logging +import os class TestRuntime(unittest.TestCase): @@ -58,6 +59,9 @@ class TestRuntime(unittest.TestCase): self.assertTrue(len(extra) == 0 and len(missing) == 0, msg) def check(self, run, correct_database_code, msg=None): + """ Check that runtime RUN's classify theory database is + equal to CORRECT_DATABASE_CODE. Should rename this function to + 'check_run_database' or something similar. """ # extract correct answer from correct_database_code self.open(msg) correct_database = self.string_to_database(correct_database_code) @@ -925,6 +929,31 @@ class TestRuntime(unittest.TestCase): 'p+(x) :- q(x), q(x1)', "Existential variables with name collision") + def test_nonrecursive_consequences(self): + """ Test consequence computation for nonrecursive rule theory """ + def check(code, correct, msg): + # We're interacting directly with the runtime's underlying + # theory b/c we haven't decided whether consequences should + # be a top-level API call. + run = self.prep_runtime() + actth = runtime.Runtime.ACTION_THEORY + run.insert(code, target=actth) + actual = run.theory[actth].consequences() + # convert result to string, since check_same expects strings + actual = compile.formulas_to_string(actual) + self.check_same(actual, correct, msg) + + code = ('p+(x) :- q(x)' + 'q(1)' + 'q(2)') + check(code, 'p+(1) p+(2) q(1) q(2)', 'Monadic') + + code = ('p+(x) :- q(x)' + 'p-(x) :- r(x)' + 'q(1)' + 'q(2)') + check(code, 'p+(1) p+(2) q(1) q(2)', 'Monadic with empty tables') + def test_remediation(self): """Test remediation computation""" def check(action_code, classify_code, query, correct, msg): @@ -964,7 +993,121 @@ class TestRuntime(unittest.TestCase): 'p-(1) :- a(1) q-(1) :- b(1)', 'Monadic, two conditions, two actions') + def test_projection(self): + """ Test projection: the computation of a query given a sequence of + actions. """ + def create(action_code, class_code): + run = self.prep_runtime() + actth = run.ACTION_THEORY + clsth = run.CLASSIFY_THEORY + run.insert(action_code, target=actth) + run.insert(class_code, target=clsth) + return run + def check(run, action_sequence, query, correct, original_db, msg): + actual = run.project(query, action_sequence) + self.check_equal(actual, correct, msg) + self.check(run, original_db, msg) + # Simple + action_code = ('p+(x) :- q(x)' + 'action("q")') + classify_code = 'p(2)' # just some other data present + run = create(action_code, classify_code) + action_sequence = 'q(1)' + check(run, action_sequence, 'p(x)', 'p(1) p(2)', + classify_code, 'Simple') + + # Noop does not break rollback + action_code = ('p-(x) :- q(x)' + 'action("q")') + classify_code = ('') + run = create(action_code, classify_code) + action_sequence = 'q(1)' + check(run, action_sequence, 'p(x)', '', + classify_code, "Rollback handles Noop") + + # insertion takes precedence over deletion + action_code = ('p+(x) :- q(x)' + 'p-(x) :- r(x)' + 'action("q")') + classify_code = ('') + run = create(action_code, classify_code) + # ordered so that consequences will be p+(1) p-(1) + action_sequence = 'q(1) :- r(1)' + check(run, action_sequence, 'p(x)', 'p(1)', + classify_code, "Deletion before insertion") + + # multiple action sequences 1 + action_code = ('p+(x) :- q(x)' + 'p-(x) :- r(x)' + 'action("q")' + 'action("r")') + classify_code = ('') + run = create(action_code, classify_code) + action_sequence = 'q(1) r(1)' + check(run, action_sequence, 'p(x)', '', + classify_code, "Multiple actions: inversion from {}") + + # multiple action sequences 2 + action_code = ('p+(x) :- q(x)' + 'p-(x) :- r(x)' + 'action("q")' + 'action("r")') + classify_code = ('p(1)') + run = create(action_code, classify_code) + action_sequence = 'q(1) r(1)' + check(run, action_sequence, 'p(x)', '', + classify_code, + "Multiple actions: inversion from p(1), first is noop") + + # multiple action sequences 3 + action_code = ('p+(x) :- q(x)' + 'p-(x) :- r(x)' + 'action("q")' + 'action("r")') + classify_code = ('p(1)') + run = create(action_code, classify_code) + action_sequence = 'r(1) q(1)' + check(run, action_sequence, 'p(x)', 'p(1)', + classify_code, + "Multiple actions: inversion from p(1), first is not noop") + + # multiple action sequences 4 + action_code = ('p+(x) :- q(x)' + 'p-(x) :- r(x)' + 'action("q")' + 'action("r")') + classify_code = ('') + run = create(action_code, classify_code) + action_sequence = 'r(1) q(1)' + check(run, action_sequence, 'p(x)', 'p(1)', + classify_code, + "Multiple actions: inversion from {}, first is not noop") + + # Action with additional info + action_code = ('p+(x,z) :- q(x,y), r(y,z)' + 'action("q") action("r")') + classify_code = 'p(1,2)' + run = create(action_code, classify_code) + action_sequence = 'q(1,2) :- r(2,3)' + check(run, action_sequence, 'p(x,y)', 'p(1,2) p(1,3)', + classify_code, 'Action with additional info') + + # State update + action_code = '' + classify_code = 'p(1)' + run = create(action_code, classify_code) + action_sequence = 'p+(2)' + check(run, action_sequence, 'p(x)', 'p(1) p(2)', + classify_code, 'State update') + + # Rule update + action_code = '' + classify_code = 'q(1)' + run = create(action_code, classify_code) + action_sequence = 'p+(x) :- q(x)' + check(run, action_sequence, 'p(x)', 'p(1)', + classify_code, 'Rule update') if __name__ == '__main__': unittest.main()