diff --git a/CHANGES.rst b/CHANGES.rst index ee807fb..014e927 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,10 +4,11 @@ Changelog Here you can see the full list of changes between each SQLAlchemy-Utils release. -0.26.9 (2014-08-xx) +0.26.9 (2014-08-06) ^^^^^^^^^^^^^^^^^^^ - Fixed PasswordType with Oracle dialect +- Added support for sort_query and attributes on mappers using with_polymorphic 0.26.8 (2014-07-30) diff --git a/docs/orm_helpers.rst b/docs/orm_helpers.rst index 09cc984..41918d7 100644 --- a/docs/orm_helpers.rst +++ b/docs/orm_helpers.rst @@ -46,6 +46,12 @@ get_mapper .. autofunction:: get_mapper +get_query_entities +^^^^^^^^^^^^^^^^^^ + +.. autofunction:: get_query_entities + + get_primary_keys ^^^^^^^^^^^^^^^^ @@ -58,12 +64,6 @@ get_tables .. autofunction:: get_tables -query_entities -^^^^^^^^^^^^^^ - -.. autofunction:: query_entities - - has_changes ^^^^^^^^^^^ diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index 3fa2945..8d82dfc 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -17,6 +17,7 @@ from .functions import ( get_declarative_base, get_hybrid_properties, get_mapper, + get_query_entities, get_primary_keys, get_referencing_foreign_keys, get_tables, @@ -99,6 +100,7 @@ __all__ = ( get_declarative_base, get_hybrid_properties, get_mapper, + get_query_entities, get_primary_keys, get_referencing_foreign_keys, get_tables, diff --git a/sqlalchemy_utils/functions/__init__.py b/sqlalchemy_utils/functions/__init__.py index 8405c33..cffcb98 100644 --- a/sqlalchemy_utils/functions/__init__.py +++ b/sqlalchemy_utils/functions/__init__.py @@ -26,13 +26,13 @@ from .orm import ( get_hybrid_properties, get_mapper, get_primary_keys, + get_query_entities, get_tables, getdotattr, has_any_changes, has_changes, identity, naturally_equivalent, - query_entities, table_name, ) @@ -49,6 +49,7 @@ __all__ = ( 'get_declarative_base', 'get_hybrid_properties', 'get_mapper', + 'get_query_entities', 'get_primary_keys', 'get_referencing_foreign_keys', 'get_tables', diff --git a/sqlalchemy_utils/functions/orm.py b/sqlalchemy_utils/functions/orm.py index a51edbd..df83b92 100644 --- a/sqlalchemy_utils/functions/orm.py +++ b/sqlalchemy_utils/functions/orm.py @@ -4,15 +4,14 @@ except ImportError: from ordereddict import OrderedDict from functools import partial from inspect import isclass +from itertools import chain from operator import attrgetter import six import sqlalchemy as sa -from sqlalchemy import inspect from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.orm import mapperlib from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.orm.exc import UnmappedInstanceError -from sqlalchemy.orm.mapper import Mapper from sqlalchemy.orm.query import _ColumnEntity from sqlalchemy.orm.session import object_session from sqlalchemy.orm.util import AliasedInsp @@ -81,6 +80,8 @@ def get_mapper(mixed): return sa.inspect(mixed).mapper if isinstance(mixed, sa.sql.selectable.Alias): mixed = mixed.element + if isinstance(mixed, AliasedInsp): + return mixed.mapper if isinstance(mixed, sa.Table): mappers = [ mapper for mapper in mapperlib._mapper_registry @@ -344,31 +345,31 @@ def query_labels(query): db.func.count(Article.id).label('articles') ) - query_labels(query) # ('articles', ) + query_labels(query) # ['articles'] :param query: SQLAlchemy Query object """ - for entity in query._entities: - if isinstance(entity, _ColumnEntity) and entity._label_name: - yield entity._label_name + return [ + entity._label_name for entity in query._entities + if isinstance(entity, _ColumnEntity) and entity._label_name + ] -def query_entities(query): +def get_query_entities(query): """ - Return a generator that iterates through all entities for given SQLAlchemy - query object. + Return a list of all entities present in given SQLAlchemy query object. Examples:: query = session.query(Category) - query_entities(query) # + query_entities(query) # [] query = session.query(Category.id) - query_entities(query) # + query_entities(query) # [] This function also supports queries with joins. @@ -378,42 +379,66 @@ def query_entities(query): query = session.query(Category).join(Article) - query_entities(query) # (,
) + query_entities(query) # [,
] + .. versionchanged: 0.26.7 + This function now returns a list instead of generator :param query: SQLAlchemy Query object """ - for entity in query._entities: - if entity.entity_zero: - yield entity.entity_zero.class_ - - for entity in query._join_entities: - if isinstance(entity, Mapper): - yield entity.class_ - else: - yield entity + def get_expr(mixed): + if isinstance(mixed, (sa.orm.Mapper, AliasedInsp)): + return mixed + expr = mixed.expr + if isinstance(expr, sa.orm.attributes.InstrumentedAttribute): + expr = expr.parent + elif isinstance(expr, sa.Column): + expr = expr.table + elif isinstance(expr, sa.sql.expression.Label): + if mixed.entity_zero: + return mixed.entity_zero + else: + return expr + return expr + return [ + get_expr(entity) for entity in + chain(query._entities, query._join_entities) + ] def get_query_entity_by_alias(query, alias): - entities = query_entities(query) + entities = get_query_entities(query) if not alias: - return list(entities)[0] + return entities[0] for entity in entities: if isinstance(entity, AliasedInsp): name = entity.name else: - name = entity.__table__.name + name = get_mapper(entity).tables[0].name if name == alias: return entity -def get_attrs(expr): - if isinstance(expr, AliasedInsp): - return expr.mapper.attrs +def get_polymorphic_mappers(mixed): + if isinstance(mixed, AliasedInsp): + return mixed.with_polymorphic_mappers else: - return inspect(expr).attrs + return mixed.polymorphic_map.values() + + +def get_attrs(expr): + insp = sa.inspect(expr) + mapper = get_mapper(expr) + polymorphic_mappers = get_polymorphic_mappers(insp) + + if polymorphic_mappers: + attrs = {} + for submapper in polymorphic_mappers: + attrs.update(submapper.attrs) + return attrs + return mapper.attrs def get_hybrid_properties(model): @@ -467,11 +492,11 @@ def get_hybrid_properties(model): ) -def get_expr_attr(expr, attr_name): +def get_expr_attr(expr, prop): if isinstance(expr, AliasedInsp): - return getattr(expr.selectable.c, attr_name) + return getattr(expr.selectable.c, prop.key) else: - return getattr(expr, attr_name) + return getattr(prop.parent.class_, prop.key) def get_declarative_base(model): diff --git a/sqlalchemy_utils/functions/sort_query.py b/sqlalchemy_utils/functions/sort_query.py index b999ad2..bd2656a 100644 --- a/sqlalchemy_utils/functions/sort_query.py +++ b/sqlalchemy_utils/functions/sort_query.py @@ -7,7 +7,7 @@ from .orm import ( get_expr_attr, get_hybrid_properties, get_query_entity_by_alias, - query_entities, + get_query_entities, query_labels, ) @@ -18,7 +18,6 @@ class QuerySorterException(Exception): class QuerySorter(object): def __init__(self, silent=True, separator='-'): - self.entities = [] self.labels = [] self.separator = separator self.silent = silent @@ -44,20 +43,21 @@ class QuerySorter(object): properties = get_attrs(entity) if attr in properties: property_ = properties[attr] + if isinstance(property_, ColumnProperty): if isinstance(property_.columns[0], Label): return getattr(entity, property_.key) else: - return get_expr_attr(entity, property_.key) + return get_expr_attr(entity, property_) elif isinstance(property_, SynonymProperty): - return get_expr_attr(entity, property_.key) + return get_expr_attr(entity, property_) return mapper = sa.inspect(entity) + entity = mapper.entity if isinstance(mapper, AliasedInsp): mapper = mapper.mapper - entity = mapper.entity for key in get_hybrid_properties(mapper).keys(): if attr == key: @@ -80,7 +80,6 @@ class QuerySorter(object): def __call__(self, query, *args): self.query = query self.labels = query_labels(query) - self.entities = query_entities(query) for sort in args: if not sort: diff --git a/tests/functions/test_get_query_entities.py b/tests/functions/test_get_query_entities.py new file mode 100644 index 0000000..45b9b4a --- /dev/null +++ b/tests/functions/test_get_query_entities.py @@ -0,0 +1,101 @@ +import sqlalchemy as sa +from sqlalchemy_utils import get_query_entities + +from tests import TestCase + + +class TestGetQueryEntities(TestCase): + def create_models(self): + class TextItem(self.Base): + __tablename__ = 'text_item' + id = sa.Column(sa.Integer, primary_key=True) + + type = sa.Column(sa.Unicode(255)) + + __mapper_args__ = { + 'polymorphic_on': type, + } + + class Article(TextItem): + __tablename__ = 'article' + id = sa.Column( + sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True + ) + category = sa.Column(sa.Unicode(255)) + __mapper_args__ = { + 'polymorphic_identity': u'article' + } + + class BlogPost(TextItem): + __tablename__ = 'blog_post' + id = sa.Column( + sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True + ) + __mapper_args__ = { + 'polymorphic_identity': u'blog_post' + } + + self.TextItem = TextItem + self.Article = Article + self.BlogPost = BlogPost + + def test_mapper(self): + query = self.session.query(sa.inspect(self.TextItem)) + assert get_query_entities(query) == [sa.inspect(self.TextItem)] + + def test_entity(self): + query = self.session.query(self.TextItem) + assert get_query_entities(query) == [self.TextItem] + + def test_instrumented_attribute(self): + query = self.session.query(self.TextItem.id) + assert get_query_entities(query) == [sa.inspect(self.TextItem)] + + def test_column(self): + query = self.session.query(self.TextItem.__table__.c.id) + assert get_query_entities(query) == [self.TextItem.__table__] + + def test_aliased_selectable(self): + selectable = sa.orm.with_polymorphic(self.TextItem, [self.BlogPost]) + query = self.session.query(selectable) + assert get_query_entities(query) == [selectable] + + def test_joined_entity(self): + query = self.session.query(self.TextItem).join( + self.BlogPost, self.BlogPost.id == self.TextItem.id + ) + assert get_query_entities(query) == [ + self.TextItem, sa.inspect(self.BlogPost) + ] + + def test_joined_aliased_entity(self): + alias = sa.orm.aliased(self.BlogPost) + + query = self.session.query(self.TextItem).join( + alias, alias.id == self.TextItem.id + ) + assert get_query_entities(query) == [ + self.TextItem, sa.inspect(alias) + ] + + def test_column_entity_with_label(self): + query = self.session.query(self.Article.id.label('id')) + assert get_query_entities(query) == [sa.inspect(self.Article)] + + def test_with_subquery(self): + number_of_articles = ( + sa.select( + [sa.func.count(self.Article.id)], + ) + .select_from( + self.Article.__table__ + ) + ).label('number_of_articles') + + query = self.session.query(self.Article, number_of_articles) + assert get_query_entities(query) == [self.Article, number_of_articles] + + def test_aliased_entity(self): + alias = sa.orm.aliased(self.Article) + query = self.session.query(alias) + assert get_query_entities(query) == [alias] diff --git a/tests/functions/test_query_entities.py b/tests/functions/test_query_entities.py deleted file mode 100644 index 377e5b4..0000000 --- a/tests/functions/test_query_entities.py +++ /dev/null @@ -1,74 +0,0 @@ -import sqlalchemy as sa -from tests import TestCase -from sqlalchemy_utils.functions import query_entities - - -class TestQueryEntities(TestCase): - def create_models(self): - class User(self.Base): - __tablename__ = 'user' - id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) - name = sa.Column(sa.Unicode(255)) - - class Category(self.Base): - __tablename__ = 'category' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) - - class Article(self.Base): - __tablename__ = 'article' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) - author_id = sa.Column( - sa.Integer, sa.ForeignKey(User.id), index=True - ) - category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id)) - - category = sa.orm.relationship( - Category, - primaryjoin=category_id == Category.id, - backref=sa.orm.backref( - 'articles', - ) - ) - - self.User = User - self.Category = Category - self.Article = Article - - def test_simple_query(self): - query = self.session.query(self.User) - assert list(query_entities(query)) == [self.User] - - def test_column_entity(self): - query = self.session.query(self.User.id) - assert list(query_entities(query)) == [self.User] - - def test_column_entity_with_label(self): - query = self.session.query(self.User.id.label('id')) - assert list(query_entities(query)) == [self.User] - - def test_with_subquery(self): - number_of_sales = ( - sa.select( - [sa.func.count(self.Article.id)], - ) - .select_from( - self.Article.__table__ - ) - ).label('number_of_articles') - - query = self.session.query(self.User, number_of_sales) - assert list(query_entities(query)) == [self.User] - - def test_mapper(self): - query = self.session.query(self.User.__mapper__) - assert list(query_entities(query)) == [self.User] - - def test_joins(self): - query = self.session.query(self.User.__mapper__).join(self.Article) - assert list(query_entities(query)) == [self.User, self.Article] - - def test_aliased_entity(self): - query = self.session.query(sa.orm.aliased(self.User)) - assert list(query_entities(query)) == [self.User] diff --git a/tests/test_sort_query.py b/tests/test_sort_query.py index afe82bc..0ca2c8a 100644 --- a/tests/test_sort_query.py +++ b/tests/test_sort_query.py @@ -210,6 +210,7 @@ class TestSortQueryWithPolymorphicInheritance(TestCase): id = sa.Column( sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True ) + category = sa.Column(sa.Unicode(255)) __mapper_args__ = { 'polymorphic_identity': u'article' } @@ -233,3 +234,69 @@ class TestSortQueryWithPolymorphicInheritance(TestCase): 'item_count' ) assert_contains('ORDER BY (SELECT count(:param_2) AS count_2', query) + + def test_child_class_attribute(self): + query = sort_query( + self.session.query(self.TextItem), + 'category' + ) + assert_contains('ORDER BY article.category ASC', query) + + +class TestSortQueryWithCustomPolymorphic(TestCase): + """ + Currently this doesn't work with SQLite + """ + dns = 'postgres://postgres@localhost/sqlalchemy_utils_test' + + def create_models(self): + class TextItem(self.Base): + __tablename__ = 'text_item' + id = sa.Column(sa.Integer, primary_key=True) + + type = sa.Column(sa.Unicode(255)) + + __mapper_args__ = { + 'polymorphic_on': type, + } + + class Article(TextItem): + __tablename__ = 'article' + id = sa.Column( + sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True + ) + category = sa.Column(sa.Unicode(255)) + __mapper_args__ = { + 'polymorphic_identity': u'article' + } + + class BlogPost(TextItem): + __tablename__ = 'blog_post' + id = sa.Column( + sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True + ) + __mapper_args__ = { + 'polymorphic_identity': u'blog_post' + } + + self.TextItem = TextItem + self.Article = Article + self.BlogPost = BlogPost + + def test_with_unknown_column(self): + query = sort_query( + self.session.query( + sa.orm.with_polymorphic(self.TextItem, [self.BlogPost]) + ), + 'category' + ) + assert 'ORDER BY' not in str(query) + + def test_with_existing_column(self): + query = sort_query( + self.session.query( + sa.orm.with_polymorphic(self.TextItem, [self.Article]) + ), + 'category' + ) + assert 'ORDER BY' in str(query)