diff --git a/cloudbaseinit/metadata/services/base.py b/cloudbaseinit/metadata/services/base.py index a1e0fc49..bda2317a 100644 --- a/cloudbaseinit/metadata/services/base.py +++ b/cloudbaseinit/metadata/services/base.py @@ -15,6 +15,8 @@ import abc import collections +import gzip +import io import time from oslo_config import cfg @@ -63,6 +65,8 @@ class NotExistingMetadataException(Exception): @six.add_metaclass(abc.ABCMeta) class BaseMetadataService(object): + _GZIP_MAGIC_NUMBER = b'\x1f\x8b' + def __init__(self): self._cache = {} self._enable_retry = False @@ -113,6 +117,21 @@ class BaseMetadataService(object): def get_user_data(self): pass + def get_decoded_user_data(self): + """Get the decoded user data, if any + + The user data can be gzip-encoded, which means + that every access to it should verify this fact, + leading to code duplication. + """ + user_data = self.get_user_data() + if user_data and user_data[:2] == self._GZIP_MAGIC_NUMBER: + bio = io.BytesIO(user_data) + with gzip.GzipFile(fileobj=bio, mode='rb') as out: + user_data = out.read() + + return user_data + def get_host_name(self): pass diff --git a/cloudbaseinit/plugins/common/userdata.py b/cloudbaseinit/plugins/common/userdata.py index c587aa20..61321a37 100644 --- a/cloudbaseinit/plugins/common/userdata.py +++ b/cloudbaseinit/plugins/common/userdata.py @@ -13,8 +13,6 @@ # under the License. import email -import gzip -import io from oslo_log import log as oslo_logging @@ -36,7 +34,7 @@ class UserDataPlugin(base.BasePlugin): def execute(self, service, shared_data): try: - user_data = service.get_user_data() + user_data = service.get_decoded_user_data() except metadata_services_base.NotExistingMetadataException: return base.PLUGIN_EXECUTION_DONE, False @@ -44,18 +42,8 @@ class UserDataPlugin(base.BasePlugin): return base.PLUGIN_EXECUTION_DONE, False LOG.debug('User data content length: %d' % len(user_data)) - user_data = self._check_gzip_compression(user_data) - return self._process_user_data(user_data) - def _check_gzip_compression(self, user_data): - if user_data[:2] == self._GZIP_MAGIC_NUMBER: - bio = io.BytesIO(user_data) - with gzip.GzipFile(fileobj=bio, mode='rb') as f: - user_data = f.read() - - return user_data - @staticmethod def _parse_mime(user_data): user_data_str = encoding.get_as_string(user_data) diff --git a/cloudbaseinit/tests/metadata/services/test_base.py b/cloudbaseinit/tests/metadata/services/test_base.py new file mode 100644 index 00000000..18d77b34 --- /dev/null +++ b/cloudbaseinit/tests/metadata/services/test_base.py @@ -0,0 +1,37 @@ +# Copyright 2015 Cloudbase Solutions Srl +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import unittest + +from cloudbaseinit.metadata.services import base + + +class FakeService(base.BaseMetadataService): + def _get_data(self): + return (b'\x1f\x8b\x08\x00\x93\x90\xf2U\x02' + b'\xff\xcbOSH\xce/-*NU\xc8,Q(\xcf/\xca.' + b'\x06\x00\x12:\xf6a\x12\x00\x00\x00') + + def get_user_data(self): + return self._get_data() + + +class TestBase(unittest.TestCase): + + def setUp(self): + self._service = FakeService() + + def test_get_decoded_user_data(self): + userdata = self._service.get_decoded_user_data() + self.assertEqual(b"of course it works", userdata) diff --git a/cloudbaseinit/tests/plugins/common/test_userdata.py b/cloudbaseinit/tests/plugins/common/test_userdata.py index 3cc52896..1bce2ea6 100644 --- a/cloudbaseinit/tests/plugins/common/test_userdata.py +++ b/cloudbaseinit/tests/plugins/common/test_userdata.py @@ -34,7 +34,7 @@ class FakeService(object): def __init__(self, user_data): self.user_data = user_data - def get_user_data(self): + def get_decoded_user_data(self): return self.user_data.encode() @@ -53,26 +53,18 @@ class UserDataPluginTest(unittest.TestCase): @mock.patch('cloudbaseinit.plugins.common.userdata.UserDataPlugin' '._process_user_data') - @mock.patch('cloudbaseinit.plugins.common.userdata.UserDataPlugin' - '._check_gzip_compression') - def _test_execute(self, mock_check_gzip_compression, - mock_process_user_data, ret_val): + def _test_execute(self, mock_process_user_data, ret_val): mock_service = mock.MagicMock() - mock_service.get_user_data.side_effect = [ret_val] + mock_service.get_decoded_user_data.side_effect = [ret_val] response = self._userdata.execute(service=mock_service, shared_data=None) - mock_service.get_user_data.assert_called_once_with() + mock_service.get_decoded_user_data.assert_called_once_with() if ret_val is metadata_services_base.NotExistingMetadataException: self.assertEqual(response, (base.PLUGIN_EXECUTION_DONE, False)) elif ret_val is None: self.assertEqual(response, (base.PLUGIN_EXECUTION_DONE, False)) - else: - mock_check_gzip_compression.assert_called_once_with(ret_val) - mock_process_user_data.assert_called_once_with( - mock_check_gzip_compression.return_value) - self.assertEqual(response, mock_process_user_data.return_value) def test_execute(self): self._test_execute(ret_val='fake_data') @@ -87,20 +79,6 @@ class UserDataPluginTest(unittest.TestCase): def test_execute_not_user_data(self): self._test_execute(ret_val=None) - @mock.patch('io.BytesIO') - @mock.patch('gzip.GzipFile') - def test_check_gzip_compression(self, mock_GzipFile, mock_BytesIO): - fake_userdata = b'\x1f\x8b' - fake_userdata += self._userdata._GZIP_MAGIC_NUMBER - - response = self._userdata._check_gzip_compression(fake_userdata) - - mock_BytesIO.assert_called_once_with(fake_userdata) - mock_GzipFile.assert_called_once_with( - fileobj=mock_BytesIO.return_value, mode='rb') - data = mock_GzipFile().__enter__().read.return_value - self.assertEqual(data, response) - @mock.patch('email.message_from_string') @mock.patch('cloudbaseinit.utils.encoding.get_as_string') def test_parse_mime(self, mock_get_as_string,