Split auth in api and change user_tokens

Now access tokens wmodel includes array of
refresh tokens.
To db.api added refresh_tokens and deleted
corresponding fubctions from auth.
To db.api added user_tokens.
Added tests for this patch.

Change-Id: I1a631e43831eb327a34726a1ce7c07d3a14f3aff
This commit is contained in:
Aleksey Ripinen 2015-02-06 18:01:24 +03:00
parent 2769d74e67
commit 6c2ab67bfb
15 changed files with 402 additions and 97 deletions

View File

@ -22,7 +22,8 @@ from oslo.config import cfg
from oslo_log import log
from storyboard.db.api import access_tokens as token_api
from storyboard.db.api import auth as auth_api
from storyboard.db.api import auth_codes as auth_api
from storyboard.db.api import refresh_tokens as refresh_token_api
from storyboard.db.api import users as user_api
CONF = cfg.CONF
@ -197,7 +198,8 @@ class SkeletonValidator(RequestValidator):
# Try refresh token
refresh_token = request._params.get("refresh_token")
refresh_token_entry = auth_api.refresh_token_get(refresh_token)
refresh_token_entry = \
refresh_token_api.refresh_token_get_by_token(refresh_token)
if refresh_token_entry:
return refresh_token_entry.user_id
@ -226,14 +228,14 @@ class SkeletonValidator(RequestValidator):
refresh_expires_in = CONF.oauth.refresh_token_ttl
refresh_token_values = {
"access_token_id": access_token.id,
"refresh_token": token["refresh_token"],
"user_id": user_id,
"expires_in": refresh_expires_in,
"expires_at": datetime.datetime.now(pytz.utc) + datetime.timedelta(
seconds=refresh_expires_in),
}
auth_api.refresh_token_save(refresh_token_values)
refresh_token_api.refresh_token_create(access_token.id,
refresh_token_values)
def invalidate_authorization_code(self, client_id, code, request, *args,
**kwargs):
@ -268,17 +270,7 @@ class SkeletonValidator(RequestValidator):
def validate_refresh_token(self, refresh_token, client, request, *args,
**kwargs):
"""Check that the refresh token exists in the db."""
refresh_token_entry = auth_api.refresh_token_get(refresh_token)
if not refresh_token_entry:
return False
if datetime.datetime.now(pytz.utc) > refresh_token_entry.expires_at:
auth_api.refresh_token_delete(refresh_token)
return False
return True
return refresh_token_api.is_valid(refresh_token)
def invalidate_refresh_token(self, request):
"""Remove a used token from the storage."""
@ -290,8 +282,10 @@ class SkeletonValidator(RequestValidator):
if not refresh_token:
return
r_token = auth_api.refresh_token_get(refresh_token)
token_api.access_token_delete(r_token.access_token_id) # Cascades
r_token = refresh_token_api.refresh_token_get_by_token(refresh_token)
token_api.access_token_delete(
refresh_token_api.get_access_token_id(r_token.id)
) # Cascades
class OpenIdConnectServer(WebApplicationServer):

View File

@ -27,7 +27,8 @@ from storyboard.api.auth.oauth_validator import SERVER
from storyboard.api.auth.openid_client import client as openid_client
from storyboard.common import decorators
from storyboard.common.exception import UnsupportedGrantType
from storyboard.db.api import auth as auth_api
from storyboard.db.api import auth_codes as auth_api
from storyboard.db.api import refresh_tokens as refresh_token_api
LOG = log.getLogger(__name__)
@ -101,7 +102,8 @@ class AuthController(rest.RestController):
def _access_token_by_refresh_token(self):
refresh_token = request.params.get("refresh_token")
refresh_token_info = auth_api.refresh_token_get(refresh_token)
refresh_token_info = \
refresh_token_api.refresh_token_get_by_token(refresh_token)
headers, body, code = SERVER.create_token_response(
uri=request.url,

View File

@ -29,7 +29,7 @@ import wsmeext.pecan as wsme_pecan
from storyboard.api.auth import authorization_checks as checks
import storyboard.api.v1.wmodels as wmodels
from storyboard.common import decorators
import storyboard.db.api.access_tokens as token_api
import storyboard.db.api.user_tokens as token_api
import storyboard.db.api.users as user_api
from storyboard.openstack.common.gettextutils import _ # noqa
@ -39,6 +39,18 @@ LOG = log.getLogger(__name__)
class UserTokensController(rest.RestController):
def _from_db_model(self, access_token):
access_token_model = wmodels.AccessToken.from_db_model(
access_token,
skip_fields="refresh_tokens")
access_token_model.refresh_tokens = [
wmodels.RefreshToken.from_db_model(token)
for token in access_token.refresh_tokens
]
return access_token_model
@decorators.db_exceptions
@secure(checks.authenticated)
@wsme_pecan.wsexpose([wmodels.AccessToken], int, int, int, wtypes.text,
@ -62,23 +74,24 @@ class UserTokensController(rest.RestController):
limit = min(CONF.page_size_maximum, max(1, limit))
# Resolve the marker record.
marker_token = token_api.access_token_get(marker)
marker_token = token_api.user_token_get(marker)
tokens = token_api.access_token_get_all(marker=marker_token,
limit=limit,
user_id=user_id,
filter_non_public=True,
sort_field=sort_field,
sort_dir=sort_dir)
token_count = token_api.access_token_get_count(user_id=user_id)
tokens = token_api.user_token_get_all(marker=marker_token,
limit=limit,
user_id=user_id,
filter_non_public=True,
sort_field=sort_field,
sort_dir=sort_dir)
token_count = token_api.user_token_get_count(user_id=user_id)
# Apply the query response headers.
response.headers['X-Limit'] = str(limit)
response.headers['X-Total'] = str(token_count)
if marker_token:
response.headers['X-Marker'] = str(marker_token.id)
return [wmodels.AccessToken.from_db_model(t) for t in tokens]
return [self._from_db_model(t) for t in tokens]
@decorators.db_exceptions
@secure(checks.authenticated)
@ -90,13 +103,13 @@ class UserTokensController(rest.RestController):
:param access_token_id: The ID of the access token.
:return: The requested access token.
"""
access_token = token_api.access_token_get(access_token_id)
access_token = token_api.user_token_get(access_token_id)
self._assert_can_access(user_id, access_token)
if not access_token:
abort(404, _("Token not found."))
abort(404, _("Token %s not found.") % access_token_id)
return wmodels.AccessToken.from_db_model(access_token)
return self._from_db_model(access_token)
@decorators.db_exceptions
@secure(checks.authenticated)
@ -115,13 +128,24 @@ class UserTokensController(rest.RestController):
body.access_token = six.text_type(uuid.uuid4())
# Token duplication check.
dupes = token_api.access_token_get_all(access_token=body.access_token)
dupes = token_api.user_token_get_all(
access_token=body.access_token
)
if dupes:
abort(409, _('This token already exist.'))
token = token_api.access_token_create(body.as_dict())
token_dict = body.as_dict()
return wmodels.AccessToken.from_db_model(token)
if "refresh_tokens" in token_dict:
del token_dict["refresh_tokens"]
token = token_api.user_token_create(token_dict)
if not token:
abort(400, _("Can't create access token."))
return self._from_db_model(token)
@decorators.db_exceptions
@secure(checks.authenticated)
@ -135,21 +159,27 @@ class UserTokensController(rest.RestController):
:param body: The access token.
:return: The created access token.
"""
target_token = token_api.access_token_get(access_token_id)
target_token = token_api.user_token_get(access_token_id)
self._assert_can_access(user_id, body)
self._assert_can_access(user_id, target_token)
if not target_token:
abort(404, _("Token not found."))
abort(404, _("Token %s not found.") % access_token_id)
# We only allow updating the expiration date.
target_token.expires_in = body.expires_in
result_token = token_api.access_token_update(access_token_id,
target_token.as_dict())
token_dict = target_token.as_dict()
return wmodels.AccessToken.from_db_model(result_token)
if "refresh_tokens" in token_dict:
del token_dict["refresh_tokens"]
result_token = token_api.user_token_update(access_token_id,
token_dict)
return self._from_db_model(result_token)
@decorators.db_exceptions
@secure(checks.authenticated)
@ -161,13 +191,13 @@ class UserTokensController(rest.RestController):
:param access_token_id: The ID of the access token.
:return: Empty body, or error response.
"""
access_token = token_api.access_token_get(access_token_id)
access_token = token_api.user_token_get(access_token_id)
self._assert_can_access(user_id, access_token)
if not access_token:
abort(404, _("Token not found."))
abort(404, _("Token %s not found.") % access_token_id)
token_api.access_token_delete(access_token_id)
token_api.user_token_delete(access_token_id)
def _assert_can_access(self, user_id, token_entity=None):
current_user = user_api.user_get(request.current_user_id)

View File

@ -389,6 +389,27 @@ class User(base.APIBase):
last_login=datetime(2014, 1, 1, 16, 42))
class RefreshToken(base.APIBase):
"""Represents a user refresh token."""
user_id = int
"""The ID of corresponding user."""
refresh_token = wtypes.text
"""The refresh token."""
expires_in = int
"""The number of seconds after creation when this token expires."""
@classmethod
def sample(cls):
return cls(
user_id=1,
refresh_token="a_unique_refresh_token",
expires_in=3600
)
class AccessToken(base.APIBase):
"""Represents a user access token."""
@ -401,6 +422,9 @@ class AccessToken(base.APIBase):
expires_in = int
"""The number of seconds after creation when this token expires."""
refresh_tokens = wtypes.ArrayType(RefreshToken)
"""Array of corresponding refresh tokens."""
@classmethod
def sample(cls):
return cls(

View File

@ -20,8 +20,10 @@ from storyboard.db.api import base as api_base
from storyboard.db import models
def access_token_get(access_token_id):
return api_base.entity_get(models.AccessToken, access_token_id)
def access_token_get(access_token_id, session=None):
return api_base.entity_get(models.AccessToken,
access_token_id,
session=session)
def access_token_get_by_token(access_token):
@ -102,15 +104,29 @@ def access_token_build_query(**kwargs):
return query
def access_token_delete(access_token_id):
session = api_base.get_session()
with session.begin():
access_token = access_token_get(access_token_id, session=session)
if access_token:
query = api_base.model_query(models.RefreshToken)
refresh_tokens = []
for refresh_token in access_token.refresh_tokens:
refresh_tokens.append(refresh_token.id)
query = query.filter(models.RefreshToken.id.in_(
refresh_tokens
))
api_base.entity_hard_delete(models.AccessToken, access_token_id)
query.delete(synchronize_session=False)
def access_token_delete_by_token(access_token):
access_token = access_token_get_by_token(access_token)
if access_token:
api_base.entity_hard_delete(models.AccessToken, access_token.id)
def access_token_delete(access_token_id):
access_token = access_token_get(access_token_id)
if access_token:
api_base.entity_hard_delete(models.AccessToken, access_token_id)
access_token_delete(access_token.id)

View File

@ -32,25 +32,3 @@ def authorization_code_delete(code):
if del_code:
api_base.entity_hard_delete(models.AuthorizationCode, del_code.id)
def refresh_token_get(refresh_token):
try:
query = api_base.model_query(models.RefreshToken,
api_base.get_session())
return query.filter_by(refresh_token=refresh_token).first()
except Exception:
# If anything goes wrong while fetching a token None will be returned
# anyway.
return None
def refresh_token_save(values):
return api_base.entity_create(models.RefreshToken, values)
def refresh_token_delete(refresh_token):
del_token = refresh_token_get(refresh_token)
if del_token:
api_base.entity_hard_delete(models.RefreshToken, del_token.id)

View File

@ -0,0 +1,117 @@
# Copyright (c) 2015 Mirantis Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import pytz
from storyboard.common import exception as exc
from storyboard.db.api import access_tokens as access_tokens_api
from storyboard.db.api import base as api_base
from storyboard.db import models
from storyboard.openstack.common.gettextutils import _ # noqa
def refresh_token_get(refresh_token_id, session=None):
return api_base.entity_get(models.RefreshToken, refresh_token_id,
session=session)
def refresh_token_get_by_token(refresh_token):
try:
return api_base.model_query(models.RefreshToken) \
.filter_by(refresh_token=refresh_token).first()
except Exception:
return None
def is_valid(refresh_token):
if not refresh_token:
return False
token = refresh_token_get_by_token(refresh_token)
if not token:
return False
if datetime.datetime.now(pytz.utc) > token.expires_at:
refresh_token_delete(token.id)
return False
return True
def get_access_token_id(refresh_token_id):
session = api_base.get_session()
with session.begin():
refresh_token = refresh_token_get(refresh_token_id, session)
if refresh_token:
return refresh_token.access_tokens[0].id
def refresh_token_create(access_token_id, values):
session = api_base.get_session()
with session.begin():
values['expires_at'] = datetime.datetime.now(pytz.utc) + datetime.\
timedelta(seconds=values['expires_in'])
access_token = access_tokens_api.access_token_get(access_token_id,
session=session)
if not access_token:
raise exc.NotFound(_("Access token not found."))
refresh_token = api_base.entity_create(models.RefreshToken, values)
access_token.refresh_tokens.append(refresh_token)
session.add(access_token)
return refresh_token
def refresh_token_build_query(**kwargs):
# Construct the query
query = api_base.model_query(models.RefreshToken)
# Apply the filters
query = api_base.apply_query_filters(query=query,
model=models.RefreshToken,
**kwargs)
return query
def refresh_token_delete(refresh_token_id):
session = api_base.get_session()
with session.begin():
refresh_token = refresh_token_get(refresh_token_id)
if refresh_token:
access_token_id = refresh_token.access_tokens[0].id
access_token = access_tokens_api.access_token_get(access_token_id,
session=session)
access_token.refresh_tokens.remove(refresh_token)
session.add(access_token)
api_base.entity_hard_delete(models.RefreshToken, refresh_token_id)
def refresh_token_delete_by_token(refresh_token):
refresh_token = refresh_token_get_by_token(refresh_token)
if refresh_token:
refresh_token_delete(refresh_token.id)

View File

@ -0,0 +1,84 @@
# Copyright (c) 2015 Mirantis Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from sqlalchemy.orm import subqueryload
from storyboard.db.api import access_tokens as access_tokens_api
from storyboard.db.api import base as api_base
from storyboard.db import models
def user_token_get(access_token_id):
token_query = api_base.model_query(models.AccessToken)
user_token = token_query.options(
subqueryload(models.AccessToken.refresh_tokens)
).filter_by(id=access_token_id).first()
return user_token
def user_token_get_all(marker=None, limit=None, sort_field=None,
sort_dir=None, **kwargs):
if not sort_field:
sort_field = 'id'
if not sort_dir:
sort_dir = 'asc'
query = api_base.model_query(models.AccessToken). \
options(subqueryload(models.AccessToken.refresh_tokens))
query = api_base.apply_query_filters(query=query,
model=models.AccessToken,
**kwargs)
query = api_base.paginate_query(query=query,
model=models.AccessToken,
limit=limit,
sort_key=sort_field,
marker=marker,
sort_dir=sort_dir)
return query.all()
def user_token_get_count(**kwargs):
query = api_base.model_query(models.AccessToken)
query = api_base.apply_query_filters(query=query,
model=models.AccessToken,
**kwargs)
return query.count()
def user_token_create(values):
access_token = access_tokens_api.access_token_create(values)
return user_token_get(access_token.id)
def user_token_update(access_token_id, values):
access_token = access_tokens_api.access_token_update(
access_token_id,
values)
if access_token:
return user_token_get(access_token.id)
else:
return None
def user_token_delete(access_token_id):
access_tokens_api.access_token_delete(access_token_id)

View File

@ -0,0 +1,50 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
"""This migration creates new table for relations between access and refresh
tokens.
Revision ID: 041
Revises: 040
Create Date: 2015-02-18 18:03:23
"""
# revision identifiers, used by Alembic.
revision = '041'
down_revision = '040'
from alembic import op
import sqlalchemy as sa
MYSQL_ENGINE = 'InnoDB'
MYSQL_CHARSET = 'utf8'
def upgrade(active_plugins=None, options=None):
op.create_table('access_refresh_tokens',
sa.Column('access_token_id', sa.Integer(), nullable=False),
sa.Column('refresh_token_id', sa.Integer(),
nullable=False),
mysql_engine=MYSQL_ENGINE,
mysql_charset=MYSQL_CHARSET
)
op.drop_column(u'refreshtokens', u'access_token_id')
def downgrade(active_plugins=None, options=None):
op.add_column('refreshtokens', sa.Column('access_token_id',
sa.Integer(),
nullable=False))
op.drop_table('access_refresh_tokens')

View File

@ -367,6 +367,14 @@ class AuthorizationCode(ModelBuilder, Base):
expires_in = Column(Integer, nullable=False, default=300)
access_refresh_tokens = Table(
'access_refresh_tokens', Base.metadata,
Column('access_token_id', Integer, ForeignKey('accesstokens.id')),
Column('refresh_token_id', Integer, ForeignKey('refreshtokens.id')),
schema.UniqueConstraint('refresh_token_id', name='uniq_refresh_token'),
)
class AccessToken(ModelBuilder, Base):
user_id = Column(Integer, ForeignKey('users.id'), nullable=False)
access_token = Column(Unicode(CommonLength.top_middle_length),
@ -374,16 +382,15 @@ class AccessToken(ModelBuilder, Base):
expires_in = Column(Integer, nullable=False)
expires_at = Column(UTCDateTime, nullable=False)
refresh_tokens = relationship("RefreshToken",
cascade="save-update, merge, delete",
passive_updates=False,
passive_deletes=False)
secondary='access_refresh_tokens',
)
class RefreshToken(ModelBuilder, Base):
user_id = Column(Integer, ForeignKey('users.id'), nullable=False)
access_token_id = Column(Integer,
ForeignKey('accesstokens.id'),
nullable=False)
access_tokens = relationship("AccessToken",
secondary='access_refresh_tokens',
)
refresh_token = Column(Unicode(CommonLength.top_middle_length),
nullable=False)
expires_in = Column(Integer, nullable=False)

View File

@ -18,6 +18,7 @@ import pytz
from apscheduler.triggers.interval import IntervalTrigger
from storyboard.db.api import access_tokens as access_tokens_api
import storyboard.db.api.base as api_base
from storyboard.db.models import AccessToken
from storyboard.plugin.scheduler.base import SchedulerPluginBase
@ -57,4 +58,4 @@ class TokenCleaner(SchedulerPluginBase):
# Manually deleting each record, because batch deletes are an
# exception to ORM Cascade markup.
for token in query.all():
api_base.entity_hard_delete(AccessToken, token.id)
access_tokens_api.access_token_delete(token.id)

View File

@ -25,7 +25,8 @@ import six.moves.urllib.parse as urlparse
from storyboard.api.auth import ErrorMessages as e_msg
from storyboard.db.api import access_tokens as token_api
from storyboard.db.api import auth as auth_api
from storyboard.db.api import auth_codes as auth_api
from storyboard.db.api import refresh_tokens
from storyboard.tests import base
@ -584,7 +585,7 @@ class TestOAuthAccessToken(BaseOAuthTest):
# Assert that the refresh token is in the database
refresh_token = \
auth_api.refresh_token_get(token['refresh_token'])
refresh_tokens.refresh_token_get_by_token(token['refresh_token'])
self.assertIsNotNone(refresh_token)
# Assert that system configured values is owned by the correct user.
@ -773,7 +774,7 @@ class TestOAuthAccessToken(BaseOAuthTest):
token_api.access_token_get_by_token(t1['access_token'])
self.assertIsNotNone(access_token)
refresh_token = \
auth_api.refresh_token_get(t1['refresh_token'])
refresh_tokens.refresh_token_get_by_token(t1['refresh_token'])
self.assertIsNotNone(refresh_token)
# Issue a refresh token request.
@ -814,7 +815,7 @@ class TestOAuthAccessToken(BaseOAuthTest):
# Assert that the refresh token is in the database
new_refresh_token = \
auth_api.refresh_token_get(t2['refresh_token'])
refresh_tokens.refresh_token_get_by_token(t2['refresh_token'])
self.assertIsNotNone(new_refresh_token)
# Assert that system configured values is owned by the correct user.
@ -829,7 +830,7 @@ class TestOAuthAccessToken(BaseOAuthTest):
no_access_token = \
token_api.access_token_get_by_token(t1['access_token'])
no_refresh_token = \
auth_api.refresh_token_get(t1['refresh_token'])
refresh_tokens.refresh_token_get_by_token(t1['refresh_token'])
self.assertIsNone(no_refresh_token)
self.assertIsNone(no_access_token)

View File

@ -28,7 +28,6 @@ class TestUserTokensAsUser(base.FunctionalTest):
"""
response = self.get_json(self.resource, expect_errors=True)
self.assertEqual(200, response.status_code)
self.assertEqual(len(response.json), 2)
def test_unauthorized_browse(self):
"""Assert a basic browse for someone else's tokens

View File

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from storyboard.db.api import auth
from storyboard.db.api import auth_codes
from storyboard.db.api import users
from storyboard.tests.db import base
@ -32,15 +32,15 @@ class AuthorizationCodeTest(base.BaseDbTestCase):
users.user_create({"fullname": "Test User"})
def test_create_code(self):
self._test_create(self.code_01, auth.authorization_code_save)
self._test_create(self.code_01, auth_codes.authorization_code_save)
def test_delete_code(self):
created_code = auth.authorization_code_save(self.code_01)
created_code = auth_codes.authorization_code_save(self.code_01)
self.assertIsNotNone(created_code,
"Could not create an Authorization code")
auth.authorization_code_delete(created_code.code)
auth_codes.authorization_code_delete(created_code.code)
fetched_code = auth.authorization_code_get(created_code.code)
fetched_code = auth_codes.authorization_code_get(created_code.code)
self.assertIsNone(fetched_code)

View File

@ -69,31 +69,33 @@ class TestTokenCleaner(db_base.BaseDbTestCase,
# 8-day-old-token to remain valid.
new_access_tokens = []
new_refresh_tokens = []
for i in range(0, 100):
created_at = datetime.now(pytz.utc) - timedelta(days=i)
expires_in = (60 * 60 * 24) - 5 # Minus five seconds, see above.
expires_at = created_at + timedelta(seconds=expires_in)
new_access_tokens.append(
AccessToken(
user_id=1,
created_at=created_at,
expires_in=expires_in,
expires_at=expires_at,
access_token='test_token_%s' % (i,))
access_token='test_token_%s' % (i,)),
)
new_access_tokens = load_data(new_access_tokens)
for token in new_access_tokens:
new_refresh_tokens.append(
RefreshToken(
user_id=1,
access_token_id=token.id,
created_at=token.created_at,
expires_in=token.expires_in,
expires_at=token.expires_at,
expires_in=300,
access_tokens=[token],
refresh_token='test_refresh_%s' % (token.id,))
)
new_refresh_tokens = load_data(new_refresh_tokens)
# Make sure we have 100 tokens.