Add get_class_by_table function
This commit is contained in:
@@ -16,6 +16,12 @@ get_bind
|
||||
.. autofunction:: get_bind
|
||||
|
||||
|
||||
get_class_by_table
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: get_class_by_table
|
||||
|
||||
|
||||
get_column_key
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
|
@@ -19,6 +19,7 @@ from .functions import (
|
||||
drop_database,
|
||||
escape_like,
|
||||
get_bind,
|
||||
get_class_by_table,
|
||||
get_column_key,
|
||||
get_columns,
|
||||
get_declarative_base,
|
||||
|
@@ -26,6 +26,7 @@ from .foreign_keys import (
|
||||
)
|
||||
from .orm import (
|
||||
get_bind,
|
||||
get_class_by_table,
|
||||
get_column_key,
|
||||
get_columns,
|
||||
get_declarative_base,
|
||||
@@ -51,6 +52,7 @@ __all__ = (
|
||||
'drop_database',
|
||||
'escape_like',
|
||||
'get_bind',
|
||||
'get_class_by_table',
|
||||
'get_columns',
|
||||
'get_declarative_base',
|
||||
'get_hybrid_properties',
|
||||
|
@@ -19,6 +19,87 @@ from sqlalchemy.orm.util import AliasedInsp
|
||||
from sqlalchemy_utils.utils import is_sequence
|
||||
|
||||
|
||||
def get_class_by_table(base, table, data=None):
|
||||
"""
|
||||
Return declarative class associated with given table. If no class is found
|
||||
this function returns `None`. If multiple classes were found (polymorphic
|
||||
cases) additional `data` parameter can be given to hint which class
|
||||
to return.
|
||||
|
||||
::
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'entity'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.String)
|
||||
|
||||
|
||||
get_class_by_table(Base, User.__table__) # User class
|
||||
|
||||
|
||||
This function also supports models using single table inheritance.
|
||||
Additional data paratemer should be provided in these case.
|
||||
|
||||
::
|
||||
|
||||
class Entity(Base):
|
||||
__tablename__ = 'entity'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.String)
|
||||
type = sa.Column(sa.String)
|
||||
__mapper_args__ = {
|
||||
'polymorphic_on': type,
|
||||
'polymorphic_identity': 'entity'
|
||||
}
|
||||
|
||||
class User(Entity):
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity': 'user'
|
||||
}
|
||||
|
||||
|
||||
# Entity class
|
||||
get_class_by_table(Base, Entity.__table__, {'type': 'entity'})
|
||||
|
||||
# User class
|
||||
get_class_by_table(Base, Entity.__table__, {'type': 'user'})
|
||||
|
||||
|
||||
:param base: Declarative model base
|
||||
:param table: SQLAlchemy Table object
|
||||
:param data: Data row to determine the class in polymorphic scenarios
|
||||
:return: Declarative class or None.
|
||||
"""
|
||||
found_classes = set()
|
||||
for c in base._decl_class_registry.values():
|
||||
if hasattr(c, '__table__') and c.__table__ is table:
|
||||
found_classes.add(c)
|
||||
if len(found_classes) > 1:
|
||||
if not data:
|
||||
raise ValueError(
|
||||
"Multiple declarative classes found for table '{0}'. "
|
||||
"Please provide data parameter for this function to be able "
|
||||
"to determine polymorphic scenarios.".format(
|
||||
table.name
|
||||
)
|
||||
)
|
||||
else:
|
||||
for cls in found_classes:
|
||||
mapper = sa.inspect(cls)
|
||||
polymorphic_on = mapper.polymorphic_on.name
|
||||
if polymorphic_on in data:
|
||||
if data[polymorphic_on] == mapper.polymorphic_identity:
|
||||
return cls
|
||||
raise ValueError(
|
||||
"Multiple declarative classes found for table '{0}'. Given "
|
||||
"data row matches does not match any polymorphic identity of "
|
||||
"the found classes."
|
||||
)
|
||||
elif found_classes:
|
||||
return found_classes.pop()
|
||||
return None
|
||||
|
||||
|
||||
def get_column_key(model, column):
|
||||
"""
|
||||
Return the key for given column in given model.
|
||||
|
100
tests/functions/test_get_class_by_table.py
Normal file
100
tests/functions/test_get_class_by_table.py
Normal file
@@ -0,0 +1,100 @@
|
||||
from pytest import raises
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from sqlalchemy_utils import get_class_by_table
|
||||
|
||||
|
||||
class TestGetClassByTableWithJoinedTableInheritance(object):
|
||||
def setup_method(self, method):
|
||||
self.Base = declarative_base()
|
||||
|
||||
class Entity(self.Base):
|
||||
__tablename__ = 'entity'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.String)
|
||||
type = sa.Column(sa.String)
|
||||
__mapper_args__ = {
|
||||
'polymorphic_on': type,
|
||||
'polymorphic_identity': 'entity'
|
||||
}
|
||||
|
||||
class User(Entity):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(
|
||||
sa.Integer,
|
||||
sa.ForeignKey(Entity.id, ondelete='CASCADE'),
|
||||
primary_key=True
|
||||
)
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity': 'user'
|
||||
}
|
||||
|
||||
self.Entity = Entity
|
||||
self.User = User
|
||||
|
||||
def test_returns_class(self):
|
||||
assert get_class_by_table(self.Base, self.User.__table__) == self.User
|
||||
assert get_class_by_table(
|
||||
self.Base,
|
||||
self.Entity.__table__
|
||||
) == self.Entity
|
||||
|
||||
def test_table_with_no_associated_class(self):
|
||||
table = sa.Table(
|
||||
'some_table',
|
||||
self.Base.metadata,
|
||||
sa.Column('id', sa.Integer)
|
||||
)
|
||||
assert get_class_by_table(self.Base, table) is None
|
||||
|
||||
|
||||
class TestGetClassByTableWithSingleTableInheritance(object):
|
||||
def setup_method(self, method):
|
||||
self.Base = declarative_base()
|
||||
|
||||
class Entity(self.Base):
|
||||
__tablename__ = 'entity'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.String)
|
||||
type = sa.Column(sa.String)
|
||||
__mapper_args__ = {
|
||||
'polymorphic_on': type,
|
||||
'polymorphic_identity': 'entity'
|
||||
}
|
||||
|
||||
class User(Entity):
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity': 'user'
|
||||
}
|
||||
|
||||
self.Entity = Entity
|
||||
self.User = User
|
||||
|
||||
def test_multiple_classes_without_data_parameter(self):
|
||||
with raises(ValueError):
|
||||
assert get_class_by_table(
|
||||
self.Base,
|
||||
self.Entity.__table__
|
||||
)
|
||||
|
||||
def test_multiple_classes_with_data_parameter(self):
|
||||
assert get_class_by_table(
|
||||
self.Base,
|
||||
self.Entity.__table__,
|
||||
{'type': 'entity'}
|
||||
) == self.Entity
|
||||
assert get_class_by_table(
|
||||
self.Base,
|
||||
self.Entity.__table__,
|
||||
{'type': 'user'}
|
||||
) == self.User
|
||||
|
||||
def test_multiple_classes_with_bogus_data(self):
|
||||
with raises(ValueError):
|
||||
assert get_class_by_table(
|
||||
self.Base,
|
||||
self.Entity.__table__,
|
||||
{'type': 'unknown'}
|
||||
)
|
Reference in New Issue
Block a user