merge lp:nova

This commit is contained in:
Jesse Andrews
2010-10-24 17:41:53 -07:00
73 changed files with 3316 additions and 3071 deletions

21
Authors Normal file
View File

@@ -0,0 +1,21 @@
Andy Smith <code@term.ie>
Anne Gentle <anne@openstack.org>
Chris Behrens <cbehrens@codestud.com>
Devin Carlen <devin.carlen@gmail.com>
Eric Day <eday@oddments.org>
Ewan Mellor <ewan.mellor@citrix.com>
Hisaki Ohara <hisaki.ohara@intel.com>
Jay Pipes <jaypipes@gmail.com>
Jesse Andrews <anotherjesse@gmail.com>
Joe Heck <heckj@mac.com>
Joel Moore joelbm24@gmail.com
Joshua McKenty <jmckenty@gmail.com>
Justin Santa Barbara <justin@fathomdb.com>
Matt Dietz <matt.dietz@rackspace.com>
Michael Gundlach <michael.gundlach@rackspace.com>
Monty Taylor <mordred@inaugust.com>
Paul Voccio <paul@openstack.org>
Rick Clark <rick@openstack.org>
Soren Hansen <soren.hansen@rackspace.com>
Todd Willey <todd@ansolabs.com>
Vishvananda Ishaya <vishvananda@gmail.com>

View File

@@ -1,22 +1,30 @@
include HACKING LICENSE run_tests.py run_tests.sh
include HACKING LICENSE run_tests.py run_tests.sh
include README builddeb.sh exercise_rsapi.py
include ChangeLog
include ChangeLog MANIFEST.in pylintrc Authors
graft CA
graft doc
graft smoketests
graft tools
include nova/api/openstack/notes.txt
include nova/auth/novarc.template
include nova/auth/slap.sh
include nova/cloudpipe/bootscript.sh
include nova/cloudpipe/client.ovpn.template
include nova/compute/fakevirtinstance.xml
include nova/compute/interfaces.template
include nova/compute/libvirt.xml.template
include nova/virt/interfaces.template
include nova/virt/libvirt.qemu.xml.template
include nova/virt/libvirt.uml.xml.template
include nova/virt/libvirt.xen.xml.template
include nova/tests/CA/
include nova/tests/CA/cacert.pem
include nova/tests/CA/private/
include nova/tests/CA/private/cakey.pem
include nova/tests/bundle/
include nova/tests/bundle/1mb.manifest.xml
include nova/tests/bundle/1mb.no_kernel_or_ramdisk.manifest.xml
include nova/tests/bundle/1mb.part.0
include nova/tests/bundle/1mb.part.1
include plugins/xenapi/README
include plugins/xenapi/etc/xapi.d/plugins/objectstore
include plugins/xenapi/etc/xapi.d/plugins/pluginlib_nova.py

View File

@@ -1,53 +1,49 @@
#!/usr/bin/env python
# pylint: disable-msg=C0103
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tornado daemon for the main API endpoint.
Nova API daemon.
"""
import logging
from tornado import httpserver
from tornado import ioloop
import os
import sys
# If ../nova/__init__.py exists, add ../ to Python search path, so that
# it will override what happens to be installed in /usr/(local/)lib/python...
possible_topdir = os.path.normpath(os.path.join(os.path.abspath(sys.argv[0]),
os.pardir,
os.pardir))
if os.path.exists(os.path.join(possible_topdir, 'nova', '__init__.py')):
sys.path.insert(0, possible_topdir)
from nova import flags
from nova import server
from nova import utils
from nova.endpoint import admin
from nova.endpoint import api
from nova.endpoint import cloud
from nova import server
FLAGS = flags.FLAGS
flags.DEFINE_integer('api_port', 8773, 'API port')
def main(_argv):
"""Load the controllers and start the tornado I/O loop."""
controllers = {
'Cloud': cloud.CloudController(),
'Admin': admin.AdminController()}
_app = api.APIServerApplication(controllers)
io_inst = ioloop.IOLoop.instance()
http_server = httpserver.HTTPServer(_app)
http_server.listen(FLAGS.cc_port)
logging.debug('Started HTTP server on %s', FLAGS.cc_port)
io_inst.start()
def main(_args):
from nova import api
from nova import wsgi
wsgi.run_server(api.API(), FLAGS.api_port)
if __name__ == '__main__':
utils.default_flagfile()

View File

@@ -1,34 +0,0 @@
#!/usr/bin/env python
# pylint: disable-msg=C0103
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Nova API daemon.
"""
from nova import api
from nova import flags
from nova import utils
from nova import wsgi
FLAGS = flags.FLAGS
flags.DEFINE_integer('api_port', 8773, 'API port')
if __name__ == '__main__':
utils.default_flagfile()
wsgi.run_server(api.API(), FLAGS.api_port)

View File

@@ -21,6 +21,17 @@
Twistd daemon for the nova compute nodes.
"""
import os
import sys
# If ../nova/__init__.py exists, add ../ to Python search path, so that
# it will override what happens to be installed in /usr/(local/)lib/python...
possible_topdir = os.path.normpath(os.path.join(os.path.abspath(sys.argv[0]),
os.pardir,
os.pardir))
if os.path.exists(os.path.join(possible_topdir, 'nova', '__init__.py')):
sys.path.insert(0, possible_topdir)
from nova import service
from nova import twistd

View File

@@ -25,62 +25,74 @@ import logging
import os
import sys
# TODO(joshua): there is concern that the user dnsmasq runs under will not
# have nova in the path. This should be verified and if it is
# not true the ugly line below can be removed
sys.path.append(os.path.abspath(os.path.join(__file__, "../../")))
# If ../nova/__init__.py exists, add ../ to Python search path, so that
# it will override what happens to be installed in /usr/(local/)lib/python...
possible_topdir = os.path.normpath(os.path.join(os.path.abspath(sys.argv[0]),
os.pardir,
os.pardir))
if os.path.exists(os.path.join(possible_topdir, 'nova', '__init__.py')):
sys.path.insert(0, possible_topdir)
from nova import context
from nova import db
from nova import flags
from nova import rpc
from nova import utils
from nova import datastore # for redis_db flag
from nova.auth import manager # for auth flags
from nova.network import linux_net
from nova.network import manager # for network flags
FLAGS = flags.FLAGS
flags.DECLARE('auth_driver', 'nova.auth.manager')
flags.DECLARE('network_size', 'nova.network.manager')
flags.DECLARE('num_networks', 'nova.network.manager')
flags.DECLARE('update_dhcp_on_disassociate', 'nova.network.manager')
def add_lease(_mac, ip_address, _hostname, _interface):
def add_lease(mac, ip_address, _hostname, _interface):
"""Set the IP that was assigned by the DHCP server."""
if FLAGS.fake_rabbit:
logging.debug("leasing ip")
network_manager = utils.import_object(FLAGS.network_manager)
network_manager.lease_fixed_ip(None, ip_address)
network_manager.lease_fixed_ip(context.get_admin_context(),
mac,
ip_address)
else:
rpc.cast("%s.%s" % (FLAGS.network_topic, FLAGS.node_name),
rpc.cast(context.get_admin_context(),
"%s.%s" % (FLAGS.network_topic, FLAGS.host),
{"method": "lease_fixed_ip",
"args": {"context": None,
"args": {"mac": mac,
"address": ip_address}})
def old_lease(_mac, _ip_address, _hostname, _interface):
"""Do nothing, just an old lease update."""
def old_lease(mac, ip_address, hostname, interface):
"""Update just as add lease."""
logging.debug("Adopted old lease or got a change of mac/hostname")
add_lease(mac, ip_address, hostname, interface)
def del_lease(_mac, ip_address, _hostname, _interface):
def del_lease(mac, ip_address, _hostname, _interface):
"""Called when a lease expires."""
if FLAGS.fake_rabbit:
logging.debug("releasing ip")
network_manager = utils.import_object(FLAGS.network_manager)
network_manager.release_fixed_ip(None, ip_address)
network_manager.release_fixed_ip(context.get_admin_context(),
mac,
ip_address)
else:
rpc.cast("%s.%s" % (FLAGS.network_topic, FLAGS.node_name),
rpc.cast(context.get_admin_context(),
"%s.%s" % (FLAGS.network_topic, FLAGS.host),
{"method": "release_fixed_ip",
"args": {"context": None,
"args": {"mac": mac,
"address": ip_address}})
def init_leases(interface):
"""Get the list of hosts for an interface."""
network_ref = db.network_get_by_bridge(None, interface)
return linux_net.get_dhcp_hosts(None, network_ref['id'])
ctxt = context.get_admin_context()
network_ref = db.network_get_by_bridge(ctxt, interface)
return linux_net.get_dhcp_hosts(ctxt, network_ref['id'])
def main():
global network_manager
"""Parse environment and arguments and call the approproate action."""
flagfile = os.environ.get('FLAGFILE', FLAGS.dhcpbridge_flagfile)
utils.default_flagfile(flagfile)
@@ -88,18 +100,16 @@ def main():
interface = os.environ.get('DNSMASQ_INTERFACE', 'br0')
if int(os.environ.get('TESTING', '0')):
FLAGS.fake_rabbit = True
FLAGS.redis_db = 8
FLAGS.network_size = 16
FLAGS.connection_type = 'fake'
FLAGS.fake_network = True
FLAGS.auth_driver = 'nova.auth.ldapdriver.FakeLdapDriver'
FLAGS.auth_driver = 'nova.auth.dbdriver.DbDriver'
FLAGS.num_networks = 5
path = os.path.abspath(os.path.join(os.path.dirname(__file__),
'..',
'_trial_temp',
'nova.sqlite'))
FLAGS.sql_connection = 'sqlite:///%s' % path
#FLAGS.sql_connection = 'mysql://root@localhost/test'
action = argv[1]
if action in ['add', 'del', 'old']:
mac = argv[2]

View File

@@ -29,6 +29,14 @@ import subprocess
import sys
import urllib2
# If ../nova/__init__.py exists, add ../ to Python search path, so that
# it will override what happens to be installed in /usr/(local/)lib/python...
possible_topdir = os.path.normpath(os.path.join(os.path.abspath(sys.argv[0]),
os.pardir,
os.pardir))
if os.path.exists(os.path.join(possible_topdir, 'nova', '__init__.py')):
sys.path.insert(0, possible_topdir)
from nova import flags
from nova import utils
from nova.objectstore import image

View File

@@ -21,9 +21,19 @@
Daemon for Nova RRD based instance resource monitoring.
"""
import os
import logging
import sys
from twisted.application import service
# If ../nova/__init__.py exists, add ../ to Python search path, so that
# it will override what happens to be installed in /usr/(local/)lib/python...
possible_topdir = os.path.normpath(os.path.join(os.path.abspath(sys.argv[0]),
os.pardir,
os.pardir))
if os.path.exists(os.path.join(possible_topdir, 'nova', '__init__.py')):
sys.path.insert(0, possible_topdir)
from nova import twistd
from nova.compute import monitor

View File

@@ -17,23 +17,73 @@
# License for the specific language governing permissions and limitations
# under the License.
# Interactive shell based on Django:
#
# Copyright (c) 2005, the Lawrence Journal-World
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
#
# 3. Neither the name of Django nor the names of its contributors may be
# used to endorse or promote products derived from this software without
# specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
CLI interface for nova management.
Connects to the running ADMIN api in the api daemon.
"""
import logging
import os
import sys
import time
import IPy
# If ../nova/__init__.py exists, add ../ to Python search path, so that
# it will override what happens to be installed in /usr/(local/)lib/python...
possible_topdir = os.path.normpath(os.path.join(os.path.abspath(sys.argv[0]),
os.pardir,
os.pardir))
if os.path.exists(os.path.join(possible_topdir, 'nova', '__init__.py')):
sys.path.insert(0, possible_topdir)
from nova import context
from nova import db
from nova import exception
from nova import flags
from nova import quota
from nova import utils
from nova.auth import manager
from nova.cloudpipe import pipelib
from nova.endpoint import cloud
FLAGS = flags.FLAGS
flags.DECLARE('fixed_range', 'nova.network.manager')
flags.DECLARE('num_networks', 'nova.network.manager')
flags.DECLARE('network_size', 'nova.network.manager')
flags.DECLARE('vlan_start', 'nova.network.manager')
flags.DECLARE('vpn_start', 'nova.network.manager')
class VpnCommands(object):
@@ -41,16 +91,21 @@ class VpnCommands(object):
def __init__(self):
self.manager = manager.AuthManager()
self.pipe = pipelib.CloudPipe(cloud.CloudController())
self.pipe = pipelib.CloudPipe()
def list(self):
"""Print a listing of the VPNs for all projects."""
print "%-12s\t" % 'project',
print "%-12s\t" % 'ip:port',
print "%-20s\t" % 'ip:port',
print "%s" % 'state'
for project in self.manager.get_projects():
print "%-12s\t" % project.name,
print "%s:%s\t" % (project.vpn_ip, project.vpn_port),
try:
s = "%s:%s" % (project.vpn_ip, project.vpn_port)
except exception.NotFound:
s = "None"
print "%-20s\t" % s,
vpn = self._vpn_for(project.id)
if vpn:
@@ -72,7 +127,7 @@ class VpnCommands(object):
def _vpn_for(self, project_id):
"""Get the VPN instance for a project ID."""
for instance in db.instance_get_all():
for instance in db.instance_get_all(context.get_admin_context()):
if (instance['image_id'] == FLAGS.vpn_image_id
and not instance['state_description'] in
['shutting_down', 'shutdown']
@@ -92,6 +147,68 @@ class VpnCommands(object):
self.pipe.launch_vpn_instance(project_id)
class ShellCommands(object):
def bpython(self):
"""Runs a bpython shell.
Falls back to Ipython/python shell if unavailable"""
self.run('bpython')
def ipython(self):
"""Runs an Ipython shell.
Falls back to Python shell if unavailable"""
self.run('ipython')
def python(self):
"""Runs a python shell.
Falls back to Python shell if unavailable"""
self.run('python')
def run(self, shell=None):
"""Runs a Python interactive interpreter.
args: [shell=bpython]"""
if not shell:
shell = 'bpython'
if shell == 'bpython':
try:
import bpython
bpython.embed()
except ImportError:
shell = 'ipython'
if shell == 'ipython':
try:
import IPython
# Explicitly pass an empty list as arguments, because
# otherwise IPython would use sys.argv from this script.
shell = IPython.Shell.IPShell(argv=[])
shell.mainloop()
except ImportError:
shell = 'python'
if shell == 'python':
import code
try:
# Try activating rlcompleter, because it's handy.
import readline
except ImportError:
pass
else:
# We don't have to wrap the following import in a 'try',
# because we already know 'readline' was imported successfully.
import rlcompleter
readline.parse_and_bind("tab:complete")
code.interact()
def script(self, path):
"""Runs the script from the specifed path with flags set properly.
arguments: path"""
exec(compile(open(path).read(), path, 'exec'), locals(), globals())
class RoleCommands(object):
"""Class for managing roles."""
@@ -121,6 +238,12 @@ class RoleCommands(object):
class UserCommands(object):
"""Class for managing users."""
@staticmethod
def _print_export(user):
"""Print export variables to use with API."""
print 'export EC2_ACCESS_KEY=%s' % user.access
print 'export EC2_SECRET_KEY=%s' % user.secret
def __init__(self):
self.manager = manager.AuthManager()
@@ -128,13 +251,13 @@ class UserCommands(object):
"""creates a new admin and prints exports
arguments: name [access] [secret]"""
user = self.manager.create_user(name, access, secret, True)
print_export(user)
self._print_export(user)
def create(self, name, access=None, secret=None):
"""creates a new user and prints exports
arguments: name [access] [secret]"""
user = self.manager.create_user(name, access, secret, False)
print_export(user)
self._print_export(user)
def delete(self, name):
"""deletes an existing user
@@ -146,7 +269,7 @@ class UserCommands(object):
arguments: name"""
user = self.manager.get_user(name)
if user:
print_export(user)
self._print_export(user)
else:
print "User %s doesn't exist" % name
@@ -156,11 +279,18 @@ class UserCommands(object):
for user in self.manager.get_users():
print user.name
def print_export(user):
"""Print export variables to use with API."""
print 'export EC2_ACCESS_KEY=%s' % user.access
print 'export EC2_SECRET_KEY=%s' % user.secret
def modify(self, name, access_key, secret_key, is_admin):
"""update a users keys & admin flag
arguments: accesskey secretkey admin
leave any field blank to ignore it, admin should be 'T', 'F', or blank
"""
if not is_admin:
is_admin = None
elif is_admin.upper()[0] == 'T':
is_admin = True
else:
is_admin = False
self.manager.modify_user(name, access_key, secret_key, is_admin)
class ProjectCommands(object):
@@ -169,10 +299,10 @@ class ProjectCommands(object):
def __init__(self):
self.manager = manager.AuthManager()
def add(self, project, user):
def add(self, project_id, user_id):
"""Adds user to project
arguments: project user"""
self.manager.add_to_project(user, project)
arguments: project_id user_id"""
self.manager.add_to_project(user_id, project_id)
def create(self, name, project_manager, description=None):
"""Creates a new project
@@ -187,7 +317,7 @@ class ProjectCommands(object):
def environment(self, project_id, user_id, filename='novarc'):
"""Exports environment variables to an sourcable file
arguments: project_id user_id [filename='novarc]"""
rc = self.manager.get_environment_rc(project_id, user_id)
rc = self.manager.get_environment_rc(user_id, project_id)
with open(filename, 'w') as f:
f.write(rc)
@@ -197,10 +327,34 @@ class ProjectCommands(object):
for project in self.manager.get_projects():
print project.name
def remove(self, project, user):
def quota(self, project_id, key=None, value=None):
"""Set or display quotas for project
arguments: project_id [key] [value]"""
ctxt = context.get_admin_context()
if key:
quo = {'project_id': project_id, key: value}
try:
db.quota_update(ctxt, project_id, quo)
except exception.NotFound:
db.quota_create(ctxt, quo)
project_quota = quota.get_quota(ctxt, project_id)
for key, value in project_quota.iteritems():
print '%s: %s' % (key, value)
def remove(self, project_id, user_id):
"""Removes user from project
arguments: project user"""
self.manager.remove_from_project(user, project)
arguments: project_id user_id"""
self.manager.remove_from_project(user_id, project_id)
def scrub(self, project_id):
"""Deletes data associated with project
arguments: project_id"""
ctxt = context.get_admin_context()
network_ref = db.project_get_network(ctxt, project_id)
db.network_disassociate(ctxt, network_ref['id'])
groups = db.security_group_get_by_project(ctxt, project_id)
for group in groups:
db.security_group_destroy(ctxt, group['id'])
def zipfile(self, project_id, user_id, filename='nova.zip'):
"""Exports credentials for project to a zip file
@@ -210,12 +364,74 @@ class ProjectCommands(object):
f.write(zip_file)
class FloatingIpCommands(object):
"""Class for managing floating ip."""
def create(self, host, range):
"""Creates floating ips for host by range
arguments: host ip_range"""
for address in IPy.IP(range):
db.floating_ip_create(context.get_admin_context(),
{'address': str(address),
'host': host})
def delete(self, ip_range):
"""Deletes floating ips by range
arguments: range"""
for address in IPy.IP(ip_range):
db.floating_ip_destroy(context.get_admin_context(),
str(address))
def list(self, host=None):
"""Lists all floating ips (optionally by host)
arguments: [host]"""
ctxt = context.get_admin_context()
if host == None:
floating_ips = db.floating_ip_get_all(ctxt)
else:
floating_ips = db.floating_ip_get_all_by_host(ctxt, host)
for floating_ip in floating_ips:
instance = None
if floating_ip['fixed_ip']:
instance = floating_ip['fixed_ip']['instance']['ec2_id']
print "%s\t%s\t%s" % (floating_ip['host'],
floating_ip['address'],
instance)
class NetworkCommands(object):
"""Class for managing networks."""
def create(self, fixed_range=None, num_networks=None,
network_size=None, vlan_start=None, vpn_start=None):
"""Creates fixed ips for host by range
arguments: [fixed_range=FLAG], [num_networks=FLAG],
[network_size=FLAG], [vlan_start=FLAG],
[vpn_start=FLAG]"""
if not fixed_range:
fixed_range = FLAGS.fixed_range
if not num_networks:
num_networks = FLAGS.num_networks
if not network_size:
network_size = FLAGS.network_size
if not vlan_start:
vlan_start = FLAGS.vlan_start
if not vpn_start:
vpn_start = FLAGS.vpn_start
net_manager = utils.import_object(FLAGS.network_manager)
net_manager.create_networks(context.get_admin_context(),
fixed_range, int(num_networks),
int(network_size), int(vlan_start),
int(vpn_start))
CATEGORIES = [
('user', UserCommands),
('project', ProjectCommands),
('role', RoleCommands),
('shell', ShellCommands),
('vpn', VpnCommands),
]
('floating', FloatingIpCommands),
('network', NetworkCommands)]
def lazy_match(name, key_value_tuples):
@@ -253,6 +469,10 @@ def main():
"""Parse options and call the appropriate class/method."""
utils.default_flagfile('/etc/nova/nova-manage.conf')
argv = FLAGS(sys.argv)
if FLAGS.verbose:
logging.getLogger().setLevel(logging.DEBUG)
script_name = argv.pop(0)
if len(argv) < 1:
print script_name + " category action [<args>]"
@@ -280,9 +500,9 @@ def main():
fn(*argv)
sys.exit(0)
except TypeError:
print "Wrong number of arguments supplied"
print "Possible wrong number of arguments supplied"
print "%s %s: %s" % (category, action, fn.__doc__)
sys.exit(2)
raise
if __name__ == '__main__':
main()

View File

@@ -21,6 +21,17 @@
Twistd daemon for the nova network nodes.
"""
import os
import sys
# If ../nova/__init__.py exists, add ../ to Python search path, so that
# it will override what happens to be installed in /usr/(local/)lib/python...
possible_topdir = os.path.normpath(os.path.join(os.path.abspath(sys.argv[0]),
os.pardir,
os.pardir))
if os.path.exists(os.path.join(possible_topdir, 'nova', '__init__.py')):
sys.path.insert(0, possible_topdir)
from nova import service
from nova import twistd

View File

@@ -21,6 +21,17 @@
Twisted daemon for nova objectstore. Supports S3 API.
"""
import os
import sys
# If ../nova/__init__.py exists, add ../ to Python search path, so that
# it will override what happens to be installed in /usr/(local/)lib/python...
possible_topdir = os.path.normpath(os.path.join(os.path.abspath(sys.argv[0]),
os.pardir,
os.pardir))
if os.path.exists(os.path.join(possible_topdir, 'nova', '__init__.py')):
sys.path.insert(0, possible_topdir)
from nova import flags
from nova import utils
from nova import twistd

View File

@@ -21,6 +21,17 @@
Twistd daemon for the nova scheduler nodes.
"""
import os
import sys
# If ../nova/__init__.py exists, add ../ to Python search path, so that
# it will override what happens to be installed in /usr/(local/)lib/python...
possible_topdir = os.path.normpath(os.path.join(os.path.abspath(sys.argv[0]),
os.pardir,
os.pardir))
if os.path.exists(os.path.join(possible_topdir, 'nova', '__init__.py')):
sys.path.insert(0, possible_topdir)
from nova import service
from nova import twistd

View File

@@ -21,6 +21,17 @@
Twistd daemon for the nova volume nodes.
"""
import os
import sys
# If ../nova/__init__.py exists, add ../ to Python search path, so that
# it will override what happens to be installed in /usr/(local/)lib/python...
possible_topdir = os.path.normpath(os.path.join(os.path.abspath(sys.argv[0]),
os.pardir,
os.pardir))
if os.path.exists(os.path.join(possible_topdir, 'nova', '__init__.py')):
sys.path.insert(0, possible_topdir)
from nova import service
from nova import twistd

View File

@@ -17,13 +17,14 @@
import bzrlib.log
from bzrlib.osutils import format_date
#
# This is mostly stolen from bzrlib.log.GnuChangelogLogFormatter
# The difference is that it logs the author rather than the committer
# which for Nova always is Tarmac.
#
class NovaLogFormat(bzrlib.log.GnuChangelogLogFormatter):
"""This is mostly stolen from bzrlib.log.GnuChangelogLogFormatter
The difference is that it logs the author rather than the committer
which for Nova always is Tarmac."""
preferred_levels = 1
def log_revision(self, revision):
"""Log a revision, either merged or not."""
to_file = self.to_file
@@ -38,13 +39,14 @@ class NovaLogFormat(bzrlib.log.GnuChangelogLogFormatter):
to_file.write('%s %s\n\n' % (date_str, ", ".join(authors)))
if revision.delta is not None and revision.delta.has_changed():
for c in revision.delta.added + revision.delta.removed + revision.delta.modified:
for c in revision.delta.added + revision.delta.removed + \
revision.delta.modified:
path, = c[:1]
to_file.write('\t* %s:\n' % (path,))
for c in revision.delta.renamed:
oldpath,newpath = c[:2]
oldpath, newpath = c[:2]
# For renamed files, show both the old and the new path
to_file.write('\t* %s:\n\t* %s:\n' % (oldpath,newpath))
to_file.write('\t* %s:\n\t* %s:\n' % (oldpath, newpath))
to_file.write('\n')
if not revision.rev.message:
@@ -56,4 +58,3 @@ class NovaLogFormat(bzrlib.log.GnuChangelogLogFormatter):
to_file.write('\n')
bzrlib.log.register_formatter('novalog', NovaLogFormat)

View File

@@ -172,14 +172,6 @@ Further Challenges
The :mod:`rbac` Module
--------------------------
.. automodule:: nova.auth.rbac
:members:
:undoc-members:
:show-inheritance:
The :mod:`signer` Module
------------------------

View File

@@ -47,9 +47,9 @@ copyright = u'2010, United States Government as represented by the Administrator
# built documents.
#
# The short X.Y version.
version = '0.9'
version = '2010.1'
# The full version, including alpha/beta/rc tags.
release = '0.9.1'
release = '2010.1'
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.

View File

@@ -18,34 +18,30 @@
Getting Started with Nova
=========================
This code base is continually changing so dependencies also change.
GOTTA HAVE A nova.pth file added or it WONT WORK (will write setup.py file soon)
Create a file named nova.pth in your python libraries directory
(usually /usr/local/lib/python2.6/dist-packages) with a single line that points
to the directory where you checked out the source (that contains the nova/
directory).
DEPENDENCIES
Dependencies
------------
Related servers we rely on
* RabbitMQ: messaging queue, used for all communication between components
* OpenLDAP: users, groups (maybe cut)
* ReDIS: Remote Dictionary Store (for fast, shared state data)
* nginx: HTTP server to handle serving large files (because Tornado can't)
Optional servers
* OpenLDAP: By default, the auth server uses the RDBMS-backed datastore by setting FLAGS.auth_driver to 'nova.auth.dbdriver.DbDriver'. But OpenLDAP (or LDAP) could be configured.
* ReDIS: By default, this is not enabled as the auth driver.
Python libraries we don't vendor
* M2Crypto: python library interface for openssl
* curl
* XenAPI: Needed only for Xen Cloud Platform or XenServer support. Available from http://wiki.xensource.com/xenwiki/XCP_SDK or http://community.citrix.com/cdn/xs/sdks.
* XenAPI: Needed only for Xen Cloud Platform or XenServer support. Available from http://wiki.xensource.com/xenwiki/XCP_SDK or http://community.citrix.com/cdn/xs/sdks.
Vendored python libaries (don't require any installation)
* Tornado: scalable non blocking web server for api requests
* Twisted: just for the twisted.internet.defer package
* Tornado: scalable non blocking web server for api requests
* boto: python api for aws api
* IPy: library for managing ip addresses
@@ -58,40 +54,19 @@ Recommended
Installation
--------------
::
# system libraries and tools
apt-get install -y aoetools vlan curl
modprobe aoe
# python libraries
apt-get install -y python-setuptools python-dev python-pycurl python-m2crypto
# ON THE CLOUD CONTROLLER
apt-get install -y rabbitmq-server dnsmasq nginx
# build redis from 2.0.0-rc1 source
# setup ldap (slap.sh as root will remove ldap and reinstall it)
NOVA_PATH/nova/auth/slap.sh
/etc/init.d/rabbitmq-server start
# ON VOLUME NODE:
apt-get install -y vblade-persist
# ON THE COMPUTE NODE:
apt-get install -y python-libvirt
apt-get install -y kpartx kvm libvirt-bin
modprobe kvm
# optional packages
apt-get install -y euca2ools
Due to many changes it's best to rely on the `OpenStack wiki <http://wiki.openstack.org>`_ for installation instructions.
Configuration
---------------
ON CLOUD CONTROLLER
These instructions are incomplete, but we are actively updating the `OpenStack wiki <http://wiki.openstack.org>`_ with more configuration information.
On the cloud controller
* Add yourself to the libvirtd group, log out, and log back in
* fix hardcoded ec2 metadata/userdata uri ($IP is the IP of the cloud), and masqurade all traffic from launched instances
* Fix hardcoded ec2 metadata/userdata uri ($IP is the IP of the cloud), and masqurade all traffic from launched instances
::
iptables -t nat -A PREROUTING -s 0.0.0.0/0 -d 169.254.169.254/32 -p tcp -m tcp --dport 80 -j DNAT --to-destination $IP:8773
@@ -119,9 +94,9 @@ ON CLOUD CONTROLLER
}
}
ON VOLUME NODE
On the volume node
* create a filesystem (you can use an actual disk if you have one spare, default is /dev/sdb)
* Create a filesystem (you can use an actual disk if you have one spare, default is /dev/sdb)
::
@@ -137,9 +112,7 @@ Running
Launch servers
* rabbitmq
* redis
* slapd
* nginx
* redis (optional)
Launch nova components

View File

@@ -15,18 +15,22 @@
License for the specific language governing permissions and limitations
under the License.
Welcome to nova's documentation!
Welcome to Nova's documentation!
================================
Nova is a cloud computing fabric controller (the main part of an IaaS system) built to match the popular AWS EC2 and S3 APIs.
It is written in Python, using the Tornado and Twisted frameworks, and relies on the standard AMQP messaging protocol,
and the Redis distributed KVS.
Nova is intended to be easy to extend, and adapt. For example, it currently uses
an LDAP server for users and groups, but also includes a fake LDAP server,
that stores data in Redis. It has extensive test coverage, and uses the
Sphinx toolkit (the same as Python itself) for code and user documentation.
Nova is a cloud computing fabric controller (the main part of an IaaS system).
It is written in Python and relies on the standard AMQP messaging protocol, uses the Twisted framework,
and optionally uses the Redis distributed key value store for authorization.
Nova is intended to be easy to extend and adapt. For example, authentication and authorization
requests by default use an RDBMS-backed datastore driver. However, there is already support
for using LDAP backing authentication (slapd) and if you wish to "fake" LDAP, there is a module
available that uses ReDIS to store authentication information in an LDAP-like backing datastore.
It has extensive test coverage, and uses the Sphinx toolkit (the same as Python itself) for code
and developer documentation. Additional documentation is available on the
'OpenStack wiki <http://wiki.openstack.org>'_.
While Nova is currently in Beta use within several organizations, the codebase
is very much under active development - there are bugs!
is very much under active development - please test it and log bugs!
Contents:

View File

@@ -20,8 +20,8 @@ Nova User API client library.
"""
import base64
import boto
import httplib
from boto.ec2.regioninfo import RegionInfo
class ConsoleInfo(object):
@@ -38,6 +38,12 @@ class ConsoleInfo(object):
if name == 'kind':
self.url = str(value)
DEFAULT_CLC_URL = 'http://127.0.0.1:8773'
DEFAULT_REGION = 'nova'
DEFAULT_ACCESS_KEY = 'admin'
DEFAULT_SECRET_KEY = 'admin'
class UserInfo(object):
"""
Information about a Nova user, as parsed through SAX
@@ -81,13 +87,13 @@ class UserRole(object):
def __init__(self, connection=None):
self.connection = connection
self.role = None
def __repr__(self):
return 'UserRole:%s' % self.role
def startElement(self, name, attrs, connection):
return None
def endElement(self, name, value, connection):
if name == 'role':
self.role = value
@@ -141,20 +147,20 @@ class ProjectMember(object):
def __init__(self, connection=None):
self.connection = connection
self.memberId = None
def __repr__(self):
return 'ProjectMember:%s' % self.memberId
def startElement(self, name, attrs, connection):
return None
def endElement(self, name, value, connection):
if name == 'member':
self.memberId = value
else:
setattr(self, name, str(value))
class HostInfo(object):
"""
Information about a Nova Host, as parsed through SAX:
@@ -184,58 +190,78 @@ class HostInfo(object):
class NovaAdminClient(object):
def __init__(self, clc_ip='127.0.0.1', region='nova', access_key='admin',
secret_key='admin', **kwargs):
self.clc_ip = clc_ip
def __init__(self, clc_url=DEFAULT_CLC_URL, region=DEFAULT_REGION,
access_key=DEFAULT_ACCESS_KEY, secret_key=DEFAULT_SECRET_KEY,
**kwargs):
parts = self.split_clc_url(clc_url)
self.clc_url = clc_url
self.region = region
self.access = access_key
self.secret = secret_key
self.apiconn = boto.connect_ec2(aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
is_secure=False,
region=RegionInfo(None, region, clc_ip),
port=8773,
is_secure=parts['is_secure'],
region=RegionInfo(None,
region,
parts['ip']),
port=parts['port'],
path='/services/Admin',
**kwargs)
self.apiconn.APIVersion = 'nova'
def connection_for(self, username, project, **kwargs):
"""
Returns a boto ec2 connection for the given username.
"""
def connection_for(self, username, project, clc_url=None, region=None,
**kwargs):
"""Returns a boto ec2 connection for the given username."""
if not clc_url:
clc_url = self.clc_url
if not region:
region = self.region
parts = self.split_clc_url(clc_url)
user = self.get_user(username)
access_key = '%s:%s' % (user.accesskey, project)
return boto.connect_ec2(
aws_access_key_id=access_key,
aws_secret_access_key=user.secretkey,
is_secure=False,
region=RegionInfo(None, self.region, self.clc_ip),
port=8773,
path='/services/Cloud',
**kwargs)
return boto.connect_ec2(aws_access_key_id=access_key,
aws_secret_access_key=user.secretkey,
is_secure=parts['is_secure'],
region=RegionInfo(None,
self.region,
parts['ip']),
port=parts['port'],
path='/services/Cloud',
**kwargs)
def split_clc_url(self, clc_url):
"""Splits a cloud controller endpoint url."""
parts = httplib.urlsplit(clc_url)
is_secure = parts.scheme == 'https'
ip, port = parts.netloc.split(':')
return {'ip': ip, 'port': int(port), 'is_secure': is_secure}
def get_users(self):
""" grabs the list of all users """
"""Grabs the list of all users."""
return self.apiconn.get_list('DescribeUsers', {}, [('item', UserInfo)])
def get_user(self, name):
""" grab a single user by name """
user = self.apiconn.get_object('DescribeUser', {'Name': name}, UserInfo)
"""Grab a single user by name."""
user = self.apiconn.get_object('DescribeUser', {'Name': name},
UserInfo)
if user.username != None:
return user
def has_user(self, username):
""" determine if user exists """
"""Determine if user exists."""
return self.get_user(username) != None
def create_user(self, username):
""" creates a new user, returning the userinfo object with access/secret """
return self.apiconn.get_object('RegisterUser', {'Name': username}, UserInfo)
"""Creates a new user, returning the userinfo object with
access/secret."""
return self.apiconn.get_object('RegisterUser', {'Name': username},
UserInfo)
def delete_user(self, username):
""" deletes a user """
return self.apiconn.get_object('DeregisterUser', {'Name': username}, UserInfo)
"""Deletes a user."""
return self.apiconn.get_object('DeregisterUser', {'Name': username},
UserInfo)
def get_roles(self, project_roles=True):
"""Returns a list of available roles."""
@@ -244,11 +270,10 @@ class NovaAdminClient(object):
[('item', UserRole)])
def get_user_roles(self, user, project=None):
"""Returns a list of roles for the given user.
Omitting project will return any global roles that the user has.
Specifying project will return only project specific roles.
"""
params = {'User':user}
"""Returns a list of roles for the given user. Omitting project will
return any global roles that the user has. Specifying project will
return only project specific roles."""
params = {'User': user}
if project:
params['Project'] = project
return self.apiconn.get_list('DescribeUserRoles',
@@ -256,24 +281,19 @@ class NovaAdminClient(object):
[('item', UserRole)])
def add_user_role(self, user, role, project=None):
"""
Add a role to a user either globally or for a specific project.
"""
"""Add a role to a user either globally or for a specific project."""
return self.modify_user_role(user, role, project=project,
operation='add')
def remove_user_role(self, user, role, project=None):
"""
Remove a role from a user either globally or for a specific project.
"""
"""Remove a role from a user either globally or for a specific
project."""
return self.modify_user_role(user, role, project=project,
operation='remove')
def modify_user_role(self, user, role, project=None, operation='add',
**kwargs):
"""
Add or remove a role for a user and project.
"""
"""Add or remove a role for a user and project."""
params = {'User': user,
'Role': role,
'Project': project,
@@ -281,9 +301,7 @@ class NovaAdminClient(object):
return self.apiconn.get_status('ModifyUserRole', params)
def get_projects(self, user=None):
"""
Returns a list of all projects.
"""
"""Returns a list of all projects."""
if user:
params = {'User': user}
else:
@@ -293,21 +311,17 @@ class NovaAdminClient(object):
[('item', ProjectInfo)])
def get_project(self, name):
"""
Returns a single project with the specified name.
"""
"""Returns a single project with the specified name."""
project = self.apiconn.get_object('DescribeProject',
{'Name': name},
ProjectInfo)
if project.projectname != None:
return project
def create_project(self, projectname, manager_user, description=None,
member_users=None):
"""
Creates a new project.
"""
"""Creates a new project."""
params = {'Name': projectname,
'ManagerUser': manager_user,
'Description': description,
@@ -315,46 +329,35 @@ class NovaAdminClient(object):
return self.apiconn.get_object('RegisterProject', params, ProjectInfo)
def delete_project(self, projectname):
"""
Permanently deletes the specified project.
"""
"""Permanently deletes the specified project."""
return self.apiconn.get_object('DeregisterProject',
{'Name': projectname},
ProjectInfo)
def get_project_members(self, name):
"""
Returns a list of members of a project.
"""
"""Returns a list of members of a project."""
return self.apiconn.get_list('DescribeProjectMembers',
{'Name': name},
[('item', ProjectMember)])
def add_project_member(self, user, project):
"""
Adds a user to a project.
"""
"""Adds a user to a project."""
return self.modify_project_member(user, project, operation='add')
def remove_project_member(self, user, project):
"""
Removes a user from a project.
"""
"""Removes a user from a project."""
return self.modify_project_member(user, project, operation='remove')
def modify_project_member(self, user, project, operation='add'):
"""
Adds or removes a user from a project.
"""
"""Adds or removes a user from a project."""
params = {'User': user,
'Project': project,
'Operation': operation}
return self.apiconn.get_status('ModifyProjectMember', params)
def get_zip(self, user, project):
"""
Returns the content of a zip file containing novarc and access credentials.
"""
"""Returns the content of a zip file containing novarc and access
credentials."""
params = {'Name': user, 'Project': project}
zip = self.apiconn.get_object('GenerateX509ForUser', params, UserInfo)
return zip.file

250
nova/auth/dbdriver.py Normal file
View File

@@ -0,0 +1,250 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Auth driver using the DB as its backend.
"""
import logging
import sys
from nova import context
from nova import exception
from nova import db
class DbDriver(object):
"""DB Auth driver
Defines enter and exit and therefore supports the with/as syntax.
"""
def __init__(self):
"""Imports the LDAP module"""
pass
db
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
pass
def get_user(self, uid):
"""Retrieve user by id"""
user = db.user_get(context.get_admin_context(), uid)
return self._db_user_to_auth_user(user)
def get_user_from_access_key(self, access):
"""Retrieve user by access key"""
user = db.user_get_by_access_key(context.get_admin_context(), access)
return self._db_user_to_auth_user(user)
def get_project(self, pid):
"""Retrieve project by id"""
project = db.project_get(context.get_admin_context(), pid)
return self._db_project_to_auth_projectuser(project)
def get_users(self):
"""Retrieve list of users"""
return [self._db_user_to_auth_user(user)
for user in db.user_get_all(context.get_admin_context())]
def get_projects(self, uid=None):
"""Retrieve list of projects"""
if uid:
result = db.project_get_by_user(context.get_admin_context(), uid)
else:
result = db.project_get_all(context.get_admin_context())
return [self._db_project_to_auth_projectuser(proj) for proj in result]
def create_user(self, name, access_key, secret_key, is_admin):
"""Create a user"""
values = {'id': name,
'access_key': access_key,
'secret_key': secret_key,
'is_admin': is_admin}
try:
user_ref = db.user_create(context.get_admin_context(), values)
return self._db_user_to_auth_user(user_ref)
except exception.Duplicate, e:
raise exception.Duplicate('User %s already exists' % name)
def _db_user_to_auth_user(self, user_ref):
return {'id': user_ref['id'],
'name': user_ref['id'],
'access': user_ref['access_key'],
'secret': user_ref['secret_key'],
'admin': user_ref['is_admin']}
def _db_project_to_auth_projectuser(self, project_ref):
member_ids = [member['id'] for member in project_ref['members']]
return {'id': project_ref['id'],
'name': project_ref['name'],
'project_manager_id': project_ref['project_manager'],
'description': project_ref['description'],
'member_ids': member_ids}
def create_project(self, name, manager_uid,
description=None, member_uids=None):
"""Create a project"""
manager = db.user_get(context.get_admin_context(), manager_uid)
if not manager:
raise exception.NotFound("Project can't be created because "
"manager %s doesn't exist" % manager_uid)
# description is a required attribute
if description is None:
description = name
# First, we ensure that all the given users exist before we go
# on to create the project. This way we won't have to destroy
# the project again because a user turns out to be invalid.
members = set([manager])
if member_uids != None:
for member_uid in member_uids:
member = db.user_get(context.get_admin_context(), member_uid)
if not member:
raise exception.NotFound("Project can't be created "
"because user %s doesn't exist"
% member_uid)
members.add(member)
values = {'id': name,
'name': name,
'project_manager': manager['id'],
'description': description}
try:
project = db.project_create(context.get_admin_context(), values)
except exception.Duplicate:
raise exception.Duplicate("Project can't be created because "
"project %s already exists" % name)
for member in members:
db.project_add_member(context.get_admin_context(),
project['id'],
member['id'])
# This looks silly, but ensures that the members element has been
# correctly populated
project_ref = db.project_get(context.get_admin_context(),
project['id'])
return self._db_project_to_auth_projectuser(project_ref)
def modify_project(self, project_id, manager_uid=None, description=None):
"""Modify an existing project"""
if not manager_uid and not description:
return
values = {}
if manager_uid:
manager = db.user_get(context.get_admin_context(), manager_uid)
if not manager:
raise exception.NotFound("Project can't be modified because "
"manager %s doesn't exist" %
manager_uid)
values['project_manager'] = manager['id']
if description:
values['description'] = description
db.project_update(context.get_admin_context(), project_id, values)
def add_to_project(self, uid, project_id):
"""Add user to project"""
user, project = self._validate_user_and_project(uid, project_id)
db.project_add_member(context.get_admin_context(),
project['id'],
user['id'])
def remove_from_project(self, uid, project_id):
"""Remove user from project"""
user, project = self._validate_user_and_project(uid, project_id)
db.project_remove_member(context.get_admin_context(),
project['id'],
user['id'])
def is_in_project(self, uid, project_id):
"""Check if user is in project"""
user, project = self._validate_user_and_project(uid, project_id)
return user in project.members
def has_role(self, uid, role, project_id=None):
"""Check if user has role
If project is specified, it checks for local role, otherwise it
checks for global role
"""
return role in self.get_user_roles(uid, project_id)
def add_role(self, uid, role, project_id=None):
"""Add role for user (or user and project)"""
if not project_id:
db.user_add_role(context.get_admin_context(), uid, role)
return
db.user_add_project_role(context.get_admin_context(),
uid, project_id, role)
def remove_role(self, uid, role, project_id=None):
"""Remove role for user (or user and project)"""
if not project_id:
db.user_remove_role(context.get_admin_context(), uid, role)
return
db.user_remove_project_role(context.get_admin_context(),
uid, project_id, role)
def get_user_roles(self, uid, project_id=None):
"""Retrieve list of roles for user (or user and project)"""
if project_id is None:
roles = db.user_get_roles(context.get_admin_context(), uid)
return roles
else:
roles = db.user_get_roles_for_project(context.get_admin_context(),
uid, project_id)
return roles
def delete_user(self, id):
"""Delete a user"""
user = db.user_get(context.get_admin_context(), id)
db.user_delete(context.get_admin_context(), user['id'])
def delete_project(self, project_id):
"""Delete a project"""
db.project_delete(context.get_admin_context(), project_id)
def modify_user(self, uid, access_key=None, secret_key=None, admin=None):
"""Modify an existing user"""
if not access_key and not secret_key and admin is None:
return
values = {}
if access_key:
values['access_key'] = access_key
if secret_key:
values['secret_key'] = secret_key
if admin is not None:
values['is_admin'] = admin
db.user_update(context.get_admin_context(), uid, values)
def _validate_user_and_project(self, user_id, project_id):
user = db.user_get(context.get_admin_context(), user_id)
if not user:
raise exception.NotFound('User "%s" not found' % user_id)
project = db.project_get(context.get_admin_context(), project_id)
if not project:
raise exception.NotFound('Project "%s" not found' % project_id)
return user, project

View File

@@ -24,23 +24,47 @@ library to work with nova.
"""
import json
import redis
from nova import datastore
from nova import flags
FLAGS = flags.FLAGS
flags.DEFINE_string('redis_host', '127.0.0.1',
'Host that redis is running on.')
flags.DEFINE_integer('redis_port', 6379,
'Port that redis is running on.')
flags.DEFINE_integer('redis_db', 0, 'Multiple DB keeps tests away')
class Redis(object):
def __init__(self):
if hasattr(self.__class__, '_instance'):
raise Exception('Attempted to instantiate singleton')
@classmethod
def instance(cls):
if not hasattr(cls, '_instance'):
inst = redis.Redis(host=FLAGS.redis_host,
port=FLAGS.redis_port,
db=FLAGS.redis_db)
cls._instance = inst
return cls._instance
SCOPE_BASE = 0
SCOPE_ONELEVEL = 1 # not implemented
SCOPE_ONELEVEL = 1 # Not implemented
SCOPE_SUBTREE = 2
MOD_ADD = 0
MOD_DELETE = 1
MOD_REPLACE = 2
class NO_SUCH_OBJECT(Exception): # pylint: disable-msg=C0103
class NO_SUCH_OBJECT(Exception): # pylint: disable-msg=C0103
"""Duplicate exception class from real LDAP module."""
pass
class OBJECT_CLASS_VIOLATION(Exception): # pylint: disable-msg=C0103
class OBJECT_CLASS_VIOLATION(Exception): # pylint: disable-msg=C0103
"""Duplicate exception class from real LDAP module."""
pass
@@ -163,11 +187,11 @@ class FakeLDAP(object):
key = "%s%s" % (self.__redis_prefix, dn)
value_dict = dict([(k, _to_json(v)) for k, v in attr])
datastore.Redis.instance().hmset(key, value_dict)
Redis.instance().hmset(key, value_dict)
def delete_s(self, dn):
"""Remove the ldap object at specified dn."""
datastore.Redis.instance().delete("%s%s" % (self.__redis_prefix, dn))
Redis.instance().delete("%s%s" % (self.__redis_prefix, dn))
def modify_s(self, dn, attrs):
"""Modify the object at dn using the attribute list.
@@ -175,16 +199,18 @@ class FakeLDAP(object):
Args:
dn -- a dn
attrs -- a list of tuples in the following form:
([MOD_ADD | MOD_DELETE], attribute, value)
([MOD_ADD | MOD_DELETE | MOD_REPACE], attribute, value)
"""
redis = datastore.Redis.instance()
redis = Redis.instance()
key = "%s%s" % (self.__redis_prefix, dn)
for cmd, k, v in attrs:
values = _from_json(redis.hget(key, k))
if cmd == MOD_ADD:
values.append(v)
elif cmd == MOD_REPLACE:
values = [v]
else:
values.remove(v)
values = redis.hset(key, k, _to_json(values))
@@ -201,7 +227,7 @@ class FakeLDAP(object):
"""
if scope != SCOPE_BASE and scope != SCOPE_SUBTREE:
raise NotImplementedError(str(scope))
redis = datastore.Redis.instance()
redis = Redis.instance()
if scope == SCOPE_BASE:
keys = ["%s%s" % (self.__redis_prefix, dn)]
else:
@@ -226,6 +252,6 @@ class FakeLDAP(object):
return objects
@property
def __redis_prefix(self): # pylint: disable-msg=R0201
def __redis_prefix(self): # pylint: disable-msg=R0201
"""Get the prefix to use for all redis keys."""
return 'ldap:'

View File

@@ -99,13 +99,6 @@ class LdapDriver(object):
dn = FLAGS.ldap_user_subtree
return self.__to_user(self.__find_object(dn, query))
def get_key_pair(self, uid, key_name):
"""Retrieve key pair by uid and key name"""
dn = 'cn=%s,%s' % (key_name,
self.__uid_to_dn(uid))
attr = self.__find_object(dn, '(objectclass=novaKeyPair)')
return self.__to_key_pair(uid, attr)
def get_project(self, pid):
"""Retrieve project by id"""
dn = 'cn=%s,%s' % (pid,
@@ -119,12 +112,6 @@ class LdapDriver(object):
'(objectclass=novaUser)')
return [self.__to_user(attr) for attr in attrs]
def get_key_pairs(self, uid):
"""Retrieve list of key pairs"""
attrs = self.__find_objects(self.__uid_to_dn(uid),
'(objectclass=novaKeyPair)')
return [self.__to_key_pair(uid, attr) for attr in attrs]
def get_projects(self, uid=None):
"""Retrieve list of projects"""
pattern = '(objectclass=novaProject)'
@@ -154,21 +141,6 @@ class LdapDriver(object):
self.conn.add_s(self.__uid_to_dn(name), attr)
return self.__to_user(dict(attr))
def create_key_pair(self, uid, key_name, public_key, fingerprint):
"""Create a key pair"""
# TODO(vish): possibly refactor this to store keys in their own ou
# and put dn reference in the user object
attr = [
('objectclass', ['novaKeyPair']),
('cn', [key_name]),
('sshPublicKey', [public_key]),
('keyFingerprint', [fingerprint]),
]
self.conn.add_s('cn=%s,%s' % (key_name,
self.__uid_to_dn(uid)),
attr)
return self.__to_key_pair(uid, dict(attr))
def create_project(self, name, manager_uid,
description=None, member_uids=None):
"""Create a project"""
@@ -202,6 +174,24 @@ class LdapDriver(object):
self.conn.add_s('cn=%s,%s' % (name, FLAGS.ldap_project_subtree), attr)
return self.__to_project(dict(attr))
def modify_project(self, project_id, manager_uid=None, description=None):
"""Modify an existing project"""
if not manager_uid and not description:
return
attr = []
if manager_uid:
if not self.__user_exists(manager_uid):
raise exception.NotFound("Project can't be modified because "
"manager %s doesn't exist" %
manager_uid)
manager_dn = self.__uid_to_dn(manager_uid)
attr.append((self.ldap.MOD_REPLACE, 'projectManager', manager_dn))
if description:
attr.append((self.ldap.MOD_REPLACE, 'description', description))
self.conn.modify_s('cn=%s,%s' % (project_id,
FLAGS.ldap_project_subtree),
attr)
def add_to_project(self, uid, project_id):
"""Add user to project"""
dn = 'cn=%s,%s' % (project_id, FLAGS.ldap_project_subtree)
@@ -265,18 +255,8 @@ class LdapDriver(object):
"""Delete a user"""
if not self.__user_exists(uid):
raise exception.NotFound("User %s doesn't exist" % uid)
self.__delete_key_pairs(uid)
self.__remove_from_all(uid)
self.conn.delete_s('uid=%s,%s' % (uid,
FLAGS.ldap_user_subtree))
def delete_key_pair(self, uid, key_name):
"""Delete a key pair"""
if not self.__key_pair_exists(uid, key_name):
raise exception.NotFound("Key Pair %s doesn't exist for user %s" %
(key_name, uid))
self.conn.delete_s('cn=%s,uid=%s,%s' % (key_name, uid,
FLAGS.ldap_user_subtree))
self.conn.delete_s(self.__uid_to_dn(uid))
def delete_project(self, project_id):
"""Delete a project"""
@@ -284,14 +264,23 @@ class LdapDriver(object):
self.__delete_roles(project_dn)
self.__delete_group(project_dn)
def modify_user(self, uid, access_key=None, secret_key=None, admin=None):
"""Modify an existing project"""
if not access_key and not secret_key and admin is None:
return
attr = []
if access_key:
attr.append((self.ldap.MOD_REPLACE, 'accessKey', access_key))
if secret_key:
attr.append((self.ldap.MOD_REPLACE, 'secretKey', secret_key))
if admin is not None:
attr.append((self.ldap.MOD_REPLACE, 'isAdmin', str(admin).upper()))
self.conn.modify_s(self.__uid_to_dn(uid), attr)
def __user_exists(self, uid):
"""Check if user exists"""
return self.get_user(uid) != None
def __key_pair_exists(self, uid, key_name):
"""Check if key pair exists"""
return self.get_key_pair(uid, key_name) != None
def __project_exists(self, project_id):
"""Check if project exists"""
return self.get_project(project_id) != None
@@ -305,24 +294,26 @@ class LdapDriver(object):
def __find_dns(self, dn, query=None, scope=None):
"""Find dns by query"""
if scope is None: # one of the flags is 0!!
if scope is None:
# One of the flags is 0!
scope = self.ldap.SCOPE_SUBTREE
try:
res = self.conn.search_s(dn, scope, query)
except self.ldap.NO_SUCH_OBJECT:
return []
# just return the DNs
# Just return the DNs
return [dn for dn, _attributes in res]
def __find_objects(self, dn, query=None, scope=None):
"""Find objects by query"""
if scope is None: # one of the flags is 0!!
if scope is None:
# One of the flags is 0!
scope = self.ldap.SCOPE_SUBTREE
try:
res = self.conn.search_s(dn, scope, query)
except self.ldap.NO_SUCH_OBJECT:
return []
# just return the attributes
# Just return the attributes
return [attributes for dn, attributes in res]
def __find_role_dns(self, tree):
@@ -341,13 +332,6 @@ class LdapDriver(object):
"""Check if group exists"""
return self.__find_object(dn, '(objectclass=groupOfNames)') != None
def __delete_key_pairs(self, uid):
"""Delete all key pairs for user"""
keys = self.get_key_pairs(uid)
if keys != None:
for key in keys:
self.delete_key_pair(uid, key['name'])
@staticmethod
def __role_to_dn(role, project_id=None):
"""Convert role to corresponding dn"""
@@ -472,18 +456,6 @@ class LdapDriver(object):
'secret': attr['secretKey'][0],
'admin': (attr['isAdmin'][0] == 'TRUE')}
@staticmethod
def __to_key_pair(owner, attr):
"""Convert ldap attributes to KeyPair object"""
if attr == None:
return None
return {
'id': attr['cn'][0],
'name': attr['cn'][0],
'owner_id': owner,
'public_key': attr['sshPublicKey'][0],
'fingerprint': attr['keyFingerprint'][0]}
def __to_project(self, attr):
"""Convert ldap attributes to Project object"""
if attr == None:
@@ -510,6 +482,6 @@ class LdapDriver(object):
class FakeLdapDriver(LdapDriver):
"""Fake Ldap Auth driver"""
def __init__(self): # pylint: disable-msg=W0231
def __init__(self): # pylint: disable-msg=W0231
__import__('nova.auth.fakeldap')
self.ldap = sys.modules['nova.auth.fakeldap']

View File

@@ -23,11 +23,12 @@ Nova authentication management
import logging
import os
import shutil
import string # pylint: disable-msg=W0402
import string # pylint: disable-msg=W0402
import tempfile
import uuid
import zipfile
from nova import context
from nova import crypto
from nova import db
from nova import exception
@@ -44,7 +45,7 @@ flags.DEFINE_list('allowed_roles',
# NOTE(vish): a user with one of these roles will be a superuser and
# have access to all api commands
flags.DEFINE_list('superuser_roles', ['cloudadmin'],
'Roles that ignore rbac checking completely')
'Roles that ignore authorization checking completely')
# NOTE(vish): a user with one of these roles will have it for every
# project, even if he or she is not a member of the project
@@ -69,7 +70,7 @@ flags.DEFINE_string('credential_cert_subject',
'/C=US/ST=California/L=MountainView/O=AnsoLabs/'
'OU=NovaDev/CN=%s-%s',
'Subject for certificate for users')
flags.DEFINE_string('auth_driver', 'nova.auth.ldapdriver.FakeLdapDriver',
flags.DEFINE_string('auth_driver', 'nova.auth.dbdriver.DbDriver',
'Driver that auth manager uses')
@@ -128,24 +129,6 @@ class User(AuthBase):
def is_project_manager(self, project):
return AuthManager().is_project_manager(self, project)
def generate_key_pair(self, name):
return AuthManager().generate_key_pair(self.id, name)
def create_key_pair(self, name, public_key, fingerprint):
return AuthManager().create_key_pair(self.id,
name,
public_key,
fingerprint)
def get_key_pair(self, name):
return AuthManager().get_key_pair(self.id, name)
def delete_key_pair(self, name):
return AuthManager().delete_key_pair(self.id, name)
def get_key_pairs(self):
return AuthManager().get_key_pairs(self.id)
def __repr__(self):
return "User('%s', '%s', '%s', '%s', %s)" % (self.id,
self.name,
@@ -154,29 +137,6 @@ class User(AuthBase):
self.admin)
class KeyPair(AuthBase):
"""Represents an ssh key returned from the datastore
Even though this object is named KeyPair, only the public key and
fingerprint is stored. The user's private key is not saved.
"""
def __init__(self, id, name, owner_id, public_key, fingerprint):
AuthBase.__init__(self)
self.id = id
self.name = name
self.owner_id = owner_id
self.public_key = public_key
self.fingerprint = fingerprint
def __repr__(self):
return "KeyPair('%s', '%s', '%s', '%s', '%s')" % (self.id,
self.name,
self.owner_id,
self.public_key,
self.fingerprint)
class Project(AuthBase):
"""Represents a Project returned from the datastore"""
@@ -242,7 +202,7 @@ class AuthManager(object):
def __new__(cls, *args, **kwargs):
"""Returns the AuthManager singleton"""
if not cls._instance:
if not cls._instance or ('new' in kwargs and kwargs['new']):
cls._instance = super(AuthManager, cls).__new__(cls)
return cls._instance
@@ -307,7 +267,7 @@ class AuthManager(object):
# NOTE(vish): if we stop using project name as id we need better
# logic to find a default project for user
if project_id is '':
if project_id == '':
project_id = user.name
project = self.get_project(project_id)
@@ -345,7 +305,7 @@ class AuthManager(object):
return "%s:%s" % (user.access, Project.safe_id(project))
def is_superuser(self, user):
"""Checks for superuser status, allowing user to bypass rbac
"""Checks for superuser status, allowing user to bypass authorization
@type user: User or uid
@param user: User to check.
@@ -495,7 +455,7 @@ class AuthManager(object):
return [Project(**project_dict) for project_dict in project_list]
def create_project(self, name, manager_user, description=None,
member_users=None, context=None):
member_users=None):
"""Create a project
@type name: str
@@ -525,14 +485,28 @@ class AuthManager(object):
member_users)
if project_dict:
project = Project(**project_dict)
try:
self.network_manager.allocate_network(context,
project.id)
except:
drv.delete_project(project.id)
raise
return project
def modify_project(self, project, manager_user=None, description=None):
"""Modify a project
@type name: Project or project_id
@param project: The project to modify.
@type manager_user: User or uid
@param manager_user: This user will be the new project manager.
@type description: str
@param project: This will be the new description of the project.
"""
if manager_user:
manager_user = User.safe_id(manager_user)
with self.driver() as drv:
drv.modify_project(Project.safe_id(project),
manager_user,
description)
def add_to_project(self, user, project):
"""Add user to project"""
with self.driver() as drv:
@@ -558,7 +532,7 @@ class AuthManager(object):
Project.safe_id(project))
@staticmethod
def get_project_vpn_data(project, context=None):
def get_project_vpn_data(project):
"""Gets vpn ip and port for project
@type project: Project or project_id
@@ -569,7 +543,7 @@ class AuthManager(object):
not been allocated for user.
"""
network_ref = db.project_get_network(context,
network_ref = db.project_get_network(context.get_admin_context(),
Project.safe_id(project))
if not network_ref['vpn_public_port']:
@@ -577,15 +551,8 @@ class AuthManager(object):
return (network_ref['vpn_public_address'],
network_ref['vpn_public_port'])
def delete_project(self, project, context=None):
def delete_project(self, project):
"""Deletes a project"""
try:
network_ref = db.project_get_network(context,
Project.safe_id(project))
db.network_destroy(context, network_ref['id'])
except:
logging.exception('Could not destroy network for %s',
project)
with self.driver() as drv:
drv.delete_project(Project.safe_id(project))
@@ -643,67 +610,20 @@ class AuthManager(object):
return User(**user_dict)
def delete_user(self, user):
"""Deletes a user"""
"""Deletes a user
Additionally deletes all users key_pairs"""
uid = User.safe_id(user)
db.key_pair_destroy_all_by_user(context.get_admin_context(),
uid)
with self.driver() as drv:
drv.delete_user(User.safe_id(user))
drv.delete_user(uid)
def generate_key_pair(self, user, key_name):
"""Generates a key pair for a user
Generates a public and private key, stores the public key using the
key_name, and returns the private key and fingerprint.
@type user: User or uid
@param user: User for which to create key pair.
@type key_name: str
@param key_name: Name to use for the generated KeyPair.
@rtype: tuple (private_key, fingerprint)
@return: A tuple containing the private_key and fingerprint.
"""
# NOTE(vish): generating key pair is slow so check for legal
# creation before creating keypair
def modify_user(self, user, access_key=None, secret_key=None, admin=None):
"""Modify credentials for a user"""
uid = User.safe_id(user)
with self.driver() as drv:
if not drv.get_user(uid):
raise exception.NotFound("User %s doesn't exist" % user)
if drv.get_key_pair(uid, key_name):
raise exception.Duplicate("The keypair %s already exists"
% key_name)
private_key, public_key, fingerprint = crypto.generate_key_pair()
self.create_key_pair(uid, key_name, public_key, fingerprint)
return private_key, fingerprint
def create_key_pair(self, user, key_name, public_key, fingerprint):
"""Creates a key pair for user"""
with self.driver() as drv:
kp_dict = drv.create_key_pair(User.safe_id(user),
key_name,
public_key,
fingerprint)
if kp_dict:
return KeyPair(**kp_dict)
def get_key_pair(self, user, key_name):
"""Retrieves a key pair for user"""
with self.driver() as drv:
kp_dict = drv.get_key_pair(User.safe_id(user), key_name)
if kp_dict:
return KeyPair(**kp_dict)
def get_key_pairs(self, user):
"""Retrieves all key pairs for user"""
with self.driver() as drv:
kp_list = drv.get_key_pairs(User.safe_id(user))
if not kp_list:
return []
return [KeyPair(**kp_dict) for kp_dict in kp_list]
def delete_key_pair(self, user, key_name):
"""Deletes a key pair for user"""
with self.driver() as drv:
drv.delete_key_pair(User.safe_id(user), key_name)
drv.modify_user(uid, access_key, secret_key, admin)
def get_credentials(self, user, project=None):
"""Get credential zip for user in project"""
@@ -722,7 +642,10 @@ class AuthManager(object):
zippy.writestr(FLAGS.credential_key_file, private_key)
zippy.writestr(FLAGS.credential_cert_file, signed_cert)
(vpn_ip, vpn_port) = self.get_project_vpn_data(project)
try:
(vpn_ip, vpn_port) = self.get_project_vpn_data(project)
except exception.NotFound:
vpn_ip = None
if vpn_ip:
configfile = open(FLAGS.vpn_client_template, "r")
s = string.Template(configfile.read())

View File

@@ -1,69 +0,0 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Role-based access control decorators to use fpr wrapping other
methods with."""
from nova import exception
def allow(*roles):
"""Allow the given roles access the wrapped function."""
def wrap(func): # pylint: disable-msg=C0111
def wrapped_func(self, context, *args,
**kwargs): # pylint: disable-msg=C0111
if context.user.is_superuser():
return func(self, context, *args, **kwargs)
for role in roles:
if __matches_role(context, role):
return func(self, context, *args, **kwargs)
raise exception.NotAuthorized()
return wrapped_func
return wrap
def deny(*roles):
"""Deny the given roles access the wrapped function."""
def wrap(func): # pylint: disable-msg=C0111
def wrapped_func(self, context, *args,
**kwargs): # pylint: disable-msg=C0111
if context.user.is_superuser():
return func(self, context, *args, **kwargs)
for role in roles:
if __matches_role(context, role):
raise exception.NotAuthorized()
return func(self, context, *args, **kwargs)
return wrapped_func
return wrap
def __matches_role(context, role):
"""Check if a role is allowed."""
if role == 'all':
return True
if role == 'none':
return False
return context.project.has_role(context.user.id, role)

View File

@@ -1,53 +0,0 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Datastore:
MAKE Sure that ReDIS is running, and your flags are set properly,
before trying to run this.
"""
import logging
import redis
from nova import flags
FLAGS = flags.FLAGS
flags.DEFINE_string('redis_host', '127.0.0.1',
'Host that redis is running on.')
flags.DEFINE_integer('redis_port', 6379,
'Port that redis is running on.')
flags.DEFINE_integer('redis_db', 0, 'Multiple DB keeps tests away')
class Redis(object):
def __init__(self):
if hasattr(self.__class__, '_instance'):
raise Exception('Attempted to instantiate singleton')
@classmethod
def instance(cls):
if not hasattr(cls, '_instance'):
inst = redis.Redis(host=FLAGS.redis_host,
port=FLAGS.redis_port,
db=FLAGS.redis_db)
cls._instance = inst
return cls._instance

View File

@@ -1,243 +0,0 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Admin API controller, exposed through http via the api worker.
"""
import base64
import uuid
import subprocess
import random
from nova import db
from nova import exception
from nova.auth import manager
from utils import novadir
def user_dict(user, base64_file=None):
"""Convert the user object to a result dict"""
if user:
return {
'username': user.id,
'accesskey': user.access,
'secretkey': user.secret,
'file': base64_file}
else:
return {}
def project_dict(project):
"""Convert the project object to a result dict"""
if project:
return {
'projectname': project.id,
'project_manager_id': project.project_manager_id,
'description': project.description}
else:
return {}
def host_dict(host):
"""Convert a host model object to a result dict"""
if host:
# FIXME(vish)
return host.state
else:
return {}
def admin_only(target):
"""Decorator for admin-only API calls"""
def wrapper(*args, **kwargs):
"""Internal wrapper method for admin-only API calls"""
context = args[1]
if context.user.is_admin():
return target(*args, **kwargs)
else:
return {}
return wrapper
class AdminController(object):
"""
API Controller for users, hosts, nodes, and workers.
Trivial admin_only wrapper will be replaced with RBAC,
allowing project managers to administer project users.
"""
def __str__(self):
return 'AdminController'
@admin_only
def describe_user(self, _context, name, **_kwargs):
"""Returns user data, including access and secret keys."""
return user_dict(manager.AuthManager().get_user(name))
@admin_only
def describe_users(self, _context, **_kwargs):
"""Returns all users - should be changed to deal with a list."""
return {'userSet':
[user_dict(u) for u in manager.AuthManager().get_users()] }
@admin_only
def register_user(self, _context, name, **_kwargs):
"""Creates a new user, and returns generated credentials."""
return user_dict(manager.AuthManager().create_user(name))
@admin_only
def deregister_user(self, _context, name, **_kwargs):
"""Deletes a single user (NOT undoable.)
Should throw an exception if the user has instances,
volumes, or buckets remaining.
"""
manager.AuthManager().delete_user(name)
return True
@admin_only
def describe_roles(self, context, project_roles=True, **kwargs):
"""Returns a list of allowed roles."""
roles = manager.AuthManager().get_roles(project_roles)
return { 'roles': [{'role': r} for r in roles]}
@admin_only
def describe_user_roles(self, context, user, project=None, **kwargs):
"""Returns a list of roles for the given user.
Omitting project will return any global roles that the user has.
Specifying project will return only project specific roles.
"""
roles = manager.AuthManager().get_user_roles(user, project=project)
return { 'roles': [{'role': r} for r in roles]}
@admin_only
def modify_user_role(self, context, user, role, project=None,
operation='add', **kwargs):
"""Add or remove a role for a user and project."""
if operation == 'add':
manager.AuthManager().add_role(user, role, project)
elif operation == 'remove':
manager.AuthManager().remove_role(user, role, project)
else:
raise exception.ApiError('operation must be add or remove')
return True
@admin_only
def generate_x509_for_user(self, _context, name, project=None, **kwargs):
"""Generates and returns an x509 certificate for a single user.
Is usually called from a client that will wrap this with
access and secret key info, and return a zip file.
"""
if project is None:
project = name
project = manager.AuthManager().get_project(project)
user = manager.AuthManager().get_user(name)
return user_dict(user, base64.b64encode(project.get_credentials(user)))
@admin_only
def describe_project(self, context, name, **kwargs):
"""Returns project data, including member ids."""
return project_dict(manager.AuthManager().get_project(name))
@admin_only
def describe_projects(self, context, user=None, **kwargs):
"""Returns all projects - should be changed to deal with a list."""
return {'projectSet':
[project_dict(u) for u in
manager.AuthManager().get_projects(user=user)]}
@admin_only
def register_project(self, context, name, manager_user, description=None,
member_users=None, **kwargs):
"""Creates a new project"""
return project_dict(
manager.AuthManager().create_project(
name,
manager_user,
description=None,
member_users=None))
@admin_only
def deregister_project(self, context, name):
"""Permanently deletes a project."""
manager.AuthManager().delete_project(name)
return True
@admin_only
def describe_project_members(self, context, name, **kwargs):
project = manager.AuthManager().get_project(name)
result = {
'members': [{'member': m} for m in project.member_ids]}
return result
@admin_only
def modify_project_member(self, context, user, project, operation, **kwargs):
"""Add or remove a user from a project."""
if operation =='add':
manager.AuthManager().add_to_project(user, project)
elif operation == 'remove':
manager.AuthManager().remove_from_project(user, project)
else:
raise exception.ApiError('operation must be add or remove')
return True
@admin_only
def describe_hosts(self, _context, **_kwargs):
"""Returns status info for all nodes. Includes:
* Disk Space
* Instance List
* RAM used
* CPU used
* DHCP servers running
* Iptables / bridges
"""
return {'hostSet': [host_dict(h) for h in db.host_get_all()]}
@admin_only
def describe_host(self, _context, name, **_kwargs):
"""Returns status info for single node."""
return host_dict(db.host_get(name))
@admin_only
def create_console(self, _context, kind, instance_id, **_kwargs):
"""Create a Console"""
#instance = db.instance_get(_context, instance_id)
def get_port():
for i in xrange(0,100): # don't loop forever
port = int(random.uniform(10000, 12000))
cmd = "netcat 0.0.0.0 " + str(port) + " -w 2 < /dev/null"
# this Popen will exit with 0 only if the port is in use,
# so a nonzero return value implies it is unused
port_is_unused = subprocess.Popen(cmd, shell=True).wait()
if port_is_unused:
return port
raise 'Unable to find an open port'
port = str(get_port())
token = str(uuid.uuid4())
host = '127.0.0.1' #TODO add actual host
cmd = novadir() + "tools/ajaxterm//ajaxterm.py --command 'ssh root@" + host + "' -t " \
+ token + " -p " + port
port_is_unused = subprocess.Popen(cmd, shell=True) #TODO error check
return {'url': 'http://tonbuntu:' + port + '/?token=' + token } #TODO - s/tonbuntu/api_server_public_ip

View File

@@ -1,344 +0,0 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Tornado REST API Request Handlers for Nova functions
Most calls are proxied into the responsible controller.
"""
import logging
import multiprocessing
import random
import re
import urllib
# TODO(termie): replace minidom with etree
from xml.dom import minidom
import tornado.web
from twisted.internet import defer
from nova import crypto
from nova import exception
from nova import flags
from nova import utils
from nova.auth import manager
import nova.cloudpipe.api
from nova.endpoint import cloud
FLAGS = flags.FLAGS
flags.DEFINE_integer('cc_port', 8773, 'cloud controller port')
_log = logging.getLogger("api")
_log.setLevel(logging.DEBUG)
_c2u = re.compile('(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))')
def _camelcase_to_underscore(str):
return _c2u.sub(r'_\1', str).lower().strip('_')
def _underscore_to_camelcase(str):
return ''.join([x[:1].upper() + x[1:] for x in str.split('_')])
def _underscore_to_xmlcase(str):
res = _underscore_to_camelcase(str)
return res[:1].lower() + res[1:]
class APIRequestContext(object):
def __init__(self, handler, user, project):
self.handler = handler
self.user = user
self.project = project
self.request_id = ''.join(
[random.choice('ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890-')
for x in xrange(20)]
)
class APIRequest(object):
def __init__(self, controller, action):
self.controller = controller
self.action = action
def send(self, context, **kwargs):
try:
method = getattr(self.controller,
_camelcase_to_underscore(self.action))
except AttributeError:
_error = ('Unsupported API request: controller = %s,'
'action = %s') % (self.controller, self.action)
_log.warning(_error)
# TODO: Raise custom exception, trap in apiserver,
# and reraise as 400 error.
raise Exception(_error)
args = {}
for key, value in kwargs.items():
parts = key.split(".")
key = _camelcase_to_underscore(parts[0])
if len(parts) > 1:
d = args.get(key, {})
d[parts[1]] = value[0]
value = d
else:
value = value[0]
args[key] = value
for key in args.keys():
if isinstance(args[key], dict):
if args[key] != {} and args[key].keys()[0].isdigit():
s = args[key].items()
s.sort()
args[key] = [v for k, v in s]
d = defer.maybeDeferred(method, context, **args)
d.addCallback(self._render_response, context.request_id)
return d
def _render_response(self, response_data, request_id):
xml = minidom.Document()
response_el = xml.createElement(self.action + 'Response')
response_el.setAttribute('xmlns',
'http://ec2.amazonaws.com/doc/2009-11-30/')
request_id_el = xml.createElement('requestId')
request_id_el.appendChild(xml.createTextNode(request_id))
response_el.appendChild(request_id_el)
if(response_data == True):
self._render_dict(xml, response_el, {'return': 'true'})
else:
self._render_dict(xml, response_el, response_data)
xml.appendChild(response_el)
response = xml.toxml()
xml.unlink()
_log.debug(response)
return response
def _render_dict(self, xml, el, data):
try:
for key in data.keys():
val = data[key]
el.appendChild(self._render_data(xml, key, val))
except:
_log.debug(data)
raise
def _render_data(self, xml, el_name, data):
el_name = _underscore_to_xmlcase(el_name)
data_el = xml.createElement(el_name)
if isinstance(data, list):
for item in data:
data_el.appendChild(self._render_data(xml, 'item', item))
elif isinstance(data, dict):
self._render_dict(xml, data_el, data)
elif hasattr(data, '__dict__'):
self._render_dict(xml, data_el, data.__dict__)
elif isinstance(data, bool):
data_el.appendChild(xml.createTextNode(str(data).lower()))
elif data != None:
data_el.appendChild(xml.createTextNode(str(data)))
return data_el
class RootRequestHandler(tornado.web.RequestHandler):
def get(self):
# available api versions
versions = [
'1.0',
'2007-01-19',
'2007-03-01',
'2007-08-29',
'2007-10-10',
'2007-12-15',
'2008-02-01',
'2008-09-01',
'2009-04-04',
]
for version in versions:
self.write('%s\n' % version)
self.finish()
class MetadataRequestHandler(tornado.web.RequestHandler):
def print_data(self, data):
if isinstance(data, dict):
output = ''
for key in data:
if key == '_name':
continue
output += key
if isinstance(data[key], dict):
if '_name' in data[key]:
output += '=' + str(data[key]['_name'])
else:
output += '/'
output += '\n'
self.write(output[:-1]) # cut off last \n
elif isinstance(data, list):
self.write('\n'.join(data))
else:
self.write(str(data))
def lookup(self, path, data):
items = path.split('/')
for item in items:
if item:
if not isinstance(data, dict):
return data
if not item in data:
return None
data = data[item]
return data
def get(self, path):
cc = self.application.controllers['Cloud']
meta_data = cc.get_metadata(self.request.remote_ip)
if meta_data is None:
_log.error('Failed to get metadata for ip: %s' %
self.request.remote_ip)
raise tornado.web.HTTPError(404)
data = self.lookup(path, meta_data)
if data is None:
raise tornado.web.HTTPError(404)
self.print_data(data)
self.finish()
class APIRequestHandler(tornado.web.RequestHandler):
def get(self, controller_name):
self.execute(controller_name)
@tornado.web.asynchronous
def execute(self, controller_name):
# Obtain the appropriate controller for this request.
try:
controller = self.application.controllers[controller_name]
except KeyError:
self._error('unhandled', 'no controller named %s' % controller_name)
return
args = self.request.arguments
# Read request signature.
try:
signature = args.pop('Signature')[0]
except:
raise tornado.web.HTTPError(400)
# Make a copy of args for authentication and signature verification.
auth_params = {}
for key, value in args.items():
auth_params[key] = value[0]
# Get requested action and remove authentication args for final request.
try:
action = args.pop('Action')[0]
access = args.pop('AWSAccessKeyId')[0]
args.pop('SignatureMethod')
args.pop('SignatureVersion')
args.pop('Version')
args.pop('Timestamp')
except:
raise tornado.web.HTTPError(400)
# Authenticate the request.
try:
(user, project) = manager.AuthManager().authenticate(
access,
signature,
auth_params,
self.request.method,
self.request.host,
self.request.path
)
except exception.Error, ex:
logging.debug("Authentication Failure: %s" % ex)
raise tornado.web.HTTPError(403)
_log.debug('action: %s' % action)
for key, value in args.items():
_log.debug('arg: %s\t\tval: %s' % (key, value))
request = APIRequest(controller, action)
context = APIRequestContext(self, user, project)
d = request.send(context, **args)
# d.addCallback(utils.debug)
# TODO: Wrap response in AWS XML format
d.addCallbacks(self._write_callback, self._error_callback)
def _write_callback(self, data):
self.set_header('Content-Type', 'text/xml')
self.write(data)
self.finish()
def _error_callback(self, failure):
try:
failure.raiseException()
except exception.ApiError as ex:
self._error(type(ex).__name__ + "." + ex.code, ex.message)
# TODO(vish): do something more useful with unknown exceptions
except Exception as ex:
self._error(type(ex).__name__, str(ex))
raise
def post(self, controller_name):
self.execute(controller_name)
def _error(self, code, message):
self._status_code = 400
self.set_header('Content-Type', 'text/xml')
self.write('<?xml version="1.0"?>\n')
self.write('<Response><Errors><Error><Code>%s</Code>'
'<Message>%s</Message></Error></Errors>'
'<RequestID>?</RequestID></Response>' % (code, message))
self.finish()
class APIServerApplication(tornado.web.Application):
def __init__(self, controllers):
tornado.web.Application.__init__(self, [
(r'/', RootRequestHandler),
(r'/cloudpipe/(.*)', nova.cloudpipe.api.CloudPipeRequestHandler),
(r'/cloudpipe', nova.cloudpipe.api.CloudPipeRequestHandler),
(r'/services/([A-Za-z0-9]+)/', APIRequestHandler),
(r'/latest/([-A-Za-z0-9/]*)', MetadataRequestHandler),
(r'/2009-04-04/([-A-Za-z0-9/]*)', MetadataRequestHandler),
(r'/2008-09-01/([-A-Za-z0-9/]*)', MetadataRequestHandler),
(r'/2008-02-01/([-A-Za-z0-9/]*)', MetadataRequestHandler),
(r'/2007-12-15/([-A-Za-z0-9/]*)', MetadataRequestHandler),
(r'/2007-10-10/([-A-Za-z0-9/]*)', MetadataRequestHandler),
(r'/2007-08-29/([-A-Za-z0-9/]*)', MetadataRequestHandler),
(r'/2007-03-01/([-A-Za-z0-9/]*)', MetadataRequestHandler),
(r'/2007-01-19/([-A-Za-z0-9/]*)', MetadataRequestHandler),
(r'/1.0/([-A-Za-z0-9/]*)', MetadataRequestHandler),
], pool=multiprocessing.Pool(4))
self.controllers = controllers

View File

@@ -1,695 +0,0 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Cloud Controller: Implementation of EC2 REST API calls, which are
dispatched to other nodes via AMQP RPC. State is via distributed
datastore.
"""
import base64
import logging
import os
import time
from twisted.internet import defer
from nova import db
from nova import exception
from nova import flags
from nova import rpc
from nova import utils
from nova.auth import rbac
from nova.auth import manager
from nova.compute.instance_types import INSTANCE_TYPES
from nova.endpoint import images
FLAGS = flags.FLAGS
def _gen_key(user_id, key_name):
""" Tuck this into AuthManager """
try:
mgr = manager.AuthManager()
private_key, fingerprint = mgr.generate_key_pair(user_id, key_name)
except Exception as ex:
return {'exception': ex}
return {'private_key': private_key, 'fingerprint': fingerprint}
class CloudController(object):
""" CloudController provides the critical dispatch between
inbound API calls through the endpoint and messages
sent to the other nodes.
"""
def __init__(self):
self.network_manager = utils.import_object(FLAGS.network_manager)
self.setup()
def __str__(self):
return 'CloudController'
def setup(self):
""" Ensure the keychains and folders exist. """
# FIXME(ja): this should be moved to a nova-manage command,
# if not setup throw exceptions instead of running
# Create keys folder, if it doesn't exist
if not os.path.exists(FLAGS.keys_path):
os.makedirs(FLAGS.keys_path)
# Gen root CA, if we don't have one
root_ca_path = os.path.join(FLAGS.ca_path, FLAGS.ca_file)
if not os.path.exists(root_ca_path):
start = os.getcwd()
os.chdir(FLAGS.ca_path)
# TODO: Do this with M2Crypto instead
utils.runthis("Generating root CA: %s", "sh genrootca.sh")
os.chdir(start)
def _get_mpi_data(self, project_id):
result = {}
for instance in db.instance_get_by_project(project_id):
line = '%s slots=%d' % (instance.fixed_ip['str_id'],
INSTANCE_TYPES[instance['instance_type']]['vcpus'])
if instance['key_name'] in result:
result[instance['key_name']].append(line)
else:
result[instance['key_name']] = [line]
return result
def get_metadata(self, ipaddress):
i = db.fixed_ip_get_instance(ipaddress)
if i is None:
return None
mpi = self._get_mpi_data(i['project_id'])
if i['key_name']:
keys = {
'0': {
'_name': i['key_name'],
'openssh-key': i['key_data']
}
}
else:
keys = ''
hostname = i['hostname']
data = {
'user-data': base64.b64decode(i['user_data']),
'meta-data': {
'ami-id': i['image_id'],
'ami-launch-index': i['ami_launch_index'],
'ami-manifest-path': 'FIXME', # image property
'block-device-mapping': { # TODO: replace with real data
'ami': 'sda1',
'ephemeral0': 'sda2',
'root': '/dev/sda1',
'swap': 'sda3'
},
'hostname': hostname,
'instance-action': 'none',
'instance-id': i['instance_id'],
'instance-type': i.get('instance_type', ''),
'local-hostname': hostname,
'local-ipv4': i['private_dns_name'], # TODO: switch to IP
'kernel-id': i.get('kernel_id', ''),
'placement': {
'availaibility-zone': i.get('availability_zone', 'nova'),
},
'public-hostname': hostname,
'public-ipv4': i.get('dns_name', ''), # TODO: switch to IP
'public-keys': keys,
'ramdisk-id': i.get('ramdisk_id', ''),
'reservation-id': i['reservation_id'],
'security-groups': i.get('groups', ''),
'mpi': mpi
}
}
if False: # TODO: store ancestor ids
data['ancestor-ami-ids'] = []
if i.get('product_codes', None):
data['product-codes'] = i['product_codes']
return data
@rbac.allow('all')
def describe_availability_zones(self, context, **kwargs):
return {'availabilityZoneInfo': [{'zoneName': 'nova',
'zoneState': 'available'}]}
@rbac.allow('all')
def describe_regions(self, context, region_name=None, **kwargs):
# TODO(vish): region_name is an array. Support filtering
return {'regionInfo': [{'regionName': 'nova',
'regionUrl': FLAGS.ec2_url}]}
@rbac.allow('all')
def describe_snapshots(self,
context,
snapshot_id=None,
owner=None,
restorable_by=None,
**kwargs):
return {'snapshotSet': [{'snapshotId': 'fixme',
'volumeId': 'fixme',
'status': 'fixme',
'startTime': 'fixme',
'progress': 'fixme',
'ownerId': 'fixme',
'volumeSize': 0,
'description': 'fixme'}]}
@rbac.allow('all')
def describe_key_pairs(self, context, key_name=None, **kwargs):
key_pairs = context.user.get_key_pairs()
if not key_name is None:
key_pairs = [x for x in key_pairs if x.name in key_name]
result = []
for key_pair in key_pairs:
# filter out the vpn keys
suffix = FLAGS.vpn_key_suffix
if context.user.is_admin() or not key_pair.name.endswith(suffix):
result.append({
'keyName': key_pair.name,
'keyFingerprint': key_pair.fingerprint,
})
return {'keypairsSet': result}
@rbac.allow('all')
def create_key_pair(self, context, key_name, **kwargs):
dcall = defer.Deferred()
pool = context.handler.application.settings.get('pool')
def _complete(kwargs):
if 'exception' in kwargs:
dcall.errback(kwargs['exception'])
return
dcall.callback({'keyName': key_name,
'keyFingerprint': kwargs['fingerprint'],
'keyMaterial': kwargs['private_key']})
pool.apply_async(_gen_key, [context.user.id, key_name],
callback=_complete)
return dcall
@rbac.allow('all')
def delete_key_pair(self, context, key_name, **kwargs):
context.user.delete_key_pair(key_name)
# aws returns true even if the key doens't exist
return True
@rbac.allow('all')
def describe_security_groups(self, context, group_names, **kwargs):
groups = {'securityGroupSet': []}
# Stubbed for now to unblock other things.
return groups
@rbac.allow('netadmin')
def create_security_group(self, context, group_name, **kwargs):
return True
@rbac.allow('netadmin')
def delete_security_group(self, context, group_name, **kwargs):
return True
@rbac.allow('projectmanager', 'sysadmin')
def get_console_output(self, context, instance_id, **kwargs):
# instance_id is passed in as a list of instances
instance_ref = db.instance_get_by_str(context, instance_id[0])
return rpc.call('%s.%s' % (FLAGS.compute_topic,
instance_ref['host']),
{"method": "get_console_output",
"args": {"context": None,
"instance_id": instance_ref['id']}})
@rbac.allow('projectmanager', 'sysadmin')
def describe_volumes(self, context, **kwargs):
if context.user.is_admin():
volumes = db.volume_get_all(context)
else:
volumes = db.volume_get_by_project(context, context.project.id)
volumes = [self._format_volume(context, v) for v in volumes]
return {'volumeSet': volumes}
def _format_volume(self, context, volume):
v = {}
v['volumeId'] = volume['str_id']
v['status'] = volume['status']
v['size'] = volume['size']
v['availabilityZone'] = volume['availability_zone']
# v['createTime'] = volume['create_time']
if context.user.is_admin():
v['status'] = '%s (%s, %s, %s, %s)' % (
volume['status'],
volume['user_id'],
'host',
volume['instance_id'],
volume['mountpoint'])
if volume['attach_status'] == 'attached':
v['attachmentSet'] = [{'attachTime': volume['attach_time'],
'deleteOnTermination': volume['delete_on_termination'],
'device': volume['mountpoint'],
'instanceId': volume['instance_id'],
'status': 'attached',
'volume_id': volume['volume_id']}]
else:
v['attachmentSet'] = [{}]
return v
@rbac.allow('projectmanager', 'sysadmin')
def create_volume(self, context, size, **kwargs):
vol = {}
vol['size'] = size
vol['user_id'] = context.user.id
vol['project_id'] = context.project.id
vol['availability_zone'] = FLAGS.storage_availability_zone
vol['status'] = "creating"
vol['attach_status'] = "detached"
volume_ref = db.volume_create(context, vol)
rpc.cast(FLAGS.volume_topic, {"method": "create_volume",
"args": {"context": None,
"volume_id": volume_ref['id']}})
return {'volumeSet': [self._format_volume(context, volume_ref)]}
@rbac.allow('projectmanager', 'sysadmin')
def attach_volume(self, context, volume_id, instance_id, device, **kwargs):
volume_ref = db.volume_get_by_str(context, volume_id)
# TODO(vish): abstract status checking?
if volume_ref['status'] == "attached":
raise exception.ApiError("Volume is already attached")
#volume.start_attach(instance_id, device)
instance_ref = db.instance_get_by_str(context, instance_id)
host = db.instance_get_host(context, instance_ref['id'])
rpc.cast(db.queue_get_for(context, FLAGS.compute_topic, host),
{"method": "attach_volume",
"args": {"context": None,
"volume_id": volume_ref['id'],
"instance_id": instance_ref['id'],
"mountpoint": device}})
return defer.succeed({'attachTime': volume_ref['attach_time'],
'device': volume_ref['mountpoint'],
'instanceId': instance_ref['id_str'],
'requestId': context.request_id,
'status': volume_ref['attach_status'],
'volumeId': volume_ref['id']})
@rbac.allow('projectmanager', 'sysadmin')
def detach_volume(self, context, volume_id, **kwargs):
volume_ref = db.volume_get_by_str(context, volume_id)
instance_ref = db.volume_get_instance(context, volume_ref['id'])
if not instance_ref:
raise exception.Error("Volume isn't attached to anything!")
# TODO(vish): abstract status checking?
if volume_ref['status'] == "available":
raise exception.Error("Volume is already detached")
try:
#volume.start_detach()
host = db.instance_get_host(context, instance_ref['id'])
rpc.cast(db.queue_get_for(context, FLAGS.compute_topic, host),
{"method": "detach_volume",
"args": {"context": None,
"instance_id": instance_ref['id'],
"volume_id": volume_ref['id']}})
except exception.NotFound:
# If the instance doesn't exist anymore,
# then we need to call detach blind
db.volume_detached(context)
return defer.succeed({'attachTime': volume_ref['attach_time'],
'device': volume_ref['mountpoint'],
'instanceId': instance_ref['id_str'],
'requestId': context.request_id,
'status': volume_ref['attach_status'],
'volumeId': volume_ref['id']})
def _convert_to_set(self, lst, label):
if lst == None or lst == []:
return None
if not isinstance(lst, list):
lst = [lst]
return [{label: x} for x in lst]
@rbac.allow('all')
def describe_instances(self, context, **kwargs):
return defer.succeed(self._format_describe_instances(context))
def _format_describe_instances(self, context):
return { 'reservationSet': self._format_instances(context) }
def _format_run_instances(self, context, reservation_id):
i = self._format_instances(context, reservation_id)
assert len(i) == 1
return i[0]
def _format_instances(self, context, reservation_id=None):
reservations = {}
if reservation_id:
instances = db.instance_get_by_reservation(context, reservation_id)
else:
if not context.user.is_admin():
instances = db.instance_get_all(context)
else:
instances = db.instance_get_by_project(context, context.project.id)
for instance in instances:
if not context.user.is_admin():
if instance['image_id'] == FLAGS.vpn_image_id:
continue
i = {}
i['instanceId'] = instance['str_id']
i['imageId'] = instance['image_id']
i['instanceState'] = {
'code': instance['state'],
'name': instance['state_description']
}
floating_addr = db.instance_get_floating_address(context,
instance['id'])
i['publicDnsName'] = floating_addr
fixed_addr = db.instance_get_fixed_address(context,
instance['id'])
i['privateDnsName'] = fixed_addr
if not i['publicDnsName']:
i['publicDnsName'] = i['privateDnsName']
i['dnsName'] = None
i['keyName'] = instance['key_name']
if context.user.is_admin():
i['keyName'] = '%s (%s, %s)' % (i['keyName'],
instance['project_id'],
instance['host'])
i['productCodesSet'] = self._convert_to_set([], 'product_codes')
i['instanceType'] = instance['instance_type']
i['launchTime'] = instance['created_at']
i['amiLaunchIndex'] = instance['launch_index']
if not reservations.has_key(instance['reservation_id']):
r = {}
r['reservationId'] = instance['reservation_id']
r['ownerId'] = instance['project_id']
r['groupSet'] = self._convert_to_set([], 'groups')
r['instancesSet'] = []
reservations[instance['reservation_id']] = r
reservations[instance['reservation_id']]['instancesSet'].append(i)
return list(reservations.values())
@rbac.allow('all')
def describe_addresses(self, context, **kwargs):
return self.format_addresses(context)
def format_addresses(self, context):
addresses = []
if context.user.is_admin():
iterator = db.floating_ip_get_all(context)
else:
iterator = db.floating_ip_get_by_project(context,
context.project.id)
for floating_ip_ref in iterator:
address = floating_ip_ref['id_str']
instance_ref = db.floating_ip_get_instance(address)
address_rv = {
'public_ip': address,
'instance_id': instance_ref['id_str']
}
if context.user.is_admin():
address_rv['instance_id'] = "%s (%s)" % (
address_rv['instance_id'],
floating_ip_ref['project_id'],
)
addresses.append(address_rv)
return {'addressesSet': addresses}
@rbac.allow('netadmin')
@defer.inlineCallbacks
def allocate_address(self, context, **kwargs):
network_topic = yield self._get_network_topic(context)
public_ip = yield rpc.call(network_topic,
{"method": "allocate_floating_ip",
"args": {"context": None,
"project_id": context.project.id}})
defer.returnValue({'addressSet': [{'publicIp': public_ip}]})
@rbac.allow('netadmin')
@defer.inlineCallbacks
def release_address(self, context, public_ip, **kwargs):
# NOTE(vish): Should we make sure this works?
floating_ip_ref = db.floating_ip_get_by_address(context, public_ip)
network_topic = yield self._get_network_topic(context)
rpc.cast(network_topic,
{"method": "deallocate_floating_ip",
"args": {"context": None,
"floating_ip": floating_ip_ref['str_id']}})
defer.returnValue({'releaseResponse': ["Address released."]})
@rbac.allow('netadmin')
@defer.inlineCallbacks
def associate_address(self, context, instance_id, public_ip, **kwargs):
instance_ref = db.instance_get_by_str(context, instance_id)
fixed_ip_ref = db.fixed_ip_get_by_instance(context, instance_ref['id'])
floating_ip_ref = db.floating_ip_get_by_address(context, public_ip)
network_topic = yield self._get_network_topic(context)
rpc.cast(network_topic,
{"method": "associate_floating_ip",
"args": {"context": None,
"floating_ip": floating_ip_ref['str_id'],
"fixed_ip": fixed_ip_ref['str_id'],
"instance_id": instance_ref['id']}})
defer.returnValue({'associateResponse': ["Address associated."]})
@rbac.allow('netadmin')
@defer.inlineCallbacks
def disassociate_address(self, context, public_ip, **kwargs):
floating_ip_ref = db.floating_ip_get_by_address(context, public_ip)
network_topic = yield self._get_network_topic(context)
rpc.cast(network_topic,
{"method": "disassociate_floating_ip",
"args": {"context": None,
"floating_ip": floating_ip_ref['str_id']}})
defer.returnValue({'disassociateResponse': ["Address disassociated."]})
@defer.inlineCallbacks
def _get_network_topic(self, context):
"""Retrieves the network host for a project"""
network_ref = db.project_get_network(context, context.project.id)
host = db.network_get_host(context, network_ref['id'])
if not host:
host = yield rpc.call(FLAGS.network_topic,
{"method": "set_network_host",
"args": {"context": None,
"project_id": context.project.id}})
defer.returnValue(db.queue_get_for(context, FLAGS.network_topic, host))
@rbac.allow('projectmanager', 'sysadmin')
@defer.inlineCallbacks
def run_instances(self, context, **kwargs):
# make sure user can access the image
# vpn image is private so it doesn't show up on lists
vpn = kwargs['image_id'] == FLAGS.vpn_image_id
if not vpn:
image = images.get(context, kwargs['image_id'])
# FIXME(ja): if image is vpn, this breaks
# get defaults from imagestore
image_id = image['imageId']
kernel_id = image.get('kernelId', FLAGS.default_kernel)
ramdisk_id = image.get('ramdiskId', FLAGS.default_ramdisk)
# API parameters overrides of defaults
kernel_id = kwargs.get('kernel_id', kernel_id)
ramdisk_id = kwargs.get('ramdisk_id', ramdisk_id)
# make sure we have access to kernel and ramdisk
images.get(context, kernel_id)
images.get(context, ramdisk_id)
logging.debug("Going to run instances...")
launch_time = time.strftime('%Y-%m-%dT%H:%M:%SZ', time.gmtime())
key_data = None
if kwargs.has_key('key_name'):
key_pair = context.user.get_key_pair(kwargs['key_name'])
if not key_pair:
raise exception.ApiError('Key Pair %s not found' %
kwargs['key_name'])
key_data = key_pair.public_key
# TODO: Get the real security group of launch in here
security_group = "default"
reservation_id = utils.generate_uid('r')
base_options = {}
base_options['image_id'] = image_id
base_options['kernel_id'] = kernel_id
base_options['ramdisk_id'] = ramdisk_id
base_options['reservation_id'] = reservation_id
base_options['key_data'] = key_data
base_options['key_name'] = kwargs.get('key_name', None)
base_options['user_id'] = context.user.id
base_options['project_id'] = context.project.id
base_options['user_data'] = kwargs.get('user_data', '')
base_options['instance_type'] = kwargs.get('instance_type', 'm1.small')
base_options['security_group'] = security_group
for num in range(int(kwargs['max_count'])):
inst_id = db.instance_create(context, base_options)
inst = {}
inst['mac_address'] = utils.generate_mac()
inst['launch_index'] = num
inst['hostname'] = inst_id
db.instance_update(context, inst_id, inst)
address = self.network_manager.allocate_fixed_ip(context,
inst_id,
vpn)
# TODO(vish): This probably should be done in the scheduler
# network is setup when host is assigned
network_topic = yield self._get_network_topic(context)
rpc.call(network_topic,
{"method": "setup_fixed_ip",
"args": {"context": None,
"address": address}})
rpc.cast(FLAGS.scheduler_topic,
{"method": "run_instance",
"args": {"context": None,
"topic": FLAGS.compute_topic,
"instance_id": inst_id}})
logging.debug("Casting to scheduler for %s/%s's instance %s" %
(context.project.name, context.user.name, inst_id))
defer.returnValue(self._format_run_instances(context,
reservation_id))
@rbac.allow('projectmanager', 'sysadmin')
@defer.inlineCallbacks
def terminate_instances(self, context, instance_id, **kwargs):
logging.debug("Going to start terminating instances")
# network_topic = yield self._get_network_topic(context)
for id_str in instance_id:
logging.debug("Going to try and terminate %s" % id_str)
try:
instance_ref = db.instance_get_by_str(context, id_str)
except exception.NotFound:
logging.warning("Instance %s was not found during terminate"
% id_str)
continue
# FIXME(ja): where should network deallocate occur?
address = db.instance_get_floating_address(context,
instance_ref['id'])
if address:
logging.debug("Disassociating address %s" % address)
# NOTE(vish): Right now we don't really care if the ip is
# disassociated. We may need to worry about
# checking this later. Perhaps in the scheduler?
network_topic = yield self._get_network_topic(context)
rpc.cast(network_topic,
{"method": "disassociate_floating_ip",
"args": {"context": None,
"address": address}})
address = db.instance_get_fixed_address(context,
instance_ref['id'])
if address:
logging.debug("Deallocating address %s" % address)
# NOTE(vish): Currently, nothing needs to be done on the
# network node until release. If this changes,
# we will need to cast here.
db.fixed_ip_deallocate(context, address)
host = db.instance_get_host(context, instance_ref['id'])
if host:
rpc.cast(db.queue_get_for(context, FLAGS.compute_topic, host),
{"method": "terminate_instance",
"args": {"context": None,
"instance_id": instance_ref['id']}})
else:
db.instance_destroy(context, instance_ref['id'])
defer.returnValue(True)
@rbac.allow('projectmanager', 'sysadmin')
def reboot_instances(self, context, instance_id, **kwargs):
"""instance_id is a list of instance ids"""
for id_str in instance_id:
instance_ref = db.instance_get_by_str(context, id_str)
host = db.instance_get_host(context, instance_ref['id'])
rpc.cast(db.queue_get_for(context, FLAGS.compute_topic, host),
{"method": "reboot_instance",
"args": {"context": None,
"instance_id": instance_ref['id']}})
return defer.succeed(True)
@rbac.allow('projectmanager', 'sysadmin')
def delete_volume(self, context, volume_id, **kwargs):
# TODO: return error if not authorized
volume_ref = db.volume_get_by_str(context, volume_id)
host = db.volume_get_host(context, volume_ref['id'])
rpc.cast(db.queue_get_for(context, FLAGS.compute_topic, host),
{"method": "delete_volume",
"args": {"context": None,
"volume_id": volume_id}})
return defer.succeed(True)
@rbac.allow('all')
def describe_images(self, context, image_id=None, **kwargs):
# The objectstore does its own authorization for describe
imageSet = images.list(context, image_id)
return defer.succeed({'imagesSet': imageSet})
@rbac.allow('projectmanager', 'sysadmin')
def deregister_image(self, context, image_id, **kwargs):
# FIXME: should the objectstore be doing these authorization checks?
images.deregister(context, image_id)
return defer.succeed({'imageId': image_id})
@rbac.allow('projectmanager', 'sysadmin')
def register_image(self, context, image_location=None, **kwargs):
# FIXME: should the objectstore be doing these authorization checks?
if image_location is None and kwargs.has_key('name'):
image_location = kwargs['name']
image_id = images.register(context, image_location)
logging.debug("Registered %s as %s" % (image_location, image_id))
return defer.succeed({'imageId': image_id})
@rbac.allow('all')
def describe_image_attribute(self, context, image_id, attribute, **kwargs):
if attribute != 'launchPermission':
raise exception.ApiError('attribute not supported: %s' % attribute)
try:
image = images.list(context, image_id)[0]
except IndexError:
raise exception.ApiError('invalid id: %s' % image_id)
result = {'image_id': image_id, 'launchPermission': []}
if image['isPublic']:
result['launchPermission'].append({'group': 'all'})
return defer.succeed(result)
@rbac.allow('projectmanager', 'sysadmin')
def modify_image_attribute(self, context, image_id, attribute, operation_type, **kwargs):
# TODO(devcamcar): Support users and groups other than 'all'.
if attribute != 'launchPermission':
raise exception.ApiError('attribute not supported: %s' % attribute)
if not 'user_group' in kwargs:
raise exception.ApiError('user or group not specified')
if len(kwargs['user_group']) != 1 and kwargs['user_group'][0] != 'all':
raise exception.ApiError('only group "all" is supported')
if not operation_type in ['add', 'remove']:
raise exception.ApiError('operation_type must be add or remove')
result = images.modify(context, image_id, operation_type)
return defer.succeed(result)

View File

@@ -1,108 +0,0 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Proxy AMI-related calls from the cloud controller, to the running
objectstore service.
"""
import json
import urllib
import boto.s3.connection
from nova import exception
from nova import flags
from nova import utils
from nova.auth import manager
FLAGS = flags.FLAGS
def modify(context, image_id, operation):
conn(context).make_request(
method='POST',
bucket='_images',
query_args=qs({'image_id': image_id, 'operation': operation}))
return True
def register(context, image_location):
""" rpc call to register a new image based from a manifest """
image_id = utils.generate_uid('ami')
conn(context).make_request(
method='PUT',
bucket='_images',
query_args=qs({'image_location': image_location,
'image_id': image_id}))
return image_id
def list(context, filter_list=[]):
""" return a list of all images that a user can see
optionally filtered by a list of image_id """
# FIXME: send along the list of only_images to check for
response = conn(context).make_request(
method='GET',
bucket='_images')
result = json.loads(response.read())
if not filter_list is None:
return [i for i in result if i['imageId'] in filter_list]
return result
def get(context, image_id):
"""return a image object if the context has permissions"""
result = list(context, [image_id])
if not result:
raise exception.NotFound('Image %s could not be found' % image_id)
image = result[0]
return image
def deregister(context, image_id):
""" unregister an image """
conn(context).make_request(
method='DELETE',
bucket='_images',
query_args=qs({'image_id': image_id}))
def conn(context):
access = manager.AuthManager().get_access_key(context.user,
context.project)
secret = str(context.user.secret)
calling = boto.s3.connection.OrdinaryCallingFormat()
return boto.s3.connection.S3Connection(aws_access_key_id=access,
aws_secret_access_key=secret,
is_secure=False,
calling_format=calling,
port=FLAGS.s3_port,
host=FLAGS.s3_host)
def qs(params):
pairs = []
for key in params.keys():
pairs.append(key + '=' + urllib.quote(params[key]))
return '&'.join(pairs)

View File

@@ -26,6 +26,18 @@ import sys
import traceback
class ProcessExecutionError(IOError):
def __init__(self, stdout=None, stderr=None, exit_code=None, cmd=None,
description=None):
if description is None:
description = "Unexpected error while running command."
if exit_code is None:
exit_code = '-'
message = "%s\nCommand: %s\nExit code: %s\nStdout: %r\nStderr: %r" % (
description, cmd, exit_code, stdout, stderr)
IOError.__init__(self, message)
class Error(Exception):
def __init__(self, message=None):
super(Error, self).__init__(message)
@@ -35,7 +47,7 @@ class ApiError(Error):
def __init__(self, message='Unknown', code='Unknown'):
self.message = message
self.code = code
super(ApiError, self).__init__('%s: %s'% (code, message))
super(ApiError, self).__init__('%s: %s' % (code, message))
class NotFound(Error):
@@ -58,6 +70,10 @@ class Invalid(Error):
pass
class InvalidInputException(Error):
pass
def wrap_exception(f):
def _wrap(*args, **kw):
try:
@@ -71,5 +87,3 @@ def wrap_exception(f):
raise
_wrap.func_name = f.func_name
return _wrap

View File

@@ -22,6 +22,7 @@ import logging
import Queue as queue
from carrot.backends import base
from eventlet import greenthread
class Message(base.BaseMessage):
@@ -38,6 +39,7 @@ class Exchange(object):
def publish(self, message, routing_key=None):
logging.debug('(%s) publish (key: %s) %s',
self.name, routing_key, message)
routing_key = routing_key.split('.')[0]
if routing_key in self._routes:
for f in self._routes[routing_key]:
logging.debug('Publishing to route %s', f)
@@ -94,6 +96,18 @@ class Backend(object):
self._exchanges[exchange].bind(self._queues[queue].push,
routing_key)
def declare_consumer(self, queue, callback, *args, **kwargs):
self.current_queue = queue
self.current_callback = callback
def consume(self, *args, **kwargs):
while True:
item = self.get(self.current_queue)
if item:
self.current_callback(item)
raise StopIteration()
greenthread.sleep(0)
def get(self, queue, no_ack=False):
if not queue in self._queues or not self._queues[queue].size():
return None
@@ -102,6 +116,7 @@ class Backend(object):
message = Message(backend=self, body=message_data,
content_type=content_type,
content_encoding=content_encoding)
message.result = True
logging.debug('Getting from %s: %s', queue, message)
return message
@@ -115,7 +130,6 @@ class Backend(object):
self._exchanges[exchange].publish(
message, routing_key=routing_key)
__instance = None
def __init__(self, *args, **kwargs):

View File

@@ -90,6 +90,12 @@ class FlagValues(gflags.FlagValues):
self.ClearDirty()
return args
def Reset(self):
gflags.FlagValues.Reset(self)
self.__dict__['__dirty'] = []
self.__dict__['__was_already_parsed'] = False
self.__dict__['__stored_argv'] = []
def SetDirty(self, name):
"""Mark a flag as dirty so that accessing it will case a reparse."""
self.__dict__['__dirty'].append(name)
@@ -167,11 +173,15 @@ def DECLARE(name, module_string, flag_values=FLAGS):
# Define any app-specific flags in their own files, docs at:
# http://code.google.com/p/python-gflags/source/browse/trunk/gflags.py#39
DEFINE_list('region_list',
[],
'list of region=url pairs separated by commas')
DEFINE_string('connection_type', 'libvirt', 'libvirt, xenapi or fake')
DEFINE_integer('s3_port', 3333, 's3 port')
DEFINE_string('s3_host', '127.0.0.1', 's3 host')
DEFINE_string('compute_topic', 'compute', 'the topic compute nodes listen on')
DEFINE_string('scheduler_topic', 'scheduler', 'the topic scheduler nodes listen on')
DEFINE_string('scheduler_topic', 'scheduler',
'the topic scheduler nodes listen on')
DEFINE_string('volume_topic', 'volume', 'the topic volume nodes listen on')
DEFINE_string('network_topic', 'network', 'the topic network nodes listen on')
@@ -185,6 +195,8 @@ DEFINE_string('rabbit_userid', 'guest', 'rabbit userid')
DEFINE_string('rabbit_password', 'guest', 'rabbit password')
DEFINE_string('rabbit_virtual_host', '/', 'rabbit virtual host')
DEFINE_string('control_exchange', 'nova', 'the main exchange to connect to')
DEFINE_string('cc_host', '127.0.0.1', 'ip of api server')
DEFINE_integer('cc_port', 8773, 'cloud controller port')
DEFINE_string('ec2_url', 'http://127.0.0.1:8773/services/Cloud',
'Url to ec2 api server')
@@ -204,12 +216,6 @@ DEFINE_string('vpn_key_suffix',
DEFINE_integer('auth_token_ttl', 3600, 'Seconds for auth tokens to linger')
# UNUSED
DEFINE_string('node_availability_zone', 'nova',
'availability zone of this node')
DEFINE_string('host', socket.gethostname(),
'name of this node')
DEFINE_string('sql_connection',
'sqlite:///%s/nova.sqlite' % os.path.abspath("./"),
'connection string for sql database')
@@ -223,4 +229,13 @@ DEFINE_string('volume_manager', 'nova.volume.manager.AOEManager',
DEFINE_string('scheduler_manager', 'nova.scheduler.manager.SchedulerManager',
'Manager for scheduler')
# The service to use for image search and retrieval
DEFINE_string('image_service', 'nova.image.service.LocalImageService',
'The service to use for retrieving and searching for images.')
DEFINE_string('host', socket.gethostname(),
'name of this node')
# UNUSED
DEFINE_string('node_availability_zone', 'nova',
'availability zone of this node')

View File

@@ -22,6 +22,7 @@ Base class for managers of different parts of the system
from nova import utils
from nova import flags
from twisted.internet import defer
FLAGS = flags.FLAGS
flags.DEFINE_string('db_driver', 'nova.db.api',
@@ -37,3 +38,13 @@ class Manager(object):
if not db_driver:
db_driver = FLAGS.db_driver
self.db = utils.import_object(db_driver) # pylint: disable-msg=C0103
@defer.inlineCallbacks
def periodic_tasks(self, context=None):
"""Tasks to be run at a periodic interval"""
yield
def init_host(self):
"""Do any initialization that needs to be run if this is a standalone
service. Child classes should override this method."""
pass

View File

@@ -18,9 +18,10 @@
# under the License.
"""
Process pool, still buggy right now.
Process pool using twisted threading
"""
import logging
import StringIO
from twisted.internet import defer
@@ -29,30 +30,15 @@ from twisted.internet import protocol
from twisted.internet import reactor
from nova import flags
from nova.exception import ProcessExecutionError
FLAGS = flags.FLAGS
flags.DEFINE_integer('process_pool_size', 4,
'Number of processes to use in the process pool')
# NOTE(termie): this is copied from twisted.internet.utils but since
# they don't export it I've copied and modified
class UnexpectedErrorOutput(IOError):
"""
Standard error data was received where it was not expected. This is a
subclass of L{IOError} to preserve backward compatibility with the previous
error behavior of L{getProcessOutput}.
@ivar processEnded: A L{Deferred} which will fire when the process which
produced the data on stderr has ended (exited and all file descriptors
closed).
"""
def __init__(self, stdout=None, stderr=None):
IOError.__init__(self, "got stdout: %r\nstderr: %r" % (stdout, stderr))
# This is based on _BackRelay from twister.internal.utils, but modified to
# capture both stdout and stderr, without odd stderr handling, and also to
# This is based on _BackRelay from twister.internal.utils, but modified to
# capture both stdout and stderr, without odd stderr handling, and also to
# handle stdin
class BackRelayWithInput(protocol.ProcessProtocol):
"""
@@ -62,22 +48,23 @@ class BackRelayWithInput(protocol.ProcessProtocol):
@ivar deferred: A L{Deferred} which will be called back with all of stdout
and all of stderr as well (as a tuple). C{terminate_on_stderr} is true
and any bytes are received over stderr, this will fire with an
L{_UnexpectedErrorOutput} instance and the attribute will be set to
L{_ProcessExecutionError} instance and the attribute will be set to
C{None}.
@ivar onProcessEnded: If C{terminate_on_stderr} is false and bytes are
received over stderr, this attribute will refer to a L{Deferred} which
will be called back when the process ends. This C{Deferred} is also
associated with the L{_UnexpectedErrorOutput} which C{deferred} fires
with earlier in this case so that users can determine when the process
has actually ended, in addition to knowing when bytes have been received
via stderr.
@ivar onProcessEnded: If C{terminate_on_stderr} is false and bytes are
received over stderr, this attribute will refer to a L{Deferred} which
will be called back when the process ends. This C{Deferred} is also
associated with the L{_ProcessExecutionError} which C{deferred} fires
with earlier in this case so that users can determine when the process
has actually ended, in addition to knowing when bytes have been
received via stderr.
"""
def __init__(self, deferred, started_deferred=None,
terminate_on_stderr=False, check_exit_code=True,
process_input=None):
def __init__(self, deferred, cmd, started_deferred=None,
terminate_on_stderr=False, check_exit_code=True,
process_input=None):
self.deferred = deferred
self.cmd = cmd
self.stdout = StringIO.StringIO()
self.stderr = StringIO.StringIO()
self.started_deferred = started_deferred
@@ -85,14 +72,18 @@ class BackRelayWithInput(protocol.ProcessProtocol):
self.check_exit_code = check_exit_code
self.process_input = process_input
self.on_process_ended = None
def _build_execution_error(self, exit_code=None):
return ProcessExecutionError(cmd=self.cmd,
exit_code=exit_code,
stdout=self.stdout.getvalue(),
stderr=self.stderr.getvalue())
def errReceived(self, text):
self.stderr.write(text)
if self.terminate_on_stderr and (self.deferred is not None):
self.on_process_ended = defer.Deferred()
self.deferred.errback(UnexpectedErrorOutput(
stdout=self.stdout.getvalue(),
stderr=self.stderr.getvalue()))
self.deferred.errback(self._build_execution_error())
self.deferred = None
self.transport.loseConnection()
@@ -102,28 +93,34 @@ class BackRelayWithInput(protocol.ProcessProtocol):
def processEnded(self, reason):
if self.deferred is not None:
stdout, stderr = self.stdout.getvalue(), self.stderr.getvalue()
try:
if self.check_exit_code:
reason.trap(error.ProcessDone)
self.deferred.callback((stdout, stderr))
except:
# NOTE(justinsb): This logic is a little suspicious to me...
# If the callback throws an exception, then errback will be
# called also. However, this is what the unit tests test for...
self.deferred.errback(UnexpectedErrorOutput(stdout, stderr))
exit_code = reason.value.exitCode
if self.check_exit_code and exit_code != 0:
self.deferred.errback(self._build_execution_error(exit_code))
else:
try:
if self.check_exit_code:
reason.trap(error.ProcessDone)
self.deferred.callback((stdout, stderr))
except:
# NOTE(justinsb): This logic is a little suspicious to me.
# If the callback throws an exception, then errback will
# be called also. However, this is what the unit tests
# test for.
exec_error = self._build_execution_error(exit_code)
self.deferred.errback(exec_error)
elif self.on_process_ended is not None:
self.on_process_ended.errback(reason)
def connectionMade(self):
if self.started_deferred:
self.started_deferred.callback(self)
if self.process_input:
self.transport.write(self.process_input)
self.transport.write(str(self.process_input))
self.transport.closeStdin()
def get_process_output(executable, args=None, env=None, path=None,
process_reactor=None, check_exit_code=True,
def get_process_output(executable, args=None, env=None, path=None,
process_reactor=None, check_exit_code=True,
process_input=None, started_deferred=None,
terminate_on_stderr=False):
if process_reactor is None:
@@ -131,10 +128,15 @@ def get_process_output(executable, args=None, env=None, path=None,
args = args and args or ()
env = env and env and {}
deferred = defer.Deferred()
cmd = executable
if args:
cmd = " ".join([cmd] + args)
logging.debug("Running cmd: %s", cmd)
process_handler = BackRelayWithInput(
deferred,
started_deferred=started_deferred,
check_exit_code=check_exit_code,
deferred,
cmd,
started_deferred=started_deferred,
check_exit_code=check_exit_code,
process_input=process_input,
terminate_on_stderr=terminate_on_stderr)
# NOTE(vish): commands come in as unicode, but self.executes needs
@@ -142,8 +144,8 @@ def get_process_output(executable, args=None, env=None, path=None,
executable = str(executable)
if not args is None:
args = [str(x) for x in args]
process_reactor.spawnProcess( process_handler, executable,
(executable,)+tuple(args), env, path)
process_reactor.spawnProcess(process_handler, executable,
(executable,) + tuple(args), env, path)
return deferred
@@ -194,9 +196,11 @@ class ProcessPool(object):
class SharedPool(object):
_instance = None
def __init__(self):
if SharedPool._instance is None:
self.__class__._instance = ProcessPool()
def __getattr__(self, key):
return getattr(self._instance, key)

96
nova/quota.py Normal file
View File

@@ -0,0 +1,96 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Quotas for instances, volumes, and floating ips
"""
from nova import db
from nova import exception
from nova import flags
from nova.compute import instance_types
FLAGS = flags.FLAGS
flags.DEFINE_integer('quota_instances', 10,
'number of instances allowed per project')
flags.DEFINE_integer('quota_cores', 20,
'number of instance cores allowed per project')
flags.DEFINE_integer('quota_volumes', 10,
'number of volumes allowed per project')
flags.DEFINE_integer('quota_gigabytes', 1000,
'number of volume gigabytes allowed per project')
flags.DEFINE_integer('quota_floating_ips', 10,
'number of floating ips allowed per project')
def get_quota(context, project_id):
rval = {'instances': FLAGS.quota_instances,
'cores': FLAGS.quota_cores,
'volumes': FLAGS.quota_volumes,
'gigabytes': FLAGS.quota_gigabytes,
'floating_ips': FLAGS.quota_floating_ips}
try:
quota = db.quota_get(context, project_id)
for key in rval.keys():
if quota[key] is not None:
rval[key] = quota[key]
except exception.NotFound:
pass
return rval
def allowed_instances(context, num_instances, instance_type):
"""Check quota and return min(num_instances, allowed_instances)"""
project_id = context.project_id
context = context.elevated()
used_instances, used_cores = db.instance_data_get_for_project(context,
project_id)
quota = get_quota(context, project_id)
allowed_instances = quota['instances'] - used_instances
allowed_cores = quota['cores'] - used_cores
type_cores = instance_types.INSTANCE_TYPES[instance_type]['vcpus']
num_cores = num_instances * type_cores
allowed_instances = min(allowed_instances,
int(allowed_cores // type_cores))
return min(num_instances, allowed_instances)
def allowed_volumes(context, num_volumes, size):
"""Check quota and return min(num_volumes, allowed_volumes)"""
project_id = context.project_id
context = context.elevated()
used_volumes, used_gigabytes = db.volume_data_get_for_project(context,
project_id)
quota = get_quota(context, project_id)
allowed_volumes = quota['volumes'] - used_volumes
allowed_gigabytes = quota['gigabytes'] - used_gigabytes
size = int(size)
num_gigabytes = num_volumes * size
allowed_volumes = min(allowed_volumes,
int(allowed_gigabytes // size))
return min(num_volumes, allowed_volumes)
def allowed_floating_ips(context, num_floating_ips):
"""Check quota and return min(num_floating_ips, allowed_floating_ips)"""
project_id = context.project_id
context = context.elevated()
used_floating_ips = db.floating_ip_count_by_project(context, project_id)
quota = get_quota(context, project_id)
allowed_floating_ips = quota['floating_ips'] - used_floating_ips
return min(num_floating_ips, allowed_floating_ips)

View File

@@ -28,13 +28,14 @@ import uuid
from carrot import connection as carrot_connection
from carrot import messaging
from eventlet import greenthread
from twisted.internet import defer
from twisted.internet import task
from nova import exception
from nova import fakerabbit
from nova import flags
from nova import context
FLAGS = flags.FLAGS
@@ -46,9 +47,9 @@ LOG.setLevel(logging.DEBUG)
class Connection(carrot_connection.BrokerConnection):
"""Connection instance object"""
@classmethod
def instance(cls):
def instance(cls, new=False):
"""Returns the instance"""
if not hasattr(cls, '_instance'):
if new or not hasattr(cls, '_instance'):
params = dict(hostname=FLAGS.rabbit_host,
port=FLAGS.rabbit_port,
userid=FLAGS.rabbit_userid,
@@ -60,7 +61,10 @@ class Connection(carrot_connection.BrokerConnection):
# NOTE(vish): magic is fun!
# pylint: disable-msg=W0142
cls._instance = cls(**params)
if new:
return cls(**params)
else:
cls._instance = cls(**params)
return cls._instance
@classmethod
@@ -81,21 +85,6 @@ class Consumer(messaging.Consumer):
self.failed_connection = False
super(Consumer, self).__init__(*args, **kwargs)
# TODO(termie): it would be nice to give these some way of automatically
# cleaning up after themselves
def attach_to_tornado(self, io_inst=None):
"""Attach a callback to tornado that fires 10 times a second"""
from tornado import ioloop
if io_inst is None:
io_inst = ioloop.IOLoop.instance()
injected = ioloop.PeriodicCallback(
lambda: self.fetch(enable_callbacks=True), 100, io_loop=io_inst)
injected.start()
return injected
attachToTornado = attach_to_tornado
def fetch(self, no_ack=None, auto_ack=None, enable_callbacks=False):
"""Wraps the parent fetch with some logic for failed connections"""
# TODO(vish): the logic for failed connections and logging should be
@@ -119,10 +108,19 @@ class Consumer(messaging.Consumer):
logging.exception("Failed to fetch message from queue")
self.failed_connection = True
def attach_to_eventlet(self):
"""Only needed for unit tests!"""
def fetch_repeatedly():
while True:
self.fetch(enable_callbacks=True)
greenthread.sleep(0.1)
greenthread.spawn(fetch_repeatedly)
def attach_to_twisted(self):
"""Attach a callback to twisted that fires 10 times a second"""
loop = task.LoopingCall(self.fetch, enable_callbacks=True)
loop.start(interval=0.1)
return loop
class Publisher(messaging.Publisher):
@@ -163,6 +161,8 @@ class AdapterConsumer(TopicConsumer):
LOG.debug('received %s' % (message_data))
msg_id = message_data.pop('_msg_id', None)
ctxt = _unpack_context(message_data)
method = message_data.get('method')
args = message_data.get('args', {})
message.ack()
@@ -179,7 +179,7 @@ class AdapterConsumer(TopicConsumer):
node_args = dict((str(k), v) for k, v in args.iteritems())
# NOTE(vish): magic is fun!
# pylint: disable-msg=W0142
d = defer.maybeDeferred(node_func, **node_args)
d = defer.maybeDeferred(node_func, context=ctxt, **node_args)
if msg_id:
d.addCallback(lambda rval: msg_reply(msg_id, rval, None))
d.addErrback(lambda e: msg_reply(msg_id, None, e))
@@ -258,12 +258,73 @@ class RemoteError(exception.Error):
traceback))
def call(topic, msg):
def _unpack_context(msg):
"""Unpack context from msg."""
context_dict = {}
for key in list(msg.keys()):
if key.startswith('_context_'):
value = msg.pop(key)
context_dict[key[9:]] = value
LOG.debug('unpacked context: %s', context_dict)
return context.RequestContext.from_dict(context_dict)
def _pack_context(msg, context):
"""Pack context into msg.
Values for message keys need to be less than 255 chars, so we pull
context out into a bunch of separate keys. If we want to support
more arguments in rabbit messages, we may want to do the same
for args at some point.
"""
context = dict([('_context_%s' % key, value)
for (key, value) in context.to_dict().iteritems()])
msg.update(context)
def call(context, topic, msg):
"""Sends a message on a topic and wait for a response"""
LOG.debug("Making asynchronous call...")
msg_id = uuid.uuid4().hex
msg.update({'_msg_id': msg_id})
LOG.debug("MSG_ID is %s" % (msg_id))
_pack_context(msg, context)
class WaitMessage(object):
def __call__(self, data, message):
"""Acks message and sets result."""
message.ack()
if data['failure']:
self.result = RemoteError(*data['failure'])
else:
self.result = data['result']
wait_msg = WaitMessage()
conn = Connection.instance(True)
consumer = DirectConsumer(connection=conn, msg_id=msg_id)
consumer.register_callback(wait_msg)
conn = Connection.instance()
publisher = TopicPublisher(connection=conn, topic=topic)
publisher.send(msg)
publisher.close()
try:
consumer.wait(limit=1)
except StopIteration:
pass
consumer.close()
return wait_msg.result
def call_twisted(context, topic, msg):
"""Sends a message on a topic and wait for a response"""
LOG.debug("Making asynchronous call...")
msg_id = uuid.uuid4().hex
msg.update({'_msg_id': msg_id})
LOG.debug("MSG_ID is %s" % (msg_id))
_pack_context(msg, context)
conn = Connection.instance()
d = defer.Deferred()
@@ -278,7 +339,7 @@ def call(topic, msg):
return d.callback(data['result'])
consumer.register_callback(deferred_receive)
injected = consumer.attach_to_tornado()
injected = consumer.attach_to_twisted()
# clean up after the injected listened and return x
d.addCallback(lambda x: injected.stop() and x or x)
@@ -289,9 +350,10 @@ def call(topic, msg):
return d
def cast(topic, msg):
def cast(context, topic, msg):
"""Sends a message on a topic without waiting for a response"""
LOG.debug("Making asynchronous cast...")
_pack_context(msg, context)
conn = Connection.instance()
publisher = TopicPublisher(connection=conn, topic=topic)
publisher.send(msg)

View File

@@ -27,14 +27,10 @@ from nova.scheduler import driver
class ChanceScheduler(driver.Scheduler):
"""
Implements Scheduler as a random node selector
"""
"""Implements Scheduler as a random node selector."""
def schedule(self, context, topic, *_args, **_kwargs):
"""
Picks a host that is up at random
"""
"""Picks a host that is up at random."""
hosts = self.hosts_up(context, topic)
if not hosts:

View File

@@ -28,34 +28,28 @@ from nova import exception
from nova import flags
FLAGS = flags.FLAGS
flags.DEFINE_integer('service_down_time',
60,
'seconds without heartbeat that determines a '
'compute node to be down')
flags.DEFINE_integer('service_down_time', 60,
'maximum time since last checkin for up service')
class NoValidHost(exception.Error):
"""There is no valid host for the command"""
"""There is no valid host for the command."""
pass
class Scheduler(object):
"""
The base class that all Scheduler clases should inherit from
"""
"""The base class that all Scheduler clases should inherit from."""
@staticmethod
def service_is_up(service):
"""
Given a service, return whether the service is considered 'up' by
if it's sent a heartbeat recently
"""
"""Check whether a service is up based on last heartbeat."""
last_heartbeat = service['updated_at'] or service['created_at']
elapsed = datetime.datetime.now() - last_heartbeat
# Timestamps in DB are UTC.
elapsed = datetime.datetime.utcnow() - last_heartbeat
return elapsed < datetime.timedelta(seconds=FLAGS.service_down_time)
def hosts_up(self, context, topic):
"""
Return the list of hosts that have a running service for topic
"""
"""Return the list of hosts that have a running service for topic."""
services = db.service_get_all_by_topic(context, topic)
return [service.host
@@ -63,7 +57,5 @@ class Scheduler(object):
if self.service_is_up(service)]
def schedule(self, context, topic, *_args, **_kwargs):
"""
Must override at least this method for scheduler to work
"""
"""Must override at least this method for scheduler to work."""
raise NotImplementedError("Must implement a fallback schedule")

View File

@@ -37,9 +37,7 @@ flags.DEFINE_string('scheduler_driver',
class SchedulerManager(manager.Manager):
"""
Chooses a host to run instances on.
"""
"""Chooses a host to run instances on."""
def __init__(self, scheduler_driver=None, *args, **kwargs):
if not scheduler_driver:
scheduler_driver = FLAGS.scheduler_driver
@@ -56,13 +54,15 @@ class SchedulerManager(manager.Manager):
Falls back to schedule(context, topic) if method doesn't exist.
"""
driver_method = 'schedule_%s' % method
elevated = context.elevated()
try:
host = getattr(self.driver, driver_method)(context, *args, **kwargs)
host = getattr(self.driver, driver_method)(elevated, *args,
**kwargs)
except AttributeError:
host = self.driver.schedule(context, topic, *args, **kwargs)
host = self.driver.schedule(elevated, topic, *args, **kwargs)
kwargs.update({"context": None})
rpc.cast(db.queue_get_for(context, topic, host),
rpc.cast(context,
db.queue_get_for(context, topic, host),
{"method": method,
"args": kwargs})
logging.debug("Casting to %s %s for %s", topic, host, self.method)
logging.debug("Casting to %s %s for %s", topic, host, method)

View File

@@ -21,56 +21,65 @@
Simple Scheduler
"""
import datetime
from nova import db
from nova import flags
from nova.scheduler import driver
from nova.scheduler import chance
FLAGS = flags.FLAGS
flags.DEFINE_integer("max_instances", 16,
"maximum number of instances to allow per host")
flags.DEFINE_integer("max_volumes", 100,
"maximum number of volumes to allow per host")
flags.DEFINE_integer("max_cores", 16,
"maximum number of instance cores to allow per host")
flags.DEFINE_integer("max_gigabytes", 10000,
"maximum number of volume gigabytes to allow per host")
flags.DEFINE_integer("max_networks", 1000,
"maximum number of networks to allow per host")
class SimpleScheduler(chance.ChanceScheduler):
"""
Implements Naive Scheduler that tries to find least loaded host
"""
def schedule_run_instance(self, context, _instance_id, *_args, **_kwargs):
"""
Picks a host that is up and has the fewest running instances
"""
"""Implements Naive Scheduler that tries to find least loaded host."""
def schedule_run_instance(self, context, instance_id, *_args, **_kwargs):
"""Picks a host that is up and has the fewest running instances."""
instance_ref = db.instance_get(context, instance_id)
results = db.service_get_all_compute_sorted(context)
for result in results:
(service, instance_count) = result
if instance_count >= FLAGS.max_instances:
raise driver.NoValidHost("All hosts have too many instances")
(service, instance_cores) = result
if instance_cores + instance_ref['vcpus'] > FLAGS.max_cores:
raise driver.NoValidHost("All hosts have too many cores")
if self.service_is_up(service):
# NOTE(vish): this probably belongs in the manager, if we
# can generalize this somehow
now = datetime.datetime.utcnow()
db.instance_update(context,
instance_id,
{'host': service['host'],
'scheduled_at': now})
return service['host']
raise driver.NoValidHost("No hosts found")
def schedule_create_volume(self, context, _volume_id, *_args, **_kwargs):
"""
Picks a host that is up and has the fewest volumes
"""
def schedule_create_volume(self, context, volume_id, *_args, **_kwargs):
"""Picks a host that is up and has the fewest volumes."""
volume_ref = db.volume_get(context, volume_id)
results = db.service_get_all_volume_sorted(context)
for result in results:
(service, instance_count) = result
if instance_count >= FLAGS.max_volumes:
raise driver.NoValidHost("All hosts have too many volumes")
(service, volume_gigabytes) = result
if volume_gigabytes + volume_ref['size'] > FLAGS.max_gigabytes:
raise driver.NoValidHost("All hosts have too many gigabytes")
if self.service_is_up(service):
# NOTE(vish): this probably belongs in the manager, if we
# can generalize this somehow
now = datetime.datetime.utcnow()
db.volume_update(context,
volume_id,
{'host': service['host'],
'scheduled_at': now})
return service['host']
raise driver.NoValidHost("No hosts found")
def schedule_set_network_host(self, context, _network_id, *_args, **_kwargs):
"""
Picks a host that is up and has the fewest networks
"""
def schedule_set_network_host(self, context, *_args, **_kwargs):
"""Picks a host that is up and has the fewest networks."""
results = db.service_get_all_network_sorted(context)
for result in results:

View File

@@ -54,11 +54,11 @@ def stop(pidfile):
"""
# Get the pid from the pidfile
try:
pid = int(open(pidfile,'r').read().strip())
pid = int(open(pidfile, 'r').read().strip())
except IOError:
message = "pidfile %s does not exist. Daemon not running?\n"
sys.stderr.write(message % pidfile)
return # not an error in a restart
return
# Try killing the daemon process
try:
@@ -106,6 +106,7 @@ def serve(name, main):
def daemonize(args, name, main):
"""Does the work of daemonizing the process"""
logging.getLogger('amqplib').setLevel(logging.WARN)
files_to_keep = []
if FLAGS.daemonize:
logger = logging.getLogger()
formatter = logging.Formatter(
@@ -114,12 +115,14 @@ def daemonize(args, name, main):
syslog = logging.handlers.SysLogHandler(address='/dev/log')
syslog.setFormatter(formatter)
logger.addHandler(syslog)
files_to_keep.append(syslog.socket)
else:
if not FLAGS.logfile:
FLAGS.logfile = '%s.log' % name
logfile = logging.FileHandler(FLAGS.logfile)
logfile.setFormatter(formatter)
logger.addHandler(logfile)
files_to_keep.append(logfile.stream)
stdin, stdout, stderr = None, None, None
else:
stdin, stdout, stderr = sys.stdin, sys.stdout, sys.stderr
@@ -139,6 +142,6 @@ def daemonize(args, name, main):
stdout=stdout,
stderr=stderr,
uid=FLAGS.uid,
gid=FLAGS.gid
):
gid=FLAGS.gid,
files_preserve=files_to_keep):
main(args)

View File

@@ -28,6 +28,7 @@ from twisted.internet import defer
from twisted.internet import task
from twisted.application import service
from nova import context
from nova import db
from nova import exception
from nova import flags
@@ -37,37 +38,75 @@ from nova import utils
FLAGS = flags.FLAGS
flags.DEFINE_integer('report_interval', 10,
'seconds between nodes reporting state to cloud',
'seconds between nodes reporting state to datastore',
lower_bound=1)
flags.DEFINE_integer('periodic_interval', 60,
'seconds between running periodic tasks',
lower_bound=1)
class Service(object, service.Service):
"""Base class for workers that run on hosts."""
def __init__(self, host, binary, topic, manager, *args, **kwargs):
def __init__(self, host, binary, topic, manager, report_interval=None,
periodic_interval=None, *args, **kwargs):
self.host = host
self.binary = binary
self.topic = topic
manager_class = utils.import_class(manager)
self.manager = manager_class(host=host, *args, **kwargs)
self.model_disconnected = False
self.manager_class_name = manager
self.report_interval = report_interval
self.periodic_interval = periodic_interval
super(Service, self).__init__(*args, **kwargs)
self.saved_args, self.saved_kwargs = args, kwargs
def startService(self): # pylint: disable-msg C0103
manager_class = utils.import_class(self.manager_class_name)
self.manager = manager_class(host=self.host, *self.saved_args,
**self.saved_kwargs)
self.manager.init_host()
self.model_disconnected = False
ctxt = context.get_admin_context()
try:
service_ref = db.service_get_by_args(None,
self.host,
self.binary)
service_ref = db.service_get_by_args(ctxt,
self.host,
self.binary)
self.service_id = service_ref['id']
except exception.NotFound:
self.service_id = db.service_create(None, {'host': self.host,
'binary': self.binary,
'topic': self.topic,
'report_count': 0})
self._create_service_ref(ctxt)
conn = rpc.Connection.instance()
if self.report_interval:
consumer_all = rpc.AdapterConsumer(
connection=conn,
topic=self.topic,
proxy=self)
consumer_node = rpc.AdapterConsumer(
connection=conn,
topic='%s.%s' % (self.topic, self.host),
proxy=self)
consumer_all.attach_to_twisted()
consumer_node.attach_to_twisted()
pulse = task.LoopingCall(self.report_state)
pulse.start(interval=self.report_interval, now=False)
if self.periodic_interval:
pulse = task.LoopingCall(self.periodic_tasks)
pulse.start(interval=self.periodic_interval, now=False)
def _create_service_ref(self, context):
service_ref = db.service_create(context,
{'host': self.host,
'binary': self.binary,
'topic': self.topic,
'report_count': 0})
self.service_id = service_ref['id']
def __getattr__(self, key):
try:
return super(Service, self).__getattr__(key)
except AttributeError:
return getattr(self.manager, key)
manager = self.__dict__.get('manager', None)
return getattr(manager, key)
@classmethod
def create(cls,
@@ -75,7 +114,8 @@ class Service(object, service.Service):
binary=None,
topic=None,
manager=None,
report_interval=None):
report_interval=None,
periodic_interval=None):
"""Instantiates class and passes back application object.
Args:
@@ -84,6 +124,7 @@ class Service(object, service.Service):
topic, defaults to bin_name - "nova-" part
manager, defaults to FLAGS.<topic>_manager
report_interval, defaults to FLAGS.report_interval
periodic_interval, defaults to FLAGS.periodic_interval
"""
if not host:
host = FLAGS.host
@@ -95,23 +136,11 @@ class Service(object, service.Service):
manager = FLAGS.get('%s_manager' % topic, None)
if not report_interval:
report_interval = FLAGS.report_interval
if not periodic_interval:
periodic_interval = FLAGS.periodic_interval
logging.warn("Starting %s node", topic)
service_obj = cls(host, binary, topic, manager)
conn = rpc.Connection.instance()
consumer_all = rpc.AdapterConsumer(
connection=conn,
topic=topic,
proxy=service_obj)
consumer_node = rpc.AdapterConsumer(
connection=conn,
topic='%s.%s' % (topic, host),
proxy=service_obj)
pulse = task.LoopingCall(service_obj.report_state)
pulse.start(interval=report_interval, now=False)
consumer_all.attach_to_twisted()
consumer_node.attach_to_twisted()
service_obj = cls(host, binary, topic, manager,
report_interval, periodic_interval)
# This is the parent service that twistd will be looking for when it
# parses this file, return it so that we can get it into globals.
@@ -119,23 +148,32 @@ class Service(object, service.Service):
service_obj.setServiceParent(application)
return application
def kill(self, context=None):
def kill(self):
"""Destroy the service object in the datastore"""
try:
service_ref = db.service_get_by_args(context,
self.host,
self.binary)
service_id = service_ref['id']
db.service_destroy(context, self.service_id)
db.service_destroy(context.get_admin_context(), self.service_id)
except exception.NotFound:
logging.warn("Service killed that has no database entry")
@defer.inlineCallbacks
def report_state(self, context=None):
def periodic_tasks(self):
"""Tasks to be run at a periodic interval"""
yield self.manager.periodic_tasks(context.get_admin_context())
@defer.inlineCallbacks
def report_state(self):
"""Update the state of this service in the datastore."""
ctxt = context.get_admin_context()
try:
service_ref = db.service_get(context, self.service_id)
db.service_update(context,
try:
service_ref = db.service_get(ctxt, self.service_id)
except exception.NotFound:
logging.debug("The service database object disappeared, "
"Recreating it.")
self._create_service_ref(ctxt)
service_ref = db.service_get(ctxt, self.service_id)
db.service_update(ctxt,
self.service_id,
{'report_count': service_ref['report_count'] + 1})
@@ -145,7 +183,7 @@ class Service(object, service.Service):
logging.error("Recovered model server connection!")
# TODO(vish): this should probably only catch connection errors
except: # pylint: disable-msg=W0702
except Exception: # pylint: disable-msg=W0702
if not getattr(self, "model_disconnected", False):
self.model_disconnected = True
logging.exception("model server went away")

View File

@@ -22,6 +22,7 @@ Allows overriding of flags for use of fakes,
and some black magic for inline callbacks.
"""
import datetime
import sys
import time
@@ -31,20 +32,18 @@ from tornado import ioloop
from twisted.internet import defer
from twisted.trial import unittest
from nova import context
from nova import db
from nova import fakerabbit
from nova import flags
from nova import rpc
from nova.network import manager as network_manager
FLAGS = flags.FLAGS
flags.DEFINE_bool('fake_tests', True,
'should we use everything for testing')
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
engine = create_engine('sqlite:///:memory:', echo=True)
Base = declarative_base()
Base.metadata.create_all(engine)
def skip_if_fake(func):
"""Decorator that skips a test if running in fake mode"""
@@ -59,27 +58,57 @@ def skip_if_fake(func):
class TrialTestCase(unittest.TestCase):
"""Test case base class for all unit tests"""
def setUp(self): # pylint: disable-msg=C0103
def setUp(self):
"""Run before each test method to initialize test environment"""
super(TrialTestCase, self).setUp()
# NOTE(vish): We need a better method for creating fixtures for tests
# now that we have some required db setup for the system
# to work properly.
self.start = datetime.datetime.utcnow()
ctxt = context.get_admin_context()
if db.network_count(ctxt) != 5:
network_manager.VlanManager().create_networks(ctxt,
FLAGS.fixed_range,
5, 16,
FLAGS.vlan_start,
FLAGS.vpn_start)
# emulate some of the mox stuff, we can't use the metaclass
# because it screws with our generators
self.mox = mox.Mox()
self.stubs = stubout.StubOutForTesting()
self.flag_overrides = {}
self.injected = []
self._monkey_patch_attach()
self._original_flags = FLAGS.FlagValuesDict()
def tearDown(self): # pylint: disable-msg=C0103
"""Runs after each test method to finalize/tear down test environment"""
super(TrialTestCase, self).tearDown()
self.reset_flags()
self.mox.UnsetStubs()
self.stubs.UnsetAll()
self.stubs.SmartUnsetAll()
self.mox.VerifyAll()
def tearDown(self):
"""Runs after each test method to finalize/tear down test
environment."""
try:
self.mox.UnsetStubs()
self.stubs.UnsetAll()
self.stubs.SmartUnsetAll()
self.mox.VerifyAll()
# NOTE(vish): Clean up any ips associated during the test.
ctxt = context.get_admin_context()
db.fixed_ip_disassociate_all_by_timeout(ctxt, FLAGS.host,
self.start)
db.network_disassociate_all(ctxt)
rpc.Consumer.attach_to_twisted = self.originalAttach
for x in self.injected:
try:
x.stop()
except AssertionError:
pass
if FLAGS.fake_rabbit:
fakerabbit.reset_all()
if FLAGS.fake_rabbit:
fakerabbit.reset_all()
db.security_group_destroy_all(ctxt)
super(TrialTestCase, self).tearDown()
finally:
self.reset_flags()
def flags(self, **kw):
"""Override flag variables for a test"""
@@ -93,38 +122,69 @@ class TrialTestCase(unittest.TestCase):
def reset_flags(self):
"""Resets all flag variables for the test. Runs after each test"""
for k, v in self.flag_overrides.iteritems():
FLAGS.Reset()
for k, v in self._original_flags.iteritems():
setattr(FLAGS, k, v)
def run(self, result=None):
test_method = getattr(self, self._testMethodName)
setattr(self,
self._testMethodName,
self._maybeInlineCallbacks(test_method, result))
rv = super(TrialTestCase, self).run(result)
setattr(self, self._testMethodName, test_method)
return rv
def _maybeInlineCallbacks(self, func, result):
def _wrapped():
g = func()
if isinstance(g, defer.Deferred):
return g
if not hasattr(g, 'send'):
return defer.succeed(g)
inlined = defer.inlineCallbacks(func)
d = inlined()
return d
_wrapped.func_name = func.func_name
return _wrapped
def _monkey_patch_attach(self):
self.originalAttach = rpc.Consumer.attach_to_twisted
def _wrapped(innerSelf):
rv = self.originalAttach(innerSelf)
self.injected.append(rv)
return rv
_wrapped.func_name = self.originalAttach.func_name
rpc.Consumer.attach_to_twisted = _wrapped
class BaseTestCase(TrialTestCase):
# TODO(jaypipes): Can this be moved into the TrialTestCase class?
"""Base test case class for all unit tests."""
def setUp(self): # pylint: disable-msg=C0103
"""Base test case class for all unit tests.
DEPRECATED: This is being removed once Tornado is gone, use TrialTestCase.
"""
def setUp(self):
"""Run before each test method to initialize test environment"""
super(BaseTestCase, self).setUp()
# TODO(termie): we could possibly keep a more global registry of
# the injected listeners... this is fine for now though
self.injected = []
self.ioloop = ioloop.IOLoop.instance()
self._waiting = None
self._done_waiting = False
self._timed_out = False
def tearDown(self):# pylint: disable-msg=C0103
"""Runs after each test method to finalize/tear down test environment"""
super(BaseTestCase, self).tearDown()
for x in self.injected:
x.stop()
if FLAGS.fake_rabbit:
fakerabbit.reset_all()
def _wait_for_test(self, timeout=60):
""" Push the ioloop along to wait for our test to complete. """
self._waiting = self.ioloop.add_timeout(time.time() + timeout,
self._timeout)
def _wait():
"""Wrapped wait function. Called on timeout."""
if self._timed_out:
self.fail('test timed out')
@@ -143,7 +203,7 @@ class BaseTestCase(TrialTestCase):
if self._waiting:
try:
self.ioloop.remove_timeout(self._waiting)
except Exception: # pylint: disable-msg=W0703
except Exception: # pylint: disable-msg=W0703
# TODO(jaypipes): This produces a pylint warning. Should
# we really be catching Exception and then passing here?
pass
@@ -164,9 +224,11 @@ class BaseTestCase(TrialTestCase):
Example (callback chain, ugly):
d = self.compute.terminate_instance(instance_id) # a Deferred instance
# A deferred instance
d = self.compute.terminate_instance(instance_id)
def _describe(_):
d_desc = self.compute.describe_instances() # another Deferred instance
# Another deferred instance
d_desc = self.compute.describe_instances()
return d_desc
def _checkDescribe(rv):
self.assertEqual(rv, [])

View File

@@ -18,66 +18,60 @@
import unittest
import logging
import webob
from nova import context
from nova import exception
from nova import flags
from nova import test
from nova.api import ec2
from nova.auth import manager
from nova.auth import rbac
FLAGS = flags.FLAGS
class Context(object):
pass
class AccessTestCase(test.BaseTestCase):
class AccessTestCase(test.TrialTestCase):
def setUp(self):
super(AccessTestCase, self).setUp()
FLAGS.connection_type = 'fake'
FLAGS.fake_storage = True
um = manager.AuthManager()
self.context = context.get_admin_context()
# Make test users
try:
self.testadmin = um.create_user('testadmin')
except Exception, err:
logging.error(str(err))
try:
self.testpmsys = um.create_user('testpmsys')
except: pass
try:
self.testnet = um.create_user('testnet')
except: pass
try:
self.testsys = um.create_user('testsys')
except: pass
self.testadmin = um.create_user('testadmin')
self.testpmsys = um.create_user('testpmsys')
self.testnet = um.create_user('testnet')
self.testsys = um.create_user('testsys')
# Assign some rules
try:
um.add_role('testadmin', 'cloudadmin')
except: pass
try:
um.add_role('testpmsys', 'sysadmin')
except: pass
try:
um.add_role('testnet', 'netadmin')
except: pass
try:
um.add_role('testsys', 'sysadmin')
except: pass
um.add_role('testadmin', 'cloudadmin')
um.add_role('testpmsys', 'sysadmin')
um.add_role('testnet', 'netadmin')
um.add_role('testsys', 'sysadmin')
# Make a test project
try:
self.project = um.create_project('testproj', 'testpmsys', 'a test project', ['testpmsys', 'testnet', 'testsys'])
except: pass
try:
self.project.add_role(self.testnet, 'netadmin')
except: pass
try:
self.project.add_role(self.testsys, 'sysadmin')
except: pass
self.context = Context()
self.context.project = self.project
self.project = um.create_project('testproj',
'testpmsys',
'a test project',
['testpmsys', 'testnet', 'testsys'])
self.project.add_role(self.testnet, 'netadmin')
self.project.add_role(self.testsys, 'sysadmin')
#user is set in each test
def noopWSGIApp(environ, start_response):
start_response('200 OK', [])
return ['']
self.mw = ec2.Authorizer(noopWSGIApp)
self.mw.action_roles = {'str': {
'_allow_all': ['all'],
'_allow_none': [],
'_allow_project_manager': ['projectmanager'],
'_allow_sys_and_net': ['sysadmin', 'netadmin'],
'_allow_sysadmin': ['sysadmin']}}
def tearDown(self):
um = manager.AuthManager()
# Delete the test project
@@ -89,76 +83,44 @@ class AccessTestCase(test.BaseTestCase):
um.delete_user('testsys')
super(AccessTestCase, self).tearDown()
def response_status(self, user, methodName):
ctxt = context.RequestContext(user, self.project)
environ = {'ec2.context': ctxt,
'ec2.controller': 'some string',
'ec2.action': methodName}
req = webob.Request.blank('/', environ)
resp = req.get_response(self.mw)
return resp.status_int
def shouldAllow(self, user, methodName):
self.assertEqual(200, self.response_status(user, methodName))
def shouldDeny(self, user, methodName):
self.assertEqual(401, self.response_status(user, methodName))
def test_001_allow_all(self):
self.context.user = self.testadmin
self.assertTrue(self._allow_all(self.context))
self.context.user = self.testpmsys
self.assertTrue(self._allow_all(self.context))
self.context.user = self.testnet
self.assertTrue(self._allow_all(self.context))
self.context.user = self.testsys
self.assertTrue(self._allow_all(self.context))
users = [self.testadmin, self.testpmsys, self.testnet, self.testsys]
for user in users:
self.shouldAllow(user, '_allow_all')
def test_002_allow_none(self):
self.context.user = self.testadmin
self.assertTrue(self._allow_none(self.context))
self.context.user = self.testpmsys
self.assertRaises(exception.NotAuthorized, self._allow_none, self.context)
self.context.user = self.testnet
self.assertRaises(exception.NotAuthorized, self._allow_none, self.context)
self.context.user = self.testsys
self.assertRaises(exception.NotAuthorized, self._allow_none, self.context)
self.shouldAllow(self.testadmin, '_allow_none')
users = [self.testpmsys, self.testnet, self.testsys]
for user in users:
self.shouldDeny(user, '_allow_none')
def test_003_allow_project_manager(self):
self.context.user = self.testadmin
self.assertTrue(self._allow_project_manager(self.context))
self.context.user = self.testpmsys
self.assertTrue(self._allow_project_manager(self.context))
self.context.user = self.testnet
self.assertRaises(exception.NotAuthorized, self._allow_project_manager, self.context)
self.context.user = self.testsys
self.assertRaises(exception.NotAuthorized, self._allow_project_manager, self.context)
for user in [self.testadmin, self.testpmsys]:
self.shouldAllow(user, '_allow_project_manager')
for user in [self.testnet, self.testsys]:
self.shouldDeny(user, '_allow_project_manager')
def test_004_allow_sys_and_net(self):
self.context.user = self.testadmin
self.assertTrue(self._allow_sys_and_net(self.context))
self.context.user = self.testpmsys # doesn't have the per project sysadmin
self.assertRaises(exception.NotAuthorized, self._allow_sys_and_net, self.context)
self.context.user = self.testnet
self.assertTrue(self._allow_sys_and_net(self.context))
self.context.user = self.testsys
self.assertTrue(self._allow_sys_and_net(self.context))
def test_005_allow_sys_no_pm(self):
self.context.user = self.testadmin
self.assertTrue(self._allow_sys_no_pm(self.context))
self.context.user = self.testpmsys
self.assertRaises(exception.NotAuthorized, self._allow_sys_no_pm, self.context)
self.context.user = self.testnet
self.assertRaises(exception.NotAuthorized, self._allow_sys_no_pm, self.context)
self.context.user = self.testsys
self.assertTrue(self._allow_sys_no_pm(self.context))
@rbac.allow('all')
def _allow_all(self, context):
return True
@rbac.allow('none')
def _allow_none(self, context):
return True
@rbac.allow('projectmanager')
def _allow_project_manager(self, context):
return True
@rbac.allow('sysadmin', 'netadmin')
def _allow_sys_and_net(self, context):
return True
@rbac.allow('sysadmin')
@rbac.deny('projectmanager')
def _allow_sys_no_pm(self, context):
return True
for user in [self.testadmin, self.testnet, self.testsys]:
self.shouldAllow(user, '_allow_sys_and_net')
# denied because it doesn't have the per project sysadmin
for user in [self.testpmsys]:
self.shouldDeny(user, '_allow_sys_and_net')
if __name__ == "__main__":
# TODO: Implement use_fake as an option

View File

@@ -28,16 +28,17 @@ CLC_IP = '127.0.0.1'
CLC_PORT = 8773
REGION = 'test'
def get_connection():
return boto.connect_ec2 (
return boto.connect_ec2(
aws_access_key_id=ACCESS_KEY,
aws_secret_access_key=SECRET_KEY,
is_secure=False,
region=RegionInfo(None, REGION, CLC_IP),
port=CLC_PORT,
path='/services/Cloud',
debug=99
)
debug=99)
class APIIntegrationTests(unittest.TestCase):
def test_001_get_all_images(self):
@@ -51,4 +52,3 @@ if __name__ == '__main__':
#print conn.get_all_key_pairs()
#print conn.create_key_pair
#print conn.create_security_group('name', 'description')

View File

@@ -23,60 +23,19 @@ from boto.ec2 import regioninfo
import httplib
import random
import StringIO
from tornado import httpserver
from twisted.internet import defer
import webob
from nova import context
from nova import flags
from nova import test
from nova import api
from nova.api.ec2 import cloud
from nova.api.ec2 import apirequest
from nova.auth import manager
from nova.endpoint import api
from nova.endpoint import cloud
FLAGS = flags.FLAGS
# NOTE(termie): These are a bunch of helper methods and classes to short
# circuit boto calls and feed them into our tornado handlers,
# it's pretty damn circuitous so apologies if you have to fix
# a bug in it
# NOTE(jaypipes) The pylint disables here are for R0913 (too many args) which
# isn't controllable since boto's HTTPRequest needs that many
# args, and for the version-differentiated import of tornado's
# httputil.
# NOTE(jaypipes): The disable-msg=E1101 and E1103 below is because pylint is
# unable to introspect the deferred's return value properly
def boto_to_tornado(method, path, headers, data, # pylint: disable-msg=R0913
host, connection=None):
""" translate boto requests into tornado requests
connection should be a FakeTornadoHttpConnection instance
"""
try:
headers = httpserver.HTTPHeaders()
except AttributeError:
from tornado import httputil # pylint: disable-msg=E0611
headers = httputil.HTTPHeaders()
for k, v in headers.iteritems():
headers[k] = v
req = httpserver.HTTPRequest(method=method,
uri=path,
headers=headers,
body=data,
host=host,
remote_ip='127.0.0.1',
connection=connection)
return req
def raw_to_httpresponse(response_string):
"""translate a raw tornado http response into an httplib.HTTPResponse"""
sock = FakeHttplibSocket(response_string)
resp = httplib.HTTPResponse(sock)
resp.begin()
return resp
FLAGS.FAKE_subdomain = 'ec2'
class FakeHttplibSocket(object):
@@ -89,116 +48,82 @@ class FakeHttplibSocket(object):
return self._buffer
class FakeTornadoStream(object):
"""a fake stream to satisfy tornado's assumptions, trivial"""
def set_close_callback(self, _func):
"""Dummy callback for stream"""
pass
class FakeTornadoConnection(object):
"""A fake connection object for tornado to pass to its handlers
web requests are expected to write to this as they get data and call
finish when they are done with the request, we buffer the writes and
kick off a callback when it is done so that we can feed the result back
into boto.
"""
def __init__(self, deferred):
self._deferred = deferred
self._buffer = StringIO.StringIO()
def write(self, chunk):
"""Writes a chunk of data to the internal buffer"""
self._buffer.write(chunk)
def finish(self):
"""Finalizes the connection and returns the buffered data via the
deferred callback.
"""
data = self._buffer.getvalue()
self._deferred.callback(data)
xheaders = None
@property
def stream(self): # pylint: disable-msg=R0201
"""Required property for interfacing with tornado"""
return FakeTornadoStream()
class FakeHttplibConnection(object):
"""A fake httplib.HTTPConnection for boto to use
requests made via this connection actually get translated and routed into
our tornado app, we then wait for the response and turn it back into
our WSGI app, we then wait for the response and turn it back into
the httplib.HTTPResponse that boto expects.
"""
def __init__(self, app, host, is_secure=False):
self.app = app
self.host = host
self.deferred = defer.Deferred()
def request(self, method, path, data, headers):
"""Creates a connection to a fake tornado and sets
up a deferred request with the supplied data and
headers"""
conn = FakeTornadoConnection(self.deferred)
request = boto_to_tornado(connection=conn,
method=method,
path=path,
headers=headers,
data=data,
host=self.host)
self.app(request)
self.deferred.addCallback(raw_to_httpresponse)
req = webob.Request.blank(path)
req.method = method
req.body = data
req.headers = headers
req.headers['Accept'] = 'text/html'
req.host = self.host
# Call the WSGI app, get the HTTP response
resp = str(req.get_response(self.app))
# For some reason, the response doesn't have "HTTP/1.0 " prepended; I
# guess that's a function the web server usually provides.
resp = "HTTP/1.0 %s" % resp
sock = FakeHttplibSocket(resp)
self.http_response = httplib.HTTPResponse(sock)
self.http_response.begin()
def getresponse(self):
"""A bit of deferred magic for catching the response
from the previously deferred request"""
@defer.inlineCallbacks
def _waiter():
"""Callback that simply yields the deferred's
return value."""
result = yield self.deferred
defer.returnValue(result)
d = _waiter()
# NOTE(termie): defer.returnValue above should ensure that
# this deferred has already been called by the time
# we get here, we are going to cheat and return
# the result of the callback
return d.result # pylint: disable-msg=E1101
return self.http_response
def close(self):
"""Required for compatibility with boto/tornado"""
pass
class XmlConversionTestCase(test.BaseTestCase):
"""Unit test api xml conversion"""
def test_number_conversion(self):
conv = apirequest._try_convert
self.assertEqual(conv('None'), None)
self.assertEqual(conv('True'), True)
self.assertEqual(conv('False'), False)
self.assertEqual(conv('0'), 0)
self.assertEqual(conv('42'), 42)
self.assertEqual(conv('3.14'), 3.14)
self.assertEqual(conv('-57.12'), -57.12)
self.assertEqual(conv('0x57'), 0x57)
self.assertEqual(conv('-0x57'), -0x57)
self.assertEqual(conv('-'), '-')
self.assertEqual(conv('-0'), 0)
class ApiEc2TestCase(test.BaseTestCase):
"""Unit test for the cloud controller on an EC2 API"""
def setUp(self): # pylint: disable-msg=C0103,C0111
def setUp(self):
super(ApiEc2TestCase, self).setUp()
self.manager = manager.AuthManager()
self.cloud = cloud.CloudController()
self.host = '127.0.0.1'
self.app = api.APIServerApplication({'Cloud': self.cloud})
self.app = api.API()
def expect_http(self, host=None, is_secure=False):
"""Returns a new EC2 connection"""
self.ec2 = boto.connect_ec2(
aws_access_key_id='fake',
aws_secret_access_key='fake',
is_secure=False,
region=regioninfo.RegionInfo(None, 'test', self.host),
port=FLAGS.cc_port,
port=8773,
path='/services/Cloud')
self.mox.StubOutWithMock(self.ec2, 'new_http_connection')
def expect_http(self, host=None, is_secure=False):
"""Returns a new EC2 connection"""
http = FakeHttplibConnection(
self.app, '%s:%d' % (self.host, FLAGS.cc_port), False)
self.app, '%s:8773' % (self.host), False)
# pylint: disable-msg=E1103
self.ec2.new_http_connection(host, is_secure).AndReturn(http)
return http
@@ -214,7 +139,6 @@ class ApiEc2TestCase(test.BaseTestCase):
self.manager.delete_project(project)
self.manager.delete_user(user)
def test_get_all_key_pairs(self):
"""Test that, after creating a user and project and generating
a key pair, that the API call to list key pairs works properly"""
@@ -224,10 +148,195 @@ class ApiEc2TestCase(test.BaseTestCase):
for x in range(random.randint(4, 8)))
user = self.manager.create_user('fake', 'fake', 'fake')
project = self.manager.create_project('fake', 'fake', 'fake')
self.manager.generate_key_pair(user.id, keyname)
# NOTE(vish): create depends on pool, so call helper directly
cloud._gen_key(context.get_admin_context(), user.id, keyname)
rv = self.ec2.get_all_key_pairs()
results = [k for k in rv if k.name == keyname]
self.assertEquals(len(results), 1)
self.manager.delete_project(project)
self.manager.delete_user(user)
def test_get_all_security_groups(self):
"""Test that we can retrieve security groups"""
self.expect_http()
self.mox.ReplayAll()
user = self.manager.create_user('fake', 'fake', 'fake', admin=True)
project = self.manager.create_project('fake', 'fake', 'fake')
rv = self.ec2.get_all_security_groups()
self.assertEquals(len(rv), 1)
self.assertEquals(rv[0].name, 'default')
self.manager.delete_project(project)
self.manager.delete_user(user)
def test_create_delete_security_group(self):
"""Test that we can create a security group"""
self.expect_http()
self.mox.ReplayAll()
user = self.manager.create_user('fake', 'fake', 'fake', admin=True)
project = self.manager.create_project('fake', 'fake', 'fake')
# At the moment, you need both of these to actually be netadmin
self.manager.add_role('fake', 'netadmin')
project.add_role('fake', 'netadmin')
security_group_name = "".join(random.choice("sdiuisudfsdcnpaqwertasd")
for x in range(random.randint(4, 8)))
self.ec2.create_security_group(security_group_name, 'test group')
self.expect_http()
self.mox.ReplayAll()
rv = self.ec2.get_all_security_groups()
self.assertEquals(len(rv), 2)
self.assertTrue(security_group_name in [group.name for group in rv])
self.expect_http()
self.mox.ReplayAll()
self.ec2.delete_security_group(security_group_name)
self.manager.delete_project(project)
self.manager.delete_user(user)
def test_authorize_revoke_security_group_cidr(self):
"""
Test that we can add and remove CIDR based rules
to a security group
"""
self.expect_http()
self.mox.ReplayAll()
user = self.manager.create_user('fake', 'fake', 'fake')
project = self.manager.create_project('fake', 'fake', 'fake')
# At the moment, you need both of these to actually be netadmin
self.manager.add_role('fake', 'netadmin')
project.add_role('fake', 'netadmin')
security_group_name = "".join(random.choice("sdiuisudfsdcnpaqwertasd")
for x in range(random.randint(4, 8)))
group = self.ec2.create_security_group(security_group_name,
'test group')
self.expect_http()
self.mox.ReplayAll()
group.connection = self.ec2
group.authorize('tcp', 80, 81, '0.0.0.0/0')
self.expect_http()
self.mox.ReplayAll()
rv = self.ec2.get_all_security_groups()
# I don't bother checkng that we actually find it here,
# because the create/delete unit test further up should
# be good enough for that.
for group in rv:
if group.name == security_group_name:
self.assertEquals(len(group.rules), 1)
self.assertEquals(int(group.rules[0].from_port), 80)
self.assertEquals(int(group.rules[0].to_port), 81)
self.assertEquals(len(group.rules[0].grants), 1)
self.assertEquals(str(group.rules[0].grants[0]), '0.0.0.0/0')
self.expect_http()
self.mox.ReplayAll()
group.connection = self.ec2
group.revoke('tcp', 80, 81, '0.0.0.0/0')
self.expect_http()
self.mox.ReplayAll()
self.ec2.delete_security_group(security_group_name)
self.expect_http()
self.mox.ReplayAll()
group.connection = self.ec2
rv = self.ec2.get_all_security_groups()
self.assertEqual(len(rv), 1)
self.assertEqual(rv[0].name, 'default')
self.manager.delete_project(project)
self.manager.delete_user(user)
return
def test_authorize_revoke_security_group_foreign_group(self):
"""
Test that we can grant and revoke another security group access
to a security group
"""
self.expect_http()
self.mox.ReplayAll()
user = self.manager.create_user('fake', 'fake', 'fake', admin=True)
project = self.manager.create_project('fake', 'fake', 'fake')
# At the moment, you need both of these to actually be netadmin
self.manager.add_role('fake', 'netadmin')
project.add_role('fake', 'netadmin')
rand_string = 'sdiuisudfsdcnpaqwertasd'
security_group_name = "".join(random.choice(rand_string)
for x in range(random.randint(4, 8)))
other_security_group_name = "".join(random.choice(rand_string)
for x in range(random.randint(4, 8)))
group = self.ec2.create_security_group(security_group_name,
'test group')
self.expect_http()
self.mox.ReplayAll()
other_group = self.ec2.create_security_group(other_security_group_name,
'some other group')
self.expect_http()
self.mox.ReplayAll()
group.connection = self.ec2
group.authorize(src_group=other_group)
self.expect_http()
self.mox.ReplayAll()
rv = self.ec2.get_all_security_groups()
# I don't bother checkng that we actually find it here,
# because the create/delete unit test further up should
# be good enough for that.
for group in rv:
if group.name == security_group_name:
self.assertEquals(len(group.rules), 1)
self.assertEquals(len(group.rules[0].grants), 1)
self.assertEquals(str(group.rules[0].grants[0]), '%s-%s' %
(other_security_group_name, 'fake'))
self.expect_http()
self.mox.ReplayAll()
rv = self.ec2.get_all_security_groups()
for group in rv:
if group.name == security_group_name:
self.expect_http()
self.mox.ReplayAll()
group.connection = self.ec2
group.revoke(src_group=other_group)
self.expect_http()
self.mox.ReplayAll()
self.ec2.delete_security_group(security_group_name)
self.manager.delete_project(project)
self.manager.delete_user(user)
return

View File

@@ -17,8 +17,6 @@
# under the License.
import logging
from M2Crypto import BIO
from M2Crypto import RSA
from M2Crypto import X509
import unittest
@@ -26,33 +24,83 @@ from nova import crypto
from nova import flags
from nova import test
from nova.auth import manager
from nova.endpoint import cloud
from nova.api.ec2 import cloud
FLAGS = flags.FLAGS
class AuthTestCase(test.BaseTestCase):
class user_generator(object):
def __init__(self, manager, **user_state):
if 'name' not in user_state:
user_state['name'] = 'test1'
self.manager = manager
self.user = manager.create_user(**user_state)
def __enter__(self):
return self.user
def __exit__(self, value, type, trace):
self.manager.delete_user(self.user)
class project_generator(object):
def __init__(self, manager, **project_state):
if 'name' not in project_state:
project_state['name'] = 'testproj'
if 'manager_user' not in project_state:
project_state['manager_user'] = 'test1'
self.manager = manager
self.project = manager.create_project(**project_state)
def __enter__(self):
return self.project
def __exit__(self, value, type, trace):
self.manager.delete_project(self.project)
class user_and_project_generator(object):
def __init__(self, manager, user_state={}, project_state={}):
self.manager = manager
if 'name' not in user_state:
user_state['name'] = 'test1'
if 'name' not in project_state:
project_state['name'] = 'testproj'
if 'manager_user' not in project_state:
project_state['manager_user'] = 'test1'
self.user = manager.create_user(**user_state)
self.project = manager.create_project(**project_state)
def __enter__(self):
return (self.user, self.project)
def __exit__(self, value, type, trace):
self.manager.delete_user(self.user)
self.manager.delete_project(self.project)
class AuthManagerTestCase(object):
def setUp(self):
super(AuthTestCase, self).setUp()
self.flags(connection_type='fake',
fake_storage=True)
self.manager = manager.AuthManager()
FLAGS.auth_driver = self.auth_driver
super(AuthManagerTestCase, self).setUp()
self.flags(connection_type='fake')
self.manager = manager.AuthManager(new=True)
def test_001_can_create_users(self):
self.manager.create_user('test1', 'access', 'secret')
self.manager.create_user('test2')
def test_create_and_find_user(self):
with user_generator(self.manager):
self.assert_(self.manager.get_user('test1'))
def test_002_can_get_user(self):
user = self.manager.get_user('test1')
def test_003_can_retreive_properties(self):
user = self.manager.get_user('test1')
self.assertEqual('test1', user.id)
self.assertEqual('access', user.access)
self.assertEqual('secret', user.secret)
def test_create_and_find_with_properties(self):
with user_generator(self.manager, name="herbert", secret="classified",
access="private-party"):
u = self.manager.get_user('herbert')
self.assertEqual('herbert', u.id)
self.assertEqual('herbert', u.name)
self.assertEqual('classified', u.secret)
self.assertEqual('private-party', u.access)
def test_004_signature_is_valid(self):
#self.assertTrue(self.manager.authenticate( **boto.generate_url ... ? ? ? ))
#self.assertTrue(self.manager.authenticate(**boto.generate_url ...? ))
pass
#raise NotImplementedError
@@ -66,156 +114,237 @@ class AuthTestCase(test.BaseTestCase):
'export S3_URL="http://127.0.0.1:3333/"\n' +
'export EC2_USER_ID="test1"\n')
def test_006_test_key_storage(self):
user = self.manager.get_user('test1')
user.create_key_pair('public', 'key', 'fingerprint')
key = user.get_key_pair('public')
self.assertEqual('key', key.public_key)
self.assertEqual('fingerprint', key.fingerprint)
def test_can_list_users(self):
with user_generator(self.manager):
with user_generator(self.manager, name="test2"):
users = self.manager.get_users()
self.assert_(filter(lambda u: u.id == 'test1', users))
self.assert_(filter(lambda u: u.id == 'test2', users))
self.assert_(not filter(lambda u: u.id == 'test3', users))
def test_007_test_key_generation(self):
user = self.manager.get_user('test1')
private_key, fingerprint = user.generate_key_pair('public2')
key = RSA.load_key_string(private_key, callback=lambda: None)
bio = BIO.MemoryBuffer()
public_key = user.get_key_pair('public2').public_key
key.save_pub_key_bio(bio)
converted = crypto.ssl_pub_to_ssh_pub(bio.read())
# assert key fields are equal
self.assertEqual(public_key.split(" ")[1].strip(),
converted.split(" ")[1].strip())
def test_can_add_and_remove_user_role(self):
with user_generator(self.manager):
self.assertFalse(self.manager.has_role('test1', 'itsec'))
self.manager.add_role('test1', 'itsec')
self.assertTrue(self.manager.has_role('test1', 'itsec'))
self.manager.remove_role('test1', 'itsec')
self.assertFalse(self.manager.has_role('test1', 'itsec'))
def test_008_can_list_key_pairs(self):
keys = self.manager.get_user('test1').get_key_pairs()
self.assertTrue(filter(lambda k: k.name == 'public', keys))
self.assertTrue(filter(lambda k: k.name == 'public2', keys))
def test_can_create_and_get_project(self):
with user_and_project_generator(self.manager) as (u, p):
self.assert_(self.manager.get_user('test1'))
self.assert_(self.manager.get_user('test1'))
self.assert_(self.manager.get_project('testproj'))
def test_009_can_delete_key_pair(self):
self.manager.get_user('test1').delete_key_pair('public')
keys = self.manager.get_user('test1').get_key_pairs()
self.assertFalse(filter(lambda k: k.name == 'public', keys))
def test_can_list_projects(self):
with user_and_project_generator(self.manager):
with project_generator(self.manager, name="testproj2"):
projects = self.manager.get_projects()
self.assert_(filter(lambda p: p.name == 'testproj', projects))
self.assert_(filter(lambda p: p.name == 'testproj2', projects))
self.assert_(not filter(lambda p: p.name == 'testproj3',
projects))
def test_010_can_list_users(self):
users = self.manager.get_users()
logging.warn(users)
self.assertTrue(filter(lambda u: u.id == 'test1', users))
def test_can_create_and_get_project_with_attributes(self):
with user_generator(self.manager):
with project_generator(self.manager, description='A test project'):
project = self.manager.get_project('testproj')
self.assertEqual('A test project', project.description)
def test_101_can_add_user_role(self):
self.assertFalse(self.manager.has_role('test1', 'itsec'))
self.manager.add_role('test1', 'itsec')
self.assertTrue(self.manager.has_role('test1', 'itsec'))
def test_can_create_project_with_manager(self):
with user_and_project_generator(self.manager) as (user, project):
self.assertEqual('test1', project.project_manager_id)
self.assertTrue(self.manager.is_project_manager(user, project))
def test_199_can_remove_user_role(self):
self.assertTrue(self.manager.has_role('test1', 'itsec'))
self.manager.remove_role('test1', 'itsec')
self.assertFalse(self.manager.has_role('test1', 'itsec'))
def test_create_project_assigns_manager_to_members(self):
with user_and_project_generator(self.manager) as (user, project):
self.assertTrue(self.manager.is_project_member(user, project))
def test_201_can_create_project(self):
project = self.manager.create_project('testproj', 'test1', 'A test project', ['test1'])
self.assertTrue(filter(lambda p: p.name == 'testproj', self.manager.get_projects()))
self.assertEqual(project.name, 'testproj')
self.assertEqual(project.description, 'A test project')
self.assertEqual(project.project_manager_id, 'test1')
self.assertTrue(project.has_member('test1'))
def test_no_extra_project_members(self):
with user_generator(self.manager, name='test2') as baduser:
with user_and_project_generator(self.manager) as (user, project):
self.assertFalse(self.manager.is_project_member(baduser,
project))
def test_202_user1_is_project_member(self):
self.assertTrue(self.manager.get_user('test1').is_project_member('testproj'))
def test_no_extra_project_managers(self):
with user_generator(self.manager, name='test2') as baduser:
with user_and_project_generator(self.manager) as (user, project):
self.assertFalse(self.manager.is_project_manager(baduser,
project))
def test_203_user2_is_not_project_member(self):
self.assertFalse(self.manager.get_user('test2').is_project_member('testproj'))
def test_can_add_user_to_project(self):
with user_generator(self.manager, name='test2') as user:
with user_and_project_generator(self.manager) as (_user, project):
self.manager.add_to_project(user, project)
project = self.manager.get_project('testproj')
self.assertTrue(self.manager.is_project_member(user, project))
def test_204_user1_is_project_manager(self):
self.assertTrue(self.manager.get_user('test1').is_project_manager('testproj'))
def test_can_remove_user_from_project(self):
with user_generator(self.manager, name='test2') as user:
with user_and_project_generator(self.manager) as (_user, project):
self.manager.add_to_project(user, project)
project = self.manager.get_project('testproj')
self.assertTrue(self.manager.is_project_member(user, project))
self.manager.remove_from_project(user, project)
project = self.manager.get_project('testproj')
self.assertFalse(self.manager.is_project_member(user, project))
def test_205_user2_is_not_project_manager(self):
self.assertFalse(self.manager.get_user('test2').is_project_manager('testproj'))
def test_can_add_remove_user_with_role(self):
with user_generator(self.manager, name='test2') as user:
with user_and_project_generator(self.manager) as (_user, project):
# NOTE(todd): after modifying users you must reload project
self.manager.add_to_project(user, project)
project = self.manager.get_project('testproj')
self.manager.add_role(user, 'developer', project)
self.assertTrue(self.manager.is_project_member(user, project))
self.manager.remove_from_project(user, project)
project = self.manager.get_project('testproj')
self.assertFalse(self.manager.has_role(user, 'developer',
project))
self.assertFalse(self.manager.is_project_member(user, project))
def test_206_can_add_user_to_project(self):
self.manager.add_to_project('test2', 'testproj')
self.assertTrue(self.manager.get_project('testproj').has_member('test2'))
def test_can_generate_x509(self):
# NOTE(todd): this doesn't assert against the auth manager
# so it probably belongs in crypto_unittest
# but I'm leaving it where I found it.
with user_and_project_generator(self.manager) as (user, project):
# NOTE(todd): Should mention why we must setup controller first
# (somebody please clue me in)
cloud_controller = cloud.CloudController()
cloud_controller.setup()
_key, cert_str = self.manager._generate_x509_cert('test1',
'testproj')
logging.debug(cert_str)
def test_207_can_remove_user_from_project(self):
self.manager.remove_from_project('test2', 'testproj')
self.assertFalse(self.manager.get_project('testproj').has_member('test2'))
# Need to verify that it's signed by the right intermediate CA
full_chain = crypto.fetch_ca(project_id='testproj', chain=True)
int_cert = crypto.fetch_ca(project_id='testproj', chain=False)
cloud_cert = crypto.fetch_ca()
logging.debug("CA chain:\n\n =====\n%s\n\n=====" % full_chain)
signed_cert = X509.load_cert_string(cert_str)
chain_cert = X509.load_cert_string(full_chain)
int_cert = X509.load_cert_string(int_cert)
cloud_cert = X509.load_cert_string(cloud_cert)
self.assertTrue(signed_cert.verify(chain_cert.get_pubkey()))
self.assertTrue(signed_cert.verify(int_cert.get_pubkey()))
if not FLAGS.use_intermediate_ca:
self.assertTrue(signed_cert.verify(cloud_cert.get_pubkey()))
else:
self.assertFalse(signed_cert.verify(cloud_cert.get_pubkey()))
def test_208_can_remove_add_user_with_role(self):
self.manager.add_to_project('test2', 'testproj')
self.manager.add_role('test2', 'developer', 'testproj')
self.manager.remove_from_project('test2', 'testproj')
self.assertFalse(self.manager.has_role('test2', 'developer', 'testproj'))
self.manager.add_to_project('test2', 'testproj')
self.manager.remove_from_project('test2', 'testproj')
def test_adding_role_to_project_is_ignored_unless_added_to_user(self):
with user_and_project_generator(self.manager) as (user, project):
self.assertFalse(self.manager.has_role(user, 'sysadmin', project))
self.manager.add_role(user, 'sysadmin', project)
# NOTE(todd): it will still show up in get_user_roles(u, project)
self.assertFalse(self.manager.has_role(user, 'sysadmin', project))
self.manager.add_role(user, 'sysadmin')
self.assertTrue(self.manager.has_role(user, 'sysadmin', project))
def test_209_can_generate_x509(self):
# MUST HAVE RUN CLOUD SETUP BY NOW
self.cloud = cloud.CloudController()
self.cloud.setup()
_key, cert_str = self.manager._generate_x509_cert('test1', 'testproj')
logging.debug(cert_str)
def test_add_user_role_doesnt_infect_project_roles(self):
with user_and_project_generator(self.manager) as (user, project):
self.assertFalse(self.manager.has_role(user, 'sysadmin', project))
self.manager.add_role(user, 'sysadmin')
self.assertFalse(self.manager.has_role(user, 'sysadmin', project))
# Need to verify that it's signed by the right intermediate CA
full_chain = crypto.fetch_ca(project_id='testproj', chain=True)
int_cert = crypto.fetch_ca(project_id='testproj', chain=False)
cloud_cert = crypto.fetch_ca()
logging.debug("CA chain:\n\n =====\n%s\n\n=====" % full_chain)
signed_cert = X509.load_cert_string(cert_str)
chain_cert = X509.load_cert_string(full_chain)
int_cert = X509.load_cert_string(int_cert)
cloud_cert = X509.load_cert_string(cloud_cert)
self.assertTrue(signed_cert.verify(chain_cert.get_pubkey()))
self.assertTrue(signed_cert.verify(int_cert.get_pubkey()))
def test_can_list_user_roles(self):
with user_and_project_generator(self.manager) as (user, project):
self.manager.add_role(user, 'sysadmin')
roles = self.manager.get_user_roles(user)
self.assertTrue('sysadmin' in roles)
self.assertFalse('netadmin' in roles)
if not FLAGS.use_intermediate_ca:
self.assertTrue(signed_cert.verify(cloud_cert.get_pubkey()))
else:
self.assertFalse(signed_cert.verify(cloud_cert.get_pubkey()))
def test_can_list_project_roles(self):
with user_and_project_generator(self.manager) as (user, project):
self.manager.add_role(user, 'sysadmin')
self.manager.add_role(user, 'sysadmin', project)
self.manager.add_role(user, 'netadmin', project)
project_roles = self.manager.get_user_roles(user, project)
self.assertTrue('sysadmin' in project_roles)
self.assertTrue('netadmin' in project_roles)
# has role should be false user-level role is missing
self.assertFalse(self.manager.has_role(user, 'netadmin', project))
def test_210_can_add_project_role(self):
project = self.manager.get_project('testproj')
self.assertFalse(project.has_role('test1', 'sysadmin'))
self.manager.add_role('test1', 'sysadmin')
self.assertFalse(project.has_role('test1', 'sysadmin'))
project.add_role('test1', 'sysadmin')
self.assertTrue(project.has_role('test1', 'sysadmin'))
def test_can_remove_user_roles(self):
with user_and_project_generator(self.manager) as (user, project):
self.manager.add_role(user, 'sysadmin')
self.assertTrue(self.manager.has_role(user, 'sysadmin'))
self.manager.remove_role(user, 'sysadmin')
self.assertFalse(self.manager.has_role(user, 'sysadmin'))
def test_211_can_list_project_roles(self):
project = self.manager.get_project('testproj')
user = self.manager.get_user('test1')
self.manager.add_role(user, 'netadmin', project)
roles = self.manager.get_user_roles(user)
self.assertTrue('sysadmin' in roles)
self.assertFalse('netadmin' in roles)
project_roles = self.manager.get_user_roles(user, project)
self.assertTrue('sysadmin' in project_roles)
self.assertTrue('netadmin' in project_roles)
# has role should be false because global role is missing
self.assertFalse(self.manager.has_role(user, 'netadmin', project))
def test_removing_user_role_hides_it_from_project(self):
with user_and_project_generator(self.manager) as (user, project):
self.manager.add_role(user, 'sysadmin')
self.manager.add_role(user, 'sysadmin', project)
self.assertTrue(self.manager.has_role(user, 'sysadmin', project))
self.manager.remove_role(user, 'sysadmin')
self.assertFalse(self.manager.has_role(user, 'sysadmin', project))
def test_can_remove_project_role_but_keep_user_role(self):
with user_and_project_generator(self.manager) as (user, project):
self.manager.add_role(user, 'sysadmin')
self.manager.add_role(user, 'sysadmin', project)
self.assertTrue(self.manager.has_role(user, 'sysadmin'))
self.manager.remove_role(user, 'sysadmin', project)
self.assertFalse(self.manager.has_role(user, 'sysadmin', project))
self.assertTrue(self.manager.has_role(user, 'sysadmin'))
def test_212_can_remove_project_role(self):
project = self.manager.get_project('testproj')
self.assertTrue(project.has_role('test1', 'sysadmin'))
project.remove_role('test1', 'sysadmin')
self.assertFalse(project.has_role('test1', 'sysadmin'))
self.manager.remove_role('test1', 'sysadmin')
self.assertFalse(project.has_role('test1', 'sysadmin'))
def test_can_retrieve_project_by_user(self):
with user_and_project_generator(self.manager) as (user, project):
self.assertEqual(1, len(self.manager.get_projects('test1')))
def test_214_can_retrieve_project_by_user(self):
project = self.manager.create_project('testproj2', 'test2', 'Another test project', ['test2'])
self.assert_(len(self.manager.get_projects()) > 1)
self.assertEqual(len(self.manager.get_projects('test2')), 1)
def test_can_modify_project(self):
with user_and_project_generator(self.manager):
with user_generator(self.manager, name='test2'):
self.manager.modify_project('testproj', 'test2', 'new desc')
project = self.manager.get_project('testproj')
self.assertEqual('test2', project.project_manager_id)
self.assertEqual('new desc', project.description)
def test_299_can_delete_project(self):
self.manager.delete_project('testproj')
self.assertFalse(filter(lambda p: p.name == 'testproj', self.manager.get_projects()))
self.manager.delete_project('testproj2')
def test_can_delete_project(self):
with user_generator(self.manager):
self.manager.create_project('testproj', 'test1')
self.assert_(self.manager.get_project('testproj'))
self.manager.delete_project('testproj')
projectlist = self.manager.get_projects()
self.assert_(not filter(lambda p: p.name == 'testproj',
projectlist))
def test_999_can_delete_users(self):
def test_can_delete_user(self):
self.manager.create_user('test1')
self.assert_(self.manager.get_user('test1'))
self.manager.delete_user('test1')
users = self.manager.get_users()
self.assertFalse(filter(lambda u: u.id == 'test1', users))
self.manager.delete_user('test2')
self.assertEqual(self.manager.get_user('test2'), None)
userlist = self.manager.get_users()
self.assert_(not filter(lambda u: u.id == 'test1', userlist))
def test_can_modify_users(self):
with user_generator(self.manager):
self.manager.modify_user('test1', 'access', 'secret', True)
user = self.manager.get_user('test1')
self.assertEqual('access', user.access)
self.assertEqual('secret', user.secret)
self.assertTrue(user.is_admin())
class AuthManagerLdapTestCase(AuthManagerTestCase, test.TrialTestCase):
auth_driver = 'nova.auth.ldapdriver.FakeLdapDriver'
def __init__(self, *args, **kwargs):
AuthManagerTestCase.__init__(self)
test.TrialTestCase.__init__(self, *args, **kwargs)
import nova.auth.fakeldap as fakeldap
FLAGS.redis_db = 8
if FLAGS.flush_db:
logging.info("Flushing redis datastore")
try:
r = fakeldap.Redis.instance()
r.flushdb()
except:
self.skip = True
class AuthManagerDbTestCase(AuthManagerTestCase, test.TrialTestCase):
auth_driver = 'nova.auth.dbdriver.DbDriver'
if __name__ == "__main__":

View File

@@ -16,31 +16,47 @@
# License for the specific language governing permissions and limitations
# under the License.
from base64 import b64decode
import json
import logging
from M2Crypto import BIO
from M2Crypto import RSA
import os
import StringIO
import tempfile
import time
from tornado import ioloop
from eventlet import greenthread
from twisted.internet import defer
import unittest
from xml.etree import ElementTree
from nova import context
from nova import crypto
from nova import db
from nova import flags
from nova import rpc
from nova import test
from nova import utils
from nova.auth import manager
from nova.endpoint import api
from nova.endpoint import cloud
from nova.compute import power_state
from nova.api.ec2 import cloud
from nova.objectstore import image
FLAGS = flags.FLAGS
# Temp dirs for working with image attributes through the cloud controller
# (stole this from objectstore_unittest.py)
OSS_TEMPDIR = tempfile.mkdtemp(prefix='test_oss-')
IMAGES_PATH = os.path.join(OSS_TEMPDIR, 'images')
os.makedirs(IMAGES_PATH)
class CloudTestCase(test.BaseTestCase):
class CloudTestCase(test.TrialTestCase):
def setUp(self):
super(CloudTestCase, self).setUp()
self.flags(connection_type='fake',
fake_storage=True)
self.flags(connection_type='fake', images_path=IMAGES_PATH)
self.conn = rpc.Connection.instance()
logging.getLogger().setLevel(logging.DEBUG)
@@ -49,33 +65,74 @@ class CloudTestCase(test.BaseTestCase):
self.cloud = cloud.CloudController()
# set up a service
self.compute = utils.import_class(FLAGS.compute_manager)
self.compute = utils.import_object(FLAGS.compute_manager)
self.compute_consumer = rpc.AdapterConsumer(connection=self.conn,
topic=FLAGS.compute_topic,
proxy=self.compute)
self.injected.append(self.compute_consumer.attach_to_tornado(self.ioloop))
topic=FLAGS.compute_topic,
proxy=self.compute)
self.compute_consumer.attach_to_eventlet()
self.network = utils.import_object(FLAGS.network_manager)
self.network_consumer = rpc.AdapterConsumer(connection=self.conn,
topic=FLAGS.network_topic,
proxy=self.network)
self.network_consumer.attach_to_eventlet()
try:
manager.AuthManager().create_user('admin', 'admin', 'admin')
except: pass
admin = manager.AuthManager().get_user('admin')
project = manager.AuthManager().create_project('proj', 'admin', 'proj')
self.context = api.APIRequestContext(handler=None,project=project,user=admin)
self.manager = manager.AuthManager()
self.user = self.manager.create_user('admin', 'admin', 'admin', True)
self.project = self.manager.create_project('proj', 'admin', 'proj')
self.context = context.RequestContext(user=self.user,
project=self.project)
def tearDown(self):
manager.AuthManager().delete_project('proj')
manager.AuthManager().delete_user('admin')
self.manager.delete_project(self.project)
self.manager.delete_user(self.user)
super(CloudTestCase, self).tearDown()
def _create_key(self, name):
# NOTE(vish): create depends on pool, so just call helper directly
return cloud._gen_key(self.context, self.context.user.id, name)
def test_console_output(self):
if FLAGS.connection_type == 'fake':
logging.debug("Can't test instances without a real virtual env.")
return
instance_id = 'foo'
inst = yield self.compute.run_instance(instance_id)
output = yield self.cloud.get_console_output(self.context, [instance_id])
logging.debug(output)
self.assert_(output)
rv = yield self.compute.terminate_instance(instance_id)
image_id = FLAGS.default_image
instance_type = FLAGS.default_instance_type
max_count = 1
kwargs = {'image_id': image_id,
'instance_type': instance_type,
'max_count': max_count}
rv = yield self.cloud.run_instances(self.context, **kwargs)
instance_id = rv['instancesSet'][0]['instanceId']
output = yield self.cloud.get_console_output(context=self.context,
instance_id=[instance_id])
self.assertEquals(b64decode(output['output']), 'FAKE CONSOLE OUTPUT')
# TODO(soren): We need this until we can stop polling in the rpc code
# for unit tests.
greenthread.sleep(0.3)
rv = yield self.cloud.terminate_instances(self.context, [instance_id])
def test_key_generation(self):
result = self._create_key('test')
private_key = result['private_key']
key = RSA.load_key_string(private_key, callback=lambda: None)
bio = BIO.MemoryBuffer()
public_key = db.key_pair_get(self.context,
self.context.user.id,
'test')['public_key']
key.save_pub_key_bio(bio)
converted = crypto.ssl_pub_to_ssh_pub(bio.read())
# assert key fields are equal
self.assertEqual(public_key.split(" ")[1].strip(),
converted.split(" ")[1].strip())
def test_describe_key_pairs(self):
self._create_key('test1')
self._create_key('test2')
result = self.cloud.describe_key_pairs(self.context)
keys = result["keypairsSet"]
self.assertTrue(filter(lambda k: k['keyName'] == 'test1', keys))
self.assertTrue(filter(lambda k: k['keyName'] == 'test2', keys))
def test_delete_key_pair(self):
self._create_key('test')
self.cloud.delete_key_pair(self.context, 'test')
def test_run_instances(self):
if FLAGS.connection_type == 'fake':
@@ -89,25 +146,28 @@ class CloudTestCase(test.BaseTestCase):
'max_count': max_count}
rv = yield self.cloud.run_instances(self.context, **kwargs)
# TODO: check for proper response
instance = rv['reservationSet'][0][rv['reservationSet'][0].keys()[0]][0]
logging.debug("Need to watch instance %s until it's running..." % instance['instance_id'])
instance_id = rv['reservationSet'][0].keys()[0]
instance = rv['reservationSet'][0][instance_id][0]
logging.debug("Need to watch instance %s until it's running..." %
instance['instance_id'])
while True:
rv = yield defer.succeed(time.sleep(1))
info = self.cloud._get_instance(instance['instance_id'])
logging.debug(info['state'])
if info['state'] == node.Instance.RUNNING:
if info['state'] == power_state.RUNNING:
break
self.assert_(rv)
if connection_type != 'fake':
time.sleep(45) # Should use boto for polling here
time.sleep(45) # Should use boto for polling here
for reservations in rv['reservationSet']:
# for res_id in reservations.keys():
# logging.debug(reservations[res_id])
# for instance in reservations[res_id]:
for instance in reservations[reservations.keys()[0]]:
logging.debug("Terminating instance %s" % instance['instance_id'])
rv = yield self.compute.terminate_instance(instance['instance_id'])
# logging.debug(reservations[res_id])
# for instance in reservations[res_id]:
for instance in reservations[reservations.keys()[0]]:
instance_id = instance['instance_id']
logging.debug("Terminating instance %s" % instance_id)
rv = yield self.compute.terminate_instance(instance_id)
def test_instance_update_state(self):
def instance(num):
@@ -126,8 +186,7 @@ class CloudTestCase(test.BaseTestCase):
'groups': ['default'],
'product_codes': None,
'state': 0x01,
'user_data': ''
}
'user_data': ''}
rv = self.cloud._format_describe_instances(self.context)
self.assert_(len(rv['reservationSet']) == 0)
@@ -142,7 +201,9 @@ class CloudTestCase(test.BaseTestCase):
#self.assert_(len(rv['reservationSet'][0]['instances_set']) == 5)
# report 4 nodes each having 1 of the instances
#for i in xrange(4):
# self.cloud.update_state('instances', {('node-%s' % i): {('i-%s' % i): instance(i)}})
# self.cloud.update_state('instances',
# {('node-%s' % i): {('i-%s' % i):
# instance(i)}})
# one instance should be pending still
#self.assert_(len(self.cloud.instances['pending'].keys()) == 1)
@@ -156,3 +217,70 @@ class CloudTestCase(test.BaseTestCase):
#for i in xrange(4):
# data = self.cloud.get_metadata(instance(i)['private_dns_name'])
# self.assert_(data['meta-data']['ami-id'] == 'ami-%s' % i)
@staticmethod
def _fake_set_image_description(ctxt, image_id, description):
from nova.objectstore import handler
class req:
pass
request = req()
request.context = ctxt
request.args = {'image_id': [image_id],
'description': [description]}
resource = handler.ImagesResource()
resource.render_POST(request)
def test_user_editable_image_endpoint(self):
pathdir = os.path.join(FLAGS.images_path, 'ami-testing')
os.mkdir(pathdir)
info = {'isPublic': False}
with open(os.path.join(pathdir, 'info.json'), 'w') as f:
json.dump(info, f)
img = image.Image('ami-testing')
# self.cloud.set_image_description(self.context, 'ami-testing',
# 'Foo Img')
# NOTE(vish): Above won't work unless we start objectstore or create
# a fake version of api/ec2/images.py conn that can
# call methods directly instead of going through boto.
# for now, just cheat and call the method directly
self._fake_set_image_description(self.context, 'ami-testing',
'Foo Img')
self.assertEqual('Foo Img', img.metadata['description'])
self._fake_set_image_description(self.context, 'ami-testing', '')
self.assertEqual('', img.metadata['description'])
def test_update_of_instance_display_fields(self):
inst = db.instance_create(self.context, {})
ec2_id = cloud.internal_id_to_ec2_id(inst['internal_id'])
self.cloud.update_instance(self.context, ec2_id,
display_name='c00l 1m4g3')
inst = db.instance_get(self.context, inst['id'])
self.assertEqual('c00l 1m4g3', inst['display_name'])
db.instance_destroy(self.context, inst['id'])
def test_update_of_instance_wont_update_private_fields(self):
inst = db.instance_create(self.context, {})
self.cloud.update_instance(self.context, inst['id'],
mac_address='DE:AD:BE:EF')
inst = db.instance_get(self.context, inst['id'])
self.assertEqual(None, inst['mac_address'])
db.instance_destroy(self.context, inst['id'])
def test_update_of_volume_display_fields(self):
vol = db.volume_create(self.context, {})
self.cloud.update_volume(self.context, vol['id'],
display_name='c00l v0lum3')
vol = db.volume_get(self.context, vol['id'])
self.assertEqual('c00l v0lum3', vol['display_name'])
db.volume_destroy(self.context, vol['id'])
def test_update_of_volume_wont_update_private_fields(self):
vol = db.volume_create(self.context, {})
self.cloud.update_volume(self.context, vol['id'],
mountpoint='/not/here')
vol = db.volume_get(self.context, vol['id'])
self.assertEqual(None, vol['mountpoint'])
db.volume_destroy(self.context, vol['id'])

View File

@@ -18,10 +18,13 @@
"""
Tests For Compute
"""
import datetime
import logging
from twisted.internet import defer
from nova import context
from nova import db
from nova import exception
from nova import flags
@@ -29,26 +32,26 @@ from nova import test
from nova import utils
from nova.auth import manager
FLAGS = flags.FLAGS
class ComputeTestCase(test.TrialTestCase):
"""Test case for compute"""
def setUp(self): # pylint: disable-msg=C0103
def setUp(self):
logging.getLogger().setLevel(logging.DEBUG)
super(ComputeTestCase, self).setUp()
self.flags(connection_type='fake',
fake_storage=True)
network_manager='nova.network.manager.FlatManager')
self.compute = utils.import_object(FLAGS.compute_manager)
self.manager = manager.AuthManager()
self.user = self.manager.create_user('fake', 'fake', 'fake')
self.project = self.manager.create_project('fake', 'fake', 'fake')
self.context = None
self.context = context.get_admin_context()
def tearDown(self): # pylint: disable-msg=C0103
def tearDown(self):
self.manager.delete_user(self.user)
self.manager.delete_project(self.project)
super(ComputeTestCase, self).tearDown()
def _create_instance(self):
"""Create a test instance"""
@@ -61,7 +64,7 @@ class ComputeTestCase(test.TrialTestCase):
inst['instance_type'] = 'm1.tiny'
inst['mac_address'] = utils.generate_mac()
inst['ami_launch_index'] = 0
return db.instance_create(self.context, inst)
return db.instance_create(self.context, inst)['id']
@defer.inlineCallbacks
def test_run_terminate(self):
@@ -70,16 +73,35 @@ class ComputeTestCase(test.TrialTestCase):
yield self.compute.run_instance(self.context, instance_id)
instances = db.instance_get_all(None)
instances = db.instance_get_all(context.get_admin_context())
logging.info("Running instances: %s", instances)
self.assertEqual(len(instances), 1)
yield self.compute.terminate_instance(self.context, instance_id)
instances = db.instance_get_all(None)
instances = db.instance_get_all(context.get_admin_context())
logging.info("After terminating instances: %s", instances)
self.assertEqual(len(instances), 0)
@defer.inlineCallbacks
def test_run_terminate_timestamps(self):
"""Make sure timestamps are set for launched and destroyed"""
instance_id = self._create_instance()
instance_ref = db.instance_get(self.context, instance_id)
self.assertEqual(instance_ref['launched_at'], None)
self.assertEqual(instance_ref['deleted_at'], None)
launch = datetime.datetime.utcnow()
yield self.compute.run_instance(self.context, instance_id)
instance_ref = db.instance_get(self.context, instance_id)
self.assert_(instance_ref['launched_at'] > launch)
self.assertEqual(instance_ref['deleted_at'], None)
terminate = datetime.datetime.utcnow()
yield self.compute.terminate_instance(self.context, instance_id)
self.context = self.context.elevated(True)
instance_ref = db.instance_get(self.context, instance_id)
self.assert_(instance_ref['launched_at'] < terminate)
self.assert_(instance_ref['deleted_at'] > terminate)
@defer.inlineCallbacks
def test_reboot(self):
"""Ensure instance can be rebooted"""

View File

@@ -20,11 +20,11 @@ from nova import flags
FLAGS = flags.FLAGS
flags.DECLARE('fake_storage', 'nova.volume.manager')
FLAGS.fake_storage = True
flags.DECLARE('volume_driver', 'nova.volume.manager')
FLAGS.volume_driver = 'nova.volume.driver.FakeAOEDriver'
FLAGS.connection_type = 'fake'
FLAGS.fake_rabbit = True
FLAGS.auth_driver = 'nova.auth.ldapdriver.FakeLdapDriver'
FLAGS.auth_driver = 'nova.auth.dbdriver.DbDriver'
flags.DECLARE('network_size', 'nova.network.manager')
flags.DECLARE('num_networks', 'nova.network.manager')
flags.DECLARE('fake_network', 'nova.network.manager')
@@ -37,4 +37,3 @@ FLAGS.num_shelves = 2
FLAGS.blades_per_shelf = 4
FLAGS.verbose = True
FLAGS.sql_connection = 'sqlite:///nova.sqlite'
#FLAGS.sql_connection = 'mysql://root@localhost/test'

View File

@@ -20,8 +20,12 @@ from nova import exception
from nova import flags
from nova import test
FLAGS = flags.FLAGS
flags.DEFINE_string('flags_unittest', 'foo', 'for testing purposes only')
class FlagsTestCase(test.TrialTestCase):
def setUp(self):
super(FlagsTestCase, self).setUp()
self.FLAGS = flags.FlagValues()
@@ -33,7 +37,8 @@ class FlagsTestCase(test.TrialTestCase):
self.assert_('false' not in self.FLAGS)
self.assert_('true' not in self.FLAGS)
flags.DEFINE_string('string', 'default', 'desc', flag_values=self.FLAGS)
flags.DEFINE_string('string', 'default', 'desc',
flag_values=self.FLAGS)
flags.DEFINE_integer('int', 1, 'desc', flag_values=self.FLAGS)
flags.DEFINE_bool('false', False, 'desc', flag_values=self.FLAGS)
flags.DEFINE_bool('true', True, 'desc', flag_values=self.FLAGS)
@@ -85,3 +90,13 @@ class FlagsTestCase(test.TrialTestCase):
self.assert_('runtime_answer' in self.global_FLAGS)
self.assertEqual(self.global_FLAGS.runtime_answer, 60)
def test_flag_leak_left(self):
self.assertEqual(FLAGS.flags_unittest, 'foo')
FLAGS.flags_unittest = 'bar'
self.assertEqual(FLAGS.flags_unittest, 'bar')
def test_flag_leak_right(self):
self.assertEqual(FLAGS.flags_unittest, 'foo')
FLAGS.flags_unittest = 'bar'
self.assertEqual(FLAGS.flags_unittest, 'bar')

View File

@@ -22,6 +22,7 @@ import IPy
import os
import logging
from nova import context
from nova import db
from nova import exception
from nova import flags
@@ -34,12 +35,11 @@ FLAGS = flags.FLAGS
class NetworkTestCase(test.TrialTestCase):
"""Test cases for network code"""
def setUp(self): # pylint: disable-msg=C0103
def setUp(self):
super(NetworkTestCase, self).setUp()
# NOTE(vish): if you change these flags, make sure to change the
# flags in the corresponding section in nova-dhcpbridge
self.flags(connection_type='fake',
fake_storage=True,
fake_network=True,
auth_driver='nova.auth.ldapdriver.FakeLdapDriver',
network_size=16,
@@ -49,68 +49,91 @@ class NetworkTestCase(test.TrialTestCase):
self.user = self.manager.create_user('netuser', 'netuser', 'netuser')
self.projects = []
self.network = utils.import_object(FLAGS.network_manager)
self.context = None
self.context = context.RequestContext(project=None, user=self.user)
for i in range(5):
name = 'project%s' % i
self.projects.append(self.manager.create_project(name,
'netuser',
name))
project = self.manager.create_project(name, 'netuser', name)
self.projects.append(project)
# create the necessary network data for the project
self.network.set_network_host(self.context, self.projects[i].id)
instance_id = db.instance_create(None,
{'mac_address': utils.generate_mac()})
self.instance_id = instance_id
instance_id = db.instance_create(None,
{'mac_address': utils.generate_mac()})
self.instance2_id = instance_id
user_context = context.RequestContext(project=self.projects[i],
user=self.user)
network_ref = self.network.get_network(user_context)
self.network.set_network_host(context.get_admin_context(),
network_ref['id'])
instance_ref = self._create_instance(0)
self.instance_id = instance_ref['id']
instance_ref = self._create_instance(1)
self.instance2_id = instance_ref['id']
def tearDown(self): # pylint: disable-msg=C0103
def tearDown(self):
super(NetworkTestCase, self).tearDown()
# TODO(termie): this should really be instantiating clean datastores
# in between runs, one failure kills all the tests
db.instance_destroy(None, self.instance_id)
db.instance_destroy(None, self.instance2_id)
db.instance_destroy(context.get_admin_context(), self.instance_id)
db.instance_destroy(context.get_admin_context(), self.instance2_id)
for project in self.projects:
self.manager.delete_project(project)
self.manager.delete_user(self.user)
def _create_instance(self, project_num, mac=None):
if not mac:
mac = utils.generate_mac()
project = self.projects[project_num]
self.context._project = project
self.context.project_id = project.id
return db.instance_create(self.context,
{'project_id': project.id,
'mac_address': mac})
def _create_address(self, project_num, instance_id=None):
"""Create an address in given project num"""
net = db.project_get_network(None, self.projects[project_num].id)
address = db.fixed_ip_allocate(None, net['id'])
if instance_id is None:
instance_id = self.instance_id
db.fixed_ip_instance_associate(None, address, instance_id)
return address
self.context._project = self.projects[project_num]
self.context.project_id = self.projects[project_num].id
return self.network.allocate_fixed_ip(self.context, instance_id)
def _deallocate_address(self, project_num, address):
self.context._project = self.projects[project_num]
self.context.project_id = self.projects[project_num].id
self.network.deallocate_fixed_ip(self.context, address)
def test_public_network_association(self):
"""Makes sure that we can allocaate a public ip"""
# TODO(vish): better way of adding floating ips
pubnet = IPy.IP(flags.FLAGS.public_range)
ip_str = str(pubnet[0])
self.context._project = self.projects[0]
self.context.project_id = self.projects[0].id
pubnet = IPy.IP(flags.FLAGS.floating_range)
address = str(pubnet[0])
try:
db.floating_ip_get_by_address(None, ip_str)
db.floating_ip_get_by_address(context.get_admin_context(), address)
except exception.NotFound:
db.floating_ip_create(None, ip_str, FLAGS.host)
db.floating_ip_create(context.get_admin_context(),
{'address': address,
'host': FLAGS.host})
float_addr = self.network.allocate_floating_ip(self.context,
self.projects[0].id)
fix_addr = self._create_address(0)
lease_ip(fix_addr)
self.assertEqual(float_addr, str(pubnet[0]))
self.network.associate_floating_ip(self.context, float_addr, fix_addr)
address = db.instance_get_floating_address(None, self.instance_id)
address = db.instance_get_floating_address(context.get_admin_context(),
self.instance_id)
self.assertEqual(address, float_addr)
self.network.disassociate_floating_ip(self.context, float_addr)
address = db.instance_get_floating_address(None, self.instance_id)
address = db.instance_get_floating_address(context.get_admin_context(),
self.instance_id)
self.assertEqual(address, None)
self.network.deallocate_floating_ip(self.context, float_addr)
db.fixed_ip_deallocate(None, fix_addr)
self.network.deallocate_fixed_ip(self.context, fix_addr)
release_ip(fix_addr)
def test_allocate_deallocate_fixed_ip(self):
"""Makes sure that we can allocate and deallocate a fixed ip"""
address = self._create_address(0)
self.assertTrue(is_allocated_in_project(address, self.projects[0].id))
lease_ip(address)
db.fixed_ip_deallocate(None, address)
self._deallocate_address(0, address)
# Doesn't go away until it's dhcp released
self.assertTrue(is_allocated_in_project(address, self.projects[0].id))
@@ -131,14 +154,14 @@ class NetworkTestCase(test.TrialTestCase):
lease_ip(address)
lease_ip(address2)
db.fixed_ip_deallocate(None, address)
self._deallocate_address(0, address)
release_ip(address)
self.assertFalse(is_allocated_in_project(address, self.projects[0].id))
# First address release shouldn't affect the second
self.assertTrue(is_allocated_in_project(address2, self.projects[1].id))
db.fixed_ip_deallocate(None, address2)
self._deallocate_address(1, address2)
release_ip(address2)
self.assertFalse(is_allocated_in_project(address2,
self.projects[1].id))
@@ -147,27 +170,41 @@ class NetworkTestCase(test.TrialTestCase):
"""Makes sure that private ips don't overlap"""
first = self._create_address(0)
lease_ip(first)
instance_ids = []
for i in range(1, 5):
address = self._create_address(i)
address2 = self._create_address(i)
address3 = self._create_address(i)
instance_ref = self._create_instance(i, mac=utils.generate_mac())
instance_ids.append(instance_ref['id'])
address = self._create_address(i, instance_ref['id'])
instance_ref = self._create_instance(i, mac=utils.generate_mac())
instance_ids.append(instance_ref['id'])
address2 = self._create_address(i, instance_ref['id'])
instance_ref = self._create_instance(i, mac=utils.generate_mac())
instance_ids.append(instance_ref['id'])
address3 = self._create_address(i, instance_ref['id'])
lease_ip(address)
lease_ip(address2)
lease_ip(address3)
self.context._project = self.projects[i]
self.context.project_id = self.projects[i].id
self.assertFalse(is_allocated_in_project(address,
self.projects[0].id))
self.assertFalse(is_allocated_in_project(address2,
self.projects[0].id))
self.assertFalse(is_allocated_in_project(address3,
self.projects[0].id))
db.fixed_ip_deallocate(None, address)
db.fixed_ip_deallocate(None, address2)
db.fixed_ip_deallocate(None, address3)
self.network.deallocate_fixed_ip(self.context, address)
self.network.deallocate_fixed_ip(self.context, address2)
self.network.deallocate_fixed_ip(self.context, address3)
release_ip(address)
release_ip(address2)
release_ip(address3)
for instance_id in instance_ids:
db.instance_destroy(context.get_admin_context(), instance_id)
self.context._project = self.projects[0]
self.context.project_id = self.projects[0].id
self.network.deallocate_fixed_ip(self.context, first)
self._deallocate_address(0, first)
release_ip(first)
db.fixed_ip_deallocate(None, first)
def test_vpn_ip_and_port_looks_valid(self):
"""Ensure the vpn ip and port are reasonable"""
@@ -179,14 +216,18 @@ class NetworkTestCase(test.TrialTestCase):
def test_too_many_networks(self):
"""Ensure error is raised if we run out of networks"""
projects = []
networks_left = FLAGS.num_networks - db.network_count(None)
networks_left = (FLAGS.num_networks -
db.network_count(context.get_admin_context()))
for i in range(networks_left):
project = self.manager.create_project('many%s' % i, self.user)
projects.append(project)
db.project_get_network(context.get_admin_context(), project.id)
project = self.manager.create_project('last', self.user)
projects.append(project)
self.assertRaises(db.NoMoreNetworks,
self.manager.create_project,
'boom',
self.user)
db.project_get_network,
context.get_admin_context(),
project.id)
for project in projects:
self.manager.delete_project(project)
@@ -194,12 +235,14 @@ class NetworkTestCase(test.TrialTestCase):
"""Makes sure that ip addresses that are deallocated get reused"""
address = self._create_address(0)
lease_ip(address)
db.fixed_ip_deallocate(None, address)
self.network.deallocate_fixed_ip(self.context, address)
release_ip(address)
address2 = self._create_address(0)
self.assertEqual(address, address2)
db.fixed_ip_deallocate(None, address2)
lease_ip(address)
self.network.deallocate_fixed_ip(self.context, address2)
release_ip(address)
def test_available_ips(self):
"""Make sure the number of available ips for the network is correct
@@ -212,45 +255,57 @@ class NetworkTestCase(test.TrialTestCase):
There are ips reserved at the bottom and top of the range.
services (network, gateway, CloudPipe, broadcast)
"""
network = db.project_get_network(None, self.projects[0].id)
network = db.project_get_network(context.get_admin_context(),
self.projects[0].id)
net_size = flags.FLAGS.network_size
total_ips = (db.network_count_available_ips(None, network['id']) +
db.network_count_reserved_ips(None, network['id']) +
db.network_count_allocated_ips(None, network['id']))
admin_context = context.get_admin_context()
total_ips = (db.network_count_available_ips(admin_context,
network['id']) +
db.network_count_reserved_ips(admin_context,
network['id']) +
db.network_count_allocated_ips(admin_context,
network['id']))
self.assertEqual(total_ips, net_size)
def test_too_many_addresses(self):
"""Test for a NoMoreAddresses exception when all fixed ips are used.
"""
network = db.project_get_network(None, self.projects[0].id)
num_available_ips = db.network_count_available_ips(None,
admin_context = context.get_admin_context()
network = db.project_get_network(admin_context, self.projects[0].id)
num_available_ips = db.network_count_available_ips(admin_context,
network['id'])
addresses = []
instance_ids = []
for i in range(num_available_ips):
address = self._create_address(0)
instance_ref = self._create_instance(0)
instance_ids.append(instance_ref['id'])
address = self._create_address(0, instance_ref['id'])
addresses.append(address)
lease_ip(address)
self.assertEqual(db.network_count_available_ips(None,
network['id']), 0)
ip_count = db.network_count_available_ips(context.get_admin_context(),
network['id'])
self.assertEqual(ip_count, 0)
self.assertRaises(db.NoMoreAddresses,
db.fixed_ip_allocate,
None,
network['id'])
self.network.allocate_fixed_ip,
self.context,
'foo')
for i in range(len(addresses)):
db.fixed_ip_deallocate(None, addresses[i])
for i in range(num_available_ips):
self.network.deallocate_fixed_ip(self.context, addresses[i])
release_ip(addresses[i])
self.assertEqual(db.network_count_available_ips(None,
network['id']),
num_available_ips)
db.instance_destroy(context.get_admin_context(), instance_ids[i])
ip_count = db.network_count_available_ips(context.get_admin_context(),
network['id'])
self.assertEqual(ip_count, num_available_ips)
def is_allocated_in_project(address, project_id):
"""Returns true if address is in specified project"""
project_net = db.project_get_network(None, project_id)
network = db.fixed_ip_get_network(None, address)
instance = db.fixed_ip_get_instance(None, address)
project_net = db.project_get_network(context.get_admin_context(),
project_id)
network = db.fixed_ip_get_network(context.get_admin_context(), address)
instance = db.fixed_ip_get_instance(context.get_admin_context(), address)
# instance exists until release
return instance is not None and network['id'] == project_net['id']
@@ -262,8 +317,13 @@ def binpath(script):
def lease_ip(private_ip):
"""Run add command on dhcpbridge"""
network_ref = db.fixed_ip_get_network(None, private_ip)
cmd = "%s add fake %s fake" % (binpath('nova-dhcpbridge'), private_ip)
network_ref = db.fixed_ip_get_network(context.get_admin_context(),
private_ip)
instance_ref = db.fixed_ip_get_instance(context.get_admin_context(),
private_ip)
cmd = "%s add %s %s fake" % (binpath('nova-dhcpbridge'),
instance_ref['mac_address'],
private_ip)
env = {'DNSMASQ_INTERFACE': network_ref['bridge'],
'TESTING': '1',
'FLAGFILE': FLAGS.dhcpbridge_flagfile}
@@ -273,8 +333,13 @@ def lease_ip(private_ip):
def release_ip(private_ip):
"""Run del command on dhcpbridge"""
network_ref = db.fixed_ip_get_network(None, private_ip)
cmd = "%s del fake %s fake" % (binpath('nova-dhcpbridge'), private_ip)
network_ref = db.fixed_ip_get_network(context.get_admin_context(),
private_ip)
instance_ref = db.fixed_ip_get_instance(context.get_admin_context(),
private_ip)
cmd = "%s del %s %s fake" % (binpath('nova-dhcpbridge'),
instance_ref['mac_address'],
private_ip)
env = {'DNSMASQ_INTERFACE': network_ref['bridge'],
'TESTING': '1',
'FLAGFILE': FLAGS.dhcpbridge_flagfile}

View File

@@ -32,6 +32,7 @@ from boto.s3.connection import S3Connection, OrdinaryCallingFormat
from twisted.internet import reactor, threads, defer
from twisted.web import http, server
from nova import context
from nova import flags
from nova import objectstore
from nova import test
@@ -53,10 +54,10 @@ os.makedirs(os.path.join(OSS_TEMPDIR, 'images'))
os.makedirs(os.path.join(OSS_TEMPDIR, 'buckets'))
class ObjectStoreTestCase(test.BaseTestCase):
class ObjectStoreTestCase(test.TrialTestCase):
"""Test objectstore API directly."""
def setUp(self): # pylint: disable-msg=C0103
def setUp(self):
"""Setup users and projects."""
super(ObjectStoreTestCase, self).setUp()
self.flags(buckets_path=os.path.join(OSS_TEMPDIR, 'buckets'),
@@ -70,15 +71,9 @@ class ObjectStoreTestCase(test.BaseTestCase):
self.auth_manager.create_user('admin_user', admin=True)
self.auth_manager.create_project('proj1', 'user1', 'a proj', ['user1'])
self.auth_manager.create_project('proj2', 'user2', 'a proj', ['user2'])
self.context = context.RequestContext('user1', 'proj1')
class Context(object):
"""Dummy context for running tests."""
user = None
project = None
self.context = Context()
def tearDown(self): # pylint: disable-msg=C0103
def tearDown(self):
"""Tear down users and projects."""
self.auth_manager.delete_project('proj1')
self.auth_manager.delete_project('proj2')
@@ -89,8 +84,6 @@ class ObjectStoreTestCase(test.BaseTestCase):
def test_buckets(self):
"""Test the bucket API."""
self.context.user = self.auth_manager.get_user('user1')
self.context.project = self.auth_manager.get_project('proj1')
objectstore.bucket.Bucket.create('new_bucket', self.context)
bucket = objectstore.bucket.Bucket('new_bucket')
@@ -98,14 +91,12 @@ class ObjectStoreTestCase(test.BaseTestCase):
self.assert_(bucket.is_authorized(self.context))
# another user is not authorized
self.context.user = self.auth_manager.get_user('user2')
self.context.project = self.auth_manager.get_project('proj2')
self.assertFalse(bucket.is_authorized(self.context))
context2 = context.RequestContext('user2', 'proj2')
self.assertFalse(bucket.is_authorized(context2))
# admin is authorized to use bucket
self.context.user = self.auth_manager.get_user('admin_user')
self.context.project = None
self.assertTrue(bucket.is_authorized(self.context))
admin_context = context.RequestContext('admin_user', None)
self.assertTrue(bucket.is_authorized(admin_context))
# new buckets are empty
self.assertTrue(bucket.list_keys()['Contents'] == [])
@@ -133,13 +124,20 @@ class ObjectStoreTestCase(test.BaseTestCase):
self.assertRaises(NotFound, objectstore.bucket.Bucket, 'new_bucket')
def test_images(self):
self.do_test_images('1mb.manifest.xml', True,
'image_bucket1', 'i-testing1')
def test_images_no_kernel_or_ramdisk(self):
self.do_test_images('1mb.no_kernel_or_ramdisk.manifest.xml',
False, 'image_bucket2', 'i-testing2')
def do_test_images(self, manifest_file, expect_kernel_and_ramdisk,
image_bucket, image_name):
"Test the image API."
self.context.user = self.auth_manager.get_user('user1')
self.context.project = self.auth_manager.get_project('proj1')
# create a bucket for our bundle
objectstore.bucket.Bucket.create('image_bucket', self.context)
bucket = objectstore.bucket.Bucket('image_bucket')
objectstore.bucket.Bucket.create(image_bucket, self.context)
bucket = objectstore.bucket.Bucket(image_bucket)
# upload an image manifest/parts
bundle_path = os.path.join(os.path.dirname(__file__), 'bundle')
@@ -147,28 +145,43 @@ class ObjectStoreTestCase(test.BaseTestCase):
bucket[os.path.basename(path)] = open(path, 'rb').read()
# register an image
image.Image.register_aws_image('i-testing',
'image_bucket/1mb.manifest.xml',
image.Image.register_aws_image(image_name,
'%s/%s' % (image_bucket, manifest_file),
self.context)
# verify image
my_img = image.Image('i-testing')
my_img = image.Image(image_name)
result_image_file = os.path.join(my_img.path, 'image')
self.assertEqual(os.stat(result_image_file).st_size, 1048576)
sha = hashlib.sha1(open(result_image_file).read()).hexdigest()
self.assertEqual(sha, '3b71f43ff30f4b15b5cd85dd9e95ebc7e84eb5a3')
if expect_kernel_and_ramdisk:
# Verify the default kernel and ramdisk are set
self.assertEqual(my_img.metadata['kernelId'], 'aki-test')
self.assertEqual(my_img.metadata['ramdiskId'], 'ari-test')
else:
# Verify that the default kernel and ramdisk (the one from FLAGS)
# doesn't get embedded in the metadata
self.assertFalse('kernelId' in my_img.metadata)
self.assertFalse('ramdiskId' in my_img.metadata)
# verify image permissions
self.context.user = self.auth_manager.get_user('user2')
self.context.project = self.auth_manager.get_project('proj2')
self.assertFalse(my_img.is_authorized(self.context))
context2 = context.RequestContext('user2', 'proj2')
self.assertFalse(my_img.is_authorized(context2))
# change user-editable fields
my_img.update_user_editable_fields({'display_name': 'my cool image'})
self.assertEqual('my cool image', my_img.metadata['displayName'])
my_img.update_user_editable_fields({'display_name': ''})
self.assert_(not my_img.metadata['displayName'])
class TestHTTPChannel(http.HTTPChannel):
"""Dummy site required for twisted.web"""
def checkPersistence(self, _, __): # pylint: disable-msg=C0103
def checkPersistence(self, _, __): # pylint: disable-msg=C0103
"""Otherwise we end up with an unclean reactor."""
return False
@@ -181,11 +194,11 @@ class TestSite(server.Site):
class S3APITestCase(test.TrialTestCase):
"""Test objectstore through S3 API."""
def setUp(self): # pylint: disable-msg=C0103
def setUp(self):
"""Setup users, projects, and start a test server."""
super(S3APITestCase, self).setUp()
FLAGS.auth_driver = 'nova.auth.ldapdriver.FakeLdapDriver',
FLAGS.auth_driver = 'nova.auth.ldapdriver.FakeLdapDriver'
FLAGS.buckets_path = os.path.join(OSS_TEMPDIR, 'buckets')
self.auth_manager = manager.AuthManager()
@@ -204,7 +217,6 @@ class S3APITestCase(test.TrialTestCase):
# pylint: enable-msg=E1101
self.tcp_port = self.listening_port.getHost().port
if not boto.config.has_section('Boto'):
boto.config.add_section('Boto')
boto.config.set('Boto', 'num_retries', '0')
@@ -221,11 +233,11 @@ class S3APITestCase(test.TrialTestCase):
self.conn.get_http_connection = get_http_connection
def _ensure_no_buckets(self, buckets): # pylint: disable-msg=C0111
def _ensure_no_buckets(self, buckets): # pylint: disable-msg=C0111
self.assertEquals(len(buckets), 0, "Bucket list was not empty")
return True
def _ensure_one_bucket(self, buckets, name): # pylint: disable-msg=C0111
def _ensure_one_bucket(self, buckets, name): # pylint: disable-msg=C0111
self.assertEquals(len(buckets), 1,
"Bucket list didn't have exactly one element in it")
self.assertEquals(buckets[0].name, name, "Wrong name")
@@ -296,7 +308,7 @@ class S3APITestCase(test.TrialTestCase):
deferred.addCallback(self._ensure_no_buckets)
return deferred
def tearDown(self): # pylint: disable-msg=C0103
def tearDown(self):
"""Tear down auth and test server."""
self.auth_manager.delete_user('admin')
self.auth_manager.delete_project('admin')

View File

@@ -38,6 +38,7 @@ class ProcessTestCase(test.TrialTestCase):
def test_execute_stdout(self):
pool = process.ProcessPool(2)
d = pool.simple_execute('echo test')
def _check(rv):
self.assertEqual(rv[0], 'test\n')
self.assertEqual(rv[1], '')
@@ -49,6 +50,7 @@ class ProcessTestCase(test.TrialTestCase):
def test_execute_stderr(self):
pool = process.ProcessPool(2)
d = pool.simple_execute('cat BAD_FILE', check_exit_code=False)
def _check(rv):
self.assertEqual(rv[0], '')
self.assert_('No such file' in rv[1])
@@ -72,6 +74,7 @@ class ProcessTestCase(test.TrialTestCase):
d4 = pool.simple_execute('sleep 0.005')
called = []
def _called(rv, name):
called.append(name)

View File

@@ -0,0 +1,153 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import logging
from nova import context
from nova import db
from nova import exception
from nova import flags
from nova import quota
from nova import test
from nova import utils
from nova.auth import manager
from nova.api.ec2 import cloud
FLAGS = flags.FLAGS
class QuotaTestCase(test.TrialTestCase):
def setUp(self):
logging.getLogger().setLevel(logging.DEBUG)
super(QuotaTestCase, self).setUp()
self.flags(connection_type='fake',
quota_instances=2,
quota_cores=4,
quota_volumes=2,
quota_gigabytes=20,
quota_floating_ips=1)
self.cloud = cloud.CloudController()
self.manager = manager.AuthManager()
self.user = self.manager.create_user('admin', 'admin', 'admin', True)
self.project = self.manager.create_project('admin', 'admin', 'admin')
self.network = utils.import_object(FLAGS.network_manager)
self.context = context.RequestContext(project=self.project,
user=self.user)
def tearDown(self):
manager.AuthManager().delete_project(self.project)
manager.AuthManager().delete_user(self.user)
super(QuotaTestCase, self).tearDown()
def _create_instance(self, cores=2):
"""Create a test instance"""
inst = {}
inst['image_id'] = 'ami-test'
inst['reservation_id'] = 'r-fakeres'
inst['user_id'] = self.user.id
inst['project_id'] = self.project.id
inst['instance_type'] = 'm1.large'
inst['vcpus'] = cores
inst['mac_address'] = utils.generate_mac()
return db.instance_create(self.context, inst)['id']
def _create_volume(self, size=10):
"""Create a test volume"""
vol = {}
vol['user_id'] = self.user.id
vol['project_id'] = self.project.id
vol['size'] = size
return db.volume_create(self.context, vol)['id']
def test_quota_overrides(self):
"""Make sure overriding a projects quotas works"""
num_instances = quota.allowed_instances(self.context, 100, 'm1.small')
self.assertEqual(num_instances, 2)
db.quota_create(self.context, {'project_id': self.project.id,
'instances': 10})
num_instances = quota.allowed_instances(self.context, 100, 'm1.small')
self.assertEqual(num_instances, 4)
db.quota_update(self.context, self.project.id, {'cores': 100})
num_instances = quota.allowed_instances(self.context, 100, 'm1.small')
self.assertEqual(num_instances, 10)
db.quota_destroy(self.context, self.project.id)
def test_too_many_instances(self):
instance_ids = []
for i in range(FLAGS.quota_instances):
instance_id = self._create_instance()
instance_ids.append(instance_id)
self.assertRaises(cloud.QuotaError, self.cloud.run_instances,
self.context,
min_count=1,
max_count=1,
instance_type='m1.small')
for instance_id in instance_ids:
db.instance_destroy(self.context, instance_id)
def test_too_many_cores(self):
instance_ids = []
instance_id = self._create_instance(cores=4)
instance_ids.append(instance_id)
self.assertRaises(cloud.QuotaError, self.cloud.run_instances,
self.context,
min_count=1,
max_count=1,
instance_type='m1.small')
for instance_id in instance_ids:
db.instance_destroy(self.context, instance_id)
def test_too_many_volumes(self):
volume_ids = []
for i in range(FLAGS.quota_volumes):
volume_id = self._create_volume()
volume_ids.append(volume_id)
self.assertRaises(cloud.QuotaError, self.cloud.create_volume,
self.context,
size=10)
for volume_id in volume_ids:
db.volume_destroy(self.context, volume_id)
def test_too_many_gigabytes(self):
volume_ids = []
volume_id = self._create_volume(size=20)
volume_ids.append(volume_id)
self.assertRaises(cloud.QuotaError,
self.cloud.create_volume,
self.context,
size=10)
for volume_id in volume_ids:
db.volume_destroy(self.context, volume_id)
def test_too_many_addresses(self):
address = '192.168.0.100'
try:
db.floating_ip_get_by_address(context.get_admin_context(), address)
except exception.NotFound:
db.floating_ip_create(context.get_admin_context(),
{'address': address, 'host': FLAGS.host})
float_addr = self.network.allocate_floating_ip(self.context,
self.project.id)
# NOTE(vish): This assert never fails. When cloud attempts to
# make an rpc.call, the test just finishes with OK. It
# appears to be something in the magic inline callbacks
# that is breaking.
self.assertRaises(cloud.QuotaError, self.cloud.allocate_address,
self.context)

View File

@@ -21,7 +21,6 @@ from nova import flags
FLAGS = flags.FLAGS
FLAGS.connection_type = 'libvirt'
FLAGS.fake_storage = False
FLAGS.fake_rabbit = False
FLAGS.fake_network = False
FLAGS.verbose = False

View File

@@ -22,6 +22,7 @@ import logging
from twisted.internet import defer
from nova import context
from nova import flags
from nova import rpc
from nova import test
@@ -30,25 +31,34 @@ from nova import test
FLAGS = flags.FLAGS
class RpcTestCase(test.BaseTestCase):
class RpcTestCase(test.TrialTestCase):
"""Test cases for rpc"""
def setUp(self): # pylint: disable-msg=C0103
def setUp(self):
super(RpcTestCase, self).setUp()
self.conn = rpc.Connection.instance()
self.receiver = TestReceiver()
self.consumer = rpc.AdapterConsumer(connection=self.conn,
topic='test',
proxy=self.receiver)
self.injected.append(self.consumer.attach_to_tornado(self.ioloop))
self.consumer.attach_to_twisted()
self.context = context.get_admin_context()
def test_call_succeed(self):
"""Get a value through rpc call"""
value = 42
result = yield rpc.call('test', {"method": "echo",
"args": {"value": value}})
result = yield rpc.call_twisted(self.context,
'test', {"method": "echo",
"args": {"value": value}})
self.assertEqual(value, result)
def test_context_passed(self):
"""Makes sure a context is passed through rpc call"""
value = 42
result = yield rpc.call_twisted(self.context,
'test', {"method": "context",
"args": {"value": value}})
self.assertEqual(self.context.to_dict(), result)
def test_call_exception(self):
"""Test that exception gets passed back properly
@@ -57,12 +67,14 @@ class RpcTestCase(test.BaseTestCase):
to an int in the test.
"""
value = 42
self.assertFailure(rpc.call('test', {"method": "fail",
self.assertFailure(rpc.call_twisted(self.context, 'test',
{"method": "fail",
"args": {"value": value}}),
rpc.RemoteError)
try:
yield rpc.call('test', {"method": "fail",
"args": {"value": value}})
yield rpc.call_twisted(self.context,
'test', {"method": "fail",
"args": {"value": value}})
self.fail("should have thrown rpc.RemoteError")
except rpc.RemoteError as exc:
self.assertEqual(int(exc.value), value)
@@ -74,12 +86,18 @@ class TestReceiver(object):
Uses static methods because we aren't actually storing any state"""
@staticmethod
def echo(value):
def echo(context, value):
"""Simply returns whatever value is sent in"""
logging.debug("Received %s", value)
return defer.succeed(value)
@staticmethod
def fail(value):
def context(context, value):
"""Returns dictionary version of context"""
logging.debug("Received %s", context)
return defer.succeed(context.to_dict())
@staticmethod
def fail(context, value):
"""Raises an exception with the value sent in"""
raise Exception(value)

View File

@@ -19,8 +19,7 @@
Tests For Scheduler
"""
import mox
from nova import context
from nova import db
from nova import flags
from nova import service
@@ -33,7 +32,8 @@ from nova.scheduler import driver
FLAGS = flags.FLAGS
flags.DECLARE('max_instances', 'nova.scheduler.simple')
flags.DECLARE('max_cores', 'nova.scheduler.simple')
class TestDriver(driver.Scheduler):
"""Scheduler Driver for Tests"""
@@ -43,106 +43,204 @@ class TestDriver(driver.Scheduler):
def schedule_named_method(context, topic, num):
return 'named_host'
class SchedulerTestCase(test.TrialTestCase):
"""Test case for scheduler"""
def setUp(self): # pylint: disable=C0103
def setUp(self):
super(SchedulerTestCase, self).setUp()
self.flags(scheduler_driver='nova.tests.scheduler_unittest.TestDriver')
def test_fallback(self):
scheduler = manager.SchedulerManager()
self.mox.StubOutWithMock(rpc, 'cast', use_mock_anything=True)
rpc.cast('topic.fallback_host',
ctxt = context.get_admin_context()
rpc.cast(ctxt,
'topic.fallback_host',
{'method': 'noexist',
'args': {'context': None,
'num': 7}})
'args': {'num': 7}})
self.mox.ReplayAll()
scheduler.noexist(None, 'topic', num=7)
scheduler.noexist(ctxt, 'topic', num=7)
def test_named_method(self):
scheduler = manager.SchedulerManager()
self.mox.StubOutWithMock(rpc, 'cast', use_mock_anything=True)
rpc.cast('topic.named_host',
ctxt = context.get_admin_context()
rpc.cast(ctxt,
'topic.named_host',
{'method': 'named_method',
'args': {'context': None,
'num': 7}})
'args': {'num': 7}})
self.mox.ReplayAll()
scheduler.named_method(None, 'topic', num=7)
scheduler.named_method(ctxt, 'topic', num=7)
class SimpleDriverTestCase(test.TrialTestCase):
"""Test case for simple driver"""
def setUp(self): # pylint: disable-msg=C0103
def setUp(self):
super(SimpleDriverTestCase, self).setUp()
self.flags(connection_type='fake',
max_instances=4,
max_cores=4,
max_gigabytes=4,
network_manager='nova.network.manager.FlatManager',
volume_driver='nova.volume.driver.FakeAOEDriver',
scheduler_driver='nova.scheduler.simple.SimpleScheduler')
self.scheduler = manager.SchedulerManager()
self.context = None
self.manager = auth_manager.AuthManager()
self.user = self.manager.create_user('fake', 'fake', 'fake')
self.project = self.manager.create_project('fake', 'fake', 'fake')
self.context = None
self.service1 = service.Service('host1',
'nova-compute',
'compute',
FLAGS.compute_manager)
self.service2 = service.Service('host2',
'nova-compute',
'compute',
FLAGS.compute_manager)
self.context = context.get_admin_context()
def tearDown(self): # pylint: disable-msg=C0103
def tearDown(self):
self.manager.delete_user(self.user)
self.manager.delete_project(self.project)
self.service1.kill()
self.service2.kill()
def _create_instance(self):
"""Create a test instance"""
inst = {}
inst['image_id'] = 'ami-test'
inst['reservation_id'] = 'r-fakeres'
inst['launch_time'] = '10'
inst['user_id'] = self.user.id
inst['project_id'] = self.project.id
inst['instance_type'] = 'm1.tiny'
inst['mac_address'] = utils.generate_mac()
inst['ami_launch_index'] = 0
return db.instance_create(self.context, inst)
inst['vcpus'] = 1
return db.instance_create(self.context, inst)['id']
def _create_volume(self):
"""Create a test volume"""
vol = {}
vol['image_id'] = 'ami-test'
vol['reservation_id'] = 'r-fakeres'
vol['size'] = 1
return db.volume_create(self.context, vol)['id']
def test_hosts_are_up(self):
"""Ensures driver can find the hosts that are up"""
# NOTE(vish): constructing service without create method
# because we are going to use it without queue
compute1 = service.Service('host1',
'nova-compute',
'compute',
FLAGS.compute_manager)
compute1.startService()
compute2 = service.Service('host2',
'nova-compute',
'compute',
FLAGS.compute_manager)
compute2.startService()
hosts = self.scheduler.driver.hosts_up(self.context, 'compute')
self.assertEqual(len(hosts), 2)
compute1.kill()
compute2.kill()
def test_least_busy_host_gets_instance(self):
instance_id = self._create_instance()
self.service1.run_instance(self.context, instance_id)
"""Ensures the host with less cores gets the next one"""
compute1 = service.Service('host1',
'nova-compute',
'compute',
FLAGS.compute_manager)
compute1.startService()
compute2 = service.Service('host2',
'nova-compute',
'compute',
FLAGS.compute_manager)
compute2.startService()
instance_id1 = self._create_instance()
compute1.run_instance(self.context, instance_id1)
instance_id2 = self._create_instance()
host = self.scheduler.driver.schedule_run_instance(self.context,
'compute',
instance_id)
instance_id2)
self.assertEqual(host, 'host2')
self.service1.terminate_instance(self.context, instance_id)
compute1.terminate_instance(self.context, instance_id1)
db.instance_destroy(self.context, instance_id2)
compute1.kill()
compute2.kill()
def test_too_many_instances(self):
def test_too_many_cores(self):
"""Ensures we don't go over max cores"""
compute1 = service.Service('host1',
'nova-compute',
'compute',
FLAGS.compute_manager)
compute1.startService()
compute2 = service.Service('host2',
'nova-compute',
'compute',
FLAGS.compute_manager)
compute2.startService()
instance_ids1 = []
instance_ids2 = []
for index in xrange(FLAGS.max_instances):
for index in xrange(FLAGS.max_cores):
instance_id = self._create_instance()
self.service1.run_instance(self.context, instance_id)
compute1.run_instance(self.context, instance_id)
instance_ids1.append(instance_id)
instance_id = self._create_instance()
self.service2.run_instance(self.context, instance_id)
compute2.run_instance(self.context, instance_id)
instance_ids2.append(instance_id)
instance_id = self._create_instance()
self.assertRaises(driver.NoValidHost,
self.scheduler.driver.schedule_run_instance,
self.context,
'compute',
instance_id)
for instance_id in instance_ids1:
self.service1.terminate_instance(self.context, instance_id)
compute1.terminate_instance(self.context, instance_id)
for instance_id in instance_ids2:
self.service2.terminate_instance(self.context, instance_id)
compute2.terminate_instance(self.context, instance_id)
compute1.kill()
compute2.kill()
def test_least_busy_host_gets_volume(self):
"""Ensures the host with less gigabytes gets the next one"""
volume1 = service.Service('host1',
'nova-volume',
'volume',
FLAGS.volume_manager)
volume1.startService()
volume2 = service.Service('host2',
'nova-volume',
'volume',
FLAGS.volume_manager)
volume2.startService()
volume_id1 = self._create_volume()
volume1.create_volume(self.context, volume_id1)
volume_id2 = self._create_volume()
host = self.scheduler.driver.schedule_create_volume(self.context,
volume_id2)
self.assertEqual(host, 'host2')
volume1.delete_volume(self.context, volume_id1)
db.volume_destroy(self.context, volume_id2)
volume1.kill()
volume2.kill()
def test_too_many_gigabytes(self):
"""Ensures we don't go over max gigabytes"""
volume1 = service.Service('host1',
'nova-volume',
'volume',
FLAGS.volume_manager)
volume1.startService()
volume2 = service.Service('host2',
'nova-volume',
'volume',
FLAGS.volume_manager)
volume2.startService()
volume_ids1 = []
volume_ids2 = []
for index in xrange(FLAGS.max_gigabytes):
volume_id = self._create_volume()
volume1.create_volume(self.context, volume_id)
volume_ids1.append(volume_id)
volume_id = self._create_volume()
volume2.create_volume(self.context, volume_id)
volume_ids2.append(volume_id)
volume_id = self._create_volume()
self.assertRaises(driver.NoValidHost,
self.scheduler.driver.schedule_create_volume,
self.context,
volume_id)
for volume_id in volume_ids1:
volume1.delete_volume(self.context, volume_id)
for volume_id in volume_ids2:
volume2.delete_volume(self.context, volume_id)
volume1.kill()
volume2.kill()

View File

@@ -22,6 +22,9 @@ Unit Tests for remote procedure calls using queue
import mox
from twisted.application.app import startApplication
from nova import context
from nova import exception
from nova import flags
from nova import rpc
@@ -36,20 +39,59 @@ flags.DEFINE_string("fake_manager", "nova.tests.service_unittest.FakeManager",
class FakeManager(manager.Manager):
"""Fake manager for tests"""
pass
def test_method(self):
return 'manager'
class ExtendedService(service.Service):
def test_method(self):
return 'service'
class ServiceManagerTestCase(test.BaseTestCase):
"""Test cases for Services"""
def test_attribute_error_for_no_manager(self):
serv = service.Service('test',
'test',
'test',
'nova.tests.service_unittest.FakeManager')
self.assertRaises(AttributeError, getattr, serv, 'test_method')
def test_message_gets_to_manager(self):
serv = service.Service('test',
'test',
'test',
'nova.tests.service_unittest.FakeManager')
serv.startService()
self.assertEqual(serv.test_method(), 'manager')
def test_override_manager_method(self):
serv = ExtendedService('test',
'test',
'test',
'nova.tests.service_unittest.FakeManager')
serv.startService()
self.assertEqual(serv.test_method(), 'service')
class ServiceTestCase(test.BaseTestCase):
"""Test cases for rpc"""
"""Test cases for Services"""
def setUp(self): # pylint: disable=C0103
def setUp(self):
super(ServiceTestCase, self).setUp()
self.mox.StubOutWithMock(service, 'db')
self.context = context.get_admin_context()
def test_create(self):
host='foo'
binary='nova-fake'
topic='fake'
host = 'foo'
binary = 'nova-fake'
topic = 'fake'
# NOTE(vish): Create was moved out of mox replay to make sure that
# the looping calls are created in StartService.
app = service.Service.create(host=host, binary=binary)
self.mox.StubOutWithMock(rpc,
'AdapterConsumer',
use_mock_anything=True)
@@ -65,32 +107,37 @@ class ServiceTestCase(test.BaseTestCase):
proxy=mox.IsA(service.Service)).AndReturn(
rpc.AdapterConsumer)
rpc.AdapterConsumer.attach_to_twisted()
rpc.AdapterConsumer.attach_to_twisted()
# Stub out looping call a bit needlessly since we don't have an easy
# way to cancel it (yet) when the tests finishes
service.task.LoopingCall(mox.IgnoreArg()).AndReturn(
service.task.LoopingCall)
service.task.LoopingCall.start(interval=mox.IgnoreArg(),
now=mox.IgnoreArg())
service.task.LoopingCall(mox.IgnoreArg()).AndReturn(
service.task.LoopingCall)
service.task.LoopingCall.start(interval=mox.IgnoreArg(),
now=mox.IgnoreArg())
rpc.AdapterConsumer.attach_to_twisted()
rpc.AdapterConsumer.attach_to_twisted()
service_create = {'host': host,
'binary': binary,
'topic': topic,
'report_count': 0}
'binary': binary,
'topic': topic,
'report_count': 0}
service_ref = {'host': host,
'binary': binary,
'report_count': 0,
'id': 1}
'binary': binary,
'report_count': 0,
'id': 1}
service.db.service_get_by_args(None,
host,
binary).AndRaise(exception.NotFound())
service.db.service_create(None,
service_create).AndReturn(service_ref['id'])
service.db.service_get_by_args(mox.IgnoreArg(),
host,
binary).AndRaise(exception.NotFound())
service.db.service_create(mox.IgnoreArg(),
service_create).AndReturn(service_ref)
self.mox.ReplayAll()
app = service.Service.create(host=host, binary=binary)
startApplication(app, False)
self.assert_(app)
# We're testing sort of weird behavior in how report_state decides
@@ -101,15 +148,15 @@ class ServiceTestCase(test.BaseTestCase):
host = 'foo'
binary = 'bar'
service_ref = {'host': host,
'binary': binary,
'report_count': 0,
'id': 1}
'binary': binary,
'report_count': 0,
'id': 1}
service.db.__getattr__('report_state')
service.db.service_get_by_args(None,
host,
binary).AndReturn(service_ref)
service.db.service_update(None, service_ref['id'],
mox.ContainsKeyValue('report_count', 1))
service.db.service_get_by_args(self.context,
host,
binary).AndReturn(service_ref)
service.db.service_update(self.context, service_ref['id'],
mox.ContainsKeyValue('report_count', 1))
self.mox.ReplayAll()
s = service.Service()
@@ -119,22 +166,23 @@ class ServiceTestCase(test.BaseTestCase):
host = 'foo'
binary = 'bar'
service_create = {'host': host,
'binary': binary,
'report_count': 0}
'binary': binary,
'report_count': 0}
service_ref = {'host': host,
'binary': binary,
'report_count': 0,
'id': 1}
'binary': binary,
'report_count': 0,
'id': 1}
service.db.__getattr__('report_state')
service.db.service_get_by_args(None,
service.db.service_get_by_args(self.context,
host,
binary).AndRaise(exception.NotFound())
service.db.service_create(None,
service_create).AndReturn(service_ref['id'])
service.db.service_get(None, service_ref['id']).AndReturn(service_ref)
service.db.service_update(None, service_ref['id'],
mox.ContainsKeyValue('report_count', 1))
service.db.service_create(self.context,
service_create).AndReturn(service_ref)
service.db.service_get(self.context,
service_ref['id']).AndReturn(service_ref)
service.db.service_update(self.context, service_ref['id'],
mox.ContainsKeyValue('report_count', 1))
self.mox.ReplayAll()
s = service.Service()
@@ -144,14 +192,14 @@ class ServiceTestCase(test.BaseTestCase):
host = 'foo'
binary = 'bar'
service_ref = {'host': host,
'binary': binary,
'report_count': 0,
'id': 1}
'binary': binary,
'report_count': 0,
'id': 1}
service.db.__getattr__('report_state')
service.db.service_get_by_args(None,
host,
binary).AndRaise(Exception())
service.db.service_get_by_args(self.context,
host,
binary).AndRaise(Exception())
self.mox.ReplayAll()
s = service.Service()
@@ -163,16 +211,16 @@ class ServiceTestCase(test.BaseTestCase):
host = 'foo'
binary = 'bar'
service_ref = {'host': host,
'binary': binary,
'report_count': 0,
'id': 1}
'binary': binary,
'report_count': 0,
'id': 1}
service.db.__getattr__('report_state')
service.db.service_get_by_args(None,
host,
binary).AndReturn(service_ref)
service.db.service_update(None, service_ref['id'],
mox.ContainsKeyValue('report_count', 1))
service.db.service_get_by_args(self.context,
host,
binary).AndReturn(service_ref)
service.db.service_update(self.context, service_ref['id'],
mox.ContainsKeyValue('report_count', 1))
self.mox.ReplayAll()
s = service.Service()

View File

@@ -1,115 +0,0 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import logging
from nova import exception
from nova import flags
from nova import test
from nova.compute import node
from nova.volume import storage
FLAGS = flags.FLAGS
class StorageTestCase(test.TrialTestCase):
def setUp(self):
logging.getLogger().setLevel(logging.DEBUG)
super(StorageTestCase, self).setUp()
self.mynode = node.Node()
self.mystorage = None
self.flags(connection_type='fake',
fake_storage=True)
self.mystorage = storage.BlockStore()
def test_run_create_volume(self):
vol_size = '0'
user_id = 'fake'
project_id = 'fake'
volume_id = self.mystorage.create_volume(vol_size, user_id, project_id)
# TODO(termie): get_volume returns differently than create_volume
self.assertEqual(volume_id,
storage.get_volume(volume_id)['volume_id'])
rv = self.mystorage.delete_volume(volume_id)
self.assertRaises(exception.Error,
storage.get_volume,
volume_id)
def test_too_big_volume(self):
vol_size = '1001'
user_id = 'fake'
project_id = 'fake'
self.assertRaises(TypeError,
self.mystorage.create_volume,
vol_size, user_id, project_id)
def test_too_many_volumes(self):
vol_size = '1'
user_id = 'fake'
project_id = 'fake'
num_shelves = FLAGS.last_shelf_id - FLAGS.first_shelf_id + 1
total_slots = FLAGS.slots_per_shelf * num_shelves
vols = []
for i in xrange(total_slots):
vid = self.mystorage.create_volume(vol_size, user_id, project_id)
vols.append(vid)
self.assertRaises(storage.NoMoreVolumes,
self.mystorage.create_volume,
vol_size, user_id, project_id)
for id in vols:
self.mystorage.delete_volume(id)
def test_run_attach_detach_volume(self):
# Create one volume and one node to test with
instance_id = "storage-test"
vol_size = "5"
user_id = "fake"
project_id = 'fake'
mountpoint = "/dev/sdf"
volume_id = self.mystorage.create_volume(vol_size, user_id, project_id)
volume_obj = storage.get_volume(volume_id)
volume_obj.start_attach(instance_id, mountpoint)
rv = yield self.mynode.attach_volume(volume_id,
instance_id,
mountpoint)
self.assertEqual(volume_obj['status'], "in-use")
self.assertEqual(volume_obj['attachStatus'], "attached")
self.assertEqual(volume_obj['instance_id'], instance_id)
self.assertEqual(volume_obj['mountpoint'], mountpoint)
self.assertRaises(exception.Error,
self.mystorage.delete_volume,
volume_id)
rv = yield self.mystorage.detach_volume(volume_id)
volume_obj = storage.get_volume(volume_id)
self.assertEqual(volume_obj['status'], "available")
rv = self.mystorage.delete_volume(volume_id)
self.assertRaises(exception.Error,
storage.get_volume,
volume_id)
def test_multi_node(self):
# TODO(termie): Figure out how to test with two nodes,
# each of them having a different FLAG for storage_node
# This will allow us to test cross-node interactions
pass

View File

@@ -0,0 +1,53 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import StringIO
import sys
from nova import twistd
from nova import exception
from nova import flags
from nova import test
FLAGS = flags.FLAGS
class TwistdTestCase(test.TrialTestCase):
def setUp(self):
super(TwistdTestCase, self).setUp()
self.Options = twistd.WrapTwistedOptions(twistd.TwistdServerOptions)
sys.stdout = StringIO.StringIO()
def tearDown(self):
super(TwistdTestCase, self).tearDown()
sys.stdout = sys.__stdout__
def test_basic(self):
options = self.Options()
argv = options.parseOptions()
def test_logfile(self):
options = self.Options()
argv = options.parseOptions(['--logfile=foo'])
self.assertEqual(FLAGS.logfile, 'foo')
def test_help(self):
options = self.Options()
self.assertRaises(SystemExit, options.parseOptions, ['--help'])
self.assert_('pidfile' in sys.stdout.getvalue())

View File

@@ -35,7 +35,8 @@ class ValidationTestCase(test.TrialTestCase):
self.assertTrue(type_case("foo", 5, 1))
self.assertRaises(TypeError, type_case, "bar", "5", 1)
self.assertRaises(TypeError, type_case, None, 5, 1)
@validate.typetest(instanceid=str, size=int, number_of_instances=int)
def type_case(instanceid, size, number_of_instances):
return True

View File

@@ -14,56 +14,246 @@
# License for the specific language governing permissions and limitations
# under the License.
from xml.etree.ElementTree import fromstring as xml_to_tree
from xml.dom.minidom import parseString as xml_to_dom
from nova import context
from nova import db
from nova import flags
from nova import test
from nova import utils
from nova.api.ec2 import cloud
from nova.auth import manager
from nova.virt import libvirt_conn
FLAGS = flags.FLAGS
flags.DECLARE('instances_path', 'nova.compute.manager')
class LibvirtConnTestCase(test.TrialTestCase):
def setUp(self):
super(LibvirtConnTestCase, self).setUp()
self.manager = manager.AuthManager()
self.user = self.manager.create_user('fake', 'fake', 'fake',
admin=True)
self.project = self.manager.create_project('fake', 'fake', 'fake')
self.network = utils.import_object(FLAGS.network_manager)
FLAGS.instances_path = ''
def test_get_uri_and_template(self):
class MockDataModel(object):
def __init__(self):
self.datamodel = { 'name' : 'i-cafebabe',
'memory_kb' : '1024000',
'basepath' : '/some/path',
'bridge_name' : 'br100',
'mac_address' : '02:12:34:46:56:67',
'vcpus' : 2 }
ip = '10.11.12.13'
type_uri_map = { 'qemu' : ('qemu:///system',
[lambda s: '<domain type=\'qemu\'>' in s,
lambda s: 'type>hvm</type' in s,
lambda s: 'emulator>/usr/bin/kvm' not in s]),
'kvm' : ('qemu:///system',
[lambda s: '<domain type=\'kvm\'>' in s,
lambda s: 'type>hvm</type' in s,
lambda s: 'emulator>/usr/bin/qemu<' not in s]),
'uml' : ('uml:///system',
[lambda s: '<domain type=\'uml\'>' in s,
lambda s: 'type>uml</type' in s]),
}
instance = {'internal_id': 1,
'memory_kb': '1024000',
'basepath': '/some/path',
'bridge_name': 'br100',
'mac_address': '02:12:34:46:56:67',
'vcpus': 2,
'project_id': 'fake',
'bridge': 'br101',
'instance_type': 'm1.small'}
for (libvirt_type,(expected_uri, checks)) in type_uri_map.iteritems():
user_context = context.RequestContext(project=self.project,
user=self.user)
instance_ref = db.instance_create(user_context, instance)
network_ref = self.network.get_network(user_context)
self.network.set_network_host(context.get_admin_context(),
network_ref['id'])
fixed_ip = {'address': ip,
'network_id': network_ref['id']}
ctxt = context.get_admin_context()
fixed_ip_ref = db.fixed_ip_create(ctxt, fixed_ip)
db.fixed_ip_update(ctxt, ip, {'allocated': True,
'instance_id': instance_ref['id']})
type_uri_map = {'qemu': ('qemu:///system',
[(lambda t: t.find('.').get('type'), 'qemu'),
(lambda t: t.find('./os/type').text, 'hvm'),
(lambda t: t.find('./devices/emulator'), None)]),
'kvm': ('qemu:///system',
[(lambda t: t.find('.').get('type'), 'kvm'),
(lambda t: t.find('./os/type').text, 'hvm'),
(lambda t: t.find('./devices/emulator'), None)]),
'uml': ('uml:///system',
[(lambda t: t.find('.').get('type'), 'uml'),
(lambda t: t.find('./os/type').text, 'uml')])}
common_checks = [
(lambda t: t.find('.').tag, 'domain'),
(lambda t: t.find('./devices/interface/filterref/parameter').\
get('name'), 'IP'),
(lambda t: t.find('./devices/interface/filterref/parameter').\
get('value'), '10.11.12.13')]
for (libvirt_type, (expected_uri, checks)) in type_uri_map.iteritems():
FLAGS.libvirt_type = libvirt_type
conn = libvirt_conn.LibvirtConnection(True)
uri, template = conn.get_uri_and_template()
self.assertEquals(uri, expected_uri)
for i, check in enumerate(checks):
xml = conn.toXml(MockDataModel())
self.assertTrue(check(xml), '%s failed check %d' % (xml, i))
xml = conn.to_xml(instance_ref)
tree = xml_to_tree(xml)
for i, (check, expected_result) in enumerate(checks):
self.assertEqual(check(tree),
expected_result,
'%s failed check %d' % (xml, i))
for i, (check, expected_result) in enumerate(common_checks):
self.assertEqual(check(tree),
expected_result,
'%s failed common check %d' % (xml, i))
# Deliberately not just assigning this string to FLAGS.libvirt_uri and
# checking against that later on. This way we make sure the
# implementation doesn't fiddle around with the FLAGS.
testuri = 'something completely different'
FLAGS.libvirt_uri = testuri
for (libvirt_type,(expected_uri, checks)) in type_uri_map.iteritems():
for (libvirt_type, (expected_uri, checks)) in type_uri_map.iteritems():
FLAGS.libvirt_type = libvirt_type
conn = libvirt_conn.LibvirtConnection(True)
uri, template = conn.get_uri_and_template()
self.assertEquals(uri, testuri)
def tearDown(self):
super(LibvirtConnTestCase, self).tearDown()
self.manager.delete_project(self.project)
self.manager.delete_user(self.user)
class NWFilterTestCase(test.TrialTestCase):
def setUp(self):
super(NWFilterTestCase, self).setUp()
class Mock(object):
pass
self.manager = manager.AuthManager()
self.user = self.manager.create_user('fake', 'fake', 'fake',
admin=True)
self.project = self.manager.create_project('fake', 'fake', 'fake')
self.context = context.RequestContext(self.user, self.project)
self.fake_libvirt_connection = Mock()
self.fw = libvirt_conn.NWFilterFirewall(self.fake_libvirt_connection)
def tearDown(self):
self.manager.delete_project(self.project)
self.manager.delete_user(self.user)
def test_cidr_rule_nwfilter_xml(self):
cloud_controller = cloud.CloudController()
cloud_controller.create_security_group(self.context,
'testgroup',
'test group description')
cloud_controller.authorize_security_group_ingress(self.context,
'testgroup',
from_port='80',
to_port='81',
ip_protocol='tcp',
cidr_ip='0.0.0.0/0')
security_group = db.security_group_get_by_name(self.context,
'fake',
'testgroup')
xml = self.fw.security_group_to_nwfilter_xml(security_group.id)
dom = xml_to_dom(xml)
self.assertEqual(dom.firstChild.tagName, 'filter')
rules = dom.getElementsByTagName('rule')
self.assertEqual(len(rules), 1)
# It's supposed to allow inbound traffic.
self.assertEqual(rules[0].getAttribute('action'), 'accept')
self.assertEqual(rules[0].getAttribute('direction'), 'in')
# Must be lower priority than the base filter (which blocks everything)
self.assertTrue(int(rules[0].getAttribute('priority')) < 1000)
ip_conditions = rules[0].getElementsByTagName('tcp')
self.assertEqual(len(ip_conditions), 1)
self.assertEqual(ip_conditions[0].getAttribute('srcipaddr'), '0.0.0.0')
self.assertEqual(ip_conditions[0].getAttribute('srcipmask'), '0.0.0.0')
self.assertEqual(ip_conditions[0].getAttribute('dstportstart'), '80')
self.assertEqual(ip_conditions[0].getAttribute('dstportend'), '81')
self.teardown_security_group()
def teardown_security_group(self):
cloud_controller = cloud.CloudController()
cloud_controller.delete_security_group(self.context, 'testgroup')
def setup_and_return_security_group(self):
cloud_controller = cloud.CloudController()
cloud_controller.create_security_group(self.context,
'testgroup',
'test group description')
cloud_controller.authorize_security_group_ingress(self.context,
'testgroup',
from_port='80',
to_port='81',
ip_protocol='tcp',
cidr_ip='0.0.0.0/0')
return db.security_group_get_by_name(self.context, 'fake', 'testgroup')
def test_creates_base_rule_first(self):
# These come pre-defined by libvirt
self.defined_filters = ['no-mac-spoofing',
'no-ip-spoofing',
'no-arp-spoofing',
'allow-dhcp-server']
self.recursive_depends = {}
for f in self.defined_filters:
self.recursive_depends[f] = []
def _filterDefineXMLMock(xml):
dom = xml_to_dom(xml)
name = dom.firstChild.getAttribute('name')
self.recursive_depends[name] = []
for f in dom.getElementsByTagName('filterref'):
ref = f.getAttribute('filter')
self.assertTrue(ref in self.defined_filters,
('%s referenced filter that does ' +
'not yet exist: %s') % (name, ref))
dependencies = [ref] + self.recursive_depends[ref]
self.recursive_depends[name] += dependencies
self.defined_filters.append(name)
return True
self.fake_libvirt_connection.nwfilterDefineXML = _filterDefineXMLMock
instance_ref = db.instance_create(self.context,
{'user_id': 'fake',
'project_id': 'fake'})
inst_id = instance_ref['id']
def _ensure_all_called(_):
instance_filter = 'nova-instance-%s' % instance_ref['name']
secgroup_filter = 'nova-secgroup-%s' % self.security_group['id']
for required in [secgroup_filter, 'allow-dhcp-server',
'no-arp-spoofing', 'no-ip-spoofing',
'no-mac-spoofing']:
self.assertTrue(required in
self.recursive_depends[instance_filter],
"Instance's filter does not include %s" %
required)
self.security_group = self.setup_and_return_security_group()
db.instance_add_security_group(self.context, inst_id,
self.security_group.id)
instance = db.instance_get(self.context, inst_id)
d = self.fw.setup_nwfilters_for_instance(instance)
d.addCallback(_ensure_all_called)
d.addCallback(lambda _: self.teardown_security_group())
return d

View File

@@ -22,6 +22,7 @@ import logging
from twisted.internet import defer
from nova import context
from nova import exception
from nova import db
from nova import flags
@@ -33,14 +34,13 @@ FLAGS = flags.FLAGS
class VolumeTestCase(test.TrialTestCase):
"""Test Case for volumes"""
def setUp(self): # pylint: disable-msg=C0103
def setUp(self):
logging.getLogger().setLevel(logging.DEBUG)
super(VolumeTestCase, self).setUp()
self.compute = utils.import_object(FLAGS.compute_manager)
self.flags(connection_type='fake',
fake_storage=True)
self.flags(connection_type='fake')
self.volume = utils.import_object(FLAGS.volume_manager)
self.context = None
self.context = context.get_admin_context()
@staticmethod
def _create_volume(size='0'):
@@ -52,19 +52,20 @@ class VolumeTestCase(test.TrialTestCase):
vol['availability_zone'] = FLAGS.storage_availability_zone
vol['status'] = "creating"
vol['attach_status'] = "detached"
return db.volume_create(None, vol)['id']
return db.volume_create(context.get_admin_context(), vol)['id']
@defer.inlineCallbacks
def test_create_delete_volume(self):
"""Test volume can be created and deleted"""
volume_id = self._create_volume()
yield self.volume.create_volume(self.context, volume_id)
self.assertEqual(volume_id, db.volume_get(None, volume_id).id)
self.assertEqual(volume_id, db.volume_get(context.get_admin_context(),
volume_id).id)
yield self.volume.delete_volume(self.context, volume_id)
self.assertRaises(exception.NotFound,
db.volume_get,
None,
self.context,
volume_id)
@defer.inlineCallbacks
@@ -93,44 +94,58 @@ class VolumeTestCase(test.TrialTestCase):
self.assertFailure(self.volume.create_volume(self.context,
volume_id),
db.NoMoreBlades)
db.volume_destroy(None, volume_id)
db.volume_destroy(context.get_admin_context(), volume_id)
for volume_id in vols:
yield self.volume.delete_volume(self.context, volume_id)
@defer.inlineCallbacks
def test_run_attach_detach_volume(self):
"""Make sure volume can be attached and detached from instance"""
instance_id = "storage-test"
inst = {}
inst['image_id'] = 'ami-test'
inst['reservation_id'] = 'r-fakeres'
inst['launch_time'] = '10'
inst['user_id'] = 'fake'
inst['project_id'] = 'fake'
inst['instance_type'] = 'm1.tiny'
inst['mac_address'] = utils.generate_mac()
inst['ami_launch_index'] = 0
instance_id = db.instance_create(self.context, inst)['id']
mountpoint = "/dev/sdf"
volume_id = self._create_volume()
yield self.volume.create_volume(self.context, volume_id)
if FLAGS.fake_tests:
db.volume_attached(None, volume_id, instance_id, mountpoint)
db.volume_attached(self.context, volume_id, instance_id,
mountpoint)
else:
yield self.compute.attach_volume(instance_id,
yield self.compute.attach_volume(self.context,
instance_id,
volume_id,
mountpoint)
vol = db.volume_get(None, volume_id)
vol = db.volume_get(context.get_admin_context(), volume_id)
self.assertEqual(vol['status'], "in-use")
self.assertEqual(vol['attach_status'], "attached")
self.assertEqual(vol['instance_id'], instance_id)
self.assertEqual(vol['mountpoint'], mountpoint)
instance_ref = db.volume_get_instance(self.context, volume_id)
self.assertEqual(instance_ref['id'], instance_id)
self.assertFailure(self.volume.delete_volume(self.context, volume_id),
exception.Error)
if FLAGS.fake_tests:
db.volume_detached(None, volume_id)
db.volume_detached(self.context, volume_id)
else:
yield self.compute.detach_volume(instance_id,
yield self.compute.detach_volume(self.context,
instance_id,
volume_id)
vol = db.volume_get(None, volume_id)
vol = db.volume_get(self.context, volume_id)
self.assertEqual(vol['status'], "available")
yield self.volume.delete_volume(self.context, volume_id)
self.assertRaises(exception.Error,
db.volume_get,
None,
self.context,
volume_id)
db.instance_destroy(self.context, instance_id)
@defer.inlineCallbacks
def test_concurrent_volumes_get_different_blades(self):
@@ -141,7 +156,8 @@ class VolumeTestCase(test.TrialTestCase):
def _check(volume_id):
"""Make sure blades aren't duplicated"""
volume_ids.append(volume_id)
(shelf_id, blade_id) = db.volume_get_shelf_and_blade(None,
admin_context = context.get_admin_context()
(shelf_id, blade_id) = db.volume_get_shelf_and_blade(admin_context,
volume_id)
shelf_blade = '%s.%s' % (shelf_id, blade_id)
self.assert_(shelf_blade not in shelf_blades)

View File

@@ -51,6 +51,9 @@ class TwistdServerOptions(ServerOptions):
class FlagParser(object):
# this is a required attribute for gflags
syntactic_help = ''
def __init__(self, parser):
self.parser = parser
@@ -61,6 +64,7 @@ class FlagParser(object):
def WrapTwistedOptions(wrapped):
class TwistedOptionsToFlags(wrapped):
subCommands = None
def __init__(self):
# NOTE(termie): _data exists because Twisted stuff expects
# to be able to set arbitrary things that are
@@ -78,16 +82,22 @@ def WrapTwistedOptions(wrapped):
def _absorbFlags(self):
twistd_flags = []
reflect.accumulateClassList(self.__class__, 'optFlags', twistd_flags)
reflect.accumulateClassList(self.__class__, 'optFlags',
twistd_flags)
for flag in twistd_flags:
key = flag[0].replace('-', '_')
if hasattr(FLAGS, key):
continue
flags.DEFINE_boolean(key, None, str(flag[-1]))
def _absorbParameters(self):
twistd_params = []
reflect.accumulateClassList(self.__class__, 'optParameters', twistd_params)
reflect.accumulateClassList(self.__class__, 'optParameters',
twistd_params)
for param in twistd_params:
key = param[0].replace('-', '_')
if hasattr(FLAGS, key):
continue
if len(param) > 4:
flags.DEFINE(FlagParser(param[4]),
key, param[2], str(param[3]),
@@ -97,13 +107,14 @@ def WrapTwistedOptions(wrapped):
def _absorbHandlers(self):
twistd_handlers = {}
reflect.addMethodNamesToDict(self.__class__, twistd_handlers, "opt_")
reflect.addMethodNamesToDict(self.__class__, twistd_handlers,
"opt_")
# NOTE(termie): Much of the following is derived/copied from
# twisted.python.usage with the express purpose of
# providing compatibility
for name in twistd_handlers.keys():
method = getattr(self, 'opt_'+name)
method = getattr(self, 'opt_' + name)
takesArg = not usage.flagFunction(method, name)
doc = getattr(method, '__doc__', None)
@@ -119,7 +130,6 @@ def WrapTwistedOptions(wrapped):
flags.DEFINE_string(name, None, doc)
self._paramHandlers[name] = method
def _doHandlers(self):
for flag, handler in self._flagHandlers.iteritems():
if self[flag]:
@@ -189,7 +199,7 @@ def stop(pidfile):
"""
# Get the pid from the pidfile
try:
pf = file(pidfile,'r')
pf = file(pidfile, 'r')
pid = int(pf.read().strip())
pf.close()
except IOError:
@@ -198,7 +208,8 @@ def stop(pidfile):
if not pid:
message = "pidfile %s does not exist. Daemon not running?\n"
sys.stderr.write(message % pidfile)
return # not an error in a restart
# Not an error in a restart
return
# Try killing the daemon process
try:

View File

@@ -33,6 +33,7 @@ from twisted.internet.threads import deferToThread
from nova import exception
from nova import flags
from nova.exception import ProcessExecutionError
FLAGS = flags.FLAGS
@@ -48,6 +49,7 @@ def import_class(import_str):
except (ImportError, ValueError, AttributeError):
raise exception.NotFound('Class %s cannot be found' % class_str)
def import_object(import_str):
"""Returns an object including a module or module and class"""
try:
@@ -57,6 +59,7 @@ def import_object(import_str):
cls = import_class(import_str)
return cls()
def fetchfile(url, target):
logging.debug("Fetching %s" % url)
# c = pycurl.Curl()
@@ -68,7 +71,9 @@ def fetchfile(url, target):
# fp.close()
execute("curl --fail %s -o %s" % (url, target))
def execute(cmd, process_input=None, addl_env=None, check_exit_code=True):
logging.debug("Running cmd: %s", cmd)
env = os.environ.copy()
if addl_env:
env.update(addl_env)
@@ -82,9 +87,12 @@ def execute(cmd, process_input=None, addl_env=None, check_exit_code=True):
obj.stdin.close()
if obj.returncode:
logging.debug("Result was %s" % (obj.returncode))
if check_exit_code and obj.returncode <> 0:
raise Exception( "Unexpected exit code: %s. result=%s"
% (obj.returncode, result))
if check_exit_code and obj.returncode != 0:
(stdout, stderr) = result
raise ProcessExecutionError(exit_code=obj.returncode,
stdout=stdout,
stderr=stderr,
cmd=cmd)
return result
@@ -106,7 +114,8 @@ def default_flagfile(filename='nova.conf'):
script_dir = os.path.dirname(inspect.stack()[-1][1])
filename = os.path.abspath(os.path.join(script_dir, filename))
if os.path.exists(filename):
sys.argv = sys.argv[:1] + ['--flagfile=%s' % filename] + sys.argv[1:]
flagfile = ['--flagfile=%s' % filename]
sys.argv = sys.argv[:1] + flagfile + sys.argv[1:]
def debug(arg):
@@ -114,23 +123,32 @@ def debug(arg):
return arg
def runthis(prompt, cmd, check_exit_code = True):
def runthis(prompt, cmd, check_exit_code=True):
logging.debug("Running %s" % (cmd))
exit_code = subprocess.call(cmd.split(" "))
logging.debug(prompt % (exit_code))
if check_exit_code and exit_code <> 0:
raise Exception( "Unexpected exit code: %s from cmd: %s"
% (exit_code, cmd))
if check_exit_code and exit_code != 0:
raise ProcessExecutionError(exit_code=exit_code,
stdout=None,
stderr=None,
cmd=cmd)
def generate_uid(topic, size=8):
return '%s-%s' % (topic, ''.join([random.choice('01234567890abcdefghijklmnopqrstuvwxyz') for x in xrange(size)]))
if topic == "i":
# Instances have integer internal ids.
return random.randint(0, 2 ** 32 - 1)
else:
characters = '01234567890abcdefghijklmnopqrstuvwxyz'
choices = [random.choice(characters) for x in xrange(size)]
return '%s-%s' % (topic, ''.join(choices))
def generate_mac():
mac = [0x02, 0x16, 0x3e, random.randint(0x00, 0x7f),
random.randint(0x00, 0xff), random.randint(0x00, 0xff)
]
mac = [0x02, 0x16, 0x3e,
random.randint(0x00, 0x7f),
random.randint(0x00, 0xff),
random.randint(0x00, 0xff)]
return ':'.join(map(lambda x: "%02x" % x, mac))
@@ -186,13 +204,14 @@ class LazyPluggable(object):
fromlist = backend
self.__backend = __import__(name, None, None, fromlist)
logging.error('backend %s', self.__backend)
logging.info('backend %s', self.__backend)
return self.__backend
def __getattr__(self, key):
backend = self.__get_backend()
return getattr(backend, key)
def deferredToThread(f):
def g(*args, **kwargs):
return deferToThread(f, *args, **kwargs)

View File

@@ -16,18 +16,20 @@
# License for the specific language governing permissions and limitations
# under the License.
"""
Decorators for argument validation, courtesy of
http://rmi.net/~lutz/rangetest.html
"""
"""Decorators for argument validation, courtesy of
http://rmi.net/~lutz/rangetest.html"""
def rangetest(**argchecks): # validate ranges for both+defaults
def onDecorator(func): # onCall remembers func and argchecks
def rangetest(**argchecks):
"""Validate ranges for both + defaults"""
def onDecorator(func):
"""onCall remembers func and argchecks"""
import sys
code = func.__code__ if sys.version_info[0] == 3 else func.func_code
allargs = code.co_varnames[:code.co_argcount]
allargs = code.co_varnames[:code.co_argcount]
funcname = func.__name__
def onCall(*pargs, **kargs):
# all pargs match first N args by position
# the rest must be in kargs or omitted defaults
@@ -38,7 +40,8 @@ def rangetest(**argchecks): # validate ranges for both+defaults
# for all args to be checked
if argname in kargs:
# was passed by name
if float(kargs[argname]) < low or float(kargs[argname]) > high:
if float(kargs[argname]) < low or \
float(kargs[argname]) > high:
errmsg = '{0} argument "{1}" not in {2}..{3}'
errmsg = errmsg.format(funcname, argname, low, high)
raise TypeError(errmsg)
@@ -46,9 +49,12 @@ def rangetest(**argchecks): # validate ranges for both+defaults
elif argname in positionals:
# was passed by position
position = positionals.index(argname)
if float(pargs[position]) < low or float(pargs[position]) > high:
errmsg = '{0} argument "{1}" with value of {4} not in {2}..{3}'
errmsg = errmsg.format(funcname, argname, low, high, pargs[position])
if float(pargs[position]) < low or \
float(pargs[position]) > high:
errmsg = '{0} argument "{1}" with value of {4} ' \
'not in {2}..{3}'
errmsg = errmsg.format(funcname, argname, low, high,
pargs[position])
raise TypeError(errmsg)
else:
pass
@@ -62,9 +68,9 @@ def typetest(**argchecks):
def onDecorator(func):
import sys
code = func.__code__ if sys.version_info[0] == 3 else func.func_code
allargs = code.co_varnames[:code.co_argcount]
allargs = code.co_varnames[:code.co_argcount]
funcname = func.__name__
def onCall(*pargs, **kargs):
positionals = list(allargs)[:len(pargs)]
for (argname, typeof) in argchecks.items():
@@ -76,12 +82,13 @@ def typetest(**argchecks):
elif argname in positionals:
position = positionals.index(argname)
if not isinstance(pargs[position], typeof):
errmsg = '{0} argument "{1}" with value of {2} not of type {3}'
errmsg = errmsg.format(funcname, argname, pargs[position], typeof)
errmsg = '{0} argument "{1}" with value of {2} ' \
'not of type {3}'
errmsg = errmsg.format(funcname, argname,
pargs[position], typeof)
raise TypeError(errmsg)
else:
pass
return func(*pargs, **kargs)
return onCall
return onDecorator

View File

@@ -21,14 +21,17 @@
Utility methods for working with WSGI servers
"""
import json
import logging
import sys
from xml.dom import minidom
import eventlet
import eventlet.wsgi
eventlet.patcher.monkey_patch(all=False, socket=True)
import routes
import routes.middleware
import webob
import webob.dec
import webob.exc
@@ -91,11 +94,11 @@ class Middleware(Application):
behavior.
"""
def __init__(self, application): # pylint: disable-msg=W0231
def __init__(self, application): # pylint: disable-msg=W0231
self.application = application
@webob.dec.wsgify
def __call__(self, req): # pylint: disable-msg=W0221
def __call__(self, req): # pylint: disable-msg=W0221
"""Override to implement middleware behavior."""
return self.application
@@ -213,7 +216,7 @@ class Controller(object):
arg_dict['req'] = req
result = method(**arg_dict)
if type(result) is dict:
return self._serialize(result, req)
return self._serialize(result, req)
else:
return result
@@ -227,10 +230,20 @@ class Controller(object):
serializer = Serializer(request.environ, _metadata)
return serializer.to_content_type(data)
def _deserialize(self, data, request):
"""
Deserialize the request body to the response type requested in request.
Uses self._serialization_metadata if it exists, which is a dict mapping
MIME types to information needed to serialize to that type.
"""
_metadata = getattr(type(self), "_serialization_metadata", {})
serializer = Serializer(request.environ, _metadata)
return serializer.deserialize(data)
class Serializer(object):
"""
Serializes a dictionary to a Content Type specified by a WSGI environment.
Serializes and deserializes dictionaries to certain MIME types.
"""
def __init__(self, environ, metadata=None):
@@ -239,31 +252,79 @@ class Serializer(object):
'metadata' is an optional dict mapping MIME types to information
needed to serialize a dictionary to that type.
"""
self.environ = environ
self.metadata = metadata or {}
self._methods = {
'application/json': self._to_json,
'application/xml': self._to_xml}
req = webob.Request(environ)
suffix = req.path_info.split('.')[-1].lower()
if suffix == 'json':
self.handler = self._to_json
elif suffix == 'xml':
self.handler = self._to_xml
elif 'application/json' in req.accept:
self.handler = self._to_json
elif 'application/xml' in req.accept:
self.handler = self._to_xml
else:
# This is the default
self.handler = self._to_json
def to_content_type(self, data):
"""
Serialize a dictionary into a string. The format of the string
will be decided based on the Content Type requested in self.environ:
by Accept: header, or by URL suffix.
Serialize a dictionary into a string.
The format of the string will be decided based on the Content Type
requested in self.environ: by Accept: header, or by URL suffix.
"""
mimetype = 'application/xml'
# TODO(gundlach): determine mimetype from request
return self._methods.get(mimetype, repr)(data)
return self.handler(data)
def deserialize(self, datastring):
"""
Deserialize a string to a dictionary.
The string must be in the format of a supported MIME type.
"""
datastring = datastring.strip()
try:
is_xml = (datastring[0] == '<')
if not is_xml:
return json.loads(datastring)
return self._from_xml(datastring)
except:
return None
def _from_xml(self, datastring):
xmldata = self.metadata.get('application/xml', {})
plurals = set(xmldata.get('plurals', {}))
node = minidom.parseString(datastring).childNodes[0]
return {node.nodeName: self._from_xml_node(node, plurals)}
def _from_xml_node(self, node, listnames):
"""
Convert a minidom node to a simple Python type.
listnames is a collection of names of XML nodes whose subnodes should
be considered list items.
"""
if len(node.childNodes) == 1 and node.childNodes[0].nodeType == 3:
return node.childNodes[0].nodeValue
elif node.nodeName in listnames:
return [self._from_xml_node(n, listnames) for n in node.childNodes]
else:
result = dict()
for attr in node.attributes.keys():
result[attr] = node.attributes[attr].nodeValue
for child in node.childNodes:
if child.nodeType != node.TEXT_NODE:
result[child.nodeName] = self._from_xml_node(child,
listnames)
return result
def _to_json(self, data):
import json
return json.dumps(data)
def _to_xml(self, data):
metadata = self.metadata.get('application/xml', {})
# We expect data to contain a single key which is the XML root.
root_key = data.keys()[0]
from xml.dom import minidom
doc = minidom.Document()
node = self._to_xml_node(doc, metadata, root_key, data[root_key])
return node.toprettyxml(indent=' ')
@@ -289,7 +350,8 @@ class Serializer(object):
else:
node = self._to_xml_node(doc, metadata, k, v)
result.appendChild(node)
else: # atom
else:
# Type is atom
node = doc.createTextNode(str(data))
result.appendChild(node)
return result

View File

@@ -1,7 +1,8 @@
[Messages Control]
# W0511: TODOs in code comments are fine.
# W0142: *args and **kwargs are fine.
disable-msg=W0511,W0142
# W0622: Redefining id is fine.
disable-msg=W0511,W0142,W0622
[Basic]
# Variable names can be 1 to 31 characters long, with lowercase and underscores
@@ -12,7 +13,7 @@ argument-rgx=[a-z_][a-z0-9_]{1,30}$
# Method names should be at least 3 characters long
# and be lowecased with underscores
method-rgx=[a-z_][a-z0-9_]{2,50}$
method-rgx=([a-z_][a-z0-9_]{2,50}|setUp|tearDown)$
# Module names matching nova-* are ok (files in bin/)
module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+)|(nova-[a-z0-9_-]+))$

View File

@@ -45,7 +45,6 @@ import sys
from twisted.scripts import trial as trial_script
from nova import datastore
from nova import flags
from nova import twistd
@@ -58,11 +57,15 @@ from nova.tests.flags_unittest import *
from nova.tests.network_unittest import *
from nova.tests.objectstore_unittest import *
from nova.tests.process_unittest import *
from nova.tests.quota_unittest import *
from nova.tests.rpc_unittest import *
from nova.tests.scheduler_unittest import *
from nova.tests.service_unittest import *
from nova.tests.twistd_unittest import *
from nova.tests.validator_unittest import *
from nova.tests.virt_unittest import *
from nova.tests.volume_unittest import *
from nova.tests.virt_unittest import *
FLAGS = flags.FLAGS
@@ -83,12 +86,6 @@ if __name__ == '__main__':
# TODO(termie): these should make a call instead of doing work on import
if FLAGS.fake_tests:
from nova.tests.fake_flags import *
# use db 8 for fake tests
FLAGS.redis_db = 8
if FLAGS.flush_db:
logging.info("Flushing redis datastore")
r = datastore.Redis.instance()
r.flushdb()
else:
from nova.tests.real_flags import *

View File

@@ -6,6 +6,7 @@ function usage {
echo ""
echo " -V, --virtual-env Always use virtualenv. Install automatically if not present"
echo " -N, --no-virtual-env Don't use virtualenv. Run tests in local environment"
echo " -f, --force Force a clean re-build of the virtual environment. Useful when dependencies have been added."
echo " -h, --help Print this usage message"
echo ""
echo "Note: with no options specified, the script will try to run the tests in a virtual environment,"
@@ -14,20 +15,12 @@ function usage {
exit
}
function process_options {
array=$1
elements=${#array[@]}
for (( x=0;x<$elements;x++)); do
process_option ${array[${x}]}
done
}
function process_option {
option=$1
case $option in
case "$1" in
-h|--help) usage;;
-V|--virtual-env) let always_venv=1; let never_venv=0;;
-N|--no-virtual-env) let always_venv=0; let never_venv=1;;
-f|--force) let force=1;;
esac
}
@@ -35,9 +28,11 @@ venv=.nova-venv
with_venv=tools/with_venv.sh
always_venv=0
never_venv=0
options=("$@")
force=0
process_options $options
for arg in "$@"; do
process_option $arg
done
if [ $never_venv -eq 1 ]; then
# Just run the test suites in current environment
@@ -45,6 +40,12 @@ if [ $never_venv -eq 1 ]; then
exit
fi
# Remove the virtual environment if --force used
if [ $force -eq 1 ]; then
echo "Cleaning virtualenv..."
rm -rf ${venv}
fi
if [ -e ${venv} ]; then
${with_venv} python run_tests.py $@
else
@@ -54,7 +55,7 @@ else
else
echo -e "No virtual environment found...create one? (Y/n) \c"
read use_ve
if [ "x$use_ve" = "xY" ]; then
if [ "x$use_ve" = "xY" -o "x$use_ve" = "x" -o "x$use_ve" = "xy" ]; then
# Install the virtualenv and run the test suite in it
python tools/install_venv.py
else

View File

@@ -39,7 +39,7 @@ class local_sdist(sdist):
sdist.run(self)
setup(name='nova',
version='0.9.1',
version='2010.1',
description='cloud computing fabric controller',
author='OpenStack',
author_email='nova@lists.launchpad.net',
@@ -54,5 +54,5 @@ setup(name='nova',
'bin/nova-manage',
'bin/nova-network',
'bin/nova-objectstore',
'bin/nova-api-new',
'bin/nova-scheduler',
'bin/nova-volume'])