Fixed a potential bug in _get_server_information() regarding buffering. Fixed a bug having to do with old versions of Python and the set module. Basic Jython 2.2 compatibility now exists.

This commit is contained in:
Pete Hunt
2010-07-27 18:19:31 +00:00
parent 952168b4c0
commit d68666032e
4 changed files with 24 additions and 11 deletions

1
README
View File

@@ -29,6 +29,7 @@ Changes
pymysql.version_info attribute. pymysql.version_info attribute.
-Now runs with no warnings with the -3 command-line switch -Now runs with no warnings with the -3 command-line switch
-Added test cases for all outstanding tickets and closed most of them. -Added test cases for all outstanding tickets and closed most of them.
-Basic Jython support added
0.2 -Changed connection parameter name 'password' to 'passwd' 0.2 -Changed connection parameter name 'password' to 'passwd'
to make it more plugin replaceable for the other mysql clients. to make it more plugin replaceable for the other mysql clients.

View File

@@ -14,7 +14,10 @@ try:
frozenset frozenset
except NameError: except NameError:
from sets import ImmutableSet as frozenset from sets import ImmutableSet as frozenset
from sets import BaseSet as set try:
from sets import BaseSet as set
except ImportError:
from sets import Set as set
threadsafety = 1 threadsafety = 1
apilevel = "2.0" apilevel = "2.0"

View File

@@ -1,8 +1,6 @@
# Python implementation of the MySQL client-server protocol # Python implementation of the MySQL client-server protocol
# http://forge.mysql.com/wiki/MySQL_Internals_ClientServer_Protocol # http://forge.mysql.com/wiki/MySQL_Internals_ClientServer_Protocol
# TODO: use streams instead of send() and recv()
import re import re
try: try:
@@ -127,7 +125,7 @@ def _hash_password_323(password):
add = 7L add = 7L
nr2 = 0x12345671L nr2 = 0x12345671L
for c in (ord(x) for x in password if x not in (' ', '\t')): for c in [ord(x) for x in password if x not in (' ', '\t')]:
nr^= (((nr & 63)+add)*c)+ (nr << 8) & 0xFFFFFFFF nr^= (((nr & 63)+add)*c)+ (nr << 8) & 0xFFFFFFFF
nr2= (nr2 + ((nr2 << 8) ^ nr)) & 0xFFFFFFFF nr2= (nr2 + ((nr2 << 8) ^ nr)) & 0xFFFFFFFF
add= (add + c) & 0xFFFFFFFF add= (add + c) & 0xFFFFFFFF
@@ -211,6 +209,8 @@ class MysqlPacket(object):
def packet_number(self): return self.__packet_number def packet_number(self): return self.__packet_number
def get_all_data(self): return self.__data
def read(self, size): def read(self, size):
"""Read the first 'size' bytes in packet and advance cursor past them.""" """Read the first 'size' bytes in packet and advance cursor past them."""
result = self.peek(size) result = self.peek(size)
@@ -635,7 +635,7 @@ class Connection(object):
# if old_passwords is enabled the packet will be 1 byte long and # if old_passwords is enabled the packet will be 1 byte long and
# have the octet 254 # have the octet 254
if auth_packet.get_bytes(0,2) == chr(254): if auth_packet.is_eof_packet():
# send legacy handshake # send legacy handshake
raise NotImplementedError, "old_passwords are not supported. Check to see if mysqld was started with --old-passwords, if old-passwords=1 in a my.cnf file, or if there are some short hashes in your mysql.user table." raise NotImplementedError, "old_passwords are not supported. Check to see if mysqld was started with --old-passwords, if old-passwords=1 in a my.cnf file, or if there are some short hashes in your mysql.user table."
#data = _scramble_323(self.password, self.salt) + "\0" #data = _scramble_323(self.password, self.salt) + "\0"
@@ -663,11 +663,12 @@ class Connection(object):
def _get_server_information(self): def _get_server_information(self):
sock = self.socket sock = self.socket
i = 0 i = 0
# TODO: likely bug here because recv() might return less bytes than we need packet = MysqlPacket(sock)
data = sock.recv(BUFFER_SIZE) data = packet.get_all_data()
if DEBUG: dump_packet(data) if DEBUG: dump_packet(data)
packet_len = ord(data[i:i+1]) #packet_len = ord(data[i:i+1])
i += 4 #i += 4
self.protocol_version = ord(data[i:i+1]) self.protocol_version = ord(data[i:i+1])
i += 1 i += 1

View File

@@ -7,6 +7,14 @@ import struct
from pymysql.times import Date, Time, TimeDelta, Timestamp from pymysql.times import Date, Time, TimeDelta, Timestamp
from pymysql.constants import FIELD_TYPE from pymysql.constants import FIELD_TYPE
try:
set
except NameError:
try:
from sets import BaseSet as set
except ImportError:
from sets import Set as set
ESCAPE_REGEX = re.compile(r"[\0\n\r\032\'\"\\]") ESCAPE_REGEX = re.compile(r"[\0\n\r\032\'\"\\]")
ESCAPE_MAP = {'\0': '\\0', '\n': '\\n', '\r': '\\r', '\032': '\\Z', ESCAPE_MAP = {'\0': '\\0', '\n': '\\n', '\r': '\\r', '\032': '\\Z',
'\'': '\\\'', '"': '\\"', '\\': '\\\\'} '\'': '\\\'', '"': '\\"', '\\': '\\\\'}
@@ -125,7 +133,7 @@ def convert_timedelta(obj):
""" """
from math import modf from math import modf
try: try:
hours, minutes, seconds = tuple(int(x) for x in obj.split(':')) hours, minutes, seconds = tuple([int(x) for x in obj.split(':')])
tdelta = datetime.timedelta( tdelta = datetime.timedelta(
hours = int(hours), hours = int(hours),
minutes = int(minutes), minutes = int(minutes),