Refactored batch fetch

This commit is contained in:
Konsta Vesterinen
2013-08-22 13:12:59 +03:00
parent 223f6a64fa
commit e197a56f53
2 changed files with 48 additions and 44 deletions

View File

@@ -1,4 +1,6 @@
from collections import defaultdict
import six
import sqlalchemy as sa
from sqlalchemy.orm import RelationshipProperty
from sqlalchemy.orm.attributes import set_committed_value
from sqlalchemy.orm.session import object_session
@@ -104,24 +106,14 @@ class FetchingCoordinator(object):
if should_populate_backrefs:
subpath = with_backrefs(subpath)
batch_fetch(
related_entities,
return self.__class__(related_entities).fetcher_for_attr_path(
subpath
)
return
else:
return getattr(
attr_path = getattr(
self.first.__class__, attrs[0]
)
else:
return attr_path
def fetch_relation_entities(self):
if len(self.prop.remote_side) > 1:
raise Exception(
'Only relationships with single remote side columns '
'are supported.'
)
return self.fetcher_for_property(attr_path.property)
def fetcher_for_property(self, property_):
if not isinstance(property_, RelationshipProperty):
@@ -130,12 +122,18 @@ class FetchingCoordinator(object):
)
if property_.secondary is not None:
return ManyToManyFetcher(self, property_)
fetcher_class = ManyToManyFetcher
else:
if property_.direction.name == 'MANYTOONE':
return ManyToOneFetcher(self, property_)
fetcher_class = ManyToOneFetcher
else:
return OneToManyFetcher(self, property_)
fetcher_class = OneToManyFetcher
return fetcher_class(
self.entities,
property_,
self.should_populate_backrefs
)
def fetcher_for_attr_path(self, attr_path):
if isinstance(attr_path, with_backrefs):
@@ -144,15 +142,19 @@ class FetchingCoordinator(object):
else:
self.should_populate_backrefs = False
attr = self.parse_attr_path(attr_path, self.should_populate_backrefs)
if not attr:
return
return self.fetcher_for_property(attr.property)
return self.parse_attr_path(
attr_path,
self.should_populate_backrefs
)
def __call__(self, attr_path):
if isinstance(attr_path, compound_path):
fetchers = []
for path in attr_path.attr_paths:
self(path)
fetchers.append(self.fetcher_for_attr_path(path))
fetcher = CompoundFetcher(*fetchers)
print fetcher.condition
else:
fetcher = self.fetcher_for_attr_path(attr_path)
if not fetcher:
@@ -162,28 +164,33 @@ class FetchingCoordinator(object):
class CompoundFetcher(object):
def __init__(self, coordinator, path):
self.coordinator = coordinator
self.entities = coordinator.entities
self.first = self.entities[0]
self.session = object_session(self.first)
def __init__(self, *fetchers):
if not all(fetchers[0].model == fetcher.model for fetcher in fetchers):
raise Exception(
'Each relationship property must have the same class when '
'using CompoundFetcher.'
)
self.fetchers = fetchers
@property
def condition(self):
return sa.or_(
*[fetcher.condition for fetcher in self.fetchers]
)
@property
def local_values(self):
pass
class Fetcher(object):
def __init__(self, coordinator, property_):
self.coordinator = coordinator
def __init__(self, entities, property_, populate_backrefs=False):
self.should_populate_backrefs = populate_backrefs
self.entities = entities
self.prop = property_
self.model = self.prop.mapper.class_
self.entities = coordinator.entities
self.first = self.entities[0]
self.session = object_session(self.first)
self.init_parent_dict()
def init_parent_dict(self):
self.parent_dict = dict(
(self.local_values(entity), [])
for entity in self.entities
)
@property
def local_values_list(self):
@@ -225,7 +232,7 @@ class Fetcher(object):
self.parent_dict[self.local_values(entity)]
)
if self.coordinator.should_populate_backrefs:
if self.should_populate_backrefs:
self.populate_backrefs(self.related_entities)
@property
@@ -271,6 +278,7 @@ class ManyToManyFetcher(Fetcher):
)
def fetch(self):
self.parent_dict = defaultdict(list)
for entity, parent_id in self.related_entities:
self.parent_dict[parent_id].append(
entity
@@ -278,19 +286,15 @@ class ManyToManyFetcher(Fetcher):
class ManyToOneFetcher(Fetcher):
def init_parent_dict(self):
self.parent_dict = dict(
(self.local_values(entity), None)
for entity in self.entities
)
def fetch(self):
self.parent_dict = defaultdict(lambda: None)
for entity in self.related_entities:
self.parent_dict[getattr(entity, self.remote_column_name)] = entity
class OneToManyFetcher(Fetcher):
def fetch(self):
self.parent_dict = defaultdict(list)
for entity in self.related_entities:
self.parent_dict[getattr(entity, self.remote_column_name)].append(
entity

View File

@@ -26,7 +26,7 @@ class TestCompoundBatchFetching(TestCase):
)
class Equipment(self.Base):
__tablename__ = 'article'
__tablename__ = 'equipment'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
building_id = sa.Column(sa.Integer, sa.ForeignKey(Building.id))