Add hybrid_property support for generic_relationships
This commit is contained in:
@@ -1,13 +1,16 @@
|
|||||||
from collections import Iterable
|
from collections import Iterable
|
||||||
|
|
||||||
import six
|
import six
|
||||||
|
from sqlalchemy.ext.hybrid import hybrid_property
|
||||||
|
from sqlalchemy.orm import attributes, class_mapper
|
||||||
|
from sqlalchemy.orm import ColumnProperty
|
||||||
from sqlalchemy.orm.interfaces import MapperProperty, PropComparator
|
from sqlalchemy.orm.interfaces import MapperProperty, PropComparator
|
||||||
from sqlalchemy.orm.session import _state_session
|
from sqlalchemy.orm.session import _state_session
|
||||||
from sqlalchemy.orm import attributes, class_mapper
|
|
||||||
from sqlalchemy.util import set_creation_order
|
from sqlalchemy.util import set_creation_order
|
||||||
from sqlalchemy_utils.functions import table_name, identity
|
from sqlalchemy_utils.functions import table_name, identity
|
||||||
|
|
||||||
|
from .exceptions import ImproperlyConfigured
|
||||||
|
|
||||||
|
|
||||||
def class_from_table_name(state, table):
|
def class_from_table_name(state, table):
|
||||||
for class_ in state.class_._decl_class_registry.values():
|
for class_ in state.class_._decl_class_registry.values():
|
||||||
@@ -31,7 +34,7 @@ class GenericAttributeImpl(attributes.ScalarAttributeImpl):
|
|||||||
|
|
||||||
# Find class for discriminator.
|
# Find class for discriminator.
|
||||||
# TODO: Perhaps optimize with some sort of lookup?
|
# TODO: Perhaps optimize with some sort of lookup?
|
||||||
discriminator = state.attrs[self.parent_token.discriminator.key].value
|
discriminator = self.get_state_discriminator(state)
|
||||||
target_class = class_from_table_name(state, discriminator)
|
target_class = class_from_table_name(state, discriminator)
|
||||||
|
|
||||||
if target_class is None:
|
if target_class is None:
|
||||||
@@ -39,11 +42,19 @@ class GenericAttributeImpl(attributes.ScalarAttributeImpl):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
id = self.get_state_id(state)
|
id = self.get_state_id(state)
|
||||||
|
|
||||||
target = session.query(target_class).get(id)
|
target = session.query(target_class).get(id)
|
||||||
|
|
||||||
# Return found (or not found) target.
|
# Return found (or not found) target.
|
||||||
return target
|
return target
|
||||||
|
|
||||||
|
def get_state_discriminator(self, state):
|
||||||
|
discriminator = self.parent_token.discriminator
|
||||||
|
if isinstance(discriminator, hybrid_property):
|
||||||
|
return getattr(state.obj(), discriminator.__name__)
|
||||||
|
else:
|
||||||
|
return state.attrs[discriminator.key].value
|
||||||
|
|
||||||
def get_state_id(self, state):
|
def get_state_id(self, state):
|
||||||
# Lookup row with the discriminator and id.
|
# Lookup row with the discriminator and id.
|
||||||
return tuple(state.attrs[id.key].value for id in self.parent_token.id)
|
return tuple(state.attrs[id.key].value for id in self.parent_token.id)
|
||||||
@@ -98,10 +109,16 @@ class GenericRelationshipProperty(MapperProperty):
|
|||||||
set_creation_order(self)
|
set_creation_order(self)
|
||||||
|
|
||||||
def _column_to_property(self, column):
|
def _column_to_property(self, column):
|
||||||
for name, attr in self.parent.attrs.items():
|
if isinstance(column, hybrid_property):
|
||||||
other = self.parent.columns.get(name)
|
attr_key = column.__name__
|
||||||
if other is not None and column.name == other.name:
|
for key, attr in self.parent.all_orm_descriptors.items():
|
||||||
return attr
|
if key == attr_key:
|
||||||
|
return attr
|
||||||
|
else:
|
||||||
|
for key, attr in self.parent.attrs.items():
|
||||||
|
if isinstance(attr, ColumnProperty):
|
||||||
|
if attr.columns[0].name == column.name:
|
||||||
|
return attr
|
||||||
|
|
||||||
def init(self):
|
def init(self):
|
||||||
def convert_strings(column):
|
def convert_strings(column):
|
||||||
@@ -118,6 +135,12 @@ class GenericRelationshipProperty(MapperProperty):
|
|||||||
self._id_cols = [self._id_cols]
|
self._id_cols = [self._id_cols]
|
||||||
|
|
||||||
self.discriminator = self._column_to_property(self._discriminator_col)
|
self.discriminator = self._column_to_property(self._discriminator_col)
|
||||||
|
|
||||||
|
if self.discriminator is None:
|
||||||
|
raise ImproperlyConfigured(
|
||||||
|
'Could not find discriminator descriptor.'
|
||||||
|
)
|
||||||
|
|
||||||
self.id = list(map(self._column_to_property, self._id_cols))
|
self.id = list(map(self._column_to_property, self._id_cols))
|
||||||
|
|
||||||
class Comparator(PropComparator):
|
class Comparator(PropComparator):
|
||||||
|
65
tests/generic_relationship/test_hybrid_properties.py
Normal file
65
tests/generic_relationship/test_hybrid_properties.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
from __future__ import unicode_literals
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.ext.hybrid import hybrid_property
|
||||||
|
from sqlalchemy_utils import generic_relationship
|
||||||
|
from tests import TestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestGenericRelationship(TestCase):
|
||||||
|
def create_models(self):
|
||||||
|
class User(self.Base):
|
||||||
|
__tablename__ = 'user'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
|
||||||
|
class UserHistory(self.Base):
|
||||||
|
__tablename__ = 'user_history'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
|
||||||
|
transaction_id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
|
||||||
|
class Event(self.Base):
|
||||||
|
__tablename__ = 'event'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
|
||||||
|
transaction_id = sa.Column(sa.Integer)
|
||||||
|
|
||||||
|
object_type = sa.Column(sa.Unicode(255))
|
||||||
|
object_id = sa.Column(sa.Integer, nullable=False)
|
||||||
|
|
||||||
|
object = generic_relationship(
|
||||||
|
object_type, object_id
|
||||||
|
)
|
||||||
|
|
||||||
|
@hybrid_property
|
||||||
|
def object_version_type(self):
|
||||||
|
return self.object_type + '_history'
|
||||||
|
|
||||||
|
@object_version_type.expression
|
||||||
|
def object_version_type(cls):
|
||||||
|
return sa.func.concat(cls.object_type, '_history')
|
||||||
|
|
||||||
|
object_version = generic_relationship(
|
||||||
|
object_version_type, (object_id, transaction_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.User = User
|
||||||
|
self.UserHistory = UserHistory
|
||||||
|
self.Event = Event
|
||||||
|
|
||||||
|
def test_set_manual_and_get(self):
|
||||||
|
user = self.User(id=1)
|
||||||
|
history = self.UserHistory(id=1, transaction_id=1)
|
||||||
|
self.session.add(user)
|
||||||
|
self.session.add(history)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
event = self.Event(transaction_id=1)
|
||||||
|
event.object_id = user.id
|
||||||
|
event.object_type = type(user).__tablename__
|
||||||
|
assert event.object is None
|
||||||
|
|
||||||
|
self.session.add(event)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
assert event.object == user
|
||||||
|
assert event.object_version == history
|
Reference in New Issue
Block a user