From d68666032e25f5e6896964a8a01ee48088146022 Mon Sep 17 00:00:00 2001 From: Pete Hunt Date: Tue, 27 Jul 2010 18:19:31 +0000 Subject: [PATCH] Fixed a potential bug in _get_server_information() regarding buffering. Fixed a bug having to do with old versions of Python and the set module. Basic Jython 2.2 compatibility now exists. --- README | 1 + pymysql/__init__.py | 5 ++++- pymysql/connections.py | 19 ++++++++++--------- pymysql/converters.py | 10 +++++++++- 4 files changed, 24 insertions(+), 11 deletions(-) diff --git a/README b/README index 231328d..fdbf399 100644 --- a/README +++ b/README @@ -29,6 +29,7 @@ Changes pymysql.version_info attribute. -Now runs with no warnings with the -3 command-line switch -Added test cases for all outstanding tickets and closed most of them. + -Basic Jython support added 0.2 -Changed connection parameter name 'password' to 'passwd' to make it more plugin replaceable for the other mysql clients. diff --git a/pymysql/__init__.py b/pymysql/__init__.py index 6381cf2..eabc3ea 100644 --- a/pymysql/__init__.py +++ b/pymysql/__init__.py @@ -14,7 +14,10 @@ try: frozenset except NameError: from sets import ImmutableSet as frozenset - from sets import BaseSet as set + try: + from sets import BaseSet as set + except ImportError: + from sets import Set as set threadsafety = 1 apilevel = "2.0" diff --git a/pymysql/connections.py b/pymysql/connections.py index 85822cd..e4aa6cb 100644 --- a/pymysql/connections.py +++ b/pymysql/connections.py @@ -1,8 +1,6 @@ # Python implementation of the MySQL client-server protocol # http://forge.mysql.com/wiki/MySQL_Internals_ClientServer_Protocol -# TODO: use streams instead of send() and recv() - import re try: @@ -127,7 +125,7 @@ def _hash_password_323(password): add = 7L nr2 = 0x12345671L - for c in (ord(x) for x in password if x not in (' ', '\t')): + for c in [ord(x) for x in password if x not in (' ', '\t')]: nr^= (((nr & 63)+add)*c)+ (nr << 8) & 0xFFFFFFFF nr2= (nr2 + ((nr2 << 8) ^ nr)) & 0xFFFFFFFF add= (add + c) & 0xFFFFFFFF @@ -211,6 +209,8 @@ class MysqlPacket(object): def packet_number(self): return self.__packet_number + def get_all_data(self): return self.__data + def read(self, size): """Read the first 'size' bytes in packet and advance cursor past them.""" result = self.peek(size) @@ -618,7 +618,7 @@ class Connection(object): data = (struct.pack('i', self.client_flag)) + "\0\0\0\x01" + \ '\x08' + '\0'*23 + \ self.user+"\0" + _scramble(self.password, self.salt) - + if self.db: data += self.db + "\0" @@ -635,7 +635,7 @@ class Connection(object): # if old_passwords is enabled the packet will be 1 byte long and # have the octet 254 - if auth_packet.get_bytes(0,2) == chr(254): + if auth_packet.is_eof_packet(): # send legacy handshake raise NotImplementedError, "old_passwords are not supported. Check to see if mysqld was started with --old-passwords, if old-passwords=1 in a my.cnf file, or if there are some short hashes in your mysql.user table." #data = _scramble_323(self.password, self.salt) + "\0" @@ -663,11 +663,12 @@ class Connection(object): def _get_server_information(self): sock = self.socket i = 0 - # TODO: likely bug here because recv() might return less bytes than we need - data = sock.recv(BUFFER_SIZE) + packet = MysqlPacket(sock) + data = packet.get_all_data() + if DEBUG: dump_packet(data) - packet_len = ord(data[i:i+1]) - i += 4 + #packet_len = ord(data[i:i+1]) + #i += 4 self.protocol_version = ord(data[i:i+1]) i += 1 diff --git a/pymysql/converters.py b/pymysql/converters.py index feb051f..bf90a3b 100644 --- a/pymysql/converters.py +++ b/pymysql/converters.py @@ -7,6 +7,14 @@ import struct from pymysql.times import Date, Time, TimeDelta, Timestamp from pymysql.constants import FIELD_TYPE +try: + set +except NameError: + try: + from sets import BaseSet as set + except ImportError: + from sets import Set as set + ESCAPE_REGEX = re.compile(r"[\0\n\r\032\'\"\\]") ESCAPE_MAP = {'\0': '\\0', '\n': '\\n', '\r': '\\r', '\032': '\\Z', '\'': '\\\'', '"': '\\"', '\\': '\\\\'} @@ -125,7 +133,7 @@ def convert_timedelta(obj): """ from math import modf try: - hours, minutes, seconds = tuple(int(x) for x in obj.split(':')) + hours, minutes, seconds = tuple([int(x) for x in obj.split(':')]) tdelta = datetime.timedelta( hours = int(hours), minutes = int(minutes),