diff --git a/sqlalchemy_utils/functions/__init__.py b/sqlalchemy_utils/functions/__init__.py index 74e41d1..bc6caeb 100644 --- a/sqlalchemy_utils/functions/__init__.py +++ b/sqlalchemy_utils/functions/__init__.py @@ -23,27 +23,25 @@ from .orm import ( ) __all__ = ( - create_database, - create_mock_engine, - database_exists, - declarative_base, - defer_except, - drop_database, - escape_like, - getdotattr, - has_changes, - identity, - is_auto_assigned_date_column, - is_indexed_foreign_key, - mock_engine, - naturally_equivalent, - non_indexed_foreign_keys, - primary_keys, - QuerySorterException, - render_expression, - render_statement, - sort_query, - table_name, + 'create_database', + 'create_mock_engine', + 'database_exists', + 'declarative_base', + 'defer_except', + 'drop_database', + 'escape_like', + 'getdotattr', + 'has_changes', + 'identity', + 'is_auto_assigned_date_column', + 'is_indexed_foreign_key', + 'mock_engine', + 'naturally_equivalent', + 'non_indexed_foreign_keys', + 'primary_keys', + 'QuerySorterException', + 'render_expression', + 'render_statement', + 'sort_query', + 'table_name', ) - - diff --git a/sqlalchemy_utils/functions/orm.py b/sqlalchemy_utils/functions/orm.py index 8811dd3..1a276d7 100644 --- a/sqlalchemy_utils/functions/orm.py +++ b/sqlalchemy_utils/functions/orm.py @@ -1,4 +1,9 @@ +try: + from collections import OrderedDict +except ImportError: + from ordereddict import OrderedDict from functools import partial +from inspect import isclass from operator import attrgetter import sqlalchemy as sa from sqlalchemy import inspect @@ -9,13 +14,19 @@ from sqlalchemy.orm.query import _ColumnEntity from sqlalchemy.orm.util import AliasedInsp -def primary_keys(class_): +def primary_keys(obj_or_class): """ - Returns all primary keys for given declarative class. + Return an OrderedDict of all primary keys for given declarative class or + object. """ - for column in class_.__table__.c: + if not isclass(obj_or_class): + obj_or_class = obj_or_class.__class__ + + columns = OrderedDict() + for key, column in sa.inspect(obj_or_class).columns.items(): if column.primary_key: - yield column + columns[key] = column + return columns def table_name(obj): @@ -325,8 +336,8 @@ def identity(obj_or_class): :param obj: SQLAlchemy declarative model object """ return tuple( - getattr(obj_or_class, column.name) - for column in primary_keys(obj_or_class) + getattr(obj_or_class, column_key) + for column_key in primary_keys(obj_or_class).keys() ) diff --git a/tests/functions/test_identity.py b/tests/functions/test_identity.py index 74094fb..f220f0d 100644 --- a/tests/functions/test_identity.py +++ b/tests/functions/test_identity.py @@ -3,15 +3,7 @@ from sqlalchemy_utils.functions import identity from tests import TestCase -class TestIdentity(TestCase): - def create_models(self): - class Building(self.Base): - __tablename__ = 'building' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column(sa.Unicode(255)) - - self.Building = Building - +class IdentityTestCase(TestCase): def test_for_transient_class_without_id(self): assert identity(self.Building()) == (None, ) @@ -24,3 +16,23 @@ class TestIdentity(TestCase): def test_identity_for_class(self): assert identity(self.Building) == (self.Building.id, ) + + +class TestIdentity(IdentityTestCase): + def create_models(self): + class Building(self.Base): + __tablename__ = 'building' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + self.Building = Building + + +class TestIdentityWithColumnAlias(IdentityTestCase): + def create_models(self): + class Building(self.Base): + __tablename__ = 'building' + id = sa.Column('_id', sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + self.Building = Building