Added QueryChain

This commit is contained in:
Konsta Vesterinen
2014-05-06 13:25:33 +03:00
parent abb3f83d94
commit 8df5f6dfa8
6 changed files with 138 additions and 4 deletions

View File

@@ -4,11 +4,12 @@ Changelog
Here you can see the full list of changes between each SQLAlchemy-Utils release. Here you can see the full list of changes between each SQLAlchemy-Utils release.
0.25.5 (2014-xx-xx) 0.26.0 (2014-xx-xx)
^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^
- Added get_referencing_foreign_keys - Added get_referencing_foreign_keys
- Added get_tables - Added get_tables
- Added QueryChain
0.25.4 (2014-04-22) 0.25.4 (2014-04-22)

View File

@@ -31,6 +31,7 @@ from .listeners import (
from .merge import merge, Merger from .merge import merge, Merger
from .generic import generic_relationship from .generic import generic_relationship
from .proxy_dict import ProxyDict, proxy_dict from .proxy_dict import ProxyDict, proxy_dict
from .query_chain import QueryChain
from .types import ( from .types import (
ArrowType, ArrowType,
Choice, Choice,
@@ -118,6 +119,7 @@ __all__ = (
PhoneNumber, PhoneNumber,
PhoneNumberType, PhoneNumberType,
ProxyDict, ProxyDict,
QueryChain,
ScalarListException, ScalarListException,
ScalarListType, ScalarListType,
TimezoneType, TimezoneType,

View File

@@ -3,6 +3,7 @@ try:
except ImportError: except ImportError:
from ordereddict import OrderedDict from ordereddict import OrderedDict
from functools import partial from functools import partial
from itertools import groupby
from inspect import isclass from inspect import isclass
from operator import attrgetter from operator import attrgetter
import sqlalchemy as sa import sqlalchemy as sa
@@ -115,7 +116,7 @@ def get_tables(mixed):
get_tables(Article.__mapper__) get_tables(Article.__mapper__)
.. versionadded: 0.25.5 .. versionadded: 0.26.0
:param mixed: :param mixed:
SQLAlchemy Mapper / Declarative class or a SA Alias object wrapping SQLAlchemy Mapper / Declarative class or a SA Alias object wrapping

View File

@@ -0,0 +1,48 @@
from copy import copy
class QueryChain(object):
"""
:param queries: A sequence of SQLAlchemy Query objects
:param limit: Similar to normal query limit this parameter can be used for
limiting the number of results for the whole query chain.
:param offset: Similar to normal query offset this parameter can be used
for offsetting the query chain as a whole.
::
chain = QueryChain([session.query(User), session.query(Article)])
for obj in chain[0:5]:
print obj
.. versionadded: 0.26.0
"""
def __init__(self, queries, limit=None, offset=None):
self.queries = queries
self.limit = limit
self.offset = offset
def __iter__(self):
consumed = 0
skipped = 0
for query in self.queries:
query_copy = copy(query)
if self.limit:
query = query.limit(self.limit - consumed)
if self.offset:
query = query.offset(self.offset - skipped)
obj_count = 0
for obj in query:
consumed += 1
obj_count += 1
yield obj
if not obj_count:
skipped += query_copy.count()
else:
skipped += obj_count
def __repr__(self):
return '<QueryChain at 0x%x>' % id(self)

View File

@@ -40,8 +40,8 @@ class TestGetReferencingFksWithInheritance(TestCase):
__tablename__ = 'user' __tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
type = sa.Column(sa.Unicode) type = sa.Column(sa.Unicode)
first_name = sa.Column(sa.Unicode(255), primary_key=True) first_name = sa.Column(sa.Unicode(255))
last_name = sa.Column(sa.Unicode(255), primary_key=True) last_name = sa.Column(sa.Unicode(255))
__mapper_args__ = { __mapper_args__ = {
'polymorphic_on': 'type' 'polymorphic_on': 'type'

82
tests/test_query_chain.py Normal file
View File

@@ -0,0 +1,82 @@
import sqlalchemy as sa
from sqlalchemy_utils import QueryChain
from tests import TestCase
class TestQueryChain(TestCase):
def create_models(self):
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
class Article(self.Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
class BlogPost(self.Base):
__tablename__ = 'blog_post'
id = sa.Column(sa.Integer, primary_key=True)
self.User = User
self.Article = Article
self.BlogPost = BlogPost
def setup_method(self, method):
TestCase.setup_method(self, method)
self.users = [
self.User(),
self.User()
]
self.articles = [
self.Article(),
self.Article(),
self.Article(),
self.Article()
]
self.posts = [
self.BlogPost(),
self.BlogPost(),
self.BlogPost(),
]
self.session.add_all(self.users)
self.session.add_all(self.articles)
self.session.add_all(self.posts)
self.session.commit()
self.chain = QueryChain(
[
self.session.query(self.User).order_by('id'),
self.session.query(self.Article).order_by('id'),
self.session.query(self.BlogPost).order_by('id')
]
)
def test_iter(self):
assert len(list(self.chain)) == 9
def test_iter_with_limit(self):
self.chain.limit = 4
objects = list(self.chain)
assert self.users == objects[0:2]
assert self.articles[0:2] == objects[2:]
def test_iter_with_offset(self):
self.chain.offset = 3
objects = list(self.chain)
assert self.articles[1:] + self.posts == objects
def test_iter_with_limit_and_offset(self):
self.chain.offset = 3
self.chain.limit = 4
objects = list(self.chain)
assert self.articles[1:] + self.posts[0:1] == objects
def test_iter_with_offset_spanning_multiple_queries(self):
self.chain.offset = 7
objects = list(self.chain)
assert self.posts[1:] == objects
def test_repr(self):
assert repr(self.chain) == '<QueryChain at 0x%x>' % id(self.chain)