Add support for pk aliases
This commit is contained in:
@@ -23,27 +23,25 @@ from .orm import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
create_database,
|
'create_database',
|
||||||
create_mock_engine,
|
'create_mock_engine',
|
||||||
database_exists,
|
'database_exists',
|
||||||
declarative_base,
|
'declarative_base',
|
||||||
defer_except,
|
'defer_except',
|
||||||
drop_database,
|
'drop_database',
|
||||||
escape_like,
|
'escape_like',
|
||||||
getdotattr,
|
'getdotattr',
|
||||||
has_changes,
|
'has_changes',
|
||||||
identity,
|
'identity',
|
||||||
is_auto_assigned_date_column,
|
'is_auto_assigned_date_column',
|
||||||
is_indexed_foreign_key,
|
'is_indexed_foreign_key',
|
||||||
mock_engine,
|
'mock_engine',
|
||||||
naturally_equivalent,
|
'naturally_equivalent',
|
||||||
non_indexed_foreign_keys,
|
'non_indexed_foreign_keys',
|
||||||
primary_keys,
|
'primary_keys',
|
||||||
QuerySorterException,
|
'QuerySorterException',
|
||||||
render_expression,
|
'render_expression',
|
||||||
render_statement,
|
'render_statement',
|
||||||
sort_query,
|
'sort_query',
|
||||||
table_name,
|
'table_name',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@@ -1,4 +1,9 @@
|
|||||||
|
try:
|
||||||
|
from collections import OrderedDict
|
||||||
|
except ImportError:
|
||||||
|
from ordereddict import OrderedDict
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from inspect import isclass
|
||||||
from operator import attrgetter
|
from operator import attrgetter
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from sqlalchemy import inspect
|
from sqlalchemy import inspect
|
||||||
@@ -9,13 +14,19 @@ from sqlalchemy.orm.query import _ColumnEntity
|
|||||||
from sqlalchemy.orm.util import AliasedInsp
|
from sqlalchemy.orm.util import AliasedInsp
|
||||||
|
|
||||||
|
|
||||||
def primary_keys(class_):
|
def primary_keys(obj_or_class):
|
||||||
"""
|
"""
|
||||||
Returns all primary keys for given declarative class.
|
Return an OrderedDict of all primary keys for given declarative class or
|
||||||
|
object.
|
||||||
"""
|
"""
|
||||||
for column in class_.__table__.c:
|
if not isclass(obj_or_class):
|
||||||
|
obj_or_class = obj_or_class.__class__
|
||||||
|
|
||||||
|
columns = OrderedDict()
|
||||||
|
for key, column in sa.inspect(obj_or_class).columns.items():
|
||||||
if column.primary_key:
|
if column.primary_key:
|
||||||
yield column
|
columns[key] = column
|
||||||
|
return columns
|
||||||
|
|
||||||
|
|
||||||
def table_name(obj):
|
def table_name(obj):
|
||||||
@@ -325,8 +336,8 @@ def identity(obj_or_class):
|
|||||||
:param obj: SQLAlchemy declarative model object
|
:param obj: SQLAlchemy declarative model object
|
||||||
"""
|
"""
|
||||||
return tuple(
|
return tuple(
|
||||||
getattr(obj_or_class, column.name)
|
getattr(obj_or_class, column_key)
|
||||||
for column in primary_keys(obj_or_class)
|
for column_key in primary_keys(obj_or_class).keys()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@@ -3,15 +3,7 @@ from sqlalchemy_utils.functions import identity
|
|||||||
from tests import TestCase
|
from tests import TestCase
|
||||||
|
|
||||||
|
|
||||||
class TestIdentity(TestCase):
|
class IdentityTestCase(TestCase):
|
||||||
def create_models(self):
|
|
||||||
class Building(self.Base):
|
|
||||||
__tablename__ = 'building'
|
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
|
||||||
name = sa.Column(sa.Unicode(255))
|
|
||||||
|
|
||||||
self.Building = Building
|
|
||||||
|
|
||||||
def test_for_transient_class_without_id(self):
|
def test_for_transient_class_without_id(self):
|
||||||
assert identity(self.Building()) == (None, )
|
assert identity(self.Building()) == (None, )
|
||||||
|
|
||||||
@@ -24,3 +16,23 @@ class TestIdentity(TestCase):
|
|||||||
|
|
||||||
def test_identity_for_class(self):
|
def test_identity_for_class(self):
|
||||||
assert identity(self.Building) == (self.Building.id, )
|
assert identity(self.Building) == (self.Building.id, )
|
||||||
|
|
||||||
|
|
||||||
|
class TestIdentity(IdentityTestCase):
|
||||||
|
def create_models(self):
|
||||||
|
class Building(self.Base):
|
||||||
|
__tablename__ = 'building'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
self.Building = Building
|
||||||
|
|
||||||
|
|
||||||
|
class TestIdentityWithColumnAlias(IdentityTestCase):
|
||||||
|
def create_models(self):
|
||||||
|
class Building(self.Base):
|
||||||
|
__tablename__ = 'building'
|
||||||
|
id = sa.Column('_id', sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
self.Building = Building
|
||||||
|
Reference in New Issue
Block a user