Add support for various polymorphic scenarios in sort_query

- Rename query_entities to get_query_entities
- Add tests for get_query_entities
This commit is contained in:
Konsta Vesterinen
2014-08-06 15:17:54 +03:00
parent 1723b1eb69
commit a975591c2e
9 changed files with 241 additions and 119 deletions

View File

@@ -4,10 +4,11 @@ Changelog
Here you can see the full list of changes between each SQLAlchemy-Utils release.
0.26.9 (2014-08-xx)
0.26.9 (2014-08-06)
^^^^^^^^^^^^^^^^^^^
- Fixed PasswordType with Oracle dialect
- Added support for sort_query and attributes on mappers using with_polymorphic
0.26.8 (2014-07-30)

View File

@@ -46,6 +46,12 @@ get_mapper
.. autofunction:: get_mapper
get_query_entities
^^^^^^^^^^^^^^^^^^
.. autofunction:: get_query_entities
get_primary_keys
^^^^^^^^^^^^^^^^
@@ -58,12 +64,6 @@ get_tables
.. autofunction:: get_tables
query_entities
^^^^^^^^^^^^^^
.. autofunction:: query_entities
has_changes
^^^^^^^^^^^

View File

@@ -17,6 +17,7 @@ from .functions import (
get_declarative_base,
get_hybrid_properties,
get_mapper,
get_query_entities,
get_primary_keys,
get_referencing_foreign_keys,
get_tables,
@@ -99,6 +100,7 @@ __all__ = (
get_declarative_base,
get_hybrid_properties,
get_mapper,
get_query_entities,
get_primary_keys,
get_referencing_foreign_keys,
get_tables,

View File

@@ -26,13 +26,13 @@ from .orm import (
get_hybrid_properties,
get_mapper,
get_primary_keys,
get_query_entities,
get_tables,
getdotattr,
has_any_changes,
has_changes,
identity,
naturally_equivalent,
query_entities,
table_name,
)
@@ -49,6 +49,7 @@ __all__ = (
'get_declarative_base',
'get_hybrid_properties',
'get_mapper',
'get_query_entities',
'get_primary_keys',
'get_referencing_foreign_keys',
'get_tables',

View File

@@ -4,15 +4,14 @@ except ImportError:
from ordereddict import OrderedDict
from functools import partial
from inspect import isclass
from itertools import chain
from operator import attrgetter
import six
import sqlalchemy as sa
from sqlalchemy import inspect
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import mapperlib
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.orm.exc import UnmappedInstanceError
from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.orm.query import _ColumnEntity
from sqlalchemy.orm.session import object_session
from sqlalchemy.orm.util import AliasedInsp
@@ -81,6 +80,8 @@ def get_mapper(mixed):
return sa.inspect(mixed).mapper
if isinstance(mixed, sa.sql.selectable.Alias):
mixed = mixed.element
if isinstance(mixed, AliasedInsp):
return mixed.mapper
if isinstance(mixed, sa.Table):
mappers = [
mapper for mapper in mapperlib._mapper_registry
@@ -344,31 +345,31 @@ def query_labels(query):
db.func.count(Article.id).label('articles')
)
query_labels(query) # ('articles', )
query_labels(query) # ['articles']
:param query: SQLAlchemy Query object
"""
for entity in query._entities:
if isinstance(entity, _ColumnEntity) and entity._label_name:
yield entity._label_name
return [
entity._label_name for entity in query._entities
if isinstance(entity, _ColumnEntity) and entity._label_name
]
def query_entities(query):
def get_query_entities(query):
"""
Return a generator that iterates through all entities for given SQLAlchemy
query object.
Return a list of all entities present in given SQLAlchemy query object.
Examples::
query = session.query(Category)
query_entities(query) # <Category>
query_entities(query) # [<Category>]
query = session.query(Category.id)
query_entities(query) # <Category>
query_entities(query) # [<Category>]
This function also supports queries with joins.
@@ -378,42 +379,66 @@ def query_entities(query):
query = session.query(Category).join(Article)
query_entities(query) # (<Category>, <Article>)
query_entities(query) # [<Category>, <Article>]
.. versionchanged: 0.26.7
This function now returns a list instead of generator
:param query: SQLAlchemy Query object
"""
for entity in query._entities:
if entity.entity_zero:
yield entity.entity_zero.class_
for entity in query._join_entities:
if isinstance(entity, Mapper):
yield entity.class_
else:
yield entity
def get_expr(mixed):
if isinstance(mixed, (sa.orm.Mapper, AliasedInsp)):
return mixed
expr = mixed.expr
if isinstance(expr, sa.orm.attributes.InstrumentedAttribute):
expr = expr.parent
elif isinstance(expr, sa.Column):
expr = expr.table
elif isinstance(expr, sa.sql.expression.Label):
if mixed.entity_zero:
return mixed.entity_zero
else:
return expr
return expr
return [
get_expr(entity) for entity in
chain(query._entities, query._join_entities)
]
def get_query_entity_by_alias(query, alias):
entities = query_entities(query)
entities = get_query_entities(query)
if not alias:
return list(entities)[0]
return entities[0]
for entity in entities:
if isinstance(entity, AliasedInsp):
name = entity.name
else:
name = entity.__table__.name
name = get_mapper(entity).tables[0].name
if name == alias:
return entity
def get_attrs(expr):
if isinstance(expr, AliasedInsp):
return expr.mapper.attrs
def get_polymorphic_mappers(mixed):
if isinstance(mixed, AliasedInsp):
return mixed.with_polymorphic_mappers
else:
return inspect(expr).attrs
return mixed.polymorphic_map.values()
def get_attrs(expr):
insp = sa.inspect(expr)
mapper = get_mapper(expr)
polymorphic_mappers = get_polymorphic_mappers(insp)
if polymorphic_mappers:
attrs = {}
for submapper in polymorphic_mappers:
attrs.update(submapper.attrs)
return attrs
return mapper.attrs
def get_hybrid_properties(model):
@@ -467,11 +492,11 @@ def get_hybrid_properties(model):
)
def get_expr_attr(expr, attr_name):
def get_expr_attr(expr, prop):
if isinstance(expr, AliasedInsp):
return getattr(expr.selectable.c, attr_name)
return getattr(expr.selectable.c, prop.key)
else:
return getattr(expr, attr_name)
return getattr(prop.parent.class_, prop.key)
def get_declarative_base(model):

View File

@@ -7,7 +7,7 @@ from .orm import (
get_expr_attr,
get_hybrid_properties,
get_query_entity_by_alias,
query_entities,
get_query_entities,
query_labels,
)
@@ -18,7 +18,6 @@ class QuerySorterException(Exception):
class QuerySorter(object):
def __init__(self, silent=True, separator='-'):
self.entities = []
self.labels = []
self.separator = separator
self.silent = silent
@@ -44,20 +43,21 @@ class QuerySorter(object):
properties = get_attrs(entity)
if attr in properties:
property_ = properties[attr]
if isinstance(property_, ColumnProperty):
if isinstance(property_.columns[0], Label):
return getattr(entity, property_.key)
else:
return get_expr_attr(entity, property_.key)
return get_expr_attr(entity, property_)
elif isinstance(property_, SynonymProperty):
return get_expr_attr(entity, property_.key)
return get_expr_attr(entity, property_)
return
mapper = sa.inspect(entity)
entity = mapper.entity
if isinstance(mapper, AliasedInsp):
mapper = mapper.mapper
entity = mapper.entity
for key in get_hybrid_properties(mapper).keys():
if attr == key:
@@ -80,7 +80,6 @@ class QuerySorter(object):
def __call__(self, query, *args):
self.query = query
self.labels = query_labels(query)
self.entities = query_entities(query)
for sort in args:
if not sort:

View File

@@ -0,0 +1,101 @@
import sqlalchemy as sa
from sqlalchemy_utils import get_query_entities
from tests import TestCase
class TestGetQueryEntities(TestCase):
def create_models(self):
class TextItem(self.Base):
__tablename__ = 'text_item'
id = sa.Column(sa.Integer, primary_key=True)
type = sa.Column(sa.Unicode(255))
__mapper_args__ = {
'polymorphic_on': type,
}
class Article(TextItem):
__tablename__ = 'article'
id = sa.Column(
sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True
)
category = sa.Column(sa.Unicode(255))
__mapper_args__ = {
'polymorphic_identity': u'article'
}
class BlogPost(TextItem):
__tablename__ = 'blog_post'
id = sa.Column(
sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True
)
__mapper_args__ = {
'polymorphic_identity': u'blog_post'
}
self.TextItem = TextItem
self.Article = Article
self.BlogPost = BlogPost
def test_mapper(self):
query = self.session.query(sa.inspect(self.TextItem))
assert get_query_entities(query) == [sa.inspect(self.TextItem)]
def test_entity(self):
query = self.session.query(self.TextItem)
assert get_query_entities(query) == [self.TextItem]
def test_instrumented_attribute(self):
query = self.session.query(self.TextItem.id)
assert get_query_entities(query) == [sa.inspect(self.TextItem)]
def test_column(self):
query = self.session.query(self.TextItem.__table__.c.id)
assert get_query_entities(query) == [self.TextItem.__table__]
def test_aliased_selectable(self):
selectable = sa.orm.with_polymorphic(self.TextItem, [self.BlogPost])
query = self.session.query(selectable)
assert get_query_entities(query) == [selectable]
def test_joined_entity(self):
query = self.session.query(self.TextItem).join(
self.BlogPost, self.BlogPost.id == self.TextItem.id
)
assert get_query_entities(query) == [
self.TextItem, sa.inspect(self.BlogPost)
]
def test_joined_aliased_entity(self):
alias = sa.orm.aliased(self.BlogPost)
query = self.session.query(self.TextItem).join(
alias, alias.id == self.TextItem.id
)
assert get_query_entities(query) == [
self.TextItem, sa.inspect(alias)
]
def test_column_entity_with_label(self):
query = self.session.query(self.Article.id.label('id'))
assert get_query_entities(query) == [sa.inspect(self.Article)]
def test_with_subquery(self):
number_of_articles = (
sa.select(
[sa.func.count(self.Article.id)],
)
.select_from(
self.Article.__table__
)
).label('number_of_articles')
query = self.session.query(self.Article, number_of_articles)
assert get_query_entities(query) == [self.Article, number_of_articles]
def test_aliased_entity(self):
alias = sa.orm.aliased(self.Article)
query = self.session.query(alias)
assert get_query_entities(query) == [alias]

View File

@@ -1,74 +0,0 @@
import sqlalchemy as sa
from tests import TestCase
from sqlalchemy_utils.functions import query_entities
class TestQueryEntities(TestCase):
def create_models(self):
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
name = sa.Column(sa.Unicode(255))
class Category(self.Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
class Article(self.Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
author_id = sa.Column(
sa.Integer, sa.ForeignKey(User.id), index=True
)
category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
category = sa.orm.relationship(
Category,
primaryjoin=category_id == Category.id,
backref=sa.orm.backref(
'articles',
)
)
self.User = User
self.Category = Category
self.Article = Article
def test_simple_query(self):
query = self.session.query(self.User)
assert list(query_entities(query)) == [self.User]
def test_column_entity(self):
query = self.session.query(self.User.id)
assert list(query_entities(query)) == [self.User]
def test_column_entity_with_label(self):
query = self.session.query(self.User.id.label('id'))
assert list(query_entities(query)) == [self.User]
def test_with_subquery(self):
number_of_sales = (
sa.select(
[sa.func.count(self.Article.id)],
)
.select_from(
self.Article.__table__
)
).label('number_of_articles')
query = self.session.query(self.User, number_of_sales)
assert list(query_entities(query)) == [self.User]
def test_mapper(self):
query = self.session.query(self.User.__mapper__)
assert list(query_entities(query)) == [self.User]
def test_joins(self):
query = self.session.query(self.User.__mapper__).join(self.Article)
assert list(query_entities(query)) == [self.User, self.Article]
def test_aliased_entity(self):
query = self.session.query(sa.orm.aliased(self.User))
assert list(query_entities(query)) == [self.User]

View File

@@ -210,6 +210,7 @@ class TestSortQueryWithPolymorphicInheritance(TestCase):
id = sa.Column(
sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True
)
category = sa.Column(sa.Unicode(255))
__mapper_args__ = {
'polymorphic_identity': u'article'
}
@@ -233,3 +234,69 @@ class TestSortQueryWithPolymorphicInheritance(TestCase):
'item_count'
)
assert_contains('ORDER BY (SELECT count(:param_2) AS count_2', query)
def test_child_class_attribute(self):
query = sort_query(
self.session.query(self.TextItem),
'category'
)
assert_contains('ORDER BY article.category ASC', query)
class TestSortQueryWithCustomPolymorphic(TestCase):
"""
Currently this doesn't work with SQLite
"""
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
def create_models(self):
class TextItem(self.Base):
__tablename__ = 'text_item'
id = sa.Column(sa.Integer, primary_key=True)
type = sa.Column(sa.Unicode(255))
__mapper_args__ = {
'polymorphic_on': type,
}
class Article(TextItem):
__tablename__ = 'article'
id = sa.Column(
sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True
)
category = sa.Column(sa.Unicode(255))
__mapper_args__ = {
'polymorphic_identity': u'article'
}
class BlogPost(TextItem):
__tablename__ = 'blog_post'
id = sa.Column(
sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True
)
__mapper_args__ = {
'polymorphic_identity': u'blog_post'
}
self.TextItem = TextItem
self.Article = Article
self.BlogPost = BlogPost
def test_with_unknown_column(self):
query = sort_query(
self.session.query(
sa.orm.with_polymorphic(self.TextItem, [self.BlogPost])
),
'category'
)
assert 'ORDER BY' not in str(query)
def test_with_existing_column(self):
query = sort_query(
self.session.query(
sa.orm.with_polymorphic(self.TextItem, [self.Article])
),
'category'
)
assert 'ORDER BY' in str(query)