One-To-One relationship for Access and Refresh Tokens

This changes the relationship in our access and refresh token table
from a many-to-many to a one-to-one. This is for two reasons:
Firstly, a many-to-many means it's possible to have a refresh token
issued that can refresh multiple access tokens, which actually
doesn't make a lot of sense, since our logic _also_ deletes all
access tokens associated with a used refresh token.

Secondly, many-to-many relationships do not support cascading updates
in SQLAlchemy, so our goal to manage all of our foreign relationships
in SQLA is not achievable if we maintain these kinds of relationships.

Change-Id: Iab3c65ac403961fccf7db15ef930f9d896626108
This commit is contained in:
Michael Krotscheck
2015-04-13 17:43:58 -07:00
parent 27bf9e5fd9
commit e42c6e38bf
10 changed files with 94 additions and 71 deletions

View File

@@ -226,14 +226,14 @@ class SkeletonValidator(RequestValidator):
refresh_expires_in = CONF.oauth.refresh_token_ttl refresh_expires_in = CONF.oauth.refresh_token_ttl
refresh_token_values = { refresh_token_values = {
"access_token_id": access_token.id,
"refresh_token": token["refresh_token"], "refresh_token": token["refresh_token"],
"user_id": user_id, "user_id": user_id,
"expires_in": refresh_expires_in, "expires_in": refresh_expires_in,
"expires_at": datetime.datetime.now(pytz.utc) + datetime.timedelta( "expires_at": datetime.datetime.now(pytz.utc) + datetime.timedelta(
seconds=refresh_expires_in), seconds=refresh_expires_in),
} }
refresh_token_api.refresh_token_create(access_token.id, refresh_token_api.refresh_token_create(refresh_token_values)
refresh_token_values)
def invalidate_authorization_code(self, client_id, code, request, *args, def invalidate_authorization_code(self, client_id, code, request, *args,
**kwargs): **kwargs):

View File

@@ -44,12 +44,11 @@ class UserTokensController(rest.RestController):
def _from_db_model(self, access_token): def _from_db_model(self, access_token):
access_token_model = wmodels.AccessToken.from_db_model( access_token_model = wmodels.AccessToken.from_db_model(
access_token, access_token,
skip_fields="refresh_tokens") skip_fields="refresh_token")
access_token_model.refresh_tokens = [ if access_token.refresh_token:
wmodels.RefreshToken.from_db_model(token) access_token_model.refresh_token = wmodels.RefreshToken \
for token in access_token.refresh_tokens .from_db_model(access_token.refresh_token)
]
return access_token_model return access_token_model
@@ -142,8 +141,8 @@ class UserTokensController(rest.RestController):
token_dict = body.as_dict() token_dict = body.as_dict()
if "refresh_tokens" in token_dict: if "refresh_token" in token_dict:
del token_dict["refresh_tokens"] del token_dict["refresh_token"]
token = token_api.user_token_create(token_dict) token = token_api.user_token_create(token_dict)
@@ -178,8 +177,8 @@ class UserTokensController(rest.RestController):
token_dict = target_token.as_dict() token_dict = target_token.as_dict()
if "refresh_tokens" in token_dict: if "refresh_token" in token_dict:
del token_dict["refresh_tokens"] del token_dict["refresh_token"]
result_token = token_api.user_token_update(access_token_id, result_token = token_api.user_token_update(access_token_id,
token_dict) token_dict)

View File

@@ -431,8 +431,8 @@ class AccessToken(base.APIBase):
expires_in = int expires_in = int
"""The number of seconds after creation when this token expires.""" """The number of seconds after creation when this token expires."""
refresh_tokens = wtypes.ArrayType(RefreshToken) refresh_token = RefreshToken
"""Array of corresponding refresh tokens.""" """The associated refresh token."""
@classmethod @classmethod
def sample(cls): def sample(cls):

View File

@@ -111,18 +111,7 @@ def access_token_delete(access_token_id):
access_token = access_token_get(access_token_id, session=session) access_token = access_token_get(access_token_id, session=session)
if access_token: if access_token:
query = api_base.model_query(models.RefreshToken) session.delete(access_token)
refresh_tokens = []
for refresh_token in access_token.refresh_tokens:
refresh_tokens.append(refresh_token.id)
query = query.filter(models.RefreshToken.id.in_(
refresh_tokens
))
api_base.entity_hard_delete(models.AccessToken, access_token_id)
query.delete(synchronize_session=False)
def access_token_delete_by_token(access_token): def access_token_delete_by_token(access_token):

View File

@@ -16,11 +16,8 @@
import datetime import datetime
import pytz import pytz
from storyboard.common import exception as exc
from storyboard.db.api import access_tokens as access_tokens_api
from storyboard.db.api import base as api_base from storyboard.db.api import base as api_base
from storyboard.db import models from storyboard.db import models
from storyboard.openstack.common.gettextutils import _ # noqa
def refresh_token_get(refresh_token_id, session=None): def refresh_token_get(refresh_token_id, session=None):
@@ -59,26 +56,18 @@ def get_access_token_id(refresh_token_id):
refresh_token = refresh_token_get(refresh_token_id, session) refresh_token = refresh_token_get(refresh_token_id, session)
if refresh_token: if refresh_token:
return refresh_token.access_tokens[0].id return refresh_token.access_token.id
def refresh_token_create(access_token_id, values): def refresh_token_create(values):
session = api_base.get_session() session = api_base.get_session()
with session.begin(subtransactions=True): with session.begin(subtransactions=True):
values['expires_at'] = datetime.datetime.now(pytz.utc) + datetime.\ values['expires_at'] = datetime.datetime.now(pytz.utc) + datetime.\
timedelta(seconds=values['expires_in']) timedelta(seconds=values['expires_in'])
access_token = access_tokens_api.access_token_get(access_token_id,
session=session)
if not access_token:
raise exc.NotFound(_("Access token not found."))
refresh_token = api_base.entity_create(models.RefreshToken, values) refresh_token = api_base.entity_create(models.RefreshToken, values)
access_token.refresh_tokens.append(refresh_token)
session.add(access_token)
return refresh_token return refresh_token
@@ -101,13 +90,7 @@ def refresh_token_delete(refresh_token_id):
refresh_token = refresh_token_get(refresh_token_id) refresh_token = refresh_token_get(refresh_token_id)
if refresh_token: if refresh_token:
access_token_id = refresh_token.access_tokens[0].id session.delete(refresh_token)
access_token = access_tokens_api.access_token_get(access_token_id,
session=session)
access_token.refresh_tokens.remove(refresh_token)
session.add(access_token)
api_base.entity_hard_delete(models.RefreshToken, refresh_token_id)
def refresh_token_delete_by_token(refresh_token): def refresh_token_delete_by_token(refresh_token):

View File

@@ -13,20 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from sqlalchemy.orm import subqueryload
from storyboard.db.api import access_tokens as access_tokens_api from storyboard.db.api import access_tokens as access_tokens_api
from storyboard.db.api import base as api_base from storyboard.db.api import base as api_base
from storyboard.db import models from storyboard.db import models
def user_token_get(access_token_id): def user_token_get(access_token_id):
token_query = api_base.model_query(models.AccessToken) return api_base.model_query(models.AccessToken) \
user_token = token_query.options( .filter_by(id=access_token_id).first()
subqueryload(models.AccessToken.refresh_tokens)
).filter_by(id=access_token_id).first()
return user_token
def user_token_get_all(marker=None, limit=None, sort_field=None, def user_token_get_all(marker=None, limit=None, sort_field=None,
@@ -36,8 +30,7 @@ def user_token_get_all(marker=None, limit=None, sort_field=None,
if not sort_dir: if not sort_dir:
sort_dir = 'asc' sort_dir = 'asc'
query = api_base.model_query(models.AccessToken). \ query = api_base.model_query(models.AccessToken)
options(subqueryload(models.AccessToken.refresh_tokens))
query = api_base.apply_query_filters(query=query, query = api_base.apply_query_filters(query=query,
model=models.AccessToken, model=models.AccessToken,

View File

@@ -0,0 +1,65 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
"""This migration converts our many-to-many mapping among auth tokens into
a one-to-one relationship.
Revision ID: 048
Revises: 047
Create Date: 2015-04-12 18:03:23
"""
# revision identifiers, used by Alembic.
revision = '048'
down_revision = '047'
from alembic import op
import sqlalchemy as sa
from sqlalchemy.sql.expression import table
MYSQL_ENGINE = 'InnoDB'
MYSQL_CHARSET = 'utf8'
def upgrade(active_plugins=None, options=None):
op.add_column('refreshtokens', sa.Column('access_token_id',
sa.Integer(),
nullable=False))
op.drop_table('access_refresh_tokens')
# Delete all refresh and access tokens, as the relationship is no longer
# valid.
bind = op.get_bind()
refresh_table = table(
'refreshtokens'
)
access_table = table(
'accesstokens'
)
bind.execute(refresh_table.delete())
bind.execute(access_table.delete())
def downgrade(active_plugins=None, options=None):
op.create_table('access_refresh_tokens',
sa.Column('access_token_id', sa.Integer(), nullable=False),
sa.Column('refresh_token_id', sa.Integer(),
nullable=False),
mysql_engine=MYSQL_ENGINE,
mysql_charset=MYSQL_CHARSET
)
op.drop_column(u'refreshtokens', u'access_token_id')

View File

@@ -393,30 +393,25 @@ class AuthorizationCode(ModelBuilder, Base):
expires_in = Column(Integer, nullable=False, default=300) expires_in = Column(Integer, nullable=False, default=300)
access_refresh_tokens = Table(
'access_refresh_tokens', Base.metadata,
Column('access_token_id', Integer, ForeignKey('accesstokens.id')),
Column('refresh_token_id', Integer, ForeignKey('refreshtokens.id')),
schema.UniqueConstraint('refresh_token_id', name='uniq_refresh_token'),
)
class AccessToken(ModelBuilder, Base): class AccessToken(ModelBuilder, Base):
user_id = Column(Integer, ForeignKey('users.id'), nullable=False) user_id = Column(Integer, ForeignKey('users.id'), nullable=False)
access_token = Column(Unicode(CommonLength.top_middle_length), access_token = Column(Unicode(CommonLength.top_middle_length),
nullable=False) nullable=False)
expires_in = Column(Integer, nullable=False) expires_in = Column(Integer, nullable=False)
expires_at = Column(UTCDateTime, nullable=False) expires_at = Column(UTCDateTime, nullable=False)
refresh_tokens = relationship("RefreshToken", refresh_token = relationship("RefreshToken",
secondary='access_refresh_tokens', uselist=False,
) cascade="all, delete-orphan",
backref="access_token",
passive_updates=False,
passive_deletes=False)
class RefreshToken(ModelBuilder, Base): class RefreshToken(ModelBuilder, Base):
user_id = Column(Integer, ForeignKey('users.id'), nullable=False) user_id = Column(Integer, ForeignKey('users.id'), nullable=False)
access_tokens = relationship("AccessToken", access_token_id = Column(Integer,
secondary='access_refresh_tokens', ForeignKey('accesstokens.id'),
) nullable=False)
refresh_token = Column(Unicode(CommonLength.top_middle_length), refresh_token = Column(Unicode(CommonLength.top_middle_length),
nullable=False) nullable=False)
expires_in = Column(Integer, nullable=False) expires_in = Column(Integer, nullable=False)

View File

@@ -28,7 +28,6 @@ class TokenTest(base.BaseDbTestCase):
self.token_01 = { self.token_01 = {
"access_token": u'an_access_token', "access_token": u'an_access_token',
"refresh_token": u'a_refresh_token',
"expires_in": 3600, "expires_in": 3600,
"expires_at": datetime.now(pytz.utc), "expires_at": datetime.now(pytz.utc),
"user_id": 1 "user_id": 1

View File

@@ -92,7 +92,7 @@ class TestTokenCleaner(db_base.BaseDbTestCase,
created_at=token.created_at, created_at=token.created_at,
expires_in=token.expires_in, expires_in=token.expires_in,
expires_at=token.expires_at, expires_at=token.expires_at,
access_tokens=[token], access_token_id=token.id,
refresh_token='test_refresh_%s' % (token.id,)) refresh_token='test_refresh_%s' % (token.id,))
) )