diff --git a/CHANGES.rst b/CHANGES.rst index 8e5c464..437cb13 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,6 +4,12 @@ Changelog Here you can see the full list of changes between each SQLAlchemy-Utils release. +0.27.1 (2014-10-xx) +^^^^^^^^^^^^^^^^^^^ + +- Added support for more SQLAlchemy based objects and classes in get_tables function + + 0.27.0 (2014-10-14) ^^^^^^^^^^^^^^^^^^^ diff --git a/sqlalchemy_utils/functions/foreign_keys.py b/sqlalchemy_utils/functions/foreign_keys.py index 2ade46f..11a646f 100644 --- a/sqlalchemy_utils/functions/foreign_keys.py +++ b/sqlalchemy_utils/functions/foreign_keys.py @@ -3,9 +3,8 @@ from itertools import groupby import six import sqlalchemy as sa -from sqlalchemy.engine import reflection from sqlalchemy.exc import NoInspectionAvailable -from sqlalchemy.orm import object_session, mapperlib +from sqlalchemy.orm import object_session from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint from .orm import get_mapper, get_tables diff --git a/sqlalchemy_utils/functions/orm.py b/sqlalchemy_utils/functions/orm.py index 41ce319..9ff3a56 100644 --- a/sqlalchemy_utils/functions/orm.py +++ b/sqlalchemy_utils/functions/orm.py @@ -177,27 +177,47 @@ def get_primary_keys(mixed): def get_tables(mixed): """ - Return a list of tables associated with given SQLAlchemy object. + Return a set of tables associated with given SQLAlchemy object. Let's say we have three classes which use joined table inheritance TextItem, Article and BlogPost. Article and BlogPost inherit TextItem. :: - get_tables(Article) # [Table('article', ...), Table('text_item')] + get_tables(Article) # set([Table('article', ...), Table('text_item')]) get_tables(Article()) get_tables(Article.__mapper__) + If the TextItem entity is using with_polymorphic='*' then this function + returns all child tables (article and blog_post) as well. + + :: + + + get_tables(TextItem) # set([Table('text_item', ...)], ...]) + + .. versionadded: 0.26.0 :param mixed: SQLAlchemy Mapper / Declarative class or a SA Alias object wrapping any of these objects. """ - return get_mapper(mixed).tables + if isinstance(mixed, sa.Table): + return [mixed] + elif isinstance(mixed, sa.Column): + return [mixed.table] + elif isinstance(mixed, sa.orm.query._ColumnEntity): + mixed = mixed.expr + mapper = get_mapper(mixed) + + polymorphic_mappers = get_polymorphic_mappers(mapper) + if polymorphic_mappers: + return sum((m.tables for m in polymorphic_mappers), []) + return mapper.tables def get_columns(mixed): diff --git a/tests/functions/test_get_tables.py b/tests/functions/test_get_tables.py new file mode 100644 index 0000000..e81ec15 --- /dev/null +++ b/tests/functions/test_get_tables.py @@ -0,0 +1,55 @@ +import sqlalchemy as sa + +from sqlalchemy_utils import get_tables + +from tests import TestCase + + +class TestGetTables(TestCase): + def create_models(self): + class TextItem(self.Base): + __tablename__ = 'text_item' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + type = sa.Column(sa.Unicode(255)) + + __mapper_args__ = { + 'polymorphic_on': type, + 'with_polymorphic': '*' + } + + class Article(TextItem): + __tablename__ = 'article' + id = sa.Column( + sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True + ) + __mapper_args__ = { + 'polymorphic_identity': u'article' + } + + self.TextItem = TextItem + self.Article = Article + + def test_child_class_using_join_table_inheritance(self): + assert get_tables(self.Article) == [ + self.TextItem.__table__, + self.Article.__table__ + ] + + def test_entity_using_with_polymorphic(self): + assert get_tables(self.TextItem) == [ + self.TextItem.__table__, + self.Article.__table__ + ] + + def test_column(self): + assert get_tables(self.Article.__table__.c.id) == [ + self.Article.__table__ + ] + + def test_column_entity(self): + query = self.session.query(self.Article.id) + assert get_tables(query._entities[0]) == [ + self.TextItem.__table__, self.Article.__table__ + ] +