From ca4454a3e7cf283838b6c3f1871fcc1a85eb5e3e Mon Sep 17 00:00:00 2001 From: Pete Hunt Date: Mon, 26 Jul 2010 19:28:33 +0000 Subject: [PATCH] My first commit. Did a lot of refactoring, so see the updated README file --- README | 28 +++- example.py | 6 +- pymysql/__init__.py | 36 +++-- pymysql/connections.py | 229 +++++++++++++++++++++++------- pymysql/constants/CLIENT.py | 20 +++ pymysql/constants/CLIENT_FLAG.py | 20 --- pymysql/constants/FLAG.py | 15 ++ pymysql/converters.py | 148 ++++++++++++------- pymysql/{cursor.py => cursors.py} | 20 ++- pymysql/{exceptions.py => err.py} | 24 ++-- pymysql/tests/__init__.py | 5 + pymysql/tests/base.py | 18 +++ pymysql/tests/test_issues.py | 144 +++++++++++++++++++ setup.py | 5 +- 14 files changed, 557 insertions(+), 161 deletions(-) create mode 100644 pymysql/constants/CLIENT.py delete mode 100644 pymysql/constants/CLIENT_FLAG.py create mode 100644 pymysql/constants/FLAG.py rename pymysql/{cursor.py => cursors.py} (89%) rename pymysql/{exceptions.py => err.py} (88%) create mode 100644 pymysql/tests/__init__.py create mode 100644 pymysql/tests/base.py create mode 100644 pymysql/tests/test_issues.py diff --git a/README b/README index 6be47d8..98569a0 100644 --- a/README +++ b/README @@ -1,15 +1,35 @@ ==================== -pymysql Installation +PyMySQL Installation ==================== .. contents:: .. - This package contains a pure python mysql client library. + This package contains a pure-Python MySQL client library. Documentation on the MySQL client/server protocol can be found here: http://forge.mysql.com/wiki/MySQL_Internals_ClientServer_Protocol + If you would like to run the test suite, create a ~/.my.cnf file and + a database called "test_pymysql". The goal of pymysql is to be a drop-in + replacement for MySQLdb and work on CPython 2.3+, Jython, IronPython, PyPy + and Python 3. We test for compatibility by simply changing the import + statements in the Django MySQL backend and running its unit tests as well + as running it against the MySQLdb and myconnpy unit tests. Changes -------- +0.3 -Implemented most of the extended DBAPI 2.0 spec including callproc() + -Fixed error handling to include the message from the server and support + multiple protocol versions. + -Implemented ping() + -Implemented unicode support (probably needs better testing) + -Removed DeprecationWarnings + -Ran against the MySQLdb and myconnpy unit tests to check for bugs + -Added support for client_flag, charset, sql_mode, read_default_file, and + use_unicode. + -Refactoring for some more compatibility with MySQLdb including a fake + pymysql.version_info attribute. + -Now runs with no warnings with the -3 command-line switch + -Added test cases for all outstanding tickets and closed most of them. + 0.2 -Changed connection parameter name 'password' to 'passwd' to make it more plugin replaceable for the other mysql clients. -Changed pack()/unpack() calls so it runs on 64 bit OSes too. @@ -26,11 +46,11 @@ Requirements * http://www.python.org/ - * 2.5 is the primary test environment. + * 2.6 is the primary test environment. * MySQL 4.1 or higher - * protocol41 support + * protocol41 support, experimental 4.0 support Installation ------------ diff --git a/example.py b/example.py index d7e2d00..7d6283e 100644 --- a/example.py +++ b/example.py @@ -2,8 +2,8 @@ import pymysql -conn = pymysql.connect(host='127.0.0.1', unix_socket='/tmp/mysql.sock', user='root', passwd=None, db='mysql') -# conn = pymysql.connect(host='127.0.0.1', port=3306, user='root', passwd=None, db='mysql') +#conn = pymysql.connect(host='127.0.0.1', unix_socket='/tmp/mysql.sock', user='root', passwd=None, db='mysql') +conn = pymysql.connect(host='127.0.0.1', port=3306, user='root', passwd=None, db='mysql') cur = conn.cursor() @@ -14,7 +14,7 @@ cur.execute("SELECT Host,User FROM user") # r = cur.fetchall() # print r # ...or... -for r in cur: +for r in cur.fetchall(): print r cur.close() diff --git a/pymysql/__init__.py b/pymysql/__init__.py index 1e42bea..9c7cdee 100644 --- a/pymysql/__init__.py +++ b/pymysql/__init__.py @@ -1,37 +1,41 @@ -VERSION = (0, 2, None) -__version__ = VERSION # for MySQLdb compatibility +VERSION = (0, 3, None) from pymysql.constants import FIELD_TYPE from pymysql.converters import escape_dict, escape_sequence, escape_string -from pymysql.exceptions import Warning, Error, InterfaceError, DataError, \ +from pymysql.err import Warning, Error, InterfaceError, DataError, \ DatabaseError, OperationalError, IntegrityError, InternalError, \ NotSupportedError, ProgrammingError from pymysql.times import Date, Time, Timestamp, \ DateFromTicks, TimeFromTicks, TimestampFromTicks -from sets import ImmutableSet +try: + frozenset +except NameError: + from sets import ImmutableSet as frozenset + from sets import BaseSet as set threadsafety = 1 apilevel = "2.0" paramstyle = "format" -class DBAPISet(ImmutableSet): +class DBAPISet(frozenset): def __ne__(self, other): - from sets import BaseSet - if isinstance(other, BaseSet): + if isinstance(other, set): return super(DBAPISet, self).__ne__(self, other) else: return other not in self def __eq__(self, other): - from sets import BaseSet - if isinstance(other, BaseSet): - return super(DBAPISet, self).__eq__(self, other) + if isinstance(other, frozenset): + return frozenset.__eq__(self, other) else: return other in self + def __hash__(self): + return frozenset.__hash__(self) + STRING = DBAPISet([FIELD_TYPE.ENUM, FIELD_TYPE.STRING, FIELD_TYPE.VAR_STRING]) @@ -59,6 +63,16 @@ def get_client_info(): # for MySQLdb compatibility connect = Connection = Connect +# we include a doctored version_info here for MySQLdb compatibility +version_info = (1,2,2,"final",0) + +NULL = "NULL" + +__version__ = get_client_info() + +def thread_safe(): + # Pure python, so yes we're threadsafe + return True __all__ = [ 'BINARY', 'Binary', 'Connect', 'Connection', 'DATE', 'Date', @@ -70,4 +84,6 @@ __all__ = [ 'connections', 'constants', 'converters', 'cursors', 'debug', 'escape', 'escape_dict', 'escape_sequence', 'escape_string', 'get_client_info', 'paramstyle', 'string_literal', 'threadsafety', 'version_info', + + "NULL","__version__", ] diff --git a/pymysql/connections.py b/pymysql/connections.py index 3f3ffe1..6393691 100644 --- a/pymysql/connections.py +++ b/pymysql/connections.py @@ -2,19 +2,28 @@ # http://forge.mysql.com/wiki/MySQL_Internals_ClientServer_Protocol import re -import sha + +try: + import hashlib + sha_new = lambda *args, **kwargs: hashlib.new("sha1", *args, **kwargs) +except ImportError: + import sha + sha_new = sha.new + import socket import struct import sys +import os +import ConfigParser from pymysql.charset import MBLENGTH -from pymysql.cursor import Cursor +from pymysql.cursors import Cursor from pymysql.constants import FIELD_TYPE from pymysql.constants import SERVER_STATUS -from pymysql.constants.CLIENT_FLAG import * +from pymysql.constants.CLIENT import * from pymysql.constants.COMMAND import * -from pymysql.converters import escape_item, encoders, decoders -from pymysql.exceptions import raise_mysql_exception, Warning, Error, \ +from pymysql.converters import escape_item, encoders, decoders, field_decoders +from pymysql.err import raise_mysql_exception, Warning, Error, \ InterfaceError, DataError, DatabaseError, OperationalError, \ IntegrityError, InternalError, NotSupportedError, ProgrammingError @@ -57,9 +66,9 @@ def _scramble(password, message): if password == None or len(password) == 0: return '\0' if DEBUG: print 'password=' + password - stage1 = sha.new(password).digest() - stage2 = sha.new(stage1).digest() - s = sha.new() + stage1 = sha_new(password).digest() + stage2 = sha_new(stage1).digest() + s = sha_new() s.update(message) s.update(stage2) result = s.digest() @@ -105,7 +114,11 @@ def defaulterrorhandler(connection, cursor, errorclass, errorvalue): connection.messages.append(err) del cursor del connection - raise errorclass, errorvalue + + if not issubclass(errorclass, Error): + raise Error(errorclass, errorvalue) + else: + raise errorclass, errorvalue class MysqlPacket(object): @@ -121,6 +134,9 @@ class MysqlPacket(object): def __recv_packet(self, socket): """Parse the packet header and read entire packet payload into buffer.""" packet_header = socket.recv(4) + while len(packet_header) < 4: + packet_header += socket.recv(4 - len(packet_header)) + if DEBUG: dump_packet(packet_header) packet_length_bin = packet_header[:3] self.__packet_number = ord(packet_header[3]) @@ -303,7 +319,7 @@ class FieldDescriptorPacket(MysqlPacket): def get_column_length(self): if self.type_code == FIELD_TYPE.VAR_STRING: mblen = MBLENGTH.get(self.charsetnr, 1) - return self.length / mblen + return self.length // mblen return self.length def __str__(self): @@ -316,54 +332,103 @@ class Connection(object): """Representation of a socket with a mysql server.""" errorhandler = defaulterrorhandler - def __init__(self, *args, **kwargs): - self.host = kwargs['host'] - self.port = kwargs.get('port', 3306) - self.user = kwargs['user'] - self.password = kwargs['passwd'] - self.db = kwargs.get('db', None) - self.unix_socket = kwargs.get('unix_socket', None) - self.charset = DEFAULT_CHARSET - - client_flag = CLIENT_CAPABILITIES - #client_flag = kwargs.get('client_flag', None) - client_flag |= CLIENT_MULTI_STATEMENTS + def __init__(self, host="localhost", user=None, passwd="", + db=None, port=3306, unix_socket=None, + charset=DEFAULT_CHARSET, sql_mode=None, + read_default_file=None, conv=decoders, use_unicode=True, + client_flag=0): + + if read_default_file: + cfg = ConfigParser.RawConfigParser() + cfg.read(os.path.expanduser(read_default_file)) + + def _config(key, default): + try: + return cfg.get("client",key) + except: + return default + + user = _config("user",user) + passwd = _config("password",passwd) + host = _config("host", host) + db = _config("db",db) + unix_socket = _config("socket",unix_socket) + port = _config("port", port) + charset = _config("default-character-set", charset) + + self.host = host + self.port = port + self.user = user + self.password = passwd + self.db = db + self.unix_socket = unix_socket + self.use_unicode = use_unicode + self.charset = charset + client_flag |= CAPABILITIES + client_flag |= MULTI_STATEMENTS if self.db: - client_flag |= CLIENT_CONNECT_WITH_DB + client_flag |= CONNECT_WITH_DB self.client_flag = client_flag self._connect() - charset = kwargs.get('charset', None) - self.set_chatset_set(charset) + self.set_charset_set(charset) self.messages = [] self.encoders = encoders - self.decoders = decoders + self.decoders = conv + self.field_decoders = field_decoders + + self._affected_rows = 0 + self.host_info = "Not connected" self.autocommit(False) + if sql_mode is not None: + c = self.cursor() + c.execute("SET sql_mode=%s", (sql_mode,)) + self.commit() + def close(self): - send_data = struct.pack('Q", b)[0] encoders = { bool: escape_bool, - int: escape_object, - long: escape_object, + int: escape_int, + long: escape_long, float: escape_float, str: escape_string, + unicode: escape_unicode, tuple: escape_sequence, list:escape_sequence, + set:escape_sequence, dict:escape_dict, + type(None):escape_None, + datetime.date: escape_date, datetime.datetime : escape_datetime, - datetime.timedelta : escape_timedelta + datetime.timedelta : escape_timedelta, + datetime.time : escape_time, + time.struct_time : escape_struct_time, } decoders = { + FIELD_TYPE.BIT: convert_bit, FIELD_TYPE.TINY: int, FIELD_TYPE.SHORT: int, FIELD_TYPE.LONG: long, @@ -236,17 +260,37 @@ decoders = { FIELD_TYPE.DATETIME: convert_datetime, FIELD_TYPE.TIME: convert_timedelta, FIELD_TYPE.DATE: convert_date, - FIELD_TYPE.BLOB: str, - FIELD_TYPE.STRING: str, - FIELD_TYPE.VAR_STRING: str, - FIELD_TYPE.VARCHAR: str + FIELD_TYPE.SET: convert_set, + #FIELD_TYPE.BLOB: str, + #FIELD_TYPE.STRING: str, + #FIELD_TYPE.VAR_STRING: str, + #FIELD_TYPE.VARCHAR: str } conversions = decoders # for MySQLdb compatibility +def decode_characters(connection, field, data): + if field.charsetnr == 63 or not connection.use_unicode: + # binary data, leave it alone + return data + return data.decode(connection.charset) + +# These take a field instance rather than just the data. +field_decoders = { + FIELD_TYPE.BLOB: decode_characters, + FIELD_TYPE.STRING: decode_characters, + FIELD_TYPE.VAR_STRING: decode_characters, + FIELD_TYPE.VARCHAR: decode_characters, +} + try: # python version > 2.3 from decimal import Decimal decoders[FIELD_TYPE.DECIMAL] = Decimal decoders[FIELD_TYPE.NEWDECIMAL] = Decimal + + def escape_decimal(obj): + return str(obj) + encoders[Decimal] = escape_decimal + except ImportError: pass diff --git a/pymysql/cursor.py b/pymysql/cursors.py similarity index 89% rename from pymysql/cursor.py rename to pymysql/cursors.py index 9d02268..fb9668f 100644 --- a/pymysql/cursor.py +++ b/pymysql/cursors.py @@ -1,9 +1,11 @@ import struct +import re -from pymysql.exceptions import Warning, Error, InterfaceError, DataError, \ +from pymysql.err import Warning, Error, InterfaceError, DataError, \ DatabaseError, OperationalError, IntegrityError, InternalError, \ NotSupportedError, ProgrammingError - + +insert_values = re.compile(r'\svalues\s*(\(.+\))', re.IGNORECASE) class Cursor(object): @@ -16,7 +18,7 @@ class Cursor(object): self.arraysize = 1 self._executed = None self.messages = [] - self.errorhandler =connection.errorhandler + self.errorhandler = connection.errorhandler self._has_next = None self._rows = () @@ -68,11 +70,15 @@ class Cursor(object): charset = conn.charset del self.messages[:] + # this ordering is good because conn.escape() returns + # an encoded string. if isinstance(query, unicode): query = query.encode(charset) + if args is not None: query = query % conn.escape(args) + result = 0 try: result = self._query(query) except: @@ -98,7 +104,13 @@ class Cursor(object): def callproc(self, procname, args=()): - self.errorhandler(self, NotSupportedError, "not supported") + #self.errorhandler(self, NotSupportedError, "not supported") + if not isinstance(args, tuple): + args = (args,) + + argstr = ("%s," * len(args))[:-1] + + return self.execute("CALL `%s`(%s)" % (procname, argstr), args) def fetchone(self): self._check_executed() diff --git a/pymysql/exceptions.py b/pymysql/err.py similarity index 88% rename from pymysql/exceptions.py rename to pymysql/err.py index 843122a..7e8df78 100644 --- a/pymysql/exceptions.py +++ b/pymysql/err.py @@ -108,22 +108,24 @@ del StandardError, _map_error, ER def _get_error_info(data): - errno = struct.unpack('h', data[5:7])[0] - sqlstate = struct.unpack('5s', data[8:8+5])[0] - start = 13 - end = data.find('\0', start) - errorvalue = data[start:end] - return (errno, sqlstate, errorvalue) + errno = struct.unpack('