Respond to review:

- Refactor both the base64 encoding and decoding into utility functions.

Also:

- Mechanically fix some other broken untested code.
This commit is contained in:
Barry Warsaw
2015-01-27 15:03:52 -05:00
parent 6e7bba7a56
commit adcd95583c
9 changed files with 42 additions and 71 deletions

View File

@@ -38,13 +38,7 @@ def _decode(data, encoding=None):
if not encoding or encoding.lower() in ['raw']: if not encoding or encoding.lower() in ['raw']:
return data return data
elif encoding.lower() in ['base64', 'b64']: elif encoding.lower() in ['base64', 'b64']:
# Try to give us a native string in both Python 2 and 3, and remember return util.b64d(data)
# that b64decode() returns bytes in Python 3.
decoded = base64.b64decode(data)
try:
return decoded.decode('utf-8')
except UnicodeDecodeError:
return decoded
elif encoding.lower() in ['gzip', 'gz']: elif encoding.lower() in ['gzip', 'gz']:
return util.decomp_gzip(data, quiet=False) return util.decomp_gzip(data, quiet=False)
else: else:

View File

@@ -32,7 +32,7 @@ from cloudinit import util
def _split_hash(bin_hash): def _split_hash(bin_hash):
split_up = [] split_up = []
for i in xrange(0, len(bin_hash), 2): for i in range(0, len(bin_hash), 2):
split_up.append(bin_hash[i:i + 2]) split_up.append(bin_hash[i:i + 2])
return split_up return split_up

View File

@@ -426,12 +426,7 @@ def read_context_disk_dir(source_dir, asuser=None):
context.get('USER_DATA_ENCODING')) context.get('USER_DATA_ENCODING'))
if encoding == "base64": if encoding == "base64":
try: try:
userdata = base64.b64decode(results['userdata']) results['userdata'] = util.b64d(results['userdata'])
# In Python 3 we still expect a str, but b64decode will return
# bytes. Convert to str.
if isinstance(userdata, bytes):
userdata = userdata.decode('utf-8')
results['userdata'] = userdata
except TypeError: except TypeError:
LOG.warn("Failed base64 decoding of userdata") LOG.warn("Failed base64 decoding of userdata")

View File

@@ -351,16 +351,7 @@ def query_data(noun, seed_device, seed_timeout, strip=False, default=None,
if b64: if b64:
try: try:
# Generally, we want native strings in the values. Python 3's return util.b64d(resp)
# b64decode will return bytes though, so decode them to utf-8 if
# possible. If that fails, return the bytes.
decoded = base64.b64decode(resp)
try:
if isinstance(decoded, bytes):
return decoded.decode('utf-8')
except UnicodeDecodeError:
pass
return decoded
# Bogus input produces different errors in Python 2 and 3; catch both. # Bogus input produces different errors in Python 2 and 3; catch both.
except (TypeError, binascii.Error): except (TypeError, binascii.Error):
LOG.warn("Failed base64 decoding key '%s'", noun) LOG.warn("Failed base64 decoding key '%s'", noun)

View File

@@ -44,6 +44,7 @@ import sys
import tempfile import tempfile
import time import time
from base64 import b64decode, b64encode
from six.moves.urllib import parse as urlparse from six.moves.urllib import parse as urlparse
import six import six
@@ -90,6 +91,25 @@ def encode_text(text, encoding='utf-8'):
return text return text
return text.encode(encoding) return text.encode(encoding)
def b64d(source):
# Base64 decode some data, accepting bytes or unicode/str, and returning
# str/unicode if the result is utf-8 compatible, otherwise returning bytes.
decoded = b64decode(source)
if isinstance(decoded, bytes):
try:
return decoded.decode('utf-8')
except UnicodeDecodeError:
return decoded
def b64e(source):
# Base64 encode some data, accepting bytes or unicode/str, and returning
# str/unicode if the result is utf-8 compatible, otherwise returning bytes.
if not isinstance(source, bytes):
source = source.encode('utf-8')
return b64encode(source).decode('utf-8')
# Path for DMI Data # Path for DMI Data
DMI_SYS_PATH = "/sys/class/dmi/id" DMI_SYS_PATH = "/sys/class/dmi/id"

View File

@@ -1,5 +1,5 @@
from cloudinit import helpers from cloudinit import helpers
from cloudinit.util import load_file from cloudinit.util import b64e, load_file
from cloudinit.sources import DataSourceAzure from cloudinit.sources import DataSourceAzure
from ..helpers import TestCase, populate_dir from ..helpers import TestCase, populate_dir
@@ -12,7 +12,6 @@ try:
except ImportError: except ImportError:
from contextlib2 import ExitStack from contextlib2 import ExitStack
import base64
import crypt import crypt
import os import os
import stat import stat
@@ -22,13 +21,6 @@ import tempfile
import unittest import unittest
def b64(source):
# In Python 3, b64encode only accepts bytes and returns bytes.
if not isinstance(source, bytes):
source = source.encode('utf-8')
return base64.b64encode(source).decode('us-ascii')
def construct_valid_ovf_env(data=None, pubkeys=None, userdata=None): def construct_valid_ovf_env(data=None, pubkeys=None, userdata=None):
if data is None: if data is None:
data = {'HostName': 'FOOHOST'} data = {'HostName': 'FOOHOST'}
@@ -58,7 +50,7 @@ def construct_valid_ovf_env(data=None, pubkeys=None, userdata=None):
content += "<%s%s>%s</%s>\n" % (key, attrs, val, key) content += "<%s%s>%s</%s>\n" % (key, attrs, val, key)
if userdata: if userdata:
content += "<UserData>%s</UserData>\n" % (b64(userdata)) content += "<UserData>%s</UserData>\n" % (b64e(userdata))
if pubkeys: if pubkeys:
content += "<SSH><PublicKeys>\n" content += "<SSH><PublicKeys>\n"
@@ -189,7 +181,7 @@ class TestAzureDataSource(TestCase):
# set dscfg in via base64 encoded yaml # set dscfg in via base64 encoded yaml
cfg = {'agent_command': "my_command"} cfg = {'agent_command': "my_command"}
odata = {'HostName': "myhost", 'UserName': "myuser", odata = {'HostName': "myhost", 'UserName': "myuser",
'dscfg': {'text': b64(yaml.dump(cfg)), 'dscfg': {'text': b64e(yaml.dump(cfg)),
'encoding': 'base64'}} 'encoding': 'base64'}}
data = {'ovfcontent': construct_valid_ovf_env(data=odata)} data = {'ovfcontent': construct_valid_ovf_env(data=odata)}
@@ -241,7 +233,7 @@ class TestAzureDataSource(TestCase):
def test_userdata_found(self): def test_userdata_found(self):
mydata = "FOOBAR" mydata = "FOOBAR"
odata = {'UserData': b64(mydata)} odata = {'UserData': b64e(mydata)}
data = {'ovfcontent': construct_valid_ovf_env(data=odata)} data = {'ovfcontent': construct_valid_ovf_env(data=odata)}
dsrc = self._get_ds(data) dsrc = self._get_ds(data)
@@ -289,7 +281,7 @@ class TestAzureDataSource(TestCase):
'command': 'my-bounce-command', 'command': 'my-bounce-command',
'hostname_command': 'my-hostname-command'}} 'hostname_command': 'my-hostname-command'}}
odata = {'HostName': "xhost", odata = {'HostName': "xhost",
'dscfg': {'text': b64(yaml.dump(cfg)), 'dscfg': {'text': b64e(yaml.dump(cfg)),
'encoding': 'base64'}} 'encoding': 'base64'}}
data = {'ovfcontent': construct_valid_ovf_env(data=odata)} data = {'ovfcontent': construct_valid_ovf_env(data=odata)}
self._get_ds(data).get_data() self._get_ds(data).get_data()
@@ -304,7 +296,7 @@ class TestAzureDataSource(TestCase):
# config specifying set_hostname off should not bounce # config specifying set_hostname off should not bounce
cfg = {'set_hostname': False} cfg = {'set_hostname': False}
odata = {'HostName': "xhost", odata = {'HostName': "xhost",
'dscfg': {'text': b64(yaml.dump(cfg)), 'dscfg': {'text': b64e(yaml.dump(cfg)),
'encoding': 'base64'}} 'encoding': 'base64'}}
data = {'ovfcontent': construct_valid_ovf_env(data=odata)} data = {'ovfcontent': construct_valid_ovf_env(data=odata)}
self._get_ds(data).get_data() self._get_ds(data).get_data()
@@ -333,7 +325,7 @@ class TestAzureDataSource(TestCase):
# Make sure that user can affect disk aliases # Make sure that user can affect disk aliases
dscfg = {'disk_aliases': {'ephemeral0': '/dev/sdc'}} dscfg = {'disk_aliases': {'ephemeral0': '/dev/sdc'}}
odata = {'HostName': "myhost", 'UserName': "myuser", odata = {'HostName': "myhost", 'UserName': "myuser",
'dscfg': {'text': b64(yaml.dump(dscfg)), 'dscfg': {'text': b64e(yaml.dump(dscfg)),
'encoding': 'base64'}} 'encoding': 'base64'}}
usercfg = {'disk_setup': {'/dev/sdc': {'something': '...'}, usercfg = {'disk_setup': {'/dev/sdc': {'something': '...'},
'ephemeral0': False}} 'ephemeral0': False}}
@@ -370,7 +362,7 @@ class TestAzureDataSource(TestCase):
def test_existing_ovf_same(self): def test_existing_ovf_same(self):
# waagent/SharedConfig left alone if found ovf-env.xml same as cached # waagent/SharedConfig left alone if found ovf-env.xml same as cached
odata = {'UserData': b64("SOMEUSERDATA")} odata = {'UserData': b64e("SOMEUSERDATA")}
data = {'ovfcontent': construct_valid_ovf_env(data=odata)} data = {'ovfcontent': construct_valid_ovf_env(data=odata)}
populate_dir(self.waagent_d, populate_dir(self.waagent_d,
@@ -394,9 +386,9 @@ class TestAzureDataSource(TestCase):
# 'get_data' should remove SharedConfig.xml in /var/lib/waagent # 'get_data' should remove SharedConfig.xml in /var/lib/waagent
# if ovf-env.xml differs. # if ovf-env.xml differs.
cached_ovfenv = construct_valid_ovf_env( cached_ovfenv = construct_valid_ovf_env(
{'userdata': b64("FOO_USERDATA")}) {'userdata': b64e("FOO_USERDATA")})
new_ovfenv = construct_valid_ovf_env( new_ovfenv = construct_valid_ovf_env(
{'userdata': b64("NEW_USERDATA")}) {'userdata': b64e("NEW_USERDATA")})
populate_dir(self.waagent_d, populate_dir(self.waagent_d,
{'ovf-env.xml': cached_ovfenv, {'ovf-env.xml': cached_ovfenv,

View File

@@ -3,19 +3,12 @@ from cloudinit.sources import DataSourceOpenNebula as ds
from cloudinit import util from cloudinit import util
from ..helpers import TestCase, populate_dir from ..helpers import TestCase, populate_dir
from base64 import b64encode
import os import os
import pwd import pwd
import shutil import shutil
import tempfile import tempfile
import unittest import unittest
def b64(source):
# In Python 3, b64encode only accepts bytes and returns bytes.
if not isinstance(source, bytes):
source = source.encode('utf-8')
return b64encode(source).decode('us-ascii')
TEST_VARS = { TEST_VARS = {
'VAR1': 'single', 'VAR1': 'single',
@@ -186,7 +179,7 @@ class TestOpenNebulaDataSource(TestCase):
self.assertEqual(USER_DATA, results['userdata']) self.assertEqual(USER_DATA, results['userdata'])
def test_user_data_encoding_required_for_decode(self): def test_user_data_encoding_required_for_decode(self):
b64userdata = b64(USER_DATA) b64userdata = util.b64e(USER_DATA)
for k in ('USER_DATA', 'USERDATA'): for k in ('USER_DATA', 'USERDATA'):
my_d = os.path.join(self.tmp, k) my_d = os.path.join(self.tmp, k)
populate_context_dir(my_d, {k: b64userdata}) populate_context_dir(my_d, {k: b64userdata})
@@ -198,7 +191,7 @@ class TestOpenNebulaDataSource(TestCase):
def test_user_data_base64_encoding(self): def test_user_data_base64_encoding(self):
for k in ('USER_DATA', 'USERDATA'): for k in ('USER_DATA', 'USERDATA'):
my_d = os.path.join(self.tmp, k) my_d = os.path.join(self.tmp, k)
populate_context_dir(my_d, {k: b64(USER_DATA), populate_context_dir(my_d, {k: util.b64e(USER_DATA),
'USERDATA_ENCODING': 'base64'}) 'USERDATA_ENCODING': 'base64'})
results = ds.read_context_disk_dir(my_d) results = ds.read_context_disk_dir(my_d)

View File

@@ -24,9 +24,9 @@
from __future__ import print_function from __future__ import print_function
import base64
from cloudinit import helpers as c_helpers from cloudinit import helpers as c_helpers
from cloudinit.sources import DataSourceSmartOS from cloudinit.sources import DataSourceSmartOS
from cloudinit.util import b64e
from .. import helpers from .. import helpers
import os import os
import os.path import os.path
@@ -36,12 +36,6 @@ import tempfile
import stat import stat
import uuid import uuid
def b64(source):
# In Python 3, b64encode only accepts bytes and returns bytes.
if not isinstance(source, bytes):
source = source.encode('utf-8')
return base64.b64encode(source).decode('us-ascii')
MOCK_RETURNS = { MOCK_RETURNS = {
'hostname': 'test-host', 'hostname': 'test-host',
@@ -239,7 +233,7 @@ class TestSmartOSDataSource(helpers.FilesystemMockingTestCase):
my_returns = MOCK_RETURNS.copy() my_returns = MOCK_RETURNS.copy()
my_returns['base64_all'] = "true" my_returns['base64_all'] = "true"
for k in ('hostname', 'cloud-init:user-data'): for k in ('hostname', 'cloud-init:user-data'):
my_returns[k] = b64(my_returns[k]) my_returns[k] = b64e(my_returns[k])
dsrc = self._get_ds(mockdata=my_returns) dsrc = self._get_ds(mockdata=my_returns)
ret = dsrc.get_data() ret = dsrc.get_data()
@@ -260,7 +254,7 @@ class TestSmartOSDataSource(helpers.FilesystemMockingTestCase):
my_returns['b64-cloud-init:user-data'] = "true" my_returns['b64-cloud-init:user-data'] = "true"
my_returns['b64-hostname'] = "true" my_returns['b64-hostname'] = "true"
for k in ('hostname', 'cloud-init:user-data'): for k in ('hostname', 'cloud-init:user-data'):
my_returns[k] = b64(my_returns[k]) my_returns[k] = b64e(my_returns[k])
dsrc = self._get_ds(mockdata=my_returns) dsrc = self._get_ds(mockdata=my_returns)
ret = dsrc.get_data() ret = dsrc.get_data()
@@ -276,7 +270,7 @@ class TestSmartOSDataSource(helpers.FilesystemMockingTestCase):
my_returns = MOCK_RETURNS.copy() my_returns = MOCK_RETURNS.copy()
my_returns['base64_keys'] = 'hostname,ignored' my_returns['base64_keys'] = 'hostname,ignored'
for k in ('hostname',): for k in ('hostname',):
my_returns[k] = b64(my_returns[k]) my_returns[k] = b64e(my_returns[k])
dsrc = self._get_ds(mockdata=my_returns) dsrc = self._get_ds(mockdata=my_returns)
ret = dsrc.get_data() ret = dsrc.get_data()

View File

@@ -18,7 +18,6 @@
from cloudinit.config import cc_seed_random from cloudinit.config import cc_seed_random
import base64
import gzip import gzip
import tempfile import tempfile
@@ -38,13 +37,6 @@ import logging
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
def b64(source):
# In Python 3, b64encode only accepts bytes and returns bytes.
if not isinstance(source, bytes):
source = source.encode('utf-8')
return base64.b64encode(source).decode('us-ascii')
class TestRandomSeed(t_help.TestCase): class TestRandomSeed(t_help.TestCase):
def setUp(self): def setUp(self):
super(TestRandomSeed, self).setUp() super(TestRandomSeed, self).setUp()
@@ -141,7 +133,7 @@ class TestRandomSeed(t_help.TestCase):
self.assertEquals("big-toe", contents) self.assertEquals("big-toe", contents)
def test_append_random_base64(self): def test_append_random_base64(self):
data = b64('bubbles') data = util.b64e('bubbles')
cfg = { cfg = {
'random_seed': { 'random_seed': {
'file': self._seed_file, 'file': self._seed_file,
@@ -154,7 +146,7 @@ class TestRandomSeed(t_help.TestCase):
self.assertEquals("bubbles", contents) self.assertEquals("bubbles", contents)
def test_append_random_b64(self): def test_append_random_b64(self):
data = b64('kit-kat') data = util.b64e('kit-kat')
cfg = { cfg = {
'random_seed': { 'random_seed': {
'file': self._seed_file, 'file': self._seed_file,