Initial generic relationship implementation.
This commit is contained in:
@@ -11,6 +11,7 @@ from .functions import (
|
||||
)
|
||||
from .listeners import coercion_listener
|
||||
from .merge import merge, Merger
|
||||
from .generic import generic_relationship
|
||||
from .proxy_dict import ProxyDict, proxy_dict
|
||||
from .types import (
|
||||
ArrowType,
|
||||
|
111
sqlalchemy_utils/generic.py
Normal file
111
sqlalchemy_utils/generic.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from sqlalchemy.orm.interfaces import MapperProperty, PropComparator
|
||||
from sqlalchemy.orm.session import _state_session
|
||||
from sqlalchemy.orm import attributes, class_mapper
|
||||
from sqlalchemy.util import set_creation_order
|
||||
from sqlalchemy import exc as sa_exc, inspect
|
||||
|
||||
|
||||
class GenericAttributeImpl(attributes.ScalarAttributeImpl):
|
||||
|
||||
def get(self, state, dict_, passive=attributes.PASSIVE_OFF):
|
||||
if self.key in dict_:
|
||||
return dict_[self.key]
|
||||
|
||||
# Retrieve the session bound to the state in order to perform
|
||||
# a lazy query for the attribute.
|
||||
session = _state_session(state)
|
||||
if session is None:
|
||||
# State is not bound to a session; we cannot proceed.
|
||||
return None
|
||||
|
||||
# Find class for discriminator.
|
||||
# TODO: Perhaps optimize with some sort of lookup?
|
||||
discriminator = state.attrs[self.parent_token.discriminator.key].value
|
||||
target_class = None
|
||||
for class_ in state.class_._decl_class_registry.values():
|
||||
name = getattr(class_, '__tablename__', None)
|
||||
if name and name == discriminator:
|
||||
target_class = class_
|
||||
|
||||
if target_class is None:
|
||||
# Unknown discriminator; return nothing.
|
||||
return None
|
||||
|
||||
# Lookup row with the discriminator and id.
|
||||
id = state.attrs[self.parent_token.id.key].value
|
||||
target = session.query(target_class).get(id)
|
||||
|
||||
# Return found (or not found) target.
|
||||
return target
|
||||
|
||||
def set(self, state, dict_, initiator,
|
||||
passive=attributes.PASSIVE_OFF,
|
||||
check_old=None,
|
||||
pop=False):
|
||||
|
||||
# Set us on the state.
|
||||
dict_[self.key] = initiator
|
||||
|
||||
# Get the primary key of the initiator.
|
||||
mapper = class_mapper(type(initiator))
|
||||
if len(mapper.primary_key) > 1:
|
||||
raise sa_exc.InvalidRequestError(
|
||||
'Generic relationships against tables with composite '
|
||||
'primary keys are not supported.')
|
||||
|
||||
pk = mapper.identity_key_from_instance(initiator)[1][0]
|
||||
|
||||
# Set the identifier and the discriminator.
|
||||
discriminator = type(initiator).__tablename__
|
||||
dict_[self.parent_token.id.key] = pk
|
||||
dict_[self.parent_token.discriminator.key] = discriminator
|
||||
|
||||
class GenericRelationshipProperty(MapperProperty):
|
||||
"""A generic form of the relationship property.
|
||||
|
||||
Creates a 1 to many relationship between the parent model
|
||||
and any other models using a descriminator (the table name).
|
||||
|
||||
:param discriminator
|
||||
Field to discriminate which model we are referring to.
|
||||
:param id:
|
||||
Field to point to the model we are referring to.
|
||||
"""
|
||||
|
||||
def __init__(self, discriminator, id, doc=None):
|
||||
self.discriminator = discriminator
|
||||
self.id = id
|
||||
self.doc = doc
|
||||
|
||||
set_creation_order(self)
|
||||
|
||||
def _column_to_property(self, column):
|
||||
for name, attr in self.parent.attrs.items():
|
||||
other = self.parent.columns.get(name)
|
||||
if other is not None and column.name == other.name:
|
||||
return attr
|
||||
|
||||
def init(self):
|
||||
# Resolve columns to attributes.
|
||||
self.discriminator = self._column_to_property(self.discriminator)
|
||||
self.id = self._column_to_property(self.id)
|
||||
|
||||
class Comparator(PropComparator):
|
||||
|
||||
def __init__(self, prop, parentmapper):
|
||||
self.prop = prop
|
||||
self._parentmapper = parentmapper
|
||||
|
||||
def instrument_class(self, mapper):
|
||||
attributes.register_attribute(
|
||||
mapper.class_,
|
||||
self.key,
|
||||
comparator=self.Comparator(self, mapper),
|
||||
parententity=mapper,
|
||||
doc=self.doc,
|
||||
impl_class=GenericAttributeImpl,
|
||||
parent_token=self
|
||||
)
|
||||
|
||||
def generic_relationship(*args, **kwargs):
|
||||
return GenericRelationshipProperty(*args, **kwargs)
|
66
tests/test_generic.py
Normal file
66
tests/test_generic.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import orm
|
||||
from tests import TestCase
|
||||
from sqlalchemy_utils import generic_relationship
|
||||
|
||||
|
||||
class TestGenericForiegnKey(TestCase):
|
||||
|
||||
def create_models(self):
|
||||
class Building(self.Base):
|
||||
__tablename__ = 'building'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
class User(self.Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
building_id = sa.Column(sa.Integer, sa.ForeignKey(Building.id), name='buildingID')
|
||||
|
||||
building = orm.relationship(Building)
|
||||
|
||||
class Event(self.Base):
|
||||
__tablename__ = 'event'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
|
||||
object_type = sa.Column(sa.Unicode(255), name="objectType")
|
||||
object_id = sa.Column(sa.Integer, nullable=False)
|
||||
|
||||
object = generic_relationship(object_type, object_id)
|
||||
|
||||
self.Building = Building
|
||||
self.User = User
|
||||
self.Event = Event
|
||||
|
||||
def test_set_manual_and_get(self):
|
||||
user = self.User()
|
||||
|
||||
self.session.add(user)
|
||||
self.session.commit()
|
||||
|
||||
event = self.Event()
|
||||
event.object_id = user.id
|
||||
event.object_type = type(user).__tablename__
|
||||
|
||||
assert event.object is None
|
||||
|
||||
self.session.add(event)
|
||||
self.session.commit()
|
||||
|
||||
assert event.object == user
|
||||
|
||||
def test_set_and_get(self):
|
||||
user = self.User()
|
||||
|
||||
self.session.add(user)
|
||||
self.session.commit()
|
||||
|
||||
event = self.Event(object=user)
|
||||
|
||||
assert event.object_id == user.id
|
||||
assert event.object_type == type(user).__tablename__
|
||||
|
||||
self.session.add(event)
|
||||
self.session.commit()
|
||||
|
||||
assert event.object == user
|
Reference in New Issue
Block a user