diff --git a/etc/trove/trove-guestagent.conf.sample b/etc/trove/trove-guestagent.conf.sample index db71cd8e39..5d4c8fea5b 100644 --- a/etc/trove/trove-guestagent.conf.sample +++ b/etc/trove/trove-guestagent.conf.sample @@ -37,9 +37,6 @@ rabbit_password=f7999d1955c5014aa32c # RabbitMQ topic used for OpenStack notifications. (list value) #rabbit_notification_topic = ['notifications'] -# Path to the extensions -api_extensions_path = trove/extensions/routes - # Configuration options for talking to nova via the novaclient. # These options are for an admin user in your keystone config. # It proxies the token received from the user to send to nova via this admin users creds, diff --git a/etc/trove/trove.conf.sample b/etc/trove/trove.conf.sample index db9cee5a69..6b8dedf4b7 100644 --- a/etc/trove/trove.conf.sample +++ b/etc/trove/trove.conf.sample @@ -68,10 +68,6 @@ sql_idle_timeout = 3600 #DB Api Implementation db_api_implementation = "trove.db.sqlalchemy.api" -# Path to the extensions. -# $pybasedir is the path to the installed trove package -api_extensions_path = $pybasedir/extensions/routes - # Configuration options for talking to nova via the novaclient. trove_auth_url = http://0.0.0.0:5000/v2.0 #nova_compute_url = http://localhost:8774/v2 diff --git a/etc/trove/trove.conf.test b/etc/trove/trove.conf.test index 2dd1bf5e1b..793ec5df7a 100644 --- a/etc/trove/trove.conf.test +++ b/etc/trove/trove.conf.test @@ -60,9 +60,6 @@ sql_idle_timeout = 3600 #DB Api Implementation db_api_implementation = trove.db.sqlalchemy.api -# Path to the extensions -api_extensions_path = trove/extensions/routes - # Configuration options for talking to nova via the novaclient. # These options are for an admin user in your keystone config. # It proxy's the token received from the user to send to nova via this admin users creds, diff --git a/requirements.txt b/requirements.txt index a6b51cf7db..e55781668d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -26,3 +26,5 @@ pexpect>=3.1 # ISC License oslo.config>=1.4.0.0a3 MySQL-python Babel>=1.3 +six>=1.7.0 +stevedore>=0.14 diff --git a/setup.cfg b/setup.cfg index 81e16f4822..8affb58544 100644 --- a/setup.cfg +++ b/setup.cfg @@ -23,7 +23,7 @@ packages = trove [entry_points] -console_scripts = +console_scripts = trove-api = trove.cmd.api:main trove-taskmanager = trove.cmd.taskmanager:main trove-mgmt-taskmanager = trove.cmd.taskmanager:mgmt_main @@ -32,6 +32,12 @@ console_scripts = trove-guestagent = trove.cmd.guest:main trove-fake-mode = trove.cmd.fakemode:main +trove.api.extensions = + account = trove.extensions.routes.account:Account + mgmt = trove.extensions.routes.mgmt:Mgmt + mysql = trove.extensions.routes.mysql:Mysql + security_group = trove.extensions.routes.security_group:Security_group + [global] setup-hooks = pbr.hooks.setup_hook diff --git a/test-requirements.txt b/test-requirements.txt index ae3eaa32e1..26c7196cfa 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -17,5 +17,4 @@ mox3>=0.7.0 testtools>=0.9.34 discover testrepository>=0.0.18 -six>=1.7.0 diff --git a/trove/common/cfg.py b/trove/common/cfg.py index a3a08caa28..3b403ef064 100644 --- a/trove/common/cfg.py +++ b/trove/common/cfg.py @@ -37,8 +37,6 @@ common_opts = [ cfg.IntOpt('sql_idle_timeout', default=3600), cfg.BoolOpt('sql_query_log', default=False), cfg.IntOpt('bind_port', default=8779), - cfg.StrOpt('api_extensions_path', default='$pybasedir/extensions/routes', - help='Path to extensions.'), cfg.StrOpt('api_paste_config', default="api-paste.ini", help='File name for the paste.deploy config for trove-api.'), diff --git a/trove/common/extensions.py b/trove/common/extensions.py index 17d9878995..537ab9d6e8 100644 --- a/trove/common/extensions.py +++ b/trove/common/extensions.py @@ -13,22 +13,496 @@ # License for the specific language governing permissions and limitations # under the License. -import os +import abc import routes -from trove.openstack.common import log as logging +import six +import stevedore +import webob.dec +import webob.exc -from trove.openstack.common import extensions +import trove.openstack.common.wsgi as os_wsgi + +from lxml import etree +from trove.openstack.common import log as logging +from trove.openstack.common import exception from trove.common import cfg from trove.common import wsgi +from trove.openstack.common.gettextutils import _ # noqa LOG = logging.getLogger(__name__) - -ExtensionsDescriptor = extensions.ExtensionDescriptor - CONF = cfg.CONF +DEFAULT_XMLNS = "http://docs.openstack.org/trove" +XMLNS_ATOM = "http://www.w3.org/2005/Atom" -class ResourceExtension(extensions.ResourceExtension): +@six.add_metaclass(abc.ABCMeta) +class ExtensionDescriptor(object): + """Base class that defines the contract for extensions. + + Note that you don't have to derive from this class to have a valid + extension; it is purely a convenience. + + """ + @abc.abstractmethod + def get_name(self): + """The name of the extension. + + e.g. 'Fox In Socks' + + """ + pass + + @abc.abstractmethod + def get_alias(self): + """The alias for the extension. + + e.g. 'FOXNSOX' + + """ + pass + + @abc.abstractmethod + def get_description(self): + """Friendly description for the extension. + + e.g. 'The Fox In Socks Extension' + + """ + pass + + @abc.abstractmethod + def get_namespace(self): + """The XML namespace for the extension. + + e.g. 'http://www.fox.in.socks/api/ext/pie/v1.0' + + """ + pass + + @abc.abstractmethod + def get_updated(self): + """The timestamp when the extension was last updated. + + e.g. '2011-01-22T13:25:27-06:00' + + """ + pass + + def get_resources(self): + """List of extensions.ResourceExtension extension objects. + + Resources define new nouns, and are accessible through URLs. + + """ + resources = [] + return resources + + def get_actions(self): + """List of extensions.ActionExtension extension objects. + + Actions are verbs callable from the API. + + """ + actions = [] + return actions + + def get_request_extensions(self): + """List of extensions.RequestException extension objects. + + Request extensions are used to handle custom request data. + + """ + request_exts = [] + return request_exts + + +class ActionExtensionController(object): + def __init__(self, application): + self.application = application + self.action_handlers = {} + + def add_action(self, action_name, handler): + self.action_handlers[action_name] = handler + + def action(self, req, id, body): + for action_name, handler in self.action_handlers.iteritems(): + if action_name in body: + return handler(body, req, id) + # no action handler found (bump to downstream application) + res = self.application + return res + + +class ActionExtensionResource(wsgi.Resource): + + def __init__(self, application): + controller = ActionExtensionController(application) + wsgi.Resource.__init__(self, controller) + + def add_action(self, action_name, handler): + self.controller.add_action(action_name, handler) + + +class RequestExtensionController(object): + + def __init__(self, application): + self.application = application + self.handlers = [] + + def add_handler(self, handler): + self.handlers.append(handler) + + def process(self, req, *args, **kwargs): + res = req.get_response(self.application) + # currently request handlers are un-ordered + for handler in self.handlers: + res = handler(req, res) + return res + + +class RequestExtensionResource(wsgi.Resource): + + def __init__(self, application): + controller = RequestExtensionController(application) + wsgi.Resource.__init__(self, controller) + + def add_handler(self, handler): + self.controller.add_handler(handler) + + +class ExtensionsResource(wsgi.Resource): + + def __init__(self, extension_manager): + self.extension_manager = extension_manager + body_serializers = {'application/xml': ExtensionsXMLSerializer()} + serializer = os_wsgi.ResponseSerializer( + body_serializers=body_serializers) + super(ExtensionsResource, self).__init__(self, None, serializer) + + def _translate(self, ext): + ext_data = {} + ext_data['name'] = ext.get_name() + ext_data['alias'] = ext.get_alias() + ext_data['description'] = ext.get_description() + ext_data['namespace'] = ext.get_namespace() + ext_data['updated'] = ext.get_updated() + ext_data['links'] = [] + return ext_data + + def index(self, req): + extensions = [] + for _alias, ext in self.extension_manager.extensions.iteritems(): + extensions.append(self._translate(ext)) + return dict(extensions=extensions) + + def show(self, req, id): + # NOTE(dprince): the extensions alias is used as the 'id' for show + ext = self.extension_manager.extensions.get(id, None) + if not ext: + raise webob.exc.HTTPNotFound( + _("Extension with alias %s does not exist") % id) + + return dict(extension=self._translate(ext)) + + def delete(self, req, id): + raise webob.exc.HTTPNotFound() + + def create(self, req): + raise webob.exc.HTTPNotFound() + + +class ExtensionMiddleware(wsgi.Middleware): + """Extensions middleware for WSGI.""" + + @classmethod + def factory(cls, global_config, **local_config): + """Paste factory.""" + def _factory(app): + return cls(app, global_config, **local_config) + return _factory + + def _action_ext_resources(self, application, ext_mgr, mapper): + """Return a dict of ActionExtensionResource-s by collection.""" + action_resources = {} + for action in ext_mgr.get_actions(): + if action.collection not in action_resources.keys(): + resource = ActionExtensionResource(application) + mapper.connect("/%s/:(id)/action.:(format)" % + action.collection, + action='action', + controller=resource, + conditions=dict(method=['POST'])) + mapper.connect("/%s/:(id)/action" % + action.collection, + action='action', + controller=resource, + conditions=dict(method=['POST'])) + action_resources[action.collection] = resource + + return action_resources + + def _request_ext_resources(self, application, ext_mgr, mapper): + """Returns a dict of RequestExtensionResource-s by collection.""" + request_ext_resources = {} + for req_ext in ext_mgr.get_request_extensions(): + if req_ext.key not in request_ext_resources.keys(): + resource = RequestExtensionResource(application) + mapper.connect(req_ext.url_route + '.:(format)', + action='process', + controller=resource, + conditions=req_ext.conditions) + + mapper.connect(req_ext.url_route, + action='process', + controller=resource, + conditions=req_ext.conditions) + request_ext_resources[req_ext.key] = resource + + return request_ext_resources + + def __init__(self, application, config, ext_mgr=None): + ext_mgr = (ext_mgr or + ExtensionManager()) + mapper = routes.Mapper() + + # extended resources + for resource_ext in ext_mgr.get_resources(): + LOG.debug('Extended resource: %s', resource_ext.collection) + controller_resource = wsgi.Resource(resource_ext.controller, + resource_ext.deserializer, + resource_ext.serializer) + self._map_custom_collection_actions(resource_ext, mapper, + controller_resource) + kargs = dict(controller=controller_resource, + collection=resource_ext.collection_actions, + member=resource_ext.member_actions) + if resource_ext.parent: + kargs['parent_resource'] = resource_ext.parent + mapper.resource(resource_ext.collection, + resource_ext.collection, **kargs) + + # extended actions + action_resources = self._action_ext_resources(application, ext_mgr, + mapper) + for action in ext_mgr.get_actions(): + LOG.debug('Extended action: %s', action.action_name) + resource = action_resources[action.collection] + resource.add_action(action.action_name, action.handler) + + # extended requests + req_controllers = self._request_ext_resources(application, ext_mgr, + mapper) + for request_ext in ext_mgr.get_request_extensions(): + LOG.debug('Extended request: %s', request_ext.key) + controller = req_controllers[request_ext.key] + controller.add_handler(request_ext.handler) + + self._router = routes.middleware.RoutesMiddleware(self._dispatch, + mapper) + + super(ExtensionMiddleware, self).__init__(application) + + def _map_custom_collection_actions(self, resource_ext, mapper, + controller_resource): + for action, method in resource_ext.collection_actions.iteritems(): + parent = resource_ext.parent + conditions = dict(method=[method]) + path = "/%s/%s" % (resource_ext.collection, action) + + path_prefix = "" + if parent: + path_prefix = "/%s/{%s_id}" % (parent["collection_name"], + parent["member_name"]) + + with mapper.submapper(controller=controller_resource, + action=action, + path_prefix=path_prefix, + conditions=conditions) as submap: + submap.connect(path) + submap.connect("%s.:(format)" % path) + + @webob.dec.wsgify(RequestClass=wsgi.Request) + def __call__(self, req): + """Route the incoming request with router.""" + req.environ['extended.app'] = self.application + return self._router + + @staticmethod + @webob.dec.wsgify(RequestClass=wsgi.Request) + def _dispatch(req): + """Dispatch the request. + + Returns the routed WSGI app's response or defers to the extended + application. + + """ + match = req.environ['wsgiorg.routing_args'][1] + if not match: + return req.environ['extended.app'] + app = match['controller'] + return app + + +class ExtensionManager(object): + + EXT_NAMESPACE = 'trove.api.extensions' + + def __init__(self): + LOG.debug('Initializing extension manager.') + + self.extensions = {} + self._load_all_extensions() + + def get_resources(self): + """Returns a list of ResourceExtension objects.""" + resources = [] + extension_resource = ExtensionsResource(self) + res_ext = ResourceExtension('extensions', + extension_resource, + serializer=extension_resource.serializer) + resources.append(res_ext) + for alias, ext in self.extensions.iteritems(): + try: + resources.extend(ext.get_resources()) + except AttributeError: + pass + return resources + + def get_actions(self): + """Returns a list of ActionExtension objects.""" + actions = [] + for alias, ext in self.extensions.iteritems(): + try: + actions.extend(ext.get_actions()) + except AttributeError: + pass + return actions + + def get_request_extensions(self): + """Returns a list of RequestExtension objects.""" + request_exts = [] + for alias, ext in self.extensions.iteritems(): + try: + request_exts.extend(ext.get_request_extensions()) + except AttributeError: + pass + return request_exts + + def _check_extension(self, extension): + """Checks for required methods in extension objects.""" + try: + LOG.debug('Ext name: %s', extension.get_name()) + LOG.debug('Ext alias: %s', extension.get_alias()) + LOG.debug('Ext description: %s', extension.get_description()) + LOG.debug('Ext namespace: %s', extension.get_namespace()) + LOG.debug('Ext updated: %s', extension.get_updated()) + except AttributeError as ex: + LOG.exception(_("Exception loading extension: %s"), unicode(ex)) + return False + return True + + def _check_load_extension(self, ext): + LOG.debug('Ext: %s', ext.obj) + return isinstance(ext.obj, ExtensionDescriptor) + + def _load_all_extensions(self): + self.api_extension_manager = stevedore.enabled.EnabledExtensionManager( + namespace=self.EXT_NAMESPACE, + check_func=self._check_load_extension, + invoke_on_load=True, + invoke_kwds={}) + self.api_extension_manager.map(self.add_extension) + + def add_extension(self, ext): + ext = ext.obj + # Do nothing if the extension doesn't check out + if not self._check_extension(ext): + return + + alias = ext.get_alias() + LOG.debug('Loaded extension: %s', alias) + + if alias in self.extensions: + raise exception.Error("Found duplicate extension: %s" % alias) + self.extensions[alias] = ext + + +class RequestExtension(object): + + def __init__(self, method, url_route, handler): + self.url_route = url_route + self.handler = handler + self.conditions = dict(method=[method]) + self.key = "%s-%s" % (method, url_route) + + +class ActionExtension(object): + + def __init__(self, collection, action_name, handler): + self.collection = collection + self.action_name = action_name + self.handler = handler + + +class BaseResourceExtension(object): + + def __init__(self, collection, controller, parent=None, + collection_actions=None, member_actions=None, + deserializer=None, serializer=None): + if not collection_actions: + collection_actions = {} + if not member_actions: + member_actions = {} + self.collection = collection + self.controller = controller + self.parent = parent + self.collection_actions = collection_actions + self.member_actions = member_actions + self.deserializer = deserializer + self.serializer = serializer + + +class ExtensionsXMLSerializer(os_wsgi.XMLDictSerializer): + + def __init__(self): + self.nsmap = {None: DEFAULT_XMLNS, 'atom': XMLNS_ATOM} + + def show(self, ext_dict): + ext = etree.Element('extension', nsmap=self.nsmap) + self._populate_ext(ext, ext_dict['extension']) + return self._to_xml(ext) + + def index(self, exts_dict): + exts = etree.Element('extensions', nsmap=self.nsmap) + for ext_dict in exts_dict['extensions']: + ext = etree.SubElement(exts, 'extension') + self._populate_ext(ext, ext_dict) + return self._to_xml(exts) + + def _populate_ext(self, ext_elem, ext_dict): + """Populate an extension xml element from a dict.""" + + ext_elem.set('name', ext_dict['name']) + ext_elem.set('namespace', ext_dict['namespace']) + ext_elem.set('alias', ext_dict['alias']) + ext_elem.set('updated', ext_dict['updated']) + desc = etree.Element('description') + desc.text = ext_dict['description'] + ext_elem.append(desc) + for link in ext_dict.get('links', []): + elem = etree.SubElement(ext_elem, '{%s}link' % XMLNS_ATOM) + elem.set('rel', link['rel']) + elem.set('href', link['href']) + elem.set('type', link['type']) + return ext_elem + + def _to_xml(self, root): + """Convert the xml object to an xml string.""" + + return etree.tostring(root, encoding='UTF-8') + + +class ResourceExtension(BaseResourceExtension): def __init__(self, collection, controller, parent=None, collection_actions=None, member_actions=None, deserializer=None, serializer=None): @@ -41,11 +515,11 @@ class ResourceExtension(extensions.ResourceExtension): serializer=wsgi.TroveResponseSerializer()) -class TroveExtensionMiddleware(extensions.ExtensionMiddleware): +class TroveExtensionMiddleware(ExtensionMiddleware): def __init__(self, application, ext_mgr=None): ext_mgr = (ext_mgr or - extensions.ExtensionManager(CONF.api_extensions_path)) + ExtensionManager()) mapper = routes.Mapper() # extended resources @@ -90,16 +564,12 @@ class TroveExtensionMiddleware(extensions.ExtensionMiddleware): self._router = routes.middleware.RoutesMiddleware(self._dispatch, mapper) - super(extensions.ExtensionMiddleware, self).__init__(application) + super(ExtensionMiddleware, self).__init__(application) def factory(global_config, **local_config): """Paste factory.""" def _factory(app): - extensions.DEFAULT_XMLNS = "http://docs.openstack.org/trove" - if not os.path.exists(CONF.api_extensions_path): - LOG.warning(_('API extensions path does not exist: %s.'), - CONF.api_extensions_path) - ext_mgr = extensions.ExtensionManager(CONF.api_extensions_path) + ext_mgr = ExtensionManager() return TroveExtensionMiddleware(app, ext_mgr) return _factory diff --git a/trove/extensions/routes/account.py b/trove/extensions/routes/account.py index 5f54cebf27..be46521de1 100644 --- a/trove/extensions/routes/account.py +++ b/trove/extensions/routes/account.py @@ -22,7 +22,7 @@ from trove.extensions.account import service LOG = logging.getLogger(__name__) -class Account(extensions.ExtensionsDescriptor): +class Account(extensions.ExtensionDescriptor): def get_name(self): return "Account" diff --git a/trove/extensions/routes/mgmt.py b/trove/extensions/routes/mgmt.py index e23a4b20df..2be6e4d3ea 100644 --- a/trove/extensions/routes/mgmt.py +++ b/trove/extensions/routes/mgmt.py @@ -26,7 +26,7 @@ from trove.extensions.mgmt.upgrade.service import UpgradeController LOG = logging.getLogger(__name__) -class Mgmt(extensions.ExtensionsDescriptor): +class Mgmt(extensions.ExtensionDescriptor): def get_name(self): return "Mgmt" diff --git a/trove/extensions/routes/mysql.py b/trove/extensions/routes/mysql.py index 3e84d0a5f1..500e48b9a4 100644 --- a/trove/extensions/routes/mysql.py +++ b/trove/extensions/routes/mysql.py @@ -22,7 +22,7 @@ from trove.extensions.mysql import service LOG = logging.getLogger(__name__) -class Mysql(extensions.ExtensionsDescriptor): +class Mysql(extensions.ExtensionDescriptor): def get_name(self): return "Mysql" diff --git a/trove/extensions/routes/security_group.py b/trove/extensions/routes/security_group.py index 46755d5ffa..522fd799f0 100644 --- a/trove/extensions/routes/security_group.py +++ b/trove/extensions/routes/security_group.py @@ -28,7 +28,7 @@ CONF = cfg.CONF # The Extensions module from openstack common expects the classname of the # extension to be loaded to be the exact same as the filename, except with # a capital first letter. That's the reason this class has such a funky name. -class Security_group(extensions.ExtensionsDescriptor): +class Security_group(extensions.ExtensionDescriptor): def get_name(self): return "SecurityGroup" diff --git a/trove/openstack/common/extensions.py b/trove/openstack/common/extensions.py deleted file mode 100644 index da1e0e7b30..0000000000 --- a/trove/openstack/common/extensions.py +++ /dev/null @@ -1,538 +0,0 @@ -# vim: tabstop=4 shiftwidth=4 softtabstop=4 - -# Copyright 2011 OpenStack LLC. -# Copyright 2011 Justin Santa Barbara -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -import imp -import os -import routes -import webob.dec -import webob.exc -import logging -from lxml import etree - -from trove.openstack.common import exception -from trove.openstack.common import wsgi - -LOG = logging.getLogger('extensions') -DEFAULT_XMLNS = "http://docs.openstack.org/" -XMLNS_ATOM = "http://www.w3.org/2005/Atom" - - -class ExtensionDescriptor(object): - """Base class that defines the contract for extensions. - - Note that you don't have to derive from this class to have a valid - extension; it is purely a convenience. - - """ - - def get_name(self): - """The name of the extension. - - e.g. 'Fox In Socks' - - """ - raise NotImplementedError() - - def get_alias(self): - """The alias for the extension. - - e.g. 'FOXNSOX' - - """ - raise NotImplementedError() - - def get_description(self): - """Friendly description for the extension. - - e.g. 'The Fox In Socks Extension' - - """ - raise NotImplementedError() - - def get_namespace(self): - """The XML namespace for the extension. - - e.g. 'http://www.fox.in.socks/api/ext/pie/v1.0' - - """ - raise NotImplementedError() - - def get_updated(self): - """The timestamp when the extension was last updated. - - e.g. '2011-01-22T13:25:27-06:00' - - """ - # NOTE(justinsb): Not sure of the purpose of this is, vs the XML NS - raise NotImplementedError() - - def get_resources(self): - """List of extensions.ResourceExtension extension objects. - - Resources define new nouns, and are accessible through URLs. - - """ - resources = [] - return resources - - def get_actions(self): - """List of extensions.ActionExtension extension objects. - - Actions are verbs callable from the API. - - """ - actions = [] - return actions - - def get_request_extensions(self): - """List of extensions.RequestException extension objects. - - Request extensions are used to handle custom request data. - - """ - request_exts = [] - return request_exts - - -class ActionExtensionController(object): - def __init__(self, application): - self.application = application - self.action_handlers = {} - - def add_action(self, action_name, handler): - self.action_handlers[action_name] = handler - - def action(self, req, id, body): - for action_name, handler in self.action_handlers.iteritems(): - if action_name in body: - return handler(body, req, id) - # no action handler found (bump to downstream application) - res = self.application - return res - - -class ActionExtensionResource(wsgi.Resource): - - def __init__(self, application): - controller = ActionExtensionController(application) - wsgi.Resource.__init__(self, controller) - - def add_action(self, action_name, handler): - self.controller.add_action(action_name, handler) - - -class RequestExtensionController(object): - - def __init__(self, application): - self.application = application - self.handlers = [] - - def add_handler(self, handler): - self.handlers.append(handler) - - def process(self, req, *args, **kwargs): - res = req.get_response(self.application) - # currently request handlers are un-ordered - for handler in self.handlers: - res = handler(req, res) - return res - - -class RequestExtensionResource(wsgi.Resource): - - def __init__(self, application): - controller = RequestExtensionController(application) - wsgi.Resource.__init__(self, controller) - - def add_handler(self, handler): - self.controller.add_handler(handler) - - -class ExtensionsResource(wsgi.Resource): - - def __init__(self, extension_manager): - self.extension_manager = extension_manager - body_serializers = {'application/xml': ExtensionsXMLSerializer()} - serializer = wsgi.ResponseSerializer(body_serializers=body_serializers) - super(ExtensionsResource, self).__init__(self, None, serializer) - - def _translate(self, ext): - ext_data = {} - ext_data['name'] = ext.get_name() - ext_data['alias'] = ext.get_alias() - ext_data['description'] = ext.get_description() - ext_data['namespace'] = ext.get_namespace() - ext_data['updated'] = ext.get_updated() - ext_data['links'] = [] # TODO(dprince): implement extension links - return ext_data - - def index(self, req): - extensions = [] - for _alias, ext in self.extension_manager.extensions.iteritems(): - extensions.append(self._translate(ext)) - return dict(extensions=extensions) - - def show(self, req, id): - # NOTE(dprince): the extensions alias is used as the 'id' for show - ext = self.extension_manager.extensions.get(id, None) - if not ext: - raise webob.exc.HTTPNotFound( - _("Extension with alias %s does not exist") % id) - - return dict(extension=self._translate(ext)) - - def delete(self, req, id): - raise webob.exc.HTTPNotFound() - - def create(self, req): - raise webob.exc.HTTPNotFound() - - -class ExtensionMiddleware(wsgi.Middleware): - """Extensions middleware for WSGI.""" - - @classmethod - def factory(cls, global_config, **local_config): - """Paste factory.""" - def _factory(app): - return cls(app, global_config, **local_config) - return _factory - - def _action_ext_resources(self, application, ext_mgr, mapper): - """Return a dict of ActionExtensionResource-s by collection.""" - action_resources = {} - for action in ext_mgr.get_actions(): - if not action.collection in action_resources.keys(): - resource = ActionExtensionResource(application) - mapper.connect("/%s/:(id)/action.:(format)" % - action.collection, - action='action', - controller=resource, - conditions=dict(method=['POST'])) - mapper.connect("/%s/:(id)/action" % - action.collection, - action='action', - controller=resource, - conditions=dict(method=['POST'])) - action_resources[action.collection] = resource - - return action_resources - - def _request_ext_resources(self, application, ext_mgr, mapper): - """Returns a dict of RequestExtensionResource-s by collection.""" - request_ext_resources = {} - for req_ext in ext_mgr.get_request_extensions(): - if not req_ext.key in request_ext_resources.keys(): - resource = RequestExtensionResource(application) - mapper.connect(req_ext.url_route + '.:(format)', - action='process', - controller=resource, - conditions=req_ext.conditions) - - mapper.connect(req_ext.url_route, - action='process', - controller=resource, - conditions=req_ext.conditions) - request_ext_resources[req_ext.key] = resource - - return request_ext_resources - - def __init__(self, application, config, ext_mgr=None): - ext_mgr = (ext_mgr or - ExtensionManager(config['api_extensions_path'])) - mapper = routes.Mapper() - - # extended resources - for resource_ext in ext_mgr.get_resources(): - LOG.debug(_('Extended resource: %s'), resource_ext.collection) - controller_resource = wsgi.Resource(resource_ext.controller, - resource_ext.deserializer, - resource_ext.serializer) - self._map_custom_collection_actions(resource_ext, mapper, - controller_resource) - kargs = dict(controller=controller_resource, - collection=resource_ext.collection_actions, - member=resource_ext.member_actions) - if resource_ext.parent: - kargs['parent_resource'] = resource_ext.parent - mapper.resource(resource_ext.collection, - resource_ext.collection, **kargs) - - # extended actions - action_resources = self._action_ext_resources(application, ext_mgr, - mapper) - for action in ext_mgr.get_actions(): - LOG.debug(_('Extended action: %s'), action.action_name) - resource = action_resources[action.collection] - resource.add_action(action.action_name, action.handler) - - # extended requests - req_controllers = self._request_ext_resources(application, ext_mgr, - mapper) - for request_ext in ext_mgr.get_request_extensions(): - LOG.debug(_('Extended request: %s'), request_ext.key) - controller = req_controllers[request_ext.key] - controller.add_handler(request_ext.handler) - - self._router = routes.middleware.RoutesMiddleware(self._dispatch, - mapper) - - super(ExtensionMiddleware, self).__init__(application) - - def _map_custom_collection_actions(self, resource_ext, mapper, - controller_resource): - for action, method in resource_ext.collection_actions.iteritems(): - parent = resource_ext.parent - conditions = dict(method=[method]) - path = "/%s/%s" % (resource_ext.collection, action) - - path_prefix = "" - if parent: - path_prefix = "/%s/{%s_id}" % (parent["collection_name"], - parent["member_name"]) - - with mapper.submapper(controller=controller_resource, - action=action, - path_prefix=path_prefix, - conditions=conditions) as submap: - submap.connect(path) - submap.connect("%s.:(format)" % path) - - @webob.dec.wsgify(RequestClass=wsgi.Request) - def __call__(self, req): - """Route the incoming request with router.""" - req.environ['extended.app'] = self.application - return self._router - - @staticmethod - @webob.dec.wsgify(RequestClass=wsgi.Request) - def _dispatch(req): - """Dispatch the request. - - Returns the routed WSGI app's response or defers to the extended - application. - - """ - match = req.environ['wsgiorg.routing_args'][1] - if not match: - return req.environ['extended.app'] - app = match['controller'] - return app - - -class ExtensionManager(object): - """Load extensions from the configured extension path. - - See nova/tests/api/openstack/extensions/foxinsocks/extension.py for an - example extension implementation. - - """ - - def __init__(self, path): - LOG.debug(_('Initializing extension manager.')) - - self.path = path - self.extensions = {} - self._load_all_extensions() - - def get_resources(self): - """Returns a list of ResourceExtension objects.""" - resources = [] - extension_resource = ExtensionsResource(self) - res_ext = ResourceExtension('extensions', - extension_resource, - serializer=extension_resource.serializer) - resources.append(res_ext) - for alias, ext in self.extensions.iteritems(): - try: - resources.extend(ext.get_resources()) - except AttributeError: - # NOTE(dprince): Extension aren't required to have resource - # extensions - pass - return resources - - def get_actions(self): - """Returns a list of ActionExtension objects.""" - actions = [] - for alias, ext in self.extensions.iteritems(): - try: - actions.extend(ext.get_actions()) - except AttributeError: - # NOTE(dprince): Extension aren't required to have action - # extensions - pass - return actions - - def get_request_extensions(self): - """Returns a list of RequestExtension objects.""" - request_exts = [] - for alias, ext in self.extensions.iteritems(): - try: - request_exts.extend(ext.get_request_extensions()) - except AttributeError: - # NOTE(dprince): Extension aren't required to have request - # extensions - pass - return request_exts - - def _check_extension(self, extension): - """Checks for required methods in extension objects.""" - try: - LOG.debug(_('Ext name: %s'), extension.get_name()) - LOG.debug(_('Ext alias: %s'), extension.get_alias()) - LOG.debug(_('Ext description: %s'), extension.get_description()) - LOG.debug(_('Ext namespace: %s'), extension.get_namespace()) - LOG.debug(_('Ext updated: %s'), extension.get_updated()) - except AttributeError as ex: - LOG.exception(_("Exception loading extension: %s"), unicode(ex)) - return False - return True - - def _load_all_extensions(self): - """Load extensions from the configured path. - - Load extensions from the configured path. The extension name is - constructed from the module_name. If your extension module was named - widgets.py the extension class within that module should be - 'Widgets'. - - In addition, extensions are loaded from the 'contrib' directory. - - See nova/tests/api/openstack/extensions/foxinsocks.py for an example - extension implementation. - - """ - if os.path.exists(self.path): - self._load_all_extensions_from_path(self.path) - - contrib_path = os.path.join(os.path.dirname(__file__), "contrib") - if os.path.exists(contrib_path): - self._load_all_extensions_from_path(contrib_path) - - def _load_all_extensions_from_path(self, path): - for f in os.listdir(path): - LOG.debug(_('Loading extension file: %s'), f) - mod_name, file_ext = os.path.splitext(os.path.split(f)[-1]) - ext_path = os.path.join(path, f) - if file_ext.lower() == '.py' and not mod_name.startswith('_'): - mod = imp.load_source(mod_name, ext_path) - ext_name = mod_name[0].upper() + mod_name[1:] - new_ext_class = getattr(mod, ext_name, None) - if not new_ext_class: - LOG.warn(_('Did not find expected name ' - '"%(ext_name)s" in %(file)s'), - {'ext_name': ext_name, - 'file': ext_path}) - continue - new_ext = new_ext_class() - self.add_extension(new_ext) - - def add_extension(self, ext): - # Do nothing if the extension doesn't check out - if not self._check_extension(ext): - return - - alias = ext.get_alias() - LOG.debug(_('Loaded extension: %s'), alias) - - if alias in self.extensions: - raise exception.Error("Found duplicate extension: %s" % alias) - self.extensions[alias] = ext - - -class RequestExtension(object): - """Extend requests and responses of core nova OpenStack API resources. - - Provide a way to add data to responses and handle custom request data - that is sent to core nova OpenStack API controllers. - - """ - def __init__(self, method, url_route, handler): - self.url_route = url_route - self.handler = handler - self.conditions = dict(method=[method]) - self.key = "%s-%s" % (method, url_route) - - -class ActionExtension(object): - """Add custom actions to core nova OpenStack API resources.""" - - def __init__(self, collection, action_name, handler): - self.collection = collection - self.action_name = action_name - self.handler = handler - - -class ResourceExtension(object): - """Add top level resources to the OpenStack API in nova.""" - - def __init__(self, collection, controller, parent=None, - collection_actions=None, member_actions=None, - deserializer=None, serializer=None): - if not collection_actions: - collection_actions = {} - if not member_actions: - member_actions = {} - self.collection = collection - self.controller = controller - self.parent = parent - self.collection_actions = collection_actions - self.member_actions = member_actions - self.deserializer = deserializer - self.serializer = serializer - - -class ExtensionsXMLSerializer(wsgi.XMLDictSerializer): - - def __init__(self): - self.nsmap = {None: DEFAULT_XMLNS, 'atom': XMLNS_ATOM} - - def show(self, ext_dict): - ext = etree.Element('extension', nsmap=self.nsmap) - self._populate_ext(ext, ext_dict['extension']) - return self._to_xml(ext) - - def index(self, exts_dict): - exts = etree.Element('extensions', nsmap=self.nsmap) - for ext_dict in exts_dict['extensions']: - ext = etree.SubElement(exts, 'extension') - self._populate_ext(ext, ext_dict) - return self._to_xml(exts) - - def _populate_ext(self, ext_elem, ext_dict): - """Populate an extension xml element from a dict.""" - - ext_elem.set('name', ext_dict['name']) - ext_elem.set('namespace', ext_dict['namespace']) - ext_elem.set('alias', ext_dict['alias']) - ext_elem.set('updated', ext_dict['updated']) - desc = etree.Element('description') - desc.text = ext_dict['description'] - ext_elem.append(desc) - for link in ext_dict.get('links', []): - elem = etree.SubElement(ext_elem, '{%s}link' % XMLNS_ATOM) - elem.set('rel', link['rel']) - elem.set('href', link['href']) - elem.set('type', link['type']) - return ext_elem - - def _to_xml(self, root): - """Convert the xml object to an xml string.""" - - return etree.tostring(root, encoding='UTF-8') diff --git a/trove/tests/unittests/api/common/test_extensions.py b/trove/tests/unittests/api/common/test_extensions.py new file mode 100644 index 0000000000..db080096cc --- /dev/null +++ b/trove/tests/unittests/api/common/test_extensions.py @@ -0,0 +1,89 @@ +# Copyright 2014 IBM Corp. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + + +import mock +import pkg_resources +import testtools + +from trove.common import extensions +from trove.extensions.routes.account import Account +from trove.extensions.routes.mgmt import Mgmt +from trove.extensions.routes.mysql import Mysql +from trove.extensions.routes.security_group import Security_group + +DEFAULT_EXTENSION_MAP = { + 'Account': [Account, extensions.ExtensionDescriptor], + 'Mgmt': [Mgmt, extensions.ExtensionDescriptor], + 'MYSQL': [Mysql, extensions.ExtensionDescriptor], + 'SecurityGroup': [Security_group, extensions.ExtensionDescriptor] +} + +EP_TEXT = ''' +account = trove.extensions.routes.account:Account +mgmt = trove.extensions.routes.mgmt:Mgmt +mysql = trove.extensions.routes.mysql:Mysql +security_group = trove.extensions.routes.security_group:Security_group +invalid = trove.tests.unittests.api.common.test_extensions:InvalidExtension +''' + + +class InvalidExtension(object): + def get_name(self): + return "Invalid" + + def get_description(self): + return "Invalid Extension" + + def get_alias(self): + return "Invalid" + + def get_namespace(self): + return "http://TBD" + + def get_updated(self): + return "2014-08-14T13:25:27-06:00" + + def get_resources(self): + return [] + + +class TestExtensionLoading(testtools.TestCase): + def setUp(self): + super(TestExtensionLoading, self).setUp() + + def tearDown(self): + super(TestExtensionLoading, self).tearDown() + + def _assert_default_extensions(self, ext_list): + for alias, ext in ext_list.items(): + for clazz in DEFAULT_EXTENSION_MAP[alias]: + self.assertIsInstance(ext, clazz, "Improper extension class") + + def test_default_extensions(self): + extension_mgr = extensions.ExtensionManager() + self.assertEqual(DEFAULT_EXTENSION_MAP.keys().sort(), + extension_mgr.extensions.keys().sort(), + "Invalid extension names") + self._assert_default_extensions(extension_mgr.extensions) + + @mock.patch("pkg_resources.iter_entry_points") + def test_invalid_extension(self, mock_iter_eps): + eps = pkg_resources.EntryPoint.parse_group('mock', EP_TEXT) + mock_iter_eps.return_value = eps.values() + extension_mgr = extensions.ExtensionManager() + self.assertEqual(len(extension_mgr.extensions), + len(DEFAULT_EXTENSION_MAP.keys()), + "Loaded invalid extensions") + self._assert_default_extensions(extension_mgr.extensions)