Add hybrid_property support for generic_relationships

This commit is contained in:
Konsta Vesterinen
2014-03-05 14:46:22 +02:00
parent 29b831b97a
commit a529fb5cba
2 changed files with 95 additions and 7 deletions

View File

@@ -1,13 +1,16 @@
from collections import Iterable from collections import Iterable
import six import six
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import attributes, class_mapper
from sqlalchemy.orm import ColumnProperty
from sqlalchemy.orm.interfaces import MapperProperty, PropComparator from sqlalchemy.orm.interfaces import MapperProperty, PropComparator
from sqlalchemy.orm.session import _state_session from sqlalchemy.orm.session import _state_session
from sqlalchemy.orm import attributes, class_mapper
from sqlalchemy.util import set_creation_order from sqlalchemy.util import set_creation_order
from sqlalchemy_utils.functions import table_name, identity from sqlalchemy_utils.functions import table_name, identity
from .exceptions import ImproperlyConfigured
def class_from_table_name(state, table): def class_from_table_name(state, table):
for class_ in state.class_._decl_class_registry.values(): for class_ in state.class_._decl_class_registry.values():
@@ -31,7 +34,7 @@ class GenericAttributeImpl(attributes.ScalarAttributeImpl):
# Find class for discriminator. # Find class for discriminator.
# TODO: Perhaps optimize with some sort of lookup? # TODO: Perhaps optimize with some sort of lookup?
discriminator = state.attrs[self.parent_token.discriminator.key].value discriminator = self.get_state_discriminator(state)
target_class = class_from_table_name(state, discriminator) target_class = class_from_table_name(state, discriminator)
if target_class is None: if target_class is None:
@@ -39,11 +42,19 @@ class GenericAttributeImpl(attributes.ScalarAttributeImpl):
return None return None
id = self.get_state_id(state) id = self.get_state_id(state)
target = session.query(target_class).get(id) target = session.query(target_class).get(id)
# Return found (or not found) target. # Return found (or not found) target.
return target return target
def get_state_discriminator(self, state):
discriminator = self.parent_token.discriminator
if isinstance(discriminator, hybrid_property):
return getattr(state.obj(), discriminator.__name__)
else:
return state.attrs[discriminator.key].value
def get_state_id(self, state): def get_state_id(self, state):
# Lookup row with the discriminator and id. # Lookup row with the discriminator and id.
return tuple(state.attrs[id.key].value for id in self.parent_token.id) return tuple(state.attrs[id.key].value for id in self.parent_token.id)
@@ -98,10 +109,16 @@ class GenericRelationshipProperty(MapperProperty):
set_creation_order(self) set_creation_order(self)
def _column_to_property(self, column): def _column_to_property(self, column):
for name, attr in self.parent.attrs.items(): if isinstance(column, hybrid_property):
other = self.parent.columns.get(name) attr_key = column.__name__
if other is not None and column.name == other.name: for key, attr in self.parent.all_orm_descriptors.items():
return attr if key == attr_key:
return attr
else:
for key, attr in self.parent.attrs.items():
if isinstance(attr, ColumnProperty):
if attr.columns[0].name == column.name:
return attr
def init(self): def init(self):
def convert_strings(column): def convert_strings(column):
@@ -118,6 +135,12 @@ class GenericRelationshipProperty(MapperProperty):
self._id_cols = [self._id_cols] self._id_cols = [self._id_cols]
self.discriminator = self._column_to_property(self._discriminator_col) self.discriminator = self._column_to_property(self._discriminator_col)
if self.discriminator is None:
raise ImproperlyConfigured(
'Could not find discriminator descriptor.'
)
self.id = list(map(self._column_to_property, self._id_cols)) self.id = list(map(self._column_to_property, self._id_cols))
class Comparator(PropComparator): class Comparator(PropComparator):

View File

@@ -0,0 +1,65 @@
from __future__ import unicode_literals
import sqlalchemy as sa
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy_utils import generic_relationship
from tests import TestCase
class TestGenericRelationship(TestCase):
def create_models(self):
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
class UserHistory(self.Base):
__tablename__ = 'user_history'
id = sa.Column(sa.Integer, primary_key=True)
transaction_id = sa.Column(sa.Integer, primary_key=True)
class Event(self.Base):
__tablename__ = 'event'
id = sa.Column(sa.Integer, primary_key=True)
transaction_id = sa.Column(sa.Integer)
object_type = sa.Column(sa.Unicode(255))
object_id = sa.Column(sa.Integer, nullable=False)
object = generic_relationship(
object_type, object_id
)
@hybrid_property
def object_version_type(self):
return self.object_type + '_history'
@object_version_type.expression
def object_version_type(cls):
return sa.func.concat(cls.object_type, '_history')
object_version = generic_relationship(
object_version_type, (object_id, transaction_id)
)
self.User = User
self.UserHistory = UserHistory
self.Event = Event
def test_set_manual_and_get(self):
user = self.User(id=1)
history = self.UserHistory(id=1, transaction_id=1)
self.session.add(user)
self.session.add(history)
self.session.commit()
event = self.Event(transaction_id=1)
event.object_id = user.id
event.object_type = type(user).__tablename__
assert event.object is None
self.session.add(event)
self.session.commit()
assert event.object == user
assert event.object_version == history