Add support for various polymorphic scenarios in sort_query

- Rename query_entities to get_query_entities
- Add tests for get_query_entities
This commit is contained in:
Konsta Vesterinen
2014-08-06 15:17:54 +03:00
parent 1723b1eb69
commit a975591c2e
9 changed files with 241 additions and 119 deletions

View File

@@ -4,10 +4,11 @@ Changelog
Here you can see the full list of changes between each SQLAlchemy-Utils release. Here you can see the full list of changes between each SQLAlchemy-Utils release.
0.26.9 (2014-08-xx) 0.26.9 (2014-08-06)
^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^
- Fixed PasswordType with Oracle dialect - Fixed PasswordType with Oracle dialect
- Added support for sort_query and attributes on mappers using with_polymorphic
0.26.8 (2014-07-30) 0.26.8 (2014-07-30)

View File

@@ -46,6 +46,12 @@ get_mapper
.. autofunction:: get_mapper .. autofunction:: get_mapper
get_query_entities
^^^^^^^^^^^^^^^^^^
.. autofunction:: get_query_entities
get_primary_keys get_primary_keys
^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^
@@ -58,12 +64,6 @@ get_tables
.. autofunction:: get_tables .. autofunction:: get_tables
query_entities
^^^^^^^^^^^^^^
.. autofunction:: query_entities
has_changes has_changes
^^^^^^^^^^^ ^^^^^^^^^^^

View File

@@ -17,6 +17,7 @@ from .functions import (
get_declarative_base, get_declarative_base,
get_hybrid_properties, get_hybrid_properties,
get_mapper, get_mapper,
get_query_entities,
get_primary_keys, get_primary_keys,
get_referencing_foreign_keys, get_referencing_foreign_keys,
get_tables, get_tables,
@@ -99,6 +100,7 @@ __all__ = (
get_declarative_base, get_declarative_base,
get_hybrid_properties, get_hybrid_properties,
get_mapper, get_mapper,
get_query_entities,
get_primary_keys, get_primary_keys,
get_referencing_foreign_keys, get_referencing_foreign_keys,
get_tables, get_tables,

View File

@@ -26,13 +26,13 @@ from .orm import (
get_hybrid_properties, get_hybrid_properties,
get_mapper, get_mapper,
get_primary_keys, get_primary_keys,
get_query_entities,
get_tables, get_tables,
getdotattr, getdotattr,
has_any_changes, has_any_changes,
has_changes, has_changes,
identity, identity,
naturally_equivalent, naturally_equivalent,
query_entities,
table_name, table_name,
) )
@@ -49,6 +49,7 @@ __all__ = (
'get_declarative_base', 'get_declarative_base',
'get_hybrid_properties', 'get_hybrid_properties',
'get_mapper', 'get_mapper',
'get_query_entities',
'get_primary_keys', 'get_primary_keys',
'get_referencing_foreign_keys', 'get_referencing_foreign_keys',
'get_tables', 'get_tables',

View File

@@ -4,15 +4,14 @@ except ImportError:
from ordereddict import OrderedDict from ordereddict import OrderedDict
from functools import partial from functools import partial
from inspect import isclass from inspect import isclass
from itertools import chain
from operator import attrgetter from operator import attrgetter
import six import six
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy import inspect
from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import mapperlib from sqlalchemy.orm import mapperlib
from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.orm.exc import UnmappedInstanceError from sqlalchemy.orm.exc import UnmappedInstanceError
from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.orm.query import _ColumnEntity from sqlalchemy.orm.query import _ColumnEntity
from sqlalchemy.orm.session import object_session from sqlalchemy.orm.session import object_session
from sqlalchemy.orm.util import AliasedInsp from sqlalchemy.orm.util import AliasedInsp
@@ -81,6 +80,8 @@ def get_mapper(mixed):
return sa.inspect(mixed).mapper return sa.inspect(mixed).mapper
if isinstance(mixed, sa.sql.selectable.Alias): if isinstance(mixed, sa.sql.selectable.Alias):
mixed = mixed.element mixed = mixed.element
if isinstance(mixed, AliasedInsp):
return mixed.mapper
if isinstance(mixed, sa.Table): if isinstance(mixed, sa.Table):
mappers = [ mappers = [
mapper for mapper in mapperlib._mapper_registry mapper for mapper in mapperlib._mapper_registry
@@ -344,31 +345,31 @@ def query_labels(query):
db.func.count(Article.id).label('articles') db.func.count(Article.id).label('articles')
) )
query_labels(query) # ('articles', ) query_labels(query) # ['articles']
:param query: SQLAlchemy Query object :param query: SQLAlchemy Query object
""" """
for entity in query._entities: return [
if isinstance(entity, _ColumnEntity) and entity._label_name: entity._label_name for entity in query._entities
yield entity._label_name if isinstance(entity, _ColumnEntity) and entity._label_name
]
def query_entities(query): def get_query_entities(query):
""" """
Return a generator that iterates through all entities for given SQLAlchemy Return a list of all entities present in given SQLAlchemy query object.
query object.
Examples:: Examples::
query = session.query(Category) query = session.query(Category)
query_entities(query) # <Category> query_entities(query) # [<Category>]
query = session.query(Category.id) query = session.query(Category.id)
query_entities(query) # <Category> query_entities(query) # [<Category>]
This function also supports queries with joins. This function also supports queries with joins.
@@ -378,42 +379,66 @@ def query_entities(query):
query = session.query(Category).join(Article) query = session.query(Category).join(Article)
query_entities(query) # (<Category>, <Article>) query_entities(query) # [<Category>, <Article>]
.. versionchanged: 0.26.7
This function now returns a list instead of generator
:param query: SQLAlchemy Query object :param query: SQLAlchemy Query object
""" """
for entity in query._entities: def get_expr(mixed):
if entity.entity_zero: if isinstance(mixed, (sa.orm.Mapper, AliasedInsp)):
yield entity.entity_zero.class_ return mixed
expr = mixed.expr
for entity in query._join_entities: if isinstance(expr, sa.orm.attributes.InstrumentedAttribute):
if isinstance(entity, Mapper): expr = expr.parent
yield entity.class_ elif isinstance(expr, sa.Column):
else: expr = expr.table
yield entity elif isinstance(expr, sa.sql.expression.Label):
if mixed.entity_zero:
return mixed.entity_zero
else:
return expr
return expr
return [
get_expr(entity) for entity in
chain(query._entities, query._join_entities)
]
def get_query_entity_by_alias(query, alias): def get_query_entity_by_alias(query, alias):
entities = query_entities(query) entities = get_query_entities(query)
if not alias: if not alias:
return list(entities)[0] return entities[0]
for entity in entities: for entity in entities:
if isinstance(entity, AliasedInsp): if isinstance(entity, AliasedInsp):
name = entity.name name = entity.name
else: else:
name = entity.__table__.name name = get_mapper(entity).tables[0].name
if name == alias: if name == alias:
return entity return entity
def get_attrs(expr): def get_polymorphic_mappers(mixed):
if isinstance(expr, AliasedInsp): if isinstance(mixed, AliasedInsp):
return expr.mapper.attrs return mixed.with_polymorphic_mappers
else: else:
return inspect(expr).attrs return mixed.polymorphic_map.values()
def get_attrs(expr):
insp = sa.inspect(expr)
mapper = get_mapper(expr)
polymorphic_mappers = get_polymorphic_mappers(insp)
if polymorphic_mappers:
attrs = {}
for submapper in polymorphic_mappers:
attrs.update(submapper.attrs)
return attrs
return mapper.attrs
def get_hybrid_properties(model): def get_hybrid_properties(model):
@@ -467,11 +492,11 @@ def get_hybrid_properties(model):
) )
def get_expr_attr(expr, attr_name): def get_expr_attr(expr, prop):
if isinstance(expr, AliasedInsp): if isinstance(expr, AliasedInsp):
return getattr(expr.selectable.c, attr_name) return getattr(expr.selectable.c, prop.key)
else: else:
return getattr(expr, attr_name) return getattr(prop.parent.class_, prop.key)
def get_declarative_base(model): def get_declarative_base(model):

View File

@@ -7,7 +7,7 @@ from .orm import (
get_expr_attr, get_expr_attr,
get_hybrid_properties, get_hybrid_properties,
get_query_entity_by_alias, get_query_entity_by_alias,
query_entities, get_query_entities,
query_labels, query_labels,
) )
@@ -18,7 +18,6 @@ class QuerySorterException(Exception):
class QuerySorter(object): class QuerySorter(object):
def __init__(self, silent=True, separator='-'): def __init__(self, silent=True, separator='-'):
self.entities = []
self.labels = [] self.labels = []
self.separator = separator self.separator = separator
self.silent = silent self.silent = silent
@@ -44,20 +43,21 @@ class QuerySorter(object):
properties = get_attrs(entity) properties = get_attrs(entity)
if attr in properties: if attr in properties:
property_ = properties[attr] property_ = properties[attr]
if isinstance(property_, ColumnProperty): if isinstance(property_, ColumnProperty):
if isinstance(property_.columns[0], Label): if isinstance(property_.columns[0], Label):
return getattr(entity, property_.key) return getattr(entity, property_.key)
else: else:
return get_expr_attr(entity, property_.key) return get_expr_attr(entity, property_)
elif isinstance(property_, SynonymProperty): elif isinstance(property_, SynonymProperty):
return get_expr_attr(entity, property_.key) return get_expr_attr(entity, property_)
return return
mapper = sa.inspect(entity) mapper = sa.inspect(entity)
entity = mapper.entity
if isinstance(mapper, AliasedInsp): if isinstance(mapper, AliasedInsp):
mapper = mapper.mapper mapper = mapper.mapper
entity = mapper.entity
for key in get_hybrid_properties(mapper).keys(): for key in get_hybrid_properties(mapper).keys():
if attr == key: if attr == key:
@@ -80,7 +80,6 @@ class QuerySorter(object):
def __call__(self, query, *args): def __call__(self, query, *args):
self.query = query self.query = query
self.labels = query_labels(query) self.labels = query_labels(query)
self.entities = query_entities(query)
for sort in args: for sort in args:
if not sort: if not sort:

View File

@@ -0,0 +1,101 @@
import sqlalchemy as sa
from sqlalchemy_utils import get_query_entities
from tests import TestCase
class TestGetQueryEntities(TestCase):
def create_models(self):
class TextItem(self.Base):
__tablename__ = 'text_item'
id = sa.Column(sa.Integer, primary_key=True)
type = sa.Column(sa.Unicode(255))
__mapper_args__ = {
'polymorphic_on': type,
}
class Article(TextItem):
__tablename__ = 'article'
id = sa.Column(
sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True
)
category = sa.Column(sa.Unicode(255))
__mapper_args__ = {
'polymorphic_identity': u'article'
}
class BlogPost(TextItem):
__tablename__ = 'blog_post'
id = sa.Column(
sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True
)
__mapper_args__ = {
'polymorphic_identity': u'blog_post'
}
self.TextItem = TextItem
self.Article = Article
self.BlogPost = BlogPost
def test_mapper(self):
query = self.session.query(sa.inspect(self.TextItem))
assert get_query_entities(query) == [sa.inspect(self.TextItem)]
def test_entity(self):
query = self.session.query(self.TextItem)
assert get_query_entities(query) == [self.TextItem]
def test_instrumented_attribute(self):
query = self.session.query(self.TextItem.id)
assert get_query_entities(query) == [sa.inspect(self.TextItem)]
def test_column(self):
query = self.session.query(self.TextItem.__table__.c.id)
assert get_query_entities(query) == [self.TextItem.__table__]
def test_aliased_selectable(self):
selectable = sa.orm.with_polymorphic(self.TextItem, [self.BlogPost])
query = self.session.query(selectable)
assert get_query_entities(query) == [selectable]
def test_joined_entity(self):
query = self.session.query(self.TextItem).join(
self.BlogPost, self.BlogPost.id == self.TextItem.id
)
assert get_query_entities(query) == [
self.TextItem, sa.inspect(self.BlogPost)
]
def test_joined_aliased_entity(self):
alias = sa.orm.aliased(self.BlogPost)
query = self.session.query(self.TextItem).join(
alias, alias.id == self.TextItem.id
)
assert get_query_entities(query) == [
self.TextItem, sa.inspect(alias)
]
def test_column_entity_with_label(self):
query = self.session.query(self.Article.id.label('id'))
assert get_query_entities(query) == [sa.inspect(self.Article)]
def test_with_subquery(self):
number_of_articles = (
sa.select(
[sa.func.count(self.Article.id)],
)
.select_from(
self.Article.__table__
)
).label('number_of_articles')
query = self.session.query(self.Article, number_of_articles)
assert get_query_entities(query) == [self.Article, number_of_articles]
def test_aliased_entity(self):
alias = sa.orm.aliased(self.Article)
query = self.session.query(alias)
assert get_query_entities(query) == [alias]

View File

@@ -1,74 +0,0 @@
import sqlalchemy as sa
from tests import TestCase
from sqlalchemy_utils.functions import query_entities
class TestQueryEntities(TestCase):
def create_models(self):
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
name = sa.Column(sa.Unicode(255))
class Category(self.Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
class Article(self.Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
author_id = sa.Column(
sa.Integer, sa.ForeignKey(User.id), index=True
)
category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
category = sa.orm.relationship(
Category,
primaryjoin=category_id == Category.id,
backref=sa.orm.backref(
'articles',
)
)
self.User = User
self.Category = Category
self.Article = Article
def test_simple_query(self):
query = self.session.query(self.User)
assert list(query_entities(query)) == [self.User]
def test_column_entity(self):
query = self.session.query(self.User.id)
assert list(query_entities(query)) == [self.User]
def test_column_entity_with_label(self):
query = self.session.query(self.User.id.label('id'))
assert list(query_entities(query)) == [self.User]
def test_with_subquery(self):
number_of_sales = (
sa.select(
[sa.func.count(self.Article.id)],
)
.select_from(
self.Article.__table__
)
).label('number_of_articles')
query = self.session.query(self.User, number_of_sales)
assert list(query_entities(query)) == [self.User]
def test_mapper(self):
query = self.session.query(self.User.__mapper__)
assert list(query_entities(query)) == [self.User]
def test_joins(self):
query = self.session.query(self.User.__mapper__).join(self.Article)
assert list(query_entities(query)) == [self.User, self.Article]
def test_aliased_entity(self):
query = self.session.query(sa.orm.aliased(self.User))
assert list(query_entities(query)) == [self.User]

View File

@@ -210,6 +210,7 @@ class TestSortQueryWithPolymorphicInheritance(TestCase):
id = sa.Column( id = sa.Column(
sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True
) )
category = sa.Column(sa.Unicode(255))
__mapper_args__ = { __mapper_args__ = {
'polymorphic_identity': u'article' 'polymorphic_identity': u'article'
} }
@@ -233,3 +234,69 @@ class TestSortQueryWithPolymorphicInheritance(TestCase):
'item_count' 'item_count'
) )
assert_contains('ORDER BY (SELECT count(:param_2) AS count_2', query) assert_contains('ORDER BY (SELECT count(:param_2) AS count_2', query)
def test_child_class_attribute(self):
query = sort_query(
self.session.query(self.TextItem),
'category'
)
assert_contains('ORDER BY article.category ASC', query)
class TestSortQueryWithCustomPolymorphic(TestCase):
"""
Currently this doesn't work with SQLite
"""
dns = 'postgres://postgres@localhost/sqlalchemy_utils_test'
def create_models(self):
class TextItem(self.Base):
__tablename__ = 'text_item'
id = sa.Column(sa.Integer, primary_key=True)
type = sa.Column(sa.Unicode(255))
__mapper_args__ = {
'polymorphic_on': type,
}
class Article(TextItem):
__tablename__ = 'article'
id = sa.Column(
sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True
)
category = sa.Column(sa.Unicode(255))
__mapper_args__ = {
'polymorphic_identity': u'article'
}
class BlogPost(TextItem):
__tablename__ = 'blog_post'
id = sa.Column(
sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True
)
__mapper_args__ = {
'polymorphic_identity': u'blog_post'
}
self.TextItem = TextItem
self.Article = Article
self.BlogPost = BlogPost
def test_with_unknown_column(self):
query = sort_query(
self.session.query(
sa.orm.with_polymorphic(self.TextItem, [self.BlogPost])
),
'category'
)
assert 'ORDER BY' not in str(query)
def test_with_existing_column(self):
query = sort_query(
self.session.query(
sa.orm.with_polymorphic(self.TextItem, [self.Article])
),
'category'
)
assert 'ORDER BY' in str(query)