diff --git a/CHANGES.rst b/CHANGES.rst index c2128ed..761db32 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,10 +4,12 @@ Changelog Here you can see the full list of changes between each SQLAlchemy-Utils release. -0.26.0 (2014-05-xx) +0.26.1 (2014-05-xx) ^^^^^^^^^^^^^^^^^^^ - Added get_bind +- Added group_foreign_keys +- Added get_mapper 0.26.0 (2014-05-07) diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index 9ef5214..196f030 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -14,6 +14,7 @@ from .functions import ( get_bind, get_columns, get_declarative_base, + get_mapper, get_primary_keys, get_referencing_foreign_keys, get_tables, @@ -87,6 +88,7 @@ __all__ = ( get_bind, get_columns, get_declarative_base, + get_mapper, get_primary_keys, get_referencing_foreign_keys, get_tables, diff --git a/sqlalchemy_utils/functions/__init__.py b/sqlalchemy_utils/functions/__init__.py index 95891ba..39e5b7e 100644 --- a/sqlalchemy_utils/functions/__init__.py +++ b/sqlalchemy_utils/functions/__init__.py @@ -16,6 +16,7 @@ from .orm import ( get_bind, get_columns, get_declarative_base, + get_mapper, get_primary_keys, get_referencing_foreign_keys, get_tables, @@ -39,6 +40,7 @@ __all__ = ( 'get_bind', 'get_columns', 'get_declarative_base', + 'get_mapper', 'get_primary_keys', 'get_referencing_foreign_keys', 'get_tables', diff --git a/sqlalchemy_utils/functions/orm.py b/sqlalchemy_utils/functions/orm.py index 0b4fd38..960738a 100644 --- a/sqlalchemy_utils/functions/orm.py +++ b/sqlalchemy_utils/functions/orm.py @@ -9,6 +9,7 @@ from operator import attrgetter import sqlalchemy as sa from sqlalchemy import inspect from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy.orm import mapperlib from sqlalchemy.orm.attributes import InstrumentedAttribute from sqlalchemy.orm.exc import UnmappedInstanceError from sqlalchemy.orm.mapper import Mapper @@ -18,6 +19,37 @@ from sqlalchemy.orm.util import AliasedInsp from ..query_chain import QueryChain +def get_mapper(mixed): + """ + Return related SQLAlchemy Mapper for given SQLAlchemy object. + + :param mixed: SQLAlchemy Table object + + .. versionadded: 0.26.1 + """ + if isinstance(mixed, sa.orm.Mapper): + return mixed + if isinstance(mixed, sa.orm.util.AliasedClass): + return sa.inspect(mixed).mapper + if isinstance(mixed, sa.sql.selectable.Alias): + mixed = mixed.element + if isinstance(mixed, sa.Table): + mappers = [ + mapper for mapper in mapperlib._mapper_registry + if mixed in mapper.tables + ] + if len(mappers) > 1: + raise ValueError( + "Could not get mapper for '%r'. Multiple mappers found." + % mixed + ) + else: + return mappers[0] + if not isclass(mixed): + mixed = type(mixed) + return sa.inspect(mixed) + + def get_bind(obj): """ Return the bind for given SQLAlchemy Engine / Connection / declarative diff --git a/tests/functions/test_get_mapper.py b/tests/functions/test_get_mapper.py new file mode 100644 index 0000000..3d8bee4 --- /dev/null +++ b/tests/functions/test_get_mapper.py @@ -0,0 +1,73 @@ +from pytest import raises +import sqlalchemy as sa +from sqlalchemy.ext.declarative import declarative_base + +from sqlalchemy_utils import get_mapper + + +class TestGetMapper(object): + def setup_method(self, method): + self.Base = declarative_base() + + class Building(self.Base): + __tablename__ = 'building' + id = sa.Column(sa.Integer, primary_key=True) + + self.Building = Building + + def test_table(self): + assert get_mapper(self.Building.__table__) == sa.inspect(self.Building) + + def test_declarative_class(self): + assert ( + get_mapper(self.Building) == + sa.inspect(self.Building) + ) + + def test_declarative_object(self): + assert ( + get_mapper(self.Building()) == + sa.inspect(self.Building) + ) + + def test_mapper(self): + assert ( + get_mapper(self.Building.__mapper__) == + sa.inspect(self.Building) + ) + + def test_class_alias(self): + assert ( + get_mapper(sa.orm.aliased(self.Building)) == + sa.inspect(self.Building) + ) + + def test_table_alias(self): + alias = sa.orm.aliased(self.Building.__table__) + assert ( + get_mapper(alias) == + sa.inspect(self.Building) + ) + + +class TestGetMapperWithMultipleMappersFound(object): + def setup_method(self, method): + Base = declarative_base() + + class Building(Base): + __tablename__ = 'building' + id = sa.Column(sa.Integer, primary_key=True) + + class BigBuilding(Building): + pass + + self.Building = Building + + def test_table(self): + with raises(ValueError): + get_mapper(self.Building.__table__) + + def test_table_alias(self): + alias = sa.orm.aliased(self.Building.__table__) + with raises(ValueError): + get_mapper(alias)