diff --git a/test/unit/__init__.py b/test/unit/__init__.py
index 596a5f84c6..c94ae073db 100644
--- a/test/unit/__init__.py
+++ b/test/unit/__init__.py
@@ -121,7 +121,7 @@ def patch_policies(thing_or_policies=None, legacy_only=False,
 class PatchPolicies(object):
     """
     Why not mock.patch?  In my case, when used as a decorator on the class it
-    seemed to patch setUp at the wrong time (i.e. in setup the global wasn't
+    seemed to patch setUp at the wrong time (i.e. in setUp the global wasn't
     patched yet)
     """
 
@@ -168,42 +168,38 @@ class PatchPolicies(object):
         """
 
         orig_setUp = cls.setUp
-        orig_tearDown = cls.tearDown
+
+        def unpatch_cleanup(cls_self):
+            if cls_self._policies_patched:
+                self.__exit__()
+                cls_self._policies_patched = False
 
         def setUp(cls_self):
-            self._orig_POLICIES = storage_policy._POLICIES
             if not getattr(cls_self, '_policies_patched', False):
-                storage_policy._POLICIES = self.policies
-                self._setup_rings()
+                self.__enter__()
                 cls_self._policies_patched = True
-
+                cls_self.addCleanup(unpatch_cleanup, cls_self)
             orig_setUp(cls_self)
 
-        def tearDown(cls_self):
-            orig_tearDown(cls_self)
-            storage_policy._POLICIES = self._orig_POLICIES
-
         cls.setUp = setUp
-        cls.tearDown = tearDown
 
         return cls
 
     def _patch_method(self, f):
         @functools.wraps(f)
         def mywrapper(*args, **kwargs):
-            self._orig_POLICIES = storage_policy._POLICIES
-            try:
-                storage_policy._POLICIES = self.policies
-                self._setup_rings()
+            with self:
                 return f(*args, **kwargs)
-            finally:
-                storage_policy._POLICIES = self._orig_POLICIES
         return mywrapper
 
     def __enter__(self):
         self._orig_POLICIES = storage_policy._POLICIES
         storage_policy._POLICIES = self.policies
-        self._setup_rings()
+        try:
+            self._setup_rings()
+        except:  # noqa
+            self.__exit__()
+            raise
 
     def __exit__(self, *args):
         storage_policy._POLICIES = self._orig_POLICIES