From 4dc90f78a9a3e248b016b4743e41babe5d6b423c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mart=C3=ADn=20Nieto?= Date: Sun, 19 Jan 2014 17:10:40 +0100 Subject: [PATCH] TreeEntry: add rich comparison function Allow direct comparisons between TreeEntry objects, which also allows us to use assertEqual in the sanity check test. This fixes #305. --- src/tree.c | 44 +++++++++++++++++++++++++++++++++++++++++++- test/test_tree.py | 2 +- 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/src/tree.c b/src/tree.c index 7599d65..a3bfe21 100644 --- a/src/tree.c +++ b/src/tree.c @@ -36,6 +36,7 @@ #include "diff.h" extern PyTypeObject TreeType; +extern PyTypeObject TreeEntryType; extern PyTypeObject DiffType; extern PyTypeObject TreeIterType; extern PyTypeObject IndexType; @@ -77,6 +78,47 @@ TreeEntry_oid__get__(TreeEntry *self) return git_oid_to_python(oid); } +PyObject * +TreeEntry_richcompare(PyObject *a, PyObject *b, int op) +{ + PyObject *res; + int cmp; + + /* We only support comparing to another tree entry */ + if (!PyObject_TypeCheck(b, &TreeEntryType)) { + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; + } + + cmp =git_tree_entry_cmp(((TreeEntry*)a)->entry, ((TreeEntry*)b)->entry); + switch (op) { + case Py_LT: + res = (cmp <= 0) ? Py_True: Py_False; + break; + case Py_LE: + res = (cmp < 0) ? Py_True: Py_False; + break; + case Py_EQ: + res = (cmp == 0) ? Py_True: Py_False; + break; + case Py_NE: + res = (cmp != 0) ? Py_True: Py_False; + break; + case Py_GT: + res = (cmp > 0) ? Py_True: Py_False; + break; + case Py_GE: + res = (cmp >= 0) ? Py_True: Py_False; + break; + default: + PyErr_Format(PyExc_RuntimeError, "Unexpected '%d' op", op); + return NULL; + } + + Py_INCREF(res); + return res; +} + PyDoc_STRVAR(TreeEntry_hex__doc__, "Hex oid."); @@ -122,7 +164,7 @@ PyTypeObject TreeEntryType = { TreeEntry__doc__, /* tp_doc */ 0, /* tp_traverse */ 0, /* tp_clear */ - 0, /* tp_richcompare */ + (richcmpfunc)TreeEntry_richcompare, /* tp_richcompare */ 0, /* tp_weaklistoffset */ 0, /* tp_iter */ 0, /* tp_iternext */ diff --git a/test/test_tree.py b/test/test_tree.py index 16674ec..d1fc0de 100644 --- a/test/test_tree.py +++ b/test/test_tree.py @@ -118,7 +118,7 @@ class TreeTest(utils.BareRepoTestCase): """ tree = self.repo[TREE_SHA] for tree_entry in tree: - self.assertEqual(tree_entry.hex, tree[tree_entry.name].hex) + self.assertEqual(tree_entry, tree[tree_entry.name]) if __name__ == '__main__':