Normalize all metadata providers and plugins

Every meta data service should return bytes only for these capabilities:
    * get_content
    * get_user_data
While `_get_meta_data` and any other method derrived from it
(including public keys, certificates etc.) should return homogeneous
data types and only strings, not bytes.
The decoding procedure is handled at its roots, not in the plugins
and is done by only using `encoding.get_as_string` function.

Fixed bugs:
    * invalid certificate splitting under maas service which usually
      generated an extra invalid certificate (empty string + footer)
    * text operations on bytes in maas and cloudstack (split, comparing)
    * multiple types for certificates (now only strings)
    * not receiving bytes from opennebula service when using `get_user_data`
      (which leads to crash under later processing through io.BytesIO)
    * erroneous certificate parsing/stripping/replacing under x509 importing
      (footer remains, not all possible EOLs replaced as it should)

Also added new and refined actual misleading unittests.

Change-Id: I704c43f5f784458a881293d761a21e62aed85732
This commit is contained in:
Cosmin Poieana
2015-06-04 16:44:16 +03:00
parent 55880b8ff8
commit ae15fee086
14 changed files with 163 additions and 90 deletions

View File

@@ -20,6 +20,7 @@ import time
from oslo.config import cfg from oslo.config import cfg
from cloudbaseinit.openstack.common import log as logging from cloudbaseinit.openstack.common import log as logging
from cloudbaseinit.utils import encoding
opts = [ opts = [
@@ -88,13 +89,17 @@ class BaseMetadataService(object):
else: else:
raise raise
def _get_cache_data(self, path): def _get_cache_data(self, path, decode=False):
if path in self._cache: """Get meta data with caching and decoding support."""
key = (path, decode)
if key in self._cache:
LOG.debug("Using cached copy of metadata: '%s'" % path) LOG.debug("Using cached copy of metadata: '%s'" % path)
return self._cache[path] return self._cache[key]
else: else:
data = self._exec_with_retry(lambda: self._get_data(path)) data = self._exec_with_retry(lambda: self._get_data(path))
self._cache[path] = data if decode:
data = encoding.get_as_string(data)
self._cache[key] = data
return data return data
def get_instance_id(self): def get_instance_id(self):

View File

@@ -51,9 +51,9 @@ class BaseOpenStackService(base.BaseMetadataService):
def _get_meta_data(self, version='latest'): def _get_meta_data(self, version='latest'):
path = posixpath.normpath( path = posixpath.normpath(
posixpath.join('openstack', version, 'meta_data.json')) posixpath.join('openstack', version, 'meta_data.json'))
data = self._get_cache_data(path) data = self._get_cache_data(path, decode=True)
if data: if data:
return json.loads(encoding.get_as_string(data)) return json.loads(data)
def get_instance_id(self): def get_instance_id(self):
return self._get_meta_data().get('uuid') return self._get_meta_data().get('uuid')
@@ -136,10 +136,10 @@ class BaseOpenStackService(base.BaseMetadataService):
if not certs: if not certs:
# Look if the user_data contains a PEM certificate # Look if the user_data contains a PEM certificate
try: try:
user_data = self.get_user_data() user_data = self.get_user_data().strip()
if user_data.startswith( if user_data.startswith(
x509constants.PEM_HEADER.encode()): x509constants.PEM_HEADER.encode()):
certs.append(user_data) certs.append(encoding.get_as_string(user_data))
except base.NotExistingMetadataException: except base.NotExistingMetadataException:
LOG.debug("user_data metadata not present") LOG.debug("user_data metadata not present")

View File

@@ -21,6 +21,7 @@ from six.moves import urllib
from cloudbaseinit.metadata.services import base from cloudbaseinit.metadata.services import base
from cloudbaseinit.openstack.common import log as logging from cloudbaseinit.openstack.common import log as logging
from cloudbaseinit.osutils import factory as osutils_factory from cloudbaseinit.osutils import factory as osutils_factory
from cloudbaseinit.utils import encoding
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@@ -104,11 +105,11 @@ class CloudStack(base.BaseMetadataService):
def get_instance_id(self): def get_instance_id(self):
"""Instance name of the virtual machine.""" """Instance name of the virtual machine."""
return self._get_cache_data('instance-id') return self._get_cache_data('instance-id', decode=True)
def get_host_name(self): def get_host_name(self):
"""Hostname of the virtual machine.""" """Hostname of the virtual machine."""
return self._get_cache_data('local-hostname') return self._get_cache_data('local-hostname', decode=True)
def get_user_data(self): def get_user_data(self):
"""User data for this virtual machine.""" """User data for this virtual machine."""
@@ -117,7 +118,9 @@ class CloudStack(base.BaseMetadataService):
def get_public_keys(self): def get_public_keys(self):
"""Available ssh public keys.""" """Available ssh public keys."""
ssh_keys = [] ssh_keys = []
for ssh_key in self._get_cache_data('public-keys').splitlines(): ssh_chunks = self._get_cache_data('public-keys',
decode=True).splitlines()
for ssh_key in ssh_chunks:
ssh_key = ssh_key.strip() ssh_key = ssh_key.strip()
if not ssh_key: if not ssh_key:
continue continue
@@ -155,14 +158,13 @@ class CloudStack(base.BaseMetadataService):
if response.status != 200: if response.status != 200:
LOG.warning("Getting password failed: %(status)s " LOG.warning("Getting password failed: %(status)s "
"%(reason)s - %(message)s", "%(reason)s - %(message)r",
{"status": response.status, {"status": response.status,
"reason": response.reason, "reason": response.reason,
"message": response.read()}) "message": response.read()})
continue continue
content = response.read() content = response.read().strip()
content = content.strip()
if not content: if not content:
LOG.warning("The Password Server did not have any " LOG.warning("The Password Server did not have any "
"password for the current instance.") "password for the current instance.")
@@ -180,7 +182,7 @@ class CloudStack(base.BaseMetadataService):
LOG.info("The password server return a valid password " LOG.info("The password server return a valid password "
"for the current instance.") "for the current instance.")
password = content.decode() password = encoding.get_as_string(content)
break break
return password return password
@@ -201,14 +203,14 @@ class CloudStack(base.BaseMetadataService):
response = connection.getresponse() response = connection.getresponse()
if response.status != 200: if response.status != 200:
LOG.warning("Removing password failed: %(status)s " LOG.warning("Removing password failed: %(status)s "
"%(reason)s - %(message)s", "%(reason)s - %(message)r",
{"status": response.status, {"status": response.status,
"reason": response.reason, "reason": response.reason,
"message": response.read()}) "message": response.read()})
continue continue
content = response.read() content = response.read()
if content.decode() != BAD_REQUEST: if content != BAD_REQUEST: # comparing bytes with bytes
LOG.info("The password was removed from the Password Server.") LOG.info("The password was removed from the Password Server.")
break break
else: else:

View File

@@ -78,25 +78,25 @@ class EC2Service(base.BaseMetadataService):
def get_host_name(self): def get_host_name(self):
return self._get_cache_data('%s/meta-data/local-hostname' % return self._get_cache_data('%s/meta-data/local-hostname' %
self._metadata_version) self._metadata_version, decode=True)
def get_instance_id(self): def get_instance_id(self):
return self._get_cache_data('%s/meta-data/instance-id' % return self._get_cache_data('%s/meta-data/instance-id' %
self._metadata_version) self._metadata_version, decode=True)
def get_public_keys(self): def get_public_keys(self):
ssh_keys = [] ssh_keys = []
keys_info = self._get_cache_data( keys_info = self._get_cache_data(
'%s/meta-data/public-keys' % '%s/meta-data/public-keys' %
self._metadata_version).split("\n") self._metadata_version, decode=True).splitlines()
for key_info in keys_info: for key_info in keys_info:
(idx, key_name) = key_info.split('=') (idx, key_name) = key_info.split('=')
ssh_key = self._get_cache_data( ssh_key = self._get_cache_data(
'%(version)s/meta-data/public-keys/%(idx)s/openssh-key' % '%(version)s/meta-data/public-keys/%(idx)s/openssh-key' %
{'version': self._metadata_version, 'idx': idx}) {'version': self._metadata_version, 'idx': idx}, decode=True)
ssh_keys.append(ssh_key.strip()) ssh_keys.append(ssh_key.strip())
return ssh_keys return ssh_keys

View File

@@ -13,6 +13,7 @@
# under the License. # under the License.
import posixpath import posixpath
import re
from oauthlib import oauth1 from oauthlib import oauth1
from oslo.config import cfg from oslo.config import cfg
@@ -108,23 +109,25 @@ class MaaSHttpService(base.BaseMetadataService):
def get_host_name(self): def get_host_name(self):
return self._get_cache_data('%s/meta-data/local-hostname' % return self._get_cache_data('%s/meta-data/local-hostname' %
self._metadata_version) self._metadata_version, decode=True)
def get_instance_id(self): def get_instance_id(self):
return self._get_cache_data('%s/meta-data/instance-id' % return self._get_cache_data('%s/meta-data/instance-id' %
self._metadata_version) self._metadata_version, decode=True)
def _get_list_from_text(self, text, delimiter):
return [v + delimiter for v in text.split(delimiter)]
def get_public_keys(self): def get_public_keys(self):
return self._get_cache_data('%s/meta-data/public-keys' % return self._get_cache_data('%s/meta-data/public-keys' %
self._metadata_version).splitlines() self._metadata_version,
decode=True).splitlines()
def get_client_auth_certs(self): def get_client_auth_certs(self):
return self._get_list_from_text( certs_data = self._get_cache_data('%s/meta-data/x509' %
self._get_cache_data('%s/meta-data/x509' % self._metadata_version), self._metadata_version,
"%s\n" % x509constants.PEM_FOOTER) decode=True)
pattern = r"{begin}[\s\S]+?{end}".format(
begin=x509constants.PEM_HEADER,
end=x509constants.PEM_FOOTER)
return re.findall(pattern, certs_data)
def get_user_data(self): def get_user_data(self):
return self._get_cache_data('%s/user-data' % self._metadata_version) return self._get_cache_data('%s/user-data' % self._metadata_version)

View File

@@ -146,7 +146,7 @@ class OpenNebulaService(base.BaseMetadataService):
raise base.NotExistingMetadataException(msg) raise base.NotExistingMetadataException(msg)
return self._dict_content[name] return self._dict_content[name]
def _get_cache_data(self, names, iid=None): def _get_cache_data(self, names, iid=None, decode=False):
# Solves caching issues when working with # Solves caching issues when working with
# multiple names (lists not hashable). # multiple names (lists not hashable).
# This happens because the caching function used # This happens because the caching function used
@@ -160,7 +160,8 @@ class OpenNebulaService(base.BaseMetadataService):
names[ind] = value.format(iid=iid) names[ind] = value.format(iid=iid)
for name in names: for name in names:
try: try:
return super(OpenNebulaService, self)._get_cache_data(name) return super(OpenNebulaService, self)._get_cache_data(
name, decode=decode)
except base.NotExistingMetadataException: except base.NotExistingMetadataException:
pass pass
msg = "None of {} metadata was found".format(", ".join(names)) msg = "None of {} metadata was found".format(", ".join(names))
@@ -192,14 +193,13 @@ class OpenNebulaService(base.BaseMetadataService):
return INSTANCE_ID return INSTANCE_ID
def get_host_name(self): def get_host_name(self):
return encoding.get_as_string(self._get_cache_data(HOST_NAME)) return self._get_cache_data(HOST_NAME, decode=True)
def get_user_data(self): def get_user_data(self):
return self._get_cache_data(USER_DATA) return self._get_cache_data(USER_DATA)
def get_public_keys(self): def get_public_keys(self):
return encoding.get_as_string( return self._get_cache_data(PUBLIC_KEY, decode=True).splitlines()
self._get_cache_data(PUBLIC_KEY)).splitlines()
def get_network_details(self): def get_network_details(self):
"""Return a list of NetworkDetails objects. """Return a list of NetworkDetails objects.
@@ -215,19 +215,17 @@ class OpenNebulaService(base.BaseMetadataService):
for iid in range(ncount): for iid in range(ncount):
try: try:
# get existing values # get existing values
mac = encoding.get_as_string( mac = self._get_cache_data(MAC, iid=iid, decode=True).upper()
self._get_cache_data(MAC, iid=iid)).upper() address = self._get_cache_data(ADDRESS, iid=iid, decode=True)
address = encoding.get_as_string(self._get_cache_data(ADDRESS,
iid=iid))
# try to find/predict and compute the rest # try to find/predict and compute the rest
try: try:
gateway = encoding.get_as_string( gateway = self._get_cache_data(GATEWAY, iid=iid,
self._get_cache_data(GATEWAY, iid=iid)) decode=True)
except base.NotExistingMetadataException: except base.NotExistingMetadataException:
gateway = None gateway = None
try: try:
netmask = encoding.get_as_string( netmask = self._get_cache_data(NETMASK, iid=iid,
self._get_cache_data(NETMASK, iid=iid)) decode=True)
except base.NotExistingMetadataException: except base.NotExistingMetadataException:
if not gateway: if not gateway:
raise raise
@@ -244,8 +242,8 @@ class OpenNebulaService(base.BaseMetadataService):
broadcast=broadcast, broadcast=broadcast,
gateway=gateway, gateway=gateway,
gateway6=None, gateway6=None,
dnsnameservers=encoding.get_as_string( dnsnameservers=self._get_cache_data(
self._get_cache_data(DNSNS, iid=iid)).split(" ") DNSNS, iid=iid, decode=True).split(" ")
) )
except base.NotExistingMetadataException: except base.NotExistingMetadataException:
LOG.debug("Incomplete NIC details") LOG.debug("Incomplete NIC details")

View File

@@ -70,11 +70,11 @@ class TestBaseOpenStackService(unittest.TestCase):
@mock.patch(MODPATH + @mock.patch(MODPATH +
".BaseOpenStackService._get_cache_data") ".BaseOpenStackService._get_cache_data")
def test_get_meta_data(self, mock_get_cache_data): def test_get_meta_data(self, mock_get_cache_data):
mock_get_cache_data.return_value = b'{"fake": "data"}' mock_get_cache_data.return_value = '{"fake": "data"}'
response = self._service._get_meta_data( response = self._service._get_meta_data(
version='fake version') version='fake version')
path = posixpath.join('openstack', 'fake version', 'meta_data.json') path = posixpath.join('openstack', 'fake version', 'meta_data.json')
mock_get_cache_data.assert_called_with(path) mock_get_cache_data.assert_called_with(path, decode=True)
self.assertEqual({"fake": "data"}, response) self.assertEqual({"fake": "data"}, response)
@mock.patch(MODPATH + @mock.patch(MODPATH +
@@ -151,7 +151,7 @@ class TestBaseOpenStackService(unittest.TestCase):
if isinstance(ret_value, bytes) and ret_value.startswith( if isinstance(ret_value, bytes) and ret_value.startswith(
x509constants.PEM_HEADER.encode()): x509constants.PEM_HEADER.encode()):
mock_get_user_data.assert_called_once_with() mock_get_user_data.assert_called_once_with()
self.assertEqual([ret_value], response) self.assertEqual([ret_value.decode()], response)
elif ret_value is base.NotExistingMetadataException: elif ret_value is base.NotExistingMetadataException:
self.assertFalse(response) self.assertFalse(response)
else: else:

View File

@@ -12,6 +12,7 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import functools
import socket import socket
import unittest import unittest
@@ -149,12 +150,17 @@ class CloudStackTest(unittest.TestCase):
@mock.patch('cloudbaseinit.metadata.services.cloudstack.CloudStack' @mock.patch('cloudbaseinit.metadata.services.cloudstack.CloudStack'
'._get_cache_data') '._get_cache_data')
def _test_cache_response(self, mock_get_cache_data, method, metadata): def _test_cache_response(self, mock_get_cache_data, method, metadata,
decode=True):
mock_get_cache_data.side_effect = [mock.sentinel.response] mock_get_cache_data.side_effect = [mock.sentinel.response]
response = method() response = method()
self.assertEqual(mock.sentinel.response, response) self.assertEqual(mock.sentinel.response, response)
mock_get_cache_data.assert_called_once_with(metadata) cache_assert = functools.partial(
mock_get_cache_data.assert_called_once_with,
metadata)
if decode:
cache_assert(decode=decode)
def test_get_instance_id(self): def test_get_instance_id(self):
self._test_cache_response(method=self._service.get_instance_id, self._test_cache_response(method=self._service.get_instance_id,
@@ -166,7 +172,7 @@ class CloudStackTest(unittest.TestCase):
def test_get_user_data(self): def test_get_user_data(self):
self._test_cache_response(method=self._service.get_user_data, self._test_cache_response(method=self._service.get_user_data,
metadata='../user-data') metadata='../user-data', decode=False)
@mock.patch('cloudbaseinit.metadata.services.cloudstack.CloudStack' @mock.patch('cloudbaseinit.metadata.services.cloudstack.CloudStack'
'._get_cache_data') '._get_cache_data')

View File

@@ -99,7 +99,8 @@ class EC2ServiceTest(unittest.TestCase):
def test_get_host_name(self, mock_get_cache_data): def test_get_host_name(self, mock_get_cache_data):
response = self._service.get_host_name() response = self._service.get_host_name()
mock_get_cache_data.assert_called_once_with( mock_get_cache_data.assert_called_once_with(
'%s/meta-data/local-hostname' % self._service._metadata_version) '%s/meta-data/local-hostname' % self._service._metadata_version,
decode=True)
self.assertEqual(mock_get_cache_data.return_value, response) self.assertEqual(mock_get_cache_data.return_value, response)
@mock.patch('cloudbaseinit.metadata.services.ec2service.EC2Service' @mock.patch('cloudbaseinit.metadata.services.ec2service.EC2Service'
@@ -107,7 +108,8 @@ class EC2ServiceTest(unittest.TestCase):
def test_get_instance_id(self, mock_get_cache_data): def test_get_instance_id(self, mock_get_cache_data):
response = self._service.get_instance_id() response = self._service.get_instance_id()
mock_get_cache_data.assert_called_once_with( mock_get_cache_data.assert_called_once_with(
'%s/meta-data/instance-id' % self._service._metadata_version) '%s/meta-data/instance-id' % self._service._metadata_version,
decode=True)
self.assertEqual(mock_get_cache_data.return_value, response) self.assertEqual(mock_get_cache_data.return_value, response)
@mock.patch('cloudbaseinit.metadata.services.ec2service.EC2Service' @mock.patch('cloudbaseinit.metadata.services.ec2service.EC2Service'
@@ -117,10 +119,11 @@ class EC2ServiceTest(unittest.TestCase):
response = self._service.get_public_keys() response = self._service.get_public_keys()
expected = [ expected = [
mock.call('%s/meta-data/public-keys' % mock.call('%s/meta-data/public-keys' %
self._service._metadata_version), self._service._metadata_version,
decode=True),
mock.call('%(version)s/meta-data/public-keys/%(' mock.call('%(version)s/meta-data/public-keys/%('
'idx)s/openssh-key' % 'idx)s/openssh-key' %
{'version': self._service._metadata_version, {'version': self._service._metadata_version,
'idx': 'key'})] 'idx': 'key'}, decode=True)]
self.assertEqual(expected, mock_get_cache_data.call_args_list) self.assertEqual(expected, mock_get_cache_data.call_args_list)
self.assertEqual(['fake key'], response) self.assertEqual(['fake key'], response)

View File

@@ -148,7 +148,8 @@ class MaaSHttpServiceTest(unittest.TestCase):
response = self._maasservice.get_host_name() response = self._maasservice.get_host_name()
mock_get_cache_data.assert_called_once_with( mock_get_cache_data.assert_called_once_with(
'%s/meta-data/local-hostname' % '%s/meta-data/local-hostname' %
self._maasservice._metadata_version) self._maasservice._metadata_version,
decode=True)
self.assertEqual(mock_get_cache_data.return_value, response) self.assertEqual(mock_get_cache_data.return_value, response)
@mock.patch("cloudbaseinit.metadata.services.maasservice.MaaSHttpService" @mock.patch("cloudbaseinit.metadata.services.maasservice.MaaSHttpService"
@@ -156,13 +157,10 @@ class MaaSHttpServiceTest(unittest.TestCase):
def test_get_instance_id(self, mock_get_cache_data): def test_get_instance_id(self, mock_get_cache_data):
response = self._maasservice.get_instance_id() response = self._maasservice.get_instance_id()
mock_get_cache_data.assert_called_once_with( mock_get_cache_data.assert_called_once_with(
'%s/meta-data/instance-id' % self._maasservice._metadata_version) '%s/meta-data/instance-id' % self._maasservice._metadata_version,
decode=True)
self.assertEqual(mock_get_cache_data.return_value, response) self.assertEqual(mock_get_cache_data.return_value, response)
def test_get_list_from_text(self):
response = self._maasservice._get_list_from_text('fake:text', ':')
self.assertEqual(['fake:', 'text:'], response)
@mock.patch("cloudbaseinit.metadata.services.maasservice.MaaSHttpService" @mock.patch("cloudbaseinit.metadata.services.maasservice.MaaSHttpService"
"._get_cache_data") "._get_cache_data")
def test_get_public_keys(self, mock_get_cache_data): def test_get_public_keys(self, mock_get_cache_data):
@@ -174,21 +172,26 @@ class MaaSHttpServiceTest(unittest.TestCase):
mock_get_cache_data.return_value = public_key mock_get_cache_data.return_value = public_key
response = self._maasservice.get_public_keys() response = self._maasservice.get_public_keys()
mock_get_cache_data.assert_called_with( mock_get_cache_data.assert_called_with(
'%s/meta-data/public-keys' % self._maasservice._metadata_version) '%s/meta-data/public-keys' % self._maasservice._metadata_version,
decode=True)
self.assertEqual(public_keys, response) self.assertEqual(public_keys, response)
@mock.patch("cloudbaseinit.metadata.services.maasservice.MaaSHttpService"
"._get_list_from_text")
@mock.patch("cloudbaseinit.metadata.services.maasservice.MaaSHttpService" @mock.patch("cloudbaseinit.metadata.services.maasservice.MaaSHttpService"
"._get_cache_data") "._get_cache_data")
def test_get_client_auth_certs(self, mock_get_cache_data, def test_get_client_auth_certs(self, mock_get_cache_data):
mock_get_list_from_text): certs = [
"{begin}\n{cert}\n{end}".format(
begin=x509constants.PEM_HEADER,
end=x509constants.PEM_FOOTER,
cert=cert)
for cert in ("first cert", "second cert")
]
mock_get_cache_data.return_value = "\n".join(certs) + "\n"
response = self._maasservice.get_client_auth_certs() response = self._maasservice.get_client_auth_certs()
mock_get_cache_data.assert_called_with( mock_get_cache_data.assert_called_with(
'%s/meta-data/x509' % self._maasservice._metadata_version) '%s/meta-data/x509' % self._maasservice._metadata_version,
mock_get_list_from_text.assert_called_once_with( decode=True)
mock_get_cache_data(), "%s\n" % x509constants.PEM_FOOTER) self.assertEqual(certs, response)
self.assertEqual(mock_get_list_from_text.return_value, response)
@mock.patch("cloudbaseinit.metadata.services.maasservice.MaaSHttpService" @mock.patch("cloudbaseinit.metadata.services.maasservice.MaaSHttpService"
"._get_cache_data") "._get_cache_data")

View File

@@ -0,0 +1,54 @@
# Copyright 2014 Cloudbase Solutions Srl
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import os
import tempfile
import unittest
from cloudbaseinit.tests import testutils
from cloudbaseinit.utils import encoding
class TestEncoding(unittest.TestCase):
def test_get_as_string(self):
content_map = [
("data", "data"),
(b"data", "data"),
("data".encode(), "data"),
("data".encode("utf-16"), None)
]
with testutils.LogSnatcher("cloudbaseinit.utils.encoding") as snatch:
for content, expect in content_map:
self.assertEqual(expect, encoding.get_as_string(content))
self.assertIn("couldn't decode", snatch.output[0].lower())
def test_write_file(self):
mode_map = [
(("w", "r"), "my test\ndata\n\n", False),
(("wb", "rb"), "\r\n".join((chr(x) for x in
(32, 125, 0))).encode(), False),
(("wb", "rb"), "my test\ndata\n\n", True)
]
with testutils.create_tempdir() as temp:
fd, path = tempfile.mkstemp(dir=temp)
os.close(fd)
for (write, read), data, encode in mode_map:
encoding.write_file(path, data, mode=write)
with open(path, read) as stream:
content = stream.read()
if encode:
data = data.encode()
self.assertEqual(data, content)

View File

@@ -285,7 +285,6 @@ class CryptoAPICertManagerTests(unittest.TestCase):
fake_cert_data += x509constants.PEM_HEADER + '\n' fake_cert_data += x509constants.PEM_HEADER + '\n'
fake_cert_data += 'fake cert' + '\n' fake_cert_data += 'fake cert' + '\n'
fake_cert_data += x509constants.PEM_FOOTER fake_cert_data += x509constants.PEM_FOOTER
fake_cert_data = fake_cert_data.encode()
response = self._x509_manager._get_cert_base64(fake_cert_data) response = self._x509_manager._get_cert_base64(fake_cert_data)
self.assertEqual('fake cert', response) self.assertEqual('fake cert', response)

View File

@@ -14,6 +14,11 @@
import six import six
from cloudbaseinit.openstack.common import log as logging
LOG = logging.getLogger(__name__)
def get_as_string(value): def get_as_string(value):
if value is None or isinstance(value, six.text_type): if value is None or isinstance(value, six.text_type):
@@ -22,7 +27,9 @@ def get_as_string(value):
try: try:
return value.decode() return value.decode()
except Exception: except Exception:
pass # This is important, because None will be returned,
# but not that serious to raise an exception.
LOG.error("Couldn't decode: %r", value)
def write_file(target_path, data, mode='wb'): def write_file(target_path, data, mode='wb'):
@@ -31,13 +38,3 @@ def write_file(target_path, data, mode='wb'):
with open(target_path, mode) as f: with open(target_path, mode) as f:
f.write(data) f.write(data)
def read_file(target_path, mode='rb'):
with open(target_path, mode) as f:
data = f.read()
if 'b' in mode:
data = data.decode()
return data

View File

@@ -20,7 +20,6 @@ import uuid
import six import six
from cloudbaseinit.utils import encoding
from cloudbaseinit.utils.windows import cryptoapi from cloudbaseinit.utils.windows import cryptoapi
from cloudbaseinit.utils import x509constants from cloudbaseinit.utils import x509constants
@@ -205,13 +204,17 @@ class CryptoAPICertManager(object):
free(subject_encoded) free(subject_encoded)
def _get_cert_base64(self, cert_data): def _get_cert_base64(self, cert_data):
base64_cert_data = encoding.get_as_string(cert_data) """Remove certificate header and footer and also new lines."""
if base64_cert_data.startswith(x509constants.PEM_HEADER): # It's assured that the certificate is already a string.
base64_cert_data = base64_cert_data[len(x509constants.PEM_HEADER):] removal = [
if base64_cert_data.endswith(x509constants.PEM_FOOTER): x509constants.PEM_HEADER,
base64_cert_data = base64_cert_data[:len(base64_cert_data) - x509constants.PEM_FOOTER,
len(x509constants.PEM_FOOTER)] "\r",
return base64_cert_data.replace("\n", "") "\n"
]
for remove in removal:
cert_data = cert_data.replace(remove, "")
return cert_data
def import_cert(self, cert_data, machine_keyset=True, def import_cert(self, cert_data, machine_keyset=True,
store_name=STORE_NAME_MY): store_name=STORE_NAME_MY):