Implement proxying of rich comparisons.

This commit is contained in:
Graham Dumpleton
2013-08-19 20:22:54 +10:00
parent 50595d1858
commit 1fc4164750
3 changed files with 68 additions and 8 deletions

View File

@@ -469,6 +469,19 @@ static int WraptObjectProxy_setattro(
/* ------------------------------------------------------------------------- */
static PyObject *WraptObjectProxy_richcompare(WraptObjectProxyObject *self,
PyObject *other, int opcode)
{
if (!self->wrapped) {
PyErr_SetString(PyExc_ValueError, "wrapper has not been initialised");
return NULL;
}
return PyObject_RichCompare(self->wrapped, other, opcode);
}
/* ------------------------------------------------------------------------- */
static PyObject *WraptObjectProxy_iter(WraptObjectProxyObject *self)
{
if (!self->wrapped) {
@@ -532,7 +545,7 @@ PyTypeObject WraptObjectProxy_Type = {
0, /*tp_doc*/
0, /*tp_traverse*/
0, /*tp_clear*/
0, /*tp_richcompare*/
(richcmpfunc)WraptObjectProxy_richcompare, /*tp_richcompare*/
0, /*tp_weaklistoffset*/
(getiterfunc)WraptObjectProxy_iter, /*tp_iter*/
0, /*tp_iternext*/

View File

@@ -131,20 +131,29 @@ class ObjectProxy(six.with_metaclass(_ObjectProxyMetaType)):
def __dir__(self):
return dir(self._self_wrapped)
def __lt__(self, other):
return self._self_wrapped < other
def __gt__(self, other):
return self._self_wrapped > other
def __le__(self, other):
return self._self_wrapped <= other
def __ge__(self, other):
return self._self_wrapped >= other
def __eq__(self, other):
return self._self_target == other
return self._self_wrapped == other
def __ne__(self, other):
result = self.__eq__(other)
if result is NotImplemented:
return result
return not result
return self._self_wrapped != other
def __hash__(self):
return hash(self._self_target)
return hash(self._self_wrapped)
def __repr__(self):
return '<%s for %s>' % (type(self).__name__, str(self._self_target))
return '<%s for %s>' % (type(self).__name__, str(self._self_wrapped))
def __enter__(self):
return self._self_wrapped.__enter__()

View File

@@ -642,5 +642,43 @@ class TestEqualityObjectProxy(unittest.TestCase):
self.assertEqual(hash(function2), hash(function1))
def test_mapping_key(self):
def function1(*args, **kwargs):
return args, kwargs
function2 = wrapt.ObjectProxy(function1)
table = dict()
table[function1] = True
self.assertTrue(table.get(function2))
table = dict()
table[function2] = True
self.assertTrue(table.get(function1))
def test_comparison(self):
one = wrapt.ObjectProxy(1)
two = wrapt.ObjectProxy(2)
three = wrapt.ObjectProxy(3)
self.assertTrue(two > 1)
self.assertTrue(two > one)
self.assertTrue(two >= 1)
self.assertTrue(two >= one)
self.assertTrue(two < 3)
self.assertTrue(two < three)
self.assertTrue(two <= 3)
self.assertTrue(two <= three)
self.assertTrue(two != 1)
self.assertTrue(two != one)
self.assertTrue(two != 3)
self.assertTrue(two != three)
self.assertTrue(two == 2)
self.assertTrue(two <= two)
if __name__ == '__main__':
unittest.main()