merge trunk... yay...

This commit is contained in:
Trey Morris
2011-06-02 17:46:16 -05:00
18 changed files with 862 additions and 276 deletions

View File

@@ -30,6 +30,7 @@ Gabe Westmaas <gabe.westmaas@rackspace.com>
Hisaharu Ishii <ishii.hisaharu@lab.ntt.co.jp>
Hisaki Ohara <hisaki.ohara@intel.com>
Ilya Alekseyev <ialekseev@griddynamics.com>
Isaku Yamahata <yamahata@valinux.co.jp>
Jason Koelker <jason@koelker.net>
Jay Pipes <jaypipes@gmail.com>
Jesse Andrews <anotherjesse@gmail.com>
@@ -58,6 +59,7 @@ Mark Washenberger <mark.washenberger@rackspace.com>
Masanori Itoh <itoumsn@nttdata.co.jp>
Matt Dietz <matt.dietz@rackspace.com>
Michael Gundlach <michael.gundlach@rackspace.com>
Mike Scherbakov <mihgen@gmail.com>
Monsyne Dragon <mdragon@rackspace.com>
Monty Taylor <mordred@inaugust.com>
MORITA Kazutaka <morita.kazutaka@gmail.com>
@@ -83,6 +85,7 @@ Trey Morris <trey.morris@rackspace.com>
Tushar Patil <tushar.vitthal.patil@gmail.com>
Vasiliy Shlykov <vash@vasiliyshlykov.org>
Vishvananda Ishaya <vishvananda@gmail.com>
Vivek Y S <vivek.ys@gmail.com>
William Wolf <throughnothing@gmail.com>
Yoshiaki Tamura <yoshi@midokura.jp>
Youcef Laribi <Youcef.Laribi@eu.citrix.com>

View File

@@ -97,7 +97,7 @@ flags.DECLARE('vlan_start', 'nova.network.manager')
flags.DECLARE('vpn_start', 'nova.network.manager')
flags.DECLARE('fixed_range_v6', 'nova.network.manager')
flags.DECLARE('images_path', 'nova.image.local')
flags.DECLARE('libvirt_type', 'nova.virt.libvirt_conn')
flags.DECLARE('libvirt_type', 'nova.virt.libvirt.connection')
flags.DEFINE_flag(flags.HelpFlag())
flags.DEFINE_flag(flags.HelpshortFlag())
flags.DEFINE_flag(flags.HelpXMLFlag())
@@ -423,12 +423,16 @@ class ProjectCommands(object):
arguments: project_id [key] [value]"""
ctxt = context.get_admin_context()
if key:
if value.lower() == 'unlimited':
value = None
try:
db.quota_update(ctxt, project_id, key, value)
except exception.ProjectQuotaNotFound:
db.quota_create(ctxt, project_id, key, value)
project_quota = quota.get_quota(ctxt, project_id)
project_quota = quota.get_project_quotas(ctxt, project_id)
for key, value in project_quota.iteritems():
if value is None:
value = 'unlimited'
print '%s: %s' % (key, value)
def remove(self, project_id, user_id):
@@ -539,7 +543,7 @@ class FloatingIpCommands(object):
for floating_ip in floating_ips:
instance = None
if floating_ip['fixed_ip']:
instance = floating_ip['fixed_ip']['instance']['ec2_id']
instance = floating_ip['fixed_ip']['instance']['hostname']
print "%s\t%s\t%s" % (floating_ip['host'],
floating_ip['address'],
instance)

View File

@@ -1,4 +1,6 @@
NOVA_KEY_DIR=$(pushd $(dirname $BASH_SOURCE)>/dev/null; pwd; popd>/dev/null)
NOVARC=$(readlink -f "${BASH_SOURCE:-${0}}" 2>/dev/null) ||
NOVARC=$(python -c 'import os,sys; print os.path.abspath(os.path.realpath(sys.argv[1]))' "${BASH_SOURCE:-${0}}")
NOVA_KEY_DIR=${NOVARC%%/*}
export EC2_ACCESS_KEY="%(access)s:%(project)s"
export EC2_SECRET_KEY="%(secret)s"
export EC2_URL="%(ec2)s"

View File

@@ -31,6 +31,7 @@ LOG = logging.getLogger("nova.fakerabbit")
EXCHANGES = {}
QUEUES = {}
CONSUMERS = {}
class Message(base.BaseMessage):
@@ -96,17 +97,29 @@ class Backend(base.BaseBackend):
' key %(routing_key)s') % locals())
EXCHANGES[exchange].bind(QUEUES[queue].push, routing_key)
def declare_consumer(self, queue, callback, *args, **kwargs):
self.current_queue = queue
self.current_callback = callback
def declare_consumer(self, queue, callback, consumer_tag, *args, **kwargs):
global CONSUMERS
LOG.debug("Adding consumer %s", consumer_tag)
CONSUMERS[consumer_tag] = (queue, callback)
def cancel(self, consumer_tag):
global CONSUMERS
LOG.debug("Removing consumer %s", consumer_tag)
del CONSUMERS[consumer_tag]
def consume(self, limit=None):
global CONSUMERS
num = 0
while True:
item = self.get(self.current_queue)
if item:
self.current_callback(item)
raise StopIteration()
greenthread.sleep(0)
for (queue, callback) in CONSUMERS.itervalues():
item = self.get(queue)
if item:
callback(item)
num += 1
yield
if limit and num == limit:
raise StopIteration()
greenthread.sleep(0.1)
def get(self, queue, no_ack=False):
global QUEUES
@@ -134,5 +147,7 @@ class Backend(base.BaseBackend):
def reset_all():
global EXCHANGES
global QUEUES
global CONSUMERS
EXCHANGES = {}
QUEUES = {}
CONSUMERS = {}

View File

@@ -296,6 +296,7 @@ DEFINE_bool('fake_network', False,
'should we use fake network devices and addresses')
DEFINE_string('rabbit_host', 'localhost', 'rabbit host')
DEFINE_integer('rabbit_port', 5672, 'rabbit port')
DEFINE_bool('rabbit_use_ssl', False, 'connect over SSL')
DEFINE_string('rabbit_userid', 'guest', 'rabbit userid')
DEFINE_string('rabbit_password', 'guest', 'rabbit password')
DEFINE_string('rabbit_virtual_host', '/', 'rabbit virtual host')

View File

@@ -35,6 +35,7 @@ import os
import sys
import traceback
import nova
from nova import flags
from nova import version
@@ -63,6 +64,7 @@ flags.DEFINE_list('default_log_levels',
'eventlet.wsgi.server=WARN'],
'list of logger=LEVEL pairs')
flags.DEFINE_bool('use_syslog', False, 'output to syslog')
flags.DEFINE_bool('publish_errors', False, 'publish error events')
flags.DEFINE_string('logfile', None, 'output to named file')
@@ -258,12 +260,20 @@ class NovaRootLogger(NovaLogger):
else:
self.removeHandler(self.filelog)
self.addHandler(self.streamlog)
if FLAGS.publish_errors:
self.addHandler(PublishErrorsHandler(ERROR))
if FLAGS.verbose:
self.setLevel(DEBUG)
else:
self.setLevel(INFO)
class PublishErrorsHandler(logging.Handler):
def emit(self, record):
nova.notifier.api.notify('nova.error.publisher', 'error_notification',
nova.notifier.api.ERROR, dict(error=record.msg))
def handle_exception(type, value, tb):
extra = {}
if FLAGS.verbose:

View File

@@ -28,12 +28,15 @@ import json
import sys
import time
import traceback
import types
import uuid
from carrot import connection as carrot_connection
from carrot import messaging
from eventlet import greenpool
from eventlet import greenthread
from eventlet import pools
from eventlet import queue
import greenlet
from nova import context
from nova import exception
@@ -47,7 +50,10 @@ LOG = logging.getLogger('nova.rpc')
FLAGS = flags.FLAGS
flags.DEFINE_integer('rpc_thread_pool_size', 1024, 'Size of RPC thread pool')
flags.DEFINE_integer('rpc_thread_pool_size', 1024,
'Size of RPC thread pool')
flags.DEFINE_integer('rpc_conn_pool_size', 30,
'Size of RPC connection pool')
class Connection(carrot_connection.BrokerConnection):
@@ -59,6 +65,7 @@ class Connection(carrot_connection.BrokerConnection):
if new or not hasattr(cls, '_instance'):
params = dict(hostname=FLAGS.rabbit_host,
port=FLAGS.rabbit_port,
ssl=FLAGS.rabbit_use_ssl,
userid=FLAGS.rabbit_userid,
password=FLAGS.rabbit_password,
virtual_host=FLAGS.rabbit_virtual_host)
@@ -90,6 +97,22 @@ class Connection(carrot_connection.BrokerConnection):
return cls.instance()
class Pool(pools.Pool):
"""Class that implements a Pool of Connections."""
# TODO(comstud): Timeout connections not used in a while
def create(self):
LOG.debug('Creating new connection')
return Connection.instance(new=True)
# Create a ConnectionPool to use for RPC calls. We'll order the
# pool as a stack (LIFO), so that we can potentially loop through and
# timeout old unused connections at some point
ConnectionPool = Pool(
max_size=FLAGS.rpc_conn_pool_size,
order_as_stack=True)
class Consumer(messaging.Consumer):
"""Consumer base class.
@@ -131,7 +154,9 @@ class Consumer(messaging.Consumer):
self.connection = Connection.recreate()
self.backend = self.connection.create_backend()
self.declare()
super(Consumer, self).fetch(no_ack, auto_ack, enable_callbacks)
return super(Consumer, self).fetch(no_ack,
auto_ack,
enable_callbacks)
if self.failed_connection:
LOG.error(_('Reconnected to queue'))
self.failed_connection = False
@@ -159,13 +184,13 @@ class AdapterConsumer(Consumer):
self.pool = greenpool.GreenPool(FLAGS.rpc_thread_pool_size)
super(AdapterConsumer, self).__init__(connection=connection,
topic=topic)
self.register_callback(self.process_data)
def receive(self, *args, **kwargs):
self.pool.spawn_n(self._receive, *args, **kwargs)
def process_data(self, message_data, message):
"""Consumer callback to call a method on a proxy object.
@exception.wrap_exception
def _receive(self, message_data, message):
"""Magically looks for a method on the proxy object and calls it.
Parses the message for validity and fires off a thread to call the
proxy object method.
Message data should be a dictionary with two keys:
method: string representing the method to call
@@ -175,8 +200,8 @@ class AdapterConsumer(Consumer):
"""
LOG.debug(_('received %s') % message_data)
msg_id = message_data.pop('_msg_id', None)
# This will be popped off in _unpack_context
msg_id = message_data.get('_msg_id', None)
ctxt = _unpack_context(message_data)
method = message_data.get('method')
@@ -188,8 +213,17 @@ class AdapterConsumer(Consumer):
# we just log the message and send an error string
# back to the caller
LOG.warn(_('no method for message: %s') % message_data)
msg_reply(msg_id, _('No method for message: %s') % message_data)
if msg_id:
msg_reply(msg_id,
_('No method for message: %s') % message_data)
return
self.pool.spawn_n(self._process_data, msg_id, ctxt, method, args)
@exception.wrap_exception
def _process_data(self, msg_id, ctxt, method, args):
"""Thread that maigcally looks for a method on the proxy
object and calls it.
"""
node_func = getattr(self.proxy, str(method))
node_args = dict((str(k), v) for k, v in args.iteritems())
@@ -197,7 +231,18 @@ class AdapterConsumer(Consumer):
try:
rval = node_func(context=ctxt, **node_args)
if msg_id:
msg_reply(msg_id, rval, None)
# Check if the result was a generator
if isinstance(rval, types.GeneratorType):
for x in rval:
msg_reply(msg_id, x, None)
else:
msg_reply(msg_id, rval, None)
# This final None tells multicall that it is done.
msg_reply(msg_id, None, None)
elif isinstance(rval, types.GeneratorType):
# NOTE(vish): this iterates through the generator
list(rval)
except Exception as e:
logging.exception('Exception during message handling')
if msg_id:
@@ -205,11 +250,6 @@ class AdapterConsumer(Consumer):
return
class Publisher(messaging.Publisher):
"""Publisher base class."""
pass
class TopicAdapterConsumer(AdapterConsumer):
"""Consumes messages on a specific topic."""
@@ -242,6 +282,58 @@ class FanoutAdapterConsumer(AdapterConsumer):
topic=topic, proxy=proxy)
class ConsumerSet(object):
"""Groups consumers to listen on together on a single connection."""
def __init__(self, connection, consumer_list):
self.consumer_list = set(consumer_list)
self.consumer_set = None
self.enabled = True
self.init(connection)
def init(self, conn):
if not conn:
conn = Connection.instance(new=True)
if self.consumer_set:
self.consumer_set.close()
self.consumer_set = messaging.ConsumerSet(conn)
for consumer in self.consumer_list:
consumer.connection = conn
# consumer.backend is set for us
self.consumer_set.add_consumer(consumer)
def reconnect(self):
self.init(None)
def wait(self, limit=None):
running = True
while running:
it = self.consumer_set.iterconsume(limit=limit)
if not it:
break
while True:
try:
it.next()
except StopIteration:
return
except greenlet.GreenletExit:
running = False
break
except Exception as e:
LOG.exception(_("Exception while processing consumer"))
self.reconnect()
# Break to outer loop
break
def close(self):
self.consumer_set.close()
class Publisher(messaging.Publisher):
"""Publisher base class."""
pass
class TopicPublisher(Publisher):
"""Publishes messages on a specific topic."""
@@ -306,16 +398,18 @@ def msg_reply(msg_id, reply=None, failure=None):
LOG.error(_("Returning exception %s to caller"), message)
LOG.error(tb)
failure = (failure[0].__name__, str(failure[1]), tb)
conn = Connection.instance()
publisher = DirectPublisher(connection=conn, msg_id=msg_id)
try:
publisher.send({'result': reply, 'failure': failure})
except TypeError:
publisher.send(
{'result': dict((k, repr(v))
for k, v in reply.__dict__.iteritems()),
'failure': failure})
publisher.close()
with ConnectionPool.item() as conn:
publisher = DirectPublisher(connection=conn, msg_id=msg_id)
try:
publisher.send({'result': reply, 'failure': failure})
except TypeError:
publisher.send(
{'result': dict((k, repr(v))
for k, v in reply.__dict__.iteritems()),
'failure': failure})
publisher.close()
class RemoteError(exception.Error):
@@ -347,8 +441,9 @@ def _unpack_context(msg):
if key.startswith('_context_'):
value = msg.pop(key)
context_dict[key[9:]] = value
context_dict['msg_id'] = msg.pop('_msg_id', None)
LOG.debug(_('unpacked context: %s'), context_dict)
return context.RequestContext.from_dict(context_dict)
return RpcContext.from_dict(context_dict)
def _pack_context(msg, context):
@@ -360,70 +455,112 @@ def _pack_context(msg, context):
for args at some point.
"""
context = dict([('_context_%s' % key, value)
for (key, value) in context.to_dict().iteritems()])
msg.update(context)
context_d = dict([('_context_%s' % key, value)
for (key, value) in context.to_dict().iteritems()])
msg.update(context_d)
def call(context, topic, msg):
"""Sends a message on a topic and wait for a response."""
class RpcContext(context.RequestContext):
def __init__(self, *args, **kwargs):
msg_id = kwargs.pop('msg_id', None)
self.msg_id = msg_id
super(RpcContext, self).__init__(*args, **kwargs)
def reply(self, *args, **kwargs):
msg_reply(self.msg_id, *args, **kwargs)
def multicall(context, topic, msg):
"""Make a call that returns multiple times."""
LOG.debug(_('Making asynchronous call on %s ...'), topic)
msg_id = uuid.uuid4().hex
msg.update({'_msg_id': msg_id})
LOG.debug(_('MSG_ID is %s') % (msg_id))
_pack_context(msg, context)
class WaitMessage(object):
def __call__(self, data, message):
"""Acks message and sets result."""
message.ack()
if data['failure']:
self.result = RemoteError(*data['failure'])
else:
self.result = data['result']
wait_msg = WaitMessage()
conn = Connection.instance()
consumer = DirectConsumer(connection=conn, msg_id=msg_id)
con_conn = ConnectionPool.get()
consumer = DirectConsumer(connection=con_conn, msg_id=msg_id)
wait_msg = MulticallWaiter(consumer)
consumer.register_callback(wait_msg)
conn = Connection.instance()
publisher = TopicPublisher(connection=conn, topic=topic)
publisher = TopicPublisher(connection=con_conn, topic=topic)
publisher.send(msg)
publisher.close()
try:
consumer.wait(limit=1)
except StopIteration:
pass
consumer.close()
# NOTE(termie): this is a little bit of a change from the original
# non-eventlet code where returning a Failure
# instance from a deferred call is very similar to
# raising an exception
if isinstance(wait_msg.result, Exception):
raise wait_msg.result
return wait_msg.result
return wait_msg
class MulticallWaiter(object):
def __init__(self, consumer):
self._consumer = consumer
self._results = queue.Queue()
self._closed = False
def close(self):
self._closed = True
self._consumer.close()
ConnectionPool.put(self._consumer.connection)
def __call__(self, data, message):
"""Acks message and sets result."""
message.ack()
if data['failure']:
self._results.put(RemoteError(*data['failure']))
else:
self._results.put(data['result'])
def __iter__(self):
return self.wait()
def wait(self):
while True:
rv = None
while rv is None and not self._closed:
try:
rv = self._consumer.fetch(enable_callbacks=True)
except Exception:
self.close()
raise
time.sleep(0.01)
result = self._results.get()
if isinstance(result, Exception):
self.close()
raise result
if result == None:
self.close()
raise StopIteration
yield result
def call(context, topic, msg):
"""Sends a message on a topic and wait for a response."""
rv = multicall(context, topic, msg)
# NOTE(vish): return the last result from the multicall
rv = list(rv)
if not rv:
return
return rv[-1]
def cast(context, topic, msg):
"""Sends a message on a topic without waiting for a response."""
LOG.debug(_('Making asynchronous cast on %s...'), topic)
_pack_context(msg, context)
conn = Connection.instance()
publisher = TopicPublisher(connection=conn, topic=topic)
publisher.send(msg)
publisher.close()
with ConnectionPool.item() as conn:
publisher = TopicPublisher(connection=conn, topic=topic)
publisher.send(msg)
publisher.close()
def fanout_cast(context, topic, msg):
"""Sends a message on a fanout exchange without waiting for a response."""
LOG.debug(_('Making asynchronous fanout cast...'))
_pack_context(msg, context)
conn = Connection.instance()
publisher = FanoutPublisher(topic, connection=conn)
publisher.send(msg)
publisher.close()
with ConnectionPool.item() as conn:
publisher = FanoutPublisher(topic, connection=conn)
publisher.send(msg)
publisher.close()
def generic_response(message_data, message):
@@ -459,6 +596,7 @@ def send_message(topic, message, wait=True):
if wait:
consumer.wait()
consumer.close()
if __name__ == '__main__':

View File

@@ -14,8 +14,8 @@
# under the License.
"""
Host Filter is a driver mechanism for requesting instance resources.
Three drivers are included: AllHosts, Flavor & JSON. AllHosts just
Host Filter is a mechanism for requesting instance resources.
Three filters are included: AllHosts, Flavor & JSON. AllHosts just
returns the full, unfiltered list of hosts. Flavor is a hard coded
matching mechanism based on flavor criteria and JSON is an ad-hoc
filter grammar.
@@ -42,17 +42,18 @@ from nova import exception
from nova import flags
from nova import log as logging
from nova import utils
from nova.scheduler import zone_aware_scheduler
LOG = logging.getLogger('nova.scheduler.host_filter')
FLAGS = flags.FLAGS
flags.DEFINE_string('default_host_filter_driver',
flags.DEFINE_string('default_host_filter',
'nova.scheduler.host_filter.AllHostsFilter',
'Which driver to use for filtering hosts.')
'Which filter to use for filtering hosts.')
class HostFilter(object):
"""Base class for host filter drivers."""
"""Base class for host filters."""
def instance_type_to_filter(self, instance_type):
"""Convert instance_type into a filter for most common use-case."""
@@ -63,14 +64,15 @@ class HostFilter(object):
raise NotImplementedError()
def _full_name(self):
"""module.classname of the filter driver"""
"""module.classname of the filter."""
return "%s.%s" % (self.__module__, self.__class__.__name__)
class AllHostsFilter(HostFilter):
"""NOP host filter driver. Returns all hosts in ZoneManager.
""" NOP host filter. Returns all hosts in ZoneManager.
This essentially does what the old Scheduler+Chance used
to give us."""
to give us.
"""
def instance_type_to_filter(self, instance_type):
"""Return anything to prevent base-class from raising
@@ -83,8 +85,8 @@ class AllHostsFilter(HostFilter):
for host, services in zone_manager.service_states.iteritems()]
class FlavorFilter(HostFilter):
"""HostFilter driver hard-coded to work with flavors."""
class InstanceTypeFilter(HostFilter):
"""HostFilter hard-coded to work with InstanceType records."""
def instance_type_to_filter(self, instance_type):
"""Use instance_type to filter hosts."""
@@ -98,9 +100,10 @@ class FlavorFilter(HostFilter):
capabilities = services.get('compute', {})
host_ram_mb = capabilities['host_memory_free']
disk_bytes = capabilities['disk_available']
if host_ram_mb >= instance_type['memory_mb'] and \
disk_bytes >= instance_type['local_gb']:
selected_hosts.append((host, capabilities))
spec_ram = instance_type['memory_mb']
spec_disk = instance_type['local_gb']
if host_ram_mb >= spec_ram and disk_bytes >= spec_disk:
selected_hosts.append((host, capabilities))
return selected_hosts
#host entries (currently) are like:
@@ -109,15 +112,15 @@ class FlavorFilter(HostFilter):
# 'host_memory_total': 8244539392,
# 'host_memory_overhead': 184225792,
# 'host_memory_free': 3868327936,
# 'host_memory_free_computed': 3840843776},
# 'host_other-config': {},
# 'host_memory_free_computed': 3840843776,
# 'host_other_config': {},
# 'host_ip_address': '192.168.1.109',
# 'host_cpu_info': {},
# 'disk_available': 32954957824,
# 'disk_total': 50394562560,
# 'disk_used': 17439604736},
# 'disk_used': 17439604736,
# 'host_uuid': 'cedb9b39-9388-41df-8891-c5c9a0c0fe5f',
# 'host_name-label': 'xs-mini'}
# 'host_name_label': 'xs-mini'}
# instance_type table has:
#name = Column(String(255), unique=True)
@@ -131,8 +134,9 @@ class FlavorFilter(HostFilter):
class JsonFilter(HostFilter):
"""Host Filter driver to allow simple JSON-based grammar for
selecting hosts."""
"""Host Filter to allow simple JSON-based grammar for
selecting hosts.
"""
def _equals(self, args):
"""First term is == all the other terms."""
@@ -222,13 +226,14 @@ class JsonFilter(HostFilter):
required_disk = instance_type['local_gb']
query = ['and',
['>=', '$compute.host_memory_free', required_ram],
['>=', '$compute.disk_available', required_disk]
['>=', '$compute.disk_available', required_disk],
]
return (self._full_name(), json.dumps(query))
def _parse_string(self, string, host, services):
"""Strings prefixed with $ are capability lookups in the
form '$service.capability[.subcap*]'"""
form '$service.capability[.subcap*]'
"""
if not string:
return None
if string[0] != '$':
@@ -271,18 +276,48 @@ class JsonFilter(HostFilter):
return hosts
DRIVERS = [AllHostsFilter, FlavorFilter, JsonFilter]
FILTERS = [AllHostsFilter, InstanceTypeFilter, JsonFilter]
def choose_driver(driver_name=None):
"""Since the caller may specify which driver to use we need
to have an authoritative list of what is permissible. This
function checks the driver name against a predefined set
of acceptable drivers."""
def choose_host_filter(filter_name=None):
"""Since the caller may specify which filter to use we need
to have an authoritative list of what is permissible. This
function checks the filter name against a predefined set
of acceptable filters.
"""
if not driver_name:
driver_name = FLAGS.default_host_filter_driver
for driver in DRIVERS:
if "%s.%s" % (driver.__module__, driver.__name__) == driver_name:
return driver()
raise exception.SchedulerHostFilterDriverNotFound(driver_name=driver_name)
if not filter_name:
filter_name = FLAGS.default_host_filter
for filter_class in FILTERS:
host_match = "%s.%s" % (filter_class.__module__, filter_class.__name__)
if host_match == filter_name:
return filter_class()
raise exception.SchedulerHostFilterNotFound(filter_name=filter_name)
class HostFilterScheduler(zone_aware_scheduler.ZoneAwareScheduler):
"""The HostFilterScheduler uses the HostFilter to filter
hosts for weighing. The particular filter used may be passed in
as an argument or the default will be used.
request_spec = {'filter': <Filter name>,
'instance_type': <InstanceType dict>}
"""
def filter_hosts(self, num, request_spec):
"""Filter the full host list (from the ZoneManager)"""
filter_name = request_spec.get('filter', None)
host_filter = choose_host_filter(filter_name)
# TODO(sandy): We're only using InstanceType-based specs
# currently. Later we'll need to snoop for more detailed
# host filter requests.
instance_type = request_spec['instance_type']
name, query = host_filter.instance_type_to_filter(instance_type)
return host_filter.filter_hosts(self.zone_manager, query)
def weigh_hosts(self, num, request_spec, hosts):
"""Derived classes must override this method and return
a lists of hosts in [{weight, hostname}] format.
"""
return [dict(weight=1, hostname=host) for host, caps in hosts]

View File

@@ -22,7 +22,9 @@ across zones. There are two expansion points to this class for:
import operator
from nova import db
from nova import log as logging
from nova import rpc
from nova.scheduler import api
from nova.scheduler import driver
@@ -36,7 +38,7 @@ class ZoneAwareScheduler(driver.Scheduler):
"""Call novaclient zone method. Broken out for testing."""
return api.call_zone_method(context, method, specs=specs)
def schedule_run_instance(self, context, topic='compute', specs={},
def schedule_run_instance(self, context, instance_id, request_spec,
*args, **kwargs):
"""This method is called from nova.compute.api to provision
an instance. However we need to look at the parameters being
@@ -44,56 +46,83 @@ class ZoneAwareScheduler(driver.Scheduler):
1. Create a Build Plan and then provision, or
2. Use the Build Plan information in the request parameters
to simply create the instance (either in this zone or
a child zone)."""
a child zone).
"""
if 'blob' in specs:
return self.provision_instance(context, topic, specs)
# TODO(sandy): We'll have to look for richer specs at some point.
if 'blob' in request_spec:
self.provision_resource(context, request_spec, instance_id, kwargs)
return None
# Create build plan and provision ...
build_plan = self.select(context, specs)
build_plan = self.select(context, request_spec)
if not build_plan:
raise driver.NoValidHost(_('No hosts were available'))
for item in build_plan:
self.provision_instance(context, topic, item)
self.provision_resource(context, item, instance_id, kwargs)
def provision_instance(context, topic, item):
"""Create the requested instance in this Zone or a child zone."""
pass
# Returning None short-circuits the routing to Compute (since
# we've already done it here)
return None
def select(self, context, *args, **kwargs):
def provision_resource(self, context, item, instance_id, kwargs):
"""Create the requested resource in this Zone or a child zone."""
if "hostname" in item:
host = item['hostname']
kwargs['instance_id'] = instance_id
rpc.cast(context,
db.queue_get_for(context, "compute", host),
{"method": "run_instance",
"args": kwargs})
LOG.debug(_("Casted to compute %(host)s for run_instance")
% locals())
else:
# TODO(sandy) Provision in child zone ...
LOG.warning(_("Provision to Child Zone not supported (yet)"))
pass
def select(self, context, request_spec, *args, **kwargs):
"""Select returns a list of weights and zone/host information
corresponding to the best hosts to service the request. Any
child zone information has been encrypted so as not to reveal
anything about the children."""
return self._schedule(context, "compute", *args, **kwargs)
anything about the children.
"""
return self._schedule(context, "compute", request_spec,
*args, **kwargs)
def schedule(self, context, topic, *args, **kwargs):
# TODO(sandy): We're only focused on compute instances right now,
# so we don't implement the default "schedule()" method required
# of Schedulers.
def schedule(self, context, topic, request_spec, *args, **kwargs):
"""The schedule() contract requires we return the one
best-suited host for this request.
"""
res = self._schedule(context, topic, *args, **kwargs)
# TODO(sirp): should this be a host object rather than a weight-dict?
if not res:
raise driver.NoValidHost(_('No hosts were available'))
return res[0]
raise driver.NoValidHost(_('No hosts were available'))
def _schedule(self, context, topic, *args, **kwargs):
def _schedule(self, context, topic, request_spec, *args, **kwargs):
"""Returns a list of hosts that meet the required specs,
ordered by their fitness.
"""
#TODO(sandy): extract these from args.
if topic != "compute":
raise NotImplemented(_("Zone Aware Scheduler only understands "
"Compute nodes (for now)"))
#TODO(sandy): how to infer this from OS API params?
num_instances = 1
specs = {}
# Filter local hosts based on requirements ...
host_list = self.filter_hosts(num_instances, specs)
host_list = self.filter_hosts(num_instances, request_spec)
# then weigh the selected hosts.
# weighted = [{weight=weight, name=hostname}, ...]
weighted = self.weigh_hosts(num_instances, specs, host_list)
weighted = self.weigh_hosts(num_instances, request_spec, host_list)
# Next, tack on the best weights from the child zones ...
child_results = self._call_zone_method(context, "select",
specs=specs)
specs=request_spec)
for child_zone, result in child_results:
for weighting in result:
# Remember the child_zone so we can get back to
@@ -108,12 +137,14 @@ class ZoneAwareScheduler(driver.Scheduler):
weighted.sort(key=operator.itemgetter('weight'))
return weighted
def filter_hosts(self, num, specs):
def filter_hosts(self, num, request_spec):
"""Derived classes must override this method and return
a list of hosts in [(hostname, capability_dict)] format."""
a list of hosts in [(hostname, capability_dict)] format.
"""
raise NotImplemented()
def weigh_hosts(self, num, specs, hosts):
def weigh_hosts(self, num, request_spec, hosts):
"""Derived classes must override this method and return
a lists of hosts in [{weight, hostname}] format."""
a lists of hosts in [{weight, hostname}] format.
"""
raise NotImplemented()

View File

@@ -17,13 +17,9 @@
# under the License.
from base64 import b64decode
import json
from M2Crypto import BIO
from M2Crypto import RSA
import os
import shutil
import tempfile
import time
from eventlet import greenthread
@@ -33,12 +29,10 @@ from nova import db
from nova import flags
from nova import log as logging
from nova import rpc
from nova import service
from nova import test
from nova import utils
from nova import exception
from nova.auth import manager
from nova.compute import power_state
from nova.api.ec2 import cloud
from nova.api.ec2 import ec2utils
from nova.image import local
@@ -79,14 +73,21 @@ class CloudTestCase(test.TestCase):
self.stubs.Set(local.LocalImageService, 'show', fake_show)
self.stubs.Set(local.LocalImageService, 'show_by_name', fake_show)
# NOTE(vish): set up a manual wait so rpc.cast has a chance to finish
rpc_cast = rpc.cast
def finish_cast(*args, **kwargs):
rpc_cast(*args, **kwargs)
greenthread.sleep(0.2)
self.stubs.Set(rpc, 'cast', finish_cast)
def tearDown(self):
network_ref = db.project_get_network(self.context,
self.project.id)
db.network_disassociate(self.context, network_ref['id'])
self.manager.delete_project(self.project)
self.manager.delete_user(self.user)
self.compute.kill()
self.network.kill()
super(CloudTestCase, self).tearDown()
def _create_key(self, name):
@@ -113,7 +114,6 @@ class CloudTestCase(test.TestCase):
self.cloud.describe_addresses(self.context)
self.cloud.release_address(self.context,
public_ip=address)
greenthread.sleep(0.3)
db.floating_ip_destroy(self.context, address)
def test_associate_disassociate_address(self):
@@ -129,12 +129,10 @@ class CloudTestCase(test.TestCase):
self.cloud.associate_address(self.context,
instance_id=ec2_id,
public_ip=address)
greenthread.sleep(0.3)
self.cloud.disassociate_address(self.context,
public_ip=address)
self.cloud.release_address(self.context,
public_ip=address)
greenthread.sleep(0.3)
self.network.deallocate_fixed_ip(self.context, fixed)
db.instance_destroy(self.context, inst['id'])
db.floating_ip_destroy(self.context, address)
@@ -171,6 +169,25 @@ class CloudTestCase(test.TestCase):
db.volume_destroy(self.context, vol1['id'])
db.volume_destroy(self.context, vol2['id'])
def test_create_volume_from_snapshot(self):
"""Makes sure create_volume works when we specify a snapshot."""
vol = db.volume_create(self.context, {'size': 1})
snap = db.snapshot_create(self.context, {'volume_id': vol['id'],
'volume_size': vol['size'],
'status': "available"})
snapshot_id = ec2utils.id_to_ec2_id(snap['id'], 'snap-%08x')
result = self.cloud.create_volume(self.context,
snapshot_id=snapshot_id)
volume_id = result['volumeId']
result = self.cloud.describe_volumes(self.context)
self.assertEqual(len(result['volumeSet']), 2)
self.assertEqual(result['volumeSet'][1]['volumeId'], volume_id)
db.volume_destroy(self.context, ec2utils.ec2_id_to_id(volume_id))
db.snapshot_destroy(self.context, snap['id'])
db.volume_destroy(self.context, vol['id'])
def test_describe_availability_zones(self):
"""Makes sure describe_availability_zones works and filters results."""
service1 = db.service_create(self.context, {'host': 'host1_zones',
@@ -188,6 +205,52 @@ class CloudTestCase(test.TestCase):
db.service_destroy(self.context, service1['id'])
db.service_destroy(self.context, service2['id'])
def test_describe_snapshots(self):
"""Makes sure describe_snapshots works and filters results."""
vol = db.volume_create(self.context, {})
snap1 = db.snapshot_create(self.context, {'volume_id': vol['id']})
snap2 = db.snapshot_create(self.context, {'volume_id': vol['id']})
result = self.cloud.describe_snapshots(self.context)
self.assertEqual(len(result['snapshotSet']), 2)
snapshot_id = ec2utils.id_to_ec2_id(snap2['id'], 'snap-%08x')
result = self.cloud.describe_snapshots(self.context,
snapshot_id=[snapshot_id])
self.assertEqual(len(result['snapshotSet']), 1)
self.assertEqual(
ec2utils.ec2_id_to_id(result['snapshotSet'][0]['snapshotId']),
snap2['id'])
db.snapshot_destroy(self.context, snap1['id'])
db.snapshot_destroy(self.context, snap2['id'])
db.volume_destroy(self.context, vol['id'])
def test_create_snapshot(self):
"""Makes sure create_snapshot works."""
vol = db.volume_create(self.context, {'status': "available"})
volume_id = ec2utils.id_to_ec2_id(vol['id'], 'vol-%08x')
result = self.cloud.create_snapshot(self.context,
volume_id=volume_id)
snapshot_id = result['snapshotId']
result = self.cloud.describe_snapshots(self.context)
self.assertEqual(len(result['snapshotSet']), 1)
self.assertEqual(result['snapshotSet'][0]['snapshotId'], snapshot_id)
db.snapshot_destroy(self.context, ec2utils.ec2_id_to_id(snapshot_id))
db.volume_destroy(self.context, vol['id'])
def test_delete_snapshot(self):
"""Makes sure delete_snapshot works."""
vol = db.volume_create(self.context, {'status': "available"})
snap = db.snapshot_create(self.context, {'volume_id': vol['id'],
'status': "available"})
snapshot_id = ec2utils.id_to_ec2_id(snap['id'], 'snap-%08x')
result = self.cloud.delete_snapshot(self.context,
snapshot_id=snapshot_id)
self.assertTrue(result)
db.volume_destroy(self.context, vol['id'])
def test_describe_instances(self):
"""Makes sure describe_instances works and filters results."""
inst1 = db.instance_create(self.context, {'reservation_id': 'a',
@@ -306,31 +369,25 @@ class CloudTestCase(test.TestCase):
'instance_type': instance_type,
'max_count': max_count}
rv = self.cloud.run_instances(self.context, **kwargs)
greenthread.sleep(0.3)
instance_id = rv['instancesSet'][0]['instanceId']
output = self.cloud.get_console_output(context=self.context,
instance_id=[instance_id])
self.assertEquals(b64decode(output['output']), 'FAKE CONSOLE?OUTPUT')
# TODO(soren): We need this until we can stop polling in the rpc code
# for unit tests.
greenthread.sleep(0.3)
rv = self.cloud.terminate_instances(self.context, [instance_id])
greenthread.sleep(0.3)
def test_ajax_console(self):
kwargs = {'image_id': 'ami-1'}
rv = self.cloud.run_instances(self.context, **kwargs)
instance_id = rv['instancesSet'][0]['instanceId']
greenthread.sleep(0.3)
output = self.cloud.get_ajax_console(context=self.context,
instance_id=[instance_id])
self.assertEquals(output['url'],
'%s/?token=FAKETOKEN' % FLAGS.ajax_console_proxy_url)
# TODO(soren): We need this until we can stop polling in the rpc code
# for unit tests.
greenthread.sleep(0.3)
rv = self.cloud.terminate_instances(self.context, [instance_id])
greenthread.sleep(0.3)
def test_key_generation(self):
result = self._create_key('test')

View File

@@ -13,7 +13,7 @@
# License for the specific language governing permissions and limitations
# under the License.
"""
Tests For Scheduler Host Filter Drivers.
Tests For Scheduler Host Filters.
"""
import json
@@ -31,7 +31,7 @@ class FakeZoneManager:
class HostFilterTestCase(test.TestCase):
"""Test case for host filter drivers."""
"""Test case for host filters."""
def _host_caps(self, multiplier):
# Returns host capabilities in the following way:
@@ -57,8 +57,8 @@ class HostFilterTestCase(test.TestCase):
'host_name-label': 'xs-%s' % multiplier}
def setUp(self):
self.old_flag = FLAGS.default_host_filter_driver
FLAGS.default_host_filter_driver = \
self.old_flag = FLAGS.default_host_filter
FLAGS.default_host_filter = \
'nova.scheduler.host_filter.AllHostsFilter'
self.instance_type = dict(name='tiny',
memory_mb=50,
@@ -76,51 +76,52 @@ class HostFilterTestCase(test.TestCase):
self.zone_manager.service_states = states
def tearDown(self):
FLAGS.default_host_filter_driver = self.old_flag
FLAGS.default_host_filter = self.old_flag
def test_choose_driver(self):
# Test default driver ...
driver = host_filter.choose_driver()
self.assertEquals(driver._full_name(),
def test_choose_filter(self):
# Test default filter ...
hf = host_filter.choose_host_filter()
self.assertEquals(hf._full_name(),
'nova.scheduler.host_filter.AllHostsFilter')
# Test valid driver ...
driver = host_filter.choose_driver(
'nova.scheduler.host_filter.FlavorFilter')
self.assertEquals(driver._full_name(),
'nova.scheduler.host_filter.FlavorFilter')
# Test invalid driver ...
# Test valid filter ...
hf = host_filter.choose_host_filter(
'nova.scheduler.host_filter.InstanceTypeFilter')
self.assertEquals(hf._full_name(),
'nova.scheduler.host_filter.InstanceTypeFilter')
# Test invalid filter ...
try:
host_filter.choose_driver('does not exist')
self.fail("Should not find driver")
except exception.SchedulerHostFilterDriverNotFound:
host_filter.choose_host_filter('does not exist')
self.fail("Should not find host filter.")
except exception.SchedulerHostFilterNotFound:
pass
def test_all_host_driver(self):
driver = host_filter.AllHostsFilter()
cooked = driver.instance_type_to_filter(self.instance_type)
hosts = driver.filter_hosts(self.zone_manager, cooked)
def test_all_host_filter(self):
hf = host_filter.AllHostsFilter()
cooked = hf.instance_type_to_filter(self.instance_type)
hosts = hf.filter_hosts(self.zone_manager, cooked)
self.assertEquals(10, len(hosts))
for host, capabilities in hosts:
self.assertTrue(host.startswith('host'))
def test_flavor_driver(self):
driver = host_filter.FlavorFilter()
def test_instance_type_filter(self):
hf = host_filter.InstanceTypeFilter()
# filter all hosts that can support 50 ram and 500 disk
name, cooked = driver.instance_type_to_filter(self.instance_type)
self.assertEquals('nova.scheduler.host_filter.FlavorFilter', name)
hosts = driver.filter_hosts(self.zone_manager, cooked)
name, cooked = hf.instance_type_to_filter(self.instance_type)
self.assertEquals('nova.scheduler.host_filter.InstanceTypeFilter',
name)
hosts = hf.filter_hosts(self.zone_manager, cooked)
self.assertEquals(6, len(hosts))
just_hosts = [host for host, caps in hosts]
just_hosts.sort()
self.assertEquals('host05', just_hosts[0])
self.assertEquals('host10', just_hosts[5])
def test_json_driver(self):
driver = host_filter.JsonFilter()
def test_json_filter(self):
hf = host_filter.JsonFilter()
# filter all hosts that can support 50 ram and 500 disk
name, cooked = driver.instance_type_to_filter(self.instance_type)
name, cooked = hf.instance_type_to_filter(self.instance_type)
self.assertEquals('nova.scheduler.host_filter.JsonFilter', name)
hosts = driver.filter_hosts(self.zone_manager, cooked)
hosts = hf.filter_hosts(self.zone_manager, cooked)
self.assertEquals(6, len(hosts))
just_hosts = [host for host, caps in hosts]
just_hosts.sort()
@@ -132,15 +133,16 @@ class HostFilterTestCase(test.TestCase):
raw = ['or',
['and',
['<', '$compute.host_memory_free', 30],
['<', '$compute.disk_available', 300]
['<', '$compute.disk_available', 300],
],
['and',
['>', '$compute.host_memory_free', 70],
['>', '$compute.disk_available', 700]
]
['>', '$compute.disk_available', 700],
],
]
cooked = json.dumps(raw)
hosts = driver.filter_hosts(self.zone_manager, cooked)
hosts = hf.filter_hosts(self.zone_manager, cooked)
self.assertEquals(5, len(hosts))
just_hosts = [host for host, caps in hosts]
@@ -152,7 +154,7 @@ class HostFilterTestCase(test.TestCase):
['=', '$compute.host_memory_free', 30],
]
cooked = json.dumps(raw)
hosts = driver.filter_hosts(self.zone_manager, cooked)
hosts = hf.filter_hosts(self.zone_manager, cooked)
self.assertEquals(9, len(hosts))
just_hosts = [host for host, caps in hosts]
@@ -162,7 +164,7 @@ class HostFilterTestCase(test.TestCase):
raw = ['in', '$compute.host_memory_free', 20, 40, 60, 80, 100]
cooked = json.dumps(raw)
hosts = driver.filter_hosts(self.zone_manager, cooked)
hosts = hf.filter_hosts(self.zone_manager, cooked)
self.assertEquals(5, len(hosts))
just_hosts = [host for host, caps in hosts]
@@ -174,35 +176,30 @@ class HostFilterTestCase(test.TestCase):
raw = ['unknown command', ]
cooked = json.dumps(raw)
try:
driver.filter_hosts(self.zone_manager, cooked)
hf.filter_hosts(self.zone_manager, cooked)
self.fail("Should give KeyError")
except KeyError, e:
pass
self.assertTrue(driver.filter_hosts(self.zone_manager, json.dumps([])))
self.assertTrue(driver.filter_hosts(self.zone_manager, json.dumps({})))
self.assertTrue(driver.filter_hosts(self.zone_manager, json.dumps(
['not', True, False, True, False]
)))
self.assertTrue(hf.filter_hosts(self.zone_manager, json.dumps([])))
self.assertTrue(hf.filter_hosts(self.zone_manager, json.dumps({})))
self.assertTrue(hf.filter_hosts(self.zone_manager, json.dumps(
['not', True, False, True, False])))
try:
driver.filter_hosts(self.zone_manager, json.dumps(
'not', True, False, True, False
))
hf.filter_hosts(self.zone_manager, json.dumps(
'not', True, False, True, False))
self.fail("Should give KeyError")
except KeyError, e:
pass
self.assertFalse(driver.filter_hosts(self.zone_manager, json.dumps(
['=', '$foo', 100]
)))
self.assertFalse(driver.filter_hosts(self.zone_manager, json.dumps(
['=', '$.....', 100]
)))
self.assertFalse(driver.filter_hosts(self.zone_manager, json.dumps(
['>', ['and', ['or', ['not', ['<', ['>=', ['<=', ['in', ]]]]]]]]
)))
self.assertFalse(hf.filter_hosts(self.zone_manager,
json.dumps(['=', '$foo', 100])))
self.assertFalse(hf.filter_hosts(self.zone_manager,
json.dumps(['=', '$.....', 100])))
self.assertFalse(hf.filter_hosts(self.zone_manager,
json.dumps(
['>', ['and', ['or', ['not', ['<', ['>=', ['<=', ['in', ]]]]]]]])))
self.assertFalse(driver.filter_hosts(self.zone_manager, json.dumps(
['=', {}, ['>', '$missing....foo']]
)))
self.assertFalse(hf.filter_hosts(self.zone_manager,
json.dumps(['=', {}, ['>', '$missing....foo']])))

View File

@@ -18,6 +18,7 @@ import eventlet
import mox
import os
import re
import shutil
import sys
from xml.etree.ElementTree import fromstring as xml_to_tree
@@ -32,7 +33,8 @@ from nova import utils
from nova.api.ec2 import cloud
from nova.auth import manager
from nova.compute import power_state
from nova.virt import libvirt_conn
from nova.virt.libvirt import connection
from nova.virt.libvirt import firewall
libvirt = None
FLAGS = flags.FLAGS
@@ -83,7 +85,7 @@ class CacheConcurrencyTestCase(test.TestCase):
def test_same_fname_concurrency(self):
"""Ensures that the same fname cache runs at a sequentially"""
conn = libvirt_conn.LibvirtConnection
conn = connection.LibvirtConnection
wait1 = eventlet.event.Event()
done1 = eventlet.event.Event()
eventlet.spawn(conn._cache_image, _concurrency,
@@ -104,7 +106,7 @@ class CacheConcurrencyTestCase(test.TestCase):
def test_different_fname_concurrency(self):
"""Ensures that two different fname caches are concurrent"""
conn = libvirt_conn.LibvirtConnection
conn = connection.LibvirtConnection
wait1 = eventlet.event.Event()
done1 = eventlet.event.Event()
eventlet.spawn(conn._cache_image, _concurrency,
@@ -125,7 +127,7 @@ class CacheConcurrencyTestCase(test.TestCase):
class LibvirtConnTestCase(test.TestCase):
def setUp(self):
super(LibvirtConnTestCase, self).setUp()
libvirt_conn._late_load_cheetah()
connection._late_load_cheetah()
self.flags(fake_call=True)
self.manager = manager.AuthManager()
@@ -159,6 +161,7 @@ class LibvirtConnTestCase(test.TestCase):
'vcpus': 2,
'project_id': 'fake',
'bridge': 'br101',
'image_id': '123456',
'instance_type_id': '5'} # m1.small
def lazy_load_library_exists(self):
@@ -171,8 +174,8 @@ class LibvirtConnTestCase(test.TestCase):
return False
global libvirt
libvirt = __import__('libvirt')
libvirt_conn.libvirt = __import__('libvirt')
libvirt_conn.libxml2 = __import__('libxml2')
connection.libvirt = __import__('libvirt')
connection.libxml2 = __import__('libxml2')
return True
def create_fake_libvirt_mock(self, **kwargs):
@@ -182,7 +185,7 @@ class LibvirtConnTestCase(test.TestCase):
class FakeLibvirtConnection(object):
pass
# A fake libvirt_conn.IptablesFirewallDriver
# A fake connection.IptablesFirewallDriver
class FakeIptablesFirewallDriver(object):
def __init__(self, **kwargs):
@@ -198,11 +201,11 @@ class LibvirtConnTestCase(test.TestCase):
for key, val in kwargs.items():
fake.__setattr__(key, val)
# Inevitable mocks for libvirt_conn.LibvirtConnection
self.mox.StubOutWithMock(libvirt_conn.utils, 'import_class')
libvirt_conn.utils.import_class(mox.IgnoreArg()).AndReturn(fakeip)
self.mox.StubOutWithMock(libvirt_conn.LibvirtConnection, '_conn')
libvirt_conn.LibvirtConnection._conn = fake
# Inevitable mocks for connection.LibvirtConnection
self.mox.StubOutWithMock(connection.utils, 'import_class')
connection.utils.import_class(mox.IgnoreArg()).AndReturn(fakeip)
self.mox.StubOutWithMock(connection.LibvirtConnection, '_conn')
connection.LibvirtConnection._conn = fake
def create_service(self, **kwargs):
service_ref = {'host': kwargs.get('host', 'dummy'),
@@ -214,7 +217,7 @@ class LibvirtConnTestCase(test.TestCase):
return db.service_create(context.get_admin_context(), service_ref)
def test_preparing_xml_info(self):
conn = libvirt_conn.LibvirtConnection(True)
conn = connection.LibvirtConnection(True)
instance_ref = db.instance_create(self.context, self.test_instance)
result = conn._prepare_xml_info(instance_ref, False)
@@ -229,7 +232,7 @@ class LibvirtConnTestCase(test.TestCase):
self.assertTrue(len(result['nics']) == 2)
def test_get_nic_for_xml_v4(self):
conn = libvirt_conn.LibvirtConnection(True)
conn = connection.LibvirtConnection(True)
network, mapping = _create_network_info()[0]
self.flags(use_ipv6=False)
params = conn._get_nic_for_xml(network, mapping)['extra_params']
@@ -237,7 +240,7 @@ class LibvirtConnTestCase(test.TestCase):
self.assertTrue(params.find('PROJMASKV6') == -1)
def test_get_nic_for_xml_v6(self):
conn = libvirt_conn.LibvirtConnection(True)
conn = connection.LibvirtConnection(True)
network, mapping = _create_network_info()[0]
self.flags(use_ipv6=True)
params = conn._get_nic_for_xml(network, mapping)['extra_params']
@@ -279,10 +282,72 @@ class LibvirtConnTestCase(test.TestCase):
instance_data = dict(self.test_instance)
self._check_xml_and_container(instance_data)
def test_snapshot(self):
FLAGS.image_service = 'nova.image.fake.FakeImageService'
# Only file-based instance storages are supported at the moment
test_xml = """
<domain type='kvm'>
<devices>
<disk type='file'>
<source file='filename'/>
</disk>
</devices>
</domain>
"""
class FakeVirtDomain(object):
def __init__(self):
pass
def snapshotCreateXML(self, *args):
return None
def XMLDesc(self, *args):
return test_xml
def fake_lookup(instance_name):
if instance_name == instance_ref.name:
return FakeVirtDomain()
def fake_execute(*args):
# Touch filename to pass 'with open(out_path)'
open(args[-1], "a").close()
# Start test
image_service = utils.import_object(FLAGS.image_service)
# Assuming that base image already exists in image_service
instance_ref = db.instance_create(self.context, self.test_instance)
properties = {'instance_id': instance_ref['id'],
'user_id': str(self.context.user_id)}
snapshot_name = 'test-snap'
sent_meta = {'name': snapshot_name, 'is_public': False,
'status': 'creating', 'properties': properties}
# Create new image. It will be updated in snapshot method
# To work with it from snapshot, the single image_service is needed
recv_meta = image_service.create(context, sent_meta)
self.mox.StubOutWithMock(connection.LibvirtConnection, '_conn')
connection.LibvirtConnection._conn.lookupByName = fake_lookup
self.mox.StubOutWithMock(connection.utils, 'execute')
connection.utils.execute = fake_execute
self.mox.ReplayAll()
conn = connection.LibvirtConnection(False)
conn.snapshot(instance_ref, recv_meta['id'])
snapshot = image_service.show(context, recv_meta['id'])
self.assertEquals(snapshot['properties']['image_state'], 'available')
self.assertEquals(snapshot['status'], 'active')
self.assertEquals(snapshot['name'], snapshot_name)
def test_multi_nic(self):
instance_data = dict(self.test_instance)
network_info = _create_network_info(2)
conn = libvirt_conn.LibvirtConnection(True)
conn = connection.LibvirtConnection(True)
instance_ref = db.instance_create(self.context, instance_data)
xml = conn.to_xml(instance_ref, False, network_info)
tree = xml_to_tree(xml)
@@ -313,7 +378,7 @@ class LibvirtConnTestCase(test.TestCase):
'instance_id': instance_ref['id']})
self.flags(libvirt_type='lxc')
conn = libvirt_conn.LibvirtConnection(True)
conn = connection.LibvirtConnection(True)
uri = conn.get_uri()
self.assertEquals(uri, 'lxc:///')
@@ -419,7 +484,7 @@ class LibvirtConnTestCase(test.TestCase):
for (libvirt_type, (expected_uri, checks)) in type_uri_map.iteritems():
FLAGS.libvirt_type = libvirt_type
conn = libvirt_conn.LibvirtConnection(True)
conn = connection.LibvirtConnection(True)
uri = conn.get_uri()
self.assertEquals(uri, expected_uri)
@@ -446,7 +511,7 @@ class LibvirtConnTestCase(test.TestCase):
FLAGS.libvirt_uri = testuri
for (libvirt_type, (expected_uri, checks)) in type_uri_map.iteritems():
FLAGS.libvirt_type = libvirt_type
conn = libvirt_conn.LibvirtConnection(True)
conn = connection.LibvirtConnection(True)
uri = conn.get_uri()
self.assertEquals(uri, testuri)
db.instance_destroy(user_context, instance_ref['id'])
@@ -470,13 +535,13 @@ class LibvirtConnTestCase(test.TestCase):
self.create_fake_libvirt_mock(getVersion=getVersion,
getType=getType,
listDomainsID=listDomainsID)
self.mox.StubOutWithMock(libvirt_conn.LibvirtConnection,
self.mox.StubOutWithMock(connection.LibvirtConnection,
'get_cpu_info')
libvirt_conn.LibvirtConnection.get_cpu_info().AndReturn('cpuinfo')
connection.LibvirtConnection.get_cpu_info().AndReturn('cpuinfo')
# Start test
self.mox.ReplayAll()
conn = libvirt_conn.LibvirtConnection(False)
conn = connection.LibvirtConnection(False)
conn.update_available_resource(self.context, 'dummy')
service_ref = db.service_get(self.context, service_ref['id'])
compute_node = service_ref['compute_node'][0]
@@ -510,7 +575,7 @@ class LibvirtConnTestCase(test.TestCase):
self.create_fake_libvirt_mock()
self.mox.ReplayAll()
conn = libvirt_conn.LibvirtConnection(False)
conn = connection.LibvirtConnection(False)
self.assertRaises(exception.ComputeServiceUnavailable,
conn.update_available_resource,
self.context, 'dummy')
@@ -545,7 +610,7 @@ class LibvirtConnTestCase(test.TestCase):
# Start test
self.mox.ReplayAll()
try:
conn = libvirt_conn.LibvirtConnection(False)
conn = connection.LibvirtConnection(False)
conn.firewall_driver.setattr('setup_basic_filtering', fake_none)
conn.firewall_driver.setattr('prepare_instance_filter', fake_none)
conn.firewall_driver.setattr('instance_filter_exists', fake_none)
@@ -594,7 +659,7 @@ class LibvirtConnTestCase(test.TestCase):
# Start test
self.mox.ReplayAll()
conn = libvirt_conn.LibvirtConnection(False)
conn = connection.LibvirtConnection(False)
self.assertRaises(libvirt.libvirtError,
conn._live_migration,
self.context, instance_ref, 'dest', '',
@@ -623,7 +688,7 @@ class LibvirtConnTestCase(test.TestCase):
# Start test
self.mox.ReplayAll()
conn = libvirt_conn.LibvirtConnection(False)
conn = connection.LibvirtConnection(False)
conn.firewall_driver.setattr('setup_basic_filtering', fake_none)
conn.firewall_driver.setattr('prepare_instance_filter', fake_none)
@@ -644,10 +709,12 @@ class LibvirtConnTestCase(test.TestCase):
except Exception, e:
count = (0 <= str(e.message).find('Unexpected method call'))
shutil.rmtree(os.path.join(FLAGS.instances_path, instance.name))
self.assertTrue(count)
def test_get_host_ip_addr(self):
conn = libvirt_conn.LibvirtConnection(False)
conn = connection.LibvirtConnection(False)
ip = conn.get_host_ip_addr()
self.assertEquals(ip, FLAGS.my_ip)
@@ -671,7 +738,7 @@ class IptablesFirewallTestCase(test.TestCase):
class FakeLibvirtConnection(object):
pass
self.fake_libvirt_connection = FakeLibvirtConnection()
self.fw = libvirt_conn.IptablesFirewallDriver(
self.fw = firewall.IptablesFirewallDriver(
get_connection=lambda: self.fake_libvirt_connection)
def tearDown(self):
@@ -895,7 +962,7 @@ class NWFilterTestCase(test.TestCase):
self.fake_libvirt_connection = Mock()
self.fw = libvirt_conn.NWFilterFirewall(
self.fw = firewall.NWFilterFirewall(
lambda: self.fake_libvirt_connection)
def tearDown(self):

View File

@@ -13,10 +13,12 @@
# License for the specific language governing permissions and limitations
# under the License.
import nova
import stubout
import nova
from nova import context
from nova import flags
from nova import log
from nova import rpc
import nova.notifier.api
from nova.notifier.api import notify
@@ -24,8 +26,6 @@ from nova.notifier import no_op_notifier
from nova.notifier import rabbit_notifier
from nova import test
import stubout
class NotifierTestCase(test.TestCase):
"""Test case for notifications"""
@@ -115,3 +115,22 @@ class NotifierTestCase(test.TestCase):
notify('publisher_id',
'event_type', 'DEBUG', dict(a=3))
self.assertEqual(self.test_topic, 'testnotify.debug')
def test_error_notification(self):
self.stubs.Set(nova.flags.FLAGS, 'notification_driver',
'nova.notifier.rabbit_notifier')
self.stubs.Set(nova.flags.FLAGS, 'publish_errors', True)
LOG = log.getLogger('nova')
LOG.setup_from_flags()
msgs = []
def mock_cast(context, topic, data):
msgs.append(data)
self.stubs.Set(nova.rpc, 'cast', mock_cast)
LOG.error('foo')
self.assertEqual(1, len(msgs))
msg = msgs[0]
self.assertEqual(msg['event_type'], 'error_notification')
self.assertEqual(msg['priority'], 'ERROR')
self.assertEqual(msg['payload']['error'], 'foo')

View File

@@ -31,7 +31,6 @@ LOG = logging.getLogger('nova.tests.rpc')
class RpcTestCase(test.TestCase):
"""Test cases for rpc"""
def setUp(self):
super(RpcTestCase, self).setUp()
self.conn = rpc.Connection.instance(True)
@@ -43,14 +42,55 @@ class RpcTestCase(test.TestCase):
self.context = context.get_admin_context()
def test_call_succeed(self):
"""Get a value through rpc call"""
value = 42
result = rpc.call(self.context, 'test', {"method": "echo",
"args": {"value": value}})
self.assertEqual(value, result)
def test_call_succeed_despite_multiple_returns(self):
value = 42
result = rpc.call(self.context, 'test', {"method": "echo_three_times",
"args": {"value": value}})
self.assertEqual(value + 2, result)
def test_call_succeed_despite_multiple_returns_yield(self):
value = 42
result = rpc.call(self.context, 'test',
{"method": "echo_three_times_yield",
"args": {"value": value}})
self.assertEqual(value + 2, result)
def test_multicall_succeed_once(self):
value = 42
result = rpc.multicall(self.context,
'test',
{"method": "echo",
"args": {"value": value}})
for i, x in enumerate(result):
if i > 0:
self.fail('should only receive one response')
self.assertEqual(value + i, x)
def test_multicall_succeed_three_times(self):
value = 42
result = rpc.multicall(self.context,
'test',
{"method": "echo_three_times",
"args": {"value": value}})
for i, x in enumerate(result):
self.assertEqual(value + i, x)
def test_multicall_succeed_three_times_yield(self):
value = 42
result = rpc.multicall(self.context,
'test',
{"method": "echo_three_times_yield",
"args": {"value": value}})
for i, x in enumerate(result):
self.assertEqual(value + i, x)
def test_context_passed(self):
"""Makes sure a context is passed through rpc call"""
"""Makes sure a context is passed through rpc call."""
value = 42
result = rpc.call(self.context,
'test', {"method": "context",
@@ -58,11 +98,12 @@ class RpcTestCase(test.TestCase):
self.assertEqual(self.context.to_dict(), result)
def test_call_exception(self):
"""Test that exception gets passed back properly
"""Test that exception gets passed back properly.
rpc.call returns a RemoteError object. The value of the
exception is converted to a string, so we convert it back
to an int in the test.
"""
value = 42
self.assertRaises(rpc.RemoteError,
@@ -81,7 +122,7 @@ class RpcTestCase(test.TestCase):
self.assertEqual(int(exc.value), value)
def test_nested_calls(self):
"""Test that we can do an rpc.call inside another call"""
"""Test that we can do an rpc.call inside another call."""
class Nested(object):
@staticmethod
def echo(context, queue, value):
@@ -108,25 +149,80 @@ class RpcTestCase(test.TestCase):
"value": value}})
self.assertEqual(value, result)
def test_connectionpool_single(self):
"""Test that ConnectionPool recycles a single connection."""
conn1 = rpc.ConnectionPool.get()
rpc.ConnectionPool.put(conn1)
conn2 = rpc.ConnectionPool.get()
rpc.ConnectionPool.put(conn2)
self.assertEqual(conn1, conn2)
def test_connectionpool_double(self):
"""Test that ConnectionPool returns and reuses separate connections.
When called consecutively we should get separate connections and upon
returning them those connections should be reused for future calls
before generating a new connection.
"""
conn1 = rpc.ConnectionPool.get()
conn2 = rpc.ConnectionPool.get()
self.assertNotEqual(conn1, conn2)
rpc.ConnectionPool.put(conn1)
rpc.ConnectionPool.put(conn2)
conn3 = rpc.ConnectionPool.get()
conn4 = rpc.ConnectionPool.get()
self.assertEqual(conn1, conn3)
self.assertEqual(conn2, conn4)
def test_connectionpool_limit(self):
"""Test connection pool limit and connection uniqueness."""
max_size = FLAGS.rpc_conn_pool_size
conns = []
for i in xrange(max_size):
conns.append(rpc.ConnectionPool.get())
self.assertFalse(rpc.ConnectionPool.free_items)
self.assertEqual(rpc.ConnectionPool.current_size,
rpc.ConnectionPool.max_size)
self.assertEqual(len(set(conns)), max_size)
class TestReceiver(object):
"""Simple Proxy class so the consumer has methods to call
"""Simple Proxy class so the consumer has methods to call.
Uses static methods because we aren't actually storing any state"""
Uses static methods because we aren't actually storing any state.
"""
@staticmethod
def echo(context, value):
"""Simply returns whatever value is sent in"""
"""Simply returns whatever value is sent in."""
LOG.debug(_("Received %s"), value)
return value
@staticmethod
def context(context, value):
"""Returns dictionary version of context"""
"""Returns dictionary version of context."""
LOG.debug(_("Received %s"), context)
return context.to_dict()
@staticmethod
def echo_three_times(context, value):
context.reply(value)
context.reply(value + 1)
context.reply(value + 2)
@staticmethod
def echo_three_times_yield(context, value):
yield value
yield value + 1
yield value + 2
@staticmethod
def fail(context, value):
"""Raises an exception with the value sent in"""
"""Raises an exception with the value sent in."""
raise Exception(value)

View File

@@ -45,10 +45,11 @@ class VolumeTestCase(test.TestCase):
self.context = context.get_admin_context()
@staticmethod
def _create_volume(size='0'):
def _create_volume(size='0', snapshot_id=None):
"""Create a volume object."""
vol = {}
vol['size'] = size
vol['snapshot_id'] = snapshot_id
vol['user_id'] = 'fake'
vol['project_id'] = 'fake'
vol['availability_zone'] = FLAGS.storage_availability_zone
@@ -69,6 +70,25 @@ class VolumeTestCase(test.TestCase):
self.context,
volume_id)
def test_create_volume_from_snapshot(self):
"""Test volume can be created from a snapshot."""
volume_src_id = self._create_volume()
self.volume.create_volume(self.context, volume_src_id)
snapshot_id = self._create_snapshot(volume_src_id)
self.volume.create_snapshot(self.context, volume_src_id, snapshot_id)
volume_dst_id = self._create_volume(0, snapshot_id)
self.volume.create_volume(self.context, volume_dst_id, snapshot_id)
self.assertEqual(volume_dst_id, db.volume_get(
context.get_admin_context(),
volume_dst_id).id)
self.assertEqual(snapshot_id, db.volume_get(
context.get_admin_context(),
volume_dst_id).snapshot_id)
self.volume.delete_volume(self.context, volume_dst_id)
self.volume.delete_snapshot(self.context, snapshot_id)
self.volume.delete_volume(self.context, volume_src_id)
def test_too_big_volume(self):
"""Ensure failure if a too large of a volume is requested."""
# FIXME(vish): validation needs to move into the data layer in
@@ -176,6 +196,34 @@ class VolumeTestCase(test.TestCase):
# This will allow us to test cross-node interactions
pass
@staticmethod
def _create_snapshot(volume_id, size='0'):
"""Create a snapshot object."""
snap = {}
snap['volume_size'] = size
snap['user_id'] = 'fake'
snap['project_id'] = 'fake'
snap['volume_id'] = volume_id
snap['status'] = "creating"
return db.snapshot_create(context.get_admin_context(), snap)['id']
def test_create_delete_snapshot(self):
"""Test snapshot can be created and deleted."""
volume_id = self._create_volume()
self.volume.create_volume(self.context, volume_id)
snapshot_id = self._create_snapshot(volume_id)
self.volume.create_snapshot(self.context, volume_id, snapshot_id)
self.assertEqual(snapshot_id,
db.snapshot_get(context.get_admin_context(),
snapshot_id).id)
self.volume.delete_snapshot(self.context, snapshot_id)
self.assertRaises(exception.NotFound,
db.snapshot_get,
self.context,
snapshot_id)
self.volume.delete_volume(self.context, volume_id)
class DriverTestCase(test.TestCase):
"""Base Test class for Drivers."""

View File

@@ -395,6 +395,29 @@ class XenAPIVMTestCase(test.TestCase):
os_type="linux")
self.check_vm_params_for_linux()
def test_spawn_vhd_glance_swapdisk(self):
# Change the default host_call_plugin to one that'll return
# a swap disk
orig_func = stubs.FakeSessionForVMTests.host_call_plugin
stubs.FakeSessionForVMTests.host_call_plugin = \
stubs.FakeSessionForVMTests.host_call_plugin_swap
try:
# We'll steal the above glance linux test
self.test_spawn_vhd_glance_linux()
finally:
# Make sure to put this back
stubs.FakeSessionForVMTests.host_call_plugin = orig_func
# We should have 2 VBDs.
self.assertEqual(len(self.vm['VBDs']), 2)
# Now test that we have 1.
self.tearDown()
self.setUp()
self.test_spawn_vhd_glance_linux()
self.assertEqual(len(self.vm['VBDs']), 1)
def test_spawn_vhd_glance_windows(self):
FLAGS.xenapi_image_service = 'glance'
self._test_spawn(glance_stubs.FakeGlance.IMAGE_VHD, None, None,
@@ -569,11 +592,29 @@ class XenAPIDiffieHellmanTestCase(test.TestCase):
bob_shared = self.bob.compute_shared(alice_pub)
self.assertEquals(alice_shared, bob_shared)
def test_encryption(self):
msg = "This is a top-secret message"
enc = self.alice.encrypt(msg)
def _test_encryption(self, message):
enc = self.alice.encrypt(message)
self.assertFalse(enc.endswith('\n'))
dec = self.bob.decrypt(enc)
self.assertEquals(dec, msg)
self.assertEquals(dec, message)
def test_encrypt_simple_message(self):
self._test_encryption('This is a simple message.')
def test_encrypt_message_with_newlines_at_end(self):
self._test_encryption('This message has a newline at the end.\n')
def test_encrypt_many_newlines_at_end(self):
self._test_encryption('Message with lotsa newlines.\n\n\n')
def test_encrypt_newlines_inside_message(self):
self._test_encryption('Message\nwith\ninterior\nnewlines.')
def test_encrypt_with_leading_newlines(self):
self._test_encryption('\n\nMessage with leading newlines.')
def test_encrypt_really_long_message(self):
self._test_encryption(''.join(['abcd' for i in xrange(1024)]))
def tearDown(self):
super(XenAPIDiffieHellmanTestCase, self).tearDown()

View File

@@ -38,16 +38,16 @@ class FakeZoneAwareScheduler(zone_aware_scheduler.ZoneAwareScheduler):
class FakeZoneManager(zone_manager.ZoneManager):
def __init__(self):
self.service_states = {
'host1': {
'compute': {'ram': 1000}
},
'host2': {
'compute': {'ram': 2000}
},
'host3': {
'compute': {'ram': 3000}
}
}
'host1': {
'compute': {'ram': 1000},
},
'host2': {
'compute': {'ram': 2000},
},
'host3': {
'compute': {'ram': 3000},
},
}
class FakeEmptyZoneManager(zone_manager.ZoneManager):
@@ -116,4 +116,6 @@ class ZoneAwareSchedulerTestCase(test.TestCase):
sched.set_zone_manager(zm)
fake_context = {}
self.assertRaises(driver.NoValidHost, sched.schedule, fake_context, {})
self.assertRaises(driver.NoValidHost, sched.schedule_run_instance,
fake_context, 1,
dict(host_filter=None, instance_type={}))

View File

@@ -17,6 +17,7 @@
"""Stubouts, mocks and fixtures for the test suite"""
import eventlet
import json
from nova.virt import xenapi_conn
from nova.virt.xenapi import fake
from nova.virt.xenapi import volume_utils
@@ -37,7 +38,7 @@ def stubout_instance_snapshot(stubs):
sr_ref=sr_ref, sharable=False)
vdi_rec = session.get_xenapi().VDI.get_record(vdi_ref)
vdi_uuid = vdi_rec['uuid']
return vdi_uuid
return [dict(vdi_type='os', vdi_uuid=vdi_uuid)]
stubs.Set(vm_utils.VMHelper, 'fetch_image', fake_fetch_image)
@@ -132,11 +133,30 @@ class FakeSessionForVMTests(fake.SessionBase):
def __init__(self, uri):
super(FakeSessionForVMTests, self).__init__(uri)
def host_call_plugin(self, _1, _2, _3, _4, _5):
def host_call_plugin(self, _1, _2, plugin, method, _5):
sr_ref = fake.get_all('SR')[0]
vdi_ref = fake.create_vdi('', False, sr_ref, False)
vdi_rec = fake.get_record('VDI', vdi_ref)
return '<string>%s</string>' % vdi_rec['uuid']
if plugin == "glance" and method == "download_vhd":
ret_str = json.dumps([dict(vdi_type='os',
vdi_uuid=vdi_rec['uuid'])])
else:
ret_str = vdi_rec['uuid']
return '<string>%s</string>' % ret_str
def host_call_plugin_swap(self, _1, _2, plugin, method, _5):
sr_ref = fake.get_all('SR')[0]
vdi_ref = fake.create_vdi('', False, sr_ref, False)
vdi_rec = fake.get_record('VDI', vdi_ref)
if plugin == "glance" and method == "download_vhd":
swap_vdi_ref = fake.create_vdi('', False, sr_ref, False)
swap_vdi_rec = fake.get_record('VDI', swap_vdi_ref)
ret_str = json.dumps(
[dict(vdi_type='os', vdi_uuid=vdi_rec['uuid']),
dict(vdi_type='swap', vdi_uuid=swap_vdi_rec['uuid'])])
else:
ret_str = vdi_rec['uuid']
return '<string>%s</string>' % ret_str
def VM_start(self, _1, ref, _2, _3):
vm = fake.get_record('VM', ref)