Add Connection.show_warnings() method.

This commit is contained in:
INADA Naoki
2015-01-14 11:59:15 +09:00
parent 92f35a6dd7
commit eb1727541c
8 changed files with 116 additions and 75 deletions

View File

@@ -1,5 +1,4 @@
[ [
{"host": "localhost", "user": "root", "passwd": "", "db": "test_pymysql", "use_unicode": true}, {"host": "localhost", "user": "root", "passwd": "", "db": "test_pymysql", "use_unicode": true, "local_infile": true},
{"host": "localhost", "user": "root", "passwd": "", "db": "test_pymysql2" }, {"host": "localhost", "user": "root", "passwd": "", "db": "test_pymysql2" }
{"host": "localhost", "user": "root", "passwd": "", "db": "test_pymysql", "local_infile": true}
] ]

View File

@@ -11,7 +11,6 @@ from functools import partial
import hashlib import hashlib
import io import io
import os import os
import re
import socket import socket
import struct import struct
import sys import sys
@@ -675,7 +674,7 @@ class Connection(object):
raise OperationalError(2014, "Command Out of Sync") raise OperationalError(2014, "Command Out of Sync")
ok = OKPacketWrapper(pkt) ok = OKPacketWrapper(pkt)
self.server_status = ok.server_status self.server_status = ok.server_status
return True return ok
def _send_autocommit_mode(self): def _send_autocommit_mode(self):
''' Set whether or not to commit after every execute() ''' ''' Set whether or not to commit after every execute() '''
@@ -698,6 +697,13 @@ class Connection(object):
self._execute_command(COMMAND.COM_QUERY, "ROLLBACK") self._execute_command(COMMAND.COM_QUERY, "ROLLBACK")
self._read_ok_packet() self._read_ok_packet()
def show_warnings(self):
"""SHOW WARNINGS"""
self._execute_command(COMMAND.COM_QUERY, "SHOW WARNINGS")
result = MySQLResult(self)
result.read()
return result.rows
def select_db(self, db): def select_db(self, db):
'''Set current db''' '''Set current db'''
self._execute_command(COMMAND.COM_INIT_DB, db) self._execute_command(COMMAND.COM_INIT_DB, db)
@@ -1069,8 +1075,6 @@ class Connection(object):
NotSupportedError = NotSupportedError NotSupportedError = NotSupportedError
# TODO: move OK and EOF packet parsing/logic into a proper subclass
# of MysqlPacket like has been done with FieldDescriptorPacket.
class MySQLResult(object): class MySQLResult(object):
def __init__(self, connection): def __init__(self, connection):
@@ -1085,7 +1089,6 @@ class MySQLResult(object):
self.rows = None self.rows = None
self.has_next = None self.has_next = None
self.unbuffered_active = False self.unbuffered_active = False
self.filename = None
def __del__(self): def __del__(self):
if self.unbuffered_active: if self.unbuffered_active:
@@ -1095,7 +1098,6 @@ class MySQLResult(object):
try: try:
first_packet = self.connection._read_packet() first_packet = self.connection._read_packet()
# TODO: use classes for different packet types?
if first_packet.is_ok_packet(): if first_packet.is_ok_packet():
self._read_ok_packet(first_packet) self._read_ok_packet(first_packet)
elif first_packet.is_load_local_packet(): elif first_packet.is_load_local_packet():
@@ -1133,19 +1135,14 @@ class MySQLResult(object):
def _read_load_local_packet(self, first_packet): def _read_load_local_packet(self, first_packet):
load_packet = LoadLocalPacketWrapper(first_packet) load_packet = LoadLocalPacketWrapper(first_packet)
local_packet = LoadLocalFile(load_packet.filename, self.connection) sender = LoadLocalFile(load_packet.filename, self.connection)
self.filename = load_packet.filename sender.send_data()
local_packet.send_data()
ok_packet = self.connection._read_packet() ok_packet = self.connection._read_packet()
if not ok_packet.is_ok_packet(): if not ok_packet.is_ok_packet():
raise OperationalError(2014, "Commands Out of Sync") raise OperationalError(2014, "Commands Out of Sync")
self._read_ok_packet(ok_packet) self._read_ok_packet(ok_packet)
if self.warning_count > 0:
self._print_warnings()
self.filename = None
def _check_packet_is_eof(self, packet): def _check_packet_is_eof(self, packet):
if packet.is_eof_packet(): if packet.is_eof_packet():
eof_packet = EOFPacketWrapper(packet) eof_packet = EOFPacketWrapper(packet)
@@ -1154,16 +1151,6 @@ class MySQLResult(object):
return True return True
return False return False
def _print_warnings(self):
from warnings import warn
self.connection._execute_command(COMMAND.COM_QUERY, 'SHOW WARNINGS')
self.read()
if self.rows:
message = "\n"
for db_warning in self.rows:
message += "{0} in file '{1}'\n".format(db_warning[2], self.filename.decode('utf-8'))
warn(message, Warning, 3)
def _read_result_packet(self, first_packet): def _read_result_packet(self, first_packet):
self.field_count = first_packet.read_length_encoded_integer() self.field_count = first_packet.read_length_encoded_integer()
self._get_descriptions() self._get_descriptions()

View File

@@ -1,13 +1,11 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import print_function, absolute_import from __future__ import print_function, absolute_import
import re import re
import warnings
from ._compat import range_type, text_type, PY2 from ._compat import range_type, text_type, PY2
from .err import ( from . import err
Warning, Error, InterfaceError, DataError,
DatabaseError, OperationalError, IntegrityError, InternalError,
NotSupportedError, ProgrammingError)
#: Regular expression for :meth:`Cursor.executemany`. #: Regular expression for :meth:`Cursor.executemany`.
@@ -63,12 +61,12 @@ class Cursor(object):
def _get_db(self): def _get_db(self):
if not self.connection: if not self.connection:
raise ProgrammingError("Cursor closed") raise err.ProgrammingError("Cursor closed")
return self.connection return self.connection
def _check_executed(self): def _check_executed(self):
if not self._executed: if not self._executed:
raise ProgrammingError("execute() first") raise err.ProgrammingError("execute() first")
def _conv_row(self, row): def _conv_row(self, row):
return row return row
@@ -262,7 +260,7 @@ class Cursor(object):
elif mode == 'absolute': elif mode == 'absolute':
r = value r = value
else: else:
raise ProgrammingError("unknown scroll mode %s" % mode) raise err.ProgrammingError("unknown scroll mode %s" % mode)
if not (0 <= r < len(self._rows)): if not (0 <= r < len(self._rows)):
raise IndexError("out of range") raise IndexError("out of range")
@@ -286,19 +284,27 @@ class Cursor(object):
self.lastrowid = result.insert_id self.lastrowid = result.insert_id
self._rows = result.rows self._rows = result.rows
if result.warning_count > 0:
self._show_warnings(conn)
def _show_warnings(self, conn):
ws = conn.show_warnings()
for w in ws:
warnings.warn(w[-1], err.Warning, 4)
def __iter__(self): def __iter__(self):
return iter(self.fetchone, None) return iter(self.fetchone, None)
Warning = Warning Warning = err.Warning
Error = Error Error = err.Error
InterfaceError = InterfaceError InterfaceError = err.InterfaceError
DatabaseError = DatabaseError DatabaseError = err.DatabaseError
DataError = DataError DataError = err.DataError
OperationalError = OperationalError OperationalError = err.OperationalError
IntegrityError = IntegrityError IntegrityError = err.IntegrityError
InternalError = InternalError InternalError = err.InternalError
ProgrammingError = ProgrammingError ProgrammingError = err.ProgrammingError
NotSupportedError = NotSupportedError NotSupportedError = err.NotSupportedError
class DictCursorMixin(object): class DictCursorMixin(object):
@@ -426,7 +432,7 @@ class SSCursor(Cursor):
if mode == 'relative': if mode == 'relative':
if value < 0: if value < 0:
raise NotSupportedError( raise err.NotSupportedError(
"Backwards scrolling not supported by this cursor") "Backwards scrolling not supported by this cursor")
for _ in range_type(value): for _ in range_type(value):
@@ -434,7 +440,7 @@ class SSCursor(Cursor):
self.rownumber += value self.rownumber += value
elif mode == 'absolute': elif mode == 'absolute':
if value < self.rownumber: if value < self.rownumber:
raise NotSupportedError( raise err.NotSupportedError(
"Backwards scrolling not supported by this cursor") "Backwards scrolling not supported by this cursor")
end = value - self.rownumber end = value - self.rownumber
@@ -442,7 +448,7 @@ class SSCursor(Cursor):
self.read_next() self.read_next()
self.rownumber = value self.rownumber = value
else: else:
raise ProgrammingError("unknown scroll mode %s" % mode) raise err.ProgrammingError("unknown scroll mode %s" % mode)
class SSDictCursor(DictCursorMixin, SSCursor): class SSDictCursor(DictCursorMixin, SSCursor):

View File

@@ -16,7 +16,7 @@ class PyMySQLTestCase(unittest.TestCase):
else: else:
databases = [ databases = [
{"host":"localhost","user":"root", {"host":"localhost","user":"root",
"passwd":"","db":"test_pymysql", "use_unicode": True}, "passwd":"","db":"test_pymysql", "use_unicode": True, 'local_infile': True},
{"host":"localhost","user":"root","passwd":"","db":"test_pymysql2"}] {"host":"localhost","user":"root","passwd":"","db":"test_pymysql2"}]
def setUp(self): def setUp(self):

View File

@@ -2,6 +2,7 @@ from pymysql.tests import base
import pymysql.cursors import pymysql.cursors
import datetime import datetime
import warnings
class TestDictCursor(base.PyMySQLTestCase): class TestDictCursor(base.PyMySQLTestCase):
@@ -17,7 +18,9 @@ class TestDictCursor(base.PyMySQLTestCase):
c = conn.cursor(self.cursor_type) c = conn.cursor(self.cursor_type)
# create a table ane some data to query # create a table ane some data to query
c.execute("drop table if exists dictcursor") with warnings.catch_warnings():
warnings.filterwarnings("ignore")
c.execute("drop table if exists dictcursor")
c.execute("""CREATE TABLE dictcursor (name char(20), age int , DOB datetime)""") c.execute("""CREATE TABLE dictcursor (name char(20), age int , DOB datetime)""")
data = [("bob", 21, "1990-02-06 23:04:56"), data = [("bob", 21, "1990-02-06 23:04:56"),
("jim", 56, "1955-05-09 13:12:45"), ("jim", 56, "1955-05-09 13:12:45"),

View File

@@ -6,6 +6,7 @@ from pymysql.err import ProgrammingError
import time import time
import datetime import datetime
import warnings
__all__ = ["TestConversion", "TestCursor", "TestBulkInserts"] __all__ = ["TestConversion", "TestCursor", "TestBulkInserts"]
@@ -136,7 +137,9 @@ class TestConversion(base.PyMySQLTestCase):
# User is running a version of MySQL that doesn't support msecs within datetime # User is running a version of MySQL that doesn't support msecs within datetime
pass pass
finally: finally:
c.execute("drop table if exists test_datetime") with warnings.catch_warnings():
warnings.filterwarnings("ignore")
c.execute("drop table if exists test_datetime")
class TestCursor(base.PyMySQLTestCase): class TestCursor(base.PyMySQLTestCase):
@@ -243,7 +246,9 @@ class TestBulkInserts(base.PyMySQLTestCase):
c = conn.cursor(self.cursor_type) c = conn.cursor(self.cursor_type)
# create a table ane some data to query # create a table ane some data to query
c.execute("drop table if exists bulkinsert") with warnings.catch_warnings():
warnings.simplefilter("ignore")
c.execute("drop table if exists bulkinsert")
c.execute( c.execute(
"""CREATE TABLE bulkinsert """CREATE TABLE bulkinsert
( (
@@ -309,6 +314,16 @@ values (0,
cursor.execute('commit') cursor.execute('commit')
self._verify_records(data) self._verify_records(data)
def test_warnings(self):
con = self.connections[0]
cur = con.cursor()
with warnings.catch_warnings(record=True) as ws:
warnings.simplefilter("always")
cur.execute("drop table if exists no_exists_table")
self.assertEqual(len(ws), 1)
self.assertEqual(ws[0].category, pymysql.Warning)
self.assertTrue(u"no_exists_table" in str(ws[0].message))
if __name__ == "__main__": if __name__ == "__main__":
import unittest import unittest

View File

@@ -12,6 +12,7 @@ except AttributeError:
pass pass
import datetime import datetime
import warnings
class TestOldIssues(base.PyMySQLTestCase): class TestOldIssues(base.PyMySQLTestCase):
@@ -19,7 +20,9 @@ class TestOldIssues(base.PyMySQLTestCase):
""" undefined methods datetime_or_None, date_or_None """ """ undefined methods datetime_or_None, date_or_None """
conn = self.connections[0] conn = self.connections[0]
c = conn.cursor() c = conn.cursor()
c.execute("drop table if exists issue3") with warnings.catch_warnings():
warnings.filterwarnings("ignore")
c.execute("drop table if exists issue3")
c.execute("create table issue3 (d date, t time, dt datetime, ts timestamp)") c.execute("create table issue3 (d date, t time, dt datetime, ts timestamp)")
try: try:
c.execute("insert into issue3 (d, t, dt, ts) values (%s,%s,%s,%s)", (None, None, None, None)) c.execute("insert into issue3 (d, t, dt, ts) values (%s,%s,%s,%s)", (None, None, None, None))
@@ -38,7 +41,9 @@ class TestOldIssues(base.PyMySQLTestCase):
""" can't retrieve TIMESTAMP fields """ """ can't retrieve TIMESTAMP fields """
conn = self.connections[0] conn = self.connections[0]
c = conn.cursor() c = conn.cursor()
c.execute("drop table if exists issue4") with warnings.catch_warnings():
warnings.filterwarnings("ignore")
c.execute("drop table if exists issue4")
c.execute("create table issue4 (ts timestamp)") c.execute("create table issue4 (ts timestamp)")
try: try:
c.execute("insert into issue4 (ts) values (now())") c.execute("insert into issue4 (ts) values (now())")
@@ -67,7 +72,9 @@ class TestOldIssues(base.PyMySQLTestCase):
""" Primary Key and Index error when selecting data """ """ Primary Key and Index error when selecting data """
conn = self.connections[0] conn = self.connections[0]
c = conn.cursor() c = conn.cursor()
c.execute("drop table if exists test") with warnings.catch_warnings():
warnings.filterwarnings("ignore")
c.execute("drop table if exists test")
c.execute("""CREATE TABLE `test` (`station` int(10) NOT NULL DEFAULT '0', `dh` c.execute("""CREATE TABLE `test` (`station` int(10) NOT NULL DEFAULT '0', `dh`
datetime NOT NULL DEFAULT '0000-00-00 00:00:00', `echeance` int(1) NOT NULL datetime NOT NULL DEFAULT '0000-00-00 00:00:00', `echeance` int(1) NOT NULL
DEFAULT '0', `me` double DEFAULT NULL, `mo` double DEFAULT NULL, PRIMARY DEFAULT '0', `me` double DEFAULT NULL, `mo` double DEFAULT NULL, PRIMARY
@@ -90,7 +97,9 @@ KEY (`station`,`dh`,`echeance`)) ENGINE=MyISAM DEFAULT CHARSET=latin1;""")
""" can't handle large result fields """ """ can't handle large result fields """
conn = self.connections[0] conn = self.connections[0]
cur = conn.cursor() cur = conn.cursor()
cur.execute("drop table if exists issue13") with warnings.catch_warnings():
warnings.filterwarnings("ignore")
cur.execute("drop table if exists issue13")
try: try:
cur.execute("create table issue13 (t text)") cur.execute("create table issue13 (t text)")
# ticket says 18k # ticket says 18k
@@ -107,7 +116,9 @@ KEY (`station`,`dh`,`echeance`)) ENGINE=MyISAM DEFAULT CHARSET=latin1;""")
""" query should be expanded before perform character encoding """ """ query should be expanded before perform character encoding """
conn = self.connections[0] conn = self.connections[0]
c = conn.cursor() c = conn.cursor()
c.execute("drop table if exists issue15") with warnings.catch_warnings():
warnings.filterwarnings("ignore")
c.execute("drop table if exists issue15")
c.execute("create table issue15 (t varchar(32))") c.execute("create table issue15 (t varchar(32))")
try: try:
c.execute("insert into issue15 (t) values (%s)", (u'\xe4\xf6\xfc',)) c.execute("insert into issue15 (t) values (%s)", (u'\xe4\xf6\xfc',))
@@ -120,7 +131,9 @@ KEY (`station`,`dh`,`echeance`)) ENGINE=MyISAM DEFAULT CHARSET=latin1;""")
""" Patch for string and tuple escaping """ """ Patch for string and tuple escaping """
conn = self.connections[0] conn = self.connections[0]
c = conn.cursor() c = conn.cursor()
c.execute("drop table if exists issue16") with warnings.catch_warnings():
warnings.filterwarnings("ignore")
c.execute("drop table if exists issue16")
c.execute("create table issue16 (name varchar(32) primary key, email varchar(32))") c.execute("create table issue16 (name varchar(32) primary key, email varchar(32))")
try: try:
c.execute("insert into issue16 (name, email) values ('pete', 'floydophone')") c.execute("insert into issue16 (name, email) values ('pete', 'floydophone')")
@@ -138,7 +151,9 @@ KEY (`station`,`dh`,`echeance`)) ENGINE=MyISAM DEFAULT CHARSET=latin1;""")
c = conn.cursor() c = conn.cursor()
# grant access to a table to a user with a password # grant access to a table to a user with a password
try: try:
c.execute("drop table if exists issue17") with warnings.catch_warnings():
warnings.filterwarnings("ignore")
c.execute("drop table if exists issue17")
c.execute("create table issue17 (x varchar(32) primary key)") c.execute("create table issue17 (x varchar(32) primary key)")
c.execute("insert into issue17 (x) values ('hello, world!')") c.execute("insert into issue17 (x) values ('hello, world!')")
c.execute("grant all privileges on %s.issue17 to 'issue17user'@'%%' identified by '1234'" % db) c.execute("grant all privileges on %s.issue17 to 'issue17user'@'%%' identified by '1234'" % db)
@@ -165,7 +180,9 @@ class TestNewIssues(base.PyMySQLTestCase):
conn = pymysql.connect(charset="utf8", **self.databases[0]) conn = pymysql.connect(charset="utf8", **self.databases[0])
c = conn.cursor() c = conn.cursor()
try: try:
c.execute(b"drop table if exists hei\xc3\x9fe".decode("utf8")) with warnings.catch_warnings():
warnings.filterwarnings("ignore")
c.execute(b"drop table if exists hei\xc3\x9fe".decode("utf8"))
c.execute(b"create table hei\xc3\x9fe (name varchar(32))".decode("utf8")) c.execute(b"create table hei\xc3\x9fe (name varchar(32))".decode("utf8"))
c.execute(b"insert into hei\xc3\x9fe (name) values ('Pi\xc3\xb1ata')".decode("utf8")) c.execute(b"insert into hei\xc3\x9fe (name) values ('Pi\xc3\xb1ata')".decode("utf8"))
c.execute(b"select name from hei\xc3\x9fe".decode("utf8")) c.execute(b"select name from hei\xc3\x9fe".decode("utf8"))
@@ -197,7 +214,7 @@ class TestNewIssues(base.PyMySQLTestCase):
kill_id = id kill_id = id
break break
# now nuke the connection # now nuke the connection
conn.kill(kill_id) self.connections[1].kill(kill_id)
# make sure this connection has broken # make sure this connection has broken
try: try:
c.execute("show tables") c.execute("show tables")
@@ -227,7 +244,9 @@ class TestNewIssues(base.PyMySQLTestCase):
datum = "a" * 1024 * 1023 # reduced size for most default mysql installs datum = "a" * 1024 * 1023 # reduced size for most default mysql installs
try: try:
c.execute("drop table if exists issue38") with warnings.catch_warnings():
warnings.filterwarnings("ignore")
c.execute("drop table if exists issue38")
c.execute("create table issue38 (id integer, data mediumblob)") c.execute("create table issue38 (id integer, data mediumblob)")
c.execute("insert into issue38 values (1, %s)", (datum,)) c.execute("insert into issue38 values (1, %s)", (datum,))
finally: finally:
@@ -236,7 +255,9 @@ class TestNewIssues(base.PyMySQLTestCase):
def disabled_test_issue_54(self): def disabled_test_issue_54(self):
conn = self.connections[0] conn = self.connections[0]
c = conn.cursor() c = conn.cursor()
c.execute("drop table if exists issue54") with warnings.catch_warnings():
warnings.filterwarnings("ignore")
c.execute("drop table if exists issue54")
big_sql = "select * from issue54 where " big_sql = "select * from issue54 where "
big_sql += " and ".join("%d=%d" % (i,i) for i in range(0, 100000)) big_sql += " and ".join("%d=%d" % (i,i) for i in range(0, 100000))
@@ -255,7 +276,9 @@ class TestGitHubIssues(base.PyMySQLTestCase):
c = conn.cursor() c = conn.cursor()
self.assertEqual(0, conn.insert_id()) self.assertEqual(0, conn.insert_id())
try: try:
c.execute("drop table if exists issue66") with warnings.catch_warnings():
warnings.filterwarnings("ignore")
c.execute("drop table if exists issue66")
c.execute("create table issue66 (id integer primary key auto_increment, x integer)") c.execute("create table issue66 (id integer primary key auto_increment, x integer)")
c.execute("insert into issue66 (x) values (1)") c.execute("insert into issue66 (x) values (1)")
c.execute("insert into issue66 (x) values (1)") c.execute("insert into issue66 (x) values (1)")
@@ -268,8 +291,10 @@ class TestGitHubIssues(base.PyMySQLTestCase):
conn = self.connections[0] conn = self.connections[0]
c = conn.cursor(pymysql.cursors.DictCursor) c = conn.cursor(pymysql.cursors.DictCursor)
c.execute("drop table if exists a") with warnings.catch_warnings():
c.execute("drop table if exists b") warnings.filterwarnings("ignore")
c.execute("drop table if exists a")
c.execute("drop table if exists b")
c.execute("""CREATE TABLE a (id int, value int)""") c.execute("""CREATE TABLE a (id int, value int)""")
c.execute("""CREATE TABLE b (id int, value int)""") c.execute("""CREATE TABLE b (id int, value int)""")
@@ -292,7 +317,9 @@ class TestGitHubIssues(base.PyMySQLTestCase):
""" Leftover trailing OK packet for "CALL my_sp" queries """ """ Leftover trailing OK packet for "CALL my_sp" queries """
conn = self.connections[0] conn = self.connections[0]
cur = conn.cursor() cur = conn.cursor()
cur.execute("DROP PROCEDURE IF EXISTS `foo`") with warnings.catch_warnings():
warnings.filterwarnings("ignore")
cur.execute("DROP PROCEDURE IF EXISTS `foo`")
cur.execute("""CREATE PROCEDURE `foo` () cur.execute("""CREATE PROCEDURE `foo` ()
BEGIN BEGIN
SELECT 1; SELECT 1;
@@ -302,7 +329,9 @@ class TestGitHubIssues(base.PyMySQLTestCase):
cur.execute("""SELECT 1""") cur.execute("""SELECT 1""")
self.assertEqual(cur.fetchone()[0], 1) self.assertEqual(cur.fetchone()[0], 1)
finally: finally:
cur.execute("DROP PROCEDURE IF EXISTS `foo`") with warnings.catch_warnings():
warnings.filterwarnings("ignore")
cur.execute("DROP PROCEDURE IF EXISTS `foo`")
def test_issue_114(self): def test_issue_114(self):
""" autocommit is not set after reconnecting with ping() """ """ autocommit is not set after reconnecting with ping() """
@@ -341,7 +370,9 @@ class TestGitHubIssues(base.PyMySQLTestCase):
cur.execute('select * from test_field_count') cur.execute('select * from test_field_count')
assert len(cur.description) == length assert len(cur.description) == length
finally: finally:
cur.execute('drop table if exists test_field_count') with warnings.catch_warnings():
warnings.filterwarnings("ignore")
cur.execute('drop table if exists test_field_count')
__all__ = ["TestOldIssues", "TestNewIssues", "TestGitHubIssues"] __all__ = ["TestOldIssues", "TestNewIssues", "TestGitHubIssues"]

View File

@@ -1,7 +1,8 @@
from pymysql.err import OperationalError from pymysql import OperationalError, Warning
from pymysql.tests import base from pymysql.tests import base
import os import os
import warnings
__all__ = ["TestLoadLocal"] __all__ = ["TestLoadLocal"]
@@ -9,7 +10,7 @@ __all__ = ["TestLoadLocal"]
class TestLoadLocal(base.PyMySQLTestCase): class TestLoadLocal(base.PyMySQLTestCase):
def test_no_file(self): def test_no_file(self):
"""Test load local infile when the file does not exist""" """Test load local infile when the file does not exist"""
conn = self.connections[2] conn = self.connections[0]
c = conn.cursor() c = conn.cursor()
c.execute("CREATE TABLE test_load_local (a INTEGER, b INTEGER)") c.execute("CREATE TABLE test_load_local (a INTEGER, b INTEGER)")
try: try:
@@ -25,7 +26,7 @@ class TestLoadLocal(base.PyMySQLTestCase):
def test_load_file(self): def test_load_file(self):
"""Test load local infile with a valid file""" """Test load local infile with a valid file"""
conn = self.connections[2] conn = self.connections[0]
c = conn.cursor() c = conn.cursor()
c.execute("CREATE TABLE test_load_local (a INTEGER, b INTEGER)") c.execute("CREATE TABLE test_load_local (a INTEGER, b INTEGER)")
filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),
@@ -43,8 +44,7 @@ class TestLoadLocal(base.PyMySQLTestCase):
def test_load_warnings(self): def test_load_warnings(self):
"""Test load local infile produces the appropriate warnings""" """Test load local infile produces the appropriate warnings"""
import warnings conn = self.connections[0]
conn = self.connections[2]
c = conn.cursor() c = conn.cursor()
c.execute("CREATE TABLE test_load_local (a INTEGER, b INTEGER)") c.execute("CREATE TABLE test_load_local (a INTEGER, b INTEGER)")
filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),
@@ -52,13 +52,13 @@ class TestLoadLocal(base.PyMySQLTestCase):
'load_local_warn_data.txt') 'load_local_warn_data.txt')
try: try:
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
warnings.simplefilter('always')
c.execute( c.execute(
("LOAD DATA LOCAL INFILE '{0}' INTO TABLE " + ("LOAD DATA LOCAL INFILE '{0}' INTO TABLE " +
"test_load_local FIELDS TERMINATED BY ','").format(filename) "test_load_local FIELDS TERMINATED BY ','").format(filename)
) )
self.assertEqual(True, "Incorrect integer value" in str(w[-1].message)) self.assertEqual(w[0].category, Warning)
except Warning as w: self.assertTrue("Incorrect integer value" in str(w[-1].message))
self.assertLess(0, str(w).find("Incorrect integer value"))
finally: finally:
c.execute("DROP TABLE test_load_local") c.execute("DROP TABLE test_load_local")