diff --git a/src/tree.c b/src/tree.c index e8f7111..e51e139 100644 --- a/src/tree.c +++ b/src/tree.c @@ -87,10 +87,20 @@ TreeEntry_oid__get__(TreeEntry *self) return TreeEntry_id__get__(self); } +static int +compare_ids(TreeEntry *a, TreeEntry *b) +{ + const git_oid *id_a, *id_b; + id_a = git_tree_entry_id(a->entry); + id_b = git_tree_entry_id(b->entry); + return git_oid_cmp(id_a, id_b); +} + PyObject * TreeEntry_richcompare(PyObject *a, PyObject *b, int op) { PyObject *res; + TreeEntry *ta, *tb; int cmp; /* We only support comparing to another tree entry */ @@ -99,7 +109,14 @@ TreeEntry_richcompare(PyObject *a, PyObject *b, int op) return Py_NotImplemented; } - cmp =git_tree_entry_cmp(((TreeEntry*)a)->entry, ((TreeEntry*)b)->entry); + ta = (TreeEntry *) a; + tb = (TreeEntry *) b; + + /* This is sorting order, if they sort equally, we still need to compare the ids */ + cmp = git_tree_entry_cmp(ta->entry, tb->entry); + if (cmp == 0) + cmp = compare_ids(ta, tb); + switch (op) { case Py_LT: res = (cmp <= 0) ? Py_True: Py_False; @@ -147,7 +164,6 @@ PyGetSetDef TreeEntry_getseters[] = { {NULL} }; - PyDoc_STRVAR(TreeEntry__doc__, "TreeEntry objects."); PyTypeObject TreeEntryType = { diff --git a/test/test_tree.py b/test/test_tree.py index 128fc75..6c8b707 100644 --- a/test/test_tree.py +++ b/test/test_tree.py @@ -32,6 +32,7 @@ from __future__ import unicode_literals import operator import unittest +from pygit2 import TreeEntry from . import utils @@ -71,6 +72,18 @@ class TreeTest(utils.BareRepoTestCase): self.assertTreeEntryEqual(tree['c/d'], sha, 'd', 0o0100644) self.assertRaisesWithArg(KeyError, 'ab/cd', lambda: tree['ab/cd']) + def test_equality(self): + tree_a = self.repo['18e2d2e9db075f9eb43bcb2daa65a2867d29a15e'] + tree_b = self.repo['2ad1d3456c5c4a1c9e40aeeddb9cd20b409623c8'] + + self.assertNotEqual(tree_a['a'], tree_b['a']) + self.assertNotEqual(tree_a['a'], tree_b['b']) + self.assertEqual(tree_a['b'], tree_b['b']) + + def test_sorting(self): + tree_a = self.repo['18e2d2e9db075f9eb43bcb2daa65a2867d29a15e'] + self.assertEqual(list(tree_a), sorted(reversed(list(tree_a)))) + self.assertNotEqual(list(tree_a), reversed(list(tree_a))) def test_read_subtree(self): tree = self.repo[TREE_SHA]