Add support for composite keys, refs #68

This commit is contained in:
Konsta Vesterinen
2014-03-05 11:52:03 +02:00
parent 44496e26a3
commit ded97b783c
7 changed files with 162 additions and 81 deletions

View File

@@ -3,8 +3,13 @@ Changelog
Here you can see the full list of changes between each SQLAlchemy-Utils release.
0.24.3 (2014-02-xx)
^^^^^^^^^^^^^^^^^^^
0.24.2 (2014-02-21)
- Added string argument support for generic_relationship
0.24.2 (2014-03-05)
^^^^^^^^^^^^^^^^^^^
- Remove toolz from dependencies

View File

@@ -1,11 +1,12 @@
from collections import Iterable
import six
from sqlalchemy.orm.interfaces import MapperProperty, PropComparator
from sqlalchemy.orm.session import _state_session
from sqlalchemy.orm import attributes, class_mapper
from sqlalchemy.util import set_creation_order
from sqlalchemy import exc as sa_exc
from sqlalchemy_utils.functions import table_name
from sqlalchemy_utils.functions import table_name, identity
def class_from_table_name(state, table):
@@ -37,13 +38,16 @@ class GenericAttributeImpl(attributes.ScalarAttributeImpl):
# Unknown discriminator; return nothing.
return None
# Lookup row with the discriminator and id.
id = state.attrs[self.parent_token.id.key].value
id = self.get_state_id(state)
target = session.query(target_class).get(id)
# Return found (or not found) target.
return target
def get_state_id(self, state):
# Lookup row with the discriminator and id.
return tuple(state.attrs[id.key].value for id in self.parent_token.id)
def set(self, state, dict_, initiator,
passive=attributes.PASSIVE_OFF,
check_old=None,
@@ -54,22 +58,21 @@ class GenericAttributeImpl(attributes.ScalarAttributeImpl):
if initiator is None:
# Nullify relationship args
dict_[self.parent_token.id.key] = None
for id in self.parent_token.id:
dict_[id.key] = None
dict_[self.parent_token.discriminator.key] = None
else:
# Get the primary key of the initiator and ensure we
# can support this assignment.
mapper = class_mapper(type(initiator))
if len(mapper.primary_key) > 1:
raise sa_exc.InvalidRequestError(
'Generic relationships against tables with composite '
'primary keys are not supported.')
pk = mapper.identity_key_from_instance(initiator)[1][0]
pk = mapper.identity_key_from_instance(initiator)[1]
# Set the identifier and the discriminator.
discriminator = table_name(initiator)
dict_[self.parent_token.id.key] = pk
for index, id in enumerate(self.parent_token.id):
dict_[id.key] = pk[index]
dict_[self.parent_token.discriminator.key] = discriminator
@@ -87,7 +90,7 @@ class GenericRelationshipProperty(MapperProperty):
def __init__(self, discriminator, id, doc=None):
self._discriminator_col = discriminator
self._id_col = id
self._id_cols = id
self._id = None
self._discriminator = None
self.doc = doc
@@ -101,19 +104,21 @@ class GenericRelationshipProperty(MapperProperty):
return attr
def init(self):
# Resolve columns to attributes.
if isinstance(self._discriminator_col, six.string_types):
self._discriminator_col = self.parent.columns[
self._discriminator_col
]
def convert_strings(column):
if isinstance(column, six.string_types):
return self.parent.columns[column]
return column
if isinstance(self._id_col, six.string_types):
self._id_col = self.parent.columns[
self._id_col
]
self._discriminator_col = convert_strings(self._discriminator_col)
self._id_cols = convert_strings(self._id_cols)
if isinstance(self._id_cols, Iterable):
self._id_cols = list(map(convert_strings, self._id_cols))
else:
self._id_cols = [self._id_cols]
self.discriminator = self._column_to_property(self._discriminator_col)
self.id = self._column_to_property(self._id_col)
self.id = list(map(self._column_to_property, self._id_cols))
class Comparator(PropComparator):
@@ -124,7 +129,9 @@ class GenericRelationshipProperty(MapperProperty):
def __eq__(self, other):
discriminator = table_name(other)
q = self.property._discriminator_col == discriminator
q &= self.property._id_col == other.id
other_id = identity(other)
for index, id in enumerate(self.property._id_cols):
q &= id == other_id[index]
return q
def __ne__(self, other):

View File

@@ -97,7 +97,7 @@ All range types support all comparison operators (>, >=, ==, !=, <=, <).
Car.price_range << [300, 500]
# Whether or not range is strictly right of another range
Car.price_range << [300, 500]
Car.price_range >> [300, 500]

View File

@@ -1,8 +1,5 @@
from __future__ import unicode_literals
import sqlalchemy as sa
from tests import TestCase
from sqlalchemy_utils import generic_relationship
from sqlalchemy.ext.declarative import declared_attr
class GenericRelationshipTestCase(TestCase):
@@ -107,56 +104,3 @@ class GenericRelationshipTestCase(TestCase):
statement = self.Event.object.is_type(self.User)
q = self.session.query(self.Event).filter(statement)
assert q.first() is not None
class TestGenericRelationship(GenericRelationshipTestCase):
def create_models(self):
class Building(self.Base):
__tablename__ = 'building'
id = sa.Column(sa.Integer, primary_key=True)
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
class Event(self.Base):
__tablename__ = 'event'
id = sa.Column(sa.Integer, primary_key=True)
object_type = sa.Column(sa.Unicode(255), name="objectType")
object_id = sa.Column(sa.Integer, nullable=False)
object = generic_relationship(object_type, object_id)
self.Building = Building
self.User = User
self.Event = Event
class TestGenericRelationshipWithAbstractBase(GenericRelationshipTestCase):
def create_models(self):
class Building(self.Base):
__tablename__ = 'building'
id = sa.Column(sa.Integer, primary_key=True)
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
class EventBase(self.Base):
__abstract__ = True
object_type = sa.Column(sa.Unicode(255))
object_id = sa.Column(sa.Integer, nullable=False)
@declared_attr
def object(cls):
return generic_relationship('object_type', 'object_id')
class Event(EventBase):
__tablename__ = 'event'
id = sa.Column(sa.Integer, primary_key=True)
self.Building = Building
self.User = User
self.Event = Event

View File

@@ -0,0 +1,34 @@
from __future__ import unicode_literals
import sqlalchemy as sa
from sqlalchemy_utils import generic_relationship
from sqlalchemy.ext.declarative import declared_attr
from tests.generic_relationship import GenericRelationshipTestCase
class TestGenericRelationshipWithAbstractBase(GenericRelationshipTestCase):
def create_models(self):
class Building(self.Base):
__tablename__ = 'building'
id = sa.Column(sa.Integer, primary_key=True)
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
class EventBase(self.Base):
__abstract__ = True
object_type = sa.Column(sa.Unicode(255))
object_id = sa.Column(sa.Integer, nullable=False)
@declared_attr
def object(cls):
return generic_relationship('object_type', 'object_id')
class Event(EventBase):
__tablename__ = 'event'
id = sa.Column(sa.Integer, primary_key=True)
self.Building = Building
self.User = User
self.Event = Event

View File

@@ -0,0 +1,63 @@
from __future__ import unicode_literals
import sqlalchemy as sa
from sqlalchemy_utils import generic_relationship
from tests.generic_relationship import GenericRelationshipTestCase
class TestGenericRelationship(GenericRelationshipTestCase):
index = 1
def create_models(self):
class Building(self.Base):
__tablename__ = 'building'
id = sa.Column(sa.Integer, primary_key=True)
code = sa.Column(sa.Integer, primary_key=True)
def __init__(obj_self):
self.index += 1
obj_self.id = self.index
obj_self.code = self.index
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
code = sa.Column(sa.Integer, primary_key=True)
def __init__(obj_self):
self.index += 1
obj_self.id = self.index
obj_self.code = self.index
class Event(self.Base):
__tablename__ = 'event'
id = sa.Column(sa.Integer, primary_key=True)
object_type = sa.Column(sa.Unicode(255))
object_id = sa.Column(sa.Integer, nullable=False)
object_code = sa.Column(sa.Integer, nullable=False)
object = generic_relationship(
object_type, (object_id, object_code)
)
self.Building = Building
self.User = User
self.Event = Event
def test_set_manual_and_get(self):
user = self.User()
self.session.add(user)
self.session.commit()
event = self.Event()
event.object_id = user.id
event.object_type = type(user).__tablename__
event.object_code = user.code
assert event.object is None
self.session.add(event)
self.session.commit()
assert event.object == user

View File

@@ -0,0 +1,28 @@
from __future__ import unicode_literals
import sqlalchemy as sa
from sqlalchemy_utils import generic_relationship
from tests.generic_relationship import GenericRelationshipTestCase
class TestGenericRelationship(GenericRelationshipTestCase):
def create_models(self):
class Building(self.Base):
__tablename__ = 'building'
id = sa.Column(sa.Integer, primary_key=True)
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
class Event(self.Base):
__tablename__ = 'event'
id = sa.Column(sa.Integer, primary_key=True)
object_type = sa.Column(sa.Unicode(255), name="objectType")
object_id = sa.Column(sa.Integer, nullable=False)
object = generic_relationship(object_type, object_id)
self.Building = Building
self.User = User
self.Event = Event