From 038620064fd56476ce9b22d004a2c962654453de Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Tue, 17 May 2016 18:29:47 +0900 Subject: [PATCH] Allow custom encoding map. Fixes #220 --- pymysql/connections.py | 31 ++++++++++++++++++++----------- pymysql/converters.py | 3 ++- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/pymysql/connections.py b/pymysql/connections.py index c394591..58fbb1a 100644 --- a/pymysql/connections.py +++ b/pymysql/connections.py @@ -18,8 +18,7 @@ import warnings from .charset import MBLENGTH, charset_by_name, charset_by_id from .constants import CLIENT, COMMAND, FIELD_TYPE, SERVER_STATUS -from .converters import ( - escape_item, encoders, decoders, escape_string, through) +from .converters import escape_item, escape_string, through, conversions as _conv from .cursors import Cursor from .optionfile import Parser from .util import byte2int, int2byte @@ -529,7 +528,7 @@ class Connection(object): def __init__(self, host=None, user=None, password="", database=None, port=0, unix_socket=None, charset='', sql_mode=None, - read_default_file=None, conv=decoders, use_unicode=None, + read_default_file=None, conv=None, use_unicode=None, client_flag=0, cursorclass=Cursor, init_command=None, connect_timeout=None, ssl=None, read_default_group=None, compress=None, named_pipe=None, no_delay=None, @@ -551,8 +550,9 @@ class Connection(object): read_default_file: Specifies my.cnf file to read these parameters from under the [client] section. conv: - Decoders dictionary to use instead of the default one. - This is used to provide custom marshalling of types. See converters. + Conversion dictionary to use instead of the default one. + This is used to provide custom marshalling and unmarshaling of types. + See converters. use_unicode: Whether or not to default to unicode strings. This option defaults to true for Py3k. @@ -667,8 +667,11 @@ class Connection(object): #: specified autocommit mode. None means use server default. self.autocommit_mode = autocommit - self.encoders = encoders # Need for MySQLdb compatibility. - self.decoders = conv + if conv is None: + conv = _conv + # Need for MySQLdb compatibility. + self.encoders = dict([(k, v) for (k, v) in conv.items() if type(k) is not int]) + self.decoders = dict([(k, v) for (k, v) in conv.items() if type(k) is int]) self.sql_mode = sql_mode self.init_command = init_command self.max_allowed_packet = max_allowed_packet @@ -770,19 +773,25 @@ class Connection(object): return result.rows def select_db(self, db): - '''Set current db''' + """Set current db""" self._execute_command(COMMAND.COM_INIT_DB, db) self._read_ok_packet() def escape(self, obj, mapping=None): - """Escape whatever value you pass to it""" + """Escape whatever value you pass to it. + + Non-standard, for internal use; do not use this in your applications. + """ if isinstance(obj, str_type): return "'" + self.escape_string(obj) + "'" return escape_item(obj, self.charset, mapping=mapping) def literal(self, obj): - '''Alias for escape()''' - return self.escape(obj) + """Alias for escape() + + Non-standard, for internal use; do not use this in your applications. + """ + return self.escape(obj, self.encoders) def escape_string(self, s): if (self.server_status & diff --git a/pymysql/converters.py b/pymysql/converters.py index 69115c0..7c6c557 100644 --- a/pymysql/converters.py +++ b/pymysql/converters.py @@ -385,5 +385,6 @@ decoders = { # for MySQLdb compatibility -conversions = decoders +conversions = encoders.copy() +conversions.update(decoders) Thing2Literal = escape_str