One-To-One relationship for Access and Refresh Tokens
This changes the relationship in our access and refresh token table from a many-to-many to a one-to-one. This is for two reasons: Firstly, a many-to-many means it's possible to have a refresh token issued that can refresh multiple access tokens, which actually doesn't make a lot of sense, since our logic _also_ deletes all access tokens associated with a used refresh token. Secondly, many-to-many relationships do not support cascading updates in SQLAlchemy, so our goal to manage all of our foreign relationships in SQLA is not achievable if we maintain these kinds of relationships. Change-Id: Iab3c65ac403961fccf7db15ef930f9d896626108
This commit is contained in:
@@ -226,14 +226,14 @@ class SkeletonValidator(RequestValidator):
|
|||||||
refresh_expires_in = CONF.oauth.refresh_token_ttl
|
refresh_expires_in = CONF.oauth.refresh_token_ttl
|
||||||
|
|
||||||
refresh_token_values = {
|
refresh_token_values = {
|
||||||
|
"access_token_id": access_token.id,
|
||||||
"refresh_token": token["refresh_token"],
|
"refresh_token": token["refresh_token"],
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"expires_in": refresh_expires_in,
|
"expires_in": refresh_expires_in,
|
||||||
"expires_at": datetime.datetime.now(pytz.utc) + datetime.timedelta(
|
"expires_at": datetime.datetime.now(pytz.utc) + datetime.timedelta(
|
||||||
seconds=refresh_expires_in),
|
seconds=refresh_expires_in),
|
||||||
}
|
}
|
||||||
refresh_token_api.refresh_token_create(access_token.id,
|
refresh_token_api.refresh_token_create(refresh_token_values)
|
||||||
refresh_token_values)
|
|
||||||
|
|
||||||
def invalidate_authorization_code(self, client_id, code, request, *args,
|
def invalidate_authorization_code(self, client_id, code, request, *args,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
@@ -44,12 +44,11 @@ class UserTokensController(rest.RestController):
|
|||||||
def _from_db_model(self, access_token):
|
def _from_db_model(self, access_token):
|
||||||
access_token_model = wmodels.AccessToken.from_db_model(
|
access_token_model = wmodels.AccessToken.from_db_model(
|
||||||
access_token,
|
access_token,
|
||||||
skip_fields="refresh_tokens")
|
skip_fields="refresh_token")
|
||||||
|
|
||||||
access_token_model.refresh_tokens = [
|
if access_token.refresh_token:
|
||||||
wmodels.RefreshToken.from_db_model(token)
|
access_token_model.refresh_token = wmodels.RefreshToken \
|
||||||
for token in access_token.refresh_tokens
|
.from_db_model(access_token.refresh_token)
|
||||||
]
|
|
||||||
|
|
||||||
return access_token_model
|
return access_token_model
|
||||||
|
|
||||||
@@ -142,8 +141,8 @@ class UserTokensController(rest.RestController):
|
|||||||
|
|
||||||
token_dict = body.as_dict()
|
token_dict = body.as_dict()
|
||||||
|
|
||||||
if "refresh_tokens" in token_dict:
|
if "refresh_token" in token_dict:
|
||||||
del token_dict["refresh_tokens"]
|
del token_dict["refresh_token"]
|
||||||
|
|
||||||
token = token_api.user_token_create(token_dict)
|
token = token_api.user_token_create(token_dict)
|
||||||
|
|
||||||
@@ -178,8 +177,8 @@ class UserTokensController(rest.RestController):
|
|||||||
|
|
||||||
token_dict = target_token.as_dict()
|
token_dict = target_token.as_dict()
|
||||||
|
|
||||||
if "refresh_tokens" in token_dict:
|
if "refresh_token" in token_dict:
|
||||||
del token_dict["refresh_tokens"]
|
del token_dict["refresh_token"]
|
||||||
|
|
||||||
result_token = token_api.user_token_update(access_token_id,
|
result_token = token_api.user_token_update(access_token_id,
|
||||||
token_dict)
|
token_dict)
|
||||||
|
@@ -431,8 +431,8 @@ class AccessToken(base.APIBase):
|
|||||||
expires_in = int
|
expires_in = int
|
||||||
"""The number of seconds after creation when this token expires."""
|
"""The number of seconds after creation when this token expires."""
|
||||||
|
|
||||||
refresh_tokens = wtypes.ArrayType(RefreshToken)
|
refresh_token = RefreshToken
|
||||||
"""Array of corresponding refresh tokens."""
|
"""The associated refresh token."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample(cls):
|
def sample(cls):
|
||||||
|
@@ -111,18 +111,7 @@ def access_token_delete(access_token_id):
|
|||||||
access_token = access_token_get(access_token_id, session=session)
|
access_token = access_token_get(access_token_id, session=session)
|
||||||
|
|
||||||
if access_token:
|
if access_token:
|
||||||
query = api_base.model_query(models.RefreshToken)
|
session.delete(access_token)
|
||||||
refresh_tokens = []
|
|
||||||
|
|
||||||
for refresh_token in access_token.refresh_tokens:
|
|
||||||
refresh_tokens.append(refresh_token.id)
|
|
||||||
|
|
||||||
query = query.filter(models.RefreshToken.id.in_(
|
|
||||||
refresh_tokens
|
|
||||||
))
|
|
||||||
|
|
||||||
api_base.entity_hard_delete(models.AccessToken, access_token_id)
|
|
||||||
query.delete(synchronize_session=False)
|
|
||||||
|
|
||||||
|
|
||||||
def access_token_delete_by_token(access_token):
|
def access_token_delete_by_token(access_token):
|
||||||
|
@@ -16,11 +16,8 @@
|
|||||||
import datetime
|
import datetime
|
||||||
import pytz
|
import pytz
|
||||||
|
|
||||||
from storyboard.common import exception as exc
|
|
||||||
from storyboard.db.api import access_tokens as access_tokens_api
|
|
||||||
from storyboard.db.api import base as api_base
|
from storyboard.db.api import base as api_base
|
||||||
from storyboard.db import models
|
from storyboard.db import models
|
||||||
from storyboard.openstack.common.gettextutils import _ # noqa
|
|
||||||
|
|
||||||
|
|
||||||
def refresh_token_get(refresh_token_id, session=None):
|
def refresh_token_get(refresh_token_id, session=None):
|
||||||
@@ -59,26 +56,18 @@ def get_access_token_id(refresh_token_id):
|
|||||||
refresh_token = refresh_token_get(refresh_token_id, session)
|
refresh_token = refresh_token_get(refresh_token_id, session)
|
||||||
|
|
||||||
if refresh_token:
|
if refresh_token:
|
||||||
return refresh_token.access_tokens[0].id
|
return refresh_token.access_token.id
|
||||||
|
|
||||||
|
|
||||||
def refresh_token_create(access_token_id, values):
|
def refresh_token_create(values):
|
||||||
session = api_base.get_session()
|
session = api_base.get_session()
|
||||||
|
|
||||||
with session.begin(subtransactions=True):
|
with session.begin(subtransactions=True):
|
||||||
values['expires_at'] = datetime.datetime.now(pytz.utc) + datetime.\
|
values['expires_at'] = datetime.datetime.now(pytz.utc) + datetime.\
|
||||||
timedelta(seconds=values['expires_in'])
|
timedelta(seconds=values['expires_in'])
|
||||||
access_token = access_tokens_api.access_token_get(access_token_id,
|
|
||||||
session=session)
|
|
||||||
|
|
||||||
if not access_token:
|
|
||||||
raise exc.NotFound(_("Access token not found."))
|
|
||||||
|
|
||||||
refresh_token = api_base.entity_create(models.RefreshToken, values)
|
refresh_token = api_base.entity_create(models.RefreshToken, values)
|
||||||
|
|
||||||
access_token.refresh_tokens.append(refresh_token)
|
|
||||||
session.add(access_token)
|
|
||||||
|
|
||||||
return refresh_token
|
return refresh_token
|
||||||
|
|
||||||
|
|
||||||
@@ -101,13 +90,7 @@ def refresh_token_delete(refresh_token_id):
|
|||||||
refresh_token = refresh_token_get(refresh_token_id)
|
refresh_token = refresh_token_get(refresh_token_id)
|
||||||
|
|
||||||
if refresh_token:
|
if refresh_token:
|
||||||
access_token_id = refresh_token.access_tokens[0].id
|
session.delete(refresh_token)
|
||||||
access_token = access_tokens_api.access_token_get(access_token_id,
|
|
||||||
session=session)
|
|
||||||
|
|
||||||
access_token.refresh_tokens.remove(refresh_token)
|
|
||||||
session.add(access_token)
|
|
||||||
api_base.entity_hard_delete(models.RefreshToken, refresh_token_id)
|
|
||||||
|
|
||||||
|
|
||||||
def refresh_token_delete_by_token(refresh_token):
|
def refresh_token_delete_by_token(refresh_token):
|
||||||
|
@@ -13,20 +13,14 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from sqlalchemy.orm import subqueryload
|
|
||||||
|
|
||||||
from storyboard.db.api import access_tokens as access_tokens_api
|
from storyboard.db.api import access_tokens as access_tokens_api
|
||||||
from storyboard.db.api import base as api_base
|
from storyboard.db.api import base as api_base
|
||||||
from storyboard.db import models
|
from storyboard.db import models
|
||||||
|
|
||||||
|
|
||||||
def user_token_get(access_token_id):
|
def user_token_get(access_token_id):
|
||||||
token_query = api_base.model_query(models.AccessToken)
|
return api_base.model_query(models.AccessToken) \
|
||||||
user_token = token_query.options(
|
.filter_by(id=access_token_id).first()
|
||||||
subqueryload(models.AccessToken.refresh_tokens)
|
|
||||||
).filter_by(id=access_token_id).first()
|
|
||||||
|
|
||||||
return user_token
|
|
||||||
|
|
||||||
|
|
||||||
def user_token_get_all(marker=None, limit=None, sort_field=None,
|
def user_token_get_all(marker=None, limit=None, sort_field=None,
|
||||||
@@ -36,8 +30,7 @@ def user_token_get_all(marker=None, limit=None, sort_field=None,
|
|||||||
if not sort_dir:
|
if not sort_dir:
|
||||||
sort_dir = 'asc'
|
sort_dir = 'asc'
|
||||||
|
|
||||||
query = api_base.model_query(models.AccessToken). \
|
query = api_base.model_query(models.AccessToken)
|
||||||
options(subqueryload(models.AccessToken.refresh_tokens))
|
|
||||||
|
|
||||||
query = api_base.apply_query_filters(query=query,
|
query = api_base.apply_query_filters(query=query,
|
||||||
model=models.AccessToken,
|
model=models.AccessToken,
|
||||||
|
@@ -0,0 +1,65 @@
|
|||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
# not use this file except in compliance with the License. You may obtain
|
||||||
|
# a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
# License for the specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
"""This migration converts our many-to-many mapping among auth tokens into
|
||||||
|
a one-to-one relationship.
|
||||||
|
|
||||||
|
Revision ID: 048
|
||||||
|
Revises: 047
|
||||||
|
Create Date: 2015-04-12 18:03:23
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
|
||||||
|
revision = '048'
|
||||||
|
down_revision = '047'
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.sql.expression import table
|
||||||
|
|
||||||
|
MYSQL_ENGINE = 'InnoDB'
|
||||||
|
MYSQL_CHARSET = 'utf8'
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade(active_plugins=None, options=None):
|
||||||
|
op.add_column('refreshtokens', sa.Column('access_token_id',
|
||||||
|
sa.Integer(),
|
||||||
|
nullable=False))
|
||||||
|
op.drop_table('access_refresh_tokens')
|
||||||
|
|
||||||
|
# Delete all refresh and access tokens, as the relationship is no longer
|
||||||
|
# valid.
|
||||||
|
bind = op.get_bind()
|
||||||
|
|
||||||
|
refresh_table = table(
|
||||||
|
'refreshtokens'
|
||||||
|
)
|
||||||
|
access_table = table(
|
||||||
|
'accesstokens'
|
||||||
|
)
|
||||||
|
|
||||||
|
bind.execute(refresh_table.delete())
|
||||||
|
bind.execute(access_table.delete())
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade(active_plugins=None, options=None):
|
||||||
|
op.create_table('access_refresh_tokens',
|
||||||
|
sa.Column('access_token_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('refresh_token_id', sa.Integer(),
|
||||||
|
nullable=False),
|
||||||
|
mysql_engine=MYSQL_ENGINE,
|
||||||
|
mysql_charset=MYSQL_CHARSET
|
||||||
|
)
|
||||||
|
op.drop_column(u'refreshtokens', u'access_token_id')
|
@@ -393,30 +393,25 @@ class AuthorizationCode(ModelBuilder, Base):
|
|||||||
expires_in = Column(Integer, nullable=False, default=300)
|
expires_in = Column(Integer, nullable=False, default=300)
|
||||||
|
|
||||||
|
|
||||||
access_refresh_tokens = Table(
|
|
||||||
'access_refresh_tokens', Base.metadata,
|
|
||||||
Column('access_token_id', Integer, ForeignKey('accesstokens.id')),
|
|
||||||
Column('refresh_token_id', Integer, ForeignKey('refreshtokens.id')),
|
|
||||||
schema.UniqueConstraint('refresh_token_id', name='uniq_refresh_token'),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AccessToken(ModelBuilder, Base):
|
class AccessToken(ModelBuilder, Base):
|
||||||
user_id = Column(Integer, ForeignKey('users.id'), nullable=False)
|
user_id = Column(Integer, ForeignKey('users.id'), nullable=False)
|
||||||
access_token = Column(Unicode(CommonLength.top_middle_length),
|
access_token = Column(Unicode(CommonLength.top_middle_length),
|
||||||
nullable=False)
|
nullable=False)
|
||||||
expires_in = Column(Integer, nullable=False)
|
expires_in = Column(Integer, nullable=False)
|
||||||
expires_at = Column(UTCDateTime, nullable=False)
|
expires_at = Column(UTCDateTime, nullable=False)
|
||||||
refresh_tokens = relationship("RefreshToken",
|
refresh_token = relationship("RefreshToken",
|
||||||
secondary='access_refresh_tokens',
|
uselist=False,
|
||||||
)
|
cascade="all, delete-orphan",
|
||||||
|
backref="access_token",
|
||||||
|
passive_updates=False,
|
||||||
|
passive_deletes=False)
|
||||||
|
|
||||||
|
|
||||||
class RefreshToken(ModelBuilder, Base):
|
class RefreshToken(ModelBuilder, Base):
|
||||||
user_id = Column(Integer, ForeignKey('users.id'), nullable=False)
|
user_id = Column(Integer, ForeignKey('users.id'), nullable=False)
|
||||||
access_tokens = relationship("AccessToken",
|
access_token_id = Column(Integer,
|
||||||
secondary='access_refresh_tokens',
|
ForeignKey('accesstokens.id'),
|
||||||
)
|
nullable=False)
|
||||||
refresh_token = Column(Unicode(CommonLength.top_middle_length),
|
refresh_token = Column(Unicode(CommonLength.top_middle_length),
|
||||||
nullable=False)
|
nullable=False)
|
||||||
expires_in = Column(Integer, nullable=False)
|
expires_in = Column(Integer, nullable=False)
|
||||||
|
@@ -28,7 +28,6 @@ class TokenTest(base.BaseDbTestCase):
|
|||||||
|
|
||||||
self.token_01 = {
|
self.token_01 = {
|
||||||
"access_token": u'an_access_token',
|
"access_token": u'an_access_token',
|
||||||
"refresh_token": u'a_refresh_token',
|
|
||||||
"expires_in": 3600,
|
"expires_in": 3600,
|
||||||
"expires_at": datetime.now(pytz.utc),
|
"expires_at": datetime.now(pytz.utc),
|
||||||
"user_id": 1
|
"user_id": 1
|
||||||
|
@@ -92,7 +92,7 @@ class TestTokenCleaner(db_base.BaseDbTestCase,
|
|||||||
created_at=token.created_at,
|
created_at=token.created_at,
|
||||||
expires_in=token.expires_in,
|
expires_in=token.expires_in,
|
||||||
expires_at=token.expires_at,
|
expires_at=token.expires_at,
|
||||||
access_tokens=[token],
|
access_token_id=token.id,
|
||||||
refresh_token='test_refresh_%s' % (token.id,))
|
refresh_token='test_refresh_%s' % (token.id,))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user