Add support for composite keys, refs #68
This commit is contained in:
@@ -3,8 +3,13 @@ Changelog
|
||||
|
||||
Here you can see the full list of changes between each SQLAlchemy-Utils release.
|
||||
|
||||
0.24.3 (2014-02-xx)
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
0.24.2 (2014-02-21)
|
||||
- Added string argument support for generic_relationship
|
||||
|
||||
|
||||
0.24.2 (2014-03-05)
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
- Remove toolz from dependencies
|
||||
|
@@ -1,11 +1,12 @@
|
||||
from collections import Iterable
|
||||
|
||||
import six
|
||||
|
||||
from sqlalchemy.orm.interfaces import MapperProperty, PropComparator
|
||||
from sqlalchemy.orm.session import _state_session
|
||||
from sqlalchemy.orm import attributes, class_mapper
|
||||
from sqlalchemy.util import set_creation_order
|
||||
from sqlalchemy import exc as sa_exc
|
||||
from sqlalchemy_utils.functions import table_name
|
||||
from sqlalchemy_utils.functions import table_name, identity
|
||||
|
||||
|
||||
def class_from_table_name(state, table):
|
||||
@@ -37,13 +38,16 @@ class GenericAttributeImpl(attributes.ScalarAttributeImpl):
|
||||
# Unknown discriminator; return nothing.
|
||||
return None
|
||||
|
||||
# Lookup row with the discriminator and id.
|
||||
id = state.attrs[self.parent_token.id.key].value
|
||||
id = self.get_state_id(state)
|
||||
target = session.query(target_class).get(id)
|
||||
|
||||
# Return found (or not found) target.
|
||||
return target
|
||||
|
||||
def get_state_id(self, state):
|
||||
# Lookup row with the discriminator and id.
|
||||
return tuple(state.attrs[id.key].value for id in self.parent_token.id)
|
||||
|
||||
def set(self, state, dict_, initiator,
|
||||
passive=attributes.PASSIVE_OFF,
|
||||
check_old=None,
|
||||
@@ -54,22 +58,21 @@ class GenericAttributeImpl(attributes.ScalarAttributeImpl):
|
||||
|
||||
if initiator is None:
|
||||
# Nullify relationship args
|
||||
dict_[self.parent_token.id.key] = None
|
||||
for id in self.parent_token.id:
|
||||
dict_[id.key] = None
|
||||
dict_[self.parent_token.discriminator.key] = None
|
||||
else:
|
||||
# Get the primary key of the initiator and ensure we
|
||||
# can support this assignment.
|
||||
mapper = class_mapper(type(initiator))
|
||||
if len(mapper.primary_key) > 1:
|
||||
raise sa_exc.InvalidRequestError(
|
||||
'Generic relationships against tables with composite '
|
||||
'primary keys are not supported.')
|
||||
|
||||
pk = mapper.identity_key_from_instance(initiator)[1][0]
|
||||
pk = mapper.identity_key_from_instance(initiator)[1]
|
||||
|
||||
# Set the identifier and the discriminator.
|
||||
discriminator = table_name(initiator)
|
||||
dict_[self.parent_token.id.key] = pk
|
||||
|
||||
for index, id in enumerate(self.parent_token.id):
|
||||
dict_[id.key] = pk[index]
|
||||
dict_[self.parent_token.discriminator.key] = discriminator
|
||||
|
||||
|
||||
@@ -87,7 +90,7 @@ class GenericRelationshipProperty(MapperProperty):
|
||||
|
||||
def __init__(self, discriminator, id, doc=None):
|
||||
self._discriminator_col = discriminator
|
||||
self._id_col = id
|
||||
self._id_cols = id
|
||||
self._id = None
|
||||
self._discriminator = None
|
||||
self.doc = doc
|
||||
@@ -101,19 +104,21 @@ class GenericRelationshipProperty(MapperProperty):
|
||||
return attr
|
||||
|
||||
def init(self):
|
||||
# Resolve columns to attributes.
|
||||
if isinstance(self._discriminator_col, six.string_types):
|
||||
self._discriminator_col = self.parent.columns[
|
||||
self._discriminator_col
|
||||
]
|
||||
def convert_strings(column):
|
||||
if isinstance(column, six.string_types):
|
||||
return self.parent.columns[column]
|
||||
return column
|
||||
|
||||
if isinstance(self._id_col, six.string_types):
|
||||
self._id_col = self.parent.columns[
|
||||
self._id_col
|
||||
]
|
||||
self._discriminator_col = convert_strings(self._discriminator_col)
|
||||
self._id_cols = convert_strings(self._id_cols)
|
||||
|
||||
if isinstance(self._id_cols, Iterable):
|
||||
self._id_cols = list(map(convert_strings, self._id_cols))
|
||||
else:
|
||||
self._id_cols = [self._id_cols]
|
||||
|
||||
self.discriminator = self._column_to_property(self._discriminator_col)
|
||||
self.id = self._column_to_property(self._id_col)
|
||||
self.id = list(map(self._column_to_property, self._id_cols))
|
||||
|
||||
class Comparator(PropComparator):
|
||||
|
||||
@@ -124,7 +129,9 @@ class GenericRelationshipProperty(MapperProperty):
|
||||
def __eq__(self, other):
|
||||
discriminator = table_name(other)
|
||||
q = self.property._discriminator_col == discriminator
|
||||
q &= self.property._id_col == other.id
|
||||
other_id = identity(other)
|
||||
for index, id in enumerate(self.property._id_cols):
|
||||
q &= id == other_id[index]
|
||||
return q
|
||||
|
||||
def __ne__(self, other):
|
||||
|
@@ -97,7 +97,7 @@ All range types support all comparison operators (>, >=, ==, !=, <=, <).
|
||||
Car.price_range << [300, 500]
|
||||
|
||||
# Whether or not range is strictly right of another range
|
||||
Car.price_range << [300, 500]
|
||||
Car.price_range >> [300, 500]
|
||||
|
||||
|
||||
|
||||
|
@@ -1,8 +1,5 @@
|
||||
from __future__ import unicode_literals
|
||||
import sqlalchemy as sa
|
||||
from tests import TestCase
|
||||
from sqlalchemy_utils import generic_relationship
|
||||
from sqlalchemy.ext.declarative import declared_attr
|
||||
|
||||
|
||||
class GenericRelationshipTestCase(TestCase):
|
||||
@@ -107,56 +104,3 @@ class GenericRelationshipTestCase(TestCase):
|
||||
statement = self.Event.object.is_type(self.User)
|
||||
q = self.session.query(self.Event).filter(statement)
|
||||
assert q.first() is not None
|
||||
|
||||
|
||||
class TestGenericRelationship(GenericRelationshipTestCase):
|
||||
def create_models(self):
|
||||
class Building(self.Base):
|
||||
__tablename__ = 'building'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
class User(self.Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
class Event(self.Base):
|
||||
__tablename__ = 'event'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
object_type = sa.Column(sa.Unicode(255), name="objectType")
|
||||
object_id = sa.Column(sa.Integer, nullable=False)
|
||||
|
||||
object = generic_relationship(object_type, object_id)
|
||||
|
||||
self.Building = Building
|
||||
self.User = User
|
||||
self.Event = Event
|
||||
|
||||
|
||||
class TestGenericRelationshipWithAbstractBase(GenericRelationshipTestCase):
|
||||
def create_models(self):
|
||||
class Building(self.Base):
|
||||
__tablename__ = 'building'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
class User(self.Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
class EventBase(self.Base):
|
||||
__abstract__ = True
|
||||
|
||||
object_type = sa.Column(sa.Unicode(255))
|
||||
object_id = sa.Column(sa.Integer, nullable=False)
|
||||
|
||||
@declared_attr
|
||||
def object(cls):
|
||||
return generic_relationship('object_type', 'object_id')
|
||||
|
||||
class Event(EventBase):
|
||||
__tablename__ = 'event'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
self.Building = Building
|
||||
self.User = User
|
||||
self.Event = Event
|
34
tests/generic_relationship/test_abstract_base_class.py
Normal file
34
tests/generic_relationship/test_abstract_base_class.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from __future__ import unicode_literals
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy_utils import generic_relationship
|
||||
from sqlalchemy.ext.declarative import declared_attr
|
||||
from tests.generic_relationship import GenericRelationshipTestCase
|
||||
|
||||
|
||||
class TestGenericRelationshipWithAbstractBase(GenericRelationshipTestCase):
|
||||
def create_models(self):
|
||||
class Building(self.Base):
|
||||
__tablename__ = 'building'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
class User(self.Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
class EventBase(self.Base):
|
||||
__abstract__ = True
|
||||
|
||||
object_type = sa.Column(sa.Unicode(255))
|
||||
object_id = sa.Column(sa.Integer, nullable=False)
|
||||
|
||||
@declared_attr
|
||||
def object(cls):
|
||||
return generic_relationship('object_type', 'object_id')
|
||||
|
||||
class Event(EventBase):
|
||||
__tablename__ = 'event'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
self.Building = Building
|
||||
self.User = User
|
||||
self.Event = Event
|
63
tests/generic_relationship/test_composite_keys.py
Normal file
63
tests/generic_relationship/test_composite_keys.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from __future__ import unicode_literals
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy_utils import generic_relationship
|
||||
from tests.generic_relationship import GenericRelationshipTestCase
|
||||
|
||||
|
||||
class TestGenericRelationship(GenericRelationshipTestCase):
|
||||
index = 1
|
||||
|
||||
def create_models(self):
|
||||
class Building(self.Base):
|
||||
__tablename__ = 'building'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
code = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
def __init__(obj_self):
|
||||
self.index += 1
|
||||
obj_self.id = self.index
|
||||
obj_self.code = self.index
|
||||
|
||||
class User(self.Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
code = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
def __init__(obj_self):
|
||||
self.index += 1
|
||||
obj_self.id = self.index
|
||||
obj_self.code = self.index
|
||||
|
||||
class Event(self.Base):
|
||||
__tablename__ = 'event'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
object_type = sa.Column(sa.Unicode(255))
|
||||
object_id = sa.Column(sa.Integer, nullable=False)
|
||||
object_code = sa.Column(sa.Integer, nullable=False)
|
||||
|
||||
object = generic_relationship(
|
||||
object_type, (object_id, object_code)
|
||||
)
|
||||
|
||||
self.Building = Building
|
||||
self.User = User
|
||||
self.Event = Event
|
||||
|
||||
def test_set_manual_and_get(self):
|
||||
user = self.User()
|
||||
|
||||
self.session.add(user)
|
||||
self.session.commit()
|
||||
|
||||
event = self.Event()
|
||||
event.object_id = user.id
|
||||
event.object_type = type(user).__tablename__
|
||||
event.object_code = user.code
|
||||
|
||||
assert event.object is None
|
||||
|
||||
self.session.add(event)
|
||||
self.session.commit()
|
||||
|
||||
assert event.object == user
|
28
tests/generic_relationship/test_simple.py
Normal file
28
tests/generic_relationship/test_simple.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from __future__ import unicode_literals
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy_utils import generic_relationship
|
||||
from tests.generic_relationship import GenericRelationshipTestCase
|
||||
|
||||
|
||||
class TestGenericRelationship(GenericRelationshipTestCase):
|
||||
def create_models(self):
|
||||
class Building(self.Base):
|
||||
__tablename__ = 'building'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
class User(self.Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
class Event(self.Base):
|
||||
__tablename__ = 'event'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
object_type = sa.Column(sa.Unicode(255), name="objectType")
|
||||
object_id = sa.Column(sa.Integer, nullable=False)
|
||||
|
||||
object = generic_relationship(object_type, object_id)
|
||||
|
||||
self.Building = Building
|
||||
self.User = User
|
||||
self.Event = Event
|
Reference in New Issue
Block a user