py3: Monkey-patch json.loads to accept bytes on py35

I'm tired of creating code churn where I just slap

    .decode("nearly arbitrary choice of encoding")

in a bunch of places.

Change-Id: I79b2bc59fed130ca537e96c1074212861d7db6b8
This commit is contained in:
Tim Burke 2018-11-02 21:38:49 +00:00
parent 887ba87c5a
commit c112203e0e
12 changed files with 70 additions and 35 deletions

View File

@ -14,6 +14,7 @@
# limitations under the License.
import os
import sys
import gettext
import pkg_resources
@ -39,3 +40,37 @@ _t = gettext.translation('swift', localedir=_localedir, fallback=True)
def gettext_(msg):
return _t.gettext(msg)
if (3, 0) <= sys.version_info[:2] <= (3, 5):
# In the development of py3, json.loads() stopped accepting byte strings
# for a while. https://bugs.python.org/issue17909 got fixed for py36, but
# since it was termed an enhancement and not a regression, we don't expect
# any backports. At the same time, it'd be better if we could avoid
# leaving a whole bunch of json.loads(resp.body.decode(...)) scars in the
# code that'd probably persist even *after* we drop support for 3.5 and
# earlier. So, monkey patch stdlib.
import json
if not getattr(json.loads, 'patched_to_decode', False):
class JsonLoadsPatcher(object):
def __init__(self, orig):
self._orig = orig
def __call__(self, s, **kw):
if isinstance(s, bytes):
# No fancy byte-order mark detection for us; just assume
# UTF-8 and raise a UnicodeDecodeError if appropriate.
s = s.decode('utf8')
return self._orig(s, **kw)
def __getattribute__(self, attr):
if attr == 'patched_to_decode':
return True
if attr == '_orig':
return super().__getattribute__(attr)
# Pass through all other attrs to the original; among other
# things, this preserves doc strings, etc.
return getattr(self._orig, attr)
json.loads = JsonLoadsPatcher(json.loads)
del JsonLoadsPatcher

View File

@ -174,7 +174,7 @@ def _get_direct_account_container(path, stype, node, part,
if resp.status == HTTP_NO_CONTENT:
resp.read()
return resp_headers, []
return resp_headers, json.loads(resp.read().decode('ascii'))
return resp_headers, json.loads(resp.read())
def gen_headers(hdrs_in=None, add_ts=True):

View File

@ -298,7 +298,7 @@ class InternalClient(object):
if resp.status_int >= HTTP_MULTIPLE_CHOICES:
b''.join(resp.app_iter)
break
data = json.loads(resp.body.decode('ascii'))
data = json.loads(resp.body)
if not data:
break
for item in data:
@ -844,7 +844,7 @@ class SimpleClient(object):
body = conn.read()
info = conn.info()
try:
body_data = json.loads(body.decode('ascii'))
body_data = json.loads(body)
except ValueError:
body_data = None
trans_stop = time()

View File

@ -315,7 +315,7 @@ class MemcacheRing(object):
else:
value = None
elif int(line[2]) & JSON_FLAG:
value = json.loads(value.decode('ascii'))
value = json.loads(value)
fp.readline()
line = fp.readline().strip().split()
self._return_conn(server, fp, sock)
@ -484,7 +484,7 @@ class MemcacheRing(object):
else:
value = None
elif int(line[2]) & JSON_FLAG:
value = json.loads(value.decode('ascii'))
value = json.loads(value)
responses[line[1]] = value
fp.readline()
line = fp.readline().strip().split()

View File

@ -185,7 +185,7 @@ class ListingFilter(object):
body = b''.join(resp_iter)
try:
listing = json.loads(body.decode('ascii'))
listing = json.loads(body)
# Do a couple sanity checks
if not isinstance(listing, list):
raise ValueError

View File

@ -295,7 +295,7 @@ class SymlinkContainerContext(WSGIContext):
"""
with closing_if_possible(resp_iter):
resp_body = b''.join(resp_iter)
body_json = json.loads(resp_body.decode('ascii'))
body_json = json.loads(resp_body)
swift_version, account, _junk = split_path(req.path, 2, 3, True)
new_body = json.dumps(
[self._extract_symlink_path_json(obj_dict, swift_version, account)

View File

@ -78,7 +78,7 @@ class RingData(object):
"""
json_len, = struct.unpack('!I', gz_file.read(4))
ring_dict = json.loads(gz_file.read(json_len).decode('ascii'))
ring_dict = json.loads(gz_file.read(json_len))
ring_dict['replica2part2dev_id'] = []
if metadata_only:

View File

@ -3477,7 +3477,7 @@ def dump_recon_cache(cache_dict, cache_file, logger, lock_timeout=2,
try:
existing_entry = cf.readline()
if existing_entry:
cache_entry = json.loads(existing_entry.decode('utf8'))
cache_entry = json.loads(existing_entry)
except ValueError:
# file doesn't have a valid entry, we'll recreate it
pass

View File

@ -59,7 +59,7 @@ class TestListingMiddleware(S3ApiTestCase):
req = Request.blank('/v1/a/c')
status, headers, body = self.call_s3api(req)
self.assertEqual(json.loads(body.decode('ascii')), [
self.assertEqual(json.loads(body), [
{'name': 'obj1', 'hash': '0123456789abcdef0123456789abcdef'},
{'name': 'obj2', 'hash': 'swiftetag', 's3_etag': '"mu-etag"'},
{'name': 'obj2', 'hash': 'swiftetag; something=else'},

View File

@ -240,7 +240,7 @@ class TestListEndpoints(unittest.TestCase):
self.list_endpoints)
self.assertEqual(resp.status_int, 200)
self.assertEqual(resp.content_type, 'application/json')
self.assertEqual(json.loads(resp.body.decode("utf-8")), [
self.assertEqual(json.loads(resp.body), [
"http://10.1.1.1:6200/sdb1/1/a/c/o1",
"http://10.1.2.2:6200/sdd1/1/a/c/o1"
])
@ -260,14 +260,14 @@ class TestListEndpoints(unittest.TestCase):
self.list_endpoints)
self.assertEqual(resp.status_int, 200)
self.assertEqual(resp.content_type, 'application/json')
self.assertEqual(json.loads(resp.body.decode("utf-8")),
self.assertEqual(json.loads(resp.body),
expected[pol.idx])
# Here, 'o1/' is the object name.
resp = Request.blank('/endpoints/a/c/o1/').get_response(
self.list_endpoints)
self.assertEqual(resp.status_int, 200)
self.assertEqual(json.loads(resp.body.decode("utf-8")), [
self.assertEqual(json.loads(resp.body), [
"http://10.1.1.1:6200/sdb1/3/a/c/o1/",
"http://10.1.2.2:6200/sdd1/3/a/c/o1/"
])
@ -275,7 +275,7 @@ class TestListEndpoints(unittest.TestCase):
resp = Request.blank('/endpoints/a/c2').get_response(
self.list_endpoints)
self.assertEqual(resp.status_int, 200)
self.assertEqual(json.loads(resp.body.decode("utf-8")), [
self.assertEqual(json.loads(resp.body), [
"http://10.1.1.1:6200/sda1/2/a/c2",
"http://10.1.2.1:6200/sdc1/2/a/c2"
])
@ -283,7 +283,7 @@ class TestListEndpoints(unittest.TestCase):
resp = Request.blank('/endpoints/a1').get_response(
self.list_endpoints)
self.assertEqual(resp.status_int, 200)
self.assertEqual(json.loads(resp.body.decode("utf-8")), [
self.assertEqual(json.loads(resp.body), [
"http://10.1.2.1:6200/sdc1/0/a1",
"http://10.1.1.1:6200/sda1/0/a1",
"http://10.1.1.1:6200/sdb1/0/a1"
@ -296,7 +296,7 @@ class TestListEndpoints(unittest.TestCase):
resp = Request.blank('/endpoints/a/c 2').get_response(
self.list_endpoints)
self.assertEqual(resp.status_int, 200)
self.assertEqual(json.loads(resp.body.decode("utf-8")), [
self.assertEqual(json.loads(resp.body), [
"http://10.1.1.1:6200/sdb1/3/a/c%202",
"http://10.1.2.2:6200/sdd1/3/a/c%202"
])
@ -304,7 +304,7 @@ class TestListEndpoints(unittest.TestCase):
resp = Request.blank('/endpoints/a/c%202').get_response(
self.list_endpoints)
self.assertEqual(resp.status_int, 200)
self.assertEqual(json.loads(resp.body.decode("utf-8")), [
self.assertEqual(json.loads(resp.body), [
"http://10.1.1.1:6200/sdb1/3/a/c%202",
"http://10.1.2.2:6200/sdd1/3/a/c%202"
])
@ -312,7 +312,7 @@ class TestListEndpoints(unittest.TestCase):
resp = Request.blank('/endpoints/ac%20count/con%20tainer/ob%20ject') \
.get_response(self.list_endpoints)
self.assertEqual(resp.status_int, 200)
self.assertEqual(json.loads(resp.body.decode("utf-8")), [
self.assertEqual(json.loads(resp.body), [
"http://10.1.1.1:6200/sdb1/3/ac%20count/con%20tainer/ob%20ject",
"http://10.1.2.2:6200/sdd1/3/ac%20count/con%20tainer/ob%20ject"
])
@ -342,7 +342,7 @@ class TestListEndpoints(unittest.TestCase):
.get_response(custom_path_le)
self.assertEqual(resp.status_int, 200)
self.assertEqual(resp.content_type, 'application/json')
self.assertEqual(json.loads(resp.body.decode("utf-8")),
self.assertEqual(json.loads(resp.body),
expected[pol.idx])
# test custom path without trailing slash
@ -356,7 +356,7 @@ class TestListEndpoints(unittest.TestCase):
.get_response(custom_path_le)
self.assertEqual(resp.status_int, 200)
self.assertEqual(resp.content_type, 'application/json')
self.assertEqual(json.loads(resp.body.decode("utf-8")),
self.assertEqual(json.loads(resp.body),
expected[pol.idx])
def test_v1_response(self):
@ -364,7 +364,7 @@ class TestListEndpoints(unittest.TestCase):
resp = req.get_response(self.list_endpoints)
expected = ["http://10.1.1.1:6200/sdb1/1/a/c/o1",
"http://10.1.2.2:6200/sdd1/1/a/c/o1"]
self.assertEqual(json.loads(resp.body.decode('utf-8')), expected)
self.assertEqual(json.loads(resp.body), expected)
def test_v2_obj_response(self):
req = Request.blank('/endpoints/v2/a/c/o1')
@ -374,7 +374,7 @@ class TestListEndpoints(unittest.TestCase):
"http://10.1.2.2:6200/sdd1/1/a/c/o1"],
'headers': {'X-Backend-Storage-Policy-Index': "0"},
}
self.assertEqual(json.loads(resp.body.decode('utf-8')), expected)
self.assertEqual(json.loads(resp.body), expected)
for policy in POLICIES:
patch_path = 'swift.common.middleware.list_endpoints' \
'.get_container_info'
@ -390,7 +390,7 @@ class TestListEndpoints(unittest.TestCase):
'X-Backend-Storage-Policy-Index': str(int(policy))},
'endpoints': [path % node for node in nodes],
}
self.assertEqual(json.loads(resp.body.decode('utf-8')), expected)
self.assertEqual(json.loads(resp.body), expected)
def test_v2_non_obj_response(self):
# account
@ -403,7 +403,7 @@ class TestListEndpoints(unittest.TestCase):
'headers': {},
}
# container
self.assertEqual(json.loads(resp.body.decode('utf-8')), expected)
self.assertEqual(json.loads(resp.body), expected)
req = Request.blank('/endpoints/v2/a/c')
resp = req.get_response(self.list_endpoints)
expected = {
@ -412,7 +412,7 @@ class TestListEndpoints(unittest.TestCase):
"http://10.1.2.1:6200/sdc1/0/a/c"],
'headers': {},
}
self.assertEqual(json.loads(resp.body.decode('utf-8')), expected)
self.assertEqual(json.loads(resp.body), expected)
def test_version_account_response(self):
req = Request.blank('/endpoints/a')
@ -420,10 +420,10 @@ class TestListEndpoints(unittest.TestCase):
expected = ["http://10.1.2.1:6200/sdc1/0/a",
"http://10.1.1.1:6200/sda1/0/a",
"http://10.1.1.1:6200/sdb1/0/a"]
self.assertEqual(json.loads(resp.body.decode('utf-8')), expected)
self.assertEqual(json.loads(resp.body), expected)
req = Request.blank('/endpoints/v1.0/a')
resp = req.get_response(self.list_endpoints)
self.assertEqual(json.loads(resp.body.decode('utf-8')), expected)
self.assertEqual(json.loads(resp.body), expected)
req = Request.blank('/endpoints/v2/a')
resp = req.get_response(self.list_endpoints)
@ -433,7 +433,7 @@ class TestListEndpoints(unittest.TestCase):
"http://10.1.1.1:6200/sdb1/0/a"],
'headers': {},
}
self.assertEqual(json.loads(resp.body.decode('utf-8')), expected)
self.assertEqual(json.loads(resp.body), expected)
if __name__ == '__main__':

View File

@ -213,7 +213,7 @@ class TestDirectClient(unittest.TestCase):
self.assertEqual(conn.req_headers['user-agent'],
self.user_agent)
self.assertEqual(resp_headers, stub_headers)
self.assertEqual(json.loads(body.decode('ascii')), resp)
self.assertEqual(json.loads(body), resp)
self.assertIn('format=json', conn.query_string)
for k, v in req_params.items():
if v is None:
@ -389,7 +389,7 @@ class TestDirectClient(unittest.TestCase):
self.assertEqual(conn.req_headers['user-agent'],
self.user_agent)
self.assertEqual(headers, resp_headers)
self.assertEqual(json.loads(body.decode('ascii')), resp)
self.assertEqual(json.loads(body), resp)
self.assertIn('format=json', conn.query_string)
for k, v in req_params.items():
if v is None:

View File

@ -62,7 +62,7 @@ class TestInfoController(unittest.TestCase):
resp = controller.GET(req)
self.assertIsInstance(resp, HTTPException)
self.assertEqual('200 OK', str(resp))
info = json.loads(resp.body.decode('ascii'))
info = json.loads(resp.body)
self.assertNotIn('admin', info)
self.assertIn('foo', info)
self.assertIn('bar', info['foo'])
@ -89,7 +89,7 @@ class TestInfoController(unittest.TestCase):
resp = controller.GET(req)
self.assertIsInstance(resp, HTTPException)
self.assertEqual('200 OK', str(resp))
info = json.loads(resp.body.decode('ascii'))
info = json.loads(resp.body)
self.assertNotIn('admin', info)
self.assertIn('foo', info)
self.assertIn('bar', info['foo'])
@ -120,7 +120,7 @@ class TestInfoController(unittest.TestCase):
resp = controller.GET(req)
self.assertIsInstance(resp, HTTPException)
self.assertEqual('200 OK', str(resp))
info = json.loads(resp.body.decode('ascii'))
info = json.loads(resp.body)
self.assertIn('foo', info)
self.assertIn('bar', info['foo'])
self.assertEqual(info['foo']['bar'], 'baz')
@ -156,7 +156,7 @@ class TestInfoController(unittest.TestCase):
resp = controller.GET(req)
self.assertIsInstance(resp, HTTPException)
self.assertEqual('200 OK', str(resp))
info = json.loads(resp.body.decode('ascii'))
info = json.loads(resp.body)
self.assertIn('admin', info)
self.assertIn('qux', info['admin'])
self.assertIn('quux', info['admin']['qux'])
@ -279,7 +279,7 @@ class TestInfoController(unittest.TestCase):
resp = controller.GET(req)
self.assertIsInstance(resp, HTTPException)
self.assertEqual('200 OK', str(resp))
info = json.loads(resp.body.decode('ascii'))
info = json.loads(resp.body)
self.assertNotIn('foo2', info)
self.assertIn('admin', info)
self.assertIn('disallowed_sections', info['admin'])