Finish generic relationship property; add comparators and further testing.

This commit is contained in:
Ryan Leckey
2013-09-09 00:38:32 -07:00
parent 684617f875
commit d94aa24110
2 changed files with 89 additions and 13 deletions

View File

@@ -2,7 +2,8 @@ from sqlalchemy.orm.interfaces import MapperProperty, PropComparator
from sqlalchemy.orm.session import _state_session from sqlalchemy.orm.session import _state_session
from sqlalchemy.orm import attributes, class_mapper from sqlalchemy.orm import attributes, class_mapper
from sqlalchemy.util import set_creation_order from sqlalchemy.util import set_creation_order
from sqlalchemy import exc as sa_exc, inspect from sqlalchemy import exc as sa_exc
from .functions import table_name
class GenericAttributeImpl(attributes.ScalarAttributeImpl): class GenericAttributeImpl(attributes.ScalarAttributeImpl):
@@ -23,7 +24,7 @@ class GenericAttributeImpl(attributes.ScalarAttributeImpl):
discriminator = state.attrs[self.parent_token.discriminator.key].value discriminator = state.attrs[self.parent_token.discriminator.key].value
target_class = None target_class = None
for class_ in state.class_._decl_class_registry.values(): for class_ in state.class_._decl_class_registry.values():
name = getattr(class_, '__tablename__', None) name = table_name(class_)
if name and name == discriminator: if name and name == discriminator:
target_class = class_ target_class = class_
@@ -46,7 +47,8 @@ class GenericAttributeImpl(attributes.ScalarAttributeImpl):
# Set us on the state. # Set us on the state.
dict_[self.key] = initiator dict_[self.key] = initiator
# Get the primary key of the initiator. # Get the primary key of the initiator and ensure we
# can support this assignment.
mapper = class_mapper(type(initiator)) mapper = class_mapper(type(initiator))
if len(mapper.primary_key) > 1: if len(mapper.primary_key) > 1:
raise sa_exc.InvalidRequestError( raise sa_exc.InvalidRequestError(
@@ -56,10 +58,11 @@ class GenericAttributeImpl(attributes.ScalarAttributeImpl):
pk = mapper.identity_key_from_instance(initiator)[1][0] pk = mapper.identity_key_from_instance(initiator)[1][0]
# Set the identifier and the discriminator. # Set the identifier and the discriminator.
discriminator = type(initiator).__tablename__ discriminator = table_name(initiator)
dict_[self.parent_token.id.key] = pk dict_[self.parent_token.id.key] = pk
dict_[self.parent_token.discriminator.key] = discriminator dict_[self.parent_token.discriminator.key] = discriminator
class GenericRelationshipProperty(MapperProperty): class GenericRelationshipProperty(MapperProperty):
"""A generic form of the relationship property. """A generic form of the relationship property.
@@ -73,8 +76,8 @@ class GenericRelationshipProperty(MapperProperty):
""" """
def __init__(self, discriminator, id, doc=None): def __init__(self, discriminator, id, doc=None):
self.discriminator = discriminator self._discriminator_col = discriminator
self.id = id self._id_col = id
self.doc = doc self.doc = doc
set_creation_order(self) set_creation_order(self)
@@ -87,8 +90,8 @@ class GenericRelationshipProperty(MapperProperty):
def init(self): def init(self):
# Resolve columns to attributes. # Resolve columns to attributes.
self.discriminator = self._column_to_property(self.discriminator) self.discriminator = self._column_to_property(self._discriminator_col)
self.id = self._column_to_property(self.id) self.id = self._column_to_property(self._id_col)
class Comparator(PropComparator): class Comparator(PropComparator):
@@ -96,6 +99,19 @@ class GenericRelationshipProperty(MapperProperty):
self.prop = prop self.prop = prop
self._parentmapper = parentmapper self._parentmapper = parentmapper
def __eq__(self, other):
discriminator = table_name(other)
q = self.prop._discriminator_col == discriminator
q &= self.prop._id_col == other.id
return q
def __ne__(self, other):
return ~(self == other)
def is_type(self, other):
discriminator = table_name(other)
return self.prop._discriminator_col == discriminator
def instrument_class(self, mapper): def instrument_class(self, mapper):
attributes.register_attribute( attributes.register_attribute(
mapper.class_, mapper.class_,
@@ -107,5 +123,6 @@ class GenericRelationshipProperty(MapperProperty):
parent_token=self parent_token=self
) )
def generic_relationship(*args, **kwargs): def generic_relationship(*args, **kwargs):
return GenericRelationshipProperty(*args, **kwargs) return GenericRelationshipProperty(*args, **kwargs)

View File

@@ -1,5 +1,4 @@
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy import orm
from tests import TestCase from tests import TestCase
from sqlalchemy_utils import generic_relationship from sqlalchemy_utils import generic_relationship
@@ -15,10 +14,6 @@ class TestGenericForiegnKey(TestCase):
__tablename__ = 'user' __tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
building_id = sa.Column(sa.Integer, sa.ForeignKey(Building.id), name='buildingID')
building = orm.relationship(Building)
class Event(self.Base): class Event(self.Base):
__tablename__ = 'event' __tablename__ = 'event'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
@@ -64,3 +59,67 @@ class TestGenericForiegnKey(TestCase):
self.session.commit() self.session.commit()
assert event.object == user assert event.object == user
def test_compare_instance(self):
user1 = self.User()
user2 = self.User()
self.session.add_all([user1, user2])
self.session.commit()
event = self.Event(object=user1)
self.session.add(event)
self.session.commit()
assert event.object == user1
assert event.object != user2
def test_compare_query(self):
user1 = self.User()
user2 = self.User()
self.session.add_all([user1, user2])
self.session.commit()
event = self.Event(object=user1)
self.session.add(event)
self.session.commit()
q = self.session.query(self.Event)
assert q.filter_by(object=user1).first() is not None
assert q.filter_by(object=user2).first() is None
assert q.filter(self.Event.object == user2).first() is None
def test_compare_not_query(self):
user1 = self.User()
user2 = self.User()
self.session.add_all([user1, user2])
self.session.commit()
event = self.Event(object=user1)
self.session.add(event)
self.session.commit()
q = self.session.query(self.Event)
assert q.filter(self.Event.object != user2).first() is not None
def test_compare_type(self):
user1 = self.User()
user2 = self.User()
self.session.add_all([user1, user2])
self.session.commit()
event1 = self.Event(object=user1)
event2 = self.Event(object=user2)
self.session.add_all([event1, event2])
self.session.commit()
statement = self.Event.object.is_type(self.User)
q = self.session.query(self.Event).filter(statement)
assert q.first() is not None