From 2617fde2ed3b2dcc59d30ac70e662fba1d5256a3 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 1 Feb 2015 16:06:02 -0500 Subject: [PATCH] -SSCursor no longer attempts to expire un-collected rows within __del__, delaying termination of an interrupted program; cleanup of uncollected rows is left to the Connection on next execute, which emits a warning at that time. (fixes #287) --- CHANGELOG | 6 ++- pymysql/_compat.py | 1 + pymysql/connections.py | 2 + pymysql/cursors.py | 6 --- pymysql/tests/base.py | 47 ++++++++++++++++++++- pymysql/tests/test_DictCursor.py | 10 +++++ pymysql/tests/test_cursor.py | 71 ++++++++++++++++++++++++++++++++ 7 files changed, 135 insertions(+), 8 deletions(-) create mode 100644 pymysql/tests/test_cursor.py diff --git a/CHANGELOG b/CHANGELOG index 2f2e95f..dec002e 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -4,6 +4,10 @@ Changes 0.6.4 -Support "LOAD LOCAL INFILE". Thanks @wraziens -Show MySQL warnings after execute query. -Fix MySQLError may be wrapped with OperationalError while connectiong. (#274) + -SSCursor no longer attempts to expire un-collected rows within __del__, + delaying termination of an interrupted program; cleanup of uncollected + rows is left to the Connection on next execute, which emits a + warning at that time. (#287) 0.6.3 -Fixed multiple result sets with SSCursor. -Fixed connection timeout. @@ -47,7 +51,7 @@ Changes -Removed DeprecationWarnings -Ran against the MySQLdb unit tests to check for bugs -Added support for client_flag, charset, sql_mode, read_default_file, - use_unicode, cursorclass, init_command, and connect_timeout. + use_unicode, cursorclass, init_command, and connect_timeout. -Refactoring for some more compatibility with MySQLdb including a fake pymysql.version_info attribute. -Now runs with no warnings with the -3 command-line switch diff --git a/pymysql/_compat.py b/pymysql/_compat.py index b97dfd1..0c55346 100644 --- a/pymysql/_compat.py +++ b/pymysql/_compat.py @@ -4,6 +4,7 @@ PY2 = sys.version_info[0] == 2 PYPY = hasattr(sys, 'pypy_translation_info') JYTHON = sys.platform.startswith('java') IRONPYTHON = sys.platform == 'cli' +CPYTHON = not PYPY and not JYTHON and not IRONPYTHON if PY2: range_type = xrange diff --git a/pymysql/connections.py b/pymysql/connections.py index d4ae1f9..ba5a079 100644 --- a/pymysql/connections.py +++ b/pymysql/connections.py @@ -14,6 +14,7 @@ import os import socket import struct import sys +import warnings try: import ssl @@ -923,6 +924,7 @@ class Connection(object): # If the last query was unbuffered, make sure it finishes before # sending new commands if self._result is not None and self._result.unbuffered_active: + warnings.warn("Previous unbuffered result was left incomplete") self._result._finish_unbuffered_query() if isinstance(sql, text_type): diff --git a/pymysql/cursors.py b/pymysql/cursors.py index 35e3b75..06ada7a 100644 --- a/pymysql/cursors.py +++ b/pymysql/cursors.py @@ -40,12 +40,6 @@ class Cursor(object): self._result = None self._rows = None - def __del__(self): - ''' - When this gets GC'd close it. - ''' - self.close() - def close(self): ''' Closing a cursor just exhausts all remaining data. diff --git a/pymysql/tests/base.py b/pymysql/tests/base.py index ee53aa2..a8e542d 100644 --- a/pymysql/tests/base.py +++ b/pymysql/tests/base.py @@ -1,10 +1,16 @@ +import gc import os import json import pymysql + +from .._compat import CPYTHON + + try: import unittest2 as unittest except ImportError: import unittest +import warnings class PyMySQLTestCase(unittest.TestCase): # You can specify your test environment creating a file named @@ -23,7 +29,46 @@ class PyMySQLTestCase(unittest.TestCase): self.connections = [] for params in self.databases: self.connections.append(pymysql.connect(**params)) + self.addCleanup(self._teardown_connections) - def tearDown(self): + def _teardown_connections(self): for connection in self.connections: connection.close() + + def safe_create_table(self, connection, tablename, ddl, cleanup=False): + """create a table. + + Ensures any existing version of that table + is first dropped. + + Also adds a cleanup rule to drop the table after the test + completes. + + """ + + cursor = connection.cursor() + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + cursor.execute("drop table if exists test") + cursor.execute("create table test (data varchar(10))") + cursor.close() + if cleanup: + self.addCleanup(self.drop_table, connection, tablename) + + def drop_table(self, connection, tablename): + cursor = connection.cursor() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + cursor.execute("drop table if exists %s" % tablename) + cursor.close() + + def safe_gc_collect(self): + """Ensure cycles are collected via gc. + + Runs additional times on non-CPython platforms. + + """ + gc.collect() + if not CPYTHON: + gc.collect() \ No newline at end of file diff --git a/pymysql/tests/test_DictCursor.py b/pymysql/tests/test_DictCursor.py index 323850b..08d188e 100644 --- a/pymysql/tests/test_DictCursor.py +++ b/pymysql/tests/test_DictCursor.py @@ -32,6 +32,9 @@ class TestDictCursor(base.PyMySQLTestCase): c.execute("drop table dictcursor") super(TestDictCursor, self).tearDown() + def _ensure_cursor_expired(self, cursor): + pass + def test_DictCursor(self): bob, jim, fred = self.bob.copy(), self.jim.copy(), self.fred.copy() #all assert test compare to the structure as would come out from MySQLdb @@ -45,6 +48,8 @@ class TestDictCursor(base.PyMySQLTestCase): c.execute("SELECT * from dictcursor where name='bob'") r = c.fetchone() self.assertEqual(bob, r, "fetchone via DictCursor failed") + self._ensure_cursor_expired(c) + # same again, but via fetchall => tuple) c.execute("SELECT * from dictcursor where name='bob'") r = c.fetchall() @@ -65,6 +70,7 @@ class TestDictCursor(base.PyMySQLTestCase): c.execute("SELECT * from dictcursor") r = c.fetchmany(2) self.assertEqual([bob, jim], r, "fetchmany failed via DictCursor") + self._ensure_cursor_expired(c) def test_custom_dict(self): class MyDict(dict): pass @@ -81,6 +87,7 @@ class TestDictCursor(base.PyMySQLTestCase): cur.execute("SELECT * FROM dictcursor WHERE name='bob'") r = cur.fetchone() self.assertEqual(bob, r, "fetchone() returns MyDictCursor") + self._ensure_cursor_expired(cur) cur.execute("SELECT * FROM dictcursor") r = cur.fetchall() @@ -96,11 +103,14 @@ class TestDictCursor(base.PyMySQLTestCase): r = cur.fetchmany(2) self.assertEqual([bob, jim], r, "list failed via MyDictCursor") + self._ensure_cursor_expired(cur) class TestSSDictCursor(TestDictCursor): cursor_type = pymysql.cursors.SSDictCursor + def _ensure_cursor_expired(self, cursor): + list(cursor.fetchall_unbuffered()) if __name__ == "__main__": import unittest diff --git a/pymysql/tests/test_cursor.py b/pymysql/tests/test_cursor.py new file mode 100644 index 0000000..f590b9a --- /dev/null +++ b/pymysql/tests/test_cursor.py @@ -0,0 +1,71 @@ +import warnings + +from pymysql.tests import base +import pymysql.cursors + +class CursorTest(base.PyMySQLTestCase): + def setUp(self): + super(CursorTest, self).setUp() + + conn = self.connections[0] + self.safe_create_table( + conn, + "test", "create table test (data varchar(10))", + cleanup=True) + cursor = conn.cursor() + cursor.execute( + "insert into test (data) values " + "('row1'), ('row2'), ('row3'), ('row4'), ('row5')") + cursor.close() + self.test_connection = pymysql.connect(**self.databases[0]) + self.addCleanup(self.test_connection.close) + + def test_cleanup_rows_unbuffered(self): + conn = self.test_connection + cursor = conn.cursor(pymysql.cursors.SSCursor) + + cursor.execute("select * from test as t1, test as t2") + for counter, row in enumerate(cursor): + if counter > 10: + break + + del cursor + self.safe_gc_collect() + + c2 = conn.cursor() + + with warnings.catch_warnings(record=True) as log: + warnings.filterwarnings("always") + + c2.execute("select 1") + + self.assertGreater(len(log), 0) + self.assertEqual( + "Previous unbuffered result was left incomplete", + str(log[-1].message)) + self.assertEqual( + c2.fetchone(), (1,) + ) + self.assertIsNone(c2.fetchone()) + + def test_cleanup_rows_buffered(self): + conn = self.test_connection + cursor = conn.cursor(pymysql.cursors.Cursor) + + cursor.execute("select * from test as t1, test as t2") + for counter, row in enumerate(cursor): + if counter > 10: + break + + del cursor + self.safe_gc_collect() + + c2 = conn.cursor() + + c2.execute("select 1") + + self.assertEqual( + c2.fetchone(), (1,) + ) + self.assertIsNone(c2.fetchone()) +