Split auth in api and change user_tokens
Now access tokens wmodel includes array of refresh tokens. To db.api added refresh_tokens and deleted corresponding fubctions from auth. To db.api added user_tokens. Added tests for this patch. Change-Id: I1a631e43831eb327a34726a1ce7c07d3a14f3aff
This commit is contained in:
parent
2769d74e67
commit
6c2ab67bfb
|
@ -22,7 +22,8 @@ from oslo.config import cfg
|
||||||
from oslo_log import log
|
from oslo_log import log
|
||||||
|
|
||||||
from storyboard.db.api import access_tokens as token_api
|
from storyboard.db.api import access_tokens as token_api
|
||||||
from storyboard.db.api import auth as auth_api
|
from storyboard.db.api import auth_codes as auth_api
|
||||||
|
from storyboard.db.api import refresh_tokens as refresh_token_api
|
||||||
from storyboard.db.api import users as user_api
|
from storyboard.db.api import users as user_api
|
||||||
|
|
||||||
CONF = cfg.CONF
|
CONF = cfg.CONF
|
||||||
|
@ -197,7 +198,8 @@ class SkeletonValidator(RequestValidator):
|
||||||
|
|
||||||
# Try refresh token
|
# Try refresh token
|
||||||
refresh_token = request._params.get("refresh_token")
|
refresh_token = request._params.get("refresh_token")
|
||||||
refresh_token_entry = auth_api.refresh_token_get(refresh_token)
|
refresh_token_entry = \
|
||||||
|
refresh_token_api.refresh_token_get_by_token(refresh_token)
|
||||||
if refresh_token_entry:
|
if refresh_token_entry:
|
||||||
return refresh_token_entry.user_id
|
return refresh_token_entry.user_id
|
||||||
|
|
||||||
|
@ -226,14 +228,14 @@ class SkeletonValidator(RequestValidator):
|
||||||
refresh_expires_in = CONF.oauth.refresh_token_ttl
|
refresh_expires_in = CONF.oauth.refresh_token_ttl
|
||||||
|
|
||||||
refresh_token_values = {
|
refresh_token_values = {
|
||||||
"access_token_id": access_token.id,
|
|
||||||
"refresh_token": token["refresh_token"],
|
"refresh_token": token["refresh_token"],
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"expires_in": refresh_expires_in,
|
"expires_in": refresh_expires_in,
|
||||||
"expires_at": datetime.datetime.now(pytz.utc) + datetime.timedelta(
|
"expires_at": datetime.datetime.now(pytz.utc) + datetime.timedelta(
|
||||||
seconds=refresh_expires_in),
|
seconds=refresh_expires_in),
|
||||||
}
|
}
|
||||||
auth_api.refresh_token_save(refresh_token_values)
|
refresh_token_api.refresh_token_create(access_token.id,
|
||||||
|
refresh_token_values)
|
||||||
|
|
||||||
def invalidate_authorization_code(self, client_id, code, request, *args,
|
def invalidate_authorization_code(self, client_id, code, request, *args,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
@ -268,17 +270,7 @@ class SkeletonValidator(RequestValidator):
|
||||||
def validate_refresh_token(self, refresh_token, client, request, *args,
|
def validate_refresh_token(self, refresh_token, client, request, *args,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""Check that the refresh token exists in the db."""
|
"""Check that the refresh token exists in the db."""
|
||||||
|
return refresh_token_api.is_valid(refresh_token)
|
||||||
refresh_token_entry = auth_api.refresh_token_get(refresh_token)
|
|
||||||
|
|
||||||
if not refresh_token_entry:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if datetime.datetime.now(pytz.utc) > refresh_token_entry.expires_at:
|
|
||||||
auth_api.refresh_token_delete(refresh_token)
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def invalidate_refresh_token(self, request):
|
def invalidate_refresh_token(self, request):
|
||||||
"""Remove a used token from the storage."""
|
"""Remove a used token from the storage."""
|
||||||
|
@ -290,8 +282,10 @@ class SkeletonValidator(RequestValidator):
|
||||||
if not refresh_token:
|
if not refresh_token:
|
||||||
return
|
return
|
||||||
|
|
||||||
r_token = auth_api.refresh_token_get(refresh_token)
|
r_token = refresh_token_api.refresh_token_get_by_token(refresh_token)
|
||||||
token_api.access_token_delete(r_token.access_token_id) # Cascades
|
token_api.access_token_delete(
|
||||||
|
refresh_token_api.get_access_token_id(r_token.id)
|
||||||
|
) # Cascades
|
||||||
|
|
||||||
|
|
||||||
class OpenIdConnectServer(WebApplicationServer):
|
class OpenIdConnectServer(WebApplicationServer):
|
||||||
|
|
|
@ -27,7 +27,8 @@ from storyboard.api.auth.oauth_validator import SERVER
|
||||||
from storyboard.api.auth.openid_client import client as openid_client
|
from storyboard.api.auth.openid_client import client as openid_client
|
||||||
from storyboard.common import decorators
|
from storyboard.common import decorators
|
||||||
from storyboard.common.exception import UnsupportedGrantType
|
from storyboard.common.exception import UnsupportedGrantType
|
||||||
from storyboard.db.api import auth as auth_api
|
from storyboard.db.api import auth_codes as auth_api
|
||||||
|
from storyboard.db.api import refresh_tokens as refresh_token_api
|
||||||
|
|
||||||
LOG = log.getLogger(__name__)
|
LOG = log.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -101,7 +102,8 @@ class AuthController(rest.RestController):
|
||||||
|
|
||||||
def _access_token_by_refresh_token(self):
|
def _access_token_by_refresh_token(self):
|
||||||
refresh_token = request.params.get("refresh_token")
|
refresh_token = request.params.get("refresh_token")
|
||||||
refresh_token_info = auth_api.refresh_token_get(refresh_token)
|
refresh_token_info = \
|
||||||
|
refresh_token_api.refresh_token_get_by_token(refresh_token)
|
||||||
|
|
||||||
headers, body, code = SERVER.create_token_response(
|
headers, body, code = SERVER.create_token_response(
|
||||||
uri=request.url,
|
uri=request.url,
|
||||||
|
|
|
@ -29,7 +29,7 @@ import wsmeext.pecan as wsme_pecan
|
||||||
from storyboard.api.auth import authorization_checks as checks
|
from storyboard.api.auth import authorization_checks as checks
|
||||||
import storyboard.api.v1.wmodels as wmodels
|
import storyboard.api.v1.wmodels as wmodels
|
||||||
from storyboard.common import decorators
|
from storyboard.common import decorators
|
||||||
import storyboard.db.api.access_tokens as token_api
|
import storyboard.db.api.user_tokens as token_api
|
||||||
import storyboard.db.api.users as user_api
|
import storyboard.db.api.users as user_api
|
||||||
from storyboard.openstack.common.gettextutils import _ # noqa
|
from storyboard.openstack.common.gettextutils import _ # noqa
|
||||||
|
|
||||||
|
@ -39,6 +39,18 @@ LOG = log.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class UserTokensController(rest.RestController):
|
class UserTokensController(rest.RestController):
|
||||||
|
def _from_db_model(self, access_token):
|
||||||
|
access_token_model = wmodels.AccessToken.from_db_model(
|
||||||
|
access_token,
|
||||||
|
skip_fields="refresh_tokens")
|
||||||
|
|
||||||
|
access_token_model.refresh_tokens = [
|
||||||
|
wmodels.RefreshToken.from_db_model(token)
|
||||||
|
for token in access_token.refresh_tokens
|
||||||
|
]
|
||||||
|
|
||||||
|
return access_token_model
|
||||||
|
|
||||||
@decorators.db_exceptions
|
@decorators.db_exceptions
|
||||||
@secure(checks.authenticated)
|
@secure(checks.authenticated)
|
||||||
@wsme_pecan.wsexpose([wmodels.AccessToken], int, int, int, wtypes.text,
|
@wsme_pecan.wsexpose([wmodels.AccessToken], int, int, int, wtypes.text,
|
||||||
|
@ -62,23 +74,24 @@ class UserTokensController(rest.RestController):
|
||||||
limit = min(CONF.page_size_maximum, max(1, limit))
|
limit = min(CONF.page_size_maximum, max(1, limit))
|
||||||
|
|
||||||
# Resolve the marker record.
|
# Resolve the marker record.
|
||||||
marker_token = token_api.access_token_get(marker)
|
marker_token = token_api.user_token_get(marker)
|
||||||
|
|
||||||
tokens = token_api.access_token_get_all(marker=marker_token,
|
tokens = token_api.user_token_get_all(marker=marker_token,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
filter_non_public=True,
|
filter_non_public=True,
|
||||||
sort_field=sort_field,
|
sort_field=sort_field,
|
||||||
sort_dir=sort_dir)
|
sort_dir=sort_dir)
|
||||||
token_count = token_api.access_token_get_count(user_id=user_id)
|
token_count = token_api.user_token_get_count(user_id=user_id)
|
||||||
|
|
||||||
# Apply the query response headers.
|
# Apply the query response headers.
|
||||||
response.headers['X-Limit'] = str(limit)
|
response.headers['X-Limit'] = str(limit)
|
||||||
response.headers['X-Total'] = str(token_count)
|
response.headers['X-Total'] = str(token_count)
|
||||||
|
|
||||||
if marker_token:
|
if marker_token:
|
||||||
response.headers['X-Marker'] = str(marker_token.id)
|
response.headers['X-Marker'] = str(marker_token.id)
|
||||||
|
|
||||||
return [wmodels.AccessToken.from_db_model(t) for t in tokens]
|
return [self._from_db_model(t) for t in tokens]
|
||||||
|
|
||||||
@decorators.db_exceptions
|
@decorators.db_exceptions
|
||||||
@secure(checks.authenticated)
|
@secure(checks.authenticated)
|
||||||
|
@ -90,13 +103,13 @@ class UserTokensController(rest.RestController):
|
||||||
:param access_token_id: The ID of the access token.
|
:param access_token_id: The ID of the access token.
|
||||||
:return: The requested access token.
|
:return: The requested access token.
|
||||||
"""
|
"""
|
||||||
access_token = token_api.access_token_get(access_token_id)
|
access_token = token_api.user_token_get(access_token_id)
|
||||||
self._assert_can_access(user_id, access_token)
|
self._assert_can_access(user_id, access_token)
|
||||||
|
|
||||||
if not access_token:
|
if not access_token:
|
||||||
abort(404, _("Token not found."))
|
abort(404, _("Token %s not found.") % access_token_id)
|
||||||
|
|
||||||
return wmodels.AccessToken.from_db_model(access_token)
|
return self._from_db_model(access_token)
|
||||||
|
|
||||||
@decorators.db_exceptions
|
@decorators.db_exceptions
|
||||||
@secure(checks.authenticated)
|
@secure(checks.authenticated)
|
||||||
|
@ -115,13 +128,24 @@ class UserTokensController(rest.RestController):
|
||||||
body.access_token = six.text_type(uuid.uuid4())
|
body.access_token = six.text_type(uuid.uuid4())
|
||||||
|
|
||||||
# Token duplication check.
|
# Token duplication check.
|
||||||
dupes = token_api.access_token_get_all(access_token=body.access_token)
|
dupes = token_api.user_token_get_all(
|
||||||
|
access_token=body.access_token
|
||||||
|
)
|
||||||
|
|
||||||
if dupes:
|
if dupes:
|
||||||
abort(409, _('This token already exist.'))
|
abort(409, _('This token already exist.'))
|
||||||
|
|
||||||
token = token_api.access_token_create(body.as_dict())
|
token_dict = body.as_dict()
|
||||||
|
|
||||||
return wmodels.AccessToken.from_db_model(token)
|
if "refresh_tokens" in token_dict:
|
||||||
|
del token_dict["refresh_tokens"]
|
||||||
|
|
||||||
|
token = token_api.user_token_create(token_dict)
|
||||||
|
|
||||||
|
if not token:
|
||||||
|
abort(400, _("Can't create access token."))
|
||||||
|
|
||||||
|
return self._from_db_model(token)
|
||||||
|
|
||||||
@decorators.db_exceptions
|
@decorators.db_exceptions
|
||||||
@secure(checks.authenticated)
|
@secure(checks.authenticated)
|
||||||
|
@ -135,21 +159,27 @@ class UserTokensController(rest.RestController):
|
||||||
:param body: The access token.
|
:param body: The access token.
|
||||||
:return: The created access token.
|
:return: The created access token.
|
||||||
"""
|
"""
|
||||||
target_token = token_api.access_token_get(access_token_id)
|
|
||||||
|
target_token = token_api.user_token_get(access_token_id)
|
||||||
|
|
||||||
self._assert_can_access(user_id, body)
|
self._assert_can_access(user_id, body)
|
||||||
self._assert_can_access(user_id, target_token)
|
self._assert_can_access(user_id, target_token)
|
||||||
|
|
||||||
if not target_token:
|
if not target_token:
|
||||||
abort(404, _("Token not found."))
|
abort(404, _("Token %s not found.") % access_token_id)
|
||||||
|
|
||||||
# We only allow updating the expiration date.
|
# We only allow updating the expiration date.
|
||||||
target_token.expires_in = body.expires_in
|
target_token.expires_in = body.expires_in
|
||||||
|
|
||||||
result_token = token_api.access_token_update(access_token_id,
|
token_dict = target_token.as_dict()
|
||||||
target_token.as_dict())
|
|
||||||
|
|
||||||
return wmodels.AccessToken.from_db_model(result_token)
|
if "refresh_tokens" in token_dict:
|
||||||
|
del token_dict["refresh_tokens"]
|
||||||
|
|
||||||
|
result_token = token_api.user_token_update(access_token_id,
|
||||||
|
token_dict)
|
||||||
|
|
||||||
|
return self._from_db_model(result_token)
|
||||||
|
|
||||||
@decorators.db_exceptions
|
@decorators.db_exceptions
|
||||||
@secure(checks.authenticated)
|
@secure(checks.authenticated)
|
||||||
|
@ -161,13 +191,13 @@ class UserTokensController(rest.RestController):
|
||||||
:param access_token_id: The ID of the access token.
|
:param access_token_id: The ID of the access token.
|
||||||
:return: Empty body, or error response.
|
:return: Empty body, or error response.
|
||||||
"""
|
"""
|
||||||
access_token = token_api.access_token_get(access_token_id)
|
access_token = token_api.user_token_get(access_token_id)
|
||||||
self._assert_can_access(user_id, access_token)
|
self._assert_can_access(user_id, access_token)
|
||||||
|
|
||||||
if not access_token:
|
if not access_token:
|
||||||
abort(404, _("Token not found."))
|
abort(404, _("Token %s not found.") % access_token_id)
|
||||||
|
|
||||||
token_api.access_token_delete(access_token_id)
|
token_api.user_token_delete(access_token_id)
|
||||||
|
|
||||||
def _assert_can_access(self, user_id, token_entity=None):
|
def _assert_can_access(self, user_id, token_entity=None):
|
||||||
current_user = user_api.user_get(request.current_user_id)
|
current_user = user_api.user_get(request.current_user_id)
|
||||||
|
|
|
@ -389,6 +389,27 @@ class User(base.APIBase):
|
||||||
last_login=datetime(2014, 1, 1, 16, 42))
|
last_login=datetime(2014, 1, 1, 16, 42))
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshToken(base.APIBase):
|
||||||
|
"""Represents a user refresh token."""
|
||||||
|
|
||||||
|
user_id = int
|
||||||
|
"""The ID of corresponding user."""
|
||||||
|
|
||||||
|
refresh_token = wtypes.text
|
||||||
|
"""The refresh token."""
|
||||||
|
|
||||||
|
expires_in = int
|
||||||
|
"""The number of seconds after creation when this token expires."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample(cls):
|
||||||
|
return cls(
|
||||||
|
user_id=1,
|
||||||
|
refresh_token="a_unique_refresh_token",
|
||||||
|
expires_in=3600
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AccessToken(base.APIBase):
|
class AccessToken(base.APIBase):
|
||||||
"""Represents a user access token."""
|
"""Represents a user access token."""
|
||||||
|
|
||||||
|
@ -401,6 +422,9 @@ class AccessToken(base.APIBase):
|
||||||
expires_in = int
|
expires_in = int
|
||||||
"""The number of seconds after creation when this token expires."""
|
"""The number of seconds after creation when this token expires."""
|
||||||
|
|
||||||
|
refresh_tokens = wtypes.ArrayType(RefreshToken)
|
||||||
|
"""Array of corresponding refresh tokens."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample(cls):
|
def sample(cls):
|
||||||
return cls(
|
return cls(
|
||||||
|
|
|
@ -20,8 +20,10 @@ from storyboard.db.api import base as api_base
|
||||||
from storyboard.db import models
|
from storyboard.db import models
|
||||||
|
|
||||||
|
|
||||||
def access_token_get(access_token_id):
|
def access_token_get(access_token_id, session=None):
|
||||||
return api_base.entity_get(models.AccessToken, access_token_id)
|
return api_base.entity_get(models.AccessToken,
|
||||||
|
access_token_id,
|
||||||
|
session=session)
|
||||||
|
|
||||||
|
|
||||||
def access_token_get_by_token(access_token):
|
def access_token_get_by_token(access_token):
|
||||||
|
@ -102,15 +104,29 @@ def access_token_build_query(**kwargs):
|
||||||
return query
|
return query
|
||||||
|
|
||||||
|
|
||||||
|
def access_token_delete(access_token_id):
|
||||||
|
session = api_base.get_session()
|
||||||
|
|
||||||
|
with session.begin():
|
||||||
|
access_token = access_token_get(access_token_id, session=session)
|
||||||
|
|
||||||
|
if access_token:
|
||||||
|
query = api_base.model_query(models.RefreshToken)
|
||||||
|
refresh_tokens = []
|
||||||
|
|
||||||
|
for refresh_token in access_token.refresh_tokens:
|
||||||
|
refresh_tokens.append(refresh_token.id)
|
||||||
|
|
||||||
|
query = query.filter(models.RefreshToken.id.in_(
|
||||||
|
refresh_tokens
|
||||||
|
))
|
||||||
|
|
||||||
|
api_base.entity_hard_delete(models.AccessToken, access_token_id)
|
||||||
|
query.delete(synchronize_session=False)
|
||||||
|
|
||||||
|
|
||||||
def access_token_delete_by_token(access_token):
|
def access_token_delete_by_token(access_token):
|
||||||
access_token = access_token_get_by_token(access_token)
|
access_token = access_token_get_by_token(access_token)
|
||||||
|
|
||||||
if access_token:
|
if access_token:
|
||||||
api_base.entity_hard_delete(models.AccessToken, access_token.id)
|
access_token_delete(access_token.id)
|
||||||
|
|
||||||
|
|
||||||
def access_token_delete(access_token_id):
|
|
||||||
access_token = access_token_get(access_token_id)
|
|
||||||
|
|
||||||
if access_token:
|
|
||||||
api_base.entity_hard_delete(models.AccessToken, access_token_id)
|
|
||||||
|
|
|
@ -32,25 +32,3 @@ def authorization_code_delete(code):
|
||||||
|
|
||||||
if del_code:
|
if del_code:
|
||||||
api_base.entity_hard_delete(models.AuthorizationCode, del_code.id)
|
api_base.entity_hard_delete(models.AuthorizationCode, del_code.id)
|
||||||
|
|
||||||
|
|
||||||
def refresh_token_get(refresh_token):
|
|
||||||
try:
|
|
||||||
query = api_base.model_query(models.RefreshToken,
|
|
||||||
api_base.get_session())
|
|
||||||
return query.filter_by(refresh_token=refresh_token).first()
|
|
||||||
except Exception:
|
|
||||||
# If anything goes wrong while fetching a token None will be returned
|
|
||||||
# anyway.
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def refresh_token_save(values):
|
|
||||||
return api_base.entity_create(models.RefreshToken, values)
|
|
||||||
|
|
||||||
|
|
||||||
def refresh_token_delete(refresh_token):
|
|
||||||
del_token = refresh_token_get(refresh_token)
|
|
||||||
|
|
||||||
if del_token:
|
|
||||||
api_base.entity_hard_delete(models.RefreshToken, del_token.id)
|
|
|
@ -0,0 +1,117 @@
|
||||||
|
# Copyright (c) 2015 Mirantis Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
# implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
import pytz
|
||||||
|
|
||||||
|
from storyboard.common import exception as exc
|
||||||
|
from storyboard.db.api import access_tokens as access_tokens_api
|
||||||
|
from storyboard.db.api import base as api_base
|
||||||
|
from storyboard.db import models
|
||||||
|
from storyboard.openstack.common.gettextutils import _ # noqa
|
||||||
|
|
||||||
|
|
||||||
|
def refresh_token_get(refresh_token_id, session=None):
|
||||||
|
return api_base.entity_get(models.RefreshToken, refresh_token_id,
|
||||||
|
session=session)
|
||||||
|
|
||||||
|
|
||||||
|
def refresh_token_get_by_token(refresh_token):
|
||||||
|
try:
|
||||||
|
return api_base.model_query(models.RefreshToken) \
|
||||||
|
.filter_by(refresh_token=refresh_token).first()
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid(refresh_token):
|
||||||
|
if not refresh_token:
|
||||||
|
return False
|
||||||
|
|
||||||
|
token = refresh_token_get_by_token(refresh_token)
|
||||||
|
|
||||||
|
if not token:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if datetime.datetime.now(pytz.utc) > token.expires_at:
|
||||||
|
refresh_token_delete(token.id)
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def get_access_token_id(refresh_token_id):
|
||||||
|
session = api_base.get_session()
|
||||||
|
|
||||||
|
with session.begin():
|
||||||
|
refresh_token = refresh_token_get(refresh_token_id, session)
|
||||||
|
|
||||||
|
if refresh_token:
|
||||||
|
return refresh_token.access_tokens[0].id
|
||||||
|
|
||||||
|
|
||||||
|
def refresh_token_create(access_token_id, values):
|
||||||
|
session = api_base.get_session()
|
||||||
|
|
||||||
|
with session.begin():
|
||||||
|
values['expires_at'] = datetime.datetime.now(pytz.utc) + datetime.\
|
||||||
|
timedelta(seconds=values['expires_in'])
|
||||||
|
access_token = access_tokens_api.access_token_get(access_token_id,
|
||||||
|
session=session)
|
||||||
|
|
||||||
|
if not access_token:
|
||||||
|
raise exc.NotFound(_("Access token not found."))
|
||||||
|
|
||||||
|
refresh_token = api_base.entity_create(models.RefreshToken, values)
|
||||||
|
|
||||||
|
access_token.refresh_tokens.append(refresh_token)
|
||||||
|
session.add(access_token)
|
||||||
|
|
||||||
|
return refresh_token
|
||||||
|
|
||||||
|
|
||||||
|
def refresh_token_build_query(**kwargs):
|
||||||
|
# Construct the query
|
||||||
|
query = api_base.model_query(models.RefreshToken)
|
||||||
|
|
||||||
|
# Apply the filters
|
||||||
|
query = api_base.apply_query_filters(query=query,
|
||||||
|
model=models.RefreshToken,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
return query
|
||||||
|
|
||||||
|
|
||||||
|
def refresh_token_delete(refresh_token_id):
|
||||||
|
session = api_base.get_session()
|
||||||
|
|
||||||
|
with session.begin():
|
||||||
|
refresh_token = refresh_token_get(refresh_token_id)
|
||||||
|
|
||||||
|
if refresh_token:
|
||||||
|
access_token_id = refresh_token.access_tokens[0].id
|
||||||
|
access_token = access_tokens_api.access_token_get(access_token_id,
|
||||||
|
session=session)
|
||||||
|
|
||||||
|
access_token.refresh_tokens.remove(refresh_token)
|
||||||
|
session.add(access_token)
|
||||||
|
api_base.entity_hard_delete(models.RefreshToken, refresh_token_id)
|
||||||
|
|
||||||
|
|
||||||
|
def refresh_token_delete_by_token(refresh_token):
|
||||||
|
refresh_token = refresh_token_get_by_token(refresh_token)
|
||||||
|
|
||||||
|
if refresh_token:
|
||||||
|
refresh_token_delete(refresh_token.id)
|
|
@ -0,0 +1,84 @@
|
||||||
|
# Copyright (c) 2015 Mirantis Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
# implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from sqlalchemy.orm import subqueryload
|
||||||
|
|
||||||
|
from storyboard.db.api import access_tokens as access_tokens_api
|
||||||
|
from storyboard.db.api import base as api_base
|
||||||
|
from storyboard.db import models
|
||||||
|
|
||||||
|
|
||||||
|
def user_token_get(access_token_id):
|
||||||
|
token_query = api_base.model_query(models.AccessToken)
|
||||||
|
user_token = token_query.options(
|
||||||
|
subqueryload(models.AccessToken.refresh_tokens)
|
||||||
|
).filter_by(id=access_token_id).first()
|
||||||
|
|
||||||
|
return user_token
|
||||||
|
|
||||||
|
|
||||||
|
def user_token_get_all(marker=None, limit=None, sort_field=None,
|
||||||
|
sort_dir=None, **kwargs):
|
||||||
|
if not sort_field:
|
||||||
|
sort_field = 'id'
|
||||||
|
if not sort_dir:
|
||||||
|
sort_dir = 'asc'
|
||||||
|
|
||||||
|
query = api_base.model_query(models.AccessToken). \
|
||||||
|
options(subqueryload(models.AccessToken.refresh_tokens))
|
||||||
|
|
||||||
|
query = api_base.apply_query_filters(query=query,
|
||||||
|
model=models.AccessToken,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
query = api_base.paginate_query(query=query,
|
||||||
|
model=models.AccessToken,
|
||||||
|
limit=limit,
|
||||||
|
sort_key=sort_field,
|
||||||
|
marker=marker,
|
||||||
|
sort_dir=sort_dir)
|
||||||
|
|
||||||
|
return query.all()
|
||||||
|
|
||||||
|
|
||||||
|
def user_token_get_count(**kwargs):
|
||||||
|
query = api_base.model_query(models.AccessToken)
|
||||||
|
|
||||||
|
query = api_base.apply_query_filters(query=query,
|
||||||
|
model=models.AccessToken,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
return query.count()
|
||||||
|
|
||||||
|
|
||||||
|
def user_token_create(values):
|
||||||
|
access_token = access_tokens_api.access_token_create(values)
|
||||||
|
|
||||||
|
return user_token_get(access_token.id)
|
||||||
|
|
||||||
|
|
||||||
|
def user_token_update(access_token_id, values):
|
||||||
|
access_token = access_tokens_api.access_token_update(
|
||||||
|
access_token_id,
|
||||||
|
values)
|
||||||
|
|
||||||
|
if access_token:
|
||||||
|
return user_token_get(access_token.id)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def user_token_delete(access_token_id):
|
||||||
|
access_tokens_api.access_token_delete(access_token_id)
|
|
@ -0,0 +1,50 @@
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
# not use this file except in compliance with the License. You may obtain
|
||||||
|
# a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
# License for the specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
#
|
||||||
|
|
||||||
|
"""This migration creates new table for relations between access and refresh
|
||||||
|
tokens.
|
||||||
|
|
||||||
|
Revision ID: 041
|
||||||
|
Revises: 040
|
||||||
|
Create Date: 2015-02-18 18:03:23
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
|
||||||
|
revision = '041'
|
||||||
|
down_revision = '040'
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
MYSQL_ENGINE = 'InnoDB'
|
||||||
|
MYSQL_CHARSET = 'utf8'
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade(active_plugins=None, options=None):
|
||||||
|
op.create_table('access_refresh_tokens',
|
||||||
|
sa.Column('access_token_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('refresh_token_id', sa.Integer(),
|
||||||
|
nullable=False),
|
||||||
|
mysql_engine=MYSQL_ENGINE,
|
||||||
|
mysql_charset=MYSQL_CHARSET
|
||||||
|
)
|
||||||
|
op.drop_column(u'refreshtokens', u'access_token_id')
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade(active_plugins=None, options=None):
|
||||||
|
op.add_column('refreshtokens', sa.Column('access_token_id',
|
||||||
|
sa.Integer(),
|
||||||
|
nullable=False))
|
||||||
|
op.drop_table('access_refresh_tokens')
|
|
@ -367,6 +367,14 @@ class AuthorizationCode(ModelBuilder, Base):
|
||||||
expires_in = Column(Integer, nullable=False, default=300)
|
expires_in = Column(Integer, nullable=False, default=300)
|
||||||
|
|
||||||
|
|
||||||
|
access_refresh_tokens = Table(
|
||||||
|
'access_refresh_tokens', Base.metadata,
|
||||||
|
Column('access_token_id', Integer, ForeignKey('accesstokens.id')),
|
||||||
|
Column('refresh_token_id', Integer, ForeignKey('refreshtokens.id')),
|
||||||
|
schema.UniqueConstraint('refresh_token_id', name='uniq_refresh_token'),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AccessToken(ModelBuilder, Base):
|
class AccessToken(ModelBuilder, Base):
|
||||||
user_id = Column(Integer, ForeignKey('users.id'), nullable=False)
|
user_id = Column(Integer, ForeignKey('users.id'), nullable=False)
|
||||||
access_token = Column(Unicode(CommonLength.top_middle_length),
|
access_token = Column(Unicode(CommonLength.top_middle_length),
|
||||||
|
@ -374,16 +382,15 @@ class AccessToken(ModelBuilder, Base):
|
||||||
expires_in = Column(Integer, nullable=False)
|
expires_in = Column(Integer, nullable=False)
|
||||||
expires_at = Column(UTCDateTime, nullable=False)
|
expires_at = Column(UTCDateTime, nullable=False)
|
||||||
refresh_tokens = relationship("RefreshToken",
|
refresh_tokens = relationship("RefreshToken",
|
||||||
cascade="save-update, merge, delete",
|
secondary='access_refresh_tokens',
|
||||||
passive_updates=False,
|
)
|
||||||
passive_deletes=False)
|
|
||||||
|
|
||||||
|
|
||||||
class RefreshToken(ModelBuilder, Base):
|
class RefreshToken(ModelBuilder, Base):
|
||||||
user_id = Column(Integer, ForeignKey('users.id'), nullable=False)
|
user_id = Column(Integer, ForeignKey('users.id'), nullable=False)
|
||||||
access_token_id = Column(Integer,
|
access_tokens = relationship("AccessToken",
|
||||||
ForeignKey('accesstokens.id'),
|
secondary='access_refresh_tokens',
|
||||||
nullable=False)
|
)
|
||||||
refresh_token = Column(Unicode(CommonLength.top_middle_length),
|
refresh_token = Column(Unicode(CommonLength.top_middle_length),
|
||||||
nullable=False)
|
nullable=False)
|
||||||
expires_in = Column(Integer, nullable=False)
|
expires_in = Column(Integer, nullable=False)
|
||||||
|
|
|
@ -18,6 +18,7 @@ import pytz
|
||||||
|
|
||||||
from apscheduler.triggers.interval import IntervalTrigger
|
from apscheduler.triggers.interval import IntervalTrigger
|
||||||
|
|
||||||
|
from storyboard.db.api import access_tokens as access_tokens_api
|
||||||
import storyboard.db.api.base as api_base
|
import storyboard.db.api.base as api_base
|
||||||
from storyboard.db.models import AccessToken
|
from storyboard.db.models import AccessToken
|
||||||
from storyboard.plugin.scheduler.base import SchedulerPluginBase
|
from storyboard.plugin.scheduler.base import SchedulerPluginBase
|
||||||
|
@ -57,4 +58,4 @@ class TokenCleaner(SchedulerPluginBase):
|
||||||
# Manually deleting each record, because batch deletes are an
|
# Manually deleting each record, because batch deletes are an
|
||||||
# exception to ORM Cascade markup.
|
# exception to ORM Cascade markup.
|
||||||
for token in query.all():
|
for token in query.all():
|
||||||
api_base.entity_hard_delete(AccessToken, token.id)
|
access_tokens_api.access_token_delete(token.id)
|
||||||
|
|
|
@ -25,7 +25,8 @@ import six.moves.urllib.parse as urlparse
|
||||||
|
|
||||||
from storyboard.api.auth import ErrorMessages as e_msg
|
from storyboard.api.auth import ErrorMessages as e_msg
|
||||||
from storyboard.db.api import access_tokens as token_api
|
from storyboard.db.api import access_tokens as token_api
|
||||||
from storyboard.db.api import auth as auth_api
|
from storyboard.db.api import auth_codes as auth_api
|
||||||
|
from storyboard.db.api import refresh_tokens
|
||||||
from storyboard.tests import base
|
from storyboard.tests import base
|
||||||
|
|
||||||
|
|
||||||
|
@ -584,7 +585,7 @@ class TestOAuthAccessToken(BaseOAuthTest):
|
||||||
|
|
||||||
# Assert that the refresh token is in the database
|
# Assert that the refresh token is in the database
|
||||||
refresh_token = \
|
refresh_token = \
|
||||||
auth_api.refresh_token_get(token['refresh_token'])
|
refresh_tokens.refresh_token_get_by_token(token['refresh_token'])
|
||||||
self.assertIsNotNone(refresh_token)
|
self.assertIsNotNone(refresh_token)
|
||||||
|
|
||||||
# Assert that system configured values is owned by the correct user.
|
# Assert that system configured values is owned by the correct user.
|
||||||
|
@ -773,7 +774,7 @@ class TestOAuthAccessToken(BaseOAuthTest):
|
||||||
token_api.access_token_get_by_token(t1['access_token'])
|
token_api.access_token_get_by_token(t1['access_token'])
|
||||||
self.assertIsNotNone(access_token)
|
self.assertIsNotNone(access_token)
|
||||||
refresh_token = \
|
refresh_token = \
|
||||||
auth_api.refresh_token_get(t1['refresh_token'])
|
refresh_tokens.refresh_token_get_by_token(t1['refresh_token'])
|
||||||
self.assertIsNotNone(refresh_token)
|
self.assertIsNotNone(refresh_token)
|
||||||
|
|
||||||
# Issue a refresh token request.
|
# Issue a refresh token request.
|
||||||
|
@ -814,7 +815,7 @@ class TestOAuthAccessToken(BaseOAuthTest):
|
||||||
|
|
||||||
# Assert that the refresh token is in the database
|
# Assert that the refresh token is in the database
|
||||||
new_refresh_token = \
|
new_refresh_token = \
|
||||||
auth_api.refresh_token_get(t2['refresh_token'])
|
refresh_tokens.refresh_token_get_by_token(t2['refresh_token'])
|
||||||
self.assertIsNotNone(new_refresh_token)
|
self.assertIsNotNone(new_refresh_token)
|
||||||
|
|
||||||
# Assert that system configured values is owned by the correct user.
|
# Assert that system configured values is owned by the correct user.
|
||||||
|
@ -829,7 +830,7 @@ class TestOAuthAccessToken(BaseOAuthTest):
|
||||||
no_access_token = \
|
no_access_token = \
|
||||||
token_api.access_token_get_by_token(t1['access_token'])
|
token_api.access_token_get_by_token(t1['access_token'])
|
||||||
no_refresh_token = \
|
no_refresh_token = \
|
||||||
auth_api.refresh_token_get(t1['refresh_token'])
|
refresh_tokens.refresh_token_get_by_token(t1['refresh_token'])
|
||||||
|
|
||||||
self.assertIsNone(no_refresh_token)
|
self.assertIsNone(no_refresh_token)
|
||||||
self.assertIsNone(no_access_token)
|
self.assertIsNone(no_access_token)
|
||||||
|
|
|
@ -28,7 +28,6 @@ class TestUserTokensAsUser(base.FunctionalTest):
|
||||||
"""
|
"""
|
||||||
response = self.get_json(self.resource, expect_errors=True)
|
response = self.get_json(self.resource, expect_errors=True)
|
||||||
self.assertEqual(200, response.status_code)
|
self.assertEqual(200, response.status_code)
|
||||||
self.assertEqual(len(response.json), 2)
|
|
||||||
|
|
||||||
def test_unauthorized_browse(self):
|
def test_unauthorized_browse(self):
|
||||||
"""Assert a basic browse for someone else's tokens
|
"""Assert a basic browse for someone else's tokens
|
||||||
|
|
|
@ -13,7 +13,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from storyboard.db.api import auth
|
from storyboard.db.api import auth_codes
|
||||||
from storyboard.db.api import users
|
from storyboard.db.api import users
|
||||||
from storyboard.tests.db import base
|
from storyboard.tests.db import base
|
||||||
|
|
||||||
|
@ -32,15 +32,15 @@ class AuthorizationCodeTest(base.BaseDbTestCase):
|
||||||
users.user_create({"fullname": "Test User"})
|
users.user_create({"fullname": "Test User"})
|
||||||
|
|
||||||
def test_create_code(self):
|
def test_create_code(self):
|
||||||
self._test_create(self.code_01, auth.authorization_code_save)
|
self._test_create(self.code_01, auth_codes.authorization_code_save)
|
||||||
|
|
||||||
def test_delete_code(self):
|
def test_delete_code(self):
|
||||||
created_code = auth.authorization_code_save(self.code_01)
|
created_code = auth_codes.authorization_code_save(self.code_01)
|
||||||
|
|
||||||
self.assertIsNotNone(created_code,
|
self.assertIsNotNone(created_code,
|
||||||
"Could not create an Authorization code")
|
"Could not create an Authorization code")
|
||||||
|
|
||||||
auth.authorization_code_delete(created_code.code)
|
auth_codes.authorization_code_delete(created_code.code)
|
||||||
|
|
||||||
fetched_code = auth.authorization_code_get(created_code.code)
|
fetched_code = auth_codes.authorization_code_get(created_code.code)
|
||||||
self.assertIsNone(fetched_code)
|
self.assertIsNone(fetched_code)
|
||||||
|
|
|
@ -69,31 +69,33 @@ class TestTokenCleaner(db_base.BaseDbTestCase,
|
||||||
# 8-day-old-token to remain valid.
|
# 8-day-old-token to remain valid.
|
||||||
new_access_tokens = []
|
new_access_tokens = []
|
||||||
new_refresh_tokens = []
|
new_refresh_tokens = []
|
||||||
|
|
||||||
for i in range(0, 100):
|
for i in range(0, 100):
|
||||||
created_at = datetime.now(pytz.utc) - timedelta(days=i)
|
created_at = datetime.now(pytz.utc) - timedelta(days=i)
|
||||||
expires_in = (60 * 60 * 24) - 5 # Minus five seconds, see above.
|
expires_in = (60 * 60 * 24) - 5 # Minus five seconds, see above.
|
||||||
expires_at = created_at + timedelta(seconds=expires_in)
|
expires_at = created_at + timedelta(seconds=expires_in)
|
||||||
|
|
||||||
new_access_tokens.append(
|
new_access_tokens.append(
|
||||||
AccessToken(
|
AccessToken(
|
||||||
user_id=1,
|
user_id=1,
|
||||||
created_at=created_at,
|
created_at=created_at,
|
||||||
expires_in=expires_in,
|
expires_in=expires_in,
|
||||||
expires_at=expires_at,
|
expires_at=expires_at,
|
||||||
access_token='test_token_%s' % (i,))
|
access_token='test_token_%s' % (i,)),
|
||||||
)
|
)
|
||||||
|
|
||||||
new_access_tokens = load_data(new_access_tokens)
|
new_access_tokens = load_data(new_access_tokens)
|
||||||
|
|
||||||
for token in new_access_tokens:
|
for token in new_access_tokens:
|
||||||
new_refresh_tokens.append(
|
new_refresh_tokens.append(
|
||||||
RefreshToken(
|
RefreshToken(
|
||||||
user_id=1,
|
user_id=1,
|
||||||
access_token_id=token.id,
|
|
||||||
created_at=token.created_at,
|
created_at=token.created_at,
|
||||||
|
expires_in=token.expires_in,
|
||||||
expires_at=token.expires_at,
|
expires_at=token.expires_at,
|
||||||
expires_in=300,
|
access_tokens=[token],
|
||||||
refresh_token='test_refresh_%s' % (token.id,))
|
refresh_token='test_refresh_%s' % (token.id,))
|
||||||
)
|
)
|
||||||
|
|
||||||
new_refresh_tokens = load_data(new_refresh_tokens)
|
new_refresh_tokens = load_data(new_refresh_tokens)
|
||||||
|
|
||||||
# Make sure we have 100 tokens.
|
# Make sure we have 100 tokens.
|
||||||
|
|
Loading…
Reference in New Issue