add image download tests
This commit is contained in:
parent
532ad9b457
commit
448a7da95b
|
@ -39,10 +39,10 @@ class ImageDownloaderTask(BaseTask):
|
|||
def _tick(self):
|
||||
# TODO: get file download percentages.
|
||||
self.percent = 0
|
||||
super(BaseTask, self)._tick()
|
||||
super(ImageDownloaderTask, self)._tick()
|
||||
|
||||
def _download_image_to_file(self, url):
|
||||
destination = file(self._destination_filename, 'w')
|
||||
destination = open(self._destination_filename, 'wb')
|
||||
|
||||
def push(data):
|
||||
if self.running:
|
||||
|
|
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||
import time
|
||||
import json
|
||||
import random
|
||||
import tempfile
|
||||
|
||||
from twisted.application.service import MultiService
|
||||
from twisted.application.internet import TCPClient
|
||||
|
@ -87,6 +88,12 @@ class TeethClient(MultiService, object):
|
|||
}
|
||||
}
|
||||
|
||||
@property
|
||||
def conf_image_cache_path(self):
|
||||
"""Path to iamge cache."""
|
||||
# TODO: improve:
|
||||
return tempfile.gettempdir()
|
||||
|
||||
def startService(self):
|
||||
"""Start the Service."""
|
||||
super(TeethClient, self).startService()
|
||||
|
|
|
@ -73,6 +73,17 @@ class RPCError(RPCMessage, RuntimeError):
|
|||
self._raw_message = message
|
||||
|
||||
|
||||
class ImageInfo(object):
|
||||
"""
|
||||
Metadata about a machine image.
|
||||
"""
|
||||
def __init__(self, image_id, image_urls, image_hashes):
|
||||
super(ImageInfo, self).__init__()
|
||||
self.id = image_id
|
||||
self.urls = image_urls
|
||||
self.hashes = image_hashes
|
||||
|
||||
|
||||
class CommandValidationError(RuntimeError):
|
||||
"""
|
||||
Exception class which can be used to return an error when the
|
||||
|
|
|
@ -33,7 +33,7 @@ class CacheImagesTask(MultiTask):
|
|||
super(CacheImagesTask, self).__init__(client, task_id, reporting_interval=reporting_interval)
|
||||
self._images = images
|
||||
for image in self._images:
|
||||
image_path = os.path.join(client.get_cache_path(), image.id + '.img')
|
||||
image_path = os.path.join(client.conf_image_cache_path, image.id + '.img')
|
||||
t = ImageDownloaderTask(client,
|
||||
task_id, image,
|
||||
image_path,
|
||||
|
|
|
@ -15,10 +15,21 @@ limitations under the License.
|
|||
"""
|
||||
|
||||
import uuid
|
||||
import shutil
|
||||
import tempfile
|
||||
import hashlib
|
||||
import os
|
||||
|
||||
from mock import Mock, patch
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.trial import unittest
|
||||
from twisted.web.client import ResponseDone
|
||||
from twisted.python.failure import Failure
|
||||
|
||||
from teeth_agent.protocol import ImageInfo
|
||||
from teeth_agent.base_task import BaseTask, MultiTask
|
||||
from mock import Mock
|
||||
from teeth_agent.cache_image import ImageDownloaderTask
|
||||
|
||||
|
||||
class FakeClient(object):
|
||||
|
@ -106,3 +117,77 @@ class MultiTaskTest(unittest.TestCase):
|
|||
self.assertEqual(self.task._state, 'complete')
|
||||
self.client.finish_task.assert_any_call(t)
|
||||
self.client.finish_task.assert_any_call(self.task)
|
||||
|
||||
|
||||
class StubResponse(object):
|
||||
def __init__(self, code, headers, body):
|
||||
self.version = ('HTTP', 1, 1)
|
||||
self.code = code
|
||||
self.status = "ima teapot"
|
||||
self.headers = headers
|
||||
self.body = body
|
||||
self.length = reduce(lambda x, y: x + len(y), body, 0)
|
||||
self.protocol = None
|
||||
|
||||
def deliverBody(self, protocol):
|
||||
self.protocol = protocol
|
||||
|
||||
def run(self):
|
||||
self.protocol.connectionMade()
|
||||
|
||||
for data in self.body:
|
||||
self.protocol.dataReceived(data)
|
||||
|
||||
self.protocol.connectionLost(Failure(ResponseDone("Response body fully received")))
|
||||
|
||||
|
||||
class ImageDownloaderTaskTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
get_patcher = patch('treq.get', autospec=True)
|
||||
self.TreqGet = get_patcher.start()
|
||||
self.addCleanup(get_patcher.stop)
|
||||
|
||||
self.tmpdir = tempfile.mkdtemp('image_download_test')
|
||||
self.task_id = str(uuid.uuid4())
|
||||
self.image_data = str(uuid.uuid4())
|
||||
self.image_md5 = hashlib.md5(self.image_data).hexdigest()
|
||||
self.cache_path = os.path.join(self.tmpdir, 'a1234.img')
|
||||
self.client = FakeClient()
|
||||
self.image_info = ImageInfo('a1234',
|
||||
['http://127.0.0.1/images/a1234.img'], {'md5': self.image_md5})
|
||||
self.task = ImageDownloaderTask(self.client,
|
||||
self.task_id,
|
||||
self.image_info,
|
||||
self.cache_path)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmpdir)
|
||||
|
||||
def assertFileHash(self, hash_type, path, value):
|
||||
file_hash = hashlib.new(hash_type)
|
||||
with open(path, 'r') as fp:
|
||||
file_hash.update(fp.read())
|
||||
self.assertEqual(value, file_hash.hexdigest())
|
||||
|
||||
def test_download_success(self):
|
||||
resp = StubResponse(200, [], [self.image_data])
|
||||
d = defer.Deferred()
|
||||
self.TreqGet.return_value = d
|
||||
self.task.run()
|
||||
self.client.addService.assert_called_once_with(self.task)
|
||||
|
||||
self.TreqGet.assert_called_once_with('http://127.0.0.1/images/a1234.img')
|
||||
|
||||
self.task.startService()
|
||||
|
||||
d.callback(resp)
|
||||
|
||||
resp.run()
|
||||
|
||||
self.client.update_task_status.assert_called_once_with(self.task)
|
||||
self.assertFileHash('md5', self.cache_path, self.image_md5)
|
||||
|
||||
self.task.stopService()
|
||||
self.assertEqual(self.task._state, 'error')
|
||||
self.client.finish_task.assert_called_once_with(self.task)
|
||||
|
|
Loading…
Reference in New Issue