diff --git a/openstackclient/tests/unit/common/test_extension.py b/openstackclient/tests/unit/common/test_extension.py
index 765903b3b6..b22365899e 100644
--- a/openstackclient/tests/unit/common/test_extension.py
+++ b/openstackclient/tests/unit/common/test_extension.py
@@ -14,7 +14,6 @@ from unittest import mock
 
 from openstackclient.common import extension
 from openstackclient.tests.unit.compute.v2 import fakes as compute_fakes
-from openstackclient.tests.unit import fakes
 from openstackclient.tests.unit.identity.v2_0 import fakes as identity_fakes
 from openstackclient.tests.unit.network.v2 import fakes as network_fakes
 from openstackclient.tests.unit import utils
@@ -22,23 +21,13 @@ from openstackclient.tests.unit import utils as tests_utils
 from openstackclient.tests.unit.volume.v3 import fakes as volume_fakes
 
 
-class TestExtension(network_fakes.FakeClientMixin, utils.TestCommand):
-    def setUp(self):
-        super().setUp()
-
-        identity_client = identity_fakes.FakeIdentityv2Client(
-            endpoint=fakes.AUTH_URL,
-            token=fakes.AUTH_TOKEN,
-        )
-        self.app.client_manager.identity = identity_client
-        self.identity_extensions_mock = identity_client.extensions
-        self.identity_extensions_mock.reset_mock()
-
-        sdk_connection = self.app.client_manager.sdk_connection
-        self.compute_extensions_mock = sdk_connection.compute.extensions
-        self.compute_extensions_mock.reset_mock()
-        self.volume_extensions_mock = sdk_connection.volume.extensions
-        self.volume_extensions_mock.reset_mock()
+class TestExtension(
+    network_fakes.FakeClientMixin,
+    compute_fakes.FakeClientMixin,
+    volume_fakes.FakeClientMixin,
+    identity_fakes.FakeClientMixin,
+    utils.TestCommand,
+): ...
 
 
 class TestExtensionList(TestExtension):
@@ -60,11 +49,15 @@ class TestExtensionList(TestExtension):
     def setUp(self):
         super().setUp()
 
-        self.identity_extensions_mock.list.return_value = [
+        self.identity_client.extensions.list.return_value = [
             self.identity_extension
         ]
-        self.compute_extensions_mock.return_value = [self.compute_extension]
-        self.volume_extensions_mock.return_value = [self.volume_extension]
+        self.compute_sdk_client.extensions.return_value = [
+            self.compute_extension
+        ]
+        self.volume_sdk_client.extensions.return_value = [
+            self.volume_extension
+        ]
         self.network_client.extensions.return_value = [self.network_extension]
 
         # Get the command object to test
@@ -112,9 +105,9 @@ class TestExtensionList(TestExtension):
             ),
         )
         self._test_extension_list_helper(arglist, verifylist, datalist)
-        self.identity_extensions_mock.list.assert_called_with()
-        self.compute_extensions_mock.assert_called_with()
-        self.volume_extensions_mock.assert_called_with()
+        self.identity_client.extensions.list.assert_called_with()
+        self.compute_sdk_client.extensions.assert_called_with()
+        self.volume_sdk_client.extensions.assert_called_with()
         self.network_client.extensions.assert_called_with()
 
     def test_extension_list_long(self):
@@ -159,9 +152,9 @@ class TestExtensionList(TestExtension):
             ),
         )
         self._test_extension_list_helper(arglist, verifylist, datalist, True)
-        self.identity_extensions_mock.list.assert_called_with()
-        self.compute_extensions_mock.assert_called_with()
-        self.volume_extensions_mock.assert_called_with()
+        self.identity_client.extensions.list.assert_called_with()
+        self.compute_sdk_client.extensions.assert_called_with()
+        self.volume_sdk_client.extensions.assert_called_with()
         self.network_client.extensions.assert_called_with()
 
     def test_extension_list_identity(self):
@@ -179,7 +172,7 @@ class TestExtensionList(TestExtension):
             ),
         )
         self._test_extension_list_helper(arglist, verifylist, datalist)
-        self.identity_extensions_mock.list.assert_called_with()
+        self.identity_client.extensions.list.assert_called_with()
 
     def test_extension_list_network(self):
         arglist = [
@@ -237,7 +230,7 @@ class TestExtensionList(TestExtension):
             ),
         )
         self._test_extension_list_helper(arglist, verifylist, datalist)
-        self.compute_extensions_mock.assert_called_with()
+        self.compute_sdk_client.extensions.assert_called_with()
 
     def test_extension_list_compute_and_network(self):
         arglist = [
@@ -261,7 +254,7 @@ class TestExtensionList(TestExtension):
             ),
         )
         self._test_extension_list_helper(arglist, verifylist, datalist)
-        self.compute_extensions_mock.assert_called_with()
+        self.compute_sdk_client.extensions.assert_called_with()
         self.network_client.extensions.assert_called_with()
 
     def test_extension_list_volume(self):
@@ -279,7 +272,7 @@ class TestExtensionList(TestExtension):
             ),
         )
         self._test_extension_list_helper(arglist, verifylist, datalist)
-        self.volume_extensions_mock.assert_called_with()
+        self.volume_sdk_client.extensions.assert_called_with()
 
 
 class TestExtensionShow(TestExtension):