diff --git a/sqlalchemy_utils/batch.py b/sqlalchemy_utils/batch.py index c42bec7..8aa97dd 100644 --- a/sqlalchemy_utils/batch.py +++ b/sqlalchemy_utils/batch.py @@ -7,9 +7,7 @@ from sqlalchemy.orm.attributes import ( set_committed_value, InstrumentedAttribute ) from sqlalchemy.orm.session import object_session -from sqlalchemy_utils.generic import ( - GenericRelationshipProperty, class_from_table_name -) +from sqlalchemy_utils.generic import GenericRelationshipProperty from sqlalchemy_utils.functions.orm import ( list_local_values, list_local_remote_exprs, @@ -345,9 +343,7 @@ class GenericRelationshipFetcher(object): def _queries(self, state, id_dict): for discriminator, ids in six.iteritems(id_dict): - class_ = class_from_table_name( - state, discriminator - ) + class_ = state.class_._decl_class_registry.get(discriminator) yield self.path.session.query( class_ ).filter( diff --git a/sqlalchemy_utils/generic.py b/sqlalchemy_utils/generic.py index 6092984..51cc15b 100644 --- a/sqlalchemy_utils/generic.py +++ b/sqlalchemy_utils/generic.py @@ -7,19 +7,11 @@ from sqlalchemy.orm import ColumnProperty from sqlalchemy.orm.interfaces import MapperProperty, PropComparator from sqlalchemy.orm.session import _state_session from sqlalchemy.util import set_creation_order -from sqlalchemy_utils.functions import table_name, identity +from sqlalchemy_utils.functions import identity from .exceptions import ImproperlyConfigured -def class_from_table_name(state, table): - for class_ in state.class_._decl_class_registry.values(): - name = table_name(class_) - if name and name == table: - return class_ - return None - - class GenericAttributeImpl(attributes.ScalarAttributeImpl): def get(self, state, dict_, passive=attributes.PASSIVE_OFF): if self.key in dict_: @@ -35,7 +27,7 @@ class GenericAttributeImpl(attributes.ScalarAttributeImpl): # Find class for discriminator. # TODO: Perhaps optimize with some sort of lookup? discriminator = self.get_state_discriminator(state) - target_class = class_from_table_name(state, discriminator) + target_class = state.class_._decl_class_registry.get(discriminator) if target_class is None: # Unknown discriminator; return nothing. @@ -75,12 +67,13 @@ class GenericAttributeImpl(attributes.ScalarAttributeImpl): else: # Get the primary key of the initiator and ensure we # can support this assignment. - mapper = class_mapper(type(initiator)) + class_ = type(initiator) + mapper = class_mapper(class_) pk = mapper.identity_key_from_instance(initiator)[1] # Set the identifier and the discriminator. - discriminator = table_name(initiator) + discriminator = unicode(class_.__name__) for index, id in enumerate(self.parent_token.id): dict_[id.key] = pk[index] @@ -150,7 +143,7 @@ class GenericRelationshipProperty(MapperProperty): self._parentmapper = parentmapper def __eq__(self, other): - discriminator = table_name(other) + discriminator = unicode(type(other).__name__) q = self.property._discriminator_col == discriminator other_id = identity(other) for index, id in enumerate(self.property._id_cols): @@ -161,7 +154,7 @@ class GenericRelationshipProperty(MapperProperty): return ~(self == other) def is_type(self, other): - discriminator = table_name(other) + discriminator = unicode(other.__name__) return self.property._discriminator_col == discriminator def instrument_class(self, mapper): diff --git a/tests/generic_relationship/__init__.py b/tests/generic_relationship/__init__.py index fda075f..0956591 100644 --- a/tests/generic_relationship/__init__.py +++ b/tests/generic_relationship/__init__.py @@ -16,7 +16,7 @@ class GenericRelationshipTestCase(TestCase): event = self.Event() event.object_id = user.id - event.object_type = type(user).__tablename__ + event.object_type = unicode(type(user).__name__) assert event.object is None @@ -34,7 +34,7 @@ class GenericRelationshipTestCase(TestCase): event = self.Event(object=user) assert event.object_id == user.id - assert event.object_type == type(user).__tablename__ + assert event.object_type == type(user).__name__ self.session.add(event) self.session.commit() diff --git a/tests/generic_relationship/test_simple.py b/tests/generic_relationship/test_column_aliases.py similarity index 100% rename from tests/generic_relationship/test_simple.py rename to tests/generic_relationship/test_column_aliases.py diff --git a/tests/generic_relationship/test_composite_keys.py b/tests/generic_relationship/test_composite_keys.py index 5458fb8..b933738 100644 --- a/tests/generic_relationship/test_composite_keys.py +++ b/tests/generic_relationship/test_composite_keys.py @@ -52,7 +52,7 @@ class TestGenericRelationship(GenericRelationshipTestCase): event = self.Event() event.object_id = user.id - event.object_type = type(user).__tablename__ + event.object_type = unicode(type(user).__name__) event.object_code = user.code assert event.object is None diff --git a/tests/generic_relationship/test_hybrid_properties.py b/tests/generic_relationship/test_hybrid_properties.py index c6a8ce0..a2b845a 100644 --- a/tests/generic_relationship/test_hybrid_properties.py +++ b/tests/generic_relationship/test_hybrid_properties.py @@ -32,11 +32,11 @@ class TestGenericRelationship(TestCase): @hybrid_property def object_version_type(self): - return self.object_type + '_history' + return self.object_type + 'History' @object_version_type.expression def object_version_type(cls): - return sa.func.concat(cls.object_type, '_history') + return sa.func.concat(cls.object_type, 'History') object_version = generic_relationship( object_version_type, (object_id, transaction_id) @@ -55,7 +55,7 @@ class TestGenericRelationship(TestCase): event = self.Event(transaction_id=1) event.object_id = user.id - event.object_type = type(user).__tablename__ + event.object_type = unicode(type(user).__name__) assert event.object is None self.session.add(event) diff --git a/tests/generic_relationship/test_single_table_inheritance.py b/tests/generic_relationship/test_single_table_inheritance.py new file mode 100644 index 0000000..247e480 --- /dev/null +++ b/tests/generic_relationship/test_single_table_inheritance.py @@ -0,0 +1,144 @@ +from __future__ import unicode_literals +import sqlalchemy as sa +from sqlalchemy_utils import generic_relationship +from tests import TestCase + + +class TestGenericRelationship(TestCase): + def create_models(self): + class Employee(self.Base): + __tablename__ = 'employee' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.String(50)) + type = sa.Column(sa.String(20)) + + __mapper_args__ = { + 'polymorphic_on': type, + 'polymorphic_identity': 'employee' + } + + class Manager(Employee): + __mapper_args__ = { + 'polymorphic_identity': 'manager' + } + + class Engineer(Employee): + __mapper_args__ = { + 'polymorphic_identity': 'engineer' + } + + class Event(self.Base): + __tablename__ = 'event' + id = sa.Column(sa.Integer, primary_key=True) + + object_type = sa.Column(sa.Unicode(255)) + object_id = sa.Column(sa.Integer, nullable=False) + + object = generic_relationship(object_type, object_id) + + self.Employee = Employee + self.Manager = Manager + self.Engineer = Engineer + self.Event = Event + + def test_set_as_none(self): + event = self.Event() + event.object = None + assert event.object is None + + def test_set_manual_and_get(self): + manager = self.Manager() + + self.session.add(manager) + self.session.commit() + + event = self.Event() + event.object_id = manager.id + event.object_type = unicode(type(manager).__name__) + + assert event.object is None + + self.session.add(event) + self.session.commit() + + assert event.object == manager + + def test_set_and_get(self): + manager = self.Manager() + + self.session.add(manager) + self.session.commit() + + event = self.Event(object=manager) + + assert event.object_id == manager.id + assert event.object_type == type(manager).__name__ + + self.session.add(event) + self.session.commit() + + assert event.object == manager + + def test_compare_instance(self): + manager1 = self.Manager() + manager2 = self.Manager() + + self.session.add_all([manager1, manager2]) + self.session.commit() + + event = self.Event(object=manager1) + + self.session.add(event) + self.session.commit() + + assert event.object == manager1 + assert event.object != manager2 + + def test_compare_query(self): + manager1 = self.Manager() + manager2 = self.Manager() + + self.session.add_all([manager1, manager2]) + self.session.commit() + + event = self.Event(object=manager1) + + self.session.add(event) + self.session.commit() + + q = self.session.query(self.Event) + assert q.filter_by(object=manager1).first() is not None + assert q.filter_by(object=manager2).first() is None + assert q.filter(self.Event.object == manager2).first() is None + + def test_compare_not_query(self): + manager1 = self.Manager() + manager2 = self.Manager() + + self.session.add_all([manager1, manager2]) + self.session.commit() + + event = self.Event(object=manager1) + + self.session.add(event) + self.session.commit() + + q = self.session.query(self.Event) + assert q.filter(self.Event.object != manager2).first() is not None + + def test_compare_type(self): + manager1 = self.Manager() + manager2 = self.Manager() + + self.session.add_all([manager1, manager2]) + self.session.commit() + + event1 = self.Event(object=manager1) + event2 = self.Event(object=manager2) + + self.session.add_all([event1, event2]) + self.session.commit() + + statement = self.Event.object.is_type(self.Manager) + q = self.session.query(self.Event).filter(statement) + assert q.first() is not None