Add get_class_by_table function

This commit is contained in:
Konsta Vesterinen
2015-03-03 14:01:17 +02:00
parent e62fe449c0
commit 8290696139
5 changed files with 190 additions and 0 deletions

View File

@@ -16,6 +16,12 @@ get_bind
.. autofunction:: get_bind
get_class_by_table
^^^^^^^^^^^^^^^^^^
.. autofunction:: get_class_by_table
get_column_key
^^^^^^^^^^^^^^

View File

@@ -19,6 +19,7 @@ from .functions import (
drop_database,
escape_like,
get_bind,
get_class_by_table,
get_column_key,
get_columns,
get_declarative_base,

View File

@@ -26,6 +26,7 @@ from .foreign_keys import (
)
from .orm import (
get_bind,
get_class_by_table,
get_column_key,
get_columns,
get_declarative_base,
@@ -51,6 +52,7 @@ __all__ = (
'drop_database',
'escape_like',
'get_bind',
'get_class_by_table',
'get_columns',
'get_declarative_base',
'get_hybrid_properties',

View File

@@ -19,6 +19,87 @@ from sqlalchemy.orm.util import AliasedInsp
from sqlalchemy_utils.utils import is_sequence
def get_class_by_table(base, table, data=None):
"""
Return declarative class associated with given table. If no class is found
this function returns `None`. If multiple classes were found (polymorphic
cases) additional `data` parameter can be given to hint which class
to return.
::
class User(Base):
__tablename__ = 'entity'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String)
get_class_by_table(Base, User.__table__) # User class
This function also supports models using single table inheritance.
Additional data paratemer should be provided in these case.
::
class Entity(Base):
__tablename__ = 'entity'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String)
type = sa.Column(sa.String)
__mapper_args__ = {
'polymorphic_on': type,
'polymorphic_identity': 'entity'
}
class User(Entity):
__mapper_args__ = {
'polymorphic_identity': 'user'
}
# Entity class
get_class_by_table(Base, Entity.__table__, {'type': 'entity'})
# User class
get_class_by_table(Base, Entity.__table__, {'type': 'user'})
:param base: Declarative model base
:param table: SQLAlchemy Table object
:param data: Data row to determine the class in polymorphic scenarios
:return: Declarative class or None.
"""
found_classes = set()
for c in base._decl_class_registry.values():
if hasattr(c, '__table__') and c.__table__ is table:
found_classes.add(c)
if len(found_classes) > 1:
if not data:
raise ValueError(
"Multiple declarative classes found for table '{0}'. "
"Please provide data parameter for this function to be able "
"to determine polymorphic scenarios.".format(
table.name
)
)
else:
for cls in found_classes:
mapper = sa.inspect(cls)
polymorphic_on = mapper.polymorphic_on.name
if polymorphic_on in data:
if data[polymorphic_on] == mapper.polymorphic_identity:
return cls
raise ValueError(
"Multiple declarative classes found for table '{0}'. Given "
"data row matches does not match any polymorphic identity of "
"the found classes."
)
elif found_classes:
return found_classes.pop()
return None
def get_column_key(model, column):
"""
Return the key for given column in given model.

View File

@@ -0,0 +1,100 @@
from pytest import raises
import sqlalchemy as sa
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy_utils import get_class_by_table
class TestGetClassByTableWithJoinedTableInheritance(object):
def setup_method(self, method):
self.Base = declarative_base()
class Entity(self.Base):
__tablename__ = 'entity'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String)
type = sa.Column(sa.String)
__mapper_args__ = {
'polymorphic_on': type,
'polymorphic_identity': 'entity'
}
class User(Entity):
__tablename__ = 'user'
id = sa.Column(
sa.Integer,
sa.ForeignKey(Entity.id, ondelete='CASCADE'),
primary_key=True
)
__mapper_args__ = {
'polymorphic_identity': 'user'
}
self.Entity = Entity
self.User = User
def test_returns_class(self):
assert get_class_by_table(self.Base, self.User.__table__) == self.User
assert get_class_by_table(
self.Base,
self.Entity.__table__
) == self.Entity
def test_table_with_no_associated_class(self):
table = sa.Table(
'some_table',
self.Base.metadata,
sa.Column('id', sa.Integer)
)
assert get_class_by_table(self.Base, table) is None
class TestGetClassByTableWithSingleTableInheritance(object):
def setup_method(self, method):
self.Base = declarative_base()
class Entity(self.Base):
__tablename__ = 'entity'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String)
type = sa.Column(sa.String)
__mapper_args__ = {
'polymorphic_on': type,
'polymorphic_identity': 'entity'
}
class User(Entity):
__mapper_args__ = {
'polymorphic_identity': 'user'
}
self.Entity = Entity
self.User = User
def test_multiple_classes_without_data_parameter(self):
with raises(ValueError):
assert get_class_by_table(
self.Base,
self.Entity.__table__
)
def test_multiple_classes_with_data_parameter(self):
assert get_class_by_table(
self.Base,
self.Entity.__table__,
{'type': 'entity'}
) == self.Entity
assert get_class_by_table(
self.Base,
self.Entity.__table__,
{'type': 'user'}
) == self.User
def test_multiple_classes_with_bogus_data(self):
with raises(ValueError):
assert get_class_by_table(
self.Base,
self.Entity.__table__,
{'type': 'unknown'}
)