Add support for single table inheritance, refs #69

This commit is contained in:
Konsta Vesterinen
2014-03-05 15:16:54 +02:00
parent d12d0ad5d6
commit 2ec373443a
7 changed files with 159 additions and 26 deletions

View File

@@ -7,9 +7,7 @@ from sqlalchemy.orm.attributes import (
set_committed_value, InstrumentedAttribute set_committed_value, InstrumentedAttribute
) )
from sqlalchemy.orm.session import object_session from sqlalchemy.orm.session import object_session
from sqlalchemy_utils.generic import ( from sqlalchemy_utils.generic import GenericRelationshipProperty
GenericRelationshipProperty, class_from_table_name
)
from sqlalchemy_utils.functions.orm import ( from sqlalchemy_utils.functions.orm import (
list_local_values, list_local_values,
list_local_remote_exprs, list_local_remote_exprs,
@@ -345,9 +343,7 @@ class GenericRelationshipFetcher(object):
def _queries(self, state, id_dict): def _queries(self, state, id_dict):
for discriminator, ids in six.iteritems(id_dict): for discriminator, ids in six.iteritems(id_dict):
class_ = class_from_table_name( class_ = state.class_._decl_class_registry.get(discriminator)
state, discriminator
)
yield self.path.session.query( yield self.path.session.query(
class_ class_
).filter( ).filter(

View File

@@ -7,19 +7,11 @@ from sqlalchemy.orm import ColumnProperty
from sqlalchemy.orm.interfaces import MapperProperty, PropComparator from sqlalchemy.orm.interfaces import MapperProperty, PropComparator
from sqlalchemy.orm.session import _state_session from sqlalchemy.orm.session import _state_session
from sqlalchemy.util import set_creation_order from sqlalchemy.util import set_creation_order
from sqlalchemy_utils.functions import table_name, identity from sqlalchemy_utils.functions import identity
from .exceptions import ImproperlyConfigured from .exceptions import ImproperlyConfigured
def class_from_table_name(state, table):
for class_ in state.class_._decl_class_registry.values():
name = table_name(class_)
if name and name == table:
return class_
return None
class GenericAttributeImpl(attributes.ScalarAttributeImpl): class GenericAttributeImpl(attributes.ScalarAttributeImpl):
def get(self, state, dict_, passive=attributes.PASSIVE_OFF): def get(self, state, dict_, passive=attributes.PASSIVE_OFF):
if self.key in dict_: if self.key in dict_:
@@ -35,7 +27,7 @@ class GenericAttributeImpl(attributes.ScalarAttributeImpl):
# Find class for discriminator. # Find class for discriminator.
# TODO: Perhaps optimize with some sort of lookup? # TODO: Perhaps optimize with some sort of lookup?
discriminator = self.get_state_discriminator(state) discriminator = self.get_state_discriminator(state)
target_class = class_from_table_name(state, discriminator) target_class = state.class_._decl_class_registry.get(discriminator)
if target_class is None: if target_class is None:
# Unknown discriminator; return nothing. # Unknown discriminator; return nothing.
@@ -75,12 +67,13 @@ class GenericAttributeImpl(attributes.ScalarAttributeImpl):
else: else:
# Get the primary key of the initiator and ensure we # Get the primary key of the initiator and ensure we
# can support this assignment. # can support this assignment.
mapper = class_mapper(type(initiator)) class_ = type(initiator)
mapper = class_mapper(class_)
pk = mapper.identity_key_from_instance(initiator)[1] pk = mapper.identity_key_from_instance(initiator)[1]
# Set the identifier and the discriminator. # Set the identifier and the discriminator.
discriminator = table_name(initiator) discriminator = unicode(class_.__name__)
for index, id in enumerate(self.parent_token.id): for index, id in enumerate(self.parent_token.id):
dict_[id.key] = pk[index] dict_[id.key] = pk[index]
@@ -150,7 +143,7 @@ class GenericRelationshipProperty(MapperProperty):
self._parentmapper = parentmapper self._parentmapper = parentmapper
def __eq__(self, other): def __eq__(self, other):
discriminator = table_name(other) discriminator = unicode(type(other).__name__)
q = self.property._discriminator_col == discriminator q = self.property._discriminator_col == discriminator
other_id = identity(other) other_id = identity(other)
for index, id in enumerate(self.property._id_cols): for index, id in enumerate(self.property._id_cols):
@@ -161,7 +154,7 @@ class GenericRelationshipProperty(MapperProperty):
return ~(self == other) return ~(self == other)
def is_type(self, other): def is_type(self, other):
discriminator = table_name(other) discriminator = unicode(other.__name__)
return self.property._discriminator_col == discriminator return self.property._discriminator_col == discriminator
def instrument_class(self, mapper): def instrument_class(self, mapper):

View File

@@ -16,7 +16,7 @@ class GenericRelationshipTestCase(TestCase):
event = self.Event() event = self.Event()
event.object_id = user.id event.object_id = user.id
event.object_type = type(user).__tablename__ event.object_type = unicode(type(user).__name__)
assert event.object is None assert event.object is None
@@ -34,7 +34,7 @@ class GenericRelationshipTestCase(TestCase):
event = self.Event(object=user) event = self.Event(object=user)
assert event.object_id == user.id assert event.object_id == user.id
assert event.object_type == type(user).__tablename__ assert event.object_type == type(user).__name__
self.session.add(event) self.session.add(event)
self.session.commit() self.session.commit()

View File

@@ -52,7 +52,7 @@ class TestGenericRelationship(GenericRelationshipTestCase):
event = self.Event() event = self.Event()
event.object_id = user.id event.object_id = user.id
event.object_type = type(user).__tablename__ event.object_type = unicode(type(user).__name__)
event.object_code = user.code event.object_code = user.code
assert event.object is None assert event.object is None

View File

@@ -32,11 +32,11 @@ class TestGenericRelationship(TestCase):
@hybrid_property @hybrid_property
def object_version_type(self): def object_version_type(self):
return self.object_type + '_history' return self.object_type + 'History'
@object_version_type.expression @object_version_type.expression
def object_version_type(cls): def object_version_type(cls):
return sa.func.concat(cls.object_type, '_history') return sa.func.concat(cls.object_type, 'History')
object_version = generic_relationship( object_version = generic_relationship(
object_version_type, (object_id, transaction_id) object_version_type, (object_id, transaction_id)
@@ -55,7 +55,7 @@ class TestGenericRelationship(TestCase):
event = self.Event(transaction_id=1) event = self.Event(transaction_id=1)
event.object_id = user.id event.object_id = user.id
event.object_type = type(user).__tablename__ event.object_type = unicode(type(user).__name__)
assert event.object is None assert event.object is None
self.session.add(event) self.session.add(event)

View File

@@ -0,0 +1,144 @@
from __future__ import unicode_literals
import sqlalchemy as sa
from sqlalchemy_utils import generic_relationship
from tests import TestCase
class TestGenericRelationship(TestCase):
def create_models(self):
class Employee(self.Base):
__tablename__ = 'employee'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String(50))
type = sa.Column(sa.String(20))
__mapper_args__ = {
'polymorphic_on': type,
'polymorphic_identity': 'employee'
}
class Manager(Employee):
__mapper_args__ = {
'polymorphic_identity': 'manager'
}
class Engineer(Employee):
__mapper_args__ = {
'polymorphic_identity': 'engineer'
}
class Event(self.Base):
__tablename__ = 'event'
id = sa.Column(sa.Integer, primary_key=True)
object_type = sa.Column(sa.Unicode(255))
object_id = sa.Column(sa.Integer, nullable=False)
object = generic_relationship(object_type, object_id)
self.Employee = Employee
self.Manager = Manager
self.Engineer = Engineer
self.Event = Event
def test_set_as_none(self):
event = self.Event()
event.object = None
assert event.object is None
def test_set_manual_and_get(self):
manager = self.Manager()
self.session.add(manager)
self.session.commit()
event = self.Event()
event.object_id = manager.id
event.object_type = unicode(type(manager).__name__)
assert event.object is None
self.session.add(event)
self.session.commit()
assert event.object == manager
def test_set_and_get(self):
manager = self.Manager()
self.session.add(manager)
self.session.commit()
event = self.Event(object=manager)
assert event.object_id == manager.id
assert event.object_type == type(manager).__name__
self.session.add(event)
self.session.commit()
assert event.object == manager
def test_compare_instance(self):
manager1 = self.Manager()
manager2 = self.Manager()
self.session.add_all([manager1, manager2])
self.session.commit()
event = self.Event(object=manager1)
self.session.add(event)
self.session.commit()
assert event.object == manager1
assert event.object != manager2
def test_compare_query(self):
manager1 = self.Manager()
manager2 = self.Manager()
self.session.add_all([manager1, manager2])
self.session.commit()
event = self.Event(object=manager1)
self.session.add(event)
self.session.commit()
q = self.session.query(self.Event)
assert q.filter_by(object=manager1).first() is not None
assert q.filter_by(object=manager2).first() is None
assert q.filter(self.Event.object == manager2).first() is None
def test_compare_not_query(self):
manager1 = self.Manager()
manager2 = self.Manager()
self.session.add_all([manager1, manager2])
self.session.commit()
event = self.Event(object=manager1)
self.session.add(event)
self.session.commit()
q = self.session.query(self.Event)
assert q.filter(self.Event.object != manager2).first() is not None
def test_compare_type(self):
manager1 = self.Manager()
manager2 = self.Manager()
self.session.add_all([manager1, manager2])
self.session.commit()
event1 = self.Event(object=manager1)
event2 = self.Event(object=manager2)
self.session.add_all([event1, event2])
self.session.commit()
statement = self.Event.object.is_type(self.Manager)
q = self.session.query(self.Event).filter(statement)
assert q.first() is not None