diff --git a/swift/common/middleware/crypto/keymaster.py b/swift/common/middleware/crypto/keymaster.py index da67c50845..40e3698a7d 100644 --- a/swift/common/middleware/crypto/keymaster.py +++ b/swift/common/middleware/crypto/keymaster.py @@ -18,7 +18,7 @@ import os from swift.common.middleware.crypto.crypto_utils import CRYPTO_KEY_CALLBACK from swift.common.swob import Request, HTTPException -from swift.common.utils import readconf, base64decode +from swift.common.utils import readconf, strict_b64decode from swift.common.wsgi import WSGIContext @@ -137,8 +137,8 @@ class KeyMaster(object): conf = readconf(self.keymaster_config_path, 'keymaster') b64_root_secret = conf.get('encryption_root_secret') try: - binary_root_secret = base64decode(b64_root_secret, - allow_line_breaks=True) + binary_root_secret = strict_b64decode(b64_root_secret, + allow_line_breaks=True) if len(binary_root_secret) < 32: raise ValueError return binary_root_secret diff --git a/swift/common/utils.py b/swift/common/utils.py index 2ca9ea6a06..80a645b5f3 100644 --- a/swift/common/utils.py +++ b/swift/common/utils.py @@ -4343,7 +4343,7 @@ def safe_json_loads(value): return None -def base64decode(value, allow_line_breaks=False): +def strict_b64decode(value, allow_line_breaks=False): ''' Validate and decode Base64-encoded data. diff --git a/test/unit/common/test_utils.py b/test/unit/common/test_utils.py index a1f9aa736f..2355d984c3 100644 --- a/test/unit/common/test_utils.py +++ b/test/unit/common/test_utils.py @@ -38,6 +38,7 @@ import string import sys import json import math +import inspect import six from six import BytesIO, StringIO @@ -3895,7 +3896,7 @@ cluster_dfw1 = http://dfw1.host/v1/ self.fail('Invalid results from pure function:\n%s' % '\n'.join(failures)) - def test_base64decode(self): + def test_strict_b64decode(self): expectations = { None: ValueError, 0: ValueError, @@ -3920,19 +3921,23 @@ cluster_dfw1 = http://dfw1.host/v1/ failures = [] for value, expected in expectations.items(): - if expected is ValueError: - try: - result = utils.base64decode(value) - except ValueError: - pass + try: + result = utils.strict_b64decode(value) + except Exception as e: + if inspect.isclass(expected) and issubclass( + expected, Exception): + if not isinstance(e, expected): + failures.append('%r raised %r (expected to raise %r)' % + (value, e, expected)) else: - failures.append('%r => %r (expected to raise ValueError)' % - (value, result)) + failures.append('%r raised %r (expected to return %r)' % + (value, e, expected)) else: - try: - result = utils.base64decode(value) - self.assertEqual(expected, result) - except AssertionError: + if inspect.isclass(expected) and issubclass( + expected, Exception): + failures.append('%r => %r (expected to raise %r)' % + (value, result, expected)) + elif result != expected: failures.append('%r => %r (expected %r)' % ( value, result, expected)) if failures: