Refactored batch fetch

This commit is contained in:
Konsta Vesterinen
2013-08-22 13:12:59 +03:00
parent 223f6a64fa
commit e197a56f53
2 changed files with 48 additions and 44 deletions

View File

@@ -1,4 +1,6 @@
from collections import defaultdict
import six import six
import sqlalchemy as sa
from sqlalchemy.orm import RelationshipProperty from sqlalchemy.orm import RelationshipProperty
from sqlalchemy.orm.attributes import set_committed_value from sqlalchemy.orm.attributes import set_committed_value
from sqlalchemy.orm.session import object_session from sqlalchemy.orm.session import object_session
@@ -104,24 +106,14 @@ class FetchingCoordinator(object):
if should_populate_backrefs: if should_populate_backrefs:
subpath = with_backrefs(subpath) subpath = with_backrefs(subpath)
batch_fetch( return self.__class__(related_entities).fetcher_for_attr_path(
related_entities,
subpath subpath
) )
return
else: else:
return getattr( attr_path = getattr(
self.first.__class__, attrs[0] self.first.__class__, attrs[0]
) )
else: return self.fetcher_for_property(attr_path.property)
return attr_path
def fetch_relation_entities(self):
if len(self.prop.remote_side) > 1:
raise Exception(
'Only relationships with single remote side columns '
'are supported.'
)
def fetcher_for_property(self, property_): def fetcher_for_property(self, property_):
if not isinstance(property_, RelationshipProperty): if not isinstance(property_, RelationshipProperty):
@@ -130,12 +122,18 @@ class FetchingCoordinator(object):
) )
if property_.secondary is not None: if property_.secondary is not None:
return ManyToManyFetcher(self, property_) fetcher_class = ManyToManyFetcher
else: else:
if property_.direction.name == 'MANYTOONE': if property_.direction.name == 'MANYTOONE':
return ManyToOneFetcher(self, property_) fetcher_class = ManyToOneFetcher
else: else:
return OneToManyFetcher(self, property_) fetcher_class = OneToManyFetcher
return fetcher_class(
self.entities,
property_,
self.should_populate_backrefs
)
def fetcher_for_attr_path(self, attr_path): def fetcher_for_attr_path(self, attr_path):
if isinstance(attr_path, with_backrefs): if isinstance(attr_path, with_backrefs):
@@ -144,15 +142,19 @@ class FetchingCoordinator(object):
else: else:
self.should_populate_backrefs = False self.should_populate_backrefs = False
attr = self.parse_attr_path(attr_path, self.should_populate_backrefs) return self.parse_attr_path(
if not attr: attr_path,
return self.should_populate_backrefs
return self.fetcher_for_property(attr.property) )
def __call__(self, attr_path): def __call__(self, attr_path):
if isinstance(attr_path, compound_path): if isinstance(attr_path, compound_path):
fetchers = []
for path in attr_path.attr_paths: for path in attr_path.attr_paths:
self(path) fetchers.append(self.fetcher_for_attr_path(path))
fetcher = CompoundFetcher(*fetchers)
print fetcher.condition
else: else:
fetcher = self.fetcher_for_attr_path(attr_path) fetcher = self.fetcher_for_attr_path(attr_path)
if not fetcher: if not fetcher:
@@ -162,28 +164,33 @@ class FetchingCoordinator(object):
class CompoundFetcher(object): class CompoundFetcher(object):
def __init__(self, coordinator, path): def __init__(self, *fetchers):
self.coordinator = coordinator if not all(fetchers[0].model == fetcher.model for fetcher in fetchers):
self.entities = coordinator.entities raise Exception(
self.first = self.entities[0] 'Each relationship property must have the same class when '
self.session = object_session(self.first) 'using CompoundFetcher.'
)
self.fetchers = fetchers
@property
def condition(self):
return sa.or_(
*[fetcher.condition for fetcher in self.fetchers]
)
@property
def local_values(self):
pass
class Fetcher(object): class Fetcher(object):
def __init__(self, coordinator, property_): def __init__(self, entities, property_, populate_backrefs=False):
self.coordinator = coordinator self.should_populate_backrefs = populate_backrefs
self.entities = entities
self.prop = property_ self.prop = property_
self.model = self.prop.mapper.class_ self.model = self.prop.mapper.class_
self.entities = coordinator.entities
self.first = self.entities[0] self.first = self.entities[0]
self.session = object_session(self.first) self.session = object_session(self.first)
self.init_parent_dict()
def init_parent_dict(self):
self.parent_dict = dict(
(self.local_values(entity), [])
for entity in self.entities
)
@property @property
def local_values_list(self): def local_values_list(self):
@@ -225,7 +232,7 @@ class Fetcher(object):
self.parent_dict[self.local_values(entity)] self.parent_dict[self.local_values(entity)]
) )
if self.coordinator.should_populate_backrefs: if self.should_populate_backrefs:
self.populate_backrefs(self.related_entities) self.populate_backrefs(self.related_entities)
@property @property
@@ -271,6 +278,7 @@ class ManyToManyFetcher(Fetcher):
) )
def fetch(self): def fetch(self):
self.parent_dict = defaultdict(list)
for entity, parent_id in self.related_entities: for entity, parent_id in self.related_entities:
self.parent_dict[parent_id].append( self.parent_dict[parent_id].append(
entity entity
@@ -278,19 +286,15 @@ class ManyToManyFetcher(Fetcher):
class ManyToOneFetcher(Fetcher): class ManyToOneFetcher(Fetcher):
def init_parent_dict(self):
self.parent_dict = dict(
(self.local_values(entity), None)
for entity in self.entities
)
def fetch(self): def fetch(self):
self.parent_dict = defaultdict(lambda: None)
for entity in self.related_entities: for entity in self.related_entities:
self.parent_dict[getattr(entity, self.remote_column_name)] = entity self.parent_dict[getattr(entity, self.remote_column_name)] = entity
class OneToManyFetcher(Fetcher): class OneToManyFetcher(Fetcher):
def fetch(self): def fetch(self):
self.parent_dict = defaultdict(list)
for entity in self.related_entities: for entity in self.related_entities:
self.parent_dict[getattr(entity, self.remote_column_name)].append( self.parent_dict[getattr(entity, self.remote_column_name)].append(
entity entity

View File

@@ -26,7 +26,7 @@ class TestCompoundBatchFetching(TestCase):
) )
class Equipment(self.Base): class Equipment(self.Base):
__tablename__ = 'article' __tablename__ = 'equipment'
id = sa.Column(sa.Integer, primary_key=True) id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255)) name = sa.Column(sa.Unicode(255))
building_id = sa.Column(sa.Integer, sa.ForeignKey(Building.id)) building_id = sa.Column(sa.Integer, sa.ForeignKey(Building.id))