diff --git a/docs/index.rst b/docs/index.rst index efe5d16..eacc165 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -209,6 +209,56 @@ TimezoneType saves timezone objects as strings on the way in and converts them b timezone = sa.Column(TimezoneType(backend='pytz')) +Generic Relationship +-------------------- + +Generic relationship is a form of relationship that supports creating a 1 to many relationship to any target model. + +:: + + from sqlalchemy_utils import generic_relationship + + class User(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + + class Customer(Base): + __tablename__ = 'customer' + id = sa.Column(sa.Integer, primary_key=True) + + class Event(Base): + __tablename__ = 'event' + id = sa.Column(sa.Integer, primary_key=True) + + # This is used to discriminate between the linked tables. + object_type = sa.Column(sa.Unicode(255)) + + # This is used to point to the primary key of the linked row. + object_id = sa.Column(sa.Integer) + + object = generic_relationship(object_type, object_id) + + + # Some general usage to attach an event to a user. + us_1 = User() + cu_1 = Customer() + + session.add_all([us_1, cu_1]) + session.commit() + + ev = Event() + ev.object = us_1 + + session.add(ev) + session.commit() + + # Find the event we just made. + session.query(Event).filter_by(object=us_1).first() + + # Find any events that are bound to users. + session.query(Event).filter(Event.object.is_type(User)).all() + + API Documentation ----------------- diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index 2917856..aa4d3fa 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -11,6 +11,7 @@ from .functions import ( ) from .listeners import coercion_listener from .merge import merge, Merger +from .generic import generic_relationship from .proxy_dict import ProxyDict, proxy_dict from .types import ( ArrowType, @@ -53,6 +54,7 @@ __all__ = ( sort_query, table_name, with_backrefs, + generic_relationship, ArrowType, ColorType, CountryType, diff --git a/sqlalchemy_utils/functions/__init__.py b/sqlalchemy_utils/functions/__init__.py index d536df4..01208b6 100644 --- a/sqlalchemy_utils/functions/__init__.py +++ b/sqlalchemy_utils/functions/__init__.py @@ -90,14 +90,22 @@ def primary_keys(class_): yield column -def table_name(class_): +def table_name(obj): """ - Return table name of given declarative class. + Return table name of given target, declarative class or the + table name where the declarative attribute is bound to. """ + class_ = getattr(obj, 'class_', obj) + try: return class_.__tablename__ except AttributeError: + pass + + try: return class_.__table__.name + except AttributeError: + pass def non_indexed_foreign_keys(metadata, engine=None): diff --git a/sqlalchemy_utils/generic.py b/sqlalchemy_utils/generic.py new file mode 100644 index 0000000..676cdb5 --- /dev/null +++ b/sqlalchemy_utils/generic.py @@ -0,0 +1,128 @@ +from sqlalchemy.orm.interfaces import MapperProperty, PropComparator +from sqlalchemy.orm.session import _state_session +from sqlalchemy.orm import attributes, class_mapper +from sqlalchemy.util import set_creation_order +from sqlalchemy import exc as sa_exc +from .functions import table_name + + +class GenericAttributeImpl(attributes.ScalarAttributeImpl): + + def get(self, state, dict_, passive=attributes.PASSIVE_OFF): + if self.key in dict_: + return dict_[self.key] + + # Retrieve the session bound to the state in order to perform + # a lazy query for the attribute. + session = _state_session(state) + if session is None: + # State is not bound to a session; we cannot proceed. + return None + + # Find class for discriminator. + # TODO: Perhaps optimize with some sort of lookup? + discriminator = state.attrs[self.parent_token.discriminator.key].value + target_class = None + for class_ in state.class_._decl_class_registry.values(): + name = table_name(class_) + if name and name == discriminator: + target_class = class_ + + if target_class is None: + # Unknown discriminator; return nothing. + return None + + # Lookup row with the discriminator and id. + id = state.attrs[self.parent_token.id.key].value + target = session.query(target_class).get(id) + + # Return found (or not found) target. + return target + + def set(self, state, dict_, initiator, + passive=attributes.PASSIVE_OFF, + check_old=None, + pop=False): + + # Set us on the state. + dict_[self.key] = initiator + + # Get the primary key of the initiator and ensure we + # can support this assignment. + mapper = class_mapper(type(initiator)) + if len(mapper.primary_key) > 1: + raise sa_exc.InvalidRequestError( + 'Generic relationships against tables with composite ' + 'primary keys are not supported.') + + pk = mapper.identity_key_from_instance(initiator)[1][0] + + # Set the identifier and the discriminator. + discriminator = table_name(initiator) + dict_[self.parent_token.id.key] = pk + dict_[self.parent_token.discriminator.key] = discriminator + + +class GenericRelationshipProperty(MapperProperty): + """A generic form of the relationship property. + + Creates a 1 to many relationship between the parent model + and any other models using a descriminator (the table name). + + :param discriminator + Field to discriminate which model we are referring to. + :param id: + Field to point to the model we are referring to. + """ + + def __init__(self, discriminator, id, doc=None): + self._discriminator_col = discriminator + self._id_col = id + self.doc = doc + + set_creation_order(self) + + def _column_to_property(self, column): + for name, attr in self.parent.attrs.items(): + other = self.parent.columns.get(name) + if other is not None and column.name == other.name: + return attr + + def init(self): + # Resolve columns to attributes. + self.discriminator = self._column_to_property(self._discriminator_col) + self.id = self._column_to_property(self._id_col) + + class Comparator(PropComparator): + + def __init__(self, prop, parentmapper): + self.prop = prop + self._parentmapper = parentmapper + + def __eq__(self, other): + discriminator = table_name(other) + q = self.prop._discriminator_col == discriminator + q &= self.prop._id_col == other.id + return q + + def __ne__(self, other): + return ~(self == other) + + def is_type(self, other): + discriminator = table_name(other) + return self.prop._discriminator_col == discriminator + + def instrument_class(self, mapper): + attributes.register_attribute( + mapper.class_, + self.key, + comparator=self.Comparator(self, mapper), + parententity=mapper, + doc=self.doc, + impl_class=GenericAttributeImpl, + parent_token=self + ) + + +def generic_relationship(*args, **kwargs): + return GenericRelationshipProperty(*args, **kwargs) diff --git a/tests/test_generic.py b/tests/test_generic.py new file mode 100644 index 0000000..c52a3ab --- /dev/null +++ b/tests/test_generic.py @@ -0,0 +1,126 @@ +from __future__ import unicode_literals +import sqlalchemy as sa +from tests import TestCase +from sqlalchemy_utils import generic_relationship + + +class TestGenericForiegnKey(TestCase): + + def create_models(self): + class Building(self.Base): + __tablename__ = 'building' + id = sa.Column(sa.Integer, primary_key=True) + + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + + class Event(self.Base): + __tablename__ = 'event' + id = sa.Column(sa.Integer, primary_key=True) + + object_type = sa.Column(sa.Unicode(255), name="objectType") + object_id = sa.Column(sa.Integer, nullable=False) + + object = generic_relationship(object_type, object_id) + + self.Building = Building + self.User = User + self.Event = Event + + def test_set_manual_and_get(self): + user = self.User() + + self.session.add(user) + self.session.commit() + + event = self.Event() + event.object_id = user.id + event.object_type = type(user).__tablename__ + + assert event.object is None + + self.session.add(event) + self.session.commit() + + assert event.object == user + + def test_set_and_get(self): + user = self.User() + + self.session.add(user) + self.session.commit() + + event = self.Event(object=user) + + assert event.object_id == user.id + assert event.object_type == type(user).__tablename__ + + self.session.add(event) + self.session.commit() + + assert event.object == user + + def test_compare_instance(self): + user1 = self.User() + user2 = self.User() + + self.session.add_all([user1, user2]) + self.session.commit() + + event = self.Event(object=user1) + + self.session.add(event) + self.session.commit() + + assert event.object == user1 + assert event.object != user2 + + def test_compare_query(self): + user1 = self.User() + user2 = self.User() + + self.session.add_all([user1, user2]) + self.session.commit() + + event = self.Event(object=user1) + + self.session.add(event) + self.session.commit() + + q = self.session.query(self.Event) + assert q.filter_by(object=user1).first() is not None + assert q.filter_by(object=user2).first() is None + assert q.filter(self.Event.object == user2).first() is None + + def test_compare_not_query(self): + user1 = self.User() + user2 = self.User() + + self.session.add_all([user1, user2]) + self.session.commit() + + event = self.Event(object=user1) + + self.session.add(event) + self.session.commit() + + q = self.session.query(self.Event) + assert q.filter(self.Event.object != user2).first() is not None + + def test_compare_type(self): + user1 = self.User() + user2 = self.User() + + self.session.add_all([user1, user2]) + self.session.commit() + + event1 = self.Event(object=user1) + event2 = self.Event(object=user2) + + self.session.add_all([event1, event2]) + self.session.commit() + + statement = self.Event.object.is_type(self.User) + q = self.session.query(self.Event).filter(statement) + assert q.first() is not None diff --git a/tests/test_table_name.py b/tests/test_table_name.py index 2cfc19e..234d22f 100644 --- a/tests/test_table_name.py +++ b/tests/test_table_name.py @@ -12,7 +12,14 @@ class TestTableName(TestCase): self.Building = Building - def test_table_name(self): + def test_class(self): assert table_name(self.Building) == 'building' del self.Building.__tablename__ assert table_name(self.Building) == 'building' + + def test_attribute(self): + assert table_name(self.Building.id) == 'building' + assert table_name(self.Building.name) == 'building' + + def test_target(self): + assert table_name(self.Building()) == 'building'