From e42c6e38bf1acffeac732fd1980fd688b865b51a Mon Sep 17 00:00:00 2001 From: Michael Krotscheck Date: Mon, 13 Apr 2015 17:43:58 -0700 Subject: [PATCH] One-To-One relationship for Access and Refresh Tokens This changes the relationship in our access and refresh token table from a many-to-many to a one-to-one. This is for two reasons: Firstly, a many-to-many means it's possible to have a refresh token issued that can refresh multiple access tokens, which actually doesn't make a lot of sense, since our logic _also_ deletes all access tokens associated with a used refresh token. Secondly, many-to-many relationships do not support cascading updates in SQLAlchemy, so our goal to manage all of our foreign relationships in SQLA is not achievable if we maintain these kinds of relationships. Change-Id: Iab3c65ac403961fccf7db15ef930f9d896626108 --- storyboard/api/auth/oauth_validator.py | 4 +- storyboard/api/v1/user_tokens.py | 17 +++-- storyboard/api/v1/wmodels.py | 4 +- storyboard/db/api/access_tokens.py | 13 +--- storyboard/db/api/refresh_tokens.py | 23 +------ storyboard/db/api/user_tokens.py | 13 +--- .../versions/048_refresh_access_token_link.py | 65 +++++++++++++++++++ storyboard/db/models.py | 23 +++---- storyboard/tests/db/api/test_access_tokens.py | 1 - .../plugin/token_cleaner/test_cleaner.py | 2 +- 10 files changed, 94 insertions(+), 71 deletions(-) create mode 100644 storyboard/db/migration/alembic_migrations/versions/048_refresh_access_token_link.py diff --git a/storyboard/api/auth/oauth_validator.py b/storyboard/api/auth/oauth_validator.py index b3faca43..59701844 100644 --- a/storyboard/api/auth/oauth_validator.py +++ b/storyboard/api/auth/oauth_validator.py @@ -226,14 +226,14 @@ class SkeletonValidator(RequestValidator): refresh_expires_in = CONF.oauth.refresh_token_ttl refresh_token_values = { + "access_token_id": access_token.id, "refresh_token": token["refresh_token"], "user_id": user_id, "expires_in": refresh_expires_in, "expires_at": datetime.datetime.now(pytz.utc) + datetime.timedelta( seconds=refresh_expires_in), } - refresh_token_api.refresh_token_create(access_token.id, - refresh_token_values) + refresh_token_api.refresh_token_create(refresh_token_values) def invalidate_authorization_code(self, client_id, code, request, *args, **kwargs): diff --git a/storyboard/api/v1/user_tokens.py b/storyboard/api/v1/user_tokens.py index d5603dcc..e7a29bd0 100644 --- a/storyboard/api/v1/user_tokens.py +++ b/storyboard/api/v1/user_tokens.py @@ -44,12 +44,11 @@ class UserTokensController(rest.RestController): def _from_db_model(self, access_token): access_token_model = wmodels.AccessToken.from_db_model( access_token, - skip_fields="refresh_tokens") + skip_fields="refresh_token") - access_token_model.refresh_tokens = [ - wmodels.RefreshToken.from_db_model(token) - for token in access_token.refresh_tokens - ] + if access_token.refresh_token: + access_token_model.refresh_token = wmodels.RefreshToken \ + .from_db_model(access_token.refresh_token) return access_token_model @@ -142,8 +141,8 @@ class UserTokensController(rest.RestController): token_dict = body.as_dict() - if "refresh_tokens" in token_dict: - del token_dict["refresh_tokens"] + if "refresh_token" in token_dict: + del token_dict["refresh_token"] token = token_api.user_token_create(token_dict) @@ -178,8 +177,8 @@ class UserTokensController(rest.RestController): token_dict = target_token.as_dict() - if "refresh_tokens" in token_dict: - del token_dict["refresh_tokens"] + if "refresh_token" in token_dict: + del token_dict["refresh_token"] result_token = token_api.user_token_update(access_token_id, token_dict) diff --git a/storyboard/api/v1/wmodels.py b/storyboard/api/v1/wmodels.py index 7b895d5f..522c9757 100644 --- a/storyboard/api/v1/wmodels.py +++ b/storyboard/api/v1/wmodels.py @@ -431,8 +431,8 @@ class AccessToken(base.APIBase): expires_in = int """The number of seconds after creation when this token expires.""" - refresh_tokens = wtypes.ArrayType(RefreshToken) - """Array of corresponding refresh tokens.""" + refresh_token = RefreshToken + """The associated refresh token.""" @classmethod def sample(cls): diff --git a/storyboard/db/api/access_tokens.py b/storyboard/db/api/access_tokens.py index ff173a1b..ef2d373b 100644 --- a/storyboard/db/api/access_tokens.py +++ b/storyboard/db/api/access_tokens.py @@ -111,18 +111,7 @@ def access_token_delete(access_token_id): access_token = access_token_get(access_token_id, session=session) if access_token: - query = api_base.model_query(models.RefreshToken) - refresh_tokens = [] - - for refresh_token in access_token.refresh_tokens: - refresh_tokens.append(refresh_token.id) - - query = query.filter(models.RefreshToken.id.in_( - refresh_tokens - )) - - api_base.entity_hard_delete(models.AccessToken, access_token_id) - query.delete(synchronize_session=False) + session.delete(access_token) def access_token_delete_by_token(access_token): diff --git a/storyboard/db/api/refresh_tokens.py b/storyboard/db/api/refresh_tokens.py index 54ba7410..01375d22 100644 --- a/storyboard/db/api/refresh_tokens.py +++ b/storyboard/db/api/refresh_tokens.py @@ -16,11 +16,8 @@ import datetime import pytz -from storyboard.common import exception as exc -from storyboard.db.api import access_tokens as access_tokens_api from storyboard.db.api import base as api_base from storyboard.db import models -from storyboard.openstack.common.gettextutils import _ # noqa def refresh_token_get(refresh_token_id, session=None): @@ -59,26 +56,18 @@ def get_access_token_id(refresh_token_id): refresh_token = refresh_token_get(refresh_token_id, session) if refresh_token: - return refresh_token.access_tokens[0].id + return refresh_token.access_token.id -def refresh_token_create(access_token_id, values): +def refresh_token_create(values): session = api_base.get_session() with session.begin(subtransactions=True): values['expires_at'] = datetime.datetime.now(pytz.utc) + datetime.\ timedelta(seconds=values['expires_in']) - access_token = access_tokens_api.access_token_get(access_token_id, - session=session) - - if not access_token: - raise exc.NotFound(_("Access token not found.")) refresh_token = api_base.entity_create(models.RefreshToken, values) - access_token.refresh_tokens.append(refresh_token) - session.add(access_token) - return refresh_token @@ -101,13 +90,7 @@ def refresh_token_delete(refresh_token_id): refresh_token = refresh_token_get(refresh_token_id) if refresh_token: - access_token_id = refresh_token.access_tokens[0].id - access_token = access_tokens_api.access_token_get(access_token_id, - session=session) - - access_token.refresh_tokens.remove(refresh_token) - session.add(access_token) - api_base.entity_hard_delete(models.RefreshToken, refresh_token_id) + session.delete(refresh_token) def refresh_token_delete_by_token(refresh_token): diff --git a/storyboard/db/api/user_tokens.py b/storyboard/db/api/user_tokens.py index af7b5cc7..b5f9fd87 100644 --- a/storyboard/db/api/user_tokens.py +++ b/storyboard/db/api/user_tokens.py @@ -13,20 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from sqlalchemy.orm import subqueryload - from storyboard.db.api import access_tokens as access_tokens_api from storyboard.db.api import base as api_base from storyboard.db import models def user_token_get(access_token_id): - token_query = api_base.model_query(models.AccessToken) - user_token = token_query.options( - subqueryload(models.AccessToken.refresh_tokens) - ).filter_by(id=access_token_id).first() - - return user_token + return api_base.model_query(models.AccessToken) \ + .filter_by(id=access_token_id).first() def user_token_get_all(marker=None, limit=None, sort_field=None, @@ -36,8 +30,7 @@ def user_token_get_all(marker=None, limit=None, sort_field=None, if not sort_dir: sort_dir = 'asc' - query = api_base.model_query(models.AccessToken). \ - options(subqueryload(models.AccessToken.refresh_tokens)) + query = api_base.model_query(models.AccessToken) query = api_base.apply_query_filters(query=query, model=models.AccessToken, diff --git a/storyboard/db/migration/alembic_migrations/versions/048_refresh_access_token_link.py b/storyboard/db/migration/alembic_migrations/versions/048_refresh_access_token_link.py new file mode 100644 index 00000000..dbeb0efc --- /dev/null +++ b/storyboard/db/migration/alembic_migrations/versions/048_refresh_access_token_link.py @@ -0,0 +1,65 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# + +"""This migration converts our many-to-many mapping among auth tokens into +a one-to-one relationship. + +Revision ID: 048 +Revises: 047 +Create Date: 2015-04-12 18:03:23 + +""" + +# revision identifiers, used by Alembic. + +revision = '048' +down_revision = '047' + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.sql.expression import table + +MYSQL_ENGINE = 'InnoDB' +MYSQL_CHARSET = 'utf8' + + +def upgrade(active_plugins=None, options=None): + op.add_column('refreshtokens', sa.Column('access_token_id', + sa.Integer(), + nullable=False)) + op.drop_table('access_refresh_tokens') + + # Delete all refresh and access tokens, as the relationship is no longer + # valid. + bind = op.get_bind() + + refresh_table = table( + 'refreshtokens' + ) + access_table = table( + 'accesstokens' + ) + + bind.execute(refresh_table.delete()) + bind.execute(access_table.delete()) + + +def downgrade(active_plugins=None, options=None): + op.create_table('access_refresh_tokens', + sa.Column('access_token_id', sa.Integer(), nullable=False), + sa.Column('refresh_token_id', sa.Integer(), + nullable=False), + mysql_engine=MYSQL_ENGINE, + mysql_charset=MYSQL_CHARSET + ) + op.drop_column(u'refreshtokens', u'access_token_id') diff --git a/storyboard/db/models.py b/storyboard/db/models.py index b5b028bf..2546aa11 100644 --- a/storyboard/db/models.py +++ b/storyboard/db/models.py @@ -393,30 +393,25 @@ class AuthorizationCode(ModelBuilder, Base): expires_in = Column(Integer, nullable=False, default=300) -access_refresh_tokens = Table( - 'access_refresh_tokens', Base.metadata, - Column('access_token_id', Integer, ForeignKey('accesstokens.id')), - Column('refresh_token_id', Integer, ForeignKey('refreshtokens.id')), - schema.UniqueConstraint('refresh_token_id', name='uniq_refresh_token'), -) - - class AccessToken(ModelBuilder, Base): user_id = Column(Integer, ForeignKey('users.id'), nullable=False) access_token = Column(Unicode(CommonLength.top_middle_length), nullable=False) expires_in = Column(Integer, nullable=False) expires_at = Column(UTCDateTime, nullable=False) - refresh_tokens = relationship("RefreshToken", - secondary='access_refresh_tokens', - ) + refresh_token = relationship("RefreshToken", + uselist=False, + cascade="all, delete-orphan", + backref="access_token", + passive_updates=False, + passive_deletes=False) class RefreshToken(ModelBuilder, Base): user_id = Column(Integer, ForeignKey('users.id'), nullable=False) - access_tokens = relationship("AccessToken", - secondary='access_refresh_tokens', - ) + access_token_id = Column(Integer, + ForeignKey('accesstokens.id'), + nullable=False) refresh_token = Column(Unicode(CommonLength.top_middle_length), nullable=False) expires_in = Column(Integer, nullable=False) diff --git a/storyboard/tests/db/api/test_access_tokens.py b/storyboard/tests/db/api/test_access_tokens.py index 1b7c5105..705c06a0 100644 --- a/storyboard/tests/db/api/test_access_tokens.py +++ b/storyboard/tests/db/api/test_access_tokens.py @@ -28,7 +28,6 @@ class TokenTest(base.BaseDbTestCase): self.token_01 = { "access_token": u'an_access_token', - "refresh_token": u'a_refresh_token', "expires_in": 3600, "expires_at": datetime.now(pytz.utc), "user_id": 1 diff --git a/storyboard/tests/plugin/token_cleaner/test_cleaner.py b/storyboard/tests/plugin/token_cleaner/test_cleaner.py index 7ccfa815..a4357fda 100644 --- a/storyboard/tests/plugin/token_cleaner/test_cleaner.py +++ b/storyboard/tests/plugin/token_cleaner/test_cleaner.py @@ -92,7 +92,7 @@ class TestTokenCleaner(db_base.BaseDbTestCase, created_at=token.created_at, expires_in=token.expires_in, expires_at=token.expires_at, - access_tokens=[token], + access_token_id=token.id, refresh_token='test_refresh_%s' % (token.id,)) )