diff --git a/docs/orm_helpers.rst b/docs/orm_helpers.rst index 76dc044..888e67f 100644 --- a/docs/orm_helpers.rst +++ b/docs/orm_helpers.rst @@ -16,6 +16,12 @@ get_bind .. autofunction:: get_bind +get_class_by_table +^^^^^^^^^^^^^^^^^^ + +.. autofunction:: get_class_by_table + + get_column_key ^^^^^^^^^^^^^^ diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index 8329d2b..4c56a7f 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -19,6 +19,7 @@ from .functions import ( drop_database, escape_like, get_bind, + get_class_by_table, get_column_key, get_columns, get_declarative_base, diff --git a/sqlalchemy_utils/functions/__init__.py b/sqlalchemy_utils/functions/__init__.py index 188316a..0aa9fe0 100644 --- a/sqlalchemy_utils/functions/__init__.py +++ b/sqlalchemy_utils/functions/__init__.py @@ -26,6 +26,7 @@ from .foreign_keys import ( ) from .orm import ( get_bind, + get_class_by_table, get_column_key, get_columns, get_declarative_base, @@ -51,6 +52,7 @@ __all__ = ( 'drop_database', 'escape_like', 'get_bind', + 'get_class_by_table', 'get_columns', 'get_declarative_base', 'get_hybrid_properties', diff --git a/sqlalchemy_utils/functions/orm.py b/sqlalchemy_utils/functions/orm.py index 961c9e6..859f7b2 100644 --- a/sqlalchemy_utils/functions/orm.py +++ b/sqlalchemy_utils/functions/orm.py @@ -19,6 +19,87 @@ from sqlalchemy.orm.util import AliasedInsp from sqlalchemy_utils.utils import is_sequence +def get_class_by_table(base, table, data=None): + """ + Return declarative class associated with given table. If no class is found + this function returns `None`. If multiple classes were found (polymorphic + cases) additional `data` parameter can be given to hint which class + to return. + + :: + + class User(Base): + __tablename__ = 'entity' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.String) + + + get_class_by_table(Base, User.__table__) # User class + + + This function also supports models using single table inheritance. + Additional data paratemer should be provided in these case. + + :: + + class Entity(Base): + __tablename__ = 'entity' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.String) + type = sa.Column(sa.String) + __mapper_args__ = { + 'polymorphic_on': type, + 'polymorphic_identity': 'entity' + } + + class User(Entity): + __mapper_args__ = { + 'polymorphic_identity': 'user' + } + + + # Entity class + get_class_by_table(Base, Entity.__table__, {'type': 'entity'}) + + # User class + get_class_by_table(Base, Entity.__table__, {'type': 'user'}) + + + :param base: Declarative model base + :param table: SQLAlchemy Table object + :param data: Data row to determine the class in polymorphic scenarios + :return: Declarative class or None. + """ + found_classes = set() + for c in base._decl_class_registry.values(): + if hasattr(c, '__table__') and c.__table__ is table: + found_classes.add(c) + if len(found_classes) > 1: + if not data: + raise ValueError( + "Multiple declarative classes found for table '{0}'. " + "Please provide data parameter for this function to be able " + "to determine polymorphic scenarios.".format( + table.name + ) + ) + else: + for cls in found_classes: + mapper = sa.inspect(cls) + polymorphic_on = mapper.polymorphic_on.name + if polymorphic_on in data: + if data[polymorphic_on] == mapper.polymorphic_identity: + return cls + raise ValueError( + "Multiple declarative classes found for table '{0}'. Given " + "data row matches does not match any polymorphic identity of " + "the found classes." + ) + elif found_classes: + return found_classes.pop() + return None + + def get_column_key(model, column): """ Return the key for given column in given model. diff --git a/tests/functions/test_get_class_by_table.py b/tests/functions/test_get_class_by_table.py new file mode 100644 index 0000000..af1f529 --- /dev/null +++ b/tests/functions/test_get_class_by_table.py @@ -0,0 +1,100 @@ +from pytest import raises + +import sqlalchemy as sa +from sqlalchemy.ext.declarative import declarative_base + +from sqlalchemy_utils import get_class_by_table + + +class TestGetClassByTableWithJoinedTableInheritance(object): + def setup_method(self, method): + self.Base = declarative_base() + + class Entity(self.Base): + __tablename__ = 'entity' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.String) + type = sa.Column(sa.String) + __mapper_args__ = { + 'polymorphic_on': type, + 'polymorphic_identity': 'entity' + } + + class User(Entity): + __tablename__ = 'user' + id = sa.Column( + sa.Integer, + sa.ForeignKey(Entity.id, ondelete='CASCADE'), + primary_key=True + ) + __mapper_args__ = { + 'polymorphic_identity': 'user' + } + + self.Entity = Entity + self.User = User + + def test_returns_class(self): + assert get_class_by_table(self.Base, self.User.__table__) == self.User + assert get_class_by_table( + self.Base, + self.Entity.__table__ + ) == self.Entity + + def test_table_with_no_associated_class(self): + table = sa.Table( + 'some_table', + self.Base.metadata, + sa.Column('id', sa.Integer) + ) + assert get_class_by_table(self.Base, table) is None + + +class TestGetClassByTableWithSingleTableInheritance(object): + def setup_method(self, method): + self.Base = declarative_base() + + class Entity(self.Base): + __tablename__ = 'entity' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.String) + type = sa.Column(sa.String) + __mapper_args__ = { + 'polymorphic_on': type, + 'polymorphic_identity': 'entity' + } + + class User(Entity): + __mapper_args__ = { + 'polymorphic_identity': 'user' + } + + self.Entity = Entity + self.User = User + + def test_multiple_classes_without_data_parameter(self): + with raises(ValueError): + assert get_class_by_table( + self.Base, + self.Entity.__table__ + ) + + def test_multiple_classes_with_data_parameter(self): + assert get_class_by_table( + self.Base, + self.Entity.__table__, + {'type': 'entity'} + ) == self.Entity + assert get_class_by_table( + self.Base, + self.Entity.__table__, + {'type': 'user'} + ) == self.User + + def test_multiple_classes_with_bogus_data(self): + with raises(ValueError): + assert get_class_by_table( + self.Base, + self.Entity.__table__, + {'type': 'unknown'} + )