diff --git a/keystone/common/sql/core.py b/keystone/common/sql/core.py index 694d30bd07..e1bd138e28 100644 --- a/keystone/common/sql/core.py +++ b/keystone/common/sql/core.py @@ -99,10 +99,6 @@ def set_global_engine(engine): GLOBAL_ENGINE = engine -def get_global_engine(): - return GLOBAL_ENGINE - - # Special Fields class JsonBlob(sql_types.TypeDecorator): @@ -244,14 +240,19 @@ class Base(object): return sql.create_engine(CONF.sql.connection, **engine_config) - engine = get_global_engine() or new_engine() + if not allow_global_engine: + return new_engine() + + if GLOBAL_ENGINE: + return GLOBAL_ENGINE + + engine = new_engine() # auto-build the db to support wsgi server w/ in-memory backend - if allow_global_engine and CONF.sql.connection == 'sqlite://': + if CONF.sql.connection == 'sqlite://': ModelBase.metadata.create_all(bind=engine) - if allow_global_engine: - set_global_engine(engine) + set_global_engine(engine) return engine diff --git a/tests/test_sql_core.py b/tests/test_sql_core.py new file mode 100644 index 0000000000..d8f2a4f7b4 --- /dev/null +++ b/tests/test_sql_core.py @@ -0,0 +1,40 @@ +# Copyright 2013 IBM Corp. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + + +from keystone.common import sql +from keystone import test + + +class TestBase(test.TestCase): + + def tearDown(self): + sql.set_global_engine(None) + super(TestBase, self).tearDown() + + def test_get_engine_global(self): + # If call get_engine() twice, get the same global engine. + base = sql.Base() + engine1 = base.get_engine() + self.assertIsNotNone(engine1) + engine2 = base.get_engine() + self.assertIs(engine1, engine2) + + def test_get_engine_not_global(self): + # If call get_engine() twice, once with allow_global_engine=True + # and once with allow_global_engine=False, get different engines. + base = sql.Base() + engine1 = base.get_engine() + engine2 = base.get_engine(allow_global_engine=False) + self.assertIsNot(engine1, engine2)