Merge branch 'defer_unbuffered_cleanup' of https://github.com/zzzeek/PyMySQL into zzzeek-defer_unbuffered_cleanup

Conflicts:
	pymysql/tests/base.py
This commit is contained in:
INADA Naoki
2015-02-03 13:32:09 +09:00
7 changed files with 134 additions and 8 deletions

View File

@@ -4,6 +4,10 @@ Changes
0.6.4 -Support "LOAD LOCAL INFILE". Thanks @wraziens
-Show MySQL warnings after execute query.
-Fix MySQLError may be wrapped with OperationalError while connectiong. (#274)
-SSCursor no longer attempts to expire un-collected rows within __del__,
delaying termination of an interrupted program; cleanup of uncollected
rows is left to the Connection on next execute, which emits a
warning at that time. (#287)
0.6.3 -Fixed multiple result sets with SSCursor.
-Fixed connection timeout.
@@ -47,7 +51,7 @@ Changes
-Removed DeprecationWarnings
-Ran against the MySQLdb unit tests to check for bugs
-Added support for client_flag, charset, sql_mode, read_default_file,
use_unicode, cursorclass, init_command, and connect_timeout.
use_unicode, cursorclass, init_command, and connect_timeout.
-Refactoring for some more compatibility with MySQLdb including a fake
pymysql.version_info attribute.
-Now runs with no warnings with the -3 command-line switch

View File

@@ -4,6 +4,7 @@ PY2 = sys.version_info[0] == 2
PYPY = hasattr(sys, 'pypy_translation_info')
JYTHON = sys.platform.startswith('java')
IRONPYTHON = sys.platform == 'cli'
CPYTHON = not PYPY and not JYTHON and not IRONPYTHON
if PY2:
range_type = xrange

View File

@@ -14,6 +14,7 @@ import os
import socket
import struct
import sys
import warnings
try:
import ssl
@@ -932,6 +933,7 @@ class Connection(object):
# If the last query was unbuffered, make sure it finishes before
# sending new commands
if self._result is not None and self._result.unbuffered_active:
warnings.warn("Previous unbuffered result was left incomplete")
self._result._finish_unbuffered_query()
if isinstance(sql, text_type):

View File

@@ -40,12 +40,6 @@ class Cursor(object):
self._result = None
self._rows = None
def __del__(self):
'''
When this gets GC'd close it.
'''
self.close()
def close(self):
'''
Closing a cursor just exhausts all remaining data.

View File

@@ -1,11 +1,16 @@
import gc
import os
import json
import pymysql
import re
from .._compat import CPYTHON
try:
import unittest2 as unittest
except ImportError:
import unittest
import warnings
class PyMySQLTestCase(unittest.TestCase):
# You can specify your test environment creating a file named
@@ -41,7 +46,46 @@ class PyMySQLTestCase(unittest.TestCase):
self.connections = []
for params in self.databases:
self.connections.append(pymysql.connect(**params))
self.addCleanup(self._teardown_connections)
def tearDown(self):
def _teardown_connections(self):
for connection in self.connections:
connection.close()
def safe_create_table(self, connection, tablename, ddl, cleanup=False):
"""create a table.
Ensures any existing version of that table
is first dropped.
Also adds a cleanup rule to drop the table after the test
completes.
"""
cursor = connection.cursor()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
cursor.execute("drop table if exists test")
cursor.execute("create table test (data varchar(10))")
cursor.close()
if cleanup:
self.addCleanup(self.drop_table, connection, tablename)
def drop_table(self, connection, tablename):
cursor = connection.cursor()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
cursor.execute("drop table if exists %s" % tablename)
cursor.close()
def safe_gc_collect(self):
"""Ensure cycles are collected via gc.
Runs additional times on non-CPython platforms.
"""
gc.collect()
if not CPYTHON:
gc.collect()

View File

@@ -32,6 +32,9 @@ class TestDictCursor(base.PyMySQLTestCase):
c.execute("drop table dictcursor")
super(TestDictCursor, self).tearDown()
def _ensure_cursor_expired(self, cursor):
pass
def test_DictCursor(self):
bob, jim, fred = self.bob.copy(), self.jim.copy(), self.fred.copy()
#all assert test compare to the structure as would come out from MySQLdb
@@ -45,6 +48,8 @@ class TestDictCursor(base.PyMySQLTestCase):
c.execute("SELECT * from dictcursor where name='bob'")
r = c.fetchone()
self.assertEqual(bob, r, "fetchone via DictCursor failed")
self._ensure_cursor_expired(c)
# same again, but via fetchall => tuple)
c.execute("SELECT * from dictcursor where name='bob'")
r = c.fetchall()
@@ -65,6 +70,7 @@ class TestDictCursor(base.PyMySQLTestCase):
c.execute("SELECT * from dictcursor")
r = c.fetchmany(2)
self.assertEqual([bob, jim], r, "fetchmany failed via DictCursor")
self._ensure_cursor_expired(c)
def test_custom_dict(self):
class MyDict(dict): pass
@@ -81,6 +87,7 @@ class TestDictCursor(base.PyMySQLTestCase):
cur.execute("SELECT * FROM dictcursor WHERE name='bob'")
r = cur.fetchone()
self.assertEqual(bob, r, "fetchone() returns MyDictCursor")
self._ensure_cursor_expired(cur)
cur.execute("SELECT * FROM dictcursor")
r = cur.fetchall()
@@ -96,11 +103,14 @@ class TestDictCursor(base.PyMySQLTestCase):
r = cur.fetchmany(2)
self.assertEqual([bob, jim], r,
"list failed via MyDictCursor")
self._ensure_cursor_expired(cur)
class TestSSDictCursor(TestDictCursor):
cursor_type = pymysql.cursors.SSDictCursor
def _ensure_cursor_expired(self, cursor):
list(cursor.fetchall_unbuffered())
if __name__ == "__main__":
import unittest

View File

@@ -0,0 +1,71 @@
import warnings
from pymysql.tests import base
import pymysql.cursors
class CursorTest(base.PyMySQLTestCase):
def setUp(self):
super(CursorTest, self).setUp()
conn = self.connections[0]
self.safe_create_table(
conn,
"test", "create table test (data varchar(10))",
cleanup=True)
cursor = conn.cursor()
cursor.execute(
"insert into test (data) values "
"('row1'), ('row2'), ('row3'), ('row4'), ('row5')")
cursor.close()
self.test_connection = pymysql.connect(**self.databases[0])
self.addCleanup(self.test_connection.close)
def test_cleanup_rows_unbuffered(self):
conn = self.test_connection
cursor = conn.cursor(pymysql.cursors.SSCursor)
cursor.execute("select * from test as t1, test as t2")
for counter, row in enumerate(cursor):
if counter > 10:
break
del cursor
self.safe_gc_collect()
c2 = conn.cursor()
with warnings.catch_warnings(record=True) as log:
warnings.filterwarnings("always")
c2.execute("select 1")
self.assertGreater(len(log), 0)
self.assertEqual(
"Previous unbuffered result was left incomplete",
str(log[-1].message))
self.assertEqual(
c2.fetchone(), (1,)
)
self.assertIsNone(c2.fetchone())
def test_cleanup_rows_buffered(self):
conn = self.test_connection
cursor = conn.cursor(pymysql.cursors.Cursor)
cursor.execute("select * from test as t1, test as t2")
for counter, row in enumerate(cursor):
if counter > 10:
break
del cursor
self.safe_gc_collect()
c2 = conn.cursor()
c2.execute("select 1")
self.assertEqual(
c2.fetchone(), (1,)
)
self.assertIsNone(c2.fetchone())