diff --git a/CHANGES.rst b/CHANGES.rst index 923a080..3aaf785 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,6 +4,12 @@ Changelog Here you can see the full list of changes between each SQLAlchemy-Utils release. +0.30.13 (2015-07-16) +^^^^^^^^^^^^^^^^^^^^ + +- Added support for InstrumentedAttributes, ColumnProperties and Columns in get_columns function + + 0.30.12 (2015-07-05) ^^^^^^^^^^^^^^^^^^^^ diff --git a/sqlalchemy_utils/aggregates.py b/sqlalchemy_utils/aggregates.py index a5f8caf..2daa537 100644 --- a/sqlalchemy_utils/aggregates.py +++ b/sqlalchemy_utils/aggregates.py @@ -377,7 +377,6 @@ from .relationships import ( select_correlated_expression ) - aggregated_attrs = WeakKeyDictionary(defaultdict(list)) diff --git a/sqlalchemy_utils/functions/orm.py b/sqlalchemy_utils/functions/orm.py index 56dba9e..583d6aa 100644 --- a/sqlalchemy_utils/functions/orm.py +++ b/sqlalchemy_utils/functions/orm.py @@ -403,14 +403,18 @@ def get_columns(mixed): SA Table object, SA Mapper, SA declarative class, SA declarative class instance or an alias of any of these objects """ - if isinstance(mixed, sa.Table): + if isinstance(mixed, sa.sql.selectable.Selectable): return mixed.c if isinstance(mixed, sa.orm.util.AliasedClass): return sa.inspect(mixed).mapper.columns - if isinstance(mixed, sa.sql.selectable.Alias): - return mixed.c if isinstance(mixed, sa.orm.Mapper): return mixed.columns + if isinstance(mixed, InstrumentedAttribute): + return mixed.property.columns + if isinstance(mixed, ColumnProperty): + return mixed.columns + if isinstance(mixed, sa.Column): + return [mixed] if not isclass(mixed): mixed = mixed.__class__ return sa.inspect(mixed).columns diff --git a/tests/functions/test_get_columns.py b/tests/functions/test_get_columns.py index 1ed3e3a..d09dd39 100644 --- a/tests/functions/test_get_columns.py +++ b/tests/functions/test_get_columns.py @@ -21,6 +21,19 @@ class TestGetColumns(object): sa.sql.base.ImmutableColumnCollection ) + def test_instrumented_attribute(self): + assert get_columns(self.Building.id) == [self.Building.__table__.c._id] + + def test_column_property(self): + assert get_columns(self.Building.id.property) == [ + self.Building.__table__.c._id + ] + + def test_column(self): + assert get_columns(self.Building.__table__.c._id) == [ + self.Building.__table__.c._id + ] + def test_declarative_class(self): assert isinstance( get_columns(self.Building),