Refactored batch fetch
This commit is contained in:
@@ -6,15 +6,15 @@ from sqlalchemy.orm import defer
|
|||||||
from sqlalchemy.orm.properties import ColumnProperty
|
from sqlalchemy.orm.properties import ColumnProperty
|
||||||
from sqlalchemy.orm.query import Query
|
from sqlalchemy.orm.query import Query
|
||||||
from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint
|
from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint
|
||||||
from .batch_fetch import batch_fetch, with_backrefs, compound_path
|
from .batch_fetch import batch_fetch, with_backrefs, CompositePath
|
||||||
from .sort_query import sort_query
|
from .sort_query import sort_query
|
||||||
|
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
batch_fetch,
|
batch_fetch,
|
||||||
compound_path,
|
|
||||||
sort_query,
|
sort_query,
|
||||||
with_backrefs
|
with_backrefs,
|
||||||
|
CompositePath,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@@ -2,7 +2,9 @@ from collections import defaultdict
|
|||||||
import six
|
import six
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from sqlalchemy.orm import RelationshipProperty
|
from sqlalchemy.orm import RelationshipProperty
|
||||||
from sqlalchemy.orm.attributes import set_committed_value
|
from sqlalchemy.orm.attributes import (
|
||||||
|
set_committed_value, InstrumentedAttribute
|
||||||
|
)
|
||||||
from sqlalchemy.orm.session import object_session
|
from sqlalchemy.orm.session import object_session
|
||||||
|
|
||||||
|
|
||||||
@@ -11,13 +13,73 @@ class with_backrefs(object):
|
|||||||
Marks given attribute path so that whenever its fetched with batch_fetch
|
Marks given attribute path so that whenever its fetched with batch_fetch
|
||||||
the backref relations are force set too.
|
the backref relations are force set too.
|
||||||
"""
|
"""
|
||||||
def __init__(self, attr_path):
|
def __init__(self, path):
|
||||||
self.attr_path = attr_path
|
self.path = path
|
||||||
|
|
||||||
|
|
||||||
class compound_path(object):
|
class Path(object):
|
||||||
def __init__(self, *attr_paths):
|
"""
|
||||||
self.attr_paths = attr_paths
|
A class that represents an attribute path.
|
||||||
|
"""
|
||||||
|
def __init__(self, entities, prop, populate_backrefs=False):
|
||||||
|
self.property = prop
|
||||||
|
self.entities = entities
|
||||||
|
self.populate_backrefs = populate_backrefs
|
||||||
|
if not isinstance(self.property, RelationshipProperty):
|
||||||
|
raise Exception(
|
||||||
|
'Given attribute is not a relationship property.'
|
||||||
|
)
|
||||||
|
self.fetcher = self.fetcher_class(self)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def session(self):
|
||||||
|
return object_session(self.entities[0])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def parent_model(self):
|
||||||
|
return self.entities[0].__class__
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model(self):
|
||||||
|
return self.property.mapper.class_
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def parse(cls, entities, path, populate_backrefs=False):
|
||||||
|
if isinstance(path, six.string_types):
|
||||||
|
attrs = path.split('.')
|
||||||
|
|
||||||
|
if len(attrs) > 1:
|
||||||
|
related_entities = []
|
||||||
|
for entity in entities:
|
||||||
|
related_entities.extend(getattr(entity, attrs[0]))
|
||||||
|
|
||||||
|
subpath = '.'.join(attrs[1:])
|
||||||
|
return Path.parse(related_entities, subpath, populate_backrefs)
|
||||||
|
else:
|
||||||
|
attr = getattr(
|
||||||
|
entities[0].__class__, attrs[0]
|
||||||
|
)
|
||||||
|
elif isinstance(path, InstrumentedAttribute):
|
||||||
|
attr = path
|
||||||
|
else:
|
||||||
|
raise Exception('Unknown path type.')
|
||||||
|
|
||||||
|
return Path(entities, attr.property, populate_backrefs)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fetcher_class(self):
|
||||||
|
if self.property.secondary is not None:
|
||||||
|
return ManyToManyFetcher
|
||||||
|
else:
|
||||||
|
if self.property.direction.name == 'MANYTOONE':
|
||||||
|
return ManyToOneFetcher
|
||||||
|
else:
|
||||||
|
return OneToManyFetcher
|
||||||
|
|
||||||
|
|
||||||
|
class CompositePath(object):
|
||||||
|
def __init__(self, *paths):
|
||||||
|
self.paths = paths
|
||||||
|
|
||||||
|
|
||||||
def batch_fetch(entities, *attr_paths):
|
def batch_fetch(entities, *attr_paths):
|
||||||
@@ -28,9 +90,9 @@ def batch_fetch(entities, *attr_paths):
|
|||||||
subqueryload and performs lot better.
|
subqueryload and performs lot better.
|
||||||
|
|
||||||
:param entities: list of entities of the same type
|
:param entities: list of entities of the same type
|
||||||
:param attr:
|
:param attr_paths:
|
||||||
Either InstrumentedAttribute object or a string representing the name
|
List of either InstrumentedAttribute objects or a strings representing
|
||||||
of the instrumented attribute
|
the name of the instrumented attribute
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
|
|
||||||
@@ -81,107 +143,51 @@ def batch_fetch(entities, *attr_paths):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if entities:
|
if entities:
|
||||||
fetcher = FetchingCoordinator(entities)
|
fetcher = FetchingCoordinator()
|
||||||
for attr_path in attr_paths:
|
for attr_path in attr_paths:
|
||||||
fetcher(attr_path)
|
fetcher(entities, attr_path)
|
||||||
|
|
||||||
|
|
||||||
class FetchingCoordinator(object):
|
class FetchingCoordinator(object):
|
||||||
def __init__(self, entities):
|
def __call__(self, entities, path):
|
||||||
self.entities = entities
|
populate_backrefs = False
|
||||||
self.first = entities[0]
|
if isinstance(path, with_backrefs):
|
||||||
self.session = object_session(self.first)
|
path = path.path
|
||||||
|
populate_backrefs = True
|
||||||
|
|
||||||
def parse_attr_path(self, attr_path, should_populate_backrefs):
|
if isinstance(path, CompositePath):
|
||||||
if isinstance(attr_path, six.string_types):
|
|
||||||
attrs = attr_path.split('.')
|
|
||||||
|
|
||||||
if len(attrs) > 1:
|
|
||||||
related_entities = []
|
|
||||||
for entity in self.entities:
|
|
||||||
related_entities.extend(getattr(entity, attrs[0]))
|
|
||||||
|
|
||||||
subpath = '.'.join(attrs[1:])
|
|
||||||
|
|
||||||
if should_populate_backrefs:
|
|
||||||
subpath = with_backrefs(subpath)
|
|
||||||
|
|
||||||
return self.__class__(related_entities).fetcher_for_attr_path(
|
|
||||||
subpath
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
attr_path = getattr(
|
|
||||||
self.first.__class__, attrs[0]
|
|
||||||
)
|
|
||||||
return self.fetcher_for_property(attr_path.property)
|
|
||||||
|
|
||||||
def fetcher_for_property(self, property_):
|
|
||||||
if not isinstance(property_, RelationshipProperty):
|
|
||||||
raise Exception(
|
|
||||||
'Given attribute is not a relationship property.'
|
|
||||||
)
|
|
||||||
|
|
||||||
if property_.secondary is not None:
|
|
||||||
fetcher_class = ManyToManyFetcher
|
|
||||||
else:
|
|
||||||
if property_.direction.name == 'MANYTOONE':
|
|
||||||
fetcher_class = ManyToOneFetcher
|
|
||||||
else:
|
|
||||||
fetcher_class = OneToManyFetcher
|
|
||||||
|
|
||||||
return fetcher_class(
|
|
||||||
self.entities,
|
|
||||||
property_,
|
|
||||||
self.should_populate_backrefs
|
|
||||||
)
|
|
||||||
|
|
||||||
def fetcher_for_attr_path(self, attr_path):
|
|
||||||
if isinstance(attr_path, with_backrefs):
|
|
||||||
self.should_populate_backrefs = True
|
|
||||||
attr_path = attr_path.attr_path
|
|
||||||
else:
|
|
||||||
self.should_populate_backrefs = False
|
|
||||||
|
|
||||||
return self.parse_attr_path(
|
|
||||||
attr_path,
|
|
||||||
self.should_populate_backrefs
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(self, attr_path):
|
|
||||||
if isinstance(attr_path, compound_path):
|
|
||||||
fetchers = []
|
fetchers = []
|
||||||
for path in attr_path.attr_paths:
|
for path in path.paths:
|
||||||
fetchers.append(self.fetcher_for_attr_path(path))
|
fetchers.append(
|
||||||
|
Path.parse(entities, path, populate_backrefs).fetcher
|
||||||
|
)
|
||||||
|
|
||||||
fetcher = CompoundFetcher(*fetchers)
|
fetcher = CompositeFetcher(*fetchers)
|
||||||
else:
|
else:
|
||||||
fetcher = self.fetcher_for_attr_path(attr_path)
|
fetcher = Path.parse(entities, path, populate_backrefs).fetcher
|
||||||
fetcher.fetch()
|
fetcher.fetch()
|
||||||
fetcher.populate()
|
fetcher.populate()
|
||||||
|
|
||||||
|
|
||||||
class AbstractFetcher(object):
|
class CompositeFetcher(object):
|
||||||
@property
|
|
||||||
def related_entities(self):
|
|
||||||
return self.session.query(self.model).filter(self.condition)
|
|
||||||
|
|
||||||
|
|
||||||
class CompoundFetcher(AbstractFetcher):
|
|
||||||
def __init__(self, *fetchers):
|
def __init__(self, *fetchers):
|
||||||
if not all(fetchers[0].model == fetcher.model for fetcher in fetchers):
|
if not all(
|
||||||
|
fetchers[0].path.model == fetcher.path.model
|
||||||
|
for fetcher in fetchers
|
||||||
|
):
|
||||||
raise Exception(
|
raise Exception(
|
||||||
'Each relationship property must have the same class when '
|
'Each relationship property must have the same class when '
|
||||||
'using CompoundFetcher.'
|
'using CompositeFetcher.'
|
||||||
)
|
)
|
||||||
self.fetchers = fetchers
|
self.fetchers = fetchers
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def session(self):
|
def session(self):
|
||||||
return self.fetchers[0].session
|
return self.fetchers[0].path.session
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model(self):
|
def model(self):
|
||||||
return self.fetchers[0].model
|
return self.fetchers[0].path.model
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def condition(self):
|
def condition(self):
|
||||||
@@ -189,6 +195,10 @@ class CompoundFetcher(AbstractFetcher):
|
|||||||
*[fetcher.condition for fetcher in self.fetchers]
|
*[fetcher.condition for fetcher in self.fetchers]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def related_entities(self):
|
||||||
|
return self.session.query(self.model).filter(self.condition)
|
||||||
|
|
||||||
def fetch(self):
|
def fetch(self):
|
||||||
for entity in self.related_entities:
|
for entity in self.related_entities:
|
||||||
for fetcher in self.fetchers:
|
for fetcher in self.fetchers:
|
||||||
@@ -200,23 +210,27 @@ class CompoundFetcher(AbstractFetcher):
|
|||||||
fetcher.populate()
|
fetcher.populate()
|
||||||
|
|
||||||
|
|
||||||
class Fetcher(AbstractFetcher):
|
class Fetcher(object):
|
||||||
def __init__(self, entities, property_, populate_backrefs=False):
|
def __init__(self, path):
|
||||||
self.should_populate_backrefs = populate_backrefs
|
self.path = path
|
||||||
self.entities = entities
|
self.prop = self.path.property
|
||||||
self.prop = property_
|
|
||||||
self.model = self.prop.mapper.class_
|
|
||||||
self.first = self.entities[0]
|
|
||||||
self.session = object_session(self.first)
|
|
||||||
self.parent_dict = defaultdict(list)
|
self.parent_dict = defaultdict(list)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def local_values_list(self):
|
def local_values_list(self):
|
||||||
return [
|
return [
|
||||||
self.local_values(entity)
|
self.local_values(entity)
|
||||||
for entity in self.entities
|
for entity in self.path.entities
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def related_entities(self):
|
||||||
|
return self.path.session.query(self.path.model).filter(self.condition)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def remote_column_name(self):
|
||||||
|
return list(self.path.property.remote_side)[0].name
|
||||||
|
|
||||||
def local_values(self, entity):
|
def local_values(self, entity):
|
||||||
return getattr(entity, list(self.prop.local_columns)[0].name)
|
return getattr(entity, list(self.prop.local_columns)[0].name)
|
||||||
|
|
||||||
@@ -230,7 +244,7 @@ class Fetcher(AbstractFetcher):
|
|||||||
)
|
)
|
||||||
for entity, parent_id in related_entities:
|
for entity, parent_id in related_entities:
|
||||||
backref_dict[self.local_values(entity)].append(
|
backref_dict[self.local_values(entity)].append(
|
||||||
self.session.query(self.first.__class__).get(parent_id)
|
self.path.session.query(self.path.parent_model).get(parent_id)
|
||||||
)
|
)
|
||||||
for entity, parent_id in related_entities:
|
for entity, parent_id in related_entities:
|
||||||
set_committed_value(
|
set_committed_value(
|
||||||
@@ -243,23 +257,19 @@ class Fetcher(AbstractFetcher):
|
|||||||
"""
|
"""
|
||||||
Populate batch fetched entities to parent objects.
|
Populate batch fetched entities to parent objects.
|
||||||
"""
|
"""
|
||||||
for entity in self.entities:
|
for entity in self.path.entities:
|
||||||
set_committed_value(
|
set_committed_value(
|
||||||
entity,
|
entity,
|
||||||
self.prop.key,
|
self.prop.key,
|
||||||
self.parent_dict[self.local_values(entity)]
|
self.parent_dict[self.local_values(entity)]
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.should_populate_backrefs:
|
if self.path.populate_backrefs:
|
||||||
self.populate_backrefs(self.related_entities)
|
self.populate_backrefs(self.related_entities)
|
||||||
|
|
||||||
@property
|
|
||||||
def remote_column_name(self):
|
|
||||||
return list(self.prop.remote_side)[0].name
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def condition(self):
|
def condition(self):
|
||||||
return getattr(self.model, self.remote_column_name).in_(
|
return getattr(self.path.model, self.remote_column_name).in_(
|
||||||
self.local_values_list
|
self.local_values_list
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -274,15 +284,15 @@ class ManyToManyFetcher(Fetcher):
|
|||||||
for column in self.prop.remote_side:
|
for column in self.prop.remote_side:
|
||||||
for fk in column.foreign_keys:
|
for fk in column.foreign_keys:
|
||||||
# TODO: make this support inherited tables
|
# TODO: make this support inherited tables
|
||||||
if fk.column.table == self.first.__class__.__table__:
|
if fk.column.table == self.path.parent_model.__table__:
|
||||||
return fk.parent.name
|
return fk.parent.name
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def related_entities(self):
|
def related_entities(self):
|
||||||
return (
|
return (
|
||||||
self.session
|
self.path.session
|
||||||
.query(
|
.query(
|
||||||
self.model,
|
self.path.model,
|
||||||
getattr(self.prop.secondary.c, self.remote_column_name)
|
getattr(self.prop.secondary.c, self.remote_column_name)
|
||||||
)
|
)
|
||||||
.join(
|
.join(
|
||||||
@@ -303,8 +313,8 @@ class ManyToManyFetcher(Fetcher):
|
|||||||
|
|
||||||
|
|
||||||
class ManyToOneFetcher(Fetcher):
|
class ManyToOneFetcher(Fetcher):
|
||||||
def __init__(self, entities, property_, populate_backrefs=False):
|
def __init__(self, path):
|
||||||
Fetcher.__init__(self, entities, property_, populate_backrefs)
|
Fetcher.__init__(self, path)
|
||||||
self.parent_dict = defaultdict(lambda: None)
|
self.parent_dict = defaultdict(lambda: None)
|
||||||
|
|
||||||
def append_entity(self, entity):
|
def append_entity(self, entity):
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from sqlalchemy_utils import batch_fetch
|
from sqlalchemy_utils import batch_fetch
|
||||||
from sqlalchemy_utils.functions import compound_path
|
from sqlalchemy_utils.functions import CompositePath
|
||||||
from tests import TestCase
|
from tests import TestCase
|
||||||
|
|
||||||
|
|
||||||
@@ -93,7 +93,7 @@ class TestCompoundOneToManyBatchFetching(TestCase):
|
|||||||
batch_fetch(
|
batch_fetch(
|
||||||
buildings,
|
buildings,
|
||||||
'business_premises',
|
'business_premises',
|
||||||
compound_path(
|
CompositePath(
|
||||||
'equipment',
|
'equipment',
|
||||||
'business_premises.equipment'
|
'business_premises.equipment'
|
||||||
)
|
)
|
||||||
@@ -198,7 +198,7 @@ class TestCompoundManyToOneBatchFetching(TestCase):
|
|||||||
batch_fetch(
|
batch_fetch(
|
||||||
buildings,
|
buildings,
|
||||||
'business_premises',
|
'business_premises',
|
||||||
compound_path(
|
CompositePath(
|
||||||
'equipment',
|
'equipment',
|
||||||
'business_premises.equipment'
|
'business_premises.equipment'
|
||||||
)
|
)
|
||||||
|
@@ -3,7 +3,7 @@ from sqlalchemy_utils import batch_fetch
|
|||||||
from tests import TestCase
|
from tests import TestCase
|
||||||
|
|
||||||
|
|
||||||
class TestBatchFetchAssociations(TestCase):
|
class TestBatchFetchJoinTableInheritedModels(TestCase):
|
||||||
def create_models(self):
|
def create_models(self):
|
||||||
class Category(self.Base):
|
class Category(self.Base):
|
||||||
__tablename__ = 'category'
|
__tablename__ = 'category'
|
||||||
|
Reference in New Issue
Block a user