Merge pull request #158 from methane/dict-type
Add DictCursor.dict_type to override it.
This commit is contained in:
@@ -12,9 +12,10 @@ if PY2:
|
|||||||
else:
|
else:
|
||||||
import io
|
import io
|
||||||
|
|
||||||
from .err import Warning, Error, InterfaceError, DataError, \
|
from .err import (
|
||||||
DatabaseError, OperationalError, IntegrityError, InternalError, \
|
Warning, Error, InterfaceError, DataError,
|
||||||
NotSupportedError, ProgrammingError
|
DatabaseError, OperationalError, IntegrityError, InternalError,
|
||||||
|
NotSupportedError, ProgrammingError)
|
||||||
|
|
||||||
insert_values = re.compile(r'\svalues\s*(\(.+\))', re.IGNORECASE)
|
insert_values = re.compile(r'\svalues\s*(\(.+\))', re.IGNORECASE)
|
||||||
|
|
||||||
@@ -115,9 +116,8 @@ class Cursor(object):
|
|||||||
result = 0
|
result = 0
|
||||||
try:
|
try:
|
||||||
result = self._query(query)
|
result = self._query(query)
|
||||||
except:
|
except Exception:
|
||||||
exc, value, tb = exc_info()
|
exc, value = exc_info()[:2]
|
||||||
del tb
|
|
||||||
self.errorhandler(self, exc, value)
|
self.errorhandler(self, exc, value)
|
||||||
|
|
||||||
self._executed = query
|
self._executed = query
|
||||||
@@ -261,6 +261,8 @@ class Cursor(object):
|
|||||||
|
|
||||||
class DictCursor(Cursor):
|
class DictCursor(Cursor):
|
||||||
"""A cursor which returns results as a dictionary"""
|
"""A cursor which returns results as a dictionary"""
|
||||||
|
# You can override this to use OrderedDict or other dict-like types.
|
||||||
|
dict_type = dict
|
||||||
|
|
||||||
def execute(self, query, args=None):
|
def execute(self, query, args=None):
|
||||||
result = super(DictCursor, self).execute(query, args)
|
result = super(DictCursor, self).execute(query, args)
|
||||||
@@ -273,7 +275,7 @@ class DictCursor(Cursor):
|
|||||||
self._check_executed()
|
self._check_executed()
|
||||||
if self._rows is None or self.rownumber >= len(self._rows):
|
if self._rows is None or self.rownumber >= len(self._rows):
|
||||||
return None
|
return None
|
||||||
result = dict(zip(self._fields, self._rows[self.rownumber]))
|
result = self.dict_type(zip(self._fields, self._rows[self.rownumber]))
|
||||||
self.rownumber += 1
|
self.rownumber += 1
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -283,9 +285,9 @@ class DictCursor(Cursor):
|
|||||||
if self._rows is None:
|
if self._rows is None:
|
||||||
return None
|
return None
|
||||||
end = self.rownumber + (size or self.arraysize)
|
end = self.rownumber + (size or self.arraysize)
|
||||||
result = [ dict(zip(self._fields, r)) for r in self._rows[self.rownumber:end] ]
|
result = [self.dict_type(zip(self._fields, r)) for r in self._rows[self.rownumber:end]]
|
||||||
self.rownumber = min(end, len(self._rows))
|
self.rownumber = min(end, len(self._rows))
|
||||||
return tuple(result)
|
return result
|
||||||
|
|
||||||
def fetchall(self):
|
def fetchall(self):
|
||||||
''' Fetch all the rows '''
|
''' Fetch all the rows '''
|
||||||
@@ -293,11 +295,11 @@ class DictCursor(Cursor):
|
|||||||
if self._rows is None:
|
if self._rows is None:
|
||||||
return None
|
return None
|
||||||
if self.rownumber:
|
if self.rownumber:
|
||||||
result = [ dict(zip(self._fields, r)) for r in self._rows[self.rownumber:] ]
|
result = [self.dict_type(zip(self._fields, r)) for r in self._rows[self.rownumber:]]
|
||||||
else:
|
else:
|
||||||
result = [ dict(zip(self._fields, r)) for r in self._rows ]
|
result = [self.dict_type(zip(self._fields, r)) for r in self._rows]
|
||||||
self.rownumber = len(self._rows)
|
self.rownumber = len(self._rows)
|
||||||
return tuple(result)
|
return result
|
||||||
|
|
||||||
class SSCursor(Cursor):
|
class SSCursor(Cursor):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -3,53 +3,94 @@ import pymysql.cursors
|
|||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
class TestDictCursor(base.PyMySQLTestCase):
|
|
||||||
|
|
||||||
def test_DictCursor(self):
|
class TestDictCursor(base.PyMySQLTestCase):
|
||||||
#all assert test compare to the structure as would come out from MySQLdb
|
bob = {'name': 'bob', 'age': 21, 'DOB': datetime.datetime(1990, 2, 6, 23, 4, 56)}
|
||||||
conn = self.connections[0]
|
jim = {'name': 'jim', 'age': 56, 'DOB': datetime.datetime(1955, 5, 9, 13, 12, 45)}
|
||||||
|
fred = {'name': 'fred', 'age': 100, 'DOB': datetime.datetime(1911, 9, 12, 1, 1, 1)}
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(TestDictCursor, self).setUp()
|
||||||
|
self.conn = conn = self.connections[0]
|
||||||
c = conn.cursor(pymysql.cursors.DictCursor)
|
c = conn.cursor(pymysql.cursors.DictCursor)
|
||||||
|
|
||||||
# create a table ane some data to query
|
# create a table ane some data to query
|
||||||
c.execute("""CREATE TABLE dictcursor (name char(20), age int , DOB datetime)""")
|
c.execute("""CREATE TABLE dictcursor (name char(20), age int , DOB datetime)""")
|
||||||
data = (("bob",21,"1990-02-06 23:04:56"),
|
data = [("bob", 21, "1990-02-06 23:04:56"),
|
||||||
("jim",56,"1955-05-09 13:12:45"),
|
("jim", 56, "1955-05-09 13:12:45"),
|
||||||
("fred",100,"1911-09-12 01:01:01"))
|
("fred", 100, "1911-09-12 01:01:01")]
|
||||||
bob = {'name':'bob','age':21,'DOB':datetime.datetime(1990, 2, 6, 23, 4, 56)}
|
c.executemany("insert into dictcursor values (%s,%s,%s)", data)
|
||||||
jim = {'name':'jim','age':56,'DOB':datetime.datetime(1955, 5, 9, 13, 12, 45)}
|
|
||||||
fred = {'name':'fred','age':100,'DOB':datetime.datetime(1911, 9, 12, 1, 1, 1)}
|
def tearDown(self):
|
||||||
try:
|
c = self.conn.cursor()
|
||||||
c.executemany("insert into dictcursor values (%s,%s,%s)", data)
|
c.execute("drop table dictcursor")
|
||||||
# try an update which should return no rows
|
super(TestDictCursor, self).tearDown()
|
||||||
c.execute("update dictcursor set age=20 where name='bob'")
|
|
||||||
bob['age'] = 20
|
def test_DictCursor(self):
|
||||||
# pull back the single row dict for bob and check
|
bob, jim, fred = self.bob.copy(), self.jim.copy(), self.fred.copy()
|
||||||
c.execute("SELECT * from dictcursor where name='bob'")
|
#all assert test compare to the structure as would come out from MySQLdb
|
||||||
r = c.fetchone()
|
conn = self.conn
|
||||||
self.assertEqual(bob,r,"fetchone via DictCursor failed")
|
c = conn.cursor(pymysql.cursors.DictCursor)
|
||||||
# same again, but via fetchall => tuple)
|
|
||||||
c.execute("SELECT * from dictcursor where name='bob'")
|
# try an update which should return no rows
|
||||||
r = c.fetchall()
|
c.execute("update dictcursor set age=20 where name='bob'")
|
||||||
self.assertEqual((bob,),r,"fetch a 1 row result via fetchall failed via DictCursor")
|
bob['age'] = 20
|
||||||
# same test again but iterate over the
|
# pull back the single row dict for bob and check
|
||||||
c.execute("SELECT * from dictcursor where name='bob'")
|
c.execute("SELECT * from dictcursor where name='bob'")
|
||||||
for r in c:
|
r = c.fetchone()
|
||||||
self.assertEqual(bob, r,"fetch a 1 row result via iteration failed via DictCursor")
|
self.assertEqual(bob, r, "fetchone via DictCursor failed")
|
||||||
# get all 3 row via fetchall
|
# same again, but via fetchall => tuple)
|
||||||
c.execute("SELECT * from dictcursor")
|
c.execute("SELECT * from dictcursor where name='bob'")
|
||||||
r = c.fetchall()
|
r = c.fetchall()
|
||||||
self.assertEqual((bob,jim,fred), r, "fetchall failed via DictCursor")
|
self.assertEqual([bob], r, "fetch a 1 row result via fetchall failed via DictCursor")
|
||||||
#same test again but do a list comprehension
|
# same test again but iterate over the
|
||||||
c.execute("SELECT * from dictcursor")
|
c.execute("SELECT * from dictcursor where name='bob'")
|
||||||
r = [x for x in c]
|
for r in c:
|
||||||
self.assertEqual([bob,jim,fred], r, "list comprehension failed via DictCursor")
|
self.assertEqual(bob, r, "fetch a 1 row result via iteration failed via DictCursor")
|
||||||
# get all 2 row via fetchmany
|
# get all 3 row via fetchall
|
||||||
c.execute("SELECT * from dictcursor")
|
c.execute("SELECT * from dictcursor")
|
||||||
r = c.fetchmany(2)
|
r = c.fetchall()
|
||||||
self.assertEqual((bob,jim), r, "fetchmany failed via DictCursor")
|
self.assertEqual([bob,jim,fred], r, "fetchall failed via DictCursor")
|
||||||
finally:
|
#same test again but do a list comprehension
|
||||||
c.execute("drop table dictcursor")
|
c.execute("SELECT * from dictcursor")
|
||||||
|
r = list(c)
|
||||||
|
self.assertEqual([bob,jim,fred], r, "DictCursor should be iterable")
|
||||||
|
# get all 2 row via fetchmany
|
||||||
|
c.execute("SELECT * from dictcursor")
|
||||||
|
r = c.fetchmany(2)
|
||||||
|
self.assertEqual([bob, jim], r, "fetchmany failed via DictCursor")
|
||||||
|
|
||||||
|
def test_custom_dict(self):
|
||||||
|
class MyDict(dict): pass
|
||||||
|
|
||||||
|
class MyDictCursor(pymysql.cursors.DictCursor):
|
||||||
|
dict_type = MyDict
|
||||||
|
|
||||||
|
keys = ['name', 'age', 'DOB']
|
||||||
|
bob = MyDict([(k, self.bob[k]) for k in keys])
|
||||||
|
jim = MyDict([(k, self.jim[k]) for k in keys])
|
||||||
|
fred = MyDict([(k, self.fred[k]) for k in keys])
|
||||||
|
|
||||||
|
cur = self.conn.cursor(MyDictCursor)
|
||||||
|
cur.execute("SELECT * FROM dictcursor WHERE name='bob'")
|
||||||
|
r = cur.fetchone()
|
||||||
|
self.assertEqual(bob, r, "fetchone() returns MyDictCursor")
|
||||||
|
|
||||||
|
cur.execute("SELECT * FROM dictcursor")
|
||||||
|
r = cur.fetchall()
|
||||||
|
self.assertEqual([bob, jim, fred], r,
|
||||||
|
"fetchall failed via MyDictCursor")
|
||||||
|
|
||||||
|
cur.execute("SELECT * FROM dictcursor")
|
||||||
|
r = list(cur)
|
||||||
|
self.assertEqual([bob, jim, fred], r,
|
||||||
|
"list failed via MyDictCursor")
|
||||||
|
|
||||||
|
cur.execute("SELECT * FROM dictcursor")
|
||||||
|
r = cur.fetchmany(2)
|
||||||
|
self.assertEqual([bob, jim], r,
|
||||||
|
"list failed via MyDictCursor")
|
||||||
|
|
||||||
__all__ = ["TestDictCursor"]
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import unittest
|
import unittest
|
||||||
|
|||||||
Reference in New Issue
Block a user