Add support for composite keys, refs #68
This commit is contained in:
@@ -3,8 +3,13 @@ Changelog
|
|||||||
|
|
||||||
Here you can see the full list of changes between each SQLAlchemy-Utils release.
|
Here you can see the full list of changes between each SQLAlchemy-Utils release.
|
||||||
|
|
||||||
|
0.24.3 (2014-02-xx)
|
||||||
|
^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
0.24.2 (2014-02-21)
|
- Added string argument support for generic_relationship
|
||||||
|
|
||||||
|
|
||||||
|
0.24.2 (2014-03-05)
|
||||||
^^^^^^^^^^^^^^^^^^^
|
^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
- Remove toolz from dependencies
|
- Remove toolz from dependencies
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
|
from collections import Iterable
|
||||||
|
|
||||||
import six
|
import six
|
||||||
|
|
||||||
from sqlalchemy.orm.interfaces import MapperProperty, PropComparator
|
from sqlalchemy.orm.interfaces import MapperProperty, PropComparator
|
||||||
from sqlalchemy.orm.session import _state_session
|
from sqlalchemy.orm.session import _state_session
|
||||||
from sqlalchemy.orm import attributes, class_mapper
|
from sqlalchemy.orm import attributes, class_mapper
|
||||||
from sqlalchemy.util import set_creation_order
|
from sqlalchemy.util import set_creation_order
|
||||||
from sqlalchemy import exc as sa_exc
|
from sqlalchemy_utils.functions import table_name, identity
|
||||||
from sqlalchemy_utils.functions import table_name
|
|
||||||
|
|
||||||
|
|
||||||
def class_from_table_name(state, table):
|
def class_from_table_name(state, table):
|
||||||
@@ -37,13 +38,16 @@ class GenericAttributeImpl(attributes.ScalarAttributeImpl):
|
|||||||
# Unknown discriminator; return nothing.
|
# Unknown discriminator; return nothing.
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Lookup row with the discriminator and id.
|
id = self.get_state_id(state)
|
||||||
id = state.attrs[self.parent_token.id.key].value
|
|
||||||
target = session.query(target_class).get(id)
|
target = session.query(target_class).get(id)
|
||||||
|
|
||||||
# Return found (or not found) target.
|
# Return found (or not found) target.
|
||||||
return target
|
return target
|
||||||
|
|
||||||
|
def get_state_id(self, state):
|
||||||
|
# Lookup row with the discriminator and id.
|
||||||
|
return tuple(state.attrs[id.key].value for id in self.parent_token.id)
|
||||||
|
|
||||||
def set(self, state, dict_, initiator,
|
def set(self, state, dict_, initiator,
|
||||||
passive=attributes.PASSIVE_OFF,
|
passive=attributes.PASSIVE_OFF,
|
||||||
check_old=None,
|
check_old=None,
|
||||||
@@ -54,22 +58,21 @@ class GenericAttributeImpl(attributes.ScalarAttributeImpl):
|
|||||||
|
|
||||||
if initiator is None:
|
if initiator is None:
|
||||||
# Nullify relationship args
|
# Nullify relationship args
|
||||||
dict_[self.parent_token.id.key] = None
|
for id in self.parent_token.id:
|
||||||
|
dict_[id.key] = None
|
||||||
dict_[self.parent_token.discriminator.key] = None
|
dict_[self.parent_token.discriminator.key] = None
|
||||||
else:
|
else:
|
||||||
# Get the primary key of the initiator and ensure we
|
# Get the primary key of the initiator and ensure we
|
||||||
# can support this assignment.
|
# can support this assignment.
|
||||||
mapper = class_mapper(type(initiator))
|
mapper = class_mapper(type(initiator))
|
||||||
if len(mapper.primary_key) > 1:
|
|
||||||
raise sa_exc.InvalidRequestError(
|
|
||||||
'Generic relationships against tables with composite '
|
|
||||||
'primary keys are not supported.')
|
|
||||||
|
|
||||||
pk = mapper.identity_key_from_instance(initiator)[1][0]
|
pk = mapper.identity_key_from_instance(initiator)[1]
|
||||||
|
|
||||||
# Set the identifier and the discriminator.
|
# Set the identifier and the discriminator.
|
||||||
discriminator = table_name(initiator)
|
discriminator = table_name(initiator)
|
||||||
dict_[self.parent_token.id.key] = pk
|
|
||||||
|
for index, id in enumerate(self.parent_token.id):
|
||||||
|
dict_[id.key] = pk[index]
|
||||||
dict_[self.parent_token.discriminator.key] = discriminator
|
dict_[self.parent_token.discriminator.key] = discriminator
|
||||||
|
|
||||||
|
|
||||||
@@ -87,7 +90,7 @@ class GenericRelationshipProperty(MapperProperty):
|
|||||||
|
|
||||||
def __init__(self, discriminator, id, doc=None):
|
def __init__(self, discriminator, id, doc=None):
|
||||||
self._discriminator_col = discriminator
|
self._discriminator_col = discriminator
|
||||||
self._id_col = id
|
self._id_cols = id
|
||||||
self._id = None
|
self._id = None
|
||||||
self._discriminator = None
|
self._discriminator = None
|
||||||
self.doc = doc
|
self.doc = doc
|
||||||
@@ -101,19 +104,21 @@ class GenericRelationshipProperty(MapperProperty):
|
|||||||
return attr
|
return attr
|
||||||
|
|
||||||
def init(self):
|
def init(self):
|
||||||
# Resolve columns to attributes.
|
def convert_strings(column):
|
||||||
if isinstance(self._discriminator_col, six.string_types):
|
if isinstance(column, six.string_types):
|
||||||
self._discriminator_col = self.parent.columns[
|
return self.parent.columns[column]
|
||||||
self._discriminator_col
|
return column
|
||||||
]
|
|
||||||
|
|
||||||
if isinstance(self._id_col, six.string_types):
|
self._discriminator_col = convert_strings(self._discriminator_col)
|
||||||
self._id_col = self.parent.columns[
|
self._id_cols = convert_strings(self._id_cols)
|
||||||
self._id_col
|
|
||||||
]
|
if isinstance(self._id_cols, Iterable):
|
||||||
|
self._id_cols = list(map(convert_strings, self._id_cols))
|
||||||
|
else:
|
||||||
|
self._id_cols = [self._id_cols]
|
||||||
|
|
||||||
self.discriminator = self._column_to_property(self._discriminator_col)
|
self.discriminator = self._column_to_property(self._discriminator_col)
|
||||||
self.id = self._column_to_property(self._id_col)
|
self.id = list(map(self._column_to_property, self._id_cols))
|
||||||
|
|
||||||
class Comparator(PropComparator):
|
class Comparator(PropComparator):
|
||||||
|
|
||||||
@@ -124,7 +129,9 @@ class GenericRelationshipProperty(MapperProperty):
|
|||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
discriminator = table_name(other)
|
discriminator = table_name(other)
|
||||||
q = self.property._discriminator_col == discriminator
|
q = self.property._discriminator_col == discriminator
|
||||||
q &= self.property._id_col == other.id
|
other_id = identity(other)
|
||||||
|
for index, id in enumerate(self.property._id_cols):
|
||||||
|
q &= id == other_id[index]
|
||||||
return q
|
return q
|
||||||
|
|
||||||
def __ne__(self, other):
|
def __ne__(self, other):
|
||||||
|
|||||||
@@ -97,7 +97,7 @@ All range types support all comparison operators (>, >=, ==, !=, <=, <).
|
|||||||
Car.price_range << [300, 500]
|
Car.price_range << [300, 500]
|
||||||
|
|
||||||
# Whether or not range is strictly right of another range
|
# Whether or not range is strictly right of another range
|
||||||
Car.price_range << [300, 500]
|
Car.price_range >> [300, 500]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,5 @@
|
|||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
import sqlalchemy as sa
|
|
||||||
from tests import TestCase
|
from tests import TestCase
|
||||||
from sqlalchemy_utils import generic_relationship
|
|
||||||
from sqlalchemy.ext.declarative import declared_attr
|
|
||||||
|
|
||||||
|
|
||||||
class GenericRelationshipTestCase(TestCase):
|
class GenericRelationshipTestCase(TestCase):
|
||||||
@@ -107,56 +104,3 @@ class GenericRelationshipTestCase(TestCase):
|
|||||||
statement = self.Event.object.is_type(self.User)
|
statement = self.Event.object.is_type(self.User)
|
||||||
q = self.session.query(self.Event).filter(statement)
|
q = self.session.query(self.Event).filter(statement)
|
||||||
assert q.first() is not None
|
assert q.first() is not None
|
||||||
|
|
||||||
|
|
||||||
class TestGenericRelationship(GenericRelationshipTestCase):
|
|
||||||
def create_models(self):
|
|
||||||
class Building(self.Base):
|
|
||||||
__tablename__ = 'building'
|
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
|
||||||
|
|
||||||
class User(self.Base):
|
|
||||||
__tablename__ = 'user'
|
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
|
||||||
|
|
||||||
class Event(self.Base):
|
|
||||||
__tablename__ = 'event'
|
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
|
||||||
|
|
||||||
object_type = sa.Column(sa.Unicode(255), name="objectType")
|
|
||||||
object_id = sa.Column(sa.Integer, nullable=False)
|
|
||||||
|
|
||||||
object = generic_relationship(object_type, object_id)
|
|
||||||
|
|
||||||
self.Building = Building
|
|
||||||
self.User = User
|
|
||||||
self.Event = Event
|
|
||||||
|
|
||||||
|
|
||||||
class TestGenericRelationshipWithAbstractBase(GenericRelationshipTestCase):
|
|
||||||
def create_models(self):
|
|
||||||
class Building(self.Base):
|
|
||||||
__tablename__ = 'building'
|
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
|
||||||
|
|
||||||
class User(self.Base):
|
|
||||||
__tablename__ = 'user'
|
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
|
||||||
|
|
||||||
class EventBase(self.Base):
|
|
||||||
__abstract__ = True
|
|
||||||
|
|
||||||
object_type = sa.Column(sa.Unicode(255))
|
|
||||||
object_id = sa.Column(sa.Integer, nullable=False)
|
|
||||||
|
|
||||||
@declared_attr
|
|
||||||
def object(cls):
|
|
||||||
return generic_relationship('object_type', 'object_id')
|
|
||||||
|
|
||||||
class Event(EventBase):
|
|
||||||
__tablename__ = 'event'
|
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
|
||||||
|
|
||||||
self.Building = Building
|
|
||||||
self.User = User
|
|
||||||
self.Event = Event
|
|
||||||
34
tests/generic_relationship/test_abstract_base_class.py
Normal file
34
tests/generic_relationship/test_abstract_base_class.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
from __future__ import unicode_literals
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy_utils import generic_relationship
|
||||||
|
from sqlalchemy.ext.declarative import declared_attr
|
||||||
|
from tests.generic_relationship import GenericRelationshipTestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestGenericRelationshipWithAbstractBase(GenericRelationshipTestCase):
|
||||||
|
def create_models(self):
|
||||||
|
class Building(self.Base):
|
||||||
|
__tablename__ = 'building'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
|
||||||
|
class User(self.Base):
|
||||||
|
__tablename__ = 'user'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
|
||||||
|
class EventBase(self.Base):
|
||||||
|
__abstract__ = True
|
||||||
|
|
||||||
|
object_type = sa.Column(sa.Unicode(255))
|
||||||
|
object_id = sa.Column(sa.Integer, nullable=False)
|
||||||
|
|
||||||
|
@declared_attr
|
||||||
|
def object(cls):
|
||||||
|
return generic_relationship('object_type', 'object_id')
|
||||||
|
|
||||||
|
class Event(EventBase):
|
||||||
|
__tablename__ = 'event'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
|
||||||
|
self.Building = Building
|
||||||
|
self.User = User
|
||||||
|
self.Event = Event
|
||||||
63
tests/generic_relationship/test_composite_keys.py
Normal file
63
tests/generic_relationship/test_composite_keys.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
from __future__ import unicode_literals
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy_utils import generic_relationship
|
||||||
|
from tests.generic_relationship import GenericRelationshipTestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestGenericRelationship(GenericRelationshipTestCase):
|
||||||
|
index = 1
|
||||||
|
|
||||||
|
def create_models(self):
|
||||||
|
class Building(self.Base):
|
||||||
|
__tablename__ = 'building'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
code = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
|
||||||
|
def __init__(obj_self):
|
||||||
|
self.index += 1
|
||||||
|
obj_self.id = self.index
|
||||||
|
obj_self.code = self.index
|
||||||
|
|
||||||
|
class User(self.Base):
|
||||||
|
__tablename__ = 'user'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
code = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
|
||||||
|
def __init__(obj_self):
|
||||||
|
self.index += 1
|
||||||
|
obj_self.id = self.index
|
||||||
|
obj_self.code = self.index
|
||||||
|
|
||||||
|
class Event(self.Base):
|
||||||
|
__tablename__ = 'event'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
|
||||||
|
object_type = sa.Column(sa.Unicode(255))
|
||||||
|
object_id = sa.Column(sa.Integer, nullable=False)
|
||||||
|
object_code = sa.Column(sa.Integer, nullable=False)
|
||||||
|
|
||||||
|
object = generic_relationship(
|
||||||
|
object_type, (object_id, object_code)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.Building = Building
|
||||||
|
self.User = User
|
||||||
|
self.Event = Event
|
||||||
|
|
||||||
|
def test_set_manual_and_get(self):
|
||||||
|
user = self.User()
|
||||||
|
|
||||||
|
self.session.add(user)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
event = self.Event()
|
||||||
|
event.object_id = user.id
|
||||||
|
event.object_type = type(user).__tablename__
|
||||||
|
event.object_code = user.code
|
||||||
|
|
||||||
|
assert event.object is None
|
||||||
|
|
||||||
|
self.session.add(event)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
assert event.object == user
|
||||||
28
tests/generic_relationship/test_simple.py
Normal file
28
tests/generic_relationship/test_simple.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
from __future__ import unicode_literals
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy_utils import generic_relationship
|
||||||
|
from tests.generic_relationship import GenericRelationshipTestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestGenericRelationship(GenericRelationshipTestCase):
|
||||||
|
def create_models(self):
|
||||||
|
class Building(self.Base):
|
||||||
|
__tablename__ = 'building'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
|
||||||
|
class User(self.Base):
|
||||||
|
__tablename__ = 'user'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
|
||||||
|
class Event(self.Base):
|
||||||
|
__tablename__ = 'event'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
|
||||||
|
object_type = sa.Column(sa.Unicode(255), name="objectType")
|
||||||
|
object_id = sa.Column(sa.Integer, nullable=False)
|
||||||
|
|
||||||
|
object = generic_relationship(object_type, object_id)
|
||||||
|
|
||||||
|
self.Building = Building
|
||||||
|
self.User = User
|
||||||
|
self.Event = Event
|
||||||
Reference in New Issue
Block a user