import sqlalchemy as sa from pytest import raises from sqlalchemy.ext.declarative import declarative_base from sqlalchemy_utils import get_mapper from tests import TestCase class TestGetMapper(object): def setup_method(self, method): self.Base = declarative_base() class Building(self.Base): __tablename__ = 'building' id = sa.Column(sa.Integer, primary_key=True) self.Building = Building def test_table(self): assert get_mapper(self.Building.__table__) == sa.inspect(self.Building) def test_declarative_class(self): assert ( get_mapper(self.Building) == sa.inspect(self.Building) ) def test_declarative_object(self): assert ( get_mapper(self.Building()) == sa.inspect(self.Building) ) def test_mapper(self): assert ( get_mapper(self.Building.__mapper__) == sa.inspect(self.Building) ) def test_class_alias(self): assert ( get_mapper(sa.orm.aliased(self.Building)) == sa.inspect(self.Building) ) def test_instrumented_attribute(self): assert ( get_mapper(self.Building.id) == sa.inspect(self.Building) ) def test_table_alias(self): alias = sa.orm.aliased(self.Building.__table__) assert ( get_mapper(alias) == sa.inspect(self.Building) ) def test_column(self): assert ( get_mapper(self.Building.__table__.c.id) == sa.inspect(self.Building) ) def test_column_of_an_alias(self): assert ( get_mapper(sa.orm.aliased(self.Building.__table__).c.id) == sa.inspect(self.Building) ) class TestGetMapperWithQueryEntities(TestCase): def create_models(self): class Building(self.Base): __tablename__ = 'building' id = sa.Column(sa.Integer, primary_key=True) self.Building = Building def test_mapper_entity_with_mapper(self): entity = self.session.query(self.Building.__mapper__)._entities[0] assert ( get_mapper(entity) == sa.inspect(self.Building) ) def test_mapper_entity_with_class(self): entity = self.session.query(self.Building)._entities[0] assert ( get_mapper(entity) == sa.inspect(self.Building) ) def test_column_entity(self): query = self.session.query(self.Building.id) assert get_mapper(query._entities[0]) == sa.inspect(self.Building) class TestGetMapperWithMultipleMappersFound(object): def setup_method(self, method): Base = declarative_base() class Building(Base): __tablename__ = 'building' id = sa.Column(sa.Integer, primary_key=True) class BigBuilding(Building): pass self.Building = Building self.BigBuilding = BigBuilding def test_table(self): with raises(ValueError): get_mapper(self.Building.__table__) def test_table_alias(self): alias = sa.orm.aliased(self.Building.__table__) with raises(ValueError): get_mapper(alias) class TestGetMapperForTableWithoutMapper(object): def setup_method(self, method): metadata = sa.MetaData() self.building = sa.Table('building', metadata) def test_table(self): with raises(ValueError): get_mapper(self.building) def test_table_alias(self): alias = sa.orm.aliased(self.building) with raises(ValueError): get_mapper(alias)