Stage 1 conversion to JSON for storing Credentials.

Reviewed in http://codereview.appspot.com/4972065/
This commit is contained in:
Joe Gregorio
2011-09-15 09:06:38 -04:00
parent 85964c56b3
commit 562b7312cf
16 changed files with 379 additions and 71 deletions

View File

@@ -59,11 +59,12 @@ DEFAULT_METHOD_DOC = 'A description of how to use this function'
STACK_QUERY_PARAMETERS = ['trace', 'fields', 'pp', 'prettyPrint', 'userIp', STACK_QUERY_PARAMETERS = ['trace', 'fields', 'pp', 'prettyPrint', 'userIp',
'userip', 'strict'] 'userip', 'strict']
RESERVED_WORDS = [ 'and', 'assert', 'break', 'class', 'continue', 'def', 'del', RESERVED_WORDS = ['and', 'assert', 'break', 'class', 'continue', 'def', 'del',
'elif', 'else', 'except', 'exec', 'finally', 'for', 'from', 'elif', 'else', 'except', 'exec', 'finally', 'for', 'from',
'global', 'if', 'import', 'in', 'is', 'lambda', 'not', 'or', 'global', 'if', 'import', 'in', 'is', 'lambda', 'not', 'or',
'pass', 'print', 'raise', 'return', 'try', 'while' ] 'pass', 'print', 'raise', 'return', 'try', 'while' ]
def _fix_method_name(name): def _fix_method_name(name):
if name in RESERVED_WORDS: if name in RESERVED_WORDS:
return name + '_' return name + '_'
@@ -242,10 +243,10 @@ def _cast(value, schema_type):
return str(value) return str(value)
MULTIPLIERS = { MULTIPLIERS = {
"KB": 2**10, "KB": 2 ** 10,
"MB": 2**20, "MB": 2 ** 20,
"GB": 2**30, "GB": 2 ** 30,
"TB": 2**40, "TB": 2 ** 40,
} }
def _media_size_to_long(maxSize): def _media_size_to_long(maxSize):
@@ -255,7 +256,7 @@ def _media_size_to_long(maxSize):
units = maxSize[-2:].upper() units = maxSize[-2:].upper()
multiplier = MULTIPLIERS.get(units, 0) multiplier = MULTIPLIERS.get(units, 0)
if multiplier: if multiplier:
return int(maxSize[:-2])*multiplier return int(maxSize[:-2]) * multiplier
else: else:
return int(maxSize) return int(maxSize)

View File

@@ -222,7 +222,8 @@ class BaseModel(Model):
_abstract() _abstract()
def deserialize(self, content): def deserialize(self, content):
"""Perform the actual deserialization from response string to Python object. """Perform the actual deserialization from response string to Python
object.
Args: Args:
content: string, the body of the HTTP response content: string, the body of the HTTP response
@@ -285,8 +286,8 @@ class ProtocolBufferModel(BaseModel):
de-serialized using the given protocol buffer class. de-serialized using the given protocol buffer class.
Args: Args:
protocol_buffer: The protocol buffer class used to de-serialize a response protocol_buffer: The protocol buffer class used to de-serialize a
from the API. response from the API.
""" """
self._protocol_buffer = protocol_buffer self._protocol_buffer = protocol_buffer

View File

@@ -377,7 +377,6 @@ class TwoLeggedOAuthCredentials(Credentials):
return http return http
class FlowThreeLegged(Flow): class FlowThreeLegged(Flow):
"""Does the Three Legged Dance for OAuth 1.0a. """Does the Three Legged Dance for OAuth 1.0a.
""" """

View File

@@ -95,6 +95,16 @@ class AppAssertionCredentials(AssertionCredentials):
None, None,
token_uri) token_uri)
@classmethod
def from_json(cls, json):
data = simplejson.loads(json)
retval = AccessTokenCredentials(
data['scope'],
data['audience'],
data['assertion_type'],
data['token_uri'])
return retval
def _generate_assertion(self): def _generate_assertion(self):
header = { header = {
'typ': 'JWT', 'typ': 'JWT',
@@ -165,17 +175,28 @@ class CredentialsProperty(db.Property):
def get_value_for_datastore(self, model_instance): def get_value_for_datastore(self, model_instance):
cred = super(CredentialsProperty, cred = super(CredentialsProperty,
self).get_value_for_datastore(model_instance) self).get_value_for_datastore(model_instance)
return db.Blob(pickle.dumps(cred)) if cred is None:
cred = ''
else:
cred = cred.to_json()
return db.Blob(cred)
# For reading from datastore. # For reading from datastore.
def make_value_from_datastore(self, value): def make_value_from_datastore(self, value):
if value is None: if value is None:
return None return None
return pickle.loads(value) if len(value) == 0:
return None
credentials = None
try:
credentials = Credentials.new_from_json(value)
except ValueError:
credentials = pickle.loads(value)
return credentials
def validate(self, value): def validate(self, value):
if value is not None and not isinstance(value, Credentials): if value is not None and not isinstance(value, Credentials):
raise BadValueError('Property %s must be convertible ' raise db.BadValueError('Property %s must be convertible '
'to an Credentials instance (%s)' % 'to an Credentials instance (%s)' %
(self.name, value)) (self.name, value))
return super(CredentialsProperty, self).validate(value) return super(CredentialsProperty, self).validate(value)
@@ -215,15 +236,15 @@ class StorageByKeyName(Storage):
oauth2client.Credentials oauth2client.Credentials
""" """
if self._cache: if self._cache:
credential = self._cache.get(self._key_name) json = self._cache.get(self._key_name)
if credential: if json:
return pickle.loads(credential) return Credentials.new_from_json(json)
entity = self._model.get_or_insert(self._key_name) entity = self._model.get_or_insert(self._key_name)
credential = getattr(entity, self._property_name) credential = getattr(entity, self._property_name)
if credential and hasattr(credential, 'set_store'): if credential and hasattr(credential, 'set_store'):
credential.set_store(self) credential.set_store(self)
if self._cache: if self._cache:
self._cache.set(self._key_name, pickle.dumps(credentials)) self._cache.set(self._key_name, credentials.to_json())
return credential return credential
@@ -237,7 +258,7 @@ class StorageByKeyName(Storage):
setattr(entity, self._property_name, credentials) setattr(entity, self._property_name, credentials)
entity.put() entity.put()
if self._cache: if self._cache:
self._cache.set(self._key_name, pickle.dumps(credentials)) self._cache.set(self._key_name, credentials.to_json())
class CredentialsModel(db.Model): class CredentialsModel(db.Model):

View File

@@ -43,6 +43,9 @@ except ImportError:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Expiry is stored in RFC3339 UTC format
EXPIRY_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ"
class Error(Exception): class Error(Exception):
"""Base error for this module.""" """Base error for this module."""
@@ -71,10 +74,15 @@ def _abstract():
class Credentials(object): class Credentials(object):
"""Base class for all Credentials objects. """Base class for all Credentials objects.
Subclasses must define an authorize() method Subclasses must define an authorize() method that applies the credentials to
that applies the credentials to an HTTP transport. an HTTP transport.
Subclasses must also specify a classmethod named 'from_json' that takes a JSON
string as input and returns an instaniated Crentials object.
""" """
NON_SERIALIZED_MEMBERS = ['store']
def authorize(self, http): def authorize(self, http):
"""Take an httplib2.Http instance (or equivalent) and """Take an httplib2.Http instance (or equivalent) and
authorizes it for the set of credentials, usually by authorizes it for the set of credentials, usually by
@@ -84,6 +92,58 @@ class Credentials(object):
""" """
_abstract() _abstract()
def _to_json(self, strip):
"""Utility function for creating a JSON representation of an instance of Credentials.
Args:
strip: array, An array of names of members to not include in the JSON.
Returns:
string, a JSON representation of this instance, suitable to pass to
from_json().
"""
t = type(self)
d = copy.copy(self.__dict__)
for member in strip:
del d[member]
if 'token_expiry' in d and isinstance(d['token_expiry'], datetime.datetime):
d['token_expiry'] = d['token_expiry'].strftime(EXPIRY_FORMAT)
# Add in information we will need later to reconsistitue this instance.
d['_class'] = t.__name__
d['_module'] = t.__module__
return simplejson.dumps(d)
def to_json(self):
"""Creating a JSON representation of an instance of Credentials.
Returns:
string, a JSON representation of this instance, suitable to pass to
from_json().
"""
return self._to_json(Credentials.NON_SERIALIZED_MEMBERS)
@classmethod
def new_from_json(cls, s):
"""Utility class method to instantiate a Credentials subclass from a JSON
representation produced by to_json().
Args:
s: string, JSON from to_json().
Returns:
An instance of the subclass of Credentials that was serialized with
to_json().
"""
data = simplejson.loads(s)
# Find and call the right classmethod from_json() to restore the object.
module = data['_module']
m = __import__(module)
for sub_module in module.split('.')[1:]:
m = getattr(m, sub_module)
kls = getattr(m, data['_class'])
from_json = getattr(kls, 'from_json')
return from_json(s)
class Flow(object): class Flow(object):
"""Base class for all Flow objects.""" """Base class for all Flow objects."""
@@ -206,6 +266,36 @@ class OAuth2Credentials(Credentials):
# refreshed. # refreshed.
self.invalid = False self.invalid = False
def to_json(self):
return self._to_json(Credentials.NON_SERIALIZED_MEMBERS)
@classmethod
def from_json(cls, s):
"""Instantiate a Credentials object from a JSON description of it. The JSON
should have been produced by calling .to_json() on the object.
Args:
data: dict, A deserialized JSON object.
Returns:
An instance of a Credentials subclass.
"""
data = simplejson.loads(s)
if 'token_expiry' in data and not isinstance(data['token_expiry'],
datetime.datetime):
data['token_expiry'] = datetime.datetime.strptime(
data['token_expiry'], EXPIRY_FORMAT)
retval = OAuth2Credentials(
data['access_token'],
data['client_id'],
data['client_secret'],
data['refresh_token'],
data['token_expiry'],
data['token_uri'],
data['user_agent'])
retval.invalid = data['invalid']
return retval
@property @property
def access_token_expired(self): def access_token_expired(self):
"""True if the credential is expired or invalid. """True if the credential is expired or invalid.
@@ -218,7 +308,7 @@ class OAuth2Credentials(Credentials):
if not self.token_expiry: if not self.token_expiry:
return False return False
now = datetime.datetime.now() now = datetime.datetime.utcnow()
if now >= self.token_expiry: if now >= self.token_expiry:
logger.info('access_token is expired. Now: %s, token_expiry: %s', logger.info('access_token is expired. Now: %s, token_expiry: %s',
now, self.token_expiry) now, self.token_expiry)
@@ -318,7 +408,7 @@ class OAuth2Credentials(Credentials):
self.refresh_token = d.get('refresh_token', self.refresh_token) self.refresh_token = d.get('refresh_token', self.refresh_token)
if 'expires_in' in d: if 'expires_in' in d:
self.token_expiry = datetime.timedelta( self.token_expiry = datetime.timedelta(
seconds=int(d['expires_in'])) + datetime.datetime.now() seconds=int(d['expires_in'])) + datetime.datetime.utcnow()
else: else:
self.token_expiry = None self.token_expiry = None
if self.store: if self.store:
@@ -446,6 +536,15 @@ class AccessTokenCredentials(OAuth2Credentials):
None, None,
user_agent) user_agent)
@classmethod
def from_json(cls, s):
data = simplejson.loads(s)
retval = AccessTokenCredentials(
data['access_token'],
data['user_agent'])
return retval
def _refresh(self, http_request): def _refresh(self, http_request):
raise AccessTokenCredentialsError( raise AccessTokenCredentialsError(
"The access_token is expired or invalid and can't be refreshed.") "The access_token is expired or invalid and can't be refreshed.")
@@ -601,7 +700,7 @@ class OAuth2WebServerFlow(Flow):
refresh_token = d.get('refresh_token', None) refresh_token = d.get('refresh_token', None)
token_expiry = None token_expiry = None
if 'expires_in' in d: if 'expires_in' in d:
token_expiry = datetime.datetime.now() + datetime.timedelta( token_expiry = datetime.datetime.utcnow() + datetime.timedelta(
seconds=int(d['expires_in'])) seconds=int(d['expires_in']))
logger.info('Successfully retrieved access token: %s' % content) logger.info('Successfully retrieved access token: %s' % content)

View File

@@ -23,7 +23,20 @@ __author__ = 'jcgregorio@google.com (Joe Gregorio)'
import pickle import pickle
import threading import threading
try: # pragma: no cover
import simplejson
except ImportError: # pragma: no cover
try:
# Try to import from django, should work on App Engine
from django.utils import simplejson
except ImportError:
# Should work for Python2.6 and higher.
import json as simplejson
from client import Storage as BaseStorage from client import Storage as BaseStorage
from client import Credentials
class Storage(BaseStorage): class Storage(BaseStorage):
@@ -40,25 +53,40 @@ class Storage(BaseStorage):
oauth2client.client.Credentials oauth2client.client.Credentials
""" """
self._lock.acquire() self._lock.acquire()
credentials = None
try: try:
f = open(self._filename, 'r') f = open(self._filename, 'r')
credentials = pickle.loads(f.read()) content = f.read()
f.close() f.close()
except IOError:
self._lock.release()
return credentials
# First try reading as JSON, and if that fails fall back to pickle.
try:
credentials = Credentials.new_from_json(content)
credentials.set_store(self) credentials.set_store(self)
except: except ValueError:
credentials = None # TODO(jcgregorio) On a future release remove this path to finally remove
self._lock.release() # all pickle support.
try:
credentials = pickle.loads(content)
credentials.set_store(self)
except:
pass
finally:
self._lock.release()
return credentials return credentials
def put(self, credentials): def put(self, credentials):
"""Write a pickled Credentials to file. """Write Credentials to file.
Args: Args:
credentials: Credentials, the credentials to store. credentials: Credentials, the credentials to store.
""" """
self._lock.acquire() self._lock.acquire()
f = open(self._filename, 'w') f = open(self._filename, 'w')
f.write(pickle.dumps(credentials)) f.write(credentials.to_json())
f.close() f.close()
self._lock.release() self._lock.release()

View File

@@ -21,7 +21,9 @@ The format of the stored data is like so:
'userAgent': '<user agent>', 'userAgent': '<user agent>',
'scope': '<scope>' 'scope': '<scope>'
}, },
'credential': '<base64 encoding of pickeled Credential object>' 'credential': {
# JSON serialized Credentials.
}
} }
] ]
} }
@@ -47,6 +49,7 @@ except ImportError: # pragma: no cover
import json as simplejson import json as simplejson
from client import Storage as BaseStorage from client import Storage as BaseStorage
from client import Credentials
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -295,7 +298,8 @@ class _MultiStore(object):
user_agent = raw_key['userAgent'] user_agent = raw_key['userAgent']
scope = raw_key['scope'] scope = raw_key['scope']
key = (client_id, user_agent, scope) key = (client_id, user_agent, scope)
credential = pickle.loads(base64.b64decode(cred_entry['credential'])) credential = None
credential = Credentials.new_from_json(simplejson.dumps(cred_entry['credential']))
return (key, credential) return (key, credential)
def _write(self): def _write(self):
@@ -312,7 +316,7 @@ class _MultiStore(object):
'userAgent': cred_key[1], 'userAgent': cred_key[1],
'scope': cred_key[2] 'scope': cred_key[2]
} }
raw_cred = base64.b64encode(pickle.dumps(cred)) raw_cred = simplejson.loads(cred.to_json())
raw_creds.append({'key': raw_key, 'credential': raw_cred}) raw_creds.append({'key': raw_key, 'credential': raw_cred})
self._locked_json_write(raw_data) self._locked_json_write(raw_data)
@@ -330,6 +334,7 @@ class _MultiStore(object):
The credential specified or None if not present The credential specified or None if not present
""" """
key = (client_id, user_agent, scope) key = (client_id, user_agent, scope)
return self._data.get(key, None) return self._data.get(key, None)
def _update_credential(self, cred, scope): def _update_credential(self, cred, scope):

View File

@@ -129,12 +129,16 @@ def run(flow, storage):
print '--noauth_local_webserver.' print '--noauth_local_webserver.'
print print
code = None
if FLAGS.auth_local_webserver: if FLAGS.auth_local_webserver:
httpd.handle_request() httpd.handle_request()
if 'error' in httpd.query_params: if 'error' in httpd.query_params:
sys.exit('Authentication request was rejected.') sys.exit('Authentication request was rejected.')
if 'code' in httpd.query_params: if 'code' in httpd.query_params:
code = httpd.query_params['code'] code = httpd.query_params['code']
else:
print 'Failed to find "code" in the query parameters of the redirect.'
sys.exit('Try running with --noauth_local_webserver.')
else: else:
code = raw_input('Enter verification code: ').strip() code = raw_input('Enter verification code: ').strip()

View File

@@ -43,8 +43,7 @@ from google.appengine.ext.webapp.util import run_wsgi_app
decorator = OAuth2Decorator( decorator = OAuth2Decorator(
client_id='837647042410-75ifgipj95q4agpm0cs452mg7i2pn17c.apps.googleusercontent.com', client_id='837647042410-75ifgipj95q4agpm0cs452mg7i2pn17c.apps.googleusercontent.com',
client_secret='QhxYsjM__u4vy5N0DXUFRwwI', client_secret='QhxYsjM__u4vy5N0DXUFRwwI',
scope='https://www.googleapis.com/auth/buzz', scope='https://www.googleapis.com/auth/buzz')
user_agent='my-sample-app/1.0')
http = httplib2.Http(memcache) http = httplib2.Http(memcache)
service = build("buzz", "v1", http=http) service = build("buzz", "v1", http=http)

View File

@@ -88,6 +88,7 @@ def main(argv):
# Credentials will get written back to a file. # Credentials will get written back to a file.
storage = Storage('buzz.dat') storage = Storage('buzz.dat')
credentials = storage.get() credentials = storage.get()
if credentials is None or credentials.invalid: if credentials is None or credentials.invalid:
credentials = run(FLOW, storage) credentials = run(FLOW, storage)
@@ -112,31 +113,10 @@ def main(argv):
activitylist = activities.list_next(activitylist).execute() activitylist = activities.list_next(activitylist).execute()
print "Retrieved the next two activities" print "Retrieved the next two activities"
# Add a new activity # List the number of followers
new_activity_body = { followers = service.people().list(
'title': 'Testing insert', userId='@me', groupId='@followers').execute(http)
'object': { print 'Hello, you have %s followers!' % followers['totalResults']
'content':
u'Just a short note to show that insert is working. ☄',
'type': 'note'}
}
activity = activities.insert(userId='@me', body=new_activity_body).execute()
print "Added a new activity"
activitylist = activities.list(
max_results='2', scope='@self', userId='@me').execute()
# Add a comment to that activity
comment_body = {
"content": "This is a comment"
}
item = activitylist['items'][0]
comment = service.comments().insert(
userId=item['actor']['id'], postId=item['id'], body=comment_body
).execute()
print 'Added a comment to the new activity'
pprint.pprint(comment)
except AccessTokenRefreshError: except AccessTokenRefreshError:
print ("The credentials have been revoked or expired, please re-run" print ("The credentials have been revoked or expired, please re-run"

View File

@@ -140,8 +140,6 @@ def main(argv):
body=vote_body) body=vote_body)
print "Voted on the submission" print "Voted on the submission"
except AccessTokenRefreshError: except AccessTokenRefreshError:
print ("The credentials have been revoked or expired, please re-run" print ("The credentials have been revoked or expired, please re-run"
"the application to re-authorize") "the application to re-authorize")

View File

@@ -109,7 +109,7 @@ def main(argv):
# Start training on a data set # Start training on a data set
train = service.training() train = service.training()
body = {'id' : FLAGS.object_name} body = {'id': FLAGS.object_name}
start = train.insert(body=body).execute() start = train.insert(body=body).execute()
print 'Started training' print 'Started training'

View File

@@ -43,4 +43,3 @@
submissionId=submission['id']['submissionId'], submissionId=submission['id']['submissionId'],
body=vote_body) body=vote_body)
print "Voted on the submission" print "Voted on the submission"

View File

@@ -22,6 +22,7 @@ Unit tests for oauth2client.
__author__ = 'jcgregorio@google.com (Joe Gregorio)' __author__ = 'jcgregorio@google.com (Joe Gregorio)'
import datetime
import httplib2 import httplib2
import unittest import unittest
import urlparse import urlparse
@@ -31,6 +32,16 @@ try:
except ImportError: except ImportError:
from cgi import parse_qs from cgi import parse_qs
try: # pragma: no cover
import simplejson
except ImportError: # pragma: no cover
try:
# Try to import from django, should work on App Engine
from django.utils import simplejson
except ImportError:
# Should work for Python2.6 and higher.
import json as simplejson
from apiclient.http import HttpMockSequence from apiclient.http import HttpMockSequence
from oauth2client.client import AccessTokenCredentials from oauth2client.client import AccessTokenCredentials
from oauth2client.client import AccessTokenCredentialsError from oauth2client.client import AccessTokenCredentialsError
@@ -48,7 +59,7 @@ class OAuth2CredentialsTests(unittest.TestCase):
client_id = "some_client_id" client_id = "some_client_id"
client_secret = "cOuDdkfjxxnv+" client_secret = "cOuDdkfjxxnv+"
refresh_token = "1/0/a.df219fjls0" refresh_token = "1/0/a.df219fjls0"
token_expiry = "ignored" token_expiry = datetime.datetime.utcnow()
token_uri = "https://www.google.com/accounts/o8/oauth2/token" token_uri = "https://www.google.com/accounts/o8/oauth2/token"
user_agent = "refresh_checker/1.0" user_agent = "refresh_checker/1.0"
self.credentials = OAuth2Credentials( self.credentials = OAuth2Credentials(
@@ -86,6 +97,12 @@ class OAuth2CredentialsTests(unittest.TestCase):
resp, content = http.request("http://example.com") resp, content = http.request("http://example.com")
self.assertEqual(400, resp.status) self.assertEqual(400, resp.status)
def test_to_from_json(self):
json = self.credentials.to_json()
instance = OAuth2Credentials.from_json(json)
self.assertEquals(type(instance), OAuth2Credentials)
self.assertEquals(self.credentials.__dict__, instance.__dict__)
class AccessTokenCredentialsTests(unittest.TestCase): class AccessTokenCredentialsTests(unittest.TestCase):

View File

@@ -173,7 +173,7 @@ class DecoratorTests(unittest.TestCase):
self.assertEqual('code', q['response_type'][0]) self.assertEqual('code', q['response_type'][0])
self.assertEqual(False, self.decorator.has_credentials()) self.assertEqual(False, self.decorator.has_credentials())
# Now simulate the callback to /oauth2callback # Now simulate the callback to /oauth2callback.
response = self.app.get('/oauth2callback', { response = self.app.get('/oauth2callback', {
'code': 'foo_access_code', 'code': 'foo_access_code',
'state': 'foo_path', 'state': 'foo_path',
@@ -181,7 +181,7 @@ class DecoratorTests(unittest.TestCase):
self.assertEqual('http://localhost/foo_path', response.headers['Location']) self.assertEqual('http://localhost/foo_path', response.headers['Location'])
self.assertEqual(None, self.decorator.credentials) self.assertEqual(None, self.decorator.credentials)
# Now requesting the decorated path should work # Now requesting the decorated path should work.
response = self.app.get('/foo_path') response = self.app.get('/foo_path')
self.assertEqual('200 OK', response.status) self.assertEqual('200 OK', response.status)
self.assertEqual(True, self.decorator.has_credentials()) self.assertEqual(True, self.decorator.has_credentials())
@@ -190,18 +190,18 @@ class DecoratorTests(unittest.TestCase):
self.assertEqual('foo_access_token', self.assertEqual('foo_access_token',
self.decorator.credentials.access_token) self.decorator.credentials.access_token)
# Invalidate the stored Credentials # Invalidate the stored Credentials.
self.decorator.credentials.invalid = True self.decorator.credentials.invalid = True
self.decorator.credentials.store.put(self.decorator.credentials) self.decorator.credentials.store.put(self.decorator.credentials)
# Invalid Credentials should start the OAuth dance again # Invalid Credentials should start the OAuth dance again.
response = self.app.get('/foo_path') response = self.app.get('/foo_path')
self.assertTrue(response.status.startswith('302')) self.assertTrue(response.status.startswith('302'))
q = parse_qs(response.headers['Location'].split('?', 1)[1]) q = parse_qs(response.headers['Location'].split('?', 1)[1])
self.assertEqual('http://localhost/oauth2callback', q['redirect_uri'][0]) self.assertEqual('http://localhost/oauth2callback', q['redirect_uri'][0])
def test_aware(self): def test_aware(self):
# An initial request to an oauth_aware decorated path should not redirect # An initial request to an oauth_aware decorated path should not redirect.
response = self.app.get('/bar_path') response = self.app.get('/bar_path')
self.assertEqual('Hello World!', response.body) self.assertEqual('Hello World!', response.body)
self.assertEqual('200 OK', response.status) self.assertEqual('200 OK', response.status)
@@ -214,7 +214,7 @@ class DecoratorTests(unittest.TestCase):
self.assertEqual('http://localhost/bar_path', q['state'][0]) self.assertEqual('http://localhost/bar_path', q['state'][0])
self.assertEqual('code', q['response_type'][0]) self.assertEqual('code', q['response_type'][0])
# Now simulate the callback to /oauth2callback # Now simulate the callback to /oauth2callback.
url = self.decorator.authorize_url() url = self.decorator.authorize_url()
response = self.app.get('/oauth2callback', { response = self.app.get('/oauth2callback', {
'code': 'foo_access_code', 'code': 'foo_access_code',
@@ -223,7 +223,7 @@ class DecoratorTests(unittest.TestCase):
self.assertEqual('http://localhost/bar_path', response.headers['Location']) self.assertEqual('http://localhost/bar_path', response.headers['Location'])
self.assertEqual(False, self.decorator.has_credentials()) self.assertEqual(False, self.decorator.has_credentials())
# Now requesting the decorated path will have credentials # Now requesting the decorated path will have credentials.
response = self.app.get('/bar_path') response = self.app.get('/bar_path')
self.assertEqual('200 OK', response.status) self.assertEqual('200 OK', response.status)
self.assertEqual('Hello World!', response.body) self.assertEqual('Hello World!', response.body)

View File

@@ -0,0 +1,157 @@
#!/usr/bin/python2.4
#
# Copyright 2010 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Oauth2client.file tests
Unit tests for oauth2client.file
"""
__author__ = 'jcgregorio@google.com (Joe Gregorio)'
import os
import pickle
import unittest
import datetime
try: # pragma: no cover
import simplejson
except ImportError: # pragma: no cover
try:
# Try to import from django, should work on App Engine
from django.utils import simplejson
except ImportError:
# Should work for Python2.6 and higher.
import json as simplejson
from oauth2client.client import OAuth2Credentials
from oauth2client.client import AccessTokenCredentials
from oauth2client.client import AssertionCredentials
from oauth2client.file import Storage
from oauth2client import multistore_file
FILENAME = os.path.join(os.path.dirname(__file__), 'test_file_storage.data')
class OAuth2ClientFileTests(unittest.TestCase):
def tearDown(self):
try:
os.unlink(FILENAME)
except OSError:
pass
def setUp(self):
try:
os.unlink(FILENAME)
except OSError:
pass
def test_non_existent_file_storage(self):
s = Storage(FILENAME)
credentials = s.get()
self.assertEquals(None, credentials)
def test_pickle_and_json_interop(self):
# Write a file with a pickled OAuth2Credentials.
access_token = 'foo'
client_id = 'some_client_id'
client_secret = 'cOuDdkfjxxnv+'
refresh_token = '1/0/a.df219fjls0'
token_expiry = datetime.datetime.utcnow()
token_uri = 'https://www.google.com/accounts/o8/oauth2/token'
user_agent = 'refresh_checker/1.0'
credentials = OAuth2Credentials(
access_token, client_id, client_secret,
refresh_token, token_expiry, token_uri,
user_agent)
f = open(FILENAME, 'w')
pickle.dump(credentials, f)
f.close()
# Storage should be able to read that object.
# TODO(jcgregorio) This should fail once pickle support is removed.
s = Storage(FILENAME)
credentials = s.get()
self.assertNotEquals(None, credentials)
self.assertEquals('foo', credentials.access_token)
# Now write it back out and confirm it has been rewritten as JSON
s.put(credentials)
f = file(FILENAME)
data = simplejson.load(f)
f.close()
self.assertEquals(data['access_token'], 'foo')
self.assertEquals(data['_class'], 'OAuth2Credentials')
self.assertEquals(data['_module'], 'oauth2client.client')
def test_access_token_credentials(self):
access_token = 'foo'
user_agent = 'refresh_checker/1.0'
credentials = AccessTokenCredentials(access_token, user_agent)
s = Storage(FILENAME)
credentials = s.put(credentials)
credentials = s.get()
self.assertNotEquals(None, credentials)
self.assertEquals('foo', credentials.access_token)
def test_multistore_non_existent_file(self):
store = multistore_file.get_credential_storage(
FILENAME,
'some_client_id',
'user-agent/1.0',
'some-scope')
credentials = store.get()
self.assertEquals(None, credentials)
def test_multistore_file(self):
access_token = 'foo'
client_secret = 'cOuDdkfjxxnv+'
refresh_token = '1/0/a.df219fjls0'
token_expiry = datetime.datetime.utcnow()
token_uri = 'https://www.google.com/accounts/o8/oauth2/token'
user_agent = 'refresh_checker/1.0'
client_id = 'some_client_id'
credentials = OAuth2Credentials(
access_token, client_id, client_secret,
refresh_token, token_expiry, token_uri,
user_agent)
store = multistore_file.get_credential_storage(
FILENAME,
credentials.client_id,
credentials.user_agent,
'some-scope')
store.put(credentials)
credentials = store.get()
self.assertNotEquals(None, credentials)
self.assertEquals('foo', credentials.access_token)
if __name__ == '__main__':
unittest.main()