Finish generic relationship property; add comparators and further testing.
This commit is contained in:
@@ -2,7 +2,8 @@ from sqlalchemy.orm.interfaces import MapperProperty, PropComparator
|
|||||||
from sqlalchemy.orm.session import _state_session
|
from sqlalchemy.orm.session import _state_session
|
||||||
from sqlalchemy.orm import attributes, class_mapper
|
from sqlalchemy.orm import attributes, class_mapper
|
||||||
from sqlalchemy.util import set_creation_order
|
from sqlalchemy.util import set_creation_order
|
||||||
from sqlalchemy import exc as sa_exc, inspect
|
from sqlalchemy import exc as sa_exc
|
||||||
|
from .functions import table_name
|
||||||
|
|
||||||
|
|
||||||
class GenericAttributeImpl(attributes.ScalarAttributeImpl):
|
class GenericAttributeImpl(attributes.ScalarAttributeImpl):
|
||||||
@@ -23,7 +24,7 @@ class GenericAttributeImpl(attributes.ScalarAttributeImpl):
|
|||||||
discriminator = state.attrs[self.parent_token.discriminator.key].value
|
discriminator = state.attrs[self.parent_token.discriminator.key].value
|
||||||
target_class = None
|
target_class = None
|
||||||
for class_ in state.class_._decl_class_registry.values():
|
for class_ in state.class_._decl_class_registry.values():
|
||||||
name = getattr(class_, '__tablename__', None)
|
name = table_name(class_)
|
||||||
if name and name == discriminator:
|
if name and name == discriminator:
|
||||||
target_class = class_
|
target_class = class_
|
||||||
|
|
||||||
@@ -46,7 +47,8 @@ class GenericAttributeImpl(attributes.ScalarAttributeImpl):
|
|||||||
# Set us on the state.
|
# Set us on the state.
|
||||||
dict_[self.key] = initiator
|
dict_[self.key] = initiator
|
||||||
|
|
||||||
# Get the primary key of the initiator.
|
# Get the primary key of the initiator and ensure we
|
||||||
|
# can support this assignment.
|
||||||
mapper = class_mapper(type(initiator))
|
mapper = class_mapper(type(initiator))
|
||||||
if len(mapper.primary_key) > 1:
|
if len(mapper.primary_key) > 1:
|
||||||
raise sa_exc.InvalidRequestError(
|
raise sa_exc.InvalidRequestError(
|
||||||
@@ -56,10 +58,11 @@ class GenericAttributeImpl(attributes.ScalarAttributeImpl):
|
|||||||
pk = mapper.identity_key_from_instance(initiator)[1][0]
|
pk = mapper.identity_key_from_instance(initiator)[1][0]
|
||||||
|
|
||||||
# Set the identifier and the discriminator.
|
# Set the identifier and the discriminator.
|
||||||
discriminator = type(initiator).__tablename__
|
discriminator = table_name(initiator)
|
||||||
dict_[self.parent_token.id.key] = pk
|
dict_[self.parent_token.id.key] = pk
|
||||||
dict_[self.parent_token.discriminator.key] = discriminator
|
dict_[self.parent_token.discriminator.key] = discriminator
|
||||||
|
|
||||||
|
|
||||||
class GenericRelationshipProperty(MapperProperty):
|
class GenericRelationshipProperty(MapperProperty):
|
||||||
"""A generic form of the relationship property.
|
"""A generic form of the relationship property.
|
||||||
|
|
||||||
@@ -73,8 +76,8 @@ class GenericRelationshipProperty(MapperProperty):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, discriminator, id, doc=None):
|
def __init__(self, discriminator, id, doc=None):
|
||||||
self.discriminator = discriminator
|
self._discriminator_col = discriminator
|
||||||
self.id = id
|
self._id_col = id
|
||||||
self.doc = doc
|
self.doc = doc
|
||||||
|
|
||||||
set_creation_order(self)
|
set_creation_order(self)
|
||||||
@@ -87,8 +90,8 @@ class GenericRelationshipProperty(MapperProperty):
|
|||||||
|
|
||||||
def init(self):
|
def init(self):
|
||||||
# Resolve columns to attributes.
|
# Resolve columns to attributes.
|
||||||
self.discriminator = self._column_to_property(self.discriminator)
|
self.discriminator = self._column_to_property(self._discriminator_col)
|
||||||
self.id = self._column_to_property(self.id)
|
self.id = self._column_to_property(self._id_col)
|
||||||
|
|
||||||
class Comparator(PropComparator):
|
class Comparator(PropComparator):
|
||||||
|
|
||||||
@@ -96,6 +99,19 @@ class GenericRelationshipProperty(MapperProperty):
|
|||||||
self.prop = prop
|
self.prop = prop
|
||||||
self._parentmapper = parentmapper
|
self._parentmapper = parentmapper
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
discriminator = table_name(other)
|
||||||
|
q = self.prop._discriminator_col == discriminator
|
||||||
|
q &= self.prop._id_col == other.id
|
||||||
|
return q
|
||||||
|
|
||||||
|
def __ne__(self, other):
|
||||||
|
return ~(self == other)
|
||||||
|
|
||||||
|
def is_type(self, other):
|
||||||
|
discriminator = table_name(other)
|
||||||
|
return self.prop._discriminator_col == discriminator
|
||||||
|
|
||||||
def instrument_class(self, mapper):
|
def instrument_class(self, mapper):
|
||||||
attributes.register_attribute(
|
attributes.register_attribute(
|
||||||
mapper.class_,
|
mapper.class_,
|
||||||
@@ -107,5 +123,6 @@ class GenericRelationshipProperty(MapperProperty):
|
|||||||
parent_token=self
|
parent_token=self
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def generic_relationship(*args, **kwargs):
|
def generic_relationship(*args, **kwargs):
|
||||||
return GenericRelationshipProperty(*args, **kwargs)
|
return GenericRelationshipProperty(*args, **kwargs)
|
||||||
|
@@ -1,5 +1,4 @@
|
|||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from sqlalchemy import orm
|
|
||||||
from tests import TestCase
|
from tests import TestCase
|
||||||
from sqlalchemy_utils import generic_relationship
|
from sqlalchemy_utils import generic_relationship
|
||||||
|
|
||||||
@@ -15,10 +14,6 @@ class TestGenericForiegnKey(TestCase):
|
|||||||
__tablename__ = 'user'
|
__tablename__ = 'user'
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
|
||||||
building_id = sa.Column(sa.Integer, sa.ForeignKey(Building.id), name='buildingID')
|
|
||||||
|
|
||||||
building = orm.relationship(Building)
|
|
||||||
|
|
||||||
class Event(self.Base):
|
class Event(self.Base):
|
||||||
__tablename__ = 'event'
|
__tablename__ = 'event'
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
@@ -64,3 +59,67 @@ class TestGenericForiegnKey(TestCase):
|
|||||||
self.session.commit()
|
self.session.commit()
|
||||||
|
|
||||||
assert event.object == user
|
assert event.object == user
|
||||||
|
|
||||||
|
def test_compare_instance(self):
|
||||||
|
user1 = self.User()
|
||||||
|
user2 = self.User()
|
||||||
|
|
||||||
|
self.session.add_all([user1, user2])
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
event = self.Event(object=user1)
|
||||||
|
|
||||||
|
self.session.add(event)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
assert event.object == user1
|
||||||
|
assert event.object != user2
|
||||||
|
|
||||||
|
def test_compare_query(self):
|
||||||
|
user1 = self.User()
|
||||||
|
user2 = self.User()
|
||||||
|
|
||||||
|
self.session.add_all([user1, user2])
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
event = self.Event(object=user1)
|
||||||
|
|
||||||
|
self.session.add(event)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
q = self.session.query(self.Event)
|
||||||
|
assert q.filter_by(object=user1).first() is not None
|
||||||
|
assert q.filter_by(object=user2).first() is None
|
||||||
|
assert q.filter(self.Event.object == user2).first() is None
|
||||||
|
|
||||||
|
def test_compare_not_query(self):
|
||||||
|
user1 = self.User()
|
||||||
|
user2 = self.User()
|
||||||
|
|
||||||
|
self.session.add_all([user1, user2])
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
event = self.Event(object=user1)
|
||||||
|
|
||||||
|
self.session.add(event)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
q = self.session.query(self.Event)
|
||||||
|
assert q.filter(self.Event.object != user2).first() is not None
|
||||||
|
|
||||||
|
def test_compare_type(self):
|
||||||
|
user1 = self.User()
|
||||||
|
user2 = self.User()
|
||||||
|
|
||||||
|
self.session.add_all([user1, user2])
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
event1 = self.Event(object=user1)
|
||||||
|
event2 = self.Event(object=user2)
|
||||||
|
|
||||||
|
self.session.add_all([event1, event2])
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
statement = self.Event.object.is_type(self.User)
|
||||||
|
q = self.session.query(self.Event).filter(statement)
|
||||||
|
assert q.first() is not None
|
||||||
|
Reference in New Issue
Block a user