congress/src/policy/compile.py

671 lines
22 KiB
Python
Executable File

#! /usr/bin/python
import sys
sys.path.insert(0, '/home/thinrichs/congress/thirdparty')
#sys.path.insert(0, '/opt/python-antlr3')
import optparse
import CongressLexer
import CongressParser
import antlr3
import logging
import copy
import runtime
class CongressException (Exception):
def __init__(self, msg, obj=None, line=None, col=None):
Exception.__init__(self, msg)
self.obj = obj
self.location = Location(line=line, col=col, obj=obj)
def __str__(self):
s = str(self.location)
if len(s) > 0:
s = " at" + s
return Exception.__str__(self) + s
##############################################################################
## Internal representation of policy language
##############################################################################
class Location (object):
""" A location in the program source code. """
def __init__(self, line=None, col=None, obj=None):
self.line = None
self.col = None
try:
self.line = obj.location.line
self.col = obj.location.col
except AttributeError:
pass
self.col = col
self.line = line
def __str__(self):
s = ""
if self.line is not None:
s += " line: {}".format(self.line)
if self.col is not None:
s += " col: {}".format(self.col)
return s
def __repr__(self):
return "Location(line={}, col={})".format(
repr(self.line), repr(self.col))
def __hash__(self):
return hash(self.__repr__())
class Term(object):
""" Represents the union of Variable and ObjectConstant. Should
only be instantiated via factory method. """
def __init__(self):
assert False, "Cannot instantiate Term directly--use factory method"
@classmethod
def create_from_python(cls, value, force_var=False):
""" To create variable, FORCE_VAR needs to be true. There is currently
no way to avoid this since variables are strings. """
if isinstance(value, Term):
return value
elif force_var:
return Variable(str(value))
elif isinstance(value, basestring):
return ObjectConstant(value, ObjectConstant.STRING)
elif isinstance(value, (int, long)):
return ObjectConstant(value, ObjectConstant.INTEGER)
elif isinstance(value, float):
return ObjectConstant(value, ObjectConstant.FLOAT)
else:
assert False, "No Term corresponding to {}".format(repr(value))
class Variable (Term):
""" Represents a term without a fixed value. """
def __init__(self, name, location=None):
self.name = name
self.location = location
def __str__(self):
return str(self.name)
def __eq__(self, other):
return isinstance(other, Variable) and self.name == other.name
def __ne__(self, other):
return not self == other
def __repr__(self):
return "Variable(name={}, location={})".format(
repr(self.name), repr(self.location))
def __hash__(self):
return hash("Variable(name={})".format(repr(self.name)))
def is_variable(self):
return True
def is_object(self):
return False
class ObjectConstant (Term):
""" Represents a term with a fixed value. """
STRING = 'STRING'
FLOAT = 'FLOAT'
INTEGER = 'INTEGER'
def __init__(self, name, type, location=None):
assert(type in [self.STRING, self.FLOAT, self.INTEGER])
self.name = name
self.type = type
self.location = location
def __str__(self):
return str(self.name)
def __repr__(self):
return "ObjectConstant(name={}, type={}, location={})".format(
repr(self.name), repr(self.type), repr(self.location))
def __hash__(self):
return hash("ObjectConstant(name={}, type={})".format(
repr(self.name), repr(self.type)))
def __eq__(self, other):
return (isinstance(other, ObjectConstant) and
self.name == other.name and
self.type == other.type)
def __ne__(self, other):
return not self == other
def is_variable(self):
return False
def is_object(self):
return True
class Atom (object):
""" Represents an atomic statement, e.g. p(a, 17, b) """
def __init__(self, table, arguments, location=None):
self.table = table
self.arguments = arguments
self.location = location
@classmethod
def create_from_table_tuple(cls, table, tuple):
""" LIST is a python list representing an atom, e.g.
['p', 17, "string", 3.14]. Returns the corresponding Atom. """
return cls(table, [Term.create_from_python(x) for x in tuple])
@classmethod
def create_from_iter(cls, list):
""" LIST is a python list representing an atom, e.g.
['p', 17, "string", 3.14]. Returns the corresponding Atom. """
arguments = []
for i in xrange(1, len(list)):
arguments.append(Term.create_from_python(list[i]))
return cls(list[0], arguments)
def __str__(self):
return "{}({})".format(self.table,
", ".join([str(x) for x in self.arguments]))
def __eq__(self, other):
return (isinstance(other, Atom) and
self.table == other.table and
len(self.arguments) == len(other.arguments) and
all(self.arguments[i] == other.arguments[i]
for i in xrange(0, len(self.arguments))))
def __ne__(self, other):
return not self == other
def __repr__(self):
return "Atom(table={}, arguments={}, location={})".format(
repr(self.table),
"[" + ",".join(repr(arg) for arg in self.arguments) + "]",
repr(self.location))
def __hash__(self):
return hash("Atom(table={}, arguments={})".format(
repr(self.table),
"[" + ",".join(repr(arg) for arg in self.arguments) + "]"))
def is_atom(self):
return True
def is_negated(self):
return False
def is_rule(self):
return False
def variable_names(self):
return set([x.name for x in self.arguments if x.is_variable()])
def variables(self):
return set([x for x in self.arguments if x.is_variable()])
def is_ground(self):
return all(not arg.is_variable() for arg in self.arguments)
def plug(self, binding, caller=None):
"Assumes domain of BINDING is Terms"
# logging.debug("Atom.plug({}, {})".format(str(binding), str(caller)))
new = copy.copy(self)
if isinstance(binding, dict):
args = []
for arg in self.arguments:
if arg in binding:
args.append(Term.create_from_python(binding[arg]))
else:
args.append(arg)
new.arguments = args
return new
else:
args = [Term.create_from_python(binding.apply(arg, caller))
for arg in self.arguments]
new.arguments = args
return new
def argument_names(self):
return tuple([arg.name for arg in self.arguments])
def make_positive(self):
""" Does NOT make copy """
return self
class Literal(Atom):
""" Represents either a negated atom or an atom. """
def __init__(self, table, arguments, negated=False, location=None):
Atom.__init__(self, table, arguments, location=location)
self.negated = negated
def __str__(self):
if self.negated:
return "not {}".format(Atom.__str__(self))
else:
return Atom.__str__(self)
def __eq__(self, other):
return (self.negated == other.negated and Atom.__eq__(self, other))
def __repr__(self):
return "Literal(table={}, arguments={}, location={}, negated={})".format(
repr(self.table),
"[" + ",".join(repr(arg) for arg in self.arguments) + "]",
repr(self.location),
repr(self.negated))
def __hash__(self):
return hash("Literal(table={}, arguments={}, negated={})".format(
repr(self.table),
"[" + ",".join(repr(arg) for arg in self.arguments) + "]",
repr(self.negated)))
def is_negated(self):
return self.negated
def is_atom(self):
return not self.negated
def is_rule(self):
return False
def complement(self):
""" Copies SELF and inverts is_negated. """
new = copy.copy(self)
new.negated = not new.negated
return new
def make_positive(self):
""" Copies SELF and makes is_negated False. """
new = copy.copy(self)
new.negated = False
return new
class Rule (object):
""" Represents a rule, e.g. p(x) :- q(x). """
def __init__(self, head, body, location=None):
self.head = head
self.body = body
self.location = location
def __str__(self):
return "{} :- {}".format(
str(self.head),
", ".join([str(atom) for atom in self.body]))
def __eq__(self, other):
return (self.head == other.head and
len(self.body) == len(other.body) and
all(self.body[i] == other.body[i]
for i in xrange(0, len(self.body))))
def __repr__(self):
return "Rule(head={}, body={}, location={})".format(
repr(self.head),
"[" + ",".join(repr(arg) for arg in self.body) + "]",
repr(self.location))
def __hash__(self):
return hash("Rule(head={}, body={})".format(
repr(self.head),
"[" + ",".join(repr(arg) for arg in self.body) + "]"))
def is_atom(self):
return False
def is_rule(self):
return True
def plug(self, binding, caller=None):
newhead = self.head.plug(binding, caller=caller)
newbody = [lit.plug(binding, caller=caller) for lit in self.body]
return Rule(newhead, newbody)
def variables(self):
vs = self.head.variables()
for lit in self.body:
vs |= lit.variables()
return vs
def variable_names(self):
vs = self.head.variable_names()
for lit in self.body:
vs |= lit.variable_names()
return vs
def formulas_to_string(formulas):
""" Takes an iterable of compiler sentence objects and returns a
string representing that iterable, which the compiler will parse
into the original iterable. """
return " ".join([str(formula) for formula in formulas])
##############################################################################
## Compiler
##############################################################################
class Compiler (object):
""" Process Congress policy file. """
def __init__(self):
self.raw_syntax_tree = None
self.theory = []
self.errors = []
self.warnings = []
def __str__ (self):
s = ""
s += '**Theory**\n'
if self.theory is not None:
s += '\n'.join([str(x) for x in self.theory])
else:
s += 'None'
return s
def read_source(self, input, input_string=False):
# parse input file and convert to internal representation
self.raw_syntax_tree = CongressSyntax.parse_file(input,
input_string=input_string)
# self.print_parse_result()
self.theory = CongressSyntax.create(self.raw_syntax_tree)
# print str(self)
def print_parse_result(self):
print_tree(
self.raw_syntax_tree,
lambda x: x.getText(),
lambda x: x.children,
ind=1)
def sigerr(self, error):
self.errors.append(error)
def sigwarn(self, error):
self.warnings.append(error)
def raise_errors(self):
if len(self.errors) > 0:
errors = [str(err) for err in self.errors]
raise CongressException('Compiler found errors:' + '\n'.join(errors))
def compute_delta_rules(self):
# logging.debug("self.theory: {}".format([str(x) for x in self.theory]))
self.delta_rules = compute_delta_rules(self.theory)
def eliminate_self_joins(theory):
""" Modify THEORY so that all self-joins have been eliminated. """
def new_table_name(name, arity, index):
return "___{}_{}_{}".format(name, arity, index)
def n_variables(n):
vars = []
for i in xrange(0, n):
vars.append("x" + str(i))
return vars
# dict from (table name, arity) tuple to
# max num of occurrences of self-joins in any rule
global_self_joins = {}
# dict from (table name, arity) to # of args for
arities = {}
# remove self-joins from rules
for rule in theory:
if rule.is_atom():
continue
logging.debug("eliminating self joins from {}".format(rule))
occurrences = {} # for just this rule
for atom in rule.body:
table = atom.table
arity = len(atom.arguments)
tablearity = (table, arity)
if tablearity not in occurrences:
occurrences[tablearity] = 1
else:
# change name of atom
atom.table = new_table_name(table, arity,
occurrences[tablearity])
# update our counters
occurrences[tablearity] += 1
if tablearity not in global_self_joins:
global_self_joins[tablearity] = 1
else:
global_self_joins[tablearity] = \
max(occurrences[tablearity] - 1,
global_self_joins[tablearity])
logging.debug("final rule: {}".format(str(rule)))
# add definitions for new tables
for tablearity in global_self_joins:
table = tablearity[0]
arity = tablearity[1]
for i in xrange(1, global_self_joins[tablearity] + 1):
newtable = new_table_name(table, arity, i)
args = [Variable(var) for var in n_variables(arity)]
head = Atom(newtable, args)
body = [Atom(table, args)]
theory.append(Rule(head, body))
logging.debug("Adding rule {}".format(str(theory[-1])))
return theory
def compute_delta_rules(theory):
eliminate_self_joins(theory)
delta_rules = []
for rule in theory:
if rule.is_atom():
continue
for literal in rule.body:
newbody = [lit for lit in rule.body if lit is not literal]
delta_rules.append(
runtime.DeltaRule(literal, rule.head, newbody, rule))
return delta_rules
##############################################################################
## External syntax: datalog
##############################################################################
class CongressSyntax (object):
""" External syntax and converting it into internal representation. """
class Lexer(CongressLexer.CongressLexer):
def __init__(self, char_stream, state=None):
self.error_list = []
CongressLexer.CongressLexer.__init__(self, char_stream, state)
def displayRecognitionError(self, token_names, e):
hdr = self.getErrorHeader(e)
msg = self.getErrorMessage(e, token_names)
self.error_list.append(str(hdr) + " " + str(msg))
def getErrorHeader(self, e):
return "line:{},col:{}".format(
e.line, e.charPositionInLine)
class Parser(CongressParser.CongressParser):
def __init__(self, tokens, state=None):
self.error_list = []
CongressParser.CongressParser.__init__(self, tokens, state)
def displayRecognitionError(self, token_names, e):
hdr = self.getErrorHeader(e)
msg = self.getErrorMessage(e, token_names)
self.error_list.append(str(hdr) + " " + str(msg))
def getErrorHeader(self, e):
return "line:{},col:{}".format(
e.line, e.charPositionInLine)
@classmethod
def parse_file(cls, input, input_string=False):
if not input_string:
char_stream = antlr3.ANTLRFileStream(input)
else:
char_stream = antlr3.ANTLRStringStream(input)
lexer = cls.Lexer(char_stream)
tokens = antlr3.CommonTokenStream(lexer)
parser = cls.Parser(tokens)
result = parser.prog()
if len(lexer.error_list) > 0:
raise CongressException("Lex failure.\n" +
"\n".join(lexer.error_list))
if len(parser.error_list) > 0:
raise CongressException("Parse failure.\n" + \
"\n".join(parser.error_list))
return result.tree
@classmethod
def create(cls, antlr):
obj = antlr.getText()
if obj == 'RULE':
return cls.create_rule(antlr)
elif obj == 'NOT':
return cls.create_literal(antlr)
elif obj == 'ATOM': # Note we're creating an ATOM, not a LITERAL
return cls.create_atom(antlr)
elif obj == 'THEORY':
return [cls.create(x) for x in antlr.children]
elif obj == '<EOF>':
return []
else:
raise CongressException(
"Antlr tree with unknown root: {}".format(obj))
@classmethod
def create_rule(cls, antlr):
# (RULE (ATOM LITERAL1 ... LITERALN))
# Makes body a list of literals
head = cls.create_atom(antlr.children[0])
body = []
for i in xrange(1, len(antlr.children)):
body.append(cls.create_literal(antlr.children[i]))
loc = Location(line=antlr.children[0].token.line,
col=antlr.children[0].token.charPositionInLine)
return Rule(head, body, location=loc)
@classmethod
def create_literal(cls, antlr):
# (NOT (ATOM (TABLE ARG1 ... ARGN)))
# (ATOM (TABLE ARG1 ... ARGN))
if antlr.getText() == 'NOT':
negated = True
antlr = antlr.children[0]
else:
negated = False
(table, args, loc) = cls.create_atom_aux(antlr)
return Literal(table, args, negated=negated, location=loc)
@classmethod
def create_atom(cls, antlr):
(table, args, loc) = cls.create_atom_aux(antlr)
return Atom(table, args, location=loc)
@classmethod
def create_atom_aux(cls, antlr):
# (ATOM (TABLENAME ARG1 ... ARGN))
table = cls.create_structured_name(antlr.children[0])
args = []
for i in xrange(1, len(antlr.children)):
args.append(cls.create_term(antlr.children[i]))
loc = Location(line=antlr.children[0].token.line,
col=antlr.children[0].token.charPositionInLine)
return (table, args, loc)
@classmethod
def create_structured_name(cls, antlr):
# (STRUCTURED_NAME (ARG1 ... ARGN))
if antlr.children[-1].getText() in ['+', '-']:
return (":".join([x.getText() for x in antlr.children[:-1]]) +
antlr.children[-1].getText())
else:
return ":".join([x.getText() for x in antlr.children])
@classmethod
def create_term(cls, antlr):
# (TYPE (VALUE))
op = antlr.getText()
loc = Location(line=antlr.children[0].token.line,
col=antlr.children[0].token.charPositionInLine)
if op == 'STRING_OBJ':
value = antlr.children[0].getText()
return ObjectConstant(value[1:len(value) - 1], # prune quotes
ObjectConstant.STRING,
location=loc)
elif op == 'INTEGER_OBJ':
return ObjectConstant(int(antlr.children[0].getText()),
ObjectConstant.INTEGER,
location=loc)
elif op == 'FLOAT_OBJ':
return ObjectConstant(float(antlr.children[0].getText()),
ObjectConstant.FLOAT,
location=loc)
elif op == 'VARIABLE':
return Variable(antlr.children[0].getText(), location=loc)
else:
raise CongressException("Unknown term operator: {}".format(op))
def print_tree(tree, text, kids, ind=0):
""" Print out TREE using function TEXT to extract node description and
function KIDS to compute the children of a given node.
IND is a number representing the indentation level. """
print "|" * ind,
print "{}".format(str(text(tree)))
children = kids(tree)
if children:
for child in children:
print_tree(child, text, kids, ind + 1)
##############################################################################
## Mains
##############################################################################
def parse(policy_string):
""" Run compiler on policy string and return the parsed formulas. """
compiler = get_compiler([policy_string, '--input_string'])
return compiler.theory
def parse1(policy_string):
""" Run compiler on policy string and return 1st parsed formula. """
return parse(policy_string)[0]
def parse_file(filename):
""" Run compiler on policy stored in FILENAME and return the parsed formulas. """
compiler = get_compiler([filename])
return compiler.theory
def get_compiler(args):
""" Run compiler as per ARGS and return the compiler object. """
# assumes script name is not passed
parser = optparse.OptionParser()
parser.add_option("--input_string", dest="input_string", default=False,
action="store_true",
help="Indicates that inputs should be treated not as file names but "
"as the contents to compile")
(options, inputs) = parser.parse_args(args)
compiler = Compiler()
for i in inputs:
compiler.read_source(i, input_string=options.input_string)
return compiler
def get_runtime(args):
""" Create runtime by running compiler as per ARGS and initializing runtime
with result of compilation. """
comp = get_compiled(args)
run = runtime.Runtime(comp.delta_rules)
tracer = runtime.Tracer()
tracer.trace('*')
run.tracer = tracer
run.database.tracer = tracer
return run
def main(args):
c = get_compiled(args)
for formula in c.theory:
print str(c)
if __name__ == '__main__':
main(sys.argv[1:])