cleaned up charset and unicode functionality, aligned BIT type handling with MySQLdb, disambiguated length coded binary and strings, raised exception after connection is close()'d

This commit is contained in:
Pete Hunt
2010-12-10 05:32:06 +00:00
parent bfa2cf574d
commit d51400138c
8 changed files with 449 additions and 222 deletions

View File

@@ -1,5 +1,3 @@
MBLENGTH = { MBLENGTH = {
8:1, 8:1,
33:3, 33:3,
@@ -7,4 +5,170 @@ MBLENGTH = {
91:2 91:2
} }
class Charset:
def __init__(self, id, name, collation, is_default):
self.id, self.name, self.collation = id, name, collation
self.is_default = is_default == 'Yes'
class Charsets:
def __init__(self):
self._by_id = {}
def add(self, c):
self._by_id[c.id] = c
def by_id(self, id):
return self._by_id[id]
def by_name(self, name):
for c in self._by_id.values():
if c.name == name and c.is_default:
return c
_charsets = Charsets()
"""
Generated with:
mysql -N -s -e "select id, character_set_name, collation_name, is_default
from information_schema.collations order by id;" | python -c "import sys
for l in sys.stdin.readlines():
id, name, collation, is_default = l.split(chr(9))
print '_charsets.add(Charset(%s, \'%s\', \'%s\', \'%s\'))' \
% (id, name, collation, is_default.strip())
"
"""
_charsets.add(Charset(1, 'big5', 'big5_chinese_ci', 'Yes'))
_charsets.add(Charset(2, 'latin2', 'latin2_czech_cs', ''))
_charsets.add(Charset(3, 'dec8', 'dec8_swedish_ci', 'Yes'))
_charsets.add(Charset(4, 'cp850', 'cp850_general_ci', 'Yes'))
_charsets.add(Charset(5, 'latin1', 'latin1_german1_ci', ''))
_charsets.add(Charset(6, 'hp8', 'hp8_english_ci', 'Yes'))
_charsets.add(Charset(7, 'koi8r', 'koi8r_general_ci', 'Yes'))
_charsets.add(Charset(8, 'latin1', 'latin1_swedish_ci', 'Yes'))
_charsets.add(Charset(9, 'latin2', 'latin2_general_ci', 'Yes'))
_charsets.add(Charset(10, 'swe7', 'swe7_swedish_ci', 'Yes'))
_charsets.add(Charset(11, 'ascii', 'ascii_general_ci', 'Yes'))
_charsets.add(Charset(12, 'ujis', 'ujis_japanese_ci', 'Yes'))
_charsets.add(Charset(13, 'sjis', 'sjis_japanese_ci', 'Yes'))
_charsets.add(Charset(14, 'cp1251', 'cp1251_bulgarian_ci', ''))
_charsets.add(Charset(15, 'latin1', 'latin1_danish_ci', ''))
_charsets.add(Charset(16, 'hebrew', 'hebrew_general_ci', 'Yes'))
_charsets.add(Charset(18, 'tis620', 'tis620_thai_ci', 'Yes'))
_charsets.add(Charset(19, 'euckr', 'euckr_korean_ci', 'Yes'))
_charsets.add(Charset(20, 'latin7', 'latin7_estonian_cs', ''))
_charsets.add(Charset(21, 'latin2', 'latin2_hungarian_ci', ''))
_charsets.add(Charset(22, 'koi8u', 'koi8u_general_ci', 'Yes'))
_charsets.add(Charset(23, 'cp1251', 'cp1251_ukrainian_ci', ''))
_charsets.add(Charset(24, 'gb2312', 'gb2312_chinese_ci', 'Yes'))
_charsets.add(Charset(25, 'greek', 'greek_general_ci', 'Yes'))
_charsets.add(Charset(26, 'cp1250', 'cp1250_general_ci', 'Yes'))
_charsets.add(Charset(27, 'latin2', 'latin2_croatian_ci', ''))
_charsets.add(Charset(28, 'gbk', 'gbk_chinese_ci', 'Yes'))
_charsets.add(Charset(29, 'cp1257', 'cp1257_lithuanian_ci', ''))
_charsets.add(Charset(30, 'latin5', 'latin5_turkish_ci', 'Yes'))
_charsets.add(Charset(31, 'latin1', 'latin1_german2_ci', ''))
_charsets.add(Charset(32, 'armscii8', 'armscii8_general_ci', 'Yes'))
_charsets.add(Charset(33, 'utf8', 'utf8_general_ci', 'Yes'))
_charsets.add(Charset(34, 'cp1250', 'cp1250_czech_cs', ''))
_charsets.add(Charset(35, 'ucs2', 'ucs2_general_ci', 'Yes'))
_charsets.add(Charset(36, 'cp866', 'cp866_general_ci', 'Yes'))
_charsets.add(Charset(37, 'keybcs2', 'keybcs2_general_ci', 'Yes'))
_charsets.add(Charset(38, 'macce', 'macce_general_ci', 'Yes'))
_charsets.add(Charset(39, 'macroman', 'macroman_general_ci', 'Yes'))
_charsets.add(Charset(40, 'cp852', 'cp852_general_ci', 'Yes'))
_charsets.add(Charset(41, 'latin7', 'latin7_general_ci', 'Yes'))
_charsets.add(Charset(42, 'latin7', 'latin7_general_cs', ''))
_charsets.add(Charset(43, 'macce', 'macce_bin', ''))
_charsets.add(Charset(44, 'cp1250', 'cp1250_croatian_ci', ''))
_charsets.add(Charset(47, 'latin1', 'latin1_bin', ''))
_charsets.add(Charset(48, 'latin1', 'latin1_general_ci', ''))
_charsets.add(Charset(49, 'latin1', 'latin1_general_cs', ''))
_charsets.add(Charset(50, 'cp1251', 'cp1251_bin', ''))
_charsets.add(Charset(51, 'cp1251', 'cp1251_general_ci', 'Yes'))
_charsets.add(Charset(52, 'cp1251', 'cp1251_general_cs', ''))
_charsets.add(Charset(53, 'macroman', 'macroman_bin', ''))
_charsets.add(Charset(57, 'cp1256', 'cp1256_general_ci', 'Yes'))
_charsets.add(Charset(58, 'cp1257', 'cp1257_bin', ''))
_charsets.add(Charset(59, 'cp1257', 'cp1257_general_ci', 'Yes'))
_charsets.add(Charset(63, 'binary', 'binary', 'Yes'))
_charsets.add(Charset(64, 'armscii8', 'armscii8_bin', ''))
_charsets.add(Charset(65, 'ascii', 'ascii_bin', ''))
_charsets.add(Charset(66, 'cp1250', 'cp1250_bin', ''))
_charsets.add(Charset(67, 'cp1256', 'cp1256_bin', ''))
_charsets.add(Charset(68, 'cp866', 'cp866_bin', ''))
_charsets.add(Charset(69, 'dec8', 'dec8_bin', ''))
_charsets.add(Charset(70, 'greek', 'greek_bin', ''))
_charsets.add(Charset(71, 'hebrew', 'hebrew_bin', ''))
_charsets.add(Charset(72, 'hp8', 'hp8_bin', ''))
_charsets.add(Charset(73, 'keybcs2', 'keybcs2_bin', ''))
_charsets.add(Charset(74, 'koi8r', 'koi8r_bin', ''))
_charsets.add(Charset(75, 'koi8u', 'koi8u_bin', ''))
_charsets.add(Charset(77, 'latin2', 'latin2_bin', ''))
_charsets.add(Charset(78, 'latin5', 'latin5_bin', ''))
_charsets.add(Charset(79, 'latin7', 'latin7_bin', ''))
_charsets.add(Charset(80, 'cp850', 'cp850_bin', ''))
_charsets.add(Charset(81, 'cp852', 'cp852_bin', ''))
_charsets.add(Charset(82, 'swe7', 'swe7_bin', ''))
_charsets.add(Charset(83, 'utf8', 'utf8_bin', ''))
_charsets.add(Charset(84, 'big5', 'big5_bin', ''))
_charsets.add(Charset(85, 'euckr', 'euckr_bin', ''))
_charsets.add(Charset(86, 'gb2312', 'gb2312_bin', ''))
_charsets.add(Charset(87, 'gbk', 'gbk_bin', ''))
_charsets.add(Charset(88, 'sjis', 'sjis_bin', ''))
_charsets.add(Charset(89, 'tis620', 'tis620_bin', ''))
_charsets.add(Charset(90, 'ucs2', 'ucs2_bin', ''))
_charsets.add(Charset(91, 'ujis', 'ujis_bin', ''))
_charsets.add(Charset(92, 'geostd8', 'geostd8_general_ci', 'Yes'))
_charsets.add(Charset(93, 'geostd8', 'geostd8_bin', ''))
_charsets.add(Charset(94, 'latin1', 'latin1_spanish_ci', ''))
_charsets.add(Charset(95, 'cp932', 'cp932_japanese_ci', 'Yes'))
_charsets.add(Charset(96, 'cp932', 'cp932_bin', ''))
_charsets.add(Charset(97, 'eucjpms', 'eucjpms_japanese_ci', 'Yes'))
_charsets.add(Charset(98, 'eucjpms', 'eucjpms_bin', ''))
_charsets.add(Charset(99, 'cp1250', 'cp1250_polish_ci', ''))
_charsets.add(Charset(128, 'ucs2', 'ucs2_unicode_ci', ''))
_charsets.add(Charset(129, 'ucs2', 'ucs2_icelandic_ci', ''))
_charsets.add(Charset(130, 'ucs2', 'ucs2_latvian_ci', ''))
_charsets.add(Charset(131, 'ucs2', 'ucs2_romanian_ci', ''))
_charsets.add(Charset(132, 'ucs2', 'ucs2_slovenian_ci', ''))
_charsets.add(Charset(133, 'ucs2', 'ucs2_polish_ci', ''))
_charsets.add(Charset(134, 'ucs2', 'ucs2_estonian_ci', ''))
_charsets.add(Charset(135, 'ucs2', 'ucs2_spanish_ci', ''))
_charsets.add(Charset(136, 'ucs2', 'ucs2_swedish_ci', ''))
_charsets.add(Charset(137, 'ucs2', 'ucs2_turkish_ci', ''))
_charsets.add(Charset(138, 'ucs2', 'ucs2_czech_ci', ''))
_charsets.add(Charset(139, 'ucs2', 'ucs2_danish_ci', ''))
_charsets.add(Charset(140, 'ucs2', 'ucs2_lithuanian_ci', ''))
_charsets.add(Charset(141, 'ucs2', 'ucs2_slovak_ci', ''))
_charsets.add(Charset(142, 'ucs2', 'ucs2_spanish2_ci', ''))
_charsets.add(Charset(143, 'ucs2', 'ucs2_roman_ci', ''))
_charsets.add(Charset(144, 'ucs2', 'ucs2_persian_ci', ''))
_charsets.add(Charset(145, 'ucs2', 'ucs2_esperanto_ci', ''))
_charsets.add(Charset(146, 'ucs2', 'ucs2_hungarian_ci', ''))
_charsets.add(Charset(192, 'utf8', 'utf8_unicode_ci', ''))
_charsets.add(Charset(193, 'utf8', 'utf8_icelandic_ci', ''))
_charsets.add(Charset(194, 'utf8', 'utf8_latvian_ci', ''))
_charsets.add(Charset(195, 'utf8', 'utf8_romanian_ci', ''))
_charsets.add(Charset(196, 'utf8', 'utf8_slovenian_ci', ''))
_charsets.add(Charset(197, 'utf8', 'utf8_polish_ci', ''))
_charsets.add(Charset(198, 'utf8', 'utf8_estonian_ci', ''))
_charsets.add(Charset(199, 'utf8', 'utf8_spanish_ci', ''))
_charsets.add(Charset(200, 'utf8', 'utf8_swedish_ci', ''))
_charsets.add(Charset(201, 'utf8', 'utf8_turkish_ci', ''))
_charsets.add(Charset(202, 'utf8', 'utf8_czech_ci', ''))
_charsets.add(Charset(203, 'utf8', 'utf8_danish_ci', ''))
_charsets.add(Charset(204, 'utf8', 'utf8_lithuanian_ci', ''))
_charsets.add(Charset(205, 'utf8', 'utf8_slovak_ci', ''))
_charsets.add(Charset(206, 'utf8', 'utf8_spanish2_ci', ''))
_charsets.add(Charset(207, 'utf8', 'utf8_roman_ci', ''))
_charsets.add(Charset(208, 'utf8', 'utf8_persian_ci', ''))
_charsets.add(Charset(209, 'utf8', 'utf8_esperanto_ci', ''))
_charsets.add(Charset(210, 'utf8', 'utf8_hungarian_ci', ''))
def charset_by_name(name):
return _charsets.by_name(name)
def charset_by_id(id):
return _charsets.by_id(id)

View File

@@ -1,8 +1,6 @@
# Python implementation of the MySQL client-server protocol # Python implementation of the MySQL client-server protocol
# http://forge.mysql.com/wiki/MySQL_Internals_ClientServer_Protocol # http://forge.mysql.com/wiki/MySQL_Internals_ClientServer_Protocol
import re
try: try:
import hashlib import hashlib
sha_new = lambda *args, **kwargs: hashlib.new("sha1", *args, **kwargs) sha_new = lambda *args, **kwargs: hashlib.new("sha1", *args, **kwargs)
@@ -22,13 +20,13 @@ try:
except ImportError: except ImportError:
import StringIO import StringIO
from charset import MBLENGTH from charset import MBLENGTH, charset_by_name, charset_by_id
from cursors import Cursor from cursors import Cursor
from constants import FIELD_TYPE from constants import FIELD_TYPE, FLAG
from constants import SERVER_STATUS from constants import SERVER_STATUS
from constants.CLIENT import * from constants.CLIENT import *
from constants.COMMAND import * from constants.COMMAND import *
from converters import escape_item, encoders, decoders, field_decoders from converters import escape_item, encoders, decoders
from err import raise_mysql_exception, Warning, Error, \ from err import raise_mysql_exception, Warning, Error, \
InterfaceError, DataError, DatabaseError, OperationalError, \ InterfaceError, DataError, DatabaseError, OperationalError, \
IntegrityError, InternalError, NotSupportedError, ProgrammingError IntegrityError, InternalError, NotSupportedError, ProgrammingError
@@ -64,7 +62,8 @@ def dump_packet(data):
dump_data = [data[i:i+16] for i in xrange(len(data)) if i%16 == 0] dump_data = [data[i:i+16] for i in xrange(len(data)) if i%16 == 0]
for d in dump_data: for d in dump_data:
print ' '.join(map(lambda x:"%02X" % ord(x), d)) + \ print ' '.join(map(lambda x:"%02X" % ord(x), d)) + \
' ' * (16 - len(d)) + ' ' * 2 + ' '.join(map(lambda x:"%s" % is_ascii(x), d)) ' ' * (16 - len(d)) + ' ' * 2 + \
' '.join(map(lambda x:"%s" % is_ascii(x), d))
print "-" * 88 print "-" * 88
print "" print ""
@@ -84,7 +83,8 @@ def _my_crypt(message1, message2):
length = len(message1) length = len(message1)
result = struct.pack('B', length) result = struct.pack('B', length)
for i in xrange(length): for i in xrange(length):
x = (struct.unpack('B', message1[i:i+1])[0] ^ struct.unpack('B', message2[i:i+1])[0]) x = (struct.unpack('B', message1[i:i+1])[0] ^ \
struct.unpack('B', message2[i:i+1])[0])
result += struct.pack('B', x) result += struct.pack('B', x)
return result return result
@@ -161,9 +161,10 @@ def unpack_int64(n):
(struct.unpack('B',n[6])[0] << 48) + (struct.unpack('B',n[7])[0]<<56) (struct.unpack('B',n[6])[0] << 48) + (struct.unpack('B',n[7])[0]<<56)
def defaulterrorhandler(connection, cursor, errorclass, errorvalue): def defaulterrorhandler(connection, cursor, errorclass, errorvalue):
raise
err = errorclass, errorvalue err = errorclass, errorvalue
if DEBUG:
raise
if cursor: if cursor:
cursor.messages.append(err) cursor.messages.append(err)
else: else:
@@ -271,8 +272,8 @@ class MysqlPacket(object):
""" """
return self.__data[position:(position+length)] return self.__data[position:(position+length)]
def read_coded_length(self): def read_length_coded_binary(self):
"""Read a 'Length Coded' number from the data buffer. """Read a 'Length Coded Binary' number from the data buffer.
Length coded numbers can be anywhere from 1 to 9 bytes depending Length coded numbers can be anywhere from 1 to 9 bytes depending
on the value of the first byte. on the value of the first byte.
@@ -290,16 +291,17 @@ class MysqlPacket(object):
# TODO: what was 'longlong'? confirm it wasn't used? # TODO: what was 'longlong'? confirm it wasn't used?
return unpack_int64(self.read(UNSIGNED_INT64_LENGTH)) return unpack_int64(self.read(UNSIGNED_INT64_LENGTH))
def read_length_coded_binary(self): def read_length_coded_string(self):
"""Read a 'Length Coded Binary' from the data buffer. """Read a 'Length Coded String' from the data buffer.
A 'Length Coded Binary' consists first of a length coded A 'Length Coded String' consists first of a length coded
(unsigned, positive) integer represented in 1-9 bytes followed by (unsigned, positive) integer represented in 1-9 bytes followed by
that many bytes of binary data. (For example "cat" would be "3cat".) that many bytes of binary data. (For example "cat" would be "3cat".)
""" """
length = self.read_coded_length() length = self.read_length_coded_binary()
if length: if length is None:
return self.read(length) return None
return self.read(length)
def is_ok_packet(self): def is_ok_packet(self):
return ord(self.get_bytes(0)) == 0 return ord(self.get_bytes(0)) == 0
@@ -342,19 +344,17 @@ class FieldDescriptorPacket(MysqlPacket):
This is compatible with MySQL 4.1+ (not compatible with MySQL 4.0). This is compatible with MySQL 4.1+ (not compatible with MySQL 4.0).
""" """
self.catalog = self.read_length_coded_binary() self.catalog = self.read_length_coded_string()
self.db = self.read_length_coded_binary() self.db = self.read_length_coded_string()
self.table_name = self.read_length_coded_binary() self.table_name = self.read_length_coded_string()
self.org_table = self.read_length_coded_binary() self.org_table = self.read_length_coded_string()
self.name = self.read_length_coded_binary() self.name = self.read_length_coded_string()
self.org_name = self.read_length_coded_binary() self.org_name = self.read_length_coded_string()
self.advance(1) # non-null filler self.advance(1) # non-null filler
self.charsetnr = struct.unpack('<h', self.read(2))[0] self.charsetnr = struct.unpack('<H', self.read(2))[0]
self.length = struct.unpack('<i', self.read(4))[0] self.length = struct.unpack('<I', self.read(4))[0]
self.type_code = ord(self.read(1)) self.type_code = ord(self.read(1))
flags = struct.unpack('<h', self.read(2)) self.flags = struct.unpack('<H', self.read(2))[0]
# TODO: what is going on here with this flag parsing???
self.flags = int(("%02X" % flags)[1:], 16)
self.scale = ord(self.read(1)) # "decimals" self.scale = ord(self.read(1)) # "decimals"
self.advance(2) # filler (always 0x00) self.advance(2) # filler (always 0x00)
@@ -401,8 +401,8 @@ class Connection(object):
def __init__(self, host="localhost", user=None, passwd="", def __init__(self, host="localhost", user=None, passwd="",
db=None, port=3306, unix_socket=None, db=None, port=3306, unix_socket=None,
charset=DEFAULT_CHARSET, sql_mode=None, charset='', sql_mode=None,
read_default_file=None, conv=decoders, use_unicode=True, read_default_file=None, conv=decoders, use_unicode=False,
client_flag=0, cursorclass=Cursor, init_command=None, client_flag=0, cursorclass=Cursor, init_command=None,
connect_timeout=None, ssl=None, read_default_group=None, connect_timeout=None, ssl=None, read_default_group=None,
compress=None, named_pipe=None): compress=None, named_pipe=None):
@@ -457,7 +457,7 @@ class Connection(object):
return cfg.get("client",key) return cfg.get("client",key)
except: except:
return default return default
user = _config("user",user) user = _config("user",user)
passwd = _config("password",passwd) passwd = _config("password",passwd)
host = _config("host", host) host = _config("host", host)
@@ -465,15 +465,22 @@ class Connection(object):
unix_socket = _config("socket",unix_socket) unix_socket = _config("socket",unix_socket)
port = _config("port", port) port = _config("port", port)
charset = _config("default-character-set", charset) charset = _config("default-character-set", charset)
self.host = host self.host = host
self.port = port self.port = port
self.user = user self.user = user
self.password = passwd self.password = passwd
self.db = db self.db = db
self.unix_socket = unix_socket self.unix_socket = unix_socket
self.use_unicode = use_unicode if charset:
self.charset = DEFAULT_CHARSET self.charset = charset
self.use_unicode = True
else:
self.charset = DEFAULT_CHARSET
self.use_unicode = False
if use_unicode:
self.use_unicode = use_unicode
client_flag |= CAPABILITIES client_flag |= CAPABILITIES
client_flag |= MULTI_STATEMENTS client_flag |= MULTI_STATEMENTS
@@ -483,20 +490,19 @@ class Connection(object):
self.cursorclass = cursorclass self.cursorclass = cursorclass
self.connect_timeout = connect_timeout self.connect_timeout = connect_timeout
self._connect() self._connect()
self.set_charset_set(charset)
self.messages = [] self.messages = []
self.set_charset(charset)
self.encoders = encoders self.encoders = encoders
self.decoders = conv self.decoders = conv
self.field_decoders = field_decoders
self._affected_rows = 0 self._affected_rows = 0
self.host_info = "Not connected" self.host_info = "Not connected"
self.autocommit(False) self.autocommit(False)
if sql_mode is not None: if sql_mode is not None:
c = self.cursor() c = self.cursor()
c.execute("SET sql_mode=%s", (sql_mode,)) c.execute("SET sql_mode=%s", (sql_mode,))
@@ -506,21 +512,20 @@ class Connection(object):
if init_command is not None: if init_command is not None:
c = self.cursor() c = self.cursor()
c.execute(init_command) c.execute(init_command)
self.commit() self.commit()
def close(self): def close(self):
''' Send the quit message and close the socket ''' ''' Send the quit message and close the socket '''
try: if self.socket:
send_data = struct.pack('<i',1) + COM_QUIT send_data = struct.pack('<i',1) + COM_QUIT
sock = self.socket self.socket.send(send_data)
sock.send(send_data) self.socket.close()
sock.close() self.socket = None
except: else:
exc,value,tb = sys.exc_info() self.errorhandler(None, InterfaceError, "(0, '')")
self.errorhandler(None, exc, value)
def autocommit(self, value): def autocommit(self, value):
''' Set whether or not to commit after every execute() ''' ''' Set whether or not to commit after every execute() '''
try: try:
@@ -560,7 +565,7 @@ class Connection(object):
def cursor(self): def cursor(self):
''' Create a new cursor to execute queries with ''' ''' Create a new cursor to execute queries with '''
return self.cursorclass(self) return self.cursorclass(self)
def __enter__(self): def __enter__(self):
''' Context manager that returns a Cursor ''' ''' Context manager that returns a Cursor '''
return self.cursor() return self.cursor()
@@ -577,7 +582,7 @@ class Connection(object):
self._execute_command(COM_QUERY, sql) self._execute_command(COM_QUERY, sql)
self._affected_rows = self._read_query_result() self._affected_rows = self._read_query_result()
return self._affected_rows return self._affected_rows
def next_result(self): def next_result(self):
self._affected_rows = self._read_query_result() self._affected_rows = self._read_query_result()
return self._affected_rows return self._affected_rows
@@ -595,7 +600,7 @@ class Connection(object):
return return
pkt = self.read_packet() pkt = self.read_packet()
return pkt.is_ok_packet() return pkt.is_ok_packet()
def ping(self, reconnect=True): def ping(self, reconnect=True):
''' Check if the server is alive ''' ''' Check if the server is alive '''
try: try:
@@ -612,14 +617,13 @@ class Connection(object):
pkt = self.read_packet() pkt = self.read_packet()
return pkt.is_ok_packet() return pkt.is_ok_packet()
def set_charset_set(self, charset): def set_charset(self, charset):
try: try:
sock = self.socket if charset:
if charset and self.charset != charset:
self._execute_command(COM_QUERY, "SET NAMES %s" % self._execute_command(COM_QUERY, "SET NAMES %s" %
self.escape(charset)) self.escape(charset))
self.read_packet() self.read_packet()
self.charset = charset self.charset = charset
except: except:
exc,value,tb = sys.exc_info() exc,value,tb = sys.exc_info()
self.errorhandler(None, exc, value) self.errorhandler(None, exc, value)
@@ -647,7 +651,7 @@ class Connection(object):
self._request_authentication() self._request_authentication()
except socket.error, e: except socket.error, e:
raise OperationalError(2003, "Can't connect to MySQL server on %r (%d)" % (self.host, e.args[0])) raise OperationalError(2003, "Can't connect to MySQL server on %r (%d)" % (self.host, e.args[0]))
def read_packet(self, packet_type=MysqlPacket): def read_packet(self, packet_type=MysqlPacket):
"""Read an entire "mysql packet" in its entirety from the network """Read an entire "mysql packet" in its entirety from the network
and return a MysqlPacket type that represents the results.""" and return a MysqlPacket type that represents the results."""
@@ -673,7 +677,9 @@ class Connection(object):
pckt_no = 0 pckt_no = 0
while len(buf) >= MAX_PACKET_LENGTH: while len(buf) >= MAX_PACKET_LENGTH:
header = struct.pack('<i', MAX_PACKET_LENGTH)[:-1]+chr(pckt_no) header = struct.pack('<i', MAX_PACKET_LENGTH)[:-1]+chr(pckt_no)
self.socket.send(header+buf[:MAX_PACKET_LENGTH]) send_data = header + buf[:MAX_PACKET_LENGTH]
self.socket.send(send_data)
if DEBUG: dump_packet(send_data)
buf = buf[MAX_PACKET_LENGTH:] buf = buf[MAX_PACKET_LENGTH:]
pckt_no += 1 pckt_no += 1
header = struct.pack('<i', len(buf))[:-1]+chr(pckt_no) header = struct.pack('<i', len(buf))[:-1]+chr(pckt_no)
@@ -683,13 +689,12 @@ class Connection(object):
#sock = self.socket #sock = self.socket
#sock.send(send_data) #sock.send(send_data)
if DEBUG: dump_packet(send_data) #
def _execute_command(self, command, sql): def _execute_command(self, command, sql):
self._send_command(command, sql) self._send_command(command, sql)
def _request_authentication(self): def _request_authentication(self):
sock = self.socket
self._send_authentication() self._send_authentication()
def _send_authentication(self): def _send_authentication(self):
@@ -700,9 +705,12 @@ class Connection(object):
if self.user is None: if self.user is None:
raise ValueError, "Did not specify a username" raise ValueError, "Did not specify a username"
data_init = (struct.pack('<i', self.client_flag)) \ charset_id = charset_by_name(self.charset).id
+ "\0\0\0\x01" + '\x08' + '\0'*23 self.user = self.user.encode(self.charset)
data_init = struct.pack('<i', self.client_flag) + "\0\0\0\x01" + \
chr(charset_id) + '\0'*23
next_packet = 1 next_packet = 1
@@ -722,13 +730,14 @@ class Connection(object):
data = data_init + self.user+"\0" + _scramble(self.password, self.salt) data = data_init + self.user+"\0" + _scramble(self.password, self.salt)
if self.db: if self.db:
data += self.db.encode(self.charset) + "\0" self.db = self.db.encode(self.charset)
data += self.db + "\0"
data = pack_int24(len(data)) + chr(next_packet) + data data = pack_int24(len(data)) + chr(next_packet) + data
next_packet += 2 next_packet += 2
if DEBUG: dump_packet(data) if DEBUG: dump_packet(data)
sock.send(data) sock.send(data)
auth_packet = MysqlPacket(sock) auth_packet = MysqlPacket(sock)
@@ -743,13 +752,13 @@ class Connection(object):
#raise NotImplementedError, "old_passwords are not supported. Check to see if mysqld was started with --old-passwords, if old-passwords=1 in a my.cnf file, or if there are some short hashes in your mysql.user table." #raise NotImplementedError, "old_passwords are not supported. Check to see if mysqld was started with --old-passwords, if old-passwords=1 in a my.cnf file, or if there are some short hashes in your mysql.user table."
data = _scramble_323(self.password, self.salt) + "\0" data = _scramble_323(self.password, self.salt) + "\0"
data = pack_int24(len(data)) + chr(next_packet) + data data = pack_int24(len(data)) + chr(next_packet) + data
sock.send(data) sock.send(data)
auth_packet = MysqlPacket(sock) auth_packet = MysqlPacket(sock)
auth_packet.check_error() auth_packet.check_error()
if DEBUG: auth_packet.dump() if DEBUG: auth_packet.dump()
# _mysql support # _mysql support
def thread_id(self): def thread_id(self):
return self.server_thread_id[0] return self.server_thread_id[0]
@@ -759,10 +768,10 @@ class Connection(object):
def get_host_info(self): def get_host_info(self):
return self.host_info return self.host_info
def get_proto_info(self): def get_proto_info(self):
return self.protocol_version return self.protocol_version
def _get_server_information(self): def _get_server_information(self):
sock = self.socket sock = self.socket
i = 0 i = 0
@@ -773,27 +782,28 @@ class Connection(object):
#packet_len = ord(data[i:i+1]) #packet_len = ord(data[i:i+1])
#i += 4 #i += 4
self.protocol_version = ord(data[i:i+1]) self.protocol_version = ord(data[i:i+1])
i += 1 i += 1
server_end = data.find("\0", i) server_end = data.find("\0", i)
self.server_version = data[i:server_end] self.server_version = data[i:server_end]
i = server_end + 1 i = server_end + 1
self.server_thread_id = struct.unpack('<h', data[i:i+2]) self.server_thread_id = struct.unpack('<h', data[i:i+2])
i += 4 i += 4
self.salt = data[i:i+8] self.salt = data[i:i+8]
i += 9 i += 9
if len(data) >= i + 1: if len(data) >= i + 1:
i += 1 i += 1
self.server_capabilities = struct.unpack('<h', data[i:i+2])[0] self.server_capabilities = struct.unpack('<h', data[i:i+2])[0]
i += 1 i += 1
self.server_language = ord(data[i:i+1]) self.server_language = ord(data[i:i+1])
self.server_charset = charset_by_id(self.server_language).name
i += 16
i += 16
if len(data) >= i+12-1: if len(data) >= i+12-1:
rest_salt = data[i:i+12] rest_salt = data[i:i+12]
self.salt += rest_salt self.salt += rest_salt
@@ -840,8 +850,8 @@ class MySQLResult(object):
def _read_ok_packet(self): def _read_ok_packet(self):
self.first_packet.advance(1) # field_count (always '0') self.first_packet.advance(1) # field_count (always '0')
self.affected_rows = self.first_packet.read_coded_length() self.affected_rows = self.first_packet.read_length_coded_binary()
self.insert_id = self.first_packet.read_coded_length() self.insert_id = self.first_packet.read_length_coded_binary()
self.server_status = struct.unpack('<H', self.first_packet.read(2))[0] self.server_status = struct.unpack('<H', self.first_packet.read(2))[0]
self.warning_count = struct.unpack('<H', self.first_packet.read(2))[0] self.warning_count = struct.unpack('<H', self.first_packet.read(2))[0]
self.message = self.first_packet.read_all() self.message = self.first_packet.read_all()
@@ -871,14 +881,7 @@ class MySQLResult(object):
converter = self.connection.decoders[field.type_code] converter = self.connection.decoders[field.type_code]
if DEBUG: print "DEBUG: field=%s, converter=%s" % (field, converter) if DEBUG: print "DEBUG: field=%s, converter=%s" % (field, converter)
data = packet.read_length_coded_binary() data = packet.read_length_coded_string()
converted = None
if data != None:
converted = converter(data)
else:
converter = self.connection.field_decoders[field.type_code]
if DEBUG: print "DEBUG: field=%s, converter=%s" % (field, converter)
data = packet.read_length_coded_binary()
converted = None converted = None
if data != None: if data != None:
converted = converter(self.connection, field, data) converted = converter(self.connection, field, data)

View File

@@ -29,4 +29,4 @@ STRING = 254
GEOMETRY = 255 GEOMETRY = 255
CHAR = TINY CHAR = TINY
INTERVAL = ENUM INTERVAL = ENUM

View File

@@ -1,11 +1,9 @@
import re import re
import datetime import datetime
import time import time
import array
import struct
from times import Date, Time, TimeDelta, Timestamp from constants import FIELD_TYPE, FLAG
from constants import FIELD_TYPE from charset import charset_by_id
try: try:
set set
@@ -20,8 +18,16 @@ ESCAPE_MAP = {'\0': '\\0', '\n': '\\n', '\r': '\\r', '\032': '\\Z',
'\'': '\\\'', '"': '\\"', '\\': '\\\\'} '\'': '\\\'', '"': '\\"', '\\': '\\\\'}
def escape_item(val, charset): def escape_item(val, charset):
if type(val) in [tuple, list, set]:
return escape_sequence(val, charset)
if type(val) is dict:
return escape_dict(val, charset)
encoder = encoders[type(val)] encoder = encoders[type(val)]
return encoder(val, charset) val = encoder(val)
if type(val) is str:
return val
val = val.encode(charset)
return val
def escape_dict(val, charset): def escape_dict(val, charset):
n = {} n = {}
@@ -37,99 +43,96 @@ def escape_sequence(val, charset):
n.append(quoted) n.append(quoted)
return tuple(n) return tuple(n)
def escape_bool(value, charset): def escape_set(val, charset):
return str(int(value)).encode(charset) val = map(lambda x: escape_item(x, charset), val)
return ','.join(val)
def escape_object(value, charset): def escape_bool(value):
return str(value).encode(charset) return str(int(value))
def escape_object(value):
return str(value)
escape_int = escape_long = escape_object escape_int = escape_long = escape_object
def escape_float(value, charset): def escape_float(value):
return ('%.15g' % value).encode(charset) return ('%.15g' % value)
def escape_string(value, charset): def escape_string(value):
r = ("'%s'" % ESCAPE_REGEX.sub( return ("'%s'" % ESCAPE_REGEX.sub(
lambda match: ESCAPE_MAP.get(match.group(0)), value)) lambda match: ESCAPE_MAP.get(match.group(0)), value))
# TODO: make sure that encodings are handled correctly here.
# Since we may be dealing with binary data, the encoding
# routine below is commented out.
#if not charset is None:
# r = r.encode(charset)
return r
def escape_unicode(value, charset):
# pass None as the charset because we already encode it
return escape_string(value.encode(charset), None)
def escape_None(value, charset): def escape_unicode(value):
return 'NULL'.encode(charset) return escape_string(value)
def escape_timedelta(obj, charset): def escape_None(value):
return 'NULL'
def escape_timedelta(obj):
seconds = int(obj.seconds) % 60 seconds = int(obj.seconds) % 60
minutes = int(obj.seconds // 60) % 60 minutes = int(obj.seconds // 60) % 60
hours = int(obj.seconds // 3600) % 24 + int(obj.days) * 24 hours = int(obj.seconds // 3600) % 24 + int(obj.days) * 24
return escape_string('%02d:%02d:%02d' % (hours, minutes, seconds), charset) return escape_string('%02d:%02d:%02d' % (hours, minutes, seconds))
def escape_time(obj, charset): def escape_time(obj):
s = "%02d:%02d:%02d" % (int(obj.hour), int(obj.minute), s = "%02d:%02d:%02d" % (int(obj.hour), int(obj.minute),
int(obj.second)) int(obj.second))
if obj.microsecond: if obj.microsecond:
s += ".%f" % obj.microsecond s += ".%f" % obj.microsecond
return escape_string(s, charset) return escape_string(s)
def escape_datetime(obj, charset): def escape_datetime(obj):
return escape_string(obj.strftime("%Y-%m-%d %H:%M:%S"), charset) return escape_string(obj.strftime("%Y-%m-%d %H:%M:%S"))
def escape_date(obj, charset): def escape_date(obj):
return escape_string(obj.strftime("%Y-%m-%d"), charset) return escape_string(obj.strftime("%Y-%m-%d"))
def escape_struct_time(obj, charset): def escape_struct_time(obj):
return escape_datetime(datetime.datetime(*obj[:6]), charset) return escape_datetime(datetime.datetime(*obj[:6]))
def convert_datetime(obj): def convert_datetime(connection, field, obj):
"""Returns a DATETIME or TIMESTAMP column value as a datetime object: """Returns a DATETIME or TIMESTAMP column value as a datetime object:
>>> datetime_or_None('2007-02-25 23:06:20') >>> datetime_or_None('2007-02-25 23:06:20')
datetime.datetime(2007, 2, 25, 23, 6, 20) datetime.datetime(2007, 2, 25, 23, 6, 20)
>>> datetime_or_None('2007-02-25T23:06:20') >>> datetime_or_None('2007-02-25T23:06:20')
datetime.datetime(2007, 2, 25, 23, 6, 20) datetime.datetime(2007, 2, 25, 23, 6, 20)
Illegal values are returned as None: Illegal values are returned as None:
>>> datetime_or_None('2007-02-31T23:06:20') is None >>> datetime_or_None('2007-02-31T23:06:20') is None
True True
>>> datetime_or_None('0000-00-00 00:00:00') is None >>> datetime_or_None('0000-00-00 00:00:00') is None
True True
""" """
if ' ' in obj: if ' ' in obj:
sep = ' ' sep = ' '
elif 'T' in obj: elif 'T' in obj:
sep = 'T' sep = 'T'
else: else:
return convert_date(obj) return convert_date(connection, field, obj)
try: try:
ymd, hms = obj.split(sep, 1) ymd, hms = obj.split(sep, 1)
return datetime.datetime(*[ int(x) for x in ymd.split('-')+hms.split(':') ]) return datetime.datetime(*[ int(x) for x in ymd.split('-')+hms.split(':') ])
except ValueError: except ValueError:
return convert_date(obj) return convert_date(connection, field, obj)
def convert_timedelta(obj): def convert_timedelta(connection, field, obj):
"""Returns a TIME column as a timedelta object: """Returns a TIME column as a timedelta object:
>>> timedelta_or_None('25:06:17') >>> timedelta_or_None('25:06:17')
datetime.timedelta(1, 3977) datetime.timedelta(1, 3977)
>>> timedelta_or_None('-25:06:17') >>> timedelta_or_None('-25:06:17')
datetime.timedelta(-2, 83177) datetime.timedelta(-2, 83177)
Illegal values are returned as None: Illegal values are returned as None:
>>> timedelta_or_None('random crap') is None >>> timedelta_or_None('random crap') is None
True True
Note that MySQL always returns TIME columns as (+|-)HH:MM:SS, but Note that MySQL always returns TIME columns as (+|-)HH:MM:SS, but
can accept values as (+|-)DD HH:MM:SS. The latter format will not can accept values as (+|-)DD HH:MM:SS. The latter format will not
be parsed correctly by this function. be parsed correctly by this function.
@@ -147,23 +150,23 @@ def convert_timedelta(obj):
except ValueError: except ValueError:
return None return None
def convert_time(obj): def convert_time(connection, field, obj):
"""Returns a TIME column as a time object: """Returns a TIME column as a time object:
>>> time_or_None('15:06:17') >>> time_or_None('15:06:17')
datetime.time(15, 6, 17) datetime.time(15, 6, 17)
Illegal values are returned as None: Illegal values are returned as None:
>>> time_or_None('-25:06:17') is None >>> time_or_None('-25:06:17') is None
True True
>>> time_or_None('random crap') is None >>> time_or_None('random crap') is None
True True
Note that MySQL always returns TIME columns as (+|-)HH:MM:SS, but Note that MySQL always returns TIME columns as (+|-)HH:MM:SS, but
can accept values as (+|-)DD HH:MM:SS. The latter format will not can accept values as (+|-)DD HH:MM:SS. The latter format will not
be parsed correctly by this function. be parsed correctly by this function.
Also note that MySQL's TIME column corresponds more closely to Also note that MySQL's TIME column corresponds more closely to
Python's timedelta and not time. However if you want TIME columns Python's timedelta and not time. However if you want TIME columns
to be treated as time-of-day and not a time offset, then you can to be treated as time-of-day and not a time offset, then you can
@@ -172,53 +175,54 @@ def convert_time(obj):
from math import modf from math import modf
try: try:
hour, minute, second = obj.split(':') hour, minute, second = obj.split(':')
return datetime.time(hour=int(hour), minute=int(minute), second=int(second), return datetime.time(hour=int(hour), minute=int(minute),
microsecond=int(modf(float(second))[0]*1000000)) second=int(second),
microsecond=int(modf(float(second))[0]*1000000))
except ValueError: except ValueError:
return None return None
def convert_date(obj): def convert_date(connection, field, obj):
"""Returns a DATE column as a date object: """Returns a DATE column as a date object:
>>> date_or_None('2007-02-26') >>> date_or_None('2007-02-26')
datetime.date(2007, 2, 26) datetime.date(2007, 2, 26)
Illegal values are returned as None: Illegal values are returned as None:
>>> date_or_None('2007-02-31') is None >>> date_or_None('2007-02-31') is None
True True
>>> date_or_None('0000-00-00') is None >>> date_or_None('0000-00-00') is None
True True
""" """
try: try:
return datetime.date(*[ int(x) for x in obj.split('-', 2) ]) return datetime.date(*[ int(x) for x in obj.split('-', 2) ])
except ValueError: except ValueError:
return None return None
def convert_mysql_timestamp(timestamp): def convert_mysql_timestamp(connection, field, timestamp):
"""Convert a MySQL TIMESTAMP to a Timestamp object. """Convert a MySQL TIMESTAMP to a Timestamp object.
MySQL >= 4.1 returns TIMESTAMP in the same format as DATETIME: MySQL >= 4.1 returns TIMESTAMP in the same format as DATETIME:
>>> mysql_timestamp_converter('2007-02-25 22:32:17') >>> mysql_timestamp_converter('2007-02-25 22:32:17')
datetime.datetime(2007, 2, 25, 22, 32, 17) datetime.datetime(2007, 2, 25, 22, 32, 17)
MySQL < 4.1 uses a big string of numbers: MySQL < 4.1 uses a big string of numbers:
>>> mysql_timestamp_converter('20070225223217') >>> mysql_timestamp_converter('20070225223217')
datetime.datetime(2007, 2, 25, 22, 32, 17) datetime.datetime(2007, 2, 25, 22, 32, 17)
Illegal values are returned as None: Illegal values are returned as None:
>>> mysql_timestamp_converter('2007-02-31 22:32:17') is None >>> mysql_timestamp_converter('2007-02-31 22:32:17') is None
True True
>>> mysql_timestamp_converter('00000000000000') is None >>> mysql_timestamp_converter('00000000000000') is None
True True
""" """
if timestamp[4] == '-': if timestamp[4] == '-':
return convert_datetime(timestamp) return convert_datetime(connection, field, timestamp)
timestamp += "0"*(14-len(timestamp)) # padding timestamp += "0"*(14-len(timestamp)) # padding
year, month, day, hour, minute, second = \ year, month, day, hour, minute, second = \
int(timestamp[:4]), int(timestamp[4:6]), int(timestamp[6:8]), \ int(timestamp[:4]), int(timestamp[4:6]), int(timestamp[6:8]), \
@@ -229,13 +233,38 @@ def convert_mysql_timestamp(timestamp):
return None return None
def convert_set(s): def convert_set(s):
# TODO: this may not be correct
return set(s.split(",")) return set(s.split(","))
def convert_bit(b): def convert_bit(connection, field, b):
b = "\x00" * (8 - len(b)) + b # pad w/ zeroes #b = "\x00" * (8 - len(b)) + b # pad w/ zeroes
return struct.unpack(">Q", b)[0] #return struct.unpack(">Q", b)[0]
#
# the snippet above is right, but MySQLdb doesn't process bits,
# so we shouldn't either
return b
def convert_characters(connection, field, data):
if field.flags & FLAG.SET:
return convert_set(data)
if field.flags & FLAG.BINARY:
return data
field_charset = charset_by_id(field.charsetnr).name
if connection.use_unicode:
data = data.decode(field_charset)
elif connection.charset != field_charset:
data = data.decode(field_charset)
data = data.encode(connection.charset)
return data
def convert_int(connection, field, data):
return int(data)
def convert_long(connection, field, data):
return long(data)
def convert_float(connection, field, data):
return float(data)
encoders = { encoders = {
bool: escape_bool, bool: escape_bool,
int: escape_int, int: escape_int,
@@ -257,21 +286,28 @@ encoders = {
decoders = { decoders = {
FIELD_TYPE.BIT: convert_bit, FIELD_TYPE.BIT: convert_bit,
FIELD_TYPE.TINY: int, FIELD_TYPE.TINY: convert_int,
FIELD_TYPE.SHORT: int, FIELD_TYPE.SHORT: convert_int,
FIELD_TYPE.LONG: long, FIELD_TYPE.LONG: convert_long,
FIELD_TYPE.FLOAT: float, FIELD_TYPE.FLOAT: convert_float,
FIELD_TYPE.DOUBLE: float, FIELD_TYPE.DOUBLE: convert_float,
FIELD_TYPE.DECIMAL: float, FIELD_TYPE.DECIMAL: convert_float,
FIELD_TYPE.NEWDECIMAL: float, FIELD_TYPE.NEWDECIMAL: convert_float,
FIELD_TYPE.LONGLONG: long, FIELD_TYPE.LONGLONG: convert_long,
FIELD_TYPE.INT24: int, FIELD_TYPE.INT24: convert_int,
FIELD_TYPE.YEAR: int, FIELD_TYPE.YEAR: convert_int,
FIELD_TYPE.TIMESTAMP: convert_mysql_timestamp, FIELD_TYPE.TIMESTAMP: convert_mysql_timestamp,
FIELD_TYPE.DATETIME: convert_datetime, FIELD_TYPE.DATETIME: convert_datetime,
FIELD_TYPE.TIME: convert_timedelta, FIELD_TYPE.TIME: convert_timedelta,
FIELD_TYPE.DATE: convert_date, FIELD_TYPE.DATE: convert_date,
FIELD_TYPE.SET: convert_set, FIELD_TYPE.SET: convert_set,
FIELD_TYPE.BLOB: convert_characters,
FIELD_TYPE.TINY_BLOB: convert_characters,
FIELD_TYPE.MEDIUM_BLOB: convert_characters,
FIELD_TYPE.LONG_BLOB: convert_characters,
FIELD_TYPE.STRING: convert_characters,
FIELD_TYPE.VAR_STRING: convert_characters,
FIELD_TYPE.VARCHAR: convert_characters,
#FIELD_TYPE.BLOB: str, #FIELD_TYPE.BLOB: str,
#FIELD_TYPE.STRING: str, #FIELD_TYPE.STRING: str,
#FIELD_TYPE.VAR_STRING: str, #FIELD_TYPE.VAR_STRING: str,
@@ -279,28 +315,13 @@ decoders = {
} }
conversions = decoders # for MySQLdb compatibility conversions = decoders # for MySQLdb compatibility
def decode_characters(connection, field, data):
if field.charsetnr == 63 or not connection.use_unicode:
# binary data, leave it alone
return data
return data.decode(connection.charset)
# These take a field instance rather than just the data.
field_decoders = {
FIELD_TYPE.BLOB: decode_characters,
FIELD_TYPE.TINY_BLOB: decode_characters,
FIELD_TYPE.MEDIUM_BLOB: decode_characters,
FIELD_TYPE.LONG_BLOB: decode_characters,
FIELD_TYPE.STRING: decode_characters,
FIELD_TYPE.VAR_STRING: decode_characters,
FIELD_TYPE.VARCHAR: decode_characters,
}
try: try:
# python version > 2.3 # python version > 2.3
from decimal import Decimal from decimal import Decimal
decoders[FIELD_TYPE.DECIMAL] = Decimal def convert_decimal(connection, field, data):
decoders[FIELD_TYPE.NEWDECIMAL] = Decimal return Decimal(data)
decoders[FIELD_TYPE.DECIMAL] = convert_decimal
decoders[FIELD_TYPE.NEWDECIMAL] = convert_decimal
def escape_decimal(obj, charset): def escape_decimal(obj, charset):
return unicode(obj).encode(charset) return unicode(obj).encode(charset)

View File

@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
import struct import struct
import re import re
@@ -52,23 +53,23 @@ class Cursor(object):
if not self.connection: if not self.connection:
self.errorhandler(self, ProgrammingError, "cursor closed") self.errorhandler(self, ProgrammingError, "cursor closed")
return self.connection return self.connection
def _check_executed(self): def _check_executed(self):
if not self._executed: if not self._executed:
self.errorhandler(self, ProgrammingError, "execute() first") self.errorhandler(self, ProgrammingError, "execute() first")
def setinputsizes(self, *args): def setinputsizes(self, *args):
"""Does nothing, required by DB API.""" """Does nothing, required by DB API."""
def setoutputsizes(self, *args): def setoutputsizes(self, *args):
"""Does nothing, required by DB API.""" """Does nothing, required by DB API."""
def nextset(self): def nextset(self):
''' Get the next query set ''' ''' Get the next query set '''
if self._executed: if self._executed:
self.fetchall() self.fetchall()
del self.messages[:] del self.messages[:]
if not self._has_next: if not self._has_next:
return None return None
connection = self._get_db() connection = self._get_db()
@@ -79,11 +80,11 @@ class Cursor(object):
def execute(self, query, args=None): def execute(self, query, args=None):
''' Execute a query ''' ''' Execute a query '''
from sys import exc_info from sys import exc_info
conn = self._get_db() conn = self._get_db()
charset = conn.charset charset = conn.charset
del self.messages[:] del self.messages[:]
# this ordering is good because conn.escape() returns # this ordering is good because conn.escape() returns
# an encoded string. # an encoded string.
if isinstance(query, unicode): if isinstance(query, unicode):
@@ -91,7 +92,7 @@ class Cursor(object):
if args is not None: if args is not None:
query = query % conn.escape(args) query = query % conn.escape(args)
result = 0 result = 0
try: try:
result = self._query(query) result = self._query(query)
@@ -103,7 +104,7 @@ class Cursor(object):
self._executed = query self._executed = query
return result return result
def executemany(self, query, args): def executemany(self, query, args):
''' Run several data against one query ''' ''' Run several data against one query '''
del self.messages[:] del self.messages[:]
@@ -113,30 +114,66 @@ class Cursor(object):
charset = conn.charset charset = conn.charset
if isinstance(query, unicode): if isinstance(query, unicode):
query = query.encode(charset) query = query.encode(charset)
self.rowcount = sum([ self.execute(query, arg) for arg in args ]) self.rowcount = sum([ self.execute(query, arg) for arg in args ])
return self.rowcount return self.rowcount
def callproc(self, procname, args=()): def callproc(self, procname, args=()):
''' Call a stored procedure. Take care to ensure that procname is """Execute stored procedure procname with args
properly escaped. '''
if not isinstance(args, tuple):
args = (args,)
argstr = ("%s," * len(args))[:-1] procname -- string, name of procedure to execute on server
return self.execute("CALL `%s`(%s)" % (procname, argstr), args) args -- Sequence of parameters to use with procedure
Returns the original args.
Compatibility warning: PEP-249 specifies that any modified
parameters must be returned. This is currently impossible
as they are only available by storing them in a server
variable and then retrieved by a query. Since stored
procedures return zero or more result sets, there is no
reliable way to get at OUT or INOUT parameters via callproc.
The server variables are named @_procname_n, where procname
is the parameter above and n is the position of the parameter
(from zero). Once all result sets generated by the procedure
have been fetched, you can issue a SELECT @_procname_0, ...
query using .execute() to get any OUT or INOUT values.
Compatibility warning: The act of calling a stored procedure
itself creates an empty result set. This appears after any
result sets generated by the procedure. This is non-standard
behavior with respect to the DB-API. Be sure to use nextset()
to advance through all result sets; otherwise you may get
disconnected.
"""
conn = self._get_db()
for index, arg in enumerate(args):
q = "SET @_%s_%d=%s" % (procname, index, conn.escape(arg))
if isinstance(q, unicode):
q = q.encode(conn.charset)
self._query(q)
self.nextset()
q = "CALL %s(%s)" % (procname,
','.join(['@_%s_%d' % (procname, i)
for i in range(len(args))]))
if isinstance(q, unicode):
q = q.encode(conn.charset)
self._query(q)
self._executed = q
return args
def fetchone(self): def fetchone(self):
''' Fetch the next row ''' ''' Fetch the next row '''
self._check_executed() self._check_executed()
if self._rows is None or self.rownumber >= len(self._rows): if self._rows is None or self.rownumber >= len(self._rows):
return None return None
result = self._rows[self.rownumber] result = self._rows[self.rownumber]
self.rownumber += 1 self.rownumber += 1
return result return result
def fetchmany(self, size=None): def fetchmany(self, size=None):
''' Fetch several rows ''' ''' Fetch several rows '''
self._check_executed() self._check_executed()
@@ -158,15 +195,15 @@ class Cursor(object):
result = self._rows result = self._rows
self.rownumber = len(self._rows) self.rownumber = len(self._rows)
return result return result
def scroll(self, value, mode='relative'): def scroll(self, value, mode='relative'):
self._check_executed()
if mode == 'relative': if mode == 'relative':
r = self.rownumber + value r = self.rownumber + value
elif mode == 'absolute': elif mode == 'absolute':
r = value r = value
else: else:
self.errorhandler(self, ProgrammingError, self.errorhandler(self, ProgrammingError,
"unknown scroll mode %s" % mode) "unknown scroll mode %s" % mode)
if r < 0 or r >= len(self._rows): if r < 0 or r >= len(self._rows):
@@ -179,23 +216,23 @@ class Cursor(object):
conn.query(q) conn.query(q)
self._do_get_result() self._do_get_result()
return self.rowcount return self.rowcount
def _do_get_result(self): def _do_get_result(self):
conn = self._get_db() conn = self._get_db()
self.rowcount = conn._result.affected_rows self.rowcount = conn._result.affected_rows
self.rownumber = 0 self.rownumber = 0
self.description = conn._result.description self.description = conn._result.description
self.lastrowid = conn._result.insert_id self.lastrowid = conn._result.insert_id
self._rows = conn._result.rows self._rows = conn._result.rows
self._has_next = conn._result.has_next self._has_next = conn._result.has_next
conn._result = None conn._result = None
def __iter__(self): def __iter__(self):
self._check_executed() self._check_executed()
result = self.rownumber and self._rows[self.rownumber:] or self._rows result = self.rownumber and self._rows[self.rownumber:] or self._rows
return iter(result) return iter(result)
Warning = Warning Warning = Warning
Error = Error Error = Error
InterfaceError = InterfaceError InterfaceError = InterfaceError

View File

@@ -3,7 +3,8 @@ import unittest
class PyMySQLTestCase(unittest.TestCase): class PyMySQLTestCase(unittest.TestCase):
databases = [ databases = [
{"host":"localhost","user":"root","passwd":"","db":"test_pymysql"}, {"host":"localhost","user":"root",
"passwd":"","db":"test_pymysql", "use_unicode": True},
{"host":"localhost","user":"root","passwd":"","db":"test_pymysql2"}] {"host":"localhost","user":"root","passwd":"","db":"test_pymysql2"}]
def setUp(self): def setUp(self):

View File

@@ -15,7 +15,8 @@ class TestConversion(base.PyMySQLTestCase):
c.execute("insert into test_datatypes (b,i,l,f,s,u,bb,d,dt,td,t,st) values (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)", v) c.execute("insert into test_datatypes (b,i,l,f,s,u,bb,d,dt,td,t,st) values (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)", v)
c.execute("select b,i,l,f,s,u,bb,d,dt,td,t,st from test_datatypes") c.execute("select b,i,l,f,s,u,bb,d,dt,td,t,st from test_datatypes")
r = c.fetchone() r = c.fetchone()
self.assertEqual(v[:8], r[:8]) self.assertEqual("\x01", r[0])
self.assertEqual(v[1:8], r[1:8])
# mysql throws away microseconds so we need to check datetimes # mysql throws away microseconds so we need to check datetimes
# specially. additionally times are turned into timedeltas. # specially. additionally times are turned into timedeltas.
self.assertEqual(datetime.datetime(*v[8].timetuple()[:6]), r[8]) self.assertEqual(datetime.datetime(*v[8].timetuple()[:6]), r[8])

View File

@@ -104,8 +104,8 @@ KEY (`station`,`dh`,`echeance`)) ENGINE=MyISAM DEFAULT CHARSET=latin1;""")
self.assertEqual('1', pymysql.converters.escape_item(1, "utf8")) self.assertEqual('1', pymysql.converters.escape_item(1, "utf8"))
self.assertEqual('1', pymysql.converters.escape_item(1L, "utf8")) self.assertEqual('1', pymysql.converters.escape_item(1L, "utf8"))
self.assertEqual('1', pymysql.converters.escape_object(1, "utf8")) self.assertEqual('1', pymysql.converters.escape_object(1))
self.assertEqual('1', pymysql.converters.escape_object(1L, "utf8")) self.assertEqual('1', pymysql.converters.escape_object(1L))
def test_issue_15(self): def test_issue_15(self):
""" query should be expanded before perform character encoding """ """ query should be expanded before perform character encoding """