diff --git a/openstack-common.conf b/openstack-common.conf index 573ee975..a5431267 100644 --- a/openstack-common.conf +++ b/openstack-common.conf @@ -1,5 +1,5 @@ [DEFAULT] -modules=setup, jsonutils, xmlutils, timeutils, exception, gettextutils, log, local, notifier/api, notifier/log_notifier, notifier/no_op_notifier, notifier/test_notifier, notifier/__init__, importutils, context, uuidutils, version, threadgroup, db, db.sqlalchemy, excutils, cfgfilter,middleware.base, middleware.debug +modules=setup, jsonutils, xmlutils, timeutils, exception, gettextutils, log, local, notifier, importutils, context, uuidutils, version, threadgroup, db, db.sqlalchemy, excutils, cfgfilter, middleware.base, middleware.debug, rpc, service base=savanna # The following code from 'wsgi' is needed: diff --git a/savanna/openstack/common/context.py b/savanna/openstack/common/context.py index febcf43b..8e886011 100644 --- a/savanna/openstack/common/context.py +++ b/savanna/openstack/common/context.py @@ -61,7 +61,7 @@ class RequestContext(object): 'request_id': self.request_id} -def get_admin_context(show_deleted="no"): +def get_admin_context(show_deleted=False): context = RequestContext(None, tenant=None, is_admin=True, diff --git a/savanna/openstack/common/db/exception.py b/savanna/openstack/common/db/exception.py index 65bed21c..68703265 100644 --- a/savanna/openstack/common/db/exception.py +++ b/savanna/openstack/common/db/exception.py @@ -18,7 +18,7 @@ """DB related custom exceptions.""" -from savanna.openstack.common.gettextutils import _ +from savanna.openstack.common.gettextutils import _ # noqa class DBError(Exception): diff --git a/savanna/openstack/common/db/sqlalchemy/models.py b/savanna/openstack/common/db/sqlalchemy/models.py index f26c1180..320dc62d 100644 --- a/savanna/openstack/common/db/sqlalchemy/models.py +++ b/savanna/openstack/common/db/sqlalchemy/models.py @@ -26,7 +26,7 @@ from sqlalchemy import Column, Integer from sqlalchemy import DateTime from sqlalchemy.orm import object_mapper -from savanna.openstack.common.db.sqlalchemy.session import get_session +from savanna.openstack.common.db.sqlalchemy import session as sa from savanna.openstack.common import timeutils @@ -37,7 +37,7 @@ class ModelBase(object): def save(self, session=None): """Save this object.""" if not session: - session = get_session() + session = sa.get_session() # NOTE(boris-42): This part of code should be look like: # sesssion.add(self) # session.flush() diff --git a/savanna/openstack/common/db/sqlalchemy/session.py b/savanna/openstack/common/db/sqlalchemy/session.py index 0ddeab0b..a858b59e 100644 --- a/savanna/openstack/common/db/sqlalchemy/session.py +++ b/savanna/openstack/common/db/sqlalchemy/session.py @@ -256,12 +256,10 @@ from sqlalchemy.pool import NullPool, StaticPool from sqlalchemy.sql.expression import literal_column from savanna.openstack.common.db import exception -from savanna.openstack.common.gettextutils import _ +from savanna.openstack.common.gettextutils import _ # noqa from savanna.openstack.common import log as logging from savanna.openstack.common import timeutils -DEFAULT = 'DEFAULT' - sqlite_db_opts = [ cfg.StrOpt('sqlite_db', default='savanna.sqlite', @@ -278,8 +276,10 @@ database_opts = [ '../', '$sqlite_db')), help='The SQLAlchemy connection string used to connect to the ' 'database', - deprecated_name='sql_connection', - deprecated_group=DEFAULT, + deprecated_opts=[cfg.DeprecatedOpt('sql_connection', + group='DEFAULT'), + cfg.DeprecatedOpt('sql_connection', + group='DATABASE')], secret=True), cfg.StrOpt('slave_connection', default='', @@ -288,56 +288,71 @@ database_opts = [ secret=True), cfg.IntOpt('idle_timeout', default=3600, - deprecated_name='sql_idle_timeout', - deprecated_group=DEFAULT, + deprecated_opts=[cfg.DeprecatedOpt('sql_idle_timeout', + group='DEFAULT'), + cfg.DeprecatedOpt('sql_idle_timeout', + group='DATABASE')], help='timeout before idle sql connections are reaped'), cfg.IntOpt('min_pool_size', default=1, - deprecated_name='sql_min_pool_size', - deprecated_group=DEFAULT, + deprecated_opts=[cfg.DeprecatedOpt('sql_min_pool_size', + group='DEFAULT'), + cfg.DeprecatedOpt('sql_min_pool_size', + group='DATABASE')], help='Minimum number of SQL connections to keep open in a ' 'pool'), cfg.IntOpt('max_pool_size', default=None, - deprecated_name='sql_max_pool_size', - deprecated_group=DEFAULT, + deprecated_opts=[cfg.DeprecatedOpt('sql_max_pool_size', + group='DEFAULT'), + cfg.DeprecatedOpt('sql_max_pool_size', + group='DATABASE')], help='Maximum number of SQL connections to keep open in a ' 'pool'), cfg.IntOpt('max_retries', default=10, - deprecated_name='sql_max_retries', - deprecated_group=DEFAULT, + deprecated_opts=[cfg.DeprecatedOpt('sql_max_retries', + group='DEFAULT'), + cfg.DeprecatedOpt('sql_max_retries', + group='DATABASE')], help='maximum db connection retries during startup. ' '(setting -1 implies an infinite retry count)'), cfg.IntOpt('retry_interval', default=10, - deprecated_name='sql_retry_interval', - deprecated_group=DEFAULT, + deprecated_opts=[cfg.DeprecatedOpt('sql_retry_interval', + group='DEFAULT'), + cfg.DeprecatedOpt('reconnect_interval', + group='DATABASE')], help='interval between retries of opening a sql connection'), cfg.IntOpt('max_overflow', default=None, - deprecated_name='sql_max_overflow', - deprecated_group=DEFAULT, + deprecated_opts=[cfg.DeprecatedOpt('sql_max_overflow', + group='DEFAULT'), + cfg.DeprecatedOpt('sqlalchemy_max_overflow', + group='DATABASE')], help='If set, use this value for max_overflow with sqlalchemy'), cfg.IntOpt('connection_debug', default=0, - deprecated_name='sql_connection_debug', - deprecated_group=DEFAULT, + deprecated_opts=[cfg.DeprecatedOpt('sql_connection_debug', + group='DEFAULT')], help='Verbosity of SQL debugging information. 0=None, ' '100=Everything'), cfg.BoolOpt('connection_trace', default=False, - deprecated_name='sql_connection_trace', - deprecated_group=DEFAULT, + deprecated_opts=[cfg.DeprecatedOpt('sql_connection_trace', + group='DEFAULT')], help='Add python stack traces to SQL as comment strings'), cfg.IntOpt('pool_timeout', default=None, + deprecated_opts=[cfg.DeprecatedOpt('sqlalchemy_pool_timeout', + group='DATABASE')], help='If set, use this value for pool_timeout with sqlalchemy'), ] CONF = cfg.CONF CONF.register_opts(sqlite_db_opts) CONF.register_opts(database_opts, 'database') + LOG = logging.getLogger(__name__) _ENGINE = None diff --git a/savanna/openstack/common/db/sqlalchemy/utils.py b/savanna/openstack/common/db/sqlalchemy/utils.py old mode 100644 new mode 100755 index 6e4a245c..11e8e68c --- a/savanna/openstack/common/db/sqlalchemy/utils.py +++ b/savanna/openstack/common/db/sqlalchemy/utils.py @@ -18,12 +18,28 @@ # License for the specific language governing permissions and limitations # under the License. -"""Implementation of paginate query.""" - import sqlalchemy +from sqlalchemy import Boolean +from sqlalchemy import CheckConstraint +from sqlalchemy import Column +from sqlalchemy.engine import reflection +from sqlalchemy.ext.compiler import compiles +from sqlalchemy import func +from sqlalchemy import Index +from sqlalchemy import Integer +from sqlalchemy import MetaData +from sqlalchemy.sql.expression import literal_column +from sqlalchemy.sql.expression import UpdateBase +from sqlalchemy.sql import select +from sqlalchemy import String +from sqlalchemy import Table +from sqlalchemy.types import NullType -from savanna.openstack.common.gettextutils import _ +from savanna.openstack.common.gettextutils import _ # noqa + +from savanna.openstack.common import exception from savanna.openstack.common import log as logging +from savanna.openstack.common import timeutils LOG = logging.getLogger(__name__) @@ -85,11 +101,14 @@ def paginate_query(query, model, limit, sort_keys, marker=None, # Add sorting for current_sort_key, current_sort_dir in zip(sort_keys, sort_dirs): - sort_dir_func = { - 'asc': sqlalchemy.asc, - 'desc': sqlalchemy.desc, - }[current_sort_dir] - + try: + sort_dir_func = { + 'asc': sqlalchemy.asc, + 'desc': sqlalchemy.desc, + }[current_sort_dir] + except KeyError: + raise ValueError(_("Unknown sort direction, " + "must be 'desc' or 'asc'")) try: sort_key_attr = getattr(model, current_sort_key) except AttributeError: @@ -114,11 +133,8 @@ def paginate_query(query, model, limit, sort_keys, marker=None, model_attr = getattr(model, sort_keys[i]) if sort_dirs[i] == 'desc': crit_attrs.append((model_attr < marker_values[i])) - elif sort_dirs[i] == 'asc': - crit_attrs.append((model_attr > marker_values[i])) else: - raise ValueError(_("Unknown sort direction, " - "must be 'desc' or 'asc'")) + crit_attrs.append((model_attr > marker_values[i])) criteria = sqlalchemy.sql.and_(*crit_attrs) criteria_list.append(criteria) @@ -130,3 +146,305 @@ def paginate_query(query, model, limit, sort_keys, marker=None, query = query.limit(limit) return query + + +def get_table(engine, name): + """Returns an sqlalchemy table dynamically from db. + + Needed because the models don't work for us in migrations + as models will be far out of sync with the current data. + """ + metadata = MetaData() + metadata.bind = engine + return Table(name, metadata, autoload=True) + + +class InsertFromSelect(UpdateBase): + """Form the base for `INSERT INTO table (SELECT ... )` statement.""" + def __init__(self, table, select): + self.table = table + self.select = select + + +@compiles(InsertFromSelect) +def visit_insert_from_select(element, compiler, **kw): + """Form the `INSERT INTO table (SELECT ... )` statement.""" + return "INSERT INTO %s %s" % ( + compiler.process(element.table, asfrom=True), + compiler.process(element.select)) + + +def _get_not_supported_column(col_name_col_instance, column_name): + try: + column = col_name_col_instance[column_name] + except KeyError: + msg = _("Please specify column %s in col_name_col_instance " + "param. It is required because column has unsupported " + "type by sqlite).") + raise exception.OpenstackException(message=msg % column_name) + + if not isinstance(column, Column): + msg = _("col_name_col_instance param has wrong type of " + "column instance for column %s It should be instance " + "of sqlalchemy.Column.") + raise exception.OpenstackException(message=msg % column_name) + return column + + +def drop_old_duplicate_entries_from_table(migrate_engine, table_name, + use_soft_delete, *uc_column_names): + """Drop all old rows having the same values for columns in uc_columns. + + This method drop (or mark ad `deleted` if use_soft_delete is True) old + duplicate rows form table with name `table_name`. + + :param migrate_engine: Sqlalchemy engine + :param table_name: Table with duplicates + :param use_soft_delete: If True - values will be marked as `deleted`, + if False - values will be removed from table + :param uc_column_names: Unique constraint columns + """ + meta = MetaData() + meta.bind = migrate_engine + + table = Table(table_name, meta, autoload=True) + columns_for_group_by = [table.c[name] for name in uc_column_names] + + columns_for_select = [func.max(table.c.id)] + columns_for_select.extend(columns_for_group_by) + + duplicated_rows_select = select(columns_for_select, + group_by=columns_for_group_by, + having=func.count(table.c.id) > 1) + + for row in migrate_engine.execute(duplicated_rows_select): + # NOTE(boris-42): Do not remove row that has the biggest ID. + delete_condition = table.c.id != row[0] + is_none = None # workaround for pyflakes + delete_condition &= table.c.deleted_at == is_none + for name in uc_column_names: + delete_condition &= table.c[name] == row[name] + + rows_to_delete_select = select([table.c.id]).where(delete_condition) + for row in migrate_engine.execute(rows_to_delete_select).fetchall(): + LOG.info(_("Deleting duplicated row with id: %(id)s from table: " + "%(table)s") % dict(id=row[0], table=table_name)) + + if use_soft_delete: + delete_statement = table.update().\ + where(delete_condition).\ + values({ + 'deleted': literal_column('id'), + 'updated_at': literal_column('updated_at'), + 'deleted_at': timeutils.utcnow() + }) + else: + delete_statement = table.delete().where(delete_condition) + migrate_engine.execute(delete_statement) + + +def _get_default_deleted_value(table): + if isinstance(table.c.id.type, Integer): + return 0 + if isinstance(table.c.id.type, String): + return "" + raise exception.OpenstackException( + message=_("Unsupported id columns type")) + + +def _restore_indexes_on_deleted_columns(migrate_engine, table_name, indexes): + table = get_table(migrate_engine, table_name) + + insp = reflection.Inspector.from_engine(migrate_engine) + real_indexes = insp.get_indexes(table_name) + existing_index_names = dict( + [(index['name'], index['column_names']) for index in real_indexes]) + + # NOTE(boris-42): Restore indexes on `deleted` column + for index in indexes: + if 'deleted' not in index['column_names']: + continue + name = index['name'] + if name in existing_index_names: + column_names = [table.c[c] for c in existing_index_names[name]] + old_index = Index(name, *column_names, unique=index["unique"]) + old_index.drop(migrate_engine) + + column_names = [table.c[c] for c in index['column_names']] + new_index = Index(index["name"], *column_names, unique=index["unique"]) + new_index.create(migrate_engine) + + +def change_deleted_column_type_to_boolean(migrate_engine, table_name, + **col_name_col_instance): + if migrate_engine.name == "sqlite": + return _change_deleted_column_type_to_boolean_sqlite( + migrate_engine, table_name, **col_name_col_instance) + insp = reflection.Inspector.from_engine(migrate_engine) + indexes = insp.get_indexes(table_name) + + table = get_table(migrate_engine, table_name) + + old_deleted = Column('old_deleted', Boolean, default=False) + old_deleted.create(table, populate_default=False) + + table.update().\ + where(table.c.deleted == table.c.id).\ + values(old_deleted=True).\ + execute() + + table.c.deleted.drop() + table.c.old_deleted.alter(name="deleted") + + _restore_indexes_on_deleted_columns(migrate_engine, table_name, indexes) + + +def _change_deleted_column_type_to_boolean_sqlite(migrate_engine, table_name, + **col_name_col_instance): + insp = reflection.Inspector.from_engine(migrate_engine) + table = get_table(migrate_engine, table_name) + + columns = [] + for column in table.columns: + column_copy = None + if column.name != "deleted": + if isinstance(column.type, NullType): + column_copy = _get_not_supported_column(col_name_col_instance, + column.name) + else: + column_copy = column.copy() + else: + column_copy = Column('deleted', Boolean, default=0) + columns.append(column_copy) + + constraints = [constraint.copy() for constraint in table.constraints] + + meta = MetaData(bind=migrate_engine) + new_table = Table(table_name + "__tmp__", meta, + *(columns + constraints)) + new_table.create() + + indexes = [] + for index in insp.get_indexes(table_name): + column_names = [new_table.c[c] for c in index['column_names']] + indexes.append(Index(index["name"], *column_names, + unique=index["unique"])) + + c_select = [] + for c in table.c: + if c.name != "deleted": + c_select.append(c) + else: + c_select.append(table.c.deleted == table.c.id) + + ins = InsertFromSelect(new_table, select(c_select)) + migrate_engine.execute(ins) + + table.drop() + [index.create(migrate_engine) for index in indexes] + + new_table.rename(table_name) + new_table.update().\ + where(new_table.c.deleted == new_table.c.id).\ + values(deleted=True).\ + execute() + + +def change_deleted_column_type_to_id_type(migrate_engine, table_name, + **col_name_col_instance): + if migrate_engine.name == "sqlite": + return _change_deleted_column_type_to_id_type_sqlite( + migrate_engine, table_name, **col_name_col_instance) + insp = reflection.Inspector.from_engine(migrate_engine) + indexes = insp.get_indexes(table_name) + + table = get_table(migrate_engine, table_name) + + new_deleted = Column('new_deleted', table.c.id.type, + default=_get_default_deleted_value(table)) + new_deleted.create(table, populate_default=True) + + deleted = True # workaround for pyflakes + table.update().\ + where(table.c.deleted == deleted).\ + values(new_deleted=table.c.id).\ + execute() + table.c.deleted.drop() + table.c.new_deleted.alter(name="deleted") + + _restore_indexes_on_deleted_columns(migrate_engine, table_name, indexes) + + +def _change_deleted_column_type_to_id_type_sqlite(migrate_engine, table_name, + **col_name_col_instance): + # NOTE(boris-42): sqlaclhemy-migrate can't drop column with check + # constraints in sqlite DB and our `deleted` column has + # 2 check constraints. So there is only one way to remove + # these constraints: + # 1) Create new table with the same columns, constraints + # and indexes. (except deleted column). + # 2) Copy all data from old to new table. + # 3) Drop old table. + # 4) Rename new table to old table name. + insp = reflection.Inspector.from_engine(migrate_engine) + meta = MetaData(bind=migrate_engine) + table = Table(table_name, meta, autoload=True) + default_deleted_value = _get_default_deleted_value(table) + + columns = [] + for column in table.columns: + column_copy = None + if column.name != "deleted": + if isinstance(column.type, NullType): + column_copy = _get_not_supported_column(col_name_col_instance, + column.name) + else: + column_copy = column.copy() + else: + column_copy = Column('deleted', table.c.id.type, + default=default_deleted_value) + columns.append(column_copy) + + def is_deleted_column_constraint(constraint): + # NOTE(boris-42): There is no other way to check is CheckConstraint + # associated with deleted column. + if not isinstance(constraint, CheckConstraint): + return False + sqltext = str(constraint.sqltext) + return (sqltext.endswith("deleted in (0, 1)") or + sqltext.endswith("deleted IN (:deleted_1, :deleted_2)")) + + constraints = [] + for constraint in table.constraints: + if not is_deleted_column_constraint(constraint): + constraints.append(constraint.copy()) + + new_table = Table(table_name + "__tmp__", meta, + *(columns + constraints)) + new_table.create() + + indexes = [] + for index in insp.get_indexes(table_name): + column_names = [new_table.c[c] for c in index['column_names']] + indexes.append(Index(index["name"], *column_names, + unique=index["unique"])) + + ins = InsertFromSelect(new_table, table.select()) + migrate_engine.execute(ins) + + table.drop() + [index.create(migrate_engine) for index in indexes] + + new_table.rename(table_name) + deleted = True # workaround for pyflakes + new_table.update().\ + where(new_table.c.deleted == deleted).\ + values(deleted=new_table.c.id).\ + execute() + + # NOTE(boris-42): Fix value of deleted column: False -> "" or 0. + deleted = False # workaround for pyflakes + new_table.update().\ + where(new_table.c.deleted == deleted).\ + values(deleted=default_deleted_value).\ + execute() diff --git a/savanna/openstack/common/eventlet_backdoor.py b/savanna/openstack/common/eventlet_backdoor.py new file mode 100644 index 00000000..1e9a5b90 --- /dev/null +++ b/savanna/openstack/common/eventlet_backdoor.py @@ -0,0 +1,146 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright (c) 2012 OpenStack Foundation. +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import print_function + +import errno +import gc +import os +import pprint +import socket +import sys +import traceback + +import eventlet +import eventlet.backdoor +import greenlet +from oslo.config import cfg + +from savanna.openstack.common.gettextutils import _ # noqa +from savanna.openstack.common import log as logging + +help_for_backdoor_port = ( + "Acceptable values are 0, , and :, where 0 results " + "in listening on a random tcp port number; results in listening " + "on the specified port number (and not enabling backdoor if that port " + "is in use); and : results in listening on the smallest " + "unused port number within the specified range of port numbers. The " + "chosen port is displayed in the service's log file.") +eventlet_backdoor_opts = [ + cfg.StrOpt('backdoor_port', + default=None, + help="Enable eventlet backdoor. %s" % help_for_backdoor_port) +] + +CONF = cfg.CONF +CONF.register_opts(eventlet_backdoor_opts) +LOG = logging.getLogger(__name__) + + +class EventletBackdoorConfigValueError(Exception): + def __init__(self, port_range, help_msg, ex): + msg = ('Invalid backdoor_port configuration %(range)s: %(ex)s. ' + '%(help)s' % + {'range': port_range, 'ex': ex, 'help': help_msg}) + super(EventletBackdoorConfigValueError, self).__init__(msg) + self.port_range = port_range + + +def _dont_use_this(): + print("Don't use this, just disconnect instead") + + +def _find_objects(t): + return filter(lambda o: isinstance(o, t), gc.get_objects()) + + +def _print_greenthreads(): + for i, gt in enumerate(_find_objects(greenlet.greenlet)): + print(i, gt) + traceback.print_stack(gt.gr_frame) + print() + + +def _print_nativethreads(): + for threadId, stack in sys._current_frames().items(): + print(threadId) + traceback.print_stack(stack) + print() + + +def _parse_port_range(port_range): + if ':' not in port_range: + start, end = port_range, port_range + else: + start, end = port_range.split(':', 1) + try: + start, end = int(start), int(end) + if end < start: + raise ValueError + return start, end + except ValueError as ex: + raise EventletBackdoorConfigValueError(port_range, ex, + help_for_backdoor_port) + + +def _listen(host, start_port, end_port, listen_func): + try_port = start_port + while True: + try: + return listen_func((host, try_port)) + except socket.error as exc: + if (exc.errno != errno.EADDRINUSE or + try_port >= end_port): + raise + try_port += 1 + + +def initialize_if_enabled(): + backdoor_locals = { + 'exit': _dont_use_this, # So we don't exit the entire process + 'quit': _dont_use_this, # So we don't exit the entire process + 'fo': _find_objects, + 'pgt': _print_greenthreads, + 'pnt': _print_nativethreads, + } + + if CONF.backdoor_port is None: + return None + + start_port, end_port = _parse_port_range(str(CONF.backdoor_port)) + + # NOTE(johannes): The standard sys.displayhook will print the value of + # the last expression and set it to __builtin__._, which overwrites + # the __builtin__._ that gettext sets. Let's switch to using pprint + # since it won't interact poorly with gettext, and it's easier to + # read the output too. + def displayhook(val): + if val is not None: + pprint.pprint(val) + sys.displayhook = displayhook + + sock = _listen('localhost', start_port, end_port, eventlet.listen) + + # In the case of backdoor port being zero, a port number is assigned by + # listen(). In any case, pull the port number out here. + port = sock.getsockname()[1] + LOG.info(_('Eventlet backdoor listening on %(port)s for process %(pid)d') % + {'port': port, 'pid': os.getpid()}) + eventlet.spawn_n(eventlet.backdoor.backdoor_server, sock, + locals=backdoor_locals) + return port diff --git a/savanna/openstack/common/exception.py b/savanna/openstack/common/exception.py index 124a7679..d529b4c4 100644 --- a/savanna/openstack/common/exception.py +++ b/savanna/openstack/common/exception.py @@ -21,7 +21,7 @@ Exceptions common to OpenStack projects import logging -from savanna.openstack.common.gettextutils import _ +from savanna.openstack.common.gettextutils import _ # noqa _FATAL_EXCEPTION_FORMAT_ERRORS = False diff --git a/savanna/openstack/common/excutils.py b/savanna/openstack/common/excutils.py index 84a92521..19896c7d 100644 --- a/savanna/openstack/common/excutils.py +++ b/savanna/openstack/common/excutils.py @@ -19,17 +19,15 @@ Exception related utilities. """ -import contextlib import logging import sys import time import traceback -from savanna.openstack.common.gettextutils import _ +from savanna.openstack.common.gettextutils import _ # noqa -@contextlib.contextmanager -def save_and_reraise_exception(): +class save_and_reraise_exception(object): """Save current exception, run some code and then re-raise. In some cases the exception context can be cleared, resulting in None @@ -41,15 +39,33 @@ def save_and_reraise_exception(): To work around this, we save the exception state, run handler code, and then re-raise the original exception. If another exception occurs, the saved exception is logged and the new exception is re-raised. - """ - type_, value, tb = sys.exc_info() - try: - yield + + In some cases the caller may not want to re-raise the exception, and + for those circumstances this context provides a reraise flag that + can be used to suppress the exception. For example: + except Exception: - logging.error(_('Original exception being dropped: %s'), - traceback.format_exception(type_, value, tb)) - raise - raise type_, value, tb + with save_and_reraise_exception() as ctxt: + decide_if_need_reraise() + if not should_be_reraised: + ctxt.reraise = False + """ + def __init__(self): + self.reraise = True + + def __enter__(self): + self.type_, self.value, self.tb, = sys.exc_info() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is not None: + logging.error(_('Original exception being dropped: %s'), + traceback.format_exception(self.type_, + self.value, + self.tb)) + return False + if self.reraise: + raise self.type_, self.value, self.tb def forever_retry_uncaught_exceptions(infunc): diff --git a/savanna/openstack/common/fileutils.py b/savanna/openstack/common/fileutils.py index 59c24bf9..c18faddf 100644 --- a/savanna/openstack/common/fileutils.py +++ b/savanna/openstack/common/fileutils.py @@ -21,7 +21,7 @@ import errno import os from savanna.openstack.common import excutils -from savanna.openstack.common.gettextutils import _ +from savanna.openstack.common.gettextutils import _ # noqa from savanna.openstack.common import log as logging LOG = logging.getLogger(__name__) diff --git a/savanna/openstack/common/gettextutils.py b/savanna/openstack/common/gettextutils.py index 40c6e400..8407b8bb 100644 --- a/savanna/openstack/common/gettextutils.py +++ b/savanna/openstack/common/gettextutils.py @@ -28,8 +28,11 @@ import copy import gettext import logging.handlers import os +import re import UserString +import six + _localedir = os.environ.get('savanna'.upper() + '_LOCALEDIR') _t = gettext.translation('savanna', localedir=_localedir, fallback=True) @@ -120,7 +123,29 @@ class Message(UserString.UserString, object): if self.params is not None: full_msg = full_msg % self.params - return unicode(full_msg) + return six.text_type(full_msg) + + def _save_dictionary_parameter(self, dict_param): + full_msg = self.data + # look for %(blah) fields in string; + # ignore %% and deal with the + # case where % is first character on the line + keys = re.findall('(?:[^%]|^)%\((\w*)\)[a-z]', full_msg) + + # if we don't find any %(blah) blocks but have a %s + if not keys and re.findall('(?:[^%]|^)%[a-z]', full_msg): + # apparently the full dictionary is the parameter + params = copy.deepcopy(dict_param) + else: + params = {} + for key in keys: + try: + params[key] = copy.deepcopy(dict_param[key]) + except TypeError: + # cast uncopyable thing to unicode string + params[key] = unicode(dict_param[key]) + + return params def _save_parameters(self, other): # we check for None later to see if @@ -128,8 +153,16 @@ class Message(UserString.UserString, object): # so encapsulate if our parameter is actually None if other is None: self.params = (other, ) + elif isinstance(other, dict): + self.params = self._save_dictionary_parameter(other) else: - self.params = copy.deepcopy(other) + # fallback to casting to unicode, + # this will handle the problematic python code-like + # objects that cannot be deep-copied + try: + self.params = copy.deepcopy(other) + except TypeError: + self.params = unicode(other) return self diff --git a/savanna/openstack/common/lockutils.py b/savanna/openstack/common/lockutils.py index a2c32e98..e880f4a9 100644 --- a/savanna/openstack/common/lockutils.py +++ b/savanna/openstack/common/lockutils.py @@ -16,11 +16,10 @@ # under the License. +import contextlib import errno import functools import os -import shutil -import tempfile import time import weakref @@ -28,7 +27,7 @@ from eventlet import semaphore from oslo.config import cfg from savanna.openstack.common import fileutils -from savanna.openstack.common.gettextutils import _ +from savanna.openstack.common.gettextutils import _ # noqa from savanna.openstack.common import local from savanna.openstack.common import log as logging @@ -40,8 +39,7 @@ util_opts = [ cfg.BoolOpt('disable_process_locking', default=False, help='Whether to disable inter-process locks'), cfg.StrOpt('lock_path', - help=('Directory to use for lock files. Default to a ' - 'temp directory')) + help=('Directory to use for lock files.')) ] @@ -135,7 +133,87 @@ else: _semaphores = weakref.WeakValueDictionary() -def synchronized(name, lock_file_prefix, external=False, lock_path=None): +@contextlib.contextmanager +def lock(name, lock_file_prefix=None, external=False, lock_path=None): + """Context based lock + + This function yields a `semaphore.Semaphore` instance unless external is + True, in which case, it'll yield an InterProcessLock instance. + + :param lock_file_prefix: The lock_file_prefix argument is used to provide + lock files on disk with a meaningful prefix. + + :param external: The external keyword argument denotes whether this lock + should work across multiple processes. This means that if two different + workers both run a a method decorated with @synchronized('mylock', + external=True), only one of them will execute at a time. + + :param lock_path: The lock_path keyword argument is used to specify a + special location for external lock files to live. If nothing is set, then + CONF.lock_path is used as a default. + """ + # NOTE(soren): If we ever go natively threaded, this will be racy. + # See http://stackoverflow.com/questions/5390569/dyn + # amically-allocating-and-destroying-mutexes + sem = _semaphores.get(name, semaphore.Semaphore()) + if name not in _semaphores: + # this check is not racy - we're already holding ref locally + # so GC won't remove the item and there was no IO switch + # (only valid in greenthreads) + _semaphores[name] = sem + + with sem: + LOG.debug(_('Got semaphore "%(lock)s"'), {'lock': name}) + + # NOTE(mikal): I know this looks odd + if not hasattr(local.strong_store, 'locks_held'): + local.strong_store.locks_held = [] + local.strong_store.locks_held.append(name) + + try: + if external and not CONF.disable_process_locking: + LOG.debug(_('Attempting to grab file lock "%(lock)s"'), + {'lock': name}) + + # We need a copy of lock_path because it is non-local + local_lock_path = lock_path or CONF.lock_path + if not local_lock_path: + raise cfg.RequiredOptError('lock_path') + + if not os.path.exists(local_lock_path): + fileutils.ensure_tree(local_lock_path) + LOG.info(_('Created lock path: %s'), local_lock_path) + + def add_prefix(name, prefix): + if not prefix: + return name + sep = '' if prefix.endswith('-') else '-' + return '%s%s%s' % (prefix, sep, name) + + # NOTE(mikal): the lock name cannot contain directory + # separators + lock_file_name = add_prefix(name.replace(os.sep, '_'), + lock_file_prefix) + + lock_file_path = os.path.join(local_lock_path, lock_file_name) + + try: + lock = InterProcessLock(lock_file_path) + with lock as lock: + LOG.debug(_('Got file lock "%(lock)s" at %(path)s'), + {'lock': name, 'path': lock_file_path}) + yield lock + finally: + LOG.debug(_('Released file lock "%(lock)s" at %(path)s'), + {'lock': name, 'path': lock_file_path}) + else: + yield sem + + finally: + local.strong_store.locks_held.remove(name) + + +def synchronized(name, lock_file_prefix=None, external=False, lock_path=None): """Synchronization decorator. Decorating a method like so:: @@ -157,99 +235,18 @@ def synchronized(name, lock_file_prefix, external=False, lock_path=None): ... This way only one of either foo or bar can be executing at a time. - - :param lock_file_prefix: The lock_file_prefix argument is used to provide - lock files on disk with a meaningful prefix. The prefix should end with a - hyphen ('-') if specified. - - :param external: The external keyword argument denotes whether this lock - should work across multiple processes. This means that if two different - workers both run a a method decorated with @synchronized('mylock', - external=True), only one of them will execute at a time. - - :param lock_path: The lock_path keyword argument is used to specify a - special location for external lock files to live. If nothing is set, then - CONF.lock_path is used as a default. """ def wrap(f): @functools.wraps(f) def inner(*args, **kwargs): - # NOTE(soren): If we ever go natively threaded, this will be racy. - # See http://stackoverflow.com/questions/5390569/dyn - # amically-allocating-and-destroying-mutexes - sem = _semaphores.get(name, semaphore.Semaphore()) - if name not in _semaphores: - # this check is not racy - we're already holding ref locally - # so GC won't remove the item and there was no IO switch - # (only valid in greenthreads) - _semaphores[name] = sem + with lock(name, lock_file_prefix, external, lock_path): + LOG.debug(_('Got semaphore / lock "%(function)s"'), + {'function': f.__name__}) + return f(*args, **kwargs) - with sem: - LOG.debug(_('Got semaphore "%(lock)s" for method ' - '"%(method)s"...'), {'lock': name, - 'method': f.__name__}) - - # NOTE(mikal): I know this looks odd - if not hasattr(local.strong_store, 'locks_held'): - local.strong_store.locks_held = [] - local.strong_store.locks_held.append(name) - - try: - if external and not CONF.disable_process_locking: - LOG.debug(_('Attempting to grab file lock "%(lock)s" ' - 'for method "%(method)s"...'), - {'lock': name, 'method': f.__name__}) - cleanup_dir = False - - # We need a copy of lock_path because it is non-local - local_lock_path = lock_path - if not local_lock_path: - local_lock_path = CONF.lock_path - - if not local_lock_path: - cleanup_dir = True - local_lock_path = tempfile.mkdtemp() - - if not os.path.exists(local_lock_path): - fileutils.ensure_tree(local_lock_path) - - # NOTE(mikal): the lock name cannot contain directory - # separators - safe_name = name.replace(os.sep, '_') - lock_file_name = '%s%s' % (lock_file_prefix, safe_name) - lock_file_path = os.path.join(local_lock_path, - lock_file_name) - - try: - lock = InterProcessLock(lock_file_path) - with lock: - LOG.debug(_('Got file lock "%(lock)s" at ' - '%(path)s for method ' - '"%(method)s"...'), - {'lock': name, - 'path': lock_file_path, - 'method': f.__name__}) - retval = f(*args, **kwargs) - finally: - LOG.debug(_('Released file lock "%(lock)s" at ' - '%(path)s for method "%(method)s"...'), - {'lock': name, - 'path': lock_file_path, - 'method': f.__name__}) - # NOTE(vish): This removes the tempdir if we needed - # to create one. This is used to - # cleanup the locks left behind by unit - # tests. - if cleanup_dir: - shutil.rmtree(local_lock_path) - else: - retval = f(*args, **kwargs) - - finally: - local.strong_store.locks_held.remove(name) - - return retval + LOG.debug(_('Semaphore / lock released "%(function)s"'), + {'function': f.__name__}) return inner return wrap @@ -273,7 +270,7 @@ def synchronized_with_prefix(lock_file_prefix): ... The lock_file_prefix argument is used to provide lock files on disk with a - meaningful prefix. The prefix should end with a hyphen ('-') if specified. + meaningful prefix. """ return functools.partial(synchronized, lock_file_prefix=lock_file_prefix) diff --git a/savanna/openstack/common/log.py b/savanna/openstack/common/log.py index 4825c228..0e8473b6 100644 --- a/savanna/openstack/common/log.py +++ b/savanna/openstack/common/log.py @@ -42,7 +42,7 @@ import traceback from oslo.config import cfg -from savanna.openstack.common.gettextutils import _ +from savanna.openstack.common.gettextutils import _ # noqa from savanna.openstack.common import importutils from savanna.openstack.common import jsonutils from savanna.openstack.common import local @@ -74,7 +74,8 @@ logging_cli_opts = [ cfg.StrOpt('log-format', default=None, metavar='FORMAT', - help='A logging.Formatter log message format string which may ' + help='DEPRECATED. ' + 'A logging.Formatter log message format string which may ' 'use any of the available logging.LogRecord attributes. ' 'This option is deprecated. Please use ' 'logging_context_format_string and ' diff --git a/savanna/openstack/common/loopingcall.py b/savanna/openstack/common/loopingcall.py index 30d592f7..2c7ccdd2 100644 --- a/savanna/openstack/common/loopingcall.py +++ b/savanna/openstack/common/loopingcall.py @@ -22,7 +22,7 @@ import sys from eventlet import event from eventlet import greenthread -from savanna.openstack.common.gettextutils import _ +from savanna.openstack.common.gettextutils import _ # noqa from savanna.openstack.common import log as logging from savanna.openstack.common import timeutils diff --git a/savanna/openstack/common/network_utils.py b/savanna/openstack/common/network_utils.py new file mode 100644 index 00000000..dbed1ceb --- /dev/null +++ b/savanna/openstack/common/network_utils.py @@ -0,0 +1,81 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2012 OpenStack Foundation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Network-related utilities and helper functions. +""" + +import urlparse + + +def parse_host_port(address, default_port=None): + """Interpret a string as a host:port pair. + + An IPv6 address MUST be escaped if accompanied by a port, + because otherwise ambiguity ensues: 2001:db8:85a3::8a2e:370:7334 + means both [2001:db8:85a3::8a2e:370:7334] and + [2001:db8:85a3::8a2e:370]:7334. + + >>> parse_host_port('server01:80') + ('server01', 80) + >>> parse_host_port('server01') + ('server01', None) + >>> parse_host_port('server01', default_port=1234) + ('server01', 1234) + >>> parse_host_port('[::1]:80') + ('::1', 80) + >>> parse_host_port('[::1]') + ('::1', None) + >>> parse_host_port('[::1]', default_port=1234) + ('::1', 1234) + >>> parse_host_port('2001:db8:85a3::8a2e:370:7334', default_port=1234) + ('2001:db8:85a3::8a2e:370:7334', 1234) + + """ + if address[0] == '[': + # Escaped ipv6 + _host, _port = address[1:].split(']') + host = _host + if ':' in _port: + port = _port.split(':')[1] + else: + port = default_port + else: + if address.count(':') == 1: + host, port = address.split(':') + else: + # 0 means ipv4, >1 means ipv6. + # We prohibit unescaped ipv6 addresses with port. + host = address + port = default_port + + return (host, None if port is None else int(port)) + + +def urlsplit(url, scheme='', allow_fragments=True): + """Parse a URL using urlparse.urlsplit(), splitting query and fragments. + This function papers over Python issue9374 when needed. + + The parameters are the same as urlparse.urlsplit. + """ + scheme, netloc, path, query, fragment = urlparse.urlsplit( + url, scheme, allow_fragments) + if allow_fragments and '#' in path: + path, fragment = path.split('#', 1) + if '?' in path: + path, query = path.split('?', 1) + return urlparse.SplitResult(scheme, netloc, path, query, fragment) diff --git a/savanna/openstack/common/notifier/api.py b/savanna/openstack/common/notifier/api.py index eaaa982d..bc2b307e 100644 --- a/savanna/openstack/common/notifier/api.py +++ b/savanna/openstack/common/notifier/api.py @@ -13,12 +13,13 @@ # License for the specific language governing permissions and limitations # under the License. +import socket import uuid from oslo.config import cfg from savanna.openstack.common import context -from savanna.openstack.common.gettextutils import _ +from savanna.openstack.common.gettextutils import _ # noqa from savanna.openstack.common import importutils from savanna.openstack.common import jsonutils from savanna.openstack.common import log as logging @@ -35,7 +36,7 @@ notifier_opts = [ default='INFO', help='Default notification level for outgoing notifications'), cfg.StrOpt('default_publisher_id', - default='$host', + default=None, help='Default publisher_id for outgoing notifications'), ] @@ -74,7 +75,7 @@ def notify_decorator(name, fn): ctxt = context.get_context_from_function_and_args(fn, args, kwarg) notify(ctxt, - CONF.default_publisher_id, + CONF.default_publisher_id or socket.gethostname(), name, CONF.default_notification_level, body) @@ -84,7 +85,10 @@ def notify_decorator(name, fn): def publisher_id(service, host=None): if not host: - host = CONF.host + try: + host = CONF.host + except AttributeError: + host = CONF.default_publisher_id or socket.gethostname() return "%s.%s" % (service, host) @@ -153,29 +157,16 @@ def _get_drivers(): if _drivers is None: _drivers = {} for notification_driver in CONF.notification_driver: - add_driver(notification_driver) - + try: + driver = importutils.import_module(notification_driver) + _drivers[notification_driver] = driver + except ImportError: + LOG.exception(_("Failed to load notifier %s. " + "These notifications will not be sent.") % + notification_driver) return _drivers.values() -def add_driver(notification_driver): - """Add a notification driver at runtime.""" - # Make sure the driver list is initialized. - _get_drivers() - if isinstance(notification_driver, basestring): - # Load and add - try: - driver = importutils.import_module(notification_driver) - _drivers[notification_driver] = driver - except ImportError: - LOG.exception(_("Failed to load notifier %s. " - "These notifications will not be sent.") % - notification_driver) - else: - # Driver is already loaded; just add the object. - _drivers[notification_driver] = notification_driver - - def _reset_drivers(): """Used by unit tests to reset the drivers.""" global _drivers diff --git a/savanna/openstack/common/notifier/rpc_notifier.py b/savanna/openstack/common/notifier/rpc_notifier.py new file mode 100644 index 00000000..84dcec05 --- /dev/null +++ b/savanna/openstack/common/notifier/rpc_notifier.py @@ -0,0 +1,46 @@ +# Copyright 2011 OpenStack Foundation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from oslo.config import cfg + +from savanna.openstack.common import context as req_context +from savanna.openstack.common.gettextutils import _ # noqa +from savanna.openstack.common import log as logging +from savanna.openstack.common import rpc + +LOG = logging.getLogger(__name__) + +notification_topic_opt = cfg.ListOpt( + 'notification_topics', default=['notifications', ], + help='AMQP topic used for openstack notifications') + +CONF = cfg.CONF +CONF.register_opt(notification_topic_opt) + + +def notify(context, message): + """Sends a notification via RPC.""" + if not context: + context = req_context.get_admin_context() + priority = message.get('priority', + CONF.default_notification_level) + priority = priority.lower() + for topic in CONF.notification_topics: + topic = '%s.%s' % (topic, priority) + try: + rpc.notify(context, topic, message) + except Exception: + LOG.exception(_("Could not send notification to %(topic)s. " + "Payload=%(message)s"), locals()) diff --git a/savanna/openstack/common/notifier/rpc_notifier2.py b/savanna/openstack/common/notifier/rpc_notifier2.py new file mode 100644 index 00000000..9c9256c5 --- /dev/null +++ b/savanna/openstack/common/notifier/rpc_notifier2.py @@ -0,0 +1,52 @@ +# Copyright 2011 OpenStack Foundation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +'''messaging based notification driver, with message envelopes''' + +from oslo.config import cfg + +from savanna.openstack.common import context as req_context +from savanna.openstack.common.gettextutils import _ # noqa +from savanna.openstack.common import log as logging +from savanna.openstack.common import rpc + +LOG = logging.getLogger(__name__) + +notification_topic_opt = cfg.ListOpt( + 'topics', default=['notifications', ], + help='AMQP topic(s) used for openstack notifications') + +opt_group = cfg.OptGroup(name='rpc_notifier2', + title='Options for rpc_notifier2') + +CONF = cfg.CONF +CONF.register_group(opt_group) +CONF.register_opt(notification_topic_opt, opt_group) + + +def notify(context, message): + """Sends a notification via RPC.""" + if not context: + context = req_context.get_admin_context() + priority = message.get('priority', + CONF.default_notification_level) + priority = priority.lower() + for topic in CONF.rpc_notifier2.topics: + topic = '%s.%s' % (topic, priority) + try: + rpc.notify(context, topic, message, envelope=True) + except Exception: + LOG.exception(_("Could not send notification to %(topic)s. " + "Payload=%(message)s"), locals()) diff --git a/savanna/openstack/common/rpc/__init__.py b/savanna/openstack/common/rpc/__init__.py new file mode 100644 index 00000000..5aa6a722 --- /dev/null +++ b/savanna/openstack/common/rpc/__init__.py @@ -0,0 +1,307 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# Copyright 2011 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +A remote procedure call (rpc) abstraction. + +For some wrappers that add message versioning to rpc, see: + rpc.dispatcher + rpc.proxy +""" + +import inspect + +from oslo.config import cfg + +from savanna.openstack.common.gettextutils import _ # noqa +from savanna.openstack.common import importutils +from savanna.openstack.common import local +from savanna.openstack.common import log as logging + + +LOG = logging.getLogger(__name__) + + +rpc_opts = [ + cfg.StrOpt('rpc_backend', + default='%s.impl_kombu' % __package__, + help="The messaging module to use, defaults to kombu."), + cfg.IntOpt('rpc_thread_pool_size', + default=64, + help='Size of RPC thread pool'), + cfg.IntOpt('rpc_conn_pool_size', + default=30, + help='Size of RPC connection pool'), + cfg.IntOpt('rpc_response_timeout', + default=60, + help='Seconds to wait for a response from call or multicall'), + cfg.IntOpt('rpc_cast_timeout', + default=30, + help='Seconds to wait before a cast expires (TTL). ' + 'Only supported by impl_zmq.'), + cfg.ListOpt('allowed_rpc_exception_modules', + default=['savanna.openstack.common.exception', + 'nova.exception', + 'cinder.exception', + 'exceptions', + ], + help='Modules of exceptions that are permitted to be recreated' + 'upon receiving exception data from an rpc call.'), + cfg.BoolOpt('fake_rabbit', + default=False, + help='If passed, use a fake RabbitMQ provider'), + cfg.StrOpt('control_exchange', + default='openstack', + help='AMQP exchange to connect to if using RabbitMQ or Qpid'), +] + +CONF = cfg.CONF +CONF.register_opts(rpc_opts) + + +def set_defaults(control_exchange): + cfg.set_defaults(rpc_opts, + control_exchange=control_exchange) + + +def create_connection(new=True): + """Create a connection to the message bus used for rpc. + + For some example usage of creating a connection and some consumers on that + connection, see nova.service. + + :param new: Whether or not to create a new connection. A new connection + will be created by default. If new is False, the + implementation is free to return an existing connection from a + pool. + + :returns: An instance of openstack.common.rpc.common.Connection + """ + return _get_impl().create_connection(CONF, new=new) + + +def _check_for_lock(): + if not CONF.debug: + return None + + if ((hasattr(local.strong_store, 'locks_held') + and local.strong_store.locks_held)): + stack = ' :: '.join([frame[3] for frame in inspect.stack()]) + LOG.warn(_('A RPC is being made while holding a lock. The locks ' + 'currently held are %(locks)s. This is probably a bug. ' + 'Please report it. Include the following: [%(stack)s].'), + {'locks': local.strong_store.locks_held, + 'stack': stack}) + return True + + return False + + +def call(context, topic, msg, timeout=None, check_for_lock=False): + """Invoke a remote method that returns something. + + :param context: Information that identifies the user that has made this + request. + :param topic: The topic to send the rpc message to. This correlates to the + topic argument of + openstack.common.rpc.common.Connection.create_consumer() + and only applies when the consumer was created with + fanout=False. + :param msg: This is a dict in the form { "method" : "method_to_invoke", + "args" : dict_of_kwargs } + :param timeout: int, number of seconds to use for a response timeout. + If set, this overrides the rpc_response_timeout option. + :param check_for_lock: if True, a warning is emitted if a RPC call is made + with a lock held. + + :returns: A dict from the remote method. + + :raises: openstack.common.rpc.common.Timeout if a complete response + is not received before the timeout is reached. + """ + if check_for_lock: + _check_for_lock() + return _get_impl().call(CONF, context, topic, msg, timeout) + + +def cast(context, topic, msg): + """Invoke a remote method that does not return anything. + + :param context: Information that identifies the user that has made this + request. + :param topic: The topic to send the rpc message to. This correlates to the + topic argument of + openstack.common.rpc.common.Connection.create_consumer() + and only applies when the consumer was created with + fanout=False. + :param msg: This is a dict in the form { "method" : "method_to_invoke", + "args" : dict_of_kwargs } + + :returns: None + """ + return _get_impl().cast(CONF, context, topic, msg) + + +def fanout_cast(context, topic, msg): + """Broadcast a remote method invocation with no return. + + This method will get invoked on all consumers that were set up with this + topic name and fanout=True. + + :param context: Information that identifies the user that has made this + request. + :param topic: The topic to send the rpc message to. This correlates to the + topic argument of + openstack.common.rpc.common.Connection.create_consumer() + and only applies when the consumer was created with + fanout=True. + :param msg: This is a dict in the form { "method" : "method_to_invoke", + "args" : dict_of_kwargs } + + :returns: None + """ + return _get_impl().fanout_cast(CONF, context, topic, msg) + + +def multicall(context, topic, msg, timeout=None, check_for_lock=False): + """Invoke a remote method and get back an iterator. + + In this case, the remote method will be returning multiple values in + separate messages, so the return values can be processed as the come in via + an iterator. + + :param context: Information that identifies the user that has made this + request. + :param topic: The topic to send the rpc message to. This correlates to the + topic argument of + openstack.common.rpc.common.Connection.create_consumer() + and only applies when the consumer was created with + fanout=False. + :param msg: This is a dict in the form { "method" : "method_to_invoke", + "args" : dict_of_kwargs } + :param timeout: int, number of seconds to use for a response timeout. + If set, this overrides the rpc_response_timeout option. + :param check_for_lock: if True, a warning is emitted if a RPC call is made + with a lock held. + + :returns: An iterator. The iterator will yield a tuple (N, X) where N is + an index that starts at 0 and increases by one for each value + returned and X is the Nth value that was returned by the remote + method. + + :raises: openstack.common.rpc.common.Timeout if a complete response + is not received before the timeout is reached. + """ + if check_for_lock: + _check_for_lock() + return _get_impl().multicall(CONF, context, topic, msg, timeout) + + +def notify(context, topic, msg, envelope=False): + """Send notification event. + + :param context: Information that identifies the user that has made this + request. + :param topic: The topic to send the notification to. + :param msg: This is a dict of content of event. + :param envelope: Set to True to enable message envelope for notifications. + + :returns: None + """ + return _get_impl().notify(cfg.CONF, context, topic, msg, envelope) + + +def cleanup(): + """Clean up resoruces in use by implementation. + + Clean up any resources that have been allocated by the RPC implementation. + This is typically open connections to a messaging service. This function + would get called before an application using this API exits to allow + connections to get torn down cleanly. + + :returns: None + """ + return _get_impl().cleanup() + + +def cast_to_server(context, server_params, topic, msg): + """Invoke a remote method that does not return anything. + + :param context: Information that identifies the user that has made this + request. + :param server_params: Connection information + :param topic: The topic to send the notification to. + :param msg: This is a dict in the form { "method" : "method_to_invoke", + "args" : dict_of_kwargs } + + :returns: None + """ + return _get_impl().cast_to_server(CONF, context, server_params, topic, + msg) + + +def fanout_cast_to_server(context, server_params, topic, msg): + """Broadcast to a remote method invocation with no return. + + :param context: Information that identifies the user that has made this + request. + :param server_params: Connection information + :param topic: The topic to send the notification to. + :param msg: This is a dict in the form { "method" : "method_to_invoke", + "args" : dict_of_kwargs } + + :returns: None + """ + return _get_impl().fanout_cast_to_server(CONF, context, server_params, + topic, msg) + + +def queue_get_for(context, topic, host): + """Get a queue name for a given topic + host. + + This function only works if this naming convention is followed on the + consumer side, as well. For example, in nova, every instance of the + nova-foo service calls create_consumer() for two topics: + + foo + foo. + + Messages sent to the 'foo' topic are distributed to exactly one instance of + the nova-foo service. The services are chosen in a round-robin fashion. + Messages sent to the 'foo.' topic are sent to the nova-foo service on + . + """ + return '%s.%s' % (topic, host) if host else topic + + +_RPCIMPL = None + + +def _get_impl(): + """Delay import of rpc_backend until configuration is loaded.""" + global _RPCIMPL + if _RPCIMPL is None: + try: + _RPCIMPL = importutils.import_module(CONF.rpc_backend) + except ImportError: + # For backwards compatibility with older nova config. + impl = CONF.rpc_backend.replace('nova.rpc', + 'nova.openstack.common.rpc') + _RPCIMPL = importutils.import_module(impl) + return _RPCIMPL diff --git a/savanna/openstack/common/rpc/amqp.py b/savanna/openstack/common/rpc/amqp.py new file mode 100644 index 00000000..7ab56527 --- /dev/null +++ b/savanna/openstack/common/rpc/amqp.py @@ -0,0 +1,610 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# Copyright 2011 - 2012, Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Shared code between AMQP based openstack.common.rpc implementations. + +The code in this module is shared between the rpc implemenations based on AMQP. +Specifically, this includes impl_kombu and impl_qpid. impl_carrot also uses +AMQP, but is deprecated and predates this code. +""" + +import collections +import inspect +import sys +import uuid + +from eventlet import greenpool +from eventlet import pools +from eventlet import queue +from eventlet import semaphore +from oslo.config import cfg + +from savanna.openstack.common import excutils +from savanna.openstack.common.gettextutils import _ # noqa +from savanna.openstack.common import local +from savanna.openstack.common import log as logging +from savanna.openstack.common.rpc import common as rpc_common + + +amqp_opts = [ + cfg.BoolOpt('amqp_durable_queues', + default=False, + deprecated_name='rabbit_durable_queues', + deprecated_group='DEFAULT', + help='Use durable queues in amqp.'), + cfg.BoolOpt('amqp_auto_delete', + default=False, + help='Auto-delete queues in amqp.'), +] + +cfg.CONF.register_opts(amqp_opts) + +UNIQUE_ID = '_unique_id' +LOG = logging.getLogger(__name__) + + +class Pool(pools.Pool): + """Class that implements a Pool of Connections.""" + def __init__(self, conf, connection_cls, *args, **kwargs): + self.connection_cls = connection_cls + self.conf = conf + kwargs.setdefault("max_size", self.conf.rpc_conn_pool_size) + kwargs.setdefault("order_as_stack", True) + super(Pool, self).__init__(*args, **kwargs) + self.reply_proxy = None + + # TODO(comstud): Timeout connections not used in a while + def create(self): + LOG.debug(_('Pool creating new connection')) + return self.connection_cls(self.conf) + + def empty(self): + while self.free_items: + self.get().close() + # Force a new connection pool to be created. + # Note that this was added due to failing unit test cases. The issue + # is the above "while loop" gets all the cached connections from the + # pool and closes them, but never returns them to the pool, a pool + # leak. The unit tests hang waiting for an item to be returned to the + # pool. The unit tests get here via the tearDown() method. In the run + # time code, it gets here via cleanup() and only appears in service.py + # just before doing a sys.exit(), so cleanup() only happens once and + # the leakage is not a problem. + self.connection_cls.pool = None + + +_pool_create_sem = semaphore.Semaphore() + + +def get_connection_pool(conf, connection_cls): + with _pool_create_sem: + # Make sure only one thread tries to create the connection pool. + if not connection_cls.pool: + connection_cls.pool = Pool(conf, connection_cls) + return connection_cls.pool + + +class ConnectionContext(rpc_common.Connection): + """The class that is actually returned to the create_connection() caller. + + This is essentially a wrapper around Connection that supports 'with'. + It can also return a new Connection, or one from a pool. + + The function will also catch when an instance of this class is to be + deleted. With that we can return Connections to the pool on exceptions + and so forth without making the caller be responsible for catching them. + If possible the function makes sure to return a connection to the pool. + """ + + def __init__(self, conf, connection_pool, pooled=True, server_params=None): + """Create a new connection, or get one from the pool.""" + self.connection = None + self.conf = conf + self.connection_pool = connection_pool + if pooled: + self.connection = connection_pool.get() + else: + self.connection = connection_pool.connection_cls( + conf, + server_params=server_params) + self.pooled = pooled + + def __enter__(self): + """When with ConnectionContext() is used, return self.""" + return self + + def _done(self): + """If the connection came from a pool, clean it up and put it back. + If it did not come from a pool, close it. + """ + if self.connection: + if self.pooled: + # Reset the connection so it's ready for the next caller + # to grab from the pool + self.connection.reset() + self.connection_pool.put(self.connection) + else: + try: + self.connection.close() + except Exception: + pass + self.connection = None + + def __exit__(self, exc_type, exc_value, tb): + """End of 'with' statement. We're done here.""" + self._done() + + def __del__(self): + """Caller is done with this connection. Make sure we cleaned up.""" + self._done() + + def close(self): + """Caller is done with this connection.""" + self._done() + + def create_consumer(self, topic, proxy, fanout=False): + self.connection.create_consumer(topic, proxy, fanout) + + def create_worker(self, topic, proxy, pool_name): + self.connection.create_worker(topic, proxy, pool_name) + + def join_consumer_pool(self, callback, pool_name, topic, exchange_name, + ack_on_error=True): + self.connection.join_consumer_pool(callback, + pool_name, + topic, + exchange_name, + ack_on_error) + + def consume_in_thread(self): + self.connection.consume_in_thread() + + def __getattr__(self, key): + """Proxy all other calls to the Connection instance.""" + if self.connection: + return getattr(self.connection, key) + else: + raise rpc_common.InvalidRPCConnectionReuse() + + +class ReplyProxy(ConnectionContext): + """Connection class for RPC replies / callbacks.""" + def __init__(self, conf, connection_pool): + self._call_waiters = {} + self._num_call_waiters = 0 + self._num_call_waiters_wrn_threshhold = 10 + self._reply_q = 'reply_' + uuid.uuid4().hex + super(ReplyProxy, self).__init__(conf, connection_pool, pooled=False) + self.declare_direct_consumer(self._reply_q, self._process_data) + self.consume_in_thread() + + def _process_data(self, message_data): + msg_id = message_data.pop('_msg_id', None) + waiter = self._call_waiters.get(msg_id) + if not waiter: + LOG.warn(_('No calling threads waiting for msg_id : %(msg_id)s' + ', message : %(data)s'), {'msg_id': msg_id, + 'data': message_data}) + LOG.warn(_('_call_waiters: %s') % str(self._call_waiters)) + else: + waiter.put(message_data) + + def add_call_waiter(self, waiter, msg_id): + self._num_call_waiters += 1 + if self._num_call_waiters > self._num_call_waiters_wrn_threshhold: + LOG.warn(_('Number of call waiters is greater than warning ' + 'threshhold: %d. There could be a MulticallProxyWaiter ' + 'leak.') % self._num_call_waiters_wrn_threshhold) + self._num_call_waiters_wrn_threshhold *= 2 + self._call_waiters[msg_id] = waiter + + def del_call_waiter(self, msg_id): + self._num_call_waiters -= 1 + del self._call_waiters[msg_id] + + def get_reply_q(self): + return self._reply_q + + +def msg_reply(conf, msg_id, reply_q, connection_pool, reply=None, + failure=None, ending=False, log_failure=True): + """Sends a reply or an error on the channel signified by msg_id. + + Failure should be a sys.exc_info() tuple. + + """ + with ConnectionContext(conf, connection_pool) as conn: + if failure: + failure = rpc_common.serialize_remote_exception(failure, + log_failure) + + msg = {'result': reply, 'failure': failure} + if ending: + msg['ending'] = True + _add_unique_id(msg) + # If a reply_q exists, add the msg_id to the reply and pass the + # reply_q to direct_send() to use it as the response queue. + # Otherwise use the msg_id for backward compatibilty. + if reply_q: + msg['_msg_id'] = msg_id + conn.direct_send(reply_q, rpc_common.serialize_msg(msg)) + else: + conn.direct_send(msg_id, rpc_common.serialize_msg(msg)) + + +class RpcContext(rpc_common.CommonRpcContext): + """Context that supports replying to a rpc.call.""" + def __init__(self, **kwargs): + self.msg_id = kwargs.pop('msg_id', None) + self.reply_q = kwargs.pop('reply_q', None) + self.conf = kwargs.pop('conf') + super(RpcContext, self).__init__(**kwargs) + + def deepcopy(self): + values = self.to_dict() + values['conf'] = self.conf + values['msg_id'] = self.msg_id + values['reply_q'] = self.reply_q + return self.__class__(**values) + + def reply(self, reply=None, failure=None, ending=False, + connection_pool=None, log_failure=True): + if self.msg_id: + msg_reply(self.conf, self.msg_id, self.reply_q, connection_pool, + reply, failure, ending, log_failure) + if ending: + self.msg_id = None + + +def unpack_context(conf, msg): + """Unpack context from msg.""" + context_dict = {} + for key in list(msg.keys()): + # NOTE(vish): Some versions of python don't like unicode keys + # in kwargs. + key = str(key) + if key.startswith('_context_'): + value = msg.pop(key) + context_dict[key[9:]] = value + context_dict['msg_id'] = msg.pop('_msg_id', None) + context_dict['reply_q'] = msg.pop('_reply_q', None) + context_dict['conf'] = conf + ctx = RpcContext.from_dict(context_dict) + rpc_common._safe_log(LOG.debug, _('unpacked context: %s'), ctx.to_dict()) + return ctx + + +def pack_context(msg, context): + """Pack context into msg. + + Values for message keys need to be less than 255 chars, so we pull + context out into a bunch of separate keys. If we want to support + more arguments in rabbit messages, we may want to do the same + for args at some point. + + """ + context_d = dict([('_context_%s' % key, value) + for (key, value) in context.to_dict().iteritems()]) + msg.update(context_d) + + +class _MsgIdCache(object): + """This class checks any duplicate messages.""" + + # NOTE: This value is considered can be a configuration item, but + # it is not necessary to change its value in most cases, + # so let this value as static for now. + DUP_MSG_CHECK_SIZE = 16 + + def __init__(self, **kwargs): + self.prev_msgids = collections.deque([], + maxlen=self.DUP_MSG_CHECK_SIZE) + + def check_duplicate_message(self, message_data): + """AMQP consumers may read same message twice when exceptions occur + before ack is returned. This method prevents doing it. + """ + if UNIQUE_ID in message_data: + msg_id = message_data[UNIQUE_ID] + if msg_id not in self.prev_msgids: + self.prev_msgids.append(msg_id) + else: + raise rpc_common.DuplicateMessageError(msg_id=msg_id) + + +def _add_unique_id(msg): + """Add unique_id for checking duplicate messages.""" + unique_id = uuid.uuid4().hex + msg.update({UNIQUE_ID: unique_id}) + LOG.debug(_('UNIQUE_ID is %s.') % (unique_id)) + + +class _ThreadPoolWithWait(object): + """Base class for a delayed invocation manager. + + Used by the Connection class to start up green threads + to handle incoming messages. + """ + + def __init__(self, conf, connection_pool): + self.pool = greenpool.GreenPool(conf.rpc_thread_pool_size) + self.connection_pool = connection_pool + self.conf = conf + + def wait(self): + """Wait for all callback threads to exit.""" + self.pool.waitall() + + +class CallbackWrapper(_ThreadPoolWithWait): + """Wraps a straight callback. + + Allows it to be invoked in a green thread. + """ + + def __init__(self, conf, callback, connection_pool): + """Initiates CallbackWrapper object. + + :param conf: cfg.CONF instance + :param callback: a callable (probably a function) + :param connection_pool: connection pool as returned by + get_connection_pool() + """ + super(CallbackWrapper, self).__init__( + conf=conf, + connection_pool=connection_pool, + ) + self.callback = callback + + def __call__(self, message_data): + self.pool.spawn_n(self.callback, message_data) + + +class ProxyCallback(_ThreadPoolWithWait): + """Calls methods on a proxy object based on method and args.""" + + def __init__(self, conf, proxy, connection_pool): + super(ProxyCallback, self).__init__( + conf=conf, + connection_pool=connection_pool, + ) + self.proxy = proxy + self.msg_id_cache = _MsgIdCache() + + def __call__(self, message_data): + """Consumer callback to call a method on a proxy object. + + Parses the message for validity and fires off a thread to call the + proxy object method. + + Message data should be a dictionary with two keys: + method: string representing the method to call + args: dictionary of arg: value + + Example: {'method': 'echo', 'args': {'value': 42}} + + """ + # It is important to clear the context here, because at this point + # the previous context is stored in local.store.context + if hasattr(local.store, 'context'): + del local.store.context + rpc_common._safe_log(LOG.debug, _('received %s'), message_data) + self.msg_id_cache.check_duplicate_message(message_data) + ctxt = unpack_context(self.conf, message_data) + method = message_data.get('method') + args = message_data.get('args', {}) + version = message_data.get('version') + namespace = message_data.get('namespace') + if not method: + LOG.warn(_('no method for message: %s') % message_data) + ctxt.reply(_('No method for message: %s') % message_data, + connection_pool=self.connection_pool) + return + self.pool.spawn_n(self._process_data, ctxt, version, method, + namespace, args) + + def _process_data(self, ctxt, version, method, namespace, args): + """Process a message in a new thread. + + If the proxy object we have has a dispatch method + (see rpc.dispatcher.RpcDispatcher), pass it the version, + method, and args and let it dispatch as appropriate. If not, use + the old behavior of magically calling the specified method on the + proxy we have here. + """ + ctxt.update_store() + try: + rval = self.proxy.dispatch(ctxt, version, method, namespace, + **args) + # Check if the result was a generator + if inspect.isgenerator(rval): + for x in rval: + ctxt.reply(x, None, connection_pool=self.connection_pool) + else: + ctxt.reply(rval, None, connection_pool=self.connection_pool) + # This final None tells multicall that it is done. + ctxt.reply(ending=True, connection_pool=self.connection_pool) + except rpc_common.ClientException as e: + LOG.debug(_('Expected exception during message handling (%s)') % + e._exc_info[1]) + ctxt.reply(None, e._exc_info, + connection_pool=self.connection_pool, + log_failure=False) + except Exception: + # sys.exc_info() is deleted by LOG.exception(). + exc_info = sys.exc_info() + LOG.error(_('Exception during message handling'), + exc_info=exc_info) + ctxt.reply(None, exc_info, connection_pool=self.connection_pool) + + +class MulticallProxyWaiter(object): + def __init__(self, conf, msg_id, timeout, connection_pool): + self._msg_id = msg_id + self._timeout = timeout or conf.rpc_response_timeout + self._reply_proxy = connection_pool.reply_proxy + self._done = False + self._got_ending = False + self._conf = conf + self._dataqueue = queue.LightQueue() + # Add this caller to the reply proxy's call_waiters + self._reply_proxy.add_call_waiter(self, self._msg_id) + self.msg_id_cache = _MsgIdCache() + + def put(self, data): + self._dataqueue.put(data) + + def done(self): + if self._done: + return + self._done = True + # Remove this caller from reply proxy's call_waiters + self._reply_proxy.del_call_waiter(self._msg_id) + + def _process_data(self, data): + result = None + self.msg_id_cache.check_duplicate_message(data) + if data['failure']: + failure = data['failure'] + result = rpc_common.deserialize_remote_exception(self._conf, + failure) + elif data.get('ending', False): + self._got_ending = True + else: + result = data['result'] + return result + + def __iter__(self): + """Return a result until we get a reply with an 'ending' flag.""" + if self._done: + raise StopIteration + while True: + try: + data = self._dataqueue.get(timeout=self._timeout) + result = self._process_data(data) + except queue.Empty: + self.done() + raise rpc_common.Timeout() + except Exception: + with excutils.save_and_reraise_exception(): + self.done() + if self._got_ending: + self.done() + raise StopIteration + if isinstance(result, Exception): + self.done() + raise result + yield result + + +def create_connection(conf, new, connection_pool): + """Create a connection.""" + return ConnectionContext(conf, connection_pool, pooled=not new) + + +_reply_proxy_create_sem = semaphore.Semaphore() + + +def multicall(conf, context, topic, msg, timeout, connection_pool): + """Make a call that returns multiple times.""" + LOG.debug(_('Making synchronous call on %s ...'), topic) + msg_id = uuid.uuid4().hex + msg.update({'_msg_id': msg_id}) + LOG.debug(_('MSG_ID is %s') % (msg_id)) + _add_unique_id(msg) + pack_context(msg, context) + + with _reply_proxy_create_sem: + if not connection_pool.reply_proxy: + connection_pool.reply_proxy = ReplyProxy(conf, connection_pool) + msg.update({'_reply_q': connection_pool.reply_proxy.get_reply_q()}) + wait_msg = MulticallProxyWaiter(conf, msg_id, timeout, connection_pool) + with ConnectionContext(conf, connection_pool) as conn: + conn.topic_send(topic, rpc_common.serialize_msg(msg), timeout) + return wait_msg + + +def call(conf, context, topic, msg, timeout, connection_pool): + """Sends a message on a topic and wait for a response.""" + rv = multicall(conf, context, topic, msg, timeout, connection_pool) + # NOTE(vish): return the last result from the multicall + rv = list(rv) + if not rv: + return + return rv[-1] + + +def cast(conf, context, topic, msg, connection_pool): + """Sends a message on a topic without waiting for a response.""" + LOG.debug(_('Making asynchronous cast on %s...'), topic) + _add_unique_id(msg) + pack_context(msg, context) + with ConnectionContext(conf, connection_pool) as conn: + conn.topic_send(topic, rpc_common.serialize_msg(msg)) + + +def fanout_cast(conf, context, topic, msg, connection_pool): + """Sends a message on a fanout exchange without waiting for a response.""" + LOG.debug(_('Making asynchronous fanout cast...')) + _add_unique_id(msg) + pack_context(msg, context) + with ConnectionContext(conf, connection_pool) as conn: + conn.fanout_send(topic, rpc_common.serialize_msg(msg)) + + +def cast_to_server(conf, context, server_params, topic, msg, connection_pool): + """Sends a message on a topic to a specific server.""" + _add_unique_id(msg) + pack_context(msg, context) + with ConnectionContext(conf, connection_pool, pooled=False, + server_params=server_params) as conn: + conn.topic_send(topic, rpc_common.serialize_msg(msg)) + + +def fanout_cast_to_server(conf, context, server_params, topic, msg, + connection_pool): + """Sends a message on a fanout exchange to a specific server.""" + _add_unique_id(msg) + pack_context(msg, context) + with ConnectionContext(conf, connection_pool, pooled=False, + server_params=server_params) as conn: + conn.fanout_send(topic, rpc_common.serialize_msg(msg)) + + +def notify(conf, context, topic, msg, connection_pool, envelope): + """Sends a notification event on a topic.""" + LOG.debug(_('Sending %(event_type)s on %(topic)s'), + dict(event_type=msg.get('event_type'), + topic=topic)) + _add_unique_id(msg) + pack_context(msg, context) + with ConnectionContext(conf, connection_pool) as conn: + if envelope: + msg = rpc_common.serialize_msg(msg) + conn.notify_send(topic, msg) + + +def cleanup(connection_pool): + if connection_pool: + connection_pool.empty() + + +def get_control_exchange(conf): + return conf.control_exchange diff --git a/savanna/openstack/common/rpc/common.py b/savanna/openstack/common/rpc/common.py new file mode 100644 index 00000000..d799a4ea --- /dev/null +++ b/savanna/openstack/common/rpc/common.py @@ -0,0 +1,509 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# Copyright 2011 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import copy +import sys +import traceback + +from oslo.config import cfg +import six + +from savanna.openstack.common.gettextutils import _ # noqa +from savanna.openstack.common import importutils +from savanna.openstack.common import jsonutils +from savanna.openstack.common import local +from savanna.openstack.common import log as logging + + +CONF = cfg.CONF +LOG = logging.getLogger(__name__) + + +'''RPC Envelope Version. + +This version number applies to the top level structure of messages sent out. +It does *not* apply to the message payload, which must be versioned +independently. For example, when using rpc APIs, a version number is applied +for changes to the API being exposed over rpc. This version number is handled +in the rpc proxy and dispatcher modules. + +This version number applies to the message envelope that is used in the +serialization done inside the rpc layer. See serialize_msg() and +deserialize_msg(). + +The current message format (version 2.0) is very simple. It is: + + { + 'oslo.version': , + 'oslo.message': + } + +Message format version '1.0' is just considered to be the messages we sent +without a message envelope. + +So, the current message envelope just includes the envelope version. It may +eventually contain additional information, such as a signature for the message +payload. + +We will JSON encode the application message payload. The message envelope, +which includes the JSON encoded application message body, will be passed down +to the messaging libraries as a dict. +''' +_RPC_ENVELOPE_VERSION = '2.0' + +_VERSION_KEY = 'oslo.version' +_MESSAGE_KEY = 'oslo.message' + +_REMOTE_POSTFIX = '_Remote' + + +class RPCException(Exception): + msg_fmt = _("An unknown RPC related exception occurred.") + + def __init__(self, message=None, **kwargs): + self.kwargs = kwargs + + if not message: + try: + message = self.msg_fmt % kwargs + + except Exception: + # kwargs doesn't match a variable in the message + # log the issue and the kwargs + LOG.exception(_('Exception in string format operation')) + for name, value in kwargs.iteritems(): + LOG.error("%s: %s" % (name, value)) + # at least get the core message out if something happened + message = self.msg_fmt + + super(RPCException, self).__init__(message) + + +class RemoteError(RPCException): + """Signifies that a remote class has raised an exception. + + Contains a string representation of the type of the original exception, + the value of the original exception, and the traceback. These are + sent to the parent as a joined string so printing the exception + contains all of the relevant info. + + """ + msg_fmt = _("Remote error: %(exc_type)s %(value)s\n%(traceback)s.") + + def __init__(self, exc_type=None, value=None, traceback=None): + self.exc_type = exc_type + self.value = value + self.traceback = traceback + super(RemoteError, self).__init__(exc_type=exc_type, + value=value, + traceback=traceback) + + +class Timeout(RPCException): + """Signifies that a timeout has occurred. + + This exception is raised if the rpc_response_timeout is reached while + waiting for a response from the remote side. + """ + msg_fmt = _('Timeout while waiting on RPC response - ' + 'topic: "%(topic)s", RPC method: "%(method)s" ' + 'info: "%(info)s"') + + def __init__(self, info=None, topic=None, method=None): + """Initiates Timeout object. + + :param info: Extra info to convey to the user + :param topic: The topic that the rpc call was sent to + :param rpc_method_name: The name of the rpc method being + called + """ + self.info = info + self.topic = topic + self.method = method + super(Timeout, self).__init__( + None, + info=info or _(''), + topic=topic or _(''), + method=method or _('')) + + +class DuplicateMessageError(RPCException): + msg_fmt = _("Found duplicate message(%(msg_id)s). Skipping it.") + + +class InvalidRPCConnectionReuse(RPCException): + msg_fmt = _("Invalid reuse of an RPC connection.") + + +class UnsupportedRpcVersion(RPCException): + msg_fmt = _("Specified RPC version, %(version)s, not supported by " + "this endpoint.") + + +class UnsupportedRpcEnvelopeVersion(RPCException): + msg_fmt = _("Specified RPC envelope version, %(version)s, " + "not supported by this endpoint.") + + +class RpcVersionCapError(RPCException): + msg_fmt = _("Specified RPC version cap, %(version_cap)s, is too low") + + +class Connection(object): + """A connection, returned by rpc.create_connection(). + + This class represents a connection to the message bus used for rpc. + An instance of this class should never be created by users of the rpc API. + Use rpc.create_connection() instead. + """ + def close(self): + """Close the connection. + + This method must be called when the connection will no longer be used. + It will ensure that any resources associated with the connection, such + as a network connection, and cleaned up. + """ + raise NotImplementedError() + + def create_consumer(self, topic, proxy, fanout=False): + """Create a consumer on this connection. + + A consumer is associated with a message queue on the backend message + bus. The consumer will read messages from the queue, unpack them, and + dispatch them to the proxy object. The contents of the message pulled + off of the queue will determine which method gets called on the proxy + object. + + :param topic: This is a name associated with what to consume from. + Multiple instances of a service may consume from the same + topic. For example, all instances of nova-compute consume + from a queue called "compute". In that case, the + messages will get distributed amongst the consumers in a + round-robin fashion if fanout=False. If fanout=True, + every consumer associated with this topic will get a + copy of every message. + :param proxy: The object that will handle all incoming messages. + :param fanout: Whether or not this is a fanout topic. See the + documentation for the topic parameter for some + additional comments on this. + """ + raise NotImplementedError() + + def create_worker(self, topic, proxy, pool_name): + """Create a worker on this connection. + + A worker is like a regular consumer of messages directed to a + topic, except that it is part of a set of such consumers (the + "pool") which may run in parallel. Every pool of workers will + receive a given message, but only one worker in the pool will + be asked to process it. Load is distributed across the members + of the pool in round-robin fashion. + + :param topic: This is a name associated with what to consume from. + Multiple instances of a service may consume from the same + topic. + :param proxy: The object that will handle all incoming messages. + :param pool_name: String containing the name of the pool of workers + """ + raise NotImplementedError() + + def join_consumer_pool(self, callback, pool_name, topic, exchange_name): + """Register as a member of a group of consumers. + + Uses given topic from the specified exchange. + Exactly one member of a given pool will receive each message. + + A message will be delivered to multiple pools, if more than + one is created. + + :param callback: Callable to be invoked for each message. + :type callback: callable accepting one argument + :param pool_name: The name of the consumer pool. + :type pool_name: str + :param topic: The routing topic for desired messages. + :type topic: str + :param exchange_name: The name of the message exchange where + the client should attach. Defaults to + the configured exchange. + :type exchange_name: str + """ + raise NotImplementedError() + + def consume_in_thread(self): + """Spawn a thread to handle incoming messages. + + Spawn a thread that will be responsible for handling all incoming + messages for consumers that were set up on this connection. + + Message dispatching inside of this is expected to be implemented in a + non-blocking manner. An example implementation would be having this + thread pull messages in for all of the consumers, but utilize a thread + pool for dispatching the messages to the proxy objects. + """ + raise NotImplementedError() + + +def _safe_log(log_func, msg, msg_data): + """Sanitizes the msg_data field before logging.""" + SANITIZE = ['_context_auth_token', 'auth_token', 'new_pass'] + + def _fix_passwords(d): + """Sanitizes the password fields in the dictionary.""" + for k in d.iterkeys(): + if k.lower().find('password') != -1: + d[k] = '' + elif k.lower() in SANITIZE: + d[k] = '' + elif isinstance(d[k], dict): + _fix_passwords(d[k]) + return d + + return log_func(msg, _fix_passwords(copy.deepcopy(msg_data))) + + +def serialize_remote_exception(failure_info, log_failure=True): + """Prepares exception data to be sent over rpc. + + Failure_info should be a sys.exc_info() tuple. + + """ + tb = traceback.format_exception(*failure_info) + failure = failure_info[1] + if log_failure: + LOG.error(_("Returning exception %s to caller"), + six.text_type(failure)) + LOG.error(tb) + + kwargs = {} + if hasattr(failure, 'kwargs'): + kwargs = failure.kwargs + + # NOTE(matiu): With cells, it's possible to re-raise remote, remote + # exceptions. Lets turn it back into the original exception type. + cls_name = str(failure.__class__.__name__) + mod_name = str(failure.__class__.__module__) + if (cls_name.endswith(_REMOTE_POSTFIX) and + mod_name.endswith(_REMOTE_POSTFIX)): + cls_name = cls_name[:-len(_REMOTE_POSTFIX)] + mod_name = mod_name[:-len(_REMOTE_POSTFIX)] + + data = { + 'class': cls_name, + 'module': mod_name, + 'message': six.text_type(failure), + 'tb': tb, + 'args': failure.args, + 'kwargs': kwargs + } + + json_data = jsonutils.dumps(data) + + return json_data + + +def deserialize_remote_exception(conf, data): + failure = jsonutils.loads(str(data)) + + trace = failure.get('tb', []) + message = failure.get('message', "") + "\n" + "\n".join(trace) + name = failure.get('class') + module = failure.get('module') + + # NOTE(ameade): We DO NOT want to allow just any module to be imported, in + # order to prevent arbitrary code execution. + if module not in conf.allowed_rpc_exception_modules: + return RemoteError(name, failure.get('message'), trace) + + try: + mod = importutils.import_module(module) + klass = getattr(mod, name) + if not issubclass(klass, Exception): + raise TypeError("Can only deserialize Exceptions") + + failure = klass(*failure.get('args', []), **failure.get('kwargs', {})) + except (AttributeError, TypeError, ImportError): + return RemoteError(name, failure.get('message'), trace) + + ex_type = type(failure) + str_override = lambda self: message + new_ex_type = type(ex_type.__name__ + _REMOTE_POSTFIX, (ex_type,), + {'__str__': str_override, '__unicode__': str_override}) + new_ex_type.__module__ = '%s%s' % (module, _REMOTE_POSTFIX) + try: + # NOTE(ameade): Dynamically create a new exception type and swap it in + # as the new type for the exception. This only works on user defined + # Exceptions and not core python exceptions. This is important because + # we cannot necessarily change an exception message so we must override + # the __str__ method. + failure.__class__ = new_ex_type + except TypeError: + # NOTE(ameade): If a core exception then just add the traceback to the + # first exception argument. + failure.args = (message,) + failure.args[1:] + return failure + + +class CommonRpcContext(object): + def __init__(self, **kwargs): + self.values = kwargs + + def __getattr__(self, key): + try: + return self.values[key] + except KeyError: + raise AttributeError(key) + + def to_dict(self): + return copy.deepcopy(self.values) + + @classmethod + def from_dict(cls, values): + return cls(**values) + + def deepcopy(self): + return self.from_dict(self.to_dict()) + + def update_store(self): + local.store.context = self + + def elevated(self, read_deleted=None, overwrite=False): + """Return a version of this context with admin flag set.""" + # TODO(russellb) This method is a bit of a nova-ism. It makes + # some assumptions about the data in the request context sent + # across rpc, while the rest of this class does not. We could get + # rid of this if we changed the nova code that uses this to + # convert the RpcContext back to its native RequestContext doing + # something like nova.context.RequestContext.from_dict(ctxt.to_dict()) + + context = self.deepcopy() + context.values['is_admin'] = True + + context.values.setdefault('roles', []) + + if 'admin' not in context.values['roles']: + context.values['roles'].append('admin') + + if read_deleted is not None: + context.values['read_deleted'] = read_deleted + + return context + + +class ClientException(Exception): + """Encapsulates actual exception expected to be hit by a RPC proxy object. + + Merely instantiating it records the current exception information, which + will be passed back to the RPC client without exceptional logging. + """ + def __init__(self): + self._exc_info = sys.exc_info() + + +def catch_client_exception(exceptions, func, *args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + if type(e) in exceptions: + raise ClientException() + else: + raise + + +def client_exceptions(*exceptions): + """Decorator for manager methods that raise expected exceptions. + + Marking a Manager method with this decorator allows the declaration + of expected exceptions that the RPC layer should not consider fatal, + and not log as if they were generated in a real error scenario. Note + that this will cause listed exceptions to be wrapped in a + ClientException, which is used internally by the RPC layer. + """ + def outer(func): + def inner(*args, **kwargs): + return catch_client_exception(exceptions, func, *args, **kwargs) + return inner + return outer + + +def version_is_compatible(imp_version, version): + """Determine whether versions are compatible. + + :param imp_version: The version implemented + :param version: The version requested by an incoming message. + """ + version_parts = version.split('.') + imp_version_parts = imp_version.split('.') + if int(version_parts[0]) != int(imp_version_parts[0]): # Major + return False + if int(version_parts[1]) > int(imp_version_parts[1]): # Minor + return False + return True + + +def serialize_msg(raw_msg): + # NOTE(russellb) See the docstring for _RPC_ENVELOPE_VERSION for more + # information about this format. + msg = {_VERSION_KEY: _RPC_ENVELOPE_VERSION, + _MESSAGE_KEY: jsonutils.dumps(raw_msg)} + + return msg + + +def deserialize_msg(msg): + # NOTE(russellb): Hang on to your hats, this road is about to + # get a little bumpy. + # + # Robustness Principle: + # "Be strict in what you send, liberal in what you accept." + # + # At this point we have to do a bit of guessing about what it + # is we just received. Here is the set of possibilities: + # + # 1) We received a dict. This could be 2 things: + # + # a) Inspect it to see if it looks like a standard message envelope. + # If so, great! + # + # b) If it doesn't look like a standard message envelope, it could either + # be a notification, or a message from before we added a message + # envelope (referred to as version 1.0). + # Just return the message as-is. + # + # 2) It's any other non-dict type. Just return it and hope for the best. + # This case covers return values from rpc.call() from before message + # envelopes were used. (messages to call a method were always a dict) + + if not isinstance(msg, dict): + # See #2 above. + return msg + + base_envelope_keys = (_VERSION_KEY, _MESSAGE_KEY) + if not all(map(lambda key: key in msg, base_envelope_keys)): + # See #1.b above. + return msg + + # At this point we think we have the message envelope + # format we were expecting. (#1.a above) + + if not version_is_compatible(_RPC_ENVELOPE_VERSION, msg[_VERSION_KEY]): + raise UnsupportedRpcEnvelopeVersion(version=msg[_VERSION_KEY]) + + raw_msg = jsonutils.loads(msg[_MESSAGE_KEY]) + + return raw_msg diff --git a/savanna/openstack/common/rpc/dispatcher.py b/savanna/openstack/common/rpc/dispatcher.py new file mode 100644 index 00000000..69571cd2 --- /dev/null +++ b/savanna/openstack/common/rpc/dispatcher.py @@ -0,0 +1,178 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2012 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Code for rpc message dispatching. + +Messages that come in have a version number associated with them. RPC API +version numbers are in the form: + + Major.Minor + +For a given message with version X.Y, the receiver must be marked as able to +handle messages of version A.B, where: + + A = X + + B >= Y + +The Major version number would be incremented for an almost completely new API. +The Minor version number would be incremented for backwards compatible changes +to an existing API. A backwards compatible change could be something like +adding a new method, adding an argument to an existing method (but not +requiring it), or changing the type for an existing argument (but still +handling the old type as well). + +The conversion over to a versioned API must be done on both the client side and +server side of the API at the same time. However, as the code stands today, +there can be both versioned and unversioned APIs implemented in the same code +base. + +EXAMPLES +======== + +Nova was the first project to use versioned rpc APIs. Consider the compute rpc +API as an example. The client side is in nova/compute/rpcapi.py and the server +side is in nova/compute/manager.py. + + +Example 1) Adding a new method. +------------------------------- + +Adding a new method is a backwards compatible change. It should be added to +nova/compute/manager.py, and RPC_API_VERSION should be bumped from X.Y to +X.Y+1. On the client side, the new method in nova/compute/rpcapi.py should +have a specific version specified to indicate the minimum API version that must +be implemented for the method to be supported. For example:: + + def get_host_uptime(self, ctxt, host): + topic = _compute_topic(self.topic, ctxt, host, None) + return self.call(ctxt, self.make_msg('get_host_uptime'), topic, + version='1.1') + +In this case, version '1.1' is the first version that supported the +get_host_uptime() method. + + +Example 2) Adding a new parameter. +---------------------------------- + +Adding a new parameter to an rpc method can be made backwards compatible. The +RPC_API_VERSION on the server side (nova/compute/manager.py) should be bumped. +The implementation of the method must not expect the parameter to be present.:: + + def some_remote_method(self, arg1, arg2, newarg=None): + # The code needs to deal with newarg=None for cases + # where an older client sends a message without it. + pass + +On the client side, the same changes should be made as in example 1. The +minimum version that supports the new parameter should be specified. +""" + +from savanna.openstack.common.rpc import common as rpc_common +from savanna.openstack.common.rpc import serializer as rpc_serializer + + +class RpcDispatcher(object): + """Dispatch rpc messages according to the requested API version. + + This class can be used as the top level 'manager' for a service. It + contains a list of underlying managers that have an API_VERSION attribute. + """ + + def __init__(self, callbacks, serializer=None): + """Initialize the rpc dispatcher. + + :param callbacks: List of proxy objects that are an instance + of a class with rpc methods exposed. Each proxy + object should have an RPC_API_VERSION attribute. + :param serializer: The Serializer object that will be used to + deserialize arguments before the method call and + to serialize the result after it returns. + """ + self.callbacks = callbacks + if serializer is None: + serializer = rpc_serializer.NoOpSerializer() + self.serializer = serializer + super(RpcDispatcher, self).__init__() + + def _deserialize_args(self, context, kwargs): + """Helper method called to deserialize args before dispatch. + + This calls our serializer on each argument, returning a new set of + args that have been deserialized. + + :param context: The request context + :param kwargs: The arguments to be deserialized + :returns: A new set of deserialized args + """ + new_kwargs = dict() + for argname, arg in kwargs.iteritems(): + new_kwargs[argname] = self.serializer.deserialize_entity(context, + arg) + return new_kwargs + + def dispatch(self, ctxt, version, method, namespace, **kwargs): + """Dispatch a message based on a requested version. + + :param ctxt: The request context + :param version: The requested API version from the incoming message + :param method: The method requested to be called by the incoming + message. + :param namespace: The namespace for the requested method. If None, + the dispatcher will look for a method on a callback + object with no namespace set. + :param kwargs: A dict of keyword arguments to be passed to the method. + + :returns: Whatever is returned by the underlying method that gets + called. + """ + if not version: + version = '1.0' + + had_compatible = False + for proxyobj in self.callbacks: + # Check for namespace compatibility + try: + cb_namespace = proxyobj.RPC_API_NAMESPACE + except AttributeError: + cb_namespace = None + + if namespace != cb_namespace: + continue + + # Check for version compatibility + try: + rpc_api_version = proxyobj.RPC_API_VERSION + except AttributeError: + rpc_api_version = '1.0' + + is_compatible = rpc_common.version_is_compatible(rpc_api_version, + version) + had_compatible = had_compatible or is_compatible + + if not hasattr(proxyobj, method): + continue + if is_compatible: + kwargs = self._deserialize_args(ctxt, kwargs) + result = getattr(proxyobj, method)(ctxt, **kwargs) + return self.serializer.serialize_entity(ctxt, result) + + if had_compatible: + raise AttributeError("No such RPC function '%s'" % method) + else: + raise rpc_common.UnsupportedRpcVersion(version=version) diff --git a/savanna/openstack/common/rpc/impl_fake.py b/savanna/openstack/common/rpc/impl_fake.py new file mode 100644 index 00000000..49591643 --- /dev/null +++ b/savanna/openstack/common/rpc/impl_fake.py @@ -0,0 +1,195 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 OpenStack Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +"""Fake RPC implementation which calls proxy methods directly with no +queues. Casts will block, but this is very useful for tests. +""" + +import inspect +# NOTE(russellb): We specifically want to use json, not our own jsonutils. +# jsonutils has some extra logic to automatically convert objects to primitive +# types so that they can be serialized. We want to catch all cases where +# non-primitive types make it into this code and treat it as an error. +import json +import time + +import eventlet + +from savanna.openstack.common.rpc import common as rpc_common + +CONSUMERS = {} + + +class RpcContext(rpc_common.CommonRpcContext): + def __init__(self, **kwargs): + super(RpcContext, self).__init__(**kwargs) + self._response = [] + self._done = False + + def deepcopy(self): + values = self.to_dict() + new_inst = self.__class__(**values) + new_inst._response = self._response + new_inst._done = self._done + return new_inst + + def reply(self, reply=None, failure=None, ending=False): + if ending: + self._done = True + if not self._done: + self._response.append((reply, failure)) + + +class Consumer(object): + def __init__(self, topic, proxy): + self.topic = topic + self.proxy = proxy + + def call(self, context, version, method, namespace, args, timeout): + done = eventlet.event.Event() + + def _inner(): + ctxt = RpcContext.from_dict(context.to_dict()) + try: + rval = self.proxy.dispatch(context, version, method, + namespace, **args) + res = [] + # Caller might have called ctxt.reply() manually + for (reply, failure) in ctxt._response: + if failure: + raise failure[0], failure[1], failure[2] + res.append(reply) + # if ending not 'sent'...we might have more data to + # return from the function itself + if not ctxt._done: + if inspect.isgenerator(rval): + for val in rval: + res.append(val) + else: + res.append(rval) + done.send(res) + except rpc_common.ClientException as e: + done.send_exception(e._exc_info[1]) + except Exception as e: + done.send_exception(e) + + thread = eventlet.greenthread.spawn(_inner) + + if timeout: + start_time = time.time() + while not done.ready(): + eventlet.greenthread.sleep(1) + cur_time = time.time() + if (cur_time - start_time) > timeout: + thread.kill() + raise rpc_common.Timeout() + + return done.wait() + + +class Connection(object): + """Connection object.""" + + def __init__(self): + self.consumers = [] + + def create_consumer(self, topic, proxy, fanout=False): + consumer = Consumer(topic, proxy) + self.consumers.append(consumer) + if topic not in CONSUMERS: + CONSUMERS[topic] = [] + CONSUMERS[topic].append(consumer) + + def close(self): + for consumer in self.consumers: + CONSUMERS[consumer.topic].remove(consumer) + self.consumers = [] + + def consume_in_thread(self): + pass + + +def create_connection(conf, new=True): + """Create a connection.""" + return Connection() + + +def check_serialize(msg): + """Make sure a message intended for rpc can be serialized.""" + json.dumps(msg) + + +def multicall(conf, context, topic, msg, timeout=None): + """Make a call that returns multiple times.""" + + check_serialize(msg) + + method = msg.get('method') + if not method: + return + args = msg.get('args', {}) + version = msg.get('version', None) + namespace = msg.get('namespace', None) + + try: + consumer = CONSUMERS[topic][0] + except (KeyError, IndexError): + return iter([None]) + else: + return consumer.call(context, version, method, namespace, args, + timeout) + + +def call(conf, context, topic, msg, timeout=None): + """Sends a message on a topic and wait for a response.""" + rv = multicall(conf, context, topic, msg, timeout) + # NOTE(vish): return the last result from the multicall + rv = list(rv) + if not rv: + return + return rv[-1] + + +def cast(conf, context, topic, msg): + check_serialize(msg) + try: + call(conf, context, topic, msg) + except Exception: + pass + + +def notify(conf, context, topic, msg, envelope): + check_serialize(msg) + + +def cleanup(): + pass + + +def fanout_cast(conf, context, topic, msg): + """Cast to all consumers of a topic.""" + check_serialize(msg) + method = msg.get('method') + if not method: + return + args = msg.get('args', {}) + version = msg.get('version', None) + namespace = msg.get('namespace', None) + + for consumer in CONSUMERS.get(topic, []): + try: + consumer.call(context, version, method, namespace, args, None) + except Exception: + pass diff --git a/savanna/openstack/common/rpc/impl_kombu.py b/savanna/openstack/common/rpc/impl_kombu.py new file mode 100644 index 00000000..970f7f72 --- /dev/null +++ b/savanna/openstack/common/rpc/impl_kombu.py @@ -0,0 +1,865 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 OpenStack Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import functools +import itertools +import socket +import ssl +import time +import uuid + +import eventlet +import greenlet +import kombu +import kombu.connection +import kombu.entity +import kombu.messaging +from oslo.config import cfg + +from savanna.openstack.common import excutils +from savanna.openstack.common.gettextutils import _ # noqa +from savanna.openstack.common import network_utils +from savanna.openstack.common.rpc import amqp as rpc_amqp +from savanna.openstack.common.rpc import common as rpc_common +from savanna.openstack.common import sslutils + +kombu_opts = [ + cfg.StrOpt('kombu_ssl_version', + default='', + help='SSL version to use (valid only if SSL enabled). ' + 'valid values are TLSv1, SSLv23 and SSLv3. SSLv2 may ' + 'be available on some distributions' + ), + cfg.StrOpt('kombu_ssl_keyfile', + default='', + help='SSL key file (valid only if SSL enabled)'), + cfg.StrOpt('kombu_ssl_certfile', + default='', + help='SSL cert file (valid only if SSL enabled)'), + cfg.StrOpt('kombu_ssl_ca_certs', + default='', + help=('SSL certification authority file ' + '(valid only if SSL enabled)')), + cfg.StrOpt('rabbit_host', + default='localhost', + help='The RabbitMQ broker address where a single node is used'), + cfg.IntOpt('rabbit_port', + default=5672, + help='The RabbitMQ broker port where a single node is used'), + cfg.ListOpt('rabbit_hosts', + default=['$rabbit_host:$rabbit_port'], + help='RabbitMQ HA cluster host:port pairs'), + cfg.BoolOpt('rabbit_use_ssl', + default=False, + help='connect over SSL for RabbitMQ'), + cfg.StrOpt('rabbit_userid', + default='guest', + help='the RabbitMQ userid'), + cfg.StrOpt('rabbit_password', + default='guest', + help='the RabbitMQ password', + secret=True), + cfg.StrOpt('rabbit_virtual_host', + default='/', + help='the RabbitMQ virtual host'), + cfg.IntOpt('rabbit_retry_interval', + default=1, + help='how frequently to retry connecting with RabbitMQ'), + cfg.IntOpt('rabbit_retry_backoff', + default=2, + help='how long to backoff for between retries when connecting ' + 'to RabbitMQ'), + cfg.IntOpt('rabbit_max_retries', + default=0, + help='maximum retries with trying to connect to RabbitMQ ' + '(the default of 0 implies an infinite retry count)'), + cfg.BoolOpt('rabbit_ha_queues', + default=False, + help='use H/A queues in RabbitMQ (x-ha-policy: all).' + 'You need to wipe RabbitMQ database when ' + 'changing this option.'), + +] + +cfg.CONF.register_opts(kombu_opts) + +LOG = rpc_common.LOG + + +def _get_queue_arguments(conf): + """Construct the arguments for declaring a queue. + + If the rabbit_ha_queues option is set, we declare a mirrored queue + as described here: + + http://www.rabbitmq.com/ha.html + + Setting x-ha-policy to all means that the queue will be mirrored + to all nodes in the cluster. + """ + return {'x-ha-policy': 'all'} if conf.rabbit_ha_queues else {} + + +class ConsumerBase(object): + """Consumer base class.""" + + def __init__(self, channel, callback, tag, **kwargs): + """Declare a queue on an amqp channel. + + 'channel' is the amqp channel to use + 'callback' is the callback to call when messages are received + 'tag' is a unique ID for the consumer on the channel + + queue name, exchange name, and other kombu options are + passed in here as a dictionary. + """ + self.callback = callback + self.tag = str(tag) + self.kwargs = kwargs + self.queue = None + self.ack_on_error = kwargs.get('ack_on_error', True) + self.reconnect(channel) + + def reconnect(self, channel): + """Re-declare the queue after a rabbit reconnect.""" + self.channel = channel + self.kwargs['channel'] = channel + self.queue = kombu.entity.Queue(**self.kwargs) + self.queue.declare() + + def _callback_handler(self, message, callback): + """Call callback with deserialized message. + + Messages that are processed without exception are ack'ed. + + If the message processing generates an exception, it will be + ack'ed if ack_on_error=True. Otherwise it will be .reject()'ed. + Rejection is better than waiting for the message to timeout. + Rejected messages are immediately requeued. + """ + + ack_msg = False + try: + msg = rpc_common.deserialize_msg(message.payload) + callback(msg) + ack_msg = True + except Exception: + if self.ack_on_error: + ack_msg = True + LOG.exception(_("Failed to process message" + " ... skipping it.")) + else: + LOG.exception(_("Failed to process message" + " ... will requeue.")) + finally: + if ack_msg: + message.ack() + else: + message.reject() + + def consume(self, *args, **kwargs): + """Actually declare the consumer on the amqp channel. This will + start the flow of messages from the queue. Using the + Connection.iterconsume() iterator will process the messages, + calling the appropriate callback. + + If a callback is specified in kwargs, use that. Otherwise, + use the callback passed during __init__() + + If kwargs['nowait'] is True, then this call will block until + a message is read. + + """ + + options = {'consumer_tag': self.tag} + options['nowait'] = kwargs.get('nowait', False) + callback = kwargs.get('callback', self.callback) + if not callback: + raise ValueError("No callback defined") + + def _callback(raw_message): + message = self.channel.message_to_python(raw_message) + self._callback_handler(message, callback) + + self.queue.consume(*args, callback=_callback, **options) + + def cancel(self): + """Cancel the consuming from the queue, if it has started.""" + try: + self.queue.cancel(self.tag) + except KeyError as e: + # NOTE(comstud): Kludge to get around a amqplib bug + if str(e) != "u'%s'" % self.tag: + raise + self.queue = None + + +class DirectConsumer(ConsumerBase): + """Queue/consumer class for 'direct'.""" + + def __init__(self, conf, channel, msg_id, callback, tag, **kwargs): + """Init a 'direct' queue. + + 'channel' is the amqp channel to use + 'msg_id' is the msg_id to listen on + 'callback' is the callback to call when messages are received + 'tag' is a unique ID for the consumer on the channel + + Other kombu options may be passed + """ + # Default options + options = {'durable': False, + 'queue_arguments': _get_queue_arguments(conf), + 'auto_delete': True, + 'exclusive': False} + options.update(kwargs) + exchange = kombu.entity.Exchange(name=msg_id, + type='direct', + durable=options['durable'], + auto_delete=options['auto_delete']) + super(DirectConsumer, self).__init__(channel, + callback, + tag, + name=msg_id, + exchange=exchange, + routing_key=msg_id, + **options) + + +class TopicConsumer(ConsumerBase): + """Consumer class for 'topic'.""" + + def __init__(self, conf, channel, topic, callback, tag, name=None, + exchange_name=None, **kwargs): + """Init a 'topic' queue. + + :param channel: the amqp channel to use + :param topic: the topic to listen on + :paramtype topic: str + :param callback: the callback to call when messages are received + :param tag: a unique ID for the consumer on the channel + :param name: optional queue name, defaults to topic + :paramtype name: str + + Other kombu options may be passed as keyword arguments + """ + # Default options + options = {'durable': conf.amqp_durable_queues, + 'queue_arguments': _get_queue_arguments(conf), + 'auto_delete': conf.amqp_auto_delete, + 'exclusive': False} + options.update(kwargs) + exchange_name = exchange_name or rpc_amqp.get_control_exchange(conf) + exchange = kombu.entity.Exchange(name=exchange_name, + type='topic', + durable=options['durable'], + auto_delete=options['auto_delete']) + super(TopicConsumer, self).__init__(channel, + callback, + tag, + name=name or topic, + exchange=exchange, + routing_key=topic, + **options) + + +class FanoutConsumer(ConsumerBase): + """Consumer class for 'fanout'.""" + + def __init__(self, conf, channel, topic, callback, tag, **kwargs): + """Init a 'fanout' queue. + + 'channel' is the amqp channel to use + 'topic' is the topic to listen on + 'callback' is the callback to call when messages are received + 'tag' is a unique ID for the consumer on the channel + + Other kombu options may be passed + """ + unique = uuid.uuid4().hex + exchange_name = '%s_fanout' % topic + queue_name = '%s_fanout_%s' % (topic, unique) + + # Default options + options = {'durable': False, + 'queue_arguments': _get_queue_arguments(conf), + 'auto_delete': True, + 'exclusive': False} + options.update(kwargs) + exchange = kombu.entity.Exchange(name=exchange_name, type='fanout', + durable=options['durable'], + auto_delete=options['auto_delete']) + super(FanoutConsumer, self).__init__(channel, callback, tag, + name=queue_name, + exchange=exchange, + routing_key=topic, + **options) + + +class Publisher(object): + """Base Publisher class.""" + + def __init__(self, channel, exchange_name, routing_key, **kwargs): + """Init the Publisher class with the exchange_name, routing_key, + and other options + """ + self.exchange_name = exchange_name + self.routing_key = routing_key + self.kwargs = kwargs + self.reconnect(channel) + + def reconnect(self, channel): + """Re-establish the Producer after a rabbit reconnection.""" + self.exchange = kombu.entity.Exchange(name=self.exchange_name, + **self.kwargs) + self.producer = kombu.messaging.Producer(exchange=self.exchange, + channel=channel, + routing_key=self.routing_key) + + def send(self, msg, timeout=None): + """Send a message.""" + if timeout: + # + # AMQP TTL is in milliseconds when set in the header. + # + self.producer.publish(msg, headers={'ttl': (timeout * 1000)}) + else: + self.producer.publish(msg) + + +class DirectPublisher(Publisher): + """Publisher class for 'direct'.""" + def __init__(self, conf, channel, msg_id, **kwargs): + """init a 'direct' publisher. + + Kombu options may be passed as keyword args to override defaults + """ + + options = {'durable': False, + 'auto_delete': True, + 'exclusive': False} + options.update(kwargs) + super(DirectPublisher, self).__init__(channel, msg_id, msg_id, + type='direct', **options) + + +class TopicPublisher(Publisher): + """Publisher class for 'topic'.""" + def __init__(self, conf, channel, topic, **kwargs): + """init a 'topic' publisher. + + Kombu options may be passed as keyword args to override defaults + """ + options = {'durable': conf.amqp_durable_queues, + 'auto_delete': conf.amqp_auto_delete, + 'exclusive': False} + options.update(kwargs) + exchange_name = rpc_amqp.get_control_exchange(conf) + super(TopicPublisher, self).__init__(channel, + exchange_name, + topic, + type='topic', + **options) + + +class FanoutPublisher(Publisher): + """Publisher class for 'fanout'.""" + def __init__(self, conf, channel, topic, **kwargs): + """init a 'fanout' publisher. + + Kombu options may be passed as keyword args to override defaults + """ + options = {'durable': False, + 'auto_delete': True, + 'exclusive': False} + options.update(kwargs) + super(FanoutPublisher, self).__init__(channel, '%s_fanout' % topic, + None, type='fanout', **options) + + +class NotifyPublisher(TopicPublisher): + """Publisher class for 'notify'.""" + + def __init__(self, conf, channel, topic, **kwargs): + self.durable = kwargs.pop('durable', conf.amqp_durable_queues) + self.queue_arguments = _get_queue_arguments(conf) + super(NotifyPublisher, self).__init__(conf, channel, topic, **kwargs) + + def reconnect(self, channel): + super(NotifyPublisher, self).reconnect(channel) + + # NOTE(jerdfelt): Normally the consumer would create the queue, but + # we do this to ensure that messages don't get dropped if the + # consumer is started after we do + queue = kombu.entity.Queue(channel=channel, + exchange=self.exchange, + durable=self.durable, + name=self.routing_key, + routing_key=self.routing_key, + queue_arguments=self.queue_arguments) + queue.declare() + + +class Connection(object): + """Connection object.""" + + pool = None + + def __init__(self, conf, server_params=None): + self.consumers = [] + self.consumer_thread = None + self.proxy_callbacks = [] + self.conf = conf + self.max_retries = self.conf.rabbit_max_retries + # Try forever? + if self.max_retries <= 0: + self.max_retries = None + self.interval_start = self.conf.rabbit_retry_interval + self.interval_stepping = self.conf.rabbit_retry_backoff + # max retry-interval = 30 seconds + self.interval_max = 30 + self.memory_transport = False + + if server_params is None: + server_params = {} + # Keys to translate from server_params to kombu params + server_params_to_kombu_params = {'username': 'userid'} + + ssl_params = self._fetch_ssl_params() + params_list = [] + for adr in self.conf.rabbit_hosts: + hostname, port = network_utils.parse_host_port( + adr, default_port=self.conf.rabbit_port) + + params = { + 'hostname': hostname, + 'port': port, + 'userid': self.conf.rabbit_userid, + 'password': self.conf.rabbit_password, + 'virtual_host': self.conf.rabbit_virtual_host, + } + + for sp_key, value in server_params.iteritems(): + p_key = server_params_to_kombu_params.get(sp_key, sp_key) + params[p_key] = value + + if self.conf.fake_rabbit: + params['transport'] = 'memory' + if self.conf.rabbit_use_ssl: + params['ssl'] = ssl_params + + params_list.append(params) + + self.params_list = params_list + + self.memory_transport = self.conf.fake_rabbit + + self.connection = None + self.reconnect() + + def _fetch_ssl_params(self): + """Handles fetching what ssl params should be used for the connection + (if any). + """ + ssl_params = dict() + + # http://docs.python.org/library/ssl.html - ssl.wrap_socket + if self.conf.kombu_ssl_version: + ssl_params['ssl_version'] = sslutils.validate_ssl_version( + self.conf.kombu_ssl_version) + if self.conf.kombu_ssl_keyfile: + ssl_params['keyfile'] = self.conf.kombu_ssl_keyfile + if self.conf.kombu_ssl_certfile: + ssl_params['certfile'] = self.conf.kombu_ssl_certfile + if self.conf.kombu_ssl_ca_certs: + ssl_params['ca_certs'] = self.conf.kombu_ssl_ca_certs + # We might want to allow variations in the + # future with this? + ssl_params['cert_reqs'] = ssl.CERT_REQUIRED + + if not ssl_params: + # Just have the default behavior + return True + else: + # Return the extended behavior + return ssl_params + + def _connect(self, params): + """Connect to rabbit. Re-establish any queues that may have + been declared before if we are reconnecting. Exceptions should + be handled by the caller. + """ + if self.connection: + LOG.info(_("Reconnecting to AMQP server on " + "%(hostname)s:%(port)d") % params) + try: + self.connection.release() + except self.connection_errors: + pass + # Setting this in case the next statement fails, though + # it shouldn't be doing any network operations, yet. + self.connection = None + self.connection = kombu.connection.BrokerConnection(**params) + self.connection_errors = self.connection.connection_errors + if self.memory_transport: + # Kludge to speed up tests. + self.connection.transport.polling_interval = 0.0 + self.consumer_num = itertools.count(1) + self.connection.connect() + self.channel = self.connection.channel() + # work around 'memory' transport bug in 1.1.3 + if self.memory_transport: + self.channel._new_queue('ae.undeliver') + for consumer in self.consumers: + consumer.reconnect(self.channel) + LOG.info(_('Connected to AMQP server on %(hostname)s:%(port)d') % + params) + + def reconnect(self): + """Handles reconnecting and re-establishing queues. + Will retry up to self.max_retries number of times. + self.max_retries = 0 means to retry forever. + Sleep between tries, starting at self.interval_start + seconds, backing off self.interval_stepping number of seconds + each attempt. + """ + + attempt = 0 + while True: + params = self.params_list[attempt % len(self.params_list)] + attempt += 1 + try: + self._connect(params) + return + except (IOError, self.connection_errors) as e: + pass + except Exception as e: + # NOTE(comstud): Unfortunately it's possible for amqplib + # to return an error not covered by its transport + # connection_errors in the case of a timeout waiting for + # a protocol response. (See paste link in LP888621) + # So, we check all exceptions for 'timeout' in them + # and try to reconnect in this case. + if 'timeout' not in str(e): + raise + + log_info = {} + log_info['err_str'] = str(e) + log_info['max_retries'] = self.max_retries + log_info.update(params) + + if self.max_retries and attempt == self.max_retries: + msg = _('Unable to connect to AMQP server on ' + '%(hostname)s:%(port)d after %(max_retries)d ' + 'tries: %(err_str)s') % log_info + LOG.error(msg) + raise rpc_common.RPCException(msg) + + if attempt == 1: + sleep_time = self.interval_start or 1 + elif attempt > 1: + sleep_time += self.interval_stepping + if self.interval_max: + sleep_time = min(sleep_time, self.interval_max) + + log_info['sleep_time'] = sleep_time + LOG.error(_('AMQP server on %(hostname)s:%(port)d is ' + 'unreachable: %(err_str)s. Trying again in ' + '%(sleep_time)d seconds.') % log_info) + time.sleep(sleep_time) + + def ensure(self, error_callback, method, *args, **kwargs): + while True: + try: + return method(*args, **kwargs) + except (self.connection_errors, socket.timeout, IOError) as e: + if error_callback: + error_callback(e) + except Exception as e: + # NOTE(comstud): Unfortunately it's possible for amqplib + # to return an error not covered by its transport + # connection_errors in the case of a timeout waiting for + # a protocol response. (See paste link in LP888621) + # So, we check all exceptions for 'timeout' in them + # and try to reconnect in this case. + if 'timeout' not in str(e): + raise + if error_callback: + error_callback(e) + self.reconnect() + + def get_channel(self): + """Convenience call for bin/clear_rabbit_queues.""" + return self.channel + + def close(self): + """Close/release this connection.""" + self.cancel_consumer_thread() + self.wait_on_proxy_callbacks() + self.connection.release() + self.connection = None + + def reset(self): + """Reset a connection so it can be used again.""" + self.cancel_consumer_thread() + self.wait_on_proxy_callbacks() + self.channel.close() + self.channel = self.connection.channel() + # work around 'memory' transport bug in 1.1.3 + if self.memory_transport: + self.channel._new_queue('ae.undeliver') + self.consumers = [] + + def declare_consumer(self, consumer_cls, topic, callback): + """Create a Consumer using the class that was passed in and + add it to our list of consumers + """ + + def _connect_error(exc): + log_info = {'topic': topic, 'err_str': str(exc)} + LOG.error(_("Failed to declare consumer for topic '%(topic)s': " + "%(err_str)s") % log_info) + + def _declare_consumer(): + consumer = consumer_cls(self.conf, self.channel, topic, callback, + self.consumer_num.next()) + self.consumers.append(consumer) + return consumer + + return self.ensure(_connect_error, _declare_consumer) + + def iterconsume(self, limit=None, timeout=None): + """Return an iterator that will consume from all queues/consumers.""" + + info = {'do_consume': True} + + def _error_callback(exc): + if isinstance(exc, socket.timeout): + LOG.debug(_('Timed out waiting for RPC response: %s') % + str(exc)) + raise rpc_common.Timeout() + else: + LOG.exception(_('Failed to consume message from queue: %s') % + str(exc)) + info['do_consume'] = True + + def _consume(): + if info['do_consume']: + queues_head = self.consumers[:-1] # not fanout. + queues_tail = self.consumers[-1] # fanout + for queue in queues_head: + queue.consume(nowait=True) + queues_tail.consume(nowait=False) + info['do_consume'] = False + return self.connection.drain_events(timeout=timeout) + + for iteration in itertools.count(0): + if limit and iteration >= limit: + raise StopIteration + yield self.ensure(_error_callback, _consume) + + def cancel_consumer_thread(self): + """Cancel a consumer thread.""" + if self.consumer_thread is not None: + self.consumer_thread.kill() + try: + self.consumer_thread.wait() + except greenlet.GreenletExit: + pass + self.consumer_thread = None + + def wait_on_proxy_callbacks(self): + """Wait for all proxy callback threads to exit.""" + for proxy_cb in self.proxy_callbacks: + proxy_cb.wait() + + def publisher_send(self, cls, topic, msg, timeout=None, **kwargs): + """Send to a publisher based on the publisher class.""" + + def _error_callback(exc): + log_info = {'topic': topic, 'err_str': str(exc)} + LOG.exception(_("Failed to publish message to topic " + "'%(topic)s': %(err_str)s") % log_info) + + def _publish(): + publisher = cls(self.conf, self.channel, topic, **kwargs) + publisher.send(msg, timeout) + + self.ensure(_error_callback, _publish) + + def declare_direct_consumer(self, topic, callback): + """Create a 'direct' queue. + In nova's use, this is generally a msg_id queue used for + responses for call/multicall + """ + self.declare_consumer(DirectConsumer, topic, callback) + + def declare_topic_consumer(self, topic, callback=None, queue_name=None, + exchange_name=None, ack_on_error=True): + """Create a 'topic' consumer.""" + self.declare_consumer(functools.partial(TopicConsumer, + name=queue_name, + exchange_name=exchange_name, + ack_on_error=ack_on_error, + ), + topic, callback) + + def declare_fanout_consumer(self, topic, callback): + """Create a 'fanout' consumer.""" + self.declare_consumer(FanoutConsumer, topic, callback) + + def direct_send(self, msg_id, msg): + """Send a 'direct' message.""" + self.publisher_send(DirectPublisher, msg_id, msg) + + def topic_send(self, topic, msg, timeout=None): + """Send a 'topic' message.""" + self.publisher_send(TopicPublisher, topic, msg, timeout) + + def fanout_send(self, topic, msg): + """Send a 'fanout' message.""" + self.publisher_send(FanoutPublisher, topic, msg) + + def notify_send(self, topic, msg, **kwargs): + """Send a notify message on a topic.""" + self.publisher_send(NotifyPublisher, topic, msg, None, **kwargs) + + def consume(self, limit=None): + """Consume from all queues/consumers.""" + it = self.iterconsume(limit=limit) + while True: + try: + it.next() + except StopIteration: + return + + def consume_in_thread(self): + """Consumer from all queues/consumers in a greenthread.""" + @excutils.forever_retry_uncaught_exceptions + def _consumer_thread(): + try: + self.consume() + except greenlet.GreenletExit: + return + if self.consumer_thread is None: + self.consumer_thread = eventlet.spawn(_consumer_thread) + return self.consumer_thread + + def create_consumer(self, topic, proxy, fanout=False): + """Create a consumer that calls a method in a proxy object.""" + proxy_cb = rpc_amqp.ProxyCallback( + self.conf, proxy, + rpc_amqp.get_connection_pool(self.conf, Connection)) + self.proxy_callbacks.append(proxy_cb) + + if fanout: + self.declare_fanout_consumer(topic, proxy_cb) + else: + self.declare_topic_consumer(topic, proxy_cb) + + def create_worker(self, topic, proxy, pool_name): + """Create a worker that calls a method in a proxy object.""" + proxy_cb = rpc_amqp.ProxyCallback( + self.conf, proxy, + rpc_amqp.get_connection_pool(self.conf, Connection)) + self.proxy_callbacks.append(proxy_cb) + self.declare_topic_consumer(topic, proxy_cb, pool_name) + + def join_consumer_pool(self, callback, pool_name, topic, + exchange_name=None, ack_on_error=True): + """Register as a member of a group of consumers for a given topic from + the specified exchange. + + Exactly one member of a given pool will receive each message. + + A message will be delivered to multiple pools, if more than + one is created. + """ + callback_wrapper = rpc_amqp.CallbackWrapper( + conf=self.conf, + callback=callback, + connection_pool=rpc_amqp.get_connection_pool(self.conf, + Connection), + ) + self.proxy_callbacks.append(callback_wrapper) + self.declare_topic_consumer( + queue_name=pool_name, + topic=topic, + exchange_name=exchange_name, + callback=callback_wrapper, + ack_on_error=ack_on_error, + ) + + +def create_connection(conf, new=True): + """Create a connection.""" + return rpc_amqp.create_connection( + conf, new, + rpc_amqp.get_connection_pool(conf, Connection)) + + +def multicall(conf, context, topic, msg, timeout=None): + """Make a call that returns multiple times.""" + return rpc_amqp.multicall( + conf, context, topic, msg, timeout, + rpc_amqp.get_connection_pool(conf, Connection)) + + +def call(conf, context, topic, msg, timeout=None): + """Sends a message on a topic and wait for a response.""" + return rpc_amqp.call( + conf, context, topic, msg, timeout, + rpc_amqp.get_connection_pool(conf, Connection)) + + +def cast(conf, context, topic, msg): + """Sends a message on a topic without waiting for a response.""" + return rpc_amqp.cast( + conf, context, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection)) + + +def fanout_cast(conf, context, topic, msg): + """Sends a message on a fanout exchange without waiting for a response.""" + return rpc_amqp.fanout_cast( + conf, context, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection)) + + +def cast_to_server(conf, context, server_params, topic, msg): + """Sends a message on a topic to a specific server.""" + return rpc_amqp.cast_to_server( + conf, context, server_params, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection)) + + +def fanout_cast_to_server(conf, context, server_params, topic, msg): + """Sends a message on a fanout exchange to a specific server.""" + return rpc_amqp.fanout_cast_to_server( + conf, context, server_params, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection)) + + +def notify(conf, context, topic, msg, envelope): + """Sends a notification event on a topic.""" + return rpc_amqp.notify( + conf, context, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection), + envelope) + + +def cleanup(): + return rpc_amqp.cleanup(Connection.pool) diff --git a/savanna/openstack/common/rpc/impl_qpid.py b/savanna/openstack/common/rpc/impl_qpid.py new file mode 100644 index 00000000..c9640688 --- /dev/null +++ b/savanna/openstack/common/rpc/impl_qpid.py @@ -0,0 +1,739 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 OpenStack Foundation +# Copyright 2011 - 2012, Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import functools +import itertools +import time +import uuid + +import eventlet +import greenlet +from oslo.config import cfg + +from savanna.openstack.common import excutils +from savanna.openstack.common.gettextutils import _ # noqa +from savanna.openstack.common import importutils +from savanna.openstack.common import jsonutils +from savanna.openstack.common import log as logging +from savanna.openstack.common.rpc import amqp as rpc_amqp +from savanna.openstack.common.rpc import common as rpc_common + +qpid_codec = importutils.try_import("qpid.codec010") +qpid_messaging = importutils.try_import("qpid.messaging") +qpid_exceptions = importutils.try_import("qpid.messaging.exceptions") + +LOG = logging.getLogger(__name__) + +qpid_opts = [ + cfg.StrOpt('qpid_hostname', + default='localhost', + help='Qpid broker hostname'), + cfg.IntOpt('qpid_port', + default=5672, + help='Qpid broker port'), + cfg.ListOpt('qpid_hosts', + default=['$qpid_hostname:$qpid_port'], + help='Qpid HA cluster host:port pairs'), + cfg.StrOpt('qpid_username', + default='', + help='Username for qpid connection'), + cfg.StrOpt('qpid_password', + default='', + help='Password for qpid connection', + secret=True), + cfg.StrOpt('qpid_sasl_mechanisms', + default='', + help='Space separated list of SASL mechanisms to use for auth'), + cfg.IntOpt('qpid_heartbeat', + default=60, + help='Seconds between connection keepalive heartbeats'), + cfg.StrOpt('qpid_protocol', + default='tcp', + help="Transport to use, either 'tcp' or 'ssl'"), + cfg.BoolOpt('qpid_tcp_nodelay', + default=True, + help='Disable Nagle algorithm'), +] + +cfg.CONF.register_opts(qpid_opts) + +JSON_CONTENT_TYPE = 'application/json; charset=utf8' + + +class ConsumerBase(object): + """Consumer base class.""" + + def __init__(self, session, callback, node_name, node_opts, + link_name, link_opts): + """Declare a queue on an amqp session. + + 'session' is the amqp session to use + 'callback' is the callback to call when messages are received + 'node_name' is the first part of the Qpid address string, before ';' + 'node_opts' will be applied to the "x-declare" section of "node" + in the address string. + 'link_name' goes into the "name" field of the "link" in the address + string + 'link_opts' will be applied to the "x-declare" section of "link" + in the address string. + """ + self.callback = callback + self.receiver = None + self.session = None + + addr_opts = { + "create": "always", + "node": { + "type": "topic", + "x-declare": { + "durable": True, + "auto-delete": True, + }, + }, + "link": { + "name": link_name, + "durable": True, + "x-declare": { + "durable": False, + "auto-delete": True, + "exclusive": False, + }, + }, + } + addr_opts["node"]["x-declare"].update(node_opts) + addr_opts["link"]["x-declare"].update(link_opts) + + self.address = "%s ; %s" % (node_name, jsonutils.dumps(addr_opts)) + + self.connect(session) + + def connect(self, session): + """Declare the reciever on connect.""" + self._declare_receiver(session) + + def reconnect(self, session): + """Re-declare the receiver after a qpid reconnect.""" + self._declare_receiver(session) + + def _declare_receiver(self, session): + self.session = session + self.receiver = session.receiver(self.address) + self.receiver.capacity = 1 + + def _unpack_json_msg(self, msg): + """Load the JSON data in msg if msg.content_type indicates that it + is necessary. Put the loaded data back into msg.content and + update msg.content_type appropriately. + + A Qpid Message containing a dict will have a content_type of + 'amqp/map', whereas one containing a string that needs to be converted + back from JSON will have a content_type of JSON_CONTENT_TYPE. + + :param msg: a Qpid Message object + :returns: None + """ + if msg.content_type == JSON_CONTENT_TYPE: + msg.content = jsonutils.loads(msg.content) + msg.content_type = 'amqp/map' + + def consume(self): + """Fetch the message and pass it to the callback object.""" + message = self.receiver.fetch() + try: + self._unpack_json_msg(message) + msg = rpc_common.deserialize_msg(message.content) + self.callback(msg) + except Exception: + LOG.exception(_("Failed to process message... skipping it.")) + finally: + # TODO(sandy): Need support for optional ack_on_error. + self.session.acknowledge(message) + + def get_receiver(self): + return self.receiver + + def get_node_name(self): + return self.address.split(';')[0] + + +class DirectConsumer(ConsumerBase): + """Queue/consumer class for 'direct'.""" + + def __init__(self, conf, session, msg_id, callback): + """Init a 'direct' queue. + + 'session' is the amqp session to use + 'msg_id' is the msg_id to listen on + 'callback' is the callback to call when messages are received + """ + + super(DirectConsumer, self).__init__( + session, callback, + "%s/%s" % (msg_id, msg_id), + {"type": "direct"}, + msg_id, + { + "auto-delete": conf.amqp_auto_delete, + "exclusive": True, + "durable": conf.amqp_durable_queues, + }) + + +class TopicConsumer(ConsumerBase): + """Consumer class for 'topic'.""" + + def __init__(self, conf, session, topic, callback, name=None, + exchange_name=None): + """Init a 'topic' queue. + + :param session: the amqp session to use + :param topic: is the topic to listen on + :paramtype topic: str + :param callback: the callback to call when messages are received + :param name: optional queue name, defaults to topic + """ + + exchange_name = exchange_name or rpc_amqp.get_control_exchange(conf) + super(TopicConsumer, self).__init__( + session, callback, + "%s/%s" % (exchange_name, topic), + {}, name or topic, + { + "auto-delete": conf.amqp_auto_delete, + "durable": conf.amqp_durable_queues, + }) + + +class FanoutConsumer(ConsumerBase): + """Consumer class for 'fanout'.""" + + def __init__(self, conf, session, topic, callback): + """Init a 'fanout' queue. + + 'session' is the amqp session to use + 'topic' is the topic to listen on + 'callback' is the callback to call when messages are received + """ + self.conf = conf + + super(FanoutConsumer, self).__init__( + session, callback, + "%s_fanout" % topic, + {"durable": False, "type": "fanout"}, + "%s_fanout_%s" % (topic, uuid.uuid4().hex), + {"exclusive": True}) + + def reconnect(self, session): + topic = self.get_node_name() + params = { + 'session': session, + 'topic': topic, + 'callback': self.callback, + } + + self.__init__(conf=self.conf, **params) + + super(FanoutConsumer, self).reconnect(session) + + +class Publisher(object): + """Base Publisher class.""" + + def __init__(self, session, node_name, node_opts=None): + """Init the Publisher class with the exchange_name, routing_key, + and other options + """ + self.sender = None + self.session = session + + addr_opts = { + "create": "always", + "node": { + "type": "topic", + "x-declare": { + "durable": False, + # auto-delete isn't implemented for exchanges in qpid, + # but put in here anyway + "auto-delete": True, + }, + }, + } + if node_opts: + addr_opts["node"]["x-declare"].update(node_opts) + + self.address = "%s ; %s" % (node_name, jsonutils.dumps(addr_opts)) + + self.reconnect(session) + + def reconnect(self, session): + """Re-establish the Sender after a reconnection.""" + self.sender = session.sender(self.address) + + def _pack_json_msg(self, msg): + """Qpid cannot serialize dicts containing strings longer than 65535 + characters. This function dumps the message content to a JSON + string, which Qpid is able to handle. + + :param msg: May be either a Qpid Message object or a bare dict. + :returns: A Qpid Message with its content field JSON encoded. + """ + try: + msg.content = jsonutils.dumps(msg.content) + except AttributeError: + # Need to have a Qpid message so we can set the content_type. + msg = qpid_messaging.Message(jsonutils.dumps(msg)) + msg.content_type = JSON_CONTENT_TYPE + return msg + + def send(self, msg): + """Send a message.""" + try: + # Check if Qpid can encode the message + check_msg = msg + if not hasattr(check_msg, 'content_type'): + check_msg = qpid_messaging.Message(msg) + content_type = check_msg.content_type + enc, dec = qpid_messaging.message.get_codec(content_type) + enc(check_msg.content) + except qpid_codec.CodecException: + # This means the message couldn't be serialized as a dict. + msg = self._pack_json_msg(msg) + self.sender.send(msg) + + +class DirectPublisher(Publisher): + """Publisher class for 'direct'.""" + def __init__(self, conf, session, msg_id): + """Init a 'direct' publisher.""" + super(DirectPublisher, self).__init__(session, msg_id, + {"type": "Direct"}) + + +class TopicPublisher(Publisher): + """Publisher class for 'topic'.""" + def __init__(self, conf, session, topic): + """init a 'topic' publisher. + """ + exchange_name = rpc_amqp.get_control_exchange(conf) + super(TopicPublisher, self).__init__(session, + "%s/%s" % (exchange_name, topic)) + + +class FanoutPublisher(Publisher): + """Publisher class for 'fanout'.""" + def __init__(self, conf, session, topic): + """init a 'fanout' publisher. + """ + super(FanoutPublisher, self).__init__( + session, + "%s_fanout" % topic, {"type": "fanout"}) + + +class NotifyPublisher(Publisher): + """Publisher class for notifications.""" + def __init__(self, conf, session, topic): + """init a 'topic' publisher. + """ + exchange_name = rpc_amqp.get_control_exchange(conf) + super(NotifyPublisher, self).__init__(session, + "%s/%s" % (exchange_name, topic), + {"durable": True}) + + +class Connection(object): + """Connection object.""" + + pool = None + + def __init__(self, conf, server_params=None): + if not qpid_messaging: + raise ImportError("Failed to import qpid.messaging") + + self.session = None + self.consumers = {} + self.consumer_thread = None + self.proxy_callbacks = [] + self.conf = conf + + if server_params and 'hostname' in server_params: + # NOTE(russellb) This enables support for cast_to_server. + server_params['qpid_hosts'] = [ + '%s:%d' % (server_params['hostname'], + server_params.get('port', 5672)) + ] + + params = { + 'qpid_hosts': self.conf.qpid_hosts, + 'username': self.conf.qpid_username, + 'password': self.conf.qpid_password, + } + params.update(server_params or {}) + + self.brokers = params['qpid_hosts'] + self.username = params['username'] + self.password = params['password'] + self.connection_create(self.brokers[0]) + self.reconnect() + + def connection_create(self, broker): + # Create the connection - this does not open the connection + self.connection = qpid_messaging.Connection(broker) + + # Check if flags are set and if so set them for the connection + # before we call open + self.connection.username = self.username + self.connection.password = self.password + + self.connection.sasl_mechanisms = self.conf.qpid_sasl_mechanisms + # Reconnection is done by self.reconnect() + self.connection.reconnect = False + self.connection.heartbeat = self.conf.qpid_heartbeat + self.connection.transport = self.conf.qpid_protocol + self.connection.tcp_nodelay = self.conf.qpid_tcp_nodelay + + def _register_consumer(self, consumer): + self.consumers[str(consumer.get_receiver())] = consumer + + def _lookup_consumer(self, receiver): + return self.consumers[str(receiver)] + + def reconnect(self): + """Handles reconnecting and re-establishing sessions and queues.""" + attempt = 0 + delay = 1 + while True: + # Close the session if necessary + if self.connection.opened(): + try: + self.connection.close() + except qpid_exceptions.ConnectionError: + pass + + broker = self.brokers[attempt % len(self.brokers)] + attempt += 1 + + try: + self.connection_create(broker) + self.connection.open() + except qpid_exceptions.ConnectionError as e: + msg_dict = dict(e=e, delay=delay) + msg = _("Unable to connect to AMQP server: %(e)s. " + "Sleeping %(delay)s seconds") % msg_dict + LOG.error(msg) + time.sleep(delay) + delay = min(2 * delay, 60) + else: + LOG.info(_('Connected to AMQP server on %s'), broker) + break + + self.session = self.connection.session() + + if self.consumers: + consumers = self.consumers + self.consumers = {} + + for consumer in consumers.itervalues(): + consumer.reconnect(self.session) + self._register_consumer(consumer) + + LOG.debug(_("Re-established AMQP queues")) + + def ensure(self, error_callback, method, *args, **kwargs): + while True: + try: + return method(*args, **kwargs) + except (qpid_exceptions.Empty, + qpid_exceptions.ConnectionError) as e: + if error_callback: + error_callback(e) + self.reconnect() + + def close(self): + """Close/release this connection.""" + self.cancel_consumer_thread() + self.wait_on_proxy_callbacks() + try: + self.connection.close() + except Exception: + # NOTE(dripton) Logging exceptions that happen during cleanup just + # causes confusion; there's really nothing useful we can do with + # them. + pass + self.connection = None + + def reset(self): + """Reset a connection so it can be used again.""" + self.cancel_consumer_thread() + self.wait_on_proxy_callbacks() + self.session.close() + self.session = self.connection.session() + self.consumers = {} + + def declare_consumer(self, consumer_cls, topic, callback): + """Create a Consumer using the class that was passed in and + add it to our list of consumers + """ + def _connect_error(exc): + log_info = {'topic': topic, 'err_str': str(exc)} + LOG.error(_("Failed to declare consumer for topic '%(topic)s': " + "%(err_str)s") % log_info) + + def _declare_consumer(): + consumer = consumer_cls(self.conf, self.session, topic, callback) + self._register_consumer(consumer) + return consumer + + return self.ensure(_connect_error, _declare_consumer) + + def iterconsume(self, limit=None, timeout=None): + """Return an iterator that will consume from all queues/consumers.""" + + def _error_callback(exc): + if isinstance(exc, qpid_exceptions.Empty): + LOG.debug(_('Timed out waiting for RPC response: %s') % + str(exc)) + raise rpc_common.Timeout() + else: + LOG.exception(_('Failed to consume message from queue: %s') % + str(exc)) + + def _consume(): + nxt_receiver = self.session.next_receiver(timeout=timeout) + try: + self._lookup_consumer(nxt_receiver).consume() + except Exception: + LOG.exception(_("Error processing message. Skipping it.")) + + for iteration in itertools.count(0): + if limit and iteration >= limit: + raise StopIteration + yield self.ensure(_error_callback, _consume) + + def cancel_consumer_thread(self): + """Cancel a consumer thread.""" + if self.consumer_thread is not None: + self.consumer_thread.kill() + try: + self.consumer_thread.wait() + except greenlet.GreenletExit: + pass + self.consumer_thread = None + + def wait_on_proxy_callbacks(self): + """Wait for all proxy callback threads to exit.""" + for proxy_cb in self.proxy_callbacks: + proxy_cb.wait() + + def publisher_send(self, cls, topic, msg): + """Send to a publisher based on the publisher class.""" + + def _connect_error(exc): + log_info = {'topic': topic, 'err_str': str(exc)} + LOG.exception(_("Failed to publish message to topic " + "'%(topic)s': %(err_str)s") % log_info) + + def _publisher_send(): + publisher = cls(self.conf, self.session, topic) + publisher.send(msg) + + return self.ensure(_connect_error, _publisher_send) + + def declare_direct_consumer(self, topic, callback): + """Create a 'direct' queue. + In nova's use, this is generally a msg_id queue used for + responses for call/multicall + """ + self.declare_consumer(DirectConsumer, topic, callback) + + def declare_topic_consumer(self, topic, callback=None, queue_name=None, + exchange_name=None): + """Create a 'topic' consumer.""" + self.declare_consumer(functools.partial(TopicConsumer, + name=queue_name, + exchange_name=exchange_name, + ), + topic, callback) + + def declare_fanout_consumer(self, topic, callback): + """Create a 'fanout' consumer.""" + self.declare_consumer(FanoutConsumer, topic, callback) + + def direct_send(self, msg_id, msg): + """Send a 'direct' message.""" + self.publisher_send(DirectPublisher, msg_id, msg) + + def topic_send(self, topic, msg, timeout=None): + """Send a 'topic' message.""" + # + # We want to create a message with attributes, e.g. a TTL. We + # don't really need to keep 'msg' in its JSON format any longer + # so let's create an actual qpid message here and get some + # value-add on the go. + # + # WARNING: Request timeout happens to be in the same units as + # qpid's TTL (seconds). If this changes in the future, then this + # will need to be altered accordingly. + # + qpid_message = qpid_messaging.Message(content=msg, ttl=timeout) + self.publisher_send(TopicPublisher, topic, qpid_message) + + def fanout_send(self, topic, msg): + """Send a 'fanout' message.""" + self.publisher_send(FanoutPublisher, topic, msg) + + def notify_send(self, topic, msg, **kwargs): + """Send a notify message on a topic.""" + self.publisher_send(NotifyPublisher, topic, msg) + + def consume(self, limit=None): + """Consume from all queues/consumers.""" + it = self.iterconsume(limit=limit) + while True: + try: + it.next() + except StopIteration: + return + + def consume_in_thread(self): + """Consumer from all queues/consumers in a greenthread.""" + @excutils.forever_retry_uncaught_exceptions + def _consumer_thread(): + try: + self.consume() + except greenlet.GreenletExit: + return + if self.consumer_thread is None: + self.consumer_thread = eventlet.spawn(_consumer_thread) + return self.consumer_thread + + def create_consumer(self, topic, proxy, fanout=False): + """Create a consumer that calls a method in a proxy object.""" + proxy_cb = rpc_amqp.ProxyCallback( + self.conf, proxy, + rpc_amqp.get_connection_pool(self.conf, Connection)) + self.proxy_callbacks.append(proxy_cb) + + if fanout: + consumer = FanoutConsumer(self.conf, self.session, topic, proxy_cb) + else: + consumer = TopicConsumer(self.conf, self.session, topic, proxy_cb) + + self._register_consumer(consumer) + + return consumer + + def create_worker(self, topic, proxy, pool_name): + """Create a worker that calls a method in a proxy object.""" + proxy_cb = rpc_amqp.ProxyCallback( + self.conf, proxy, + rpc_amqp.get_connection_pool(self.conf, Connection)) + self.proxy_callbacks.append(proxy_cb) + + consumer = TopicConsumer(self.conf, self.session, topic, proxy_cb, + name=pool_name) + + self._register_consumer(consumer) + + return consumer + + def join_consumer_pool(self, callback, pool_name, topic, + exchange_name=None, ack_on_error=True): + """Register as a member of a group of consumers for a given topic from + the specified exchange. + + Exactly one member of a given pool will receive each message. + + A message will be delivered to multiple pools, if more than + one is created. + """ + callback_wrapper = rpc_amqp.CallbackWrapper( + conf=self.conf, + callback=callback, + connection_pool=rpc_amqp.get_connection_pool(self.conf, + Connection), + ) + self.proxy_callbacks.append(callback_wrapper) + + consumer = TopicConsumer(conf=self.conf, + session=self.session, + topic=topic, + callback=callback_wrapper, + name=pool_name, + exchange_name=exchange_name) + + self._register_consumer(consumer) + return consumer + + +def create_connection(conf, new=True): + """Create a connection.""" + return rpc_amqp.create_connection( + conf, new, + rpc_amqp.get_connection_pool(conf, Connection)) + + +def multicall(conf, context, topic, msg, timeout=None): + """Make a call that returns multiple times.""" + return rpc_amqp.multicall( + conf, context, topic, msg, timeout, + rpc_amqp.get_connection_pool(conf, Connection)) + + +def call(conf, context, topic, msg, timeout=None): + """Sends a message on a topic and wait for a response.""" + return rpc_amqp.call( + conf, context, topic, msg, timeout, + rpc_amqp.get_connection_pool(conf, Connection)) + + +def cast(conf, context, topic, msg): + """Sends a message on a topic without waiting for a response.""" + return rpc_amqp.cast( + conf, context, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection)) + + +def fanout_cast(conf, context, topic, msg): + """Sends a message on a fanout exchange without waiting for a response.""" + return rpc_amqp.fanout_cast( + conf, context, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection)) + + +def cast_to_server(conf, context, server_params, topic, msg): + """Sends a message on a topic to a specific server.""" + return rpc_amqp.cast_to_server( + conf, context, server_params, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection)) + + +def fanout_cast_to_server(conf, context, server_params, topic, msg): + """Sends a message on a fanout exchange to a specific server.""" + return rpc_amqp.fanout_cast_to_server( + conf, context, server_params, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection)) + + +def notify(conf, context, topic, msg, envelope): + """Sends a notification event on a topic.""" + return rpc_amqp.notify(conf, context, topic, msg, + rpc_amqp.get_connection_pool(conf, Connection), + envelope) + + +def cleanup(): + return rpc_amqp.cleanup(Connection.pool) diff --git a/savanna/openstack/common/rpc/impl_zmq.py b/savanna/openstack/common/rpc/impl_zmq.py new file mode 100644 index 00000000..20d9d7ea --- /dev/null +++ b/savanna/openstack/common/rpc/impl_zmq.py @@ -0,0 +1,817 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 Cloudscaling Group, Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import os +import pprint +import re +import socket +import sys +import types +import uuid + +import eventlet +import greenlet +from oslo.config import cfg + +from savanna.openstack.common import excutils +from savanna.openstack.common.gettextutils import _ # noqa +from savanna.openstack.common import importutils +from savanna.openstack.common import jsonutils +from savanna.openstack.common.rpc import common as rpc_common + +zmq = importutils.try_import('eventlet.green.zmq') + +# for convenience, are not modified. +pformat = pprint.pformat +Timeout = eventlet.timeout.Timeout +LOG = rpc_common.LOG +RemoteError = rpc_common.RemoteError +RPCException = rpc_common.RPCException + +zmq_opts = [ + cfg.StrOpt('rpc_zmq_bind_address', default='*', + help='ZeroMQ bind address. Should be a wildcard (*), ' + 'an ethernet interface, or IP. ' + 'The "host" option should point or resolve to this ' + 'address.'), + + # The module.Class to use for matchmaking. + cfg.StrOpt( + 'rpc_zmq_matchmaker', + default=('savanna.openstack.common.rpc.' + 'matchmaker.MatchMakerLocalhost'), + help='MatchMaker driver', + ), + + # The following port is unassigned by IANA as of 2012-05-21 + cfg.IntOpt('rpc_zmq_port', default=9501, + help='ZeroMQ receiver listening port'), + + cfg.IntOpt('rpc_zmq_contexts', default=1, + help='Number of ZeroMQ contexts, defaults to 1'), + + cfg.IntOpt('rpc_zmq_topic_backlog', default=None, + help='Maximum number of ingress messages to locally buffer ' + 'per topic. Default is unlimited.'), + + cfg.StrOpt('rpc_zmq_ipc_dir', default='/var/run/openstack', + help='Directory for holding IPC sockets'), + + cfg.StrOpt('rpc_zmq_host', default=socket.gethostname(), + help='Name of this node. Must be a valid hostname, FQDN, or ' + 'IP address. Must match "host" option, if running Nova.') +] + + +CONF = cfg.CONF +CONF.register_opts(zmq_opts) + +ZMQ_CTX = None # ZeroMQ Context, must be global. +matchmaker = None # memoized matchmaker object + + +def _serialize(data): + """Serialization wrapper. + + We prefer using JSON, but it cannot encode all types. + Error if a developer passes us bad data. + """ + try: + return jsonutils.dumps(data, ensure_ascii=True) + except TypeError: + with excutils.save_and_reraise_exception(): + LOG.error(_("JSON serialization failed.")) + + +def _deserialize(data): + """Deserialization wrapper.""" + LOG.debug(_("Deserializing: %s"), data) + return jsonutils.loads(data) + + +class ZmqSocket(object): + """A tiny wrapper around ZeroMQ. + + Simplifies the send/recv protocol and connection management. + Can be used as a Context (supports the 'with' statement). + """ + + def __init__(self, addr, zmq_type, bind=True, subscribe=None): + self.sock = _get_ctxt().socket(zmq_type) + self.addr = addr + self.type = zmq_type + self.subscriptions = [] + + # Support failures on sending/receiving on wrong socket type. + self.can_recv = zmq_type in (zmq.PULL, zmq.SUB) + self.can_send = zmq_type in (zmq.PUSH, zmq.PUB) + self.can_sub = zmq_type in (zmq.SUB, ) + + # Support list, str, & None for subscribe arg (cast to list) + do_sub = { + list: subscribe, + str: [subscribe], + type(None): [] + }[type(subscribe)] + + for f in do_sub: + self.subscribe(f) + + str_data = {'addr': addr, 'type': self.socket_s(), + 'subscribe': subscribe, 'bind': bind} + + LOG.debug(_("Connecting to %(addr)s with %(type)s"), str_data) + LOG.debug(_("-> Subscribed to %(subscribe)s"), str_data) + LOG.debug(_("-> bind: %(bind)s"), str_data) + + try: + if bind: + self.sock.bind(addr) + else: + self.sock.connect(addr) + except Exception: + raise RPCException(_("Could not open socket.")) + + def socket_s(self): + """Get socket type as string.""" + t_enum = ('PUSH', 'PULL', 'PUB', 'SUB', 'REP', 'REQ', 'ROUTER', + 'DEALER') + return dict(map(lambda t: (getattr(zmq, t), t), t_enum))[self.type] + + def subscribe(self, msg_filter): + """Subscribe.""" + if not self.can_sub: + raise RPCException("Cannot subscribe on this socket.") + LOG.debug(_("Subscribing to %s"), msg_filter) + + try: + self.sock.setsockopt(zmq.SUBSCRIBE, msg_filter) + except Exception: + return + + self.subscriptions.append(msg_filter) + + def unsubscribe(self, msg_filter): + """Unsubscribe.""" + if msg_filter not in self.subscriptions: + return + self.sock.setsockopt(zmq.UNSUBSCRIBE, msg_filter) + self.subscriptions.remove(msg_filter) + + def close(self): + if self.sock is None or self.sock.closed: + return + + # We must unsubscribe, or we'll leak descriptors. + if self.subscriptions: + for f in self.subscriptions: + try: + self.sock.setsockopt(zmq.UNSUBSCRIBE, f) + except Exception: + pass + self.subscriptions = [] + + try: + # Default is to linger + self.sock.close() + except Exception: + # While this is a bad thing to happen, + # it would be much worse if some of the code calling this + # were to fail. For now, lets log, and later evaluate + # if we can safely raise here. + LOG.error("ZeroMQ socket could not be closed.") + self.sock = None + + def recv(self, **kwargs): + if not self.can_recv: + raise RPCException(_("You cannot recv on this socket.")) + return self.sock.recv_multipart(**kwargs) + + def send(self, data, **kwargs): + if not self.can_send: + raise RPCException(_("You cannot send on this socket.")) + self.sock.send_multipart(data, **kwargs) + + +class ZmqClient(object): + """Client for ZMQ sockets.""" + + def __init__(self, addr): + self.outq = ZmqSocket(addr, zmq.PUSH, bind=False) + + def cast(self, msg_id, topic, data, envelope): + msg_id = msg_id or 0 + + if not envelope: + self.outq.send(map(bytes, + (msg_id, topic, 'cast', _serialize(data)))) + return + + rpc_envelope = rpc_common.serialize_msg(data[1], envelope) + zmq_msg = reduce(lambda x, y: x + y, rpc_envelope.items()) + self.outq.send(map(bytes, + (msg_id, topic, 'impl_zmq_v2', data[0]) + zmq_msg)) + + def close(self): + self.outq.close() + + +class RpcContext(rpc_common.CommonRpcContext): + """Context that supports replying to a rpc.call.""" + def __init__(self, **kwargs): + self.replies = [] + super(RpcContext, self).__init__(**kwargs) + + def deepcopy(self): + values = self.to_dict() + values['replies'] = self.replies + return self.__class__(**values) + + def reply(self, reply=None, failure=None, ending=False): + if ending: + return + self.replies.append(reply) + + @classmethod + def marshal(self, ctx): + ctx_data = ctx.to_dict() + return _serialize(ctx_data) + + @classmethod + def unmarshal(self, data): + return RpcContext.from_dict(_deserialize(data)) + + +class InternalContext(object): + """Used by ConsumerBase as a private context for - methods.""" + + def __init__(self, proxy): + self.proxy = proxy + self.msg_waiter = None + + def _get_response(self, ctx, proxy, topic, data): + """Process a curried message and cast the result to topic.""" + LOG.debug(_("Running func with context: %s"), ctx.to_dict()) + data.setdefault('version', None) + data.setdefault('args', {}) + + try: + result = proxy.dispatch( + ctx, data['version'], data['method'], + data.get('namespace'), **data['args']) + return ConsumerBase.normalize_reply(result, ctx.replies) + except greenlet.GreenletExit: + # ignore these since they are just from shutdowns + pass + except rpc_common.ClientException as e: + LOG.debug(_("Expected exception during message handling (%s)") % + e._exc_info[1]) + return {'exc': + rpc_common.serialize_remote_exception(e._exc_info, + log_failure=False)} + except Exception: + LOG.error(_("Exception during message handling")) + return {'exc': + rpc_common.serialize_remote_exception(sys.exc_info())} + + def reply(self, ctx, proxy, + msg_id=None, context=None, topic=None, msg=None): + """Reply to a casted call.""" + # NOTE(ewindisch): context kwarg exists for Grizzly compat. + # this may be able to be removed earlier than + # 'I' if ConsumerBase.process were refactored. + if type(msg) is list: + payload = msg[-1] + else: + payload = msg + + response = ConsumerBase.normalize_reply( + self._get_response(ctx, proxy, topic, payload), + ctx.replies) + + LOG.debug(_("Sending reply")) + _multi_send(_cast, ctx, topic, { + 'method': '-process_reply', + 'args': { + 'msg_id': msg_id, # Include for Folsom compat. + 'response': response + } + }, _msg_id=msg_id) + + +class ConsumerBase(object): + """Base Consumer.""" + + def __init__(self): + self.private_ctx = InternalContext(None) + + @classmethod + def normalize_reply(self, result, replies): + #TODO(ewindisch): re-evaluate and document this method. + if isinstance(result, types.GeneratorType): + return list(result) + elif replies: + return replies + else: + return [result] + + def process(self, proxy, ctx, data): + data.setdefault('version', None) + data.setdefault('args', {}) + + # Method starting with - are + # processed internally. (non-valid method name) + method = data.get('method') + if not method: + LOG.error(_("RPC message did not include method.")) + return + + # Internal method + # uses internal context for safety. + if method == '-reply': + self.private_ctx.reply(ctx, proxy, **data['args']) + return + + proxy.dispatch(ctx, data['version'], + data['method'], data.get('namespace'), **data['args']) + + +class ZmqBaseReactor(ConsumerBase): + """A consumer class implementing a centralized casting broker (PULL-PUSH). + + Used for RoundRobin requests. + """ + + def __init__(self, conf): + super(ZmqBaseReactor, self).__init__() + + self.proxies = {} + self.threads = [] + self.sockets = [] + self.subscribe = {} + + self.pool = eventlet.greenpool.GreenPool(conf.rpc_thread_pool_size) + + def register(self, proxy, in_addr, zmq_type_in, + in_bind=True, subscribe=None): + + LOG.info(_("Registering reactor")) + + if zmq_type_in not in (zmq.PULL, zmq.SUB): + raise RPCException("Bad input socktype") + + # Items push in. + inq = ZmqSocket(in_addr, zmq_type_in, bind=in_bind, + subscribe=subscribe) + + self.proxies[inq] = proxy + self.sockets.append(inq) + + LOG.info(_("In reactor registered")) + + def consume_in_thread(self): + def _consume(sock): + LOG.info(_("Consuming socket")) + while True: + self.consume(sock) + + for k in self.proxies.keys(): + self.threads.append( + self.pool.spawn(_consume, k) + ) + + def wait(self): + for t in self.threads: + t.wait() + + def close(self): + for s in self.sockets: + s.close() + + for t in self.threads: + t.kill() + + +class ZmqProxy(ZmqBaseReactor): + """A consumer class implementing a topic-based proxy. + + Forwards to IPC sockets. + """ + + def __init__(self, conf): + super(ZmqProxy, self).__init__(conf) + pathsep = set((os.path.sep or '', os.path.altsep or '', '/', '\\')) + self.badchars = re.compile(r'[%s]' % re.escape(''.join(pathsep))) + + self.topic_proxy = {} + + def consume(self, sock): + ipc_dir = CONF.rpc_zmq_ipc_dir + + data = sock.recv(copy=False) + topic = data[1].bytes + + if topic.startswith('fanout~'): + sock_type = zmq.PUB + topic = topic.split('.', 1)[0] + elif topic.startswith('zmq_replies'): + sock_type = zmq.PUB + else: + sock_type = zmq.PUSH + + if topic not in self.topic_proxy: + def publisher(waiter): + LOG.info(_("Creating proxy for topic: %s"), topic) + + try: + # The topic is received over the network, + # don't trust this input. + if self.badchars.search(topic) is not None: + emsg = _("Topic contained dangerous characters.") + LOG.warn(emsg) + raise RPCException(emsg) + + out_sock = ZmqSocket("ipc://%s/zmq_topic_%s" % + (ipc_dir, topic), + sock_type, bind=True) + except RPCException: + waiter.send_exception(*sys.exc_info()) + return + + self.topic_proxy[topic] = eventlet.queue.LightQueue( + CONF.rpc_zmq_topic_backlog) + self.sockets.append(out_sock) + + # It takes some time for a pub socket to open, + # before we can have any faith in doing a send() to it. + if sock_type == zmq.PUB: + eventlet.sleep(.5) + + waiter.send(True) + + while(True): + data = self.topic_proxy[topic].get() + out_sock.send(data, copy=False) + + wait_sock_creation = eventlet.event.Event() + eventlet.spawn(publisher, wait_sock_creation) + + try: + wait_sock_creation.wait() + except RPCException: + LOG.error(_("Topic socket file creation failed.")) + return + + try: + self.topic_proxy[topic].put_nowait(data) + except eventlet.queue.Full: + LOG.error(_("Local per-topic backlog buffer full for topic " + "%(topic)s. Dropping message.") % {'topic': topic}) + + def consume_in_thread(self): + """Runs the ZmqProxy service.""" + ipc_dir = CONF.rpc_zmq_ipc_dir + consume_in = "tcp://%s:%s" % \ + (CONF.rpc_zmq_bind_address, + CONF.rpc_zmq_port) + consumption_proxy = InternalContext(None) + + try: + os.makedirs(ipc_dir) + except os.error: + if not os.path.isdir(ipc_dir): + with excutils.save_and_reraise_exception(): + LOG.error(_("Required IPC directory does not exist at" + " %s") % (ipc_dir, )) + try: + self.register(consumption_proxy, + consume_in, + zmq.PULL) + except zmq.ZMQError: + if os.access(ipc_dir, os.X_OK): + with excutils.save_and_reraise_exception(): + LOG.error(_("Permission denied to IPC directory at" + " %s") % (ipc_dir, )) + with excutils.save_and_reraise_exception(): + LOG.error(_("Could not create ZeroMQ receiver daemon. " + "Socket may already be in use.")) + + super(ZmqProxy, self).consume_in_thread() + + +def unflatten_envelope(packenv): + """Unflattens the RPC envelope. + + Takes a list and returns a dictionary. + i.e. [1,2,3,4] => {1: 2, 3: 4} + """ + i = iter(packenv) + h = {} + try: + while True: + k = i.next() + h[k] = i.next() + except StopIteration: + return h + + +class ZmqReactor(ZmqBaseReactor): + """A consumer class implementing a consumer for messages. + + Can also be used as a 1:1 proxy + """ + + def __init__(self, conf): + super(ZmqReactor, self).__init__(conf) + + def consume(self, sock): + #TODO(ewindisch): use zero-copy (i.e. references, not copying) + data = sock.recv() + LOG.debug(_("CONSUMER RECEIVED DATA: %s"), data) + + proxy = self.proxies[sock] + + if data[2] == 'cast': # Legacy protocol + packenv = data[3] + + ctx, msg = _deserialize(packenv) + request = rpc_common.deserialize_msg(msg) + ctx = RpcContext.unmarshal(ctx) + elif data[2] == 'impl_zmq_v2': + packenv = data[4:] + + msg = unflatten_envelope(packenv) + request = rpc_common.deserialize_msg(msg) + + # Unmarshal only after verifying the message. + ctx = RpcContext.unmarshal(data[3]) + else: + LOG.error(_("ZMQ Envelope version unsupported or unknown.")) + return + + self.pool.spawn_n(self.process, proxy, ctx, request) + + +class Connection(rpc_common.Connection): + """Manages connections and threads.""" + + def __init__(self, conf): + self.topics = [] + self.reactor = ZmqReactor(conf) + + def create_consumer(self, topic, proxy, fanout=False): + # Register with matchmaker. + _get_matchmaker().register(topic, CONF.rpc_zmq_host) + + # Subscription scenarios + if fanout: + sock_type = zmq.SUB + subscribe = ('', fanout)[type(fanout) == str] + topic = 'fanout~' + topic.split('.', 1)[0] + else: + sock_type = zmq.PULL + subscribe = None + topic = '.'.join((topic.split('.', 1)[0], CONF.rpc_zmq_host)) + + if topic in self.topics: + LOG.info(_("Skipping topic registration. Already registered.")) + return + + # Receive messages from (local) proxy + inaddr = "ipc://%s/zmq_topic_%s" % \ + (CONF.rpc_zmq_ipc_dir, topic) + + LOG.debug(_("Consumer is a zmq.%s"), + ['PULL', 'SUB'][sock_type == zmq.SUB]) + + self.reactor.register(proxy, inaddr, sock_type, + subscribe=subscribe, in_bind=False) + self.topics.append(topic) + + def close(self): + _get_matchmaker().stop_heartbeat() + for topic in self.topics: + _get_matchmaker().unregister(topic, CONF.rpc_zmq_host) + + self.reactor.close() + self.topics = [] + + def wait(self): + self.reactor.wait() + + def consume_in_thread(self): + _get_matchmaker().start_heartbeat() + self.reactor.consume_in_thread() + + +def _cast(addr, context, topic, msg, timeout=None, envelope=False, + _msg_id=None): + timeout_cast = timeout or CONF.rpc_cast_timeout + payload = [RpcContext.marshal(context), msg] + + with Timeout(timeout_cast, exception=rpc_common.Timeout): + try: + conn = ZmqClient(addr) + + # assumes cast can't return an exception + conn.cast(_msg_id, topic, payload, envelope) + except zmq.ZMQError: + raise RPCException("Cast failed. ZMQ Socket Exception") + finally: + if 'conn' in vars(): + conn.close() + + +def _call(addr, context, topic, msg, timeout=None, + envelope=False): + # timeout_response is how long we wait for a response + timeout = timeout or CONF.rpc_response_timeout + + # The msg_id is used to track replies. + msg_id = uuid.uuid4().hex + + # Replies always come into the reply service. + reply_topic = "zmq_replies.%s" % CONF.rpc_zmq_host + + LOG.debug(_("Creating payload")) + # Curry the original request into a reply method. + mcontext = RpcContext.marshal(context) + payload = { + 'method': '-reply', + 'args': { + 'msg_id': msg_id, + 'topic': reply_topic, + # TODO(ewindisch): safe to remove mcontext in I. + 'msg': [mcontext, msg] + } + } + + LOG.debug(_("Creating queue socket for reply waiter")) + + # Messages arriving async. + # TODO(ewindisch): have reply consumer with dynamic subscription mgmt + with Timeout(timeout, exception=rpc_common.Timeout): + try: + msg_waiter = ZmqSocket( + "ipc://%s/zmq_topic_zmq_replies.%s" % + (CONF.rpc_zmq_ipc_dir, + CONF.rpc_zmq_host), + zmq.SUB, subscribe=msg_id, bind=False + ) + + LOG.debug(_("Sending cast")) + _cast(addr, context, topic, payload, envelope) + + LOG.debug(_("Cast sent; Waiting reply")) + # Blocks until receives reply + msg = msg_waiter.recv() + LOG.debug(_("Received message: %s"), msg) + LOG.debug(_("Unpacking response")) + + if msg[2] == 'cast': # Legacy version + raw_msg = _deserialize(msg[-1])[-1] + elif msg[2] == 'impl_zmq_v2': + rpc_envelope = unflatten_envelope(msg[4:]) + raw_msg = rpc_common.deserialize_msg(rpc_envelope) + else: + raise rpc_common.UnsupportedRpcEnvelopeVersion( + _("Unsupported or unknown ZMQ envelope returned.")) + + responses = raw_msg['args']['response'] + # ZMQError trumps the Timeout error. + except zmq.ZMQError: + raise RPCException("ZMQ Socket Error") + except (IndexError, KeyError): + raise RPCException(_("RPC Message Invalid.")) + finally: + if 'msg_waiter' in vars(): + msg_waiter.close() + + # It seems we don't need to do all of the following, + # but perhaps it would be useful for multicall? + # One effect of this is that we're checking all + # responses for Exceptions. + for resp in responses: + if isinstance(resp, types.DictType) and 'exc' in resp: + raise rpc_common.deserialize_remote_exception(CONF, resp['exc']) + + return responses[-1] + + +def _multi_send(method, context, topic, msg, timeout=None, + envelope=False, _msg_id=None): + """Wraps the sending of messages. + + Dispatches to the matchmaker and sends message to all relevant hosts. + """ + conf = CONF + LOG.debug(_("%(msg)s") % {'msg': ' '.join(map(pformat, (topic, msg)))}) + + queues = _get_matchmaker().queues(topic) + LOG.debug(_("Sending message(s) to: %s"), queues) + + # Don't stack if we have no matchmaker results + if not queues: + LOG.warn(_("No matchmaker results. Not casting.")) + # While not strictly a timeout, callers know how to handle + # this exception and a timeout isn't too big a lie. + raise rpc_common.Timeout(_("No match from matchmaker.")) + + # This supports brokerless fanout (addresses > 1) + for queue in queues: + (_topic, ip_addr) = queue + _addr = "tcp://%s:%s" % (ip_addr, conf.rpc_zmq_port) + + if method.__name__ == '_cast': + eventlet.spawn_n(method, _addr, context, + _topic, msg, timeout, envelope, + _msg_id) + return + return method(_addr, context, _topic, msg, timeout, + envelope) + + +def create_connection(conf, new=True): + return Connection(conf) + + +def multicall(conf, *args, **kwargs): + """Multiple calls.""" + return _multi_send(_call, *args, **kwargs) + + +def call(conf, *args, **kwargs): + """Send a message, expect a response.""" + data = _multi_send(_call, *args, **kwargs) + return data[-1] + + +def cast(conf, *args, **kwargs): + """Send a message expecting no reply.""" + _multi_send(_cast, *args, **kwargs) + + +def fanout_cast(conf, context, topic, msg, **kwargs): + """Send a message to all listening and expect no reply.""" + # NOTE(ewindisch): fanout~ is used because it avoid splitting on . + # and acts as a non-subtle hint to the matchmaker and ZmqProxy. + _multi_send(_cast, context, 'fanout~' + str(topic), msg, **kwargs) + + +def notify(conf, context, topic, msg, envelope): + """Send notification event. + + Notifications are sent to topic-priority. + This differs from the AMQP drivers which send to topic.priority. + """ + # NOTE(ewindisch): dot-priority in rpc notifier does not + # work with our assumptions. + topic = topic.replace('.', '-') + cast(conf, context, topic, msg, envelope=envelope) + + +def cleanup(): + """Clean up resources in use by implementation.""" + global ZMQ_CTX + if ZMQ_CTX: + ZMQ_CTX.term() + ZMQ_CTX = None + + global matchmaker + matchmaker = None + + +def _get_ctxt(): + if not zmq: + raise ImportError("Failed to import eventlet.green.zmq") + + global ZMQ_CTX + if not ZMQ_CTX: + ZMQ_CTX = zmq.Context(CONF.rpc_zmq_contexts) + return ZMQ_CTX + + +def _get_matchmaker(*args, **kwargs): + global matchmaker + if not matchmaker: + mm = CONF.rpc_zmq_matchmaker + if mm.endswith('matchmaker.MatchMakerRing'): + mm.replace('matchmaker', 'matchmaker_ring') + LOG.warn(_('rpc_zmq_matchmaker = %(orig)s is deprecated; use' + ' %(new)s instead') % dict( + orig=CONF.rpc_zmq_matchmaker, new=mm)) + matchmaker = importutils.import_object(mm, *args, **kwargs) + return matchmaker diff --git a/savanna/openstack/common/rpc/matchmaker.py b/savanna/openstack/common/rpc/matchmaker.py new file mode 100644 index 00000000..229a4052 --- /dev/null +++ b/savanna/openstack/common/rpc/matchmaker.py @@ -0,0 +1,330 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 Cloudscaling Group, Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +""" +The MatchMaker classes should except a Topic or Fanout exchange key and +return keys for direct exchanges, per (approximate) AMQP parlance. +""" + +import contextlib + +import eventlet +from oslo.config import cfg + +from savanna.openstack.common.gettextutils import _ # noqa +from savanna.openstack.common import log as logging + + +matchmaker_opts = [ + cfg.IntOpt('matchmaker_heartbeat_freq', + default=300, + help='Heartbeat frequency'), + cfg.IntOpt('matchmaker_heartbeat_ttl', + default=600, + help='Heartbeat time-to-live.'), +] + +CONF = cfg.CONF +CONF.register_opts(matchmaker_opts) +LOG = logging.getLogger(__name__) +contextmanager = contextlib.contextmanager + + +class MatchMakerException(Exception): + """Signified a match could not be found.""" + message = _("Match not found by MatchMaker.") + + +class Exchange(object): + """Implements lookups. + + Subclass this to support hashtables, dns, etc. + """ + def __init__(self): + pass + + def run(self, key): + raise NotImplementedError() + + +class Binding(object): + """A binding on which to perform a lookup.""" + def __init__(self): + pass + + def test(self, key): + raise NotImplementedError() + + +class MatchMakerBase(object): + """Match Maker Base Class. + + Build off HeartbeatMatchMakerBase if building a heartbeat-capable + MatchMaker. + """ + def __init__(self): + # Array of tuples. Index [2] toggles negation, [3] is last-if-true + self.bindings = [] + + self.no_heartbeat_msg = _('Matchmaker does not implement ' + 'registration or heartbeat.') + + def register(self, key, host): + """Register a host on a backend. + + Heartbeats, if applicable, may keepalive registration. + """ + pass + + def ack_alive(self, key, host): + """Acknowledge that a key.host is alive. + + Used internally for updating heartbeats, but may also be used + publically to acknowledge a system is alive (i.e. rpc message + successfully sent to host) + """ + pass + + def is_alive(self, topic, host): + """Checks if a host is alive.""" + pass + + def expire(self, topic, host): + """Explicitly expire a host's registration.""" + pass + + def send_heartbeats(self): + """Send all heartbeats. + + Use start_heartbeat to spawn a heartbeat greenthread, + which loops this method. + """ + pass + + def unregister(self, key, host): + """Unregister a topic.""" + pass + + def start_heartbeat(self): + """Spawn heartbeat greenthread.""" + pass + + def stop_heartbeat(self): + """Destroys the heartbeat greenthread.""" + pass + + def add_binding(self, binding, rule, last=True): + self.bindings.append((binding, rule, False, last)) + + #NOTE(ewindisch): kept the following method in case we implement the + # underlying support. + #def add_negate_binding(self, binding, rule, last=True): + # self.bindings.append((binding, rule, True, last)) + + def queues(self, key): + workers = [] + + # bit is for negate bindings - if we choose to implement it. + # last stops processing rules if this matches. + for (binding, exchange, bit, last) in self.bindings: + if binding.test(key): + workers.extend(exchange.run(key)) + + # Support last. + if last: + return workers + return workers + + +class HeartbeatMatchMakerBase(MatchMakerBase): + """Base for a heart-beat capable MatchMaker. + + Provides common methods for registering, unregistering, and maintaining + heartbeats. + """ + def __init__(self): + self.hosts = set() + self._heart = None + self.host_topic = {} + + super(HeartbeatMatchMakerBase, self).__init__() + + def send_heartbeats(self): + """Send all heartbeats. + + Use start_heartbeat to spawn a heartbeat greenthread, + which loops this method. + """ + for key, host in self.host_topic: + self.ack_alive(key, host) + + def ack_alive(self, key, host): + """Acknowledge that a host.topic is alive. + + Used internally for updating heartbeats, but may also be used + publically to acknowledge a system is alive (i.e. rpc message + successfully sent to host) + """ + raise NotImplementedError("Must implement ack_alive") + + def backend_register(self, key, host): + """Implements registration logic. + + Called by register(self,key,host) + """ + raise NotImplementedError("Must implement backend_register") + + def backend_unregister(self, key, key_host): + """Implements de-registration logic. + + Called by unregister(self,key,host) + """ + raise NotImplementedError("Must implement backend_unregister") + + def register(self, key, host): + """Register a host on a backend. + + Heartbeats, if applicable, may keepalive registration. + """ + self.hosts.add(host) + self.host_topic[(key, host)] = host + key_host = '.'.join((key, host)) + + self.backend_register(key, key_host) + + self.ack_alive(key, host) + + def unregister(self, key, host): + """Unregister a topic.""" + if (key, host) in self.host_topic: + del self.host_topic[(key, host)] + + self.hosts.discard(host) + self.backend_unregister(key, '.'.join((key, host))) + + LOG.info(_("Matchmaker unregistered: %(key)s, %(host)s"), + {'key': key, 'host': host}) + + def start_heartbeat(self): + """Implementation of MatchMakerBase.start_heartbeat. + + Launches greenthread looping send_heartbeats(), + yielding for CONF.matchmaker_heartbeat_freq seconds + between iterations. + """ + if not self.hosts: + raise MatchMakerException( + _("Register before starting heartbeat.")) + + def do_heartbeat(): + while True: + self.send_heartbeats() + eventlet.sleep(CONF.matchmaker_heartbeat_freq) + + self._heart = eventlet.spawn(do_heartbeat) + + def stop_heartbeat(self): + """Destroys the heartbeat greenthread.""" + if self._heart: + self._heart.kill() + + +class DirectBinding(Binding): + """Specifies a host in the key via a '.' character. + + Although dots are used in the key, the behavior here is + that it maps directly to a host, thus direct. + """ + def test(self, key): + if '.' in key: + return True + return False + + +class TopicBinding(Binding): + """Where a 'bare' key without dots. + + AMQP generally considers topic exchanges to be those *with* dots, + but we deviate here in terminology as the behavior here matches + that of a topic exchange (whereas where there are dots, behavior + matches that of a direct exchange. + """ + def test(self, key): + if '.' not in key: + return True + return False + + +class FanoutBinding(Binding): + """Match on fanout keys, where key starts with 'fanout.' string.""" + def test(self, key): + if key.startswith('fanout~'): + return True + return False + + +class StubExchange(Exchange): + """Exchange that does nothing.""" + def run(self, key): + return [(key, None)] + + +class LocalhostExchange(Exchange): + """Exchange where all direct topics are local.""" + def __init__(self, host='localhost'): + self.host = host + super(Exchange, self).__init__() + + def run(self, key): + return [('.'.join((key.split('.')[0], self.host)), self.host)] + + +class DirectExchange(Exchange): + """Exchange where all topic keys are split, sending to second half. + + i.e. "compute.host" sends a message to "compute.host" running on "host" + """ + def __init__(self): + super(Exchange, self).__init__() + + def run(self, key): + e = key.split('.', 1)[1] + return [(key, e)] + + +class MatchMakerLocalhost(MatchMakerBase): + """Match Maker where all bare topics resolve to localhost. + + Useful for testing. + """ + def __init__(self, host='localhost'): + super(MatchMakerLocalhost, self).__init__() + self.add_binding(FanoutBinding(), LocalhostExchange(host)) + self.add_binding(DirectBinding(), DirectExchange()) + self.add_binding(TopicBinding(), LocalhostExchange(host)) + + +class MatchMakerStub(MatchMakerBase): + """Match Maker where topics are untouched. + + Useful for testing, or for AMQP/brokered queues. + Will not work where knowledge of hosts is known (i.e. zeromq) + """ + def __init__(self): + super(MatchMakerStub, self).__init__() + + self.add_binding(FanoutBinding(), StubExchange()) + self.add_binding(DirectBinding(), StubExchange()) + self.add_binding(TopicBinding(), StubExchange()) diff --git a/savanna/openstack/common/rpc/matchmaker_redis.py b/savanna/openstack/common/rpc/matchmaker_redis.py new file mode 100644 index 00000000..bd5bf0ca --- /dev/null +++ b/savanna/openstack/common/rpc/matchmaker_redis.py @@ -0,0 +1,145 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2013 Cloudscaling Group, Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +""" +The MatchMaker classes should accept a Topic or Fanout exchange key and +return keys for direct exchanges, per (approximate) AMQP parlance. +""" + +from oslo.config import cfg + +from savanna.openstack.common import importutils +from savanna.openstack.common import log as logging +from savanna.openstack.common.rpc import matchmaker as mm_common + +redis = importutils.try_import('redis') + + +matchmaker_redis_opts = [ + cfg.StrOpt('host', + default='127.0.0.1', + help='Host to locate redis'), + cfg.IntOpt('port', + default=6379, + help='Use this port to connect to redis host.'), + cfg.StrOpt('password', + default=None, + help='Password for Redis server. (optional)'), +] + +CONF = cfg.CONF +opt_group = cfg.OptGroup(name='matchmaker_redis', + title='Options for Redis-based MatchMaker') +CONF.register_group(opt_group) +CONF.register_opts(matchmaker_redis_opts, opt_group) +LOG = logging.getLogger(__name__) + + +class RedisExchange(mm_common.Exchange): + def __init__(self, matchmaker): + self.matchmaker = matchmaker + self.redis = matchmaker.redis + super(RedisExchange, self).__init__() + + +class RedisTopicExchange(RedisExchange): + """Exchange where all topic keys are split, sending to second half. + + i.e. "compute.host" sends a message to "compute" running on "host" + """ + def run(self, topic): + while True: + member_name = self.redis.srandmember(topic) + + if not member_name: + # If this happens, there are no + # longer any members. + break + + if not self.matchmaker.is_alive(topic, member_name): + continue + + host = member_name.split('.', 1)[1] + return [(member_name, host)] + return [] + + +class RedisFanoutExchange(RedisExchange): + """Return a list of all hosts.""" + def run(self, topic): + topic = topic.split('~', 1)[1] + hosts = self.redis.smembers(topic) + good_hosts = filter( + lambda host: self.matchmaker.is_alive(topic, host), hosts) + + return [(x, x.split('.', 1)[1]) for x in good_hosts] + + +class MatchMakerRedis(mm_common.HeartbeatMatchMakerBase): + """MatchMaker registering and looking-up hosts with a Redis server.""" + def __init__(self): + super(MatchMakerRedis, self).__init__() + + if not redis: + raise ImportError("Failed to import module redis.") + + self.redis = redis.StrictRedis( + host=CONF.matchmaker_redis.host, + port=CONF.matchmaker_redis.port, + password=CONF.matchmaker_redis.password) + + self.add_binding(mm_common.FanoutBinding(), RedisFanoutExchange(self)) + self.add_binding(mm_common.DirectBinding(), mm_common.DirectExchange()) + self.add_binding(mm_common.TopicBinding(), RedisTopicExchange(self)) + + def ack_alive(self, key, host): + topic = "%s.%s" % (key, host) + if not self.redis.expire(topic, CONF.matchmaker_heartbeat_ttl): + # If we could not update the expiration, the key + # might have been pruned. Re-register, creating a new + # key in Redis. + self.register(self.topic_host[host], host) + + def is_alive(self, topic, host): + if self.redis.ttl(host) == -1: + self.expire(topic, host) + return False + return True + + def expire(self, topic, host): + with self.redis.pipeline() as pipe: + pipe.multi() + pipe.delete(host) + pipe.srem(topic, host) + pipe.execute() + + def backend_register(self, key, key_host): + with self.redis.pipeline() as pipe: + pipe.multi() + pipe.sadd(key, key_host) + + # No value is needed, we just + # care if it exists. Sets aren't viable + # because only keys can expire. + pipe.set(key_host, '') + + pipe.execute() + + def backend_unregister(self, key, key_host): + with self.redis.pipeline() as pipe: + pipe.multi() + pipe.srem(key, key_host) + pipe.delete(key_host) + pipe.execute() diff --git a/savanna/openstack/common/rpc/matchmaker_ring.py b/savanna/openstack/common/rpc/matchmaker_ring.py new file mode 100644 index 00000000..46e28659 --- /dev/null +++ b/savanna/openstack/common/rpc/matchmaker_ring.py @@ -0,0 +1,110 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011-2013 Cloudscaling Group, Inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +""" +The MatchMaker classes should except a Topic or Fanout exchange key and +return keys for direct exchanges, per (approximate) AMQP parlance. +""" + +import itertools +import json + +from oslo.config import cfg + +from savanna.openstack.common.gettextutils import _ # noqa +from savanna.openstack.common import log as logging +from savanna.openstack.common.rpc import matchmaker as mm + + +matchmaker_opts = [ + # Matchmaker ring file + cfg.StrOpt('ringfile', + deprecated_name='matchmaker_ringfile', + deprecated_group='DEFAULT', + default='/etc/oslo/matchmaker_ring.json', + help='Matchmaker ring file (JSON)'), +] + +CONF = cfg.CONF +CONF.register_opts(matchmaker_opts, 'matchmaker_ring') +LOG = logging.getLogger(__name__) + + +class RingExchange(mm.Exchange): + """Match Maker where hosts are loaded from a static JSON formatted file. + + __init__ takes optional ring dictionary argument, otherwise + loads the ringfile from CONF.mathcmaker_ringfile. + """ + def __init__(self, ring=None): + super(RingExchange, self).__init__() + + if ring: + self.ring = ring + else: + fh = open(CONF.matchmaker_ring.ringfile, 'r') + self.ring = json.load(fh) + fh.close() + + self.ring0 = {} + for k in self.ring.keys(): + self.ring0[k] = itertools.cycle(self.ring[k]) + + def _ring_has(self, key): + if key in self.ring0: + return True + return False + + +class RoundRobinRingExchange(RingExchange): + """A Topic Exchange based on a hashmap.""" + def __init__(self, ring=None): + super(RoundRobinRingExchange, self).__init__(ring) + + def run(self, key): + if not self._ring_has(key): + LOG.warn( + _("No key defining hosts for topic '%s', " + "see ringfile") % (key, ) + ) + return [] + host = next(self.ring0[key]) + return [(key + '.' + host, host)] + + +class FanoutRingExchange(RingExchange): + """Fanout Exchange based on a hashmap.""" + def __init__(self, ring=None): + super(FanoutRingExchange, self).__init__(ring) + + def run(self, key): + # Assume starts with "fanout~", strip it for lookup. + nkey = key.split('fanout~')[1:][0] + if not self._ring_has(nkey): + LOG.warn( + _("No key defining hosts for topic '%s', " + "see ringfile") % (nkey, ) + ) + return [] + return map(lambda x: (key + '.' + x, x), self.ring[nkey]) + + +class MatchMakerRing(mm.MatchMakerBase): + """Match Maker where hosts are loaded from a static hashmap.""" + def __init__(self, ring=None): + super(MatchMakerRing, self).__init__() + self.add_binding(mm.FanoutBinding(), FanoutRingExchange(ring)) + self.add_binding(mm.DirectBinding(), mm.DirectExchange()) + self.add_binding(mm.TopicBinding(), RoundRobinRingExchange(ring)) diff --git a/savanna/openstack/common/rpc/proxy.py b/savanna/openstack/common/rpc/proxy.py new file mode 100644 index 00000000..ef31267f --- /dev/null +++ b/savanna/openstack/common/rpc/proxy.py @@ -0,0 +1,226 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2012-2013 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +A helper class for proxy objects to remote APIs. + +For more information about rpc API version numbers, see: + rpc/dispatcher.py +""" + + +from savanna.openstack.common import rpc +from savanna.openstack.common.rpc import common as rpc_common +from savanna.openstack.common.rpc import serializer as rpc_serializer + + +class RpcProxy(object): + """A helper class for rpc clients. + + This class is a wrapper around the RPC client API. It allows you to + specify the topic and API version in a single place. This is intended to + be used as a base class for a class that implements the client side of an + rpc API. + """ + + # The default namespace, which can be overriden in a subclass. + RPC_API_NAMESPACE = None + + def __init__(self, topic, default_version, version_cap=None, + serializer=None): + """Initialize an RpcProxy. + + :param topic: The topic to use for all messages. + :param default_version: The default API version to request in all + outgoing messages. This can be overridden on a per-message + basis. + :param version_cap: Optionally cap the maximum version used for sent + messages. + :param serializer: Optionaly (de-)serialize entities with a + provided helper. + """ + self.topic = topic + self.default_version = default_version + self.version_cap = version_cap + if serializer is None: + serializer = rpc_serializer.NoOpSerializer() + self.serializer = serializer + super(RpcProxy, self).__init__() + + def _set_version(self, msg, vers): + """Helper method to set the version in a message. + + :param msg: The message having a version added to it. + :param vers: The version number to add to the message. + """ + v = vers if vers else self.default_version + if (self.version_cap and not + rpc_common.version_is_compatible(self.version_cap, v)): + raise rpc_common.RpcVersionCapError(version_cap=self.version_cap) + msg['version'] = v + + def _get_topic(self, topic): + """Return the topic to use for a message.""" + return topic if topic else self.topic + + def can_send_version(self, version): + """Check to see if a version is compatible with the version cap.""" + return (not self.version_cap or + rpc_common.version_is_compatible(self.version_cap, version)) + + @staticmethod + def make_namespaced_msg(method, namespace, **kwargs): + return {'method': method, 'namespace': namespace, 'args': kwargs} + + def make_msg(self, method, **kwargs): + return self.make_namespaced_msg(method, self.RPC_API_NAMESPACE, + **kwargs) + + def _serialize_msg_args(self, context, kwargs): + """Helper method called to serialize message arguments. + + This calls our serializer on each argument, returning a new + set of args that have been serialized. + + :param context: The request context + :param kwargs: The arguments to serialize + :returns: A new set of serialized arguments + """ + new_kwargs = dict() + for argname, arg in kwargs.iteritems(): + new_kwargs[argname] = self.serializer.serialize_entity(context, + arg) + return new_kwargs + + def call(self, context, msg, topic=None, version=None, timeout=None): + """rpc.call() a remote method. + + :param context: The request context + :param msg: The message to send, including the method and args. + :param topic: Override the topic for this message. + :param version: (Optional) Override the requested API version in this + message. + :param timeout: (Optional) A timeout to use when waiting for the + response. If no timeout is specified, a default timeout will be + used that is usually sufficient. + + :returns: The return value from the remote method. + """ + self._set_version(msg, version) + msg['args'] = self._serialize_msg_args(context, msg['args']) + real_topic = self._get_topic(topic) + try: + result = rpc.call(context, real_topic, msg, timeout) + return self.serializer.deserialize_entity(context, result) + except rpc.common.Timeout as exc: + raise rpc.common.Timeout( + exc.info, real_topic, msg.get('method')) + + def multicall(self, context, msg, topic=None, version=None, timeout=None): + """rpc.multicall() a remote method. + + :param context: The request context + :param msg: The message to send, including the method and args. + :param topic: Override the topic for this message. + :param version: (Optional) Override the requested API version in this + message. + :param timeout: (Optional) A timeout to use when waiting for the + response. If no timeout is specified, a default timeout will be + used that is usually sufficient. + + :returns: An iterator that lets you process each of the returned values + from the remote method as they arrive. + """ + self._set_version(msg, version) + msg['args'] = self._serialize_msg_args(context, msg['args']) + real_topic = self._get_topic(topic) + try: + result = rpc.multicall(context, real_topic, msg, timeout) + return self.serializer.deserialize_entity(context, result) + except rpc.common.Timeout as exc: + raise rpc.common.Timeout( + exc.info, real_topic, msg.get('method')) + + def cast(self, context, msg, topic=None, version=None): + """rpc.cast() a remote method. + + :param context: The request context + :param msg: The message to send, including the method and args. + :param topic: Override the topic for this message. + :param version: (Optional) Override the requested API version in this + message. + + :returns: None. rpc.cast() does not wait on any return value from the + remote method. + """ + self._set_version(msg, version) + msg['args'] = self._serialize_msg_args(context, msg['args']) + rpc.cast(context, self._get_topic(topic), msg) + + def fanout_cast(self, context, msg, topic=None, version=None): + """rpc.fanout_cast() a remote method. + + :param context: The request context + :param msg: The message to send, including the method and args. + :param topic: Override the topic for this message. + :param version: (Optional) Override the requested API version in this + message. + + :returns: None. rpc.fanout_cast() does not wait on any return value + from the remote method. + """ + self._set_version(msg, version) + msg['args'] = self._serialize_msg_args(context, msg['args']) + rpc.fanout_cast(context, self._get_topic(topic), msg) + + def cast_to_server(self, context, server_params, msg, topic=None, + version=None): + """rpc.cast_to_server() a remote method. + + :param context: The request context + :param server_params: Server parameters. See rpc.cast_to_server() for + details. + :param msg: The message to send, including the method and args. + :param topic: Override the topic for this message. + :param version: (Optional) Override the requested API version in this + message. + + :returns: None. rpc.cast_to_server() does not wait on any + return values. + """ + self._set_version(msg, version) + msg['args'] = self._serialize_msg_args(context, msg['args']) + rpc.cast_to_server(context, server_params, self._get_topic(topic), msg) + + def fanout_cast_to_server(self, context, server_params, msg, topic=None, + version=None): + """rpc.fanout_cast_to_server() a remote method. + + :param context: The request context + :param server_params: Server parameters. See rpc.cast_to_server() for + details. + :param msg: The message to send, including the method and args. + :param topic: Override the topic for this message. + :param version: (Optional) Override the requested API version in this + message. + + :returns: None. rpc.fanout_cast_to_server() does not wait on any + return values. + """ + self._set_version(msg, version) + msg['args'] = self._serialize_msg_args(context, msg['args']) + rpc.fanout_cast_to_server(context, server_params, + self._get_topic(topic), msg) diff --git a/savanna/openstack/common/rpc/serializer.py b/savanna/openstack/common/rpc/serializer.py new file mode 100644 index 00000000..76c68310 --- /dev/null +++ b/savanna/openstack/common/rpc/serializer.py @@ -0,0 +1,52 @@ +# Copyright 2013 IBM Corp. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Provides the definition of an RPC serialization handler""" + +import abc + + +class Serializer(object): + """Generic (de-)serialization definition base class.""" + __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def serialize_entity(self, context, entity): + """Serialize something to primitive form. + + :param context: Security context + :param entity: Entity to be serialized + :returns: Serialized form of entity + """ + pass + + @abc.abstractmethod + def deserialize_entity(self, context, entity): + """Deserialize something from primitive form. + + :param context: Security context + :param entity: Primitive to be deserialized + :returns: Deserialized form of entity + """ + pass + + +class NoOpSerializer(Serializer): + """A serializer that does nothing.""" + + def serialize_entity(self, context, entity): + return entity + + def deserialize_entity(self, context, entity): + return entity diff --git a/savanna/openstack/common/rpc/service.py b/savanna/openstack/common/rpc/service.py new file mode 100644 index 00000000..79ecda5f --- /dev/null +++ b/savanna/openstack/common/rpc/service.py @@ -0,0 +1,76 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# Copyright 2011 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from savanna.openstack.common.gettextutils import _ # noqa +from savanna.openstack.common import log as logging +from savanna.openstack.common import rpc +from savanna.openstack.common.rpc import dispatcher as rpc_dispatcher +from savanna.openstack.common import service + + +LOG = logging.getLogger(__name__) + + +class Service(service.Service): + """Service object for binaries running on hosts. + + A service enables rpc by listening to queues based on topic and host. + """ + def __init__(self, host, topic, manager=None): + super(Service, self).__init__() + self.host = host + self.topic = topic + if manager is None: + self.manager = self + else: + self.manager = manager + + def start(self): + super(Service, self).start() + + self.conn = rpc.create_connection(new=True) + LOG.debug(_("Creating Consumer connection for Service %s") % + self.topic) + + dispatcher = rpc_dispatcher.RpcDispatcher([self.manager]) + + # Share this same connection for these Consumers + self.conn.create_consumer(self.topic, dispatcher, fanout=False) + + node_topic = '%s.%s' % (self.topic, self.host) + self.conn.create_consumer(node_topic, dispatcher, fanout=False) + + self.conn.create_consumer(self.topic, dispatcher, fanout=True) + + # Hook to allow the manager to do other initializations after + # the rpc connection is created. + if callable(getattr(self.manager, 'initialize_service_hook', None)): + self.manager.initialize_service_hook(self) + + # Consume from all consumers in a thread + self.conn.consume_in_thread() + + def stop(self): + # Try to shut the connection down, but if we get any sort of + # errors, go ahead and ignore them.. as we're shutting down anyway + try: + self.conn.close() + except Exception: + pass + super(Service, self).stop() diff --git a/savanna/openstack/common/rpc/zmq_receiver.py b/savanna/openstack/common/rpc/zmq_receiver.py new file mode 100755 index 00000000..61d8d713 --- /dev/null +++ b/savanna/openstack/common/rpc/zmq_receiver.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 OpenStack Foundation +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import eventlet +eventlet.monkey_patch() + +import contextlib +import sys + +from oslo.config import cfg + +from savanna.openstack.common import log as logging +from savanna.openstack.common import rpc +from savanna.openstack.common.rpc import impl_zmq + +CONF = cfg.CONF +CONF.register_opts(rpc.rpc_opts) +CONF.register_opts(impl_zmq.zmq_opts) + + +def main(): + CONF(sys.argv[1:], project='oslo') + logging.setup("oslo") + + with contextlib.closing(impl_zmq.ZmqProxy(CONF)) as reactor: + reactor.consume_in_thread() + reactor.wait() diff --git a/savanna/openstack/common/service.py b/savanna/openstack/common/service.py new file mode 100644 index 00000000..f4fcaab8 --- /dev/null +++ b/savanna/openstack/common/service.py @@ -0,0 +1,376 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# Copyright 2011 Justin Santa Barbara +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Generic Node base class for all workers that run on hosts.""" + +import errno +import os +import random +import signal +import sys +import time + +import eventlet +from eventlet import event +import logging as std_logging +from oslo.config import cfg + +from savanna.openstack.common import eventlet_backdoor +from savanna.openstack.common.gettextutils import _ # noqa +from savanna.openstack.common import importutils +from savanna.openstack.common import log as logging +from savanna.openstack.common import threadgroup + + +rpc = importutils.try_import('savanna.openstack.common.rpc') +CONF = cfg.CONF +LOG = logging.getLogger(__name__) + + +class Launcher(object): + """Launch one or more services and wait for them to complete.""" + + def __init__(self): + """Initialize the service launcher. + + :returns: None + + """ + self.services = Services() + self.backdoor_port = eventlet_backdoor.initialize_if_enabled() + + def launch_service(self, service): + """Load and start the given service. + + :param service: The service you would like to start. + :returns: None + + """ + service.backdoor_port = self.backdoor_port + self.services.add(service) + + def stop(self): + """Stop all services which are currently running. + + :returns: None + + """ + self.services.stop() + + def wait(self): + """Waits until all services have been stopped, and then returns. + + :returns: None + + """ + self.services.wait() + + +class SignalExit(SystemExit): + def __init__(self, signo, exccode=1): + super(SignalExit, self).__init__(exccode) + self.signo = signo + + +class ServiceLauncher(Launcher): + def _handle_signal(self, signo, frame): + # Allow the process to be killed again and die from natural causes + signal.signal(signal.SIGTERM, signal.SIG_DFL) + signal.signal(signal.SIGINT, signal.SIG_DFL) + + raise SignalExit(signo) + + def wait(self): + signal.signal(signal.SIGTERM, self._handle_signal) + signal.signal(signal.SIGINT, self._handle_signal) + + LOG.debug(_('Full set of CONF:')) + CONF.log_opt_values(LOG, std_logging.DEBUG) + + status = None + try: + super(ServiceLauncher, self).wait() + except SignalExit as exc: + signame = {signal.SIGTERM: 'SIGTERM', + signal.SIGINT: 'SIGINT'}[exc.signo] + LOG.info(_('Caught %s, exiting'), signame) + status = exc.code + except SystemExit as exc: + status = exc.code + finally: + self.stop() + if rpc: + try: + rpc.cleanup() + except Exception: + # We're shutting down, so it doesn't matter at this point. + LOG.exception(_('Exception during rpc cleanup.')) + return status + + +class ServiceWrapper(object): + def __init__(self, service, workers): + self.service = service + self.workers = workers + self.children = set() + self.forktimes = [] + + +class ProcessLauncher(object): + def __init__(self): + self.children = {} + self.sigcaught = None + self.running = True + rfd, self.writepipe = os.pipe() + self.readpipe = eventlet.greenio.GreenPipe(rfd, 'r') + + signal.signal(signal.SIGTERM, self._handle_signal) + signal.signal(signal.SIGINT, self._handle_signal) + + def _handle_signal(self, signo, frame): + self.sigcaught = signo + self.running = False + + # Allow the process to be killed again and die from natural causes + signal.signal(signal.SIGTERM, signal.SIG_DFL) + signal.signal(signal.SIGINT, signal.SIG_DFL) + + def _pipe_watcher(self): + # This will block until the write end is closed when the parent + # dies unexpectedly + self.readpipe.read() + + LOG.info(_('Parent process has died unexpectedly, exiting')) + + sys.exit(1) + + def _child_process(self, service): + # Setup child signal handlers differently + def _sigterm(*args): + signal.signal(signal.SIGTERM, signal.SIG_DFL) + raise SignalExit(signal.SIGTERM) + + signal.signal(signal.SIGTERM, _sigterm) + # Block SIGINT and let the parent send us a SIGTERM + signal.signal(signal.SIGINT, signal.SIG_IGN) + + # Reopen the eventlet hub to make sure we don't share an epoll + # fd with parent and/or siblings, which would be bad + eventlet.hubs.use_hub() + + # Close write to ensure only parent has it open + os.close(self.writepipe) + # Create greenthread to watch for parent to close pipe + eventlet.spawn_n(self._pipe_watcher) + + # Reseed random number generator + random.seed() + + launcher = Launcher() + launcher.launch_service(service) + launcher.wait() + + def _start_child(self, wrap): + if len(wrap.forktimes) > wrap.workers: + # Limit ourselves to one process a second (over the period of + # number of workers * 1 second). This will allow workers to + # start up quickly but ensure we don't fork off children that + # die instantly too quickly. + if time.time() - wrap.forktimes[0] < wrap.workers: + LOG.info(_('Forking too fast, sleeping')) + time.sleep(1) + + wrap.forktimes.pop(0) + + wrap.forktimes.append(time.time()) + + pid = os.fork() + if pid == 0: + # NOTE(johannes): All exceptions are caught to ensure this + # doesn't fallback into the loop spawning children. It would + # be bad for a child to spawn more children. + status = 0 + try: + self._child_process(wrap.service) + except SignalExit as exc: + signame = {signal.SIGTERM: 'SIGTERM', + signal.SIGINT: 'SIGINT'}[exc.signo] + LOG.info(_('Caught %s, exiting'), signame) + status = exc.code + except SystemExit as exc: + status = exc.code + except BaseException: + LOG.exception(_('Unhandled exception')) + status = 2 + finally: + wrap.service.stop() + + os._exit(status) + + LOG.info(_('Started child %d'), pid) + + wrap.children.add(pid) + self.children[pid] = wrap + + return pid + + def launch_service(self, service, workers=1): + wrap = ServiceWrapper(service, workers) + + LOG.info(_('Starting %d workers'), wrap.workers) + while self.running and len(wrap.children) < wrap.workers: + self._start_child(wrap) + + def _wait_child(self): + try: + # Don't block if no child processes have exited + pid, status = os.waitpid(0, os.WNOHANG) + if not pid: + return None + except OSError as exc: + if exc.errno not in (errno.EINTR, errno.ECHILD): + raise + return None + + if os.WIFSIGNALED(status): + sig = os.WTERMSIG(status) + LOG.info(_('Child %(pid)d killed by signal %(sig)d'), + dict(pid=pid, sig=sig)) + else: + code = os.WEXITSTATUS(status) + LOG.info(_('Child %(pid)s exited with status %(code)d'), + dict(pid=pid, code=code)) + + if pid not in self.children: + LOG.warning(_('pid %d not in child list'), pid) + return None + + wrap = self.children.pop(pid) + wrap.children.remove(pid) + return wrap + + def wait(self): + """Loop waiting on children to die and respawning as necessary.""" + + LOG.debug(_('Full set of CONF:')) + CONF.log_opt_values(LOG, std_logging.DEBUG) + + while self.running: + wrap = self._wait_child() + if not wrap: + # Yield to other threads if no children have exited + # Sleep for a short time to avoid excessive CPU usage + # (see bug #1095346) + eventlet.greenthread.sleep(.01) + continue + + while self.running and len(wrap.children) < wrap.workers: + self._start_child(wrap) + + if self.sigcaught: + signame = {signal.SIGTERM: 'SIGTERM', + signal.SIGINT: 'SIGINT'}[self.sigcaught] + LOG.info(_('Caught %s, stopping children'), signame) + + for pid in self.children: + try: + os.kill(pid, signal.SIGTERM) + except OSError as exc: + if exc.errno != errno.ESRCH: + raise + + # Wait for children to die + if self.children: + LOG.info(_('Waiting on %d children to exit'), len(self.children)) + while self.children: + self._wait_child() + + +class Service(object): + """Service object for binaries running on hosts.""" + + def __init__(self, threads=1000): + self.tg = threadgroup.ThreadGroup(threads) + + # signal that the service is done shutting itself down: + self._done = event.Event() + + def start(self): + pass + + def stop(self): + self.tg.stop() + self.tg.wait() + # Signal that service cleanup is done: + if not self._done.ready(): + self._done.send() + + def wait(self): + self._done.wait() + + +class Services(object): + + def __init__(self): + self.services = [] + self.tg = threadgroup.ThreadGroup() + self.done = event.Event() + + def add(self, service): + self.services.append(service) + self.tg.add_thread(self.run_service, service, self.done) + + def stop(self): + # wait for graceful shutdown of services: + for service in self.services: + service.stop() + service.wait() + + # Each service has performed cleanup, now signal that the run_service + # wrapper threads can now die: + if not self.done.ready(): + self.done.send() + + # reap threads: + self.tg.stop() + + def wait(self): + self.tg.wait() + + @staticmethod + def run_service(service, done): + """Service start wrapper. + + :param service: service to run + :param done: event to wait on until a shutdown is triggered + :returns: None + + """ + service.start() + done.wait() + + +def launch(service, workers=None): + if workers: + launcher = ProcessLauncher() + launcher.launch_service(service, workers=workers) + else: + launcher = ServiceLauncher() + launcher.launch_service(service) + return launcher diff --git a/savanna/openstack/common/sslutils.py b/savanna/openstack/common/sslutils.py new file mode 100644 index 00000000..c7d75756 --- /dev/null +++ b/savanna/openstack/common/sslutils.py @@ -0,0 +1,100 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2013 IBM Corp. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import os +import ssl + +from oslo.config import cfg + +from savanna.openstack.common.gettextutils import _ # noqa + + +ssl_opts = [ + cfg.StrOpt('ca_file', + default=None, + help="CA certificate file to use to verify " + "connecting clients"), + cfg.StrOpt('cert_file', + default=None, + help="Certificate file to use when starting " + "the server securely"), + cfg.StrOpt('key_file', + default=None, + help="Private key file to use when starting " + "the server securely"), +] + + +CONF = cfg.CONF +CONF.register_opts(ssl_opts, "ssl") + + +def is_enabled(): + cert_file = CONF.ssl.cert_file + key_file = CONF.ssl.key_file + ca_file = CONF.ssl.ca_file + use_ssl = cert_file or key_file + + if cert_file and not os.path.exists(cert_file): + raise RuntimeError(_("Unable to find cert_file : %s") % cert_file) + + if ca_file and not os.path.exists(ca_file): + raise RuntimeError(_("Unable to find ca_file : %s") % ca_file) + + if key_file and not os.path.exists(key_file): + raise RuntimeError(_("Unable to find key_file : %s") % key_file) + + if use_ssl and (not cert_file or not key_file): + raise RuntimeError(_("When running server in SSL mode, you must " + "specify both a cert_file and key_file " + "option value in your configuration file")) + + return use_ssl + + +def wrap(sock): + ssl_kwargs = { + 'server_side': True, + 'certfile': CONF.ssl.cert_file, + 'keyfile': CONF.ssl.key_file, + 'cert_reqs': ssl.CERT_NONE, + } + + if CONF.ssl.ca_file: + ssl_kwargs['ca_certs'] = CONF.ssl.ca_file + ssl_kwargs['cert_reqs'] = ssl.CERT_REQUIRED + + return ssl.wrap_socket(sock, **ssl_kwargs) + + +_SSL_PROTOCOLS = { + "tlsv1": ssl.PROTOCOL_TLSv1, + "sslv23": ssl.PROTOCOL_SSLv23, + "sslv3": ssl.PROTOCOL_SSLv3 +} + +try: + _SSL_PROTOCOLS["sslv2"] = ssl.PROTOCOL_SSLv2 +except AttributeError: + pass + + +def validate_ssl_version(version): + key = version.lower() + try: + return _SSL_PROTOCOLS[key] + except KeyError: + raise RuntimeError(_("Invalid SSL version : %s") % version) diff --git a/savanna/openstack/common/threadgroup.py b/savanna/openstack/common/threadgroup.py index 88b53ed1..c1def1a9 100644 --- a/savanna/openstack/common/threadgroup.py +++ b/savanna/openstack/common/threadgroup.py @@ -14,7 +14,7 @@ # License for the specific language governing permissions and limitations # under the License. -from eventlet import greenlet +import eventlet from eventlet import greenpool from eventlet import greenthread @@ -105,7 +105,7 @@ class ThreadGroup(object): for x in self.timers: try: x.wait() - except greenlet.GreenletExit: + except eventlet.greenlet.GreenletExit: pass except Exception as ex: LOG.exception(ex) @@ -115,7 +115,7 @@ class ThreadGroup(object): continue try: x.wait() - except greenlet.GreenletExit: + except eventlet.greenlet.GreenletExit: pass except Exception as ex: LOG.exception(ex) diff --git a/savanna/openstack/common/timeutils.py b/savanna/openstack/common/timeutils.py index ac2441bc..bd60489e 100644 --- a/savanna/openstack/common/timeutils.py +++ b/savanna/openstack/common/timeutils.py @@ -23,6 +23,7 @@ import calendar import datetime import iso8601 +import six # ISO 8601 extended time format with microseconds @@ -75,14 +76,14 @@ def normalize_time(timestamp): def is_older_than(before, seconds): """Return True if before is older than seconds.""" - if isinstance(before, basestring): + if isinstance(before, six.string_types): before = parse_strtime(before).replace(tzinfo=None) return utcnow() - before > datetime.timedelta(seconds=seconds) def is_newer_than(after, seconds): """Return True if after is newer than seconds.""" - if isinstance(after, basestring): + if isinstance(after, six.string_types): after = parse_strtime(after).replace(tzinfo=None) return after - utcnow() > datetime.timedelta(seconds=seconds) diff --git a/tools/pip-requires b/tools/pip-requires index 88be908e..e1125b63 100644 --- a/tools/pip-requires +++ b/tools/pip-requires @@ -3,8 +3,8 @@ alembic>=0.4.1 eventlet>=0.9.12 flask==0.9 jsonschema>=1.0.0,!=1.4.0,<2 +kombu>=2.4.8 netaddr -oslo.config>=1.1.0 paramiko>=1.8.0 pbr>=0.5.16,<0.6 pycrypto>=2.6 @@ -14,3 +14,6 @@ python-novaclient>=2.12.0,<3 six sqlalchemy>=0.7,<=0.7.99 webob>=1.2.3,<1.3 + +-f http://tarballs.openstack.org/oslo.config/oslo.config-1.2.0a3.tar.gz#egg=oslo.config-1.2.0a3 +oslo.config>=1.2.0a3