From a6169a553c30fe5f6c1e5344070db47a1c106f22 Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Sun, 8 Sep 2013 11:41:43 -0700 Subject: [PATCH 1/6] Initial generic relationship implementation. --- sqlalchemy_utils/__init__.py | 1 + sqlalchemy_utils/generic.py | 111 +++++++++++++++++++++++++++++++++++ tests/test_generic.py | 66 +++++++++++++++++++++ 3 files changed, 178 insertions(+) create mode 100644 sqlalchemy_utils/generic.py create mode 100644 tests/test_generic.py diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index 2917856..28c015b 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -11,6 +11,7 @@ from .functions import ( ) from .listeners import coercion_listener from .merge import merge, Merger +from .generic import generic_relationship from .proxy_dict import ProxyDict, proxy_dict from .types import ( ArrowType, diff --git a/sqlalchemy_utils/generic.py b/sqlalchemy_utils/generic.py new file mode 100644 index 0000000..ce57424 --- /dev/null +++ b/sqlalchemy_utils/generic.py @@ -0,0 +1,111 @@ +from sqlalchemy.orm.interfaces import MapperProperty, PropComparator +from sqlalchemy.orm.session import _state_session +from sqlalchemy.orm import attributes, class_mapper +from sqlalchemy.util import set_creation_order +from sqlalchemy import exc as sa_exc, inspect + + +class GenericAttributeImpl(attributes.ScalarAttributeImpl): + + def get(self, state, dict_, passive=attributes.PASSIVE_OFF): + if self.key in dict_: + return dict_[self.key] + + # Retrieve the session bound to the state in order to perform + # a lazy query for the attribute. + session = _state_session(state) + if session is None: + # State is not bound to a session; we cannot proceed. + return None + + # Find class for discriminator. + # TODO: Perhaps optimize with some sort of lookup? + discriminator = state.attrs[self.parent_token.discriminator.key].value + target_class = None + for class_ in state.class_._decl_class_registry.values(): + name = getattr(class_, '__tablename__', None) + if name and name == discriminator: + target_class = class_ + + if target_class is None: + # Unknown discriminator; return nothing. + return None + + # Lookup row with the discriminator and id. + id = state.attrs[self.parent_token.id.key].value + target = session.query(target_class).get(id) + + # Return found (or not found) target. + return target + + def set(self, state, dict_, initiator, + passive=attributes.PASSIVE_OFF, + check_old=None, + pop=False): + + # Set us on the state. + dict_[self.key] = initiator + + # Get the primary key of the initiator. + mapper = class_mapper(type(initiator)) + if len(mapper.primary_key) > 1: + raise sa_exc.InvalidRequestError( + 'Generic relationships against tables with composite ' + 'primary keys are not supported.') + + pk = mapper.identity_key_from_instance(initiator)[1][0] + + # Set the identifier and the discriminator. + discriminator = type(initiator).__tablename__ + dict_[self.parent_token.id.key] = pk + dict_[self.parent_token.discriminator.key] = discriminator + +class GenericRelationshipProperty(MapperProperty): + """A generic form of the relationship property. + + Creates a 1 to many relationship between the parent model + and any other models using a descriminator (the table name). + + :param discriminator + Field to discriminate which model we are referring to. + :param id: + Field to point to the model we are referring to. + """ + + def __init__(self, discriminator, id, doc=None): + self.discriminator = discriminator + self.id = id + self.doc = doc + + set_creation_order(self) + + def _column_to_property(self, column): + for name, attr in self.parent.attrs.items(): + other = self.parent.columns.get(name) + if other is not None and column.name == other.name: + return attr + + def init(self): + # Resolve columns to attributes. + self.discriminator = self._column_to_property(self.discriminator) + self.id = self._column_to_property(self.id) + + class Comparator(PropComparator): + + def __init__(self, prop, parentmapper): + self.prop = prop + self._parentmapper = parentmapper + + def instrument_class(self, mapper): + attributes.register_attribute( + mapper.class_, + self.key, + comparator=self.Comparator(self, mapper), + parententity=mapper, + doc=self.doc, + impl_class=GenericAttributeImpl, + parent_token=self + ) + +def generic_relationship(*args, **kwargs): + return GenericRelationshipProperty(*args, **kwargs) diff --git a/tests/test_generic.py b/tests/test_generic.py new file mode 100644 index 0000000..3738be7 --- /dev/null +++ b/tests/test_generic.py @@ -0,0 +1,66 @@ +import sqlalchemy as sa +from sqlalchemy import orm +from tests import TestCase +from sqlalchemy_utils import generic_relationship + + +class TestGenericForiegnKey(TestCase): + + def create_models(self): + class Building(self.Base): + __tablename__ = 'building' + id = sa.Column(sa.Integer, primary_key=True) + + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + + building_id = sa.Column(sa.Integer, sa.ForeignKey(Building.id), name='buildingID') + + building = orm.relationship(Building) + + class Event(self.Base): + __tablename__ = 'event' + id = sa.Column(sa.Integer, primary_key=True) + + object_type = sa.Column(sa.Unicode(255), name="objectType") + object_id = sa.Column(sa.Integer, nullable=False) + + object = generic_relationship(object_type, object_id) + + self.Building = Building + self.User = User + self.Event = Event + + def test_set_manual_and_get(self): + user = self.User() + + self.session.add(user) + self.session.commit() + + event = self.Event() + event.object_id = user.id + event.object_type = type(user).__tablename__ + + assert event.object is None + + self.session.add(event) + self.session.commit() + + assert event.object == user + + def test_set_and_get(self): + user = self.User() + + self.session.add(user) + self.session.commit() + + event = self.Event(object=user) + + assert event.object_id == user.id + assert event.object_type == type(user).__tablename__ + + self.session.add(event) + self.session.commit() + + assert event.object == user From 684617f87504cd2efc7db0eaa7f0d461b4bd8d2a Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Mon, 9 Sep 2013 00:35:31 -0700 Subject: [PATCH 2/6] Extend table_name to support attributes. --- sqlalchemy_utils/functions/__init__.py | 12 ++++++++++-- tests/test_table_name.py | 9 ++++++++- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/sqlalchemy_utils/functions/__init__.py b/sqlalchemy_utils/functions/__init__.py index d536df4..01208b6 100644 --- a/sqlalchemy_utils/functions/__init__.py +++ b/sqlalchemy_utils/functions/__init__.py @@ -90,14 +90,22 @@ def primary_keys(class_): yield column -def table_name(class_): +def table_name(obj): """ - Return table name of given declarative class. + Return table name of given target, declarative class or the + table name where the declarative attribute is bound to. """ + class_ = getattr(obj, 'class_', obj) + try: return class_.__tablename__ except AttributeError: + pass + + try: return class_.__table__.name + except AttributeError: + pass def non_indexed_foreign_keys(metadata, engine=None): diff --git a/tests/test_table_name.py b/tests/test_table_name.py index 2cfc19e..234d22f 100644 --- a/tests/test_table_name.py +++ b/tests/test_table_name.py @@ -12,7 +12,14 @@ class TestTableName(TestCase): self.Building = Building - def test_table_name(self): + def test_class(self): assert table_name(self.Building) == 'building' del self.Building.__tablename__ assert table_name(self.Building) == 'building' + + def test_attribute(self): + assert table_name(self.Building.id) == 'building' + assert table_name(self.Building.name) == 'building' + + def test_target(self): + assert table_name(self.Building()) == 'building' From d94aa24110042860b193c1dd33eeb8abc27e38bb Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Mon, 9 Sep 2013 00:38:32 -0700 Subject: [PATCH 3/6] Finish generic relationship property; add comparators and further testing. --- sqlalchemy_utils/generic.py | 33 +++++++++++++----- tests/test_generic.py | 69 ++++++++++++++++++++++++++++++++++--- 2 files changed, 89 insertions(+), 13 deletions(-) diff --git a/sqlalchemy_utils/generic.py b/sqlalchemy_utils/generic.py index ce57424..676cdb5 100644 --- a/sqlalchemy_utils/generic.py +++ b/sqlalchemy_utils/generic.py @@ -2,7 +2,8 @@ from sqlalchemy.orm.interfaces import MapperProperty, PropComparator from sqlalchemy.orm.session import _state_session from sqlalchemy.orm import attributes, class_mapper from sqlalchemy.util import set_creation_order -from sqlalchemy import exc as sa_exc, inspect +from sqlalchemy import exc as sa_exc +from .functions import table_name class GenericAttributeImpl(attributes.ScalarAttributeImpl): @@ -23,7 +24,7 @@ class GenericAttributeImpl(attributes.ScalarAttributeImpl): discriminator = state.attrs[self.parent_token.discriminator.key].value target_class = None for class_ in state.class_._decl_class_registry.values(): - name = getattr(class_, '__tablename__', None) + name = table_name(class_) if name and name == discriminator: target_class = class_ @@ -46,7 +47,8 @@ class GenericAttributeImpl(attributes.ScalarAttributeImpl): # Set us on the state. dict_[self.key] = initiator - # Get the primary key of the initiator. + # Get the primary key of the initiator and ensure we + # can support this assignment. mapper = class_mapper(type(initiator)) if len(mapper.primary_key) > 1: raise sa_exc.InvalidRequestError( @@ -56,10 +58,11 @@ class GenericAttributeImpl(attributes.ScalarAttributeImpl): pk = mapper.identity_key_from_instance(initiator)[1][0] # Set the identifier and the discriminator. - discriminator = type(initiator).__tablename__ + discriminator = table_name(initiator) dict_[self.parent_token.id.key] = pk dict_[self.parent_token.discriminator.key] = discriminator + class GenericRelationshipProperty(MapperProperty): """A generic form of the relationship property. @@ -73,8 +76,8 @@ class GenericRelationshipProperty(MapperProperty): """ def __init__(self, discriminator, id, doc=None): - self.discriminator = discriminator - self.id = id + self._discriminator_col = discriminator + self._id_col = id self.doc = doc set_creation_order(self) @@ -87,8 +90,8 @@ class GenericRelationshipProperty(MapperProperty): def init(self): # Resolve columns to attributes. - self.discriminator = self._column_to_property(self.discriminator) - self.id = self._column_to_property(self.id) + self.discriminator = self._column_to_property(self._discriminator_col) + self.id = self._column_to_property(self._id_col) class Comparator(PropComparator): @@ -96,6 +99,19 @@ class GenericRelationshipProperty(MapperProperty): self.prop = prop self._parentmapper = parentmapper + def __eq__(self, other): + discriminator = table_name(other) + q = self.prop._discriminator_col == discriminator + q &= self.prop._id_col == other.id + return q + + def __ne__(self, other): + return ~(self == other) + + def is_type(self, other): + discriminator = table_name(other) + return self.prop._discriminator_col == discriminator + def instrument_class(self, mapper): attributes.register_attribute( mapper.class_, @@ -107,5 +123,6 @@ class GenericRelationshipProperty(MapperProperty): parent_token=self ) + def generic_relationship(*args, **kwargs): return GenericRelationshipProperty(*args, **kwargs) diff --git a/tests/test_generic.py b/tests/test_generic.py index 3738be7..6443076 100644 --- a/tests/test_generic.py +++ b/tests/test_generic.py @@ -1,5 +1,4 @@ import sqlalchemy as sa -from sqlalchemy import orm from tests import TestCase from sqlalchemy_utils import generic_relationship @@ -15,10 +14,6 @@ class TestGenericForiegnKey(TestCase): __tablename__ = 'user' id = sa.Column(sa.Integer, primary_key=True) - building_id = sa.Column(sa.Integer, sa.ForeignKey(Building.id), name='buildingID') - - building = orm.relationship(Building) - class Event(self.Base): __tablename__ = 'event' id = sa.Column(sa.Integer, primary_key=True) @@ -64,3 +59,67 @@ class TestGenericForiegnKey(TestCase): self.session.commit() assert event.object == user + + def test_compare_instance(self): + user1 = self.User() + user2 = self.User() + + self.session.add_all([user1, user2]) + self.session.commit() + + event = self.Event(object=user1) + + self.session.add(event) + self.session.commit() + + assert event.object == user1 + assert event.object != user2 + + def test_compare_query(self): + user1 = self.User() + user2 = self.User() + + self.session.add_all([user1, user2]) + self.session.commit() + + event = self.Event(object=user1) + + self.session.add(event) + self.session.commit() + + q = self.session.query(self.Event) + assert q.filter_by(object=user1).first() is not None + assert q.filter_by(object=user2).first() is None + assert q.filter(self.Event.object == user2).first() is None + + def test_compare_not_query(self): + user1 = self.User() + user2 = self.User() + + self.session.add_all([user1, user2]) + self.session.commit() + + event = self.Event(object=user1) + + self.session.add(event) + self.session.commit() + + q = self.session.query(self.Event) + assert q.filter(self.Event.object != user2).first() is not None + + def test_compare_type(self): + user1 = self.User() + user2 = self.User() + + self.session.add_all([user1, user2]) + self.session.commit() + + event1 = self.Event(object=user1) + event2 = self.Event(object=user2) + + self.session.add_all([event1, event2]) + self.session.commit() + + statement = self.Event.object.is_type(self.User) + q = self.session.query(self.Event).filter(statement) + assert q.first() is not None From 93bc135866beb6ad5bb9a84d9f67bf3ee89b34c4 Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Mon, 9 Sep 2013 00:50:21 -0700 Subject: [PATCH 4/6] Export generic_relationship in __init__ --- sqlalchemy_utils/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index 28c015b..aa4d3fa 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -54,6 +54,7 @@ __all__ = ( sort_query, table_name, with_backrefs, + generic_relationship, ArrowType, ColorType, CountryType, From 7d7996e16d5c0060100fbad3f3cfdbf4eaf7db6b Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Mon, 9 Sep 2013 00:51:45 -0700 Subject: [PATCH 5/6] Add small documentation and example for generic_relationship. --- docs/index.rst | 50 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/docs/index.rst b/docs/index.rst index efe5d16..eacc165 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -209,6 +209,56 @@ TimezoneType saves timezone objects as strings on the way in and converts them b timezone = sa.Column(TimezoneType(backend='pytz')) +Generic Relationship +-------------------- + +Generic relationship is a form of relationship that supports creating a 1 to many relationship to any target model. + +:: + + from sqlalchemy_utils import generic_relationship + + class User(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + + class Customer(Base): + __tablename__ = 'customer' + id = sa.Column(sa.Integer, primary_key=True) + + class Event(Base): + __tablename__ = 'event' + id = sa.Column(sa.Integer, primary_key=True) + + # This is used to discriminate between the linked tables. + object_type = sa.Column(sa.Unicode(255)) + + # This is used to point to the primary key of the linked row. + object_id = sa.Column(sa.Integer) + + object = generic_relationship(object_type, object_id) + + + # Some general usage to attach an event to a user. + us_1 = User() + cu_1 = Customer() + + session.add_all([us_1, cu_1]) + session.commit() + + ev = Event() + ev.object = us_1 + + session.add(ev) + session.commit() + + # Find the event we just made. + session.query(Event).filter_by(object=us_1).first() + + # Find any events that are bound to users. + session.query(Event).filter(Event.object.is_type(User)).all() + + API Documentation ----------------- From 90a3b4e3f1f0e7774562f9a8bf3d23e34461a76e Mon Sep 17 00:00:00 2001 From: Ryan Leckey Date: Mon, 9 Sep 2013 01:05:43 -0700 Subject: [PATCH 6/6] Unicode fix for python 2.x --- tests/test_generic.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_generic.py b/tests/test_generic.py index 6443076..c52a3ab 100644 --- a/tests/test_generic.py +++ b/tests/test_generic.py @@ -1,3 +1,4 @@ +from __future__ import unicode_literals import sqlalchemy as sa from tests import TestCase from sqlalchemy_utils import generic_relationship