diff --git a/swiftclient/utils.py b/swiftclient/utils.py index 0abaed6f..10687bf4 100644 --- a/swiftclient/utils.py +++ b/swiftclient/utils.py @@ -78,15 +78,20 @@ def generate_temp_url(path, seconds, key, method, absolute=False): :raises: TypeError if seconds is not an integer :return: the path portion of a temporary URL """ + try: + seconds = int(seconds) + except ValueError: + raise TypeError('seconds must be an integer') if seconds < 0: raise ValueError('seconds must be a positive integer') - try: - if not absolute: - expiration = int(time.time() + seconds) - else: - expiration = int(seconds) - except TypeError: - raise TypeError('seconds must be an integer') + + if isinstance(path, six.binary_type): + try: + path_for_body = path.decode('utf-8') + except UnicodeDecodeError: + raise ValueError('path must be representable as UTF-8') + else: + path_for_body = path standard_methods = ['GET', 'PUT', 'HEAD', 'POST', 'DELETE'] if method.upper() not in standard_methods: @@ -94,18 +99,24 @@ def generate_temp_url(path, seconds, key, method, absolute=False): logger.warning('Non default HTTP method %s for tempurl specified, ' 'possibly an error', method.upper()) - hmac_body = '\n'.join([method.upper(), str(expiration), path]) + if not absolute: + expiration = int(time.time() + seconds) + else: + expiration = seconds + hmac_body = u'\n'.join([method.upper(), str(expiration), path_for_body]) # Encode to UTF-8 for py3 compatibility - sig = hmac.new(key.encode(), - hmac_body.encode(), - hashlib.sha1).hexdigest() + if not isinstance(key, six.binary_type): + key = key.encode('utf-8') + sig = hmac.new(key, hmac_body.encode('utf-8'), hashlib.sha1).hexdigest() - return ('{path}?temp_url_sig=' - '{sig}&temp_url_expires={exp}'.format( - path=path, - sig=sig, - exp=expiration)) + temp_url = u'{path}?temp_url_sig={sig}&temp_url_expires={exp}'.format( + path=path_for_body, sig=sig, exp=expiration) + # Have return type match path from caller + if isinstance(path, six.binary_type): + return temp_url.encode('utf-8') + else: + return temp_url def parse_api_response(headers, body): diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index aae466c7..0f210a3c 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -17,7 +17,7 @@ import unittest import mock import six import tempfile -from hashlib import md5 +from hashlib import md5, sha1 from swiftclient import utils as u @@ -120,48 +120,124 @@ class TestPrtBytes(unittest.TestCase): class TestTempURL(unittest.TestCase): + url = '/v1/AUTH_account/c/o' + seconds = 3600 + key = 'correcthorsebatterystaple' + method = 'GET' + expected_url = url + ('?temp_url_sig=temp_url_signature' + '&temp_url_expires=1400003600') + expected_body = '\n'.join([ + method, + '1400003600', + url, + ]).encode('utf-8') - def setUp(self): - super(TestTempURL, self).setUp() - self.url = '/v1/AUTH_account/c/o' - self.seconds = 3600 - self.key = 'correcthorsebatterystaple' - self.method = 'GET' - - @mock.patch('hmac.HMAC.hexdigest', return_value='temp_url_signature') + @mock.patch('hmac.HMAC') @mock.patch('time.time', return_value=1400000000) def test_generate_temp_url(self, time_mock, hmac_mock): - expected_url = ( - '/v1/AUTH_account/c/o?' - 'temp_url_sig=temp_url_signature&' - 'temp_url_expires=1400003600') - url = u.generate_temp_url(self.url, self.seconds, self.key, - self.method) - self.assertEqual(url, expected_url) + hmac_mock().hexdigest.return_value = 'temp_url_signature' + url = u.generate_temp_url(self.url, self.seconds, + self.key, self.method) + key = self.key + if not isinstance(key, six.binary_type): + key = key.encode('utf-8') + self.assertEqual(url, self.expected_url) + self.assertEqual(hmac_mock.mock_calls, [ + mock.call(), + mock.call(key, self.expected_body, sha1), + mock.call().hexdigest(), + ]) + self.assertIsInstance(url, type(self.url)) + + def test_generate_temp_url_invalid_path(self): + with self.assertRaises(ValueError) as exc_manager: + u.generate_temp_url(b'/v1/a/c/\xff', self.seconds, self.key, + self.method) + self.assertEqual(exc_manager.exception.args[0], + 'path must be representable as UTF-8') @mock.patch('hmac.HMAC.hexdigest', return_value="temp_url_signature") def test_generate_absolute_expiry_temp_url(self, hmac_mock): - expected_url = ('/v1/AUTH_account/c/o?' - 'temp_url_sig=temp_url_signature&' - 'temp_url_expires=2146636800') + if isinstance(self.expected_url, six.binary_type): + expected_url = self.expected_url.replace( + b'1400003600', b'2146636800') + else: + expected_url = self.expected_url.replace( + u'1400003600', u'2146636800') url = u.generate_temp_url(self.url, 2146636800, self.key, self.method, absolute=True) self.assertEqual(url, expected_url) def test_generate_temp_url_bad_seconds(self): - self.assertRaises(TypeError, - u.generate_temp_url, - self.url, - 'not_an_int', - self.key, - self.method) + with self.assertRaises(TypeError) as exc_manager: + u.generate_temp_url(self.url, 'not_an_int', self.key, self.method) + self.assertEqual(exc_manager.exception.args[0], + 'seconds must be an integer') - self.assertRaises(ValueError, - u.generate_temp_url, - self.url, - -1, - self.key, - self.method) + with self.assertRaises(ValueError) as exc_manager: + u.generate_temp_url(self.url, -1, self.key, self.method) + self.assertEqual(exc_manager.exception.args[0], + 'seconds must be a positive integer') + + +class TestTempURLUnicodePathAndKey(TestTempURL): + url = u'/v1/\u00e4/c/\u00f3' + key = u'k\u00e9y' + expected_url = (u'%s?temp_url_sig=temp_url_signature' + u'&temp_url_expires=1400003600') % url + expected_body = u'\n'.join([ + u'GET', + u'1400003600', + url, + ]).encode('utf-8') + + +class TestTempURLUnicodePathBytesKey(TestTempURL): + url = u'/v1/\u00e4/c/\u00f3' + key = u'k\u00e9y'.encode('utf-8') + expected_url = (u'%s?temp_url_sig=temp_url_signature' + u'&temp_url_expires=1400003600') % url + expected_body = '\n'.join([ + u'GET', + u'1400003600', + url, + ]).encode('utf-8') + + +class TestTempURLBytesPathUnicodeKey(TestTempURL): + url = u'/v1/\u00e4/c/\u00f3'.encode('utf-8') + key = u'k\u00e9y' + expected_url = url + (b'?temp_url_sig=temp_url_signature' + b'&temp_url_expires=1400003600') + expected_body = b'\n'.join([ + b'GET', + b'1400003600', + url, + ]) + + +class TestTempURLBytesPathAndKey(TestTempURL): + url = u'/v1/\u00e4/c/\u00f3'.encode('utf-8') + key = u'k\u00e9y'.encode('utf-8') + expected_url = url + (b'?temp_url_sig=temp_url_signature' + b'&temp_url_expires=1400003600') + expected_body = b'\n'.join([ + b'GET', + b'1400003600', + url, + ]) + + +class TestTempURLBytesPathAndNonUtf8Key(TestTempURL): + url = u'/v1/\u00e4/c/\u00f3'.encode('utf-8') + key = b'k\xffy' + expected_url = url + (b'?temp_url_sig=temp_url_signature' + b'&temp_url_expires=1400003600') + expected_body = b'\n'.join([ + b'GET', + b'1400003600', + url, + ]) class TestReadableToIterable(unittest.TestCase):