diff --git a/CHANGES.rst b/CHANGES.rst index 19a7234..6d7ffdc 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -3,8 +3,13 @@ Changelog Here you can see the full list of changes between each SQLAlchemy-Utils release. +0.24.3 (2014-02-xx) +^^^^^^^^^^^^^^^^^^^ -0.24.2 (2014-02-21) +- Added string argument support for generic_relationship + + +0.24.2 (2014-03-05) ^^^^^^^^^^^^^^^^^^^ - Remove toolz from dependencies diff --git a/sqlalchemy_utils/generic.py b/sqlalchemy_utils/generic.py index 6fe4499..120f29f 100644 --- a/sqlalchemy_utils/generic.py +++ b/sqlalchemy_utils/generic.py @@ -1,11 +1,12 @@ +from collections import Iterable + import six from sqlalchemy.orm.interfaces import MapperProperty, PropComparator from sqlalchemy.orm.session import _state_session from sqlalchemy.orm import attributes, class_mapper from sqlalchemy.util import set_creation_order -from sqlalchemy import exc as sa_exc -from sqlalchemy_utils.functions import table_name +from sqlalchemy_utils.functions import table_name, identity def class_from_table_name(state, table): @@ -37,13 +38,16 @@ class GenericAttributeImpl(attributes.ScalarAttributeImpl): # Unknown discriminator; return nothing. return None - # Lookup row with the discriminator and id. - id = state.attrs[self.parent_token.id.key].value + id = self.get_state_id(state) target = session.query(target_class).get(id) # Return found (or not found) target. return target + def get_state_id(self, state): + # Lookup row with the discriminator and id. + return tuple(state.attrs[id.key].value for id in self.parent_token.id) + def set(self, state, dict_, initiator, passive=attributes.PASSIVE_OFF, check_old=None, @@ -54,22 +58,21 @@ class GenericAttributeImpl(attributes.ScalarAttributeImpl): if initiator is None: # Nullify relationship args - dict_[self.parent_token.id.key] = None + for id in self.parent_token.id: + dict_[id.key] = None dict_[self.parent_token.discriminator.key] = None else: # Get the primary key of the initiator and ensure we # can support this assignment. mapper = class_mapper(type(initiator)) - if len(mapper.primary_key) > 1: - raise sa_exc.InvalidRequestError( - 'Generic relationships against tables with composite ' - 'primary keys are not supported.') - pk = mapper.identity_key_from_instance(initiator)[1][0] + pk = mapper.identity_key_from_instance(initiator)[1] # Set the identifier and the discriminator. discriminator = table_name(initiator) - dict_[self.parent_token.id.key] = pk + + for index, id in enumerate(self.parent_token.id): + dict_[id.key] = pk[index] dict_[self.parent_token.discriminator.key] = discriminator @@ -87,7 +90,7 @@ class GenericRelationshipProperty(MapperProperty): def __init__(self, discriminator, id, doc=None): self._discriminator_col = discriminator - self._id_col = id + self._id_cols = id self._id = None self._discriminator = None self.doc = doc @@ -101,19 +104,21 @@ class GenericRelationshipProperty(MapperProperty): return attr def init(self): - # Resolve columns to attributes. - if isinstance(self._discriminator_col, six.string_types): - self._discriminator_col = self.parent.columns[ - self._discriminator_col - ] + def convert_strings(column): + if isinstance(column, six.string_types): + return self.parent.columns[column] + return column - if isinstance(self._id_col, six.string_types): - self._id_col = self.parent.columns[ - self._id_col - ] + self._discriminator_col = convert_strings(self._discriminator_col) + self._id_cols = convert_strings(self._id_cols) + + if isinstance(self._id_cols, Iterable): + self._id_cols = list(map(convert_strings, self._id_cols)) + else: + self._id_cols = [self._id_cols] self.discriminator = self._column_to_property(self._discriminator_col) - self.id = self._column_to_property(self._id_col) + self.id = list(map(self._column_to_property, self._id_cols)) class Comparator(PropComparator): @@ -124,7 +129,9 @@ class GenericRelationshipProperty(MapperProperty): def __eq__(self, other): discriminator = table_name(other) q = self.property._discriminator_col == discriminator - q &= self.property._id_col == other.id + other_id = identity(other) + for index, id in enumerate(self.property._id_cols): + q &= id == other_id[index] return q def __ne__(self, other): diff --git a/sqlalchemy_utils/types/range.py b/sqlalchemy_utils/types/range.py index 379e0ea..a40036a 100644 --- a/sqlalchemy_utils/types/range.py +++ b/sqlalchemy_utils/types/range.py @@ -97,7 +97,7 @@ All range types support all comparison operators (>, >=, ==, !=, <=, <). Car.price_range << [300, 500] # Whether or not range is strictly right of another range - Car.price_range << [300, 500] + Car.price_range >> [300, 500] diff --git a/tests/test_generic.py b/tests/generic_relationship/__init__.py similarity index 60% rename from tests/test_generic.py rename to tests/generic_relationship/__init__.py index 0ce6671..fda075f 100644 --- a/tests/test_generic.py +++ b/tests/generic_relationship/__init__.py @@ -1,8 +1,5 @@ from __future__ import unicode_literals -import sqlalchemy as sa from tests import TestCase -from sqlalchemy_utils import generic_relationship -from sqlalchemy.ext.declarative import declared_attr class GenericRelationshipTestCase(TestCase): @@ -107,56 +104,3 @@ class GenericRelationshipTestCase(TestCase): statement = self.Event.object.is_type(self.User) q = self.session.query(self.Event).filter(statement) assert q.first() is not None - - -class TestGenericRelationship(GenericRelationshipTestCase): - def create_models(self): - class Building(self.Base): - __tablename__ = 'building' - id = sa.Column(sa.Integer, primary_key=True) - - class User(self.Base): - __tablename__ = 'user' - id = sa.Column(sa.Integer, primary_key=True) - - class Event(self.Base): - __tablename__ = 'event' - id = sa.Column(sa.Integer, primary_key=True) - - object_type = sa.Column(sa.Unicode(255), name="objectType") - object_id = sa.Column(sa.Integer, nullable=False) - - object = generic_relationship(object_type, object_id) - - self.Building = Building - self.User = User - self.Event = Event - - -class TestGenericRelationshipWithAbstractBase(GenericRelationshipTestCase): - def create_models(self): - class Building(self.Base): - __tablename__ = 'building' - id = sa.Column(sa.Integer, primary_key=True) - - class User(self.Base): - __tablename__ = 'user' - id = sa.Column(sa.Integer, primary_key=True) - - class EventBase(self.Base): - __abstract__ = True - - object_type = sa.Column(sa.Unicode(255)) - object_id = sa.Column(sa.Integer, nullable=False) - - @declared_attr - def object(cls): - return generic_relationship('object_type', 'object_id') - - class Event(EventBase): - __tablename__ = 'event' - id = sa.Column(sa.Integer, primary_key=True) - - self.Building = Building - self.User = User - self.Event = Event diff --git a/tests/generic_relationship/test_abstract_base_class.py b/tests/generic_relationship/test_abstract_base_class.py new file mode 100644 index 0000000..efdf79d --- /dev/null +++ b/tests/generic_relationship/test_abstract_base_class.py @@ -0,0 +1,34 @@ +from __future__ import unicode_literals +import sqlalchemy as sa +from sqlalchemy_utils import generic_relationship +from sqlalchemy.ext.declarative import declared_attr +from tests.generic_relationship import GenericRelationshipTestCase + + +class TestGenericRelationshipWithAbstractBase(GenericRelationshipTestCase): + def create_models(self): + class Building(self.Base): + __tablename__ = 'building' + id = sa.Column(sa.Integer, primary_key=True) + + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + + class EventBase(self.Base): + __abstract__ = True + + object_type = sa.Column(sa.Unicode(255)) + object_id = sa.Column(sa.Integer, nullable=False) + + @declared_attr + def object(cls): + return generic_relationship('object_type', 'object_id') + + class Event(EventBase): + __tablename__ = 'event' + id = sa.Column(sa.Integer, primary_key=True) + + self.Building = Building + self.User = User + self.Event = Event diff --git a/tests/generic_relationship/test_composite_keys.py b/tests/generic_relationship/test_composite_keys.py new file mode 100644 index 0000000..5458fb8 --- /dev/null +++ b/tests/generic_relationship/test_composite_keys.py @@ -0,0 +1,63 @@ +from __future__ import unicode_literals +import sqlalchemy as sa +from sqlalchemy_utils import generic_relationship +from tests.generic_relationship import GenericRelationshipTestCase + + +class TestGenericRelationship(GenericRelationshipTestCase): + index = 1 + + def create_models(self): + class Building(self.Base): + __tablename__ = 'building' + id = sa.Column(sa.Integer, primary_key=True) + code = sa.Column(sa.Integer, primary_key=True) + + def __init__(obj_self): + self.index += 1 + obj_self.id = self.index + obj_self.code = self.index + + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + code = sa.Column(sa.Integer, primary_key=True) + + def __init__(obj_self): + self.index += 1 + obj_self.id = self.index + obj_self.code = self.index + + class Event(self.Base): + __tablename__ = 'event' + id = sa.Column(sa.Integer, primary_key=True) + + object_type = sa.Column(sa.Unicode(255)) + object_id = sa.Column(sa.Integer, nullable=False) + object_code = sa.Column(sa.Integer, nullable=False) + + object = generic_relationship( + object_type, (object_id, object_code) + ) + + self.Building = Building + self.User = User + self.Event = Event + + def test_set_manual_and_get(self): + user = self.User() + + self.session.add(user) + self.session.commit() + + event = self.Event() + event.object_id = user.id + event.object_type = type(user).__tablename__ + event.object_code = user.code + + assert event.object is None + + self.session.add(event) + self.session.commit() + + assert event.object == user diff --git a/tests/generic_relationship/test_simple.py b/tests/generic_relationship/test_simple.py new file mode 100644 index 0000000..3e34bc7 --- /dev/null +++ b/tests/generic_relationship/test_simple.py @@ -0,0 +1,28 @@ +from __future__ import unicode_literals +import sqlalchemy as sa +from sqlalchemy_utils import generic_relationship +from tests.generic_relationship import GenericRelationshipTestCase + + +class TestGenericRelationship(GenericRelationshipTestCase): + def create_models(self): + class Building(self.Base): + __tablename__ = 'building' + id = sa.Column(sa.Integer, primary_key=True) + + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + + class Event(self.Base): + __tablename__ = 'event' + id = sa.Column(sa.Integer, primary_key=True) + + object_type = sa.Column(sa.Unicode(255), name="objectType") + object_id = sa.Column(sa.Integer, nullable=False) + + object = generic_relationship(object_type, object_id) + + self.Building = Building + self.User = User + self.Event = Event