Refactored batch fetch
This commit is contained in:
@@ -6,12 +6,13 @@ from sqlalchemy.orm import defer
|
|||||||
from sqlalchemy.orm.properties import ColumnProperty
|
from sqlalchemy.orm.properties import ColumnProperty
|
||||||
from sqlalchemy.orm.query import Query
|
from sqlalchemy.orm.query import Query
|
||||||
from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint
|
from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint
|
||||||
from .batch_fetch import batch_fetch, with_backrefs
|
from .batch_fetch import batch_fetch, with_backrefs, compound_path
|
||||||
from .sort_query import sort_query
|
from .sort_query import sort_query
|
||||||
|
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
batch_fetch,
|
batch_fetch,
|
||||||
|
compound_path,
|
||||||
sort_query,
|
sort_query,
|
||||||
with_backrefs
|
with_backrefs
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -13,6 +13,11 @@ class with_backrefs(object):
|
|||||||
self.attr_path = attr_path
|
self.attr_path = attr_path
|
||||||
|
|
||||||
|
|
||||||
|
class compound_path(object):
|
||||||
|
def __init__(self, *attr_paths):
|
||||||
|
self.attr_paths = attr_paths
|
||||||
|
|
||||||
|
|
||||||
def batch_fetch(entities, *attr_paths):
|
def batch_fetch(entities, *attr_paths):
|
||||||
"""
|
"""
|
||||||
Batch fetch given relationship attribute for collection of entities.
|
Batch fetch given relationship attribute for collection of entities.
|
||||||
@@ -118,7 +123,7 @@ class FetchingCoordinator(object):
|
|||||||
'are supported.'
|
'are supported.'
|
||||||
)
|
)
|
||||||
|
|
||||||
def fetcher(self, property_):
|
def fetcher_for_property(self, property_):
|
||||||
if not isinstance(property_, RelationshipProperty):
|
if not isinstance(property_, RelationshipProperty):
|
||||||
raise Exception(
|
raise Exception(
|
||||||
'Given attribute is not a relationship property.'
|
'Given attribute is not a relationship property.'
|
||||||
@@ -132,7 +137,7 @@ class FetchingCoordinator(object):
|
|||||||
else:
|
else:
|
||||||
return OneToManyFetcher(self, property_)
|
return OneToManyFetcher(self, property_)
|
||||||
|
|
||||||
def __call__(self, attr_path):
|
def fetcher_for_attr_path(self, attr_path):
|
||||||
if isinstance(attr_path, with_backrefs):
|
if isinstance(attr_path, with_backrefs):
|
||||||
self.should_populate_backrefs = True
|
self.should_populate_backrefs = True
|
||||||
attr_path = attr_path.attr_path
|
attr_path = attr_path.attr_path
|
||||||
@@ -142,12 +147,28 @@ class FetchingCoordinator(object):
|
|||||||
attr = self.parse_attr_path(attr_path, self.should_populate_backrefs)
|
attr = self.parse_attr_path(attr_path, self.should_populate_backrefs)
|
||||||
if not attr:
|
if not attr:
|
||||||
return
|
return
|
||||||
|
return self.fetcher_for_property(attr.property)
|
||||||
|
|
||||||
fetcher = self.fetcher(attr.property)
|
def __call__(self, attr_path):
|
||||||
|
if isinstance(attr_path, compound_path):
|
||||||
|
for path in attr_path.attr_paths:
|
||||||
|
self(path)
|
||||||
|
else:
|
||||||
|
fetcher = self.fetcher_for_attr_path(attr_path)
|
||||||
|
if not fetcher:
|
||||||
|
return
|
||||||
fetcher.fetch()
|
fetcher.fetch()
|
||||||
fetcher.populate()
|
fetcher.populate()
|
||||||
|
|
||||||
|
|
||||||
|
class CompoundFetcher(object):
|
||||||
|
def __init__(self, coordinator, path):
|
||||||
|
self.coordinator = coordinator
|
||||||
|
self.entities = coordinator.entities
|
||||||
|
self.first = self.entities[0]
|
||||||
|
self.session = object_session(self.first)
|
||||||
|
|
||||||
|
|
||||||
class Fetcher(object):
|
class Fetcher(object):
|
||||||
def __init__(self, coordinator, property_):
|
def __init__(self, coordinator, property_):
|
||||||
self.coordinator = coordinator
|
self.coordinator = coordinator
|
||||||
@@ -207,31 +228,49 @@ class Fetcher(object):
|
|||||||
if self.coordinator.should_populate_backrefs:
|
if self.coordinator.should_populate_backrefs:
|
||||||
self.populate_backrefs(self.related_entities)
|
self.populate_backrefs(self.related_entities)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def remote_column_name(self):
|
||||||
|
return list(self.prop.remote_side)[0].name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def condition(self):
|
||||||
|
return getattr(self.model, self.remote_column_name).in_(
|
||||||
|
self.local_values_list
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def related_entities(self):
|
||||||
|
return self.session.query(self.model).filter(self.condition)
|
||||||
|
|
||||||
|
|
||||||
class ManyToManyFetcher(Fetcher):
|
class ManyToManyFetcher(Fetcher):
|
||||||
def fetch(self):
|
@property
|
||||||
column_name = None
|
def remote_column_name(self):
|
||||||
for column in self.prop.remote_side:
|
for column in self.prop.remote_side:
|
||||||
for fk in column.foreign_keys:
|
for fk in column.foreign_keys:
|
||||||
# TODO: make this support inherited tables
|
# TODO: make this support inherited tables
|
||||||
if fk.column.table == self.first.__class__.__table__:
|
if fk.column.table == self.first.__class__.__table__:
|
||||||
column_name = fk.parent.name
|
return fk.parent.name
|
||||||
break
|
|
||||||
if column_name:
|
|
||||||
break
|
|
||||||
|
|
||||||
self.related_entities = (
|
@property
|
||||||
|
def related_entities(self):
|
||||||
|
return (
|
||||||
self.session
|
self.session
|
||||||
.query(self.model, getattr(self.prop.secondary.c, column_name))
|
.query(
|
||||||
|
self.model,
|
||||||
|
getattr(self.prop.secondary.c, self.remote_column_name)
|
||||||
|
)
|
||||||
.join(
|
.join(
|
||||||
self.prop.secondary, self.prop.secondaryjoin
|
self.prop.secondary, self.prop.secondaryjoin
|
||||||
)
|
)
|
||||||
.filter(
|
.filter(
|
||||||
getattr(self.prop.secondary.c, column_name).in_(
|
getattr(self.prop.secondary.c, self.remote_column_name).in_(
|
||||||
self.local_values_list
|
self.local_values_list
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def fetch(self):
|
||||||
for entity, parent_id in self.related_entities:
|
for entity, parent_id in self.related_entities:
|
||||||
self.parent_dict[parent_id].append(
|
self.parent_dict[parent_id].append(
|
||||||
entity
|
entity
|
||||||
@@ -246,31 +285,13 @@ class ManyToOneFetcher(Fetcher):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def fetch(self):
|
def fetch(self):
|
||||||
column_name = list(self.prop.remote_side)[0].name
|
|
||||||
|
|
||||||
self.related_entities = (
|
|
||||||
self.session.query(self.model)
|
|
||||||
.filter(
|
|
||||||
getattr(self.model, column_name).in_(self.local_values_list)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
for entity in self.related_entities:
|
for entity in self.related_entities:
|
||||||
self.parent_dict[getattr(entity, column_name)] = entity
|
self.parent_dict[getattr(entity, self.remote_column_name)] = entity
|
||||||
|
|
||||||
|
|
||||||
class OneToManyFetcher(Fetcher):
|
class OneToManyFetcher(Fetcher):
|
||||||
def fetch(self):
|
def fetch(self):
|
||||||
column_name = list(self.prop.remote_side)[0].name
|
|
||||||
|
|
||||||
self.related_entities = (
|
|
||||||
self.session.query(self.model)
|
|
||||||
.filter(
|
|
||||||
getattr(self.model, column_name).in_(self.local_values_list)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
for entity in self.related_entities:
|
for entity in self.related_entities:
|
||||||
self.parent_dict[getattr(entity, column_name)].append(
|
self.parent_dict[getattr(entity, self.remote_column_name)].append(
|
||||||
entity
|
entity
|
||||||
)
|
)
|
||||||
|
|||||||
100
tests/batch_fetch/test_compound_fetching.py
Normal file
100
tests/batch_fetch/test_compound_fetching.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy_utils import batch_fetch
|
||||||
|
from sqlalchemy_utils.functions import compound_path
|
||||||
|
from tests import TestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompoundBatchFetching(TestCase):
|
||||||
|
def create_models(self):
|
||||||
|
class Building(self.Base):
|
||||||
|
__tablename__ = 'building'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
class BusinessPremise(self.Base):
|
||||||
|
__tablename__ = 'business_premise'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
building_id = sa.Column(sa.Integer, sa.ForeignKey(Building.id))
|
||||||
|
|
||||||
|
building = sa.orm.relationship(
|
||||||
|
Building,
|
||||||
|
backref=sa.orm.backref(
|
||||||
|
'business_premises'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
class Equipment(self.Base):
|
||||||
|
__tablename__ = 'article'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
building_id = sa.Column(sa.Integer, sa.ForeignKey(Building.id))
|
||||||
|
business_premise_id = sa.Column(
|
||||||
|
sa.Integer, sa.ForeignKey(BusinessPremise.id)
|
||||||
|
)
|
||||||
|
|
||||||
|
building = sa.orm.relationship(
|
||||||
|
Building,
|
||||||
|
backref=sa.orm.backref(
|
||||||
|
'equipment'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
business_premise = sa.orm.relationship(
|
||||||
|
BusinessPremise,
|
||||||
|
backref=sa.orm.backref(
|
||||||
|
'equipment'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.Building = Building
|
||||||
|
self.BusinessPremise = BusinessPremise
|
||||||
|
self.Equipment = Equipment
|
||||||
|
|
||||||
|
def setup_method(self, method):
|
||||||
|
TestCase.setup_method(self, method)
|
||||||
|
self.buildings = [
|
||||||
|
self.Building(name=u'B 1'),
|
||||||
|
self.Building(name=u'B 2'),
|
||||||
|
self.Building(name=u'B 3'),
|
||||||
|
]
|
||||||
|
self.business_premises = [
|
||||||
|
self.BusinessPremise(name=u'BP 1', building=self.buildings[0]),
|
||||||
|
self.BusinessPremise(name=u'BP 2', building=self.buildings[0]),
|
||||||
|
self.BusinessPremise(name=u'BP 3', building=self.buildings[2]),
|
||||||
|
]
|
||||||
|
self.equipment = [
|
||||||
|
self.Equipment(
|
||||||
|
name=u'E 1', building=self.buildings[0]
|
||||||
|
),
|
||||||
|
self.Equipment(
|
||||||
|
name=u'E 2', building=self.buildings[2]
|
||||||
|
),
|
||||||
|
self.Equipment(
|
||||||
|
name=u'E 3', business_premise=self.business_premises[0]
|
||||||
|
),
|
||||||
|
self.Equipment(
|
||||||
|
name=u'E 4', business_premise=self.business_premises[2]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
self.session.add_all(self.buildings)
|
||||||
|
self.session.add_all(self.business_premises)
|
||||||
|
self.session.add_all(self.equipment)
|
||||||
|
self.session.commit()
|
||||||
|
|
||||||
|
def test_compound_fetching(self):
|
||||||
|
buildings = self.session.query(self.Building).all()
|
||||||
|
batch_fetch(
|
||||||
|
buildings,
|
||||||
|
'business_premises',
|
||||||
|
compound_path(
|
||||||
|
'equipment',
|
||||||
|
'business_premises.equipment'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
query_count = self.connection.query_count
|
||||||
|
|
||||||
|
buildings[0].equipment
|
||||||
|
buildings[1].equipment
|
||||||
|
buildings[0].business_premises[0].equipment
|
||||||
|
assert self.connection.query_count == query_count
|
||||||
Reference in New Issue
Block a user