Update interface of encoders to include the mapping argument.
This commit is contained in:
@@ -725,11 +725,11 @@ class Connection(object):
|
|||||||
self._execute_command(COMMAND.COM_INIT_DB, db)
|
self._execute_command(COMMAND.COM_INIT_DB, db)
|
||||||
self._read_ok_packet()
|
self._read_ok_packet()
|
||||||
|
|
||||||
def escape(self, obj, encoders=None):
|
def escape(self, obj, mapping=None):
|
||||||
''' Escape whatever value you pass to it '''
|
''' Escape whatever value you pass to it '''
|
||||||
if isinstance(obj, str_type):
|
if isinstance(obj, str_type):
|
||||||
return "'" + self.escape_string(obj) + "'"
|
return "'" + self.escape_string(obj) + "'"
|
||||||
return escape_item(obj, self.charset, custom_encoders=encoders)
|
return escape_item(obj, self.charset, mapping=mapping)
|
||||||
|
|
||||||
def literal(self, obj):
|
def literal(self, obj):
|
||||||
'''Alias for escape()'''
|
'''Alias for escape()'''
|
||||||
|
|||||||
@@ -15,65 +15,72 @@ ESCAPE_REGEX = re.compile(r"[\0\n\r\032\'\"\\]")
|
|||||||
ESCAPE_MAP = {'\0': '\\0', '\n': '\\n', '\r': '\\r', '\032': '\\Z',
|
ESCAPE_MAP = {'\0': '\\0', '\n': '\\n', '\r': '\\r', '\032': '\\Z',
|
||||||
'\'': '\\\'', '"': '\\"', '\\': '\\\\'}
|
'\'': '\\\'', '"': '\\"', '\\': '\\\\'}
|
||||||
|
|
||||||
def escape_item(val, charset, custom_encoders=None):
|
def escape_item(val, charset, mapping=None):
|
||||||
if type(val) in [tuple, list, set]:
|
if mapping is None:
|
||||||
return escape_sequence(val, charset)
|
mapping = encoders
|
||||||
if type(val) is dict:
|
encoder = mapping.get(type(val))
|
||||||
return escape_dict(val, charset)
|
|
||||||
item_encoders = custom_encoders or encoders
|
# Fallback to default when no encoder found
|
||||||
encoder = item_encoders[type(val)]
|
if not encoder:
|
||||||
val = encoder(val)
|
try:
|
||||||
|
encoder = mapping[text_type]
|
||||||
|
except KeyError:
|
||||||
|
raise TypeError("no default type converter defined")
|
||||||
|
|
||||||
|
if encoder in (escape_dict, escape_sequence):
|
||||||
|
val = encoder(val, charset, mapping)
|
||||||
|
else:
|
||||||
|
val = encoder(val, mapping)
|
||||||
return val
|
return val
|
||||||
|
|
||||||
def escape_dict(val, charset):
|
def escape_dict(val, charset, mapping=None):
|
||||||
n = {}
|
n = {}
|
||||||
for k, v in val.items():
|
for k, v in val.items():
|
||||||
quoted = escape_item(v, charset)
|
quoted = escape_item(v, charset, mapping)
|
||||||
n[k] = quoted
|
n[k] = quoted
|
||||||
return n
|
return n
|
||||||
|
|
||||||
def escape_sequence(val, charset):
|
def escape_sequence(val, charset, mapping=None):
|
||||||
n = []
|
n = []
|
||||||
for item in val:
|
for item in val:
|
||||||
quoted = escape_item(item, charset)
|
quoted = escape_item(item, charset, mapping)
|
||||||
n.append(quoted)
|
n.append(quoted)
|
||||||
return "(" + ",".join(n) + ")"
|
return "(" + ",".join(n) + ")"
|
||||||
|
|
||||||
def escape_set(val, charset):
|
def escape_set(val, charset, mapping=None):
|
||||||
val = map(lambda x: escape_item(x, charset), val)
|
val = map(lambda x: escape_item(x, charset, mapping), val)
|
||||||
return ','.join(val)
|
return ','.join(val)
|
||||||
|
|
||||||
def escape_bool(value):
|
def escape_bool(value, mapping=None):
|
||||||
return str(int(value))
|
return str(int(value))
|
||||||
|
|
||||||
def escape_object(value):
|
def escape_object(value, mapping=None):
|
||||||
return str(value)
|
return str(value)
|
||||||
|
|
||||||
def escape_int(value):
|
def escape_int(value, mapping=None):
|
||||||
return str(value)
|
return str(value)
|
||||||
|
|
||||||
|
def escape_float(value, mapping=None):
|
||||||
def escape_float(value):
|
|
||||||
return ('%.15g' % value)
|
return ('%.15g' % value)
|
||||||
|
|
||||||
def escape_string(value):
|
def escape_string(value, mapping=None):
|
||||||
return ("%s" % (ESCAPE_REGEX.sub(
|
return ("%s" % (ESCAPE_REGEX.sub(
|
||||||
lambda match: ESCAPE_MAP.get(match.group(0)), value),))
|
lambda match: ESCAPE_MAP.get(match.group(0)), value),))
|
||||||
|
|
||||||
def escape_str(value):
|
def escape_str(value, mapping=None):
|
||||||
return "'%s'" % escape_string(value)
|
return "'%s'" % escape_string(value, mapping)
|
||||||
|
|
||||||
def escape_unicode(value):
|
def escape_unicode(value, mapping=None):
|
||||||
return escape_str(value)
|
return escape_str(value, mapping)
|
||||||
|
|
||||||
def escape_bytes(value):
|
def escape_bytes(value, mapping=None):
|
||||||
# escape_bytes is calld only on Python 3.
|
# escape_bytes is calld only on Python 3.
|
||||||
return escape_str(value.decode('ascii', 'surrogateescape'))
|
return escape_str(value.decode('ascii', 'surrogateescape'), mapping)
|
||||||
|
|
||||||
def escape_None(value):
|
def escape_None(value, mapping=None):
|
||||||
return 'NULL'
|
return 'NULL'
|
||||||
|
|
||||||
def escape_timedelta(obj):
|
def escape_timedelta(obj, mapping=None):
|
||||||
seconds = int(obj.seconds) % 60
|
seconds = int(obj.seconds) % 60
|
||||||
minutes = int(obj.seconds // 60) % 60
|
minutes = int(obj.seconds // 60) % 60
|
||||||
hours = int(obj.seconds // 3600) % 24 + int(obj.days) * 24
|
hours = int(obj.seconds // 3600) % 24 + int(obj.days) * 24
|
||||||
@@ -83,25 +90,25 @@ def escape_timedelta(obj):
|
|||||||
fmt = "'{0:02d}:{1:02d}:{2:02d}'"
|
fmt = "'{0:02d}:{1:02d}:{2:02d}'"
|
||||||
return fmt.format(hours, minutes, seconds, obj.microseconds)
|
return fmt.format(hours, minutes, seconds, obj.microseconds)
|
||||||
|
|
||||||
def escape_time(obj):
|
def escape_time(obj, mapping=None):
|
||||||
if obj.microsecond:
|
if obj.microsecond:
|
||||||
fmt = "'{0.hour:02}:{0.minute:02}:{0.second:02}.{0.microsecond:06}'"
|
fmt = "'{0.hour:02}:{0.minute:02}:{0.second:02}.{0.microsecond:06}'"
|
||||||
else:
|
else:
|
||||||
fmt = "'{0.hour:02}:{0.minute:02}:{0.second:02}'"
|
fmt = "'{0.hour:02}:{0.minute:02}:{0.second:02}'"
|
||||||
return fmt.format(obj)
|
return fmt.format(obj)
|
||||||
|
|
||||||
def escape_datetime(obj):
|
def escape_datetime(obj, mapping=None):
|
||||||
if obj.microsecond:
|
if obj.microsecond:
|
||||||
fmt = "'{0.year:04}-{0.month:02}-{0.day:02} {0.hour:02}:{0.minute:02}:{0.second:02}.{0.microsecond:06}'"
|
fmt = "'{0.year:04}-{0.month:02}-{0.day:02} {0.hour:02}:{0.minute:02}:{0.second:02}.{0.microsecond:06}'"
|
||||||
else:
|
else:
|
||||||
fmt = "'{0.year:04}-{0.month:02}-{0.day:02} {0.hour:02}:{0.minute:02}:{0.second:02}'"
|
fmt = "'{0.year:04}-{0.month:02}-{0.day:02} {0.hour:02}:{0.minute:02}:{0.second:02}'"
|
||||||
return fmt.format(obj)
|
return fmt.format(obj)
|
||||||
|
|
||||||
def escape_date(obj):
|
def escape_date(obj, mapping=None):
|
||||||
fmt = "'{0.year:04}-{0.month:02}-{0.day:02}'"
|
fmt = "'{0.year:04}-{0.month:02}-{0.day:02}'"
|
||||||
return fmt.format(obj)
|
return fmt.format(obj)
|
||||||
|
|
||||||
def escape_struct_time(obj):
|
def escape_struct_time(obj, mapping=None):
|
||||||
return escape_datetime(datetime.datetime(*obj[:6]))
|
return escape_datetime(datetime.datetime(*obj[:6]))
|
||||||
|
|
||||||
def convert_datetime(obj):
|
def convert_datetime(obj):
|
||||||
@@ -309,7 +316,7 @@ encoders = {
|
|||||||
datetime.timedelta: escape_timedelta,
|
datetime.timedelta: escape_timedelta,
|
||||||
datetime.time: escape_time,
|
datetime.time: escape_time,
|
||||||
time.struct_time: escape_struct_time,
|
time.struct_time: escape_struct_time,
|
||||||
Decimal: str,
|
Decimal: escape_object,
|
||||||
}
|
}
|
||||||
|
|
||||||
if not PY2 or JYTHON or IRONPYTHON:
|
if not PY2 or JYTHON or IRONPYTHON:
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
import datetime
|
||||||
|
import decimal
|
||||||
import pymysql
|
import pymysql
|
||||||
import time
|
import time
|
||||||
from pymysql.tests import base
|
from pymysql.tests import base
|
||||||
@@ -21,24 +23,6 @@ class TestConnection(base.PyMySQLTestCase):
|
|||||||
cur.execute("SELECT '" + t + "'")
|
cur.execute("SELECT '" + t + "'")
|
||||||
assert cur.fetchone()[0] == t
|
assert cur.fetchone()[0] == t
|
||||||
|
|
||||||
def test_escape_string(self):
|
|
||||||
con = self.connections[0]
|
|
||||||
cur = con.cursor()
|
|
||||||
|
|
||||||
self.assertEqual(con.escape("foo'bar"), "'foo\\'bar'")
|
|
||||||
cur.execute("SET sql_mode='NO_BACKSLASH_ESCAPES'")
|
|
||||||
self.assertEqual(con.escape("foo'bar"), "'foo''bar'")
|
|
||||||
|
|
||||||
def test_escape_custom_object(self):
|
|
||||||
con = self.connections[0]
|
|
||||||
cur = con.cursor()
|
|
||||||
|
|
||||||
class Foo(object):
|
|
||||||
value = "bar"
|
|
||||||
encoder = lambda x: x.value
|
|
||||||
|
|
||||||
self.assertEqual(con.escape(Foo(), encoders={Foo: encoder}), "bar")
|
|
||||||
|
|
||||||
def test_autocommit(self):
|
def test_autocommit(self):
|
||||||
con = self.connections[0]
|
con = self.connections[0]
|
||||||
self.assertFalse(con.get_autocommit())
|
self.assertFalse(con.get_autocommit())
|
||||||
@@ -79,3 +63,68 @@ class TestConnection(base.PyMySQLTestCase):
|
|||||||
# error occures while reading, not writing because of socket buffer.
|
# error occures while reading, not writing because of socket buffer.
|
||||||
#self.assertEquals(cm.exception.args[0], 2006)
|
#self.assertEquals(cm.exception.args[0], 2006)
|
||||||
self.assertIn(cm.exception.args[0], (2006, 2013))
|
self.assertIn(cm.exception.args[0], (2006, 2013))
|
||||||
|
|
||||||
|
|
||||||
|
# A custom type and function to escape it
|
||||||
|
class Foo(object):
|
||||||
|
value = "bar"
|
||||||
|
|
||||||
|
def escape_foo(x, d):
|
||||||
|
return x.value
|
||||||
|
|
||||||
|
|
||||||
|
class TestEscape(base.PyMySQLTestCase):
|
||||||
|
def test_escape_string(self):
|
||||||
|
con = self.connections[0]
|
||||||
|
cur = con.cursor()
|
||||||
|
|
||||||
|
self.assertEqual(con.escape("foo'bar"), "'foo\\'bar'")
|
||||||
|
cur.execute("SET sql_mode='NO_BACKSLASH_ESCAPES'")
|
||||||
|
self.assertEqual(con.escape("foo'bar"), "'foo''bar'")
|
||||||
|
|
||||||
|
def test_escape_builtin_encoders(self):
|
||||||
|
con = self.connections[0]
|
||||||
|
cur = con.cursor()
|
||||||
|
|
||||||
|
val = datetime.datetime(2012, 3, 4, 5, 6)
|
||||||
|
self.assertEqual(con.escape(val, con.encoders), "'2012-03-04 05:06:00'")
|
||||||
|
|
||||||
|
def test_escape_custom_object(self):
|
||||||
|
con = self.connections[0]
|
||||||
|
cur = con.cursor()
|
||||||
|
|
||||||
|
mapping = {Foo: escape_foo}
|
||||||
|
self.assertEqual(con.escape(Foo(), mapping), "bar")
|
||||||
|
|
||||||
|
def test_escape_fallback_encoder(self):
|
||||||
|
con = self.connections[0]
|
||||||
|
cur = con.cursor()
|
||||||
|
|
||||||
|
class Custom(str):
|
||||||
|
pass
|
||||||
|
|
||||||
|
mapping = {pymysql.text_type: pymysql.escape_string}
|
||||||
|
self.assertEqual(con.escape(Custom('foobar'), mapping), "'foobar'")
|
||||||
|
|
||||||
|
def test_escape_no_default(self):
|
||||||
|
con = self.connections[0]
|
||||||
|
cur = con.cursor()
|
||||||
|
|
||||||
|
self.assertRaises(TypeError, con.escape, 42, {})
|
||||||
|
|
||||||
|
def test_escape_dict_value(self):
|
||||||
|
con = self.connections[0]
|
||||||
|
cur = con.cursor()
|
||||||
|
|
||||||
|
mapping = con.encoders.copy()
|
||||||
|
mapping[Foo] = escape_foo
|
||||||
|
self.assertEqual(con.escape({'foo': Foo()}, mapping), {'foo': "bar"})
|
||||||
|
|
||||||
|
def test_escape_list_item(self):
|
||||||
|
con = self.connections[0]
|
||||||
|
cur = con.cursor()
|
||||||
|
|
||||||
|
mapping = con.encoders.copy()
|
||||||
|
mapping[Foo] = escape_foo
|
||||||
|
self.assertEqual(con.escape([Foo()], mapping), "(bar)")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user