From dca4ea5fd6e1a4eb15eeea9879913f64f2851446 Mon Sep 17 00:00:00 2001
From: Tim Burke <tim.burke@gmail.com>
Date: Wed, 29 Jun 2016 11:07:41 -0700
Subject: [PATCH] Fix unicode issues in tempurl command

Previously, we weren't encoding paths and keys as UTF-8, which would
trigger a UnicodeEncodeError on py27.

Change-Id: I2fad428369406c2ae32343a5e943ffb2cd1ca6ef
---
 swiftclient/utils.py     |  43 +++++++-----
 tests/unit/test_utils.py | 138 ++++++++++++++++++++++++++++++---------
 2 files changed, 134 insertions(+), 47 deletions(-)

diff --git a/swiftclient/utils.py b/swiftclient/utils.py
index 0abaed6f..10687bf4 100644
--- a/swiftclient/utils.py
+++ b/swiftclient/utils.py
@@ -78,15 +78,20 @@ def generate_temp_url(path, seconds, key, method, absolute=False):
     :raises: TypeError if seconds is not an integer
     :return: the path portion of a temporary URL
     """
+    try:
+        seconds = int(seconds)
+    except ValueError:
+        raise TypeError('seconds must be an integer')
     if seconds < 0:
         raise ValueError('seconds must be a positive integer')
-    try:
-        if not absolute:
-            expiration = int(time.time() + seconds)
-        else:
-            expiration = int(seconds)
-    except TypeError:
-        raise TypeError('seconds must be an integer')
+
+    if isinstance(path, six.binary_type):
+        try:
+            path_for_body = path.decode('utf-8')
+        except UnicodeDecodeError:
+            raise ValueError('path must be representable as UTF-8')
+    else:
+        path_for_body = path
 
     standard_methods = ['GET', 'PUT', 'HEAD', 'POST', 'DELETE']
     if method.upper() not in standard_methods:
@@ -94,18 +99,24 @@ def generate_temp_url(path, seconds, key, method, absolute=False):
         logger.warning('Non default HTTP method %s for tempurl specified, '
                        'possibly an error', method.upper())
 
-    hmac_body = '\n'.join([method.upper(), str(expiration), path])
+    if not absolute:
+        expiration = int(time.time() + seconds)
+    else:
+        expiration = seconds
+    hmac_body = u'\n'.join([method.upper(), str(expiration), path_for_body])
 
     # Encode to UTF-8 for py3 compatibility
-    sig = hmac.new(key.encode(),
-                   hmac_body.encode(),
-                   hashlib.sha1).hexdigest()
+    if not isinstance(key, six.binary_type):
+        key = key.encode('utf-8')
+    sig = hmac.new(key, hmac_body.encode('utf-8'), hashlib.sha1).hexdigest()
 
-    return ('{path}?temp_url_sig='
-            '{sig}&temp_url_expires={exp}'.format(
-                path=path,
-                sig=sig,
-                exp=expiration))
+    temp_url = u'{path}?temp_url_sig={sig}&temp_url_expires={exp}'.format(
+        path=path_for_body, sig=sig, exp=expiration)
+    # Have return type match path from caller
+    if isinstance(path, six.binary_type):
+        return temp_url.encode('utf-8')
+    else:
+        return temp_url
 
 
 def parse_api_response(headers, body):
diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py
index aae466c7..0f210a3c 100644
--- a/tests/unit/test_utils.py
+++ b/tests/unit/test_utils.py
@@ -17,7 +17,7 @@ import unittest
 import mock
 import six
 import tempfile
-from hashlib import md5
+from hashlib import md5, sha1
 
 from swiftclient import utils as u
 
@@ -120,48 +120,124 @@ class TestPrtBytes(unittest.TestCase):
 
 
 class TestTempURL(unittest.TestCase):
+    url = '/v1/AUTH_account/c/o'
+    seconds = 3600
+    key = 'correcthorsebatterystaple'
+    method = 'GET'
+    expected_url = url + ('?temp_url_sig=temp_url_signature'
+                          '&temp_url_expires=1400003600')
+    expected_body = '\n'.join([
+        method,
+        '1400003600',
+        url,
+    ]).encode('utf-8')
 
-    def setUp(self):
-        super(TestTempURL, self).setUp()
-        self.url = '/v1/AUTH_account/c/o'
-        self.seconds = 3600
-        self.key = 'correcthorsebatterystaple'
-        self.method = 'GET'
-
-    @mock.patch('hmac.HMAC.hexdigest', return_value='temp_url_signature')
+    @mock.patch('hmac.HMAC')
     @mock.patch('time.time', return_value=1400000000)
     def test_generate_temp_url(self, time_mock, hmac_mock):
-        expected_url = (
-            '/v1/AUTH_account/c/o?'
-            'temp_url_sig=temp_url_signature&'
-            'temp_url_expires=1400003600')
-        url = u.generate_temp_url(self.url, self.seconds, self.key,
-                                  self.method)
-        self.assertEqual(url, expected_url)
+        hmac_mock().hexdigest.return_value = 'temp_url_signature'
+        url = u.generate_temp_url(self.url, self.seconds,
+                                  self.key, self.method)
+        key = self.key
+        if not isinstance(key, six.binary_type):
+            key = key.encode('utf-8')
+        self.assertEqual(url, self.expected_url)
+        self.assertEqual(hmac_mock.mock_calls, [
+            mock.call(),
+            mock.call(key, self.expected_body, sha1),
+            mock.call().hexdigest(),
+        ])
+        self.assertIsInstance(url, type(self.url))
+
+    def test_generate_temp_url_invalid_path(self):
+        with self.assertRaises(ValueError) as exc_manager:
+            u.generate_temp_url(b'/v1/a/c/\xff', self.seconds, self.key,
+                                self.method)
+        self.assertEqual(exc_manager.exception.args[0],
+                         'path must be representable as UTF-8')
 
     @mock.patch('hmac.HMAC.hexdigest', return_value="temp_url_signature")
     def test_generate_absolute_expiry_temp_url(self, hmac_mock):
-        expected_url = ('/v1/AUTH_account/c/o?'
-                        'temp_url_sig=temp_url_signature&'
-                        'temp_url_expires=2146636800')
+        if isinstance(self.expected_url, six.binary_type):
+            expected_url = self.expected_url.replace(
+                b'1400003600', b'2146636800')
+        else:
+            expected_url = self.expected_url.replace(
+                u'1400003600', u'2146636800')
         url = u.generate_temp_url(self.url, 2146636800, self.key, self.method,
                                   absolute=True)
         self.assertEqual(url, expected_url)
 
     def test_generate_temp_url_bad_seconds(self):
-        self.assertRaises(TypeError,
-                          u.generate_temp_url,
-                          self.url,
-                          'not_an_int',
-                          self.key,
-                          self.method)
+        with self.assertRaises(TypeError) as exc_manager:
+            u.generate_temp_url(self.url, 'not_an_int', self.key, self.method)
+        self.assertEqual(exc_manager.exception.args[0],
+                         'seconds must be an integer')
 
-        self.assertRaises(ValueError,
-                          u.generate_temp_url,
-                          self.url,
-                          -1,
-                          self.key,
-                          self.method)
+        with self.assertRaises(ValueError) as exc_manager:
+            u.generate_temp_url(self.url, -1, self.key, self.method)
+        self.assertEqual(exc_manager.exception.args[0],
+                         'seconds must be a positive integer')
+
+
+class TestTempURLUnicodePathAndKey(TestTempURL):
+    url = u'/v1/\u00e4/c/\u00f3'
+    key = u'k\u00e9y'
+    expected_url = (u'%s?temp_url_sig=temp_url_signature'
+                    u'&temp_url_expires=1400003600') % url
+    expected_body = u'\n'.join([
+        u'GET',
+        u'1400003600',
+        url,
+    ]).encode('utf-8')
+
+
+class TestTempURLUnicodePathBytesKey(TestTempURL):
+    url = u'/v1/\u00e4/c/\u00f3'
+    key = u'k\u00e9y'.encode('utf-8')
+    expected_url = (u'%s?temp_url_sig=temp_url_signature'
+                    u'&temp_url_expires=1400003600') % url
+    expected_body = '\n'.join([
+        u'GET',
+        u'1400003600',
+        url,
+    ]).encode('utf-8')
+
+
+class TestTempURLBytesPathUnicodeKey(TestTempURL):
+    url = u'/v1/\u00e4/c/\u00f3'.encode('utf-8')
+    key = u'k\u00e9y'
+    expected_url = url + (b'?temp_url_sig=temp_url_signature'
+                          b'&temp_url_expires=1400003600')
+    expected_body = b'\n'.join([
+        b'GET',
+        b'1400003600',
+        url,
+    ])
+
+
+class TestTempURLBytesPathAndKey(TestTempURL):
+    url = u'/v1/\u00e4/c/\u00f3'.encode('utf-8')
+    key = u'k\u00e9y'.encode('utf-8')
+    expected_url = url + (b'?temp_url_sig=temp_url_signature'
+                          b'&temp_url_expires=1400003600')
+    expected_body = b'\n'.join([
+        b'GET',
+        b'1400003600',
+        url,
+    ])
+
+
+class TestTempURLBytesPathAndNonUtf8Key(TestTempURL):
+    url = u'/v1/\u00e4/c/\u00f3'.encode('utf-8')
+    key = b'k\xffy'
+    expected_url = url + (b'?temp_url_sig=temp_url_signature'
+                          b'&temp_url_expires=1400003600')
+    expected_body = b'\n'.join([
+        b'GET',
+        b'1400003600',
+        url,
+    ])
 
 
 class TestReadableToIterable(unittest.TestCase):