Enable SQL tests for oauth

Enables testing of the SQL backend for oauth.
As a result of running the tests, I noticed that this also
fixes an update_consumer test cases too.

fixes bug: #1215483
fixes bug: #1216447

Change-Id: I206d164caa66c3211cfc216d13e3d0bab0e7d54a
This commit is contained in:
Steve Martinelli 2013-08-24 20:18:49 -05:00
parent 8fdfbf04ba
commit a746527c41
6 changed files with 44 additions and 24 deletions

View File

@ -44,7 +44,7 @@ class OAuth(auth.AuthMethodHandler):
attribute='oauth_token', target='request')
acc_token = self.oauth_api.get_access_token(access_token_id)
consumer = self.oauth_api._get_consumer(consumer_id)
consumer = self.oauth_api.get_consumer_with_secret(consumer_id)
expires_at = acc_token['expires_at']
if expires_at:

View File

@ -84,19 +84,20 @@ class OAuth1(sql.Base):
def db_sync(self):
migration.db_sync()
def _get_consumer(self, consumer_id):
session = self.get_session()
def _get_consumer(self, session, consumer_id):
consumer_ref = session.query(Consumer).get(consumer_id)
if consumer_ref is None:
raise exception.NotFound(_('Consumer not found'))
return consumer_ref
def get_consumer(self, consumer_id):
def get_consumer_with_secret(self, consumer_id):
session = self.get_session()
consumer_ref = session.query(Consumer).get(consumer_id)
if consumer_ref is None:
raise exception.NotFound(_('Consumer not found'))
return core.filter_consumer(consumer_ref.to_dict())
consumer_ref = self._get_consumer(session, consumer_id)
return consumer_ref.to_dict()
def get_consumer(self, consumer_id):
return core.filter_consumer(
self.get_consumer_with_secret(consumer_id))
def create_consumer(self, consumer):
consumer['secret'] = uuid.uuid4().hex
@ -110,7 +111,7 @@ class OAuth1(sql.Base):
return consumer_ref.to_dict()
def _delete_consumer(self, session, consumer_id):
consumer_ref = self._get_consumer(consumer_id)
consumer_ref = self._get_consumer(session, consumer_id)
q = session.query(Consumer)
q = q.filter_by(id=consumer_id)
q.delete(False)
@ -154,15 +155,11 @@ class OAuth1(sql.Base):
def update_consumer(self, consumer_id, consumer):
session = self.get_session()
with session.begin():
consumer_ref = self._get_consumer(consumer_id)
consumer_ref = self._get_consumer(session, consumer_id)
old_consumer_dict = consumer_ref.to_dict()
old_consumer_dict.update(consumer)
new_consumer = Consumer.from_dict(old_consumer_dict)
for attr in Consumer.attributes:
if (attr != 'id' or attr != 'secret'):
setattr(consumer_ref,
attr,
getattr(new_consumer, attr))
consumer_ref.description = new_consumer.description
consumer_ref.extra = new_consumer.extra
session.flush()
return core.filter_consumer(consumer_ref.to_dict())

View File

@ -172,7 +172,7 @@ class OAuthControllerV3(controller.V3Controller):
attribute='requested_project_id', target='request')
req_role_ids = requested_role_ids.split(',')
consumer_ref = self.oauth_api._get_consumer(consumer_id)
consumer_ref = self.oauth_api.get_consumer_with_secret(consumer_id)
consumer = oauth1.Consumer(key=consumer_ref['id'],
secret=consumer_ref['secret'])
@ -251,7 +251,7 @@ class OAuthControllerV3(controller.V3Controller):
raise exception.ValidationError(
attribute='oauth_verifier', target='request')
consumer = self.oauth_api._get_consumer(consumer_id)
consumer = self.oauth_api.get_consumer_with_secret(consumer_id)
req_token = self.oauth_api.get_request_token(
request_token_id)

View File

@ -169,7 +169,22 @@ class Driver(object):
raise exception.NotImplemented()
def get_consumer(self, consumer_id):
"""Get consumer.
"""Get consumer, returns the consumer id (key)
and description.
:param consumer_id: id of consumer to get
:type consumer_ref: string
:returns: consumer_ref
"""
raise exception.NotImplemented()
def get_consumer_with_secret(self, consumer_id):
"""Like get_consumer() but returned consumer_ref includes
the consumer secret.
Secrets should only be shared upon consumer creation; the
consumer secret is required to verify incoming OAuth requests.
:param consumer_id: id of consumer to get
:type consumer_ref: string

View File

@ -19,9 +19,6 @@ backend = dogpile.cache.memory
enabled = True
debug_cache_backend = True
[oauth1]
driver = keystone.contrib.oauth1.backends.kvs.OAuth1
[signing]
certfile = ../../examples/pki/certs/signing_cert.pem
keyfile = ../../examples/pki/private/signing_key.pem

View File

@ -22,9 +22,12 @@ import uuid
import webtest
from keystone.common import cms
from keystone.common.sql import migration
from keystone import config
from keystone import contrib
from keystone.contrib import oauth1
from keystone.contrib.oauth1 import controllers
from keystone.openstack.common import importutils
from keystone.tests import core
import test_v3
@ -35,12 +38,22 @@ CONF = config.CONF
class OAuth1Tests(test_v3.RestfulTestCase):
EXTENSION_NAME = 'oauth1'
def setup_database(self):
super(OAuth1Tests, self).setup_database()
package_name = "%s.%s.migrate_repo" % (contrib.__name__,
self.EXTENSION_NAME)
package = importutils.import_module(package_name)
self.repo_path = os.path.abspath(os.path.dirname(package.__file__))
migration.db_version_control(version=None, repo_path=self.repo_path)
migration.db_sync(version=None, repo_path=self.repo_path)
def setUp(self):
super(OAuth1Tests, self).setUp()
self.controller = controllers.OAuthControllerV3()
self.base_url = CONF.public_endpoint % CONF + "v3"
self._generate_paste_config()
self.load_backends()
self.admin_app = webtest.TestApp(
self.loadapp('v3_oauth1', name='admin'))
self.public_app = webtest.TestApp(
@ -169,7 +182,6 @@ class ConsumerCRUDTests(OAuth1Tests):
consumer = self._create_single_consumer()
original_id = consumer.get('id')
original_description = consumer.get('description')
original_secret = consumer.get('secret')
update_description = original_description + "_new"
update_ref = {'description': update_description}
@ -179,7 +191,6 @@ class ConsumerCRUDTests(OAuth1Tests):
consumer = update_resp.result.get('consumer')
self.assertEqual(consumer.get('description'), update_description)
self.assertEqual(consumer.get('id'), original_id)
self.assertEqual(consumer.get('secret'), original_secret)
def test_consumer_update_bad_secret(self):
consumer = self._create_single_consumer()