From 1fc41647502996f11a2a7a608399ec15a0281018 Mon Sep 17 00:00:00 2001 From: Graham Dumpleton Date: Mon, 19 Aug 2013 20:22:54 +1000 Subject: [PATCH] Implement proxying of rich comparisons. --- src/_wrappers.c | 15 ++++++++++++++- src/wrappers.py | 23 ++++++++++++++++------- tests/test_object_proxy.py | 38 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 68 insertions(+), 8 deletions(-) diff --git a/src/_wrappers.c b/src/_wrappers.c index 6dcc34c..c501d32 100644 --- a/src/_wrappers.c +++ b/src/_wrappers.c @@ -469,6 +469,19 @@ static int WraptObjectProxy_setattro( /* ------------------------------------------------------------------------- */ +static PyObject *WraptObjectProxy_richcompare(WraptObjectProxyObject *self, + PyObject *other, int opcode) +{ + if (!self->wrapped) { + PyErr_SetString(PyExc_ValueError, "wrapper has not been initialised"); + return NULL; + } + + return PyObject_RichCompare(self->wrapped, other, opcode); +} + +/* ------------------------------------------------------------------------- */ + static PyObject *WraptObjectProxy_iter(WraptObjectProxyObject *self) { if (!self->wrapped) { @@ -532,7 +545,7 @@ PyTypeObject WraptObjectProxy_Type = { 0, /*tp_doc*/ 0, /*tp_traverse*/ 0, /*tp_clear*/ - 0, /*tp_richcompare*/ + (richcmpfunc)WraptObjectProxy_richcompare, /*tp_richcompare*/ 0, /*tp_weaklistoffset*/ (getiterfunc)WraptObjectProxy_iter, /*tp_iter*/ 0, /*tp_iternext*/ diff --git a/src/wrappers.py b/src/wrappers.py index eff7820..0b4c6f9 100644 --- a/src/wrappers.py +++ b/src/wrappers.py @@ -131,20 +131,29 @@ class ObjectProxy(six.with_metaclass(_ObjectProxyMetaType)): def __dir__(self): return dir(self._self_wrapped) + def __lt__(self, other): + return self._self_wrapped < other + + def __gt__(self, other): + return self._self_wrapped > other + + def __le__(self, other): + return self._self_wrapped <= other + + def __ge__(self, other): + return self._self_wrapped >= other + def __eq__(self, other): - return self._self_target == other + return self._self_wrapped == other def __ne__(self, other): - result = self.__eq__(other) - if result is NotImplemented: - return result - return not result + return self._self_wrapped != other def __hash__(self): - return hash(self._self_target) + return hash(self._self_wrapped) def __repr__(self): - return '<%s for %s>' % (type(self).__name__, str(self._self_target)) + return '<%s for %s>' % (type(self).__name__, str(self._self_wrapped)) def __enter__(self): return self._self_wrapped.__enter__() diff --git a/tests/test_object_proxy.py b/tests/test_object_proxy.py index ce3bea0..fc0b880 100644 --- a/tests/test_object_proxy.py +++ b/tests/test_object_proxy.py @@ -642,5 +642,43 @@ class TestEqualityObjectProxy(unittest.TestCase): self.assertEqual(hash(function2), hash(function1)) + def test_mapping_key(self): + def function1(*args, **kwargs): + return args, kwargs + function2 = wrapt.ObjectProxy(function1) + + table = dict() + table[function1] = True + + self.assertTrue(table.get(function2)) + + table = dict() + table[function2] = True + + self.assertTrue(table.get(function1)) + + def test_comparison(self): + one = wrapt.ObjectProxy(1) + two = wrapt.ObjectProxy(2) + three = wrapt.ObjectProxy(3) + + self.assertTrue(two > 1) + self.assertTrue(two > one) + self.assertTrue(two >= 1) + self.assertTrue(two >= one) + + self.assertTrue(two < 3) + self.assertTrue(two < three) + self.assertTrue(two <= 3) + self.assertTrue(two <= three) + + self.assertTrue(two != 1) + self.assertTrue(two != one) + self.assertTrue(two != 3) + self.assertTrue(two != three) + + self.assertTrue(two == 2) + self.assertTrue(two <= two) + if __name__ == '__main__': unittest.main()