Merge "No more legacy engine facade in tests"

This commit is contained in:
Jenkins 2016-02-25 19:11:52 +00:00 committed by Gerrit Code Review
commit 05c58bb35e
4 changed files with 63 additions and 52 deletions

View File

@ -130,7 +130,8 @@ def find_migrate_repo(package=None, repo_name='migrate_repo'):
def _sync_common_repo(version): def _sync_common_repo(version):
abs_path = find_migrate_repo() abs_path = find_migrate_repo()
init_version = migrate_repo.DB_INIT_VERSION init_version = migrate_repo.DB_INIT_VERSION
engine = sql.get_engine() with sql.session_for_write() as session:
engine = session.get_bind()
_assert_not_schema_downgrade(version=version) _assert_not_schema_downgrade(version=version)
migration.db_sync(engine, abs_path, version=version, migration.db_sync(engine, abs_path, version=version,
init_version=init_version, sanity_check=False) init_version=init_version, sanity_check=False)
@ -154,7 +155,8 @@ def _sync_extension_repo(extension, version):
raise exception.MigrationMovedFailure(extension=extension) raise exception.MigrationMovedFailure(extension=extension)
init_version = 0 init_version = 0
engine = sql.get_engine() with sql.session_for_write() as session:
engine = session.get_bind()
try: try:
package_name = '.'.join((contrib.__name__, extension)) package_name = '.'.join((contrib.__name__, extension))
@ -195,7 +197,9 @@ def sync_database_to_version(extension=None, version=None):
def get_db_version(extension=None): def get_db_version(extension=None):
if not extension: if not extension:
return migration.db_version(sql.get_engine(), find_migrate_repo(), with sql.session_for_write() as session:
return migration.db_version(session.get_bind(),
find_migrate_repo(),
migrate_repo.DB_INIT_VERSION) migrate_repo.DB_INIT_VERSION)
try: try:
@ -205,8 +209,9 @@ def get_db_version(extension=None):
raise ImportError(_("%s extension does not exist.") raise ImportError(_("%s extension does not exist.")
% package_name) % package_name)
with sql.session_for_write() as session:
return migration.db_version( return migration.db_version(
sql.get_engine(), find_migrate_repo(package), 0) session.get_bind(), find_migrate_repo(package), 0)
def print_db_version(extension=None): def print_db_version(extension=None):

View File

@ -148,7 +148,8 @@ class Database(fixtures.Fixture):
def setUp(self): def setUp(self):
super(Database, self).setUp() super(Database, self).setUp()
self.engine = sql.get_engine() with sql.session_for_write() as session:
self.engine = session.get_bind()
self.addCleanup(sql.cleanup) self.addCleanup(sql.cleanup)
sql.ModelBase.metadata.create_all(bind=self.engine) sql.ModelBase.metadata.create_all(bind=self.engine)
self.addCleanup(sql.ModelBase.metadata.drop_all, bind=self.engine) self.addCleanup(sql.ModelBase.metadata.drop_all, bind=self.engine)

View File

@ -184,16 +184,18 @@ class SqlModels(SqlTests):
class SqlIdentity(SqlTests, test_backend.IdentityTests): class SqlIdentity(SqlTests, test_backend.IdentityTests):
def test_password_hashed(self): def test_password_hashed(self):
session = sql.get_session() with sql.session_for_read() as session:
user_ref = self.identity_api._get_user(session, self.user_foo['id']) user_ref = self.identity_api._get_user(session,
self.assertNotEqual(self.user_foo['password'], user_ref['password']) self.user_foo['id'])
self.assertNotEqual(self.user_foo['password'],
user_ref['password'])
def test_create_user_with_null_password(self): def test_create_user_with_null_password(self):
user_dict = unit.new_user_ref( user_dict = unit.new_user_ref(
domain_id=CONF.identity.default_domain_id) domain_id=CONF.identity.default_domain_id)
user_dict["password"] = None user_dict["password"] = None
new_user_dict = self.identity_api.create_user(user_dict) new_user_dict = self.identity_api.create_user(user_dict)
session = sql.get_session() with sql.session_for_read() as session:
new_user_ref = self.identity_api._get_user(session, new_user_ref = self.identity_api._get_user(session,
new_user_dict['id']) new_user_dict['id'])
self.assertFalse(new_user_ref.local_user.passwords) self.assertFalse(new_user_ref.local_user.passwords)
@ -206,7 +208,7 @@ class SqlIdentity(SqlTests, test_backend.IdentityTests):
new_user_dict["password"] = None new_user_dict["password"] = None
new_user_dict = self.identity_api.update_user(new_user_dict['id'], new_user_dict = self.identity_api.update_user(new_user_dict['id'],
new_user_dict) new_user_dict)
session = sql.get_session() with sql.session_for_read() as session:
new_user_ref = self.identity_api._get_user(session, new_user_ref = self.identity_api._get_user(session,
new_user_dict['id']) new_user_dict['id'])
self.assertFalse(new_user_ref.local_user.passwords) self.assertFalse(new_user_ref.local_user.passwords)
@ -339,7 +341,7 @@ class SqlIdentity(SqlTests, test_backend.IdentityTests):
def test_sql_user_to_dict_null_default_project_id(self): def test_sql_user_to_dict_null_default_project_id(self):
user = unit.new_user_ref(domain_id=CONF.identity.default_domain_id) user = unit.new_user_ref(domain_id=CONF.identity.default_domain_id)
user = self.identity_api.create_user(user) user = self.identity_api.create_user(user)
session = sql.get_session() with sql.session_for_read() as session:
query = session.query(identity_sql.User) query = session.query(identity_sql.User)
query = query.filter_by(id=user['id']) query = query.filter_by(id=user['id'])
raw_user_ref = query.one() raw_user_ref = query.one()

View File

@ -151,7 +151,9 @@ class SqlMigrateBase(unit.SQLDriverOverrides, unit.TestCase):
connection='sqlite:///%s' % db_file) connection='sqlite:///%s' % db_file)
# create and share a single sqlalchemy engine for testing # create and share a single sqlalchemy engine for testing
self.engine = sql.get_engine() with sql.session_for_write() as session:
self.engine = session.get_bind()
self.addCleanup(self.cleanup_instance('engine'))
self.Session = db_session.get_maker(self.engine, autocommit=False) self.Session = db_session.get_maker(self.engine, autocommit=False)
self.addCleanup(sqlalchemy.orm.session.Session.close_all) self.addCleanup(sqlalchemy.orm.session.Session.close_all)
@ -284,7 +286,8 @@ class SqlUpgradeTests(SqlMigrateBase):
self.assertTableDoesNotExist('user') self.assertTableDoesNotExist('user')
def test_start_version_db_init_version(self): def test_start_version_db_init_version(self):
version = migration.db_version(sql.get_engine(), self.repo_path, with sql.session_for_write() as session:
version = migration.db_version(session.get_bind(), self.repo_path,
migrate_repo.DB_INIT_VERSION) migrate_repo.DB_INIT_VERSION)
self.assertEqual( self.assertEqual(
migrate_repo.DB_INIT_VERSION, migrate_repo.DB_INIT_VERSION,