diff --git a/pygit2.c b/pygit2.c index 3b3c745..b870e69 100644 --- a/pygit2.c +++ b/pygit2.c @@ -83,6 +83,12 @@ typedef struct { int i; } IndexIter; +typedef struct { + PyObject_HEAD + Tree *owner; + int i; +} TreeIter; + typedef struct { PyObject_HEAD git_index_entry *entry; @@ -101,6 +107,7 @@ static PyTypeObject TreeType; static PyTypeObject BlobType; static PyTypeObject TagType; static PyTypeObject IndexType; +static PyTypeObject TreeIterType; static PyTypeObject IndexIterType; static PyTypeObject IndexEntryType; static PyTypeObject WalkerType; @@ -1156,6 +1163,21 @@ Tree_fix_index(Tree *self, PyObject *py_index) { return (int)index; } +static PyObject * +Tree_iter(Tree *self) { + TreeIter *iter; + + iter = PyObject_New(TreeIter, &TreeIterType); + if (!iter) + return NULL; + + Py_INCREF(self); + iter->owner = self; + iter->i = 0; + + return (PyObject*)iter; +} + static TreeEntry * Tree_getitem_by_index(Tree *self, PyObject *py_index) { int index; @@ -1231,7 +1253,7 @@ static PyTypeObject TreeType = { 0, /* tp_clear */ 0, /* tp_richcompare */ 0, /* tp_weaklistoffset */ - 0, /* tp_iter */ + (getiterfunc)Tree_iter, /* tp_iter */ 0, /* tp_iternext */ 0, /* tp_methods */ 0, /* tp_members */ @@ -1246,6 +1268,55 @@ static PyTypeObject TreeType = { 0, /* tp_new */ }; +static void +TreeIter_dealloc(TreeIter *self) { + Py_CLEAR(self->owner); + PyObject_Del(self); +} + +static TreeEntry * +TreeIter_iternext(TreeIter *self) { + const git_tree_entry *tree_entry; + + tree_entry = git_tree_entry_byindex(self->owner->tree, self->i); + if (!tree_entry) + return NULL; + + self->i += 1; + return (TreeEntry*)wrap_tree_entry(tree_entry, self->owner); +} + +static PyTypeObject TreeIterType = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + "pygit2.TreeIter", /* tp_name */ + sizeof(TreeIter), /* tp_basicsize */ + 0, /* tp_itemsize */ + (destructor)TreeIter_dealloc , /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ + 0, /* tp_compare */ + 0, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + PyObject_GenericGetAttr, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT | + Py_TPFLAGS_BASETYPE, /* tp_flags */ + 0, /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + PyObject_SelfIter, /* tp_iter */ + (iternextfunc)TreeIter_iternext, /* tp_iternext */ + }; + static PyGetSetDef Blob_getseters[] = { {"data", (getter)Object_read_raw, NULL, "raw data", NULL}, {NULL} diff --git a/test/test_tree.py b/test/test_tree.py index bcbac31..b48308f 100644 --- a/test/test_tree.py +++ b/test/test_tree.py @@ -108,6 +108,16 @@ class TreeTest(utils.BareRepoTestCase): self.assertRaises(TypeError, operator.setitem, 'c', tree['a']) self.assertRaises(TypeError, operator.delitem, 'c') + def test_iterate_tree(self): + """ + Testing that we're able to iterate of a Tree object and that the + resulting sha strings are consitent with the sha strings we could + get with other Tree access methods. + """ + tree = self.repo[TREE_SHA] + for tree_entry in tree: + self.assertEqual(tree_entry.sha, tree[tree_entry.name].sha) + if __name__ == '__main__': unittest.main()