added tests to ensure the easy api works as a backend for Compute API

This commit is contained in:
Andy Smith
2010-12-22 17:53:42 -08:00
parent 43e9f8727a
commit 70d254c626
8 changed files with 100 additions and 30 deletions

View File

@@ -31,7 +31,6 @@ The general flow of a request is:
""" """
import json
import urllib import urllib
import routes import routes
@@ -39,6 +38,7 @@ import webob
from nova import context from nova import context
from nova import flags from nova import flags
from nova import utils
from nova import wsgi from nova import wsgi
# prxy compute_api in amazon tests # prxy compute_api in amazon tests
@@ -65,7 +65,7 @@ class JsonParamsMiddleware(wsgi.Middleware):
return return
params_json = request.params['json'] params_json = request.params['json']
params_parsed = json.loads(params_json) params_parsed = utils.loads(params_json)
params = {} params = {}
for k, v in params_parsed.iteritems(): for k, v in params_parsed.iteritems():
if k in ('self', 'context'): if k in ('self', 'context'):
@@ -125,7 +125,7 @@ class ServiceWrapper(wsgi.Controller):
method = getattr(self.service_handle, action) method = getattr(self.service_handle, action)
result = method(context, **params) result = method(context, **params)
if type(result) is dict: if type(result) is dict or type(result) is list:
return self._serialize(result, req) return self._serialize(result, req)
else: else:
return result return result
@@ -140,11 +140,11 @@ class Proxy(object):
def __do_request(self, path, context, **kwargs): def __do_request(self, path, context, **kwargs):
req = webob.Request.blank(path) req = webob.Request.blank(path)
req.method = 'POST' req.method = 'POST'
req.body = urllib.urlencode({'json': json.dumps(kwargs)}) req.body = urllib.urlencode({'json': utils.dumps(kwargs)})
req.environ['openstack.context'] = context req.environ['openstack.context'] = context
resp = req.get_response(self.app) resp = req.get_response(self.app)
try: try:
return json.loads(resp.body) return utils.loads(resp.body)
except Exception: except Exception:
return resp.body return resp.body

View File

@@ -118,7 +118,8 @@ class CloudController(object):
def _get_mpi_data(self, context, project_id): def _get_mpi_data(self, context, project_id):
result = {} result = {}
for instance in self.compute_api.get_instances(context, project_id): for instance in self.compute_api.get_instances(context,
project_id=project_id):
if instance['fixed_ip']: if instance['fixed_ip']:
line = '%s slots=%d' % (instance['fixed_ip']['address'], line = '%s slots=%d' % (instance['fixed_ip']['address'],
instance['vcpus']) instance['vcpus'])
@@ -442,7 +443,8 @@ class CloudController(object):
# instance_id is passed in as a list of instances # instance_id is passed in as a list of instances
ec2_id = instance_id[0] ec2_id = instance_id[0]
internal_id = ec2_id_to_internal_id(ec2_id) internal_id = ec2_id_to_internal_id(ec2_id)
instance_ref = self.compute_api.get_instance(context, internal_id) instance_ref = self.compute_api.get_instance(context,
instance_id=internal_id)
output = rpc.call(context, output = rpc.call(context,
'%s.%s' % (FLAGS.compute_topic, '%s.%s' % (FLAGS.compute_topic,
instance_ref['host']), instance_ref['host']),
@@ -541,7 +543,8 @@ class CloudController(object):
if volume_ref['attach_status'] == "attached": if volume_ref['attach_status'] == "attached":
raise exception.ApiError(_("Volume is already attached")) raise exception.ApiError(_("Volume is already attached"))
internal_id = ec2_id_to_internal_id(instance_id) internal_id = ec2_id_to_internal_id(instance_id)
instance_ref = self.compute_api.get_instance(context, internal_id) instance_ref = self.compute_api.get_instance(context,
instance_id=internal_id)
host = instance_ref['host'] host = instance_ref['host']
rpc.cast(context, rpc.cast(context,
db.queue_get_for(context, FLAGS.compute_topic, host), db.queue_get_for(context, FLAGS.compute_topic, host),
@@ -722,14 +725,15 @@ class CloudController(object):
def associate_address(self, context, instance_id, public_ip, **kwargs): def associate_address(self, context, instance_id, public_ip, **kwargs):
internal_id = ec2_id_to_internal_id(instance_id) internal_id = ec2_id_to_internal_id(instance_id)
instance_ref = self.compute_api.get_instance(context, internal_id) instance_ref = self.compute_api.get_instance(context,
instance_id=internal_id)
fixed_address = db.instance_get_fixed_address(context, fixed_address = db.instance_get_fixed_address(context,
instance_ref['id']) instance_ref['id'])
floating_ip_ref = db.floating_ip_get_by_address(context, public_ip) floating_ip_ref = db.floating_ip_get_by_address(context, public_ip)
# NOTE(vish): Perhaps we should just pass this on to compute and # NOTE(vish): Perhaps we should just pass this on to compute and
# let compute communicate with network. # let compute communicate with network.
network_topic = self.compute_api.get_network_topic(context, network_topic = self.compute_api.get_network_topic(
internal_id) context, instance_id=internal_id)
rpc.cast(context, rpc.cast(context,
network_topic, network_topic,
{"method": "associate_floating_ip", {"method": "associate_floating_ip",
@@ -754,8 +758,9 @@ class CloudController(object):
def run_instances(self, context, **kwargs): def run_instances(self, context, **kwargs):
max_count = int(kwargs.get('max_count', 1)) max_count = int(kwargs.get('max_count', 1))
instances = self.compute_api.create_instances(context, instances = self.compute_api.create_instances(context,
instance_types.get_by_type(kwargs.get('instance_type', None)), instance_type=instance_types.get_by_type(
kwargs['image_id'], kwargs.get('instance_type', None)),
image_id=kwargs['image_id'],
min_count=int(kwargs.get('min_count', max_count)), min_count=int(kwargs.get('min_count', max_count)),
max_count=max_count, max_count=max_count,
kernel_id=kwargs.get('kernel_id', None), kernel_id=kwargs.get('kernel_id', None),
@@ -765,7 +770,7 @@ class CloudController(object):
key_name=kwargs.get('key_name'), key_name=kwargs.get('key_name'),
user_data=kwargs.get('user_data'), user_data=kwargs.get('user_data'),
security_group=kwargs.get('security_group'), security_group=kwargs.get('security_group'),
generate_hostname=internal_id_to_ec2_id) hostname_format='ec2')
return self._format_run_instances(context, return self._format_run_instances(context,
instances[0]['reservation_id']) instances[0]['reservation_id'])
@@ -775,26 +780,26 @@ class CloudController(object):
logging.debug("Going to start terminating instances") logging.debug("Going to start terminating instances")
for ec2_id in instance_id: for ec2_id in instance_id:
internal_id = ec2_id_to_internal_id(ec2_id) internal_id = ec2_id_to_internal_id(ec2_id)
self.compute_api.delete_instance(context, internal_id) self.compute_api.delete_instance(context, instance_id=internal_id)
return True return True
def reboot_instances(self, context, instance_id, **kwargs): def reboot_instances(self, context, instance_id, **kwargs):
"""instance_id is a list of instance ids""" """instance_id is a list of instance ids"""
for ec2_id in instance_id: for ec2_id in instance_id:
internal_id = ec2_id_to_internal_id(ec2_id) internal_id = ec2_id_to_internal_id(ec2_id)
self.compute_api.reboot(context, internal_id) self.compute_api.reboot(context, instance_id=internal_id)
return True return True
def rescue_instance(self, context, instance_id, **kwargs): def rescue_instance(self, context, instance_id, **kwargs):
"""This is an extension to the normal ec2_api""" """This is an extension to the normal ec2_api"""
internal_id = ec2_id_to_internal_id(instance_id) internal_id = ec2_id_to_internal_id(instance_id)
self.compute_api.rescue(context, internal_id) self.compute_api.rescue(context, instance_id=internal_id)
return True return True
def unrescue_instance(self, context, instance_id, **kwargs): def unrescue_instance(self, context, instance_id, **kwargs):
"""This is an extension to the normal ec2_api""" """This is an extension to the normal ec2_api"""
internal_id = ec2_id_to_internal_id(instance_id) internal_id = ec2_id_to_internal_id(instance_id)
self.compute_api.unrescue(context, internal_id) self.compute_api.unrescue(context, instance_id=internal_id)
return True return True
def update_instance(self, context, ec2_id, **kwargs): def update_instance(self, context, ec2_id, **kwargs):
@@ -805,7 +810,8 @@ class CloudController(object):
changes[field] = kwargs[field] changes[field] = kwargs[field]
if changes: if changes:
internal_id = ec2_id_to_internal_id(ec2_id) internal_id = ec2_id_to_internal_id(ec2_id)
inst = self.compute_api.get_instance(context, internal_id) inst = self.compute_api.get_instance(context,
instance_id=internal_id)
db.instance_update(context, inst['id'], kwargs) db.instance_update(context, inst['id'], kwargs)
return True return True

View File

@@ -36,10 +36,19 @@ from nova.db import base
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
def generate_default_hostname(internal_id): def id_to_default_hostname(internal_id):
"""Default function to generate a hostname given an instance reference.""" """Default function to generate a hostname given an instance reference."""
return str(internal_id) return str(internal_id)
def id_to_ec2_hostname(internal_id):
digits = []
while internal_id != 0:
internal_id, remainder = divmod(internal_id, 36)
digits.append('0123456789abcdefghijklmnopqrstuvwxyz'[remainder])
return "i-%s" % ''.join(reversed(digits))
HOSTNAME_FORMATTERS = {'default': id_to_default_hostname,
'ec2': id_to_ec2_hostname}
class ComputeAPI(base.Base): class ComputeAPI(base.Base):
"""API for interacting with the compute manager.""" """API for interacting with the compute manager."""
@@ -75,7 +84,7 @@ class ComputeAPI(base.Base):
display_name='', description='', key_name=None, display_name='', description='', key_name=None,
key_data=None, security_group='default', key_data=None, security_group='default',
user_data=None, user_data=None,
generate_hostname=generate_default_hostname): hostname_format='default'):
"""Create the number of instances requested if quote and """Create the number of instances requested if quote and
other arguments check out ok.""" other arguments check out ok."""
@@ -144,6 +153,7 @@ class ComputeAPI(base.Base):
elevated = context.elevated() elevated = context.elevated()
instances = [] instances = []
generate_hostname = HOSTNAME_FORMATTERS[hostname_format]
logging.debug(_("Going to run %s instances..."), num_instances) logging.debug(_("Going to run %s instances..."), num_instances)
for num in range(num_instances): for num in range(num_instances):
instance = dict(mac_address=utils.generate_mac(), instance = dict(mac_address=utils.generate_mac(),
@@ -177,7 +187,7 @@ class ComputeAPI(base.Base):
"args": {"topic": FLAGS.compute_topic, "args": {"topic": FLAGS.compute_topic,
"instance_id": instance_id}}) "instance_id": instance_id}})
return instances return [dict(x.iteritems()) for x in instances]
def ensure_default_security_group(self, context): def ensure_default_security_group(self, context):
""" Create security group for the security context if it """ Create security group for the security context if it
@@ -254,7 +264,8 @@ class ComputeAPI(base.Base):
return self.db.instance_get_all(context) return self.db.instance_get_all(context)
def get_instance(self, context, instance_id): def get_instance(self, context, instance_id):
return self.db.instance_get_by_internal_id(context, instance_id) rv = self.db.instance_get_by_internal_id(context, instance_id)
return dict(rv.iteritems())
def reboot(self, context, instance_id): def reboot(self, context, instance_id):
"""Reboot the given instance.""" """Reboot the given instance."""

View File

@@ -22,6 +22,7 @@ import logging
from M2Crypto import BIO from M2Crypto import BIO
from M2Crypto import RSA from M2Crypto import RSA
import os import os
import shutil
import tempfile import tempfile
import time import time
@@ -293,6 +294,7 @@ class CloudTestCase(test.TestCase):
self.assertEqual('Foo Img', img.metadata['description']) self.assertEqual('Foo Img', img.metadata['description'])
self._fake_set_image_description(self.context, 'ami-testing', '') self._fake_set_image_description(self.context, 'ami-testing', '')
self.assertEqual('', img.metadata['description']) self.assertEqual('', img.metadata['description'])
shutil.rmtree(pathdir)
def test_update_of_instance_display_fields(self): def test_update_of_instance_display_fields(self):
inst = db.instance_create(self.context, {}) inst = db.instance_create(self.context, {})

View File

@@ -75,7 +75,7 @@ class ComputeTestCase(test.TestCase):
ref = self.compute_api.create_instances(self.context, ref = self.compute_api.create_instances(self.context,
FLAGS.default_instance_type, None, **instance) FLAGS.default_instance_type, None, **instance)
try: try:
self.assertNotEqual(ref[0].display_name, None) self.assertNotEqual(ref[0]['display_name'], None)
finally: finally:
db.instance_destroy(self.context, ref[0]['id']) db.instance_destroy(self.context, ref[0]['id'])
@@ -87,9 +87,12 @@ class ComputeTestCase(test.TestCase):
'project_id': self.project.id} 'project_id': self.project.id}
group = db.security_group_create(self.context, values) group = db.security_group_create(self.context, values)
ref = self.compute_api.create_instances(self.context, ref = self.compute_api.create_instances(self.context,
FLAGS.default_instance_type, None, security_group=['default']) instance_type=FLAGS.default_instance_type,
image_id=None,
security_group=['default'])
try: try:
self.assertEqual(len(ref[0]['security_groups']), 1) self.assertEqual(len(db.security_group_get_by_instance(
self.context, ref[0]['id'])), 1)
finally: finally:
db.security_group_destroy(self.context, group['id']) db.security_group_destroy(self.context, group['id'])
db.instance_destroy(self.context, ref[0]['id']) db.instance_destroy(self.context, ref[0]['id'])

View File

@@ -28,7 +28,8 @@ from nova import exception
from nova import test from nova import test
from nova import utils from nova import utils
from nova.api import easy from nova.api import easy
from nova.compute import api as compute_api
from nova.tests import cloud_unittest
class FakeService(object): class FakeService(object):
def echo(self, context, data): def echo(self, context, data):
@@ -83,3 +84,19 @@ class EasyTestCase(test.TestCase):
proxy = easy.Proxy(self.router) proxy = easy.Proxy(self.router)
rv = proxy.fake.echo(self.context, data='baz') rv = proxy.fake.echo(self.context, data='baz')
self.assertEqual(rv['data'], 'baz') self.assertEqual(rv['data'], 'baz')
class EasyCloudTestCase(cloud_unittest.CloudTestCase):
def setUp(self):
super(EasyCloudTestCase, self).setUp()
compute_handle = compute_api.ComputeAPI(self.cloud.network_manager,
self.cloud.image_service)
easy.register_service('compute', compute_handle)
self.router = easy.JsonParamsMiddleware(easy.SundayMorning())
proxy = easy.Proxy(self.router)
self.cloud.compute_api = proxy.compute
def tearDown(self):
super(EasyCloudTestCase, self).tearDown()
easy.EASY_ROUTES = {}

View File

@@ -22,6 +22,7 @@ System-level utilities and helper functions.
import datetime import datetime
import inspect import inspect
import json
import logging import logging
import os import os
import random import random
@@ -361,3 +362,33 @@ def utf8(value):
return value.encode("utf-8") return value.encode("utf-8")
assert isinstance(value, str) assert isinstance(value, str)
return value return value
def to_primitive(value):
if type(value) is type([]) or type(value) is type((None,)):
o = []
for v in value:
o.append(to_primitive(v))
return o
elif type(value) is type({}):
o = {}
for k, v in value.iteritems():
o[k] = to_primitive(v)
return o
elif isinstance(value, datetime.datetime):
return str(value)
else:
return value
def dumps(value):
try:
return json.dumps(value)
except TypeError:
pass
return json.dumps(to_primitive(value))
def loads(s):
return json.loads(s)

View File

@@ -21,7 +21,6 @@
Utility methods for working with WSGI servers Utility methods for working with WSGI servers
""" """
import json
import logging import logging
import sys import sys
from xml.dom import minidom from xml.dom import minidom
@@ -35,6 +34,7 @@ import webob
import webob.dec import webob.dec
import webob.exc import webob.exc
from nova import utils
logging.getLogger("routes.middleware").addHandler(logging.StreamHandler()) logging.getLogger("routes.middleware").addHandler(logging.StreamHandler())
@@ -322,7 +322,7 @@ class Serializer(object):
try: try:
is_xml = (datastring[0] == '<') is_xml = (datastring[0] == '<')
if not is_xml: if not is_xml:
return json.loads(datastring) return utils.loads(datastring)
return self._from_xml(datastring) return self._from_xml(datastring)
except: except:
return None return None
@@ -355,7 +355,7 @@ class Serializer(object):
return result return result
def _to_json(self, data): def _to_json(self, data):
return json.dumps(data) return utils.dumps(data)
def _to_xml(self, data): def _to_xml(self, data):
metadata = self.metadata.get('application/xml', {}) metadata = self.metadata.get('application/xml', {})