diff --git a/sqlalchemy_utils/functions/__init__.py b/sqlalchemy_utils/functions/__init__.py index 03486db..d536df4 100644 --- a/sqlalchemy_utils/functions/__init__.py +++ b/sqlalchemy_utils/functions/__init__.py @@ -6,15 +6,15 @@ from sqlalchemy.orm import defer from sqlalchemy.orm.properties import ColumnProperty from sqlalchemy.orm.query import Query from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint -from .batch_fetch import batch_fetch, with_backrefs, compound_path +from .batch_fetch import batch_fetch, with_backrefs, CompositePath from .sort_query import sort_query __all__ = ( batch_fetch, - compound_path, sort_query, - with_backrefs + with_backrefs, + CompositePath, ) diff --git a/sqlalchemy_utils/functions/batch_fetch.py b/sqlalchemy_utils/functions/batch_fetch.py index d8b8a58..cbe4715 100644 --- a/sqlalchemy_utils/functions/batch_fetch.py +++ b/sqlalchemy_utils/functions/batch_fetch.py @@ -2,7 +2,9 @@ from collections import defaultdict import six import sqlalchemy as sa from sqlalchemy.orm import RelationshipProperty -from sqlalchemy.orm.attributes import set_committed_value +from sqlalchemy.orm.attributes import ( + set_committed_value, InstrumentedAttribute +) from sqlalchemy.orm.session import object_session @@ -11,13 +13,73 @@ class with_backrefs(object): Marks given attribute path so that whenever its fetched with batch_fetch the backref relations are force set too. """ - def __init__(self, attr_path): - self.attr_path = attr_path + def __init__(self, path): + self.path = path -class compound_path(object): - def __init__(self, *attr_paths): - self.attr_paths = attr_paths +class Path(object): + """ + A class that represents an attribute path. + """ + def __init__(self, entities, prop, populate_backrefs=False): + self.property = prop + self.entities = entities + self.populate_backrefs = populate_backrefs + if not isinstance(self.property, RelationshipProperty): + raise Exception( + 'Given attribute is not a relationship property.' + ) + self.fetcher = self.fetcher_class(self) + + @property + def session(self): + return object_session(self.entities[0]) + + @property + def parent_model(self): + return self.entities[0].__class__ + + @property + def model(self): + return self.property.mapper.class_ + + @classmethod + def parse(cls, entities, path, populate_backrefs=False): + if isinstance(path, six.string_types): + attrs = path.split('.') + + if len(attrs) > 1: + related_entities = [] + for entity in entities: + related_entities.extend(getattr(entity, attrs[0])) + + subpath = '.'.join(attrs[1:]) + return Path.parse(related_entities, subpath, populate_backrefs) + else: + attr = getattr( + entities[0].__class__, attrs[0] + ) + elif isinstance(path, InstrumentedAttribute): + attr = path + else: + raise Exception('Unknown path type.') + + return Path(entities, attr.property, populate_backrefs) + + @property + def fetcher_class(self): + if self.property.secondary is not None: + return ManyToManyFetcher + else: + if self.property.direction.name == 'MANYTOONE': + return ManyToOneFetcher + else: + return OneToManyFetcher + + +class CompositePath(object): + def __init__(self, *paths): + self.paths = paths def batch_fetch(entities, *attr_paths): @@ -28,9 +90,9 @@ def batch_fetch(entities, *attr_paths): subqueryload and performs lot better. :param entities: list of entities of the same type - :param attr: - Either InstrumentedAttribute object or a string representing the name - of the instrumented attribute + :param attr_paths: + List of either InstrumentedAttribute objects or a strings representing + the name of the instrumented attribute Example:: @@ -81,107 +143,51 @@ def batch_fetch(entities, *attr_paths): """ if entities: - fetcher = FetchingCoordinator(entities) + fetcher = FetchingCoordinator() for attr_path in attr_paths: - fetcher(attr_path) + fetcher(entities, attr_path) class FetchingCoordinator(object): - def __init__(self, entities): - self.entities = entities - self.first = entities[0] - self.session = object_session(self.first) + def __call__(self, entities, path): + populate_backrefs = False + if isinstance(path, with_backrefs): + path = path.path + populate_backrefs = True - def parse_attr_path(self, attr_path, should_populate_backrefs): - if isinstance(attr_path, six.string_types): - attrs = attr_path.split('.') - - if len(attrs) > 1: - related_entities = [] - for entity in self.entities: - related_entities.extend(getattr(entity, attrs[0])) - - subpath = '.'.join(attrs[1:]) - - if should_populate_backrefs: - subpath = with_backrefs(subpath) - - return self.__class__(related_entities).fetcher_for_attr_path( - subpath - ) - else: - attr_path = getattr( - self.first.__class__, attrs[0] - ) - return self.fetcher_for_property(attr_path.property) - - def fetcher_for_property(self, property_): - if not isinstance(property_, RelationshipProperty): - raise Exception( - 'Given attribute is not a relationship property.' - ) - - if property_.secondary is not None: - fetcher_class = ManyToManyFetcher - else: - if property_.direction.name == 'MANYTOONE': - fetcher_class = ManyToOneFetcher - else: - fetcher_class = OneToManyFetcher - - return fetcher_class( - self.entities, - property_, - self.should_populate_backrefs - ) - - def fetcher_for_attr_path(self, attr_path): - if isinstance(attr_path, with_backrefs): - self.should_populate_backrefs = True - attr_path = attr_path.attr_path - else: - self.should_populate_backrefs = False - - return self.parse_attr_path( - attr_path, - self.should_populate_backrefs - ) - - def __call__(self, attr_path): - if isinstance(attr_path, compound_path): + if isinstance(path, CompositePath): fetchers = [] - for path in attr_path.attr_paths: - fetchers.append(self.fetcher_for_attr_path(path)) + for path in path.paths: + fetchers.append( + Path.parse(entities, path, populate_backrefs).fetcher + ) - fetcher = CompoundFetcher(*fetchers) + fetcher = CompositeFetcher(*fetchers) else: - fetcher = self.fetcher_for_attr_path(attr_path) + fetcher = Path.parse(entities, path, populate_backrefs).fetcher fetcher.fetch() fetcher.populate() -class AbstractFetcher(object): - @property - def related_entities(self): - return self.session.query(self.model).filter(self.condition) - - -class CompoundFetcher(AbstractFetcher): +class CompositeFetcher(object): def __init__(self, *fetchers): - if not all(fetchers[0].model == fetcher.model for fetcher in fetchers): + if not all( + fetchers[0].path.model == fetcher.path.model + for fetcher in fetchers + ): raise Exception( 'Each relationship property must have the same class when ' - 'using CompoundFetcher.' + 'using CompositeFetcher.' ) self.fetchers = fetchers @property def session(self): - return self.fetchers[0].session + return self.fetchers[0].path.session @property def model(self): - return self.fetchers[0].model + return self.fetchers[0].path.model @property def condition(self): @@ -189,6 +195,10 @@ class CompoundFetcher(AbstractFetcher): *[fetcher.condition for fetcher in self.fetchers] ) + @property + def related_entities(self): + return self.session.query(self.model).filter(self.condition) + def fetch(self): for entity in self.related_entities: for fetcher in self.fetchers: @@ -200,23 +210,27 @@ class CompoundFetcher(AbstractFetcher): fetcher.populate() -class Fetcher(AbstractFetcher): - def __init__(self, entities, property_, populate_backrefs=False): - self.should_populate_backrefs = populate_backrefs - self.entities = entities - self.prop = property_ - self.model = self.prop.mapper.class_ - self.first = self.entities[0] - self.session = object_session(self.first) +class Fetcher(object): + def __init__(self, path): + self.path = path + self.prop = self.path.property self.parent_dict = defaultdict(list) @property def local_values_list(self): return [ self.local_values(entity) - for entity in self.entities + for entity in self.path.entities ] + @property + def related_entities(self): + return self.path.session.query(self.path.model).filter(self.condition) + + @property + def remote_column_name(self): + return list(self.path.property.remote_side)[0].name + def local_values(self, entity): return getattr(entity, list(self.prop.local_columns)[0].name) @@ -230,7 +244,7 @@ class Fetcher(AbstractFetcher): ) for entity, parent_id in related_entities: backref_dict[self.local_values(entity)].append( - self.session.query(self.first.__class__).get(parent_id) + self.path.session.query(self.path.parent_model).get(parent_id) ) for entity, parent_id in related_entities: set_committed_value( @@ -243,23 +257,19 @@ class Fetcher(AbstractFetcher): """ Populate batch fetched entities to parent objects. """ - for entity in self.entities: + for entity in self.path.entities: set_committed_value( entity, self.prop.key, self.parent_dict[self.local_values(entity)] ) - if self.should_populate_backrefs: + if self.path.populate_backrefs: self.populate_backrefs(self.related_entities) - @property - def remote_column_name(self): - return list(self.prop.remote_side)[0].name - @property def condition(self): - return getattr(self.model, self.remote_column_name).in_( + return getattr(self.path.model, self.remote_column_name).in_( self.local_values_list ) @@ -274,15 +284,15 @@ class ManyToManyFetcher(Fetcher): for column in self.prop.remote_side: for fk in column.foreign_keys: # TODO: make this support inherited tables - if fk.column.table == self.first.__class__.__table__: + if fk.column.table == self.path.parent_model.__table__: return fk.parent.name @property def related_entities(self): return ( - self.session + self.path.session .query( - self.model, + self.path.model, getattr(self.prop.secondary.c, self.remote_column_name) ) .join( @@ -303,8 +313,8 @@ class ManyToManyFetcher(Fetcher): class ManyToOneFetcher(Fetcher): - def __init__(self, entities, property_, populate_backrefs=False): - Fetcher.__init__(self, entities, property_, populate_backrefs) + def __init__(self, path): + Fetcher.__init__(self, path) self.parent_dict = defaultdict(lambda: None) def append_entity(self, entity): diff --git a/tests/batch_fetch/test_compound_fetching.py b/tests/batch_fetch/test_compound_fetching.py index 7d7ce35..977a330 100644 --- a/tests/batch_fetch/test_compound_fetching.py +++ b/tests/batch_fetch/test_compound_fetching.py @@ -1,6 +1,6 @@ import sqlalchemy as sa from sqlalchemy_utils import batch_fetch -from sqlalchemy_utils.functions import compound_path +from sqlalchemy_utils.functions import CompositePath from tests import TestCase @@ -93,7 +93,7 @@ class TestCompoundOneToManyBatchFetching(TestCase): batch_fetch( buildings, 'business_premises', - compound_path( + CompositePath( 'equipment', 'business_premises.equipment' ) @@ -198,7 +198,7 @@ class TestCompoundManyToOneBatchFetching(TestCase): batch_fetch( buildings, 'business_premises', - compound_path( + CompositePath( 'equipment', 'business_premises.equipment' ) diff --git a/tests/batch_fetch/test_join_table_inheritance.py b/tests/batch_fetch/test_join_table_inheritance.py index 95abdec..d793d4e 100644 --- a/tests/batch_fetch/test_join_table_inheritance.py +++ b/tests/batch_fetch/test_join_table_inheritance.py @@ -3,7 +3,7 @@ from sqlalchemy_utils import batch_fetch from tests import TestCase -class TestBatchFetchAssociations(TestCase): +class TestBatchFetchJoinTableInheritedModels(TestCase): def create_models(self): class Category(self.Base): __tablename__ = 'category'