Run reindent.py

This commit is contained in:
INADA Naoki
2013-08-30 13:22:43 +09:00
parent 050424d289
commit ef6ad66ca6
16 changed files with 298 additions and 309 deletions

View File

@@ -91,14 +91,14 @@ def Connect(*args, **kwargs):
"""
from .connections import Connection
return Connection(*args, **kwargs)
from pymysql import connections as _orig_conn
Connect.__doc__ = _orig_conn.Connection.__init__.__doc__ + """\nSee connections.Connection.__init__() for
information about defaults."""
del _orig_conn
def get_client_info(): # for MySQLdb compatibility
return '%s.%s.%s' % VERSION
return '%s.%s.%s' % VERSION
connect = Connection = Connect
@@ -115,7 +115,7 @@ def thread_safe():
def install_as_MySQLdb():
"""
After this function is called, any application that imports MySQLdb or
_mysql will unwittingly actually use
_mysql will unwittingly actually use
"""
sys.modules["MySQLdb"] = sys.modules["_mysql"] = sys.modules["pymysql"]

View File

@@ -241,4 +241,3 @@ def charset_by_name(name):
def charset_by_id(id):
return _charsets.by_id(id)

View File

@@ -68,12 +68,12 @@ DEFAULT_CHARSET = 'latin1'
def dump_packet(data):
def is_ascii(data):
if byte2int(data) >= 65 and byte2int(data) <= 122: #data.isalnum():
return data
return '.'
try:
print("packet length {}".format(len(data)))
print("method call[1]: {}".format(sys._getframe(1).f_code.co_name))
@@ -164,7 +164,7 @@ def pack_int24(n):
return struct.pack('BBB', n&0xFF, (n>>8)&0xFF, (n>>16)&0xFF)
def unpack_uint16(n):
return struct.unpack('<H', n[0:2])[0]
return struct.unpack('<H', n[0:2])[0]
# TODO: stop using bit-shifting in these functions...
@@ -212,208 +212,208 @@ def defaulterrorhandler(connection, cursor, errorclass, errorvalue):
class MysqlPacket(object):
"""Representation of a MySQL response packet. Reads in the packet
from the network socket, removes packet header and provides an interface
for reading/parsing the packet results."""
"""Representation of a MySQL response packet. Reads in the packet
from the network socket, removes packet header and provides an interface
for reading/parsing the packet results."""
def __init__(self, connection):
self.connection = connection
self.__position = 0
self.__recv_packet()
def __init__(self, connection):
self.connection = connection
self.__position = 0
self.__recv_packet()
def __recv_packet(self):
"""Parse the packet header and read entire packet payload into buffer."""
packet_header = self.connection._read_bytes(4)
if len(packet_header) < 4:
raise OperationalError(2013, "Lost connection to MySQL server during query")
def __recv_packet(self):
"""Parse the packet header and read entire packet payload into buffer."""
packet_header = self.connection._read_bytes(4)
if len(packet_header) < 4:
raise OperationalError(2013, "Lost connection to MySQL server during query")
if DEBUG: dump_packet(packet_header)
packet_length_bin = packet_header[:3]
self.__packet_number = byte2int(packet_header[3])
# TODO: check packet_num is correct (+1 from last packet)
if DEBUG: dump_packet(packet_header)
packet_length_bin = packet_header[:3]
self.__packet_number = byte2int(packet_header[3])
# TODO: check packet_num is correct (+1 from last packet)
bin_length = packet_length_bin + int2byte(0) # pad little-endian number
bytes_to_read = struct.unpack('<I', bin_length)[0]
recv_data = self.connection._read_bytes(bytes_to_read)
if len(recv_data) < bytes_to_read:
raise OperationalError(2013, "Lost connection to MySQL server during query")
if DEBUG: dump_packet(recv_data)
self.__data = recv_data
bin_length = packet_length_bin + int2byte(0) # pad little-endian number
bytes_to_read = struct.unpack('<I', bin_length)[0]
recv_data = self.connection._read_bytes(bytes_to_read)
if len(recv_data) < bytes_to_read:
raise OperationalError(2013, "Lost connection to MySQL server during query")
if DEBUG: dump_packet(recv_data)
self.__data = recv_data
def packet_number(self): return self.__packet_number
def packet_number(self): return self.__packet_number
def get_all_data(self): return self.__data
def get_all_data(self): return self.__data
def read(self, size):
"""Read the first 'size' bytes in packet and advance cursor past them."""
result = self.peek(size)
self.advance(size)
return result
def read(self, size):
"""Read the first 'size' bytes in packet and advance cursor past them."""
result = self.peek(size)
self.advance(size)
return result
def read_all(self):
"""Read all remaining data in the packet.
def read_all(self):
"""Read all remaining data in the packet.
(Subsequent read() or peek() will return errors.)
"""
result = self.__data[self.__position:]
self.__position = None # ensure no subsequent read() or peek()
return result
(Subsequent read() or peek() will return errors.)
"""
result = self.__data[self.__position:]
self.__position = None # ensure no subsequent read() or peek()
return result
def advance(self, length):
"""Advance the cursor in data buffer 'length' bytes."""
new_position = self.__position + length
if new_position < 0 or new_position > len(self.__data):
raise Exception('Invalid advance amount (%s) for cursor. '
'Position=%s' % (length, new_position))
self.__position = new_position
def advance(self, length):
"""Advance the cursor in data buffer 'length' bytes."""
new_position = self.__position + length
if new_position < 0 or new_position > len(self.__data):
raise Exception('Invalid advance amount (%s) for cursor. '
'Position=%s' % (length, new_position))
self.__position = new_position
def rewind(self, position=0):
"""Set the position of the data buffer cursor to 'position'."""
if position < 0 or position > len(self.__data):
raise Exception("Invalid position to rewind cursor to: %s." % position)
self.__position = position
def rewind(self, position=0):
"""Set the position of the data buffer cursor to 'position'."""
if position < 0 or position > len(self.__data):
raise Exception("Invalid position to rewind cursor to: %s." % position)
self.__position = position
def peek(self, size):
"""Look at the first 'size' bytes in packet without moving cursor."""
result = self.__data[self.__position:(self.__position+size)]
if len(result) != size:
error = ('Result length not requested length:\n'
'Expected=%s. Actual=%s. Position: %s. Data Length: %s'
% (size, len(result), self.__position, len(self.__data)))
if DEBUG:
print(error)
self.dump()
raise AssertionError(error)
return result
def peek(self, size):
"""Look at the first 'size' bytes in packet without moving cursor."""
result = self.__data[self.__position:(self.__position+size)]
if len(result) != size:
error = ('Result length not requested length:\n'
'Expected=%s. Actual=%s. Position: %s. Data Length: %s'
% (size, len(result), self.__position, len(self.__data)))
if DEBUG:
print(error)
self.dump()
raise AssertionError(error)
return result
def get_bytes(self, position, length=1):
"""Get 'length' bytes starting at 'position'.
def get_bytes(self, position, length=1):
"""Get 'length' bytes starting at 'position'.
Position is start of payload (first four packet header bytes are not
included) starting at index '0'.
Position is start of payload (first four packet header bytes are not
included) starting at index '0'.
No error checking is done. If requesting outside end of buffer
an empty string (or string shorter than 'length') may be returned!
"""
return self.__data[position:(position+length)]
No error checking is done. If requesting outside end of buffer
an empty string (or string shorter than 'length') may be returned!
"""
return self.__data[position:(position+length)]
def read_length_coded_binary(self):
"""Read a 'Length Coded Binary' number from the data buffer.
def read_length_coded_binary(self):
"""Read a 'Length Coded Binary' number from the data buffer.
Length coded numbers can be anywhere from 1 to 9 bytes depending
on the value of the first byte.
"""
c = byte2int(self.read(1))
if c == NULL_COLUMN:
return None
if c < UNSIGNED_CHAR_COLUMN:
return c
elif c == UNSIGNED_SHORT_COLUMN:
return unpack_uint16(self.read(UNSIGNED_SHORT_LENGTH))
elif c == UNSIGNED_INT24_COLUMN:
return unpack_int24(self.read(UNSIGNED_INT24_LENGTH))
elif c == UNSIGNED_INT64_COLUMN:
# TODO: what was 'longlong'? confirm it wasn't used?
return unpack_int64(self.read(UNSIGNED_INT64_LENGTH))
Length coded numbers can be anywhere from 1 to 9 bytes depending
on the value of the first byte.
"""
c = byte2int(self.read(1))
if c == NULL_COLUMN:
return None
if c < UNSIGNED_CHAR_COLUMN:
return c
elif c == UNSIGNED_SHORT_COLUMN:
return unpack_uint16(self.read(UNSIGNED_SHORT_LENGTH))
elif c == UNSIGNED_INT24_COLUMN:
return unpack_int24(self.read(UNSIGNED_INT24_LENGTH))
elif c == UNSIGNED_INT64_COLUMN:
# TODO: what was 'longlong'? confirm it wasn't used?
return unpack_int64(self.read(UNSIGNED_INT64_LENGTH))
def read_length_coded_string(self):
"""Read a 'Length Coded String' from the data buffer.
def read_length_coded_string(self):
"""Read a 'Length Coded String' from the data buffer.
A 'Length Coded String' consists first of a length coded
(unsigned, positive) integer represented in 1-9 bytes followed by
that many bytes of binary data. (For example "cat" would be "3cat".)
"""
length = self.read_length_coded_binary()
if length is None:
return None
return self.read(length)
A 'Length Coded String' consists first of a length coded
(unsigned, positive) integer represented in 1-9 bytes followed by
that many bytes of binary data. (For example "cat" would be "3cat".)
"""
length = self.read_length_coded_binary()
if length is None:
return None
return self.read(length)
def is_ok_packet(self):
return byte2int(self.get_bytes(0)) == 0
def is_ok_packet(self):
return byte2int(self.get_bytes(0)) == 0
def is_eof_packet(self):
return byte2int(self.get_bytes(0)) == 254 # 'fe'
def is_eof_packet(self):
return byte2int(self.get_bytes(0)) == 254 # 'fe'
def is_resultset_packet(self):
field_count = byte2int(self.get_bytes(0))
return field_count >= 1 and field_count <= 250
def is_resultset_packet(self):
field_count = byte2int(self.get_bytes(0))
return field_count >= 1 and field_count <= 250
def is_error_packet(self):
return byte2int(self.get_bytes(0)) == 255
def is_error_packet(self):
return byte2int(self.get_bytes(0)) == 255
def check_error(self):
if self.is_error_packet():
self.rewind()
self.advance(1) # field_count == error (we already know that)
errno = unpack_uint16(self.read(2))
if DEBUG: print("errno = {}".format(errno))
raise_mysql_exception(self.__data)
def check_error(self):
if self.is_error_packet():
self.rewind()
self.advance(1) # field_count == error (we already know that)
errno = unpack_uint16(self.read(2))
if DEBUG: print("errno = {}".format(errno))
raise_mysql_exception(self.__data)
def dump(self):
dump_packet(self.__data)
def dump(self):
dump_packet(self.__data)
class FieldDescriptorPacket(MysqlPacket):
"""A MysqlPacket that represents a specific column's metadata in the result.
"""A MysqlPacket that represents a specific column's metadata in the result.
Parsing is automatically done and the results are exported via public
attributes on the class such as: db, table_name, name, length, type_code.
"""
def __init__(self, *args):
MysqlPacket.__init__(self, *args)
self.__parse_field_descriptor()
def __parse_field_descriptor(self):
"""Parse the 'Field Descriptor' (Metadata) packet.
This is compatible with MySQL 4.1+ (not compatible with MySQL 4.0).
Parsing is automatically done and the results are exported via public
attributes on the class such as: db, table_name, name, length, type_code.
"""
self.catalog = self.read_length_coded_string()
self.db = self.read_length_coded_string()
self.table_name = self.read_length_coded_string()
self.org_table = self.read_length_coded_string()
self.name = self.read_length_coded_string().decode(self.connection.charset)
self.org_name = self.read_length_coded_string()
self.advance(1) # non-null filler
self.charsetnr = struct.unpack('<H', self.read(2))[0]
self.length = struct.unpack('<I', self.read(4))[0]
self.type_code = byte2int(self.read(1))
self.flags = struct.unpack('<H', self.read(2))[0]
self.scale = byte2int(self.read(1)) # "decimals"
self.advance(2) # filler (always 0x00)
# 'default' is a length coded binary and is still in the buffer?
# not used for normal result sets...
def __init__(self, *args):
MysqlPacket.__init__(self, *args)
self.__parse_field_descriptor()
def description(self):
"""Provides a 7-item tuple compatible with the Python PEP249 DB Spec."""
desc = []
desc.append(self.name)
desc.append(self.type_code)
desc.append(None) # TODO: display_length; should this be self.length?
desc.append(self.get_column_length()) # 'internal_size'
desc.append(self.get_column_length()) # 'precision' # TODO: why!?!?
desc.append(self.scale)
def __parse_field_descriptor(self):
"""Parse the 'Field Descriptor' (Metadata) packet.
# 'null_ok' -- can this be True/False rather than 1/0?
# if so just do: desc.append(bool(self.flags % 2 == 0))
if self.flags % 2 == 0:
desc.append(1)
else:
desc.append(0)
return tuple(desc)
This is compatible with MySQL 4.1+ (not compatible with MySQL 4.0).
"""
self.catalog = self.read_length_coded_string()
self.db = self.read_length_coded_string()
self.table_name = self.read_length_coded_string()
self.org_table = self.read_length_coded_string()
self.name = self.read_length_coded_string().decode(self.connection.charset)
self.org_name = self.read_length_coded_string()
self.advance(1) # non-null filler
self.charsetnr = struct.unpack('<H', self.read(2))[0]
self.length = struct.unpack('<I', self.read(4))[0]
self.type_code = byte2int(self.read(1))
self.flags = struct.unpack('<H', self.read(2))[0]
self.scale = byte2int(self.read(1)) # "decimals"
self.advance(2) # filler (always 0x00)
def get_column_length(self):
if self.type_code == FIELD_TYPE.VAR_STRING:
mblen = MBLENGTH.get(self.charsetnr, 1)
return self.length // mblen
return self.length
# 'default' is a length coded binary and is still in the buffer?
# not used for normal result sets...
def __str__(self):
return ('%s %s.%s.%s, type=%s'
% (self.__class__, self.db, self.table_name, self.name,
self.type_code))
def description(self):
"""Provides a 7-item tuple compatible with the Python PEP249 DB Spec."""
desc = []
desc.append(self.name)
desc.append(self.type_code)
desc.append(None) # TODO: display_length; should this be self.length?
desc.append(self.get_column_length()) # 'internal_size'
desc.append(self.get_column_length()) # 'precision' # TODO: why!?!?
desc.append(self.scale)
# 'null_ok' -- can this be True/False rather than 1/0?
# if so just do: desc.append(bool(self.flags % 2 == 0))
if self.flags % 2 == 0:
desc.append(1)
else:
desc.append(0)
return tuple(desc)
def get_column_length(self):
if self.type_code == FIELD_TYPE.VAR_STRING:
mblen = MBLENGTH.get(self.charsetnr, 1)
return self.length // mblen
return self.length
def __str__(self):
return ('%s %s.%s.%s, type=%s'
% (self.__class__, self.db, self.table_name, self.name,
self.type_code))
class OKPacketWrapper(object):
"""
@@ -426,20 +426,20 @@ class OKPacketWrapper(object):
if not from_packet.is_ok_packet():
raise ValueError('Cannot create ' + str(self.__class__.__name__)
+ ' object from invalid packet type')
self.packet = from_packet
self.packet.advance(1)
self.affected_rows = self.packet.read_length_coded_binary()
self.insert_id = self.packet.read_length_coded_binary()
self.server_status = struct.unpack('<H', self.packet.read(2))[0]
self.warning_count = struct.unpack('<H', self.packet.read(2))[0]
self.message = self.packet.read_all()
def __getattr__(self, key):
if hasattr(self.packet, key):
return getattr(self.packet, key)
raise AttributeError(str(self.__class__)
+ " instance has no attribute '" + key + "'")
@@ -454,7 +454,7 @@ class EOFPacketWrapper(object):
if not from_packet.is_eof_packet():
raise ValueError('Cannot create ' + str(self.__class__.__name__)
+ ' object from invalid packet type')
self.packet = from_packet
self.warning_count = self.packet.read(2)
server_status = struct.unpack('<h', self.packet.read(2))[0]
@@ -464,7 +464,7 @@ class EOFPacketWrapper(object):
def __getattr__(self, key):
if hasattr(self.packet, key):
return getattr(self.packet, key)
raise AttributeError(str(self.__class__)
+ " instance has no attribute '" + key + "'")
@@ -756,12 +756,12 @@ class Connection(object):
raise OperationalError(2003, "Can't connect to MySQL server on %r (%s)" % (self.host, e.args[0]))
def read_packet(self, packet_type=MysqlPacket):
"""Read an entire "mysql packet" in its entirety from the network
and return a MysqlPacket type that represents the results."""
"""Read an entire "mysql packet" in its entirety from the network
and return a MysqlPacket type that represents the results."""
packet = packet_type(self)
packet.check_error()
return packet
packet = packet_type(self)
packet.check_error()
return packet
def _read_bytes(self, num_bytes):
d = self.socket.recv(num_bytes)
@@ -820,7 +820,7 @@ class Connection(object):
def _execute_command(self, command, sql):
self._send_command(command, sql)
def _request_authentication(self):
self.client_flag |= CAPABILITIES
if self.server_version.startswith('5'):
@@ -987,7 +987,7 @@ class MySQLResult(object):
else:
self.field_count = byte2int(first_packet.read(1))
self._get_descriptions()
# Apparently, MySQLdb picks this number because it's the maximum
# value of a 64bit unsigned integer. Since we're emulating MySQLdb,
# we set it to this instead of None, which would be preferred.
@@ -1017,7 +1017,7 @@ class MySQLResult(object):
def _read_rowdata_packet_unbuffered(self):
# Check if in an active query
if self.unbuffered_active == False: return
# EOF
packet = self.connection.read_packet()
if self._check_packet_is_eof(packet):
@@ -1052,29 +1052,29 @@ class MySQLResult(object):
# TODO: implement this as an iteratable so that it is more
# memory efficient and lower-latency to client...
def _read_rowdata_packet(self):
"""Read a rowdata packet for each data row in the result set."""
rows = []
while True:
packet = self.connection.read_packet()
if self._check_packet_is_eof(packet):
break
"""Read a rowdata packet for each data row in the result set."""
rows = []
while True:
packet = self.connection.read_packet()
if self._check_packet_is_eof(packet):
break
row = []
for field in self.fields:
data = packet.read_length_coded_string()
converted = None
if field.type_code in self.connection.decoders:
converter = self.connection.decoders[field.type_code]
if DEBUG: print("DEBUG: field={}, converter={}".format(field, converter))
if data != None:
converted = converter(self.connection, field, data)
row.append(converted)
row = []
for field in self.fields:
data = packet.read_length_coded_string()
converted = None
if field.type_code in self.connection.decoders:
converter = self.connection.decoders[field.type_code]
if DEBUG: print("DEBUG: field={}, converter={}".format(field, converter))
if data != None:
converted = converter(self.connection, field, data)
row.append(converted)
rows.append(tuple(row))
rows.append(tuple(row))
self.affected_rows = len(rows)
self.rows = tuple(rows)
if DEBUG: self.rows
self.affected_rows = len(rows)
self.rows = tuple(rows)
if DEBUG: self.rows
def _get_descriptions(self):
"""Read a column descriptor packet for each column in the result."""

View File

@@ -9,4 +9,3 @@ SERVER_STATUS_LAST_ROW_SENT = 128
SERVER_STATUS_DB_DROPPED = 256
SERVER_STATUS_NO_BACKSLASH_ESCAPES = 512
SERVER_STATUS_METADATA_CHANGED = 1024

View File

@@ -304,22 +304,22 @@ class SSCursor(Cursor):
"""
Unbuffered Cursor, mainly useful for queries that return a lot of data,
or for connections to remote servers over a slow network.
Instead of copying every row of data into a buffer, this will fetch
rows as needed. The upside of this, is the client uses much less memory,
and rows are returned much faster when traveling over a slow network,
or if the result set is very big.
There are limitations, though. The MySQL protocol doesn't support
returning the total number of rows, so the only way to tell how many rows
there are is to iterate over every row returned. Also, it currently isn't
possible to scroll backwards, as only the current row is held in memory.
"""
def close(self):
conn = self._get_db()
conn._result._finish_unbuffered_query()
try:
if self._has_next:
while self.nextset(): pass
@@ -331,31 +331,31 @@ class SSCursor(Cursor):
conn.query(q, unbuffered=True)
self._do_get_result()
return self.rowcount
def read_next(self):
""" Read next row """
conn = self._get_db()
conn._result._read_rowdata_packet_unbuffered()
return conn._result.rows
def fetchone(self):
""" Fetch next row """
self._check_executed()
row = self.read_next()
if row is None:
return None
self.rownumber += 1
return row
def fetchall(self):
"""
Fetch all, as per MySQLdb. Pretty useless for large queries, as
it is buffered. See fetchall_unbuffered(), if you want an unbuffered
generator version of this method.
"""
rows = []
while True:
row = self.fetchone()
@@ -370,19 +370,19 @@ class SSCursor(Cursor):
however, it doesn't make sense to return everything in a list, as that
would use ridiculous memory for large result sets.
"""
row = self.fetchone()
while row is not None:
yield row
row = self.fetchone()
def fetchmany(self, size=None):
""" Fetch many """
self._check_executed()
if size is None:
size = self.arraysize
rows = []
for i in range(0, size):
row = self.read_next()
@@ -391,25 +391,25 @@ class SSCursor(Cursor):
rows.append(row)
self.rownumber += 1
return tuple(rows)
def scroll(self, value, mode='relative'):
self._check_executed()
if not mode == 'relative' and not mode == 'absolute':
self.errorhandler(self, ProgrammingError,
"unknown scroll mode %s" % mode)
if mode == 'relative':
if value < 0:
self.errorhandler(self, NotSupportedError,
"Backwards scrolling not supported by this cursor")
for i in range(0, value): self.read_next()
self.rownumber += value
else:
if value < self.rownumber:
self.errorhandler(self, NotSupportedError,
"Backwards scrolling not supported by this cursor")
end = value - self.rownumber
for i in range(0, end): self.read_next()
self.rownumber = value

View File

@@ -4,7 +4,7 @@ from .constants import ER
import sys
class MySQLError(Exception):
"""Exception related to operation with MySQL."""
@@ -95,12 +95,12 @@ _map_error(IntegrityError, ER.DUP_ENTRY, ER.NO_REFERENCED_ROW,
ER.CANNOT_ADD_FOREIGN)
_map_error(NotSupportedError, ER.WARNING_NOT_COMPLETE_ROLLBACK,
ER.NOT_SUPPORTED_YET, ER.FEATURE_DISABLED, ER.UNKNOWN_STORAGE_ENGINE)
_map_error(OperationalError, ER.DBACCESS_DENIED_ERROR, ER.ACCESS_DENIED_ERROR,
ER.TABLEACCESS_DENIED_ERROR, ER.COLUMNACCESS_DENIED_ERROR)
_map_error(OperationalError, ER.DBACCESS_DENIED_ERROR, ER.ACCESS_DENIED_ERROR,
ER.TABLEACCESS_DENIED_ERROR, ER.COLUMNACCESS_DENIED_ERROR)
del _map_error, ER
def _get_error_info(data):
errno = struct.unpack('<h', data[1:3])[0]
if sys.version_info[0] == 3:
@@ -117,7 +117,7 @@ def _get_error_info(data):
return (errno, None, data[3:].decode("utf8"))
def _check_mysql_exception(errinfo):
errno, sqlstate, errorvalue = errinfo
errno, sqlstate, errorvalue = errinfo
errorclass = error_map.get(errno, None)
if errorclass:
raise errorclass(errno,errorvalue)
@@ -128,8 +128,3 @@ def _check_mysql_exception(errinfo):
def raise_mysql_exception(data):
errinfo = _get_error_info(data)
_check_mysql_exception(errinfo)

View File

@@ -17,4 +17,3 @@ class PyMySQLTestCase(unittest.TestCase):
def tearDown(self):
for connection in self.connections:
connection.close()

View File

@@ -6,7 +6,7 @@ import datetime
class TestDictCursor(base.PyMySQLTestCase):
def test_DictCursor(self):
#all assert test compare to the structure as would come out from MySQLdb
#all assert test compare to the structure as would come out from MySQLdb
conn = self.connections[0]
c = conn.cursor(pymysql.cursors.DictCursor)
# create a table ane some data to query
@@ -30,7 +30,7 @@ class TestDictCursor(base.PyMySQLTestCase):
c.execute("SELECT * from dictcursor where name='bob'")
r = c.fetchall()
self.assertEqual((bob,),r,"fetch a 1 row result via fetchall failed via DictCursor")
# same test again but iterate over the
# same test again but iterate over the
c.execute("SELECT * from dictcursor where name='bob'")
for r in c:
self.assertEqual(bob, r,"fetch a 1 row result via iteration failed via DictCursor")

View File

@@ -12,7 +12,7 @@ except:
class TestSSCursor(base.PyMySQLTestCase):
def test_SSCursor(self):
affected_rows = 18446744073709551615
conn = self.connections[0]
data = [
('America', '', 'America/Jamaica'),
@@ -25,22 +25,22 @@ class TestSSCursor(base.PyMySQLTestCase):
('America', '', 'America/Costa_Rica'),
('America', '', 'America/Denver'),
('America', '', 'America/Detroit'),]
try:
cursor = conn.cursor(pymysql.cursors.SSCursor)
# Create table
cursor.execute(('CREATE TABLE tz_data ('
'region VARCHAR(64),'
'zone VARCHAR(64),'
'name VARCHAR(64))'))
# Test INSERT
for i in data:
cursor.execute('INSERT INTO tz_data VALUES (%s, %s, %s)', i)
self.assertEqual(conn.affected_rows(), 1, 'affected_rows does not match')
conn.commit()
# Test fetchone()
iter = 0
cursor.execute('SELECT * FROM tz_data')
@@ -49,46 +49,46 @@ class TestSSCursor(base.PyMySQLTestCase):
if row is None:
break
iter += 1
# Test cursor.rowcount
self.assertEqual(cursor.rowcount, affected_rows,
'cursor.rowcount != %s' % (str(affected_rows)))
# Test cursor.rownumber
self.assertEqual(cursor.rownumber, iter,
'cursor.rowcount != %s' % (str(iter)))
# Test row came out the same as it went in
self.assertEqual((row in data), True,
'Row not found in source data')
# Test fetchall
cursor.execute('SELECT * FROM tz_data')
self.assertEqual(len(cursor.fetchall()), len(data),
'fetchall failed. Number of rows does not match')
# Test fetchmany
cursor.execute('SELECT * FROM tz_data')
self.assertEqual(len(cursor.fetchmany(2)), 2,
'fetchmany failed. Number of rows does not match')
# So MySQLdb won't throw "Commands out of sync"
while True:
res = cursor.fetchone()
if res is None:
break
# Test update, affected_rows()
cursor.execute('UPDATE tz_data SET zone = %s', ['Foo'])
conn.commit()
self.assertEqual(cursor.rowcount, len(data),
'Update failed. affected_rows != %s' % (str(len(data))))
# Test executemany
cursor.executemany('INSERT INTO tz_data VALUES (%s, %s, %s)', data)
self.assertEqual(cursor.rowcount, len(data),
'executemany failed. cursor.rowcount != %s' % (str(len(data))))
finally:
cursor.execute('DROP TABLE tz_data')
cursor.close()

View File

@@ -92,7 +92,7 @@ class TestConversion(base.PyMySQLTestCase):
self.assertEqual(data.encode(conn.charset), c.fetchone()[0])
finally:
c.execute("drop table test_big_blob")
def test_untyped(self):
""" test conversion of null, empty string """
conn = self.connections[0]
@@ -101,7 +101,7 @@ class TestConversion(base.PyMySQLTestCase):
self.assertEqual((None,u''), c.fetchone())
c.execute("select '',null")
self.assertEqual((u'',None), c.fetchone())
def test_datetime(self):
""" test conversion of null, empty string """
conn = self.connections[0]

View File

@@ -4,7 +4,7 @@ from pymysql.tests import base
class TestExample(base.PyMySQLTestCase):
def test_example(self):
conn = pymysql.connect(host='127.0.0.1', port=3306, user='root', passwd='', db='mysql')
cur = conn.cursor()

View File

@@ -143,7 +143,7 @@ KEY (`station`,`dh`,`echeance`)) ENGINE=MyISAM DEFAULT CHARSET=latin1;""")
c.execute("insert into issue17 (x) values ('hello, world!')")
c.execute("grant all privileges on %s.issue17 to 'issue17user'@'%%' identified by '1234'" % db)
conn.commit()
conn2 = pymysql.connect(host=host, user="issue17user", passwd="1234", db=db)
c2 = conn2.cursor()
c2.execute("select x from issue17")
@@ -229,7 +229,7 @@ class TestNewIssues(base.PyMySQLTestCase):
conn = self.connections[0]
c = conn.cursor()
datum = "a" * 1024 * 1023 # reduced size for most default mysql installs
try:
c.execute("create table issue38 (id integer, data mediumblob)")
c.execute("insert into issue38 values (1, %s)", (datum,))

View File

@@ -3,7 +3,7 @@
for functionality and memory leaks.
Adapted from a script by M-A Lemburg.
"""
import sys
from time import time
@@ -20,7 +20,7 @@ class DatabaseTest(unittest.TestCase):
create_table_extra = "ENGINE=INNODB CHARACTER SET UTF8"
rows = 10
debug = False
def setUp(self):
import gc
db = self.db_module.connect(*self.connect_args, **self.connect_kwargs)
@@ -34,18 +34,18 @@ class DatabaseTest(unittest.TestCase):
self.BLOBBinary = self.db_module.Binary(''.join([chr(i) for i in range(256)] * 16))
leak_test = True
def tearDown(self):
if self.leak_test:
import gc
del self.cursor
orphans = gc.collect()
self.assertFalse(orphans, "%d orphaned objects found after deleting cursor" % orphans)
del self.connection
orphans = gc.collect()
self.assertFalse(orphans, "%d orphaned objects found after deleting connection" % orphans)
def table_exists(self, name):
try:
self.cursor.execute('select * from %s where 1=0' % name)
@@ -56,7 +56,7 @@ class DatabaseTest(unittest.TestCase):
def quote_identifier(self, ident):
return '"%s"' % ident
def new_table_name(self):
i = id(self.cursor)
while True:
@@ -69,14 +69,14 @@ class DatabaseTest(unittest.TestCase):
""" Create a table using a list of column definitions given in
columndefs.
generator must be a function taking arguments (row_number,
col_number) returning a suitable data object for insertion
into the table.
"""
self.table = self.new_table_name()
self.cursor.execute('CREATE TABLE %s (%s) %s' %
self.cursor.execute('CREATE TABLE %s (%s) %s' %
(self.table,
',\n'.join(columndefs),
self.create_table_extra))
@@ -84,7 +84,7 @@ class DatabaseTest(unittest.TestCase):
def check_data_integrity(self, columndefs, generator):
# insert
self.create_table(columndefs)
insert_statement = ('INSERT INTO %s VALUES (%s)' %
insert_statement = ('INSERT INTO %s VALUES (%s)' %
(self.table,
','.join(['%s'] * len(columndefs))))
data = [ [ generator(i,j) for j in range(len(columndefs)) ]
@@ -113,7 +113,7 @@ class DatabaseTest(unittest.TestCase):
if col == 0: return row
else: return ('%i' % (row%10))*255
self.create_table(columndefs)
insert_statement = ('INSERT INTO %s VALUES (%s)' %
insert_statement = ('INSERT INTO %s VALUES (%s)' %
(self.table,
','.join(['%s'] * len(columndefs))))
data = [ [ generator(i,j) for j in range(len(columndefs)) ]
@@ -146,7 +146,7 @@ class DatabaseTest(unittest.TestCase):
if col == 0: return row
else: return ('%i' % (row%10))*((255-self.rows//2)+row)
self.create_table(columndefs)
insert_statement = ('INSERT INTO %s VALUES (%s)' %
insert_statement = ('INSERT INTO %s VALUES (%s)' %
(self.table,
','.join(['%s'] * len(columndefs))))
@@ -160,7 +160,7 @@ class DatabaseTest(unittest.TestCase):
self.fail("Over-long column did not generate warnings/exception with single insert")
self.connection.rollback()
try:
for i in range(self.rows):
data = []
@@ -175,7 +175,7 @@ class DatabaseTest(unittest.TestCase):
self.fail("Over-long columns did not generate warnings/exception with execute()")
self.connection.rollback()
try:
data = [ [ generator(i,j) for j in range(len(columndefs)) ]
for i in range(self.rows) ]
@@ -300,4 +300,3 @@ class DatabaseTest(unittest.TestCase):
self.check_data_integrity(
('col1 INT','col2 BLOB'),
generator)

View File

@@ -1,6 +1,6 @@
#!/usr/bin/env python
''' Python DB API 2.0 driver compliance unit test suite.
''' Python DB API 2.0 driver compliance unit test suite.
This software is Public Domain and may be used without restrictions.
"Now we have booze and barflies entering the discussion, plus rumours of
@@ -67,8 +67,8 @@ import time
class DatabaseAPI20Test(unittest.TestCase):
''' Test a database self.driver for DB API 2.0 compatibility.
This implementation tests Gadfly, but the TestCase
is structured so that other self.drivers can subclass this
test case to ensure compiliance with the DB-API. It is
is structured so that other self.drivers can subclass this
test case to ensure compiliance with the DB-API. It is
expected that this TestCase may be expanded in the future
if ambiguities or edge conditions are discovered.
@@ -78,9 +78,9 @@ class DatabaseAPI20Test(unittest.TestCase):
self.driver, connect_args and connect_kw_args. Class specification
should be as follows:
import dbapi20
import dbapi20
class mytest(dbapi20.DatabaseAPI20Test):
[...]
[...]
Don't 'import DatabaseAPI20Test from dbapi20', or you will
confuse the unit tester - just 'import dbapi20'.
@@ -99,7 +99,7 @@ class DatabaseAPI20Test(unittest.TestCase):
xddl2 = 'drop table %sbarflys' % table_prefix
lowerfunc = 'lower' # Name of stored procedure to convert string->lowercase
# Some drivers may need to override these helpers, for example adding
# a 'commit' after the execute.
def executeDDL1(self,cursor):
@@ -123,10 +123,10 @@ class DatabaseAPI20Test(unittest.TestCase):
try:
cur = con.cursor()
for ddl in (self.xddl1,self.xddl2):
try:
try:
cur.execute(ddl)
con.commit()
except self.driver.Error:
except self.driver.Error:
# Assume table didn't exist. Other tests will check if
# execute is busted.
pass
@@ -238,7 +238,7 @@ class DatabaseAPI20Test(unittest.TestCase):
con.rollback()
except self.driver.NotSupportedError:
pass
def test_cursor(self):
con = self._connect()
try:
@@ -392,7 +392,7 @@ class DatabaseAPI20Test(unittest.TestCase):
)
elif self.driver.paramstyle == 'named':
cur.execute(
'insert into %sbooze values (:beer)' % self.table_prefix,
'insert into %sbooze values (:beer)' % self.table_prefix,
{'beer':"Cooper's"}
)
elif self.driver.paramstyle == 'format':
@@ -532,7 +532,7 @@ class DatabaseAPI20Test(unittest.TestCase):
tests.
'''
populate = [
"insert into %sbooze values ('%s')" % (self.table_prefix,s)
"insert into %sbooze values ('%s')" % (self.table_prefix,s)
for s in self.samples
]
return populate
@@ -593,7 +593,7 @@ class DatabaseAPI20Test(unittest.TestCase):
self.assertEqual(len(rows),6)
rows = [r[0] for r in rows]
rows.sort()
# Make sure we get the right data back out
for i in range(0,6):
self.assertEqual(rows[i],self.samples[i],
@@ -664,10 +664,10 @@ class DatabaseAPI20Test(unittest.TestCase):
'cursor.fetchall should return an empty list if '
'a select query returns no rows'
)
finally:
con.close()
def test_mixedfetch(self):
con = self._connect()
try:
@@ -703,8 +703,8 @@ class DatabaseAPI20Test(unittest.TestCase):
def help_nextset_setUp(self,cur):
''' Should create a procedure called deleteme
that returns two result sets, first the
number of rows in booze then "name from booze"
that returns two result sets, first the
number of rows in booze then "name from booze"
'''
raise NotImplementedError('Helper not implemented')
#sql="""
@@ -850,4 +850,3 @@ class DatabaseAPI20Test(unittest.TestCase):
self.assertTrue(hasattr(self.driver,'ROWID'),
'module.ROWID must be defined.'
)

View File

@@ -18,7 +18,7 @@ class test_MySQLdb(capabilities.DatabaseTest):
create_table_extra = "ENGINE=INNODB CHARACTER SET UTF8"
leak_test = False
def quote_identifier(self, ident):
return "`%s`" % ident
@@ -40,7 +40,7 @@ class test_MySQLdb(capabilities.DatabaseTest):
self.check_data_integrity(
('col1 TINYINT',),
generator)
def test_stored_procedures(self):
db = self.connection
c = self.cursor
@@ -49,7 +49,7 @@ class test_MySQLdb(capabilities.DatabaseTest):
c.executemany("INSERT INTO %s (pos,tree) VALUES (%%s,%%s)" % self.table,
list(enumerate('ash birch cedar larch pine'.split())))
db.commit()
c.execute("""
CREATE PROCEDURE test_sp(IN t VARCHAR(255))
BEGIN
@@ -57,7 +57,7 @@ class test_MySQLdb(capabilities.DatabaseTest):
END
""" % self.table)
db.commit()
c.callproc('test_sp', ('larch',))
rows = c.fetchall()
self.assertEquals(len(rows), 1)
@@ -84,7 +84,7 @@ class test_MySQLdb(capabilities.DatabaseTest):
self.cursor.execute("describe some_non_existent_table");
except self.connection.ProgrammingError as msg:
self.assertTrue(msg.args[0] == ER.NO_SUCH_TABLE)
def test_insert_values(self):
from pymysql.cursors import insert_values
query = """INSERT FOO (a, b, c) VALUES (a, b, c)"""
@@ -92,24 +92,23 @@ class test_MySQLdb(capabilities.DatabaseTest):
self.assertTrue(matched)
values = matched.group(1)
self.assertTrue(values == "(a, b, c)")
def test_ping(self):
self.connection.ping()
def test_literal_int(self):
self.assertTrue("2" == self.connection.literal(2))
def test_literal_float(self):
self.assertTrue("3.1415" == self.connection.literal(3.1415))
def test_literal_string(self):
self.assertTrue("'foo'" == self.connection.literal("foo"))
if __name__ == '__main__':
if test_MySQLdb.leak_test:
import gc
gc.enable()
gc.set_debug(gc.DEBUG_LEAK)
unittest.main()

View File

@@ -15,7 +15,7 @@ class test_MySQLdb(dbapi20.DatabaseAPI20Test):
connect_kw_args.update(dict(read_default_file='~/.my.cnf',
charset='utf8',
sql_mode="ANSI,STRICT_TRANS_TABLES,TRADITIONAL"))
def test_setoutputsize(self): pass
def test_setoutputsize_basic(self): pass
def test_nextset(self): pass
@@ -24,7 +24,7 @@ class test_MySQLdb(dbapi20.DatabaseAPI20Test):
test for an exception if the statement cannot return a
result set. MySQL always returns a result set; it's just that
some things return empty result sets."""
def test_fetchall(self):
con = self._connect()
try:
@@ -70,10 +70,10 @@ class test_MySQLdb(dbapi20.DatabaseAPI20Test):
'cursor.fetchall should return an empty list if '
'a select query returns no rows'
)
finally:
con.close()
def test_fetchone(self):
con = self._connect()
try:
@@ -152,8 +152,8 @@ class test_MySQLdb(dbapi20.DatabaseAPI20Test):
def help_nextset_setUp(self,cur):
''' Should create a procedure called deleteme
that returns two result sets, first the
number of rows in booze then "name from booze"
that returns two result sets, first the
number of rows in booze then "name from booze"
'''
sql="""
create procedure deleteme()
@@ -205,6 +205,6 @@ class test_MySQLdb(dbapi20.DatabaseAPI20Test):
finally:
con.close()
if __name__ == '__main__':
unittest.main()