From 242f3c2ffa267006f0dbba098bbec8fe3e92d620 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mart=C3=ADn=20Nieto?= Date: Fri, 9 Mar 2012 12:12:05 +0100 Subject: [PATCH] TreeBuilder: allow the source to be a Tree If the user passes a tree, use it as the source for the TreeBuilder. On the way, make sure we free the tree we looked up, and fix a test to make sure the TreeBuilder starts empty. --- pygit2.c | 29 +++++++++++++++++++---------- test/test_treebuilder.py | 8 +++++++- 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/pygit2.c b/pygit2.c index 0d9fd53..2a7c317 100644 --- a/pygit2.c +++ b/pygit2.c @@ -942,27 +942,36 @@ Repository_TreeBuilder(Repository *self, PyObject *args) { TreeBuilder *builder; git_treebuilder *bld; - PyObject *py_oid = NULL; + PyObject *py_src = NULL; size_t oid_len; git_oid oid; git_tree *tree = NULL; int err; - if (!PyArg_ParseTuple(args, "|O", &py_oid)) + if (!PyArg_ParseTuple(args, "|O", &py_src)) return NULL; - if (py_oid) { - oid_len = py_str_to_git_oid(py_oid, &oid); - TODO_SUPPORT_SHORT_HEXS(oid_len) - if (oid_len == 0) - return NULL; + if (py_src) { + if (PyObject_TypeCheck(py_src, &TreeType)) { + Tree *py_tree = (Tree *)py_src; + if (py_tree->repo->repo != self->repo) { + return Error_set(GIT_EINVALIDARGS); + } + tree = py_tree->tree; + } else { + oid_len = py_str_to_git_oid(py_src, &oid); + TODO_SUPPORT_SHORT_HEXS(oid_len) + if (oid_len == 0) + return NULL; - err = git_tree_lookup(&tree, self->repo, &oid); - if (err < 0) - return Error_set(err); + err = git_tree_lookup(&tree, self->repo, &oid); + if (err < 0) + return Error_set(err); + } } err = git_treebuilder_create(&bld, tree); + git_tree_free(tree); if (err < 0) return Error_set(err); diff --git a/test/test_treebuilder.py b/test/test_treebuilder.py index d1b003b..f9628b1 100644 --- a/test/test_treebuilder.py +++ b/test/test_treebuilder.py @@ -50,9 +50,15 @@ class TreeBuilderTest(utils.BareRepoTestCase): result = bld.write() self.assertEqual(tree.oid, result) + def test_noop_treebuilder_from_tree(self): + tree = self.repo[TREE_SHA] + bld = self.repo.TreeBuilder(tree) + result = bld.write() + self.assertEqual(tree.oid, result) + def test_rebuild_treebuilder(self): tree = self.repo[TREE_SHA] - bld = self.repo.TreeBuilder(TREE_SHA) + bld = self.repo.TreeBuilder() for e in tree: bld.insert(e)