Support custom mapping on connection escape.

This commit is contained in:
Dave Stinson
2015-02-05 08:54:29 +00:00
parent f225c13c82
commit af8340c3fe
3 changed files with 15 additions and 5 deletions

View File

@@ -725,11 +725,11 @@ class Connection(object):
self._execute_command(COMMAND.COM_INIT_DB, db)
self._read_ok_packet()
def escape(self, obj):
def escape(self, obj, encoders=None):
''' Escape whatever value you pass to it '''
if isinstance(obj, str_type):
return "'" + self.escape_string(obj) + "'"
return escape_item(obj, self.charset)
return escape_item(obj, self.charset, custom_encoders=encoders)
def literal(self, obj):
'''Alias for escape()'''

View File

@@ -15,13 +15,13 @@ ESCAPE_REGEX = re.compile(r"[\0\n\r\032\'\"\\]")
ESCAPE_MAP = {'\0': '\\0', '\n': '\\n', '\r': '\\r', '\032': '\\Z',
'\'': '\\\'', '"': '\\"', '\\': '\\\\'}
def escape_item(val, charset):
def escape_item(val, charset, custom_encoders=None):
if type(val) in [tuple, list, set]:
return escape_sequence(val, charset)
if type(val) is dict:
return escape_dict(val, charset)
encoder = encoders[type(val)]
item_encoders = custom_encoders or encoders
encoder = item_encoders[type(val)]
val = encoder(val)
return val

View File

@@ -29,6 +29,16 @@ class TestConnection(base.PyMySQLTestCase):
cur.execute("SET sql_mode='NO_BACKSLASH_ESCAPES'")
self.assertEqual(con.escape("foo'bar"), "'foo''bar'")
def test_escape_custom_object(self):
con = self.connections[0]
cur = con.cursor()
class Foo(object):
value = "bar"
encoder = lambda x: x.value
self.assertEqual(con.escape(Foo(), encoders={Foo: encoder}), "bar")
def test_autocommit(self):
con = self.connections[0]
self.assertFalse(con.get_autocommit())