diff --git a/keystone/common/sql/migration_helpers.py b/keystone/common/sql/migration_helpers.py index b6a68091e8..8c33989456 100644 --- a/keystone/common/sql/migration_helpers.py +++ b/keystone/common/sql/migration_helpers.py @@ -130,10 +130,11 @@ def find_migrate_repo(package=None, repo_name='migrate_repo'): def _sync_common_repo(version): abs_path = find_migrate_repo() init_version = migrate_repo.DB_INIT_VERSION - engine = sql.get_engine() - _assert_not_schema_downgrade(version=version) - migration.db_sync(engine, abs_path, version=version, - init_version=init_version, sanity_check=False) + with sql.session_for_write() as session: + engine = session.get_bind() + _assert_not_schema_downgrade(version=version) + migration.db_sync(engine, abs_path, version=version, + init_version=init_version, sanity_check=False) def _assert_not_schema_downgrade(extension=None, version=None): @@ -154,31 +155,32 @@ def _sync_extension_repo(extension, version): raise exception.MigrationMovedFailure(extension=extension) init_version = 0 - engine = sql.get_engine() + with sql.session_for_write() as session: + engine = session.get_bind() - try: - package_name = '.'.join((contrib.__name__, extension)) - package = importutils.import_module(package_name) - except ImportError: - raise ImportError(_("%s extension does not exist.") - % package_name) - try: - abs_path = find_migrate_repo(package) try: - migration.db_version_control(engine, abs_path) - # Register the repo with the version control API - # If it already knows about the repo, it will throw - # an exception that we can safely ignore - except exceptions.DatabaseAlreadyControlledError: # nosec - pass - except exception.MigrationNotProvided as e: - print(e) - sys.exit(1) + package_name = '.'.join((contrib.__name__, extension)) + package = importutils.import_module(package_name) + except ImportError: + raise ImportError(_("%s extension does not exist.") + % package_name) + try: + abs_path = find_migrate_repo(package) + try: + migration.db_version_control(engine, abs_path) + # Register the repo with the version control API + # If it already knows about the repo, it will throw + # an exception that we can safely ignore + except exceptions.DatabaseAlreadyControlledError: # nosec + pass + except exception.MigrationNotProvided as e: + print(e) + sys.exit(1) - _assert_not_schema_downgrade(extension=extension, version=version) + _assert_not_schema_downgrade(extension=extension, version=version) - migration.db_sync(engine, abs_path, version=version, - init_version=init_version, sanity_check=False) + migration.db_sync(engine, abs_path, version=version, + init_version=init_version, sanity_check=False) def sync_database_to_version(extension=None, version=None): @@ -195,8 +197,10 @@ def sync_database_to_version(extension=None, version=None): def get_db_version(extension=None): if not extension: - return migration.db_version(sql.get_engine(), find_migrate_repo(), - migrate_repo.DB_INIT_VERSION) + with sql.session_for_write() as session: + return migration.db_version(session.get_bind(), + find_migrate_repo(), + migrate_repo.DB_INIT_VERSION) try: package_name = '.'.join((contrib.__name__, extension)) @@ -205,8 +209,9 @@ def get_db_version(extension=None): raise ImportError(_("%s extension does not exist.") % package_name) - return migration.db_version( - sql.get_engine(), find_migrate_repo(package), 0) + with sql.session_for_write() as session: + return migration.db_version( + session.get_bind(), find_migrate_repo(package), 0) def print_db_version(extension=None): diff --git a/keystone/tests/unit/ksfixtures/database.py b/keystone/tests/unit/ksfixtures/database.py index 5aec38c2ca..52c35cee87 100644 --- a/keystone/tests/unit/ksfixtures/database.py +++ b/keystone/tests/unit/ksfixtures/database.py @@ -148,7 +148,8 @@ class Database(fixtures.Fixture): def setUp(self): super(Database, self).setUp() - self.engine = sql.get_engine() + with sql.session_for_write() as session: + self.engine = session.get_bind() self.addCleanup(sql.cleanup) sql.ModelBase.metadata.create_all(bind=self.engine) self.addCleanup(sql.ModelBase.metadata.drop_all, bind=self.engine) diff --git a/keystone/tests/unit/test_backend_sql.py b/keystone/tests/unit/test_backend_sql.py index c42e450701..d5d79ac483 100644 --- a/keystone/tests/unit/test_backend_sql.py +++ b/keystone/tests/unit/test_backend_sql.py @@ -184,19 +184,21 @@ class SqlModels(SqlTests): class SqlIdentity(SqlTests, test_backend.IdentityTests): def test_password_hashed(self): - session = sql.get_session() - user_ref = self.identity_api._get_user(session, self.user_foo['id']) - self.assertNotEqual(self.user_foo['password'], user_ref['password']) + with sql.session_for_read() as session: + user_ref = self.identity_api._get_user(session, + self.user_foo['id']) + self.assertNotEqual(self.user_foo['password'], + user_ref['password']) def test_create_user_with_null_password(self): user_dict = unit.new_user_ref( domain_id=CONF.identity.default_domain_id) user_dict["password"] = None new_user_dict = self.identity_api.create_user(user_dict) - session = sql.get_session() - new_user_ref = self.identity_api._get_user(session, - new_user_dict['id']) - self.assertFalse(new_user_ref.local_user.passwords) + with sql.session_for_read() as session: + new_user_ref = self.identity_api._get_user(session, + new_user_dict['id']) + self.assertFalse(new_user_ref.local_user.passwords) def test_update_user_with_null_password(self): user_dict = unit.new_user_ref( @@ -206,10 +208,10 @@ class SqlIdentity(SqlTests, test_backend.IdentityTests): new_user_dict["password"] = None new_user_dict = self.identity_api.update_user(new_user_dict['id'], new_user_dict) - session = sql.get_session() - new_user_ref = self.identity_api._get_user(session, - new_user_dict['id']) - self.assertFalse(new_user_ref.local_user.passwords) + with sql.session_for_read() as session: + new_user_ref = self.identity_api._get_user(session, + new_user_dict['id']) + self.assertFalse(new_user_ref.local_user.passwords) def test_delete_user_with_project_association(self): user = unit.new_user_ref(domain_id=CONF.identity.default_domain_id) @@ -339,14 +341,14 @@ class SqlIdentity(SqlTests, test_backend.IdentityTests): def test_sql_user_to_dict_null_default_project_id(self): user = unit.new_user_ref(domain_id=CONF.identity.default_domain_id) user = self.identity_api.create_user(user) - session = sql.get_session() - query = session.query(identity_sql.User) - query = query.filter_by(id=user['id']) - raw_user_ref = query.one() - self.assertIsNone(raw_user_ref.default_project_id) - user_ref = raw_user_ref.to_dict() - self.assertNotIn('default_project_id', user_ref) - session.close() + with sql.session_for_read() as session: + query = session.query(identity_sql.User) + query = query.filter_by(id=user['id']) + raw_user_ref = query.one() + self.assertIsNone(raw_user_ref.default_project_id) + user_ref = raw_user_ref.to_dict() + self.assertNotIn('default_project_id', user_ref) + session.close() def test_list_domains_for_user(self): domain = unit.new_domain_ref() diff --git a/keystone/tests/unit/test_sql_upgrade.py b/keystone/tests/unit/test_sql_upgrade.py index dea3a5db0f..afc8e23184 100644 --- a/keystone/tests/unit/test_sql_upgrade.py +++ b/keystone/tests/unit/test_sql_upgrade.py @@ -151,7 +151,9 @@ class SqlMigrateBase(unit.SQLDriverOverrides, unit.TestCase): connection='sqlite:///%s' % db_file) # create and share a single sqlalchemy engine for testing - self.engine = sql.get_engine() + with sql.session_for_write() as session: + self.engine = session.get_bind() + self.addCleanup(self.cleanup_instance('engine')) self.Session = db_session.get_maker(self.engine, autocommit=False) self.addCleanup(sqlalchemy.orm.session.Session.close_all) @@ -284,8 +286,9 @@ class SqlUpgradeTests(SqlMigrateBase): self.assertTableDoesNotExist('user') def test_start_version_db_init_version(self): - version = migration.db_version(sql.get_engine(), self.repo_path, - migrate_repo.DB_INIT_VERSION) + with sql.session_for_write() as session: + version = migration.db_version(session.get_bind(), self.repo_path, + migrate_repo.DB_INIT_VERSION) self.assertEqual( migrate_repo.DB_INIT_VERSION, version,