From 223f6a64fa92e74a665af4da03b7be8340b9a72c Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Wed, 21 Aug 2013 16:38:44 +0300 Subject: [PATCH] Refactored batch fetch --- sqlalchemy_utils/functions/__init__.py | 3 +- sqlalchemy_utils/functions/batch_fetch.py | 89 ++++++++++------- tests/batch_fetch/test_compound_fetching.py | 100 ++++++++++++++++++++ 3 files changed, 157 insertions(+), 35 deletions(-) create mode 100644 tests/batch_fetch/test_compound_fetching.py diff --git a/sqlalchemy_utils/functions/__init__.py b/sqlalchemy_utils/functions/__init__.py index 1bee38e..03486db 100644 --- a/sqlalchemy_utils/functions/__init__.py +++ b/sqlalchemy_utils/functions/__init__.py @@ -6,12 +6,13 @@ from sqlalchemy.orm import defer from sqlalchemy.orm.properties import ColumnProperty from sqlalchemy.orm.query import Query from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint -from .batch_fetch import batch_fetch, with_backrefs +from .batch_fetch import batch_fetch, with_backrefs, compound_path from .sort_query import sort_query __all__ = ( batch_fetch, + compound_path, sort_query, with_backrefs ) diff --git a/sqlalchemy_utils/functions/batch_fetch.py b/sqlalchemy_utils/functions/batch_fetch.py index 1c8ccc8..e833b43 100644 --- a/sqlalchemy_utils/functions/batch_fetch.py +++ b/sqlalchemy_utils/functions/batch_fetch.py @@ -13,6 +13,11 @@ class with_backrefs(object): self.attr_path = attr_path +class compound_path(object): + def __init__(self, *attr_paths): + self.attr_paths = attr_paths + + def batch_fetch(entities, *attr_paths): """ Batch fetch given relationship attribute for collection of entities. @@ -118,7 +123,7 @@ class FetchingCoordinator(object): 'are supported.' ) - def fetcher(self, property_): + def fetcher_for_property(self, property_): if not isinstance(property_, RelationshipProperty): raise Exception( 'Given attribute is not a relationship property.' @@ -132,7 +137,7 @@ class FetchingCoordinator(object): else: return OneToManyFetcher(self, property_) - def __call__(self, attr_path): + def fetcher_for_attr_path(self, attr_path): if isinstance(attr_path, with_backrefs): self.should_populate_backrefs = True attr_path = attr_path.attr_path @@ -142,10 +147,26 @@ class FetchingCoordinator(object): attr = self.parse_attr_path(attr_path, self.should_populate_backrefs) if not attr: return + return self.fetcher_for_property(attr.property) - fetcher = self.fetcher(attr.property) - fetcher.fetch() - fetcher.populate() + def __call__(self, attr_path): + if isinstance(attr_path, compound_path): + for path in attr_path.attr_paths: + self(path) + else: + fetcher = self.fetcher_for_attr_path(attr_path) + if not fetcher: + return + fetcher.fetch() + fetcher.populate() + + +class CompoundFetcher(object): + def __init__(self, coordinator, path): + self.coordinator = coordinator + self.entities = coordinator.entities + self.first = self.entities[0] + self.session = object_session(self.first) class Fetcher(object): @@ -207,31 +228,49 @@ class Fetcher(object): if self.coordinator.should_populate_backrefs: self.populate_backrefs(self.related_entities) + @property + def remote_column_name(self): + return list(self.prop.remote_side)[0].name + + @property + def condition(self): + return getattr(self.model, self.remote_column_name).in_( + self.local_values_list + ) + + @property + def related_entities(self): + return self.session.query(self.model).filter(self.condition) + class ManyToManyFetcher(Fetcher): - def fetch(self): - column_name = None + @property + def remote_column_name(self): for column in self.prop.remote_side: for fk in column.foreign_keys: # TODO: make this support inherited tables if fk.column.table == self.first.__class__.__table__: - column_name = fk.parent.name - break - if column_name: - break + return fk.parent.name - self.related_entities = ( + @property + def related_entities(self): + return ( self.session - .query(self.model, getattr(self.prop.secondary.c, column_name)) + .query( + self.model, + getattr(self.prop.secondary.c, self.remote_column_name) + ) .join( self.prop.secondary, self.prop.secondaryjoin ) .filter( - getattr(self.prop.secondary.c, column_name).in_( + getattr(self.prop.secondary.c, self.remote_column_name).in_( self.local_values_list ) ) ) + + def fetch(self): for entity, parent_id in self.related_entities: self.parent_dict[parent_id].append( entity @@ -246,31 +285,13 @@ class ManyToOneFetcher(Fetcher): ) def fetch(self): - column_name = list(self.prop.remote_side)[0].name - - self.related_entities = ( - self.session.query(self.model) - .filter( - getattr(self.model, column_name).in_(self.local_values_list) - ) - ) - for entity in self.related_entities: - self.parent_dict[getattr(entity, column_name)] = entity + self.parent_dict[getattr(entity, self.remote_column_name)] = entity class OneToManyFetcher(Fetcher): def fetch(self): - column_name = list(self.prop.remote_side)[0].name - - self.related_entities = ( - self.session.query(self.model) - .filter( - getattr(self.model, column_name).in_(self.local_values_list) - ) - ) - for entity in self.related_entities: - self.parent_dict[getattr(entity, column_name)].append( + self.parent_dict[getattr(entity, self.remote_column_name)].append( entity ) diff --git a/tests/batch_fetch/test_compound_fetching.py b/tests/batch_fetch/test_compound_fetching.py new file mode 100644 index 0000000..d320b63 --- /dev/null +++ b/tests/batch_fetch/test_compound_fetching.py @@ -0,0 +1,100 @@ +import sqlalchemy as sa +from sqlalchemy_utils import batch_fetch +from sqlalchemy_utils.functions import compound_path +from tests import TestCase + + +class TestCompoundBatchFetching(TestCase): + def create_models(self): + class Building(self.Base): + __tablename__ = 'building' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + class BusinessPremise(self.Base): + __tablename__ = 'business_premise' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + building_id = sa.Column(sa.Integer, sa.ForeignKey(Building.id)) + + building = sa.orm.relationship( + Building, + backref=sa.orm.backref( + 'business_premises' + ) + ) + + class Equipment(self.Base): + __tablename__ = 'article' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + building_id = sa.Column(sa.Integer, sa.ForeignKey(Building.id)) + business_premise_id = sa.Column( + sa.Integer, sa.ForeignKey(BusinessPremise.id) + ) + + building = sa.orm.relationship( + Building, + backref=sa.orm.backref( + 'equipment' + ) + ) + business_premise = sa.orm.relationship( + BusinessPremise, + backref=sa.orm.backref( + 'equipment' + ) + ) + + self.Building = Building + self.BusinessPremise = BusinessPremise + self.Equipment = Equipment + + def setup_method(self, method): + TestCase.setup_method(self, method) + self.buildings = [ + self.Building(name=u'B 1'), + self.Building(name=u'B 2'), + self.Building(name=u'B 3'), + ] + self.business_premises = [ + self.BusinessPremise(name=u'BP 1', building=self.buildings[0]), + self.BusinessPremise(name=u'BP 2', building=self.buildings[0]), + self.BusinessPremise(name=u'BP 3', building=self.buildings[2]), + ] + self.equipment = [ + self.Equipment( + name=u'E 1', building=self.buildings[0] + ), + self.Equipment( + name=u'E 2', building=self.buildings[2] + ), + self.Equipment( + name=u'E 3', business_premise=self.business_premises[0] + ), + self.Equipment( + name=u'E 4', business_premise=self.business_premises[2] + ), + ] + self.session.add_all(self.buildings) + self.session.add_all(self.business_premises) + self.session.add_all(self.equipment) + self.session.commit() + + def test_compound_fetching(self): + buildings = self.session.query(self.Building).all() + batch_fetch( + buildings, + 'business_premises', + compound_path( + 'equipment', + 'business_premises.equipment' + ) + ) + query_count = self.connection.query_count + + buildings[0].equipment + buildings[1].equipment + buildings[0].business_premises[0].equipment + assert self.connection.query_count == query_count