From fc9d5b5b89155e55391d4678b74e5dd1a3e1cb21 Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Mon, 21 Apr 2014 11:40:02 +0300 Subject: [PATCH] Add aliases support for get_columns --- sqlalchemy_utils/functions/orm.py | 17 +++++++++++++++-- tests/functions/test_get_columns.py | 19 +++++++++++++++++++ tests/functions/test_primary_keys.py | 11 +++++++++++ 3 files changed, 45 insertions(+), 2 deletions(-) diff --git a/sqlalchemy_utils/functions/orm.py b/sqlalchemy_utils/functions/orm.py index 6cccd41..5b44754 100644 --- a/sqlalchemy_utils/functions/orm.py +++ b/sqlalchemy_utils/functions/orm.py @@ -47,7 +47,7 @@ def primary_keys(mixed): def get_columns(mixed): """ Return a collection of all Column objects for given SQLAlchemy - Table object, declarative class or declarative class instance. + object. The type of the collection depends on the type of the object to return the columns from. @@ -60,12 +60,25 @@ def get_columns(mixed): get_columns(User.__table__) + get_columns(User.__mapper__) + + get_column(sa.orm.aliased(User)) + + get_columns(sa.orm.alised(User.__table__)) + :param mixed: - SA Table object, SA declarative class or SA declarative class instance + SA Table object, SA Mapper, SA declarative class, SA declarative class + instance or an alias of any of these objects """ if isinstance(mixed, sa.Table): return mixed.c + if isinstance(mixed, sa.orm.util.AliasedClass): + return sa.inspect(mixed).mapper.columns + if isinstance(mixed, sa.sql.selectable.Alias): + return mixed.c + if isinstance(mixed, sa.orm.Mapper): + return mixed.columns if not isclass(mixed): mixed = mixed.__class__ return sa.inspect(mixed).columns diff --git a/tests/functions/test_get_columns.py b/tests/functions/test_get_columns.py index 42712a8..18dd57c 100644 --- a/tests/functions/test_get_columns.py +++ b/tests/functions/test_get_columns.py @@ -29,3 +29,22 @@ class TestGetColumns(TestCase): get_columns(self.Building()), sa.util._collections.OrderedProperties ) + + def test_mapper(self): + assert isinstance( + get_columns(self.Building.__mapper__), + sa.util._collections.OrderedProperties + ) + + def test_class_alias(self): + assert isinstance( + get_columns(sa.orm.aliased(self.Building)), + sa.util._collections.OrderedProperties + ) + + def test_table_alias(self): + alias = sa.orm.aliased(self.Building.__table__) + assert isinstance( + get_columns(alias), + sa.sql.base.ImmutableColumnCollection + ) diff --git a/tests/functions/test_primary_keys.py b/tests/functions/test_primary_keys.py index c1c9f3a..d86fe2f 100644 --- a/tests/functions/test_primary_keys.py +++ b/tests/functions/test_primary_keys.py @@ -30,3 +30,14 @@ class TestPrimaryKeys(TestCase): assert primary_keys(self.Building()) == OrderedDict({ 'id': self.Building.__table__.c._id }) + + def test_class_alias(self): + assert primary_keys(sa.orm.aliased(self.Building())) == OrderedDict({ + 'id': self.Building.__table__.c._id + }) + + def test_table_alias(self): + alias = sa.orm.aliased(self.Building.__table__) + assert primary_keys(alias) == OrderedDict({ + '_id': alias.c._id + })