diff --git a/sqlalchemy_utils/generic.py b/sqlalchemy_utils/generic.py index 120f29f..6092984 100644 --- a/sqlalchemy_utils/generic.py +++ b/sqlalchemy_utils/generic.py @@ -1,13 +1,16 @@ from collections import Iterable import six - +from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy.orm import attributes, class_mapper +from sqlalchemy.orm import ColumnProperty from sqlalchemy.orm.interfaces import MapperProperty, PropComparator from sqlalchemy.orm.session import _state_session -from sqlalchemy.orm import attributes, class_mapper from sqlalchemy.util import set_creation_order from sqlalchemy_utils.functions import table_name, identity +from .exceptions import ImproperlyConfigured + def class_from_table_name(state, table): for class_ in state.class_._decl_class_registry.values(): @@ -31,7 +34,7 @@ class GenericAttributeImpl(attributes.ScalarAttributeImpl): # Find class for discriminator. # TODO: Perhaps optimize with some sort of lookup? - discriminator = state.attrs[self.parent_token.discriminator.key].value + discriminator = self.get_state_discriminator(state) target_class = class_from_table_name(state, discriminator) if target_class is None: @@ -39,11 +42,19 @@ class GenericAttributeImpl(attributes.ScalarAttributeImpl): return None id = self.get_state_id(state) + target = session.query(target_class).get(id) # Return found (or not found) target. return target + def get_state_discriminator(self, state): + discriminator = self.parent_token.discriminator + if isinstance(discriminator, hybrid_property): + return getattr(state.obj(), discriminator.__name__) + else: + return state.attrs[discriminator.key].value + def get_state_id(self, state): # Lookup row with the discriminator and id. return tuple(state.attrs[id.key].value for id in self.parent_token.id) @@ -98,10 +109,16 @@ class GenericRelationshipProperty(MapperProperty): set_creation_order(self) def _column_to_property(self, column): - for name, attr in self.parent.attrs.items(): - other = self.parent.columns.get(name) - if other is not None and column.name == other.name: - return attr + if isinstance(column, hybrid_property): + attr_key = column.__name__ + for key, attr in self.parent.all_orm_descriptors.items(): + if key == attr_key: + return attr + else: + for key, attr in self.parent.attrs.items(): + if isinstance(attr, ColumnProperty): + if attr.columns[0].name == column.name: + return attr def init(self): def convert_strings(column): @@ -118,6 +135,12 @@ class GenericRelationshipProperty(MapperProperty): self._id_cols = [self._id_cols] self.discriminator = self._column_to_property(self._discriminator_col) + + if self.discriminator is None: + raise ImproperlyConfigured( + 'Could not find discriminator descriptor.' + ) + self.id = list(map(self._column_to_property, self._id_cols)) class Comparator(PropComparator): diff --git a/tests/generic_relationship/test_hybrid_properties.py b/tests/generic_relationship/test_hybrid_properties.py new file mode 100644 index 0000000..c6a8ce0 --- /dev/null +++ b/tests/generic_relationship/test_hybrid_properties.py @@ -0,0 +1,65 @@ +from __future__ import unicode_literals +import sqlalchemy as sa +from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy_utils import generic_relationship +from tests import TestCase + + +class TestGenericRelationship(TestCase): + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + + class UserHistory(self.Base): + __tablename__ = 'user_history' + id = sa.Column(sa.Integer, primary_key=True) + + transaction_id = sa.Column(sa.Integer, primary_key=True) + + class Event(self.Base): + __tablename__ = 'event' + id = sa.Column(sa.Integer, primary_key=True) + + transaction_id = sa.Column(sa.Integer) + + object_type = sa.Column(sa.Unicode(255)) + object_id = sa.Column(sa.Integer, nullable=False) + + object = generic_relationship( + object_type, object_id + ) + + @hybrid_property + def object_version_type(self): + return self.object_type + '_history' + + @object_version_type.expression + def object_version_type(cls): + return sa.func.concat(cls.object_type, '_history') + + object_version = generic_relationship( + object_version_type, (object_id, transaction_id) + ) + + self.User = User + self.UserHistory = UserHistory + self.Event = Event + + def test_set_manual_and_get(self): + user = self.User(id=1) + history = self.UserHistory(id=1, transaction_id=1) + self.session.add(user) + self.session.add(history) + self.session.commit() + + event = self.Event(transaction_id=1) + event.object_id = user.id + event.object_type = type(user).__tablename__ + assert event.object is None + + self.session.add(event) + self.session.commit() + + assert event.object == user + assert event.object_version == history