Adds xml support for quantum v2 API.

Fixes bug 1007998.
blueprint quantum-v2-api-xml

Author: gongysh <gongysh@linux.vnet.ibm.com>

Change-Id: I5f5407e767f8203f980b77075109845bc1553ed9
This commit is contained in:
gongysh 2012-08-05 20:04:24 -03:00 committed by Gerrit Code Review
parent 5840627a48
commit 3f013e370d
27 changed files with 1301 additions and 909 deletions

View File

@ -25,11 +25,12 @@ import routes
import webob.dec import webob.dec
import webob.exc import webob.exc
from quantum.api.v2 import attributes
from quantum.common import constants
from quantum.common import exceptions from quantum.common import exceptions
import quantum.extensions import quantum.extensions
from quantum.manager import QuantumManager from quantum.manager import QuantumManager
from quantum.openstack.common import cfg from quantum.openstack.common import cfg
from quantum.openstack.common import importutils
from quantum.openstack.common import log as logging from quantum.openstack.common import log as logging
from quantum import wsgi from quantum import wsgi
@ -436,6 +437,8 @@ class ExtensionManager(object):
attr_map[resource].update(resource_attrs) attr_map[resource].update(resource_attrs)
else: else:
attr_map[resource] = resource_attrs attr_map[resource] = resource_attrs
if extended_attrs:
attributes.EXT_NSES[ext.get_alias()] = ext.get_namespace()
except AttributeError: except AttributeError:
LOG.exception(_("Error fetching extended attributes for " LOG.exception(_("Error fetching extended attributes for "
"extension '%s'"), ext.get_name()) "extension '%s'"), ext.get_name())
@ -536,7 +539,6 @@ class PluginAwareExtensionManager(ExtensionManager):
"supported_extension_aliases") and "supported_extension_aliases") and
alias in plugin.supported_extension_aliases) alias in plugin.supported_extension_aliases)
for plugin in self.plugins.values()) for plugin in self.plugins.values())
plugin_provider = cfg.CONF.core_plugin
if not supports_extension: if not supports_extension:
LOG.warn(_("Extension %s not supported by any of loaded plugins"), LOG.warn(_("Extension %s not supported by any of loaded plugins"),
alias) alias)

View File

@ -18,6 +18,7 @@
import netaddr import netaddr
import re import re
from quantum.common import constants
from quantum.common import exceptions as q_exc from quantum.common import exceptions as q_exc
from quantum.openstack.common import log as logging from quantum.openstack.common import log as logging
from quantum.openstack.common import uuidutils from quantum.openstack.common import uuidutils
@ -536,3 +537,19 @@ RESOURCE_HIERARCHY_MAP = {
'ports': {'parent': 'networks', 'identified_by': 'network_id'}, 'ports': {'parent': 'networks', 'identified_by': 'network_id'},
'subnets': {'parent': 'networks', 'identified_by': 'network_id'} 'subnets': {'parent': 'networks', 'identified_by': 'network_id'}
} }
PLURALS = {'networks': 'network',
'ports': 'port',
'subnets': 'subnet',
'dns_nameservers': 'dns_nameserver',
'host_routes': 'host_route',
'allocation_pools': 'allocation_pool',
'fixed_ips': 'fixed_ip',
'extensions': 'extension'}
EXT_NSES = {}
def get_attr_metadata():
return {'plurals': PLURALS,
'xmlns': constants.XML_NS_V20,
constants.EXT_NS: EXT_NSES}

View File

@ -26,8 +26,8 @@ from quantum.openstack.common.notifier import api as notifier_api
from quantum import policy from quantum import policy
from quantum import quota from quantum import quota
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
XML_NS_V20 = 'http://openstack.org/quantum/api/v2.0'
FAULT_MAP = {exceptions.NotFound: webob.exc.HTTPNotFound, FAULT_MAP = {exceptions.NotFound: webob.exc.HTTPNotFound,
exceptions.Conflict: webob.exc.HTTPConflict, exceptions.Conflict: webob.exc.HTTPConflict,
@ -527,13 +527,4 @@ def create_resource(collection, resource, plugin, params, allow_bulk=False,
controller = Controller(plugin, collection, resource, params, allow_bulk, controller = Controller(plugin, collection, resource, params, allow_bulk,
member_actions=member_actions, parent=parent) member_actions=member_actions, parent=parent)
# NOTE(jkoelker) To anyone wishing to add "proper" xml support return wsgi_resource.Resource(controller, FAULT_MAP)
# this is where you do it
serializers = {}
# 'application/xml': wsgi.XMLDictSerializer(metadata, XML_NS_V20),
deserializers = {}
# 'application/xml': wsgi.XMLDeserializer(metadata),
return wsgi_resource.Resource(controller, FAULT_MAP, deserializers,
serializers)

View File

@ -18,13 +18,11 @@ Utility methods for working with WSGI servers redux
""" """
import netaddr import netaddr
import webob
import webob.dec import webob.dec
import webob.exc import webob.exc
from quantum.api.v2 import attributes
from quantum.common import exceptions from quantum.common import exceptions
from quantum import context
from quantum.openstack.common import jsonutils as json
from quantum.openstack.common import log as logging from quantum.openstack.common import log as logging
from quantum import wsgi from quantum import wsgi
@ -32,31 +30,20 @@ from quantum import wsgi
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
class Request(webob.Request): class Request(wsgi.Request):
"""Add some Openstack API-specific logic to the base webob.Request.""" pass
def best_match_content_type(self):
supported = ('application/json', )
return self.accept.best_match(supported,
default_match='application/json')
@property
def context(self):
#Eventually the Auth[NZ] code will supply this. (mdragon)
#when that happens this if block should raise instead.
if 'quantum.context' not in self.environ:
self.environ['quantum.context'] = context.get_admin_context()
return self.environ['quantum.context']
def Resource(controller, faults=None, deserializers=None, serializers=None): def Resource(controller, faults=None, deserializers=None, serializers=None):
"""Represents an API entity resource and the associated serialization and """Represents an API entity resource and the associated serialization and
deserialization logic deserialization logic
""" """
default_deserializers = {'application/xml': wsgi.XMLDeserializer(), xml_deserializer = wsgi.XMLDeserializer(attributes.get_attr_metadata())
'application/json': lambda x: json.loads(x)} default_deserializers = {'application/xml': xml_deserializer,
default_serializers = {'application/xml': wsgi.XMLDictSerializer(), 'application/json': wsgi.JSONDeserializer()}
'application/json': lambda x: json.dumps(x)} xml_serializer = wsgi.XMLDictSerializer(attributes.get_attr_metadata())
default_serializers = {'application/xml': xml_serializer,
'application/json': wsgi.JSONDictSerializer()}
format_types = {'xml': 'application/xml', format_types = {'xml': 'application/xml',
'json': 'application/json'} 'json': 'application/json'}
action_status = dict(create=201, delete=204) action_status = dict(create=201, delete=204)
@ -81,7 +68,6 @@ def Resource(controller, faults=None, deserializers=None, serializers=None):
args.pop('controller', None) args.pop('controller', None)
fmt = args.pop('format', None) fmt = args.pop('format', None)
action = args.pop('action', None) action = args.pop('action', None)
content_type = format_types.get(fmt, content_type = format_types.get(fmt,
request.best_match_content_type()) request.best_match_content_type())
deserializer = deserializers.get(content_type) deserializer = deserializers.get(content_type)
@ -89,7 +75,7 @@ def Resource(controller, faults=None, deserializers=None, serializers=None):
try: try:
if request.body: if request.body:
args['body'] = deserializer(request.body) args['body'] = deserializer.deserialize(request.body)['body']
method = getattr(controller, action) method = getattr(controller, action)
@ -98,7 +84,7 @@ def Resource(controller, faults=None, deserializers=None, serializers=None):
exceptions.QuantumException, exceptions.QuantumException,
netaddr.AddrFormatError) as e: netaddr.AddrFormatError) as e:
LOG.exception(_('%s failed'), action) LOG.exception(_('%s failed'), action)
body = serializer({'QuantumError': str(e)}) body = serializer.serialize({'QuantumError': str(e)})
kwargs = {'body': body, 'content_type': content_type} kwargs = {'body': body, 'content_type': content_type}
for fault in faults: for fault in faults:
if isinstance(e, fault): if isinstance(e, fault):
@ -106,7 +92,7 @@ def Resource(controller, faults=None, deserializers=None, serializers=None):
raise webob.exc.HTTPInternalServerError(**kwargs) raise webob.exc.HTTPInternalServerError(**kwargs)
except webob.exc.HTTPException as e: except webob.exc.HTTPException as e:
LOG.exception(_('%s failed'), action) LOG.exception(_('%s failed'), action)
e.body = serializer({'QuantumError': str(e)}) e.body = serializer.serialize({'QuantumError': str(e)})
e.content_type = content_type e.content_type = content_type
raise raise
except Exception as e: except Exception as e:
@ -115,12 +101,12 @@ def Resource(controller, faults=None, deserializers=None, serializers=None):
# Do not expose details of 500 error to clients. # Do not expose details of 500 error to clients.
msg = _('Request Failed: internal server error while ' msg = _('Request Failed: internal server error while '
'processing your request.') 'processing your request.')
body = serializer({'QuantumError': msg}) body = serializer.serialize({'QuantumError': msg})
kwargs = {'body': body, 'content_type': content_type} kwargs = {'body': body, 'content_type': content_type}
raise webob.exc.HTTPInternalServerError(**kwargs) raise webob.exc.HTTPInternalServerError(**kwargs)
status = action_status.get(action, 200) status = action_status.get(action, 200)
body = serializer(result) body = serializer.serialize(result)
# NOTE(jkoelker) Comply with RFC2616 section 9.7 # NOTE(jkoelker) Comply with RFC2616 section 9.7
if status == 204: if status == 204:
content_type = '' content_type = ''

View File

@ -33,3 +33,19 @@ INTERFACE_KEY = '_interfaces'
IPv4 = 'IPv4' IPv4 = 'IPv4'
IPv6 = 'IPv6' IPv6 = 'IPv6'
EXT_NS = '_extension_ns'
XML_NS_V20 = 'http://openstack.org/quantum/api/v2.0'
XSI_NAMESPACE = "http://www.w3.org/2001/XMLSchema-instance"
XSI_ATTR = "xsi:nil"
XSI_NIL_ATTR = "xmlns:xsi"
TYPE_XMLNS = "xmlns:quantum"
TYPE_ATTR = "quantum:type"
VIRTUAL_ROOT_KEY = "_v_root"
TYPE_BOOL = "bool"
TYPE_INT = "int"
TYPE_LONG = "long"
TYPE_FLOAT = "float"
TYPE_LIST = "list"
TYPE_DICT = "dict"

View File

@ -103,7 +103,7 @@ class DbQuotaDriver(object):
tenant_quota[quota['resource']] = quota['limit'] tenant_quota[quota['resource']] = quota['limit']
return all_tenant_quotas.itervalues() return all_tenant_quotas.values()
@staticmethod @staticmethod
def update_quota_limit(context, tenant_id, resource, limit): def update_quota_limit(context, tenant_id, resource, limit):

View File

@ -185,6 +185,8 @@ class L3(extensions.ExtensionDescriptor):
@classmethod @classmethod
def get_resources(cls): def get_resources(cls):
""" Returns Ext Resources """ """ Returns Ext Resources """
my_plurals = [(key, key[:-1]) for key in RESOURCE_ATTRIBUTE_MAP.keys()]
attr.PLURALS.update(dict(my_plurals))
exts = [] exts = []
plugin = manager.QuantumManager.get_plugin() plugin = manager.QuantumManager.get_plugin()
for resource_name in ['router', 'floatingip']: for resource_name in ['router', 'floatingip']:

View File

@ -257,6 +257,8 @@ class Loadbalancer(extensions.ExtensionDescriptor):
@classmethod @classmethod
def get_resources(cls): def get_resources(cls):
my_plurals = [(key, key[:-1]) for key in RESOURCE_ATTRIBUTE_MAP.keys()]
attr.PLURALS.update(dict(my_plurals))
resources = [] resources = []
plugin = manager.QuantumManager.get_service_plugins()[ plugin = manager.QuantumManager.get_service_plugins()[
constants.LOADBALANCER] constants.LOADBALANCER]

View File

@ -51,7 +51,8 @@ class QuotaSetsController(wsgi.Controller):
self._driver = importutils.import_class(DB_QUOTA_DRIVER) self._driver = importutils.import_class(DB_QUOTA_DRIVER)
def _get_body(self, request): def _get_body(self, request):
body = self._deserialize(request.body, request.get_content_type()) body = self._deserialize(request.body,
request.best_match_content_type())
attr_info = EXTENDED_ATTRIBUTES_2_0[RESOURCE_COLLECTION] attr_info = EXTENDED_ATTRIBUTES_2_0[RESOURCE_COLLECTION]
req_body = base.Controller.prepare_request_body( req_body = base.Controller.prepare_request_body(
request.context, body, False, self._resource_name, attr_info) request.context, body, False, self._resource_name, attr_info)

View File

@ -289,6 +289,8 @@ class Securitygroup(extensions.ExtensionDescriptor):
@classmethod @classmethod
def get_resources(cls): def get_resources(cls):
""" Returns Ext Resources """ """ Returns Ext Resources """
my_plurals = [(key, key[:-1]) for key in RESOURCE_ATTRIBUTE_MAP.keys()]
attr.PLURALS.update(dict(my_plurals))
exts = [] exts = []
plugin = manager.QuantumManager.get_plugin() plugin = manager.QuantumManager.get_plugin()
for resource_name in ['security_group', 'security_group_rule']: for resource_name in ['security_group', 'security_group_rule']:

View File

@ -183,6 +183,11 @@ class Servicetype(extensions.ExtensionDescriptor):
@classmethod @classmethod
def get_resources(cls): def get_resources(cls):
""" Returns Extended Resource for service type management """ """ Returns Extended Resource for service type management """
my_plurals = [(key.replace('-', '_'),
key[:-1].replace('-', '_')) for
key in RESOURCE_ATTRIBUTE_MAP.keys()]
my_plurals.append(('service_definitions', 'service_definition'))
attributes.PLURALS.update(dict(my_plurals))
controller = base.create_resource( controller = base.create_resource(
COLLECTION_NAME, COLLECTION_NAME,
RESOURCE_NAME, RESOURCE_NAME,

View File

@ -87,7 +87,7 @@ class TestCiscoPortsV2(CiscoNetworkPluginV2TestCase,
patched_plugin.side_effect = side_effect patched_plugin.side_effect = side_effect
with self.network() as net: with self.network() as net:
res = self._create_port_bulk('json', 2, res = self._create_port_bulk(self.fmt, 2,
net['network']['id'], net['network']['id'],
'test', 'test',
True) True)
@ -109,7 +109,7 @@ class TestCiscoPortsV2(CiscoNetworkPluginV2TestCase,
*args, **kwargs) *args, **kwargs)
patched_plugin.side_effect = side_effect patched_plugin.side_effect = side_effect
res = self._create_port_bulk('json', 2, net['network']['id'], res = self._create_port_bulk(self.fmt, 2, net['network']['id'],
'test', True, context=ctx) 'test', True, context=ctx)
# We expect a 500 as we injected a fault in the plugin # We expect a 500 as we injected a fault in the plugin
self._validate_behavior_on_bulk_failure(res, 'ports') self._validate_behavior_on_bulk_failure(res, 'ports')
@ -137,7 +137,7 @@ class TestCiscoNetworksV2(CiscoNetworkPluginV2TestCase,
return self._do_side_effect(patched_plugin, orig, return self._do_side_effect(patched_plugin, orig,
*args, **kwargs) *args, **kwargs)
patched_plugin.side_effect = side_effect patched_plugin.side_effect = side_effect
res = self._create_network_bulk('json', 2, 'test', True) res = self._create_network_bulk(self.fmt, 2, 'test', True)
LOG.debug("response is %s" % res) LOG.debug("response is %s" % res)
# We expect a 500 as we injected a fault in the plugin # We expect a 500 as we injected a fault in the plugin
self._validate_behavior_on_bulk_failure(res, 'networks') self._validate_behavior_on_bulk_failure(res, 'networks')
@ -155,7 +155,7 @@ class TestCiscoNetworksV2(CiscoNetworkPluginV2TestCase,
*args, **kwargs) *args, **kwargs)
patched_plugin.side_effect = side_effect patched_plugin.side_effect = side_effect
res = self._create_network_bulk('json', 2, 'test', True) res = self._create_network_bulk(self.fmt, 2, 'test', True)
# We expect a 500 as we injected a fault in the plugin # We expect a 500 as we injected a fault in the plugin
self._validate_behavior_on_bulk_failure(res, 'networks') self._validate_behavior_on_bulk_failure(res, 'networks')
@ -185,7 +185,7 @@ class TestCiscoSubnetsV2(CiscoNetworkPluginV2TestCase,
patched_plugin.side_effect = side_effect patched_plugin.side_effect = side_effect
with self.network() as net: with self.network() as net:
res = self._create_subnet_bulk('json', 2, res = self._create_subnet_bulk(self.fmt, 2,
net['network']['id'], net['network']['id'],
'test') 'test')
# We expect a 500 as we injected a fault in the plugin # We expect a 500 as we injected a fault in the plugin
@ -204,9 +204,21 @@ class TestCiscoSubnetsV2(CiscoNetworkPluginV2TestCase,
patched_plugin.side_effect = side_effect patched_plugin.side_effect = side_effect
with self.network() as net: with self.network() as net:
res = self._create_subnet_bulk('json', 2, res = self._create_subnet_bulk(self.fmt, 2,
net['network']['id'], net['network']['id'],
'test') 'test')
# We expect a 500 as we injected a fault in the plugin # We expect a 500 as we injected a fault in the plugin
self._validate_behavior_on_bulk_failure(res, 'subnets') self._validate_behavior_on_bulk_failure(res, 'subnets')
class TestCiscoPortsV2XML(TestCiscoPortsV2):
fmt = 'xml'
class TestCiscoNetworksV2XML(TestCiscoNetworksV2):
fmt = 'xml'
class TestCiscoSubnetsV2XML(TestCiscoSubnetsV2):
fmt = 'xml'

View File

@ -17,30 +17,24 @@ import contextlib
import logging import logging
import os import os
import unittest2
import webob.exc import webob.exc
import quantum
from quantum.api.extensions import ExtensionMiddleware from quantum.api.extensions import ExtensionMiddleware
from quantum.api.extensions import PluginAwareExtensionManager from quantum.api.extensions import PluginAwareExtensionManager
from quantum.api.v2 import attributes from quantum.api.v2 import attributes
from quantum.api.v2.router import APIRouter from quantum.api.v2.router import APIRouter
from quantum.common import config from quantum.common import config
from quantum.common import exceptions as q_exc
from quantum.common.test_lib import test_config from quantum.common.test_lib import test_config
from quantum import context
from quantum.db import api as db from quantum.db import api as db
from quantum.db import db_base_plugin_v2 import quantum.extensions
from quantum.db import models_v2
from quantum.extensions import loadbalancer from quantum.extensions import loadbalancer
from quantum.manager import QuantumManager from quantum.manager import QuantumManager
from quantum.openstack.common import cfg from quantum.openstack.common import cfg
from quantum.openstack.common import timeutils
from quantum.plugins.common import constants from quantum.plugins.common import constants
from quantum.plugins.services.loadbalancer import loadbalancerPlugin from quantum.plugins.services.loadbalancer import loadbalancerPlugin
from quantum.tests.unit import test_extensions from quantum.tests.unit import testlib_api
from quantum.tests.unit.testlib_api import create_request from quantum.tests.unit.testlib_api import create_request
from quantum.wsgi import Serializer, JSONDeserializer from quantum import wsgi
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -60,7 +54,7 @@ def etcdir(*p):
return os.path.join(ETCDIR, *p) return os.path.join(ETCDIR, *p)
class LoadBalancerPluginDbTestCase(unittest2.TestCase): class LoadBalancerPluginDbTestCase(testlib_api.WebTestCase):
def setUp(self, core_plugin=None, lb_plugin=None): def setUp(self, core_plugin=None, lb_plugin=None):
super(LoadBalancerPluginDbTestCase, self).setUp() super(LoadBalancerPluginDbTestCase, self).setUp()
@ -75,11 +69,6 @@ class LoadBalancerPluginDbTestCase(unittest2.TestCase):
self._tenant_id = "test-tenant" self._tenant_id = "test-tenant"
self._subnet_id = "0c798ed8-33ba-11e2-8b28-000c291c4d14" self._subnet_id = "0c798ed8-33ba-11e2-8b28-000c291c4d14"
json_deserializer = JSONDeserializer()
self._deserializers = {
'application/json': json_deserializer,
}
if not core_plugin: if not core_plugin:
core_plugin = test_config.get('plugin_name_v2', core_plugin = test_config.get('plugin_name_v2',
DB_CORE_PLUGIN_KLASS) DB_CORE_PLUGIN_KLASS)
@ -103,11 +92,11 @@ class LoadBalancerPluginDbTestCase(unittest2.TestCase):
) )
app = config.load_paste_app('extensions_test_app') app = config.load_paste_app('extensions_test_app')
self.ext_api = ExtensionMiddleware(app, ext_mgr=ext_mgr) self.ext_api = ExtensionMiddleware(app, ext_mgr=ext_mgr)
super(LoadBalancerPluginDbTestCase, self).setUp()
def tearDown(self): def tearDown(self):
super(LoadBalancerPluginDbTestCase, self).tearDown() super(LoadBalancerPluginDbTestCase, self).tearDown()
self.api = None self.api = None
self._deserializers = None
self._skip_native_bulk = None self._skip_native_bulk = None
self.ext_api = None self.ext_api = None
@ -118,8 +107,10 @@ class LoadBalancerPluginDbTestCase(unittest2.TestCase):
# Restore the original attribute map # Restore the original attribute map
loadbalancer.RESOURCE_ATTRIBUTE_MAP = self._attribute_map_bk loadbalancer.RESOURCE_ATTRIBUTE_MAP = self._attribute_map_bk
def _req(self, method, resource, data=None, fmt='json', def _req(self, method, resource, data=None, fmt=None,
id=None, subresource=None, sub_id=None, params=None, action=None): id=None, subresource=None, sub_id=None, params=None, action=None):
if not fmt:
fmt = self.fmt
if id and action: if id and action:
path = '/lb/%(resource)s/%(id)s/%(action)s.%(fmt)s' % locals() path = '/lb/%(resource)s/%(id)s/%(action)s.%(fmt)s' % locals()
elif id and subresource and sub_id: elif id and subresource and sub_id:
@ -138,7 +129,8 @@ class LoadBalancerPluginDbTestCase(unittest2.TestCase):
content_type = 'application/%s' % fmt content_type = 'application/%s' % fmt
body = None body = None
if data is not None: # empty dict is valid if data is not None: # empty dict is valid
body = Serializer().serialize(data, content_type) body = wsgi.Serializer(
attributes.get_attr_metadata()).serialize(data, content_type)
req = create_request(path, req = create_request(path,
body, body,
@ -147,32 +139,27 @@ class LoadBalancerPluginDbTestCase(unittest2.TestCase):
query_string=params) query_string=params)
return req return req
def new_create_request(self, resource, data, fmt='json', id=None, def new_create_request(self, resource, data, fmt=None, id=None,
subresource=None): subresource=None):
return self._req('POST', resource, data, fmt, id=id, return self._req('POST', resource, data, fmt, id=id,
subresource=subresource) subresource=subresource)
def new_list_request(self, resource, fmt='json', params=None): def new_list_request(self, resource, fmt=None, params=None):
return self._req('GET', resource, None, fmt, params=params) return self._req('GET', resource, None, fmt, params=params)
def new_show_request(self, resource, id, fmt='json', action=None, def new_show_request(self, resource, id, fmt=None, action=None,
subresource=None, sub_id=None): subresource=None, sub_id=None):
return self._req('GET', resource, None, fmt, id=id, action=action, return self._req('GET', resource, None, fmt, id=id, action=action,
subresource=subresource, sub_id=sub_id) subresource=subresource, sub_id=sub_id)
def new_delete_request(self, resource, id, fmt='json', def new_delete_request(self, resource, id, fmt=None,
subresource=None, sub_id=None): subresource=None, sub_id=None):
return self._req('DELETE', resource, None, fmt, id=id, return self._req('DELETE', resource, None, fmt, id=id,
subresource=subresource, sub_id=sub_id) subresource=subresource, sub_id=sub_id)
def new_update_request(self, resource, data, id, fmt='json'): def new_update_request(self, resource, data, id, fmt=None):
return self._req('PUT', resource, data, fmt, id=id) return self._req('PUT', resource, data, fmt, id=id)
def deserialize(self, content_type, response):
ctype = 'application/%s' % content_type
data = self._deserializers[ctype].deserialize(response.body)['body']
return data
def _create_vip(self, fmt, name, pool_id, protocol, port, admin_status_up, def _create_vip(self, fmt, name, pool_id, protocol, port, admin_status_up,
expected_res_status=None, **kwargs): expected_res_status=None, **kwargs):
data = {'vip': {'name': name, data = {'vip': {'name': name,
@ -267,25 +254,27 @@ class LoadBalancerPluginDbTestCase(unittest2.TestCase):
req = self.new_show_request(resource, id) req = self.new_show_request(resource, id)
res = req.get_response(self._api_for_resource(resource)) res = req.get_response(self._api_for_resource(resource))
self.assertEqual(res.status_int, expected_code) self.assertEqual(res.status_int, expected_code)
return self.deserialize('json', res) return self.deserialize(res)
def _update(self, resource, id, new_data, def _update(self, resource, id, new_data,
expected_code=webob.exc.HTTPOk.code): expected_code=webob.exc.HTTPOk.code):
req = self.new_update_request(resource, new_data, id) req = self.new_update_request(resource, new_data, id)
res = req.get_response(self._api_for_resource(resource)) res = req.get_response(self._api_for_resource(resource))
self.assertEqual(res.status_int, expected_code) self.assertEqual(res.status_int, expected_code)
return self.deserialize('json', res) return self.deserialize(res)
def _list(self, resource, fmt='json', query_params=None): def _list(self, resource, fmt=None, query_params=None):
req = self.new_list_request(resource, fmt, query_params) req = self.new_list_request(resource, fmt, query_params)
res = req.get_response(self._api_for_resource(resource)) res = req.get_response(self._api_for_resource(resource))
self.assertEqual(res.status_int, webob.exc.HTTPOk.code) self.assertEqual(res.status_int, webob.exc.HTTPOk.code)
return self.deserialize('json', res) return self.deserialize(res)
@contextlib.contextmanager @contextlib.contextmanager
def vip(self, fmt='json', name='vip1', pool=None, def vip(self, fmt=None, name='vip1', pool=None,
protocol='HTTP', port=80, admin_status_up=True, no_delete=False, protocol='HTTP', port=80, admin_status_up=True, no_delete=False,
address="172.16.1.123", **kwargs): address="172.16.1.123", **kwargs):
if not fmt:
fmt = self.fmt
if not pool: if not pool:
with self.pool() as pool: with self.pool() as pool:
pool_id = pool['pool']['id'] pool_id = pool['pool']['id']
@ -297,7 +286,7 @@ class LoadBalancerPluginDbTestCase(unittest2.TestCase):
admin_status_up, admin_status_up,
address=address, address=address,
**kwargs) **kwargs)
vip = self.deserialize(fmt, res) vip = self.deserialize(res)
if res.status_int >= 400: if res.status_int >= 400:
raise webob.exc.HTTPClientError(code=res.status_int) raise webob.exc.HTTPClientError(code=res.status_int)
yield vip yield vip
@ -313,7 +302,7 @@ class LoadBalancerPluginDbTestCase(unittest2.TestCase):
admin_status_up, admin_status_up,
address=address, address=address,
**kwargs) **kwargs)
vip = self.deserialize(fmt, res) vip = self.deserialize(res)
if res.status_int >= 400: if res.status_int >= 400:
raise webob.exc.HTTPClientError(code=res.status_int) raise webob.exc.HTTPClientError(code=res.status_int)
yield vip yield vip
@ -321,16 +310,18 @@ class LoadBalancerPluginDbTestCase(unittest2.TestCase):
self._delete('vips', vip['vip']['id']) self._delete('vips', vip['vip']['id'])
@contextlib.contextmanager @contextlib.contextmanager
def pool(self, fmt='json', name='pool1', lb_method='ROUND_ROBIN', def pool(self, fmt=None, name='pool1', lb_method='ROUND_ROBIN',
protocol='HTTP', admin_status_up=True, no_delete=False, protocol='HTTP', admin_status_up=True, no_delete=False,
**kwargs): **kwargs):
if not fmt:
fmt = self.fmt
res = self._create_pool(fmt, res = self._create_pool(fmt,
name, name,
lb_method, lb_method,
protocol, protocol,
admin_status_up, admin_status_up,
**kwargs) **kwargs)
pool = self.deserialize(fmt, res) pool = self.deserialize(res)
if res.status_int >= 400: if res.status_int >= 400:
raise webob.exc.HTTPClientError(code=res.status_int) raise webob.exc.HTTPClientError(code=res.status_int)
yield pool yield pool
@ -338,15 +329,17 @@ class LoadBalancerPluginDbTestCase(unittest2.TestCase):
self._delete('pools', pool['pool']['id']) self._delete('pools', pool['pool']['id'])
@contextlib.contextmanager @contextlib.contextmanager
def member(self, fmt='json', address='192.168.1.100', def member(self, fmt=None, address='192.168.1.100',
port=80, admin_status_up=True, no_delete=False, port=80, admin_status_up=True, no_delete=False,
**kwargs): **kwargs):
if not fmt:
fmt = self.fmt
res = self._create_member(fmt, res = self._create_member(fmt,
address, address,
port, port,
admin_status_up, admin_status_up,
**kwargs) **kwargs)
member = self.deserialize(fmt, res) member = self.deserialize(res)
if res.status_int >= 400: if res.status_int >= 400:
raise webob.exc.HTTPClientError(code=res.status_int) raise webob.exc.HTTPClientError(code=res.status_int)
yield member yield member
@ -354,10 +347,12 @@ class LoadBalancerPluginDbTestCase(unittest2.TestCase):
self._delete('members', member['member']['id']) self._delete('members', member['member']['id'])
@contextlib.contextmanager @contextlib.contextmanager
def health_monitor(self, fmt='json', type='TCP', def health_monitor(self, fmt=None, type='TCP',
delay=30, timeout=10, max_retries=3, delay=30, timeout=10, max_retries=3,
admin_status_up=True, admin_status_up=True,
no_delete=False, **kwargs): no_delete=False, **kwargs):
if not fmt:
fmt = self.fmt
res = self._create_health_monitor(fmt, res = self._create_health_monitor(fmt,
type, type,
delay, delay,
@ -365,7 +360,7 @@ class LoadBalancerPluginDbTestCase(unittest2.TestCase):
max_retries, max_retries,
admin_status_up, admin_status_up,
**kwargs) **kwargs)
health_monitor = self.deserialize(fmt, res) health_monitor = self.deserialize(res)
the_health_monitor = health_monitor['health_monitor'] the_health_monitor = health_monitor['health_monitor']
if res.status_int >= 400: if res.status_int >= 400:
raise webob.exc.HTTPClientError(code=res.status_int) raise webob.exc.HTTPClientError(code=res.status_int)
@ -452,7 +447,7 @@ class TestLoadBalancer(LoadBalancerPluginDbTestCase):
# Try resetting session_persistence # Try resetting session_persistence
req = self.new_update_request('vips', update_info, v['vip']['id']) req = self.new_update_request('vips', update_info, v['vip']['id'])
res = self.deserialize('json', req.get_response(self.ext_api)) res = self.deserialize(req.get_response(self.ext_api))
# If session persistence has been removed, it won't be present in # If session persistence has been removed, it won't be present in
# the response. # the response.
@ -476,7 +471,7 @@ class TestLoadBalancer(LoadBalancerPluginDbTestCase):
'cookie_name': "jesssionId"}, 'cookie_name': "jesssionId"},
'admin_state_up': False}} 'admin_state_up': False}}
req = self.new_update_request('vips', data, vip['vip']['id']) req = self.new_update_request('vips', data, vip['vip']['id'])
res = self.deserialize('json', req.get_response(self.ext_api)) res = self.deserialize(req.get_response(self.ext_api))
for k, v in keys: for k, v in keys:
self.assertEqual(res['vip'][k], v) self.assertEqual(res['vip'][k], v)
@ -501,7 +496,7 @@ class TestLoadBalancer(LoadBalancerPluginDbTestCase):
with self.vip(name=name) as vip: with self.vip(name=name) as vip:
req = self.new_show_request('vips', req = self.new_show_request('vips',
vip['vip']['id']) vip['vip']['id'])
res = self.deserialize('json', req.get_response(self.ext_api)) res = self.deserialize(req.get_response(self.ext_api))
for k, v in keys: for k, v in keys:
self.assertEqual(res['vip'][k], v) self.assertEqual(res['vip'][k], v)
@ -517,7 +512,7 @@ class TestLoadBalancer(LoadBalancerPluginDbTestCase):
('status', 'PENDING_CREATE')] ('status', 'PENDING_CREATE')]
with self.vip(name=name): with self.vip(name=name):
req = self.new_list_request('vips') req = self.new_list_request('vips')
res = self.deserialize('json', req.get_response(self.ext_api)) res = self.deserialize(req.get_response(self.ext_api))
for k, v in keys: for k, v in keys:
self.assertEqual(res['vips'][0][k], v) self.assertEqual(res['vips'][0][k], v)
@ -548,7 +543,7 @@ class TestLoadBalancer(LoadBalancerPluginDbTestCase):
name = "pool2" name = "pool2"
with self.pool(name=name) as pool: with self.pool(name=name) as pool:
pool_id = pool['pool']['id'] pool_id = pool['pool']['id']
res1 = self._create_member('json', res1 = self._create_member(self.fmt,
'192.168.1.100', '192.168.1.100',
'80', '80',
True, True,
@ -556,11 +551,10 @@ class TestLoadBalancer(LoadBalancerPluginDbTestCase):
weight=1) weight=1)
req = self.new_show_request('pools', req = self.new_show_request('pools',
pool_id, pool_id,
fmt='json') fmt=self.fmt)
pool_updated = self.deserialize('json', pool_updated = self.deserialize(req.get_response(self.ext_api))
req.get_response(self.ext_api))
member1 = self.deserialize('json', res1) member1 = self.deserialize(res1)
self.assertEqual(member1['member']['id'], self.assertEqual(member1['member']['id'],
pool_updated['pool']['members'][0]) pool_updated['pool']['members'][0])
self.assertEqual(len(pool_updated['pool']['members']), 1) self.assertEqual(len(pool_updated['pool']['members']), 1)
@ -596,8 +590,8 @@ class TestLoadBalancer(LoadBalancerPluginDbTestCase):
with self.pool(name=name) as pool: with self.pool(name=name) as pool:
req = self.new_show_request('pools', req = self.new_show_request('pools',
pool['pool']['id'], pool['pool']['id'],
fmt='json') fmt=self.fmt)
res = self.deserialize('json', req.get_response(self.ext_api)) res = self.deserialize(req.get_response(self.ext_api))
for k, v in keys: for k, v in keys:
self.assertEqual(res['pool'][k], v) self.assertEqual(res['pool'][k], v)
@ -612,9 +606,8 @@ class TestLoadBalancer(LoadBalancerPluginDbTestCase):
pool_id=pool_id) as member2: pool_id=pool_id) as member2:
req = self.new_show_request('pools', req = self.new_show_request('pools',
pool_id, pool_id,
fmt='json') fmt=self.fmt)
pool_update = self.deserialize( pool_update = self.deserialize(
'json',
req.get_response(self.ext_api)) req.get_response(self.ext_api))
self.assertIn(member1['member']['id'], self.assertIn(member1['member']['id'],
pool_update['pool']['members']) pool_update['pool']['members'])
@ -634,17 +627,15 @@ class TestLoadBalancer(LoadBalancerPluginDbTestCase):
with self.member(pool_id=pool1['pool']['id']) as member: with self.member(pool_id=pool1['pool']['id']) as member:
req = self.new_show_request('pools', req = self.new_show_request('pools',
pool1['pool']['id'], pool1['pool']['id'],
fmt='json') fmt=self.fmt)
pool1_update = self.deserialize( pool1_update = self.deserialize(
'json',
req.get_response(self.ext_api)) req.get_response(self.ext_api))
self.assertEqual(len(pool1_update['pool']['members']), 1) self.assertEqual(len(pool1_update['pool']['members']), 1)
req = self.new_show_request('pools', req = self.new_show_request('pools',
pool2['pool']['id'], pool2['pool']['id'],
fmt='json') fmt=self.fmt)
pool2_update = self.deserialize( pool2_update = self.deserialize(
'json',
req.get_response(self.ext_api)) req.get_response(self.ext_api))
self.assertEqual(len(pool1_update['pool']['members']), 1) self.assertEqual(len(pool1_update['pool']['members']), 1)
self.assertEqual(len(pool2_update['pool']['members']), 0) self.assertEqual(len(pool2_update['pool']['members']), 0)
@ -655,23 +646,20 @@ class TestLoadBalancer(LoadBalancerPluginDbTestCase):
req = self.new_update_request('members', req = self.new_update_request('members',
data, data,
member['member']['id']) member['member']['id'])
res = self.deserialize('json', res = self.deserialize(req.get_response(self.ext_api))
req.get_response(self.ext_api))
for k, v in keys: for k, v in keys:
self.assertEqual(res['member'][k], v) self.assertEqual(res['member'][k], v)
req = self.new_show_request('pools', req = self.new_show_request('pools',
pool1['pool']['id'], pool1['pool']['id'],
fmt='json') fmt=self.fmt)
pool1_update = self.deserialize( pool1_update = self.deserialize(
'json',
req.get_response(self.ext_api)) req.get_response(self.ext_api))
req = self.new_show_request('pools', req = self.new_show_request('pools',
pool2['pool']['id'], pool2['pool']['id'],
fmt='json') fmt=self.fmt)
pool2_update = self.deserialize( pool2_update = self.deserialize(
'json',
req.get_response(self.ext_api)) req.get_response(self.ext_api))
self.assertEqual(len(pool2_update['pool']['members']), 1) self.assertEqual(len(pool2_update['pool']['members']), 1)
@ -689,9 +677,8 @@ class TestLoadBalancer(LoadBalancerPluginDbTestCase):
req = self.new_show_request('pools', req = self.new_show_request('pools',
pool_id, pool_id,
fmt='json') fmt=self.fmt)
pool_update = self.deserialize( pool_update = self.deserialize(
'json',
req.get_response(self.ext_api)) req.get_response(self.ext_api))
self.assertEqual(len(pool_update['pool']['members']), 0) self.assertEqual(len(pool_update['pool']['members']), 0)
@ -707,8 +694,8 @@ class TestLoadBalancer(LoadBalancerPluginDbTestCase):
with self.member(pool_id=pool['pool']['id']) as member: with self.member(pool_id=pool['pool']['id']) as member:
req = self.new_show_request('members', req = self.new_show_request('members',
member['member']['id'], member['member']['id'],
fmt='json') fmt=self.fmt)
res = self.deserialize('json', req.get_response(self.ext_api)) res = self.deserialize(req.get_response(self.ext_api))
for k, v in keys: for k, v in keys:
self.assertEqual(res['member'][k], v) self.assertEqual(res['member'][k], v)
@ -740,7 +727,7 @@ class TestLoadBalancer(LoadBalancerPluginDbTestCase):
req = self.new_update_request("health_monitors", req = self.new_update_request("health_monitors",
data, data,
monitor['health_monitor']['id']) monitor['health_monitor']['id'])
res = self.deserialize('json', req.get_response(self.ext_api)) res = self.deserialize(req.get_response(self.ext_api))
for k, v in keys: for k, v in keys:
self.assertEqual(res['health_monitor'][k], v) self.assertEqual(res['health_monitor'][k], v)
@ -762,8 +749,8 @@ class TestLoadBalancer(LoadBalancerPluginDbTestCase):
('status', 'PENDING_CREATE')] ('status', 'PENDING_CREATE')]
req = self.new_show_request('health_monitors', req = self.new_show_request('health_monitors',
monitor['health_monitor']['id'], monitor['health_monitor']['id'],
fmt='json') fmt=self.fmt)
res = self.deserialize('json', req.get_response(self.ext_api)) res = self.deserialize(req.get_response(self.ext_api))
for k, v in keys: for k, v in keys:
self.assertEqual(res['health_monitor'][k], v) self.assertEqual(res['health_monitor'][k], v)
@ -776,8 +763,8 @@ class TestLoadBalancer(LoadBalancerPluginDbTestCase):
req = self.new_show_request("pools", req = self.new_show_request("pools",
pool['pool']['id'], pool['pool']['id'],
subresource="stats", subresource="stats",
fmt='json') fmt=self.fmt)
res = self.deserialize('json', req.get_response(self.ext_api)) res = self.deserialize(req.get_response(self.ext_api))
for k, v in keys: for k, v in keys:
self.assertEqual(res['stats'][k], v) self.assertEqual(res['stats'][k], v)
@ -791,7 +778,7 @@ class TestLoadBalancer(LoadBalancerPluginDbTestCase):
req = self.new_create_request( req = self.new_create_request(
"pools", "pools",
data, data,
fmt='json', fmt=self.fmt,
id=pool['pool']['id'], id=pool['pool']['id'],
subresource="health_monitors") subresource="health_monitors")
res = req.get_response(self.ext_api) res = req.get_response(self.ext_api)
@ -803,7 +790,7 @@ class TestLoadBalancer(LoadBalancerPluginDbTestCase):
req = self.new_create_request( req = self.new_create_request(
"pools", "pools",
data, data,
fmt='json', fmt=self.fmt,
id=pool['pool']['id'], id=pool['pool']['id'],
subresource="health_monitors") subresource="health_monitors")
res = req.get_response(self.ext_api) res = req.get_response(self.ext_api)
@ -812,9 +799,8 @@ class TestLoadBalancer(LoadBalancerPluginDbTestCase):
req = self.new_show_request( req = self.new_show_request(
'pools', 'pools',
pool['pool']['id'], pool['pool']['id'],
fmt='json') fmt=self.fmt)
res = self.deserialize('json', res = self.deserialize(req.get_response(self.ext_api))
req.get_response(self.ext_api))
self.assertIn(monitor1['health_monitor']['id'], self.assertIn(monitor1['health_monitor']['id'],
res['pool']['health_monitors']) res['pool']['health_monitors'])
self.assertIn(monitor2['health_monitor']['id'], self.assertIn(monitor2['health_monitor']['id'],
@ -831,7 +817,7 @@ class TestLoadBalancer(LoadBalancerPluginDbTestCase):
req = self.new_create_request( req = self.new_create_request(
"pools", "pools",
data, data,
fmt='json', fmt=self.fmt,
id=pool['pool']['id'], id=pool['pool']['id'],
subresource="health_monitors") subresource="health_monitors")
res = req.get_response(self.ext_api) res = req.get_response(self.ext_api)
@ -843,7 +829,7 @@ class TestLoadBalancer(LoadBalancerPluginDbTestCase):
req = self.new_create_request( req = self.new_create_request(
"pools", "pools",
data, data,
fmt='json', fmt=self.fmt,
id=pool['pool']['id'], id=pool['pool']['id'],
subresource="health_monitors") subresource="health_monitors")
res = req.get_response(self.ext_api) res = req.get_response(self.ext_api)
@ -852,7 +838,7 @@ class TestLoadBalancer(LoadBalancerPluginDbTestCase):
# remove one of healthmonitor from the pool # remove one of healthmonitor from the pool
req = self.new_delete_request( req = self.new_delete_request(
"pools", "pools",
fmt='json', fmt=self.fmt,
id=pool['pool']['id'], id=pool['pool']['id'],
sub_id=monitor1['health_monitor']['id'], sub_id=monitor1['health_monitor']['id'],
subresource="health_monitors") subresource="health_monitors")
@ -862,9 +848,8 @@ class TestLoadBalancer(LoadBalancerPluginDbTestCase):
req = self.new_show_request( req = self.new_show_request(
'pools', 'pools',
pool['pool']['id'], pool['pool']['id'],
fmt='json') fmt=self.fmt)
res = self.deserialize('json', res = self.deserialize(req.get_response(self.ext_api))
req.get_response(self.ext_api))
self.assertNotIn(monitor1['health_monitor']['id'], self.assertNotIn(monitor1['health_monitor']['id'],
res['pool']['health_monitors']) res['pool']['health_monitors'])
self.assertIn(monitor2['health_monitor']['id'], self.assertIn(monitor2['health_monitor']['id'],
@ -879,26 +864,26 @@ class TestLoadBalancer(LoadBalancerPluginDbTestCase):
pool_id = pool['pool']['id'] pool_id = pool['pool']['id']
vip_id = vip['vip']['id'] vip_id = vip['vip']['id']
# Add two members # Add two members
res1 = self._create_member('json', res1 = self._create_member(self.fmt,
'192.168.1.100', '192.168.1.100',
'80', '80',
True, True,
pool_id=pool_id, pool_id=pool_id,
weight=1) weight=1)
res2 = self._create_member('json', res2 = self._create_member(self.fmt,
'192.168.1.101', '192.168.1.101',
'80', '80',
True, True,
pool_id=pool_id, pool_id=pool_id,
weight=2) weight=2)
# Add a health_monitor # Add a health_monitor
req = self._create_health_monitor('json', req = self._create_health_monitor(self.fmt,
'HTTP', 'HTTP',
'10', '10',
'10', '10',
'3', '3',
True) True)
health_monitor = self.deserialize('json', req) health_monitor = self.deserialize(req)
self.assertEqual(req.status_int, 201) self.assertEqual(req.status_int, 201)
# Associate the health_monitor to the pool # Associate the health_monitor to the pool
@ -907,7 +892,7 @@ class TestLoadBalancer(LoadBalancerPluginDbTestCase):
'tenant_id': self._tenant_id}} 'tenant_id': self._tenant_id}}
req = self.new_create_request("pools", req = self.new_create_request("pools",
data, data,
fmt='json', fmt=self.fmt,
id=pool['pool']['id'], id=pool['pool']['id'],
subresource="health_monitors") subresource="health_monitors")
res = req.get_response(self.ext_api) res = req.get_response(self.ext_api)
@ -916,11 +901,10 @@ class TestLoadBalancer(LoadBalancerPluginDbTestCase):
# Get pool and vip # Get pool and vip
req = self.new_show_request('pools', req = self.new_show_request('pools',
pool_id, pool_id,
fmt='json') fmt=self.fmt)
pool_updated = self.deserialize('json', pool_updated = self.deserialize(req.get_response(self.ext_api))
req.get_response(self.ext_api)) member1 = self.deserialize(res1)
member1 = self.deserialize('json', res1) member2 = self.deserialize(res2)
member2 = self.deserialize('json', res2)
self.assertIn(member1['member']['id'], self.assertIn(member1['member']['id'],
pool_updated['pool']['members']) pool_updated['pool']['members'])
self.assertIn(member2['member']['id'], self.assertIn(member2['member']['id'],
@ -930,9 +914,8 @@ class TestLoadBalancer(LoadBalancerPluginDbTestCase):
req = self.new_show_request('vips', req = self.new_show_request('vips',
vip_id, vip_id,
fmt='json') fmt=self.fmt)
vip_updated = self.deserialize('json', vip_updated = self.deserialize(req.get_response(self.ext_api))
req.get_response(self.ext_api))
self.assertEqual(vip_updated['vip']['pool_id'], self.assertEqual(vip_updated['vip']['pool_id'],
pool_updated['pool']['id']) pool_updated['pool']['id'])
@ -941,3 +924,7 @@ class TestLoadBalancer(LoadBalancerPluginDbTestCase):
health_monitor['health_monitor']['id']) health_monitor['health_monitor']['id'])
self._delete('members', member1['member']['id']) self._delete('members', member1['member']['id'])
self._delete('members', member2['member']['id']) self._delete('members', member2['member']['id'])
class TestLoadBalancerXML(TestLoadBalancer):
fmt = 'xml'

View File

@ -83,8 +83,8 @@ class TestLinuxBridgeSecurityGroups(LinuxBridgeSecurityGroupsTestCase,
with self.subnet(n): with self.subnet(n):
with self.security_group() as sg: with self.security_group() as sg:
security_group_id = sg['security_group']['id'] security_group_id = sg['security_group']['id']
res = self._create_port('json', n['network']['id']) res = self._create_port(self.fmt, n['network']['id'])
port = self.deserialize('json', res) port = self.deserialize(self.fmt, res)
data = {'port': {'fixed_ips': port['port']['fixed_ips'], data = {'port': {'fixed_ips': port['port']['fixed_ips'],
'name': port['port']['name'], 'name': port['port']['name'],
@ -93,7 +93,8 @@ class TestLinuxBridgeSecurityGroups(LinuxBridgeSecurityGroupsTestCase,
req = self.new_update_request('ports', data, req = self.new_update_request('ports', data,
port['port']['id']) port['port']['id'])
res = self.deserialize('json', req.get_response(self.api)) res = self.deserialize(self.fmt,
req.get_response(self.api))
self.assertEquals(res['port'][ext_sg.SECURITYGROUPS][0], self.assertEquals(res['port'][ext_sg.SECURITYGROUPS][0],
security_group_id) security_group_id)
self._delete('ports', port['port']['id']) self._delete('ports', port['port']['id'])
@ -104,14 +105,18 @@ class TestLinuxBridgeSecurityGroups(LinuxBridgeSecurityGroupsTestCase,
mock.ANY, [security_group_id])]) mock.ANY, [security_group_id])])
class TestLinuxBridgeSecurityGroupsXML(TestLinuxBridgeSecurityGroups):
fmt = 'xml'
class TestLinuxBridgeSecurityGroupsDB(LinuxBridgeSecurityGroupsTestCase): class TestLinuxBridgeSecurityGroupsDB(LinuxBridgeSecurityGroupsTestCase):
def test_security_group_get_port_from_device(self): def test_security_group_get_port_from_device(self):
with self.network() as n: with self.network() as n:
with self.subnet(n): with self.subnet(n):
with self.security_group() as sg: with self.security_group() as sg:
security_group_id = sg['security_group']['id'] security_group_id = sg['security_group']['id']
res = self._create_port('json', n['network']['id']) res = self._create_port(self.fmt, n['network']['id'])
port = self.deserialize('json', res) port = self.deserialize(self.fmt, res)
fixed_ips = port['port']['fixed_ips'] fixed_ips = port['port']['fixed_ips']
data = {'port': {'fixed_ips': fixed_ips, data = {'port': {'fixed_ips': fixed_ips,
'name': port['port']['name'], 'name': port['port']['name'],
@ -120,7 +125,8 @@ class TestLinuxBridgeSecurityGroupsDB(LinuxBridgeSecurityGroupsTestCase):
req = self.new_update_request('ports', data, req = self.new_update_request('ports', data,
port['port']['id']) port['port']['id'])
res = self.deserialize('json', req.get_response(self.api)) res = self.deserialize(self.fmt,
req.get_response(self.api))
port_id = res['port']['id'] port_id = res['port']['id']
device_id = port_id[:8] device_id = port_id[:8]
port_dict = lb_db.get_port_from_device(device_id) port_dict = lb_db.get_port_from_device(device_id)
@ -135,3 +141,7 @@ class TestLinuxBridgeSecurityGroupsDB(LinuxBridgeSecurityGroupsTestCase):
def test_security_group_get_port_from_device_with_no_port(self): def test_security_group_get_port_from_device_with_no_port(self):
port_dict = lb_db.get_port_from_device('bad_device_id') port_dict = lb_db.get_port_from_device('bad_device_id')
self.assertEqual(None, port_dict) self.assertEqual(None, port_dict)
class TestLinuxBridgeSecurityGroupsDBXML(TestLinuxBridgeSecurityGroupsDB):
fmt = 'xml'

View File

@ -28,12 +28,15 @@ from quantum.api.v2 import attributes
from quantum.api.v2 import base from quantum.api.v2 import base
from quantum.api.v2 import router from quantum.api.v2 import router
from quantum.common import config from quantum.common import config
from quantum.common import constants
from quantum.common import exceptions as q_exc from quantum.common import exceptions as q_exc
from quantum import context from quantum import context
from quantum.manager import QuantumManager from quantum.manager import QuantumManager
from quantum.openstack.common import cfg from quantum.openstack.common import cfg
from quantum.openstack.common.notifier import api as notifer_api from quantum.openstack.common.notifier import api as notifer_api
from quantum.openstack.common import uuidutils from quantum.openstack.common import uuidutils
from quantum.tests.unit import testlib_api
from quantum import wsgi
ROOTDIR = os.path.dirname(os.path.dirname(__file__)) ROOTDIR = os.path.dirname(os.path.dirname(__file__))
@ -105,6 +108,7 @@ class APIv2TestBase(unittest.TestCase):
api = router.APIRouter() api = router.APIRouter()
self.api = webtest.TestApp(api) self.api = webtest.TestApp(api)
super(APIv2TestBase, self).setUp()
def tearDown(self): def tearDown(self):
self._plugin_patcher.stop() self._plugin_patcher.stop()
@ -270,7 +274,9 @@ class APIv2TestCase(APIv2TestBase):
# Note: since all resources use the same controller and validation # Note: since all resources use the same controller and validation
# logic, we actually get really good coverage from testing just networks. # logic, we actually get really good coverage from testing just networks.
class JSONV2TestCase(APIv2TestBase): class JSONV2TestCase(APIv2TestBase, testlib_api.WebTestCase):
def setUp(self):
super(JSONV2TestCase, self).setUp()
def _test_list(self, req_tenant_id, real_tenant_id): def _test_list(self, req_tenant_id, real_tenant_id):
env = {} env = {}
@ -287,19 +293,21 @@ class JSONV2TestCase(APIv2TestBase):
instance = self.plugin.return_value instance = self.plugin.return_value
instance.get_networks.return_value = return_value instance.get_networks.return_value = return_value
res = self.api.get(_get_path('networks'), extra_environ=env) res = self.api.get(_get_path('networks',
self.assertTrue('networks' in res.json) fmt=self.fmt), extra_environ=env)
res = self.deserialize(res)
self.assertTrue('networks' in res)
if not req_tenant_id or req_tenant_id == real_tenant_id: if not req_tenant_id or req_tenant_id == real_tenant_id:
# expect full list returned # expect full list returned
self.assertEqual(len(res.json['networks']), 1) self.assertEqual(len(res['networks']), 1)
output_dict = res.json['networks'][0] output_dict = res['networks'][0]
input_dict['shared'] = False input_dict['shared'] = False
self.assertEqual(len(input_dict), len(output_dict)) self.assertEqual(len(input_dict), len(output_dict))
for k, v in input_dict.iteritems(): for k, v in input_dict.iteritems():
self.assertEqual(v, output_dict[k]) self.assertEqual(v, output_dict[k])
else: else:
# expect no results # expect no results
self.assertEqual(len(res.json['networks']), 0) self.assertEqual(len(res['networks']), 0)
def test_list_noauth(self): def test_list_noauth(self):
self._test_list(None, _uuid()) self._test_list(None, _uuid())
@ -324,11 +332,13 @@ class JSONV2TestCase(APIv2TestBase):
instance.create_network.return_value = return_value instance.create_network.return_value = return_value
instance.get_networks_count.return_value = 0 instance.get_networks_count.return_value = 0
res = self.api.post_json(_get_path('networks'), data) res = self.api.post(_get_path('networks', fmt=self.fmt),
self.serialize(data),
content_type='application/' + self.fmt)
self.assertEqual(res.status_int, exc.HTTPCreated.code) self.assertEqual(res.status_int, exc.HTTPCreated.code)
self.assertTrue('network' in res.json) res = self.deserialize(res)
net = res.json['network'] self.assertTrue('network' in res)
net = res['network']
self.assertEqual(net['id'], net_id) self.assertEqual(net['id'], net_id)
self.assertEqual(net['status'], "ACTIVE") self.assertEqual(net['status'], "ACTIVE")
@ -346,21 +356,25 @@ class JSONV2TestCase(APIv2TestBase):
instance.create_network.return_value = return_value instance.create_network.return_value = return_value
instance.get_networks_count.return_value = 0 instance.get_networks_count.return_value = 0
res = self.api.post_json(_get_path('networks'), initial_input) res = self.api.post(_get_path('networks', fmt=self.fmt),
self.serialize(initial_input),
content_type='application/' + self.fmt)
instance.create_network.assert_called_with(mock.ANY, instance.create_network.assert_called_with(mock.ANY,
network=full_input) network=full_input)
self.assertEqual(res.status_int, exc.HTTPCreated.code) self.assertEqual(res.status_int, exc.HTTPCreated.code)
self.assertTrue('network' in res.json) res = self.deserialize(res)
net = res.json['network'] self.assertIn('network', res)
net = res['network']
self.assertEqual(net['id'], net_id) self.assertEqual(net['id'], net_id)
self.assertEqual(net['admin_state_up'], True) self.assertEqual(net['admin_state_up'], True)
self.assertEqual(net['status'], "ACTIVE") self.assertEqual(net['status'], "ACTIVE")
def test_create_no_keystone_env(self): def test_create_no_keystone_env(self):
data = {'name': 'net1'} data = {'name': 'net1'}
res = self.api.post_json(_get_path('networks'), data, res = self.api.post(_get_path('networks', fmt=self.fmt),
expect_errors=True) self.serialize(data),
content_type='application/' + self.fmt,
expect_errors=True)
self.assertEqual(res.status_int, exc.HTTPBadRequest.code) self.assertEqual(res.status_int, exc.HTTPBadRequest.code)
def test_create_with_keystone_env(self): def test_create_with_keystone_env(self):
@ -380,8 +394,10 @@ class JSONV2TestCase(APIv2TestBase):
instance.create_network.return_value = return_value instance.create_network.return_value = return_value
instance.get_networks_count.return_value = 0 instance.get_networks_count.return_value = 0
res = self.api.post_json(_get_path('networks'), initial_input, res = self.api.post(_get_path('networks', fmt=self.fmt),
extra_environ=env) self.serialize(initial_input),
content_type='application/' + self.fmt,
extra_environ=env)
instance.create_network.assert_called_with(mock.ANY, instance.create_network.assert_called_with(mock.ANY,
network=full_input) network=full_input)
@ -391,33 +407,44 @@ class JSONV2TestCase(APIv2TestBase):
tenant_id = _uuid() tenant_id = _uuid()
data = {'network': {'name': 'net1', 'tenant_id': tenant_id}} data = {'network': {'name': 'net1', 'tenant_id': tenant_id}}
env = {'quantum.context': context.Context('', tenant_id + "bad")} env = {'quantum.context': context.Context('', tenant_id + "bad")}
res = self.api.post_json(_get_path('networks'), data, res = self.api.post(_get_path('networks', fmt=self.fmt),
expect_errors=True, self.serialize(data),
extra_environ=env) content_type='application/' + self.fmt,
expect_errors=True,
extra_environ=env)
self.assertEqual(res.status_int, exc.HTTPBadRequest.code) self.assertEqual(res.status_int, exc.HTTPBadRequest.code)
def test_create_no_body(self): def test_create_no_body(self):
data = {'whoa': None} data = {'whoa': None}
res = self.api.post_json(_get_path('networks'), data, res = self.api.post(_get_path('networks', fmt=self.fmt),
expect_errors=True) self.serialize(data),
content_type='application/' + self.fmt,
expect_errors=True)
self.assertEqual(res.status_int, exc.HTTPBadRequest.code) self.assertEqual(res.status_int, exc.HTTPBadRequest.code)
def test_create_no_resource(self): def test_create_no_resource(self):
res = self.api.post_json(_get_path('networks'), dict(), data = {}
expect_errors=True) res = self.api.post(_get_path('networks', fmt=self.fmt),
self.serialize(data),
content_type='application/' + self.fmt,
expect_errors=True)
self.assertEqual(res.status_int, exc.HTTPBadRequest.code) self.assertEqual(res.status_int, exc.HTTPBadRequest.code)
def test_create_missing_attr(self): def test_create_missing_attr(self):
data = {'port': {'what': 'who', 'tenant_id': _uuid()}} data = {'port': {'what': 'who', 'tenant_id': _uuid()}}
res = self.api.post_json(_get_path('ports'), data, res = self.api.post(_get_path('ports', fmt=self.fmt),
expect_errors=True) self.serialize(data),
content_type='application/' + self.fmt,
expect_errors=True)
self.assertEqual(res.status_int, 400) self.assertEqual(res.status_int, 400)
def test_create_readonly_attr(self): def test_create_readonly_attr(self):
data = {'network': {'name': 'net1', 'tenant_id': _uuid(), data = {'network': {'name': 'net1', 'tenant_id': _uuid(),
'status': "ACTIVE"}} 'status': "ACTIVE"}}
res = self.api.post_json(_get_path('networks'), data, res = self.api.post(_get_path('networks', fmt=self.fmt),
expect_errors=True) self.serialize(data),
content_type='application/' + self.fmt,
expect_errors=True)
self.assertEqual(res.status_int, 400) self.assertEqual(res.status_int, 400)
def test_create_bulk(self): def test_create_bulk(self):
@ -436,28 +463,35 @@ class JSONV2TestCase(APIv2TestBase):
instance = self.plugin.return_value instance = self.plugin.return_value
instance.create_network.side_effect = side_effect instance.create_network.side_effect = side_effect
instance.get_networks_count.return_value = 0 instance.get_networks_count.return_value = 0
res = self.api.post(_get_path('networks', fmt=self.fmt),
res = self.api.post_json(_get_path('networks'), data) self.serialize(data),
content_type='application/' + self.fmt)
self.assertEqual(res.status_int, exc.HTTPCreated.code) self.assertEqual(res.status_int, exc.HTTPCreated.code)
def test_create_bulk_no_networks(self): def test_create_bulk_no_networks(self):
data = {'networks': []} data = {'networks': []}
res = self.api.post_json(_get_path('networks'), data, res = self.api.post(_get_path('networks', fmt=self.fmt),
expect_errors=True) self.serialize(data),
content_type='application/' + self.fmt,
expect_errors=True)
self.assertEqual(res.status_int, exc.HTTPBadRequest.code) self.assertEqual(res.status_int, exc.HTTPBadRequest.code)
def test_create_bulk_missing_attr(self): def test_create_bulk_missing_attr(self):
data = {'ports': [{'what': 'who', 'tenant_id': _uuid()}]} data = {'ports': [{'what': 'who', 'tenant_id': _uuid()}]}
res = self.api.post_json(_get_path('ports'), data, res = self.api.post(_get_path('ports', fmt=self.fmt),
expect_errors=True) self.serialize(data),
content_type='application/' + self.fmt,
expect_errors=True)
self.assertEqual(res.status_int, 400) self.assertEqual(res.status_int, 400)
def test_create_bulk_partial_body(self): def test_create_bulk_partial_body(self):
data = {'ports': [{'device_id': 'device_1', data = {'ports': [{'device_id': 'device_1',
'tenant_id': _uuid()}, 'tenant_id': _uuid()},
{'tenant_id': _uuid()}]} {'tenant_id': _uuid()}]}
res = self.api.post_json(_get_path('ports'), data, res = self.api.post(_get_path('ports', fmt=self.fmt),
expect_errors=True) self.serialize(data),
content_type='application/' + self.fmt,
expect_errors=True)
self.assertEqual(res.status_int, 400) self.assertEqual(res.status_int, 400)
def test_create_attr_not_specified(self): def test_create_attr_not_specified(self):
@ -484,12 +518,14 @@ class JSONV2TestCase(APIv2TestBase):
instance.get_network.return_value = {'tenant_id': unicode(tenant_id)} instance.get_network.return_value = {'tenant_id': unicode(tenant_id)}
instance.get_ports_count.return_value = 1 instance.get_ports_count.return_value = 1
instance.create_port.return_value = return_value instance.create_port.return_value = return_value
res = self.api.post_json(_get_path('ports'), initial_input) res = self.api.post(_get_path('ports', fmt=self.fmt),
self.serialize(initial_input),
content_type='application/' + self.fmt)
instance.create_port.assert_called_with(mock.ANY, port=full_input) instance.create_port.assert_called_with(mock.ANY, port=full_input)
self.assertEqual(res.status_int, exc.HTTPCreated.code) self.assertEqual(res.status_int, exc.HTTPCreated.code)
self.assertTrue('port' in res.json) res = self.deserialize(res)
port = res.json['port'] self.assertIn('port', res)
port = res['port']
self.assertEqual(port['network_id'], net_id) self.assertEqual(port['network_id'], net_id)
self.assertEqual(port['mac_address'], 'ca:fe:de:ad:be:ef') self.assertEqual(port['mac_address'], 'ca:fe:de:ad:be:ef')
@ -505,11 +541,13 @@ class JSONV2TestCase(APIv2TestBase):
instance.create_network.return_value = return_value instance.create_network.return_value = return_value
instance.get_networks_count.return_value = 0 instance.get_networks_count.return_value = 0
res = self.api.post_json(_get_path('networks'), data) res = self.api.post(_get_path('networks', fmt=self.fmt),
self.serialize(data),
content_type='application/' + self.fmt)
self.assertEqual(res.status_int, exc.HTTPCreated.code) self.assertEqual(res.status_int, exc.HTTPCreated.code)
self.assertTrue('network' in res.json) res = self.deserialize(res)
net = res.json['network'] self.assertIn('network', res)
net = res['network']
self.assertEqual(net['id'], net_id) self.assertEqual(net['id'], net_id)
self.assertEqual(net['status'], "ACTIVE") self.assertEqual(net['status'], "ACTIVE")
self.assertFalse('v2attrs:something' in net) self.assertFalse('v2attrs:something' in net)
@ -521,7 +559,9 @@ class JSONV2TestCase(APIv2TestBase):
instance = self.plugin.return_value instance = self.plugin.return_value
instance.get_network.return_value = return_value instance.get_network.return_value = return_value
self.api.get(_get_path('networks', id=uuidutils.generate_uuid())) self.api.get(_get_path('networks',
id=uuidutils.generate_uuid(),
fmt=self.fmt))
def _test_delete(self, req_tenant_id, real_tenant_id, expected_code, def _test_delete(self, req_tenant_id, real_tenant_id, expected_code,
expect_errors=False): expect_errors=False):
@ -534,7 +574,8 @@ class JSONV2TestCase(APIv2TestBase):
instance.delete_network.return_value = None instance.delete_network.return_value = None
res = self.api.delete(_get_path('networks', res = self.api.delete(_get_path('networks',
id=uuidutils.generate_uuid()), id=uuidutils.generate_uuid(),
fmt=self.fmt),
extra_environ=env, extra_environ=env,
expect_errors=expect_errors) expect_errors=expect_errors)
self.assertEqual(res.status_int, expected_code) self.assertEqual(res.status_int, expected_code)
@ -566,7 +607,8 @@ class JSONV2TestCase(APIv2TestBase):
instance.get_network.return_value = data instance.get_network.return_value = data
res = self.api.get(_get_path('networks', res = self.api.get(_get_path('networks',
id=uuidutils.generate_uuid()), id=uuidutils.generate_uuid(),
fmt=self.fmt),
extra_environ=env, extra_environ=env,
expect_errors=expect_errors) expect_errors=expect_errors)
self.assertEqual(res.status_int, expected_code) self.assertEqual(res.status_int, expected_code)
@ -602,11 +644,12 @@ class JSONV2TestCase(APIv2TestBase):
'shared': False} 'shared': False}
instance.update_network.return_value = return_value instance.update_network.return_value = return_value
res = self.api.put_json(_get_path('networks', res = self.api.put(_get_path('networks',
id=uuidutils.generate_uuid()), id=uuidutils.generate_uuid(),
data, fmt=self.fmt),
extra_environ=env, self.serialize(data),
expect_errors=expect_errors) extra_environ=env,
expect_errors=expect_errors)
self.assertEqual(res.status_int, expected_code) self.assertEqual(res.status_int, expected_code)
def test_update_noauth(self): def test_update_noauth(self):
@ -623,8 +666,10 @@ class JSONV2TestCase(APIv2TestBase):
def test_update_readonly_field(self): def test_update_readonly_field(self):
data = {'network': {'status': "NANANA"}} data = {'network': {'status': "NANANA"}}
res = self.api.put_json(_get_path('networks', id=_uuid()), data, res = self.api.put(_get_path('networks', id=_uuid()),
expect_errors=True) self.serialize(data),
content_type='application/' + self.fmt,
expect_errors=True)
self.assertEqual(res.status_int, 400) self.assertEqual(res.status_int, 400)
@ -705,6 +750,12 @@ class SubresourceTest(unittest.TestCase):
network_id='id1') network_id='id1')
# Note: since all resources use the same controller and validation
# logic, we actually get really good coverage from testing just networks.
class XMLV2TestCase(JSONV2TestCase):
fmt = 'xml'
class V2Views(unittest.TestCase): class V2Views(unittest.TestCase):
def _view(self, keys, collection, resource): def _view(self, keys, collection, resource):
data = dict((key, 'value') for key in keys) data = dict((key, 'value') for key in keys)
@ -778,7 +829,7 @@ class NotificationTest(APIv2TestBase):
def test_network_update_notifer(self): def test_network_update_notifer(self):
self._resource_op_notifier('update', 'network') self._resource_op_notifier('update', 'network')
def test_network_create_notifer(self): def test_network_create_notifer_with_log_level(self):
cfg.CONF.set_override('default_notification_level', 'DEBUG') cfg.CONF.set_override('default_notification_level', 'DEBUG')
self._resource_op_notifier('create', 'network', self._resource_op_notifier('create', 'network',
notification_level='DEBUG') notification_level='DEBUG')

View File

@ -27,15 +27,68 @@ import webtest
from quantum.api.v2 import resource as wsgi_resource from quantum.api.v2 import resource as wsgi_resource
from quantum.common import exceptions as q_exc from quantum.common import exceptions as q_exc
from quantum import context from quantum import context
from quantum import wsgi
class RequestTestCase(unittest.TestCase): class RequestTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.req = wsgi_resource.Request({'foo': 'bar'}) self.req = wsgi_resource.Request({'foo': 'bar'})
def test_best_match_content_type(self): def test_content_type_missing(self):
self.assertEqual(self.req.best_match_content_type(), request = wsgi.Request.blank('/tests/123', method='POST')
'application/json') request.body = "<body />"
self.assertEqual(None, request.get_content_type())
def test_content_type_with_charset(self):
request = wsgi.Request.blank('/tests/123')
request.headers["Content-Type"] = "application/json; charset=UTF-8"
result = request.get_content_type()
self.assertEqual(result, "application/json")
def test_content_type_from_accept(self):
for content_type in ('application/xml',
'application/json'):
request = wsgi.Request.blank('/tests/123')
request.headers["Accept"] = content_type
result = request.best_match_content_type()
self.assertEqual(result, content_type)
def test_content_type_from_accept_best(self):
request = wsgi.Request.blank('/tests/123')
request.headers["Accept"] = "application/xml, application/json"
result = request.best_match_content_type()
self.assertEqual(result, "application/json")
request = wsgi.Request.blank('/tests/123')
request.headers["Accept"] = ("application/json; q=0.3, "
"application/xml; q=0.9")
result = request.best_match_content_type()
self.assertEqual(result, "application/xml")
def test_content_type_from_query_extension(self):
request = wsgi.Request.blank('/tests/123.xml')
result = request.best_match_content_type()
self.assertEqual(result, "application/xml")
request = wsgi.Request.blank('/tests/123.json')
result = request.best_match_content_type()
self.assertEqual(result, "application/json")
request = wsgi.Request.blank('/tests/123.invalid')
result = request.best_match_content_type()
self.assertEqual(result, "application/json")
def test_content_type_accept_and_query_extension(self):
request = wsgi.Request.blank('/tests/123.xml')
request.headers["Accept"] = "application/json"
result = request.best_match_content_type()
self.assertEqual(result, "application/xml")
def test_content_type_accept_default(self):
request = wsgi.Request.blank('/tests/123.unsupported')
request.headers["Accept"] = "application/unsupported1"
result = request.best_match_content_type()
self.assertEqual(result, "application/json")
def test_context_with_quantum_context(self): def test_context_with_quantum_context(self):
ctxt = context.Context('fake_user', 'fake_tenant') ctxt = context.Context('fake_user', 'fake_tenant')

File diff suppressed because it is too large Load Diff

View File

@ -113,14 +113,16 @@ class SecurityGroupsTestCase(test_db_plugin.QuantumDbPluginV2TestCase):
return self.deserialize(fmt, res) return self.deserialize(fmt, res)
def _make_security_group_rule(self, fmt, rules, **kwargs): def _make_security_group_rule(self, fmt, rules, **kwargs):
res = self._create_security_group_rule('json', rules) res = self._create_security_group_rule(self.fmt, rules)
if res.status_int >= 400: if res.status_int >= 400:
raise webob.exc.HTTPClientError(code=res.status_int) raise webob.exc.HTTPClientError(code=res.status_int)
return self.deserialize(fmt, res) return self.deserialize(fmt, res)
@contextlib.contextmanager @contextlib.contextmanager
def security_group(self, name='webservers', description='webservers', def security_group(self, name='webservers', description='webservers',
external_id=None, fmt='json', no_delete=False): external_id=None, fmt=None, no_delete=False):
if not fmt:
fmt = self.fmt
security_group = self._make_security_group(fmt, name, description, security_group = self._make_security_group(fmt, name, description,
external_id) external_id)
try: try:
@ -136,8 +138,10 @@ class SecurityGroupsTestCase(test_db_plugin.QuantumDbPluginV2TestCase):
direction='ingress', protocol='tcp', direction='ingress', protocol='tcp',
port_range_min='22', port_range_max='22', port_range_min='22', port_range_max='22',
source_ip_prefix=None, source_group_id=None, source_ip_prefix=None, source_group_id=None,
external_id=None, fmt='json', no_delete=False, external_id=None, fmt=None, no_delete=False,
ethertype='IPv4'): ethertype='IPv4'):
if not fmt:
fmt = self.fmt
rule = self._build_security_group_rule(security_group_id, rule = self._build_security_group_rule(security_group_id,
direction, direction,
protocol, port_range_min, protocol, port_range_min,
@ -146,7 +150,7 @@ class SecurityGroupsTestCase(test_db_plugin.QuantumDbPluginV2TestCase):
source_group_id, source_group_id,
external_id, external_id,
ethertype=ethertype) ethertype=ethertype)
security_group_rule = self._make_security_group_rule('json', rule) security_group_rule = self._make_security_group_rule(self.fmt, rule)
try: try:
yield security_group_rule yield security_group_rule
finally: finally:
@ -155,6 +159,10 @@ class SecurityGroupsTestCase(test_db_plugin.QuantumDbPluginV2TestCase):
security_group_rule['security_group_rule']['id']) security_group_rule['security_group_rule']['id'])
class SecurityGroupsTestCaseXML(SecurityGroupsTestCase):
fmt = 'xml'
class SecurityGroupTestPlugin(db_base_plugin_v2.QuantumDbPluginV2, class SecurityGroupTestPlugin(db_base_plugin_v2.QuantumDbPluginV2,
securitygroups_db.SecurityGroupDbMixin): securitygroups_db.SecurityGroupDbMixin):
""" Test plugin that implements necessary calls on create/delete port for """ Test plugin that implements necessary calls on create/delete port for
@ -231,36 +239,36 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
def test_default_security_group(self): def test_default_security_group(self):
with self.network(): with self.network():
res = self.new_list_request('security-groups') res = self.new_list_request('security-groups')
groups = self.deserialize('json', res.get_response(self.ext_api)) groups = self.deserialize(self.fmt, res.get_response(self.ext_api))
self.assertEqual(len(groups['security_groups']), 1) self.assertEqual(len(groups['security_groups']), 1)
def test_create_security_group_proxy_mode_not_admin(self): def test_create_security_group_proxy_mode_not_admin(self):
cfg.CONF.set_override('proxy_mode', True, 'SECURITYGROUP') cfg.CONF.set_override('proxy_mode', True, 'SECURITYGROUP')
res = self._create_security_group('json', 'webservers', res = self._create_security_group(self.fmt, 'webservers',
'webservers', '1', 'webservers', '1',
tenant_id='bad_tenant', tenant_id='bad_tenant',
set_context=True) set_context=True)
self.deserialize('json', res) self.deserialize(self.fmt, res)
self.assertEqual(res.status_int, 403) self.assertEqual(res.status_int, 403)
def test_create_security_group_no_external_id_proxy_mode(self): def test_create_security_group_no_external_id_proxy_mode(self):
cfg.CONF.set_override('proxy_mode', True, 'SECURITYGROUP') cfg.CONF.set_override('proxy_mode', True, 'SECURITYGROUP')
res = self._create_security_group('json', 'webservers', res = self._create_security_group(self.fmt, 'webservers',
'webservers') 'webservers')
self.deserialize('json', res) self.deserialize(self.fmt, res)
self.assertEqual(res.status_int, 400) self.assertEqual(res.status_int, 400)
def test_create_security_group_no_external_id_not_proxy_mode(self): def test_create_security_group_no_external_id_not_proxy_mode(self):
res = self._create_security_group('json', 'webservers', res = self._create_security_group(self.fmt, 'webservers',
'webservers', '1') 'webservers', '1')
self.deserialize('json', res) self.deserialize(self.fmt, res)
self.assertEqual(res.status_int, 409) self.assertEqual(res.status_int, 409)
def test_create_default_security_group_fail(self): def test_create_default_security_group_fail(self):
name = 'default' name = 'default'
description = 'my webservers' description = 'my webservers'
res = self._create_security_group('json', name, description) res = self._create_security_group(self.fmt, name, description)
self.deserialize('json', res) self.deserialize(self.fmt, res)
self.assertEqual(res.status_int, 409) self.assertEqual(res.status_int, 409)
def test_create_security_group_duplicate_external_id(self): def test_create_security_group_duplicate_external_id(self):
@ -269,9 +277,9 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
description = 'my webservers' description = 'my webservers'
external_id = 1 external_id = 1
with self.security_group(name, description, external_id): with self.security_group(name, description, external_id):
res = self._create_security_group('json', name, description, res = self._create_security_group(self.fmt, name, description,
external_id) external_id)
self.deserialize('json', res) self.deserialize(self.fmt, res)
self.assertEqual(res.status_int, 409) self.assertEqual(res.status_int, 409)
def test_list_security_groups(self): def test_list_security_groups(self):
@ -279,7 +287,7 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
description = 'my webservers' description = 'my webservers'
with self.security_group(name, description): with self.security_group(name, description):
res = self.new_list_request('security-groups') res = self.new_list_request('security-groups')
groups = self.deserialize('json', res.get_response(self.ext_api)) groups = self.deserialize(self.fmt, res.get_response(self.ext_api))
self.assertEqual(len(groups['security_groups']), 2) self.assertEqual(len(groups['security_groups']), 2)
for group in groups['security_groups']: for group in groups['security_groups']:
if group['name'] == 'default': if group['name'] == 'default':
@ -342,7 +350,6 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
with self.security_group(name, description) as sg: with self.security_group(name, description) as sg:
source_group_id = sg['security_group']['id'] source_group_id = sg['security_group']['id']
res = self.new_show_request('security-groups', source_group_id) res = self.new_show_request('security-groups', source_group_id)
security_group_id = sg['security_group']['id'] security_group_id = sg['security_group']['id']
direction = "ingress" direction = "ingress"
source_ip_prefix = "10.0.0.0/24" source_ip_prefix = "10.0.0.0/24"
@ -361,7 +368,7 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
source_ip_prefix): source_ip_prefix):
group = self.deserialize( group = self.deserialize(
'json', res.get_response(self.ext_api)) self.fmt, res.get_response(self.ext_api))
sg_rule = group['security_group']['security_group_rules'] sg_rule = group['security_group']['security_group_rules']
self.assertEqual(group['security_group']['id'], self.assertEqual(group['security_group']['id'],
source_group_id) source_group_id)
@ -379,17 +386,17 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
def test_delete_default_security_group_fail(self): def test_delete_default_security_group_fail(self):
with self.network(): with self.network():
res = self.new_list_request('security-groups') res = self.new_list_request('security-groups')
sg = self.deserialize('json', res.get_response(self.ext_api)) sg = self.deserialize(self.fmt, res.get_response(self.ext_api))
self._delete('security-groups', sg['security_groups'][0]['id'], self._delete('security-groups', sg['security_groups'][0]['id'],
409) 409)
def test_default_security_group_rules(self): def test_default_security_group_rules(self):
with self.network(): with self.network():
res = self.new_list_request('security-groups') res = self.new_list_request('security-groups')
groups = self.deserialize('json', res.get_response(self.ext_api)) groups = self.deserialize(self.fmt, res.get_response(self.ext_api))
self.assertEqual(len(groups['security_groups']), 1) self.assertEqual(len(groups['security_groups']), 1)
res = self.new_list_request('security-group-rules') res = self.new_list_request('security-group-rules')
rules = self.deserialize('json', res.get_response(self.ext_api)) rules = self.deserialize(self.fmt, res.get_response(self.ext_api))
self.assertEqual(len(rules['security_group_rules']), 2) self.assertEqual(len(rules['security_group_rules']), 2)
# just generic rules to allow default egress and # just generic rules to allow default egress and
# intergroup communicartion # intergroup communicartion
@ -459,8 +466,8 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
port_range_max, port_range_max,
source_ip_prefix, source_ip_prefix,
source_group_id) source_group_id)
res = self._create_security_group_rule('json', rule) res = self._create_security_group_rule(self.fmt, rule)
self.deserialize('json', res) self.deserialize(self.fmt, res)
self.assertEqual(res.status_int, 400) self.assertEqual(res.status_int, 400)
def test_create_security_group_rule_bad_security_group_id(self): def test_create_security_group_rule_bad_security_group_id(self):
@ -474,8 +481,8 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
protocol, port_range_min, protocol, port_range_min,
port_range_max, port_range_max,
source_ip_prefix) source_ip_prefix)
res = self._create_security_group_rule('json', rule) res = self._create_security_group_rule(self.fmt, rule)
self.deserialize('json', res) self.deserialize(self.fmt, res)
self.assertEqual(res.status_int, 404) self.assertEqual(res.status_int, 404)
def test_create_security_group_rule_bad_tenant(self): def test_create_security_group_rule_bad_tenant(self):
@ -488,8 +495,8 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
'port_range_max': '22', 'port_range_max': '22',
'tenant_id': "bad_tenant"}} 'tenant_id': "bad_tenant"}}
res = self._create_security_group_rule('json', rule) res = self._create_security_group_rule(self.fmt, rule)
self.deserialize('json', res) self.deserialize(self.fmt, res)
self.assertEqual(res.status_int, 404) self.assertEqual(res.status_int, 404)
def test_create_security_group_rule_exteral_id_proxy_mode(self): def test_create_security_group_rule_exteral_id_proxy_mode(self):
@ -505,8 +512,8 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
'tenant_id': 'test_tenant', 'tenant_id': 'test_tenant',
'source_group_id': sg['security_group']['id']}} 'source_group_id': sg['security_group']['id']}}
res = self._create_security_group_rule('json', rule) res = self._create_security_group_rule(self.fmt, rule)
self.deserialize('json', res) self.deserialize(self.fmt, res)
self.assertEqual(res.status_int, 201) self.assertEqual(res.status_int, 201)
def test_create_security_group_rule_exteral_id_not_proxy_mode(self): def test_create_security_group_rule_exteral_id_not_proxy_mode(self):
@ -521,8 +528,8 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
'tenant_id': 'test_tenant', 'tenant_id': 'test_tenant',
'source_group_id': sg['security_group']['id']}} 'source_group_id': sg['security_group']['id']}}
res = self._create_security_group_rule('json', rule) res = self._create_security_group_rule(self.fmt, rule)
self.deserialize('json', res) self.deserialize(self.fmt, res)
self.assertEqual(res.status_int, 409) self.assertEqual(res.status_int, 409)
def test_create_security_group_rule_not_admin(self): def test_create_security_group_rule_not_admin(self):
@ -538,18 +545,18 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
'external_id': 1, 'external_id': 1,
'source_group_id': sg['security_group']['id']}} 'source_group_id': sg['security_group']['id']}}
res = self._create_security_group_rule('json', rule, res = self._create_security_group_rule(self.fmt, rule,
tenant_id='bad_tenant', tenant_id='bad_tenant',
set_context=True) set_context=True)
self.deserialize('json', res) self.deserialize(self.fmt, res)
self.assertEqual(res.status_int, 403) self.assertEqual(res.status_int, 403)
def test_create_security_group_rule_bad_tenant_source_group_id(self): def test_create_security_group_rule_bad_tenant_source_group_id(self):
with self.security_group() as sg: with self.security_group() as sg:
res = self._create_security_group('json', 'webservers', res = self._create_security_group(self.fmt, 'webservers',
'webservers', 'webservers',
tenant_id='bad_tenant') tenant_id='bad_tenant')
sg2 = self.deserialize('json', res) sg2 = self.deserialize(self.fmt, res)
rule = {'security_group_rule': rule = {'security_group_rule':
{'security_group_id': sg2['security_group']['id'], {'security_group_id': sg2['security_group']['id'],
'direction': 'ingress', 'direction': 'ingress',
@ -559,18 +566,18 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
'tenant_id': 'bad_tenant', 'tenant_id': 'bad_tenant',
'source_group_id': sg['security_group']['id']}} 'source_group_id': sg['security_group']['id']}}
res = self._create_security_group_rule('json', rule, res = self._create_security_group_rule(self.fmt, rule,
tenant_id='bad_tenant', tenant_id='bad_tenant',
set_context=True) set_context=True)
self.deserialize('json', res) self.deserialize(self.fmt, res)
self.assertEqual(res.status_int, 404) self.assertEqual(res.status_int, 404)
def test_create_security_group_rule_bad_tenant_security_group_rule(self): def test_create_security_group_rule_bad_tenant_security_group_rule(self):
with self.security_group() as sg: with self.security_group() as sg:
res = self._create_security_group('json', 'webservers', res = self._create_security_group(self.fmt, 'webservers',
'webservers', 'webservers',
tenant_id='bad_tenant') tenant_id='bad_tenant')
self.deserialize('json', res) self.deserialize(self.fmt, res)
rule = {'security_group_rule': rule = {'security_group_rule':
{'security_group_id': sg['security_group']['id'], {'security_group_id': sg['security_group']['id'],
'direction': 'ingress', 'direction': 'ingress',
@ -579,10 +586,10 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
'port_range_max': '22', 'port_range_max': '22',
'tenant_id': 'bad_tenant'}} 'tenant_id': 'bad_tenant'}}
res = self._create_security_group_rule('json', rule, res = self._create_security_group_rule(self.fmt, rule,
tenant_id='bad_tenant', tenant_id='bad_tenant',
set_context=True) set_context=True)
self.deserialize('json', res) self.deserialize(self.fmt, res)
self.assertEqual(res.status_int, 404) self.assertEqual(res.status_int, 404)
def test_create_security_group_rule_bad_source_group_id(self): def test_create_security_group_rule_bad_source_group_id(self):
@ -599,8 +606,8 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
protocol, port_range_min, protocol, port_range_min,
port_range_max, port_range_max,
source_group_id=source_group_id) source_group_id=source_group_id)
res = self._create_security_group_rule('json', rule) res = self._create_security_group_rule(self.fmt, rule)
self.deserialize('json', res) self.deserialize(self.fmt, res)
self.assertEqual(res.status_int, 404) self.assertEqual(res.status_int, 404)
def test_create_security_group_rule_duplicate_rules(self): def test_create_security_group_rule_duplicate_rules(self):
@ -611,9 +618,9 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
with self.security_group_rule(security_group_id): with self.security_group_rule(security_group_id):
rule = self._build_security_group_rule( rule = self._build_security_group_rule(
sg['security_group']['id'], 'ingress', 'tcp', '22', '22') sg['security_group']['id'], 'ingress', 'tcp', '22', '22')
self._create_security_group_rule('json', rule) self._create_security_group_rule(self.fmt, rule)
res = self._create_security_group_rule('json', rule) res = self._create_security_group_rule(self.fmt, rule)
self.deserialize('json', res) self.deserialize(self.fmt, res)
self.assertEqual(res.status_int, 409) self.assertEqual(res.status_int, 409)
def test_create_security_group_rule_min_port_greater_max(self): def test_create_security_group_rule_min_port_greater_max(self):
@ -624,9 +631,9 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
with self.security_group_rule(security_group_id): with self.security_group_rule(security_group_id):
rule = self._build_security_group_rule( rule = self._build_security_group_rule(
sg['security_group']['id'], 'ingress', 'tcp', '50', '22') sg['security_group']['id'], 'ingress', 'tcp', '50', '22')
self._create_security_group_rule('json', rule) self._create_security_group_rule(self.fmt, rule)
res = self._create_security_group_rule('json', rule) res = self._create_security_group_rule(self.fmt, rule)
self.deserialize('json', res) self.deserialize(self.fmt, res)
self.assertEqual(res.status_int, 400) self.assertEqual(res.status_int, 400)
def test_create_security_group_rule_ports_but_no_protocol(self): def test_create_security_group_rule_ports_but_no_protocol(self):
@ -637,17 +644,17 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
with self.security_group_rule(security_group_id): with self.security_group_rule(security_group_id):
rule = self._build_security_group_rule( rule = self._build_security_group_rule(
sg['security_group']['id'], 'ingress', None, '22', '22') sg['security_group']['id'], 'ingress', None, '22', '22')
self._create_security_group_rule('json', rule) self._create_security_group_rule(self.fmt, rule)
res = self._create_security_group_rule('json', rule) res = self._create_security_group_rule(self.fmt, rule)
self.deserialize('json', res) self.deserialize(self.fmt, res)
self.assertEqual(res.status_int, 400) self.assertEqual(res.status_int, 400)
def test_update_port_with_security_group(self): def test_update_port_with_security_group(self):
with self.network() as n: with self.network() as n:
with self.subnet(n): with self.subnet(n):
with self.security_group() as sg: with self.security_group() as sg:
res = self._create_port('json', n['network']['id']) res = self._create_port(self.fmt, n['network']['id'])
port = self.deserialize('json', res) port = self.deserialize(self.fmt, res)
data = {'port': {'fixed_ips': port['port']['fixed_ips'], data = {'port': {'fixed_ips': port['port']['fixed_ips'],
'name': port['port']['name'], 'name': port['port']['name'],
@ -656,7 +663,8 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
req = self.new_update_request('ports', data, req = self.new_update_request('ports', data,
port['port']['id']) port['port']['id'])
res = self.deserialize('json', req.get_response(self.api)) res = self.deserialize(self.fmt,
req.get_response(self.api))
self.assertEqual(res['port'][ext_sg.SECURITYGROUPS][0], self.assertEqual(res['port'][ext_sg.SECURITYGROUPS][0],
sg['security_group']['id']) sg['security_group']['id'])
@ -666,7 +674,8 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
req = self.new_update_request('ports', data, req = self.new_update_request('ports', data,
port['port']['id']) port['port']['id'])
res = self.deserialize('json', req.get_response(self.api)) res = self.deserialize(self.fmt,
req.get_response(self.api))
self.assertEqual(res['port'][ext_sg.SECURITYGROUPS][0], self.assertEqual(res['port'][ext_sg.SECURITYGROUPS][0],
sg['security_group']['id']) sg['security_group']['id'])
@ -678,10 +687,10 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
with self.security_group() as sg1: with self.security_group() as sg1:
with self.security_group() as sg2: with self.security_group() as sg2:
res = self._create_port( res = self._create_port(
'json', n['network']['id'], self.fmt, n['network']['id'],
security_groups=[sg1['security_group']['id'], security_groups=[sg1['security_group']['id'],
sg2['security_group']['id']]) sg2['security_group']['id']])
port = self.deserialize('json', res) port = self.deserialize(self.fmt, res)
self.assertEqual(len( self.assertEqual(len(
port['port'][ext_sg.SECURITYGROUPS]), 2) port['port'][ext_sg.SECURITYGROUPS]), 2)
self._delete('ports', port['port']['id']) self._delete('ports', port['port']['id'])
@ -690,10 +699,10 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
with self.network() as n: with self.network() as n:
with self.subnet(n): with self.subnet(n):
with self.security_group() as sg: with self.security_group() as sg:
res = self._create_port('json', n['network']['id'], res = self._create_port(self.fmt, n['network']['id'],
security_groups=( security_groups=(
[sg['security_group']['id']])) [sg['security_group']['id']]))
port = self.deserialize('json', res) port = self.deserialize(self.fmt, res)
data = {'port': {'fixed_ips': port['port']['fixed_ips'], data = {'port': {'fixed_ips': port['port']['fixed_ips'],
'name': port['port']['name'], 'name': port['port']['name'],
@ -701,7 +710,8 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
req = self.new_update_request('ports', data, req = self.new_update_request('ports', data,
port['port']['id']) port['port']['id'])
res = self.deserialize('json', req.get_response(self.api)) res = self.deserialize(self.fmt,
req.get_response(self.api))
self.assertEqual(res['port'].get(ext_sg.SECURITYGROUPS), self.assertEqual(res['port'].get(ext_sg.SECURITYGROUPS),
[]) [])
self._delete('ports', port['port']['id']) self._delete('ports', port['port']['id'])
@ -710,10 +720,10 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
with self.network() as n: with self.network() as n:
with self.subnet(n): with self.subnet(n):
with self.security_group() as sg: with self.security_group() as sg:
res = self._create_port('json', n['network']['id'], res = self._create_port(self.fmt, n['network']['id'],
security_groups=( security_groups=(
[sg['security_group']['id']])) [sg['security_group']['id']]))
port = self.deserialize('json', res) port = self.deserialize(self.fmt, res)
data = {'port': {'fixed_ips': port['port']['fixed_ips'], data = {'port': {'fixed_ips': port['port']['fixed_ips'],
'name': port['port']['name'], 'name': port['port']['name'],
@ -721,7 +731,8 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
req = self.new_update_request('ports', data, req = self.new_update_request('ports', data,
port['port']['id']) port['port']['id'])
res = self.deserialize('json', req.get_response(self.api)) res = self.deserialize(self.fmt,
req.get_response(self.api))
self.assertEqual(res['port'].get(ext_sg.SECURITYGROUPS), self.assertEqual(res['port'].get(ext_sg.SECURITYGROUPS),
[]) [])
self._delete('ports', port['port']['id']) self._delete('ports', port['port']['id'])
@ -729,20 +740,20 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
def test_create_port_with_bad_security_group(self): def test_create_port_with_bad_security_group(self):
with self.network() as n: with self.network() as n:
with self.subnet(n): with self.subnet(n):
res = self._create_port('json', n['network']['id'], res = self._create_port(self.fmt, n['network']['id'],
security_groups=['bad_id']) security_groups=['bad_id'])
self.deserialize('json', res) self.deserialize(self.fmt, res)
self.assertEqual(res.status_int, 400) self.assertEqual(res.status_int, 400)
def test_create_delete_security_group_port_in_use(self): def test_create_delete_security_group_port_in_use(self):
with self.network() as n: with self.network() as n:
with self.subnet(n): with self.subnet(n):
with self.security_group() as sg: with self.security_group() as sg:
res = self._create_port('json', n['network']['id'], res = self._create_port(self.fmt, n['network']['id'],
security_groups=( security_groups=(
[sg['security_group']['id']])) [sg['security_group']['id']]))
port = self.deserialize('json', res) port = self.deserialize(self.fmt, res)
self.assertEqual(port['port'][ext_sg.SECURITYGROUPS][0], self.assertEqual(port['port'][ext_sg.SECURITYGROUPS][0],
sg['security_group']['id']) sg['security_group']['id'])
# try to delete security group that's in use # try to delete security group that's in use
@ -764,8 +775,8 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
'23', '10.0.0.1/24') '23', '10.0.0.1/24')
rules = {'security_group_rules': [rule1['security_group_rule'], rules = {'security_group_rules': [rule1['security_group_rule'],
rule2['security_group_rule']]} rule2['security_group_rule']]}
res = self._create_security_group_rule('json', rules) res = self._create_security_group_rule(self.fmt, rules)
self.deserialize('json', res) self.deserialize(self.fmt, res)
self.assertEqual(res.status_int, 201) self.assertEqual(res.status_int, 201)
def test_create_security_group_rule_bulk_emulated(self): def test_create_security_group_rule_bulk_emulated(self):
@ -789,8 +800,8 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
rules = {'security_group_rules': [rule1['security_group_rule'], rules = {'security_group_rules': [rule1['security_group_rule'],
rule2['security_group_rule']] rule2['security_group_rule']]
} }
res = self._create_security_group_rule('json', rules) res = self._create_security_group_rule(self.fmt, rules)
self.deserialize('json', res) self.deserialize(self.fmt, res)
self.assertEqual(res.status_int, 201) self.assertEqual(res.status_int, 201)
def test_create_security_group_rule_duplicate_rule_in_post(self): def test_create_security_group_rule_duplicate_rule_in_post(self):
@ -803,8 +814,8 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
'22', '10.0.0.1/24') '22', '10.0.0.1/24')
rules = {'security_group_rules': [rule['security_group_rule'], rules = {'security_group_rules': [rule['security_group_rule'],
rule['security_group_rule']]} rule['security_group_rule']]}
res = self._create_security_group_rule('json', rules) res = self._create_security_group_rule(self.fmt, rules)
rule = self.deserialize('json', res) rule = self.deserialize(self.fmt, res)
self.assertEqual(res.status_int, 409) self.assertEqual(res.status_int, 409)
def test_create_security_group_rule_duplicate_rule_in_post_emulated(self): def test_create_security_group_rule_duplicate_rule_in_post_emulated(self):
@ -825,8 +836,8 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
'10.0.0.1/24') '10.0.0.1/24')
rules = {'security_group_rules': [rule['security_group_rule'], rules = {'security_group_rules': [rule['security_group_rule'],
rule['security_group_rule']]} rule['security_group_rule']]}
res = self._create_security_group_rule('json', rules) res = self._create_security_group_rule(self.fmt, rules)
rule = self.deserialize('json', res) rule = self.deserialize(self.fmt, res)
self.assertEqual(res.status_int, 409) self.assertEqual(res.status_int, 409)
def test_create_security_group_rule_duplicate_rule_db(self): def test_create_security_group_rule_duplicate_rule_db(self):
@ -838,9 +849,9 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
'ingress', 'tcp', '22', 'ingress', 'tcp', '22',
'22', '10.0.0.1/24') '22', '10.0.0.1/24')
rules = {'security_group_rules': [rule]} rules = {'security_group_rules': [rule]}
self._create_security_group_rule('json', rules) self._create_security_group_rule(self.fmt, rules)
res = self._create_security_group_rule('json', rules) res = self._create_security_group_rule(self.fmt, rules)
rule = self.deserialize('json', res) rule = self.deserialize(self.fmt, res)
self.assertEqual(res.status_int, 409) self.assertEqual(res.status_int, 409)
def test_create_security_group_rule_duplicate_rule_db_emulated(self): def test_create_security_group_rule_duplicate_rule_db_emulated(self):
@ -859,9 +870,9 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
sg['security_group']['id'], 'ingress', 'tcp', '22', '22', sg['security_group']['id'], 'ingress', 'tcp', '22', '22',
'10.0.0.1/24') '10.0.0.1/24')
rules = {'security_group_rules': [rule]} rules = {'security_group_rules': [rule]}
self._create_security_group_rule('json', rules) self._create_security_group_rule(self.fmt, rules)
res = self._create_security_group_rule('json', rule) res = self._create_security_group_rule(self.fmt, rule)
self.deserialize('json', res) self.deserialize(self.fmt, res)
self.assertEqual(res.status_int, 409) self.assertEqual(res.status_int, 409)
def test_create_security_group_rule_differnt_security_group_ids(self): def test_create_security_group_rule_differnt_security_group_ids(self):
@ -880,8 +891,8 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
rules = {'security_group_rules': [rule1['security_group_rule'], rules = {'security_group_rules': [rule1['security_group_rule'],
rule2['security_group_rule']] rule2['security_group_rule']]
} }
res = self._create_security_group_rule('json', rules) res = self._create_security_group_rule(self.fmt, rules)
self.deserialize('json', res) self.deserialize(self.fmt, res)
self.assertEqual(res.status_int, 400) self.assertEqual(res.status_int, 400)
def test_create_security_group_rule_with_invalid_ethertype(self): def test_create_security_group_rule_with_invalid_ethertype(self):
@ -898,8 +909,8 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
source_ip_prefix, source_ip_prefix,
source_group_id, source_group_id,
ethertype='IPv5') ethertype='IPv5')
res = self._create_security_group_rule('json', rule) res = self._create_security_group_rule(self.fmt, rule)
self.deserialize('json', res) self.deserialize(self.fmt, res)
self.assertEqual(res.status_int, 400) self.assertEqual(res.status_int, 400)
def test_create_security_group_rule_with_invalid_protocol(self): def test_create_security_group_rule_with_invalid_protocol(self):
@ -915,23 +926,25 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
port_range_max, port_range_max,
source_ip_prefix, source_ip_prefix,
source_group_id) source_group_id)
res = self._create_security_group_rule('json', rule) res = self._create_security_group_rule(self.fmt, rule)
self.deserialize('json', res) self.deserialize(self.fmt, res)
self.assertEqual(res.status_int, 400) self.assertEqual(res.status_int, 400)
def test_validate_port_external_id_quantum_id(self): def test_validate_port_external_id_quantum_id(self):
cfg.CONF.set_override('proxy_mode', True, 'SECURITYGROUP') cfg.CONF.set_override('proxy_mode', True, 'SECURITYGROUP')
with self.network() as n: with self.network() as n:
with self.subnet(n): with self.subnet(n):
sg1 = (self.deserialize('json', sg1 = (self.deserialize(self.fmt,
self._create_security_group('json', 'foo', 'bar', '1'))) self._create_security_group(self.fmt,
sg2 = (self.deserialize('json', 'foo', 'bar', '1')))
self._create_security_group('json', 'foo', 'bar', '2'))) sg2 = (self.deserialize(self.fmt,
self._create_security_group(self.fmt,
'foo', 'bar', '2')))
res = self._create_port( res = self._create_port(
'json', n['network']['id'], self.fmt, n['network']['id'],
security_groups=[sg1['security_group']['id']]) security_groups=[sg1['security_group']['id']])
port = self.deserialize('json', res) port = self.deserialize(self.fmt, res)
# This request updates the port sending the quantum security # This request updates the port sending the quantum security
# group id in and a nova security group id. # group id in and a nova security group id.
data = {'port': {'fixed_ips': port['port']['fixed_ips'], data = {'port': {'fixed_ips': port['port']['fixed_ips'],
@ -941,7 +954,7 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
sg2['security_group']['id']]}} sg2['security_group']['id']]}}
req = self.new_update_request('ports', data, req = self.new_update_request('ports', data,
port['port']['id']) port['port']['id'])
res = self.deserialize('json', req.get_response(self.api)) res = self.deserialize(self.fmt, req.get_response(self.api))
self.assertEquals(len(res['port'][ext_sg.SECURITYGROUPS]), 2) self.assertEquals(len(res['port'][ext_sg.SECURITYGROUPS]), 2)
for sg_id in res['port'][ext_sg.SECURITYGROUPS]: for sg_id in res['port'][ext_sg.SECURITYGROUPS]:
# only security group id's should be # only security group id's should be
@ -956,25 +969,27 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
string_id = '1' string_id = '1'
int_id = 2 int_id = 2
self.deserialize( self.deserialize(
'json', self._create_security_group('json', 'foo', 'bar', self.fmt, self._create_security_group(self.fmt,
string_id)) 'foo', 'bar',
string_id))
self.deserialize( self.deserialize(
'json', self._create_security_group('json', 'foo', 'bar', self.fmt, self._create_security_group(self.fmt,
int_id)) 'foo', 'bar',
int_id))
res = self._create_port( res = self._create_port(
'json', n['network']['id'], self.fmt, n['network']['id'],
security_groups=[string_id, int_id]) security_groups=[string_id, int_id])
port = self.deserialize('json', res) port = self.deserialize(self.fmt, res)
self._delete('ports', port['port']['id']) self._delete('ports', port['port']['id'])
def test_create_port_with_non_uuid_or_int(self): def test_create_port_with_non_uuid_or_int(self):
with self.network() as n: with self.network() as n:
with self.subnet(n): with self.subnet(n):
res = self._create_port('json', n['network']['id'], res = self._create_port(self.fmt, n['network']['id'],
security_groups=['not_valid']) security_groups=['not_valid'])
self.deserialize('json', res) self.deserialize(self.fmt, res)
self.assertEqual(res.status_int, 400) self.assertEqual(res.status_int, 400)
def test_validate_port_external_id_fail(self): def test_validate_port_external_id_fail(self):
@ -983,8 +998,12 @@ class TestSecurityGroups(SecurityGroupDBTestCase):
with self.subnet(n): with self.subnet(n):
bad_id = 1 bad_id = 1
res = self._create_port( res = self._create_port(
'json', n['network']['id'], self.fmt, n['network']['id'],
security_groups=[bad_id]) security_groups=[bad_id])
self.deserialize('json', res) self.deserialize(self.fmt, res)
self.assertEqual(res.status_int, 404) self.assertEqual(res.status_int, 404)
class TestSecurityGroupsXML(TestSecurityGroups):
fmt = 'xml'

View File

@ -30,6 +30,7 @@ from quantum.openstack.common import log as logging
from quantum.plugins.common import constants from quantum.plugins.common import constants
from quantum.tests.unit import extension_stubs as ext_stubs from quantum.tests.unit import extension_stubs as ext_stubs
import quantum.tests.unit.extensions import quantum.tests.unit.extensions
from quantum.tests.unit import testlib_api
from quantum import wsgi from quantum import wsgi
@ -523,33 +524,40 @@ class PluginAwareExtensionManagerTest(unittest.TestCase):
self.assertTrue("e1" in ext_mgr.extensions) self.assertTrue("e1" in ext_mgr.extensions)
class ExtensionControllerTest(unittest.TestCase): class ExtensionControllerTest(testlib_api.WebTestCase):
def setUp(self): def setUp(self):
super(ExtensionControllerTest, self).setUp() super(ExtensionControllerTest, self).setUp()
self.test_app = _setup_extensions_test_app() self.test_app = _setup_extensions_test_app()
def test_index_gets_all_registerd_extensions(self): def test_index_gets_all_registerd_extensions(self):
response = self.test_app.get("/extensions") response = self.test_app.get("/extensions." + self.fmt)
foxnsox = response.json["extensions"][0] res_body = self.deserialize(response)
foxnsox = res_body["extensions"][0]
self.assertEqual(foxnsox["alias"], "FOXNSOX") self.assertEqual(foxnsox["alias"], "FOXNSOX")
self.assertEqual(foxnsox["namespace"], self.assertEqual(foxnsox["namespace"],
"http://www.fox.in.socks/api/ext/pie/v1.0") "http://www.fox.in.socks/api/ext/pie/v1.0")
def test_extension_can_be_accessed_by_alias(self): def test_extension_can_be_accessed_by_alias(self):
foxnsox_extension = self.test_app.get("/extensions/FOXNSOX").json response = self.test_app.get("/extensions/FOXNSOX." + self.fmt)
foxnsox_extension = self.deserialize(response)
foxnsox_extension = foxnsox_extension['extension'] foxnsox_extension = foxnsox_extension['extension']
self.assertEqual(foxnsox_extension["alias"], "FOXNSOX") self.assertEqual(foxnsox_extension["alias"], "FOXNSOX")
self.assertEqual(foxnsox_extension["namespace"], self.assertEqual(foxnsox_extension["namespace"],
"http://www.fox.in.socks/api/ext/pie/v1.0") "http://www.fox.in.socks/api/ext/pie/v1.0")
def test_show_returns_not_found_for_non_existent_extension(self): def test_show_returns_not_found_for_non_existent_extension(self):
response = self.test_app.get("/extensions/non_existent", status="*") response = self.test_app.get("/extensions/non_existent" + self.fmt,
status="*")
self.assertEqual(response.status_int, 404) self.assertEqual(response.status_int, 404)
class ExtensionControllerTestXML(ExtensionControllerTest):
fmt = 'xml'
def app_factory(global_conf, **local_conf): def app_factory(global_conf, **local_conf):
conf = global_conf.copy() conf = global_conf.copy()
conf.update(local_conf) conf.update(local_conf)

View File

@ -48,6 +48,7 @@ from quantum.openstack.common import uuidutils
from quantum.tests.unit import test_api_v2 from quantum.tests.unit import test_api_v2
from quantum.tests.unit import test_db_plugin from quantum.tests.unit import test_db_plugin
from quantum.tests.unit import test_extensions from quantum.tests.unit import test_extensions
from quantum.tests.unit import testlib_api
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -68,7 +69,8 @@ class L3TestExtensionManager(object):
return [] return []
class L3NatExtensionTestCase(unittest.TestCase): class L3NatExtensionTestCase(testlib_api.WebTestCase):
fmt = 'json'
def setUp(self): def setUp(self):
@ -100,6 +102,7 @@ class L3NatExtensionTestCase(unittest.TestCase):
ext_mgr = L3TestExtensionManager() ext_mgr = L3TestExtensionManager()
self.ext_mdw = test_extensions.setup_extensions_middleware(ext_mgr) self.ext_mdw = test_extensions.setup_extensions_middleware(ext_mgr)
self.api = webtest.TestApp(self.ext_mdw) self.api = webtest.TestApp(self.ext_mdw)
super(L3NatExtensionTestCase, self).setUp()
def tearDown(self): def tearDown(self):
self._plugin_patcher.stop() self._plugin_patcher.stop()
@ -121,12 +124,15 @@ class L3NatExtensionTestCase(unittest.TestCase):
instance = self.plugin.return_value instance = self.plugin.return_value
instance.create_router.return_value = return_value instance.create_router.return_value = return_value
instance.get_routers_count.return_value = 0 instance.get_routers_count.return_value = 0
res = self.api.post_json(_get_path('routers'), data) res = self.api.post(_get_path('routers', fmt=self.fmt),
self.serialize(data),
content_type='application/%s' % self.fmt)
instance.create_router.assert_called_with(mock.ANY, instance.create_router.assert_called_with(mock.ANY,
router=data) router=data)
self.assertEqual(res.status_int, exc.HTTPCreated.code) self.assertEqual(res.status_int, exc.HTTPCreated.code)
self.assertTrue('router' in res.json) res = self.deserialize(res)
router = res.json['router'] self.assertTrue('router' in res)
router = res['router']
self.assertEqual(router['id'], router_id) self.assertEqual(router['id'], router_id)
self.assertEqual(router['status'], "ACTIVE") self.assertEqual(router['status'], "ACTIVE")
self.assertEqual(router['admin_state_up'], True) self.assertEqual(router['admin_state_up'], True)
@ -139,14 +145,15 @@ class L3NatExtensionTestCase(unittest.TestCase):
instance = self.plugin.return_value instance = self.plugin.return_value
instance.get_routers.return_value = return_value instance.get_routers.return_value = return_value
res = self.api.get(_get_path('routers')) res = self.api.get(_get_path('routers', fmt=self.fmt))
instance.get_routers.assert_called_with(mock.ANY, fields=mock.ANY, instance.get_routers.assert_called_with(mock.ANY, fields=mock.ANY,
filters=mock.ANY) filters=mock.ANY)
self.assertEqual(res.status_int, exc.HTTPOk.code) self.assertEqual(res.status_int, exc.HTTPOk.code)
self.assertTrue('routers' in res.json) res = self.deserialize(res)
self.assertEqual(1, len(res.json['routers'])) self.assertTrue('routers' in res)
self.assertEqual(router_id, res.json['routers'][0]['id']) self.assertEqual(1, len(res['routers']))
self.assertEqual(router_id, res['routers'][0]['id'])
def test_router_update(self): def test_router_update(self):
router_id = _uuid() router_id = _uuid()
@ -158,14 +165,16 @@ class L3NatExtensionTestCase(unittest.TestCase):
instance = self.plugin.return_value instance = self.plugin.return_value
instance.update_router.return_value = return_value instance.update_router.return_value = return_value
res = self.api.put_json(_get_path('routers', id=router_id), res = self.api.put(_get_path('routers', id=router_id,
update_data) fmt=self.fmt),
self.serialize(update_data))
instance.update_router.assert_called_with(mock.ANY, router_id, instance.update_router.assert_called_with(mock.ANY, router_id,
router=update_data) router=update_data)
self.assertEqual(res.status_int, exc.HTTPOk.code) self.assertEqual(res.status_int, exc.HTTPOk.code)
self.assertTrue('router' in res.json) res = self.deserialize(res)
router = res.json['router'] self.assertTrue('router' in res)
router = res['router']
self.assertEqual(router['id'], router_id) self.assertEqual(router['id'], router_id)
self.assertEqual(router['status'], "ACTIVE") self.assertEqual(router['status'], "ACTIVE")
self.assertEqual(router['admin_state_up'], False) self.assertEqual(router['admin_state_up'], False)
@ -179,13 +188,15 @@ class L3NatExtensionTestCase(unittest.TestCase):
instance = self.plugin.return_value instance = self.plugin.return_value
instance.get_router.return_value = return_value instance.get_router.return_value = return_value
res = self.api.get(_get_path('routers', id=router_id)) res = self.api.get(_get_path('routers', id=router_id,
fmt=self.fmt))
instance.get_router.assert_called_with(mock.ANY, router_id, instance.get_router.assert_called_with(mock.ANY, router_id,
fields=mock.ANY) fields=mock.ANY)
self.assertEqual(res.status_int, exc.HTTPOk.code) self.assertEqual(res.status_int, exc.HTTPOk.code)
self.assertTrue('router' in res.json) res = self.deserialize(res)
router = res.json['router'] self.assertTrue('router' in res)
router = res['router']
self.assertEqual(router['id'], router_id) self.assertEqual(router['id'], router_id)
self.assertEqual(router['status'], "ACTIVE") self.assertEqual(router['status'], "ACTIVE")
self.assertEqual(router['admin_state_up'], False) self.assertEqual(router['admin_state_up'], False)
@ -212,15 +223,21 @@ class L3NatExtensionTestCase(unittest.TestCase):
instance.add_router_interface.return_value = return_value instance.add_router_interface.return_value = return_value
path = _get_path('routers', id=router_id, path = _get_path('routers', id=router_id,
action="add_router_interface") action="add_router_interface",
res = self.api.put_json(path, interface_data) fmt=self.fmt)
res = self.api.put(path, self.serialize(interface_data))
instance.add_router_interface.assert_called_with(mock.ANY, router_id, instance.add_router_interface.assert_called_with(mock.ANY, router_id,
interface_data) interface_data)
self.assertEqual(res.status_int, exc.HTTPOk.code) self.assertEqual(res.status_int, exc.HTTPOk.code)
self.assertTrue('port_id' in res.json) res = self.deserialize(res)
self.assertEqual(res.json['port_id'], port_id) self.assertTrue('port_id' in res)
self.assertEqual(res.json['subnet_id'], subnet_id) self.assertEqual(res['port_id'], port_id)
self.assertEqual(res['subnet_id'], subnet_id)
class L3NatExtensionTestCaseXML(L3NatExtensionTestCase):
fmt = 'xml'
# This plugin class is just for testing # This plugin class is just for testing
@ -355,12 +372,12 @@ class L3NatDBTestCase(test_db_plugin.QuantumDbPluginV2TestCase):
"%s_router_interface" % action) "%s_router_interface" % action)
res = req.get_response(self.ext_api) res = req.get_response(self.ext_api)
self.assertEqual(res.status_int, expected_code) self.assertEqual(res.status_int, expected_code)
return self.deserialize('json', res) return self.deserialize(self.fmt, res)
@contextlib.contextmanager @contextlib.contextmanager
def router(self, name='router1', admin_status_up=True, def router(self, name='router1', admin_status_up=True,
fmt='json', tenant_id=_uuid(), set_context=False): fmt=None, tenant_id=_uuid(), set_context=False):
router = self._make_router(fmt, tenant_id, name, router = self._make_router(fmt or self.fmt, tenant_id, name,
admin_status_up, set_context) admin_status_up, set_context)
try: try:
yield router yield router
@ -385,9 +402,9 @@ class L3NatDBTestCase(test_db_plugin.QuantumDbPluginV2TestCase):
data['router']['name'] = 'router1' data['router']['name'] = 'router1'
data['router']['external_gateway_info'] = { data['router']['external_gateway_info'] = {
'network_id': s['subnet']['network_id']} 'network_id': s['subnet']['network_id']}
router_req = self.new_create_request('routers', data, 'json') router_req = self.new_create_request('routers', data, self.fmt)
res = router_req.get_response(self.ext_api) res = router_req.get_response(self.ext_api)
router = self.deserialize('json', res) router = self.deserialize(self.fmt, res)
self.assertEqual( self.assertEqual(
s['subnet']['network_id'], s['subnet']['network_id'],
router['router']['external_gateway_info']['network_id']) router['router']['external_gateway_info']['network_id'])
@ -844,8 +861,8 @@ class L3NatDBTestCase(test_db_plugin.QuantumDbPluginV2TestCase):
None, None,
p['port']['id']) p['port']['id'])
# create another port for testing failure case # create another port for testing failure case
res = self._create_port('json', p['port']['network_id']) res = self._create_port(self.fmt, p['port']['network_id'])
p2 = self.deserialize('json', res) p2 = self.deserialize(self.fmt, res)
self._router_interface_action('remove', self._router_interface_action('remove',
r['router']['id'], r['router']['id'],
None, None,
@ -867,10 +884,9 @@ class L3NatDBTestCase(test_db_plugin.QuantumDbPluginV2TestCase):
self.assertEqual(res.status_int, 404) self.assertEqual(res.status_int, 404)
def test_router_delete_with_port_existed_returns_409(self): def test_router_delete_with_port_existed_returns_409(self):
fmt = 'json'
with self.subnet() as subnet: with self.subnet() as subnet:
res = self._create_router(fmt, _uuid()) res = self._create_router(self.fmt, _uuid())
router = self.deserialize(fmt, res) router = self.deserialize(self.fmt, res)
self._router_interface_action('add', self._router_interface_action('add',
router['router']['id'], router['router']['id'],
subnet['subnet']['id'], subnet['subnet']['id'],
@ -888,10 +904,9 @@ class L3NatDBTestCase(test_db_plugin.QuantumDbPluginV2TestCase):
private_sub = {'subnet': {'id': private_sub = {'subnet': {'id':
p['port']['fixed_ips'][0]['subnet_id']}} p['port']['fixed_ips'][0]['subnet_id']}}
with self.subnet(cidr='12.0.0.0/24') as public_sub: with self.subnet(cidr='12.0.0.0/24') as public_sub:
fmt = 'json'
self._set_net_external(public_sub['subnet']['network_id']) self._set_net_external(public_sub['subnet']['network_id'])
res = self._create_router(fmt, _uuid()) res = self._create_router(self.fmt, _uuid())
r = self.deserialize(fmt, res) r = self.deserialize(self.fmt, res)
self._add_external_gateway_to_router( self._add_external_gateway_to_router(
r['router']['id'], r['router']['id'],
public_sub['subnet']['network_id']) public_sub['subnet']['network_id'])
@ -899,10 +914,10 @@ class L3NatDBTestCase(test_db_plugin.QuantumDbPluginV2TestCase):
private_sub['subnet']['id'], private_sub['subnet']['id'],
None) None)
res = self._create_floatingip( res = self._create_floatingip(
fmt, public_sub['subnet']['network_id'], self.fmt, public_sub['subnet']['network_id'],
port_id=p['port']['id']) port_id=p['port']['id'])
self.assertEqual(res.status_int, exc.HTTPCreated.code) self.assertEqual(res.status_int, exc.HTTPCreated.code)
floatingip = self.deserialize(fmt, res) floatingip = self.deserialize(self.fmt, res)
self._delete('routers', r['router']['id'], self._delete('routers', r['router']['id'],
expected_code=exc.HTTPConflict.code) expected_code=exc.HTTPConflict.code)
# Cleanup # Cleanup
@ -990,7 +1005,7 @@ class L3NatDBTestCase(test_db_plugin.QuantumDbPluginV2TestCase):
fip['floatingip']['id']) fip['floatingip']['id'])
@contextlib.contextmanager @contextlib.contextmanager
def floatingip_with_assoc(self, port_id=None, fmt='json', def floatingip_with_assoc(self, port_id=None, fmt=None,
set_context=False): set_context=False):
with self.subnet(cidr='11.0.0.0/24') as public_sub: with self.subnet(cidr='11.0.0.0/24') as public_sub:
self._set_net_external(public_sub['subnet']['network_id']) self._set_net_external(public_sub['subnet']['network_id'])
@ -1008,7 +1023,7 @@ class L3NatDBTestCase(test_db_plugin.QuantumDbPluginV2TestCase):
private_sub['subnet']['id'], None) private_sub['subnet']['id'], None)
floatingip = self._make_floatingip( floatingip = self._make_floatingip(
fmt, fmt or self.fmt,
public_sub['subnet']['network_id'], public_sub['subnet']['network_id'],
port_id=private_port['port']['id'], port_id=private_port['port']['id'],
set_context=False) set_context=False)
@ -1025,7 +1040,7 @@ class L3NatDBTestCase(test_db_plugin.QuantumDbPluginV2TestCase):
public_sub['subnet']['network_id']) public_sub['subnet']['network_id'])
@contextlib.contextmanager @contextlib.contextmanager
def floatingip_no_assoc(self, private_sub, fmt='json', set_context=False): def floatingip_no_assoc(self, private_sub, fmt=None, set_context=False):
with self.subnet(cidr='12.0.0.0/24') as public_sub: with self.subnet(cidr='12.0.0.0/24') as public_sub:
self._set_net_external(public_sub['subnet']['network_id']) self._set_net_external(public_sub['subnet']['network_id'])
with self.router() as r: with self.router() as r:
@ -1039,7 +1054,7 @@ class L3NatDBTestCase(test_db_plugin.QuantumDbPluginV2TestCase):
None) None)
floatingip = self._make_floatingip( floatingip = self._make_floatingip(
fmt, fmt or self.fmt,
public_sub['subnet']['network_id'], public_sub['subnet']['network_id'],
set_context=set_context) set_context=set_context)
yield floatingip yield floatingip
@ -1066,7 +1081,6 @@ class L3NatDBTestCase(test_db_plugin.QuantumDbPluginV2TestCase):
expected_code=exc.HTTPNotFound.code) expected_code=exc.HTTPNotFound.code)
def test_floatingip_with_assoc_fails(self): def test_floatingip_with_assoc_fails(self):
fmt = 'json'
with self.subnet(cidr='200.0.0.1/24') as public_sub: with self.subnet(cidr='200.0.0.1/24') as public_sub:
self._set_net_external(public_sub['subnet']['network_id']) self._set_net_external(public_sub['subnet']['network_id'])
with self.port() as private_port: with self.port() as private_port:
@ -1086,7 +1100,7 @@ class L3NatDBTestCase(test_db_plugin.QuantumDbPluginV2TestCase):
resource='floatingip', resource='floatingip',
msg='fake_error') msg='fake_error')
res = self._create_floatingip( res = self._create_floatingip(
fmt, self.fmt,
public_sub['subnet']['network_id'], public_sub['subnet']['network_id'],
port_id=private_port['port']['id']) port_id=private_port['port']['id'])
self.assertEqual(res.status_int, 400) self.assertEqual(res.status_int, 400)
@ -1151,7 +1165,7 @@ class L3NatDBTestCase(test_db_plugin.QuantumDbPluginV2TestCase):
def test_two_fips_one_port_invalid_return_409(self): def test_two_fips_one_port_invalid_return_409(self):
with self.floatingip_with_assoc() as fip1: with self.floatingip_with_assoc() as fip1:
res = self._create_floatingip( res = self._create_floatingip(
'json', self.fmt,
fip1['floatingip']['floating_network_id'], fip1['floatingip']['floating_network_id'],
fip1['floatingip']['port_id']) fip1['floatingip']['port_id'])
self.assertEqual(res.status_int, exc.HTTPConflict.code) self.assertEqual(res.status_int, exc.HTTPConflict.code)
@ -1172,7 +1186,7 @@ class L3NatDBTestCase(test_db_plugin.QuantumDbPluginV2TestCase):
with self.port() as private_port: with self.port() as private_port:
with self.router() as r: with self.router() as r:
res = self._create_floatingip( res = self._create_floatingip(
'json', self.fmt,
public_sub['subnet']['network_id'], public_sub['subnet']['network_id'],
port_id=private_port['port']['id']) port_id=private_port['port']['id'])
# this should be some kind of error # this should be some kind of error
@ -1185,7 +1199,7 @@ class L3NatDBTestCase(test_db_plugin.QuantumDbPluginV2TestCase):
# that is not the case # that is not the case
with self.router() as r: with self.router() as r:
res = self._create_floatingip( res = self._create_floatingip(
'json', self.fmt,
public_sub['subnet']['network_id']) public_sub['subnet']['network_id'])
self.assertEqual(res.status_int, exc.HTTPBadRequest.code) self.assertEqual(res.status_int, exc.HTTPBadRequest.code)
@ -1200,7 +1214,7 @@ class L3NatDBTestCase(test_db_plugin.QuantumDbPluginV2TestCase):
None) None)
res = self._create_floatingip( res = self._create_floatingip(
'json', self.fmt,
public_network['network']['id'], public_network['network']['id'],
port_id=private_port['port']['id']) port_id=private_port['port']['id'])
self.assertEqual(res.status_int, exc.HTTPBadRequest.code) self.assertEqual(res.status_int, exc.HTTPBadRequest.code)
@ -1212,19 +1226,19 @@ class L3NatDBTestCase(test_db_plugin.QuantumDbPluginV2TestCase):
def test_create_floatingip_invalid_floating_network_id_returns_400(self): def test_create_floatingip_invalid_floating_network_id_returns_400(self):
# API-level test - no need to create all objects for l3 plugin # API-level test - no need to create all objects for l3 plugin
res = self._create_floatingip('json', 'iamnotanuuid', res = self._create_floatingip(self.fmt, 'iamnotanuuid',
uuidutils.generate_uuid(), '192.168.0.1') uuidutils.generate_uuid(), '192.168.0.1')
self.assertEqual(res.status_int, 400) self.assertEqual(res.status_int, 400)
def test_create_floatingip_invalid_floating_port_id_returns_400(self): def test_create_floatingip_invalid_floating_port_id_returns_400(self):
# API-level test - no need to create all objects for l3 plugin # API-level test - no need to create all objects for l3 plugin
res = self._create_floatingip('json', uuidutils.generate_uuid(), res = self._create_floatingip(self.fmt, uuidutils.generate_uuid(),
'iamnotanuuid', '192.168.0.1') 'iamnotanuuid', '192.168.0.1')
self.assertEqual(res.status_int, 400) self.assertEqual(res.status_int, 400)
def test_create_floatingip_invalid_fixed_ip_address_returns_400(self): def test_create_floatingip_invalid_fixed_ip_address_returns_400(self):
# API-level test - no need to create all objects for l3 plugin # API-level test - no need to create all objects for l3 plugin
res = self._create_floatingip('json', uuidutils.generate_uuid(), res = self._create_floatingip(self.fmt, uuidutils.generate_uuid(),
uuidutils.generate_uuid(), 'iamnotnanip') uuidutils.generate_uuid(), 'iamnotnanip')
self.assertEqual(res.status_int, 400) self.assertEqual(res.status_int, 400)
@ -1466,3 +1480,7 @@ class L3NatDBTestCase(test_db_plugin.QuantumDbPluginV2TestCase):
fip['floatingip']['port_id']) fip['floatingip']['port_id'])
self.assertTrue(floatingips[0]['fixed_ip_address'] is not None) self.assertTrue(floatingips[0]['fixed_ip_address'] is not None)
self.assertTrue(floatingips[0]['router_id'] is not None) self.assertTrue(floatingips[0]['router_id'] is not None)
class L3NatDBTestCaseXML(L3NatDBTestCase):
fmt = 'xml'

View File

@ -22,6 +22,7 @@ from webob import exc
import webtest import webtest
from quantum.api import extensions from quantum.api import extensions
from quantum.api.v2 import attributes
from quantum.common import config from quantum.common import config
from quantum.extensions import loadbalancer from quantum.extensions import loadbalancer
from quantum import manager from quantum import manager
@ -30,6 +31,7 @@ from quantum.openstack.common import uuidutils
from quantum.plugins.common import constants from quantum.plugins.common import constants
from quantum.tests.unit import test_api_v2 from quantum.tests.unit import test_api_v2
from quantum.tests.unit import test_extensions from quantum.tests.unit import test_extensions
from quantum.tests.unit import testlib_api
_uuid = uuidutils.generate_uuid _uuid = uuidutils.generate_uuid
@ -48,7 +50,8 @@ class LoadBalancerTestExtensionManager(object):
return [] return []
class LoadBalancerExtensionTestCase(unittest2.TestCase): class LoadBalancerExtensionTestCase(testlib_api.WebTestCase):
fmt = 'json'
def setUp(self): def setUp(self):
@ -75,6 +78,7 @@ class LoadBalancerExtensionTestCase(unittest2.TestCase):
ext_mgr = LoadBalancerTestExtensionManager() ext_mgr = LoadBalancerTestExtensionManager()
self.ext_mdw = test_extensions.setup_extensions_middleware(ext_mgr) self.ext_mdw = test_extensions.setup_extensions_middleware(ext_mgr)
self.api = webtest.TestApp(self.ext_mdw) self.api = webtest.TestApp(self.ext_mdw)
super(LoadBalancerExtensionTestCase, self).setUp()
def tearDown(self): def tearDown(self):
self._plugin_patcher.stop() self._plugin_patcher.stop()
@ -100,12 +104,15 @@ class LoadBalancerExtensionTestCase(unittest2.TestCase):
instance = self.plugin.return_value instance = self.plugin.return_value
instance.create_vip.return_value = return_value instance.create_vip.return_value = return_value
res = self.api.post_json(_get_path('lb/vips'), data) res = self.api.post(_get_path('lb/vips', fmt=self.fmt),
self.serialize(data),
content_type='application/%s' % self.fmt)
instance.create_vip.assert_called_with(mock.ANY, instance.create_vip.assert_called_with(mock.ANY,
vip=data) vip=data)
self.assertEqual(res.status_int, exc.HTTPCreated.code) self.assertEqual(res.status_int, exc.HTTPCreated.code)
self.assertTrue('vip' in res.json) res = self.deserialize(res)
self.assertEqual(res.json['vip'], return_value) self.assertIn('vip', res)
self.assertEqual(res['vip'], return_value)
def test_vip_list(self): def test_vip_list(self):
vip_id = _uuid() vip_id = _uuid()
@ -117,7 +124,7 @@ class LoadBalancerExtensionTestCase(unittest2.TestCase):
instance = self.plugin.return_value instance = self.plugin.return_value
instance.get_vips.return_value = return_value instance.get_vips.return_value = return_value
res = self.api.get(_get_path('lb/vips')) res = self.api.get(_get_path('lb/vips', fmt=self.fmt))
instance.get_vips.assert_called_with(mock.ANY, fields=mock.ANY, instance.get_vips.assert_called_with(mock.ANY, fields=mock.ANY,
filters=mock.ANY) filters=mock.ANY)
@ -135,14 +142,15 @@ class LoadBalancerExtensionTestCase(unittest2.TestCase):
instance = self.plugin.return_value instance = self.plugin.return_value
instance.update_vip.return_value = return_value instance.update_vip.return_value = return_value
res = self.api.put_json(_get_path('lb/vips', id=vip_id), res = self.api.put(_get_path('lb/vips', id=vip_id, fmt=self.fmt),
update_data) self.serialize(update_data))
instance.update_vip.assert_called_with(mock.ANY, vip_id, instance.update_vip.assert_called_with(mock.ANY, vip_id,
vip=update_data) vip=update_data)
self.assertEqual(res.status_int, exc.HTTPOk.code) self.assertEqual(res.status_int, exc.HTTPOk.code)
self.assertTrue('vip' in res.json) res = self.deserialize(res)
self.assertEqual(res.json['vip'], return_value) self.assertIn('vip', res)
self.assertEqual(res['vip'], return_value)
def test_vip_get(self): def test_vip_get(self):
vip_id = _uuid() vip_id = _uuid()
@ -155,18 +163,20 @@ class LoadBalancerExtensionTestCase(unittest2.TestCase):
instance = self.plugin.return_value instance = self.plugin.return_value
instance.get_vip.return_value = return_value instance.get_vip.return_value = return_value
res = self.api.get(_get_path('lb/vips', id=vip_id)) res = self.api.get(_get_path('lb/vips', id=vip_id, fmt=self.fmt))
instance.get_vip.assert_called_with(mock.ANY, vip_id, instance.get_vip.assert_called_with(mock.ANY, vip_id,
fields=mock.ANY) fields=mock.ANY)
self.assertEqual(res.status_int, exc.HTTPOk.code) self.assertEqual(res.status_int, exc.HTTPOk.code)
self.assertTrue('vip' in res.json) res = self.deserialize(res)
self.assertEqual(res.json['vip'], return_value) self.assertIn('vip', res)
self.assertEqual(res['vip'], return_value)
def _test_entity_delete(self, entity): def _test_entity_delete(self, entity):
""" does the entity deletion based on naming convention """ """ does the entity deletion based on naming convention """
entity_id = _uuid() entity_id = _uuid()
res = self.api.delete(_get_path('lb/' + entity + 's', id=entity_id)) res = self.api.delete(_get_path('lb/' + entity + 's', id=entity_id,
fmt=self.fmt))
delete_entity = getattr(self.plugin.return_value, "delete_" + entity) delete_entity = getattr(self.plugin.return_value, "delete_" + entity)
delete_entity.assert_called_with(mock.ANY, entity_id) delete_entity.assert_called_with(mock.ANY, entity_id)
self.assertEqual(res.status_int, exc.HTTPNoContent.code) self.assertEqual(res.status_int, exc.HTTPNoContent.code)
@ -190,12 +200,15 @@ class LoadBalancerExtensionTestCase(unittest2.TestCase):
instance = self.plugin.return_value instance = self.plugin.return_value
instance.create_pool.return_value = return_value instance.create_pool.return_value = return_value
res = self.api.post_json(_get_path('lb/pools'), data) res = self.api.post(_get_path('lb/pools', fmt=self.fmt),
self.serialize(data),
content_type='application/%s' % self.fmt)
instance.create_pool.assert_called_with(mock.ANY, instance.create_pool.assert_called_with(mock.ANY,
pool=data) pool=data)
self.assertEqual(res.status_int, exc.HTTPCreated.code) self.assertEqual(res.status_int, exc.HTTPCreated.code)
self.assertTrue('pool' in res.json) res = self.deserialize(res)
self.assertEqual(res.json['pool'], return_value) self.assertIn('pool', res)
self.assertEqual(res['pool'], return_value)
def test_pool_list(self): def test_pool_list(self):
pool_id = _uuid() pool_id = _uuid()
@ -207,7 +220,7 @@ class LoadBalancerExtensionTestCase(unittest2.TestCase):
instance = self.plugin.return_value instance = self.plugin.return_value
instance.get_pools.return_value = return_value instance.get_pools.return_value = return_value
res = self.api.get(_get_path('lb/pools')) res = self.api.get(_get_path('lb/pools', fmt=self.fmt))
instance.get_pools.assert_called_with(mock.ANY, fields=mock.ANY, instance.get_pools.assert_called_with(mock.ANY, fields=mock.ANY,
filters=mock.ANY) filters=mock.ANY)
@ -225,14 +238,15 @@ class LoadBalancerExtensionTestCase(unittest2.TestCase):
instance = self.plugin.return_value instance = self.plugin.return_value
instance.update_pool.return_value = return_value instance.update_pool.return_value = return_value
res = self.api.put_json(_get_path('lb/pools', id=pool_id), res = self.api.put(_get_path('lb/pools', id=pool_id, fmt=self.fmt),
update_data) self.serialize(update_data))
instance.update_pool.assert_called_with(mock.ANY, pool_id, instance.update_pool.assert_called_with(mock.ANY, pool_id,
pool=update_data) pool=update_data)
self.assertEqual(res.status_int, exc.HTTPOk.code) self.assertEqual(res.status_int, exc.HTTPOk.code)
self.assertTrue('pool' in res.json) res = self.deserialize(res)
self.assertEqual(res.json['pool'], return_value) self.assertIn('pool', res)
self.assertEqual(res['pool'], return_value)
def test_pool_get(self): def test_pool_get(self):
pool_id = _uuid() pool_id = _uuid()
@ -245,13 +259,14 @@ class LoadBalancerExtensionTestCase(unittest2.TestCase):
instance = self.plugin.return_value instance = self.plugin.return_value
instance.get_pool.return_value = return_value instance.get_pool.return_value = return_value
res = self.api.get(_get_path('lb/pools', id=pool_id)) res = self.api.get(_get_path('lb/pools', id=pool_id, fmt=self.fmt))
instance.get_pool.assert_called_with(mock.ANY, pool_id, instance.get_pool.assert_called_with(mock.ANY, pool_id,
fields=mock.ANY) fields=mock.ANY)
self.assertEqual(res.status_int, exc.HTTPOk.code) self.assertEqual(res.status_int, exc.HTTPOk.code)
self.assertTrue('pool' in res.json) res = self.deserialize(res)
self.assertEqual(res.json['pool'], return_value) self.assertIn('pool', res)
self.assertEqual(res['pool'], return_value)
def test_pool_delete(self): def test_pool_delete(self):
self._test_entity_delete('pool') self._test_entity_delete('pool')
@ -264,13 +279,14 @@ class LoadBalancerExtensionTestCase(unittest2.TestCase):
instance.stats.return_value = stats instance.stats.return_value = stats
path = _get_path('lb/pools', id=pool_id, path = _get_path('lb/pools', id=pool_id,
action="stats") action="stats", fmt=self.fmt)
res = self.api.get(path) res = self.api.get(path)
instance.stats.assert_called_with(mock.ANY, pool_id) instance.stats.assert_called_with(mock.ANY, pool_id)
self.assertEqual(res.status_int, exc.HTTPOk.code) self.assertEqual(res.status_int, exc.HTTPOk.code)
self.assertTrue('stats' in res.json) res = self.deserialize(res)
self.assertEqual(res.json['stats'], stats['stats']) self.assertIn('stats', res)
self.assertEqual(res['stats'], stats['stats'])
def test_member_create(self): def test_member_create(self):
member_id = _uuid() member_id = _uuid()
@ -285,12 +301,15 @@ class LoadBalancerExtensionTestCase(unittest2.TestCase):
instance = self.plugin.return_value instance = self.plugin.return_value
instance.create_member.return_value = return_value instance.create_member.return_value = return_value
res = self.api.post_json(_get_path('lb/members'), data) res = self.api.post(_get_path('lb/members', fmt=self.fmt),
self.serialize(data),
content_type='application/%s' % self.fmt)
instance.create_member.assert_called_with(mock.ANY, instance.create_member.assert_called_with(mock.ANY,
member=data) member=data)
self.assertEqual(res.status_int, exc.HTTPCreated.code) self.assertEqual(res.status_int, exc.HTTPCreated.code)
self.assertTrue('member' in res.json) res = self.deserialize(res)
self.assertEqual(res.json['member'], return_value) self.assertIn('member', res)
self.assertEqual(res['member'], return_value)
def test_member_list(self): def test_member_list(self):
member_id = _uuid() member_id = _uuid()
@ -302,7 +321,7 @@ class LoadBalancerExtensionTestCase(unittest2.TestCase):
instance = self.plugin.return_value instance = self.plugin.return_value
instance.get_members.return_value = return_value instance.get_members.return_value = return_value
res = self.api.get(_get_path('lb/members')) res = self.api.get(_get_path('lb/members', fmt=self.fmt))
instance.get_members.assert_called_with(mock.ANY, fields=mock.ANY, instance.get_members.assert_called_with(mock.ANY, fields=mock.ANY,
filters=mock.ANY) filters=mock.ANY)
@ -319,14 +338,16 @@ class LoadBalancerExtensionTestCase(unittest2.TestCase):
instance = self.plugin.return_value instance = self.plugin.return_value
instance.update_member.return_value = return_value instance.update_member.return_value = return_value
res = self.api.put_json(_get_path('lb/members', id=member_id), res = self.api.put(_get_path('lb/members', id=member_id,
update_data) fmt=self.fmt),
self.serialize(update_data))
instance.update_member.assert_called_with(mock.ANY, member_id, instance.update_member.assert_called_with(mock.ANY, member_id,
member=update_data) member=update_data)
self.assertEqual(res.status_int, exc.HTTPOk.code) self.assertEqual(res.status_int, exc.HTTPOk.code)
self.assertTrue('member' in res.json) res = self.deserialize(res)
self.assertEqual(res.json['member'], return_value) self.assertIn('member', res)
self.assertEqual(res['member'], return_value)
def test_member_get(self): def test_member_get(self):
member_id = _uuid() member_id = _uuid()
@ -338,13 +359,15 @@ class LoadBalancerExtensionTestCase(unittest2.TestCase):
instance = self.plugin.return_value instance = self.plugin.return_value
instance.get_member.return_value = return_value instance.get_member.return_value = return_value
res = self.api.get(_get_path('lb/members', id=member_id)) res = self.api.get(_get_path('lb/members', id=member_id,
fmt=self.fmt))
instance.get_member.assert_called_with(mock.ANY, member_id, instance.get_member.assert_called_with(mock.ANY, member_id,
fields=mock.ANY) fields=mock.ANY)
self.assertEqual(res.status_int, exc.HTTPOk.code) self.assertEqual(res.status_int, exc.HTTPOk.code)
self.assertTrue('member' in res.json) res = self.deserialize(res)
self.assertEqual(res.json['member'], return_value) self.assertIn('member', res)
self.assertEqual(res['member'], return_value)
def test_member_delete(self): def test_member_delete(self):
self._test_entity_delete('member') self._test_entity_delete('member')
@ -365,12 +388,16 @@ class LoadBalancerExtensionTestCase(unittest2.TestCase):
instance = self.plugin.return_value instance = self.plugin.return_value
instance.create_health_monitor.return_value = return_value instance.create_health_monitor.return_value = return_value
res = self.api.post_json(_get_path('lb/health_monitors'), data) res = self.api.post(_get_path('lb/health_monitors',
fmt=self.fmt),
self.serialize(data),
content_type='application/%s' % self.fmt)
instance.create_health_monitor.assert_called_with(mock.ANY, instance.create_health_monitor.assert_called_with(mock.ANY,
health_monitor=data) health_monitor=data)
self.assertEqual(res.status_int, exc.HTTPCreated.code) self.assertEqual(res.status_int, exc.HTTPCreated.code)
self.assertTrue('health_monitor' in res.json) res = self.deserialize(res)
self.assertEqual(res.json['health_monitor'], return_value) self.assertIn('health_monitor', res)
self.assertEqual(res['health_monitor'], return_value)
def test_health_monitor_list(self): def test_health_monitor_list(self):
health_monitor_id = _uuid() health_monitor_id = _uuid()
@ -382,7 +409,7 @@ class LoadBalancerExtensionTestCase(unittest2.TestCase):
instance = self.plugin.return_value instance = self.plugin.return_value
instance.get_health_monitors.return_value = return_value instance.get_health_monitors.return_value = return_value
res = self.api.get(_get_path('lb/health_monitors')) res = self.api.get(_get_path('lb/health_monitors', fmt=self.fmt))
instance.get_health_monitors.assert_called_with( instance.get_health_monitors.assert_called_with(
mock.ANY, fields=mock.ANY, filters=mock.ANY) mock.ANY, fields=mock.ANY, filters=mock.ANY)
@ -400,15 +427,17 @@ class LoadBalancerExtensionTestCase(unittest2.TestCase):
instance = self.plugin.return_value instance = self.plugin.return_value
instance.update_health_monitor.return_value = return_value instance.update_health_monitor.return_value = return_value
res = self.api.put_json(_get_path('lb/health_monitors', res = self.api.put(_get_path('lb/health_monitors',
id=health_monitor_id), id=health_monitor_id,
update_data) fmt=self.fmt),
self.serialize(update_data))
instance.update_health_monitor.assert_called_with( instance.update_health_monitor.assert_called_with(
mock.ANY, health_monitor_id, health_monitor=update_data) mock.ANY, health_monitor_id, health_monitor=update_data)
self.assertEqual(res.status_int, exc.HTTPOk.code) self.assertEqual(res.status_int, exc.HTTPOk.code)
self.assertTrue('health_monitor' in res.json) res = self.deserialize(res)
self.assertEqual(res.json['health_monitor'], return_value) self.assertIn('health_monitor', res)
self.assertEqual(res['health_monitor'], return_value)
def test_health_monitor_get(self): def test_health_monitor_get(self):
health_monitor_id = _uuid() health_monitor_id = _uuid()
@ -422,13 +451,15 @@ class LoadBalancerExtensionTestCase(unittest2.TestCase):
instance.get_health_monitor.return_value = return_value instance.get_health_monitor.return_value = return_value
res = self.api.get(_get_path('lb/health_monitors', res = self.api.get(_get_path('lb/health_monitors',
id=health_monitor_id)) id=health_monitor_id,
fmt=self.fmt))
instance.get_health_monitor.assert_called_with( instance.get_health_monitor.assert_called_with(
mock.ANY, health_monitor_id, fields=mock.ANY) mock.ANY, health_monitor_id, fields=mock.ANY)
self.assertEqual(res.status_int, exc.HTTPOk.code) self.assertEqual(res.status_int, exc.HTTPOk.code)
self.assertTrue('health_monitor' in res.json) res = self.deserialize(res)
self.assertEqual(res.json['health_monitor'], return_value) self.assertIn('health_monitor', res)
self.assertEqual(res['health_monitor'], return_value)
def test_health_monitor_delete(self): def test_health_monitor_delete(self):
self._test_entity_delete('health_monitor') self._test_entity_delete('health_monitor')
@ -441,12 +472,15 @@ class LoadBalancerExtensionTestCase(unittest2.TestCase):
return_value = copy.copy(data['health_monitor']) return_value = copy.copy(data['health_monitor'])
instance = self.plugin.return_value instance = self.plugin.return_value
instance.create_pool_health_monitor.return_value = return_value instance.create_pool_health_monitor.return_value = return_value
res = self.api.post_json('/lb/pools/id1/health_monitors', data) res = self.api.post('/lb/pools/id1/health_monitors',
self.serialize(data),
content_type='application/%s' % self.fmt)
instance.create_pool_health_monitor.assert_called_with( instance.create_pool_health_monitor.assert_called_with(
mock.ANY, pool_id='id1', health_monitor=data) mock.ANY, pool_id='id1', health_monitor=data)
self.assertEqual(res.status_int, exc.HTTPCreated.code) self.assertEqual(res.status_int, exc.HTTPCreated.code)
self.assertTrue('health_monitor' in res.json) res = self.deserialize(res)
self.assertEqual(res.json['health_monitor'], return_value) self.assertIn('health_monitor', res)
self.assertEqual(res['health_monitor'], return_value)
def test_delete_pool_health_monitor(self): def test_delete_pool_health_monitor(self):
health_monitor_id = _uuid() health_monitor_id = _uuid()
@ -458,3 +492,7 @@ class LoadBalancerExtensionTestCase(unittest2.TestCase):
instance.delete_pool_health_monitor.assert_called_with( instance.delete_pool_health_monitor.assert_called_with(
mock.ANY, health_monitor_id, pool_id='id1') mock.ANY, health_monitor_id, pool_id='id1')
self.assertEqual(res.status_int, exc.HTTPNoContent.code) self.assertEqual(res.status_int, exc.HTTPNoContent.code)
class LoadBalancerExtensionTestCaseXML(LoadBalancerExtensionTestCase):
fmt = 'xml'

View File

@ -15,16 +15,16 @@ from quantum.plugins.linuxbridge.db import l2network_db_v2
from quantum import quota from quantum import quota
from quantum.tests.unit import test_api_v2 from quantum.tests.unit import test_api_v2
from quantum.tests.unit import test_extensions from quantum.tests.unit import test_extensions
from quantum.tests.unit import testlib_api
TARGET_PLUGIN = ('quantum.plugins.linuxbridge.lb_quantum_plugin' TARGET_PLUGIN = ('quantum.plugins.linuxbridge.lb_quantum_plugin'
'.LinuxBridgePluginV2') '.LinuxBridgePluginV2')
_get_path = test_api_v2._get_path _get_path = test_api_v2._get_path
class QuotaExtensionTestCase(unittest.TestCase): class QuotaExtensionTestCase(testlib_api.WebTestCase):
fmt = 'json'
def setUp(self): def setUp(self):
db._ENGINE = None db._ENGINE = None
@ -67,6 +67,7 @@ class QuotaExtensionTestCase(unittest.TestCase):
app = config.load_paste_app('extensions_test_app') app = config.load_paste_app('extensions_test_app')
ext_middleware = extensions.ExtensionMiddleware(app, ext_mgr=ext_mgr) ext_middleware = extensions.ExtensionMiddleware(app, ext_mgr=ext_mgr)
self.api = webtest.TestApp(ext_middleware) self.api = webtest.TestApp(ext_middleware)
super(QuotaExtensionTestCase, self).setUp()
def tearDown(self): def tearDown(self):
self._plugin_patcher.stop() self._plugin_patcher.stop()
@ -80,24 +81,27 @@ class QuotaExtensionTestCase(unittest.TestCase):
attributes.RESOURCE_ATTRIBUTE_MAP = self.saved_attr_map attributes.RESOURCE_ATTRIBUTE_MAP = self.saved_attr_map
def test_quotas_loaded_right(self): def test_quotas_loaded_right(self):
res = self.api.get(_get_path('quotas')) res = self.api.get(_get_path('quotas', fmt=self.fmt))
quota = self.deserialize(res)
self.assertEqual([], quota['quotas'])
self.assertEqual(200, res.status_int) self.assertEqual(200, res.status_int)
def test_quotas_default_values(self): def test_quotas_default_values(self):
tenant_id = 'tenant_id1' tenant_id = 'tenant_id1'
env = {'quantum.context': context.Context('', tenant_id)} env = {'quantum.context': context.Context('', tenant_id)}
res = self.api.get(_get_path('quotas', id=tenant_id), res = self.api.get(_get_path('quotas', id=tenant_id, fmt=self.fmt),
extra_environ=env) extra_environ=env)
self.assertEqual(10, res.json['quota']['network']) quota = self.deserialize(res)
self.assertEqual(10, res.json['quota']['subnet']) self.assertEqual(10, quota['quota']['network'])
self.assertEqual(50, res.json['quota']['port']) self.assertEqual(10, quota['quota']['subnet'])
self.assertEqual(-1, res.json['quota']['extra1']) self.assertEqual(50, quota['quota']['port'])
self.assertEqual(-1, quota['quota']['extra1'])
def test_show_quotas_with_admin(self): def test_show_quotas_with_admin(self):
tenant_id = 'tenant_id1' tenant_id = 'tenant_id1'
env = {'quantum.context': context.Context('', tenant_id + '2', env = {'quantum.context': context.Context('', tenant_id + '2',
is_admin=True)} is_admin=True)}
res = self.api.get(_get_path('quotas', id=tenant_id), res = self.api.get(_get_path('quotas', id=tenant_id, fmt=self.fmt),
extra_environ=env) extra_environ=env)
self.assertEqual(200, res.status_int) self.assertEqual(200, res.status_int)
@ -105,7 +109,7 @@ class QuotaExtensionTestCase(unittest.TestCase):
tenant_id = 'tenant_id1' tenant_id = 'tenant_id1'
env = {'quantum.context': context.Context('', tenant_id + '2', env = {'quantum.context': context.Context('', tenant_id + '2',
is_admin=False)} is_admin=False)}
res = self.api.get(_get_path('quotas', id=tenant_id), res = self.api.get(_get_path('quotas', id=tenant_id, fmt=self.fmt),
extra_environ=env, expect_errors=True) extra_environ=env, expect_errors=True)
self.assertEqual(403, res.status_int) self.assertEqual(403, res.status_int)
@ -114,10 +118,9 @@ class QuotaExtensionTestCase(unittest.TestCase):
env = {'quantum.context': context.Context('', tenant_id, env = {'quantum.context': context.Context('', tenant_id,
is_admin=False)} is_admin=False)}
quotas = {'quota': {'network': 100}} quotas = {'quota': {'network': 100}}
res = self.api.put_json(_get_path('quotas', id=tenant_id, res = self.api.put(_get_path('quotas', id=tenant_id, fmt=self.fmt),
fmt='json'), self.serialize(quotas), extra_environ=env,
quotas, extra_environ=env, expect_errors=True)
expect_errors=True)
self.assertEqual(403, res.status_int) self.assertEqual(403, res.status_int)
def test_update_quotas_with_admin(self): def test_update_quotas_with_admin(self):
@ -125,19 +128,20 @@ class QuotaExtensionTestCase(unittest.TestCase):
env = {'quantum.context': context.Context('', tenant_id + '2', env = {'quantum.context': context.Context('', tenant_id + '2',
is_admin=True)} is_admin=True)}
quotas = {'quota': {'network': 100}} quotas = {'quota': {'network': 100}}
res = self.api.put_json(_get_path('quotas', id=tenant_id, fmt='json'), res = self.api.put(_get_path('quotas', id=tenant_id, fmt=self.fmt),
quotas, extra_environ=env) self.serialize(quotas), extra_environ=env)
self.assertEqual(200, res.status_int) self.assertEqual(200, res.status_int)
env2 = {'quantum.context': context.Context('', tenant_id)} env2 = {'quantum.context': context.Context('', tenant_id)}
res = self.api.get(_get_path('quotas', id=tenant_id), res = self.api.get(_get_path('quotas', id=tenant_id, fmt=self.fmt),
extra_environ=env2).json extra_environ=env2)
self.assertEqual(100, res['quota']['network']) quota = self.deserialize(res)
self.assertEqual(100, quota['quota']['network'])
def test_delete_quotas_with_admin(self): def test_delete_quotas_with_admin(self):
tenant_id = 'tenant_id1' tenant_id = 'tenant_id1'
env = {'quantum.context': context.Context('', tenant_id + '2', env = {'quantum.context': context.Context('', tenant_id + '2',
is_admin=True)} is_admin=True)}
res = self.api.delete(_get_path('quotas', id=tenant_id, fmt='json'), res = self.api.delete(_get_path('quotas', id=tenant_id, fmt=self.fmt),
extra_environ=env) extra_environ=env)
self.assertEqual(204, res.status_int) self.assertEqual(204, res.status_int)
@ -145,7 +149,7 @@ class QuotaExtensionTestCase(unittest.TestCase):
tenant_id = 'tenant_id1' tenant_id = 'tenant_id1'
env = {'quantum.context': context.Context('', tenant_id, env = {'quantum.context': context.Context('', tenant_id,
is_admin=False)} is_admin=False)}
res = self.api.delete(_get_path('quotas', id=tenant_id, fmt='json'), res = self.api.delete(_get_path('quotas', id=tenant_id, fmt=self.fmt),
extra_environ=env, expect_errors=True) extra_environ=env, expect_errors=True)
self.assertEqual(403, res.status_int) self.assertEqual(403, res.status_int)
@ -161,8 +165,9 @@ class QuotaExtensionTestCase(unittest.TestCase):
env = {'quantum.context': context.Context('', tenant_id, env = {'quantum.context': context.Context('', tenant_id,
is_admin=True)} is_admin=True)}
quotas = {'quota': {'network': 5}} quotas = {'quota': {'network': 5}}
res = self.api.put_json(_get_path('quotas', id=tenant_id, fmt='json'), res = self.api.put(_get_path('quotas', id=tenant_id,
quotas, extra_environ=env) fmt=self.fmt),
self.serialize(quotas), extra_environ=env)
self.assertEqual(200, res.status_int) self.assertEqual(200, res.status_int)
quota.QUOTAS.limit_check(context.Context('', tenant_id), quota.QUOTAS.limit_check(context.Context('', tenant_id),
tenant_id, tenant_id,
@ -173,8 +178,9 @@ class QuotaExtensionTestCase(unittest.TestCase):
env = {'quantum.context': context.Context('', tenant_id, env = {'quantum.context': context.Context('', tenant_id,
is_admin=True)} is_admin=True)}
quotas = {'quota': {'network': 5}} quotas = {'quota': {'network': 5}}
res = self.api.put_json(_get_path('quotas', id=tenant_id, fmt='json'), res = self.api.put(_get_path('quotas', id=tenant_id,
quotas, extra_environ=env) fmt=self.fmt),
self.serialize(quotas), extra_environ=env)
self.assertEqual(200, res.status_int) self.assertEqual(200, res.status_int)
with self.assertRaises(exceptions.OverQuota): with self.assertRaises(exceptions.OverQuota):
quota.QUOTAS.limit_check(context.Context('', tenant_id), quota.QUOTAS.limit_check(context.Context('', tenant_id),
@ -187,3 +193,7 @@ class QuotaExtensionTestCase(unittest.TestCase):
quota.QUOTAS.limit_check(context.Context('', tenant_id), quota.QUOTAS.limit_check(context.Context('', tenant_id),
tenant_id, tenant_id,
network=-1) network=-1)
class QuotaExtensionTestCaseXML(QuotaExtensionTestCase):
fmt = 'xml'

View File

@ -67,14 +67,14 @@ class SGServerRpcCallBackMixinTestCase(test_sg.SecurityGroupDBTestCase):
rules = { rules = {
'security_group_rules': [rule1['security_group_rule'], 'security_group_rules': [rule1['security_group_rule'],
rule2['security_group_rule']]} rule2['security_group_rule']]}
res = self._create_security_group_rule('json', rules) res = self._create_security_group_rule(self.fmt, rules)
self.deserialize('json', res) self.deserialize(self.fmt, res)
self.assertEquals(res.status_int, 201) self.assertEquals(res.status_int, 201)
res1 = self._create_port( res1 = self._create_port(
'json', n['network']['id'], self.fmt, n['network']['id'],
security_groups=[sg1_id]) security_groups=[sg1_id])
ports_rest1 = self.deserialize('json', res1) ports_rest1 = self.deserialize(self.fmt, res1)
port_id1 = ports_rest1['port']['id'] port_id1 = ports_rest1['port']['id']
self.rpc.devices = {port_id1: ports_rest1['port']} self.rpc.devices = {port_id1: ports_rest1['port']}
devices = [port_id1, 'no_exist_device'] devices = [port_id1, 'no_exist_device']
@ -116,14 +116,14 @@ class SGServerRpcCallBackMixinTestCase(test_sg.SecurityGroupDBTestCase):
rules = { rules = {
'security_group_rules': [rule1['security_group_rule'], 'security_group_rules': [rule1['security_group_rule'],
rule2['security_group_rule']]} rule2['security_group_rule']]}
res = self._create_security_group_rule('json', rules) res = self._create_security_group_rule(self.fmt, rules)
self.deserialize('json', res) self.deserialize(self.fmt, res)
self.assertEquals(res.status_int, 201) self.assertEquals(res.status_int, 201)
res1 = self._create_port( res1 = self._create_port(
'json', n['network']['id'], self.fmt, n['network']['id'],
security_groups=[sg1_id]) security_groups=[sg1_id])
ports_rest1 = self.deserialize('json', res1) ports_rest1 = self.deserialize(self.fmt, res1)
port_id1 = ports_rest1['port']['id'] port_id1 = ports_rest1['port']['id']
self.rpc.devices = {port_id1: ports_rest1['port']} self.rpc.devices = {port_id1: ports_rest1['port']}
devices = [port_id1, 'no_exist_device'] devices = [port_id1, 'no_exist_device']
@ -162,23 +162,23 @@ class SGServerRpcCallBackMixinTestCase(test_sg.SecurityGroupDBTestCase):
'25', source_group_id=sg2['security_group']['id']) '25', source_group_id=sg2['security_group']['id'])
rules = { rules = {
'security_group_rules': [rule1['security_group_rule']]} 'security_group_rules': [rule1['security_group_rule']]}
res = self._create_security_group_rule('json', rules) res = self._create_security_group_rule(self.fmt, rules)
self.deserialize('json', res) self.deserialize(self.fmt, res)
self.assertEquals(res.status_int, 201) self.assertEquals(res.status_int, 201)
res1 = self._create_port( res1 = self._create_port(
'json', n['network']['id'], self.fmt, n['network']['id'],
security_groups=[sg1_id, security_groups=[sg1_id,
sg2_id]) sg2_id])
ports_rest1 = self.deserialize('json', res1) ports_rest1 = self.deserialize(self.fmt, res1)
port_id1 = ports_rest1['port']['id'] port_id1 = ports_rest1['port']['id']
self.rpc.devices = {port_id1: ports_rest1['port']} self.rpc.devices = {port_id1: ports_rest1['port']}
devices = [port_id1, 'no_exist_device'] devices = [port_id1, 'no_exist_device']
res2 = self._create_port( res2 = self._create_port(
'json', n['network']['id'], self.fmt, n['network']['id'],
security_groups=[sg2_id]) security_groups=[sg2_id])
ports_rest2 = self.deserialize('json', res2) ports_rest2 = self.deserialize(self.fmt, res2)
port_id2 = ports_rest2['port']['id'] port_id2 = ports_rest2['port']['id']
ctx = context.get_admin_context() ctx = context.get_admin_context()
ports_rpc = self.rpc.security_group_rules_for_devices( ports_rpc = self.rpc.security_group_rules_for_devices(
@ -219,15 +219,15 @@ class SGServerRpcCallBackMixinTestCase(test_sg.SecurityGroupDBTestCase):
rules = { rules = {
'security_group_rules': [rule1['security_group_rule'], 'security_group_rules': [rule1['security_group_rule'],
rule2['security_group_rule']]} rule2['security_group_rule']]}
res = self._create_security_group_rule('json', rules) res = self._create_security_group_rule(self.fmt, rules)
self.deserialize('json', res) self.deserialize(self.fmt, res)
self.assertEquals(res.status_int, 201) self.assertEquals(res.status_int, 201)
res1 = self._create_port( res1 = self._create_port(
'json', n['network']['id'], self.fmt, n['network']['id'],
fixed_ips=[{'subnet_id': subnet_v6['subnet']['id']}], fixed_ips=[{'subnet_id': subnet_v6['subnet']['id']}],
security_groups=[sg1_id]) security_groups=[sg1_id])
ports_rest1 = self.deserialize('json', res1) ports_rest1 = self.deserialize(self.fmt, res1)
port_id1 = ports_rest1['port']['id'] port_id1 = ports_rest1['port']['id']
self.rpc.devices = {port_id1: ports_rest1['port']} self.rpc.devices = {port_id1: ports_rest1['port']}
devices = [port_id1, 'no_exist_device'] devices = [port_id1, 'no_exist_device']
@ -273,15 +273,15 @@ class SGServerRpcCallBackMixinTestCase(test_sg.SecurityGroupDBTestCase):
rules = { rules = {
'security_group_rules': [rule1['security_group_rule'], 'security_group_rules': [rule1['security_group_rule'],
rule2['security_group_rule']]} rule2['security_group_rule']]}
res = self._create_security_group_rule('json', rules) res = self._create_security_group_rule(self.fmt, rules)
self.deserialize('json', res) self.deserialize(self.fmt, res)
self.assertEquals(res.status_int, 201) self.assertEquals(res.status_int, 201)
res1 = self._create_port( res1 = self._create_port(
'json', n['network']['id'], self.fmt, n['network']['id'],
fixed_ips=[{'subnet_id': subnet_v6['subnet']['id']}], fixed_ips=[{'subnet_id': subnet_v6['subnet']['id']}],
security_groups=[sg1_id]) security_groups=[sg1_id])
ports_rest1 = self.deserialize('json', res1) ports_rest1 = self.deserialize(self.fmt, res1)
port_id1 = ports_rest1['port']['id'] port_id1 = ports_rest1['port']['id']
self.rpc.devices = {port_id1: ports_rest1['port']} self.rpc.devices = {port_id1: ports_rest1['port']}
devices = [port_id1, 'no_exist_device'] devices = [port_id1, 'no_exist_device']
@ -325,25 +325,25 @@ class SGServerRpcCallBackMixinTestCase(test_sg.SecurityGroupDBTestCase):
source_group_id=sg2['security_group']['id']) source_group_id=sg2['security_group']['id'])
rules = { rules = {
'security_group_rules': [rule1['security_group_rule']]} 'security_group_rules': [rule1['security_group_rule']]}
res = self._create_security_group_rule('json', rules) res = self._create_security_group_rule(self.fmt, rules)
self.deserialize('json', res) self.deserialize(self.fmt, res)
self.assertEquals(res.status_int, 201) self.assertEquals(res.status_int, 201)
res1 = self._create_port( res1 = self._create_port(
'json', n['network']['id'], self.fmt, n['network']['id'],
fixed_ips=[{'subnet_id': subnet_v6['subnet']['id']}], fixed_ips=[{'subnet_id': subnet_v6['subnet']['id']}],
security_groups=[sg1_id, security_groups=[sg1_id,
sg2_id]) sg2_id])
ports_rest1 = self.deserialize('json', res1) ports_rest1 = self.deserialize(self.fmt, res1)
port_id1 = ports_rest1['port']['id'] port_id1 = ports_rest1['port']['id']
self.rpc.devices = {port_id1: ports_rest1['port']} self.rpc.devices = {port_id1: ports_rest1['port']}
devices = [port_id1, 'no_exist_device'] devices = [port_id1, 'no_exist_device']
res2 = self._create_port( res2 = self._create_port(
'json', n['network']['id'], self.fmt, n['network']['id'],
fixed_ips=[{'subnet_id': subnet_v6['subnet']['id']}], fixed_ips=[{'subnet_id': subnet_v6['subnet']['id']}],
security_groups=[sg2_id]) security_groups=[sg2_id])
ports_rest2 = self.deserialize('json', res2) ports_rest2 = self.deserialize(self.fmt, res2)
port_id2 = ports_rest2['port']['id'] port_id2 = ports_rest2['port']['id']
ctx = context.get_admin_context() ctx = context.get_admin_context()
@ -364,6 +364,10 @@ class SGServerRpcCallBackMixinTestCase(test_sg.SecurityGroupDBTestCase):
self._delete('ports', port_id2) self._delete('ports', port_id2)
class SGServerRpcCallBackMixinTestCaseXML(SGServerRpcCallBackMixinTestCase):
fmt = 'xml'
class SGAgentRpcCallBackMixinTestCase(unittest.TestCase): class SGAgentRpcCallBackMixinTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.rpc = sg_rpc.SecurityGroupAgentRpcCallbackMixin() self.rpc = sg_rpc.SecurityGroupAgentRpcCallbackMixin()

View File

@ -26,9 +26,9 @@ import webob.exc as webexc
import webtest import webtest
from quantum.api import extensions from quantum.api import extensions
from quantum.api.v2 import attributes
from quantum import context from quantum import context
from quantum.db import api as db_api from quantum.db import api as db_api
from quantum.db import models_v2
from quantum.db import servicetype_db from quantum.db import servicetype_db
from quantum.extensions import servicetype from quantum.extensions import servicetype
from quantum import manager from quantum import manager
@ -38,6 +38,7 @@ from quantum.tests.unit import dummy_plugin as dp
from quantum.tests.unit import test_api_v2 from quantum.tests.unit import test_api_v2
from quantum.tests.unit import test_db_plugin from quantum.tests.unit import test_db_plugin
from quantum.tests.unit import test_extensions from quantum.tests.unit import test_extensions
from quantum.tests.unit import testlib_api
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -62,7 +63,8 @@ class TestServiceTypeExtensionManager(object):
return [] return []
class ServiceTypeTestCaseBase(unittest.TestCase): class ServiceTypeTestCaseBase(testlib_api.WebTestCase):
fmt = 'json'
def setUp(self): def setUp(self):
# This is needed because otherwise a failure will occur due to # This is needed because otherwise a failure will occur due to
@ -79,6 +81,7 @@ class ServiceTypeTestCaseBase(unittest.TestCase):
self.ext_mdw = test_extensions.setup_extensions_middleware(ext_mgr) self.ext_mdw = test_extensions.setup_extensions_middleware(ext_mgr)
self.api = webtest.TestApp(self.ext_mdw) self.api = webtest.TestApp(self.ext_mdw)
self.resource_name = servicetype.RESOURCE_NAME.replace('-', '_') self.resource_name = servicetype.RESOURCE_NAME.replace('-', '_')
super(ServiceTypeTestCaseBase, self).setUp()
def tearDown(self): def tearDown(self):
self.api = None self.api = None
@ -119,15 +122,18 @@ class ServiceTypeExtensionTestCase(ServiceTypeTestCaseBase):
instance = self.mock_mgr.return_value instance = self.mock_mgr.return_value
instance.create_service_type.return_value = return_value instance.create_service_type.return_value = return_value
expect_errors = expected_status >= webexc.HTTPBadRequest.code expect_errors = expected_status >= webexc.HTTPBadRequest.code
res = self.api.post_json(_get_path('service-types'), data, res = self.api.post(_get_path('service-types', fmt=self.fmt),
extra_environ=env, self.serialize(data),
expect_errors=expect_errors) extra_environ=env,
expect_errors=expect_errors,
content_type='application/%s' % self.fmt)
self.assertEqual(res.status_int, expected_status) self.assertEqual(res.status_int, expected_status)
if not expect_errors: if not expect_errors:
instance.create_service_type.assert_called_with(mock.ANY, instance.create_service_type.assert_called_with(mock.ANY,
service_type=data) service_type=data)
self.assertTrue(self.resource_name in res.json) res = self.deserialize(res)
svc_type = res.json[self.resource_name] self.assertTrue(self.resource_name in res)
svc_type = res[self.resource_name]
self.assertEqual(svc_type['id'], svc_type_id) self.assertEqual(svc_type['id'], svc_type_id)
# NOTE(salvatore-orlando): The following two checks are # NOTE(salvatore-orlando): The following two checks are
# probably not essential # probably not essential
@ -149,15 +155,17 @@ class ServiceTypeExtensionTestCase(ServiceTypeTestCaseBase):
instance = self.mock_mgr.return_value instance = self.mock_mgr.return_value
expect_errors = expected_status >= webexc.HTTPBadRequest.code expect_errors = expected_status >= webexc.HTTPBadRequest.code
instance.update_service_type.return_value = return_value instance.update_service_type.return_value = return_value
res = self.api.put_json(_get_path('service-types/%s' % svc_type_id), res = self.api.put(_get_path('service-types/%s' % svc_type_id,
data) fmt=self.fmt),
self.serialize(data))
if not expect_errors: if not expect_errors:
instance.update_service_type.assert_called_with(mock.ANY, instance.update_service_type.assert_called_with(mock.ANY,
svc_type_id, svc_type_id,
service_type=data) service_type=data)
self.assertEqual(res.status_int, webexc.HTTPOk.code) self.assertEqual(res.status_int, webexc.HTTPOk.code)
self.assertTrue(self.resource_name in res.json) res = self.deserialize(res)
svc_type = res.json[self.resource_name] self.assertTrue(self.resource_name in res)
svc_type = res[self.resource_name]
self.assertEqual(svc_type['id'], svc_type_id) self.assertEqual(svc_type['id'], svc_type_id)
self.assertEqual(svc_type['name'], self.assertEqual(svc_type['name'],
data[self.resource_name]['name']) data[self.resource_name]['name'])
@ -171,7 +179,8 @@ class ServiceTypeExtensionTestCase(ServiceTypeTestCaseBase):
def test_service_type_delete(self): def test_service_type_delete(self):
svctype_id = _uuid() svctype_id = _uuid()
instance = self.mock_mgr.return_value instance = self.mock_mgr.return_value
res = self.api.delete(_get_path('service-types/%s' % svctype_id)) res = self.api.delete(_get_path('service-types/%s' % svctype_id,
fmt=self.fmt))
instance.delete_service_type.assert_called_with(mock.ANY, instance.delete_service_type.assert_called_with(mock.ANY,
svctype_id) svctype_id)
self.assertEqual(res.status_int, webexc.HTTPNoContent.code) self.assertEqual(res.status_int, webexc.HTTPNoContent.code)
@ -185,7 +194,8 @@ class ServiceTypeExtensionTestCase(ServiceTypeTestCaseBase):
instance = self.mock_mgr.return_value instance = self.mock_mgr.return_value
instance.get_service_type.return_value = return_value instance.get_service_type.return_value = return_value
res = self.api.get(_get_path('service-types/%s' % svctype_id)) res = self.api.get(_get_path('service-types/%s' % svctype_id,
fmt=self.fmt))
instance.get_service_type.assert_called_with(mock.ANY, instance.get_service_type.assert_called_with(mock.ANY,
svctype_id, svctype_id,
@ -201,7 +211,8 @@ class ServiceTypeExtensionTestCase(ServiceTypeTestCaseBase):
instance = self.mock_mgr.return_value instance = self.mock_mgr.return_value
instance.get_service_types.return_value = return_value instance.get_service_types.return_value = return_value
res = self.api.get(_get_path('service-types')) res = self.api.get(_get_path('service-types',
fmt=self.fmt))
instance.get_service_types.assert_called_with(mock.ANY, instance.get_service_types.assert_called_with(mock.ANY,
fields=mock.ANY, fields=mock.ANY,
@ -231,6 +242,10 @@ class ServiceTypeExtensionTestCase(ServiceTypeTestCaseBase):
self._test_service_type_update(env=env) self._test_service_type_update(env=env)
class ServiceTypeExtensionTestCaseXML(ServiceTypeExtensionTestCase):
fmt = 'xml'
class ServiceTypeManagerTestCase(ServiceTypeTestCaseBase): class ServiceTypeManagerTestCase(ServiceTypeTestCaseBase):
def setUp(self): def setUp(self):
@ -256,7 +271,7 @@ class ServiceTypeManagerTestCase(ServiceTypeTestCaseBase):
service_defs = [{'service_class': constants.DUMMY, service_defs = [{'service_class': constants.DUMMY,
'plugin': dp.DUMMY_PLUGIN_NAME}] 'plugin': dp.DUMMY_PLUGIN_NAME}]
res = self._create_service_type(name, service_defs) res = self._create_service_type(name, service_defs)
svc_type = res.json svc_type = self.deserialize(res)
if res.status_int >= 400: if res.status_int >= 400:
raise webexc.HTTPClientError(code=res.status_int) raise webexc.HTTPClientError(code=res.status_int)
yield svc_type yield svc_type
@ -269,10 +284,11 @@ class ServiceTypeManagerTestCase(ServiceTypeTestCaseBase):
self._delete_service_type(svc_type[self.resource_name]['id']) self._delete_service_type(svc_type[self.resource_name]['id'])
def _list_service_types(self): def _list_service_types(self):
return self.api.get(_get_path('service-types')) return self.api.get(_get_path('service-types', fmt=self.fmt))
def _show_service_type(self, svctype_id, expect_errors=False): def _show_service_type(self, svctype_id, expect_errors=False):
return self.api.get(_get_path('service-types/%s' % str(svctype_id)), return self.api.get(_get_path('service-types/%s' % str(svctype_id),
fmt=self.fmt),
expect_errors=expect_errors) expect_errors=expect_errors)
def _create_service_type(self, name, service_defs, def _create_service_type(self, name, service_defs,
@ -285,14 +301,19 @@ class ServiceTypeManagerTestCase(ServiceTypeTestCaseBase):
data[self.resource_name]['default'] = default data[self.resource_name]['default'] = default
if 'tenant_id' not in data[self.resource_name]: if 'tenant_id' not in data[self.resource_name]:
data[self.resource_name]['tenant_id'] = 'fake' data[self.resource_name]['tenant_id'] = 'fake'
return self.api.post_json(_get_path('service-types'), data, return self.api.post(_get_path('service-types', fmt=self.fmt),
expect_errors=expect_errors) self.serialize(data),
expect_errors=expect_errors,
content_type='application/%s' % self.fmt)
def _create_dummy(self, dummyname='dummyobject'): def _create_dummy(self, dummyname='dummyobject'):
data = {'dummy': {'name': dummyname, data = {'dummy': {'name': dummyname,
'tenant_id': 'fake'}} 'tenant_id': 'fake'}}
dummy_res = self.api.post_json(_get_path('dummys'), data) dummy_res = self.api.post(_get_path('dummys', fmt=self.fmt),
return dummy_res.json['dummy'] self.serialize(data),
content_type='application/%s' % self.fmt)
dummy_res = self.deserialize(dummy_res)
return dummy_res['dummy']
def _update_service_type(self, svc_type_id, name, service_defs, def _update_service_type(self, svc_type_id, name, service_defs,
default=None, expect_errors=False): default=None, expect_errors=False):
@ -303,28 +324,34 @@ class ServiceTypeManagerTestCase(ServiceTypeTestCaseBase):
# set this attribute only if True # set this attribute only if True
if default: if default:
data[self.resource_name]['default'] = default data[self.resource_name]['default'] = default
return self.api.put_json( return self.api.put(
_get_path('service-types/%s' % str(svc_type_id)), data, _get_path('service-types/%s' % str(svc_type_id), fmt=self.fmt),
self.serialize(data),
expect_errors=expect_errors) expect_errors=expect_errors)
def _delete_service_type(self, svctype_id, expect_errors=False): def _delete_service_type(self, svctype_id, expect_errors=False):
return self.api.delete(_get_path('service-types/%s' % str(svctype_id)), return self.api.delete(_get_path('service-types/%s' % str(svctype_id),
fmt=self.fmt),
expect_errors=expect_errors) expect_errors=expect_errors)
def _validate_service_type(self, res, name, service_defs, def _validate_service_type(self, res, name, service_defs,
svc_type_id=None): svc_type_id=None):
self.assertTrue(self.resource_name in res.json) res = self.deserialize(res)
svc_type = res.json[self.resource_name] self.assertTrue(self.resource_name in res)
svc_type = res[self.resource_name]
if svc_type_id: if svc_type_id:
self.assertEqual(svc_type['id'], svc_type_id) self.assertEqual(svc_type['id'], svc_type_id)
if name: if name:
self.assertEqual(svc_type['name'], name) self.assertEqual(svc_type['name'], name)
if service_defs: if service_defs:
target_defs = []
# unspecified drivers will value None in response # unspecified drivers will value None in response
for svc_def in service_defs: for svc_def in service_defs:
svc_def['driver'] = svc_def.get('driver') new_svc_def = svc_def.copy()
new_svc_def['driver'] = svc_def.get('driver')
target_defs.append(new_svc_def)
self.assertEqual(svc_type['service_definitions'], self.assertEqual(svc_type['service_definitions'],
service_defs) target_defs)
self.assertEqual(svc_type['default'], False) self.assertEqual(svc_type['default'], False)
def _test_service_type_create(self, name='test', def _test_service_type_create(self, name='test',
@ -390,7 +417,7 @@ class ServiceTypeManagerTestCase(ServiceTypeTestCaseBase):
self.service_type('st2')): self.service_type('st2')):
res = self._list_service_types() res = self._list_service_types()
self.assertEqual(res.status_int, webexc.HTTPOk.code) self.assertEqual(res.status_int, webexc.HTTPOk.code)
data = res.json data = self.deserialize(res)
self.assertTrue('service_types' in data) self.assertTrue('service_types' in data)
# it must be 3 because we have the default service type too! # it must be 3 because we have the default service type too!
self.assertEquals(len(data['service_types']), 3) self.assertEquals(len(data['service_types']), 3)
@ -398,7 +425,7 @@ class ServiceTypeManagerTestCase(ServiceTypeTestCaseBase):
def test_get_default_service_type(self): def test_get_default_service_type(self):
res = self._list_service_types() res = self._list_service_types()
self.assertEqual(res.status_int, webexc.HTTPOk.code) self.assertEqual(res.status_int, webexc.HTTPOk.code)
data = res.json data = self.deserialize(res)
self.assertTrue('service_types' in data) self.assertTrue('service_types' in data)
self.assertEquals(len(data['service_types']), 1) self.assertEquals(len(data['service_types']), 1)
def_svc_type = data['service_types'][0] def_svc_type = data['service_types'][0]
@ -426,15 +453,23 @@ class ServiceTypeManagerTestCase(ServiceTypeTestCaseBase):
def test_create_dummy_increases_service_type_refcount(self): def test_create_dummy_increases_service_type_refcount(self):
dummy = self._create_dummy() dummy = self._create_dummy()
svc_type_res = self._show_service_type(dummy['service_type']) svc_type_res = self._show_service_type(dummy['service_type'])
svc_type = svc_type_res.json[self.resource_name] svc_type_res = self.deserialize(svc_type_res)
svc_type = svc_type_res[self.resource_name]
self.assertEquals(svc_type['num_instances'], 1) self.assertEquals(svc_type['num_instances'], 1)
def test_delete_dummy_decreases_service_type_refcount(self): def test_delete_dummy_decreases_service_type_refcount(self):
dummy = self._create_dummy() dummy = self._create_dummy()
svc_type_res = self._show_service_type(dummy['service_type']) svc_type_res = self._show_service_type(dummy['service_type'])
svc_type = svc_type_res.json[self.resource_name] svc_type_res = self.deserialize(svc_type_res)
svc_type = svc_type_res[self.resource_name]
self.assertEquals(svc_type['num_instances'], 1) self.assertEquals(svc_type['num_instances'], 1)
self.api.delete(_get_path('dummys/%s' % str(dummy['id']))) self.api.delete(_get_path('dummys/%s' % str(dummy['id']),
fmt=self.fmt))
svc_type_res = self._show_service_type(dummy['service_type']) svc_type_res = self._show_service_type(dummy['service_type'])
svc_type = svc_type_res.json[self.resource_name] svc_type_res = self.deserialize(svc_type_res)
svc_type = svc_type_res[self.resource_name]
self.assertEquals(svc_type['num_instances'], 0) self.assertEquals(svc_type['num_instances'], 0)
class ServiceTypeManagerTestCaseXML(ServiceTypeManagerTestCase):
fmt = 'xml'

View File

@ -15,11 +15,13 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import mock
import socket import socket
import mock
import unittest2 as unittest import unittest2 as unittest
from quantum.api.v2 import attributes
from quantum.common import constants
from quantum.common import exceptions as exception from quantum.common import exceptions as exception
from quantum import wsgi from quantum import wsgi
@ -256,3 +258,116 @@ class ResourceTest(unittest.TestCase):
request = FakeRequest() request = FakeRequest()
result = resource(request) result = resource(request)
self.assertEqual(400, result.status_int) self.assertEqual(400, result.status_int)
class XMLDictSerializerTest(unittest.TestCase):
def test_xml(self):
NETWORK = {'network': {'test': None,
'tenant_id': 'test-tenant',
'name': 'net1',
'admin_state_up': True,
'subnets': [],
'dict': {},
'int': 3,
'long': 4L,
'float': 5.0,
'prefix:external': True,
'tests': [{'test1': 'value1'},
{'test2': 2, 'test3': 3}]}}
# XML is:
# <network xmlns="http://openstack.org/quantum/api/v2.0"
# xmlns:prefix="http://xxxx.yy.com"
# xmlns:quantum="http://openstack.org/quantum/api/v2.0"
# xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance">
# <subnets quantum:type="list" /> # Empty List
# <int quantum:type="int">3</int> # Integer text
# <int quantum:type="long">4</int> # Long text
# <int quantum:type="float">5.0</int> # Float text
# <dict quantum:type="dict" /> # Empty Dict
# <name>net1</name>
# <admin_state_up quantum:type="bool">True</admin_state_up> # Bool
# <test xsi:nil="true" /> # None
# <tenant_id>test-tenant</tenant_id>
# # We must have a namespace defined in root for prefix:external
# <prefix:external quantum:type="bool">True</prefix:external>
# <tests> # List
# <test><test1>value1</test1></test>
# <test><test3 quantum:type="int">3</test3>
# <test2 quantum:type="int">2</test2>
# </test></tests>
# </network>
metadata = attributes.get_attr_metadata()
ns = {'prefix': 'http://xxxx.yy.com'}
metadata[constants.EXT_NS] = ns
metadata['plurals'] = {'tests': 'test'}
serializer = wsgi.XMLDictSerializer(metadata)
result = serializer.serialize(NETWORK)
deserializer = wsgi.XMLDeserializer(metadata)
new_net = deserializer.deserialize(result)['body']
self.assertEqual(NETWORK, new_net)
def test_None(self):
data = None
# Since it is None, we use xsi:nil='true'.
# In addition, we use an
# virtual XML root _v_root to wrap the XML doc.
# XML is:
# <_v_root xsi:nil="true"
# xmlns="http://openstack.org/quantum/api/v2.0"
# xmlns:quantum="http://openstack.org/quantum/api/v2.0"
# xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" />
serializer = wsgi.XMLDictSerializer(attributes.get_attr_metadata())
result = serializer.serialize(data)
deserializer = wsgi.XMLDeserializer(attributes.get_attr_metadata())
new_data = deserializer.deserialize(result)['body']
self.assertIsNone(new_data)
def test_empty_dic_xml(self):
data = {}
# Since it is an empty dict, we use quantum:type='dict' and
# an empty XML element to represent it. In addition, we use an
# virtual XML root _v_root to wrap the XML doc.
# XML is:
# <_v_root quantum:type="dict"
# xmlns="http://openstack.org/quantum/api/v2.0"
# xmlns:quantum="http://openstack.org/quantum/api/v2.0"
# xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" />
serializer = wsgi.XMLDictSerializer(attributes.get_attr_metadata())
result = serializer.serialize(data)
deserializer = wsgi.XMLDeserializer(attributes.get_attr_metadata())
new_data = deserializer.deserialize(result)['body']
self.assertEqual(data, new_data)
def test_non_root_one_item_dic_xml(self):
data = {'test1': 1}
# We have a key in this dict, and its value is an integer.
# XML is:
# <test1 quantum:type="int"
# xmlns="http://openstack.org/quantum/api/v2.0"
# xmlns:quantum="http://openstack.org/quantum/api/v2.0"
# xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance">
# 1</test1>
serializer = wsgi.XMLDictSerializer(attributes.get_attr_metadata())
result = serializer.serialize(data)
deserializer = wsgi.XMLDeserializer(attributes.get_attr_metadata())
new_data = deserializer.deserialize(result)['body']
self.assertEqual(data, new_data)
def test_non_root_two_items_dic_xml(self):
data = {'test1': 1, 'test2': '2'}
# We have no root element in this data, We will use a virtual
# root element _v_root to wrap the doct.
# The XML is:
# <_v_root xmlns="http://openstack.org/quantum/api/v2.0"
# xmlns:quantum="http://openstack.org/quantum/api/v2.0"
# xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance">
# <test1 quantum:type="int">1</test1><test2>2</test2>
# </_v_root>
serializer = wsgi.XMLDictSerializer(attributes.get_attr_metadata())
result = serializer.serialize(data)
deserializer = wsgi.XMLDeserializer(attributes.get_attr_metadata())
new_data = deserializer.deserialize(result)['body']
self.assertEqual(data, new_data)

View File

@ -13,6 +13,9 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import unittest2 as unittest
from quantum.api.v2 import attributes
from quantum import wsgi from quantum import wsgi
@ -30,3 +33,28 @@ def create_request(path, body, content_type, method='GET',
if context: if context:
req.environ['quantum.context'] = context req.environ['quantum.context'] = context
return req return req
class WebTestCase(unittest.TestCase):
fmt = 'json'
def setUp(self):
json_deserializer = wsgi.JSONDeserializer()
xml_deserializer = wsgi.XMLDeserializer(
attributes.get_attr_metadata())
self._deserializers = {
'application/json': json_deserializer,
'application/xml': xml_deserializer,
}
super(WebTestCase, self).setUp()
def deserialize(self, response):
ctype = 'application/%s' % self.fmt
data = self._deserializers[ctype].deserialize(response.body)['body']
return data
def serialize(self, data):
ctype = 'application/%s' % self.fmt
result = wsgi.Serializer(
attributes.get_attr_metadata()).serialize(data, ctype)
return result

View File

@ -20,7 +20,7 @@ Utility methods for working with WSGI servers
""" """
import socket import socket
import sys import sys
from xml.dom import minidom from xml.etree import ElementTree as etree
from xml.parsers import expat from xml.parsers import expat
import eventlet.wsgi import eventlet.wsgi
@ -29,6 +29,7 @@ import routes.middleware
import webob.dec import webob.dec
import webob.exc import webob.exc
from quantum.common import constants
from quantum.common import exceptions as exception from quantum.common import exceptions as exception
from quantum import context from quantum import context
from quantum.openstack.common import jsonutils from quantum.openstack.common import jsonutils
@ -173,12 +174,12 @@ class Request(webob.Request):
2) Content-type header 2) Content-type header
3) Accept* headers 3) Accept* headers
""" """
# First lookup http request # First lookup http request path
parts = self.path.rsplit('.', 1) parts = self.path.rsplit('.', 1)
if len(parts) > 1: if len(parts) > 1:
format = parts[1] _format = parts[1]
if format in ['json', 'xml']: if _format in ['json', 'xml']:
return 'application/{0}'.format(parts[1]) return 'application/{0}'.format(_format)
#Then look up content header #Then look up content header
type_from_header = self.get_content_type() type_from_header = self.get_content_type()
@ -195,9 +196,9 @@ class Request(webob.Request):
if "Content-Type" not in self.headers: if "Content-Type" not in self.headers:
LOG.debug(_("Missing Content-Type")) LOG.debug(_("Missing Content-Type"))
return None return None
type = self.content_type _type = self.content_type
if type in allowed_types: if _type in allowed_types:
return type return _type
return None return None
@property @property
@ -247,15 +248,31 @@ class XMLDictSerializer(DictSerializer):
""" """
super(XMLDictSerializer, self).__init__() super(XMLDictSerializer, self).__init__()
self.metadata = metadata or {} self.metadata = metadata or {}
if not xmlns:
xmlns = self.metadata.get('xmlns')
if not xmlns:
xmlns = constants.XML_NS_V20
self.xmlns = xmlns self.xmlns = xmlns
def default(self, data): def default(self, data):
# We expect data to contain a single key which is the XML root. # We expect data to contain a single key which is the XML root or
root_key = data.keys()[0] # non root
doc = minidom.Document() try:
node = self._to_xml_node(doc, self.metadata, root_key, data[root_key]) key_len = data and len(data.keys()) or 0
if (key_len == 1):
return self.to_xml_string(node) root_key = data.keys()[0]
root_value = data[root_key]
else:
root_key = constants.VIRTUAL_ROOT_KEY
root_value = data
doc = etree.Element("_temp_root")
used_prefixes = []
self._to_xml_node(doc, self.metadata, root_key,
root_value, used_prefixes)
return self.to_xml_string(list(doc)[0], used_prefixes)
except AttributeError as e:
LOG.exception(str(e))
return ''
def __call__(self, data): def __call__(self, data):
# Provides a migration path to a cleaner WSGI layer, this # Provides a migration path to a cleaner WSGI layer, this
@ -263,39 +280,36 @@ class XMLDictSerializer(DictSerializer):
# like originally intended # like originally intended
return self.default(data) return self.default(data)
def to_xml_string(self, node, has_atom=False): def to_xml_string(self, node, used_prefixes, has_atom=False):
self._add_xmlns(node, has_atom) self._add_xmlns(node, used_prefixes, has_atom)
return node.toxml('UTF-8') return etree.tostring(node, encoding='UTF-8')
#NOTE (ameade): the has_atom should be removed after all of the #NOTE (ameade): the has_atom should be removed after all of the
# xml serializers and view builders have been updated to the current # xml serializers and view builders have been updated to the current
# spec that required all responses include the xmlns:atom, the has_atom # spec that required all responses include the xmlns:atom, the has_atom
# flag is to prevent current tests from breaking # flag is to prevent current tests from breaking
def _add_xmlns(self, node, has_atom=False): def _add_xmlns(self, node, used_prefixes, has_atom=False):
if self.xmlns is not None: node.set('xmlns', self.xmlns)
node.setAttribute('xmlns', self.xmlns) node.set(constants.TYPE_XMLNS, self.xmlns)
if has_atom: if has_atom:
node.setAttribute('xmlns:atom', "http://www.w3.org/2005/Atom") node.set('xmlns:atom', "http://www.w3.org/2005/Atom")
node.set(constants.XSI_NIL_ATTR, constants.XSI_NAMESPACE)
ext_ns = self.metadata.get(constants.EXT_NS, {})
for prefix in used_prefixes:
if prefix in ext_ns:
node.set('xmlns:' + prefix, ext_ns[prefix])
def _to_xml_node(self, doc, metadata, nodename, data): def _to_xml_node(self, parent, metadata, nodename, data, used_prefixes):
"""Recursive method to convert data members to XML nodes.""" """Recursive method to convert data members to XML nodes."""
result = doc.createElement(nodename) result = etree.SubElement(parent, nodename)
if ":" in nodename:
# Set the xml namespace if one is specified used_prefixes.append(nodename.split(":", 1)[0])
# TODO(justinsb): We could also use prefixes on the keys
xmlns = metadata.get('xmlns', None)
if xmlns:
result.setAttribute('xmlns', xmlns)
#TODO(bcwaldon): accomplish this without a type-check #TODO(bcwaldon): accomplish this without a type-check
if isinstance(data, list): if isinstance(data, list):
collections = metadata.get('list_collections', {}) if not data:
if nodename in collections: result.set(
metadata = collections[nodename] constants.TYPE_ATTR,
for item in data: constants.TYPE_LIST)
node = doc.createElement(metadata['item_name'])
node.setAttribute(metadata['item_key'], str(item))
result.appendChild(node)
return result return result
singular = metadata.get('plurals', {}).get(nodename, None) singular = metadata.get('plurals', {}).get(nodename, None)
if singular is None: if singular is None:
@ -304,41 +318,55 @@ class XMLDictSerializer(DictSerializer):
else: else:
singular = 'item' singular = 'item'
for item in data: for item in data:
node = self._to_xml_node(doc, metadata, singular, item) self._to_xml_node(result, metadata, singular, item,
result.appendChild(node) used_prefixes)
#TODO(bcwaldon): accomplish this without a type-check #TODO(bcwaldon): accomplish this without a type-check
elif isinstance(data, dict): elif isinstance(data, dict):
collections = metadata.get('dict_collections', {}) if not data:
if nodename in collections: result.set(
metadata = collections[nodename] constants.TYPE_ATTR,
for k, v in data.items(): constants.TYPE_DICT)
node = doc.createElement(metadata['item_name'])
node.setAttribute(metadata['item_key'], str(k))
text = doc.createTextNode(str(v))
node.appendChild(text)
result.appendChild(node)
return result return result
attrs = metadata.get('attributes', {}).get(nodename, {}) attrs = metadata.get('attributes', {}).get(nodename, {})
for k, v in data.items(): for k, v in data.items():
if k in attrs: if k in attrs:
result.setAttribute(k, str(v)) result.set(k, str(v))
else: else:
node = self._to_xml_node(doc, metadata, k, v) self._to_xml_node(result, metadata, k, v,
result.appendChild(node) used_prefixes)
elif data is None:
result.set(constants.XSI_ATTR, 'true')
else: else:
# Type is atom if isinstance(data, bool):
node = doc.createTextNode(str(data)) result.set(
result.appendChild(node) constants.TYPE_ATTR,
constants.TYPE_BOOL)
elif isinstance(data, int):
result.set(
constants.TYPE_ATTR,
constants.TYPE_INT)
elif isinstance(data, long):
result.set(
constants.TYPE_ATTR,
constants.TYPE_LONG)
elif isinstance(data, float):
result.set(
constants.TYPE_ATTR,
constants.TYPE_FLOAT)
LOG.debug(_("Data %(data)s type is %(type)s"),
{'data': data,
'type': type(data)})
result.text = str(data)
return result return result
def _create_link_nodes(self, xml_doc, links): def _create_link_nodes(self, xml_doc, links):
link_nodes = [] link_nodes = []
for link in links: for link in links:
link_node = xml_doc.createElement('atom:link') link_node = xml_doc.createElement('atom:link')
link_node.setAttribute('rel', link['rel']) link_node.set('rel', link['rel'])
link_node.setAttribute('href', link['href']) link_node.set('href', link['href'])
if 'type' in link: if 'type' in link:
link_node.setAttribute('type', link['type']) link_node.set('type', link['type'])
link_nodes.append(link_node) link_nodes.append(link_node)
return link_nodes return link_nodes
@ -426,15 +454,51 @@ class XMLDeserializer(TextDeserializer):
""" """
super(XMLDeserializer, self).__init__() super(XMLDeserializer, self).__init__()
self.metadata = metadata or {} self.metadata = metadata or {}
xmlns = self.metadata.get('xmlns')
if not xmlns:
xmlns = constants.XML_NS_V20
self.xmlns = xmlns
def _get_key(self, tag):
tags = tag.split("}", 1)
if len(tags) == 2:
ns = tags[0][1:]
bare_tag = tags[1]
ext_ns = self.metadata.get(constants.EXT_NS, {})
if ns == self.xmlns:
return bare_tag
for prefix, _ns in ext_ns.items():
if ns == _ns:
return prefix + ":" + bare_tag
else:
return tag
def _from_xml(self, datastring): def _from_xml(self, datastring):
if datastring is None:
return None
plurals = set(self.metadata.get('plurals', {})) plurals = set(self.metadata.get('plurals', {}))
try: try:
node = minidom.parseString(datastring).childNodes[0] node = etree.fromstring(datastring)
return {node.nodeName: self._from_xml_node(node, plurals)} result = self._from_xml_node(node, plurals)
except expat.ExpatError: root_tag = self._get_key(node.tag)
msg = _("Cannot understand XML") if root_tag == constants.VIRTUAL_ROOT_KEY:
raise exception.MalformedRequestBody(reason=msg) return result
else:
return {root_tag: result}
except Exception as e:
parseError = False
# Python2.7
if (hasattr(etree, 'ParseError') and
isinstance(e, getattr(etree, 'ParseError'))):
parseError = True
# Python2.6
elif isinstance(e, expat.ExpatError):
parseError = True
if parseError:
msg = _("Cannot understand XML")
raise exception.MalformedRequestBody(reason=msg)
else:
raise
def _from_xml_node(self, node, listnames): def _from_xml_node(self, node, listnames):
"""Convert a minidom node to a simple Python type. """Convert a minidom node to a simple Python type.
@ -443,18 +507,46 @@ class XMLDeserializer(TextDeserializer):
be considered list items. be considered list items.
""" """
if len(node.childNodes) == 1 and node.childNodes[0].nodeType == 3: attrNil = node.get(str(etree.QName(constants.XSI_NAMESPACE, "nil")))
return node.childNodes[0].nodeValue attrType = node.get(str(etree.QName(
elif node.nodeName in listnames: self.metadata.get('xmlns'), "type")))
return [self._from_xml_node(n, listnames) for n in node.childNodes] if (attrNil and attrNil.lower() == 'true'):
return None
elif not len(node) and not node.text:
if (attrType and attrType == constants.TYPE_DICT):
return {}
elif (attrType and attrType == constants.TYPE_LIST):
return []
else:
return ''
elif (len(node) == 0 and node.text):
converters = {constants.TYPE_BOOL:
lambda x: x.lower() == 'true',
constants.TYPE_INT:
lambda x: int(x),
constants.TYPE_LONG:
lambda x: long(x),
constants.TYPE_FLOAT:
lambda x: float(x)}
if attrType and attrType in converters:
return converters[attrType](node.text)
else:
return node.text
elif self._get_key(node.tag) in listnames:
return [self._from_xml_node(n, listnames) for n in node]
else: else:
result = dict() result = dict()
for attr in node.attributes.keys(): for attr in node.keys():
result[attr] = node.attributes[attr].nodeValue if (attr == 'xmlns' or
for child in node.childNodes: attr.startswith('xmlns:') or
if child.nodeType != node.TEXT_NODE: attr == constants.XSI_ATTR or
result[child.nodeName] = self._from_xml_node(child, attr == constants.TYPE_ATTR):
listnames) continue
result[self._get_key(attr)] = node.get[attr]
children = list(node)
for child in children:
result[self._get_key(child.tag)] = self._from_xml_node(
child, listnames)
return result return result
def find_first_child_named(self, parent, name): def find_first_child_named(self, parent, name):
@ -960,7 +1052,7 @@ class Controller(object):
""" """
_metadata = getattr(type(self), '_serialization_metadata', {}) _metadata = getattr(type(self), '_serialization_metadata', {})
serializer = Serializer(_metadata) serializer = Serializer(_metadata)
return serializer.deserialize(data, content_type) return serializer.deserialize(data, content_type)['body']
def get_default_xmlns(self, req): def get_default_xmlns(self, req):
"""Provide the XML namespace to use if none is otherwise specified.""" """Provide the XML namespace to use if none is otherwise specified."""
@ -984,8 +1076,8 @@ class Serializer(object):
def _get_serialize_handler(self, content_type): def _get_serialize_handler(self, content_type):
handlers = { handlers = {
'application/json': self._to_json, 'application/json': JSONDictSerializer(),
'application/xml': self._to_xml, 'application/xml': XMLDictSerializer(self.metadata),
} }
try: try:
@ -995,7 +1087,7 @@ class Serializer(object):
def serialize(self, data, content_type): def serialize(self, data, content_type):
"""Serialize a dictionary into the specified content type.""" """Serialize a dictionary into the specified content type."""
return self._get_serialize_handler(content_type)(data) return self._get_serialize_handler(content_type).serialize(data)
def deserialize(self, datastring, content_type): def deserialize(self, datastring, content_type):
"""Deserialize a string to a dictionary. """Deserialize a string to a dictionary.
@ -1004,115 +1096,18 @@ class Serializer(object):
""" """
try: try:
return self.get_deserialize_handler(content_type)(datastring) return self.get_deserialize_handler(content_type).deserialize(
datastring)
except Exception: except Exception:
raise webob.exc.HTTPBadRequest(_("Could not deserialize data")) raise webob.exc.HTTPBadRequest(_("Could not deserialize data"))
def get_deserialize_handler(self, content_type): def get_deserialize_handler(self, content_type):
handlers = { handlers = {
'application/json': self._from_json, 'application/json': JSONDeserializer(),
'application/xml': self._from_xml, 'application/xml': XMLDeserializer(self.metadata),
} }
try: try:
return handlers[content_type] return handlers[content_type]
except Exception: except Exception:
raise exception.InvalidContentType(content_type=content_type) raise exception.InvalidContentType(content_type=content_type)
def _from_json(self, datastring):
return jsonutils.loads(datastring)
def _from_xml(self, datastring):
xmldata = self.metadata.get('application/xml', {})
plurals = set(xmldata.get('plurals', {}))
node = minidom.parseString(datastring).childNodes[0]
return {node.nodeName: self._from_xml_node(node, plurals)}
def _from_xml_node(self, node, listnames):
"""Convert a minidom node to a simple Python type.
listnames is a collection of names of XML nodes whose subnodes should
be considered list items.
"""
if len(node.childNodes) == 1 and node.childNodes[0].nodeType == 3:
return node.childNodes[0].nodeValue
elif node.nodeName in listnames:
return [self._from_xml_node(n, listnames)
for n in node.childNodes if n.nodeType != node.TEXT_NODE]
else:
result = dict()
for attr in node.attributes.keys():
result[attr] = node.attributes[attr].nodeValue
for child in node.childNodes:
if child.nodeType != node.TEXT_NODE:
result[child.nodeName] = self._from_xml_node(child,
listnames)
return result
def _to_json(self, data):
return jsonutils.dumps(data)
def _to_xml(self, data):
metadata = self.metadata.get('application/xml', {})
# We expect data to contain a single key which is the XML root.
root_key = data.keys()[0]
doc = minidom.Document()
node = self._to_xml_node(doc, metadata, root_key, data[root_key])
xmlns = node.getAttribute('xmlns')
if not xmlns and self.default_xmlns:
node.setAttribute('xmlns', self.default_xmlns)
return node.toprettyxml(indent='', newl='')
def _to_xml_node(self, doc, metadata, nodename, data):
"""Recursive method to convert data members to XML nodes."""
result = doc.createElement(nodename)
# Set the xml namespace if one is specified
# TODO(justinsb): We could also use prefixes on the keys
xmlns = metadata.get('xmlns', None)
if xmlns:
result.setAttribute('xmlns', xmlns)
if isinstance(data, list):
collections = metadata.get('list_collections', {})
if nodename in collections:
metadata = collections[nodename]
for item in data:
node = doc.createElement(metadata['item_name'])
node.setAttribute(metadata['item_key'], str(item))
result.appendChild(node)
return result
singular = metadata.get('plurals', {}).get(nodename, None)
if singular is None:
if nodename.endswith('s'):
singular = nodename[:-1]
else:
singular = 'item'
for item in data:
node = self._to_xml_node(doc, metadata, singular, item)
result.appendChild(node)
elif isinstance(data, dict):
collections = metadata.get('dict_collections', {})
if nodename in collections:
metadata = collections[nodename]
for k, v in data.items():
node = doc.createElement(metadata['item_name'])
node.setAttribute(metadata['item_key'], str(k))
text = doc.createTextNode(str(v))
node.appendChild(text)
result.appendChild(node)
return result
attrs = metadata.get('attributes', {}).get(nodename, {})
for k, v in data.items():
if k in attrs:
result.setAttribute(k, str(v))
else:
node = self._to_xml_node(doc, metadata, k, v)
result.appendChild(node)
else:
# Type is atom.
node = doc.createTextNode(str(data))
result.appendChild(node)
return result