Add support for single table inheritance, refs #69
This commit is contained in:
@@ -7,9 +7,7 @@ from sqlalchemy.orm.attributes import (
|
|||||||
set_committed_value, InstrumentedAttribute
|
set_committed_value, InstrumentedAttribute
|
||||||
)
|
)
|
||||||
from sqlalchemy.orm.session import object_session
|
from sqlalchemy.orm.session import object_session
|
||||||
from sqlalchemy_utils.generic import (
|
from sqlalchemy_utils.generic import GenericRelationshipProperty
|
||||||
GenericRelationshipProperty, class_from_table_name
|
|
||||||
)
|
|
||||||
from sqlalchemy_utils.functions.orm import (
|
from sqlalchemy_utils.functions.orm import (
|
||||||
list_local_values,
|
list_local_values,
|
||||||
list_local_remote_exprs,
|
list_local_remote_exprs,
|
||||||
@@ -345,9 +343,7 @@ class GenericRelationshipFetcher(object):
|
|||||||
|
|
||||||
def _queries(self, state, id_dict):
|
def _queries(self, state, id_dict):
|
||||||
for discriminator, ids in six.iteritems(id_dict):
|
for discriminator, ids in six.iteritems(id_dict):
|
||||||
class_ = class_from_table_name(
|
class_ = state.class_._decl_class_registry.get(discriminator)
|
||||||
state, discriminator
|
|
||||||
)
|
|
||||||
yield self.path.session.query(
|
yield self.path.session.query(
|
||||||
class_
|
class_
|
||||||
).filter(
|
).filter(
|
||||||
|
@@ -7,19 +7,11 @@ from sqlalchemy.orm import ColumnProperty
|
|||||||
from sqlalchemy.orm.interfaces import MapperProperty, PropComparator
|
from sqlalchemy.orm.interfaces import MapperProperty, PropComparator
|
||||||
from sqlalchemy.orm.session import _state_session
|
from sqlalchemy.orm.session import _state_session
|
||||||
from sqlalchemy.util import set_creation_order
|
from sqlalchemy.util import set_creation_order
|
||||||
from sqlalchemy_utils.functions import table_name, identity
|
from sqlalchemy_utils.functions import identity
|
||||||
|
|
||||||
from .exceptions import ImproperlyConfigured
|
from .exceptions import ImproperlyConfigured
|
||||||
|
|
||||||
|
|
||||||
def class_from_table_name(state, table):
|
|
||||||
for class_ in state.class_._decl_class_registry.values():
|
|
||||||
name = table_name(class_)
|
|
||||||
if name and name == table:
|
|
||||||
return class_
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class GenericAttributeImpl(attributes.ScalarAttributeImpl):
|
class GenericAttributeImpl(attributes.ScalarAttributeImpl):
|
||||||
def get(self, state, dict_, passive=attributes.PASSIVE_OFF):
|
def get(self, state, dict_, passive=attributes.PASSIVE_OFF):
|
||||||
if self.key in dict_:
|
if self.key in dict_:
|
||||||
@@ -35,7 +27,7 @@ class GenericAttributeImpl(attributes.ScalarAttributeImpl):
|
|||||||
# Find class for discriminator.
|
# Find class for discriminator.
|
||||||
# TODO: Perhaps optimize with some sort of lookup?
|
# TODO: Perhaps optimize with some sort of lookup?
|
||||||
discriminator = self.get_state_discriminator(state)
|
discriminator = self.get_state_discriminator(state)
|
||||||
target_class = class_from_table_name(state, discriminator)
|
target_class = state.class_._decl_class_registry.get(discriminator)
|
||||||
|
|
||||||
if target_class is None:
|
if target_class is None:
|
||||||
# Unknown discriminator; return nothing.
|
# Unknown discriminator; return nothing.
|
||||||
@@ -75,12 +67,13 @@ class GenericAttributeImpl(attributes.ScalarAttributeImpl):
|
|||||||
else:
|
else:
|
||||||
# Get the primary key of the initiator and ensure we
|
# Get the primary key of the initiator and ensure we
|
||||||
# can support this assignment.
|
# can support this assignment.
|
||||||
mapper = class_mapper(type(initiator))
|
class_ = type(initiator)
|
||||||
|
mapper = class_mapper(class_)
|
||||||
|
|
||||||
pk = mapper.identity_key_from_instance(initiator)[1]
|
pk = mapper.identity_key_from_instance(initiator)[1]
|
||||||
|
|
||||||
# Set the identifier and the discriminator.
|
# Set the identifier and the discriminator.
|
||||||
discriminator = table_name(initiator)
|
discriminator = unicode(class_.__name__)
|
||||||
|
|
||||||
for index, id in enumerate(self.parent_token.id):
|
for index, id in enumerate(self.parent_token.id):
|
||||||
dict_[id.key] = pk[index]
|
dict_[id.key] = pk[index]
|
||||||
@@ -150,7 +143,7 @@ class GenericRelationshipProperty(MapperProperty):
|
|||||||
self._parentmapper = parentmapper
|
self._parentmapper = parentmapper
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
discriminator = table_name(other)
|
discriminator = unicode(type(other).__name__)
|
||||||
q = self.property._discriminator_col == discriminator
|
q = self.property._discriminator_col == discriminator
|
||||||
other_id = identity(other)
|
other_id = identity(other)
|
||||||
for index, id in enumerate(self.property._id_cols):
|
for index, id in enumerate(self.property._id_cols):
|
||||||
@@ -161,7 +154,7 @@ class GenericRelationshipProperty(MapperProperty):
|
|||||||
return ~(self == other)
|
return ~(self == other)
|
||||||
|
|
||||||
def is_type(self, other):
|
def is_type(self, other):
|
||||||
discriminator = table_name(other)
|
discriminator = unicode(other.__name__)
|
||||||
return self.property._discriminator_col == discriminator
|
return self.property._discriminator_col == discriminator
|
||||||
|
|
||||||
def instrument_class(self, mapper):
|
def instrument_class(self, mapper):
|
||||||
|
@@ -16,7 +16,7 @@ class GenericRelationshipTestCase(TestCase):
|
|||||||
|
|
||||||
event = self.Event()
|
event = self.Event()
|
||||||
event.object_id = user.id
|
event.object_id = user.id
|
||||||
event.object_type = type(user).__tablename__
|
event.object_type = unicode(type(user).__name__)
|
||||||
|
|
||||||
assert event.object is None
|
assert event.object is None
|
||||||
|
|
||||||
@@ -34,7 +34,7 @@ class GenericRelationshipTestCase(TestCase):
|
|||||||
event = self.Event(object=user)
|
event = self.Event(object=user)
|
||||||
|
|
||||||
assert event.object_id == user.id
|
assert event.object_id == user.id
|
||||||
assert event.object_type == type(user).__tablename__
|
assert event.object_type == type(user).__name__
|
||||||
|
|
||||||
self.session.add(event)
|
self.session.add(event)
|
||||||
self.session.commit()
|
self.session.commit()
|
||||||
|
@@ -52,7 +52,7 @@ class TestGenericRelationship(GenericRelationshipTestCase):
|
|||||||
|
|
||||||
event = self.Event()
|
event = self.Event()
|
||||||
event.object_id = user.id
|
event.object_id = user.id
|
||||||
event.object_type = type(user).__tablename__
|
event.object_type = unicode(type(user).__name__)
|
||||||
event.object_code = user.code
|
event.object_code = user.code
|
||||||
|
|
||||||
assert event.object is None
|
assert event.object is None
|
||||||
|
@@ -32,11 +32,11 @@ class TestGenericRelationship(TestCase):
|
|||||||
|
|
||||||
@hybrid_property
|
@hybrid_property
|
||||||
def object_version_type(self):
|
def object_version_type(self):
|
||||||
return self.object_type + '_history'
|
return self.object_type + 'History'
|
||||||
|
|
||||||
@object_version_type.expression
|
@object_version_type.expression
|
||||||
def object_version_type(cls):
|
def object_version_type(cls):
|
||||||
return sa.func.concat(cls.object_type, '_history')
|
return sa.func.concat(cls.object_type, 'History')
|
||||||
|
|
||||||
object_version = generic_relationship(
|
object_version = generic_relationship(
|
||||||
object_version_type, (object_id, transaction_id)
|
object_version_type, (object_id, transaction_id)
|
||||||
@@ -55,7 +55,7 @@ class TestGenericRelationship(TestCase):
|
|||||||
|
|
||||||
event = self.Event(transaction_id=1)
|
event = self.Event(transaction_id=1)
|
||||||
event.object_id = user.id
|
event.object_id = user.id
|
||||||
event.object_type = type(user).__tablename__
|
event.object_type = unicode(type(user).__name__)
|
||||||
assert event.object is None
|
assert event.object is None
|
||||||
|
|
||||||
self.session.add(event)
|
self.session.add(event)
|
||||||
|
144
tests/generic_relationship/test_single_table_inheritance.py
Normal file
144
tests/generic_relationship/test_single_table_inheritance.py
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
from __future__ import unicode_literals
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy_utils import generic_relationship
|
||||||
|
from tests import TestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestGenericRelationship(TestCase):
|
||||||
|
def create_models(self):
|
||||||
|
class Employee(self.Base):
|
||||||
|
__tablename__ = 'employee'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.String(50))
|
||||||
|
type = sa.Column(sa.String(20))
|
||||||
|
|
||||||
|
__mapper_args__ = {
|
||||||
|
'polymorphic_on': type,
|
||||||
|
'polymorphic_identity': 'employee'
|
||||||
|
}
|
||||||
|
|
||||||
|
class Manager(Employee):
|
||||||
|
__mapper_args__ = {
|
||||||
|
'polymorphic_identity': 'manager'
|
||||||
|
}
|
||||||
|
|
||||||
|
class Engineer(Employee):
|
||||||
|
__mapper_args__ = {
|
||||||
|
'polymorphic_identity': 'engineer'
|
||||||
|
}
|
||||||
|
|
||||||
|
class Event(self.Base):
|
||||||
|
__tablename__ = 'event'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
|
||||||
|
object_type = sa.Column(sa.Unicode(255))
|
||||||
|
object_id = sa.Column(sa.Integer, nullable=False)
|
||||||
|
|
||||||
|
object = generic_relationship(object_type, object_id)
|
||||||
|
|
||||||
|
self.Employee = Employee
|
||||||
|
self.Manager = Manager
|
||||||
|
self.Engineer = Engineer
|
||||||
|
self.Event = Event
|
||||||
|
|
||||||
|
def test_set_as_none(self):
|
||||||
|
event = self.Event()
|
||||||
|
event.object = None
|
||||||
|
assert event.object is None
|
||||||
|
|
||||||
|
def test_set_manual_and_get(self):
|
||||||
|
manager = self.Manager()
|
||||||
|
|
||||||
|
self.session.add(manager)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
event = self.Event()
|
||||||
|
event.object_id = manager.id
|
||||||
|
event.object_type = unicode(type(manager).__name__)
|
||||||
|
|
||||||
|
assert event.object is None
|
||||||
|
|
||||||
|
self.session.add(event)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
assert event.object == manager
|
||||||
|
|
||||||
|
def test_set_and_get(self):
|
||||||
|
manager = self.Manager()
|
||||||
|
|
||||||
|
self.session.add(manager)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
event = self.Event(object=manager)
|
||||||
|
|
||||||
|
assert event.object_id == manager.id
|
||||||
|
assert event.object_type == type(manager).__name__
|
||||||
|
|
||||||
|
self.session.add(event)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
assert event.object == manager
|
||||||
|
|
||||||
|
def test_compare_instance(self):
|
||||||
|
manager1 = self.Manager()
|
||||||
|
manager2 = self.Manager()
|
||||||
|
|
||||||
|
self.session.add_all([manager1, manager2])
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
event = self.Event(object=manager1)
|
||||||
|
|
||||||
|
self.session.add(event)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
assert event.object == manager1
|
||||||
|
assert event.object != manager2
|
||||||
|
|
||||||
|
def test_compare_query(self):
|
||||||
|
manager1 = self.Manager()
|
||||||
|
manager2 = self.Manager()
|
||||||
|
|
||||||
|
self.session.add_all([manager1, manager2])
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
event = self.Event(object=manager1)
|
||||||
|
|
||||||
|
self.session.add(event)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
q = self.session.query(self.Event)
|
||||||
|
assert q.filter_by(object=manager1).first() is not None
|
||||||
|
assert q.filter_by(object=manager2).first() is None
|
||||||
|
assert q.filter(self.Event.object == manager2).first() is None
|
||||||
|
|
||||||
|
def test_compare_not_query(self):
|
||||||
|
manager1 = self.Manager()
|
||||||
|
manager2 = self.Manager()
|
||||||
|
|
||||||
|
self.session.add_all([manager1, manager2])
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
event = self.Event(object=manager1)
|
||||||
|
|
||||||
|
self.session.add(event)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
q = self.session.query(self.Event)
|
||||||
|
assert q.filter(self.Event.object != manager2).first() is not None
|
||||||
|
|
||||||
|
def test_compare_type(self):
|
||||||
|
manager1 = self.Manager()
|
||||||
|
manager2 = self.Manager()
|
||||||
|
|
||||||
|
self.session.add_all([manager1, manager2])
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
event1 = self.Event(object=manager1)
|
||||||
|
event2 = self.Event(object=manager2)
|
||||||
|
|
||||||
|
self.session.add_all([event1, event2])
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
statement = self.Event.object.is_type(self.Manager)
|
||||||
|
q = self.session.query(self.Event).filter(statement)
|
||||||
|
assert q.first() is not None
|
Reference in New Issue
Block a user