Merge "Base.get_engine honor allow_global_engine=False"

This commit is contained in:
Jenkins 2013-06-22 01:32:57 +00:00 committed by Gerrit Code Review
commit 97225ffc03
2 changed files with 49 additions and 8 deletions

View File

@ -99,10 +99,6 @@ def set_global_engine(engine):
GLOBAL_ENGINE = engine
def get_global_engine():
return GLOBAL_ENGINE
# Special Fields
class JsonBlob(sql_types.TypeDecorator):
@ -244,14 +240,19 @@ class Base(object):
return sql.create_engine(CONF.sql.connection, **engine_config)
engine = get_global_engine() or new_engine()
if not allow_global_engine:
return new_engine()
if GLOBAL_ENGINE:
return GLOBAL_ENGINE
engine = new_engine()
# auto-build the db to support wsgi server w/ in-memory backend
if allow_global_engine and CONF.sql.connection == 'sqlite://':
if CONF.sql.connection == 'sqlite://':
ModelBase.metadata.create_all(bind=engine)
if allow_global_engine:
set_global_engine(engine)
set_global_engine(engine)
return engine

40
tests/test_sql_core.py Normal file
View File

@ -0,0 +1,40 @@
# Copyright 2013 IBM Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from keystone.common import sql
from keystone import test
class TestBase(test.TestCase):
def tearDown(self):
sql.set_global_engine(None)
super(TestBase, self).tearDown()
def test_get_engine_global(self):
# If call get_engine() twice, get the same global engine.
base = sql.Base()
engine1 = base.get_engine()
self.assertIsNotNone(engine1)
engine2 = base.get_engine()
self.assertIs(engine1, engine2)
def test_get_engine_not_global(self):
# If call get_engine() twice, once with allow_global_engine=True
# and once with allow_global_engine=False, get different engines.
base = sql.Base()
engine1 = base.get_engine()
engine2 = base.get_engine(allow_global_engine=False)
self.assertIsNot(engine1, engine2)