Add support for various polymorphic scenarios in sort_query
- Rename query_entities to get_query_entities - Add tests for get_query_entities
This commit is contained in:
@@ -4,10 +4,11 @@ Changelog
|
||||
Here you can see the full list of changes between each SQLAlchemy-Utils release.
|
||||
|
||||
|
||||
0.26.9 (2014-08-xx)
|
||||
0.26.9 (2014-08-06)
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
- Fixed PasswordType with Oracle dialect
|
||||
- Added support for sort_query and attributes on mappers using with_polymorphic
|
||||
|
||||
|
||||
0.26.8 (2014-07-30)
|
||||
|
@@ -46,6 +46,12 @@ get_mapper
|
||||
.. autofunction:: get_mapper
|
||||
|
||||
|
||||
get_query_entities
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: get_query_entities
|
||||
|
||||
|
||||
get_primary_keys
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
@@ -58,12 +64,6 @@ get_tables
|
||||
.. autofunction:: get_tables
|
||||
|
||||
|
||||
query_entities
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: query_entities
|
||||
|
||||
|
||||
has_changes
|
||||
^^^^^^^^^^^
|
||||
|
||||
|
@@ -17,6 +17,7 @@ from .functions import (
|
||||
get_declarative_base,
|
||||
get_hybrid_properties,
|
||||
get_mapper,
|
||||
get_query_entities,
|
||||
get_primary_keys,
|
||||
get_referencing_foreign_keys,
|
||||
get_tables,
|
||||
@@ -99,6 +100,7 @@ __all__ = (
|
||||
get_declarative_base,
|
||||
get_hybrid_properties,
|
||||
get_mapper,
|
||||
get_query_entities,
|
||||
get_primary_keys,
|
||||
get_referencing_foreign_keys,
|
||||
get_tables,
|
||||
|
@@ -26,13 +26,13 @@ from .orm import (
|
||||
get_hybrid_properties,
|
||||
get_mapper,
|
||||
get_primary_keys,
|
||||
get_query_entities,
|
||||
get_tables,
|
||||
getdotattr,
|
||||
has_any_changes,
|
||||
has_changes,
|
||||
identity,
|
||||
naturally_equivalent,
|
||||
query_entities,
|
||||
table_name,
|
||||
)
|
||||
|
||||
@@ -49,6 +49,7 @@ __all__ = (
|
||||
'get_declarative_base',
|
||||
'get_hybrid_properties',
|
||||
'get_mapper',
|
||||
'get_query_entities',
|
||||
'get_primary_keys',
|
||||
'get_referencing_foreign_keys',
|
||||
'get_tables',
|
||||
|
@@ -4,15 +4,14 @@ except ImportError:
|
||||
from ordereddict import OrderedDict
|
||||
from functools import partial
|
||||
from inspect import isclass
|
||||
from itertools import chain
|
||||
from operator import attrgetter
|
||||
import six
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy.ext.hybrid import hybrid_property
|
||||
from sqlalchemy.orm import mapperlib
|
||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||
from sqlalchemy.orm.exc import UnmappedInstanceError
|
||||
from sqlalchemy.orm.mapper import Mapper
|
||||
from sqlalchemy.orm.query import _ColumnEntity
|
||||
from sqlalchemy.orm.session import object_session
|
||||
from sqlalchemy.orm.util import AliasedInsp
|
||||
@@ -81,6 +80,8 @@ def get_mapper(mixed):
|
||||
return sa.inspect(mixed).mapper
|
||||
if isinstance(mixed, sa.sql.selectable.Alias):
|
||||
mixed = mixed.element
|
||||
if isinstance(mixed, AliasedInsp):
|
||||
return mixed.mapper
|
||||
if isinstance(mixed, sa.Table):
|
||||
mappers = [
|
||||
mapper for mapper in mapperlib._mapper_registry
|
||||
@@ -344,31 +345,31 @@ def query_labels(query):
|
||||
db.func.count(Article.id).label('articles')
|
||||
)
|
||||
|
||||
query_labels(query) # ('articles', )
|
||||
query_labels(query) # ['articles']
|
||||
|
||||
:param query: SQLAlchemy Query object
|
||||
"""
|
||||
for entity in query._entities:
|
||||
if isinstance(entity, _ColumnEntity) and entity._label_name:
|
||||
yield entity._label_name
|
||||
return [
|
||||
entity._label_name for entity in query._entities
|
||||
if isinstance(entity, _ColumnEntity) and entity._label_name
|
||||
]
|
||||
|
||||
|
||||
def query_entities(query):
|
||||
def get_query_entities(query):
|
||||
"""
|
||||
Return a generator that iterates through all entities for given SQLAlchemy
|
||||
query object.
|
||||
Return a list of all entities present in given SQLAlchemy query object.
|
||||
|
||||
Examples::
|
||||
|
||||
|
||||
query = session.query(Category)
|
||||
|
||||
query_entities(query) # <Category>
|
||||
query_entities(query) # [<Category>]
|
||||
|
||||
|
||||
query = session.query(Category.id)
|
||||
|
||||
query_entities(query) # <Category>
|
||||
query_entities(query) # [<Category>]
|
||||
|
||||
|
||||
This function also supports queries with joins.
|
||||
@@ -378,42 +379,66 @@ def query_entities(query):
|
||||
|
||||
query = session.query(Category).join(Article)
|
||||
|
||||
query_entities(query) # (<Category>, <Article>)
|
||||
query_entities(query) # [<Category>, <Article>]
|
||||
|
||||
.. versionchanged: 0.26.7
|
||||
This function now returns a list instead of generator
|
||||
|
||||
:param query: SQLAlchemy Query object
|
||||
"""
|
||||
for entity in query._entities:
|
||||
if entity.entity_zero:
|
||||
yield entity.entity_zero.class_
|
||||
|
||||
for entity in query._join_entities:
|
||||
if isinstance(entity, Mapper):
|
||||
yield entity.class_
|
||||
else:
|
||||
yield entity
|
||||
def get_expr(mixed):
|
||||
if isinstance(mixed, (sa.orm.Mapper, AliasedInsp)):
|
||||
return mixed
|
||||
expr = mixed.expr
|
||||
if isinstance(expr, sa.orm.attributes.InstrumentedAttribute):
|
||||
expr = expr.parent
|
||||
elif isinstance(expr, sa.Column):
|
||||
expr = expr.table
|
||||
elif isinstance(expr, sa.sql.expression.Label):
|
||||
if mixed.entity_zero:
|
||||
return mixed.entity_zero
|
||||
else:
|
||||
return expr
|
||||
return expr
|
||||
return [
|
||||
get_expr(entity) for entity in
|
||||
chain(query._entities, query._join_entities)
|
||||
]
|
||||
|
||||
|
||||
def get_query_entity_by_alias(query, alias):
|
||||
entities = query_entities(query)
|
||||
entities = get_query_entities(query)
|
||||
if not alias:
|
||||
return list(entities)[0]
|
||||
return entities[0]
|
||||
|
||||
for entity in entities:
|
||||
if isinstance(entity, AliasedInsp):
|
||||
name = entity.name
|
||||
else:
|
||||
name = entity.__table__.name
|
||||
name = get_mapper(entity).tables[0].name
|
||||
|
||||
if name == alias:
|
||||
return entity
|
||||
|
||||
|
||||
def get_attrs(expr):
|
||||
if isinstance(expr, AliasedInsp):
|
||||
return expr.mapper.attrs
|
||||
def get_polymorphic_mappers(mixed):
|
||||
if isinstance(mixed, AliasedInsp):
|
||||
return mixed.with_polymorphic_mappers
|
||||
else:
|
||||
return inspect(expr).attrs
|
||||
return mixed.polymorphic_map.values()
|
||||
|
||||
|
||||
def get_attrs(expr):
|
||||
insp = sa.inspect(expr)
|
||||
mapper = get_mapper(expr)
|
||||
polymorphic_mappers = get_polymorphic_mappers(insp)
|
||||
|
||||
if polymorphic_mappers:
|
||||
attrs = {}
|
||||
for submapper in polymorphic_mappers:
|
||||
attrs.update(submapper.attrs)
|
||||
return attrs
|
||||
return mapper.attrs
|
||||
|
||||
|
||||
def get_hybrid_properties(model):
|
||||
@@ -467,11 +492,11 @@ def get_hybrid_properties(model):
|
||||
)
|
||||
|
||||
|
||||
def get_expr_attr(expr, attr_name):
|
||||
def get_expr_attr(expr, prop):
|
||||
if isinstance(expr, AliasedInsp):
|
||||
return getattr(expr.selectable.c, attr_name)
|
||||
return getattr(expr.selectable.c, prop.key)
|
||||
else:
|
||||
return getattr(expr, attr_name)
|
||||
return getattr(prop.parent.class_, prop.key)
|
||||
|
||||
|
||||
def get_declarative_base(model):
|
||||
|
@@ -7,7 +7,7 @@ from .orm import (
|
||||
get_expr_attr,
|
||||
get_hybrid_properties,
|
||||
get_query_entity_by_alias,
|
||||
query_entities,
|
||||
get_query_entities,
|
||||
query_labels,
|
||||
)
|
||||
|
||||
@@ -18,7 +18,6 @@ class QuerySorterException(Exception):
|
||||
|
||||
class QuerySorter(object):
|
||||
def __init__(self, silent=True, separator='-'):
|
||||
self.entities = []
|
||||
self.labels = []
|
||||
self.separator = separator
|
||||
self.silent = silent
|
||||
@@ -44,20 +43,21 @@ class QuerySorter(object):
|
||||
properties = get_attrs(entity)
|
||||
if attr in properties:
|
||||
property_ = properties[attr]
|
||||
|
||||
if isinstance(property_, ColumnProperty):
|
||||
if isinstance(property_.columns[0], Label):
|
||||
return getattr(entity, property_.key)
|
||||
else:
|
||||
return get_expr_attr(entity, property_.key)
|
||||
return get_expr_attr(entity, property_)
|
||||
elif isinstance(property_, SynonymProperty):
|
||||
return get_expr_attr(entity, property_.key)
|
||||
return get_expr_attr(entity, property_)
|
||||
return
|
||||
|
||||
mapper = sa.inspect(entity)
|
||||
entity = mapper.entity
|
||||
|
||||
if isinstance(mapper, AliasedInsp):
|
||||
mapper = mapper.mapper
|
||||
entity = mapper.entity
|
||||
|
||||
for key in get_hybrid_properties(mapper).keys():
|
||||
if attr == key:
|
||||
@@ -80,7 +80,6 @@ class QuerySorter(object):
|
||||
def __call__(self, query, *args):
|
||||
self.query = query
|
||||
self.labels = query_labels(query)
|
||||
self.entities = query_entities(query)
|
||||
|
||||
for sort in args:
|
||||
if not sort:
|
||||
|
101
tests/functions/test_get_query_entities.py
Normal file
101
tests/functions/test_get_query_entities.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy_utils import get_query_entities
|
||||
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestGetQueryEntities(TestCase):
|
||||
def create_models(self):
|
||||
class TextItem(self.Base):
|
||||
__tablename__ = 'text_item'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
type = sa.Column(sa.Unicode(255))
|
||||
|
||||
__mapper_args__ = {
|
||||
'polymorphic_on': type,
|
||||
}
|
||||
|
||||
class Article(TextItem):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True
|
||||
)
|
||||
category = sa.Column(sa.Unicode(255))
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity': u'article'
|
||||
}
|
||||
|
||||
class BlogPost(TextItem):
|
||||
__tablename__ = 'blog_post'
|
||||
id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True
|
||||
)
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity': u'blog_post'
|
||||
}
|
||||
|
||||
self.TextItem = TextItem
|
||||
self.Article = Article
|
||||
self.BlogPost = BlogPost
|
||||
|
||||
def test_mapper(self):
|
||||
query = self.session.query(sa.inspect(self.TextItem))
|
||||
assert get_query_entities(query) == [sa.inspect(self.TextItem)]
|
||||
|
||||
def test_entity(self):
|
||||
query = self.session.query(self.TextItem)
|
||||
assert get_query_entities(query) == [self.TextItem]
|
||||
|
||||
def test_instrumented_attribute(self):
|
||||
query = self.session.query(self.TextItem.id)
|
||||
assert get_query_entities(query) == [sa.inspect(self.TextItem)]
|
||||
|
||||
def test_column(self):
|
||||
query = self.session.query(self.TextItem.__table__.c.id)
|
||||
assert get_query_entities(query) == [self.TextItem.__table__]
|
||||
|
||||
def test_aliased_selectable(self):
|
||||
selectable = sa.orm.with_polymorphic(self.TextItem, [self.BlogPost])
|
||||
query = self.session.query(selectable)
|
||||
assert get_query_entities(query) == [selectable]
|
||||
|
||||
def test_joined_entity(self):
|
||||
query = self.session.query(self.TextItem).join(
|
||||
self.BlogPost, self.BlogPost.id == self.TextItem.id
|
||||
)
|
||||
assert get_query_entities(query) == [
|
||||
self.TextItem, sa.inspect(self.BlogPost)
|
||||
]
|
||||
|
||||
def test_joined_aliased_entity(self):
|
||||
alias = sa.orm.aliased(self.BlogPost)
|
||||
|
||||
query = self.session.query(self.TextItem).join(
|
||||
alias, alias.id == self.TextItem.id
|
||||
)
|
||||
assert get_query_entities(query) == [
|
||||
self.TextItem, sa.inspect(alias)
|
||||
]
|
||||
|
||||
def test_column_entity_with_label(self):
|
||||
query = self.session.query(self.Article.id.label('id'))
|
||||
assert get_query_entities(query) == [sa.inspect(self.Article)]
|
||||
|
||||
def test_with_subquery(self):
|
||||
number_of_articles = (
|
||||
sa.select(
|
||||
[sa.func.count(self.Article.id)],
|
||||
)
|
||||
.select_from(
|
||||
self.Article.__table__
|
||||
)
|
||||
).label('number_of_articles')
|
||||
|
||||
query = self.session.query(self.Article, number_of_articles)
|
||||
assert get_query_entities(query) == [self.Article, number_of_articles]
|
||||
|
||||
def test_aliased_entity(self):
|
||||
alias = sa.orm.aliased(self.Article)
|
||||
query = self.session.query(alias)
|
||||
assert get_query_entities(query) == [alias]
|
@@ -1,74 +0,0 @@
|
||||
import sqlalchemy as sa
|
||||
from tests import TestCase
|
||||
from sqlalchemy_utils.functions import query_entities
|
||||
|
||||
|
||||
class TestQueryEntities(TestCase):
|
||||
def create_models(self):
|
||||
class User(self.Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
class Category(self.Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
class Article(self.Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
author_id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey(User.id), index=True
|
||||
)
|
||||
category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
|
||||
|
||||
category = sa.orm.relationship(
|
||||
Category,
|
||||
primaryjoin=category_id == Category.id,
|
||||
backref=sa.orm.backref(
|
||||
'articles',
|
||||
)
|
||||
)
|
||||
|
||||
self.User = User
|
||||
self.Category = Category
|
||||
self.Article = Article
|
||||
|
||||
def test_simple_query(self):
|
||||
query = self.session.query(self.User)
|
||||
assert list(query_entities(query)) == [self.User]
|
||||
|
||||
def test_column_entity(self):
|
||||
query = self.session.query(self.User.id)
|
||||
assert list(query_entities(query)) == [self.User]
|
||||
|
||||
def test_column_entity_with_label(self):
|
||||
query = self.session.query(self.User.id.label('id'))
|
||||
assert list(query_entities(query)) == [self.User]
|
||||
|
||||
def test_with_subquery(self):
|
||||
number_of_sales = (
|
||||
sa.select(
|
||||
[sa.func.count(self.Article.id)],
|
||||
)
|
||||
.select_from(
|
||||
self.Article.__table__
|
||||
)
|
||||
).label('number_of_articles')
|
||||
|
||||
query = self.session.query(self.User, number_of_sales)
|
||||
assert list(query_entities(query)) == [self.User]
|
||||
|
||||
def test_mapper(self):
|
||||
query = self.session.query(self.User.__mapper__)
|
||||
assert list(query_entities(query)) == [self.User]
|
||||
|
||||
def test_joins(self):
|
||||
query = self.session.query(self.User.__mapper__).join(self.Article)
|
||||
assert list(query_entities(query)) == [self.User, self.Article]
|
||||
|
||||
def test_aliased_entity(self):
|
||||
query = self.session.query(sa.orm.aliased(self.User))
|
||||
assert list(query_entities(query)) == [self.User]
|
@@ -210,6 +210,7 @@ class TestSortQueryWithPolymorphicInheritance(TestCase):
|
||||
id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True
|
||||
)
|
||||
category = sa.Column(sa.Unicode(255))
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity': u'article'
|
||||
}
|
||||
@@ -233,3 +234,69 @@ class TestSortQueryWithPolymorphicInheritance(TestCase):
|
||||
'item_count'
|
||||
)
|
||||
assert_contains('ORDER BY (SELECT count(:param_2) AS count_2', query)
|
||||
|
||||
def test_child_class_attribute(self):
|
||||
query = sort_query(
|
||||
self.session.query(self.TextItem),
|
||||
'category'
|
||||
)
|
||||
assert_contains('ORDER BY article.category ASC', query)
|
||||
|
||||
|
||||
class TestSortQueryWithCustomPolymorphic(TestCase):
|
||||
"""
|
||||
Currently this doesn't work with SQLite
|
||||
"""
|
||||
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
|
||||
|
||||
def create_models(self):
|
||||
class TextItem(self.Base):
|
||||
__tablename__ = 'text_item'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
type = sa.Column(sa.Unicode(255))
|
||||
|
||||
__mapper_args__ = {
|
||||
'polymorphic_on': type,
|
||||
}
|
||||
|
||||
class Article(TextItem):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True
|
||||
)
|
||||
category = sa.Column(sa.Unicode(255))
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity': u'article'
|
||||
}
|
||||
|
||||
class BlogPost(TextItem):
|
||||
__tablename__ = 'blog_post'
|
||||
id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True
|
||||
)
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity': u'blog_post'
|
||||
}
|
||||
|
||||
self.TextItem = TextItem
|
||||
self.Article = Article
|
||||
self.BlogPost = BlogPost
|
||||
|
||||
def test_with_unknown_column(self):
|
||||
query = sort_query(
|
||||
self.session.query(
|
||||
sa.orm.with_polymorphic(self.TextItem, [self.BlogPost])
|
||||
),
|
||||
'category'
|
||||
)
|
||||
assert 'ORDER BY' not in str(query)
|
||||
|
||||
def test_with_existing_column(self):
|
||||
query = sort_query(
|
||||
self.session.query(
|
||||
sa.orm.with_polymorphic(self.TextItem, [self.Article])
|
||||
),
|
||||
'category'
|
||||
)
|
||||
assert 'ORDER BY' in str(query)
|
||||
|
Reference in New Issue
Block a user