Merged with trunk, fixed broken stuff

This commit is contained in:
Justin Santa Barbara
2010-10-14 12:59:36 -07:00
51 changed files with 3164 additions and 3300 deletions

View File

@@ -1,61 +1,48 @@
#!/usr/bin/env python
# pylint: disable-msg=C0103
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tornado daemon for the main API endpoint.
Nova API daemon.
"""
import logging
from tornado import httpserver
from tornado import ioloop
import os
import sys
# If ../nova/__init__.py exists, add ../ to Python search path, so that
# it will override what happens to be installed in /usr/(local/)lib/python...
possible_topdir = os.path.normpath(os.path.join(os.path.abspath(sys.argv[0]),
os.pardir,
os.pardir))
if os.path.exists(os.path.join(possible_topdir, 'nova', '__init__.py')):
sys.path.insert(0, possible_topdir)
from nova import flags
from nova import rpc
from nova import server
from nova import utils
from nova.endpoint import admin
from nova.endpoint import api
from nova.endpoint import cloud
from nova import server
FLAGS = flags.FLAGS
flags.DEFINE_integer('api_port', 8773, 'API port')
def main(_argv):
"""Load the controllers and start the tornado I/O loop."""
controllers = {
'Cloud': cloud.CloudController(),
'Admin': admin.AdminController()}
_app = api.APIServerApplication(controllers)
conn = rpc.Connection.instance()
consumer = rpc.AdapterConsumer(connection=conn,
topic=FLAGS.cloud_topic,
proxy=controllers['Cloud'])
io_inst = ioloop.IOLoop.instance()
_injected = consumer.attach_to_tornado(io_inst)
http_server = httpserver.HTTPServer(_app)
http_server.listen(FLAGS.cc_port)
logging.debug('Started HTTP server on %s', FLAGS.cc_port)
io_inst.start()
def main(_args):
from nova import api
from nova import wsgi
wsgi.run_server(api.API(), FLAGS.api_port)
if __name__ == '__main__':
utils.default_flagfile()

View File

@@ -1,34 +0,0 @@
#!/usr/bin/env python
# pylint: disable-msg=C0103
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Nova API daemon.
"""
from nova import api
from nova import flags
from nova import utils
from nova import wsgi
FLAGS = flags.FLAGS
flags.DEFINE_integer('api_port', 8773, 'API port')
if __name__ == '__main__':
utils.default_flagfile()
wsgi.run_server(api.API(), FLAGS.api_port)

View File

@@ -21,12 +21,23 @@
Twistd daemon for the nova compute nodes.
"""
import os
import sys
# If ../nova/__init__.py exists, add ../ to Python search path, so that
# it will override what happens to be installed in /usr/(local/)lib/python...
possible_topdir = os.path.normpath(os.path.join(os.path.abspath(sys.argv[0]),
os.pardir,
os.pardir))
if os.path.exists(os.path.join(possible_topdir, 'nova', '__init__.py')):
sys.path.insert(0, possible_topdir)
from nova import service
from nova import twistd
from nova.compute import service
if __name__ == '__main__':
twistd.serve(__file__)
if __name__ == '__builtin__':
application = service.ComputeService.create()
application = service.Service.create() # pylint: disable=C0103

View File

@@ -25,53 +25,65 @@ import logging
import os
import sys
#TODO(joshua): there is concern that the user dnsmasq runs under will not
# have nova in the path. This should be verified and if it is
# not true the ugly line below can be removed
sys.path.append(os.path.abspath(os.path.join(__file__, "../../")))
# If ../nova/__init__.py exists, add ../ to Python search path, so that
# it will override what happens to be installed in /usr/(local/)lib/python...
possible_topdir = os.path.normpath(os.path.join(os.path.abspath(sys.argv[0]),
os.pardir,
os.pardir))
if os.path.exists(os.path.join(possible_topdir, 'nova', '__init__.py')):
sys.path.insert(0, possible_topdir)
from nova import db
from nova import flags
from nova import rpc
from nova import utils
from nova.network import linux_net
from nova.network import model
from nova.network import service
FLAGS = flags.FLAGS
flags.DECLARE('auth_driver', 'nova.auth.manager')
flags.DECLARE('redis_db', 'nova.datastore')
flags.DECLARE('network_size', 'nova.network.manager')
flags.DECLARE('num_networks', 'nova.network.manager')
flags.DECLARE('update_dhcp_on_disassociate', 'nova.network.manager')
def add_lease(_mac, ip, _hostname, _interface):
def add_lease(mac, ip_address, _hostname, _interface):
"""Set the IP that was assigned by the DHCP server."""
if FLAGS.fake_rabbit:
service.VlanNetworkService().lease_ip(ip)
logging.debug("leasing ip")
network_manager = utils.import_object(FLAGS.network_manager)
network_manager.lease_fixed_ip(None, mac, ip_address)
else:
rpc.cast("%s.%s" % (FLAGS.network_topic, FLAGS.node_name),
{"method": "lease_ip",
"args": {"fixed_ip": ip}})
rpc.cast("%s.%s" % (FLAGS.network_topic, FLAGS.host),
{"method": "lease_fixed_ip",
"args": {"context": None,
"mac": mac,
"address": ip_address}})
def old_lease(_mac, _ip, _hostname, _interface):
def old_lease(_mac, _ip_address, _hostname, _interface):
"""Do nothing, just an old lease update."""
logging.debug("Adopted old lease or got a change of mac/hostname")
def del_lease(_mac, ip, _hostname, _interface):
def del_lease(mac, ip_address, _hostname, _interface):
"""Called when a lease expires."""
if FLAGS.fake_rabbit:
service.VlanNetworkService().release_ip(ip)
logging.debug("releasing ip")
network_manager = utils.import_object(FLAGS.network_manager)
network_manager.release_fixed_ip(None, mac, ip_address)
else:
rpc.cast("%s.%s" % (FLAGS.network_topic, FLAGS.node_name),
{"method": "release_ip",
"args": {"fixed_ip": ip}})
rpc.cast("%s.%s" % (FLAGS.network_topic, FLAGS.host),
{"method": "release_fixed_ip",
"args": {"context": None,
"mac": mac,
"address": ip_address}})
def init_leases(interface):
"""Get the list of hosts for an interface."""
net = model.get_network_by_interface(interface)
res = ""
for address in net.assigned_objs:
res += "%s\n" % linux_net.host_dhcp(address)
return res
network_ref = db.network_get_by_bridge(None, interface)
return linux_net.get_dhcp_hosts(None, network_ref['id'])
def main():
@@ -83,10 +95,16 @@ def main():
if int(os.environ.get('TESTING', '0')):
FLAGS.fake_rabbit = True
FLAGS.redis_db = 8
FLAGS.network_size = 32
FLAGS.network_size = 16
FLAGS.connection_type = 'fake'
FLAGS.fake_network = True
FLAGS.auth_driver = 'nova.auth.ldapdriver.FakeLdapDriver'
FLAGS.num_networks = 5
path = os.path.abspath(os.path.join(os.path.dirname(__file__),
'..',
'_trial_temp',
'nova.sqlite'))
FLAGS.sql_connection = 'sqlite:///%s' % path
action = argv[1]
if action in ['add', 'del', 'old']:
mac = argv[2]

View File

@@ -29,18 +29,26 @@ import subprocess
import sys
import urllib2
# If ../nova/__init__.py exists, add ../ to Python search path, so that
# it will override what happens to be installed in /usr/(local/)lib/python...
possible_topdir = os.path.normpath(os.path.join(os.path.abspath(sys.argv[0]),
os.pardir,
os.pardir))
if os.path.exists(os.path.join(possible_topdir, 'nova', '__init__.py')):
sys.path.insert(0, possible_topdir)
from nova import flags
from nova import utils
from nova.objectstore import image
FLAGS = flags.FLAGS
api_url = 'https://imagestore.canonical.com/api/dashboard'
API_URL = 'https://imagestore.canonical.com/api/dashboard'
def get_images():
"""Get a list of the images from the imagestore URL."""
images = json.load(urllib2.urlopen(api_url))['images']
images = json.load(urllib2.urlopen(API_URL))['images']
images = [img for img in images if img['title'].find('amd64') > -1]
return images

View File

@@ -21,9 +21,19 @@
Daemon for Nova RRD based instance resource monitoring.
"""
import os
import logging
import sys
from twisted.application import service
# If ../nova/__init__.py exists, add ../ to Python search path, so that
# it will override what happens to be installed in /usr/(local/)lib/python...
possible_topdir = os.path.normpath(os.path.join(os.path.abspath(sys.argv[0]),
os.pardir,
os.pardir))
if os.path.exists(os.path.join(possible_topdir, 'nova', '__init__.py')):
sys.path.insert(0, possible_topdir)
from nova import twistd
from nova.compute import monitor
@@ -35,9 +45,10 @@ if __name__ == '__main__':
if __name__ == '__builtin__':
logging.warn('Starting instance monitor')
m = monitor.InstanceMonitor()
# pylint: disable-msg=C0103
monitor = monitor.InstanceMonitor()
# This is the parent service that twistd will be looking for when it
# parses this file, return it so that we can get it into globals below
application = service.Application('nova-instancemonitor')
m.setServiceParent(application)
monitor.setServiceParent(application)

View File

@@ -17,20 +17,64 @@
# License for the specific language governing permissions and limitations
# under the License.
# Interactive shell based on Django:
#
# Copyright (c) 2005, the Lawrence Journal-World
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
#
# 3. Neither the name of Django nor the names of its contributors may be used
# to endorse or promote products derived from this software without
# specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
CLI interface for nova management.
Connects to the running ADMIN api in the api daemon.
"""
import logging
import os
import sys
import time
import IPy
# If ../nova/__init__.py exists, add ../ to Python search path, so that
# it will override what happens to be installed in /usr/(local/)lib/python...
possible_topdir = os.path.normpath(os.path.join(os.path.abspath(sys.argv[0]),
os.pardir,
os.pardir))
if os.path.exists(os.path.join(possible_topdir, 'nova', '__init__.py')):
sys.path.insert(0, possible_topdir)
from nova import db
from nova import exception
from nova import flags
from nova import quota
from nova import utils
from nova.auth import manager
from nova.compute import model
from nova.network import manager as network_manager
from nova.cloudpipe import pipelib
from nova.endpoint import cloud
FLAGS = flags.FLAGS
@@ -41,23 +85,27 @@ class VpnCommands(object):
def __init__(self):
self.manager = manager.AuthManager()
self.instdir = model.InstanceDirectory()
self.pipe = pipelib.CloudPipe(cloud.CloudController())
self.pipe = pipelib.CloudPipe()
def list(self):
"""Print a listing of the VPNs for all projects."""
print "%-12s\t" % 'project',
print "%-12s\t" % 'ip:port',
print "%-20s\t" % 'ip:port',
print "%s" % 'state'
for project in self.manager.get_projects():
print "%-12s\t" % project.name,
print "%s:%s\t" % (project.vpn_ip, project.vpn_port),
try:
s = "%s:%s" % (project.vpn_ip, project.vpn_port)
except exception.NotFound:
s = "None"
print "%-20s\t" % s,
vpn = self._vpn_for(project.id)
if vpn:
command = "ping -c1 -w1 %s > /dev/null; echo $?"
out, _err = utils.execute( command % vpn['private_dns_name'],
check_exit_code=False)
out, _err = utils.execute(command % vpn['private_dns_name'],
check_exit_code=False)
if out.strip() == '0':
net = 'up'
else:
@@ -73,9 +121,8 @@ class VpnCommands(object):
def _vpn_for(self, project_id):
"""Get the VPN instance for a project ID."""
for instance in self.instdir.all:
if ('image_id' in instance.state
and instance['image_id'] == FLAGS.vpn_image_id
for instance in db.instance_get_all(None):
if (instance['image_id'] == FLAGS.vpn_image_id
and not instance['state_description'] in
['shutting_down', 'shutdown']
and instance['project_id'] == project_id):
@@ -94,6 +141,67 @@ class VpnCommands(object):
self.pipe.launch_vpn_instance(project_id)
class ShellCommands(object):
def bpython(self):
"""Runs a bpython shell.
Falls back to Ipython/python shell if unavailable"""
self.run('bpython')
def ipython(self):
"""Runs an Ipython shell.
Falls back to Python shell if unavailable"""
self.run('ipython')
def python(self):
"""Runs a python shell.
Falls back to Python shell if unavailable"""
self.run('python')
def run(self, shell=None):
"""Runs a Python interactive interpreter.
args: [shell=bpython]"""
if not shell:
shell = 'bpython'
if shell == 'bpython':
try:
import bpython
bpython.embed()
except ImportError:
shell = 'ipython'
if shell == 'ipython':
try:
import IPython
# Explicitly pass an empty list as arguments, because otherwise IPython
# would use sys.argv from this script.
shell = IPython.Shell.IPShell(argv=[])
shell.mainloop()
except ImportError:
shell = 'python'
if shell == 'python':
import code
try: # Try activating rlcompleter, because it's handy.
import readline
except ImportError:
pass
else:
# We don't have to wrap the following import in a 'try', because
# we already know 'readline' was imported successfully.
import rlcompleter
readline.parse_and_bind("tab:complete")
code.interact()
def script(self, path):
"""Runs the script from the specifed path with flags set properly.
arguments: path"""
exec(compile(open(path).read(), path, 'exec'), locals(), globals())
class RoleCommands(object):
"""Class for managing roles."""
@@ -123,6 +231,13 @@ class RoleCommands(object):
class UserCommands(object):
"""Class for managing users."""
@staticmethod
def _print_export(user):
"""Print export variables to use with API."""
print 'export EC2_ACCESS_KEY=%s' % user.access
print 'export EC2_SECRET_KEY=%s' % user.secret
def __init__(self):
self.manager = manager.AuthManager()
@@ -130,13 +245,13 @@ class UserCommands(object):
"""creates a new admin and prints exports
arguments: name [access] [secret]"""
user = self.manager.create_user(name, access, secret, True)
print_export(user)
self._print_export(user)
def create(self, name, access=None, secret=None):
"""creates a new user and prints exports
arguments: name [access] [secret]"""
user = self.manager.create_user(name, access, secret, False)
print_export(user)
self._print_export(user)
def delete(self, name):
"""deletes an existing user
@@ -148,7 +263,7 @@ class UserCommands(object):
arguments: name"""
user = self.manager.get_user(name)
if user:
print_export(user)
self._print_export(user)
else:
print "User %s doesn't exist" % name
@@ -158,12 +273,18 @@ class UserCommands(object):
for user in self.manager.get_users():
print user.name
def print_export(user):
"""Print export variables to use with API."""
print 'export EC2_ACCESS_KEY=%s' % user.access
print 'export EC2_SECRET_KEY=%s' % user.secret
def modify(self, name, access_key, secret_key, is_admin):
"""update a users keys & admin flag
arguments: accesskey secretkey admin
leave any field blank to ignore it, admin should be 'T', 'F', or blank
"""
if not is_admin:
is_admin = None
elif is_admin.upper()[0] == 'T':
is_admin = True
else:
is_admin = False
self.manager.modify_user(name, access_key, secret_key, is_admin)
class ProjectCommands(object):
"""Class for managing projects."""
@@ -189,7 +310,7 @@ class ProjectCommands(object):
def environment(self, project_id, user_id, filename='novarc'):
"""Exports environment variables to an sourcable file
arguments: project_id user_id [filename='novarc]"""
rc = self.manager.get_environment_rc(project_id, user_id)
rc = self.manager.get_environment_rc(user_id, project_id)
with open(filename, 'w') as f:
f.write(rc)
@@ -199,6 +320,19 @@ class ProjectCommands(object):
for project in self.manager.get_projects():
print project.name
def quota(self, project_id, key=None, value=None):
"""Set or display quotas for project
arguments: project_id [key] [value]"""
if key:
quo = {'project_id': project_id, key: value}
try:
db.quota_update(None, project_id, quo)
except exception.NotFound:
db.quota_create(None, quo)
project_quota = quota.get_quota(None, project_id)
for key, value in project_quota.iteritems():
print '%s: %s' % (key, value)
def remove(self, project, user):
"""Removes user from project
arguments: project user"""
@@ -212,11 +346,70 @@ class ProjectCommands(object):
f.write(zip_file)
categories = [
class FloatingIpCommands(object):
"""Class for managing floating ip."""
def create(self, host, range):
"""Creates floating ips for host by range
arguments: host ip_range"""
for address in IPy.IP(range):
db.floating_ip_create(None, {'address': str(address),
'host': host})
def delete(self, ip_range):
"""Deletes floating ips by range
arguments: range"""
for address in IPy.IP(ip_range):
db.floating_ip_destroy(None, str(address))
def list(self, host=None):
"""Lists all floating ips (optionally by host)
arguments: [host]"""
if host == None:
floating_ips = db.floating_ip_get_all(None)
else:
floating_ips = db.floating_ip_get_all_by_host(None, host)
for floating_ip in floating_ips:
instance = None
if floating_ip['fixed_ip']:
instance = floating_ip['fixed_ip']['instance']['ec2_id']
print "%s\t%s\t%s" % (floating_ip['host'],
floating_ip['address'],
instance)
class NetworkCommands(object):
"""Class for managing networks."""
def create(self, fixed_range=None, num_networks=None,
network_size=None, vlan_start=None, vpn_start=None):
"""Creates fixed ips for host by range
arguments: [fixed_range=FLAG], [num_networks=FLAG],
[network_size=FLAG], [vlan_start=FLAG],
[vpn_start=FLAG]"""
if not fixed_range:
fixed_range = FLAGS.fixed_range
if not num_networks:
num_networks = FLAGS.num_networks
if not network_size:
network_size = FLAGS.network_size
if not vlan_start:
vlan_start = FLAGS.vlan_start
if not vpn_start:
vpn_start = FLAGS.vpn_start
net_manager = utils.import_object(FLAGS.network_manager)
net_manager.create_networks(None, fixed_range, int(num_networks),
int(network_size), int(vlan_start),
int(vpn_start))
CATEGORIES = [
('user', UserCommands),
('project', ProjectCommands),
('role', RoleCommands),
('shell', ShellCommands),
('vpn', VpnCommands),
('floating', FloatingIpCommands),
('network', NetworkCommands)
]
@@ -255,15 +448,19 @@ def main():
"""Parse options and call the appropriate class/method."""
utils.default_flagfile('/etc/nova/nova-manage.conf')
argv = FLAGS(sys.argv)
if FLAGS.verbose:
logging.getLogger().setLevel(logging.DEBUG)
script_name = argv.pop(0)
if len(argv) < 1:
print script_name + " category action [<args>]"
print "Available categories:"
for k, _ in categories:
for k, _ in CATEGORIES:
print "\t%s" % k
sys.exit(2)
category = argv.pop(0)
matches = lazy_match(category, categories)
matches = lazy_match(category, CATEGORIES)
# instantiate the command group object
category, fn = matches[0]
command_object = fn()
@@ -282,9 +479,9 @@ def main():
fn(*argv)
sys.exit(0)
except TypeError:
print "Wrong number of arguments supplied"
print "Possible wrong number of arguments supplied"
print "%s %s: %s" % (category, action, fn.__doc__)
sys.exit(2)
raise
if __name__ == '__main__':
main()

View File

@@ -21,16 +21,23 @@
Twistd daemon for the nova network nodes.
"""
from nova import flags
import os
import sys
# If ../nova/__init__.py exists, add ../ to Python search path, so that
# it will override what happens to be installed in /usr/(local/)lib/python...
possible_topdir = os.path.normpath(os.path.join(os.path.abspath(sys.argv[0]),
os.pardir,
os.pardir))
if os.path.exists(os.path.join(possible_topdir, 'nova', '__init__.py')):
sys.path.insert(0, possible_topdir)
from nova import service
from nova import twistd
from nova.network import service
FLAGS = flags.FLAGS
if __name__ == '__main__':
twistd.serve(__file__)
if __name__ == '__builtin__':
application = service.type_to_class(FLAGS.network_type).create()
application = service.Service.create() # pylint: disable-msg=C0103

View File

@@ -21,6 +21,17 @@
Twisted daemon for nova objectstore. Supports S3 API.
"""
import os
import sys
# If ../nova/__init__.py exists, add ../ to Python search path, so that
# it will override what happens to be installed in /usr/(local/)lib/python...
possible_topdir = os.path.normpath(os.path.join(os.path.abspath(sys.argv[0]),
os.pardir,
os.pardir))
if os.path.exists(os.path.join(possible_topdir, 'nova', '__init__.py')):
sys.path.insert(0, possible_topdir)
from nova import flags
from nova import utils
from nova import twistd
@@ -35,4 +46,4 @@ if __name__ == '__main__':
if __name__ == '__builtin__':
utils.default_flagfile()
application = handler.get_application()
application = handler.get_application() # pylint: disable-msg=C0103

43
bin/nova-scheduler Executable file
View File

@@ -0,0 +1,43 @@
#!/usr/bin/env python
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Twistd daemon for the nova scheduler nodes.
"""
import os
import sys
# If ../nova/__init__.py exists, add ../ to Python search path, so that
# it will override what happens to be installed in /usr/(local/)lib/python...
possible_topdir = os.path.normpath(os.path.join(os.path.abspath(sys.argv[0]),
os.pardir,
os.pardir))
if os.path.exists(os.path.join(possible_topdir, 'nova', '__init__.py')):
sys.path.insert(0, possible_topdir)
from nova import service
from nova import twistd
if __name__ == '__main__':
twistd.serve(__file__)
if __name__ == '__builtin__':
application = service.Service.create()

View File

@@ -21,12 +21,23 @@
Twistd daemon for the nova volume nodes.
"""
import os
import sys
# If ../nova/__init__.py exists, add ../ to Python search path, so that
# it will override what happens to be installed in /usr/(local/)lib/python...
possible_topdir = os.path.normpath(os.path.join(os.path.abspath(sys.argv[0]),
os.pardir,
os.pardir))
if os.path.exists(os.path.join(possible_topdir, 'nova', '__init__.py')):
sys.path.insert(0, possible_topdir)
from nova import service
from nova import twistd
from nova.volume import service
if __name__ == '__main__':
twistd.serve(__file__)
if __name__ == '__builtin__':
application = service.VolumeService.create()
application = service.Service.create() # pylint: disable-msg=C0103

View File

@@ -0,0 +1,59 @@
# Copyright 2010 OpenStack LLC
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Log format for Nova's changelog."""
import bzrlib.log
from bzrlib.osutils import format_date
#
# This is mostly stolen from bzrlib.log.GnuChangelogLogFormatter
# The difference is that it logs the author rather than the committer
# which for Nova always is Tarmac.
#
class NovaLogFormat(bzrlib.log.GnuChangelogLogFormatter):
preferred_levels = 1
def log_revision(self, revision):
"""Log a revision, either merged or not."""
to_file = self.to_file
date_str = format_date(revision.rev.timestamp,
revision.rev.timezone or 0,
self.show_timezone,
date_fmt='%Y-%m-%d',
show_offset=False)
authors = revision.rev.get_apparent_authors()
to_file.write('%s %s\n\n' % (date_str, ", ".join(authors)))
if revision.delta is not None and revision.delta.has_changed():
for c in revision.delta.added + revision.delta.removed + revision.delta.modified:
path, = c[:1]
to_file.write('\t* %s:\n' % (path,))
for c in revision.delta.renamed:
oldpath,newpath = c[:2]
# For renamed files, show both the old and the new path
to_file.write('\t* %s:\n\t* %s:\n' % (oldpath,newpath))
to_file.write('\n')
if not revision.rev.message:
to_file.write('\tNo commit message\n')
else:
message = revision.rev.message.rstrip('\r\n')
for l in message.split('\n'):
to_file.write('\t%s\n' % (l.lstrip(),))
to_file.write('\n')
bzrlib.log.register_formatter('novalog', NovaLogFormat)

View File

@@ -20,11 +20,17 @@ Nova User API client library.
"""
import base64
import boto
import httplib
from boto.ec2.regioninfo import RegionInfo
DEFAULT_CLC_URL='http://127.0.0.1:8773'
DEFAULT_REGION='nova'
DEFAULT_ACCESS_KEY='admin'
DEFAULT_SECRET_KEY='admin'
class UserInfo(object):
"""
Information about a Nova user, as parsed through SAX
@@ -68,13 +74,13 @@ class UserRole(object):
def __init__(self, connection=None):
self.connection = connection
self.role = None
def __repr__(self):
return 'UserRole:%s' % self.role
def startElement(self, name, attrs, connection):
return None
def endElement(self, name, value, connection):
if name == 'role':
self.role = value
@@ -128,20 +134,20 @@ class ProjectMember(object):
def __init__(self, connection=None):
self.connection = connection
self.memberId = None
def __repr__(self):
return 'ProjectMember:%s' % self.memberId
def startElement(self, name, attrs, connection):
return None
def endElement(self, name, value, connection):
if name == 'member':
self.memberId = value
else:
setattr(self, name, str(value))
class HostInfo(object):
"""
Information about a Nova Host, as parsed through SAX:
@@ -171,35 +177,56 @@ class HostInfo(object):
class NovaAdminClient(object):
def __init__(self, clc_ip='127.0.0.1', region='nova', access_key='admin',
secret_key='admin', **kwargs):
self.clc_ip = clc_ip
def __init__(self, clc_url=DEFAULT_CLC_URL, region=DEFAULT_REGION,
access_key=DEFAULT_ACCESS_KEY, secret_key=DEFAULT_SECRET_KEY,
**kwargs):
parts = self.split_clc_url(clc_url)
self.clc_url = clc_url
self.region = region
self.access = access_key
self.secret = secret_key
self.apiconn = boto.connect_ec2(aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
is_secure=False,
region=RegionInfo(None, region, clc_ip),
port=8773,
is_secure=parts['is_secure'],
region=RegionInfo(None,
region,
parts['ip']),
port=parts['port'],
path='/services/Admin',
**kwargs)
self.apiconn.APIVersion = 'nova'
def connection_for(self, username, project, **kwargs):
def connection_for(self, username, project, clc_url=None, region=None,
**kwargs):
"""
Returns a boto ec2 connection for the given username.
"""
if not clc_url:
clc_url = self.clc_url
if not region:
region = self.region
parts = self.split_clc_url(clc_url)
user = self.get_user(username)
access_key = '%s:%s' % (user.accesskey, project)
return boto.connect_ec2(
aws_access_key_id=access_key,
aws_secret_access_key=user.secretkey,
is_secure=False,
region=RegionInfo(None, self.region, self.clc_ip),
port=8773,
path='/services/Cloud',
**kwargs)
return boto.connect_ec2(aws_access_key_id=access_key,
aws_secret_access_key=user.secretkey,
is_secure=parts['is_secure'],
region=RegionInfo(None,
self.region,
parts['ip']),
port=parts['port'],
path='/services/Cloud',
**kwargs)
def split_clc_url(self, clc_url):
"""
Splits a cloud controller endpoint url.
"""
parts = httplib.urlsplit(clc_url)
is_secure = parts.scheme == 'https'
ip, port = parts.netloc.split(':')
return {'ip': ip, 'port': int(port), 'is_secure': is_secure}
def get_users(self):
""" grabs the list of all users """
@@ -289,7 +316,7 @@ class NovaAdminClient(object):
if project.projectname != None:
return project
def create_project(self, projectname, manager_user, description=None,
member_users=None):
"""
@@ -322,7 +349,7 @@ class NovaAdminClient(object):
Adds a user to a project.
"""
return self.modify_project_member(user, project, operation='add')
def remove_project_member(self, user, project):
"""
Removes a user from a project.

236
nova/auth/dbdriver.py Normal file
View File

@@ -0,0 +1,236 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Auth driver using the DB as its backend.
"""
import logging
import sys
from nova import exception
from nova import db
class DbDriver(object):
"""DB Auth driver
Defines enter and exit and therefore supports the with/as syntax.
"""
def __init__(self):
"""Imports the LDAP module"""
pass
db
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
pass
def get_user(self, uid):
"""Retrieve user by id"""
return self._db_user_to_auth_user(db.user_get({}, uid))
def get_user_from_access_key(self, access):
"""Retrieve user by access key"""
return self._db_user_to_auth_user(db.user_get_by_access_key({}, access))
def get_project(self, pid):
"""Retrieve project by id"""
return self._db_project_to_auth_projectuser(db.project_get({}, pid))
def get_users(self):
"""Retrieve list of users"""
return [self._db_user_to_auth_user(user) for user in db.user_get_all({})]
def get_projects(self, uid=None):
"""Retrieve list of projects"""
if uid:
result = db.project_get_by_user({}, uid)
else:
result = db.project_get_all({})
return [self._db_project_to_auth_projectuser(proj) for proj in result]
def create_user(self, name, access_key, secret_key, is_admin):
"""Create a user"""
values = { 'id' : name,
'access_key' : access_key,
'secret_key' : secret_key,
'is_admin' : is_admin
}
try:
user_ref = db.user_create({}, values)
return self._db_user_to_auth_user(user_ref)
except exception.Duplicate, e:
raise exception.Duplicate('User %s already exists' % name)
def _db_user_to_auth_user(self, user_ref):
return { 'id' : user_ref['id'],
'name' : user_ref['id'],
'access' : user_ref['access_key'],
'secret' : user_ref['secret_key'],
'admin' : user_ref['is_admin'] }
def _db_project_to_auth_projectuser(self, project_ref):
return { 'id' : project_ref['id'],
'name' : project_ref['name'],
'project_manager_id' : project_ref['project_manager'],
'description' : project_ref['description'],
'member_ids' : [member['id'] for member in project_ref['members']] }
def create_project(self, name, manager_uid,
description=None, member_uids=None):
"""Create a project"""
manager = db.user_get({}, manager_uid)
if not manager:
raise exception.NotFound("Project can't be created because "
"manager %s doesn't exist" % manager_uid)
# description is a required attribute
if description is None:
description = name
# First, we ensure that all the given users exist before we go
# on to create the project. This way we won't have to destroy
# the project again because a user turns out to be invalid.
members = set([manager])
if member_uids != None:
for member_uid in member_uids:
member = db.user_get({}, member_uid)
if not member:
raise exception.NotFound("Project can't be created "
"because user %s doesn't exist"
% member_uid)
members.add(member)
values = { 'id' : name,
'name' : name,
'project_manager' : manager['id'],
'description': description }
try:
project = db.project_create({}, values)
except exception.Duplicate:
raise exception.Duplicate("Project can't be created because "
"project %s already exists" % name)
for member in members:
db.project_add_member({}, project['id'], member['id'])
# This looks silly, but ensures that the members element has been
# correctly populated
project_ref = db.project_get({}, project['id'])
return self._db_project_to_auth_projectuser(project_ref)
def modify_project(self, project_id, manager_uid=None, description=None):
"""Modify an existing project"""
if not manager_uid and not description:
return
values = {}
if manager_uid:
manager = db.user_get({}, manager_uid)
if not manager:
raise exception.NotFound("Project can't be modified because "
"manager %s doesn't exist" %
manager_uid)
values['project_manager'] = manager['id']
if description:
values['description'] = description
db.project_update({}, project_id, values)
def add_to_project(self, uid, project_id):
"""Add user to project"""
user, project = self._validate_user_and_project(uid, project_id)
db.project_add_member({}, project['id'], user['id'])
def remove_from_project(self, uid, project_id):
"""Remove user from project"""
user, project = self._validate_user_and_project(uid, project_id)
db.project_remove_member({}, project['id'], user['id'])
def is_in_project(self, uid, project_id):
"""Check if user is in project"""
user, project = self._validate_user_and_project(uid, project_id)
return user in project.members
def has_role(self, uid, role, project_id=None):
"""Check if user has role
If project is specified, it checks for local role, otherwise it
checks for global role
"""
return role in self.get_user_roles(uid, project_id)
def add_role(self, uid, role, project_id=None):
"""Add role for user (or user and project)"""
if not project_id:
db.user_add_role({}, uid, role)
return
db.user_add_project_role({}, uid, project_id, role)
def remove_role(self, uid, role, project_id=None):
"""Remove role for user (or user and project)"""
if not project_id:
db.user_remove_role({}, uid, role)
return
db.user_remove_project_role({}, uid, project_id, role)
def get_user_roles(self, uid, project_id=None):
"""Retrieve list of roles for user (or user and project)"""
if project_id is None:
roles = db.user_get_roles({}, uid)
return roles
else:
roles = db.user_get_roles_for_project({}, uid, project_id)
return roles
def delete_user(self, id):
"""Delete a user"""
user = db.user_get({}, id)
db.user_delete({}, user['id'])
def delete_project(self, project_id):
"""Delete a project"""
db.project_delete({}, project_id)
def modify_user(self, uid, access_key=None, secret_key=None, admin=None):
"""Modify an existing user"""
if not access_key and not secret_key and admin is None:
return
values = {}
if access_key:
values['access_key'] = access_key
if secret_key:
values['secret_key'] = secret_key
if admin is not None:
values['is_admin'] = admin
db.user_update({}, uid, values)
def _validate_user_and_project(self, user_id, project_id):
user = db.user_get({}, user_id)
if not user:
raise exception.NotFound('User "%s" not found' % user_id)
project = db.project_get({}, project_id)
if not project:
raise exception.NotFound('Project "%s" not found' % project_id)
return user, project

View File

@@ -30,20 +30,24 @@ from nova import datastore
SCOPE_BASE = 0
SCOPE_ONELEVEL = 1 # not implemented
SCOPE_SUBTREE = 2
SCOPE_SUBTREE = 2
MOD_ADD = 0
MOD_DELETE = 1
MOD_REPLACE = 2
class NO_SUCH_OBJECT(Exception):
class NO_SUCH_OBJECT(Exception): # pylint: disable-msg=C0103
"""Duplicate exception class from real LDAP module."""
pass
class OBJECT_CLASS_VIOLATION(Exception):
class OBJECT_CLASS_VIOLATION(Exception): # pylint: disable-msg=C0103
"""Duplicate exception class from real LDAP module."""
pass
def initialize(uri):
def initialize(_uri):
"""Opens a fake connection with an LDAP server."""
return FakeLDAP()
@@ -68,7 +72,7 @@ def _match_query(query, attrs):
# cut off the ! and the nested parentheses
return not _match_query(query[2:-1], attrs)
(k, sep, v) = inner.partition('=')
(k, _sep, v) = inner.partition('=')
return _match(k, v, attrs)
@@ -85,20 +89,20 @@ def _paren_groups(source):
if source[pos] == ')':
count -= 1
if count == 0:
result.append(source[start:pos+1])
result.append(source[start:pos + 1])
return result
def _match(k, v, attrs):
def _match(key, value, attrs):
"""Match a given key and value against an attribute list."""
if k not in attrs:
if key not in attrs:
return False
if k != "objectclass":
return v in attrs[k]
if key != "objectclass":
return value in attrs[key]
# it is an objectclass check, so check subclasses
values = _subs(v)
for value in values:
if value in attrs[k]:
values = _subs(value)
for v in values:
if v in attrs[key]:
return True
return False
@@ -145,6 +149,7 @@ def _to_json(unencoded):
class FakeLDAP(object):
#TODO(vish): refactor this class to use a wrapper instead of accessing
# redis directly
"""Fake LDAP connection."""
def simple_bind_s(self, dn, password):
"""This method is ignored, but provided for compatibility."""
@@ -171,7 +176,7 @@ class FakeLDAP(object):
Args:
dn -- a dn
attrs -- a list of tuples in the following form:
([MOD_ADD | MOD_DELETE], attribute, value)
([MOD_ADD | MOD_DELETE | MOD_REPACE], attribute, value)
"""
redis = datastore.Redis.instance()
@@ -181,6 +186,8 @@ class FakeLDAP(object):
values = _from_json(redis.hget(key, k))
if cmd == MOD_ADD:
values.append(v)
elif cmd == MOD_REPLACE:
values = [v]
else:
values.remove(v)
values = redis.hset(key, k, _to_json(values))
@@ -207,6 +214,7 @@ class FakeLDAP(object):
# get the attributes from redis
attrs = redis.hgetall(key)
# turn the values from redis into lists
# pylint: disable-msg=E1103
attrs = dict([(k, _from_json(v))
for k, v in attrs.iteritems()])
# filter the objects by query
@@ -215,12 +223,12 @@ class FakeLDAP(object):
attrs = dict([(k, v) for k, v in attrs.iteritems()
if not fields or k in fields])
objects.append((key[len(self.__redis_prefix):], attrs))
# pylint: enable-msg=E1103
if objects == []:
raise NO_SUCH_OBJECT()
return objects
@property
def __redis_prefix(self):
def __redis_prefix(self): # pylint: disable-msg=R0201
"""Get the prefix to use for all redis keys."""
return 'ldap:'

View File

@@ -34,7 +34,7 @@ from nova import flags
FLAGS = flags.FLAGS
flags.DEFINE_string('ldap_url', 'ldap://localhost',
'Point this at your ldap server')
flags.DEFINE_string('ldap_password', 'changeme', 'LDAP password')
flags.DEFINE_string('ldap_password', 'changeme', 'LDAP password')
flags.DEFINE_string('ldap_user_dn', 'cn=Manager,dc=example,dc=com',
'DN of admin user')
flags.DEFINE_string('ldap_user_unit', 'Users', 'OID for Users')
@@ -63,14 +63,18 @@ flags.DEFINE_string('ldap_developer',
# to define a set interface for AuthDrivers. I'm delaying
# creating this now because I'm expecting an auth refactor
# in which we may want to change the interface a bit more.
class LdapDriver(object):
"""Ldap Auth driver
Defines enter and exit and therefore supports the with/as syntax.
"""
def __init__(self):
"""Imports the LDAP module"""
self.ldap = __import__('ldap')
self.conn = None
def __enter__(self):
"""Creates the connection to LDAP"""
@@ -78,7 +82,7 @@ class LdapDriver(object):
self.conn.simple_bind_s(FLAGS.ldap_user_dn, FLAGS.ldap_password)
return self
def __exit__(self, type, value, traceback):
def __exit__(self, exc_type, exc_value, traceback):
"""Destroys the connection to LDAP"""
self.conn.unbind_s()
return False
@@ -95,13 +99,6 @@ class LdapDriver(object):
dn = FLAGS.ldap_user_subtree
return self.__to_user(self.__find_object(dn, query))
def get_key_pair(self, uid, key_name):
"""Retrieve key pair by uid and key name"""
dn = 'cn=%s,%s' % (key_name,
self.__uid_to_dn(uid))
attr = self.__find_object(dn, '(objectclass=novaKeyPair)')
return self.__to_key_pair(uid, attr)
def get_project(self, pid):
"""Retrieve project by id"""
dn = 'cn=%s,%s' % (pid,
@@ -115,19 +112,13 @@ class LdapDriver(object):
'(objectclass=novaUser)')
return [self.__to_user(attr) for attr in attrs]
def get_key_pairs(self, uid):
"""Retrieve list of key pairs"""
attrs = self.__find_objects(self.__uid_to_dn(uid),
'(objectclass=novaKeyPair)')
return [self.__to_key_pair(uid, attr) for attr in attrs]
def get_projects(self, uid=None):
"""Retrieve list of projects"""
filter = '(objectclass=novaProject)'
pattern = '(objectclass=novaProject)'
if uid:
filter = "(&%s(member=%s))" % (filter, self.__uid_to_dn(uid))
pattern = "(&%s(member=%s))" % (pattern, self.__uid_to_dn(uid))
attrs = self.__find_objects(FLAGS.ldap_project_subtree,
filter)
pattern)
return [self.__to_project(attr) for attr in attrs]
def create_user(self, name, access_key, secret_key, is_admin):
@@ -150,21 +141,6 @@ class LdapDriver(object):
self.conn.add_s(self.__uid_to_dn(name), attr)
return self.__to_user(dict(attr))
def create_key_pair(self, uid, key_name, public_key, fingerprint):
"""Create a key pair"""
# TODO(vish): possibly refactor this to store keys in their own ou
# and put dn reference in the user object
attr = [
('objectclass', ['novaKeyPair']),
('cn', [key_name]),
('sshPublicKey', [public_key]),
('keyFingerprint', [fingerprint]),
]
self.conn.add_s('cn=%s,%s' % (key_name,
self.__uid_to_dn(uid)),
attr)
return self.__to_key_pair(uid, dict(attr))
def create_project(self, name, manager_uid,
description=None, member_uids=None):
"""Create a project"""
@@ -194,11 +170,28 @@ class LdapDriver(object):
('cn', [name]),
('description', [description]),
('projectManager', [manager_dn]),
('member', members)
]
('member', members)]
self.conn.add_s('cn=%s,%s' % (name, FLAGS.ldap_project_subtree), attr)
return self.__to_project(dict(attr))
def modify_project(self, project_id, manager_uid=None, description=None):
"""Modify an existing project"""
if not manager_uid and not description:
return
attr = []
if manager_uid:
if not self.__user_exists(manager_uid):
raise exception.NotFound("Project can't be modified because "
"manager %s doesn't exist" %
manager_uid)
manager_dn = self.__uid_to_dn(manager_uid)
attr.append((self.ldap.MOD_REPLACE, 'projectManager', manager_dn))
if description:
attr.append((self.ldap.MOD_REPLACE, 'description', description))
self.conn.modify_s('cn=%s,%s' % (project_id,
FLAGS.ldap_project_subtree),
attr)
def add_to_project(self, uid, project_id):
"""Add user to project"""
dn = 'cn=%s,%s' % (project_id, FLAGS.ldap_project_subtree)
@@ -262,18 +255,8 @@ class LdapDriver(object):
"""Delete a user"""
if not self.__user_exists(uid):
raise exception.NotFound("User %s doesn't exist" % uid)
self.__delete_key_pairs(uid)
self.__remove_from_all(uid)
self.conn.delete_s('uid=%s,%s' % (uid,
FLAGS.ldap_user_subtree))
def delete_key_pair(self, uid, key_name):
"""Delete a key pair"""
if not self.__key_pair_exists(uid, key_name):
raise exception.NotFound("Key Pair %s doesn't exist for user %s" %
(key_name, uid))
self.conn.delete_s('cn=%s,uid=%s,%s' % (key_name, uid,
FLAGS.ldap_user_subtree))
self.conn.delete_s(self.__uid_to_dn(uid))
def delete_project(self, project_id):
"""Delete a project"""
@@ -281,15 +264,23 @@ class LdapDriver(object):
self.__delete_roles(project_dn)
self.__delete_group(project_dn)
def modify_user(self, uid, access_key=None, secret_key=None, admin=None):
"""Modify an existing project"""
if not access_key and not secret_key and admin is None:
return
attr = []
if access_key:
attr.append((self.ldap.MOD_REPLACE, 'accessKey', access_key))
if secret_key:
attr.append((self.ldap.MOD_REPLACE, 'secretKey', secret_key))
if admin is not None:
attr.append((self.ldap.MOD_REPLACE, 'isAdmin', str(admin).upper()))
self.conn.modify_s(self.__uid_to_dn(uid), attr)
def __user_exists(self, uid):
"""Check if user exists"""
return self.get_user(uid) != None
def __key_pair_exists(self, uid, key_name):
"""Check if key pair exists"""
return self.get_user(uid) != None
return self.get_key_pair(uid, key_name) != None
def __project_exists(self, project_id):
"""Check if project exists"""
return self.get_project(project_id) != None
@@ -310,7 +301,7 @@ class LdapDriver(object):
except self.ldap.NO_SUCH_OBJECT:
return []
# just return the DNs
return [dn for dn, attributes in res]
return [dn for dn, _attributes in res]
def __find_objects(self, dn, query=None, scope=None):
"""Find objects by query"""
@@ -339,14 +330,8 @@ class LdapDriver(object):
"""Check if group exists"""
return self.__find_object(dn, '(objectclass=groupOfNames)') != None
def __delete_key_pairs(self, uid):
"""Delete all key pairs for user"""
keys = self.get_key_pairs(uid)
if keys != None:
for key in keys:
self.delete_key_pair(uid, key['name'])
def __role_to_dn(self, role, project_id=None):
@staticmethod
def __role_to_dn(role, project_id=None):
"""Convert role to corresponding dn"""
if project_id == None:
return FLAGS.__getitem__("ldap_%s" % role).value
@@ -356,7 +341,7 @@ class LdapDriver(object):
FLAGS.ldap_project_subtree)
def __create_group(self, group_dn, name, uid,
description, member_uids = None):
description, member_uids=None):
"""Create a group"""
if self.__group_exists(group_dn):
raise exception.Duplicate("Group can't be created because "
@@ -375,8 +360,7 @@ class LdapDriver(object):
('objectclass', ['groupOfNames']),
('cn', [name]),
('description', [description]),
('member', members)
]
('member', members)]
self.conn.add_s(group_dn, attr)
def __is_in_group(self, uid, group_dn):
@@ -402,9 +386,7 @@ class LdapDriver(object):
if self.__is_in_group(uid, group_dn):
raise exception.Duplicate("User %s is already a member of "
"the group %s" % (uid, group_dn))
attr = [
(self.ldap.MOD_ADD, 'member', self.__uid_to_dn(uid))
]
attr = [(self.ldap.MOD_ADD, 'member', self.__uid_to_dn(uid))]
self.conn.modify_s(group_dn, attr)
def __remove_from_group(self, uid, group_dn):
@@ -432,7 +414,7 @@ class LdapDriver(object):
self.conn.modify_s(group_dn, attr)
except self.ldap.OBJECT_CLASS_VIOLATION:
logging.debug("Attempted to remove the last member of a group. "
"Deleting the group at %s instead." % group_dn )
"Deleting the group at %s instead.", group_dn)
self.__delete_group(group_dn)
def __remove_from_all(self, uid):
@@ -440,7 +422,6 @@ class LdapDriver(object):
if not self.__user_exists(uid):
raise exception.NotFound("User %s can't be removed from all "
"because the user doesn't exist" % (uid,))
dn = self.__uid_to_dn(uid)
role_dns = self.__find_group_dns_with_member(
FLAGS.role_project_subtree, uid)
for role_dn in role_dns:
@@ -448,7 +429,7 @@ class LdapDriver(object):
project_dns = self.__find_group_dns_with_member(
FLAGS.ldap_project_subtree, uid)
for project_dn in project_dns:
self.__safe_remove_from_group(uid, role_dn)
self.__safe_remove_from_group(uid, project_dn)
def __delete_group(self, group_dn):
"""Delete Group"""
@@ -461,7 +442,8 @@ class LdapDriver(object):
for role_dn in self.__find_role_dns(project_dn):
self.__delete_group(role_dn)
def __to_user(self, attr):
@staticmethod
def __to_user(attr):
"""Convert ldap attributes to User object"""
if attr == None:
return None
@@ -470,20 +452,7 @@ class LdapDriver(object):
'name': attr['cn'][0],
'access': attr['accessKey'][0],
'secret': attr['secretKey'][0],
'admin': (attr['isAdmin'][0] == 'TRUE')
}
def __to_key_pair(self, owner, attr):
"""Convert ldap attributes to KeyPair object"""
if attr == None:
return None
return {
'id': attr['cn'][0],
'name': attr['cn'][0],
'owner_id': owner,
'public_key': attr['sshPublicKey'][0],
'fingerprint': attr['keyFingerprint'][0],
}
'admin': (attr['isAdmin'][0] == 'TRUE')}
def __to_project(self, attr):
"""Convert ldap attributes to Project object"""
@@ -495,21 +464,22 @@ class LdapDriver(object):
'name': attr['cn'][0],
'project_manager_id': self.__dn_to_uid(attr['projectManager'][0]),
'description': attr.get('description', [None])[0],
'member_ids': [self.__dn_to_uid(x) for x in member_dns]
}
'member_ids': [self.__dn_to_uid(x) for x in member_dns]}
def __dn_to_uid(self, dn):
@staticmethod
def __dn_to_uid(dn):
"""Convert user dn to uid"""
return dn.split(',')[0].split('=')[1]
def __uid_to_dn(self, dn):
@staticmethod
def __uid_to_dn(dn):
"""Convert uid to dn"""
return 'uid=%s,%s' % (dn, FLAGS.ldap_user_subtree)
class FakeLdapDriver(LdapDriver):
"""Fake Ldap Auth driver"""
def __init__(self):
def __init__(self): # pylint: disable-msg=W0231
__import__('nova.auth.fakeldap')
self.ldap = sys.modules['nova.auth.fakeldap']

View File

@@ -23,17 +23,17 @@ Nova authentication management
import logging
import os
import shutil
import string
import string # pylint: disable-msg=W0402
import tempfile
import uuid
import zipfile
from nova import crypto
from nova import db
from nova import exception
from nova import flags
from nova import utils
from nova.auth import signer
from nova.network import vpn
FLAGS = flags.FLAGS
@@ -44,7 +44,7 @@ flags.DEFINE_list('allowed_roles',
# NOTE(vish): a user with one of these roles will be a superuser and
# have access to all api commands
flags.DEFINE_list('superuser_roles', ['cloudadmin'],
'Roles that ignore rbac checking completely')
'Roles that ignore authorization checking completely')
# NOTE(vish): a user with one of these roles will have it for every
# project, even if he or she is not a member of the project
@@ -69,7 +69,7 @@ flags.DEFINE_string('credential_cert_subject',
'/C=US/ST=California/L=MountainView/O=AnsoLabs/'
'OU=NovaDev/CN=%s-%s',
'Subject for certificate for users')
flags.DEFINE_string('auth_driver', 'nova.auth.ldapdriver.FakeLdapDriver',
flags.DEFINE_string('auth_driver', 'nova.auth.dbdriver.DbDriver',
'Driver that auth manager uses')
@@ -128,24 +128,6 @@ class User(AuthBase):
def is_project_manager(self, project):
return AuthManager().is_project_manager(self, project)
def generate_key_pair(self, name):
return AuthManager().generate_key_pair(self.id, name)
def create_key_pair(self, name, public_key, fingerprint):
return AuthManager().create_key_pair(self.id,
name,
public_key,
fingerprint)
def get_key_pair(self, name):
return AuthManager().get_key_pair(self.id, name)
def delete_key_pair(self, name):
return AuthManager().delete_key_pair(self.id, name)
def get_key_pairs(self):
return AuthManager().get_key_pairs(self.id)
def __repr__(self):
return "User('%s', '%s', '%s', '%s', %s)" % (self.id,
self.name,
@@ -154,29 +136,6 @@ class User(AuthBase):
self.admin)
class KeyPair(AuthBase):
"""Represents an ssh key returned from the datastore
Even though this object is named KeyPair, only the public key and
fingerprint is stored. The user's private key is not saved.
"""
def __init__(self, id, name, owner_id, public_key, fingerprint):
AuthBase.__init__(self)
self.id = id
self.name = name
self.owner_id = owner_id
self.public_key = public_key
self.fingerprint = fingerprint
def __repr__(self):
return "KeyPair('%s', '%s', '%s', '%s', '%s')" % (self.id,
self.name,
self.owner_id,
self.public_key,
self.fingerprint)
class Project(AuthBase):
"""Represents a Project returned from the datastore"""
@@ -194,12 +153,12 @@ class Project(AuthBase):
@property
def vpn_ip(self):
ip, port = AuthManager().get_project_vpn_data(self)
ip, _port = AuthManager().get_project_vpn_data(self)
return ip
@property
def vpn_port(self):
ip, port = AuthManager().get_project_vpn_data(self)
_ip, port = AuthManager().get_project_vpn_data(self)
return port
def has_manager(self, user):
@@ -221,11 +180,9 @@ class Project(AuthBase):
return AuthManager().get_credentials(user, self)
def __repr__(self):
return "Project('%s', '%s', '%s', '%s', %s)" % (self.id,
self.name,
self.project_manager_id,
self.description,
self.member_ids)
return "Project('%s', '%s', '%s', '%s', %s)" % \
(self.id, self.name, self.project_manager_id, self.description,
self.member_ids)
class AuthManager(object):
@@ -254,6 +211,7 @@ class AuthManager(object):
__init__ is run every time AuthManager() is called, so we only
reset the driver if it is not set or a new driver is specified.
"""
self.network_manager = utils.import_object(FLAGS.network_manager)
if driver or not getattr(self, 'driver', None):
self.driver = utils.import_class(driver or FLAGS.auth_driver)
@@ -297,7 +255,7 @@ class AuthManager(object):
@return: User and project that the request represents.
"""
# TODO(vish): check for valid timestamp
(access_key, sep, project_id) = access.partition(':')
(access_key, _sep, project_id) = access.partition(':')
logging.info('Looking up user: %r', access_key)
user = self.get_user_from_access_key(access_key)
@@ -308,7 +266,7 @@ class AuthManager(object):
# NOTE(vish): if we stop using project name as id we need better
# logic to find a default project for user
if project_id is '':
if project_id == '':
project_id = user.name
project = self.get_project(project_id)
@@ -320,7 +278,8 @@ class AuthManager(object):
raise exception.NotFound('User %s is not a member of project %s' %
(user.id, project.id))
if check_type == 's3':
expected_signature = signer.Signer(user.secret.encode()).s3_authorization(headers, verb, path)
sign = signer.Signer(user.secret.encode())
expected_signature = sign.s3_authorization(headers, verb, path)
logging.debug('user.secret: %s', user.secret)
logging.debug('expected_signature: %s', expected_signature)
logging.debug('signature: %s', signature)
@@ -345,7 +304,7 @@ class AuthManager(object):
return "%s:%s" % (user.access, Project.safe_id(project))
def is_superuser(self, user):
"""Checks for superuser status, allowing user to bypass rbac
"""Checks for superuser status, allowing user to bypass authorization
@type user: User or uid
@param user: User to check.
@@ -465,7 +424,8 @@ class AuthManager(object):
with self.driver() as drv:
drv.remove_role(User.safe_id(user), role, Project.safe_id(project))
def get_roles(self, project_roles=True):
@staticmethod
def get_roles(project_roles=True):
"""Get list of allowed roles"""
if project_roles:
return list(set(FLAGS.allowed_roles) - set(FLAGS.global_roles))
@@ -493,8 +453,8 @@ class AuthManager(object):
return []
return [Project(**project_dict) for project_dict in project_list]
def create_project(self, name, manager_user,
description=None, member_users=None):
def create_project(self, name, manager_user, description=None,
member_users=None, context=None):
"""Create a project
@type name: str
@@ -518,12 +478,33 @@ class AuthManager(object):
if member_users:
member_users = [User.safe_id(u) for u in member_users]
with self.driver() as drv:
project_dict = drv.create_project(name,
User.safe_id(manager_user),
description,
member_users)
project_dict = drv.create_project(name,
User.safe_id(manager_user),
description,
member_users)
if project_dict:
return Project(**project_dict)
project = Project(**project_dict)
return project
def modify_project(self, project, manager_user=None, description=None):
"""Modify a project
@type name: Project or project_id
@param project: The project to modify.
@type manager_user: User or uid
@param manager_user: This user will be the new project manager.
@type description: str
@param project: This will be the new description of the project.
"""
if manager_user:
manager_user = User.safe_id(manager_user)
with self.driver() as drv:
drv.modify_project(Project.safe_id(project),
manager_user,
description)
def add_to_project(self, user, project):
"""Add user to project"""
@@ -549,7 +530,8 @@ class AuthManager(object):
return drv.remove_from_project(User.safe_id(user),
Project.safe_id(project))
def get_project_vpn_data(self, project):
@staticmethod
def get_project_vpn_data(project, context=None):
"""Gets vpn ip and port for project
@type project: Project or project_id
@@ -559,15 +541,19 @@ class AuthManager(object):
@return: A tuple containing (ip, port) or None, None if vpn has
not been allocated for user.
"""
network_data = vpn.NetworkData.lookup(Project.safe_id(project))
if not network_data:
raise exception.NotFound('project network data has not been set')
return (network_data.ip, network_data.port)
def delete_project(self, project):
network_ref = db.project_get_network(context,
Project.safe_id(project))
if not network_ref['vpn_public_port']:
raise exception.NotFound('project network data has not been set')
return (network_ref['vpn_public_address'],
network_ref['vpn_public_port'])
def delete_project(self, project, context=None):
"""Deletes a project"""
with self.driver() as drv:
return drv.delete_project(Project.safe_id(project))
drv.delete_project(Project.safe_id(project))
def get_user(self, uid):
"""Retrieves a user by id"""
@@ -613,75 +599,29 @@ class AuthManager(object):
@rtype: User
@return: The new user.
"""
if access == None: access = str(uuid.uuid4())
if secret == None: secret = str(uuid.uuid4())
if access == None:
access = str(uuid.uuid4())
if secret == None:
secret = str(uuid.uuid4())
with self.driver() as drv:
user_dict = drv.create_user(name, access, secret, admin)
if user_dict:
return User(**user_dict)
def delete_user(self, user):
"""Deletes a user"""
"""Deletes a user
Additionally deletes all users key_pairs"""
uid = User.safe_id(user)
db.key_pair_destroy_all_by_user(None, uid)
with self.driver() as drv:
drv.delete_user(User.safe_id(user))
drv.delete_user(uid)
def generate_key_pair(self, user, key_name):
"""Generates a key pair for a user
Generates a public and private key, stores the public key using the
key_name, and returns the private key and fingerprint.
@type user: User or uid
@param user: User for which to create key pair.
@type key_name: str
@param key_name: Name to use for the generated KeyPair.
@rtype: tuple (private_key, fingerprint)
@return: A tuple containing the private_key and fingerprint.
"""
# NOTE(vish): generating key pair is slow so check for legal
# creation before creating keypair
def modify_user(self, user, access_key=None, secret_key=None, admin=None):
"""Modify credentials for a user"""
uid = User.safe_id(user)
with self.driver() as drv:
if not drv.get_user(uid):
raise exception.NotFound("User %s doesn't exist" % user)
if drv.get_key_pair(uid, key_name):
raise exception.Duplicate("The keypair %s already exists"
% key_name)
private_key, public_key, fingerprint = crypto.generate_key_pair()
self.create_key_pair(uid, key_name, public_key, fingerprint)
return private_key, fingerprint
def create_key_pair(self, user, key_name, public_key, fingerprint):
"""Creates a key pair for user"""
with self.driver() as drv:
kp_dict = drv.create_key_pair(User.safe_id(user),
key_name,
public_key,
fingerprint)
if kp_dict:
return KeyPair(**kp_dict)
def get_key_pair(self, user, key_name):
"""Retrieves a key pair for user"""
with self.driver() as drv:
kp_dict = drv.get_key_pair(User.safe_id(user), key_name)
if kp_dict:
return KeyPair(**kp_dict)
def get_key_pairs(self, user):
"""Retrieves all key pairs for user"""
with self.driver() as drv:
kp_list = drv.get_key_pairs(User.safe_id(user))
if not kp_list:
return []
return [KeyPair(**kp_dict) for kp_dict in kp_list]
def delete_key_pair(self, user, key_name):
"""Deletes a key pair for user"""
with self.driver() as drv:
drv.delete_key_pair(User.safe_id(user), key_name)
drv.modify_user(uid, access_key, secret_key, admin)
def get_credentials(self, user, project=None):
"""Get credential zip for user in project"""
@@ -700,15 +640,18 @@ class AuthManager(object):
zippy.writestr(FLAGS.credential_key_file, private_key)
zippy.writestr(FLAGS.credential_cert_file, signed_cert)
network_data = vpn.NetworkData.lookup(pid)
if network_data:
configfile = open(FLAGS.vpn_client_template,"r")
try:
(vpn_ip, vpn_port) = self.get_project_vpn_data(project)
except exception.NotFound:
vpn_ip = None
if vpn_ip:
configfile = open(FLAGS.vpn_client_template, "r")
s = string.Template(configfile.read())
configfile.close()
config = s.substitute(keyfile=FLAGS.credential_key_file,
certfile=FLAGS.credential_cert_file,
ip=network_data.ip,
port=network_data.port)
ip=vpn_ip,
port=vpn_port)
zippy.writestr(FLAGS.credential_vpn_file, config)
else:
logging.warn("No vpn data for project %s" %
@@ -717,10 +660,10 @@ class AuthManager(object):
zippy.writestr(FLAGS.ca_file, crypto.fetch_ca(user.id))
zippy.close()
with open(zf, 'rb') as f:
buffer = f.read()
read_buffer = f.read()
shutil.rmtree(tmpdir)
return buffer
return read_buffer
def get_environment_rc(self, user, project=None):
"""Get credential zip for user in project"""
@@ -731,18 +674,18 @@ class AuthManager(object):
pid = Project.safe_id(project)
return self.__generate_rc(user.access, user.secret, pid)
def __generate_rc(self, access, secret, pid):
@staticmethod
def __generate_rc(access, secret, pid):
"""Generate rc file for user"""
rc = open(FLAGS.credentials_template).read()
rc = rc % { 'access': access,
'project': pid,
'secret': secret,
'ec2': FLAGS.ec2_url,
's3': 'http://%s:%s' % (FLAGS.s3_host, FLAGS.s3_port),
'nova': FLAGS.ca_file,
'cert': FLAGS.credential_cert_file,
'key': FLAGS.credential_key_file,
}
rc = rc % {'access': access,
'project': pid,
'secret': secret,
'ec2': FLAGS.ec2_url,
's3': 'http://%s:%s' % (FLAGS.s3_host, FLAGS.s3_port),
'nova': FLAGS.ca_file,
'cert': FLAGS.credential_cert_file,
'key': FLAGS.credential_key_file}
return rc
def _generate_x509_cert(self, uid, pid):
@@ -753,6 +696,7 @@ class AuthManager(object):
signed_cert = crypto.sign_csr(csr, pid)
return (private_key, signed_cert)
def __cert_subject(self, uid):
@staticmethod
def __cert_subject(uid):
"""Helper to generate cert subject"""
return FLAGS.credential_cert_subject % (uid, utils.isotime())

View File

@@ -1,55 +0,0 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from nova import exception
from nova.auth import manager
def allow(*roles):
def wrap(f):
def wrapped_f(self, context, *args, **kwargs):
if context.user.is_superuser():
return f(self, context, *args, **kwargs)
for role in roles:
if __matches_role(context, role):
return f(self, context, *args, **kwargs)
raise exception.NotAuthorized()
return wrapped_f
return wrap
def deny(*roles):
def wrap(f):
def wrapped_f(self, context, *args, **kwargs):
if context.user.is_superuser():
return f(self, context, *args, **kwargs)
for role in roles:
if __matches_role(context, role):
raise exception.NotAuthorized()
return f(self, context, *args, **kwargs)
return wrapped_f
return wrap
def __matches_role(context, role):
if role == 'all':
return True
if role == 'none':
return False
return context.project.has_role(context.user.id, role)

View File

@@ -50,15 +50,15 @@ import logging
import urllib
# NOTE(vish): for new boto
import boto
import boto
# NOTE(vish): for old boto
import boto.utils
import boto.utils
from nova.exception import Error
class Signer(object):
""" hacked up code from boto/connection.py """
"""Hacked up code from boto/connection.py"""
def __init__(self, secret_key):
self.hmac = hmac.new(secret_key, digestmod=hashlib.sha1)
@@ -66,22 +66,27 @@ class Signer(object):
self.hmac_256 = hmac.new(secret_key, digestmod=hashlib.sha256)
def s3_authorization(self, headers, verb, path):
"""Generate S3 authorization string."""
c_string = boto.utils.canonical_string(verb, path, headers)
hmac = self.hmac.copy()
hmac.update(c_string)
b64_hmac = base64.encodestring(hmac.digest()).strip()
hmac_copy = self.hmac.copy()
hmac_copy.update(c_string)
b64_hmac = base64.encodestring(hmac_copy.digest()).strip()
return b64_hmac
def generate(self, params, verb, server_string, path):
"""Generate auth string according to what SignatureVersion is given."""
if params['SignatureVersion'] == '0':
return self._calc_signature_0(params)
if params['SignatureVersion'] == '1':
return self._calc_signature_1(params)
if params['SignatureVersion'] == '2':
return self._calc_signature_2(params, verb, server_string, path)
raise Error('Unknown Signature Version: %s' % self.SignatureVersion)
raise Error('Unknown Signature Version: %s' %
params['SignatureVersion'])
def _get_utf8_value(self, value):
@staticmethod
def _get_utf8_value(value):
"""Get the UTF8-encoded version of a value."""
if not isinstance(value, str) and not isinstance(value, unicode):
value = str(value)
if isinstance(value, unicode):
@@ -90,10 +95,11 @@ class Signer(object):
return value
def _calc_signature_0(self, params):
"""Generate AWS signature version 0 string."""
s = params['Action'] + params['Timestamp']
self.hmac.update(s)
keys = params.keys()
keys.sort(cmp = lambda x, y: cmp(x.lower(), y.lower()))
keys.sort(cmp=lambda x, y: cmp(x.lower(), y.lower()))
pairs = []
for key in keys:
val = self._get_utf8_value(params[key])
@@ -101,8 +107,9 @@ class Signer(object):
return base64.b64encode(self.hmac.digest())
def _calc_signature_1(self, params):
"""Generate AWS signature version 1 string."""
keys = params.keys()
keys.sort(cmp = lambda x, y: cmp(x.lower(), y.lower()))
keys.sort(cmp=lambda x, y: cmp(x.lower(), y.lower()))
pairs = []
for key in keys:
self.hmac.update(key)
@@ -112,30 +119,34 @@ class Signer(object):
return base64.b64encode(self.hmac.digest())
def _calc_signature_2(self, params, verb, server_string, path):
"""Generate AWS signature version 2 string."""
logging.debug('using _calc_signature_2')
string_to_sign = '%s\n%s\n%s\n' % (verb, server_string, path)
if self.hmac_256:
hmac = self.hmac_256
current_hmac = self.hmac_256
params['SignatureMethod'] = 'HmacSHA256'
else:
hmac = self.hmac
current_hmac = self.hmac
params['SignatureMethod'] = 'HmacSHA1'
keys = params.keys()
keys.sort()
pairs = []
for key in keys:
val = self._get_utf8_value(params[key])
pairs.append(urllib.quote(key, safe='') + '=' + urllib.quote(val, safe='-_~'))
val = urllib.quote(val, safe='-_~')
pairs.append(urllib.quote(key, safe='') + '=' + val)
qs = '&'.join(pairs)
logging.debug('query string: %s' % qs)
logging.debug('query string: %s', qs)
string_to_sign += qs
logging.debug('string_to_sign: %s' % string_to_sign)
hmac.update(string_to_sign)
b64 = base64.b64encode(hmac.digest())
logging.debug('len(b64)=%d' % len(b64))
logging.debug('base64 encoded digest: %s' % b64)
logging.debug('string_to_sign: %s', string_to_sign)
current_hmac.update(string_to_sign)
b64 = base64.b64encode(current_hmac.digest())
logging.debug('len(b64)=%d', len(b64))
logging.debug('base64 encoded digest: %s', b64)
return b64
if __name__ == '__main__':
print Signer('foo').generate({"SignatureMethod": 'HmacSHA256', 'SignatureVersion': '2'}, "get", "server", "/foo")
print Signer('foo').generate({'SignatureMethod': 'HmacSHA256',
'SignatureVersion': '2'},
'get', 'server', '/foo')

View File

@@ -26,10 +26,7 @@ before trying to run this.
import logging
import redis
from nova import exception
from nova import flags
from nova import utils
FLAGS = flags.FLAGS
flags.DEFINE_string('redis_host', '127.0.0.1',
@@ -54,209 +51,3 @@ class Redis(object):
return cls._instance
class ConnectionError(exception.Error):
pass
def absorb_connection_error(fn):
def _wrapper(*args, **kwargs):
try:
return fn(*args, **kwargs)
except redis.exceptions.ConnectionError, ce:
raise ConnectionError(str(ce))
return _wrapper
class BasicModel(object):
"""
All Redis-backed data derives from this class.
You MUST specify an identifier() property that returns a unique string
per instance.
You MUST have an initializer that takes a single argument that is a value
returned by identifier() to load a new class with.
You may want to specify a dictionary for default_state().
You may also specify override_type at the class left to use a key other
than __class__.__name__.
You override save and destroy calls to automatically build and destroy
associations.
"""
override_type = None
@absorb_connection_error
def __init__(self):
state = Redis.instance().hgetall(self.__redis_key)
if state:
self.initial_state = state
self.state = dict(self.initial_state)
else:
self.initial_state = {}
self.state = self.default_state()
def default_state(self):
"""You probably want to define this in your subclass"""
return {}
@classmethod
def _redis_name(cls):
return cls.override_type or cls.__name__.lower()
@classmethod
def lookup(cls, identifier):
rv = cls(identifier)
if rv.is_new_record():
return None
else:
return rv
@classmethod
@absorb_connection_error
def all(cls):
"""yields all objects in the store"""
redis_set = cls._redis_set_name(cls.__name__)
for identifier in Redis.instance().smembers(redis_set):
yield cls(identifier)
@classmethod
def associated_to(cls, foreign_type, foreign_id):
for identifier in cls.associated_keys(foreign_type, foreign_id):
yield cls(identifier)
@classmethod
@absorb_connection_error
def associated_keys(cls, foreign_type, foreign_id):
redis_set = cls._redis_association_name(foreign_type, foreign_id)
return Redis.instance().smembers(redis_set) or []
@classmethod
def _redis_set_name(cls, kls_name):
# stupidly pluralize (for compatiblity with previous codebase)
return kls_name.lower() + "s"
@classmethod
def _redis_association_name(cls, foreign_type, foreign_id):
return cls._redis_set_name("%s:%s:%s" %
(foreign_type, foreign_id, cls._redis_name()))
@property
def identifier(self):
"""You DEFINITELY want to define this in your subclass"""
raise NotImplementedError("Your subclass should define identifier")
@property
def __redis_key(self):
return '%s:%s' % (self._redis_name(), self.identifier)
def __repr__(self):
return "<%s:%s>" % (self.__class__.__name__, self.identifier)
def keys(self):
return self.state.keys()
def copy(self):
copyDict = {}
for item in self.keys():
copyDict[item] = self[item]
return copyDict
def get(self, item, default):
return self.state.get(item, default)
def update(self, update_dict):
return self.state.update(update_dict)
def setdefault(self, item, default):
return self.state.setdefault(item, default)
def __contains__(self, item):
return item in self.state
def __getitem__(self, item):
return self.state[item]
def __setitem__(self, item, val):
self.state[item] = val
return self.state[item]
def __delitem__(self, item):
"""We don't support this"""
raise Exception("Silly monkey, models NEED all their properties.")
def is_new_record(self):
return self.initial_state == {}
@absorb_connection_error
def add_to_index(self):
"""Each insance of Foo has its id tracked int the set named Foos"""
set_name = self.__class__._redis_set_name(self.__class__.__name__)
Redis.instance().sadd(set_name, self.identifier)
@absorb_connection_error
def remove_from_index(self):
"""Remove id of this instance from the set tracking ids of this type"""
set_name = self.__class__._redis_set_name(self.__class__.__name__)
Redis.instance().srem(set_name, self.identifier)
@absorb_connection_error
def associate_with(self, foreign_type, foreign_id):
"""Add this class id into the set foreign_type:foreign_id:this_types"""
# note the extra 's' on the end is for plurality
# to match the old data without requiring a migration of any sort
self.add_associated_model_to_its_set(foreign_type, foreign_id)
redis_set = self.__class__._redis_association_name(foreign_type,
foreign_id)
Redis.instance().sadd(redis_set, self.identifier)
@absorb_connection_error
def unassociate_with(self, foreign_type, foreign_id):
"""Delete from foreign_type:foreign_id:this_types set"""
redis_set = self.__class__._redis_association_name(foreign_type,
foreign_id)
Redis.instance().srem(redis_set, self.identifier)
def add_associated_model_to_its_set(self, model_type, model_id):
"""
When associating an X to a Y, save Y for newer timestamp, etc, and to
make sure to save it if Y is a new record.
If the model_type isn't found as a usable class, ignore it, this can
happen when associating to things stored in LDAP (user, project, ...).
"""
table = globals()
klsname = model_type.capitalize()
if table.has_key(klsname):
model_class = table[klsname]
model_inst = model_class(model_id)
model_inst.save()
@absorb_connection_error
def save(self):
"""
update the directory with the state from this model
also add it to the index of items of the same type
then set the initial_state = state so new changes are tracked
"""
# TODO(ja): implement hmset in redis-py and use it
# instead of multiple calls to hset
if self.is_new_record():
self["create_time"] = utils.isotime()
for key, val in self.state.iteritems():
Redis.instance().hset(self.__redis_key, key, val)
self.add_to_index()
self.initial_state = dict(self.state)
return True
@absorb_connection_error
def destroy(self):
"""deletes all related records from datastore."""
logging.info("Destroying datamodel for %s %s",
self.__class__.__name__, self.identifier)
Redis.instance().delete(self.__redis_key)
self.remove_from_index()
return True

View File

@@ -1,211 +0,0 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Admin API controller, exposed through http via the api worker.
"""
import base64
from nova.auth import manager
from nova.compute import model
def user_dict(user, base64_file=None):
"""Convert the user object to a result dict"""
if user:
return {
'username': user.id,
'accesskey': user.access,
'secretkey': user.secret,
'file': base64_file}
else:
return {}
def project_dict(project):
"""Convert the project object to a result dict"""
if project:
return {
'projectname': project.id,
'project_manager_id': project.project_manager_id,
'description': project.description}
else:
return {}
def host_dict(host):
"""Convert a host model object to a result dict"""
if host:
return host.state
else:
return {}
def admin_only(target):
"""Decorator for admin-only API calls"""
def wrapper(*args, **kwargs):
"""Internal wrapper method for admin-only API calls"""
context = args[1]
if context.user.is_admin():
return target(*args, **kwargs)
else:
return {}
return wrapper
class AdminController(object):
"""
API Controller for users, hosts, nodes, and workers.
Trivial admin_only wrapper will be replaced with RBAC,
allowing project managers to administer project users.
"""
def __str__(self):
return 'AdminController'
@admin_only
def describe_user(self, _context, name, **_kwargs):
"""Returns user data, including access and secret keys."""
return user_dict(manager.AuthManager().get_user(name))
@admin_only
def describe_users(self, _context, **_kwargs):
"""Returns all users - should be changed to deal with a list."""
return {'userSet':
[user_dict(u) for u in manager.AuthManager().get_users()] }
@admin_only
def register_user(self, _context, name, **_kwargs):
"""Creates a new user, and returns generated credentials."""
return user_dict(manager.AuthManager().create_user(name))
@admin_only
def deregister_user(self, _context, name, **_kwargs):
"""Deletes a single user (NOT undoable.)
Should throw an exception if the user has instances,
volumes, or buckets remaining.
"""
manager.AuthManager().delete_user(name)
return True
@admin_only
def describe_roles(self, context, project_roles=True, **kwargs):
"""Returns a list of allowed roles."""
roles = manager.AuthManager().get_roles(project_roles)
return { 'roles': [{'role': r} for r in roles]}
@admin_only
def describe_user_roles(self, context, user, project=None, **kwargs):
"""Returns a list of roles for the given user.
Omitting project will return any global roles that the user has.
Specifying project will return only project specific roles.
"""
roles = manager.AuthManager().get_user_roles(user, project=project)
return { 'roles': [{'role': r} for r in roles]}
@admin_only
def modify_user_role(self, context, user, role, project=None,
operation='add', **kwargs):
"""Add or remove a role for a user and project."""
if operation == 'add':
manager.AuthManager().add_role(user, role, project)
elif operation == 'remove':
manager.AuthManager().remove_role(user, role, project)
else:
raise exception.ApiError('operation must be add or remove')
return True
@admin_only
def generate_x509_for_user(self, _context, name, project=None, **kwargs):
"""Generates and returns an x509 certificate for a single user.
Is usually called from a client that will wrap this with
access and secret key info, and return a zip file.
"""
if project is None:
project = name
project = manager.AuthManager().get_project(project)
user = manager.AuthManager().get_user(name)
return user_dict(user, base64.b64encode(project.get_credentials(user)))
@admin_only
def describe_project(self, context, name, **kwargs):
"""Returns project data, including member ids."""
return project_dict(manager.AuthManager().get_project(name))
@admin_only
def describe_projects(self, context, user=None, **kwargs):
"""Returns all projects - should be changed to deal with a list."""
return {'projectSet':
[project_dict(u) for u in
manager.AuthManager().get_projects(user=user)]}
@admin_only
def register_project(self, context, name, manager_user, description=None,
member_users=None, **kwargs):
"""Creates a new project"""
return project_dict(
manager.AuthManager().create_project(
name,
manager_user,
description=None,
member_users=None))
@admin_only
def deregister_project(self, context, name):
"""Permanently deletes a project."""
manager.AuthManager().delete_project(name)
return True
@admin_only
def describe_project_members(self, context, name, **kwargs):
project = manager.AuthManager().get_project(name)
result = {
'members': [{'member': m} for m in project.member_ids]}
return result
@admin_only
def modify_project_member(self, context, user, project, operation, **kwargs):
"""Add or remove a user from a project."""
if operation =='add':
manager.AuthManager().add_to_project(user, project)
elif operation == 'remove':
manager.AuthManager().remove_from_project(user, project)
else:
raise exception.ApiError('operation must be add or remove')
return True
@admin_only
def describe_hosts(self, _context, **_kwargs):
"""Returns status info for all nodes. Includes:
* Disk Space
* Instance List
* RAM used
* CPU used
* DHCP servers running
* Iptables / bridges
"""
return {'hostSet': [host_dict(h) for h in model.Host.all()]}
@admin_only
def describe_host(self, _context, name, **_kwargs):
"""Returns status info for single node."""
return host_dict(model.Host.lookup(name))

View File

@@ -1,344 +0,0 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Tornado REST API Request Handlers for Nova functions
Most calls are proxied into the responsible controller.
"""
import logging
import multiprocessing
import random
import re
import urllib
# TODO(termie): replace minidom with etree
from xml.dom import minidom
import tornado.web
from twisted.internet import defer
from nova import crypto
from nova import exception
from nova import flags
from nova import utils
from nova.auth import manager
import nova.cloudpipe.api
from nova.endpoint import cloud
FLAGS = flags.FLAGS
flags.DEFINE_integer('cc_port', 8773, 'cloud controller port')
_log = logging.getLogger("api")
_log.setLevel(logging.DEBUG)
_c2u = re.compile('(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))')
def _camelcase_to_underscore(str):
return _c2u.sub(r'_\1', str).lower().strip('_')
def _underscore_to_camelcase(str):
return ''.join([x[:1].upper() + x[1:] for x in str.split('_')])
def _underscore_to_xmlcase(str):
res = _underscore_to_camelcase(str)
return res[:1].lower() + res[1:]
class APIRequestContext(object):
def __init__(self, handler, user, project):
self.handler = handler
self.user = user
self.project = project
self.request_id = ''.join(
[random.choice('ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890-')
for x in xrange(20)]
)
class APIRequest(object):
def __init__(self, controller, action):
self.controller = controller
self.action = action
def send(self, context, **kwargs):
try:
method = getattr(self.controller,
_camelcase_to_underscore(self.action))
except AttributeError:
_error = ('Unsupported API request: controller = %s,'
'action = %s') % (self.controller, self.action)
_log.warning(_error)
# TODO: Raise custom exception, trap in apiserver,
# and reraise as 400 error.
raise Exception(_error)
args = {}
for key, value in kwargs.items():
parts = key.split(".")
key = _camelcase_to_underscore(parts[0])
if len(parts) > 1:
d = args.get(key, {})
d[parts[1]] = value[0]
value = d
else:
value = value[0]
args[key] = value
for key in args.keys():
if isinstance(args[key], dict):
if args[key] != {} and args[key].keys()[0].isdigit():
s = args[key].items()
s.sort()
args[key] = [v for k, v in s]
d = defer.maybeDeferred(method, context, **args)
d.addCallback(self._render_response, context.request_id)
return d
def _render_response(self, response_data, request_id):
xml = minidom.Document()
response_el = xml.createElement(self.action + 'Response')
response_el.setAttribute('xmlns',
'http://ec2.amazonaws.com/doc/2009-11-30/')
request_id_el = xml.createElement('requestId')
request_id_el.appendChild(xml.createTextNode(request_id))
response_el.appendChild(request_id_el)
if(response_data == True):
self._render_dict(xml, response_el, {'return': 'true'})
else:
self._render_dict(xml, response_el, response_data)
xml.appendChild(response_el)
response = xml.toxml()
xml.unlink()
_log.debug(response)
return response
def _render_dict(self, xml, el, data):
try:
for key in data.keys():
val = data[key]
el.appendChild(self._render_data(xml, key, val))
except:
_log.debug(data)
raise
def _render_data(self, xml, el_name, data):
el_name = _underscore_to_xmlcase(el_name)
data_el = xml.createElement(el_name)
if isinstance(data, list):
for item in data:
data_el.appendChild(self._render_data(xml, 'item', item))
elif isinstance(data, dict):
self._render_dict(xml, data_el, data)
elif hasattr(data, '__dict__'):
self._render_dict(xml, data_el, data.__dict__)
elif isinstance(data, bool):
data_el.appendChild(xml.createTextNode(str(data).lower()))
elif data != None:
data_el.appendChild(xml.createTextNode(str(data)))
return data_el
class RootRequestHandler(tornado.web.RequestHandler):
def get(self):
# available api versions
versions = [
'1.0',
'2007-01-19',
'2007-03-01',
'2007-08-29',
'2007-10-10',
'2007-12-15',
'2008-02-01',
'2008-09-01',
'2009-04-04',
]
for version in versions:
self.write('%s\n' % version)
self.finish()
class MetadataRequestHandler(tornado.web.RequestHandler):
def print_data(self, data):
if isinstance(data, dict):
output = ''
for key in data:
if key == '_name':
continue
output += key
if isinstance(data[key], dict):
if '_name' in data[key]:
output += '=' + str(data[key]['_name'])
else:
output += '/'
output += '\n'
self.write(output[:-1]) # cut off last \n
elif isinstance(data, list):
self.write('\n'.join(data))
else:
self.write(str(data))
def lookup(self, path, data):
items = path.split('/')
for item in items:
if item:
if not isinstance(data, dict):
return data
if not item in data:
return None
data = data[item]
return data
def get(self, path):
cc = self.application.controllers['Cloud']
meta_data = cc.get_metadata(self.request.remote_ip)
if meta_data is None:
_log.error('Failed to get metadata for ip: %s' %
self.request.remote_ip)
raise tornado.web.HTTPError(404)
data = self.lookup(path, meta_data)
if data is None:
raise tornado.web.HTTPError(404)
self.print_data(data)
self.finish()
class APIRequestHandler(tornado.web.RequestHandler):
def get(self, controller_name):
self.execute(controller_name)
@tornado.web.asynchronous
def execute(self, controller_name):
# Obtain the appropriate controller for this request.
try:
controller = self.application.controllers[controller_name]
except KeyError:
self._error('unhandled', 'no controller named %s' % controller_name)
return
args = self.request.arguments
# Read request signature.
try:
signature = args.pop('Signature')[0]
except:
raise tornado.web.HTTPError(400)
# Make a copy of args for authentication and signature verification.
auth_params = {}
for key, value in args.items():
auth_params[key] = value[0]
# Get requested action and remove authentication args for final request.
try:
action = args.pop('Action')[0]
access = args.pop('AWSAccessKeyId')[0]
args.pop('SignatureMethod')
args.pop('SignatureVersion')
args.pop('Version')
args.pop('Timestamp')
except:
raise tornado.web.HTTPError(400)
# Authenticate the request.
try:
(user, project) = manager.AuthManager().authenticate(
access,
signature,
auth_params,
self.request.method,
self.request.host,
self.request.path
)
except exception.Error, ex:
logging.debug("Authentication Failure: %s" % ex)
raise tornado.web.HTTPError(403)
_log.debug('action: %s' % action)
for key, value in args.items():
_log.debug('arg: %s\t\tval: %s' % (key, value))
request = APIRequest(controller, action)
context = APIRequestContext(self, user, project)
d = request.send(context, **args)
# d.addCallback(utils.debug)
# TODO: Wrap response in AWS XML format
d.addCallbacks(self._write_callback, self._error_callback)
def _write_callback(self, data):
self.set_header('Content-Type', 'text/xml')
self.write(data)
self.finish()
def _error_callback(self, failure):
try:
failure.raiseException()
except exception.ApiError as ex:
self._error(type(ex).__name__ + "." + ex.code, ex.message)
# TODO(vish): do something more useful with unknown exceptions
except Exception as ex:
self._error(type(ex).__name__, str(ex))
raise
def post(self, controller_name):
self.execute(controller_name)
def _error(self, code, message):
self._status_code = 400
self.set_header('Content-Type', 'text/xml')
self.write('<?xml version="1.0"?>\n')
self.write('<Response><Errors><Error><Code>%s</Code>'
'<Message>%s</Message></Error></Errors>'
'<RequestID>?</RequestID></Response>' % (code, message))
self.finish()
class APIServerApplication(tornado.web.Application):
def __init__(self, controllers):
tornado.web.Application.__init__(self, [
(r'/', RootRequestHandler),
(r'/cloudpipe/(.*)', nova.cloudpipe.api.CloudPipeRequestHandler),
(r'/cloudpipe', nova.cloudpipe.api.CloudPipeRequestHandler),
(r'/services/([A-Za-z0-9]+)/', APIRequestHandler),
(r'/latest/([-A-Za-z0-9/]*)', MetadataRequestHandler),
(r'/2009-04-04/([-A-Za-z0-9/]*)', MetadataRequestHandler),
(r'/2008-09-01/([-A-Za-z0-9/]*)', MetadataRequestHandler),
(r'/2008-02-01/([-A-Za-z0-9/]*)', MetadataRequestHandler),
(r'/2007-12-15/([-A-Za-z0-9/]*)', MetadataRequestHandler),
(r'/2007-10-10/([-A-Za-z0-9/]*)', MetadataRequestHandler),
(r'/2007-08-29/([-A-Za-z0-9/]*)', MetadataRequestHandler),
(r'/2007-03-01/([-A-Za-z0-9/]*)', MetadataRequestHandler),
(r'/2007-01-19/([-A-Za-z0-9/]*)', MetadataRequestHandler),
(r'/1.0/([-A-Za-z0-9/]*)', MetadataRequestHandler),
], pool=multiprocessing.Pool(4))
self.controllers = controllers

View File

@@ -1,745 +0,0 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Cloud Controller: Implementation of EC2 REST API calls, which are
dispatched to other nodes via AMQP RPC. State is via distributed
datastore.
"""
import base64
import logging
import os
import time
from twisted.internet import defer
from nova import datastore
from nova import exception
from nova import flags
from nova import rpc
from nova import utils
from nova.auth import rbac
from nova.auth import manager
from nova.compute import model
from nova.compute.instance_types import INSTANCE_TYPES
from nova.endpoint import images
from nova.network import service as network_service
from nova.network import model as network_model
from nova.volume import service
FLAGS = flags.FLAGS
flags.DEFINE_string('cloud_topic', 'cloud', 'the topic clouds listen on')
def _gen_key(user_id, key_name):
""" Tuck this into AuthManager """
try:
mgr = manager.AuthManager()
private_key, fingerprint = mgr.generate_key_pair(user_id, key_name)
except Exception as ex:
return {'exception': ex}
return {'private_key': private_key, 'fingerprint': fingerprint}
class CloudController(object):
""" CloudController provides the critical dispatch between
inbound API calls through the endpoint and messages
sent to the other nodes.
"""
def __init__(self):
self.instdir = model.InstanceDirectory()
self.setup()
@property
def instances(self):
""" All instances in the system, as dicts """
return self.instdir.all
@property
def volumes(self):
""" returns a list of all volumes """
for volume_id in datastore.Redis.instance().smembers("volumes"):
volume = service.get_volume(volume_id)
yield volume
def __str__(self):
return 'CloudController'
def setup(self):
""" Ensure the keychains and folders exist. """
# Create keys folder, if it doesn't exist
if not os.path.exists(FLAGS.keys_path):
os.makedirs(FLAGS.keys_path)
# Gen root CA, if we don't have one
root_ca_path = os.path.join(FLAGS.ca_path, FLAGS.ca_file)
if not os.path.exists(root_ca_path):
start = os.getcwd()
os.chdir(FLAGS.ca_path)
utils.runthis("Generating root CA: %s", "sh genrootca.sh")
os.chdir(start)
# TODO: Do this with M2Crypto instead
def get_instance_by_ip(self, ip):
return self.instdir.by_ip(ip)
def _get_mpi_data(self, project_id):
result = {}
for instance in self.instdir.all:
if instance['project_id'] == project_id:
line = '%s slots=%d' % (instance['private_dns_name'],
INSTANCE_TYPES[instance['instance_type']]['vcpus'])
if instance['key_name'] in result:
result[instance['key_name']].append(line)
else:
result[instance['key_name']] = [line]
return result
def get_metadata(self, ipaddress):
i = self.get_instance_by_ip(ipaddress)
if i is None:
return None
mpi = self._get_mpi_data(i['project_id'])
if i['key_name']:
keys = {
'0': {
'_name': i['key_name'],
'openssh-key': i['key_data']
}
}
else:
keys = ''
address_record = network_model.FixedIp(i['private_dns_name'])
if address_record:
hostname = address_record['hostname']
else:
hostname = 'ip-%s' % i['private_dns_name'].replace('.', '-')
data = {
'user-data': base64.b64decode(i['user_data']),
'meta-data': {
'ami-id': i['image_id'],
'ami-launch-index': i['ami_launch_index'],
'ami-manifest-path': 'FIXME', # image property
'block-device-mapping': { # TODO: replace with real data
'ami': 'sda1',
'ephemeral0': 'sda2',
'root': '/dev/sda1',
'swap': 'sda3'
},
'hostname': hostname,
'instance-action': 'none',
'instance-id': i['instance_id'],
'instance-type': i.get('instance_type', ''),
'local-hostname': hostname,
'local-ipv4': i['private_dns_name'], # TODO: switch to IP
'kernel-id': i.get('kernel_id', ''),
'placement': {
'availaibility-zone': i.get('availability_zone', 'nova'),
},
'public-hostname': hostname,
'public-ipv4': i.get('dns_name', ''), # TODO: switch to IP
'public-keys': keys,
'ramdisk-id': i.get('ramdisk_id', ''),
'reservation-id': i['reservation_id'],
'security-groups': i.get('groups', ''),
'mpi': mpi
}
}
if False: # TODO: store ancestor ids
data['ancestor-ami-ids'] = []
if i.get('product_codes', None):
data['product-codes'] = i['product_codes']
return data
@rbac.allow('all')
def describe_availability_zones(self, context, **kwargs):
return {'availabilityZoneInfo': [{'zoneName': 'nova',
'zoneState': 'available'}]}
@rbac.allow('all')
def describe_regions(self, context, region_name=None, **kwargs):
# TODO(vish): region_name is an array. Support filtering
return {'regionInfo': [{'regionName': 'nova',
'regionUrl': FLAGS.ec2_url}]}
@rbac.allow('all')
def describe_snapshots(self,
context,
snapshot_id=None,
owner=None,
restorable_by=None,
**kwargs):
return {'snapshotSet': [{'snapshotId': 'fixme',
'volumeId': 'fixme',
'status': 'fixme',
'startTime': 'fixme',
'progress': 'fixme',
'ownerId': 'fixme',
'volumeSize': 0,
'description': 'fixme'}]}
@rbac.allow('all')
def describe_key_pairs(self, context, key_name=None, **kwargs):
key_pairs = context.user.get_key_pairs()
if not key_name is None:
key_pairs = [x for x in key_pairs if x.name in key_name]
result = []
for key_pair in key_pairs:
# filter out the vpn keys
suffix = FLAGS.vpn_key_suffix
if context.user.is_admin() or not key_pair.name.endswith(suffix):
result.append({
'keyName': key_pair.name,
'keyFingerprint': key_pair.fingerprint,
})
return {'keypairsSet': result}
@rbac.allow('all')
def create_key_pair(self, context, key_name, **kwargs):
dcall = defer.Deferred()
pool = context.handler.application.settings.get('pool')
def _complete(kwargs):
if 'exception' in kwargs:
dcall.errback(kwargs['exception'])
return
dcall.callback({'keyName': key_name,
'keyFingerprint': kwargs['fingerprint'],
'keyMaterial': kwargs['private_key']})
pool.apply_async(_gen_key, [context.user.id, key_name],
callback=_complete)
return dcall
@rbac.allow('all')
def delete_key_pair(self, context, key_name, **kwargs):
context.user.delete_key_pair(key_name)
# aws returns true even if the key doens't exist
return True
@rbac.allow('all')
def describe_security_groups(self, context, group_names, **kwargs):
groups = {'securityGroupSet': []}
# Stubbed for now to unblock other things.
return groups
@rbac.allow('netadmin')
def create_security_group(self, context, group_name, **kwargs):
return True
@rbac.allow('netadmin')
def delete_security_group(self, context, group_name, **kwargs):
return True
@rbac.allow('projectmanager', 'sysadmin')
def get_console_output(self, context, instance_id, **kwargs):
# instance_id is passed in as a list of instances
instance = self._get_instance(context, instance_id[0])
return rpc.call('%s.%s' % (FLAGS.compute_topic, instance['node_name']),
{"method": "get_console_output",
"args": {"instance_id": instance_id[0]}})
def _get_user_id(self, context):
if context and context.user:
return context.user.id
else:
return None
@rbac.allow('projectmanager', 'sysadmin')
def describe_volumes(self, context, **kwargs):
volumes = []
for volume in self.volumes:
if context.user.is_admin() or volume['project_id'] == context.project.id:
v = self.format_volume(context, volume)
volumes.append(v)
return defer.succeed({'volumeSet': volumes})
def format_volume(self, context, volume):
v = {}
v['volumeId'] = volume['volume_id']
v['status'] = volume['status']
v['size'] = volume['size']
v['availabilityZone'] = volume['availability_zone']
v['createTime'] = volume['create_time']
if context.user.is_admin():
v['status'] = '%s (%s, %s, %s, %s)' % (
volume.get('status', None),
volume.get('user_id', None),
volume.get('node_name', None),
volume.get('instance_id', ''),
volume.get('mountpoint', ''))
if volume['attach_status'] == 'attached':
v['attachmentSet'] = [{'attachTime': volume['attach_time'],
'deleteOnTermination': volume['delete_on_termination'],
'device': volume['mountpoint'],
'instanceId': volume['instance_id'],
'status': 'attached',
'volume_id': volume['volume_id']}]
else:
v['attachmentSet'] = [{}]
return v
@rbac.allow('projectmanager', 'sysadmin')
@defer.inlineCallbacks
def create_volume(self, context, size, **kwargs):
# TODO(vish): refactor this to create the volume object here and tell service to create it
result = yield rpc.call(FLAGS.volume_topic, {"method": "create_volume",
"args": {"size": size,
"user_id": context.user.id,
"project_id": context.project.id}})
# NOTE(vish): rpc returned value is in the result key in the dictionary
volume = self._get_volume(context, result)
defer.returnValue({'volumeSet': [self.format_volume(context, volume)]})
def _get_address(self, context, public_ip):
# FIXME(vish) this should move into network.py
address = network_model.ElasticIp.lookup(public_ip)
if address and (context.user.is_admin() or address['project_id'] == context.project.id):
return address
raise exception.NotFound("Address at ip %s not found" % public_ip)
def _get_image(self, context, image_id):
"""passes in context because
objectstore does its own authorization"""
result = images.list(context, [image_id])
if not result:
raise exception.NotFound('Image %s could not be found' % image_id)
image = result[0]
return image
def _get_instance(self, context, instance_id):
for instance in self.instdir.all:
if instance['instance_id'] == instance_id:
if context.user.is_admin() or instance['project_id'] == context.project.id:
return instance
raise exception.NotFound('Instance %s could not be found' % instance_id)
def _get_volume(self, context, volume_id):
volume = service.get_volume(volume_id)
if context.user.is_admin() or volume['project_id'] == context.project.id:
return volume
raise exception.NotFound('Volume %s could not be found' % volume_id)
@rbac.allow('projectmanager', 'sysadmin')
def attach_volume(self, context, volume_id, instance_id, device, **kwargs):
volume = self._get_volume(context, volume_id)
if volume['status'] == "attached":
raise exception.ApiError("Volume is already attached")
# TODO(vish): looping through all volumes is slow. We should probably maintain an index
for vol in self.volumes:
if vol['instance_id'] == instance_id and vol['mountpoint'] == device:
raise exception.ApiError("Volume %s is already attached to %s" % (vol['volume_id'], vol['mountpoint']))
volume.start_attach(instance_id, device)
instance = self._get_instance(context, instance_id)
compute_node = instance['node_name']
rpc.cast('%s.%s' % (FLAGS.compute_topic, compute_node),
{"method": "attach_volume",
"args": {"volume_id": volume_id,
"instance_id": instance_id,
"mountpoint": device}})
return defer.succeed({'attachTime': volume['attach_time'],
'device': volume['mountpoint'],
'instanceId': instance_id,
'requestId': context.request_id,
'status': volume['attach_status'],
'volumeId': volume_id})
@rbac.allow('projectmanager', 'sysadmin')
def detach_volume(self, context, volume_id, **kwargs):
volume = self._get_volume(context, volume_id)
instance_id = volume.get('instance_id', None)
if not instance_id:
raise exception.Error("Volume isn't attached to anything!")
if volume['status'] == "available":
raise exception.Error("Volume is already detached")
try:
volume.start_detach()
instance = self._get_instance(context, instance_id)
rpc.cast('%s.%s' % (FLAGS.compute_topic, instance['node_name']),
{"method": "detach_volume",
"args": {"instance_id": instance_id,
"volume_id": volume_id}})
except exception.NotFound:
# If the instance doesn't exist anymore,
# then we need to call detach blind
volume.finish_detach()
return defer.succeed({'attachTime': volume['attach_time'],
'device': volume['mountpoint'],
'instanceId': instance_id,
'requestId': context.request_id,
'status': volume['attach_status'],
'volumeId': volume_id})
def _convert_to_set(self, lst, label):
if lst == None or lst == []:
return None
if not isinstance(lst, list):
lst = [lst]
return [{label: x} for x in lst]
@rbac.allow('all')
def describe_instances(self, context, **kwargs):
return defer.succeed(self._format_describe_instances(context))
def _format_describe_instances(self, context):
return { 'reservationSet': self._format_instances(context) }
def _format_run_instances(self, context, reservation_id):
i = self._format_instances(context, reservation_id)
assert len(i) == 1
return i[0]
def _format_instances(self, context, reservation_id = None):
reservations = {}
if context.user.is_admin():
instgenerator = self.instdir.all
else:
instgenerator = self.instdir.by_project(context.project.id)
for instance in instgenerator:
res_id = instance.get('reservation_id', 'Unknown')
if reservation_id != None and reservation_id != res_id:
continue
if not context.user.is_admin():
if instance['image_id'] == FLAGS.vpn_image_id:
continue
i = {}
i['instance_id'] = instance.get('instance_id', None)
i['image_id'] = instance.get('image_id', None)
i['instance_state'] = {
'code': instance.get('state', 0),
'name': instance.get('state_description', 'pending')
}
i['public_dns_name'] = network_model.get_public_ip_for_instance(
i['instance_id'])
i['private_dns_name'] = instance.get('private_dns_name', None)
if not i['public_dns_name']:
i['public_dns_name'] = i['private_dns_name']
i['dns_name'] = instance.get('dns_name', None)
i['key_name'] = instance.get('key_name', None)
if context.user.is_admin():
i['key_name'] = '%s (%s, %s)' % (i['key_name'],
instance.get('project_id', None),
instance.get('node_name', ''))
i['product_codes_set'] = self._convert_to_set(
instance.get('product_codes', None), 'product_code')
i['instance_type'] = instance.get('instance_type', None)
i['launch_time'] = instance.get('launch_time', None)
i['ami_launch_index'] = instance.get('ami_launch_index',
None)
if not reservations.has_key(res_id):
r = {}
r['reservation_id'] = res_id
r['owner_id'] = instance.get('project_id', None)
r['group_set'] = self._convert_to_set(
instance.get('groups', None), 'group_id')
r['instances_set'] = []
reservations[res_id] = r
reservations[res_id]['instances_set'].append(i)
return list(reservations.values())
@rbac.allow('all')
def describe_addresses(self, context, **kwargs):
return self.format_addresses(context)
def format_addresses(self, context):
addresses = []
for address in network_model.ElasticIp.all():
# TODO(vish): implement a by_project iterator for addresses
if (context.user.is_admin() or
address['project_id'] == context.project.id):
address_rv = {
'public_ip': address['address'],
'instance_id': address.get('instance_id', 'free')
}
if context.user.is_admin():
address_rv['instance_id'] = "%s (%s, %s)" % (
address['instance_id'],
address['user_id'],
address['project_id'],
)
addresses.append(address_rv)
return {'addressesSet': addresses}
@rbac.allow('netadmin')
@defer.inlineCallbacks
def allocate_address(self, context, **kwargs):
network_topic = yield self._get_network_topic(context)
public_ip = yield rpc.call(network_topic,
{"method": "allocate_elastic_ip",
"args": {"user_id": context.user.id,
"project_id": context.project.id}})
defer.returnValue({'addressSet': [{'publicIp': public_ip}]})
@rbac.allow('netadmin')
@defer.inlineCallbacks
def release_address(self, context, public_ip, **kwargs):
# NOTE(vish): Should we make sure this works?
network_topic = yield self._get_network_topic(context)
rpc.cast(network_topic,
{"method": "deallocate_elastic_ip",
"args": {"elastic_ip": public_ip}})
defer.returnValue({'releaseResponse': ["Address released."]})
@rbac.allow('netadmin')
@defer.inlineCallbacks
def associate_address(self, context, instance_id, public_ip, **kwargs):
instance = self._get_instance(context, instance_id)
address = self._get_address(context, public_ip)
network_topic = yield self._get_network_topic(context)
rpc.cast(network_topic,
{"method": "associate_elastic_ip",
"args": {"elastic_ip": address['address'],
"fixed_ip": instance['private_dns_name'],
"instance_id": instance['instance_id']}})
defer.returnValue({'associateResponse': ["Address associated."]})
@rbac.allow('netadmin')
@defer.inlineCallbacks
def disassociate_address(self, context, public_ip, **kwargs):
address = self._get_address(context, public_ip)
network_topic = yield self._get_network_topic(context)
rpc.cast(network_topic,
{"method": "disassociate_elastic_ip",
"args": {"elastic_ip": address['address']}})
defer.returnValue({'disassociateResponse': ["Address disassociated."]})
@defer.inlineCallbacks
def _get_network_topic(self, context):
"""Retrieves the network host for a project"""
host = network_service.get_host_for_project(context.project.id)
if not host:
host = yield rpc.call(FLAGS.network_topic,
{"method": "set_network_host",
"args": {"user_id": context.user.id,
"project_id": context.project.id}})
defer.returnValue('%s.%s' %(FLAGS.network_topic, host))
@rbac.allow('projectmanager', 'sysadmin')
@defer.inlineCallbacks
def run_instances(self, context, **kwargs):
# make sure user can access the image
# vpn image is private so it doesn't show up on lists
if kwargs['image_id'] != FLAGS.vpn_image_id:
image = self._get_image(context, kwargs['image_id'])
# FIXME(ja): if image is cloudpipe, this breaks
# get defaults from imagestore
image_id = image['imageId']
kernel_id = image.get('kernelId', FLAGS.default_kernel)
ramdisk_id = image.get('ramdiskId', FLAGS.default_ramdisk)
# API parameters overrides of defaults
kernel_id = kwargs.get('kernel_id', kernel_id)
ramdisk_id = kwargs.get('ramdisk_id', ramdisk_id)
if kernel_id == str(FLAGS.null_kernel):
kernel_id = None
ramdisk_id = None
# make sure we have access to kernel and ramdisk
if kernel_id:
self._get_image(context, kernel_id)
if ramdisk_id:
self._get_image(context, ramdisk_id)
logging.debug("Going to run instances...")
reservation_id = utils.generate_uid('r')
launch_time = time.strftime('%Y-%m-%dT%H:%M:%SZ', time.gmtime())
key_data = None
if kwargs.has_key('key_name'):
key_pair = context.user.get_key_pair(kwargs['key_name'])
if not key_pair:
raise exception.ApiError('Key Pair %s not found' %
kwargs['key_name'])
key_data = key_pair.public_key
network_topic = yield self._get_network_topic(context)
# TODO: Get the real security group of launch in here
security_group = "default"
for num in range(int(kwargs['max_count'])):
is_vpn = False
if image_id == FLAGS.vpn_image_id:
is_vpn = True
inst = self.instdir.new()
allocate_data = yield rpc.call(network_topic,
{"method": "allocate_fixed_ip",
"args": {"user_id": context.user.id,
"project_id": context.project.id,
"security_group": security_group,
"is_vpn": is_vpn,
"hostname": inst.instance_id}})
inst['image_id'] = image_id
inst['kernel_id'] = kernel_id or ''
inst['ramdisk_id'] = ramdisk_id or ''
inst['user_data'] = kwargs.get('user_data', '')
inst['instance_type'] = kwargs.get('instance_type', 'm1.small')
inst['reservation_id'] = reservation_id
inst['launch_time'] = launch_time
inst['key_data'] = key_data or ''
inst['key_name'] = kwargs.get('key_name', '')
inst['user_id'] = context.user.id
inst['project_id'] = context.project.id
inst['ami_launch_index'] = num
inst['security_group'] = security_group
inst['hostname'] = inst.instance_id
for (key, value) in allocate_data.iteritems():
inst[key] = value
inst.save()
rpc.cast(FLAGS.compute_topic,
{"method": "run_instance",
"args": {"instance_id": inst.instance_id}})
logging.debug("Casting to node for %s's instance with IP of %s" %
(context.user.name, inst['private_dns_name']))
# TODO: Make Network figure out the network name from ip.
defer.returnValue(self._format_run_instances(context, reservation_id))
@rbac.allow('projectmanager', 'sysadmin')
@defer.inlineCallbacks
def terminate_instances(self, context, instance_id, **kwargs):
logging.debug("Going to start terminating instances")
network_topic = yield self._get_network_topic(context)
for i in instance_id:
logging.debug("Going to try and terminate %s" % i)
try:
instance = self._get_instance(context, i)
except exception.NotFound:
logging.warning("Instance %s was not found during terminate"
% i)
continue
elastic_ip = network_model.get_public_ip_for_instance(i)
if elastic_ip:
logging.debug("Disassociating address %s" % elastic_ip)
# NOTE(vish): Right now we don't really care if the ip is
# disassociated. We may need to worry about
# checking this later. Perhaps in the scheduler?
rpc.cast(network_topic,
{"method": "disassociate_elastic_ip",
"args": {"elastic_ip": elastic_ip}})
fixed_ip = instance.get('private_dns_name', None)
if fixed_ip:
logging.debug("Deallocating address %s" % fixed_ip)
# NOTE(vish): Right now we don't really care if the ip is
# actually removed. We may need to worry about
# checking this later. Perhaps in the scheduler?
rpc.cast(network_topic,
{"method": "deallocate_fixed_ip",
"args": {"fixed_ip": fixed_ip}})
if instance.get('node_name', 'unassigned') != 'unassigned':
# NOTE(joshua?): It's also internal default
rpc.cast('%s.%s' % (FLAGS.compute_topic, instance['node_name']),
{"method": "terminate_instance",
"args": {"instance_id": i}})
else:
instance.destroy()
defer.returnValue(True)
@rbac.allow('projectmanager', 'sysadmin')
def reboot_instances(self, context, instance_id, **kwargs):
"""instance_id is a list of instance ids"""
for i in instance_id:
instance = self._get_instance(context, i)
rpc.cast('%s.%s' % (FLAGS.compute_topic, instance['node_name']),
{"method": "reboot_instance",
"args": {"instance_id": i}})
return defer.succeed(True)
@rbac.allow('projectmanager', 'sysadmin')
def delete_volume(self, context, volume_id, **kwargs):
# TODO: return error if not authorized
volume = self._get_volume(context, volume_id)
volume_node = volume['node_name']
rpc.cast('%s.%s' % (FLAGS.volume_topic, volume_node),
{"method": "delete_volume",
"args": {"volume_id": volume_id}})
return defer.succeed(True)
@rbac.allow('all')
def describe_images(self, context, image_id=None, **kwargs):
# The objectstore does its own authorization for describe
imageSet = images.list(context, image_id)
return defer.succeed({'imagesSet': imageSet})
@rbac.allow('projectmanager', 'sysadmin')
def deregister_image(self, context, image_id, **kwargs):
# FIXME: should the objectstore be doing these authorization checks?
images.deregister(context, image_id)
return defer.succeed({'imageId': image_id})
@rbac.allow('projectmanager', 'sysadmin')
def register_image(self, context, image_location=None, **kwargs):
# FIXME: should the objectstore be doing these authorization checks?
if image_location is None and kwargs.has_key('name'):
image_location = kwargs['name']
image_id = images.register(context, image_location)
logging.debug("Registered %s as %s" % (image_location, image_id))
return defer.succeed({'imageId': image_id})
@rbac.allow('all')
def describe_image_attribute(self, context, image_id, attribute, **kwargs):
if attribute != 'launchPermission':
raise exception.ApiError('attribute not supported: %s' % attribute)
try:
image = images.list(context, image_id)[0]
except IndexError:
raise exception.ApiError('invalid id: %s' % image_id)
result = {'image_id': image_id, 'launchPermission': []}
if image['isPublic']:
result['launchPermission'].append({'group': 'all'})
return defer.succeed(result)
@rbac.allow('projectmanager', 'sysadmin')
def modify_image_attribute(self, context, image_id, attribute, operation_type, **kwargs):
# TODO(devcamcar): Support users and groups other than 'all'.
if attribute != 'launchPermission':
raise exception.ApiError('attribute not supported: %s' % attribute)
if not 'user_group' in kwargs:
raise exception.ApiError('user or group not specified')
if len(kwargs['user_group']) != 1 and kwargs['user_group'][0] != 'all':
raise exception.ApiError('only group "all" is supported')
if not operation_type in ['add', 'remove']:
raise exception.ApiError('operation_type must be add or remove')
result = images.modify(context, image_id, operation_type)
return defer.succeed(result)
def update_state(self, topic, value):
""" accepts status reports from the queue and consolidates them """
# TODO(jmc): if an instance has disappeared from
# the node, call instance_death
if topic == "instances":
return defer.succeed(True)
aggregate_state = getattr(self, topic)
node_name = value.keys()[0]
items = value[node_name]
logging.debug("Updating %s state for %s" % (topic, node_name))
for item_id in items.keys():
if (aggregate_state.has_key('pending') and
aggregate_state['pending'].has_key(item_id)):
del aggregate_state['pending'][item_id]
aggregate_state[node_name] = items
return defer.succeed(True)

View File

@@ -1,100 +0,0 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Proxy AMI-related calls from the cloud controller, to the running
objectstore daemon.
"""
import json
import urllib
import boto.s3.connection
from nova import flags
from nova import utils
from nova.auth import manager
FLAGS = flags.FLAGS
def modify(context, image_id, operation):
conn(context).make_request(
method='POST',
bucket='_images',
query_args=qs({'image_id': image_id, 'operation': operation}))
return True
def register(context, image_location):
""" rpc call to register a new image based from a manifest """
image_id = utils.generate_uid('ami')
conn(context).make_request(
method='PUT',
bucket='_images',
query_args=qs({'image_location': image_location,
'image_id': image_id}))
return image_id
def list(context, filter_list=[]):
""" return a list of all images that a user can see
optionally filtered by a list of image_id """
# FIXME: send along the list of only_images to check for
response = conn(context).make_request(
method='GET',
bucket='_images')
result = json.loads(response.read())
if not filter_list is None:
return [i for i in result if i['imageId'] in filter_list]
return result
def deregister(context, image_id):
""" unregister an image """
conn(context).make_request(
method='DELETE',
bucket='_images',
query_args=qs({'image_id': image_id}))
def conn(context):
access = manager.AuthManager().get_access_key(context.user,
context.project)
secret = str(context.user.secret)
calling = boto.s3.connection.OrdinaryCallingFormat()
return boto.s3.connection.S3Connection(aws_access_key_id=access,
aws_secret_access_key=secret,
is_secure=False,
calling_format=calling,
port=FLAGS.s3_port,
host=FLAGS.s3_host)
def qs(params):
pairs = []
for key in params.keys():
pairs.append(key + '=' + urllib.quote(params[key]))
return '&'.join(pairs)

View File

@@ -22,6 +22,7 @@ import logging
import Queue as queue
from carrot.backends import base
from eventlet import greenthread
class Message(base.BaseMessage):
@@ -38,6 +39,7 @@ class Exchange(object):
def publish(self, message, routing_key=None):
logging.debug('(%s) publish (key: %s) %s',
self.name, routing_key, message)
routing_key = routing_key.split('.')[0]
if routing_key in self._routes:
for f in self._routes[routing_key]:
logging.debug('Publishing to route %s', f)
@@ -94,6 +96,18 @@ class Backend(object):
self._exchanges[exchange].bind(self._queues[queue].push,
routing_key)
def declare_consumer(self, queue, callback, *args, **kwargs):
self.current_queue = queue
self.current_callback = callback
def consume(self, *args, **kwargs):
while True:
item = self.get(self.current_queue)
if item:
self.current_callback(item)
raise StopIteration()
greenthread.sleep(0)
def get(self, queue, no_ack=False):
if not queue in self._queues or not self._queues[queue].size():
return None
@@ -102,6 +116,7 @@ class Backend(object):
message = Message(backend=self, body=message_data,
content_type=content_type,
content_encoding=content_encoding)
message.result = True
logging.debug('Getting from %s: %s', queue, message)
return message

View File

@@ -22,6 +22,7 @@ where they're used.
"""
import getopt
import os
import socket
import sys
@@ -34,7 +35,7 @@ class FlagValues(gflags.FlagValues):
Unknown flags will be ignored when parsing the command line, but the
command line will be kept so that it can be replayed if new flags are
defined after the initial parsing.
"""
def __init__(self):
@@ -50,7 +51,7 @@ class FlagValues(gflags.FlagValues):
# leftover args at the end
sneaky_unparsed_args = {"value": None}
original_argv = list(argv)
if self.IsGnuGetOpt():
orig_getopt = getattr(getopt, 'gnu_getopt')
orig_name = 'gnu_getopt'
@@ -74,14 +75,14 @@ class FlagValues(gflags.FlagValues):
unparsed_args = sneaky_unparsed_args['value']
if unparsed_args:
if self.IsGnuGetOpt():
args = argv[:1] + unparsed
args = argv[:1] + unparsed_args
else:
args = argv[:1] + original_argv[-len(unparsed_args):]
else:
args = argv[:1]
finally:
setattr(getopt, orig_name, orig_getopt)
# Store the arguments for later, we'll need them for new flags
# added at runtime
self.__dict__['__stored_argv'] = original_argv
@@ -92,7 +93,7 @@ class FlagValues(gflags.FlagValues):
def SetDirty(self, name):
"""Mark a flag as dirty so that accessing it will case a reparse."""
self.__dict__['__dirty'].append(name)
def IsDirty(self, name):
return name in self.__dict__['__dirty']
@@ -113,12 +114,12 @@ class FlagValues(gflags.FlagValues):
for k in self.__dict__['__dirty']:
setattr(self, k, getattr(new_flags, k))
self.ClearDirty()
def __setitem__(self, name, flag):
gflags.FlagValues.__setitem__(self, name, flag)
if self.WasAlreadyParsed():
self.SetDirty(name)
def __getitem__(self, name):
if self.IsDirty(name):
self.ParseNewFlags()
@@ -141,6 +142,7 @@ def _wrapper(func):
return _wrapped
DEFINE = _wrapper(gflags.DEFINE)
DEFINE_string = _wrapper(gflags.DEFINE_string)
DEFINE_integer = _wrapper(gflags.DEFINE_integer)
DEFINE_bool = _wrapper(gflags.DEFINE_bool)
@@ -165,11 +167,14 @@ def DECLARE(name, module_string, flag_values=FLAGS):
# Define any app-specific flags in their own files, docs at:
# http://code.google.com/p/python-gflags/source/browse/trunk/gflags.py#39
DEFINE_list('region_list',
[],
'list of region=url pairs separated by commas')
DEFINE_string('connection_type', 'libvirt', 'libvirt, xenapi or fake')
DEFINE_integer('s3_port', 3333, 's3 port')
DEFINE_string('s3_host', '127.0.0.1', 's3 host')
#DEFINE_string('cloud_topic', 'cloud', 'the topic clouds listen on')
DEFINE_string('compute_topic', 'compute', 'the topic compute nodes listen on')
DEFINE_string('scheduler_topic', 'scheduler', 'the topic scheduler nodes listen on')
DEFINE_string('volume_topic', 'volume', 'the topic volume nodes listen on')
DEFINE_string('network_topic', 'network', 'the topic network nodes listen on')
@@ -183,6 +188,8 @@ DEFINE_string('rabbit_userid', 'guest', 'rabbit userid')
DEFINE_string('rabbit_password', 'guest', 'rabbit password')
DEFINE_string('rabbit_virtual_host', '/', 'rabbit virtual host')
DEFINE_string('control_exchange', 'nova', 'the main exchange to connect to')
DEFINE_string('cc_host', '127.0.0.1', 'ip of api server')
DEFINE_integer('cc_port', 8773, 'cloud controller port')
DEFINE_string('ec2_url', 'http://127.0.0.1:8773/services/Cloud',
'Url to ec2 api server')
@@ -204,9 +211,26 @@ DEFINE_string('vpn_key_suffix',
DEFINE_integer('auth_token_ttl', 3600, 'Seconds for auth tokens to linger')
DEFINE_string('sql_connection',
'sqlite:///%s/nova.sqlite' % os.path.abspath("./"),
'connection string for sql database')
DEFINE_string('compute_manager', 'nova.compute.manager.ComputeManager',
'Manager for compute')
DEFINE_string('network_manager', 'nova.network.manager.VlanManager',
'Manager for network')
DEFINE_string('volume_manager', 'nova.volume.manager.AOEManager',
'Manager for volume')
DEFINE_string('scheduler_manager', 'nova.scheduler.manager.SchedulerManager',
'Manager for scheduler')
# The service to use for image search and retrieval
DEFINE_string('image_service', 'nova.image.service.LocalImageService',
'The service to use for retrieving and searching for images.')
DEFINE_string('host', socket.gethostname(),
'name of this node')
# UNUSED
DEFINE_string('node_availability_zone', 'nova',
'availability zone of this node')
DEFINE_string('node_name', socket.gethostname(),
'name of this node')

View File

@@ -18,9 +18,10 @@
# under the License.
"""
Process pool, still buggy right now.
Process pool using twisted threading
"""
import logging
import StringIO
from twisted.internet import defer
@@ -29,30 +30,14 @@ from twisted.internet import protocol
from twisted.internet import reactor
from nova import flags
from nova.exception import ProcessExecutionError
FLAGS = flags.FLAGS
flags.DEFINE_integer('process_pool_size', 4,
'Number of processes to use in the process pool')
# NOTE(termie): this is copied from twisted.internet.utils but since
# they don't export it I've copied and modified
class UnexpectedErrorOutput(IOError):
"""
Standard error data was received where it was not expected. This is a
subclass of L{IOError} to preserve backward compatibility with the previous
error behavior of L{getProcessOutput}.
@ivar processEnded: A L{Deferred} which will fire when the process which
produced the data on stderr has ended (exited and all file descriptors
closed).
"""
def __init__(self, stdout=None, stderr=None):
IOError.__init__(self, "got stdout: %r\nstderr: %r" % (stdout, stderr))
# This is based on _BackRelay from twister.internal.utils, but modified to
# capture both stdout and stderr, without odd stderr handling, and also to
# This is based on _BackRelay from twister.internal.utils, but modified to
# capture both stdout and stderr, without odd stderr handling, and also to
# handle stdin
class BackRelayWithInput(protocol.ProcessProtocol):
"""
@@ -62,22 +47,23 @@ class BackRelayWithInput(protocol.ProcessProtocol):
@ivar deferred: A L{Deferred} which will be called back with all of stdout
and all of stderr as well (as a tuple). C{terminate_on_stderr} is true
and any bytes are received over stderr, this will fire with an
L{_UnexpectedErrorOutput} instance and the attribute will be set to
L{_ProcessExecutionError} instance and the attribute will be set to
C{None}.
@ivar onProcessEnded: If C{terminate_on_stderr} is false and bytes are
received over stderr, this attribute will refer to a L{Deferred} which
will be called back when the process ends. This C{Deferred} is also
associated with the L{_UnexpectedErrorOutput} which C{deferred} fires
with earlier in this case so that users can determine when the process
@ivar onProcessEnded: If C{terminate_on_stderr} is false and bytes are
received over stderr, this attribute will refer to a L{Deferred} which
will be called back when the process ends. This C{Deferred} is also
associated with the L{_ProcessExecutionError} which C{deferred} fires
with earlier in this case so that users can determine when the process
has actually ended, in addition to knowing when bytes have been received
via stderr.
"""
def __init__(self, deferred, started_deferred=None,
terminate_on_stderr=False, check_exit_code=True,
process_input=None):
def __init__(self, deferred, cmd, started_deferred=None,
terminate_on_stderr=False, check_exit_code=True,
process_input=None):
self.deferred = deferred
self.cmd = cmd
self.stdout = StringIO.StringIO()
self.stderr = StringIO.StringIO()
self.started_deferred = started_deferred
@@ -85,14 +71,18 @@ class BackRelayWithInput(protocol.ProcessProtocol):
self.check_exit_code = check_exit_code
self.process_input = process_input
self.on_process_ended = None
def _build_execution_error(self, exit_code=None):
return ProcessExecutionError(cmd=self.cmd,
exit_code=exit_code,
stdout=self.stdout.getvalue(),
stderr=self.stderr.getvalue())
def errReceived(self, text):
self.stderr.write(text)
if self.terminate_on_stderr and (self.deferred is not None):
self.on_process_ended = defer.Deferred()
self.deferred.errback(UnexpectedErrorOutput(
stdout=self.stdout.getvalue(),
stderr=self.stderr.getvalue()))
self.deferred.errback(self._build_execution_error())
self.deferred = None
self.transport.loseConnection()
@@ -102,15 +92,19 @@ class BackRelayWithInput(protocol.ProcessProtocol):
def processEnded(self, reason):
if self.deferred is not None:
stdout, stderr = self.stdout.getvalue(), self.stderr.getvalue()
try:
if self.check_exit_code:
reason.trap(error.ProcessDone)
self.deferred.callback((stdout, stderr))
except:
# NOTE(justinsb): This logic is a little suspicious to me...
# If the callback throws an exception, then errback will be
# called also. However, this is what the unit tests test for...
self.deferred.errback(UnexpectedErrorOutput(stdout, stderr))
exit_code = reason.value.exitCode
if self.check_exit_code and exit_code <> 0:
self.deferred.errback(self._build_execution_error(exit_code))
else:
try:
if self.check_exit_code:
reason.trap(error.ProcessDone)
self.deferred.callback((stdout, stderr))
except:
# NOTE(justinsb): This logic is a little suspicious to me...
# If the callback throws an exception, then errback will be
# called also. However, this is what the unit tests test for...
self.deferred.errback(self._build_execution_error(exit_code))
elif self.on_process_ended is not None:
self.on_process_ended.errback(reason)
@@ -119,11 +113,11 @@ class BackRelayWithInput(protocol.ProcessProtocol):
if self.started_deferred:
self.started_deferred.callback(self)
if self.process_input:
self.transport.write(self.process_input)
self.transport.write(str(self.process_input))
self.transport.closeStdin()
def get_process_output(executable, args=None, env=None, path=None,
process_reactor=None, check_exit_code=True,
def get_process_output(executable, args=None, env=None, path=None,
process_reactor=None, check_exit_code=True,
process_input=None, started_deferred=None,
terminate_on_stderr=False):
if process_reactor is None:
@@ -131,10 +125,15 @@ def get_process_output(executable, args=None, env=None, path=None,
args = args and args or ()
env = env and env and {}
deferred = defer.Deferred()
cmd = executable
if args:
cmd = " ".join([cmd] + args)
logging.debug("Running cmd: %s", cmd)
process_handler = BackRelayWithInput(
deferred,
started_deferred=started_deferred,
check_exit_code=check_exit_code,
deferred,
cmd,
started_deferred=started_deferred,
check_exit_code=check_exit_code,
process_input=process_input,
terminate_on_stderr=terminate_on_stderr)
# NOTE(vish): commands come in as unicode, but self.executes needs
@@ -142,8 +141,8 @@ def get_process_output(executable, args=None, env=None, path=None,
executable = str(executable)
if not args is None:
args = [str(x) for x in args]
process_reactor.spawnProcess( process_handler, executable,
(executable,)+tuple(args), env, path)
process_reactor.spawnProcess(process_handler, executable,
(executable,)+tuple(args), env, path)
return deferred

View File

@@ -28,6 +28,7 @@ import uuid
from carrot import connection as carrot_connection
from carrot import messaging
from eventlet import greenthread
from twisted.internet import defer
from twisted.internet import task
@@ -46,9 +47,9 @@ LOG.setLevel(logging.DEBUG)
class Connection(carrot_connection.BrokerConnection):
"""Connection instance object"""
@classmethod
def instance(cls):
def instance(cls, new=False):
"""Returns the instance"""
if not hasattr(cls, '_instance'):
if new or not hasattr(cls, '_instance'):
params = dict(hostname=FLAGS.rabbit_host,
port=FLAGS.rabbit_port,
userid=FLAGS.rabbit_userid,
@@ -60,7 +61,10 @@ class Connection(carrot_connection.BrokerConnection):
# NOTE(vish): magic is fun!
# pylint: disable-msg=W0142
cls._instance = cls(**params)
if new:
return cls(**params)
else:
cls._instance = cls(**params)
return cls._instance
@classmethod
@@ -81,21 +85,6 @@ class Consumer(messaging.Consumer):
self.failed_connection = False
super(Consumer, self).__init__(*args, **kwargs)
# TODO(termie): it would be nice to give these some way of automatically
# cleaning up after themselves
def attach_to_tornado(self, io_inst=None):
"""Attach a callback to tornado that fires 10 times a second"""
from tornado import ioloop
if io_inst is None:
io_inst = ioloop.IOLoop.instance()
injected = ioloop.PeriodicCallback(
lambda: self.fetch(enable_callbacks=True), 100, io_loop=io_inst)
injected.start()
return injected
attachToTornado = attach_to_tornado
def fetch(self, no_ack=None, auto_ack=None, enable_callbacks=False):
"""Wraps the parent fetch with some logic for failed connections"""
# TODO(vish): the logic for failed connections and logging should be
@@ -119,10 +108,19 @@ class Consumer(messaging.Consumer):
logging.exception("Failed to fetch message from queue")
self.failed_connection = True
def attach_to_eventlet(self):
"""Only needed for unit tests!"""
def fetch_repeatedly():
while True:
self.fetch(enable_callbacks=True)
greenthread.sleep(0.1)
greenthread.spawn(fetch_repeatedly)
def attach_to_twisted(self):
"""Attach a callback to twisted that fires 10 times a second"""
loop = task.LoopingCall(self.fetch, enable_callbacks=True)
loop.start(interval=0.1)
return loop
class Publisher(messaging.Publisher):
@@ -265,6 +263,41 @@ def call(topic, msg):
msg.update({'_msg_id': msg_id})
LOG.debug("MSG_ID is %s" % (msg_id))
class WaitMessage(object):
def __call__(self, data, message):
"""Acks message and sets result."""
message.ack()
if data['failure']:
self.result = RemoteError(*data['failure'])
else:
self.result = data['result']
wait_msg = WaitMessage()
conn = Connection.instance(True)
consumer = DirectConsumer(connection=conn, msg_id=msg_id)
consumer.register_callback(wait_msg)
conn = Connection.instance()
publisher = TopicPublisher(connection=conn, topic=topic)
publisher.send(msg)
publisher.close()
try:
consumer.wait(limit=1)
except StopIteration:
pass
consumer.close()
return wait_msg.result
def call_twisted(topic, msg):
"""Sends a message on a topic and wait for a response"""
LOG.debug("Making asynchronous call...")
msg_id = uuid.uuid4().hex
msg.update({'_msg_id': msg_id})
LOG.debug("MSG_ID is %s" % (msg_id))
conn = Connection.instance()
d = defer.Deferred()
consumer = DirectConsumer(connection=conn, msg_id=msg_id)
@@ -278,7 +311,7 @@ def call(topic, msg):
return d.callback(data['result'])
consumer.register_callback(deferred_receive)
injected = consumer.attach_to_tornado()
injected = consumer.attach_to_twisted()
# clean up after the injected listened and return x
d.addCallback(lambda x: injected.stop() and x or x)

View File

@@ -0,0 +1,25 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright (c) 2010 Openstack, LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
:mod:`nova.scheduler` -- Scheduler Nodes
=====================================================
.. automodule:: nova.scheduler
:platform: Unix
:synopsis: Module that picks a compute node to run a VM instance.
.. moduleauthor:: Chris Behrens <cbehrens@codestud.com>
"""

90
nova/scheduler/simple.py Normal file
View File

@@ -0,0 +1,90 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright (c) 2010 Openstack, LLC.
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Simple Scheduler
"""
import datetime
from nova import db
from nova import flags
from nova.scheduler import driver
from nova.scheduler import chance
FLAGS = flags.FLAGS
flags.DEFINE_integer("max_cores", 16,
"maximum number of instance cores to allow per host")
flags.DEFINE_integer("max_gigabytes", 10000,
"maximum number of volume gigabytes to allow per host")
flags.DEFINE_integer("max_networks", 1000,
"maximum number of networks to allow per host")
class SimpleScheduler(chance.ChanceScheduler):
"""Implements Naive Scheduler that tries to find least loaded host."""
def schedule_run_instance(self, context, instance_id, *_args, **_kwargs):
"""Picks a host that is up and has the fewest running instances."""
instance_ref = db.instance_get(context, instance_id)
results = db.service_get_all_compute_sorted(context)
for result in results:
(service, instance_cores) = result
if instance_cores + instance_ref['vcpus'] > FLAGS.max_cores:
raise driver.NoValidHost("All hosts have too many cores")
if self.service_is_up(service):
# NOTE(vish): this probably belongs in the manager, if we
# can generalize this somehow
now = datetime.datetime.utcnow()
db.instance_update(context,
instance_id,
{'host': service['host'],
'scheduled_at': now})
return service['host']
raise driver.NoValidHost("No hosts found")
def schedule_create_volume(self, context, volume_id, *_args, **_kwargs):
"""Picks a host that is up and has the fewest volumes."""
volume_ref = db.volume_get(context, volume_id)
results = db.service_get_all_volume_sorted(context)
for result in results:
(service, volume_gigabytes) = result
if volume_gigabytes + volume_ref['size'] > FLAGS.max_gigabytes:
raise driver.NoValidHost("All hosts have too many gigabytes")
if self.service_is_up(service):
# NOTE(vish): this probably belongs in the manager, if we
# can generalize this somehow
now = datetime.datetime.utcnow()
db.volume_update(context,
volume_id,
{'host': service['host'],
'scheduled_at': now})
return service['host']
raise driver.NoValidHost("No hosts found")
def schedule_set_network_host(self, context, *_args, **_kwargs):
"""Picks a host that is up and has the fewest networks."""
results = db.service_get_all_network_sorted(context)
for result in results:
(service, instance_count) = result
if instance_count >= FLAGS.max_networks:
raise driver.NoValidHost("All hosts have too many networks")
if self.service_is_up(service):
return service['host']
raise driver.NoValidHost("No hosts found")

View File

@@ -44,6 +44,8 @@ flags.DEFINE_bool('use_syslog', True, 'output to syslog when daemonizing')
flags.DEFINE_string('logfile', None, 'log file to output to')
flags.DEFINE_string('pidfile', None, 'pid file to output to')
flags.DEFINE_string('working_directory', './', 'working directory...')
flags.DEFINE_integer('uid', os.getuid(), 'uid under which to run')
flags.DEFINE_integer('gid', os.getgid(), 'gid under which to run')
def stop(pidfile):
@@ -58,7 +60,7 @@ def stop(pidfile):
sys.stderr.write(message % pidfile)
return # not an error in a restart
# Try killing the daemon process
# Try killing the daemon process
try:
while 1:
os.kill(pid, signal.SIGTERM)
@@ -104,6 +106,7 @@ def serve(name, main):
def daemonize(args, name, main):
"""Does the work of daemonizing the process"""
logging.getLogger('amqplib').setLevel(logging.WARN)
files_to_keep = []
if FLAGS.daemonize:
logger = logging.getLogger()
formatter = logging.Formatter(
@@ -112,12 +115,14 @@ def daemonize(args, name, main):
syslog = logging.handlers.SysLogHandler(address='/dev/log')
syslog.setFormatter(formatter)
logger.addHandler(syslog)
files_to_keep.append(syslog.socket)
else:
if not FLAGS.logfile:
FLAGS.logfile = '%s.log' % name
logfile = logging.FileHandler(FLAGS.logfile)
logfile.setFormatter(formatter)
logger.addHandler(logfile)
files_to_keep.append(logfile.stream)
stdin, stdout, stderr = None, None, None
else:
stdin, stdout, stderr = sys.stdin, sys.stdout, sys.stderr
@@ -135,6 +140,9 @@ def daemonize(args, name, main):
threaded=False),
stdin=stdin,
stdout=stdout,
stderr=stderr
stderr=stderr,
uid=FLAGS.uid,
gid=FLAGS.gid,
files_preserve=files_to_keep
):
main(args)

View File

@@ -18,23 +18,22 @@
import unittest
import logging
import webob
from nova import exception
from nova import flags
from nova import test
from nova.api import ec2
from nova.auth import manager
from nova.auth import rbac
FLAGS = flags.FLAGS
class Context(object):
pass
class AccessTestCase(test.BaseTestCase):
class AccessTestCase(test.TrialTestCase):
def setUp(self):
super(AccessTestCase, self).setUp()
FLAGS.connection_type = 'fake'
FLAGS.fake_storage = True
um = manager.AuthManager()
# Make test users
try:
@@ -74,9 +73,17 @@ class AccessTestCase(test.BaseTestCase):
try:
self.project.add_role(self.testsys, 'sysadmin')
except: pass
self.context = Context()
self.context.project = self.project
#user is set in each test
def noopWSGIApp(environ, start_response):
start_response('200 OK', [])
return ['']
self.mw = ec2.Authorizer(noopWSGIApp)
self.mw.action_roles = {'str': {
'_allow_all': ['all'],
'_allow_none': [],
'_allow_project_manager': ['projectmanager'],
'_allow_sys_and_net': ['sysadmin', 'netadmin'],
'_allow_sysadmin': ['sysadmin']}}
def tearDown(self):
um = manager.AuthManager()
@@ -89,76 +96,46 @@ class AccessTestCase(test.BaseTestCase):
um.delete_user('testsys')
super(AccessTestCase, self).tearDown()
def response_status(self, user, methodName):
context = Context()
context.project = self.project
context.user = user
environ = {'ec2.context' : context,
'ec2.controller': 'some string',
'ec2.action': methodName}
req = webob.Request.blank('/', environ)
resp = req.get_response(self.mw)
return resp.status_int
def shouldAllow(self, user, methodName):
self.assertEqual(200, self.response_status(user, methodName))
def shouldDeny(self, user, methodName):
self.assertEqual(401, self.response_status(user, methodName))
def test_001_allow_all(self):
self.context.user = self.testadmin
self.assertTrue(self._allow_all(self.context))
self.context.user = self.testpmsys
self.assertTrue(self._allow_all(self.context))
self.context.user = self.testnet
self.assertTrue(self._allow_all(self.context))
self.context.user = self.testsys
self.assertTrue(self._allow_all(self.context))
users = [self.testadmin, self.testpmsys, self.testnet, self.testsys]
for user in users:
self.shouldAllow(user, '_allow_all')
def test_002_allow_none(self):
self.context.user = self.testadmin
self.assertTrue(self._allow_none(self.context))
self.context.user = self.testpmsys
self.assertRaises(exception.NotAuthorized, self._allow_none, self.context)
self.context.user = self.testnet
self.assertRaises(exception.NotAuthorized, self._allow_none, self.context)
self.context.user = self.testsys
self.assertRaises(exception.NotAuthorized, self._allow_none, self.context)
self.shouldAllow(self.testadmin, '_allow_none')
users = [self.testpmsys, self.testnet, self.testsys]
for user in users:
self.shouldDeny(user, '_allow_none')
def test_003_allow_project_manager(self):
self.context.user = self.testadmin
self.assertTrue(self._allow_project_manager(self.context))
self.context.user = self.testpmsys
self.assertTrue(self._allow_project_manager(self.context))
self.context.user = self.testnet
self.assertRaises(exception.NotAuthorized, self._allow_project_manager, self.context)
self.context.user = self.testsys
self.assertRaises(exception.NotAuthorized, self._allow_project_manager, self.context)
for user in [self.testadmin, self.testpmsys]:
self.shouldAllow(user, '_allow_project_manager')
for user in [self.testnet, self.testsys]:
self.shouldDeny(user, '_allow_project_manager')
def test_004_allow_sys_and_net(self):
self.context.user = self.testadmin
self.assertTrue(self._allow_sys_and_net(self.context))
self.context.user = self.testpmsys # doesn't have the per project sysadmin
self.assertRaises(exception.NotAuthorized, self._allow_sys_and_net, self.context)
self.context.user = self.testnet
self.assertTrue(self._allow_sys_and_net(self.context))
self.context.user = self.testsys
self.assertTrue(self._allow_sys_and_net(self.context))
def test_005_allow_sys_no_pm(self):
self.context.user = self.testadmin
self.assertTrue(self._allow_sys_no_pm(self.context))
self.context.user = self.testpmsys
self.assertRaises(exception.NotAuthorized, self._allow_sys_no_pm, self.context)
self.context.user = self.testnet
self.assertRaises(exception.NotAuthorized, self._allow_sys_no_pm, self.context)
self.context.user = self.testsys
self.assertTrue(self._allow_sys_no_pm(self.context))
@rbac.allow('all')
def _allow_all(self, context):
return True
@rbac.allow('none')
def _allow_none(self, context):
return True
@rbac.allow('projectmanager')
def _allow_project_manager(self, context):
return True
@rbac.allow('sysadmin', 'netadmin')
def _allow_sys_and_net(self, context):
return True
@rbac.allow('sysadmin')
@rbac.deny('projectmanager')
def _allow_sys_no_pm(self, context):
return True
for user in [self.testadmin, self.testnet, self.testsys]:
self.shouldAllow(user, '_allow_sys_and_net')
# denied because it doesn't have the per project sysadmin
for user in [self.testpmsys]:
self.shouldDeny(user, '_allow_sys_and_net')
if __name__ == "__main__":
# TODO: Implement use_fake as an option

View File

@@ -16,167 +16,102 @@
# License for the specific language governing permissions and limitations
# under the License.
"""Unit tests for the API endpoint"""
import boto
from boto.ec2 import regioninfo
import httplib
import random
import StringIO
from tornado import httpserver
from twisted.internet import defer
import webob
from nova import flags
from nova import test
from nova import api
from nova.api.ec2 import cloud
from nova.auth import manager
from nova.endpoint import api
from nova.endpoint import cloud
FLAGS = flags.FLAGS
# NOTE(termie): These are a bunch of helper methods and classes to short
# circuit boto calls and feed them into our tornado handlers,
# it's pretty damn circuitous so apologies if you have to fix
# a bug in it
def boto_to_tornado(method, path, headers, data, host, connection=None):
""" translate boto requests into tornado requests
connection should be a FakeTornadoHttpConnection instance
"""
try:
headers = httpserver.HTTPHeaders()
except AttributeError:
from tornado import httputil
headers = httputil.HTTPHeaders()
for k, v in headers.iteritems():
headers[k] = v
req = httpserver.HTTPRequest(method=method,
uri=path,
headers=headers,
body=data,
host=host,
remote_ip='127.0.0.1',
connection=connection)
return req
def raw_to_httpresponse(s):
""" translate a raw tornado http response into an httplib.HTTPResponse """
sock = FakeHttplibSocket(s)
resp = httplib.HTTPResponse(sock)
resp.begin()
return resp
FLAGS.FAKE_subdomain = 'ec2'
class FakeHttplibSocket(object):
""" a fake socket implementation for httplib.HTTPResponse, trivial """
def __init__(self, s):
self.fp = StringIO.StringIO(s)
"""a fake socket implementation for httplib.HTTPResponse, trivial"""
def __init__(self, response_string):
self._buffer = StringIO.StringIO(response_string)
def makefile(self, mode, other):
return self.fp
class FakeTornadoStream(object):
""" a fake stream to satisfy tornado's assumptions, trivial """
def set_close_callback(self, f):
pass
class FakeTornadoConnection(object):
""" a fake connection object for tornado to pass to its handlers
web requests are expected to write to this as they get data and call
finish when they are done with the request, we buffer the writes and
kick off a callback when it is done so that we can feed the result back
into boto.
"""
def __init__(self, d):
self.d = d
self._buffer = StringIO.StringIO()
def write(self, chunk):
self._buffer.write(chunk)
def finish(self):
s = self._buffer.getvalue()
self.d.callback(s)
xheaders = None
@property
def stream(self):
return FakeTornadoStream()
def makefile(self, _mode, _other):
"""Returns the socket's internal buffer"""
return self._buffer
class FakeHttplibConnection(object):
""" a fake httplib.HTTPConnection for boto to use
"""A fake httplib.HTTPConnection for boto to use
requests made via this connection actually get translated and routed into
our tornado app, we then wait for the response and turn it back into
our WSGI app, we then wait for the response and turn it back into
the httplib.HTTPResponse that boto expects.
"""
def __init__(self, app, host, is_secure=False):
self.app = app
self.host = host
self.deferred = defer.Deferred()
def request(self, method, path, data, headers):
req = boto_to_tornado
conn = FakeTornadoConnection(self.deferred)
request = boto_to_tornado(connection=conn,
method=method,
path=path,
headers=headers,
data=data,
host=self.host)
handler = self.app(request)
self.deferred.addCallback(raw_to_httpresponse)
req = webob.Request.blank(path)
req.method = method
req.body = data
req.headers = headers
req.headers['Accept'] = 'text/html'
req.host = self.host
# Call the WSGI app, get the HTTP response
resp = str(req.get_response(self.app))
# For some reason, the response doesn't have "HTTP/1.0 " prepended; I
# guess that's a function the web server usually provides.
resp = "HTTP/1.0 %s" % resp
sock = FakeHttplibSocket(resp)
self.http_response = httplib.HTTPResponse(sock)
self.http_response.begin()
def getresponse(self):
@defer.inlineCallbacks
def _waiter():
result = yield self.deferred
defer.returnValue(result)
d = _waiter()
# NOTE(termie): defer.returnValue above should ensure that
# this deferred has already been called by the time
# we get here, we are going to cheat and return
# the result of the callback
return d.result
return self.http_response
def close(self):
"""Required for compatibility with boto/tornado"""
pass
class ApiEc2TestCase(test.BaseTestCase):
def setUp(self):
"""Unit test for the cloud controller on an EC2 API"""
def setUp(self): # pylint: disable-msg=C0103,C0111
super(ApiEc2TestCase, self).setUp()
self.manager = manager.AuthManager()
self.cloud = cloud.CloudController()
self.host = '127.0.0.1'
self.app = api.APIServerApplication({'Cloud': self.cloud})
self.app = api.API()
def expect_http(self, host=None, is_secure=False):
"""Returns a new EC2 connection"""
self.ec2 = boto.connect_ec2(
aws_access_key_id='fake',
aws_secret_access_key='fake',
is_secure=False,
region=regioninfo.RegionInfo(None, 'test', self.host),
port=FLAGS.cc_port,
port=8773,
path='/services/Cloud')
self.mox.StubOutWithMock(self.ec2, 'new_http_connection')
def expect_http(self, host=None, is_secure=False):
http = FakeHttplibConnection(
self.app, '%s:%d' % (self.host, FLAGS.cc_port), False)
self.app, '%s:8773' % (self.host), False)
# pylint: disable-msg=E1103
self.ec2.new_http_connection(host, is_secure).AndReturn(http)
return http
def test_describe_instances(self):
"""Test that, after creating a user and a project, the describe
instances call to the API works properly"""
self.expect_http()
self.mox.ReplayAll()
user = self.manager.create_user('fake', 'fake', 'fake')
@@ -187,14 +122,201 @@ class ApiEc2TestCase(test.BaseTestCase):
def test_get_all_key_pairs(self):
"""Test that, after creating a user and project and generating
a key pair, that the API call to list key pairs works properly"""
self.expect_http()
self.mox.ReplayAll()
keyname = "".join(random.choice("sdiuisudfsdcnpaqwertasd") for x in range(random.randint(4, 8)))
keyname = "".join(random.choice("sdiuisudfsdcnpaqwertasd") \
for x in range(random.randint(4, 8)))
user = self.manager.create_user('fake', 'fake', 'fake')
project = self.manager.create_project('fake', 'fake', 'fake')
self.manager.generate_key_pair(user.id, keyname)
# NOTE(vish): create depends on pool, so call helper directly
cloud._gen_key(None, user.id, keyname)
rv = self.ec2.get_all_key_pairs()
self.assertTrue(filter(lambda k: k.name == keyname, rv))
results = [k for k in rv if k.name == keyname]
self.assertEquals(len(results), 1)
self.manager.delete_project(project)
self.manager.delete_user(user)
def test_get_all_security_groups(self):
"""Test that we can retrieve security groups"""
self.expect_http()
self.mox.ReplayAll()
user = self.manager.create_user('fake', 'fake', 'fake', admin=True)
project = self.manager.create_project('fake', 'fake', 'fake')
rv = self.ec2.get_all_security_groups()
self.assertEquals(len(rv), 1)
self.assertEquals(rv[0].name, 'default')
self.manager.delete_project(project)
self.manager.delete_user(user)
def test_create_delete_security_group(self):
"""Test that we can create a security group"""
self.expect_http()
self.mox.ReplayAll()
user = self.manager.create_user('fake', 'fake', 'fake', admin=True)
project = self.manager.create_project('fake', 'fake', 'fake')
# At the moment, you need both of these to actually be netadmin
self.manager.add_role('fake', 'netadmin')
project.add_role('fake', 'netadmin')
security_group_name = "".join(random.choice("sdiuisudfsdcnpaqwertasd") \
for x in range(random.randint(4, 8)))
self.ec2.create_security_group(security_group_name, 'test group')
self.expect_http()
self.mox.ReplayAll()
rv = self.ec2.get_all_security_groups()
self.assertEquals(len(rv), 2)
self.assertTrue(security_group_name in [group.name for group in rv])
self.expect_http()
self.mox.ReplayAll()
self.ec2.delete_security_group(security_group_name)
self.manager.delete_project(project)
self.manager.delete_user(user)
def test_authorize_revoke_security_group_cidr(self):
"""
Test that we can add and remove CIDR based rules
to a security group
"""
self.expect_http()
self.mox.ReplayAll()
user = self.manager.create_user('fake', 'fake', 'fake')
project = self.manager.create_project('fake', 'fake', 'fake')
# At the moment, you need both of these to actually be netadmin
self.manager.add_role('fake', 'netadmin')
project.add_role('fake', 'netadmin')
security_group_name = "".join(random.choice("sdiuisudfsdcnpaqwertasd") \
for x in range(random.randint(4, 8)))
group = self.ec2.create_security_group(security_group_name, 'test group')
self.expect_http()
self.mox.ReplayAll()
group.connection = self.ec2
group.authorize('tcp', 80, 81, '0.0.0.0/0')
self.expect_http()
self.mox.ReplayAll()
rv = self.ec2.get_all_security_groups()
# I don't bother checkng that we actually find it here,
# because the create/delete unit test further up should
# be good enough for that.
for group in rv:
if group.name == security_group_name:
self.assertEquals(len(group.rules), 1)
self.assertEquals(int(group.rules[0].from_port), 80)
self.assertEquals(int(group.rules[0].to_port), 81)
self.assertEquals(len(group.rules[0].grants), 1)
self.assertEquals(str(group.rules[0].grants[0]), '0.0.0.0/0')
self.expect_http()
self.mox.ReplayAll()
group.connection = self.ec2
group.revoke('tcp', 80, 81, '0.0.0.0/0')
self.expect_http()
self.mox.ReplayAll()
self.ec2.delete_security_group(security_group_name)
self.expect_http()
self.mox.ReplayAll()
group.connection = self.ec2
rv = self.ec2.get_all_security_groups()
self.assertEqual(len(rv), 1)
self.assertEqual(rv[0].name, 'default')
self.manager.delete_project(project)
self.manager.delete_user(user)
return
def test_authorize_revoke_security_group_foreign_group(self):
"""
Test that we can grant and revoke another security group access
to a security group
"""
self.expect_http()
self.mox.ReplayAll()
user = self.manager.create_user('fake', 'fake', 'fake', admin=True)
project = self.manager.create_project('fake', 'fake', 'fake')
# At the moment, you need both of these to actually be netadmin
self.manager.add_role('fake', 'netadmin')
project.add_role('fake', 'netadmin')
security_group_name = "".join(random.choice("sdiuisudfsdcnpaqwertasd") \
for x in range(random.randint(4, 8)))
other_security_group_name = "".join(random.choice("sdiuisudfsdcnpaqwertasd") \
for x in range(random.randint(4, 8)))
group = self.ec2.create_security_group(security_group_name, 'test group')
self.expect_http()
self.mox.ReplayAll()
other_group = self.ec2.create_security_group(other_security_group_name,
'some other group')
self.expect_http()
self.mox.ReplayAll()
group.connection = self.ec2
group.authorize(src_group=other_group)
self.expect_http()
self.mox.ReplayAll()
rv = self.ec2.get_all_security_groups()
# I don't bother checkng that we actually find it here,
# because the create/delete unit test further up should
# be good enough for that.
for group in rv:
if group.name == security_group_name:
self.assertEquals(len(group.rules), 1)
self.assertEquals(len(group.rules[0].grants), 1)
self.assertEquals(str(group.rules[0].grants[0]),
'%s-%s' % (other_security_group_name, 'fake'))
self.expect_http()
self.mox.ReplayAll()
rv = self.ec2.get_all_security_groups()
for group in rv:
if group.name == security_group_name:
self.expect_http()
self.mox.ReplayAll()
group.connection = self.ec2
group.revoke(src_group=other_group)
self.expect_http()
self.mox.ReplayAll()
self.ec2.delete_security_group(security_group_name)
self.manager.delete_project(project)
self.manager.delete_user(user)
return

View File

@@ -17,8 +17,6 @@
# under the License.
import logging
from M2Crypto import BIO
from M2Crypto import RSA
from M2Crypto import X509
import unittest
@@ -26,31 +24,76 @@ from nova import crypto
from nova import flags
from nova import test
from nova.auth import manager
from nova.endpoint import cloud
from nova.api.ec2 import cloud
FLAGS = flags.FLAGS
class user_generator(object):
def __init__(self, manager, **user_state):
if 'name' not in user_state:
user_state['name'] = 'test1'
self.manager = manager
self.user = manager.create_user(**user_state)
class AuthTestCase(test.BaseTestCase):
flush_db = False
def __enter__(self):
return self.user
def __exit__(self, value, type, trace):
self.manager.delete_user(self.user)
class project_generator(object):
def __init__(self, manager, **project_state):
if 'name' not in project_state:
project_state['name'] = 'testproj'
if 'manager_user' not in project_state:
project_state['manager_user'] = 'test1'
self.manager = manager
self.project = manager.create_project(**project_state)
def __enter__(self):
return self.project
def __exit__(self, value, type, trace):
self.manager.delete_project(self.project)
class user_and_project_generator(object):
def __init__(self, manager, user_state={}, project_state={}):
self.manager = manager
if 'name' not in user_state:
user_state['name'] = 'test1'
if 'name' not in project_state:
project_state['name'] = 'testproj'
if 'manager_user' not in project_state:
project_state['manager_user'] = 'test1'
self.user = manager.create_user(**user_state)
self.project = manager.create_project(**project_state)
def __enter__(self):
return (self.user, self.project)
def __exit__(self, value, type, trace):
self.manager.delete_user(self.user)
self.manager.delete_project(self.project)
class AuthManagerTestCase(object):
def setUp(self):
super(AuthTestCase, self).setUp()
self.flags(connection_type='fake',
fake_storage=True)
FLAGS.auth_driver = self.auth_driver
super(AuthManagerTestCase, self).setUp()
self.flags(connection_type='fake')
self.manager = manager.AuthManager()
def test_001_can_create_users(self):
self.manager.create_user('test1', 'access', 'secret')
self.manager.create_user('test2')
def test_create_and_find_user(self):
with user_generator(self.manager):
self.assert_(self.manager.get_user('test1'))
def test_002_can_get_user(self):
user = self.manager.get_user('test1')
def test_003_can_retreive_properties(self):
user = self.manager.get_user('test1')
self.assertEqual('test1', user.id)
self.assertEqual('access', user.access)
self.assertEqual('secret', user.secret)
def test_create_and_find_with_properties(self):
with user_generator(self.manager, name="herbert", secret="classified",
access="private-party"):
u = self.manager.get_user('herbert')
self.assertEqual('herbert', u.id)
self.assertEqual('herbert', u.name)
self.assertEqual('classified', u.secret)
self.assertEqual('private-party', u.access)
def test_004_signature_is_valid(self):
#self.assertTrue(self.manager.authenticate( **boto.generate_url ... ? ? ? ))
@@ -67,156 +110,222 @@ class AuthTestCase(test.BaseTestCase):
'export S3_URL="http://127.0.0.1:3333/"\n' +
'export EC2_USER_ID="test1"\n')
def test_006_test_key_storage(self):
user = self.manager.get_user('test1')
user.create_key_pair('public', 'key', 'fingerprint')
key = user.get_key_pair('public')
self.assertEqual('key', key.public_key)
self.assertEqual('fingerprint', key.fingerprint)
def test_can_list_users(self):
with user_generator(self.manager):
with user_generator(self.manager, name="test2"):
users = self.manager.get_users()
self.assert_(filter(lambda u: u.id == 'test1', users))
self.assert_(filter(lambda u: u.id == 'test2', users))
self.assert_(not filter(lambda u: u.id == 'test3', users))
def test_can_add_and_remove_user_role(self):
with user_generator(self.manager):
self.assertFalse(self.manager.has_role('test1', 'itsec'))
self.manager.add_role('test1', 'itsec')
self.assertTrue(self.manager.has_role('test1', 'itsec'))
self.manager.remove_role('test1', 'itsec')
self.assertFalse(self.manager.has_role('test1', 'itsec'))
def test_007_test_key_generation(self):
user = self.manager.get_user('test1')
private_key, fingerprint = user.generate_key_pair('public2')
key = RSA.load_key_string(private_key, callback=lambda: None)
bio = BIO.MemoryBuffer()
public_key = user.get_key_pair('public2').public_key
key.save_pub_key_bio(bio)
converted = crypto.ssl_pub_to_ssh_pub(bio.read())
# assert key fields are equal
self.assertEqual(public_key.split(" ")[1].strip(),
converted.split(" ")[1].strip())
def test_can_create_and_get_project(self):
with user_and_project_generator(self.manager) as (u,p):
self.assert_(self.manager.get_user('test1'))
self.assert_(self.manager.get_user('test1'))
self.assert_(self.manager.get_project('testproj'))
def test_008_can_list_key_pairs(self):
keys = self.manager.get_user('test1').get_key_pairs()
self.assertTrue(filter(lambda k: k.name == 'public', keys))
self.assertTrue(filter(lambda k: k.name == 'public2', keys))
def test_can_list_projects(self):
with user_and_project_generator(self.manager):
with project_generator(self.manager, name="testproj2"):
projects = self.manager.get_projects()
self.assert_(filter(lambda p: p.name == 'testproj', projects))
self.assert_(filter(lambda p: p.name == 'testproj2', projects))
self.assert_(not filter(lambda p: p.name == 'testproj3',
projects))
def test_009_can_delete_key_pair(self):
self.manager.get_user('test1').delete_key_pair('public')
keys = self.manager.get_user('test1').get_key_pairs()
self.assertFalse(filter(lambda k: k.name == 'public', keys))
def test_can_create_and_get_project_with_attributes(self):
with user_generator(self.manager):
with project_generator(self.manager, description='A test project'):
project = self.manager.get_project('testproj')
self.assertEqual('A test project', project.description)
def test_010_can_list_users(self):
users = self.manager.get_users()
logging.warn(users)
self.assertTrue(filter(lambda u: u.id == 'test1', users))
def test_can_create_project_with_manager(self):
with user_and_project_generator(self.manager) as (user, project):
self.assertEqual('test1', project.project_manager_id)
self.assertTrue(self.manager.is_project_manager(user, project))
def test_101_can_add_user_role(self):
self.assertFalse(self.manager.has_role('test1', 'itsec'))
self.manager.add_role('test1', 'itsec')
self.assertTrue(self.manager.has_role('test1', 'itsec'))
def test_create_project_assigns_manager_to_members(self):
with user_and_project_generator(self.manager) as (user, project):
self.assertTrue(self.manager.is_project_member(user, project))
def test_199_can_remove_user_role(self):
self.assertTrue(self.manager.has_role('test1', 'itsec'))
self.manager.remove_role('test1', 'itsec')
self.assertFalse(self.manager.has_role('test1', 'itsec'))
def test_no_extra_project_members(self):
with user_generator(self.manager, name='test2') as baduser:
with user_and_project_generator(self.manager) as (user, project):
self.assertFalse(self.manager.is_project_member(baduser,
project))
def test_201_can_create_project(self):
project = self.manager.create_project('testproj', 'test1', 'A test project', ['test1'])
self.assertTrue(filter(lambda p: p.name == 'testproj', self.manager.get_projects()))
self.assertEqual(project.name, 'testproj')
self.assertEqual(project.description, 'A test project')
self.assertEqual(project.project_manager_id, 'test1')
self.assertTrue(project.has_member('test1'))
def test_no_extra_project_managers(self):
with user_generator(self.manager, name='test2') as baduser:
with user_and_project_generator(self.manager) as (user, project):
self.assertFalse(self.manager.is_project_manager(baduser,
project))
def test_202_user1_is_project_member(self):
self.assertTrue(self.manager.get_user('test1').is_project_member('testproj'))
def test_can_add_user_to_project(self):
with user_generator(self.manager, name='test2') as user:
with user_and_project_generator(self.manager) as (_user, project):
self.manager.add_to_project(user, project)
project = self.manager.get_project('testproj')
self.assertTrue(self.manager.is_project_member(user, project))
def test_203_user2_is_not_project_member(self):
self.assertFalse(self.manager.get_user('test2').is_project_member('testproj'))
def test_can_remove_user_from_project(self):
with user_generator(self.manager, name='test2') as user:
with user_and_project_generator(self.manager) as (_user, project):
self.manager.add_to_project(user, project)
project = self.manager.get_project('testproj')
self.assertTrue(self.manager.is_project_member(user, project))
self.manager.remove_from_project(user, project)
project = self.manager.get_project('testproj')
self.assertFalse(self.manager.is_project_member(user, project))
def test_204_user1_is_project_manager(self):
self.assertTrue(self.manager.get_user('test1').is_project_manager('testproj'))
def test_can_add_remove_user_with_role(self):
with user_generator(self.manager, name='test2') as user:
with user_and_project_generator(self.manager) as (_user, project):
# NOTE(todd): after modifying users you must reload project
self.manager.add_to_project(user, project)
project = self.manager.get_project('testproj')
self.manager.add_role(user, 'developer', project)
self.assertTrue(self.manager.is_project_member(user, project))
self.manager.remove_from_project(user, project)
project = self.manager.get_project('testproj')
self.assertFalse(self.manager.has_role(user, 'developer',
project))
self.assertFalse(self.manager.is_project_member(user, project))
def test_205_user2_is_not_project_manager(self):
self.assertFalse(self.manager.get_user('test2').is_project_manager('testproj'))
def test_can_generate_x509(self):
# NOTE(todd): this doesn't assert against the auth manager
# so it probably belongs in crypto_unittest
# but I'm leaving it where I found it.
with user_and_project_generator(self.manager) as (user, project):
# NOTE(todd): Should mention why we must setup controller first
# (somebody please clue me in)
cloud_controller = cloud.CloudController()
cloud_controller.setup()
_key, cert_str = self.manager._generate_x509_cert('test1',
'testproj')
logging.debug(cert_str)
def test_206_can_add_user_to_project(self):
self.manager.add_to_project('test2', 'testproj')
self.assertTrue(self.manager.get_project('testproj').has_member('test2'))
# Need to verify that it's signed by the right intermediate CA
full_chain = crypto.fetch_ca(project_id='testproj', chain=True)
int_cert = crypto.fetch_ca(project_id='testproj', chain=False)
cloud_cert = crypto.fetch_ca()
logging.debug("CA chain:\n\n =====\n%s\n\n=====" % full_chain)
signed_cert = X509.load_cert_string(cert_str)
chain_cert = X509.load_cert_string(full_chain)
int_cert = X509.load_cert_string(int_cert)
cloud_cert = X509.load_cert_string(cloud_cert)
self.assertTrue(signed_cert.verify(chain_cert.get_pubkey()))
self.assertTrue(signed_cert.verify(int_cert.get_pubkey()))
if not FLAGS.use_intermediate_ca:
self.assertTrue(signed_cert.verify(cloud_cert.get_pubkey()))
else:
self.assertFalse(signed_cert.verify(cloud_cert.get_pubkey()))
def test_207_can_remove_user_from_project(self):
self.manager.remove_from_project('test2', 'testproj')
self.assertFalse(self.manager.get_project('testproj').has_member('test2'))
def test_adding_role_to_project_is_ignored_unless_added_to_user(self):
with user_and_project_generator(self.manager) as (user, project):
self.assertFalse(self.manager.has_role(user, 'sysadmin', project))
self.manager.add_role(user, 'sysadmin', project)
# NOTE(todd): it will still show up in get_user_roles(u, project)
self.assertFalse(self.manager.has_role(user, 'sysadmin', project))
self.manager.add_role(user, 'sysadmin')
self.assertTrue(self.manager.has_role(user, 'sysadmin', project))
def test_208_can_remove_add_user_with_role(self):
self.manager.add_to_project('test2', 'testproj')
self.manager.add_role('test2', 'developer', 'testproj')
self.manager.remove_from_project('test2', 'testproj')
self.assertFalse(self.manager.has_role('test2', 'developer', 'testproj'))
self.manager.add_to_project('test2', 'testproj')
self.manager.remove_from_project('test2', 'testproj')
def test_add_user_role_doesnt_infect_project_roles(self):
with user_and_project_generator(self.manager) as (user, project):
self.assertFalse(self.manager.has_role(user, 'sysadmin', project))
self.manager.add_role(user, 'sysadmin')
self.assertFalse(self.manager.has_role(user, 'sysadmin', project))
def test_209_can_generate_x509(self):
# MUST HAVE RUN CLOUD SETUP BY NOW
self.cloud = cloud.CloudController()
self.cloud.setup()
_key, cert_str = self.manager._generate_x509_cert('test1', 'testproj')
logging.debug(cert_str)
def test_can_list_user_roles(self):
with user_and_project_generator(self.manager) as (user, project):
self.manager.add_role(user, 'sysadmin')
roles = self.manager.get_user_roles(user)
self.assertTrue('sysadmin' in roles)
self.assertFalse('netadmin' in roles)
# Need to verify that it's signed by the right intermediate CA
full_chain = crypto.fetch_ca(project_id='testproj', chain=True)
int_cert = crypto.fetch_ca(project_id='testproj', chain=False)
cloud_cert = crypto.fetch_ca()
logging.debug("CA chain:\n\n =====\n%s\n\n=====" % full_chain)
signed_cert = X509.load_cert_string(cert_str)
chain_cert = X509.load_cert_string(full_chain)
int_cert = X509.load_cert_string(int_cert)
cloud_cert = X509.load_cert_string(cloud_cert)
self.assertTrue(signed_cert.verify(chain_cert.get_pubkey()))
self.assertTrue(signed_cert.verify(int_cert.get_pubkey()))
def test_can_list_project_roles(self):
with user_and_project_generator(self.manager) as (user, project):
self.manager.add_role(user, 'sysadmin')
self.manager.add_role(user, 'sysadmin', project)
self.manager.add_role(user, 'netadmin', project)
project_roles = self.manager.get_user_roles(user, project)
self.assertTrue('sysadmin' in project_roles)
self.assertTrue('netadmin' in project_roles)
# has role should be false user-level role is missing
self.assertFalse(self.manager.has_role(user, 'netadmin', project))
if not FLAGS.use_intermediate_ca:
self.assertTrue(signed_cert.verify(cloud_cert.get_pubkey()))
else:
self.assertFalse(signed_cert.verify(cloud_cert.get_pubkey()))
def test_can_remove_user_roles(self):
with user_and_project_generator(self.manager) as (user, project):
self.manager.add_role(user, 'sysadmin')
self.assertTrue(self.manager.has_role(user, 'sysadmin'))
self.manager.remove_role(user, 'sysadmin')
self.assertFalse(self.manager.has_role(user, 'sysadmin'))
def test_210_can_add_project_role(self):
project = self.manager.get_project('testproj')
self.assertFalse(project.has_role('test1', 'sysadmin'))
self.manager.add_role('test1', 'sysadmin')
self.assertFalse(project.has_role('test1', 'sysadmin'))
project.add_role('test1', 'sysadmin')
self.assertTrue(project.has_role('test1', 'sysadmin'))
def test_removing_user_role_hides_it_from_project(self):
with user_and_project_generator(self.manager) as (user, project):
self.manager.add_role(user, 'sysadmin')
self.manager.add_role(user, 'sysadmin', project)
self.assertTrue(self.manager.has_role(user, 'sysadmin', project))
self.manager.remove_role(user, 'sysadmin')
self.assertFalse(self.manager.has_role(user, 'sysadmin', project))
def test_211_can_list_project_roles(self):
project = self.manager.get_project('testproj')
user = self.manager.get_user('test1')
self.manager.add_role(user, 'netadmin', project)
roles = self.manager.get_user_roles(user)
self.assertTrue('sysadmin' in roles)
self.assertFalse('netadmin' in roles)
project_roles = self.manager.get_user_roles(user, project)
self.assertTrue('sysadmin' in project_roles)
self.assertTrue('netadmin' in project_roles)
# has role should be false because global role is missing
self.assertFalse(self.manager.has_role(user, 'netadmin', project))
def test_can_remove_project_role_but_keep_user_role(self):
with user_and_project_generator(self.manager) as (user, project):
self.manager.add_role(user, 'sysadmin')
self.manager.add_role(user, 'sysadmin', project)
self.assertTrue(self.manager.has_role(user, 'sysadmin'))
self.manager.remove_role(user, 'sysadmin', project)
self.assertFalse(self.manager.has_role(user, 'sysadmin', project))
self.assertTrue(self.manager.has_role(user, 'sysadmin'))
def test_can_retrieve_project_by_user(self):
with user_and_project_generator(self.manager) as (user, project):
self.assertEqual(1, len(self.manager.get_projects('test1')))
def test_212_can_remove_project_role(self):
project = self.manager.get_project('testproj')
self.assertTrue(project.has_role('test1', 'sysadmin'))
project.remove_role('test1', 'sysadmin')
self.assertFalse(project.has_role('test1', 'sysadmin'))
self.manager.remove_role('test1', 'sysadmin')
self.assertFalse(project.has_role('test1', 'sysadmin'))
def test_can_modify_project(self):
with user_and_project_generator(self.manager):
with user_generator(self.manager, name='test2'):
self.manager.modify_project('testproj', 'test2', 'new desc')
project = self.manager.get_project('testproj')
self.assertEqual('test2', project.project_manager_id)
self.assertEqual('new desc', project.description)
def test_214_can_retrieve_project_by_user(self):
project = self.manager.create_project('testproj2', 'test2', 'Another test project', ['test2'])
self.assert_(len(self.manager.get_projects()) > 1)
self.assertEqual(len(self.manager.get_projects('test2')), 1)
def test_can_delete_project(self):
with user_generator(self.manager):
self.manager.create_project('testproj', 'test1')
self.assert_(self.manager.get_project('testproj'))
self.manager.delete_project('testproj')
projectlist = self.manager.get_projects()
self.assert_(not filter(lambda p: p.name == 'testproj',
projectlist))
def test_299_can_delete_project(self):
self.manager.delete_project('testproj')
self.assertFalse(filter(lambda p: p.name == 'testproj', self.manager.get_projects()))
self.manager.delete_project('testproj2')
def test_999_can_delete_users(self):
def test_can_delete_user(self):
self.manager.create_user('test1')
self.assert_(self.manager.get_user('test1'))
self.manager.delete_user('test1')
users = self.manager.get_users()
self.assertFalse(filter(lambda u: u.id == 'test1', users))
self.manager.delete_user('test2')
self.assertEqual(self.manager.get_user('test2'), None)
userlist = self.manager.get_users()
self.assert_(not filter(lambda u: u.id == 'test1', userlist))
def test_can_modify_users(self):
with user_generator(self.manager):
self.manager.modify_user('test1', 'access', 'secret', True)
user = self.manager.get_user('test1')
self.assertEqual('access', user.access)
self.assertEqual('secret', user.secret)
self.assertTrue(user.is_admin())
class AuthManagerLdapTestCase(AuthManagerTestCase, test.TrialTestCase):
auth_driver = 'nova.auth.ldapdriver.FakeLdapDriver'
class AuthManagerDbTestCase(AuthManagerTestCase, test.TrialTestCase):
auth_driver = 'nova.auth.dbdriver.DbDriver'
if __name__ == "__main__":

View File

@@ -16,70 +16,119 @@
# License for the specific language governing permissions and limitations
# under the License.
from base64 import b64decode
import json
import logging
from M2Crypto import BIO
from M2Crypto import RSA
import os
import StringIO
import tempfile
import time
from tornado import ioloop
from twisted.internet import defer
import unittest
from xml.etree import ElementTree
from nova import crypto
from nova import db
from nova import flags
from nova import rpc
from nova import test
from nova import utils
from nova.auth import manager
from nova.compute import service
from nova.endpoint import api
from nova.endpoint import cloud
from nova.compute import power_state
from nova.api.ec2 import context
from nova.api.ec2 import cloud
from nova.objectstore import image
FLAGS = flags.FLAGS
class CloudTestCase(test.BaseTestCase):
# Temp dirs for working with image attributes through the cloud controller
# (stole this from objectstore_unittest.py)
OSS_TEMPDIR = tempfile.mkdtemp(prefix='test_oss-')
IMAGES_PATH = os.path.join(OSS_TEMPDIR, 'images')
os.makedirs(IMAGES_PATH)
class CloudTestCase(test.TrialTestCase):
def setUp(self):
super(CloudTestCase, self).setUp()
self.flags(connection_type='fake',
fake_storage=True)
self.flags(connection_type='fake', images_path=IMAGES_PATH)
self.conn = rpc.Connection.instance()
logging.getLogger().setLevel(logging.DEBUG)
# set up our cloud
self.cloud = cloud.CloudController()
self.cloud_consumer = rpc.AdapterConsumer(connection=self.conn,
topic=FLAGS.cloud_topic,
proxy=self.cloud)
self.injected.append(self.cloud_consumer.attach_to_tornado(self.ioloop))
# set up a service
self.compute = service.ComputeService()
self.compute = utils.import_object(FLAGS.compute_manager)
self.compute_consumer = rpc.AdapterConsumer(connection=self.conn,
topic=FLAGS.compute_topic,
proxy=self.compute)
self.injected.append(self.compute_consumer.attach_to_tornado(self.ioloop))
topic=FLAGS.compute_topic,
proxy=self.compute)
self.compute_consumer.attach_to_eventlet()
self.network = utils.import_object(FLAGS.network_manager)
self.network_consumer = rpc.AdapterConsumer(connection=self.conn,
topic=FLAGS.network_topic,
proxy=self.network)
self.network_consumer.attach_to_eventlet()
try:
manager.AuthManager().create_user('admin', 'admin', 'admin')
except: pass
admin = manager.AuthManager().get_user('admin')
project = manager.AuthManager().create_project('proj', 'admin', 'proj')
self.context = api.APIRequestContext(handler=None,project=project,user=admin)
self.manager = manager.AuthManager()
self.user = self.manager.create_user('admin', 'admin', 'admin', True)
self.project = self.manager.create_project('proj', 'admin', 'proj')
self.context = context.APIRequestContext(user=self.user,
project=self.project)
def tearDown(self):
manager.AuthManager().delete_project('proj')
manager.AuthManager().delete_user('admin')
self.manager.delete_project(self.project)
self.manager.delete_user(self.user)
super(CloudTestCase, self).tearDown()
def _create_key(self, name):
# NOTE(vish): create depends on pool, so just call helper directly
return cloud._gen_key(self.context, self.context.user.id, name)
def test_console_output(self):
if FLAGS.connection_type == 'fake':
logging.debug("Can't test instances without a real virtual env.")
return
instance_id = 'foo'
inst = yield self.compute.run_instance(instance_id)
output = yield self.cloud.get_console_output(self.context, [instance_id])
logging.debug(output)
self.assert_(output)
rv = yield self.compute.terminate_instance(instance_id)
image_id = FLAGS.default_image
instance_type = FLAGS.default_instance_type
max_count = 1
kwargs = {'image_id': image_id,
'instance_type': instance_type,
'max_count': max_count }
rv = yield self.cloud.run_instances(self.context, **kwargs)
instance_id = rv['instancesSet'][0]['instanceId']
output = yield self.cloud.get_console_output(context=self.context, instance_id=[instance_id])
self.assertEquals(b64decode(output['output']), 'FAKE CONSOLE OUTPUT')
rv = yield self.cloud.terminate_instances(self.context, [instance_id])
def test_key_generation(self):
result = self._create_key('test')
private_key = result['private_key']
key = RSA.load_key_string(private_key, callback=lambda: None)
bio = BIO.MemoryBuffer()
public_key = db.key_pair_get(self.context,
self.context.user.id,
'test')['public_key']
key.save_pub_key_bio(bio)
converted = crypto.ssl_pub_to_ssh_pub(bio.read())
# assert key fields are equal
self.assertEqual(public_key.split(" ")[1].strip(),
converted.split(" ")[1].strip())
def test_describe_key_pairs(self):
self._create_key('test1')
self._create_key('test2')
result = self.cloud.describe_key_pairs(self.context)
keys = result["keypairsSet"]
self.assertTrue(filter(lambda k: k['keyName'] == 'test1', keys))
self.assertTrue(filter(lambda k: k['keyName'] == 'test2', keys))
def test_delete_key_pair(self):
self._create_key('test')
self.cloud.delete_key_pair(self.context, 'test')
def test_run_instances(self):
if FLAGS.connection_type == 'fake':
@@ -99,7 +148,7 @@ class CloudTestCase(test.BaseTestCase):
rv = yield defer.succeed(time.sleep(1))
info = self.cloud._get_instance(instance['instance_id'])
logging.debug(info['state'])
if info['state'] == node.Instance.RUNNING:
if info['state'] == power_state.RUNNING:
break
self.assert_(rv)
@@ -160,3 +209,68 @@ class CloudTestCase(test.BaseTestCase):
#for i in xrange(4):
# data = self.cloud.get_metadata(instance(i)['private_dns_name'])
# self.assert_(data['meta-data']['ami-id'] == 'ami-%s' % i)
@staticmethod
def _fake_set_image_description(ctxt, image_id, description):
from nova.objectstore import handler
class req:
pass
request = req()
request.context = ctxt
request.args = {'image_id': [image_id],
'description': [description]}
resource = handler.ImagesResource()
resource.render_POST(request)
def test_user_editable_image_endpoint(self):
pathdir = os.path.join(FLAGS.images_path, 'ami-testing')
os.mkdir(pathdir)
info = {'isPublic': False}
with open(os.path.join(pathdir, 'info.json'), 'w') as f:
json.dump(info, f)
img = image.Image('ami-testing')
# self.cloud.set_image_description(self.context, 'ami-testing',
# 'Foo Img')
# NOTE(vish): Above won't work unless we start objectstore or create
# a fake version of api/ec2/images.py conn that can
# call methods directly instead of going through boto.
# for now, just cheat and call the method directly
self._fake_set_image_description(self.context, 'ami-testing',
'Foo Img')
self.assertEqual('Foo Img', img.metadata['description'])
self._fake_set_image_description(self.context, 'ami-testing', '')
self.assertEqual('', img.metadata['description'])
def test_update_of_instance_display_fields(self):
inst = db.instance_create({}, {})
ec2_id = cloud.internal_id_to_ec2_id(inst['internal_id'])
self.cloud.update_instance(self.context, ec2_id,
display_name='c00l 1m4g3')
inst = db.instance_get({}, inst['id'])
self.assertEqual('c00l 1m4g3', inst['display_name'])
db.instance_destroy({}, inst['id'])
def test_update_of_instance_wont_update_private_fields(self):
inst = db.instance_create({}, {})
self.cloud.update_instance(self.context, inst['id'],
mac_address='DE:AD:BE:EF')
inst = db.instance_get({}, inst['id'])
self.assertEqual(None, inst['mac_address'])
db.instance_destroy({}, inst['id'])
def test_update_of_volume_display_fields(self):
vol = db.volume_create({}, {})
self.cloud.update_volume(self.context, vol['id'],
display_name='c00l v0lum3')
vol = db.volume_get({}, vol['id'])
self.assertEqual('c00l v0lum3', vol['display_name'])
db.volume_destroy({}, vol['id'])
def test_update_of_volume_wont_update_private_fields(self):
vol = db.volume_create({}, {})
self.cloud.update_volume(self.context, vol['id'],
mountpoint='/not/here')
vol = db.volume_get({}, vol['id'])
self.assertEqual(None, vol['mountpoint'])
db.volume_destroy({}, vol['id'])

View File

@@ -15,113 +15,119 @@
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Tests For Compute
"""
import datetime
import logging
import time
from twisted.internet import defer
from xml.etree import ElementTree
from twisted.internet import defer
from nova import db
from nova import exception
from nova import flags
from nova import test
from nova import utils
from nova.compute import model
from nova.compute import service
from nova.auth import manager
from nova.api import context
FLAGS = flags.FLAGS
class InstanceXmlTestCase(test.TrialTestCase):
# @defer.inlineCallbacks
def test_serialization(self):
# TODO: Reimplement this, it doesn't make sense in redis-land
return
# instance_id = 'foo'
# first_node = node.Node()
# inst = yield first_node.run_instance(instance_id)
#
# # force the state so that we can verify that it changes
# inst._s['state'] = node.Instance.NOSTATE
# xml = inst.toXml()
# self.assert_(ElementTree.parse(StringIO.StringIO(xml)))
#
# second_node = node.Node()
# new_inst = node.Instance.fromXml(second_node._conn, pool=second_node._pool, xml=xml)
# self.assertEqual(new_inst.state, node.Instance.RUNNING)
# rv = yield first_node.terminate_instance(instance_id)
class ComputeConnectionTestCase(test.TrialTestCase):
def setUp(self):
class ComputeTestCase(test.TrialTestCase):
"""Test case for compute"""
def setUp(self): # pylint: disable-msg=C0103
logging.getLogger().setLevel(logging.DEBUG)
super(ComputeConnectionTestCase, self).setUp()
super(ComputeTestCase, self).setUp()
self.flags(connection_type='fake',
fake_storage=True)
self.compute = service.ComputeService()
network_manager='nova.network.manager.FlatManager')
self.compute = utils.import_object(FLAGS.compute_manager)
self.manager = manager.AuthManager()
self.user = self.manager.create_user('fake', 'fake', 'fake')
self.project = self.manager.create_project('fake', 'fake', 'fake')
self.context = None
def create_instance(self):
instdir = model.InstanceDirectory()
inst = instdir.new()
# TODO(ja): add ami, ari, aki, user_data
def tearDown(self): # pylint: disable-msg=C0103
self.manager.delete_user(self.user)
self.manager.delete_project(self.project)
super(ComputeTestCase, self).tearDown()
def _create_instance(self):
"""Create a test instance"""
inst = {}
inst['image_id'] = 'ami-test'
inst['reservation_id'] = 'r-fakeres'
inst['launch_time'] = '10'
inst['user_id'] = 'fake'
inst['project_id'] = 'fake'
inst['user_id'] = self.user.id
inst['project_id'] = self.project.id
inst['instance_type'] = 'm1.tiny'
inst['node_name'] = FLAGS.node_name
inst['mac_address'] = utils.generate_mac()
inst['ami_launch_index'] = 0
inst.save()
return inst['instance_id']
return db.instance_create(self.context, inst)['id']
@defer.inlineCallbacks
def test_run_describe_terminate(self):
instance_id = self.create_instance()
def test_run_terminate(self):
"""Make sure it is possible to run and terminate instance"""
instance_id = self._create_instance()
rv = yield self.compute.run_instance(instance_id)
yield self.compute.run_instance(self.context, instance_id)
rv = yield self.compute.describe_instances()
logging.info("Running instances: %s", rv)
self.assertEqual(rv[instance_id].name, instance_id)
instances = db.instance_get_all(None)
logging.info("Running instances: %s", instances)
self.assertEqual(len(instances), 1)
rv = yield self.compute.terminate_instance(instance_id)
yield self.compute.terminate_instance(self.context, instance_id)
rv = yield self.compute.describe_instances()
logging.info("After terminating instances: %s", rv)
self.assertEqual(rv, {})
instances = db.instance_get_all(None)
logging.info("After terminating instances: %s", instances)
self.assertEqual(len(instances), 0)
@defer.inlineCallbacks
def test_run_terminate_timestamps(self):
"""Make sure timestamps are set for launched and destroyed"""
instance_id = self._create_instance()
instance_ref = db.instance_get(self.context, instance_id)
self.assertEqual(instance_ref['launched_at'], None)
self.assertEqual(instance_ref['deleted_at'], None)
launch = datetime.datetime.utcnow()
yield self.compute.run_instance(self.context, instance_id)
instance_ref = db.instance_get(self.context, instance_id)
self.assert_(instance_ref['launched_at'] > launch)
self.assertEqual(instance_ref['deleted_at'], None)
terminate = datetime.datetime.utcnow()
yield self.compute.terminate_instance(self.context, instance_id)
self.context = context.get_admin_context(user=self.user,
read_deleted=True)
instance_ref = db.instance_get(self.context, instance_id)
self.assert_(instance_ref['launched_at'] < terminate)
self.assert_(instance_ref['deleted_at'] > terminate)
@defer.inlineCallbacks
def test_reboot(self):
instance_id = self.create_instance()
rv = yield self.compute.run_instance(instance_id)
rv = yield self.compute.describe_instances()
self.assertEqual(rv[instance_id].name, instance_id)
yield self.compute.reboot_instance(instance_id)
rv = yield self.compute.describe_instances()
self.assertEqual(rv[instance_id].name, instance_id)
rv = yield self.compute.terminate_instance(instance_id)
"""Ensure instance can be rebooted"""
instance_id = self._create_instance()
yield self.compute.run_instance(self.context, instance_id)
yield self.compute.reboot_instance(self.context, instance_id)
yield self.compute.terminate_instance(self.context, instance_id)
@defer.inlineCallbacks
def test_console_output(self):
instance_id = self.create_instance()
rv = yield self.compute.run_instance(instance_id)
"""Make sure we can get console output from instance"""
instance_id = self._create_instance()
yield self.compute.run_instance(self.context, instance_id)
console = yield self.compute.get_console_output(instance_id)
console = yield self.compute.get_console_output(self.context,
instance_id)
self.assert_(console)
rv = yield self.compute.terminate_instance(instance_id)
yield self.compute.terminate_instance(self.context, instance_id)
@defer.inlineCallbacks
def test_run_instance_existing(self):
instance_id = self.create_instance()
rv = yield self.compute.run_instance(instance_id)
rv = yield self.compute.describe_instances()
self.assertEqual(rv[instance_id].name, instance_id)
self.assertRaises(exception.Error, self.compute.run_instance, instance_id)
rv = yield self.compute.terminate_instance(instance_id)
"""Ensure failure when running an instance that already exists"""
instance_id = self._create_instance()
yield self.compute.run_instance(self.context, instance_id)
self.assertFailure(self.compute.run_instance(self.context,
instance_id),
exception.Error)
yield self.compute.terminate_instance(self.context, instance_id)

View File

@@ -20,9 +20,20 @@ from nova import flags
FLAGS = flags.FLAGS
flags.DECLARE('volume_driver', 'nova.volume.manager')
FLAGS.volume_driver = 'nova.volume.driver.FakeAOEDriver'
FLAGS.connection_type = 'fake'
FLAGS.fake_storage = True
FLAGS.fake_rabbit = True
FLAGS.auth_driver = 'nova.auth.dbdriver.DbDriver'
flags.DECLARE('network_size', 'nova.network.manager')
flags.DECLARE('num_networks', 'nova.network.manager')
flags.DECLARE('fake_network', 'nova.network.manager')
FLAGS.network_size = 16
FLAGS.num_networks = 5
FLAGS.fake_network = True
FLAGS.auth_driver = 'nova.auth.ldapdriver.FakeLdapDriver'
flags.DECLARE('num_shelves', 'nova.volume.manager')
flags.DECLARE('blades_per_shelf', 'nova.volume.manager')
FLAGS.num_shelves = 2
FLAGS.blades_per_shelf = 4
FLAGS.verbose = True
FLAGS.sql_connection = 'sqlite:///nova.sqlite'

View File

@@ -1,292 +0,0 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from datetime import datetime, timedelta
import logging
import time
from nova import flags
from nova import test
from nova import utils
from nova.compute import model
FLAGS = flags.FLAGS
class ModelTestCase(test.TrialTestCase):
def setUp(self):
super(ModelTestCase, self).setUp()
self.flags(connection_type='fake',
fake_storage=True)
def tearDown(self):
model.Instance('i-test').destroy()
model.Host('testhost').destroy()
model.Daemon('testhost', 'nova-testdaemon').destroy()
def create_instance(self):
inst = model.Instance('i-test')
inst['reservation_id'] = 'r-test'
inst['launch_time'] = '10'
inst['user_id'] = 'fake'
inst['project_id'] = 'fake'
inst['instance_type'] = 'm1.tiny'
inst['mac_address'] = utils.generate_mac()
inst['ami_launch_index'] = 0
inst['private_dns_name'] = '10.0.0.1'
inst.save()
return inst
def create_host(self):
host = model.Host('testhost')
host.save()
return host
def create_daemon(self):
daemon = model.Daemon('testhost', 'nova-testdaemon')
daemon.save()
return daemon
def create_session_token(self):
session_token = model.SessionToken('tk12341234')
session_token['user'] = 'testuser'
session_token.save()
return session_token
def test_create_instance(self):
"""store with create_instace, then test that a load finds it"""
instance = self.create_instance()
old = model.Instance(instance.identifier)
self.assertFalse(old.is_new_record())
def test_delete_instance(self):
"""create, then destroy, then make sure loads a new record"""
instance = self.create_instance()
instance.destroy()
newinst = model.Instance('i-test')
self.assertTrue(newinst.is_new_record())
def test_instance_added_to_set(self):
"""create, then check that it is listed in global set"""
instance = self.create_instance()
found = False
for x in model.InstanceDirectory().all:
if x.identifier == 'i-test':
found = True
self.assert_(found)
def test_instance_associates_project(self):
"""create, then check that it is listed for the project"""
instance = self.create_instance()
found = False
for x in model.InstanceDirectory().by_project(instance.project):
if x.identifier == 'i-test':
found = True
self.assert_(found)
def test_instance_associates_ip(self):
"""create, then check that it is listed for the ip"""
instance = self.create_instance()
found = False
x = model.InstanceDirectory().by_ip(instance['private_dns_name'])
self.assertEqual(x.identifier, 'i-test')
def test_instance_associates_node(self):
"""create, then check that it is listed for the node_name"""
instance = self.create_instance()
found = False
for x in model.InstanceDirectory().by_node(FLAGS.node_name):
if x.identifier == 'i-test':
found = True
self.assertFalse(found)
instance['node_name'] = 'test_node'
instance.save()
for x in model.InstanceDirectory().by_node('test_node'):
if x.identifier == 'i-test':
found = True
self.assert_(found)
def test_host_class_finds_hosts(self):
host = self.create_host()
self.assertEqual('testhost', model.Host.lookup('testhost').identifier)
def test_host_class_doesnt_find_missing_hosts(self):
rv = model.Host.lookup('woahnelly')
self.assertEqual(None, rv)
def test_create_host(self):
"""store with create_host, then test that a load finds it"""
host = self.create_host()
old = model.Host(host.identifier)
self.assertFalse(old.is_new_record())
def test_delete_host(self):
"""create, then destroy, then make sure loads a new record"""
instance = self.create_host()
instance.destroy()
newinst = model.Host('testhost')
self.assertTrue(newinst.is_new_record())
def test_host_added_to_set(self):
"""create, then check that it is included in list"""
instance = self.create_host()
found = False
for x in model.Host.all():
if x.identifier == 'testhost':
found = True
self.assert_(found)
def test_create_daemon_two_args(self):
"""create a daemon with two arguments"""
d = self.create_daemon()
d = model.Daemon('testhost', 'nova-testdaemon')
self.assertFalse(d.is_new_record())
def test_create_daemon_single_arg(self):
"""Create a daemon using the combined host:bin format"""
d = model.Daemon("testhost:nova-testdaemon")
d.save()
d = model.Daemon('testhost:nova-testdaemon')
self.assertFalse(d.is_new_record())
def test_equality_of_daemon_single_and_double_args(self):
"""Create a daemon using the combined host:bin arg, find with 2"""
d = model.Daemon("testhost:nova-testdaemon")
d.save()
d = model.Daemon('testhost', 'nova-testdaemon')
self.assertFalse(d.is_new_record())
def test_equality_daemon_of_double_and_single_args(self):
"""Create a daemon using the combined host:bin arg, find with 2"""
d = self.create_daemon()
d = model.Daemon('testhost:nova-testdaemon')
self.assertFalse(d.is_new_record())
def test_delete_daemon(self):
"""create, then destroy, then make sure loads a new record"""
instance = self.create_daemon()
instance.destroy()
newinst = model.Daemon('testhost', 'nova-testdaemon')
self.assertTrue(newinst.is_new_record())
def test_daemon_heartbeat(self):
"""Create a daemon, sleep, heartbeat, check for update"""
d = self.create_daemon()
ts = d['updated_at']
time.sleep(2)
d.heartbeat()
d2 = model.Daemon('testhost', 'nova-testdaemon')
ts2 = d2['updated_at']
self.assert_(ts2 > ts)
def test_daemon_added_to_set(self):
"""create, then check that it is included in list"""
instance = self.create_daemon()
found = False
for x in model.Daemon.all():
if x.identifier == 'testhost:nova-testdaemon':
found = True
self.assert_(found)
def test_daemon_associates_host(self):
"""create, then check that it is listed for the host"""
instance = self.create_daemon()
found = False
for x in model.Daemon.by_host('testhost'):
if x.identifier == 'testhost:nova-testdaemon':
found = True
self.assertTrue(found)
def test_create_session_token(self):
"""create"""
d = self.create_session_token()
d = model.SessionToken(d.token)
self.assertFalse(d.is_new_record())
def test_delete_session_token(self):
"""create, then destroy, then make sure loads a new record"""
instance = self.create_session_token()
instance.destroy()
newinst = model.SessionToken(instance.token)
self.assertTrue(newinst.is_new_record())
def test_session_token_added_to_set(self):
"""create, then check that it is included in list"""
instance = self.create_session_token()
found = False
for x in model.SessionToken.all():
if x.identifier == instance.token:
found = True
self.assert_(found)
def test_session_token_associates_user(self):
"""create, then check that it is listed for the user"""
instance = self.create_session_token()
found = False
for x in model.SessionToken.associated_to('user', 'testuser'):
if x.identifier == instance.identifier:
found = True
self.assertTrue(found)
def test_session_token_generation(self):
instance = model.SessionToken.generate('username', 'TokenType')
self.assertFalse(instance.is_new_record())
def test_find_generated_session_token(self):
instance = model.SessionToken.generate('username', 'TokenType')
found = model.SessionToken.lookup(instance.identifier)
self.assert_(found)
def test_update_session_token_expiry(self):
instance = model.SessionToken('tk12341234')
oldtime = datetime.utcnow()
instance['expiry'] = oldtime.strftime(utils.TIME_FORMAT)
instance.update_expiry()
expiry = utils.parse_isotime(instance['expiry'])
self.assert_(expiry > datetime.utcnow())
def test_session_token_lookup_when_expired(self):
instance = model.SessionToken.generate("testuser")
instance['expiry'] = datetime.utcnow().strftime(utils.TIME_FORMAT)
instance.save()
inst = model.SessionToken.lookup(instance.identifier)
self.assertFalse(inst)
def test_session_token_lookup_when_not_expired(self):
instance = model.SessionToken.generate("testuser")
inst = model.SessionToken.lookup(instance.identifier)
self.assert_(inst)
def test_session_token_is_expired_when_expired(self):
instance = model.SessionToken.generate("testuser")
instance['expiry'] = datetime.utcnow().strftime(utils.TIME_FORMAT)
self.assert_(instance.is_expired())
def test_session_token_is_expired_when_not_expired(self):
instance = model.SessionToken.generate("testuser")
self.assertFalse(instance.is_expired())
def test_session_token_ttl(self):
instance = model.SessionToken.generate("testuser")
now = datetime.utcnow()
delta = timedelta(hours=1)
instance['expiry'] = (now + delta).strftime(utils.TIME_FORMAT)
# give 5 seconds of fuzziness
self.assert_(abs(instance.ttl() - FLAGS.auth_token_ttl) < 5)

View File

@@ -22,14 +22,13 @@ import IPy
import os
import logging
from nova import db
from nova import exception
from nova import flags
from nova import test
from nova import utils
from nova.auth import manager
from nova.network import model
from nova.network import service
from nova.network import vpn
from nova.network.exception import NoMoreAddresses
from nova.api.ec2 import context
FLAGS = flags.FLAGS
@@ -41,169 +40,200 @@ class NetworkTestCase(test.TrialTestCase):
# NOTE(vish): if you change these flags, make sure to change the
# flags in the corresponding section in nova-dhcpbridge
self.flags(connection_type='fake',
fake_storage=True,
fake_network=True,
auth_driver='nova.auth.ldapdriver.FakeLdapDriver',
network_size=32)
network_size=16,
num_networks=5)
logging.getLogger().setLevel(logging.DEBUG)
self.manager = manager.AuthManager()
self.user = self.manager.create_user('netuser', 'netuser', 'netuser')
self.projects = []
self.projects.append(self.manager.create_project('netuser',
'netuser',
'netuser'))
for i in range(0, 6):
self.network = utils.import_object(FLAGS.network_manager)
self.context = context.APIRequestContext(project=None, user=self.user)
for i in range(5):
name = 'project%s' % i
self.projects.append(self.manager.create_project(name,
'netuser',
name))
vpn.NetworkData.create(self.projects[i].id)
self.service = service.VlanNetworkService()
project = self.manager.create_project(name, 'netuser', name)
self.projects.append(project)
# create the necessary network data for the project
user_context = context.APIRequestContext(project=self.projects[i],
user=self.user)
network_ref = self.network.get_network(user_context)
self.network.set_network_host(context.get_admin_context(),
network_ref['id'])
instance_ref = self._create_instance(0)
self.instance_id = instance_ref['id']
instance_ref = self._create_instance(1)
self.instance2_id = instance_ref['id']
def tearDown(self): # pylint: disable-msg=C0103
super(NetworkTestCase, self).tearDown()
# TODO(termie): this should really be instantiating clean datastores
# in between runs, one failure kills all the tests
db.instance_destroy(None, self.instance_id)
db.instance_destroy(None, self.instance2_id)
for project in self.projects:
self.manager.delete_project(project)
self.manager.delete_user(self.user)
def test_public_network_allocation(self):
def _create_instance(self, project_num, mac=None):
if not mac:
mac = utils.generate_mac()
project = self.projects[project_num]
self.context.project = project
return db.instance_create(self.context,
{'project_id': project.id,
'mac_address': mac})
def _create_address(self, project_num, instance_id=None):
"""Create an address in given project num"""
if instance_id is None:
instance_id = self.instance_id
self.context.project = self.projects[project_num]
return self.network.allocate_fixed_ip(self.context, instance_id)
def _deallocate_address(self, project_num, address):
self.context.project = self.projects[project_num]
self.network.deallocate_fixed_ip(self.context, address)
def test_public_network_association(self):
"""Makes sure that we can allocaate a public ip"""
pubnet = IPy.IP(flags.FLAGS.public_range)
address = self.service.allocate_elastic_ip(self.user.id,
self.projects[0].id)
self.assertTrue(IPy.IP(address) in pubnet)
# TODO(vish): better way of adding floating ips
self.context.project = self.projects[0]
pubnet = IPy.IP(flags.FLAGS.floating_range)
address = str(pubnet[0])
try:
db.floating_ip_get_by_address(None, address)
except exception.NotFound:
db.floating_ip_create(None, {'address': address,
'host': FLAGS.host})
float_addr = self.network.allocate_floating_ip(self.context,
self.projects[0].id)
fix_addr = self._create_address(0)
lease_ip(fix_addr)
self.assertEqual(float_addr, str(pubnet[0]))
self.network.associate_floating_ip(self.context, float_addr, fix_addr)
address = db.instance_get_floating_address(None, self.instance_id)
self.assertEqual(address, float_addr)
self.network.disassociate_floating_ip(self.context, float_addr)
address = db.instance_get_floating_address(None, self.instance_id)
self.assertEqual(address, None)
self.network.deallocate_floating_ip(self.context, float_addr)
self.network.deallocate_fixed_ip(self.context, fix_addr)
release_ip(fix_addr)
def test_allocate_deallocate_fixed_ip(self):
"""Makes sure that we can allocate and deallocate a fixed ip"""
result = self.service.allocate_fixed_ip(
self.user.id, self.projects[0].id)
address = result['private_dns_name']
mac = result['mac_address']
net = model.get_project_network(self.projects[0].id, "default")
self.assertEqual(True, is_in_project(address, self.projects[0].id))
hostname = "test-host"
issue_ip(mac, address, hostname, net.bridge_name)
self.service.deallocate_fixed_ip(address)
address = self._create_address(0)
self.assertTrue(is_allocated_in_project(address, self.projects[0].id))
lease_ip(address)
self._deallocate_address(0, address)
# Doesn't go away until it's dhcp released
self.assertEqual(True, is_in_project(address, self.projects[0].id))
self.assertTrue(is_allocated_in_project(address, self.projects[0].id))
release_ip(mac, address, hostname, net.bridge_name)
self.assertEqual(False, is_in_project(address, self.projects[0].id))
release_ip(address)
self.assertFalse(is_allocated_in_project(address, self.projects[0].id))
def test_side_effects(self):
"""Ensures allocating and releasing has no side effects"""
hostname = "side-effect-host"
result = self.service.allocate_fixed_ip(self.user.id,
self.projects[0].id)
mac = result['mac_address']
address = result['private_dns_name']
result = self.service.allocate_fixed_ip(self.user,
self.projects[1].id)
secondmac = result['mac_address']
secondaddress = result['private_dns_name']
address = self._create_address(0)
address2 = self._create_address(1, self.instance2_id)
net = model.get_project_network(self.projects[0].id, "default")
secondnet = model.get_project_network(self.projects[1].id, "default")
self.assertEqual(True, is_in_project(address, self.projects[0].id))
self.assertEqual(True, is_in_project(secondaddress,
self.projects[1].id))
self.assertEqual(False, is_in_project(address, self.projects[1].id))
self.assertTrue(is_allocated_in_project(address, self.projects[0].id))
self.assertTrue(is_allocated_in_project(address2, self.projects[1].id))
self.assertFalse(is_allocated_in_project(address, self.projects[1].id))
# Addresses are allocated before they're issued
issue_ip(mac, address, hostname, net.bridge_name)
issue_ip(secondmac, secondaddress, hostname, secondnet.bridge_name)
lease_ip(address)
lease_ip(address2)
self.service.deallocate_fixed_ip(address)
release_ip(mac, address, hostname, net.bridge_name)
self.assertEqual(False, is_in_project(address, self.projects[0].id))
self._deallocate_address(0, address)
release_ip(address)
self.assertFalse(is_allocated_in_project(address, self.projects[0].id))
# First address release shouldn't affect the second
self.assertEqual(True, is_in_project(secondaddress,
self.projects[1].id))
self.assertTrue(is_allocated_in_project(address2, self.projects[1].id))
self.service.deallocate_fixed_ip(secondaddress)
release_ip(secondmac, secondaddress, hostname, secondnet.bridge_name)
self.assertEqual(False, is_in_project(secondaddress,
self.projects[1].id))
self._deallocate_address(1, address2)
release_ip(address2)
self.assertFalse(is_allocated_in_project(address2,
self.projects[1].id))
def test_subnet_edge(self):
"""Makes sure that private ips don't overlap"""
result = self.service.allocate_fixed_ip(self.user.id,
self.projects[0].id)
firstaddress = result['private_dns_name']
hostname = "toomany-hosts"
first = self._create_address(0)
lease_ip(first)
instance_ids = []
for i in range(1, 5):
project_id = self.projects[i].id
result = self.service.allocate_fixed_ip(
self.user, project_id)
mac = result['mac_address']
address = result['private_dns_name']
result = self.service.allocate_fixed_ip(
self.user, project_id)
mac2 = result['mac_address']
address2 = result['private_dns_name']
result = self.service.allocate_fixed_ip(
self.user, project_id)
mac3 = result['mac_address']
address3 = result['private_dns_name']
net = model.get_project_network(project_id, "default")
issue_ip(mac, address, hostname, net.bridge_name)
issue_ip(mac2, address2, hostname, net.bridge_name)
issue_ip(mac3, address3, hostname, net.bridge_name)
self.assertEqual(False, is_in_project(address,
self.projects[0].id))
self.assertEqual(False, is_in_project(address2,
self.projects[0].id))
self.assertEqual(False, is_in_project(address3,
self.projects[0].id))
self.service.deallocate_fixed_ip(address)
self.service.deallocate_fixed_ip(address2)
self.service.deallocate_fixed_ip(address3)
release_ip(mac, address, hostname, net.bridge_name)
release_ip(mac2, address2, hostname, net.bridge_name)
release_ip(mac3, address3, hostname, net.bridge_name)
net = model.get_project_network(self.projects[0].id, "default")
self.service.deallocate_fixed_ip(firstaddress)
instance_ref = self._create_instance(i, mac=utils.generate_mac())
instance_ids.append(instance_ref['id'])
address = self._create_address(i, instance_ref['id'])
instance_ref = self._create_instance(i, mac=utils.generate_mac())
instance_ids.append(instance_ref['id'])
address2 = self._create_address(i, instance_ref['id'])
instance_ref = self._create_instance(i, mac=utils.generate_mac())
instance_ids.append(instance_ref['id'])
address3 = self._create_address(i, instance_ref['id'])
lease_ip(address)
lease_ip(address2)
lease_ip(address3)
self.context.project = self.projects[i]
self.assertFalse(is_allocated_in_project(address,
self.projects[0].id))
self.assertFalse(is_allocated_in_project(address2,
self.projects[0].id))
self.assertFalse(is_allocated_in_project(address3,
self.projects[0].id))
self.network.deallocate_fixed_ip(self.context, address)
self.network.deallocate_fixed_ip(self.context, address2)
self.network.deallocate_fixed_ip(self.context, address3)
release_ip(address)
release_ip(address2)
release_ip(address3)
for instance_id in instance_ids:
db.instance_destroy(None, instance_id)
self.context.project = self.projects[0]
self.network.deallocate_fixed_ip(self.context, first)
self._deallocate_address(0, first)
release_ip(first)
def test_vpn_ip_and_port_looks_valid(self):
"""Ensure the vpn ip and port are reasonable"""
self.assert_(self.projects[0].vpn_ip)
self.assert_(self.projects[0].vpn_port >= FLAGS.vpn_start_port)
self.assert_(self.projects[0].vpn_port <= FLAGS.vpn_end_port)
self.assert_(self.projects[0].vpn_port >= FLAGS.vpn_start)
self.assert_(self.projects[0].vpn_port <= FLAGS.vpn_start +
FLAGS.num_networks)
def test_too_many_vpns(self):
"""Ensure error is raised if we run out of vpn ports"""
vpns = []
for i in xrange(vpn.NetworkData.num_ports_for_ip(FLAGS.vpn_ip)):
vpns.append(vpn.NetworkData.create("vpnuser%s" % i))
self.assertRaises(vpn.NoMorePorts, vpn.NetworkData.create, "boom")
for network_datum in vpns:
network_datum.destroy()
def test_too_many_networks(self):
"""Ensure error is raised if we run out of networks"""
projects = []
networks_left = FLAGS.num_networks - db.network_count(None)
for i in range(networks_left):
project = self.manager.create_project('many%s' % i, self.user)
projects.append(project)
db.project_get_network(None, project.id)
project = self.manager.create_project('last', self.user)
projects.append(project)
self.assertRaises(db.NoMoreNetworks,
db.project_get_network,
None,
project.id)
for project in projects:
self.manager.delete_project(project)
def test_ips_are_reused(self):
"""Makes sure that ip addresses that are deallocated get reused"""
result = self.service.allocate_fixed_ip(
self.user.id, self.projects[0].id)
mac = result['mac_address']
address = result['private_dns_name']
address = self._create_address(0)
lease_ip(address)
self.network.deallocate_fixed_ip(self.context, address)
release_ip(address)
hostname = "reuse-host"
net = model.get_project_network(self.projects[0].id, "default")
issue_ip(mac, address, hostname, net.bridge_name)
self.service.deallocate_fixed_ip(address)
release_ip(mac, address, hostname, net.bridge_name)
result = self.service.allocate_fixed_ip(
self.user, self.projects[0].id)
secondmac = result['mac_address']
secondaddress = result['private_dns_name']
self.assertEqual(address, secondaddress)
issue_ip(secondmac, secondaddress, hostname, net.bridge_name)
self.service.deallocate_fixed_ip(secondaddress)
release_ip(secondmac, secondaddress, hostname, net.bridge_name)
address2 = self._create_address(0)
self.assertEqual(address, address2)
lease_ip(address)
self.network.deallocate_fixed_ip(self.context, address2)
release_ip(address)
def test_available_ips(self):
"""Make sure the number of available ips for the network is correct
@@ -216,44 +246,51 @@ class NetworkTestCase(test.TrialTestCase):
There are ips reserved at the bottom and top of the range.
services (network, gateway, CloudPipe, broadcast)
"""
net = model.get_project_network(self.projects[0].id, "default")
num_preallocated_ips = len(net.assigned)
network = db.project_get_network(None, self.projects[0].id)
net_size = flags.FLAGS.network_size
num_available_ips = net_size - (net.num_bottom_reserved_ips +
num_preallocated_ips +
net.num_top_reserved_ips)
self.assertEqual(num_available_ips, len(list(net.available)))
total_ips = (db.network_count_available_ips(None, network['id']) +
db.network_count_reserved_ips(None, network['id']) +
db.network_count_allocated_ips(None, network['id']))
self.assertEqual(total_ips, net_size)
def test_too_many_addresses(self):
"""Test for a NoMoreAddresses exception when all fixed ips are used.
"""
net = model.get_project_network(self.projects[0].id, "default")
hostname = "toomany-hosts"
macs = {}
addresses = {}
# Number of availaible ips is len of the available list
num_available_ips = len(list(net.available))
network = db.project_get_network(None, self.projects[0].id)
num_available_ips = db.network_count_available_ips(None,
network['id'])
addresses = []
instance_ids = []
for i in range(num_available_ips):
result = self.service.allocate_fixed_ip(self.user.id,
self.projects[0].id)
macs[i] = result['mac_address']
addresses[i] = result['private_dns_name']
issue_ip(macs[i], addresses[i], hostname, net.bridge_name)
instance_ref = self._create_instance(0)
instance_ids.append(instance_ref['id'])
address = self._create_address(0, instance_ref['id'])
addresses.append(address)
lease_ip(address)
self.assertEqual(len(list(net.available)), 0)
self.assertRaises(NoMoreAddresses, self.service.allocate_fixed_ip,
self.user.id, self.projects[0].id)
self.assertEqual(db.network_count_available_ips(None,
network['id']), 0)
self.assertRaises(db.NoMoreAddresses,
self.network.allocate_fixed_ip,
self.context,
'foo')
for i in range(len(addresses)):
self.service.deallocate_fixed_ip(addresses[i])
release_ip(macs[i], addresses[i], hostname, net.bridge_name)
self.assertEqual(len(list(net.available)), num_available_ips)
for i in range(num_available_ips):
self.network.deallocate_fixed_ip(self.context, addresses[i])
release_ip(addresses[i])
db.instance_destroy(None, instance_ids[i])
self.assertEqual(db.network_count_available_ips(None,
network['id']),
num_available_ips)
def is_in_project(address, project_id):
def is_allocated_in_project(address, project_id):
"""Returns true if address is in specified project"""
return address in model.get_project_network(project_id).assigned
project_net = db.project_get_network(None, project_id)
network = db.fixed_ip_get_network(None, address)
instance = db.fixed_ip_get_instance(None, address)
# instance exists until release
return instance is not None and network['id'] == project_net['id']
def binpath(script):
@@ -261,22 +298,28 @@ def binpath(script):
return os.path.abspath(os.path.join(__file__, "../../../bin", script))
def issue_ip(mac, private_ip, hostname, interface):
def lease_ip(private_ip):
"""Run add command on dhcpbridge"""
cmd = "%s add %s %s %s" % (binpath('nova-dhcpbridge'),
mac, private_ip, hostname)
env = {'DNSMASQ_INTERFACE': interface,
network_ref = db.fixed_ip_get_network(None, private_ip)
instance_ref = db.fixed_ip_get_instance(None, private_ip)
cmd = "%s add %s %s fake" % (binpath('nova-dhcpbridge'),
instance_ref['mac_address'],
private_ip)
env = {'DNSMASQ_INTERFACE': network_ref['bridge'],
'TESTING': '1',
'FLAGFILE': FLAGS.dhcpbridge_flagfile}
(out, err) = utils.execute(cmd, addl_env=env)
logging.debug("ISSUE_IP: %s, %s ", out, err)
def release_ip(mac, private_ip, hostname, interface):
def release_ip(private_ip):
"""Run del command on dhcpbridge"""
cmd = "%s del %s %s %s" % (binpath('nova-dhcpbridge'),
mac, private_ip, hostname)
env = {'DNSMASQ_INTERFACE': interface,
network_ref = db.fixed_ip_get_network(None, private_ip)
instance_ref = db.fixed_ip_get_instance(None, private_ip)
cmd = "%s del %s %s fake" % (binpath('nova-dhcpbridge'),
instance_ref['mac_address'],
private_ip)
env = {'DNSMASQ_INTERFACE': network_ref['bridge'],
'TESTING': '1',
'FLAGFILE': FLAGS.dhcpbridge_flagfile}
(out, err) = utils.execute(cmd, addl_env=env)

View File

@@ -53,7 +53,7 @@ os.makedirs(os.path.join(OSS_TEMPDIR, 'images'))
os.makedirs(os.path.join(OSS_TEMPDIR, 'buckets'))
class ObjectStoreTestCase(test.BaseTestCase):
class ObjectStoreTestCase(test.TrialTestCase):
"""Test objectstore API directly."""
def setUp(self): # pylint: disable-msg=C0103
@@ -133,13 +133,22 @@ class ObjectStoreTestCase(test.BaseTestCase):
self.assertRaises(NotFound, objectstore.bucket.Bucket, 'new_bucket')
def test_images(self):
self.do_test_images('1mb.manifest.xml', True,
'image_bucket1', 'i-testing1')
def test_images_no_kernel_or_ramdisk(self):
self.do_test_images('1mb.no_kernel_or_ramdisk.manifest.xml',
False, 'image_bucket2', 'i-testing2')
def do_test_images(self, manifest_file, expect_kernel_and_ramdisk,
image_bucket, image_name):
"Test the image API."
self.context.user = self.auth_manager.get_user('user1')
self.context.project = self.auth_manager.get_project('proj1')
# create a bucket for our bundle
objectstore.bucket.Bucket.create('image_bucket', self.context)
bucket = objectstore.bucket.Bucket('image_bucket')
objectstore.bucket.Bucket.create(image_bucket, self.context)
bucket = objectstore.bucket.Bucket(image_bucket)
# upload an image manifest/parts
bundle_path = os.path.join(os.path.dirname(__file__), 'bundle')
@@ -147,23 +156,39 @@ class ObjectStoreTestCase(test.BaseTestCase):
bucket[os.path.basename(path)] = open(path, 'rb').read()
# register an image
image.Image.register_aws_image('i-testing',
'image_bucket/1mb.manifest.xml',
image.Image.register_aws_image(image_name,
'%s/%s' % (image_bucket, manifest_file),
self.context)
# verify image
my_img = image.Image('i-testing')
my_img = image.Image(image_name)
result_image_file = os.path.join(my_img.path, 'image')
self.assertEqual(os.stat(result_image_file).st_size, 1048576)
sha = hashlib.sha1(open(result_image_file).read()).hexdigest()
self.assertEqual(sha, '3b71f43ff30f4b15b5cd85dd9e95ebc7e84eb5a3')
if expect_kernel_and_ramdisk:
# Verify the default kernel and ramdisk are set
self.assertEqual(my_img.metadata['kernelId'], 'aki-test')
self.assertEqual(my_img.metadata['ramdiskId'], 'ari-test')
else:
# Verify that the default kernel and ramdisk (the one from FLAGS)
# doesn't get embedded in the metadata
self.assertFalse('kernelId' in my_img.metadata)
self.assertFalse('ramdiskId' in my_img.metadata)
# verify image permissions
self.context.user = self.auth_manager.get_user('user2')
self.context.project = self.auth_manager.get_project('proj2')
self.assertFalse(my_img.is_authorized(self.context))
# change user-editable fields
my_img.update_user_editable_fields({'display_name': 'my cool image'})
self.assertEqual('my cool image', my_img.metadata['displayName'])
my_img.update_user_editable_fields({'display_name': ''})
self.assert_(not my_img.metadata['displayName'])
class TestHTTPChannel(http.HTTPChannel):
"""Dummy site required for twisted.web"""
@@ -185,7 +210,7 @@ class S3APITestCase(test.TrialTestCase):
"""Setup users, projects, and start a test server."""
super(S3APITestCase, self).setUp()
FLAGS.auth_driver = 'nova.auth.ldapdriver.FakeLdapDriver',
FLAGS.auth_driver = 'nova.auth.ldapdriver.FakeLdapDriver'
FLAGS.buckets_path = os.path.join(OSS_TEMPDIR, 'buckets')
self.auth_manager = manager.AuthManager()

View File

@@ -0,0 +1,152 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import logging
from nova import db
from nova import exception
from nova import flags
from nova import quota
from nova import test
from nova import utils
from nova.auth import manager
from nova.api.ec2 import cloud
from nova.api.ec2 import context
FLAGS = flags.FLAGS
class QuotaTestCase(test.TrialTestCase):
def setUp(self): # pylint: disable-msg=C0103
logging.getLogger().setLevel(logging.DEBUG)
super(QuotaTestCase, self).setUp()
self.flags(connection_type='fake',
quota_instances=2,
quota_cores=4,
quota_volumes=2,
quota_gigabytes=20,
quota_floating_ips=1)
self.cloud = cloud.CloudController()
self.manager = manager.AuthManager()
self.user = self.manager.create_user('admin', 'admin', 'admin', True)
self.project = self.manager.create_project('admin', 'admin', 'admin')
self.network = utils.import_object(FLAGS.network_manager)
self.context = context.APIRequestContext(project=self.project,
user=self.user)
def tearDown(self): # pylint: disable-msg=C0103
manager.AuthManager().delete_project(self.project)
manager.AuthManager().delete_user(self.user)
super(QuotaTestCase, self).tearDown()
def _create_instance(self, cores=2):
"""Create a test instance"""
inst = {}
inst['image_id'] = 'ami-test'
inst['reservation_id'] = 'r-fakeres'
inst['user_id'] = self.user.id
inst['project_id'] = self.project.id
inst['instance_type'] = 'm1.large'
inst['vcpus'] = cores
inst['mac_address'] = utils.generate_mac()
return db.instance_create(self.context, inst)['id']
def _create_volume(self, size=10):
"""Create a test volume"""
vol = {}
vol['user_id'] = self.user.id
vol['project_id'] = self.project.id
vol['size'] = size
return db.volume_create(self.context, vol)['id']
def test_quota_overrides(self):
"""Make sure overriding a projects quotas works"""
num_instances = quota.allowed_instances(self.context, 100, 'm1.small')
self.assertEqual(num_instances, 2)
db.quota_create(self.context, {'project_id': self.project.id,
'instances': 10})
num_instances = quota.allowed_instances(self.context, 100, 'm1.small')
self.assertEqual(num_instances, 4)
db.quota_update(self.context, self.project.id, {'cores': 100})
num_instances = quota.allowed_instances(self.context, 100, 'm1.small')
self.assertEqual(num_instances, 10)
db.quota_destroy(self.context, self.project.id)
def test_too_many_instances(self):
instance_ids = []
for i in range(FLAGS.quota_instances):
instance_id = self._create_instance()
instance_ids.append(instance_id)
self.assertRaises(cloud.QuotaError, self.cloud.run_instances,
self.context,
min_count=1,
max_count=1,
instance_type='m1.small')
for instance_id in instance_ids:
db.instance_destroy(self.context, instance_id)
def test_too_many_cores(self):
instance_ids = []
instance_id = self._create_instance(cores=4)
instance_ids.append(instance_id)
self.assertRaises(cloud.QuotaError, self.cloud.run_instances,
self.context,
min_count=1,
max_count=1,
instance_type='m1.small')
for instance_id in instance_ids:
db.instance_destroy(self.context, instance_id)
def test_too_many_volumes(self):
volume_ids = []
for i in range(FLAGS.quota_volumes):
volume_id = self._create_volume()
volume_ids.append(volume_id)
self.assertRaises(cloud.QuotaError, self.cloud.create_volume,
self.context,
size=10)
for volume_id in volume_ids:
db.volume_destroy(self.context, volume_id)
def test_too_many_gigabytes(self):
volume_ids = []
volume_id = self._create_volume(size=20)
volume_ids.append(volume_id)
self.assertRaises(cloud.QuotaError,
self.cloud.create_volume,
self.context,
size=10)
for volume_id in volume_ids:
db.volume_destroy(self.context, volume_id)
def test_too_many_addresses(self):
address = '192.168.0.100'
try:
db.floating_ip_get_by_address(None, address)
except exception.NotFound:
db.floating_ip_create(None, {'address': address,
'host': FLAGS.host})
float_addr = self.network.allocate_floating_ip(self.context,
self.project.id)
# NOTE(vish): This assert never fails. When cloud attempts to
# make an rpc.call, the test just finishes with OK. It
# appears to be something in the magic inline callbacks
# that is breaking.
self.assertRaises(cloud.QuotaError, self.cloud.allocate_address, self.context)

View File

@@ -21,7 +21,6 @@ from nova import flags
FLAGS = flags.FLAGS
FLAGS.connection_type = 'libvirt'
FLAGS.fake_storage = False
FLAGS.fake_rabbit = False
FLAGS.fake_network = False
FLAGS.verbose = False

View File

@@ -30,7 +30,7 @@ from nova import test
FLAGS = flags.FLAGS
class RpcTestCase(test.BaseTestCase):
class RpcTestCase(test.TrialTestCase):
"""Test cases for rpc"""
def setUp(self): # pylint: disable-msg=C0103
super(RpcTestCase, self).setUp()
@@ -39,14 +39,13 @@ class RpcTestCase(test.BaseTestCase):
self.consumer = rpc.AdapterConsumer(connection=self.conn,
topic='test',
proxy=self.receiver)
self.injected.append(self.consumer.attach_to_tornado(self.ioloop))
self.consumer.attach_to_twisted()
def test_call_succeed(self):
"""Get a value through rpc call"""
value = 42
result = yield rpc.call('test', {"method": "echo",
"args": {"value": value}})
result = yield rpc.call_twisted('test', {"method": "echo",
"args": {"value": value}})
self.assertEqual(value, result)
def test_call_exception(self):
@@ -57,12 +56,12 @@ class RpcTestCase(test.BaseTestCase):
to an int in the test.
"""
value = 42
self.assertFailure(rpc.call('test', {"method": "fail",
"args": {"value": value}}),
self.assertFailure(rpc.call_twisted('test', {"method": "fail",
"args": {"value": value}}),
rpc.RemoteError)
try:
yield rpc.call('test', {"method": "fail",
"args": {"value": value}})
yield rpc.call_twisted('test', {"method": "fail",
"args": {"value": value}})
self.fail("should have thrown rpc.RemoteError")
except rpc.RemoteError as exc:
self.assertEqual(int(exc.value), value)

View File

@@ -0,0 +1,242 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Tests For Scheduler
"""
from nova import db
from nova import flags
from nova import service
from nova import test
from nova import rpc
from nova import utils
from nova.auth import manager as auth_manager
from nova.scheduler import manager
from nova.scheduler import driver
FLAGS = flags.FLAGS
flags.DECLARE('max_cores', 'nova.scheduler.simple')
class TestDriver(driver.Scheduler):
"""Scheduler Driver for Tests"""
def schedule(context, topic, *args, **kwargs):
return 'fallback_host'
def schedule_named_method(context, topic, num):
return 'named_host'
class SchedulerTestCase(test.TrialTestCase):
"""Test case for scheduler"""
def setUp(self): # pylint: disable=C0103
super(SchedulerTestCase, self).setUp()
self.flags(scheduler_driver='nova.tests.scheduler_unittest.TestDriver')
def test_fallback(self):
scheduler = manager.SchedulerManager()
self.mox.StubOutWithMock(rpc, 'cast', use_mock_anything=True)
rpc.cast('topic.fallback_host',
{'method': 'noexist',
'args': {'context': None,
'num': 7}})
self.mox.ReplayAll()
scheduler.noexist(None, 'topic', num=7)
def test_named_method(self):
scheduler = manager.SchedulerManager()
self.mox.StubOutWithMock(rpc, 'cast', use_mock_anything=True)
rpc.cast('topic.named_host',
{'method': 'named_method',
'args': {'context': None,
'num': 7}})
self.mox.ReplayAll()
scheduler.named_method(None, 'topic', num=7)
class SimpleDriverTestCase(test.TrialTestCase):
"""Test case for simple driver"""
def setUp(self): # pylint: disable-msg=C0103
super(SimpleDriverTestCase, self).setUp()
self.flags(connection_type='fake',
max_cores=4,
max_gigabytes=4,
network_manager='nova.network.manager.FlatManager',
volume_driver='nova.volume.driver.FakeAOEDriver',
scheduler_driver='nova.scheduler.simple.SimpleScheduler')
self.scheduler = manager.SchedulerManager()
self.context = None
self.manager = auth_manager.AuthManager()
self.user = self.manager.create_user('fake', 'fake', 'fake')
self.project = self.manager.create_project('fake', 'fake', 'fake')
self.context = None
def tearDown(self): # pylint: disable-msg=C0103
self.manager.delete_user(self.user)
self.manager.delete_project(self.project)
def _create_instance(self):
"""Create a test instance"""
inst = {}
inst['image_id'] = 'ami-test'
inst['reservation_id'] = 'r-fakeres'
inst['user_id'] = self.user.id
inst['project_id'] = self.project.id
inst['instance_type'] = 'm1.tiny'
inst['mac_address'] = utils.generate_mac()
inst['ami_launch_index'] = 0
inst['vcpus'] = 1
return db.instance_create(self.context, inst)['id']
def _create_volume(self):
"""Create a test volume"""
vol = {}
vol['image_id'] = 'ami-test'
vol['reservation_id'] = 'r-fakeres'
vol['size'] = 1
return db.volume_create(self.context, vol)['id']
def test_hosts_are_up(self):
"""Ensures driver can find the hosts that are up"""
# NOTE(vish): constructing service without create method
# because we are going to use it without queue
compute1 = service.Service('host1',
'nova-compute',
'compute',
FLAGS.compute_manager)
compute1.startService()
compute2 = service.Service('host2',
'nova-compute',
'compute',
FLAGS.compute_manager)
compute2.startService()
hosts = self.scheduler.driver.hosts_up(self.context, 'compute')
self.assertEqual(len(hosts), 2)
compute1.kill()
compute2.kill()
def test_least_busy_host_gets_instance(self):
"""Ensures the host with less cores gets the next one"""
compute1 = service.Service('host1',
'nova-compute',
'compute',
FLAGS.compute_manager)
compute1.startService()
compute2 = service.Service('host2',
'nova-compute',
'compute',
FLAGS.compute_manager)
compute2.startService()
instance_id1 = self._create_instance()
compute1.run_instance(self.context, instance_id1)
instance_id2 = self._create_instance()
host = self.scheduler.driver.schedule_run_instance(self.context,
instance_id2)
self.assertEqual(host, 'host2')
compute1.terminate_instance(self.context, instance_id1)
db.instance_destroy(self.context, instance_id2)
compute1.kill()
compute2.kill()
def test_too_many_cores(self):
"""Ensures we don't go over max cores"""
compute1 = service.Service('host1',
'nova-compute',
'compute',
FLAGS.compute_manager)
compute1.startService()
compute2 = service.Service('host2',
'nova-compute',
'compute',
FLAGS.compute_manager)
compute2.startService()
instance_ids1 = []
instance_ids2 = []
for index in xrange(FLAGS.max_cores):
instance_id = self._create_instance()
compute1.run_instance(self.context, instance_id)
instance_ids1.append(instance_id)
instance_id = self._create_instance()
compute2.run_instance(self.context, instance_id)
instance_ids2.append(instance_id)
instance_id = self._create_instance()
self.assertRaises(driver.NoValidHost,
self.scheduler.driver.schedule_run_instance,
self.context,
instance_id)
for instance_id in instance_ids1:
compute1.terminate_instance(self.context, instance_id)
for instance_id in instance_ids2:
compute2.terminate_instance(self.context, instance_id)
compute1.kill()
compute2.kill()
def test_least_busy_host_gets_volume(self):
"""Ensures the host with less gigabytes gets the next one"""
volume1 = service.Service('host1',
'nova-volume',
'volume',
FLAGS.volume_manager)
volume1.startService()
volume2 = service.Service('host2',
'nova-volume',
'volume',
FLAGS.volume_manager)
volume2.startService()
volume_id1 = self._create_volume()
volume1.create_volume(self.context, volume_id1)
volume_id2 = self._create_volume()
host = self.scheduler.driver.schedule_create_volume(self.context,
volume_id2)
self.assertEqual(host, 'host2')
volume1.delete_volume(self.context, volume_id1)
db.volume_destroy(self.context, volume_id2)
volume1.kill()
volume2.kill()
def test_too_many_gigabytes(self):
"""Ensures we don't go over max gigabytes"""
volume1 = service.Service('host1',
'nova-volume',
'volume',
FLAGS.volume_manager)
volume1.startService()
volume2 = service.Service('host2',
'nova-volume',
'volume',
FLAGS.volume_manager)
volume2.startService()
volume_ids1 = []
volume_ids2 = []
for index in xrange(FLAGS.max_gigabytes):
volume_id = self._create_volume()
volume1.create_volume(self.context, volume_id)
volume_ids1.append(volume_id)
volume_id = self._create_volume()
volume2.create_volume(self.context, volume_id)
volume_ids2.append(volume_id)
volume_id = self._create_volume()
self.assertRaises(driver.NoValidHost,
self.scheduler.driver.schedule_create_volume,
self.context,
volume_id)
for volume_id in volume_ids1:
volume1.delete_volume(self.context, volume_id)
for volume_id in volume_ids2:
volume2.delete_volume(self.context, volume_id)
volume1.kill()
volume2.kill()

View File

@@ -0,0 +1,190 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Unit Tests for remote procedure calls using queue
"""
import mox
from twisted.application.app import startApplication
from nova import exception
from nova import flags
from nova import rpc
from nova import test
from nova import service
from nova import manager
FLAGS = flags.FLAGS
flags.DEFINE_string("fake_manager", "nova.tests.service_unittest.FakeManager",
"Manager for testing")
class FakeManager(manager.Manager):
"""Fake manager for tests"""
pass
class ServiceTestCase(test.BaseTestCase):
"""Test cases for rpc"""
def setUp(self): # pylint: disable=C0103
super(ServiceTestCase, self).setUp()
self.mox.StubOutWithMock(service, 'db')
def test_create(self):
host = 'foo'
binary = 'nova-fake'
topic = 'fake'
self.mox.StubOutWithMock(rpc,
'AdapterConsumer',
use_mock_anything=True)
self.mox.StubOutWithMock(
service.task, 'LoopingCall', use_mock_anything=True)
rpc.AdapterConsumer(connection=mox.IgnoreArg(),
topic=topic,
proxy=mox.IsA(service.Service)).AndReturn(
rpc.AdapterConsumer)
rpc.AdapterConsumer(connection=mox.IgnoreArg(),
topic='%s.%s' % (topic, host),
proxy=mox.IsA(service.Service)).AndReturn(
rpc.AdapterConsumer)
rpc.AdapterConsumer.attach_to_twisted()
rpc.AdapterConsumer.attach_to_twisted()
# Stub out looping call a bit needlessly since we don't have an easy
# way to cancel it (yet) when the tests finishes
service.task.LoopingCall(mox.IgnoreArg()).AndReturn(
service.task.LoopingCall)
service.task.LoopingCall.start(interval=mox.IgnoreArg(),
now=mox.IgnoreArg())
service.task.LoopingCall(mox.IgnoreArg()).AndReturn(
service.task.LoopingCall)
service.task.LoopingCall.start(interval=mox.IgnoreArg(),
now=mox.IgnoreArg())
service_create = {'host': host,
'binary': binary,
'topic': topic,
'report_count': 0}
service_ref = {'host': host,
'binary': binary,
'report_count': 0,
'id': 1}
service.db.service_get_by_args(None,
host,
binary).AndRaise(exception.NotFound())
service.db.service_create(None,
service_create).AndReturn(service_ref)
self.mox.ReplayAll()
app = service.Service.create(host=host, binary=binary)
startApplication(app, False)
self.assert_(app)
# We're testing sort of weird behavior in how report_state decides
# whether it is disconnected, it looks for a variable on itself called
# 'model_disconnected' and report_state doesn't really do much so this
# these are mostly just for coverage
def test_report_state(self):
host = 'foo'
binary = 'bar'
service_ref = {'host': host,
'binary': binary,
'report_count': 0,
'id': 1}
service.db.__getattr__('report_state')
service.db.service_get_by_args(None,
host,
binary).AndReturn(service_ref)
service.db.service_update(None, service_ref['id'],
mox.ContainsKeyValue('report_count', 1))
self.mox.ReplayAll()
s = service.Service()
rv = yield s.report_state(host, binary)
def test_report_state_no_service(self):
host = 'foo'
binary = 'bar'
service_create = {'host': host,
'binary': binary,
'report_count': 0}
service_ref = {'host': host,
'binary': binary,
'report_count': 0,
'id': 1}
service.db.__getattr__('report_state')
service.db.service_get_by_args(None,
host,
binary).AndRaise(exception.NotFound())
service.db.service_create(None,
service_create).AndReturn(service_ref)
service.db.service_get(None, service_ref['id']).AndReturn(service_ref)
service.db.service_update(None, service_ref['id'],
mox.ContainsKeyValue('report_count', 1))
self.mox.ReplayAll()
s = service.Service()
rv = yield s.report_state(host, binary)
def test_report_state_newly_disconnected(self):
host = 'foo'
binary = 'bar'
service_ref = {'host': host,
'binary': binary,
'report_count': 0,
'id': 1}
service.db.__getattr__('report_state')
service.db.service_get_by_args(None,
host,
binary).AndRaise(Exception())
self.mox.ReplayAll()
s = service.Service()
rv = yield s.report_state(host, binary)
self.assert_(s.model_disconnected)
def test_report_state_newly_connected(self):
host = 'foo'
binary = 'bar'
service_ref = {'host': host,
'binary': binary,
'report_count': 0,
'id': 1}
service.db.__getattr__('report_state')
service.db.service_get_by_args(None,
host,
binary).AndReturn(service_ref)
service.db.service_update(None, service_ref['id'],
mox.ContainsKeyValue('report_count', 1))
self.mox.ReplayAll()
s = service.Service()
s.model_disconnected = True
rv = yield s.report_state(host, binary)
self.assert_(not s.model_disconnected)

View File

@@ -1,115 +0,0 @@
# vim: tabstop=4 shiftwidth=4 softtabstop=4
# Copyright 2010 United States Government as represented by the
# Administrator of the National Aeronautics and Space Administration.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import logging
from nova import exception
from nova import flags
from nova import test
from nova.compute import node
from nova.volume import storage
FLAGS = flags.FLAGS
class StorageTestCase(test.TrialTestCase):
def setUp(self):
logging.getLogger().setLevel(logging.DEBUG)
super(StorageTestCase, self).setUp()
self.mynode = node.Node()
self.mystorage = None
self.flags(connection_type='fake',
fake_storage=True)
self.mystorage = storage.BlockStore()
def test_run_create_volume(self):
vol_size = '0'
user_id = 'fake'
project_id = 'fake'
volume_id = self.mystorage.create_volume(vol_size, user_id, project_id)
# TODO(termie): get_volume returns differently than create_volume
self.assertEqual(volume_id,
storage.get_volume(volume_id)['volume_id'])
rv = self.mystorage.delete_volume(volume_id)
self.assertRaises(exception.Error,
storage.get_volume,
volume_id)
def test_too_big_volume(self):
vol_size = '1001'
user_id = 'fake'
project_id = 'fake'
self.assertRaises(TypeError,
self.mystorage.create_volume,
vol_size, user_id, project_id)
def test_too_many_volumes(self):
vol_size = '1'
user_id = 'fake'
project_id = 'fake'
num_shelves = FLAGS.last_shelf_id - FLAGS.first_shelf_id + 1
total_slots = FLAGS.slots_per_shelf * num_shelves
vols = []
for i in xrange(total_slots):
vid = self.mystorage.create_volume(vol_size, user_id, project_id)
vols.append(vid)
self.assertRaises(storage.NoMoreVolumes,
self.mystorage.create_volume,
vol_size, user_id, project_id)
for id in vols:
self.mystorage.delete_volume(id)
def test_run_attach_detach_volume(self):
# Create one volume and one node to test with
instance_id = "storage-test"
vol_size = "5"
user_id = "fake"
project_id = 'fake'
mountpoint = "/dev/sdf"
volume_id = self.mystorage.create_volume(vol_size, user_id, project_id)
volume_obj = storage.get_volume(volume_id)
volume_obj.start_attach(instance_id, mountpoint)
rv = yield self.mynode.attach_volume(volume_id,
instance_id,
mountpoint)
self.assertEqual(volume_obj['status'], "in-use")
self.assertEqual(volume_obj['attachStatus'], "attached")
self.assertEqual(volume_obj['instance_id'], instance_id)
self.assertEqual(volume_obj['mountpoint'], mountpoint)
self.assertRaises(exception.Error,
self.mystorage.delete_volume,
volume_id)
rv = yield self.mystorage.detach_volume(volume_id)
volume_obj = storage.get_volume(volume_id)
self.assertEqual(volume_obj['status'], "available")
rv = self.mystorage.delete_volume(volume_id)
self.assertRaises(exception.Error,
storage.get_volume,
volume_id)
def test_multi_node(self):
# TODO(termie): Figure out how to test with two nodes,
# each of them having a different FLAG for storage_node
# This will allow us to test cross-node interactions
pass

View File

@@ -14,36 +14,77 @@
# License for the specific language governing permissions and limitations
# under the License.
from xml.etree.ElementTree import fromstring as xml_to_tree
from xml.dom.minidom import parseString as xml_to_dom
from nova import db
from nova import flags
from nova import test
from nova import utils
from nova.api import context
from nova.api.ec2 import cloud
from nova.auth import manager
from nova.virt import libvirt_conn
FLAGS = flags.FLAGS
flags.DECLARE('instances_path', 'nova.compute.manager')
class LibvirtConnTestCase(test.TrialTestCase):
def setUp(self):
super(LibvirtConnTestCase, self).setUp()
self.manager = manager.AuthManager()
self.user = self.manager.create_user('fake', 'fake', 'fake', admin=True)
self.project = self.manager.create_project('fake', 'fake', 'fake')
self.network = utils.import_object(FLAGS.network_manager)
FLAGS.instances_path = ''
def test_get_uri_and_template(self):
class MockDataModel(object):
def __init__(self):
self.datamodel = { 'name' : 'i-cafebabe',
'memory_kb' : '1024000',
'basepath' : '/some/path',
'bridge_name' : 'br100',
'mac_address' : '02:12:34:46:56:67',
'vcpus' : 2 }
ip = '10.11.12.13'
instance = { 'internal_id' : 1,
'memory_kb' : '1024000',
'basepath' : '/some/path',
'bridge_name' : 'br100',
'mac_address' : '02:12:34:46:56:67',
'vcpus' : 2,
'project_id' : 'fake',
'bridge' : 'br101',
'instance_type' : 'm1.small'}
instance_ref = db.instance_create(None, instance)
user_context = context.APIRequestContext(project=self.project,
user=self.user)
network_ref = self.network.get_network(user_context)
self.network.set_network_host(context.get_admin_context(),
network_ref['id'])
fixed_ip = { 'address' : ip,
'network_id' : network_ref['id'] }
fixed_ip_ref = db.fixed_ip_create(None, fixed_ip)
db.fixed_ip_update(None, ip, { 'allocated' : True,
'instance_id' : instance_ref['id'] })
type_uri_map = { 'qemu' : ('qemu:///system',
[lambda s: '<domain type=\'qemu\'>' in s,
lambda s: 'type>hvm</type' in s,
lambda s: 'emulator>/usr/bin/kvm' not in s]),
[(lambda t: t.find('.').get('type'), 'qemu'),
(lambda t: t.find('./os/type').text, 'hvm'),
(lambda t: t.find('./devices/emulator'), None)]),
'kvm' : ('qemu:///system',
[lambda s: '<domain type=\'kvm\'>' in s,
lambda s: 'type>hvm</type' in s,
lambda s: 'emulator>/usr/bin/qemu<' not in s]),
[(lambda t: t.find('.').get('type'), 'kvm'),
(lambda t: t.find('./os/type').text, 'hvm'),
(lambda t: t.find('./devices/emulator'), None)]),
'uml' : ('uml:///system',
[lambda s: '<domain type=\'uml\'>' in s,
lambda s: 'type>uml</type' in s]),
}
[(lambda t: t.find('.').get('type'), 'uml'),
(lambda t: t.find('./os/type').text, 'uml')]),
}
common_checks = [(lambda t: t.find('.').tag, 'domain'),
(lambda t: \
t.find('./devices/interface/filterref/parameter') \
.get('name'), 'IP'),
(lambda t: \
t.find('./devices/interface/filterref/parameter') \
.get('value'), '10.11.12.13')]
for (libvirt_type,(expected_uri, checks)) in type_uri_map.iteritems():
FLAGS.libvirt_type = libvirt_type
@@ -52,9 +93,17 @@ class LibvirtConnTestCase(test.TrialTestCase):
uri, template = conn.get_uri_and_template()
self.assertEquals(uri, expected_uri)
for i, check in enumerate(checks):
xml = conn.toXml(MockDataModel())
self.assertTrue(check(xml), '%s failed check %d' % (xml, i))
xml = conn.to_xml(instance_ref)
tree = xml_to_tree(xml)
for i, (check, expected_result) in enumerate(checks):
self.assertEqual(check(tree),
expected_result,
'%s failed check %d' % (xml, i))
for i, (check, expected_result) in enumerate(common_checks):
self.assertEqual(check(tree),
expected_result,
'%s failed common check %d' % (xml, i))
# Deliberately not just assigning this string to FLAGS.libvirt_uri and
# checking against that later on. This way we make sure the
@@ -67,3 +116,143 @@ class LibvirtConnTestCase(test.TrialTestCase):
uri, template = conn.get_uri_and_template()
self.assertEquals(uri, testuri)
def tearDown(self):
super(LibvirtConnTestCase, self).tearDown()
self.manager.delete_project(self.project)
self.manager.delete_user(self.user)
class NWFilterTestCase(test.TrialTestCase):
def setUp(self):
super(NWFilterTestCase, self).setUp()
class Mock(object):
pass
self.manager = manager.AuthManager()
self.user = self.manager.create_user('fake', 'fake', 'fake', admin=True)
self.project = self.manager.create_project('fake', 'fake', 'fake')
self.context = context.APIRequestContext(self.user, self.project)
self.fake_libvirt_connection = Mock()
self.fw = libvirt_conn.NWFilterFirewall(self.fake_libvirt_connection)
def tearDown(self):
self.manager.delete_project(self.project)
self.manager.delete_user(self.user)
def test_cidr_rule_nwfilter_xml(self):
cloud_controller = cloud.CloudController()
cloud_controller.create_security_group(self.context,
'testgroup',
'test group description')
cloud_controller.authorize_security_group_ingress(self.context,
'testgroup',
from_port='80',
to_port='81',
ip_protocol='tcp',
cidr_ip='0.0.0.0/0')
security_group = db.security_group_get_by_name(self.context,
'fake',
'testgroup')
xml = self.fw.security_group_to_nwfilter_xml(security_group.id)
dom = xml_to_dom(xml)
self.assertEqual(dom.firstChild.tagName, 'filter')
rules = dom.getElementsByTagName('rule')
self.assertEqual(len(rules), 1)
# It's supposed to allow inbound traffic.
self.assertEqual(rules[0].getAttribute('action'), 'accept')
self.assertEqual(rules[0].getAttribute('direction'), 'in')
# Must be lower priority than the base filter (which blocks everything)
self.assertTrue(int(rules[0].getAttribute('priority')) < 1000)
ip_conditions = rules[0].getElementsByTagName('tcp')
self.assertEqual(len(ip_conditions), 1)
self.assertEqual(ip_conditions[0].getAttribute('srcipaddr'), '0.0.0.0')
self.assertEqual(ip_conditions[0].getAttribute('srcipmask'), '0.0.0.0')
self.assertEqual(ip_conditions[0].getAttribute('dstportstart'), '80')
self.assertEqual(ip_conditions[0].getAttribute('dstportend'), '81')
self.teardown_security_group()
def teardown_security_group(self):
cloud_controller = cloud.CloudController()
cloud_controller.delete_security_group(self.context, 'testgroup')
def setup_and_return_security_group(self):
cloud_controller = cloud.CloudController()
cloud_controller.create_security_group(self.context,
'testgroup',
'test group description')
cloud_controller.authorize_security_group_ingress(self.context,
'testgroup',
from_port='80',
to_port='81',
ip_protocol='tcp',
cidr_ip='0.0.0.0/0')
return db.security_group_get_by_name(self.context, 'fake', 'testgroup')
def test_creates_base_rule_first(self):
# These come pre-defined by libvirt
self.defined_filters = ['no-mac-spoofing',
'no-ip-spoofing',
'no-arp-spoofing',
'allow-dhcp-server']
self.recursive_depends = {}
for f in self.defined_filters:
self.recursive_depends[f] = []
def _filterDefineXMLMock(xml):
dom = xml_to_dom(xml)
name = dom.firstChild.getAttribute('name')
self.recursive_depends[name] = []
for f in dom.getElementsByTagName('filterref'):
ref = f.getAttribute('filter')
self.assertTrue(ref in self.defined_filters,
('%s referenced filter that does ' +
'not yet exist: %s') % (name, ref))
dependencies = [ref] + self.recursive_depends[ref]
self.recursive_depends[name] += dependencies
self.defined_filters.append(name)
return True
self.fake_libvirt_connection.nwfilterDefineXML = _filterDefineXMLMock
instance_ref = db.instance_create(self.context,
{'user_id': 'fake',
'project_id': 'fake'})
inst_id = instance_ref['id']
def _ensure_all_called(_):
instance_filter = 'nova-instance-%s' % instance_ref['name']
secgroup_filter = 'nova-secgroup-%s' % self.security_group['id']
for required in [secgroup_filter, 'allow-dhcp-server',
'no-arp-spoofing', 'no-ip-spoofing',
'no-mac-spoofing']:
self.assertTrue(required in self.recursive_depends[instance_filter],
"Instance's filter does not include %s" % required)
self.security_group = self.setup_and_return_security_group()
db.instance_add_security_group(self.context, inst_id, self.security_group.id)
instance = db.instance_get(self.context, inst_id)
d = self.fw.setup_nwfilters_for_instance(instance)
d.addCallback(_ensure_all_called)
d.addCallback(lambda _:self.teardown_security_group())
return d

View File

@@ -15,139 +15,159 @@
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""
Tests for Volume Code
"""
import logging
import shutil
import tempfile
from twisted.internet import defer
from nova import compute
from nova import exception
from nova import db
from nova import flags
from nova import test
from nova.volume import service as volume_service
from nova import utils
FLAGS = flags.FLAGS
class VolumeTestCase(test.TrialTestCase):
def setUp(self):
"""Test Case for volumes"""
def setUp(self): # pylint: disable-msg=C0103
logging.getLogger().setLevel(logging.DEBUG)
super(VolumeTestCase, self).setUp()
self.compute = compute.service.ComputeService()
self.volume = None
self.tempdir = tempfile.mkdtemp()
self.flags(connection_type='fake',
fake_storage=True,
aoe_export_dir=self.tempdir)
self.volume = volume_service.VolumeService()
self.compute = utils.import_object(FLAGS.compute_manager)
self.flags(connection_type='fake')
self.volume = utils.import_object(FLAGS.volume_manager)
self.context = None
def tearDown(self):
shutil.rmtree(self.tempdir)
@staticmethod
def _create_volume(size='0'):
"""Create a volume object"""
vol = {}
vol['size'] = size
vol['user_id'] = 'fake'
vol['project_id'] = 'fake'
vol['availability_zone'] = FLAGS.storage_availability_zone
vol['status'] = "creating"
vol['attach_status'] = "detached"
return db.volume_create(None, vol)['id']
@defer.inlineCallbacks
def test_run_create_volume(self):
vol_size = '0'
user_id = 'fake'
project_id = 'fake'
volume_id = yield self.volume.create_volume(vol_size, user_id, project_id)
# TODO(termie): get_volume returns differently than create_volume
self.assertEqual(volume_id,
volume_service.get_volume(volume_id)['volume_id'])
def test_create_delete_volume(self):
"""Test volume can be created and deleted"""
volume_id = self._create_volume()
yield self.volume.create_volume(self.context, volume_id)
self.assertEqual(volume_id, db.volume_get(None, volume_id).id)
rv = self.volume.delete_volume(volume_id)
self.assertRaises(exception.Error, volume_service.get_volume, volume_id)
yield self.volume.delete_volume(self.context, volume_id)
self.assertRaises(exception.NotFound,
db.volume_get,
None,
volume_id)
@defer.inlineCallbacks
def test_too_big_volume(self):
vol_size = '1001'
user_id = 'fake'
project_id = 'fake'
"""Ensure failure if a too large of a volume is requested"""
# FIXME(vish): validation needs to move into the data layer in
# volume_create
defer.returnValue(True)
try:
yield self.volume.create_volume(vol_size, user_id, project_id)
volume_id = self._create_volume('1001')
yield self.volume.create_volume(self.context, volume_id)
self.fail("Should have thrown TypeError")
except TypeError:
pass
@defer.inlineCallbacks
def test_too_many_volumes(self):
vol_size = '1'
user_id = 'fake'
project_id = 'fake'
num_shelves = FLAGS.last_shelf_id - FLAGS.first_shelf_id + 1
total_slots = FLAGS.blades_per_shelf * num_shelves
"""Ensure that NoMoreBlades is raised when we run out of volumes"""
vols = []
from nova import datastore
redis = datastore.Redis.instance()
for i in xrange(total_slots):
vid = yield self.volume.create_volume(vol_size, user_id, project_id)
vols.append(vid)
self.assertFailure(self.volume.create_volume(vol_size,
user_id,
project_id),
volume_service.NoMoreBlades)
for id in vols:
yield self.volume.delete_volume(id)
total_slots = FLAGS.num_shelves * FLAGS.blades_per_shelf
for _index in xrange(total_slots):
volume_id = self._create_volume()
yield self.volume.create_volume(self.context, volume_id)
vols.append(volume_id)
volume_id = self._create_volume()
self.assertFailure(self.volume.create_volume(self.context,
volume_id),
db.NoMoreBlades)
db.volume_destroy(None, volume_id)
for volume_id in vols:
yield self.volume.delete_volume(self.context, volume_id)
@defer.inlineCallbacks
def test_run_attach_detach_volume(self):
# Create one volume and one compute to test with
instance_id = "storage-test"
vol_size = "5"
user_id = "fake"
project_id = 'fake'
"""Make sure volume can be attached and detached from instance"""
inst = {}
inst['image_id'] = 'ami-test'
inst['reservation_id'] = 'r-fakeres'
inst['launch_time'] = '10'
inst['user_id'] = 'fake'
inst['project_id'] = 'fake'
inst['instance_type'] = 'm1.tiny'
inst['mac_address'] = utils.generate_mac()
inst['ami_launch_index'] = 0
instance_id = db.instance_create(self.context, inst)['id']
mountpoint = "/dev/sdf"
volume_id = yield self.volume.create_volume(vol_size, user_id, project_id)
volume_obj = volume_service.get_volume(volume_id)
volume_obj.start_attach(instance_id, mountpoint)
volume_id = self._create_volume()
yield self.volume.create_volume(self.context, volume_id)
if FLAGS.fake_tests:
volume_obj.finish_attach()
db.volume_attached(None, volume_id, instance_id, mountpoint)
else:
rv = yield self.compute.attach_volume(instance_id,
volume_id,
mountpoint)
self.assertEqual(volume_obj['status'], "in-use")
self.assertEqual(volume_obj['attach_status'], "attached")
self.assertEqual(volume_obj['instance_id'], instance_id)
self.assertEqual(volume_obj['mountpoint'], mountpoint)
yield self.compute.attach_volume(instance_id,
volume_id,
mountpoint)
vol = db.volume_get(None, volume_id)
self.assertEqual(vol['status'], "in-use")
self.assertEqual(vol['attach_status'], "attached")
self.assertEqual(vol['mountpoint'], mountpoint)
instance_ref = db.volume_get_instance(self.context, volume_id)
self.assertEqual(instance_ref['id'], instance_id)
self.assertFailure(self.volume.delete_volume(volume_id), exception.Error)
volume_obj.start_detach()
self.assertFailure(self.volume.delete_volume(self.context, volume_id),
exception.Error)
if FLAGS.fake_tests:
volume_obj.finish_detach()
db.volume_detached(None, volume_id)
else:
rv = yield self.volume.detach_volume(instance_id,
volume_id)
volume_obj = volume_service.get_volume(volume_id)
self.assertEqual(volume_obj['status'], "available")
yield self.compute.detach_volume(instance_id,
volume_id)
vol = db.volume_get(None, volume_id)
self.assertEqual(vol['status'], "available")
rv = self.volume.delete_volume(volume_id)
yield self.volume.delete_volume(self.context, volume_id)
self.assertRaises(exception.Error,
volume_service.get_volume,
db.volume_get,
None,
volume_id)
db.instance_destroy(self.context, instance_id)
@defer.inlineCallbacks
def test_multiple_volume_race_condition(self):
vol_size = "5"
user_id = "fake"
project_id = 'fake'
def test_concurrent_volumes_get_different_blades(self):
"""Ensure multiple concurrent volumes get different blades"""
volume_ids = []
shelf_blades = []
def _check(volume_id):
vol = volume_service.get_volume(volume_id)
shelf_blade = '%s.%s' % (vol['shelf_id'], vol['blade_id'])
"""Make sure blades aren't duplicated"""
volume_ids.append(volume_id)
(shelf_id, blade_id) = db.volume_get_shelf_and_blade(None,
volume_id)
shelf_blade = '%s.%s' % (shelf_id, blade_id)
self.assert_(shelf_blade not in shelf_blades)
shelf_blades.append(shelf_blade)
logging.debug("got %s" % shelf_blade)
vol.destroy()
logging.debug("Blade %s allocated", shelf_blade)
deferreds = []
for i in range(5):
d = self.volume.create_volume(vol_size, user_id, project_id)
total_slots = FLAGS.num_shelves * FLAGS.blades_per_shelf
for _index in xrange(total_slots):
volume_id = self._create_volume()
d = self.volume.create_volume(self.context, volume_id)
d.addCallback(_check)
d.addErrback(self.fail)
deferreds.append(d)
yield defer.DeferredList(deferreds)
for volume_id in volume_ids:
self.volume.delete_volume(self.context, volume_id)
def test_multi_node(self):
# TODO(termie): Figure out how to test with two nodes,

View File

@@ -21,6 +21,7 @@ Twisted daemon helpers, specifically to parse out gFlags from twisted flags,
manage pid files and support syslogging.
"""
import gflags
import logging
import os
import signal
@@ -49,6 +50,14 @@ class TwistdServerOptions(ServerOptions):
return
class FlagParser(object):
def __init__(self, parser):
self.parser = parser
def Parse(self, s):
return self.parser(s)
def WrapTwistedOptions(wrapped):
class TwistedOptionsToFlags(wrapped):
subCommands = None
@@ -79,7 +88,12 @@ def WrapTwistedOptions(wrapped):
reflect.accumulateClassList(self.__class__, 'optParameters', twistd_params)
for param in twistd_params:
key = param[0].replace('-', '_')
flags.DEFINE_string(key, param[2], str(param[-1]))
if len(param) > 4:
flags.DEFINE(FlagParser(param[4]),
key, param[2], str(param[3]),
serializer=gflags.ArgumentSerializer())
else:
flags.DEFINE_string(key, param[2], str(param[3]))
def _absorbHandlers(self):
twistd_handlers = {}

View File

@@ -55,13 +55,17 @@ from nova.tests.api_unittest import *
from nova.tests.cloud_unittest import *
from nova.tests.compute_unittest import *
from nova.tests.flags_unittest import *
from nova.tests.model_unittest import *
from nova.tests.network_unittest import *
from nova.tests.objectstore_unittest import *
from nova.tests.process_unittest import *
from nova.tests.quota_unittest import *
from nova.tests.rpc_unittest import *
from nova.tests.scheduler_unittest import *
from nova.tests.service_unittest import *
from nova.tests.validator_unittest import *
from nova.tests.virt_unittest import *
from nova.tests.volume_unittest import *
from nova.tests.virt_unittest import *
FLAGS = flags.FLAGS