diff --git a/test/test_blob.py b/test/test_blob.py index 8105455..2f52596 100644 --- a/test/test_blob.py +++ b/test/test_blob.py @@ -93,13 +93,12 @@ class BlobTest(utils.RepoTestCase): def test_create_blob_outside_workdir(self): - path = join(dirname(__file__), 'data', self.repo_dir + '.tar') + path = __file__ self.assertRaises(KeyError, self.repo.create_blob_fromworkdir, path) def test_create_blob_fromdisk(self): - path = join(dirname(__file__), 'data', self.repo_dir + '.tar') - blob_oid = self.repo.create_blob_fromdisk(path) + blob_oid = self.repo.create_blob_fromdisk(__file__) blob = self.repo[blob_oid] self.assertTrue(isinstance(blob, pygit2.Blob)) diff --git a/test/test_repository.py b/test/test_repository.py index 87c9fc7..ab0b22c 100644 --- a/test/test_repository.py +++ b/test/test_repository.py @@ -143,7 +143,7 @@ class RepositoryTest(utils.BareRepoTestCase): def test_get_path(self): directory = realpath(self.repo.path) - expected = realpath(join(self._temp_dir, 'testrepo.git')) + expected = realpath(self.repo_path) self.assertEqual(directory, expected) def test_get_workdir(self): @@ -179,12 +179,12 @@ class RepositoryTest_II(utils.RepoTestCase): def test_get_path(self): directory = realpath(self.repo.path) - expected = realpath(join(self._temp_dir, 'testrepo', '.git')) + expected = realpath(join(self.repo_path, '.git')) self.assertEqual(directory, expected) def test_get_workdir(self): directory = realpath(self.repo.workdir) - expected = realpath(join(self._temp_dir, 'testrepo')) + expected = realpath(self.repo_path) self.assertEqual(directory, expected) def test_checkout_ref(self): diff --git a/test/utils.py b/test/utils.py index 10a4e17..fb7ca58 100644 --- a/test/utils.py +++ b/test/utils.py @@ -65,6 +65,27 @@ def rmtree(path): shutil.rmtree(path, onerror=onerror) +class TemporaryRepository(object): + def __init__(self, repo_spec): + self.repo_spec = repo_spec + + def __enter__(self): + container, name = self.repo_spec + repo_path = os.path.join(os.path.dirname(__file__), 'data', name) + self.temp_dir = tempfile.mkdtemp() + temp_repo_path = os.path.join(self.temp_dir, name) + if container == 'tar': + tar = tarfile.open('.'.join((repo_path, 'tar'))) + tar.extractall(self.temp_dir) + tar.close() + else: + shutil.copytree(repo_path, temp_repo_path) + return temp_repo_path + + def __exit__(self, exc_type, exc_value, traceback): + rmtree(self.temp_dir) + + class NoRepoTestCase(unittest.TestCase): def setUp(self): @@ -103,45 +124,33 @@ class NoRepoTestCase(unittest.TestCase): self.assertEqual(a.offset, b.offset) -class BareRepoTestCase(NoRepoTestCase): - - repo_dir = 'testrepo.git' - +class AutoRepoTestCase(NoRepoTestCase): def setUp(self): - super(BareRepoTestCase, self).setUp() + super(AutoRepoTestCase, self).setUp() + self.repo_ctxtmgr = TemporaryRepository(self.repo_spec) + self.repo_path = self.repo_ctxtmgr.__enter__() + self.repo = pygit2.Repository(self.repo_path) - repo_dir = self.repo_dir - repo_path = os.path.join(os.path.dirname(__file__), 'data', repo_dir) - temp_repo_path = os.path.join(self._temp_dir, repo_dir) - - shutil.copytree(repo_path, temp_repo_path) - - self.repo = pygit2.Repository(temp_repo_path) + def tearDown(self): + self.repo_ctxtmgr.__exit__(None, None, None) + super(AutoRepoTestCase, self).tearDown() -class RepoTestCase(NoRepoTestCase): +class BareRepoTestCase(AutoRepoTestCase): - repo_dir = 'testrepo' - - def setUp(self): - super(RepoTestCase, self).setUp() - - repo_dir = self.repo_dir - repo_path = os.path.join(os.path.dirname(__file__), 'data', repo_dir) - temp_repo_path = os.path.join(self._temp_dir, repo_dir, '.git') - - tar = tarfile.open(repo_path + '.tar') - tar.extractall(self._temp_dir) - tar.close() - - self.repo = pygit2.Repository(temp_repo_path) + repo_spec = 'git', 'testrepo.git' -class DirtyRepoTestCase(RepoTestCase): +class RepoTestCase(AutoRepoTestCase): - repo_dir = 'dirtyrepo' + repo_spec = 'tar', 'testrepo' -class EmptyRepoTestCase(RepoTestCase): +class DirtyRepoTestCase(AutoRepoTestCase): - repo_dir = 'emptyrepo' + repo_spec = 'tar', 'dirtyrepo' + + +class EmptyRepoTestCase(AutoRepoTestCase): + + repo_spec = 'tar', 'emptyrepo'