From 26db4e94b4c792aecb479a330771da63493623f1 Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Tue, 21 Oct 2014 14:15:45 +0300 Subject: [PATCH] Fix query entity handling for get_mapper --- CHANGES.rst | 6 ++++++ sqlalchemy_utils/functions/orm.py | 5 +++-- tests/functions/test_get_mapper.py | 25 +++++++++++++++++++++++++ tests/functions/test_get_tables.py | 18 ++++++++++++------ 4 files changed, 46 insertions(+), 8 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 9e220ba..9277c5a 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,6 +4,12 @@ Changelog Here you can see the full list of changes between each SQLAlchemy-Utils release. +0.27.2 (2014-10-21) +^^^^^^^^^^^^^^^^^^^ + +- Fixed MapperEntity handling in get_mapper and get_tables utility functions + + 0.27.1 (2014-10-20) ^^^^^^^^^^^^^^^^^^^ diff --git a/sqlalchemy_utils/functions/orm.py b/sqlalchemy_utils/functions/orm.py index 77a20da..6a6eaa0 100644 --- a/sqlalchemy_utils/functions/orm.py +++ b/sqlalchemy_utils/functions/orm.py @@ -75,6 +75,9 @@ def get_mapper(mixed): .. versionadded: 0.26.1 """ + if isinstance(mixed, sa.orm.query._MapperEntity): + mixed = mixed.expr + if isinstance(mixed, sa.orm.Mapper): return mixed if isinstance(mixed, sa.orm.util.AliasedClass): @@ -83,8 +86,6 @@ def get_mapper(mixed): mixed = mixed.element if isinstance(mixed, AliasedInsp): return mixed.mapper - if isinstance(mixed, sa.orm.query._MapperEntity): - mixed = mixed.expr if isinstance(mixed, sa.orm.attributes.InstrumentedAttribute): mixed = mixed.class_ if isinstance(mixed, sa.Table): diff --git a/tests/functions/test_get_mapper.py b/tests/functions/test_get_mapper.py index a282be9..bdcb8d6 100644 --- a/tests/functions/test_get_mapper.py +++ b/tests/functions/test_get_mapper.py @@ -4,6 +4,8 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy_utils import get_mapper +from tests import TestCase + class TestGetMapper(object): def setup_method(self, method): @@ -55,6 +57,29 @@ class TestGetMapper(object): ) +class TestGetMapperWithQueryEntities(TestCase): + def create_models(self): + class Building(self.Base): + __tablename__ = 'building' + id = sa.Column(sa.Integer, primary_key=True) + + self.Building = Building + + def test_mapper_entity_with_mapper(self): + entity = self.session.query(self.Building.__mapper__)._entities[0] + assert ( + get_mapper(entity) == + sa.inspect(self.Building) + ) + + def test_mapper_entity_with_class(self): + entity = self.session.query(self.Building)._entities[0] + assert ( + get_mapper(entity) == + sa.inspect(self.Building) + ) + + class TestGetMapperWithMultipleMappersFound(object): def setup_method(self, method): Base = declarative_base() diff --git a/tests/functions/test_get_tables.py b/tests/functions/test_get_tables.py index 4ffe769..58f4edc 100644 --- a/tests/functions/test_get_tables.py +++ b/tests/functions/test_get_tables.py @@ -58,14 +58,20 @@ class TestGetTables(TestCase): self.Article.__table__ ] + def test_mapper_entity_with_class(self): + query = self.session.query(self.Article) + assert get_tables(query._entities[0]) == [ + self.TextItem.__table__, self.Article.__table__ + ] + + def test_mapper_entity_with_mapper(self): + query = self.session.query(sa.inspect(self.Article)) + assert get_tables(query._entities[0]) == [ + self.TextItem.__table__, self.Article.__table__ + ] + def test_column_entity(self): query = self.session.query(self.Article.id) assert get_tables(query._entities[0]) == [ self.TextItem.__table__, self.Article.__table__ ] - - def test_mapper_entity(self): - query = self.session.query(self.Article) - assert get_tables(query._entities[0]) == [ - self.TextItem.__table__, self.Article.__table__ - ]