diff --git a/src/tree.c b/src/tree.c index 7599d65..a3bfe21 100644 --- a/src/tree.c +++ b/src/tree.c @@ -36,6 +36,7 @@ #include "diff.h" extern PyTypeObject TreeType; +extern PyTypeObject TreeEntryType; extern PyTypeObject DiffType; extern PyTypeObject TreeIterType; extern PyTypeObject IndexType; @@ -77,6 +78,47 @@ TreeEntry_oid__get__(TreeEntry *self) return git_oid_to_python(oid); } +PyObject * +TreeEntry_richcompare(PyObject *a, PyObject *b, int op) +{ + PyObject *res; + int cmp; + + /* We only support comparing to another tree entry */ + if (!PyObject_TypeCheck(b, &TreeEntryType)) { + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; + } + + cmp =git_tree_entry_cmp(((TreeEntry*)a)->entry, ((TreeEntry*)b)->entry); + switch (op) { + case Py_LT: + res = (cmp <= 0) ? Py_True: Py_False; + break; + case Py_LE: + res = (cmp < 0) ? Py_True: Py_False; + break; + case Py_EQ: + res = (cmp == 0) ? Py_True: Py_False; + break; + case Py_NE: + res = (cmp != 0) ? Py_True: Py_False; + break; + case Py_GT: + res = (cmp > 0) ? Py_True: Py_False; + break; + case Py_GE: + res = (cmp >= 0) ? Py_True: Py_False; + break; + default: + PyErr_Format(PyExc_RuntimeError, "Unexpected '%d' op", op); + return NULL; + } + + Py_INCREF(res); + return res; +} + PyDoc_STRVAR(TreeEntry_hex__doc__, "Hex oid."); @@ -122,7 +164,7 @@ PyTypeObject TreeEntryType = { TreeEntry__doc__, /* tp_doc */ 0, /* tp_traverse */ 0, /* tp_clear */ - 0, /* tp_richcompare */ + (richcmpfunc)TreeEntry_richcompare, /* tp_richcompare */ 0, /* tp_weaklistoffset */ 0, /* tp_iter */ 0, /* tp_iternext */ diff --git a/test/test_tree.py b/test/test_tree.py index 16674ec..d1fc0de 100644 --- a/test/test_tree.py +++ b/test/test_tree.py @@ -118,7 +118,7 @@ class TreeTest(utils.BareRepoTestCase): """ tree = self.repo[TREE_SHA] for tree_entry in tree: - self.assertEqual(tree_entry.hex, tree[tree_entry.name].hex) + self.assertEqual(tree_entry, tree[tree_entry.name]) if __name__ == '__main__':