My first commit. Did a lot of refactoring, so see the updated README file
This commit is contained in:
28
README
28
README
@@ -1,15 +1,35 @@
|
||||
====================
|
||||
pymysql Installation
|
||||
PyMySQL Installation
|
||||
====================
|
||||
|
||||
.. contents::
|
||||
..
|
||||
This package contains a pure python mysql client library.
|
||||
This package contains a pure-Python MySQL client library.
|
||||
Documentation on the MySQL client/server protocol can be found here:
|
||||
http://forge.mysql.com/wiki/MySQL_Internals_ClientServer_Protocol
|
||||
If you would like to run the test suite, create a ~/.my.cnf file and
|
||||
a database called "test_pymysql". The goal of pymysql is to be a drop-in
|
||||
replacement for MySQLdb and work on CPython 2.3+, Jython, IronPython, PyPy
|
||||
and Python 3. We test for compatibility by simply changing the import
|
||||
statements in the Django MySQL backend and running its unit tests as well
|
||||
as running it against the MySQLdb and myconnpy unit tests.
|
||||
|
||||
Changes
|
||||
--------
|
||||
0.3 -Implemented most of the extended DBAPI 2.0 spec including callproc()
|
||||
-Fixed error handling to include the message from the server and support
|
||||
multiple protocol versions.
|
||||
-Implemented ping()
|
||||
-Implemented unicode support (probably needs better testing)
|
||||
-Removed DeprecationWarnings
|
||||
-Ran against the MySQLdb and myconnpy unit tests to check for bugs
|
||||
-Added support for client_flag, charset, sql_mode, read_default_file, and
|
||||
use_unicode.
|
||||
-Refactoring for some more compatibility with MySQLdb including a fake
|
||||
pymysql.version_info attribute.
|
||||
-Now runs with no warnings with the -3 command-line switch
|
||||
-Added test cases for all outstanding tickets and closed most of them.
|
||||
|
||||
0.2 -Changed connection parameter name 'password' to 'passwd'
|
||||
to make it more plugin replaceable for the other mysql clients.
|
||||
-Changed pack()/unpack() calls so it runs on 64 bit OSes too.
|
||||
@@ -26,11 +46,11 @@ Requirements
|
||||
|
||||
* http://www.python.org/
|
||||
|
||||
* 2.5 is the primary test environment.
|
||||
* 2.6 is the primary test environment.
|
||||
|
||||
* MySQL 4.1 or higher
|
||||
|
||||
* protocol41 support
|
||||
* protocol41 support, experimental 4.0 support
|
||||
|
||||
Installation
|
||||
------------
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
|
||||
import pymysql
|
||||
|
||||
conn = pymysql.connect(host='127.0.0.1', unix_socket='/tmp/mysql.sock', user='root', passwd=None, db='mysql')
|
||||
# conn = pymysql.connect(host='127.0.0.1', port=3306, user='root', passwd=None, db='mysql')
|
||||
#conn = pymysql.connect(host='127.0.0.1', unix_socket='/tmp/mysql.sock', user='root', passwd=None, db='mysql')
|
||||
conn = pymysql.connect(host='127.0.0.1', port=3306, user='root', passwd=None, db='mysql')
|
||||
|
||||
cur = conn.cursor()
|
||||
|
||||
@@ -14,7 +14,7 @@ cur.execute("SELECT Host,User FROM user")
|
||||
# r = cur.fetchall()
|
||||
# print r
|
||||
# ...or...
|
||||
for r in cur:
|
||||
for r in cur.fetchall():
|
||||
print r
|
||||
|
||||
cur.close()
|
||||
|
||||
@@ -1,37 +1,41 @@
|
||||
VERSION = (0, 2, None)
|
||||
__version__ = VERSION # for MySQLdb compatibility
|
||||
VERSION = (0, 3, None)
|
||||
|
||||
from pymysql.constants import FIELD_TYPE
|
||||
from pymysql.converters import escape_dict, escape_sequence, escape_string
|
||||
from pymysql.exceptions import Warning, Error, InterfaceError, DataError, \
|
||||
from pymysql.err import Warning, Error, InterfaceError, DataError, \
|
||||
DatabaseError, OperationalError, IntegrityError, InternalError, \
|
||||
NotSupportedError, ProgrammingError
|
||||
from pymysql.times import Date, Time, Timestamp, \
|
||||
DateFromTicks, TimeFromTicks, TimestampFromTicks
|
||||
|
||||
from sets import ImmutableSet
|
||||
try:
|
||||
frozenset
|
||||
except NameError:
|
||||
from sets import ImmutableSet as frozenset
|
||||
from sets import BaseSet as set
|
||||
|
||||
threadsafety = 1
|
||||
apilevel = "2.0"
|
||||
paramstyle = "format"
|
||||
|
||||
class DBAPISet(ImmutableSet):
|
||||
class DBAPISet(frozenset):
|
||||
|
||||
|
||||
def __ne__(self, other):
|
||||
from sets import BaseSet
|
||||
if isinstance(other, BaseSet):
|
||||
if isinstance(other, set):
|
||||
return super(DBAPISet, self).__ne__(self, other)
|
||||
else:
|
||||
return other not in self
|
||||
|
||||
def __eq__(self, other):
|
||||
from sets import BaseSet
|
||||
if isinstance(other, BaseSet):
|
||||
return super(DBAPISet, self).__eq__(self, other)
|
||||
if isinstance(other, frozenset):
|
||||
return frozenset.__eq__(self, other)
|
||||
else:
|
||||
return other in self
|
||||
|
||||
def __hash__(self):
|
||||
return frozenset.__hash__(self)
|
||||
|
||||
|
||||
STRING = DBAPISet([FIELD_TYPE.ENUM, FIELD_TYPE.STRING,
|
||||
FIELD_TYPE.VAR_STRING])
|
||||
@@ -59,6 +63,16 @@ def get_client_info(): # for MySQLdb compatibility
|
||||
|
||||
connect = Connection = Connect
|
||||
|
||||
# we include a doctored version_info here for MySQLdb compatibility
|
||||
version_info = (1,2,2,"final",0)
|
||||
|
||||
NULL = "NULL"
|
||||
|
||||
__version__ = get_client_info()
|
||||
|
||||
def thread_safe():
|
||||
# Pure python, so yes we're threadsafe
|
||||
return True
|
||||
|
||||
__all__ = [
|
||||
'BINARY', 'Binary', 'Connect', 'Connection', 'DATE', 'Date',
|
||||
@@ -70,4 +84,6 @@ __all__ = [
|
||||
'connections', 'constants', 'converters', 'cursors', 'debug', 'escape',
|
||||
'escape_dict', 'escape_sequence', 'escape_string', 'get_client_info',
|
||||
'paramstyle', 'string_literal', 'threadsafety', 'version_info',
|
||||
|
||||
"NULL","__version__",
|
||||
]
|
||||
|
||||
@@ -2,19 +2,28 @@
|
||||
# http://forge.mysql.com/wiki/MySQL_Internals_ClientServer_Protocol
|
||||
|
||||
import re
|
||||
import sha
|
||||
|
||||
try:
|
||||
import hashlib
|
||||
sha_new = lambda *args, **kwargs: hashlib.new("sha1", *args, **kwargs)
|
||||
except ImportError:
|
||||
import sha
|
||||
sha_new = sha.new
|
||||
|
||||
import socket
|
||||
import struct
|
||||
import sys
|
||||
import os
|
||||
import ConfigParser
|
||||
|
||||
from pymysql.charset import MBLENGTH
|
||||
from pymysql.cursor import Cursor
|
||||
from pymysql.cursors import Cursor
|
||||
from pymysql.constants import FIELD_TYPE
|
||||
from pymysql.constants import SERVER_STATUS
|
||||
from pymysql.constants.CLIENT_FLAG import *
|
||||
from pymysql.constants.CLIENT import *
|
||||
from pymysql.constants.COMMAND import *
|
||||
from pymysql.converters import escape_item, encoders, decoders
|
||||
from pymysql.exceptions import raise_mysql_exception, Warning, Error, \
|
||||
from pymysql.converters import escape_item, encoders, decoders, field_decoders
|
||||
from pymysql.err import raise_mysql_exception, Warning, Error, \
|
||||
InterfaceError, DataError, DatabaseError, OperationalError, \
|
||||
IntegrityError, InternalError, NotSupportedError, ProgrammingError
|
||||
|
||||
@@ -57,9 +66,9 @@ def _scramble(password, message):
|
||||
if password == None or len(password) == 0:
|
||||
return '\0'
|
||||
if DEBUG: print 'password=' + password
|
||||
stage1 = sha.new(password).digest()
|
||||
stage2 = sha.new(stage1).digest()
|
||||
s = sha.new()
|
||||
stage1 = sha_new(password).digest()
|
||||
stage2 = sha_new(stage1).digest()
|
||||
s = sha_new()
|
||||
s.update(message)
|
||||
s.update(stage2)
|
||||
result = s.digest()
|
||||
@@ -105,7 +114,11 @@ def defaulterrorhandler(connection, cursor, errorclass, errorvalue):
|
||||
connection.messages.append(err)
|
||||
del cursor
|
||||
del connection
|
||||
raise errorclass, errorvalue
|
||||
|
||||
if not issubclass(errorclass, Error):
|
||||
raise Error(errorclass, errorvalue)
|
||||
else:
|
||||
raise errorclass, errorvalue
|
||||
|
||||
|
||||
class MysqlPacket(object):
|
||||
@@ -121,6 +134,9 @@ class MysqlPacket(object):
|
||||
def __recv_packet(self, socket):
|
||||
"""Parse the packet header and read entire packet payload into buffer."""
|
||||
packet_header = socket.recv(4)
|
||||
while len(packet_header) < 4:
|
||||
packet_header += socket.recv(4 - len(packet_header))
|
||||
|
||||
if DEBUG: dump_packet(packet_header)
|
||||
packet_length_bin = packet_header[:3]
|
||||
self.__packet_number = ord(packet_header[3])
|
||||
@@ -303,7 +319,7 @@ class FieldDescriptorPacket(MysqlPacket):
|
||||
def get_column_length(self):
|
||||
if self.type_code == FIELD_TYPE.VAR_STRING:
|
||||
mblen = MBLENGTH.get(self.charsetnr, 1)
|
||||
return self.length / mblen
|
||||
return self.length // mblen
|
||||
return self.length
|
||||
|
||||
def __str__(self):
|
||||
@@ -316,54 +332,103 @@ class Connection(object):
|
||||
"""Representation of a socket with a mysql server."""
|
||||
errorhandler = defaulterrorhandler
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.host = kwargs['host']
|
||||
self.port = kwargs.get('port', 3306)
|
||||
self.user = kwargs['user']
|
||||
self.password = kwargs['passwd']
|
||||
self.db = kwargs.get('db', None)
|
||||
self.unix_socket = kwargs.get('unix_socket', None)
|
||||
self.charset = DEFAULT_CHARSET
|
||||
|
||||
client_flag = CLIENT_CAPABILITIES
|
||||
#client_flag = kwargs.get('client_flag', None)
|
||||
client_flag |= CLIENT_MULTI_STATEMENTS
|
||||
def __init__(self, host="localhost", user=None, passwd="",
|
||||
db=None, port=3306, unix_socket=None,
|
||||
charset=DEFAULT_CHARSET, sql_mode=None,
|
||||
read_default_file=None, conv=decoders, use_unicode=True,
|
||||
client_flag=0):
|
||||
|
||||
if read_default_file:
|
||||
cfg = ConfigParser.RawConfigParser()
|
||||
cfg.read(os.path.expanduser(read_default_file))
|
||||
|
||||
def _config(key, default):
|
||||
try:
|
||||
return cfg.get("client",key)
|
||||
except:
|
||||
return default
|
||||
|
||||
user = _config("user",user)
|
||||
passwd = _config("password",passwd)
|
||||
host = _config("host", host)
|
||||
db = _config("db",db)
|
||||
unix_socket = _config("socket",unix_socket)
|
||||
port = _config("port", port)
|
||||
charset = _config("default-character-set", charset)
|
||||
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.user = user
|
||||
self.password = passwd
|
||||
self.db = db
|
||||
self.unix_socket = unix_socket
|
||||
self.use_unicode = use_unicode
|
||||
self.charset = charset
|
||||
client_flag |= CAPABILITIES
|
||||
client_flag |= MULTI_STATEMENTS
|
||||
if self.db:
|
||||
client_flag |= CLIENT_CONNECT_WITH_DB
|
||||
client_flag |= CONNECT_WITH_DB
|
||||
self.client_flag = client_flag
|
||||
|
||||
self._connect()
|
||||
|
||||
charset = kwargs.get('charset', None)
|
||||
self.set_chatset_set(charset)
|
||||
self.set_charset_set(charset)
|
||||
self.messages = []
|
||||
self.encoders = encoders
|
||||
self.decoders = decoders
|
||||
self.decoders = conv
|
||||
self.field_decoders = field_decoders
|
||||
|
||||
self._affected_rows = 0
|
||||
self.host_info = "Not connected"
|
||||
|
||||
self.autocommit(False)
|
||||
|
||||
if sql_mode is not None:
|
||||
c = self.cursor()
|
||||
c.execute("SET sql_mode=%s", (sql_mode,))
|
||||
self.commit()
|
||||
|
||||
|
||||
def close(self):
|
||||
send_data = struct.pack('<i',1) + COM_QUIT
|
||||
sock = self.socket
|
||||
sock.send(send_data)
|
||||
sock.close()
|
||||
try:
|
||||
send_data = struct.pack('<i',1) + COM_QUIT
|
||||
sock = self.socket
|
||||
sock.send(send_data)
|
||||
sock.close()
|
||||
except:
|
||||
exc,value,tb = sys.exc_info()
|
||||
self.errorhandler(None, exc, value)
|
||||
|
||||
def autocommit(self, value):
|
||||
self._execute_command(COM_QUERY, "SET AUTOCOMMIT = %s" % \
|
||||
self.escape(value))
|
||||
self.read_packet()
|
||||
try:
|
||||
self._execute_command(COM_QUERY, "SET AUTOCOMMIT = %s" % \
|
||||
self.escape(value))
|
||||
self.read_packet()
|
||||
except:
|
||||
exc,value,tb = sys.exc_info()
|
||||
self.errorhandler(None, exc, value)
|
||||
|
||||
def commit(self):
|
||||
self._execute_command(COM_QUERY, "COMMIT")
|
||||
self.read_packet()
|
||||
try:
|
||||
self._execute_command(COM_QUERY, "COMMIT")
|
||||
self.read_packet()
|
||||
except:
|
||||
exc,value,tb = sys.exc_info()
|
||||
self.errorhandler(None, exc, value)
|
||||
|
||||
def rollback(self):
|
||||
self._execute_command(COM_QUERY, "ROLLBACK")
|
||||
self.read_packet()
|
||||
try:
|
||||
self._execute_command(COM_QUERY, "ROLLBACK")
|
||||
self.read_packet()
|
||||
except:
|
||||
exc,value,tb = sys.exc_info()
|
||||
self.errorhandler(None, exc, value)
|
||||
|
||||
def escape(self, obj):
|
||||
return escape_item(obj)
|
||||
return escape_item(obj, self.charset)
|
||||
|
||||
def literal(self, obj):
|
||||
return escape_item(obj, self.charset)
|
||||
|
||||
def cursor(self):
|
||||
return Cursor(self)
|
||||
@@ -379,26 +444,53 @@ class Connection(object):
|
||||
|
||||
def query(self, sql):
|
||||
self._execute_command(COM_QUERY, sql)
|
||||
return self._read_query_result()
|
||||
self._affected_rows = self._read_query_result()
|
||||
return self._affected_rows
|
||||
|
||||
def next_result(self):
|
||||
return self._read_query_result()
|
||||
self._affected_rows = self._read_query_result()
|
||||
return self._affected_rows
|
||||
|
||||
def set_chatset_set(self, charset):
|
||||
sock = self.socket
|
||||
if charset and self.charset != charset:
|
||||
self._execute_command(COM_QUERY, "SET NAMES %s" % charset)
|
||||
self.read_packet()
|
||||
self.charset = charset
|
||||
def affected_rows(self):
|
||||
return self._affected_rows
|
||||
|
||||
def ping(self, reconnect=True):
|
||||
try:
|
||||
self._execute_command(COM_PING, "")
|
||||
except:
|
||||
if reconnect:
|
||||
self._connect()
|
||||
return self.ping(False)
|
||||
else:
|
||||
exc,value,tb = sys.exc_info()
|
||||
self.errorhandler(None, exc, value)
|
||||
return
|
||||
|
||||
pkt = self.read_packet()
|
||||
return pkt.is_ok_packet()
|
||||
|
||||
def set_charset_set(self, charset):
|
||||
try:
|
||||
sock = self.socket
|
||||
if charset and self.charset != charset:
|
||||
self._execute_command(COM_QUERY, "SET NAMES %s" %
|
||||
self.escape(charset))
|
||||
self.read_packet()
|
||||
self.charset = charset
|
||||
except:
|
||||
exc,value,tb = sys.exc_info()
|
||||
self.errorhandler(None, exc, value)
|
||||
|
||||
def _connect(self):
|
||||
if self.unix_socket and (self.host == 'localhost' or self.host == '127.0.0.1'):
|
||||
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
sock.connect(self.unix_socket)
|
||||
self.host_info = "Localhost via UNIX socket"
|
||||
if DEBUG: print 'connected using unix_socket'
|
||||
else:
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.connect((self.host, self.port))
|
||||
self.host_info = "socket %s:%d" % (self.host, self.port)
|
||||
if DEBUG: print 'connected using socket'
|
||||
self.socket = sock
|
||||
self._get_server_information()
|
||||
@@ -424,8 +516,10 @@ class Connection(object):
|
||||
|
||||
def _send_command(self, command, sql):
|
||||
send_data = struct.pack('<i', len(sql) + 1) + command + sql
|
||||
|
||||
sock = self.socket
|
||||
sock.send(send_data)
|
||||
|
||||
if DEBUG: dump_packet(send_data)
|
||||
|
||||
def _execute_command(self, command, sql):
|
||||
@@ -437,9 +531,12 @@ class Connection(object):
|
||||
|
||||
def _send_authentication(self):
|
||||
sock = self.socket
|
||||
self.client_flag |= CLIENT_CAPABILITIES
|
||||
self.client_flag |= CAPABILITIES
|
||||
if self.server_version.startswith('5'):
|
||||
self.client_flag |= CLIENT_MULTI_RESULTS
|
||||
self.client_flag |= MULTI_RESULTS
|
||||
|
||||
if self.user is None:
|
||||
raise ValueError, "Did not specify a username"
|
||||
|
||||
data = (struct.pack('i', self.client_flag)) + "\0\0\0\x01" + \
|
||||
'\x08' + '\0'*23 + \
|
||||
@@ -458,6 +555,19 @@ class Connection(object):
|
||||
auth_packet.check_error()
|
||||
if DEBUG: auth_packet.dump()
|
||||
|
||||
# _mysql support
|
||||
def thread_id(self):
|
||||
return self.server_thread_id[0]
|
||||
|
||||
def character_set_name(self):
|
||||
return self.charset
|
||||
|
||||
def get_host_info(self):
|
||||
return self.host_info
|
||||
|
||||
def get_proto_info(self):
|
||||
return self.protocol_version
|
||||
|
||||
def _get_server_information(self):
|
||||
sock = self.socket
|
||||
i = 0
|
||||
@@ -473,7 +583,6 @@ class Connection(object):
|
||||
|
||||
i = server_end + 1
|
||||
self.server_thread_id = struct.unpack('h', data[i:i+2])
|
||||
self.thread_id = self.server_thread_id # MySQLdb compatibility
|
||||
|
||||
i += 4
|
||||
self.salt = data[i:i+8]
|
||||
@@ -561,12 +670,22 @@ class MySQLResult(object):
|
||||
|
||||
row = []
|
||||
for field in self.fields:
|
||||
converter = self.connection.decoders[field.type_code]
|
||||
if DEBUG: print "DEBUG: field=%s, converter=%s" % (field, converter)
|
||||
data = packet.read_length_coded_binary()
|
||||
converted = None
|
||||
if data != None:
|
||||
converted = converter(data)
|
||||
if field.type_code in self.connection.decoders:
|
||||
converter = self.connection.decoders[field.type_code]
|
||||
|
||||
if DEBUG: print "DEBUG: field=%s, converter=%s" % (field, converter)
|
||||
data = packet.read_length_coded_binary()
|
||||
converted = None
|
||||
if data != None:
|
||||
converted = converter(data)
|
||||
else:
|
||||
converter = self.connection.field_decoders[field.type_code]
|
||||
if DEBUG: print "DEBUG: field=%s, converter=%s" % (field, converter)
|
||||
data = packet.read_length_coded_binary()
|
||||
converted = None
|
||||
if data != None:
|
||||
converted = converter(self.connection, field, data)
|
||||
|
||||
row.append(converted)
|
||||
|
||||
rows.append(tuple(row))
|
||||
|
||||
20
pymysql/constants/CLIENT.py
Normal file
20
pymysql/constants/CLIENT.py
Normal file
@@ -0,0 +1,20 @@
|
||||
|
||||
LONG_PASSWORD = 1
|
||||
FOUND_ROWS = 1 << 1
|
||||
LONG_FLAG = 1 << 2
|
||||
CONNECT_WITH_DB = 1 << 3
|
||||
NO_SCHEMA = 1 << 4
|
||||
COMPRESS = 1 << 5
|
||||
ODBC = 1 << 6
|
||||
LOCAL_FILES = 1 << 7
|
||||
IGNORE_SPACE = 1 << 8
|
||||
PROTOCOL_41 = 1 << 9
|
||||
INTERACTIVE = 1 << 10
|
||||
SSL = 1 << 11
|
||||
IGNORE_SIGPIPE = 1 << 12
|
||||
TRANSACTIONS = 1 << 13
|
||||
SECURE_CONNECTION = 1 << 15
|
||||
MULTI_STATEMENTS = 1 << 16
|
||||
MULTI_RESULTS = 1 << 17
|
||||
CAPABILITIES = LONG_PASSWORD|LONG_FLAG|TRANSACTIONS| \
|
||||
PROTOCOL_41|SECURE_CONNECTION
|
||||
@@ -1,20 +0,0 @@
|
||||
|
||||
CLIENT_LONG_PASSWORD = 1
|
||||
CLIENT_FOUND_ROWS = 1 << 1
|
||||
CLIENT_LONG_FLAG = 1 << 2
|
||||
CLIENT_CONNECT_WITH_DB = 1 << 3
|
||||
CLIENT_NO_SCHEMA = 1 << 4
|
||||
CLIENT_COMPRESS = 1 << 5
|
||||
CLIENT_ODBC = 1 << 6
|
||||
CLIENT_LOCAL_FILES = 1 << 7
|
||||
CLIENT_IGNORE_SPACE = 1 << 8
|
||||
CLIENT_PROTOCOL_41 = 1 << 9
|
||||
CLIENT_INTERACTIVE = 1 << 10
|
||||
CLIENT_SSL = 1 << 11
|
||||
CLIENT_IGNORE_SIGPIPE = 1 << 12
|
||||
CLIENT_TRANSACTIONS = 1 << 13
|
||||
CLIENT_SECURE_CONNECTION = 1 << 15
|
||||
CLIENT_MULTI_STATEMENTS = 1 << 16
|
||||
CLIENT_MULTI_RESULTS = 1 << 17
|
||||
CLIENT_CAPABILITIES = CLIENT_LONG_PASSWORD|CLIENT_LONG_FLAG|CLIENT_TRANSACTIONS| \
|
||||
CLIENT_PROTOCOL_41|CLIENT_SECURE_CONNECTION
|
||||
15
pymysql/constants/FLAG.py
Normal file
15
pymysql/constants/FLAG.py
Normal file
@@ -0,0 +1,15 @@
|
||||
NOT_NULL = 1
|
||||
PRI_KEY = 2
|
||||
UNIQUE_KEY = 4
|
||||
MULTIPLE_KEY = 8
|
||||
BLOB = 16
|
||||
UNSIGNED = 32
|
||||
ZEROFILL = 64
|
||||
BINARY = 128
|
||||
ENUM = 256
|
||||
AUTO_INCREMENT = 512
|
||||
TIMESTAMP = 1024
|
||||
SET = 2048
|
||||
PART_KEY = 16384
|
||||
GROUP = 32767
|
||||
UNIQUE = 65536
|
||||
@@ -1,68 +1,81 @@
|
||||
import re
|
||||
import datetime
|
||||
import time
|
||||
import array
|
||||
import struct
|
||||
|
||||
from pymysql.times import Date, Time, TimeDelta, Timestamp
|
||||
from pymysql.constants import FIELD_TYPE
|
||||
|
||||
ESCAPE_REGEX = re.compile(r"[\0\n\r\032\'\"\\]", re.IGNORECASE)
|
||||
ESCAPE_REGEX = re.compile(r"[\0\n\r\032\'\"\\]")
|
||||
ESCAPE_MAP = {'\0': '\\0', '\n': '\\n', '\r': '\\r', '\032': '\\Z',
|
||||
'\'': '\\\'', '"': '\\"', '\\': '\\\\'}
|
||||
|
||||
def escape_item(val):
|
||||
encorder = encoders[type(val)]
|
||||
return encorder(val)
|
||||
def escape_item(val, charset):
|
||||
encoder = encoders[type(val)]
|
||||
return encoder(val, charset)
|
||||
|
||||
def escape_dict(val):
|
||||
def escape_dict(val, charset):
|
||||
n = {}
|
||||
for k, v in val.items():
|
||||
quoted = escape_item(v)
|
||||
quoted = escape_item(v, charset)
|
||||
n[k] = quoted
|
||||
return n
|
||||
|
||||
def escape_sequence(val):
|
||||
def escape_sequence(val, charset):
|
||||
n = []
|
||||
for item in val:
|
||||
quoted = escape_item(item)
|
||||
quoted = escape_item(item, charset)
|
||||
n.append(quoted)
|
||||
return tuple(n)
|
||||
|
||||
def escape_bool(value):
|
||||
return str(int(value))
|
||||
def escape_bool(value, charset):
|
||||
return str(int(value)).encode(charset)
|
||||
|
||||
def escape_object(value):
|
||||
return str(val)
|
||||
def escape_object(value, charset):
|
||||
return str(value).encode(charset)
|
||||
|
||||
def escape_float(value):
|
||||
return '%.15g' % value
|
||||
escape_int = escape_long = escape_object
|
||||
|
||||
def escape_sequence(value):
|
||||
return value
|
||||
def escape_float(value, charset):
|
||||
return ('%.15g' % value).encode(charset)
|
||||
|
||||
def escape_string(value):
|
||||
def escape_string(value, charset):
|
||||
r = ("'%s'" % ESCAPE_REGEX.sub(
|
||||
lambda match: ESCAPE_MAP.get(match.group(0)), value))
|
||||
if not charset is None:
|
||||
r = r.encode(charset)
|
||||
return r
|
||||
|
||||
def rep(m):
|
||||
n = m.group(0)
|
||||
if n == "\0":
|
||||
return "\\0"
|
||||
elif n == "\n":
|
||||
return "\\n"
|
||||
elif n == "\r":
|
||||
return "\\r"
|
||||
elif n == "\032":
|
||||
return "\\Z"
|
||||
else:
|
||||
return "\\"+n
|
||||
s = re.sub(ESCAPE_REGEX, rep, value)
|
||||
return s
|
||||
def escape_unicode(value, charset):
|
||||
# pass None as the charset because we already encode it
|
||||
return escape_string(value.encode(charset), None)
|
||||
|
||||
def escape_None(value, charset):
|
||||
return 'NULL'.encode(charset)
|
||||
|
||||
def escape_timedelta(obj):
|
||||
def escape_timedelta(obj, charset):
|
||||
seconds = int(obj.seconds) % 60
|
||||
minutes = int(obj.seconds / 60) % 60
|
||||
hours = int(obj.seconds / 3600) % 24
|
||||
return '%d %02d:%02d:%02d' % (obj.days, hours, minutes, seconds)
|
||||
minutes = int(obj.seconds // 60) % 60
|
||||
hours = int(obj.seconds // 3600) % 24 + int(obj.days) * 24
|
||||
return escape_string('%02d:%02d:%02d' % (hours, minutes, seconds), charset)
|
||||
|
||||
def escape_datetime(obj):
|
||||
return obj.strftime("%Y-%m-%d %H:%M:%S")
|
||||
def escape_time(obj, charset):
|
||||
s = "%02d:%02d:%02d" % (int(obj.hour), int(obj.minute),
|
||||
int(obj.second))
|
||||
if obj.microsecond:
|
||||
s += ".%f" % obj.microsecond
|
||||
|
||||
return escape_string(s, charset)
|
||||
|
||||
def escape_datetime(obj, charset):
|
||||
return escape_string(obj.strftime("%Y-%m-%d %H:%M:%S"), charset)
|
||||
|
||||
def escape_date(obj, charset):
|
||||
return escape_string(obj.strftime("%Y-%m-%d"), charset)
|
||||
|
||||
def escape_struct_time(obj, charset):
|
||||
return escape_datetime(datetime.datetime(*obj[:6]), charset)
|
||||
|
||||
def convert_datetime(obj):
|
||||
"""Returns a DATETIME or TIMESTAMP column value as a datetime object:
|
||||
@@ -85,13 +98,13 @@ def convert_datetime(obj):
|
||||
elif 'T' in obj:
|
||||
sep = 'T'
|
||||
else:
|
||||
return date_or_None(obj)
|
||||
return convert_date(obj)
|
||||
|
||||
try:
|
||||
ymd, hms = obj.split(sep, 1)
|
||||
return datetime.datetime(*[ int(x) for x in ymd.split('-')+hms.split(':') ])
|
||||
except ValueError:
|
||||
return date_or_None(obj)
|
||||
return convert_date(obj)
|
||||
|
||||
def convert_timedelta(obj):
|
||||
"""Returns a TIME column as a timedelta object:
|
||||
@@ -112,17 +125,14 @@ def convert_timedelta(obj):
|
||||
"""
|
||||
from math import modf
|
||||
try:
|
||||
hours, minutes, seconds = obj.split(':')
|
||||
hours, minutes, seconds = tuple(int(x) for x in obj.split(':'))
|
||||
tdelta = datetime.timedelta(
|
||||
hours = int(hours),
|
||||
minutes = int(minutes),
|
||||
seconds = int(seconds),
|
||||
microseconds = int(modf(float(seconds))[0]*1000000),
|
||||
)
|
||||
if hours < 0:
|
||||
return -tdelta
|
||||
else:
|
||||
return tdelta
|
||||
return tdelta
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
@@ -197,7 +207,7 @@ def convert_mysql_timestamp(timestamp):
|
||||
|
||||
"""
|
||||
if timestamp[4] == '-':
|
||||
return datetime_or_None(timestamp)
|
||||
return convert_datetime(timestamp)
|
||||
timestamp += "0"*(14-len(timestamp)) # padding
|
||||
year, month, day, hour, minute, second = \
|
||||
int(timestamp[:4]), int(timestamp[4:6]), int(timestamp[6:8]), \
|
||||
@@ -207,21 +217,35 @@ def convert_mysql_timestamp(timestamp):
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
def convert_set(s):
|
||||
# TODO: this may not be correct
|
||||
return set(s.split(","))
|
||||
|
||||
def convert_bit(b):
|
||||
b = "\x00" * (8 - len(b)) + b # pad w/ zeroes
|
||||
return struct.unpack(">Q", b)[0]
|
||||
|
||||
encoders = {
|
||||
bool: escape_bool,
|
||||
int: escape_object,
|
||||
long: escape_object,
|
||||
int: escape_int,
|
||||
long: escape_long,
|
||||
float: escape_float,
|
||||
str: escape_string,
|
||||
unicode: escape_unicode,
|
||||
tuple: escape_sequence,
|
||||
list:escape_sequence,
|
||||
set:escape_sequence,
|
||||
dict:escape_dict,
|
||||
type(None):escape_None,
|
||||
datetime.date: escape_date,
|
||||
datetime.datetime : escape_datetime,
|
||||
datetime.timedelta : escape_timedelta
|
||||
datetime.timedelta : escape_timedelta,
|
||||
datetime.time : escape_time,
|
||||
time.struct_time : escape_struct_time,
|
||||
}
|
||||
|
||||
decoders = {
|
||||
FIELD_TYPE.BIT: convert_bit,
|
||||
FIELD_TYPE.TINY: int,
|
||||
FIELD_TYPE.SHORT: int,
|
||||
FIELD_TYPE.LONG: long,
|
||||
@@ -236,17 +260,37 @@ decoders = {
|
||||
FIELD_TYPE.DATETIME: convert_datetime,
|
||||
FIELD_TYPE.TIME: convert_timedelta,
|
||||
FIELD_TYPE.DATE: convert_date,
|
||||
FIELD_TYPE.BLOB: str,
|
||||
FIELD_TYPE.STRING: str,
|
||||
FIELD_TYPE.VAR_STRING: str,
|
||||
FIELD_TYPE.VARCHAR: str
|
||||
FIELD_TYPE.SET: convert_set,
|
||||
#FIELD_TYPE.BLOB: str,
|
||||
#FIELD_TYPE.STRING: str,
|
||||
#FIELD_TYPE.VAR_STRING: str,
|
||||
#FIELD_TYPE.VARCHAR: str
|
||||
}
|
||||
conversions = decoders # for MySQLdb compatibility
|
||||
|
||||
def decode_characters(connection, field, data):
|
||||
if field.charsetnr == 63 or not connection.use_unicode:
|
||||
# binary data, leave it alone
|
||||
return data
|
||||
return data.decode(connection.charset)
|
||||
|
||||
# These take a field instance rather than just the data.
|
||||
field_decoders = {
|
||||
FIELD_TYPE.BLOB: decode_characters,
|
||||
FIELD_TYPE.STRING: decode_characters,
|
||||
FIELD_TYPE.VAR_STRING: decode_characters,
|
||||
FIELD_TYPE.VARCHAR: decode_characters,
|
||||
}
|
||||
|
||||
try:
|
||||
# python version > 2.3
|
||||
from decimal import Decimal
|
||||
decoders[FIELD_TYPE.DECIMAL] = Decimal
|
||||
decoders[FIELD_TYPE.NEWDECIMAL] = Decimal
|
||||
|
||||
def escape_decimal(obj):
|
||||
return str(obj)
|
||||
encoders[Decimal] = escape_decimal
|
||||
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import struct
|
||||
import re
|
||||
|
||||
from pymysql.exceptions import Warning, Error, InterfaceError, DataError, \
|
||||
from pymysql.err import Warning, Error, InterfaceError, DataError, \
|
||||
DatabaseError, OperationalError, IntegrityError, InternalError, \
|
||||
NotSupportedError, ProgrammingError
|
||||
|
||||
|
||||
insert_values = re.compile(r'\svalues\s*(\(.+\))', re.IGNORECASE)
|
||||
|
||||
class Cursor(object):
|
||||
|
||||
@@ -16,7 +18,7 @@ class Cursor(object):
|
||||
self.arraysize = 1
|
||||
self._executed = None
|
||||
self.messages = []
|
||||
self.errorhandler =connection.errorhandler
|
||||
self.errorhandler = connection.errorhandler
|
||||
self._has_next = None
|
||||
self._rows = ()
|
||||
|
||||
@@ -68,11 +70,15 @@ class Cursor(object):
|
||||
charset = conn.charset
|
||||
del self.messages[:]
|
||||
|
||||
# this ordering is good because conn.escape() returns
|
||||
# an encoded string.
|
||||
if isinstance(query, unicode):
|
||||
query = query.encode(charset)
|
||||
|
||||
if args is not None:
|
||||
query = query % conn.escape(args)
|
||||
|
||||
result = 0
|
||||
try:
|
||||
result = self._query(query)
|
||||
except:
|
||||
@@ -98,7 +104,13 @@ class Cursor(object):
|
||||
|
||||
|
||||
def callproc(self, procname, args=()):
|
||||
self.errorhandler(self, NotSupportedError, "not supported")
|
||||
#self.errorhandler(self, NotSupportedError, "not supported")
|
||||
if not isinstance(args, tuple):
|
||||
args = (args,)
|
||||
|
||||
argstr = ("%s," * len(args))[:-1]
|
||||
|
||||
return self.execute("CALL `%s`(%s)" % (procname, argstr), args)
|
||||
|
||||
def fetchone(self):
|
||||
self._check_executed()
|
||||
@@ -108,22 +108,24 @@ del StandardError, _map_error, ER
|
||||
|
||||
|
||||
def _get_error_info(data):
|
||||
errno = struct.unpack('h', data[5:7])[0]
|
||||
sqlstate = struct.unpack('5s', data[8:8+5])[0]
|
||||
start = 13
|
||||
end = data.find('\0', start)
|
||||
errorvalue = data[start:end]
|
||||
return (errno, sqlstate, errorvalue)
|
||||
errno = struct.unpack('<h', data[1:3])[0]
|
||||
if data[3] == "#":
|
||||
# version 4.1
|
||||
sqlstate = data[4:9].decode("utf8")
|
||||
errorvalue = data[9:].decode("utf8")
|
||||
return (errno, sqlstate, errorvalue)
|
||||
else:
|
||||
# version 4.0
|
||||
return (errno, None, data[3:].decode("utf8"))
|
||||
|
||||
def _check_mysql_exception(errinfo):
|
||||
errno, sqlstate, errorvalue = errinfo
|
||||
errorclass = error_map.get(errno, None)
|
||||
if errorclass:
|
||||
raise errorclass, errorvalue
|
||||
"""
|
||||
TODO not found errno
|
||||
"""
|
||||
raise InternalError, ""
|
||||
raise errorclass, (errno,errorvalue)
|
||||
|
||||
# couldn't find the right error number
|
||||
raise InternalError, (errno, errorvalue)
|
||||
|
||||
def raise_mysql_exception(data):
|
||||
errinfo = _get_error_info(data)
|
||||
5
pymysql/tests/__init__.py
Normal file
5
pymysql/tests/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from test_issues import *
|
||||
|
||||
if __name__ == "__main__":
|
||||
import unittest
|
||||
unittest.main()
|
||||
18
pymysql/tests/base.py
Normal file
18
pymysql/tests/base.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import pymysql
|
||||
import unittest
|
||||
|
||||
class PyMySqlTestCase(unittest.TestCase):
|
||||
databases = [
|
||||
{"host":"localhost","user":"root","passwd":"","db":"test_pymysql"},
|
||||
{"host":"localhost","user":"root","passwd":"","db":"test_pymysql2"}]
|
||||
|
||||
def setUp(self):
|
||||
self.connections = []
|
||||
|
||||
for params in self.databases:
|
||||
self.connections.append(pymysql.connect(**params))
|
||||
|
||||
def tearDown(self):
|
||||
for connection in self.connections:
|
||||
connection.close()
|
||||
|
||||
144
pymysql/tests/test_issues.py
Normal file
144
pymysql/tests/test_issues.py
Normal file
@@ -0,0 +1,144 @@
|
||||
import pymysql
|
||||
import base
|
||||
|
||||
import imp
|
||||
import datetime
|
||||
|
||||
class TestOldIssues(base.PyMySqlTestCase):
|
||||
def test_issue_3(self):
|
||||
""" undefined methods datetime_or_None, date_or_None """
|
||||
conn = self.connections[0]
|
||||
c = conn.cursor()
|
||||
c.execute("create table issue3 (d date, t time, dt datetime, ts timestamp)")
|
||||
try:
|
||||
c.execute("insert into issue3 (d, t, dt, ts) values (%s,%s,%s,%s)", (None, None, None, None))
|
||||
c.execute("select d from issue3")
|
||||
self.assertEqual(None, c.fetchone()[0])
|
||||
c.execute("select t from issue3")
|
||||
self.assertEqual(None, c.fetchone()[0])
|
||||
c.execute("select dt from issue3")
|
||||
self.assertEqual(None, c.fetchone()[0])
|
||||
c.execute("select ts from issue3")
|
||||
self.assertTrue(isinstance(c.fetchone()[0], datetime.datetime))
|
||||
finally:
|
||||
c.execute("drop table issue3")
|
||||
|
||||
def test_issue_4(self):
|
||||
""" can't retrieve TIMESTAMP fields """
|
||||
conn = self.connections[0]
|
||||
c = conn.cursor()
|
||||
c.execute("create table issue4 (ts timestamp)")
|
||||
try:
|
||||
c.execute("insert into issue4 (ts) values (now())")
|
||||
c.execute("select ts from issue4")
|
||||
self.assertTrue(isinstance(c.fetchone()[0], datetime.datetime))
|
||||
finally:
|
||||
c.execute("drop table issue4")
|
||||
|
||||
def test_issue_6(self):
|
||||
""" exception: TypeError: ord() expected a character, but string of length 0 found """
|
||||
conn = pymysql.connect(host="localhost",user="root",passwd="",db="mysql")
|
||||
c = conn.cursor()
|
||||
c.execute("select * from user")
|
||||
conn.close()
|
||||
|
||||
def test_issue_8(self):
|
||||
""" Primary Key and Index error when selecting data """
|
||||
conn = self.connections[0]
|
||||
c = conn.cursor()
|
||||
c.execute("""CREATE TABLE `test` (`station` int(10) NOT NULL DEFAULT '0', `dh`
|
||||
datetime NOT NULL DEFAULT '0000-00-00 00:00:00', `echeance` int(1) NOT NULL
|
||||
DEFAULT '0', `me` double DEFAULT NULL, `mo` double DEFAULT NULL, PRIMARY
|
||||
KEY (`station`,`dh`,`echeance`)) ENGINE=MyISAM DEFAULT CHARSET=latin1;""")
|
||||
try:
|
||||
self.assertEqual(0, c.execute("SELECT * FROM test"))
|
||||
c.execute("ALTER TABLE `test` ADD INDEX `idx_station` (`station`)")
|
||||
self.assertEqual(0, c.execute("SELECT * FROM test"))
|
||||
finally:
|
||||
c.execute("drop table test")
|
||||
|
||||
def test_issue_9(self):
|
||||
""" sets DeprecationWarning in Python 2.6 """
|
||||
try:
|
||||
imp.reload(pymysql)
|
||||
except DeprecationWarning:
|
||||
self.fail()
|
||||
|
||||
def test_issue_10(self):
|
||||
""" Allocate a variable to return when the exception handler is permissive """
|
||||
conn = self.connections[0]
|
||||
conn.errorhandler = lambda cursor, errorclass, errorvalue: None
|
||||
cur = conn.cursor()
|
||||
cur.execute( "create table t( n int )" )
|
||||
cur.execute( "create table t( n int )" )
|
||||
|
||||
def test_issue_13(self):
|
||||
""" can't handle large result fields """
|
||||
conn = self.connections[0]
|
||||
cur = conn.cursor()
|
||||
cur.execute("create table issue13 (t text)")
|
||||
try:
|
||||
# ticket says 18k
|
||||
size = 18*1024
|
||||
cur.execute("insert into issue13 (t) values (%s)", ("x" * size,))
|
||||
cur.execute("select t from issue13")
|
||||
# use assert_ so that obscenely huge error messages don't print
|
||||
r = cur.fetchone()[0]
|
||||
self.assert_("x" * size == r)
|
||||
finally:
|
||||
cur.execute("drop table issue13")
|
||||
|
||||
def test_issue_14(self):
|
||||
""" typo in converters.py """
|
||||
self.assertEqual('1', pymysql.converters.escape_item(1, "utf8"))
|
||||
self.assertEqual('1', pymysql.converters.escape_item(1L, "utf8"))
|
||||
|
||||
self.assertEqual('1', pymysql.converters.escape_object(1, "utf8"))
|
||||
self.assertEqual('1', pymysql.converters.escape_object(1L, "utf8"))
|
||||
|
||||
def test_issue_15(self):
|
||||
""" query should be expanded before perform character encoding """
|
||||
conn = self.connections[0]
|
||||
c = conn.cursor()
|
||||
c.execute("create table issue15 (t varchar(32))")
|
||||
try:
|
||||
c.execute("insert into issue15 (t) values (%s)", (u'\xe4\xf6\xfc'))
|
||||
c.execute("select t from issue15")
|
||||
self.assertEqual(u'\xe4\xf6\xfc', c.fetchone()[0])
|
||||
finally:
|
||||
c.execute("drop table issue15")
|
||||
|
||||
def test_issue_16(self):
|
||||
""" Patch for string and tuple escaping """
|
||||
conn = self.connections[0]
|
||||
c = conn.cursor()
|
||||
c.execute("create table issue16 (name varchar(32) primary key, email varchar(32))")
|
||||
try:
|
||||
c.execute("insert into issue16 (name, email) values ('pete', 'floydophone')")
|
||||
c.execute("select email from issue16 where name=%s", ("pete",))
|
||||
self.assertEqual("floydophone", c.fetchone()[0])
|
||||
finally:
|
||||
c.execute("drop table issue16")
|
||||
|
||||
def test_issue_17(self):
|
||||
""" could not connect mysql use passwod """
|
||||
conn = self.connections[0]
|
||||
host = self.databases[0]["host"]
|
||||
db = self.databases[0]["db"]
|
||||
c = conn.cursor()
|
||||
# grant access to a table to a user with a password
|
||||
c.execute("create table issue17 (x varchar(32) primary key)")
|
||||
try:
|
||||
c.execute("insert into issue17 (x) values ('hello, world!')")
|
||||
c.execute("grant all privileges on issue17 to issue17user identified by '1234'")
|
||||
|
||||
conn2 = pymysql.connect(host=host, user="issue17user", passwd="1234", db=db)
|
||||
c2 = conn2.cursor()
|
||||
c2.execute("select x from issue17")
|
||||
self.assertEqual("hello, world!", c2.fetchone()[0])
|
||||
finally:
|
||||
c.execute("drop table issue17")
|
||||
|
||||
if __name__ == "__main__":
|
||||
import unittest
|
||||
unittest.main()
|
||||
5
setup.py
5
setup.py
@@ -1,5 +1,6 @@
|
||||
|
||||
from setuptools import setup
|
||||
#from setuptools import setup
|
||||
from distutils.core import setup
|
||||
|
||||
version_tuple = __import__('pymysql').VERSION
|
||||
|
||||
@@ -17,5 +18,5 @@ setup(
|
||||
maintainer = 'David.Story',
|
||||
maintainer_email = 'iDavidStory@gmail.com',
|
||||
description = 'Pure Python MySQL Driver ',
|
||||
packages = ['pymysql', 'pymysql.constants']
|
||||
packages = ['pymysql', 'pymysql.constants', 'pymysql.tests']
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user