cleaned up charset and unicode functionality, aligned BIT type handling with MySQLdb, disambiguated length coded binary and strings, raised exception after connection is close()'d
This commit is contained in:
		| @@ -1,5 +1,3 @@ | ||||
|  | ||||
|  | ||||
| MBLENGTH = { | ||||
|         8:1, | ||||
|         33:3, | ||||
| @@ -7,4 +5,170 @@ MBLENGTH = { | ||||
|         91:2 | ||||
|         } | ||||
|  | ||||
| class Charset: | ||||
|     def __init__(self, id, name, collation, is_default): | ||||
|         self.id, self.name, self.collation = id, name, collation | ||||
|         self.is_default = is_default == 'Yes' | ||||
|  | ||||
| class Charsets: | ||||
|     def __init__(self): | ||||
|         self._by_id = {} | ||||
|  | ||||
|     def add(self, c): | ||||
|         self._by_id[c.id] = c | ||||
|  | ||||
|     def by_id(self, id): | ||||
|         return self._by_id[id] | ||||
|  | ||||
|     def by_name(self, name): | ||||
|         for c in self._by_id.values(): | ||||
|             if c.name == name and c.is_default: | ||||
|                 return c | ||||
|  | ||||
| _charsets = Charsets() | ||||
| """ | ||||
| Generated with: | ||||
|  | ||||
| mysql -N -s -e "select id, character_set_name, collation_name, is_default | ||||
| from information_schema.collations order by id;" | python -c "import sys | ||||
| for l in sys.stdin.readlines(): | ||||
|         id, name, collation, is_default  = l.split(chr(9)) | ||||
|         print '_charsets.add(Charset(%s, \'%s\', \'%s\', \'%s\'))' \ | ||||
|                 % (id, name, collation, is_default.strip()) | ||||
| " | ||||
|  | ||||
| """ | ||||
| _charsets.add(Charset(1, 'big5', 'big5_chinese_ci', 'Yes')) | ||||
| _charsets.add(Charset(2, 'latin2', 'latin2_czech_cs', '')) | ||||
| _charsets.add(Charset(3, 'dec8', 'dec8_swedish_ci', 'Yes')) | ||||
| _charsets.add(Charset(4, 'cp850', 'cp850_general_ci', 'Yes')) | ||||
| _charsets.add(Charset(5, 'latin1', 'latin1_german1_ci', '')) | ||||
| _charsets.add(Charset(6, 'hp8', 'hp8_english_ci', 'Yes')) | ||||
| _charsets.add(Charset(7, 'koi8r', 'koi8r_general_ci', 'Yes')) | ||||
| _charsets.add(Charset(8, 'latin1', 'latin1_swedish_ci', 'Yes')) | ||||
| _charsets.add(Charset(9, 'latin2', 'latin2_general_ci', 'Yes')) | ||||
| _charsets.add(Charset(10, 'swe7', 'swe7_swedish_ci', 'Yes')) | ||||
| _charsets.add(Charset(11, 'ascii', 'ascii_general_ci', 'Yes')) | ||||
| _charsets.add(Charset(12, 'ujis', 'ujis_japanese_ci', 'Yes')) | ||||
| _charsets.add(Charset(13, 'sjis', 'sjis_japanese_ci', 'Yes')) | ||||
| _charsets.add(Charset(14, 'cp1251', 'cp1251_bulgarian_ci', '')) | ||||
| _charsets.add(Charset(15, 'latin1', 'latin1_danish_ci', '')) | ||||
| _charsets.add(Charset(16, 'hebrew', 'hebrew_general_ci', 'Yes')) | ||||
| _charsets.add(Charset(18, 'tis620', 'tis620_thai_ci', 'Yes')) | ||||
| _charsets.add(Charset(19, 'euckr', 'euckr_korean_ci', 'Yes')) | ||||
| _charsets.add(Charset(20, 'latin7', 'latin7_estonian_cs', '')) | ||||
| _charsets.add(Charset(21, 'latin2', 'latin2_hungarian_ci', '')) | ||||
| _charsets.add(Charset(22, 'koi8u', 'koi8u_general_ci', 'Yes')) | ||||
| _charsets.add(Charset(23, 'cp1251', 'cp1251_ukrainian_ci', '')) | ||||
| _charsets.add(Charset(24, 'gb2312', 'gb2312_chinese_ci', 'Yes')) | ||||
| _charsets.add(Charset(25, 'greek', 'greek_general_ci', 'Yes')) | ||||
| _charsets.add(Charset(26, 'cp1250', 'cp1250_general_ci', 'Yes')) | ||||
| _charsets.add(Charset(27, 'latin2', 'latin2_croatian_ci', '')) | ||||
| _charsets.add(Charset(28, 'gbk', 'gbk_chinese_ci', 'Yes')) | ||||
| _charsets.add(Charset(29, 'cp1257', 'cp1257_lithuanian_ci', '')) | ||||
| _charsets.add(Charset(30, 'latin5', 'latin5_turkish_ci', 'Yes')) | ||||
| _charsets.add(Charset(31, 'latin1', 'latin1_german2_ci', '')) | ||||
| _charsets.add(Charset(32, 'armscii8', 'armscii8_general_ci', 'Yes')) | ||||
| _charsets.add(Charset(33, 'utf8', 'utf8_general_ci', 'Yes')) | ||||
| _charsets.add(Charset(34, 'cp1250', 'cp1250_czech_cs', '')) | ||||
| _charsets.add(Charset(35, 'ucs2', 'ucs2_general_ci', 'Yes')) | ||||
| _charsets.add(Charset(36, 'cp866', 'cp866_general_ci', 'Yes')) | ||||
| _charsets.add(Charset(37, 'keybcs2', 'keybcs2_general_ci', 'Yes')) | ||||
| _charsets.add(Charset(38, 'macce', 'macce_general_ci', 'Yes')) | ||||
| _charsets.add(Charset(39, 'macroman', 'macroman_general_ci', 'Yes')) | ||||
| _charsets.add(Charset(40, 'cp852', 'cp852_general_ci', 'Yes')) | ||||
| _charsets.add(Charset(41, 'latin7', 'latin7_general_ci', 'Yes')) | ||||
| _charsets.add(Charset(42, 'latin7', 'latin7_general_cs', '')) | ||||
| _charsets.add(Charset(43, 'macce', 'macce_bin', '')) | ||||
| _charsets.add(Charset(44, 'cp1250', 'cp1250_croatian_ci', '')) | ||||
| _charsets.add(Charset(47, 'latin1', 'latin1_bin', '')) | ||||
| _charsets.add(Charset(48, 'latin1', 'latin1_general_ci', '')) | ||||
| _charsets.add(Charset(49, 'latin1', 'latin1_general_cs', '')) | ||||
| _charsets.add(Charset(50, 'cp1251', 'cp1251_bin', '')) | ||||
| _charsets.add(Charset(51, 'cp1251', 'cp1251_general_ci', 'Yes')) | ||||
| _charsets.add(Charset(52, 'cp1251', 'cp1251_general_cs', '')) | ||||
| _charsets.add(Charset(53, 'macroman', 'macroman_bin', '')) | ||||
| _charsets.add(Charset(57, 'cp1256', 'cp1256_general_ci', 'Yes')) | ||||
| _charsets.add(Charset(58, 'cp1257', 'cp1257_bin', '')) | ||||
| _charsets.add(Charset(59, 'cp1257', 'cp1257_general_ci', 'Yes')) | ||||
| _charsets.add(Charset(63, 'binary', 'binary', 'Yes')) | ||||
| _charsets.add(Charset(64, 'armscii8', 'armscii8_bin', '')) | ||||
| _charsets.add(Charset(65, 'ascii', 'ascii_bin', '')) | ||||
| _charsets.add(Charset(66, 'cp1250', 'cp1250_bin', '')) | ||||
| _charsets.add(Charset(67, 'cp1256', 'cp1256_bin', '')) | ||||
| _charsets.add(Charset(68, 'cp866', 'cp866_bin', '')) | ||||
| _charsets.add(Charset(69, 'dec8', 'dec8_bin', '')) | ||||
| _charsets.add(Charset(70, 'greek', 'greek_bin', '')) | ||||
| _charsets.add(Charset(71, 'hebrew', 'hebrew_bin', '')) | ||||
| _charsets.add(Charset(72, 'hp8', 'hp8_bin', '')) | ||||
| _charsets.add(Charset(73, 'keybcs2', 'keybcs2_bin', '')) | ||||
| _charsets.add(Charset(74, 'koi8r', 'koi8r_bin', '')) | ||||
| _charsets.add(Charset(75, 'koi8u', 'koi8u_bin', '')) | ||||
| _charsets.add(Charset(77, 'latin2', 'latin2_bin', '')) | ||||
| _charsets.add(Charset(78, 'latin5', 'latin5_bin', '')) | ||||
| _charsets.add(Charset(79, 'latin7', 'latin7_bin', '')) | ||||
| _charsets.add(Charset(80, 'cp850', 'cp850_bin', '')) | ||||
| _charsets.add(Charset(81, 'cp852', 'cp852_bin', '')) | ||||
| _charsets.add(Charset(82, 'swe7', 'swe7_bin', '')) | ||||
| _charsets.add(Charset(83, 'utf8', 'utf8_bin', '')) | ||||
| _charsets.add(Charset(84, 'big5', 'big5_bin', '')) | ||||
| _charsets.add(Charset(85, 'euckr', 'euckr_bin', '')) | ||||
| _charsets.add(Charset(86, 'gb2312', 'gb2312_bin', '')) | ||||
| _charsets.add(Charset(87, 'gbk', 'gbk_bin', '')) | ||||
| _charsets.add(Charset(88, 'sjis', 'sjis_bin', '')) | ||||
| _charsets.add(Charset(89, 'tis620', 'tis620_bin', '')) | ||||
| _charsets.add(Charset(90, 'ucs2', 'ucs2_bin', '')) | ||||
| _charsets.add(Charset(91, 'ujis', 'ujis_bin', '')) | ||||
| _charsets.add(Charset(92, 'geostd8', 'geostd8_general_ci', 'Yes')) | ||||
| _charsets.add(Charset(93, 'geostd8', 'geostd8_bin', '')) | ||||
| _charsets.add(Charset(94, 'latin1', 'latin1_spanish_ci', '')) | ||||
| _charsets.add(Charset(95, 'cp932', 'cp932_japanese_ci', 'Yes')) | ||||
| _charsets.add(Charset(96, 'cp932', 'cp932_bin', '')) | ||||
| _charsets.add(Charset(97, 'eucjpms', 'eucjpms_japanese_ci', 'Yes')) | ||||
| _charsets.add(Charset(98, 'eucjpms', 'eucjpms_bin', '')) | ||||
| _charsets.add(Charset(99, 'cp1250', 'cp1250_polish_ci', '')) | ||||
| _charsets.add(Charset(128, 'ucs2', 'ucs2_unicode_ci', '')) | ||||
| _charsets.add(Charset(129, 'ucs2', 'ucs2_icelandic_ci', '')) | ||||
| _charsets.add(Charset(130, 'ucs2', 'ucs2_latvian_ci', '')) | ||||
| _charsets.add(Charset(131, 'ucs2', 'ucs2_romanian_ci', '')) | ||||
| _charsets.add(Charset(132, 'ucs2', 'ucs2_slovenian_ci', '')) | ||||
| _charsets.add(Charset(133, 'ucs2', 'ucs2_polish_ci', '')) | ||||
| _charsets.add(Charset(134, 'ucs2', 'ucs2_estonian_ci', '')) | ||||
| _charsets.add(Charset(135, 'ucs2', 'ucs2_spanish_ci', '')) | ||||
| _charsets.add(Charset(136, 'ucs2', 'ucs2_swedish_ci', '')) | ||||
| _charsets.add(Charset(137, 'ucs2', 'ucs2_turkish_ci', '')) | ||||
| _charsets.add(Charset(138, 'ucs2', 'ucs2_czech_ci', '')) | ||||
| _charsets.add(Charset(139, 'ucs2', 'ucs2_danish_ci', '')) | ||||
| _charsets.add(Charset(140, 'ucs2', 'ucs2_lithuanian_ci', '')) | ||||
| _charsets.add(Charset(141, 'ucs2', 'ucs2_slovak_ci', '')) | ||||
| _charsets.add(Charset(142, 'ucs2', 'ucs2_spanish2_ci', '')) | ||||
| _charsets.add(Charset(143, 'ucs2', 'ucs2_roman_ci', '')) | ||||
| _charsets.add(Charset(144, 'ucs2', 'ucs2_persian_ci', '')) | ||||
| _charsets.add(Charset(145, 'ucs2', 'ucs2_esperanto_ci', '')) | ||||
| _charsets.add(Charset(146, 'ucs2', 'ucs2_hungarian_ci', '')) | ||||
| _charsets.add(Charset(192, 'utf8', 'utf8_unicode_ci', '')) | ||||
| _charsets.add(Charset(193, 'utf8', 'utf8_icelandic_ci', '')) | ||||
| _charsets.add(Charset(194, 'utf8', 'utf8_latvian_ci', '')) | ||||
| _charsets.add(Charset(195, 'utf8', 'utf8_romanian_ci', '')) | ||||
| _charsets.add(Charset(196, 'utf8', 'utf8_slovenian_ci', '')) | ||||
| _charsets.add(Charset(197, 'utf8', 'utf8_polish_ci', '')) | ||||
| _charsets.add(Charset(198, 'utf8', 'utf8_estonian_ci', '')) | ||||
| _charsets.add(Charset(199, 'utf8', 'utf8_spanish_ci', '')) | ||||
| _charsets.add(Charset(200, 'utf8', 'utf8_swedish_ci', '')) | ||||
| _charsets.add(Charset(201, 'utf8', 'utf8_turkish_ci', '')) | ||||
| _charsets.add(Charset(202, 'utf8', 'utf8_czech_ci', '')) | ||||
| _charsets.add(Charset(203, 'utf8', 'utf8_danish_ci', '')) | ||||
| _charsets.add(Charset(204, 'utf8', 'utf8_lithuanian_ci', '')) | ||||
| _charsets.add(Charset(205, 'utf8', 'utf8_slovak_ci', '')) | ||||
| _charsets.add(Charset(206, 'utf8', 'utf8_spanish2_ci', '')) | ||||
| _charsets.add(Charset(207, 'utf8', 'utf8_roman_ci', '')) | ||||
| _charsets.add(Charset(208, 'utf8', 'utf8_persian_ci', '')) | ||||
| _charsets.add(Charset(209, 'utf8', 'utf8_esperanto_ci', '')) | ||||
| _charsets.add(Charset(210, 'utf8', 'utf8_hungarian_ci', '')) | ||||
|  | ||||
| def charset_by_name(name): | ||||
|     return _charsets.by_name(name) | ||||
|  | ||||
| def charset_by_id(id): | ||||
|     return _charsets.by_id(id) | ||||
|  | ||||
|   | ||||
| @@ -1,8 +1,6 @@ | ||||
| # Python implementation of the MySQL client-server protocol | ||||
| #   http://forge.mysql.com/wiki/MySQL_Internals_ClientServer_Protocol | ||||
|  | ||||
| import re | ||||
|  | ||||
| try: | ||||
|     import hashlib | ||||
|     sha_new = lambda *args, **kwargs: hashlib.new("sha1", *args, **kwargs) | ||||
| @@ -22,13 +20,13 @@ try: | ||||
| except ImportError: | ||||
|     import StringIO | ||||
|  | ||||
| from charset import MBLENGTH | ||||
| from charset import MBLENGTH, charset_by_name, charset_by_id | ||||
| from cursors import Cursor | ||||
| from constants import FIELD_TYPE | ||||
| from constants import FIELD_TYPE, FLAG | ||||
| from constants import SERVER_STATUS | ||||
| from constants.CLIENT import * | ||||
| from constants.COMMAND import * | ||||
| from converters import escape_item, encoders, decoders, field_decoders | ||||
| from converters import escape_item, encoders, decoders | ||||
| from err import raise_mysql_exception, Warning, Error, \ | ||||
|      InterfaceError, DataError, DatabaseError, OperationalError, \ | ||||
|      IntegrityError, InternalError, NotSupportedError, ProgrammingError | ||||
| @@ -64,7 +62,8 @@ def dump_packet(data): | ||||
|     dump_data = [data[i:i+16] for i in xrange(len(data)) if i%16 == 0] | ||||
|     for d in dump_data: | ||||
|         print ' '.join(map(lambda x:"%02X" % ord(x), d)) + \ | ||||
|                 '   ' * (16 - len(d)) + ' ' * 2 + ' '.join(map(lambda x:"%s" % is_ascii(x), d)) | ||||
|                 '   ' * (16 - len(d)) + ' ' * 2 + \ | ||||
|                 ' '.join(map(lambda x:"%s" % is_ascii(x), d)) | ||||
|     print "-" * 88 | ||||
|     print "" | ||||
|  | ||||
| @@ -84,7 +83,8 @@ def _my_crypt(message1, message2): | ||||
|     length = len(message1) | ||||
|     result = struct.pack('B', length) | ||||
|     for i in xrange(length): | ||||
|         x = (struct.unpack('B', message1[i:i+1])[0] ^ struct.unpack('B', message2[i:i+1])[0]) | ||||
|         x = (struct.unpack('B', message1[i:i+1])[0] ^ \ | ||||
|              struct.unpack('B', message2[i:i+1])[0]) | ||||
|         result += struct.pack('B', x) | ||||
|     return result | ||||
|  | ||||
| @@ -161,9 +161,10 @@ def unpack_int64(n): | ||||
|     (struct.unpack('B',n[6])[0] << 48) + (struct.unpack('B',n[7])[0]<<56) | ||||
|  | ||||
| def defaulterrorhandler(connection, cursor, errorclass, errorvalue): | ||||
|     raise | ||||
|     err = errorclass, errorvalue | ||||
|      | ||||
|     if DEBUG: | ||||
|         raise | ||||
|  | ||||
|     if cursor: | ||||
|         cursor.messages.append(err) | ||||
|     else: | ||||
| @@ -271,8 +272,8 @@ class MysqlPacket(object): | ||||
|     """ | ||||
|     return self.__data[position:(position+length)] | ||||
|  | ||||
|   def read_coded_length(self): | ||||
|     """Read a 'Length Coded' number from the data buffer. | ||||
|   def read_length_coded_binary(self): | ||||
|     """Read a 'Length Coded Binary' number from the data buffer. | ||||
|  | ||||
|     Length coded numbers can be anywhere from 1 to 9 bytes depending | ||||
|     on the value of the first byte. | ||||
| @@ -290,16 +291,17 @@ class MysqlPacket(object): | ||||
|       # TODO: what was 'longlong'?  confirm it wasn't used? | ||||
|       return unpack_int64(self.read(UNSIGNED_INT64_LENGTH)) | ||||
|  | ||||
|   def read_length_coded_binary(self): | ||||
|     """Read a 'Length Coded Binary' from the data buffer. | ||||
|   def read_length_coded_string(self): | ||||
|     """Read a 'Length Coded String' from the data buffer. | ||||
|  | ||||
|     A 'Length Coded Binary' consists first of a length coded | ||||
|     A 'Length Coded String' consists first of a length coded | ||||
|     (unsigned, positive) integer represented in 1-9 bytes followed by | ||||
|     that many bytes of binary data.  (For example "cat" would be "3cat".) | ||||
|     """ | ||||
|     length = self.read_coded_length() | ||||
|     if length: | ||||
|       return self.read(length) | ||||
|     length = self.read_length_coded_binary() | ||||
|     if length is None: | ||||
|         return None | ||||
|     return self.read(length) | ||||
|  | ||||
|   def is_ok_packet(self): | ||||
|     return ord(self.get_bytes(0)) == 0 | ||||
| @@ -342,19 +344,17 @@ class FieldDescriptorPacket(MysqlPacket): | ||||
|  | ||||
|     This is compatible with MySQL 4.1+ (not compatible with MySQL 4.0). | ||||
|     """ | ||||
|     self.catalog = self.read_length_coded_binary() | ||||
|     self.db = self.read_length_coded_binary() | ||||
|     self.table_name = self.read_length_coded_binary() | ||||
|     self.org_table = self.read_length_coded_binary() | ||||
|     self.name = self.read_length_coded_binary() | ||||
|     self.org_name = self.read_length_coded_binary() | ||||
|     self.catalog = self.read_length_coded_string() | ||||
|     self.db = self.read_length_coded_string() | ||||
|     self.table_name = self.read_length_coded_string() | ||||
|     self.org_table = self.read_length_coded_string() | ||||
|     self.name = self.read_length_coded_string() | ||||
|     self.org_name = self.read_length_coded_string() | ||||
|     self.advance(1)  # non-null filler | ||||
|     self.charsetnr = struct.unpack('<h', self.read(2))[0] | ||||
|     self.length = struct.unpack('<i', self.read(4))[0] | ||||
|     self.charsetnr = struct.unpack('<H', self.read(2))[0] | ||||
|     self.length = struct.unpack('<I', self.read(4))[0] | ||||
|     self.type_code = ord(self.read(1)) | ||||
|     flags = struct.unpack('<h', self.read(2)) | ||||
|     # TODO: what is going on here with this flag parsing??? | ||||
|     self.flags = int(("%02X" % flags)[1:], 16) | ||||
|     self.flags = struct.unpack('<H', self.read(2))[0] | ||||
|     self.scale = ord(self.read(1))  # "decimals" | ||||
|     self.advance(2)  # filler (always 0x00) | ||||
|  | ||||
| @@ -401,8 +401,8 @@ class Connection(object): | ||||
|  | ||||
|     def __init__(self, host="localhost", user=None, passwd="", | ||||
|                  db=None, port=3306, unix_socket=None, | ||||
|                  charset=DEFAULT_CHARSET, sql_mode=None, | ||||
|                  read_default_file=None, conv=decoders, use_unicode=True, | ||||
|                  charset='', sql_mode=None, | ||||
|                  read_default_file=None, conv=decoders, use_unicode=False, | ||||
|                  client_flag=0, cursorclass=Cursor, init_command=None, | ||||
|                  connect_timeout=None, ssl=None, read_default_group=None, | ||||
|                  compress=None, named_pipe=None): | ||||
| @@ -457,7 +457,7 @@ class Connection(object): | ||||
|                     return cfg.get("client",key) | ||||
|                 except: | ||||
|                     return default | ||||
|              | ||||
|  | ||||
|             user = _config("user",user) | ||||
|             passwd = _config("password",passwd) | ||||
|             host = _config("host", host) | ||||
| @@ -465,15 +465,22 @@ class Connection(object): | ||||
|             unix_socket = _config("socket",unix_socket) | ||||
|             port = _config("port", port) | ||||
|             charset = _config("default-character-set", charset) | ||||
|                  | ||||
|  | ||||
|         self.host = host | ||||
|         self.port = port | ||||
|         self.user = user | ||||
|         self.password = passwd | ||||
|         self.db = db | ||||
|         self.unix_socket = unix_socket | ||||
|         self.use_unicode = use_unicode | ||||
|         self.charset = DEFAULT_CHARSET | ||||
|         if charset: | ||||
|             self.charset = charset | ||||
|             self.use_unicode = True | ||||
|         else: | ||||
|             self.charset = DEFAULT_CHARSET | ||||
|             self.use_unicode = False | ||||
|  | ||||
|         if use_unicode: | ||||
|             self.use_unicode = use_unicode | ||||
|  | ||||
|         client_flag |= CAPABILITIES | ||||
|         client_flag |= MULTI_STATEMENTS | ||||
| @@ -483,20 +490,19 @@ class Connection(object): | ||||
|  | ||||
|         self.cursorclass = cursorclass | ||||
|         self.connect_timeout = connect_timeout | ||||
|          | ||||
|  | ||||
|         self._connect() | ||||
|          | ||||
|         self.set_charset_set(charset) | ||||
|  | ||||
|         self.messages = [] | ||||
|         self.set_charset(charset) | ||||
|         self.encoders = encoders | ||||
|         self.decoders = conv | ||||
|         self.field_decoders = field_decoders | ||||
|  | ||||
|         self._affected_rows = 0 | ||||
|         self.host_info = "Not connected" | ||||
|          | ||||
|  | ||||
|         self.autocommit(False) | ||||
|          | ||||
|  | ||||
|         if sql_mode is not None: | ||||
|             c = self.cursor() | ||||
|             c.execute("SET sql_mode=%s", (sql_mode,)) | ||||
| @@ -506,21 +512,20 @@ class Connection(object): | ||||
|         if init_command is not None: | ||||
|             c = self.cursor() | ||||
|             c.execute(init_command) | ||||
|              | ||||
|  | ||||
|             self.commit() | ||||
|          | ||||
|  | ||||
|  | ||||
|     def close(self): | ||||
|         ''' Send the quit message and close the socket ''' | ||||
|         try: | ||||
|         if self.socket: | ||||
|             send_data = struct.pack('<i',1) + COM_QUIT | ||||
|             sock = self.socket | ||||
|             sock.send(send_data) | ||||
|             sock.close() | ||||
|         except: | ||||
|             exc,value,tb = sys.exc_info() | ||||
|             self.errorhandler(None, exc, value) | ||||
|      | ||||
|             self.socket.send(send_data) | ||||
|             self.socket.close() | ||||
|             self.socket = None | ||||
|         else: | ||||
|             self.errorhandler(None, InterfaceError, "(0, '')") | ||||
|  | ||||
|     def autocommit(self, value): | ||||
|         ''' Set whether or not to commit after every execute() ''' | ||||
|         try: | ||||
| @@ -560,7 +565,7 @@ class Connection(object): | ||||
|     def cursor(self): | ||||
|         ''' Create a new cursor to execute queries with ''' | ||||
|         return self.cursorclass(self) | ||||
|      | ||||
|  | ||||
|     def __enter__(self): | ||||
|         ''' Context manager that returns a Cursor ''' | ||||
|         return self.cursor() | ||||
| @@ -577,7 +582,7 @@ class Connection(object): | ||||
|         self._execute_command(COM_QUERY, sql) | ||||
|         self._affected_rows = self._read_query_result() | ||||
|         return self._affected_rows | ||||
|      | ||||
|  | ||||
|     def next_result(self): | ||||
|         self._affected_rows = self._read_query_result() | ||||
|         return self._affected_rows | ||||
| @@ -595,7 +600,7 @@ class Connection(object): | ||||
|             return | ||||
|         pkt = self.read_packet() | ||||
|         return pkt.is_ok_packet() | ||||
|      | ||||
|  | ||||
|     def ping(self, reconnect=True): | ||||
|         ''' Check if the server is alive ''' | ||||
|         try: | ||||
| @@ -612,14 +617,13 @@ class Connection(object): | ||||
|         pkt = self.read_packet() | ||||
|         return pkt.is_ok_packet() | ||||
|  | ||||
|     def set_charset_set(self, charset): | ||||
|     def set_charset(self, charset): | ||||
|         try: | ||||
|             sock = self.socket | ||||
|             if charset and self.charset != charset: | ||||
|             if charset: | ||||
|                 self._execute_command(COM_QUERY, "SET NAMES %s" % | ||||
|                                       self.escape(charset)) | ||||
|                 self.read_packet() | ||||
|                 self.charset = charset      | ||||
|                 self.charset = charset | ||||
|         except: | ||||
|             exc,value,tb = sys.exc_info() | ||||
|             self.errorhandler(None, exc, value) | ||||
| @@ -647,7 +651,7 @@ class Connection(object): | ||||
|             self._request_authentication() | ||||
|         except socket.error, e: | ||||
|             raise OperationalError(2003, "Can't connect to MySQL server on %r (%d)" % (self.host, e.args[0])) | ||||
|      | ||||
|  | ||||
|     def read_packet(self, packet_type=MysqlPacket): | ||||
|       """Read an entire "mysql packet" in its entirety from the network | ||||
|       and return a MysqlPacket type that represents the results.""" | ||||
| @@ -673,7 +677,9 @@ class Connection(object): | ||||
|         pckt_no = 0 | ||||
|         while len(buf) >= MAX_PACKET_LENGTH: | ||||
|             header = struct.pack('<i', MAX_PACKET_LENGTH)[:-1]+chr(pckt_no) | ||||
|             self.socket.send(header+buf[:MAX_PACKET_LENGTH]) | ||||
|             send_data = header + buf[:MAX_PACKET_LENGTH] | ||||
|             self.socket.send(send_data) | ||||
|             if DEBUG: dump_packet(send_data) | ||||
|             buf = buf[MAX_PACKET_LENGTH:] | ||||
|             pckt_no += 1 | ||||
|         header = struct.pack('<i', len(buf))[:-1]+chr(pckt_no) | ||||
| @@ -683,13 +689,12 @@ class Connection(object): | ||||
|         #sock = self.socket | ||||
|         #sock.send(send_data) | ||||
|  | ||||
|         if DEBUG: dump_packet(send_data) | ||||
|         # | ||||
|  | ||||
|     def _execute_command(self, command, sql): | ||||
|         self._send_command(command, sql) | ||||
|          | ||||
|     def _request_authentication(self): | ||||
|         sock = self.socket | ||||
|         self._send_authentication() | ||||
|  | ||||
|     def _send_authentication(self): | ||||
| @@ -700,9 +705,12 @@ class Connection(object): | ||||
|  | ||||
|         if self.user is None: | ||||
|             raise ValueError, "Did not specify a username" | ||||
|      | ||||
|         data_init = (struct.pack('<i', self.client_flag)) \ | ||||
|                             + "\0\0\0\x01" + '\x08' + '\0'*23 | ||||
|  | ||||
|         charset_id = charset_by_name(self.charset).id | ||||
|         self.user = self.user.encode(self.charset) | ||||
|  | ||||
|         data_init = struct.pack('<i', self.client_flag) + "\0\0\0\x01" + \ | ||||
|                      chr(charset_id) + '\0'*23 | ||||
|  | ||||
|         next_packet = 1 | ||||
|  | ||||
| @@ -722,13 +730,14 @@ class Connection(object): | ||||
|         data = data_init + self.user+"\0" + _scramble(self.password, self.salt) | ||||
|  | ||||
|         if self.db: | ||||
|             data += self.db.encode(self.charset) + "\0" | ||||
|             self.db = self.db.encode(self.charset) | ||||
|             data += self.db + "\0" | ||||
|  | ||||
|         data = pack_int24(len(data)) + chr(next_packet) + data | ||||
|         next_packet += 2 | ||||
|          | ||||
|  | ||||
|         if DEBUG: dump_packet(data) | ||||
|          | ||||
|  | ||||
|         sock.send(data) | ||||
|  | ||||
|         auth_packet = MysqlPacket(sock) | ||||
| @@ -743,13 +752,13 @@ class Connection(object): | ||||
|             #raise NotImplementedError, "old_passwords are not supported. Check to see if mysqld was started with --old-passwords, if old-passwords=1 in a my.cnf file, or if there are some short hashes in your mysql.user table." | ||||
|             data = _scramble_323(self.password, self.salt) + "\0" | ||||
|             data = pack_int24(len(data)) + chr(next_packet) + data | ||||
|          | ||||
|  | ||||
|             sock.send(data) | ||||
|             auth_packet = MysqlPacket(sock) | ||||
|             auth_packet.check_error() | ||||
|             if DEBUG: auth_packet.dump() | ||||
|          | ||||
|          | ||||
|  | ||||
|  | ||||
|     # _mysql support | ||||
|     def thread_id(self): | ||||
|         return self.server_thread_id[0] | ||||
| @@ -759,10 +768,10 @@ class Connection(object): | ||||
|  | ||||
|     def get_host_info(self): | ||||
|         return self.host_info | ||||
|      | ||||
|  | ||||
|     def get_proto_info(self): | ||||
|         return self.protocol_version | ||||
|          | ||||
|  | ||||
|     def _get_server_information(self): | ||||
|         sock = self.socket | ||||
|         i = 0 | ||||
| @@ -773,27 +782,28 @@ class Connection(object): | ||||
|         #packet_len = ord(data[i:i+1]) | ||||
|         #i += 4 | ||||
|         self.protocol_version = ord(data[i:i+1]) | ||||
|          | ||||
|  | ||||
|         i += 1 | ||||
|         server_end = data.find("\0", i) | ||||
|         self.server_version = data[i:server_end] | ||||
|          | ||||
|  | ||||
|         i = server_end + 1 | ||||
|         self.server_thread_id = struct.unpack('<h', data[i:i+2]) | ||||
|  | ||||
|         i += 4 | ||||
|         self.salt = data[i:i+8] | ||||
|          | ||||
|  | ||||
|         i += 9 | ||||
|         if len(data) >= i + 1: | ||||
|             i += 1 | ||||
|         | ||||
|  | ||||
|         self.server_capabilities = struct.unpack('<h', data[i:i+2])[0] | ||||
|  | ||||
|         i += 1 | ||||
|         self.server_language = ord(data[i:i+1]) | ||||
|          | ||||
|         i += 16  | ||||
|         self.server_charset = charset_by_id(self.server_language).name | ||||
|  | ||||
|         i += 16 | ||||
|         if len(data) >= i+12-1: | ||||
|             rest_salt = data[i:i+12] | ||||
|             self.salt += rest_salt | ||||
| @@ -840,8 +850,8 @@ class MySQLResult(object): | ||||
|  | ||||
|     def _read_ok_packet(self): | ||||
|         self.first_packet.advance(1)  # field_count (always '0') | ||||
|         self.affected_rows = self.first_packet.read_coded_length() | ||||
|         self.insert_id = self.first_packet.read_coded_length() | ||||
|         self.affected_rows = self.first_packet.read_length_coded_binary() | ||||
|         self.insert_id = self.first_packet.read_length_coded_binary() | ||||
|         self.server_status = struct.unpack('<H', self.first_packet.read(2))[0] | ||||
|         self.warning_count = struct.unpack('<H', self.first_packet.read(2))[0] | ||||
|         self.message = self.first_packet.read_all() | ||||
| @@ -871,14 +881,7 @@ class MySQLResult(object): | ||||
|                 converter = self.connection.decoders[field.type_code] | ||||
|  | ||||
|                 if DEBUG: print "DEBUG: field=%s, converter=%s" % (field, converter) | ||||
|                 data = packet.read_length_coded_binary() | ||||
|                 converted = None | ||||
|                 if data != None: | ||||
|                     converted = converter(data) | ||||
|             else: | ||||
|                 converter = self.connection.field_decoders[field.type_code] | ||||
|                 if DEBUG: print "DEBUG: field=%s, converter=%s" % (field, converter) | ||||
|                 data = packet.read_length_coded_binary() | ||||
|                 data = packet.read_length_coded_string() | ||||
|                 converted = None | ||||
|                 if data != None: | ||||
|                     converted = converter(self.connection, field, data) | ||||
|   | ||||
| @@ -29,4 +29,4 @@ STRING = 254 | ||||
| GEOMETRY = 255 | ||||
|  | ||||
| CHAR = TINY | ||||
| INTERVAL = ENUM  | ||||
| INTERVAL = ENUM | ||||
|   | ||||
| @@ -1,11 +1,9 @@ | ||||
| import re | ||||
| import datetime | ||||
| import time | ||||
| import array | ||||
| import struct | ||||
|  | ||||
| from times import Date, Time, TimeDelta, Timestamp | ||||
| from constants import FIELD_TYPE | ||||
| from constants import FIELD_TYPE, FLAG | ||||
| from charset import charset_by_id | ||||
|  | ||||
| try: | ||||
|     set | ||||
| @@ -20,8 +18,16 @@ ESCAPE_MAP = {'\0': '\\0', '\n': '\\n', '\r': '\\r', '\032': '\\Z', | ||||
|               '\'': '\\\'', '"': '\\"', '\\': '\\\\'} | ||||
|  | ||||
| def escape_item(val, charset): | ||||
|     if type(val) in [tuple, list, set]: | ||||
|         return escape_sequence(val, charset) | ||||
|     if type(val) is dict: | ||||
|         return escape_dict(val, charset) | ||||
|     encoder = encoders[type(val)] | ||||
|     return encoder(val, charset) | ||||
|     val = encoder(val) | ||||
|     if type(val) is str: | ||||
|         return val | ||||
|     val = val.encode(charset) | ||||
|     return val | ||||
|  | ||||
| def escape_dict(val, charset): | ||||
|     n = {} | ||||
| @@ -37,99 +43,96 @@ def escape_sequence(val, charset): | ||||
|         n.append(quoted) | ||||
|     return tuple(n) | ||||
|  | ||||
| def escape_bool(value, charset): | ||||
|     return str(int(value)).encode(charset) | ||||
| def escape_set(val, charset): | ||||
|     val = map(lambda x: escape_item(x, charset), val) | ||||
|     return ','.join(val) | ||||
|  | ||||
| def escape_object(value, charset): | ||||
|     return str(value).encode(charset) | ||||
| def escape_bool(value): | ||||
|     return str(int(value)) | ||||
|  | ||||
| def escape_object(value): | ||||
|     return str(value) | ||||
|  | ||||
| escape_int = escape_long = escape_object | ||||
|  | ||||
| def escape_float(value, charset): | ||||
|     return ('%.15g' % value).encode(charset) | ||||
| def escape_float(value): | ||||
|     return ('%.15g' % value) | ||||
|  | ||||
| def escape_string(value, charset): | ||||
|     r = ("'%s'" % ESCAPE_REGEX.sub( | ||||
|         lambda match: ESCAPE_MAP.get(match.group(0)), value)) | ||||
|     # TODO: make sure that encodings are handled correctly here. | ||||
|     # Since we may be dealing with binary data, the encoding | ||||
|     # routine below is commented out. | ||||
|     #if not charset is None: | ||||
|     #    r = r.encode(charset) | ||||
|     return r | ||||
|      | ||||
| def escape_unicode(value, charset): | ||||
|     # pass None as the charset because we already encode it | ||||
|     return escape_string(value.encode(charset), None) | ||||
| def escape_string(value): | ||||
|     return ("'%s'" % ESCAPE_REGEX.sub( | ||||
|             lambda match: ESCAPE_MAP.get(match.group(0)), value)) | ||||
|  | ||||
| def escape_None(value, charset): | ||||
|     return 'NULL'.encode(charset) | ||||
| def escape_unicode(value): | ||||
|     return escape_string(value) | ||||
|  | ||||
| def escape_timedelta(obj, charset): | ||||
| def escape_None(value): | ||||
|     return 'NULL' | ||||
|  | ||||
| def escape_timedelta(obj): | ||||
|     seconds = int(obj.seconds) % 60 | ||||
|     minutes = int(obj.seconds // 60) % 60 | ||||
|     hours = int(obj.seconds // 3600) % 24 + int(obj.days) * 24 | ||||
|     return escape_string('%02d:%02d:%02d' % (hours, minutes, seconds), charset) | ||||
|     return escape_string('%02d:%02d:%02d' % (hours, minutes, seconds)) | ||||
|  | ||||
| def escape_time(obj, charset): | ||||
| def escape_time(obj): | ||||
|     s = "%02d:%02d:%02d" % (int(obj.hour), int(obj.minute), | ||||
|                             int(obj.second)) | ||||
|     if obj.microsecond: | ||||
|         s += ".%f" % obj.microsecond | ||||
|  | ||||
|     return escape_string(s, charset) | ||||
|     return escape_string(s) | ||||
|  | ||||
| def escape_datetime(obj, charset): | ||||
|     return escape_string(obj.strftime("%Y-%m-%d %H:%M:%S"), charset) | ||||
| def escape_datetime(obj): | ||||
|     return escape_string(obj.strftime("%Y-%m-%d %H:%M:%S")) | ||||
|  | ||||
| def escape_date(obj, charset): | ||||
|     return escape_string(obj.strftime("%Y-%m-%d"), charset) | ||||
| def escape_date(obj): | ||||
|     return escape_string(obj.strftime("%Y-%m-%d")) | ||||
|  | ||||
| def escape_struct_time(obj, charset): | ||||
|     return escape_datetime(datetime.datetime(*obj[:6]), charset) | ||||
| def escape_struct_time(obj): | ||||
|     return escape_datetime(datetime.datetime(*obj[:6])) | ||||
|  | ||||
| def convert_datetime(obj): | ||||
| def convert_datetime(connection, field, obj): | ||||
|     """Returns a DATETIME or TIMESTAMP column value as a datetime object: | ||||
|      | ||||
|  | ||||
|       >>> datetime_or_None('2007-02-25 23:06:20') | ||||
|       datetime.datetime(2007, 2, 25, 23, 6, 20) | ||||
|       >>> datetime_or_None('2007-02-25T23:06:20') | ||||
|       datetime.datetime(2007, 2, 25, 23, 6, 20) | ||||
|      | ||||
|  | ||||
|     Illegal values are returned as None: | ||||
|      | ||||
|  | ||||
|       >>> datetime_or_None('2007-02-31T23:06:20') is None | ||||
|       True | ||||
|       >>> datetime_or_None('0000-00-00 00:00:00') is None | ||||
|       True | ||||
|     | ||||
|  | ||||
|     """ | ||||
|     if ' ' in obj: | ||||
|         sep = ' ' | ||||
|     elif 'T' in obj: | ||||
|         sep = 'T' | ||||
|     else: | ||||
|         return convert_date(obj) | ||||
|         return convert_date(connection, field, obj) | ||||
|  | ||||
|     try: | ||||
|         ymd, hms = obj.split(sep, 1) | ||||
|         return datetime.datetime(*[ int(x) for x in ymd.split('-')+hms.split(':') ]) | ||||
|     except ValueError: | ||||
|         return convert_date(obj) | ||||
|         return convert_date(connection, field, obj) | ||||
|  | ||||
| def convert_timedelta(obj): | ||||
| def convert_timedelta(connection, field, obj): | ||||
|     """Returns a TIME column as a timedelta object: | ||||
|  | ||||
|       >>> timedelta_or_None('25:06:17') | ||||
|       datetime.timedelta(1, 3977) | ||||
|       >>> timedelta_or_None('-25:06:17') | ||||
|       datetime.timedelta(-2, 83177) | ||||
|        | ||||
|  | ||||
|     Illegal values are returned as None: | ||||
|      | ||||
|  | ||||
|       >>> timedelta_or_None('random crap') is None | ||||
|       True | ||||
|     | ||||
|  | ||||
|     Note that MySQL always returns TIME columns as (+|-)HH:MM:SS, but | ||||
|     can accept values as (+|-)DD HH:MM:SS. The latter format will not | ||||
|     be parsed correctly by this function. | ||||
| @@ -147,23 +150,23 @@ def convert_timedelta(obj): | ||||
|     except ValueError: | ||||
|         return None | ||||
|  | ||||
| def convert_time(obj): | ||||
| def convert_time(connection, field, obj): | ||||
|     """Returns a TIME column as a time object: | ||||
|  | ||||
|       >>> time_or_None('15:06:17') | ||||
|       datetime.time(15, 6, 17) | ||||
|        | ||||
|  | ||||
|     Illegal values are returned as None: | ||||
|   | ||||
|  | ||||
|       >>> time_or_None('-25:06:17') is None | ||||
|       True | ||||
|       >>> time_or_None('random crap') is None | ||||
|       True | ||||
|     | ||||
|  | ||||
|     Note that MySQL always returns TIME columns as (+|-)HH:MM:SS, but | ||||
|     can accept values as (+|-)DD HH:MM:SS. The latter format will not | ||||
|     be parsed correctly by this function. | ||||
|      | ||||
|  | ||||
|     Also note that MySQL's TIME column corresponds more closely to | ||||
|     Python's timedelta and not time. However if you want TIME columns | ||||
|     to be treated as time-of-day and not a time offset, then you can | ||||
| @@ -172,53 +175,54 @@ def convert_time(obj): | ||||
|     from math import modf | ||||
|     try: | ||||
|         hour, minute, second = obj.split(':') | ||||
|         return datetime.time(hour=int(hour), minute=int(minute), second=int(second), | ||||
|                     microsecond=int(modf(float(second))[0]*1000000)) | ||||
|         return datetime.time(hour=int(hour), minute=int(minute), | ||||
|                              second=int(second), | ||||
|                              microsecond=int(modf(float(second))[0]*1000000)) | ||||
|     except ValueError: | ||||
|         return None | ||||
|  | ||||
| def convert_date(obj): | ||||
| def convert_date(connection, field, obj): | ||||
|     """Returns a DATE column as a date object: | ||||
|  | ||||
|       >>> date_or_None('2007-02-26') | ||||
|       datetime.date(2007, 2, 26) | ||||
|        | ||||
|  | ||||
|     Illegal values are returned as None: | ||||
|   | ||||
|  | ||||
|       >>> date_or_None('2007-02-31') is None | ||||
|       True | ||||
|       >>> date_or_None('0000-00-00') is None | ||||
|       True | ||||
|      | ||||
|  | ||||
|     """ | ||||
|     try: | ||||
|         return datetime.date(*[ int(x) for x in obj.split('-', 2) ]) | ||||
|     except ValueError: | ||||
|         return None | ||||
|  | ||||
| def convert_mysql_timestamp(timestamp): | ||||
| def convert_mysql_timestamp(connection, field, timestamp): | ||||
|     """Convert a MySQL TIMESTAMP to a Timestamp object. | ||||
|  | ||||
|     MySQL >= 4.1 returns TIMESTAMP in the same format as DATETIME: | ||||
|      | ||||
|  | ||||
|       >>> mysql_timestamp_converter('2007-02-25 22:32:17') | ||||
|       datetime.datetime(2007, 2, 25, 22, 32, 17) | ||||
|      | ||||
|  | ||||
|     MySQL < 4.1 uses a big string of numbers: | ||||
|      | ||||
|  | ||||
|       >>> mysql_timestamp_converter('20070225223217') | ||||
|       datetime.datetime(2007, 2, 25, 22, 32, 17) | ||||
|      | ||||
|  | ||||
|     Illegal values are returned as None: | ||||
|      | ||||
|  | ||||
|       >>> mysql_timestamp_converter('2007-02-31 22:32:17') is None | ||||
|       True | ||||
|       >>> mysql_timestamp_converter('00000000000000') is None | ||||
|       True | ||||
|        | ||||
|  | ||||
|     """ | ||||
|     if timestamp[4] == '-': | ||||
|         return convert_datetime(timestamp) | ||||
|         return convert_datetime(connection, field, timestamp) | ||||
|     timestamp += "0"*(14-len(timestamp)) # padding | ||||
|     year, month, day, hour, minute, second = \ | ||||
|         int(timestamp[:4]), int(timestamp[4:6]), int(timestamp[6:8]), \ | ||||
| @@ -229,13 +233,38 @@ def convert_mysql_timestamp(timestamp): | ||||
|         return None | ||||
|  | ||||
| def convert_set(s): | ||||
|     # TODO: this may not be correct | ||||
|     return set(s.split(",")) | ||||
|  | ||||
| def convert_bit(b): | ||||
|     b = "\x00" * (8 - len(b)) + b # pad w/ zeroes | ||||
|     return struct.unpack(">Q", b)[0] | ||||
|      | ||||
| def convert_bit(connection, field, b): | ||||
|     #b = "\x00" * (8 - len(b)) + b # pad w/ zeroes | ||||
|     #return struct.unpack(">Q", b)[0] | ||||
|     # | ||||
|     # the snippet above is right, but MySQLdb doesn't process bits, | ||||
|     # so we shouldn't either | ||||
|     return b | ||||
|  | ||||
| def convert_characters(connection, field, data): | ||||
|     if field.flags & FLAG.SET: | ||||
|         return convert_set(data) | ||||
|     if field.flags & FLAG.BINARY: | ||||
|         return data | ||||
|     field_charset = charset_by_id(field.charsetnr).name | ||||
|     if connection.use_unicode: | ||||
|         data = data.decode(field_charset) | ||||
|     elif connection.charset != field_charset: | ||||
|         data = data.decode(field_charset) | ||||
|         data = data.encode(connection.charset) | ||||
|     return data | ||||
|  | ||||
| def convert_int(connection, field, data): | ||||
|     return int(data) | ||||
|  | ||||
| def convert_long(connection, field, data): | ||||
|     return long(data) | ||||
|  | ||||
| def convert_float(connection, field, data): | ||||
|     return float(data) | ||||
|  | ||||
| encoders = { | ||||
|         bool: escape_bool, | ||||
|         int: escape_int, | ||||
| @@ -257,21 +286,28 @@ encoders = { | ||||
|  | ||||
| decoders = { | ||||
|         FIELD_TYPE.BIT: convert_bit, | ||||
|         FIELD_TYPE.TINY: int, | ||||
|         FIELD_TYPE.SHORT: int, | ||||
|         FIELD_TYPE.LONG: long, | ||||
|         FIELD_TYPE.FLOAT: float, | ||||
|         FIELD_TYPE.DOUBLE: float, | ||||
|         FIELD_TYPE.DECIMAL: float, | ||||
|         FIELD_TYPE.NEWDECIMAL: float, | ||||
|         FIELD_TYPE.LONGLONG: long, | ||||
|         FIELD_TYPE.INT24: int, | ||||
|         FIELD_TYPE.YEAR: int, | ||||
|         FIELD_TYPE.TINY: convert_int, | ||||
|         FIELD_TYPE.SHORT: convert_int, | ||||
|         FIELD_TYPE.LONG: convert_long, | ||||
|         FIELD_TYPE.FLOAT: convert_float, | ||||
|         FIELD_TYPE.DOUBLE: convert_float, | ||||
|         FIELD_TYPE.DECIMAL: convert_float, | ||||
|         FIELD_TYPE.NEWDECIMAL: convert_float, | ||||
|         FIELD_TYPE.LONGLONG: convert_long, | ||||
|         FIELD_TYPE.INT24: convert_int, | ||||
|         FIELD_TYPE.YEAR: convert_int, | ||||
|         FIELD_TYPE.TIMESTAMP: convert_mysql_timestamp, | ||||
|         FIELD_TYPE.DATETIME: convert_datetime, | ||||
|         FIELD_TYPE.TIME: convert_timedelta, | ||||
|         FIELD_TYPE.DATE: convert_date, | ||||
|         FIELD_TYPE.SET: convert_set, | ||||
|         FIELD_TYPE.BLOB: convert_characters, | ||||
|         FIELD_TYPE.TINY_BLOB: convert_characters, | ||||
|         FIELD_TYPE.MEDIUM_BLOB: convert_characters, | ||||
|         FIELD_TYPE.LONG_BLOB: convert_characters, | ||||
|         FIELD_TYPE.STRING: convert_characters, | ||||
|         FIELD_TYPE.VAR_STRING: convert_characters, | ||||
|         FIELD_TYPE.VARCHAR: convert_characters, | ||||
|         #FIELD_TYPE.BLOB: str, | ||||
|         #FIELD_TYPE.STRING: str, | ||||
|         #FIELD_TYPE.VAR_STRING: str, | ||||
| @@ -279,28 +315,13 @@ decoders = { | ||||
|         } | ||||
| conversions = decoders  # for MySQLdb compatibility | ||||
|  | ||||
| def decode_characters(connection, field, data): | ||||
|     if field.charsetnr == 63 or not connection.use_unicode: | ||||
|         # binary data, leave it alone | ||||
|         return data | ||||
|     return data.decode(connection.charset) | ||||
|  | ||||
| # These take a field instance rather than just the data. | ||||
| field_decoders = { | ||||
|     FIELD_TYPE.BLOB: decode_characters, | ||||
|     FIELD_TYPE.TINY_BLOB: decode_characters, | ||||
|     FIELD_TYPE.MEDIUM_BLOB: decode_characters, | ||||
|     FIELD_TYPE.LONG_BLOB: decode_characters, | ||||
|     FIELD_TYPE.STRING: decode_characters, | ||||
|     FIELD_TYPE.VAR_STRING: decode_characters, | ||||
|     FIELD_TYPE.VARCHAR: decode_characters, | ||||
| } | ||||
|  | ||||
| try: | ||||
|     # python version > 2.3 | ||||
|     from decimal import Decimal | ||||
|     decoders[FIELD_TYPE.DECIMAL] = Decimal | ||||
|     decoders[FIELD_TYPE.NEWDECIMAL] = Decimal | ||||
|     def convert_decimal(connection, field, data): | ||||
|         return Decimal(data) | ||||
|     decoders[FIELD_TYPE.DECIMAL] = convert_decimal | ||||
|     decoders[FIELD_TYPE.NEWDECIMAL] = convert_decimal | ||||
|  | ||||
|     def escape_decimal(obj, charset): | ||||
|         return unicode(obj).encode(charset) | ||||
|   | ||||
| @@ -1,3 +1,4 @@ | ||||
| # -*- coding: utf-8 -*- | ||||
| import struct | ||||
| import re | ||||
|  | ||||
| @@ -52,23 +53,23 @@ class Cursor(object): | ||||
|         if not self.connection: | ||||
|             self.errorhandler(self, ProgrammingError, "cursor closed") | ||||
|         return self.connection | ||||
|      | ||||
|  | ||||
|     def _check_executed(self): | ||||
|         if not self._executed: | ||||
|             self.errorhandler(self, ProgrammingError, "execute() first") | ||||
|      | ||||
|  | ||||
|     def setinputsizes(self, *args): | ||||
|         """Does nothing, required by DB API.""" | ||||
|                    | ||||
|  | ||||
|     def setoutputsizes(self, *args): | ||||
|         """Does nothing, required by DB API.""" | ||||
|                                    | ||||
|  | ||||
|     def nextset(self): | ||||
|         ''' Get the next query set ''' | ||||
|         if self._executed: | ||||
|             self.fetchall() | ||||
|         del self.messages[:] | ||||
|          | ||||
|  | ||||
|         if not self._has_next: | ||||
|             return None | ||||
|         connection = self._get_db() | ||||
| @@ -79,11 +80,11 @@ class Cursor(object): | ||||
|     def execute(self, query, args=None): | ||||
|         ''' Execute a query ''' | ||||
|         from sys import exc_info | ||||
|          | ||||
|  | ||||
|         conn = self._get_db() | ||||
|         charset = conn.charset | ||||
|         del self.messages[:] | ||||
|          | ||||
|  | ||||
|         # this ordering is good because conn.escape() returns | ||||
|         # an encoded string. | ||||
|         if isinstance(query, unicode): | ||||
| @@ -91,7 +92,7 @@ class Cursor(object): | ||||
|  | ||||
|         if args is not None: | ||||
|             query = query % conn.escape(args) | ||||
|                  | ||||
|  | ||||
|         result = 0 | ||||
|         try: | ||||
|             result = self._query(query) | ||||
| @@ -103,7 +104,7 @@ class Cursor(object): | ||||
|  | ||||
|         self._executed = query | ||||
|         return result | ||||
|      | ||||
|  | ||||
|     def executemany(self, query, args): | ||||
|         ''' Run several data against one query ''' | ||||
|         del self.messages[:] | ||||
| @@ -113,30 +114,66 @@ class Cursor(object): | ||||
|         charset = conn.charset | ||||
|         if isinstance(query, unicode): | ||||
|             query = query.encode(charset) | ||||
|      | ||||
|  | ||||
|         self.rowcount = sum([ self.execute(query, arg) for arg in args ]) | ||||
|         return self.rowcount | ||||
|      | ||||
|      | ||||
|  | ||||
|  | ||||
|     def callproc(self, procname, args=()): | ||||
|         ''' Call a stored procedure. Take care to ensure that procname is | ||||
|         properly escaped. ''' | ||||
|         if not isinstance(args, tuple): | ||||
|             args = (args,) | ||||
|         """Execute stored procedure procname with args | ||||
|  | ||||
|         argstr = ("%s," * len(args))[:-1] | ||||
|         procname -- string, name of procedure to execute on server | ||||
|  | ||||
|         return self.execute("CALL `%s`(%s)" % (procname, argstr), args) | ||||
|         args -- Sequence of parameters to use with procedure | ||||
|  | ||||
|         Returns the original args. | ||||
|  | ||||
|         Compatibility warning: PEP-249 specifies that any modified | ||||
|         parameters must be returned. This is currently impossible | ||||
|         as they are only available by storing them in a server | ||||
|         variable and then retrieved by a query. Since stored | ||||
|         procedures return zero or more result sets, there is no | ||||
|         reliable way to get at OUT or INOUT parameters via callproc. | ||||
|         The server variables are named @_procname_n, where procname | ||||
|         is the parameter above and n is the position of the parameter | ||||
|         (from zero). Once all result sets generated by the procedure | ||||
|         have been fetched, you can issue a SELECT @_procname_0, ... | ||||
|         query using .execute() to get any OUT or INOUT values. | ||||
|  | ||||
|         Compatibility warning: The act of calling a stored procedure | ||||
|         itself creates an empty result set. This appears after any | ||||
|         result sets generated by the procedure. This is non-standard | ||||
|         behavior with respect to the DB-API. Be sure to use nextset() | ||||
|         to advance through all result sets; otherwise you may get | ||||
|         disconnected. | ||||
|         """ | ||||
|         conn = self._get_db() | ||||
|         for index, arg in enumerate(args): | ||||
|             q = "SET @_%s_%d=%s" % (procname, index, conn.escape(arg)) | ||||
|             if isinstance(q, unicode): | ||||
|                 q = q.encode(conn.charset) | ||||
|             self._query(q) | ||||
|             self.nextset() | ||||
|  | ||||
|         q = "CALL %s(%s)" % (procname, | ||||
|                              ','.join(['@_%s_%d' % (procname, i) | ||||
|                                        for i in range(len(args))])) | ||||
|         if isinstance(q, unicode): | ||||
|             q = q.encode(conn.charset) | ||||
|         self._query(q) | ||||
|         self._executed = q | ||||
|  | ||||
|         return args | ||||
|  | ||||
|     def fetchone(self): | ||||
|         ''' Fetch the next row ''' | ||||
|         self._check_executed()         | ||||
|         self._check_executed() | ||||
|         if self._rows is None or self.rownumber >= len(self._rows): | ||||
|             return None | ||||
|         result = self._rows[self.rownumber] | ||||
|         self.rownumber += 1 | ||||
|         return result | ||||
|      | ||||
|  | ||||
|     def fetchmany(self, size=None): | ||||
|         ''' Fetch several rows ''' | ||||
|         self._check_executed() | ||||
| @@ -158,15 +195,15 @@ class Cursor(object): | ||||
|             result = self._rows | ||||
|         self.rownumber = len(self._rows) | ||||
|         return result | ||||
|      | ||||
|  | ||||
|     def scroll(self, value, mode='relative'): | ||||
|          | ||||
|         self._check_executed() | ||||
|         if mode == 'relative': | ||||
|             r = self.rownumber + value | ||||
|         elif mode == 'absolute': | ||||
|             r = value | ||||
|         else: | ||||
|             self.errorhandler(self, ProgrammingError,  | ||||
|             self.errorhandler(self, ProgrammingError, | ||||
|                     "unknown scroll mode %s" % mode) | ||||
|  | ||||
|         if r < 0 or r >= len(self._rows): | ||||
| @@ -179,23 +216,23 @@ class Cursor(object): | ||||
|         conn.query(q) | ||||
|         self._do_get_result() | ||||
|         return self.rowcount | ||||
|      | ||||
|  | ||||
|     def _do_get_result(self): | ||||
|         conn = self._get_db() | ||||
|         self.rowcount = conn._result.affected_rows | ||||
|          | ||||
|  | ||||
|         self.rownumber = 0 | ||||
|         self.description = conn._result.description | ||||
|         self.lastrowid = conn._result.insert_id | ||||
|         self._rows = conn._result.rows | ||||
|         self._has_next = conn._result.has_next | ||||
|         conn._result = None | ||||
|      | ||||
|  | ||||
|     def __iter__(self): | ||||
|         self._check_executed() | ||||
|         result = self.rownumber and self._rows[self.rownumber:] or self._rows | ||||
|         return iter(result) | ||||
|      | ||||
|  | ||||
|     Warning = Warning | ||||
|     Error = Error | ||||
|     InterfaceError = InterfaceError | ||||
|   | ||||
| @@ -3,7 +3,8 @@ import unittest | ||||
|  | ||||
| class PyMySQLTestCase(unittest.TestCase): | ||||
|     databases = [ | ||||
|         {"host":"localhost","user":"root","passwd":"","db":"test_pymysql"}, | ||||
|         {"host":"localhost","user":"root", | ||||
|          "passwd":"","db":"test_pymysql", "use_unicode": True}, | ||||
|         {"host":"localhost","user":"root","passwd":"","db":"test_pymysql2"}] | ||||
|  | ||||
|     def setUp(self): | ||||
|   | ||||
| @@ -15,7 +15,8 @@ class TestConversion(base.PyMySQLTestCase): | ||||
|             c.execute("insert into test_datatypes (b,i,l,f,s,u,bb,d,dt,td,t,st) values (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)", v) | ||||
|             c.execute("select b,i,l,f,s,u,bb,d,dt,td,t,st from test_datatypes") | ||||
|             r = c.fetchone() | ||||
|             self.assertEqual(v[:8], r[:8]) | ||||
|             self.assertEqual("\x01", r[0]) | ||||
|             self.assertEqual(v[1:8], r[1:8]) | ||||
|             # mysql throws away microseconds so we need to check datetimes | ||||
|             # specially. additionally times are turned into timedeltas. | ||||
|             self.assertEqual(datetime.datetime(*v[8].timetuple()[:6]), r[8]) | ||||
|   | ||||
| @@ -104,8 +104,8 @@ KEY (`station`,`dh`,`echeance`)) ENGINE=MyISAM DEFAULT CHARSET=latin1;""") | ||||
|         self.assertEqual('1', pymysql.converters.escape_item(1, "utf8")) | ||||
|         self.assertEqual('1', pymysql.converters.escape_item(1L, "utf8")) | ||||
|  | ||||
|         self.assertEqual('1', pymysql.converters.escape_object(1, "utf8")) | ||||
|         self.assertEqual('1', pymysql.converters.escape_object(1L, "utf8")) | ||||
|         self.assertEqual('1', pymysql.converters.escape_object(1)) | ||||
|         self.assertEqual('1', pymysql.converters.escape_object(1L)) | ||||
|  | ||||
|     def test_issue_15(self): | ||||
|         """ query should be expanded before perform character encoding """ | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Pete Hunt
					Pete Hunt