No more legacy engine facade in tests

Do not use the legacy engine facade from oslo_db even in testing.

Change-Id: I70d0724679a5fd5cd3ace54e4ae44957c2277daf
Closes-Bug: #1490571
This commit is contained in:
Morgan Fainberg 2016-02-24 19:29:20 -08:00
parent 2707f8ffb0
commit 8f432e9b44
4 changed files with 63 additions and 52 deletions

View File

@ -130,10 +130,11 @@ def find_migrate_repo(package=None, repo_name='migrate_repo'):
def _sync_common_repo(version):
abs_path = find_migrate_repo()
init_version = migrate_repo.DB_INIT_VERSION
engine = sql.get_engine()
_assert_not_schema_downgrade(version=version)
migration.db_sync(engine, abs_path, version=version,
init_version=init_version, sanity_check=False)
with sql.session_for_write() as session:
engine = session.get_bind()
_assert_not_schema_downgrade(version=version)
migration.db_sync(engine, abs_path, version=version,
init_version=init_version, sanity_check=False)
def _assert_not_schema_downgrade(extension=None, version=None):
@ -154,31 +155,32 @@ def _sync_extension_repo(extension, version):
raise exception.MigrationMovedFailure(extension=extension)
init_version = 0
engine = sql.get_engine()
with sql.session_for_write() as session:
engine = session.get_bind()
try:
package_name = '.'.join((contrib.__name__, extension))
package = importutils.import_module(package_name)
except ImportError:
raise ImportError(_("%s extension does not exist.")
% package_name)
try:
abs_path = find_migrate_repo(package)
try:
migration.db_version_control(engine, abs_path)
# Register the repo with the version control API
# If it already knows about the repo, it will throw
# an exception that we can safely ignore
except exceptions.DatabaseAlreadyControlledError: # nosec
pass
except exception.MigrationNotProvided as e:
print(e)
sys.exit(1)
package_name = '.'.join((contrib.__name__, extension))
package = importutils.import_module(package_name)
except ImportError:
raise ImportError(_("%s extension does not exist.")
% package_name)
try:
abs_path = find_migrate_repo(package)
try:
migration.db_version_control(engine, abs_path)
# Register the repo with the version control API
# If it already knows about the repo, it will throw
# an exception that we can safely ignore
except exceptions.DatabaseAlreadyControlledError: # nosec
pass
except exception.MigrationNotProvided as e:
print(e)
sys.exit(1)
_assert_not_schema_downgrade(extension=extension, version=version)
_assert_not_schema_downgrade(extension=extension, version=version)
migration.db_sync(engine, abs_path, version=version,
init_version=init_version, sanity_check=False)
migration.db_sync(engine, abs_path, version=version,
init_version=init_version, sanity_check=False)
def sync_database_to_version(extension=None, version=None):
@ -195,8 +197,10 @@ def sync_database_to_version(extension=None, version=None):
def get_db_version(extension=None):
if not extension:
return migration.db_version(sql.get_engine(), find_migrate_repo(),
migrate_repo.DB_INIT_VERSION)
with sql.session_for_write() as session:
return migration.db_version(session.get_bind(),
find_migrate_repo(),
migrate_repo.DB_INIT_VERSION)
try:
package_name = '.'.join((contrib.__name__, extension))
@ -205,8 +209,9 @@ def get_db_version(extension=None):
raise ImportError(_("%s extension does not exist.")
% package_name)
return migration.db_version(
sql.get_engine(), find_migrate_repo(package), 0)
with sql.session_for_write() as session:
return migration.db_version(
session.get_bind(), find_migrate_repo(package), 0)
def print_db_version(extension=None):

View File

@ -148,7 +148,8 @@ class Database(fixtures.Fixture):
def setUp(self):
super(Database, self).setUp()
self.engine = sql.get_engine()
with sql.session_for_write() as session:
self.engine = session.get_bind()
self.addCleanup(sql.cleanup)
sql.ModelBase.metadata.create_all(bind=self.engine)
self.addCleanup(sql.ModelBase.metadata.drop_all, bind=self.engine)

View File

@ -184,19 +184,21 @@ class SqlModels(SqlTests):
class SqlIdentity(SqlTests, test_backend.IdentityTests):
def test_password_hashed(self):
session = sql.get_session()
user_ref = self.identity_api._get_user(session, self.user_foo['id'])
self.assertNotEqual(self.user_foo['password'], user_ref['password'])
with sql.session_for_read() as session:
user_ref = self.identity_api._get_user(session,
self.user_foo['id'])
self.assertNotEqual(self.user_foo['password'],
user_ref['password'])
def test_create_user_with_null_password(self):
user_dict = unit.new_user_ref(
domain_id=CONF.identity.default_domain_id)
user_dict["password"] = None
new_user_dict = self.identity_api.create_user(user_dict)
session = sql.get_session()
new_user_ref = self.identity_api._get_user(session,
new_user_dict['id'])
self.assertFalse(new_user_ref.local_user.passwords)
with sql.session_for_read() as session:
new_user_ref = self.identity_api._get_user(session,
new_user_dict['id'])
self.assertFalse(new_user_ref.local_user.passwords)
def test_update_user_with_null_password(self):
user_dict = unit.new_user_ref(
@ -206,10 +208,10 @@ class SqlIdentity(SqlTests, test_backend.IdentityTests):
new_user_dict["password"] = None
new_user_dict = self.identity_api.update_user(new_user_dict['id'],
new_user_dict)
session = sql.get_session()
new_user_ref = self.identity_api._get_user(session,
new_user_dict['id'])
self.assertFalse(new_user_ref.local_user.passwords)
with sql.session_for_read() as session:
new_user_ref = self.identity_api._get_user(session,
new_user_dict['id'])
self.assertFalse(new_user_ref.local_user.passwords)
def test_delete_user_with_project_association(self):
user = unit.new_user_ref(domain_id=CONF.identity.default_domain_id)
@ -339,14 +341,14 @@ class SqlIdentity(SqlTests, test_backend.IdentityTests):
def test_sql_user_to_dict_null_default_project_id(self):
user = unit.new_user_ref(domain_id=CONF.identity.default_domain_id)
user = self.identity_api.create_user(user)
session = sql.get_session()
query = session.query(identity_sql.User)
query = query.filter_by(id=user['id'])
raw_user_ref = query.one()
self.assertIsNone(raw_user_ref.default_project_id)
user_ref = raw_user_ref.to_dict()
self.assertNotIn('default_project_id', user_ref)
session.close()
with sql.session_for_read() as session:
query = session.query(identity_sql.User)
query = query.filter_by(id=user['id'])
raw_user_ref = query.one()
self.assertIsNone(raw_user_ref.default_project_id)
user_ref = raw_user_ref.to_dict()
self.assertNotIn('default_project_id', user_ref)
session.close()
def test_list_domains_for_user(self):
domain = unit.new_domain_ref()

View File

@ -151,7 +151,9 @@ class SqlMigrateBase(unit.SQLDriverOverrides, unit.TestCase):
connection='sqlite:///%s' % db_file)
# create and share a single sqlalchemy engine for testing
self.engine = sql.get_engine()
with sql.session_for_write() as session:
self.engine = session.get_bind()
self.addCleanup(self.cleanup_instance('engine'))
self.Session = db_session.get_maker(self.engine, autocommit=False)
self.addCleanup(sqlalchemy.orm.session.Session.close_all)
@ -284,8 +286,9 @@ class SqlUpgradeTests(SqlMigrateBase):
self.assertTableDoesNotExist('user')
def test_start_version_db_init_version(self):
version = migration.db_version(sql.get_engine(), self.repo_path,
migrate_repo.DB_INIT_VERSION)
with sql.session_for_write() as session:
version = migration.db_version(session.get_bind(), self.repo_path,
migrate_repo.DB_INIT_VERSION)
self.assertEqual(
migrate_repo.DB_INIT_VERSION,
version,