add basic tasks structure
This commit is contained in:
parent
c4e9a6dd0a
commit
532ad9b457
|
@ -4,3 +4,4 @@ argparse==1.2.1
|
|||
wsgiref==0.1.2
|
||||
zope.interface==4.0.5
|
||||
structlog==0.3.0
|
||||
treq==0.2.0
|
||||
|
|
|
@ -0,0 +1,126 @@
|
|||
"""
|
||||
Copyright 2013 Rackspace, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from twisted.application.service import MultiService
|
||||
from twisted.application.internet import TimerService
|
||||
from twisted.internet import defer
|
||||
from teeth_agent.logging import get_logger
|
||||
|
||||
__all__ = ['BaseTask', 'MultiTask']
|
||||
|
||||
|
||||
class BaseTask(MultiService, object):
|
||||
"""
|
||||
Task to execute, reporting status periodically to TeethClient instance.
|
||||
"""
|
||||
|
||||
task_name = 'task_undefined'
|
||||
|
||||
def __init__(self, client, task_id, reporting_interval=10):
|
||||
super(BaseTask, self).__init__()
|
||||
self.log = get_logger(task_id=task_id, task_name=self.task_name)
|
||||
self.setName(self.task_name)
|
||||
self._client = client
|
||||
self._id = task_id
|
||||
self._percent = 0
|
||||
self._reporting_interval = reporting_interval
|
||||
self._state = 'starting'
|
||||
self._timer = TimerService(self._reporting_interval, self._tick)
|
||||
self._timer.setServiceParent(self)
|
||||
self._error_msg = None
|
||||
self._done = False
|
||||
self._d = defer.Deferred()
|
||||
|
||||
def _run(self):
|
||||
"""Do the actual work here."""
|
||||
|
||||
def run(self):
|
||||
"""Run the Task."""
|
||||
# setServiceParent actually starts the task if it is already running
|
||||
# so we run it in start.
|
||||
if not self.parent:
|
||||
self.setServiceParent(self._client)
|
||||
self._run()
|
||||
return self._d
|
||||
|
||||
def _tick(self):
|
||||
if not self.running:
|
||||
# log.debug("_tick called while not running :()")
|
||||
return
|
||||
|
||||
if self._state in ['error', 'complete']:
|
||||
self.stopService()
|
||||
|
||||
return self._client.update_task_status(self)
|
||||
|
||||
def error(self, message, *args, **kwargs):
|
||||
"""Error out running of the task."""
|
||||
self._error_msg = message
|
||||
self._state = 'error'
|
||||
self.stopService()
|
||||
|
||||
def complete(self, *args, **kwargs):
|
||||
"""Complete running of the task."""
|
||||
self._state = 'complete'
|
||||
self.stopService()
|
||||
|
||||
def startService(self):
|
||||
"""Start the Service."""
|
||||
self._state = 'running'
|
||||
super(BaseTask, self).startService()
|
||||
|
||||
def stopService(self):
|
||||
"""Stop the Service."""
|
||||
super(BaseTask, self).stopService()
|
||||
|
||||
if self._state not in ['error', 'complete']:
|
||||
self.log.err("told to shutdown before task could complete, marking as error.")
|
||||
self._error_msg = 'service being shutdown'
|
||||
self._state = 'error'
|
||||
|
||||
if self._done is False:
|
||||
self._done = True
|
||||
self._d.callback(None)
|
||||
self._client.finish_task(self)
|
||||
|
||||
|
||||
class MultiTask(BaseTask):
|
||||
|
||||
"""Run multiple tasks in parallel."""
|
||||
|
||||
def __init__(self, client, task_id, reporting_interval=10):
|
||||
super(MultiTask, self).__init__(client, task_id, reporting_interval=reporting_interval)
|
||||
self._tasks = []
|
||||
|
||||
def _tick(self):
|
||||
if len(self._tasks):
|
||||
percents = [t._percent for t in self._tasks]
|
||||
self._percent = sum(percents)/float(len(percents))
|
||||
else:
|
||||
self._percent = 0
|
||||
super(MultiTask, self)._tick()
|
||||
|
||||
def _run(self):
|
||||
ds = []
|
||||
for t in self._tasks:
|
||||
ds.append(t.run())
|
||||
dl = defer.DeferredList(ds)
|
||||
dl.addBoth(self.complete, self.error)
|
||||
|
||||
def add_task(self, task):
|
||||
"""Add a task to be ran."""
|
||||
task.setServiceParent(self)
|
||||
self._tasks.append(task)
|
|
@ -0,0 +1,54 @@
|
|||
"""
|
||||
Copyright 2013 Rackspace, Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from teeth_agent.base_task import BaseTask
|
||||
import treq
|
||||
|
||||
|
||||
class ImageDownloaderTask(BaseTask):
|
||||
"""Download image to cache. """
|
||||
task_name = 'image_download'
|
||||
|
||||
def __init__(self, client, task_id, image_info, destination_filename, reporting_interval=10):
|
||||
super(ImageDownloaderTask, self).__init__(client, task_id, reporting_interval=reporting_interval)
|
||||
self._destination_filename = destination_filename
|
||||
self._image_id = image_info.id
|
||||
self._image_hashes = image_info.hashes
|
||||
self._iamge_urls = image_info.urls
|
||||
self._destination_filename = destination_filename
|
||||
|
||||
def _run(self):
|
||||
# TODO: pick by protocol priority.
|
||||
url = self._iamge_urls[0]
|
||||
# TODO: more than just download, sha1 it.
|
||||
return self._download_image_to_file(url)
|
||||
|
||||
def _tick(self):
|
||||
# TODO: get file download percentages.
|
||||
self.percent = 0
|
||||
super(BaseTask, self)._tick()
|
||||
|
||||
def _download_image_to_file(self, url):
|
||||
destination = file(self._destination_filename, 'w')
|
||||
|
||||
def push(data):
|
||||
if self.running:
|
||||
destination.write(data)
|
||||
|
||||
d = treq.get(url)
|
||||
d.addCallback(treq.collect, push)
|
||||
d.addBoth(lambda _: destination.close())
|
||||
return d
|
|
@ -14,82 +14,34 @@ See the License for the specific language governing permissions and
|
|||
limitations under the License.
|
||||
"""
|
||||
|
||||
from twisted.application.service import MultiService
|
||||
from twisted.application.internet import TimerService
|
||||
from teeth_agent.logging import get_logger
|
||||
import os
|
||||
|
||||
from teeth_agent.base_task import MultiTask, BaseTask
|
||||
from teeth_agent.cache_image import ImageDownloaderTask
|
||||
|
||||
|
||||
__all__ = ['Task', 'PrepareImageTask']
|
||||
__all__ = ['CacheImagesTask', 'PrepareImageTask']
|
||||
|
||||
|
||||
class Task(MultiService, object):
|
||||
"""
|
||||
Task to execute, reporting status periodically to TeethClient instance.
|
||||
"""
|
||||
class CacheImagesTask(MultiTask):
|
||||
|
||||
task_name = 'task_undefined'
|
||||
"""Cache an array of images on a machine."""
|
||||
|
||||
def __init__(self, client, task_id, reporting_interval=10):
|
||||
super(Task, self).__init__()
|
||||
self.setName(self.task_name)
|
||||
self._client = client
|
||||
self._id = task_id
|
||||
self._percent = 0
|
||||
self._reporting_interval = reporting_interval
|
||||
self._state = 'starting'
|
||||
self._timer = TimerService(self._reporting_interval, self._tick)
|
||||
self._timer.setServiceParent(self)
|
||||
self._error_msg = None
|
||||
self.log = get_logger(task_id=task_id, task_name=self.task_name)
|
||||
task_name = 'cache_images'
|
||||
|
||||
def _run(self):
|
||||
"""Do the actual work here."""
|
||||
|
||||
def run(self):
|
||||
"""Run the Task."""
|
||||
# setServiceParent actually starts the task if it is already running
|
||||
# so we run it in start.
|
||||
self.setServiceParent(self._client)
|
||||
self._run()
|
||||
|
||||
def _tick(self):
|
||||
if not self.running:
|
||||
# log.debug("_tick called while not running :()")
|
||||
return
|
||||
return self._client.update_task_status(self)
|
||||
|
||||
def error(self, message):
|
||||
"""Error out running of the task."""
|
||||
self._error_msg = message
|
||||
self._state = 'error'
|
||||
self.stopService()
|
||||
|
||||
def complete(self):
|
||||
"""Complete running of the task."""
|
||||
self._state = 'complete'
|
||||
self.stopService()
|
||||
|
||||
def startService(self):
|
||||
"""Start the Service."""
|
||||
super(Task, self).startService()
|
||||
self._state = 'running'
|
||||
|
||||
def stopService(self):
|
||||
"""Stop the Service."""
|
||||
super(Task, self).stopService()
|
||||
|
||||
if not self._client.running:
|
||||
return
|
||||
|
||||
if self._state not in ['error', 'complete']:
|
||||
self.log.err("told to shutdown before task could complete, marking as error.")
|
||||
self._error_msg = 'service being shutdown'
|
||||
self._state = 'error'
|
||||
|
||||
self._client.finish_task(self)
|
||||
def __init__(self, client, task_id, images, reporting_interval=10):
|
||||
super(CacheImagesTask, self).__init__(client, task_id, reporting_interval=reporting_interval)
|
||||
self._images = images
|
||||
for image in self._images:
|
||||
image_path = os.path.join(client.get_cache_path(), image.id + '.img')
|
||||
t = ImageDownloaderTask(client,
|
||||
task_id, image,
|
||||
image_path,
|
||||
reporting_interval=reporting_interval)
|
||||
self.add_task(t)
|
||||
|
||||
|
||||
class PrepareImageTask(Task):
|
||||
class PrepareImageTask(BaseTask):
|
||||
|
||||
"""Prepare an image to be ran on the machine."""
|
||||
|
||||
|
|
|
@ -17,23 +17,24 @@ limitations under the License.
|
|||
import uuid
|
||||
|
||||
from twisted.trial import unittest
|
||||
from teeth_agent.task import Task
|
||||
from teeth_agent.base_task import BaseTask, MultiTask
|
||||
from mock import Mock
|
||||
|
||||
|
||||
class FakeClient(object):
|
||||
addService = Mock(return_value=None)
|
||||
running = Mock(return_value=0)
|
||||
update_task_status = Mock(return_value=None)
|
||||
finish_task = Mock(return_value=None)
|
||||
def __init__(self):
|
||||
self.addService = Mock(return_value=None)
|
||||
self.running = Mock(return_value=0)
|
||||
self.update_task_status = Mock(return_value=None)
|
||||
self.finish_task = Mock(return_value=None)
|
||||
|
||||
|
||||
class TestTask(Task):
|
||||
class TestTask(BaseTask):
|
||||
task_name = 'test_task'
|
||||
|
||||
|
||||
class TaskTest(unittest.TestCase):
|
||||
"""Event Emitter tests."""
|
||||
"""Basic tests of the Task API."""
|
||||
|
||||
def setUp(self):
|
||||
self.task_id = str(uuid.uuid4())
|
||||
|
@ -45,6 +46,15 @@ class TaskTest(unittest.TestCase):
|
|||
del self.task
|
||||
del self.client
|
||||
|
||||
def test_error(self):
|
||||
self.task.run()
|
||||
self.client.addService.assert_called_once_with(self.task)
|
||||
self.task.startService()
|
||||
self.client.update_task_status.assert_called_once_with(self.task)
|
||||
self.task.error('chaos monkey attack')
|
||||
self.assertEqual(self.task._state, 'error')
|
||||
self.client.finish_task.assert_called_once_with(self.task)
|
||||
|
||||
def test_run(self):
|
||||
self.assertEqual(self.task._state, 'starting')
|
||||
self.assertEqual(self.task._id, self.task_id)
|
||||
|
@ -55,3 +65,44 @@ class TaskTest(unittest.TestCase):
|
|||
self.task.complete()
|
||||
self.assertEqual(self.task._state, 'complete')
|
||||
self.client.finish_task.assert_called_once_with(self.task)
|
||||
|
||||
def test_fast_shutdown(self):
|
||||
self.task.run()
|
||||
self.client.addService.assert_called_once_with(self.task)
|
||||
self.task.startService()
|
||||
self.client.update_task_status.assert_called_once_with(self.task)
|
||||
self.task.stopService()
|
||||
self.assertEqual(self.task._state, 'error')
|
||||
self.client.finish_task.assert_called_once_with(self.task)
|
||||
|
||||
|
||||
class MultiTestTask(MultiTask):
|
||||
task_name = 'test_multitask'
|
||||
|
||||
|
||||
class MultiTaskTest(unittest.TestCase):
|
||||
"""Basic tests of the Multi Task API."""
|
||||
|
||||
def setUp(self):
|
||||
self.task_id = str(uuid.uuid4())
|
||||
self.client = FakeClient()
|
||||
self.task = MultiTestTask(self.client, self.task_id)
|
||||
|
||||
def tearDown(self):
|
||||
del self.task_id
|
||||
del self.task
|
||||
del self.client
|
||||
|
||||
def test_tasks(self):
|
||||
t = TestTask(self.client, self.task_id)
|
||||
self.task.add_task(t)
|
||||
self.assertEqual(self.task._state, 'starting')
|
||||
self.assertEqual(self.task._id, self.task_id)
|
||||
self.task.run()
|
||||
self.client.addService.assert_called_once_with(self.task)
|
||||
self.task.startService()
|
||||
self.client.update_task_status.assert_any_call(self.task)
|
||||
t.complete()
|
||||
self.assertEqual(self.task._state, 'complete')
|
||||
self.client.finish_task.assert_any_call(t)
|
||||
self.client.finish_task.assert_any_call(self.task)
|
||||
|
|
Loading…
Reference in New Issue