diff --git a/taskflow/test.py b/taskflow/test.py new file mode 100644 index 00000000..63da745d --- /dev/null +++ b/taskflow/test.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- + +# vim: tabstop=4 shiftwidth=4 softtabstop=4 + +# Copyright (C) 2012 Yahoo! Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import unittest2 + +from oslo.config import cfg + +CONF = cfg.CONF + + +class TestCase(unittest2.TestCase): + """Test case base class for all unit tests.""" + + def setUp(self): + """Run before each test method to initialize test environment.""" + super(TestCase, self).setUp() + self.overriden = [] + self.addCleanup(self._clear_attrs) + + def tearDown(self): + super(TestCase, self).tearDown() + self._reset_flags() + + def _reset_flags(self): + for k, group in self.overriden: + CONF.clear_override(k, group=group) + + def _clear_attrs(self): + # Delete attributes that don't start with _ so they don't pin + # memory around unnecessarily for the duration of the test + # suite + for key in [k for k in self.__dict__.keys() if k[0] != '_']: + del self.__dict__[key] + + def flags(self, **kw): + """Override flag variables for a test.""" + group = kw.pop('group', None) + for k, v in kw.iteritems(): + CONF.set_override(k, v, group) + self.overriden.append((k, group)) diff --git a/taskflow/tests/unit/persistence/test_memory_persistence.py b/taskflow/tests/unit/persistence/test_memory_persistence.py index 37aaa366..0d940b35 100644 --- a/taskflow/tests/unit/persistence/test_memory_persistence.py +++ b/taskflow/tests/unit/persistence/test_memory_persistence.py @@ -17,12 +17,11 @@ # under the License. from taskflow.persistence.backends import api as b_api +from taskflow import test from taskflow.tests.unit.persistence import base -import unittest2 - -class MemoryPersistenceTest(unittest2.TestCase, base.PersistenceTestMixin): +class MemoryPersistenceTest(test.TestCase, base.PersistenceTestMixin): def _get_backend(self): return 'memory' diff --git a/taskflow/tests/unit/persistence/test_sql_persistence.py b/taskflow/tests/unit/persistence/test_sql_persistence.py index 1a67f6a6..44a7b0cf 100644 --- a/taskflow/tests/unit/persistence/test_sql_persistence.py +++ b/taskflow/tests/unit/persistence/test_sql_persistence.py @@ -22,12 +22,11 @@ import tempfile from taskflow.openstack.common.db.sqlalchemy import session from taskflow.persistence.backends import api as b_api from taskflow.persistence.backends.sqlalchemy import migration +from taskflow import test from taskflow.tests.unit.persistence import base -import unittest2 - -class SqlPersistenceTest(unittest2.TestCase, base.PersistenceTestMixin): +class SqlPersistenceTest(test.TestCase, base.PersistenceTestMixin): """Inherits from the base test and sets up a sqlite temporary db.""" def _get_backend(self): return 'sqlalchemy' diff --git a/taskflow/tests/unit/test_decorators.py b/taskflow/tests/unit/test_decorators.py index a2f91505..46689020 100644 --- a/taskflow/tests/unit/test_decorators.py +++ b/taskflow/tests/unit/test_decorators.py @@ -16,13 +16,12 @@ # License for the specific language governing permissions and limitations # under the License. -import unittest2 - from taskflow import decorators from taskflow.patterns import linear_flow +from taskflow import test -class WrapableObjectsTest(unittest2.TestCase): +class WrapableObjectsTest(test.TestCase): def test_simple_function(self): values = [] diff --git a/taskflow/tests/unit/test_functor_task.py b/taskflow/tests/unit/test_functor_task.py index e9be3468..f2f79e9b 100644 --- a/taskflow/tests/unit/test_functor_task.py +++ b/taskflow/tests/unit/test_functor_task.py @@ -16,10 +16,9 @@ # License for the specific language governing permissions and limitations # under the License. -import unittest2 - from taskflow import functor_task from taskflow.patterns import linear_flow +from taskflow import test def add(a, b): @@ -42,7 +41,7 @@ class BunchOfFunctions(object): raise RuntimeError('Woot!') -class FunctorTaskTest(unittest2.TestCase): +class FunctorTaskTest(test.TestCase): def test_simple(self): task = functor_task.FunctorTask(add) diff --git a/taskflow/tests/unit/test_graph_flow.py b/taskflow/tests/unit/test_graph_flow.py index 409657ce..639cfe8d 100644 --- a/taskflow/tests/unit/test_graph_flow.py +++ b/taskflow/tests/unit/test_graph_flow.py @@ -17,17 +17,16 @@ # under the License. import collections -import unittest2 from taskflow import decorators from taskflow import exceptions as excp -from taskflow import states - from taskflow.patterns import graph_flow as gw +from taskflow import states +from taskflow import test from taskflow.tests import utils -class GraphFlowTest(unittest2.TestCase): +class GraphFlowTest(test.TestCase): def test_reverting_flow(self): flo = gw.Flow("test-flow") reverted = [] diff --git a/taskflow/tests/unit/test_linear_flow.py b/taskflow/tests/unit/test_linear_flow.py index 672306d1..ffe9339e 100644 --- a/taskflow/tests/unit/test_linear_flow.py +++ b/taskflow/tests/unit/test_linear_flow.py @@ -16,19 +16,16 @@ # License for the specific language governing permissions and limitations # under the License. -import unittest2 - from taskflow import decorators from taskflow import exceptions as exc from taskflow import states +from taskflow import test from taskflow.patterns import linear_flow as lw -from taskflow.patterns.resumption import logbook as lr -from taskflow.persistence.backends import memory from taskflow.tests import utils -class LinearFlowTest(unittest2.TestCase): +class LinearFlowTest(test.TestCase): def make_reverting_task(self, token, blowup=False): def do_revert(context, *args, **kwargs): @@ -210,35 +207,34 @@ class LinearFlowTest(unittest2.TestCase): wf.reset() wf.run({}) - @unittest2.skip('') - def test_interrupt_flow(self): - wf = lw.Flow("the-int-action") - - # If we interrupt we need to know how to resume so attach the needed - # parts to do that... - tracker = lr.Resumption(memory.MemoryLogBook()) - tracker.record_for(wf) - wf.resumer = tracker - - wf.add(self.make_reverting_task(1)) - wf.add(self.make_interrupt_task(wf)) - wf.add(self.make_reverting_task(2)) - - self.assertEquals(states.PENDING, wf.state) - context = {} - wf.run(context) - - # Interrupt should have been triggered after task 1 - self.assertEquals(1, len(context)) - self.assertEquals(states.INTERRUPTED, wf.state) - - # And now reset and resume. - wf.reset() - tracker.record_for(wf) - wf.resumer = tracker - self.assertEquals(states.PENDING, wf.state) - wf.run(context) - self.assertEquals(2, len(context)) +# def test_interrupt_flow(self): +# wf = lw.Flow("the-int-action") +# +# # If we interrupt we need to know how to resume so attach the needed +# # parts to do that... +# tracker = lr.Resumption(memory.MemoryLogBook()) +# tracker.record_for(wf) +# wf.resumer = tracker +# +# wf.add(self.make_reverting_task(1)) +# wf.add(self.make_interrupt_task(wf)) +# wf.add(self.make_reverting_task(2)) +# +# self.assertEquals(states.PENDING, wf.state) +# context = {} +# wf.run(context) +# +# # Interrupt should have been triggered after task 1 +# self.assertEquals(1, len(context)) +# self.assertEquals(states.INTERRUPTED, wf.state) +# +# # And now reset and resume. +# wf.reset() +# tracker.record_for(wf) +# wf.resumer = tracker +# self.assertEquals(states.PENDING, wf.state) +# wf.run(context) +# self.assertEquals(2, len(context)) def test_parent_reverting_flow(self): happy_wf = lw.Flow("the-happy-action") diff --git a/taskflow/tests/unit/test_threaded_flow.py b/taskflow/tests/unit/test_threaded_flow.py index a93329c5..a6a70285 100644 --- a/taskflow/tests/unit/test_threaded_flow.py +++ b/taskflow/tests/unit/test_threaded_flow.py @@ -18,13 +18,13 @@ import threading import time -import unittest2 from taskflow import decorators from taskflow import exceptions as excp from taskflow import states from taskflow.patterns import threaded_flow as tf +from taskflow import test from taskflow.tests import utils @@ -35,7 +35,7 @@ def _find_idx(what, search_where): return -1 -class ThreadedFlowTest(unittest2.TestCase): +class ThreadedFlowTest(test.TestCase): def _make_tracking_flow(self, name): notify_lock = threading.RLock() flo = tf.Flow(name) diff --git a/taskflow/tests/unit/test_utils.py b/taskflow/tests/unit/test_utils.py index ad36d81f..987bc663 100644 --- a/taskflow/tests/unit/test_utils.py +++ b/taskflow/tests/unit/test_utils.py @@ -17,12 +17,12 @@ # under the License. import functools -import unittest +from taskflow import test from taskflow import utils -class UtilTest(unittest.TestCase): +class UtilTest(test.TestCase): def test_rollback_accum(self): context = {}