Add tests for load data local infile
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
[
|
||||
{"host": "localhost", "user": "root", "passwd": "", "db": "test_pymysql", "use_unicode": true},
|
||||
{"host": "localhost", "user": "root", "passwd": "", "db": "test_pymysql2" }
|
||||
{"host": "localhost", "user": "root", "passwd": "", "db": "test_pymysql", "use_unicode": true, "load_local": true},
|
||||
]
|
||||
|
||||
@@ -458,7 +458,6 @@ class EOFPacketWrapper(object):
|
||||
return getattr(self.packet, key)
|
||||
|
||||
|
||||
|
||||
class LoadLocalPacketWrapper(object):
|
||||
"""
|
||||
Load Local Packet Wrapper. It uses an existing packet object, and wraps
|
||||
@@ -480,7 +479,6 @@ class LoadLocalPacketWrapper(object):
|
||||
return getattr(self.packet, key)
|
||||
|
||||
|
||||
|
||||
class Connection(object):
|
||||
"""
|
||||
Representation of a socket with a mysql server.
|
||||
@@ -1261,6 +1259,7 @@ class LoadLocalFile(object):
|
||||
if not self.connection.socket:
|
||||
raise InterfaceError("(0, '')")
|
||||
|
||||
# sequence id is 2 as we already sent a query packet
|
||||
seq_id = 2
|
||||
try:
|
||||
with open(self.filename, 'r') as open_file:
|
||||
@@ -1268,7 +1267,6 @@ class LoadLocalFile(object):
|
||||
prelude = ""
|
||||
packet = ""
|
||||
packet_size = 0
|
||||
# sequence id is 2 as we already sent a query packet
|
||||
|
||||
for line in open_file:
|
||||
line_length = len(line)
|
||||
|
||||
@@ -4,6 +4,7 @@ from pymysql.tests.test_nextset import *
|
||||
from pymysql.tests.test_DictCursor import *
|
||||
from pymysql.tests.test_connection import TestConnection
|
||||
from pymysql.tests.test_SSCursor import *
|
||||
from pymysql.tests.test_load_local import *
|
||||
|
||||
from pymysql.tests.thirdparty import *
|
||||
|
||||
|
||||
22749
pymysql/tests/data/load_local_data.txt
Normal file
22749
pymysql/tests/data/load_local_data.txt
Normal file
File diff suppressed because it is too large
Load Diff
50
pymysql/tests/data/load_local_warn_data.txt
Normal file
50
pymysql/tests/data/load_local_warn_data.txt
Normal file
@@ -0,0 +1,50 @@
|
||||
1,2,
|
||||
3,4,
|
||||
5,6,
|
||||
7,8,
|
||||
1,2,
|
||||
3,4,
|
||||
5,6,
|
||||
,8,
|
||||
1,2,
|
||||
3,4,
|
||||
5,6,
|
||||
7,8,
|
||||
1,2,
|
||||
3,4,
|
||||
5,6,
|
||||
7,8,
|
||||
1,2,
|
||||
3,4,
|
||||
5,6,
|
||||
7,8,
|
||||
1,2,
|
||||
3,4,
|
||||
5,6,
|
||||
7,8,
|
||||
1,2,
|
||||
3,4,
|
||||
5,6,
|
||||
7,8,
|
||||
1,2,
|
||||
3,4,
|
||||
5,6,
|
||||
7,8,
|
||||
1,2,
|
||||
3,4,
|
||||
5,6,
|
||||
7,8,
|
||||
1,2,
|
||||
3,4,
|
||||
5,6,
|
||||
7,8,
|
||||
1,2,
|
||||
3,4,
|
||||
5,6,
|
||||
7,8,
|
||||
1,2,
|
||||
3,4,
|
||||
5,6,
|
||||
7,8,
|
||||
1,2,
|
||||
3,4,
|
||||
75
pymysql/tests/test_load_local.py
Normal file
75
pymysql/tests/test_load_local.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from pymysql.tests import base
|
||||
from pymysql.err import OperationalError
|
||||
|
||||
import os
|
||||
|
||||
__all__ = ["TestLoadLocal"]
|
||||
|
||||
|
||||
class TestLoadLocal(base.PyMySQLTestCase):
|
||||
def test_no_file(self):
|
||||
"""Test load local infile when the file does not exist"""
|
||||
conn = self.connections[2]
|
||||
c = conn.cursor()
|
||||
c.execute("CREATE TABLE test_load_local (a INTEGER, b INTEGER)")
|
||||
try:
|
||||
with self.assertRaisesRegexp(
|
||||
OperationalError, "Can't find file 'no_data.txt'"):
|
||||
c.execute(
|
||||
"LOAD DATA LOCAL INFILE 'no_data.txt' INTO TABLE " +
|
||||
"test_load_local fields terminated by ','"
|
||||
)
|
||||
finally:
|
||||
c.execute("DROP TABLE test_load_local")
|
||||
c.close()
|
||||
|
||||
def test_load_file(self):
|
||||
"""Test load local infile with a valid file"""
|
||||
conn = self.connections[2]
|
||||
c = conn.cursor()
|
||||
c.execute("CREATE TABLE test_load_local (a INTEGER, b INTEGER)")
|
||||
filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),
|
||||
'data',
|
||||
'load_local_data.txt')
|
||||
try:
|
||||
c.execute(
|
||||
("LOAD DATA LOCAL INFILE '{}' INTO TABLE " +
|
||||
"test_load_local FIELDS TERMINATED BY ','").format(filename)
|
||||
)
|
||||
c.execute("SELECT COUNT(*) FROM test_load_local")
|
||||
self.assertEquals(22749, c.fetchone()[0])
|
||||
finally:
|
||||
c.execute("DROP TABLE test_load_local")
|
||||
|
||||
def test_load_warnings(self):
|
||||
"""Test load local infile produces the appropriate warnings"""
|
||||
import sys
|
||||
from StringIO import StringIO
|
||||
|
||||
saved_stdout = sys.stdout
|
||||
out = StringIO()
|
||||
sys.stdout = out
|
||||
conn = self.connections[2]
|
||||
c = conn.cursor()
|
||||
c.execute("CREATE TABLE test_load_local (a INTEGER, b INTEGER)")
|
||||
filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),
|
||||
'data',
|
||||
'load_local_warn_data.txt')
|
||||
|
||||
try:
|
||||
c.execute(
|
||||
("LOAD DATA LOCAL INFILE '{}' INTO TABLE " +
|
||||
"test_load_local FIELDS TERMINATED BY ','").format(filename)
|
||||
)
|
||||
output = out.getvalue().strip().split('\n')
|
||||
self.assertEquals(2, len(output))
|
||||
self.assertEqual(" Warning: Incorrect integer value: '' for column 'a' at row 8", output[1])
|
||||
|
||||
finally:
|
||||
sys.stdout = saved_stdout
|
||||
c.execute("DROP TABLE test_load_local")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import unittest
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user