diff --git a/nova/objects/aggregate.py b/nova/objects/aggregate.py index 5d43b75eb..94ae25cd7 100644 --- a/nova/objects/aggregate.py +++ b/nova/objects/aggregate.py @@ -18,6 +18,8 @@ from oslo_utils import excutils from oslo_utils import uuidutils from sqlalchemy.orm import contains_eager from sqlalchemy.orm import joinedload +from sqlalchemy.sql import func +from sqlalchemy.sql import text from nova.compute import utils as compute_utils from nova import db @@ -558,3 +560,64 @@ class AggregateList(base.ObjectListBase, base.NovaObject): all_aggregates = cls._filter_db_aggregates(all_aggregates, hosts) return base.obj_make_list(context, cls(context), objects.Aggregate, all_aggregates) + + +@db_api.main_context_manager.reader +def _get_main_db_aggregate_ids(context, limit): + from nova.db.sqlalchemy import models + return [x[0] for x in context.session.query(models.Aggregate.id). + filter_by(deleted=0). + limit(limit)] + + +def migrate_aggregates(ctxt, count): + main_db_ids = _get_main_db_aggregate_ids(ctxt, count) + if not main_db_ids: + return 0, 0 + + count_all = len(main_db_ids) + count_hit = 0 + + for aggregate_id in main_db_ids: + try: + aggregate = Aggregate.get_by_id(ctxt, aggregate_id) + remove = ['metadata', 'hosts'] + values = {field: getattr(aggregate, field) + for field in aggregate.fields if field not in remove} + _aggregate_create_in_db(ctxt, values, metadata=aggregate.metadata) + for host in aggregate.hosts: + _host_add_to_db(ctxt, aggregate_id, host) + count_hit += 1 + db.aggregate_delete(ctxt, aggregate.id) + except exception.AggregateNotFound: + LOG.warning( + _LW('Aggregate id %(id)i disappeared during migration'), + {'id': aggregate_id}) + except (exception.AggregateNameExists) as e: + LOG.error(str(e)) + + return count_all, count_hit + + +def _adjust_autoincrement(context, value): + engine = db_api.get_api_engine() + if engine.name == 'postgresql': + # NOTE(danms): If we migrated some aggregates in the above function, + # then we will have confused postgres' sequence for the autoincrement + # primary key. MySQL does not care about this, but since postgres does, + # we need to reset this to avoid a failure on the next aggregate + # creation. + engine.execute( + text('ALTER SEQUENCE aggregates_id_seq RESTART WITH %i;' % ( + value))) + + +@db_api.api_context_manager.reader +def _get_max_aggregate_id(context): + return context.session.query(func.max(api_models.Aggregate.id)).one()[0] + + +def migrate_aggregate_reset_autoincrement(ctxt, count): + max_id = _get_max_aggregate_id(ctxt) or 0 + _adjust_autoincrement(ctxt, max_id + 1) + return 0, 0 diff --git a/nova/tests/functional/db/test_aggregate.py b/nova/tests/functional/db/test_aggregate.py index a4dbb57f8..13e34558c 100644 --- a/nova/tests/functional/db/test_aggregate.py +++ b/nova/tests/functional/db/test_aggregate.py @@ -593,3 +593,56 @@ class AggregateObjectMixedTestCase(AggregateObjectCellTestCase): new_agg.name = 'new-aggregate' self.assertRaises(exception.ObjectActionError, new_agg.create) + + +class AggregateObjectMigrationTestCase(AggregateObjectCellTestCase): + """Tests the aggregate in the case where data is migrated to the API db""" + def _seed_data(self): + for i in range(1, 10): + create_aggregate(self.context, i, in_api=False) + aggregate_obj.migrate_aggregates(self.context, 50) + + def test_create(self): + new_agg = aggregate_obj.Aggregate(self.context) + new_agg.name = 'new-aggregate' + new_agg.create() + result = aggregate_obj.Aggregate.get_by_id(self.context, new_agg.id) + self.assertEqual(new_agg.name, result.name) + + +class AggregateMigrationTestCase(test.NoDBTestCase): + USES_DB_SELF = True + + def setUp(self): + super(AggregateMigrationTestCase, self).setUp() + self.useFixture(fixtures.Database()) + self.useFixture(fixtures.Database(database='api')) + self.context = context.get_admin_context() + + def test_migration(self): + db.aggregate_create(self.context, {'name': 'foo'}) + main_aggregates_len = len(db.aggregate_get_all(self.context)) + match, done = aggregate_obj.migrate_aggregates(self.context, 50) + self.assertEqual(1, main_aggregates_len) + self.assertEqual(main_aggregates_len, match) + self.assertEqual(main_aggregates_len, done) + self.assertEqual(0, len(db.aggregate_get_all(self.context))) + self.assertEqual(main_aggregates_len, + len(aggregate_obj.AggregateList.get_all( + self.context))) + + def test_migrate_aggregate_reset_autoincrement(self): + agg = aggregate_obj.Aggregate(self.context, name='foo') + agg.create() + match, done = aggregate_obj.migrate_aggregate_reset_autoincrement( + self.context, 0) + self.assertEqual(0, match) + self.assertEqual(0, done) + + def test_migrate_aggregate_reset_autoincrement_no_aggregates(self): + # NOTE(danms): This validates the "or 0" default if there are no + # aggregates (and thus no max id). + match, done = aggregate_obj.migrate_aggregate_reset_autoincrement( + self.context, 0) + self.assertEqual(0, match) + self.assertEqual(0, done)