From 8df5f6dfa81375a982e0ef8fa97b7d796e9ad411 Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Tue, 6 May 2014 13:25:33 +0300 Subject: [PATCH] Added QueryChain --- CHANGES.rst | 3 +- sqlalchemy_utils/__init__.py | 2 + sqlalchemy_utils/functions/orm.py | 3 +- sqlalchemy_utils/query_chain.py | 48 +++++++++++ .../test_get_referencing_foreign_keys.py | 4 +- tests/test_query_chain.py | 82 +++++++++++++++++++ 6 files changed, 138 insertions(+), 4 deletions(-) create mode 100644 sqlalchemy_utils/query_chain.py create mode 100644 tests/test_query_chain.py diff --git a/CHANGES.rst b/CHANGES.rst index 3dbae83..f05bad5 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,11 +4,12 @@ Changelog Here you can see the full list of changes between each SQLAlchemy-Utils release. -0.25.5 (2014-xx-xx) +0.26.0 (2014-xx-xx) ^^^^^^^^^^^^^^^^^^^ - Added get_referencing_foreign_keys - Added get_tables +- Added QueryChain 0.25.4 (2014-04-22) diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index 87641fc..c310539 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -31,6 +31,7 @@ from .listeners import ( from .merge import merge, Merger from .generic import generic_relationship from .proxy_dict import ProxyDict, proxy_dict +from .query_chain import QueryChain from .types import ( ArrowType, Choice, @@ -118,6 +119,7 @@ __all__ = ( PhoneNumber, PhoneNumberType, ProxyDict, + QueryChain, ScalarListException, ScalarListType, TimezoneType, diff --git a/sqlalchemy_utils/functions/orm.py b/sqlalchemy_utils/functions/orm.py index 543113f..b5ab5c3 100644 --- a/sqlalchemy_utils/functions/orm.py +++ b/sqlalchemy_utils/functions/orm.py @@ -3,6 +3,7 @@ try: except ImportError: from ordereddict import OrderedDict from functools import partial +from itertools import groupby from inspect import isclass from operator import attrgetter import sqlalchemy as sa @@ -115,7 +116,7 @@ def get_tables(mixed): get_tables(Article.__mapper__) - .. versionadded: 0.25.5 + .. versionadded: 0.26.0 :param mixed: SQLAlchemy Mapper / Declarative class or a SA Alias object wrapping diff --git a/sqlalchemy_utils/query_chain.py b/sqlalchemy_utils/query_chain.py new file mode 100644 index 0000000..8d25448 --- /dev/null +++ b/sqlalchemy_utils/query_chain.py @@ -0,0 +1,48 @@ +from copy import copy + + +class QueryChain(object): + """ + :param queries: A sequence of SQLAlchemy Query objects + :param limit: Similar to normal query limit this parameter can be used for + limiting the number of results for the whole query chain. + :param offset: Similar to normal query offset this parameter can be used + for offsetting the query chain as a whole. + + :: + + chain = QueryChain([session.query(User), session.query(Article)]) + + for obj in chain[0:5]: + print obj + + .. versionadded: 0.26.0 + """ + def __init__(self, queries, limit=None, offset=None): + self.queries = queries + self.limit = limit + self.offset = offset + + def __iter__(self): + consumed = 0 + skipped = 0 + for query in self.queries: + query_copy = copy(query) + if self.limit: + query = query.limit(self.limit - consumed) + if self.offset: + query = query.offset(self.offset - skipped) + + obj_count = 0 + for obj in query: + consumed += 1 + obj_count += 1 + yield obj + + if not obj_count: + skipped += query_copy.count() + else: + skipped += obj_count + + def __repr__(self): + return '' % id(self) diff --git a/tests/functions/test_get_referencing_foreign_keys.py b/tests/functions/test_get_referencing_foreign_keys.py index 3ad4385..4e0660e 100644 --- a/tests/functions/test_get_referencing_foreign_keys.py +++ b/tests/functions/test_get_referencing_foreign_keys.py @@ -40,8 +40,8 @@ class TestGetReferencingFksWithInheritance(TestCase): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) type = sa.Column(sa.Unicode) - first_name = sa.Column(sa.Unicode(255), primary_key=True) - last_name = sa.Column(sa.Unicode(255), primary_key=True) + first_name = sa.Column(sa.Unicode(255)) + last_name = sa.Column(sa.Unicode(255)) __mapper_args__ = { 'polymorphic_on': 'type' diff --git a/tests/test_query_chain.py b/tests/test_query_chain.py new file mode 100644 index 0000000..61541da --- /dev/null +++ b/tests/test_query_chain.py @@ -0,0 +1,82 @@ +import sqlalchemy as sa +from sqlalchemy_utils import QueryChain + +from tests import TestCase + + +class TestQueryChain(TestCase): + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + + class Article(self.Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + + class BlogPost(self.Base): + __tablename__ = 'blog_post' + id = sa.Column(sa.Integer, primary_key=True) + + self.User = User + self.Article = Article + self.BlogPost = BlogPost + + def setup_method(self, method): + TestCase.setup_method(self, method) + self.users = [ + self.User(), + self.User() + ] + self.articles = [ + self.Article(), + self.Article(), + self.Article(), + self.Article() + ] + self.posts = [ + self.BlogPost(), + self.BlogPost(), + self.BlogPost(), + ] + + self.session.add_all(self.users) + self.session.add_all(self.articles) + self.session.add_all(self.posts) + self.session.commit() + + self.chain = QueryChain( + [ + self.session.query(self.User).order_by('id'), + self.session.query(self.Article).order_by('id'), + self.session.query(self.BlogPost).order_by('id') + ] + ) + + def test_iter(self): + assert len(list(self.chain)) == 9 + + def test_iter_with_limit(self): + self.chain.limit = 4 + objects = list(self.chain) + assert self.users == objects[0:2] + assert self.articles[0:2] == objects[2:] + + def test_iter_with_offset(self): + self.chain.offset = 3 + objects = list(self.chain) + assert self.articles[1:] + self.posts == objects + + def test_iter_with_limit_and_offset(self): + self.chain.offset = 3 + self.chain.limit = 4 + objects = list(self.chain) + assert self.articles[1:] + self.posts[0:1] == objects + + def test_iter_with_offset_spanning_multiple_queries(self): + self.chain.offset = 7 + objects = list(self.chain) + assert self.posts[1:] == objects + + def test_repr(self): + assert repr(self.chain) == '' % id(self.chain)