From c9b9ec93f924a0cb2eee7c4fd9afad8f8491dfe7 Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Mon, 20 Jan 2014 21:30:54 +0200 Subject: [PATCH] Fix nullify generic relationship --- sqlalchemy_utils/generic.py | 29 +++++++++++++++++------------ tests/test_generic.py | 4 ++++ 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/sqlalchemy_utils/generic.py b/sqlalchemy_utils/generic.py index f2855d2..e4aeb49 100644 --- a/sqlalchemy_utils/generic.py +++ b/sqlalchemy_utils/generic.py @@ -50,20 +50,25 @@ class GenericAttributeImpl(attributes.ScalarAttributeImpl): # Set us on the state. dict_[self.key] = initiator - # Get the primary key of the initiator and ensure we - # can support this assignment. - mapper = class_mapper(type(initiator)) - if len(mapper.primary_key) > 1: - raise sa_exc.InvalidRequestError( - 'Generic relationships against tables with composite ' - 'primary keys are not supported.') + if initiator is None: + # Nullify relationship args + dict_[self.parent_token.id.key] = None + dict_[self.parent_token.discriminator.key] = None + else: + # Get the primary key of the initiator and ensure we + # can support this assignment. + mapper = class_mapper(type(initiator)) + if len(mapper.primary_key) > 1: + raise sa_exc.InvalidRequestError( + 'Generic relationships against tables with composite ' + 'primary keys are not supported.') - pk = mapper.identity_key_from_instance(initiator)[1][0] + pk = mapper.identity_key_from_instance(initiator)[1][0] - # Set the identifier and the discriminator. - discriminator = table_name(initiator) - dict_[self.parent_token.id.key] = pk - dict_[self.parent_token.discriminator.key] = discriminator + # Set the identifier and the discriminator. + discriminator = table_name(initiator) + dict_[self.parent_token.id.key] = pk + dict_[self.parent_token.discriminator.key] = discriminator class GenericRelationshipProperty(MapperProperty): diff --git a/tests/test_generic.py b/tests/test_generic.py index c52a3ab..af94265 100644 --- a/tests/test_generic.py +++ b/tests/test_generic.py @@ -28,6 +28,10 @@ class TestGenericForiegnKey(TestCase): self.User = User self.Event = Event + def test_set_as_none(self): + event = self.Event() + event.object = None + def test_set_manual_and_get(self): user = self.User()