From 35f094b57b00adfa4edd6e5df8b4a49f08464409 Mon Sep 17 00:00:00 2001
From: Dave Borowitz <dborowitz@google.com>
Date: Mon, 1 Nov 2010 14:37:57 -0700
Subject: [PATCH] Add Tree and TreeEntry classes, with tests.

Change-Id: Idadd5fc59b85506260a1f57b5e7488aed590bfa1
---
 pygit2.c          | 401 +++++++++++++++++++++++++++++++++++++++++++++-
 test/__init__.py  |   2 +-
 test/test_tree.py | 134 ++++++++++++++++
 3 files changed, 535 insertions(+), 2 deletions(-)
 create mode 100644 test/test_tree.py

diff --git a/pygit2.c b/pygit2.c
index ef3ee6b..0c90d03 100644
--- a/pygit2.c
+++ b/pygit2.c
@@ -49,10 +49,19 @@ typedef struct {
 
 OBJECT_STRUCT(Object, git_object, obj)
 OBJECT_STRUCT(Commit, git_commit, commit)
+OBJECT_STRUCT(Tree, git_tree, tree)
+
+typedef struct {
+    PyObject_HEAD
+    git_tree_entry *entry;
+    Tree *tree;
+} TreeEntry;
 
 static PyTypeObject RepositoryType;
 static PyTypeObject ObjectType;
 static PyTypeObject CommitType;
+static PyTypeObject TreeEntryType;
+static PyTypeObject TreeType;
 
 static int
 Repository_init(Repository *self, PyObject *args, PyObject *kwds) {
@@ -106,7 +115,7 @@ static Object *wrap_object(git_object *obj, Repository *repo) {
             py_obj = (Object*)CommitType.tp_alloc(&CommitType, 0);
             break;
         case GIT_OBJ_TREE:
-            py_obj = (Object*)ObjectType.tp_alloc(&ObjectType, 0);
+            py_obj = (Object*)TreeType.tp_alloc(&TreeType, 0);
             break;
         case GIT_OBJ_BLOB:
             py_obj = (Object*)ObjectType.tp_alloc(&ObjectType, 0);
@@ -533,6 +542,382 @@ static PyTypeObject CommitType = {
     0,                                         /* tp_new */
 };
 
+static void
+TreeEntry_dealloc(TreeEntry *self) {
+    Py_XDECREF(self->tree);
+    self->ob_type->tp_free((PyObject *)self);
+}
+
+static PyObject *
+TreeEntry_get_attributes(TreeEntry *self) {
+    return PyInt_FromLong(git_tree_entry_attributes(self->entry));
+}
+
+static int
+TreeEntry_set_attributes(TreeEntry *self, PyObject *value) {
+    unsigned int attributes;
+    attributes = PyInt_AsLong(value);
+    if (PyErr_Occurred())
+        return -1;
+    git_tree_entry_set_attributes(self->entry, attributes);
+    return 0;
+}
+
+static PyObject *
+TreeEntry_get_name(TreeEntry *self) {
+    return PyString_FromString(git_tree_entry_name(self->entry));
+}
+
+static int
+TreeEntry_set_name(TreeEntry *self, PyObject *value) {
+    char *name;
+    name = PyString_AsString(value);
+    if (!name)
+        return -1;
+    git_tree_entry_set_name(self->entry, name);
+    return 0;
+}
+
+static PyObject *
+TreeEntry_get_sha(TreeEntry *self) {
+    char hex[GIT_OID_HEXSZ];
+    git_oid_fmt(hex, git_tree_entry_id(self->entry));
+    return PyString_FromStringAndSize(hex, GIT_OID_HEXSZ);
+}
+
+static int
+TreeEntry_set_sha(TreeEntry *self, PyObject *value) {
+    char *hex;
+    git_oid oid;
+
+    hex = PyString_AsString(value);
+    if (!hex)
+        return -1;
+    if (git_oid_mkstr(&oid, hex) < 0) {
+        PyErr_Format(PyExc_ValueError, "Invalid hex SHA \"%s\"", hex);
+        return -1;
+    }
+    git_tree_entry_set_id(self->entry, &oid);
+    return 0;
+}
+
+static PyObject *
+TreeEntry_to_object(TreeEntry *self) {
+    git_object *obj;
+    char hex[GIT_OID_HEXSZ];
+    PyObject *py_hex;
+
+    obj = git_tree_entry_2object(self->entry);
+    if (!obj) {
+        git_oid_fmt(hex, git_tree_entry_id(self->entry));
+        py_hex = PyString_FromStringAndSize(hex, GIT_OID_HEXSZ);
+        PyErr_SetObject(PyExc_KeyError, py_hex);
+        return NULL;
+    }
+    return (PyObject*)wrap_object(obj, self->tree->repo);
+}
+
+static PyGetSetDef TreeEntry_getseters[] = {
+    {"attributes", (getter)TreeEntry_get_attributes,
+     (setter)TreeEntry_set_attributes, "attributes", NULL},
+    {"name", (getter)TreeEntry_get_name, (setter)TreeEntry_set_name, "name",
+     NULL},
+    {"sha", (getter)TreeEntry_get_sha, (setter)TreeEntry_set_sha, "sha", NULL},
+    {NULL}
+};
+
+static PyMethodDef TreeEntry_methods[] = {
+    {"to_object", (PyCFunction)TreeEntry_to_object, METH_NOARGS,
+     "Look up the corresponding object in the repo."},
+    {NULL, NULL, 0, NULL}
+};
+
+static PyTypeObject TreeEntryType = {
+    PyObject_HEAD_INIT(NULL)
+    0,                                         /*ob_size*/
+    "pygit2.TreeEntry",                        /*tp_name*/
+    sizeof(TreeEntry),                         /*tp_basicsize*/
+    0,                                         /*tp_itemsize*/
+    (destructor)TreeEntry_dealloc,             /*tp_dealloc*/
+    0,                                         /*tp_print*/
+    0,                                         /*tp_getattr*/
+    0,                                         /*tp_setattr*/
+    0,                                         /*tp_compare*/
+    0,                                         /*tp_repr*/
+    0,                                         /*tp_as_number*/
+    0,                                         /*tp_as_sequence*/
+    0,                                         /*tp_as_mapping*/
+    0,                                         /*tp_hash */
+    0,                                         /*tp_call*/
+    0,                                         /*tp_str*/
+    0,                                         /*tp_getattro*/
+    0,                                         /*tp_setattro*/
+    0,                                         /*tp_as_buffer*/
+    Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE,  /*tp_flags*/
+    "TreeEntry objects",                       /* tp_doc */
+    0,                                         /* tp_traverse */
+    0,                                         /* tp_clear */
+    0,                                         /* tp_richcompare */
+    0,                                         /* tp_weaklistoffset */
+    0,                                         /* tp_iter */
+    0,                                         /* tp_iternext */
+    TreeEntry_methods,                         /* tp_methods */
+    0,                                         /* tp_members */
+    TreeEntry_getseters,                       /* tp_getset */
+    0,                                         /* tp_base */
+    0,                                         /* tp_dict */
+    0,                                         /* tp_descr_get */
+    0,                                         /* tp_descr_set */
+    0,                                         /* tp_dictoffset */
+    0,                                         /* tp_init */
+    0,                                         /* tp_alloc */
+    0,                                         /* tp_new */
+};
+
+static int
+Tree_init(Tree *py_tree, PyObject *args, PyObject *kwds) {
+    Repository *repo = NULL;
+    git_tree *tree;
+
+    if (!object_init_check("Tree", args, kwds, &repo))
+        return -1;
+
+    tree = git_tree_new(repo->repo);
+    if (!tree) {
+        PyErr_SetNone(PyExc_MemoryError);
+        return -1;
+    }
+    Py_INCREF(repo);
+    py_tree->repo = repo;
+    py_tree->own_obj = 1;
+    py_tree->tree = tree;
+    return 0;
+}
+
+static Py_ssize_t
+Tree_len(Tree *self) {
+    return (Py_ssize_t)git_tree_entrycount(self->tree);
+}
+
+static int
+Tree_contains(Tree *self, PyObject *py_name) {
+    char *name;
+    name = PyString_AsString(py_name);
+    return name && git_tree_entry_byname(self->tree, name) ? 1 : 0;
+}
+
+static TreeEntry *
+wrap_tree_entry(git_tree_entry *entry, Tree *tree) {
+    TreeEntry *py_entry = NULL;
+    py_entry = (TreeEntry*)TreeEntryType.tp_alloc(&TreeEntryType, 0);
+    if (!py_entry)
+        return NULL;
+
+    py_entry->entry = entry;
+    py_entry->tree = tree;
+    Py_INCREF(tree);
+    return py_entry;
+}
+
+static TreeEntry *
+Tree_getitem_by_name(Tree *self, PyObject *py_name) {
+    char *name;
+    git_tree_entry *entry;
+    name = PyString_AS_STRING(py_name);
+    entry = git_tree_entry_byname(self->tree, name);
+    if (!entry) {
+        PyErr_SetObject(PyExc_KeyError, py_name);
+        return NULL;
+    }
+    return wrap_tree_entry(entry, self);
+}
+
+static int
+Tree_fix_index(Tree *self, PyObject *py_index) {
+    long index;
+    size_t len;
+    long slen;
+
+    index = PyInt_AsLong(py_index);
+    if (PyErr_Occurred())
+        return -1;
+
+    len = git_tree_entrycount(self->tree);
+    slen = (long)len;
+    if (index >= slen) {
+        PyErr_SetObject(PyExc_IndexError, py_index);
+        return -1;
+    } else if (index < -slen) {
+        PyErr_SetObject(PyExc_IndexError, py_index);
+        return -1;
+    }
+
+    /* This function is called via mp_subscript, which doesn't do negative index
+     * rewriting, so we have to do it manually. */
+    if (index < 0)
+        index = len + index;
+    return (int)index;
+}
+
+static TreeEntry *
+Tree_getitem_by_index(Tree *self, PyObject *py_index) {
+    int index;
+    git_tree_entry *entry;
+
+    index = Tree_fix_index(self, py_index);
+    if (PyErr_Occurred())
+        return NULL;
+
+    entry = git_tree_entry_byindex(self->tree, index);
+    if (!entry) {
+        PyErr_SetObject(PyExc_IndexError, py_index);
+        return NULL;
+    }
+    return wrap_tree_entry(entry, self);
+}
+
+static TreeEntry *
+Tree_getitem(Tree *self, PyObject *value) {
+    if (PyString_Check(value)) {
+        return Tree_getitem_by_name(self, value);
+    } else if (PyInt_Check(value)) {
+        return Tree_getitem_by_index(self, value);
+    } else {
+        PyErr_SetString(PyExc_TypeError, "Expected int or str for tree index.");
+        return NULL;
+    }
+}
+
+static int
+Tree_delitem_by_name(Tree *self, PyObject *name) {
+    int err;
+    err = git_tree_remove_entry_byname(self->tree, PyString_AS_STRING(name));
+    if (err < 0) {
+        PyErr_SetObject(PyExc_KeyError, name);
+        return -1;
+    }
+    return 0;
+}
+
+static int
+Tree_delitem_by_index(Tree *self, PyObject *py_index) {
+    int index, err;
+    index = Tree_fix_index(self, py_index);
+    if (PyErr_Occurred())
+        return -1;
+    err = git_tree_remove_entry_byindex(self->tree, index);
+    if (err < 0) {
+        PyErr_SetObject(PyExc_IndexError, py_index);
+        return -1;
+    }
+    return 0;
+}
+
+static int
+Tree_delitem(Tree *self, PyObject *name, PyObject *value) {
+    /* TODO: This function is only used for deleting items. We may be able to
+     * come up with some reasonable assignment semantics, but it's tricky
+     * because git_tree_entry objects are owned by their containing tree. */
+    if (value) {
+        PyErr_SetString(PyExc_ValueError,
+                        "Cannot set TreeEntry directly; use add_entry.");
+        return -1;
+    }
+
+    if (PyString_Check(name)) {
+        return Tree_delitem_by_name(self, name);
+    } else if (PyInt_Check(name)) {
+        return Tree_delitem_by_index(self, name);
+    } else {
+        PyErr_SetString(PyExc_TypeError, "Expected int or str for tree index.");
+        return -1;
+    }
+}
+
+static PyObject *
+Tree_add_entry(Tree *self, PyObject *args) {
+    char *hex, *name;
+    int attributes;
+    git_oid oid;
+
+    if (!PyArg_ParseTuple(args, "ssi", &hex, &name, &attributes))
+        return NULL;
+
+    if (git_oid_mkstr(&oid, hex) < 0) {
+        PyErr_Format(PyExc_ValueError, "Invalid hex SHA \"%s\"", hex);
+        return NULL;
+    }
+
+    if (git_tree_add_entry(self->tree, &oid, name, attributes) < 0)
+        return PyErr_NoMemory();
+    return Py_None;
+}
+
+static PyMethodDef Tree_methods[] = {
+    {"add_entry", (PyCFunction)Tree_add_entry, METH_VARARGS,
+     "Add an entry to a Tree."},
+    {NULL}
+};
+
+static PySequenceMethods Tree_as_sequence = {
+    0,                          /* sq_length */
+    0,                          /* sq_concat */
+    0,                          /* sq_repeat */
+    0,                          /* sq_item */
+    0,                          /* sq_slice */
+    0,                          /* sq_ass_item */
+    0,                          /* sq_ass_slice */
+    (objobjproc)Tree_contains,  /* sq_contains */
+};
+
+static PyMappingMethods Tree_as_mapping = {
+    (lenfunc)Tree_len,            /* mp_length */
+    (binaryfunc)Tree_getitem,     /* mp_subscript */
+    (objobjargproc)Tree_delitem,  /* mp_ass_subscript */
+};
+
+static PyTypeObject TreeType = {
+    PyObject_HEAD_INIT(NULL)
+    0,                                         /*ob_size*/
+    "pygit2.Tree",                             /*tp_name*/
+    sizeof(Tree),                              /*tp_basicsize*/
+    0,                                         /*tp_itemsize*/
+    0,                                         /*tp_dealloc*/
+    0,                                         /*tp_print*/
+    0,                                         /*tp_getattr*/
+    0,                                         /*tp_setattr*/
+    0,                                         /*tp_compare*/
+    0,                                         /*tp_repr*/
+    0,                                         /*tp_as_number*/
+    &Tree_as_sequence,                         /*tp_as_sequence*/
+    &Tree_as_mapping,                          /*tp_as_mapping*/
+    0,                                         /*tp_hash */
+    0,                                         /*tp_call*/
+    0,                                         /*tp_str*/
+    0,                                         /*tp_getattro*/
+    0,                                         /*tp_setattro*/
+    0,                                         /*tp_as_buffer*/
+    Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE,  /*tp_flags*/
+    "Tree objects",                            /* tp_doc */
+    0,                                         /* tp_traverse */
+    0,                                         /* tp_clear */
+    0,                                         /* tp_richcompare */
+    0,                                         /* tp_weaklistoffset */
+    0,                                         /* tp_iter */
+    0,                                         /* tp_iternext */
+    Tree_methods,                              /* tp_methods */
+    0,                                         /* tp_members */
+    0,                                         /* tp_getset */
+    0,                                         /* tp_base */
+    0,                                         /* tp_dict */
+    0,                                         /* tp_descr_get */
+    0,                                         /* tp_descr_set */
+    0,                                         /* tp_dictoffset */
+    (initproc)Tree_init,                       /* tp_init */
+    0,                                         /* tp_alloc */
+    0,                                         /* tp_new */
+};
+
 static PyMethodDef module_methods[] = {
     {NULL}
 };
@@ -552,6 +937,14 @@ initpygit2(void)
     CommitType.tp_new = PyType_GenericNew;
     if (PyType_Ready(&CommitType) < 0)
         return;
+    TreeEntryType.tp_base = &ObjectType;
+    TreeEntryType.tp_new = PyType_GenericNew;
+    if (PyType_Ready(&TreeEntryType) < 0)
+        return;
+    TreeType.tp_base = &ObjectType;
+    TreeType.tp_new = PyType_GenericNew;
+    if (PyType_Ready(&TreeType) < 0)
+        return;
 
     m = Py_InitModule3("pygit2", module_methods,
                        "Python bindings for libgit2.");
@@ -568,6 +961,12 @@ initpygit2(void)
     Py_INCREF(&CommitType);
     PyModule_AddObject(m, "Commit", (PyObject *)&CommitType);
 
+    Py_INCREF(&TreeEntryType);
+    PyModule_AddObject(m, "TreeEntry", (PyObject *)&TreeEntryType);
+
+    Py_INCREF(&TreeType);
+    PyModule_AddObject(m, "Tree", (PyObject *)&TreeType);
+
     PyModule_AddIntConstant(m, "GIT_OBJ_ANY", GIT_OBJ_ANY);
     PyModule_AddIntConstant(m, "GIT_OBJ_COMMIT", GIT_OBJ_COMMIT);
     PyModule_AddIntConstant(m, "GIT_OBJ_TREE", GIT_OBJ_TREE);
diff --git a/test/__init__.py b/test/__init__.py
index 8d007ff..89b3d0c 100644
--- a/test/__init__.py
+++ b/test/__init__.py
@@ -36,7 +36,7 @@ import unittest
 
 
 def test_suite():
-    names = ['commit', 'repository']
+    names = ['commit', 'repository', 'tree']
     modules = ['test.test_%s' % n for n in names]
     return unittest.defaultTestLoader.loadTestsFromNames(modules)
 
diff --git a/test/test_tree.py b/test/test_tree.py
new file mode 100644
index 0000000..9a5d7e6
--- /dev/null
+++ b/test/test_tree.py
@@ -0,0 +1,134 @@
+#!/usr/bin/env python
+#
+# Copyright 2010 Google, Inc.
+#
+# This file is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License, version 2,
+# as published by the Free Software Foundation.
+#
+# In addition to the permissions in the GNU General Public License,
+# the authors give you unlimited permission to link the compiled
+# version of this file into combinations with other programs,
+# and to distribute those combinations without any restriction
+# coming from the use of this file.  (The General Public License
+# restrictions do apply in other respects; for example, they cover
+# modification of the file, and distribution when not linked into
+# a combined executable.)
+#
+# This file is distributed in the hope that it will be useful, but
+# WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+# General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; see the file COPYING.  If not, write to
+# the Free Software Foundation, 51 Franklin Street, Fifth Floor,
+# Boston, MA 02110-1301, USA.
+
+"""Tests for Commit objects."""
+
+__author__ = 'dborowitz@google.com (Dave Borowitz)'
+
+import unittest
+
+import pygit2
+import utils
+
+TREE_SHA = '967fce8df97cc71722d3c2a5930ef3e6f1d27b12'
+SUBTREE_SHA = '614fd9a3094bf618ea938fffc00e7d1a54f89ad0'
+
+
+class TreeTest(utils.TestRepoTestCase):
+
+    def assertTreeEntryEqual(self, entry, sha, name, attributes):
+        self.assertEqual(entry.sha, sha)
+        self.assertEqual(entry.name, name)
+        self.assertEqual(entry.attributes, attributes,
+                         '0%o != 0%o' % (entry.attributes, attributes))
+
+    def test_read_tree(self):
+        tree = self.repo[TREE_SHA]
+        self.assertRaises(TypeError, lambda: tree[()])
+        self.assertRaises(KeyError, lambda: tree['abcd'])
+        self.assertRaises(IndexError, lambda: tree[-4])
+        self.assertRaises(IndexError, lambda: tree[3])
+
+        self.assertEqual(3, len(tree))
+        a_sha = '7f129fd57e31e935c6d60a0c794efe4e6927664b'
+        self.assertTrue('a' in tree)
+        self.assertTreeEntryEqual(tree[0], a_sha, 'a', 0100644)
+        self.assertTreeEntryEqual(tree[-3], a_sha, 'a', 0100644)
+        self.assertTreeEntryEqual(tree['a'], a_sha, 'a', 0100644)
+
+        b_sha = '85f120ee4dac60d0719fd51731e4199aa5a37df6'
+        self.assertTrue('b' in tree)
+        self.assertTreeEntryEqual(tree[1], b_sha, 'b', 0100644)
+        self.assertTreeEntryEqual(tree[-2], b_sha, 'b', 0100644)
+        self.assertTreeEntryEqual(tree['b'], b_sha, 'b', 0100644)
+
+    def test_read_subtree(self):
+        tree = self.repo[TREE_SHA]
+        subtree_entry = tree['c']
+        self.assertTreeEntryEqual(subtree_entry, SUBTREE_SHA, 'c', 0040000)
+
+        subtree = subtree_entry.to_object()
+        self.assertEqual(1, len(subtree))
+        self.assertTreeEntryEqual(
+          subtree[0], '297efb891a47de80be0cfe9c639e4b8c9b450989', 'd', 0100644)
+
+    def test_new_tree(self):
+        tree = pygit2.Tree(self.repo)
+        self.assertEqual(0, len(tree))
+        tree.add_entry('1' * 40, 'x', 0100644)
+        tree.add_entry('2' * 40, 'y', 0100755)
+        self.assertEqual(2, len(tree))
+        self.assertTrue('x' in tree)
+        self.assertTrue('y' in tree)
+        self.assertRaises(KeyError, tree['x'].to_object)
+
+        tree.add_entry('3' * 40, 'z1', 0100644)
+        tree.add_entry('4' * 40, 'z2', 0100644)
+        self.assertEqual(4, len(tree))
+        del tree['z1']
+        del tree[2]
+        self.assertEqual(2, len(tree))
+
+        self.assertEqual(None, tree.sha)
+        tree.write()
+        contents = '100644 x\0%s100755 y\0%s' % ('\x11' * 20, '\x22' * 20)
+        self.assertEqual((pygit2.GIT_OBJ_TREE, contents),
+                         self.repo.read(tree.sha))
+
+    def test_modify_tree(self):
+        tree = self.repo[TREE_SHA]
+
+        def fail_set():
+            tree['c'] = tree['a']
+        self.assertRaises(ValueError, fail_set)
+
+        def fail_del_by_name():
+            del tree['asdf']
+        self.assertRaises(KeyError, fail_del_by_name)
+
+        def fail_del_by_index():
+            del tree[99]
+        self.assertRaises(IndexError, fail_del_by_index)
+
+        self.assertTrue('c' in tree)
+        self.assertEqual(3, len(tree))
+        del tree['c']
+        self.assertEqual(2, len(tree))
+        self.assertFalse('c' in tree)
+
+        tree.add_entry('1' * 40, 'c', 0100644)
+        self.assertTrue('c' in tree)
+        self.assertEqual(3, len(tree))
+
+        old_sha = tree.sha
+        tree.write()
+        self.assertNotEqual(tree.sha, old_sha)
+        self.assertEqual(tree.sha, self.repo[tree.sha].sha)
+
+
+if __name__ == '__main__':
+  unittest.main()