From 44496e26a3fff8f3b5f25619afd2a2e1e286a27f Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Wed, 5 Mar 2014 11:09:59 +0200 Subject: [PATCH] Add string arg support for generic relationship --- docs/generic_relationship.rst | 32 +++++++++++++- sqlalchemy_utils/generic.py | 14 ++++++ tests/test_generic.py | 80 ++++++++++++++++++++++++----------- 3 files changed, 101 insertions(+), 25 deletions(-) diff --git a/docs/generic_relationship.rst b/docs/generic_relationship.rst index 9c2c787..d381168 100644 --- a/docs/generic_relationship.rst +++ b/docs/generic_relationship.rst @@ -47,4 +47,34 @@ Generic relationship is a form of relationship that supports creating a 1 to man # Find any events that are bound to users. session.query(Event).filter(Event.object.is_type(User)).all() -.. _colour: https://github.com/vaab/colour + +Using generic_relationship with abstract base classes +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Generic relationships also allows using string arguments. When using generic_relationship with abstract base classes you need to set up the relationship using declared_attr decorator and string arguments. + + +:: + + + class Building(self.Base): + __tablename__ = 'building' + id = sa.Column(sa.Integer, primary_key=True) + + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + + class EventBase(self.Base): + __abstract__ = True + + object_type = sa.Column(sa.Unicode(255)) + object_id = sa.Column(sa.Integer, nullable=False) + + @declared_attr + def object(cls): + return generic_relationship('object_type', 'object_id') + + class Event(EventBase): + __tablename__ = 'event' + id = sa.Column(sa.Integer, primary_key=True) diff --git a/sqlalchemy_utils/generic.py b/sqlalchemy_utils/generic.py index e4aeb49..6fe4499 100644 --- a/sqlalchemy_utils/generic.py +++ b/sqlalchemy_utils/generic.py @@ -1,3 +1,5 @@ +import six + from sqlalchemy.orm.interfaces import MapperProperty, PropComparator from sqlalchemy.orm.session import _state_session from sqlalchemy.orm import attributes, class_mapper @@ -86,6 +88,8 @@ class GenericRelationshipProperty(MapperProperty): def __init__(self, discriminator, id, doc=None): self._discriminator_col = discriminator self._id_col = id + self._id = None + self._discriminator = None self.doc = doc set_creation_order(self) @@ -98,6 +102,16 @@ class GenericRelationshipProperty(MapperProperty): def init(self): # Resolve columns to attributes. + if isinstance(self._discriminator_col, six.string_types): + self._discriminator_col = self.parent.columns[ + self._discriminator_col + ] + + if isinstance(self._id_col, six.string_types): + self._id_col = self.parent.columns[ + self._id_col + ] + self.discriminator = self._column_to_property(self._discriminator_col) self.id = self._column_to_property(self._id_col) diff --git a/tests/test_generic.py b/tests/test_generic.py index af94265..0ce6671 100644 --- a/tests/test_generic.py +++ b/tests/test_generic.py @@ -2,35 +2,14 @@ from __future__ import unicode_literals import sqlalchemy as sa from tests import TestCase from sqlalchemy_utils import generic_relationship +from sqlalchemy.ext.declarative import declared_attr -class TestGenericForiegnKey(TestCase): - - def create_models(self): - class Building(self.Base): - __tablename__ = 'building' - id = sa.Column(sa.Integer, primary_key=True) - - class User(self.Base): - __tablename__ = 'user' - id = sa.Column(sa.Integer, primary_key=True) - - class Event(self.Base): - __tablename__ = 'event' - id = sa.Column(sa.Integer, primary_key=True) - - object_type = sa.Column(sa.Unicode(255), name="objectType") - object_id = sa.Column(sa.Integer, nullable=False) - - object = generic_relationship(object_type, object_id) - - self.Building = Building - self.User = User - self.Event = Event - +class GenericRelationshipTestCase(TestCase): def test_set_as_none(self): event = self.Event() event.object = None + assert event.object is None def test_set_manual_and_get(self): user = self.User() @@ -128,3 +107,56 @@ class TestGenericForiegnKey(TestCase): statement = self.Event.object.is_type(self.User) q = self.session.query(self.Event).filter(statement) assert q.first() is not None + + +class TestGenericRelationship(GenericRelationshipTestCase): + def create_models(self): + class Building(self.Base): + __tablename__ = 'building' + id = sa.Column(sa.Integer, primary_key=True) + + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + + class Event(self.Base): + __tablename__ = 'event' + id = sa.Column(sa.Integer, primary_key=True) + + object_type = sa.Column(sa.Unicode(255), name="objectType") + object_id = sa.Column(sa.Integer, nullable=False) + + object = generic_relationship(object_type, object_id) + + self.Building = Building + self.User = User + self.Event = Event + + +class TestGenericRelationshipWithAbstractBase(GenericRelationshipTestCase): + def create_models(self): + class Building(self.Base): + __tablename__ = 'building' + id = sa.Column(sa.Integer, primary_key=True) + + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + + class EventBase(self.Base): + __abstract__ = True + + object_type = sa.Column(sa.Unicode(255)) + object_id = sa.Column(sa.Integer, nullable=False) + + @declared_attr + def object(cls): + return generic_relationship('object_type', 'object_id') + + class Event(EventBase): + __tablename__ = 'event' + id = sa.Column(sa.Integer, primary_key=True) + + self.Building = Building + self.User = User + self.Event = Event