prevent leakage of FLAGS changes across tests

This commit is contained in:
Andy Smith
2010-10-15 17:37:26 +09:00
parent 2ba6b6b605
commit 13d30a1c68
3 changed files with 43 additions and 27 deletions

View File

@@ -90,6 +90,12 @@ class FlagValues(gflags.FlagValues):
self.ClearDirty() self.ClearDirty()
return args return args
def Reset(self):
gflags.FlagValues.Reset(self)
self.__dict__['__dirty'] = []
self.__dict__['__was_already_parsed'] = False
self.__dict__['__stored_argv'] = []
def SetDirty(self, name): def SetDirty(self, name):
"""Mark a flag as dirty so that accessing it will case a reparse.""" """Mark a flag as dirty so that accessing it will case a reparse."""
self.__dict__['__dirty'].append(name) self.__dict__['__dirty'].append(name)

View File

@@ -22,9 +22,9 @@ Allows overriding of flags for use of fakes,
and some black magic for inline callbacks. and some black magic for inline callbacks.
""" """
import datetime
import sys import sys
import time import time
import datetime
import mox import mox
import stubout import stubout
@@ -80,30 +80,33 @@ class TrialTestCase(unittest.TestCase):
self.flag_overrides = {} self.flag_overrides = {}
self.injected = [] self.injected = []
self._monkeyPatchAttach() self._monkeyPatchAttach()
self._original_flags = FLAGS.FlagValuesDict()
def tearDown(self): # pylint: disable-msg=C0103 def tearDown(self): # pylint: disable-msg=C0103
"""Runs after each test method to finalize/tear down test environment""" """Runs after each test method to finalize/tear down test environment"""
self.reset_flags() try:
self.mox.UnsetStubs() self.mox.UnsetStubs()
self.stubs.UnsetAll() self.stubs.UnsetAll()
self.stubs.SmartUnsetAll() self.stubs.SmartUnsetAll()
self.mox.VerifyAll() self.mox.VerifyAll()
# NOTE(vish): Clean up any ips associated during the test. # NOTE(vish): Clean up any ips associated during the test.
ctxt = context.get_admin_context() ctxt = context.get_admin_context()
db.fixed_ip_disassociate_all_by_timeout(ctxt, FLAGS.host, self.start) db.fixed_ip_disassociate_all_by_timeout(ctxt, FLAGS.host, self.start)
db.network_disassociate_all(ctxt) db.network_disassociate_all(ctxt)
rpc.Consumer.attach_to_twisted = self.originalAttach rpc.Consumer.attach_to_twisted = self.originalAttach
for x in self.injected: for x in self.injected:
try: try:
x.stop() x.stop()
except AssertionError: except AssertionError:
pass pass
if FLAGS.fake_rabbit: if FLAGS.fake_rabbit:
fakerabbit.reset_all() fakerabbit.reset_all()
db.security_group_destroy_all(ctxt)
super(TrialTestCase, self).tearDown() db.security_group_destroy_all(ctxt)
super(TrialTestCase, self).tearDown()
finally:
self.reset_flags()
def flags(self, **kw): def flags(self, **kw):
"""Override flag variables for a test""" """Override flag variables for a test"""
@@ -117,7 +120,8 @@ class TrialTestCase(unittest.TestCase):
def reset_flags(self): def reset_flags(self):
"""Resets all flag variables for the test. Runs after each test""" """Resets all flag variables for the test. Runs after each test"""
for k, v in self.flag_overrides.iteritems(): FLAGS.Reset()
for k, v in self._original_flags.iteritems():
setattr(FLAGS, k, v) setattr(FLAGS, k, v)
def run(self, result=None): def run(self, result=None):
@@ -171,12 +175,6 @@ class BaseTestCase(TrialTestCase):
self._done_waiting = False self._done_waiting = False
self._timed_out = False self._timed_out = False
def tearDown(self):# pylint: disable-msg=C0103
"""Runs after each test method to finalize/tear down test environment"""
super(BaseTestCase, self).tearDown()
if FLAGS.fake_rabbit:
fakerabbit.reset_all()
def _wait_for_test(self, timeout=60): def _wait_for_test(self, timeout=60):
""" Push the ioloop along to wait for our test to complete. """ """ Push the ioloop along to wait for our test to complete. """
self._waiting = self.ioloop.add_timeout(time.time() + timeout, self._waiting = self.ioloop.add_timeout(time.time() + timeout,

View File

@@ -20,6 +20,8 @@ from nova import exception
from nova import flags from nova import flags
from nova import test from nova import test
FLAGS = flags.FLAGS
flags.DEFINE_string('flags_unittest', 'foo', 'for testing purposes only')
class FlagsTestCase(test.TrialTestCase): class FlagsTestCase(test.TrialTestCase):
def setUp(self): def setUp(self):
@@ -85,3 +87,13 @@ class FlagsTestCase(test.TrialTestCase):
self.assert_('runtime_answer' in self.global_FLAGS) self.assert_('runtime_answer' in self.global_FLAGS)
self.assertEqual(self.global_FLAGS.runtime_answer, 60) self.assertEqual(self.global_FLAGS.runtime_answer, 60)
def test_flag_leak_left(self):
self.assertEqual(FLAGS.flags_unittest, 'foo')
FLAGS.flags_unittest = 'bar'
self.assertEqual(FLAGS.flags_unittest, 'bar')
def test_flag_leak_right(self):
self.assertEqual(FLAGS.flags_unittest, 'foo')
FLAGS.flags_unittest = 'bar'
self.assertEqual(FLAGS.flags_unittest, 'bar')