add image download tests

This commit is contained in:
Paul Querna 2013-11-07 20:00:01 +00:00
parent 532ad9b457
commit 448a7da95b
5 changed files with 107 additions and 4 deletions

View File

@ -39,10 +39,10 @@ class ImageDownloaderTask(BaseTask):
def _tick(self):
# TODO: get file download percentages.
self.percent = 0
super(BaseTask, self)._tick()
super(ImageDownloaderTask, self)._tick()
def _download_image_to_file(self, url):
destination = file(self._destination_filename, 'w')
destination = open(self._destination_filename, 'wb')
def push(data):
if self.running:

View File

@ -17,6 +17,7 @@ limitations under the License.
import time
import json
import random
import tempfile
from twisted.application.service import MultiService
from twisted.application.internet import TCPClient
@ -87,6 +88,12 @@ class TeethClient(MultiService, object):
}
}
@property
def conf_image_cache_path(self):
"""Path to iamge cache."""
# TODO: improve:
return tempfile.gettempdir()
def startService(self):
"""Start the Service."""
super(TeethClient, self).startService()

View File

@ -73,6 +73,17 @@ class RPCError(RPCMessage, RuntimeError):
self._raw_message = message
class ImageInfo(object):
"""
Metadata about a machine image.
"""
def __init__(self, image_id, image_urls, image_hashes):
super(ImageInfo, self).__init__()
self.id = image_id
self.urls = image_urls
self.hashes = image_hashes
class CommandValidationError(RuntimeError):
"""
Exception class which can be used to return an error when the

View File

@ -33,7 +33,7 @@ class CacheImagesTask(MultiTask):
super(CacheImagesTask, self).__init__(client, task_id, reporting_interval=reporting_interval)
self._images = images
for image in self._images:
image_path = os.path.join(client.get_cache_path(), image.id + '.img')
image_path = os.path.join(client.conf_image_cache_path, image.id + '.img')
t = ImageDownloaderTask(client,
task_id, image,
image_path,

View File

@ -15,10 +15,21 @@ limitations under the License.
"""
import uuid
import shutil
import tempfile
import hashlib
import os
from mock import Mock, patch
from twisted.internet import defer
from twisted.trial import unittest
from twisted.web.client import ResponseDone
from twisted.python.failure import Failure
from teeth_agent.protocol import ImageInfo
from teeth_agent.base_task import BaseTask, MultiTask
from mock import Mock
from teeth_agent.cache_image import ImageDownloaderTask
class FakeClient(object):
@ -106,3 +117,77 @@ class MultiTaskTest(unittest.TestCase):
self.assertEqual(self.task._state, 'complete')
self.client.finish_task.assert_any_call(t)
self.client.finish_task.assert_any_call(self.task)
class StubResponse(object):
def __init__(self, code, headers, body):
self.version = ('HTTP', 1, 1)
self.code = code
self.status = "ima teapot"
self.headers = headers
self.body = body
self.length = reduce(lambda x, y: x + len(y), body, 0)
self.protocol = None
def deliverBody(self, protocol):
self.protocol = protocol
def run(self):
self.protocol.connectionMade()
for data in self.body:
self.protocol.dataReceived(data)
self.protocol.connectionLost(Failure(ResponseDone("Response body fully received")))
class ImageDownloaderTaskTest(unittest.TestCase):
def setUp(self):
get_patcher = patch('treq.get', autospec=True)
self.TreqGet = get_patcher.start()
self.addCleanup(get_patcher.stop)
self.tmpdir = tempfile.mkdtemp('image_download_test')
self.task_id = str(uuid.uuid4())
self.image_data = str(uuid.uuid4())
self.image_md5 = hashlib.md5(self.image_data).hexdigest()
self.cache_path = os.path.join(self.tmpdir, 'a1234.img')
self.client = FakeClient()
self.image_info = ImageInfo('a1234',
['http://127.0.0.1/images/a1234.img'], {'md5': self.image_md5})
self.task = ImageDownloaderTask(self.client,
self.task_id,
self.image_info,
self.cache_path)
def tearDown(self):
shutil.rmtree(self.tmpdir)
def assertFileHash(self, hash_type, path, value):
file_hash = hashlib.new(hash_type)
with open(path, 'r') as fp:
file_hash.update(fp.read())
self.assertEqual(value, file_hash.hexdigest())
def test_download_success(self):
resp = StubResponse(200, [], [self.image_data])
d = defer.Deferred()
self.TreqGet.return_value = d
self.task.run()
self.client.addService.assert_called_once_with(self.task)
self.TreqGet.assert_called_once_with('http://127.0.0.1/images/a1234.img')
self.task.startService()
d.callback(resp)
resp.run()
self.client.update_task_status.assert_called_once_with(self.task)
self.assertFileHash('md5', self.cache_path, self.image_md5)
self.task.stopService()
self.assertEqual(self.task._state, 'error')
self.client.finish_task.assert_called_once_with(self.task)