diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index 52c2af7..8e49c95 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -4,6 +4,7 @@ from .functions import ( batch_fetch, defer_except, escape_like, + identity, primary_keys, render_statement, render_expression, @@ -55,6 +56,7 @@ __all__ = ( escape_like, generates, generic_relationship, + identity, instrumented_list, merge, primary_keys, diff --git a/sqlalchemy_utils/functions/__init__.py b/sqlalchemy_utils/functions/__init__.py index 1c58c1b..89d4d38 100644 --- a/sqlalchemy_utils/functions/__init__.py +++ b/sqlalchemy_utils/functions/__init__.py @@ -201,16 +201,41 @@ def identity(obj): always returns the identity even if object is still in transient state ( new object that is not yet persisted into database). + :: + + from sqlalchemy import inspect + from sqlalchemy_utils import identity + + + user = User(name=u'John Matrix') + session.add(user) + identity(user) # None + inspect(user).identity # None + + session.flush() # User now has id but is still in transient state + + identity(user) # (1,) + inspect(user).identity # None + + session.commit() + + identity(user) # (1,) + inspect(user).identity # (1, ) + + + .. versionadded: 0.21.0 + :param obj: SQLAlchemy declarative model object """ id_ = [] - for attr in obj._sa_class_manager.values(): - prop = attr.property - if isinstance(prop, sa.orm.ColumnProperty): - column = prop.columns[0] - if column.primary_key: - id_.append(getattr(obj, column.name)) - return tuple(id_) + for column in sa.inspect(obj.__class__).columns: + if column.primary_key: + id_.append(getattr(obj, column.name)) + + if all(value is None for value in id_): + return None + else: + return tuple(id_) def naturally_equivalent(obj, obj2): diff --git a/tests/functions/test_identity.py b/tests/functions/test_identity.py index 202bdfe..66bcb90 100644 --- a/tests/functions/test_identity.py +++ b/tests/functions/test_identity.py @@ -13,4 +13,11 @@ class TestIdentity(TestCase): self.Building = Building def test_for_transient_class_without_id(self): - assert identity(self.Building()) == (None,) + assert identity(self.Building()) is None + + def test_for_transient_class_with_id(self): + building = self.Building(name=u'Some building') + self.session.add(building) + self.session.flush() + + assert identity(building) == (building.id, )