diff --git a/cloudbaseinit/metadata/services/base.py b/cloudbaseinit/metadata/services/base.py index 010f19d7..f2016e80 100644 --- a/cloudbaseinit/metadata/services/base.py +++ b/cloudbaseinit/metadata/services/base.py @@ -20,6 +20,7 @@ import time from oslo.config import cfg from cloudbaseinit.openstack.common import log as logging +from cloudbaseinit.utils import encoding opts = [ @@ -88,13 +89,17 @@ class BaseMetadataService(object): else: raise - def _get_cache_data(self, path): - if path in self._cache: + def _get_cache_data(self, path, decode=False): + """Get meta data with caching and decoding support.""" + key = (path, decode) + if key in self._cache: LOG.debug("Using cached copy of metadata: '%s'" % path) - return self._cache[path] + return self._cache[key] else: data = self._exec_with_retry(lambda: self._get_data(path)) - self._cache[path] = data + if decode: + data = encoding.get_as_string(data) + self._cache[key] = data return data def get_instance_id(self): diff --git a/cloudbaseinit/metadata/services/baseopenstackservice.py b/cloudbaseinit/metadata/services/baseopenstackservice.py index b7cb480b..3b019768 100644 --- a/cloudbaseinit/metadata/services/baseopenstackservice.py +++ b/cloudbaseinit/metadata/services/baseopenstackservice.py @@ -51,9 +51,9 @@ class BaseOpenStackService(base.BaseMetadataService): def _get_meta_data(self, version='latest'): path = posixpath.normpath( posixpath.join('openstack', version, 'meta_data.json')) - data = self._get_cache_data(path) + data = self._get_cache_data(path, decode=True) if data: - return json.loads(encoding.get_as_string(data)) + return json.loads(data) def get_instance_id(self): return self._get_meta_data().get('uuid') @@ -136,10 +136,10 @@ class BaseOpenStackService(base.BaseMetadataService): if not certs: # Look if the user_data contains a PEM certificate try: - user_data = self.get_user_data() + user_data = self.get_user_data().strip() if user_data.startswith( x509constants.PEM_HEADER.encode()): - certs.append(user_data) + certs.append(encoding.get_as_string(user_data)) except base.NotExistingMetadataException: LOG.debug("user_data metadata not present") diff --git a/cloudbaseinit/metadata/services/cloudstack.py b/cloudbaseinit/metadata/services/cloudstack.py index 8abc6e1d..b3645da8 100644 --- a/cloudbaseinit/metadata/services/cloudstack.py +++ b/cloudbaseinit/metadata/services/cloudstack.py @@ -21,6 +21,7 @@ from six.moves import urllib from cloudbaseinit.metadata.services import base from cloudbaseinit.openstack.common import log as logging from cloudbaseinit.osutils import factory as osutils_factory +from cloudbaseinit.utils import encoding LOG = logging.getLogger(__name__) @@ -104,11 +105,11 @@ class CloudStack(base.BaseMetadataService): def get_instance_id(self): """Instance name of the virtual machine.""" - return self._get_cache_data('instance-id') + return self._get_cache_data('instance-id', decode=True) def get_host_name(self): """Hostname of the virtual machine.""" - return self._get_cache_data('local-hostname') + return self._get_cache_data('local-hostname', decode=True) def get_user_data(self): """User data for this virtual machine.""" @@ -117,7 +118,9 @@ class CloudStack(base.BaseMetadataService): def get_public_keys(self): """Available ssh public keys.""" ssh_keys = [] - for ssh_key in self._get_cache_data('public-keys').splitlines(): + ssh_chunks = self._get_cache_data('public-keys', + decode=True).splitlines() + for ssh_key in ssh_chunks: ssh_key = ssh_key.strip() if not ssh_key: continue @@ -155,14 +158,13 @@ class CloudStack(base.BaseMetadataService): if response.status != 200: LOG.warning("Getting password failed: %(status)s " - "%(reason)s - %(message)s", + "%(reason)s - %(message)r", {"status": response.status, "reason": response.reason, "message": response.read()}) continue - content = response.read() - content = content.strip() + content = response.read().strip() if not content: LOG.warning("The Password Server did not have any " "password for the current instance.") @@ -180,7 +182,7 @@ class CloudStack(base.BaseMetadataService): LOG.info("The password server return a valid password " "for the current instance.") - password = content.decode() + password = encoding.get_as_string(content) break return password @@ -201,14 +203,14 @@ class CloudStack(base.BaseMetadataService): response = connection.getresponse() if response.status != 200: LOG.warning("Removing password failed: %(status)s " - "%(reason)s - %(message)s", + "%(reason)s - %(message)r", {"status": response.status, "reason": response.reason, "message": response.read()}) continue content = response.read() - if content.decode() != BAD_REQUEST: + if content != BAD_REQUEST: # comparing bytes with bytes LOG.info("The password was removed from the Password Server.") break else: diff --git a/cloudbaseinit/metadata/services/ec2service.py b/cloudbaseinit/metadata/services/ec2service.py index 94995b3e..b6a0fafa 100644 --- a/cloudbaseinit/metadata/services/ec2service.py +++ b/cloudbaseinit/metadata/services/ec2service.py @@ -78,25 +78,25 @@ class EC2Service(base.BaseMetadataService): def get_host_name(self): return self._get_cache_data('%s/meta-data/local-hostname' % - self._metadata_version) + self._metadata_version, decode=True) def get_instance_id(self): return self._get_cache_data('%s/meta-data/instance-id' % - self._metadata_version) + self._metadata_version, decode=True) def get_public_keys(self): ssh_keys = [] keys_info = self._get_cache_data( '%s/meta-data/public-keys' % - self._metadata_version).split("\n") + self._metadata_version, decode=True).splitlines() for key_info in keys_info: (idx, key_name) = key_info.split('=') ssh_key = self._get_cache_data( '%(version)s/meta-data/public-keys/%(idx)s/openssh-key' % - {'version': self._metadata_version, 'idx': idx}) + {'version': self._metadata_version, 'idx': idx}, decode=True) ssh_keys.append(ssh_key.strip()) return ssh_keys diff --git a/cloudbaseinit/metadata/services/maasservice.py b/cloudbaseinit/metadata/services/maasservice.py index 79498dfc..cac1f05d 100644 --- a/cloudbaseinit/metadata/services/maasservice.py +++ b/cloudbaseinit/metadata/services/maasservice.py @@ -13,6 +13,7 @@ # under the License. import posixpath +import re from oauthlib import oauth1 from oslo.config import cfg @@ -108,23 +109,25 @@ class MaaSHttpService(base.BaseMetadataService): def get_host_name(self): return self._get_cache_data('%s/meta-data/local-hostname' % - self._metadata_version) + self._metadata_version, decode=True) def get_instance_id(self): return self._get_cache_data('%s/meta-data/instance-id' % - self._metadata_version) - - def _get_list_from_text(self, text, delimiter): - return [v + delimiter for v in text.split(delimiter)] + self._metadata_version, decode=True) def get_public_keys(self): return self._get_cache_data('%s/meta-data/public-keys' % - self._metadata_version).splitlines() + self._metadata_version, + decode=True).splitlines() def get_client_auth_certs(self): - return self._get_list_from_text( - self._get_cache_data('%s/meta-data/x509' % self._metadata_version), - "%s\n" % x509constants.PEM_FOOTER) + certs_data = self._get_cache_data('%s/meta-data/x509' % + self._metadata_version, + decode=True) + pattern = r"{begin}[\s\S]+?{end}".format( + begin=x509constants.PEM_HEADER, + end=x509constants.PEM_FOOTER) + return re.findall(pattern, certs_data) def get_user_data(self): return self._get_cache_data('%s/user-data' % self._metadata_version) diff --git a/cloudbaseinit/metadata/services/opennebulaservice.py b/cloudbaseinit/metadata/services/opennebulaservice.py index 0f6c7070..f46dc1cb 100644 --- a/cloudbaseinit/metadata/services/opennebulaservice.py +++ b/cloudbaseinit/metadata/services/opennebulaservice.py @@ -146,7 +146,7 @@ class OpenNebulaService(base.BaseMetadataService): raise base.NotExistingMetadataException(msg) return self._dict_content[name] - def _get_cache_data(self, names, iid=None): + def _get_cache_data(self, names, iid=None, decode=False): # Solves caching issues when working with # multiple names (lists not hashable). # This happens because the caching function used @@ -160,7 +160,8 @@ class OpenNebulaService(base.BaseMetadataService): names[ind] = value.format(iid=iid) for name in names: try: - return super(OpenNebulaService, self)._get_cache_data(name) + return super(OpenNebulaService, self)._get_cache_data( + name, decode=decode) except base.NotExistingMetadataException: pass msg = "None of {} metadata was found".format(", ".join(names)) @@ -192,14 +193,13 @@ class OpenNebulaService(base.BaseMetadataService): return INSTANCE_ID def get_host_name(self): - return encoding.get_as_string(self._get_cache_data(HOST_NAME)) + return self._get_cache_data(HOST_NAME, decode=True) def get_user_data(self): return self._get_cache_data(USER_DATA) def get_public_keys(self): - return encoding.get_as_string( - self._get_cache_data(PUBLIC_KEY)).splitlines() + return self._get_cache_data(PUBLIC_KEY, decode=True).splitlines() def get_network_details(self): """Return a list of NetworkDetails objects. @@ -215,19 +215,17 @@ class OpenNebulaService(base.BaseMetadataService): for iid in range(ncount): try: # get existing values - mac = encoding.get_as_string( - self._get_cache_data(MAC, iid=iid)).upper() - address = encoding.get_as_string(self._get_cache_data(ADDRESS, - iid=iid)) + mac = self._get_cache_data(MAC, iid=iid, decode=True).upper() + address = self._get_cache_data(ADDRESS, iid=iid, decode=True) # try to find/predict and compute the rest try: - gateway = encoding.get_as_string( - self._get_cache_data(GATEWAY, iid=iid)) + gateway = self._get_cache_data(GATEWAY, iid=iid, + decode=True) except base.NotExistingMetadataException: gateway = None try: - netmask = encoding.get_as_string( - self._get_cache_data(NETMASK, iid=iid)) + netmask = self._get_cache_data(NETMASK, iid=iid, + decode=True) except base.NotExistingMetadataException: if not gateway: raise @@ -244,8 +242,8 @@ class OpenNebulaService(base.BaseMetadataService): broadcast=broadcast, gateway=gateway, gateway6=None, - dnsnameservers=encoding.get_as_string( - self._get_cache_data(DNSNS, iid=iid)).split(" ") + dnsnameservers=self._get_cache_data( + DNSNS, iid=iid, decode=True).split(" ") ) except base.NotExistingMetadataException: LOG.debug("Incomplete NIC details") diff --git a/cloudbaseinit/tests/metadata/services/test_baseopenstackservice.py b/cloudbaseinit/tests/metadata/services/test_baseopenstackservice.py index c3f89b98..72db9d0d 100644 --- a/cloudbaseinit/tests/metadata/services/test_baseopenstackservice.py +++ b/cloudbaseinit/tests/metadata/services/test_baseopenstackservice.py @@ -70,11 +70,11 @@ class TestBaseOpenStackService(unittest.TestCase): @mock.patch(MODPATH + ".BaseOpenStackService._get_cache_data") def test_get_meta_data(self, mock_get_cache_data): - mock_get_cache_data.return_value = b'{"fake": "data"}' + mock_get_cache_data.return_value = '{"fake": "data"}' response = self._service._get_meta_data( version='fake version') path = posixpath.join('openstack', 'fake version', 'meta_data.json') - mock_get_cache_data.assert_called_with(path) + mock_get_cache_data.assert_called_with(path, decode=True) self.assertEqual({"fake": "data"}, response) @mock.patch(MODPATH + @@ -151,7 +151,7 @@ class TestBaseOpenStackService(unittest.TestCase): if isinstance(ret_value, bytes) and ret_value.startswith( x509constants.PEM_HEADER.encode()): mock_get_user_data.assert_called_once_with() - self.assertEqual([ret_value], response) + self.assertEqual([ret_value.decode()], response) elif ret_value is base.NotExistingMetadataException: self.assertFalse(response) else: diff --git a/cloudbaseinit/tests/metadata/services/test_cloudstack.py b/cloudbaseinit/tests/metadata/services/test_cloudstack.py index 30c47fec..71570d06 100644 --- a/cloudbaseinit/tests/metadata/services/test_cloudstack.py +++ b/cloudbaseinit/tests/metadata/services/test_cloudstack.py @@ -12,6 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. +import functools import socket import unittest @@ -149,12 +150,17 @@ class CloudStackTest(unittest.TestCase): @mock.patch('cloudbaseinit.metadata.services.cloudstack.CloudStack' '._get_cache_data') - def _test_cache_response(self, mock_get_cache_data, method, metadata): + def _test_cache_response(self, mock_get_cache_data, method, metadata, + decode=True): mock_get_cache_data.side_effect = [mock.sentinel.response] response = method() self.assertEqual(mock.sentinel.response, response) - mock_get_cache_data.assert_called_once_with(metadata) + cache_assert = functools.partial( + mock_get_cache_data.assert_called_once_with, + metadata) + if decode: + cache_assert(decode=decode) def test_get_instance_id(self): self._test_cache_response(method=self._service.get_instance_id, @@ -166,7 +172,7 @@ class CloudStackTest(unittest.TestCase): def test_get_user_data(self): self._test_cache_response(method=self._service.get_user_data, - metadata='../user-data') + metadata='../user-data', decode=False) @mock.patch('cloudbaseinit.metadata.services.cloudstack.CloudStack' '._get_cache_data') diff --git a/cloudbaseinit/tests/metadata/services/test_ec2service.py b/cloudbaseinit/tests/metadata/services/test_ec2service.py index 5ea8a874..73299853 100644 --- a/cloudbaseinit/tests/metadata/services/test_ec2service.py +++ b/cloudbaseinit/tests/metadata/services/test_ec2service.py @@ -99,7 +99,8 @@ class EC2ServiceTest(unittest.TestCase): def test_get_host_name(self, mock_get_cache_data): response = self._service.get_host_name() mock_get_cache_data.assert_called_once_with( - '%s/meta-data/local-hostname' % self._service._metadata_version) + '%s/meta-data/local-hostname' % self._service._metadata_version, + decode=True) self.assertEqual(mock_get_cache_data.return_value, response) @mock.patch('cloudbaseinit.metadata.services.ec2service.EC2Service' @@ -107,7 +108,8 @@ class EC2ServiceTest(unittest.TestCase): def test_get_instance_id(self, mock_get_cache_data): response = self._service.get_instance_id() mock_get_cache_data.assert_called_once_with( - '%s/meta-data/instance-id' % self._service._metadata_version) + '%s/meta-data/instance-id' % self._service._metadata_version, + decode=True) self.assertEqual(mock_get_cache_data.return_value, response) @mock.patch('cloudbaseinit.metadata.services.ec2service.EC2Service' @@ -117,10 +119,11 @@ class EC2ServiceTest(unittest.TestCase): response = self._service.get_public_keys() expected = [ mock.call('%s/meta-data/public-keys' % - self._service._metadata_version), + self._service._metadata_version, + decode=True), mock.call('%(version)s/meta-data/public-keys/%(' 'idx)s/openssh-key' % {'version': self._service._metadata_version, - 'idx': 'key'})] + 'idx': 'key'}, decode=True)] self.assertEqual(expected, mock_get_cache_data.call_args_list) self.assertEqual(['fake key'], response) diff --git a/cloudbaseinit/tests/metadata/services/test_maasservice.py b/cloudbaseinit/tests/metadata/services/test_maasservice.py index 599fe00d..1bee1926 100644 --- a/cloudbaseinit/tests/metadata/services/test_maasservice.py +++ b/cloudbaseinit/tests/metadata/services/test_maasservice.py @@ -148,7 +148,8 @@ class MaaSHttpServiceTest(unittest.TestCase): response = self._maasservice.get_host_name() mock_get_cache_data.assert_called_once_with( '%s/meta-data/local-hostname' % - self._maasservice._metadata_version) + self._maasservice._metadata_version, + decode=True) self.assertEqual(mock_get_cache_data.return_value, response) @mock.patch("cloudbaseinit.metadata.services.maasservice.MaaSHttpService" @@ -156,13 +157,10 @@ class MaaSHttpServiceTest(unittest.TestCase): def test_get_instance_id(self, mock_get_cache_data): response = self._maasservice.get_instance_id() mock_get_cache_data.assert_called_once_with( - '%s/meta-data/instance-id' % self._maasservice._metadata_version) + '%s/meta-data/instance-id' % self._maasservice._metadata_version, + decode=True) self.assertEqual(mock_get_cache_data.return_value, response) - def test_get_list_from_text(self): - response = self._maasservice._get_list_from_text('fake:text', ':') - self.assertEqual(['fake:', 'text:'], response) - @mock.patch("cloudbaseinit.metadata.services.maasservice.MaaSHttpService" "._get_cache_data") def test_get_public_keys(self, mock_get_cache_data): @@ -174,21 +172,26 @@ class MaaSHttpServiceTest(unittest.TestCase): mock_get_cache_data.return_value = public_key response = self._maasservice.get_public_keys() mock_get_cache_data.assert_called_with( - '%s/meta-data/public-keys' % self._maasservice._metadata_version) + '%s/meta-data/public-keys' % self._maasservice._metadata_version, + decode=True) self.assertEqual(public_keys, response) - @mock.patch("cloudbaseinit.metadata.services.maasservice.MaaSHttpService" - "._get_list_from_text") @mock.patch("cloudbaseinit.metadata.services.maasservice.MaaSHttpService" "._get_cache_data") - def test_get_client_auth_certs(self, mock_get_cache_data, - mock_get_list_from_text): + def test_get_client_auth_certs(self, mock_get_cache_data): + certs = [ + "{begin}\n{cert}\n{end}".format( + begin=x509constants.PEM_HEADER, + end=x509constants.PEM_FOOTER, + cert=cert) + for cert in ("first cert", "second cert") + ] + mock_get_cache_data.return_value = "\n".join(certs) + "\n" response = self._maasservice.get_client_auth_certs() mock_get_cache_data.assert_called_with( - '%s/meta-data/x509' % self._maasservice._metadata_version) - mock_get_list_from_text.assert_called_once_with( - mock_get_cache_data(), "%s\n" % x509constants.PEM_FOOTER) - self.assertEqual(mock_get_list_from_text.return_value, response) + '%s/meta-data/x509' % self._maasservice._metadata_version, + decode=True) + self.assertEqual(certs, response) @mock.patch("cloudbaseinit.metadata.services.maasservice.MaaSHttpService" "._get_cache_data") diff --git a/cloudbaseinit/tests/utils/test_encoding.py b/cloudbaseinit/tests/utils/test_encoding.py new file mode 100644 index 00000000..61caf631 --- /dev/null +++ b/cloudbaseinit/tests/utils/test_encoding.py @@ -0,0 +1,54 @@ +# Copyright 2014 Cloudbase Solutions Srl +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + + +import os +import tempfile +import unittest + +from cloudbaseinit.tests import testutils +from cloudbaseinit.utils import encoding + + +class TestEncoding(unittest.TestCase): + + def test_get_as_string(self): + content_map = [ + ("data", "data"), + (b"data", "data"), + ("data".encode(), "data"), + ("data".encode("utf-16"), None) + ] + with testutils.LogSnatcher("cloudbaseinit.utils.encoding") as snatch: + for content, expect in content_map: + self.assertEqual(expect, encoding.get_as_string(content)) + self.assertIn("couldn't decode", snatch.output[0].lower()) + + def test_write_file(self): + mode_map = [ + (("w", "r"), "my test\ndata\n\n", False), + (("wb", "rb"), "\r\n".join((chr(x) for x in + (32, 125, 0))).encode(), False), + (("wb", "rb"), "my test\ndata\n\n", True) + ] + with testutils.create_tempdir() as temp: + fd, path = tempfile.mkstemp(dir=temp) + os.close(fd) + for (write, read), data, encode in mode_map: + encoding.write_file(path, data, mode=write) + with open(path, read) as stream: + content = stream.read() + if encode: + data = data.encode() + self.assertEqual(data, content) diff --git a/cloudbaseinit/tests/utils/windows/test_x509.py b/cloudbaseinit/tests/utils/windows/test_x509.py index 6423545a..26d8966e 100644 --- a/cloudbaseinit/tests/utils/windows/test_x509.py +++ b/cloudbaseinit/tests/utils/windows/test_x509.py @@ -285,7 +285,6 @@ class CryptoAPICertManagerTests(unittest.TestCase): fake_cert_data += x509constants.PEM_HEADER + '\n' fake_cert_data += 'fake cert' + '\n' fake_cert_data += x509constants.PEM_FOOTER - fake_cert_data = fake_cert_data.encode() response = self._x509_manager._get_cert_base64(fake_cert_data) self.assertEqual('fake cert', response) diff --git a/cloudbaseinit/utils/encoding.py b/cloudbaseinit/utils/encoding.py index bbd25a78..606af08f 100644 --- a/cloudbaseinit/utils/encoding.py +++ b/cloudbaseinit/utils/encoding.py @@ -14,6 +14,11 @@ import six +from cloudbaseinit.openstack.common import log as logging + + +LOG = logging.getLogger(__name__) + def get_as_string(value): if value is None or isinstance(value, six.text_type): @@ -22,7 +27,9 @@ def get_as_string(value): try: return value.decode() except Exception: - pass + # This is important, because None will be returned, + # but not that serious to raise an exception. + LOG.error("Couldn't decode: %r", value) def write_file(target_path, data, mode='wb'): @@ -31,13 +38,3 @@ def write_file(target_path, data, mode='wb'): with open(target_path, mode) as f: f.write(data) - - -def read_file(target_path, mode='rb'): - with open(target_path, mode) as f: - data = f.read() - - if 'b' in mode: - data = data.decode() - - return data diff --git a/cloudbaseinit/utils/windows/x509.py b/cloudbaseinit/utils/windows/x509.py index 19ff5462..4ed0170d 100644 --- a/cloudbaseinit/utils/windows/x509.py +++ b/cloudbaseinit/utils/windows/x509.py @@ -20,7 +20,6 @@ import uuid import six -from cloudbaseinit.utils import encoding from cloudbaseinit.utils.windows import cryptoapi from cloudbaseinit.utils import x509constants @@ -205,13 +204,17 @@ class CryptoAPICertManager(object): free(subject_encoded) def _get_cert_base64(self, cert_data): - base64_cert_data = encoding.get_as_string(cert_data) - if base64_cert_data.startswith(x509constants.PEM_HEADER): - base64_cert_data = base64_cert_data[len(x509constants.PEM_HEADER):] - if base64_cert_data.endswith(x509constants.PEM_FOOTER): - base64_cert_data = base64_cert_data[:len(base64_cert_data) - - len(x509constants.PEM_FOOTER)] - return base64_cert_data.replace("\n", "") + """Remove certificate header and footer and also new lines.""" + # It's assured that the certificate is already a string. + removal = [ + x509constants.PEM_HEADER, + x509constants.PEM_FOOTER, + "\r", + "\n" + ] + for remove in removal: + cert_data = cert_data.replace(remove, "") + return cert_data def import_cert(self, cert_data, machine_keyset=True, store_name=STORE_NAME_MY):