Session per request
DB session is now saved and reused for all db operations within an http request. Change-Id: If8975bfc7a552c545cf0412f8e0027bec0969234
This commit is contained in:
committed by
Aleksey Ripinen
parent
32bceedb7e
commit
6ea78ee947
@@ -23,6 +23,7 @@ from wsgiref import simple_server
|
||||
|
||||
from storyboard.api import config as api_config
|
||||
from storyboard.api.middleware.cors_middleware import CORSMiddleware
|
||||
from storyboard.api.middleware import session_hook
|
||||
from storyboard.api.middleware import token_middleware
|
||||
from storyboard.api.middleware import user_id_hook
|
||||
from storyboard.api.middleware import validation_hook
|
||||
@@ -85,6 +86,7 @@ def setup_app(pecan_config=None):
|
||||
log.setup(CONF, 'storyboard')
|
||||
|
||||
hooks = [
|
||||
session_hook.DBSessionHook(),
|
||||
user_id_hook.UserIdHook(),
|
||||
validation_hook.ValidationHook()
|
||||
]
|
||||
|
||||
65
storyboard/api/middleware/session_hook.py
Normal file
65
storyboard/api/middleware/session_hook.py
Normal file
@@ -0,0 +1,65 @@
|
||||
# Copyright (c) 2015 Mirantis Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
# implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pecan import hooks
|
||||
from pecan import request
|
||||
from sqlalchemy.exc import InvalidRequestError
|
||||
|
||||
import storyboard.common.hook_priorities as priority
|
||||
from storyboard.db.api import base
|
||||
|
||||
|
||||
class DBSessionHook(hooks.TransactionHook):
|
||||
|
||||
priority = priority.PRE_AUTH
|
||||
|
||||
def _start_session(self):
|
||||
# in_request is False because at this point we need a new session
|
||||
session = base.get_session(autocommit=False, in_request=False)
|
||||
request.session = session
|
||||
|
||||
def _commit_session(self):
|
||||
if hasattr(request, "session"):
|
||||
# Commit the session
|
||||
try:
|
||||
request.session.commit()
|
||||
request.session.flush()
|
||||
except InvalidRequestError:
|
||||
# Session may have got into error state after a rollback was
|
||||
# called for a failed create or update. Skipping.
|
||||
pass
|
||||
|
||||
def _rollback_session(self):
|
||||
if hasattr(request, "session"):
|
||||
try:
|
||||
request.session.rollback()
|
||||
request.session.close()
|
||||
except InvalidRequestError:
|
||||
# There may be no transactions to roll back. Skipping.
|
||||
pass
|
||||
|
||||
def _clear_session(self):
|
||||
if hasattr(request, "session"):
|
||||
request.session.close()
|
||||
|
||||
def is_transactional(self, state):
|
||||
return True
|
||||
|
||||
def __init__(self):
|
||||
super(DBSessionHook, self).__init__(self._start_session,
|
||||
self._start_session,
|
||||
self._commit_session,
|
||||
self._rollback_session,
|
||||
self._clear_session)
|
||||
@@ -84,6 +84,8 @@ class TagsController(rest.RestController):
|
||||
|
||||
for tag in tags:
|
||||
stories_api.story_add_tag(story_id, tag)
|
||||
# For some reason the story gets cached and the tags do not appear.
|
||||
stories_api.api_base.get_session().expunge(story)
|
||||
|
||||
story = stories_api.story_get(story_id)
|
||||
return wmodels.Story.from_db_model(story)
|
||||
|
||||
@@ -17,10 +17,12 @@
|
||||
# information, please see the pecan documentation at
|
||||
# https://github.com/stackforge/pecan/blob/master/pecan/hooks.py
|
||||
|
||||
# Low-level hooks. Example setup db session.
|
||||
PRE_AUTH = 1
|
||||
|
||||
# Authentication must occur relatively early in the hook processing,
|
||||
# as subsequent logic may depend on ACLs.
|
||||
AUTH = 1
|
||||
AUTH = 10
|
||||
|
||||
# Data validation occurs after we've figured out who is making the request,
|
||||
# but before we perform any cleaning on the data. It's there to make sure
|
||||
|
||||
@@ -107,7 +107,7 @@ def access_token_build_query(**kwargs):
|
||||
def access_token_delete(access_token_id):
|
||||
session = api_base.get_session()
|
||||
|
||||
with session.begin():
|
||||
with session.begin(subtransactions=True):
|
||||
access_token = access_token_get(access_token_id, session=session)
|
||||
|
||||
if access_token:
|
||||
|
||||
@@ -21,6 +21,7 @@ from oslo.db.sqlalchemy import session as db_session
|
||||
from oslo.db.sqlalchemy.utils import InvalidSortKey
|
||||
from oslo.db.sqlalchemy.utils import paginate_query as utils_paginate_query
|
||||
from oslo_log import log
|
||||
from pecan import request
|
||||
import six
|
||||
import sqlalchemy.types as sqltypes
|
||||
|
||||
@@ -110,13 +111,19 @@ def paginate_query(query, model, limit, sort_key, marker=None,
|
||||
sort_key)
|
||||
|
||||
|
||||
def get_session(autocommit=True, expire_on_commit=False, **kwargs):
|
||||
def get_session(autocommit=True, expire_on_commit=False, in_request=True,
|
||||
**kwargs):
|
||||
"""Returns a database session from our facade.
|
||||
"""
|
||||
facade = _get_facade_instance()
|
||||
try:
|
||||
return facade.get_session(autocommit=autocommit,
|
||||
expire_on_commit=expire_on_commit, **kwargs)
|
||||
if in_request:
|
||||
return request.session
|
||||
else:
|
||||
# Ok, no request, just return a new session
|
||||
return facade.get_session(
|
||||
autocommit=autocommit,
|
||||
expire_on_commit=expire_on_commit, **kwargs)
|
||||
except db_exc.DBConnectionError:
|
||||
raise exc.DBConnectionError()
|
||||
except db_exc.DBDeadlock:
|
||||
@@ -261,8 +268,9 @@ def entity_create(kls, values):
|
||||
session = get_session()
|
||||
|
||||
try:
|
||||
with session.begin():
|
||||
with session.begin(subtransactions=True):
|
||||
session.add(entity)
|
||||
session.expunge(entity)
|
||||
|
||||
except db_exc.DBDuplicateEntry as de:
|
||||
raise exc.DBDuplicateEntry(object_name=kls.__name__,
|
||||
@@ -288,7 +296,7 @@ def entity_update(kls, entity_id, values):
|
||||
session = get_session()
|
||||
|
||||
try:
|
||||
with session.begin():
|
||||
with session.begin(subtransactions=True):
|
||||
entity = __entity_get(kls, entity_id, session)
|
||||
if entity is None:
|
||||
raise exc.NotFound(_("%(name)s %(id)s not found") %
|
||||
@@ -298,6 +306,7 @@ def entity_update(kls, entity_id, values):
|
||||
values_copy["id"] = entity_id
|
||||
entity.update(values_copy)
|
||||
session.add(entity)
|
||||
session.expunge(entity)
|
||||
|
||||
except db_exc.DBDuplicateEntry as de:
|
||||
raise exc.DBDuplicateEntry(object_name=kls.__name__,
|
||||
@@ -326,7 +335,7 @@ def entity_hard_delete(kls, entity_id):
|
||||
session = get_session()
|
||||
|
||||
try:
|
||||
with session.begin():
|
||||
with session.begin(subtransactions=True):
|
||||
query = model_query(kls, session)
|
||||
entity = query.filter_by(id=entity_id).first()
|
||||
if entity is None:
|
||||
|
||||
@@ -63,7 +63,7 @@ def project_group_update(project_group_id, values):
|
||||
def project_group_add_project(project_group_id, project_id):
|
||||
session = api_base.get_session()
|
||||
|
||||
with session.begin():
|
||||
with session.begin(subtransactions=True):
|
||||
project_group = _entity_get(project_group_id, session)
|
||||
if project_group is None:
|
||||
raise exc.NotFound(_("%(name)s %(id)s not found")
|
||||
@@ -90,7 +90,7 @@ def project_group_add_project(project_group_id, project_id):
|
||||
def project_group_delete_project(project_group_id, project_id):
|
||||
session = api_base.get_session()
|
||||
|
||||
with session.begin():
|
||||
with session.begin(subtransactions=True):
|
||||
project_group = _entity_get(project_group_id, session)
|
||||
if project_group is None:
|
||||
raise exc.NotFound(_("%(name)s %(id)s not found")
|
||||
|
||||
@@ -55,7 +55,7 @@ def is_valid(refresh_token):
|
||||
def get_access_token_id(refresh_token_id):
|
||||
session = api_base.get_session()
|
||||
|
||||
with session.begin():
|
||||
with session.begin(subtransactions=True):
|
||||
refresh_token = refresh_token_get(refresh_token_id, session)
|
||||
|
||||
if refresh_token:
|
||||
@@ -65,7 +65,7 @@ def get_access_token_id(refresh_token_id):
|
||||
def refresh_token_create(access_token_id, values):
|
||||
session = api_base.get_session()
|
||||
|
||||
with session.begin():
|
||||
with session.begin(subtransactions=True):
|
||||
values['expires_at'] = datetime.datetime.now(pytz.utc) + datetime.\
|
||||
timedelta(seconds=values['expires_in'])
|
||||
access_token = access_tokens_api.access_token_get(access_token_id,
|
||||
@@ -97,7 +97,7 @@ def refresh_token_build_query(**kwargs):
|
||||
def refresh_token_delete(refresh_token_id):
|
||||
session = api_base.get_session()
|
||||
|
||||
with session.begin():
|
||||
with session.begin(subtransactions=True):
|
||||
refresh_token = refresh_token_get(refresh_token_id)
|
||||
|
||||
if refresh_token:
|
||||
|
||||
@@ -174,7 +174,7 @@ def story_update(story_id, values):
|
||||
def story_add_tag(story_id, tag_name):
|
||||
session = api_base.get_session()
|
||||
|
||||
with session.begin():
|
||||
with session.begin(subtransactions=True):
|
||||
|
||||
# Get a tag or create a new one
|
||||
tag = story_tags.tag_get_by_name(tag_name, session=session)
|
||||
@@ -193,12 +193,13 @@ def story_add_tag(story_id, tag_name):
|
||||
|
||||
story.tags.append(tag)
|
||||
session.add(story)
|
||||
session.expunge(story)
|
||||
|
||||
|
||||
def story_remove_tag(story_id, tag_name):
|
||||
session = api_base.get_session()
|
||||
|
||||
with session.begin():
|
||||
with session.begin(subtransactions=True):
|
||||
|
||||
story = story_get_simple(story_id, session=session)
|
||||
if not story:
|
||||
@@ -213,6 +214,7 @@ def story_remove_tag(story_id, tag_name):
|
||||
tag = [t for t in story.tags if t.name == tag_name][0]
|
||||
story.tags.remove(tag)
|
||||
session.add(story)
|
||||
session.expunge(story)
|
||||
|
||||
|
||||
def story_delete(story_id):
|
||||
|
||||
@@ -63,7 +63,7 @@ def team_update(team_id, values):
|
||||
def team_add_user(team_id, user_id):
|
||||
session = api_base.get_session()
|
||||
|
||||
with session.begin():
|
||||
with session.begin(subtransactions=True):
|
||||
team = _entity_get(team_id, session)
|
||||
if team is None:
|
||||
raise exc.NotFound(_("Team %s not found") % team_id)
|
||||
@@ -86,7 +86,7 @@ def team_add_user(team_id, user_id):
|
||||
def team_delete_user(team_id, user_id):
|
||||
session = api_base.get_session()
|
||||
|
||||
with session.begin():
|
||||
with session.begin(subtransactions=True):
|
||||
team = _entity_get(team_id, session)
|
||||
if team is None:
|
||||
raise exc.NotFound(_("Team %s not found") % team_id)
|
||||
|
||||
@@ -35,7 +35,7 @@ LOG = log.getLogger(__name__)
|
||||
|
||||
def do_load_models(filename):
|
||||
config_file = open(filename)
|
||||
session = db_api.get_session(autocommit=False)
|
||||
session = db_api.get_session(autocommit=False, in_request=False)
|
||||
projects_list = yaml.load(config_file)
|
||||
|
||||
project_groups = list()
|
||||
|
||||
@@ -30,7 +30,7 @@ def do_load_models(filename):
|
||||
config_file = open(filename)
|
||||
superusers_list = yaml.load(config_file)
|
||||
|
||||
session = db_api.get_session()
|
||||
session = db_api.get_session(in_request=False)
|
||||
|
||||
with session.begin():
|
||||
for user in superusers_list:
|
||||
|
||||
@@ -389,7 +389,9 @@ class TestOAuthAuthorizeReturn(BaseOAuthTest):
|
||||
location = response.headers.get('Location')
|
||||
location_url = urlparse.urlparse(location)
|
||||
parameters = urlparse.parse_qs(location_url[4])
|
||||
token = auth_api.authorization_code_get(parameters['code'])
|
||||
|
||||
with base.HybridSessionManager():
|
||||
token = auth_api.authorization_code_get(parameters['code'])
|
||||
|
||||
# Validate the redirect response
|
||||
self.assertValidRedirect(response=response,
|
||||
@@ -543,11 +545,12 @@ class TestOAuthAccessToken(BaseOAuthTest):
|
||||
"""
|
||||
|
||||
# Generate a valid auth token
|
||||
authorization_code = auth_api.authorization_code_save({
|
||||
'user_id': 2,
|
||||
'state': 'test_state',
|
||||
'code': 'test_valid_code'
|
||||
})
|
||||
with base.HybridSessionManager():
|
||||
authorization_code = auth_api.authorization_code_save({
|
||||
'user_id': 2,
|
||||
'state': 'test_state',
|
||||
'code': 'test_valid_code'
|
||||
})
|
||||
|
||||
# POST with content: application/x-www-form-urlencoded
|
||||
response = self.app.post('/v1/openid/token',
|
||||
@@ -572,8 +575,9 @@ class TestOAuthAccessToken(BaseOAuthTest):
|
||||
self.assertEqual('Bearer', token['token_type'])
|
||||
|
||||
# Assert that the access token is in the database
|
||||
access_token = \
|
||||
token_api.access_token_get_by_token(token['access_token'])
|
||||
with base.HybridSessionManager():
|
||||
access_token = \
|
||||
token_api.access_token_get_by_token(token['access_token'])
|
||||
self.assertIsNotNone(access_token)
|
||||
|
||||
# Assert that system configured values is owned by the correct user.
|
||||
@@ -584,8 +588,11 @@ class TestOAuthAccessToken(BaseOAuthTest):
|
||||
self.assertEqual(token['access_token'], access_token.access_token)
|
||||
|
||||
# Assert that the refresh token is in the database
|
||||
refresh_token = \
|
||||
refresh_tokens.refresh_token_get_by_token(token['refresh_token'])
|
||||
with base.HybridSessionManager():
|
||||
refresh_token = \
|
||||
refresh_tokens.refresh_token_get_by_token(
|
||||
token['refresh_token'])
|
||||
|
||||
self.assertIsNotNone(refresh_token)
|
||||
|
||||
# Assert that system configured values is owned by the correct user.
|
||||
@@ -595,9 +602,10 @@ class TestOAuthAccessToken(BaseOAuthTest):
|
||||
self.assertEqual(token['refresh_token'], refresh_token.refresh_token)
|
||||
|
||||
# Assert that the authorization code is no longer in the database.
|
||||
self.assertIsNone(auth_api.authorization_code_get(
|
||||
authorization_code.code
|
||||
))
|
||||
with base.HybridSessionManager():
|
||||
none_code = \
|
||||
auth_api.authorization_code_get(authorization_code.code)
|
||||
self.assertIsNone(none_code)
|
||||
|
||||
def test_valid_access_token_time(self):
|
||||
"""Assert that a newly created access token is valid if storyboard is
|
||||
@@ -616,14 +624,14 @@ class TestOAuthAccessToken(BaseOAuthTest):
|
||||
os.environ['TZ'] = name
|
||||
|
||||
# Create a token.
|
||||
authorization_code = auth_api.authorization_code_save({
|
||||
'user_id': 2,
|
||||
'state': 'test_state',
|
||||
'code': 'test_valid_code',
|
||||
'expires_in': 300
|
||||
})
|
||||
with base.HybridSessionManager():
|
||||
authorization_code = auth_api.authorization_code_save({
|
||||
'user_id': 2,
|
||||
'state': 'test_state',
|
||||
'code': 'test_valid_code',
|
||||
'expires_in': 300
|
||||
})
|
||||
|
||||
# POST with content: application/x-www-form-urlencoded
|
||||
response = self.app.post('/v1/openid/token',
|
||||
params={
|
||||
'code': authorization_code.code,
|
||||
@@ -662,13 +670,14 @@ class TestOAuthAccessToken(BaseOAuthTest):
|
||||
os.environ['TZ'] = name
|
||||
|
||||
# Create a token.
|
||||
authorization_code = auth_api.authorization_code_save({
|
||||
'user_id': 2,
|
||||
'state': 'test_state',
|
||||
'code': 'test_valid_code',
|
||||
'expires_in': 300,
|
||||
'created_at': expired
|
||||
})
|
||||
with base.HybridSessionManager():
|
||||
authorization_code = auth_api.authorization_code_save({
|
||||
'user_id': 2,
|
||||
'state': 'test_state',
|
||||
'code': 'test_valid_code',
|
||||
'expires_in': 300,
|
||||
'created_at': expired
|
||||
})
|
||||
|
||||
# POST with content: application/x-www-form-urlencoded
|
||||
response = self.app.post('/v1/openid/token',
|
||||
@@ -695,12 +704,13 @@ class TestOAuthAccessToken(BaseOAuthTest):
|
||||
"""
|
||||
|
||||
# Generate a valid auth token
|
||||
authorization_code = auth_api.authorization_code_save({
|
||||
'user_id': 2,
|
||||
'state': 'test_state',
|
||||
'code': 'test_valid_code',
|
||||
'expires_in': 300
|
||||
})
|
||||
with base.HybridSessionManager():
|
||||
authorization_code = auth_api.authorization_code_save({
|
||||
'user_id': 2,
|
||||
'state': 'test_state',
|
||||
'code': 'test_valid_code',
|
||||
'expires_in': 300
|
||||
})
|
||||
|
||||
# POST with content: application/x-www-form-urlencoded
|
||||
response = self.app.post('/v1/openid/token',
|
||||
@@ -747,11 +757,12 @@ class TestOAuthAccessToken(BaseOAuthTest):
|
||||
"""
|
||||
|
||||
# Generate a valid access code
|
||||
authorization_code = auth_api.authorization_code_save({
|
||||
'user_id': 2,
|
||||
'state': 'test_state',
|
||||
'code': 'test_valid_code'
|
||||
})
|
||||
with base.HybridSessionManager():
|
||||
authorization_code = auth_api.authorization_code_save({
|
||||
'user_id': 2,
|
||||
'state': 'test_state',
|
||||
'code': 'test_valid_code'
|
||||
})
|
||||
|
||||
# Generate an auth and a refresh token.
|
||||
resp_1 = self.app.post('/v1/openid/token',
|
||||
@@ -770,11 +781,15 @@ class TestOAuthAccessToken(BaseOAuthTest):
|
||||
t1 = resp_1.json
|
||||
|
||||
# Assert that both are in the database.
|
||||
access_token = \
|
||||
token_api.access_token_get_by_token(t1['access_token'])
|
||||
with base.HybridSessionManager():
|
||||
access_token = \
|
||||
token_api.access_token_get_by_token(t1['access_token'])
|
||||
self.assertIsNotNone(access_token)
|
||||
refresh_token = \
|
||||
refresh_tokens.refresh_token_get_by_token(t1['refresh_token'])
|
||||
|
||||
with base.HybridSessionManager():
|
||||
refresh_token = refresh_tokens.refresh_token_get_by_token(
|
||||
t1['refresh_token'])
|
||||
|
||||
self.assertIsNotNone(refresh_token)
|
||||
|
||||
# Issue a refresh token request.
|
||||
@@ -801,8 +816,9 @@ class TestOAuthAccessToken(BaseOAuthTest):
|
||||
self.assertEqual('Bearer', t2['token_type'])
|
||||
|
||||
# Assert that the access token is in the database
|
||||
new_access_token = \
|
||||
token_api.access_token_get_by_token(t2['access_token'])
|
||||
with base.HybridSessionManager():
|
||||
new_access_token = \
|
||||
token_api.access_token_get_by_token(t2['access_token'])
|
||||
self.assertIsNotNone(new_access_token)
|
||||
|
||||
# Assert that system configured values is owned by the correct user.
|
||||
@@ -814,8 +830,11 @@ class TestOAuthAccessToken(BaseOAuthTest):
|
||||
new_access_token.access_token)
|
||||
|
||||
# Assert that the refresh token is in the database
|
||||
new_refresh_token = \
|
||||
refresh_tokens.refresh_token_get_by_token(t2['refresh_token'])
|
||||
|
||||
with base.HybridSessionManager():
|
||||
new_refresh_token = refresh_tokens.refresh_token_get_by_token(
|
||||
t2['refresh_token'])
|
||||
|
||||
self.assertIsNotNone(new_refresh_token)
|
||||
|
||||
# Assert that system configured values is owned by the correct user.
|
||||
@@ -827,10 +846,13 @@ class TestOAuthAccessToken(BaseOAuthTest):
|
||||
|
||||
# Assert that the old access tokens are no longer in the database and
|
||||
# have been cleaned up.
|
||||
no_access_token = \
|
||||
token_api.access_token_get_by_token(t1['access_token'])
|
||||
no_refresh_token = \
|
||||
refresh_tokens.refresh_token_get_by_token(t1['refresh_token'])
|
||||
|
||||
with base.HybridSessionManager():
|
||||
no_access_token = \
|
||||
token_api.access_token_get_by_token(t1['access_token'])
|
||||
with base.HybridSessionManager():
|
||||
no_refresh_token = \
|
||||
refresh_tokens.refresh_token_get_by_token(t1['refresh_token'])
|
||||
|
||||
self.assertIsNone(no_refresh_token)
|
||||
self.assertIsNone(no_access_token)
|
||||
|
||||
@@ -170,6 +170,22 @@ class DbTestCase(WorkingDirTestCase):
|
||||
PATH_PREFIX = '/v1'
|
||||
|
||||
|
||||
class HybridSessionManager(object):
|
||||
|
||||
def _mock_get_session(self, autocommit=True, expire_on_commit=False,
|
||||
in_request=True, **kwargs):
|
||||
return self.original_get_session(autocommit=autocommit,
|
||||
expire_on_commit=expire_on_commit,
|
||||
in_request=False, **kwargs)
|
||||
|
||||
def __enter__(self):
|
||||
self.original_get_session = db_api_base.get_session
|
||||
db_api_base.get_session = self._mock_get_session
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
db_api_base.get_session = self.original_get_session
|
||||
|
||||
|
||||
class FunctionalTest(DbTestCase):
|
||||
"""Used for functional tests of Pecan controllers where you need to
|
||||
test your literal application and its integration with the
|
||||
|
||||
@@ -24,6 +24,11 @@ class TestHookPriority(base.TestCase):
|
||||
def test_hook_order(self):
|
||||
"""Assert that the hook priorities are ordered properly."""
|
||||
|
||||
self.assertLess(priority.PRE_AUTH, priority.AUTH)
|
||||
self.assertLess(priority.PRE_AUTH, priority.VALIDATION)
|
||||
self.assertLess(priority.PRE_AUTH, priority.POST_VALIDATION)
|
||||
self.assertLess(priority.PRE_AUTH, priority.DEFAULT)
|
||||
|
||||
self.assertLess(priority.AUTH, priority.VALIDATION)
|
||||
self.assertLess(priority.AUTH, priority.POST_VALIDATION)
|
||||
self.assertLess(priority.AUTH, priority.DEFAULT)
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from storyboard.db.api import base as db_base
|
||||
from storyboard.tests import base
|
||||
from storyboard.tests import mock_data
|
||||
|
||||
@@ -20,8 +21,22 @@ from storyboard.tests import mock_data
|
||||
class BaseDbTestCase(base.DbTestCase):
|
||||
def setUp(self):
|
||||
super(BaseDbTestCase, self).setUp()
|
||||
|
||||
self.original_get_session = db_base.get_session
|
||||
self.addCleanup(self._reset_get_session)
|
||||
db_base.get_session = self._mock_get_session
|
||||
|
||||
mock_data.load()
|
||||
|
||||
def _mock_get_session(self, autocommit=True, expire_on_commit=False,
|
||||
in_request=True, **kwargs):
|
||||
return self.original_get_session(autocommit=autocommit,
|
||||
expire_on_commit=expire_on_commit,
|
||||
in_request=False, **kwargs)
|
||||
|
||||
def _reset_get_session(self):
|
||||
db_base.get_session = self.original_get_session
|
||||
|
||||
def _assert_saved_fields(self, expected, actual):
|
||||
for k in expected.keys():
|
||||
self.assertEqual(expected[k], actual[k])
|
||||
|
||||
@@ -317,7 +317,7 @@ def load_data(data):
|
||||
|
||||
:param data An iterable collection of database models.
|
||||
"""
|
||||
session = db.get_session(autocommit=False)
|
||||
session = db.get_session(autocommit=False, in_request=False)
|
||||
|
||||
for entity in data:
|
||||
session.add(entity)
|
||||
|
||||
Reference in New Issue
Block a user