Refactored batch fetch

This commit is contained in:
Konsta Vesterinen
2013-08-22 21:43:56 +03:00
parent e106c71795
commit 7a98f47e4c
4 changed files with 126 additions and 116 deletions

View File

@@ -6,15 +6,15 @@ from sqlalchemy.orm import defer
from sqlalchemy.orm.properties import ColumnProperty
from sqlalchemy.orm.query import Query
from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint
from .batch_fetch import batch_fetch, with_backrefs, compound_path
from .batch_fetch import batch_fetch, with_backrefs, CompositePath
from .sort_query import sort_query
__all__ = (
batch_fetch,
compound_path,
sort_query,
with_backrefs
with_backrefs,
CompositePath,
)

View File

@@ -2,7 +2,9 @@ from collections import defaultdict
import six
import sqlalchemy as sa
from sqlalchemy.orm import RelationshipProperty
from sqlalchemy.orm.attributes import set_committed_value
from sqlalchemy.orm.attributes import (
set_committed_value, InstrumentedAttribute
)
from sqlalchemy.orm.session import object_session
@@ -11,13 +13,73 @@ class with_backrefs(object):
Marks given attribute path so that whenever its fetched with batch_fetch
the backref relations are force set too.
"""
def __init__(self, attr_path):
self.attr_path = attr_path
def __init__(self, path):
self.path = path
class compound_path(object):
def __init__(self, *attr_paths):
self.attr_paths = attr_paths
class Path(object):
"""
A class that represents an attribute path.
"""
def __init__(self, entities, prop, populate_backrefs=False):
self.property = prop
self.entities = entities
self.populate_backrefs = populate_backrefs
if not isinstance(self.property, RelationshipProperty):
raise Exception(
'Given attribute is not a relationship property.'
)
self.fetcher = self.fetcher_class(self)
@property
def session(self):
return object_session(self.entities[0])
@property
def parent_model(self):
return self.entities[0].__class__
@property
def model(self):
return self.property.mapper.class_
@classmethod
def parse(cls, entities, path, populate_backrefs=False):
if isinstance(path, six.string_types):
attrs = path.split('.')
if len(attrs) > 1:
related_entities = []
for entity in entities:
related_entities.extend(getattr(entity, attrs[0]))
subpath = '.'.join(attrs[1:])
return Path.parse(related_entities, subpath, populate_backrefs)
else:
attr = getattr(
entities[0].__class__, attrs[0]
)
elif isinstance(path, InstrumentedAttribute):
attr = path
else:
raise Exception('Unknown path type.')
return Path(entities, attr.property, populate_backrefs)
@property
def fetcher_class(self):
if self.property.secondary is not None:
return ManyToManyFetcher
else:
if self.property.direction.name == 'MANYTOONE':
return ManyToOneFetcher
else:
return OneToManyFetcher
class CompositePath(object):
def __init__(self, *paths):
self.paths = paths
def batch_fetch(entities, *attr_paths):
@@ -28,9 +90,9 @@ def batch_fetch(entities, *attr_paths):
subqueryload and performs lot better.
:param entities: list of entities of the same type
:param attr:
Either InstrumentedAttribute object or a string representing the name
of the instrumented attribute
:param attr_paths:
List of either InstrumentedAttribute objects or a strings representing
the name of the instrumented attribute
Example::
@@ -81,107 +143,51 @@ def batch_fetch(entities, *attr_paths):
"""
if entities:
fetcher = FetchingCoordinator(entities)
fetcher = FetchingCoordinator()
for attr_path in attr_paths:
fetcher(attr_path)
fetcher(entities, attr_path)
class FetchingCoordinator(object):
def __init__(self, entities):
self.entities = entities
self.first = entities[0]
self.session = object_session(self.first)
def __call__(self, entities, path):
populate_backrefs = False
if isinstance(path, with_backrefs):
path = path.path
populate_backrefs = True
def parse_attr_path(self, attr_path, should_populate_backrefs):
if isinstance(attr_path, six.string_types):
attrs = attr_path.split('.')
if len(attrs) > 1:
related_entities = []
for entity in self.entities:
related_entities.extend(getattr(entity, attrs[0]))
subpath = '.'.join(attrs[1:])
if should_populate_backrefs:
subpath = with_backrefs(subpath)
return self.__class__(related_entities).fetcher_for_attr_path(
subpath
)
else:
attr_path = getattr(
self.first.__class__, attrs[0]
)
return self.fetcher_for_property(attr_path.property)
def fetcher_for_property(self, property_):
if not isinstance(property_, RelationshipProperty):
raise Exception(
'Given attribute is not a relationship property.'
)
if property_.secondary is not None:
fetcher_class = ManyToManyFetcher
else:
if property_.direction.name == 'MANYTOONE':
fetcher_class = ManyToOneFetcher
else:
fetcher_class = OneToManyFetcher
return fetcher_class(
self.entities,
property_,
self.should_populate_backrefs
)
def fetcher_for_attr_path(self, attr_path):
if isinstance(attr_path, with_backrefs):
self.should_populate_backrefs = True
attr_path = attr_path.attr_path
else:
self.should_populate_backrefs = False
return self.parse_attr_path(
attr_path,
self.should_populate_backrefs
)
def __call__(self, attr_path):
if isinstance(attr_path, compound_path):
if isinstance(path, CompositePath):
fetchers = []
for path in attr_path.attr_paths:
fetchers.append(self.fetcher_for_attr_path(path))
for path in path.paths:
fetchers.append(
Path.parse(entities, path, populate_backrefs).fetcher
)
fetcher = CompoundFetcher(*fetchers)
fetcher = CompositeFetcher(*fetchers)
else:
fetcher = self.fetcher_for_attr_path(attr_path)
fetcher = Path.parse(entities, path, populate_backrefs).fetcher
fetcher.fetch()
fetcher.populate()
class AbstractFetcher(object):
@property
def related_entities(self):
return self.session.query(self.model).filter(self.condition)
class CompoundFetcher(AbstractFetcher):
class CompositeFetcher(object):
def __init__(self, *fetchers):
if not all(fetchers[0].model == fetcher.model for fetcher in fetchers):
if not all(
fetchers[0].path.model == fetcher.path.model
for fetcher in fetchers
):
raise Exception(
'Each relationship property must have the same class when '
'using CompoundFetcher.'
'using CompositeFetcher.'
)
self.fetchers = fetchers
@property
def session(self):
return self.fetchers[0].session
return self.fetchers[0].path.session
@property
def model(self):
return self.fetchers[0].model
return self.fetchers[0].path.model
@property
def condition(self):
@@ -189,6 +195,10 @@ class CompoundFetcher(AbstractFetcher):
*[fetcher.condition for fetcher in self.fetchers]
)
@property
def related_entities(self):
return self.session.query(self.model).filter(self.condition)
def fetch(self):
for entity in self.related_entities:
for fetcher in self.fetchers:
@@ -200,23 +210,27 @@ class CompoundFetcher(AbstractFetcher):
fetcher.populate()
class Fetcher(AbstractFetcher):
def __init__(self, entities, property_, populate_backrefs=False):
self.should_populate_backrefs = populate_backrefs
self.entities = entities
self.prop = property_
self.model = self.prop.mapper.class_
self.first = self.entities[0]
self.session = object_session(self.first)
class Fetcher(object):
def __init__(self, path):
self.path = path
self.prop = self.path.property
self.parent_dict = defaultdict(list)
@property
def local_values_list(self):
return [
self.local_values(entity)
for entity in self.entities
for entity in self.path.entities
]
@property
def related_entities(self):
return self.path.session.query(self.path.model).filter(self.condition)
@property
def remote_column_name(self):
return list(self.path.property.remote_side)[0].name
def local_values(self, entity):
return getattr(entity, list(self.prop.local_columns)[0].name)
@@ -230,7 +244,7 @@ class Fetcher(AbstractFetcher):
)
for entity, parent_id in related_entities:
backref_dict[self.local_values(entity)].append(
self.session.query(self.first.__class__).get(parent_id)
self.path.session.query(self.path.parent_model).get(parent_id)
)
for entity, parent_id in related_entities:
set_committed_value(
@@ -243,23 +257,19 @@ class Fetcher(AbstractFetcher):
"""
Populate batch fetched entities to parent objects.
"""
for entity in self.entities:
for entity in self.path.entities:
set_committed_value(
entity,
self.prop.key,
self.parent_dict[self.local_values(entity)]
)
if self.should_populate_backrefs:
if self.path.populate_backrefs:
self.populate_backrefs(self.related_entities)
@property
def remote_column_name(self):
return list(self.prop.remote_side)[0].name
@property
def condition(self):
return getattr(self.model, self.remote_column_name).in_(
return getattr(self.path.model, self.remote_column_name).in_(
self.local_values_list
)
@@ -274,15 +284,15 @@ class ManyToManyFetcher(Fetcher):
for column in self.prop.remote_side:
for fk in column.foreign_keys:
# TODO: make this support inherited tables
if fk.column.table == self.first.__class__.__table__:
if fk.column.table == self.path.parent_model.__table__:
return fk.parent.name
@property
def related_entities(self):
return (
self.session
self.path.session
.query(
self.model,
self.path.model,
getattr(self.prop.secondary.c, self.remote_column_name)
)
.join(
@@ -303,8 +313,8 @@ class ManyToManyFetcher(Fetcher):
class ManyToOneFetcher(Fetcher):
def __init__(self, entities, property_, populate_backrefs=False):
Fetcher.__init__(self, entities, property_, populate_backrefs)
def __init__(self, path):
Fetcher.__init__(self, path)
self.parent_dict = defaultdict(lambda: None)
def append_entity(self, entity):

View File

@@ -1,6 +1,6 @@
import sqlalchemy as sa
from sqlalchemy_utils import batch_fetch
from sqlalchemy_utils.functions import compound_path
from sqlalchemy_utils.functions import CompositePath
from tests import TestCase
@@ -93,7 +93,7 @@ class TestCompoundOneToManyBatchFetching(TestCase):
batch_fetch(
buildings,
'business_premises',
compound_path(
CompositePath(
'equipment',
'business_premises.equipment'
)
@@ -198,7 +198,7 @@ class TestCompoundManyToOneBatchFetching(TestCase):
batch_fetch(
buildings,
'business_premises',
compound_path(
CompositePath(
'equipment',
'business_premises.equipment'
)

View File

@@ -3,7 +3,7 @@ from sqlalchemy_utils import batch_fetch
from tests import TestCase
class TestBatchFetchAssociations(TestCase):
class TestBatchFetchJoinTableInheritedModels(TestCase):
def create_models(self):
class Category(self.Base):
__tablename__ = 'category'