Merge pull request #158 from methane/dict-type

Add DictCursor.dict_type to override it.
This commit is contained in:
INADA Naoki
2013-09-22 18:49:40 -07:00
2 changed files with 97 additions and 54 deletions

View File

@@ -12,9 +12,10 @@ if PY2:
else: else:
import io import io
from .err import Warning, Error, InterfaceError, DataError, \ from .err import (
DatabaseError, OperationalError, IntegrityError, InternalError, \ Warning, Error, InterfaceError, DataError,
NotSupportedError, ProgrammingError DatabaseError, OperationalError, IntegrityError, InternalError,
NotSupportedError, ProgrammingError)
insert_values = re.compile(r'\svalues\s*(\(.+\))', re.IGNORECASE) insert_values = re.compile(r'\svalues\s*(\(.+\))', re.IGNORECASE)
@@ -115,9 +116,8 @@ class Cursor(object):
result = 0 result = 0
try: try:
result = self._query(query) result = self._query(query)
except: except Exception:
exc, value, tb = exc_info() exc, value = exc_info()[:2]
del tb
self.errorhandler(self, exc, value) self.errorhandler(self, exc, value)
self._executed = query self._executed = query
@@ -261,6 +261,8 @@ class Cursor(object):
class DictCursor(Cursor): class DictCursor(Cursor):
"""A cursor which returns results as a dictionary""" """A cursor which returns results as a dictionary"""
# You can override this to use OrderedDict or other dict-like types.
dict_type = dict
def execute(self, query, args=None): def execute(self, query, args=None):
result = super(DictCursor, self).execute(query, args) result = super(DictCursor, self).execute(query, args)
@@ -273,7 +275,7 @@ class DictCursor(Cursor):
self._check_executed() self._check_executed()
if self._rows is None or self.rownumber >= len(self._rows): if self._rows is None or self.rownumber >= len(self._rows):
return None return None
result = dict(zip(self._fields, self._rows[self.rownumber])) result = self.dict_type(zip(self._fields, self._rows[self.rownumber]))
self.rownumber += 1 self.rownumber += 1
return result return result
@@ -283,9 +285,9 @@ class DictCursor(Cursor):
if self._rows is None: if self._rows is None:
return None return None
end = self.rownumber + (size or self.arraysize) end = self.rownumber + (size or self.arraysize)
result = [ dict(zip(self._fields, r)) for r in self._rows[self.rownumber:end] ] result = [self.dict_type(zip(self._fields, r)) for r in self._rows[self.rownumber:end]]
self.rownumber = min(end, len(self._rows)) self.rownumber = min(end, len(self._rows))
return tuple(result) return result
def fetchall(self): def fetchall(self):
''' Fetch all the rows ''' ''' Fetch all the rows '''
@@ -293,11 +295,11 @@ class DictCursor(Cursor):
if self._rows is None: if self._rows is None:
return None return None
if self.rownumber: if self.rownumber:
result = [ dict(zip(self._fields, r)) for r in self._rows[self.rownumber:] ] result = [self.dict_type(zip(self._fields, r)) for r in self._rows[self.rownumber:]]
else: else:
result = [ dict(zip(self._fields, r)) for r in self._rows ] result = [self.dict_type(zip(self._fields, r)) for r in self._rows]
self.rownumber = len(self._rows) self.rownumber = len(self._rows)
return tuple(result) return result
class SSCursor(Cursor): class SSCursor(Cursor):
""" """

View File

@@ -3,53 +3,94 @@ import pymysql.cursors
import datetime import datetime
class TestDictCursor(base.PyMySQLTestCase):
def test_DictCursor(self): class TestDictCursor(base.PyMySQLTestCase):
#all assert test compare to the structure as would come out from MySQLdb bob = {'name': 'bob', 'age': 21, 'DOB': datetime.datetime(1990, 2, 6, 23, 4, 56)}
conn = self.connections[0] jim = {'name': 'jim', 'age': 56, 'DOB': datetime.datetime(1955, 5, 9, 13, 12, 45)}
fred = {'name': 'fred', 'age': 100, 'DOB': datetime.datetime(1911, 9, 12, 1, 1, 1)}
def setUp(self):
super(TestDictCursor, self).setUp()
self.conn = conn = self.connections[0]
c = conn.cursor(pymysql.cursors.DictCursor) c = conn.cursor(pymysql.cursors.DictCursor)
# create a table ane some data to query # create a table ane some data to query
c.execute("""CREATE TABLE dictcursor (name char(20), age int , DOB datetime)""") c.execute("""CREATE TABLE dictcursor (name char(20), age int , DOB datetime)""")
data = (("bob",21,"1990-02-06 23:04:56"), data = [("bob", 21, "1990-02-06 23:04:56"),
("jim",56,"1955-05-09 13:12:45"), ("jim", 56, "1955-05-09 13:12:45"),
("fred",100,"1911-09-12 01:01:01")) ("fred", 100, "1911-09-12 01:01:01")]
bob = {'name':'bob','age':21,'DOB':datetime.datetime(1990, 2, 6, 23, 4, 56)} c.executemany("insert into dictcursor values (%s,%s,%s)", data)
jim = {'name':'jim','age':56,'DOB':datetime.datetime(1955, 5, 9, 13, 12, 45)}
fred = {'name':'fred','age':100,'DOB':datetime.datetime(1911, 9, 12, 1, 1, 1)} def tearDown(self):
try: c = self.conn.cursor()
c.executemany("insert into dictcursor values (%s,%s,%s)", data) c.execute("drop table dictcursor")
# try an update which should return no rows super(TestDictCursor, self).tearDown()
c.execute("update dictcursor set age=20 where name='bob'")
bob['age'] = 20 def test_DictCursor(self):
# pull back the single row dict for bob and check bob, jim, fred = self.bob.copy(), self.jim.copy(), self.fred.copy()
c.execute("SELECT * from dictcursor where name='bob'") #all assert test compare to the structure as would come out from MySQLdb
r = c.fetchone() conn = self.conn
self.assertEqual(bob,r,"fetchone via DictCursor failed") c = conn.cursor(pymysql.cursors.DictCursor)
# same again, but via fetchall => tuple)
c.execute("SELECT * from dictcursor where name='bob'") # try an update which should return no rows
r = c.fetchall() c.execute("update dictcursor set age=20 where name='bob'")
self.assertEqual((bob,),r,"fetch a 1 row result via fetchall failed via DictCursor") bob['age'] = 20
# same test again but iterate over the # pull back the single row dict for bob and check
c.execute("SELECT * from dictcursor where name='bob'") c.execute("SELECT * from dictcursor where name='bob'")
for r in c: r = c.fetchone()
self.assertEqual(bob, r,"fetch a 1 row result via iteration failed via DictCursor") self.assertEqual(bob, r, "fetchone via DictCursor failed")
# get all 3 row via fetchall # same again, but via fetchall => tuple)
c.execute("SELECT * from dictcursor") c.execute("SELECT * from dictcursor where name='bob'")
r = c.fetchall() r = c.fetchall()
self.assertEqual((bob,jim,fred), r, "fetchall failed via DictCursor") self.assertEqual([bob], r, "fetch a 1 row result via fetchall failed via DictCursor")
#same test again but do a list comprehension # same test again but iterate over the
c.execute("SELECT * from dictcursor") c.execute("SELECT * from dictcursor where name='bob'")
r = [x for x in c] for r in c:
self.assertEqual([bob,jim,fred], r, "list comprehension failed via DictCursor") self.assertEqual(bob, r, "fetch a 1 row result via iteration failed via DictCursor")
# get all 2 row via fetchmany # get all 3 row via fetchall
c.execute("SELECT * from dictcursor") c.execute("SELECT * from dictcursor")
r = c.fetchmany(2) r = c.fetchall()
self.assertEqual((bob,jim), r, "fetchmany failed via DictCursor") self.assertEqual([bob,jim,fred], r, "fetchall failed via DictCursor")
finally: #same test again but do a list comprehension
c.execute("drop table dictcursor") c.execute("SELECT * from dictcursor")
r = list(c)
self.assertEqual([bob,jim,fred], r, "DictCursor should be iterable")
# get all 2 row via fetchmany
c.execute("SELECT * from dictcursor")
r = c.fetchmany(2)
self.assertEqual([bob, jim], r, "fetchmany failed via DictCursor")
def test_custom_dict(self):
class MyDict(dict): pass
class MyDictCursor(pymysql.cursors.DictCursor):
dict_type = MyDict
keys = ['name', 'age', 'DOB']
bob = MyDict([(k, self.bob[k]) for k in keys])
jim = MyDict([(k, self.jim[k]) for k in keys])
fred = MyDict([(k, self.fred[k]) for k in keys])
cur = self.conn.cursor(MyDictCursor)
cur.execute("SELECT * FROM dictcursor WHERE name='bob'")
r = cur.fetchone()
self.assertEqual(bob, r, "fetchone() returns MyDictCursor")
cur.execute("SELECT * FROM dictcursor")
r = cur.fetchall()
self.assertEqual([bob, jim, fred], r,
"fetchall failed via MyDictCursor")
cur.execute("SELECT * FROM dictcursor")
r = list(cur)
self.assertEqual([bob, jim, fred], r,
"list failed via MyDictCursor")
cur.execute("SELECT * FROM dictcursor")
r = cur.fetchmany(2)
self.assertEqual([bob, jim], r,
"list failed via MyDictCursor")
__all__ = ["TestDictCursor"]
if __name__ == "__main__": if __name__ == "__main__":
import unittest import unittest