Session per request

DB session is now saved and reused for all db operations within an http
request.

Change-Id: If8975bfc7a552c545cf0412f8e0027bec0969234
This commit is contained in:
Nikita Konovalov
2015-02-12 14:36:35 +03:00
committed by Aleksey Ripinen
parent 32bceedb7e
commit 6ea78ee947
17 changed files with 210 additions and 70 deletions

View File

@@ -23,6 +23,7 @@ from wsgiref import simple_server
from storyboard.api import config as api_config
from storyboard.api.middleware.cors_middleware import CORSMiddleware
from storyboard.api.middleware import session_hook
from storyboard.api.middleware import token_middleware
from storyboard.api.middleware import user_id_hook
from storyboard.api.middleware import validation_hook
@@ -85,6 +86,7 @@ def setup_app(pecan_config=None):
log.setup(CONF, 'storyboard')
hooks = [
session_hook.DBSessionHook(),
user_id_hook.UserIdHook(),
validation_hook.ValidationHook()
]

View File

@@ -0,0 +1,65 @@
# Copyright (c) 2015 Mirantis Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pecan import hooks
from pecan import request
from sqlalchemy.exc import InvalidRequestError
import storyboard.common.hook_priorities as priority
from storyboard.db.api import base
class DBSessionHook(hooks.TransactionHook):
priority = priority.PRE_AUTH
def _start_session(self):
# in_request is False because at this point we need a new session
session = base.get_session(autocommit=False, in_request=False)
request.session = session
def _commit_session(self):
if hasattr(request, "session"):
# Commit the session
try:
request.session.commit()
request.session.flush()
except InvalidRequestError:
# Session may have got into error state after a rollback was
# called for a failed create or update. Skipping.
pass
def _rollback_session(self):
if hasattr(request, "session"):
try:
request.session.rollback()
request.session.close()
except InvalidRequestError:
# There may be no transactions to roll back. Skipping.
pass
def _clear_session(self):
if hasattr(request, "session"):
request.session.close()
def is_transactional(self, state):
return True
def __init__(self):
super(DBSessionHook, self).__init__(self._start_session,
self._start_session,
self._commit_session,
self._rollback_session,
self._clear_session)

View File

@@ -84,6 +84,8 @@ class TagsController(rest.RestController):
for tag in tags:
stories_api.story_add_tag(story_id, tag)
# For some reason the story gets cached and the tags do not appear.
stories_api.api_base.get_session().expunge(story)
story = stories_api.story_get(story_id)
return wmodels.Story.from_db_model(story)

View File

@@ -17,10 +17,12 @@
# information, please see the pecan documentation at
# https://github.com/stackforge/pecan/blob/master/pecan/hooks.py
# Low-level hooks. Example setup db session.
PRE_AUTH = 1
# Authentication must occur relatively early in the hook processing,
# as subsequent logic may depend on ACLs.
AUTH = 1
AUTH = 10
# Data validation occurs after we've figured out who is making the request,
# but before we perform any cleaning on the data. It's there to make sure

View File

@@ -107,7 +107,7 @@ def access_token_build_query(**kwargs):
def access_token_delete(access_token_id):
session = api_base.get_session()
with session.begin():
with session.begin(subtransactions=True):
access_token = access_token_get(access_token_id, session=session)
if access_token:

View File

@@ -21,6 +21,7 @@ from oslo.db.sqlalchemy import session as db_session
from oslo.db.sqlalchemy.utils import InvalidSortKey
from oslo.db.sqlalchemy.utils import paginate_query as utils_paginate_query
from oslo_log import log
from pecan import request
import six
import sqlalchemy.types as sqltypes
@@ -110,13 +111,19 @@ def paginate_query(query, model, limit, sort_key, marker=None,
sort_key)
def get_session(autocommit=True, expire_on_commit=False, **kwargs):
def get_session(autocommit=True, expire_on_commit=False, in_request=True,
**kwargs):
"""Returns a database session from our facade.
"""
facade = _get_facade_instance()
try:
return facade.get_session(autocommit=autocommit,
expire_on_commit=expire_on_commit, **kwargs)
if in_request:
return request.session
else:
# Ok, no request, just return a new session
return facade.get_session(
autocommit=autocommit,
expire_on_commit=expire_on_commit, **kwargs)
except db_exc.DBConnectionError:
raise exc.DBConnectionError()
except db_exc.DBDeadlock:
@@ -261,8 +268,9 @@ def entity_create(kls, values):
session = get_session()
try:
with session.begin():
with session.begin(subtransactions=True):
session.add(entity)
session.expunge(entity)
except db_exc.DBDuplicateEntry as de:
raise exc.DBDuplicateEntry(object_name=kls.__name__,
@@ -288,7 +296,7 @@ def entity_update(kls, entity_id, values):
session = get_session()
try:
with session.begin():
with session.begin(subtransactions=True):
entity = __entity_get(kls, entity_id, session)
if entity is None:
raise exc.NotFound(_("%(name)s %(id)s not found") %
@@ -298,6 +306,7 @@ def entity_update(kls, entity_id, values):
values_copy["id"] = entity_id
entity.update(values_copy)
session.add(entity)
session.expunge(entity)
except db_exc.DBDuplicateEntry as de:
raise exc.DBDuplicateEntry(object_name=kls.__name__,
@@ -326,7 +335,7 @@ def entity_hard_delete(kls, entity_id):
session = get_session()
try:
with session.begin():
with session.begin(subtransactions=True):
query = model_query(kls, session)
entity = query.filter_by(id=entity_id).first()
if entity is None:

View File

@@ -63,7 +63,7 @@ def project_group_update(project_group_id, values):
def project_group_add_project(project_group_id, project_id):
session = api_base.get_session()
with session.begin():
with session.begin(subtransactions=True):
project_group = _entity_get(project_group_id, session)
if project_group is None:
raise exc.NotFound(_("%(name)s %(id)s not found")
@@ -90,7 +90,7 @@ def project_group_add_project(project_group_id, project_id):
def project_group_delete_project(project_group_id, project_id):
session = api_base.get_session()
with session.begin():
with session.begin(subtransactions=True):
project_group = _entity_get(project_group_id, session)
if project_group is None:
raise exc.NotFound(_("%(name)s %(id)s not found")

View File

@@ -55,7 +55,7 @@ def is_valid(refresh_token):
def get_access_token_id(refresh_token_id):
session = api_base.get_session()
with session.begin():
with session.begin(subtransactions=True):
refresh_token = refresh_token_get(refresh_token_id, session)
if refresh_token:
@@ -65,7 +65,7 @@ def get_access_token_id(refresh_token_id):
def refresh_token_create(access_token_id, values):
session = api_base.get_session()
with session.begin():
with session.begin(subtransactions=True):
values['expires_at'] = datetime.datetime.now(pytz.utc) + datetime.\
timedelta(seconds=values['expires_in'])
access_token = access_tokens_api.access_token_get(access_token_id,
@@ -97,7 +97,7 @@ def refresh_token_build_query(**kwargs):
def refresh_token_delete(refresh_token_id):
session = api_base.get_session()
with session.begin():
with session.begin(subtransactions=True):
refresh_token = refresh_token_get(refresh_token_id)
if refresh_token:

View File

@@ -174,7 +174,7 @@ def story_update(story_id, values):
def story_add_tag(story_id, tag_name):
session = api_base.get_session()
with session.begin():
with session.begin(subtransactions=True):
# Get a tag or create a new one
tag = story_tags.tag_get_by_name(tag_name, session=session)
@@ -193,12 +193,13 @@ def story_add_tag(story_id, tag_name):
story.tags.append(tag)
session.add(story)
session.expunge(story)
def story_remove_tag(story_id, tag_name):
session = api_base.get_session()
with session.begin():
with session.begin(subtransactions=True):
story = story_get_simple(story_id, session=session)
if not story:
@@ -213,6 +214,7 @@ def story_remove_tag(story_id, tag_name):
tag = [t for t in story.tags if t.name == tag_name][0]
story.tags.remove(tag)
session.add(story)
session.expunge(story)
def story_delete(story_id):

View File

@@ -63,7 +63,7 @@ def team_update(team_id, values):
def team_add_user(team_id, user_id):
session = api_base.get_session()
with session.begin():
with session.begin(subtransactions=True):
team = _entity_get(team_id, session)
if team is None:
raise exc.NotFound(_("Team %s not found") % team_id)
@@ -86,7 +86,7 @@ def team_add_user(team_id, user_id):
def team_delete_user(team_id, user_id):
session = api_base.get_session()
with session.begin():
with session.begin(subtransactions=True):
team = _entity_get(team_id, session)
if team is None:
raise exc.NotFound(_("Team %s not found") % team_id)

View File

@@ -35,7 +35,7 @@ LOG = log.getLogger(__name__)
def do_load_models(filename):
config_file = open(filename)
session = db_api.get_session(autocommit=False)
session = db_api.get_session(autocommit=False, in_request=False)
projects_list = yaml.load(config_file)
project_groups = list()

View File

@@ -30,7 +30,7 @@ def do_load_models(filename):
config_file = open(filename)
superusers_list = yaml.load(config_file)
session = db_api.get_session()
session = db_api.get_session(in_request=False)
with session.begin():
for user in superusers_list:

View File

@@ -389,7 +389,9 @@ class TestOAuthAuthorizeReturn(BaseOAuthTest):
location = response.headers.get('Location')
location_url = urlparse.urlparse(location)
parameters = urlparse.parse_qs(location_url[4])
token = auth_api.authorization_code_get(parameters['code'])
with base.HybridSessionManager():
token = auth_api.authorization_code_get(parameters['code'])
# Validate the redirect response
self.assertValidRedirect(response=response,
@@ -543,11 +545,12 @@ class TestOAuthAccessToken(BaseOAuthTest):
"""
# Generate a valid auth token
authorization_code = auth_api.authorization_code_save({
'user_id': 2,
'state': 'test_state',
'code': 'test_valid_code'
})
with base.HybridSessionManager():
authorization_code = auth_api.authorization_code_save({
'user_id': 2,
'state': 'test_state',
'code': 'test_valid_code'
})
# POST with content: application/x-www-form-urlencoded
response = self.app.post('/v1/openid/token',
@@ -572,8 +575,9 @@ class TestOAuthAccessToken(BaseOAuthTest):
self.assertEqual('Bearer', token['token_type'])
# Assert that the access token is in the database
access_token = \
token_api.access_token_get_by_token(token['access_token'])
with base.HybridSessionManager():
access_token = \
token_api.access_token_get_by_token(token['access_token'])
self.assertIsNotNone(access_token)
# Assert that system configured values is owned by the correct user.
@@ -584,8 +588,11 @@ class TestOAuthAccessToken(BaseOAuthTest):
self.assertEqual(token['access_token'], access_token.access_token)
# Assert that the refresh token is in the database
refresh_token = \
refresh_tokens.refresh_token_get_by_token(token['refresh_token'])
with base.HybridSessionManager():
refresh_token = \
refresh_tokens.refresh_token_get_by_token(
token['refresh_token'])
self.assertIsNotNone(refresh_token)
# Assert that system configured values is owned by the correct user.
@@ -595,9 +602,10 @@ class TestOAuthAccessToken(BaseOAuthTest):
self.assertEqual(token['refresh_token'], refresh_token.refresh_token)
# Assert that the authorization code is no longer in the database.
self.assertIsNone(auth_api.authorization_code_get(
authorization_code.code
))
with base.HybridSessionManager():
none_code = \
auth_api.authorization_code_get(authorization_code.code)
self.assertIsNone(none_code)
def test_valid_access_token_time(self):
"""Assert that a newly created access token is valid if storyboard is
@@ -616,14 +624,14 @@ class TestOAuthAccessToken(BaseOAuthTest):
os.environ['TZ'] = name
# Create a token.
authorization_code = auth_api.authorization_code_save({
'user_id': 2,
'state': 'test_state',
'code': 'test_valid_code',
'expires_in': 300
})
with base.HybridSessionManager():
authorization_code = auth_api.authorization_code_save({
'user_id': 2,
'state': 'test_state',
'code': 'test_valid_code',
'expires_in': 300
})
# POST with content: application/x-www-form-urlencoded
response = self.app.post('/v1/openid/token',
params={
'code': authorization_code.code,
@@ -662,13 +670,14 @@ class TestOAuthAccessToken(BaseOAuthTest):
os.environ['TZ'] = name
# Create a token.
authorization_code = auth_api.authorization_code_save({
'user_id': 2,
'state': 'test_state',
'code': 'test_valid_code',
'expires_in': 300,
'created_at': expired
})
with base.HybridSessionManager():
authorization_code = auth_api.authorization_code_save({
'user_id': 2,
'state': 'test_state',
'code': 'test_valid_code',
'expires_in': 300,
'created_at': expired
})
# POST with content: application/x-www-form-urlencoded
response = self.app.post('/v1/openid/token',
@@ -695,12 +704,13 @@ class TestOAuthAccessToken(BaseOAuthTest):
"""
# Generate a valid auth token
authorization_code = auth_api.authorization_code_save({
'user_id': 2,
'state': 'test_state',
'code': 'test_valid_code',
'expires_in': 300
})
with base.HybridSessionManager():
authorization_code = auth_api.authorization_code_save({
'user_id': 2,
'state': 'test_state',
'code': 'test_valid_code',
'expires_in': 300
})
# POST with content: application/x-www-form-urlencoded
response = self.app.post('/v1/openid/token',
@@ -747,11 +757,12 @@ class TestOAuthAccessToken(BaseOAuthTest):
"""
# Generate a valid access code
authorization_code = auth_api.authorization_code_save({
'user_id': 2,
'state': 'test_state',
'code': 'test_valid_code'
})
with base.HybridSessionManager():
authorization_code = auth_api.authorization_code_save({
'user_id': 2,
'state': 'test_state',
'code': 'test_valid_code'
})
# Generate an auth and a refresh token.
resp_1 = self.app.post('/v1/openid/token',
@@ -770,11 +781,15 @@ class TestOAuthAccessToken(BaseOAuthTest):
t1 = resp_1.json
# Assert that both are in the database.
access_token = \
token_api.access_token_get_by_token(t1['access_token'])
with base.HybridSessionManager():
access_token = \
token_api.access_token_get_by_token(t1['access_token'])
self.assertIsNotNone(access_token)
refresh_token = \
refresh_tokens.refresh_token_get_by_token(t1['refresh_token'])
with base.HybridSessionManager():
refresh_token = refresh_tokens.refresh_token_get_by_token(
t1['refresh_token'])
self.assertIsNotNone(refresh_token)
# Issue a refresh token request.
@@ -801,8 +816,9 @@ class TestOAuthAccessToken(BaseOAuthTest):
self.assertEqual('Bearer', t2['token_type'])
# Assert that the access token is in the database
new_access_token = \
token_api.access_token_get_by_token(t2['access_token'])
with base.HybridSessionManager():
new_access_token = \
token_api.access_token_get_by_token(t2['access_token'])
self.assertIsNotNone(new_access_token)
# Assert that system configured values is owned by the correct user.
@@ -814,8 +830,11 @@ class TestOAuthAccessToken(BaseOAuthTest):
new_access_token.access_token)
# Assert that the refresh token is in the database
new_refresh_token = \
refresh_tokens.refresh_token_get_by_token(t2['refresh_token'])
with base.HybridSessionManager():
new_refresh_token = refresh_tokens.refresh_token_get_by_token(
t2['refresh_token'])
self.assertIsNotNone(new_refresh_token)
# Assert that system configured values is owned by the correct user.
@@ -827,10 +846,13 @@ class TestOAuthAccessToken(BaseOAuthTest):
# Assert that the old access tokens are no longer in the database and
# have been cleaned up.
no_access_token = \
token_api.access_token_get_by_token(t1['access_token'])
no_refresh_token = \
refresh_tokens.refresh_token_get_by_token(t1['refresh_token'])
with base.HybridSessionManager():
no_access_token = \
token_api.access_token_get_by_token(t1['access_token'])
with base.HybridSessionManager():
no_refresh_token = \
refresh_tokens.refresh_token_get_by_token(t1['refresh_token'])
self.assertIsNone(no_refresh_token)
self.assertIsNone(no_access_token)

View File

@@ -170,6 +170,22 @@ class DbTestCase(WorkingDirTestCase):
PATH_PREFIX = '/v1'
class HybridSessionManager(object):
def _mock_get_session(self, autocommit=True, expire_on_commit=False,
in_request=True, **kwargs):
return self.original_get_session(autocommit=autocommit,
expire_on_commit=expire_on_commit,
in_request=False, **kwargs)
def __enter__(self):
self.original_get_session = db_api_base.get_session
db_api_base.get_session = self._mock_get_session
def __exit__(self, exc_type, exc_val, exc_tb):
db_api_base.get_session = self.original_get_session
class FunctionalTest(DbTestCase):
"""Used for functional tests of Pecan controllers where you need to
test your literal application and its integration with the

View File

@@ -24,6 +24,11 @@ class TestHookPriority(base.TestCase):
def test_hook_order(self):
"""Assert that the hook priorities are ordered properly."""
self.assertLess(priority.PRE_AUTH, priority.AUTH)
self.assertLess(priority.PRE_AUTH, priority.VALIDATION)
self.assertLess(priority.PRE_AUTH, priority.POST_VALIDATION)
self.assertLess(priority.PRE_AUTH, priority.DEFAULT)
self.assertLess(priority.AUTH, priority.VALIDATION)
self.assertLess(priority.AUTH, priority.POST_VALIDATION)
self.assertLess(priority.AUTH, priority.DEFAULT)

View File

@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from storyboard.db.api import base as db_base
from storyboard.tests import base
from storyboard.tests import mock_data
@@ -20,8 +21,22 @@ from storyboard.tests import mock_data
class BaseDbTestCase(base.DbTestCase):
def setUp(self):
super(BaseDbTestCase, self).setUp()
self.original_get_session = db_base.get_session
self.addCleanup(self._reset_get_session)
db_base.get_session = self._mock_get_session
mock_data.load()
def _mock_get_session(self, autocommit=True, expire_on_commit=False,
in_request=True, **kwargs):
return self.original_get_session(autocommit=autocommit,
expire_on_commit=expire_on_commit,
in_request=False, **kwargs)
def _reset_get_session(self):
db_base.get_session = self.original_get_session
def _assert_saved_fields(self, expected, actual):
for k in expected.keys():
self.assertEqual(expected[k], actual[k])

View File

@@ -317,7 +317,7 @@ def load_data(data):
:param data An iterable collection of database models.
"""
session = db.get_session(autocommit=False)
session = db.get_session(autocommit=False, in_request=False)
for entity in data:
session.add(entity)