diff --git a/requirements.txt b/requirements.txt index b945a5f99..4f4583148 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ argparse==1.2.1 wsgiref==0.1.2 zope.interface==4.0.5 structlog==0.3.0 +treq==0.2.0 diff --git a/teeth_agent/base_task.py b/teeth_agent/base_task.py new file mode 100644 index 000000000..f64ac78df --- /dev/null +++ b/teeth_agent/base_task.py @@ -0,0 +1,126 @@ +""" +Copyright 2013 Rackspace, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from twisted.application.service import MultiService +from twisted.application.internet import TimerService +from twisted.internet import defer +from teeth_agent.logging import get_logger + +__all__ = ['BaseTask', 'MultiTask'] + + +class BaseTask(MultiService, object): + """ + Task to execute, reporting status periodically to TeethClient instance. + """ + + task_name = 'task_undefined' + + def __init__(self, client, task_id, reporting_interval=10): + super(BaseTask, self).__init__() + self.log = get_logger(task_id=task_id, task_name=self.task_name) + self.setName(self.task_name + '.' + task_id) + self._client = client + self._id = task_id + self._percent = 0 + self._reporting_interval = reporting_interval + self._state = 'starting' + self._timer = TimerService(self._reporting_interval, self._tick) + self._timer.setServiceParent(self) + self._error_msg = None + self._done = False + self._d = defer.Deferred() + + def _run(self): + """Do the actual work here.""" + + def run(self): + """Run the Task.""" + # setServiceParent actually starts the task if it is already running + # so we run it in start. + if not self.parent: + self.setServiceParent(self._client) + self._run() + return self._d + + def _tick(self): + if not self.running: + # log.debug("_tick called while not running :()") + return + + if self._state in ['error', 'complete']: + self.stopService() + + return self._client.update_task_status(self) + + def error(self, message, *args, **kwargs): + """Error out running of the task.""" + self._error_msg = message + self._state = 'error' + self.stopService() + + def complete(self, *args, **kwargs): + """Complete running of the task.""" + self._state = 'complete' + self.stopService() + + def startService(self): + """Start the Service.""" + self._state = 'running' + super(BaseTask, self).startService() + + def stopService(self): + """Stop the Service.""" + super(BaseTask, self).stopService() + + if self._state not in ['error', 'complete']: + self.log.err("told to shutdown before task could complete, marking as error.") + self._error_msg = 'service being shutdown' + self._state = 'error' + + if self._done is False: + self._done = True + self._d.callback(None) + self._client.finish_task(self) + + +class MultiTask(BaseTask): + + """Run multiple tasks in parallel.""" + + def __init__(self, client, task_id, reporting_interval=10): + super(MultiTask, self).__init__(client, task_id, reporting_interval=reporting_interval) + self._tasks = [] + + def _tick(self): + if len(self._tasks): + percents = [t._percent for t in self._tasks] + self._percent = sum(percents)/float(len(percents)) + else: + self._percent = 0 + super(MultiTask, self)._tick() + + def _run(self): + ds = [] + for t in self._tasks: + ds.append(t.run()) + dl = defer.DeferredList(ds) + dl.addBoth(self.complete, self.error) + + def add_task(self, task): + """Add a task to be ran.""" + task.setServiceParent(self) + self._tasks.append(task) diff --git a/teeth_agent/cache_image.py b/teeth_agent/cache_image.py new file mode 100644 index 000000000..98c8f3975 --- /dev/null +++ b/teeth_agent/cache_image.py @@ -0,0 +1,54 @@ +""" +Copyright 2013 Rackspace, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from teeth_agent.base_task import BaseTask +import treq + + +class ImageDownloaderTask(BaseTask): + """Download image to cache. """ + task_name = 'image_download' + + def __init__(self, client, task_id, image_info, destination_filename, reporting_interval=10): + super(ImageDownloaderTask, self).__init__(client, task_id, reporting_interval=reporting_interval) + self._destination_filename = destination_filename + self._image_id = image_info.id + self._image_hashes = image_info.hashes + self._iamge_urls = image_info.urls + self._destination_filename = destination_filename + + def _run(self): + # TODO: pick by protocol priority. + url = self._iamge_urls[0] + # TODO: more than just download, sha1 it. + return self._download_image_to_file(url) + + def _tick(self): + # TODO: get file download percentages. + self.percent = 0 + super(ImageDownloaderTask, self)._tick() + + def _download_image_to_file(self, url): + destination = open(self._destination_filename, 'wb') + + def push(data): + if self.running: + destination.write(data) + + d = treq.get(url) + d.addCallback(treq.collect, push) + d.addBoth(lambda _: destination.close()) + return d diff --git a/teeth_agent/client.py b/teeth_agent/client.py index 2c02e11ad..93083baee 100644 --- a/teeth_agent/client.py +++ b/teeth_agent/client.py @@ -17,6 +17,7 @@ limitations under the License. import time import json import random +import tempfile from twisted.application.service import MultiService from twisted.application.internet import TCPClient @@ -87,6 +88,12 @@ class TeethClient(MultiService, object): } } + @property + def conf_image_cache_path(self): + """Path to iamge cache.""" + # TODO: improve: + return tempfile.gettempdir() + def startService(self): """Start the Service.""" super(TeethClient, self).startService() diff --git a/teeth_agent/protocol.py b/teeth_agent/protocol.py index 14cce06aa..52f06c54b 100644 --- a/teeth_agent/protocol.py +++ b/teeth_agent/protocol.py @@ -73,6 +73,17 @@ class RPCError(RPCMessage, RuntimeError): self._raw_message = message +class ImageInfo(object): + """ + Metadata about a machine image. + """ + def __init__(self, image_id, image_urls, image_hashes): + super(ImageInfo, self).__init__() + self.id = image_id + self.urls = image_urls + self.hashes = image_hashes + + class CommandValidationError(RuntimeError): """ Exception class which can be used to return an error when the diff --git a/teeth_agent/task.py b/teeth_agent/task.py index 685c0b3a8..08793c908 100644 --- a/teeth_agent/task.py +++ b/teeth_agent/task.py @@ -14,82 +14,34 @@ See the License for the specific language governing permissions and limitations under the License. """ -from twisted.application.service import MultiService -from twisted.application.internet import TimerService -from teeth_agent.logging import get_logger +import os + +from teeth_agent.base_task import MultiTask, BaseTask +from teeth_agent.cache_image import ImageDownloaderTask -__all__ = ['Task', 'PrepareImageTask'] +__all__ = ['CacheImagesTask', 'PrepareImageTask'] -class Task(MultiService, object): - """ - Task to execute, reporting status periodically to TeethClient instance. - """ +class CacheImagesTask(MultiTask): - task_name = 'task_undefined' + """Cache an array of images on a machine.""" - def __init__(self, client, task_id, reporting_interval=10): - super(Task, self).__init__() - self.setName(self.task_name) - self._client = client - self._id = task_id - self._percent = 0 - self._reporting_interval = reporting_interval - self._state = 'starting' - self._timer = TimerService(self._reporting_interval, self._tick) - self._timer.setServiceParent(self) - self._error_msg = None - self.log = get_logger(task_id=task_id, task_name=self.task_name) + task_name = 'cache_images' - def _run(self): - """Do the actual work here.""" - - def run(self): - """Run the Task.""" - # setServiceParent actually starts the task if it is already running - # so we run it in start. - self.setServiceParent(self._client) - self._run() - - def _tick(self): - if not self.running: - # log.debug("_tick called while not running :()") - return - return self._client.update_task_status(self) - - def error(self, message): - """Error out running of the task.""" - self._error_msg = message - self._state = 'error' - self.stopService() - - def complete(self): - """Complete running of the task.""" - self._state = 'complete' - self.stopService() - - def startService(self): - """Start the Service.""" - super(Task, self).startService() - self._state = 'running' - - def stopService(self): - """Stop the Service.""" - super(Task, self).stopService() - - if not self._client.running: - return - - if self._state not in ['error', 'complete']: - self.log.err("told to shutdown before task could complete, marking as error.") - self._error_msg = 'service being shutdown' - self._state = 'error' - - self._client.finish_task(self) + def __init__(self, client, task_id, images, reporting_interval=10): + super(CacheImagesTask, self).__init__(client, task_id, reporting_interval=reporting_interval) + self._images = images + for image in self._images: + image_path = os.path.join(client.conf_image_cache_path, image.id + '.img') + t = ImageDownloaderTask(client, + task_id, image, + image_path, + reporting_interval=reporting_interval) + self.add_task(t) -class PrepareImageTask(Task): +class PrepareImageTask(BaseTask): """Prepare an image to be ran on the machine.""" diff --git a/teeth_agent/tests/test_task.py b/teeth_agent/tests/test_task.py index d0acc1527..896c876d7 100644 --- a/teeth_agent/tests/test_task.py +++ b/teeth_agent/tests/test_task.py @@ -15,25 +15,37 @@ limitations under the License. """ import uuid +import shutil +import tempfile +import hashlib +import os +from mock import Mock, patch + +from twisted.internet import defer from twisted.trial import unittest -from teeth_agent.task import Task -from mock import Mock +from twisted.web.client import ResponseDone +from twisted.python.failure import Failure + +from teeth_agent.protocol import ImageInfo +from teeth_agent.base_task import BaseTask, MultiTask +from teeth_agent.cache_image import ImageDownloaderTask class FakeClient(object): - addService = Mock(return_value=None) - running = Mock(return_value=0) - update_task_status = Mock(return_value=None) - finish_task = Mock(return_value=None) + def __init__(self): + self.addService = Mock(return_value=None) + self.running = Mock(return_value=0) + self.update_task_status = Mock(return_value=None) + self.finish_task = Mock(return_value=None) -class TestTask(Task): +class TestTask(BaseTask): task_name = 'test_task' class TaskTest(unittest.TestCase): - """Event Emitter tests.""" + """Basic tests of the Task API.""" def setUp(self): self.task_id = str(uuid.uuid4()) @@ -45,6 +57,15 @@ class TaskTest(unittest.TestCase): del self.task del self.client + def test_error(self): + self.task.run() + self.client.addService.assert_called_once_with(self.task) + self.task.startService() + self.client.update_task_status.assert_called_once_with(self.task) + self.task.error('chaos monkey attack') + self.assertEqual(self.task._state, 'error') + self.client.finish_task.assert_called_once_with(self.task) + def test_run(self): self.assertEqual(self.task._state, 'starting') self.assertEqual(self.task._id, self.task_id) @@ -55,3 +76,118 @@ class TaskTest(unittest.TestCase): self.task.complete() self.assertEqual(self.task._state, 'complete') self.client.finish_task.assert_called_once_with(self.task) + + def test_fast_shutdown(self): + self.task.run() + self.client.addService.assert_called_once_with(self.task) + self.task.startService() + self.client.update_task_status.assert_called_once_with(self.task) + self.task.stopService() + self.assertEqual(self.task._state, 'error') + self.client.finish_task.assert_called_once_with(self.task) + + +class MultiTestTask(MultiTask): + task_name = 'test_multitask' + + +class MultiTaskTest(unittest.TestCase): + """Basic tests of the Multi Task API.""" + + def setUp(self): + self.task_id = str(uuid.uuid4()) + self.client = FakeClient() + self.task = MultiTestTask(self.client, self.task_id) + + def tearDown(self): + del self.task_id + del self.task + del self.client + + def test_tasks(self): + t = TestTask(self.client, self.task_id) + self.task.add_task(t) + self.assertEqual(self.task._state, 'starting') + self.assertEqual(self.task._id, self.task_id) + self.task.run() + self.client.addService.assert_called_once_with(self.task) + self.task.startService() + self.client.update_task_status.assert_any_call(self.task) + t.complete() + self.assertEqual(self.task._state, 'complete') + self.client.finish_task.assert_any_call(t) + self.client.finish_task.assert_any_call(self.task) + + +class StubResponse(object): + def __init__(self, code, headers, body): + self.version = ('HTTP', 1, 1) + self.code = code + self.status = "ima teapot" + self.headers = headers + self.body = body + self.length = reduce(lambda x, y: x + len(y), body, 0) + self.protocol = None + + def deliverBody(self, protocol): + self.protocol = protocol + + def run(self): + self.protocol.connectionMade() + + for data in self.body: + self.protocol.dataReceived(data) + + self.protocol.connectionLost(Failure(ResponseDone("Response body fully received"))) + + +class ImageDownloaderTaskTest(unittest.TestCase): + + def setUp(self): + get_patcher = patch('treq.get', autospec=True) + self.TreqGet = get_patcher.start() + self.addCleanup(get_patcher.stop) + + self.tmpdir = tempfile.mkdtemp('image_download_test') + self.task_id = str(uuid.uuid4()) + self.image_data = str(uuid.uuid4()) + self.image_md5 = hashlib.md5(self.image_data).hexdigest() + self.cache_path = os.path.join(self.tmpdir, 'a1234.img') + self.client = FakeClient() + self.image_info = ImageInfo('a1234', + ['http://127.0.0.1/images/a1234.img'], {'md5': self.image_md5}) + self.task = ImageDownloaderTask(self.client, + self.task_id, + self.image_info, + self.cache_path) + + def tearDown(self): + shutil.rmtree(self.tmpdir) + + def assertFileHash(self, hash_type, path, value): + file_hash = hashlib.new(hash_type) + with open(path, 'r') as fp: + file_hash.update(fp.read()) + self.assertEqual(value, file_hash.hexdigest()) + + def test_download_success(self): + resp = StubResponse(200, [], [self.image_data]) + d = defer.Deferred() + self.TreqGet.return_value = d + self.task.run() + self.client.addService.assert_called_once_with(self.task) + + self.TreqGet.assert_called_once_with('http://127.0.0.1/images/a1234.img') + + self.task.startService() + + d.callback(resp) + + resp.run() + + self.client.update_task_status.assert_called_once_with(self.task) + self.assertFileHash('md5', self.cache_path, self.image_md5) + + self.task.stopService() + self.assertEqual(self.task._state, 'error') + self.client.finish_task.assert_called_once_with(self.task)