Support custom mapping on connection escape.
This commit is contained in:
@@ -725,11 +725,11 @@ class Connection(object):
|
||||
self._execute_command(COMMAND.COM_INIT_DB, db)
|
||||
self._read_ok_packet()
|
||||
|
||||
def escape(self, obj):
|
||||
def escape(self, obj, encoders=None):
|
||||
''' Escape whatever value you pass to it '''
|
||||
if isinstance(obj, str_type):
|
||||
return "'" + self.escape_string(obj) + "'"
|
||||
return escape_item(obj, self.charset)
|
||||
return escape_item(obj, self.charset, custom_encoders=encoders)
|
||||
|
||||
def literal(self, obj):
|
||||
'''Alias for escape()'''
|
||||
|
||||
@@ -15,13 +15,13 @@ ESCAPE_REGEX = re.compile(r"[\0\n\r\032\'\"\\]")
|
||||
ESCAPE_MAP = {'\0': '\\0', '\n': '\\n', '\r': '\\r', '\032': '\\Z',
|
||||
'\'': '\\\'', '"': '\\"', '\\': '\\\\'}
|
||||
|
||||
|
||||
def escape_item(val, charset):
|
||||
def escape_item(val, charset, custom_encoders=None):
|
||||
if type(val) in [tuple, list, set]:
|
||||
return escape_sequence(val, charset)
|
||||
if type(val) is dict:
|
||||
return escape_dict(val, charset)
|
||||
encoder = encoders[type(val)]
|
||||
item_encoders = custom_encoders or encoders
|
||||
encoder = item_encoders[type(val)]
|
||||
val = encoder(val)
|
||||
return val
|
||||
|
||||
|
||||
@@ -29,6 +29,16 @@ class TestConnection(base.PyMySQLTestCase):
|
||||
cur.execute("SET sql_mode='NO_BACKSLASH_ESCAPES'")
|
||||
self.assertEqual(con.escape("foo'bar"), "'foo''bar'")
|
||||
|
||||
def test_escape_custom_object(self):
|
||||
con = self.connections[0]
|
||||
cur = con.cursor()
|
||||
|
||||
class Foo(object):
|
||||
value = "bar"
|
||||
encoder = lambda x: x.value
|
||||
|
||||
self.assertEqual(con.escape(Foo(), encoders={Foo: encoder}), "bar")
|
||||
|
||||
def test_autocommit(self):
|
||||
con = self.connections[0]
|
||||
self.assertFalse(con.get_autocommit())
|
||||
|
||||
Reference in New Issue
Block a user