diff --git a/mistral/api/__init__.py b/mistral/api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mistral/api/controllers/__init__.py b/mistral/api/controllers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mistral/cmd/__init__.py b/mistral/cmd/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mistral/cmd/api.py b/mistral/cmd/api.py new file mode 100644 index 00000000..e16dd930 --- /dev/null +++ b/mistral/cmd/api.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python +# +# Copyright (c) 2013 Mirantis, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import os +import sys + +possible_topdir = os.path.normpath(os.path.join(os.path.abspath(__file__), + os.pardir, + os.pardir, + os.pardir)) +if os.path.exists(os.path.join(possible_topdir, 'mistral-api', '__init__.py')): + sys.path.insert(0, possible_topdir) + +#from mistral import config +#from mistral.openstack.common import log + + +def main(): + raise NotImplemented('Mistral API is not implemented yet.') + + +if __name__ == '__main__': + main() diff --git a/mistral/db/__init__.py b/mistral/db/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mistral/engine/__init__.py b/mistral/engine/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mistral/openstack/__init__.py b/mistral/openstack/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mistral/openstack/common/__init__.py b/mistral/openstack/common/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mistral/openstack/common/apiclient/__init__.py b/mistral/openstack/common/apiclient/__init__.py new file mode 100644 index 00000000..d5d00222 --- /dev/null +++ b/mistral/openstack/common/apiclient/__init__.py @@ -0,0 +1,16 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2013 OpenStack Foundation +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. diff --git a/mistral/openstack/common/apiclient/auth.py b/mistral/openstack/common/apiclient/auth.py new file mode 100644 index 00000000..f1df136c --- /dev/null +++ b/mistral/openstack/common/apiclient/auth.py @@ -0,0 +1,227 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2013 OpenStack Foundation +# Copyright 2013 Spanish National Research Council. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +# E0202: An attribute inherited from %s hide this method +# pylint: disable=E0202 + +import abc +import argparse +import logging +import os + +import six +from stevedore import extension + +from mistral.openstack.common.apiclient import exceptions + + +logger = logging.getLogger(__name__) + + +_discovered_plugins = {} + + +def discover_auth_systems(): + """Discover the available auth-systems. + + This won't take into account the old style auth-systems. + """ + global _discovered_plugins + _discovered_plugins = {} + + def add_plugin(ext): + _discovered_plugins[ext.name] = ext.plugin + + ep_namespace = "mistral.openstack.common.apiclient.auth" + mgr = extension.ExtensionManager(ep_namespace) + mgr.map(add_plugin) + + +def load_auth_system_opts(parser): + """Load options needed by the available auth-systems into a parser. + + This function will try to populate the parser with options from the + available plugins. + """ + group = parser.add_argument_group("Common auth options") + BaseAuthPlugin.add_common_opts(group) + for name, auth_plugin in _discovered_plugins.iteritems(): + group = parser.add_argument_group( + "Auth-system '%s' options" % name, + conflict_handler="resolve") + auth_plugin.add_opts(group) + + +def load_plugin(auth_system): + try: + plugin_class = _discovered_plugins[auth_system] + except KeyError: + raise exceptions.AuthSystemNotFound(auth_system) + return plugin_class(auth_system=auth_system) + + +def load_plugin_from_args(args): + """Load required plugin and populate it with options. + + Try to guess auth system if it is not specified. Systems are tried in + alphabetical order. + + :type args: argparse.Namespace + :raises: AuthorizationFailure + """ + auth_system = args.os_auth_system + if auth_system: + plugin = load_plugin(auth_system) + plugin.parse_opts(args) + plugin.sufficient_options() + return plugin + + for plugin_auth_system in sorted(_discovered_plugins.iterkeys()): + plugin_class = _discovered_plugins[plugin_auth_system] + plugin = plugin_class() + plugin.parse_opts(args) + try: + plugin.sufficient_options() + except exceptions.AuthPluginOptionsMissing: + continue + return plugin + raise exceptions.AuthPluginOptionsMissing(["auth_system"]) + + +@six.add_metaclass(abc.ABCMeta) +class BaseAuthPlugin(object): + """Base class for authentication plugins. + + An authentication plugin needs to override at least the authenticate + method to be a valid plugin. + """ + + auth_system = None + opt_names = [] + common_opt_names = [ + "auth_system", + "username", + "password", + "tenant_name", + "token", + "auth_url", + ] + + def __init__(self, auth_system=None, **kwargs): + self.auth_system = auth_system or self.auth_system + self.opts = dict((name, kwargs.get(name)) + for name in self.opt_names) + + @staticmethod + def _parser_add_opt(parser, opt): + """Add an option to parser in two variants. + + :param opt: option name (with underscores) + """ + dashed_opt = opt.replace("_", "-") + env_var = "OS_%s" % opt.upper() + arg_default = os.environ.get(env_var, "") + arg_help = "Defaults to env[%s]." % env_var + parser.add_argument( + "--os-%s" % dashed_opt, + metavar="<%s>" % dashed_opt, + default=arg_default, + help=arg_help) + parser.add_argument( + "--os_%s" % opt, + metavar="<%s>" % dashed_opt, + help=argparse.SUPPRESS) + + @classmethod + def add_opts(cls, parser): + """Populate the parser with the options for this plugin. + """ + for opt in cls.opt_names: + # use `BaseAuthPlugin.common_opt_names` since it is never + # changed in child classes + if opt not in BaseAuthPlugin.common_opt_names: + cls._parser_add_opt(parser, opt) + + @classmethod + def add_common_opts(cls, parser): + """Add options that are common for several plugins. + """ + for opt in cls.common_opt_names: + cls._parser_add_opt(parser, opt) + + @staticmethod + def get_opt(opt_name, args): + """Return option name and value. + + :param opt_name: name of the option, e.g., "username" + :param args: parsed arguments + """ + return (opt_name, getattr(args, "os_%s" % opt_name, None)) + + def parse_opts(self, args): + """Parse the actual auth-system options if any. + + This method is expected to populate the attribute `self.opts` with a + dict containing the options and values needed to make authentication. + """ + self.opts.update(dict(self.get_opt(opt_name, args) + for opt_name in self.opt_names)) + + def authenticate(self, http_client): + """Authenticate using plugin defined method. + + The method usually analyses `self.opts` and performs + a request to authentication server. + + :param http_client: client object that needs authentication + :type http_client: HTTPClient + :raises: AuthorizationFailure + """ + self.sufficient_options() + self._do_authenticate(http_client) + + @abc.abstractmethod + def _do_authenticate(self, http_client): + """Protected method for authentication. + """ + + def sufficient_options(self): + """Check if all required options are present. + + :raises: AuthPluginOptionsMissing + """ + missing = [opt + for opt in self.opt_names + if not self.opts.get(opt)] + if missing: + raise exceptions.AuthPluginOptionsMissing(missing) + + @abc.abstractmethod + def token_and_endpoint(self, endpoint_type, service_type): + """Return token and endpoint. + + :param service_type: Service type of the endpoint + :type service_type: string + :param endpoint_type: Type of endpoint. + Possible values: public or publicURL, + internal or internalURL, + admin or adminURL + :type endpoint_type: string + :returns: tuple of token and endpoint strings + :raises: EndpointException + """ diff --git a/mistral/openstack/common/apiclient/base.py b/mistral/openstack/common/apiclient/base.py new file mode 100644 index 00000000..75dae495 --- /dev/null +++ b/mistral/openstack/common/apiclient/base.py @@ -0,0 +1,493 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 Jacob Kaplan-Moss +# Copyright 2011 OpenStack Foundation +# Copyright 2012 Grid Dynamics +# Copyright 2013 OpenStack Foundation +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Base utilities to build API operation managers and objects on top of. +""" + +# E1102: %s is not callable +# pylint: disable=E1102 + +import abc +import urllib + +import six + +from mistral.openstack.common.apiclient import exceptions +from mistral.openstack.common import strutils + + +def getid(obj): + """Return id if argument is a Resource. + + Abstracts the common pattern of allowing both an object or an object's ID + (UUID) as a parameter when dealing with relationships. + """ + try: + if obj.uuid: + return obj.uuid + except AttributeError: + pass + try: + return obj.id + except AttributeError: + return obj + + +# TODO(aababilov): call run_hooks() in HookableMixin's child classes +class HookableMixin(object): + """Mixin so classes can register and run hooks.""" + _hooks_map = {} + + @classmethod + def add_hook(cls, hook_type, hook_func): + """Add a new hook of specified type. + + :param cls: class that registers hooks + :param hook_type: hook type, e.g., '__pre_parse_args__' + :param hook_func: hook function + """ + if hook_type not in cls._hooks_map: + cls._hooks_map[hook_type] = [] + + cls._hooks_map[hook_type].append(hook_func) + + @classmethod + def run_hooks(cls, hook_type, *args, **kwargs): + """Run all hooks of specified type. + + :param cls: class that registers hooks + :param hook_type: hook type, e.g., '__pre_parse_args__' + :param **args: args to be passed to every hook function + :param **kwargs: kwargs to be passed to every hook function + """ + hook_funcs = cls._hooks_map.get(hook_type) or [] + for hook_func in hook_funcs: + hook_func(*args, **kwargs) + + +class BaseManager(HookableMixin): + """Basic manager type providing common operations. + + Managers interact with a particular type of API (servers, flavors, images, + etc.) and provide CRUD operations for them. + """ + resource_class = None + + def __init__(self, client): + """Initializes BaseManager with `client`. + + :param client: instance of BaseClient descendant for HTTP requests + """ + super(BaseManager, self).__init__() + self.client = client + + def _list(self, url, response_key, obj_class=None, json=None): + """List the collection. + + :param url: a partial URL, e.g., '/servers' + :param response_key: the key to be looked up in response dictionary, + e.g., 'servers' + :param obj_class: class for constructing the returned objects + (self.resource_class will be used by default) + :param json: data that will be encoded as JSON and passed in POST + request (GET will be sent by default) + """ + if json: + body = self.client.post(url, json=json).json() + else: + body = self.client.get(url).json() + + if obj_class is None: + obj_class = self.resource_class + + data = body[response_key] + # NOTE(ja): keystone returns values as list as {'values': [ ... ]} + # unlike other services which just return the list... + try: + data = data['values'] + except (KeyError, TypeError): + pass + + return [obj_class(self, res, loaded=True) for res in data if res] + + def _get(self, url, response_key): + """Get an object from collection. + + :param url: a partial URL, e.g., '/servers' + :param response_key: the key to be looked up in response dictionary, + e.g., 'server' + """ + body = self.client.get(url).json() + return self.resource_class(self, body[response_key], loaded=True) + + def _head(self, url): + """Retrieve request headers for an object. + + :param url: a partial URL, e.g., '/servers' + """ + resp = self.client.head(url) + return resp.status_code == 204 + + def _post(self, url, json, response_key, return_raw=False): + """Create an object. + + :param url: a partial URL, e.g., '/servers' + :param json: data that will be encoded as JSON and passed in POST + request (GET will be sent by default) + :param response_key: the key to be looked up in response dictionary, + e.g., 'servers' + :param return_raw: flag to force returning raw JSON instead of + Python object of self.resource_class + """ + body = self.client.post(url, json=json).json() + if return_raw: + return body[response_key] + return self.resource_class(self, body[response_key]) + + def _put(self, url, json=None, response_key=None): + """Update an object with PUT method. + + :param url: a partial URL, e.g., '/servers' + :param json: data that will be encoded as JSON and passed in POST + request (GET will be sent by default) + :param response_key: the key to be looked up in response dictionary, + e.g., 'servers' + """ + resp = self.client.put(url, json=json) + # PUT requests may not return a body + if resp.content: + body = resp.json() + if response_key is not None: + return self.resource_class(self, body[response_key]) + else: + return self.resource_class(self, body) + + def _patch(self, url, json=None, response_key=None): + """Update an object with PATCH method. + + :param url: a partial URL, e.g., '/servers' + :param json: data that will be encoded as JSON and passed in POST + request (GET will be sent by default) + :param response_key: the key to be looked up in response dictionary, + e.g., 'servers' + """ + body = self.client.patch(url, json=json).json() + if response_key is not None: + return self.resource_class(self, body[response_key]) + else: + return self.resource_class(self, body) + + def _delete(self, url): + """Delete an object. + + :param url: a partial URL, e.g., '/servers/my-server' + """ + return self.client.delete(url) + + +@six.add_metaclass(abc.ABCMeta) +class ManagerWithFind(BaseManager): + """Manager with additional `find()`/`findall()` methods.""" + + @abc.abstractmethod + def list(self): + pass + + def find(self, **kwargs): + """Find a single item with attributes matching ``**kwargs``. + + This isn't very efficient: it loads the entire list then filters on + the Python side. + """ + matches = self.findall(**kwargs) + num_matches = len(matches) + if num_matches == 0: + msg = "No %s matching %s." % (self.resource_class.__name__, kwargs) + raise exceptions.NotFound(msg) + elif num_matches > 1: + raise exceptions.NoUniqueMatch() + else: + return matches[0] + + def findall(self, **kwargs): + """Find all items with attributes matching ``**kwargs``. + + This isn't very efficient: it loads the entire list then filters on + the Python side. + """ + found = [] + searches = kwargs.items() + + for obj in self.list(): + try: + if all(getattr(obj, attr) == value + for (attr, value) in searches): + found.append(obj) + except AttributeError: + continue + + return found + + +class CrudManager(BaseManager): + """Base manager class for manipulating entities. + + Children of this class are expected to define a `collection_key` and `key`. + + - `collection_key`: Usually a plural noun by convention (e.g. `entities`); + used to refer collections in both URL's (e.g. `/v3/entities`) and JSON + objects containing a list of member resources (e.g. `{'entities': [{}, + {}, {}]}`). + - `key`: Usually a singular noun by convention (e.g. `entity`); used to + refer to an individual member of the collection. + + """ + collection_key = None + key = None + + def build_url(self, base_url=None, **kwargs): + """Builds a resource URL for the given kwargs. + + Given an example collection where `collection_key = 'entities'` and + `key = 'entity'`, the following URL's could be generated. + + By default, the URL will represent a collection of entities, e.g.:: + + /entities + + If kwargs contains an `entity_id`, then the URL will represent a + specific member, e.g.:: + + /entities/{entity_id} + + :param base_url: if provided, the generated URL will be appended to it + """ + url = base_url if base_url is not None else '' + + url += '/%s' % self.collection_key + + # do we have a specific entity? + entity_id = kwargs.get('%s_id' % self.key) + if entity_id is not None: + url += '/%s' % entity_id + + return url + + def _filter_kwargs(self, kwargs): + """Drop null values and handle ids.""" + for key, ref in kwargs.copy().iteritems(): + if ref is None: + kwargs.pop(key) + else: + if isinstance(ref, Resource): + kwargs.pop(key) + kwargs['%s_id' % key] = getid(ref) + return kwargs + + def create(self, **kwargs): + kwargs = self._filter_kwargs(kwargs) + return self._post( + self.build_url(**kwargs), + {self.key: kwargs}, + self.key) + + def get(self, **kwargs): + kwargs = self._filter_kwargs(kwargs) + return self._get( + self.build_url(**kwargs), + self.key) + + def head(self, **kwargs): + kwargs = self._filter_kwargs(kwargs) + return self._head(self.build_url(**kwargs)) + + def list(self, base_url=None, **kwargs): + """List the collection. + + :param base_url: if provided, the generated URL will be appended to it + """ + kwargs = self._filter_kwargs(kwargs) + + return self._list( + '%(base_url)s%(query)s' % { + 'base_url': self.build_url(base_url=base_url, **kwargs), + 'query': '?%s' % urllib.urlencode(kwargs) if kwargs else '', + }, + self.collection_key) + + def put(self, base_url=None, **kwargs): + """Update an element. + + :param base_url: if provided, the generated URL will be appended to it + """ + kwargs = self._filter_kwargs(kwargs) + + return self._put(self.build_url(base_url=base_url, **kwargs)) + + def update(self, **kwargs): + kwargs = self._filter_kwargs(kwargs) + params = kwargs.copy() + params.pop('%s_id' % self.key) + + return self._patch( + self.build_url(**kwargs), + {self.key: params}, + self.key) + + def delete(self, **kwargs): + kwargs = self._filter_kwargs(kwargs) + + return self._delete( + self.build_url(**kwargs)) + + def find(self, base_url=None, **kwargs): + """Find a single item with attributes matching ``**kwargs``. + + :param base_url: if provided, the generated URL will be appended to it + """ + kwargs = self._filter_kwargs(kwargs) + + rl = self._list( + '%(base_url)s%(query)s' % { + 'base_url': self.build_url(base_url=base_url, **kwargs), + 'query': '?%s' % urllib.urlencode(kwargs) if kwargs else '', + }, + self.collection_key) + num = len(rl) + + if num == 0: + msg = "No %s matching %s." % (self.resource_class.__name__, kwargs) + raise exceptions.NotFound(404, msg) + elif num > 1: + raise exceptions.NoUniqueMatch + else: + return rl[0] + + +class Extension(HookableMixin): + """Extension descriptor.""" + + SUPPORTED_HOOKS = ('__pre_parse_args__', '__post_parse_args__') + manager_class = None + + def __init__(self, name, module): + super(Extension, self).__init__() + self.name = name + self.module = module + self._parse_extension_module() + + def _parse_extension_module(self): + self.manager_class = None + for attr_name, attr_value in self.module.__dict__.items(): + if attr_name in self.SUPPORTED_HOOKS: + self.add_hook(attr_name, attr_value) + else: + try: + if issubclass(attr_value, BaseManager): + self.manager_class = attr_value + except TypeError: + pass + + def __repr__(self): + return "" % self.name + + +class Resource(object): + """Base class for OpenStack resources (tenant, user, etc.). + + This is pretty much just a bag for attributes. + """ + + HUMAN_ID = False + NAME_ATTR = 'name' + + def __init__(self, manager, info, loaded=False): + """Populate and bind to a manager. + + :param manager: BaseManager object + :param info: dictionary representing resource attributes + :param loaded: prevent lazy-loading if set to True + """ + self.manager = manager + self._info = info + self._add_details(info) + self._loaded = loaded + + def __repr__(self): + reprkeys = sorted(k + for k in self.__dict__.keys() + if k[0] != '_' and k != 'manager') + info = ", ".join("%s=%s" % (k, getattr(self, k)) for k in reprkeys) + return "<%s %s>" % (self.__class__.__name__, info) + + @property + def human_id(self): + """Human-readable ID which can be used for bash completion. + """ + if self.NAME_ATTR in self.__dict__ and self.HUMAN_ID: + return strutils.to_slug(getattr(self, self.NAME_ATTR)) + return None + + def _add_details(self, info): + for (k, v) in info.iteritems(): + try: + setattr(self, k, v) + self._info[k] = v + except AttributeError: + # In this case we already defined the attribute on the class + pass + + def __getattr__(self, k): + if k not in self.__dict__: + #NOTE(bcwaldon): disallow lazy-loading if already loaded once + if not self.is_loaded(): + self.get() + return self.__getattr__(k) + + raise AttributeError(k) + else: + return self.__dict__[k] + + def get(self): + # set_loaded() first ... so if we have to bail, we know we tried. + self.set_loaded(True) + if not hasattr(self.manager, 'get'): + return + + new = self.manager.get(self.id) + if new: + self._add_details(new._info) + + def __eq__(self, other): + if not isinstance(other, Resource): + return NotImplemented + # two resources of different types are not equal + if not isinstance(other, self.__class__): + return False + if hasattr(self, 'id') and hasattr(other, 'id'): + return self.id == other.id + return self._info == other._info + + def is_loaded(self): + return self._loaded + + def set_loaded(self, val): + self._loaded = val diff --git a/mistral/openstack/common/apiclient/client.py b/mistral/openstack/common/apiclient/client.py new file mode 100644 index 00000000..814f7f92 --- /dev/null +++ b/mistral/openstack/common/apiclient/client.py @@ -0,0 +1,360 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 Jacob Kaplan-Moss +# Copyright 2011 OpenStack Foundation +# Copyright 2011 Piston Cloud Computing, Inc. +# Copyright 2013 Alessio Ababilov +# Copyright 2013 Grid Dynamics +# Copyright 2013 OpenStack Foundation +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +OpenStack Client interface. Handles the REST calls and responses. +""" + +# E0202: An attribute inherited from %s hide this method +# pylint: disable=E0202 + +import logging +import time + +try: + import simplejson as json +except ImportError: + import json + +import requests + +from mistral.openstack.common.apiclient import exceptions +from mistral.openstack.common import importutils + + +_logger = logging.getLogger(__name__) + + +class HTTPClient(object): + """This client handles sending HTTP requests to OpenStack servers. + + Features: + - share authentication information between several clients to different + services (e.g., for compute and image clients); + - reissue authentication request for expired tokens; + - encode/decode JSON bodies; + - raise exceptions on HTTP errors; + - pluggable authentication; + - store authentication information in a keyring; + - store time spent for requests; + - register clients for particular services, so one can use + `http_client.identity` or `http_client.compute`; + - log requests and responses in a format that is easy to copy-and-paste + into terminal and send the same request with curl. + """ + + user_agent = "mistral.openstack.common.apiclient" + + def __init__(self, + auth_plugin, + region_name=None, + endpoint_type="publicURL", + original_ip=None, + verify=True, + cert=None, + timeout=None, + timings=False, + keyring_saver=None, + debug=False, + user_agent=None, + http=None): + self.auth_plugin = auth_plugin + + self.endpoint_type = endpoint_type + self.region_name = region_name + + self.original_ip = original_ip + self.timeout = timeout + self.verify = verify + self.cert = cert + + self.keyring_saver = keyring_saver + self.debug = debug + self.user_agent = user_agent or self.user_agent + + self.times = [] # [("item", starttime, endtime), ...] + self.timings = timings + + # requests within the same session can reuse TCP connections from pool + self.http = http or requests.Session() + + self.cached_token = None + + def _http_log_req(self, method, url, kwargs): + if not self.debug: + return + + string_parts = [ + "curl -i", + "-X '%s'" % method, + "'%s'" % url, + ] + + for element in kwargs['headers']: + header = "-H '%s: %s'" % (element, kwargs['headers'][element]) + string_parts.append(header) + + _logger.debug("REQ: %s" % " ".join(string_parts)) + if 'data' in kwargs: + _logger.debug("REQ BODY: %s\n" % (kwargs['data'])) + + def _http_log_resp(self, resp): + if not self.debug: + return + _logger.debug( + "RESP: [%s] %s\n", + resp.status_code, + resp.headers) + if resp._content_consumed: + _logger.debug( + "RESP BODY: %s\n", + resp.text) + + def serialize(self, kwargs): + if kwargs.get('json') is not None: + kwargs['headers']['Content-Type'] = 'application/json' + kwargs['data'] = json.dumps(kwargs['json']) + try: + del kwargs['json'] + except KeyError: + pass + + def get_timings(self): + return self.times + + def reset_timings(self): + self.times = [] + + def request(self, method, url, **kwargs): + """Send an http request with the specified characteristics. + + Wrapper around `requests.Session.request` to handle tasks such as + setting headers, JSON encoding/decoding, and error handling. + + :param method: method of HTTP request + :param url: URL of HTTP request + :param kwargs: any other parameter that can be passed to +' requests.Session.request (such as `headers`) or `json` + that will be encoded as JSON and used as `data` argument + """ + kwargs.setdefault("headers", kwargs.get("headers", {})) + kwargs["headers"]["User-Agent"] = self.user_agent + if self.original_ip: + kwargs["headers"]["Forwarded"] = "for=%s;by=%s" % ( + self.original_ip, self.user_agent) + if self.timeout is not None: + kwargs.setdefault("timeout", self.timeout) + kwargs.setdefault("verify", self.verify) + if self.cert is not None: + kwargs.setdefault("cert", self.cert) + self.serialize(kwargs) + + self._http_log_req(method, url, kwargs) + if self.timings: + start_time = time.time() + resp = self.http.request(method, url, **kwargs) + if self.timings: + self.times.append(("%s %s" % (method, url), + start_time, time.time())) + self._http_log_resp(resp) + + if resp.status_code >= 400: + _logger.debug( + "Request returned failure status: %s", + resp.status_code) + raise exceptions.from_response(resp, method, url) + + return resp + + @staticmethod + def concat_url(endpoint, url): + """Concatenate endpoint and final URL. + + E.g., "http://keystone/v2.0/" and "/tokens" are concatenated to + "http://keystone/v2.0/tokens". + + :param endpoint: the base URL + :param url: the final URL + """ + return "%s/%s" % (endpoint.rstrip("/"), url.strip("/")) + + def client_request(self, client, method, url, **kwargs): + """Send an http request using `client`'s endpoint and specified `url`. + + If request was rejected as unauthorized (possibly because the token is + expired), issue one authorization attempt and send the request once + again. + + :param client: instance of BaseClient descendant + :param method: method of HTTP request + :param url: URL of HTTP request + :param kwargs: any other parameter that can be passed to +' `HTTPClient.request` + """ + + filter_args = { + "endpoint_type": client.endpoint_type or self.endpoint_type, + "service_type": client.service_type, + } + token, endpoint = (self.cached_token, client.cached_endpoint) + just_authenticated = False + if not (token and endpoint): + try: + token, endpoint = self.auth_plugin.token_and_endpoint( + **filter_args) + except exceptions.EndpointException: + pass + if not (token and endpoint): + self.authenticate() + just_authenticated = True + token, endpoint = self.auth_plugin.token_and_endpoint( + **filter_args) + if not (token and endpoint): + raise exceptions.AuthorizationFailure( + "Cannot find endpoint or token for request") + + old_token_endpoint = (token, endpoint) + kwargs.setdefault("headers", {})["X-Auth-Token"] = token + self.cached_token = token + client.cached_endpoint = endpoint + # Perform the request once. If we get Unauthorized, then it + # might be because the auth token expired, so try to + # re-authenticate and try again. If it still fails, bail. + try: + return self.request( + method, self.concat_url(endpoint, url), **kwargs) + except exceptions.Unauthorized as unauth_ex: + if just_authenticated: + raise + self.cached_token = None + client.cached_endpoint = None + self.authenticate() + try: + token, endpoint = self.auth_plugin.token_and_endpoint( + **filter_args) + except exceptions.EndpointException: + raise unauth_ex + if (not (token and endpoint) or + old_token_endpoint == (token, endpoint)): + raise unauth_ex + self.cached_token = token + client.cached_endpoint = endpoint + kwargs["headers"]["X-Auth-Token"] = token + return self.request( + method, self.concat_url(endpoint, url), **kwargs) + + def add_client(self, base_client_instance): + """Add a new instance of :class:`BaseClient` descendant. + + `self` will store a reference to `base_client_instance`. + + Example: + + >>> def test_clients(): + ... from keystoneclient.auth import keystone + ... from openstack.common.apiclient import client + ... auth = keystone.KeystoneAuthPlugin( + ... username="user", password="pass", tenant_name="tenant", + ... auth_url="http://auth:5000/v2.0") + ... openstack_client = client.HTTPClient(auth) + ... # create nova client + ... from novaclient.v1_1 import client + ... client.Client(openstack_client) + ... # create keystone client + ... from keystoneclient.v2_0 import client + ... client.Client(openstack_client) + ... # use them + ... openstack_client.identity.tenants.list() + ... openstack_client.compute.servers.list() + """ + service_type = base_client_instance.service_type + if service_type and not hasattr(self, service_type): + setattr(self, service_type, base_client_instance) + + def authenticate(self): + self.auth_plugin.authenticate(self) + # Store the authentication results in the keyring for later requests + if self.keyring_saver: + self.keyring_saver.save(self) + + +class BaseClient(object): + """Top-level object to access the OpenStack API. + + This client uses :class:`HTTPClient` to send requests. :class:`HTTPClient` + will handle a bunch of issues such as authentication. + """ + + service_type = None + endpoint_type = None # "publicURL" will be used + cached_endpoint = None + + def __init__(self, http_client, extensions=None): + self.http_client = http_client + http_client.add_client(self) + + # Add in any extensions... + if extensions: + for extension in extensions: + if extension.manager_class: + setattr(self, extension.name, + extension.manager_class(self)) + + def client_request(self, method, url, **kwargs): + return self.http_client.client_request( + self, method, url, **kwargs) + + def head(self, url, **kwargs): + return self.client_request("HEAD", url, **kwargs) + + def get(self, url, **kwargs): + return self.client_request("GET", url, **kwargs) + + def post(self, url, **kwargs): + return self.client_request("POST", url, **kwargs) + + def put(self, url, **kwargs): + return self.client_request("PUT", url, **kwargs) + + def delete(self, url, **kwargs): + return self.client_request("DELETE", url, **kwargs) + + def patch(self, url, **kwargs): + return self.client_request("PATCH", url, **kwargs) + + @staticmethod + def get_class(api_name, version, version_map): + """Returns the client class for the requested API version + + :param api_name: the name of the API, e.g. 'compute', 'image', etc + :param version: the requested API version + :param version_map: a dict of client classes keyed by version + :rtype: a client class for the requested API version + """ + try: + client_path = version_map[str(version)] + except (KeyError, ValueError): + msg = "Invalid %s client version '%s'. must be one of: %s" % ( + (api_name, version, ', '.join(version_map.keys()))) + raise exceptions.UnsupportedVersion(msg) + + return importutils.import_class(client_path) diff --git a/mistral/openstack/common/apiclient/exceptions.py b/mistral/openstack/common/apiclient/exceptions.py new file mode 100644 index 00000000..fec8ade1 --- /dev/null +++ b/mistral/openstack/common/apiclient/exceptions.py @@ -0,0 +1,441 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 Jacob Kaplan-Moss +# Copyright 2011 Nebula, Inc. +# Copyright 2013 Alessio Ababilov +# Copyright 2013 OpenStack Foundation +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Exception definitions. +""" + +import inspect +import sys + +import six + + +class ClientException(Exception): + """The base exception class for all exceptions this library raises. + """ + pass + + +class MissingArgs(ClientException): + """Supplied arguments are not sufficient for calling a function.""" + def __init__(self, missing): + self.missing = missing + msg = "Missing argument(s): %s" % ", ".join(missing) + super(MissingArgs, self).__init__(msg) + + +class ValidationError(ClientException): + """Error in validation on API client side.""" + pass + + +class UnsupportedVersion(ClientException): + """User is trying to use an unsupported version of the API.""" + pass + + +class CommandError(ClientException): + """Error in CLI tool.""" + pass + + +class AuthorizationFailure(ClientException): + """Cannot authorize API client.""" + pass + + +class AuthPluginOptionsMissing(AuthorizationFailure): + """Auth plugin misses some options.""" + def __init__(self, opt_names): + super(AuthPluginOptionsMissing, self).__init__( + "Authentication failed. Missing options: %s" % + ", ".join(opt_names)) + self.opt_names = opt_names + + +class AuthSystemNotFound(AuthorizationFailure): + """User has specified a AuthSystem that is not installed.""" + def __init__(self, auth_system): + super(AuthSystemNotFound, self).__init__( + "AuthSystemNotFound: %s" % repr(auth_system)) + self.auth_system = auth_system + + +class NoUniqueMatch(ClientException): + """Multiple entities found instead of one.""" + pass + + +class EndpointException(ClientException): + """Something is rotten in Service Catalog.""" + pass + + +class EndpointNotFound(EndpointException): + """Could not find requested endpoint in Service Catalog.""" + pass + + +class AmbiguousEndpoints(EndpointException): + """Found more than one matching endpoint in Service Catalog.""" + def __init__(self, endpoints=None): + super(AmbiguousEndpoints, self).__init__( + "AmbiguousEndpoints: %s" % repr(endpoints)) + self.endpoints = endpoints + + +class HttpError(ClientException): + """The base exception class for all HTTP exceptions. + """ + http_status = 0 + message = "HTTP Error" + + def __init__(self, message=None, details=None, + response=None, request_id=None, + url=None, method=None, http_status=None): + self.http_status = http_status or self.http_status + self.message = message or self.message + self.details = details + self.request_id = request_id + self.response = response + self.url = url + self.method = method + formatted_string = "%s (HTTP %s)" % (self.message, self.http_status) + if request_id: + formatted_string += " (Request-ID: %s)" % request_id + super(HttpError, self).__init__(formatted_string) + + +class HTTPClientError(HttpError): + """Client-side HTTP error. + + Exception for cases in which the client seems to have erred. + """ + message = "HTTP Client Error" + + +class HttpServerError(HttpError): + """Server-side HTTP error. + + Exception for cases in which the server is aware that it has + erred or is incapable of performing the request. + """ + message = "HTTP Server Error" + + +class BadRequest(HTTPClientError): + """HTTP 400 - Bad Request. + + The request cannot be fulfilled due to bad syntax. + """ + http_status = 400 + message = "Bad Request" + + +class Unauthorized(HTTPClientError): + """HTTP 401 - Unauthorized. + + Similar to 403 Forbidden, but specifically for use when authentication + is required and has failed or has not yet been provided. + """ + http_status = 401 + message = "Unauthorized" + + +class PaymentRequired(HTTPClientError): + """HTTP 402 - Payment Required. + + Reserved for future use. + """ + http_status = 402 + message = "Payment Required" + + +class Forbidden(HTTPClientError): + """HTTP 403 - Forbidden. + + The request was a valid request, but the server is refusing to respond + to it. + """ + http_status = 403 + message = "Forbidden" + + +class NotFound(HTTPClientError): + """HTTP 404 - Not Found. + + The requested resource could not be found but may be available again + in the future. + """ + http_status = 404 + message = "Not Found" + + +class MethodNotAllowed(HTTPClientError): + """HTTP 405 - Method Not Allowed. + + A request was made of a resource using a request method not supported + by that resource. + """ + http_status = 405 + message = "Method Not Allowed" + + +class NotAcceptable(HTTPClientError): + """HTTP 406 - Not Acceptable. + + The requested resource is only capable of generating content not + acceptable according to the Accept headers sent in the request. + """ + http_status = 406 + message = "Not Acceptable" + + +class ProxyAuthenticationRequired(HTTPClientError): + """HTTP 407 - Proxy Authentication Required. + + The client must first authenticate itself with the proxy. + """ + http_status = 407 + message = "Proxy Authentication Required" + + +class RequestTimeout(HTTPClientError): + """HTTP 408 - Request Timeout. + + The server timed out waiting for the request. + """ + http_status = 408 + message = "Request Timeout" + + +class Conflict(HTTPClientError): + """HTTP 409 - Conflict. + + Indicates that the request could not be processed because of conflict + in the request, such as an edit conflict. + """ + http_status = 409 + message = "Conflict" + + +class Gone(HTTPClientError): + """HTTP 410 - Gone. + + Indicates that the resource requested is no longer available and will + not be available again. + """ + http_status = 410 + message = "Gone" + + +class LengthRequired(HTTPClientError): + """HTTP 411 - Length Required. + + The request did not specify the length of its content, which is + required by the requested resource. + """ + http_status = 411 + message = "Length Required" + + +class PreconditionFailed(HTTPClientError): + """HTTP 412 - Precondition Failed. + + The server does not meet one of the preconditions that the requester + put on the request. + """ + http_status = 412 + message = "Precondition Failed" + + +class RequestEntityTooLarge(HTTPClientError): + """HTTP 413 - Request Entity Too Large. + + The request is larger than the server is willing or able to process. + """ + http_status = 413 + message = "Request Entity Too Large" + + def __init__(self, *args, **kwargs): + try: + self.retry_after = int(kwargs.pop('retry_after')) + except (KeyError, ValueError): + self.retry_after = 0 + + super(RequestEntityTooLarge, self).__init__(*args, **kwargs) + + +class RequestUriTooLong(HTTPClientError): + """HTTP 414 - Request-URI Too Long. + + The URI provided was too long for the server to process. + """ + http_status = 414 + message = "Request-URI Too Long" + + +class UnsupportedMediaType(HTTPClientError): + """HTTP 415 - Unsupported Media Type. + + The request entity has a media type which the server or resource does + not support. + """ + http_status = 415 + message = "Unsupported Media Type" + + +class RequestedRangeNotSatisfiable(HTTPClientError): + """HTTP 416 - Requested Range Not Satisfiable. + + The client has asked for a portion of the file, but the server cannot + supply that portion. + """ + http_status = 416 + message = "Requested Range Not Satisfiable" + + +class ExpectationFailed(HTTPClientError): + """HTTP 417 - Expectation Failed. + + The server cannot meet the requirements of the Expect request-header field. + """ + http_status = 417 + message = "Expectation Failed" + + +class UnprocessableEntity(HTTPClientError): + """HTTP 422 - Unprocessable Entity. + + The request was well-formed but was unable to be followed due to semantic + errors. + """ + http_status = 422 + message = "Unprocessable Entity" + + +class InternalServerError(HttpServerError): + """HTTP 500 - Internal Server Error. + + A generic error message, given when no more specific message is suitable. + """ + http_status = 500 + message = "Internal Server Error" + + +# NotImplemented is a python keyword. +class HttpNotImplemented(HttpServerError): + """HTTP 501 - Not Implemented. + + The server either does not recognize the request method, or it lacks + the ability to fulfill the request. + """ + http_status = 501 + message = "Not Implemented" + + +class BadGateway(HttpServerError): + """HTTP 502 - Bad Gateway. + + The server was acting as a gateway or proxy and received an invalid + response from the upstream server. + """ + http_status = 502 + message = "Bad Gateway" + + +class ServiceUnavailable(HttpServerError): + """HTTP 503 - Service Unavailable. + + The server is currently unavailable. + """ + http_status = 503 + message = "Service Unavailable" + + +class GatewayTimeout(HttpServerError): + """HTTP 504 - Gateway Timeout. + + The server was acting as a gateway or proxy and did not receive a timely + response from the upstream server. + """ + http_status = 504 + message = "Gateway Timeout" + + +class HttpVersionNotSupported(HttpServerError): + """HTTP 505 - HttpVersion Not Supported. + + The server does not support the HTTP protocol version used in the request. + """ + http_status = 505 + message = "HTTP Version Not Supported" + + +# _code_map contains all the classes that have http_status attribute. +_code_map = dict( + (getattr(obj, 'http_status', None), obj) + for name, obj in six.iteritems(vars(sys.modules[__name__])) + if inspect.isclass(obj) and getattr(obj, 'http_status', False) +) + + +def from_response(response, method, url): + """Returns an instance of :class:`HttpError` or subclass based on response. + + :param response: instance of `requests.Response` class + :param method: HTTP method used for request + :param url: URL used for request + """ + kwargs = { + "http_status": response.status_code, + "response": response, + "method": method, + "url": url, + "request_id": response.headers.get("x-compute-request-id"), + } + if "retry-after" in response.headers: + kwargs["retry_after"] = response.headers["retry-after"] + + content_type = response.headers.get("Content-Type", "") + if content_type.startswith("application/json"): + try: + body = response.json() + except ValueError: + pass + else: + if hasattr(body, "keys"): + error = body[body.keys()[0]] + kwargs["message"] = error.get("message", None) + kwargs["details"] = error.get("details", None) + elif content_type.startswith("text/"): + kwargs["details"] = response.text + + try: + cls = _code_map[response.status_code] + except KeyError: + if 500 <= response.status_code < 600: + cls = HttpServerError + elif 400 <= response.status_code < 500: + cls = HTTPClientError + else: + cls = HttpError + return cls(**kwargs) diff --git a/mistral/openstack/common/apiclient/fake_client.py b/mistral/openstack/common/apiclient/fake_client.py new file mode 100644 index 00000000..52b9866e --- /dev/null +++ b/mistral/openstack/common/apiclient/fake_client.py @@ -0,0 +1,172 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2013 OpenStack Foundation +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +A fake server that "responds" to API methods with pre-canned responses. + +All of these responses come from the spec, so if for some reason the spec's +wrong the tests might raise AssertionError. I've indicated in comments the +places where actual behavior differs from the spec. +""" + +# W0102: Dangerous default value %s as argument +# pylint: disable=W0102 + +import json +import urlparse + +import requests + +from mistral.openstack.common.apiclient import client + + +def assert_has_keys(dct, required=[], optional=[]): + for k in required: + try: + assert k in dct + except AssertionError: + extra_keys = set(dct.keys()).difference(set(required + optional)) + raise AssertionError("found unexpected keys: %s" % + list(extra_keys)) + + +class TestResponse(requests.Response): + """Wrap requests.Response and provide a convenient initialization. + """ + + def __init__(self, data): + super(TestResponse, self).__init__() + self._content_consumed = True + if isinstance(data, dict): + self.status_code = data.get('status_code', 200) + # Fake the text attribute to streamline Response creation + text = data.get('text', "") + if isinstance(text, (dict, list)): + self._content = json.dumps(text) + default_headers = { + "Content-Type": "application/json", + } + else: + self._content = text + default_headers = {} + self.headers = data.get('headers') or default_headers + else: + self.status_code = data + + def __eq__(self, other): + return (self.status_code == other.status_code and + self.headers == other.headers and + self._content == other._content) + + +class FakeHTTPClient(client.HTTPClient): + + def __init__(self, *args, **kwargs): + self.callstack = [] + self.fixtures = kwargs.pop("fixtures", None) or {} + if not args and not "auth_plugin" in kwargs: + args = (None, ) + super(FakeHTTPClient, self).__init__(*args, **kwargs) + + def assert_called(self, method, url, body=None, pos=-1): + """Assert than an API method was just called. + """ + expected = (method, url) + called = self.callstack[pos][0:2] + assert self.callstack, \ + "Expected %s %s but no calls were made." % expected + + assert expected == called, 'Expected %s %s; got %s %s' % \ + (expected + called) + + if body is not None: + if self.callstack[pos][3] != body: + raise AssertionError('%r != %r' % + (self.callstack[pos][3], body)) + + def assert_called_anytime(self, method, url, body=None): + """Assert than an API method was called anytime in the test. + """ + expected = (method, url) + + assert self.callstack, \ + "Expected %s %s but no calls were made." % expected + + found = False + entry = None + for entry in self.callstack: + if expected == entry[0:2]: + found = True + break + + assert found, 'Expected %s %s; got %s' % \ + (method, url, self.callstack) + if body is not None: + assert entry[3] == body, "%s != %s" % (entry[3], body) + + self.callstack = [] + + def clear_callstack(self): + self.callstack = [] + + def authenticate(self): + pass + + def client_request(self, client, method, url, **kwargs): + # Check that certain things are called correctly + if method in ["GET", "DELETE"]: + assert "json" not in kwargs + + # Note the call + self.callstack.append( + (method, + url, + kwargs.get("headers") or {}, + kwargs.get("json") or kwargs.get("data"))) + try: + fixture = self.fixtures[url][method] + except KeyError: + pass + else: + return TestResponse({"headers": fixture[0], + "text": fixture[1]}) + + # Call the method + args = urlparse.parse_qsl(urlparse.urlparse(url)[4]) + kwargs.update(args) + munged_url = url.rsplit('?', 1)[0] + munged_url = munged_url.strip('/').replace('/', '_').replace('.', '_') + munged_url = munged_url.replace('-', '_') + + callback = "%s_%s" % (method.lower(), munged_url) + + if not hasattr(self, callback): + raise AssertionError('Called unknown API method: %s %s, ' + 'expected fakes method name: %s' % + (method, url, callback)) + + resp = getattr(self, callback)(**kwargs) + if len(resp) == 3: + status, headers, body = resp + else: + status, body = resp + headers = {} + return TestResponse({ + "status_code": status, + "text": body, + "headers": headers, + }) diff --git a/mistral/openstack/common/cliutils.py b/mistral/openstack/common/cliutils.py new file mode 100644 index 00000000..74c12965 --- /dev/null +++ b/mistral/openstack/common/cliutils.py @@ -0,0 +1,214 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2012 Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +# W0603: Using the global statement +# W0621: Redefining name %s from outer scope +# pylint: disable=W0603,W0621 + +import getpass +import inspect +import os +import sys +import textwrap + +import prettytable +import six + +from mistral.openstack.common.apiclient import exceptions +from mistral.openstack.common import strutils + + +def validate_args(fn, *args, **kwargs): + """Check that the supplied args are sufficient for calling a function. + + >>> validate_args(lambda a: None) + Traceback (most recent call last): + ... + MissingArgs: Missing argument(s): a + >>> validate_args(lambda a, b, c, d: None, 0, c=1) + Traceback (most recent call last): + ... + MissingArgs: Missing argument(s): b, d + + :param fn: the function to check + :param arg: the positional arguments supplied + :param kwargs: the keyword arguments supplied + """ + argspec = inspect.getargspec(fn) + + num_defaults = len(argspec.defaults or []) + required_args = argspec.args[:len(argspec.args) - num_defaults] + + def isbound(method): + return getattr(method, 'im_self', None) is not None + + if isbound(fn): + required_args.pop(0) + + missing = [arg for arg in required_args if arg not in kwargs] + missing = missing[len(args):] + if missing: + raise exceptions.MissingArgs(missing) + + +def arg(*args, **kwargs): + """Decorator for CLI args. + + Example: + + >>> @arg("name", help="Name of the new entity") + ... def entity_create(args): + ... pass + """ + def _decorator(func): + add_arg(func, *args, **kwargs) + return func + return _decorator + + +def env(*args, **kwargs): + """Returns the first environment variable set. + + If all are empty, defaults to '' or keyword arg `default`. + """ + for arg in args: + value = os.environ.get(arg, None) + if value: + return value + return kwargs.get('default', '') + + +def add_arg(func, *args, **kwargs): + """Bind CLI arguments to a shell.py `do_foo` function.""" + + if not hasattr(func, 'arguments'): + func.arguments = [] + + # NOTE(sirp): avoid dups that can occur when the module is shared across + # tests. + if (args, kwargs) not in func.arguments: + # Because of the semantics of decorator composition if we just append + # to the options list positional options will appear to be backwards. + func.arguments.insert(0, (args, kwargs)) + + +def unauthenticated(func): + """Adds 'unauthenticated' attribute to decorated function. + + Usage: + + >>> @unauthenticated + ... def mymethod(f): + ... pass + """ + func.unauthenticated = True + return func + + +def isunauthenticated(func): + """Checks if the function does not require authentication. + + Mark such functions with the `@unauthenticated` decorator. + + :returns: bool + """ + return getattr(func, 'unauthenticated', False) + + +def print_list(objs, fields, formatters=None, sortby_index=0, + mixed_case_fields=None): + """Print a list or objects as a table, one row per object. + + :param objs: iterable of :class:`Resource` + :param fields: attributes that correspond to columns, in order + :param formatters: `dict` of callables for field formatting + :param sortby_index: index of the field for sorting table rows + :param mixed_case_fields: fields corresponding to object attributes that + have mixed case names (e.g., 'serverId') + """ + formatters = formatters or {} + mixed_case_fields = mixed_case_fields or [] + if sortby_index is None: + sortby = None + else: + sortby = fields[sortby_index] + pt = prettytable.PrettyTable(fields, caching=False) + pt.align = 'l' + + for o in objs: + row = [] + for field in fields: + if field in formatters: + row.append(formatters[field](o)) + else: + if field in mixed_case_fields: + field_name = field.replace(' ', '_') + else: + field_name = field.lower().replace(' ', '_') + data = getattr(o, field_name, '') + row.append(data) + pt.add_row(row) + + print(strutils.safe_encode(pt.get_string(sortby=sortby))) + + +def print_dict(dct, dict_property="Property", wrap=0): + """Print a `dict` as a table of two columns. + + :param dct: `dict` to print + :param dict_property: name of the first column + :param wrap: wrapping for the second column + """ + pt = prettytable.PrettyTable([dict_property, 'Value'], caching=False) + pt.align = 'l' + for k, v in dct.iteritems(): + # convert dict to str to check length + if isinstance(v, dict): + v = str(v) + if wrap > 0: + v = textwrap.fill(str(v), wrap) + # if value has a newline, add in multiple rows + # e.g. fault with stacktrace + if v and isinstance(v, six.string_types) and r'\n' in v: + lines = v.strip().split(r'\n') + col1 = k + for line in lines: + pt.add_row([col1, line]) + col1 = '' + else: + pt.add_row([k, v]) + print(strutils.safe_encode(pt.get_string())) + + +def get_password(max_password_prompts=3): + """Read password from TTY.""" + verify = strutils.bool_from_string(env("OS_VERIFY_PASSWORD")) + pw = None + if hasattr(sys.stdin, "isatty") and sys.stdin.isatty(): + # Check for Ctrl-D + try: + for _ in xrange(max_password_prompts): + pw1 = getpass.getpass("OS Password: ") + if verify: + pw2 = getpass.getpass("Please verify: ") + else: + pw2 = pw1 + if pw1 == pw2 and pw1: + pw = pw1 + break + except EOFError: + pass + return pw diff --git a/mistral/openstack/common/db/__init__.py b/mistral/openstack/common/db/__init__.py new file mode 100644 index 00000000..1b9b60de --- /dev/null +++ b/mistral/openstack/common/db/__init__.py @@ -0,0 +1,16 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2012 Cloudscaling Group, Inc +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. diff --git a/mistral/openstack/common/db/api.py b/mistral/openstack/common/db/api.py new file mode 100644 index 00000000..458e67eb --- /dev/null +++ b/mistral/openstack/common/db/api.py @@ -0,0 +1,106 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright (c) 2013 Rackspace Hosting +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Multiple DB API backend support. + +Supported configuration options: + +The following two parameters are in the 'database' group: +`backend`: DB backend name or full module path to DB backend module. +`use_tpool`: Enable thread pooling of DB API calls. + +A DB backend module should implement a method named 'get_backend' which +takes no arguments. The method can return any object that implements DB +API methods. + +*NOTE*: There are bugs in eventlet when using tpool combined with +threading locks. The python logging module happens to use such locks. To +work around this issue, be sure to specify thread=False with +eventlet.monkey_patch(). + +A bug for eventlet has been filed here: + +https://bitbucket.org/eventlet/eventlet/issue/137/ +""" +import functools + +from oslo.config import cfg + +from mistral.openstack.common import importutils +from mistral.openstack.common import lockutils + + +db_opts = [ + cfg.StrOpt('backend', + default='sqlalchemy', + deprecated_name='db_backend', + deprecated_group='DEFAULT', + help='The backend to use for db'), + cfg.BoolOpt('use_tpool', + default=False, + deprecated_name='dbapi_use_tpool', + deprecated_group='DEFAULT', + help='Enable the experimental use of thread pooling for ' + 'all DB API calls') +] + +CONF = cfg.CONF +CONF.register_opts(db_opts, 'database') + + +class DBAPI(object): + def __init__(self, backend_mapping=None): + if backend_mapping is None: + backend_mapping = {} + self.__backend = None + self.__backend_mapping = backend_mapping + + @lockutils.synchronized('dbapi_backend', 'mistral-') + def __get_backend(self): + """Get the actual backend. May be a module or an instance of + a class. Doesn't matter to us. We do this synchronized as it's + possible multiple greenthreads started very quickly trying to do + DB calls and eventlet can switch threads before self.__backend gets + assigned. + """ + if self.__backend: + # Another thread assigned it + return self.__backend + backend_name = CONF.database.backend + self.__use_tpool = CONF.database.use_tpool + if self.__use_tpool: + from eventlet import tpool + self.__tpool = tpool + # Import the untranslated name if we don't have a + # mapping. + backend_path = self.__backend_mapping.get(backend_name, + backend_name) + backend_mod = importutils.import_module(backend_path) + self.__backend = backend_mod.get_backend() + return self.__backend + + def __getattr__(self, key): + backend = self.__backend or self.__get_backend() + attr = getattr(backend, key) + if not self.__use_tpool or not hasattr(attr, '__call__'): + return attr + + def tpool_wrapper(*args, **kwargs): + return self.__tpool.execute(attr, *args, **kwargs) + + functools.update_wrapper(tpool_wrapper, attr) + return tpool_wrapper diff --git a/mistral/openstack/common/db/exception.py b/mistral/openstack/common/db/exception.py new file mode 100644 index 00000000..d372ea82 --- /dev/null +++ b/mistral/openstack/common/db/exception.py @@ -0,0 +1,51 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""DB related custom exceptions.""" + +from mistral.openstack.common.gettextutils import _ # noqa + + +class DBError(Exception): + """Wraps an implementation specific exception.""" + def __init__(self, inner_exception=None): + self.inner_exception = inner_exception + super(DBError, self).__init__(str(inner_exception)) + + +class DBDuplicateEntry(DBError): + """Wraps an implementation specific exception.""" + def __init__(self, columns=[], inner_exception=None): + self.columns = columns + super(DBDuplicateEntry, self).__init__(inner_exception) + + +class DBDeadlock(DBError): + def __init__(self, inner_exception=None): + super(DBDeadlock, self).__init__(inner_exception) + + +class DBInvalidUnicodeParameter(Exception): + message = _("Invalid Parameter: " + "Unicode is not supported by the current database.") + + +class DbMigrationError(DBError): + """Wraps migration specific exception.""" + def __init__(self, message=None): + super(DbMigrationError, self).__init__(str(message)) diff --git a/mistral/openstack/common/db/sqlalchemy/__init__.py b/mistral/openstack/common/db/sqlalchemy/__init__.py new file mode 100644 index 00000000..1b9b60de --- /dev/null +++ b/mistral/openstack/common/db/sqlalchemy/__init__.py @@ -0,0 +1,16 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2012 Cloudscaling Group, Inc +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. diff --git a/mistral/openstack/common/db/sqlalchemy/migration.py b/mistral/openstack/common/db/sqlalchemy/migration.py new file mode 100644 index 00000000..b76eb450 --- /dev/null +++ b/mistral/openstack/common/db/sqlalchemy/migration.py @@ -0,0 +1,278 @@ +# coding: utf-8 +# +# Copyright (c) 2013 OpenStack Foundation +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# +# Base on code in migrate/changeset/databases/sqlite.py which is under +# the following license: +# +# The MIT License +# +# Copyright (c) 2009 Evan Rosson, Jan Dittberner, Domen Kožar +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE + +import distutils.version as dist_version +import os +import re + +import migrate +from migrate.changeset import ansisql +from migrate.changeset.databases import sqlite +from migrate.versioning import util as migrate_util +import sqlalchemy +from sqlalchemy.schema import UniqueConstraint + +from mistral.openstack.common.db import exception +from mistral.openstack.common.db.sqlalchemy import session as db_session +from mistral.openstack.common.gettextutils import _ # noqa + + +@migrate_util.decorator +def patched_with_engine(f, *a, **kw): + url = a[0] + engine = migrate_util.construct_engine(url, **kw) + + try: + kw['engine'] = engine + return f(*a, **kw) + finally: + if isinstance(engine, migrate_util.Engine) and engine is not url: + migrate_util.log.debug('Disposing SQLAlchemy engine %s', engine) + engine.dispose() + + +# TODO(jkoelker) When migrate 0.7.3 is released and nova depends +# on that version or higher, this can be removed +MIN_PKG_VERSION = dist_version.StrictVersion('0.7.3') +if (not hasattr(migrate, '__version__') or + dist_version.StrictVersion(migrate.__version__) < MIN_PKG_VERSION): + migrate_util.with_engine = patched_with_engine + + +# NOTE(jkoelker) Delay importing migrate until we are patched +from migrate import exceptions as versioning_exceptions +from migrate.versioning import api as versioning_api +from migrate.versioning.repository import Repository + +_REPOSITORY = None + +get_engine = db_session.get_engine + + +def _get_unique_constraints(self, table): + """Retrieve information about existing unique constraints of the table + + This feature is needed for _recreate_table() to work properly. + Unfortunately, it's not available in sqlalchemy 0.7.x/0.8.x. + + """ + + data = table.metadata.bind.execute( + """SELECT sql + FROM sqlite_master + WHERE + type='table' AND + name=:table_name""", + table_name=table.name + ).fetchone()[0] + + UNIQUE_PATTERN = "CONSTRAINT (\w+) UNIQUE \(([^\)]+)\)" + return [ + UniqueConstraint( + *[getattr(table.columns, c.strip(' "')) for c in cols.split(",")], + name=name + ) + for name, cols in re.findall(UNIQUE_PATTERN, data) + ] + + +def _recreate_table(self, table, column=None, delta=None, omit_uniques=None): + """Recreate the table properly + + Unlike the corresponding original method of sqlalchemy-migrate this one + doesn't drop existing unique constraints when creating a new one. + + """ + + table_name = self.preparer.format_table(table) + + # we remove all indexes so as not to have + # problems during copy and re-create + for index in table.indexes: + index.drop() + + # reflect existing unique constraints + for uc in self._get_unique_constraints(table): + table.append_constraint(uc) + # omit given unique constraints when creating a new table if required + table.constraints = set([ + cons for cons in table.constraints + if omit_uniques is None or cons.name not in omit_uniques + ]) + + self.append('ALTER TABLE %s RENAME TO migration_tmp' % table_name) + self.execute() + + insertion_string = self._modify_table(table, column, delta) + + table.create(bind=self.connection) + self.append(insertion_string % {'table_name': table_name}) + self.execute() + self.append('DROP TABLE migration_tmp') + self.execute() + + +def _visit_migrate_unique_constraint(self, *p, **k): + """Drop the given unique constraint + + The corresponding original method of sqlalchemy-migrate just + raises NotImplemented error + + """ + + self.recreate_table(p[0].table, omit_uniques=[p[0].name]) + + +def patch_migrate(): + """A workaround for SQLite's inability to alter things + + SQLite abilities to alter tables are very limited (please read + http://www.sqlite.org/lang_altertable.html for more details). + E. g. one can't drop a column or a constraint in SQLite. The + workaround for this is to recreate the original table omitting + the corresponding constraint (or column). + + sqlalchemy-migrate library has recreate_table() method that + implements this workaround, but it does it wrong: + + - information about unique constraints of a table + is not retrieved. So if you have a table with one + unique constraint and a migration adding another one + you will end up with a table that has only the + latter unique constraint, and the former will be lost + + - dropping of unique constraints is not supported at all + + The proper way to fix this is to provide a pull-request to + sqlalchemy-migrate, but the project seems to be dead. So we + can go on with monkey-patching of the lib at least for now. + + """ + + # this patch is needed to ensure that recreate_table() doesn't drop + # existing unique constraints of the table when creating a new one + helper_cls = sqlite.SQLiteHelper + helper_cls.recreate_table = _recreate_table + helper_cls._get_unique_constraints = _get_unique_constraints + + # this patch is needed to be able to drop existing unique constraints + constraint_cls = sqlite.SQLiteConstraintDropper + constraint_cls.visit_migrate_unique_constraint = \ + _visit_migrate_unique_constraint + constraint_cls.__bases__ = (ansisql.ANSIColumnDropper, + sqlite.SQLiteConstraintGenerator) + + +def db_sync(abs_path, version=None, init_version=0): + """Upgrade or downgrade a database. + + Function runs the upgrade() or downgrade() functions in change scripts. + + :param abs_path: Absolute path to migrate repository. + :param version: Database will upgrade/downgrade until this version. + If None - database will update to the latest + available version. + :param init_version: Initial database version + """ + if version is not None: + try: + version = int(version) + except ValueError: + raise exception.DbMigrationError( + message=_("version should be an integer")) + + current_version = db_version(abs_path, init_version) + repository = _find_migrate_repo(abs_path) + if version is None or version > current_version: + return versioning_api.upgrade(get_engine(), repository, version) + else: + return versioning_api.downgrade(get_engine(), repository, + version) + + +def db_version(abs_path, init_version): + """Show the current version of the repository. + + :param abs_path: Absolute path to migrate repository + :param version: Initial database version + """ + repository = _find_migrate_repo(abs_path) + try: + return versioning_api.db_version(get_engine(), repository) + except versioning_exceptions.DatabaseNotControlledError: + meta = sqlalchemy.MetaData() + engine = get_engine() + meta.reflect(bind=engine) + tables = meta.tables + if len(tables) == 0: + db_version_control(abs_path, init_version) + return versioning_api.db_version(get_engine(), repository) + else: + # Some pre-Essex DB's may not be version controlled. + # Require them to upgrade using Essex first. + raise exception.DbMigrationError( + message=_("Upgrade DB using Essex release first.")) + + +def db_version_control(abs_path, version=None): + """Mark a database as under this repository's version control. + + Once a database is under version control, schema changes should + only be done via change scripts in this repository. + + :param abs_path: Absolute path to migrate repository + :param version: Initial database version + """ + repository = _find_migrate_repo(abs_path) + versioning_api.version_control(get_engine(), repository, version) + return version + + +def _find_migrate_repo(abs_path): + """Get the project's change script repository + + :param abs_path: Absolute path to migrate repository + """ + global _REPOSITORY + if not os.path.exists(abs_path): + raise exception.DbMigrationError("Path %s not found" % abs_path) + if _REPOSITORY is None: + _REPOSITORY = Repository(abs_path) + return _REPOSITORY diff --git a/mistral/openstack/common/db/sqlalchemy/models.py b/mistral/openstack/common/db/sqlalchemy/models.py new file mode 100644 index 00000000..883eb418 --- /dev/null +++ b/mistral/openstack/common/db/sqlalchemy/models.py @@ -0,0 +1,110 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright (c) 2011 X.commerce, a business unit of eBay Inc. +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# Copyright 2011 Piston Cloud Computing, Inc. +# Copyright 2012 Cloudscaling Group, Inc. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +""" +SQLAlchemy models. +""" + +import six + +from sqlalchemy import Column, Integer +from sqlalchemy import DateTime +from sqlalchemy.orm import object_mapper + +from mistral.openstack.common.db.sqlalchemy import session as sa +from mistral.openstack.common import timeutils + + +class ModelBase(object): + """Base class for models.""" + __table_initialized__ = False + + def save(self, session=None): + """Save this object.""" + if not session: + session = sa.get_session() + # NOTE(boris-42): This part of code should be look like: + # sesssion.add(self) + # session.flush() + # But there is a bug in sqlalchemy and eventlet that + # raises NoneType exception if there is no running + # transaction and rollback is called. As long as + # sqlalchemy has this bug we have to create transaction + # explicity. + with session.begin(subtransactions=True): + session.add(self) + session.flush() + + def __setitem__(self, key, value): + setattr(self, key, value) + + def __getitem__(self, key): + return getattr(self, key) + + def get(self, key, default=None): + return getattr(self, key, default) + + def _get_extra_keys(self): + return [] + + def __iter__(self): + columns = dict(object_mapper(self).columns).keys() + # NOTE(russellb): Allow models to specify other keys that can be looked + # up, beyond the actual db columns. An example would be the 'name' + # property for an Instance. + columns.extend(self._get_extra_keys()) + self._i = iter(columns) + return self + + def next(self): + n = six.advance_iterator(self._i) + return n, getattr(self, n) + + def update(self, values): + """Make the model object behave like a dict.""" + for k, v in six.iteritems(values): + setattr(self, k, v) + + def iteritems(self): + """Make the model object behave like a dict. + + Includes attributes from joins. + """ + local = dict(self) + joined = dict([(k, v) for k, v in six.iteritems(self.__dict__) + if not k[0] == '_']) + local.update(joined) + return local.iteritems() + + +class TimestampMixin(object): + created_at = Column(DateTime, default=timeutils.utcnow) + updated_at = Column(DateTime, onupdate=timeutils.utcnow) + + +class SoftDeleteMixin(object): + deleted_at = Column(DateTime) + deleted = Column(Integer, default=0) + + def soft_delete(self, session=None): + """Mark this object as deleted.""" + self.deleted = self.id + self.deleted_at = timeutils.utcnow() + self.save(session=session) diff --git a/mistral/openstack/common/db/sqlalchemy/session.py b/mistral/openstack/common/db/sqlalchemy/session.py new file mode 100644 index 00000000..7aec94eb --- /dev/null +++ b/mistral/openstack/common/db/sqlalchemy/session.py @@ -0,0 +1,797 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Session Handling for SQLAlchemy backend. + +Initializing: + +* Call set_defaults with the minimal of the following kwargs: + sql_connection, sqlite_db + + Example: + + session.set_defaults( + sql_connection="sqlite:///var/lib/mistral/sqlite.db", + sqlite_db="/var/lib/mistral/sqlite.db") + +Recommended ways to use sessions within this framework: + +* Don't use them explicitly; this is like running with AUTOCOMMIT=1. + model_query() will implicitly use a session when called without one + supplied. This is the ideal situation because it will allow queries + to be automatically retried if the database connection is interrupted. + + Note: Automatic retry will be enabled in a future patch. + + It is generally fine to issue several queries in a row like this. Even though + they may be run in separate transactions and/or separate sessions, each one + will see the data from the prior calls. If needed, undo- or rollback-like + functionality should be handled at a logical level. For an example, look at + the code around quotas and reservation_rollback(). + + Examples: + + def get_foo(context, foo): + return model_query(context, models.Foo).\ + filter_by(foo=foo).\ + first() + + def update_foo(context, id, newfoo): + model_query(context, models.Foo).\ + filter_by(id=id).\ + update({'foo': newfoo}) + + def create_foo(context, values): + foo_ref = models.Foo() + foo_ref.update(values) + foo_ref.save() + return foo_ref + + +* Within the scope of a single method, keeping all the reads and writes within + the context managed by a single session. In this way, the session's __exit__ + handler will take care of calling flush() and commit() for you. + If using this approach, you should not explicitly call flush() or commit(). + Any error within the context of the session will cause the session to emit + a ROLLBACK. If the connection is dropped before this is possible, the + database will implicitly rollback the transaction. + + Note: statements in the session scope will not be automatically retried. + + If you create models within the session, they need to be added, but you + do not need to call model.save() + + def create_many_foo(context, foos): + session = get_session() + with session.begin(): + for foo in foos: + foo_ref = models.Foo() + foo_ref.update(foo) + session.add(foo_ref) + + def update_bar(context, foo_id, newbar): + session = get_session() + with session.begin(): + foo_ref = model_query(context, models.Foo, session).\ + filter_by(id=foo_id).\ + first() + model_query(context, models.Bar, session).\ + filter_by(id=foo_ref['bar_id']).\ + update({'bar': newbar}) + + Note: update_bar is a trivially simple example of using "with session.begin". + Whereas create_many_foo is a good example of when a transaction is needed, + it is always best to use as few queries as possible. The two queries in + update_bar can be better expressed using a single query which avoids + the need for an explicit transaction. It can be expressed like so: + + def update_bar(context, foo_id, newbar): + subq = model_query(context, models.Foo.id).\ + filter_by(id=foo_id).\ + limit(1).\ + subquery() + model_query(context, models.Bar).\ + filter_by(id=subq.as_scalar()).\ + update({'bar': newbar}) + + For reference, this emits approximagely the following SQL statement: + + UPDATE bar SET bar = ${newbar} + WHERE id=(SELECT bar_id FROM foo WHERE id = ${foo_id} LIMIT 1); + +* Passing an active session between methods. Sessions should only be passed + to private methods. The private method must use a subtransaction; otherwise + SQLAlchemy will throw an error when you call session.begin() on an existing + transaction. Public methods should not accept a session parameter and should + not be involved in sessions within the caller's scope. + + Note that this incurs more overhead in SQLAlchemy than the above means + due to nesting transactions, and it is not possible to implicitly retry + failed database operations when using this approach. + + This also makes code somewhat more difficult to read and debug, because a + single database transaction spans more than one method. Error handling + becomes less clear in this situation. When this is needed for code clarity, + it should be clearly documented. + + def myfunc(foo): + session = get_session() + with session.begin(): + # do some database things + bar = _private_func(foo, session) + return bar + + def _private_func(foo, session=None): + if not session: + session = get_session() + with session.begin(subtransaction=True): + # do some other database things + return bar + + +There are some things which it is best to avoid: + +* Don't keep a transaction open any longer than necessary. + + This means that your "with session.begin()" block should be as short + as possible, while still containing all the related calls for that + transaction. + +* Avoid "with_lockmode('UPDATE')" when possible. + + In MySQL/InnoDB, when a "SELECT ... FOR UPDATE" query does not match + any rows, it will take a gap-lock. This is a form of write-lock on the + "gap" where no rows exist, and prevents any other writes to that space. + This can effectively prevent any INSERT into a table by locking the gap + at the end of the index. Similar problems will occur if the SELECT FOR UPDATE + has an overly broad WHERE clause, or doesn't properly use an index. + + One idea proposed at ODS Fall '12 was to use a normal SELECT to test the + number of rows matching a query, and if only one row is returned, + then issue the SELECT FOR UPDATE. + + The better long-term solution is to use INSERT .. ON DUPLICATE KEY UPDATE. + However, this can not be done until the "deleted" columns are removed and + proper UNIQUE constraints are added to the tables. + + +Enabling soft deletes: + +* To use/enable soft-deletes, the SoftDeleteMixin must be added + to your model class. For example: + + class NovaBase(models.SoftDeleteMixin, models.ModelBase): + pass + + +Efficient use of soft deletes: + +* There are two possible ways to mark a record as deleted: + model.soft_delete() and query.soft_delete(). + + model.soft_delete() method works with single already fetched entry. + query.soft_delete() makes only one db request for all entries that correspond + to query. + +* In almost all cases you should use query.soft_delete(). Some examples: + + def soft_delete_bar(): + count = model_query(BarModel).find(some_condition).soft_delete() + if count == 0: + raise Exception("0 entries were soft deleted") + + def complex_soft_delete_with_synchronization_bar(session=None): + if session is None: + session = get_session() + with session.begin(subtransactions=True): + count = model_query(BarModel).\ + find(some_condition).\ + soft_delete(synchronize_session=True) + # Here synchronize_session is required, because we + # don't know what is going on in outer session. + if count == 0: + raise Exception("0 entries were soft deleted") + +* There is only one situation where model.soft_delete() is appropriate: when + you fetch a single record, work with it, and mark it as deleted in the same + transaction. + + def soft_delete_bar_model(): + session = get_session() + with session.begin(): + bar_ref = model_query(BarModel).find(some_condition).first() + # Work with bar_ref + bar_ref.soft_delete(session=session) + + However, if you need to work with all entries that correspond to query and + then soft delete them you should use query.soft_delete() method: + + def soft_delete_multi_models(): + session = get_session() + with session.begin(): + query = model_query(BarModel, session=session).\ + find(some_condition) + model_refs = query.all() + # Work with model_refs + query.soft_delete(synchronize_session=False) + # synchronize_session=False should be set if there is no outer + # session and these entries are not used after this. + + When working with many rows, it is very important to use query.soft_delete, + which issues a single query. Using model.soft_delete(), as in the following + example, is very inefficient. + + for bar_ref in bar_refs: + bar_ref.soft_delete(session=session) + # This will produce count(bar_refs) db requests. +""" + +import functools +import os.path +import re +import time + +from oslo.config import cfg +import six +from sqlalchemy import exc as sqla_exc +import sqlalchemy.interfaces +from sqlalchemy.interfaces import PoolListener +import sqlalchemy.orm +from sqlalchemy.pool import NullPool, StaticPool +from sqlalchemy.sql.expression import literal_column + +from mistral.openstack.common.db import exception +from mistral.openstack.common.gettextutils import _ # noqa +from mistral.openstack.common import log as logging +from mistral.openstack.common import timeutils + +sqlite_db_opts = [ + cfg.StrOpt('sqlite_db', + default='mistral.sqlite', + help='the filename to use with sqlite'), + cfg.BoolOpt('sqlite_synchronous', + default=True, + help='If true, use synchronous mode for sqlite'), +] + +database_opts = [ + cfg.StrOpt('connection', + default='sqlite:///' + + os.path.abspath(os.path.join(os.path.dirname(__file__), + '../', '$sqlite_db')), + help='The SQLAlchemy connection string used to connect to the ' + 'database', + deprecated_opts=[cfg.DeprecatedOpt('sql_connection', + group='DEFAULT'), + cfg.DeprecatedOpt('sql_connection', + group='DATABASE'), + cfg.DeprecatedOpt('connection', + group='sql'), ]), + cfg.StrOpt('slave_connection', + default='', + help='The SQLAlchemy connection string used to connect to the ' + 'slave database'), + cfg.IntOpt('idle_timeout', + default=3600, + deprecated_opts=[cfg.DeprecatedOpt('sql_idle_timeout', + group='DEFAULT'), + cfg.DeprecatedOpt('sql_idle_timeout', + group='DATABASE')], + help='timeout before idle sql connections are reaped'), + cfg.IntOpt('min_pool_size', + default=1, + deprecated_opts=[cfg.DeprecatedOpt('sql_min_pool_size', + group='DEFAULT'), + cfg.DeprecatedOpt('sql_min_pool_size', + group='DATABASE')], + help='Minimum number of SQL connections to keep open in a ' + 'pool'), + cfg.IntOpt('max_pool_size', + default=None, + deprecated_opts=[cfg.DeprecatedOpt('sql_max_pool_size', + group='DEFAULT'), + cfg.DeprecatedOpt('sql_max_pool_size', + group='DATABASE')], + help='Maximum number of SQL connections to keep open in a ' + 'pool'), + cfg.IntOpt('max_retries', + default=10, + deprecated_opts=[cfg.DeprecatedOpt('sql_max_retries', + group='DEFAULT'), + cfg.DeprecatedOpt('sql_max_retries', + group='DATABASE')], + help='maximum db connection retries during startup. ' + '(setting -1 implies an infinite retry count)'), + cfg.IntOpt('retry_interval', + default=10, + deprecated_opts=[cfg.DeprecatedOpt('sql_retry_interval', + group='DEFAULT'), + cfg.DeprecatedOpt('reconnect_interval', + group='DATABASE')], + help='interval between retries of opening a sql connection'), + cfg.IntOpt('max_overflow', + default=None, + deprecated_opts=[cfg.DeprecatedOpt('sql_max_overflow', + group='DEFAULT'), + cfg.DeprecatedOpt('sqlalchemy_max_overflow', + group='DATABASE')], + help='If set, use this value for max_overflow with sqlalchemy'), + cfg.IntOpt('connection_debug', + default=0, + deprecated_opts=[cfg.DeprecatedOpt('sql_connection_debug', + group='DEFAULT')], + help='Verbosity of SQL debugging information. 0=None, ' + '100=Everything'), + cfg.BoolOpt('connection_trace', + default=False, + deprecated_opts=[cfg.DeprecatedOpt('sql_connection_trace', + group='DEFAULT')], + help='Add python stack traces to SQL as comment strings'), + cfg.IntOpt('pool_timeout', + default=None, + deprecated_opts=[cfg.DeprecatedOpt('sqlalchemy_pool_timeout', + group='DATABASE')], + help='If set, use this value for pool_timeout with sqlalchemy'), +] + +CONF = cfg.CONF +CONF.register_opts(sqlite_db_opts) +CONF.register_opts(database_opts, 'database') + +LOG = logging.getLogger(__name__) + +_ENGINE = None +_MAKER = None +_SLAVE_ENGINE = None +_SLAVE_MAKER = None + + +def set_defaults(sql_connection, sqlite_db, max_pool_size=None, + max_overflow=None, pool_timeout=None): + """Set defaults for configuration variables.""" + cfg.set_defaults(database_opts, + connection=sql_connection) + cfg.set_defaults(sqlite_db_opts, + sqlite_db=sqlite_db) + # Update the QueuePool defaults + if max_pool_size is not None: + cfg.set_defaults(database_opts, + max_pool_size=max_pool_size) + if max_overflow is not None: + cfg.set_defaults(database_opts, + max_overflow=max_overflow) + if pool_timeout is not None: + cfg.set_defaults(database_opts, + pool_timeout=pool_timeout) + + +def cleanup(): + global _ENGINE, _MAKER + global _SLAVE_ENGINE, _SLAVE_MAKER + + if _MAKER: + _MAKER.close_all() + _MAKER = None + if _ENGINE: + _ENGINE.dispose() + _ENGINE = None + if _SLAVE_MAKER: + _SLAVE_MAKER.close_all() + _SLAVE_MAKER = None + if _SLAVE_ENGINE: + _SLAVE_ENGINE.dispose() + _SLAVE_ENGINE = None + + +class SqliteForeignKeysListener(PoolListener): + """Ensures that the foreign key constraints are enforced in SQLite. + + The foreign key constraints are disabled by default in SQLite, + so the foreign key constraints will be enabled here for every + database connection + """ + def connect(self, dbapi_con, con_record): + dbapi_con.execute('pragma foreign_keys=ON') + + +def get_session(autocommit=True, expire_on_commit=False, + sqlite_fk=False, slave_session=False): + """Return a SQLAlchemy session.""" + global _MAKER + global _SLAVE_MAKER + maker = _MAKER + + if slave_session: + maker = _SLAVE_MAKER + + if maker is None: + engine = get_engine(sqlite_fk=sqlite_fk, slave_engine=slave_session) + maker = get_maker(engine, autocommit, expire_on_commit) + + if slave_session: + _SLAVE_MAKER = maker + else: + _MAKER = maker + + session = maker() + return session + + +# note(boris-42): In current versions of DB backends unique constraint +# violation messages follow the structure: +# +# sqlite: +# 1 column - (IntegrityError) column c1 is not unique +# N columns - (IntegrityError) column c1, c2, ..., N are not unique +# +# postgres: +# 1 column - (IntegrityError) duplicate key value violates unique +# constraint "users_c1_key" +# N columns - (IntegrityError) duplicate key value violates unique +# constraint "name_of_our_constraint" +# +# mysql: +# 1 column - (IntegrityError) (1062, "Duplicate entry 'value_of_c1' for key +# 'c1'") +# N columns - (IntegrityError) (1062, "Duplicate entry 'values joined +# with -' for key 'name_of_our_constraint'") +_DUP_KEY_RE_DB = { + "sqlite": re.compile(r"^.*columns?([^)]+)(is|are)\s+not\s+unique$"), + "postgresql": re.compile(r"^.*duplicate\s+key.*\"([^\"]+)\"\s*\n.*$"), + "mysql": re.compile(r"^.*\(1062,.*'([^\']+)'\"\)$") +} + + +def _raise_if_duplicate_entry_error(integrity_error, engine_name): + """Raise exception if two entries are duplicated. + + In this function will be raised DBDuplicateEntry exception if integrity + error wrap unique constraint violation. + """ + + def get_columns_from_uniq_cons_or_name(columns): + # note(vsergeyev): UniqueConstraint name convention: "uniq_t0c10c2" + # where `t` it is table name and columns `c1`, `c2` + # are in UniqueConstraint. + uniqbase = "uniq_" + if not columns.startswith(uniqbase): + if engine_name == "postgresql": + return [columns[columns.index("_") + 1:columns.rindex("_")]] + return [columns] + return columns[len(uniqbase):].split("0")[1:] + + if engine_name not in ["mysql", "sqlite", "postgresql"]: + return + + # FIXME(johannes): The usage of the .message attribute has been + # deprecated since Python 2.6. However, the exceptions raised by + # SQLAlchemy can differ when using unicode() and accessing .message. + # An audit across all three supported engines will be necessary to + # ensure there are no regressions. + m = _DUP_KEY_RE_DB[engine_name].match(integrity_error.message) + if not m: + return + columns = m.group(1) + + if engine_name == "sqlite": + columns = columns.strip().split(", ") + else: + columns = get_columns_from_uniq_cons_or_name(columns) + raise exception.DBDuplicateEntry(columns, integrity_error) + + +# NOTE(comstud): In current versions of DB backends, Deadlock violation +# messages follow the structure: +# +# mysql: +# (OperationalError) (1213, 'Deadlock found when trying to get lock; try ' +# 'restarting transaction') +_DEADLOCK_RE_DB = { + "mysql": re.compile(r"^.*\(1213, 'Deadlock.*") +} + + +def _raise_if_deadlock_error(operational_error, engine_name): + """Raise exception on deadlock condition. + + Raise DBDeadlock exception if OperationalError contains a Deadlock + condition. + """ + re = _DEADLOCK_RE_DB.get(engine_name) + if re is None: + return + # FIXME(johannes): The usage of the .message attribute has been + # deprecated since Python 2.6. However, the exceptions raised by + # SQLAlchemy can differ when using unicode() and accessing .message. + # An audit across all three supported engines will be necessary to + # ensure there are no regressions. + m = re.match(operational_error.message) + if not m: + return + raise exception.DBDeadlock(operational_error) + + +def _wrap_db_error(f): + @functools.wraps(f) + def _wrap(*args, **kwargs): + try: + return f(*args, **kwargs) + except UnicodeEncodeError: + raise exception.DBInvalidUnicodeParameter() + # note(boris-42): We should catch unique constraint violation and + # wrap it by our own DBDuplicateEntry exception. Unique constraint + # violation is wrapped by IntegrityError. + except sqla_exc.OperationalError as e: + _raise_if_deadlock_error(e, get_engine().name) + # NOTE(comstud): A lot of code is checking for OperationalError + # so let's not wrap it for now. + raise + except sqla_exc.IntegrityError as e: + # note(boris-42): SqlAlchemy doesn't unify errors from different + # DBs so we must do this. Also in some tables (for example + # instance_types) there are more than one unique constraint. This + # means we should get names of columns, which values violate + # unique constraint, from error message. + _raise_if_duplicate_entry_error(e, get_engine().name) + raise exception.DBError(e) + except Exception as e: + LOG.exception(_('DB exception wrapped.')) + raise exception.DBError(e) + return _wrap + + +def get_engine(sqlite_fk=False, slave_engine=False): + """Return a SQLAlchemy engine.""" + global _ENGINE + global _SLAVE_ENGINE + engine = _ENGINE + db_uri = CONF.database.connection + + if slave_engine: + engine = _SLAVE_ENGINE + db_uri = CONF.database.slave_connection + + if engine is None: + engine = create_engine(db_uri, + sqlite_fk=sqlite_fk) + if slave_engine: + _SLAVE_ENGINE = engine + else: + _ENGINE = engine + + return engine + + +def _synchronous_switch_listener(dbapi_conn, connection_rec): + """Switch sqlite connections to non-synchronous mode.""" + dbapi_conn.execute("PRAGMA synchronous = OFF") + + +def _add_regexp_listener(dbapi_con, con_record): + """Add REGEXP function to sqlite connections.""" + + def regexp(expr, item): + reg = re.compile(expr) + return reg.search(six.text_type(item)) is not None + dbapi_con.create_function('regexp', 2, regexp) + + +def _thread_yield(dbapi_con, con_record): + """Ensure other greenthreads get a chance to be executed. + + If we use eventlet.monkey_patch(), eventlet.greenthread.sleep(0) will + execute instead of time.sleep(0). + Force a context switch. With common database backends (eg MySQLdb and + sqlite), there is no implicit yield caused by network I/O since they are + implemented by C libraries that eventlet cannot monkey patch. + """ + time.sleep(0) + + +def _ping_listener(dbapi_conn, connection_rec, connection_proxy): + """Ensures that MySQL connections checked out of the pool are alive. + + Borrowed from: + http://groups.google.com/group/sqlalchemy/msg/a4ce563d802c929f + """ + try: + dbapi_conn.cursor().execute('select 1') + except dbapi_conn.OperationalError as ex: + if ex.args[0] in (2006, 2013, 2014, 2045, 2055): + LOG.warn(_('Got mysql server has gone away: %s'), ex) + raise sqla_exc.DisconnectionError("Database server went away") + else: + raise + + +def _is_db_connection_error(args): + """Return True if error in connecting to db.""" + # NOTE(adam_g): This is currently MySQL specific and needs to be extended + # to support Postgres and others. + # For the db2, the error code is -30081 since the db2 is still not ready + conn_err_codes = ('2002', '2003', '2006', '-30081') + for err_code in conn_err_codes: + if args.find(err_code) != -1: + return True + return False + + +def create_engine(sql_connection, sqlite_fk=False): + """Return a new SQLAlchemy engine.""" + # NOTE(geekinutah): At this point we could be connecting to the normal + # db handle or the slave db handle. Things like + # _wrap_db_error aren't going to work well if their + # backends don't match. Let's check. + _assert_matching_drivers() + connection_dict = sqlalchemy.engine.url.make_url(sql_connection) + + engine_args = { + "pool_recycle": CONF.database.idle_timeout, + "echo": False, + 'convert_unicode': True, + } + + # Map our SQL debug level to SQLAlchemy's options + if CONF.database.connection_debug >= 100: + engine_args['echo'] = 'debug' + elif CONF.database.connection_debug >= 50: + engine_args['echo'] = True + + if "sqlite" in connection_dict.drivername: + if sqlite_fk: + engine_args["listeners"] = [SqliteForeignKeysListener()] + engine_args["poolclass"] = NullPool + + if CONF.database.connection == "sqlite://": + engine_args["poolclass"] = StaticPool + engine_args["connect_args"] = {'check_same_thread': False} + else: + if CONF.database.max_pool_size is not None: + engine_args['pool_size'] = CONF.database.max_pool_size + if CONF.database.max_overflow is not None: + engine_args['max_overflow'] = CONF.database.max_overflow + if CONF.database.pool_timeout is not None: + engine_args['pool_timeout'] = CONF.database.pool_timeout + + engine = sqlalchemy.create_engine(sql_connection, **engine_args) + + sqlalchemy.event.listen(engine, 'checkin', _thread_yield) + + if 'mysql' in connection_dict.drivername: + sqlalchemy.event.listen(engine, 'checkout', _ping_listener) + elif 'sqlite' in connection_dict.drivername: + if not CONF.sqlite_synchronous: + sqlalchemy.event.listen(engine, 'connect', + _synchronous_switch_listener) + sqlalchemy.event.listen(engine, 'connect', _add_regexp_listener) + + if (CONF.database.connection_trace and + engine.dialect.dbapi.__name__ == 'MySQLdb'): + _patch_mysqldb_with_stacktrace_comments() + + try: + engine.connect() + except sqla_exc.OperationalError as e: + if not _is_db_connection_error(e.args[0]): + raise + + remaining = CONF.database.max_retries + if remaining == -1: + remaining = 'infinite' + while True: + msg = _('SQL connection failed. %s attempts left.') + LOG.warn(msg % remaining) + if remaining != 'infinite': + remaining -= 1 + time.sleep(CONF.database.retry_interval) + try: + engine.connect() + break + except sqla_exc.OperationalError as e: + if (remaining != 'infinite' and remaining == 0) or \ + not _is_db_connection_error(e.args[0]): + raise + return engine + + +class Query(sqlalchemy.orm.query.Query): + """Subclass of sqlalchemy.query with soft_delete() method.""" + def soft_delete(self, synchronize_session='evaluate'): + return self.update({'deleted': literal_column('id'), + 'updated_at': literal_column('updated_at'), + 'deleted_at': timeutils.utcnow()}, + synchronize_session=synchronize_session) + + +class Session(sqlalchemy.orm.session.Session): + """Custom Session class to avoid SqlAlchemy Session monkey patching.""" + @_wrap_db_error + def query(self, *args, **kwargs): + return super(Session, self).query(*args, **kwargs) + + @_wrap_db_error + def flush(self, *args, **kwargs): + return super(Session, self).flush(*args, **kwargs) + + @_wrap_db_error + def execute(self, *args, **kwargs): + return super(Session, self).execute(*args, **kwargs) + + +def get_maker(engine, autocommit=True, expire_on_commit=False): + """Return a SQLAlchemy sessionmaker using the given engine.""" + return sqlalchemy.orm.sessionmaker(bind=engine, + class_=Session, + autocommit=autocommit, + expire_on_commit=expire_on_commit, + query_cls=Query) + + +def _patch_mysqldb_with_stacktrace_comments(): + """Adds current stack trace as a comment in queries. + + Patches MySQLdb.cursors.BaseCursor._do_query. + """ + import MySQLdb.cursors + import traceback + + old_mysql_do_query = MySQLdb.cursors.BaseCursor._do_query + + def _do_query(self, q): + stack = '' + for filename, line, method, function in traceback.extract_stack(): + # exclude various common things from trace + if filename.endswith('session.py') and method == '_do_query': + continue + if filename.endswith('api.py') and method == 'wrapper': + continue + if filename.endswith('utils.py') and method == '_inner': + continue + if filename.endswith('exception.py') and method == '_wrap': + continue + # db/api is just a wrapper around db/sqlalchemy/api + if filename.endswith('db/api.py'): + continue + # only trace inside mistral + index = filename.rfind('mistral') + if index == -1: + continue + stack += "File:%s:%s Method:%s() Line:%s | " \ + % (filename[index:], line, method, function) + + # strip trailing " | " from stack + if stack: + stack = stack[:-3] + qq = "%s /* %s */" % (q, stack) + else: + qq = q + old_mysql_do_query(self, qq) + + setattr(MySQLdb.cursors.BaseCursor, '_do_query', _do_query) + + +def _assert_matching_drivers(): + """Make sure slave handle and normal handle have the same driver.""" + # NOTE(geekinutah): There's no use case for writing to one backend and + # reading from another. Who knows what the future holds? + if CONF.database.slave_connection == '': + return + + normal = sqlalchemy.engine.url.make_url(CONF.database.connection) + slave = sqlalchemy.engine.url.make_url(CONF.database.slave_connection) + assert normal.drivername == slave.drivername diff --git a/mistral/openstack/common/db/sqlalchemy/test_migrations.py b/mistral/openstack/common/db/sqlalchemy/test_migrations.py new file mode 100644 index 00000000..2f60f3b5 --- /dev/null +++ b/mistral/openstack/common/db/sqlalchemy/test_migrations.py @@ -0,0 +1,289 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010-2011 OpenStack Foundation +# Copyright 2012-2013 IBM Corp. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + + +import commands +import ConfigParser +import os +import urlparse + +import sqlalchemy +import sqlalchemy.exc + +from mistral.openstack.common import lockutils +from mistral.openstack.common import log as logging +from mistral.openstack.common import test + +LOG = logging.getLogger(__name__) + + +def _get_connect_string(backend, user, passwd, database): + """Get database connection + + Try to get a connection with a very specific set of values, if we get + these then we'll run the tests, otherwise they are skipped + """ + if backend == "postgres": + backend = "postgresql+psycopg2" + elif backend == "mysql": + backend = "mysql+mysqldb" + else: + raise Exception("Unrecognized backend: '%s'" % backend) + + return ("%(backend)s://%(user)s:%(passwd)s@localhost/%(database)s" + % {'backend': backend, 'user': user, 'passwd': passwd, + 'database': database}) + + +def _is_backend_avail(backend, user, passwd, database): + try: + connect_uri = _get_connect_string(backend, user, passwd, database) + engine = sqlalchemy.create_engine(connect_uri) + connection = engine.connect() + except Exception: + # intentionally catch all to handle exceptions even if we don't + # have any backend code loaded. + return False + else: + connection.close() + engine.dispose() + return True + + +def _have_mysql(user, passwd, database): + present = os.environ.get('TEST_MYSQL_PRESENT') + if present is None: + return _is_backend_avail('mysql', user, passwd, database) + return present.lower() in ('', 'true') + + +def _have_postgresql(user, passwd, database): + present = os.environ.get('TEST_POSTGRESQL_PRESENT') + if present is None: + return _is_backend_avail('postgres', user, passwd, database) + return present.lower() in ('', 'true') + + +def get_db_connection_info(conn_pieces): + database = conn_pieces.path.strip('/') + loc_pieces = conn_pieces.netloc.split('@') + host = loc_pieces[1] + + auth_pieces = loc_pieces[0].split(':') + user = auth_pieces[0] + password = "" + if len(auth_pieces) > 1: + password = auth_pieces[1].strip() + + return (user, password, database, host) + + +class BaseMigrationTestCase(test.BaseTestCase): + """Base class fort testing of migration utils.""" + + def __init__(self, *args, **kwargs): + super(BaseMigrationTestCase, self).__init__(*args, **kwargs) + + self.DEFAULT_CONFIG_FILE = os.path.join(os.path.dirname(__file__), + 'test_migrations.conf') + # Test machines can set the TEST_MIGRATIONS_CONF variable + # to override the location of the config file for migration testing + self.CONFIG_FILE_PATH = os.environ.get('TEST_MIGRATIONS_CONF', + self.DEFAULT_CONFIG_FILE) + self.test_databases = {} + self.migration_api = None + + def setUp(self): + super(BaseMigrationTestCase, self).setUp() + + # Load test databases from the config file. Only do this + # once. No need to re-run this on each test... + LOG.debug('config_path is %s' % self.CONFIG_FILE_PATH) + if os.path.exists(self.CONFIG_FILE_PATH): + cp = ConfigParser.RawConfigParser() + try: + cp.read(self.CONFIG_FILE_PATH) + defaults = cp.defaults() + for key, value in defaults.items(): + self.test_databases[key] = value + except ConfigParser.ParsingError as e: + self.fail("Failed to read test_migrations.conf config " + "file. Got error: %s" % e) + else: + self.fail("Failed to find test_migrations.conf config " + "file.") + + self.engines = {} + for key, value in self.test_databases.items(): + self.engines[key] = sqlalchemy.create_engine(value) + + # We start each test case with a completely blank slate. + self._reset_databases() + + def tearDown(self): + # We destroy the test data store between each test case, + # and recreate it, which ensures that we have no side-effects + # from the tests + self._reset_databases() + super(BaseMigrationTestCase, self).tearDown() + + def execute_cmd(self, cmd=None): + status, output = commands.getstatusoutput(cmd) + LOG.debug(output) + self.assertEqual(0, status, + "Failed to run: %s\n%s" % (cmd, output)) + + @lockutils.synchronized('pgadmin', 'tests-', external=True) + def _reset_pg(self, conn_pieces): + (user, password, database, host) = get_db_connection_info(conn_pieces) + os.environ['PGPASSWORD'] = password + os.environ['PGUSER'] = user + # note(boris-42): We must create and drop database, we can't + # drop database which we have connected to, so for such + # operations there is a special database template1. + sqlcmd = ("psql -w -U %(user)s -h %(host)s -c" + " '%(sql)s' -d template1") + + sql = ("drop database if exists %s;") % database + droptable = sqlcmd % {'user': user, 'host': host, 'sql': sql} + self.execute_cmd(droptable) + + sql = ("create database %s;") % database + createtable = sqlcmd % {'user': user, 'host': host, 'sql': sql} + self.execute_cmd(createtable) + + os.unsetenv('PGPASSWORD') + os.unsetenv('PGUSER') + + def _reset_databases(self): + for key, engine in self.engines.items(): + conn_string = self.test_databases[key] + conn_pieces = urlparse.urlparse(conn_string) + engine.dispose() + if conn_string.startswith('sqlite'): + # We can just delete the SQLite database, which is + # the easiest and cleanest solution + db_path = conn_pieces.path.strip('/') + if os.path.exists(db_path): + os.unlink(db_path) + # No need to recreate the SQLite DB. SQLite will + # create it for us if it's not there... + elif conn_string.startswith('mysql'): + # We can execute the MySQL client to destroy and re-create + # the MYSQL database, which is easier and less error-prone + # than using SQLAlchemy to do this via MetaData...trust me. + (user, password, database, host) = \ + get_db_connection_info(conn_pieces) + sql = ("drop database if exists %(db)s; " + "create database %(db)s;") % {'db': database} + cmd = ("mysql -u \"%(user)s\" -p\"%(password)s\" -h %(host)s " + "-e \"%(sql)s\"") % {'user': user, 'password': password, + 'host': host, 'sql': sql} + self.execute_cmd(cmd) + elif conn_string.startswith('postgresql'): + self._reset_pg(conn_pieces) + + +class WalkVersionsMixin(object): + def _walk_versions(self, engine=None, snake_walk=False, downgrade=True): + # Determine latest version script from the repo, then + # upgrade from 1 through to the latest, with no data + # in the databases. This just checks that the schema itself + # upgrades successfully. + + # Place the database under version control + self.migration_api.version_control(engine, self.REPOSITORY, + self.INIT_VERSION) + self.assertEqual(self.INIT_VERSION, + self.migration_api.db_version(engine, + self.REPOSITORY)) + + LOG.debug('latest version is %s' % self.REPOSITORY.latest) + versions = range(self.INIT_VERSION + 1, self.REPOSITORY.latest + 1) + + for version in versions: + # upgrade -> downgrade -> upgrade + self._migrate_up(engine, version, with_data=True) + if snake_walk: + downgraded = self._migrate_down( + engine, version - 1, with_data=True) + if downgraded: + self._migrate_up(engine, version) + + if downgrade: + # Now walk it back down to 0 from the latest, testing + # the downgrade paths. + for version in reversed(versions): + # downgrade -> upgrade -> downgrade + downgraded = self._migrate_down(engine, version - 1) + + if snake_walk and downgraded: + self._migrate_up(engine, version) + self._migrate_down(engine, version - 1) + + def _migrate_down(self, engine, version, with_data=False): + try: + self.migration_api.downgrade(engine, self.REPOSITORY, version) + except NotImplementedError: + # NOTE(sirp): some migrations, namely release-level + # migrations, don't support a downgrade. + return False + + self.assertEqual( + version, self.migration_api.db_version(engine, self.REPOSITORY)) + + # NOTE(sirp): `version` is what we're downgrading to (i.e. the 'target' + # version). So if we have any downgrade checks, they need to be run for + # the previous (higher numbered) migration. + if with_data: + post_downgrade = getattr( + self, "_post_downgrade_%03d" % (version + 1), None) + if post_downgrade: + post_downgrade(engine) + + return True + + def _migrate_up(self, engine, version, with_data=False): + """migrate up to a new version of the db. + + We allow for data insertion and post checks at every + migration version with special _pre_upgrade_### and + _check_### functions in the main test. + """ + # NOTE(sdague): try block is here because it's impossible to debug + # where a failed data migration happens otherwise + try: + if with_data: + data = None + pre_upgrade = getattr( + self, "_pre_upgrade_%03d" % version, None) + if pre_upgrade: + data = pre_upgrade(engine) + + self.migration_api.upgrade(engine, self.REPOSITORY, version) + self.assertEqual(version, + self.migration_api.db_version(engine, + self.REPOSITORY)) + if with_data: + check = getattr(self, "_check_%03d" % version, None) + if check: + check(engine, data) + except Exception: + LOG.error("Failed to migrate to version %s on engine %s" % + (version, engine)) + raise diff --git a/mistral/openstack/common/db/sqlalchemy/utils.py b/mistral/openstack/common/db/sqlalchemy/utils.py new file mode 100644 index 00000000..b8aafb1b --- /dev/null +++ b/mistral/openstack/common/db/sqlalchemy/utils.py @@ -0,0 +1,501 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# Copyright 2010-2011 OpenStack Foundation. +# Copyright 2012 Justin Santa Barbara +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import re + +from migrate.changeset import UniqueConstraint +import sqlalchemy +from sqlalchemy import Boolean +from sqlalchemy import CheckConstraint +from sqlalchemy import Column +from sqlalchemy.engine import reflection +from sqlalchemy.ext.compiler import compiles +from sqlalchemy import func +from sqlalchemy import Index +from sqlalchemy import Integer +from sqlalchemy import MetaData +from sqlalchemy.sql.expression import literal_column +from sqlalchemy.sql.expression import UpdateBase +from sqlalchemy.sql import select +from sqlalchemy import String +from sqlalchemy import Table +from sqlalchemy.types import NullType + +from mistral.openstack.common.gettextutils import _ # noqa + +from mistral.openstack.common import log as logging +from mistral.openstack.common import timeutils + + +LOG = logging.getLogger(__name__) + +_DBURL_REGEX = re.compile(r"[^:]+://([^:]+):([^@]+)@.+") + + +def sanitize_db_url(url): + match = _DBURL_REGEX.match(url) + if match: + return '%s****:****%s' % (url[:match.start(1)], url[match.end(2):]) + return url + + +class InvalidSortKey(Exception): + message = _("Sort key supplied was not valid.") + + +# copy from glance/db/sqlalchemy/api.py +def paginate_query(query, model, limit, sort_keys, marker=None, + sort_dir=None, sort_dirs=None): + """Returns a query with sorting / pagination criteria added. + + Pagination works by requiring a unique sort_key, specified by sort_keys. + (If sort_keys is not unique, then we risk looping through values.) + We use the last row in the previous page as the 'marker' for pagination. + So we must return values that follow the passed marker in the order. + With a single-valued sort_key, this would be easy: sort_key > X. + With a compound-values sort_key, (k1, k2, k3) we must do this to repeat + the lexicographical ordering: + (k1 > X1) or (k1 == X1 && k2 > X2) or (k1 == X1 && k2 == X2 && k3 > X3) + + We also have to cope with different sort_directions. + + Typically, the id of the last row is used as the client-facing pagination + marker, then the actual marker object must be fetched from the db and + passed in to us as marker. + + :param query: the query object to which we should add paging/sorting + :param model: the ORM model class + :param limit: maximum number of items to return + :param sort_keys: array of attributes by which results should be sorted + :param marker: the last item of the previous page; we returns the next + results after this value. + :param sort_dir: direction in which results should be sorted (asc, desc) + :param sort_dirs: per-column array of sort_dirs, corresponding to sort_keys + + :rtype: sqlalchemy.orm.query.Query + :return: The query with sorting/pagination added. + """ + + if 'id' not in sort_keys: + # TODO(justinsb): If this ever gives a false-positive, check + # the actual primary key, rather than assuming its id + LOG.warn(_('Id not in sort_keys; is sort_keys unique?')) + + assert(not (sort_dir and sort_dirs)) + + # Default the sort direction to ascending + if sort_dirs is None and sort_dir is None: + sort_dir = 'asc' + + # Ensure a per-column sort direction + if sort_dirs is None: + sort_dirs = [sort_dir for _sort_key in sort_keys] + + assert(len(sort_dirs) == len(sort_keys)) + + # Add sorting + for current_sort_key, current_sort_dir in zip(sort_keys, sort_dirs): + try: + sort_dir_func = { + 'asc': sqlalchemy.asc, + 'desc': sqlalchemy.desc, + }[current_sort_dir] + except KeyError: + raise ValueError(_("Unknown sort direction, " + "must be 'desc' or 'asc'")) + try: + sort_key_attr = getattr(model, current_sort_key) + except AttributeError: + raise InvalidSortKey() + query = query.order_by(sort_dir_func(sort_key_attr)) + + # Add pagination + if marker is not None: + marker_values = [] + for sort_key in sort_keys: + v = getattr(marker, sort_key) + marker_values.append(v) + + # Build up an array of sort criteria as in the docstring + criteria_list = [] + for i in range(0, len(sort_keys)): + crit_attrs = [] + for j in range(0, i): + model_attr = getattr(model, sort_keys[j]) + crit_attrs.append((model_attr == marker_values[j])) + + model_attr = getattr(model, sort_keys[i]) + if sort_dirs[i] == 'desc': + crit_attrs.append((model_attr < marker_values[i])) + else: + crit_attrs.append((model_attr > marker_values[i])) + + criteria = sqlalchemy.sql.and_(*crit_attrs) + criteria_list.append(criteria) + + f = sqlalchemy.sql.or_(*criteria_list) + query = query.filter(f) + + if limit is not None: + query = query.limit(limit) + + return query + + +def get_table(engine, name): + """Returns an sqlalchemy table dynamically from db. + + Needed because the models don't work for us in migrations + as models will be far out of sync with the current data. + """ + metadata = MetaData() + metadata.bind = engine + return Table(name, metadata, autoload=True) + + +class InsertFromSelect(UpdateBase): + """Form the base for `INSERT INTO table (SELECT ... )` statement.""" + def __init__(self, table, select): + self.table = table + self.select = select + + +@compiles(InsertFromSelect) +def visit_insert_from_select(element, compiler, **kw): + """Form the `INSERT INTO table (SELECT ... )` statement.""" + return "INSERT INTO %s %s" % ( + compiler.process(element.table, asfrom=True), + compiler.process(element.select)) + + +class ColumnError(Exception): + """Error raised when no column or an invalid column is found.""" + + +def _get_not_supported_column(col_name_col_instance, column_name): + try: + column = col_name_col_instance[column_name] + except KeyError: + msg = _("Please specify column %s in col_name_col_instance " + "param. It is required because column has unsupported " + "type by sqlite).") + raise ColumnError(msg % column_name) + + if not isinstance(column, Column): + msg = _("col_name_col_instance param has wrong type of " + "column instance for column %s It should be instance " + "of sqlalchemy.Column.") + raise ColumnError(msg % column_name) + return column + + +def drop_unique_constraint(migrate_engine, table_name, uc_name, *columns, + **col_name_col_instance): + """Drop unique constraint from table. + + This method drops UC from table and works for mysql, postgresql and sqlite. + In mysql and postgresql we are able to use "alter table" construction. + Sqlalchemy doesn't support some sqlite column types and replaces their + type with NullType in metadata. We process these columns and replace + NullType with the correct column type. + + :param migrate_engine: sqlalchemy engine + :param table_name: name of table that contains uniq constraint. + :param uc_name: name of uniq constraint that will be dropped. + :param columns: columns that are in uniq constraint. + :param col_name_col_instance: contains pair column_name=column_instance. + column_instance is instance of Column. These params + are required only for columns that have unsupported + types by sqlite. For example BigInteger. + """ + + meta = MetaData() + meta.bind = migrate_engine + t = Table(table_name, meta, autoload=True) + + if migrate_engine.name == "sqlite": + override_cols = [ + _get_not_supported_column(col_name_col_instance, col.name) + for col in t.columns + if isinstance(col.type, NullType) + ] + for col in override_cols: + t.columns.replace(col) + + uc = UniqueConstraint(*columns, table=t, name=uc_name) + uc.drop() + + +def drop_old_duplicate_entries_from_table(migrate_engine, table_name, + use_soft_delete, *uc_column_names): + """Drop all old rows having the same values for columns in uc_columns. + + This method drop (or mark ad `deleted` if use_soft_delete is True) old + duplicate rows form table with name `table_name`. + + :param migrate_engine: Sqlalchemy engine + :param table_name: Table with duplicates + :param use_soft_delete: If True - values will be marked as `deleted`, + if False - values will be removed from table + :param uc_column_names: Unique constraint columns + """ + meta = MetaData() + meta.bind = migrate_engine + + table = Table(table_name, meta, autoload=True) + columns_for_group_by = [table.c[name] for name in uc_column_names] + + columns_for_select = [func.max(table.c.id)] + columns_for_select.extend(columns_for_group_by) + + duplicated_rows_select = select(columns_for_select, + group_by=columns_for_group_by, + having=func.count(table.c.id) > 1) + + for row in migrate_engine.execute(duplicated_rows_select): + # NOTE(boris-42): Do not remove row that has the biggest ID. + delete_condition = table.c.id != row[0] + is_none = None # workaround for pyflakes + delete_condition &= table.c.deleted_at == is_none + for name in uc_column_names: + delete_condition &= table.c[name] == row[name] + + rows_to_delete_select = select([table.c.id]).where(delete_condition) + for row in migrate_engine.execute(rows_to_delete_select).fetchall(): + LOG.info(_("Deleting duplicated row with id: %(id)s from table: " + "%(table)s") % dict(id=row[0], table=table_name)) + + if use_soft_delete: + delete_statement = table.update().\ + where(delete_condition).\ + values({ + 'deleted': literal_column('id'), + 'updated_at': literal_column('updated_at'), + 'deleted_at': timeutils.utcnow() + }) + else: + delete_statement = table.delete().where(delete_condition) + migrate_engine.execute(delete_statement) + + +def _get_default_deleted_value(table): + if isinstance(table.c.id.type, Integer): + return 0 + if isinstance(table.c.id.type, String): + return "" + raise ColumnError(_("Unsupported id columns type")) + + +def _restore_indexes_on_deleted_columns(migrate_engine, table_name, indexes): + table = get_table(migrate_engine, table_name) + + insp = reflection.Inspector.from_engine(migrate_engine) + real_indexes = insp.get_indexes(table_name) + existing_index_names = dict( + [(index['name'], index['column_names']) for index in real_indexes]) + + # NOTE(boris-42): Restore indexes on `deleted` column + for index in indexes: + if 'deleted' not in index['column_names']: + continue + name = index['name'] + if name in existing_index_names: + column_names = [table.c[c] for c in existing_index_names[name]] + old_index = Index(name, *column_names, unique=index["unique"]) + old_index.drop(migrate_engine) + + column_names = [table.c[c] for c in index['column_names']] + new_index = Index(index["name"], *column_names, unique=index["unique"]) + new_index.create(migrate_engine) + + +def change_deleted_column_type_to_boolean(migrate_engine, table_name, + **col_name_col_instance): + if migrate_engine.name == "sqlite": + return _change_deleted_column_type_to_boolean_sqlite( + migrate_engine, table_name, **col_name_col_instance) + insp = reflection.Inspector.from_engine(migrate_engine) + indexes = insp.get_indexes(table_name) + + table = get_table(migrate_engine, table_name) + + old_deleted = Column('old_deleted', Boolean, default=False) + old_deleted.create(table, populate_default=False) + + table.update().\ + where(table.c.deleted == table.c.id).\ + values(old_deleted=True).\ + execute() + + table.c.deleted.drop() + table.c.old_deleted.alter(name="deleted") + + _restore_indexes_on_deleted_columns(migrate_engine, table_name, indexes) + + +def _change_deleted_column_type_to_boolean_sqlite(migrate_engine, table_name, + **col_name_col_instance): + insp = reflection.Inspector.from_engine(migrate_engine) + table = get_table(migrate_engine, table_name) + + columns = [] + for column in table.columns: + column_copy = None + if column.name != "deleted": + if isinstance(column.type, NullType): + column_copy = _get_not_supported_column(col_name_col_instance, + column.name) + else: + column_copy = column.copy() + else: + column_copy = Column('deleted', Boolean, default=0) + columns.append(column_copy) + + constraints = [constraint.copy() for constraint in table.constraints] + + meta = table.metadata + new_table = Table(table_name + "__tmp__", meta, + *(columns + constraints)) + new_table.create() + + indexes = [] + for index in insp.get_indexes(table_name): + column_names = [new_table.c[c] for c in index['column_names']] + indexes.append(Index(index["name"], *column_names, + unique=index["unique"])) + + c_select = [] + for c in table.c: + if c.name != "deleted": + c_select.append(c) + else: + c_select.append(table.c.deleted == table.c.id) + + ins = InsertFromSelect(new_table, select(c_select)) + migrate_engine.execute(ins) + + table.drop() + [index.create(migrate_engine) for index in indexes] + + new_table.rename(table_name) + new_table.update().\ + where(new_table.c.deleted == new_table.c.id).\ + values(deleted=True).\ + execute() + + +def change_deleted_column_type_to_id_type(migrate_engine, table_name, + **col_name_col_instance): + if migrate_engine.name == "sqlite": + return _change_deleted_column_type_to_id_type_sqlite( + migrate_engine, table_name, **col_name_col_instance) + insp = reflection.Inspector.from_engine(migrate_engine) + indexes = insp.get_indexes(table_name) + + table = get_table(migrate_engine, table_name) + + new_deleted = Column('new_deleted', table.c.id.type, + default=_get_default_deleted_value(table)) + new_deleted.create(table, populate_default=True) + + deleted = True # workaround for pyflakes + table.update().\ + where(table.c.deleted == deleted).\ + values(new_deleted=table.c.id).\ + execute() + table.c.deleted.drop() + table.c.new_deleted.alter(name="deleted") + + _restore_indexes_on_deleted_columns(migrate_engine, table_name, indexes) + + +def _change_deleted_column_type_to_id_type_sqlite(migrate_engine, table_name, + **col_name_col_instance): + # NOTE(boris-42): sqlaclhemy-migrate can't drop column with check + # constraints in sqlite DB and our `deleted` column has + # 2 check constraints. So there is only one way to remove + # these constraints: + # 1) Create new table with the same columns, constraints + # and indexes. (except deleted column). + # 2) Copy all data from old to new table. + # 3) Drop old table. + # 4) Rename new table to old table name. + insp = reflection.Inspector.from_engine(migrate_engine) + meta = MetaData(bind=migrate_engine) + table = Table(table_name, meta, autoload=True) + default_deleted_value = _get_default_deleted_value(table) + + columns = [] + for column in table.columns: + column_copy = None + if column.name != "deleted": + if isinstance(column.type, NullType): + column_copy = _get_not_supported_column(col_name_col_instance, + column.name) + else: + column_copy = column.copy() + else: + column_copy = Column('deleted', table.c.id.type, + default=default_deleted_value) + columns.append(column_copy) + + def is_deleted_column_constraint(constraint): + # NOTE(boris-42): There is no other way to check is CheckConstraint + # associated with deleted column. + if not isinstance(constraint, CheckConstraint): + return False + sqltext = str(constraint.sqltext) + return (sqltext.endswith("deleted in (0, 1)") or + sqltext.endswith("deleted IN (:deleted_1, :deleted_2)")) + + constraints = [] + for constraint in table.constraints: + if not is_deleted_column_constraint(constraint): + constraints.append(constraint.copy()) + + new_table = Table(table_name + "__tmp__", meta, + *(columns + constraints)) + new_table.create() + + indexes = [] + for index in insp.get_indexes(table_name): + column_names = [new_table.c[c] for c in index['column_names']] + indexes.append(Index(index["name"], *column_names, + unique=index["unique"])) + + ins = InsertFromSelect(new_table, table.select()) + migrate_engine.execute(ins) + + table.drop() + [index.create(migrate_engine) for index in indexes] + + new_table.rename(table_name) + deleted = True # workaround for pyflakes + new_table.update().\ + where(new_table.c.deleted == deleted).\ + values(deleted=new_table.c.id).\ + execute() + + # NOTE(boris-42): Fix value of deleted column: False -> "" or 0. + deleted = False # workaround for pyflakes + new_table.update().\ + where(new_table.c.deleted == deleted).\ + values(deleted=default_deleted_value).\ + execute() diff --git a/mistral/openstack/common/excutils.py b/mistral/openstack/common/excutils.py new file mode 100644 index 00000000..92326151 --- /dev/null +++ b/mistral/openstack/common/excutils.py @@ -0,0 +1,101 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 OpenStack Foundation. +# Copyright 2012, Red Hat, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Exception related utilities. +""" + +import logging +import sys +import time +import traceback + +import six + +from mistral.openstack.common.gettextutils import _ # noqa + + +class save_and_reraise_exception(object): + """Save current exception, run some code and then re-raise. + + In some cases the exception context can be cleared, resulting in None + being attempted to be re-raised after an exception handler is run. This + can happen when eventlet switches greenthreads or when running an + exception handler, code raises and catches an exception. In both + cases the exception context will be cleared. + + To work around this, we save the exception state, run handler code, and + then re-raise the original exception. If another exception occurs, the + saved exception is logged and the new exception is re-raised. + + In some cases the caller may not want to re-raise the exception, and + for those circumstances this context provides a reraise flag that + can be used to suppress the exception. For example: + + except Exception: + with save_and_reraise_exception() as ctxt: + decide_if_need_reraise() + if not should_be_reraised: + ctxt.reraise = False + """ + def __init__(self): + self.reraise = True + + def __enter__(self): + self.type_, self.value, self.tb, = sys.exc_info() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is not None: + logging.error(_('Original exception being dropped: %s'), + traceback.format_exception(self.type_, + self.value, + self.tb)) + return False + if self.reraise: + six.reraise(self.type_, self.value, self.tb) + + +def forever_retry_uncaught_exceptions(infunc): + def inner_func(*args, **kwargs): + last_log_time = 0 + last_exc_message = None + exc_count = 0 + while True: + try: + return infunc(*args, **kwargs) + except Exception as exc: + this_exc_message = six.u(str(exc)) + if this_exc_message == last_exc_message: + exc_count += 1 + else: + exc_count = 1 + # Do not log any more frequently than once a minute unless + # the exception message changes + cur_time = int(time.time()) + if (cur_time - last_log_time > 60 or + this_exc_message != last_exc_message): + logging.exception( + _('Unexpected exception occurred %d time(s)... ' + 'retrying.') % exc_count) + last_log_time = cur_time + last_exc_message = this_exc_message + exc_count = 0 + # This should be a very rare event. In case it isn't, do + # a sleep. + time.sleep(1) + return inner_func diff --git a/mistral/openstack/common/fileutils.py b/mistral/openstack/common/fileutils.py new file mode 100644 index 00000000..96104d01 --- /dev/null +++ b/mistral/openstack/common/fileutils.py @@ -0,0 +1,139 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 OpenStack Foundation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + + +import contextlib +import errno +import os +import tempfile + +from mistral.openstack.common import excutils +from mistral.openstack.common.gettextutils import _ # noqa +from mistral.openstack.common import log as logging + +LOG = logging.getLogger(__name__) + +_FILE_CACHE = {} + + +def ensure_tree(path): + """Create a directory (and any ancestor directories required) + + :param path: Directory to create + """ + try: + os.makedirs(path) + except OSError as exc: + if exc.errno == errno.EEXIST: + if not os.path.isdir(path): + raise + else: + raise + + +def read_cached_file(filename, force_reload=False): + """Read from a file if it has been modified. + + :param force_reload: Whether to reload the file. + :returns: A tuple with a boolean specifying if the data is fresh + or not. + """ + global _FILE_CACHE + + if force_reload and filename in _FILE_CACHE: + del _FILE_CACHE[filename] + + reloaded = False + mtime = os.path.getmtime(filename) + cache_info = _FILE_CACHE.setdefault(filename, {}) + + if not cache_info or mtime > cache_info.get('mtime', 0): + LOG.debug(_("Reloading cached file %s") % filename) + with open(filename) as fap: + cache_info['data'] = fap.read() + cache_info['mtime'] = mtime + reloaded = True + return (reloaded, cache_info['data']) + + +def delete_if_exists(path, remove=os.unlink): + """Delete a file, but ignore file not found error. + + :param path: File to delete + :param remove: Optional function to remove passed path + """ + + try: + remove(path) + except OSError as e: + if e.errno != errno.ENOENT: + raise + + +@contextlib.contextmanager +def remove_path_on_error(path, remove=delete_if_exists): + """Protect code that wants to operate on PATH atomically. + Any exception will cause PATH to be removed. + + :param path: File to work with + :param remove: Optional function to remove passed path + """ + + try: + yield + except Exception: + with excutils.save_and_reraise_exception(): + remove(path) + + +def file_open(*args, **kwargs): + """Open file + + see built-in file() documentation for more details + + Note: The reason this is kept in a separate module is to easily + be able to provide a stub module that doesn't alter system + state at all (for unit tests) + """ + return file(*args, **kwargs) + + +def write_to_tempfile(content, path=None, suffix='', prefix='tmp'): + """Create temporary file or use existing file. + + This util is needed for creating temporary file with + specified content, suffix and prefix. If path is not None, + it will be used for writing content. If the path doesn't + exist it'll be created. + + :param content: content for temporary file. + :param path: same as parameter 'dir' for mkstemp + :param suffix: same as parameter 'suffix' for mkstemp + :param prefix: same as parameter 'prefix' for mkstemp + + For example: it can be used in database tests for creating + configuration files. + """ + if path: + ensure_tree(path) + + (fd, path) = tempfile.mkstemp(suffix=suffix, dir=path, prefix=prefix) + try: + os.write(fd, content) + finally: + os.close(fd) + return path diff --git a/mistral/openstack/common/gettextutils.py b/mistral/openstack/common/gettextutils.py new file mode 100644 index 00000000..f5108e97 --- /dev/null +++ b/mistral/openstack/common/gettextutils.py @@ -0,0 +1,373 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2012 Red Hat, Inc. +# Copyright 2013 IBM Corp. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +gettext for openstack-common modules. + +Usual usage in an openstack.common module: + + from mistral.openstack.common.gettextutils import _ +""" + +import copy +import gettext +import logging +import os +import re +try: + import UserString as _userString +except ImportError: + import collections as _userString + +from babel import localedata +import six + +_localedir = os.environ.get('mistral'.upper() + '_LOCALEDIR') +_t = gettext.translation('mistral', localedir=_localedir, fallback=True) + +_AVAILABLE_LANGUAGES = {} +USE_LAZY = False + + +def enable_lazy(): + """Convenience function for configuring _() to use lazy gettext + + Call this at the start of execution to enable the gettextutils._ + function to use lazy gettext functionality. This is useful if + your project is importing _ directly instead of using the + gettextutils.install() way of importing the _ function. + """ + global USE_LAZY + USE_LAZY = True + + +def _(msg): + if USE_LAZY: + return Message(msg, 'mistral') + else: + if six.PY3: + return _t.gettext(msg) + return _t.ugettext(msg) + + +def install(domain, lazy=False): + """Install a _() function using the given translation domain. + + Given a translation domain, install a _() function using gettext's + install() function. + + The main difference from gettext.install() is that we allow + overriding the default localedir (e.g. /usr/share/locale) using + a translation-domain-specific environment variable (e.g. + NOVA_LOCALEDIR). + + :param domain: the translation domain + :param lazy: indicates whether or not to install the lazy _() function. + The lazy _() introduces a way to do deferred translation + of messages by installing a _ that builds Message objects, + instead of strings, which can then be lazily translated into + any available locale. + """ + if lazy: + # NOTE(mrodden): Lazy gettext functionality. + # + # The following introduces a deferred way to do translations on + # messages in OpenStack. We override the standard _() function + # and % (format string) operation to build Message objects that can + # later be translated when we have more information. + # + # Also included below is an example LocaleHandler that translates + # Messages to an associated locale, effectively allowing many logs, + # each with their own locale. + + def _lazy_gettext(msg): + """Create and return a Message object. + + Lazy gettext function for a given domain, it is a factory method + for a project/module to get a lazy gettext function for its own + translation domain (i.e. nova, glance, cinder, etc.) + + Message encapsulates a string so that we can translate + it later when needed. + """ + return Message(msg, domain) + + from six import moves + moves.builtins.__dict__['_'] = _lazy_gettext + else: + localedir = '%s_LOCALEDIR' % domain.upper() + if six.PY3: + gettext.install(domain, + localedir=os.environ.get(localedir)) + else: + gettext.install(domain, + localedir=os.environ.get(localedir), + unicode=True) + + +class Message(_userString.UserString, object): + """Class used to encapsulate translatable messages.""" + def __init__(self, msg, domain): + # _msg is the gettext msgid and should never change + self._msg = msg + self._left_extra_msg = '' + self._right_extra_msg = '' + self._locale = None + self.params = None + self.domain = domain + + @property + def data(self): + # NOTE(mrodden): this should always resolve to a unicode string + # that best represents the state of the message currently + + localedir = os.environ.get(self.domain.upper() + '_LOCALEDIR') + if self.locale: + lang = gettext.translation(self.domain, + localedir=localedir, + languages=[self.locale], + fallback=True) + else: + # use system locale for translations + lang = gettext.translation(self.domain, + localedir=localedir, + fallback=True) + + if six.PY3: + ugettext = lang.gettext + else: + ugettext = lang.ugettext + + full_msg = (self._left_extra_msg + + ugettext(self._msg) + + self._right_extra_msg) + + if self.params is not None: + full_msg = full_msg % self.params + + return six.text_type(full_msg) + + @property + def locale(self): + return self._locale + + @locale.setter + def locale(self, value): + self._locale = value + if not self.params: + return + + # This Message object may have been constructed with one or more + # Message objects as substitution parameters, given as a single + # Message, or a tuple or Map containing some, so when setting the + # locale for this Message we need to set it for those Messages too. + if isinstance(self.params, Message): + self.params.locale = value + return + if isinstance(self.params, tuple): + for param in self.params: + if isinstance(param, Message): + param.locale = value + return + if isinstance(self.params, dict): + for param in self.params.values(): + if isinstance(param, Message): + param.locale = value + + def _save_dictionary_parameter(self, dict_param): + full_msg = self.data + # look for %(blah) fields in string; + # ignore %% and deal with the + # case where % is first character on the line + keys = re.findall('(?:[^%]|^)?%\((\w*)\)[a-z]', full_msg) + + # if we don't find any %(blah) blocks but have a %s + if not keys and re.findall('(?:[^%]|^)%[a-z]', full_msg): + # apparently the full dictionary is the parameter + params = copy.deepcopy(dict_param) + else: + params = {} + for key in keys: + try: + params[key] = copy.deepcopy(dict_param[key]) + except TypeError: + # cast uncopyable thing to unicode string + params[key] = six.text_type(dict_param[key]) + + return params + + def _save_parameters(self, other): + # we check for None later to see if + # we actually have parameters to inject, + # so encapsulate if our parameter is actually None + if other is None: + self.params = (other, ) + elif isinstance(other, dict): + self.params = self._save_dictionary_parameter(other) + else: + # fallback to casting to unicode, + # this will handle the problematic python code-like + # objects that cannot be deep-copied + try: + self.params = copy.deepcopy(other) + except TypeError: + self.params = six.text_type(other) + + return self + + # overrides to be more string-like + def __unicode__(self): + return self.data + + def __str__(self): + if six.PY3: + return self.__unicode__() + return self.data.encode('utf-8') + + def __getstate__(self): + to_copy = ['_msg', '_right_extra_msg', '_left_extra_msg', + 'domain', 'params', '_locale'] + new_dict = self.__dict__.fromkeys(to_copy) + for attr in to_copy: + new_dict[attr] = copy.deepcopy(self.__dict__[attr]) + + return new_dict + + def __setstate__(self, state): + for (k, v) in state.items(): + setattr(self, k, v) + + # operator overloads + def __add__(self, other): + copied = copy.deepcopy(self) + copied._right_extra_msg += other.__str__() + return copied + + def __radd__(self, other): + copied = copy.deepcopy(self) + copied._left_extra_msg += other.__str__() + return copied + + def __mod__(self, other): + # do a format string to catch and raise + # any possible KeyErrors from missing parameters + self.data % other + copied = copy.deepcopy(self) + return copied._save_parameters(other) + + def __mul__(self, other): + return self.data * other + + def __rmul__(self, other): + return other * self.data + + def __getitem__(self, key): + return self.data[key] + + def __getslice__(self, start, end): + return self.data.__getslice__(start, end) + + def __getattribute__(self, name): + # NOTE(mrodden): handle lossy operations that we can't deal with yet + # These override the UserString implementation, since UserString + # uses our __class__ attribute to try and build a new message + # after running the inner data string through the operation. + # At that point, we have lost the gettext message id and can just + # safely resolve to a string instead. + ops = ['capitalize', 'center', 'decode', 'encode', + 'expandtabs', 'ljust', 'lstrip', 'replace', 'rjust', 'rstrip', + 'strip', 'swapcase', 'title', 'translate', 'upper', 'zfill'] + if name in ops: + return getattr(self.data, name) + else: + return _userString.UserString.__getattribute__(self, name) + + +def get_available_languages(domain): + """Lists the available languages for the given translation domain. + + :param domain: the domain to get languages for + """ + if domain in _AVAILABLE_LANGUAGES: + return copy.copy(_AVAILABLE_LANGUAGES[domain]) + + localedir = '%s_LOCALEDIR' % domain.upper() + find = lambda x: gettext.find(domain, + localedir=os.environ.get(localedir), + languages=[x]) + + # NOTE(mrodden): en_US should always be available (and first in case + # order matters) since our in-line message strings are en_US + language_list = ['en_US'] + # NOTE(luisg): Babel <1.0 used a function called list(), which was + # renamed to locale_identifiers() in >=1.0, the requirements master list + # requires >=0.9.6, uncapped, so defensively work with both. We can remove + # this check when the master list updates to >=1.0, and update all projects + list_identifiers = (getattr(localedata, 'list', None) or + getattr(localedata, 'locale_identifiers')) + locale_identifiers = list_identifiers() + for i in locale_identifiers: + if find(i) is not None: + language_list.append(i) + _AVAILABLE_LANGUAGES[domain] = language_list + return copy.copy(language_list) + + +def get_localized_message(message, user_locale): + """Gets a localized version of the given message in the given locale. + + If the message is not a Message object the message is returned as-is. + If the locale is None the message is translated to the default locale. + + :returns: the translated message in unicode, or the original message if + it could not be translated + """ + translated = message + if isinstance(message, Message): + original_locale = message.locale + message.locale = user_locale + translated = six.text_type(message) + message.locale = original_locale + return translated + + +class LocaleHandler(logging.Handler): + """Handler that can have a locale associated to translate Messages. + + A quick example of how to utilize the Message class above. + LocaleHandler takes a locale and a target logging.Handler object + to forward LogRecord objects to after translating the internal Message. + """ + + def __init__(self, locale, target): + """Initialize a LocaleHandler + + :param locale: locale to use for translating messages + :param target: logging.Handler object to forward + LogRecord objects to after translation + """ + logging.Handler.__init__(self) + self.locale = locale + self.target = target + + def emit(self, record): + if isinstance(record.msg, Message): + # set the locale and resolve to a string + record.msg.locale = self.locale + + self.target.emit(record) diff --git a/mistral/openstack/common/importutils.py b/mistral/openstack/common/importutils.py new file mode 100644 index 00000000..7a303f93 --- /dev/null +++ b/mistral/openstack/common/importutils.py @@ -0,0 +1,68 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 OpenStack Foundation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Import related utilities and helper functions. +""" + +import sys +import traceback + + +def import_class(import_str): + """Returns a class from a string including module and class.""" + mod_str, _sep, class_str = import_str.rpartition('.') + try: + __import__(mod_str) + return getattr(sys.modules[mod_str], class_str) + except (ValueError, AttributeError): + raise ImportError('Class %s cannot be found (%s)' % + (class_str, + traceback.format_exception(*sys.exc_info()))) + + +def import_object(import_str, *args, **kwargs): + """Import a class and return an instance of it.""" + return import_class(import_str)(*args, **kwargs) + + +def import_object_ns(name_space, import_str, *args, **kwargs): + """Tries to import object from default namespace. + + Imports a class and return an instance of it, first by trying + to find the class in a default namespace, then failing back to + a full path if not found in the default namespace. + """ + import_value = "%s.%s" % (name_space, import_str) + try: + return import_class(import_value)(*args, **kwargs) + except ImportError: + return import_class(import_str)(*args, **kwargs) + + +def import_module(import_str): + """Import a module.""" + __import__(import_str) + return sys.modules[import_str] + + +def try_import(import_str, default=None): + """Try to import a module and if it fails return default.""" + try: + return import_module(import_str) + except ImportError: + return default diff --git a/mistral/openstack/common/jsonutils.py b/mistral/openstack/common/jsonutils.py new file mode 100644 index 00000000..7222af5a --- /dev/null +++ b/mistral/openstack/common/jsonutils.py @@ -0,0 +1,180 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# Copyright 2011 Justin Santa Barbara +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +''' +JSON related utilities. + +This module provides a few things: + + 1) A handy function for getting an object down to something that can be + JSON serialized. See to_primitive(). + + 2) Wrappers around loads() and dumps(). The dumps() wrapper will + automatically use to_primitive() for you if needed. + + 3) This sets up anyjson to use the loads() and dumps() wrappers if anyjson + is available. +''' + + +import datetime +import functools +import inspect +import itertools +import json +try: + import xmlrpclib +except ImportError: + # NOTE(jd): xmlrpclib is not shipped with Python 3 + xmlrpclib = None + +import six + +from mistral.openstack.common import gettextutils +from mistral.openstack.common import importutils +from mistral.openstack.common import timeutils + +netaddr = importutils.try_import("netaddr") + +_nasty_type_tests = [inspect.ismodule, inspect.isclass, inspect.ismethod, + inspect.isfunction, inspect.isgeneratorfunction, + inspect.isgenerator, inspect.istraceback, inspect.isframe, + inspect.iscode, inspect.isbuiltin, inspect.isroutine, + inspect.isabstract] + +_simple_types = (six.string_types + six.integer_types + + (type(None), bool, float)) + + +def to_primitive(value, convert_instances=False, convert_datetime=True, + level=0, max_depth=3): + """Convert a complex object into primitives. + + Handy for JSON serialization. We can optionally handle instances, + but since this is a recursive function, we could have cyclical + data structures. + + To handle cyclical data structures we could track the actual objects + visited in a set, but not all objects are hashable. Instead we just + track the depth of the object inspections and don't go too deep. + + Therefore, convert_instances=True is lossy ... be aware. + + """ + # handle obvious types first - order of basic types determined by running + # full tests on nova project, resulting in the following counts: + # 572754 + # 460353 + # 379632 + # 274610 + # 199918 + # 114200 + # 51817 + # 26164 + # 6491 + # 283 + # 19 + if isinstance(value, _simple_types): + return value + + if isinstance(value, datetime.datetime): + if convert_datetime: + return timeutils.strtime(value) + else: + return value + + # value of itertools.count doesn't get caught by nasty_type_tests + # and results in infinite loop when list(value) is called. + if type(value) == itertools.count: + return six.text_type(value) + + # FIXME(vish): Workaround for LP bug 852095. Without this workaround, + # tests that raise an exception in a mocked method that + # has a @wrap_exception with a notifier will fail. If + # we up the dependency to 0.5.4 (when it is released) we + # can remove this workaround. + if getattr(value, '__module__', None) == 'mox': + return 'mock' + + if level > max_depth: + return '?' + + # The try block may not be necessary after the class check above, + # but just in case ... + try: + recursive = functools.partial(to_primitive, + convert_instances=convert_instances, + convert_datetime=convert_datetime, + level=level, + max_depth=max_depth) + if isinstance(value, dict): + return dict((k, recursive(v)) for k, v in value.iteritems()) + elif isinstance(value, (list, tuple)): + return [recursive(lv) for lv in value] + + # It's not clear why xmlrpclib created their own DateTime type, but + # for our purposes, make it a datetime type which is explicitly + # handled + if xmlrpclib and isinstance(value, xmlrpclib.DateTime): + value = datetime.datetime(*tuple(value.timetuple())[:6]) + + if convert_datetime and isinstance(value, datetime.datetime): + return timeutils.strtime(value) + elif isinstance(value, gettextutils.Message): + return value.data + elif hasattr(value, 'iteritems'): + return recursive(dict(value.iteritems()), level=level + 1) + elif hasattr(value, '__iter__'): + return recursive(list(value)) + elif convert_instances and hasattr(value, '__dict__'): + # Likely an instance of something. Watch for cycles. + # Ignore class member vars. + return recursive(value.__dict__, level=level + 1) + elif netaddr and isinstance(value, netaddr.IPAddress): + return six.text_type(value) + else: + if any(test(value) for test in _nasty_type_tests): + return six.text_type(value) + return value + except TypeError: + # Class objects are tricky since they may define something like + # __iter__ defined but it isn't callable as list(). + return six.text_type(value) + + +def dumps(value, default=to_primitive, **kwargs): + return json.dumps(value, default=default, **kwargs) + + +def loads(s): + return json.loads(s) + + +def load(s): + return json.load(s) + + +try: + import anyjson +except ImportError: + pass +else: + anyjson._modules.append((__name__, 'dumps', TypeError, + 'loads', ValueError, 'load')) + anyjson.force_implementation(__name__) diff --git a/mistral/openstack/common/local.py b/mistral/openstack/common/local.py new file mode 100644 index 00000000..e82f17d0 --- /dev/null +++ b/mistral/openstack/common/local.py @@ -0,0 +1,47 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 OpenStack Foundation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Local storage of variables using weak references""" + +import threading +import weakref + + +class WeakLocal(threading.local): + def __getattribute__(self, attr): + rval = super(WeakLocal, self).__getattribute__(attr) + if rval: + # NOTE(mikal): this bit is confusing. What is stored is a weak + # reference, not the value itself. We therefore need to lookup + # the weak reference and return the inner value here. + rval = rval() + return rval + + def __setattr__(self, attr, value): + value = weakref.ref(value) + return super(WeakLocal, self).__setattr__(attr, value) + + +# NOTE(mikal): the name "store" should be deprecated in the future +store = WeakLocal() + +# A "weak" store uses weak references and allows an object to fall out of scope +# when it falls out of scope in the code that uses the thread local storage. A +# "strong" store will hold a reference to the object so that it never falls out +# of scope. +weak_store = WeakLocal() +strong_store = threading.local() diff --git a/mistral/openstack/common/lockutils.py b/mistral/openstack/common/lockutils.py new file mode 100644 index 00000000..132e4917 --- /dev/null +++ b/mistral/openstack/common/lockutils.py @@ -0,0 +1,305 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 OpenStack Foundation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + + +import contextlib +import errno +import functools +import os +import shutil +import subprocess +import sys +import tempfile +import threading +import time +import weakref + +from oslo.config import cfg + +from mistral.openstack.common import fileutils +from mistral.openstack.common.gettextutils import _ # noqa +from mistral.openstack.common import local +from mistral.openstack.common import log as logging + + +LOG = logging.getLogger(__name__) + + +util_opts = [ + cfg.BoolOpt('disable_process_locking', default=False, + help='Whether to disable inter-process locks'), + cfg.StrOpt('lock_path', + default=os.environ.get("MISTRAL_LOCK_PATH"), + help=('Directory to use for lock files.')) +] + + +CONF = cfg.CONF +CONF.register_opts(util_opts) + + +def set_defaults(lock_path): + cfg.set_defaults(util_opts, lock_path=lock_path) + + +class _InterProcessLock(object): + """Lock implementation which allows multiple locks, working around + issues like bugs.debian.org/cgi-bin/bugreport.cgi?bug=632857 and does + not require any cleanup. Since the lock is always held on a file + descriptor rather than outside of the process, the lock gets dropped + automatically if the process crashes, even if __exit__ is not executed. + + There are no guarantees regarding usage by multiple green threads in a + single process here. This lock works only between processes. Exclusive + access between local threads should be achieved using the semaphores + in the @synchronized decorator. + + Note these locks are released when the descriptor is closed, so it's not + safe to close the file descriptor while another green thread holds the + lock. Just opening and closing the lock file can break synchronisation, + so lock files must be accessed only using this abstraction. + """ + + def __init__(self, name): + self.lockfile = None + self.fname = name + + def __enter__(self): + self.lockfile = open(self.fname, 'w') + + while True: + try: + # Using non-blocking locks since green threads are not + # patched to deal with blocking locking calls. + # Also upon reading the MSDN docs for locking(), it seems + # to have a laughable 10 attempts "blocking" mechanism. + self.trylock() + return self + except IOError as e: + if e.errno in (errno.EACCES, errno.EAGAIN): + # external locks synchronise things like iptables + # updates - give it some time to prevent busy spinning + time.sleep(0.01) + else: + raise + + def __exit__(self, exc_type, exc_val, exc_tb): + try: + self.unlock() + self.lockfile.close() + except IOError: + LOG.exception(_("Could not release the acquired lock `%s`"), + self.fname) + + def trylock(self): + raise NotImplementedError() + + def unlock(self): + raise NotImplementedError() + + +class _WindowsLock(_InterProcessLock): + def trylock(self): + msvcrt.locking(self.lockfile.fileno(), msvcrt.LK_NBLCK, 1) + + def unlock(self): + msvcrt.locking(self.lockfile.fileno(), msvcrt.LK_UNLCK, 1) + + +class _PosixLock(_InterProcessLock): + def trylock(self): + fcntl.lockf(self.lockfile, fcntl.LOCK_EX | fcntl.LOCK_NB) + + def unlock(self): + fcntl.lockf(self.lockfile, fcntl.LOCK_UN) + + +if os.name == 'nt': + import msvcrt + InterProcessLock = _WindowsLock +else: + import fcntl + InterProcessLock = _PosixLock + +_semaphores = weakref.WeakValueDictionary() +_semaphores_lock = threading.Lock() + + +@contextlib.contextmanager +def lock(name, lock_file_prefix=None, external=False, lock_path=None): + """Context based lock + + This function yields a `threading.Semaphore` instance (if we don't use + eventlet.monkey_patch(), else `semaphore.Semaphore`) unless external is + True, in which case, it'll yield an InterProcessLock instance. + + :param lock_file_prefix: The lock_file_prefix argument is used to provide + lock files on disk with a meaningful prefix. + + :param external: The external keyword argument denotes whether this lock + should work across multiple processes. This means that if two different + workers both run a a method decorated with @synchronized('mylock', + external=True), only one of them will execute at a time. + + :param lock_path: The lock_path keyword argument is used to specify a + special location for external lock files to live. If nothing is set, then + CONF.lock_path is used as a default. + """ + with _semaphores_lock: + try: + sem = _semaphores[name] + except KeyError: + sem = threading.Semaphore() + _semaphores[name] = sem + + with sem: + LOG.debug(_('Got semaphore "%(lock)s"'), {'lock': name}) + + # NOTE(mikal): I know this looks odd + if not hasattr(local.strong_store, 'locks_held'): + local.strong_store.locks_held = [] + local.strong_store.locks_held.append(name) + + try: + if external and not CONF.disable_process_locking: + LOG.debug(_('Attempting to grab file lock "%(lock)s"'), + {'lock': name}) + + # We need a copy of lock_path because it is non-local + local_lock_path = lock_path or CONF.lock_path + if not local_lock_path: + raise cfg.RequiredOptError('lock_path') + + if not os.path.exists(local_lock_path): + fileutils.ensure_tree(local_lock_path) + LOG.info(_('Created lock path: %s'), local_lock_path) + + def add_prefix(name, prefix): + if not prefix: + return name + sep = '' if prefix.endswith('-') else '-' + return '%s%s%s' % (prefix, sep, name) + + # NOTE(mikal): the lock name cannot contain directory + # separators + lock_file_name = add_prefix(name.replace(os.sep, '_'), + lock_file_prefix) + + lock_file_path = os.path.join(local_lock_path, lock_file_name) + + try: + lock = InterProcessLock(lock_file_path) + with lock as lock: + LOG.debug(_('Got file lock "%(lock)s" at %(path)s'), + {'lock': name, 'path': lock_file_path}) + yield lock + finally: + LOG.debug(_('Released file lock "%(lock)s" at %(path)s'), + {'lock': name, 'path': lock_file_path}) + else: + yield sem + + finally: + local.strong_store.locks_held.remove(name) + + +def synchronized(name, lock_file_prefix=None, external=False, lock_path=None): + """Synchronization decorator. + + Decorating a method like so:: + + @synchronized('mylock') + def foo(self, *args): + ... + + ensures that only one thread will execute the foo method at a time. + + Different methods can share the same lock:: + + @synchronized('mylock') + def foo(self, *args): + ... + + @synchronized('mylock') + def bar(self, *args): + ... + + This way only one of either foo or bar can be executing at a time. + """ + + def wrap(f): + @functools.wraps(f) + def inner(*args, **kwargs): + try: + with lock(name, lock_file_prefix, external, lock_path): + LOG.debug(_('Got semaphore / lock "%(function)s"'), + {'function': f.__name__}) + return f(*args, **kwargs) + finally: + LOG.debug(_('Semaphore / lock released "%(function)s"'), + {'function': f.__name__}) + return inner + return wrap + + +def synchronized_with_prefix(lock_file_prefix): + """Partial object generator for the synchronization decorator. + + Redefine @synchronized in each project like so:: + + (in nova/utils.py) + from nova.openstack.common import lockutils + + synchronized = lockutils.synchronized_with_prefix('nova-') + + + (in nova/foo.py) + from nova import utils + + @utils.synchronized('mylock') + def bar(self, *args): + ... + + The lock_file_prefix argument is used to provide lock files on disk with a + meaningful prefix. + """ + + return functools.partial(synchronized, lock_file_prefix=lock_file_prefix) + + +def main(argv): + """Create a dir for locks and pass it to command from arguments + + If you run this: + python -m openstack.common.lockutils python setup.py testr + + a temporary directory will be created for all your locks and passed to all + your tests in an environment variable. The temporary dir will be deleted + afterwards and the return value will be preserved. + """ + + lock_dir = tempfile.mkdtemp() + os.environ["MISTRAL_LOCK_PATH"] = lock_dir + try: + ret_val = subprocess.call(argv[1:]) + finally: + shutil.rmtree(lock_dir, ignore_errors=True) + return ret_val + + +if __name__ == '__main__': + sys.exit(main(sys.argv)) diff --git a/mistral/openstack/common/log.py b/mistral/openstack/common/log.py new file mode 100644 index 00000000..9b369a3b --- /dev/null +++ b/mistral/openstack/common/log.py @@ -0,0 +1,626 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 OpenStack Foundation. +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Openstack logging handler. + +This module adds to logging functionality by adding the option to specify +a context object when calling the various log methods. If the context object +is not specified, default formatting is used. Additionally, an instance uuid +may be passed as part of the log message, which is intended to make it easier +for admins to find messages related to a specific instance. + +It also allows setting of formatting information through conf. + +""" + +import inspect +import itertools +import logging +import logging.config +import logging.handlers +import os +import re +import sys +import traceback + +from oslo.config import cfg +import six +from six import moves + +from mistral.openstack.common.gettextutils import _ # noqa +from mistral.openstack.common import importutils +from mistral.openstack.common import jsonutils +from mistral.openstack.common import local + + +_DEFAULT_LOG_DATE_FORMAT = "%Y-%m-%d %H:%M:%S" + +_SANITIZE_KEYS = ['adminPass', 'admin_pass', 'password', 'admin_password'] + +# NOTE(ldbragst): Let's build a list of regex objects using the list of +# _SANITIZE_KEYS we already have. This way, we only have to add the new key +# to the list of _SANITIZE_KEYS and we can generate regular expressions +# for XML and JSON automatically. +_SANITIZE_PATTERNS = [] +_FORMAT_PATTERNS = [r'(%(key)s\s*[=]\s*[\"\']).*?([\"\'])', + r'(<%(key)s>).*?()', + r'([\"\']%(key)s[\"\']\s*:\s*[\"\']).*?([\"\'])', + r'([\'"].*?%(key)s[\'"]\s*:\s*u?[\'"]).*?([\'"])'] + +for key in _SANITIZE_KEYS: + for pattern in _FORMAT_PATTERNS: + reg_ex = re.compile(pattern % {'key': key}, re.DOTALL) + _SANITIZE_PATTERNS.append(reg_ex) + + +common_cli_opts = [ + cfg.BoolOpt('debug', + short='d', + default=False, + help='Print debugging output (set logging level to ' + 'DEBUG instead of default WARNING level).'), + cfg.BoolOpt('verbose', + short='v', + default=False, + help='Print more verbose output (set logging level to ' + 'INFO instead of default WARNING level).'), +] + +logging_cli_opts = [ + cfg.StrOpt('log-config-append', + metavar='PATH', + deprecated_name='log-config', + help='The name of logging configuration file. It does not ' + 'disable existing loggers, but just appends specified ' + 'logging configuration to any other existing logging ' + 'options. Please see the Python logging module ' + 'documentation for details on logging configuration ' + 'files.'), + cfg.StrOpt('log-format', + default=None, + metavar='FORMAT', + help='DEPRECATED. ' + 'A logging.Formatter log message format string which may ' + 'use any of the available logging.LogRecord attributes. ' + 'This option is deprecated. Please use ' + 'logging_context_format_string and ' + 'logging_default_format_string instead.'), + cfg.StrOpt('log-date-format', + default=_DEFAULT_LOG_DATE_FORMAT, + metavar='DATE_FORMAT', + help='Format string for %%(asctime)s in log records. ' + 'Default: %(default)s'), + cfg.StrOpt('log-file', + metavar='PATH', + deprecated_name='logfile', + help='(Optional) Name of log file to output to. ' + 'If no default is set, logging will go to stdout.'), + cfg.StrOpt('log-dir', + deprecated_name='logdir', + help='(Optional) The base directory used for relative ' + '--log-file paths'), + cfg.BoolOpt('use-syslog', + default=False, + help='Use syslog for logging.'), + cfg.StrOpt('syslog-log-facility', + default='LOG_USER', + help='syslog facility to receive log lines') +] + +generic_log_opts = [ + cfg.BoolOpt('use_stderr', + default=True, + help='Log output to standard error') +] + +log_opts = [ + cfg.StrOpt('logging_context_format_string', + default='%(asctime)s.%(msecs)03d %(process)d %(levelname)s ' + '%(name)s [%(request_id)s %(user)s %(tenant)s] ' + '%(instance)s%(message)s', + help='format string to use for log messages with context'), + cfg.StrOpt('logging_default_format_string', + default='%(asctime)s.%(msecs)03d %(process)d %(levelname)s ' + '%(name)s [-] %(instance)s%(message)s', + help='format string to use for log messages without context'), + cfg.StrOpt('logging_debug_format_suffix', + default='%(funcName)s %(pathname)s:%(lineno)d', + help='data to append to log format when level is DEBUG'), + cfg.StrOpt('logging_exception_prefix', + default='%(asctime)s.%(msecs)03d %(process)d TRACE %(name)s ' + '%(instance)s', + help='prefix each line of exception output with this format'), + cfg.ListOpt('default_log_levels', + default=[ + 'amqp=WARN', + 'amqplib=WARN', + 'boto=WARN', + 'keystone=INFO', + 'qpid=WARN', + 'sqlalchemy=WARN', + 'suds=INFO', + 'iso8601=WARN', + ], + help='list of logger=LEVEL pairs'), + cfg.BoolOpt('publish_errors', + default=False, + help='publish error events'), + cfg.BoolOpt('fatal_deprecations', + default=False, + help='make deprecations fatal'), + + # NOTE(mikal): there are two options here because sometimes we are handed + # a full instance (and could include more information), and other times we + # are just handed a UUID for the instance. + cfg.StrOpt('instance_format', + default='[instance: %(uuid)s] ', + help='If an instance is passed with the log message, format ' + 'it like this'), + cfg.StrOpt('instance_uuid_format', + default='[instance: %(uuid)s] ', + help='If an instance UUID is passed with the log message, ' + 'format it like this'), +] + +CONF = cfg.CONF +CONF.register_cli_opts(common_cli_opts) +CONF.register_cli_opts(logging_cli_opts) +CONF.register_opts(generic_log_opts) +CONF.register_opts(log_opts) + +# our new audit level +# NOTE(jkoelker) Since we synthesized an audit level, make the logging +# module aware of it so it acts like other levels. +logging.AUDIT = logging.INFO + 1 +logging.addLevelName(logging.AUDIT, 'AUDIT') + + +try: + NullHandler = logging.NullHandler +except AttributeError: # NOTE(jkoelker) NullHandler added in Python 2.7 + class NullHandler(logging.Handler): + def handle(self, record): + pass + + def emit(self, record): + pass + + def createLock(self): + self.lock = None + + +def _dictify_context(context): + if context is None: + return None + if not isinstance(context, dict) and getattr(context, 'to_dict', None): + context = context.to_dict() + return context + + +def _get_binary_name(): + return os.path.basename(inspect.stack()[-1][1]) + + +def _get_log_file_path(binary=None): + logfile = CONF.log_file + logdir = CONF.log_dir + + if logfile and not logdir: + return logfile + + if logfile and logdir: + return os.path.join(logdir, logfile) + + if logdir: + binary = binary or _get_binary_name() + return '%s.log' % (os.path.join(logdir, binary),) + + return None + + +def mask_password(message, secret="***"): + """Replace password with 'secret' in message. + + :param message: The string which includes security information. + :param secret: value with which to replace passwords, defaults to "***". + :returns: The unicode value of message with the password fields masked. + + For example: + >>> mask_password("'adminPass' : 'aaaaa'") + "'adminPass' : '***'" + >>> mask_password("'admin_pass' : 'aaaaa'") + "'admin_pass' : '***'" + >>> mask_password('"password" : "aaaaa"') + '"password" : "***"' + >>> mask_password("'original_password' : 'aaaaa'") + "'original_password' : '***'" + >>> mask_password("u'original_password' : u'aaaaa'") + "u'original_password' : u'***'" + """ + message = six.text_type(message) + + # NOTE(ldbragst): Check to see if anything in message contains any key + # specified in _SANITIZE_KEYS, if not then just return the message since + # we don't have to mask any passwords. + if not any(key in message for key in _SANITIZE_KEYS): + return message + + secret = r'\g<1>' + secret + r'\g<2>' + for pattern in _SANITIZE_PATTERNS: + message = re.sub(pattern, secret, message) + return message + + +class BaseLoggerAdapter(logging.LoggerAdapter): + + def audit(self, msg, *args, **kwargs): + self.log(logging.AUDIT, msg, *args, **kwargs) + + +class LazyAdapter(BaseLoggerAdapter): + def __init__(self, name='unknown', version='unknown'): + self._logger = None + self.extra = {} + self.name = name + self.version = version + + @property + def logger(self): + if not self._logger: + self._logger = getLogger(self.name, self.version) + return self._logger + + +class ContextAdapter(BaseLoggerAdapter): + warn = logging.LoggerAdapter.warning + + def __init__(self, logger, project_name, version_string): + self.logger = logger + self.project = project_name + self.version = version_string + + @property + def handlers(self): + return self.logger.handlers + + def deprecated(self, msg, *args, **kwargs): + stdmsg = _("Deprecated: %s") % msg + if CONF.fatal_deprecations: + self.critical(stdmsg, *args, **kwargs) + raise DeprecatedConfig(msg=stdmsg) + else: + self.warn(stdmsg, *args, **kwargs) + + def process(self, msg, kwargs): + # NOTE(mrodden): catch any Message/other object and + # coerce to unicode before they can get + # to the python logging and possibly + # cause string encoding trouble + if not isinstance(msg, six.string_types): + msg = six.text_type(msg) + + if 'extra' not in kwargs: + kwargs['extra'] = {} + extra = kwargs['extra'] + + context = kwargs.pop('context', None) + if not context: + context = getattr(local.store, 'context', None) + if context: + extra.update(_dictify_context(context)) + + instance = kwargs.pop('instance', None) + instance_uuid = (extra.get('instance_uuid', None) or + kwargs.pop('instance_uuid', None)) + instance_extra = '' + if instance: + instance_extra = CONF.instance_format % instance + elif instance_uuid: + instance_extra = (CONF.instance_uuid_format + % {'uuid': instance_uuid}) + extra.update({'instance': instance_extra}) + + extra.update({"project": self.project}) + extra.update({"version": self.version}) + extra['extra'] = extra.copy() + return msg, kwargs + + +class JSONFormatter(logging.Formatter): + def __init__(self, fmt=None, datefmt=None): + # NOTE(jkoelker) we ignore the fmt argument, but its still there + # since logging.config.fileConfig passes it. + self.datefmt = datefmt + + def formatException(self, ei, strip_newlines=True): + lines = traceback.format_exception(*ei) + if strip_newlines: + lines = [itertools.ifilter( + lambda x: x, + line.rstrip().splitlines()) for line in lines] + lines = list(itertools.chain(*lines)) + return lines + + def format(self, record): + message = {'message': record.getMessage(), + 'asctime': self.formatTime(record, self.datefmt), + 'name': record.name, + 'msg': record.msg, + 'args': record.args, + 'levelname': record.levelname, + 'levelno': record.levelno, + 'pathname': record.pathname, + 'filename': record.filename, + 'module': record.module, + 'lineno': record.lineno, + 'funcname': record.funcName, + 'created': record.created, + 'msecs': record.msecs, + 'relative_created': record.relativeCreated, + 'thread': record.thread, + 'thread_name': record.threadName, + 'process_name': record.processName, + 'process': record.process, + 'traceback': None} + + if hasattr(record, 'extra'): + message['extra'] = record.extra + + if record.exc_info: + message['traceback'] = self.formatException(record.exc_info) + + return jsonutils.dumps(message) + + +def _create_logging_excepthook(product_name): + def logging_excepthook(exc_type, value, tb): + extra = {} + if CONF.verbose: + extra['exc_info'] = (exc_type, value, tb) + getLogger(product_name).critical(str(value), **extra) + return logging_excepthook + + +class LogConfigError(Exception): + + message = _('Error loading logging config %(log_config)s: %(err_msg)s') + + def __init__(self, log_config, err_msg): + self.log_config = log_config + self.err_msg = err_msg + + def __str__(self): + return self.message % dict(log_config=self.log_config, + err_msg=self.err_msg) + + +def _load_log_config(log_config_append): + try: + logging.config.fileConfig(log_config_append, + disable_existing_loggers=False) + except moves.configparser.Error as exc: + raise LogConfigError(log_config_append, str(exc)) + + +def setup(product_name): + """Setup logging.""" + if CONF.log_config_append: + _load_log_config(CONF.log_config_append) + else: + _setup_logging_from_conf() + sys.excepthook = _create_logging_excepthook(product_name) + + +def set_defaults(logging_context_format_string): + cfg.set_defaults(log_opts, + logging_context_format_string= + logging_context_format_string) + + +def _find_facility_from_conf(): + facility_names = logging.handlers.SysLogHandler.facility_names + facility = getattr(logging.handlers.SysLogHandler, + CONF.syslog_log_facility, + None) + + if facility is None and CONF.syslog_log_facility in facility_names: + facility = facility_names.get(CONF.syslog_log_facility) + + if facility is None: + valid_facilities = facility_names.keys() + consts = ['LOG_AUTH', 'LOG_AUTHPRIV', 'LOG_CRON', 'LOG_DAEMON', + 'LOG_FTP', 'LOG_KERN', 'LOG_LPR', 'LOG_MAIL', 'LOG_NEWS', + 'LOG_AUTH', 'LOG_SYSLOG', 'LOG_USER', 'LOG_UUCP', + 'LOG_LOCAL0', 'LOG_LOCAL1', 'LOG_LOCAL2', 'LOG_LOCAL3', + 'LOG_LOCAL4', 'LOG_LOCAL5', 'LOG_LOCAL6', 'LOG_LOCAL7'] + valid_facilities.extend(consts) + raise TypeError(_('syslog facility must be one of: %s') % + ', '.join("'%s'" % fac + for fac in valid_facilities)) + + return facility + + +def _setup_logging_from_conf(): + log_root = getLogger(None).logger + for handler in log_root.handlers: + log_root.removeHandler(handler) + + if CONF.use_syslog: + facility = _find_facility_from_conf() + syslog = logging.handlers.SysLogHandler(address='/dev/log', + facility=facility) + log_root.addHandler(syslog) + + logpath = _get_log_file_path() + if logpath: + filelog = logging.handlers.WatchedFileHandler(logpath) + log_root.addHandler(filelog) + + if CONF.use_stderr: + streamlog = ColorHandler() + log_root.addHandler(streamlog) + + elif not CONF.log_file: + # pass sys.stdout as a positional argument + # python2.6 calls the argument strm, in 2.7 it's stream + streamlog = logging.StreamHandler(sys.stdout) + log_root.addHandler(streamlog) + + if CONF.publish_errors: + handler = importutils.import_object( + "mistral.openstack.common.log_handler.PublishErrorsHandler", + logging.ERROR) + log_root.addHandler(handler) + + datefmt = CONF.log_date_format + for handler in log_root.handlers: + # NOTE(alaski): CONF.log_format overrides everything currently. This + # should be deprecated in favor of context aware formatting. + if CONF.log_format: + handler.setFormatter(logging.Formatter(fmt=CONF.log_format, + datefmt=datefmt)) + log_root.info('Deprecated: log_format is now deprecated and will ' + 'be removed in the next release') + else: + handler.setFormatter(ContextFormatter(datefmt=datefmt)) + + if CONF.debug: + log_root.setLevel(logging.DEBUG) + elif CONF.verbose: + log_root.setLevel(logging.INFO) + else: + log_root.setLevel(logging.WARNING) + + for pair in CONF.default_log_levels: + mod, _sep, level_name = pair.partition('=') + level = logging.getLevelName(level_name) + logger = logging.getLogger(mod) + logger.setLevel(level) + +_loggers = {} + + +def getLogger(name='unknown', version='unknown'): + if name not in _loggers: + _loggers[name] = ContextAdapter(logging.getLogger(name), + name, + version) + return _loggers[name] + + +def getLazyLogger(name='unknown', version='unknown'): + """Returns lazy logger. + + Creates a pass-through logger that does not create the real logger + until it is really needed and delegates all calls to the real logger + once it is created. + """ + return LazyAdapter(name, version) + + +class WritableLogger(object): + """A thin wrapper that responds to `write` and logs.""" + + def __init__(self, logger, level=logging.INFO): + self.logger = logger + self.level = level + + def write(self, msg): + self.logger.log(self.level, msg) + + +class ContextFormatter(logging.Formatter): + """A context.RequestContext aware formatter configured through flags. + + The flags used to set format strings are: logging_context_format_string + and logging_default_format_string. You can also specify + logging_debug_format_suffix to append extra formatting if the log level is + debug. + + For information about what variables are available for the formatter see: + http://docs.python.org/library/logging.html#formatter + + """ + + def format(self, record): + """Uses contextstring if request_id is set, otherwise default.""" + # NOTE(sdague): default the fancier formating params + # to an empty string so we don't throw an exception if + # they get used + for key in ('instance', 'color'): + if key not in record.__dict__: + record.__dict__[key] = '' + + if record.__dict__.get('request_id', None): + self._fmt = CONF.logging_context_format_string + else: + self._fmt = CONF.logging_default_format_string + + if (record.levelno == logging.DEBUG and + CONF.logging_debug_format_suffix): + self._fmt += " " + CONF.logging_debug_format_suffix + + # Cache this on the record, Logger will respect our formated copy + if record.exc_info: + record.exc_text = self.formatException(record.exc_info, record) + return logging.Formatter.format(self, record) + + def formatException(self, exc_info, record=None): + """Format exception output with CONF.logging_exception_prefix.""" + if not record: + return logging.Formatter.formatException(self, exc_info) + + stringbuffer = moves.StringIO() + traceback.print_exception(exc_info[0], exc_info[1], exc_info[2], + None, stringbuffer) + lines = stringbuffer.getvalue().split('\n') + stringbuffer.close() + + if CONF.logging_exception_prefix.find('%(asctime)') != -1: + record.asctime = self.formatTime(record, self.datefmt) + + formatted_lines = [] + for line in lines: + pl = CONF.logging_exception_prefix % record.__dict__ + fl = '%s%s' % (pl, line) + formatted_lines.append(fl) + return '\n'.join(formatted_lines) + + +class ColorHandler(logging.StreamHandler): + LEVEL_COLORS = { + logging.DEBUG: '\033[00;32m', # GREEN + logging.INFO: '\033[00;36m', # CYAN + logging.AUDIT: '\033[01;36m', # BOLD CYAN + logging.WARN: '\033[01;33m', # BOLD YELLOW + logging.ERROR: '\033[01;31m', # BOLD RED + logging.CRITICAL: '\033[01;31m', # BOLD RED + } + + def format(self, record): + record.color = self.LEVEL_COLORS[record.levelno] + return logging.StreamHandler.format(self, record) + + +class DeprecatedConfig(Exception): + message = _("Fatal call to deprecated config: %(msg)s") + + def __init__(self, msg): + super(Exception, self).__init__(self.message % dict(msg=msg)) diff --git a/mistral/openstack/common/strutils.py b/mistral/openstack/common/strutils.py new file mode 100644 index 00000000..10227081 --- /dev/null +++ b/mistral/openstack/common/strutils.py @@ -0,0 +1,218 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 OpenStack Foundation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +System-level utilities and helper functions. +""" + +import re +import sys +import unicodedata + +import six + +from mistral.openstack.common.gettextutils import _ # noqa + + +# Used for looking up extensions of text +# to their 'multiplied' byte amount +BYTE_MULTIPLIERS = { + '': 1, + 't': 1024 ** 4, + 'g': 1024 ** 3, + 'm': 1024 ** 2, + 'k': 1024, +} +BYTE_REGEX = re.compile(r'(^-?\d+)(\D*)') + +TRUE_STRINGS = ('1', 't', 'true', 'on', 'y', 'yes') +FALSE_STRINGS = ('0', 'f', 'false', 'off', 'n', 'no') + +SLUGIFY_STRIP_RE = re.compile(r"[^\w\s-]") +SLUGIFY_HYPHENATE_RE = re.compile(r"[-\s]+") + + +def int_from_bool_as_string(subject): + """Interpret a string as a boolean and return either 1 or 0. + + Any string value in: + + ('True', 'true', 'On', 'on', '1') + + is interpreted as a boolean True. + + Useful for JSON-decoded stuff and config file parsing + """ + return bool_from_string(subject) and 1 or 0 + + +def bool_from_string(subject, strict=False): + """Interpret a string as a boolean. + + A case-insensitive match is performed such that strings matching 't', + 'true', 'on', 'y', 'yes', or '1' are considered True and, when + `strict=False`, anything else is considered False. + + Useful for JSON-decoded stuff and config file parsing. + + If `strict=True`, unrecognized values, including None, will raise a + ValueError which is useful when parsing values passed in from an API call. + Strings yielding False are 'f', 'false', 'off', 'n', 'no', or '0'. + """ + if not isinstance(subject, six.string_types): + subject = str(subject) + + lowered = subject.strip().lower() + + if lowered in TRUE_STRINGS: + return True + elif lowered in FALSE_STRINGS: + return False + elif strict: + acceptable = ', '.join( + "'%s'" % s for s in sorted(TRUE_STRINGS + FALSE_STRINGS)) + msg = _("Unrecognized value '%(val)s', acceptable values are:" + " %(acceptable)s") % {'val': subject, + 'acceptable': acceptable} + raise ValueError(msg) + else: + return False + + +def safe_decode(text, incoming=None, errors='strict'): + """Decodes incoming str using `incoming` if they're not already unicode. + + :param incoming: Text's current encoding + :param errors: Errors handling policy. See here for valid + values http://docs.python.org/2/library/codecs.html + :returns: text or a unicode `incoming` encoded + representation of it. + :raises TypeError: If text is not an instance of str + """ + if not isinstance(text, six.string_types): + raise TypeError("%s can't be decoded" % type(text)) + + if isinstance(text, six.text_type): + return text + + if not incoming: + incoming = (sys.stdin.encoding or + sys.getdefaultencoding()) + + try: + return text.decode(incoming, errors) + except UnicodeDecodeError: + # Note(flaper87) If we get here, it means that + # sys.stdin.encoding / sys.getdefaultencoding + # didn't return a suitable encoding to decode + # text. This happens mostly when global LANG + # var is not set correctly and there's no + # default encoding. In this case, most likely + # python will use ASCII or ANSI encoders as + # default encodings but they won't be capable + # of decoding non-ASCII characters. + # + # Also, UTF-8 is being used since it's an ASCII + # extension. + return text.decode('utf-8', errors) + + +def safe_encode(text, incoming=None, + encoding='utf-8', errors='strict'): + """Encodes incoming str/unicode using `encoding`. + + If incoming is not specified, text is expected to be encoded with + current python's default encoding. (`sys.getdefaultencoding`) + + :param incoming: Text's current encoding + :param encoding: Expected encoding for text (Default UTF-8) + :param errors: Errors handling policy. See here for valid + values http://docs.python.org/2/library/codecs.html + :returns: text or a bytestring `encoding` encoded + representation of it. + :raises TypeError: If text is not an instance of str + """ + if not isinstance(text, six.string_types): + raise TypeError("%s can't be encoded" % type(text)) + + if not incoming: + incoming = (sys.stdin.encoding or + sys.getdefaultencoding()) + + if isinstance(text, six.text_type): + return text.encode(encoding, errors) + elif text and encoding != incoming: + # Decode text before encoding it with `encoding` + text = safe_decode(text, incoming, errors) + return text.encode(encoding, errors) + + return text + + +def to_bytes(text, default=0): + """Converts a string into an integer of bytes. + + Looks at the last characters of the text to determine + what conversion is needed to turn the input text into a byte number. + Supports "B, K(B), M(B), G(B), and T(B)". (case insensitive) + + :param text: String input for bytes size conversion. + :param default: Default return value when text is blank. + + """ + match = BYTE_REGEX.search(text) + if match: + magnitude = int(match.group(1)) + mult_key_org = match.group(2) + if not mult_key_org: + return magnitude + elif text: + msg = _('Invalid string format: %s') % text + raise TypeError(msg) + else: + return default + mult_key = mult_key_org.lower().replace('b', '', 1) + multiplier = BYTE_MULTIPLIERS.get(mult_key) + if multiplier is None: + msg = _('Unknown byte multiplier: %s') % mult_key_org + raise TypeError(msg) + return magnitude * multiplier + + +def to_slug(value, incoming=None, errors="strict"): + """Normalize string. + + Convert to lowercase, remove non-word characters, and convert spaces + to hyphens. + + Inspired by Django's `slugify` filter. + + :param value: Text to slugify + :param incoming: Text's current encoding + :param errors: Errors handling policy. See here for valid + values http://docs.python.org/2/library/codecs.html + :returns: slugified unicode representation of `value` + :raises TypeError: If text is not an instance of str + """ + value = safe_decode(value, incoming, errors) + # NOTE(aababilov): no need to use safe_(encode|decode) here: + # encodings are always "ascii", error handling is always "ignore" + # and types are always known (first: unicode; second: str) + value = unicodedata.normalize("NFKD", value).encode( + "ascii", "ignore").decode("ascii") + value = SLUGIFY_STRIP_RE.sub("", value).strip().lower() + return SLUGIFY_HYPHENATE_RE.sub("-", value) diff --git a/mistral/openstack/common/test.py b/mistral/openstack/common/test.py new file mode 100644 index 00000000..a1a736f7 --- /dev/null +++ b/mistral/openstack/common/test.py @@ -0,0 +1,54 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright (c) 2013 Hewlett-Packard Development Company, L.P. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Common utilities used in testing""" + +import os + +import fixtures +import testtools + +_TRUE_VALUES = ('True', 'true', '1', 'yes') + + +class BaseTestCase(testtools.TestCase): + + def setUp(self): + super(BaseTestCase, self).setUp() + self._set_timeout() + self._fake_output() + self.useFixture(fixtures.FakeLogger('mistral.openstack.common')) + self.useFixture(fixtures.NestedTempfile()) + self.useFixture(fixtures.TempHomeDir()) + + def _set_timeout(self): + test_timeout = os.environ.get('OS_TEST_TIMEOUT', 0) + try: + test_timeout = int(test_timeout) + except ValueError: + # If timeout value is invalid do not set a timeout. + test_timeout = 0 + if test_timeout > 0: + self.useFixture(fixtures.Timeout(test_timeout, gentle=True)) + + def _fake_output(self): + if os.environ.get('OS_STDOUT_CAPTURE') in _TRUE_VALUES: + stdout = self.useFixture(fixtures.StringStream('stdout')).stream + self.useFixture(fixtures.MonkeyPatch('sys.stdout', stdout)) + if os.environ.get('OS_STDERR_CAPTURE') in _TRUE_VALUES: + stderr = self.useFixture(fixtures.StringStream('stderr')).stream + self.useFixture(fixtures.MonkeyPatch('sys.stderr', stderr)) diff --git a/mistral/openstack/common/timeutils.py b/mistral/openstack/common/timeutils.py new file mode 100644 index 00000000..b79ebf37 --- /dev/null +++ b/mistral/openstack/common/timeutils.py @@ -0,0 +1,197 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 OpenStack Foundation. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Time related utilities and helper functions. +""" + +import calendar +import datetime +import time + +import iso8601 +import six + + +# ISO 8601 extended time format with microseconds +_ISO8601_TIME_FORMAT_SUBSECOND = '%Y-%m-%dT%H:%M:%S.%f' +_ISO8601_TIME_FORMAT = '%Y-%m-%dT%H:%M:%S' +PERFECT_TIME_FORMAT = _ISO8601_TIME_FORMAT_SUBSECOND + + +def isotime(at=None, subsecond=False): + """Stringify time in ISO 8601 format.""" + if not at: + at = utcnow() + st = at.strftime(_ISO8601_TIME_FORMAT + if not subsecond + else _ISO8601_TIME_FORMAT_SUBSECOND) + tz = at.tzinfo.tzname(None) if at.tzinfo else 'UTC' + st += ('Z' if tz == 'UTC' else tz) + return st + + +def parse_isotime(timestr): + """Parse time from ISO 8601 format.""" + try: + return iso8601.parse_date(timestr) + except iso8601.ParseError as e: + raise ValueError(six.text_type(e)) + except TypeError as e: + raise ValueError(six.text_type(e)) + + +def strtime(at=None, fmt=PERFECT_TIME_FORMAT): + """Returns formatted utcnow.""" + if not at: + at = utcnow() + return at.strftime(fmt) + + +def parse_strtime(timestr, fmt=PERFECT_TIME_FORMAT): + """Turn a formatted time back into a datetime.""" + return datetime.datetime.strptime(timestr, fmt) + + +def normalize_time(timestamp): + """Normalize time in arbitrary timezone to UTC naive object.""" + offset = timestamp.utcoffset() + if offset is None: + return timestamp + return timestamp.replace(tzinfo=None) - offset + + +def is_older_than(before, seconds): + """Return True if before is older than seconds.""" + if isinstance(before, six.string_types): + before = parse_strtime(before).replace(tzinfo=None) + return utcnow() - before > datetime.timedelta(seconds=seconds) + + +def is_newer_than(after, seconds): + """Return True if after is newer than seconds.""" + if isinstance(after, six.string_types): + after = parse_strtime(after).replace(tzinfo=None) + return after - utcnow() > datetime.timedelta(seconds=seconds) + + +def utcnow_ts(): + """Timestamp version of our utcnow function.""" + if utcnow.override_time is None: + # NOTE(kgriffs): This is several times faster + # than going through calendar.timegm(...) + return int(time.time()) + + return calendar.timegm(utcnow().timetuple()) + + +def utcnow(): + """Overridable version of utils.utcnow.""" + if utcnow.override_time: + try: + return utcnow.override_time.pop(0) + except AttributeError: + return utcnow.override_time + return datetime.datetime.utcnow() + + +def iso8601_from_timestamp(timestamp): + """Returns a iso8601 formated date from timestamp.""" + return isotime(datetime.datetime.utcfromtimestamp(timestamp)) + + +utcnow.override_time = None + + +def set_time_override(override_time=None): + """Overrides utils.utcnow. + + Make it return a constant time or a list thereof, one at a time. + + :param override_time: datetime instance or list thereof. If not + given, defaults to the current UTC time. + """ + utcnow.override_time = override_time or datetime.datetime.utcnow() + + +def advance_time_delta(timedelta): + """Advance overridden time using a datetime.timedelta.""" + assert(not utcnow.override_time is None) + try: + for dt in utcnow.override_time: + dt += timedelta + except TypeError: + utcnow.override_time += timedelta + + +def advance_time_seconds(seconds): + """Advance overridden time by seconds.""" + advance_time_delta(datetime.timedelta(0, seconds)) + + +def clear_time_override(): + """Remove the overridden time.""" + utcnow.override_time = None + + +def marshall_now(now=None): + """Make an rpc-safe datetime with microseconds. + + Note: tzinfo is stripped, but not required for relative times. + """ + if not now: + now = utcnow() + return dict(day=now.day, month=now.month, year=now.year, hour=now.hour, + minute=now.minute, second=now.second, + microsecond=now.microsecond) + + +def unmarshall_time(tyme): + """Unmarshall a datetime dict.""" + return datetime.datetime(day=tyme['day'], + month=tyme['month'], + year=tyme['year'], + hour=tyme['hour'], + minute=tyme['minute'], + second=tyme['second'], + microsecond=tyme['microsecond']) + + +def delta_seconds(before, after): + """Return the difference between two timing objects. + + Compute the difference in seconds between two date, time, or + datetime objects (as a float, to microsecond resolution). + """ + delta = after - before + try: + return delta.total_seconds() + except AttributeError: + return ((delta.days * 24 * 3600) + delta.seconds + + float(delta.microseconds) / (10 ** 6)) + + +def is_soon(dt, window): + """Determines if time is going to happen in the next window seconds. + + :params dt: the time + :params window: minimum seconds to remain to consider the time not soon + + :return: True if expiration is within the given duration + """ + soon = (utcnow() + datetime.timedelta(seconds=window)) + return normalize_time(dt) <= soon diff --git a/mistral/tests/__init__.py b/mistral/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mistral/utils.py b/mistral/utils.py new file mode 100644 index 00000000..7d93825c --- /dev/null +++ b/mistral/utils.py @@ -0,0 +1,13 @@ +# Copyright (c) 2013 Mirantis, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. diff --git a/openstack-common.conf b/openstack-common.conf new file mode 100644 index 00000000..fd81b62d --- /dev/null +++ b/openstack-common.conf @@ -0,0 +1,11 @@ +[DEFAULT] + +# The list of modules to copy from oslo-incubator.git +module=cliutils +module=db +module=db.sqlalchemy +module=log +module=test + +# The base module to hold the copy of openstack.common +base=mistral diff --git a/requirements.txt b/requirements.txt index 6207a8af..40d7e39b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,7 @@ pbr>=0.5.21,<1.0 taskflow pyyaml +pecan>=0.2.0 +WSME>=0.5b6 +amqplib>=0.6.1 +argparse