Make decorators thread-safe.

Reviewed in https://codereview.appspot.com/9363044/.
This commit is contained in:
Joe Gregorio
2013-05-16 15:52:57 -04:00
parent 48d10b04b5
commit b8b6feab37
2 changed files with 85 additions and 14 deletions

View File

@@ -25,6 +25,7 @@ import httplib2
import logging
import os
import pickle
import threading
import time
from google.appengine.api import app_identity
@@ -570,6 +571,22 @@ class OAuth2Decorator(object):
"""
def set_credentials(self, credentials):
self._tls.credentials = credentials
def get_credentials(self):
return self._tls.credentials
def set_flow(self, flow):
self._tls.flow = flow
def get_flow(self):
return self._tls.flow
flow = property(get_flow, set_flow)
credentials = property(get_credentials, set_credentials)
@util.positional(4)
def __init__(self, client_id, client_secret, scope,
auth_uri=GOOGLE_AUTH_URI,
@@ -621,6 +638,7 @@ class OAuth2Decorator(object):
**kwargs: dict, Keyword arguments are be passed along as kwargs to the
OAuth2WebServerFlow constructor.
"""
self._tls = threading.local()
self.flow = None
self.credentials = None
self._client_id = client_id
@@ -678,9 +696,12 @@ class OAuth2Decorator(object):
if not self.has_credentials():
return request_handler.redirect(self.authorize_url())
try:
return method(request_handler, *args, **kwargs)
resp = method(request_handler, *args, **kwargs)
except AccessTokenRefreshError:
return request_handler.redirect(self.authorize_url())
finally:
self.credentials = None
return resp
return check_oauth
@@ -737,9 +758,14 @@ class OAuth2Decorator(object):
self.credentials = self._storage_class(
self._credentials_class, None,
self._credentials_property_name, user=user).get()
return method(request_handler, *args, **kwargs)
try:
resp = method(request_handler, *args, **kwargs)
finally:
self.credentials = None
return resp
return setup_oauth
def has_credentials(self):
"""True if for the logged in user there are valid access Credentials.

View File

@@ -441,20 +441,31 @@ class DecoratorTests(unittest.TestCase):
def _finish_setup(self, decorator, user_mock):
self.decorator = decorator
self.had_credentials = False
self.found_credentials = None
self.should_raise = False
parent = self
class TestRequiredHandler(webapp2.RequestHandler):
@decorator.oauth_required
def get(self):
pass
if decorator.has_credentials():
parent.had_credentials = True
parent.found_credentials = decorator.credentials
if parent.should_raise:
raise Exception('')
class TestAwareHandler(webapp2.RequestHandler):
@decorator.oauth_aware
def get(self, *args, **kwargs):
self.response.out.write('Hello World!')
assert(kwargs['year'] == '2012')
assert(kwargs['month'] == '01')
if decorator.has_credentials():
parent.had_credentials = True
parent.found_credentials = decorator.credentials
if parent.should_raise:
raise Exception('')
application = webapp2.WSGIApplication([
@@ -507,6 +518,9 @@ class DecoratorTests(unittest.TestCase):
response = parse_qs(parts[1])[self.decorator._token_response_param][0]
self.assertEqual(Http2Mock.content,
simplejson.loads(urllib.unquote(response)))
self.assertEqual(self.decorator.flow, self.decorator._tls.flow)
self.assertEqual(self.decorator.credentials,
self.decorator._tls.credentials)
m.UnsetStubs()
m.VerifyAll()
@@ -514,15 +528,26 @@ class DecoratorTests(unittest.TestCase):
# Now requesting the decorated path should work.
response = self.app.get('/foo_path')
self.assertEqual('200 OK', response.status)
self.assertEqual(True, self.decorator.has_credentials())
self.assertEqual(True, self.had_credentials)
self.assertEqual('foo_refresh_token',
self.decorator.credentials.refresh_token)
self.found_credentials.refresh_token)
self.assertEqual('foo_access_token',
self.decorator.credentials.access_token)
self.found_credentials.access_token)
self.assertEqual(None, self.decorator.credentials)
# Raising an exception still clears the Credentials.
self.should_raise = True
try:
response = self.app.get('/foo_path')
self.fail('Should have raised an exception.')
except Exception:
pass
self.assertEqual(None, self.decorator.credentials)
self.should_raise = False
# Invalidate the stored Credentials.
self.decorator.credentials.invalid = True
self.decorator.credentials.store.put(self.decorator.credentials)
self.found_credentials.invalid = True
self.found_credentials.store.put(self.found_credentials)
# Invalid Credentials should start the OAuth dance again.
response = self.app.get('/foo_path')
@@ -553,8 +578,13 @@ class DecoratorTests(unittest.TestCase):
# Now requesting the decorated path should work.
response = self.app.get('/foo_path')
self.assertTrue(self.had_credentials)
# Credentials should be cleared after each call.
self.assertEqual(None, self.decorator.credentials)
# Invalidate the stored Credentials.
self.decorator.credentials.store.delete()
self.found_credentials.store.delete()
# Invalid Credentials should start the OAuth dance again.
response = self.app.get('/foo_path')
@@ -600,11 +630,25 @@ class DecoratorTests(unittest.TestCase):
response = self.app.get('/bar_path/2012/01')
self.assertEqual('200 OK', response.status)
self.assertEqual('Hello World!', response.body)
self.assertEqual(True, self.decorator.has_credentials())
self.assertEqual(True, self.had_credentials)
self.assertEqual('foo_refresh_token',
self.decorator.credentials.refresh_token)
self.found_credentials.refresh_token)
self.assertEqual('foo_access_token',
self.decorator.credentials.access_token)
self.found_credentials.access_token)
# Credentials should be cleared after each call.
self.assertEqual(None, self.decorator.credentials)
# Raising an exception still clears the Credentials.
self.should_raise = True
try:
response = self.app.get('/bar_path/2012/01')
self.fail('Should have raised an exception.')
except Exception:
pass
self.assertEqual(None, self.decorator.credentials)
self.should_raise = False
def test_error_in_step2(self):
# An initial request to an oauth_aware decorated path should not redirect.
@@ -634,6 +678,7 @@ class DecoratorTests(unittest.TestCase):
self.assertEqual('foo_user_agent', decorator.flow.user_agent)
self.assertEqual('dummy_revoke_uri', decorator.flow.revoke_uri)
self.assertEqual(None, decorator.flow.params.get('user_agent', None))
self.assertEqual(decorator.flow, decorator._tls.flow)
def test_token_response_param(self):
self.decorator._token_response_param = 'foobar'