Merge pull request #42 from kvesteri/topics/generic
Generic relationships
This commit is contained in:
@@ -209,6 +209,56 @@ TimezoneType saves timezone objects as strings on the way in and converts them b
|
|||||||
timezone = sa.Column(TimezoneType(backend='pytz'))
|
timezone = sa.Column(TimezoneType(backend='pytz'))
|
||||||
|
|
||||||
|
|
||||||
|
Generic Relationship
|
||||||
|
--------------------
|
||||||
|
|
||||||
|
Generic relationship is a form of relationship that supports creating a 1 to many relationship to any target model.
|
||||||
|
|
||||||
|
::
|
||||||
|
|
||||||
|
from sqlalchemy_utils import generic_relationship
|
||||||
|
|
||||||
|
class User(Base):
|
||||||
|
__tablename__ = 'user'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
|
||||||
|
class Customer(Base):
|
||||||
|
__tablename__ = 'customer'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
|
||||||
|
class Event(Base):
|
||||||
|
__tablename__ = 'event'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
|
||||||
|
# This is used to discriminate between the linked tables.
|
||||||
|
object_type = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
# This is used to point to the primary key of the linked row.
|
||||||
|
object_id = sa.Column(sa.Integer)
|
||||||
|
|
||||||
|
object = generic_relationship(object_type, object_id)
|
||||||
|
|
||||||
|
|
||||||
|
# Some general usage to attach an event to a user.
|
||||||
|
us_1 = User()
|
||||||
|
cu_1 = Customer()
|
||||||
|
|
||||||
|
session.add_all([us_1, cu_1])
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
ev = Event()
|
||||||
|
ev.object = us_1
|
||||||
|
|
||||||
|
session.add(ev)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
# Find the event we just made.
|
||||||
|
session.query(Event).filter_by(object=us_1).first()
|
||||||
|
|
||||||
|
# Find any events that are bound to users.
|
||||||
|
session.query(Event).filter(Event.object.is_type(User)).all()
|
||||||
|
|
||||||
|
|
||||||
API Documentation
|
API Documentation
|
||||||
-----------------
|
-----------------
|
||||||
|
|
||||||
|
@@ -11,6 +11,7 @@ from .functions import (
|
|||||||
)
|
)
|
||||||
from .listeners import coercion_listener
|
from .listeners import coercion_listener
|
||||||
from .merge import merge, Merger
|
from .merge import merge, Merger
|
||||||
|
from .generic import generic_relationship
|
||||||
from .proxy_dict import ProxyDict, proxy_dict
|
from .proxy_dict import ProxyDict, proxy_dict
|
||||||
from .types import (
|
from .types import (
|
||||||
ArrowType,
|
ArrowType,
|
||||||
@@ -53,6 +54,7 @@ __all__ = (
|
|||||||
sort_query,
|
sort_query,
|
||||||
table_name,
|
table_name,
|
||||||
with_backrefs,
|
with_backrefs,
|
||||||
|
generic_relationship,
|
||||||
ArrowType,
|
ArrowType,
|
||||||
ColorType,
|
ColorType,
|
||||||
CountryType,
|
CountryType,
|
||||||
|
@@ -90,14 +90,22 @@ def primary_keys(class_):
|
|||||||
yield column
|
yield column
|
||||||
|
|
||||||
|
|
||||||
def table_name(class_):
|
def table_name(obj):
|
||||||
"""
|
"""
|
||||||
Return table name of given declarative class.
|
Return table name of given target, declarative class or the
|
||||||
|
table name where the declarative attribute is bound to.
|
||||||
"""
|
"""
|
||||||
|
class_ = getattr(obj, 'class_', obj)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return class_.__tablename__
|
return class_.__tablename__
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
return class_.__table__.name
|
return class_.__table__.name
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def non_indexed_foreign_keys(metadata, engine=None):
|
def non_indexed_foreign_keys(metadata, engine=None):
|
||||||
|
128
sqlalchemy_utils/generic.py
Normal file
128
sqlalchemy_utils/generic.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
from sqlalchemy.orm.interfaces import MapperProperty, PropComparator
|
||||||
|
from sqlalchemy.orm.session import _state_session
|
||||||
|
from sqlalchemy.orm import attributes, class_mapper
|
||||||
|
from sqlalchemy.util import set_creation_order
|
||||||
|
from sqlalchemy import exc as sa_exc
|
||||||
|
from .functions import table_name
|
||||||
|
|
||||||
|
|
||||||
|
class GenericAttributeImpl(attributes.ScalarAttributeImpl):
|
||||||
|
|
||||||
|
def get(self, state, dict_, passive=attributes.PASSIVE_OFF):
|
||||||
|
if self.key in dict_:
|
||||||
|
return dict_[self.key]
|
||||||
|
|
||||||
|
# Retrieve the session bound to the state in order to perform
|
||||||
|
# a lazy query for the attribute.
|
||||||
|
session = _state_session(state)
|
||||||
|
if session is None:
|
||||||
|
# State is not bound to a session; we cannot proceed.
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Find class for discriminator.
|
||||||
|
# TODO: Perhaps optimize with some sort of lookup?
|
||||||
|
discriminator = state.attrs[self.parent_token.discriminator.key].value
|
||||||
|
target_class = None
|
||||||
|
for class_ in state.class_._decl_class_registry.values():
|
||||||
|
name = table_name(class_)
|
||||||
|
if name and name == discriminator:
|
||||||
|
target_class = class_
|
||||||
|
|
||||||
|
if target_class is None:
|
||||||
|
# Unknown discriminator; return nothing.
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Lookup row with the discriminator and id.
|
||||||
|
id = state.attrs[self.parent_token.id.key].value
|
||||||
|
target = session.query(target_class).get(id)
|
||||||
|
|
||||||
|
# Return found (or not found) target.
|
||||||
|
return target
|
||||||
|
|
||||||
|
def set(self, state, dict_, initiator,
|
||||||
|
passive=attributes.PASSIVE_OFF,
|
||||||
|
check_old=None,
|
||||||
|
pop=False):
|
||||||
|
|
||||||
|
# Set us on the state.
|
||||||
|
dict_[self.key] = initiator
|
||||||
|
|
||||||
|
# Get the primary key of the initiator and ensure we
|
||||||
|
# can support this assignment.
|
||||||
|
mapper = class_mapper(type(initiator))
|
||||||
|
if len(mapper.primary_key) > 1:
|
||||||
|
raise sa_exc.InvalidRequestError(
|
||||||
|
'Generic relationships against tables with composite '
|
||||||
|
'primary keys are not supported.')
|
||||||
|
|
||||||
|
pk = mapper.identity_key_from_instance(initiator)[1][0]
|
||||||
|
|
||||||
|
# Set the identifier and the discriminator.
|
||||||
|
discriminator = table_name(initiator)
|
||||||
|
dict_[self.parent_token.id.key] = pk
|
||||||
|
dict_[self.parent_token.discriminator.key] = discriminator
|
||||||
|
|
||||||
|
|
||||||
|
class GenericRelationshipProperty(MapperProperty):
|
||||||
|
"""A generic form of the relationship property.
|
||||||
|
|
||||||
|
Creates a 1 to many relationship between the parent model
|
||||||
|
and any other models using a descriminator (the table name).
|
||||||
|
|
||||||
|
:param discriminator
|
||||||
|
Field to discriminate which model we are referring to.
|
||||||
|
:param id:
|
||||||
|
Field to point to the model we are referring to.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, discriminator, id, doc=None):
|
||||||
|
self._discriminator_col = discriminator
|
||||||
|
self._id_col = id
|
||||||
|
self.doc = doc
|
||||||
|
|
||||||
|
set_creation_order(self)
|
||||||
|
|
||||||
|
def _column_to_property(self, column):
|
||||||
|
for name, attr in self.parent.attrs.items():
|
||||||
|
other = self.parent.columns.get(name)
|
||||||
|
if other is not None and column.name == other.name:
|
||||||
|
return attr
|
||||||
|
|
||||||
|
def init(self):
|
||||||
|
# Resolve columns to attributes.
|
||||||
|
self.discriminator = self._column_to_property(self._discriminator_col)
|
||||||
|
self.id = self._column_to_property(self._id_col)
|
||||||
|
|
||||||
|
class Comparator(PropComparator):
|
||||||
|
|
||||||
|
def __init__(self, prop, parentmapper):
|
||||||
|
self.prop = prop
|
||||||
|
self._parentmapper = parentmapper
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
discriminator = table_name(other)
|
||||||
|
q = self.prop._discriminator_col == discriminator
|
||||||
|
q &= self.prop._id_col == other.id
|
||||||
|
return q
|
||||||
|
|
||||||
|
def __ne__(self, other):
|
||||||
|
return ~(self == other)
|
||||||
|
|
||||||
|
def is_type(self, other):
|
||||||
|
discriminator = table_name(other)
|
||||||
|
return self.prop._discriminator_col == discriminator
|
||||||
|
|
||||||
|
def instrument_class(self, mapper):
|
||||||
|
attributes.register_attribute(
|
||||||
|
mapper.class_,
|
||||||
|
self.key,
|
||||||
|
comparator=self.Comparator(self, mapper),
|
||||||
|
parententity=mapper,
|
||||||
|
doc=self.doc,
|
||||||
|
impl_class=GenericAttributeImpl,
|
||||||
|
parent_token=self
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def generic_relationship(*args, **kwargs):
|
||||||
|
return GenericRelationshipProperty(*args, **kwargs)
|
126
tests/test_generic.py
Normal file
126
tests/test_generic.py
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
from __future__ import unicode_literals
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from tests import TestCase
|
||||||
|
from sqlalchemy_utils import generic_relationship
|
||||||
|
|
||||||
|
|
||||||
|
class TestGenericForiegnKey(TestCase):
|
||||||
|
|
||||||
|
def create_models(self):
|
||||||
|
class Building(self.Base):
|
||||||
|
__tablename__ = 'building'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
|
||||||
|
class User(self.Base):
|
||||||
|
__tablename__ = 'user'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
|
||||||
|
class Event(self.Base):
|
||||||
|
__tablename__ = 'event'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
|
||||||
|
object_type = sa.Column(sa.Unicode(255), name="objectType")
|
||||||
|
object_id = sa.Column(sa.Integer, nullable=False)
|
||||||
|
|
||||||
|
object = generic_relationship(object_type, object_id)
|
||||||
|
|
||||||
|
self.Building = Building
|
||||||
|
self.User = User
|
||||||
|
self.Event = Event
|
||||||
|
|
||||||
|
def test_set_manual_and_get(self):
|
||||||
|
user = self.User()
|
||||||
|
|
||||||
|
self.session.add(user)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
event = self.Event()
|
||||||
|
event.object_id = user.id
|
||||||
|
event.object_type = type(user).__tablename__
|
||||||
|
|
||||||
|
assert event.object is None
|
||||||
|
|
||||||
|
self.session.add(event)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
assert event.object == user
|
||||||
|
|
||||||
|
def test_set_and_get(self):
|
||||||
|
user = self.User()
|
||||||
|
|
||||||
|
self.session.add(user)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
event = self.Event(object=user)
|
||||||
|
|
||||||
|
assert event.object_id == user.id
|
||||||
|
assert event.object_type == type(user).__tablename__
|
||||||
|
|
||||||
|
self.session.add(event)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
assert event.object == user
|
||||||
|
|
||||||
|
def test_compare_instance(self):
|
||||||
|
user1 = self.User()
|
||||||
|
user2 = self.User()
|
||||||
|
|
||||||
|
self.session.add_all([user1, user2])
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
event = self.Event(object=user1)
|
||||||
|
|
||||||
|
self.session.add(event)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
assert event.object == user1
|
||||||
|
assert event.object != user2
|
||||||
|
|
||||||
|
def test_compare_query(self):
|
||||||
|
user1 = self.User()
|
||||||
|
user2 = self.User()
|
||||||
|
|
||||||
|
self.session.add_all([user1, user2])
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
event = self.Event(object=user1)
|
||||||
|
|
||||||
|
self.session.add(event)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
q = self.session.query(self.Event)
|
||||||
|
assert q.filter_by(object=user1).first() is not None
|
||||||
|
assert q.filter_by(object=user2).first() is None
|
||||||
|
assert q.filter(self.Event.object == user2).first() is None
|
||||||
|
|
||||||
|
def test_compare_not_query(self):
|
||||||
|
user1 = self.User()
|
||||||
|
user2 = self.User()
|
||||||
|
|
||||||
|
self.session.add_all([user1, user2])
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
event = self.Event(object=user1)
|
||||||
|
|
||||||
|
self.session.add(event)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
q = self.session.query(self.Event)
|
||||||
|
assert q.filter(self.Event.object != user2).first() is not None
|
||||||
|
|
||||||
|
def test_compare_type(self):
|
||||||
|
user1 = self.User()
|
||||||
|
user2 = self.User()
|
||||||
|
|
||||||
|
self.session.add_all([user1, user2])
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
event1 = self.Event(object=user1)
|
||||||
|
event2 = self.Event(object=user2)
|
||||||
|
|
||||||
|
self.session.add_all([event1, event2])
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
statement = self.Event.object.is_type(self.User)
|
||||||
|
q = self.session.query(self.Event).filter(statement)
|
||||||
|
assert q.first() is not None
|
@@ -12,7 +12,14 @@ class TestTableName(TestCase):
|
|||||||
|
|
||||||
self.Building = Building
|
self.Building = Building
|
||||||
|
|
||||||
def test_table_name(self):
|
def test_class(self):
|
||||||
assert table_name(self.Building) == 'building'
|
assert table_name(self.Building) == 'building'
|
||||||
del self.Building.__tablename__
|
del self.Building.__tablename__
|
||||||
assert table_name(self.Building) == 'building'
|
assert table_name(self.Building) == 'building'
|
||||||
|
|
||||||
|
def test_attribute(self):
|
||||||
|
assert table_name(self.Building.id) == 'building'
|
||||||
|
assert table_name(self.Building.name) == 'building'
|
||||||
|
|
||||||
|
def test_target(self):
|
||||||
|
assert table_name(self.Building()) == 'building'
|
||||||
|
Reference in New Issue
Block a user