From ef6ad66ca64721bb67c910bcf4b887cf06668dcb Mon Sep 17 00:00:00 2001 From: INADA Naoki Date: Fri, 30 Aug 2013 13:22:43 +0900 Subject: [PATCH] Run reindent.py --- pymysql/__init__.py | 6 +- pymysql/charset.py | 1 - pymysql/connections.py | 408 +++++++++--------- pymysql/constants/SERVER_STATUS.py | 1 - pymysql/cursors.py | 36 +- pymysql/err.py | 15 +- pymysql/tests/base.py | 1 - pymysql/tests/test_DictCursor.py | 4 +- pymysql/tests/test_SSCursor.py | 28 +- pymysql/tests/test_basic.py | 4 +- pymysql/tests/test_example.py | 2 +- pymysql/tests/test_issues.py | 4 +- .../thirdparty/test_MySQLdb/capabilities.py | 27 +- .../tests/thirdparty/test_MySQLdb/dbapi20.py | 35 +- .../test_MySQLdb/test_MySQLdb_capabilities.py | 21 +- .../test_MySQLdb/test_MySQLdb_dbapi20.py | 14 +- 16 files changed, 298 insertions(+), 309 deletions(-) diff --git a/pymysql/__init__.py b/pymysql/__init__.py index 803b014..e64ddc8 100644 --- a/pymysql/__init__.py +++ b/pymysql/__init__.py @@ -91,14 +91,14 @@ def Connect(*args, **kwargs): """ from .connections import Connection return Connection(*args, **kwargs) - + from pymysql import connections as _orig_conn Connect.__doc__ = _orig_conn.Connection.__init__.__doc__ + """\nSee connections.Connection.__init__() for information about defaults.""" del _orig_conn def get_client_info(): # for MySQLdb compatibility - return '%s.%s.%s' % VERSION + return '%s.%s.%s' % VERSION connect = Connection = Connect @@ -115,7 +115,7 @@ def thread_safe(): def install_as_MySQLdb(): """ After this function is called, any application that imports MySQLdb or - _mysql will unwittingly actually use + _mysql will unwittingly actually use """ sys.modules["MySQLdb"] = sys.modules["_mysql"] = sys.modules["pymysql"] diff --git a/pymysql/charset.py b/pymysql/charset.py index ce994eb..d1eb60d 100644 --- a/pymysql/charset.py +++ b/pymysql/charset.py @@ -241,4 +241,3 @@ def charset_by_name(name): def charset_by_id(id): return _charsets.by_id(id) - diff --git a/pymysql/connections.py b/pymysql/connections.py index f61e124..105f71e 100644 --- a/pymysql/connections.py +++ b/pymysql/connections.py @@ -68,12 +68,12 @@ DEFAULT_CHARSET = 'latin1' def dump_packet(data): - + def is_ascii(data): if byte2int(data) >= 65 and byte2int(data) <= 122: #data.isalnum(): return data return '.' - + try: print("packet length {}".format(len(data))) print("method call[1]: {}".format(sys._getframe(1).f_code.co_name)) @@ -164,7 +164,7 @@ def pack_int24(n): return struct.pack('BBB', n&0xFF, (n>>8)&0xFF, (n>>16)&0xFF) def unpack_uint16(n): - return struct.unpack(' len(self.__data): - raise Exception('Invalid advance amount (%s) for cursor. ' - 'Position=%s' % (length, new_position)) - self.__position = new_position + def advance(self, length): + """Advance the cursor in data buffer 'length' bytes.""" + new_position = self.__position + length + if new_position < 0 or new_position > len(self.__data): + raise Exception('Invalid advance amount (%s) for cursor. ' + 'Position=%s' % (length, new_position)) + self.__position = new_position - def rewind(self, position=0): - """Set the position of the data buffer cursor to 'position'.""" - if position < 0 or position > len(self.__data): - raise Exception("Invalid position to rewind cursor to: %s." % position) - self.__position = position + def rewind(self, position=0): + """Set the position of the data buffer cursor to 'position'.""" + if position < 0 or position > len(self.__data): + raise Exception("Invalid position to rewind cursor to: %s." % position) + self.__position = position - def peek(self, size): - """Look at the first 'size' bytes in packet without moving cursor.""" - result = self.__data[self.__position:(self.__position+size)] - if len(result) != size: - error = ('Result length not requested length:\n' - 'Expected=%s. Actual=%s. Position: %s. Data Length: %s' - % (size, len(result), self.__position, len(self.__data))) - if DEBUG: - print(error) - self.dump() - raise AssertionError(error) - return result + def peek(self, size): + """Look at the first 'size' bytes in packet without moving cursor.""" + result = self.__data[self.__position:(self.__position+size)] + if len(result) != size: + error = ('Result length not requested length:\n' + 'Expected=%s. Actual=%s. Position: %s. Data Length: %s' + % (size, len(result), self.__position, len(self.__data))) + if DEBUG: + print(error) + self.dump() + raise AssertionError(error) + return result - def get_bytes(self, position, length=1): - """Get 'length' bytes starting at 'position'. + def get_bytes(self, position, length=1): + """Get 'length' bytes starting at 'position'. - Position is start of payload (first four packet header bytes are not - included) starting at index '0'. + Position is start of payload (first four packet header bytes are not + included) starting at index '0'. - No error checking is done. If requesting outside end of buffer - an empty string (or string shorter than 'length') may be returned! - """ - return self.__data[position:(position+length)] + No error checking is done. If requesting outside end of buffer + an empty string (or string shorter than 'length') may be returned! + """ + return self.__data[position:(position+length)] - def read_length_coded_binary(self): - """Read a 'Length Coded Binary' number from the data buffer. + def read_length_coded_binary(self): + """Read a 'Length Coded Binary' number from the data buffer. - Length coded numbers can be anywhere from 1 to 9 bytes depending - on the value of the first byte. - """ - c = byte2int(self.read(1)) - if c == NULL_COLUMN: - return None - if c < UNSIGNED_CHAR_COLUMN: - return c - elif c == UNSIGNED_SHORT_COLUMN: - return unpack_uint16(self.read(UNSIGNED_SHORT_LENGTH)) - elif c == UNSIGNED_INT24_COLUMN: - return unpack_int24(self.read(UNSIGNED_INT24_LENGTH)) - elif c == UNSIGNED_INT64_COLUMN: - # TODO: what was 'longlong'? confirm it wasn't used? - return unpack_int64(self.read(UNSIGNED_INT64_LENGTH)) + Length coded numbers can be anywhere from 1 to 9 bytes depending + on the value of the first byte. + """ + c = byte2int(self.read(1)) + if c == NULL_COLUMN: + return None + if c < UNSIGNED_CHAR_COLUMN: + return c + elif c == UNSIGNED_SHORT_COLUMN: + return unpack_uint16(self.read(UNSIGNED_SHORT_LENGTH)) + elif c == UNSIGNED_INT24_COLUMN: + return unpack_int24(self.read(UNSIGNED_INT24_LENGTH)) + elif c == UNSIGNED_INT64_COLUMN: + # TODO: what was 'longlong'? confirm it wasn't used? + return unpack_int64(self.read(UNSIGNED_INT64_LENGTH)) - def read_length_coded_string(self): - """Read a 'Length Coded String' from the data buffer. + def read_length_coded_string(self): + """Read a 'Length Coded String' from the data buffer. - A 'Length Coded String' consists first of a length coded - (unsigned, positive) integer represented in 1-9 bytes followed by - that many bytes of binary data. (For example "cat" would be "3cat".) - """ - length = self.read_length_coded_binary() - if length is None: - return None - return self.read(length) + A 'Length Coded String' consists first of a length coded + (unsigned, positive) integer represented in 1-9 bytes followed by + that many bytes of binary data. (For example "cat" would be "3cat".) + """ + length = self.read_length_coded_binary() + if length is None: + return None + return self.read(length) - def is_ok_packet(self): - return byte2int(self.get_bytes(0)) == 0 + def is_ok_packet(self): + return byte2int(self.get_bytes(0)) == 0 - def is_eof_packet(self): - return byte2int(self.get_bytes(0)) == 254 # 'fe' + def is_eof_packet(self): + return byte2int(self.get_bytes(0)) == 254 # 'fe' - def is_resultset_packet(self): - field_count = byte2int(self.get_bytes(0)) - return field_count >= 1 and field_count <= 250 + def is_resultset_packet(self): + field_count = byte2int(self.get_bytes(0)) + return field_count >= 1 and field_count <= 250 - def is_error_packet(self): - return byte2int(self.get_bytes(0)) == 255 + def is_error_packet(self): + return byte2int(self.get_bytes(0)) == 255 - def check_error(self): - if self.is_error_packet(): - self.rewind() - self.advance(1) # field_count == error (we already know that) - errno = unpack_uint16(self.read(2)) - if DEBUG: print("errno = {}".format(errno)) - raise_mysql_exception(self.__data) + def check_error(self): + if self.is_error_packet(): + self.rewind() + self.advance(1) # field_count == error (we already know that) + errno = unpack_uint16(self.read(2)) + if DEBUG: print("errno = {}".format(errno)) + raise_mysql_exception(self.__data) - def dump(self): - dump_packet(self.__data) + def dump(self): + dump_packet(self.__data) class FieldDescriptorPacket(MysqlPacket): - """A MysqlPacket that represents a specific column's metadata in the result. + """A MysqlPacket that represents a specific column's metadata in the result. - Parsing is automatically done and the results are exported via public - attributes on the class such as: db, table_name, name, length, type_code. - """ - - def __init__(self, *args): - MysqlPacket.__init__(self, *args) - self.__parse_field_descriptor() - - def __parse_field_descriptor(self): - """Parse the 'Field Descriptor' (Metadata) packet. - - This is compatible with MySQL 4.1+ (not compatible with MySQL 4.0). + Parsing is automatically done and the results are exported via public + attributes on the class such as: db, table_name, name, length, type_code. """ - self.catalog = self.read_length_coded_string() - self.db = self.read_length_coded_string() - self.table_name = self.read_length_coded_string() - self.org_table = self.read_length_coded_string() - self.name = self.read_length_coded_string().decode(self.connection.charset) - self.org_name = self.read_length_coded_string() - self.advance(1) # non-null filler - self.charsetnr = struct.unpack('lowercase - + # Some drivers may need to override these helpers, for example adding # a 'commit' after the execute. def executeDDL1(self,cursor): @@ -123,10 +123,10 @@ class DatabaseAPI20Test(unittest.TestCase): try: cur = con.cursor() for ddl in (self.xddl1,self.xddl2): - try: + try: cur.execute(ddl) con.commit() - except self.driver.Error: + except self.driver.Error: # Assume table didn't exist. Other tests will check if # execute is busted. pass @@ -238,7 +238,7 @@ class DatabaseAPI20Test(unittest.TestCase): con.rollback() except self.driver.NotSupportedError: pass - + def test_cursor(self): con = self._connect() try: @@ -392,7 +392,7 @@ class DatabaseAPI20Test(unittest.TestCase): ) elif self.driver.paramstyle == 'named': cur.execute( - 'insert into %sbooze values (:beer)' % self.table_prefix, + 'insert into %sbooze values (:beer)' % self.table_prefix, {'beer':"Cooper's"} ) elif self.driver.paramstyle == 'format': @@ -532,7 +532,7 @@ class DatabaseAPI20Test(unittest.TestCase): tests. ''' populate = [ - "insert into %sbooze values ('%s')" % (self.table_prefix,s) + "insert into %sbooze values ('%s')" % (self.table_prefix,s) for s in self.samples ] return populate @@ -593,7 +593,7 @@ class DatabaseAPI20Test(unittest.TestCase): self.assertEqual(len(rows),6) rows = [r[0] for r in rows] rows.sort() - + # Make sure we get the right data back out for i in range(0,6): self.assertEqual(rows[i],self.samples[i], @@ -664,10 +664,10 @@ class DatabaseAPI20Test(unittest.TestCase): 'cursor.fetchall should return an empty list if ' 'a select query returns no rows' ) - + finally: con.close() - + def test_mixedfetch(self): con = self._connect() try: @@ -703,8 +703,8 @@ class DatabaseAPI20Test(unittest.TestCase): def help_nextset_setUp(self,cur): ''' Should create a procedure called deleteme - that returns two result sets, first the - number of rows in booze then "name from booze" + that returns two result sets, first the + number of rows in booze then "name from booze" ''' raise NotImplementedError('Helper not implemented') #sql=""" @@ -850,4 +850,3 @@ class DatabaseAPI20Test(unittest.TestCase): self.assertTrue(hasattr(self.driver,'ROWID'), 'module.ROWID must be defined.' ) - diff --git a/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_capabilities.py b/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_capabilities.py index cfdfe12..d62f4a3 100644 --- a/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_capabilities.py +++ b/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_capabilities.py @@ -18,7 +18,7 @@ class test_MySQLdb(capabilities.DatabaseTest): create_table_extra = "ENGINE=INNODB CHARACTER SET UTF8" leak_test = False - + def quote_identifier(self, ident): return "`%s`" % ident @@ -40,7 +40,7 @@ class test_MySQLdb(capabilities.DatabaseTest): self.check_data_integrity( ('col1 TINYINT',), generator) - + def test_stored_procedures(self): db = self.connection c = self.cursor @@ -49,7 +49,7 @@ class test_MySQLdb(capabilities.DatabaseTest): c.executemany("INSERT INTO %s (pos,tree) VALUES (%%s,%%s)" % self.table, list(enumerate('ash birch cedar larch pine'.split()))) db.commit() - + c.execute(""" CREATE PROCEDURE test_sp(IN t VARCHAR(255)) BEGIN @@ -57,7 +57,7 @@ class test_MySQLdb(capabilities.DatabaseTest): END """ % self.table) db.commit() - + c.callproc('test_sp', ('larch',)) rows = c.fetchall() self.assertEquals(len(rows), 1) @@ -84,7 +84,7 @@ class test_MySQLdb(capabilities.DatabaseTest): self.cursor.execute("describe some_non_existent_table"); except self.connection.ProgrammingError as msg: self.assertTrue(msg.args[0] == ER.NO_SUCH_TABLE) - + def test_insert_values(self): from pymysql.cursors import insert_values query = """INSERT FOO (a, b, c) VALUES (a, b, c)""" @@ -92,24 +92,23 @@ class test_MySQLdb(capabilities.DatabaseTest): self.assertTrue(matched) values = matched.group(1) self.assertTrue(values == "(a, b, c)") - + def test_ping(self): self.connection.ping() def test_literal_int(self): self.assertTrue("2" == self.connection.literal(2)) - + def test_literal_float(self): self.assertTrue("3.1415" == self.connection.literal(3.1415)) - + def test_literal_string(self): self.assertTrue("'foo'" == self.connection.literal("foo")) - - + + if __name__ == '__main__': if test_MySQLdb.leak_test: import gc gc.enable() gc.set_debug(gc.DEBUG_LEAK) unittest.main() - diff --git a/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_dbapi20.py b/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_dbapi20.py index 6366045..829fdb8 100644 --- a/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_dbapi20.py +++ b/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_dbapi20.py @@ -15,7 +15,7 @@ class test_MySQLdb(dbapi20.DatabaseAPI20Test): connect_kw_args.update(dict(read_default_file='~/.my.cnf', charset='utf8', sql_mode="ANSI,STRICT_TRANS_TABLES,TRADITIONAL")) - + def test_setoutputsize(self): pass def test_setoutputsize_basic(self): pass def test_nextset(self): pass @@ -24,7 +24,7 @@ class test_MySQLdb(dbapi20.DatabaseAPI20Test): test for an exception if the statement cannot return a result set. MySQL always returns a result set; it's just that some things return empty result sets.""" - + def test_fetchall(self): con = self._connect() try: @@ -70,10 +70,10 @@ class test_MySQLdb(dbapi20.DatabaseAPI20Test): 'cursor.fetchall should return an empty list if ' 'a select query returns no rows' ) - + finally: con.close() - + def test_fetchone(self): con = self._connect() try: @@ -152,8 +152,8 @@ class test_MySQLdb(dbapi20.DatabaseAPI20Test): def help_nextset_setUp(self,cur): ''' Should create a procedure called deleteme - that returns two result sets, first the - number of rows in booze then "name from booze" + that returns two result sets, first the + number of rows in booze then "name from booze" ''' sql=""" create procedure deleteme() @@ -205,6 +205,6 @@ class test_MySQLdb(dbapi20.DatabaseAPI20Test): finally: con.close() - + if __name__ == '__main__': unittest.main()