diff --git a/.travis.databases.json b/.travis.databases.json index b700531..f27349d 100644 --- a/.travis.databases.json +++ b/.travis.databases.json @@ -1,5 +1,4 @@ [ - {"host": "localhost", "user": "root", "passwd": "", "db": "test_pymysql", "use_unicode": true}, - {"host": "localhost", "user": "root", "passwd": "", "db": "test_pymysql2" }, - {"host": "localhost", "user": "root", "passwd": "", "db": "test_pymysql", "local_infile": true} + {"host": "localhost", "user": "root", "passwd": "", "db": "test_pymysql", "use_unicode": true, "local_infile": true}, + {"host": "localhost", "user": "root", "passwd": "", "db": "test_pymysql2" } ] diff --git a/pymysql/connections.py b/pymysql/connections.py index 2f16e9d..359cece 100644 --- a/pymysql/connections.py +++ b/pymysql/connections.py @@ -11,7 +11,6 @@ from functools import partial import hashlib import io import os -import re import socket import struct import sys @@ -675,7 +674,7 @@ class Connection(object): raise OperationalError(2014, "Command Out of Sync") ok = OKPacketWrapper(pkt) self.server_status = ok.server_status - return True + return ok def _send_autocommit_mode(self): ''' Set whether or not to commit after every execute() ''' @@ -698,6 +697,13 @@ class Connection(object): self._execute_command(COMMAND.COM_QUERY, "ROLLBACK") self._read_ok_packet() + def show_warnings(self): + """SHOW WARNINGS""" + self._execute_command(COMMAND.COM_QUERY, "SHOW WARNINGS") + result = MySQLResult(self) + result.read() + return result.rows + def select_db(self, db): '''Set current db''' self._execute_command(COMMAND.COM_INIT_DB, db) @@ -1069,8 +1075,6 @@ class Connection(object): NotSupportedError = NotSupportedError -# TODO: move OK and EOF packet parsing/logic into a proper subclass -# of MysqlPacket like has been done with FieldDescriptorPacket. class MySQLResult(object): def __init__(self, connection): @@ -1085,7 +1089,6 @@ class MySQLResult(object): self.rows = None self.has_next = None self.unbuffered_active = False - self.filename = None def __del__(self): if self.unbuffered_active: @@ -1095,7 +1098,6 @@ class MySQLResult(object): try: first_packet = self.connection._read_packet() - # TODO: use classes for different packet types? if first_packet.is_ok_packet(): self._read_ok_packet(first_packet) elif first_packet.is_load_local_packet(): @@ -1133,19 +1135,14 @@ class MySQLResult(object): def _read_load_local_packet(self, first_packet): load_packet = LoadLocalPacketWrapper(first_packet) - local_packet = LoadLocalFile(load_packet.filename, self.connection) - self.filename = load_packet.filename - local_packet.send_data() + sender = LoadLocalFile(load_packet.filename, self.connection) + sender.send_data() ok_packet = self.connection._read_packet() if not ok_packet.is_ok_packet(): raise OperationalError(2014, "Commands Out of Sync") self._read_ok_packet(ok_packet) - if self.warning_count > 0: - self._print_warnings() - self.filename = None - def _check_packet_is_eof(self, packet): if packet.is_eof_packet(): eof_packet = EOFPacketWrapper(packet) @@ -1154,16 +1151,6 @@ class MySQLResult(object): return True return False - def _print_warnings(self): - from warnings import warn - self.connection._execute_command(COMMAND.COM_QUERY, 'SHOW WARNINGS') - self.read() - if self.rows: - message = "\n" - for db_warning in self.rows: - message += "{0} in file '{1}'\n".format(db_warning[2], self.filename.decode('utf-8')) - warn(message, Warning, 3) - def _read_result_packet(self, first_packet): self.field_count = first_packet.read_length_encoded_integer() self._get_descriptions() diff --git a/pymysql/cursors.py b/pymysql/cursors.py index 0a85f04..1dbdd12 100644 --- a/pymysql/cursors.py +++ b/pymysql/cursors.py @@ -1,13 +1,11 @@ # -*- coding: utf-8 -*- from __future__ import print_function, absolute_import import re +import warnings from ._compat import range_type, text_type, PY2 -from .err import ( - Warning, Error, InterfaceError, DataError, - DatabaseError, OperationalError, IntegrityError, InternalError, - NotSupportedError, ProgrammingError) +from . import err #: Regular expression for :meth:`Cursor.executemany`. @@ -63,12 +61,12 @@ class Cursor(object): def _get_db(self): if not self.connection: - raise ProgrammingError("Cursor closed") + raise err.ProgrammingError("Cursor closed") return self.connection def _check_executed(self): if not self._executed: - raise ProgrammingError("execute() first") + raise err.ProgrammingError("execute() first") def _conv_row(self, row): return row @@ -262,7 +260,7 @@ class Cursor(object): elif mode == 'absolute': r = value else: - raise ProgrammingError("unknown scroll mode %s" % mode) + raise err.ProgrammingError("unknown scroll mode %s" % mode) if not (0 <= r < len(self._rows)): raise IndexError("out of range") @@ -286,19 +284,27 @@ class Cursor(object): self.lastrowid = result.insert_id self._rows = result.rows + if result.warning_count > 0: + self._show_warnings(conn) + + def _show_warnings(self, conn): + ws = conn.show_warnings() + for w in ws: + warnings.warn(w[-1], err.Warning, 4) + def __iter__(self): return iter(self.fetchone, None) - Warning = Warning - Error = Error - InterfaceError = InterfaceError - DatabaseError = DatabaseError - DataError = DataError - OperationalError = OperationalError - IntegrityError = IntegrityError - InternalError = InternalError - ProgrammingError = ProgrammingError - NotSupportedError = NotSupportedError + Warning = err.Warning + Error = err.Error + InterfaceError = err.InterfaceError + DatabaseError = err.DatabaseError + DataError = err.DataError + OperationalError = err.OperationalError + IntegrityError = err.IntegrityError + InternalError = err.InternalError + ProgrammingError = err.ProgrammingError + NotSupportedError = err.NotSupportedError class DictCursorMixin(object): @@ -426,7 +432,7 @@ class SSCursor(Cursor): if mode == 'relative': if value < 0: - raise NotSupportedError( + raise err.NotSupportedError( "Backwards scrolling not supported by this cursor") for _ in range_type(value): @@ -434,7 +440,7 @@ class SSCursor(Cursor): self.rownumber += value elif mode == 'absolute': if value < self.rownumber: - raise NotSupportedError( + raise err.NotSupportedError( "Backwards scrolling not supported by this cursor") end = value - self.rownumber @@ -442,7 +448,7 @@ class SSCursor(Cursor): self.read_next() self.rownumber = value else: - raise ProgrammingError("unknown scroll mode %s" % mode) + raise err.ProgrammingError("unknown scroll mode %s" % mode) class SSDictCursor(DictCursorMixin, SSCursor): diff --git a/pymysql/tests/base.py b/pymysql/tests/base.py index dafd1de..ee53aa2 100644 --- a/pymysql/tests/base.py +++ b/pymysql/tests/base.py @@ -16,7 +16,7 @@ class PyMySQLTestCase(unittest.TestCase): else: databases = [ {"host":"localhost","user":"root", - "passwd":"","db":"test_pymysql", "use_unicode": True}, + "passwd":"","db":"test_pymysql", "use_unicode": True, 'local_infile': True}, {"host":"localhost","user":"root","passwd":"","db":"test_pymysql2"}] def setUp(self): diff --git a/pymysql/tests/test_DictCursor.py b/pymysql/tests/test_DictCursor.py index f1cd06a..323850b 100644 --- a/pymysql/tests/test_DictCursor.py +++ b/pymysql/tests/test_DictCursor.py @@ -2,6 +2,7 @@ from pymysql.tests import base import pymysql.cursors import datetime +import warnings class TestDictCursor(base.PyMySQLTestCase): @@ -17,7 +18,9 @@ class TestDictCursor(base.PyMySQLTestCase): c = conn.cursor(self.cursor_type) # create a table ane some data to query - c.execute("drop table if exists dictcursor") + with warnings.catch_warnings(): + warnings.filterwarnings("ignore") + c.execute("drop table if exists dictcursor") c.execute("""CREATE TABLE dictcursor (name char(20), age int , DOB datetime)""") data = [("bob", 21, "1990-02-06 23:04:56"), ("jim", 56, "1955-05-09 13:12:45"), diff --git a/pymysql/tests/test_basic.py b/pymysql/tests/test_basic.py index af23690..77024e6 100644 --- a/pymysql/tests/test_basic.py +++ b/pymysql/tests/test_basic.py @@ -6,6 +6,7 @@ from pymysql.err import ProgrammingError import time import datetime +import warnings __all__ = ["TestConversion", "TestCursor", "TestBulkInserts"] @@ -136,7 +137,9 @@ class TestConversion(base.PyMySQLTestCase): # User is running a version of MySQL that doesn't support msecs within datetime pass finally: - c.execute("drop table if exists test_datetime") + with warnings.catch_warnings(): + warnings.filterwarnings("ignore") + c.execute("drop table if exists test_datetime") class TestCursor(base.PyMySQLTestCase): @@ -243,7 +246,9 @@ class TestBulkInserts(base.PyMySQLTestCase): c = conn.cursor(self.cursor_type) # create a table ane some data to query - c.execute("drop table if exists bulkinsert") + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + c.execute("drop table if exists bulkinsert") c.execute( """CREATE TABLE bulkinsert ( @@ -309,6 +314,16 @@ values (0, cursor.execute('commit') self._verify_records(data) + def test_warnings(self): + con = self.connections[0] + cur = con.cursor() + with warnings.catch_warnings(record=True) as ws: + warnings.simplefilter("always") + cur.execute("drop table if exists no_exists_table") + self.assertEqual(len(ws), 1) + self.assertEqual(ws[0].category, pymysql.Warning) + self.assertTrue(u"no_exists_table" in str(ws[0].message)) + if __name__ == "__main__": import unittest diff --git a/pymysql/tests/test_issues.py b/pymysql/tests/test_issues.py index 3a09831..643c9fd 100644 --- a/pymysql/tests/test_issues.py +++ b/pymysql/tests/test_issues.py @@ -12,6 +12,7 @@ except AttributeError: pass import datetime +import warnings class TestOldIssues(base.PyMySQLTestCase): @@ -19,7 +20,9 @@ class TestOldIssues(base.PyMySQLTestCase): """ undefined methods datetime_or_None, date_or_None """ conn = self.connections[0] c = conn.cursor() - c.execute("drop table if exists issue3") + with warnings.catch_warnings(): + warnings.filterwarnings("ignore") + c.execute("drop table if exists issue3") c.execute("create table issue3 (d date, t time, dt datetime, ts timestamp)") try: c.execute("insert into issue3 (d, t, dt, ts) values (%s,%s,%s,%s)", (None, None, None, None)) @@ -38,7 +41,9 @@ class TestOldIssues(base.PyMySQLTestCase): """ can't retrieve TIMESTAMP fields """ conn = self.connections[0] c = conn.cursor() - c.execute("drop table if exists issue4") + with warnings.catch_warnings(): + warnings.filterwarnings("ignore") + c.execute("drop table if exists issue4") c.execute("create table issue4 (ts timestamp)") try: c.execute("insert into issue4 (ts) values (now())") @@ -67,7 +72,9 @@ class TestOldIssues(base.PyMySQLTestCase): """ Primary Key and Index error when selecting data """ conn = self.connections[0] c = conn.cursor() - c.execute("drop table if exists test") + with warnings.catch_warnings(): + warnings.filterwarnings("ignore") + c.execute("drop table if exists test") c.execute("""CREATE TABLE `test` (`station` int(10) NOT NULL DEFAULT '0', `dh` datetime NOT NULL DEFAULT '0000-00-00 00:00:00', `echeance` int(1) NOT NULL DEFAULT '0', `me` double DEFAULT NULL, `mo` double DEFAULT NULL, PRIMARY @@ -90,7 +97,9 @@ KEY (`station`,`dh`,`echeance`)) ENGINE=MyISAM DEFAULT CHARSET=latin1;""") """ can't handle large result fields """ conn = self.connections[0] cur = conn.cursor() - cur.execute("drop table if exists issue13") + with warnings.catch_warnings(): + warnings.filterwarnings("ignore") + cur.execute("drop table if exists issue13") try: cur.execute("create table issue13 (t text)") # ticket says 18k @@ -107,7 +116,9 @@ KEY (`station`,`dh`,`echeance`)) ENGINE=MyISAM DEFAULT CHARSET=latin1;""") """ query should be expanded before perform character encoding """ conn = self.connections[0] c = conn.cursor() - c.execute("drop table if exists issue15") + with warnings.catch_warnings(): + warnings.filterwarnings("ignore") + c.execute("drop table if exists issue15") c.execute("create table issue15 (t varchar(32))") try: c.execute("insert into issue15 (t) values (%s)", (u'\xe4\xf6\xfc',)) @@ -120,7 +131,9 @@ KEY (`station`,`dh`,`echeance`)) ENGINE=MyISAM DEFAULT CHARSET=latin1;""") """ Patch for string and tuple escaping """ conn = self.connections[0] c = conn.cursor() - c.execute("drop table if exists issue16") + with warnings.catch_warnings(): + warnings.filterwarnings("ignore") + c.execute("drop table if exists issue16") c.execute("create table issue16 (name varchar(32) primary key, email varchar(32))") try: c.execute("insert into issue16 (name, email) values ('pete', 'floydophone')") @@ -138,7 +151,9 @@ KEY (`station`,`dh`,`echeance`)) ENGINE=MyISAM DEFAULT CHARSET=latin1;""") c = conn.cursor() # grant access to a table to a user with a password try: - c.execute("drop table if exists issue17") + with warnings.catch_warnings(): + warnings.filterwarnings("ignore") + c.execute("drop table if exists issue17") c.execute("create table issue17 (x varchar(32) primary key)") c.execute("insert into issue17 (x) values ('hello, world!')") c.execute("grant all privileges on %s.issue17 to 'issue17user'@'%%' identified by '1234'" % db) @@ -165,7 +180,9 @@ class TestNewIssues(base.PyMySQLTestCase): conn = pymysql.connect(charset="utf8", **self.databases[0]) c = conn.cursor() try: - c.execute(b"drop table if exists hei\xc3\x9fe".decode("utf8")) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore") + c.execute(b"drop table if exists hei\xc3\x9fe".decode("utf8")) c.execute(b"create table hei\xc3\x9fe (name varchar(32))".decode("utf8")) c.execute(b"insert into hei\xc3\x9fe (name) values ('Pi\xc3\xb1ata')".decode("utf8")) c.execute(b"select name from hei\xc3\x9fe".decode("utf8")) @@ -197,7 +214,7 @@ class TestNewIssues(base.PyMySQLTestCase): kill_id = id break # now nuke the connection - conn.kill(kill_id) + self.connections[1].kill(kill_id) # make sure this connection has broken try: c.execute("show tables") @@ -227,7 +244,9 @@ class TestNewIssues(base.PyMySQLTestCase): datum = "a" * 1024 * 1023 # reduced size for most default mysql installs try: - c.execute("drop table if exists issue38") + with warnings.catch_warnings(): + warnings.filterwarnings("ignore") + c.execute("drop table if exists issue38") c.execute("create table issue38 (id integer, data mediumblob)") c.execute("insert into issue38 values (1, %s)", (datum,)) finally: @@ -236,7 +255,9 @@ class TestNewIssues(base.PyMySQLTestCase): def disabled_test_issue_54(self): conn = self.connections[0] c = conn.cursor() - c.execute("drop table if exists issue54") + with warnings.catch_warnings(): + warnings.filterwarnings("ignore") + c.execute("drop table if exists issue54") big_sql = "select * from issue54 where " big_sql += " and ".join("%d=%d" % (i,i) for i in range(0, 100000)) @@ -255,7 +276,9 @@ class TestGitHubIssues(base.PyMySQLTestCase): c = conn.cursor() self.assertEqual(0, conn.insert_id()) try: - c.execute("drop table if exists issue66") + with warnings.catch_warnings(): + warnings.filterwarnings("ignore") + c.execute("drop table if exists issue66") c.execute("create table issue66 (id integer primary key auto_increment, x integer)") c.execute("insert into issue66 (x) values (1)") c.execute("insert into issue66 (x) values (1)") @@ -268,8 +291,10 @@ class TestGitHubIssues(base.PyMySQLTestCase): conn = self.connections[0] c = conn.cursor(pymysql.cursors.DictCursor) - c.execute("drop table if exists a") - c.execute("drop table if exists b") + with warnings.catch_warnings(): + warnings.filterwarnings("ignore") + c.execute("drop table if exists a") + c.execute("drop table if exists b") c.execute("""CREATE TABLE a (id int, value int)""") c.execute("""CREATE TABLE b (id int, value int)""") @@ -292,7 +317,9 @@ class TestGitHubIssues(base.PyMySQLTestCase): """ Leftover trailing OK packet for "CALL my_sp" queries """ conn = self.connections[0] cur = conn.cursor() - cur.execute("DROP PROCEDURE IF EXISTS `foo`") + with warnings.catch_warnings(): + warnings.filterwarnings("ignore") + cur.execute("DROP PROCEDURE IF EXISTS `foo`") cur.execute("""CREATE PROCEDURE `foo` () BEGIN SELECT 1; @@ -302,7 +329,9 @@ class TestGitHubIssues(base.PyMySQLTestCase): cur.execute("""SELECT 1""") self.assertEqual(cur.fetchone()[0], 1) finally: - cur.execute("DROP PROCEDURE IF EXISTS `foo`") + with warnings.catch_warnings(): + warnings.filterwarnings("ignore") + cur.execute("DROP PROCEDURE IF EXISTS `foo`") def test_issue_114(self): """ autocommit is not set after reconnecting with ping() """ @@ -341,7 +370,9 @@ class TestGitHubIssues(base.PyMySQLTestCase): cur.execute('select * from test_field_count') assert len(cur.description) == length finally: - cur.execute('drop table if exists test_field_count') + with warnings.catch_warnings(): + warnings.filterwarnings("ignore") + cur.execute('drop table if exists test_field_count') __all__ = ["TestOldIssues", "TestNewIssues", "TestGitHubIssues"] diff --git a/pymysql/tests/test_load_local.py b/pymysql/tests/test_load_local.py index 5136e70..1115bb3 100644 --- a/pymysql/tests/test_load_local.py +++ b/pymysql/tests/test_load_local.py @@ -1,7 +1,8 @@ -from pymysql.err import OperationalError +from pymysql import OperationalError, Warning from pymysql.tests import base import os +import warnings __all__ = ["TestLoadLocal"] @@ -9,7 +10,7 @@ __all__ = ["TestLoadLocal"] class TestLoadLocal(base.PyMySQLTestCase): def test_no_file(self): """Test load local infile when the file does not exist""" - conn = self.connections[2] + conn = self.connections[0] c = conn.cursor() c.execute("CREATE TABLE test_load_local (a INTEGER, b INTEGER)") try: @@ -25,7 +26,7 @@ class TestLoadLocal(base.PyMySQLTestCase): def test_load_file(self): """Test load local infile with a valid file""" - conn = self.connections[2] + conn = self.connections[0] c = conn.cursor() c.execute("CREATE TABLE test_load_local (a INTEGER, b INTEGER)") filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), @@ -43,8 +44,7 @@ class TestLoadLocal(base.PyMySQLTestCase): def test_load_warnings(self): """Test load local infile produces the appropriate warnings""" - import warnings - conn = self.connections[2] + conn = self.connections[0] c = conn.cursor() c.execute("CREATE TABLE test_load_local (a INTEGER, b INTEGER)") filename = os.path.join(os.path.dirname(os.path.realpath(__file__)), @@ -52,13 +52,13 @@ class TestLoadLocal(base.PyMySQLTestCase): 'load_local_warn_data.txt') try: with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') c.execute( ("LOAD DATA LOCAL INFILE '{0}' INTO TABLE " + "test_load_local FIELDS TERMINATED BY ','").format(filename) ) - self.assertEqual(True, "Incorrect integer value" in str(w[-1].message)) - except Warning as w: - self.assertLess(0, str(w).find("Incorrect integer value")) + self.assertEqual(w[0].category, Warning) + self.assertTrue("Incorrect integer value" in str(w[-1].message)) finally: c.execute("DROP TABLE test_load_local")