diff --git a/manilaclient/client.py b/manilaclient/client.py index 485860e47..5c0e20815 100644 --- a/manilaclient/client.py +++ b/manilaclient/client.py @@ -41,15 +41,29 @@ def get_client_class(version): return importutils.import_class(client_path) -def Client(api_version, *args, **kwargs): - if not hasattr(api_version, 'get_major_version'): - if api_version in ('1', '1.0'): - api_version = api_versions.APIVersion( - api_versions.DEPRECATED_VERSION) - elif api_version == '2': - api_version = api_versions.APIVersion(api_versions.MIN_VERSION) +def Client(client_version, *args, **kwargs): + + def _convert_to_api_version(version): + """Convert version to an APIVersion object unless it already is one.""" + + if hasattr(version, 'get_major_version'): + api_version = version else: - api_version = api_versions.APIVersion(api_version) + if version in ('1', '1.0'): + api_version = api_versions.APIVersion( + api_versions.DEPRECATED_VERSION) + elif version == '2': + api_version = api_versions.APIVersion(api_versions.MIN_VERSION) + else: + api_version = api_versions.APIVersion(version) + return api_version + + api_version = _convert_to_api_version(client_version) client_class = get_client_class(api_version.get_major_version()) - kwargs['api_version'] = api_version + + # Make sure the kwarg api_version is set with an APIVersion object. + # 1st choice is to use the incoming kwarg. 2nd choice is the positional. + kwargs['api_version'] = _convert_to_api_version( + kwargs.get('api_version', api_version)) + return client_class(*args, **kwargs) diff --git a/manilaclient/tests/unit/test_client.py b/manilaclient/tests/unit/test_client.py index 81744bf48..a3cc778e9 100644 --- a/manilaclient/tests/unit/test_client.py +++ b/manilaclient/tests/unit/test_client.py @@ -75,3 +75,59 @@ class ClientTest(utils.TestCase): @ddt.data(None, '', '3', 'v1', 'v2', 'v1.0', 'v2.0') def test_init_client_with_unsupported_version(self, v): self.assertRaises(exceptions.UnsupportedVersion, client.Client, v) + + @ddt.data( + ('1', '1.0'), + ('1', '2.0'), + ('1', '2.7'), + ('1', None), + ('1.0', '1.0'), + ('1.0', '2.0'), + ('1.0', '2.7'), + ('1.0', None), + ('2', '1.0'), + ('2', '2.0'), + ('2', '2.7'), + ('2', None), + ) + @ddt.unpack + def test_init_client_with_version_parms(self, pos, kw): + + major = int(float(pos)) + pos_av = mock.Mock() + kw_av = mock.Mock() + + with mock.patch.object(manilaclient.v1.client, 'Client'): + with mock.patch.object(manilaclient.v2.client, 'Client'): + with mock.patch.object(api_versions, 'APIVersion'): + api_versions.APIVersion.side_effect = [pos_av, kw_av] + pos_av.get_major_version.return_value = str(major) + + if kw is None: + manilaclient.client.Client(pos, 'foo') + expected_av = pos_av + else: + manilaclient.client.Client(pos, 'foo', api_version=kw) + expected_av = kw_av + + if int(float(pos)) == 1: + expected_client_ver = api_versions.DEPRECATED_VERSION + self.assertFalse(manilaclient.v2.client.Client.called) + manilaclient.v1.client.Client.assert_has_calls([ + mock.call('foo', api_version=expected_av) + ]) + else: + expected_client_ver = api_versions.MIN_VERSION + self.assertFalse(manilaclient.v1.client.Client.called) + manilaclient.v2.client.Client.assert_has_calls([ + mock.call('foo', api_version=expected_av) + ]) + + if kw is None: + api_versions.APIVersion.assert_called_once_with( + expected_client_ver) + else: + api_versions.APIVersion.assert_has_calls([ + mock.call(expected_client_ver), + mock.call(kw), + ])