Complete branches from partial test coverages (#629)

This commit is contained in:
Pat Ferate
2016-08-12 14:15:31 -07:00
committed by Jon Wayne Parrott
parent c9b4b07525
commit 5137d7e837
2 changed files with 66 additions and 19 deletions

View File

@@ -15,6 +15,7 @@
import datetime import datetime
import unittest import unittest
import mock
import sqlalchemy import sqlalchemy
import sqlalchemy.ext.declarative import sqlalchemy.ext.declarative
import sqlalchemy.orm import sqlalchemy.orm
@@ -66,7 +67,8 @@ class TestSQLAlchemyStorage(unittest.TestCase):
self.assertEqual(result.token_uri, self.credentials.token_uri) self.assertEqual(result.token_uri, self.credentials.token_uri)
self.assertEqual(result.user_agent, self.credentials.user_agent) self.assertEqual(result.user_agent, self.credentials.user_agent)
def test_get(self): @mock.patch('oauth2client.client.OAuth2Credentials.set_store')
def test_get(self, set_store):
session = self.session() session = self.session()
credentials_storage = oauth2client.contrib.sqlalchemy.Storage( credentials_storage = oauth2client.contrib.sqlalchemy.Storage(
session=session, session=session,
@@ -75,7 +77,21 @@ class TestSQLAlchemyStorage(unittest.TestCase):
key_value=1, key_value=1,
property_name='credentials', property_name='credentials',
) )
# No credentials stored
self.assertIsNone(credentials_storage.get()) self.assertIsNone(credentials_storage.get())
# Invalid credentials stored
session.add(DummyModel(
key=1,
credentials=oauth2client.client.Credentials(),
))
session.commit()
bad_credentials = credentials_storage.get()
self.assertIsInstance(bad_credentials, oauth2client.client.Credentials)
set_store.assert_not_called()
# Valid credentials stored
session.query(DummyModel).filter_by(key=1).delete()
session.add(DummyModel( session.add(DummyModel(
key=1, key=1,
credentials=self.credentials, credentials=self.credentials,
@@ -83,16 +99,20 @@ class TestSQLAlchemyStorage(unittest.TestCase):
session.commit() session.commit()
self.compare_credentials(credentials_storage.get()) self.compare_credentials(credentials_storage.get())
set_store.assert_called_with(credentials_storage)
def test_put(self): def test_put(self):
session = self.session() session = self.session()
oauth2client.contrib.sqlalchemy.Storage( storage = oauth2client.contrib.sqlalchemy.Storage(
session=session, session=session,
model_class=DummyModel, model_class=DummyModel,
key_name='key', key_name='key',
key_value=1, key_value=1,
property_name='credentials', property_name='credentials',
).put(self.credentials) )
# Store invalid credentials first to verify overwriting
storage.put(oauth2client.client.Credentials())
storage.put(self.credentials)
session.commit() session.commit()
entity = session.query(DummyModel).filter_by(key=1).first() entity = session.query(DummyModel).filter_by(key=1).first()

View File

@@ -1619,6 +1619,9 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
user_agent='unittest-sample/1.0', user_agent='unittest-sample/1.0',
revoke_uri='dummy_revoke_uri', revoke_uri='dummy_revoke_uri',
) )
self.bad_verifier = b'__NOT_THE_VERIFIER_YOURE_LOOKING_FOR__'
self.good_verifier = b'__TEST_VERIFIER__'
self.good_challenger = b'__TEST_CHALLENGE__'
def test_construct_authorize_url(self): def test_construct_authorize_url(self):
authorize_url = self.flow.step1_get_authorize_url(state='state+1') authorize_url = self.flow.step1_get_authorize_url(state='state+1')
@@ -1691,19 +1694,42 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
@mock.patch('oauth2client.client._pkce.code_challenge') @mock.patch('oauth2client.client._pkce.code_challenge')
@mock.patch('oauth2client.client._pkce.code_verifier') @mock.patch('oauth2client.client._pkce.code_verifier')
def test_step1_get_authorize_url_pkce(self, fake_verifier, fake_challenge): def test_step1_get_authorize_url_pkce(self, fake_verifier, fake_challenge):
fake_verifier.return_value = b'__TEST_VERIFIER__' fake_verifier.return_value = self.good_verifier
fake_challenge.return_value = b'__TEST_CHALLENGE__' fake_challenge.return_value = self.good_challenger
flow = client.OAuth2WebServerFlow( flow = client.OAuth2WebServerFlow(
'client_id+1', 'client_id+1',
scope='foo', scope='foo',
redirect_uri='http://example.com', redirect_uri='http://example.com',
pkce=True) pkce=True)
auth_url = urllib.parse.urlparse(flow.step1_get_authorize_url()) auth_url = urllib.parse.urlparse(flow.step1_get_authorize_url())
self.assertEqual(flow.code_verifier, b'__TEST_VERIFIER__') self.assertEqual(flow.code_verifier, self.good_verifier)
results = dict(urllib.parse.parse_qsl(auth_url.query)) results = dict(urllib.parse.parse_qsl(auth_url.query))
self.assertEqual(results['code_challenge'], '__TEST_CHALLENGE__') self.assertEqual(
results['code_challenge'], self.good_challenger.decode())
self.assertEqual(results['code_challenge_method'], 'S256') self.assertEqual(results['code_challenge_method'], 'S256')
fake_challenge.assert_called_with(b'__TEST_VERIFIER__') fake_verifier.assert_called()
fake_challenge.assert_called_with(self.good_verifier)
@mock.patch('oauth2client.client._pkce.code_challenge')
@mock.patch('oauth2client.client._pkce.code_verifier')
def test_step1_get_authorize_url_pkce_invalid_verifier(
self, fake_verifier, fake_challenge):
fake_verifier.return_value = self.good_verifier
fake_challenge.return_value = self.good_challenger
flow = client.OAuth2WebServerFlow(
'client_id+1',
scope='foo',
redirect_uri='http://example.com',
pkce=True,
code_verifier=self.bad_verifier)
auth_url = urllib.parse.urlparse(flow.step1_get_authorize_url())
self.assertEqual(flow.code_verifier, self.bad_verifier)
results = dict(urllib.parse.parse_qsl(auth_url.query))
self.assertEqual(
results['code_challenge'], self.good_challenger.decode())
self.assertEqual(results['code_challenge_method'], 'S256')
fake_verifier.assert_not_called()
fake_challenge.assert_called_with(self.bad_verifier)
def test_step1_get_authorize_url_without_redirect(self): def test_step1_get_authorize_url_without_redirect(self):
flow = client.OAuth2WebServerFlow('client_id+1', scope='foo', flow = client.OAuth2WebServerFlow('client_id+1', scope='foo',
@@ -1955,17 +1981,18 @@ class OAuth2WebServerFlowTest(unittest.TestCase):
({'status': http_client.OK}, b'access_token=SlAV32hkKG'), ({'status': http_client.OK}, b'access_token=SlAV32hkKG'),
]) ])
flow = client.OAuth2WebServerFlow( flow = client.OAuth2WebServerFlow(
'client_id+1', 'client_id+1',
scope='foo', scope='foo',
redirect_uri='http://example.com', redirect_uri='http://example.com',
pkce=True, pkce=True,
code_verifier=b'__TEST_VERIFIER__' code_verifier=self.good_verifier)
)
flow.step2_exchange(code='some random code', http=http) flow.step2_exchange(code='some random code', http=http)
self.assertEqual(len(http.requests), 1) self.assertEqual(len(http.requests), 1)
test_request = http.requests[0] test_request = http.requests[0]
self.assertIn('code_verifier=__TEST_VERIFIER__', test_request['body']) self.assertIn(
'code_verifier={0}'.format(self.good_verifier.decode()),
test_request['body'])
def test_exchange_using_authorization_header(self): def test_exchange_using_authorization_header(self):
auth_header = 'Basic Y2xpZW50X2lkKzE6c2Vjexc_managerV0KzE=', auth_header = 'Basic Y2xpZW50X2lkKzE6c2Vjexc_managerV0KzE=',