diff --git a/sqlalchemy_utils/generic.py b/sqlalchemy_utils/generic.py index ce57424..676cdb5 100644 --- a/sqlalchemy_utils/generic.py +++ b/sqlalchemy_utils/generic.py @@ -2,7 +2,8 @@ from sqlalchemy.orm.interfaces import MapperProperty, PropComparator from sqlalchemy.orm.session import _state_session from sqlalchemy.orm import attributes, class_mapper from sqlalchemy.util import set_creation_order -from sqlalchemy import exc as sa_exc, inspect +from sqlalchemy import exc as sa_exc +from .functions import table_name class GenericAttributeImpl(attributes.ScalarAttributeImpl): @@ -23,7 +24,7 @@ class GenericAttributeImpl(attributes.ScalarAttributeImpl): discriminator = state.attrs[self.parent_token.discriminator.key].value target_class = None for class_ in state.class_._decl_class_registry.values(): - name = getattr(class_, '__tablename__', None) + name = table_name(class_) if name and name == discriminator: target_class = class_ @@ -46,7 +47,8 @@ class GenericAttributeImpl(attributes.ScalarAttributeImpl): # Set us on the state. dict_[self.key] = initiator - # Get the primary key of the initiator. + # Get the primary key of the initiator and ensure we + # can support this assignment. mapper = class_mapper(type(initiator)) if len(mapper.primary_key) > 1: raise sa_exc.InvalidRequestError( @@ -56,10 +58,11 @@ class GenericAttributeImpl(attributes.ScalarAttributeImpl): pk = mapper.identity_key_from_instance(initiator)[1][0] # Set the identifier and the discriminator. - discriminator = type(initiator).__tablename__ + discriminator = table_name(initiator) dict_[self.parent_token.id.key] = pk dict_[self.parent_token.discriminator.key] = discriminator + class GenericRelationshipProperty(MapperProperty): """A generic form of the relationship property. @@ -73,8 +76,8 @@ class GenericRelationshipProperty(MapperProperty): """ def __init__(self, discriminator, id, doc=None): - self.discriminator = discriminator - self.id = id + self._discriminator_col = discriminator + self._id_col = id self.doc = doc set_creation_order(self) @@ -87,8 +90,8 @@ class GenericRelationshipProperty(MapperProperty): def init(self): # Resolve columns to attributes. - self.discriminator = self._column_to_property(self.discriminator) - self.id = self._column_to_property(self.id) + self.discriminator = self._column_to_property(self._discriminator_col) + self.id = self._column_to_property(self._id_col) class Comparator(PropComparator): @@ -96,6 +99,19 @@ class GenericRelationshipProperty(MapperProperty): self.prop = prop self._parentmapper = parentmapper + def __eq__(self, other): + discriminator = table_name(other) + q = self.prop._discriminator_col == discriminator + q &= self.prop._id_col == other.id + return q + + def __ne__(self, other): + return ~(self == other) + + def is_type(self, other): + discriminator = table_name(other) + return self.prop._discriminator_col == discriminator + def instrument_class(self, mapper): attributes.register_attribute( mapper.class_, @@ -107,5 +123,6 @@ class GenericRelationshipProperty(MapperProperty): parent_token=self ) + def generic_relationship(*args, **kwargs): return GenericRelationshipProperty(*args, **kwargs) diff --git a/tests/test_generic.py b/tests/test_generic.py index 3738be7..6443076 100644 --- a/tests/test_generic.py +++ b/tests/test_generic.py @@ -1,5 +1,4 @@ import sqlalchemy as sa -from sqlalchemy import orm from tests import TestCase from sqlalchemy_utils import generic_relationship @@ -15,10 +14,6 @@ class TestGenericForiegnKey(TestCase): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) - building_id = sa.Column(sa.Integer, sa.ForeignKey(Building.id), name='buildingID') - - building = orm.relationship(Building) - class Event(self.Base): __tablename__ = 'event' id = sa.Column(sa.Integer, primary_key=True) @@ -64,3 +59,67 @@ class TestGenericForiegnKey(TestCase): self.session.commit() assert event.object == user + + def test_compare_instance(self): + user1 = self.User() + user2 = self.User() + + self.session.add_all([user1, user2]) + self.session.commit() + + event = self.Event(object=user1) + + self.session.add(event) + self.session.commit() + + assert event.object == user1 + assert event.object != user2 + + def test_compare_query(self): + user1 = self.User() + user2 = self.User() + + self.session.add_all([user1, user2]) + self.session.commit() + + event = self.Event(object=user1) + + self.session.add(event) + self.session.commit() + + q = self.session.query(self.Event) + assert q.filter_by(object=user1).first() is not None + assert q.filter_by(object=user2).first() is None + assert q.filter(self.Event.object == user2).first() is None + + def test_compare_not_query(self): + user1 = self.User() + user2 = self.User() + + self.session.add_all([user1, user2]) + self.session.commit() + + event = self.Event(object=user1) + + self.session.add(event) + self.session.commit() + + q = self.session.query(self.Event) + assert q.filter(self.Event.object != user2).first() is not None + + def test_compare_type(self): + user1 = self.User() + user2 = self.User() + + self.session.add_all([user1, user2]) + self.session.commit() + + event1 = self.Event(object=user1) + event2 = self.Event(object=user2) + + self.session.add_all([event1, event2]) + self.session.commit() + + statement = self.Event.object.is_type(self.User) + q = self.session.query(self.Event).filter(statement) + assert q.first() is not None