Update interface of encoders to include the mapping argument.

This commit is contained in:
Dave Stinson
2015-02-06 02:11:33 +00:00
parent af8340c3fe
commit 7527eeebfc
3 changed files with 109 additions and 53 deletions

View File

@@ -725,11 +725,11 @@ class Connection(object):
self._execute_command(COMMAND.COM_INIT_DB, db)
self._read_ok_packet()
def escape(self, obj, encoders=None):
def escape(self, obj, mapping=None):
''' Escape whatever value you pass to it '''
if isinstance(obj, str_type):
return "'" + self.escape_string(obj) + "'"
return escape_item(obj, self.charset, custom_encoders=encoders)
return escape_item(obj, self.charset, mapping=mapping)
def literal(self, obj):
'''Alias for escape()'''

View File

@@ -15,65 +15,72 @@ ESCAPE_REGEX = re.compile(r"[\0\n\r\032\'\"\\]")
ESCAPE_MAP = {'\0': '\\0', '\n': '\\n', '\r': '\\r', '\032': '\\Z',
'\'': '\\\'', '"': '\\"', '\\': '\\\\'}
def escape_item(val, charset, custom_encoders=None):
if type(val) in [tuple, list, set]:
return escape_sequence(val, charset)
if type(val) is dict:
return escape_dict(val, charset)
item_encoders = custom_encoders or encoders
encoder = item_encoders[type(val)]
val = encoder(val)
def escape_item(val, charset, mapping=None):
if mapping is None:
mapping = encoders
encoder = mapping.get(type(val))
# Fallback to default when no encoder found
if not encoder:
try:
encoder = mapping[text_type]
except KeyError:
raise TypeError("no default type converter defined")
if encoder in (escape_dict, escape_sequence):
val = encoder(val, charset, mapping)
else:
val = encoder(val, mapping)
return val
def escape_dict(val, charset):
def escape_dict(val, charset, mapping=None):
n = {}
for k, v in val.items():
quoted = escape_item(v, charset)
quoted = escape_item(v, charset, mapping)
n[k] = quoted
return n
def escape_sequence(val, charset):
def escape_sequence(val, charset, mapping=None):
n = []
for item in val:
quoted = escape_item(item, charset)
quoted = escape_item(item, charset, mapping)
n.append(quoted)
return "(" + ",".join(n) + ")"
def escape_set(val, charset):
val = map(lambda x: escape_item(x, charset), val)
def escape_set(val, charset, mapping=None):
val = map(lambda x: escape_item(x, charset, mapping), val)
return ','.join(val)
def escape_bool(value):
def escape_bool(value, mapping=None):
return str(int(value))
def escape_object(value):
def escape_object(value, mapping=None):
return str(value)
def escape_int(value):
def escape_int(value, mapping=None):
return str(value)
def escape_float(value):
def escape_float(value, mapping=None):
return ('%.15g' % value)
def escape_string(value):
def escape_string(value, mapping=None):
return ("%s" % (ESCAPE_REGEX.sub(
lambda match: ESCAPE_MAP.get(match.group(0)), value),))
def escape_str(value):
return "'%s'" % escape_string(value)
def escape_str(value, mapping=None):
return "'%s'" % escape_string(value, mapping)
def escape_unicode(value):
return escape_str(value)
def escape_unicode(value, mapping=None):
return escape_str(value, mapping)
def escape_bytes(value):
def escape_bytes(value, mapping=None):
# escape_bytes is calld only on Python 3.
return escape_str(value.decode('ascii', 'surrogateescape'))
return escape_str(value.decode('ascii', 'surrogateescape'), mapping)
def escape_None(value):
def escape_None(value, mapping=None):
return 'NULL'
def escape_timedelta(obj):
def escape_timedelta(obj, mapping=None):
seconds = int(obj.seconds) % 60
minutes = int(obj.seconds // 60) % 60
hours = int(obj.seconds // 3600) % 24 + int(obj.days) * 24
@@ -83,25 +90,25 @@ def escape_timedelta(obj):
fmt = "'{0:02d}:{1:02d}:{2:02d}'"
return fmt.format(hours, minutes, seconds, obj.microseconds)
def escape_time(obj):
def escape_time(obj, mapping=None):
if obj.microsecond:
fmt = "'{0.hour:02}:{0.minute:02}:{0.second:02}.{0.microsecond:06}'"
else:
fmt = "'{0.hour:02}:{0.minute:02}:{0.second:02}'"
return fmt.format(obj)
def escape_datetime(obj):
def escape_datetime(obj, mapping=None):
if obj.microsecond:
fmt = "'{0.year:04}-{0.month:02}-{0.day:02} {0.hour:02}:{0.minute:02}:{0.second:02}.{0.microsecond:06}'"
else:
fmt = "'{0.year:04}-{0.month:02}-{0.day:02} {0.hour:02}:{0.minute:02}:{0.second:02}'"
return fmt.format(obj)
def escape_date(obj):
def escape_date(obj, mapping=None):
fmt = "'{0.year:04}-{0.month:02}-{0.day:02}'"
return fmt.format(obj)
def escape_struct_time(obj):
def escape_struct_time(obj, mapping=None):
return escape_datetime(datetime.datetime(*obj[:6]))
def convert_datetime(obj):
@@ -309,7 +316,7 @@ encoders = {
datetime.timedelta: escape_timedelta,
datetime.time: escape_time,
time.struct_time: escape_struct_time,
Decimal: str,
Decimal: escape_object,
}
if not PY2 or JYTHON or IRONPYTHON:

View File

@@ -1,3 +1,5 @@
import datetime
import decimal
import pymysql
import time
from pymysql.tests import base
@@ -21,24 +23,6 @@ class TestConnection(base.PyMySQLTestCase):
cur.execute("SELECT '" + t + "'")
assert cur.fetchone()[0] == t
def test_escape_string(self):
con = self.connections[0]
cur = con.cursor()
self.assertEqual(con.escape("foo'bar"), "'foo\\'bar'")
cur.execute("SET sql_mode='NO_BACKSLASH_ESCAPES'")
self.assertEqual(con.escape("foo'bar"), "'foo''bar'")
def test_escape_custom_object(self):
con = self.connections[0]
cur = con.cursor()
class Foo(object):
value = "bar"
encoder = lambda x: x.value
self.assertEqual(con.escape(Foo(), encoders={Foo: encoder}), "bar")
def test_autocommit(self):
con = self.connections[0]
self.assertFalse(con.get_autocommit())
@@ -79,3 +63,68 @@ class TestConnection(base.PyMySQLTestCase):
# error occures while reading, not writing because of socket buffer.
#self.assertEquals(cm.exception.args[0], 2006)
self.assertIn(cm.exception.args[0], (2006, 2013))
# A custom type and function to escape it
class Foo(object):
value = "bar"
def escape_foo(x, d):
return x.value
class TestEscape(base.PyMySQLTestCase):
def test_escape_string(self):
con = self.connections[0]
cur = con.cursor()
self.assertEqual(con.escape("foo'bar"), "'foo\\'bar'")
cur.execute("SET sql_mode='NO_BACKSLASH_ESCAPES'")
self.assertEqual(con.escape("foo'bar"), "'foo''bar'")
def test_escape_builtin_encoders(self):
con = self.connections[0]
cur = con.cursor()
val = datetime.datetime(2012, 3, 4, 5, 6)
self.assertEqual(con.escape(val, con.encoders), "'2012-03-04 05:06:00'")
def test_escape_custom_object(self):
con = self.connections[0]
cur = con.cursor()
mapping = {Foo: escape_foo}
self.assertEqual(con.escape(Foo(), mapping), "bar")
def test_escape_fallback_encoder(self):
con = self.connections[0]
cur = con.cursor()
class Custom(str):
pass
mapping = {pymysql.text_type: pymysql.escape_string}
self.assertEqual(con.escape(Custom('foobar'), mapping), "'foobar'")
def test_escape_no_default(self):
con = self.connections[0]
cur = con.cursor()
self.assertRaises(TypeError, con.escape, 42, {})
def test_escape_dict_value(self):
con = self.connections[0]
cur = con.cursor()
mapping = con.encoders.copy()
mapping[Foo] = escape_foo
self.assertEqual(con.escape({'foo': Foo()}, mapping), {'foo': "bar"})
def test_escape_list_item(self):
con = self.connections[0]
cur = con.cursor()
mapping = con.encoders.copy()
mapping[Foo] = escape_foo
self.assertEqual(con.escape([Foo()], mapping), "(bar)")