Update interface of encoders to include the mapping argument.
This commit is contained in:
@@ -725,11 +725,11 @@ class Connection(object):
|
||||
self._execute_command(COMMAND.COM_INIT_DB, db)
|
||||
self._read_ok_packet()
|
||||
|
||||
def escape(self, obj, encoders=None):
|
||||
def escape(self, obj, mapping=None):
|
||||
''' Escape whatever value you pass to it '''
|
||||
if isinstance(obj, str_type):
|
||||
return "'" + self.escape_string(obj) + "'"
|
||||
return escape_item(obj, self.charset, custom_encoders=encoders)
|
||||
return escape_item(obj, self.charset, mapping=mapping)
|
||||
|
||||
def literal(self, obj):
|
||||
'''Alias for escape()'''
|
||||
|
||||
@@ -15,65 +15,72 @@ ESCAPE_REGEX = re.compile(r"[\0\n\r\032\'\"\\]")
|
||||
ESCAPE_MAP = {'\0': '\\0', '\n': '\\n', '\r': '\\r', '\032': '\\Z',
|
||||
'\'': '\\\'', '"': '\\"', '\\': '\\\\'}
|
||||
|
||||
def escape_item(val, charset, custom_encoders=None):
|
||||
if type(val) in [tuple, list, set]:
|
||||
return escape_sequence(val, charset)
|
||||
if type(val) is dict:
|
||||
return escape_dict(val, charset)
|
||||
item_encoders = custom_encoders or encoders
|
||||
encoder = item_encoders[type(val)]
|
||||
val = encoder(val)
|
||||
def escape_item(val, charset, mapping=None):
|
||||
if mapping is None:
|
||||
mapping = encoders
|
||||
encoder = mapping.get(type(val))
|
||||
|
||||
# Fallback to default when no encoder found
|
||||
if not encoder:
|
||||
try:
|
||||
encoder = mapping[text_type]
|
||||
except KeyError:
|
||||
raise TypeError("no default type converter defined")
|
||||
|
||||
if encoder in (escape_dict, escape_sequence):
|
||||
val = encoder(val, charset, mapping)
|
||||
else:
|
||||
val = encoder(val, mapping)
|
||||
return val
|
||||
|
||||
def escape_dict(val, charset):
|
||||
def escape_dict(val, charset, mapping=None):
|
||||
n = {}
|
||||
for k, v in val.items():
|
||||
quoted = escape_item(v, charset)
|
||||
quoted = escape_item(v, charset, mapping)
|
||||
n[k] = quoted
|
||||
return n
|
||||
|
||||
def escape_sequence(val, charset):
|
||||
def escape_sequence(val, charset, mapping=None):
|
||||
n = []
|
||||
for item in val:
|
||||
quoted = escape_item(item, charset)
|
||||
quoted = escape_item(item, charset, mapping)
|
||||
n.append(quoted)
|
||||
return "(" + ",".join(n) + ")"
|
||||
|
||||
def escape_set(val, charset):
|
||||
val = map(lambda x: escape_item(x, charset), val)
|
||||
def escape_set(val, charset, mapping=None):
|
||||
val = map(lambda x: escape_item(x, charset, mapping), val)
|
||||
return ','.join(val)
|
||||
|
||||
def escape_bool(value):
|
||||
def escape_bool(value, mapping=None):
|
||||
return str(int(value))
|
||||
|
||||
def escape_object(value):
|
||||
def escape_object(value, mapping=None):
|
||||
return str(value)
|
||||
|
||||
def escape_int(value):
|
||||
def escape_int(value, mapping=None):
|
||||
return str(value)
|
||||
|
||||
|
||||
def escape_float(value):
|
||||
def escape_float(value, mapping=None):
|
||||
return ('%.15g' % value)
|
||||
|
||||
def escape_string(value):
|
||||
def escape_string(value, mapping=None):
|
||||
return ("%s" % (ESCAPE_REGEX.sub(
|
||||
lambda match: ESCAPE_MAP.get(match.group(0)), value),))
|
||||
|
||||
def escape_str(value):
|
||||
return "'%s'" % escape_string(value)
|
||||
def escape_str(value, mapping=None):
|
||||
return "'%s'" % escape_string(value, mapping)
|
||||
|
||||
def escape_unicode(value):
|
||||
return escape_str(value)
|
||||
def escape_unicode(value, mapping=None):
|
||||
return escape_str(value, mapping)
|
||||
|
||||
def escape_bytes(value):
|
||||
def escape_bytes(value, mapping=None):
|
||||
# escape_bytes is calld only on Python 3.
|
||||
return escape_str(value.decode('ascii', 'surrogateescape'))
|
||||
return escape_str(value.decode('ascii', 'surrogateescape'), mapping)
|
||||
|
||||
def escape_None(value):
|
||||
def escape_None(value, mapping=None):
|
||||
return 'NULL'
|
||||
|
||||
def escape_timedelta(obj):
|
||||
def escape_timedelta(obj, mapping=None):
|
||||
seconds = int(obj.seconds) % 60
|
||||
minutes = int(obj.seconds // 60) % 60
|
||||
hours = int(obj.seconds // 3600) % 24 + int(obj.days) * 24
|
||||
@@ -83,25 +90,25 @@ def escape_timedelta(obj):
|
||||
fmt = "'{0:02d}:{1:02d}:{2:02d}'"
|
||||
return fmt.format(hours, minutes, seconds, obj.microseconds)
|
||||
|
||||
def escape_time(obj):
|
||||
def escape_time(obj, mapping=None):
|
||||
if obj.microsecond:
|
||||
fmt = "'{0.hour:02}:{0.minute:02}:{0.second:02}.{0.microsecond:06}'"
|
||||
else:
|
||||
fmt = "'{0.hour:02}:{0.minute:02}:{0.second:02}'"
|
||||
return fmt.format(obj)
|
||||
|
||||
def escape_datetime(obj):
|
||||
def escape_datetime(obj, mapping=None):
|
||||
if obj.microsecond:
|
||||
fmt = "'{0.year:04}-{0.month:02}-{0.day:02} {0.hour:02}:{0.minute:02}:{0.second:02}.{0.microsecond:06}'"
|
||||
else:
|
||||
fmt = "'{0.year:04}-{0.month:02}-{0.day:02} {0.hour:02}:{0.minute:02}:{0.second:02}'"
|
||||
return fmt.format(obj)
|
||||
|
||||
def escape_date(obj):
|
||||
def escape_date(obj, mapping=None):
|
||||
fmt = "'{0.year:04}-{0.month:02}-{0.day:02}'"
|
||||
return fmt.format(obj)
|
||||
|
||||
def escape_struct_time(obj):
|
||||
def escape_struct_time(obj, mapping=None):
|
||||
return escape_datetime(datetime.datetime(*obj[:6]))
|
||||
|
||||
def convert_datetime(obj):
|
||||
@@ -309,7 +316,7 @@ encoders = {
|
||||
datetime.timedelta: escape_timedelta,
|
||||
datetime.time: escape_time,
|
||||
time.struct_time: escape_struct_time,
|
||||
Decimal: str,
|
||||
Decimal: escape_object,
|
||||
}
|
||||
|
||||
if not PY2 or JYTHON or IRONPYTHON:
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import datetime
|
||||
import decimal
|
||||
import pymysql
|
||||
import time
|
||||
from pymysql.tests import base
|
||||
@@ -21,24 +23,6 @@ class TestConnection(base.PyMySQLTestCase):
|
||||
cur.execute("SELECT '" + t + "'")
|
||||
assert cur.fetchone()[0] == t
|
||||
|
||||
def test_escape_string(self):
|
||||
con = self.connections[0]
|
||||
cur = con.cursor()
|
||||
|
||||
self.assertEqual(con.escape("foo'bar"), "'foo\\'bar'")
|
||||
cur.execute("SET sql_mode='NO_BACKSLASH_ESCAPES'")
|
||||
self.assertEqual(con.escape("foo'bar"), "'foo''bar'")
|
||||
|
||||
def test_escape_custom_object(self):
|
||||
con = self.connections[0]
|
||||
cur = con.cursor()
|
||||
|
||||
class Foo(object):
|
||||
value = "bar"
|
||||
encoder = lambda x: x.value
|
||||
|
||||
self.assertEqual(con.escape(Foo(), encoders={Foo: encoder}), "bar")
|
||||
|
||||
def test_autocommit(self):
|
||||
con = self.connections[0]
|
||||
self.assertFalse(con.get_autocommit())
|
||||
@@ -79,3 +63,68 @@ class TestConnection(base.PyMySQLTestCase):
|
||||
# error occures while reading, not writing because of socket buffer.
|
||||
#self.assertEquals(cm.exception.args[0], 2006)
|
||||
self.assertIn(cm.exception.args[0], (2006, 2013))
|
||||
|
||||
|
||||
# A custom type and function to escape it
|
||||
class Foo(object):
|
||||
value = "bar"
|
||||
|
||||
def escape_foo(x, d):
|
||||
return x.value
|
||||
|
||||
|
||||
class TestEscape(base.PyMySQLTestCase):
|
||||
def test_escape_string(self):
|
||||
con = self.connections[0]
|
||||
cur = con.cursor()
|
||||
|
||||
self.assertEqual(con.escape("foo'bar"), "'foo\\'bar'")
|
||||
cur.execute("SET sql_mode='NO_BACKSLASH_ESCAPES'")
|
||||
self.assertEqual(con.escape("foo'bar"), "'foo''bar'")
|
||||
|
||||
def test_escape_builtin_encoders(self):
|
||||
con = self.connections[0]
|
||||
cur = con.cursor()
|
||||
|
||||
val = datetime.datetime(2012, 3, 4, 5, 6)
|
||||
self.assertEqual(con.escape(val, con.encoders), "'2012-03-04 05:06:00'")
|
||||
|
||||
def test_escape_custom_object(self):
|
||||
con = self.connections[0]
|
||||
cur = con.cursor()
|
||||
|
||||
mapping = {Foo: escape_foo}
|
||||
self.assertEqual(con.escape(Foo(), mapping), "bar")
|
||||
|
||||
def test_escape_fallback_encoder(self):
|
||||
con = self.connections[0]
|
||||
cur = con.cursor()
|
||||
|
||||
class Custom(str):
|
||||
pass
|
||||
|
||||
mapping = {pymysql.text_type: pymysql.escape_string}
|
||||
self.assertEqual(con.escape(Custom('foobar'), mapping), "'foobar'")
|
||||
|
||||
def test_escape_no_default(self):
|
||||
con = self.connections[0]
|
||||
cur = con.cursor()
|
||||
|
||||
self.assertRaises(TypeError, con.escape, 42, {})
|
||||
|
||||
def test_escape_dict_value(self):
|
||||
con = self.connections[0]
|
||||
cur = con.cursor()
|
||||
|
||||
mapping = con.encoders.copy()
|
||||
mapping[Foo] = escape_foo
|
||||
self.assertEqual(con.escape({'foo': Foo()}, mapping), {'foo': "bar"})
|
||||
|
||||
def test_escape_list_item(self):
|
||||
con = self.connections[0]
|
||||
cur = con.cursor()
|
||||
|
||||
mapping = con.encoders.copy()
|
||||
mapping[Foo] = escape_foo
|
||||
self.assertEqual(con.escape([Foo()], mapping), "(bar)")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user