# Copyright 2010-2011 OpenStack Foundation # Copyright 2012-2013 IBM Corp. # All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain # a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # License for the specific language governing permissions and limitations # under the License. """ Tests for database migrations. There are "opportunistic" tests for both mysql and postgresql in here, which allows testing against these databases in a properly configured unit test environment. For the opportunistic testing you need to set up a db named 'openstack_citest' with user 'openstack_citest' and password 'openstack_citest' on localhost. The test will then use that db and u/p combo to run the tests. For postgres on Ubuntu this can be done with the following commands: :: sudo -u postgres psql postgres=# create user openstack_citest with createdb login password 'openstack_citest'; postgres=# create database openstack_citest with owner openstack_citest; """ import contextlib from alembic import script import mock from oslo_db import exception as db_exc from oslo_db.sqlalchemy import enginefacade from oslo_db.sqlalchemy import test_base from oslo_db.sqlalchemy import test_migrations from oslo_db.sqlalchemy import utils as db_utils from oslo_log import log as logging from oslo_utils import uuidutils import sqlalchemy import sqlalchemy.exc from ironic.common.i18n import _LE from ironic.db.sqlalchemy import migration from ironic.db.sqlalchemy import models from ironic.tests.unit import base LOG = logging.getLogger(__name__) def _get_connect_string(backend, user, passwd, database): """Get database connection Try to get a connection with a very specific set of values, if we get these then we'll run the tests, otherwise they are skipped """ if backend == "postgres": backend = "postgresql+psycopg2" elif backend == "mysql": backend = "mysql+mysqldb" else: raise Exception("Unrecognized backend: '%s'" % backend) return ("%(backend)s://%(user)s:%(passwd)s@localhost/%(database)s" % {'backend': backend, 'user': user, 'passwd': passwd, 'database': database}) def _is_backend_avail(backend, user, passwd, database): try: connect_uri = _get_connect_string(backend, user, passwd, database) engine = sqlalchemy.create_engine(connect_uri) connection = engine.connect() except Exception: # intentionally catch all to handle exceptions even if we don't # have any backend code loaded. return False else: connection.close() engine.dispose() return True @contextlib.contextmanager def patch_with_engine(engine): with mock.patch.object(enginefacade.get_legacy_facade(), 'get_engine') as patch_engine: patch_engine.return_value = engine yield class WalkVersionsMixin(object): def _walk_versions(self, engine=None, alembic_cfg=None, downgrade=True): # Determine latest version script from the repo, then # upgrade from 1 through to the latest, with no data # in the databases. This just checks that the schema itself # upgrades successfully. # Place the database under version control with patch_with_engine(engine): script_directory = script.ScriptDirectory.from_config(alembic_cfg) self.assertIsNone(self.migration_api.version(alembic_cfg)) versions = [ver for ver in script_directory.walk_revisions()] for version in reversed(versions): self._migrate_up(engine, alembic_cfg, version.revision, with_data=True) if downgrade: for version in versions: self._migrate_down(engine, alembic_cfg, version.revision) def _migrate_down(self, engine, config, version, with_data=False): try: self.migration_api.downgrade(version, config=config) except NotImplementedError: # NOTE(sirp): some migrations, namely release-level # migrations, don't support a downgrade. return False self.assertEqual(version, self.migration_api.version(config)) # NOTE(sirp): `version` is what we're downgrading to (i.e. the 'target' # version). So if we have any downgrade checks, they need to be run for # the previous (higher numbered) migration. if with_data: post_downgrade = getattr( self, "_post_downgrade_%s" % (version), None) if post_downgrade: post_downgrade(engine) return True def _migrate_up(self, engine, config, version, with_data=False): """migrate up to a new version of the db. We allow for data insertion and post checks at every migration version with special _pre_upgrade_### and _check_### functions in the main test. """ # NOTE(sdague): try block is here because it's impossible to debug # where a failed data migration happens otherwise try: if with_data: data = None pre_upgrade = getattr( self, "_pre_upgrade_%s" % version, None) if pre_upgrade: data = pre_upgrade(engine) self.migration_api.upgrade(version, config=config) self.assertEqual(version, self.migration_api.version(config)) if with_data: check = getattr(self, "_check_%s" % version, None) if check: check(engine, data) except Exception: LOG.error(_LE("Failed to migrate to version %(version)s on engine " "%(engine)s"), {'version': version, 'engine': engine}) raise class TestWalkVersions(base.TestCase, WalkVersionsMixin): def setUp(self): super(TestWalkVersions, self).setUp() self.migration_api = mock.MagicMock() self.engine = mock.MagicMock() self.config = mock.MagicMock() self.versions = [mock.Mock(revision='2b2'), mock.Mock(revision='1a1')] def test_migrate_up(self): self.migration_api.version.return_value = 'dsa123' self._migrate_up(self.engine, self.config, 'dsa123') self.migration_api.upgrade.assert_called_with('dsa123', config=self.config) self.migration_api.version.assert_called_with(self.config) def test_migrate_up_with_data(self): test_value = {"a": 1, "b": 2} self.migration_api.version.return_value = '141' self._pre_upgrade_141 = mock.MagicMock() self._pre_upgrade_141.return_value = test_value self._check_141 = mock.MagicMock() self._migrate_up(self.engine, self.config, '141', True) self._pre_upgrade_141.assert_called_with(self.engine) self._check_141.assert_called_with(self.engine, test_value) def test_migrate_down(self): self.migration_api.version.return_value = '42' self.assertTrue(self._migrate_down(self.engine, self.config, '42')) self.migration_api.version.assert_called_with(self.config) def test_migrate_down_not_implemented(self): self.migration_api.downgrade.side_effect = NotImplementedError self.assertFalse(self._migrate_down(self.engine, self.config, '42')) def test_migrate_down_with_data(self): self._post_downgrade_043 = mock.MagicMock() self.migration_api.version.return_value = '043' self._migrate_down(self.engine, self.config, '043', True) self._post_downgrade_043.assert_called_with(self.engine) @mock.patch.object(script, 'ScriptDirectory') @mock.patch.object(WalkVersionsMixin, '_migrate_up') @mock.patch.object(WalkVersionsMixin, '_migrate_down') def test_walk_versions_all_default(self, _migrate_up, _migrate_down, script_directory): fc = script_directory.from_config() fc.walk_revisions.return_value = self.versions self.migration_api.version.return_value = None self._walk_versions(self.engine, self.config) self.migration_api.version.assert_called_with(self.config) upgraded = [mock.call(self.engine, self.config, v.revision, with_data=True) for v in reversed(self.versions)] self.assertEqual(self._migrate_up.call_args_list, upgraded) downgraded = [mock.call(self.engine, self.config, v.revision) for v in self.versions] self.assertEqual(self._migrate_down.call_args_list, downgraded) @mock.patch.object(script, 'ScriptDirectory') @mock.patch.object(WalkVersionsMixin, '_migrate_up') @mock.patch.object(WalkVersionsMixin, '_migrate_down') def test_walk_versions_all_false(self, _migrate_up, _migrate_down, script_directory): fc = script_directory.from_config() fc.walk_revisions.return_value = self.versions self.migration_api.version.return_value = None self._walk_versions(self.engine, self.config, downgrade=False) upgraded = [mock.call(self.engine, self.config, v.revision, with_data=True) for v in reversed(self.versions)] self.assertEqual(upgraded, self._migrate_up.call_args_list) class MigrationCheckersMixin(object): def setUp(self): super(MigrationCheckersMixin, self).setUp() self.config = migration._alembic_config() self.migration_api = migration def test_walk_versions(self): self._walk_versions(self.engine, self.config, downgrade=False) def test_connect_fail(self): """Test that we can trigger a database connection failure Test that we can fail gracefully to ensure we don't break people without specific database backend """ if _is_backend_avail(self.FIXTURE.DRIVER, "openstack_cifail", self.FIXTURE.USERNAME, self.FIXTURE.DBNAME): self.fail("Shouldn't have connected") def _check_21b331f883ef(self, engine, data): nodes = db_utils.get_table(engine, 'nodes') col_names = [column.name for column in nodes.c] self.assertIn('provision_updated_at', col_names) self.assertIsInstance(nodes.c.provision_updated_at.type, sqlalchemy.types.DateTime) def _check_3cb628139ea4(self, engine, data): nodes = db_utils.get_table(engine, 'nodes') col_names = [column.name for column in nodes.c] self.assertIn('console_enabled', col_names) # in some backends bool type is integer self.assertTrue(isinstance(nodes.c.console_enabled.type, sqlalchemy.types.Boolean) or isinstance(nodes.c.console_enabled.type, sqlalchemy.types.Integer)) def _check_31baaf680d2b(self, engine, data): nodes = db_utils.get_table(engine, 'nodes') col_names = [column.name for column in nodes.c] self.assertIn('instance_info', col_names) self.assertIsInstance(nodes.c.instance_info.type, sqlalchemy.types.TEXT) def _check_3bea56f25597(self, engine, data): nodes = db_utils.get_table(engine, 'nodes') instance_uuid = 'aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee' data = {'driver': 'fake', 'uuid': uuidutils.generate_uuid(), 'instance_uuid': instance_uuid} nodes.insert().values(data).execute() data['uuid'] = uuidutils.generate_uuid() self.assertRaises(db_exc.DBDuplicateEntry, nodes.insert().execute, data) def _check_242cc6a923b3(self, engine, data): nodes = db_utils.get_table(engine, 'nodes') col_names = [column.name for column in nodes.c] self.assertIn('maintenance_reason', col_names) self.assertIsInstance(nodes.c.maintenance_reason.type, sqlalchemy.types.String) def _pre_upgrade_5674c57409b9(self, engine): # add some nodes in various states so we can assert that "None" # was replaced by "available", and nothing else changed. nodes = db_utils.get_table(engine, 'nodes') data = [{'uuid': uuidutils.generate_uuid(), 'provision_state': 'fake state'}, {'uuid': uuidutils.generate_uuid(), 'provision_state': 'active'}, {'uuid': uuidutils.generate_uuid(), 'provision_state': 'deleting'}, {'uuid': uuidutils.generate_uuid(), 'provision_state': None}] nodes.insert().values(data).execute() return data def _check_5674c57409b9(self, engine, data): nodes = db_utils.get_table(engine, 'nodes') result = engine.execute(nodes.select()) def _get_state(uuid): for row in data: if row['uuid'] == uuid: return row['provision_state'] for row in result: old = _get_state(row['uuid']) new = row['provision_state'] if old is None: self.assertEqual('available', new) else: self.assertEqual(old, new) def _check_bb59b63f55a(self, engine, data): nodes = db_utils.get_table(engine, 'nodes') col_names = [column.name for column in nodes.c] self.assertIn('driver_internal_info', col_names) self.assertIsInstance(nodes.c.driver_internal_info.type, sqlalchemy.types.TEXT) def _check_4f399b21ae71(self, engine, data): nodes = db_utils.get_table(engine, 'nodes') col_names = [column.name for column in nodes.c] self.assertIn('clean_step', col_names) self.assertIsInstance(nodes.c.clean_step.type, sqlalchemy.types.String) def _check_789acc877671(self, engine, data): nodes = db_utils.get_table(engine, 'nodes') col_names = [column.name for column in nodes.c] self.assertIn('raid_config', col_names) self.assertIn('target_raid_config', col_names) self.assertIsInstance(nodes.c.raid_config.type, sqlalchemy.types.String) self.assertIsInstance(nodes.c.target_raid_config.type, sqlalchemy.types.String) def _check_2fb93ffd2af1(self, engine, data): nodes = db_utils.get_table(engine, 'nodes') bigstring = 'a' * 255 uuid = uuidutils.generate_uuid() data = {'uuid': uuid, 'name': bigstring} nodes.insert().execute(data) node = nodes.select(nodes.c.uuid == uuid).execute().first() self.assertEqual(bigstring, node['name']) def _check_516faf1bb9b1(self, engine, data): nodes = db_utils.get_table(engine, 'nodes') bigstring = 'a' * 255 uuid = uuidutils.generate_uuid() data = {'uuid': uuid, 'driver': bigstring} nodes.insert().execute(data) node = nodes.select(nodes.c.uuid == uuid).execute().first() self.assertEqual(bigstring, node['driver']) def test_upgrade_and_version(self): with patch_with_engine(self.engine): self.migration_api.upgrade('head') self.assertIsNotNone(self.migration_api.version()) def test_create_schema_and_version(self): with patch_with_engine(self.engine): self.migration_api.create_schema() self.assertIsNotNone(self.migration_api.version()) def test_upgrade_and_create_schema(self): with patch_with_engine(self.engine): self.migration_api.upgrade('31baaf680d2b') self.assertRaises(db_exc.DbMigrationError, self.migration_api.create_schema) def test_upgrade_twice(self): with patch_with_engine(self.engine): self.migration_api.upgrade('31baaf680d2b') v1 = self.migration_api.version() self.migration_api.upgrade('head') v2 = self.migration_api.version() self.assertNotEqual(v1, v2) class TestMigrationsMySQL(MigrationCheckersMixin, WalkVersionsMixin, test_base.MySQLOpportunisticTestCase): pass class TestMigrationsPostgreSQL(MigrationCheckersMixin, WalkVersionsMixin, test_base.PostgreSQLOpportunisticTestCase): pass class ModelsMigrationSyncMixin(object): def get_metadata(self): return models.Base.metadata def get_engine(self): return self.engine def db_sync(self, engine): with patch_with_engine(engine): migration.upgrade('head') class ModelsMigrationsSyncMysql(ModelsMigrationSyncMixin, test_migrations.ModelsMigrationsSync, test_base.MySQLOpportunisticTestCase): pass class ModelsMigrationsSyncPostgres(ModelsMigrationSyncMixin, test_migrations.ModelsMigrationsSync, test_base.PostgreSQLOpportunisticTestCase): pass