Add support for single table inheritance, refs #69
This commit is contained in:
@@ -7,9 +7,7 @@ from sqlalchemy.orm.attributes import (
|
||||
set_committed_value, InstrumentedAttribute
|
||||
)
|
||||
from sqlalchemy.orm.session import object_session
|
||||
from sqlalchemy_utils.generic import (
|
||||
GenericRelationshipProperty, class_from_table_name
|
||||
)
|
||||
from sqlalchemy_utils.generic import GenericRelationshipProperty
|
||||
from sqlalchemy_utils.functions.orm import (
|
||||
list_local_values,
|
||||
list_local_remote_exprs,
|
||||
@@ -345,9 +343,7 @@ class GenericRelationshipFetcher(object):
|
||||
|
||||
def _queries(self, state, id_dict):
|
||||
for discriminator, ids in six.iteritems(id_dict):
|
||||
class_ = class_from_table_name(
|
||||
state, discriminator
|
||||
)
|
||||
class_ = state.class_._decl_class_registry.get(discriminator)
|
||||
yield self.path.session.query(
|
||||
class_
|
||||
).filter(
|
||||
|
@@ -7,19 +7,11 @@ from sqlalchemy.orm import ColumnProperty
|
||||
from sqlalchemy.orm.interfaces import MapperProperty, PropComparator
|
||||
from sqlalchemy.orm.session import _state_session
|
||||
from sqlalchemy.util import set_creation_order
|
||||
from sqlalchemy_utils.functions import table_name, identity
|
||||
from sqlalchemy_utils.functions import identity
|
||||
|
||||
from .exceptions import ImproperlyConfigured
|
||||
|
||||
|
||||
def class_from_table_name(state, table):
|
||||
for class_ in state.class_._decl_class_registry.values():
|
||||
name = table_name(class_)
|
||||
if name and name == table:
|
||||
return class_
|
||||
return None
|
||||
|
||||
|
||||
class GenericAttributeImpl(attributes.ScalarAttributeImpl):
|
||||
def get(self, state, dict_, passive=attributes.PASSIVE_OFF):
|
||||
if self.key in dict_:
|
||||
@@ -35,7 +27,7 @@ class GenericAttributeImpl(attributes.ScalarAttributeImpl):
|
||||
# Find class for discriminator.
|
||||
# TODO: Perhaps optimize with some sort of lookup?
|
||||
discriminator = self.get_state_discriminator(state)
|
||||
target_class = class_from_table_name(state, discriminator)
|
||||
target_class = state.class_._decl_class_registry.get(discriminator)
|
||||
|
||||
if target_class is None:
|
||||
# Unknown discriminator; return nothing.
|
||||
@@ -75,12 +67,13 @@ class GenericAttributeImpl(attributes.ScalarAttributeImpl):
|
||||
else:
|
||||
# Get the primary key of the initiator and ensure we
|
||||
# can support this assignment.
|
||||
mapper = class_mapper(type(initiator))
|
||||
class_ = type(initiator)
|
||||
mapper = class_mapper(class_)
|
||||
|
||||
pk = mapper.identity_key_from_instance(initiator)[1]
|
||||
|
||||
# Set the identifier and the discriminator.
|
||||
discriminator = table_name(initiator)
|
||||
discriminator = unicode(class_.__name__)
|
||||
|
||||
for index, id in enumerate(self.parent_token.id):
|
||||
dict_[id.key] = pk[index]
|
||||
@@ -150,7 +143,7 @@ class GenericRelationshipProperty(MapperProperty):
|
||||
self._parentmapper = parentmapper
|
||||
|
||||
def __eq__(self, other):
|
||||
discriminator = table_name(other)
|
||||
discriminator = unicode(type(other).__name__)
|
||||
q = self.property._discriminator_col == discriminator
|
||||
other_id = identity(other)
|
||||
for index, id in enumerate(self.property._id_cols):
|
||||
@@ -161,7 +154,7 @@ class GenericRelationshipProperty(MapperProperty):
|
||||
return ~(self == other)
|
||||
|
||||
def is_type(self, other):
|
||||
discriminator = table_name(other)
|
||||
discriminator = unicode(other.__name__)
|
||||
return self.property._discriminator_col == discriminator
|
||||
|
||||
def instrument_class(self, mapper):
|
||||
|
@@ -16,7 +16,7 @@ class GenericRelationshipTestCase(TestCase):
|
||||
|
||||
event = self.Event()
|
||||
event.object_id = user.id
|
||||
event.object_type = type(user).__tablename__
|
||||
event.object_type = unicode(type(user).__name__)
|
||||
|
||||
assert event.object is None
|
||||
|
||||
@@ -34,7 +34,7 @@ class GenericRelationshipTestCase(TestCase):
|
||||
event = self.Event(object=user)
|
||||
|
||||
assert event.object_id == user.id
|
||||
assert event.object_type == type(user).__tablename__
|
||||
assert event.object_type == type(user).__name__
|
||||
|
||||
self.session.add(event)
|
||||
self.session.commit()
|
||||
|
@@ -52,7 +52,7 @@ class TestGenericRelationship(GenericRelationshipTestCase):
|
||||
|
||||
event = self.Event()
|
||||
event.object_id = user.id
|
||||
event.object_type = type(user).__tablename__
|
||||
event.object_type = unicode(type(user).__name__)
|
||||
event.object_code = user.code
|
||||
|
||||
assert event.object is None
|
||||
|
@@ -32,11 +32,11 @@ class TestGenericRelationship(TestCase):
|
||||
|
||||
@hybrid_property
|
||||
def object_version_type(self):
|
||||
return self.object_type + '_history'
|
||||
return self.object_type + 'History'
|
||||
|
||||
@object_version_type.expression
|
||||
def object_version_type(cls):
|
||||
return sa.func.concat(cls.object_type, '_history')
|
||||
return sa.func.concat(cls.object_type, 'History')
|
||||
|
||||
object_version = generic_relationship(
|
||||
object_version_type, (object_id, transaction_id)
|
||||
@@ -55,7 +55,7 @@ class TestGenericRelationship(TestCase):
|
||||
|
||||
event = self.Event(transaction_id=1)
|
||||
event.object_id = user.id
|
||||
event.object_type = type(user).__tablename__
|
||||
event.object_type = unicode(type(user).__name__)
|
||||
assert event.object is None
|
||||
|
||||
self.session.add(event)
|
||||
|
144
tests/generic_relationship/test_single_table_inheritance.py
Normal file
144
tests/generic_relationship/test_single_table_inheritance.py
Normal file
@@ -0,0 +1,144 @@
|
||||
from __future__ import unicode_literals
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy_utils import generic_relationship
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestGenericRelationship(TestCase):
|
||||
def create_models(self):
|
||||
class Employee(self.Base):
|
||||
__tablename__ = 'employee'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.String(50))
|
||||
type = sa.Column(sa.String(20))
|
||||
|
||||
__mapper_args__ = {
|
||||
'polymorphic_on': type,
|
||||
'polymorphic_identity': 'employee'
|
||||
}
|
||||
|
||||
class Manager(Employee):
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity': 'manager'
|
||||
}
|
||||
|
||||
class Engineer(Employee):
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity': 'engineer'
|
||||
}
|
||||
|
||||
class Event(self.Base):
|
||||
__tablename__ = 'event'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
object_type = sa.Column(sa.Unicode(255))
|
||||
object_id = sa.Column(sa.Integer, nullable=False)
|
||||
|
||||
object = generic_relationship(object_type, object_id)
|
||||
|
||||
self.Employee = Employee
|
||||
self.Manager = Manager
|
||||
self.Engineer = Engineer
|
||||
self.Event = Event
|
||||
|
||||
def test_set_as_none(self):
|
||||
event = self.Event()
|
||||
event.object = None
|
||||
assert event.object is None
|
||||
|
||||
def test_set_manual_and_get(self):
|
||||
manager = self.Manager()
|
||||
|
||||
self.session.add(manager)
|
||||
self.session.commit()
|
||||
|
||||
event = self.Event()
|
||||
event.object_id = manager.id
|
||||
event.object_type = unicode(type(manager).__name__)
|
||||
|
||||
assert event.object is None
|
||||
|
||||
self.session.add(event)
|
||||
self.session.commit()
|
||||
|
||||
assert event.object == manager
|
||||
|
||||
def test_set_and_get(self):
|
||||
manager = self.Manager()
|
||||
|
||||
self.session.add(manager)
|
||||
self.session.commit()
|
||||
|
||||
event = self.Event(object=manager)
|
||||
|
||||
assert event.object_id == manager.id
|
||||
assert event.object_type == type(manager).__name__
|
||||
|
||||
self.session.add(event)
|
||||
self.session.commit()
|
||||
|
||||
assert event.object == manager
|
||||
|
||||
def test_compare_instance(self):
|
||||
manager1 = self.Manager()
|
||||
manager2 = self.Manager()
|
||||
|
||||
self.session.add_all([manager1, manager2])
|
||||
self.session.commit()
|
||||
|
||||
event = self.Event(object=manager1)
|
||||
|
||||
self.session.add(event)
|
||||
self.session.commit()
|
||||
|
||||
assert event.object == manager1
|
||||
assert event.object != manager2
|
||||
|
||||
def test_compare_query(self):
|
||||
manager1 = self.Manager()
|
||||
manager2 = self.Manager()
|
||||
|
||||
self.session.add_all([manager1, manager2])
|
||||
self.session.commit()
|
||||
|
||||
event = self.Event(object=manager1)
|
||||
|
||||
self.session.add(event)
|
||||
self.session.commit()
|
||||
|
||||
q = self.session.query(self.Event)
|
||||
assert q.filter_by(object=manager1).first() is not None
|
||||
assert q.filter_by(object=manager2).first() is None
|
||||
assert q.filter(self.Event.object == manager2).first() is None
|
||||
|
||||
def test_compare_not_query(self):
|
||||
manager1 = self.Manager()
|
||||
manager2 = self.Manager()
|
||||
|
||||
self.session.add_all([manager1, manager2])
|
||||
self.session.commit()
|
||||
|
||||
event = self.Event(object=manager1)
|
||||
|
||||
self.session.add(event)
|
||||
self.session.commit()
|
||||
|
||||
q = self.session.query(self.Event)
|
||||
assert q.filter(self.Event.object != manager2).first() is not None
|
||||
|
||||
def test_compare_type(self):
|
||||
manager1 = self.Manager()
|
||||
manager2 = self.Manager()
|
||||
|
||||
self.session.add_all([manager1, manager2])
|
||||
self.session.commit()
|
||||
|
||||
event1 = self.Event(object=manager1)
|
||||
event2 = self.Event(object=manager2)
|
||||
|
||||
self.session.add_all([event1, event2])
|
||||
self.session.commit()
|
||||
|
||||
statement = self.Event.object.is_type(self.Manager)
|
||||
q = self.session.query(self.Event).filter(statement)
|
||||
assert q.first() is not None
|
Reference in New Issue
Block a user