Merge branch 'defer_unbuffered_cleanup' of https://github.com/zzzeek/PyMySQL into zzzeek-defer_unbuffered_cleanup
Conflicts: pymysql/tests/base.py
This commit is contained in:
@@ -4,6 +4,10 @@ Changes
|
||||
0.6.4 -Support "LOAD LOCAL INFILE". Thanks @wraziens
|
||||
-Show MySQL warnings after execute query.
|
||||
-Fix MySQLError may be wrapped with OperationalError while connectiong. (#274)
|
||||
-SSCursor no longer attempts to expire un-collected rows within __del__,
|
||||
delaying termination of an interrupted program; cleanup of uncollected
|
||||
rows is left to the Connection on next execute, which emits a
|
||||
warning at that time. (#287)
|
||||
|
||||
0.6.3 -Fixed multiple result sets with SSCursor.
|
||||
-Fixed connection timeout.
|
||||
@@ -47,7 +51,7 @@ Changes
|
||||
-Removed DeprecationWarnings
|
||||
-Ran against the MySQLdb unit tests to check for bugs
|
||||
-Added support for client_flag, charset, sql_mode, read_default_file,
|
||||
use_unicode, cursorclass, init_command, and connect_timeout.
|
||||
use_unicode, cursorclass, init_command, and connect_timeout.
|
||||
-Refactoring for some more compatibility with MySQLdb including a fake
|
||||
pymysql.version_info attribute.
|
||||
-Now runs with no warnings with the -3 command-line switch
|
||||
|
||||
@@ -4,6 +4,7 @@ PY2 = sys.version_info[0] == 2
|
||||
PYPY = hasattr(sys, 'pypy_translation_info')
|
||||
JYTHON = sys.platform.startswith('java')
|
||||
IRONPYTHON = sys.platform == 'cli'
|
||||
CPYTHON = not PYPY and not JYTHON and not IRONPYTHON
|
||||
|
||||
if PY2:
|
||||
range_type = xrange
|
||||
|
||||
@@ -14,6 +14,7 @@ import os
|
||||
import socket
|
||||
import struct
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
try:
|
||||
import ssl
|
||||
@@ -932,6 +933,7 @@ class Connection(object):
|
||||
# If the last query was unbuffered, make sure it finishes before
|
||||
# sending new commands
|
||||
if self._result is not None and self._result.unbuffered_active:
|
||||
warnings.warn("Previous unbuffered result was left incomplete")
|
||||
self._result._finish_unbuffered_query()
|
||||
|
||||
if isinstance(sql, text_type):
|
||||
|
||||
@@ -40,12 +40,6 @@ class Cursor(object):
|
||||
self._result = None
|
||||
self._rows = None
|
||||
|
||||
def __del__(self):
|
||||
'''
|
||||
When this gets GC'd close it.
|
||||
'''
|
||||
self.close()
|
||||
|
||||
def close(self):
|
||||
'''
|
||||
Closing a cursor just exhausts all remaining data.
|
||||
|
||||
@@ -1,11 +1,16 @@
|
||||
import gc
|
||||
import os
|
||||
import json
|
||||
import pymysql
|
||||
import re
|
||||
|
||||
from .._compat import CPYTHON
|
||||
|
||||
try:
|
||||
import unittest2 as unittest
|
||||
except ImportError:
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
class PyMySQLTestCase(unittest.TestCase):
|
||||
# You can specify your test environment creating a file named
|
||||
@@ -41,7 +46,46 @@ class PyMySQLTestCase(unittest.TestCase):
|
||||
self.connections = []
|
||||
for params in self.databases:
|
||||
self.connections.append(pymysql.connect(**params))
|
||||
self.addCleanup(self._teardown_connections)
|
||||
|
||||
def tearDown(self):
|
||||
def _teardown_connections(self):
|
||||
for connection in self.connections:
|
||||
connection.close()
|
||||
|
||||
def safe_create_table(self, connection, tablename, ddl, cleanup=False):
|
||||
"""create a table.
|
||||
|
||||
Ensures any existing version of that table
|
||||
is first dropped.
|
||||
|
||||
Also adds a cleanup rule to drop the table after the test
|
||||
completes.
|
||||
|
||||
"""
|
||||
|
||||
cursor = connection.cursor()
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
cursor.execute("drop table if exists test")
|
||||
cursor.execute("create table test (data varchar(10))")
|
||||
cursor.close()
|
||||
if cleanup:
|
||||
self.addCleanup(self.drop_table, connection, tablename)
|
||||
|
||||
def drop_table(self, connection, tablename):
|
||||
cursor = connection.cursor()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
cursor.execute("drop table if exists %s" % tablename)
|
||||
cursor.close()
|
||||
|
||||
def safe_gc_collect(self):
|
||||
"""Ensure cycles are collected via gc.
|
||||
|
||||
Runs additional times on non-CPython platforms.
|
||||
|
||||
"""
|
||||
gc.collect()
|
||||
if not CPYTHON:
|
||||
gc.collect()
|
||||
|
||||
@@ -32,6 +32,9 @@ class TestDictCursor(base.PyMySQLTestCase):
|
||||
c.execute("drop table dictcursor")
|
||||
super(TestDictCursor, self).tearDown()
|
||||
|
||||
def _ensure_cursor_expired(self, cursor):
|
||||
pass
|
||||
|
||||
def test_DictCursor(self):
|
||||
bob, jim, fred = self.bob.copy(), self.jim.copy(), self.fred.copy()
|
||||
#all assert test compare to the structure as would come out from MySQLdb
|
||||
@@ -45,6 +48,8 @@ class TestDictCursor(base.PyMySQLTestCase):
|
||||
c.execute("SELECT * from dictcursor where name='bob'")
|
||||
r = c.fetchone()
|
||||
self.assertEqual(bob, r, "fetchone via DictCursor failed")
|
||||
self._ensure_cursor_expired(c)
|
||||
|
||||
# same again, but via fetchall => tuple)
|
||||
c.execute("SELECT * from dictcursor where name='bob'")
|
||||
r = c.fetchall()
|
||||
@@ -65,6 +70,7 @@ class TestDictCursor(base.PyMySQLTestCase):
|
||||
c.execute("SELECT * from dictcursor")
|
||||
r = c.fetchmany(2)
|
||||
self.assertEqual([bob, jim], r, "fetchmany failed via DictCursor")
|
||||
self._ensure_cursor_expired(c)
|
||||
|
||||
def test_custom_dict(self):
|
||||
class MyDict(dict): pass
|
||||
@@ -81,6 +87,7 @@ class TestDictCursor(base.PyMySQLTestCase):
|
||||
cur.execute("SELECT * FROM dictcursor WHERE name='bob'")
|
||||
r = cur.fetchone()
|
||||
self.assertEqual(bob, r, "fetchone() returns MyDictCursor")
|
||||
self._ensure_cursor_expired(cur)
|
||||
|
||||
cur.execute("SELECT * FROM dictcursor")
|
||||
r = cur.fetchall()
|
||||
@@ -96,11 +103,14 @@ class TestDictCursor(base.PyMySQLTestCase):
|
||||
r = cur.fetchmany(2)
|
||||
self.assertEqual([bob, jim], r,
|
||||
"list failed via MyDictCursor")
|
||||
self._ensure_cursor_expired(cur)
|
||||
|
||||
|
||||
class TestSSDictCursor(TestDictCursor):
|
||||
cursor_type = pymysql.cursors.SSDictCursor
|
||||
|
||||
def _ensure_cursor_expired(self, cursor):
|
||||
list(cursor.fetchall_unbuffered())
|
||||
|
||||
if __name__ == "__main__":
|
||||
import unittest
|
||||
|
||||
71
pymysql/tests/test_cursor.py
Normal file
71
pymysql/tests/test_cursor.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import warnings
|
||||
|
||||
from pymysql.tests import base
|
||||
import pymysql.cursors
|
||||
|
||||
class CursorTest(base.PyMySQLTestCase):
|
||||
def setUp(self):
|
||||
super(CursorTest, self).setUp()
|
||||
|
||||
conn = self.connections[0]
|
||||
self.safe_create_table(
|
||||
conn,
|
||||
"test", "create table test (data varchar(10))",
|
||||
cleanup=True)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"insert into test (data) values "
|
||||
"('row1'), ('row2'), ('row3'), ('row4'), ('row5')")
|
||||
cursor.close()
|
||||
self.test_connection = pymysql.connect(**self.databases[0])
|
||||
self.addCleanup(self.test_connection.close)
|
||||
|
||||
def test_cleanup_rows_unbuffered(self):
|
||||
conn = self.test_connection
|
||||
cursor = conn.cursor(pymysql.cursors.SSCursor)
|
||||
|
||||
cursor.execute("select * from test as t1, test as t2")
|
||||
for counter, row in enumerate(cursor):
|
||||
if counter > 10:
|
||||
break
|
||||
|
||||
del cursor
|
||||
self.safe_gc_collect()
|
||||
|
||||
c2 = conn.cursor()
|
||||
|
||||
with warnings.catch_warnings(record=True) as log:
|
||||
warnings.filterwarnings("always")
|
||||
|
||||
c2.execute("select 1")
|
||||
|
||||
self.assertGreater(len(log), 0)
|
||||
self.assertEqual(
|
||||
"Previous unbuffered result was left incomplete",
|
||||
str(log[-1].message))
|
||||
self.assertEqual(
|
||||
c2.fetchone(), (1,)
|
||||
)
|
||||
self.assertIsNone(c2.fetchone())
|
||||
|
||||
def test_cleanup_rows_buffered(self):
|
||||
conn = self.test_connection
|
||||
cursor = conn.cursor(pymysql.cursors.Cursor)
|
||||
|
||||
cursor.execute("select * from test as t1, test as t2")
|
||||
for counter, row in enumerate(cursor):
|
||||
if counter > 10:
|
||||
break
|
||||
|
||||
del cursor
|
||||
self.safe_gc_collect()
|
||||
|
||||
c2 = conn.cursor()
|
||||
|
||||
c2.execute("select 1")
|
||||
|
||||
self.assertEqual(
|
||||
c2.fetchone(), (1,)
|
||||
)
|
||||
self.assertIsNone(c2.fetchone())
|
||||
|
||||
Reference in New Issue
Block a user