My first commit. Did a lot of refactoring, so see the updated README file

This commit is contained in:
Pete Hunt
2010-07-26 19:28:33 +00:00
parent 718be87596
commit ca4454a3e7
14 changed files with 557 additions and 161 deletions

28
README
View File

@@ -1,15 +1,35 @@
==================== ====================
pymysql Installation PyMySQL Installation
==================== ====================
.. contents:: .. contents::
.. ..
This package contains a pure python mysql client library. This package contains a pure-Python MySQL client library.
Documentation on the MySQL client/server protocol can be found here: Documentation on the MySQL client/server protocol can be found here:
http://forge.mysql.com/wiki/MySQL_Internals_ClientServer_Protocol http://forge.mysql.com/wiki/MySQL_Internals_ClientServer_Protocol
If you would like to run the test suite, create a ~/.my.cnf file and
a database called "test_pymysql". The goal of pymysql is to be a drop-in
replacement for MySQLdb and work on CPython 2.3+, Jython, IronPython, PyPy
and Python 3. We test for compatibility by simply changing the import
statements in the Django MySQL backend and running its unit tests as well
as running it against the MySQLdb and myconnpy unit tests.
Changes Changes
-------- --------
0.3 -Implemented most of the extended DBAPI 2.0 spec including callproc()
-Fixed error handling to include the message from the server and support
multiple protocol versions.
-Implemented ping()
-Implemented unicode support (probably needs better testing)
-Removed DeprecationWarnings
-Ran against the MySQLdb and myconnpy unit tests to check for bugs
-Added support for client_flag, charset, sql_mode, read_default_file, and
use_unicode.
-Refactoring for some more compatibility with MySQLdb including a fake
pymysql.version_info attribute.
-Now runs with no warnings with the -3 command-line switch
-Added test cases for all outstanding tickets and closed most of them.
0.2 -Changed connection parameter name 'password' to 'passwd' 0.2 -Changed connection parameter name 'password' to 'passwd'
to make it more plugin replaceable for the other mysql clients. to make it more plugin replaceable for the other mysql clients.
-Changed pack()/unpack() calls so it runs on 64 bit OSes too. -Changed pack()/unpack() calls so it runs on 64 bit OSes too.
@@ -26,11 +46,11 @@ Requirements
* http://www.python.org/ * http://www.python.org/
* 2.5 is the primary test environment. * 2.6 is the primary test environment.
* MySQL 4.1 or higher * MySQL 4.1 or higher
* protocol41 support * protocol41 support, experimental 4.0 support
Installation Installation
------------ ------------

View File

@@ -2,8 +2,8 @@
import pymysql import pymysql
conn = pymysql.connect(host='127.0.0.1', unix_socket='/tmp/mysql.sock', user='root', passwd=None, db='mysql') #conn = pymysql.connect(host='127.0.0.1', unix_socket='/tmp/mysql.sock', user='root', passwd=None, db='mysql')
# conn = pymysql.connect(host='127.0.0.1', port=3306, user='root', passwd=None, db='mysql') conn = pymysql.connect(host='127.0.0.1', port=3306, user='root', passwd=None, db='mysql')
cur = conn.cursor() cur = conn.cursor()
@@ -14,7 +14,7 @@ cur.execute("SELECT Host,User FROM user")
# r = cur.fetchall() # r = cur.fetchall()
# print r # print r
# ...or... # ...or...
for r in cur: for r in cur.fetchall():
print r print r
cur.close() cur.close()

View File

@@ -1,37 +1,41 @@
VERSION = (0, 2, None) VERSION = (0, 3, None)
__version__ = VERSION # for MySQLdb compatibility
from pymysql.constants import FIELD_TYPE from pymysql.constants import FIELD_TYPE
from pymysql.converters import escape_dict, escape_sequence, escape_string from pymysql.converters import escape_dict, escape_sequence, escape_string
from pymysql.exceptions import Warning, Error, InterfaceError, DataError, \ from pymysql.err import Warning, Error, InterfaceError, DataError, \
DatabaseError, OperationalError, IntegrityError, InternalError, \ DatabaseError, OperationalError, IntegrityError, InternalError, \
NotSupportedError, ProgrammingError NotSupportedError, ProgrammingError
from pymysql.times import Date, Time, Timestamp, \ from pymysql.times import Date, Time, Timestamp, \
DateFromTicks, TimeFromTicks, TimestampFromTicks DateFromTicks, TimeFromTicks, TimestampFromTicks
from sets import ImmutableSet try:
frozenset
except NameError:
from sets import ImmutableSet as frozenset
from sets import BaseSet as set
threadsafety = 1 threadsafety = 1
apilevel = "2.0" apilevel = "2.0"
paramstyle = "format" paramstyle = "format"
class DBAPISet(ImmutableSet): class DBAPISet(frozenset):
def __ne__(self, other): def __ne__(self, other):
from sets import BaseSet if isinstance(other, set):
if isinstance(other, BaseSet):
return super(DBAPISet, self).__ne__(self, other) return super(DBAPISet, self).__ne__(self, other)
else: else:
return other not in self return other not in self
def __eq__(self, other): def __eq__(self, other):
from sets import BaseSet if isinstance(other, frozenset):
if isinstance(other, BaseSet): return frozenset.__eq__(self, other)
return super(DBAPISet, self).__eq__(self, other)
else: else:
return other in self return other in self
def __hash__(self):
return frozenset.__hash__(self)
STRING = DBAPISet([FIELD_TYPE.ENUM, FIELD_TYPE.STRING, STRING = DBAPISet([FIELD_TYPE.ENUM, FIELD_TYPE.STRING,
FIELD_TYPE.VAR_STRING]) FIELD_TYPE.VAR_STRING])
@@ -59,6 +63,16 @@ def get_client_info(): # for MySQLdb compatibility
connect = Connection = Connect connect = Connection = Connect
# we include a doctored version_info here for MySQLdb compatibility
version_info = (1,2,2,"final",0)
NULL = "NULL"
__version__ = get_client_info()
def thread_safe():
# Pure python, so yes we're threadsafe
return True
__all__ = [ __all__ = [
'BINARY', 'Binary', 'Connect', 'Connection', 'DATE', 'Date', 'BINARY', 'Binary', 'Connect', 'Connection', 'DATE', 'Date',
@@ -70,4 +84,6 @@ __all__ = [
'connections', 'constants', 'converters', 'cursors', 'debug', 'escape', 'connections', 'constants', 'converters', 'cursors', 'debug', 'escape',
'escape_dict', 'escape_sequence', 'escape_string', 'get_client_info', 'escape_dict', 'escape_sequence', 'escape_string', 'get_client_info',
'paramstyle', 'string_literal', 'threadsafety', 'version_info', 'paramstyle', 'string_literal', 'threadsafety', 'version_info',
"NULL","__version__",
] ]

View File

@@ -2,19 +2,28 @@
# http://forge.mysql.com/wiki/MySQL_Internals_ClientServer_Protocol # http://forge.mysql.com/wiki/MySQL_Internals_ClientServer_Protocol
import re import re
import sha
try:
import hashlib
sha_new = lambda *args, **kwargs: hashlib.new("sha1", *args, **kwargs)
except ImportError:
import sha
sha_new = sha.new
import socket import socket
import struct import struct
import sys import sys
import os
import ConfigParser
from pymysql.charset import MBLENGTH from pymysql.charset import MBLENGTH
from pymysql.cursor import Cursor from pymysql.cursors import Cursor
from pymysql.constants import FIELD_TYPE from pymysql.constants import FIELD_TYPE
from pymysql.constants import SERVER_STATUS from pymysql.constants import SERVER_STATUS
from pymysql.constants.CLIENT_FLAG import * from pymysql.constants.CLIENT import *
from pymysql.constants.COMMAND import * from pymysql.constants.COMMAND import *
from pymysql.converters import escape_item, encoders, decoders from pymysql.converters import escape_item, encoders, decoders, field_decoders
from pymysql.exceptions import raise_mysql_exception, Warning, Error, \ from pymysql.err import raise_mysql_exception, Warning, Error, \
InterfaceError, DataError, DatabaseError, OperationalError, \ InterfaceError, DataError, DatabaseError, OperationalError, \
IntegrityError, InternalError, NotSupportedError, ProgrammingError IntegrityError, InternalError, NotSupportedError, ProgrammingError
@@ -57,9 +66,9 @@ def _scramble(password, message):
if password == None or len(password) == 0: if password == None or len(password) == 0:
return '\0' return '\0'
if DEBUG: print 'password=' + password if DEBUG: print 'password=' + password
stage1 = sha.new(password).digest() stage1 = sha_new(password).digest()
stage2 = sha.new(stage1).digest() stage2 = sha_new(stage1).digest()
s = sha.new() s = sha_new()
s.update(message) s.update(message)
s.update(stage2) s.update(stage2)
result = s.digest() result = s.digest()
@@ -105,7 +114,11 @@ def defaulterrorhandler(connection, cursor, errorclass, errorvalue):
connection.messages.append(err) connection.messages.append(err)
del cursor del cursor
del connection del connection
raise errorclass, errorvalue
if not issubclass(errorclass, Error):
raise Error(errorclass, errorvalue)
else:
raise errorclass, errorvalue
class MysqlPacket(object): class MysqlPacket(object):
@@ -121,6 +134,9 @@ class MysqlPacket(object):
def __recv_packet(self, socket): def __recv_packet(self, socket):
"""Parse the packet header and read entire packet payload into buffer.""" """Parse the packet header and read entire packet payload into buffer."""
packet_header = socket.recv(4) packet_header = socket.recv(4)
while len(packet_header) < 4:
packet_header += socket.recv(4 - len(packet_header))
if DEBUG: dump_packet(packet_header) if DEBUG: dump_packet(packet_header)
packet_length_bin = packet_header[:3] packet_length_bin = packet_header[:3]
self.__packet_number = ord(packet_header[3]) self.__packet_number = ord(packet_header[3])
@@ -303,7 +319,7 @@ class FieldDescriptorPacket(MysqlPacket):
def get_column_length(self): def get_column_length(self):
if self.type_code == FIELD_TYPE.VAR_STRING: if self.type_code == FIELD_TYPE.VAR_STRING:
mblen = MBLENGTH.get(self.charsetnr, 1) mblen = MBLENGTH.get(self.charsetnr, 1)
return self.length / mblen return self.length // mblen
return self.length return self.length
def __str__(self): def __str__(self):
@@ -316,54 +332,103 @@ class Connection(object):
"""Representation of a socket with a mysql server.""" """Representation of a socket with a mysql server."""
errorhandler = defaulterrorhandler errorhandler = defaulterrorhandler
def __init__(self, *args, **kwargs): def __init__(self, host="localhost", user=None, passwd="",
self.host = kwargs['host'] db=None, port=3306, unix_socket=None,
self.port = kwargs.get('port', 3306) charset=DEFAULT_CHARSET, sql_mode=None,
self.user = kwargs['user'] read_default_file=None, conv=decoders, use_unicode=True,
self.password = kwargs['passwd'] client_flag=0):
self.db = kwargs.get('db', None)
self.unix_socket = kwargs.get('unix_socket', None) if read_default_file:
self.charset = DEFAULT_CHARSET cfg = ConfigParser.RawConfigParser()
cfg.read(os.path.expanduser(read_default_file))
client_flag = CLIENT_CAPABILITIES
#client_flag = kwargs.get('client_flag', None) def _config(key, default):
client_flag |= CLIENT_MULTI_STATEMENTS try:
return cfg.get("client",key)
except:
return default
user = _config("user",user)
passwd = _config("password",passwd)
host = _config("host", host)
db = _config("db",db)
unix_socket = _config("socket",unix_socket)
port = _config("port", port)
charset = _config("default-character-set", charset)
self.host = host
self.port = port
self.user = user
self.password = passwd
self.db = db
self.unix_socket = unix_socket
self.use_unicode = use_unicode
self.charset = charset
client_flag |= CAPABILITIES
client_flag |= MULTI_STATEMENTS
if self.db: if self.db:
client_flag |= CLIENT_CONNECT_WITH_DB client_flag |= CONNECT_WITH_DB
self.client_flag = client_flag self.client_flag = client_flag
self._connect() self._connect()
charset = kwargs.get('charset', None) self.set_charset_set(charset)
self.set_chatset_set(charset)
self.messages = [] self.messages = []
self.encoders = encoders self.encoders = encoders
self.decoders = decoders self.decoders = conv
self.field_decoders = field_decoders
self._affected_rows = 0
self.host_info = "Not connected"
self.autocommit(False) self.autocommit(False)
if sql_mode is not None:
c = self.cursor()
c.execute("SET sql_mode=%s", (sql_mode,))
self.commit()
def close(self): def close(self):
send_data = struct.pack('<i',1) + COM_QUIT try:
sock = self.socket send_data = struct.pack('<i',1) + COM_QUIT
sock.send(send_data) sock = self.socket
sock.close() sock.send(send_data)
sock.close()
except:
exc,value,tb = sys.exc_info()
self.errorhandler(None, exc, value)
def autocommit(self, value): def autocommit(self, value):
self._execute_command(COM_QUERY, "SET AUTOCOMMIT = %s" % \ try:
self.escape(value)) self._execute_command(COM_QUERY, "SET AUTOCOMMIT = %s" % \
self.read_packet() self.escape(value))
self.read_packet()
except:
exc,value,tb = sys.exc_info()
self.errorhandler(None, exc, value)
def commit(self): def commit(self):
self._execute_command(COM_QUERY, "COMMIT") try:
self.read_packet() self._execute_command(COM_QUERY, "COMMIT")
self.read_packet()
except:
exc,value,tb = sys.exc_info()
self.errorhandler(None, exc, value)
def rollback(self): def rollback(self):
self._execute_command(COM_QUERY, "ROLLBACK") try:
self.read_packet() self._execute_command(COM_QUERY, "ROLLBACK")
self.read_packet()
except:
exc,value,tb = sys.exc_info()
self.errorhandler(None, exc, value)
def escape(self, obj): def escape(self, obj):
return escape_item(obj) return escape_item(obj, self.charset)
def literal(self, obj):
return escape_item(obj, self.charset)
def cursor(self): def cursor(self):
return Cursor(self) return Cursor(self)
@@ -379,26 +444,53 @@ class Connection(object):
def query(self, sql): def query(self, sql):
self._execute_command(COM_QUERY, sql) self._execute_command(COM_QUERY, sql)
return self._read_query_result() self._affected_rows = self._read_query_result()
return self._affected_rows
def next_result(self): def next_result(self):
return self._read_query_result() self._affected_rows = self._read_query_result()
return self._affected_rows
def set_chatset_set(self, charset): def affected_rows(self):
sock = self.socket return self._affected_rows
if charset and self.charset != charset:
self._execute_command(COM_QUERY, "SET NAMES %s" % charset) def ping(self, reconnect=True):
self.read_packet() try:
self.charset = charset self._execute_command(COM_PING, "")
except:
if reconnect:
self._connect()
return self.ping(False)
else:
exc,value,tb = sys.exc_info()
self.errorhandler(None, exc, value)
return
pkt = self.read_packet()
return pkt.is_ok_packet()
def set_charset_set(self, charset):
try:
sock = self.socket
if charset and self.charset != charset:
self._execute_command(COM_QUERY, "SET NAMES %s" %
self.escape(charset))
self.read_packet()
self.charset = charset
except:
exc,value,tb = sys.exc_info()
self.errorhandler(None, exc, value)
def _connect(self): def _connect(self):
if self.unix_socket and (self.host == 'localhost' or self.host == '127.0.0.1'): if self.unix_socket and (self.host == 'localhost' or self.host == '127.0.0.1'):
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.connect(self.unix_socket) sock.connect(self.unix_socket)
self.host_info = "Localhost via UNIX socket"
if DEBUG: print 'connected using unix_socket' if DEBUG: print 'connected using unix_socket'
else: else:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect((self.host, self.port)) sock.connect((self.host, self.port))
self.host_info = "socket %s:%d" % (self.host, self.port)
if DEBUG: print 'connected using socket' if DEBUG: print 'connected using socket'
self.socket = sock self.socket = sock
self._get_server_information() self._get_server_information()
@@ -424,8 +516,10 @@ class Connection(object):
def _send_command(self, command, sql): def _send_command(self, command, sql):
send_data = struct.pack('<i', len(sql) + 1) + command + sql send_data = struct.pack('<i', len(sql) + 1) + command + sql
sock = self.socket sock = self.socket
sock.send(send_data) sock.send(send_data)
if DEBUG: dump_packet(send_data) if DEBUG: dump_packet(send_data)
def _execute_command(self, command, sql): def _execute_command(self, command, sql):
@@ -437,9 +531,12 @@ class Connection(object):
def _send_authentication(self): def _send_authentication(self):
sock = self.socket sock = self.socket
self.client_flag |= CLIENT_CAPABILITIES self.client_flag |= CAPABILITIES
if self.server_version.startswith('5'): if self.server_version.startswith('5'):
self.client_flag |= CLIENT_MULTI_RESULTS self.client_flag |= MULTI_RESULTS
if self.user is None:
raise ValueError, "Did not specify a username"
data = (struct.pack('i', self.client_flag)) + "\0\0\0\x01" + \ data = (struct.pack('i', self.client_flag)) + "\0\0\0\x01" + \
'\x08' + '\0'*23 + \ '\x08' + '\0'*23 + \
@@ -458,6 +555,19 @@ class Connection(object):
auth_packet.check_error() auth_packet.check_error()
if DEBUG: auth_packet.dump() if DEBUG: auth_packet.dump()
# _mysql support
def thread_id(self):
return self.server_thread_id[0]
def character_set_name(self):
return self.charset
def get_host_info(self):
return self.host_info
def get_proto_info(self):
return self.protocol_version
def _get_server_information(self): def _get_server_information(self):
sock = self.socket sock = self.socket
i = 0 i = 0
@@ -473,7 +583,6 @@ class Connection(object):
i = server_end + 1 i = server_end + 1
self.server_thread_id = struct.unpack('h', data[i:i+2]) self.server_thread_id = struct.unpack('h', data[i:i+2])
self.thread_id = self.server_thread_id # MySQLdb compatibility
i += 4 i += 4
self.salt = data[i:i+8] self.salt = data[i:i+8]
@@ -561,12 +670,22 @@ class MySQLResult(object):
row = [] row = []
for field in self.fields: for field in self.fields:
converter = self.connection.decoders[field.type_code] if field.type_code in self.connection.decoders:
if DEBUG: print "DEBUG: field=%s, converter=%s" % (field, converter) converter = self.connection.decoders[field.type_code]
data = packet.read_length_coded_binary()
converted = None if DEBUG: print "DEBUG: field=%s, converter=%s" % (field, converter)
if data != None: data = packet.read_length_coded_binary()
converted = converter(data) converted = None
if data != None:
converted = converter(data)
else:
converter = self.connection.field_decoders[field.type_code]
if DEBUG: print "DEBUG: field=%s, converter=%s" % (field, converter)
data = packet.read_length_coded_binary()
converted = None
if data != None:
converted = converter(self.connection, field, data)
row.append(converted) row.append(converted)
rows.append(tuple(row)) rows.append(tuple(row))

View File

@@ -0,0 +1,20 @@
LONG_PASSWORD = 1
FOUND_ROWS = 1 << 1
LONG_FLAG = 1 << 2
CONNECT_WITH_DB = 1 << 3
NO_SCHEMA = 1 << 4
COMPRESS = 1 << 5
ODBC = 1 << 6
LOCAL_FILES = 1 << 7
IGNORE_SPACE = 1 << 8
PROTOCOL_41 = 1 << 9
INTERACTIVE = 1 << 10
SSL = 1 << 11
IGNORE_SIGPIPE = 1 << 12
TRANSACTIONS = 1 << 13
SECURE_CONNECTION = 1 << 15
MULTI_STATEMENTS = 1 << 16
MULTI_RESULTS = 1 << 17
CAPABILITIES = LONG_PASSWORD|LONG_FLAG|TRANSACTIONS| \
PROTOCOL_41|SECURE_CONNECTION

View File

@@ -1,20 +0,0 @@
CLIENT_LONG_PASSWORD = 1
CLIENT_FOUND_ROWS = 1 << 1
CLIENT_LONG_FLAG = 1 << 2
CLIENT_CONNECT_WITH_DB = 1 << 3
CLIENT_NO_SCHEMA = 1 << 4
CLIENT_COMPRESS = 1 << 5
CLIENT_ODBC = 1 << 6
CLIENT_LOCAL_FILES = 1 << 7
CLIENT_IGNORE_SPACE = 1 << 8
CLIENT_PROTOCOL_41 = 1 << 9
CLIENT_INTERACTIVE = 1 << 10
CLIENT_SSL = 1 << 11
CLIENT_IGNORE_SIGPIPE = 1 << 12
CLIENT_TRANSACTIONS = 1 << 13
CLIENT_SECURE_CONNECTION = 1 << 15
CLIENT_MULTI_STATEMENTS = 1 << 16
CLIENT_MULTI_RESULTS = 1 << 17
CLIENT_CAPABILITIES = CLIENT_LONG_PASSWORD|CLIENT_LONG_FLAG|CLIENT_TRANSACTIONS| \
CLIENT_PROTOCOL_41|CLIENT_SECURE_CONNECTION

15
pymysql/constants/FLAG.py Normal file
View File

@@ -0,0 +1,15 @@
NOT_NULL = 1
PRI_KEY = 2
UNIQUE_KEY = 4
MULTIPLE_KEY = 8
BLOB = 16
UNSIGNED = 32
ZEROFILL = 64
BINARY = 128
ENUM = 256
AUTO_INCREMENT = 512
TIMESTAMP = 1024
SET = 2048
PART_KEY = 16384
GROUP = 32767
UNIQUE = 65536

View File

@@ -1,68 +1,81 @@
import re import re
import datetime import datetime
import time
import array import array
import struct
from pymysql.times import Date, Time, TimeDelta, Timestamp from pymysql.times import Date, Time, TimeDelta, Timestamp
from pymysql.constants import FIELD_TYPE from pymysql.constants import FIELD_TYPE
ESCAPE_REGEX = re.compile(r"[\0\n\r\032\'\"\\]", re.IGNORECASE) ESCAPE_REGEX = re.compile(r"[\0\n\r\032\'\"\\]")
ESCAPE_MAP = {'\0': '\\0', '\n': '\\n', '\r': '\\r', '\032': '\\Z',
'\'': '\\\'', '"': '\\"', '\\': '\\\\'}
def escape_item(val): def escape_item(val, charset):
encorder = encoders[type(val)] encoder = encoders[type(val)]
return encorder(val) return encoder(val, charset)
def escape_dict(val): def escape_dict(val, charset):
n = {} n = {}
for k, v in val.items(): for k, v in val.items():
quoted = escape_item(v) quoted = escape_item(v, charset)
n[k] = quoted n[k] = quoted
return n return n
def escape_sequence(val): def escape_sequence(val, charset):
n = [] n = []
for item in val: for item in val:
quoted = escape_item(item) quoted = escape_item(item, charset)
n.append(quoted) n.append(quoted)
return tuple(n) return tuple(n)
def escape_bool(value): def escape_bool(value, charset):
return str(int(value)) return str(int(value)).encode(charset)
def escape_object(value): def escape_object(value, charset):
return str(val) return str(value).encode(charset)
def escape_float(value): escape_int = escape_long = escape_object
return '%.15g' % value
def escape_sequence(value): def escape_float(value, charset):
return value return ('%.15g' % value).encode(charset)
def escape_string(value): def escape_string(value, charset):
r = ("'%s'" % ESCAPE_REGEX.sub(
lambda match: ESCAPE_MAP.get(match.group(0)), value))
if not charset is None:
r = r.encode(charset)
return r
def rep(m): def escape_unicode(value, charset):
n = m.group(0) # pass None as the charset because we already encode it
if n == "\0": return escape_string(value.encode(charset), None)
return "\\0"
elif n == "\n":
return "\\n"
elif n == "\r":
return "\\r"
elif n == "\032":
return "\\Z"
else:
return "\\"+n
s = re.sub(ESCAPE_REGEX, rep, value)
return s
def escape_None(value, charset):
return 'NULL'.encode(charset)
def escape_timedelta(obj): def escape_timedelta(obj, charset):
seconds = int(obj.seconds) % 60 seconds = int(obj.seconds) % 60
minutes = int(obj.seconds / 60) % 60 minutes = int(obj.seconds // 60) % 60
hours = int(obj.seconds / 3600) % 24 hours = int(obj.seconds // 3600) % 24 + int(obj.days) * 24
return '%d %02d:%02d:%02d' % (obj.days, hours, minutes, seconds) return escape_string('%02d:%02d:%02d' % (hours, minutes, seconds), charset)
def escape_datetime(obj): def escape_time(obj, charset):
return obj.strftime("%Y-%m-%d %H:%M:%S") s = "%02d:%02d:%02d" % (int(obj.hour), int(obj.minute),
int(obj.second))
if obj.microsecond:
s += ".%f" % obj.microsecond
return escape_string(s, charset)
def escape_datetime(obj, charset):
return escape_string(obj.strftime("%Y-%m-%d %H:%M:%S"), charset)
def escape_date(obj, charset):
return escape_string(obj.strftime("%Y-%m-%d"), charset)
def escape_struct_time(obj, charset):
return escape_datetime(datetime.datetime(*obj[:6]), charset)
def convert_datetime(obj): def convert_datetime(obj):
"""Returns a DATETIME or TIMESTAMP column value as a datetime object: """Returns a DATETIME or TIMESTAMP column value as a datetime object:
@@ -85,13 +98,13 @@ def convert_datetime(obj):
elif 'T' in obj: elif 'T' in obj:
sep = 'T' sep = 'T'
else: else:
return date_or_None(obj) return convert_date(obj)
try: try:
ymd, hms = obj.split(sep, 1) ymd, hms = obj.split(sep, 1)
return datetime.datetime(*[ int(x) for x in ymd.split('-')+hms.split(':') ]) return datetime.datetime(*[ int(x) for x in ymd.split('-')+hms.split(':') ])
except ValueError: except ValueError:
return date_or_None(obj) return convert_date(obj)
def convert_timedelta(obj): def convert_timedelta(obj):
"""Returns a TIME column as a timedelta object: """Returns a TIME column as a timedelta object:
@@ -112,17 +125,14 @@ def convert_timedelta(obj):
""" """
from math import modf from math import modf
try: try:
hours, minutes, seconds = obj.split(':') hours, minutes, seconds = tuple(int(x) for x in obj.split(':'))
tdelta = datetime.timedelta( tdelta = datetime.timedelta(
hours = int(hours), hours = int(hours),
minutes = int(minutes), minutes = int(minutes),
seconds = int(seconds), seconds = int(seconds),
microseconds = int(modf(float(seconds))[0]*1000000), microseconds = int(modf(float(seconds))[0]*1000000),
) )
if hours < 0: return tdelta
return -tdelta
else:
return tdelta
except ValueError: except ValueError:
return None return None
@@ -197,7 +207,7 @@ def convert_mysql_timestamp(timestamp):
""" """
if timestamp[4] == '-': if timestamp[4] == '-':
return datetime_or_None(timestamp) return convert_datetime(timestamp)
timestamp += "0"*(14-len(timestamp)) # padding timestamp += "0"*(14-len(timestamp)) # padding
year, month, day, hour, minute, second = \ year, month, day, hour, minute, second = \
int(timestamp[:4]), int(timestamp[4:6]), int(timestamp[6:8]), \ int(timestamp[:4]), int(timestamp[4:6]), int(timestamp[6:8]), \
@@ -207,21 +217,35 @@ def convert_mysql_timestamp(timestamp):
except ValueError: except ValueError:
return None return None
def convert_set(s):
# TODO: this may not be correct
return set(s.split(","))
def convert_bit(b):
b = "\x00" * (8 - len(b)) + b # pad w/ zeroes
return struct.unpack(">Q", b)[0]
encoders = { encoders = {
bool: escape_bool, bool: escape_bool,
int: escape_object, int: escape_int,
long: escape_object, long: escape_long,
float: escape_float, float: escape_float,
str: escape_string, str: escape_string,
unicode: escape_unicode,
tuple: escape_sequence, tuple: escape_sequence,
list:escape_sequence, list:escape_sequence,
set:escape_sequence,
dict:escape_dict, dict:escape_dict,
type(None):escape_None,
datetime.date: escape_date,
datetime.datetime : escape_datetime, datetime.datetime : escape_datetime,
datetime.timedelta : escape_timedelta datetime.timedelta : escape_timedelta,
datetime.time : escape_time,
time.struct_time : escape_struct_time,
} }
decoders = { decoders = {
FIELD_TYPE.BIT: convert_bit,
FIELD_TYPE.TINY: int, FIELD_TYPE.TINY: int,
FIELD_TYPE.SHORT: int, FIELD_TYPE.SHORT: int,
FIELD_TYPE.LONG: long, FIELD_TYPE.LONG: long,
@@ -236,17 +260,37 @@ decoders = {
FIELD_TYPE.DATETIME: convert_datetime, FIELD_TYPE.DATETIME: convert_datetime,
FIELD_TYPE.TIME: convert_timedelta, FIELD_TYPE.TIME: convert_timedelta,
FIELD_TYPE.DATE: convert_date, FIELD_TYPE.DATE: convert_date,
FIELD_TYPE.BLOB: str, FIELD_TYPE.SET: convert_set,
FIELD_TYPE.STRING: str, #FIELD_TYPE.BLOB: str,
FIELD_TYPE.VAR_STRING: str, #FIELD_TYPE.STRING: str,
FIELD_TYPE.VARCHAR: str #FIELD_TYPE.VAR_STRING: str,
#FIELD_TYPE.VARCHAR: str
} }
conversions = decoders # for MySQLdb compatibility conversions = decoders # for MySQLdb compatibility
def decode_characters(connection, field, data):
if field.charsetnr == 63 or not connection.use_unicode:
# binary data, leave it alone
return data
return data.decode(connection.charset)
# These take a field instance rather than just the data.
field_decoders = {
FIELD_TYPE.BLOB: decode_characters,
FIELD_TYPE.STRING: decode_characters,
FIELD_TYPE.VAR_STRING: decode_characters,
FIELD_TYPE.VARCHAR: decode_characters,
}
try: try:
# python version > 2.3 # python version > 2.3
from decimal import Decimal from decimal import Decimal
decoders[FIELD_TYPE.DECIMAL] = Decimal decoders[FIELD_TYPE.DECIMAL] = Decimal
decoders[FIELD_TYPE.NEWDECIMAL] = Decimal decoders[FIELD_TYPE.NEWDECIMAL] = Decimal
def escape_decimal(obj):
return str(obj)
encoders[Decimal] = escape_decimal
except ImportError: except ImportError:
pass pass

View File

@@ -1,9 +1,11 @@
import struct import struct
import re
from pymysql.exceptions import Warning, Error, InterfaceError, DataError, \ from pymysql.err import Warning, Error, InterfaceError, DataError, \
DatabaseError, OperationalError, IntegrityError, InternalError, \ DatabaseError, OperationalError, IntegrityError, InternalError, \
NotSupportedError, ProgrammingError NotSupportedError, ProgrammingError
insert_values = re.compile(r'\svalues\s*(\(.+\))', re.IGNORECASE)
class Cursor(object): class Cursor(object):
@@ -16,7 +18,7 @@ class Cursor(object):
self.arraysize = 1 self.arraysize = 1
self._executed = None self._executed = None
self.messages = [] self.messages = []
self.errorhandler =connection.errorhandler self.errorhandler = connection.errorhandler
self._has_next = None self._has_next = None
self._rows = () self._rows = ()
@@ -68,11 +70,15 @@ class Cursor(object):
charset = conn.charset charset = conn.charset
del self.messages[:] del self.messages[:]
# this ordering is good because conn.escape() returns
# an encoded string.
if isinstance(query, unicode): if isinstance(query, unicode):
query = query.encode(charset) query = query.encode(charset)
if args is not None: if args is not None:
query = query % conn.escape(args) query = query % conn.escape(args)
result = 0
try: try:
result = self._query(query) result = self._query(query)
except: except:
@@ -98,7 +104,13 @@ class Cursor(object):
def callproc(self, procname, args=()): def callproc(self, procname, args=()):
self.errorhandler(self, NotSupportedError, "not supported") #self.errorhandler(self, NotSupportedError, "not supported")
if not isinstance(args, tuple):
args = (args,)
argstr = ("%s," * len(args))[:-1]
return self.execute("CALL `%s`(%s)" % (procname, argstr), args)
def fetchone(self): def fetchone(self):
self._check_executed() self._check_executed()

View File

@@ -108,22 +108,24 @@ del StandardError, _map_error, ER
def _get_error_info(data): def _get_error_info(data):
errno = struct.unpack('h', data[5:7])[0] errno = struct.unpack('<h', data[1:3])[0]
sqlstate = struct.unpack('5s', data[8:8+5])[0] if data[3] == "#":
start = 13 # version 4.1
end = data.find('\0', start) sqlstate = data[4:9].decode("utf8")
errorvalue = data[start:end] errorvalue = data[9:].decode("utf8")
return (errno, sqlstate, errorvalue) return (errno, sqlstate, errorvalue)
else:
# version 4.0
return (errno, None, data[3:].decode("utf8"))
def _check_mysql_exception(errinfo): def _check_mysql_exception(errinfo):
errno, sqlstate, errorvalue = errinfo errno, sqlstate, errorvalue = errinfo
errorclass = error_map.get(errno, None) errorclass = error_map.get(errno, None)
if errorclass: if errorclass:
raise errorclass, errorvalue raise errorclass, (errno,errorvalue)
"""
TODO not found errno # couldn't find the right error number
""" raise InternalError, (errno, errorvalue)
raise InternalError, ""
def raise_mysql_exception(data): def raise_mysql_exception(data):
errinfo = _get_error_info(data) errinfo = _get_error_info(data)

View File

@@ -0,0 +1,5 @@
from test_issues import *
if __name__ == "__main__":
import unittest
unittest.main()

18
pymysql/tests/base.py Normal file
View File

@@ -0,0 +1,18 @@
import pymysql
import unittest
class PyMySqlTestCase(unittest.TestCase):
databases = [
{"host":"localhost","user":"root","passwd":"","db":"test_pymysql"},
{"host":"localhost","user":"root","passwd":"","db":"test_pymysql2"}]
def setUp(self):
self.connections = []
for params in self.databases:
self.connections.append(pymysql.connect(**params))
def tearDown(self):
for connection in self.connections:
connection.close()

View File

@@ -0,0 +1,144 @@
import pymysql
import base
import imp
import datetime
class TestOldIssues(base.PyMySqlTestCase):
def test_issue_3(self):
""" undefined methods datetime_or_None, date_or_None """
conn = self.connections[0]
c = conn.cursor()
c.execute("create table issue3 (d date, t time, dt datetime, ts timestamp)")
try:
c.execute("insert into issue3 (d, t, dt, ts) values (%s,%s,%s,%s)", (None, None, None, None))
c.execute("select d from issue3")
self.assertEqual(None, c.fetchone()[0])
c.execute("select t from issue3")
self.assertEqual(None, c.fetchone()[0])
c.execute("select dt from issue3")
self.assertEqual(None, c.fetchone()[0])
c.execute("select ts from issue3")
self.assertTrue(isinstance(c.fetchone()[0], datetime.datetime))
finally:
c.execute("drop table issue3")
def test_issue_4(self):
""" can't retrieve TIMESTAMP fields """
conn = self.connections[0]
c = conn.cursor()
c.execute("create table issue4 (ts timestamp)")
try:
c.execute("insert into issue4 (ts) values (now())")
c.execute("select ts from issue4")
self.assertTrue(isinstance(c.fetchone()[0], datetime.datetime))
finally:
c.execute("drop table issue4")
def test_issue_6(self):
""" exception: TypeError: ord() expected a character, but string of length 0 found """
conn = pymysql.connect(host="localhost",user="root",passwd="",db="mysql")
c = conn.cursor()
c.execute("select * from user")
conn.close()
def test_issue_8(self):
""" Primary Key and Index error when selecting data """
conn = self.connections[0]
c = conn.cursor()
c.execute("""CREATE TABLE `test` (`station` int(10) NOT NULL DEFAULT '0', `dh`
datetime NOT NULL DEFAULT '0000-00-00 00:00:00', `echeance` int(1) NOT NULL
DEFAULT '0', `me` double DEFAULT NULL, `mo` double DEFAULT NULL, PRIMARY
KEY (`station`,`dh`,`echeance`)) ENGINE=MyISAM DEFAULT CHARSET=latin1;""")
try:
self.assertEqual(0, c.execute("SELECT * FROM test"))
c.execute("ALTER TABLE `test` ADD INDEX `idx_station` (`station`)")
self.assertEqual(0, c.execute("SELECT * FROM test"))
finally:
c.execute("drop table test")
def test_issue_9(self):
""" sets DeprecationWarning in Python 2.6 """
try:
imp.reload(pymysql)
except DeprecationWarning:
self.fail()
def test_issue_10(self):
""" Allocate a variable to return when the exception handler is permissive """
conn = self.connections[0]
conn.errorhandler = lambda cursor, errorclass, errorvalue: None
cur = conn.cursor()
cur.execute( "create table t( n int )" )
cur.execute( "create table t( n int )" )
def test_issue_13(self):
""" can't handle large result fields """
conn = self.connections[0]
cur = conn.cursor()
cur.execute("create table issue13 (t text)")
try:
# ticket says 18k
size = 18*1024
cur.execute("insert into issue13 (t) values (%s)", ("x" * size,))
cur.execute("select t from issue13")
# use assert_ so that obscenely huge error messages don't print
r = cur.fetchone()[0]
self.assert_("x" * size == r)
finally:
cur.execute("drop table issue13")
def test_issue_14(self):
""" typo in converters.py """
self.assertEqual('1', pymysql.converters.escape_item(1, "utf8"))
self.assertEqual('1', pymysql.converters.escape_item(1L, "utf8"))
self.assertEqual('1', pymysql.converters.escape_object(1, "utf8"))
self.assertEqual('1', pymysql.converters.escape_object(1L, "utf8"))
def test_issue_15(self):
""" query should be expanded before perform character encoding """
conn = self.connections[0]
c = conn.cursor()
c.execute("create table issue15 (t varchar(32))")
try:
c.execute("insert into issue15 (t) values (%s)", (u'\xe4\xf6\xfc'))
c.execute("select t from issue15")
self.assertEqual(u'\xe4\xf6\xfc', c.fetchone()[0])
finally:
c.execute("drop table issue15")
def test_issue_16(self):
""" Patch for string and tuple escaping """
conn = self.connections[0]
c = conn.cursor()
c.execute("create table issue16 (name varchar(32) primary key, email varchar(32))")
try:
c.execute("insert into issue16 (name, email) values ('pete', 'floydophone')")
c.execute("select email from issue16 where name=%s", ("pete",))
self.assertEqual("floydophone", c.fetchone()[0])
finally:
c.execute("drop table issue16")
def test_issue_17(self):
""" could not connect mysql use passwod """
conn = self.connections[0]
host = self.databases[0]["host"]
db = self.databases[0]["db"]
c = conn.cursor()
# grant access to a table to a user with a password
c.execute("create table issue17 (x varchar(32) primary key)")
try:
c.execute("insert into issue17 (x) values ('hello, world!')")
c.execute("grant all privileges on issue17 to issue17user identified by '1234'")
conn2 = pymysql.connect(host=host, user="issue17user", passwd="1234", db=db)
c2 = conn2.cursor()
c2.execute("select x from issue17")
self.assertEqual("hello, world!", c2.fetchone()[0])
finally:
c.execute("drop table issue17")
if __name__ == "__main__":
import unittest
unittest.main()

View File

@@ -1,5 +1,6 @@
from setuptools import setup #from setuptools import setup
from distutils.core import setup
version_tuple = __import__('pymysql').VERSION version_tuple = __import__('pymysql').VERSION
@@ -17,5 +18,5 @@ setup(
maintainer = 'David.Story', maintainer = 'David.Story',
maintainer_email = 'iDavidStory@gmail.com', maintainer_email = 'iDavidStory@gmail.com',
description = 'Pure Python MySQL Driver ', description = 'Pure Python MySQL Driver ',
packages = ['pymysql', 'pymysql.constants'] packages = ['pymysql', 'pymysql.constants', 'pymysql.tests']
) )