Files
deb-python-sqlalchemy-utils/sqlalchemy_utils/batch.py
2014-03-05 15:16:54 +02:00

388 lines
11 KiB
Python

from collections import defaultdict
from itertools import chain
import six
import sqlalchemy as sa
from sqlalchemy.orm import RelationshipProperty
from sqlalchemy.orm.attributes import (
set_committed_value, InstrumentedAttribute
)
from sqlalchemy.orm.session import object_session
from sqlalchemy_utils.generic import GenericRelationshipProperty
from sqlalchemy_utils.functions.orm import (
list_local_values,
list_local_remote_exprs,
local_values,
remote_column_names,
remote_values,
remote
)
class PathException(Exception):
pass
class DataException(Exception):
pass
class with_backrefs(object):
"""
Marks given attribute path so that whenever its fetched with batch_fetch
the backref relations are force set too. Very useful when dealing with
certain many-to-many relationship scenarios.
"""
def __init__(self, path):
self.path = path
class Path(object):
"""
A class that represents an attribute path.
"""
def __init__(self, entities, prop, populate_backrefs=False):
self.validate_property(prop)
self.property = prop
self.entities = entities
self.populate_backrefs = populate_backrefs
self.fetcher = self.fetcher_class(self)
def validate_property(self, prop):
if (
not isinstance(prop, RelationshipProperty) and
not isinstance(prop, GenericRelationshipProperty)
):
raise PathException(
'Given attribute is not a relationship property.'
)
@property
def session(self):
return object_session(self.entities[0])
@property
def model(self):
return self.property.mapper.class_
@classmethod
def parse(cls, entities, path, populate_backrefs=False):
if isinstance(path, six.string_types):
attrs = path.split('.')
if len(attrs) > 1:
related_entities = []
for entity in entities:
related_entities.extend(getattr(entity, attrs[0]))
if not related_entities:
raise DataException('No related entities.')
subpath = '.'.join(attrs[1:])
return Path.parse(related_entities, subpath, populate_backrefs)
else:
attr = getattr(
entities[0].__class__, attrs[0]
)
elif isinstance(path, InstrumentedAttribute):
attr = path
else:
raise PathException('Unknown path type.')
return Path(entities, attr.property, populate_backrefs)
@property
def fetcher_class(self):
if isinstance(self.property, GenericRelationshipProperty):
return GenericRelationshipFetcher
else:
if self.property.secondary is not None:
return ManyToManyFetcher
else:
if self.property.direction.name == 'MANYTOONE':
return ManyToOneFetcher
else:
return OneToManyFetcher
def batch_fetch(entities, *attr_paths):
"""
Batch fetch given relationship attribute for collection of entities.
This function is in many cases a valid alternative for SQLAlchemy's
subqueryload and performs lot better.
:param entities: list of entities of the same type
:param attr_paths:
List of either InstrumentedAttribute objects or a strings representing
the name of the instrumented attribute
Example::
from sqlalchemy_utils import batch_fetch
users = session.query(User).limit(20).all()
batch_fetch(users, User.phonenumbers)
Function also accepts strings as attribute names: ::
users = session.query(User).limit(20).all()
batch_fetch(users, 'phonenumbers')
Multiple attributes may be provided: ::
clubs = session.query(Club).limit(20).all()
batch_fetch(
clubs,
'teams',
'teams.players',
'teams.players.user_groups'
)
You can also force populate backrefs: ::
from sqlalchemy_utils import with_backrefs
clubs = session.query(Club).limit(20).all()
batch_fetch(
clubs,
'teams',
'teams.players',
with_backrefs('teams.players.user_groups')
)
"""
if entities:
for path in attr_paths:
try:
fetcher = fetcher_factory(entities, path)
fetcher.fetch()
fetcher.populate()
except DataException:
pass
def get_fetcher(entities, path, populate_backrefs):
return Path.parse(entities, path, populate_backrefs).fetcher
def fetcher_factory(entities, path):
populate_backrefs = False
if isinstance(path, with_backrefs):
path = path.path
populate_backrefs = True
if isinstance(path, tuple):
return CompositeFetcher(
*(get_fetcher(entities, p, populate_backrefs) for p in path)
)
else:
return get_fetcher(entities, path, populate_backrefs)
class CompositeFetcher(object):
def __init__(self, *fetchers):
if not all(
fetchers[0].path.model == fetcher.path.model
for fetcher in fetchers
):
raise PathException(
'Each relationship property must have the same class when '
'using CompositeFetcher.'
)
self.fetchers = fetchers
@property
def session(self):
return self.fetchers[0].path.session
@property
def model(self):
return self.fetchers[0].path.model
@property
def condition(self):
return sa.or_(
*(fetcher.condition for fetcher in self.fetchers)
)
@property
def related_entities(self):
return self.session.query(self.model).filter(self.condition)
def fetch(self):
for entity in self.related_entities:
for fetcher in self.fetchers:
if any(remote_values(fetcher.prop, entity)):
fetcher.append_entity(entity)
def populate(self):
for fetcher in self.fetchers:
fetcher.populate()
class Fetcher(object):
def __init__(self, path):
self.path = path
self.prop = self.path.property
default = list if self.prop.uselist else lambda: None
self.parent_dict = defaultdict(default)
@property
def relation_query_base(self):
return self.path.session.query(self.path.model)
@property
def related_entities(self):
return self.relation_query_base.filter(self.condition)
def populate_backrefs(self, related_entities):
"""
Populates backrefs for given related entities.
"""
backref_dict = dict(
(local_values(self.prop, value[0]), [])
for value in related_entities
)
for value in related_entities:
backref_dict[local_values(self.prop, value[0])].append(
self.path.session.query(self.path.entities[0].__class__).get(
tuple(value[1:])
)
)
for value in related_entities:
set_committed_value(
value[0],
self.prop.back_populates,
backref_dict[local_values(self.prop, value[0])]
)
def populate(self):
"""
Populate batch fetched entities to parent objects.
"""
for entity in self.path.entities:
set_committed_value(
entity,
self.prop.key,
self.parent_dict[local_values(self.prop, entity)]
)
if self.path.populate_backrefs:
self.populate_backrefs(self.related_entities)
@property
def condition(self):
names = list(remote_column_names(self.prop))
if len(names) == 1:
attr = getattr(remote(self.prop), names[0])
return attr.in_(
v[0] for v in list_local_values(self.prop, self.path.entities)
)
elif len(names) > 1:
return sa.or_(
*list_local_remote_exprs(self.prop, self.path.entities)
)
else:
raise PathException(
'Could not obtain remote column names.'
)
def fetch(self):
for entity in self.related_entities:
self.append_entity(entity)
class GenericRelationshipFetcher(object):
def __init__(self, path):
self.path = path
self.prop = self.path.property
self.parent_dict = defaultdict(lambda: None)
def fetch(self):
for entity in self.related_entities:
self.append_entity(entity)
def append_entity(self, entity):
self.parent_dict[remote_values(self.prop, entity)] = entity
def populate(self):
"""
Populate batch fetched entities to parent objects.
"""
for entity in self.path.entities:
set_committed_value(
entity,
self.prop.key,
self.parent_dict[local_values(self.prop, entity)]
)
@property
def related_entities(self):
id_dict = defaultdict(list)
for entity in self.path.entities:
discriminator = getattr(entity, self.prop._discriminator_col.key)
for id_col in self.prop._id_cols:
id_dict[discriminator].append(
getattr(entity, id_col.key)
)
return chain(*self._queries(sa.inspect(entity), id_dict))
def _queries(self, state, id_dict):
for discriminator, ids in six.iteritems(id_dict):
class_ = state.class_._decl_class_registry.get(discriminator)
yield self.path.session.query(
class_
).filter(
class_.id.in_(ids)
)
class ManyToManyFetcher(Fetcher):
@property
def relation_query_base(self):
return (
self.path.session
.query(
self.path.model,
*[
getattr(remote(self.prop), name)
for name in remote_column_names(self.prop)
]
)
.join(
self.prop.secondary, self.prop.secondaryjoin
)
)
def fetch(self):
for value in self.related_entities:
self.parent_dict[tuple(value[1:])].append(
value[0]
)
class ManyToOneFetcher(Fetcher):
def append_entity(self, entity):
self.parent_dict[remote_values(self.prop, entity)] = entity
class OneToManyFetcher(Fetcher):
def append_entity(self, entity):
self.parent_dict[remote_values(self.prop, entity)].append(
entity
)