diff --git a/nova/tests/fake_libvirt_utils.py b/nova/tests/fake_libvirt_utils.py index 020ff819..0092e116 100644 --- a/nova/tests/fake_libvirt_utils.py +++ b/nova/tests/fake_libvirt_utils.py @@ -86,7 +86,10 @@ def extract_snapshot(disk_path, source_fmt, snapshot_name, out_path, dest_fmt): class File(object): def __init__(self, path, mode=None): - self.fp = StringIO.StringIO(files[path]) + if path in files: + self.fp = StringIO.StringIO(files[path]) + else: + self.fp = StringIO.StringIO(files[os.path.split(path)[-1]]) def __enter__(self): return self.fp diff --git a/nova/tests/test_libvirt.py b/nova/tests/test_libvirt.py index 070024fa..768406e9 100644 --- a/nova/tests/test_libvirt.py +++ b/nova/tests/test_libvirt.py @@ -2304,6 +2304,7 @@ class LibvirtConnTestCase(test.TestCase): CONF.base_dir_name)) def test_get_console_output_file(self): + fake_libvirt_utils.files['console.log'] = '01234567890' with utils.tempdir() as tmpdir: self.flags(instances_path=tmpdir) @@ -2313,11 +2314,7 @@ class LibvirtConnTestCase(test.TestCase): instance = db.instance_create(self.context, instance_ref) console_dir = (os.path.join(tmpdir, instance['name'])) - os.mkdir(console_dir) console_log = '%s/console.log' % (console_dir) - f = open(console_log, "w") - f.write("foo") - f.close() fake_dom_xml = """ @@ -2340,10 +2337,18 @@ class LibvirtConnTestCase(test.TestCase): libvirt_driver.libvirt_utils = fake_libvirt_utils conn = libvirt_driver.LibvirtDriver(fake.FakeVirtAPI(), False) - output = conn.get_console_output(instance) - self.assertEquals("foo", output) + + try: + prev_max = libvirt_driver.MAX_CONSOLE_BYTES + libvirt_driver.MAX_CONSOLE_BYTES = 5 + output = conn.get_console_output(instance) + finally: + libvirt_driver.MAX_CONSOLE_BYTES = prev_max + + self.assertEquals('67890', output) def test_get_console_output_pty(self): + fake_libvirt_utils.files['pty'] = '01234567890' with utils.tempdir() as tmpdir: self.flags(instances_path=tmpdir) @@ -2353,11 +2358,7 @@ class LibvirtConnTestCase(test.TestCase): instance = db.instance_create(self.context, instance_ref) console_dir = (os.path.join(tmpdir, instance['name'])) - os.mkdir(console_dir) pty_file = '%s/fake_pty' % (console_dir) - f = open(pty_file, "w") - f.write("foo") - f.close() fake_dom_xml = """ @@ -2376,17 +2377,27 @@ class LibvirtConnTestCase(test.TestCase): return FakeVirtDomain(fake_dom_xml) def _fake_flush(self, fake_pty): - with open(fake_pty, 'r') as fp: - return fp.read() + return 'foo' + + def _fake_append_to_file(self, data, fpath): + return 'pty' self.create_fake_libvirt_mock() libvirt_driver.LibvirtDriver._conn.lookupByName = fake_lookup libvirt_driver.LibvirtDriver._flush_libvirt_console = _fake_flush + libvirt_driver.LibvirtDriver._append_to_file = _fake_append_to_file libvirt_driver.libvirt_utils = fake_libvirt_utils conn = libvirt_driver.LibvirtDriver(fake.FakeVirtAPI(), False) - output = conn.get_console_output(instance) - self.assertEquals("foo", output) + + try: + prev_max = libvirt_driver.MAX_CONSOLE_BYTES + libvirt_driver.MAX_CONSOLE_BYTES = 5 + output = conn.get_console_output(instance) + finally: + libvirt_driver.MAX_CONSOLE_BYTES = prev_max + + self.assertEquals('67890', output) def test_get_host_ip_addr(self): conn = libvirt_driver.LibvirtDriver(fake.FakeVirtAPI(), False) diff --git a/nova/tests/test_virt_drivers.py b/nova/tests/test_virt_drivers.py index 9d48cdf0..e09f9672 100644 --- a/nova/tests/test_virt_drivers.py +++ b/nova/tests/test_virt_drivers.py @@ -14,8 +14,11 @@ # License for the specific language governing permissions and limitations # under the License. +import __builtin__ import base64 +import mox import netaddr +import StringIO import sys import traceback @@ -25,6 +28,7 @@ from nova import exception from nova.openstack.common import importutils from nova.openstack.common import log as logging from nova import test +from nova.tests import fake_libvirt_utils from nova.tests.image import fake as fake_image from nova.tests import utils as test_utils from nova.virt import fake @@ -428,6 +432,7 @@ class _VirtDriverTestCase(_FakeDriverBackendTestCase): @catch_notimplementederror def test_get_console_output(self): + fake_libvirt_utils.files['dummy.log'] = '' instance_ref, network_info = self._get_running_instance() console_output = self.connection.get_console_output(instance_ref) self.assertTrue(isinstance(console_output, basestring))