My first commit. Did a lot of refactoring, so see the updated README file

This commit is contained in:
Pete Hunt
2010-07-26 19:28:33 +00:00
parent 718be87596
commit ca4454a3e7
14 changed files with 557 additions and 161 deletions

28
README
View File

@@ -1,15 +1,35 @@
====================
pymysql Installation
PyMySQL Installation
====================
.. contents::
..
This package contains a pure python mysql client library.
This package contains a pure-Python MySQL client library.
Documentation on the MySQL client/server protocol can be found here:
http://forge.mysql.com/wiki/MySQL_Internals_ClientServer_Protocol
If you would like to run the test suite, create a ~/.my.cnf file and
a database called "test_pymysql". The goal of pymysql is to be a drop-in
replacement for MySQLdb and work on CPython 2.3+, Jython, IronPython, PyPy
and Python 3. We test for compatibility by simply changing the import
statements in the Django MySQL backend and running its unit tests as well
as running it against the MySQLdb and myconnpy unit tests.
Changes
--------
0.3 -Implemented most of the extended DBAPI 2.0 spec including callproc()
-Fixed error handling to include the message from the server and support
multiple protocol versions.
-Implemented ping()
-Implemented unicode support (probably needs better testing)
-Removed DeprecationWarnings
-Ran against the MySQLdb and myconnpy unit tests to check for bugs
-Added support for client_flag, charset, sql_mode, read_default_file, and
use_unicode.
-Refactoring for some more compatibility with MySQLdb including a fake
pymysql.version_info attribute.
-Now runs with no warnings with the -3 command-line switch
-Added test cases for all outstanding tickets and closed most of them.
0.2 -Changed connection parameter name 'password' to 'passwd'
to make it more plugin replaceable for the other mysql clients.
-Changed pack()/unpack() calls so it runs on 64 bit OSes too.
@@ -26,11 +46,11 @@ Requirements
* http://www.python.org/
* 2.5 is the primary test environment.
* 2.6 is the primary test environment.
* MySQL 4.1 or higher
* protocol41 support
* protocol41 support, experimental 4.0 support
Installation
------------

View File

@@ -2,8 +2,8 @@
import pymysql
conn = pymysql.connect(host='127.0.0.1', unix_socket='/tmp/mysql.sock', user='root', passwd=None, db='mysql')
# conn = pymysql.connect(host='127.0.0.1', port=3306, user='root', passwd=None, db='mysql')
#conn = pymysql.connect(host='127.0.0.1', unix_socket='/tmp/mysql.sock', user='root', passwd=None, db='mysql')
conn = pymysql.connect(host='127.0.0.1', port=3306, user='root', passwd=None, db='mysql')
cur = conn.cursor()
@@ -14,7 +14,7 @@ cur.execute("SELECT Host,User FROM user")
# r = cur.fetchall()
# print r
# ...or...
for r in cur:
for r in cur.fetchall():
print r
cur.close()

View File

@@ -1,37 +1,41 @@
VERSION = (0, 2, None)
__version__ = VERSION # for MySQLdb compatibility
VERSION = (0, 3, None)
from pymysql.constants import FIELD_TYPE
from pymysql.converters import escape_dict, escape_sequence, escape_string
from pymysql.exceptions import Warning, Error, InterfaceError, DataError, \
from pymysql.err import Warning, Error, InterfaceError, DataError, \
DatabaseError, OperationalError, IntegrityError, InternalError, \
NotSupportedError, ProgrammingError
from pymysql.times import Date, Time, Timestamp, \
DateFromTicks, TimeFromTicks, TimestampFromTicks
from sets import ImmutableSet
try:
frozenset
except NameError:
from sets import ImmutableSet as frozenset
from sets import BaseSet as set
threadsafety = 1
apilevel = "2.0"
paramstyle = "format"
class DBAPISet(ImmutableSet):
class DBAPISet(frozenset):
def __ne__(self, other):
from sets import BaseSet
if isinstance(other, BaseSet):
if isinstance(other, set):
return super(DBAPISet, self).__ne__(self, other)
else:
return other not in self
def __eq__(self, other):
from sets import BaseSet
if isinstance(other, BaseSet):
return super(DBAPISet, self).__eq__(self, other)
if isinstance(other, frozenset):
return frozenset.__eq__(self, other)
else:
return other in self
def __hash__(self):
return frozenset.__hash__(self)
STRING = DBAPISet([FIELD_TYPE.ENUM, FIELD_TYPE.STRING,
FIELD_TYPE.VAR_STRING])
@@ -59,6 +63,16 @@ def get_client_info(): # for MySQLdb compatibility
connect = Connection = Connect
# we include a doctored version_info here for MySQLdb compatibility
version_info = (1,2,2,"final",0)
NULL = "NULL"
__version__ = get_client_info()
def thread_safe():
# Pure python, so yes we're threadsafe
return True
__all__ = [
'BINARY', 'Binary', 'Connect', 'Connection', 'DATE', 'Date',
@@ -70,4 +84,6 @@ __all__ = [
'connections', 'constants', 'converters', 'cursors', 'debug', 'escape',
'escape_dict', 'escape_sequence', 'escape_string', 'get_client_info',
'paramstyle', 'string_literal', 'threadsafety', 'version_info',
"NULL","__version__",
]

View File

@@ -2,19 +2,28 @@
# http://forge.mysql.com/wiki/MySQL_Internals_ClientServer_Protocol
import re
import sha
try:
import hashlib
sha_new = lambda *args, **kwargs: hashlib.new("sha1", *args, **kwargs)
except ImportError:
import sha
sha_new = sha.new
import socket
import struct
import sys
import os
import ConfigParser
from pymysql.charset import MBLENGTH
from pymysql.cursor import Cursor
from pymysql.cursors import Cursor
from pymysql.constants import FIELD_TYPE
from pymysql.constants import SERVER_STATUS
from pymysql.constants.CLIENT_FLAG import *
from pymysql.constants.CLIENT import *
from pymysql.constants.COMMAND import *
from pymysql.converters import escape_item, encoders, decoders
from pymysql.exceptions import raise_mysql_exception, Warning, Error, \
from pymysql.converters import escape_item, encoders, decoders, field_decoders
from pymysql.err import raise_mysql_exception, Warning, Error, \
InterfaceError, DataError, DatabaseError, OperationalError, \
IntegrityError, InternalError, NotSupportedError, ProgrammingError
@@ -57,9 +66,9 @@ def _scramble(password, message):
if password == None or len(password) == 0:
return '\0'
if DEBUG: print 'password=' + password
stage1 = sha.new(password).digest()
stage2 = sha.new(stage1).digest()
s = sha.new()
stage1 = sha_new(password).digest()
stage2 = sha_new(stage1).digest()
s = sha_new()
s.update(message)
s.update(stage2)
result = s.digest()
@@ -105,7 +114,11 @@ def defaulterrorhandler(connection, cursor, errorclass, errorvalue):
connection.messages.append(err)
del cursor
del connection
raise errorclass, errorvalue
if not issubclass(errorclass, Error):
raise Error(errorclass, errorvalue)
else:
raise errorclass, errorvalue
class MysqlPacket(object):
@@ -121,6 +134,9 @@ class MysqlPacket(object):
def __recv_packet(self, socket):
"""Parse the packet header and read entire packet payload into buffer."""
packet_header = socket.recv(4)
while len(packet_header) < 4:
packet_header += socket.recv(4 - len(packet_header))
if DEBUG: dump_packet(packet_header)
packet_length_bin = packet_header[:3]
self.__packet_number = ord(packet_header[3])
@@ -303,7 +319,7 @@ class FieldDescriptorPacket(MysqlPacket):
def get_column_length(self):
if self.type_code == FIELD_TYPE.VAR_STRING:
mblen = MBLENGTH.get(self.charsetnr, 1)
return self.length / mblen
return self.length // mblen
return self.length
def __str__(self):
@@ -316,54 +332,103 @@ class Connection(object):
"""Representation of a socket with a mysql server."""
errorhandler = defaulterrorhandler
def __init__(self, *args, **kwargs):
self.host = kwargs['host']
self.port = kwargs.get('port', 3306)
self.user = kwargs['user']
self.password = kwargs['passwd']
self.db = kwargs.get('db', None)
self.unix_socket = kwargs.get('unix_socket', None)
self.charset = DEFAULT_CHARSET
client_flag = CLIENT_CAPABILITIES
#client_flag = kwargs.get('client_flag', None)
client_flag |= CLIENT_MULTI_STATEMENTS
def __init__(self, host="localhost", user=None, passwd="",
db=None, port=3306, unix_socket=None,
charset=DEFAULT_CHARSET, sql_mode=None,
read_default_file=None, conv=decoders, use_unicode=True,
client_flag=0):
if read_default_file:
cfg = ConfigParser.RawConfigParser()
cfg.read(os.path.expanduser(read_default_file))
def _config(key, default):
try:
return cfg.get("client",key)
except:
return default
user = _config("user",user)
passwd = _config("password",passwd)
host = _config("host", host)
db = _config("db",db)
unix_socket = _config("socket",unix_socket)
port = _config("port", port)
charset = _config("default-character-set", charset)
self.host = host
self.port = port
self.user = user
self.password = passwd
self.db = db
self.unix_socket = unix_socket
self.use_unicode = use_unicode
self.charset = charset
client_flag |= CAPABILITIES
client_flag |= MULTI_STATEMENTS
if self.db:
client_flag |= CLIENT_CONNECT_WITH_DB
client_flag |= CONNECT_WITH_DB
self.client_flag = client_flag
self._connect()
charset = kwargs.get('charset', None)
self.set_chatset_set(charset)
self.set_charset_set(charset)
self.messages = []
self.encoders = encoders
self.decoders = decoders
self.decoders = conv
self.field_decoders = field_decoders
self._affected_rows = 0
self.host_info = "Not connected"
self.autocommit(False)
if sql_mode is not None:
c = self.cursor()
c.execute("SET sql_mode=%s", (sql_mode,))
self.commit()
def close(self):
send_data = struct.pack('<i',1) + COM_QUIT
sock = self.socket
sock.send(send_data)
sock.close()
try:
send_data = struct.pack('<i',1) + COM_QUIT
sock = self.socket
sock.send(send_data)
sock.close()
except:
exc,value,tb = sys.exc_info()
self.errorhandler(None, exc, value)
def autocommit(self, value):
self._execute_command(COM_QUERY, "SET AUTOCOMMIT = %s" % \
self.escape(value))
self.read_packet()
try:
self._execute_command(COM_QUERY, "SET AUTOCOMMIT = %s" % \
self.escape(value))
self.read_packet()
except:
exc,value,tb = sys.exc_info()
self.errorhandler(None, exc, value)
def commit(self):
self._execute_command(COM_QUERY, "COMMIT")
self.read_packet()
try:
self._execute_command(COM_QUERY, "COMMIT")
self.read_packet()
except:
exc,value,tb = sys.exc_info()
self.errorhandler(None, exc, value)
def rollback(self):
self._execute_command(COM_QUERY, "ROLLBACK")
self.read_packet()
try:
self._execute_command(COM_QUERY, "ROLLBACK")
self.read_packet()
except:
exc,value,tb = sys.exc_info()
self.errorhandler(None, exc, value)
def escape(self, obj):
return escape_item(obj)
return escape_item(obj, self.charset)
def literal(self, obj):
return escape_item(obj, self.charset)
def cursor(self):
return Cursor(self)
@@ -379,26 +444,53 @@ class Connection(object):
def query(self, sql):
self._execute_command(COM_QUERY, sql)
return self._read_query_result()
self._affected_rows = self._read_query_result()
return self._affected_rows
def next_result(self):
return self._read_query_result()
self._affected_rows = self._read_query_result()
return self._affected_rows
def set_chatset_set(self, charset):
sock = self.socket
if charset and self.charset != charset:
self._execute_command(COM_QUERY, "SET NAMES %s" % charset)
self.read_packet()
self.charset = charset
def affected_rows(self):
return self._affected_rows
def ping(self, reconnect=True):
try:
self._execute_command(COM_PING, "")
except:
if reconnect:
self._connect()
return self.ping(False)
else:
exc,value,tb = sys.exc_info()
self.errorhandler(None, exc, value)
return
pkt = self.read_packet()
return pkt.is_ok_packet()
def set_charset_set(self, charset):
try:
sock = self.socket
if charset and self.charset != charset:
self._execute_command(COM_QUERY, "SET NAMES %s" %
self.escape(charset))
self.read_packet()
self.charset = charset
except:
exc,value,tb = sys.exc_info()
self.errorhandler(None, exc, value)
def _connect(self):
if self.unix_socket and (self.host == 'localhost' or self.host == '127.0.0.1'):
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.connect(self.unix_socket)
self.host_info = "Localhost via UNIX socket"
if DEBUG: print 'connected using unix_socket'
else:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect((self.host, self.port))
self.host_info = "socket %s:%d" % (self.host, self.port)
if DEBUG: print 'connected using socket'
self.socket = sock
self._get_server_information()
@@ -424,8 +516,10 @@ class Connection(object):
def _send_command(self, command, sql):
send_data = struct.pack('<i', len(sql) + 1) + command + sql
sock = self.socket
sock.send(send_data)
if DEBUG: dump_packet(send_data)
def _execute_command(self, command, sql):
@@ -437,9 +531,12 @@ class Connection(object):
def _send_authentication(self):
sock = self.socket
self.client_flag |= CLIENT_CAPABILITIES
self.client_flag |= CAPABILITIES
if self.server_version.startswith('5'):
self.client_flag |= CLIENT_MULTI_RESULTS
self.client_flag |= MULTI_RESULTS
if self.user is None:
raise ValueError, "Did not specify a username"
data = (struct.pack('i', self.client_flag)) + "\0\0\0\x01" + \
'\x08' + '\0'*23 + \
@@ -458,6 +555,19 @@ class Connection(object):
auth_packet.check_error()
if DEBUG: auth_packet.dump()
# _mysql support
def thread_id(self):
return self.server_thread_id[0]
def character_set_name(self):
return self.charset
def get_host_info(self):
return self.host_info
def get_proto_info(self):
return self.protocol_version
def _get_server_information(self):
sock = self.socket
i = 0
@@ -473,7 +583,6 @@ class Connection(object):
i = server_end + 1
self.server_thread_id = struct.unpack('h', data[i:i+2])
self.thread_id = self.server_thread_id # MySQLdb compatibility
i += 4
self.salt = data[i:i+8]
@@ -561,12 +670,22 @@ class MySQLResult(object):
row = []
for field in self.fields:
converter = self.connection.decoders[field.type_code]
if DEBUG: print "DEBUG: field=%s, converter=%s" % (field, converter)
data = packet.read_length_coded_binary()
converted = None
if data != None:
converted = converter(data)
if field.type_code in self.connection.decoders:
converter = self.connection.decoders[field.type_code]
if DEBUG: print "DEBUG: field=%s, converter=%s" % (field, converter)
data = packet.read_length_coded_binary()
converted = None
if data != None:
converted = converter(data)
else:
converter = self.connection.field_decoders[field.type_code]
if DEBUG: print "DEBUG: field=%s, converter=%s" % (field, converter)
data = packet.read_length_coded_binary()
converted = None
if data != None:
converted = converter(self.connection, field, data)
row.append(converted)
rows.append(tuple(row))

View File

@@ -0,0 +1,20 @@
LONG_PASSWORD = 1
FOUND_ROWS = 1 << 1
LONG_FLAG = 1 << 2
CONNECT_WITH_DB = 1 << 3
NO_SCHEMA = 1 << 4
COMPRESS = 1 << 5
ODBC = 1 << 6
LOCAL_FILES = 1 << 7
IGNORE_SPACE = 1 << 8
PROTOCOL_41 = 1 << 9
INTERACTIVE = 1 << 10
SSL = 1 << 11
IGNORE_SIGPIPE = 1 << 12
TRANSACTIONS = 1 << 13
SECURE_CONNECTION = 1 << 15
MULTI_STATEMENTS = 1 << 16
MULTI_RESULTS = 1 << 17
CAPABILITIES = LONG_PASSWORD|LONG_FLAG|TRANSACTIONS| \
PROTOCOL_41|SECURE_CONNECTION

View File

@@ -1,20 +0,0 @@
CLIENT_LONG_PASSWORD = 1
CLIENT_FOUND_ROWS = 1 << 1
CLIENT_LONG_FLAG = 1 << 2
CLIENT_CONNECT_WITH_DB = 1 << 3
CLIENT_NO_SCHEMA = 1 << 4
CLIENT_COMPRESS = 1 << 5
CLIENT_ODBC = 1 << 6
CLIENT_LOCAL_FILES = 1 << 7
CLIENT_IGNORE_SPACE = 1 << 8
CLIENT_PROTOCOL_41 = 1 << 9
CLIENT_INTERACTIVE = 1 << 10
CLIENT_SSL = 1 << 11
CLIENT_IGNORE_SIGPIPE = 1 << 12
CLIENT_TRANSACTIONS = 1 << 13
CLIENT_SECURE_CONNECTION = 1 << 15
CLIENT_MULTI_STATEMENTS = 1 << 16
CLIENT_MULTI_RESULTS = 1 << 17
CLIENT_CAPABILITIES = CLIENT_LONG_PASSWORD|CLIENT_LONG_FLAG|CLIENT_TRANSACTIONS| \
CLIENT_PROTOCOL_41|CLIENT_SECURE_CONNECTION

15
pymysql/constants/FLAG.py Normal file
View File

@@ -0,0 +1,15 @@
NOT_NULL = 1
PRI_KEY = 2
UNIQUE_KEY = 4
MULTIPLE_KEY = 8
BLOB = 16
UNSIGNED = 32
ZEROFILL = 64
BINARY = 128
ENUM = 256
AUTO_INCREMENT = 512
TIMESTAMP = 1024
SET = 2048
PART_KEY = 16384
GROUP = 32767
UNIQUE = 65536

View File

@@ -1,68 +1,81 @@
import re
import datetime
import time
import array
import struct
from pymysql.times import Date, Time, TimeDelta, Timestamp
from pymysql.constants import FIELD_TYPE
ESCAPE_REGEX = re.compile(r"[\0\n\r\032\'\"\\]", re.IGNORECASE)
ESCAPE_REGEX = re.compile(r"[\0\n\r\032\'\"\\]")
ESCAPE_MAP = {'\0': '\\0', '\n': '\\n', '\r': '\\r', '\032': '\\Z',
'\'': '\\\'', '"': '\\"', '\\': '\\\\'}
def escape_item(val):
encorder = encoders[type(val)]
return encorder(val)
def escape_item(val, charset):
encoder = encoders[type(val)]
return encoder(val, charset)
def escape_dict(val):
def escape_dict(val, charset):
n = {}
for k, v in val.items():
quoted = escape_item(v)
quoted = escape_item(v, charset)
n[k] = quoted
return n
def escape_sequence(val):
def escape_sequence(val, charset):
n = []
for item in val:
quoted = escape_item(item)
quoted = escape_item(item, charset)
n.append(quoted)
return tuple(n)
def escape_bool(value):
return str(int(value))
def escape_bool(value, charset):
return str(int(value)).encode(charset)
def escape_object(value):
return str(val)
def escape_object(value, charset):
return str(value).encode(charset)
def escape_float(value):
return '%.15g' % value
escape_int = escape_long = escape_object
def escape_sequence(value):
return value
def escape_float(value, charset):
return ('%.15g' % value).encode(charset)
def escape_string(value):
def escape_string(value, charset):
r = ("'%s'" % ESCAPE_REGEX.sub(
lambda match: ESCAPE_MAP.get(match.group(0)), value))
if not charset is None:
r = r.encode(charset)
return r
def rep(m):
n = m.group(0)
if n == "\0":
return "\\0"
elif n == "\n":
return "\\n"
elif n == "\r":
return "\\r"
elif n == "\032":
return "\\Z"
else:
return "\\"+n
s = re.sub(ESCAPE_REGEX, rep, value)
return s
def escape_unicode(value, charset):
# pass None as the charset because we already encode it
return escape_string(value.encode(charset), None)
def escape_None(value, charset):
return 'NULL'.encode(charset)
def escape_timedelta(obj):
def escape_timedelta(obj, charset):
seconds = int(obj.seconds) % 60
minutes = int(obj.seconds / 60) % 60
hours = int(obj.seconds / 3600) % 24
return '%d %02d:%02d:%02d' % (obj.days, hours, minutes, seconds)
minutes = int(obj.seconds // 60) % 60
hours = int(obj.seconds // 3600) % 24 + int(obj.days) * 24
return escape_string('%02d:%02d:%02d' % (hours, minutes, seconds), charset)
def escape_datetime(obj):
return obj.strftime("%Y-%m-%d %H:%M:%S")
def escape_time(obj, charset):
s = "%02d:%02d:%02d" % (int(obj.hour), int(obj.minute),
int(obj.second))
if obj.microsecond:
s += ".%f" % obj.microsecond
return escape_string(s, charset)
def escape_datetime(obj, charset):
return escape_string(obj.strftime("%Y-%m-%d %H:%M:%S"), charset)
def escape_date(obj, charset):
return escape_string(obj.strftime("%Y-%m-%d"), charset)
def escape_struct_time(obj, charset):
return escape_datetime(datetime.datetime(*obj[:6]), charset)
def convert_datetime(obj):
"""Returns a DATETIME or TIMESTAMP column value as a datetime object:
@@ -85,13 +98,13 @@ def convert_datetime(obj):
elif 'T' in obj:
sep = 'T'
else:
return date_or_None(obj)
return convert_date(obj)
try:
ymd, hms = obj.split(sep, 1)
return datetime.datetime(*[ int(x) for x in ymd.split('-')+hms.split(':') ])
except ValueError:
return date_or_None(obj)
return convert_date(obj)
def convert_timedelta(obj):
"""Returns a TIME column as a timedelta object:
@@ -112,17 +125,14 @@ def convert_timedelta(obj):
"""
from math import modf
try:
hours, minutes, seconds = obj.split(':')
hours, minutes, seconds = tuple(int(x) for x in obj.split(':'))
tdelta = datetime.timedelta(
hours = int(hours),
minutes = int(minutes),
seconds = int(seconds),
microseconds = int(modf(float(seconds))[0]*1000000),
)
if hours < 0:
return -tdelta
else:
return tdelta
return tdelta
except ValueError:
return None
@@ -197,7 +207,7 @@ def convert_mysql_timestamp(timestamp):
"""
if timestamp[4] == '-':
return datetime_or_None(timestamp)
return convert_datetime(timestamp)
timestamp += "0"*(14-len(timestamp)) # padding
year, month, day, hour, minute, second = \
int(timestamp[:4]), int(timestamp[4:6]), int(timestamp[6:8]), \
@@ -207,21 +217,35 @@ def convert_mysql_timestamp(timestamp):
except ValueError:
return None
def convert_set(s):
# TODO: this may not be correct
return set(s.split(","))
def convert_bit(b):
b = "\x00" * (8 - len(b)) + b # pad w/ zeroes
return struct.unpack(">Q", b)[0]
encoders = {
bool: escape_bool,
int: escape_object,
long: escape_object,
int: escape_int,
long: escape_long,
float: escape_float,
str: escape_string,
unicode: escape_unicode,
tuple: escape_sequence,
list:escape_sequence,
set:escape_sequence,
dict:escape_dict,
type(None):escape_None,
datetime.date: escape_date,
datetime.datetime : escape_datetime,
datetime.timedelta : escape_timedelta
datetime.timedelta : escape_timedelta,
datetime.time : escape_time,
time.struct_time : escape_struct_time,
}
decoders = {
FIELD_TYPE.BIT: convert_bit,
FIELD_TYPE.TINY: int,
FIELD_TYPE.SHORT: int,
FIELD_TYPE.LONG: long,
@@ -236,17 +260,37 @@ decoders = {
FIELD_TYPE.DATETIME: convert_datetime,
FIELD_TYPE.TIME: convert_timedelta,
FIELD_TYPE.DATE: convert_date,
FIELD_TYPE.BLOB: str,
FIELD_TYPE.STRING: str,
FIELD_TYPE.VAR_STRING: str,
FIELD_TYPE.VARCHAR: str
FIELD_TYPE.SET: convert_set,
#FIELD_TYPE.BLOB: str,
#FIELD_TYPE.STRING: str,
#FIELD_TYPE.VAR_STRING: str,
#FIELD_TYPE.VARCHAR: str
}
conversions = decoders # for MySQLdb compatibility
def decode_characters(connection, field, data):
if field.charsetnr == 63 or not connection.use_unicode:
# binary data, leave it alone
return data
return data.decode(connection.charset)
# These take a field instance rather than just the data.
field_decoders = {
FIELD_TYPE.BLOB: decode_characters,
FIELD_TYPE.STRING: decode_characters,
FIELD_TYPE.VAR_STRING: decode_characters,
FIELD_TYPE.VARCHAR: decode_characters,
}
try:
# python version > 2.3
from decimal import Decimal
decoders[FIELD_TYPE.DECIMAL] = Decimal
decoders[FIELD_TYPE.NEWDECIMAL] = Decimal
def escape_decimal(obj):
return str(obj)
encoders[Decimal] = escape_decimal
except ImportError:
pass

View File

@@ -1,9 +1,11 @@
import struct
import re
from pymysql.exceptions import Warning, Error, InterfaceError, DataError, \
from pymysql.err import Warning, Error, InterfaceError, DataError, \
DatabaseError, OperationalError, IntegrityError, InternalError, \
NotSupportedError, ProgrammingError
insert_values = re.compile(r'\svalues\s*(\(.+\))', re.IGNORECASE)
class Cursor(object):
@@ -16,7 +18,7 @@ class Cursor(object):
self.arraysize = 1
self._executed = None
self.messages = []
self.errorhandler =connection.errorhandler
self.errorhandler = connection.errorhandler
self._has_next = None
self._rows = ()
@@ -68,11 +70,15 @@ class Cursor(object):
charset = conn.charset
del self.messages[:]
# this ordering is good because conn.escape() returns
# an encoded string.
if isinstance(query, unicode):
query = query.encode(charset)
if args is not None:
query = query % conn.escape(args)
result = 0
try:
result = self._query(query)
except:
@@ -98,7 +104,13 @@ class Cursor(object):
def callproc(self, procname, args=()):
self.errorhandler(self, NotSupportedError, "not supported")
#self.errorhandler(self, NotSupportedError, "not supported")
if not isinstance(args, tuple):
args = (args,)
argstr = ("%s," * len(args))[:-1]
return self.execute("CALL `%s`(%s)" % (procname, argstr), args)
def fetchone(self):
self._check_executed()

View File

@@ -108,22 +108,24 @@ del StandardError, _map_error, ER
def _get_error_info(data):
errno = struct.unpack('h', data[5:7])[0]
sqlstate = struct.unpack('5s', data[8:8+5])[0]
start = 13
end = data.find('\0', start)
errorvalue = data[start:end]
return (errno, sqlstate, errorvalue)
errno = struct.unpack('<h', data[1:3])[0]
if data[3] == "#":
# version 4.1
sqlstate = data[4:9].decode("utf8")
errorvalue = data[9:].decode("utf8")
return (errno, sqlstate, errorvalue)
else:
# version 4.0
return (errno, None, data[3:].decode("utf8"))
def _check_mysql_exception(errinfo):
errno, sqlstate, errorvalue = errinfo
errorclass = error_map.get(errno, None)
if errorclass:
raise errorclass, errorvalue
"""
TODO not found errno
"""
raise InternalError, ""
raise errorclass, (errno,errorvalue)
# couldn't find the right error number
raise InternalError, (errno, errorvalue)
def raise_mysql_exception(data):
errinfo = _get_error_info(data)

View File

@@ -0,0 +1,5 @@
from test_issues import *
if __name__ == "__main__":
import unittest
unittest.main()

18
pymysql/tests/base.py Normal file
View File

@@ -0,0 +1,18 @@
import pymysql
import unittest
class PyMySqlTestCase(unittest.TestCase):
databases = [
{"host":"localhost","user":"root","passwd":"","db":"test_pymysql"},
{"host":"localhost","user":"root","passwd":"","db":"test_pymysql2"}]
def setUp(self):
self.connections = []
for params in self.databases:
self.connections.append(pymysql.connect(**params))
def tearDown(self):
for connection in self.connections:
connection.close()

View File

@@ -0,0 +1,144 @@
import pymysql
import base
import imp
import datetime
class TestOldIssues(base.PyMySqlTestCase):
def test_issue_3(self):
""" undefined methods datetime_or_None, date_or_None """
conn = self.connections[0]
c = conn.cursor()
c.execute("create table issue3 (d date, t time, dt datetime, ts timestamp)")
try:
c.execute("insert into issue3 (d, t, dt, ts) values (%s,%s,%s,%s)", (None, None, None, None))
c.execute("select d from issue3")
self.assertEqual(None, c.fetchone()[0])
c.execute("select t from issue3")
self.assertEqual(None, c.fetchone()[0])
c.execute("select dt from issue3")
self.assertEqual(None, c.fetchone()[0])
c.execute("select ts from issue3")
self.assertTrue(isinstance(c.fetchone()[0], datetime.datetime))
finally:
c.execute("drop table issue3")
def test_issue_4(self):
""" can't retrieve TIMESTAMP fields """
conn = self.connections[0]
c = conn.cursor()
c.execute("create table issue4 (ts timestamp)")
try:
c.execute("insert into issue4 (ts) values (now())")
c.execute("select ts from issue4")
self.assertTrue(isinstance(c.fetchone()[0], datetime.datetime))
finally:
c.execute("drop table issue4")
def test_issue_6(self):
""" exception: TypeError: ord() expected a character, but string of length 0 found """
conn = pymysql.connect(host="localhost",user="root",passwd="",db="mysql")
c = conn.cursor()
c.execute("select * from user")
conn.close()
def test_issue_8(self):
""" Primary Key and Index error when selecting data """
conn = self.connections[0]
c = conn.cursor()
c.execute("""CREATE TABLE `test` (`station` int(10) NOT NULL DEFAULT '0', `dh`
datetime NOT NULL DEFAULT '0000-00-00 00:00:00', `echeance` int(1) NOT NULL
DEFAULT '0', `me` double DEFAULT NULL, `mo` double DEFAULT NULL, PRIMARY
KEY (`station`,`dh`,`echeance`)) ENGINE=MyISAM DEFAULT CHARSET=latin1;""")
try:
self.assertEqual(0, c.execute("SELECT * FROM test"))
c.execute("ALTER TABLE `test` ADD INDEX `idx_station` (`station`)")
self.assertEqual(0, c.execute("SELECT * FROM test"))
finally:
c.execute("drop table test")
def test_issue_9(self):
""" sets DeprecationWarning in Python 2.6 """
try:
imp.reload(pymysql)
except DeprecationWarning:
self.fail()
def test_issue_10(self):
""" Allocate a variable to return when the exception handler is permissive """
conn = self.connections[0]
conn.errorhandler = lambda cursor, errorclass, errorvalue: None
cur = conn.cursor()
cur.execute( "create table t( n int )" )
cur.execute( "create table t( n int )" )
def test_issue_13(self):
""" can't handle large result fields """
conn = self.connections[0]
cur = conn.cursor()
cur.execute("create table issue13 (t text)")
try:
# ticket says 18k
size = 18*1024
cur.execute("insert into issue13 (t) values (%s)", ("x" * size,))
cur.execute("select t from issue13")
# use assert_ so that obscenely huge error messages don't print
r = cur.fetchone()[0]
self.assert_("x" * size == r)
finally:
cur.execute("drop table issue13")
def test_issue_14(self):
""" typo in converters.py """
self.assertEqual('1', pymysql.converters.escape_item(1, "utf8"))
self.assertEqual('1', pymysql.converters.escape_item(1L, "utf8"))
self.assertEqual('1', pymysql.converters.escape_object(1, "utf8"))
self.assertEqual('1', pymysql.converters.escape_object(1L, "utf8"))
def test_issue_15(self):
""" query should be expanded before perform character encoding """
conn = self.connections[0]
c = conn.cursor()
c.execute("create table issue15 (t varchar(32))")
try:
c.execute("insert into issue15 (t) values (%s)", (u'\xe4\xf6\xfc'))
c.execute("select t from issue15")
self.assertEqual(u'\xe4\xf6\xfc', c.fetchone()[0])
finally:
c.execute("drop table issue15")
def test_issue_16(self):
""" Patch for string and tuple escaping """
conn = self.connections[0]
c = conn.cursor()
c.execute("create table issue16 (name varchar(32) primary key, email varchar(32))")
try:
c.execute("insert into issue16 (name, email) values ('pete', 'floydophone')")
c.execute("select email from issue16 where name=%s", ("pete",))
self.assertEqual("floydophone", c.fetchone()[0])
finally:
c.execute("drop table issue16")
def test_issue_17(self):
""" could not connect mysql use passwod """
conn = self.connections[0]
host = self.databases[0]["host"]
db = self.databases[0]["db"]
c = conn.cursor()
# grant access to a table to a user with a password
c.execute("create table issue17 (x varchar(32) primary key)")
try:
c.execute("insert into issue17 (x) values ('hello, world!')")
c.execute("grant all privileges on issue17 to issue17user identified by '1234'")
conn2 = pymysql.connect(host=host, user="issue17user", passwd="1234", db=db)
c2 = conn2.cursor()
c2.execute("select x from issue17")
self.assertEqual("hello, world!", c2.fetchone()[0])
finally:
c.execute("drop table issue17")
if __name__ == "__main__":
import unittest
unittest.main()

View File

@@ -1,5 +1,6 @@
from setuptools import setup
#from setuptools import setup
from distutils.core import setup
version_tuple = __import__('pymysql').VERSION
@@ -17,5 +18,5 @@ setup(
maintainer = 'David.Story',
maintainer_email = 'iDavidStory@gmail.com',
description = 'Pure Python MySQL Driver ',
packages = ['pymysql', 'pymysql.constants']
packages = ['pymysql', 'pymysql.constants', 'pymysql.tests']
)