From 8aefef51bd666c950124fdf0093e9c61a1236b1c Mon Sep 17 00:00:00 2001 From: Isaku Yamahata Date: Fri, 12 Oct 2012 10:40:35 +0900 Subject: [PATCH] import ovs python binding library From changeset 8087f5ff825cae3a699e5a60ca6dd0deb10fc8e5 dirs.py.template needs to be adopted for Ryu environment. Signed-off-by: Isaku Yamahata --- ryu/contrib/ovs/__init__.py | 1 + ryu/contrib/ovs/daemon.py | 537 +++++++++++ ryu/contrib/ovs/db/__init__.py | 1 + ryu/contrib/ovs/db/data.py | 547 ++++++++++++ ryu/contrib/ovs/db/error.py | 34 + ryu/contrib/ovs/db/idl.py | 1287 +++++++++++++++++++++++++++ ryu/contrib/ovs/db/parser.py | 109 +++ ryu/contrib/ovs/db/schema.py | 271 ++++++ ryu/contrib/ovs/db/types.py | 587 ++++++++++++ ryu/contrib/ovs/dirs.py.template | 17 + ryu/contrib/ovs/fatal_signal.py | 136 +++ ryu/contrib/ovs/json.py | 586 ++++++++++++ ryu/contrib/ovs/jsonrpc.py | 560 ++++++++++++ ryu/contrib/ovs/ovsuuid.py | 70 ++ ryu/contrib/ovs/poller.py | 185 ++++ ryu/contrib/ovs/process.py | 41 + ryu/contrib/ovs/reconnect.py | 588 ++++++++++++ ryu/contrib/ovs/socket_util.py | 192 ++++ ryu/contrib/ovs/stream.py | 361 ++++++++ ryu/contrib/ovs/timeval.py | 26 + ryu/contrib/ovs/unixctl/__init__.py | 83 ++ ryu/contrib/ovs/unixctl/client.py | 70 ++ ryu/contrib/ovs/unixctl/server.py | 247 +++++ ryu/contrib/ovs/util.py | 93 ++ ryu/contrib/ovs/version.py | 2 + ryu/contrib/ovs/vlog.py | 267 ++++++ 26 files changed, 6898 insertions(+) create mode 100644 ryu/contrib/ovs/__init__.py create mode 100644 ryu/contrib/ovs/daemon.py create mode 100644 ryu/contrib/ovs/db/__init__.py create mode 100644 ryu/contrib/ovs/db/data.py create mode 100644 ryu/contrib/ovs/db/error.py create mode 100644 ryu/contrib/ovs/db/idl.py create mode 100644 ryu/contrib/ovs/db/parser.py create mode 100644 ryu/contrib/ovs/db/schema.py create mode 100644 ryu/contrib/ovs/db/types.py create mode 100644 ryu/contrib/ovs/dirs.py.template create mode 100644 ryu/contrib/ovs/fatal_signal.py create mode 100644 ryu/contrib/ovs/json.py create mode 100644 ryu/contrib/ovs/jsonrpc.py create mode 100644 ryu/contrib/ovs/ovsuuid.py create mode 100644 ryu/contrib/ovs/poller.py create mode 100644 ryu/contrib/ovs/process.py create mode 100644 ryu/contrib/ovs/reconnect.py create mode 100644 ryu/contrib/ovs/socket_util.py create mode 100644 ryu/contrib/ovs/stream.py create mode 100644 ryu/contrib/ovs/timeval.py create mode 100644 ryu/contrib/ovs/unixctl/__init__.py create mode 100644 ryu/contrib/ovs/unixctl/client.py create mode 100644 ryu/contrib/ovs/unixctl/server.py create mode 100644 ryu/contrib/ovs/util.py create mode 100644 ryu/contrib/ovs/version.py create mode 100644 ryu/contrib/ovs/vlog.py diff --git a/ryu/contrib/ovs/__init__.py b/ryu/contrib/ovs/__init__.py new file mode 100644 index 00000000..218d8921 --- /dev/null +++ b/ryu/contrib/ovs/__init__.py @@ -0,0 +1 @@ +# This file intentionally left blank. diff --git a/ryu/contrib/ovs/daemon.py b/ryu/contrib/ovs/daemon.py new file mode 100644 index 00000000..650d2504 --- /dev/null +++ b/ryu/contrib/ovs/daemon.py @@ -0,0 +1,537 @@ +# Copyright (c) 2010, 2011 Nicira, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import errno +import fcntl +import os +import resource +import signal +import sys +import time + +import ovs.dirs +import ovs.fatal_signal +#import ovs.lockfile +import ovs.process +import ovs.socket_util +import ovs.timeval +import ovs.util +import ovs.vlog + +vlog = ovs.vlog.Vlog("daemon") + +# --detach: Should we run in the background? +_detach = False + +# --pidfile: Name of pidfile (null if none). +_pidfile = None + +# Our pidfile's inode and device, if we have created one. +_pidfile_dev = None +_pidfile_ino = None + +# --overwrite-pidfile: Create pidfile even if one already exists and is locked? +_overwrite_pidfile = False + +# --no-chdir: Should we chdir to "/"? +_chdir = True + +# --monitor: Should a supervisory process monitor the daemon and restart it if +# it dies due to an error signal? +_monitor = False + +# File descriptor used by daemonize_start() and daemonize_complete(). +_daemonize_fd = None + +RESTART_EXIT_CODE = 5 + + +def make_pidfile_name(name): + """Returns the file name that would be used for a pidfile if 'name' were + provided to set_pidfile().""" + if name is None or name == "": + return "%s/%s.pid" % (ovs.dirs.RUNDIR, ovs.util.PROGRAM_NAME) + else: + return ovs.util.abs_file_name(ovs.dirs.RUNDIR, name) + + +def set_pidfile(name): + """Sets up a following call to daemonize() to create a pidfile named + 'name'. If 'name' begins with '/', then it is treated as an absolute path. + Otherwise, it is taken relative to ovs.util.RUNDIR, which is + $(prefix)/var/run by default. + + If 'name' is null, then ovs.util.PROGRAM_NAME followed by ".pid" is + used.""" + global _pidfile + _pidfile = make_pidfile_name(name) + + +def get_pidfile(): + """Returns an absolute path to the configured pidfile, or None if no + pidfile is configured.""" + return _pidfile + + +def set_no_chdir(): + """Sets that we do not chdir to "/".""" + global _chdir + _chdir = False + + +def is_chdir_enabled(): + """Will we chdir to "/" as part of daemonizing?""" + return _chdir + + +def ignore_existing_pidfile(): + """Normally, daemonize() or daemonize_start() will terminate the program + with a message if a locked pidfile already exists. If this function is + called, an existing pidfile will be replaced, with a warning.""" + global _overwrite_pidfile + _overwrite_pidfile = True + + +def set_detach(): + """Sets up a following call to daemonize() to detach from the foreground + session, running this process in the background.""" + global _detach + _detach = True + + +def get_detach(): + """Will daemonize() really detach?""" + return _detach + + +def set_monitor(): + """Sets up a following call to daemonize() to fork a supervisory process to + monitor the daemon and restart it if it dies due to an error signal.""" + global _monitor + _monitor = True + + +def _fatal(msg): + vlog.err(msg) + sys.stderr.write("%s\n" % msg) + sys.exit(1) + + +def _make_pidfile(): + """If a pidfile has been configured, creates it and stores the running + process's pid in it. Ensures that the pidfile will be deleted when the + process exits.""" + pid = os.getpid() + + # Create a temporary pidfile. + tmpfile = "%s.tmp%d" % (_pidfile, pid) + ovs.fatal_signal.add_file_to_unlink(tmpfile) + try: + # This is global to keep Python from garbage-collecting and + # therefore closing our file after this function exits. That would + # unlock the lock for us, and we don't want that. + global file_handle + + file_handle = open(tmpfile, "w") + except IOError, e: + _fatal("%s: create failed (%s)" % (tmpfile, e.strerror)) + + try: + s = os.fstat(file_handle.fileno()) + except IOError, e: + _fatal("%s: fstat failed (%s)" % (tmpfile, e.strerror)) + + try: + file_handle.write("%s\n" % pid) + file_handle.flush() + except OSError, e: + _fatal("%s: write failed: %s" % (tmpfile, e.strerror)) + + try: + fcntl.lockf(file_handle, fcntl.LOCK_EX | fcntl.LOCK_NB) + except IOError, e: + _fatal("%s: fcntl failed: %s" % (tmpfile, e.strerror)) + + # Rename or link it to the correct name. + if _overwrite_pidfile: + try: + os.rename(tmpfile, _pidfile) + except OSError, e: + _fatal("failed to rename \"%s\" to \"%s\" (%s)" + % (tmpfile, _pidfile, e.strerror)) + else: + while True: + try: + os.link(tmpfile, _pidfile) + error = 0 + except OSError, e: + error = e.errno + if error == errno.EEXIST: + _check_already_running() + elif error != errno.EINTR: + break + if error: + _fatal("failed to link \"%s\" as \"%s\" (%s)" + % (tmpfile, _pidfile, os.strerror(error))) + + # Ensure that the pidfile will get deleted on exit. + ovs.fatal_signal.add_file_to_unlink(_pidfile) + + # Delete the temporary pidfile if it still exists. + if not _overwrite_pidfile: + error = ovs.fatal_signal.unlink_file_now(tmpfile) + if error: + _fatal("%s: unlink failed (%s)" % (tmpfile, os.strerror(error))) + + global _pidfile_dev + global _pidfile_ino + _pidfile_dev = s.st_dev + _pidfile_ino = s.st_ino + + +def daemonize(): + """If configured with set_pidfile() or set_detach(), creates the pid file + and detaches from the foreground session.""" + daemonize_start() + daemonize_complete() + + +def _waitpid(pid, options): + while True: + try: + return os.waitpid(pid, options) + except OSError, e: + if e.errno == errno.EINTR: + pass + return -e.errno, 0 + + +def _fork_and_wait_for_startup(): + try: + rfd, wfd = os.pipe() + except OSError, e: + sys.stderr.write("pipe failed: %s\n" % os.strerror(e.errno)) + sys.exit(1) + + try: + pid = os.fork() + except OSError, e: + sys.stderr.write("could not fork: %s\n" % os.strerror(e.errno)) + sys.exit(1) + + if pid > 0: + # Running in parent process. + os.close(wfd) + ovs.fatal_signal.fork() + while True: + try: + s = os.read(rfd, 1) + error = 0 + except OSError, e: + s = "" + error = e.errno + if error != errno.EINTR: + break + if len(s) != 1: + retval, status = _waitpid(pid, 0) + if retval == pid: + if os.WIFEXITED(status) and os.WEXITSTATUS(status): + # Child exited with an error. Convey the same error to + # our parent process as a courtesy. + sys.exit(os.WEXITSTATUS(status)) + else: + sys.stderr.write("fork child failed to signal " + "startup (%s)\n" + % ovs.process.status_msg(status)) + else: + assert retval < 0 + sys.stderr.write("waitpid failed (%s)\n" + % os.strerror(-retval)) + sys.exit(1) + + os.close(rfd) + else: + # Running in parent process. + os.close(rfd) + ovs.timeval.postfork() + #ovs.lockfile.postfork() + + global _daemonize_fd + _daemonize_fd = wfd + return pid + + +def _fork_notify_startup(fd): + if fd is not None: + error, bytes_written = ovs.socket_util.write_fully(fd, "0") + if error: + sys.stderr.write("could not write to pipe\n") + sys.exit(1) + os.close(fd) + + +def _should_restart(status): + global RESTART_EXIT_CODE + + if os.WIFEXITED(status) and os.WEXITSTATUS(status) == RESTART_EXIT_CODE: + return True + + if os.WIFSIGNALED(status): + for signame in ("SIGABRT", "SIGALRM", "SIGBUS", "SIGFPE", "SIGILL", + "SIGPIPE", "SIGSEGV", "SIGXCPU", "SIGXFSZ"): + if os.WTERMSIG(status) == getattr(signal, signame, None): + return True + return False + + +def _monitor_daemon(daemon_pid): + # XXX should log daemon's stderr output at startup time + # XXX should use setproctitle module if available + last_restart = None + while True: + retval, status = _waitpid(daemon_pid, 0) + if retval < 0: + sys.stderr.write("waitpid failed\n") + sys.exit(1) + elif retval == daemon_pid: + status_msg = ("pid %d died, %s" + % (daemon_pid, ovs.process.status_msg(status))) + + if _should_restart(status): + if os.WCOREDUMP(status): + # Disable further core dumps to save disk space. + try: + resource.setrlimit(resource.RLIMIT_CORE, (0, 0)) + except resource.error: + vlog.warn("failed to disable core dumps") + + # Throttle restarts to no more than once every 10 seconds. + if (last_restart is not None and + ovs.timeval.msec() < last_restart + 10000): + vlog.warn("%s, waiting until 10 seconds since last " + "restart" % status_msg) + while True: + now = ovs.timeval.msec() + wakeup = last_restart + 10000 + if now > wakeup: + break + print "sleep %f" % ((wakeup - now) / 1000.0) + time.sleep((wakeup - now) / 1000.0) + last_restart = ovs.timeval.msec() + + vlog.err("%s, restarting" % status_msg) + daemon_pid = _fork_and_wait_for_startup() + if not daemon_pid: + break + else: + vlog.info("%s, exiting" % status_msg) + sys.exit(0) + + # Running in new daemon process. + + +def _close_standard_fds(): + """Close stdin, stdout, stderr. If we're started from e.g. an SSH session, + then this keeps us from holding that session open artificially.""" + null_fd = ovs.socket_util.get_null_fd() + if null_fd >= 0: + os.dup2(null_fd, 0) + os.dup2(null_fd, 1) + os.dup2(null_fd, 2) + + +def daemonize_start(): + """If daemonization is configured, then starts daemonization, by forking + and returning in the child process. The parent process hangs around until + the child lets it know either that it completed startup successfully (by + calling daemon_complete()) or that it failed to start up (by exiting with a + nonzero exit code).""" + + if _detach: + if _fork_and_wait_for_startup() > 0: + # Running in parent process. + sys.exit(0) + # Running in daemon or monitor process. + + if _monitor: + saved_daemonize_fd = _daemonize_fd + daemon_pid = _fork_and_wait_for_startup() + if daemon_pid > 0: + # Running in monitor process. + _fork_notify_startup(saved_daemonize_fd) + _close_standard_fds() + _monitor_daemon(daemon_pid) + # Running in daemon process + + if _pidfile: + _make_pidfile() + + +def daemonize_complete(): + """If daemonization is configured, then this function notifies the parent + process that the child process has completed startup successfully.""" + _fork_notify_startup(_daemonize_fd) + + if _detach: + os.setsid() + if _chdir: + os.chdir("/") + _close_standard_fds() + + +def usage(): + sys.stdout.write(""" +Daemon options: + --detach run in background as daemon + --no-chdir do not chdir to '/' + --pidfile[=FILE] create pidfile (default: %s/%s.pid) + --overwrite-pidfile with --pidfile, start even if already running +""" % (ovs.dirs.RUNDIR, ovs.util.PROGRAM_NAME)) + + +def __read_pidfile(pidfile, delete_if_stale): + if _pidfile_dev is not None: + try: + s = os.stat(pidfile) + if s.st_ino == _pidfile_ino and s.st_dev == _pidfile_dev: + # It's our own pidfile. We can't afford to open it, + # because closing *any* fd for a file that a process + # has locked also releases all the locks on that file. + # + # Fortunately, we know the associated pid anyhow. + return os.getpid() + except OSError: + pass + + try: + file_handle = open(pidfile, "r+") + except IOError, e: + if e.errno == errno.ENOENT and delete_if_stale: + return 0 + vlog.warn("%s: open: %s" % (pidfile, e.strerror)) + return -e.errno + + # Python fcntl doesn't directly support F_GETLK so we have to just try + # to lock it. + try: + fcntl.lockf(file_handle, fcntl.LOCK_EX | fcntl.LOCK_NB) + + # pidfile exists but wasn't locked by anyone. Now we have the lock. + if not delete_if_stale: + file_handle.close() + vlog.warn("%s: pid file is stale" % pidfile) + return -errno.ESRCH + + # Is the file we have locked still named 'pidfile'? + try: + raced = False + s = os.stat(pidfile) + s2 = os.fstat(file_handle.fileno()) + if s.st_ino != s2.st_ino or s.st_dev != s2.st_dev: + raced = True + except IOError: + raced = True + if raced: + vlog.warn("%s: lost race to delete pidfile" % pidfile) + return -errno.EALREADY + + # We won the right to delete the stale pidfile. + try: + os.unlink(pidfile) + except IOError, e: + vlog.warn("%s: failed to delete stale pidfile (%s)" + % (pidfile, e.strerror)) + return -e.errno + else: + vlog.dbg("%s: deleted stale pidfile" % pidfile) + file_handle.close() + return 0 + except IOError, e: + if e.errno not in [errno.EACCES, errno.EAGAIN]: + vlog.warn("%s: fcntl: %s" % (pidfile, e.strerror)) + return -e.errno + + # Someone else has the pidfile locked. + try: + try: + error = int(file_handle.readline()) + except IOError, e: + vlog.warn("%s: read: %s" % (pidfile, e.strerror)) + error = -e.errno + except ValueError: + vlog.warn("%s does not contain a pid" % pidfile) + error = -errno.EINVAL + + return error + finally: + try: + file_handle.close() + except IOError: + pass + + +def read_pidfile(pidfile): + """Opens and reads a PID from 'pidfile'. Returns the positive PID if + successful, otherwise a negative errno value.""" + return __read_pidfile(pidfile, False) + + +def _check_already_running(): + pid = __read_pidfile(_pidfile, True) + if pid > 0: + _fatal("%s: already running as pid %d, aborting" % (_pidfile, pid)) + elif pid < 0: + _fatal("%s: pidfile check failed (%s), aborting" + % (_pidfile, os.strerror(pid))) + + +def add_args(parser): + """Populates 'parser', an ArgumentParser allocated using the argparse + module, with the command line arguments required by the daemon module.""" + + pidfile = make_pidfile_name(None) + + group = parser.add_argument_group(title="Daemon Options") + group.add_argument("--detach", action="store_true", + help="Run in background as a daemon.") + group.add_argument("--no-chdir", action="store_true", + help="Do not chdir to '/'.") + group.add_argument("--monitor", action="store_true", + help="Monitor %s process." % ovs.util.PROGRAM_NAME) + group.add_argument("--pidfile", nargs="?", const=pidfile, + help="Create pidfile (default %s)." % pidfile) + group.add_argument("--overwrite-pidfile", action="store_true", + help="With --pidfile, start even if already running.") + + +def handle_args(args): + """Handles daemon module settings in 'args'. 'args' is an object + containing values parsed by the parse_args() method of ArgumentParser. The + parent ArgumentParser should have been prepared by add_args() before + calling parse_args().""" + + if args.detach: + set_detach() + + if args.no_chdir: + set_no_chdir() + + if args.pidfile: + set_pidfile(args.pidfile) + + if args.overwrite_pidfile: + ignore_existing_pidfile() + + if args.monitor: + set_monitor() diff --git a/ryu/contrib/ovs/db/__init__.py b/ryu/contrib/ovs/db/__init__.py new file mode 100644 index 00000000..218d8921 --- /dev/null +++ b/ryu/contrib/ovs/db/__init__.py @@ -0,0 +1 @@ +# This file intentionally left blank. diff --git a/ryu/contrib/ovs/db/data.py b/ryu/contrib/ovs/db/data.py new file mode 100644 index 00000000..55e7a732 --- /dev/null +++ b/ryu/contrib/ovs/db/data.py @@ -0,0 +1,547 @@ +# Copyright (c) 2009, 2010, 2011 Nicira, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import uuid + +import ovs.poller +import ovs.socket_util +import ovs.json +import ovs.jsonrpc +import ovs.ovsuuid + +import ovs.db.parser +from ovs.db import error +import ovs.db.types + + +class ConstraintViolation(error.Error): + def __init__(self, msg, json=None): + error.Error.__init__(self, msg, json, tag="constraint violation") + + +def escapeCString(src): + dst = [] + for c in src: + if c in "\\\"": + dst.append("\\" + c) + elif ord(c) < 32: + if c == '\n': + dst.append('\\n') + elif c == '\r': + dst.append('\\r') + elif c == '\a': + dst.append('\\a') + elif c == '\b': + dst.append('\\b') + elif c == '\f': + dst.append('\\f') + elif c == '\t': + dst.append('\\t') + elif c == '\v': + dst.append('\\v') + else: + dst.append('\\%03o' % ord(c)) + else: + dst.append(c) + return ''.join(dst) + + +def returnUnchanged(x): + return x + + +class Atom(object): + def __init__(self, type_, value=None): + self.type = type_ + if value is not None: + self.value = value + else: + self.value = type_.default_atom() + + def __cmp__(self, other): + if not isinstance(other, Atom) or self.type != other.type: + return NotImplemented + elif self.value < other.value: + return -1 + elif self.value > other.value: + return 1 + else: + return 0 + + def __hash__(self): + return hash(self.value) + + @staticmethod + def default(type_): + """Returns the default value for the given type_, which must be an + instance of ovs.db.types.AtomicType. + + The default value for each atomic type is; + + - 0, for integer or real atoms. + + - False, for a boolean atom. + + - "", for a string atom. + + - The all-zeros UUID, for a UUID atom.""" + return Atom(type_) + + def is_default(self): + return self == self.default(self.type) + + @staticmethod + def from_json(base, json, symtab=None): + type_ = base.type + json = ovs.db.parser.float_to_int(json) + if ((type_ == ovs.db.types.IntegerType and type(json) in [int, long]) + or (type_ == ovs.db.types.RealType + and type(json) in [int, long, float]) + or (type_ == ovs.db.types.BooleanType and type(json) == bool) + or (type_ == ovs.db.types.StringType + and type(json) in [str, unicode])): + atom = Atom(type_, json) + elif type_ == ovs.db.types.UuidType: + atom = Atom(type_, ovs.ovsuuid.from_json(json, symtab)) + else: + raise error.Error("expected %s" % type_.to_string(), json) + atom.check_constraints(base) + return atom + + @staticmethod + def from_python(base, value): + value = ovs.db.parser.float_to_int(value) + if type(value) in base.type.python_types: + atom = Atom(base.type, value) + else: + raise error.Error("expected %s, got %s" % (base.type, type(value))) + atom.check_constraints(base) + return atom + + def check_constraints(self, base): + """Checks whether 'atom' meets the constraints (if any) defined in + 'base' and raises an ovs.db.error.Error if any constraint is violated. + + 'base' and 'atom' must have the same type. + Checking UUID constraints is deferred to transaction commit time, so + this function does nothing for UUID constraints.""" + assert base.type == self.type + if base.enum is not None and self not in base.enum: + raise ConstraintViolation( + "%s is not one of the allowed values (%s)" + % (self.to_string(), base.enum.to_string())) + elif base.type in [ovs.db.types.IntegerType, ovs.db.types.RealType]: + if ((base.min is None or self.value >= base.min) and + (base.max is None or self.value <= base.max)): + pass + elif base.min is not None and base.max is not None: + raise ConstraintViolation( + "%s is not in the valid range %.15g to %.15g (inclusive)" + % (self.to_string(), base.min, base.max)) + elif base.min is not None: + raise ConstraintViolation( + "%s is less than minimum allowed value %.15g" + % (self.to_string(), base.min)) + else: + raise ConstraintViolation( + "%s is greater than maximum allowed value %.15g" + % (self.to_string(), base.max)) + elif base.type == ovs.db.types.StringType: + # XXX The C version validates that the string is valid UTF-8 here. + # Do we need to do that in Python too? + s = self.value + length = len(s) + if length < base.min_length: + raise ConstraintViolation( + '"%s" length %d is less than minimum allowed length %d' + % (s, length, base.min_length)) + elif length > base.max_length: + raise ConstraintViolation( + '"%s" length %d is greater than maximum allowed ' + 'length %d' % (s, length, base.max_length)) + + def to_json(self): + if self.type == ovs.db.types.UuidType: + return ovs.ovsuuid.to_json(self.value) + else: + return self.value + + def cInitAtom(self, var): + if self.type == ovs.db.types.IntegerType: + return ['%s.integer = %d;' % (var, self.value)] + elif self.type == ovs.db.types.RealType: + return ['%s.real = %.15g;' % (var, self.value)] + elif self.type == ovs.db.types.BooleanType: + if self.value: + return ['%s.boolean = true;'] + else: + return ['%s.boolean = false;'] + elif self.type == ovs.db.types.StringType: + return ['%s.string = xstrdup("%s");' + % (var, escapeCString(self.value))] + elif self.type == ovs.db.types.UuidType: + return ovs.ovsuuid.to_c_assignment(self.value, var) + + def toEnglish(self, escapeLiteral=returnUnchanged): + if self.type == ovs.db.types.IntegerType: + return '%d' % self.value + elif self.type == ovs.db.types.RealType: + return '%.15g' % self.value + elif self.type == ovs.db.types.BooleanType: + if self.value: + return 'true' + else: + return 'false' + elif self.type == ovs.db.types.StringType: + return escapeLiteral(self.value) + elif self.type == ovs.db.types.UuidType: + return self.value.value + + __need_quotes_re = re.compile("$|true|false|[^_a-zA-Z]|.*[^-._a-zA-Z]") + + @staticmethod + def __string_needs_quotes(s): + return Atom.__need_quotes_re.match(s) + + def to_string(self): + if self.type == ovs.db.types.IntegerType: + return '%d' % self.value + elif self.type == ovs.db.types.RealType: + return '%.15g' % self.value + elif self.type == ovs.db.types.BooleanType: + if self.value: + return 'true' + else: + return 'false' + elif self.type == ovs.db.types.StringType: + if Atom.__string_needs_quotes(self.value): + return ovs.json.to_string(self.value) + else: + return self.value + elif self.type == ovs.db.types.UuidType: + return str(self.value) + + @staticmethod + def new(x): + if type(x) in [int, long]: + t = ovs.db.types.IntegerType + elif type(x) == float: + t = ovs.db.types.RealType + elif x in [False, True]: + t = ovs.db.types.BooleanType + elif type(x) in [str, unicode]: + t = ovs.db.types.StringType + elif isinstance(x, uuid): + t = ovs.db.types.UuidType + else: + raise TypeError + return Atom(t, x) + + +class Datum(object): + def __init__(self, type_, values={}): + self.type = type_ + self.values = values + + def __cmp__(self, other): + if not isinstance(other, Datum): + return NotImplemented + elif self.values < other.values: + return -1 + elif self.values > other.values: + return 1 + else: + return 0 + + __hash__ = None + + def __contains__(self, item): + return item in self.values + + def copy(self): + return Datum(self.type, dict(self.values)) + + @staticmethod + def default(type_): + if type_.n_min == 0: + values = {} + elif type_.is_map(): + values = {type_.key.default(): type_.value.default()} + else: + values = {type_.key.default(): None} + return Datum(type_, values) + + def is_default(self): + return self == Datum.default(self.type) + + def check_constraints(self): + """Checks that each of the atoms in 'datum' conforms to the constraints + specified by its 'type' and raises an ovs.db.error.Error. + + This function is not commonly useful because the most ordinary way to + obtain a datum is ultimately via Datum.from_json() or Atom.from_json(), + which check constraints themselves.""" + for keyAtom, valueAtom in self.values.iteritems(): + keyAtom.check_constraints(self.type.key) + if valueAtom is not None: + valueAtom.check_constraints(self.type.value) + + @staticmethod + def from_json(type_, json, symtab=None): + """Parses 'json' as a datum of the type described by 'type'. If + successful, returns a new datum. On failure, raises an + ovs.db.error.Error. + + Violations of constraints expressed by 'type' are treated as errors. + + If 'symtab' is nonnull, then named UUIDs in 'symtab' are accepted. + Refer to ovsdb/SPECS for information about this, and for the syntax + that this function accepts.""" + is_map = type_.is_map() + if (is_map or + (type(json) == list and len(json) > 0 and json[0] == "set")): + if is_map: + class_ = "map" + else: + class_ = "set" + + inner = ovs.db.parser.unwrap_json(json, class_, [list, tuple], + "array") + n = len(inner) + if n < type_.n_min or n > type_.n_max: + raise error.Error("%s must have %d to %d members but %d are " + "present" % (class_, type_.n_min, + type_.n_max, n), + json) + + values = {} + for element in inner: + if is_map: + key, value = ovs.db.parser.parse_json_pair(element) + keyAtom = Atom.from_json(type_.key, key, symtab) + valueAtom = Atom.from_json(type_.value, value, symtab) + else: + keyAtom = Atom.from_json(type_.key, element, symtab) + valueAtom = None + + if keyAtom in values: + if is_map: + raise error.Error("map contains duplicate key") + else: + raise error.Error("set contains duplicate") + + values[keyAtom] = valueAtom + + return Datum(type_, values) + else: + keyAtom = Atom.from_json(type_.key, json, symtab) + return Datum(type_, {keyAtom: None}) + + def to_json(self): + if self.type.is_map(): + return ["map", [[k.to_json(), v.to_json()] + for k, v in sorted(self.values.items())]] + elif len(self.values) == 1: + key = self.values.keys()[0] + return key.to_json() + else: + return ["set", [k.to_json() for k in sorted(self.values.keys())]] + + def to_string(self): + head = tail = None + if self.type.n_max > 1 or len(self.values) == 0: + if self.type.is_map(): + head = "{" + tail = "}" + else: + head = "[" + tail = "]" + + s = [] + if head: + s.append(head) + + for i, key in enumerate(sorted(self.values)): + if i: + s.append(", ") + + s.append(key.to_string()) + if self.type.is_map(): + s.append("=") + s.append(self.values[key].to_string()) + + if tail: + s.append(tail) + return ''.join(s) + + def as_list(self): + if self.type.is_map(): + return [[k.value, v.value] for k, v in self.values.iteritems()] + else: + return [k.value for k in self.values.iterkeys()] + + def as_dict(self): + return dict(self.values) + + def as_scalar(self): + if len(self.values) == 1: + if self.type.is_map(): + k, v = self.values.iteritems()[0] + return [k.value, v.value] + else: + return self.values.keys()[0].value + else: + return None + + def to_python(self, uuid_to_row): + """Returns this datum's value converted into a natural Python + representation of this datum's type, according to the following + rules: + + - If the type has exactly one value and it is not a map (that is, + self.type.is_scalar() returns True), then the value is: + + * An int or long, for an integer column. + + * An int or long or float, for a real column. + + * A bool, for a boolean column. + + * A str or unicode object, for a string column. + + * A uuid.UUID object, for a UUID column without a ref_table. + + * An object represented the referenced row, for a UUID column with + a ref_table. (For the Idl, this object will be an ovs.db.idl.Row + object.) + + If some error occurs (e.g. the database server's idea of the column + is different from the IDL's idea), then the default value for the + scalar type is used (see Atom.default()). + + - Otherwise, if the type is not a map, then the value is a Python list + whose elements have the types described above. + + - Otherwise, the type is a map, and the value is a Python dict that + maps from key to value, with key and value types determined as + described above. + + 'uuid_to_row' must be a function that takes a value and an + ovs.db.types.BaseType and translates UUIDs into row objects.""" + if self.type.is_scalar(): + value = uuid_to_row(self.as_scalar(), self.type.key) + if value is None: + return self.type.key.default() + else: + return value + elif self.type.is_map(): + value = {} + for k, v in self.values.iteritems(): + dk = uuid_to_row(k.value, self.type.key) + dv = uuid_to_row(v.value, self.type.value) + if dk is not None and dv is not None: + value[dk] = dv + return value + else: + s = set() + for k in self.values: + dk = uuid_to_row(k.value, self.type.key) + if dk is not None: + s.add(dk) + return sorted(s) + + @staticmethod + def from_python(type_, value, row_to_uuid): + """Returns a new Datum with the given ovs.db.types.Type 'type_'. The + new datum's value is taken from 'value', which must take the form + described as a valid return value from Datum.to_python() for 'type'. + + Each scalar value within 'value' is initally passed through + 'row_to_uuid', which should convert objects that represent rows (if + any) into uuid.UUID objects and return other data unchanged. + + Raises ovs.db.error.Error if 'value' is not in an appropriate form for + 'type_'.""" + d = {} + if type(value) == dict: + for k, v in value.iteritems(): + ka = Atom.from_python(type_.key, row_to_uuid(k)) + va = Atom.from_python(type_.value, row_to_uuid(v)) + d[ka] = va + elif type(value) in (list, tuple): + for k in value: + ka = Atom.from_python(type_.key, row_to_uuid(k)) + d[ka] = None + else: + ka = Atom.from_python(type_.key, row_to_uuid(value)) + d[ka] = None + + datum = Datum(type_, d) + datum.check_constraints() + if not datum.conforms_to_type(): + raise error.Error("%d values when type requires between %d and %d" + % (len(d), type_.n_min, type_.n_max)) + + return datum + + def __getitem__(self, key): + if not isinstance(key, Atom): + key = Atom.new(key) + if not self.type.is_map(): + raise IndexError + elif key not in self.values: + raise KeyError + else: + return self.values[key].value + + def get(self, key, default=None): + if not isinstance(key, Atom): + key = Atom.new(key) + if key in self.values: + return self.values[key].value + else: + return default + + def __str__(self): + return self.to_string() + + def conforms_to_type(self): + n = len(self.values) + return self.type.n_min <= n <= self.type.n_max + + def cInitDatum(self, var): + if len(self.values) == 0: + return ["ovsdb_datum_init_empty(%s);" % var] + + s = ["%s->n = %d;" % (var, len(self.values))] + s += ["%s->keys = xmalloc(%d * sizeof *%s->keys);" + % (var, len(self.values), var)] + + for i, key in enumerate(sorted(self.values)): + s += key.cInitAtom("%s->keys[%d]" % (var, i)) + + if self.type.value: + s += ["%s->values = xmalloc(%d * sizeof *%s->values);" + % (var, len(self.values), var)] + for i, (key, value) in enumerate(sorted(self.values.items())): + s += value.cInitAtom("%s->values[%d]" % (var, i)) + else: + s += ["%s->values = NULL;" % var] + + if len(self.values) > 1: + s += ["ovsdb_datum_sort_assert(%s, OVSDB_TYPE_%s);" + % (var, self.type.key.type.to_string().upper())] + + return s diff --git a/ryu/contrib/ovs/db/error.py b/ryu/contrib/ovs/db/error.py new file mode 100644 index 00000000..d9217e41 --- /dev/null +++ b/ryu/contrib/ovs/db/error.py @@ -0,0 +1,34 @@ +# Copyright (c) 2009, 2010, 2011 Nicira, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ovs.json + + +class Error(Exception): + def __init__(self, msg, json=None, tag=None): + self.msg = msg + self.json = json + if tag is None: + if json is None: + self.tag = "ovsdb error" + else: + self.tag = "syntax error" + else: + self.tag = tag + + # Compose message. + syntax = "" + if self.json is not None: + syntax = 'syntax "%s": ' % ovs.json.to_string(self.json) + Exception.__init__(self, "%s%s: %s" % (syntax, self.tag, self.msg)) diff --git a/ryu/contrib/ovs/db/idl.py b/ryu/contrib/ovs/db/idl.py new file mode 100644 index 00000000..9e9bf0f5 --- /dev/null +++ b/ryu/contrib/ovs/db/idl.py @@ -0,0 +1,1287 @@ +# Copyright (c) 2009, 2010, 2011, 2012 Nicira, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import uuid + +import ovs.jsonrpc +import ovs.db.parser +import ovs.db.schema +from ovs.db import error +import ovs.ovsuuid +import ovs.poller +import ovs.vlog + +vlog = ovs.vlog.Vlog("idl") + +__pychecker__ = 'no-classattr no-objattrs' + + +class Idl: + """Open vSwitch Database Interface Definition Language (OVSDB IDL). + + The OVSDB IDL maintains an in-memory replica of a database. It issues RPC + requests to an OVSDB database server and parses the responses, converting + raw JSON into data structures that are easier for clients to digest. + + The IDL also assists with issuing database transactions. The client + creates a transaction, manipulates the IDL data structures, and commits or + aborts the transaction. The IDL then composes and issues the necessary + JSON-RPC requests and reports to the client whether the transaction + completed successfully. + + The client is allowed to access the following attributes directly, in a + read-only fashion: + + - 'tables': This is the 'tables' map in the ovs.db.schema.DbSchema provided + to the Idl constructor. Each ovs.db.schema.TableSchema in the map is + annotated with a new attribute 'rows', which is a dict from a uuid.UUID + to a Row object. + + The client may directly read and write the Row objects referenced by the + 'rows' map values. Refer to Row for more details. + + - 'change_seqno': A number that represents the IDL's state. When the IDL + is updated (by Idl.run()), its value changes. The sequence number can + occasionally change even if the database does not. This happens if the + connection to the database drops and reconnects, which causes the + database contents to be reloaded even if they didn't change. (It could + also happen if the database server sends out a "change" that reflects + what the IDL already thought was in the database. The database server is + not supposed to do that, but bugs could in theory cause it to do so.) + + - 'lock_name': The name of the lock configured with Idl.set_lock(), or None + if no lock is configured. + + - 'has_lock': True, if the IDL is configured to obtain a lock and owns that + lock, and False otherwise. + + Locking and unlocking happens asynchronously from the database client's + point of view, so the information is only useful for optimization + (e.g. if the client doesn't have the lock then there's no point in trying + to write to the database). + + - 'is_lock_contended': True, if the IDL is configured to obtain a lock but + the database server has indicated that some other client already owns the + requested lock, and False otherwise. + + - 'txn': The ovs.db.idl.Transaction object for the database transaction + currently being constructed, if there is one, or None otherwise. +""" + + def __init__(self, remote, schema): + """Creates and returns a connection to the database named 'db_name' on + 'remote', which should be in a form acceptable to + ovs.jsonrpc.session.open(). The connection will maintain an in-memory + replica of the remote database. + + 'schema' should be the schema for the remote database. The caller may + have cut it down by removing tables or columns that are not of + interest. The IDL will only replicate the tables and columns that + remain. The caller may also add a attribute named 'alert' to selected + remaining columns, setting its value to False; if so, then changes to + those columns will not be considered changes to the database for the + purpose of the return value of Idl.run() and Idl.change_seqno. This is + useful for columns that the IDL's client will write but not read. + + As a convenience to users, 'schema' may also be an instance of the + SchemaHelper class. + + The IDL uses and modifies 'schema' directly.""" + + assert isinstance(schema, SchemaHelper) + schema = schema.get_idl_schema() + + self.tables = schema.tables + self._db = schema + self._session = ovs.jsonrpc.Session.open(remote) + self._monitor_request_id = None + self._last_seqno = None + self.change_seqno = 0 + + # Database locking. + self.lock_name = None # Name of lock we need, None if none. + self.has_lock = False # Has db server said we have the lock? + self.is_lock_contended = False # Has db server said we can't get lock? + self._lock_request_id = None # JSON-RPC ID of in-flight lock request. + + # Transaction support. + self.txn = None + self._outstanding_txns = {} + + for table in schema.tables.itervalues(): + for column in table.columns.itervalues(): + if not hasattr(column, 'alert'): + column.alert = True + table.need_table = False + table.rows = {} + table.idl = self + + def close(self): + """Closes the connection to the database. The IDL will no longer + update.""" + self._session.close() + + def run(self): + """Processes a batch of messages from the database server. Returns + True if the database as seen through the IDL changed, False if it did + not change. The initial fetch of the entire contents of the remote + database is considered to be one kind of change. If the IDL has been + configured to acquire a database lock (with Idl.set_lock()), then + successfully acquiring the lock is also considered to be a change. + + This function can return occasional false positives, that is, report + that the database changed even though it didn't. This happens if the + connection to the database drops and reconnects, which causes the + database contents to be reloaded even if they didn't change. (It could + also happen if the database server sends out a "change" that reflects + what we already thought was in the database, but the database server is + not supposed to do that.) + + As an alternative to checking the return value, the client may check + for changes in self.change_seqno.""" + assert not self.txn + initial_change_seqno = self.change_seqno + self._session.run() + i = 0 + while i < 50: + i += 1 + if not self._session.is_connected(): + break + + seqno = self._session.get_seqno() + if seqno != self._last_seqno: + self._last_seqno = seqno + self.__txn_abort_all() + self.__send_monitor_request() + if self.lock_name: + self.__send_lock_request() + break + + msg = self._session.recv() + if msg is None: + break + if (msg.type == ovs.jsonrpc.Message.T_NOTIFY + and msg.method == "update" + and len(msg.params) == 2 + and msg.params[0] == None): + # Database contents changed. + self.__parse_update(msg.params[1]) + elif (msg.type == ovs.jsonrpc.Message.T_REPLY + and self._monitor_request_id is not None + and self._monitor_request_id == msg.id): + # Reply to our "monitor" request. + try: + self.change_seqno += 1 + self._monitor_request_id = None + self.__clear() + self.__parse_update(msg.result) + except error.Error, e: + vlog.err("%s: parse error in received schema: %s" + % (self._session.get_name(), e)) + self.__error() + elif (msg.type == ovs.jsonrpc.Message.T_REPLY + and self._lock_request_id is not None + and self._lock_request_id == msg.id): + # Reply to our "lock" request. + self.__parse_lock_reply(msg.result) + elif (msg.type == ovs.jsonrpc.Message.T_NOTIFY + and msg.method == "locked"): + # We got our lock. + self.__parse_lock_notify(msg.params, True) + elif (msg.type == ovs.jsonrpc.Message.T_NOTIFY + and msg.method == "stolen"): + # Someone else stole our lock. + self.__parse_lock_notify(msg.params, False) + elif msg.type == ovs.jsonrpc.Message.T_NOTIFY and msg.id == "echo": + # Reply to our echo request. Ignore it. + pass + elif (msg.type in (ovs.jsonrpc.Message.T_ERROR, + ovs.jsonrpc.Message.T_REPLY) + and self.__txn_process_reply(msg)): + # __txn_process_reply() did everything needed. + pass + else: + # This can happen if a transaction is destroyed before we + # receive the reply, so keep the log level low. + vlog.dbg("%s: received unexpected %s message" + % (self._session.get_name(), + ovs.jsonrpc.Message.type_to_string(msg.type))) + + return initial_change_seqno != self.change_seqno + + def wait(self, poller): + """Arranges for poller.block() to wake up when self.run() has something + to do or when activity occurs on a transaction on 'self'.""" + self._session.wait(poller) + self._session.recv_wait(poller) + + def has_ever_connected(self): + """Returns True, if the IDL successfully connected to the remote + database and retrieved its contents (even if the connection + subsequently dropped and is in the process of reconnecting). If so, + then the IDL contains an atomic snapshot of the database's contents + (but it might be arbitrarily old if the connection dropped). + + Returns False if the IDL has never connected or retrieved the + database's contents. If so, the IDL is empty.""" + return self.change_seqno != 0 + + def force_reconnect(self): + """Forces the IDL to drop its connection to the database and reconnect. + In the meantime, the contents of the IDL will not change.""" + self._session.force_reconnect() + + def set_lock(self, lock_name): + """If 'lock_name' is not None, configures the IDL to obtain the named + lock from the database server and to avoid modifying the database when + the lock cannot be acquired (that is, when another client has the same + lock). + + If 'lock_name' is None, drops the locking requirement and releases the + lock.""" + assert not self.txn + assert not self._outstanding_txns + + if self.lock_name and (not lock_name or lock_name != self.lock_name): + # Release previous lock. + self.__send_unlock_request() + self.lock_name = None + self.is_lock_contended = False + + if lock_name and not self.lock_name: + # Acquire new lock. + self.lock_name = lock_name + self.__send_lock_request() + + def __clear(self): + changed = False + + for table in self.tables.itervalues(): + if table.rows: + changed = True + table.rows = {} + + if changed: + self.change_seqno += 1 + + def __update_has_lock(self, new_has_lock): + if new_has_lock and not self.has_lock: + if self._monitor_request_id is None: + self.change_seqno += 1 + else: + # We're waiting for a monitor reply, so don't signal that the + # database changed. The monitor reply will increment + # change_seqno anyhow. + pass + self.is_lock_contended = False + self.has_lock = new_has_lock + + def __do_send_lock_request(self, method): + self.__update_has_lock(False) + self._lock_request_id = None + if self._session.is_connected(): + msg = ovs.jsonrpc.Message.create_request(method, [self.lock_name]) + msg_id = msg.id + self._session.send(msg) + else: + msg_id = None + return msg_id + + def __send_lock_request(self): + self._lock_request_id = self.__do_send_lock_request("lock") + + def __send_unlock_request(self): + self.__do_send_lock_request("unlock") + + def __parse_lock_reply(self, result): + self._lock_request_id = None + got_lock = type(result) == dict and result.get("locked") is True + self.__update_has_lock(got_lock) + if not got_lock: + self.is_lock_contended = True + + def __parse_lock_notify(self, params, new_has_lock): + if (self.lock_name is not None + and type(params) in (list, tuple) + and params + and params[0] == self.lock_name): + self.__update_has_lock(self, new_has_lock) + if not new_has_lock: + self.is_lock_contended = True + + def __send_monitor_request(self): + monitor_requests = {} + for table in self.tables.itervalues(): + monitor_requests[table.name] = {"columns": table.columns.keys()} + msg = ovs.jsonrpc.Message.create_request( + "monitor", [self._db.name, None, monitor_requests]) + self._monitor_request_id = msg.id + self._session.send(msg) + + def __parse_update(self, update): + try: + self.__do_parse_update(update) + except error.Error, e: + vlog.err("%s: error parsing update: %s" + % (self._session.get_name(), e)) + + def __do_parse_update(self, table_updates): + if type(table_updates) != dict: + raise error.Error(" is not an object", + table_updates) + + for table_name, table_update in table_updates.iteritems(): + table = self.tables.get(table_name) + if not table: + raise error.Error(' includes unknown ' + 'table "%s"' % table_name) + + if type(table_update) != dict: + raise error.Error(' for table "%s" is not ' + 'an object' % table_name, table_update) + + for uuid_string, row_update in table_update.iteritems(): + if not ovs.ovsuuid.is_valid_string(uuid_string): + raise error.Error(' for table "%s" ' + 'contains bad UUID "%s" as member ' + 'name' % (table_name, uuid_string), + table_update) + uuid = ovs.ovsuuid.from_string(uuid_string) + + if type(row_update) != dict: + raise error.Error(' for table "%s" ' + 'contains for %s that ' + 'is not an object' + % (table_name, uuid_string)) + + parser = ovs.db.parser.Parser(row_update, "row-update") + old = parser.get_optional("old", [dict]) + new = parser.get_optional("new", [dict]) + parser.finish() + + if not old and not new: + raise error.Error(' missing "old" and ' + '"new" members', row_update) + + if self.__process_update(table, uuid, old, new): + self.change_seqno += 1 + + def __process_update(self, table, uuid, old, new): + """Returns True if a column changed, False otherwise.""" + row = table.rows.get(uuid) + changed = False + if not new: + # Delete row. + if row: + del table.rows[uuid] + changed = True + else: + # XXX rate-limit + vlog.warn("cannot delete missing row %s from table %s" + % (uuid, table.name)) + elif not old: + # Insert row. + if not row: + row = self.__create_row(table, uuid) + changed = True + else: + # XXX rate-limit + vlog.warn("cannot add existing row %s to table %s" + % (uuid, table.name)) + if self.__row_update(table, row, new): + changed = True + else: + if not row: + row = self.__create_row(table, uuid) + changed = True + # XXX rate-limit + vlog.warn("cannot modify missing row %s in table %s" + % (uuid, table.name)) + if self.__row_update(table, row, new): + changed = True + return changed + + def __row_update(self, table, row, row_json): + changed = False + for column_name, datum_json in row_json.iteritems(): + column = table.columns.get(column_name) + if not column: + # XXX rate-limit + vlog.warn("unknown column %s updating table %s" + % (column_name, table.name)) + continue + + try: + datum = ovs.db.data.Datum.from_json(column.type, datum_json) + except error.Error, e: + # XXX rate-limit + vlog.warn("error parsing column %s in table %s: %s" + % (column_name, table.name, e)) + continue + + if datum != row._data[column_name]: + row._data[column_name] = datum + if column.alert: + changed = True + else: + # Didn't really change but the OVSDB monitor protocol always + # includes every value in a row. + pass + return changed + + def __create_row(self, table, uuid): + data = {} + for column in table.columns.itervalues(): + data[column.name] = ovs.db.data.Datum.default(column.type) + row = table.rows[uuid] = Row(self, table, uuid, data) + return row + + def __error(self): + self._session.force_reconnect() + + def __txn_abort_all(self): + while self._outstanding_txns: + txn = self._outstanding_txns.popitem()[1] + txn._status = Transaction.TRY_AGAIN + + def __txn_process_reply(self, msg): + txn = self._outstanding_txns.pop(msg.id, None) + if txn: + txn._process_reply(msg) + + +def _uuid_to_row(atom, base): + if base.ref_table: + return base.ref_table.rows.get(atom) + else: + return atom + + +def _row_to_uuid(value): + if type(value) == Row: + return value.uuid + else: + return value + + +class Row(object): + """A row within an IDL. + + The client may access the following attributes directly: + + - 'uuid': a uuid.UUID object whose value is the row's database UUID. + + - An attribute for each column in the Row's table, named for the column, + whose values are as returned by Datum.to_python() for the column's type. + + If some error occurs (e.g. the database server's idea of the column is + different from the IDL's idea), then the attribute values is the + "default" value return by Datum.default() for the column's type. (It is + important to know this because the default value may violate constraints + for the column's type, e.g. the default integer value is 0 even if column + contraints require the column's value to be positive.) + + When a transaction is active, column attributes may also be assigned new + values. Committing the transaction will then cause the new value to be + stored into the database. + + *NOTE*: In the current implementation, the value of a column is a *copy* + of the value in the database. This means that modifying its value + directly will have no useful effect. For example, the following: + row.mycolumn["a"] = "b" # don't do this + will not change anything in the database, even after commit. To modify + the column, instead assign the modified column value back to the column: + d = row.mycolumn + d["a"] = "b" + row.mycolumn = d +""" + def __init__(self, idl, table, uuid, data): + # All of the explicit references to self.__dict__ below are required + # to set real attributes with invoking self.__getattr__(). + self.__dict__["uuid"] = uuid + + self.__dict__["_idl"] = idl + self.__dict__["_table"] = table + + # _data is the committed data. It takes the following values: + # + # - A dictionary that maps every column name to a Datum, if the row + # exists in the committed form of the database. + # + # - None, if this row is newly inserted within the active transaction + # and thus has no committed form. + self.__dict__["_data"] = data + + # _changes describes changes to this row within the active transaction. + # It takes the following values: + # + # - {}, the empty dictionary, if no transaction is active or if the + # row has yet not been changed within this transaction. + # + # - A dictionary that maps a column name to its new Datum, if an + # active transaction changes those columns' values. + # + # - A dictionary that maps every column name to a Datum, if the row + # is newly inserted within the active transaction. + # + # - None, if this transaction deletes this row. + self.__dict__["_changes"] = {} + + # A dictionary whose keys are the names of columns that must be + # verified as prerequisites when the transaction commits. The values + # in the dictionary are all None. + self.__dict__["_prereqs"] = {} + + def __getattr__(self, column_name): + assert self._changes is not None + + datum = self._changes.get(column_name) + if datum is None: + if self._data is None: + raise AttributeError("%s instance has no attribute '%s'" % + (self.__class__.__name__, column_name)) + datum = self._data[column_name] + + return datum.to_python(_uuid_to_row) + + def __setattr__(self, column_name, value): + assert self._changes is not None + assert self._idl.txn + + column = self._table.columns[column_name] + try: + datum = ovs.db.data.Datum.from_python(column.type, value, + _row_to_uuid) + except error.Error, e: + # XXX rate-limit + vlog.err("attempting to write bad value to column %s (%s)" + % (column_name, e)) + return + self._idl.txn._write(self, column, datum) + + def verify(self, column_name): + """Causes the original contents of column 'column_name' in this row to + be verified as a prerequisite to completing the transaction. That is, + if 'column_name' changed in this row (or if this row was deleted) + between the time that the IDL originally read its contents and the time + that the transaction commits, then the transaction aborts and + Transaction.commit() returns Transaction.TRY_AGAIN. + + The intention is that, to ensure that no transaction commits based on + dirty reads, an application should call Row.verify() on each data item + read as part of a read-modify-write operation. + + In some cases Row.verify() reduces to a no-op, because the current + value of the column is already known: + + - If this row is a row created by the current transaction (returned + by Transaction.insert()). + + - If the column has already been modified within the current + transaction. + + Because of the latter property, always call Row.verify() *before* + modifying the column, for a given read-modify-write. + + A transaction must be in progress.""" + assert self._idl.txn + assert self._changes is not None + if not self._data or column_name in self._changes: + return + + self._prereqs[column_name] = None + + def delete(self): + """Deletes this row from its table. + + A transaction must be in progress.""" + assert self._idl.txn + assert self._changes is not None + if self._data is None: + del self._idl.txn._txn_rows[self.uuid] + self.__dict__["_changes"] = None + del self._table.rows[self.uuid] + + def increment(self, column_name): + """Causes the transaction, when committed, to increment the value of + 'column_name' within this row by 1. 'column_name' must have an integer + type. After the transaction commits successfully, the client may + retrieve the final (incremented) value of 'column_name' with + Transaction.get_increment_new_value(). + + The client could accomplish something similar by reading and writing + and verify()ing columns. However, increment() will never (by itself) + cause a transaction to fail because of a verify error. + + The intended use is for incrementing the "next_cfg" column in + the Open_vSwitch table.""" + self._idl.txn._increment(self, column_name) + + +def _uuid_name_from_uuid(uuid): + return "row%s" % str(uuid).replace("-", "_") + + +def _where_uuid_equals(uuid): + return [["_uuid", "==", ["uuid", str(uuid)]]] + + +class _InsertedRow(object): + def __init__(self, op_index): + self.op_index = op_index + self.real = None + + +class Transaction(object): + """A transaction may modify the contents of a database by modifying the + values of columns, deleting rows, inserting rows, or adding checks that + columns in the database have not changed ("verify" operations), through + Row methods. + + Reading and writing columns and inserting and deleting rows are all + straightforward. The reasons to verify columns are less obvious. + Verification is the key to maintaining transactional integrity. Because + OVSDB handles multiple clients, it can happen that between the time that + OVSDB client A reads a column and writes a new value, OVSDB client B has + written that column. Client A's write should not ordinarily overwrite + client B's, especially if the column in question is a "map" column that + contains several more or less independent data items. If client A adds a + "verify" operation before it writes the column, then the transaction fails + in case client B modifies it first. Client A will then see the new value + of the column and compose a new transaction based on the new contents + written by client B. + + When a transaction is complete, which must be before the next call to + Idl.run(), call Transaction.commit() or Transaction.abort(). + + The life-cycle of a transaction looks like this: + + 1. Create the transaction and record the initial sequence number: + + seqno = idl.change_seqno(idl) + txn = Transaction(idl) + + 2. Modify the database with Row and Transaction methods. + + 3. Commit the transaction by calling Transaction.commit(). The first call + to this function probably returns Transaction.INCOMPLETE. The client + must keep calling again along as this remains true, calling Idl.run() in + between to let the IDL do protocol processing. (If the client doesn't + have anything else to do in the meantime, it can use + Transaction.commit_block() to avoid having to loop itself.) + + 4. If the final status is Transaction.TRY_AGAIN, wait for Idl.change_seqno + to change from the saved 'seqno' (it's possible that it's already + changed, in which case the client should not wait at all), then start + over from step 1. Only a call to Idl.run() will change the return value + of Idl.change_seqno. (Transaction.commit_block() calls Idl.run().)""" + + # Status values that Transaction.commit() can return. + UNCOMMITTED = "uncommitted" # Not yet committed or aborted. + UNCHANGED = "unchanged" # Transaction didn't include any changes. + INCOMPLETE = "incomplete" # Commit in progress, please wait. + ABORTED = "aborted" # ovsdb_idl_txn_abort() called. + SUCCESS = "success" # Commit successful. + TRY_AGAIN = "try again" # Commit failed because a "verify" operation + # reported an inconsistency, due to a network + # problem, or other transient failure. Wait + # for a change, then try again. + NOT_LOCKED = "not locked" # Server hasn't given us the lock yet. + ERROR = "error" # Commit failed due to a hard error. + + @staticmethod + def status_to_string(status): + """Converts one of the status values that Transaction.commit() can + return into a human-readable string. + + (The status values are in fact such strings already, so + there's nothing to do.)""" + return status + + def __init__(self, idl): + """Starts a new transaction on 'idl' (an instance of ovs.db.idl.Idl). + A given Idl may only have a single active transaction at a time. + + A Transaction may modify the contents of a database by assigning new + values to columns (attributes of Row), deleting rows (with + Row.delete()), or inserting rows (with Transaction.insert()). It may + also check that columns in the database have not changed with + Row.verify(). + + When a transaction is complete (which must be before the next call to + Idl.run()), call Transaction.commit() or Transaction.abort().""" + assert idl.txn is None + + idl.txn = self + self._request_id = None + self.idl = idl + self.dry_run = False + self._txn_rows = {} + self._status = Transaction.UNCOMMITTED + self._error = None + self._comments = [] + self._commit_seqno = self.idl.change_seqno + + self._inc_row = None + self._inc_column = None + + self._inserted_rows = {} # Map from UUID to _InsertedRow + + def add_comment(self, comment): + """Appens 'comment' to the comments that will be passed to the OVSDB + server when this transaction is committed. (The comment will be + committed to the OVSDB log, which "ovsdb-tool show-log" can print in a + relatively human-readable form.)""" + self._comments.append(comment) + + def wait(self, poller): + """Causes poll_block() to wake up if this transaction has completed + committing.""" + if self._status not in (Transaction.UNCOMMITTED, + Transaction.INCOMPLETE): + poller.immediate_wake() + + def _substitute_uuids(self, json): + if type(json) in (list, tuple): + if (len(json) == 2 + and json[0] == 'uuid' + and ovs.ovsuuid.is_valid_string(json[1])): + uuid = ovs.ovsuuid.from_string(json[1]) + row = self._txn_rows.get(uuid, None) + if row and row._data is None: + return ["named-uuid", _uuid_name_from_uuid(uuid)] + else: + return [self._substitute_uuids(elem) for elem in json] + return json + + def __disassemble(self): + self.idl.txn = None + + for row in self._txn_rows.itervalues(): + if row._changes is None: + row._table.rows[row.uuid] = row + elif row._data is None: + del row._table.rows[row.uuid] + row.__dict__["_changes"] = {} + row.__dict__["_prereqs"] = {} + self._txn_rows = {} + + def commit(self): + """Attempts to commit 'txn'. Returns the status of the commit + operation, one of the following constants: + + Transaction.INCOMPLETE: + + The transaction is in progress, but not yet complete. The caller + should call again later, after calling Idl.run() to let the + IDL do OVSDB protocol processing. + + Transaction.UNCHANGED: + + The transaction is complete. (It didn't actually change the + database, so the IDL didn't send any request to the database + server.) + + Transaction.ABORTED: + + The caller previously called Transaction.abort(). + + Transaction.SUCCESS: + + The transaction was successful. The update made by the + transaction (and possibly other changes made by other database + clients) should already be visible in the IDL. + + Transaction.TRY_AGAIN: + + The transaction failed for some transient reason, e.g. because a + "verify" operation reported an inconsistency or due to a network + problem. The caller should wait for a change to the database, + then compose a new transaction, and commit the new transaction. + + Use Idl.change_seqno to wait for a change in the database. It is + important to use its value *before* the initial call to + Transaction.commit() as the baseline for this purpose, because + the change that one should wait for can happen after the initial + call but before the call that returns Transaction.TRY_AGAIN, and + using some other baseline value in that situation could cause an + indefinite wait if the database rarely changes. + + Transaction.NOT_LOCKED: + + The transaction failed because the IDL has been configured to + require a database lock (with Idl.set_lock()) but didn't + get it yet or has already lost it. + + Committing a transaction rolls back all of the changes that it made to + the IDL's copy of the database. If the transaction commits + successfully, then the database server will send an update and, thus, + the IDL will be updated with the committed changes.""" + # The status can only change if we're the active transaction. + # (Otherwise, our status will change only in Idl.run().) + if self != self.idl.txn: + return self._status + + # If we need a lock but don't have it, give up quickly. + if self.idl.lock_name and not self.idl.has_lock(): + self._status = Transaction.NOT_LOCKED + self.__disassemble() + return self._status + + operations = [self.idl._db.name] + + # Assert that we have the required lock (avoiding a race). + if self.idl.lock_name: + operations.append({"op": "assert", + "lock": self.idl.lock_name}) + + # Add prerequisites and declarations of new rows. + for row in self._txn_rows.itervalues(): + if row._prereqs: + rows = {} + columns = [] + for column_name in row._prereqs: + columns.append(column_name) + rows[column_name] = row._data[column_name].to_json() + operations.append({"op": "wait", + "table": row._table.name, + "timeout": 0, + "where": _where_uuid_equals(row.uuid), + "until": "==", + "columns": columns, + "rows": [rows]}) + + # Add updates. + any_updates = False + for row in self._txn_rows.itervalues(): + if row._changes is None: + if row._table.is_root: + operations.append({"op": "delete", + "table": row._table.name, + "where": _where_uuid_equals(row.uuid)}) + any_updates = True + else: + # Let ovsdb-server decide whether to really delete it. + pass + elif row._changes: + op = {"table": row._table.name} + if row._data is None: + op["op"] = "insert" + op["uuid-name"] = _uuid_name_from_uuid(row.uuid) + any_updates = True + + op_index = len(operations) - 1 + self._inserted_rows[row.uuid] = _InsertedRow(op_index) + else: + op["op"] = "update" + op["where"] = _where_uuid_equals(row.uuid) + + row_json = {} + op["row"] = row_json + + for column_name, datum in row._changes.iteritems(): + if row._data is not None or not datum.is_default(): + row_json[column_name] = ( + self._substitute_uuids(datum.to_json())) + + # If anything really changed, consider it an update. + # We can't suppress not-really-changed values earlier + # or transactions would become nonatomic (see the big + # comment inside Transaction._write()). + if (not any_updates and row._data is not None and + row._data[column_name] != datum): + any_updates = True + + if row._data is None or row_json: + operations.append(op) + + # Add increment. + if self._inc_row and any_updates: + self._inc_index = len(operations) - 1 + + operations.append({"op": "mutate", + "table": self._inc_row._table.name, + "where": self._substitute_uuids( + _where_uuid_equals(self._inc_row.uuid)), + "mutations": [[self._inc_column, "+=", 1]]}) + operations.append({"op": "select", + "table": self._inc_row._table.name, + "where": self._substitute_uuids( + _where_uuid_equals(self._inc_row.uuid)), + "columns": [self._inc_column]}) + + # Add comment. + if self._comments: + operations.append({"op": "comment", + "comment": "\n".join(self._comments)}) + + # Dry run? + if self.dry_run: + operations.append({"op": "abort"}) + + if not any_updates: + self._status = Transaction.UNCHANGED + else: + msg = ovs.jsonrpc.Message.create_request("transact", operations) + self._request_id = msg.id + if not self.idl._session.send(msg): + self.idl._outstanding_txns[self._request_id] = self + self._status = Transaction.INCOMPLETE + else: + self._status = Transaction.TRY_AGAIN + + self.__disassemble() + return self._status + + def commit_block(self): + """Attempts to commit this transaction, blocking until the commit + either succeeds or fails. Returns the final commit status, which may + be any Transaction.* value other than Transaction.INCOMPLETE. + + This function calls Idl.run() on this transaction'ss IDL, so it may + cause Idl.change_seqno to change.""" + while True: + status = self.commit() + if status != Transaction.INCOMPLETE: + return status + + self.idl.run() + + poller = ovs.poller.Poller() + self.idl.wait(poller) + self.wait(poller) + poller.block() + + def get_increment_new_value(self): + """Returns the final (incremented) value of the column in this + transaction that was set to be incremented by Row.increment. This + transaction must have committed successfully.""" + assert self._status == Transaction.SUCCESS + return self._inc_new_value + + def abort(self): + """Aborts this transaction. If Transaction.commit() has already been + called then the transaction might get committed anyhow.""" + self.__disassemble() + if self._status in (Transaction.UNCOMMITTED, + Transaction.INCOMPLETE): + self._status = Transaction.ABORTED + + def get_error(self): + """Returns a string representing this transaction's current status, + suitable for use in log messages.""" + if self._status != Transaction.ERROR: + return Transaction.status_to_string(self._status) + elif self._error: + return self._error + else: + return "no error details available" + + def __set_error_json(self, json): + if self._error is None: + self._error = ovs.json.to_string(json) + + def get_insert_uuid(self, uuid): + """Finds and returns the permanent UUID that the database assigned to a + newly inserted row, given the UUID that Transaction.insert() assigned + locally to that row. + + Returns None if 'uuid' is not a UUID assigned by Transaction.insert() + or if it was assigned by that function and then deleted by Row.delete() + within the same transaction. (Rows that are inserted and then deleted + within a single transaction are never sent to the database server, so + it never assigns them a permanent UUID.) + + This transaction must have completed successfully.""" + assert self._status in (Transaction.SUCCESS, + Transaction.UNCHANGED) + inserted_row = self._inserted_rows.get(uuid) + if inserted_row: + return inserted_row.real + return None + + def _increment(self, row, column): + assert not self._inc_row + self._inc_row = row + self._inc_column = column + + def _write(self, row, column, datum): + assert row._changes is not None + + txn = row._idl.txn + + # If this is a write-only column and the datum being written is the + # same as the one already there, just skip the update entirely. This + # is worth optimizing because we have a lot of columns that get + # periodically refreshed into the database but don't actually change + # that often. + # + # We don't do this for read/write columns because that would break + # atomicity of transactions--some other client might have written a + # different value in that column since we read it. (But if a whole + # transaction only does writes of existing values, without making any + # real changes, we will drop the whole transaction later in + # ovsdb_idl_txn_commit().) + if not column.alert and row._data.get(column.name) == datum: + new_value = row._changes.get(column.name) + if new_value is None or new_value == datum: + return + + txn._txn_rows[row.uuid] = row + row._changes[column.name] = datum.copy() + + def insert(self, table, new_uuid=None): + """Inserts and returns a new row in 'table', which must be one of the + ovs.db.schema.TableSchema objects in the Idl's 'tables' dict. + + The new row is assigned a provisional UUID. If 'uuid' is None then one + is randomly generated; otherwise 'uuid' should specify a randomly + generated uuid.UUID not otherwise in use. ovsdb-server will assign a + different UUID when 'txn' is committed, but the IDL will replace any + uses of the provisional UUID in the data to be to be committed by the + UUID assigned by ovsdb-server.""" + assert self._status == Transaction.UNCOMMITTED + if new_uuid is None: + new_uuid = uuid.uuid4() + row = Row(self.idl, table, new_uuid, None) + table.rows[row.uuid] = row + self._txn_rows[row.uuid] = row + return row + + def _process_reply(self, msg): + if msg.type == ovs.jsonrpc.Message.T_ERROR: + self._status = Transaction.ERROR + elif type(msg.result) not in (list, tuple): + # XXX rate-limit + vlog.warn('reply to "transact" is not JSON array') + else: + hard_errors = False + soft_errors = False + lock_errors = False + + ops = msg.result + for op in ops: + if op is None: + # This isn't an error in itself but indicates that some + # prior operation failed, so make sure that we know about + # it. + soft_errors = True + elif type(op) == dict: + error = op.get("error") + if error is not None: + if error == "timed out": + soft_errors = True + elif error == "not owner": + lock_errors = True + elif error == "aborted": + pass + else: + hard_errors = True + self.__set_error_json(op) + else: + hard_errors = True + self.__set_error_json(op) + # XXX rate-limit + vlog.warn("operation reply is not JSON null or object") + + if not soft_errors and not hard_errors and not lock_errors: + if self._inc_row and not self.__process_inc_reply(ops): + hard_errors = True + + for insert in self._inserted_rows.itervalues(): + if not self.__process_insert_reply(insert, ops): + hard_errors = True + + if hard_errors: + self._status = Transaction.ERROR + elif lock_errors: + self._status = Transaction.NOT_LOCKED + elif soft_errors: + self._status = Transaction.TRY_AGAIN + else: + self._status = Transaction.SUCCESS + + @staticmethod + def __check_json_type(json, types, name): + if not json: + # XXX rate-limit + vlog.warn("%s is missing" % name) + return False + elif type(json) not in types: + # XXX rate-limit + vlog.warn("%s has unexpected type %s" % (name, type(json))) + return False + else: + return True + + def __process_inc_reply(self, ops): + if self._inc_index + 2 > len(ops): + # XXX rate-limit + vlog.warn("reply does not contain enough operations for " + "increment (has %d, needs %d)" % + (len(ops), self._inc_index + 2)) + + # We know that this is a JSON object because the loop in + # __process_reply() already checked. + mutate = ops[self._inc_index] + count = mutate.get("count") + if not Transaction.__check_json_type(count, (int, long), + '"mutate" reply "count"'): + return False + if count != 1: + # XXX rate-limit + vlog.warn('"mutate" reply "count" is %d instead of 1' % count) + return False + + select = ops[self._inc_index + 1] + rows = select.get("rows") + if not Transaction.__check_json_type(rows, (list, tuple), + '"select" reply "rows"'): + return False + if len(rows) != 1: + # XXX rate-limit + vlog.warn('"select" reply "rows" has %d elements ' + 'instead of 1' % len(rows)) + return False + row = rows[0] + if not Transaction.__check_json_type(row, (dict,), + '"select" reply row'): + return False + column = row.get(self._inc_column) + if not Transaction.__check_json_type(column, (int, long), + '"select" reply inc column'): + return False + self._inc_new_value = column + return True + + def __process_insert_reply(self, insert, ops): + if insert.op_index >= len(ops): + # XXX rate-limit + vlog.warn("reply does not contain enough operations " + "for insert (has %d, needs %d)" + % (len(ops), insert.op_index)) + return False + + # We know that this is a JSON object because the loop in + # __process_reply() already checked. + reply = ops[insert.op_index] + json_uuid = reply.get("uuid") + if not Transaction.__check_json_type(json_uuid, (tuple, list), + '"insert" reply "uuid"'): + return False + + try: + uuid_ = ovs.ovsuuid.from_json(json_uuid) + except error.Error: + # XXX rate-limit + vlog.warn('"insert" reply "uuid" is not a JSON UUID') + return False + + insert.real = uuid_ + return True + + +class SchemaHelper(object): + """IDL Schema helper. + + This class encapsulates the logic required to generate schemas suitable + for creating 'ovs.db.idl.Idl' objects. Clients should register columns + they are interested in using register_columns(). When finished, the + get_idl_schema() function may be called. + + The location on disk of the schema used may be found in the + 'schema_location' variable.""" + + def __init__(self, location=None, schema_json=None): + """Creates a new Schema object. + + 'location' file path to ovs schema. None means default location + 'schema_json' schema in json preresentation in memory + """ + + if location and schema_json: + raise ValueError("both location and schema_json can't be " + "specified. it's ambiguous.") + if schema_json is None: + if location is None: + location = "%s/vswitch.ovsschema" % ovs.dirs.PKGDATADIR + schema_json = ovs.json.from_file(location) + + self.schema_json = schema_json + self._tables = {} + self._all = False + + def register_columns(self, table, columns): + """Registers interest in the given 'columns' of 'table'. Future calls + to get_idl_schema() will include 'table':column for each column in + 'columns'. This function automatically avoids adding duplicate entries + to the schema. + + 'table' must be a string. + 'columns' must be a list of strings. + """ + + assert type(table) is str + assert type(columns) is list + + columns = set(columns) | self._tables.get(table, set()) + self._tables[table] = columns + + def register_table(self, table): + """Registers interest in the given all columns of 'table'. Future calls + to get_idl_schema() will include all columns of 'table'. + + 'table' must be a string + """ + assert type(table) is str + self._tables[table] = set() # empty set means all columns in the table + + def register_all(self): + """Registers interest in every column of every table.""" + self._all = True + + def get_idl_schema(self): + """Gets a schema appropriate for the creation of an 'ovs.db.id.IDL' + object based on columns registered using the register_columns() + function.""" + + schema = ovs.db.schema.DbSchema.from_json(self.schema_json) + self.schema_json = None + + if not self._all: + schema_tables = {} + for table, columns in self._tables.iteritems(): + schema_tables[table] = ( + self._keep_table_columns(schema, table, columns)) + + schema.tables = schema_tables + return schema + + def _keep_table_columns(self, schema, table_name, columns): + assert table_name in schema.tables + table = schema.tables[table_name] + + if not columns: + # empty set means all columns in the table + return table + + new_columns = {} + for column_name in columns: + assert type(column_name) is str + assert column_name in table.columns + + new_columns[column_name] = table.columns[column_name] + + table.columns = new_columns + return table diff --git a/ryu/contrib/ovs/db/parser.py b/ryu/contrib/ovs/db/parser.py new file mode 100644 index 00000000..2556becc --- /dev/null +++ b/ryu/contrib/ovs/db/parser.py @@ -0,0 +1,109 @@ +# Copyright (c) 2010, 2011 Nicira, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +from ovs.db import error + + +class Parser(object): + def __init__(self, json, name): + self.name = name + self.json = json + if type(json) != dict: + self.__raise_error("Object expected.") + self.used = set() + + def __get(self, name, types, optional, default=None): + if name in self.json: + self.used.add(name) + member = float_to_int(self.json[name]) + if is_identifier(member) and "id" in types: + return member + if len(types) and type(member) not in types: + self.__raise_error("Type mismatch for member '%s'." % name) + return member + else: + if not optional: + self.__raise_error("Required '%s' member is missing." % name) + return default + + def get(self, name, types): + return self.__get(name, types, False) + + def get_optional(self, name, types, default=None): + return self.__get(name, types, True, default) + + def __raise_error(self, message): + raise error.Error("Parsing %s failed: %s" % (self.name, message), + self.json) + + def finish(self): + missing = set(self.json) - set(self.used) + if missing: + name = missing.pop() + if len(missing) > 1: + present = "and %d other members are" % len(missing) + elif missing: + present = "and 1 other member are" + else: + present = "is" + self.__raise_error("Member '%s' %s present but not allowed here" % + (name, present)) + + +def float_to_int(x): + # XXX still needed? + if type(x) == float: + integer = int(x) + if integer == x and -2 ** 53 <= integer < 2 ** 53: + return integer + return x + + +id_re = re.compile("[_a-zA-Z][_a-zA-Z0-9]*$") + + +def is_identifier(s): + return type(s) in [str, unicode] and id_re.match(s) + + +def json_type_to_string(type_): + if type_ == None: + return "null" + elif type_ == bool: + return "boolean" + elif type_ == dict: + return "object" + elif type_ == list: + return "array" + elif type_ in [int, long, float]: + return "number" + elif type_ in [str, unicode]: + return "string" + else: + return "" + + +def unwrap_json(json, name, types, desc): + if (type(json) not in (list, tuple) or len(json) != 2 or json[0] != name or + type(json[1]) not in types): + raise error.Error('expected ["%s", <%s>]' % (name, desc), json) + return json[1] + + +def parse_json_pair(json): + if type(json) != list or len(json) != 2: + raise error.Error("expected 2-element array", json) + return json diff --git a/ryu/contrib/ovs/db/schema.py b/ryu/contrib/ovs/db/schema.py new file mode 100644 index 00000000..1b5a771f --- /dev/null +++ b/ryu/contrib/ovs/db/schema.py @@ -0,0 +1,271 @@ +# Copyright (c) 2009, 2010, 2011 Nicira, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import sys + +from ovs.db import error +import ovs.db.parser +from ovs.db import types + + +def _check_id(name, json): + if name.startswith('_'): + raise error.Error('names beginning with "_" are reserved', json) + elif not ovs.db.parser.is_identifier(name): + raise error.Error("name must be a valid id", json) + + +class DbSchema(object): + """Schema for an OVSDB database.""" + + def __init__(self, name, version, tables): + self.name = name + self.version = version + self.tables = tables + + # "isRoot" was not part of the original schema definition. Before it + # was added, there was no support for garbage collection. So, for + # backward compatibility, if the root set is empty then assume that + # every table is in the root set. + if self.__root_set_size() == 0: + for table in self.tables.itervalues(): + table.is_root = True + + # Find the "ref_table"s referenced by "ref_table_name"s. + # + # Also force certain columns to be persistent, as explained in + # __check_ref_table(). This requires 'is_root' to be known, so this + # must follow the loop updating 'is_root' above. + for table in self.tables.itervalues(): + for column in table.columns.itervalues(): + self.__follow_ref_table(column, column.type.key, "key") + self.__follow_ref_table(column, column.type.value, "value") + + def __root_set_size(self): + """Returns the number of tables in the schema's root set.""" + n_root = 0 + for table in self.tables.itervalues(): + if table.is_root: + n_root += 1 + return n_root + + @staticmethod + def from_json(json): + parser = ovs.db.parser.Parser(json, "database schema") + name = parser.get("name", ['id']) + version = parser.get_optional("version", [str, unicode]) + parser.get_optional("cksum", [str, unicode]) + tablesJson = parser.get("tables", [dict]) + parser.finish() + + if (version is not None and + not re.match('[0-9]+\.[0-9]+\.[0-9]+$', version)): + raise error.Error('schema version "%s" not in format x.y.z' + % version) + + tables = {} + for tableName, tableJson in tablesJson.iteritems(): + _check_id(tableName, json) + tables[tableName] = TableSchema.from_json(tableJson, tableName) + + return DbSchema(name, version, tables) + + def to_json(self): + # "isRoot" was not part of the original schema definition. Before it + # was added, there was no support for garbage collection. So, for + # backward compatibility, if every table is in the root set then do not + # output "isRoot" in table schemas. + default_is_root = self.__root_set_size() == len(self.tables) + + tables = {} + for table in self.tables.itervalues(): + tables[table.name] = table.to_json(default_is_root) + json = {"name": self.name, "tables": tables} + if self.version: + json["version"] = self.version + return json + + def copy(self): + return DbSchema.from_json(self.to_json()) + + def __follow_ref_table(self, column, base, base_name): + if not base or base.type != types.UuidType or not base.ref_table_name: + return + + base.ref_table = self.tables.get(base.ref_table_name) + if not base.ref_table: + raise error.Error("column %s %s refers to undefined table %s" + % (column.name, base_name, base.ref_table_name), + tag="syntax error") + + if base.is_strong_ref() and not base.ref_table.is_root: + # We cannot allow a strong reference to a non-root table to be + # ephemeral: if it is the only reference to a row, then replaying + # the database log from disk will cause the referenced row to be + # deleted, even though it did exist in memory. If there are + # references to that row later in the log (to modify it, to delete + # it, or just to point to it), then this will yield a transaction + # error. + column.persistent = True + + +class IdlSchema(DbSchema): + def __init__(self, name, version, tables, idlPrefix, idlHeader): + DbSchema.__init__(self, name, version, tables) + self.idlPrefix = idlPrefix + self.idlHeader = idlHeader + + @staticmethod + def from_json(json): + parser = ovs.db.parser.Parser(json, "IDL schema") + idlPrefix = parser.get("idlPrefix", [str, unicode]) + idlHeader = parser.get("idlHeader", [str, unicode]) + + subjson = dict(json) + del subjson["idlPrefix"] + del subjson["idlHeader"] + schema = DbSchema.from_json(subjson) + + return IdlSchema(schema.name, schema.version, schema.tables, + idlPrefix, idlHeader) + + +def column_set_from_json(json, columns): + if json is None: + return tuple(columns) + elif type(json) != list: + raise error.Error("array of distinct column names expected", json) + else: + for column_name in json: + if type(column_name) not in [str, unicode]: + raise error.Error("array of distinct column names expected", + json) + elif column_name not in columns: + raise error.Error("%s is not a valid column name" + % column_name, json) + if len(set(json)) != len(json): + # Duplicate. + raise error.Error("array of distinct column names expected", json) + return tuple([columns[column_name] for column_name in json]) + + +class TableSchema(object): + def __init__(self, name, columns, mutable=True, max_rows=sys.maxint, + is_root=True, indexes=[]): + self.name = name + self.columns = columns + self.mutable = mutable + self.max_rows = max_rows + self.is_root = is_root + self.indexes = indexes + + @staticmethod + def from_json(json, name): + parser = ovs.db.parser.Parser(json, "table schema for table %s" % name) + columns_json = parser.get("columns", [dict]) + mutable = parser.get_optional("mutable", [bool], True) + max_rows = parser.get_optional("maxRows", [int]) + is_root = parser.get_optional("isRoot", [bool], False) + indexes_json = parser.get_optional("indexes", [list], []) + parser.finish() + + if max_rows == None: + max_rows = sys.maxint + elif max_rows <= 0: + raise error.Error("maxRows must be at least 1", json) + + if not columns_json: + raise error.Error("table must have at least one column", json) + + columns = {} + for column_name, column_json in columns_json.iteritems(): + _check_id(column_name, json) + columns[column_name] = ColumnSchema.from_json(column_json, + column_name) + + indexes = [] + for index_json in indexes_json: + index = column_set_from_json(index_json, columns) + if not index: + raise error.Error("index must have at least one column", json) + elif len(index) == 1: + index[0].unique = True + for column in index: + if not column.persistent: + raise error.Error("ephemeral columns (such as %s) may " + "not be indexed" % column.name, json) + indexes.append(index) + + return TableSchema(name, columns, mutable, max_rows, is_root, indexes) + + def to_json(self, default_is_root=False): + """Returns this table schema serialized into JSON. + + The "isRoot" member is included in the JSON only if its value would + differ from 'default_is_root'. Ordinarily 'default_is_root' should be + false, because ordinarily a table would be not be part of the root set + if its "isRoot" member is omitted. However, garbage collection was not + orginally included in OVSDB, so in older schemas that do not include + any "isRoot" members, every table is implicitly part of the root set. + To serialize such a schema in a way that can be read by older OVSDB + tools, specify 'default_is_root' as True. + """ + json = {} + if not self.mutable: + json["mutable"] = False + if default_is_root != self.is_root: + json["isRoot"] = self.is_root + + json["columns"] = columns = {} + for column in self.columns.itervalues(): + if not column.name.startswith("_"): + columns[column.name] = column.to_json() + + if self.max_rows != sys.maxint: + json["maxRows"] = self.max_rows + + if self.indexes: + json["indexes"] = [] + for index in self.indexes: + json["indexes"].append([column.name for column in index]) + + return json + + +class ColumnSchema(object): + def __init__(self, name, mutable, persistent, type_): + self.name = name + self.mutable = mutable + self.persistent = persistent + self.type = type_ + self.unique = False + + @staticmethod + def from_json(json, name): + parser = ovs.db.parser.Parser(json, "schema for column %s" % name) + mutable = parser.get_optional("mutable", [bool], True) + ephemeral = parser.get_optional("ephemeral", [bool], False) + type_ = types.Type.from_json(parser.get("type", [dict, str, unicode])) + parser.finish() + + return ColumnSchema(name, mutable, not ephemeral, type_) + + def to_json(self): + json = {"type": self.type.to_json()} + if not self.mutable: + json["mutable"] = False + if not self.persistent: + json["ephemeral"] = True + return json diff --git a/ryu/contrib/ovs/db/types.py b/ryu/contrib/ovs/db/types.py new file mode 100644 index 00000000..5865acd7 --- /dev/null +++ b/ryu/contrib/ovs/db/types.py @@ -0,0 +1,587 @@ +# Copyright (c) 2009, 2010, 2011, 2012 Nicira, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import uuid + +from ovs.db import error +import ovs.db.parser +import ovs.db.data +import ovs.ovsuuid + + +class AtomicType(object): + def __init__(self, name, default, python_types): + self.name = name + self.default = default + self.python_types = python_types + + @staticmethod + def from_string(s): + if s != "void": + for atomic_type in ATOMIC_TYPES: + if s == atomic_type.name: + return atomic_type + raise error.Error('"%s" is not an atomic-type' % s, s) + + @staticmethod + def from_json(json): + if type(json) not in [str, unicode]: + raise error.Error("atomic-type expected", json) + else: + return AtomicType.from_string(json) + + def __str__(self): + return self.name + + def to_string(self): + return self.name + + def to_json(self): + return self.name + + def default_atom(self): + return ovs.db.data.Atom(self, self.default) + +VoidType = AtomicType("void", None, ()) +IntegerType = AtomicType("integer", 0, (int, long)) +RealType = AtomicType("real", 0.0, (int, long, float)) +BooleanType = AtomicType("boolean", False, (bool,)) +StringType = AtomicType("string", "", (str, unicode)) +UuidType = AtomicType("uuid", ovs.ovsuuid.zero(), (uuid.UUID,)) + +ATOMIC_TYPES = [VoidType, IntegerType, RealType, BooleanType, StringType, + UuidType] + + +def escapeCString(src): + dst = "" + for c in src: + if c in "\\\"": + dst += "\\" + c + elif ord(c) < 32: + if c == '\n': + dst += '\\n' + elif c == '\r': + dst += '\\r' + elif c == '\a': + dst += '\\a' + elif c == '\b': + dst += '\\b' + elif c == '\f': + dst += '\\f' + elif c == '\t': + dst += '\\t' + elif c == '\v': + dst += '\\v' + else: + dst += '\\%03o' % ord(c) + else: + dst += c + return dst + + +def commafy(x): + """Returns integer x formatted in decimal with thousands set off by + commas.""" + return _commafy("%d" % x) + + +def _commafy(s): + if s.startswith('-'): + return '-' + _commafy(s[1:]) + elif len(s) <= 3: + return s + else: + return _commafy(s[:-3]) + ',' + _commafy(s[-3:]) + + +def returnUnchanged(x): + return x + + +class BaseType(object): + def __init__(self, type_, enum=None, min=None, max=None, + min_length=0, max_length=sys.maxint, ref_table_name=None): + assert isinstance(type_, AtomicType) + self.type = type_ + self.enum = enum + self.min = min + self.max = max + self.min_length = min_length + self.max_length = max_length + self.ref_table_name = ref_table_name + if ref_table_name: + self.ref_type = 'strong' + else: + self.ref_type = None + self.ref_table = None + + def default(self): + return ovs.db.data.Atom.default(self.type) + + def __eq__(self, other): + if not isinstance(other, BaseType): + return NotImplemented + return (self.type == other.type and self.enum == other.enum and + self.min == other.min and self.max == other.max and + self.min_length == other.min_length and + self.max_length == other.max_length and + self.ref_table_name == other.ref_table_name) + + def __ne__(self, other): + if not isinstance(other, BaseType): + return NotImplemented + else: + return not (self == other) + + @staticmethod + def __parse_uint(parser, name, default): + value = parser.get_optional(name, [int, long]) + if value is None: + value = default + else: + max_value = 2 ** 32 - 1 + if not (0 <= value <= max_value): + raise error.Error("%s out of valid range 0 to %d" + % (name, max_value), value) + return value + + @staticmethod + def from_json(json): + if type(json) in [str, unicode]: + return BaseType(AtomicType.from_json(json)) + + parser = ovs.db.parser.Parser(json, "ovsdb type") + atomic_type = AtomicType.from_json(parser.get("type", [str, unicode])) + + base = BaseType(atomic_type) + + enum = parser.get_optional("enum", []) + if enum is not None: + base.enum = ovs.db.data.Datum.from_json( + BaseType.get_enum_type(base.type), enum) + elif base.type == IntegerType: + base.min = parser.get_optional("minInteger", [int, long]) + base.max = parser.get_optional("maxInteger", [int, long]) + if (base.min is not None and base.max is not None + and base.min > base.max): + raise error.Error("minInteger exceeds maxInteger", json) + elif base.type == RealType: + base.min = parser.get_optional("minReal", [int, long, float]) + base.max = parser.get_optional("maxReal", [int, long, float]) + if (base.min is not None and base.max is not None + and base.min > base.max): + raise error.Error("minReal exceeds maxReal", json) + elif base.type == StringType: + base.min_length = BaseType.__parse_uint(parser, "minLength", 0) + base.max_length = BaseType.__parse_uint(parser, "maxLength", + sys.maxint) + if base.min_length > base.max_length: + raise error.Error("minLength exceeds maxLength", json) + elif base.type == UuidType: + base.ref_table_name = parser.get_optional("refTable", ['id']) + if base.ref_table_name: + base.ref_type = parser.get_optional("refType", [str, unicode], + "strong") + if base.ref_type not in ['strong', 'weak']: + raise error.Error('refType must be "strong" or "weak" ' + '(not "%s")' % base.ref_type) + parser.finish() + + return base + + def to_json(self): + if not self.has_constraints(): + return self.type.to_json() + + json = {'type': self.type.to_json()} + + if self.enum: + json['enum'] = self.enum.to_json() + + if self.type == IntegerType: + if self.min is not None: + json['minInteger'] = self.min + if self.max is not None: + json['maxInteger'] = self.max + elif self.type == RealType: + if self.min is not None: + json['minReal'] = self.min + if self.max is not None: + json['maxReal'] = self.max + elif self.type == StringType: + if self.min_length != 0: + json['minLength'] = self.min_length + if self.max_length != sys.maxint: + json['maxLength'] = self.max_length + elif self.type == UuidType: + if self.ref_table_name: + json['refTable'] = self.ref_table_name + if self.ref_type != 'strong': + json['refType'] = self.ref_type + return json + + def copy(self): + base = BaseType(self.type, self.enum.copy(), self.min, self.max, + self.min_length, self.max_length, self.ref_table_name) + base.ref_table = self.ref_table + return base + + def is_valid(self): + if self.type in (VoidType, BooleanType, UuidType): + return True + elif self.type in (IntegerType, RealType): + return self.min is None or self.max is None or self.min <= self.max + elif self.type == StringType: + return self.min_length <= self.max_length + else: + return False + + def has_constraints(self): + return (self.enum is not None or self.min is not None or + self.max is not None or + self.min_length != 0 or self.max_length != sys.maxint or + self.ref_table_name is not None) + + def without_constraints(self): + return BaseType(self.type) + + @staticmethod + def get_enum_type(atomic_type): + """Returns the type of the 'enum' member for a BaseType whose + 'type' is 'atomic_type'.""" + return Type(BaseType(atomic_type), None, 1, sys.maxint) + + def is_ref(self): + return self.type == UuidType and self.ref_table_name is not None + + def is_strong_ref(self): + return self.is_ref() and self.ref_type == 'strong' + + def is_weak_ref(self): + return self.is_ref() and self.ref_type == 'weak' + + def toEnglish(self, escapeLiteral=returnUnchanged): + if self.type == UuidType and self.ref_table_name: + s = escapeLiteral(self.ref_table_name) + if self.ref_type == 'weak': + s = "weak reference to " + s + return s + else: + return self.type.to_string() + + def constraintsToEnglish(self, escapeLiteral=returnUnchanged, + escapeNumber=returnUnchanged): + if self.enum: + literals = [value.toEnglish(escapeLiteral) + for value in self.enum.values] + if len(literals) == 2: + english = 'either %s or %s' % (literals[0], literals[1]) + else: + english = 'one of %s, %s, or %s' % (literals[0], + ', '.join(literals[1:-1]), + literals[-1]) + elif self.min is not None and self.max is not None: + if self.type == IntegerType: + english = 'in range %s to %s' % ( + escapeNumber(commafy(self.min)), + escapeNumber(commafy(self.max))) + else: + english = 'in range %s to %s' % ( + escapeNumber("%g" % self.min), + escapeNumber("%g" % self.max)) + elif self.min is not None: + if self.type == IntegerType: + english = 'at least %s' % escapeNumber(commafy(self.min)) + else: + english = 'at least %s' % escapeNumber("%g" % self.min) + elif self.max is not None: + if self.type == IntegerType: + english = 'at most %s' % escapeNumber(commafy(self.max)) + else: + english = 'at most %s' % escapeNumber("%g" % self.max) + elif self.min_length != 0 and self.max_length != sys.maxint: + if self.min_length == self.max_length: + english = ('exactly %s characters long' + % commafy(self.min_length)) + else: + english = ('between %s and %s characters long' + % (commafy(self.min_length), + commafy(self.max_length))) + elif self.min_length != 0: + return 'at least %s characters long' % commafy(self.min_length) + elif self.max_length != sys.maxint: + english = 'at most %s characters long' % commafy(self.max_length) + else: + english = '' + + return english + + def toCType(self, prefix): + if self.ref_table_name: + return "struct %s%s *" % (prefix, self.ref_table_name.lower()) + else: + return {IntegerType: 'int64_t ', + RealType: 'double ', + UuidType: 'struct uuid ', + BooleanType: 'bool ', + StringType: 'char *'}[self.type] + + def toAtomicType(self): + return "OVSDB_TYPE_%s" % self.type.to_string().upper() + + def copyCValue(self, dst, src): + args = {'dst': dst, 'src': src} + if self.ref_table_name: + return ("%(dst)s = %(src)s->header_.uuid;") % args + elif self.type == StringType: + return "%(dst)s = xstrdup(%(src)s);" % args + else: + return "%(dst)s = %(src)s;" % args + + def initCDefault(self, var, is_optional): + if self.ref_table_name: + return "%s = NULL;" % var + elif self.type == StringType and not is_optional: + return '%s = "";' % var + else: + pattern = {IntegerType: '%s = 0;', + RealType: '%s = 0.0;', + UuidType: 'uuid_zero(&%s);', + BooleanType: '%s = false;', + StringType: '%s = NULL;'}[self.type] + return pattern % var + + def cInitBaseType(self, indent, var): + stmts = [] + stmts.append('ovsdb_base_type_init(&%s, %s);' % ( + var, self.toAtomicType())) + if self.enum: + stmts.append("%s.enum_ = xmalloc(sizeof *%s.enum_);" + % (var, var)) + stmts += self.enum.cInitDatum("%s.enum_" % var) + if self.type == IntegerType: + if self.min is not None: + stmts.append('%s.u.integer.min = INT64_C(%d);' + % (var, self.min)) + if self.max is not None: + stmts.append('%s.u.integer.max = INT64_C(%d);' + % (var, self.max)) + elif self.type == RealType: + if self.min is not None: + stmts.append('%s.u.real.min = %d;' % (var, self.min)) + if self.max is not None: + stmts.append('%s.u.real.max = %d;' % (var, self.max)) + elif self.type == StringType: + if self.min_length is not None: + stmts.append('%s.u.string.minLen = %d;' + % (var, self.min_length)) + if self.max_length != sys.maxint: + stmts.append('%s.u.string.maxLen = %d;' + % (var, self.max_length)) + elif self.type == UuidType: + if self.ref_table_name is not None: + stmts.append('%s.u.uuid.refTableName = "%s";' + % (var, escapeCString(self.ref_table_name))) + stmts.append('%s.u.uuid.refType = OVSDB_REF_%s;' + % (var, self.ref_type.upper())) + return '\n'.join([indent + stmt for stmt in stmts]) + + +class Type(object): + DEFAULT_MIN = 1 + DEFAULT_MAX = 1 + + def __init__(self, key, value=None, n_min=DEFAULT_MIN, n_max=DEFAULT_MAX): + self.key = key + self.value = value + self.n_min = n_min + self.n_max = n_max + + def copy(self): + if self.value is None: + value = None + else: + value = self.value.copy() + return Type(self.key.copy(), value, self.n_min, self.n_max) + + def __eq__(self, other): + if not isinstance(other, Type): + return NotImplemented + return (self.key == other.key and self.value == other.value and + self.n_min == other.n_min and self.n_max == other.n_max) + + def __ne__(self, other): + if not isinstance(other, Type): + return NotImplemented + else: + return not (self == other) + + def is_valid(self): + return (self.key.type != VoidType and self.key.is_valid() and + (self.value is None or + (self.value.type != VoidType and self.value.is_valid())) and + self.n_min <= 1 <= self.n_max) + + def is_scalar(self): + return self.n_min == 1 and self.n_max == 1 and not self.value + + def is_optional(self): + return self.n_min == 0 and self.n_max == 1 + + def is_composite(self): + return self.n_max > 1 + + def is_set(self): + return self.value is None and (self.n_min != 1 or self.n_max != 1) + + def is_map(self): + return self.value is not None + + def is_smap(self): + return (self.is_map() + and self.key.type == StringType + and self.value.type == StringType) + + def is_optional_pointer(self): + return (self.is_optional() and not self.value + and (self.key.type == StringType or self.key.ref_table_name)) + + @staticmethod + def __n_from_json(json, default): + if json is None: + return default + elif type(json) == int and 0 <= json <= sys.maxint: + return json + else: + raise error.Error("bad min or max value", json) + + @staticmethod + def from_json(json): + if type(json) in [str, unicode]: + return Type(BaseType.from_json(json)) + + parser = ovs.db.parser.Parser(json, "ovsdb type") + key_json = parser.get("key", [dict, str, unicode]) + value_json = parser.get_optional("value", [dict, str, unicode]) + min_json = parser.get_optional("min", [int]) + max_json = parser.get_optional("max", [int, str, unicode]) + parser.finish() + + key = BaseType.from_json(key_json) + if value_json: + value = BaseType.from_json(value_json) + else: + value = None + + n_min = Type.__n_from_json(min_json, Type.DEFAULT_MIN) + + if max_json == 'unlimited': + n_max = sys.maxint + else: + n_max = Type.__n_from_json(max_json, Type.DEFAULT_MAX) + + type_ = Type(key, value, n_min, n_max) + if not type_.is_valid(): + raise error.Error("ovsdb type fails constraint checks", json) + return type_ + + def to_json(self): + if self.is_scalar() and not self.key.has_constraints(): + return self.key.to_json() + + json = {"key": self.key.to_json()} + if self.value is not None: + json["value"] = self.value.to_json() + if self.n_min != Type.DEFAULT_MIN: + json["min"] = self.n_min + if self.n_max == sys.maxint: + json["max"] = "unlimited" + elif self.n_max != Type.DEFAULT_MAX: + json["max"] = self.n_max + return json + + def toEnglish(self, escapeLiteral=returnUnchanged): + keyName = self.key.toEnglish(escapeLiteral) + if self.value: + valueName = self.value.toEnglish(escapeLiteral) + + if self.is_scalar(): + return keyName + elif self.is_optional(): + if self.value: + return "optional %s-%s pair" % (keyName, valueName) + else: + return "optional %s" % keyName + else: + if self.n_max == sys.maxint: + if self.n_min: + quantity = "%s or more " % commafy(self.n_min) + else: + quantity = "" + elif self.n_min: + quantity = "%s to %s " % (commafy(self.n_min), + commafy(self.n_max)) + else: + quantity = "up to %s " % commafy(self.n_max) + + if self.value: + return "map of %s%s-%s pairs" % (quantity, keyName, valueName) + else: + if keyName.endswith('s'): + plural = keyName + "es" + else: + plural = keyName + "s" + return "set of %s%s" % (quantity, plural) + + def constraintsToEnglish(self, escapeLiteral=returnUnchanged, + escapeNumber=returnUnchanged): + constraints = [] + keyConstraints = self.key.constraintsToEnglish(escapeLiteral, + escapeNumber) + if keyConstraints: + if self.value: + constraints.append('key %s' % keyConstraints) + else: + constraints.append(keyConstraints) + + if self.value: + valueConstraints = self.value.constraintsToEnglish(escapeLiteral, + escapeNumber) + if valueConstraints: + constraints.append('value %s' % valueConstraints) + + return ', '.join(constraints) + + def cDeclComment(self): + if self.n_min == 1 and self.n_max == 1 and self.key.type == StringType: + return "\t/* Always nonnull. */" + else: + return "" + + def cInitType(self, indent, var): + initKey = self.key.cInitBaseType(indent, "%s.key" % var) + if self.value: + initValue = self.value.cInitBaseType(indent, "%s.value" % var) + else: + initValue = ('%sovsdb_base_type_init(&%s.value, ' + 'OVSDB_TYPE_VOID);' % (indent, var)) + initMin = "%s%s.n_min = %s;" % (indent, var, self.n_min) + if self.n_max == sys.maxint: + n_max = "UINT_MAX" + else: + n_max = self.n_max + initMax = "%s%s.n_max = %s;" % (indent, var, n_max) + return "\n".join((initKey, initValue, initMin, initMax)) diff --git a/ryu/contrib/ovs/dirs.py.template b/ryu/contrib/ovs/dirs.py.template new file mode 100644 index 00000000..370c69f4 --- /dev/null +++ b/ryu/contrib/ovs/dirs.py.template @@ -0,0 +1,17 @@ +## The @variables@ in this file are replaced by default directories for +## use in python/ovs/dirs.py in the source directory and replaced by the +## configured directories for use in the installed python/ovs/dirs.py. +## +import os +PKGDATADIR = os.environ.get("OVS_PKGDATADIR", """@pkgdatadir@""") +RUNDIR = os.environ.get("OVS_RUNDIR", """@RUNDIR@""") +LOGDIR = os.environ.get("OVS_LOGDIR", """@LOGDIR@""") +BINDIR = os.environ.get("OVS_BINDIR", """@bindir@""") + +DBDIR = os.environ.get("OVS_DBDIR") +if not DBDIR: + sysconfdir = os.environ.get("OVS_SYSCONFDIR") + if sysconfdir: + DBDIR = "%s/openvswitch" % sysconfdir + else: + DBDIR = """@DBDIR@""" diff --git a/ryu/contrib/ovs/fatal_signal.py b/ryu/contrib/ovs/fatal_signal.py new file mode 100644 index 00000000..e6fe7838 --- /dev/null +++ b/ryu/contrib/ovs/fatal_signal.py @@ -0,0 +1,136 @@ +# Copyright (c) 2010, 2011 Nicira, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import atexit +import os +import signal + +import ovs.vlog + +_hooks = [] +vlog = ovs.vlog.Vlog("fatal-signal") + + +def add_hook(hook, cancel, run_at_exit): + _init() + _hooks.append((hook, cancel, run_at_exit)) + + +def fork(): + """Clears all of the fatal signal hooks without executing them. If any of + the hooks passed a 'cancel' function to add_hook(), then those functions + will be called, allowing them to free resources, etc. + + Following a fork, one of the resulting processes can call this function to + allow it to terminate without calling the hooks registered before calling + this function. New hooks registered after calling this function will take + effect normally.""" + global _hooks + for hook, cancel, run_at_exit in _hooks: + if cancel: + cancel() + + _hooks = [] + +_added_hook = False +_files = {} + + +def add_file_to_unlink(file): + """Registers 'file' to be unlinked when the program terminates via + sys.exit() or a fatal signal.""" + global _added_hook + if not _added_hook: + _added_hook = True + add_hook(_unlink_files, _cancel_files, True) + _files[file] = None + + +def remove_file_to_unlink(file): + """Unregisters 'file' from being unlinked when the program terminates via + sys.exit() or a fatal signal.""" + if file in _files: + del _files[file] + + +def unlink_file_now(file): + """Like fatal_signal_remove_file_to_unlink(), but also unlinks 'file'. + Returns 0 if successful, otherwise a positive errno value.""" + error = _unlink(file) + if error: + vlog.warn("could not unlink \"%s\" (%s)" % (file, os.strerror(error))) + remove_file_to_unlink(file) + return error + + +def _unlink_files(): + for file_ in _files: + _unlink(file_) + + +def _cancel_files(): + global _added_hook + global _files + _added_hook = False + _files = {} + + +def _unlink(file_): + try: + os.unlink(file_) + return 0 + except OSError, e: + return e.errno + + +def _signal_handler(signr, _): + _call_hooks(signr) + + # Re-raise the signal with the default handling so that the program + # termination status reflects that we were killed by this signal. + signal.signal(signr, signal.SIG_DFL) + os.kill(os.getpid(), signr) + + +def _atexit_handler(): + _call_hooks(0) + + +recurse = False + + +def _call_hooks(signr): + global recurse + if recurse: + return + recurse = True + + for hook, cancel, run_at_exit in _hooks: + if signr != 0 or run_at_exit: + hook() + + +_inited = False + + +def _init(): + global _inited + if not _inited: + _inited = True + + for signr in (signal.SIGTERM, signal.SIGINT, + signal.SIGHUP, signal.SIGALRM): + if signal.getsignal(signr) == signal.SIG_DFL: + signal.signal(signr, _signal_handler) + atexit.register(_atexit_handler) diff --git a/ryu/contrib/ovs/json.py b/ryu/contrib/ovs/json.py new file mode 100644 index 00000000..d329ee41 --- /dev/null +++ b/ryu/contrib/ovs/json.py @@ -0,0 +1,586 @@ +# Copyright (c) 2010, 2011, 2012 Nicira, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import StringIO +import sys + +__pychecker__ = 'no-stringiter' + +escapes = {ord('"'): u"\\\"", + ord("\\"): u"\\\\", + ord("\b"): u"\\b", + ord("\f"): u"\\f", + ord("\n"): u"\\n", + ord("\r"): u"\\r", + ord("\t"): u"\\t"} +for esc in range(32): + if esc not in escapes: + escapes[esc] = u"\\u%04x" % esc + +SPACES_PER_LEVEL = 2 + + +class _Serializer(object): + def __init__(self, stream, pretty, sort_keys): + self.stream = stream + self.pretty = pretty + self.sort_keys = sort_keys + self.depth = 0 + + def __serialize_string(self, s): + self.stream.write(u'"%s"' % ''.join(escapes.get(ord(c), c) for c in s)) + + def __indent_line(self): + if self.pretty: + self.stream.write('\n') + self.stream.write(' ' * (SPACES_PER_LEVEL * self.depth)) + + def serialize(self, obj): + if obj is None: + self.stream.write(u"null") + elif obj is False: + self.stream.write(u"false") + elif obj is True: + self.stream.write(u"true") + elif type(obj) in (int, long): + self.stream.write(u"%d" % obj) + elif type(obj) == float: + self.stream.write("%.15g" % obj) + elif type(obj) == unicode: + self.__serialize_string(obj) + elif type(obj) == str: + self.__serialize_string(unicode(obj)) + elif type(obj) == dict: + self.stream.write(u"{") + + self.depth += 1 + self.__indent_line() + + if self.sort_keys: + items = sorted(obj.items()) + else: + items = obj.iteritems() + for i, (key, value) in enumerate(items): + if i > 0: + self.stream.write(u",") + self.__indent_line() + self.__serialize_string(unicode(key)) + self.stream.write(u":") + if self.pretty: + self.stream.write(u' ') + self.serialize(value) + + self.stream.write(u"}") + self.depth -= 1 + elif type(obj) in (list, tuple): + self.stream.write(u"[") + self.depth += 1 + + if obj: + self.__indent_line() + + for i, value in enumerate(obj): + if i > 0: + self.stream.write(u",") + self.__indent_line() + self.serialize(value) + + self.depth -= 1 + self.stream.write(u"]") + else: + raise Exception("can't serialize %s as JSON" % obj) + + +def to_stream(obj, stream, pretty=False, sort_keys=True): + _Serializer(stream, pretty, sort_keys).serialize(obj) + + +def to_file(obj, name, pretty=False, sort_keys=True): + stream = open(name, "w") + try: + to_stream(obj, stream, pretty, sort_keys) + finally: + stream.close() + + +def to_string(obj, pretty=False, sort_keys=True): + output = StringIO.StringIO() + to_stream(obj, output, pretty, sort_keys) + s = output.getvalue() + output.close() + return s + + +def from_stream(stream): + p = Parser(check_trailer=True) + while True: + buf = stream.read(4096) + if buf == "" or p.feed(buf) != len(buf): + break + return p.finish() + + +def from_file(name): + stream = open(name, "r") + try: + return from_stream(stream) + finally: + stream.close() + + +def from_string(s): + try: + s = unicode(s, 'utf-8') + except UnicodeDecodeError, e: + seq = ' '.join(["0x%2x" % ord(c) + for c in e.object[e.start:e.end] if ord(c) >= 0x80]) + return ("not a valid UTF-8 string: invalid UTF-8 sequence %s" % seq) + p = Parser(check_trailer=True) + p.feed(s) + return p.finish() + + +class Parser(object): + ## Maximum height of parsing stack. ## + MAX_HEIGHT = 1000 + + def __init__(self, check_trailer=False): + self.check_trailer = check_trailer + + # Lexical analysis. + self.lex_state = Parser.__lex_start + self.buffer = "" + self.line_number = 0 + self.column_number = 0 + self.byte_number = 0 + + # Parsing. + self.parse_state = Parser.__parse_start + self.stack = [] + self.member_name = None + + # Parse status. + self.done = False + self.error = None + + def __lex_start_space(self, c): + pass + + def __lex_start_alpha(self, c): + self.buffer = c + self.lex_state = Parser.__lex_keyword + + def __lex_start_token(self, c): + self.__parser_input(c) + + def __lex_start_number(self, c): + self.buffer = c + self.lex_state = Parser.__lex_number + + def __lex_start_string(self, _): + self.lex_state = Parser.__lex_string + + def __lex_start_error(self, c): + if ord(c) >= 32 and ord(c) < 128: + self.__error("invalid character '%s'" % c) + else: + self.__error("invalid character U+%04x" % ord(c)) + + __lex_start_actions = {} + for c in " \t\n\r": + __lex_start_actions[c] = __lex_start_space + for c in "abcdefghijklmnopqrstuvwxyz": + __lex_start_actions[c] = __lex_start_alpha + for c in "[{]}:,": + __lex_start_actions[c] = __lex_start_token + for c in "-0123456789": + __lex_start_actions[c] = __lex_start_number + __lex_start_actions['"'] = __lex_start_string + + def __lex_start(self, c): + Parser.__lex_start_actions.get( + c, Parser.__lex_start_error)(self, c) + return True + + __lex_alpha = {} + for c in "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ": + __lex_alpha[c] = True + + def __lex_finish_keyword(self): + if self.buffer == "false": + self.__parser_input(False) + elif self.buffer == "true": + self.__parser_input(True) + elif self.buffer == "null": + self.__parser_input(None) + else: + self.__error("invalid keyword '%s'" % self.buffer) + + def __lex_keyword(self, c): + if c in Parser.__lex_alpha: + self.buffer += c + return True + else: + self.__lex_finish_keyword() + return False + + __number_re = re.compile("(-)?(0|[1-9][0-9]*)" + "(?:\.([0-9]+))?(?:[eE]([-+]?[0-9]+))?$") + + def __lex_finish_number(self): + s = self.buffer + m = Parser.__number_re.match(s) + if m: + sign, integer, fraction, exp = m.groups() + if (exp is not None and + (long(exp) > sys.maxint or long(exp) < -sys.maxint - 1)): + self.__error("exponent outside valid range") + return + + if fraction is not None and len(fraction.lstrip('0')) == 0: + fraction = None + + sig_string = integer + if fraction is not None: + sig_string += fraction + significand = int(sig_string) + + pow10 = 0 + if fraction is not None: + pow10 -= len(fraction) + if exp is not None: + pow10 += long(exp) + + if significand == 0: + self.__parser_input(0) + return + elif significand <= 2 ** 63: + while pow10 > 0 and significand <= 2 ** 63: + significand *= 10 + pow10 -= 1 + while pow10 < 0 and significand % 10 == 0: + significand /= 10 + pow10 += 1 + if (pow10 == 0 and + ((not sign and significand < 2 ** 63) or + (sign and significand <= 2 ** 63))): + if sign: + self.__parser_input(-significand) + else: + self.__parser_input(significand) + return + + value = float(s) + if value == float("inf") or value == float("-inf"): + self.__error("number outside valid range") + return + if value == 0: + # Suppress negative zero. + value = 0 + self.__parser_input(value) + elif re.match("-?0[0-9]", s): + self.__error("leading zeros not allowed") + elif re.match("-([^0-9]|$)", s): + self.__error("'-' must be followed by digit") + elif re.match("-?(0|[1-9][0-9]*)\.([^0-9]|$)", s): + self.__error("decimal point must be followed by digit") + elif re.search("e[-+]?([^0-9]|$)", s): + self.__error("exponent must contain at least one digit") + else: + self.__error("syntax error in number") + + def __lex_number(self, c): + if c in ".0123456789eE-+": + self.buffer += c + return True + else: + self.__lex_finish_number() + return False + + __4hex_re = re.compile("[0-9a-fA-F]{4}") + + def __lex_4hex(self, s): + if len(s) < 4: + self.__error("quoted string ends within \\u escape") + elif not Parser.__4hex_re.match(s): + self.__error("malformed \\u escape") + elif s == "0000": + self.__error("null bytes not supported in quoted strings") + else: + return int(s, 16) + + @staticmethod + def __is_leading_surrogate(c): + """Returns true if 'c' is a Unicode code point for a leading + surrogate.""" + return c >= 0xd800 and c <= 0xdbff + + @staticmethod + def __is_trailing_surrogate(c): + """Returns true if 'c' is a Unicode code point for a trailing + surrogate.""" + return c >= 0xdc00 and c <= 0xdfff + + @staticmethod + def __utf16_decode_surrogate_pair(leading, trailing): + """Returns the unicode code point corresponding to leading surrogate + 'leading' and trailing surrogate 'trailing'. The return value will not + make any sense if 'leading' or 'trailing' are not in the correct ranges + for leading or trailing surrogates.""" + # Leading surrogate: 110110wwwwxxxxxx + # Trailing surrogate: 110111xxxxxxxxxx + # Code point: 000uuuuuxxxxxxxxxxxxxxxx + w = (leading >> 6) & 0xf + u = w + 1 + x0 = leading & 0x3f + x1 = trailing & 0x3ff + return (u << 16) | (x0 << 10) | x1 + __unescape = {'"': u'"', + "\\": u"\\", + "/": u"/", + "b": u"\b", + "f": u"\f", + "n": u"\n", + "r": u"\r", + "t": u"\t"} + + def __lex_finish_string(self): + inp = self.buffer + out = u"" + while len(inp): + backslash = inp.find('\\') + if backslash == -1: + out += inp + break + out += inp[:backslash] + inp = inp[backslash + 1:] + if inp == "": + self.__error("quoted string may not end with backslash") + return + + replacement = Parser.__unescape.get(inp[0]) + if replacement is not None: + out += replacement + inp = inp[1:] + continue + elif inp[0] != u'u': + self.__error("bad escape \\%s" % inp[0]) + return + + c0 = self.__lex_4hex(inp[1:5]) + if c0 is None: + return + inp = inp[5:] + + if Parser.__is_leading_surrogate(c0): + if inp[:2] != u'\\u': + self.__error("malformed escaped surrogate pair") + return + c1 = self.__lex_4hex(inp[2:6]) + if c1 is None: + return + if not Parser.__is_trailing_surrogate(c1): + self.__error("second half of escaped surrogate pair is " + "not trailing surrogate") + return + code_point = Parser.__utf16_decode_surrogate_pair(c0, c1) + inp = inp[6:] + else: + code_point = c0 + out += unichr(code_point) + self.__parser_input('string', out) + + def __lex_string_escape(self, c): + self.buffer += c + self.lex_state = Parser.__lex_string + return True + + def __lex_string(self, c): + if c == '\\': + self.buffer += c + self.lex_state = Parser.__lex_string_escape + elif c == '"': + self.__lex_finish_string() + elif ord(c) >= 0x20: + self.buffer += c + else: + self.__error("U+%04X must be escaped in quoted string" % ord(c)) + return True + + def __lex_input(self, c): + eat = self.lex_state(self, c) + assert eat is True or eat is False + return eat + + def __parse_start(self, token, unused_string): + if token == '{': + self.__push_object() + elif token == '[': + self.__push_array() + else: + self.__error("syntax error at beginning of input") + + def __parse_end(self, unused_token, unused_string): + self.__error("trailing garbage at end of input") + + def __parse_object_init(self, token, string): + if token == '}': + self.__parser_pop() + else: + self.__parse_object_name(token, string) + + def __parse_object_name(self, token, string): + if token == 'string': + self.member_name = string + self.parse_state = Parser.__parse_object_colon + else: + self.__error("syntax error parsing object expecting string") + + def __parse_object_colon(self, token, unused_string): + if token == ":": + self.parse_state = Parser.__parse_object_value + else: + self.__error("syntax error parsing object expecting ':'") + + def __parse_object_value(self, token, string): + self.__parse_value(token, string, Parser.__parse_object_next) + + def __parse_object_next(self, token, unused_string): + if token == ",": + self.parse_state = Parser.__parse_object_name + elif token == "}": + self.__parser_pop() + else: + self.__error("syntax error expecting '}' or ','") + + def __parse_array_init(self, token, string): + if token == ']': + self.__parser_pop() + else: + self.__parse_array_value(token, string) + + def __parse_array_value(self, token, string): + self.__parse_value(token, string, Parser.__parse_array_next) + + def __parse_array_next(self, token, unused_string): + if token == ",": + self.parse_state = Parser.__parse_array_value + elif token == "]": + self.__parser_pop() + else: + self.__error("syntax error expecting ']' or ','") + + def __parser_input(self, token, string=None): + self.lex_state = Parser.__lex_start + self.buffer = "" + self.parse_state(self, token, string) + + def __put_value(self, value): + top = self.stack[-1] + if type(top) == dict: + top[self.member_name] = value + else: + top.append(value) + + def __parser_push(self, new_json, next_state): + if len(self.stack) < Parser.MAX_HEIGHT: + if len(self.stack) > 0: + self.__put_value(new_json) + self.stack.append(new_json) + self.parse_state = next_state + else: + self.__error("input exceeds maximum nesting depth %d" % + Parser.MAX_HEIGHT) + + def __push_object(self): + self.__parser_push({}, Parser.__parse_object_init) + + def __push_array(self): + self.__parser_push([], Parser.__parse_array_init) + + def __parser_pop(self): + if len(self.stack) == 1: + self.parse_state = Parser.__parse_end + if not self.check_trailer: + self.done = True + else: + self.stack.pop() + top = self.stack[-1] + if type(top) == list: + self.parse_state = Parser.__parse_array_next + else: + self.parse_state = Parser.__parse_object_next + + def __parse_value(self, token, string, next_state): + if token in [False, None, True] or type(token) in [int, long, float]: + self.__put_value(token) + elif token == 'string': + self.__put_value(string) + else: + if token == '{': + self.__push_object() + elif token == '[': + self.__push_array() + else: + self.__error("syntax error expecting value") + return + self.parse_state = next_state + + def __error(self, message): + if self.error is None: + self.error = ("line %d, column %d, byte %d: %s" + % (self.line_number, self.column_number, + self.byte_number, message)) + self.done = True + + def feed(self, s): + i = 0 + while True: + if self.done or i >= len(s): + return i + + c = s[i] + if self.__lex_input(c): + self.byte_number += 1 + if c == '\n': + self.column_number = 0 + self.line_number += 1 + else: + self.column_number += 1 + + i += 1 + + def is_done(self): + return self.done + + def finish(self): + if self.lex_state == Parser.__lex_start: + pass + elif self.lex_state in (Parser.__lex_string, + Parser.__lex_string_escape): + self.__error("unexpected end of input in quoted string") + else: + self.__lex_input(" ") + + if self.parse_state == Parser.__parse_start: + self.__error("empty input stream") + elif self.parse_state != Parser.__parse_end: + self.__error("unexpected end of input") + + if self.error == None: + assert len(self.stack) == 1 + return self.stack.pop() + else: + return self.error diff --git a/ryu/contrib/ovs/jsonrpc.py b/ryu/contrib/ovs/jsonrpc.py new file mode 100644 index 00000000..c1540eb7 --- /dev/null +++ b/ryu/contrib/ovs/jsonrpc.py @@ -0,0 +1,560 @@ +# Copyright (c) 2010, 2011, 2012 Nicira, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import errno +import os + +import ovs.json +import ovs.poller +import ovs.reconnect +import ovs.stream +import ovs.timeval +import ovs.util +import ovs.vlog + +EOF = ovs.util.EOF +vlog = ovs.vlog.Vlog("jsonrpc") + + +class Message(object): + T_REQUEST = 0 # Request. + T_NOTIFY = 1 # Notification. + T_REPLY = 2 # Successful reply. + T_ERROR = 3 # Error reply. + + __types = {T_REQUEST: "request", + T_NOTIFY: "notification", + T_REPLY: "reply", + T_ERROR: "error"} + + def __init__(self, type_, method, params, result, error, id): + self.type = type_ + self.method = method + self.params = params + self.result = result + self.error = error + self.id = id + + _next_id = 0 + + @staticmethod + def _create_id(): + this_id = Message._next_id + Message._next_id += 1 + return this_id + + @staticmethod + def create_request(method, params): + return Message(Message.T_REQUEST, method, params, None, None, + Message._create_id()) + + @staticmethod + def create_notify(method, params): + return Message(Message.T_NOTIFY, method, params, None, None, + None) + + @staticmethod + def create_reply(result, id): + return Message(Message.T_REPLY, None, None, result, None, id) + + @staticmethod + def create_error(error, id): + return Message(Message.T_ERROR, None, None, None, error, id) + + @staticmethod + def type_to_string(type_): + return Message.__types[type_] + + def __validate_arg(self, value, name, must_have): + if (value is not None) == (must_have != 0): + return None + else: + type_name = Message.type_to_string(self.type) + if must_have: + verb = "must" + else: + verb = "must not" + return "%s %s have \"%s\"" % (type_name, verb, name) + + def is_valid(self): + if self.params is not None and type(self.params) != list: + return "\"params\" must be JSON array" + + pattern = {Message.T_REQUEST: 0x11001, + Message.T_NOTIFY: 0x11000, + Message.T_REPLY: 0x00101, + Message.T_ERROR: 0x00011}.get(self.type) + if pattern is None: + return "invalid JSON-RPC message type %s" % self.type + + return ( + self.__validate_arg(self.method, "method", pattern & 0x10000) or + self.__validate_arg(self.params, "params", pattern & 0x1000) or + self.__validate_arg(self.result, "result", pattern & 0x100) or + self.__validate_arg(self.error, "error", pattern & 0x10) or + self.__validate_arg(self.id, "id", pattern & 0x1)) + + @staticmethod + def from_json(json): + if type(json) != dict: + return "message is not a JSON object" + + # Make a copy to avoid modifying the caller's dict. + json = dict(json) + + if "method" in json: + method = json.pop("method") + if type(method) not in [str, unicode]: + return "method is not a JSON string" + else: + method = None + + params = json.pop("params", None) + result = json.pop("result", None) + error = json.pop("error", None) + id_ = json.pop("id", None) + if len(json): + return "message has unexpected member \"%s\"" % json.popitem()[0] + + if result is not None: + msg_type = Message.T_REPLY + elif error is not None: + msg_type = Message.T_ERROR + elif id_ is not None: + msg_type = Message.T_REQUEST + else: + msg_type = Message.T_NOTIFY + + msg = Message(msg_type, method, params, result, error, id_) + validation_error = msg.is_valid() + if validation_error is not None: + return validation_error + else: + return msg + + def to_json(self): + json = {} + + if self.method is not None: + json["method"] = self.method + + if self.params is not None: + json["params"] = self.params + + if self.result is not None or self.type == Message.T_ERROR: + json["result"] = self.result + + if self.error is not None or self.type == Message.T_REPLY: + json["error"] = self.error + + if self.id is not None or self.type == Message.T_NOTIFY: + json["id"] = self.id + + return json + + def __str__(self): + s = [Message.type_to_string(self.type)] + if self.method is not None: + s.append("method=\"%s\"" % self.method) + if self.params is not None: + s.append("params=" + ovs.json.to_string(self.params)) + if self.result is not None: + s.append("result=" + ovs.json.to_string(self.result)) + if self.error is not None: + s.append("error=" + ovs.json.to_string(self.error)) + if self.id is not None: + s.append("id=" + ovs.json.to_string(self.id)) + return ", ".join(s) + + +class Connection(object): + def __init__(self, stream): + self.name = stream.name + self.stream = stream + self.status = 0 + self.input = "" + self.output = "" + self.parser = None + self.received_bytes = 0 + + def close(self): + self.stream.close() + self.stream = None + + def run(self): + if self.status: + return + + while len(self.output): + retval = self.stream.send(self.output) + if retval >= 0: + self.output = self.output[retval:] + else: + if retval != -errno.EAGAIN: + vlog.warn("%s: send error: %s" % + (self.name, os.strerror(-retval))) + self.error(-retval) + break + + def wait(self, poller): + if not self.status: + self.stream.run_wait(poller) + if len(self.output): + self.stream.send_wait(poller) + + def get_status(self): + return self.status + + def get_backlog(self): + if self.status != 0: + return 0 + else: + return len(self.output) + + def get_received_bytes(self): + return self.received_bytes + + def __log_msg(self, title, msg): + vlog.dbg("%s: %s %s" % (self.name, title, msg)) + + def send(self, msg): + if self.status: + return self.status + + self.__log_msg("send", msg) + + was_empty = len(self.output) == 0 + self.output += ovs.json.to_string(msg.to_json()) + if was_empty: + self.run() + return self.status + + def send_block(self, msg): + error = self.send(msg) + if error: + return error + + while True: + self.run() + if not self.get_backlog() or self.get_status(): + return self.status + + poller = ovs.poller.Poller() + self.wait(poller) + poller.block() + + def recv(self): + if self.status: + return self.status, None + + while True: + if not self.input: + error, data = self.stream.recv(4096) + if error: + if error == errno.EAGAIN: + return error, None + else: + # XXX rate-limit + vlog.warn("%s: receive error: %s" + % (self.name, os.strerror(error))) + self.error(error) + return self.status, None + elif not data: + self.error(EOF) + return EOF, None + else: + self.input += data + self.received_bytes += len(data) + else: + if self.parser is None: + self.parser = ovs.json.Parser() + self.input = self.input[self.parser.feed(self.input):] + if self.parser.is_done(): + msg = self.__process_msg() + if msg: + return 0, msg + else: + return self.status, None + + def recv_block(self): + while True: + error, msg = self.recv() + if error != errno.EAGAIN: + return error, msg + + self.run() + + poller = ovs.poller.Poller() + self.wait(poller) + self.recv_wait(poller) + poller.block() + + def transact_block(self, request): + id_ = request.id + + error = self.send(request) + reply = None + while not error: + error, reply = self.recv_block() + if (reply + and (reply.type == Message.T_REPLY + or reply.type == Message.T_ERROR) + and reply.id == id_): + break + return error, reply + + def __process_msg(self): + json = self.parser.finish() + self.parser = None + if type(json) in [str, unicode]: + # XXX rate-limit + vlog.warn("%s: error parsing stream: %s" % (self.name, json)) + self.error(errno.EPROTO) + return + + msg = Message.from_json(json) + if not isinstance(msg, Message): + # XXX rate-limit + vlog.warn("%s: received bad JSON-RPC message: %s" + % (self.name, msg)) + self.error(errno.EPROTO) + return + + self.__log_msg("received", msg) + return msg + + def recv_wait(self, poller): + if self.status or self.input: + poller.immediate_wake() + else: + self.stream.recv_wait(poller) + + def error(self, error): + if self.status == 0: + self.status = error + self.stream.close() + self.output = "" + + +class Session(object): + """A JSON-RPC session with reconnection.""" + + def __init__(self, reconnect, rpc): + self.reconnect = reconnect + self.rpc = rpc + self.stream = None + self.pstream = None + self.seqno = 0 + + @staticmethod + def open(name): + """Creates and returns a Session that maintains a JSON-RPC session to + 'name', which should be a string acceptable to ovs.stream.Stream or + ovs.stream.PassiveStream's initializer. + + If 'name' is an active connection method, e.g. "tcp:127.1.2.3", the new + session connects and reconnects, with back-off, to 'name'. + + If 'name' is a passive connection method, e.g. "ptcp:", the new session + listens for connections to 'name'. It maintains at most one connection + at any given time. Any new connection causes the previous one (if any) + to be dropped.""" + reconnect = ovs.reconnect.Reconnect(ovs.timeval.msec()) + reconnect.set_name(name) + reconnect.enable(ovs.timeval.msec()) + + if ovs.stream.PassiveStream.is_valid_name(name): + reconnect.set_passive(True, ovs.timeval.msec()) + + if ovs.stream.stream_or_pstream_needs_probes(name): + reconnect.set_probe_interval(0) + + return Session(reconnect, None) + + @staticmethod + def open_unreliably(jsonrpc): + reconnect = ovs.reconnect.Reconnect(ovs.timeval.msec()) + reconnect.set_quiet(True) + reconnect.set_name(jsonrpc.name) + reconnect.set_max_tries(0) + reconnect.connected(ovs.timeval.msec()) + return Session(reconnect, jsonrpc) + + def close(self): + if self.rpc is not None: + self.rpc.close() + self.rpc = None + if self.stream is not None: + self.stream.close() + self.stream = None + if self.pstream is not None: + self.pstream.close() + self.pstream = None + + def __disconnect(self): + if self.rpc is not None: + self.rpc.error(EOF) + self.rpc.close() + self.rpc = None + self.seqno += 1 + elif self.stream is not None: + self.stream.close() + self.stream = None + self.seqno += 1 + + def __connect(self): + self.__disconnect() + + name = self.reconnect.get_name() + if not self.reconnect.is_passive(): + error, self.stream = ovs.stream.Stream.open(name) + if not error: + self.reconnect.connecting(ovs.timeval.msec()) + else: + self.reconnect.connect_failed(ovs.timeval.msec(), error) + elif self.pstream is not None: + error, self.pstream = ovs.stream.PassiveStream.open(name) + if not error: + self.reconnect.listening(ovs.timeval.msec()) + else: + self.reconnect.connect_failed(ovs.timeval.msec(), error) + + self.seqno += 1 + + def run(self): + if self.pstream is not None: + error, stream = self.pstream.accept() + if error == 0: + if self.rpc or self.stream: + # XXX rate-limit + vlog.info("%s: new connection replacing active " + "connection" % self.reconnect.get_name()) + self.__disconnect() + self.reconnect.connected(ovs.timeval.msec()) + self.rpc = Connection(stream) + elif error != errno.EAGAIN: + self.reconnect.listen_error(ovs.timeval.msec(), error) + self.pstream.close() + self.pstream = None + + if self.rpc: + backlog = self.rpc.get_backlog() + self.rpc.run() + if self.rpc.get_backlog() < backlog: + # Data previously caught in a queue was successfully sent (or + # there's an error, which we'll catch below). + # + # We don't count data that is successfully sent immediately as + # activity, because there's a lot of queuing downstream from + # us, which means that we can push a lot of data into a + # connection that has stalled and won't ever recover. + self.reconnect.activity(ovs.timeval.msec()) + + error = self.rpc.get_status() + if error != 0: + self.reconnect.disconnected(ovs.timeval.msec(), error) + self.__disconnect() + elif self.stream is not None: + self.stream.run() + error = self.stream.connect() + if error == 0: + self.reconnect.connected(ovs.timeval.msec()) + self.rpc = Connection(self.stream) + self.stream = None + elif error != errno.EAGAIN: + self.reconnect.connect_failed(ovs.timeval.msec(), error) + self.stream.close() + self.stream = None + + action = self.reconnect.run(ovs.timeval.msec()) + if action == ovs.reconnect.CONNECT: + self.__connect() + elif action == ovs.reconnect.DISCONNECT: + self.reconnect.disconnected(ovs.timeval.msec(), 0) + self.__disconnect() + elif action == ovs.reconnect.PROBE: + if self.rpc: + request = Message.create_request("echo", []) + request.id = "echo" + self.rpc.send(request) + else: + assert action == None + + def wait(self, poller): + if self.rpc is not None: + self.rpc.wait(poller) + elif self.stream is not None: + self.stream.run_wait(poller) + self.stream.connect_wait(poller) + if self.pstream is not None: + self.pstream.wait(poller) + self.reconnect.wait(poller, ovs.timeval.msec()) + + def get_backlog(self): + if self.rpc is not None: + return self.rpc.get_backlog() + else: + return 0 + + def get_name(self): + return self.reconnect.get_name() + + def send(self, msg): + if self.rpc is not None: + return self.rpc.send(msg) + else: + return errno.ENOTCONN + + def recv(self): + if self.rpc is not None: + received_bytes = self.rpc.get_received_bytes() + error, msg = self.rpc.recv() + if received_bytes != self.rpc.get_received_bytes(): + # Data was successfully received. + # + # Previously we only counted receiving a full message as + # activity, but with large messages or a slow connection that + # policy could time out the session mid-message. + self.reconnect.activity(ovs.timeval.msec()) + + if not error: + if msg.type == Message.T_REQUEST and msg.method == "echo": + # Echo request. Send reply. + self.send(Message.create_reply(msg.params, msg.id)) + elif msg.type == Message.T_REPLY and msg.id == "echo": + # It's a reply to our echo request. Suppress it. + pass + else: + return msg + return None + + def recv_wait(self, poller): + if self.rpc is not None: + self.rpc.recv_wait(poller) + + def is_alive(self): + if self.rpc is not None or self.stream is not None: + return True + else: + max_tries = self.reconnect.get_max_tries() + return max_tries is None or max_tries > 0 + + def is_connected(self): + return self.rpc is not None + + def get_seqno(self): + return self.seqno + + def force_reconnect(self): + self.reconnect.force_reconnect(ovs.timeval.msec()) diff --git a/ryu/contrib/ovs/ovsuuid.py b/ryu/contrib/ovs/ovsuuid.py new file mode 100644 index 00000000..56fdad05 --- /dev/null +++ b/ryu/contrib/ovs/ovsuuid.py @@ -0,0 +1,70 @@ +# Copyright (c) 2009, 2010, 2011 Nicira, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import uuid + +from ovs.db import error +import ovs.db.parser + +uuidRE = re.compile("^xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx$" + .replace('x', '[0-9a-fA-F]')) + + +def zero(): + return uuid.UUID(int=0) + + +def is_valid_string(s): + return uuidRE.match(s) is not None + + +def from_string(s): + if not is_valid_string(s): + raise error.Error("%s is not a valid UUID" % s) + return uuid.UUID(s) + + +def from_json(json, symtab=None): + try: + s = ovs.db.parser.unwrap_json(json, "uuid", [str, unicode], "string") + if not uuidRE.match(s): + raise error.Error("\"%s\" is not a valid UUID" % s, json) + return uuid.UUID(s) + except error.Error, e: + if not symtab: + raise e + try: + name = ovs.db.parser.unwrap_json(json, "named-uuid", + [str, unicode], "string") + except error.Error: + raise e + + if name not in symtab: + symtab[name] = uuid.uuid4() + return symtab[name] + + +def to_json(uuid_): + return ["uuid", str(uuid_)] + + +def to_c_assignment(uuid_, var): + """Returns an array of strings, each of which contain a C statement. The + statements assign 'uuid_' to a "struct uuid" as defined in Open vSwitch + lib/uuid.h.""" + + hex_string = uuid_.hex + return ["%s.parts[%d] = 0x%s;" % (var, x, hex_string[x * 8:(x + 1) * 8]) + for x in range(4)] diff --git a/ryu/contrib/ovs/poller.py b/ryu/contrib/ovs/poller.py new file mode 100644 index 00000000..c04c9b36 --- /dev/null +++ b/ryu/contrib/ovs/poller.py @@ -0,0 +1,185 @@ +# Copyright (c) 2010 Nicira, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import errno +import ovs.timeval +import ovs.vlog +import select +import socket + +vlog = ovs.vlog.Vlog("poller") + + +# eventlet/gevent doesn't support select.poll. If select.poll is used, +# python interpreter is blocked as a whole instead of switching from the +# current thread that is about to block to other runnable thread. +# So emulate select.poll by select.select because using python means that +# performance isn't so important. +class _SelectSelect(object): + """ select.poll emulation by using select.select. + Only register and poll are needed at the moment. + """ + def __init__(self): + self.rlist = [] + self.wlist = [] + self.xlist = [] + + def register(self, fd, events): + if isinstance(fd, socket.socket): + fd = fd.fileno() + assert isinstance(fd, int) + if events & select.POLLIN: + self.rlist.append(fd) + events &= ~select.POLLIN + if events & select.POLLOUT: + self.wlist.append(fd) + events &= ~select.POLLOUT + if events: + self.xlist.append(fd) + + def poll(self, timeout): + if timeout == -1: + # epoll uses -1 for infinite timeout, select uses None. + timeout = None + else: + timeout = float(timeout) / 1000 + + rlist, wlist, xlist = select.select(self.rlist, self.wlist, self.xlist, + timeout) + # collections.defaultdict is introduced by python 2.5 and + # XenServer uses python 2.4. We don't use it for XenServer. + # events_dict = collections.defaultdict(int) + # events_dict[fd] |= event + events_dict = {} + for fd in rlist: + events_dict[fd] = events_dict.get(fd, 0) | select.POLLIN + for fd in wlist: + events_dict[fd] = events_dict.get(fd, 0) | select.POLLOUT + for fd in xlist: + events_dict[fd] = events_dict.get(fd, 0) | (select.POLLERR | + select.POLLHUP | + select.POLLNVAL) + return events_dict.items() + + +SelectPoll = _SelectSelect +# If eventlet/gevent isn't used, we can use select.poll by replacing +# _SelectPoll with select.poll class +# _SelectPoll = select.poll + + +class Poller(object): + """High-level wrapper around the "poll" system call. + + Intended usage is for the program's main loop to go about its business + servicing whatever events it needs to. Then, when it runs out of immediate + tasks, it calls each subordinate module or object's "wait" function, which + in turn calls one (or more) of the functions Poller.fd_wait(), + Poller.immediate_wake(), and Poller.timer_wait() to register to be awakened + when the appropriate event occurs. Then the main loop calls + Poller.block(), which blocks until one of the registered events happens.""" + + def __init__(self): + self.__reset() + + def fd_wait(self, fd, events): + """Registers 'fd' as waiting for the specified 'events' (which should + be select.POLLIN or select.POLLOUT or their bitwise-OR). The following + call to self.block() will wake up when 'fd' becomes ready for one or + more of the requested events. + + The event registration is one-shot: only the following call to + self.block() is affected. The event will need to be re-registered + after self.block() is called if it is to persist. + + 'fd' may be an integer file descriptor or an object with a fileno() + method that returns an integer file descriptor.""" + self.poll.register(fd, events) + + def __timer_wait(self, msec): + if self.timeout < 0 or msec < self.timeout: + self.timeout = msec + + def timer_wait(self, msec): + """Causes the following call to self.block() to block for no more than + 'msec' milliseconds. If 'msec' is nonpositive, the following call to + self.block() will not block at all. + + The timer registration is one-shot: only the following call to + self.block() is affected. The timer will need to be re-registered + after self.block() is called if it is to persist.""" + if msec <= 0: + self.immediate_wake() + else: + self.__timer_wait(msec) + + def timer_wait_until(self, msec): + """Causes the following call to self.block() to wake up when the + current time, as returned by ovs.timeval.msec(), reaches 'msec' or + later. If 'msec' is earlier than the current time, the following call + to self.block() will not block at all. + + The timer registration is one-shot: only the following call to + self.block() is affected. The timer will need to be re-registered + after self.block() is called if it is to persist.""" + now = ovs.timeval.msec() + if msec <= now: + self.immediate_wake() + else: + self.__timer_wait(msec - now) + + def immediate_wake(self): + """Causes the following call to self.block() to wake up immediately, + without blocking.""" + self.timeout = 0 + + def block(self): + """Blocks until one or more of the events registered with + self.fd_wait() occurs, or until the minimum duration registered with + self.timer_wait() elapses, or not at all if self.immediate_wake() has + been called.""" + try: + try: + events = self.poll.poll(self.timeout) + self.__log_wakeup(events) + except select.error, e: + # XXX rate-limit + error, msg = e + if error != errno.EINTR: + vlog.err("poll: %s" % e[1]) + finally: + self.__reset() + + def __log_wakeup(self, events): + if not events: + vlog.dbg("%d-ms timeout" % self.timeout) + else: + for fd, revents in events: + if revents != 0: + s = "" + if revents & select.POLLIN: + s += "[POLLIN]" + if revents & select.POLLOUT: + s += "[POLLOUT]" + if revents & select.POLLERR: + s += "[POLLERR]" + if revents & select.POLLHUP: + s += "[POLLHUP]" + if revents & select.POLLNVAL: + s += "[POLLNVAL]" + vlog.dbg("%s on fd %d" % (s, fd)) + + def __reset(self): + self.poll = SelectPoll() + self.timeout = -1 diff --git a/ryu/contrib/ovs/process.py b/ryu/contrib/ovs/process.py new file mode 100644 index 00000000..d7561310 --- /dev/null +++ b/ryu/contrib/ovs/process.py @@ -0,0 +1,41 @@ +# Copyright (c) 2010, 2011 Nicira, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import signal + + +def _signal_status_msg(type_, signr): + s = "%s by signal %d" % (type_, signr) + for name in signal.__dict__: + if name.startswith("SIG") and getattr(signal, name) == signr: + return "%s (%s)" % (s, name) + return s + + +def status_msg(status): + """Given 'status', which is a process status in the form reported by + waitpid(2) and returned by process_status(), returns a string describing + how the process terminated.""" + if os.WIFEXITED(status): + s = "exit status %d" % os.WEXITSTATUS(status) + elif os.WIFSIGNALED(status): + s = _signal_status_msg("killed", os.WTERMSIG(status)) + elif os.WIFSTOPPED(status): + s = _signal_status_msg("stopped", os.WSTOPSIG(status)) + else: + s = "terminated abnormally (%x)" % status + if os.WCOREDUMP(status): + s += ", core dumped" + return s diff --git a/ryu/contrib/ovs/reconnect.py b/ryu/contrib/ovs/reconnect.py new file mode 100644 index 00000000..39dd556d --- /dev/null +++ b/ryu/contrib/ovs/reconnect.py @@ -0,0 +1,588 @@ +# Copyright (c) 2010, 2011, 2012 Nicira, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import ovs.vlog +import ovs.util + +# Values returned by Reconnect.run() +CONNECT = 'connect' +DISCONNECT = 'disconnect' +PROBE = 'probe' + +EOF = ovs.util.EOF +vlog = ovs.vlog.Vlog("reconnect") + + +class Reconnect(object): + """A finite-state machine for connecting and reconnecting to a network + resource with exponential backoff. It also provides optional support for + detecting a connection on which the peer is no longer responding. + + The library does not implement anything networking related, only an FSM for + networking code to use. + + Many Reconnect methods take a "now" argument. This makes testing easier + since there is no hidden state. When not testing, just pass the return + value of ovs.time.msec(). (Perhaps this design should be revisited + later.)""" + + class Void(object): + name = "VOID" + is_connected = False + + @staticmethod + def deadline(fsm): + return None + + @staticmethod + def run(fsm, now): + return None + + class Listening(object): + name = "LISTENING" + is_connected = False + + @staticmethod + def deadline(fsm): + return None + + @staticmethod + def run(fsm, now): + return None + + class Backoff(object): + name = "BACKOFF" + is_connected = False + + @staticmethod + def deadline(fsm): + return fsm.state_entered + fsm.backoff + + @staticmethod + def run(fsm, now): + return CONNECT + + class ConnectInProgress(object): + name = "CONNECTING" + is_connected = False + + @staticmethod + def deadline(fsm): + return fsm.state_entered + max(1000, fsm.backoff) + + @staticmethod + def run(fsm, now): + return DISCONNECT + + class Active(object): + name = "ACTIVE" + is_connected = True + + @staticmethod + def deadline(fsm): + if fsm.probe_interval: + base = max(fsm.last_activity, fsm.state_entered) + return base + fsm.probe_interval + return None + + @staticmethod + def run(fsm, now): + vlog.dbg("%s: idle %d ms, sending inactivity probe" + % (fsm.name, + now - max(fsm.last_activity, fsm.state_entered))) + fsm._transition(now, Reconnect.Idle) + return PROBE + + class Idle(object): + name = "IDLE" + is_connected = True + + @staticmethod + def deadline(fsm): + if fsm.probe_interval: + return fsm.state_entered + fsm.probe_interval + return None + + @staticmethod + def run(fsm, now): + vlog.err("%s: no response to inactivity probe after %.3g " + "seconds, disconnecting" + % (fsm.name, (now - fsm.state_entered) / 1000.0)) + return DISCONNECT + + class Reconnect(object): + name = "RECONNECT" + is_connected = False + + @staticmethod + def deadline(fsm): + return fsm.state_entered + + @staticmethod + def run(fsm, now): + return DISCONNECT + + def __init__(self, now): + """Creates and returns a new reconnect FSM with default settings. The + FSM is initially disabled. The caller will likely want to call + self.enable() and self.set_name() on the returned object.""" + + self.name = "void" + self.min_backoff = 1000 + self.max_backoff = 8000 + self.probe_interval = 5000 + self.passive = False + self.info_level = vlog.info + + self.state = Reconnect.Void + self.state_entered = now + self.backoff = 0 + self.last_activity = now + self.last_connected = None + self.last_disconnected = None + self.max_tries = None + + self.creation_time = now + self.n_attempted_connections = 0 + self.n_successful_connections = 0 + self.total_connected_duration = 0 + self.seqno = 0 + + def set_quiet(self, quiet): + """If 'quiet' is true, this object will log informational messages at + debug level, by default keeping them out of log files. This is + appropriate if the connection is one that is expected to be + short-lived, so that the log messages are merely distracting. + + If 'quiet' is false, this object logs informational messages at info + level. This is the default. + + This setting has no effect on the log level of debugging, warning, or + error messages.""" + if quiet: + self.info_level = vlog.dbg + else: + self.info_level = vlog.info + + def get_name(self): + return self.name + + def set_name(self, name): + """Sets this object's name to 'name'. If 'name' is None, then "void" + is used instead. + + The name is used in log messages.""" + if name is None: + self.name = "void" + else: + self.name = name + + def get_min_backoff(self): + """Return the minimum number of milliseconds to back off between + consecutive connection attempts. The default is 1000 ms.""" + return self.min_backoff + + def get_max_backoff(self): + """Return the maximum number of milliseconds to back off between + consecutive connection attempts. The default is 8000 ms.""" + return self.max_backoff + + def get_probe_interval(self): + """Returns the "probe interval" in milliseconds. If this is zero, it + disables the connection keepalive feature. If it is nonzero, then if + the interval passes while the FSM is connected and without + self.activity() being called, self.run() returns ovs.reconnect.PROBE. + If the interval passes again without self.activity() being called, + self.run() returns ovs.reconnect.DISCONNECT.""" + return self.probe_interval + + def set_max_tries(self, max_tries): + """Limits the maximum number of times that this object will ask the + client to try to reconnect to 'max_tries'. None (the default) means an + unlimited number of tries. + + After the number of tries has expired, the FSM will disable itself + instead of backing off and retrying.""" + self.max_tries = max_tries + + def get_max_tries(self): + """Returns the current remaining number of connection attempts, + None if the number is unlimited.""" + return self.max_tries + + def set_backoff(self, min_backoff, max_backoff): + """Configures the backoff parameters for this FSM. 'min_backoff' is + the minimum number of milliseconds, and 'max_backoff' is the maximum, + between connection attempts. + + 'min_backoff' must be at least 1000, and 'max_backoff' must be greater + than or equal to 'min_backoff'.""" + self.min_backoff = max(min_backoff, 1000) + if self.max_backoff: + self.max_backoff = max(max_backoff, 1000) + else: + self.max_backoff = 8000 + if self.min_backoff > self.max_backoff: + self.max_backoff = self.min_backoff + + if (self.state == Reconnect.Backoff and + self.backoff > self.max_backoff): + self.backoff = self.max_backoff + + def set_probe_interval(self, probe_interval): + """Sets the "probe interval" to 'probe_interval', in milliseconds. If + this is zero, it disables the connection keepalive feature. If it is + nonzero, then if the interval passes while this FSM is connected and + without self.activity() being called, self.run() returns + ovs.reconnect.PROBE. If the interval passes again without + self.activity() being called, self.run() returns + ovs.reconnect.DISCONNECT. + + If 'probe_interval' is nonzero, then it will be forced to a value of at + least 1000 ms.""" + if probe_interval: + self.probe_interval = max(1000, probe_interval) + else: + self.probe_interval = 0 + + def is_passive(self): + """Returns true if 'fsm' is in passive mode, false if 'fsm' is in + active mode (the default).""" + return self.passive + + def set_passive(self, passive, now): + """Configures this FSM for active or passive mode. In active mode (the + default), the FSM is attempting to connect to a remote host. In + passive mode, the FSM is listening for connections from a remote + host.""" + if self.passive != passive: + self.passive = passive + + if ((passive and self.state in (Reconnect.ConnectInProgress, + Reconnect.Reconnect)) or + (not passive and self.state == Reconnect.Listening + and self.__may_retry())): + self._transition(now, Reconnect.Backoff) + self.backoff = 0 + + def is_enabled(self): + """Returns true if this FSM has been enabled with self.enable(). + Calling another function that indicates a change in connection state, + such as self.disconnected() or self.force_reconnect(), will also enable + a reconnect FSM.""" + return self.state != Reconnect.Void + + def enable(self, now): + """If this FSM is disabled (the default for newly created FSMs), + enables it, so that the next call to reconnect_run() for 'fsm' will + return ovs.reconnect.CONNECT. + + If this FSM is not disabled, this function has no effect.""" + if self.state == Reconnect.Void and self.__may_retry(): + self._transition(now, Reconnect.Backoff) + self.backoff = 0 + + def disable(self, now): + """Disables this FSM. Until 'fsm' is enabled again, self.run() will + always return 0.""" + if self.state != Reconnect.Void: + self._transition(now, Reconnect.Void) + + def force_reconnect(self, now): + """If this FSM is enabled and currently connected (or attempting to + connect), forces self.run() to return ovs.reconnect.DISCONNECT the next + time it is called, which should cause the client to drop the connection + (or attempt), back off, and then reconnect.""" + if self.state in (Reconnect.ConnectInProgress, + Reconnect.Active, + Reconnect.Idle): + self._transition(now, Reconnect.Reconnect) + + def disconnected(self, now, error): + """Tell this FSM that the connection dropped or that a connection + attempt failed. 'error' specifies the reason: a positive value + represents an errno value, EOF indicates that the connection was closed + by the peer (e.g. read() returned 0), and 0 indicates no specific + error. + + The FSM will back off, then reconnect.""" + if self.state not in (Reconnect.Backoff, Reconnect.Void): + # Report what happened + if self.state in (Reconnect.Active, Reconnect.Idle): + if error > 0: + vlog.warn("%s: connection dropped (%s)" + % (self.name, os.strerror(error))) + elif error == EOF: + self.info_level("%s: connection closed by peer" + % self.name) + else: + self.info_level("%s: connection dropped" % self.name) + elif self.state == Reconnect.Listening: + if error > 0: + vlog.warn("%s: error listening for connections (%s)" + % (self.name, os.strerror(error))) + else: + self.info_level("%s: error listening for connections" + % self.name) + else: + if self.passive: + type_ = "listen" + else: + type_ = "connection" + if error > 0: + vlog.warn("%s: %s attempt failed (%s)" + % (self.name, type_, os.strerror(error))) + else: + self.info_level("%s: %s attempt timed out" + % (self.name, type_)) + + if (self.state in (Reconnect.Active, Reconnect.Idle)): + self.last_disconnected = now + + # Back off + if (self.state in (Reconnect.Active, Reconnect.Idle) and + (self.last_activity - self.last_connected >= self.backoff or + self.passive)): + if self.passive: + self.backoff = 0 + else: + self.backoff = self.min_backoff + else: + if self.backoff < self.min_backoff: + self.backoff = self.min_backoff + elif self.backoff >= self.max_backoff / 2: + self.backoff = self.max_backoff + else: + self.backoff *= 2 + + if self.passive: + self.info_level("%s: waiting %.3g seconds before trying " + "to listen again" + % (self.name, self.backoff / 1000.0)) + else: + self.info_level("%s: waiting %.3g seconds before reconnect" + % (self.name, self.backoff / 1000.0)) + + if self.__may_retry(): + self._transition(now, Reconnect.Backoff) + else: + self._transition(now, Reconnect.Void) + + def connecting(self, now): + """Tell this FSM that a connection or listening attempt is in progress. + + The FSM will start a timer, after which the connection or listening + attempt will be aborted (by returning ovs.reconnect.DISCONNECT from + self.run()).""" + if self.state != Reconnect.ConnectInProgress: + if self.passive: + self.info_level("%s: listening..." % self.name) + else: + self.info_level("%s: connecting..." % self.name) + self._transition(now, Reconnect.ConnectInProgress) + + def listening(self, now): + """Tell this FSM that the client is listening for connection attempts. + This state last indefinitely until the client reports some change. + + The natural progression from this state is for the client to report + that a connection has been accepted or is in progress of being + accepted, by calling self.connecting() or self.connected(). + + The client may also report that listening failed (e.g. accept() + returned an unexpected error such as ENOMEM) by calling + self.listen_error(), in which case the FSM will back off and eventually + return ovs.reconnect.CONNECT from self.run() to tell the client to try + listening again.""" + if self.state != Reconnect.Listening: + self.info_level("%s: listening..." % self.name) + self._transition(now, Reconnect.Listening) + + def listen_error(self, now, error): + """Tell this FSM that the client's attempt to accept a connection + failed (e.g. accept() returned an unexpected error such as ENOMEM). + + If the FSM is currently listening (self.listening() was called), it + will back off and eventually return ovs.reconnect.CONNECT from + self.run() to tell the client to try listening again. If there is an + active connection, this will be delayed until that connection drops.""" + if self.state == Reconnect.Listening: + self.disconnected(now, error) + + def connected(self, now): + """Tell this FSM that the connection was successful. + + The FSM will start the probe interval timer, which is reset by + self.activity(). If the timer expires, a probe will be sent (by + returning ovs.reconnect.PROBE from self.run(). If the timer expires + again without being reset, the connection will be aborted (by returning + ovs.reconnect.DISCONNECT from self.run().""" + if not self.state.is_connected: + self.connecting(now) + + self.info_level("%s: connected" % self.name) + self._transition(now, Reconnect.Active) + self.last_connected = now + + def connect_failed(self, now, error): + """Tell this FSM that the connection attempt failed. + + The FSM will back off and attempt to reconnect.""" + self.connecting(now) + self.disconnected(now, error) + + def activity(self, now): + """Tell this FSM that some activity occurred on the connection. This + resets the probe interval timer, so that the connection is known not to + be idle.""" + if self.state != Reconnect.Active: + self._transition(now, Reconnect.Active) + self.last_activity = now + + def _transition(self, now, state): + if self.state == Reconnect.ConnectInProgress: + self.n_attempted_connections += 1 + if state == Reconnect.Active: + self.n_successful_connections += 1 + + connected_before = self.state.is_connected + connected_now = state.is_connected + if connected_before != connected_now: + if connected_before: + self.total_connected_duration += now - self.last_connected + self.seqno += 1 + + vlog.dbg("%s: entering %s" % (self.name, state.name)) + self.state = state + self.state_entered = now + + def run(self, now): + """Assesses whether any action should be taken on this FSM. The return + value is one of: + + - None: The client need not take any action. + + - Active client, ovs.reconnect.CONNECT: The client should start a + connection attempt and indicate this by calling + self.connecting(). If the connection attempt has definitely + succeeded, it should call self.connected(). If the connection + attempt has definitely failed, it should call + self.connect_failed(). + + The FSM is smart enough to back off correctly after successful + connections that quickly abort, so it is OK to call + self.connected() after a low-level successful connection + (e.g. connect()) even if the connection might soon abort due to a + failure at a high-level (e.g. SSL negotiation failure). + + - Passive client, ovs.reconnect.CONNECT: The client should try to + listen for a connection, if it is not already listening. It + should call self.listening() if successful, otherwise + self.connecting() or reconnected_connect_failed() if the attempt + is in progress or definitely failed, respectively. + + A listening passive client should constantly attempt to accept a + new connection and report an accepted connection with + self.connected(). + + - ovs.reconnect.DISCONNECT: The client should abort the current + connection or connection attempt or listen attempt and call + self.disconnected() or self.connect_failed() to indicate it. + + - ovs.reconnect.PROBE: The client should send some kind of request + to the peer that will elicit a response, to ensure that the + connection is indeed in working order. (This will only be + returned if the "probe interval" is nonzero--see + self.set_probe_interval()).""" + + deadline = self.state.deadline(self) + if deadline is not None and now >= deadline: + return self.state.run(self, now) + else: + return None + + def wait(self, poller, now): + """Causes the next call to poller.block() to wake up when self.run() + should be called.""" + timeout = self.timeout(now) + if timeout >= 0: + poller.timer_wait(timeout) + + def timeout(self, now): + """Returns the number of milliseconds after which self.run() should be + called if nothing else notable happens in the meantime, or None if this + is currently unnecessary.""" + deadline = self.state.deadline(self) + if deadline is not None: + remaining = deadline - now + return max(0, remaining) + else: + return None + + def is_connected(self): + """Returns True if this FSM is currently believed to be connected, that + is, if self.connected() was called more recently than any call to + self.connect_failed() or self.disconnected() or self.disable(), and + False otherwise.""" + return self.state.is_connected + + def get_last_connect_elapsed(self, now): + """Returns the number of milliseconds since 'fsm' was last connected + to its peer. Returns None if never connected.""" + if self.last_connected: + return now - self.last_connected + else: + return None + + def get_last_disconnect_elapsed(self, now): + """Returns the number of milliseconds since 'fsm' was last disconnected + from its peer. Returns None if never disconnected.""" + if self.last_disconnected: + return now - self.last_disconnected + else: + return None + + def get_stats(self, now): + class Stats(object): + pass + stats = Stats() + stats.creation_time = self.creation_time + stats.last_connected = self.last_connected + stats.last_disconnected = self.last_disconnected + stats.last_activity = self.last_activity + stats.backoff = self.backoff + stats.seqno = self.seqno + stats.is_connected = self.is_connected() + stats.msec_since_connect = self.get_last_connect_elapsed(now) + stats.msec_since_disconnect = self.get_last_disconnect_elapsed(now) + stats.total_connected_duration = self.total_connected_duration + if self.is_connected(): + stats.total_connected_duration += ( + self.get_last_connect_elapsed(now)) + stats.n_attempted_connections = self.n_attempted_connections + stats.n_successful_connections = self.n_successful_connections + stats.state = self.state.name + stats.state_elapsed = now - self.state_entered + return stats + + def __may_retry(self): + if self.max_tries is None: + return True + elif self.max_tries > 0: + self.max_tries -= 1 + return True + else: + return False diff --git a/ryu/contrib/ovs/socket_util.py b/ryu/contrib/ovs/socket_util.py new file mode 100644 index 00000000..dd45fe4b --- /dev/null +++ b/ryu/contrib/ovs/socket_util.py @@ -0,0 +1,192 @@ +# Copyright (c) 2010, 2012 Nicira, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import errno +import os +import select +import socket +import sys + +import ovs.fatal_signal +import ovs.poller +import ovs.vlog + +vlog = ovs.vlog.Vlog("socket_util") + + +def make_unix_socket(style, nonblock, bind_path, connect_path): + """Creates a Unix domain socket in the given 'style' (either + socket.SOCK_DGRAM or socket.SOCK_STREAM) that is bound to 'bind_path' (if + 'bind_path' is not None) and connected to 'connect_path' (if 'connect_path' + is not None). If 'nonblock' is true, the socket is made non-blocking. + + Returns (error, socket): on success 'error' is 0 and 'socket' is a new + socket object, on failure 'error' is a positive errno value and 'socket' is + None.""" + + try: + sock = socket.socket(socket.AF_UNIX, style) + except socket.error, e: + return get_exception_errno(e), None + + try: + if nonblock: + set_nonblocking(sock) + if bind_path is not None: + # Delete bind_path but ignore ENOENT. + try: + os.unlink(bind_path) + except OSError, e: + if e.errno != errno.ENOENT: + return e.errno, None + + ovs.fatal_signal.add_file_to_unlink(bind_path) + sock.bind(bind_path) + + try: + if sys.hexversion >= 0x02060000: + os.fchmod(sock.fileno(), 0700) + else: + os.chmod("/dev/fd/%d" % sock.fileno(), 0700) + except OSError, e: + pass + if connect_path is not None: + try: + sock.connect(connect_path) + except socket.error, e: + if get_exception_errno(e) != errno.EINPROGRESS: + raise + return 0, sock + except socket.error, e: + sock.close() + if bind_path is not None: + ovs.fatal_signal.unlink_file_now(bind_path) + return get_exception_errno(e), None + + +def check_connection_completion(sock): + p = ovs.poller.SelectPoll() + p.register(sock, select.POLLOUT) + if len(p.poll(0)) == 1: + return get_socket_error(sock) + else: + return errno.EAGAIN + + +def inet_parse_active(target, default_port): + address = target.split(":") + host_name = address[0] + if not host_name: + raise ValueError("%s: bad peer name format" % target) + if len(address) >= 2: + port = int(address[1]) + elif default_port: + port = default_port + else: + raise ValueError("%s: port number must be specified" % target) + return (host_name, port) + + +def inet_open_active(style, target, default_port, dscp): + address = inet_parse_active(target, default_port) + try: + sock = socket.socket(socket.AF_INET, style, 0) + except socket.error, e: + return get_exception_errno(e), None + + try: + set_nonblocking(sock) + set_dscp(sock, dscp) + try: + sock.connect(address) + except socket.error, e: + if get_exception_errno(e) != errno.EINPROGRESS: + raise + return 0, sock + except socket.error, e: + sock.close() + return get_exception_errno(e), None + + +def get_socket_error(sock): + """Returns the errno value associated with 'socket' (0 if no error) and + resets the socket's error status.""" + return sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR) + + +def get_exception_errno(e): + """A lot of methods on Python socket objects raise socket.error, but that + exception is documented as having two completely different forms of + arguments: either a string or a (errno, string) tuple. We only want the + errno.""" + if type(e.args) == tuple: + return e.args[0] + else: + return errno.EPROTO + + +null_fd = -1 + + +def get_null_fd(): + """Returns a readable and writable fd for /dev/null, if successful, + otherwise a negative errno value. The caller must not close the returned + fd (because the same fd will be handed out to subsequent callers).""" + global null_fd + if null_fd < 0: + try: + null_fd = os.open("/dev/null", os.O_RDWR) + except OSError, e: + vlog.err("could not open /dev/null: %s" % os.strerror(e.errno)) + return -e.errno + return null_fd + + +def write_fully(fd, buf): + """Returns an (error, bytes_written) tuple where 'error' is 0 on success, + otherwise a positive errno value, and 'bytes_written' is the number of + bytes that were written before the error occurred. 'error' is 0 if and + only if 'bytes_written' is len(buf).""" + bytes_written = 0 + if len(buf) == 0: + return 0, 0 + while True: + try: + retval = os.write(fd, buf) + assert retval >= 0 + if retval == len(buf): + return 0, bytes_written + len(buf) + elif retval == 0: + vlog.warn("write returned 0") + return errno.EPROTO, bytes_written + else: + bytes_written += retval + buf = buf[:retval] + except OSError, e: + return e.errno, bytes_written + + +def set_nonblocking(sock): + try: + sock.setblocking(0) + except socket.error, e: + vlog.err("could not set nonblocking mode on socket: %s" + % os.strerror(get_socket_error(e))) + + +def set_dscp(sock, dscp): + if dscp > 63: + raise ValueError("Invalid dscp %d" % dscp) + val = dscp << 2 + sock.setsockopt(socket.IPPROTO_IP, socket.IP_TOS, val) diff --git a/ryu/contrib/ovs/stream.py b/ryu/contrib/ovs/stream.py new file mode 100644 index 00000000..dad68483 --- /dev/null +++ b/ryu/contrib/ovs/stream.py @@ -0,0 +1,361 @@ +# Copyright (c) 2010, 2011, 2012 Nicira, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import errno +import os +import select +import socket + +import ovs.poller +import ovs.socket_util +import ovs.vlog + +vlog = ovs.vlog.Vlog("stream") + + +def stream_or_pstream_needs_probes(name): + """ 1 if the stream or pstream specified by 'name' needs periodic probes to + verify connectivity. For [p]streams which need probes, it can take a long + time to notice the connection was dropped. Returns 0 if probes aren't + needed, and -1 if 'name' is invalid""" + + if PassiveStream.is_valid_name(name) or Stream.is_valid_name(name): + # Only unix and punix are supported currently. + return 0 + else: + return -1 + + +class Stream(object): + """Bidirectional byte stream. Currently only Unix domain sockets + are implemented.""" + + # States. + __S_CONNECTING = 0 + __S_CONNECTED = 1 + __S_DISCONNECTED = 2 + + # Kinds of events that one might wait for. + W_CONNECT = 0 # Connect complete (success or failure). + W_RECV = 1 # Data received. + W_SEND = 2 # Send buffer room available. + + _SOCKET_METHODS = {} + + @staticmethod + def register_method(method, cls): + Stream._SOCKET_METHODS[method + ":"] = cls + + @staticmethod + def _find_method(name): + for method, cls in Stream._SOCKET_METHODS.items(): + if name.startswith(method): + return cls + return None + + @staticmethod + def is_valid_name(name): + """Returns True if 'name' is a stream name in the form "TYPE:ARGS" and + TYPE is a supported stream type (currently only "unix:" and "tcp:"), + otherwise False.""" + return bool(Stream._find_method(name)) + + def __init__(self, socket, name, status): + self.socket = socket + self.name = name + if status == errno.EAGAIN: + self.state = Stream.__S_CONNECTING + elif status == 0: + self.state = Stream.__S_CONNECTED + else: + self.state = Stream.__S_DISCONNECTED + + self.error = 0 + + # Default value of dscp bits for connection between controller and manager. + # Value of IPTOS_PREC_INTERNETCONTROL = 0xc0 which is defined + # in is used. + IPTOS_PREC_INTERNETCONTROL = 0xc0 + DSCP_DEFAULT = IPTOS_PREC_INTERNETCONTROL >> 2 + + @staticmethod + def open(name, dscp=DSCP_DEFAULT): + """Attempts to connect a stream to a remote peer. 'name' is a + connection name in the form "TYPE:ARGS", where TYPE is an active stream + class's name and ARGS are stream class-specific. Currently the only + supported TYPEs are "unix" and "tcp". + + Returns (error, stream): on success 'error' is 0 and 'stream' is the + new Stream, on failure 'error' is a positive errno value and 'stream' + is None. + + Never returns errno.EAGAIN or errno.EINPROGRESS. Instead, returns 0 + and a new Stream. The connect() method can be used to check for + successful connection completion.""" + cls = Stream._find_method(name) + if not cls: + return errno.EAFNOSUPPORT, None + + suffix = name.split(":", 1)[1] + error, sock = cls._open(suffix, dscp) + if error: + return error, None + else: + status = ovs.socket_util.check_connection_completion(sock) + return 0, Stream(sock, name, status) + + @staticmethod + def _open(suffix, dscp): + raise NotImplementedError("This method must be overrided by subclass") + + @staticmethod + def open_block((error, stream)): + """Blocks until a Stream completes its connection attempt, either + succeeding or failing. (error, stream) should be the tuple returned by + Stream.open(). Returns a tuple of the same form. + + Typical usage: + error, stream = Stream.open_block(Stream.open("unix:/tmp/socket"))""" + + if not error: + while True: + error = stream.connect() + if error != errno.EAGAIN: + break + stream.run() + poller = ovs.poller.Poller() + stream.run_wait(poller) + stream.connect_wait(poller) + poller.block() + assert error != errno.EINPROGRESS + + if error and stream: + stream.close() + stream = None + return error, stream + + def close(self): + self.socket.close() + + def __scs_connecting(self): + retval = ovs.socket_util.check_connection_completion(self.socket) + assert retval != errno.EINPROGRESS + if retval == 0: + self.state = Stream.__S_CONNECTED + elif retval != errno.EAGAIN: + self.state = Stream.__S_DISCONNECTED + self.error = retval + + def connect(self): + """Tries to complete the connection on this stream. If the connection + is complete, returns 0 if the connection was successful or a positive + errno value if it failed. If the connection is still in progress, + returns errno.EAGAIN.""" + last_state = -1 # Always differs from initial self.state + while self.state != last_state: + last_state = self.state + if self.state == Stream.__S_CONNECTING: + self.__scs_connecting() + elif self.state == Stream.__S_CONNECTED: + return 0 + elif self.state == Stream.__S_DISCONNECTED: + return self.error + + def recv(self, n): + """Tries to receive up to 'n' bytes from this stream. Returns a + (error, string) tuple: + + - If successful, 'error' is zero and 'string' contains between 1 + and 'n' bytes of data. + + - On error, 'error' is a positive errno value. + + - If the connection has been closed in the normal fashion or if 'n' + is 0, the tuple is (0, ""). + + The recv function will not block waiting for data to arrive. If no + data have been received, it returns (errno.EAGAIN, "") immediately.""" + + retval = self.connect() + if retval != 0: + return (retval, "") + elif n == 0: + return (0, "") + + try: + return (0, self.socket.recv(n)) + except socket.error, e: + return (ovs.socket_util.get_exception_errno(e), "") + + def send(self, buf): + """Tries to send 'buf' on this stream. + + If successful, returns the number of bytes sent, between 1 and + len(buf). 0 is only a valid return value if len(buf) is 0. + + On error, returns a negative errno value. + + Will not block. If no bytes can be immediately accepted for + transmission, returns -errno.EAGAIN immediately.""" + + retval = self.connect() + if retval != 0: + return -retval + elif len(buf) == 0: + return 0 + + try: + return self.socket.send(buf) + except socket.error, e: + return -ovs.socket_util.get_exception_errno(e) + + def run(self): + pass + + def run_wait(self, poller): + pass + + def wait(self, poller, wait): + assert wait in (Stream.W_CONNECT, Stream.W_RECV, Stream.W_SEND) + + if self.state == Stream.__S_DISCONNECTED: + poller.immediate_wake() + return + + if self.state == Stream.__S_CONNECTING: + wait = Stream.W_CONNECT + if wait == Stream.W_RECV: + poller.fd_wait(self.socket, select.POLLIN) + else: + poller.fd_wait(self.socket, select.POLLOUT) + + def connect_wait(self, poller): + self.wait(poller, Stream.W_CONNECT) + + def recv_wait(self, poller): + self.wait(poller, Stream.W_RECV) + + def send_wait(self, poller): + self.wait(poller, Stream.W_SEND) + + def __del__(self): + # Don't delete the file: we might have forked. + self.socket.close() + + +class PassiveStream(object): + @staticmethod + def is_valid_name(name): + """Returns True if 'name' is a passive stream name in the form + "TYPE:ARGS" and TYPE is a supported passive stream type (currently only + "punix:"), otherwise False.""" + return name.startswith("punix:") + + def __init__(self, sock, name, bind_path): + self.name = name + self.socket = sock + self.bind_path = bind_path + + @staticmethod + def open(name): + """Attempts to start listening for remote stream connections. 'name' + is a connection name in the form "TYPE:ARGS", where TYPE is an passive + stream class's name and ARGS are stream class-specific. Currently the + only supported TYPE is "punix". + + Returns (error, pstream): on success 'error' is 0 and 'pstream' is the + new PassiveStream, on failure 'error' is a positive errno value and + 'pstream' is None.""" + if not PassiveStream.is_valid_name(name): + return errno.EAFNOSUPPORT, None + + bind_path = name[6:] + error, sock = ovs.socket_util.make_unix_socket(socket.SOCK_STREAM, + True, bind_path, None) + if error: + return error, None + + try: + sock.listen(10) + except socket.error, e: + vlog.err("%s: listen: %s" % (name, os.strerror(e.error))) + sock.close() + return e.error, None + + return 0, PassiveStream(sock, name, bind_path) + + def close(self): + """Closes this PassiveStream.""" + self.socket.close() + if self.bind_path is not None: + ovs.fatal_signal.unlink_file_now(self.bind_path) + self.bind_path = None + + def accept(self): + """Tries to accept a new connection on this passive stream. Returns + (error, stream): if successful, 'error' is 0 and 'stream' is the new + Stream object, and on failure 'error' is a positive errno value and + 'stream' is None. + + Will not block waiting for a connection. If no connection is ready to + be accepted, returns (errno.EAGAIN, None) immediately.""" + + while True: + try: + sock, addr = self.socket.accept() + ovs.socket_util.set_nonblocking(sock) + return 0, Stream(sock, "unix:%s" % addr, 0) + except socket.error, e: + error = ovs.socket_util.get_exception_errno(e) + if error != errno.EAGAIN: + # XXX rate-limit + vlog.dbg("accept: %s" % os.strerror(error)) + return error, None + + def wait(self, poller): + poller.fd_wait(self.socket, select.POLLIN) + + def __del__(self): + # Don't delete the file: we might have forked. + self.socket.close() + + +def usage(name): + return """ +Active %s connection methods: + unix:FILE Unix domain socket named FILE + tcp:IP:PORT TCP socket to IP with port no of PORT + +Passive %s connection methods: + punix:FILE Listen on Unix domain socket FILE""" % (name, name) + + +class UnixStream(Stream): + @staticmethod + def _open(suffix, dscp): + connect_path = suffix + return ovs.socket_util.make_unix_socket(socket.SOCK_STREAM, + True, None, connect_path) +Stream.register_method("unix", UnixStream) + + +class TCPStream(Stream): + @staticmethod + def _open(suffix, dscp): + error, sock = ovs.socket_util.inet_open_active(socket.SOCK_STREAM, + suffix, 0, dscp) + if not error: + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + return error, sock +Stream.register_method("tcp", TCPStream) diff --git a/ryu/contrib/ovs/timeval.py b/ryu/contrib/ovs/timeval.py new file mode 100644 index 00000000..ba0e54e9 --- /dev/null +++ b/ryu/contrib/ovs/timeval.py @@ -0,0 +1,26 @@ +# Copyright (c) 2009, 2010 Nicira, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + + +def msec(): + """Returns the current time, as the amount of time since the epoch, in + milliseconds, as a float.""" + return time.time() * 1000.0 + + +def postfork(): + # Just a stub for now + pass diff --git a/ryu/contrib/ovs/unixctl/__init__.py b/ryu/contrib/ovs/unixctl/__init__.py new file mode 100644 index 00000000..715f2db5 --- /dev/null +++ b/ryu/contrib/ovs/unixctl/__init__.py @@ -0,0 +1,83 @@ +# Copyright (c) 2011, 2012 Nicira, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import types + +import ovs.util + +commands = {} +strtypes = types.StringTypes + + +class _UnixctlCommand(object): + def __init__(self, usage, min_args, max_args, callback, aux): + self.usage = usage + self.min_args = min_args + self.max_args = max_args + self.callback = callback + self.aux = aux + + +def _unixctl_help(conn, unused_argv, unused_aux): + reply = "The available commands are:\n" + command_names = sorted(commands.keys()) + for name in command_names: + reply += " " + usage = commands[name].usage + if usage: + reply += "%-23s %s" % (name, usage) + else: + reply += name + reply += "\n" + conn.reply(reply) + + +def command_register(name, usage, min_args, max_args, callback, aux): + """ Registers a command with the given 'name' to be exposed by the + UnixctlServer. 'usage' describes the arguments to the command; it is used + only for presentation to the user in "help" output. + + 'callback' is called when the command is received. It is passed a + UnixctlConnection object, the list of arguments as unicode strings, and + 'aux'. Normally 'callback' should reply by calling + UnixctlConnection.reply() or UnixctlConnection.reply_error() before it + returns, but if the command cannot be handled immediately, then it can + defer the reply until later. A given connection can only process a single + request at a time, so a reply must be made eventually to avoid blocking + that connection.""" + + assert isinstance(name, strtypes) + assert isinstance(usage, strtypes) + assert isinstance(min_args, int) + assert isinstance(max_args, int) + assert isinstance(callback, types.FunctionType) + + if name not in commands: + commands[name] = _UnixctlCommand(usage, min_args, max_args, callback, + aux) + +def socket_name_from_target(target): + assert isinstance(target, strtypes) + + if target.startswith("/"): + return 0, target + + pidfile_name = "%s/%s.pid" % (ovs.dirs.RUNDIR, target) + pid = ovs.daemon.read_pidfile(pidfile_name) + if pid < 0: + return -pid, "cannot read pidfile \"%s\"" % pidfile_name + + return 0, "%s/%s.%d.ctl" % (ovs.dirs.RUNDIR, target, pid) + +command_register("help", "", 0, 0, _unixctl_help, None) diff --git a/ryu/contrib/ovs/unixctl/client.py b/ryu/contrib/ovs/unixctl/client.py new file mode 100644 index 00000000..2176009a --- /dev/null +++ b/ryu/contrib/ovs/unixctl/client.py @@ -0,0 +1,70 @@ +# Copyright (c) 2011, 2012 Nicira, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import errno +import os +import types + +import ovs.jsonrpc +import ovs.stream +import ovs.util + + +vlog = ovs.vlog.Vlog("unixctl_client") +strtypes = types.StringTypes + + +class UnixctlClient(object): + def __init__(self, conn): + assert isinstance(conn, ovs.jsonrpc.Connection) + self._conn = conn + + def transact(self, command, argv): + assert isinstance(command, strtypes) + assert isinstance(argv, list) + for arg in argv: + assert isinstance(arg, strtypes) + + request = ovs.jsonrpc.Message.create_request(command, argv) + error, reply = self._conn.transact_block(request) + + if error: + vlog.warn("error communicating with %s: %s" + % (self._conn.name, os.strerror(error))) + return error, None, None + + if reply.error is not None: + return 0, str(reply.error), None + else: + assert reply.result is not None + return 0, None, str(reply.result) + + def close(self): + self._conn.close() + self.conn = None + + @staticmethod + def create(path): + assert isinstance(path, str) + + unix = "unix:%s" % ovs.util.abs_file_name(ovs.dirs.RUNDIR, path) + error, stream = ovs.stream.Stream.open_block( + ovs.stream.Stream.open(unix)) + + if error: + vlog.warn("failed to connect to %s" % path) + return error, None + + return 0, UnixctlClient(ovs.jsonrpc.Connection(stream)) diff --git a/ryu/contrib/ovs/unixctl/server.py b/ryu/contrib/ovs/unixctl/server.py new file mode 100644 index 00000000..18e1cf20 --- /dev/null +++ b/ryu/contrib/ovs/unixctl/server.py @@ -0,0 +1,247 @@ +# Copyright (c) 2012 Nicira, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import errno +import os +import types + +import ovs.dirs +import ovs.jsonrpc +import ovs.stream +import ovs.unixctl +import ovs.util +import ovs.version +import ovs.vlog + +Message = ovs.jsonrpc.Message +vlog = ovs.vlog.Vlog("unixctl_server") +strtypes = types.StringTypes + + +class UnixctlConnection(object): + def __init__(self, rpc): + assert isinstance(rpc, ovs.jsonrpc.Connection) + self._rpc = rpc + self._request_id = None + + def run(self): + self._rpc.run() + error = self._rpc.get_status() + if error or self._rpc.get_backlog(): + return error + + for _ in range(10): + if error or self._request_id: + break + + error, msg = self._rpc.recv() + if msg: + if msg.type == Message.T_REQUEST: + self._process_command(msg) + else: + # XXX: rate-limit + vlog.warn("%s: received unexpected %s message" + % (self._rpc.name, + Message.type_to_string(msg.type))) + error = errno.EINVAL + + if not error: + error = self._rpc.get_status() + + return error + + def reply(self, body): + self._reply_impl(True, body) + + def reply_error(self, body): + self._reply_impl(False, body) + + # Called only by unixctl classes. + def _close(self): + self._rpc.close() + self._request_id = None + + def _wait(self, poller): + self._rpc.wait(poller) + if not self._rpc.get_backlog(): + self._rpc.recv_wait(poller) + + def _reply_impl(self, success, body): + assert isinstance(success, bool) + assert body is None or isinstance(body, strtypes) + + assert self._request_id is not None + + if body is None: + body = "" + + if body and not body.endswith("\n"): + body += "\n" + + if success: + reply = Message.create_reply(body, self._request_id) + else: + reply = Message.create_error(body, self._request_id) + + self._rpc.send(reply) + self._request_id = None + + def _process_command(self, request): + assert isinstance(request, ovs.jsonrpc.Message) + assert request.type == ovs.jsonrpc.Message.T_REQUEST + + self._request_id = request.id + + error = None + params = request.params + method = request.method + command = ovs.unixctl.commands.get(method) + if command is None: + error = '"%s" is not a valid command' % method + elif len(params) < command.min_args: + error = '"%s" command requires at least %d arguments' \ + % (method, command.min_args) + elif len(params) > command.max_args: + error = '"%s" command takes at most %d arguments' \ + % (method, command.max_args) + else: + for param in params: + if not isinstance(param, strtypes): + error = '"%s" command has non-string argument' % method + break + + if error is None: + unicode_params = [unicode(p) for p in params] + command.callback(self, unicode_params, command.aux) + + if error: + self.reply_error(error) + + +def _unixctl_version(conn, unused_argv, version): + assert isinstance(conn, UnixctlConnection) + version = "%s (Open vSwitch) %s" % (ovs.util.PROGRAM_NAME, version) + conn.reply(version) + +class UnixctlServer(object): + def __init__(self, listener): + assert isinstance(listener, ovs.stream.PassiveStream) + self._listener = listener + self._conns = [] + + def run(self): + for _ in range(10): + error, stream = self._listener.accept() + if not error: + rpc = ovs.jsonrpc.Connection(stream) + self._conns.append(UnixctlConnection(rpc)) + elif error == errno.EAGAIN: + break + else: + # XXX: rate-limit + vlog.warn("%s: accept failed: %s" % (self._listener.name, + os.strerror(error))) + + for conn in copy.copy(self._conns): + error = conn.run() + if error and error != errno.EAGAIN: + conn._close() + self._conns.remove(conn) + + def wait(self, poller): + self._listener.wait(poller) + for conn in self._conns: + conn._wait(poller) + + def close(self): + for conn in self._conns: + conn._close() + self._conns = None + + self._listener.close() + self._listener = None + + @staticmethod + def create(path, version=None): + """Creates a new UnixctlServer which listens on a unixctl socket + created at 'path'. If 'path' is None, the default path is chosen. + 'version' contains the version of the server as reported by the unixctl + version command. If None, ovs.version.VERSION is used.""" + + assert path is None or isinstance(path, strtypes) + + if path is not None: + path = "punix:%s" % ovs.util.abs_file_name(ovs.dirs.RUNDIR, path) + else: + path = "punix:%s/%s.%d.ctl" % (ovs.dirs.RUNDIR, + ovs.util.PROGRAM_NAME, os.getpid()) + + if version is None: + version = ovs.version.VERSION + + error, listener = ovs.stream.PassiveStream.open(path) + if error: + ovs.util.ovs_error(error, "could not initialize control socket %s" + % path) + return error, None + + ovs.unixctl.command_register("version", "", 0, 0, _unixctl_version, + version) + + return 0, UnixctlServer(listener) + + +class UnixctlClient(object): + def __init__(self, conn): + assert isinstance(conn, ovs.jsonrpc.Connection) + self._conn = conn + + def transact(self, command, argv): + assert isinstance(command, strtypes) + assert isinstance(argv, list) + for arg in argv: + assert isinstance(arg, strtypes) + + request = Message.create_request(command, argv) + error, reply = self._conn.transact_block(request) + + if error: + vlog.warn("error communicating with %s: %s" + % (self._conn.name, os.strerror(error))) + return error, None, None + + if reply.error is not None: + return 0, str(reply.error), None + else: + assert reply.result is not None + return 0, None, str(reply.result) + + def close(self): + self._conn.close() + self.conn = None + + @staticmethod + def create(path): + assert isinstance(path, str) + + unix = "unix:%s" % ovs.util.abs_file_name(ovs.dirs.RUNDIR, path) + error, stream = ovs.stream.Stream.open_block( + ovs.stream.Stream.open(unix)) + + if error: + vlog.warn("failed to connect to %s" % path) + return error, None + + return 0, UnixctlClient(ovs.jsonrpc.Connection(stream)) diff --git a/ryu/contrib/ovs/util.py b/ryu/contrib/ovs/util.py new file mode 100644 index 00000000..cb0574bf --- /dev/null +++ b/ryu/contrib/ovs/util.py @@ -0,0 +1,93 @@ +# Copyright (c) 2010, 2011, 2012 Nicira, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import os.path +import sys + +PROGRAM_NAME = os.path.basename(sys.argv[0]) +EOF = -1 + + +def abs_file_name(dir_, file_name): + """If 'file_name' starts with '/', returns a copy of 'file_name'. + Otherwise, returns an absolute path to 'file_name' considering it relative + to 'dir_', which itself must be absolute. 'dir_' may be None or the empty + string, in which case the current working directory is used. + + Returns None if 'dir_' is None and getcwd() fails. + + This differs from os.path.abspath() in that it will never change the + meaning of a file name.""" + if file_name.startswith('/'): + return file_name + else: + if dir_ is None or dir_ == "": + try: + dir_ = os.getcwd() + except OSError: + return None + + if dir_.endswith('/'): + return dir_ + file_name + else: + return "%s/%s" % (dir_, file_name) + + +def ovs_retval_to_string(retval): + """Many OVS functions return an int which is one of: + - 0: no error yet + - >0: errno value + - EOF: end of file (not necessarily an error; depends on the function + called) + + Returns the appropriate human-readable string.""" + + if not retval: + return "" + if retval > 0: + return os.strerror(retval) + if retval == EOF: + return "End of file" + return "***unknown return value: %s***" % retval + + +def ovs_error(err_no, message, vlog=None): + """Prints 'message' on stderr and emits an ERROR level log message to + 'vlog' if supplied. If 'err_no' is nonzero, then it is formatted with + ovs_retval_to_string() and appended to the message inside parentheses. + + 'message' should not end with a new-line, because this function will add + one itself.""" + + err_msg = "%s: %s" % (PROGRAM_NAME, message) + if err_no: + err_msg += " (%s)" % ovs_retval_to_string(err_no) + + sys.stderr.write("%s\n" % err_msg) + if vlog: + vlog.err(err_msg) + + +def ovs_fatal(*args, **kwargs): + """Prints 'message' on stderr and emits an ERROR level log message to + 'vlog' if supplied. If 'err_no' is nonzero, then it is formatted with + ovs_retval_to_string() and appended to the message inside parentheses. + Then, terminates with exit code 1 (indicating a failure). + + 'message' should not end with a new-line, because this function will add + one itself.""" + + ovs_error(*args, **kwargs) + sys.exit(1) diff --git a/ryu/contrib/ovs/version.py b/ryu/contrib/ovs/version.py new file mode 100644 index 00000000..aa9c9eb3 --- /dev/null +++ b/ryu/contrib/ovs/version.py @@ -0,0 +1,2 @@ +# Generated automatically -- do not modify! -*- buffer-read-only: t -*- +VERSION = "1.7.90" diff --git a/ryu/contrib/ovs/vlog.py b/ryu/contrib/ovs/vlog.py new file mode 100644 index 00000000..f7ace66f --- /dev/null +++ b/ryu/contrib/ovs/vlog.py @@ -0,0 +1,267 @@ + +# Copyright (c) 2011, 2012 Nicira, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import logging +import logging.handlers +import re +import socket +import sys + +import ovs.dirs +import ovs.unixctl +import ovs.util + +FACILITIES = {"console": "info", "file": "info", "syslog": "info"} +LEVELS = { + "dbg": logging.DEBUG, + "info": logging.INFO, + "warn": logging.WARNING, + "err": logging.ERROR, + "emer": logging.CRITICAL, + "off": logging.CRITICAL +} + + +def get_level(level_str): + return LEVELS.get(level_str.lower()) + + +class Vlog: + __inited = False + __msg_num = 0 + __mfl = {} # Module -> facility -> level + __log_file = None + __file_handler = None + + def __init__(self, name): + """Creates a new Vlog object representing a module called 'name'. The + created Vlog object will do nothing until the Vlog.init() static method + is called. Once called, no more Vlog objects may be created.""" + + assert not Vlog.__inited + self.name = name.lower() + if name not in Vlog.__mfl: + Vlog.__mfl[self.name] = FACILITIES.copy() + + def __log(self, level, message, **kwargs): + if not Vlog.__inited: + return + + now = datetime.datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%SZ") + message = ("%s|%s|%s|%s|%s" + % (now, Vlog.__msg_num, self.name, level, message)) + + level = LEVELS.get(level.lower(), logging.DEBUG) + Vlog.__msg_num += 1 + + for f, f_level in Vlog.__mfl[self.name].iteritems(): + f_level = LEVELS.get(f_level, logging.CRITICAL) + if level >= f_level: + logging.getLogger(f).log(level, message, **kwargs) + + def emer(self, message, **kwargs): + self.__log("EMER", message, **kwargs) + + def err(self, message, **kwargs): + self.__log("ERR", message, **kwargs) + + def warn(self, message, **kwargs): + self.__log("WARN", message, **kwargs) + + def info(self, message, **kwargs): + self.__log("INFO", message, **kwargs) + + def dbg(self, message, **kwargs): + self.__log("DBG", message, **kwargs) + + def exception(self, message): + """Logs 'message' at ERR log level. Includes a backtrace when in + exception context.""" + self.err(message, exc_info=True) + + @staticmethod + def init(log_file=None): + """Intializes the Vlog module. Causes Vlog to write to 'log_file' if + not None. Should be called after all Vlog objects have been created. + No logging will occur until this function is called.""" + + if Vlog.__inited: + return + + Vlog.__inited = True + logging.raiseExceptions = False + Vlog.__log_file = log_file + for f in FACILITIES: + logger = logging.getLogger(f) + logger.setLevel(logging.DEBUG) + + try: + if f == "console": + logger.addHandler(logging.StreamHandler(sys.stderr)) + elif f == "syslog": + logger.addHandler(logging.handlers.SysLogHandler( + address="/dev/log", + facility=logging.handlers.SysLogHandler.LOG_DAEMON)) + elif f == "file" and Vlog.__log_file: + Vlog.__file_handler = logging.FileHandler(Vlog.__log_file) + logger.addHandler(Vlog.__file_handler) + except (IOError, socket.error): + logger.setLevel(logging.CRITICAL) + + ovs.unixctl.command_register("vlog/reopen", "", 0, 0, + Vlog._unixctl_vlog_reopen, None) + ovs.unixctl.command_register("vlog/set", "spec", 1, sys.maxint, + Vlog._unixctl_vlog_set, None) + ovs.unixctl.command_register("vlog/list", "", 0, 0, + Vlog._unixctl_vlog_list, None) + + @staticmethod + def set_level(module, facility, level): + """ Sets the log level of the 'module'-'facility' tuple to 'level'. + All three arguments are strings which are interpreted the same as + arguments to the --verbose flag. Should be called after all Vlog + objects have already been created.""" + + module = module.lower() + facility = facility.lower() + level = level.lower() + + if facility != "any" and facility not in FACILITIES: + return + + if module != "any" and module not in Vlog.__mfl: + return + + if level not in LEVELS: + return + + if module == "any": + modules = Vlog.__mfl.keys() + else: + modules = [module] + + if facility == "any": + facilities = FACILITIES.keys() + else: + facilities = [facility] + + for m in modules: + for f in facilities: + Vlog.__mfl[m][f] = level + + @staticmethod + def set_levels_from_string(s): + module = None + level = None + facility = None + + for word in [w.lower() for w in re.split('[ :]', s)]: + if word == "any": + pass + elif word in FACILITIES: + if facility: + return "cannot specify multiple facilities" + facility = word + elif word in LEVELS: + if level: + return "cannot specify multiple levels" + level = word + elif word in Vlog.__mfl: + if module: + return "cannot specify multiple modules" + module = word + else: + return "no facility, level, or module \"%s\"" % word + + Vlog.set_level(module or "any", facility or "any", level or "any") + + @staticmethod + def get_levels(): + lines = [" console syslog file\n", + " ------- ------ ------\n"] + lines.extend(sorted(["%-16s %4s %4s %4s\n" + % (m, + Vlog.__mfl[m]["console"], + Vlog.__mfl[m]["syslog"], + Vlog.__mfl[m]["file"]) for m in Vlog.__mfl])) + return ''.join(lines) + + @staticmethod + def reopen_log_file(): + """Closes and then attempts to re-open the current log file. (This is + useful just after log rotation, to ensure that the new log file starts + being used.)""" + + if Vlog.__log_file: + logger = logging.getLogger("file") + logger.removeHandler(Vlog.__file_handler) + Vlog.__file_handler = logging.FileHandler(Vlog.__log_file) + logger.addHandler(Vlog.__file_handler) + + @staticmethod + def _unixctl_vlog_reopen(conn, unused_argv, unused_aux): + if Vlog.__log_file: + Vlog.reopen_log_file() + conn.reply(None) + else: + conn.reply("Logging to file not configured") + + @staticmethod + def _unixctl_vlog_set(conn, argv, unused_aux): + for arg in argv: + msg = Vlog.set_levels_from_string(arg) + if msg: + conn.reply(msg) + return + conn.reply(None) + + @staticmethod + def _unixctl_vlog_list(conn, unused_argv, unused_aux): + conn.reply(Vlog.get_levels()) + +def add_args(parser): + """Adds vlog related options to 'parser', an ArgumentParser object. The + resulting arguments parsed by 'parser' should be passed to handle_args.""" + + group = parser.add_argument_group(title="Logging Options") + group.add_argument("--log-file", nargs="?", const="default", + help="Enables logging to a file. Default log file" + " is used if LOG_FILE is omitted.") + group.add_argument("-v", "--verbose", nargs="*", + help="Sets logging levels, see ovs-vswitchd(8)." + " Defaults to dbg.") + + +def handle_args(args): + """ Handles command line arguments ('args') parsed by an ArgumentParser. + The ArgumentParser should have been primed by add_args(). Also takes care + of initializing the Vlog module.""" + + log_file = args.log_file + if log_file == "default": + log_file = "%s/%s.log" % (ovs.dirs.LOGDIR, ovs.util.PROGRAM_NAME) + + if args.verbose is None: + args.verbose = [] + elif args.verbose == []: + args.verbose = ["any:any:dbg"] + + for verbose in args.verbose: + msg = Vlog.set_levels_from_string(verbose) + if msg: + ovs.util.ovs_fatal(0, "processing \"%s\": %s" % (verbose, msg)) + + Vlog.init(log_file)