diff --git a/sqlalchemy_utils/functions/orm.py b/sqlalchemy_utils/functions/orm.py index 5d84435..0e2149a 100644 --- a/sqlalchemy_utils/functions/orm.py +++ b/sqlalchemy_utils/functions/orm.py @@ -24,7 +24,9 @@ def get_referencing_foreign_keys(mixed): :: - get_foreign_keys(User) # set([ForeignKey('user.id')]) + get_foreign_keys(User) # set([ForeignKey('user.id')]) + + get_foreign_keys(User.__table__) # set([ForeignKey('user.id')]) """ if isinstance(mixed, sa.Table): tables = [mixed] diff --git a/tests/functions/test_get_referencing_foreign_keys.py b/tests/functions/test_get_referencing_foreign_keys.py new file mode 100644 index 0000000..4f55ba3 --- /dev/null +++ b/tests/functions/test_get_referencing_foreign_keys.py @@ -0,0 +1,34 @@ +import sqlalchemy as sa +from sqlalchemy_utils import get_referencing_foreign_keys +from tests import TestCase + + +class TestGetReferencingFksWithCompositeKeys(TestCase): + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + first_name = sa.Column(sa.Unicode(255), primary_key=True) + last_name = sa.Column(sa.Unicode(255), primary_key=True) + + class Article(self.Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + author_first_name = sa.Column(sa.Unicode(255)) + author_last_name = sa.Column(sa.Unicode(255)) + __table_args__ = ( + sa.ForeignKeyConstraint( + [author_first_name, author_last_name], + [User.first_name, User.last_name] + ), + ) + + self.User = User + self.Article = Article + + def test_with_declarative_class(self): + fks = get_referencing_foreign_keys(self.User) + assert self.Article.__table__.foreign_keys == fks + + def test_with_table(self): + fks = get_referencing_foreign_keys(self.User.__table__) + assert self.Article.__table__.foreign_keys == fks diff --git a/tests/test_expression_parser.py b/tests/test_expression_parser.py index 4578b52..6510818 100644 --- a/tests/test_expression_parser.py +++ b/tests/test_expression_parser.py @@ -79,3 +79,7 @@ class TestExpressionParser(TestCase): def test_null(self): expr = self.parser(self.User.name == Null()) assert str(expr) == 'category.name IS NULL' + + def test_instrumented_attribute(self): + expr = self.parser(self.User.name) + assert str(expr) == 'category.name'