Merge "Fix failure with "None" volume type in Pure drivers"

This commit is contained in:
Jenkins
2016-07-12 01:18:29 +00:00
committed by Gerrit Code Review
2 changed files with 19 additions and 7 deletions

View File

@@ -1594,9 +1594,19 @@ class PureBaseVolumeDriverTestCase(PureBaseSharedDriverTestCase):
def test_is_vol_replicated_has_repl_extra_specs(self, mock_get_vol_type):
mock_get_vol_type.return_value = REPLICATED_VOL_TYPE
volume = fake_volume.fake_volume_obj(mock.MagicMock())
volume.volume_type_id = REPLICATED_VOL_TYPE['id']
actual = self.driver._is_volume_replicated_type(volume)
self.assertTrue(actual)
@mock.patch('cinder.volume.volume_types.get_volume_type')
def test_is_vol_replicated_none_type(self, mock_get_vol_type):
mock_get_vol_type.side_effect = exception.InvalidVolumeType(reason='')
volume = fake_volume.fake_volume_obj(mock.MagicMock())
volume.volume_type = None
volume.volume_type_id = None
actual = self.driver._is_volume_replicated_type(volume)
self.assertFalse(actual)
@mock.patch('cinder.volume.volume_types.get_volume_type')
def test_is_vol_replicated_has_other_extra_specs(self, mock_get_vol_type):
vtype_test = deepcopy(NON_REPLICATED_VOL_TYPE)

View File

@@ -1423,14 +1423,16 @@ class PureBaseVolumeDriver(san.SanDriver):
def _is_volume_replicated_type(self, volume):
ctxt = context.get_admin_context()
volume_type = volume_types.get_volume_type(ctxt,
volume["volume_type_id"])
replication_flag = False
specs = volume_type.get("extra_specs")
if specs and EXTRA_SPECS_REPL_ENABLED in specs:
replication_capability = specs[EXTRA_SPECS_REPL_ENABLED]
# Do not validate settings, ignore invalid.
replication_flag = (replication_capability == "<is> True")
if volume["volume_type_id"]:
volume_type = volume_types.get_volume_type(
ctxt, volume["volume_type_id"])
specs = volume_type.get("extra_specs")
if specs and EXTRA_SPECS_REPL_ENABLED in specs:
replication_capability = specs[EXTRA_SPECS_REPL_ENABLED]
# Do not validate settings, ignore invalid.
replication_flag = (replication_capability == "<is> True")
return replication_flag
def _find_failover_target(self, secondary):