Refactored batch fetch
This commit is contained in:
@@ -1,4 +1,6 @@
|
|||||||
|
from collections import defaultdict
|
||||||
import six
|
import six
|
||||||
|
import sqlalchemy as sa
|
||||||
from sqlalchemy.orm import RelationshipProperty
|
from sqlalchemy.orm import RelationshipProperty
|
||||||
from sqlalchemy.orm.attributes import set_committed_value
|
from sqlalchemy.orm.attributes import set_committed_value
|
||||||
from sqlalchemy.orm.session import object_session
|
from sqlalchemy.orm.session import object_session
|
||||||
@@ -104,24 +106,14 @@ class FetchingCoordinator(object):
|
|||||||
if should_populate_backrefs:
|
if should_populate_backrefs:
|
||||||
subpath = with_backrefs(subpath)
|
subpath = with_backrefs(subpath)
|
||||||
|
|
||||||
batch_fetch(
|
return self.__class__(related_entities).fetcher_for_attr_path(
|
||||||
related_entities,
|
|
||||||
subpath
|
subpath
|
||||||
)
|
)
|
||||||
return
|
|
||||||
else:
|
else:
|
||||||
return getattr(
|
attr_path = getattr(
|
||||||
self.first.__class__, attrs[0]
|
self.first.__class__, attrs[0]
|
||||||
)
|
)
|
||||||
else:
|
return self.fetcher_for_property(attr_path.property)
|
||||||
return attr_path
|
|
||||||
|
|
||||||
def fetch_relation_entities(self):
|
|
||||||
if len(self.prop.remote_side) > 1:
|
|
||||||
raise Exception(
|
|
||||||
'Only relationships with single remote side columns '
|
|
||||||
'are supported.'
|
|
||||||
)
|
|
||||||
|
|
||||||
def fetcher_for_property(self, property_):
|
def fetcher_for_property(self, property_):
|
||||||
if not isinstance(property_, RelationshipProperty):
|
if not isinstance(property_, RelationshipProperty):
|
||||||
@@ -130,12 +122,18 @@ class FetchingCoordinator(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if property_.secondary is not None:
|
if property_.secondary is not None:
|
||||||
return ManyToManyFetcher(self, property_)
|
fetcher_class = ManyToManyFetcher
|
||||||
else:
|
else:
|
||||||
if property_.direction.name == 'MANYTOONE':
|
if property_.direction.name == 'MANYTOONE':
|
||||||
return ManyToOneFetcher(self, property_)
|
fetcher_class = ManyToOneFetcher
|
||||||
else:
|
else:
|
||||||
return OneToManyFetcher(self, property_)
|
fetcher_class = OneToManyFetcher
|
||||||
|
|
||||||
|
return fetcher_class(
|
||||||
|
self.entities,
|
||||||
|
property_,
|
||||||
|
self.should_populate_backrefs
|
||||||
|
)
|
||||||
|
|
||||||
def fetcher_for_attr_path(self, attr_path):
|
def fetcher_for_attr_path(self, attr_path):
|
||||||
if isinstance(attr_path, with_backrefs):
|
if isinstance(attr_path, with_backrefs):
|
||||||
@@ -144,15 +142,19 @@ class FetchingCoordinator(object):
|
|||||||
else:
|
else:
|
||||||
self.should_populate_backrefs = False
|
self.should_populate_backrefs = False
|
||||||
|
|
||||||
attr = self.parse_attr_path(attr_path, self.should_populate_backrefs)
|
return self.parse_attr_path(
|
||||||
if not attr:
|
attr_path,
|
||||||
return
|
self.should_populate_backrefs
|
||||||
return self.fetcher_for_property(attr.property)
|
)
|
||||||
|
|
||||||
def __call__(self, attr_path):
|
def __call__(self, attr_path):
|
||||||
if isinstance(attr_path, compound_path):
|
if isinstance(attr_path, compound_path):
|
||||||
|
fetchers = []
|
||||||
for path in attr_path.attr_paths:
|
for path in attr_path.attr_paths:
|
||||||
self(path)
|
fetchers.append(self.fetcher_for_attr_path(path))
|
||||||
|
|
||||||
|
fetcher = CompoundFetcher(*fetchers)
|
||||||
|
print fetcher.condition
|
||||||
else:
|
else:
|
||||||
fetcher = self.fetcher_for_attr_path(attr_path)
|
fetcher = self.fetcher_for_attr_path(attr_path)
|
||||||
if not fetcher:
|
if not fetcher:
|
||||||
@@ -162,28 +164,33 @@ class FetchingCoordinator(object):
|
|||||||
|
|
||||||
|
|
||||||
class CompoundFetcher(object):
|
class CompoundFetcher(object):
|
||||||
def __init__(self, coordinator, path):
|
def __init__(self, *fetchers):
|
||||||
self.coordinator = coordinator
|
if not all(fetchers[0].model == fetcher.model for fetcher in fetchers):
|
||||||
self.entities = coordinator.entities
|
raise Exception(
|
||||||
self.first = self.entities[0]
|
'Each relationship property must have the same class when '
|
||||||
self.session = object_session(self.first)
|
'using CompoundFetcher.'
|
||||||
|
)
|
||||||
|
self.fetchers = fetchers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def condition(self):
|
||||||
|
return sa.or_(
|
||||||
|
*[fetcher.condition for fetcher in self.fetchers]
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def local_values(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Fetcher(object):
|
class Fetcher(object):
|
||||||
def __init__(self, coordinator, property_):
|
def __init__(self, entities, property_, populate_backrefs=False):
|
||||||
self.coordinator = coordinator
|
self.should_populate_backrefs = populate_backrefs
|
||||||
|
self.entities = entities
|
||||||
self.prop = property_
|
self.prop = property_
|
||||||
self.model = self.prop.mapper.class_
|
self.model = self.prop.mapper.class_
|
||||||
self.entities = coordinator.entities
|
|
||||||
self.first = self.entities[0]
|
self.first = self.entities[0]
|
||||||
self.session = object_session(self.first)
|
self.session = object_session(self.first)
|
||||||
self.init_parent_dict()
|
|
||||||
|
|
||||||
def init_parent_dict(self):
|
|
||||||
self.parent_dict = dict(
|
|
||||||
(self.local_values(entity), [])
|
|
||||||
for entity in self.entities
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def local_values_list(self):
|
def local_values_list(self):
|
||||||
@@ -225,7 +232,7 @@ class Fetcher(object):
|
|||||||
self.parent_dict[self.local_values(entity)]
|
self.parent_dict[self.local_values(entity)]
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.coordinator.should_populate_backrefs:
|
if self.should_populate_backrefs:
|
||||||
self.populate_backrefs(self.related_entities)
|
self.populate_backrefs(self.related_entities)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -271,6 +278,7 @@ class ManyToManyFetcher(Fetcher):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def fetch(self):
|
def fetch(self):
|
||||||
|
self.parent_dict = defaultdict(list)
|
||||||
for entity, parent_id in self.related_entities:
|
for entity, parent_id in self.related_entities:
|
||||||
self.parent_dict[parent_id].append(
|
self.parent_dict[parent_id].append(
|
||||||
entity
|
entity
|
||||||
@@ -278,19 +286,15 @@ class ManyToManyFetcher(Fetcher):
|
|||||||
|
|
||||||
|
|
||||||
class ManyToOneFetcher(Fetcher):
|
class ManyToOneFetcher(Fetcher):
|
||||||
def init_parent_dict(self):
|
|
||||||
self.parent_dict = dict(
|
|
||||||
(self.local_values(entity), None)
|
|
||||||
for entity in self.entities
|
|
||||||
)
|
|
||||||
|
|
||||||
def fetch(self):
|
def fetch(self):
|
||||||
|
self.parent_dict = defaultdict(lambda: None)
|
||||||
for entity in self.related_entities:
|
for entity in self.related_entities:
|
||||||
self.parent_dict[getattr(entity, self.remote_column_name)] = entity
|
self.parent_dict[getattr(entity, self.remote_column_name)] = entity
|
||||||
|
|
||||||
|
|
||||||
class OneToManyFetcher(Fetcher):
|
class OneToManyFetcher(Fetcher):
|
||||||
def fetch(self):
|
def fetch(self):
|
||||||
|
self.parent_dict = defaultdict(list)
|
||||||
for entity in self.related_entities:
|
for entity in self.related_entities:
|
||||||
self.parent_dict[getattr(entity, self.remote_column_name)].append(
|
self.parent_dict[getattr(entity, self.remote_column_name)].append(
|
||||||
entity
|
entity
|
||||||
|
@@ -26,7 +26,7 @@ class TestCompoundBatchFetching(TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
class Equipment(self.Base):
|
class Equipment(self.Base):
|
||||||
__tablename__ = 'article'
|
__tablename__ = 'equipment'
|
||||||
id = sa.Column(sa.Integer, primary_key=True)
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
name = sa.Column(sa.Unicode(255))
|
name = sa.Column(sa.Unicode(255))
|
||||||
building_id = sa.Column(sa.Integer, sa.ForeignKey(Building.id))
|
building_id = sa.Column(sa.Integer, sa.ForeignKey(Building.id))
|
||||||
|
Reference in New Issue
Block a user