Rate limits implementation

added unittest for limits
    reverted changes to openstack/common
    removed commented code
    cleaned up unittest
    added int-tests
    updated reference to XMLNS
    removed 1.1 XMLMS in wsgi

Implements: blueprint rate-limits
Change-Id: I842de3a6cae1859cc246264a5836abfd97fb8074
This commit is contained in:
daniel-a-nguyen 2013-02-14 11:22:06 -08:00
parent 8de8507493
commit 0a71ef9e88
21 changed files with 3872 additions and 3 deletions

View File

@ -7,7 +7,7 @@ use = call:reddwarf.common.wsgi:versioned_urlmap
paste.app_factory = reddwarf.versions:app_factory paste.app_factory = reddwarf.versions:app_factory
[pipeline:reddwarfapi] [pipeline:reddwarfapi]
pipeline = faultwrapper tokenauth authorization contextwrapper extensions reddwarfapp pipeline = faultwrapper tokenauth authorization contextwrapper ratelimit extensions reddwarfapp
#pipeline = debug extensions reddwarfapp #pipeline = debug extensions reddwarfapp
[filter:extensions] [filter:extensions]
@ -34,6 +34,9 @@ paste.filter_factory = reddwarf.common.wsgi:ContextMiddleware.factory
[filter:faultwrapper] [filter:faultwrapper]
paste.filter_factory = reddwarf.common.wsgi:FaultWrapper.factory paste.filter_factory = reddwarf.common.wsgi:FaultWrapper.factory
[filter:ratelimit]
paste.filter_factory = reddwarf.common.limits:RateLimitingMiddleware.factory
[app:reddwarfapp] [app:reddwarfapp]
paste.app_factory = reddwarf.common.api:app_factory paste.app_factory = reddwarf.common.api:app_factory

View File

@ -54,6 +54,12 @@ max_instances_per_user = 5
max_volumes_per_user = 100 max_volumes_per_user = 100
volume_time_out=30 volume_time_out=30
# Config options for rate limits
http_get_rate = 200
http_post_rate = 200
http_put_rate = 200
http_delete_rate = 200
# Reddwarf DNS # Reddwarf DNS
reddwarf_dns_support = False reddwarf_dns_support = False

View File

@ -106,7 +106,7 @@ use = call:reddwarf.common.wsgi:versioned_urlmap
paste.app_factory = reddwarf.versions:app_factory paste.app_factory = reddwarf.versions:app_factory
[pipeline:reddwarfapi] [pipeline:reddwarfapi]
pipeline = faultwrapper tokenauth authorization contextwrapper extensions reddwarfapp pipeline = faultwrapper tokenauth authorization contextwrapper ratelimit extensions reddwarfapp
# pipeline = debug reddwarfapp # pipeline = debug reddwarfapp
[filter:extensions] [filter:extensions]
@ -132,6 +132,9 @@ paste.filter_factory = reddwarf.common.wsgi:ContextMiddleware.factory
[filter:faultwrapper] [filter:faultwrapper]
paste.filter_factory = reddwarf.common.wsgi:FaultWrapper.factory paste.filter_factory = reddwarf.common.wsgi:FaultWrapper.factory
[filter:ratelimit]
paste.filter_factory = reddwarf.common.limits:RateLimitingMiddleware.factory
[app:reddwarfapp] [app:reddwarfapp]
paste.app_factory = reddwarf.common.api:app_factory paste.app_factory = reddwarf.common.api:app_factory

View File

@ -45,6 +45,24 @@
"is_admin":false, "is_admin":false,
"services": ["reddwarf"] "services": ["reddwarf"]
} }
},
{
"auth_user":"rate_limit",
"auth_key":"password",
"tenant":"4000",
"requirements": {
"is_admin":false,
"services": ["reddwarf"]
}
},
{
"auth_user":"rate_limit_exceeded",
"auth_key":"password",
"tenant":"4050",
"requirements": {
"is_admin":false,
"services": ["reddwarf"]
}
} }
], ],

View File

@ -19,6 +19,7 @@ from reddwarf.common import wsgi
from reddwarf.extensions.mgmt.host.instance import service as hostservice from reddwarf.extensions.mgmt.host.instance import service as hostservice
from reddwarf.flavor.service import FlavorController from reddwarf.flavor.service import FlavorController
from reddwarf.instance.service import InstanceController from reddwarf.instance.service import InstanceController
from reddwarf.limits.service import LimitsController
from reddwarf.openstack.common import log as logging from reddwarf.openstack.common import log as logging
from reddwarf.openstack.common import rpc from reddwarf.openstack.common import rpc
from reddwarf.versions import VersionsController from reddwarf.versions import VersionsController
@ -32,6 +33,7 @@ class API(wsgi.Router):
self._instance_router(mapper) self._instance_router(mapper)
self._flavor_router(mapper) self._flavor_router(mapper)
self._versions_router(mapper) self._versions_router(mapper)
self._limits_router(mapper)
def _versions_router(self, mapper): def _versions_router(self, mapper):
versions_resource = VersionsController().create_resource() versions_resource = VersionsController().create_resource()
@ -48,6 +50,11 @@ class API(wsgi.Router):
path = "/{tenant_id}/flavors" path = "/{tenant_id}/flavors"
mapper.resource("flavor", path, controller=flavor_resource) mapper.resource("flavor", path, controller=flavor_resource)
def _limits_router(self, mapper):
limits_resource = LimitsController().create_resource()
path = "/{tenant_id}/limits"
mapper.resource("limits", path, controller=limits_resource)
def app_factory(global_conf, **local_conf): def app_factory(global_conf, **local_conf):
return API() return API()

View File

@ -100,6 +100,10 @@ common_opts = [
cfg.IntOpt('revert_time_out', default=60 * 10), cfg.IntOpt('revert_time_out', default=60 * 10),
cfg.ListOpt('root_grant', default=['ALL']), cfg.ListOpt('root_grant', default=['ALL']),
cfg.BoolOpt('root_grant_option', default=True), cfg.BoolOpt('root_grant_option', default=True),
cfg.IntOpt('http_get_rate', default=200),
cfg.IntOpt('http_post_rate', default=200),
cfg.IntOpt('http_delete_rate', default=200),
cfg.IntOpt('http_put_rate', default=200),
] ]

464
reddwarf/common/limits.py Normal file
View File

@ -0,0 +1,464 @@
# Copyright 2011 OpenStack LLC.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Module dedicated functions/classes dealing with rate limiting requests.
"""
import collections
import copy
import httplib
import math
import re
import time
import webob.dec
import webob.exc
import xmlutil
from reddwarf.common import cfg
from reddwarf.common import wsgi as base_wsgi
from reddwarf.openstack.common import importutils
from reddwarf.openstack.common import jsonutils
from reddwarf.openstack.common import wsgi
from reddwarf.openstack.common.gettextutils import _
#
# TODO: come back to this later
# Dan Nguyen
#
#from nova import quota
#QUOTAS = quota.QUOTAS
CONF = cfg.CONF
# Convenience constants for the limits dictionary passed to Limiter().
PER_SECOND = 1
PER_MINUTE = 60
PER_HOUR = 60 * 60
PER_DAY = 60 * 60 * 24
limits_nsmap = {None: xmlutil.XMLNS_COMMON_V10, 'atom': xmlutil.XMLNS_ATOM}
class LimitsTemplate(xmlutil.TemplateBuilder):
def construct(self):
root = xmlutil.TemplateElement('limits', selector='limits')
rates = xmlutil.SubTemplateElement(root, 'rates')
rate = xmlutil.SubTemplateElement(rates, 'rate', selector='rate')
rate.set('uri', 'uri')
rate.set('regex', 'regex')
limit = xmlutil.SubTemplateElement(rate, 'limit', selector='limit')
limit.set('value', 'value')
limit.set('verb', 'verb')
limit.set('remaining', 'remaining')
limit.set('unit', 'unit')
limit.set('next-available', 'next-available')
absolute = xmlutil.SubTemplateElement(root, 'absolute',
selector='absolute')
limit = xmlutil.SubTemplateElement(absolute, 'limit',
selector=xmlutil.get_items)
limit.set('name', 0)
limit.set('value', 1)
return xmlutil.MasterTemplate(root, 1, nsmap=limits_nsmap)
class Limit(object):
"""
Stores information about a limit for HTTP requests.
"""
UNITS = {
1: "SECOND",
60: "MINUTE",
60 * 60: "HOUR",
60 * 60 * 24: "DAY",
}
UNIT_MAP = dict([(v, k) for k, v in UNITS.items()])
def __init__(self, verb, uri, regex, value, unit):
"""
Initialize a new `Limit`.
@param verb: HTTP verb (POST, PUT, etc.)
@param uri: Human-readable URI
@param regex: Regular expression format for this limit
@param value: Integer number of requests which can be made
@param unit: Unit of measure for the value parameter
"""
self.verb = verb
self.uri = uri
self.regex = regex
self.value = int(value)
self.unit = unit
self.unit_string = self.display_unit().lower()
self.remaining = int(value)
if value <= 0:
raise ValueError("Limit value must be > 0")
self.last_request = None
self.next_request = None
self.water_level = 0
self.capacity = self.unit
self.request_value = float(self.capacity) / float(self.value)
msg = _("Only %(value)s %(verb)s request(s) can be "
"made to %(uri)s every %(unit_string)s.")
self.error_message = msg % self.__dict__
def __call__(self, verb, url):
"""
Represents a call to this limit from a relevant request.
@param verb: string http verb (POST, GET, etc.)
@param url: string URL
"""
if self.verb != verb or not re.match(self.regex, url):
return
now = self._get_time()
if self.last_request is None:
self.last_request = now
leak_value = now - self.last_request
self.water_level -= leak_value
self.water_level = max(self.water_level, 0)
self.water_level += self.request_value
difference = self.water_level - self.capacity
self.last_request = now
if difference > 0:
self.water_level -= self.request_value
self.next_request = now + difference
return difference
cap = self.capacity
water = self.water_level
val = self.value
self.remaining = math.floor(((cap - water) / cap) * val)
self.next_request = now
def _get_time(self):
"""Retrieve the current time. Broken out for testability."""
return time.time()
def display_unit(self):
"""Display the string name of the unit."""
return self.UNITS.get(self.unit, "UNKNOWN")
def display(self):
"""Return a useful representation of this class."""
return {
"verb": self.verb,
"URI": self.uri,
"regex": self.regex,
"value": self.value,
"remaining": int(self.remaining),
"unit": self.display_unit(),
"resetTime": int(self.next_request or self._get_time()),
}
# "Limit" format is a dictionary with the HTTP verb, human-readable URI,
# a regular-expression to match, value and unit of measure (PER_DAY, etc.)
DEFAULT_LIMITS = [
Limit("POST", "*", ".*", CONF.http_post_rate, PER_MINUTE),
Limit("PUT", "*", ".*", CONF.http_put_rate, PER_MINUTE),
Limit("DELETE", "*", ".*", CONF.http_delete_rate, PER_MINUTE),
Limit("GET", "*", ".*", CONF.http_get_rate, PER_MINUTE),
]
class RateLimitingMiddleware(base_wsgi.ReddwarfMiddleware):
"""
Rate-limits requests passing through this middleware. All limit information
is stored in memory for this implementation.
"""
def __init__(self, application, limits=None, limiter=None, **kwargs):
"""
Initialize new `RateLimitingMiddleware`, which wraps the given WSGI
application and sets up the given limits.
@param application: WSGI application to wrap
@param limits: String describing limits
@param limiter: String identifying class for representing limits
Other parameters are passed to the constructor for the limiter.
"""
base_wsgi.Middleware.__init__(self, application)
# Select the limiter class
if limiter is None:
limiter = Limiter
else:
limiter = importutils.import_class(limiter)
# Parse the limits, if any are provided
if limits is not None:
limits = limiter.parse_limits(limits)
self._limiter = limiter(limits or DEFAULT_LIMITS, **kwargs)
@webob.dec.wsgify(RequestClass=wsgi.Request)
def __call__(self, req):
"""
Represents a single call through this middleware. We should record the
request if we have a limit relevant to it. If no limit is relevant to
the request, ignore it.
If the request should be rate limited, return a fault telling the user
they are over the limit and need to retry later.
"""
verb = req.method
url = req.url
context = req.environ.get(base_wsgi.CONTEXT_KEY)
tenant_id = None
if context:
tenant_id = context.tenant
delay, error = self._limiter.check_for_delay(verb, url, tenant_id)
if delay:
msg = _("This request was rate-limited.")
retry = time.time() + delay
return base_wsgi.OverLimitFault(msg, error, retry)
req.environ["reddwarf.limits"] = self._limiter.get_limits(tenant_id)
return self.application
class Limiter(object):
"""
Rate-limit checking class which handles limits in memory.
"""
def __init__(self, limits, **kwargs):
"""
Initialize the new `Limiter`.
@param limits: List of `Limit` objects
"""
self.limits = copy.deepcopy(limits)
self.levels = collections.defaultdict(lambda: copy.deepcopy(limits))
# Pick up any per-user limit information
for key, value in kwargs.items():
if key.startswith('user:'):
username = key[5:]
self.levels[username] = self.parse_limits(value)
def get_limits(self, username=None):
"""
Return the limits for a given user.
"""
return [limit.display() for limit in self.levels[username]]
def check_for_delay(self, verb, url, username=None):
"""
Check the given verb/user/user triplet for limit.
@return: Tuple of delay (in seconds) and error message (or None, None)
"""
delays = []
for limit in self.levels[username]:
delay = limit(verb, url)
if delay:
delays.append((delay, limit.error_message))
if delays:
delays.sort()
return delays[0]
return None, None
# This was ported from nova.
# Keeping it as a static method for the sake of consistency
#
# Note: This method gets called before the class is instantiated,
# so this must be either a static method or a class method. It is
# used to develop a list of limits to feed to the constructor. We
# put this in the class so that subclasses can override the
# default limit parsing.
@staticmethod
def parse_limits(limits):
"""
Convert a string into a list of Limit instances. This
implementation expects a semicolon-separated sequence of
parenthesized groups, where each group contains a
comma-separated sequence consisting of HTTP method,
user-readable URI, a URI reg-exp, an integer number of
requests which can be made, and a unit of measure. Valid
values for the latter are "SECOND", "MINUTE", "HOUR", and
"DAY".
@return: List of Limit instances.
"""
# Handle empty limit strings
limits = limits.strip()
if not limits:
return []
# Split up the limits by semicolon
result = []
for group in limits.split(';'):
group = group.strip()
if group[:1] != '(' or group[-1:] != ')':
raise ValueError("Limit rules must be surrounded by "
"parentheses")
group = group[1:-1]
# Extract the Limit arguments
args = [a.strip() for a in group.split(',')]
if len(args) != 5:
raise ValueError("Limit rules must contain the following "
"arguments: verb, uri, regex, value, unit")
# Pull out the arguments
verb, uri, regex, value, unit = args
# Upper-case the verb
verb = verb.upper()
# Convert value--raises ValueError if it's not integer
value = int(value)
# Convert unit
unit = unit.upper()
if unit not in Limit.UNIT_MAP:
raise ValueError("Invalid units specified")
unit = Limit.UNIT_MAP[unit]
# Build a limit
result.append(Limit(verb, uri, regex, value, unit))
return result
class WsgiLimiter(object):
"""
Rate-limit checking from a WSGI application. Uses an in-memory `Limiter`.
To use, POST ``/<username>`` with JSON data such as::
{
"verb" : GET,
"path" : "/servers"
}
and receive a 204 No Content, or a 403 Forbidden with an X-Wait-Seconds
header containing the number of seconds to wait before the action would
succeed.
"""
def __init__(self, limits=None):
"""
Initialize the new `WsgiLimiter`.
@param limits: List of `Limit` objects
"""
self._limiter = Limiter(limits or DEFAULT_LIMITS)
@webob.dec.wsgify(RequestClass=wsgi.Request)
def __call__(self, request):
"""
Handles a call to this application. Returns 204 if the request is
acceptable to the limiter, else a 403 is returned with a relevant
header indicating when the request *will* succeed.
"""
if request.method != "POST":
raise webob.exc.HTTPMethodNotAllowed()
try:
info = dict(jsonutils.loads(request.body))
except ValueError:
raise webob.exc.HTTPBadRequest()
username = request.path_info_pop()
verb = info.get("verb")
path = info.get("path")
delay, error = self._limiter.check_for_delay(verb, path, username)
if delay:
headers = {"X-Wait-Seconds": "%.2f" % delay}
return webob.exc.HTTPForbidden(headers=headers, explanation=error)
else:
return webob.exc.HTTPNoContent()
class WsgiLimiterProxy(object):
"""
Rate-limit requests based on answers from a remote source.
"""
def __init__(self, limiter_address):
"""
Initialize the new `WsgiLimiterProxy`.
@param limiter_address: IP/port combination of where to request limit
"""
self.limiter_address = limiter_address
def check_for_delay(self, verb, path, username=None):
body = jsonutils.dumps({"verb": verb, "path": path})
headers = {"Content-Type": "application/json"}
conn = httplib.HTTPConnection(self.limiter_address)
if username:
conn.request("POST", "/%s" % (username), body, headers)
else:
conn.request("POST", "/", body, headers)
resp = conn.getresponse()
if 200 >= resp.status < 300:
return None, None
return resp.getheader("X-Wait-Seconds"), resp.read() or None
# This was ported from nova.
# Keeping it as a static method for the sake of consistency
#
# Note: This method gets called before the class is instantiated,
# so this must be either a static method or a class method. It is
# used to develop a list of limits to feed to the constructor.
# This implementation returns an empty list, since all limit
# decisions are made by a remote server.
@staticmethod
def parse_limits(limits):
"""
Ignore a limits string--simply doesn't apply for the limit
proxy.
@return: Empty list.
"""
return []

View File

@ -0,0 +1,141 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
-*- rnc -*-
RELAX NG Compact Syntax Grammar for the
Atom Format Specification Version 11
-->
<grammar xmlns:xhtml="http://www.w3.org/1999/xhtml" xmlns:atom="http://www.w3.org/2005/Atom" xmlns:s="http://www.ascc.net/xml/schematron" xmlns="http://relaxng.org/ns/structure/1.0" datatypeLibrary="http://www.w3.org/2001/XMLSchema-datatypes">
<start>
<choice>
<ref name="atomLink"/>
</choice>
</start>
<!-- Common attributes -->
<define name="atomCommonAttributes">
<optional>
<attribute name="xml:base">
<ref name="atomUri"/>
</attribute>
</optional>
<optional>
<attribute name="xml:lang">
<ref name="atomLanguageTag"/>
</attribute>
</optional>
<zeroOrMore>
<ref name="undefinedAttribute"/>
</zeroOrMore>
</define>
<!-- atom:link -->
<define name="atomLink">
<element name="atom:link">
<ref name="atomCommonAttributes"/>
<attribute name="href">
<ref name="atomUri"/>
</attribute>
<optional>
<attribute name="rel">
<choice>
<ref name="atomNCName"/>
<ref name="atomUri"/>
</choice>
</attribute>
</optional>
<optional>
<attribute name="type">
<ref name="atomMediaType"/>
</attribute>
</optional>
<optional>
<attribute name="hreflang">
<ref name="atomLanguageTag"/>
</attribute>
</optional>
<optional>
<attribute name="title"/>
</optional>
<optional>
<attribute name="length"/>
</optional>
<ref name="undefinedContent"/>
</element>
</define>
<!-- Low-level simple types -->
<define name="atomNCName">
<data type="string">
<param name="minLength">1</param>
<param name="pattern">[^:]*</param>
</data>
</define>
<!-- Whatever a media type is, it contains at least one slash -->
<define name="atomMediaType">
<data type="string">
<param name="pattern">.+/.+</param>
</data>
</define>
<!-- As defined in RFC 3066 -->
<define name="atomLanguageTag">
<data type="string">
<param name="pattern">[A-Za-z]{1,8}(-[A-Za-z0-9]{1,8})*</param>
</data>
</define>
<!--
Unconstrained; it's not entirely clear how IRI fit into
xsd:anyURI so let's not try to constrain it here
-->
<define name="atomUri">
<text/>
</define>
<!-- Other Extensibility -->
<define name="undefinedAttribute">
<attribute>
<anyName>
<except>
<name>xml:base</name>
<name>xml:lang</name>
<nsName ns=""/>
</except>
</anyName>
</attribute>
</define>
<define name="undefinedContent">
<zeroOrMore>
<choice>
<text/>
<ref name="anyForeignElement"/>
</choice>
</zeroOrMore>
</define>
<define name="anyElement">
<element>
<anyName/>
<zeroOrMore>
<choice>
<attribute>
<anyName/>
</attribute>
<text/>
<ref name="anyElement"/>
</choice>
</zeroOrMore>
</element>
</define>
<define name="anyForeignElement">
<element>
<anyName>
<except>
<nsName ns="http://www.w3.org/2005/Atom"/>
</except>
</anyName>
<zeroOrMore>
<choice>
<attribute>
<anyName/>
</attribute>
<text/>
<ref name="anyElement"/>
</choice>
</zeroOrMore>
</element>
</define>
</grammar>

View File

@ -0,0 +1,597 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
-*- rnc -*-
RELAX NG Compact Syntax Grammar for the
Atom Format Specification Version 11
-->
<grammar xmlns:xhtml="http://www.w3.org/1999/xhtml" xmlns:atom="http://www.w3.org/2005/Atom" xmlns:s="http://www.ascc.net/xml/schematron" xmlns="http://relaxng.org/ns/structure/1.0" datatypeLibrary="http://www.w3.org/2001/XMLSchema-datatypes">
<start>
<choice>
<ref name="atomFeed"/>
<ref name="atomEntry"/>
</choice>
</start>
<!-- Common attributes -->
<define name="atomCommonAttributes">
<optional>
<attribute name="xml:base">
<ref name="atomUri"/>
</attribute>
</optional>
<optional>
<attribute name="xml:lang">
<ref name="atomLanguageTag"/>
</attribute>
</optional>
<zeroOrMore>
<ref name="undefinedAttribute"/>
</zeroOrMore>
</define>
<!-- Text Constructs -->
<define name="atomPlainTextConstruct">
<ref name="atomCommonAttributes"/>
<optional>
<attribute name="type">
<choice>
<value>text</value>
<value>html</value>
</choice>
</attribute>
</optional>
<text/>
</define>
<define name="atomXHTMLTextConstruct">
<ref name="atomCommonAttributes"/>
<attribute name="type">
<value>xhtml</value>
</attribute>
<ref name="xhtmlDiv"/>
</define>
<define name="atomTextConstruct">
<choice>
<ref name="atomPlainTextConstruct"/>
<ref name="atomXHTMLTextConstruct"/>
</choice>
</define>
<!-- Person Construct -->
<define name="atomPersonConstruct">
<ref name="atomCommonAttributes"/>
<interleave>
<element name="atom:name">
<text/>
</element>
<optional>
<element name="atom:uri">
<ref name="atomUri"/>
</element>
</optional>
<optional>
<element name="atom:email">
<ref name="atomEmailAddress"/>
</element>
</optional>
<zeroOrMore>
<ref name="extensionElement"/>
</zeroOrMore>
</interleave>
</define>
<!-- Date Construct -->
<define name="atomDateConstruct">
<ref name="atomCommonAttributes"/>
<data type="dateTime"/>
</define>
<!-- atom:feed -->
<define name="atomFeed">
<element name="atom:feed">
<s:rule context="atom:feed">
<s:assert test="atom:author or not(atom:entry[not(atom:author)])">An atom:feed must have an atom:author unless all of its atom:entry children have an atom:author.</s:assert>
</s:rule>
<ref name="atomCommonAttributes"/>
<interleave>
<zeroOrMore>
<ref name="atomAuthor"/>
</zeroOrMore>
<zeroOrMore>
<ref name="atomCategory"/>
</zeroOrMore>
<zeroOrMore>
<ref name="atomContributor"/>
</zeroOrMore>
<optional>
<ref name="atomGenerator"/>
</optional>
<optional>
<ref name="atomIcon"/>
</optional>
<ref name="atomId"/>
<zeroOrMore>
<ref name="atomLink"/>
</zeroOrMore>
<optional>
<ref name="atomLogo"/>
</optional>
<optional>
<ref name="atomRights"/>
</optional>
<optional>
<ref name="atomSubtitle"/>
</optional>
<ref name="atomTitle"/>
<ref name="atomUpdated"/>
<zeroOrMore>
<ref name="extensionElement"/>
</zeroOrMore>
</interleave>
<zeroOrMore>
<ref name="atomEntry"/>
</zeroOrMore>
</element>
</define>
<!-- atom:entry -->
<define name="atomEntry">
<element name="atom:entry">
<s:rule context="atom:entry">
<s:assert test="atom:link[@rel='alternate'] or atom:link[not(@rel)] or atom:content">An atom:entry must have at least one atom:link element with a rel attribute of 'alternate' or an atom:content.</s:assert>
</s:rule>
<s:rule context="atom:entry">
<s:assert test="atom:author or ../atom:author or atom:source/atom:author">An atom:entry must have an atom:author if its feed does not.</s:assert>
</s:rule>
<ref name="atomCommonAttributes"/>
<interleave>
<zeroOrMore>
<ref name="atomAuthor"/>
</zeroOrMore>
<zeroOrMore>
<ref name="atomCategory"/>
</zeroOrMore>
<optional>
<ref name="atomContent"/>
</optional>
<zeroOrMore>
<ref name="atomContributor"/>
</zeroOrMore>
<ref name="atomId"/>
<zeroOrMore>
<ref name="atomLink"/>
</zeroOrMore>
<optional>
<ref name="atomPublished"/>
</optional>
<optional>
<ref name="atomRights"/>
</optional>
<optional>
<ref name="atomSource"/>
</optional>
<optional>
<ref name="atomSummary"/>
</optional>
<ref name="atomTitle"/>
<ref name="atomUpdated"/>
<zeroOrMore>
<ref name="extensionElement"/>
</zeroOrMore>
</interleave>
</element>
</define>
<!-- atom:content -->
<define name="atomInlineTextContent">
<element name="atom:content">
<ref name="atomCommonAttributes"/>
<optional>
<attribute name="type">
<choice>
<value>text</value>
<value>html</value>
</choice>
</attribute>
</optional>
<zeroOrMore>
<text/>
</zeroOrMore>
</element>
</define>
<define name="atomInlineXHTMLContent">
<element name="atom:content">
<ref name="atomCommonAttributes"/>
<attribute name="type">
<value>xhtml</value>
</attribute>
<ref name="xhtmlDiv"/>
</element>
</define>
<define name="atomInlineOtherContent">
<element name="atom:content">
<ref name="atomCommonAttributes"/>
<optional>
<attribute name="type">
<ref name="atomMediaType"/>
</attribute>
</optional>
<zeroOrMore>
<choice>
<text/>
<ref name="anyElement"/>
</choice>
</zeroOrMore>
</element>
</define>
<define name="atomOutOfLineContent">
<element name="atom:content">
<ref name="atomCommonAttributes"/>
<optional>
<attribute name="type">
<ref name="atomMediaType"/>
</attribute>
</optional>
<attribute name="src">
<ref name="atomUri"/>
</attribute>
<empty/>
</element>
</define>
<define name="atomContent">
<choice>
<ref name="atomInlineTextContent"/>
<ref name="atomInlineXHTMLContent"/>
<ref name="atomInlineOtherContent"/>
<ref name="atomOutOfLineContent"/>
</choice>
</define>
<!-- atom:author -->
<define name="atomAuthor">
<element name="atom:author">
<ref name="atomPersonConstruct"/>
</element>
</define>
<!-- atom:category -->
<define name="atomCategory">
<element name="atom:category">
<ref name="atomCommonAttributes"/>
<attribute name="term"/>
<optional>
<attribute name="scheme">
<ref name="atomUri"/>
</attribute>
</optional>
<optional>
<attribute name="label"/>
</optional>
<ref name="undefinedContent"/>
</element>
</define>
<!-- atom:contributor -->
<define name="atomContributor">
<element name="atom:contributor">
<ref name="atomPersonConstruct"/>
</element>
</define>
<!-- atom:generator -->
<define name="atomGenerator">
<element name="atom:generator">
<ref name="atomCommonAttributes"/>
<optional>
<attribute name="uri">
<ref name="atomUri"/>
</attribute>
</optional>
<optional>
<attribute name="version"/>
</optional>
<text/>
</element>
</define>
<!-- atom:icon -->
<define name="atomIcon">
<element name="atom:icon">
<ref name="atomCommonAttributes"/>
<ref name="atomUri"/>
</element>
</define>
<!-- atom:id -->
<define name="atomId">
<element name="atom:id">
<ref name="atomCommonAttributes"/>
<ref name="atomUri"/>
</element>
</define>
<!-- atom:logo -->
<define name="atomLogo">
<element name="atom:logo">
<ref name="atomCommonAttributes"/>
<ref name="atomUri"/>
</element>
</define>
<!-- atom:link -->
<define name="atomLink">
<element name="atom:link">
<ref name="atomCommonAttributes"/>
<attribute name="href">
<ref name="atomUri"/>
</attribute>
<optional>
<attribute name="rel">
<choice>
<ref name="atomNCName"/>
<ref name="atomUri"/>
</choice>
</attribute>
</optional>
<optional>
<attribute name="type">
<ref name="atomMediaType"/>
</attribute>
</optional>
<optional>
<attribute name="hreflang">
<ref name="atomLanguageTag"/>
</attribute>
</optional>
<optional>
<attribute name="title"/>
</optional>
<optional>
<attribute name="length"/>
</optional>
<ref name="undefinedContent"/>
</element>
</define>
<!-- atom:published -->
<define name="atomPublished">
<element name="atom:published">
<ref name="atomDateConstruct"/>
</element>
</define>
<!-- atom:rights -->
<define name="atomRights">
<element name="atom:rights">
<ref name="atomTextConstruct"/>
</element>
</define>
<!-- atom:source -->
<define name="atomSource">
<element name="atom:source">
<ref name="atomCommonAttributes"/>
<interleave>
<zeroOrMore>
<ref name="atomAuthor"/>
</zeroOrMore>
<zeroOrMore>
<ref name="atomCategory"/>
</zeroOrMore>
<zeroOrMore>
<ref name="atomContributor"/>
</zeroOrMore>
<optional>
<ref name="atomGenerator"/>
</optional>
<optional>
<ref name="atomIcon"/>
</optional>
<optional>
<ref name="atomId"/>
</optional>
<zeroOrMore>
<ref name="atomLink"/>
</zeroOrMore>
<optional>
<ref name="atomLogo"/>
</optional>
<optional>
<ref name="atomRights"/>
</optional>
<optional>
<ref name="atomSubtitle"/>
</optional>
<optional>
<ref name="atomTitle"/>
</optional>
<optional>
<ref name="atomUpdated"/>
</optional>
<zeroOrMore>
<ref name="extensionElement"/>
</zeroOrMore>
</interleave>
</element>
</define>
<!-- atom:subtitle -->
<define name="atomSubtitle">
<element name="atom:subtitle">
<ref name="atomTextConstruct"/>
</element>
</define>
<!-- atom:summary -->
<define name="atomSummary">
<element name="atom:summary">
<ref name="atomTextConstruct"/>
</element>
</define>
<!-- atom:title -->
<define name="atomTitle">
<element name="atom:title">
<ref name="atomTextConstruct"/>
</element>
</define>
<!-- atom:updated -->
<define name="atomUpdated">
<element name="atom:updated">
<ref name="atomDateConstruct"/>
</element>
</define>
<!-- Low-level simple types -->
<define name="atomNCName">
<data type="string">
<param name="minLength">1</param>
<param name="pattern">[^:]*</param>
</data>
</define>
<!-- Whatever a media type is, it contains at least one slash -->
<define name="atomMediaType">
<data type="string">
<param name="pattern">.+/.+</param>
</data>
</define>
<!-- As defined in RFC 3066 -->
<define name="atomLanguageTag">
<data type="string">
<param name="pattern">[A-Za-z]{1,8}(-[A-Za-z0-9]{1,8})*</param>
</data>
</define>
<!--
Unconstrained; it's not entirely clear how IRI fit into
xsd:anyURI so let's not try to constrain it here
-->
<define name="atomUri">
<text/>
</define>
<!-- Whatever an email address is, it contains at least one @ -->
<define name="atomEmailAddress">
<data type="string">
<param name="pattern">.+@.+</param>
</data>
</define>
<!-- Simple Extension -->
<define name="simpleExtensionElement">
<element>
<anyName>
<except>
<nsName ns="http://www.w3.org/2005/Atom"/>
</except>
</anyName>
<text/>
</element>
</define>
<!-- Structured Extension -->
<define name="structuredExtensionElement">
<element>
<anyName>
<except>
<nsName ns="http://www.w3.org/2005/Atom"/>
</except>
</anyName>
<choice>
<group>
<oneOrMore>
<attribute>
<anyName/>
</attribute>
</oneOrMore>
<zeroOrMore>
<choice>
<text/>
<ref name="anyElement"/>
</choice>
</zeroOrMore>
</group>
<group>
<zeroOrMore>
<attribute>
<anyName/>
</attribute>
</zeroOrMore>
<group>
<optional>
<text/>
</optional>
<oneOrMore>
<ref name="anyElement"/>
</oneOrMore>
<zeroOrMore>
<choice>
<text/>
<ref name="anyElement"/>
</choice>
</zeroOrMore>
</group>
</group>
</choice>
</element>
</define>
<!-- Other Extensibility -->
<define name="extensionElement">
<choice>
<ref name="simpleExtensionElement"/>
<ref name="structuredExtensionElement"/>
</choice>
</define>
<define name="undefinedAttribute">
<attribute>
<anyName>
<except>
<name>xml:base</name>
<name>xml:lang</name>
<nsName ns=""/>
</except>
</anyName>
</attribute>
</define>
<define name="undefinedContent">
<zeroOrMore>
<choice>
<text/>
<ref name="anyForeignElement"/>
</choice>
</zeroOrMore>
</define>
<define name="anyElement">
<element>
<anyName/>
<zeroOrMore>
<choice>
<attribute>
<anyName/>
</attribute>
<text/>
<ref name="anyElement"/>
</choice>
</zeroOrMore>
</element>
</define>
<define name="anyForeignElement">
<element>
<anyName>
<except>
<nsName ns="http://www.w3.org/2005/Atom"/>
</except>
</anyName>
<zeroOrMore>
<choice>
<attribute>
<anyName/>
</attribute>
<text/>
<ref name="anyElement"/>
</choice>
</zeroOrMore>
</element>
</define>
<!-- XHTML -->
<define name="anyXHTML">
<element>
<nsName ns="http://www.w3.org/1999/xhtml"/>
<zeroOrMore>
<choice>
<attribute>
<anyName/>
</attribute>
<text/>
<ref name="anyXHTML"/>
</choice>
</zeroOrMore>
</element>
</define>
<define name="xhtmlDiv">
<element name="xhtml:div">
<zeroOrMore>
<choice>
<attribute>
<anyName/>
</attribute>
<text/>
<ref name="anyXHTML"/>
</choice>
</zeroOrMore>
</element>
</define>
</grammar>

View File

@ -0,0 +1,28 @@
<element name="limits" ns="http://docs.openstack.org/common/api/v1.0"
xmlns="http://relaxng.org/ns/structure/1.0">
<element name="rates">
<zeroOrMore>
<element name="rate">
<attribute name="uri"> <text/> </attribute>
<attribute name="regex"> <text/> </attribute>
<zeroOrMore>
<element name="limit">
<attribute name="value"> <text/> </attribute>
<attribute name="verb"> <text/> </attribute>
<attribute name="remaining"> <text/> </attribute>
<attribute name="unit"> <text/> </attribute>
<attribute name="next-available"> <text/> </attribute>
</element>
</zeroOrMore>
</element>
</zeroOrMore>
</element>
<element name="absolute">
<zeroOrMore>
<element name="limit">
<attribute name="name"> <text/> </attribute>
<attribute name="value"> <text/> </attribute>
</element>
</zeroOrMore>
</element>
</element>

View File

@ -17,12 +17,15 @@
"""Wsgi helper utilities for reddwarf""" """Wsgi helper utilities for reddwarf"""
import eventlet.wsgi import eventlet.wsgi
import math
import paste.urlmap import paste.urlmap
import re import re
import time
import traceback import traceback
import webob import webob
import webob.dec import webob.dec
import webob.exc import webob.exc
from lxml import etree
from paste import deploy from paste import deploy
from xml.dom import minidom from xml.dom import minidom
@ -30,13 +33,14 @@ from reddwarf.common import context as rd_context
from reddwarf.common import exception from reddwarf.common import exception
from reddwarf.common import utils from reddwarf.common import utils
from reddwarf.openstack.common.gettextutils import _ from reddwarf.openstack.common.gettextutils import _
from reddwarf.openstack.common import jsonutils
from reddwarf.openstack.common import pastedeploy from reddwarf.openstack.common import pastedeploy
from reddwarf.openstack.common import service from reddwarf.openstack.common import service
from reddwarf.openstack.common import wsgi as openstack_wsgi from reddwarf.openstack.common import wsgi as openstack_wsgi
from reddwarf.openstack.common import log as logging from reddwarf.openstack.common import log as logging
from reddwarf.common import cfg from reddwarf.common import cfg
CONTEXT_KEY = 'reddwarf.context' CONTEXT_KEY = 'reddwarf.context'
Router = openstack_wsgi.Router Router = openstack_wsgi.Router
Debug = openstack_wsgi.Debug Debug = openstack_wsgi.Debug
@ -130,6 +134,54 @@ def launch(app_name, port, paste_config_file, data={},
return service.launch(server) return service.launch(server)
# Note: taken from Nova
def serializers(**serializers):
"""Attaches serializers to a method.
This decorator associates a dictionary of serializers with a
method. Note that the function attributes are directly
manipulated; the method is not wrapped.
"""
def decorator(func):
if not hasattr(func, 'wsgi_serializers'):
func.wsgi_serializers = {}
func.wsgi_serializers.update(serializers)
return func
return decorator
class ReddwarfMiddleware(Middleware):
# Note: taken from nova
@classmethod
def factory(cls, global_config, **local_config):
"""Used for paste app factories in paste.deploy config files.
Any local configuration (that is, values under the [filter:APPNAME]
section of the paste config) will be passed into the `__init__` method
as kwargs.
A hypothetical configuration would look like:
[filter:analytics]
redis_host = 127.0.0.1
paste.filter_factory = nova.api.analytics:Analytics.factory
which would result in a call to the `Analytics` class as
import nova.api.analytics
analytics.Analytics(app_from_paste, redis_host='127.0.0.1')
You could of course re-implement the `factory` method in subclasses,
but using the kwarg passing it shouldn't be necessary.
"""
def _factory(app):
return cls(app, **local_config)
return _factory
class VersionedURLMap(object): class VersionedURLMap(object):
def __init__(self, urlmap): def __init__(self, urlmap):
@ -591,3 +643,186 @@ class FaultWrapper(openstack_wsgi.Middleware):
def _factory(app): def _factory(app):
return cls(app) return cls(app)
return _factory return _factory
# ported from Nova
class OverLimitFault(webob.exc.HTTPException):
"""
Rate-limited request response.
"""
def __init__(self, message, details, retry_time):
"""
Initialize new `OverLimitFault` with relevant information.
"""
hdrs = OverLimitFault._retry_after(retry_time)
self.wrapped_exc = webob.exc.HTTPRequestEntityTooLarge(headers=hdrs)
self.content = {"overLimit": {"code": self.wrapped_exc.status_int,
"message": message,
"details": details,
"retryAfter": hdrs['Retry-After'],
},
}
@staticmethod
def _retry_after(retry_time):
delay = int(math.ceil(retry_time - time.time()))
retry_after = delay if delay > 0 else 0
headers = {'Retry-After': '%d' % retry_after}
return headers
@webob.dec.wsgify(RequestClass=Request)
def __call__(self, request):
"""
Return the wrapped exception with a serialized body conforming to our
error format.
"""
content_type = request.best_match_content_type()
metadata = {"attributes": {"overLimit": ["code", "retryAfter"]}}
xml_serializer = XMLDictSerializer(metadata, XMLNS)
serializer = {'application/xml': xml_serializer,
'application/json': JSONDictSerializer(),
}[content_type]
content = serializer.serialize(self.content)
self.wrapped_exc.body = content
self.wrapped_exc.content_type = content_type
return self.wrapped_exc
class ActionDispatcher(object):
"""Maps method name to local methods through action name."""
def dispatch(self, *args, **kwargs):
"""Find and call local method."""
action = kwargs.pop('action', 'default')
action_method = getattr(self, str(action), self.default)
return action_method(*args, **kwargs)
def default(self, data):
raise NotImplementedError()
class DictSerializer(ActionDispatcher):
"""Default request body serialization."""
def serialize(self, data, action='default'):
return self.dispatch(data, action=action)
def default(self, data):
return ""
class JSONDictSerializer(DictSerializer):
"""Default JSON request body serialization."""
def default(self, data):
return jsonutils.dumps(data)
class XMLDictSerializer(DictSerializer):
def __init__(self, metadata=None, xmlns=None):
"""
:param metadata: information needed to deserialize xml into
a dictionary.
:param xmlns: XML namespace to include with serialized xml
"""
super(XMLDictSerializer, self).__init__()
self.metadata = metadata or {}
self.xmlns = xmlns
def default(self, data):
# We expect data to contain a single key which is the XML root.
root_key = data.keys()[0]
doc = minidom.Document()
node = self._to_xml_node(doc, self.metadata, root_key, data[root_key])
return self.to_xml_string(node)
def to_xml_string(self, node, has_atom=False):
self._add_xmlns(node, has_atom)
return node.toxml('UTF-8')
#NOTE (ameade): the has_atom should be removed after all of the
# xml serializers and view builders have been updated to the current
# spec that required all responses include the xmlns:atom, the has_atom
# flag is to prevent current tests from breaking
def _add_xmlns(self, node, has_atom=False):
if self.xmlns is not None:
node.setAttribute('xmlns', self.xmlns)
if has_atom:
node.setAttribute('xmlns:atom', "http://www.w3.org/2005/Atom")
def _to_xml_node(self, doc, metadata, nodename, data):
"""Recursive method to convert data members to XML nodes."""
result = doc.createElement(nodename)
# Set the xml namespace if one is specified
# TODO(justinsb): We could also use prefixes on the keys
xmlns = metadata.get('xmlns', None)
if xmlns:
result.setAttribute('xmlns', xmlns)
#TODO(bcwaldon): accomplish this without a type-check
if isinstance(data, list):
collections = metadata.get('list_collections', {})
if nodename in collections:
metadata = collections[nodename]
for item in data:
node = doc.createElement(metadata['item_name'])
node.setAttribute(metadata['item_key'], str(item))
result.appendChild(node)
return result
singular = metadata.get('plurals', {}).get(nodename, None)
if singular is None:
if nodename.endswith('s'):
singular = nodename[:-1]
else:
singular = 'item'
for item in data:
node = self._to_xml_node(doc, metadata, singular, item)
result.appendChild(node)
#TODO(bcwaldon): accomplish this without a type-check
elif isinstance(data, dict):
collections = metadata.get('dict_collections', {})
if nodename in collections:
metadata = collections[nodename]
for k, v in data.items():
node = doc.createElement(metadata['item_name'])
node.setAttribute(metadata['item_key'], str(k))
text = doc.createTextNode(str(v))
node.appendChild(text)
result.appendChild(node)
return result
attrs = metadata.get('attributes', {}).get(nodename, {})
for k, v in data.items():
if k in attrs:
result.setAttribute(k, str(v))
else:
if k == "deleted":
v = str(bool(v))
node = self._to_xml_node(doc, metadata, k, v)
result.appendChild(node)
else:
# Type is atom
node = doc.createTextNode(str(data))
result.appendChild(node)
return result
def _create_link_nodes(self, xml_doc, links):
link_nodes = []
for link in links:
link_node = xml_doc.createElement('atom:link')
link_node.setAttribute('rel', link['rel'])
link_node.setAttribute('href', link['href'])
if 'type' in link:
link_node.setAttribute('type', link['type'])
link_nodes.append(link_node)
return link_nodes
def _to_xml(self, root):
"""Convert the xml object to an xml string."""
return etree.tostring(root, encoding='UTF-8', xml_declaration=True)

910
reddwarf/common/xmlutil.py Normal file
View File

@ -0,0 +1,910 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2011 OpenStack LLC.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import os.path
from lxml import etree
XMLNS_V10 = 'http://docs.rackspacecloud.com/servers/api/v1.0'
XMLNS_V11 = 'http://docs.openstack.org/database/api/v1.1'
XMLNS_COMMON_V10 = 'http://docs.openstack.org/common/api/v1.0'
XMLNS_ATOM = 'http://www.w3.org/2005/Atom'
def validate_schema(xml, schema_name):
if isinstance(xml, str):
xml = etree.fromstring(xml)
base_path = 'reddwarf/common/schemas/v1.1/'
if schema_name in ('atom', 'atom-link'):
base_path = 'reddwarf/common/schemas/'
# TODO: need to figure out our schema paths later
import reddwarf
schema_path = os.path.join(os.path.abspath(reddwarf.__file__)
.split('reddwarf/__init__.py')[0],
'%s%s.rng' % (base_path, schema_name))
schema_doc = etree.parse(schema_path)
relaxng = etree.RelaxNG(schema_doc)
relaxng.assertValid(xml)
class Selector(object):
"""Selects datum to operate on from an object."""
def __init__(self, *chain):
"""Initialize the selector.
Each argument is a subsequent index into the object.
"""
self.chain = chain
def __repr__(self):
"""Return a representation of the selector."""
return "Selector" + repr(self.chain)
def __call__(self, obj, do_raise=False):
"""Select a datum to operate on.
Selects the relevant datum within the object.
:param obj: The object from which to select the object.
:param do_raise: If False (the default), return None if the
indexed datum does not exist. Otherwise,
raise a KeyError.
"""
# Walk the selector list
for elem in self.chain:
# If it's callable, call it
if callable(elem):
obj = elem(obj)
else:
# Use indexing
try:
obj = obj[elem]
except (KeyError, IndexError):
# No sense going any further
if do_raise:
# Convert to a KeyError, for consistency
raise KeyError(elem)
return None
# Return the finally-selected object
return obj
def get_items(obj):
"""Get items in obj."""
return list(obj.items())
class EmptyStringSelector(Selector):
"""Returns the empty string if Selector would return None."""
def __call__(self, obj, do_raise=False):
"""Returns empty string if the selected value does not exist."""
try:
return super(EmptyStringSelector, self).__call__(obj, True)
except KeyError:
return ""
class ConstantSelector(object):
"""Returns a constant."""
def __init__(self, value):
"""Initialize the selector.
:param value: The value to return.
"""
self.value = value
def __repr__(self):
"""Return a representation of the selector."""
return repr(self.value)
def __call__(self, _obj, _do_raise=False):
"""Select a datum to operate on.
Returns a constant value. Compatible with
Selector.__call__().
"""
return self.value
class TemplateElement(object):
"""Represent an element in the template."""
def __init__(self, tag, attrib=None, selector=None, subselector=None,
**extra):
"""Initialize an element.
Initializes an element in the template. Keyword arguments
specify attributes to be set on the element; values must be
callables. See TemplateElement.set() for more information.
:param tag: The name of the tag to create.
:param attrib: An optional dictionary of element attributes.
:param selector: An optional callable taking an object and
optional boolean do_raise indicator and
returning the object bound to the element.
:param subselector: An optional callable taking an object and
optional boolean do_raise indicator and
returning the object bound to the element.
This is used to further refine the datum
object returned by selector in the event
that it is a list of objects.
"""
# Convert selector into a Selector
if selector is None:
selector = Selector()
elif not callable(selector):
selector = Selector(selector)
# Convert subselector into a Selector
if subselector is not None and not callable(subselector):
subselector = Selector(subselector)
self.tag = tag
self.selector = selector
self.subselector = subselector
self.attrib = {}
self._text = None
self._children = []
self._childmap = {}
# Run the incoming attributes through set() so that they
# become selectorized
if not attrib:
attrib = {}
attrib.update(extra)
for k, v in attrib.items():
self.set(k, v)
def __repr__(self):
"""Return a representation of the template element."""
return ('<%s.%s %r at %#x>' %
(self.__class__.__module__, self.__class__.__name__,
self.tag, id(self)))
def __len__(self):
"""Return the number of child elements."""
return len(self._children)
def __contains__(self, key):
"""Determine whether a child node named by key exists."""
return key in self._childmap
def __getitem__(self, idx):
"""Retrieve a child node by index or name."""
if isinstance(idx, basestring):
# Allow access by node name
return self._childmap[idx]
else:
return self._children[idx]
def append(self, elem):
"""Append a child to the element."""
# Unwrap templates...
elem = elem.unwrap()
# Avoid duplications
if elem.tag in self._childmap:
raise KeyError(elem.tag)
self._children.append(elem)
self._childmap[elem.tag] = elem
def extend(self, elems):
"""Append children to the element."""
# Pre-evaluate the elements
elemmap = {}
elemlist = []
for elem in elems:
# Unwrap templates...
elem = elem.unwrap()
# Avoid duplications
if elem.tag in self._childmap or elem.tag in elemmap:
raise KeyError(elem.tag)
elemmap[elem.tag] = elem
elemlist.append(elem)
# Update the children
self._children.extend(elemlist)
self._childmap.update(elemmap)
def insert(self, idx, elem):
"""Insert a child element at the given index."""
# Unwrap templates...
elem = elem.unwrap()
# Avoid duplications
if elem.tag in self._childmap:
raise KeyError(elem.tag)
self._children.insert(idx, elem)
self._childmap[elem.tag] = elem
def remove(self, elem):
"""Remove a child element."""
# Unwrap templates...
elem = elem.unwrap()
# Check if element exists
if elem.tag not in self._childmap or self._childmap[elem.tag] != elem:
raise ValueError(_('element is not a child'))
self._children.remove(elem)
del self._childmap[elem.tag]
def get(self, key):
"""Get an attribute.
Returns a callable which performs datum selection.
:param key: The name of the attribute to get.
"""
return self.attrib[key]
def set(self, key, value=None):
"""Set an attribute.
:param key: The name of the attribute to set.
:param value: A callable taking an object and optional boolean
do_raise indicator and returning the datum bound
to the attribute. If None, a Selector() will be
constructed from the key. If a string, a
Selector() will be constructed from the string.
"""
# Convert value to a selector
if value is None:
value = Selector(key)
elif not callable(value):
value = Selector(value)
self.attrib[key] = value
def keys(self):
"""Return the attribute names."""
return self.attrib.keys()
def items(self):
"""Return the attribute names and values."""
return self.attrib.items()
def unwrap(self):
"""Unwraps a template to return a template element."""
# We are a template element
return self
def wrap(self):
"""Wraps a template element to return a template."""
# Wrap in a basic Template
return Template(self)
def apply(self, elem, obj):
"""Apply text and attributes to an etree.Element.
Applies the text and attribute instructions in the template
element to an etree.Element instance.
:param elem: An etree.Element instance.
:param obj: The base object associated with this template
element.
"""
# Start with the text...
if self.text is not None:
elem.text = unicode(self.text(obj))
# Now set up all the attributes...
for key, value in self.attrib.items():
try:
elem.set(key, unicode(value(obj, True)))
except KeyError:
# Attribute has no value, so don't include it
pass
def _render(self, parent, datum, patches, nsmap):
"""Internal rendering.
Renders the template node into an etree.Element object.
Returns the etree.Element object.
:param parent: The parent etree.Element instance.
:param datum: The datum associated with this template element.
:param patches: A list of other template elements that must
also be applied.
:param nsmap: An optional namespace dictionary to be
associated with the etree.Element instance.
"""
# Allocate a node
if callable(self.tag):
tagname = self.tag(datum)
else:
tagname = self.tag
elem = etree.Element(tagname, nsmap=nsmap)
# If we have a parent, append the node to the parent
if parent is not None:
parent.append(elem)
# If the datum is None, do nothing else
if datum is None:
return elem
# Apply this template element to the element
self.apply(elem, datum)
# Additionally, apply the patches
for patch in patches:
patch.apply(elem, datum)
# We have fully rendered the element; return it
return elem
def render(self, parent, obj, patches=[], nsmap=None):
"""Render an object.
Renders an object against this template node. Returns a list
of two-item tuples, where the first item is an etree.Element
instance and the second item is the datum associated with that
instance.
:param parent: The parent for the etree.Element instances.
:param obj: The object to render this template element
against.
:param patches: A list of other template elements to apply
when rendering this template element.
:param nsmap: An optional namespace dictionary to attach to
the etree.Element instances.
"""
# First, get the datum we're rendering
data = None if obj is None else self.selector(obj)
# Check if we should render at all
if not self.will_render(data):
return []
elif data is None:
return [(self._render(parent, None, patches, nsmap), None)]
# Make the data into a list if it isn't already
if not isinstance(data, list):
data = [data]
elif parent is None:
raise ValueError(_('root element selecting a list'))
# Render all the elements
elems = []
for datum in data:
if self.subselector is not None:
datum = self.subselector(datum)
elems.append((self._render(parent, datum, patches, nsmap), datum))
# Return all the elements rendered, as well as the
# corresponding datum for the next step down the tree
return elems
def will_render(self, datum):
"""Hook method.
An overridable hook method to determine whether this template
element will be rendered at all. By default, returns False
(inhibiting rendering) if the datum is None.
:param datum: The datum associated with this template element.
"""
# Don't render if datum is None
return datum is not None
def _text_get(self):
"""Template element text.
Either None or a callable taking an object and optional
boolean do_raise indicator and returning the datum bound to
the text of the template element.
"""
return self._text
def _text_set(self, value):
# Convert value to a selector
if value is not None and not callable(value):
value = Selector(value)
self._text = value
def _text_del(self):
self._text = None
text = property(_text_get, _text_set, _text_del)
def tree(self):
"""Return string representation of the template tree.
Returns a representation of the template rooted at this
element as a string, suitable for inclusion in debug logs.
"""
# Build the inner contents of the tag...
contents = [self.tag, '!selector=%r' % self.selector]
# Add the text...
if self.text is not None:
contents.append('!text=%r' % self.text)
# Add all the other attributes
for key, value in self.attrib.items():
contents.append('%s=%r' % (key, value))
# If there are no children, return it as a closed tag
if len(self) == 0:
return '<%s/>' % ' '.join([str(i) for i in contents])
# OK, recurse to our children
children = [c.tree() for c in self]
# Return the result
return ('<%s>%s</%s>' %
(' '.join(contents), ''.join(children), self.tag))
def SubTemplateElement(parent, tag, attrib=None, selector=None,
subselector=None, **extra):
"""Create a template element as a child of another.
Corresponds to the etree.SubElement interface. Parameters are as
for TemplateElement, with the addition of the parent.
"""
# Convert attributes
attrib = attrib or {}
attrib.update(extra)
# Get a TemplateElement
elem = TemplateElement(tag, attrib=attrib, selector=selector,
subselector=subselector)
# Append the parent safely
if parent is not None:
parent.append(elem)
return elem
class Template(object):
"""Represent a template."""
def __init__(self, root, nsmap=None):
"""Initialize a template.
:param root: The root element of the template.
:param nsmap: An optional namespace dictionary to be
associated with the root element of the
template.
"""
self.root = root.unwrap() if root is not None else None
self.nsmap = nsmap or {}
self.serialize_options = dict(encoding='UTF-8', xml_declaration=True)
def _serialize(self, parent, obj, siblings, nsmap=None):
"""Internal serialization.
Recursive routine to build a tree of etree.Element instances
from an object based on the template. Returns the first
etree.Element instance rendered, or None.
:param parent: The parent etree.Element instance. Can be
None.
:param obj: The object to render.
:param siblings: The TemplateElement instances against which
to render the object.
:param nsmap: An optional namespace dictionary to be
associated with the etree.Element instance
rendered.
"""
# First step, render the element
elems = siblings[0].render(parent, obj, siblings[1:], nsmap)
# Now, recurse to all child elements
seen = set()
for idx, sibling in enumerate(siblings):
for child in sibling:
# Have we handled this child already?
if child.tag in seen:
continue
seen.add(child.tag)
# Determine the child's siblings
nieces = [child]
for sib in siblings[idx + 1:]:
if child.tag in sib:
nieces.append(sib[child.tag])
# Now we recurse for every data element
for elem, datum in elems:
self._serialize(elem, datum, nieces)
# Return the first element; at the top level, this will be the
# root element
if elems:
return elems[0][0]
def serialize(self, obj, *args, **kwargs):
"""Serialize an object.
Serializes an object against the template. Returns a string
with the serialized XML. Positional and keyword arguments are
passed to etree.tostring().
:param obj: The object to serialize.
"""
elem = self.make_tree(obj)
if elem is None:
return ''
for k, v in self.serialize_options.items():
kwargs.setdefault(k, v)
# Serialize it into XML
return etree.tostring(elem, *args, **kwargs)
def make_tree(self, obj):
"""Create a tree.
Serializes an object against the template. Returns an Element
node with appropriate children.
:param obj: The object to serialize.
"""
# If the template is empty, return the empty string
if self.root is None:
return None
# Get the siblings and nsmap of the root element
siblings = self._siblings()
nsmap = self._nsmap()
# Form the element tree
return self._serialize(None, obj, siblings, nsmap)
def _siblings(self):
"""Hook method for computing root siblings.
An overridable hook method to return the siblings of the root
element. By default, this is the root element itself.
"""
return [self.root]
def _nsmap(self):
"""Hook method for computing the namespace dictionary.
An overridable hook method to return the namespace dictionary.
"""
return self.nsmap.copy()
def unwrap(self):
"""Unwraps a template to return a template element."""
# Return the root element
return self.root
def wrap(self):
"""Wraps a template element to return a template."""
# We are a template
return self
def apply(self, master):
"""Hook method for determining slave applicability.
An overridable hook method used to determine if this template
is applicable as a slave to a given master template.
:param master: The master template to test.
"""
return True
def tree(self):
"""Return string representation of the template tree.
Returns a representation of the template as a string, suitable
for inclusion in debug logs.
"""
return "%r: %s" % (self, self.root.tree())
class MasterTemplate(Template):
"""Represent a master template.
Master templates are versioned derivatives of templates that
additionally allow slave templates to be attached. Slave
templates allow modification of the serialized result without
directly changing the master.
"""
def __init__(self, root, version, nsmap=None):
"""Initialize a master template.
:param root: The root element of the template.
:param version: The version number of the template.
:param nsmap: An optional namespace dictionary to be
associated with the root element of the
template.
"""
super(MasterTemplate, self).__init__(root, nsmap)
self.version = version
self.slaves = []
def __repr__(self):
"""Return string representation of the template."""
return ("<%s.%s object version %s at %#x>" %
(self.__class__.__module__, self.__class__.__name__,
self.version, id(self)))
def _siblings(self):
"""Hook method for computing root siblings.
An overridable hook method to return the siblings of the root
element. This is the root element plus the root elements of
all the slave templates.
"""
return [self.root] + [slave.root for slave in self.slaves]
def _nsmap(self):
"""Hook method for computing the namespace dictionary.
An overridable hook method to return the namespace dictionary.
The namespace dictionary is computed by taking the master
template's namespace dictionary and updating it from all the
slave templates.
"""
nsmap = self.nsmap.copy()
for slave in self.slaves:
nsmap.update(slave._nsmap())
return nsmap
def attach(self, *slaves):
"""Attach one or more slave templates.
Attaches one or more slave templates to the master template.
Slave templates must have a root element with the same tag as
the master template. The slave template's apply() method will
be called to determine if the slave should be applied to this
master; if it returns False, that slave will be skipped.
(This allows filtering of slaves based on the version of the
master template.)
"""
slave_list = []
for slave in slaves:
slave = slave.wrap()
# Make sure we have a tree match
if slave.root.tag != self.root.tag:
slavetag = slave.root.tag
mastertag = self.root.tag
msg = _("Template tree mismatch; adding slave %(slavetag)s "
"to master %(mastertag)s") % locals()
raise ValueError(msg)
# Make sure slave applies to this template
if not slave.apply(self):
continue
slave_list.append(slave)
# Add the slaves
self.slaves.extend(slave_list)
def copy(self):
"""Return a copy of this master template."""
# Return a copy of the MasterTemplate
tmp = self.__class__(self.root, self.version, self.nsmap)
tmp.slaves = self.slaves[:]
return tmp
class SlaveTemplate(Template):
"""Represent a slave template.
Slave templates are versioned derivatives of templates. Each
slave has a minimum version and optional maximum version of the
master template to which they can be attached.
"""
def __init__(self, root, min_vers, max_vers=None, nsmap=None):
"""Initialize a slave template.
:param root: The root element of the template.
:param min_vers: The minimum permissible version of the master
template for this slave template to apply.
:param max_vers: An optional upper bound for the master
template version.
:param nsmap: An optional namespace dictionary to be
associated with the root element of the
template.
"""
super(SlaveTemplate, self).__init__(root, nsmap)
self.min_vers = min_vers
self.max_vers = max_vers
def __repr__(self):
"""Return string representation of the template."""
return ("<%s.%s object versions %s-%s at %#x>" %
(self.__class__.__module__, self.__class__.__name__,
self.min_vers, self.max_vers, id(self)))
def apply(self, master):
"""Hook method for determining slave applicability.
An overridable hook method used to determine if this template
is applicable as a slave to a given master template. This
version requires the master template to have a version number
between min_vers and max_vers.
:param master: The master template to test.
"""
# Does the master meet our minimum version requirement?
if master.version < self.min_vers:
return False
# How about our maximum version requirement?
if self.max_vers is not None and master.version > self.max_vers:
return False
return True
class TemplateBuilder(object):
"""Template builder.
This class exists to allow templates to be lazily built without
having to build them each time they are needed. It must be
subclassed, and the subclass must implement the construct()
method, which must return a Template (or subclass) instance. The
constructor will always return the template returned by
construct(), or, if it has a copy() method, a copy of that
template.
"""
_tmpl = None
def __new__(cls, copy=True):
"""Construct and return a template.
:param copy: If True (the default), a copy of the template
will be constructed and returned, if possible.
"""
# Do we need to construct the template?
if cls._tmpl is None:
tmp = super(TemplateBuilder, cls).__new__(cls)
# Construct the template
cls._tmpl = tmp.construct()
# If the template has a copy attribute, return the result of
# calling it
if copy and hasattr(cls._tmpl, 'copy'):
return cls._tmpl.copy()
# Return the template
return cls._tmpl
def construct(self):
"""Construct a template.
Called to construct a template instance, which it must return.
Only called once.
"""
raise NotImplementedError(_("subclasses must implement construct()!"))
def make_links(parent, selector=None):
"""
Attach an Atom <links> element to the parent.
"""
elem = SubTemplateElement(parent, '{%s}link' % XMLNS_ATOM,
selector=selector)
elem.set('rel')
elem.set('type')
elem.set('href')
# Just for completeness...
return elem
def make_flat_dict(name, selector=None, subselector=None, ns=None):
"""
Utility for simple XML templates that traditionally used
XMLDictSerializer with no metadata. Returns a template element
where the top-level element has the given tag name, and where
sub-elements have tag names derived from the object's keys and
text derived from the object's values. This only works for flat
dictionary objects, not dictionaries containing nested lists or
dictionaries.
"""
# Set up the names we need...
if ns is None:
elemname = name
tagname = Selector(0)
else:
elemname = '{%s}%s' % (ns, name)
tagname = lambda obj, do_raise=False: '{%s}%s' % (ns, obj[0])
if selector is None:
selector = name
# Build the root element
root = TemplateElement(elemname, selector=selector,
subselector=subselector)
# Build an element to represent all the keys and values
elem = SubTemplateElement(root, tagname, selector=get_items)
elem.text = 1
# Return the template
return root

View File

View File

@ -0,0 +1,49 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2013 OpenStack LLC
# Copyright 2013 Hewlett-Packard Development Company, L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from reddwarf.common import wsgi as base_wsgi
from reddwarf.common.limits import LimitsTemplate
from reddwarf.limits import views
from reddwarf.openstack.common import wsgi
class LimitsController(base_wsgi.Controller):
"""
Controller for accessing limits in the OpenStack API.
Note: this is a little different than how other controllers are implemented
"""
@base_wsgi.serializers(xml=LimitsTemplate)
def index(self, req, tenant_id):
"""
Return all global and rate limit information.
"""
context = req.environ[base_wsgi.CONTEXT_KEY]
#
# TODO: hook this in later
#quotas = QUOTAS.get_project_quotas(context, context.project_id,
# usages=False)
#abs_limits = dict((k, v['limit']) for k, v in quotas.items())
abs_limits = {}
rate_limits = req.environ.get("reddwarf.limits", [])
builder = self._get_view_builder(req)
return builder.build(rate_limits, abs_limits)
def _get_view_builder(self, req):
return views.ViewBuilder()

98
reddwarf/limits/views.py Normal file
View File

@ -0,0 +1,98 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010-2011 OpenStack LLC.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import datetime
from reddwarf.openstack.common import timeutils
class ViewBuilder(object):
"""OpenStack API base limits view builder."""
def build(self, rate_limits, absolute_limits):
rate_limits = self._build_rate_limits(rate_limits)
absolute_limits = self._build_absolute_limits(absolute_limits)
output = {
"limits": {
"rate": rate_limits,
"absolute": absolute_limits,
},
}
return output
def _build_absolute_limits(self, absolute_limits):
"""Builder for absolute limits
absolute_limits should be given as a dict of limits.
For example: {"ram": 512, "gigabytes": 1024}.
"""
limit_names = {
"ram": ["maxTotalRAMSize"],
"instances": ["maxTotalInstances"],
"cores": ["maxTotalCores"],
"metadata_items": ["maxServerMeta", "maxImageMeta"],
"injected_files": ["maxPersonality"],
"injected_file_content_bytes": ["maxPersonalitySize"],
"security_groups": ["maxSecurityGroups"],
"security_group_rules": ["maxSecurityGroupRules"],
}
limits = {}
for name, value in absolute_limits.iteritems():
if name in limit_names and value is not None:
for name in limit_names[name]:
limits[name] = value
return limits
def _build_rate_limits(self, rate_limits):
limits = []
for rate_limit in rate_limits:
_rate_limit_key = None
_rate_limit = self._build_rate_limit(rate_limit)
# check for existing key
for limit in limits:
if (limit["uri"] == rate_limit["URI"] and
limit["regex"] == rate_limit["regex"]):
_rate_limit_key = limit
break
# ensure we have a key if we didn't find one
if not _rate_limit_key:
_rate_limit_key = {
"uri": rate_limit["URI"],
"regex": rate_limit["regex"],
"limit": [],
}
limits.append(_rate_limit_key)
_rate_limit_key["limit"].append(_rate_limit)
return limits
def _build_rate_limit(self, rate_limit):
_get_utc = datetime.datetime.utcfromtimestamp
next_avail = _get_utc(rate_limit["resetTime"])
return {
"verb": rate_limit["verb"],
"value": rate_limit["value"],
"remaining": int(rate_limit["remaining"]),
"unit": rate_limit["unit"],
"next-available": timeutils.isotime(at=next_avail),
}

View File

@ -0,0 +1,109 @@
from nose.tools import assert_equal
from nose.tools import assert_false
from nose.tools import assert_true
from proboscis import before_class
from proboscis import test
from reddwarf.openstack.common import timeutils
from reddwarf.tests.util import create_dbaas_client
from reddwarf.tests.util import test_config
from reddwarfclient import exceptions
from datetime import datetime
GROUP = "dbaas.api.limits"
DEFAULT_RATE = 200
# Note: This should not be enabled until rd-client merges
RD_CLIENT_OK = False
@test(groups=[GROUP])
class Limits(object):
@before_class
def setUp(self):
rate_user = self._get_user('rate_limit')
self.rd_client = create_dbaas_client(rate_user)
def _get_user(self, name):
return test_config.users.find_user_by_name(name)
def _get_next_available(self, resource):
return resource.__dict__['next-available']
def __is_available(self, next_available):
dt_next = timeutils.parse_isotime(next_available)
dt_now = datetime.now()
return dt_next.time() < dt_now.time()
@test(enabled=RD_CLIENT_OK)
def test_limits_index(self):
"""test_limits_index"""
r1, r2, r3, r4 = self.rd_client.limits.index()
assert_equal(r1.verb, "POST")
assert_equal(r1.unit, "MINUTE")
assert_true(r1.remaining <= DEFAULT_RATE)
next_available = self._get_next_available(r1)
assert_true(next_available is not None)
assert_equal(r2.verb, "PUT")
assert_equal(r2.unit, "MINUTE")
assert_true(r2.remaining <= DEFAULT_RATE)
next_available = self._get_next_available(r2)
assert_true(next_available is not None)
assert_equal(r3.verb, "DELETE")
assert_equal(r3.unit, "MINUTE")
assert_true(r3.remaining <= DEFAULT_RATE)
next_available = self._get_next_available(r3)
assert_true(next_available is not None)
assert_equal(r4.verb, "GET")
assert_equal(r4.unit, "MINUTE")
assert_true(r4.remaining <= DEFAULT_RATE)
next_available = self._get_next_available(r4)
assert_true(next_available is not None)
@test(enabled=RD_CLIENT_OK)
def test_limits_get_remaining(self):
"""test_limits_get_remaining"""
gets = None
for i in xrange(5):
r1, r2, r3, r4 = self.rd_client.limits.index()
gets = r4
assert_equal(gets.verb, "GET")
assert_equal(gets.unit, "MINUTE")
assert_true(gets.remaining <= DEFAULT_RATE - 5)
next_available = self._get_next_available(gets)
assert_true(next_available is not None)
@test(enabled=RD_CLIENT_OK)
def test_limits_exception(self):
"""test_limits_exception"""
# use a different user to avoid throttling tests run out of order
rate_user_exceeded = self._get_user('rate_limit_exceeded')
rd_client = create_dbaas_client(rate_user_exceeded)
gets = None
encountered = False
for i in xrange(DEFAULT_RATE + 50):
try:
r1, r2, r3, r4 = rd_client.limits.index()
gets = r4
assert_equal(gets.verb, "GET")
assert_equal(gets.unit, "MINUTE")
except exceptions.OverLimit:
encountered = True
assert_true(encountered)
assert_true(gets.remaining <= 50)

View File

@ -0,0 +1,741 @@
# Copyright 2011 OpenStack LLC.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Tests dealing with HTTP rate-limiting.
"""
import httplib
import StringIO
from xml.dom import minidom
from lxml import etree
import testtools
import webob
from mockito import when
from reddwarf.common import limits
from reddwarf.common import xmlutil
from reddwarf.common.limits import Limit
from reddwarf.limits import views
from reddwarf.openstack.common import jsonutils
from reddwarf.tests.unittests.util.matchers import DictMatches
TEST_LIMITS = [
Limit("GET", "/delayed", "^/delayed", 1, limits.PER_MINUTE),
Limit("POST", "*", ".*", 7, limits.PER_MINUTE),
Limit("POST", "/servers", "^/servers", 3, limits.PER_MINUTE),
Limit("PUT", "*", "", 10, limits.PER_MINUTE),
Limit("PUT", "/servers", "^/servers", 5, limits.PER_MINUTE),
]
NS = {
'atom': 'http://www.w3.org/2005/Atom',
'ns': 'http://docs.openstack.org/common/api/v1.0'
}
class BaseLimitTestSuite(testtools.TestCase):
"""Base test suite which provides relevant stubs and time abstraction."""
def setUp(self):
super(BaseLimitTestSuite, self).setUp()
self.absolute_limits = {}
class LimitsControllerTest(BaseLimitTestSuite):
"""
Tests for `limits.LimitsController` class.
TODO: add test cases once absolute limits are integrated
"""
pass
class TestLimiter(limits.Limiter):
pass
class LimitMiddlewareTest(BaseLimitTestSuite):
"""
Tests for the `limits.RateLimitingMiddleware` class.
"""
@webob.dec.wsgify
def _empty_app(self, request):
"""Do-nothing WSGI app."""
pass
def setUp(self):
"""Prepare middleware for use through fake WSGI app."""
super(LimitMiddlewareTest, self).setUp()
_limits = '(GET, *, .*, 1, MINUTE)'
self.app = limits.RateLimitingMiddleware(self._empty_app, _limits,
"%s.TestLimiter" %
self.__class__.__module__)
def test_limit_class(self):
# Test that middleware selected correct limiter class.
assert isinstance(self.app._limiter, TestLimiter)
def test_good_request(self):
# Test successful GET request through middleware.
request = webob.Request.blank("/")
response = request.get_response(self.app)
self.assertEqual(200, response.status_int)
def test_limited_request_json(self):
# Test a rate-limited (413) GET request through middleware.
request = webob.Request.blank("/")
response = request.get_response(self.app)
self.assertEqual(200, response.status_int)
request = webob.Request.blank("/")
response = request.get_response(self.app)
self.assertEqual(response.status_int, 413)
self.assertTrue('Retry-After' in response.headers)
retry_after = int(response.headers['Retry-After'])
self.assertAlmostEqual(retry_after, 60, 1)
body = jsonutils.loads(response.body)
expected = "Only 1 GET request(s) can be made to * every minute."
value = body["overLimit"]["details"].strip()
self.assertEqual(value, expected)
self.assertTrue("retryAfter" in body["overLimit"])
retryAfter = body["overLimit"]["retryAfter"]
self.assertEqual(retryAfter, "60")
def test_limited_request_xml(self):
# Test a rate-limited (413) response as XML.
request = webob.Request.blank("/")
response = request.get_response(self.app)
self.assertEqual(200, response.status_int)
request = webob.Request.blank("/")
request.accept = "application/xml"
response = request.get_response(self.app)
self.assertEqual(response.status_int, 413)
root = minidom.parseString(response.body).childNodes[0]
expected = "Only 1 GET request(s) can be made to * every minute."
self.assertNotEqual(root.attributes.getNamedItem("retryAfter"), None)
retryAfter = root.attributes.getNamedItem("retryAfter").value
self.assertEqual(retryAfter, "60")
details = root.getElementsByTagName("details")
self.assertEqual(details.length, 1)
value = details.item(0).firstChild.data.strip()
self.assertEqual(value, expected)
class LimitTest(BaseLimitTestSuite):
"""
Tests for the `limits.Limit` class.
"""
def test_GET_no_delay(self):
# Test a limit handles 1 GET per second.
limit = Limit("GET", "*", ".*", 1, 1)
when(limit)._get_time().thenReturn(0.0)
delay = limit("GET", "/anything")
self.assertEqual(None, delay)
self.assertEqual(0, limit.next_request)
self.assertEqual(0, limit.last_request)
def test_GET_delay(self):
# Test two calls to 1 GET per second limit.
limit = Limit("GET", "*", ".*", 1, 1)
when(limit)._get_time().thenReturn(0.0)
delay = limit("GET", "/anything")
self.assertEqual(None, delay)
delay = limit("GET", "/anything")
self.assertEqual(1, delay)
self.assertEqual(1, limit.next_request)
self.assertEqual(0, limit.last_request)
when(limit)._get_time().thenReturn(4.0)
delay = limit("GET", "/anything")
self.assertEqual(None, delay)
self.assertEqual(4, limit.next_request)
self.assertEqual(4, limit.last_request)
class ParseLimitsTest(BaseLimitTestSuite):
"""
Tests for the default limits parser in the in-memory
`limits.Limiter` class.
"""
def test_invalid(self):
# Test that parse_limits() handles invalid input correctly.
self.assertRaises(ValueError, limits.Limiter.parse_limits,
';;;;;')
def test_bad_rule(self):
# Test that parse_limits() handles bad rules correctly.
self.assertRaises(ValueError, limits.Limiter.parse_limits,
'GET, *, .*, 20, minute')
def test_missing_arg(self):
# Test that parse_limits() handles missing args correctly.
self.assertRaises(ValueError, limits.Limiter.parse_limits,
'(GET, *, .*, 20)')
def test_bad_value(self):
# Test that parse_limits() handles bad values correctly.
self.assertRaises(ValueError, limits.Limiter.parse_limits,
'(GET, *, .*, foo, minute)')
def test_bad_unit(self):
# Test that parse_limits() handles bad units correctly.
self.assertRaises(ValueError, limits.Limiter.parse_limits,
'(GET, *, .*, 20, lightyears)')
def test_multiple_rules(self):
# Test that parse_limits() handles multiple rules correctly.
try:
l = limits.Limiter.parse_limits('(get, *, .*, 20, minute);'
'(PUT, /foo*, /foo.*, 10, hour);'
'(POST, /bar*, /bar.*, 5, second);'
'(Say, /derp*, /derp.*, 1, day)')
except ValueError, e:
assert False, str(e)
# Make sure the number of returned limits are correct
self.assertEqual(len(l), 4)
# Check all the verbs...
expected = ['GET', 'PUT', 'POST', 'SAY']
self.assertEqual([t.verb for t in l], expected)
# ...the URIs...
expected = ['*', '/foo*', '/bar*', '/derp*']
self.assertEqual([t.uri for t in l], expected)
# ...the regexes...
expected = ['.*', '/foo.*', '/bar.*', '/derp.*']
self.assertEqual([t.regex for t in l], expected)
# ...the values...
expected = [20, 10, 5, 1]
self.assertEqual([t.value for t in l], expected)
# ...and the units...
expected = [limits.PER_MINUTE, limits.PER_HOUR,
limits.PER_SECOND, limits.PER_DAY]
self.assertEqual([t.unit for t in l], expected)
class LimiterTest(BaseLimitTestSuite):
"""
Tests for the in-memory `limits.Limiter` class.
"""
def update_limits(self, delay):
for l in TEST_LIMITS:
when(l)._get_time().thenReturn(delay)
def setUp(self):
"""Run before each test."""
super(LimiterTest, self).setUp()
userlimits = {'user:user3': ''}
self.update_limits(0.0)
self.limiter = limits.Limiter(TEST_LIMITS, **userlimits)
def _check(self, num, verb, url, username=None):
"""Check and yield results from checks."""
for x in xrange(num):
yield self.limiter.check_for_delay(verb, url, username)[0]
def _check_sum(self, num, verb, url, username=None):
"""Check and sum results from checks."""
results = self._check(num, verb, url, username)
return sum(item for item in results if item)
def test_no_delay_GET(self):
"""
Simple test to ensure no delay on a single call for a limit verb we
didn"t set.
"""
delay = self.limiter.check_for_delay("GET", "/anything")
self.assertEqual(delay, (None, None))
def test_no_delay_PUT(self):
# Simple test to ensure no delay on a single call for a known limit.
delay = self.limiter.check_for_delay("PUT", "/anything")
self.assertEqual(delay, (None, None))
def test_delay_PUT(self):
"""
Ensure the 11th PUT will result in a delay of 6.0 seconds until
the next request will be granced.
"""
expected = [None] * 10 + [6.0]
results = list(self._check(11, "PUT", "/anything"))
self.assertEqual(expected, results)
def test_delay_POST(self):
"""
Ensure the 8th POST will result in a delay of 6.0 seconds until
the next request will be granced.
"""
expected = [None] * 7
results = list(self._check(7, "POST", "/anything"))
self.assertEqual(expected, results)
expected = 60.0 / 7.0
results = self._check_sum(1, "POST", "/anything")
self.failUnlessAlmostEqual(expected, results, 8)
def test_delay_GET(self):
# Ensure the 11th GET will result in NO delay.
expected = [None] * 11
results = list(self._check(11, "GET", "/anything"))
self.assertEqual(expected, results)
def test_delay_PUT_servers(self):
"""
Ensure PUT on /servers limits at 5 requests, and PUT elsewhere is still
OK after 5 requests...but then after 11 total requests, PUT limiting
kicks in.
"""
# First 6 requests on PUT /servers
expected = [None] * 5 + [12.0]
results = list(self._check(6, "PUT", "/servers"))
self.assertEqual(expected, results)
# Next 5 request on PUT /anything
expected = [None] * 4 + [6.0]
results = list(self._check(5, "PUT", "/anything"))
self.assertEqual(expected, results)
def test_delay_PUT_wait(self):
"""
Ensure after hitting the limit and then waiting for the correct
amount of time, the limit will be lifted.
"""
expected = [None] * 10 + [6.0]
results = list(self._check(11, "PUT", "/anything"))
self.assertEqual(expected, results)
# Advance time
self.update_limits(6.0)
expected = [None, 6.0]
results = list(self._check(2, "PUT", "/anything"))
self.assertEqual(expected, results)
def test_multiple_delays(self):
# Ensure multiple requests still get a delay.
expected = [None] * 10 + [6.0] * 10
results = list(self._check(20, "PUT", "/anything"))
self.assertEqual(expected, results)
self.update_limits(1.0)
expected = [5.0] * 10
results = list(self._check(10, "PUT", "/anything"))
self.assertEqual(expected, results)
def test_user_limit(self):
# Test user-specific limits.
self.assertEqual(self.limiter.levels['user3'], [])
def test_multiple_users(self):
# Tests involving multiple users.
# User1
self.update_limits(0.0)
expected = [None] * 10 + [6.0] * 10
results = list(self._check(20, "PUT", "/anything", "user1"))
self.assertEqual(expected, results)
# User2
expected = [None] * 10 + [6.0] * 5
results = list(self._check(15, "PUT", "/anything", "user2"))
self.assertEqual(expected, results)
# User3
expected = [None] * 20
results = list(self._check(20, "PUT", "/anything", "user3"))
self.assertEqual(expected, results)
self.update_limits(1.0)
# User1 again
expected = [5.0] * 10
results = list(self._check(10, "PUT", "/anything", "user1"))
self.assertEqual(expected, results)
self.update_limits(2.0)
# User1 again
expected = [4.0] * 5
results = list(self._check(5, "PUT", "/anything", "user2"))
self.assertEqual(expected, results)
class WsgiLimiterTest(BaseLimitTestSuite):
"""
Tests for `limits.WsgiLimiter` class.
"""
def setUp(self):
"""Run before each test."""
super(WsgiLimiterTest, self).setUp()
self.app = limits.WsgiLimiter(TEST_LIMITS)
def _request_data(self, verb, path):
"""Get data describing a limit request verb/path."""
return jsonutils.dumps({"verb": verb, "path": path})
def _request(self, verb, url, username=None):
"""Make sure that POSTing to the given url causes the given username
to perform the given action. Make the internal rate limiter return
delay and make sure that the WSGI app returns the correct response.
"""
if username:
request = webob.Request.blank("/%s" % username)
else:
request = webob.Request.blank("/")
request.method = "POST"
request.body = self._request_data(verb, url)
response = request.get_response(self.app)
if "X-Wait-Seconds" in response.headers:
self.assertEqual(response.status_int, 403)
return response.headers["X-Wait-Seconds"]
self.assertEqual(response.status_int, 204)
def test_invalid_methods(self):
# Only POSTs should work.
requests = []
for method in ["GET", "PUT", "DELETE", "HEAD", "OPTIONS"]:
request = webob.Request.blank("/", method=method)
response = request.get_response(self.app)
self.assertEqual(response.status_int, 405)
def test_good_url(self):
delay = self._request("GET", "/something")
self.assertEqual(delay, None)
def test_escaping(self):
delay = self._request("GET", "/something/jump%20up")
self.assertEqual(delay, None)
def test_response_to_delays(self):
delay = self._request("GET", "/delayed")
self.assertEqual(delay, None)
delay = self._request("GET", "/delayed")
self.assertEqual(delay, '60.00')
def test_response_to_delays_usernames(self):
delay = self._request("GET", "/delayed", "user1")
self.assertEqual(delay, None)
delay = self._request("GET", "/delayed", "user2")
self.assertEqual(delay, None)
delay = self._request("GET", "/delayed", "user1")
self.assertEqual(delay, '60.00')
delay = self._request("GET", "/delayed", "user2")
self.assertEqual(delay, '60.00')
class FakeHttplibSocket(object):
"""
Fake `httplib.HTTPResponse` replacement.
"""
def __init__(self, response_string):
"""Initialize new `FakeHttplibSocket`."""
self._buffer = StringIO.StringIO(response_string)
def makefile(self, _mode, _other):
"""Returns the socket's internal buffer."""
return self._buffer
class FakeHttplibConnection(object):
"""
Fake `httplib.HTTPConnection`.
"""
def __init__(self, app, host):
"""
Initialize `FakeHttplibConnection`.
"""
self.app = app
self.host = host
def request(self, method, path, body="", headers=None):
"""
Requests made via this connection actually get translated and routed
into our WSGI app, we then wait for the response and turn it back into
an `httplib.HTTPResponse`.
"""
if not headers:
headers = {}
req = webob.Request.blank(path)
req.method = method
req.headers = headers
req.host = self.host
req.body = body
resp = str(req.get_response(self.app))
resp = "HTTP/1.0 %s" % resp
sock = FakeHttplibSocket(resp)
self.http_response = httplib.HTTPResponse(sock)
self.http_response.begin()
def getresponse(self):
"""Return our generated response from the request."""
return self.http_response
def wire_HTTPConnection_to_WSGI(host, app):
"""Monkeypatches HTTPConnection so that if you try to connect to host, you
are instead routed straight to the given WSGI app.
After calling this method, when any code calls
httplib.HTTPConnection(host)
the connection object will be a fake. Its requests will be sent directly
to the given WSGI app rather than through a socket.
Code connecting to hosts other than host will not be affected.
This method may be called multiple times to map different hosts to
different apps.
This method returns the original HTTPConnection object, so that the caller
can restore the default HTTPConnection interface (for all hosts).
"""
class HTTPConnectionDecorator(object):
"""Wraps the real HTTPConnection class so that when you instantiate
the class you might instead get a fake instance."""
def __init__(self, wrapped):
self.wrapped = wrapped
def __call__(self, connection_host, *args, **kwargs):
if connection_host == host:
return FakeHttplibConnection(app, host)
else:
return self.wrapped(connection_host, *args, **kwargs)
oldHTTPConnection = httplib.HTTPConnection
httplib.HTTPConnection = HTTPConnectionDecorator(httplib.HTTPConnection)
return oldHTTPConnection
class WsgiLimiterProxyTest(BaseLimitTestSuite):
"""
Tests for the `limits.WsgiLimiterProxy` class.
"""
def setUp(self):
"""
Do some nifty HTTP/WSGI magic which allows for WSGI to be called
directly by something like the `httplib` library.
"""
super(WsgiLimiterProxyTest, self).setUp()
self.app = limits.WsgiLimiter(TEST_LIMITS)
self.oldHTTPConnection = (
wire_HTTPConnection_to_WSGI("169.254.0.1:80", self.app))
self.proxy = limits.WsgiLimiterProxy("169.254.0.1:80")
def test_200(self):
# Successful request test.
delay = self.proxy.check_for_delay("GET", "/anything")
self.assertEqual(delay, (None, None))
def test_403(self):
# Forbidden request test.
delay = self.proxy.check_for_delay("GET", "/delayed")
self.assertEqual(delay, (None, None))
delay, error = self.proxy.check_for_delay("GET", "/delayed")
error = error.strip()
expected = ("60.00", "403 Forbidden\n\nOnly 1 GET request(s) can be "
"made to /delayed every minute.")
self.assertEqual((delay, error), expected)
def tearDown(self):
# restore original HTTPConnection object
httplib.HTTPConnection = self.oldHTTPConnection
super(WsgiLimiterProxyTest, self).tearDown()
class LimitsViewBuilderTest(testtools.TestCase):
def setUp(self):
super(LimitsViewBuilderTest, self).setUp()
self.view_builder = views.ViewBuilder()
self.rate_limits = [{"URI": "*",
"regex": ".*",
"value": 10,
"verb": "POST",
"remaining": 2,
"unit": "MINUTE",
"resetTime": 1311272226},
{"URI": "*/servers",
"regex": "^/servers",
"value": 50,
"verb": "POST",
"remaining": 10,
"unit": "DAY",
"resetTime": 1311272226}]
self.absolute_limits = {"metadata_items": 1,
"injected_files": 5,
"injected_file_content_bytes": 5}
def test_build_limits(self):
expected_limits = {"limits": {
"rate": [{"uri": "*",
"regex": ".*",
"limit": [{"value": 10,
"verb": "POST",
"remaining": 2,
"unit": "MINUTE",
"next-available": "2011-07-21T18:17:06Z"}]},
{"uri": "*/servers",
"regex": "^/servers",
"limit": [{"value": 50,
"verb": "POST",
"remaining": 10,
"unit": "DAY",
"next-available": "2011-07-21T18:17:06Z"}]}],
"absolute": {
"maxServerMeta": 1,
"maxImageMeta": 1,
"maxPersonality": 5,
"maxPersonalitySize": 5}}}
output = self.view_builder.build(self.rate_limits,
self.absolute_limits)
self.assertThat(output, DictMatches(expected_limits))
def test_build_limits_empty_limits(self):
expected_limits = {"limits": {"rate": [],
"absolute": {}}}
abs_limits = {}
rate_limits = []
output = self.view_builder.build(rate_limits, abs_limits)
self.assertThat(output, DictMatches(expected_limits))
class LimitsXMLSerializationTest(testtools.TestCase):
def test_xml_declaration(self):
serializer = limits.LimitsTemplate()
fixture = {"limits": {
"rate": [],
"absolute": {}}}
output = serializer.serialize(fixture)
has_dec = output.startswith("<?xml version='1.0' encoding='UTF-8'?>")
self.assertTrue(has_dec)
def test_index(self):
serializer = limits.LimitsTemplate()
fixture = {
"limits": {
"rate": [{"uri": "*",
"regex": ".*",
"limit": [
{"value": 10,
"verb": "POST",
"remaining": 2,
"unit": "MINUTE",
"next-available": "2011-12-15T22:42:45Z"}]},
{"uri": "*/servers",
"regex": "^/servers",
"limit": [
{"value": 50,
"verb": "POST",
"remaining": 10,
"unit": "DAY",
"next-available": "2011-12-15T22:42:45Z"}]}],
"absolute": {
"maxServerMeta": 1,
"maxImageMeta": 1,
"maxPersonality": 5,
"maxPersonalitySize": 10240}}}
output = serializer.serialize(fixture)
root = etree.XML(output)
xmlutil.validate_schema(root, 'limits')
#verify absolute limits
absolutes = root.xpath('ns:absolute/ns:limit', namespaces=NS)
self.assertEqual(len(absolutes), 4)
for limit in absolutes:
name = limit.get('name')
value = limit.get('value')
self.assertEqual(value, str(fixture['limits']['absolute'][name]))
#verify rate limits
rates = root.xpath('ns:rates/ns:rate', namespaces=NS)
self.assertEqual(len(rates), 2)
for i, rate in enumerate(rates):
for key in ['uri', 'regex']:
self.assertEqual(rate.get(key),
str(fixture['limits']['rate'][i][key]))
rate_limits = rate.xpath('ns:limit', namespaces=NS)
self.assertEqual(len(rate_limits), 1)
for j, limit in enumerate(rate_limits):
for key in ['verb', 'value', 'remaining', 'unit',
'next-available']:
self.assertEqual(limit.get(key),
str(fixture['limits']['rate'][i]['limit']
[j][key]))
def test_index_no_limits(self):
serializer = limits.LimitsTemplate()
fixture = {"limits": {
"rate": [],
"absolute": {}}}
output = serializer.serialize(fixture)
root = etree.XML(output)
xmlutil.validate_schema(root, 'limits')
#verify absolute limits
absolutes = root.xpath('ns:absolute/ns:limit', namespaces=NS)
self.assertEqual(len(absolutes), 0)
#verify rate limits
rates = root.xpath('ns:rates/ns:rate', namespaces=NS)
self.assertEqual(len(rates), 0)

View File

@ -0,0 +1,454 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# Copyright 2012 Hewlett-Packard Development Company, L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Matcher classes to be used inside of the testtools assertThat framework."""
import pprint
from lxml import etree
class DictKeysMismatch(object):
def __init__(self, d1only, d2only):
self.d1only = d1only
self.d2only = d2only
def describe(self):
return ('Keys in d1 and not d2: %(d1only)s.'
' Keys in d2 and not d1: %(d2only)s' % self.__dict__)
def get_details(self):
return {}
class DictMismatch(object):
def __init__(self, key, d1_value, d2_value):
self.key = key
self.d1_value = d1_value
self.d2_value = d2_value
def describe(self):
return ("Dictionaries do not match at %(key)s."
" d1: %(d1_value)s d2: %(d2_value)s" % self.__dict__)
def get_details(self):
return {}
class DictMatches(object):
def __init__(self, d1, approx_equal=False, tolerance=0.001):
self.d1 = d1
self.approx_equal = approx_equal
self.tolerance = tolerance
def __str__(self):
return 'DictMatches(%s)' % (pprint.pformat(self.d1))
# Useful assertions
def match(self, d2):
"""Assert two dicts are equivalent.
This is a 'deep' match in the sense that it handles nested
dictionaries appropriately.
NOTE:
If you don't care (or don't know) a given value, you can specify
the string DONTCARE as the value. This will cause that dict-item
to be skipped.
"""
d1keys = set(self.d1.keys())
d2keys = set(d2.keys())
if d1keys != d2keys:
d1only = d1keys - d2keys
d2only = d2keys - d1keys
return DictKeysMismatch(d1only, d2only)
for key in d1keys:
d1value = self.d1[key]
d2value = d2[key]
try:
error = abs(float(d1value) - float(d2value))
within_tolerance = error <= self.tolerance
except (ValueError, TypeError):
# If both values aren't convertible to float, just ignore
# ValueError if arg is a str, TypeError if it's something else
# (like None)
within_tolerance = False
if hasattr(d1value, 'keys') and hasattr(d2value, 'keys'):
matcher = DictMatches(d1value)
did_match = matcher.match(d2value)
if did_match is not None:
return did_match
elif 'DONTCARE' in (d1value, d2value):
continue
elif self.approx_equal and within_tolerance:
continue
elif d1value != d2value:
return DictMismatch(key, d1value, d2value)
class ListLengthMismatch(object):
def __init__(self, len1, len2):
self.len1 = len1
self.len2 = len2
def describe(self):
return ('Length mismatch: len(L1)=%(len1)d != '
'len(L2)=%(len2)d' % self.__dict__)
def get_details(self):
return {}
class DictListMatches(object):
def __init__(self, l1, approx_equal=False, tolerance=0.001):
self.l1 = l1
self.approx_equal = approx_equal
self.tolerance = tolerance
def __str__(self):
return 'DictListMatches(%s)' % (pprint.pformat(self.l1))
# Useful assertions
def match(self, l2):
"""Assert a list of dicts are equivalent."""
l1count = len(self.l1)
l2count = len(l2)
if l1count != l2count:
return ListLengthMismatch(l1count, l2count)
for d1, d2 in zip(self.l1, l2):
matcher = DictMatches(d2,
approx_equal=self.approx_equal,
tolerance=self.tolerance)
did_match = matcher.match(d1)
if did_match:
return did_match
class SubDictMismatch(object):
def __init__(self,
key=None,
sub_value=None,
super_value=None,
keys=False):
self.key = key
self.sub_value = sub_value
self.super_value = super_value
self.keys = keys
def describe(self):
if self.keys:
return "Keys between dictionaries did not match"
else:
return("Dictionaries do not match at %s. d1: %s d2: %s"
% (self.key,
self.super_value,
self.sub_value))
def get_details(self):
return {}
class IsSubDictOf(object):
def __init__(self, super_dict):
self.super_dict = super_dict
def __str__(self):
return 'IsSubDictOf(%s)' % (self.super_dict)
def match(self, sub_dict):
"""Assert a sub_dict is subset of super_dict."""
if not set(sub_dict.keys()).issubset(set(self.super_dict.keys())):
return SubDictMismatch(keys=True)
for k, sub_value in sub_dict.items():
super_value = self.super_dict[k]
if isinstance(sub_value, dict):
matcher = IsSubDictOf(super_value)
did_match = matcher.match(sub_value)
if did_match is not None:
return did_match
elif 'DONTCARE' in (sub_value, super_value):
continue
else:
if sub_value != super_value:
return SubDictMismatch(k, sub_value, super_value)
class FunctionCallMatcher(object):
def __init__(self, expected_func_calls):
self.expected_func_calls = expected_func_calls
self.actual_func_calls = []
def call(self, *args, **kwargs):
func_call = {'args': args, 'kwargs': kwargs}
self.actual_func_calls.append(func_call)
def match(self):
dict_list_matcher = DictListMatches(self.expected_func_calls)
return dict_list_matcher.match(self.actual_func_calls)
class XMLMismatch(object):
"""Superclass for XML mismatch."""
def __init__(self, state):
self.path = str(state)
self.expected = state.expected
self.actual = state.actual
def describe(self):
return "%(path)s: XML does not match" % self.__dict__
def get_details(self):
return {
'expected': self.expected,
'actual': self.actual,
}
class XMLTagMismatch(XMLMismatch):
"""XML tags don't match."""
def __init__(self, state, idx, expected_tag, actual_tag):
super(XMLTagMismatch, self).__init__(state)
self.idx = idx
self.expected_tag = expected_tag
self.actual_tag = actual_tag
def describe(self):
return ("%(path)s: XML tag mismatch at index %(idx)d: "
"expected tag <%(expected_tag)s>; "
"actual tag <%(actual_tag)s>" % self.__dict__)
class XMLAttrKeysMismatch(XMLMismatch):
"""XML attribute keys don't match."""
def __init__(self, state, expected_only, actual_only):
super(XMLAttrKeysMismatch, self).__init__(state)
self.expected_only = ', '.join(sorted(expected_only))
self.actual_only = ', '.join(sorted(actual_only))
def describe(self):
return ("%(path)s: XML attributes mismatch: "
"keys only in expected: %(expected_only)s; "
"keys only in actual: %(actual_only)s" % self.__dict__)
class XMLAttrValueMismatch(XMLMismatch):
"""XML attribute values don't match."""
def __init__(self, state, key, expected_value, actual_value):
super(XMLAttrValueMismatch, self).__init__(state)
self.key = key
self.expected_value = expected_value
self.actual_value = actual_value
def describe(self):
return ("%(path)s: XML attribute value mismatch: "
"expected value of attribute %(key)s: %(expected_value)r; "
"actual value: %(actual_value)r" % self.__dict__)
class XMLTextValueMismatch(XMLMismatch):
"""XML text values don't match."""
def __init__(self, state, expected_text, actual_text):
super(XMLTextValueMismatch, self).__init__(state)
self.expected_text = expected_text
self.actual_text = actual_text
def describe(self):
return ("%(path)s: XML text value mismatch: "
"expected text value: %(expected_text)r; "
"actual value: %(actual_text)r" % self.__dict__)
class XMLUnexpectedChild(XMLMismatch):
"""Unexpected child present in XML."""
def __init__(self, state, tag, idx):
super(XMLUnexpectedChild, self).__init__(state)
self.tag = tag
self.idx = idx
def describe(self):
return ("%(path)s: XML unexpected child element <%(tag)s> "
"present at index %(idx)d" % self.__dict__)
class XMLExpectedChild(XMLMismatch):
"""Expected child not present in XML."""
def __init__(self, state, tag, idx):
super(XMLExpectedChild, self).__init__(state)
self.tag = tag
self.idx = idx
def describe(self):
return ("%(path)s: XML expected child element <%(tag)s> "
"not present at index %(idx)d" % self.__dict__)
class XMLMatchState(object):
"""
Maintain some state for matching.
Tracks the XML node path and saves the expected and actual full
XML text, for use by the XMLMismatch subclasses.
"""
def __init__(self, expected, actual):
self.path = []
self.expected = expected
self.actual = actual
def __enter__(self):
pass
def __exit__(self, exc_type, exc_value, exc_tb):
self.path.pop()
return False
def __str__(self):
return '/' + '/'.join(self.path)
def node(self, tag, idx):
"""
Adds tag and index to the path; they will be popped off when
the corresponding 'with' statement exits.
:param tag: The element tag
:param idx: If not None, the integer index of the element
within its parent. Not included in the path
element if None.
"""
if idx is not None:
self.path.append("%s[%d]" % (tag, idx))
else:
self.path.append(tag)
return self
class XMLMatches(object):
"""Compare XML strings. More complete than string comparison."""
def __init__(self, expected):
self.expected_xml = expected
self.expected = etree.fromstring(expected)
def __str__(self):
return 'XMLMatches(%r)' % self.expected_xml
def match(self, actual_xml):
actual = etree.fromstring(actual_xml)
state = XMLMatchState(self.expected_xml, actual_xml)
result = self._compare_node(self.expected, actual, state, None)
if result is False:
return XMLMismatch(state)
elif result is not True:
return result
def _compare_node(self, expected, actual, state, idx):
"""Recursively compares nodes within the XML tree."""
# Start by comparing the tags
if expected.tag != actual.tag:
return XMLTagMismatch(state, idx, expected.tag, actual.tag)
with state.node(expected.tag, idx):
# Compare the attribute keys
expected_attrs = set(expected.attrib.keys())
actual_attrs = set(actual.attrib.keys())
if expected_attrs != actual_attrs:
expected_only = expected_attrs - actual_attrs
actual_only = actual_attrs - expected_attrs
return XMLAttrKeysMismatch(state, expected_only, actual_only)
# Compare the attribute values
for key in expected_attrs:
expected_value = expected.attrib[key]
actual_value = actual.attrib[key]
if 'DONTCARE' in (expected_value, actual_value):
continue
elif expected_value != actual_value:
return XMLAttrValueMismatch(state, key, expected_value,
actual_value)
# Compare the contents of the node
if len(expected) == 0 and len(actual) == 0:
# No children, compare text values
if ('DONTCARE' not in (expected.text, actual.text) and
expected.text != actual.text):
return XMLTextValueMismatch(state, expected.text,
actual.text)
else:
expected_idx = 0
actual_idx = 0
while (expected_idx < len(expected) and
actual_idx < len(actual)):
# Ignore comments and processing instructions
# TODO(Vek): may interpret PIs in the future, to
# allow for, say, arbitrary ordering of some
# elements
if (expected[expected_idx].tag in
(etree.Comment, etree.ProcessingInstruction)):
expected_idx += 1
continue
# Compare the nodes
result = self._compare_node(expected[expected_idx],
actual[actual_idx], state,
actual_idx)
if result is not True:
return result
# Step on to comparing the next nodes...
expected_idx += 1
actual_idx += 1
# Make sure we consumed all nodes in actual
if actual_idx < len(actual):
return XMLUnexpectedChild(state, actual[actual_idx].tag,
actual_idx)
# Make sure we consumed all nodes in expected
if expected_idx < len(expected):
for node in expected[expected_idx:]:
if (node.tag in
(etree.Comment, etree.ProcessingInstruction)):
continue
return XMLExpectedChild(state, node.tag, actual_idx)
# The nodes match
return True

View File

@ -111,6 +111,7 @@ if __name__ == "__main__":
# Initialize the test configuration. # Initialize the test configuration.
CONFIG.load_from_file('etc/tests/localhost.test.conf') CONFIG.load_from_file('etc/tests/localhost.test.conf')
from reddwarf.tests.api import limits
from reddwarf.tests.api import flavors from reddwarf.tests.api import flavors
from reddwarf.tests.api import versions from reddwarf.tests.api import versions
from reddwarf.tests.api import instances from reddwarf.tests.api import instances

View File

@ -17,3 +17,4 @@ testtools>=0.9.22
pexpect pexpect
discover discover
testrepository>=0.0.8 testrepository>=0.0.8
mockito