Refactored batch fetch

This commit is contained in:
Konsta Vesterinen
2013-08-21 16:38:44 +03:00
parent 03b1985980
commit 223f6a64fa
3 changed files with 157 additions and 35 deletions

View File

@@ -6,12 +6,13 @@ from sqlalchemy.orm import defer
from sqlalchemy.orm.properties import ColumnProperty
from sqlalchemy.orm.query import Query
from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint
from .batch_fetch import batch_fetch, with_backrefs
from .batch_fetch import batch_fetch, with_backrefs, compound_path
from .sort_query import sort_query
__all__ = (
batch_fetch,
compound_path,
sort_query,
with_backrefs
)

View File

@@ -13,6 +13,11 @@ class with_backrefs(object):
self.attr_path = attr_path
class compound_path(object):
def __init__(self, *attr_paths):
self.attr_paths = attr_paths
def batch_fetch(entities, *attr_paths):
"""
Batch fetch given relationship attribute for collection of entities.
@@ -118,7 +123,7 @@ class FetchingCoordinator(object):
'are supported.'
)
def fetcher(self, property_):
def fetcher_for_property(self, property_):
if not isinstance(property_, RelationshipProperty):
raise Exception(
'Given attribute is not a relationship property.'
@@ -132,7 +137,7 @@ class FetchingCoordinator(object):
else:
return OneToManyFetcher(self, property_)
def __call__(self, attr_path):
def fetcher_for_attr_path(self, attr_path):
if isinstance(attr_path, with_backrefs):
self.should_populate_backrefs = True
attr_path = attr_path.attr_path
@@ -142,10 +147,26 @@ class FetchingCoordinator(object):
attr = self.parse_attr_path(attr_path, self.should_populate_backrefs)
if not attr:
return
return self.fetcher_for_property(attr.property)
fetcher = self.fetcher(attr.property)
fetcher.fetch()
fetcher.populate()
def __call__(self, attr_path):
if isinstance(attr_path, compound_path):
for path in attr_path.attr_paths:
self(path)
else:
fetcher = self.fetcher_for_attr_path(attr_path)
if not fetcher:
return
fetcher.fetch()
fetcher.populate()
class CompoundFetcher(object):
def __init__(self, coordinator, path):
self.coordinator = coordinator
self.entities = coordinator.entities
self.first = self.entities[0]
self.session = object_session(self.first)
class Fetcher(object):
@@ -207,31 +228,49 @@ class Fetcher(object):
if self.coordinator.should_populate_backrefs:
self.populate_backrefs(self.related_entities)
@property
def remote_column_name(self):
return list(self.prop.remote_side)[0].name
@property
def condition(self):
return getattr(self.model, self.remote_column_name).in_(
self.local_values_list
)
@property
def related_entities(self):
return self.session.query(self.model).filter(self.condition)
class ManyToManyFetcher(Fetcher):
def fetch(self):
column_name = None
@property
def remote_column_name(self):
for column in self.prop.remote_side:
for fk in column.foreign_keys:
# TODO: make this support inherited tables
if fk.column.table == self.first.__class__.__table__:
column_name = fk.parent.name
break
if column_name:
break
return fk.parent.name
self.related_entities = (
@property
def related_entities(self):
return (
self.session
.query(self.model, getattr(self.prop.secondary.c, column_name))
.query(
self.model,
getattr(self.prop.secondary.c, self.remote_column_name)
)
.join(
self.prop.secondary, self.prop.secondaryjoin
)
.filter(
getattr(self.prop.secondary.c, column_name).in_(
getattr(self.prop.secondary.c, self.remote_column_name).in_(
self.local_values_list
)
)
)
def fetch(self):
for entity, parent_id in self.related_entities:
self.parent_dict[parent_id].append(
entity
@@ -246,31 +285,13 @@ class ManyToOneFetcher(Fetcher):
)
def fetch(self):
column_name = list(self.prop.remote_side)[0].name
self.related_entities = (
self.session.query(self.model)
.filter(
getattr(self.model, column_name).in_(self.local_values_list)
)
)
for entity in self.related_entities:
self.parent_dict[getattr(entity, column_name)] = entity
self.parent_dict[getattr(entity, self.remote_column_name)] = entity
class OneToManyFetcher(Fetcher):
def fetch(self):
column_name = list(self.prop.remote_side)[0].name
self.related_entities = (
self.session.query(self.model)
.filter(
getattr(self.model, column_name).in_(self.local_values_list)
)
)
for entity in self.related_entities:
self.parent_dict[getattr(entity, column_name)].append(
self.parent_dict[getattr(entity, self.remote_column_name)].append(
entity
)

View File

@@ -0,0 +1,100 @@
import sqlalchemy as sa
from sqlalchemy_utils import batch_fetch
from sqlalchemy_utils.functions import compound_path
from tests import TestCase
class TestCompoundBatchFetching(TestCase):
def create_models(self):
class Building(self.Base):
__tablename__ = 'building'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
class BusinessPremise(self.Base):
__tablename__ = 'business_premise'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
building_id = sa.Column(sa.Integer, sa.ForeignKey(Building.id))
building = sa.orm.relationship(
Building,
backref=sa.orm.backref(
'business_premises'
)
)
class Equipment(self.Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
building_id = sa.Column(sa.Integer, sa.ForeignKey(Building.id))
business_premise_id = sa.Column(
sa.Integer, sa.ForeignKey(BusinessPremise.id)
)
building = sa.orm.relationship(
Building,
backref=sa.orm.backref(
'equipment'
)
)
business_premise = sa.orm.relationship(
BusinessPremise,
backref=sa.orm.backref(
'equipment'
)
)
self.Building = Building
self.BusinessPremise = BusinessPremise
self.Equipment = Equipment
def setup_method(self, method):
TestCase.setup_method(self, method)
self.buildings = [
self.Building(name=u'B 1'),
self.Building(name=u'B 2'),
self.Building(name=u'B 3'),
]
self.business_premises = [
self.BusinessPremise(name=u'BP 1', building=self.buildings[0]),
self.BusinessPremise(name=u'BP 2', building=self.buildings[0]),
self.BusinessPremise(name=u'BP 3', building=self.buildings[2]),
]
self.equipment = [
self.Equipment(
name=u'E 1', building=self.buildings[0]
),
self.Equipment(
name=u'E 2', building=self.buildings[2]
),
self.Equipment(
name=u'E 3', business_premise=self.business_premises[0]
),
self.Equipment(
name=u'E 4', business_premise=self.business_premises[2]
),
]
self.session.add_all(self.buildings)
self.session.add_all(self.business_premises)
self.session.add_all(self.equipment)
self.session.commit()
def test_compound_fetching(self):
buildings = self.session.query(self.Building).all()
batch_fetch(
buildings,
'business_premises',
compound_path(
'equipment',
'business_premises.equipment'
)
)
query_count = self.connection.query_count
buildings[0].equipment
buildings[1].equipment
buildings[0].business_premises[0].equipment
assert self.connection.query_count == query_count