Added QueryChain
This commit is contained in:
@@ -4,11 +4,12 @@ Changelog
|
||||
Here you can see the full list of changes between each SQLAlchemy-Utils release.
|
||||
|
||||
|
||||
0.25.5 (2014-xx-xx)
|
||||
0.26.0 (2014-xx-xx)
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
- Added get_referencing_foreign_keys
|
||||
- Added get_tables
|
||||
- Added QueryChain
|
||||
|
||||
|
||||
0.25.4 (2014-04-22)
|
||||
|
@@ -31,6 +31,7 @@ from .listeners import (
|
||||
from .merge import merge, Merger
|
||||
from .generic import generic_relationship
|
||||
from .proxy_dict import ProxyDict, proxy_dict
|
||||
from .query_chain import QueryChain
|
||||
from .types import (
|
||||
ArrowType,
|
||||
Choice,
|
||||
@@ -118,6 +119,7 @@ __all__ = (
|
||||
PhoneNumber,
|
||||
PhoneNumberType,
|
||||
ProxyDict,
|
||||
QueryChain,
|
||||
ScalarListException,
|
||||
ScalarListType,
|
||||
TimezoneType,
|
||||
|
@@ -3,6 +3,7 @@ try:
|
||||
except ImportError:
|
||||
from ordereddict import OrderedDict
|
||||
from functools import partial
|
||||
from itertools import groupby
|
||||
from inspect import isclass
|
||||
from operator import attrgetter
|
||||
import sqlalchemy as sa
|
||||
@@ -115,7 +116,7 @@ def get_tables(mixed):
|
||||
get_tables(Article.__mapper__)
|
||||
|
||||
|
||||
.. versionadded: 0.25.5
|
||||
.. versionadded: 0.26.0
|
||||
|
||||
:param mixed:
|
||||
SQLAlchemy Mapper / Declarative class or a SA Alias object wrapping
|
||||
|
48
sqlalchemy_utils/query_chain.py
Normal file
48
sqlalchemy_utils/query_chain.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from copy import copy
|
||||
|
||||
|
||||
class QueryChain(object):
|
||||
"""
|
||||
:param queries: A sequence of SQLAlchemy Query objects
|
||||
:param limit: Similar to normal query limit this parameter can be used for
|
||||
limiting the number of results for the whole query chain.
|
||||
:param offset: Similar to normal query offset this parameter can be used
|
||||
for offsetting the query chain as a whole.
|
||||
|
||||
::
|
||||
|
||||
chain = QueryChain([session.query(User), session.query(Article)])
|
||||
|
||||
for obj in chain[0:5]:
|
||||
print obj
|
||||
|
||||
.. versionadded: 0.26.0
|
||||
"""
|
||||
def __init__(self, queries, limit=None, offset=None):
|
||||
self.queries = queries
|
||||
self.limit = limit
|
||||
self.offset = offset
|
||||
|
||||
def __iter__(self):
|
||||
consumed = 0
|
||||
skipped = 0
|
||||
for query in self.queries:
|
||||
query_copy = copy(query)
|
||||
if self.limit:
|
||||
query = query.limit(self.limit - consumed)
|
||||
if self.offset:
|
||||
query = query.offset(self.offset - skipped)
|
||||
|
||||
obj_count = 0
|
||||
for obj in query:
|
||||
consumed += 1
|
||||
obj_count += 1
|
||||
yield obj
|
||||
|
||||
if not obj_count:
|
||||
skipped += query_copy.count()
|
||||
else:
|
||||
skipped += obj_count
|
||||
|
||||
def __repr__(self):
|
||||
return '<QueryChain at 0x%x>' % id(self)
|
@@ -40,8 +40,8 @@ class TestGetReferencingFksWithInheritance(TestCase):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
type = sa.Column(sa.Unicode)
|
||||
first_name = sa.Column(sa.Unicode(255), primary_key=True)
|
||||
last_name = sa.Column(sa.Unicode(255), primary_key=True)
|
||||
first_name = sa.Column(sa.Unicode(255))
|
||||
last_name = sa.Column(sa.Unicode(255))
|
||||
|
||||
__mapper_args__ = {
|
||||
'polymorphic_on': 'type'
|
||||
|
82
tests/test_query_chain.py
Normal file
82
tests/test_query_chain.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy_utils import QueryChain
|
||||
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestQueryChain(TestCase):
|
||||
def create_models(self):
|
||||
class User(self.Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
class Article(self.Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
class BlogPost(self.Base):
|
||||
__tablename__ = 'blog_post'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
self.User = User
|
||||
self.Article = Article
|
||||
self.BlogPost = BlogPost
|
||||
|
||||
def setup_method(self, method):
|
||||
TestCase.setup_method(self, method)
|
||||
self.users = [
|
||||
self.User(),
|
||||
self.User()
|
||||
]
|
||||
self.articles = [
|
||||
self.Article(),
|
||||
self.Article(),
|
||||
self.Article(),
|
||||
self.Article()
|
||||
]
|
||||
self.posts = [
|
||||
self.BlogPost(),
|
||||
self.BlogPost(),
|
||||
self.BlogPost(),
|
||||
]
|
||||
|
||||
self.session.add_all(self.users)
|
||||
self.session.add_all(self.articles)
|
||||
self.session.add_all(self.posts)
|
||||
self.session.commit()
|
||||
|
||||
self.chain = QueryChain(
|
||||
[
|
||||
self.session.query(self.User).order_by('id'),
|
||||
self.session.query(self.Article).order_by('id'),
|
||||
self.session.query(self.BlogPost).order_by('id')
|
||||
]
|
||||
)
|
||||
|
||||
def test_iter(self):
|
||||
assert len(list(self.chain)) == 9
|
||||
|
||||
def test_iter_with_limit(self):
|
||||
self.chain.limit = 4
|
||||
objects = list(self.chain)
|
||||
assert self.users == objects[0:2]
|
||||
assert self.articles[0:2] == objects[2:]
|
||||
|
||||
def test_iter_with_offset(self):
|
||||
self.chain.offset = 3
|
||||
objects = list(self.chain)
|
||||
assert self.articles[1:] + self.posts == objects
|
||||
|
||||
def test_iter_with_limit_and_offset(self):
|
||||
self.chain.offset = 3
|
||||
self.chain.limit = 4
|
||||
objects = list(self.chain)
|
||||
assert self.articles[1:] + self.posts[0:1] == objects
|
||||
|
||||
def test_iter_with_offset_spanning_multiple_queries(self):
|
||||
self.chain.offset = 7
|
||||
objects = list(self.chain)
|
||||
assert self.posts[1:] == objects
|
||||
|
||||
def test_repr(self):
|
||||
assert repr(self.chain) == '<QueryChain at 0x%x>' % id(self.chain)
|
Reference in New Issue
Block a user