399 lines
11 KiB
Python
399 lines
11 KiB
Python
from collections import defaultdict
|
|
import six
|
|
import sqlalchemy as sa
|
|
from sqlalchemy.orm import RelationshipProperty
|
|
from sqlalchemy.orm.attributes import (
|
|
set_committed_value, InstrumentedAttribute
|
|
)
|
|
from sqlalchemy.orm.session import object_session
|
|
|
|
|
|
class PathException(Exception):
|
|
pass
|
|
|
|
|
|
class with_backrefs(object):
|
|
"""
|
|
Marks given attribute path so that whenever its fetched with batch_fetch
|
|
the backref relations are force set too. Very useful when dealing with
|
|
certain many-to-many relationship scenarios.
|
|
"""
|
|
def __init__(self, path):
|
|
self.path = path
|
|
|
|
|
|
class Path(object):
|
|
"""
|
|
A class that represents an attribute path.
|
|
"""
|
|
def __init__(self, entities, prop, populate_backrefs=False):
|
|
self.property = prop
|
|
self.entities = entities
|
|
self.populate_backrefs = populate_backrefs
|
|
if not isinstance(self.property, RelationshipProperty):
|
|
raise PathException(
|
|
'Given attribute is not a relationship property.'
|
|
)
|
|
self.fetcher = self.fetcher_class(self)
|
|
|
|
@property
|
|
def session(self):
|
|
return object_session(self.entities[0])
|
|
|
|
@property
|
|
def parent_model(self):
|
|
return self.entities[0].__class__
|
|
|
|
@property
|
|
def model(self):
|
|
return self.property.mapper.class_
|
|
|
|
@classmethod
|
|
def parse(cls, entities, path, populate_backrefs=False):
|
|
if isinstance(path, six.string_types):
|
|
attrs = path.split('.')
|
|
|
|
if len(attrs) > 1:
|
|
related_entities = []
|
|
for entity in entities:
|
|
related_entities.extend(getattr(entity, attrs[0]))
|
|
|
|
if not related_entities:
|
|
return
|
|
subpath = '.'.join(attrs[1:])
|
|
return Path.parse(related_entities, subpath, populate_backrefs)
|
|
else:
|
|
attr = getattr(
|
|
entities[0].__class__, attrs[0]
|
|
)
|
|
elif isinstance(path, InstrumentedAttribute):
|
|
attr = path
|
|
else:
|
|
raise PathException('Unknown path type.')
|
|
|
|
return Path(entities, attr.property, populate_backrefs)
|
|
|
|
@property
|
|
def fetcher_class(self):
|
|
if self.property.secondary is not None:
|
|
return ManyToManyFetcher
|
|
else:
|
|
if self.property.direction.name == 'MANYTOONE':
|
|
return ManyToOneFetcher
|
|
else:
|
|
return OneToManyFetcher
|
|
|
|
|
|
class CompositePath(object):
|
|
def __init__(self, *paths):
|
|
self.paths = paths
|
|
|
|
|
|
def batch_fetch(entities, *attr_paths):
|
|
"""
|
|
Batch fetch given relationship attribute for collection of entities.
|
|
|
|
This function is in many cases a valid alternative for SQLAlchemy's
|
|
subqueryload and performs lot better.
|
|
|
|
:param entities: list of entities of the same type
|
|
:param attr_paths:
|
|
List of either InstrumentedAttribute objects or a strings representing
|
|
the name of the instrumented attribute
|
|
|
|
Example::
|
|
|
|
|
|
from sqlalchemy_utils import batch_fetch
|
|
|
|
|
|
users = session.query(User).limit(20).all()
|
|
|
|
batch_fetch(users, User.phonenumbers)
|
|
|
|
|
|
Function also accepts strings as attribute names: ::
|
|
|
|
|
|
users = session.query(User).limit(20).all()
|
|
|
|
batch_fetch(users, 'phonenumbers')
|
|
|
|
|
|
Multiple attributes may be provided: ::
|
|
|
|
|
|
clubs = session.query(Club).limit(20).all()
|
|
|
|
batch_fetch(
|
|
clubs,
|
|
'teams',
|
|
'teams.players',
|
|
'teams.players.user_groups'
|
|
)
|
|
|
|
You can also force populate backrefs: ::
|
|
|
|
|
|
from sqlalchemy_utils import with_backrefs
|
|
|
|
|
|
clubs = session.query(Club).limit(20).all()
|
|
|
|
batch_fetch(
|
|
clubs,
|
|
'teams',
|
|
'teams.players',
|
|
with_backrefs('teams.players.user_groups')
|
|
)
|
|
|
|
"""
|
|
|
|
if entities:
|
|
for path in attr_paths:
|
|
fetcher = fetcher_factory(entities, path)
|
|
if fetcher:
|
|
fetcher.fetch()
|
|
fetcher.populate()
|
|
|
|
|
|
def fetcher_factory(entities, path):
|
|
populate_backrefs = False
|
|
if isinstance(path, with_backrefs):
|
|
path = path.path
|
|
populate_backrefs = True
|
|
|
|
if isinstance(path, CompositePath):
|
|
fetchers = []
|
|
for path in path.paths:
|
|
path = Path.parse(entities, path, populate_backrefs)
|
|
if path:
|
|
fetchers.append(
|
|
path.fetcher
|
|
)
|
|
|
|
return CompositeFetcher(*fetchers)
|
|
else:
|
|
path = Path.parse(entities, path, populate_backrefs)
|
|
if path:
|
|
return path.fetcher
|
|
|
|
|
|
class CompositeFetcher(object):
|
|
def __init__(self, *fetchers):
|
|
if not all(
|
|
fetchers[0].path.model == fetcher.path.model
|
|
for fetcher in fetchers
|
|
):
|
|
raise PathException(
|
|
'Each relationship property must have the same class when '
|
|
'using CompositeFetcher.'
|
|
)
|
|
self.fetchers = fetchers
|
|
|
|
@property
|
|
def session(self):
|
|
return self.fetchers[0].path.session
|
|
|
|
@property
|
|
def model(self):
|
|
return self.fetchers[0].path.model
|
|
|
|
@property
|
|
def condition(self):
|
|
return sa.or_(
|
|
*[fetcher.condition for fetcher in self.fetchers]
|
|
)
|
|
|
|
@property
|
|
def related_entities(self):
|
|
return self.session.query(self.model).filter(self.condition)
|
|
|
|
def fetch(self):
|
|
for entity in self.related_entities:
|
|
for fetcher in self.fetchers:
|
|
if any(
|
|
getattr(entity, name)
|
|
for name in fetcher.remote_column_names
|
|
):
|
|
fetcher.append_entity(entity)
|
|
|
|
def populate(self):
|
|
for fetcher in self.fetchers:
|
|
fetcher.populate()
|
|
|
|
|
|
class Fetcher(object):
|
|
def __init__(self, path):
|
|
self.path = path
|
|
self.prop = self.path.property
|
|
if self.prop.uselist:
|
|
self.parent_dict = defaultdict(list)
|
|
else:
|
|
self.parent_dict = defaultdict(lambda: None)
|
|
|
|
@property
|
|
def local_values_list(self):
|
|
return [
|
|
self.local_values(entity)
|
|
for entity in self.path.entities
|
|
]
|
|
|
|
@property
|
|
def relation_query_base(self):
|
|
return self.path.session.query(self.path.model)
|
|
|
|
@property
|
|
def related_entities(self):
|
|
return self.relation_query_base.filter(self.condition)
|
|
|
|
@property
|
|
def local_column_names(self):
|
|
return [local.name for local, remote in self.prop.local_remote_pairs]
|
|
|
|
def parent_key(self, entity):
|
|
return tuple(
|
|
getattr(entity, name)
|
|
for name in self.remote_column_names
|
|
)
|
|
|
|
def local_values(self, entity):
|
|
return tuple(
|
|
getattr(entity, name)
|
|
for name in self.local_column_names
|
|
)
|
|
|
|
def populate_backrefs(self, related_entities):
|
|
"""
|
|
Populates backrefs for given related entities.
|
|
"""
|
|
backref_dict = dict(
|
|
(self.local_values(value[0]), [])
|
|
for value in related_entities
|
|
)
|
|
for value in related_entities:
|
|
backref_dict[self.local_values(value[0])].append(
|
|
self.path.session.query(self.path.parent_model).get(
|
|
tuple(value[1:])
|
|
)
|
|
)
|
|
for value in related_entities:
|
|
set_committed_value(
|
|
value[0],
|
|
self.prop.back_populates,
|
|
backref_dict[self.local_values(value[0])]
|
|
)
|
|
|
|
def populate(self):
|
|
"""
|
|
Populate batch fetched entities to parent objects.
|
|
"""
|
|
for entity in self.path.entities:
|
|
set_committed_value(
|
|
entity,
|
|
self.prop.key,
|
|
self.parent_dict[self.local_values(entity)]
|
|
)
|
|
|
|
if self.path.populate_backrefs:
|
|
self.populate_backrefs(self.related_entities)
|
|
|
|
@property
|
|
def remote(self):
|
|
return self.path.model
|
|
|
|
@property
|
|
def condition(self):
|
|
names = self.remote_column_names
|
|
if len(names) == 1:
|
|
return getattr(self.remote, names[0]).in_(
|
|
value[0] for value in self.local_values_list
|
|
)
|
|
elif len(names) > 1:
|
|
conditions = []
|
|
for entity in self.path.entities:
|
|
conditions.append(
|
|
sa.and_(
|
|
*[
|
|
getattr(self.remote, remote.name)
|
|
==
|
|
getattr(entity, local.name)
|
|
for local, remote in self.prop.local_remote_pairs
|
|
if remote in self.remote_column_names
|
|
]
|
|
)
|
|
)
|
|
return sa.or_(*conditions)
|
|
else:
|
|
raise PathException(
|
|
'Could not obtain remote column names.'
|
|
)
|
|
|
|
def fetch(self):
|
|
for entity in self.related_entities:
|
|
self.append_entity(entity)
|
|
|
|
@property
|
|
def remote_column_names(self):
|
|
return [remote.name for local, remote in self.prop.local_remote_pairs]
|
|
|
|
|
|
class ManyToManyFetcher(Fetcher):
|
|
@property
|
|
def remote(self):
|
|
return self.prop.secondary.c
|
|
|
|
@property
|
|
def local_column_names(self):
|
|
names = []
|
|
for local, remote in self.prop.local_remote_pairs:
|
|
for fk in remote.foreign_keys:
|
|
if fk.column.table in self.prop.parent.tables:
|
|
names.append(local.name)
|
|
return names
|
|
|
|
@property
|
|
def remote_column_names(self):
|
|
names = []
|
|
for local, remote in self.prop.local_remote_pairs:
|
|
for fk in remote.foreign_keys:
|
|
if fk.column.table in self.prop.parent.tables:
|
|
names.append(remote.name)
|
|
return names
|
|
|
|
@property
|
|
def relation_query_base(self):
|
|
return (
|
|
self.path.session
|
|
.query(
|
|
self.path.model,
|
|
*[
|
|
getattr(self.prop.secondary.c, name)
|
|
for name in self.remote_column_names
|
|
]
|
|
)
|
|
.join(
|
|
self.prop.secondary, self.prop.secondaryjoin
|
|
)
|
|
)
|
|
|
|
def fetch(self):
|
|
for value in self.related_entities:
|
|
self.parent_dict[tuple(value[1:])].append(
|
|
value[0]
|
|
)
|
|
|
|
|
|
class ManyToOneFetcher(Fetcher):
|
|
def append_entity(self, entity):
|
|
#print 'appending entity ', entity, ' to key ', self.parent_key(entity)
|
|
self.parent_dict[self.parent_key(entity)] = entity
|
|
|
|
|
|
class OneToManyFetcher(Fetcher):
|
|
def append_entity(self, entity):
|
|
#print 'appending entity ', entity, ' to key ', self.parent_key(entity)
|
|
self.parent_dict[self.parent_key(entity)].append(
|
|
entity
|
|
)
|