Split auth in api and change user_tokens
Now access tokens wmodel includes array of refresh tokens. To db.api added refresh_tokens and deleted corresponding fubctions from auth. To db.api added user_tokens. Added tests for this patch. Change-Id: I1a631e43831eb327a34726a1ce7c07d3a14f3aff
This commit is contained in:
parent
2769d74e67
commit
6c2ab67bfb
|
@ -22,7 +22,8 @@ from oslo.config import cfg
|
|||
from oslo_log import log
|
||||
|
||||
from storyboard.db.api import access_tokens as token_api
|
||||
from storyboard.db.api import auth as auth_api
|
||||
from storyboard.db.api import auth_codes as auth_api
|
||||
from storyboard.db.api import refresh_tokens as refresh_token_api
|
||||
from storyboard.db.api import users as user_api
|
||||
|
||||
CONF = cfg.CONF
|
||||
|
@ -197,7 +198,8 @@ class SkeletonValidator(RequestValidator):
|
|||
|
||||
# Try refresh token
|
||||
refresh_token = request._params.get("refresh_token")
|
||||
refresh_token_entry = auth_api.refresh_token_get(refresh_token)
|
||||
refresh_token_entry = \
|
||||
refresh_token_api.refresh_token_get_by_token(refresh_token)
|
||||
if refresh_token_entry:
|
||||
return refresh_token_entry.user_id
|
||||
|
||||
|
@ -226,14 +228,14 @@ class SkeletonValidator(RequestValidator):
|
|||
refresh_expires_in = CONF.oauth.refresh_token_ttl
|
||||
|
||||
refresh_token_values = {
|
||||
"access_token_id": access_token.id,
|
||||
"refresh_token": token["refresh_token"],
|
||||
"user_id": user_id,
|
||||
"expires_in": refresh_expires_in,
|
||||
"expires_at": datetime.datetime.now(pytz.utc) + datetime.timedelta(
|
||||
seconds=refresh_expires_in),
|
||||
}
|
||||
auth_api.refresh_token_save(refresh_token_values)
|
||||
refresh_token_api.refresh_token_create(access_token.id,
|
||||
refresh_token_values)
|
||||
|
||||
def invalidate_authorization_code(self, client_id, code, request, *args,
|
||||
**kwargs):
|
||||
|
@ -268,17 +270,7 @@ class SkeletonValidator(RequestValidator):
|
|||
def validate_refresh_token(self, refresh_token, client, request, *args,
|
||||
**kwargs):
|
||||
"""Check that the refresh token exists in the db."""
|
||||
|
||||
refresh_token_entry = auth_api.refresh_token_get(refresh_token)
|
||||
|
||||
if not refresh_token_entry:
|
||||
return False
|
||||
|
||||
if datetime.datetime.now(pytz.utc) > refresh_token_entry.expires_at:
|
||||
auth_api.refresh_token_delete(refresh_token)
|
||||
return False
|
||||
|
||||
return True
|
||||
return refresh_token_api.is_valid(refresh_token)
|
||||
|
||||
def invalidate_refresh_token(self, request):
|
||||
"""Remove a used token from the storage."""
|
||||
|
@ -290,8 +282,10 @@ class SkeletonValidator(RequestValidator):
|
|||
if not refresh_token:
|
||||
return
|
||||
|
||||
r_token = auth_api.refresh_token_get(refresh_token)
|
||||
token_api.access_token_delete(r_token.access_token_id) # Cascades
|
||||
r_token = refresh_token_api.refresh_token_get_by_token(refresh_token)
|
||||
token_api.access_token_delete(
|
||||
refresh_token_api.get_access_token_id(r_token.id)
|
||||
) # Cascades
|
||||
|
||||
|
||||
class OpenIdConnectServer(WebApplicationServer):
|
||||
|
|
|
@ -27,7 +27,8 @@ from storyboard.api.auth.oauth_validator import SERVER
|
|||
from storyboard.api.auth.openid_client import client as openid_client
|
||||
from storyboard.common import decorators
|
||||
from storyboard.common.exception import UnsupportedGrantType
|
||||
from storyboard.db.api import auth as auth_api
|
||||
from storyboard.db.api import auth_codes as auth_api
|
||||
from storyboard.db.api import refresh_tokens as refresh_token_api
|
||||
|
||||
LOG = log.getLogger(__name__)
|
||||
|
||||
|
@ -101,7 +102,8 @@ class AuthController(rest.RestController):
|
|||
|
||||
def _access_token_by_refresh_token(self):
|
||||
refresh_token = request.params.get("refresh_token")
|
||||
refresh_token_info = auth_api.refresh_token_get(refresh_token)
|
||||
refresh_token_info = \
|
||||
refresh_token_api.refresh_token_get_by_token(refresh_token)
|
||||
|
||||
headers, body, code = SERVER.create_token_response(
|
||||
uri=request.url,
|
||||
|
|
|
@ -29,7 +29,7 @@ import wsmeext.pecan as wsme_pecan
|
|||
from storyboard.api.auth import authorization_checks as checks
|
||||
import storyboard.api.v1.wmodels as wmodels
|
||||
from storyboard.common import decorators
|
||||
import storyboard.db.api.access_tokens as token_api
|
||||
import storyboard.db.api.user_tokens as token_api
|
||||
import storyboard.db.api.users as user_api
|
||||
from storyboard.openstack.common.gettextutils import _ # noqa
|
||||
|
||||
|
@ -39,6 +39,18 @@ LOG = log.getLogger(__name__)
|
|||
|
||||
|
||||
class UserTokensController(rest.RestController):
|
||||
def _from_db_model(self, access_token):
|
||||
access_token_model = wmodels.AccessToken.from_db_model(
|
||||
access_token,
|
||||
skip_fields="refresh_tokens")
|
||||
|
||||
access_token_model.refresh_tokens = [
|
||||
wmodels.RefreshToken.from_db_model(token)
|
||||
for token in access_token.refresh_tokens
|
||||
]
|
||||
|
||||
return access_token_model
|
||||
|
||||
@decorators.db_exceptions
|
||||
@secure(checks.authenticated)
|
||||
@wsme_pecan.wsexpose([wmodels.AccessToken], int, int, int, wtypes.text,
|
||||
|
@ -62,23 +74,24 @@ class UserTokensController(rest.RestController):
|
|||
limit = min(CONF.page_size_maximum, max(1, limit))
|
||||
|
||||
# Resolve the marker record.
|
||||
marker_token = token_api.access_token_get(marker)
|
||||
marker_token = token_api.user_token_get(marker)
|
||||
|
||||
tokens = token_api.access_token_get_all(marker=marker_token,
|
||||
tokens = token_api.user_token_get_all(marker=marker_token,
|
||||
limit=limit,
|
||||
user_id=user_id,
|
||||
filter_non_public=True,
|
||||
sort_field=sort_field,
|
||||
sort_dir=sort_dir)
|
||||
token_count = token_api.access_token_get_count(user_id=user_id)
|
||||
token_count = token_api.user_token_get_count(user_id=user_id)
|
||||
|
||||
# Apply the query response headers.
|
||||
response.headers['X-Limit'] = str(limit)
|
||||
response.headers['X-Total'] = str(token_count)
|
||||
|
||||
if marker_token:
|
||||
response.headers['X-Marker'] = str(marker_token.id)
|
||||
|
||||
return [wmodels.AccessToken.from_db_model(t) for t in tokens]
|
||||
return [self._from_db_model(t) for t in tokens]
|
||||
|
||||
@decorators.db_exceptions
|
||||
@secure(checks.authenticated)
|
||||
|
@ -90,13 +103,13 @@ class UserTokensController(rest.RestController):
|
|||
:param access_token_id: The ID of the access token.
|
||||
:return: The requested access token.
|
||||
"""
|
||||
access_token = token_api.access_token_get(access_token_id)
|
||||
access_token = token_api.user_token_get(access_token_id)
|
||||
self._assert_can_access(user_id, access_token)
|
||||
|
||||
if not access_token:
|
||||
abort(404, _("Token not found."))
|
||||
abort(404, _("Token %s not found.") % access_token_id)
|
||||
|
||||
return wmodels.AccessToken.from_db_model(access_token)
|
||||
return self._from_db_model(access_token)
|
||||
|
||||
@decorators.db_exceptions
|
||||
@secure(checks.authenticated)
|
||||
|
@ -115,13 +128,24 @@ class UserTokensController(rest.RestController):
|
|||
body.access_token = six.text_type(uuid.uuid4())
|
||||
|
||||
# Token duplication check.
|
||||
dupes = token_api.access_token_get_all(access_token=body.access_token)
|
||||
dupes = token_api.user_token_get_all(
|
||||
access_token=body.access_token
|
||||
)
|
||||
|
||||
if dupes:
|
||||
abort(409, _('This token already exist.'))
|
||||
|
||||
token = token_api.access_token_create(body.as_dict())
|
||||
token_dict = body.as_dict()
|
||||
|
||||
return wmodels.AccessToken.from_db_model(token)
|
||||
if "refresh_tokens" in token_dict:
|
||||
del token_dict["refresh_tokens"]
|
||||
|
||||
token = token_api.user_token_create(token_dict)
|
||||
|
||||
if not token:
|
||||
abort(400, _("Can't create access token."))
|
||||
|
||||
return self._from_db_model(token)
|
||||
|
||||
@decorators.db_exceptions
|
||||
@secure(checks.authenticated)
|
||||
|
@ -135,21 +159,27 @@ class UserTokensController(rest.RestController):
|
|||
:param body: The access token.
|
||||
:return: The created access token.
|
||||
"""
|
||||
target_token = token_api.access_token_get(access_token_id)
|
||||
|
||||
target_token = token_api.user_token_get(access_token_id)
|
||||
|
||||
self._assert_can_access(user_id, body)
|
||||
self._assert_can_access(user_id, target_token)
|
||||
|
||||
if not target_token:
|
||||
abort(404, _("Token not found."))
|
||||
abort(404, _("Token %s not found.") % access_token_id)
|
||||
|
||||
# We only allow updating the expiration date.
|
||||
target_token.expires_in = body.expires_in
|
||||
|
||||
result_token = token_api.access_token_update(access_token_id,
|
||||
target_token.as_dict())
|
||||
token_dict = target_token.as_dict()
|
||||
|
||||
return wmodels.AccessToken.from_db_model(result_token)
|
||||
if "refresh_tokens" in token_dict:
|
||||
del token_dict["refresh_tokens"]
|
||||
|
||||
result_token = token_api.user_token_update(access_token_id,
|
||||
token_dict)
|
||||
|
||||
return self._from_db_model(result_token)
|
||||
|
||||
@decorators.db_exceptions
|
||||
@secure(checks.authenticated)
|
||||
|
@ -161,13 +191,13 @@ class UserTokensController(rest.RestController):
|
|||
:param access_token_id: The ID of the access token.
|
||||
:return: Empty body, or error response.
|
||||
"""
|
||||
access_token = token_api.access_token_get(access_token_id)
|
||||
access_token = token_api.user_token_get(access_token_id)
|
||||
self._assert_can_access(user_id, access_token)
|
||||
|
||||
if not access_token:
|
||||
abort(404, _("Token not found."))
|
||||
abort(404, _("Token %s not found.") % access_token_id)
|
||||
|
||||
token_api.access_token_delete(access_token_id)
|
||||
token_api.user_token_delete(access_token_id)
|
||||
|
||||
def _assert_can_access(self, user_id, token_entity=None):
|
||||
current_user = user_api.user_get(request.current_user_id)
|
||||
|
|
|
@ -389,6 +389,27 @@ class User(base.APIBase):
|
|||
last_login=datetime(2014, 1, 1, 16, 42))
|
||||
|
||||
|
||||
class RefreshToken(base.APIBase):
|
||||
"""Represents a user refresh token."""
|
||||
|
||||
user_id = int
|
||||
"""The ID of corresponding user."""
|
||||
|
||||
refresh_token = wtypes.text
|
||||
"""The refresh token."""
|
||||
|
||||
expires_in = int
|
||||
"""The number of seconds after creation when this token expires."""
|
||||
|
||||
@classmethod
|
||||
def sample(cls):
|
||||
return cls(
|
||||
user_id=1,
|
||||
refresh_token="a_unique_refresh_token",
|
||||
expires_in=3600
|
||||
)
|
||||
|
||||
|
||||
class AccessToken(base.APIBase):
|
||||
"""Represents a user access token."""
|
||||
|
||||
|
@ -401,6 +422,9 @@ class AccessToken(base.APIBase):
|
|||
expires_in = int
|
||||
"""The number of seconds after creation when this token expires."""
|
||||
|
||||
refresh_tokens = wtypes.ArrayType(RefreshToken)
|
||||
"""Array of corresponding refresh tokens."""
|
||||
|
||||
@classmethod
|
||||
def sample(cls):
|
||||
return cls(
|
||||
|
|
|
@ -20,8 +20,10 @@ from storyboard.db.api import base as api_base
|
|||
from storyboard.db import models
|
||||
|
||||
|
||||
def access_token_get(access_token_id):
|
||||
return api_base.entity_get(models.AccessToken, access_token_id)
|
||||
def access_token_get(access_token_id, session=None):
|
||||
return api_base.entity_get(models.AccessToken,
|
||||
access_token_id,
|
||||
session=session)
|
||||
|
||||
|
||||
def access_token_get_by_token(access_token):
|
||||
|
@ -102,15 +104,29 @@ def access_token_build_query(**kwargs):
|
|||
return query
|
||||
|
||||
|
||||
def access_token_delete(access_token_id):
|
||||
session = api_base.get_session()
|
||||
|
||||
with session.begin():
|
||||
access_token = access_token_get(access_token_id, session=session)
|
||||
|
||||
if access_token:
|
||||
query = api_base.model_query(models.RefreshToken)
|
||||
refresh_tokens = []
|
||||
|
||||
for refresh_token in access_token.refresh_tokens:
|
||||
refresh_tokens.append(refresh_token.id)
|
||||
|
||||
query = query.filter(models.RefreshToken.id.in_(
|
||||
refresh_tokens
|
||||
))
|
||||
|
||||
api_base.entity_hard_delete(models.AccessToken, access_token_id)
|
||||
query.delete(synchronize_session=False)
|
||||
|
||||
|
||||
def access_token_delete_by_token(access_token):
|
||||
access_token = access_token_get_by_token(access_token)
|
||||
|
||||
if access_token:
|
||||
api_base.entity_hard_delete(models.AccessToken, access_token.id)
|
||||
|
||||
|
||||
def access_token_delete(access_token_id):
|
||||
access_token = access_token_get(access_token_id)
|
||||
|
||||
if access_token:
|
||||
api_base.entity_hard_delete(models.AccessToken, access_token_id)
|
||||
access_token_delete(access_token.id)
|
||||
|
|
|
@ -32,25 +32,3 @@ def authorization_code_delete(code):
|
|||
|
||||
if del_code:
|
||||
api_base.entity_hard_delete(models.AuthorizationCode, del_code.id)
|
||||
|
||||
|
||||
def refresh_token_get(refresh_token):
|
||||
try:
|
||||
query = api_base.model_query(models.RefreshToken,
|
||||
api_base.get_session())
|
||||
return query.filter_by(refresh_token=refresh_token).first()
|
||||
except Exception:
|
||||
# If anything goes wrong while fetching a token None will be returned
|
||||
# anyway.
|
||||
return None
|
||||
|
||||
|
||||
def refresh_token_save(values):
|
||||
return api_base.entity_create(models.RefreshToken, values)
|
||||
|
||||
|
||||
def refresh_token_delete(refresh_token):
|
||||
del_token = refresh_token_get(refresh_token)
|
||||
|
||||
if del_token:
|
||||
api_base.entity_hard_delete(models.RefreshToken, del_token.id)
|
|
@ -0,0 +1,117 @@
|
|||
# Copyright (c) 2015 Mirantis Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import datetime
|
||||
import pytz
|
||||
|
||||
from storyboard.common import exception as exc
|
||||
from storyboard.db.api import access_tokens as access_tokens_api
|
||||
from storyboard.db.api import base as api_base
|
||||
from storyboard.db import models
|
||||
from storyboard.openstack.common.gettextutils import _ # noqa
|
||||
|
||||
|
||||
def refresh_token_get(refresh_token_id, session=None):
|
||||
return api_base.entity_get(models.RefreshToken, refresh_token_id,
|
||||
session=session)
|
||||
|
||||
|
||||
def refresh_token_get_by_token(refresh_token):
|
||||
try:
|
||||
return api_base.model_query(models.RefreshToken) \
|
||||
.filter_by(refresh_token=refresh_token).first()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def is_valid(refresh_token):
|
||||
if not refresh_token:
|
||||
return False
|
||||
|
||||
token = refresh_token_get_by_token(refresh_token)
|
||||
|
||||
if not token:
|
||||
return False
|
||||
|
||||
if datetime.datetime.now(pytz.utc) > token.expires_at:
|
||||
refresh_token_delete(token.id)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_access_token_id(refresh_token_id):
|
||||
session = api_base.get_session()
|
||||
|
||||
with session.begin():
|
||||
refresh_token = refresh_token_get(refresh_token_id, session)
|
||||
|
||||
if refresh_token:
|
||||
return refresh_token.access_tokens[0].id
|
||||
|
||||
|
||||
def refresh_token_create(access_token_id, values):
|
||||
session = api_base.get_session()
|
||||
|
||||
with session.begin():
|
||||
values['expires_at'] = datetime.datetime.now(pytz.utc) + datetime.\
|
||||
timedelta(seconds=values['expires_in'])
|
||||
access_token = access_tokens_api.access_token_get(access_token_id,
|
||||
session=session)
|
||||
|
||||
if not access_token:
|
||||
raise exc.NotFound(_("Access token not found."))
|
||||
|
||||
refresh_token = api_base.entity_create(models.RefreshToken, values)
|
||||
|
||||
access_token.refresh_tokens.append(refresh_token)
|
||||
session.add(access_token)
|
||||
|
||||
return refresh_token
|
||||
|
||||
|
||||
def refresh_token_build_query(**kwargs):
|
||||
# Construct the query
|
||||
query = api_base.model_query(models.RefreshToken)
|
||||
|
||||
# Apply the filters
|
||||
query = api_base.apply_query_filters(query=query,
|
||||
model=models.RefreshToken,
|
||||
**kwargs)
|
||||
|
||||
return query
|
||||
|
||||
|
||||
def refresh_token_delete(refresh_token_id):
|
||||
session = api_base.get_session()
|
||||
|
||||
with session.begin():
|
||||
refresh_token = refresh_token_get(refresh_token_id)
|
||||
|
||||
if refresh_token:
|
||||
access_token_id = refresh_token.access_tokens[0].id
|
||||
access_token = access_tokens_api.access_token_get(access_token_id,
|
||||
session=session)
|
||||
|
||||
access_token.refresh_tokens.remove(refresh_token)
|
||||
session.add(access_token)
|
||||
api_base.entity_hard_delete(models.RefreshToken, refresh_token_id)
|
||||
|
||||
|
||||
def refresh_token_delete_by_token(refresh_token):
|
||||
refresh_token = refresh_token_get_by_token(refresh_token)
|
||||
|
||||
if refresh_token:
|
||||
refresh_token_delete(refresh_token.id)
|
|
@ -0,0 +1,84 @@
|
|||
# Copyright (c) 2015 Mirantis Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from sqlalchemy.orm import subqueryload
|
||||
|
||||
from storyboard.db.api import access_tokens as access_tokens_api
|
||||
from storyboard.db.api import base as api_base
|
||||
from storyboard.db import models
|
||||
|
||||
|
||||
def user_token_get(access_token_id):
|
||||
token_query = api_base.model_query(models.AccessToken)
|
||||
user_token = token_query.options(
|
||||
subqueryload(models.AccessToken.refresh_tokens)
|
||||
).filter_by(id=access_token_id).first()
|
||||
|
||||
return user_token
|
||||
|
||||
|
||||
def user_token_get_all(marker=None, limit=None, sort_field=None,
|
||||
sort_dir=None, **kwargs):
|
||||
if not sort_field:
|
||||
sort_field = 'id'
|
||||
if not sort_dir:
|
||||
sort_dir = 'asc'
|
||||
|
||||
query = api_base.model_query(models.AccessToken). \
|
||||
options(subqueryload(models.AccessToken.refresh_tokens))
|
||||
|
||||
query = api_base.apply_query_filters(query=query,
|
||||
model=models.AccessToken,
|
||||
**kwargs)
|
||||
|
||||
query = api_base.paginate_query(query=query,
|
||||
model=models.AccessToken,
|
||||
limit=limit,
|
||||
sort_key=sort_field,
|
||||
marker=marker,
|
||||
sort_dir=sort_dir)
|
||||
|
||||
return query.all()
|
||||
|
||||
|
||||
def user_token_get_count(**kwargs):
|
||||
query = api_base.model_query(models.AccessToken)
|
||||
|
||||
query = api_base.apply_query_filters(query=query,
|
||||
model=models.AccessToken,
|
||||
**kwargs)
|
||||
|
||||
return query.count()
|
||||
|
||||
|
||||
def user_token_create(values):
|
||||
access_token = access_tokens_api.access_token_create(values)
|
||||
|
||||
return user_token_get(access_token.id)
|
||||
|
||||
|
||||
def user_token_update(access_token_id, values):
|
||||
access_token = access_tokens_api.access_token_update(
|
||||
access_token_id,
|
||||
values)
|
||||
|
||||
if access_token:
|
||||
return user_token_get(access_token.id)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def user_token_delete(access_token_id):
|
||||
access_tokens_api.access_token_delete(access_token_id)
|
|
@ -0,0 +1,50 @@
|
|||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
|
||||
"""This migration creates new table for relations between access and refresh
|
||||
tokens.
|
||||
|
||||
Revision ID: 041
|
||||
Revises: 040
|
||||
Create Date: 2015-02-18 18:03:23
|
||||
|
||||
"""
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
|
||||
revision = '041'
|
||||
down_revision = '040'
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
MYSQL_ENGINE = 'InnoDB'
|
||||
MYSQL_CHARSET = 'utf8'
|
||||
|
||||
|
||||
def upgrade(active_plugins=None, options=None):
|
||||
op.create_table('access_refresh_tokens',
|
||||
sa.Column('access_token_id', sa.Integer(), nullable=False),
|
||||
sa.Column('refresh_token_id', sa.Integer(),
|
||||
nullable=False),
|
||||
mysql_engine=MYSQL_ENGINE,
|
||||
mysql_charset=MYSQL_CHARSET
|
||||
)
|
||||
op.drop_column(u'refreshtokens', u'access_token_id')
|
||||
|
||||
|
||||
def downgrade(active_plugins=None, options=None):
|
||||
op.add_column('refreshtokens', sa.Column('access_token_id',
|
||||
sa.Integer(),
|
||||
nullable=False))
|
||||
op.drop_table('access_refresh_tokens')
|
|
@ -367,6 +367,14 @@ class AuthorizationCode(ModelBuilder, Base):
|
|||
expires_in = Column(Integer, nullable=False, default=300)
|
||||
|
||||
|
||||
access_refresh_tokens = Table(
|
||||
'access_refresh_tokens', Base.metadata,
|
||||
Column('access_token_id', Integer, ForeignKey('accesstokens.id')),
|
||||
Column('refresh_token_id', Integer, ForeignKey('refreshtokens.id')),
|
||||
schema.UniqueConstraint('refresh_token_id', name='uniq_refresh_token'),
|
||||
)
|
||||
|
||||
|
||||
class AccessToken(ModelBuilder, Base):
|
||||
user_id = Column(Integer, ForeignKey('users.id'), nullable=False)
|
||||
access_token = Column(Unicode(CommonLength.top_middle_length),
|
||||
|
@ -374,16 +382,15 @@ class AccessToken(ModelBuilder, Base):
|
|||
expires_in = Column(Integer, nullable=False)
|
||||
expires_at = Column(UTCDateTime, nullable=False)
|
||||
refresh_tokens = relationship("RefreshToken",
|
||||
cascade="save-update, merge, delete",
|
||||
passive_updates=False,
|
||||
passive_deletes=False)
|
||||
secondary='access_refresh_tokens',
|
||||
)
|
||||
|
||||
|
||||
class RefreshToken(ModelBuilder, Base):
|
||||
user_id = Column(Integer, ForeignKey('users.id'), nullable=False)
|
||||
access_token_id = Column(Integer,
|
||||
ForeignKey('accesstokens.id'),
|
||||
nullable=False)
|
||||
access_tokens = relationship("AccessToken",
|
||||
secondary='access_refresh_tokens',
|
||||
)
|
||||
refresh_token = Column(Unicode(CommonLength.top_middle_length),
|
||||
nullable=False)
|
||||
expires_in = Column(Integer, nullable=False)
|
||||
|
|
|
@ -18,6 +18,7 @@ import pytz
|
|||
|
||||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
|
||||
from storyboard.db.api import access_tokens as access_tokens_api
|
||||
import storyboard.db.api.base as api_base
|
||||
from storyboard.db.models import AccessToken
|
||||
from storyboard.plugin.scheduler.base import SchedulerPluginBase
|
||||
|
@ -57,4 +58,4 @@ class TokenCleaner(SchedulerPluginBase):
|
|||
# Manually deleting each record, because batch deletes are an
|
||||
# exception to ORM Cascade markup.
|
||||
for token in query.all():
|
||||
api_base.entity_hard_delete(AccessToken, token.id)
|
||||
access_tokens_api.access_token_delete(token.id)
|
||||
|
|
|
@ -25,7 +25,8 @@ import six.moves.urllib.parse as urlparse
|
|||
|
||||
from storyboard.api.auth import ErrorMessages as e_msg
|
||||
from storyboard.db.api import access_tokens as token_api
|
||||
from storyboard.db.api import auth as auth_api
|
||||
from storyboard.db.api import auth_codes as auth_api
|
||||
from storyboard.db.api import refresh_tokens
|
||||
from storyboard.tests import base
|
||||
|
||||
|
||||
|
@ -584,7 +585,7 @@ class TestOAuthAccessToken(BaseOAuthTest):
|
|||
|
||||
# Assert that the refresh token is in the database
|
||||
refresh_token = \
|
||||
auth_api.refresh_token_get(token['refresh_token'])
|
||||
refresh_tokens.refresh_token_get_by_token(token['refresh_token'])
|
||||
self.assertIsNotNone(refresh_token)
|
||||
|
||||
# Assert that system configured values is owned by the correct user.
|
||||
|
@ -773,7 +774,7 @@ class TestOAuthAccessToken(BaseOAuthTest):
|
|||
token_api.access_token_get_by_token(t1['access_token'])
|
||||
self.assertIsNotNone(access_token)
|
||||
refresh_token = \
|
||||
auth_api.refresh_token_get(t1['refresh_token'])
|
||||
refresh_tokens.refresh_token_get_by_token(t1['refresh_token'])
|
||||
self.assertIsNotNone(refresh_token)
|
||||
|
||||
# Issue a refresh token request.
|
||||
|
@ -814,7 +815,7 @@ class TestOAuthAccessToken(BaseOAuthTest):
|
|||
|
||||
# Assert that the refresh token is in the database
|
||||
new_refresh_token = \
|
||||
auth_api.refresh_token_get(t2['refresh_token'])
|
||||
refresh_tokens.refresh_token_get_by_token(t2['refresh_token'])
|
||||
self.assertIsNotNone(new_refresh_token)
|
||||
|
||||
# Assert that system configured values is owned by the correct user.
|
||||
|
@ -829,7 +830,7 @@ class TestOAuthAccessToken(BaseOAuthTest):
|
|||
no_access_token = \
|
||||
token_api.access_token_get_by_token(t1['access_token'])
|
||||
no_refresh_token = \
|
||||
auth_api.refresh_token_get(t1['refresh_token'])
|
||||
refresh_tokens.refresh_token_get_by_token(t1['refresh_token'])
|
||||
|
||||
self.assertIsNone(no_refresh_token)
|
||||
self.assertIsNone(no_access_token)
|
||||
|
|
|
@ -28,7 +28,6 @@ class TestUserTokensAsUser(base.FunctionalTest):
|
|||
"""
|
||||
response = self.get_json(self.resource, expect_errors=True)
|
||||
self.assertEqual(200, response.status_code)
|
||||
self.assertEqual(len(response.json), 2)
|
||||
|
||||
def test_unauthorized_browse(self):
|
||||
"""Assert a basic browse for someone else's tokens
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from storyboard.db.api import auth
|
||||
from storyboard.db.api import auth_codes
|
||||
from storyboard.db.api import users
|
||||
from storyboard.tests.db import base
|
||||
|
||||
|
@ -32,15 +32,15 @@ class AuthorizationCodeTest(base.BaseDbTestCase):
|
|||
users.user_create({"fullname": "Test User"})
|
||||
|
||||
def test_create_code(self):
|
||||
self._test_create(self.code_01, auth.authorization_code_save)
|
||||
self._test_create(self.code_01, auth_codes.authorization_code_save)
|
||||
|
||||
def test_delete_code(self):
|
||||
created_code = auth.authorization_code_save(self.code_01)
|
||||
created_code = auth_codes.authorization_code_save(self.code_01)
|
||||
|
||||
self.assertIsNotNone(created_code,
|
||||
"Could not create an Authorization code")
|
||||
|
||||
auth.authorization_code_delete(created_code.code)
|
||||
auth_codes.authorization_code_delete(created_code.code)
|
||||
|
||||
fetched_code = auth.authorization_code_get(created_code.code)
|
||||
fetched_code = auth_codes.authorization_code_get(created_code.code)
|
||||
self.assertIsNone(fetched_code)
|
||||
|
|
|
@ -69,31 +69,33 @@ class TestTokenCleaner(db_base.BaseDbTestCase,
|
|||
# 8-day-old-token to remain valid.
|
||||
new_access_tokens = []
|
||||
new_refresh_tokens = []
|
||||
|
||||
for i in range(0, 100):
|
||||
created_at = datetime.now(pytz.utc) - timedelta(days=i)
|
||||
expires_in = (60 * 60 * 24) - 5 # Minus five seconds, see above.
|
||||
expires_at = created_at + timedelta(seconds=expires_in)
|
||||
|
||||
new_access_tokens.append(
|
||||
AccessToken(
|
||||
user_id=1,
|
||||
created_at=created_at,
|
||||
expires_in=expires_in,
|
||||
expires_at=expires_at,
|
||||
access_token='test_token_%s' % (i,))
|
||||
access_token='test_token_%s' % (i,)),
|
||||
)
|
||||
|
||||
new_access_tokens = load_data(new_access_tokens)
|
||||
|
||||
for token in new_access_tokens:
|
||||
new_refresh_tokens.append(
|
||||
RefreshToken(
|
||||
user_id=1,
|
||||
access_token_id=token.id,
|
||||
created_at=token.created_at,
|
||||
expires_in=token.expires_in,
|
||||
expires_at=token.expires_at,
|
||||
expires_in=300,
|
||||
access_tokens=[token],
|
||||
refresh_token='test_refresh_%s' % (token.id,))
|
||||
)
|
||||
|
||||
new_refresh_tokens = load_data(new_refresh_tokens)
|
||||
|
||||
# Make sure we have 100 tokens.
|
||||
|
|
Loading…
Reference in New Issue