Add flag to restore legacy encoding rules (See #1).

With this commit, all ldap connections accept a new parameter,
``bytes_mode``.

When set to ``True``, this flag emulates the old Python 2 behavior,
where all fields are bytes - including those declared as UTF-8 by the RFC (DN,
RDN, attribute names).

If this flag is set to ``False``, the code works with text (unicode) for all
text fields (everything except attribute values).

If no value is set under Python 2, the code will raise a BytesWarning
and proceed with the flag set to ``True``, for backwards compatibility.
Under Python 3, the value can only be set to ``False``.

For safety and ease of upgrade, the code checks that all provided
arguments are of the expected type (unicode with ``bytes_mode=False``,
bytes with ``bytes_mode=True``).
This commit is contained in:
Xelnor
2015-07-21 01:11:55 +02:00
committed by Raphaël Barrois
parent acb163c3e3
commit 770611a056
5 changed files with 244 additions and 14 deletions

View File

@@ -74,7 +74,7 @@ def _ldap_function_call(lock,func,*args,**kwargs):
return result
def initialize(uri,trace_level=0,trace_file=sys.stdout,trace_stack_limit=None):
def initialize(uri,trace_level=0,trace_file=sys.stdout,trace_stack_limit=None, bytes_mode=None):
"""
Return LDAPObject instance by opening LDAP connection to
LDAP host specified by LDAP URL
@@ -88,11 +88,13 @@ def initialize(uri,trace_level=0,trace_file=sys.stdout,trace_stack_limit=None):
trace_file
File object where to write the trace output to.
Default is to use stdout.
bytes_mode
Whether to enable "bytes_mode" for backwards compatibility under Py2.
"""
return LDAPObject(uri,trace_level,trace_file,trace_stack_limit)
return LDAPObject(uri,trace_level,trace_file,trace_stack_limit,bytes_mode)
def open(host,port=389,trace_level=0,trace_file=sys.stdout,trace_stack_limit=None):
def open(host,port=389,trace_level=0,trace_file=sys.stdout,trace_stack_limit=None,bytes_mode=None):
"""
Return LDAPObject instance by opening LDAP connection to
specified LDAP host
@@ -107,10 +109,12 @@ def open(host,port=389,trace_level=0,trace_file=sys.stdout,trace_stack_limit=Non
trace_file
File object where to write the trace output to.
Default is to use stdout.
bytes_mode
Whether to enable "bytes_mode" for backwards compatibility under Py2.
"""
import warnings
warnings.warn('ldap.open() is deprecated! Use ldap.initialize() instead.', DeprecationWarning,2)
return initialize('ldap://%s:%d' % (host,port),trace_level,trace_file,trace_stack_limit)
return initialize('ldap://%s:%d' % (host,port),trace_level,trace_file,trace_stack_limit,bytes_mode)
init = open

View File

@@ -18,6 +18,8 @@ Basically calls into the LDAP lib are serialized by the module-wide
lock self._ldap_object_lock.
"""
from __future__ import unicode_literals
from ldap import __version__
__all__ = [
@@ -33,6 +35,7 @@ if __debug__:
import traceback
import sys,time,pprint,_ldap,ldap,ldap.sasl,ldap.functions
import warnings
from ldap.schema import SCHEMA_ATTRS
from ldap.controls import LDAPControl,DecodeControlTuples,RequestControlTuples
@@ -40,6 +43,11 @@ from ldap.extop import ExtendedRequest,ExtendedResponse
from ldap import LDAPError
PY2 = bool(sys.version_info[0] <= 2)
if PY2:
text_type = unicode
else:
text_type = str
class NO_UNIQUE_ENTRY(ldap.NO_SUCH_OBJECT):
"""
@@ -67,7 +75,7 @@ class SimpleLDAPObject:
def __init__(
self,uri,
trace_level=0,trace_file=None,trace_stack_limit=5
trace_level=0,trace_file=None,trace_stack_limit=5,bytes_mode=None
):
self._trace_level = trace_level
self._trace_file = trace_file or sys.stdout
@@ -78,6 +86,123 @@ class SimpleLDAPObject:
self.timeout = -1
self.protocol_version = ldap.VERSION3
# Bytes mode
# ----------
if bytes_mode is None and PY2:
warnings.warn(
"Under Python 2, python-ldap uses bytes by default. "
"This will be removed in Python 3 (no bytes for DN/RDN/field names). "
"Please call initialize(..., bytes_mode=False) explicitly.",
BytesWarning,
stacklevel=2,
)
bytes_mode = True
elif bytes_mode and not PY2:
raise ValueError("bytes_mode is *not* supported under Python 3.")
# On by default on Py2, off on Py3.
self.bytes_mode = bytes_mode
def _unbytesify_value(self, value):
"""Adapt a value following bytes_mode.
With bytes_mode ON, takes bytes or None and returns unicode or None.
With bytes_mode OFF, takes unicode or None and returns unicode or None.
"""
if not PY2:
return value
if value is None:
return value
elif self.bytes_mode:
if not isinstance(value, bytes):
raise TypeError("All provided fields *must* be bytes in bytes mode; got %r" % (value,))
return value.decode('utf-8')
else:
if not isinstance(value, text_type):
raise TypeError("All provided fields *must* be text when bytes mode is off; got %r" % (value,))
assert not isinstance(value, bytes)
return value
def _unbytesify_values(self, *values):
"""Adapt values following bytes_mode.
Applies _unbytesify_value on each arg.
Usage:
>>> a, b, c= self._unbytesify_values(a, b, c)
"""
if not PY2:
return values
return (
self._unbytesify_value(value)
for value in values
)
def _unbytesify_modlist(self, modlist):
"""Adapt a modlist according to bytes_mode.
A modlist is a tuple of (op, attr, value), where:
- With bytes_mode ON, attr is converted from bytes to unicode
- With bytes_mode OFF, attr is checked to be unicode
- value is *always* bytes
"""
if not PY2:
return modlist
return tuple(
(op, self._unbytesify_value(attr), val)
for op, attr, val in modlist
)
def _bytesify_value(self, value):
"""Adapt a returned value according to bytes_mode.
Takes unicode (and checks for it), and returns:
- bytes under bytes_mode
- unicode otherwise.
"""
if not PY2:
return value
if value is None:
return value
assert isinstance(value, text_type), "Should return text, got bytes instead (%r)" % (value,)
if not self.bytes_mode:
return value
else:
return value.encode('utf-8')
def _bytesify_keys(self, dct):
"""Applies bytes_mode to the keys of a dict."""
if not PY2:
return dct
return dict(
(self._bytesify_value(key), value)
for (key, value) in dct.items()
)
def _bytesify_results(self, results, with_ctrls=False):
"""Converts a "results" object according to bytes_mode.
Takes:
- a list of (dn, {field: [values]}) if with_ctrls is False
- a list of (dn, {field: [values]}, ctrls) if with_ctrls is True
And, if bytes_mode is on, converts dn and fields to bytes.
"""
if not PY2:
return results
if with_ctrls:
return [
(self._bytesify_value(dn), self._bytesify_keys(fields), ctrls)
for (dn, fields, ctrls) in results
]
else:
return [
(self._bytesify_value(dn), self._bytesify_keys(fields))
for (dn, fields) in results
]
def _ldap_lock(self,desc=''):
if ldap.LIBLDAP_R:
return ldap.LDAPLock(desc='%s within %s' %(desc,repr(self)))
@@ -188,6 +313,8 @@ class SimpleLDAPObject:
The parameter modlist is similar to the one passed to modify(),
except that no operation integer need be included in the tuples.
"""
dn = self._unbytesify_value(dn)
modlist = self._unbytesify_modlist(modlist)
return self._ldap_call(self._l.add_ext,dn,modlist,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls))
def add_ext_s(self,dn,modlist,serverctrls=None,clientctrls=None):
@@ -213,6 +340,7 @@ class SimpleLDAPObject:
"""
simple_bind([who='' [,cred='']]) -> int
"""
who, cred = self._unbytesify_values(who, cred)
return self._ldap_call(self._l.simple_bind,who,cred,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls))
def simple_bind_s(self,who='',cred='',serverctrls=None,clientctrls=None):
@@ -285,6 +413,7 @@ class SimpleLDAPObject:
A design bug in the library prevents value from containing
nul characters.
"""
dn, attr = self._unbytesify_values(dn, attr)
return self._ldap_call(self._l.compare_ext,dn,attr,value,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls))
def compare_ext_s(self,dn,attr,value,serverctrls=None,clientctrls=None):
@@ -313,6 +442,7 @@ class SimpleLDAPObject:
form returns the message id of the initiated request, and the
result can be obtained from a subsequent call to result().
"""
dn = self._unbytesify_value(dn)
return self._ldap_call(self._l.delete_ext,dn,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls))
def delete_ext_s(self,dn,serverctrls=None,clientctrls=None):
@@ -361,6 +491,8 @@ class SimpleLDAPObject:
"""
modify_ext(dn, modlist[,serverctrls=None[,clientctrls=None]]) -> int
"""
dn = self._unbytesify_value(dn)
modlist = self._unbytesify_modlist(modlist)
return self._ldap_call(self._l.modify_ext,dn,modlist,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls))
def modify_ext_s(self,dn,modlist,serverctrls=None,clientctrls=None):
@@ -415,6 +547,7 @@ class SimpleLDAPObject:
return self.rename_s(dn,newrdn,None,delold)
def passwd(self,user,oldpw,newpw,serverctrls=None,clientctrls=None):
user, oldpw, newpw = self._unbytesify_values(user, oldpw, newpw)
return self._ldap_call(self._l.passwd,user,oldpw,newpw,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls))
def passwd_s(self,user,oldpw,newpw,serverctrls=None,clientctrls=None):
@@ -436,6 +569,7 @@ class SimpleLDAPObject:
This actually corresponds to the rename* routines in the
LDAP-EXT C API library.
"""
dn, newrdn, newsuperior = self._unbytesify_values(dn, newrdn, newsuperior)
return self._ldap_call(self._l.rename,dn,newrdn,newsuperior,delold,RequestControlTuples(serverctrls),RequestControlTuples(clientctrls))
def rename_s(self,dn,newrdn,newsuperior=None,delold=1,serverctrls=None,clientctrls=None):
@@ -524,6 +658,8 @@ class SimpleLDAPObject:
if add_ctrls:
resp_data = [ (t,r,DecodeControlTuples(c,resp_ctrl_classes)) for t,r,c in resp_data ]
decoded_resp_ctrls = DecodeControlTuples(resp_ctrls,resp_ctrl_classes)
if resp_data is not None:
resp_data = self._bytesify_results(resp_data, with_ctrls=add_ctrls)
return resp_type, resp_data, resp_msgid, decoded_resp_ctrls, resp_name, resp_value
def search_ext(self,base,scope,filterstr='(objectClass=*)',attrlist=None,attrsonly=0,serverctrls=None,clientctrls=None,timeout=-1,sizelimit=0):
@@ -571,6 +707,9 @@ class SimpleLDAPObject:
The amount of search results retrieved can be limited with the
sizelimit parameter if non-zero.
"""
base, filterstr = self._unbytesify_values(base, filterstr)
if attrlist is not None:
attrlist = tuple(self._unbytesify_values(*attrlist))
return self._ldap_call(
self._l.search_ext,
base,scope,filterstr,

View File

@@ -445,7 +445,7 @@ class SubSchema:
return r_must,r_may # attribute_types()
def urlfetch(uri,trace_level=0):
def urlfetch(uri,trace_level=0,bytes_mode=None):
"""
Fetches a parsed schema entry by uri.
@@ -457,7 +457,7 @@ def urlfetch(uri,trace_level=0):
if uri.startswith('ldap:') or uri.startswith('ldaps:') or uri.startswith('ldapi:'):
import ldapurl
ldap_url = ldapurl.LDAPUrl(uri)
l=ldap.initialize(ldap_url.initializeUrl(),trace_level)
l=ldap.initialize(ldap_url.initializeUrl(),trace_level,bytes_mode=bytes_mode)
l.protocol_version = ldap.VERSION3
l.simple_bind_s(ldap_url.who or '', ldap_url.cred or '')
subschemasubentry_dn = l.search_subschemasubentry_s(ldap_url.dn)

View File

@@ -50,6 +50,7 @@ LDAPmessage_to_python(LDAP *ld, LDAPMessage *m, int add_ctrls, int add_intermedi
BerElement *ber = NULL;
PyObject* entrytuple;
PyObject* attrdict;
PyObject* pydn;
dn = ldap_get_dn( ld, entry );
if (dn == NULL) {
@@ -57,9 +58,17 @@ LDAPmessage_to_python(LDAP *ld, LDAPMessage *m, int add_ctrls, int add_intermedi
ldap_msgfree( m );
return LDAPerror( ld, "ldap_get_dn" );
}
pydn = PyUnicode_FromString(dn);
if (pydn == NULL) {
Py_DECREF(result);
ldap_msgfree( m );
ldap_memfree(dn);
return NULL;
}
attrdict = PyDict_New();
if (attrdict == NULL) {
Py_DECREF(pydn);
Py_DECREF(result);
ldap_msgfree( m );
ldap_memfree(dn);
@@ -68,6 +77,7 @@ LDAPmessage_to_python(LDAP *ld, LDAPMessage *m, int add_ctrls, int add_intermedi
rc = ldap_get_entry_controls( ld, entry, &serverctrls );
if (rc) {
Py_DECREF(pydn);
Py_DECREF(result);
ldap_msgfree( m );
ldap_memfree(dn);
@@ -78,6 +88,7 @@ LDAPmessage_to_python(LDAP *ld, LDAPMessage *m, int add_ctrls, int add_intermedi
if ( ! ( pyctrls = LDAPControls_to_List( serverctrls ) ) ) {
int err = LDAP_NO_MEMORY;
ldap_set_option( ld, LDAP_OPT_ERROR_NUMBER, &err );
Py_DECREF(pydn);
Py_DECREF(result);
ldap_msgfree( m );
ldap_memfree(dn);
@@ -92,22 +103,27 @@ LDAPmessage_to_python(LDAP *ld, LDAPMessage *m, int add_ctrls, int add_intermedi
attr = ldap_next_attribute( ld, entry, ber )
) {
PyObject* valuelist;
PyObject* pyattr;
pyattr = PyUnicode_FromString(attr);
struct berval ** bvals =
ldap_get_values_len( ld, entry, attr );
/* Find which list to append to */
if ( PyMapping_HasKeyString( attrdict, attr ) ) {
valuelist = PyMapping_GetItemString( attrdict, attr );
if ( PyDict_Contains( attrdict, pyattr ) ) {
valuelist = PyDict_GetItem( attrdict, pyattr );
} else {
valuelist = PyList_New(0);
if (valuelist != NULL && PyMapping_SetItemString(attrdict,
attr, valuelist) == -1) {
if (valuelist != NULL && PyDict_SetItem(attrdict,
pyattr, valuelist) == -1) {
Py_DECREF(valuelist);
valuelist = NULL; /* catch error later */
}
}
if (valuelist == NULL) {
Py_DECREF(pydn);
Py_DECREF(pyattr);
Py_DECREF(attrdict);
Py_DECREF(result);
if (ber != NULL)
@@ -126,6 +142,8 @@ LDAPmessage_to_python(LDAP *ld, LDAPMessage *m, int add_ctrls, int add_intermedi
valuestr = LDAPberval_to_object(bvals[i]);
if (PyList_Append( valuelist, valuestr ) == -1) {
Py_DECREF(pydn);
Py_DECREF(pyattr);
Py_DECREF(attrdict);
Py_DECREF(result);
Py_DECREF(valuestr);
@@ -142,15 +160,17 @@ LDAPmessage_to_python(LDAP *ld, LDAPMessage *m, int add_ctrls, int add_intermedi
}
ldap_value_free_len(bvals);
}
Py_DECREF(pyattr);
Py_DECREF( valuelist );
ldap_memfree(attr);
}
if (add_ctrls) {
entrytuple = Py_BuildValue("(sOO)", dn, attrdict, pyctrls);
entrytuple = Py_BuildValue("(OOO)", pydn, attrdict, pyctrls);
} else {
entrytuple = Py_BuildValue("(sO)", dn, attrdict);
entrytuple = Py_BuildValue("(OO)", pydn, attrdict);
}
Py_DECREF(pydn);
ldap_memfree(dn);
Py_DECREF(attrdict);
Py_XDECREF(pyctrls);

View File

@@ -1,5 +1,14 @@
from __future__ import unicode_literals
import sys
if sys.version_info[0] <= 2:
PY2 = True
text_type = unicode
else:
PY2 = False
text_type = str
import ldap, unittest
import slapd
@@ -40,7 +49,7 @@ class TestSearch(unittest.TestCase):
"",
])+"\n")
l = LDAPObject(server.get_url())
l = LDAPObject(server.get_url(), bytes_mode=False)
l.protocol_version = 3
l.set_option(ldap.OPT_REFERRALS,0)
l.simple_bind_s(server.get_root_dn(),
@@ -48,6 +57,64 @@ class TestSearch(unittest.TestCase):
self.ldap = l
self.server = server
def test_reject_bytes_base(self):
base = self.server.get_dn_suffix()
l = self.ldap
with self.assertRaises(TypeError):
l.search_s(base.encode('utf-8'), ldap.SCOPE_SUBTREE, '(cn=Foo*)', ['*'])
with self.assertRaises(TypeError):
l.search_s(base, ldap.SCOPE_SUBTREE, b'(cn=Foo*)', ['*'])
with self.assertRaises(TypeError):
l.search_s(base, ldap.SCOPE_SUBTREE, '(cn=Foo*)', [b'*'])
def test_search_keys_are_text(self):
base = self.server.get_dn_suffix()
l = self.ldap
result = l.search_s(base, ldap.SCOPE_SUBTREE, '(cn=Foo*)', ['*'])
result.sort()
dn, fields = result[0]
self.assertEqual(dn, 'cn=Foo1,%s' % base)
self.assertEqual(type(dn), text_type)
for key, values in fields.items():
self.assertEqual(type(key), text_type)
for value in values:
self.assertEqual(type(value), bytes)
def _get_bytes_ldapobject(self):
l = LDAPObject(server.get_url(), bytes_mode=True)
l.protocol_version = 3
l.set_option(ldap.OPT_REFERRALS,0)
l.simple_bind_s(self.server.get_root_dn().encode('utf-8'),
self.server.get_root_password().encode('utf-8'))
return l
@unittest.skipUnless(PY2, "no bytes_mode under Py3")
def test_bytesmode_search_requires_bytes(self):
l = self._get_bytes_ldapobject()
base = self.server.get_dn_suffix()
with self.assertRaises(TypeError):
l.search_s(base.encode('utf-8'), ldap.SCOPE_SUBTREE, '(cn=Foo*)', [b'*'])
with self.assertRaises(TypeError):
l.search_s(base.encode('utf-8'), ldap.SCOPE_SUBTREE, b'(cn=Foo*)', ['*'])
with self.assertRaises(TypeError):
l.search_s(base, ldap.SCOPE_SUBTREE, b'(cn=Foo*)', [b'*'])
@unittest.skipUnless(PY2, "no bytes_mode under Py3")
def test_bytesmode_search_results_have_bytes(self):
l = self._get_bytes_ldapobject()
base = self.server.get_dn_suffix()
result = l.search_s(base.encode('utf-8'), ldap.SCOPE_SUBTREE, b'(cn=Foo*)', [b'*'])
result.sort()
dn, fields = result[0]
self.assertEqual(dn, b'cn=Foo1,%s' % base)
self.assertEqual(type(dn), bytes)
for key, values in fields.items():
self.assertEqual(type(key), bytes)
for value in values:
self.assertEqual(type(value), bytes)
def test_search_subtree(self):
base = self.server.get_dn_suffix()
l = self.ldap