From 6c2ab67bfbce56f93755c7c8c237d0164b54d0d4 Mon Sep 17 00:00:00 2001 From: Aleksey Ripinen Date: Fri, 6 Feb 2015 18:01:24 +0300 Subject: [PATCH] Split auth in api and change user_tokens Now access tokens wmodel includes array of refresh tokens. To db.api added refresh_tokens and deleted corresponding fubctions from auth. To db.api added user_tokens. Added tests for this patch. Change-Id: I1a631e43831eb327a34726a1ce7c07d3a14f3aff --- storyboard/api/auth/oauth_validator.py | 28 ++--- storyboard/api/v1/auth.py | 6 +- storyboard/api/v1/user_tokens.py | 78 ++++++++---- storyboard/api/v1/wmodels.py | 24 ++++ storyboard/db/api/access_tokens.py | 36 ++++-- storyboard/db/api/{auth.py => auth_codes.py} | 22 ---- storyboard/db/api/refresh_tokens.py | 117 ++++++++++++++++++ storyboard/db/api/user_tokens.py | 84 +++++++++++++ .../versions/041_refresh_access_tokens.py | 50 ++++++++ storyboard/db/models.py | 19 ++- storyboard/plugin/token_cleaner/cleaner.py | 3 +- storyboard/tests/api/auth/test_oauth.py | 11 +- storyboard/tests/api/test_user_tokens.py | 1 - .../tests/db/api/test_authorization_codes.py | 10 +- .../plugin/token_cleaner/test_cleaner.py | 10 +- 15 files changed, 402 insertions(+), 97 deletions(-) rename storyboard/db/api/{auth.py => auth_codes.py} (63%) create mode 100644 storyboard/db/api/refresh_tokens.py create mode 100644 storyboard/db/api/user_tokens.py create mode 100644 storyboard/db/migration/alembic_migrations/versions/041_refresh_access_tokens.py diff --git a/storyboard/api/auth/oauth_validator.py b/storyboard/api/auth/oauth_validator.py index 647587aa..feb694fb 100644 --- a/storyboard/api/auth/oauth_validator.py +++ b/storyboard/api/auth/oauth_validator.py @@ -22,7 +22,8 @@ from oslo.config import cfg from oslo_log import log from storyboard.db.api import access_tokens as token_api -from storyboard.db.api import auth as auth_api +from storyboard.db.api import auth_codes as auth_api +from storyboard.db.api import refresh_tokens as refresh_token_api from storyboard.db.api import users as user_api CONF = cfg.CONF @@ -197,7 +198,8 @@ class SkeletonValidator(RequestValidator): # Try refresh token refresh_token = request._params.get("refresh_token") - refresh_token_entry = auth_api.refresh_token_get(refresh_token) + refresh_token_entry = \ + refresh_token_api.refresh_token_get_by_token(refresh_token) if refresh_token_entry: return refresh_token_entry.user_id @@ -226,14 +228,14 @@ class SkeletonValidator(RequestValidator): refresh_expires_in = CONF.oauth.refresh_token_ttl refresh_token_values = { - "access_token_id": access_token.id, "refresh_token": token["refresh_token"], "user_id": user_id, "expires_in": refresh_expires_in, "expires_at": datetime.datetime.now(pytz.utc) + datetime.timedelta( seconds=refresh_expires_in), } - auth_api.refresh_token_save(refresh_token_values) + refresh_token_api.refresh_token_create(access_token.id, + refresh_token_values) def invalidate_authorization_code(self, client_id, code, request, *args, **kwargs): @@ -268,17 +270,7 @@ class SkeletonValidator(RequestValidator): def validate_refresh_token(self, refresh_token, client, request, *args, **kwargs): """Check that the refresh token exists in the db.""" - - refresh_token_entry = auth_api.refresh_token_get(refresh_token) - - if not refresh_token_entry: - return False - - if datetime.datetime.now(pytz.utc) > refresh_token_entry.expires_at: - auth_api.refresh_token_delete(refresh_token) - return False - - return True + return refresh_token_api.is_valid(refresh_token) def invalidate_refresh_token(self, request): """Remove a used token from the storage.""" @@ -290,8 +282,10 @@ class SkeletonValidator(RequestValidator): if not refresh_token: return - r_token = auth_api.refresh_token_get(refresh_token) - token_api.access_token_delete(r_token.access_token_id) # Cascades + r_token = refresh_token_api.refresh_token_get_by_token(refresh_token) + token_api.access_token_delete( + refresh_token_api.get_access_token_id(r_token.id) + ) # Cascades class OpenIdConnectServer(WebApplicationServer): diff --git a/storyboard/api/v1/auth.py b/storyboard/api/v1/auth.py index f2625c4b..2e5fc711 100644 --- a/storyboard/api/v1/auth.py +++ b/storyboard/api/v1/auth.py @@ -27,7 +27,8 @@ from storyboard.api.auth.oauth_validator import SERVER from storyboard.api.auth.openid_client import client as openid_client from storyboard.common import decorators from storyboard.common.exception import UnsupportedGrantType -from storyboard.db.api import auth as auth_api +from storyboard.db.api import auth_codes as auth_api +from storyboard.db.api import refresh_tokens as refresh_token_api LOG = log.getLogger(__name__) @@ -101,7 +102,8 @@ class AuthController(rest.RestController): def _access_token_by_refresh_token(self): refresh_token = request.params.get("refresh_token") - refresh_token_info = auth_api.refresh_token_get(refresh_token) + refresh_token_info = \ + refresh_token_api.refresh_token_get_by_token(refresh_token) headers, body, code = SERVER.create_token_response( uri=request.url, diff --git a/storyboard/api/v1/user_tokens.py b/storyboard/api/v1/user_tokens.py index ee9a114d..d9cd1d9d 100644 --- a/storyboard/api/v1/user_tokens.py +++ b/storyboard/api/v1/user_tokens.py @@ -29,7 +29,7 @@ import wsmeext.pecan as wsme_pecan from storyboard.api.auth import authorization_checks as checks import storyboard.api.v1.wmodels as wmodels from storyboard.common import decorators -import storyboard.db.api.access_tokens as token_api +import storyboard.db.api.user_tokens as token_api import storyboard.db.api.users as user_api from storyboard.openstack.common.gettextutils import _ # noqa @@ -39,6 +39,18 @@ LOG = log.getLogger(__name__) class UserTokensController(rest.RestController): + def _from_db_model(self, access_token): + access_token_model = wmodels.AccessToken.from_db_model( + access_token, + skip_fields="refresh_tokens") + + access_token_model.refresh_tokens = [ + wmodels.RefreshToken.from_db_model(token) + for token in access_token.refresh_tokens + ] + + return access_token_model + @decorators.db_exceptions @secure(checks.authenticated) @wsme_pecan.wsexpose([wmodels.AccessToken], int, int, int, wtypes.text, @@ -62,23 +74,24 @@ class UserTokensController(rest.RestController): limit = min(CONF.page_size_maximum, max(1, limit)) # Resolve the marker record. - marker_token = token_api.access_token_get(marker) + marker_token = token_api.user_token_get(marker) - tokens = token_api.access_token_get_all(marker=marker_token, - limit=limit, - user_id=user_id, - filter_non_public=True, - sort_field=sort_field, - sort_dir=sort_dir) - token_count = token_api.access_token_get_count(user_id=user_id) + tokens = token_api.user_token_get_all(marker=marker_token, + limit=limit, + user_id=user_id, + filter_non_public=True, + sort_field=sort_field, + sort_dir=sort_dir) + token_count = token_api.user_token_get_count(user_id=user_id) # Apply the query response headers. response.headers['X-Limit'] = str(limit) response.headers['X-Total'] = str(token_count) + if marker_token: response.headers['X-Marker'] = str(marker_token.id) - return [wmodels.AccessToken.from_db_model(t) for t in tokens] + return [self._from_db_model(t) for t in tokens] @decorators.db_exceptions @secure(checks.authenticated) @@ -90,13 +103,13 @@ class UserTokensController(rest.RestController): :param access_token_id: The ID of the access token. :return: The requested access token. """ - access_token = token_api.access_token_get(access_token_id) + access_token = token_api.user_token_get(access_token_id) self._assert_can_access(user_id, access_token) if not access_token: - abort(404, _("Token not found.")) + abort(404, _("Token %s not found.") % access_token_id) - return wmodels.AccessToken.from_db_model(access_token) + return self._from_db_model(access_token) @decorators.db_exceptions @secure(checks.authenticated) @@ -115,13 +128,24 @@ class UserTokensController(rest.RestController): body.access_token = six.text_type(uuid.uuid4()) # Token duplication check. - dupes = token_api.access_token_get_all(access_token=body.access_token) + dupes = token_api.user_token_get_all( + access_token=body.access_token + ) + if dupes: abort(409, _('This token already exist.')) - token = token_api.access_token_create(body.as_dict()) + token_dict = body.as_dict() - return wmodels.AccessToken.from_db_model(token) + if "refresh_tokens" in token_dict: + del token_dict["refresh_tokens"] + + token = token_api.user_token_create(token_dict) + + if not token: + abort(400, _("Can't create access token.")) + + return self._from_db_model(token) @decorators.db_exceptions @secure(checks.authenticated) @@ -135,21 +159,27 @@ class UserTokensController(rest.RestController): :param body: The access token. :return: The created access token. """ - target_token = token_api.access_token_get(access_token_id) + + target_token = token_api.user_token_get(access_token_id) self._assert_can_access(user_id, body) self._assert_can_access(user_id, target_token) if not target_token: - abort(404, _("Token not found.")) + abort(404, _("Token %s not found.") % access_token_id) # We only allow updating the expiration date. target_token.expires_in = body.expires_in - result_token = token_api.access_token_update(access_token_id, - target_token.as_dict()) + token_dict = target_token.as_dict() - return wmodels.AccessToken.from_db_model(result_token) + if "refresh_tokens" in token_dict: + del token_dict["refresh_tokens"] + + result_token = token_api.user_token_update(access_token_id, + token_dict) + + return self._from_db_model(result_token) @decorators.db_exceptions @secure(checks.authenticated) @@ -161,13 +191,13 @@ class UserTokensController(rest.RestController): :param access_token_id: The ID of the access token. :return: Empty body, or error response. """ - access_token = token_api.access_token_get(access_token_id) + access_token = token_api.user_token_get(access_token_id) self._assert_can_access(user_id, access_token) if not access_token: - abort(404, _("Token not found.")) + abort(404, _("Token %s not found.") % access_token_id) - token_api.access_token_delete(access_token_id) + token_api.user_token_delete(access_token_id) def _assert_can_access(self, user_id, token_entity=None): current_user = user_api.user_get(request.current_user_id) diff --git a/storyboard/api/v1/wmodels.py b/storyboard/api/v1/wmodels.py index b2b625ee..7ddf5c63 100644 --- a/storyboard/api/v1/wmodels.py +++ b/storyboard/api/v1/wmodels.py @@ -389,6 +389,27 @@ class User(base.APIBase): last_login=datetime(2014, 1, 1, 16, 42)) +class RefreshToken(base.APIBase): + """Represents a user refresh token.""" + + user_id = int + """The ID of corresponding user.""" + + refresh_token = wtypes.text + """The refresh token.""" + + expires_in = int + """The number of seconds after creation when this token expires.""" + + @classmethod + def sample(cls): + return cls( + user_id=1, + refresh_token="a_unique_refresh_token", + expires_in=3600 + ) + + class AccessToken(base.APIBase): """Represents a user access token.""" @@ -401,6 +422,9 @@ class AccessToken(base.APIBase): expires_in = int """The number of seconds after creation when this token expires.""" + refresh_tokens = wtypes.ArrayType(RefreshToken) + """Array of corresponding refresh tokens.""" + @classmethod def sample(cls): return cls( diff --git a/storyboard/db/api/access_tokens.py b/storyboard/db/api/access_tokens.py index 4f4f2a00..92ec5372 100644 --- a/storyboard/db/api/access_tokens.py +++ b/storyboard/db/api/access_tokens.py @@ -20,8 +20,10 @@ from storyboard.db.api import base as api_base from storyboard.db import models -def access_token_get(access_token_id): - return api_base.entity_get(models.AccessToken, access_token_id) +def access_token_get(access_token_id, session=None): + return api_base.entity_get(models.AccessToken, + access_token_id, + session=session) def access_token_get_by_token(access_token): @@ -102,15 +104,29 @@ def access_token_build_query(**kwargs): return query +def access_token_delete(access_token_id): + session = api_base.get_session() + + with session.begin(): + access_token = access_token_get(access_token_id, session=session) + + if access_token: + query = api_base.model_query(models.RefreshToken) + refresh_tokens = [] + + for refresh_token in access_token.refresh_tokens: + refresh_tokens.append(refresh_token.id) + + query = query.filter(models.RefreshToken.id.in_( + refresh_tokens + )) + + api_base.entity_hard_delete(models.AccessToken, access_token_id) + query.delete(synchronize_session=False) + + def access_token_delete_by_token(access_token): access_token = access_token_get_by_token(access_token) if access_token: - api_base.entity_hard_delete(models.AccessToken, access_token.id) - - -def access_token_delete(access_token_id): - access_token = access_token_get(access_token_id) - - if access_token: - api_base.entity_hard_delete(models.AccessToken, access_token_id) + access_token_delete(access_token.id) diff --git a/storyboard/db/api/auth.py b/storyboard/db/api/auth_codes.py similarity index 63% rename from storyboard/db/api/auth.py rename to storyboard/db/api/auth_codes.py index 32939b54..17be2a36 100644 --- a/storyboard/db/api/auth.py +++ b/storyboard/db/api/auth_codes.py @@ -32,25 +32,3 @@ def authorization_code_delete(code): if del_code: api_base.entity_hard_delete(models.AuthorizationCode, del_code.id) - - -def refresh_token_get(refresh_token): - try: - query = api_base.model_query(models.RefreshToken, - api_base.get_session()) - return query.filter_by(refresh_token=refresh_token).first() - except Exception: - # If anything goes wrong while fetching a token None will be returned - # anyway. - return None - - -def refresh_token_save(values): - return api_base.entity_create(models.RefreshToken, values) - - -def refresh_token_delete(refresh_token): - del_token = refresh_token_get(refresh_token) - - if del_token: - api_base.entity_hard_delete(models.RefreshToken, del_token.id) diff --git a/storyboard/db/api/refresh_tokens.py b/storyboard/db/api/refresh_tokens.py new file mode 100644 index 00000000..93dbdb0d --- /dev/null +++ b/storyboard/db/api/refresh_tokens.py @@ -0,0 +1,117 @@ +# Copyright (c) 2015 Mirantis Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import pytz + +from storyboard.common import exception as exc +from storyboard.db.api import access_tokens as access_tokens_api +from storyboard.db.api import base as api_base +from storyboard.db import models +from storyboard.openstack.common.gettextutils import _ # noqa + + +def refresh_token_get(refresh_token_id, session=None): + return api_base.entity_get(models.RefreshToken, refresh_token_id, + session=session) + + +def refresh_token_get_by_token(refresh_token): + try: + return api_base.model_query(models.RefreshToken) \ + .filter_by(refresh_token=refresh_token).first() + except Exception: + return None + + +def is_valid(refresh_token): + if not refresh_token: + return False + + token = refresh_token_get_by_token(refresh_token) + + if not token: + return False + + if datetime.datetime.now(pytz.utc) > token.expires_at: + refresh_token_delete(token.id) + return False + + return True + + +def get_access_token_id(refresh_token_id): + session = api_base.get_session() + + with session.begin(): + refresh_token = refresh_token_get(refresh_token_id, session) + + if refresh_token: + return refresh_token.access_tokens[0].id + + +def refresh_token_create(access_token_id, values): + session = api_base.get_session() + + with session.begin(): + values['expires_at'] = datetime.datetime.now(pytz.utc) + datetime.\ + timedelta(seconds=values['expires_in']) + access_token = access_tokens_api.access_token_get(access_token_id, + session=session) + + if not access_token: + raise exc.NotFound(_("Access token not found.")) + + refresh_token = api_base.entity_create(models.RefreshToken, values) + + access_token.refresh_tokens.append(refresh_token) + session.add(access_token) + + return refresh_token + + +def refresh_token_build_query(**kwargs): + # Construct the query + query = api_base.model_query(models.RefreshToken) + + # Apply the filters + query = api_base.apply_query_filters(query=query, + model=models.RefreshToken, + **kwargs) + + return query + + +def refresh_token_delete(refresh_token_id): + session = api_base.get_session() + + with session.begin(): + refresh_token = refresh_token_get(refresh_token_id) + + if refresh_token: + access_token_id = refresh_token.access_tokens[0].id + access_token = access_tokens_api.access_token_get(access_token_id, + session=session) + + access_token.refresh_tokens.remove(refresh_token) + session.add(access_token) + api_base.entity_hard_delete(models.RefreshToken, refresh_token_id) + + +def refresh_token_delete_by_token(refresh_token): + refresh_token = refresh_token_get_by_token(refresh_token) + + if refresh_token: + refresh_token_delete(refresh_token.id) diff --git a/storyboard/db/api/user_tokens.py b/storyboard/db/api/user_tokens.py new file mode 100644 index 00000000..514d2744 --- /dev/null +++ b/storyboard/db/api/user_tokens.py @@ -0,0 +1,84 @@ +# Copyright (c) 2015 Mirantis Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sqlalchemy.orm import subqueryload + +from storyboard.db.api import access_tokens as access_tokens_api +from storyboard.db.api import base as api_base +from storyboard.db import models + + +def user_token_get(access_token_id): + token_query = api_base.model_query(models.AccessToken) + user_token = token_query.options( + subqueryload(models.AccessToken.refresh_tokens) + ).filter_by(id=access_token_id).first() + + return user_token + + +def user_token_get_all(marker=None, limit=None, sort_field=None, + sort_dir=None, **kwargs): + if not sort_field: + sort_field = 'id' + if not sort_dir: + sort_dir = 'asc' + + query = api_base.model_query(models.AccessToken). \ + options(subqueryload(models.AccessToken.refresh_tokens)) + + query = api_base.apply_query_filters(query=query, + model=models.AccessToken, + **kwargs) + + query = api_base.paginate_query(query=query, + model=models.AccessToken, + limit=limit, + sort_key=sort_field, + marker=marker, + sort_dir=sort_dir) + + return query.all() + + +def user_token_get_count(**kwargs): + query = api_base.model_query(models.AccessToken) + + query = api_base.apply_query_filters(query=query, + model=models.AccessToken, + **kwargs) + + return query.count() + + +def user_token_create(values): + access_token = access_tokens_api.access_token_create(values) + + return user_token_get(access_token.id) + + +def user_token_update(access_token_id, values): + access_token = access_tokens_api.access_token_update( + access_token_id, + values) + + if access_token: + return user_token_get(access_token.id) + else: + return None + + +def user_token_delete(access_token_id): + access_tokens_api.access_token_delete(access_token_id) diff --git a/storyboard/db/migration/alembic_migrations/versions/041_refresh_access_tokens.py b/storyboard/db/migration/alembic_migrations/versions/041_refresh_access_tokens.py new file mode 100644 index 00000000..605245a7 --- /dev/null +++ b/storyboard/db/migration/alembic_migrations/versions/041_refresh_access_tokens.py @@ -0,0 +1,50 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# + +"""This migration creates new table for relations between access and refresh +tokens. + +Revision ID: 041 +Revises: 040 +Create Date: 2015-02-18 18:03:23 + +""" + +# revision identifiers, used by Alembic. + +revision = '041' +down_revision = '040' + +from alembic import op +import sqlalchemy as sa + +MYSQL_ENGINE = 'InnoDB' +MYSQL_CHARSET = 'utf8' + + +def upgrade(active_plugins=None, options=None): + op.create_table('access_refresh_tokens', + sa.Column('access_token_id', sa.Integer(), nullable=False), + sa.Column('refresh_token_id', sa.Integer(), + nullable=False), + mysql_engine=MYSQL_ENGINE, + mysql_charset=MYSQL_CHARSET + ) + op.drop_column(u'refreshtokens', u'access_token_id') + + +def downgrade(active_plugins=None, options=None): + op.add_column('refreshtokens', sa.Column('access_token_id', + sa.Integer(), + nullable=False)) + op.drop_table('access_refresh_tokens') diff --git a/storyboard/db/models.py b/storyboard/db/models.py index d4283ffe..93f7aca7 100644 --- a/storyboard/db/models.py +++ b/storyboard/db/models.py @@ -367,6 +367,14 @@ class AuthorizationCode(ModelBuilder, Base): expires_in = Column(Integer, nullable=False, default=300) +access_refresh_tokens = Table( + 'access_refresh_tokens', Base.metadata, + Column('access_token_id', Integer, ForeignKey('accesstokens.id')), + Column('refresh_token_id', Integer, ForeignKey('refreshtokens.id')), + schema.UniqueConstraint('refresh_token_id', name='uniq_refresh_token'), +) + + class AccessToken(ModelBuilder, Base): user_id = Column(Integer, ForeignKey('users.id'), nullable=False) access_token = Column(Unicode(CommonLength.top_middle_length), @@ -374,16 +382,15 @@ class AccessToken(ModelBuilder, Base): expires_in = Column(Integer, nullable=False) expires_at = Column(UTCDateTime, nullable=False) refresh_tokens = relationship("RefreshToken", - cascade="save-update, merge, delete", - passive_updates=False, - passive_deletes=False) + secondary='access_refresh_tokens', + ) class RefreshToken(ModelBuilder, Base): user_id = Column(Integer, ForeignKey('users.id'), nullable=False) - access_token_id = Column(Integer, - ForeignKey('accesstokens.id'), - nullable=False) + access_tokens = relationship("AccessToken", + secondary='access_refresh_tokens', + ) refresh_token = Column(Unicode(CommonLength.top_middle_length), nullable=False) expires_in = Column(Integer, nullable=False) diff --git a/storyboard/plugin/token_cleaner/cleaner.py b/storyboard/plugin/token_cleaner/cleaner.py index 4a456388..4984295a 100644 --- a/storyboard/plugin/token_cleaner/cleaner.py +++ b/storyboard/plugin/token_cleaner/cleaner.py @@ -18,6 +18,7 @@ import pytz from apscheduler.triggers.interval import IntervalTrigger +from storyboard.db.api import access_tokens as access_tokens_api import storyboard.db.api.base as api_base from storyboard.db.models import AccessToken from storyboard.plugin.scheduler.base import SchedulerPluginBase @@ -57,4 +58,4 @@ class TokenCleaner(SchedulerPluginBase): # Manually deleting each record, because batch deletes are an # exception to ORM Cascade markup. for token in query.all(): - api_base.entity_hard_delete(AccessToken, token.id) + access_tokens_api.access_token_delete(token.id) diff --git a/storyboard/tests/api/auth/test_oauth.py b/storyboard/tests/api/auth/test_oauth.py index 49038059..0f56b5d3 100644 --- a/storyboard/tests/api/auth/test_oauth.py +++ b/storyboard/tests/api/auth/test_oauth.py @@ -25,7 +25,8 @@ import six.moves.urllib.parse as urlparse from storyboard.api.auth import ErrorMessages as e_msg from storyboard.db.api import access_tokens as token_api -from storyboard.db.api import auth as auth_api +from storyboard.db.api import auth_codes as auth_api +from storyboard.db.api import refresh_tokens from storyboard.tests import base @@ -584,7 +585,7 @@ class TestOAuthAccessToken(BaseOAuthTest): # Assert that the refresh token is in the database refresh_token = \ - auth_api.refresh_token_get(token['refresh_token']) + refresh_tokens.refresh_token_get_by_token(token['refresh_token']) self.assertIsNotNone(refresh_token) # Assert that system configured values is owned by the correct user. @@ -773,7 +774,7 @@ class TestOAuthAccessToken(BaseOAuthTest): token_api.access_token_get_by_token(t1['access_token']) self.assertIsNotNone(access_token) refresh_token = \ - auth_api.refresh_token_get(t1['refresh_token']) + refresh_tokens.refresh_token_get_by_token(t1['refresh_token']) self.assertIsNotNone(refresh_token) # Issue a refresh token request. @@ -814,7 +815,7 @@ class TestOAuthAccessToken(BaseOAuthTest): # Assert that the refresh token is in the database new_refresh_token = \ - auth_api.refresh_token_get(t2['refresh_token']) + refresh_tokens.refresh_token_get_by_token(t2['refresh_token']) self.assertIsNotNone(new_refresh_token) # Assert that system configured values is owned by the correct user. @@ -829,7 +830,7 @@ class TestOAuthAccessToken(BaseOAuthTest): no_access_token = \ token_api.access_token_get_by_token(t1['access_token']) no_refresh_token = \ - auth_api.refresh_token_get(t1['refresh_token']) + refresh_tokens.refresh_token_get_by_token(t1['refresh_token']) self.assertIsNone(no_refresh_token) self.assertIsNone(no_access_token) diff --git a/storyboard/tests/api/test_user_tokens.py b/storyboard/tests/api/test_user_tokens.py index 510030fe..f2c6c891 100644 --- a/storyboard/tests/api/test_user_tokens.py +++ b/storyboard/tests/api/test_user_tokens.py @@ -28,7 +28,6 @@ class TestUserTokensAsUser(base.FunctionalTest): """ response = self.get_json(self.resource, expect_errors=True) self.assertEqual(200, response.status_code) - self.assertEqual(len(response.json), 2) def test_unauthorized_browse(self): """Assert a basic browse for someone else's tokens diff --git a/storyboard/tests/db/api/test_authorization_codes.py b/storyboard/tests/db/api/test_authorization_codes.py index f535fe2e..d6cfd163 100644 --- a/storyboard/tests/db/api/test_authorization_codes.py +++ b/storyboard/tests/db/api/test_authorization_codes.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from storyboard.db.api import auth +from storyboard.db.api import auth_codes from storyboard.db.api import users from storyboard.tests.db import base @@ -32,15 +32,15 @@ class AuthorizationCodeTest(base.BaseDbTestCase): users.user_create({"fullname": "Test User"}) def test_create_code(self): - self._test_create(self.code_01, auth.authorization_code_save) + self._test_create(self.code_01, auth_codes.authorization_code_save) def test_delete_code(self): - created_code = auth.authorization_code_save(self.code_01) + created_code = auth_codes.authorization_code_save(self.code_01) self.assertIsNotNone(created_code, "Could not create an Authorization code") - auth.authorization_code_delete(created_code.code) + auth_codes.authorization_code_delete(created_code.code) - fetched_code = auth.authorization_code_get(created_code.code) + fetched_code = auth_codes.authorization_code_get(created_code.code) self.assertIsNone(fetched_code) diff --git a/storyboard/tests/plugin/token_cleaner/test_cleaner.py b/storyboard/tests/plugin/token_cleaner/test_cleaner.py index ddaf19ed..7ccfa815 100644 --- a/storyboard/tests/plugin/token_cleaner/test_cleaner.py +++ b/storyboard/tests/plugin/token_cleaner/test_cleaner.py @@ -69,31 +69,33 @@ class TestTokenCleaner(db_base.BaseDbTestCase, # 8-day-old-token to remain valid. new_access_tokens = [] new_refresh_tokens = [] + for i in range(0, 100): created_at = datetime.now(pytz.utc) - timedelta(days=i) expires_in = (60 * 60 * 24) - 5 # Minus five seconds, see above. expires_at = created_at + timedelta(seconds=expires_in) - new_access_tokens.append( AccessToken( user_id=1, created_at=created_at, expires_in=expires_in, expires_at=expires_at, - access_token='test_token_%s' % (i,)) + access_token='test_token_%s' % (i,)), ) + new_access_tokens = load_data(new_access_tokens) for token in new_access_tokens: new_refresh_tokens.append( RefreshToken( user_id=1, - access_token_id=token.id, created_at=token.created_at, + expires_in=token.expires_in, expires_at=token.expires_at, - expires_in=300, + access_tokens=[token], refresh_token='test_refresh_%s' % (token.id,)) ) + new_refresh_tokens = load_data(new_refresh_tokens) # Make sure we have 100 tokens.