diff --git a/cinder/tests/unit/test_pure.py b/cinder/tests/unit/test_pure.py index 88116cac129..6625948e420 100644 --- a/cinder/tests/unit/test_pure.py +++ b/cinder/tests/unit/test_pure.py @@ -1594,9 +1594,19 @@ class PureBaseVolumeDriverTestCase(PureBaseSharedDriverTestCase): def test_is_vol_replicated_has_repl_extra_specs(self, mock_get_vol_type): mock_get_vol_type.return_value = REPLICATED_VOL_TYPE volume = fake_volume.fake_volume_obj(mock.MagicMock()) + volume.volume_type_id = REPLICATED_VOL_TYPE['id'] actual = self.driver._is_volume_replicated_type(volume) self.assertTrue(actual) + @mock.patch('cinder.volume.volume_types.get_volume_type') + def test_is_vol_replicated_none_type(self, mock_get_vol_type): + mock_get_vol_type.side_effect = exception.InvalidVolumeType(reason='') + volume = fake_volume.fake_volume_obj(mock.MagicMock()) + volume.volume_type = None + volume.volume_type_id = None + actual = self.driver._is_volume_replicated_type(volume) + self.assertFalse(actual) + @mock.patch('cinder.volume.volume_types.get_volume_type') def test_is_vol_replicated_has_other_extra_specs(self, mock_get_vol_type): vtype_test = deepcopy(NON_REPLICATED_VOL_TYPE) diff --git a/cinder/volume/drivers/pure.py b/cinder/volume/drivers/pure.py index 34b7e99d322..29c663e67b7 100644 --- a/cinder/volume/drivers/pure.py +++ b/cinder/volume/drivers/pure.py @@ -1423,14 +1423,16 @@ class PureBaseVolumeDriver(san.SanDriver): def _is_volume_replicated_type(self, volume): ctxt = context.get_admin_context() - volume_type = volume_types.get_volume_type(ctxt, - volume["volume_type_id"]) replication_flag = False - specs = volume_type.get("extra_specs") - if specs and EXTRA_SPECS_REPL_ENABLED in specs: - replication_capability = specs[EXTRA_SPECS_REPL_ENABLED] - # Do not validate settings, ignore invalid. - replication_flag = (replication_capability == " True") + if volume["volume_type_id"]: + volume_type = volume_types.get_volume_type( + ctxt, volume["volume_type_id"]) + + specs = volume_type.get("extra_specs") + if specs and EXTRA_SPECS_REPL_ENABLED in specs: + replication_capability = specs[EXTRA_SPECS_REPL_ENABLED] + # Do not validate settings, ignore invalid. + replication_flag = (replication_capability == " True") return replication_flag def _find_failover_target(self, secondary):