Remove line containing comment - # vim: tabstop=4 shiftwidth=4 softtabstop=4 Change-Id: I7581cc88b8de433d5609ed06c6570b0b45c13573 Closes-Bug:#1229324
		
			
				
	
	
		
			285 lines
		
	
	
		
			9.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			285 lines
		
	
	
		
			9.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
# -*- coding: utf-8 -*-
 | 
						|
 | 
						|
#    Copyright (C) 2013 Yahoo! Inc. All Rights Reserved.
 | 
						|
#
 | 
						|
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
 | 
						|
#    not use this file except in compliance with the License. You may obtain
 | 
						|
#    a copy of the License at
 | 
						|
#
 | 
						|
#         http://www.apache.org/licenses/LICENSE-2.0
 | 
						|
#
 | 
						|
#    Unless required by applicable law or agreed to in writing, software
 | 
						|
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 | 
						|
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 | 
						|
#    License for the specific language governing permissions and limitations
 | 
						|
#    under the License.
 | 
						|
 | 
						|
import mock
 | 
						|
 | 
						|
from taskflow import task
 | 
						|
from taskflow import test
 | 
						|
from taskflow.utils import reflection
 | 
						|
 | 
						|
 | 
						|
class MyTask(task.Task):
 | 
						|
    def execute(self, context, spam, eggs):
 | 
						|
        pass
 | 
						|
 | 
						|
 | 
						|
class KwargsTask(task.Task):
 | 
						|
    def execute(self, spam, **kwargs):
 | 
						|
        pass
 | 
						|
 | 
						|
 | 
						|
class DefaultArgTask(task.Task):
 | 
						|
    def execute(self, spam, eggs=()):
 | 
						|
        pass
 | 
						|
 | 
						|
 | 
						|
class DefaultProvidesTask(task.Task):
 | 
						|
    default_provides = 'def'
 | 
						|
 | 
						|
    def execute(self):
 | 
						|
        return None
 | 
						|
 | 
						|
 | 
						|
class ProgressTask(task.Task):
 | 
						|
    def execute(self, values, **kwargs):
 | 
						|
        for value in values:
 | 
						|
            self.update_progress(value)
 | 
						|
 | 
						|
 | 
						|
class TaskTest(test.TestCase):
 | 
						|
 | 
						|
    def test_passed_name(self):
 | 
						|
        my_task = MyTask(name='my name')
 | 
						|
        self.assertEqual(my_task.name, 'my name')
 | 
						|
 | 
						|
    def test_generated_name(self):
 | 
						|
        my_task = MyTask()
 | 
						|
        self.assertEqual(my_task.name,
 | 
						|
                         '%s.%s' % (__name__, 'MyTask'))
 | 
						|
 | 
						|
    def test_no_provides(self):
 | 
						|
        my_task = MyTask()
 | 
						|
        self.assertEqual(my_task.save_as, {})
 | 
						|
 | 
						|
    def test_provides(self):
 | 
						|
        my_task = MyTask(provides='food')
 | 
						|
        self.assertEqual(my_task.save_as, {'food': None})
 | 
						|
 | 
						|
    def test_multi_provides(self):
 | 
						|
        my_task = MyTask(provides=('food', 'water'))
 | 
						|
        self.assertEqual(my_task.save_as, {'food': 0, 'water': 1})
 | 
						|
 | 
						|
    def test_unpack(self):
 | 
						|
        my_task = MyTask(provides=('food',))
 | 
						|
        self.assertEqual(my_task.save_as, {'food': 0})
 | 
						|
 | 
						|
    def test_bad_provides(self):
 | 
						|
        self.assertRaisesRegexp(TypeError, '^Task provides',
 | 
						|
                                MyTask, provides=object())
 | 
						|
 | 
						|
    def test_requires_by_default(self):
 | 
						|
        my_task = MyTask()
 | 
						|
        self.assertEqual(my_task.rebind, {
 | 
						|
            'spam': 'spam',
 | 
						|
            'eggs': 'eggs',
 | 
						|
            'context': 'context'
 | 
						|
        })
 | 
						|
 | 
						|
    def test_requires_amended(self):
 | 
						|
        my_task = MyTask(requires=('spam', 'eggs'))
 | 
						|
        self.assertEqual(my_task.rebind, {
 | 
						|
            'spam': 'spam',
 | 
						|
            'eggs': 'eggs',
 | 
						|
            'context': 'context'
 | 
						|
        })
 | 
						|
 | 
						|
    def test_requires_explicit(self):
 | 
						|
        my_task = MyTask(auto_extract=False,
 | 
						|
                         requires=('spam', 'eggs', 'context'))
 | 
						|
        self.assertEqual(my_task.rebind, {
 | 
						|
            'spam': 'spam',
 | 
						|
            'eggs': 'eggs',
 | 
						|
            'context': 'context'
 | 
						|
        })
 | 
						|
 | 
						|
    def test_requires_explicit_not_enough(self):
 | 
						|
        self.assertRaisesRegexp(ValueError, '^Missing arguments',
 | 
						|
                                MyTask,
 | 
						|
                                auto_extract=False, requires=('spam', 'eggs'))
 | 
						|
 | 
						|
    def test_requires_ignores_optional(self):
 | 
						|
        my_task = DefaultArgTask()
 | 
						|
        self.assertEqual(my_task.requires, set(['spam']))
 | 
						|
 | 
						|
    def test_requires_allows_optional(self):
 | 
						|
        my_task = DefaultArgTask(requires=('spam', 'eggs'))
 | 
						|
        self.assertEqual(my_task.requires, set(['spam', 'eggs']))
 | 
						|
 | 
						|
    def test_rebind_all_args(self):
 | 
						|
        my_task = MyTask(rebind={'spam': 'a', 'eggs': 'b', 'context': 'c'})
 | 
						|
        self.assertEqual(my_task.rebind, {
 | 
						|
            'spam': 'a',
 | 
						|
            'eggs': 'b',
 | 
						|
            'context': 'c'
 | 
						|
        })
 | 
						|
 | 
						|
    def test_rebind_partial(self):
 | 
						|
        my_task = MyTask(rebind={'spam': 'a', 'eggs': 'b'})
 | 
						|
        self.assertEqual(my_task.rebind, {
 | 
						|
            'spam': 'a',
 | 
						|
            'eggs': 'b',
 | 
						|
            'context': 'context'
 | 
						|
        })
 | 
						|
 | 
						|
    def test_rebind_unknown(self):
 | 
						|
        self.assertRaisesRegexp(ValueError, '^Extra arguments',
 | 
						|
                                MyTask, rebind={'foo': 'bar'})
 | 
						|
 | 
						|
    def test_rebind_unknown_kwargs(self):
 | 
						|
        task = KwargsTask(rebind={'foo': 'bar'})
 | 
						|
        self.assertEqual(task.rebind, {
 | 
						|
            'foo': 'bar',
 | 
						|
            'spam': 'spam'
 | 
						|
        })
 | 
						|
 | 
						|
    def test_rebind_list_all(self):
 | 
						|
        my_task = MyTask(rebind=('a', 'b', 'c'))
 | 
						|
        self.assertEqual(my_task.rebind, {
 | 
						|
            'context': 'a',
 | 
						|
            'spam': 'b',
 | 
						|
            'eggs': 'c'
 | 
						|
        })
 | 
						|
 | 
						|
    def test_rebind_list_partial(self):
 | 
						|
        my_task = MyTask(rebind=('a', 'b'))
 | 
						|
        self.assertEqual(my_task.rebind, {
 | 
						|
            'context': 'a',
 | 
						|
            'spam': 'b',
 | 
						|
            'eggs': 'eggs'
 | 
						|
        })
 | 
						|
 | 
						|
    def test_rebind_list_more(self):
 | 
						|
        self.assertRaisesRegexp(ValueError, '^Extra arguments',
 | 
						|
                                MyTask, rebind=('a', 'b', 'c', 'd'))
 | 
						|
 | 
						|
    def test_rebind_list_more_kwargs(self):
 | 
						|
        task = KwargsTask(rebind=('a', 'b', 'c'))
 | 
						|
        self.assertEqual(task.rebind, {
 | 
						|
            'spam': 'a',
 | 
						|
            'b': 'b',
 | 
						|
            'c': 'c'
 | 
						|
        })
 | 
						|
 | 
						|
    def test_rebind_list_bad_value(self):
 | 
						|
        self.assertRaisesRegexp(TypeError, '^Invalid rebind value:',
 | 
						|
                                MyTask, rebind=object())
 | 
						|
 | 
						|
    def test_default_provides(self):
 | 
						|
        task = DefaultProvidesTask()
 | 
						|
        self.assertEqual(task.provides, set(['def']))
 | 
						|
        self.assertEqual(task.save_as, {'def': None})
 | 
						|
 | 
						|
    def test_default_provides_can_be_overridden(self):
 | 
						|
        task = DefaultProvidesTask(provides=('spam', 'eggs'))
 | 
						|
        self.assertEqual(task.provides, set(['spam', 'eggs']))
 | 
						|
        self.assertEqual(task.save_as, {'spam': 0, 'eggs': 1})
 | 
						|
 | 
						|
    def test_update_progress_within_bounds(self):
 | 
						|
        values = [0.0, 0.5, 1.0]
 | 
						|
        result = []
 | 
						|
 | 
						|
        def progress_callback(task, event_data, progress):
 | 
						|
            result.append(progress)
 | 
						|
 | 
						|
        task = ProgressTask()
 | 
						|
        with task.autobind('update_progress', progress_callback):
 | 
						|
            task.execute(values)
 | 
						|
        self.assertEqual(result, values)
 | 
						|
 | 
						|
    @mock.patch.object(task.LOG, 'warn')
 | 
						|
    def test_update_progress_lower_bound(self, mocked_warn):
 | 
						|
        result = []
 | 
						|
 | 
						|
        def progress_callback(task, event_data, progress):
 | 
						|
            result.append(progress)
 | 
						|
 | 
						|
        task = ProgressTask()
 | 
						|
        with task.autobind('update_progress', progress_callback):
 | 
						|
            task.execute([-1.0, -0.5, 0.0])
 | 
						|
        self.assertEqual(result, [0.0, 0.0, 0.0])
 | 
						|
        self.assertEqual(mocked_warn.call_count, 2)
 | 
						|
 | 
						|
    @mock.patch.object(task.LOG, 'warn')
 | 
						|
    def test_update_progress_upper_bound(self, mocked_warn):
 | 
						|
        result = []
 | 
						|
 | 
						|
        def progress_callback(task, event_data, progress):
 | 
						|
            result.append(progress)
 | 
						|
 | 
						|
        task = ProgressTask()
 | 
						|
        with task.autobind('update_progress', progress_callback):
 | 
						|
            task.execute([1.0, 1.5, 2.0])
 | 
						|
        self.assertEqual(result, [1.0, 1.0, 1.0])
 | 
						|
        self.assertEqual(mocked_warn.call_count, 2)
 | 
						|
 | 
						|
    @mock.patch.object(task.LOG, 'exception')
 | 
						|
    def test_update_progress_handler_failure(self, mocked_exception):
 | 
						|
        def progress_callback(*args, **kwargs):
 | 
						|
            raise Exception('Woot!')
 | 
						|
 | 
						|
        task = ProgressTask()
 | 
						|
        with task.autobind('update_progress', progress_callback):
 | 
						|
            task.execute([0.5])
 | 
						|
        mocked_exception.assert_called_once_with(
 | 
						|
            mock.ANY, reflection.get_callable_name(progress_callback),
 | 
						|
            'update_progress')
 | 
						|
 | 
						|
    @mock.patch.object(task.LOG, 'exception')
 | 
						|
    def test_autobind_non_existent_event(self, mocked_exception):
 | 
						|
        event = 'test-event'
 | 
						|
        handler = lambda: None
 | 
						|
        task = MyTask()
 | 
						|
        with task.autobind(event, handler):
 | 
						|
            self.assertEqual(len(task._events_listeners), 0)
 | 
						|
            mocked_exception.assert_called_once_with(
 | 
						|
                mock.ANY, handler, event, task)
 | 
						|
 | 
						|
    def test_autobind_handler_is_none(self):
 | 
						|
        task = MyTask()
 | 
						|
        with task.autobind('update_progress', None):
 | 
						|
            self.assertEqual(len(task._events_listeners), 0)
 | 
						|
 | 
						|
    def test_unbind_any_handler(self):
 | 
						|
        task = MyTask()
 | 
						|
        self.assertEqual(len(task._events_listeners), 0)
 | 
						|
        task.bind('update_progress', lambda: None)
 | 
						|
        self.assertEqual(len(task._events_listeners), 1)
 | 
						|
        self.assertTrue(task.unbind('update_progress'))
 | 
						|
        self.assertEqual(len(task._events_listeners), 0)
 | 
						|
 | 
						|
    def test_unbind_any_handler_empty_listeners(self):
 | 
						|
        task = MyTask()
 | 
						|
        self.assertEqual(len(task._events_listeners), 0)
 | 
						|
        self.assertFalse(task.unbind('update_progress'))
 | 
						|
        self.assertEqual(len(task._events_listeners), 0)
 | 
						|
 | 
						|
    def test_unbind_non_existent_listener(self):
 | 
						|
        handler1 = lambda: None
 | 
						|
        handler2 = lambda: None
 | 
						|
        task = MyTask()
 | 
						|
        task.bind('update_progress', handler1)
 | 
						|
        self.assertEqual(len(task._events_listeners), 1)
 | 
						|
        self.assertFalse(task.unbind('update_progress', handler2))
 | 
						|
        self.assertEqual(len(task._events_listeners), 1)
 | 
						|
 | 
						|
 | 
						|
class FunctorTaskTest(test.TestCase):
 | 
						|
 | 
						|
    def test_creation_with_version(self):
 | 
						|
        version = (2, 0)
 | 
						|
        f_task = task.FunctorTask(lambda: None, version=version)
 | 
						|
        self.assertEqual(f_task.version, version)
 |