Refactored batch fetch

This commit is contained in:
Konsta Vesterinen
2013-08-21 16:38:44 +03:00
parent 03b1985980
commit 223f6a64fa
3 changed files with 157 additions and 35 deletions

View File

@@ -6,12 +6,13 @@ from sqlalchemy.orm import defer
from sqlalchemy.orm.properties import ColumnProperty from sqlalchemy.orm.properties import ColumnProperty
from sqlalchemy.orm.query import Query from sqlalchemy.orm.query import Query
from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint
from .batch_fetch import batch_fetch, with_backrefs from .batch_fetch import batch_fetch, with_backrefs, compound_path
from .sort_query import sort_query from .sort_query import sort_query
__all__ = ( __all__ = (
batch_fetch, batch_fetch,
compound_path,
sort_query, sort_query,
with_backrefs with_backrefs
) )

View File

@@ -13,6 +13,11 @@ class with_backrefs(object):
self.attr_path = attr_path self.attr_path = attr_path
class compound_path(object):
def __init__(self, *attr_paths):
self.attr_paths = attr_paths
def batch_fetch(entities, *attr_paths): def batch_fetch(entities, *attr_paths):
""" """
Batch fetch given relationship attribute for collection of entities. Batch fetch given relationship attribute for collection of entities.
@@ -118,7 +123,7 @@ class FetchingCoordinator(object):
'are supported.' 'are supported.'
) )
def fetcher(self, property_): def fetcher_for_property(self, property_):
if not isinstance(property_, RelationshipProperty): if not isinstance(property_, RelationshipProperty):
raise Exception( raise Exception(
'Given attribute is not a relationship property.' 'Given attribute is not a relationship property.'
@@ -132,7 +137,7 @@ class FetchingCoordinator(object):
else: else:
return OneToManyFetcher(self, property_) return OneToManyFetcher(self, property_)
def __call__(self, attr_path): def fetcher_for_attr_path(self, attr_path):
if isinstance(attr_path, with_backrefs): if isinstance(attr_path, with_backrefs):
self.should_populate_backrefs = True self.should_populate_backrefs = True
attr_path = attr_path.attr_path attr_path = attr_path.attr_path
@@ -142,12 +147,28 @@ class FetchingCoordinator(object):
attr = self.parse_attr_path(attr_path, self.should_populate_backrefs) attr = self.parse_attr_path(attr_path, self.should_populate_backrefs)
if not attr: if not attr:
return return
return self.fetcher_for_property(attr.property)
fetcher = self.fetcher(attr.property) def __call__(self, attr_path):
if isinstance(attr_path, compound_path):
for path in attr_path.attr_paths:
self(path)
else:
fetcher = self.fetcher_for_attr_path(attr_path)
if not fetcher:
return
fetcher.fetch() fetcher.fetch()
fetcher.populate() fetcher.populate()
class CompoundFetcher(object):
def __init__(self, coordinator, path):
self.coordinator = coordinator
self.entities = coordinator.entities
self.first = self.entities[0]
self.session = object_session(self.first)
class Fetcher(object): class Fetcher(object):
def __init__(self, coordinator, property_): def __init__(self, coordinator, property_):
self.coordinator = coordinator self.coordinator = coordinator
@@ -207,31 +228,49 @@ class Fetcher(object):
if self.coordinator.should_populate_backrefs: if self.coordinator.should_populate_backrefs:
self.populate_backrefs(self.related_entities) self.populate_backrefs(self.related_entities)
@property
def remote_column_name(self):
return list(self.prop.remote_side)[0].name
@property
def condition(self):
return getattr(self.model, self.remote_column_name).in_(
self.local_values_list
)
@property
def related_entities(self):
return self.session.query(self.model).filter(self.condition)
class ManyToManyFetcher(Fetcher): class ManyToManyFetcher(Fetcher):
def fetch(self): @property
column_name = None def remote_column_name(self):
for column in self.prop.remote_side: for column in self.prop.remote_side:
for fk in column.foreign_keys: for fk in column.foreign_keys:
# TODO: make this support inherited tables # TODO: make this support inherited tables
if fk.column.table == self.first.__class__.__table__: if fk.column.table == self.first.__class__.__table__:
column_name = fk.parent.name return fk.parent.name
break
if column_name:
break
self.related_entities = ( @property
def related_entities(self):
return (
self.session self.session
.query(self.model, getattr(self.prop.secondary.c, column_name)) .query(
self.model,
getattr(self.prop.secondary.c, self.remote_column_name)
)
.join( .join(
self.prop.secondary, self.prop.secondaryjoin self.prop.secondary, self.prop.secondaryjoin
) )
.filter( .filter(
getattr(self.prop.secondary.c, column_name).in_( getattr(self.prop.secondary.c, self.remote_column_name).in_(
self.local_values_list self.local_values_list
) )
) )
) )
def fetch(self):
for entity, parent_id in self.related_entities: for entity, parent_id in self.related_entities:
self.parent_dict[parent_id].append( self.parent_dict[parent_id].append(
entity entity
@@ -246,31 +285,13 @@ class ManyToOneFetcher(Fetcher):
) )
def fetch(self): def fetch(self):
column_name = list(self.prop.remote_side)[0].name
self.related_entities = (
self.session.query(self.model)
.filter(
getattr(self.model, column_name).in_(self.local_values_list)
)
)
for entity in self.related_entities: for entity in self.related_entities:
self.parent_dict[getattr(entity, column_name)] = entity self.parent_dict[getattr(entity, self.remote_column_name)] = entity
class OneToManyFetcher(Fetcher): class OneToManyFetcher(Fetcher):
def fetch(self): def fetch(self):
column_name = list(self.prop.remote_side)[0].name
self.related_entities = (
self.session.query(self.model)
.filter(
getattr(self.model, column_name).in_(self.local_values_list)
)
)
for entity in self.related_entities: for entity in self.related_entities:
self.parent_dict[getattr(entity, column_name)].append( self.parent_dict[getattr(entity, self.remote_column_name)].append(
entity entity
) )

View File

@@ -0,0 +1,100 @@
import sqlalchemy as sa
from sqlalchemy_utils import batch_fetch
from sqlalchemy_utils.functions import compound_path
from tests import TestCase
class TestCompoundBatchFetching(TestCase):
def create_models(self):
class Building(self.Base):
__tablename__ = 'building'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
class BusinessPremise(self.Base):
__tablename__ = 'business_premise'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
building_id = sa.Column(sa.Integer, sa.ForeignKey(Building.id))
building = sa.orm.relationship(
Building,
backref=sa.orm.backref(
'business_premises'
)
)
class Equipment(self.Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
building_id = sa.Column(sa.Integer, sa.ForeignKey(Building.id))
business_premise_id = sa.Column(
sa.Integer, sa.ForeignKey(BusinessPremise.id)
)
building = sa.orm.relationship(
Building,
backref=sa.orm.backref(
'equipment'
)
)
business_premise = sa.orm.relationship(
BusinessPremise,
backref=sa.orm.backref(
'equipment'
)
)
self.Building = Building
self.BusinessPremise = BusinessPremise
self.Equipment = Equipment
def setup_method(self, method):
TestCase.setup_method(self, method)
self.buildings = [
self.Building(name=u'B 1'),
self.Building(name=u'B 2'),
self.Building(name=u'B 3'),
]
self.business_premises = [
self.BusinessPremise(name=u'BP 1', building=self.buildings[0]),
self.BusinessPremise(name=u'BP 2', building=self.buildings[0]),
self.BusinessPremise(name=u'BP 3', building=self.buildings[2]),
]
self.equipment = [
self.Equipment(
name=u'E 1', building=self.buildings[0]
),
self.Equipment(
name=u'E 2', building=self.buildings[2]
),
self.Equipment(
name=u'E 3', business_premise=self.business_premises[0]
),
self.Equipment(
name=u'E 4', business_premise=self.business_premises[2]
),
]
self.session.add_all(self.buildings)
self.session.add_all(self.business_premises)
self.session.add_all(self.equipment)
self.session.commit()
def test_compound_fetching(self):
buildings = self.session.query(self.Building).all()
batch_fetch(
buildings,
'business_premises',
compound_path(
'equipment',
'business_premises.equipment'
)
)
query_count = self.connection.query_count
buildings[0].equipment
buildings[1].equipment
buildings[0].business_premises[0].equipment
assert self.connection.query_count == query_count