diff --git a/cinder/db/sqlalchemy/api.py b/cinder/db/sqlalchemy/api.py index 746fb7be76e..59d5d40af65 100644 --- a/cinder/db/sqlalchemy/api.py +++ b/cinder/db/sqlalchemy/api.py @@ -2466,13 +2466,11 @@ def volumes_update(context, values_list): @require_context def volume_attachment_update(context, attachment_id, values): - session = get_session() - with session.begin(): - volume_attachment_ref = _attachment_get(context, attachment_id, - session=session) - volume_attachment_ref.update(values) - volume_attachment_ref.save(session=session) - return volume_attachment_ref + query = model_query(context, models.VolumeAttachment) + result = query.filter_by(id=attachment_id).update(values) + if not result: + raise exception.VolumeAttachmentNotFound( + filter='attachment_id = ' + attachment_id) def volume_update_status_based_on_attachment(context, volume_id): @@ -3121,11 +3119,11 @@ def snapshot_get_all_active_by_window(context, begin, end=None, @handle_db_data_error @require_context def snapshot_update(context, snapshot_id, values): - session = get_session() - with session.begin(): - snapshot_ref = _snapshot_get(context, snapshot_id, session=session) - snapshot_ref.update(values) - return snapshot_ref + query = model_query(context, models.Snapshot, project_only=True) + result = query.filter_by(id=snapshot_id).update(values) + if not result: + raise exception.SnapshotNotFound(snapshot_id=snapshot_id) + #################### @@ -3420,60 +3418,16 @@ def _process_group_types_filters(query, filters): @handle_db_data_error @require_admin_context -def volume_type_update(context, volume_type_id, values): +def _type_update(context, type_id, values, is_group): + if is_group: + model = models.GroupTypes + exists_exc = exception.GroupTypeExists + else: + model = models.VolumeTypes + exists_exc = exception.VolumeTypeExists + session = get_session() with session.begin(): - # Check it exists - volume_type_ref = _volume_type_ref_get(context, - volume_type_id, - session) - if not volume_type_ref: - raise exception.VolumeTypeNotFound(type_id=volume_type_id) - - # No description change - if values['description'] is None: - del values['description'] - - # No is_public change - if values['is_public'] is None: - del values['is_public'] - - # No name change - if values['name'] is None: - del values['name'] - else: - # Volume type name is unique. If change to a name that belongs to - # a different volume_type , it should be prevented. - check_vol_type = None - try: - check_vol_type = \ - _volume_type_get_by_name(context, - values['name'], - session=session) - except exception.VolumeTypeNotFoundByName: - pass - else: - if check_vol_type.get('id') != volume_type_id: - raise exception.VolumeTypeExists(id=values['name']) - - volume_type_ref.update(values) - volume_type_ref.save(session=session) - - return volume_type_ref - - -@handle_db_data_error -@require_admin_context -def group_type_update(context, group_type_id, values): - session = get_session() - with session.begin(): - # Check it exists - group_type_ref = _group_type_ref_get(context, - group_type_id, - session) - if not group_type_ref: - raise exception.GroupTypeNotFound(type_id=group_type_id) - # No description change if values['description'] is None: del values['description'] @@ -3487,23 +3441,28 @@ def group_type_update(context, group_type_id, values): del values['name'] else: # Group type name is unique. If change to a name that belongs to - # a different group_type , it should be prevented. - check_grp_type = None - try: - check_grp_type = \ - _group_type_get_by_name(context, - values['name'], - session=session) - except exception.GroupTypeNotFoundByName: - pass + # a different group_type, it should be prevented. + conditions = and_(model.name == values['name'], + model.id != type_id, ~model.deleted) + query = session.query(sql.exists().where(conditions)) + if query.scalar(): + raise exists_exc(id=values['name']) + + query = model_query(context, model, project_only=True, session=session) + result = query.filter_by(id=type_id).update(values) + if not result: + if is_group: + raise exception.GroupTypeNotFound(group_type_id=type_id) else: - if check_grp_type.get('id') != group_type_id: - raise exception.GroupTypeExists(id=values['name']) + raise exception.VolumeTypeNotFound(volume_type_id=type_id) - group_type_ref.update(values) - group_type_ref.save(session=session) - return group_type_ref +def volume_type_update(context, volume_type_id, values): + _type_update(context, volume_type_id, values, is_group=False) + + +def group_type_update(context, group_type_id, values): + _type_update(context, group_type_id, values, is_group=True) @require_context @@ -4580,7 +4539,10 @@ def qos_specs_update(context, qos_specs_id, updates): session = get_session() with session.begin(): # make sure qos specs exists - _qos_specs_get_all_ref(context, qos_specs_id, session) + exists = resource_exists(context, models.QualityOfServiceSpecs, + qos_specs_id, session) + if not exists: + raise exception.QoSSpecsNotFound(specs_id=qos_specs_id) specs = updates.get('specs', {}) if 'consumer' in updates: @@ -4656,18 +4618,10 @@ def volume_type_encryption_create(context, volume_type_id, values): @handle_db_data_error @require_admin_context def volume_type_encryption_update(context, volume_type_id, values): - session = get_session() - with session.begin(): - encryption = volume_type_encryption_get(context, volume_type_id, - session) - - if not encryption: - raise exception.VolumeTypeEncryptionNotFound( - type_id=volume_type_id) - - encryption.update(values) - - return encryption + query = model_query(context, models.Encryption) + result = query.filter_by(volume_type_id=volume_type_id).update(values) + if not result: + raise exception.VolumeTypeEncryptionNotFound(type_id=volume_type_id) def volume_type_encryption_volume_get(context, volume_type_id, session=None): @@ -5053,19 +5007,13 @@ def backup_create(context, values): @handle_db_data_error @require_context def backup_update(context, backup_id, values): - session = get_session() - with session.begin(): - backup = model_query(context, models.Backup, - session=session, read_deleted="yes").\ - filter_by(id=backup_id).first() - - if not backup: - raise exception.BackupNotFound( - _("No backup with id %s") % backup_id) - - backup.update(values) - - return backup + if 'fail_reason' in values: + values = values.copy() + values['fail_reason'] = (values['fail_reason'] or '')[:255] + query = model_query(context, models.Backup, read_deleted="yes") + result = query.filter_by(id=backup_id).update(values) + if not result: + raise exception.BackupNotFound(backup_id=backup_id) @require_admin_context @@ -5381,21 +5329,11 @@ def consistencygroup_create(context, values, cg_snap_id=None, cg_id=None): @handle_db_data_error @require_context def consistencygroup_update(context, consistencygroup_id, values): - session = get_session() - with session.begin(): - result = model_query(context, models.ConsistencyGroup, - project_only=True).\ - filter_by(id=consistencygroup_id).\ - first() - - if not result: - raise exception.ConsistencyGroupNotFound( - _("No consistency group with id %s") % consistencygroup_id) - - result.update(values) - result.save(session=session) - - return result + query = model_query(context, models.ConsistencyGroup, project_only=True) + result = query.filter_by(id=consistencygroup_id).update(values) + if not result: + raise exception.ConsistencyGroupNotFound( + consistencygroup_id=consistencygroup_id) @require_admin_context @@ -5768,20 +5706,10 @@ def group_volume_type_mapping_create(context, group_id, volume_type_id): @handle_db_data_error @require_context def group_update(context, group_id, values): - session = get_session() - with session.begin(): - result = (model_query(context, models.Group, - project_only=True). - filter_by(id=group_id). - first()) - - if not result: - raise exception.GroupNotFound( - _("No group with id %s") % group_id) - - result.update(values) - result.save(session=session) - return result + query = model_query(context, models.Group, project_only=True) + result = query.filter_by(id=group_id).update(values) + if not result: + raise exception.GroupNotFound(group_id=group_id) @require_admin_context @@ -6104,19 +6032,10 @@ def cgsnapshot_create(context, values): @require_context @handle_db_data_error def cgsnapshot_update(context, cgsnapshot_id, values): - session = get_session() - with session.begin(): - result = model_query(context, models.Cgsnapshot, project_only=True).\ - filter_by(id=cgsnapshot_id).\ - first() - - if not result: - raise exception.CgSnapshotNotFound( - _("No cgsnapshot with id %s") % cgsnapshot_id) - - result.update(values) - result.save(session=session) - return result + query = model_query(context, models.Cgsnapshot, project_only=True) + result = query.filter_by(id=cgsnapshot_id).update(values) + if not result: + raise exception.CgSnapshotNotFound(cgsnapshot_id=cgsnapshot_id) @require_admin_context @@ -6737,13 +6656,14 @@ def worker_destroy(context, **filters): @require_context -def resource_exists(context, model, resource_id): +def resource_exists(context, model, resource_id, session=None): # Match non deleted resources by the id conditions = [model.id == resource_id, ~model.deleted] # If the context is not admin we limit it to the context's project if is_user_context(context) and hasattr(model, 'project_id'): conditions.append(model.project_id == context.project_id) - query = get_session().query(sql.exists().where(and_(*conditions))) + session = session or get_session() + query = session.query(sql.exists().where(and_(*conditions))) return query.scalar() diff --git a/cinder/tests/unit/api/v3/test_groups.py b/cinder/tests/unit/api/v3/test_groups.py index cb4c7723d3b..beb8e4b5e70 100644 --- a/cinder/tests/unit/api/v3/test_groups.py +++ b/cinder/tests/unit/api/v3/test_groups.py @@ -197,11 +197,16 @@ class GroupsAPITestCase(test.TestCase): def test_list_groups_json(self): self.group2.group_type_id = fake.GROUP_TYPE2_ID - self.group2.volume_type_ids = [fake.VOLUME_TYPE2_ID] + # TODO(geguileo): One `volume_type_ids` gets sorted out make proper + # changes here + # self.group2.volume_type_ids = [fake.VOLUME_TYPE2_ID] + self.group2.save() self.group3.group_type_id = fake.GROUP_TYPE3_ID - self.group3.volume_type_ids = [fake.VOLUME_TYPE3_ID] + # TODO(geguileo): One `volume_type_ids` gets sorted out make proper + # changes here + # self.group3.volume_type_ids = [fake.VOLUME_TYPE3_ID] self.group3.save() req = fakes.HTTPRequest.blank('/v3/%s/groups' % fake.PROJECT_ID, @@ -332,12 +337,14 @@ class GroupsAPITestCase(test.TestCase): objects=vol_type_objs) mock_vol_type_get_all_by_group.return_value = vol_types - self.group1.volume_type_ids = volume_type_ids - self.group1.save() - self.group2.volume_type_ids = volume_type_ids - self.group2.save() - self.group3.volume_type_ids = volume_type_ids - self.group3.save() + # TODO(geguileo): One `volume_type_ids` gets sorted out make proper + # changes here + # self.group1.volume_type_ids = volume_type_ids + # self.group1.save() + # self.group2.volume_type_ids = volume_type_ids + # self.group2.save() + # self.group3.volume_type_ids = volume_type_ids + # self.group3.save() req = fakes.HTTPRequest.blank('/v3/%s/groups/detail' % fake.PROJECT_ID, version=GROUP_MICRO_VERSION) @@ -680,7 +687,9 @@ class GroupsAPITestCase(test.TestCase): volume_type_id = fake.VOLUME_TYPE_ID self.group1.status = fields.GroupStatus.AVAILABLE self.group1.host = 'test_host' - self.group1.volume_type_ids = [volume_type_id] + # TODO(geguileo): One `volume_type_ids` gets sorted out make proper + # changes here + # self.group1.volume_type_ids = [volume_type_id] self.group1.save() remove_volume = utils.create_volume( diff --git a/cinder/tests/unit/db/test_volume_type.py b/cinder/tests/unit/db/test_volume_type.py index 67a4ea78efd..e052e772a78 100644 --- a/cinder/tests/unit/db/test_volume_type.py +++ b/cinder/tests/unit/db/test_volume_type.py @@ -84,9 +84,9 @@ class VolumeTypeTestCase(test.TestCase): updates = dict(name = 'test_volume_type_update', description = None, is_public = None) - updated_vol_type = db.volume_type_update( - self.ctxt, vol_type_ref.id, updates) - self.assertEqual('test_volume_type_update', updated_vol_type.name) + db.volume_type_update(self.ctxt, vol_type_ref.id, updates) + updated_vol_type = db.volume_type_get(self.ctxt, vol_type_ref.id) + self.assertEqual('test_volume_type_update', updated_vol_type['name']) volume_types.destroy(self.ctxt, vol_type_ref.id) def test_volume_type_get_with_qos_specs(self): diff --git a/cinder/tests/unit/test_db_api.py b/cinder/tests/unit/test_db_api.py index a6f2ea271f6..d898c84a127 100644 --- a/cinder/tests/unit/test_db_api.py +++ b/cinder/tests/unit/test_db_api.py @@ -2282,14 +2282,12 @@ class DBAPIEncryptionTestCase(BaseTest): self._assertEqualObjects(values[i], encryption, self._ignored_keys) def test_volume_type_encryption_update(self): - update_values = self._get_values(updated=True) - self.updated = \ - [db.volume_type_encryption_update(self.ctxt, - values['volume_type_id'], values) - for values in update_values] - for i, encryption in enumerate(self.updated): - self._assertEqualObjects(update_values[i], encryption, - self._ignored_keys) + for values in self._get_values(updated=True): + db.volume_type_encryption_update(self.ctxt, + values['volume_type_id'], values) + db_enc = db.volume_type_encryption_get(self.ctxt, + values['volume_type_id']) + self._assertEqualObjects(values, db_enc, self._ignored_keys) def test_volume_type_encryption_get(self): for encryption in self.created: @@ -2757,8 +2755,8 @@ class DBAPIBackupTestCase(BaseTest): def test_backup_update(self): updated_values = self._get_values(one=True) update_id = self.created[1]['id'] - updated_backup = db.backup_update(self.ctxt, update_id, - updated_values) + db.backup_update(self.ctxt, update_id, updated_values) + updated_backup = db.backup_get(self.ctxt, update_id) self._assertEqualObjects(updated_values, updated_backup, self._ignored_keys) @@ -2768,9 +2766,8 @@ class DBAPIBackupTestCase(BaseTest): updated_values['fail_reason'] = fail_reason update_id = self.created[1]['id'] - updated_backup = db.backup_update(self.ctxt, update_id, - updated_values) - + db.backup_update(self.ctxt, update_id, updated_values) + updated_backup = db.backup_get(self.ctxt, update_id) updated_values['fail_reason'] = fail_reason[:255] self._assertEqualObjects(updated_values, updated_backup, self._ignored_keys) diff --git a/cinder/tests/unit/test_volume_types.py b/cinder/tests/unit/test_volume_types.py index d51a6aaf668..cd457258375 100644 --- a/cinder/tests/unit/test_volume_types.py +++ b/cinder/tests/unit/test_volume_types.py @@ -78,10 +78,9 @@ class VolumeTypeTestCase(test.TestCase): # update new_type_name = self.vol_type1_name + '_updated' new_type_desc = self.vol_type1_description + '_updated' - type_ref_updated = volume_types.update(self.ctxt, - type_ref.id, - new_type_name, - new_type_desc) + volume_types.update(self.ctxt, type_ref.id, new_type_name, + new_type_desc) + type_ref_updated = volume_types.get_volume_type(self.ctxt, type_ref.id) self.assertEqual(new_type_name, type_ref_updated['name']) self.assertEqual(new_type_desc, type_ref_updated['description']) @@ -139,10 +138,9 @@ class VolumeTypeTestCase(test.TestCase): # update new_type_name = self.vol_type1_name + '_updated' new_type_desc = self.vol_type1_description + '_updated' - type_ref_updated = volume_types.update(self.ctxt, - type_ref.id, - new_type_name, - new_type_desc) + volume_types.update(self.ctxt, type_ref.id, new_type_name, + new_type_desc) + type_ref_updated = volume_types.get_volume_type(self.ctxt, type_ref.id) self.assertEqual(new_type_name, type_ref_updated['name']) self.assertEqual(new_type_desc, type_ref_updated['description']) diff --git a/cinder/volume/group_types.py b/cinder/volume/group_types.py index b898a1ef25e..70c72b4c1d5 100644 --- a/cinder/volume/group_types.py +++ b/cinder/volume/group_types.py @@ -60,15 +60,12 @@ def update(context, id, name, description, is_public=None): raise exception.InvalidGroupType(reason=msg) elevated = context if context.is_admin else context.elevated() try: - type_updated = db.group_type_update(elevated, - id, - dict(name=name, - description=description, - is_public=is_public)) + db.group_type_update(elevated, id, + dict(name=name, description=description, + is_public=is_public)) except db_exc.DBError: LOG.exception(_LE('DB error:')) raise exception.GroupTypeUpdateFailed(id=id) - return type_updated def destroy(context, id): diff --git a/cinder/volume/volume_types.py b/cinder/volume/volume_types.py index ddba3f13fd3..f2010f75416 100644 --- a/cinder/volume/volume_types.py +++ b/cinder/volume/volume_types.py @@ -72,11 +72,9 @@ def update(context, id, name, description, is_public=None): elevated = context if context.is_admin else context.elevated() old_volume_type = get_volume_type(elevated, id) try: - type_updated = db.volume_type_update(elevated, - id, - dict(name=name, - description=description, - is_public=is_public)) + db.volume_type_update(elevated, id, + dict(name=name, description=description, + is_public=is_public)) # Rename resource in quota if volume type name is changed. if name: old_type_name = old_volume_type.get('name') @@ -87,7 +85,6 @@ def update(context, id, name, description, is_public=None): except db_exc.DBError: LOG.exception(_LE('DB error:')) raise exception.VolumeTypeUpdateFailed(id=id) - return type_updated def destroy(context, id):