diff --git a/pymysql/charset.py b/pymysql/charset.py index 6a5907a..10a91bd 100644 --- a/pymysql/charset.py +++ b/pymysql/charset.py @@ -1,5 +1,3 @@ - - MBLENGTH = { 8:1, 33:3, @@ -7,4 +5,170 @@ MBLENGTH = { 91:2 } +class Charset: + def __init__(self, id, name, collation, is_default): + self.id, self.name, self.collation = id, name, collation + self.is_default = is_default == 'Yes' + +class Charsets: + def __init__(self): + self._by_id = {} + + def add(self, c): + self._by_id[c.id] = c + + def by_id(self, id): + return self._by_id[id] + + def by_name(self, name): + for c in self._by_id.values(): + if c.name == name and c.is_default: + return c + +_charsets = Charsets() +""" +Generated with: + +mysql -N -s -e "select id, character_set_name, collation_name, is_default +from information_schema.collations order by id;" | python -c "import sys +for l in sys.stdin.readlines(): + id, name, collation, is_default = l.split(chr(9)) + print '_charsets.add(Charset(%s, \'%s\', \'%s\', \'%s\'))' \ + % (id, name, collation, is_default.strip()) +" + +""" +_charsets.add(Charset(1, 'big5', 'big5_chinese_ci', 'Yes')) +_charsets.add(Charset(2, 'latin2', 'latin2_czech_cs', '')) +_charsets.add(Charset(3, 'dec8', 'dec8_swedish_ci', 'Yes')) +_charsets.add(Charset(4, 'cp850', 'cp850_general_ci', 'Yes')) +_charsets.add(Charset(5, 'latin1', 'latin1_german1_ci', '')) +_charsets.add(Charset(6, 'hp8', 'hp8_english_ci', 'Yes')) +_charsets.add(Charset(7, 'koi8r', 'koi8r_general_ci', 'Yes')) +_charsets.add(Charset(8, 'latin1', 'latin1_swedish_ci', 'Yes')) +_charsets.add(Charset(9, 'latin2', 'latin2_general_ci', 'Yes')) +_charsets.add(Charset(10, 'swe7', 'swe7_swedish_ci', 'Yes')) +_charsets.add(Charset(11, 'ascii', 'ascii_general_ci', 'Yes')) +_charsets.add(Charset(12, 'ujis', 'ujis_japanese_ci', 'Yes')) +_charsets.add(Charset(13, 'sjis', 'sjis_japanese_ci', 'Yes')) +_charsets.add(Charset(14, 'cp1251', 'cp1251_bulgarian_ci', '')) +_charsets.add(Charset(15, 'latin1', 'latin1_danish_ci', '')) +_charsets.add(Charset(16, 'hebrew', 'hebrew_general_ci', 'Yes')) +_charsets.add(Charset(18, 'tis620', 'tis620_thai_ci', 'Yes')) +_charsets.add(Charset(19, 'euckr', 'euckr_korean_ci', 'Yes')) +_charsets.add(Charset(20, 'latin7', 'latin7_estonian_cs', '')) +_charsets.add(Charset(21, 'latin2', 'latin2_hungarian_ci', '')) +_charsets.add(Charset(22, 'koi8u', 'koi8u_general_ci', 'Yes')) +_charsets.add(Charset(23, 'cp1251', 'cp1251_ukrainian_ci', '')) +_charsets.add(Charset(24, 'gb2312', 'gb2312_chinese_ci', 'Yes')) +_charsets.add(Charset(25, 'greek', 'greek_general_ci', 'Yes')) +_charsets.add(Charset(26, 'cp1250', 'cp1250_general_ci', 'Yes')) +_charsets.add(Charset(27, 'latin2', 'latin2_croatian_ci', '')) +_charsets.add(Charset(28, 'gbk', 'gbk_chinese_ci', 'Yes')) +_charsets.add(Charset(29, 'cp1257', 'cp1257_lithuanian_ci', '')) +_charsets.add(Charset(30, 'latin5', 'latin5_turkish_ci', 'Yes')) +_charsets.add(Charset(31, 'latin1', 'latin1_german2_ci', '')) +_charsets.add(Charset(32, 'armscii8', 'armscii8_general_ci', 'Yes')) +_charsets.add(Charset(33, 'utf8', 'utf8_general_ci', 'Yes')) +_charsets.add(Charset(34, 'cp1250', 'cp1250_czech_cs', '')) +_charsets.add(Charset(35, 'ucs2', 'ucs2_general_ci', 'Yes')) +_charsets.add(Charset(36, 'cp866', 'cp866_general_ci', 'Yes')) +_charsets.add(Charset(37, 'keybcs2', 'keybcs2_general_ci', 'Yes')) +_charsets.add(Charset(38, 'macce', 'macce_general_ci', 'Yes')) +_charsets.add(Charset(39, 'macroman', 'macroman_general_ci', 'Yes')) +_charsets.add(Charset(40, 'cp852', 'cp852_general_ci', 'Yes')) +_charsets.add(Charset(41, 'latin7', 'latin7_general_ci', 'Yes')) +_charsets.add(Charset(42, 'latin7', 'latin7_general_cs', '')) +_charsets.add(Charset(43, 'macce', 'macce_bin', '')) +_charsets.add(Charset(44, 'cp1250', 'cp1250_croatian_ci', '')) +_charsets.add(Charset(47, 'latin1', 'latin1_bin', '')) +_charsets.add(Charset(48, 'latin1', 'latin1_general_ci', '')) +_charsets.add(Charset(49, 'latin1', 'latin1_general_cs', '')) +_charsets.add(Charset(50, 'cp1251', 'cp1251_bin', '')) +_charsets.add(Charset(51, 'cp1251', 'cp1251_general_ci', 'Yes')) +_charsets.add(Charset(52, 'cp1251', 'cp1251_general_cs', '')) +_charsets.add(Charset(53, 'macroman', 'macroman_bin', '')) +_charsets.add(Charset(57, 'cp1256', 'cp1256_general_ci', 'Yes')) +_charsets.add(Charset(58, 'cp1257', 'cp1257_bin', '')) +_charsets.add(Charset(59, 'cp1257', 'cp1257_general_ci', 'Yes')) +_charsets.add(Charset(63, 'binary', 'binary', 'Yes')) +_charsets.add(Charset(64, 'armscii8', 'armscii8_bin', '')) +_charsets.add(Charset(65, 'ascii', 'ascii_bin', '')) +_charsets.add(Charset(66, 'cp1250', 'cp1250_bin', '')) +_charsets.add(Charset(67, 'cp1256', 'cp1256_bin', '')) +_charsets.add(Charset(68, 'cp866', 'cp866_bin', '')) +_charsets.add(Charset(69, 'dec8', 'dec8_bin', '')) +_charsets.add(Charset(70, 'greek', 'greek_bin', '')) +_charsets.add(Charset(71, 'hebrew', 'hebrew_bin', '')) +_charsets.add(Charset(72, 'hp8', 'hp8_bin', '')) +_charsets.add(Charset(73, 'keybcs2', 'keybcs2_bin', '')) +_charsets.add(Charset(74, 'koi8r', 'koi8r_bin', '')) +_charsets.add(Charset(75, 'koi8u', 'koi8u_bin', '')) +_charsets.add(Charset(77, 'latin2', 'latin2_bin', '')) +_charsets.add(Charset(78, 'latin5', 'latin5_bin', '')) +_charsets.add(Charset(79, 'latin7', 'latin7_bin', '')) +_charsets.add(Charset(80, 'cp850', 'cp850_bin', '')) +_charsets.add(Charset(81, 'cp852', 'cp852_bin', '')) +_charsets.add(Charset(82, 'swe7', 'swe7_bin', '')) +_charsets.add(Charset(83, 'utf8', 'utf8_bin', '')) +_charsets.add(Charset(84, 'big5', 'big5_bin', '')) +_charsets.add(Charset(85, 'euckr', 'euckr_bin', '')) +_charsets.add(Charset(86, 'gb2312', 'gb2312_bin', '')) +_charsets.add(Charset(87, 'gbk', 'gbk_bin', '')) +_charsets.add(Charset(88, 'sjis', 'sjis_bin', '')) +_charsets.add(Charset(89, 'tis620', 'tis620_bin', '')) +_charsets.add(Charset(90, 'ucs2', 'ucs2_bin', '')) +_charsets.add(Charset(91, 'ujis', 'ujis_bin', '')) +_charsets.add(Charset(92, 'geostd8', 'geostd8_general_ci', 'Yes')) +_charsets.add(Charset(93, 'geostd8', 'geostd8_bin', '')) +_charsets.add(Charset(94, 'latin1', 'latin1_spanish_ci', '')) +_charsets.add(Charset(95, 'cp932', 'cp932_japanese_ci', 'Yes')) +_charsets.add(Charset(96, 'cp932', 'cp932_bin', '')) +_charsets.add(Charset(97, 'eucjpms', 'eucjpms_japanese_ci', 'Yes')) +_charsets.add(Charset(98, 'eucjpms', 'eucjpms_bin', '')) +_charsets.add(Charset(99, 'cp1250', 'cp1250_polish_ci', '')) +_charsets.add(Charset(128, 'ucs2', 'ucs2_unicode_ci', '')) +_charsets.add(Charset(129, 'ucs2', 'ucs2_icelandic_ci', '')) +_charsets.add(Charset(130, 'ucs2', 'ucs2_latvian_ci', '')) +_charsets.add(Charset(131, 'ucs2', 'ucs2_romanian_ci', '')) +_charsets.add(Charset(132, 'ucs2', 'ucs2_slovenian_ci', '')) +_charsets.add(Charset(133, 'ucs2', 'ucs2_polish_ci', '')) +_charsets.add(Charset(134, 'ucs2', 'ucs2_estonian_ci', '')) +_charsets.add(Charset(135, 'ucs2', 'ucs2_spanish_ci', '')) +_charsets.add(Charset(136, 'ucs2', 'ucs2_swedish_ci', '')) +_charsets.add(Charset(137, 'ucs2', 'ucs2_turkish_ci', '')) +_charsets.add(Charset(138, 'ucs2', 'ucs2_czech_ci', '')) +_charsets.add(Charset(139, 'ucs2', 'ucs2_danish_ci', '')) +_charsets.add(Charset(140, 'ucs2', 'ucs2_lithuanian_ci', '')) +_charsets.add(Charset(141, 'ucs2', 'ucs2_slovak_ci', '')) +_charsets.add(Charset(142, 'ucs2', 'ucs2_spanish2_ci', '')) +_charsets.add(Charset(143, 'ucs2', 'ucs2_roman_ci', '')) +_charsets.add(Charset(144, 'ucs2', 'ucs2_persian_ci', '')) +_charsets.add(Charset(145, 'ucs2', 'ucs2_esperanto_ci', '')) +_charsets.add(Charset(146, 'ucs2', 'ucs2_hungarian_ci', '')) +_charsets.add(Charset(192, 'utf8', 'utf8_unicode_ci', '')) +_charsets.add(Charset(193, 'utf8', 'utf8_icelandic_ci', '')) +_charsets.add(Charset(194, 'utf8', 'utf8_latvian_ci', '')) +_charsets.add(Charset(195, 'utf8', 'utf8_romanian_ci', '')) +_charsets.add(Charset(196, 'utf8', 'utf8_slovenian_ci', '')) +_charsets.add(Charset(197, 'utf8', 'utf8_polish_ci', '')) +_charsets.add(Charset(198, 'utf8', 'utf8_estonian_ci', '')) +_charsets.add(Charset(199, 'utf8', 'utf8_spanish_ci', '')) +_charsets.add(Charset(200, 'utf8', 'utf8_swedish_ci', '')) +_charsets.add(Charset(201, 'utf8', 'utf8_turkish_ci', '')) +_charsets.add(Charset(202, 'utf8', 'utf8_czech_ci', '')) +_charsets.add(Charset(203, 'utf8', 'utf8_danish_ci', '')) +_charsets.add(Charset(204, 'utf8', 'utf8_lithuanian_ci', '')) +_charsets.add(Charset(205, 'utf8', 'utf8_slovak_ci', '')) +_charsets.add(Charset(206, 'utf8', 'utf8_spanish2_ci', '')) +_charsets.add(Charset(207, 'utf8', 'utf8_roman_ci', '')) +_charsets.add(Charset(208, 'utf8', 'utf8_persian_ci', '')) +_charsets.add(Charset(209, 'utf8', 'utf8_esperanto_ci', '')) +_charsets.add(Charset(210, 'utf8', 'utf8_hungarian_ci', '')) + +def charset_by_name(name): + return _charsets.by_name(name) + +def charset_by_id(id): + return _charsets.by_id(id) diff --git a/pymysql/connections.py b/pymysql/connections.py index 82562be..2228849 100644 --- a/pymysql/connections.py +++ b/pymysql/connections.py @@ -1,8 +1,6 @@ # Python implementation of the MySQL client-server protocol # http://forge.mysql.com/wiki/MySQL_Internals_ClientServer_Protocol -import re - try: import hashlib sha_new = lambda *args, **kwargs: hashlib.new("sha1", *args, **kwargs) @@ -22,13 +20,13 @@ try: except ImportError: import StringIO -from charset import MBLENGTH +from charset import MBLENGTH, charset_by_name, charset_by_id from cursors import Cursor -from constants import FIELD_TYPE +from constants import FIELD_TYPE, FLAG from constants import SERVER_STATUS from constants.CLIENT import * from constants.COMMAND import * -from converters import escape_item, encoders, decoders, field_decoders +from converters import escape_item, encoders, decoders from err import raise_mysql_exception, Warning, Error, \ InterfaceError, DataError, DatabaseError, OperationalError, \ IntegrityError, InternalError, NotSupportedError, ProgrammingError @@ -64,7 +62,8 @@ def dump_packet(data): dump_data = [data[i:i+16] for i in xrange(len(data)) if i%16 == 0] for d in dump_data: print ' '.join(map(lambda x:"%02X" % ord(x), d)) + \ - ' ' * (16 - len(d)) + ' ' * 2 + ' '.join(map(lambda x:"%s" % is_ascii(x), d)) + ' ' * (16 - len(d)) + ' ' * 2 + \ + ' '.join(map(lambda x:"%s" % is_ascii(x), d)) print "-" * 88 print "" @@ -84,7 +83,8 @@ def _my_crypt(message1, message2): length = len(message1) result = struct.pack('B', length) for i in xrange(length): - x = (struct.unpack('B', message1[i:i+1])[0] ^ struct.unpack('B', message2[i:i+1])[0]) + x = (struct.unpack('B', message1[i:i+1])[0] ^ \ + struct.unpack('B', message2[i:i+1])[0]) result += struct.pack('B', x) return result @@ -161,9 +161,10 @@ def unpack_int64(n): (struct.unpack('B',n[6])[0] << 48) + (struct.unpack('B',n[7])[0]<<56) def defaulterrorhandler(connection, cursor, errorclass, errorvalue): - raise err = errorclass, errorvalue - + if DEBUG: + raise + if cursor: cursor.messages.append(err) else: @@ -271,8 +272,8 @@ class MysqlPacket(object): """ return self.__data[position:(position+length)] - def read_coded_length(self): - """Read a 'Length Coded' number from the data buffer. + def read_length_coded_binary(self): + """Read a 'Length Coded Binary' number from the data buffer. Length coded numbers can be anywhere from 1 to 9 bytes depending on the value of the first byte. @@ -290,16 +291,17 @@ class MysqlPacket(object): # TODO: what was 'longlong'? confirm it wasn't used? return unpack_int64(self.read(UNSIGNED_INT64_LENGTH)) - def read_length_coded_binary(self): - """Read a 'Length Coded Binary' from the data buffer. + def read_length_coded_string(self): + """Read a 'Length Coded String' from the data buffer. - A 'Length Coded Binary' consists first of a length coded + A 'Length Coded String' consists first of a length coded (unsigned, positive) integer represented in 1-9 bytes followed by that many bytes of binary data. (For example "cat" would be "3cat".) """ - length = self.read_coded_length() - if length: - return self.read(length) + length = self.read_length_coded_binary() + if length is None: + return None + return self.read(length) def is_ok_packet(self): return ord(self.get_bytes(0)) == 0 @@ -342,19 +344,17 @@ class FieldDescriptorPacket(MysqlPacket): This is compatible with MySQL 4.1+ (not compatible with MySQL 4.0). """ - self.catalog = self.read_length_coded_binary() - self.db = self.read_length_coded_binary() - self.table_name = self.read_length_coded_binary() - self.org_table = self.read_length_coded_binary() - self.name = self.read_length_coded_binary() - self.org_name = self.read_length_coded_binary() + self.catalog = self.read_length_coded_string() + self.db = self.read_length_coded_string() + self.table_name = self.read_length_coded_string() + self.org_table = self.read_length_coded_string() + self.name = self.read_length_coded_string() + self.org_name = self.read_length_coded_string() self.advance(1) # non-null filler - self.charsetnr = struct.unpack('= MAX_PACKET_LENGTH: header = struct.pack('= i + 1: i += 1 - + self.server_capabilities = struct.unpack('= i+12-1: rest_salt = data[i:i+12] self.salt += rest_salt @@ -840,8 +850,8 @@ class MySQLResult(object): def _read_ok_packet(self): self.first_packet.advance(1) # field_count (always '0') - self.affected_rows = self.first_packet.read_coded_length() - self.insert_id = self.first_packet.read_coded_length() + self.affected_rows = self.first_packet.read_length_coded_binary() + self.insert_id = self.first_packet.read_length_coded_binary() self.server_status = struct.unpack('>> datetime_or_None('2007-02-25 23:06:20') datetime.datetime(2007, 2, 25, 23, 6, 20) >>> datetime_or_None('2007-02-25T23:06:20') datetime.datetime(2007, 2, 25, 23, 6, 20) - + Illegal values are returned as None: - + >>> datetime_or_None('2007-02-31T23:06:20') is None True >>> datetime_or_None('0000-00-00 00:00:00') is None True - + """ if ' ' in obj: sep = ' ' elif 'T' in obj: sep = 'T' else: - return convert_date(obj) + return convert_date(connection, field, obj) try: ymd, hms = obj.split(sep, 1) return datetime.datetime(*[ int(x) for x in ymd.split('-')+hms.split(':') ]) except ValueError: - return convert_date(obj) + return convert_date(connection, field, obj) -def convert_timedelta(obj): +def convert_timedelta(connection, field, obj): """Returns a TIME column as a timedelta object: >>> timedelta_or_None('25:06:17') datetime.timedelta(1, 3977) >>> timedelta_or_None('-25:06:17') datetime.timedelta(-2, 83177) - + Illegal values are returned as None: - + >>> timedelta_or_None('random crap') is None True - + Note that MySQL always returns TIME columns as (+|-)HH:MM:SS, but can accept values as (+|-)DD HH:MM:SS. The latter format will not be parsed correctly by this function. @@ -147,23 +150,23 @@ def convert_timedelta(obj): except ValueError: return None -def convert_time(obj): +def convert_time(connection, field, obj): """Returns a TIME column as a time object: >>> time_or_None('15:06:17') datetime.time(15, 6, 17) - + Illegal values are returned as None: - + >>> time_or_None('-25:06:17') is None True >>> time_or_None('random crap') is None True - + Note that MySQL always returns TIME columns as (+|-)HH:MM:SS, but can accept values as (+|-)DD HH:MM:SS. The latter format will not be parsed correctly by this function. - + Also note that MySQL's TIME column corresponds more closely to Python's timedelta and not time. However if you want TIME columns to be treated as time-of-day and not a time offset, then you can @@ -172,53 +175,54 @@ def convert_time(obj): from math import modf try: hour, minute, second = obj.split(':') - return datetime.time(hour=int(hour), minute=int(minute), second=int(second), - microsecond=int(modf(float(second))[0]*1000000)) + return datetime.time(hour=int(hour), minute=int(minute), + second=int(second), + microsecond=int(modf(float(second))[0]*1000000)) except ValueError: return None -def convert_date(obj): +def convert_date(connection, field, obj): """Returns a DATE column as a date object: >>> date_or_None('2007-02-26') datetime.date(2007, 2, 26) - + Illegal values are returned as None: - + >>> date_or_None('2007-02-31') is None True >>> date_or_None('0000-00-00') is None True - + """ try: return datetime.date(*[ int(x) for x in obj.split('-', 2) ]) except ValueError: return None -def convert_mysql_timestamp(timestamp): +def convert_mysql_timestamp(connection, field, timestamp): """Convert a MySQL TIMESTAMP to a Timestamp object. MySQL >= 4.1 returns TIMESTAMP in the same format as DATETIME: - + >>> mysql_timestamp_converter('2007-02-25 22:32:17') datetime.datetime(2007, 2, 25, 22, 32, 17) - + MySQL < 4.1 uses a big string of numbers: - + >>> mysql_timestamp_converter('20070225223217') datetime.datetime(2007, 2, 25, 22, 32, 17) - + Illegal values are returned as None: - + >>> mysql_timestamp_converter('2007-02-31 22:32:17') is None True >>> mysql_timestamp_converter('00000000000000') is None True - + """ if timestamp[4] == '-': - return convert_datetime(timestamp) + return convert_datetime(connection, field, timestamp) timestamp += "0"*(14-len(timestamp)) # padding year, month, day, hour, minute, second = \ int(timestamp[:4]), int(timestamp[4:6]), int(timestamp[6:8]), \ @@ -229,13 +233,38 @@ def convert_mysql_timestamp(timestamp): return None def convert_set(s): - # TODO: this may not be correct return set(s.split(",")) -def convert_bit(b): - b = "\x00" * (8 - len(b)) + b # pad w/ zeroes - return struct.unpack(">Q", b)[0] - +def convert_bit(connection, field, b): + #b = "\x00" * (8 - len(b)) + b # pad w/ zeroes + #return struct.unpack(">Q", b)[0] + # + # the snippet above is right, but MySQLdb doesn't process bits, + # so we shouldn't either + return b + +def convert_characters(connection, field, data): + if field.flags & FLAG.SET: + return convert_set(data) + if field.flags & FLAG.BINARY: + return data + field_charset = charset_by_id(field.charsetnr).name + if connection.use_unicode: + data = data.decode(field_charset) + elif connection.charset != field_charset: + data = data.decode(field_charset) + data = data.encode(connection.charset) + return data + +def convert_int(connection, field, data): + return int(data) + +def convert_long(connection, field, data): + return long(data) + +def convert_float(connection, field, data): + return float(data) + encoders = { bool: escape_bool, int: escape_int, @@ -257,21 +286,28 @@ encoders = { decoders = { FIELD_TYPE.BIT: convert_bit, - FIELD_TYPE.TINY: int, - FIELD_TYPE.SHORT: int, - FIELD_TYPE.LONG: long, - FIELD_TYPE.FLOAT: float, - FIELD_TYPE.DOUBLE: float, - FIELD_TYPE.DECIMAL: float, - FIELD_TYPE.NEWDECIMAL: float, - FIELD_TYPE.LONGLONG: long, - FIELD_TYPE.INT24: int, - FIELD_TYPE.YEAR: int, + FIELD_TYPE.TINY: convert_int, + FIELD_TYPE.SHORT: convert_int, + FIELD_TYPE.LONG: convert_long, + FIELD_TYPE.FLOAT: convert_float, + FIELD_TYPE.DOUBLE: convert_float, + FIELD_TYPE.DECIMAL: convert_float, + FIELD_TYPE.NEWDECIMAL: convert_float, + FIELD_TYPE.LONGLONG: convert_long, + FIELD_TYPE.INT24: convert_int, + FIELD_TYPE.YEAR: convert_int, FIELD_TYPE.TIMESTAMP: convert_mysql_timestamp, FIELD_TYPE.DATETIME: convert_datetime, FIELD_TYPE.TIME: convert_timedelta, FIELD_TYPE.DATE: convert_date, FIELD_TYPE.SET: convert_set, + FIELD_TYPE.BLOB: convert_characters, + FIELD_TYPE.TINY_BLOB: convert_characters, + FIELD_TYPE.MEDIUM_BLOB: convert_characters, + FIELD_TYPE.LONG_BLOB: convert_characters, + FIELD_TYPE.STRING: convert_characters, + FIELD_TYPE.VAR_STRING: convert_characters, + FIELD_TYPE.VARCHAR: convert_characters, #FIELD_TYPE.BLOB: str, #FIELD_TYPE.STRING: str, #FIELD_TYPE.VAR_STRING: str, @@ -279,28 +315,13 @@ decoders = { } conversions = decoders # for MySQLdb compatibility -def decode_characters(connection, field, data): - if field.charsetnr == 63 or not connection.use_unicode: - # binary data, leave it alone - return data - return data.decode(connection.charset) - -# These take a field instance rather than just the data. -field_decoders = { - FIELD_TYPE.BLOB: decode_characters, - FIELD_TYPE.TINY_BLOB: decode_characters, - FIELD_TYPE.MEDIUM_BLOB: decode_characters, - FIELD_TYPE.LONG_BLOB: decode_characters, - FIELD_TYPE.STRING: decode_characters, - FIELD_TYPE.VAR_STRING: decode_characters, - FIELD_TYPE.VARCHAR: decode_characters, -} - try: # python version > 2.3 from decimal import Decimal - decoders[FIELD_TYPE.DECIMAL] = Decimal - decoders[FIELD_TYPE.NEWDECIMAL] = Decimal + def convert_decimal(connection, field, data): + return Decimal(data) + decoders[FIELD_TYPE.DECIMAL] = convert_decimal + decoders[FIELD_TYPE.NEWDECIMAL] = convert_decimal def escape_decimal(obj, charset): return unicode(obj).encode(charset) diff --git a/pymysql/cursors.py b/pymysql/cursors.py index f653b87..8804579 100644 --- a/pymysql/cursors.py +++ b/pymysql/cursors.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- import struct import re @@ -52,23 +53,23 @@ class Cursor(object): if not self.connection: self.errorhandler(self, ProgrammingError, "cursor closed") return self.connection - + def _check_executed(self): if not self._executed: self.errorhandler(self, ProgrammingError, "execute() first") - + def setinputsizes(self, *args): """Does nothing, required by DB API.""" - + def setoutputsizes(self, *args): """Does nothing, required by DB API.""" - + def nextset(self): ''' Get the next query set ''' if self._executed: self.fetchall() del self.messages[:] - + if not self._has_next: return None connection = self._get_db() @@ -79,11 +80,11 @@ class Cursor(object): def execute(self, query, args=None): ''' Execute a query ''' from sys import exc_info - + conn = self._get_db() charset = conn.charset del self.messages[:] - + # this ordering is good because conn.escape() returns # an encoded string. if isinstance(query, unicode): @@ -91,7 +92,7 @@ class Cursor(object): if args is not None: query = query % conn.escape(args) - + result = 0 try: result = self._query(query) @@ -103,7 +104,7 @@ class Cursor(object): self._executed = query return result - + def executemany(self, query, args): ''' Run several data against one query ''' del self.messages[:] @@ -113,30 +114,66 @@ class Cursor(object): charset = conn.charset if isinstance(query, unicode): query = query.encode(charset) - + self.rowcount = sum([ self.execute(query, arg) for arg in args ]) return self.rowcount - - + + def callproc(self, procname, args=()): - ''' Call a stored procedure. Take care to ensure that procname is - properly escaped. ''' - if not isinstance(args, tuple): - args = (args,) + """Execute stored procedure procname with args - argstr = ("%s," * len(args))[:-1] + procname -- string, name of procedure to execute on server - return self.execute("CALL `%s`(%s)" % (procname, argstr), args) + args -- Sequence of parameters to use with procedure + + Returns the original args. + + Compatibility warning: PEP-249 specifies that any modified + parameters must be returned. This is currently impossible + as they are only available by storing them in a server + variable and then retrieved by a query. Since stored + procedures return zero or more result sets, there is no + reliable way to get at OUT or INOUT parameters via callproc. + The server variables are named @_procname_n, where procname + is the parameter above and n is the position of the parameter + (from zero). Once all result sets generated by the procedure + have been fetched, you can issue a SELECT @_procname_0, ... + query using .execute() to get any OUT or INOUT values. + + Compatibility warning: The act of calling a stored procedure + itself creates an empty result set. This appears after any + result sets generated by the procedure. This is non-standard + behavior with respect to the DB-API. Be sure to use nextset() + to advance through all result sets; otherwise you may get + disconnected. + """ + conn = self._get_db() + for index, arg in enumerate(args): + q = "SET @_%s_%d=%s" % (procname, index, conn.escape(arg)) + if isinstance(q, unicode): + q = q.encode(conn.charset) + self._query(q) + self.nextset() + + q = "CALL %s(%s)" % (procname, + ','.join(['@_%s_%d' % (procname, i) + for i in range(len(args))])) + if isinstance(q, unicode): + q = q.encode(conn.charset) + self._query(q) + self._executed = q + + return args def fetchone(self): ''' Fetch the next row ''' - self._check_executed() + self._check_executed() if self._rows is None or self.rownumber >= len(self._rows): return None result = self._rows[self.rownumber] self.rownumber += 1 return result - + def fetchmany(self, size=None): ''' Fetch several rows ''' self._check_executed() @@ -158,15 +195,15 @@ class Cursor(object): result = self._rows self.rownumber = len(self._rows) return result - + def scroll(self, value, mode='relative'): - + self._check_executed() if mode == 'relative': r = self.rownumber + value elif mode == 'absolute': r = value else: - self.errorhandler(self, ProgrammingError, + self.errorhandler(self, ProgrammingError, "unknown scroll mode %s" % mode) if r < 0 or r >= len(self._rows): @@ -179,23 +216,23 @@ class Cursor(object): conn.query(q) self._do_get_result() return self.rowcount - + def _do_get_result(self): conn = self._get_db() self.rowcount = conn._result.affected_rows - + self.rownumber = 0 self.description = conn._result.description self.lastrowid = conn._result.insert_id self._rows = conn._result.rows self._has_next = conn._result.has_next conn._result = None - + def __iter__(self): self._check_executed() result = self.rownumber and self._rows[self.rownumber:] or self._rows return iter(result) - + Warning = Warning Error = Error InterfaceError = InterfaceError diff --git a/pymysql/tests/base.py b/pymysql/tests/base.py index a1c0a59..3e97dad 100644 --- a/pymysql/tests/base.py +++ b/pymysql/tests/base.py @@ -3,7 +3,8 @@ import unittest class PyMySQLTestCase(unittest.TestCase): databases = [ - {"host":"localhost","user":"root","passwd":"","db":"test_pymysql"}, + {"host":"localhost","user":"root", + "passwd":"","db":"test_pymysql", "use_unicode": True}, {"host":"localhost","user":"root","passwd":"","db":"test_pymysql2"}] def setUp(self): diff --git a/pymysql/tests/test_basic.py b/pymysql/tests/test_basic.py index eb223ef..3204eb7 100644 --- a/pymysql/tests/test_basic.py +++ b/pymysql/tests/test_basic.py @@ -15,7 +15,8 @@ class TestConversion(base.PyMySQLTestCase): c.execute("insert into test_datatypes (b,i,l,f,s,u,bb,d,dt,td,t,st) values (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)", v) c.execute("select b,i,l,f,s,u,bb,d,dt,td,t,st from test_datatypes") r = c.fetchone() - self.assertEqual(v[:8], r[:8]) + self.assertEqual("\x01", r[0]) + self.assertEqual(v[1:8], r[1:8]) # mysql throws away microseconds so we need to check datetimes # specially. additionally times are turned into timedeltas. self.assertEqual(datetime.datetime(*v[8].timetuple()[:6]), r[8]) diff --git a/pymysql/tests/test_issues.py b/pymysql/tests/test_issues.py index 3af30ec..610b663 100644 --- a/pymysql/tests/test_issues.py +++ b/pymysql/tests/test_issues.py @@ -104,8 +104,8 @@ KEY (`station`,`dh`,`echeance`)) ENGINE=MyISAM DEFAULT CHARSET=latin1;""") self.assertEqual('1', pymysql.converters.escape_item(1, "utf8")) self.assertEqual('1', pymysql.converters.escape_item(1L, "utf8")) - self.assertEqual('1', pymysql.converters.escape_object(1, "utf8")) - self.assertEqual('1', pymysql.converters.escape_object(1L, "utf8")) + self.assertEqual('1', pymysql.converters.escape_object(1)) + self.assertEqual('1', pymysql.converters.escape_object(1L)) def test_issue_15(self): """ query should be expanded before perform character encoding """