Merge pull request #455 from methane/customize-encode

Allow custom encoding mapping via conv parameter.
This commit is contained in:
INADA Naoki
2016-05-17 20:24:42 +09:00
2 changed files with 22 additions and 12 deletions

View File

@@ -18,8 +18,7 @@ import warnings
from .charset import MBLENGTH, charset_by_name, charset_by_id from .charset import MBLENGTH, charset_by_name, charset_by_id
from .constants import CLIENT, COMMAND, FIELD_TYPE, SERVER_STATUS from .constants import CLIENT, COMMAND, FIELD_TYPE, SERVER_STATUS
from .converters import ( from .converters import escape_item, escape_string, through, conversions as _conv
escape_item, encoders, decoders, escape_string, through)
from .cursors import Cursor from .cursors import Cursor
from .optionfile import Parser from .optionfile import Parser
from .util import byte2int, int2byte from .util import byte2int, int2byte
@@ -529,7 +528,7 @@ class Connection(object):
def __init__(self, host=None, user=None, password="", def __init__(self, host=None, user=None, password="",
database=None, port=0, unix_socket=None, database=None, port=0, unix_socket=None,
charset='', sql_mode=None, charset='', sql_mode=None,
read_default_file=None, conv=decoders, use_unicode=None, read_default_file=None, conv=None, use_unicode=None,
client_flag=0, cursorclass=Cursor, init_command=None, client_flag=0, cursorclass=Cursor, init_command=None,
connect_timeout=None, ssl=None, read_default_group=None, connect_timeout=None, ssl=None, read_default_group=None,
compress=None, named_pipe=None, no_delay=None, compress=None, named_pipe=None, no_delay=None,
@@ -551,8 +550,9 @@ class Connection(object):
read_default_file: read_default_file:
Specifies my.cnf file to read these parameters from under the [client] section. Specifies my.cnf file to read these parameters from under the [client] section.
conv: conv:
Decoders dictionary to use instead of the default one. Conversion dictionary to use instead of the default one.
This is used to provide custom marshalling of types. See converters. This is used to provide custom marshalling and unmarshaling of types.
See converters.
use_unicode: use_unicode:
Whether or not to default to unicode strings. Whether or not to default to unicode strings.
This option defaults to true for Py3k. This option defaults to true for Py3k.
@@ -667,8 +667,11 @@ class Connection(object):
#: specified autocommit mode. None means use server default. #: specified autocommit mode. None means use server default.
self.autocommit_mode = autocommit self.autocommit_mode = autocommit
self.encoders = encoders # Need for MySQLdb compatibility. if conv is None:
self.decoders = conv conv = _conv
# Need for MySQLdb compatibility.
self.encoders = dict([(k, v) for (k, v) in conv.items() if type(k) is not int])
self.decoders = dict([(k, v) for (k, v) in conv.items() if type(k) is int])
self.sql_mode = sql_mode self.sql_mode = sql_mode
self.init_command = init_command self.init_command = init_command
self.max_allowed_packet = max_allowed_packet self.max_allowed_packet = max_allowed_packet
@@ -770,19 +773,25 @@ class Connection(object):
return result.rows return result.rows
def select_db(self, db): def select_db(self, db):
'''Set current db''' """Set current db"""
self._execute_command(COMMAND.COM_INIT_DB, db) self._execute_command(COMMAND.COM_INIT_DB, db)
self._read_ok_packet() self._read_ok_packet()
def escape(self, obj, mapping=None): def escape(self, obj, mapping=None):
"""Escape whatever value you pass to it""" """Escape whatever value you pass to it.
Non-standard, for internal use; do not use this in your applications.
"""
if isinstance(obj, str_type): if isinstance(obj, str_type):
return "'" + self.escape_string(obj) + "'" return "'" + self.escape_string(obj) + "'"
return escape_item(obj, self.charset, mapping=mapping) return escape_item(obj, self.charset, mapping=mapping)
def literal(self, obj): def literal(self, obj):
'''Alias for escape()''' """Alias for escape()
return self.escape(obj)
Non-standard, for internal use; do not use this in your applications.
"""
return self.escape(obj, self.encoders)
def escape_string(self, s): def escape_string(self, s):
if (self.server_status & if (self.server_status &

View File

@@ -397,5 +397,6 @@ decoders = {
# for MySQLdb compatibility # for MySQLdb compatibility
conversions = decoders conversions = encoders.copy()
conversions.update(decoders)
Thing2Literal = escape_str Thing2Literal = escape_str