Added support for compound one to many batch fetching
This commit is contained in:
@@ -154,16 +154,19 @@ class FetchingCoordinator(object):
|
|||||||
fetchers.append(self.fetcher_for_attr_path(path))
|
fetchers.append(self.fetcher_for_attr_path(path))
|
||||||
|
|
||||||
fetcher = CompoundFetcher(*fetchers)
|
fetcher = CompoundFetcher(*fetchers)
|
||||||
print fetcher.condition
|
|
||||||
else:
|
else:
|
||||||
fetcher = self.fetcher_for_attr_path(attr_path)
|
fetcher = self.fetcher_for_attr_path(attr_path)
|
||||||
if not fetcher:
|
fetcher.fetch()
|
||||||
return
|
fetcher.populate()
|
||||||
fetcher.fetch()
|
|
||||||
fetcher.populate()
|
|
||||||
|
|
||||||
|
|
||||||
class CompoundFetcher(object):
|
class AbstractFetcher(object):
|
||||||
|
@property
|
||||||
|
def related_entities(self):
|
||||||
|
return self.session.query(self.model).filter(self.condition)
|
||||||
|
|
||||||
|
|
||||||
|
class CompoundFetcher(AbstractFetcher):
|
||||||
def __init__(self, *fetchers):
|
def __init__(self, *fetchers):
|
||||||
if not all(fetchers[0].model == fetcher.model for fetcher in fetchers):
|
if not all(fetchers[0].model == fetcher.model for fetcher in fetchers):
|
||||||
raise Exception(
|
raise Exception(
|
||||||
@@ -172,18 +175,36 @@ class CompoundFetcher(object):
|
|||||||
)
|
)
|
||||||
self.fetchers = fetchers
|
self.fetchers = fetchers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def session(self):
|
||||||
|
return self.fetchers[0].session
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model(self):
|
||||||
|
return self.fetchers[0].model
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def condition(self):
|
def condition(self):
|
||||||
return sa.or_(
|
return sa.or_(
|
||||||
*[fetcher.condition for fetcher in self.fetchers]
|
*[fetcher.condition for fetcher in self.fetchers]
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
def fetcher_for_entity(self, entity):
|
||||||
def local_values(self):
|
for fetcher in self.fetchers:
|
||||||
pass
|
if getattr(entity, fetcher.remote_column_name) is not None:
|
||||||
|
return fetcher
|
||||||
|
|
||||||
|
def fetch(self):
|
||||||
|
self.parent_dict = defaultdict(list)
|
||||||
|
for entity in self.related_entities:
|
||||||
|
self.fetcher_for_entity(entity).append_entity(entity)
|
||||||
|
|
||||||
|
def populate(self):
|
||||||
|
for fetcher in self.fetchers:
|
||||||
|
fetcher.populate()
|
||||||
|
|
||||||
|
|
||||||
class Fetcher(object):
|
class Fetcher(AbstractFetcher):
|
||||||
def __init__(self, entities, property_, populate_backrefs=False):
|
def __init__(self, entities, property_, populate_backrefs=False):
|
||||||
self.should_populate_backrefs = populate_backrefs
|
self.should_populate_backrefs = populate_backrefs
|
||||||
self.entities = entities
|
self.entities = entities
|
||||||
@@ -191,6 +212,7 @@ class Fetcher(object):
|
|||||||
self.model = self.prop.mapper.class_
|
self.model = self.prop.mapper.class_
|
||||||
self.first = self.entities[0]
|
self.first = self.entities[0]
|
||||||
self.session = object_session(self.first)
|
self.session = object_session(self.first)
|
||||||
|
self.parent_dict = defaultdict(list)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def local_values_list(self):
|
def local_values_list(self):
|
||||||
@@ -245,9 +267,9 @@ class Fetcher(object):
|
|||||||
self.local_values_list
|
self.local_values_list
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
def fetch(self):
|
||||||
def related_entities(self):
|
for entity in self.related_entities:
|
||||||
return self.session.query(self.model).filter(self.condition)
|
self.append_entity(entity)
|
||||||
|
|
||||||
|
|
||||||
class ManyToManyFetcher(Fetcher):
|
class ManyToManyFetcher(Fetcher):
|
||||||
@@ -278,7 +300,6 @@ class ManyToManyFetcher(Fetcher):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def fetch(self):
|
def fetch(self):
|
||||||
self.parent_dict = defaultdict(list)
|
|
||||||
for entity, parent_id in self.related_entities:
|
for entity, parent_id in self.related_entities:
|
||||||
self.parent_dict[parent_id].append(
|
self.parent_dict[parent_id].append(
|
||||||
entity
|
entity
|
||||||
@@ -286,16 +307,16 @@ class ManyToManyFetcher(Fetcher):
|
|||||||
|
|
||||||
|
|
||||||
class ManyToOneFetcher(Fetcher):
|
class ManyToOneFetcher(Fetcher):
|
||||||
def fetch(self):
|
def __init__(self, entities, property_, populate_backrefs=False):
|
||||||
|
Fetcher.__init__(self, entities, property_, populate_backrefs)
|
||||||
self.parent_dict = defaultdict(lambda: None)
|
self.parent_dict = defaultdict(lambda: None)
|
||||||
for entity in self.related_entities:
|
|
||||||
self.parent_dict[getattr(entity, self.remote_column_name)] = entity
|
def append_entity(self, entity):
|
||||||
|
self.parent_dict[getattr(entity, self.remote_column_name)] = entity
|
||||||
|
|
||||||
|
|
||||||
class OneToManyFetcher(Fetcher):
|
class OneToManyFetcher(Fetcher):
|
||||||
def fetch(self):
|
def append_entity(self, entity):
|
||||||
self.parent_dict = defaultdict(list)
|
self.parent_dict[getattr(entity, self.remote_column_name)].append(
|
||||||
for entity in self.related_entities:
|
entity
|
||||||
self.parent_dict[getattr(entity, self.remote_column_name)].append(
|
)
|
||||||
entity
|
|
||||||
)
|
|
||||||
|
@@ -4,7 +4,7 @@ from sqlalchemy_utils.functions import compound_path
|
|||||||
from tests import TestCase
|
from tests import TestCase
|
||||||
|
|
||||||
|
|
||||||
class TestCompoundBatchFetching(TestCase):
|
class TestCompoundOneToManyBatchFetching(TestCase):
|
||||||
def create_models(self):
|
def create_models(self):
|
||||||
class Building(self.Base):
|
class Building(self.Base):
|
||||||
__tablename__ = 'building'
|
__tablename__ = 'building'
|
||||||
@@ -54,27 +54,33 @@ class TestCompoundBatchFetching(TestCase):
|
|||||||
def setup_method(self, method):
|
def setup_method(self, method):
|
||||||
TestCase.setup_method(self, method)
|
TestCase.setup_method(self, method)
|
||||||
self.buildings = [
|
self.buildings = [
|
||||||
self.Building(name=u'B 1'),
|
self.Building(id=12, name=u'B 1'),
|
||||||
self.Building(name=u'B 2'),
|
self.Building(id=15, name=u'B 2'),
|
||||||
self.Building(name=u'B 3'),
|
self.Building(id=19, name=u'B 3'),
|
||||||
]
|
]
|
||||||
self.business_premises = [
|
self.business_premises = [
|
||||||
self.BusinessPremise(name=u'BP 1', building=self.buildings[0]),
|
self.BusinessPremise(
|
||||||
self.BusinessPremise(name=u'BP 2', building=self.buildings[0]),
|
id=22, name=u'BP 1', building=self.buildings[0]
|
||||||
self.BusinessPremise(name=u'BP 3', building=self.buildings[2]),
|
),
|
||||||
|
self.BusinessPremise(
|
||||||
|
id=33, name=u'BP 2', building=self.buildings[0]
|
||||||
|
),
|
||||||
|
self.BusinessPremise(
|
||||||
|
id=44, name=u'BP 3', building=self.buildings[2]
|
||||||
|
),
|
||||||
]
|
]
|
||||||
self.equipment = [
|
self.equipment = [
|
||||||
self.Equipment(
|
self.Equipment(
|
||||||
name=u'E 1', building=self.buildings[0]
|
id=2, name=u'E 1', building=self.buildings[0]
|
||||||
),
|
),
|
||||||
self.Equipment(
|
self.Equipment(
|
||||||
name=u'E 2', building=self.buildings[2]
|
id=4, name=u'E 2', building=self.buildings[2]
|
||||||
),
|
),
|
||||||
self.Equipment(
|
self.Equipment(
|
||||||
name=u'E 3', business_premise=self.business_premises[0]
|
id=6, name=u'E 3', business_premise=self.business_premises[0]
|
||||||
),
|
),
|
||||||
self.Equipment(
|
self.Equipment(
|
||||||
name=u'E 4', business_premise=self.business_premises[2]
|
id=8, name=u'E 4', business_premise=self.business_premises[2]
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
self.session.add_all(self.buildings)
|
self.session.add_all(self.buildings)
|
||||||
@@ -94,7 +100,10 @@ class TestCompoundBatchFetching(TestCase):
|
|||||||
)
|
)
|
||||||
query_count = self.connection.query_count
|
query_count = self.connection.query_count
|
||||||
|
|
||||||
buildings[0].equipment
|
assert len(buildings[0].equipment) == 1
|
||||||
buildings[1].equipment
|
assert buildings[0].equipment[0].name == 'E 1'
|
||||||
buildings[0].business_premises[0].equipment
|
assert not buildings[1].equipment
|
||||||
|
assert buildings[0].business_premises[0].equipment
|
||||||
|
assert self.business_premises[2].equipment
|
||||||
|
assert self.business_premises[2].equipment[0].name == 'E 4'
|
||||||
assert self.connection.query_count == query_count
|
assert self.connection.query_count == query_count
|
||||||
|
@@ -63,5 +63,16 @@ class TestBatchFetchOneToManyRelationships(TestCase):
|
|||||||
categories = self.session.query(self.Category).all()
|
categories = self.session.query(self.Category).all()
|
||||||
batch_fetch(categories, self.Category.articles)
|
batch_fetch(categories, self.Category.articles)
|
||||||
query_count = self.connection.query_count
|
query_count = self.connection.query_count
|
||||||
categories[0].articles # no lazy load should occur
|
articles = categories[0].articles # no lazy load should occur
|
||||||
|
assert len(articles) == 2
|
||||||
|
article_names = [article.name for article in articles]
|
||||||
|
|
||||||
|
assert 'Article 1' in article_names
|
||||||
|
assert 'Article 2' in article_names
|
||||||
|
articles = categories[1].articles # no lazy load should occur
|
||||||
|
assert len(articles) == 3
|
||||||
|
article_names = [article.name for article in articles]
|
||||||
|
assert 'Article 3' in article_names
|
||||||
|
assert 'Article 4' in article_names
|
||||||
|
assert 'Article 5' in article_names
|
||||||
assert self.connection.query_count == query_count
|
assert self.connection.query_count == query_count
|
||||||
|
Reference in New Issue
Block a user