diff --git a/trove/guestagent/backup/backupagent.py b/trove/guestagent/backup/backupagent.py index 46449dfcf9..7420f03782 100644 --- a/trove/guestagent/backup/backupagent.py +++ b/trove/guestagent/backup/backupagent.py @@ -119,17 +119,12 @@ class BackupAgent(object): restore_runner = self._get_restore_runner(backup_info['type']) LOG.debug("Getting Storage Strategy") - storage_strategy = get_storage_strategy( + storage = get_storage_strategy( CONF.storage_strategy, CONF.storage_namespace)(context) - LOG.debug("Preparing storage to download stream.") - download_stream = storage_strategy.load(context, - backup_info['location'], - restore_runner.is_zipped, - backup_info['checksum']) - - with restore_runner(restore_stream=download_stream, + with restore_runner(storage, location=backup_info['location'], + checksum=backup_info['checksum'], restore_location=restore_location) as runner: LOG.debug("Restoring instance from backup %s to %s", backup_info['id'], restore_location) diff --git a/trove/guestagent/strategies/restore/base.py b/trove/guestagent/strategies/restore/base.py index 782517dc4e..06fa2d7eea 100644 --- a/trove/guestagent/strategies/restore/base.py +++ b/trove/guestagent/strategies/restore/base.py @@ -59,8 +59,10 @@ class RestoreRunner(Strategy): is_encrypted = BACKUP_USE_OPENSSL decrypt_key = BACKUP_DECRYPT_KEY - def __init__(self, restore_stream, **kwargs): - self.restore_stream = restore_stream + def __init__(self, storage, **kwargs): + self.storage = storage + self.location = kwargs.pop('location') + self.checksum = kwargs.pop('checksum') self.restore_location = kwargs.get('restore_location', '/var/lib/mysql') self.restore_cmd = (self.decrypt_cmd + @@ -102,20 +104,17 @@ class RestoreRunner(Strategy): return content_length def _run_restore(self): - with self.restore_stream as stream: - self.process = subprocess.Popen(self.restore_cmd, shell=True, - stdin=subprocess.PIPE, - stderr=subprocess.PIPE) - self.pid = self.process.pid - content_length = 0 - chunk = stream.read(CHUNK_SIZE) - while chunk: - self.process.stdin.write(chunk) - content_length += len(chunk) - chunk = stream.read(CHUNK_SIZE) - self.process.stdin.close() - LOG.info("Restored %s bytes from swift via xbstream." - % content_length) + stream = self.storage.load(self.location, self.checksum) + self.process = subprocess.Popen(self.restore_cmd, shell=True, + stdin=subprocess.PIPE, + stderr=subprocess.PIPE) + self.pid = self.process.pid + content_length = 0 + for chunk in stream: + self.process.stdin.write(chunk) + content_length += len(chunk) + self.process.stdin.close() + LOG.info("Restored %s bytes from stream." % content_length) return content_length diff --git a/trove/guestagent/strategies/storage/base.py b/trove/guestagent/strategies/storage/base.py index 2a91d7639f..3f3646dc60 100644 --- a/trove/guestagent/strategies/storage/base.py +++ b/trove/guestagent/strategies/storage/base.py @@ -32,5 +32,5 @@ class Storage(Strategy): """Persist information from the stream """ @abc.abstractmethod - def load(self, context, location, is_zipped, backup_checksum): + def load(self, location, backup_checksum): """Load a stream from a persisted storage location """ diff --git a/trove/guestagent/strategies/storage/swift.py b/trove/guestagent/strategies/storage/swift.py index 83cb7b3f61..0f20416f72 100644 --- a/trove/guestagent/strategies/storage/swift.py +++ b/trove/guestagent/strategies/storage/swift.py @@ -20,8 +20,6 @@ from trove.guestagent.strategies.storage import base from trove.openstack.common import log as logging from trove.common.remote import create_swift_client from trove.common import cfg -from trove.common import utils -from eventlet.green import subprocess LOG = logging.getLogger(__name__) CONF = cfg.CONF @@ -173,74 +171,24 @@ class SwiftStorage(base.Storage): filename = location.split('/')[-1] return storage_url, container, filename - def load(self, context, location, is_zipped, backup_checksum): - """Restore a backup from the input stream to the restore_location """ + def _verify_checksum(self, etag, checksum): + etag_checksum = etag.strip('"') + if etag_checksum != checksum: + msg = ("Original checksum: %(original)s does not match" + " the current checksum: %(current)s" % + {'original': etag_checksum, 'current': checksum}) + LOG.error(msg) + raise SwiftDownloadIntegrityError(msg) + return True + def load(self, location, backup_checksum): + """Restore a backup from the input stream to the restore_location""" storage_url, container, filename = self._explodeLocation(location) - return SwiftDownloadStream(context, - auth_token=context.auth_token, - storage_url=storage_url, - container=container, - filename=filename, - is_zipped=is_zipped, - backup_checksum=backup_checksum) + headers, info = self.connection.get_object(container, filename, + resp_chunk_size=CHUNK_SIZE) - -class SwiftDownloadStream(object): - """Class to do the actual swift download using the swiftclient """ - - cmd = ("swift --os-auth-token=%(auth_token)s " - "--os-storage-url=%(storage_url)s " - "download %(container)s %(filename)s -o -") - - def __init__(self, context, **kwargs): - self.process = None - self.pid = None - self.cmd = self.cmd % kwargs - self.container = kwargs.get('container') - self.filename = kwargs.get('filename') - self.original_backup_checksum = kwargs.get('backup_checksum', None) - self.swift_client = create_swift_client(context) - - def __enter__(self): - """Start up the process""" - self.run() - return self - - def __exit__(self, exc_type, exc_value, traceback): - """Clean up everything.""" - if exc_type is None: - utils.raise_if_process_errored(self.process, DownloadError) - - # Make sure to terminate the process - try: - self.process.terminate() - except OSError: - # Already stopped - pass - - def read(self, *args, **kwargs): - return self.process.stdout.read(*args, **kwargs) - - def run(self): if CONF.verify_swift_checksum_on_restore: - # Right before downloading swift object lets check that the current - # swift object checksum matches the original backup checksum - self._verify_checksum() - self._run_download_cmd() + self._verify_checksum(headers.get('etag', ''), backup_checksum) - def _verify_checksum(self): - if self.original_backup_checksum: - resp = self.swift_client.head_object(self.container, self.filename) - current_swift_checksum = resp['etag'].strip('"') - if current_swift_checksum != self.original_backup_checksum: - raise SwiftDownloadIntegrityError("Original backup checksum " - "does not match current " - "checksum.") - - def _run_download_cmd(self): - self.process = subprocess.Popen(self.cmd, shell=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE) - self.pid = self.process.pid + return info diff --git a/trove/tests/fakes/swift.py b/trove/tests/fakes/swift.py index 71049d849e..29fd84ab61 100644 --- a/trove/tests/fakes/swift.py +++ b/trove/tests/fakes/swift.py @@ -44,8 +44,13 @@ class FakeSwiftClient(object): class FakeSwiftConnection(object): """Logging calls instead of executing""" + MANIFEST_HEADER_KEY = 'X-Object-Manifest' + url = 'http://mockswift/v1' + def __init__(self, *args, **kwargs): - pass + self.manifest_prefix = None + self.manifest_name = None + self.container_objects = {} def get_auth(self): return ( @@ -94,9 +99,26 @@ class FakeSwiftConnection(object): def head_object(self, container, name): LOG.debug("fake put_container(%s, %s)" % (container, name)) - return {'etag': 'fake-md5-sum'} + checksum = md5() + if self.manifest_prefix and self.manifest_name == name: + for object_name in sorted(self.container_objects.iterkeys()): + object_checksum = md5(self.container_objects[object_name]) + # The manifest file etag for a HEAD or GET is the checksum of + # the concatenated checksums. + checksum.update(object_checksum.hexdigest()) + # this is included to test bad swift segment etags + if name.startswith("bad_manifest_etag_"): + return {'etag': '"this_is_an_intentional_bad_manifest_etag"'} + else: + if name in self.container_objects: + checksum.update(self.container_objects[name]) + else: + return {'etag': 'fake-md5-sum'} - def get_object(self, container, name): + # Currently a swift HEAD object returns etag with double quotes + return {'etag': '"%s"' % checksum.hexdigest()} + + def get_object(self, container, name, resp_chunk_size=None): LOG.debug("fake get_object(%s, %s)" % (container, name)) if container == 'socket_error_on_get': raise socket.error(111, 'ECONNREFUSED') @@ -121,62 +143,22 @@ class FakeSwiftConnection(object): fake_object_body = metadata_json return (fake_object_header, fake_object_body) - fake_header = None - fake_object_body = os.urandom(1024 * 1024) + fake_header = {'etag': '"fake-md5-sum"'} + if resp_chunk_size: + def _object_info(): + length = 0 + while length < (1024 * 1024): + yield os.urandom(resp_chunk_size) + length += resp_chunk_size + fake_object_body = _object_info() + else: + fake_object_body = os.urandom(1024 * 1024) return (fake_header, fake_object_body) - def put_object(self, container, name, reader): + def put_object(self, container, name, contents, **kwargs): LOG.debug("fake put_object(%s, %s)" % (container, name)) if container == 'socket_error_on_put': raise socket.error(111, 'ECONNREFUSED') - return 'fake-md5-sum' - - def delete_object(self, container, name): - LOG.debug("fake delete_object(%s, %s)" % (container, name)) - if container == 'socket_error_on_delete': - raise socket.error(111, 'ECONNREFUSED') - pass - - -class FakeSwiftConnectionWithRealEtag(FakeSwiftConnection): - """ - Overides methods that deal with object etags/checksums so it returns - the actual object etag/checksum - - This fake swift client is meant to only handle at most one large segmented - object. - """ - - MANIFEST_HEADER_KEY = 'X-Object-Manifest' - url = 'http://mockswift/v1' - - def __init__(self, *args, **kwargs): - super(FakeSwiftConnectionWithRealEtag, self).__init__(args, kwargs) - self.manifest_prefix = None - self.manifest_name = None - self.container_objects = {} - - def head_object(self, container, name): - checksum = md5() - if self.manifest_prefix and self.manifest_name == name: - for object_name in sorted(self.container_objects.iterkeys()): - object_checksum = md5(self.container_objects[object_name]) - # The manifest file etag for a HEAD or GET is the checksum of - # the concatenated checksums. - checksum.update(object_checksum.hexdigest()) - # this is included to test bad swift segment etags - if name.startswith("bad_manifest_etag_"): - return {'etag': '"this_is_an_intentional_bad_manifest_etag"'} - else: - if name in self.container_objects: - checksum.update(self.container_objects[name]) - else: - return {'etag': ""} - - # Currently a swift HEAD object returns etag with double quotes - return {'etag': '"%s"' % checksum.hexdigest()} - - def put_object(self, container, name, contents, **kwargs): headers = kwargs.get('headers', {}) object_checksum = md5() if self.MANIFEST_HEADER_KEY in headers: @@ -206,6 +188,12 @@ class FakeSwiftConnectionWithRealEtag(FakeSwiftConnection): return "this_is_an_intentional_bad_segment_etag" return object_checksum.hexdigest() + def delete_object(self, container, name): + LOG.debug("fake delete_object(%s, %s)" % (container, name)) + if container == 'socket_error_on_delete': + raise socket.error(111, 'ECONNREFUSED') + pass + class SwiftClientStub(object): """ diff --git a/trove/tests/unittests/backup/test_backupagent.py b/trove/tests/unittests/backup/test_backupagent.py index 3a2bce5b66..ccf15d09c7 100644 --- a/trove/tests/unittests/backup/test_backupagent.py +++ b/trove/tests/unittests/backup/test_backupagent.py @@ -112,7 +112,7 @@ class MockStorage(Storage): def __call__(self, *args, **kwargs): return self - def load(self, context, location, is_zipped, backup_checksum): + def load(self, location, backup_checksum): pass def save(self, filename, stream): @@ -123,7 +123,7 @@ class MockStorage(Storage): class MockRestoreRunner(RestoreRunner): - def __init__(self, restore_stream, restore_location): + def __init__(self, storage, **kwargs): pass def __enter__(self): diff --git a/trove/tests/unittests/backup/test_storage.py b/trove/tests/unittests/backup/test_storage.py index cf750e0b6f..2ab2270bb3 100644 --- a/trove/tests/unittests/backup/test_storage.py +++ b/trove/tests/unittests/backup/test_storage.py @@ -13,12 +13,11 @@ #limitations under the License. import testtools -from mockito import when, unstub, mock, any +from mockito import when, unstub import hashlib from trove.common.context import TroveContext from trove.tests.fakes.swift import FakeSwiftConnection -from trove.tests.fakes.swift import FakeSwiftConnectionWithRealEtag from trove.tests.unittests.backup.test_backupagent \ import MockBackup as MockBackupRunner from trove.guestagent.strategies.storage.swift \ @@ -28,17 +27,6 @@ from trove.guestagent.strategies.storage.swift import SwiftStorage from trove.guestagent.strategies.storage.swift import StreamReader -class MockProcess(object): - """Fake swift download process""" - - def __init__(self): - self.pid = 1 - self.stdout = "Mock Process stdout." - - def terminate(self): - pass - - class SwiftStorageSaveChecksumTests(testtools.TestCase): """SwiftStorage.save is used to save a backup to Swift""" @@ -56,7 +44,7 @@ class SwiftStorageSaveChecksumTests(testtools.TestCase): user = 'user' password = 'password' - swift_client = FakeSwiftConnectionWithRealEtag() + swift_client = FakeSwiftConnection() when(swift).create_swift_client(context).thenReturn(swift_client) storage_strategy = SwiftStorage(context) @@ -85,7 +73,7 @@ class SwiftStorageSaveChecksumTests(testtools.TestCase): user = 'user' password = 'password' - swift_client = FakeSwiftConnectionWithRealEtag() + swift_client = FakeSwiftConnection() when(swift).create_swift_client(context).thenReturn(swift_client) storage_strategy = SwiftStorage(context) @@ -117,7 +105,7 @@ class SwiftStorageSaveChecksumTests(testtools.TestCase): user = 'user' password = 'password' - swift_client = FakeSwiftConnectionWithRealEtag() + swift_client = FakeSwiftConnection() when(swift).create_swift_client(context).thenReturn(swift_client) storage_strategy = SwiftStorage(context) @@ -139,6 +127,36 @@ class SwiftStorageSaveChecksumTests(testtools.TestCase): "Incorrect swift location was returned.") +class SwiftStorageUtils(testtools.TestCase): + + def setUp(self): + super(SwiftStorageUtils, self).setUp() + context = TroveContext() + swift_client = FakeSwiftConnection() + when(swift).create_swift_client(context).thenReturn(swift_client) + self.swift = SwiftStorage(context) + + def tearDown(self): + super(SwiftStorageUtils, self).tearDown() + + def test_explode_location(self): + location = 'http://mockswift.com/v1/545433/backups/mybackup.tar' + url, container, filename = self.swift._explodeLocation(location) + self.assertEqual(url, 'http://mockswift.com/v1/545433') + self.assertEqual(container, 'backups') + self.assertEqual(filename, 'mybackup.tar') + + def test_validate_checksum_good(self): + match = self.swift._verify_checksum('"my-good-etag"', 'my-good-etag') + self.assertTrue(match) + + def test_verify_checksum_bad(self): + self.assertRaises(SwiftDownloadIntegrityError, + self.swift._verify_checksum, + '"THE-GOOD-THE-BAD"', + 'AND-THE-UGLY') + + class SwiftStorageLoad(testtools.TestCase): """SwiftStorage.load is used to return SwiftDownloadStream which is used to download a backup object from Swift @@ -158,35 +176,14 @@ class SwiftStorageLoad(testtools.TestCase): context = TroveContext() location = "/backup/location/123" - is_zipped = False backup_checksum = "fake-md5-sum" swift_client = FakeSwiftConnection() when(swift).create_swift_client(context).thenReturn(swift_client) - download_process = MockProcess() - subprocess = mock(swift.subprocess) - when(subprocess).Popen(any(), any(), - any(), any()).thenReturn(download_process) - when(swift.utils).raise_if_process_errored().thenReturn(None) storage_strategy = SwiftStorage(context) - download_stream = storage_strategy.load(context, - location, - is_zipped, - backup_checksum) - - self.assertEqual('location', download_stream.container) - self.assertEqual('123', download_stream.filename) - - with download_stream as stream: - print("Testing SwiftDownloadStream context manager: %s" % stream) - - self.assertIsNotNone(download_stream.process, - "SwiftDownloadStream process/cmd is supposed " - "to run.") - self.assertIsNotNone(download_stream.pid, - "SwiftDownloadStream process/cmd is supposed " - "to run.") + download_stream = storage_strategy.load(location, backup_checksum) + self.assertIsNotNone(download_stream) def test_run_verify_checksum_mismatch(self): """This tests that SwiftDownloadIntegrityError is raised and swift @@ -196,27 +193,17 @@ class SwiftStorageLoad(testtools.TestCase): context = TroveContext() location = "/backup/location/123" - is_zipped = False backup_checksum = "checksum_different_then_fake_swift_etag" swift_client = FakeSwiftConnection() when(swift).create_swift_client(context).thenReturn(swift_client) storage_strategy = SwiftStorage(context) - download_stream = storage_strategy.load(context, - location, - is_zipped, - backup_checksum) - - self.assertEqual('location', download_stream.container) - self.assertEqual('123', download_stream.filename) self.assertRaises(SwiftDownloadIntegrityError, - download_stream.__enter__) - - self.assertIsNone(download_stream.process, - "SwiftDownloadStream process/cmd was not supposed" - "to run.") + storage_strategy.load, + location, + backup_checksum) class MockBackupStream(MockBackupRunner): diff --git a/trove/tests/unittests/guestagent/test_backups.py b/trove/tests/unittests/guestagent/test_backups.py index 0b0a8afeb0..124ae7c175 100644 --- a/trove/tests/unittests/guestagent/test_backups.py +++ b/trove/tests/unittests/guestagent/test_backups.py @@ -112,7 +112,8 @@ class GuestAgentBackupTest(testtools.TestCase): restoreBase.RestoreRunner.is_zipped = True restoreBase.RestoreRunner.is_encrypted = False RunnerClass = utils.import_class(RESTORE_XTRA_CLS) - restr = RunnerClass(None, restore_location="/var/lib/mysql") + restr = RunnerClass(None, restore_location="/var/lib/mysql", + location="filename", checksum="md5") self.assertEqual(restr.restore_cmd, UNZIP + PIPE + XTRA_RESTORE) self.assertEqual(restr.prepare_cmd, PREPARE) @@ -121,7 +122,8 @@ class GuestAgentBackupTest(testtools.TestCase): restoreBase.RestoreRunner.is_encrypted = True restoreBase.RestoreRunner.decrypt_key = CRYPTO_KEY RunnerClass = utils.import_class(RESTORE_XTRA_CLS) - restr = RunnerClass(None, restore_location="/var/lib/mysql") + restr = RunnerClass(None, restore_location="/var/lib/mysql", + location="filename", checksum="md5") self.assertEqual(restr.restore_cmd, DECRYPT + PIPE + UNZIP + PIPE + XTRA_RESTORE) self.assertEqual(restr.prepare_cmd, PREPARE) @@ -131,6 +133,7 @@ class GuestAgentBackupTest(testtools.TestCase): restoreBase.RestoreRunner.is_encrypted = False RunnerClass = utils.import_class(RESTORE_SQLDUMP_CLS) restr = RunnerClass(None, restore_location="/var/lib/mysql", + location="filename", checksum="md5", user="user", password="password") self.assertEqual(restr.restore_cmd, UNZIP + PIPE + SQLDUMP_RESTORE) @@ -140,6 +143,7 @@ class GuestAgentBackupTest(testtools.TestCase): restoreBase.RestoreRunner.decrypt_key = CRYPTO_KEY RunnerClass = utils.import_class(RESTORE_SQLDUMP_CLS) restr = RunnerClass(None, restore_location="/var/lib/mysql", + location="filename", checksum="md5", user="user", password="password") self.assertEqual(restr.restore_cmd, DECRYPT + PIPE + UNZIP + PIPE + SQLDUMP_RESTORE)