Refactored batch fetch
This commit is contained in:
@@ -6,12 +6,13 @@ from sqlalchemy.orm import defer
|
||||
from sqlalchemy.orm.properties import ColumnProperty
|
||||
from sqlalchemy.orm.query import Query
|
||||
from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint
|
||||
from .batch_fetch import batch_fetch, with_backrefs
|
||||
from .batch_fetch import batch_fetch, with_backrefs, compound_path
|
||||
from .sort_query import sort_query
|
||||
|
||||
|
||||
__all__ = (
|
||||
batch_fetch,
|
||||
compound_path,
|
||||
sort_query,
|
||||
with_backrefs
|
||||
)
|
||||
|
||||
@@ -13,6 +13,11 @@ class with_backrefs(object):
|
||||
self.attr_path = attr_path
|
||||
|
||||
|
||||
class compound_path(object):
|
||||
def __init__(self, *attr_paths):
|
||||
self.attr_paths = attr_paths
|
||||
|
||||
|
||||
def batch_fetch(entities, *attr_paths):
|
||||
"""
|
||||
Batch fetch given relationship attribute for collection of entities.
|
||||
@@ -118,7 +123,7 @@ class FetchingCoordinator(object):
|
||||
'are supported.'
|
||||
)
|
||||
|
||||
def fetcher(self, property_):
|
||||
def fetcher_for_property(self, property_):
|
||||
if not isinstance(property_, RelationshipProperty):
|
||||
raise Exception(
|
||||
'Given attribute is not a relationship property.'
|
||||
@@ -132,7 +137,7 @@ class FetchingCoordinator(object):
|
||||
else:
|
||||
return OneToManyFetcher(self, property_)
|
||||
|
||||
def __call__(self, attr_path):
|
||||
def fetcher_for_attr_path(self, attr_path):
|
||||
if isinstance(attr_path, with_backrefs):
|
||||
self.should_populate_backrefs = True
|
||||
attr_path = attr_path.attr_path
|
||||
@@ -142,10 +147,26 @@ class FetchingCoordinator(object):
|
||||
attr = self.parse_attr_path(attr_path, self.should_populate_backrefs)
|
||||
if not attr:
|
||||
return
|
||||
return self.fetcher_for_property(attr.property)
|
||||
|
||||
fetcher = self.fetcher(attr.property)
|
||||
fetcher.fetch()
|
||||
fetcher.populate()
|
||||
def __call__(self, attr_path):
|
||||
if isinstance(attr_path, compound_path):
|
||||
for path in attr_path.attr_paths:
|
||||
self(path)
|
||||
else:
|
||||
fetcher = self.fetcher_for_attr_path(attr_path)
|
||||
if not fetcher:
|
||||
return
|
||||
fetcher.fetch()
|
||||
fetcher.populate()
|
||||
|
||||
|
||||
class CompoundFetcher(object):
|
||||
def __init__(self, coordinator, path):
|
||||
self.coordinator = coordinator
|
||||
self.entities = coordinator.entities
|
||||
self.first = self.entities[0]
|
||||
self.session = object_session(self.first)
|
||||
|
||||
|
||||
class Fetcher(object):
|
||||
@@ -207,31 +228,49 @@ class Fetcher(object):
|
||||
if self.coordinator.should_populate_backrefs:
|
||||
self.populate_backrefs(self.related_entities)
|
||||
|
||||
@property
|
||||
def remote_column_name(self):
|
||||
return list(self.prop.remote_side)[0].name
|
||||
|
||||
@property
|
||||
def condition(self):
|
||||
return getattr(self.model, self.remote_column_name).in_(
|
||||
self.local_values_list
|
||||
)
|
||||
|
||||
@property
|
||||
def related_entities(self):
|
||||
return self.session.query(self.model).filter(self.condition)
|
||||
|
||||
|
||||
class ManyToManyFetcher(Fetcher):
|
||||
def fetch(self):
|
||||
column_name = None
|
||||
@property
|
||||
def remote_column_name(self):
|
||||
for column in self.prop.remote_side:
|
||||
for fk in column.foreign_keys:
|
||||
# TODO: make this support inherited tables
|
||||
if fk.column.table == self.first.__class__.__table__:
|
||||
column_name = fk.parent.name
|
||||
break
|
||||
if column_name:
|
||||
break
|
||||
return fk.parent.name
|
||||
|
||||
self.related_entities = (
|
||||
@property
|
||||
def related_entities(self):
|
||||
return (
|
||||
self.session
|
||||
.query(self.model, getattr(self.prop.secondary.c, column_name))
|
||||
.query(
|
||||
self.model,
|
||||
getattr(self.prop.secondary.c, self.remote_column_name)
|
||||
)
|
||||
.join(
|
||||
self.prop.secondary, self.prop.secondaryjoin
|
||||
)
|
||||
.filter(
|
||||
getattr(self.prop.secondary.c, column_name).in_(
|
||||
getattr(self.prop.secondary.c, self.remote_column_name).in_(
|
||||
self.local_values_list
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
def fetch(self):
|
||||
for entity, parent_id in self.related_entities:
|
||||
self.parent_dict[parent_id].append(
|
||||
entity
|
||||
@@ -246,31 +285,13 @@ class ManyToOneFetcher(Fetcher):
|
||||
)
|
||||
|
||||
def fetch(self):
|
||||
column_name = list(self.prop.remote_side)[0].name
|
||||
|
||||
self.related_entities = (
|
||||
self.session.query(self.model)
|
||||
.filter(
|
||||
getattr(self.model, column_name).in_(self.local_values_list)
|
||||
)
|
||||
)
|
||||
|
||||
for entity in self.related_entities:
|
||||
self.parent_dict[getattr(entity, column_name)] = entity
|
||||
self.parent_dict[getattr(entity, self.remote_column_name)] = entity
|
||||
|
||||
|
||||
class OneToManyFetcher(Fetcher):
|
||||
def fetch(self):
|
||||
column_name = list(self.prop.remote_side)[0].name
|
||||
|
||||
self.related_entities = (
|
||||
self.session.query(self.model)
|
||||
.filter(
|
||||
getattr(self.model, column_name).in_(self.local_values_list)
|
||||
)
|
||||
)
|
||||
|
||||
for entity in self.related_entities:
|
||||
self.parent_dict[getattr(entity, column_name)].append(
|
||||
self.parent_dict[getattr(entity, self.remote_column_name)].append(
|
||||
entity
|
||||
)
|
||||
|
||||
100
tests/batch_fetch/test_compound_fetching.py
Normal file
100
tests/batch_fetch/test_compound_fetching.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy_utils import batch_fetch
|
||||
from sqlalchemy_utils.functions import compound_path
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestCompoundBatchFetching(TestCase):
|
||||
def create_models(self):
|
||||
class Building(self.Base):
|
||||
__tablename__ = 'building'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
class BusinessPremise(self.Base):
|
||||
__tablename__ = 'business_premise'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
building_id = sa.Column(sa.Integer, sa.ForeignKey(Building.id))
|
||||
|
||||
building = sa.orm.relationship(
|
||||
Building,
|
||||
backref=sa.orm.backref(
|
||||
'business_premises'
|
||||
)
|
||||
)
|
||||
|
||||
class Equipment(self.Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
building_id = sa.Column(sa.Integer, sa.ForeignKey(Building.id))
|
||||
business_premise_id = sa.Column(
|
||||
sa.Integer, sa.ForeignKey(BusinessPremise.id)
|
||||
)
|
||||
|
||||
building = sa.orm.relationship(
|
||||
Building,
|
||||
backref=sa.orm.backref(
|
||||
'equipment'
|
||||
)
|
||||
)
|
||||
business_premise = sa.orm.relationship(
|
||||
BusinessPremise,
|
||||
backref=sa.orm.backref(
|
||||
'equipment'
|
||||
)
|
||||
)
|
||||
|
||||
self.Building = Building
|
||||
self.BusinessPremise = BusinessPremise
|
||||
self.Equipment = Equipment
|
||||
|
||||
def setup_method(self, method):
|
||||
TestCase.setup_method(self, method)
|
||||
self.buildings = [
|
||||
self.Building(name=u'B 1'),
|
||||
self.Building(name=u'B 2'),
|
||||
self.Building(name=u'B 3'),
|
||||
]
|
||||
self.business_premises = [
|
||||
self.BusinessPremise(name=u'BP 1', building=self.buildings[0]),
|
||||
self.BusinessPremise(name=u'BP 2', building=self.buildings[0]),
|
||||
self.BusinessPremise(name=u'BP 3', building=self.buildings[2]),
|
||||
]
|
||||
self.equipment = [
|
||||
self.Equipment(
|
||||
name=u'E 1', building=self.buildings[0]
|
||||
),
|
||||
self.Equipment(
|
||||
name=u'E 2', building=self.buildings[2]
|
||||
),
|
||||
self.Equipment(
|
||||
name=u'E 3', business_premise=self.business_premises[0]
|
||||
),
|
||||
self.Equipment(
|
||||
name=u'E 4', business_premise=self.business_premises[2]
|
||||
),
|
||||
]
|
||||
self.session.add_all(self.buildings)
|
||||
self.session.add_all(self.business_premises)
|
||||
self.session.add_all(self.equipment)
|
||||
self.session.commit()
|
||||
|
||||
def test_compound_fetching(self):
|
||||
buildings = self.session.query(self.Building).all()
|
||||
batch_fetch(
|
||||
buildings,
|
||||
'business_premises',
|
||||
compound_path(
|
||||
'equipment',
|
||||
'business_premises.equipment'
|
||||
)
|
||||
)
|
||||
query_count = self.connection.query_count
|
||||
|
||||
buildings[0].equipment
|
||||
buildings[1].equipment
|
||||
buildings[0].business_premises[0].equipment
|
||||
assert self.connection.query_count == query_count
|
||||
Reference in New Issue
Block a user