diff --git a/nova/tests/test_virt_drivers.py b/nova/tests/test_virt_drivers.py index 774450e0..f5ecd344 100644 --- a/nova/tests/test_virt_drivers.py +++ b/nova/tests/test_virt_drivers.py @@ -19,17 +19,14 @@ import netaddr import sys import traceback +from nova.compute.manager import ComputeManager from nova import exception -from nova import flags from nova import image from nova import log as logging from nova.openstack.common import importutils from nova import test from nova.tests import utils as test_utils -libvirt = None -FLAGS = flags.FLAGS - LOG = logging.getLogger(__name__) @@ -53,7 +50,127 @@ def catch_notimplementederror(f): return wrapped_func -class _VirtDriverTestCase(test.TestCase): +class _FakeDriverBackendTestCase(test.TestCase): + def _setup_fakelibvirt(self): + # So that the _supports_direct_io does the test based + # on the current working directory, instead of the + # default instances_path which doesn't exist + self.flags(instances_path='') + + # Put fakelibvirt in place + if 'libvirt' in sys.modules: + self.saved_libvirt = sys.modules['libvirt'] + else: + self.saved_libvirt = None + + import fake_libvirt_utils + import fakelibvirt + + sys.modules['libvirt'] = fakelibvirt + import nova.virt.libvirt.connection + import nova.virt.libvirt.firewall + + nova.virt.libvirt.connection.libvirt = fakelibvirt + nova.virt.libvirt.connection.libvirt_utils = fake_libvirt_utils + nova.virt.libvirt.firewall.libvirt = fakelibvirt + + self.flags(firewall_driver=nova.virt.libvirt.firewall.drivers[0], + rescue_image_id="2", + rescue_kernel_id="3", + rescue_ramdisk_id=None) + + def fake_extend(image, size): + pass + + self.stubs.Set(nova.virt.libvirt.connection.disk, + 'extend', fake_extend) + + def _teardown_fakelibvirt(self): + # Restore libvirt + import nova.virt.libvirt.connection + import nova.virt.libvirt.firewall + if self.saved_libvirt: + sys.modules['libvirt'] = self.saved_libvirt + nova.virt.libvirt.connection.libvirt = self.saved_libvirt + nova.virt.libvirt.connection.libvirt_utils = self.saved_libvirt + nova.virt.libvirt.firewall.libvirt = self.saved_libvirt + + def setUp(self): + super(_FakeDriverBackendTestCase, self).setUp() + # TODO(sdague): it would be nice to do this in a way that only + # the relevant backends where replaced for tests, though this + # should not harm anything by doing it for all backends + self._setup_fakelibvirt() + + def tearDown(self): + self._teardown_fakelibvirt() + super(_FakeDriverBackendTestCase, self).tearDown() + + +class VirtDriverLoaderTestCase(_FakeDriverBackendTestCase): + """Test that ComputeManager can successfully load both + old style and new style drivers and end up with the correct + final class""" + + # if your driver supports being tested in a fake way, it can go here + new_drivers = { + 'nova.virt.fake.FakeDriver': 'FakeDriver', + 'nova.virt.libvirt.connection.LibvirtDriver': 'LibvirtDriver' + } + + # NOTE(sdague): remove after Folsom release when connection_type + # is removed + old_drivers = { + 'libvirt': 'LibvirtDriver', + 'fake': 'FakeDriver' + } + + def test_load_new_drivers(self): + for cls, driver in self.new_drivers.iteritems(): + self.flags(compute_driver=cls) + # NOTE(sdague) the try block is to make it easier to debug a + # failure by knowing which driver broke + try: + cm = ComputeManager() + except Exception as e: + self.fail("Couldn't load driver %s - %s" % (cls, e)) + + self.assertEqual(cm.driver.__class__.__name__, driver, + "Could't load driver %s" % cls) + + # NOTE(sdague): remove after Folsom release when connection_type + # is removed + def test_load_old_drivers(self): + # we explicitly use the old default + self.flags(compute_driver='nova.virt.connection.get_connection') + for cls, driver in self.old_drivers.iteritems(): + self.flags(connection_type=cls) + # NOTE(sdague) the try block is to make it easier to debug a + # failure by knowing which driver broke + try: + cm = ComputeManager() + except Exception as e: + self.fail("Couldn't load connection %s - %s" % (cls, e)) + + self.assertEqual(cm.driver.__class__.__name__, driver, + "Could't load connection %s" % cls) + + def test_fail_to_load_old_drivers(self): + self.flags(compute_driver='nova.virt.connection.get_connection') + self.flags(connection_type='56kmodem') + self.assertRaises(exception.VirtDriverNotFound, ComputeManager) + + def test_fail_to_load_new_drivers(self): + self.flags(compute_driver='nova.virt.amiga') + + def _fake_exit(error): + raise test.TestingException() + + self.stubs.Set(sys, 'exit', _fake_exit) + self.assertRaises(test.TestingException, ComputeManager) + + +class _VirtDriverTestCase(_FakeDriverBackendTestCase): def setUp(self): super(_VirtDriverTestCase, self).setUp() self.connection = importutils.import_object(self.driver_module, '') @@ -440,52 +557,11 @@ class FakeConnectionTestCase(_VirtDriverTestCase): class LibvirtConnTestCase(_VirtDriverTestCase): def setUp(self): - # Put fakelibvirt in place - if 'libvirt' in sys.modules: - self.saved_libvirt = sys.modules['libvirt'] - else: - self.saved_libvirt = None - - import fake_libvirt_utils - import fakelibvirt - - sys.modules['libvirt'] = fakelibvirt - - import nova.virt.libvirt.connection - import nova.virt.libvirt.firewall - - nova.virt.libvirt.connection.libvirt = fakelibvirt - nova.virt.libvirt.connection.libvirt_utils = fake_libvirt_utils - nova.virt.libvirt.firewall.libvirt = fakelibvirt - - # So that the _supports_direct_io does the test based - # on the current working directory, instead of the - # default instances_path which doesn't exist - FLAGS.instances_path = '' - # Point _VirtDriverTestCase at the right module self.driver_module = 'nova.virt.libvirt.connection.LibvirtDriver' super(LibvirtConnTestCase, self).setUp() - self.flags(firewall_driver=nova.virt.libvirt.firewall.drivers[0], - rescue_image_id="2", - rescue_kernel_id="3", - rescue_ramdisk_id=None) - - def fake_extend(image, size): - pass - - self.stubs.Set(nova.virt.libvirt.connection.disk, - 'extend', fake_extend) def tearDown(self): - # Restore libvirt - import nova.virt.libvirt.connection - import nova.virt.libvirt.firewall - if self.saved_libvirt: - sys.modules['libvirt'] = self.saved_libvirt - nova.virt.libvirt.connection.libvirt = self.saved_libvirt - nova.virt.libvirt.connection.libvirt_utils = self.saved_libvirt - nova.virt.libvirt.firewall.libvirt = self.saved_libvirt super(LibvirtConnTestCase, self).tearDown() def test_force_hard_reboot(self):