Merge pull request #11 from racker/tasks_for_doing_things

Tasks and Image Download
This commit is contained in:
Paul Querna 2013-11-07 13:49:38 -08:00
commit b3ee1d828c
7 changed files with 362 additions and 75 deletions

View File

@ -4,3 +4,4 @@ argparse==1.2.1
wsgiref==0.1.2
zope.interface==4.0.5
structlog==0.3.0
treq==0.2.0

126
teeth_agent/base_task.py Normal file
View File

@ -0,0 +1,126 @@
"""
Copyright 2013 Rackspace, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from twisted.application.service import MultiService
from twisted.application.internet import TimerService
from twisted.internet import defer
from teeth_agent.logging import get_logger
__all__ = ['BaseTask', 'MultiTask']
class BaseTask(MultiService, object):
"""
Task to execute, reporting status periodically to TeethClient instance.
"""
task_name = 'task_undefined'
def __init__(self, client, task_id, reporting_interval=10):
super(BaseTask, self).__init__()
self.log = get_logger(task_id=task_id, task_name=self.task_name)
self.setName(self.task_name + '.' + task_id)
self._client = client
self._id = task_id
self._percent = 0
self._reporting_interval = reporting_interval
self._state = 'starting'
self._timer = TimerService(self._reporting_interval, self._tick)
self._timer.setServiceParent(self)
self._error_msg = None
self._done = False
self._d = defer.Deferred()
def _run(self):
"""Do the actual work here."""
def run(self):
"""Run the Task."""
# setServiceParent actually starts the task if it is already running
# so we run it in start.
if not self.parent:
self.setServiceParent(self._client)
self._run()
return self._d
def _tick(self):
if not self.running:
# log.debug("_tick called while not running :()")
return
if self._state in ['error', 'complete']:
self.stopService()
return self._client.update_task_status(self)
def error(self, message, *args, **kwargs):
"""Error out running of the task."""
self._error_msg = message
self._state = 'error'
self.stopService()
def complete(self, *args, **kwargs):
"""Complete running of the task."""
self._state = 'complete'
self.stopService()
def startService(self):
"""Start the Service."""
self._state = 'running'
super(BaseTask, self).startService()
def stopService(self):
"""Stop the Service."""
super(BaseTask, self).stopService()
if self._state not in ['error', 'complete']:
self.log.err("told to shutdown before task could complete, marking as error.")
self._error_msg = 'service being shutdown'
self._state = 'error'
if self._done is False:
self._done = True
self._d.callback(None)
self._client.finish_task(self)
class MultiTask(BaseTask):
"""Run multiple tasks in parallel."""
def __init__(self, client, task_id, reporting_interval=10):
super(MultiTask, self).__init__(client, task_id, reporting_interval=reporting_interval)
self._tasks = []
def _tick(self):
if len(self._tasks):
percents = [t._percent for t in self._tasks]
self._percent = sum(percents)/float(len(percents))
else:
self._percent = 0
super(MultiTask, self)._tick()
def _run(self):
ds = []
for t in self._tasks:
ds.append(t.run())
dl = defer.DeferredList(ds)
dl.addBoth(self.complete, self.error)
def add_task(self, task):
"""Add a task to be ran."""
task.setServiceParent(self)
self._tasks.append(task)

View File

@ -0,0 +1,54 @@
"""
Copyright 2013 Rackspace, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from teeth_agent.base_task import BaseTask
import treq
class ImageDownloaderTask(BaseTask):
"""Download image to cache. """
task_name = 'image_download'
def __init__(self, client, task_id, image_info, destination_filename, reporting_interval=10):
super(ImageDownloaderTask, self).__init__(client, task_id, reporting_interval=reporting_interval)
self._destination_filename = destination_filename
self._image_id = image_info.id
self._image_hashes = image_info.hashes
self._iamge_urls = image_info.urls
self._destination_filename = destination_filename
def _run(self):
# TODO: pick by protocol priority.
url = self._iamge_urls[0]
# TODO: more than just download, sha1 it.
return self._download_image_to_file(url)
def _tick(self):
# TODO: get file download percentages.
self.percent = 0
super(ImageDownloaderTask, self)._tick()
def _download_image_to_file(self, url):
destination = open(self._destination_filename, 'wb')
def push(data):
if self.running:
destination.write(data)
d = treq.get(url)
d.addCallback(treq.collect, push)
d.addBoth(lambda _: destination.close())
return d

View File

@ -17,6 +17,7 @@ limitations under the License.
import time
import json
import random
import tempfile
from twisted.application.service import MultiService
from twisted.application.internet import TCPClient
@ -87,6 +88,12 @@ class TeethClient(MultiService, object):
}
}
@property
def conf_image_cache_path(self):
"""Path to iamge cache."""
# TODO: improve:
return tempfile.gettempdir()
def startService(self):
"""Start the Service."""
super(TeethClient, self).startService()

View File

@ -73,6 +73,17 @@ class RPCError(RPCMessage, RuntimeError):
self._raw_message = message
class ImageInfo(object):
"""
Metadata about a machine image.
"""
def __init__(self, image_id, image_urls, image_hashes):
super(ImageInfo, self).__init__()
self.id = image_id
self.urls = image_urls
self.hashes = image_hashes
class CommandValidationError(RuntimeError):
"""
Exception class which can be used to return an error when the

View File

@ -14,82 +14,34 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
from twisted.application.service import MultiService
from twisted.application.internet import TimerService
from teeth_agent.logging import get_logger
import os
from teeth_agent.base_task import MultiTask, BaseTask
from teeth_agent.cache_image import ImageDownloaderTask
__all__ = ['Task', 'PrepareImageTask']
__all__ = ['CacheImagesTask', 'PrepareImageTask']
class Task(MultiService, object):
"""
Task to execute, reporting status periodically to TeethClient instance.
"""
class CacheImagesTask(MultiTask):
task_name = 'task_undefined'
"""Cache an array of images on a machine."""
def __init__(self, client, task_id, reporting_interval=10):
super(Task, self).__init__()
self.setName(self.task_name)
self._client = client
self._id = task_id
self._percent = 0
self._reporting_interval = reporting_interval
self._state = 'starting'
self._timer = TimerService(self._reporting_interval, self._tick)
self._timer.setServiceParent(self)
self._error_msg = None
self.log = get_logger(task_id=task_id, task_name=self.task_name)
task_name = 'cache_images'
def _run(self):
"""Do the actual work here."""
def run(self):
"""Run the Task."""
# setServiceParent actually starts the task if it is already running
# so we run it in start.
self.setServiceParent(self._client)
self._run()
def _tick(self):
if not self.running:
# log.debug("_tick called while not running :()")
return
return self._client.update_task_status(self)
def error(self, message):
"""Error out running of the task."""
self._error_msg = message
self._state = 'error'
self.stopService()
def complete(self):
"""Complete running of the task."""
self._state = 'complete'
self.stopService()
def startService(self):
"""Start the Service."""
super(Task, self).startService()
self._state = 'running'
def stopService(self):
"""Stop the Service."""
super(Task, self).stopService()
if not self._client.running:
return
if self._state not in ['error', 'complete']:
self.log.err("told to shutdown before task could complete, marking as error.")
self._error_msg = 'service being shutdown'
self._state = 'error'
self._client.finish_task(self)
def __init__(self, client, task_id, images, reporting_interval=10):
super(CacheImagesTask, self).__init__(client, task_id, reporting_interval=reporting_interval)
self._images = images
for image in self._images:
image_path = os.path.join(client.conf_image_cache_path, image.id + '.img')
t = ImageDownloaderTask(client,
task_id, image,
image_path,
reporting_interval=reporting_interval)
self.add_task(t)
class PrepareImageTask(Task):
class PrepareImageTask(BaseTask):
"""Prepare an image to be ran on the machine."""

View File

@ -15,25 +15,37 @@ limitations under the License.
"""
import uuid
import shutil
import tempfile
import hashlib
import os
from mock import Mock, patch
from twisted.internet import defer
from twisted.trial import unittest
from teeth_agent.task import Task
from mock import Mock
from twisted.web.client import ResponseDone
from twisted.python.failure import Failure
from teeth_agent.protocol import ImageInfo
from teeth_agent.base_task import BaseTask, MultiTask
from teeth_agent.cache_image import ImageDownloaderTask
class FakeClient(object):
addService = Mock(return_value=None)
running = Mock(return_value=0)
update_task_status = Mock(return_value=None)
finish_task = Mock(return_value=None)
def __init__(self):
self.addService = Mock(return_value=None)
self.running = Mock(return_value=0)
self.update_task_status = Mock(return_value=None)
self.finish_task = Mock(return_value=None)
class TestTask(Task):
class TestTask(BaseTask):
task_name = 'test_task'
class TaskTest(unittest.TestCase):
"""Event Emitter tests."""
"""Basic tests of the Task API."""
def setUp(self):
self.task_id = str(uuid.uuid4())
@ -45,6 +57,15 @@ class TaskTest(unittest.TestCase):
del self.task
del self.client
def test_error(self):
self.task.run()
self.client.addService.assert_called_once_with(self.task)
self.task.startService()
self.client.update_task_status.assert_called_once_with(self.task)
self.task.error('chaos monkey attack')
self.assertEqual(self.task._state, 'error')
self.client.finish_task.assert_called_once_with(self.task)
def test_run(self):
self.assertEqual(self.task._state, 'starting')
self.assertEqual(self.task._id, self.task_id)
@ -55,3 +76,118 @@ class TaskTest(unittest.TestCase):
self.task.complete()
self.assertEqual(self.task._state, 'complete')
self.client.finish_task.assert_called_once_with(self.task)
def test_fast_shutdown(self):
self.task.run()
self.client.addService.assert_called_once_with(self.task)
self.task.startService()
self.client.update_task_status.assert_called_once_with(self.task)
self.task.stopService()
self.assertEqual(self.task._state, 'error')
self.client.finish_task.assert_called_once_with(self.task)
class MultiTestTask(MultiTask):
task_name = 'test_multitask'
class MultiTaskTest(unittest.TestCase):
"""Basic tests of the Multi Task API."""
def setUp(self):
self.task_id = str(uuid.uuid4())
self.client = FakeClient()
self.task = MultiTestTask(self.client, self.task_id)
def tearDown(self):
del self.task_id
del self.task
del self.client
def test_tasks(self):
t = TestTask(self.client, self.task_id)
self.task.add_task(t)
self.assertEqual(self.task._state, 'starting')
self.assertEqual(self.task._id, self.task_id)
self.task.run()
self.client.addService.assert_called_once_with(self.task)
self.task.startService()
self.client.update_task_status.assert_any_call(self.task)
t.complete()
self.assertEqual(self.task._state, 'complete')
self.client.finish_task.assert_any_call(t)
self.client.finish_task.assert_any_call(self.task)
class StubResponse(object):
def __init__(self, code, headers, body):
self.version = ('HTTP', 1, 1)
self.code = code
self.status = "ima teapot"
self.headers = headers
self.body = body
self.length = reduce(lambda x, y: x + len(y), body, 0)
self.protocol = None
def deliverBody(self, protocol):
self.protocol = protocol
def run(self):
self.protocol.connectionMade()
for data in self.body:
self.protocol.dataReceived(data)
self.protocol.connectionLost(Failure(ResponseDone("Response body fully received")))
class ImageDownloaderTaskTest(unittest.TestCase):
def setUp(self):
get_patcher = patch('treq.get', autospec=True)
self.TreqGet = get_patcher.start()
self.addCleanup(get_patcher.stop)
self.tmpdir = tempfile.mkdtemp('image_download_test')
self.task_id = str(uuid.uuid4())
self.image_data = str(uuid.uuid4())
self.image_md5 = hashlib.md5(self.image_data).hexdigest()
self.cache_path = os.path.join(self.tmpdir, 'a1234.img')
self.client = FakeClient()
self.image_info = ImageInfo('a1234',
['http://127.0.0.1/images/a1234.img'], {'md5': self.image_md5})
self.task = ImageDownloaderTask(self.client,
self.task_id,
self.image_info,
self.cache_path)
def tearDown(self):
shutil.rmtree(self.tmpdir)
def assertFileHash(self, hash_type, path, value):
file_hash = hashlib.new(hash_type)
with open(path, 'r') as fp:
file_hash.update(fp.read())
self.assertEqual(value, file_hash.hexdigest())
def test_download_success(self):
resp = StubResponse(200, [], [self.image_data])
d = defer.Deferred()
self.TreqGet.return_value = d
self.task.run()
self.client.addService.assert_called_once_with(self.task)
self.TreqGet.assert_called_once_with('http://127.0.0.1/images/a1234.img')
self.task.startService()
d.callback(resp)
resp.run()
self.client.update_task_status.assert_called_once_with(self.task)
self.assertFileHash('md5', self.cache_path, self.image_md5)
self.task.stopService()
self.assertEqual(self.task._state, 'error')
self.client.finish_task.assert_called_once_with(self.task)