From 0b34e9447d46d283bf5c0552598fa2d7e0bfe9b7 Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Tue, 6 May 2014 10:53:04 +0300 Subject: [PATCH] Add get_tables utility function --- CHANGES.rst | 1 + docs/model_helpers.rst | 6 +++ sqlalchemy_utils/__init__.py | 2 + sqlalchemy_utils/functions/__init__.py | 2 + sqlalchemy_utils/functions/orm.py | 47 ++++++++++++++--- .../test_get_referencing_foreign_keys.py | 51 +++++++++++++++++++ 6 files changed, 101 insertions(+), 8 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 02473d0..3dbae83 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -8,6 +8,7 @@ Here you can see the full list of changes between each SQLAlchemy-Utils release. ^^^^^^^^^^^^^^^^^^^ - Added get_referencing_foreign_keys +- Added get_tables 0.25.4 (2014-04-22) diff --git a/docs/model_helpers.rst b/docs/model_helpers.rst index 5213374..61f535a 100644 --- a/docs/model_helpers.rst +++ b/docs/model_helpers.rst @@ -34,6 +34,12 @@ get_referencing_foreign_keys .. autofunction:: get_referencing_foreign_keys +get_tables +^^^^^^^^^^ + +.. autofunction:: get_tables + + query_entities ^^^^^^^^^^^^^^ diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index 2589a29..87641fc 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -14,6 +14,7 @@ from .functions import ( get_declarative_base, get_primary_keys, get_referencing_foreign_keys, + get_tables, identity, mock_engine, naturally_equivalent, @@ -82,6 +83,7 @@ __all__ = ( get_declarative_base, get_primary_keys, get_referencing_foreign_keys, + get_tables, identity, instrumented_list, merge, diff --git a/sqlalchemy_utils/functions/__init__.py b/sqlalchemy_utils/functions/__init__.py index 6f35ef2..b62ad46 100644 --- a/sqlalchemy_utils/functions/__init__.py +++ b/sqlalchemy_utils/functions/__init__.py @@ -16,6 +16,7 @@ from .orm import ( get_declarative_base, get_primary_keys, get_referencing_foreign_keys, + get_tables, getdotattr, has_changes, identity, @@ -35,6 +36,7 @@ __all__ = ( 'get_declarative_base', 'get_primary_keys', 'get_referencing_foreign_keys', + 'get_tables', 'getdotattr', 'has_changes', 'identity', diff --git a/sqlalchemy_utils/functions/orm.py b/sqlalchemy_utils/functions/orm.py index 3ca3c65..161b7c2 100644 --- a/sqlalchemy_utils/functions/orm.py +++ b/sqlalchemy_utils/functions/orm.py @@ -31,17 +31,17 @@ def get_referencing_foreign_keys(mixed): if isinstance(mixed, sa.Table): tables = [mixed] else: - # TODO: make this support joined table inheritance - tables = [mixed.__table__] + tables = get_tables(mixed) referencing_foreign_keys = set() for table in mixed.metadata.tables.values(): - for constraint in table.constraints: - if isinstance(constraint, sa.sql.schema.ForeignKeyConstraint): - for fk in constraint.elements: - if any(fk.references(t) for t in tables): - referencing_foreign_keys.add(fk) + if table not in tables: + for constraint in table.constraints: + if isinstance(constraint, sa.sql.schema.ForeignKeyConstraint): + for fk in constraint.elements: + if any(fk.references(t) for t in tables): + referencing_foreign_keys.add(fk) return referencing_foreign_keys @@ -65,7 +65,7 @@ def get_primary_keys(mixed): get_primary_keys(sa.orm.aliased(User)) - get_primary_keys(sa.orm.alised(User.__table__)) + get_primary_keys(sa.orm.aliased(User.__table__)) .. versionchanged: 0.25.3 @@ -84,6 +84,37 @@ def get_primary_keys(mixed): ) +def get_tables(mixed): + """ + Return a list of tables associated with given SQLAlchemy object. + + Let's say we have three classes which use joined table inheritance + TextItem, Article and BlogPost. Article and BlogPost inherit TextItem. + + :: + + get_tables(Article) # [Table('article', ...), Table('text_item')] + + get_tables(Article()) + + get_tables(Article.__mapper__) + + + .. versionadded: 0.25.5 + + :param mixed: + SQLAlchemy Mapper / Declarative class or a SA Alias object wrapping + any of these objects. + """ + if isinstance(mixed, sa.orm.util.AliasedClass): + mapper = sa.inspect(mixed).mapper + else: + if not isclass(mixed): + mixed = mixed.__class__ + mapper = sa.inspect(mixed) + return mapper.tables + + def get_columns(mixed): """ Return a collection of all Column objects for given SQLAlchemy diff --git a/tests/functions/test_get_referencing_foreign_keys.py b/tests/functions/test_get_referencing_foreign_keys.py index 4f55ba3..3ad4385 100644 --- a/tests/functions/test_get_referencing_foreign_keys.py +++ b/tests/functions/test_get_referencing_foreign_keys.py @@ -32,3 +32,54 @@ class TestGetReferencingFksWithCompositeKeys(TestCase): def test_with_table(self): fks = get_referencing_foreign_keys(self.User.__table__) assert self.Article.__table__.foreign_keys == fks + + +class TestGetReferencingFksWithInheritance(TestCase): + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + type = sa.Column(sa.Unicode) + first_name = sa.Column(sa.Unicode(255), primary_key=True) + last_name = sa.Column(sa.Unicode(255), primary_key=True) + + __mapper_args__ = { + 'polymorphic_on': 'type' + } + + class Admin(User): + __tablename__ = 'admin' + id = sa.Column( + sa.Integer, sa.ForeignKey(User.id), primary_key=True + ) + + class TextItem(self.Base): + __tablename__ = 'textitem' + id = sa.Column(sa.Integer, primary_key=True) + type = sa.Column(sa.Unicode) + author_id = sa.Column(sa.Integer, sa.ForeignKey(User.id)) + __mapper_args__ = { + 'polymorphic_on': 'type' + } + + class Article(TextItem): + __tablename__ = 'article' + id = sa.Column( + sa.Integer, sa.ForeignKey(TextItem.id), primary_key=True + ) + __mapper_args__ = { + 'polymorphic_identity': 'article' + } + + self.Admin = Admin + self.User = User + self.Article = Article + self.TextItem = TextItem + + def test_with_declarative_class(self): + fks = get_referencing_foreign_keys(self.Admin) + assert self.TextItem.__table__.foreign_keys == fks + + def test_with_table(self): + fks = get_referencing_foreign_keys(self.Admin.__table__) + assert fks == set([])