Refactored batch fetch

This commit is contained in:
Konsta Vesterinen
2013-08-22 21:43:56 +03:00
parent e106c71795
commit 7a98f47e4c
4 changed files with 126 additions and 116 deletions

View File

@@ -6,15 +6,15 @@ from sqlalchemy.orm import defer
from sqlalchemy.orm.properties import ColumnProperty from sqlalchemy.orm.properties import ColumnProperty
from sqlalchemy.orm.query import Query from sqlalchemy.orm.query import Query
from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint
from .batch_fetch import batch_fetch, with_backrefs, compound_path from .batch_fetch import batch_fetch, with_backrefs, CompositePath
from .sort_query import sort_query from .sort_query import sort_query
__all__ = ( __all__ = (
batch_fetch, batch_fetch,
compound_path,
sort_query, sort_query,
with_backrefs with_backrefs,
CompositePath,
) )

View File

@@ -2,7 +2,9 @@ from collections import defaultdict
import six import six
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy.orm import RelationshipProperty from sqlalchemy.orm import RelationshipProperty
from sqlalchemy.orm.attributes import set_committed_value from sqlalchemy.orm.attributes import (
set_committed_value, InstrumentedAttribute
)
from sqlalchemy.orm.session import object_session from sqlalchemy.orm.session import object_session
@@ -11,13 +13,73 @@ class with_backrefs(object):
Marks given attribute path so that whenever its fetched with batch_fetch Marks given attribute path so that whenever its fetched with batch_fetch
the backref relations are force set too. the backref relations are force set too.
""" """
def __init__(self, attr_path): def __init__(self, path):
self.attr_path = attr_path self.path = path
class compound_path(object): class Path(object):
def __init__(self, *attr_paths): """
self.attr_paths = attr_paths A class that represents an attribute path.
"""
def __init__(self, entities, prop, populate_backrefs=False):
self.property = prop
self.entities = entities
self.populate_backrefs = populate_backrefs
if not isinstance(self.property, RelationshipProperty):
raise Exception(
'Given attribute is not a relationship property.'
)
self.fetcher = self.fetcher_class(self)
@property
def session(self):
return object_session(self.entities[0])
@property
def parent_model(self):
return self.entities[0].__class__
@property
def model(self):
return self.property.mapper.class_
@classmethod
def parse(cls, entities, path, populate_backrefs=False):
if isinstance(path, six.string_types):
attrs = path.split('.')
if len(attrs) > 1:
related_entities = []
for entity in entities:
related_entities.extend(getattr(entity, attrs[0]))
subpath = '.'.join(attrs[1:])
return Path.parse(related_entities, subpath, populate_backrefs)
else:
attr = getattr(
entities[0].__class__, attrs[0]
)
elif isinstance(path, InstrumentedAttribute):
attr = path
else:
raise Exception('Unknown path type.')
return Path(entities, attr.property, populate_backrefs)
@property
def fetcher_class(self):
if self.property.secondary is not None:
return ManyToManyFetcher
else:
if self.property.direction.name == 'MANYTOONE':
return ManyToOneFetcher
else:
return OneToManyFetcher
class CompositePath(object):
def __init__(self, *paths):
self.paths = paths
def batch_fetch(entities, *attr_paths): def batch_fetch(entities, *attr_paths):
@@ -28,9 +90,9 @@ def batch_fetch(entities, *attr_paths):
subqueryload and performs lot better. subqueryload and performs lot better.
:param entities: list of entities of the same type :param entities: list of entities of the same type
:param attr: :param attr_paths:
Either InstrumentedAttribute object or a string representing the name List of either InstrumentedAttribute objects or a strings representing
of the instrumented attribute the name of the instrumented attribute
Example:: Example::
@@ -81,107 +143,51 @@ def batch_fetch(entities, *attr_paths):
""" """
if entities: if entities:
fetcher = FetchingCoordinator(entities) fetcher = FetchingCoordinator()
for attr_path in attr_paths: for attr_path in attr_paths:
fetcher(attr_path) fetcher(entities, attr_path)
class FetchingCoordinator(object): class FetchingCoordinator(object):
def __init__(self, entities): def __call__(self, entities, path):
self.entities = entities populate_backrefs = False
self.first = entities[0] if isinstance(path, with_backrefs):
self.session = object_session(self.first) path = path.path
populate_backrefs = True
def parse_attr_path(self, attr_path, should_populate_backrefs): if isinstance(path, CompositePath):
if isinstance(attr_path, six.string_types):
attrs = attr_path.split('.')
if len(attrs) > 1:
related_entities = []
for entity in self.entities:
related_entities.extend(getattr(entity, attrs[0]))
subpath = '.'.join(attrs[1:])
if should_populate_backrefs:
subpath = with_backrefs(subpath)
return self.__class__(related_entities).fetcher_for_attr_path(
subpath
)
else:
attr_path = getattr(
self.first.__class__, attrs[0]
)
return self.fetcher_for_property(attr_path.property)
def fetcher_for_property(self, property_):
if not isinstance(property_, RelationshipProperty):
raise Exception(
'Given attribute is not a relationship property.'
)
if property_.secondary is not None:
fetcher_class = ManyToManyFetcher
else:
if property_.direction.name == 'MANYTOONE':
fetcher_class = ManyToOneFetcher
else:
fetcher_class = OneToManyFetcher
return fetcher_class(
self.entities,
property_,
self.should_populate_backrefs
)
def fetcher_for_attr_path(self, attr_path):
if isinstance(attr_path, with_backrefs):
self.should_populate_backrefs = True
attr_path = attr_path.attr_path
else:
self.should_populate_backrefs = False
return self.parse_attr_path(
attr_path,
self.should_populate_backrefs
)
def __call__(self, attr_path):
if isinstance(attr_path, compound_path):
fetchers = [] fetchers = []
for path in attr_path.attr_paths: for path in path.paths:
fetchers.append(self.fetcher_for_attr_path(path)) fetchers.append(
Path.parse(entities, path, populate_backrefs).fetcher
)
fetcher = CompoundFetcher(*fetchers) fetcher = CompositeFetcher(*fetchers)
else: else:
fetcher = self.fetcher_for_attr_path(attr_path) fetcher = Path.parse(entities, path, populate_backrefs).fetcher
fetcher.fetch() fetcher.fetch()
fetcher.populate() fetcher.populate()
class AbstractFetcher(object): class CompositeFetcher(object):
@property
def related_entities(self):
return self.session.query(self.model).filter(self.condition)
class CompoundFetcher(AbstractFetcher):
def __init__(self, *fetchers): def __init__(self, *fetchers):
if not all(fetchers[0].model == fetcher.model for fetcher in fetchers): if not all(
fetchers[0].path.model == fetcher.path.model
for fetcher in fetchers
):
raise Exception( raise Exception(
'Each relationship property must have the same class when ' 'Each relationship property must have the same class when '
'using CompoundFetcher.' 'using CompositeFetcher.'
) )
self.fetchers = fetchers self.fetchers = fetchers
@property @property
def session(self): def session(self):
return self.fetchers[0].session return self.fetchers[0].path.session
@property @property
def model(self): def model(self):
return self.fetchers[0].model return self.fetchers[0].path.model
@property @property
def condition(self): def condition(self):
@@ -189,6 +195,10 @@ class CompoundFetcher(AbstractFetcher):
*[fetcher.condition for fetcher in self.fetchers] *[fetcher.condition for fetcher in self.fetchers]
) )
@property
def related_entities(self):
return self.session.query(self.model).filter(self.condition)
def fetch(self): def fetch(self):
for entity in self.related_entities: for entity in self.related_entities:
for fetcher in self.fetchers: for fetcher in self.fetchers:
@@ -200,23 +210,27 @@ class CompoundFetcher(AbstractFetcher):
fetcher.populate() fetcher.populate()
class Fetcher(AbstractFetcher): class Fetcher(object):
def __init__(self, entities, property_, populate_backrefs=False): def __init__(self, path):
self.should_populate_backrefs = populate_backrefs self.path = path
self.entities = entities self.prop = self.path.property
self.prop = property_
self.model = self.prop.mapper.class_
self.first = self.entities[0]
self.session = object_session(self.first)
self.parent_dict = defaultdict(list) self.parent_dict = defaultdict(list)
@property @property
def local_values_list(self): def local_values_list(self):
return [ return [
self.local_values(entity) self.local_values(entity)
for entity in self.entities for entity in self.path.entities
] ]
@property
def related_entities(self):
return self.path.session.query(self.path.model).filter(self.condition)
@property
def remote_column_name(self):
return list(self.path.property.remote_side)[0].name
def local_values(self, entity): def local_values(self, entity):
return getattr(entity, list(self.prop.local_columns)[0].name) return getattr(entity, list(self.prop.local_columns)[0].name)
@@ -230,7 +244,7 @@ class Fetcher(AbstractFetcher):
) )
for entity, parent_id in related_entities: for entity, parent_id in related_entities:
backref_dict[self.local_values(entity)].append( backref_dict[self.local_values(entity)].append(
self.session.query(self.first.__class__).get(parent_id) self.path.session.query(self.path.parent_model).get(parent_id)
) )
for entity, parent_id in related_entities: for entity, parent_id in related_entities:
set_committed_value( set_committed_value(
@@ -243,23 +257,19 @@ class Fetcher(AbstractFetcher):
""" """
Populate batch fetched entities to parent objects. Populate batch fetched entities to parent objects.
""" """
for entity in self.entities: for entity in self.path.entities:
set_committed_value( set_committed_value(
entity, entity,
self.prop.key, self.prop.key,
self.parent_dict[self.local_values(entity)] self.parent_dict[self.local_values(entity)]
) )
if self.should_populate_backrefs: if self.path.populate_backrefs:
self.populate_backrefs(self.related_entities) self.populate_backrefs(self.related_entities)
@property
def remote_column_name(self):
return list(self.prop.remote_side)[0].name
@property @property
def condition(self): def condition(self):
return getattr(self.model, self.remote_column_name).in_( return getattr(self.path.model, self.remote_column_name).in_(
self.local_values_list self.local_values_list
) )
@@ -274,15 +284,15 @@ class ManyToManyFetcher(Fetcher):
for column in self.prop.remote_side: for column in self.prop.remote_side:
for fk in column.foreign_keys: for fk in column.foreign_keys:
# TODO: make this support inherited tables # TODO: make this support inherited tables
if fk.column.table == self.first.__class__.__table__: if fk.column.table == self.path.parent_model.__table__:
return fk.parent.name return fk.parent.name
@property @property
def related_entities(self): def related_entities(self):
return ( return (
self.session self.path.session
.query( .query(
self.model, self.path.model,
getattr(self.prop.secondary.c, self.remote_column_name) getattr(self.prop.secondary.c, self.remote_column_name)
) )
.join( .join(
@@ -303,8 +313,8 @@ class ManyToManyFetcher(Fetcher):
class ManyToOneFetcher(Fetcher): class ManyToOneFetcher(Fetcher):
def __init__(self, entities, property_, populate_backrefs=False): def __init__(self, path):
Fetcher.__init__(self, entities, property_, populate_backrefs) Fetcher.__init__(self, path)
self.parent_dict = defaultdict(lambda: None) self.parent_dict = defaultdict(lambda: None)
def append_entity(self, entity): def append_entity(self, entity):

View File

@@ -1,6 +1,6 @@
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils import batch_fetch from sqlalchemy_utils import batch_fetch
from sqlalchemy_utils.functions import compound_path from sqlalchemy_utils.functions import CompositePath
from tests import TestCase from tests import TestCase
@@ -93,7 +93,7 @@ class TestCompoundOneToManyBatchFetching(TestCase):
batch_fetch( batch_fetch(
buildings, buildings,
'business_premises', 'business_premises',
compound_path( CompositePath(
'equipment', 'equipment',
'business_premises.equipment' 'business_premises.equipment'
) )
@@ -198,7 +198,7 @@ class TestCompoundManyToOneBatchFetching(TestCase):
batch_fetch( batch_fetch(
buildings, buildings,
'business_premises', 'business_premises',
compound_path( CompositePath(
'equipment', 'equipment',
'business_premises.equipment' 'business_premises.equipment'
) )

View File

@@ -3,7 +3,7 @@ from sqlalchemy_utils import batch_fetch
from tests import TestCase from tests import TestCase
class TestBatchFetchAssociations(TestCase): class TestBatchFetchJoinTableInheritedModels(TestCase):
def create_models(self): def create_models(self):
class Category(self.Base): class Category(self.Base):
__tablename__ = 'category' __tablename__ = 'category'