From af8340c3fe75210aa1abfd69af14bf634d97377c Mon Sep 17 00:00:00 2001 From: Dave Stinson Date: Thu, 5 Feb 2015 08:54:29 +0000 Subject: [PATCH] Support custom mapping on connection escape. --- pymysql/connections.py | 4 ++-- pymysql/converters.py | 6 +++--- pymysql/tests/test_connection.py | 10 ++++++++++ 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/pymysql/connections.py b/pymysql/connections.py index e91bba8..0d32d60 100644 --- a/pymysql/connections.py +++ b/pymysql/connections.py @@ -725,11 +725,11 @@ class Connection(object): self._execute_command(COMMAND.COM_INIT_DB, db) self._read_ok_packet() - def escape(self, obj): + def escape(self, obj, encoders=None): ''' Escape whatever value you pass to it ''' if isinstance(obj, str_type): return "'" + self.escape_string(obj) + "'" - return escape_item(obj, self.charset) + return escape_item(obj, self.charset, custom_encoders=encoders) def literal(self, obj): '''Alias for escape()''' diff --git a/pymysql/converters.py b/pymysql/converters.py index 893c7c6..e2e3c1b 100644 --- a/pymysql/converters.py +++ b/pymysql/converters.py @@ -15,13 +15,13 @@ ESCAPE_REGEX = re.compile(r"[\0\n\r\032\'\"\\]") ESCAPE_MAP = {'\0': '\\0', '\n': '\\n', '\r': '\\r', '\032': '\\Z', '\'': '\\\'', '"': '\\"', '\\': '\\\\'} - -def escape_item(val, charset): +def escape_item(val, charset, custom_encoders=None): if type(val) in [tuple, list, set]: return escape_sequence(val, charset) if type(val) is dict: return escape_dict(val, charset) - encoder = encoders[type(val)] + item_encoders = custom_encoders or encoders + encoder = item_encoders[type(val)] val = encoder(val) return val diff --git a/pymysql/tests/test_connection.py b/pymysql/tests/test_connection.py index 7580ca6..4fa521b 100644 --- a/pymysql/tests/test_connection.py +++ b/pymysql/tests/test_connection.py @@ -29,6 +29,16 @@ class TestConnection(base.PyMySQLTestCase): cur.execute("SET sql_mode='NO_BACKSLASH_ESCAPES'") self.assertEqual(con.escape("foo'bar"), "'foo''bar'") + def test_escape_custom_object(self): + con = self.connections[0] + cur = con.cursor() + + class Foo(object): + value = "bar" + encoder = lambda x: x.value + + self.assertEqual(con.escape(Foo(), encoders={Foo: encoder}), "bar") + def test_autocommit(self): con = self.connections[0] self.assertFalse(con.get_autocommit())