diff --git a/src/branch.c b/src/branch.c index a3b4c78..48f5bba 100644 --- a/src/branch.c +++ b/src/branch.c @@ -31,7 +31,10 @@ #include "reference.h" #include "utils.h" + extern PyObject *GitError; +extern PyTypeObject ReferenceType; + PyDoc_STRVAR(Branch_delete__doc__, "delete()\n" @@ -105,6 +108,151 @@ PyObject* Branch_rename(Branch *self, PyObject *args) } +PyDoc_STRVAR(Branch_branch_name__doc__, + "The name of the local or remote branch."); + +PyObject* Branch_branch_name__get__(Branch *self) +{ + int err; + const char *c_name; + + CHECK_REFERENCE(self); + + err = git_branch_name(&c_name, self->reference); + if (err == GIT_OK) + return to_unicode(c_name, NULL, NULL); + else + return Error_set(err); +} + + +PyDoc_STRVAR(Branch_remote_name__doc__, + "The name of the remote that the remote tracking branch belongs to."); + +PyObject* Branch_remote_name__get__(Branch *self) +{ + int err; + const char *branch_name; + char *c_name = NULL; + + CHECK_REFERENCE(self); + + branch_name = git_reference_name(self->reference); + // get the length of the remote name + err = git_branch_remote_name(NULL, 0, self->repo->repo, branch_name); + if (err < GIT_OK) + return Error_set(err); + + // get the actual remote name + c_name = calloc(err, sizeof(char)); + if (c_name == NULL) + return PyErr_NoMemory(); + + err = git_branch_remote_name(c_name, + err * sizeof(char), + self->repo->repo, + branch_name); + if (err < GIT_OK) { + free(c_name); + return Error_set(err); + } + + PyObject *py_name = to_unicode(c_name, NULL, NULL); + free(c_name); + + return py_name; +} + + +PyDoc_STRVAR(Branch_upstream__doc__, + "The branch supporting the remote tracking branch or None if this is not a " + "remote tracking branch. Set to None to unset."); + +PyObject* Branch_upstream__get__(Branch *self) +{ + int err; + git_reference *c_reference; + + CHECK_REFERENCE(self); + + err = git_branch_upstream(&c_reference, self->reference); + if (err == GIT_ENOTFOUND) + Py_RETURN_NONE; + else if (err < GIT_OK) + return Error_set(err); + + return wrap_branch(c_reference, self->repo); +} + +int Branch_upstream__set__(Branch *self, Reference *py_ref) +{ + int err; + const char *branch_name = NULL; + + CHECK_REFERENCE_INT(self); + + if ((PyObject *)py_ref != Py_None) { + if (!PyObject_TypeCheck(py_ref, (PyTypeObject *)&ReferenceType)) { + PyErr_SetObject(PyExc_TypeError, (PyObject *)py_ref); + return -1; + } + + CHECK_REFERENCE_INT(py_ref); + err = git_branch_name(&branch_name, py_ref->reference); + if (err < GIT_OK) { + Error_set(err); + return -1; + } + } + + err = git_branch_set_upstream(self->reference, branch_name); + if (err < GIT_OK) { + Error_set(err); + return -1; + } + + return 0; +} + + +PyDoc_STRVAR(Branch_upstream_name__doc__, + "The name of the reference supporting the remote tracking branch."); + +PyObject* Branch_upstream_name__get__(Branch *self) +{ + int err; + const char *branch_name; + char *c_name = NULL; + + CHECK_REFERENCE(self); + + branch_name = git_reference_name(self->reference); + // get the length of the upstream name + err = git_branch_upstream_name(NULL, 0, self->repo->repo, branch_name); + if (err < GIT_OK) + return Error_set(err); + + // get the actual upstream name + c_name = calloc(err, sizeof(char)); + if (c_name == NULL) + return PyErr_NoMemory(); + + err = git_branch_upstream_name(c_name, + err * sizeof(char), + self->repo->repo, + branch_name); + if (err < GIT_OK) { + free(c_name); + return Error_set(err); + } + + PyObject *py_name = to_unicode(c_name, NULL, NULL); + free(c_name); + + return py_name; +} + + PyMethodDef Branch_methods[] = { METHOD(Branch, delete, METH_NOARGS), METHOD(Branch, is_head, METH_NOARGS), @@ -113,6 +261,10 @@ PyMethodDef Branch_methods[] = { }; PyGetSetDef Branch_getseters[] = { + GETTER(Branch, branch_name), + GETTER(Branch, remote_name), + GETSET(Branch, upstream), + GETTER(Branch, upstream_name), {NULL} }; diff --git a/test/test_branch.py b/test/test_branch.py index eebf1e0..8dc62cb 100644 --- a/test/test_branch.py +++ b/test/test_branch.py @@ -110,6 +110,15 @@ class BranchesTestCase(utils.RepoTestCase): self.assertRaises(ValueError, lambda: original_branch.rename('abc@{123')) + def test_branch_name(self): + branch = self.repo.lookup_branch('master') + self.assertEqual(branch.branch_name, 'master') + self.assertEqual(branch.name, 'refs/heads/master') + + branch = self.repo.lookup_branch('i18n') + self.assertEqual(branch.branch_name, 'i18n') + self.assertEqual(branch.name, 'refs/heads/i18n') + class BranchesEmptyRepoTestCase(utils.EmptyRepoTestCase): def setUp(self): @@ -131,6 +140,40 @@ class BranchesEmptyRepoTestCase(utils.EmptyRepoTestCase): branches = sorted(self.repo.listall_branches(pygit2.GIT_BRANCH_REMOTE)) self.assertEqual(branches, ['origin/master']) + def test_branch_remote_name(self): + self.repo.remotes[0].fetch() + branch = self.repo.lookup_branch('origin/master', + pygit2.GIT_BRANCH_REMOTE) + self.assertEqual(branch.remote_name, 'origin') + + def test_branch_upstream(self): + self.repo.remotes[0].fetch() + remote_master = self.repo.lookup_branch('origin/master', + pygit2.GIT_BRANCH_REMOTE) + master = self.repo.create_branch('master', + self.repo[remote_master.target.hex]) + + self.assertTrue(master.upstream is None) + master.upstream = remote_master + self.assertEqual(master.upstream.branch_name, 'origin/master') + + def set_bad_upstream(): + master.upstream = 2.5 + self.assertRaises(TypeError, set_bad_upstream) + + master.upstream = None + self.assertTrue(master.upstream is None) + + def test_branch_upstream_name(self): + self.repo.remotes[0].fetch() + remote_master = self.repo.lookup_branch('origin/master', + pygit2.GIT_BRANCH_REMOTE) + master = self.repo.create_branch('master', + self.repo[remote_master.target.hex]) + + master.upstream = remote_master + self.assertEqual(master.upstream_name, 'refs/remotes/origin/master') + if __name__ == '__main__': unittest.main()