From ddb969a83964e858f6cd812d0e1b26f5c421f602 Mon Sep 17 00:00:00 2001 From: Joe Gregorio Date: Wed, 11 Jul 2012 11:04:12 -0400 Subject: [PATCH] oauth2client support for URL-encoded format of exchange token response (e.g. Facebook) Contribution from crhyme. Reviewed in http://codereview.appspot.com/6352088/. --- oauth2client/client.py | 45 +++++++++++++++++++++-------- tests/test_oauth2client.py | 58 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+), 11 deletions(-) diff --git a/oauth2client/client.py b/oauth2client/client.py index e8e9001..a5f2c2d 100644 --- a/oauth2client/client.py +++ b/oauth2client/client.py @@ -888,6 +888,33 @@ def _extract_id_token(id_token): return simplejson.loads(_urlsafe_b64decode(segments[1])) +def _parse_exchange_token_response(content): + """Parses response of an exchange token request. + + Most providers return JSON but some (e.g. Facebook) return a + url-encoded string. + + Args: + content: The body of a response + + Returns: + Content as a dictionary object. Note that the dict could be empty, + i.e. {}. That basically indicates a failure. + """ + resp = {} + try: + resp = simplejson.loads(content) + except StandardError: + # different JSON libs raise different exceptions, + # so we just do a catch-all here + resp = dict(parse_qsl(content)) + + # some providers respond with 'expires', others with 'expires_in' + if resp and 'expires' in resp: + resp['expires_in'] = resp.pop('expires') + + return resp + def credentials_from_code(client_id, client_secret, scope, code, redirect_uri = 'postmessage', http=None, user_agent=None, @@ -1074,9 +1101,8 @@ class OAuth2WebServerFlow(Flow): resp, content = http.request(self.token_uri, method='POST', body=body, headers=headers) - if resp.status == 200: - # TODO(jcgregorio) Raise an error if simplejson.loads fails? - d = simplejson.loads(content) + d = _parse_exchange_token_response(content) + if resp.status == 200 and 'access_token' in d: access_token = d['access_token'] refresh_token = d.get('refresh_token', None) token_expiry = None @@ -1094,14 +1120,11 @@ class OAuth2WebServerFlow(Flow): id_token=d.get('id_token', None)) else: logger.info('Failed to retrieve access token: %s' % content) - error_msg = 'Invalid response %s.' % resp['status'] - try: - d = simplejson.loads(content) - if 'error' in d: - error_msg = d['error'] - except: - pass - + if 'error' in d: + # you never know what those providers got to say + error_msg = unicode(d['error']) + else: + error_msg = 'Invalid response: %s.' % str(resp.status) raise FlowExchangeError(error_msg) def flow_from_clientsecrets(filename, scope, message=None): diff --git a/tests/test_oauth2client.py b/tests/test_oauth2client.py index eb2d7db..49433df 100644 --- a/tests/test_oauth2client.py +++ b/tests/test_oauth2client.py @@ -257,6 +257,35 @@ class OAuth2WebServerFlowTest(unittest.TestCase): except FlowExchangeError: pass + def test_urlencoded_exchange_failure(self): + http = HttpMockSequence([ + ({'status': '400'}, "error=invalid_request"), + ]) + + try: + credentials = self.flow.step2_exchange('some random code', http) + self.fail("should raise exception if exchange doesn't get 200") + except FlowExchangeError, e: + self.assertEquals('invalid_request', str(e)) + + def test_exchange_failure_with_json_error(self): + # Some providers have "error" attribute as a JSON object + # in place of regular string. + # This test makes sure no strange object-to-string coversion + # exceptions are being raised instead of FlowExchangeError. + http = HttpMockSequence([ + ({'status': '400'}, + """ {"error": { + "type": "OAuthException", + "message": "Error validating verification code."} }"""), + ]) + + try: + credentials = self.flow.step2_exchange('some random code', http) + self.fail("should raise exception if exchange doesn't get 200") + except FlowExchangeError, e: + pass + def test_exchange_success(self): http = HttpMockSequence([ ({'status': '200'}, @@ -270,6 +299,25 @@ class OAuth2WebServerFlowTest(unittest.TestCase): self.assertNotEqual(None, credentials.token_expiry) self.assertEqual('8xLOxBtZp8', credentials.refresh_token) + def test_urlencoded_exchange_success(self): + http = HttpMockSequence([ + ({'status': '200'}, "access_token=SlAV32hkKG&expires_in=3600"), + ]) + + credentials = self.flow.step2_exchange('some random code', http) + self.assertEqual('SlAV32hkKG', credentials.access_token) + self.assertNotEqual(None, credentials.token_expiry) + + def test_urlencoded_expires_param(self): + http = HttpMockSequence([ + # Note the "expires=3600" where you'd normally + # have if named "expires_in" + ({'status': '200'}, "access_token=SlAV32hkKG&expires=3600"), + ]) + + credentials = self.flow.step2_exchange('some random code', http) + self.assertNotEqual(None, credentials.token_expiry) + def test_exchange_no_expires_in(self): http = HttpMockSequence([ ({'status': '200'}, """{ "access_token":"SlAV32hkKG", @@ -279,6 +327,16 @@ class OAuth2WebServerFlowTest(unittest.TestCase): credentials = self.flow.step2_exchange('some random code', http) self.assertEqual(None, credentials.token_expiry) + def test_urlencoded_exchange_no_expires_in(self): + http = HttpMockSequence([ + # This might be redundant but just to make sure + # urlencoded access_token gets parsed correctly + ({'status': '200'}, "access_token=SlAV32hkKG"), + ]) + + credentials = self.flow.step2_exchange('some random code', http) + self.assertEqual(None, credentials.token_expiry) + def test_exchange_fails_if_no_code(self): http = HttpMockSequence([ ({'status': '200'}, """{ "access_token":"SlAV32hkKG",