Fix query entity handling for get_mapper

This commit is contained in:
Konsta Vesterinen
2014-10-21 14:15:45 +03:00
parent 92efb4d326
commit 26db4e94b4
4 changed files with 46 additions and 8 deletions

View File

@@ -4,6 +4,12 @@ Changelog
Here you can see the full list of changes between each SQLAlchemy-Utils release. Here you can see the full list of changes between each SQLAlchemy-Utils release.
0.27.2 (2014-10-21)
^^^^^^^^^^^^^^^^^^^
- Fixed MapperEntity handling in get_mapper and get_tables utility functions
0.27.1 (2014-10-20) 0.27.1 (2014-10-20)
^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^

View File

@@ -75,6 +75,9 @@ def get_mapper(mixed):
.. versionadded: 0.26.1 .. versionadded: 0.26.1
""" """
if isinstance(mixed, sa.orm.query._MapperEntity):
mixed = mixed.expr
if isinstance(mixed, sa.orm.Mapper): if isinstance(mixed, sa.orm.Mapper):
return mixed return mixed
if isinstance(mixed, sa.orm.util.AliasedClass): if isinstance(mixed, sa.orm.util.AliasedClass):
@@ -83,8 +86,6 @@ def get_mapper(mixed):
mixed = mixed.element mixed = mixed.element
if isinstance(mixed, AliasedInsp): if isinstance(mixed, AliasedInsp):
return mixed.mapper return mixed.mapper
if isinstance(mixed, sa.orm.query._MapperEntity):
mixed = mixed.expr
if isinstance(mixed, sa.orm.attributes.InstrumentedAttribute): if isinstance(mixed, sa.orm.attributes.InstrumentedAttribute):
mixed = mixed.class_ mixed = mixed.class_
if isinstance(mixed, sa.Table): if isinstance(mixed, sa.Table):

View File

@@ -4,6 +4,8 @@ from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy_utils import get_mapper from sqlalchemy_utils import get_mapper
from tests import TestCase
class TestGetMapper(object): class TestGetMapper(object):
def setup_method(self, method): def setup_method(self, method):
@@ -55,6 +57,29 @@ class TestGetMapper(object):
) )
class TestGetMapperWithQueryEntities(TestCase):
def create_models(self):
class Building(self.Base):
__tablename__ = 'building'
id = sa.Column(sa.Integer, primary_key=True)
self.Building = Building
def test_mapper_entity_with_mapper(self):
entity = self.session.query(self.Building.__mapper__)._entities[0]
assert (
get_mapper(entity) ==
sa.inspect(self.Building)
)
def test_mapper_entity_with_class(self):
entity = self.session.query(self.Building)._entities[0]
assert (
get_mapper(entity) ==
sa.inspect(self.Building)
)
class TestGetMapperWithMultipleMappersFound(object): class TestGetMapperWithMultipleMappersFound(object):
def setup_method(self, method): def setup_method(self, method):
Base = declarative_base() Base = declarative_base()

View File

@@ -58,14 +58,20 @@ class TestGetTables(TestCase):
self.Article.__table__ self.Article.__table__
] ]
def test_mapper_entity_with_class(self):
query = self.session.query(self.Article)
assert get_tables(query._entities[0]) == [
self.TextItem.__table__, self.Article.__table__
]
def test_mapper_entity_with_mapper(self):
query = self.session.query(sa.inspect(self.Article))
assert get_tables(query._entities[0]) == [
self.TextItem.__table__, self.Article.__table__
]
def test_column_entity(self): def test_column_entity(self):
query = self.session.query(self.Article.id) query = self.session.query(self.Article.id)
assert get_tables(query._entities[0]) == [ assert get_tables(query._entities[0]) == [
self.TextItem.__table__, self.Article.__table__ self.TextItem.__table__, self.Article.__table__
] ]
def test_mapper_entity(self):
query = self.session.query(self.Article)
assert get_tables(query._entities[0]) == [
self.TextItem.__table__, self.Article.__table__
]