diff --git a/Lib/ldap/functions.py b/Lib/ldap/functions.py index 2a977a6..a6d5fa3 100644 --- a/Lib/ldap/functions.py +++ b/Lib/ldap/functions.py @@ -74,7 +74,7 @@ def _ldap_function_call(lock,func,*args,**kwargs): return result -def initialize(uri,trace_level=0,trace_file=sys.stdout,trace_stack_limit=None): +def initialize(uri,trace_level=0,trace_file=sys.stdout,trace_stack_limit=None, bytes_mode=None): """ Return LDAPObject instance by opening LDAP connection to LDAP host specified by LDAP URL @@ -88,11 +88,13 @@ def initialize(uri,trace_level=0,trace_file=sys.stdout,trace_stack_limit=None): trace_file File object where to write the trace output to. Default is to use stdout. + bytes_mode + Whether to enable "bytes_mode" for backwards compatibility under Py2. """ - return LDAPObject(uri,trace_level,trace_file,trace_stack_limit) + return LDAPObject(uri,trace_level,trace_file,trace_stack_limit,bytes_mode) -def open(host,port=389,trace_level=0,trace_file=sys.stdout,trace_stack_limit=None): +def open(host,port=389,trace_level=0,trace_file=sys.stdout,trace_stack_limit=None,bytes_mode=None): """ Return LDAPObject instance by opening LDAP connection to specified LDAP host @@ -107,10 +109,12 @@ def open(host,port=389,trace_level=0,trace_file=sys.stdout,trace_stack_limit=Non trace_file File object where to write the trace output to. Default is to use stdout. + bytes_mode + Whether to enable "bytes_mode" for backwards compatibility under Py2. """ import warnings warnings.warn('ldap.open() is deprecated! Use ldap.initialize() instead.', DeprecationWarning,2) - return initialize('ldap://%s:%d' % (host,port),trace_level,trace_file,trace_stack_limit) + return initialize('ldap://%s:%d' % (host,port),trace_level,trace_file,trace_stack_limit,bytes_mode) init = open diff --git a/Lib/ldap/ldapobject.py b/Lib/ldap/ldapobject.py index 9eb3858..3e22509 100644 --- a/Lib/ldap/ldapobject.py +++ b/Lib/ldap/ldapobject.py @@ -18,6 +18,8 @@ Basically calls into the LDAP lib are serialized by the module-wide lock self._ldap_object_lock. """ +from __future__ import unicode_literals + from ldap import __version__ __all__ = [ @@ -33,6 +35,7 @@ if __debug__: import traceback import sys,time,pprint,_ldap,ldap,ldap.sasl,ldap.functions +import warnings from ldap.schema import SCHEMA_ATTRS from ldap.controls import LDAPControl,DecodeControlTuples,RequestControlTuples @@ -40,6 +43,11 @@ from ldap.extop import ExtendedRequest,ExtendedResponse from ldap import LDAPError +PY2 = bool(sys.version_info[0] <= 2) +if PY2: + text_type = unicode +else: + text_type = str class NO_UNIQUE_ENTRY(ldap.NO_SUCH_OBJECT): """ @@ -67,7 +75,7 @@ class SimpleLDAPObject: def __init__( self,uri, - trace_level=0,trace_file=None,trace_stack_limit=5 + trace_level=0,trace_file=None,trace_stack_limit=5,bytes_mode=None ): self._trace_level = trace_level self._trace_file = trace_file or sys.stdout @@ -78,6 +86,123 @@ class SimpleLDAPObject: self.timeout = -1 self.protocol_version = ldap.VERSION3 + # Bytes mode + # ---------- + if bytes_mode is None and PY2: + warnings.warn( + "Under Python 2, python-ldap uses bytes by default. " + "This will be removed in Python 3 (no bytes for DN/RDN/field names). " + "Please call initialize(..., bytes_mode=False) explicitly.", + BytesWarning, + stacklevel=2, + ) + bytes_mode = True + elif bytes_mode and not PY2: + raise ValueError("bytes_mode is *not* supported under Python 3.") + # On by default on Py2, off on Py3. + self.bytes_mode = bytes_mode + + def _unbytesify_value(self, value): + """Adapt a value following bytes_mode. + + With bytes_mode ON, takes bytes or None and returns unicode or None. + With bytes_mode OFF, takes unicode or None and returns unicode or None. + """ + if not PY2: + return value + + if value is None: + return value + elif self.bytes_mode: + if not isinstance(value, bytes): + raise TypeError("All provided fields *must* be bytes in bytes mode; got %r" % (value,)) + return value.decode('utf-8') + else: + if not isinstance(value, text_type): + raise TypeError("All provided fields *must* be text when bytes mode is off; got %r" % (value,)) + assert not isinstance(value, bytes) + return value + + def _unbytesify_values(self, *values): + """Adapt values following bytes_mode. + + Applies _unbytesify_value on each arg. + + Usage: + >>> a, b, c= self._unbytesify_values(a, b, c) + """ + if not PY2: + return values + return ( + self._unbytesify_value(value) + for value in values + ) + + def _unbytesify_modlist(self, modlist): + """Adapt a modlist according to bytes_mode. + + A modlist is a tuple of (op, attr, value), where: + - With bytes_mode ON, attr is converted from bytes to unicode + - With bytes_mode OFF, attr is checked to be unicode + - value is *always* bytes + """ + if not PY2: + return modlist + return tuple( + (op, self._unbytesify_value(attr), val) + for op, attr, val in modlist + ) + + def _bytesify_value(self, value): + """Adapt a returned value according to bytes_mode. + + Takes unicode (and checks for it), and returns: + - bytes under bytes_mode + - unicode otherwise. + """ + if not PY2: + return value + + if value is None: + return value + + assert isinstance(value, text_type), "Should return text, got bytes instead (%r)" % (value,) + if not self.bytes_mode: + return value + else: + return value.encode('utf-8') + + def _bytesify_keys(self, dct): + """Applies bytes_mode to the keys of a dict.""" + if not PY2: + return dct + return dict( + (self._bytesify_value(key), value) + for (key, value) in dct.items() + ) + + def _bytesify_results(self, results, with_ctrls=False): + """Converts a "results" object according to bytes_mode. + + Takes: + - a list of (dn, {field: [values]}) if with_ctrls is False + - a list of (dn, {field: [values]}, ctrls) if with_ctrls is True + + And, if bytes_mode is on, converts dn and fields to bytes. + """ + if not PY2: + return results + if with_ctrls: + return [ + (self._bytesify_value(dn), self._bytesify_keys(fields), ctrls) + for (dn, fields, ctrls) in results + ] + else: + return [ + (self._bytesify_value(dn), self._bytesify_keys(fields)) + for (dn, fields) in results + ] + def _ldap_lock(self,desc=''): if ldap.LIBLDAP_R: return ldap.LDAPLock(desc='%s within %s' %(desc,repr(self))) @@ -188,6 +313,8 @@ class SimpleLDAPObject: The parameter modlist is similar to the one passed to modify(), except that no operation integer need be included in the tuples. """ + dn = self._unbytesify_value(dn) + modlist = self._unbytesify_modlist(modlist) return self._ldap_call(self._l.add_ext,dn,modlist,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls)) def add_ext_s(self,dn,modlist,serverctrls=None,clientctrls=None): @@ -213,6 +340,7 @@ class SimpleLDAPObject: """ simple_bind([who='' [,cred='']]) -> int """ + who, cred = self._unbytesify_values(who, cred) return self._ldap_call(self._l.simple_bind,who,cred,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls)) def simple_bind_s(self,who='',cred='',serverctrls=None,clientctrls=None): @@ -285,6 +413,7 @@ class SimpleLDAPObject: A design bug in the library prevents value from containing nul characters. """ + dn, attr = self._unbytesify_values(dn, attr) return self._ldap_call(self._l.compare_ext,dn,attr,value,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls)) def compare_ext_s(self,dn,attr,value,serverctrls=None,clientctrls=None): @@ -313,6 +442,7 @@ class SimpleLDAPObject: form returns the message id of the initiated request, and the result can be obtained from a subsequent call to result(). """ + dn = self._unbytesify_value(dn) return self._ldap_call(self._l.delete_ext,dn,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls)) def delete_ext_s(self,dn,serverctrls=None,clientctrls=None): @@ -361,6 +491,8 @@ class SimpleLDAPObject: """ modify_ext(dn, modlist[,serverctrls=None[,clientctrls=None]]) -> int """ + dn = self._unbytesify_value(dn) + modlist = self._unbytesify_modlist(modlist) return self._ldap_call(self._l.modify_ext,dn,modlist,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls)) def modify_ext_s(self,dn,modlist,serverctrls=None,clientctrls=None): @@ -415,6 +547,7 @@ class SimpleLDAPObject: return self.rename_s(dn,newrdn,None,delold) def passwd(self,user,oldpw,newpw,serverctrls=None,clientctrls=None): + user, oldpw, newpw = self._unbytesify_values(user, oldpw, newpw) return self._ldap_call(self._l.passwd,user,oldpw,newpw,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls)) def passwd_s(self,user,oldpw,newpw,serverctrls=None,clientctrls=None): @@ -436,6 +569,7 @@ class SimpleLDAPObject: This actually corresponds to the rename* routines in the LDAP-EXT C API library. """ + dn, newrdn, newsuperior = self._unbytesify_values(dn, newrdn, newsuperior) return self._ldap_call(self._l.rename,dn,newrdn,newsuperior,delold,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls)) def rename_s(self,dn,newrdn,newsuperior=None,delold=1,serverctrls=None,clientctrls=None): @@ -524,6 +658,8 @@ class SimpleLDAPObject: if add_ctrls: resp_data = [ (t,r,DecodeControlTuples(c,resp_ctrl_classes)) for t,r,c in resp_data ] decoded_resp_ctrls = DecodeControlTuples(resp_ctrls,resp_ctrl_classes) + if resp_data is not None: + resp_data = self._bytesify_results(resp_data, with_ctrls=add_ctrls) return resp_type, resp_data, resp_msgid, decoded_resp_ctrls, resp_name, resp_value def search_ext(self,base,scope,filterstr='(objectClass=*)',attrlist=None,attrsonly=0,serverctrls=None,clientctrls=None,timeout=-1,sizelimit=0): @@ -571,6 +707,9 @@ class SimpleLDAPObject: The amount of search results retrieved can be limited with the sizelimit parameter if non-zero. """ + base, filterstr = self._unbytesify_values(base, filterstr) + if attrlist is not None: + attrlist = tuple(self._unbytesify_values(*attrlist)) return self._ldap_call( self._l.search_ext, base,scope,filterstr, diff --git a/Lib/ldap/schema/subentry.py b/Lib/ldap/schema/subentry.py index 3060672..082464f 100644 --- a/Lib/ldap/schema/subentry.py +++ b/Lib/ldap/schema/subentry.py @@ -445,7 +445,7 @@ class SubSchema: return r_must,r_may # attribute_types() -def urlfetch(uri,trace_level=0): +def urlfetch(uri,trace_level=0,bytes_mode=None): """ Fetches a parsed schema entry by uri. @@ -457,7 +457,7 @@ def urlfetch(uri,trace_level=0): if uri.startswith('ldap:') or uri.startswith('ldaps:') or uri.startswith('ldapi:'): import ldapurl ldap_url = ldapurl.LDAPUrl(uri) - l=ldap.initialize(ldap_url.initializeUrl(),trace_level) + l=ldap.initialize(ldap_url.initializeUrl(),trace_level,bytes_mode=bytes_mode) l.protocol_version = ldap.VERSION3 l.simple_bind_s(ldap_url.who or '', ldap_url.cred or '') subschemasubentry_dn = l.search_subschemasubentry_s(ldap_url.dn) diff --git a/Modules/message.c b/Modules/message.c index a159e10..e1e7609 100644 --- a/Modules/message.c +++ b/Modules/message.c @@ -50,6 +50,7 @@ LDAPmessage_to_python(LDAP *ld, LDAPMessage *m, int add_ctrls, int add_intermedi BerElement *ber = NULL; PyObject* entrytuple; PyObject* attrdict; + PyObject* pydn; dn = ldap_get_dn( ld, entry ); if (dn == NULL) { @@ -57,9 +58,17 @@ LDAPmessage_to_python(LDAP *ld, LDAPMessage *m, int add_ctrls, int add_intermedi ldap_msgfree( m ); return LDAPerror( ld, "ldap_get_dn" ); } + pydn = PyUnicode_FromString(dn); + if (pydn == NULL) { + Py_DECREF(result); + ldap_msgfree( m ); + ldap_memfree(dn); + return NULL; + } attrdict = PyDict_New(); if (attrdict == NULL) { + Py_DECREF(pydn); Py_DECREF(result); ldap_msgfree( m ); ldap_memfree(dn); @@ -68,6 +77,7 @@ LDAPmessage_to_python(LDAP *ld, LDAPMessage *m, int add_ctrls, int add_intermedi rc = ldap_get_entry_controls( ld, entry, &serverctrls ); if (rc) { + Py_DECREF(pydn); Py_DECREF(result); ldap_msgfree( m ); ldap_memfree(dn); @@ -78,6 +88,7 @@ LDAPmessage_to_python(LDAP *ld, LDAPMessage *m, int add_ctrls, int add_intermedi if ( ! ( pyctrls = LDAPControls_to_List( serverctrls ) ) ) { int err = LDAP_NO_MEMORY; ldap_set_option( ld, LDAP_OPT_ERROR_NUMBER, &err ); + Py_DECREF(pydn); Py_DECREF(result); ldap_msgfree( m ); ldap_memfree(dn); @@ -92,22 +103,27 @@ LDAPmessage_to_python(LDAP *ld, LDAPMessage *m, int add_ctrls, int add_intermedi attr = ldap_next_attribute( ld, entry, ber ) ) { PyObject* valuelist; + PyObject* pyattr; + pyattr = PyUnicode_FromString(attr); + struct berval ** bvals = ldap_get_values_len( ld, entry, attr ); /* Find which list to append to */ - if ( PyMapping_HasKeyString( attrdict, attr ) ) { - valuelist = PyMapping_GetItemString( attrdict, attr ); + if ( PyDict_Contains( attrdict, pyattr ) ) { + valuelist = PyDict_GetItem( attrdict, pyattr ); } else { valuelist = PyList_New(0); - if (valuelist != NULL && PyMapping_SetItemString(attrdict, - attr, valuelist) == -1) { + if (valuelist != NULL && PyDict_SetItem(attrdict, + pyattr, valuelist) == -1) { Py_DECREF(valuelist); valuelist = NULL; /* catch error later */ } } if (valuelist == NULL) { + Py_DECREF(pydn); + Py_DECREF(pyattr); Py_DECREF(attrdict); Py_DECREF(result); if (ber != NULL) @@ -126,6 +142,8 @@ LDAPmessage_to_python(LDAP *ld, LDAPMessage *m, int add_ctrls, int add_intermedi valuestr = LDAPberval_to_object(bvals[i]); if (PyList_Append( valuelist, valuestr ) == -1) { + Py_DECREF(pydn); + Py_DECREF(pyattr); Py_DECREF(attrdict); Py_DECREF(result); Py_DECREF(valuestr); @@ -142,15 +160,17 @@ LDAPmessage_to_python(LDAP *ld, LDAPMessage *m, int add_ctrls, int add_intermedi } ldap_value_free_len(bvals); } + Py_DECREF(pyattr); Py_DECREF( valuelist ); ldap_memfree(attr); } if (add_ctrls) { - entrytuple = Py_BuildValue("(sOO)", dn, attrdict, pyctrls); + entrytuple = Py_BuildValue("(OOO)", pydn, attrdict, pyctrls); } else { - entrytuple = Py_BuildValue("(sO)", dn, attrdict); + entrytuple = Py_BuildValue("(OO)", pydn, attrdict); } + Py_DECREF(pydn); ldap_memfree(dn); Py_DECREF(attrdict); Py_XDECREF(pyctrls); diff --git a/Tests/t_search.py b/Tests/t_search.py index 8da55c9..e398287 100644 --- a/Tests/t_search.py +++ b/Tests/t_search.py @@ -1,5 +1,14 @@ from __future__ import unicode_literals +import sys + +if sys.version_info[0] <= 2: + PY2 = True + text_type = unicode +else: + PY2 = False + text_type = str + import ldap, unittest import slapd @@ -40,7 +49,7 @@ class TestSearch(unittest.TestCase): "", ])+"\n") - l = LDAPObject(server.get_url()) + l = LDAPObject(server.get_url(), bytes_mode=False) l.protocol_version = 3 l.set_option(ldap.OPT_REFERRALS,0) l.simple_bind_s(server.get_root_dn(), @@ -48,6 +57,64 @@ class TestSearch(unittest.TestCase): self.ldap = l self.server = server + def test_reject_bytes_base(self): + base = self.server.get_dn_suffix() + l = self.ldap + + with self.assertRaises(TypeError): + l.search_s(base.encode('utf-8'), ldap.SCOPE_SUBTREE, '(cn=Foo*)', ['*']) + with self.assertRaises(TypeError): + l.search_s(base, ldap.SCOPE_SUBTREE, b'(cn=Foo*)', ['*']) + with self.assertRaises(TypeError): + l.search_s(base, ldap.SCOPE_SUBTREE, '(cn=Foo*)', [b'*']) + + def test_search_keys_are_text(self): + base = self.server.get_dn_suffix() + l = self.ldap + result = l.search_s(base, ldap.SCOPE_SUBTREE, '(cn=Foo*)', ['*']) + result.sort() + dn, fields = result[0] + self.assertEqual(dn, 'cn=Foo1,%s' % base) + self.assertEqual(type(dn), text_type) + for key, values in fields.items(): + self.assertEqual(type(key), text_type) + for value in values: + self.assertEqual(type(value), bytes) + + def _get_bytes_ldapobject(self): + l = LDAPObject(server.get_url(), bytes_mode=True) + l.protocol_version = 3 + l.set_option(ldap.OPT_REFERRALS,0) + l.simple_bind_s(self.server.get_root_dn().encode('utf-8'), + self.server.get_root_password().encode('utf-8')) + return l + + @unittest.skipUnless(PY2, "no bytes_mode under Py3") + def test_bytesmode_search_requires_bytes(self): + l = self._get_bytes_ldapobject() + base = self.server.get_dn_suffix() + + with self.assertRaises(TypeError): + l.search_s(base.encode('utf-8'), ldap.SCOPE_SUBTREE, '(cn=Foo*)', [b'*']) + with self.assertRaises(TypeError): + l.search_s(base.encode('utf-8'), ldap.SCOPE_SUBTREE, b'(cn=Foo*)', ['*']) + with self.assertRaises(TypeError): + l.search_s(base, ldap.SCOPE_SUBTREE, b'(cn=Foo*)', [b'*']) + + @unittest.skipUnless(PY2, "no bytes_mode under Py3") + def test_bytesmode_search_results_have_bytes(self): + l = self._get_bytes_ldapobject() + base = self.server.get_dn_suffix() + result = l.search_s(base.encode('utf-8'), ldap.SCOPE_SUBTREE, b'(cn=Foo*)', [b'*']) + result.sort() + dn, fields = result[0] + self.assertEqual(dn, b'cn=Foo1,%s' % base) + self.assertEqual(type(dn), bytes) + for key, values in fields.items(): + self.assertEqual(type(key), bytes) + for value in values: + self.assertEqual(type(value), bytes) + def test_search_subtree(self): base = self.server.get_dn_suffix() l = self.ldap