Run reindent.py
This commit is contained in:
@@ -91,14 +91,14 @@ def Connect(*args, **kwargs):
|
||||
"""
|
||||
from .connections import Connection
|
||||
return Connection(*args, **kwargs)
|
||||
|
||||
|
||||
from pymysql import connections as _orig_conn
|
||||
Connect.__doc__ = _orig_conn.Connection.__init__.__doc__ + """\nSee connections.Connection.__init__() for
|
||||
information about defaults."""
|
||||
del _orig_conn
|
||||
|
||||
def get_client_info(): # for MySQLdb compatibility
|
||||
return '%s.%s.%s' % VERSION
|
||||
return '%s.%s.%s' % VERSION
|
||||
|
||||
connect = Connection = Connect
|
||||
|
||||
@@ -115,7 +115,7 @@ def thread_safe():
|
||||
def install_as_MySQLdb():
|
||||
"""
|
||||
After this function is called, any application that imports MySQLdb or
|
||||
_mysql will unwittingly actually use
|
||||
_mysql will unwittingly actually use
|
||||
"""
|
||||
sys.modules["MySQLdb"] = sys.modules["_mysql"] = sys.modules["pymysql"]
|
||||
|
||||
|
||||
@@ -241,4 +241,3 @@ def charset_by_name(name):
|
||||
|
||||
def charset_by_id(id):
|
||||
return _charsets.by_id(id)
|
||||
|
||||
|
||||
@@ -68,12 +68,12 @@ DEFAULT_CHARSET = 'latin1'
|
||||
|
||||
|
||||
def dump_packet(data):
|
||||
|
||||
|
||||
def is_ascii(data):
|
||||
if byte2int(data) >= 65 and byte2int(data) <= 122: #data.isalnum():
|
||||
return data
|
||||
return '.'
|
||||
|
||||
|
||||
try:
|
||||
print("packet length {}".format(len(data)))
|
||||
print("method call[1]: {}".format(sys._getframe(1).f_code.co_name))
|
||||
@@ -164,7 +164,7 @@ def pack_int24(n):
|
||||
return struct.pack('BBB', n&0xFF, (n>>8)&0xFF, (n>>16)&0xFF)
|
||||
|
||||
def unpack_uint16(n):
|
||||
return struct.unpack('<H', n[0:2])[0]
|
||||
return struct.unpack('<H', n[0:2])[0]
|
||||
|
||||
|
||||
# TODO: stop using bit-shifting in these functions...
|
||||
@@ -212,208 +212,208 @@ def defaulterrorhandler(connection, cursor, errorclass, errorvalue):
|
||||
|
||||
|
||||
class MysqlPacket(object):
|
||||
"""Representation of a MySQL response packet. Reads in the packet
|
||||
from the network socket, removes packet header and provides an interface
|
||||
for reading/parsing the packet results."""
|
||||
"""Representation of a MySQL response packet. Reads in the packet
|
||||
from the network socket, removes packet header and provides an interface
|
||||
for reading/parsing the packet results."""
|
||||
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
self.__position = 0
|
||||
self.__recv_packet()
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
self.__position = 0
|
||||
self.__recv_packet()
|
||||
|
||||
def __recv_packet(self):
|
||||
"""Parse the packet header and read entire packet payload into buffer."""
|
||||
packet_header = self.connection._read_bytes(4)
|
||||
if len(packet_header) < 4:
|
||||
raise OperationalError(2013, "Lost connection to MySQL server during query")
|
||||
def __recv_packet(self):
|
||||
"""Parse the packet header and read entire packet payload into buffer."""
|
||||
packet_header = self.connection._read_bytes(4)
|
||||
if len(packet_header) < 4:
|
||||
raise OperationalError(2013, "Lost connection to MySQL server during query")
|
||||
|
||||
if DEBUG: dump_packet(packet_header)
|
||||
packet_length_bin = packet_header[:3]
|
||||
self.__packet_number = byte2int(packet_header[3])
|
||||
# TODO: check packet_num is correct (+1 from last packet)
|
||||
if DEBUG: dump_packet(packet_header)
|
||||
packet_length_bin = packet_header[:3]
|
||||
self.__packet_number = byte2int(packet_header[3])
|
||||
# TODO: check packet_num is correct (+1 from last packet)
|
||||
|
||||
bin_length = packet_length_bin + int2byte(0) # pad little-endian number
|
||||
bytes_to_read = struct.unpack('<I', bin_length)[0]
|
||||
recv_data = self.connection._read_bytes(bytes_to_read)
|
||||
if len(recv_data) < bytes_to_read:
|
||||
raise OperationalError(2013, "Lost connection to MySQL server during query")
|
||||
if DEBUG: dump_packet(recv_data)
|
||||
self.__data = recv_data
|
||||
bin_length = packet_length_bin + int2byte(0) # pad little-endian number
|
||||
bytes_to_read = struct.unpack('<I', bin_length)[0]
|
||||
recv_data = self.connection._read_bytes(bytes_to_read)
|
||||
if len(recv_data) < bytes_to_read:
|
||||
raise OperationalError(2013, "Lost connection to MySQL server during query")
|
||||
if DEBUG: dump_packet(recv_data)
|
||||
self.__data = recv_data
|
||||
|
||||
def packet_number(self): return self.__packet_number
|
||||
def packet_number(self): return self.__packet_number
|
||||
|
||||
def get_all_data(self): return self.__data
|
||||
def get_all_data(self): return self.__data
|
||||
|
||||
def read(self, size):
|
||||
"""Read the first 'size' bytes in packet and advance cursor past them."""
|
||||
result = self.peek(size)
|
||||
self.advance(size)
|
||||
return result
|
||||
def read(self, size):
|
||||
"""Read the first 'size' bytes in packet and advance cursor past them."""
|
||||
result = self.peek(size)
|
||||
self.advance(size)
|
||||
return result
|
||||
|
||||
def read_all(self):
|
||||
"""Read all remaining data in the packet.
|
||||
def read_all(self):
|
||||
"""Read all remaining data in the packet.
|
||||
|
||||
(Subsequent read() or peek() will return errors.)
|
||||
"""
|
||||
result = self.__data[self.__position:]
|
||||
self.__position = None # ensure no subsequent read() or peek()
|
||||
return result
|
||||
(Subsequent read() or peek() will return errors.)
|
||||
"""
|
||||
result = self.__data[self.__position:]
|
||||
self.__position = None # ensure no subsequent read() or peek()
|
||||
return result
|
||||
|
||||
def advance(self, length):
|
||||
"""Advance the cursor in data buffer 'length' bytes."""
|
||||
new_position = self.__position + length
|
||||
if new_position < 0 or new_position > len(self.__data):
|
||||
raise Exception('Invalid advance amount (%s) for cursor. '
|
||||
'Position=%s' % (length, new_position))
|
||||
self.__position = new_position
|
||||
def advance(self, length):
|
||||
"""Advance the cursor in data buffer 'length' bytes."""
|
||||
new_position = self.__position + length
|
||||
if new_position < 0 or new_position > len(self.__data):
|
||||
raise Exception('Invalid advance amount (%s) for cursor. '
|
||||
'Position=%s' % (length, new_position))
|
||||
self.__position = new_position
|
||||
|
||||
def rewind(self, position=0):
|
||||
"""Set the position of the data buffer cursor to 'position'."""
|
||||
if position < 0 or position > len(self.__data):
|
||||
raise Exception("Invalid position to rewind cursor to: %s." % position)
|
||||
self.__position = position
|
||||
def rewind(self, position=0):
|
||||
"""Set the position of the data buffer cursor to 'position'."""
|
||||
if position < 0 or position > len(self.__data):
|
||||
raise Exception("Invalid position to rewind cursor to: %s." % position)
|
||||
self.__position = position
|
||||
|
||||
def peek(self, size):
|
||||
"""Look at the first 'size' bytes in packet without moving cursor."""
|
||||
result = self.__data[self.__position:(self.__position+size)]
|
||||
if len(result) != size:
|
||||
error = ('Result length not requested length:\n'
|
||||
'Expected=%s. Actual=%s. Position: %s. Data Length: %s'
|
||||
% (size, len(result), self.__position, len(self.__data)))
|
||||
if DEBUG:
|
||||
print(error)
|
||||
self.dump()
|
||||
raise AssertionError(error)
|
||||
return result
|
||||
def peek(self, size):
|
||||
"""Look at the first 'size' bytes in packet without moving cursor."""
|
||||
result = self.__data[self.__position:(self.__position+size)]
|
||||
if len(result) != size:
|
||||
error = ('Result length not requested length:\n'
|
||||
'Expected=%s. Actual=%s. Position: %s. Data Length: %s'
|
||||
% (size, len(result), self.__position, len(self.__data)))
|
||||
if DEBUG:
|
||||
print(error)
|
||||
self.dump()
|
||||
raise AssertionError(error)
|
||||
return result
|
||||
|
||||
def get_bytes(self, position, length=1):
|
||||
"""Get 'length' bytes starting at 'position'.
|
||||
def get_bytes(self, position, length=1):
|
||||
"""Get 'length' bytes starting at 'position'.
|
||||
|
||||
Position is start of payload (first four packet header bytes are not
|
||||
included) starting at index '0'.
|
||||
Position is start of payload (first four packet header bytes are not
|
||||
included) starting at index '0'.
|
||||
|
||||
No error checking is done. If requesting outside end of buffer
|
||||
an empty string (or string shorter than 'length') may be returned!
|
||||
"""
|
||||
return self.__data[position:(position+length)]
|
||||
No error checking is done. If requesting outside end of buffer
|
||||
an empty string (or string shorter than 'length') may be returned!
|
||||
"""
|
||||
return self.__data[position:(position+length)]
|
||||
|
||||
def read_length_coded_binary(self):
|
||||
"""Read a 'Length Coded Binary' number from the data buffer.
|
||||
def read_length_coded_binary(self):
|
||||
"""Read a 'Length Coded Binary' number from the data buffer.
|
||||
|
||||
Length coded numbers can be anywhere from 1 to 9 bytes depending
|
||||
on the value of the first byte.
|
||||
"""
|
||||
c = byte2int(self.read(1))
|
||||
if c == NULL_COLUMN:
|
||||
return None
|
||||
if c < UNSIGNED_CHAR_COLUMN:
|
||||
return c
|
||||
elif c == UNSIGNED_SHORT_COLUMN:
|
||||
return unpack_uint16(self.read(UNSIGNED_SHORT_LENGTH))
|
||||
elif c == UNSIGNED_INT24_COLUMN:
|
||||
return unpack_int24(self.read(UNSIGNED_INT24_LENGTH))
|
||||
elif c == UNSIGNED_INT64_COLUMN:
|
||||
# TODO: what was 'longlong'? confirm it wasn't used?
|
||||
return unpack_int64(self.read(UNSIGNED_INT64_LENGTH))
|
||||
Length coded numbers can be anywhere from 1 to 9 bytes depending
|
||||
on the value of the first byte.
|
||||
"""
|
||||
c = byte2int(self.read(1))
|
||||
if c == NULL_COLUMN:
|
||||
return None
|
||||
if c < UNSIGNED_CHAR_COLUMN:
|
||||
return c
|
||||
elif c == UNSIGNED_SHORT_COLUMN:
|
||||
return unpack_uint16(self.read(UNSIGNED_SHORT_LENGTH))
|
||||
elif c == UNSIGNED_INT24_COLUMN:
|
||||
return unpack_int24(self.read(UNSIGNED_INT24_LENGTH))
|
||||
elif c == UNSIGNED_INT64_COLUMN:
|
||||
# TODO: what was 'longlong'? confirm it wasn't used?
|
||||
return unpack_int64(self.read(UNSIGNED_INT64_LENGTH))
|
||||
|
||||
def read_length_coded_string(self):
|
||||
"""Read a 'Length Coded String' from the data buffer.
|
||||
def read_length_coded_string(self):
|
||||
"""Read a 'Length Coded String' from the data buffer.
|
||||
|
||||
A 'Length Coded String' consists first of a length coded
|
||||
(unsigned, positive) integer represented in 1-9 bytes followed by
|
||||
that many bytes of binary data. (For example "cat" would be "3cat".)
|
||||
"""
|
||||
length = self.read_length_coded_binary()
|
||||
if length is None:
|
||||
return None
|
||||
return self.read(length)
|
||||
A 'Length Coded String' consists first of a length coded
|
||||
(unsigned, positive) integer represented in 1-9 bytes followed by
|
||||
that many bytes of binary data. (For example "cat" would be "3cat".)
|
||||
"""
|
||||
length = self.read_length_coded_binary()
|
||||
if length is None:
|
||||
return None
|
||||
return self.read(length)
|
||||
|
||||
def is_ok_packet(self):
|
||||
return byte2int(self.get_bytes(0)) == 0
|
||||
def is_ok_packet(self):
|
||||
return byte2int(self.get_bytes(0)) == 0
|
||||
|
||||
def is_eof_packet(self):
|
||||
return byte2int(self.get_bytes(0)) == 254 # 'fe'
|
||||
def is_eof_packet(self):
|
||||
return byte2int(self.get_bytes(0)) == 254 # 'fe'
|
||||
|
||||
def is_resultset_packet(self):
|
||||
field_count = byte2int(self.get_bytes(0))
|
||||
return field_count >= 1 and field_count <= 250
|
||||
def is_resultset_packet(self):
|
||||
field_count = byte2int(self.get_bytes(0))
|
||||
return field_count >= 1 and field_count <= 250
|
||||
|
||||
def is_error_packet(self):
|
||||
return byte2int(self.get_bytes(0)) == 255
|
||||
def is_error_packet(self):
|
||||
return byte2int(self.get_bytes(0)) == 255
|
||||
|
||||
def check_error(self):
|
||||
if self.is_error_packet():
|
||||
self.rewind()
|
||||
self.advance(1) # field_count == error (we already know that)
|
||||
errno = unpack_uint16(self.read(2))
|
||||
if DEBUG: print("errno = {}".format(errno))
|
||||
raise_mysql_exception(self.__data)
|
||||
def check_error(self):
|
||||
if self.is_error_packet():
|
||||
self.rewind()
|
||||
self.advance(1) # field_count == error (we already know that)
|
||||
errno = unpack_uint16(self.read(2))
|
||||
if DEBUG: print("errno = {}".format(errno))
|
||||
raise_mysql_exception(self.__data)
|
||||
|
||||
def dump(self):
|
||||
dump_packet(self.__data)
|
||||
def dump(self):
|
||||
dump_packet(self.__data)
|
||||
|
||||
|
||||
class FieldDescriptorPacket(MysqlPacket):
|
||||
"""A MysqlPacket that represents a specific column's metadata in the result.
|
||||
"""A MysqlPacket that represents a specific column's metadata in the result.
|
||||
|
||||
Parsing is automatically done and the results are exported via public
|
||||
attributes on the class such as: db, table_name, name, length, type_code.
|
||||
"""
|
||||
|
||||
def __init__(self, *args):
|
||||
MysqlPacket.__init__(self, *args)
|
||||
self.__parse_field_descriptor()
|
||||
|
||||
def __parse_field_descriptor(self):
|
||||
"""Parse the 'Field Descriptor' (Metadata) packet.
|
||||
|
||||
This is compatible with MySQL 4.1+ (not compatible with MySQL 4.0).
|
||||
Parsing is automatically done and the results are exported via public
|
||||
attributes on the class such as: db, table_name, name, length, type_code.
|
||||
"""
|
||||
self.catalog = self.read_length_coded_string()
|
||||
self.db = self.read_length_coded_string()
|
||||
self.table_name = self.read_length_coded_string()
|
||||
self.org_table = self.read_length_coded_string()
|
||||
self.name = self.read_length_coded_string().decode(self.connection.charset)
|
||||
self.org_name = self.read_length_coded_string()
|
||||
self.advance(1) # non-null filler
|
||||
self.charsetnr = struct.unpack('<H', self.read(2))[0]
|
||||
self.length = struct.unpack('<I', self.read(4))[0]
|
||||
self.type_code = byte2int(self.read(1))
|
||||
self.flags = struct.unpack('<H', self.read(2))[0]
|
||||
self.scale = byte2int(self.read(1)) # "decimals"
|
||||
self.advance(2) # filler (always 0x00)
|
||||
|
||||
# 'default' is a length coded binary and is still in the buffer?
|
||||
# not used for normal result sets...
|
||||
def __init__(self, *args):
|
||||
MysqlPacket.__init__(self, *args)
|
||||
self.__parse_field_descriptor()
|
||||
|
||||
def description(self):
|
||||
"""Provides a 7-item tuple compatible with the Python PEP249 DB Spec."""
|
||||
desc = []
|
||||
desc.append(self.name)
|
||||
desc.append(self.type_code)
|
||||
desc.append(None) # TODO: display_length; should this be self.length?
|
||||
desc.append(self.get_column_length()) # 'internal_size'
|
||||
desc.append(self.get_column_length()) # 'precision' # TODO: why!?!?
|
||||
desc.append(self.scale)
|
||||
def __parse_field_descriptor(self):
|
||||
"""Parse the 'Field Descriptor' (Metadata) packet.
|
||||
|
||||
# 'null_ok' -- can this be True/False rather than 1/0?
|
||||
# if so just do: desc.append(bool(self.flags % 2 == 0))
|
||||
if self.flags % 2 == 0:
|
||||
desc.append(1)
|
||||
else:
|
||||
desc.append(0)
|
||||
return tuple(desc)
|
||||
This is compatible with MySQL 4.1+ (not compatible with MySQL 4.0).
|
||||
"""
|
||||
self.catalog = self.read_length_coded_string()
|
||||
self.db = self.read_length_coded_string()
|
||||
self.table_name = self.read_length_coded_string()
|
||||
self.org_table = self.read_length_coded_string()
|
||||
self.name = self.read_length_coded_string().decode(self.connection.charset)
|
||||
self.org_name = self.read_length_coded_string()
|
||||
self.advance(1) # non-null filler
|
||||
self.charsetnr = struct.unpack('<H', self.read(2))[0]
|
||||
self.length = struct.unpack('<I', self.read(4))[0]
|
||||
self.type_code = byte2int(self.read(1))
|
||||
self.flags = struct.unpack('<H', self.read(2))[0]
|
||||
self.scale = byte2int(self.read(1)) # "decimals"
|
||||
self.advance(2) # filler (always 0x00)
|
||||
|
||||
def get_column_length(self):
|
||||
if self.type_code == FIELD_TYPE.VAR_STRING:
|
||||
mblen = MBLENGTH.get(self.charsetnr, 1)
|
||||
return self.length // mblen
|
||||
return self.length
|
||||
# 'default' is a length coded binary and is still in the buffer?
|
||||
# not used for normal result sets...
|
||||
|
||||
def __str__(self):
|
||||
return ('%s %s.%s.%s, type=%s'
|
||||
% (self.__class__, self.db, self.table_name, self.name,
|
||||
self.type_code))
|
||||
def description(self):
|
||||
"""Provides a 7-item tuple compatible with the Python PEP249 DB Spec."""
|
||||
desc = []
|
||||
desc.append(self.name)
|
||||
desc.append(self.type_code)
|
||||
desc.append(None) # TODO: display_length; should this be self.length?
|
||||
desc.append(self.get_column_length()) # 'internal_size'
|
||||
desc.append(self.get_column_length()) # 'precision' # TODO: why!?!?
|
||||
desc.append(self.scale)
|
||||
|
||||
# 'null_ok' -- can this be True/False rather than 1/0?
|
||||
# if so just do: desc.append(bool(self.flags % 2 == 0))
|
||||
if self.flags % 2 == 0:
|
||||
desc.append(1)
|
||||
else:
|
||||
desc.append(0)
|
||||
return tuple(desc)
|
||||
|
||||
def get_column_length(self):
|
||||
if self.type_code == FIELD_TYPE.VAR_STRING:
|
||||
mblen = MBLENGTH.get(self.charsetnr, 1)
|
||||
return self.length // mblen
|
||||
return self.length
|
||||
|
||||
def __str__(self):
|
||||
return ('%s %s.%s.%s, type=%s'
|
||||
% (self.__class__, self.db, self.table_name, self.name,
|
||||
self.type_code))
|
||||
|
||||
class OKPacketWrapper(object):
|
||||
"""
|
||||
@@ -426,20 +426,20 @@ class OKPacketWrapper(object):
|
||||
if not from_packet.is_ok_packet():
|
||||
raise ValueError('Cannot create ' + str(self.__class__.__name__)
|
||||
+ ' object from invalid packet type')
|
||||
|
||||
|
||||
self.packet = from_packet
|
||||
self.packet.advance(1)
|
||||
|
||||
|
||||
self.affected_rows = self.packet.read_length_coded_binary()
|
||||
self.insert_id = self.packet.read_length_coded_binary()
|
||||
self.server_status = struct.unpack('<H', self.packet.read(2))[0]
|
||||
self.warning_count = struct.unpack('<H', self.packet.read(2))[0]
|
||||
self.message = self.packet.read_all()
|
||||
|
||||
|
||||
def __getattr__(self, key):
|
||||
if hasattr(self.packet, key):
|
||||
return getattr(self.packet, key)
|
||||
|
||||
|
||||
raise AttributeError(str(self.__class__)
|
||||
+ " instance has no attribute '" + key + "'")
|
||||
|
||||
@@ -454,7 +454,7 @@ class EOFPacketWrapper(object):
|
||||
if not from_packet.is_eof_packet():
|
||||
raise ValueError('Cannot create ' + str(self.__class__.__name__)
|
||||
+ ' object from invalid packet type')
|
||||
|
||||
|
||||
self.packet = from_packet
|
||||
self.warning_count = self.packet.read(2)
|
||||
server_status = struct.unpack('<h', self.packet.read(2))[0]
|
||||
@@ -464,7 +464,7 @@ class EOFPacketWrapper(object):
|
||||
def __getattr__(self, key):
|
||||
if hasattr(self.packet, key):
|
||||
return getattr(self.packet, key)
|
||||
|
||||
|
||||
raise AttributeError(str(self.__class__)
|
||||
+ " instance has no attribute '" + key + "'")
|
||||
|
||||
@@ -756,12 +756,12 @@ class Connection(object):
|
||||
raise OperationalError(2003, "Can't connect to MySQL server on %r (%s)" % (self.host, e.args[0]))
|
||||
|
||||
def read_packet(self, packet_type=MysqlPacket):
|
||||
"""Read an entire "mysql packet" in its entirety from the network
|
||||
and return a MysqlPacket type that represents the results."""
|
||||
"""Read an entire "mysql packet" in its entirety from the network
|
||||
and return a MysqlPacket type that represents the results."""
|
||||
|
||||
packet = packet_type(self)
|
||||
packet.check_error()
|
||||
return packet
|
||||
packet = packet_type(self)
|
||||
packet.check_error()
|
||||
return packet
|
||||
|
||||
def _read_bytes(self, num_bytes):
|
||||
d = self.socket.recv(num_bytes)
|
||||
@@ -820,7 +820,7 @@ class Connection(object):
|
||||
|
||||
def _execute_command(self, command, sql):
|
||||
self._send_command(command, sql)
|
||||
|
||||
|
||||
def _request_authentication(self):
|
||||
self.client_flag |= CAPABILITIES
|
||||
if self.server_version.startswith('5'):
|
||||
@@ -987,7 +987,7 @@ class MySQLResult(object):
|
||||
else:
|
||||
self.field_count = byte2int(first_packet.read(1))
|
||||
self._get_descriptions()
|
||||
|
||||
|
||||
# Apparently, MySQLdb picks this number because it's the maximum
|
||||
# value of a 64bit unsigned integer. Since we're emulating MySQLdb,
|
||||
# we set it to this instead of None, which would be preferred.
|
||||
@@ -1017,7 +1017,7 @@ class MySQLResult(object):
|
||||
def _read_rowdata_packet_unbuffered(self):
|
||||
# Check if in an active query
|
||||
if self.unbuffered_active == False: return
|
||||
|
||||
|
||||
# EOF
|
||||
packet = self.connection.read_packet()
|
||||
if self._check_packet_is_eof(packet):
|
||||
@@ -1052,29 +1052,29 @@ class MySQLResult(object):
|
||||
# TODO: implement this as an iteratable so that it is more
|
||||
# memory efficient and lower-latency to client...
|
||||
def _read_rowdata_packet(self):
|
||||
"""Read a rowdata packet for each data row in the result set."""
|
||||
rows = []
|
||||
while True:
|
||||
packet = self.connection.read_packet()
|
||||
if self._check_packet_is_eof(packet):
|
||||
break
|
||||
"""Read a rowdata packet for each data row in the result set."""
|
||||
rows = []
|
||||
while True:
|
||||
packet = self.connection.read_packet()
|
||||
if self._check_packet_is_eof(packet):
|
||||
break
|
||||
|
||||
row = []
|
||||
for field in self.fields:
|
||||
data = packet.read_length_coded_string()
|
||||
converted = None
|
||||
if field.type_code in self.connection.decoders:
|
||||
converter = self.connection.decoders[field.type_code]
|
||||
if DEBUG: print("DEBUG: field={}, converter={}".format(field, converter))
|
||||
if data != None:
|
||||
converted = converter(self.connection, field, data)
|
||||
row.append(converted)
|
||||
row = []
|
||||
for field in self.fields:
|
||||
data = packet.read_length_coded_string()
|
||||
converted = None
|
||||
if field.type_code in self.connection.decoders:
|
||||
converter = self.connection.decoders[field.type_code]
|
||||
if DEBUG: print("DEBUG: field={}, converter={}".format(field, converter))
|
||||
if data != None:
|
||||
converted = converter(self.connection, field, data)
|
||||
row.append(converted)
|
||||
|
||||
rows.append(tuple(row))
|
||||
rows.append(tuple(row))
|
||||
|
||||
self.affected_rows = len(rows)
|
||||
self.rows = tuple(rows)
|
||||
if DEBUG: self.rows
|
||||
self.affected_rows = len(rows)
|
||||
self.rows = tuple(rows)
|
||||
if DEBUG: self.rows
|
||||
|
||||
def _get_descriptions(self):
|
||||
"""Read a column descriptor packet for each column in the result."""
|
||||
|
||||
@@ -9,4 +9,3 @@ SERVER_STATUS_LAST_ROW_SENT = 128
|
||||
SERVER_STATUS_DB_DROPPED = 256
|
||||
SERVER_STATUS_NO_BACKSLASH_ESCAPES = 512
|
||||
SERVER_STATUS_METADATA_CHANGED = 1024
|
||||
|
||||
|
||||
@@ -304,22 +304,22 @@ class SSCursor(Cursor):
|
||||
"""
|
||||
Unbuffered Cursor, mainly useful for queries that return a lot of data,
|
||||
or for connections to remote servers over a slow network.
|
||||
|
||||
|
||||
Instead of copying every row of data into a buffer, this will fetch
|
||||
rows as needed. The upside of this, is the client uses much less memory,
|
||||
and rows are returned much faster when traveling over a slow network,
|
||||
or if the result set is very big.
|
||||
|
||||
|
||||
There are limitations, though. The MySQL protocol doesn't support
|
||||
returning the total number of rows, so the only way to tell how many rows
|
||||
there are is to iterate over every row returned. Also, it currently isn't
|
||||
possible to scroll backwards, as only the current row is held in memory.
|
||||
"""
|
||||
|
||||
|
||||
def close(self):
|
||||
conn = self._get_db()
|
||||
conn._result._finish_unbuffered_query()
|
||||
|
||||
|
||||
try:
|
||||
if self._has_next:
|
||||
while self.nextset(): pass
|
||||
@@ -331,31 +331,31 @@ class SSCursor(Cursor):
|
||||
conn.query(q, unbuffered=True)
|
||||
self._do_get_result()
|
||||
return self.rowcount
|
||||
|
||||
|
||||
def read_next(self):
|
||||
""" Read next row """
|
||||
|
||||
|
||||
conn = self._get_db()
|
||||
conn._result._read_rowdata_packet_unbuffered()
|
||||
return conn._result.rows
|
||||
|
||||
|
||||
def fetchone(self):
|
||||
""" Fetch next row """
|
||||
|
||||
|
||||
self._check_executed()
|
||||
row = self.read_next()
|
||||
if row is None:
|
||||
return None
|
||||
self.rownumber += 1
|
||||
return row
|
||||
|
||||
|
||||
def fetchall(self):
|
||||
"""
|
||||
Fetch all, as per MySQLdb. Pretty useless for large queries, as
|
||||
it is buffered. See fetchall_unbuffered(), if you want an unbuffered
|
||||
generator version of this method.
|
||||
"""
|
||||
|
||||
|
||||
rows = []
|
||||
while True:
|
||||
row = self.fetchone()
|
||||
@@ -370,19 +370,19 @@ class SSCursor(Cursor):
|
||||
however, it doesn't make sense to return everything in a list, as that
|
||||
would use ridiculous memory for large result sets.
|
||||
"""
|
||||
|
||||
|
||||
row = self.fetchone()
|
||||
while row is not None:
|
||||
yield row
|
||||
row = self.fetchone()
|
||||
|
||||
|
||||
def fetchmany(self, size=None):
|
||||
""" Fetch many """
|
||||
|
||||
|
||||
self._check_executed()
|
||||
if size is None:
|
||||
size = self.arraysize
|
||||
|
||||
|
||||
rows = []
|
||||
for i in range(0, size):
|
||||
row = self.read_next()
|
||||
@@ -391,25 +391,25 @@ class SSCursor(Cursor):
|
||||
rows.append(row)
|
||||
self.rownumber += 1
|
||||
return tuple(rows)
|
||||
|
||||
|
||||
def scroll(self, value, mode='relative'):
|
||||
self._check_executed()
|
||||
if not mode == 'relative' and not mode == 'absolute':
|
||||
self.errorhandler(self, ProgrammingError,
|
||||
"unknown scroll mode %s" % mode)
|
||||
|
||||
|
||||
if mode == 'relative':
|
||||
if value < 0:
|
||||
self.errorhandler(self, NotSupportedError,
|
||||
"Backwards scrolling not supported by this cursor")
|
||||
|
||||
|
||||
for i in range(0, value): self.read_next()
|
||||
self.rownumber += value
|
||||
else:
|
||||
if value < self.rownumber:
|
||||
self.errorhandler(self, NotSupportedError,
|
||||
"Backwards scrolling not supported by this cursor")
|
||||
|
||||
|
||||
end = value - self.rownumber
|
||||
for i in range(0, end): self.read_next()
|
||||
self.rownumber = value
|
||||
|
||||
@@ -4,7 +4,7 @@ from .constants import ER
|
||||
import sys
|
||||
|
||||
class MySQLError(Exception):
|
||||
|
||||
|
||||
"""Exception related to operation with MySQL."""
|
||||
|
||||
|
||||
@@ -95,12 +95,12 @@ _map_error(IntegrityError, ER.DUP_ENTRY, ER.NO_REFERENCED_ROW,
|
||||
ER.CANNOT_ADD_FOREIGN)
|
||||
_map_error(NotSupportedError, ER.WARNING_NOT_COMPLETE_ROLLBACK,
|
||||
ER.NOT_SUPPORTED_YET, ER.FEATURE_DISABLED, ER.UNKNOWN_STORAGE_ENGINE)
|
||||
_map_error(OperationalError, ER.DBACCESS_DENIED_ERROR, ER.ACCESS_DENIED_ERROR,
|
||||
ER.TABLEACCESS_DENIED_ERROR, ER.COLUMNACCESS_DENIED_ERROR)
|
||||
_map_error(OperationalError, ER.DBACCESS_DENIED_ERROR, ER.ACCESS_DENIED_ERROR,
|
||||
ER.TABLEACCESS_DENIED_ERROR, ER.COLUMNACCESS_DENIED_ERROR)
|
||||
|
||||
del _map_error, ER
|
||||
|
||||
|
||||
|
||||
def _get_error_info(data):
|
||||
errno = struct.unpack('<h', data[1:3])[0]
|
||||
if sys.version_info[0] == 3:
|
||||
@@ -117,7 +117,7 @@ def _get_error_info(data):
|
||||
return (errno, None, data[3:].decode("utf8"))
|
||||
|
||||
def _check_mysql_exception(errinfo):
|
||||
errno, sqlstate, errorvalue = errinfo
|
||||
errno, sqlstate, errorvalue = errinfo
|
||||
errorclass = error_map.get(errno, None)
|
||||
if errorclass:
|
||||
raise errorclass(errno,errorvalue)
|
||||
@@ -128,8 +128,3 @@ def _check_mysql_exception(errinfo):
|
||||
def raise_mysql_exception(data):
|
||||
errinfo = _get_error_info(data)
|
||||
_check_mysql_exception(errinfo)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -17,4 +17,3 @@ class PyMySQLTestCase(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
for connection in self.connections:
|
||||
connection.close()
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import datetime
|
||||
class TestDictCursor(base.PyMySQLTestCase):
|
||||
|
||||
def test_DictCursor(self):
|
||||
#all assert test compare to the structure as would come out from MySQLdb
|
||||
#all assert test compare to the structure as would come out from MySQLdb
|
||||
conn = self.connections[0]
|
||||
c = conn.cursor(pymysql.cursors.DictCursor)
|
||||
# create a table ane some data to query
|
||||
@@ -30,7 +30,7 @@ class TestDictCursor(base.PyMySQLTestCase):
|
||||
c.execute("SELECT * from dictcursor where name='bob'")
|
||||
r = c.fetchall()
|
||||
self.assertEqual((bob,),r,"fetch a 1 row result via fetchall failed via DictCursor")
|
||||
# same test again but iterate over the
|
||||
# same test again but iterate over the
|
||||
c.execute("SELECT * from dictcursor where name='bob'")
|
||||
for r in c:
|
||||
self.assertEqual(bob, r,"fetch a 1 row result via iteration failed via DictCursor")
|
||||
|
||||
@@ -12,7 +12,7 @@ except:
|
||||
class TestSSCursor(base.PyMySQLTestCase):
|
||||
def test_SSCursor(self):
|
||||
affected_rows = 18446744073709551615
|
||||
|
||||
|
||||
conn = self.connections[0]
|
||||
data = [
|
||||
('America', '', 'America/Jamaica'),
|
||||
@@ -25,22 +25,22 @@ class TestSSCursor(base.PyMySQLTestCase):
|
||||
('America', '', 'America/Costa_Rica'),
|
||||
('America', '', 'America/Denver'),
|
||||
('America', '', 'America/Detroit'),]
|
||||
|
||||
|
||||
try:
|
||||
cursor = conn.cursor(pymysql.cursors.SSCursor)
|
||||
|
||||
|
||||
# Create table
|
||||
cursor.execute(('CREATE TABLE tz_data ('
|
||||
'region VARCHAR(64),'
|
||||
'zone VARCHAR(64),'
|
||||
'name VARCHAR(64))'))
|
||||
|
||||
|
||||
# Test INSERT
|
||||
for i in data:
|
||||
cursor.execute('INSERT INTO tz_data VALUES (%s, %s, %s)', i)
|
||||
self.assertEqual(conn.affected_rows(), 1, 'affected_rows does not match')
|
||||
conn.commit()
|
||||
|
||||
|
||||
# Test fetchone()
|
||||
iter = 0
|
||||
cursor.execute('SELECT * FROM tz_data')
|
||||
@@ -49,46 +49,46 @@ class TestSSCursor(base.PyMySQLTestCase):
|
||||
if row is None:
|
||||
break
|
||||
iter += 1
|
||||
|
||||
|
||||
# Test cursor.rowcount
|
||||
self.assertEqual(cursor.rowcount, affected_rows,
|
||||
'cursor.rowcount != %s' % (str(affected_rows)))
|
||||
|
||||
|
||||
# Test cursor.rownumber
|
||||
self.assertEqual(cursor.rownumber, iter,
|
||||
'cursor.rowcount != %s' % (str(iter)))
|
||||
|
||||
|
||||
# Test row came out the same as it went in
|
||||
self.assertEqual((row in data), True,
|
||||
'Row not found in source data')
|
||||
|
||||
|
||||
# Test fetchall
|
||||
cursor.execute('SELECT * FROM tz_data')
|
||||
self.assertEqual(len(cursor.fetchall()), len(data),
|
||||
'fetchall failed. Number of rows does not match')
|
||||
|
||||
|
||||
# Test fetchmany
|
||||
cursor.execute('SELECT * FROM tz_data')
|
||||
self.assertEqual(len(cursor.fetchmany(2)), 2,
|
||||
'fetchmany failed. Number of rows does not match')
|
||||
|
||||
|
||||
# So MySQLdb won't throw "Commands out of sync"
|
||||
while True:
|
||||
res = cursor.fetchone()
|
||||
if res is None:
|
||||
break
|
||||
|
||||
|
||||
# Test update, affected_rows()
|
||||
cursor.execute('UPDATE tz_data SET zone = %s', ['Foo'])
|
||||
conn.commit()
|
||||
self.assertEqual(cursor.rowcount, len(data),
|
||||
'Update failed. affected_rows != %s' % (str(len(data))))
|
||||
|
||||
|
||||
# Test executemany
|
||||
cursor.executemany('INSERT INTO tz_data VALUES (%s, %s, %s)', data)
|
||||
self.assertEqual(cursor.rowcount, len(data),
|
||||
'executemany failed. cursor.rowcount != %s' % (str(len(data))))
|
||||
|
||||
|
||||
finally:
|
||||
cursor.execute('DROP TABLE tz_data')
|
||||
cursor.close()
|
||||
|
||||
@@ -92,7 +92,7 @@ class TestConversion(base.PyMySQLTestCase):
|
||||
self.assertEqual(data.encode(conn.charset), c.fetchone()[0])
|
||||
finally:
|
||||
c.execute("drop table test_big_blob")
|
||||
|
||||
|
||||
def test_untyped(self):
|
||||
""" test conversion of null, empty string """
|
||||
conn = self.connections[0]
|
||||
@@ -101,7 +101,7 @@ class TestConversion(base.PyMySQLTestCase):
|
||||
self.assertEqual((None,u''), c.fetchone())
|
||||
c.execute("select '',null")
|
||||
self.assertEqual((u'',None), c.fetchone())
|
||||
|
||||
|
||||
def test_datetime(self):
|
||||
""" test conversion of null, empty string """
|
||||
conn = self.connections[0]
|
||||
|
||||
@@ -4,7 +4,7 @@ from pymysql.tests import base
|
||||
class TestExample(base.PyMySQLTestCase):
|
||||
def test_example(self):
|
||||
conn = pymysql.connect(host='127.0.0.1', port=3306, user='root', passwd='', db='mysql')
|
||||
|
||||
|
||||
|
||||
cur = conn.cursor()
|
||||
|
||||
|
||||
@@ -143,7 +143,7 @@ KEY (`station`,`dh`,`echeance`)) ENGINE=MyISAM DEFAULT CHARSET=latin1;""")
|
||||
c.execute("insert into issue17 (x) values ('hello, world!')")
|
||||
c.execute("grant all privileges on %s.issue17 to 'issue17user'@'%%' identified by '1234'" % db)
|
||||
conn.commit()
|
||||
|
||||
|
||||
conn2 = pymysql.connect(host=host, user="issue17user", passwd="1234", db=db)
|
||||
c2 = conn2.cursor()
|
||||
c2.execute("select x from issue17")
|
||||
@@ -229,7 +229,7 @@ class TestNewIssues(base.PyMySQLTestCase):
|
||||
conn = self.connections[0]
|
||||
c = conn.cursor()
|
||||
datum = "a" * 1024 * 1023 # reduced size for most default mysql installs
|
||||
|
||||
|
||||
try:
|
||||
c.execute("create table issue38 (id integer, data mediumblob)")
|
||||
c.execute("insert into issue38 values (1, %s)", (datum,))
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
for functionality and memory leaks.
|
||||
|
||||
Adapted from a script by M-A Lemburg.
|
||||
|
||||
|
||||
"""
|
||||
import sys
|
||||
from time import time
|
||||
@@ -20,7 +20,7 @@ class DatabaseTest(unittest.TestCase):
|
||||
create_table_extra = "ENGINE=INNODB CHARACTER SET UTF8"
|
||||
rows = 10
|
||||
debug = False
|
||||
|
||||
|
||||
def setUp(self):
|
||||
import gc
|
||||
db = self.db_module.connect(*self.connect_args, **self.connect_kwargs)
|
||||
@@ -34,18 +34,18 @@ class DatabaseTest(unittest.TestCase):
|
||||
self.BLOBBinary = self.db_module.Binary(''.join([chr(i) for i in range(256)] * 16))
|
||||
|
||||
leak_test = True
|
||||
|
||||
|
||||
def tearDown(self):
|
||||
if self.leak_test:
|
||||
import gc
|
||||
del self.cursor
|
||||
orphans = gc.collect()
|
||||
self.assertFalse(orphans, "%d orphaned objects found after deleting cursor" % orphans)
|
||||
|
||||
|
||||
del self.connection
|
||||
orphans = gc.collect()
|
||||
self.assertFalse(orphans, "%d orphaned objects found after deleting connection" % orphans)
|
||||
|
||||
|
||||
def table_exists(self, name):
|
||||
try:
|
||||
self.cursor.execute('select * from %s where 1=0' % name)
|
||||
@@ -56,7 +56,7 @@ class DatabaseTest(unittest.TestCase):
|
||||
|
||||
def quote_identifier(self, ident):
|
||||
return '"%s"' % ident
|
||||
|
||||
|
||||
def new_table_name(self):
|
||||
i = id(self.cursor)
|
||||
while True:
|
||||
@@ -69,14 +69,14 @@ class DatabaseTest(unittest.TestCase):
|
||||
|
||||
""" Create a table using a list of column definitions given in
|
||||
columndefs.
|
||||
|
||||
|
||||
generator must be a function taking arguments (row_number,
|
||||
col_number) returning a suitable data object for insertion
|
||||
into the table.
|
||||
|
||||
"""
|
||||
self.table = self.new_table_name()
|
||||
self.cursor.execute('CREATE TABLE %s (%s) %s' %
|
||||
self.cursor.execute('CREATE TABLE %s (%s) %s' %
|
||||
(self.table,
|
||||
',\n'.join(columndefs),
|
||||
self.create_table_extra))
|
||||
@@ -84,7 +84,7 @@ class DatabaseTest(unittest.TestCase):
|
||||
def check_data_integrity(self, columndefs, generator):
|
||||
# insert
|
||||
self.create_table(columndefs)
|
||||
insert_statement = ('INSERT INTO %s VALUES (%s)' %
|
||||
insert_statement = ('INSERT INTO %s VALUES (%s)' %
|
||||
(self.table,
|
||||
','.join(['%s'] * len(columndefs))))
|
||||
data = [ [ generator(i,j) for j in range(len(columndefs)) ]
|
||||
@@ -113,7 +113,7 @@ class DatabaseTest(unittest.TestCase):
|
||||
if col == 0: return row
|
||||
else: return ('%i' % (row%10))*255
|
||||
self.create_table(columndefs)
|
||||
insert_statement = ('INSERT INTO %s VALUES (%s)' %
|
||||
insert_statement = ('INSERT INTO %s VALUES (%s)' %
|
||||
(self.table,
|
||||
','.join(['%s'] * len(columndefs))))
|
||||
data = [ [ generator(i,j) for j in range(len(columndefs)) ]
|
||||
@@ -146,7 +146,7 @@ class DatabaseTest(unittest.TestCase):
|
||||
if col == 0: return row
|
||||
else: return ('%i' % (row%10))*((255-self.rows//2)+row)
|
||||
self.create_table(columndefs)
|
||||
insert_statement = ('INSERT INTO %s VALUES (%s)' %
|
||||
insert_statement = ('INSERT INTO %s VALUES (%s)' %
|
||||
(self.table,
|
||||
','.join(['%s'] * len(columndefs))))
|
||||
|
||||
@@ -160,7 +160,7 @@ class DatabaseTest(unittest.TestCase):
|
||||
self.fail("Over-long column did not generate warnings/exception with single insert")
|
||||
|
||||
self.connection.rollback()
|
||||
|
||||
|
||||
try:
|
||||
for i in range(self.rows):
|
||||
data = []
|
||||
@@ -175,7 +175,7 @@ class DatabaseTest(unittest.TestCase):
|
||||
self.fail("Over-long columns did not generate warnings/exception with execute()")
|
||||
|
||||
self.connection.rollback()
|
||||
|
||||
|
||||
try:
|
||||
data = [ [ generator(i,j) for j in range(len(columndefs)) ]
|
||||
for i in range(self.rows) ]
|
||||
@@ -300,4 +300,3 @@ class DatabaseTest(unittest.TestCase):
|
||||
self.check_data_integrity(
|
||||
('col1 INT','col2 BLOB'),
|
||||
generator)
|
||||
|
||||
|
||||
35
pymysql/tests/thirdparty/test_MySQLdb/dbapi20.py
vendored
35
pymysql/tests/thirdparty/test_MySQLdb/dbapi20.py
vendored
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
''' Python DB API 2.0 driver compliance unit test suite.
|
||||
|
||||
''' Python DB API 2.0 driver compliance unit test suite.
|
||||
|
||||
This software is Public Domain and may be used without restrictions.
|
||||
|
||||
"Now we have booze and barflies entering the discussion, plus rumours of
|
||||
@@ -67,8 +67,8 @@ import time
|
||||
class DatabaseAPI20Test(unittest.TestCase):
|
||||
''' Test a database self.driver for DB API 2.0 compatibility.
|
||||
This implementation tests Gadfly, but the TestCase
|
||||
is structured so that other self.drivers can subclass this
|
||||
test case to ensure compiliance with the DB-API. It is
|
||||
is structured so that other self.drivers can subclass this
|
||||
test case to ensure compiliance with the DB-API. It is
|
||||
expected that this TestCase may be expanded in the future
|
||||
if ambiguities or edge conditions are discovered.
|
||||
|
||||
@@ -78,9 +78,9 @@ class DatabaseAPI20Test(unittest.TestCase):
|
||||
self.driver, connect_args and connect_kw_args. Class specification
|
||||
should be as follows:
|
||||
|
||||
import dbapi20
|
||||
import dbapi20
|
||||
class mytest(dbapi20.DatabaseAPI20Test):
|
||||
[...]
|
||||
[...]
|
||||
|
||||
Don't 'import DatabaseAPI20Test from dbapi20', or you will
|
||||
confuse the unit tester - just 'import dbapi20'.
|
||||
@@ -99,7 +99,7 @@ class DatabaseAPI20Test(unittest.TestCase):
|
||||
xddl2 = 'drop table %sbarflys' % table_prefix
|
||||
|
||||
lowerfunc = 'lower' # Name of stored procedure to convert string->lowercase
|
||||
|
||||
|
||||
# Some drivers may need to override these helpers, for example adding
|
||||
# a 'commit' after the execute.
|
||||
def executeDDL1(self,cursor):
|
||||
@@ -123,10 +123,10 @@ class DatabaseAPI20Test(unittest.TestCase):
|
||||
try:
|
||||
cur = con.cursor()
|
||||
for ddl in (self.xddl1,self.xddl2):
|
||||
try:
|
||||
try:
|
||||
cur.execute(ddl)
|
||||
con.commit()
|
||||
except self.driver.Error:
|
||||
except self.driver.Error:
|
||||
# Assume table didn't exist. Other tests will check if
|
||||
# execute is busted.
|
||||
pass
|
||||
@@ -238,7 +238,7 @@ class DatabaseAPI20Test(unittest.TestCase):
|
||||
con.rollback()
|
||||
except self.driver.NotSupportedError:
|
||||
pass
|
||||
|
||||
|
||||
def test_cursor(self):
|
||||
con = self._connect()
|
||||
try:
|
||||
@@ -392,7 +392,7 @@ class DatabaseAPI20Test(unittest.TestCase):
|
||||
)
|
||||
elif self.driver.paramstyle == 'named':
|
||||
cur.execute(
|
||||
'insert into %sbooze values (:beer)' % self.table_prefix,
|
||||
'insert into %sbooze values (:beer)' % self.table_prefix,
|
||||
{'beer':"Cooper's"}
|
||||
)
|
||||
elif self.driver.paramstyle == 'format':
|
||||
@@ -532,7 +532,7 @@ class DatabaseAPI20Test(unittest.TestCase):
|
||||
tests.
|
||||
'''
|
||||
populate = [
|
||||
"insert into %sbooze values ('%s')" % (self.table_prefix,s)
|
||||
"insert into %sbooze values ('%s')" % (self.table_prefix,s)
|
||||
for s in self.samples
|
||||
]
|
||||
return populate
|
||||
@@ -593,7 +593,7 @@ class DatabaseAPI20Test(unittest.TestCase):
|
||||
self.assertEqual(len(rows),6)
|
||||
rows = [r[0] for r in rows]
|
||||
rows.sort()
|
||||
|
||||
|
||||
# Make sure we get the right data back out
|
||||
for i in range(0,6):
|
||||
self.assertEqual(rows[i],self.samples[i],
|
||||
@@ -664,10 +664,10 @@ class DatabaseAPI20Test(unittest.TestCase):
|
||||
'cursor.fetchall should return an empty list if '
|
||||
'a select query returns no rows'
|
||||
)
|
||||
|
||||
|
||||
finally:
|
||||
con.close()
|
||||
|
||||
|
||||
def test_mixedfetch(self):
|
||||
con = self._connect()
|
||||
try:
|
||||
@@ -703,8 +703,8 @@ class DatabaseAPI20Test(unittest.TestCase):
|
||||
|
||||
def help_nextset_setUp(self,cur):
|
||||
''' Should create a procedure called deleteme
|
||||
that returns two result sets, first the
|
||||
number of rows in booze then "name from booze"
|
||||
that returns two result sets, first the
|
||||
number of rows in booze then "name from booze"
|
||||
'''
|
||||
raise NotImplementedError('Helper not implemented')
|
||||
#sql="""
|
||||
@@ -850,4 +850,3 @@ class DatabaseAPI20Test(unittest.TestCase):
|
||||
self.assertTrue(hasattr(self.driver,'ROWID'),
|
||||
'module.ROWID must be defined.'
|
||||
)
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ class test_MySQLdb(capabilities.DatabaseTest):
|
||||
|
||||
create_table_extra = "ENGINE=INNODB CHARACTER SET UTF8"
|
||||
leak_test = False
|
||||
|
||||
|
||||
def quote_identifier(self, ident):
|
||||
return "`%s`" % ident
|
||||
|
||||
@@ -40,7 +40,7 @@ class test_MySQLdb(capabilities.DatabaseTest):
|
||||
self.check_data_integrity(
|
||||
('col1 TINYINT',),
|
||||
generator)
|
||||
|
||||
|
||||
def test_stored_procedures(self):
|
||||
db = self.connection
|
||||
c = self.cursor
|
||||
@@ -49,7 +49,7 @@ class test_MySQLdb(capabilities.DatabaseTest):
|
||||
c.executemany("INSERT INTO %s (pos,tree) VALUES (%%s,%%s)" % self.table,
|
||||
list(enumerate('ash birch cedar larch pine'.split())))
|
||||
db.commit()
|
||||
|
||||
|
||||
c.execute("""
|
||||
CREATE PROCEDURE test_sp(IN t VARCHAR(255))
|
||||
BEGIN
|
||||
@@ -57,7 +57,7 @@ class test_MySQLdb(capabilities.DatabaseTest):
|
||||
END
|
||||
""" % self.table)
|
||||
db.commit()
|
||||
|
||||
|
||||
c.callproc('test_sp', ('larch',))
|
||||
rows = c.fetchall()
|
||||
self.assertEquals(len(rows), 1)
|
||||
@@ -84,7 +84,7 @@ class test_MySQLdb(capabilities.DatabaseTest):
|
||||
self.cursor.execute("describe some_non_existent_table");
|
||||
except self.connection.ProgrammingError as msg:
|
||||
self.assertTrue(msg.args[0] == ER.NO_SUCH_TABLE)
|
||||
|
||||
|
||||
def test_insert_values(self):
|
||||
from pymysql.cursors import insert_values
|
||||
query = """INSERT FOO (a, b, c) VALUES (a, b, c)"""
|
||||
@@ -92,24 +92,23 @@ class test_MySQLdb(capabilities.DatabaseTest):
|
||||
self.assertTrue(matched)
|
||||
values = matched.group(1)
|
||||
self.assertTrue(values == "(a, b, c)")
|
||||
|
||||
|
||||
def test_ping(self):
|
||||
self.connection.ping()
|
||||
|
||||
def test_literal_int(self):
|
||||
self.assertTrue("2" == self.connection.literal(2))
|
||||
|
||||
|
||||
def test_literal_float(self):
|
||||
self.assertTrue("3.1415" == self.connection.literal(3.1415))
|
||||
|
||||
|
||||
def test_literal_string(self):
|
||||
self.assertTrue("'foo'" == self.connection.literal("foo"))
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if test_MySQLdb.leak_test:
|
||||
import gc
|
||||
gc.enable()
|
||||
gc.set_debug(gc.DEBUG_LEAK)
|
||||
unittest.main()
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ class test_MySQLdb(dbapi20.DatabaseAPI20Test):
|
||||
connect_kw_args.update(dict(read_default_file='~/.my.cnf',
|
||||
charset='utf8',
|
||||
sql_mode="ANSI,STRICT_TRANS_TABLES,TRADITIONAL"))
|
||||
|
||||
|
||||
def test_setoutputsize(self): pass
|
||||
def test_setoutputsize_basic(self): pass
|
||||
def test_nextset(self): pass
|
||||
@@ -24,7 +24,7 @@ class test_MySQLdb(dbapi20.DatabaseAPI20Test):
|
||||
test for an exception if the statement cannot return a
|
||||
result set. MySQL always returns a result set; it's just that
|
||||
some things return empty result sets."""
|
||||
|
||||
|
||||
def test_fetchall(self):
|
||||
con = self._connect()
|
||||
try:
|
||||
@@ -70,10 +70,10 @@ class test_MySQLdb(dbapi20.DatabaseAPI20Test):
|
||||
'cursor.fetchall should return an empty list if '
|
||||
'a select query returns no rows'
|
||||
)
|
||||
|
||||
|
||||
finally:
|
||||
con.close()
|
||||
|
||||
|
||||
def test_fetchone(self):
|
||||
con = self._connect()
|
||||
try:
|
||||
@@ -152,8 +152,8 @@ class test_MySQLdb(dbapi20.DatabaseAPI20Test):
|
||||
|
||||
def help_nextset_setUp(self,cur):
|
||||
''' Should create a procedure called deleteme
|
||||
that returns two result sets, first the
|
||||
number of rows in booze then "name from booze"
|
||||
that returns two result sets, first the
|
||||
number of rows in booze then "name from booze"
|
||||
'''
|
||||
sql="""
|
||||
create procedure deleteme()
|
||||
@@ -205,6 +205,6 @@ class test_MySQLdb(dbapi20.DatabaseAPI20Test):
|
||||
finally:
|
||||
con.close()
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user