Added QueryChain

This commit is contained in:
Konsta Vesterinen
2014-05-06 13:25:33 +03:00
parent abb3f83d94
commit 8df5f6dfa8
6 changed files with 138 additions and 4 deletions

View File

@@ -4,11 +4,12 @@ Changelog
Here you can see the full list of changes between each SQLAlchemy-Utils release.
0.25.5 (2014-xx-xx)
0.26.0 (2014-xx-xx)
^^^^^^^^^^^^^^^^^^^
- Added get_referencing_foreign_keys
- Added get_tables
- Added QueryChain
0.25.4 (2014-04-22)

View File

@@ -31,6 +31,7 @@ from .listeners import (
from .merge import merge, Merger
from .generic import generic_relationship
from .proxy_dict import ProxyDict, proxy_dict
from .query_chain import QueryChain
from .types import (
ArrowType,
Choice,
@@ -118,6 +119,7 @@ __all__ = (
PhoneNumber,
PhoneNumberType,
ProxyDict,
QueryChain,
ScalarListException,
ScalarListType,
TimezoneType,

View File

@@ -3,6 +3,7 @@ try:
except ImportError:
from ordereddict import OrderedDict
from functools import partial
from itertools import groupby
from inspect import isclass
from operator import attrgetter
import sqlalchemy as sa
@@ -115,7 +116,7 @@ def get_tables(mixed):
get_tables(Article.__mapper__)
.. versionadded: 0.25.5
.. versionadded: 0.26.0
:param mixed:
SQLAlchemy Mapper / Declarative class or a SA Alias object wrapping

View File

@@ -0,0 +1,48 @@
from copy import copy
class QueryChain(object):
"""
:param queries: A sequence of SQLAlchemy Query objects
:param limit: Similar to normal query limit this parameter can be used for
limiting the number of results for the whole query chain.
:param offset: Similar to normal query offset this parameter can be used
for offsetting the query chain as a whole.
::
chain = QueryChain([session.query(User), session.query(Article)])
for obj in chain[0:5]:
print obj
.. versionadded: 0.26.0
"""
def __init__(self, queries, limit=None, offset=None):
self.queries = queries
self.limit = limit
self.offset = offset
def __iter__(self):
consumed = 0
skipped = 0
for query in self.queries:
query_copy = copy(query)
if self.limit:
query = query.limit(self.limit - consumed)
if self.offset:
query = query.offset(self.offset - skipped)
obj_count = 0
for obj in query:
consumed += 1
obj_count += 1
yield obj
if not obj_count:
skipped += query_copy.count()
else:
skipped += obj_count
def __repr__(self):
return '<QueryChain at 0x%x>' % id(self)

View File

@@ -40,8 +40,8 @@ class TestGetReferencingFksWithInheritance(TestCase):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
type = sa.Column(sa.Unicode)
first_name = sa.Column(sa.Unicode(255), primary_key=True)
last_name = sa.Column(sa.Unicode(255), primary_key=True)
first_name = sa.Column(sa.Unicode(255))
last_name = sa.Column(sa.Unicode(255))
__mapper_args__ = {
'polymorphic_on': 'type'

82
tests/test_query_chain.py Normal file
View File

@@ -0,0 +1,82 @@
import sqlalchemy as sa
from sqlalchemy_utils import QueryChain
from tests import TestCase
class TestQueryChain(TestCase):
def create_models(self):
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
class Article(self.Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
class BlogPost(self.Base):
__tablename__ = 'blog_post'
id = sa.Column(sa.Integer, primary_key=True)
self.User = User
self.Article = Article
self.BlogPost = BlogPost
def setup_method(self, method):
TestCase.setup_method(self, method)
self.users = [
self.User(),
self.User()
]
self.articles = [
self.Article(),
self.Article(),
self.Article(),
self.Article()
]
self.posts = [
self.BlogPost(),
self.BlogPost(),
self.BlogPost(),
]
self.session.add_all(self.users)
self.session.add_all(self.articles)
self.session.add_all(self.posts)
self.session.commit()
self.chain = QueryChain(
[
self.session.query(self.User).order_by('id'),
self.session.query(self.Article).order_by('id'),
self.session.query(self.BlogPost).order_by('id')
]
)
def test_iter(self):
assert len(list(self.chain)) == 9
def test_iter_with_limit(self):
self.chain.limit = 4
objects = list(self.chain)
assert self.users == objects[0:2]
assert self.articles[0:2] == objects[2:]
def test_iter_with_offset(self):
self.chain.offset = 3
objects = list(self.chain)
assert self.articles[1:] + self.posts == objects
def test_iter_with_limit_and_offset(self):
self.chain.offset = 3
self.chain.limit = 4
objects = list(self.chain)
assert self.articles[1:] + self.posts[0:1] == objects
def test_iter_with_offset_spanning_multiple_queries(self):
self.chain.offset = 7
objects = list(self.chain)
assert self.posts[1:] == objects
def test_repr(self):
assert repr(self.chain) == '<QueryChain at 0x%x>' % id(self.chain)