From 0a71ef9e880996597f6f94d06fa2d0934051d420 Mon Sep 17 00:00:00 2001 From: daniel-a-nguyen Date: Thu, 14 Feb 2013 11:22:06 -0800 Subject: [PATCH] Rate limits implementation added unittest for limits reverted changes to openstack/common removed commented code cleaned up unittest added int-tests updated reference to XMLNS removed 1.1 XMLMS in wsgi Implements: blueprint rate-limits Change-Id: I842de3a6cae1859cc246264a5836abfd97fb8074 --- etc/reddwarf/api-paste.ini | 5 +- etc/reddwarf/reddwarf.conf.sample | 6 + etc/reddwarf/reddwarf.conf.test | 5 +- etc/tests/localhost.test.conf | 18 + reddwarf/common/api.py | 7 + reddwarf/common/cfg.py | 4 + reddwarf/common/limits.py | 464 +++++++++ reddwarf/common/schemas/atom-link.rng | 141 +++ reddwarf/common/schemas/atom.rng | 597 ++++++++++++ reddwarf/common/schemas/v1.1/limits.rng | 28 + reddwarf/common/wsgi.py | 237 ++++- reddwarf/common/xmlutil.py | 910 ++++++++++++++++++ reddwarf/limits/__init__.py | 0 reddwarf/limits/service.py | 49 + reddwarf/limits/views.py | 98 ++ reddwarf/tests/api/limits.py | 109 +++ .../tests/unittests/api/common/__init__.py | 0 .../tests/unittests/api/common/test_limits.py | 741 ++++++++++++++ reddwarf/tests/unittests/util/matchers.py | 454 +++++++++ run_tests.py | 1 + tools/test-requires | 1 + 21 files changed, 3872 insertions(+), 3 deletions(-) create mode 100644 reddwarf/common/limits.py create mode 100644 reddwarf/common/schemas/atom-link.rng create mode 100644 reddwarf/common/schemas/atom.rng create mode 100644 reddwarf/common/schemas/v1.1/limits.rng create mode 100644 reddwarf/common/xmlutil.py create mode 100644 reddwarf/limits/__init__.py create mode 100644 reddwarf/limits/service.py create mode 100644 reddwarf/limits/views.py create mode 100644 reddwarf/tests/api/limits.py create mode 100644 reddwarf/tests/unittests/api/common/__init__.py create mode 100644 reddwarf/tests/unittests/api/common/test_limits.py create mode 100644 reddwarf/tests/unittests/util/matchers.py diff --git a/etc/reddwarf/api-paste.ini b/etc/reddwarf/api-paste.ini index 896c6c18a5..7f197ca5ae 100644 --- a/etc/reddwarf/api-paste.ini +++ b/etc/reddwarf/api-paste.ini @@ -7,7 +7,7 @@ use = call:reddwarf.common.wsgi:versioned_urlmap paste.app_factory = reddwarf.versions:app_factory [pipeline:reddwarfapi] -pipeline = faultwrapper tokenauth authorization contextwrapper extensions reddwarfapp +pipeline = faultwrapper tokenauth authorization contextwrapper ratelimit extensions reddwarfapp #pipeline = debug extensions reddwarfapp [filter:extensions] @@ -34,6 +34,9 @@ paste.filter_factory = reddwarf.common.wsgi:ContextMiddleware.factory [filter:faultwrapper] paste.filter_factory = reddwarf.common.wsgi:FaultWrapper.factory +[filter:ratelimit] +paste.filter_factory = reddwarf.common.limits:RateLimitingMiddleware.factory + [app:reddwarfapp] paste.app_factory = reddwarf.common.api:app_factory diff --git a/etc/reddwarf/reddwarf.conf.sample b/etc/reddwarf/reddwarf.conf.sample index 7399fbc6a0..89b4cf679c 100644 --- a/etc/reddwarf/reddwarf.conf.sample +++ b/etc/reddwarf/reddwarf.conf.sample @@ -54,6 +54,12 @@ max_instances_per_user = 5 max_volumes_per_user = 100 volume_time_out=30 +# Config options for rate limits +http_get_rate = 200 +http_post_rate = 200 +http_put_rate = 200 +http_delete_rate = 200 + # Reddwarf DNS reddwarf_dns_support = False diff --git a/etc/reddwarf/reddwarf.conf.test b/etc/reddwarf/reddwarf.conf.test index 03a0569d01..96867724ad 100644 --- a/etc/reddwarf/reddwarf.conf.test +++ b/etc/reddwarf/reddwarf.conf.test @@ -106,7 +106,7 @@ use = call:reddwarf.common.wsgi:versioned_urlmap paste.app_factory = reddwarf.versions:app_factory [pipeline:reddwarfapi] -pipeline = faultwrapper tokenauth authorization contextwrapper extensions reddwarfapp +pipeline = faultwrapper tokenauth authorization contextwrapper ratelimit extensions reddwarfapp # pipeline = debug reddwarfapp [filter:extensions] @@ -132,6 +132,9 @@ paste.filter_factory = reddwarf.common.wsgi:ContextMiddleware.factory [filter:faultwrapper] paste.filter_factory = reddwarf.common.wsgi:FaultWrapper.factory +[filter:ratelimit] +paste.filter_factory = reddwarf.common.limits:RateLimitingMiddleware.factory + [app:reddwarfapp] paste.app_factory = reddwarf.common.api:app_factory diff --git a/etc/tests/localhost.test.conf b/etc/tests/localhost.test.conf index cae44a17e2..d9efcd2757 100644 --- a/etc/tests/localhost.test.conf +++ b/etc/tests/localhost.test.conf @@ -45,6 +45,24 @@ "is_admin":false, "services": ["reddwarf"] } + }, + { + "auth_user":"rate_limit", + "auth_key":"password", + "tenant":"4000", + "requirements": { + "is_admin":false, + "services": ["reddwarf"] + } + }, + { + "auth_user":"rate_limit_exceeded", + "auth_key":"password", + "tenant":"4050", + "requirements": { + "is_admin":false, + "services": ["reddwarf"] + } } ], diff --git a/reddwarf/common/api.py b/reddwarf/common/api.py index cb80ace4fb..3bc2ab5054 100644 --- a/reddwarf/common/api.py +++ b/reddwarf/common/api.py @@ -19,6 +19,7 @@ from reddwarf.common import wsgi from reddwarf.extensions.mgmt.host.instance import service as hostservice from reddwarf.flavor.service import FlavorController from reddwarf.instance.service import InstanceController +from reddwarf.limits.service import LimitsController from reddwarf.openstack.common import log as logging from reddwarf.openstack.common import rpc from reddwarf.versions import VersionsController @@ -32,6 +33,7 @@ class API(wsgi.Router): self._instance_router(mapper) self._flavor_router(mapper) self._versions_router(mapper) + self._limits_router(mapper) def _versions_router(self, mapper): versions_resource = VersionsController().create_resource() @@ -48,6 +50,11 @@ class API(wsgi.Router): path = "/{tenant_id}/flavors" mapper.resource("flavor", path, controller=flavor_resource) + def _limits_router(self, mapper): + limits_resource = LimitsController().create_resource() + path = "/{tenant_id}/limits" + mapper.resource("limits", path, controller=limits_resource) + def app_factory(global_conf, **local_conf): return API() diff --git a/reddwarf/common/cfg.py b/reddwarf/common/cfg.py index 937ba7501b..f6523a1d55 100644 --- a/reddwarf/common/cfg.py +++ b/reddwarf/common/cfg.py @@ -100,6 +100,10 @@ common_opts = [ cfg.IntOpt('revert_time_out', default=60 * 10), cfg.ListOpt('root_grant', default=['ALL']), cfg.BoolOpt('root_grant_option', default=True), + cfg.IntOpt('http_get_rate', default=200), + cfg.IntOpt('http_post_rate', default=200), + cfg.IntOpt('http_delete_rate', default=200), + cfg.IntOpt('http_put_rate', default=200), ] diff --git a/reddwarf/common/limits.py b/reddwarf/common/limits.py new file mode 100644 index 0000000000..bfe233f6f6 --- /dev/null +++ b/reddwarf/common/limits.py @@ -0,0 +1,464 @@ +# Copyright 2011 OpenStack LLC. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Module dedicated functions/classes dealing with rate limiting requests. +""" + +import collections +import copy +import httplib +import math +import re +import time +import webob.dec +import webob.exc +import xmlutil + +from reddwarf.common import cfg +from reddwarf.common import wsgi as base_wsgi +from reddwarf.openstack.common import importutils +from reddwarf.openstack.common import jsonutils +from reddwarf.openstack.common import wsgi +from reddwarf.openstack.common.gettextutils import _ + +# +# TODO: come back to this later +# Dan Nguyen +# +#from nova import quota +#QUOTAS = quota.QUOTAS + + +CONF = cfg.CONF + +# Convenience constants for the limits dictionary passed to Limiter(). +PER_SECOND = 1 +PER_MINUTE = 60 +PER_HOUR = 60 * 60 +PER_DAY = 60 * 60 * 24 + + +limits_nsmap = {None: xmlutil.XMLNS_COMMON_V10, 'atom': xmlutil.XMLNS_ATOM} + + +class LimitsTemplate(xmlutil.TemplateBuilder): + def construct(self): + root = xmlutil.TemplateElement('limits', selector='limits') + + rates = xmlutil.SubTemplateElement(root, 'rates') + rate = xmlutil.SubTemplateElement(rates, 'rate', selector='rate') + rate.set('uri', 'uri') + rate.set('regex', 'regex') + limit = xmlutil.SubTemplateElement(rate, 'limit', selector='limit') + limit.set('value', 'value') + limit.set('verb', 'verb') + limit.set('remaining', 'remaining') + limit.set('unit', 'unit') + limit.set('next-available', 'next-available') + + absolute = xmlutil.SubTemplateElement(root, 'absolute', + selector='absolute') + limit = xmlutil.SubTemplateElement(absolute, 'limit', + selector=xmlutil.get_items) + limit.set('name', 0) + limit.set('value', 1) + + return xmlutil.MasterTemplate(root, 1, nsmap=limits_nsmap) + + +class Limit(object): + """ + Stores information about a limit for HTTP requests. + """ + + UNITS = { + 1: "SECOND", + 60: "MINUTE", + 60 * 60: "HOUR", + 60 * 60 * 24: "DAY", + } + + UNIT_MAP = dict([(v, k) for k, v in UNITS.items()]) + + def __init__(self, verb, uri, regex, value, unit): + """ + Initialize a new `Limit`. + + @param verb: HTTP verb (POST, PUT, etc.) + @param uri: Human-readable URI + @param regex: Regular expression format for this limit + @param value: Integer number of requests which can be made + @param unit: Unit of measure for the value parameter + """ + self.verb = verb + self.uri = uri + self.regex = regex + self.value = int(value) + self.unit = unit + self.unit_string = self.display_unit().lower() + self.remaining = int(value) + + if value <= 0: + raise ValueError("Limit value must be > 0") + + self.last_request = None + self.next_request = None + + self.water_level = 0 + self.capacity = self.unit + self.request_value = float(self.capacity) / float(self.value) + msg = _("Only %(value)s %(verb)s request(s) can be " + "made to %(uri)s every %(unit_string)s.") + self.error_message = msg % self.__dict__ + + def __call__(self, verb, url): + """ + Represents a call to this limit from a relevant request. + + @param verb: string http verb (POST, GET, etc.) + @param url: string URL + """ + if self.verb != verb or not re.match(self.regex, url): + return + + now = self._get_time() + + if self.last_request is None: + self.last_request = now + + leak_value = now - self.last_request + + self.water_level -= leak_value + self.water_level = max(self.water_level, 0) + self.water_level += self.request_value + + difference = self.water_level - self.capacity + + self.last_request = now + + if difference > 0: + self.water_level -= self.request_value + self.next_request = now + difference + return difference + + cap = self.capacity + water = self.water_level + val = self.value + + self.remaining = math.floor(((cap - water) / cap) * val) + self.next_request = now + + def _get_time(self): + """Retrieve the current time. Broken out for testability.""" + return time.time() + + def display_unit(self): + """Display the string name of the unit.""" + return self.UNITS.get(self.unit, "UNKNOWN") + + def display(self): + """Return a useful representation of this class.""" + return { + "verb": self.verb, + "URI": self.uri, + "regex": self.regex, + "value": self.value, + "remaining": int(self.remaining), + "unit": self.display_unit(), + "resetTime": int(self.next_request or self._get_time()), + } + +# "Limit" format is a dictionary with the HTTP verb, human-readable URI, +# a regular-expression to match, value and unit of measure (PER_DAY, etc.) +DEFAULT_LIMITS = [ + Limit("POST", "*", ".*", CONF.http_post_rate, PER_MINUTE), + Limit("PUT", "*", ".*", CONF.http_put_rate, PER_MINUTE), + Limit("DELETE", "*", ".*", CONF.http_delete_rate, PER_MINUTE), + Limit("GET", "*", ".*", CONF.http_get_rate, PER_MINUTE), +] + + +class RateLimitingMiddleware(base_wsgi.ReddwarfMiddleware): + """ + Rate-limits requests passing through this middleware. All limit information + is stored in memory for this implementation. + """ + + def __init__(self, application, limits=None, limiter=None, **kwargs): + """ + Initialize new `RateLimitingMiddleware`, which wraps the given WSGI + application and sets up the given limits. + + @param application: WSGI application to wrap + @param limits: String describing limits + @param limiter: String identifying class for representing limits + + Other parameters are passed to the constructor for the limiter. + """ + base_wsgi.Middleware.__init__(self, application) + + # Select the limiter class + if limiter is None: + limiter = Limiter + else: + limiter = importutils.import_class(limiter) + + # Parse the limits, if any are provided + if limits is not None: + limits = limiter.parse_limits(limits) + + self._limiter = limiter(limits or DEFAULT_LIMITS, **kwargs) + + @webob.dec.wsgify(RequestClass=wsgi.Request) + def __call__(self, req): + """ + Represents a single call through this middleware. We should record the + request if we have a limit relevant to it. If no limit is relevant to + the request, ignore it. + + If the request should be rate limited, return a fault telling the user + they are over the limit and need to retry later. + """ + verb = req.method + url = req.url + context = req.environ.get(base_wsgi.CONTEXT_KEY) + + tenant_id = None + if context: + tenant_id = context.tenant + + delay, error = self._limiter.check_for_delay(verb, url, tenant_id) + + if delay: + msg = _("This request was rate-limited.") + retry = time.time() + delay + return base_wsgi.OverLimitFault(msg, error, retry) + + req.environ["reddwarf.limits"] = self._limiter.get_limits(tenant_id) + + return self.application + + +class Limiter(object): + """ + Rate-limit checking class which handles limits in memory. + """ + + def __init__(self, limits, **kwargs): + """ + Initialize the new `Limiter`. + + @param limits: List of `Limit` objects + """ + self.limits = copy.deepcopy(limits) + self.levels = collections.defaultdict(lambda: copy.deepcopy(limits)) + + # Pick up any per-user limit information + for key, value in kwargs.items(): + if key.startswith('user:'): + username = key[5:] + self.levels[username] = self.parse_limits(value) + + def get_limits(self, username=None): + """ + Return the limits for a given user. + """ + return [limit.display() for limit in self.levels[username]] + + def check_for_delay(self, verb, url, username=None): + """ + Check the given verb/user/user triplet for limit. + + @return: Tuple of delay (in seconds) and error message (or None, None) + """ + delays = [] + + for limit in self.levels[username]: + delay = limit(verb, url) + if delay: + delays.append((delay, limit.error_message)) + + if delays: + delays.sort() + return delays[0] + + return None, None + + # This was ported from nova. + # Keeping it as a static method for the sake of consistency + # + # Note: This method gets called before the class is instantiated, + # so this must be either a static method or a class method. It is + # used to develop a list of limits to feed to the constructor. We + # put this in the class so that subclasses can override the + # default limit parsing. + @staticmethod + def parse_limits(limits): + """ + Convert a string into a list of Limit instances. This + implementation expects a semicolon-separated sequence of + parenthesized groups, where each group contains a + comma-separated sequence consisting of HTTP method, + user-readable URI, a URI reg-exp, an integer number of + requests which can be made, and a unit of measure. Valid + values for the latter are "SECOND", "MINUTE", "HOUR", and + "DAY". + + @return: List of Limit instances. + """ + + # Handle empty limit strings + limits = limits.strip() + if not limits: + return [] + + # Split up the limits by semicolon + result = [] + for group in limits.split(';'): + group = group.strip() + if group[:1] != '(' or group[-1:] != ')': + raise ValueError("Limit rules must be surrounded by " + "parentheses") + group = group[1:-1] + + # Extract the Limit arguments + args = [a.strip() for a in group.split(',')] + if len(args) != 5: + raise ValueError("Limit rules must contain the following " + "arguments: verb, uri, regex, value, unit") + + # Pull out the arguments + verb, uri, regex, value, unit = args + + # Upper-case the verb + verb = verb.upper() + + # Convert value--raises ValueError if it's not integer + value = int(value) + + # Convert unit + unit = unit.upper() + if unit not in Limit.UNIT_MAP: + raise ValueError("Invalid units specified") + unit = Limit.UNIT_MAP[unit] + + # Build a limit + result.append(Limit(verb, uri, regex, value, unit)) + + return result + + +class WsgiLimiter(object): + """ + Rate-limit checking from a WSGI application. Uses an in-memory `Limiter`. + + To use, POST ``/`` with JSON data such as:: + + { + "verb" : GET, + "path" : "/servers" + } + + and receive a 204 No Content, or a 403 Forbidden with an X-Wait-Seconds + header containing the number of seconds to wait before the action would + succeed. + """ + + def __init__(self, limits=None): + """ + Initialize the new `WsgiLimiter`. + + @param limits: List of `Limit` objects + """ + self._limiter = Limiter(limits or DEFAULT_LIMITS) + + @webob.dec.wsgify(RequestClass=wsgi.Request) + def __call__(self, request): + """ + Handles a call to this application. Returns 204 if the request is + acceptable to the limiter, else a 403 is returned with a relevant + header indicating when the request *will* succeed. + """ + if request.method != "POST": + raise webob.exc.HTTPMethodNotAllowed() + + try: + info = dict(jsonutils.loads(request.body)) + except ValueError: + raise webob.exc.HTTPBadRequest() + + username = request.path_info_pop() + verb = info.get("verb") + path = info.get("path") + + delay, error = self._limiter.check_for_delay(verb, path, username) + + if delay: + headers = {"X-Wait-Seconds": "%.2f" % delay} + return webob.exc.HTTPForbidden(headers=headers, explanation=error) + else: + return webob.exc.HTTPNoContent() + + +class WsgiLimiterProxy(object): + """ + Rate-limit requests based on answers from a remote source. + """ + + def __init__(self, limiter_address): + """ + Initialize the new `WsgiLimiterProxy`. + + @param limiter_address: IP/port combination of where to request limit + """ + self.limiter_address = limiter_address + + def check_for_delay(self, verb, path, username=None): + body = jsonutils.dumps({"verb": verb, "path": path}) + headers = {"Content-Type": "application/json"} + + conn = httplib.HTTPConnection(self.limiter_address) + + if username: + conn.request("POST", "/%s" % (username), body, headers) + else: + conn.request("POST", "/", body, headers) + + resp = conn.getresponse() + + if 200 >= resp.status < 300: + return None, None + + return resp.getheader("X-Wait-Seconds"), resp.read() or None + + # This was ported from nova. + # Keeping it as a static method for the sake of consistency + # + # Note: This method gets called before the class is instantiated, + # so this must be either a static method or a class method. It is + # used to develop a list of limits to feed to the constructor. + # This implementation returns an empty list, since all limit + # decisions are made by a remote server. + @staticmethod + def parse_limits(limits): + """ + Ignore a limits string--simply doesn't apply for the limit + proxy. + + @return: Empty list. + """ + + return [] diff --git a/reddwarf/common/schemas/atom-link.rng b/reddwarf/common/schemas/atom-link.rng new file mode 100644 index 0000000000..edba5eee6c --- /dev/null +++ b/reddwarf/common/schemas/atom-link.rng @@ -0,0 +1,141 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 1 + [^:]* + + + + + + .+/.+ + + + + + + [A-Za-z]{1,8}(-[A-Za-z0-9]{1,8})* + + + + + + + + + + + + xml:base + xml:lang + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/reddwarf/common/schemas/atom.rng b/reddwarf/common/schemas/atom.rng new file mode 100644 index 0000000000..c2df4e4101 --- /dev/null +++ b/reddwarf/common/schemas/atom.rng @@ -0,0 +1,597 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + text + html + + + + + + + + + xhtml + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + An atom:feed must have an atom:author unless all of its atom:entry children have an atom:author. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + An atom:entry must have at least one atom:link element with a rel attribute of 'alternate' or an atom:content. + + + An atom:entry must have an atom:author if its feed does not. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + text + html + + + + + + + + + + + + + xhtml + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 1 + [^:]* + + + + + + .+/.+ + + + + + + [A-Za-z]{1,8}(-[A-Za-z0-9]{1,8})* + + + + + + + + + + .+@.+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + xml:base + xml:lang + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/reddwarf/common/schemas/v1.1/limits.rng b/reddwarf/common/schemas/v1.1/limits.rng new file mode 100644 index 0000000000..a66af4b9c4 --- /dev/null +++ b/reddwarf/common/schemas/v1.1/limits.rng @@ -0,0 +1,28 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/reddwarf/common/wsgi.py b/reddwarf/common/wsgi.py index 673d6f3855..d68562c75f 100644 --- a/reddwarf/common/wsgi.py +++ b/reddwarf/common/wsgi.py @@ -17,12 +17,15 @@ """Wsgi helper utilities for reddwarf""" import eventlet.wsgi +import math import paste.urlmap import re +import time import traceback import webob import webob.dec import webob.exc +from lxml import etree from paste import deploy from xml.dom import minidom @@ -30,13 +33,14 @@ from reddwarf.common import context as rd_context from reddwarf.common import exception from reddwarf.common import utils from reddwarf.openstack.common.gettextutils import _ +from reddwarf.openstack.common import jsonutils + from reddwarf.openstack.common import pastedeploy from reddwarf.openstack.common import service from reddwarf.openstack.common import wsgi as openstack_wsgi from reddwarf.openstack.common import log as logging from reddwarf.common import cfg - CONTEXT_KEY = 'reddwarf.context' Router = openstack_wsgi.Router Debug = openstack_wsgi.Debug @@ -130,6 +134,54 @@ def launch(app_name, port, paste_config_file, data={}, return service.launch(server) +# Note: taken from Nova +def serializers(**serializers): + """Attaches serializers to a method. + + This decorator associates a dictionary of serializers with a + method. Note that the function attributes are directly + manipulated; the method is not wrapped. + """ + + def decorator(func): + if not hasattr(func, 'wsgi_serializers'): + func.wsgi_serializers = {} + func.wsgi_serializers.update(serializers) + return func + return decorator + + +class ReddwarfMiddleware(Middleware): + + # Note: taken from nova + @classmethod + def factory(cls, global_config, **local_config): + """Used for paste app factories in paste.deploy config files. + + Any local configuration (that is, values under the [filter:APPNAME] + section of the paste config) will be passed into the `__init__` method + as kwargs. + + A hypothetical configuration would look like: + + [filter:analytics] + redis_host = 127.0.0.1 + paste.filter_factory = nova.api.analytics:Analytics.factory + + which would result in a call to the `Analytics` class as + + import nova.api.analytics + analytics.Analytics(app_from_paste, redis_host='127.0.0.1') + + You could of course re-implement the `factory` method in subclasses, + but using the kwarg passing it shouldn't be necessary. + + """ + def _factory(app): + return cls(app, **local_config) + return _factory + + class VersionedURLMap(object): def __init__(self, urlmap): @@ -591,3 +643,186 @@ class FaultWrapper(openstack_wsgi.Middleware): def _factory(app): return cls(app) return _factory + + +# ported from Nova +class OverLimitFault(webob.exc.HTTPException): + """ + Rate-limited request response. + """ + + def __init__(self, message, details, retry_time): + """ + Initialize new `OverLimitFault` with relevant information. + """ + hdrs = OverLimitFault._retry_after(retry_time) + self.wrapped_exc = webob.exc.HTTPRequestEntityTooLarge(headers=hdrs) + self.content = {"overLimit": {"code": self.wrapped_exc.status_int, + "message": message, + "details": details, + "retryAfter": hdrs['Retry-After'], + }, + } + + @staticmethod + def _retry_after(retry_time): + delay = int(math.ceil(retry_time - time.time())) + retry_after = delay if delay > 0 else 0 + headers = {'Retry-After': '%d' % retry_after} + return headers + + @webob.dec.wsgify(RequestClass=Request) + def __call__(self, request): + """ + Return the wrapped exception with a serialized body conforming to our + error format. + """ + content_type = request.best_match_content_type() + metadata = {"attributes": {"overLimit": ["code", "retryAfter"]}} + + xml_serializer = XMLDictSerializer(metadata, XMLNS) + serializer = {'application/xml': xml_serializer, + 'application/json': JSONDictSerializer(), + }[content_type] + + content = serializer.serialize(self.content) + self.wrapped_exc.body = content + self.wrapped_exc.content_type = content_type + + return self.wrapped_exc + + +class ActionDispatcher(object): + """Maps method name to local methods through action name.""" + + def dispatch(self, *args, **kwargs): + """Find and call local method.""" + action = kwargs.pop('action', 'default') + action_method = getattr(self, str(action), self.default) + return action_method(*args, **kwargs) + + def default(self, data): + raise NotImplementedError() + + +class DictSerializer(ActionDispatcher): + """Default request body serialization.""" + + def serialize(self, data, action='default'): + return self.dispatch(data, action=action) + + def default(self, data): + return "" + + +class JSONDictSerializer(DictSerializer): + """Default JSON request body serialization.""" + + def default(self, data): + return jsonutils.dumps(data) + + +class XMLDictSerializer(DictSerializer): + + def __init__(self, metadata=None, xmlns=None): + """ + :param metadata: information needed to deserialize xml into + a dictionary. + :param xmlns: XML namespace to include with serialized xml + """ + super(XMLDictSerializer, self).__init__() + self.metadata = metadata or {} + self.xmlns = xmlns + + def default(self, data): + # We expect data to contain a single key which is the XML root. + root_key = data.keys()[0] + doc = minidom.Document() + node = self._to_xml_node(doc, self.metadata, root_key, data[root_key]) + + return self.to_xml_string(node) + + def to_xml_string(self, node, has_atom=False): + self._add_xmlns(node, has_atom) + return node.toxml('UTF-8') + + #NOTE (ameade): the has_atom should be removed after all of the + # xml serializers and view builders have been updated to the current + # spec that required all responses include the xmlns:atom, the has_atom + # flag is to prevent current tests from breaking + def _add_xmlns(self, node, has_atom=False): + if self.xmlns is not None: + node.setAttribute('xmlns', self.xmlns) + if has_atom: + node.setAttribute('xmlns:atom', "http://www.w3.org/2005/Atom") + + def _to_xml_node(self, doc, metadata, nodename, data): + """Recursive method to convert data members to XML nodes.""" + result = doc.createElement(nodename) + + # Set the xml namespace if one is specified + # TODO(justinsb): We could also use prefixes on the keys + xmlns = metadata.get('xmlns', None) + if xmlns: + result.setAttribute('xmlns', xmlns) + + #TODO(bcwaldon): accomplish this without a type-check + if isinstance(data, list): + collections = metadata.get('list_collections', {}) + if nodename in collections: + metadata = collections[nodename] + for item in data: + node = doc.createElement(metadata['item_name']) + node.setAttribute(metadata['item_key'], str(item)) + result.appendChild(node) + return result + singular = metadata.get('plurals', {}).get(nodename, None) + if singular is None: + if nodename.endswith('s'): + singular = nodename[:-1] + else: + singular = 'item' + for item in data: + node = self._to_xml_node(doc, metadata, singular, item) + result.appendChild(node) + #TODO(bcwaldon): accomplish this without a type-check + elif isinstance(data, dict): + collections = metadata.get('dict_collections', {}) + if nodename in collections: + metadata = collections[nodename] + for k, v in data.items(): + node = doc.createElement(metadata['item_name']) + node.setAttribute(metadata['item_key'], str(k)) + text = doc.createTextNode(str(v)) + node.appendChild(text) + result.appendChild(node) + return result + attrs = metadata.get('attributes', {}).get(nodename, {}) + for k, v in data.items(): + if k in attrs: + result.setAttribute(k, str(v)) + else: + if k == "deleted": + v = str(bool(v)) + node = self._to_xml_node(doc, metadata, k, v) + result.appendChild(node) + else: + # Type is atom + node = doc.createTextNode(str(data)) + result.appendChild(node) + return result + + def _create_link_nodes(self, xml_doc, links): + link_nodes = [] + for link in links: + link_node = xml_doc.createElement('atom:link') + link_node.setAttribute('rel', link['rel']) + link_node.setAttribute('href', link['href']) + if 'type' in link: + link_node.setAttribute('type', link['type']) + link_nodes.append(link_node) + return link_nodes + + def _to_xml(self, root): + """Convert the xml object to an xml string.""" + return etree.tostring(root, encoding='UTF-8', xml_declaration=True) diff --git a/reddwarf/common/xmlutil.py b/reddwarf/common/xmlutil.py new file mode 100644 index 0000000000..934da12ca1 --- /dev/null +++ b/reddwarf/common/xmlutil.py @@ -0,0 +1,910 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2011 OpenStack LLC. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import os.path + +from lxml import etree + + +XMLNS_V10 = 'http://docs.rackspacecloud.com/servers/api/v1.0' +XMLNS_V11 = 'http://docs.openstack.org/database/api/v1.1' +XMLNS_COMMON_V10 = 'http://docs.openstack.org/common/api/v1.0' +XMLNS_ATOM = 'http://www.w3.org/2005/Atom' + + +def validate_schema(xml, schema_name): + if isinstance(xml, str): + xml = etree.fromstring(xml) + base_path = 'reddwarf/common/schemas/v1.1/' + if schema_name in ('atom', 'atom-link'): + base_path = 'reddwarf/common/schemas/' + + # TODO: need to figure out our schema paths later + import reddwarf + schema_path = os.path.join(os.path.abspath(reddwarf.__file__) + .split('reddwarf/__init__.py')[0], + '%s%s.rng' % (base_path, schema_name)) + + schema_doc = etree.parse(schema_path) + relaxng = etree.RelaxNG(schema_doc) + relaxng.assertValid(xml) + + +class Selector(object): + """Selects datum to operate on from an object.""" + + def __init__(self, *chain): + """Initialize the selector. + + Each argument is a subsequent index into the object. + """ + + self.chain = chain + + def __repr__(self): + """Return a representation of the selector.""" + + return "Selector" + repr(self.chain) + + def __call__(self, obj, do_raise=False): + """Select a datum to operate on. + + Selects the relevant datum within the object. + + :param obj: The object from which to select the object. + :param do_raise: If False (the default), return None if the + indexed datum does not exist. Otherwise, + raise a KeyError. + """ + + # Walk the selector list + for elem in self.chain: + # If it's callable, call it + if callable(elem): + obj = elem(obj) + else: + # Use indexing + try: + obj = obj[elem] + except (KeyError, IndexError): + # No sense going any further + if do_raise: + # Convert to a KeyError, for consistency + raise KeyError(elem) + return None + + # Return the finally-selected object + return obj + + +def get_items(obj): + """Get items in obj.""" + + return list(obj.items()) + + +class EmptyStringSelector(Selector): + """Returns the empty string if Selector would return None.""" + def __call__(self, obj, do_raise=False): + """Returns empty string if the selected value does not exist.""" + + try: + return super(EmptyStringSelector, self).__call__(obj, True) + except KeyError: + return "" + + +class ConstantSelector(object): + """Returns a constant.""" + + def __init__(self, value): + """Initialize the selector. + + :param value: The value to return. + """ + + self.value = value + + def __repr__(self): + """Return a representation of the selector.""" + + return repr(self.value) + + def __call__(self, _obj, _do_raise=False): + """Select a datum to operate on. + + Returns a constant value. Compatible with + Selector.__call__(). + """ + + return self.value + + +class TemplateElement(object): + """Represent an element in the template.""" + + def __init__(self, tag, attrib=None, selector=None, subselector=None, + **extra): + """Initialize an element. + + Initializes an element in the template. Keyword arguments + specify attributes to be set on the element; values must be + callables. See TemplateElement.set() for more information. + + :param tag: The name of the tag to create. + :param attrib: An optional dictionary of element attributes. + :param selector: An optional callable taking an object and + optional boolean do_raise indicator and + returning the object bound to the element. + :param subselector: An optional callable taking an object and + optional boolean do_raise indicator and + returning the object bound to the element. + This is used to further refine the datum + object returned by selector in the event + that it is a list of objects. + """ + + # Convert selector into a Selector + if selector is None: + selector = Selector() + elif not callable(selector): + selector = Selector(selector) + + # Convert subselector into a Selector + if subselector is not None and not callable(subselector): + subselector = Selector(subselector) + + self.tag = tag + self.selector = selector + self.subselector = subselector + self.attrib = {} + self._text = None + self._children = [] + self._childmap = {} + + # Run the incoming attributes through set() so that they + # become selectorized + if not attrib: + attrib = {} + attrib.update(extra) + for k, v in attrib.items(): + self.set(k, v) + + def __repr__(self): + """Return a representation of the template element.""" + + return ('<%s.%s %r at %#x>' % + (self.__class__.__module__, self.__class__.__name__, + self.tag, id(self))) + + def __len__(self): + """Return the number of child elements.""" + + return len(self._children) + + def __contains__(self, key): + """Determine whether a child node named by key exists.""" + + return key in self._childmap + + def __getitem__(self, idx): + """Retrieve a child node by index or name.""" + + if isinstance(idx, basestring): + # Allow access by node name + return self._childmap[idx] + else: + return self._children[idx] + + def append(self, elem): + """Append a child to the element.""" + + # Unwrap templates... + elem = elem.unwrap() + + # Avoid duplications + if elem.tag in self._childmap: + raise KeyError(elem.tag) + + self._children.append(elem) + self._childmap[elem.tag] = elem + + def extend(self, elems): + """Append children to the element.""" + + # Pre-evaluate the elements + elemmap = {} + elemlist = [] + for elem in elems: + # Unwrap templates... + elem = elem.unwrap() + + # Avoid duplications + if elem.tag in self._childmap or elem.tag in elemmap: + raise KeyError(elem.tag) + + elemmap[elem.tag] = elem + elemlist.append(elem) + + # Update the children + self._children.extend(elemlist) + self._childmap.update(elemmap) + + def insert(self, idx, elem): + """Insert a child element at the given index.""" + + # Unwrap templates... + elem = elem.unwrap() + + # Avoid duplications + if elem.tag in self._childmap: + raise KeyError(elem.tag) + + self._children.insert(idx, elem) + self._childmap[elem.tag] = elem + + def remove(self, elem): + """Remove a child element.""" + + # Unwrap templates... + elem = elem.unwrap() + + # Check if element exists + if elem.tag not in self._childmap or self._childmap[elem.tag] != elem: + raise ValueError(_('element is not a child')) + + self._children.remove(elem) + del self._childmap[elem.tag] + + def get(self, key): + """Get an attribute. + + Returns a callable which performs datum selection. + + :param key: The name of the attribute to get. + """ + + return self.attrib[key] + + def set(self, key, value=None): + """Set an attribute. + + :param key: The name of the attribute to set. + + :param value: A callable taking an object and optional boolean + do_raise indicator and returning the datum bound + to the attribute. If None, a Selector() will be + constructed from the key. If a string, a + Selector() will be constructed from the string. + """ + + # Convert value to a selector + if value is None: + value = Selector(key) + elif not callable(value): + value = Selector(value) + + self.attrib[key] = value + + def keys(self): + """Return the attribute names.""" + + return self.attrib.keys() + + def items(self): + """Return the attribute names and values.""" + + return self.attrib.items() + + def unwrap(self): + """Unwraps a template to return a template element.""" + + # We are a template element + return self + + def wrap(self): + """Wraps a template element to return a template.""" + + # Wrap in a basic Template + return Template(self) + + def apply(self, elem, obj): + """Apply text and attributes to an etree.Element. + + Applies the text and attribute instructions in the template + element to an etree.Element instance. + + :param elem: An etree.Element instance. + :param obj: The base object associated with this template + element. + """ + + # Start with the text... + if self.text is not None: + elem.text = unicode(self.text(obj)) + + # Now set up all the attributes... + for key, value in self.attrib.items(): + try: + elem.set(key, unicode(value(obj, True))) + except KeyError: + # Attribute has no value, so don't include it + pass + + def _render(self, parent, datum, patches, nsmap): + """Internal rendering. + + Renders the template node into an etree.Element object. + Returns the etree.Element object. + + :param parent: The parent etree.Element instance. + :param datum: The datum associated with this template element. + :param patches: A list of other template elements that must + also be applied. + :param nsmap: An optional namespace dictionary to be + associated with the etree.Element instance. + """ + + # Allocate a node + if callable(self.tag): + tagname = self.tag(datum) + else: + tagname = self.tag + elem = etree.Element(tagname, nsmap=nsmap) + + # If we have a parent, append the node to the parent + if parent is not None: + parent.append(elem) + + # If the datum is None, do nothing else + if datum is None: + return elem + + # Apply this template element to the element + self.apply(elem, datum) + + # Additionally, apply the patches + for patch in patches: + patch.apply(elem, datum) + + # We have fully rendered the element; return it + return elem + + def render(self, parent, obj, patches=[], nsmap=None): + """Render an object. + + Renders an object against this template node. Returns a list + of two-item tuples, where the first item is an etree.Element + instance and the second item is the datum associated with that + instance. + + :param parent: The parent for the etree.Element instances. + :param obj: The object to render this template element + against. + :param patches: A list of other template elements to apply + when rendering this template element. + :param nsmap: An optional namespace dictionary to attach to + the etree.Element instances. + """ + + # First, get the datum we're rendering + data = None if obj is None else self.selector(obj) + + # Check if we should render at all + if not self.will_render(data): + return [] + elif data is None: + return [(self._render(parent, None, patches, nsmap), None)] + + # Make the data into a list if it isn't already + if not isinstance(data, list): + data = [data] + elif parent is None: + raise ValueError(_('root element selecting a list')) + + # Render all the elements + elems = [] + for datum in data: + if self.subselector is not None: + datum = self.subselector(datum) + elems.append((self._render(parent, datum, patches, nsmap), datum)) + + # Return all the elements rendered, as well as the + # corresponding datum for the next step down the tree + return elems + + def will_render(self, datum): + """Hook method. + + An overridable hook method to determine whether this template + element will be rendered at all. By default, returns False + (inhibiting rendering) if the datum is None. + + :param datum: The datum associated with this template element. + """ + + # Don't render if datum is None + return datum is not None + + def _text_get(self): + """Template element text. + + Either None or a callable taking an object and optional + boolean do_raise indicator and returning the datum bound to + the text of the template element. + """ + + return self._text + + def _text_set(self, value): + # Convert value to a selector + if value is not None and not callable(value): + value = Selector(value) + + self._text = value + + def _text_del(self): + self._text = None + + text = property(_text_get, _text_set, _text_del) + + def tree(self): + """Return string representation of the template tree. + + Returns a representation of the template rooted at this + element as a string, suitable for inclusion in debug logs. + """ + + # Build the inner contents of the tag... + contents = [self.tag, '!selector=%r' % self.selector] + + # Add the text... + if self.text is not None: + contents.append('!text=%r' % self.text) + + # Add all the other attributes + for key, value in self.attrib.items(): + contents.append('%s=%r' % (key, value)) + + # If there are no children, return it as a closed tag + if len(self) == 0: + return '<%s/>' % ' '.join([str(i) for i in contents]) + + # OK, recurse to our children + children = [c.tree() for c in self] + + # Return the result + return ('<%s>%s' % + (' '.join(contents), ''.join(children), self.tag)) + + +def SubTemplateElement(parent, tag, attrib=None, selector=None, + subselector=None, **extra): + """Create a template element as a child of another. + + Corresponds to the etree.SubElement interface. Parameters are as + for TemplateElement, with the addition of the parent. + """ + + # Convert attributes + attrib = attrib or {} + attrib.update(extra) + + # Get a TemplateElement + elem = TemplateElement(tag, attrib=attrib, selector=selector, + subselector=subselector) + + # Append the parent safely + if parent is not None: + parent.append(elem) + + return elem + + +class Template(object): + """Represent a template.""" + + def __init__(self, root, nsmap=None): + """Initialize a template. + + :param root: The root element of the template. + :param nsmap: An optional namespace dictionary to be + associated with the root element of the + template. + """ + + self.root = root.unwrap() if root is not None else None + self.nsmap = nsmap or {} + self.serialize_options = dict(encoding='UTF-8', xml_declaration=True) + + def _serialize(self, parent, obj, siblings, nsmap=None): + """Internal serialization. + + Recursive routine to build a tree of etree.Element instances + from an object based on the template. Returns the first + etree.Element instance rendered, or None. + + :param parent: The parent etree.Element instance. Can be + None. + :param obj: The object to render. + :param siblings: The TemplateElement instances against which + to render the object. + :param nsmap: An optional namespace dictionary to be + associated with the etree.Element instance + rendered. + """ + + # First step, render the element + elems = siblings[0].render(parent, obj, siblings[1:], nsmap) + + # Now, recurse to all child elements + seen = set() + for idx, sibling in enumerate(siblings): + for child in sibling: + # Have we handled this child already? + if child.tag in seen: + continue + seen.add(child.tag) + + # Determine the child's siblings + nieces = [child] + for sib in siblings[idx + 1:]: + if child.tag in sib: + nieces.append(sib[child.tag]) + + # Now we recurse for every data element + for elem, datum in elems: + self._serialize(elem, datum, nieces) + + # Return the first element; at the top level, this will be the + # root element + if elems: + return elems[0][0] + + def serialize(self, obj, *args, **kwargs): + """Serialize an object. + + Serializes an object against the template. Returns a string + with the serialized XML. Positional and keyword arguments are + passed to etree.tostring(). + + :param obj: The object to serialize. + """ + + elem = self.make_tree(obj) + if elem is None: + return '' + + for k, v in self.serialize_options.items(): + kwargs.setdefault(k, v) + + # Serialize it into XML + return etree.tostring(elem, *args, **kwargs) + + def make_tree(self, obj): + """Create a tree. + + Serializes an object against the template. Returns an Element + node with appropriate children. + + :param obj: The object to serialize. + """ + + # If the template is empty, return the empty string + if self.root is None: + return None + + # Get the siblings and nsmap of the root element + siblings = self._siblings() + nsmap = self._nsmap() + + # Form the element tree + return self._serialize(None, obj, siblings, nsmap) + + def _siblings(self): + """Hook method for computing root siblings. + + An overridable hook method to return the siblings of the root + element. By default, this is the root element itself. + """ + + return [self.root] + + def _nsmap(self): + """Hook method for computing the namespace dictionary. + + An overridable hook method to return the namespace dictionary. + """ + + return self.nsmap.copy() + + def unwrap(self): + """Unwraps a template to return a template element.""" + + # Return the root element + return self.root + + def wrap(self): + """Wraps a template element to return a template.""" + + # We are a template + return self + + def apply(self, master): + """Hook method for determining slave applicability. + + An overridable hook method used to determine if this template + is applicable as a slave to a given master template. + + :param master: The master template to test. + """ + + return True + + def tree(self): + """Return string representation of the template tree. + + Returns a representation of the template as a string, suitable + for inclusion in debug logs. + """ + + return "%r: %s" % (self, self.root.tree()) + + +class MasterTemplate(Template): + """Represent a master template. + + Master templates are versioned derivatives of templates that + additionally allow slave templates to be attached. Slave + templates allow modification of the serialized result without + directly changing the master. + """ + + def __init__(self, root, version, nsmap=None): + """Initialize a master template. + + :param root: The root element of the template. + :param version: The version number of the template. + :param nsmap: An optional namespace dictionary to be + associated with the root element of the + template. + """ + + super(MasterTemplate, self).__init__(root, nsmap) + self.version = version + self.slaves = [] + + def __repr__(self): + """Return string representation of the template.""" + + return ("<%s.%s object version %s at %#x>" % + (self.__class__.__module__, self.__class__.__name__, + self.version, id(self))) + + def _siblings(self): + """Hook method for computing root siblings. + + An overridable hook method to return the siblings of the root + element. This is the root element plus the root elements of + all the slave templates. + """ + + return [self.root] + [slave.root for slave in self.slaves] + + def _nsmap(self): + """Hook method for computing the namespace dictionary. + + An overridable hook method to return the namespace dictionary. + The namespace dictionary is computed by taking the master + template's namespace dictionary and updating it from all the + slave templates. + """ + + nsmap = self.nsmap.copy() + for slave in self.slaves: + nsmap.update(slave._nsmap()) + return nsmap + + def attach(self, *slaves): + """Attach one or more slave templates. + + Attaches one or more slave templates to the master template. + Slave templates must have a root element with the same tag as + the master template. The slave template's apply() method will + be called to determine if the slave should be applied to this + master; if it returns False, that slave will be skipped. + (This allows filtering of slaves based on the version of the + master template.) + """ + + slave_list = [] + for slave in slaves: + slave = slave.wrap() + + # Make sure we have a tree match + if slave.root.tag != self.root.tag: + slavetag = slave.root.tag + mastertag = self.root.tag + msg = _("Template tree mismatch; adding slave %(slavetag)s " + "to master %(mastertag)s") % locals() + raise ValueError(msg) + + # Make sure slave applies to this template + if not slave.apply(self): + continue + + slave_list.append(slave) + + # Add the slaves + self.slaves.extend(slave_list) + + def copy(self): + """Return a copy of this master template.""" + + # Return a copy of the MasterTemplate + tmp = self.__class__(self.root, self.version, self.nsmap) + tmp.slaves = self.slaves[:] + return tmp + + +class SlaveTemplate(Template): + """Represent a slave template. + + Slave templates are versioned derivatives of templates. Each + slave has a minimum version and optional maximum version of the + master template to which they can be attached. + """ + + def __init__(self, root, min_vers, max_vers=None, nsmap=None): + """Initialize a slave template. + + :param root: The root element of the template. + :param min_vers: The minimum permissible version of the master + template for this slave template to apply. + :param max_vers: An optional upper bound for the master + template version. + :param nsmap: An optional namespace dictionary to be + associated with the root element of the + template. + """ + + super(SlaveTemplate, self).__init__(root, nsmap) + self.min_vers = min_vers + self.max_vers = max_vers + + def __repr__(self): + """Return string representation of the template.""" + + return ("<%s.%s object versions %s-%s at %#x>" % + (self.__class__.__module__, self.__class__.__name__, + self.min_vers, self.max_vers, id(self))) + + def apply(self, master): + """Hook method for determining slave applicability. + + An overridable hook method used to determine if this template + is applicable as a slave to a given master template. This + version requires the master template to have a version number + between min_vers and max_vers. + + :param master: The master template to test. + """ + + # Does the master meet our minimum version requirement? + if master.version < self.min_vers: + return False + + # How about our maximum version requirement? + if self.max_vers is not None and master.version > self.max_vers: + return False + + return True + + +class TemplateBuilder(object): + """Template builder. + + This class exists to allow templates to be lazily built without + having to build them each time they are needed. It must be + subclassed, and the subclass must implement the construct() + method, which must return a Template (or subclass) instance. The + constructor will always return the template returned by + construct(), or, if it has a copy() method, a copy of that + template. + """ + + _tmpl = None + + def __new__(cls, copy=True): + """Construct and return a template. + + :param copy: If True (the default), a copy of the template + will be constructed and returned, if possible. + """ + + # Do we need to construct the template? + if cls._tmpl is None: + tmp = super(TemplateBuilder, cls).__new__(cls) + + # Construct the template + cls._tmpl = tmp.construct() + + # If the template has a copy attribute, return the result of + # calling it + if copy and hasattr(cls._tmpl, 'copy'): + return cls._tmpl.copy() + + # Return the template + return cls._tmpl + + def construct(self): + """Construct a template. + + Called to construct a template instance, which it must return. + Only called once. + """ + + raise NotImplementedError(_("subclasses must implement construct()!")) + + +def make_links(parent, selector=None): + """ + Attach an Atom element to the parent. + """ + + elem = SubTemplateElement(parent, '{%s}link' % XMLNS_ATOM, + selector=selector) + elem.set('rel') + elem.set('type') + elem.set('href') + + # Just for completeness... + return elem + + +def make_flat_dict(name, selector=None, subselector=None, ns=None): + """ + Utility for simple XML templates that traditionally used + XMLDictSerializer with no metadata. Returns a template element + where the top-level element has the given tag name, and where + sub-elements have tag names derived from the object's keys and + text derived from the object's values. This only works for flat + dictionary objects, not dictionaries containing nested lists or + dictionaries. + """ + + # Set up the names we need... + if ns is None: + elemname = name + tagname = Selector(0) + else: + elemname = '{%s}%s' % (ns, name) + tagname = lambda obj, do_raise=False: '{%s}%s' % (ns, obj[0]) + + if selector is None: + selector = name + + # Build the root element + root = TemplateElement(elemname, selector=selector, + subselector=subselector) + + # Build an element to represent all the keys and values + elem = SubTemplateElement(root, tagname, selector=get_items) + elem.text = 1 + + # Return the template + return root diff --git a/reddwarf/limits/__init__.py b/reddwarf/limits/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/reddwarf/limits/service.py b/reddwarf/limits/service.py new file mode 100644 index 0000000000..3bb8a7b3e8 --- /dev/null +++ b/reddwarf/limits/service.py @@ -0,0 +1,49 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2013 OpenStack LLC +# Copyright 2013 Hewlett-Packard Development Company, L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from reddwarf.common import wsgi as base_wsgi +from reddwarf.common.limits import LimitsTemplate +from reddwarf.limits import views +from reddwarf.openstack.common import wsgi + + +class LimitsController(base_wsgi.Controller): + """ + Controller for accessing limits in the OpenStack API. + Note: this is a little different than how other controllers are implemented + """ + + @base_wsgi.serializers(xml=LimitsTemplate) + def index(self, req, tenant_id): + """ + Return all global and rate limit information. + """ + context = req.environ[base_wsgi.CONTEXT_KEY] + + # + # TODO: hook this in later + #quotas = QUOTAS.get_project_quotas(context, context.project_id, + # usages=False) + #abs_limits = dict((k, v['limit']) for k, v in quotas.items()) + abs_limits = {} + rate_limits = req.environ.get("reddwarf.limits", []) + + builder = self._get_view_builder(req) + return builder.build(rate_limits, abs_limits) + + def _get_view_builder(self, req): + return views.ViewBuilder() diff --git a/reddwarf/limits/views.py b/reddwarf/limits/views.py new file mode 100644 index 0000000000..6158cc2bb7 --- /dev/null +++ b/reddwarf/limits/views.py @@ -0,0 +1,98 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010-2011 OpenStack LLC. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import datetime + +from reddwarf.openstack.common import timeutils + + +class ViewBuilder(object): + """OpenStack API base limits view builder.""" + + def build(self, rate_limits, absolute_limits): + rate_limits = self._build_rate_limits(rate_limits) + absolute_limits = self._build_absolute_limits(absolute_limits) + + output = { + "limits": { + "rate": rate_limits, + "absolute": absolute_limits, + }, + } + + return output + + def _build_absolute_limits(self, absolute_limits): + """Builder for absolute limits + + absolute_limits should be given as a dict of limits. + For example: {"ram": 512, "gigabytes": 1024}. + + """ + limit_names = { + "ram": ["maxTotalRAMSize"], + "instances": ["maxTotalInstances"], + "cores": ["maxTotalCores"], + "metadata_items": ["maxServerMeta", "maxImageMeta"], + "injected_files": ["maxPersonality"], + "injected_file_content_bytes": ["maxPersonalitySize"], + "security_groups": ["maxSecurityGroups"], + "security_group_rules": ["maxSecurityGroupRules"], + } + limits = {} + for name, value in absolute_limits.iteritems(): + if name in limit_names and value is not None: + for name in limit_names[name]: + limits[name] = value + return limits + + def _build_rate_limits(self, rate_limits): + limits = [] + for rate_limit in rate_limits: + _rate_limit_key = None + _rate_limit = self._build_rate_limit(rate_limit) + + # check for existing key + for limit in limits: + if (limit["uri"] == rate_limit["URI"] and + limit["regex"] == rate_limit["regex"]): + _rate_limit_key = limit + break + + # ensure we have a key if we didn't find one + if not _rate_limit_key: + _rate_limit_key = { + "uri": rate_limit["URI"], + "regex": rate_limit["regex"], + "limit": [], + } + limits.append(_rate_limit_key) + + _rate_limit_key["limit"].append(_rate_limit) + + return limits + + def _build_rate_limit(self, rate_limit): + _get_utc = datetime.datetime.utcfromtimestamp + next_avail = _get_utc(rate_limit["resetTime"]) + return { + "verb": rate_limit["verb"], + "value": rate_limit["value"], + "remaining": int(rate_limit["remaining"]), + "unit": rate_limit["unit"], + "next-available": timeutils.isotime(at=next_avail), + } diff --git a/reddwarf/tests/api/limits.py b/reddwarf/tests/api/limits.py new file mode 100644 index 0000000000..c22fa26ff7 --- /dev/null +++ b/reddwarf/tests/api/limits.py @@ -0,0 +1,109 @@ +from nose.tools import assert_equal +from nose.tools import assert_false +from nose.tools import assert_true + +from proboscis import before_class +from proboscis import test + +from reddwarf.openstack.common import timeutils +from reddwarf.tests.util import create_dbaas_client +from reddwarf.tests.util import test_config +from reddwarfclient import exceptions + +from datetime import datetime + +GROUP = "dbaas.api.limits" +DEFAULT_RATE = 200 +# Note: This should not be enabled until rd-client merges +RD_CLIENT_OK = False + + +@test(groups=[GROUP]) +class Limits(object): + + @before_class + def setUp(self): + rate_user = self._get_user('rate_limit') + self.rd_client = create_dbaas_client(rate_user) + + def _get_user(self, name): + return test_config.users.find_user_by_name(name) + + def _get_next_available(self, resource): + return resource.__dict__['next-available'] + + def __is_available(self, next_available): + dt_next = timeutils.parse_isotime(next_available) + dt_now = datetime.now() + return dt_next.time() < dt_now.time() + + @test(enabled=RD_CLIENT_OK) + def test_limits_index(self): + """test_limits_index""" + r1, r2, r3, r4 = self.rd_client.limits.index() + + assert_equal(r1.verb, "POST") + assert_equal(r1.unit, "MINUTE") + assert_true(r1.remaining <= DEFAULT_RATE) + + next_available = self._get_next_available(r1) + assert_true(next_available is not None) + + assert_equal(r2.verb, "PUT") + assert_equal(r2.unit, "MINUTE") + assert_true(r2.remaining <= DEFAULT_RATE) + + next_available = self._get_next_available(r2) + assert_true(next_available is not None) + + assert_equal(r3.verb, "DELETE") + assert_equal(r3.unit, "MINUTE") + assert_true(r3.remaining <= DEFAULT_RATE) + + next_available = self._get_next_available(r3) + assert_true(next_available is not None) + + assert_equal(r4.verb, "GET") + assert_equal(r4.unit, "MINUTE") + assert_true(r4.remaining <= DEFAULT_RATE) + + next_available = self._get_next_available(r4) + assert_true(next_available is not None) + + @test(enabled=RD_CLIENT_OK) + def test_limits_get_remaining(self): + """test_limits_get_remaining""" + gets = None + for i in xrange(5): + r1, r2, r3, r4 = self.rd_client.limits.index() + gets = r4 + + assert_equal(gets.verb, "GET") + assert_equal(gets.unit, "MINUTE") + assert_true(gets.remaining <= DEFAULT_RATE - 5) + + next_available = self._get_next_available(gets) + assert_true(next_available is not None) + + @test(enabled=RD_CLIENT_OK) + def test_limits_exception(self): + """test_limits_exception""" + + # use a different user to avoid throttling tests run out of order + rate_user_exceeded = self._get_user('rate_limit_exceeded') + rd_client = create_dbaas_client(rate_user_exceeded) + + gets = None + encountered = False + for i in xrange(DEFAULT_RATE + 50): + try: + r1, r2, r3, r4 = rd_client.limits.index() + gets = r4 + assert_equal(gets.verb, "GET") + assert_equal(gets.unit, "MINUTE") + + except exceptions.OverLimit: + encountered = True + + assert_true(encountered) + assert_true(gets.remaining <= 50) diff --git a/reddwarf/tests/unittests/api/common/__init__.py b/reddwarf/tests/unittests/api/common/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/reddwarf/tests/unittests/api/common/test_limits.py b/reddwarf/tests/unittests/api/common/test_limits.py new file mode 100644 index 0000000000..a7b19c7482 --- /dev/null +++ b/reddwarf/tests/unittests/api/common/test_limits.py @@ -0,0 +1,741 @@ +# Copyright 2011 OpenStack LLC. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Tests dealing with HTTP rate-limiting. +""" + +import httplib +import StringIO +from xml.dom import minidom +from lxml import etree +import testtools +import webob + +from mockito import when + +from reddwarf.common import limits +from reddwarf.common import xmlutil +from reddwarf.common.limits import Limit +from reddwarf.limits import views +from reddwarf.openstack.common import jsonutils + +from reddwarf.tests.unittests.util.matchers import DictMatches + +TEST_LIMITS = [ + Limit("GET", "/delayed", "^/delayed", 1, limits.PER_MINUTE), + Limit("POST", "*", ".*", 7, limits.PER_MINUTE), + Limit("POST", "/servers", "^/servers", 3, limits.PER_MINUTE), + Limit("PUT", "*", "", 10, limits.PER_MINUTE), + Limit("PUT", "/servers", "^/servers", 5, limits.PER_MINUTE), +] +NS = { + 'atom': 'http://www.w3.org/2005/Atom', + 'ns': 'http://docs.openstack.org/common/api/v1.0' +} + + +class BaseLimitTestSuite(testtools.TestCase): + """Base test suite which provides relevant stubs and time abstraction.""" + + def setUp(self): + super(BaseLimitTestSuite, self).setUp() + + self.absolute_limits = {} + + +class LimitsControllerTest(BaseLimitTestSuite): + """ + Tests for `limits.LimitsController` class. + TODO: add test cases once absolute limits are integrated + """ + pass + + +class TestLimiter(limits.Limiter): + pass + + +class LimitMiddlewareTest(BaseLimitTestSuite): + """ + Tests for the `limits.RateLimitingMiddleware` class. + """ + + @webob.dec.wsgify + def _empty_app(self, request): + """Do-nothing WSGI app.""" + pass + + def setUp(self): + """Prepare middleware for use through fake WSGI app.""" + super(LimitMiddlewareTest, self).setUp() + _limits = '(GET, *, .*, 1, MINUTE)' + self.app = limits.RateLimitingMiddleware(self._empty_app, _limits, + "%s.TestLimiter" % + self.__class__.__module__) + + def test_limit_class(self): + # Test that middleware selected correct limiter class. + assert isinstance(self.app._limiter, TestLimiter) + + def test_good_request(self): + # Test successful GET request through middleware. + request = webob.Request.blank("/") + response = request.get_response(self.app) + self.assertEqual(200, response.status_int) + + def test_limited_request_json(self): + # Test a rate-limited (413) GET request through middleware. + request = webob.Request.blank("/") + response = request.get_response(self.app) + self.assertEqual(200, response.status_int) + + request = webob.Request.blank("/") + response = request.get_response(self.app) + self.assertEqual(response.status_int, 413) + + self.assertTrue('Retry-After' in response.headers) + retry_after = int(response.headers['Retry-After']) + self.assertAlmostEqual(retry_after, 60, 1) + + body = jsonutils.loads(response.body) + expected = "Only 1 GET request(s) can be made to * every minute." + value = body["overLimit"]["details"].strip() + self.assertEqual(value, expected) + + self.assertTrue("retryAfter" in body["overLimit"]) + retryAfter = body["overLimit"]["retryAfter"] + self.assertEqual(retryAfter, "60") + + def test_limited_request_xml(self): + # Test a rate-limited (413) response as XML. + request = webob.Request.blank("/") + response = request.get_response(self.app) + self.assertEqual(200, response.status_int) + + request = webob.Request.blank("/") + request.accept = "application/xml" + response = request.get_response(self.app) + self.assertEqual(response.status_int, 413) + + root = minidom.parseString(response.body).childNodes[0] + expected = "Only 1 GET request(s) can be made to * every minute." + + self.assertNotEqual(root.attributes.getNamedItem("retryAfter"), None) + retryAfter = root.attributes.getNamedItem("retryAfter").value + self.assertEqual(retryAfter, "60") + + details = root.getElementsByTagName("details") + self.assertEqual(details.length, 1) + + value = details.item(0).firstChild.data.strip() + self.assertEqual(value, expected) + + +class LimitTest(BaseLimitTestSuite): + """ + Tests for the `limits.Limit` class. + """ + + def test_GET_no_delay(self): + # Test a limit handles 1 GET per second. + limit = Limit("GET", "*", ".*", 1, 1) + when(limit)._get_time().thenReturn(0.0) + delay = limit("GET", "/anything") + self.assertEqual(None, delay) + self.assertEqual(0, limit.next_request) + self.assertEqual(0, limit.last_request) + + def test_GET_delay(self): + # Test two calls to 1 GET per second limit. + limit = Limit("GET", "*", ".*", 1, 1) + when(limit)._get_time().thenReturn(0.0) + + delay = limit("GET", "/anything") + self.assertEqual(None, delay) + + delay = limit("GET", "/anything") + self.assertEqual(1, delay) + self.assertEqual(1, limit.next_request) + self.assertEqual(0, limit.last_request) + + when(limit)._get_time().thenReturn(4.0) + + delay = limit("GET", "/anything") + self.assertEqual(None, delay) + self.assertEqual(4, limit.next_request) + self.assertEqual(4, limit.last_request) + + +class ParseLimitsTest(BaseLimitTestSuite): + """ + Tests for the default limits parser in the in-memory + `limits.Limiter` class. + """ + + def test_invalid(self): + # Test that parse_limits() handles invalid input correctly. + self.assertRaises(ValueError, limits.Limiter.parse_limits, + ';;;;;') + + def test_bad_rule(self): + # Test that parse_limits() handles bad rules correctly. + self.assertRaises(ValueError, limits.Limiter.parse_limits, + 'GET, *, .*, 20, minute') + + def test_missing_arg(self): + # Test that parse_limits() handles missing args correctly. + self.assertRaises(ValueError, limits.Limiter.parse_limits, + '(GET, *, .*, 20)') + + def test_bad_value(self): + # Test that parse_limits() handles bad values correctly. + self.assertRaises(ValueError, limits.Limiter.parse_limits, + '(GET, *, .*, foo, minute)') + + def test_bad_unit(self): + # Test that parse_limits() handles bad units correctly. + self.assertRaises(ValueError, limits.Limiter.parse_limits, + '(GET, *, .*, 20, lightyears)') + + def test_multiple_rules(self): + # Test that parse_limits() handles multiple rules correctly. + try: + l = limits.Limiter.parse_limits('(get, *, .*, 20, minute);' + '(PUT, /foo*, /foo.*, 10, hour);' + '(POST, /bar*, /bar.*, 5, second);' + '(Say, /derp*, /derp.*, 1, day)') + except ValueError, e: + assert False, str(e) + + # Make sure the number of returned limits are correct + self.assertEqual(len(l), 4) + + # Check all the verbs... + expected = ['GET', 'PUT', 'POST', 'SAY'] + self.assertEqual([t.verb for t in l], expected) + + # ...the URIs... + expected = ['*', '/foo*', '/bar*', '/derp*'] + self.assertEqual([t.uri for t in l], expected) + + # ...the regexes... + expected = ['.*', '/foo.*', '/bar.*', '/derp.*'] + self.assertEqual([t.regex for t in l], expected) + + # ...the values... + expected = [20, 10, 5, 1] + self.assertEqual([t.value for t in l], expected) + + # ...and the units... + expected = [limits.PER_MINUTE, limits.PER_HOUR, + limits.PER_SECOND, limits.PER_DAY] + self.assertEqual([t.unit for t in l], expected) + + +class LimiterTest(BaseLimitTestSuite): + """ + Tests for the in-memory `limits.Limiter` class. + """ + + def update_limits(self, delay): + for l in TEST_LIMITS: + when(l)._get_time().thenReturn(delay) + + def setUp(self): + """Run before each test.""" + super(LimiterTest, self).setUp() + userlimits = {'user:user3': ''} + + self.update_limits(0.0) + + self.limiter = limits.Limiter(TEST_LIMITS, **userlimits) + + def _check(self, num, verb, url, username=None): + """Check and yield results from checks.""" + for x in xrange(num): + yield self.limiter.check_for_delay(verb, url, username)[0] + + def _check_sum(self, num, verb, url, username=None): + """Check and sum results from checks.""" + results = self._check(num, verb, url, username) + return sum(item for item in results if item) + + def test_no_delay_GET(self): + """ + Simple test to ensure no delay on a single call for a limit verb we + didn"t set. + """ + delay = self.limiter.check_for_delay("GET", "/anything") + self.assertEqual(delay, (None, None)) + + def test_no_delay_PUT(self): + # Simple test to ensure no delay on a single call for a known limit. + delay = self.limiter.check_for_delay("PUT", "/anything") + self.assertEqual(delay, (None, None)) + + def test_delay_PUT(self): + """ + Ensure the 11th PUT will result in a delay of 6.0 seconds until + the next request will be granced. + """ + expected = [None] * 10 + [6.0] + results = list(self._check(11, "PUT", "/anything")) + + self.assertEqual(expected, results) + + def test_delay_POST(self): + """ + Ensure the 8th POST will result in a delay of 6.0 seconds until + the next request will be granced. + """ + expected = [None] * 7 + results = list(self._check(7, "POST", "/anything")) + self.assertEqual(expected, results) + + expected = 60.0 / 7.0 + results = self._check_sum(1, "POST", "/anything") + self.failUnlessAlmostEqual(expected, results, 8) + + def test_delay_GET(self): + # Ensure the 11th GET will result in NO delay. + expected = [None] * 11 + results = list(self._check(11, "GET", "/anything")) + + self.assertEqual(expected, results) + + def test_delay_PUT_servers(self): + """ + Ensure PUT on /servers limits at 5 requests, and PUT elsewhere is still + OK after 5 requests...but then after 11 total requests, PUT limiting + kicks in. + """ + # First 6 requests on PUT /servers + expected = [None] * 5 + [12.0] + results = list(self._check(6, "PUT", "/servers")) + self.assertEqual(expected, results) + + # Next 5 request on PUT /anything + expected = [None] * 4 + [6.0] + results = list(self._check(5, "PUT", "/anything")) + self.assertEqual(expected, results) + + def test_delay_PUT_wait(self): + """ + Ensure after hitting the limit and then waiting for the correct + amount of time, the limit will be lifted. + """ + expected = [None] * 10 + [6.0] + results = list(self._check(11, "PUT", "/anything")) + self.assertEqual(expected, results) + + # Advance time + self.update_limits(6.0) + + expected = [None, 6.0] + results = list(self._check(2, "PUT", "/anything")) + self.assertEqual(expected, results) + + def test_multiple_delays(self): + # Ensure multiple requests still get a delay. + expected = [None] * 10 + [6.0] * 10 + results = list(self._check(20, "PUT", "/anything")) + self.assertEqual(expected, results) + + self.update_limits(1.0) + + expected = [5.0] * 10 + results = list(self._check(10, "PUT", "/anything")) + self.assertEqual(expected, results) + + def test_user_limit(self): + # Test user-specific limits. + self.assertEqual(self.limiter.levels['user3'], []) + + def test_multiple_users(self): + # Tests involving multiple users. + # User1 + self.update_limits(0.0) + expected = [None] * 10 + [6.0] * 10 + results = list(self._check(20, "PUT", "/anything", "user1")) + self.assertEqual(expected, results) + + # User2 + expected = [None] * 10 + [6.0] * 5 + results = list(self._check(15, "PUT", "/anything", "user2")) + self.assertEqual(expected, results) + + # User3 + expected = [None] * 20 + results = list(self._check(20, "PUT", "/anything", "user3")) + self.assertEqual(expected, results) + + self.update_limits(1.0) + # User1 again + expected = [5.0] * 10 + results = list(self._check(10, "PUT", "/anything", "user1")) + self.assertEqual(expected, results) + + self.update_limits(2.0) + + # User1 again + expected = [4.0] * 5 + results = list(self._check(5, "PUT", "/anything", "user2")) + self.assertEqual(expected, results) + + +class WsgiLimiterTest(BaseLimitTestSuite): + """ + Tests for `limits.WsgiLimiter` class. + """ + + def setUp(self): + """Run before each test.""" + super(WsgiLimiterTest, self).setUp() + self.app = limits.WsgiLimiter(TEST_LIMITS) + + def _request_data(self, verb, path): + """Get data describing a limit request verb/path.""" + return jsonutils.dumps({"verb": verb, "path": path}) + + def _request(self, verb, url, username=None): + """Make sure that POSTing to the given url causes the given username + to perform the given action. Make the internal rate limiter return + delay and make sure that the WSGI app returns the correct response. + """ + if username: + request = webob.Request.blank("/%s" % username) + else: + request = webob.Request.blank("/") + + request.method = "POST" + request.body = self._request_data(verb, url) + response = request.get_response(self.app) + + if "X-Wait-Seconds" in response.headers: + self.assertEqual(response.status_int, 403) + return response.headers["X-Wait-Seconds"] + + self.assertEqual(response.status_int, 204) + + def test_invalid_methods(self): + # Only POSTs should work. + requests = [] + for method in ["GET", "PUT", "DELETE", "HEAD", "OPTIONS"]: + request = webob.Request.blank("/", method=method) + response = request.get_response(self.app) + self.assertEqual(response.status_int, 405) + + def test_good_url(self): + delay = self._request("GET", "/something") + self.assertEqual(delay, None) + + def test_escaping(self): + delay = self._request("GET", "/something/jump%20up") + self.assertEqual(delay, None) + + def test_response_to_delays(self): + delay = self._request("GET", "/delayed") + self.assertEqual(delay, None) + + delay = self._request("GET", "/delayed") + self.assertEqual(delay, '60.00') + + def test_response_to_delays_usernames(self): + delay = self._request("GET", "/delayed", "user1") + self.assertEqual(delay, None) + + delay = self._request("GET", "/delayed", "user2") + self.assertEqual(delay, None) + + delay = self._request("GET", "/delayed", "user1") + self.assertEqual(delay, '60.00') + + delay = self._request("GET", "/delayed", "user2") + self.assertEqual(delay, '60.00') + + +class FakeHttplibSocket(object): + """ + Fake `httplib.HTTPResponse` replacement. + """ + + def __init__(self, response_string): + """Initialize new `FakeHttplibSocket`.""" + self._buffer = StringIO.StringIO(response_string) + + def makefile(self, _mode, _other): + """Returns the socket's internal buffer.""" + return self._buffer + + +class FakeHttplibConnection(object): + """ + Fake `httplib.HTTPConnection`. + """ + + def __init__(self, app, host): + """ + Initialize `FakeHttplibConnection`. + """ + self.app = app + self.host = host + + def request(self, method, path, body="", headers=None): + """ + Requests made via this connection actually get translated and routed + into our WSGI app, we then wait for the response and turn it back into + an `httplib.HTTPResponse`. + """ + if not headers: + headers = {} + + req = webob.Request.blank(path) + req.method = method + req.headers = headers + req.host = self.host + req.body = body + + resp = str(req.get_response(self.app)) + resp = "HTTP/1.0 %s" % resp + sock = FakeHttplibSocket(resp) + self.http_response = httplib.HTTPResponse(sock) + self.http_response.begin() + + def getresponse(self): + """Return our generated response from the request.""" + return self.http_response + + +def wire_HTTPConnection_to_WSGI(host, app): + """Monkeypatches HTTPConnection so that if you try to connect to host, you + are instead routed straight to the given WSGI app. + + After calling this method, when any code calls + + httplib.HTTPConnection(host) + + the connection object will be a fake. Its requests will be sent directly + to the given WSGI app rather than through a socket. + + Code connecting to hosts other than host will not be affected. + + This method may be called multiple times to map different hosts to + different apps. + + This method returns the original HTTPConnection object, so that the caller + can restore the default HTTPConnection interface (for all hosts). + """ + + class HTTPConnectionDecorator(object): + """Wraps the real HTTPConnection class so that when you instantiate + the class you might instead get a fake instance.""" + + def __init__(self, wrapped): + self.wrapped = wrapped + + def __call__(self, connection_host, *args, **kwargs): + if connection_host == host: + return FakeHttplibConnection(app, host) + else: + return self.wrapped(connection_host, *args, **kwargs) + + oldHTTPConnection = httplib.HTTPConnection + httplib.HTTPConnection = HTTPConnectionDecorator(httplib.HTTPConnection) + return oldHTTPConnection + + +class WsgiLimiterProxyTest(BaseLimitTestSuite): + """ + Tests for the `limits.WsgiLimiterProxy` class. + """ + + def setUp(self): + """ + Do some nifty HTTP/WSGI magic which allows for WSGI to be called + directly by something like the `httplib` library. + """ + super(WsgiLimiterProxyTest, self).setUp() + self.app = limits.WsgiLimiter(TEST_LIMITS) + self.oldHTTPConnection = ( + wire_HTTPConnection_to_WSGI("169.254.0.1:80", self.app)) + self.proxy = limits.WsgiLimiterProxy("169.254.0.1:80") + + def test_200(self): + # Successful request test. + delay = self.proxy.check_for_delay("GET", "/anything") + self.assertEqual(delay, (None, None)) + + def test_403(self): + # Forbidden request test. + delay = self.proxy.check_for_delay("GET", "/delayed") + self.assertEqual(delay, (None, None)) + + delay, error = self.proxy.check_for_delay("GET", "/delayed") + error = error.strip() + + expected = ("60.00", "403 Forbidden\n\nOnly 1 GET request(s) can be " + "made to /delayed every minute.") + + self.assertEqual((delay, error), expected) + + def tearDown(self): + # restore original HTTPConnection object + httplib.HTTPConnection = self.oldHTTPConnection + super(WsgiLimiterProxyTest, self).tearDown() + + +class LimitsViewBuilderTest(testtools.TestCase): + def setUp(self): + super(LimitsViewBuilderTest, self).setUp() + self.view_builder = views.ViewBuilder() + self.rate_limits = [{"URI": "*", + "regex": ".*", + "value": 10, + "verb": "POST", + "remaining": 2, + "unit": "MINUTE", + "resetTime": 1311272226}, + {"URI": "*/servers", + "regex": "^/servers", + "value": 50, + "verb": "POST", + "remaining": 10, + "unit": "DAY", + "resetTime": 1311272226}] + self.absolute_limits = {"metadata_items": 1, + "injected_files": 5, + "injected_file_content_bytes": 5} + + def test_build_limits(self): + expected_limits = {"limits": { + "rate": [{"uri": "*", + "regex": ".*", + "limit": [{"value": 10, + "verb": "POST", + "remaining": 2, + "unit": "MINUTE", + "next-available": "2011-07-21T18:17:06Z"}]}, + {"uri": "*/servers", + "regex": "^/servers", + "limit": [{"value": 50, + "verb": "POST", + "remaining": 10, + "unit": "DAY", + "next-available": "2011-07-21T18:17:06Z"}]}], + "absolute": { + "maxServerMeta": 1, + "maxImageMeta": 1, + "maxPersonality": 5, + "maxPersonalitySize": 5}}} + + output = self.view_builder.build(self.rate_limits, + self.absolute_limits) + self.assertThat(output, DictMatches(expected_limits)) + + def test_build_limits_empty_limits(self): + expected_limits = {"limits": {"rate": [], + "absolute": {}}} + + abs_limits = {} + rate_limits = [] + output = self.view_builder.build(rate_limits, abs_limits) + self.assertThat(output, DictMatches(expected_limits)) + + +class LimitsXMLSerializationTest(testtools.TestCase): + def test_xml_declaration(self): + serializer = limits.LimitsTemplate() + + fixture = {"limits": { + "rate": [], + "absolute": {}}} + + output = serializer.serialize(fixture) + has_dec = output.startswith("") + self.assertTrue(has_dec) + + def test_index(self): + serializer = limits.LimitsTemplate() + fixture = { + "limits": { + "rate": [{"uri": "*", + "regex": ".*", + "limit": [ + {"value": 10, + "verb": "POST", + "remaining": 2, + "unit": "MINUTE", + "next-available": "2011-12-15T22:42:45Z"}]}, + {"uri": "*/servers", + "regex": "^/servers", + "limit": [ + {"value": 50, + "verb": "POST", + "remaining": 10, + "unit": "DAY", + "next-available": "2011-12-15T22:42:45Z"}]}], + "absolute": { + "maxServerMeta": 1, + "maxImageMeta": 1, + "maxPersonality": 5, + "maxPersonalitySize": 10240}}} + + output = serializer.serialize(fixture) + root = etree.XML(output) + xmlutil.validate_schema(root, 'limits') + + #verify absolute limits + absolutes = root.xpath('ns:absolute/ns:limit', namespaces=NS) + self.assertEqual(len(absolutes), 4) + for limit in absolutes: + name = limit.get('name') + value = limit.get('value') + self.assertEqual(value, str(fixture['limits']['absolute'][name])) + + #verify rate limits + rates = root.xpath('ns:rates/ns:rate', namespaces=NS) + self.assertEqual(len(rates), 2) + for i, rate in enumerate(rates): + for key in ['uri', 'regex']: + self.assertEqual(rate.get(key), + str(fixture['limits']['rate'][i][key])) + rate_limits = rate.xpath('ns:limit', namespaces=NS) + self.assertEqual(len(rate_limits), 1) + for j, limit in enumerate(rate_limits): + for key in ['verb', 'value', 'remaining', 'unit', + 'next-available']: + self.assertEqual(limit.get(key), + str(fixture['limits']['rate'][i]['limit'] + [j][key])) + + def test_index_no_limits(self): + serializer = limits.LimitsTemplate() + + fixture = {"limits": { + "rate": [], + "absolute": {}}} + + output = serializer.serialize(fixture) + root = etree.XML(output) + xmlutil.validate_schema(root, 'limits') + + #verify absolute limits + absolutes = root.xpath('ns:absolute/ns:limit', namespaces=NS) + self.assertEqual(len(absolutes), 0) + + #verify rate limits + rates = root.xpath('ns:rates/ns:rate', namespaces=NS) + self.assertEqual(len(rates), 0) diff --git a/reddwarf/tests/unittests/util/matchers.py b/reddwarf/tests/unittests/util/matchers.py new file mode 100644 index 0000000000..be65da823a --- /dev/null +++ b/reddwarf/tests/unittests/util/matchers.py @@ -0,0 +1,454 @@ +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright 2010 United States Government as represented by the +# Administrator of the National Aeronautics and Space Administration. +# Copyright 2012 Hewlett-Packard Development Company, L.P. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +"""Matcher classes to be used inside of the testtools assertThat framework.""" + +import pprint + +from lxml import etree + + +class DictKeysMismatch(object): + def __init__(self, d1only, d2only): + self.d1only = d1only + self.d2only = d2only + + def describe(self): + return ('Keys in d1 and not d2: %(d1only)s.' + ' Keys in d2 and not d1: %(d2only)s' % self.__dict__) + + def get_details(self): + return {} + + +class DictMismatch(object): + def __init__(self, key, d1_value, d2_value): + self.key = key + self.d1_value = d1_value + self.d2_value = d2_value + + def describe(self): + return ("Dictionaries do not match at %(key)s." + " d1: %(d1_value)s d2: %(d2_value)s" % self.__dict__) + + def get_details(self): + return {} + + +class DictMatches(object): + + def __init__(self, d1, approx_equal=False, tolerance=0.001): + self.d1 = d1 + self.approx_equal = approx_equal + self.tolerance = tolerance + + def __str__(self): + return 'DictMatches(%s)' % (pprint.pformat(self.d1)) + + # Useful assertions + def match(self, d2): + """Assert two dicts are equivalent. + + This is a 'deep' match in the sense that it handles nested + dictionaries appropriately. + + NOTE: + + If you don't care (or don't know) a given value, you can specify + the string DONTCARE as the value. This will cause that dict-item + to be skipped. + + """ + + d1keys = set(self.d1.keys()) + d2keys = set(d2.keys()) + if d1keys != d2keys: + d1only = d1keys - d2keys + d2only = d2keys - d1keys + return DictKeysMismatch(d1only, d2only) + + for key in d1keys: + d1value = self.d1[key] + d2value = d2[key] + try: + error = abs(float(d1value) - float(d2value)) + within_tolerance = error <= self.tolerance + except (ValueError, TypeError): + # If both values aren't convertible to float, just ignore + # ValueError if arg is a str, TypeError if it's something else + # (like None) + within_tolerance = False + + if hasattr(d1value, 'keys') and hasattr(d2value, 'keys'): + matcher = DictMatches(d1value) + did_match = matcher.match(d2value) + if did_match is not None: + return did_match + elif 'DONTCARE' in (d1value, d2value): + continue + elif self.approx_equal and within_tolerance: + continue + elif d1value != d2value: + return DictMismatch(key, d1value, d2value) + + +class ListLengthMismatch(object): + def __init__(self, len1, len2): + self.len1 = len1 + self.len2 = len2 + + def describe(self): + return ('Length mismatch: len(L1)=%(len1)d != ' + 'len(L2)=%(len2)d' % self.__dict__) + + def get_details(self): + return {} + + +class DictListMatches(object): + + def __init__(self, l1, approx_equal=False, tolerance=0.001): + self.l1 = l1 + self.approx_equal = approx_equal + self.tolerance = tolerance + + def __str__(self): + return 'DictListMatches(%s)' % (pprint.pformat(self.l1)) + + # Useful assertions + def match(self, l2): + """Assert a list of dicts are equivalent.""" + + l1count = len(self.l1) + l2count = len(l2) + if l1count != l2count: + return ListLengthMismatch(l1count, l2count) + + for d1, d2 in zip(self.l1, l2): + matcher = DictMatches(d2, + approx_equal=self.approx_equal, + tolerance=self.tolerance) + did_match = matcher.match(d1) + if did_match: + return did_match + + +class SubDictMismatch(object): + def __init__(self, + key=None, + sub_value=None, + super_value=None, + keys=False): + self.key = key + self.sub_value = sub_value + self.super_value = super_value + self.keys = keys + + def describe(self): + if self.keys: + return "Keys between dictionaries did not match" + else: + return("Dictionaries do not match at %s. d1: %s d2: %s" + % (self.key, + self.super_value, + self.sub_value)) + + def get_details(self): + return {} + + +class IsSubDictOf(object): + + def __init__(self, super_dict): + self.super_dict = super_dict + + def __str__(self): + return 'IsSubDictOf(%s)' % (self.super_dict) + + def match(self, sub_dict): + """Assert a sub_dict is subset of super_dict.""" + if not set(sub_dict.keys()).issubset(set(self.super_dict.keys())): + return SubDictMismatch(keys=True) + for k, sub_value in sub_dict.items(): + super_value = self.super_dict[k] + if isinstance(sub_value, dict): + matcher = IsSubDictOf(super_value) + did_match = matcher.match(sub_value) + if did_match is not None: + return did_match + elif 'DONTCARE' in (sub_value, super_value): + continue + else: + if sub_value != super_value: + return SubDictMismatch(k, sub_value, super_value) + + +class FunctionCallMatcher(object): + + def __init__(self, expected_func_calls): + self.expected_func_calls = expected_func_calls + self.actual_func_calls = [] + + def call(self, *args, **kwargs): + func_call = {'args': args, 'kwargs': kwargs} + self.actual_func_calls.append(func_call) + + def match(self): + dict_list_matcher = DictListMatches(self.expected_func_calls) + return dict_list_matcher.match(self.actual_func_calls) + + +class XMLMismatch(object): + """Superclass for XML mismatch.""" + + def __init__(self, state): + self.path = str(state) + self.expected = state.expected + self.actual = state.actual + + def describe(self): + return "%(path)s: XML does not match" % self.__dict__ + + def get_details(self): + return { + 'expected': self.expected, + 'actual': self.actual, + } + + +class XMLTagMismatch(XMLMismatch): + """XML tags don't match.""" + + def __init__(self, state, idx, expected_tag, actual_tag): + super(XMLTagMismatch, self).__init__(state) + self.idx = idx + self.expected_tag = expected_tag + self.actual_tag = actual_tag + + def describe(self): + return ("%(path)s: XML tag mismatch at index %(idx)d: " + "expected tag <%(expected_tag)s>; " + "actual tag <%(actual_tag)s>" % self.__dict__) + + +class XMLAttrKeysMismatch(XMLMismatch): + """XML attribute keys don't match.""" + + def __init__(self, state, expected_only, actual_only): + super(XMLAttrKeysMismatch, self).__init__(state) + self.expected_only = ', '.join(sorted(expected_only)) + self.actual_only = ', '.join(sorted(actual_only)) + + def describe(self): + return ("%(path)s: XML attributes mismatch: " + "keys only in expected: %(expected_only)s; " + "keys only in actual: %(actual_only)s" % self.__dict__) + + +class XMLAttrValueMismatch(XMLMismatch): + """XML attribute values don't match.""" + + def __init__(self, state, key, expected_value, actual_value): + super(XMLAttrValueMismatch, self).__init__(state) + self.key = key + self.expected_value = expected_value + self.actual_value = actual_value + + def describe(self): + return ("%(path)s: XML attribute value mismatch: " + "expected value of attribute %(key)s: %(expected_value)r; " + "actual value: %(actual_value)r" % self.__dict__) + + +class XMLTextValueMismatch(XMLMismatch): + """XML text values don't match.""" + + def __init__(self, state, expected_text, actual_text): + super(XMLTextValueMismatch, self).__init__(state) + self.expected_text = expected_text + self.actual_text = actual_text + + def describe(self): + return ("%(path)s: XML text value mismatch: " + "expected text value: %(expected_text)r; " + "actual value: %(actual_text)r" % self.__dict__) + + +class XMLUnexpectedChild(XMLMismatch): + """Unexpected child present in XML.""" + + def __init__(self, state, tag, idx): + super(XMLUnexpectedChild, self).__init__(state) + self.tag = tag + self.idx = idx + + def describe(self): + return ("%(path)s: XML unexpected child element <%(tag)s> " + "present at index %(idx)d" % self.__dict__) + + +class XMLExpectedChild(XMLMismatch): + """Expected child not present in XML.""" + + def __init__(self, state, tag, idx): + super(XMLExpectedChild, self).__init__(state) + self.tag = tag + self.idx = idx + + def describe(self): + return ("%(path)s: XML expected child element <%(tag)s> " + "not present at index %(idx)d" % self.__dict__) + + +class XMLMatchState(object): + """ + Maintain some state for matching. + + Tracks the XML node path and saves the expected and actual full + XML text, for use by the XMLMismatch subclasses. + """ + + def __init__(self, expected, actual): + self.path = [] + self.expected = expected + self.actual = actual + + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_value, exc_tb): + self.path.pop() + return False + + def __str__(self): + return '/' + '/'.join(self.path) + + def node(self, tag, idx): + """ + Adds tag and index to the path; they will be popped off when + the corresponding 'with' statement exits. + + :param tag: The element tag + :param idx: If not None, the integer index of the element + within its parent. Not included in the path + element if None. + """ + + if idx is not None: + self.path.append("%s[%d]" % (tag, idx)) + else: + self.path.append(tag) + return self + + +class XMLMatches(object): + """Compare XML strings. More complete than string comparison.""" + + def __init__(self, expected): + self.expected_xml = expected + self.expected = etree.fromstring(expected) + + def __str__(self): + return 'XMLMatches(%r)' % self.expected_xml + + def match(self, actual_xml): + actual = etree.fromstring(actual_xml) + + state = XMLMatchState(self.expected_xml, actual_xml) + result = self._compare_node(self.expected, actual, state, None) + + if result is False: + return XMLMismatch(state) + elif result is not True: + return result + + def _compare_node(self, expected, actual, state, idx): + """Recursively compares nodes within the XML tree.""" + + # Start by comparing the tags + if expected.tag != actual.tag: + return XMLTagMismatch(state, idx, expected.tag, actual.tag) + + with state.node(expected.tag, idx): + # Compare the attribute keys + expected_attrs = set(expected.attrib.keys()) + actual_attrs = set(actual.attrib.keys()) + if expected_attrs != actual_attrs: + expected_only = expected_attrs - actual_attrs + actual_only = actual_attrs - expected_attrs + return XMLAttrKeysMismatch(state, expected_only, actual_only) + + # Compare the attribute values + for key in expected_attrs: + expected_value = expected.attrib[key] + actual_value = actual.attrib[key] + + if 'DONTCARE' in (expected_value, actual_value): + continue + elif expected_value != actual_value: + return XMLAttrValueMismatch(state, key, expected_value, + actual_value) + + # Compare the contents of the node + if len(expected) == 0 and len(actual) == 0: + # No children, compare text values + if ('DONTCARE' not in (expected.text, actual.text) and + expected.text != actual.text): + return XMLTextValueMismatch(state, expected.text, + actual.text) + else: + expected_idx = 0 + actual_idx = 0 + while (expected_idx < len(expected) and + actual_idx < len(actual)): + # Ignore comments and processing instructions + # TODO(Vek): may interpret PIs in the future, to + # allow for, say, arbitrary ordering of some + # elements + if (expected[expected_idx].tag in + (etree.Comment, etree.ProcessingInstruction)): + expected_idx += 1 + continue + + # Compare the nodes + result = self._compare_node(expected[expected_idx], + actual[actual_idx], state, + actual_idx) + if result is not True: + return result + + # Step on to comparing the next nodes... + expected_idx += 1 + actual_idx += 1 + + # Make sure we consumed all nodes in actual + if actual_idx < len(actual): + return XMLUnexpectedChild(state, actual[actual_idx].tag, + actual_idx) + + # Make sure we consumed all nodes in expected + if expected_idx < len(expected): + for node in expected[expected_idx:]: + if (node.tag in + (etree.Comment, etree.ProcessingInstruction)): + continue + + return XMLExpectedChild(state, node.tag, actual_idx) + + # The nodes match + return True diff --git a/run_tests.py b/run_tests.py index 71701f83ff..7e7fdbe7e6 100644 --- a/run_tests.py +++ b/run_tests.py @@ -111,6 +111,7 @@ if __name__ == "__main__": # Initialize the test configuration. CONFIG.load_from_file('etc/tests/localhost.test.conf') + from reddwarf.tests.api import limits from reddwarf.tests.api import flavors from reddwarf.tests.api import versions from reddwarf.tests.api import instances diff --git a/tools/test-requires b/tools/test-requires index a095ae4f1f..2096d766bd 100644 --- a/tools/test-requires +++ b/tools/test-requires @@ -17,3 +17,4 @@ testtools>=0.9.22 pexpect discover testrepository>=0.0.8 +mockito