diff --git a/pymysql/converters.py b/pymysql/converters.py index 4e74ced..1e97dc2 100644 --- a/pymysql/converters.py +++ b/pymysql/converters.py @@ -4,17 +4,12 @@ import sys import binascii import datetime from decimal import Decimal -import re import time from .constants import FIELD_TYPE, FLAG from .charset import charset_by_id, charset_to_encoding -ESCAPE_REGEX = re.compile(r"[\0\n\r\032\'\"\\]") -ESCAPE_MAP = {'\0': '\\0', '\n': '\\n', '\r': '\\r', '\032': '\\Z', - '\'': '\\\'', '"': '\\"', '\\': '\\\\'} - def escape_item(val, charset, mapping=None): if mapping is None: mapping = encoders @@ -48,8 +43,7 @@ def escape_sequence(val, charset, mapping=None): return "(" + ",".join(n) + ")" def escape_set(val, charset, mapping=None): - val = map(lambda x: escape_item(x, charset, mapping), val) - return ','.join(val) + return ','.join([escape_item(x, charset, mapping) for x in val]) def escape_bool(value, mapping=None): return str(int(value)) @@ -63,9 +57,46 @@ def escape_int(value, mapping=None): def escape_float(value, mapping=None): return ('%.15g' % value) -def escape_string(value, mapping=None): - return ("%s" % (ESCAPE_REGEX.sub( - lambda match: ESCAPE_MAP.get(match.group(0)), value),)) +if PY2: + def escape_string(value, mapping=None): + """escape_string escapes *value* but not surround it with quotes. + + Value should be bytes or unicode. + """ + value = value.replace('\\', '\\\\') + value = value.replace('\0', '\\0') + value = value.replace('\n', '\\n') + value = value.replace('\r', '\\r') + value = value.replace('\032', '\\Z') + value = value.replace("'", "\\'") + value = value.replace('"', '\\"') + return value +else: + _escape_table = [chr(x) for x in range(128)] + _escape_table[0] = '\\0' + _escape_table[ord('\\')] = '\\\\' + _escape_table[ord('\n')] = '\\n' + _escape_table[ord('\r')] = '\\r' + _escape_table[ord('\032')] = '\\Z' + _escape_table[ord('"')] = '\\"' + _escape_table[ord("'")] = "\\'" + + def escape_string(value, mapping=None): + """escape_string escapes *value* but not surround it with quotes. + + Value should be str (unicode). + """ + return value.translate(_escape_table) + + # On Python ~3.5, str.decode('ascii', 'surrogateescape') is slow. + # (fixed in Python 3.6, http://bugs.python.org/issue24870) + # Workaround is str.decode('latin1') then translate 0x80-0xff into 0udc80-0udcff. + # We can escape special chars and surrogateescape at once. + _escape_bytes_table = _escape_table + [chr(i) for i in range(0xdc80, 0xdd00)] + + def escape_bytes(value, mapping=None): + return "'%s'" % value.decode('latin1').translate(_escape_bytes_table) + def escape_str(value, mapping=None): return "'%s'" % escape_string(value, mapping) @@ -73,10 +104,6 @@ def escape_str(value, mapping=None): def escape_unicode(value, mapping=None): return escape_str(value, mapping) -def escape_bytes(value, mapping=None): - # escape_bytes is calld only on Python 3. - return escape_str(value.decode('ascii', 'surrogateescape'), mapping) - def escape_None(value, mapping=None): return 'NULL' diff --git a/pymysql/cursors.py b/pymysql/cursors.py index 266e137..ff00d7b 100644 --- a/pymysql/cursors.py +++ b/pymysql/cursors.py @@ -117,7 +117,7 @@ class Cursor(object): # If it's not a dictionary let's try escaping it anyways. # Worst case it will throw a Value error if PY2: - ensure_bytes(args) + args = ensure_bytes(args) return conn.escape(args) def mogrify(self, query, args=None):