Add support for pk aliases

This commit is contained in:
Konsta Vesterinen
2014-04-21 10:56:04 +03:00
parent 13964b3759
commit 73d4dbb2d1
3 changed files with 59 additions and 38 deletions

View File

@@ -23,27 +23,25 @@ from .orm import (
)
__all__ = (
create_database,
create_mock_engine,
database_exists,
declarative_base,
defer_except,
drop_database,
escape_like,
getdotattr,
has_changes,
identity,
is_auto_assigned_date_column,
is_indexed_foreign_key,
mock_engine,
naturally_equivalent,
non_indexed_foreign_keys,
primary_keys,
QuerySorterException,
render_expression,
render_statement,
sort_query,
table_name,
'create_database',
'create_mock_engine',
'database_exists',
'declarative_base',
'defer_except',
'drop_database',
'escape_like',
'getdotattr',
'has_changes',
'identity',
'is_auto_assigned_date_column',
'is_indexed_foreign_key',
'mock_engine',
'naturally_equivalent',
'non_indexed_foreign_keys',
'primary_keys',
'QuerySorterException',
'render_expression',
'render_statement',
'sort_query',
'table_name',
)

View File

@@ -1,4 +1,9 @@
try:
from collections import OrderedDict
except ImportError:
from ordereddict import OrderedDict
from functools import partial
from inspect import isclass
from operator import attrgetter
import sqlalchemy as sa
from sqlalchemy import inspect
@@ -9,13 +14,19 @@ from sqlalchemy.orm.query import _ColumnEntity
from sqlalchemy.orm.util import AliasedInsp
def primary_keys(class_):
def primary_keys(obj_or_class):
"""
Returns all primary keys for given declarative class.
Return an OrderedDict of all primary keys for given declarative class or
object.
"""
for column in class_.__table__.c:
if not isclass(obj_or_class):
obj_or_class = obj_or_class.__class__
columns = OrderedDict()
for key, column in sa.inspect(obj_or_class).columns.items():
if column.primary_key:
yield column
columns[key] = column
return columns
def table_name(obj):
@@ -325,8 +336,8 @@ def identity(obj_or_class):
:param obj: SQLAlchemy declarative model object
"""
return tuple(
getattr(obj_or_class, column.name)
for column in primary_keys(obj_or_class)
getattr(obj_or_class, column_key)
for column_key in primary_keys(obj_or_class).keys()
)

View File

@@ -3,15 +3,7 @@ from sqlalchemy_utils.functions import identity
from tests import TestCase
class TestIdentity(TestCase):
def create_models(self):
class Building(self.Base):
__tablename__ = 'building'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
self.Building = Building
class IdentityTestCase(TestCase):
def test_for_transient_class_without_id(self):
assert identity(self.Building()) == (None, )
@@ -24,3 +16,23 @@ class TestIdentity(TestCase):
def test_identity_for_class(self):
assert identity(self.Building) == (self.Building.id, )
class TestIdentity(IdentityTestCase):
def create_models(self):
class Building(self.Base):
__tablename__ = 'building'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
self.Building = Building
class TestIdentityWithColumnAlias(IdentityTestCase):
def create_models(self):
class Building(self.Base):
__tablename__ = 'building'
id = sa.Column('_id', sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
self.Building = Building