Add get_class_by_table function
This commit is contained in:
@@ -16,6 +16,12 @@ get_bind
|
|||||||
.. autofunction:: get_bind
|
.. autofunction:: get_bind
|
||||||
|
|
||||||
|
|
||||||
|
get_class_by_table
|
||||||
|
^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
.. autofunction:: get_class_by_table
|
||||||
|
|
||||||
|
|
||||||
get_column_key
|
get_column_key
|
||||||
^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^
|
||||||
|
|
||||||
|
@@ -19,6 +19,7 @@ from .functions import (
|
|||||||
drop_database,
|
drop_database,
|
||||||
escape_like,
|
escape_like,
|
||||||
get_bind,
|
get_bind,
|
||||||
|
get_class_by_table,
|
||||||
get_column_key,
|
get_column_key,
|
||||||
get_columns,
|
get_columns,
|
||||||
get_declarative_base,
|
get_declarative_base,
|
||||||
|
@@ -26,6 +26,7 @@ from .foreign_keys import (
|
|||||||
)
|
)
|
||||||
from .orm import (
|
from .orm import (
|
||||||
get_bind,
|
get_bind,
|
||||||
|
get_class_by_table,
|
||||||
get_column_key,
|
get_column_key,
|
||||||
get_columns,
|
get_columns,
|
||||||
get_declarative_base,
|
get_declarative_base,
|
||||||
@@ -51,6 +52,7 @@ __all__ = (
|
|||||||
'drop_database',
|
'drop_database',
|
||||||
'escape_like',
|
'escape_like',
|
||||||
'get_bind',
|
'get_bind',
|
||||||
|
'get_class_by_table',
|
||||||
'get_columns',
|
'get_columns',
|
||||||
'get_declarative_base',
|
'get_declarative_base',
|
||||||
'get_hybrid_properties',
|
'get_hybrid_properties',
|
||||||
|
@@ -19,6 +19,87 @@ from sqlalchemy.orm.util import AliasedInsp
|
|||||||
from sqlalchemy_utils.utils import is_sequence
|
from sqlalchemy_utils.utils import is_sequence
|
||||||
|
|
||||||
|
|
||||||
|
def get_class_by_table(base, table, data=None):
|
||||||
|
"""
|
||||||
|
Return declarative class associated with given table. If no class is found
|
||||||
|
this function returns `None`. If multiple classes were found (polymorphic
|
||||||
|
cases) additional `data` parameter can be given to hint which class
|
||||||
|
to return.
|
||||||
|
|
||||||
|
::
|
||||||
|
|
||||||
|
class User(Base):
|
||||||
|
__tablename__ = 'entity'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.String)
|
||||||
|
|
||||||
|
|
||||||
|
get_class_by_table(Base, User.__table__) # User class
|
||||||
|
|
||||||
|
|
||||||
|
This function also supports models using single table inheritance.
|
||||||
|
Additional data paratemer should be provided in these case.
|
||||||
|
|
||||||
|
::
|
||||||
|
|
||||||
|
class Entity(Base):
|
||||||
|
__tablename__ = 'entity'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.String)
|
||||||
|
type = sa.Column(sa.String)
|
||||||
|
__mapper_args__ = {
|
||||||
|
'polymorphic_on': type,
|
||||||
|
'polymorphic_identity': 'entity'
|
||||||
|
}
|
||||||
|
|
||||||
|
class User(Entity):
|
||||||
|
__mapper_args__ = {
|
||||||
|
'polymorphic_identity': 'user'
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Entity class
|
||||||
|
get_class_by_table(Base, Entity.__table__, {'type': 'entity'})
|
||||||
|
|
||||||
|
# User class
|
||||||
|
get_class_by_table(Base, Entity.__table__, {'type': 'user'})
|
||||||
|
|
||||||
|
|
||||||
|
:param base: Declarative model base
|
||||||
|
:param table: SQLAlchemy Table object
|
||||||
|
:param data: Data row to determine the class in polymorphic scenarios
|
||||||
|
:return: Declarative class or None.
|
||||||
|
"""
|
||||||
|
found_classes = set()
|
||||||
|
for c in base._decl_class_registry.values():
|
||||||
|
if hasattr(c, '__table__') and c.__table__ is table:
|
||||||
|
found_classes.add(c)
|
||||||
|
if len(found_classes) > 1:
|
||||||
|
if not data:
|
||||||
|
raise ValueError(
|
||||||
|
"Multiple declarative classes found for table '{0}'. "
|
||||||
|
"Please provide data parameter for this function to be able "
|
||||||
|
"to determine polymorphic scenarios.".format(
|
||||||
|
table.name
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
for cls in found_classes:
|
||||||
|
mapper = sa.inspect(cls)
|
||||||
|
polymorphic_on = mapper.polymorphic_on.name
|
||||||
|
if polymorphic_on in data:
|
||||||
|
if data[polymorphic_on] == mapper.polymorphic_identity:
|
||||||
|
return cls
|
||||||
|
raise ValueError(
|
||||||
|
"Multiple declarative classes found for table '{0}'. Given "
|
||||||
|
"data row matches does not match any polymorphic identity of "
|
||||||
|
"the found classes."
|
||||||
|
)
|
||||||
|
elif found_classes:
|
||||||
|
return found_classes.pop()
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_column_key(model, column):
|
def get_column_key(model, column):
|
||||||
"""
|
"""
|
||||||
Return the key for given column in given model.
|
Return the key for given column in given model.
|
||||||
|
100
tests/functions/test_get_class_by_table.py
Normal file
100
tests/functions/test_get_class_by_table.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
from pytest import raises
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
|
||||||
|
from sqlalchemy_utils import get_class_by_table
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetClassByTableWithJoinedTableInheritance(object):
|
||||||
|
def setup_method(self, method):
|
||||||
|
self.Base = declarative_base()
|
||||||
|
|
||||||
|
class Entity(self.Base):
|
||||||
|
__tablename__ = 'entity'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.String)
|
||||||
|
type = sa.Column(sa.String)
|
||||||
|
__mapper_args__ = {
|
||||||
|
'polymorphic_on': type,
|
||||||
|
'polymorphic_identity': 'entity'
|
||||||
|
}
|
||||||
|
|
||||||
|
class User(Entity):
|
||||||
|
__tablename__ = 'user'
|
||||||
|
id = sa.Column(
|
||||||
|
sa.Integer,
|
||||||
|
sa.ForeignKey(Entity.id, ondelete='CASCADE'),
|
||||||
|
primary_key=True
|
||||||
|
)
|
||||||
|
__mapper_args__ = {
|
||||||
|
'polymorphic_identity': 'user'
|
||||||
|
}
|
||||||
|
|
||||||
|
self.Entity = Entity
|
||||||
|
self.User = User
|
||||||
|
|
||||||
|
def test_returns_class(self):
|
||||||
|
assert get_class_by_table(self.Base, self.User.__table__) == self.User
|
||||||
|
assert get_class_by_table(
|
||||||
|
self.Base,
|
||||||
|
self.Entity.__table__
|
||||||
|
) == self.Entity
|
||||||
|
|
||||||
|
def test_table_with_no_associated_class(self):
|
||||||
|
table = sa.Table(
|
||||||
|
'some_table',
|
||||||
|
self.Base.metadata,
|
||||||
|
sa.Column('id', sa.Integer)
|
||||||
|
)
|
||||||
|
assert get_class_by_table(self.Base, table) is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetClassByTableWithSingleTableInheritance(object):
|
||||||
|
def setup_method(self, method):
|
||||||
|
self.Base = declarative_base()
|
||||||
|
|
||||||
|
class Entity(self.Base):
|
||||||
|
__tablename__ = 'entity'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.String)
|
||||||
|
type = sa.Column(sa.String)
|
||||||
|
__mapper_args__ = {
|
||||||
|
'polymorphic_on': type,
|
||||||
|
'polymorphic_identity': 'entity'
|
||||||
|
}
|
||||||
|
|
||||||
|
class User(Entity):
|
||||||
|
__mapper_args__ = {
|
||||||
|
'polymorphic_identity': 'user'
|
||||||
|
}
|
||||||
|
|
||||||
|
self.Entity = Entity
|
||||||
|
self.User = User
|
||||||
|
|
||||||
|
def test_multiple_classes_without_data_parameter(self):
|
||||||
|
with raises(ValueError):
|
||||||
|
assert get_class_by_table(
|
||||||
|
self.Base,
|
||||||
|
self.Entity.__table__
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_multiple_classes_with_data_parameter(self):
|
||||||
|
assert get_class_by_table(
|
||||||
|
self.Base,
|
||||||
|
self.Entity.__table__,
|
||||||
|
{'type': 'entity'}
|
||||||
|
) == self.Entity
|
||||||
|
assert get_class_by_table(
|
||||||
|
self.Base,
|
||||||
|
self.Entity.__table__,
|
||||||
|
{'type': 'user'}
|
||||||
|
) == self.User
|
||||||
|
|
||||||
|
def test_multiple_classes_with_bogus_data(self):
|
||||||
|
with raises(ValueError):
|
||||||
|
assert get_class_by_table(
|
||||||
|
self.Base,
|
||||||
|
self.Entity.__table__,
|
||||||
|
{'type': 'unknown'}
|
||||||
|
)
|
Reference in New Issue
Block a user