diff --git a/storyboard/api/auth/oauth_validator.py b/storyboard/api/auth/oauth_validator.py index b3faca43..59701844 100644 --- a/storyboard/api/auth/oauth_validator.py +++ b/storyboard/api/auth/oauth_validator.py @@ -226,14 +226,14 @@ class SkeletonValidator(RequestValidator): refresh_expires_in = CONF.oauth.refresh_token_ttl refresh_token_values = { + "access_token_id": access_token.id, "refresh_token": token["refresh_token"], "user_id": user_id, "expires_in": refresh_expires_in, "expires_at": datetime.datetime.now(pytz.utc) + datetime.timedelta( seconds=refresh_expires_in), } - refresh_token_api.refresh_token_create(access_token.id, - refresh_token_values) + refresh_token_api.refresh_token_create(refresh_token_values) def invalidate_authorization_code(self, client_id, code, request, *args, **kwargs): diff --git a/storyboard/api/v1/user_tokens.py b/storyboard/api/v1/user_tokens.py index d5603dcc..e7a29bd0 100644 --- a/storyboard/api/v1/user_tokens.py +++ b/storyboard/api/v1/user_tokens.py @@ -44,12 +44,11 @@ class UserTokensController(rest.RestController): def _from_db_model(self, access_token): access_token_model = wmodels.AccessToken.from_db_model( access_token, - skip_fields="refresh_tokens") + skip_fields="refresh_token") - access_token_model.refresh_tokens = [ - wmodels.RefreshToken.from_db_model(token) - for token in access_token.refresh_tokens - ] + if access_token.refresh_token: + access_token_model.refresh_token = wmodels.RefreshToken \ + .from_db_model(access_token.refresh_token) return access_token_model @@ -142,8 +141,8 @@ class UserTokensController(rest.RestController): token_dict = body.as_dict() - if "refresh_tokens" in token_dict: - del token_dict["refresh_tokens"] + if "refresh_token" in token_dict: + del token_dict["refresh_token"] token = token_api.user_token_create(token_dict) @@ -178,8 +177,8 @@ class UserTokensController(rest.RestController): token_dict = target_token.as_dict() - if "refresh_tokens" in token_dict: - del token_dict["refresh_tokens"] + if "refresh_token" in token_dict: + del token_dict["refresh_token"] result_token = token_api.user_token_update(access_token_id, token_dict) diff --git a/storyboard/api/v1/wmodels.py b/storyboard/api/v1/wmodels.py index 7b895d5f..522c9757 100644 --- a/storyboard/api/v1/wmodels.py +++ b/storyboard/api/v1/wmodels.py @@ -431,8 +431,8 @@ class AccessToken(base.APIBase): expires_in = int """The number of seconds after creation when this token expires.""" - refresh_tokens = wtypes.ArrayType(RefreshToken) - """Array of corresponding refresh tokens.""" + refresh_token = RefreshToken + """The associated refresh token.""" @classmethod def sample(cls): diff --git a/storyboard/db/api/access_tokens.py b/storyboard/db/api/access_tokens.py index ff173a1b..ef2d373b 100644 --- a/storyboard/db/api/access_tokens.py +++ b/storyboard/db/api/access_tokens.py @@ -111,18 +111,7 @@ def access_token_delete(access_token_id): access_token = access_token_get(access_token_id, session=session) if access_token: - query = api_base.model_query(models.RefreshToken) - refresh_tokens = [] - - for refresh_token in access_token.refresh_tokens: - refresh_tokens.append(refresh_token.id) - - query = query.filter(models.RefreshToken.id.in_( - refresh_tokens - )) - - api_base.entity_hard_delete(models.AccessToken, access_token_id) - query.delete(synchronize_session=False) + session.delete(access_token) def access_token_delete_by_token(access_token): diff --git a/storyboard/db/api/refresh_tokens.py b/storyboard/db/api/refresh_tokens.py index 54ba7410..01375d22 100644 --- a/storyboard/db/api/refresh_tokens.py +++ b/storyboard/db/api/refresh_tokens.py @@ -16,11 +16,8 @@ import datetime import pytz -from storyboard.common import exception as exc -from storyboard.db.api import access_tokens as access_tokens_api from storyboard.db.api import base as api_base from storyboard.db import models -from storyboard.openstack.common.gettextutils import _ # noqa def refresh_token_get(refresh_token_id, session=None): @@ -59,26 +56,18 @@ def get_access_token_id(refresh_token_id): refresh_token = refresh_token_get(refresh_token_id, session) if refresh_token: - return refresh_token.access_tokens[0].id + return refresh_token.access_token.id -def refresh_token_create(access_token_id, values): +def refresh_token_create(values): session = api_base.get_session() with session.begin(subtransactions=True): values['expires_at'] = datetime.datetime.now(pytz.utc) + datetime.\ timedelta(seconds=values['expires_in']) - access_token = access_tokens_api.access_token_get(access_token_id, - session=session) - - if not access_token: - raise exc.NotFound(_("Access token not found.")) refresh_token = api_base.entity_create(models.RefreshToken, values) - access_token.refresh_tokens.append(refresh_token) - session.add(access_token) - return refresh_token @@ -101,13 +90,7 @@ def refresh_token_delete(refresh_token_id): refresh_token = refresh_token_get(refresh_token_id) if refresh_token: - access_token_id = refresh_token.access_tokens[0].id - access_token = access_tokens_api.access_token_get(access_token_id, - session=session) - - access_token.refresh_tokens.remove(refresh_token) - session.add(access_token) - api_base.entity_hard_delete(models.RefreshToken, refresh_token_id) + session.delete(refresh_token) def refresh_token_delete_by_token(refresh_token): diff --git a/storyboard/db/api/user_tokens.py b/storyboard/db/api/user_tokens.py index af7b5cc7..b5f9fd87 100644 --- a/storyboard/db/api/user_tokens.py +++ b/storyboard/db/api/user_tokens.py @@ -13,20 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from sqlalchemy.orm import subqueryload - from storyboard.db.api import access_tokens as access_tokens_api from storyboard.db.api import base as api_base from storyboard.db import models def user_token_get(access_token_id): - token_query = api_base.model_query(models.AccessToken) - user_token = token_query.options( - subqueryload(models.AccessToken.refresh_tokens) - ).filter_by(id=access_token_id).first() - - return user_token + return api_base.model_query(models.AccessToken) \ + .filter_by(id=access_token_id).first() def user_token_get_all(marker=None, limit=None, sort_field=None, @@ -36,8 +30,7 @@ def user_token_get_all(marker=None, limit=None, sort_field=None, if not sort_dir: sort_dir = 'asc' - query = api_base.model_query(models.AccessToken). \ - options(subqueryload(models.AccessToken.refresh_tokens)) + query = api_base.model_query(models.AccessToken) query = api_base.apply_query_filters(query=query, model=models.AccessToken, diff --git a/storyboard/db/migration/alembic_migrations/versions/048_refresh_access_token_link.py b/storyboard/db/migration/alembic_migrations/versions/048_refresh_access_token_link.py new file mode 100644 index 00000000..dbeb0efc --- /dev/null +++ b/storyboard/db/migration/alembic_migrations/versions/048_refresh_access_token_link.py @@ -0,0 +1,65 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# + +"""This migration converts our many-to-many mapping among auth tokens into +a one-to-one relationship. + +Revision ID: 048 +Revises: 047 +Create Date: 2015-04-12 18:03:23 + +""" + +# revision identifiers, used by Alembic. + +revision = '048' +down_revision = '047' + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.sql.expression import table + +MYSQL_ENGINE = 'InnoDB' +MYSQL_CHARSET = 'utf8' + + +def upgrade(active_plugins=None, options=None): + op.add_column('refreshtokens', sa.Column('access_token_id', + sa.Integer(), + nullable=False)) + op.drop_table('access_refresh_tokens') + + # Delete all refresh and access tokens, as the relationship is no longer + # valid. + bind = op.get_bind() + + refresh_table = table( + 'refreshtokens' + ) + access_table = table( + 'accesstokens' + ) + + bind.execute(refresh_table.delete()) + bind.execute(access_table.delete()) + + +def downgrade(active_plugins=None, options=None): + op.create_table('access_refresh_tokens', + sa.Column('access_token_id', sa.Integer(), nullable=False), + sa.Column('refresh_token_id', sa.Integer(), + nullable=False), + mysql_engine=MYSQL_ENGINE, + mysql_charset=MYSQL_CHARSET + ) + op.drop_column(u'refreshtokens', u'access_token_id') diff --git a/storyboard/db/models.py b/storyboard/db/models.py index b5b028bf..2546aa11 100644 --- a/storyboard/db/models.py +++ b/storyboard/db/models.py @@ -393,30 +393,25 @@ class AuthorizationCode(ModelBuilder, Base): expires_in = Column(Integer, nullable=False, default=300) -access_refresh_tokens = Table( - 'access_refresh_tokens', Base.metadata, - Column('access_token_id', Integer, ForeignKey('accesstokens.id')), - Column('refresh_token_id', Integer, ForeignKey('refreshtokens.id')), - schema.UniqueConstraint('refresh_token_id', name='uniq_refresh_token'), -) - - class AccessToken(ModelBuilder, Base): user_id = Column(Integer, ForeignKey('users.id'), nullable=False) access_token = Column(Unicode(CommonLength.top_middle_length), nullable=False) expires_in = Column(Integer, nullable=False) expires_at = Column(UTCDateTime, nullable=False) - refresh_tokens = relationship("RefreshToken", - secondary='access_refresh_tokens', - ) + refresh_token = relationship("RefreshToken", + uselist=False, + cascade="all, delete-orphan", + backref="access_token", + passive_updates=False, + passive_deletes=False) class RefreshToken(ModelBuilder, Base): user_id = Column(Integer, ForeignKey('users.id'), nullable=False) - access_tokens = relationship("AccessToken", - secondary='access_refresh_tokens', - ) + access_token_id = Column(Integer, + ForeignKey('accesstokens.id'), + nullable=False) refresh_token = Column(Unicode(CommonLength.top_middle_length), nullable=False) expires_in = Column(Integer, nullable=False) diff --git a/storyboard/tests/db/api/test_access_tokens.py b/storyboard/tests/db/api/test_access_tokens.py index 1b7c5105..705c06a0 100644 --- a/storyboard/tests/db/api/test_access_tokens.py +++ b/storyboard/tests/db/api/test_access_tokens.py @@ -28,7 +28,6 @@ class TokenTest(base.BaseDbTestCase): self.token_01 = { "access_token": u'an_access_token', - "refresh_token": u'a_refresh_token', "expires_in": 3600, "expires_at": datetime.now(pytz.utc), "user_id": 1 diff --git a/storyboard/tests/plugin/token_cleaner/test_cleaner.py b/storyboard/tests/plugin/token_cleaner/test_cleaner.py index 7ccfa815..a4357fda 100644 --- a/storyboard/tests/plugin/token_cleaner/test_cleaner.py +++ b/storyboard/tests/plugin/token_cleaner/test_cleaner.py @@ -92,7 +92,7 @@ class TestTokenCleaner(db_base.BaseDbTestCase, created_at=token.created_at, expires_in=token.expires_in, expires_at=token.expires_at, - access_tokens=[token], + access_token_id=token.id, refresh_token='test_refresh_%s' % (token.id,)) )