From a6169a553c30fe5f6c1e5344070db47a1c106f22 Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Sun, 8 Sep 2013 11:41:43 -0700 Subject: [PATCH] Initial generic relationship implementation. --- sqlalchemy_utils/__init__.py | 1 + sqlalchemy_utils/generic.py | 111 +++++++++++++++++++++++++++++++++++ tests/test_generic.py | 66 +++++++++++++++++++++ 3 files changed, 178 insertions(+) create mode 100644 sqlalchemy_utils/generic.py create mode 100644 tests/test_generic.py diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index 2917856..28c015b 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -11,6 +11,7 @@ from .functions import ( ) from .listeners import coercion_listener from .merge import merge, Merger +from .generic import generic_relationship from .proxy_dict import ProxyDict, proxy_dict from .types import ( ArrowType, diff --git a/sqlalchemy_utils/generic.py b/sqlalchemy_utils/generic.py new file mode 100644 index 0000000..ce57424 --- /dev/null +++ b/sqlalchemy_utils/generic.py @@ -0,0 +1,111 @@ +from sqlalchemy.orm.interfaces import MapperProperty, PropComparator +from sqlalchemy.orm.session import _state_session +from sqlalchemy.orm import attributes, class_mapper +from sqlalchemy.util import set_creation_order +from sqlalchemy import exc as sa_exc, inspect + + +class GenericAttributeImpl(attributes.ScalarAttributeImpl): + + def get(self, state, dict_, passive=attributes.PASSIVE_OFF): + if self.key in dict_: + return dict_[self.key] + + # Retrieve the session bound to the state in order to perform + # a lazy query for the attribute. + session = _state_session(state) + if session is None: + # State is not bound to a session; we cannot proceed. + return None + + # Find class for discriminator. + # TODO: Perhaps optimize with some sort of lookup? + discriminator = state.attrs[self.parent_token.discriminator.key].value + target_class = None + for class_ in state.class_._decl_class_registry.values(): + name = getattr(class_, '__tablename__', None) + if name and name == discriminator: + target_class = class_ + + if target_class is None: + # Unknown discriminator; return nothing. + return None + + # Lookup row with the discriminator and id. + id = state.attrs[self.parent_token.id.key].value + target = session.query(target_class).get(id) + + # Return found (or not found) target. + return target + + def set(self, state, dict_, initiator, + passive=attributes.PASSIVE_OFF, + check_old=None, + pop=False): + + # Set us on the state. + dict_[self.key] = initiator + + # Get the primary key of the initiator. + mapper = class_mapper(type(initiator)) + if len(mapper.primary_key) > 1: + raise sa_exc.InvalidRequestError( + 'Generic relationships against tables with composite ' + 'primary keys are not supported.') + + pk = mapper.identity_key_from_instance(initiator)[1][0] + + # Set the identifier and the discriminator. + discriminator = type(initiator).__tablename__ + dict_[self.parent_token.id.key] = pk + dict_[self.parent_token.discriminator.key] = discriminator + +class GenericRelationshipProperty(MapperProperty): + """A generic form of the relationship property. + + Creates a 1 to many relationship between the parent model + and any other models using a descriminator (the table name). + + :param discriminator + Field to discriminate which model we are referring to. + :param id: + Field to point to the model we are referring to. + """ + + def __init__(self, discriminator, id, doc=None): + self.discriminator = discriminator + self.id = id + self.doc = doc + + set_creation_order(self) + + def _column_to_property(self, column): + for name, attr in self.parent.attrs.items(): + other = self.parent.columns.get(name) + if other is not None and column.name == other.name: + return attr + + def init(self): + # Resolve columns to attributes. + self.discriminator = self._column_to_property(self.discriminator) + self.id = self._column_to_property(self.id) + + class Comparator(PropComparator): + + def __init__(self, prop, parentmapper): + self.prop = prop + self._parentmapper = parentmapper + + def instrument_class(self, mapper): + attributes.register_attribute( + mapper.class_, + self.key, + comparator=self.Comparator(self, mapper), + parententity=mapper, + doc=self.doc, + impl_class=GenericAttributeImpl, + parent_token=self + ) + +def generic_relationship(*args, **kwargs): + return GenericRelationshipProperty(*args, **kwargs) diff --git a/tests/test_generic.py b/tests/test_generic.py new file mode 100644 index 0000000..3738be7 --- /dev/null +++ b/tests/test_generic.py @@ -0,0 +1,66 @@ +import sqlalchemy as sa +from sqlalchemy import orm +from tests import TestCase +from sqlalchemy_utils import generic_relationship + + +class TestGenericForiegnKey(TestCase): + + def create_models(self): + class Building(self.Base): + __tablename__ = 'building' + id = sa.Column(sa.Integer, primary_key=True) + + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + + building_id = sa.Column(sa.Integer, sa.ForeignKey(Building.id), name='buildingID') + + building = orm.relationship(Building) + + class Event(self.Base): + __tablename__ = 'event' + id = sa.Column(sa.Integer, primary_key=True) + + object_type = sa.Column(sa.Unicode(255), name="objectType") + object_id = sa.Column(sa.Integer, nullable=False) + + object = generic_relationship(object_type, object_id) + + self.Building = Building + self.User = User + self.Event = Event + + def test_set_manual_and_get(self): + user = self.User() + + self.session.add(user) + self.session.commit() + + event = self.Event() + event.object_id = user.id + event.object_type = type(user).__tablename__ + + assert event.object is None + + self.session.add(event) + self.session.commit() + + assert event.object == user + + def test_set_and_get(self): + user = self.User() + + self.session.add(user) + self.session.commit() + + event = self.Event(object=user) + + assert event.object_id == user.id + assert event.object_type == type(user).__tablename__ + + self.session.add(event) + self.session.commit() + + assert event.object == user