From 39feea5a6f5665828c321cb2c6a5bd7df7c59c90 Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Tue, 13 Aug 2013 10:17:47 +0300 Subject: [PATCH] Refactored batch fetch, rewrote backref population syntax --- sqlalchemy_utils/__init__.py | 2 + sqlalchemy_utils/functions/__init__.py | 5 +- sqlalchemy_utils/functions/batch_fetch.py | 227 +++++++++++-------- tests/batch_fetch/test_deep_relationships.py | 4 +- 4 files changed, 144 insertions(+), 94 deletions(-) diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index ea0259f..1c0727a 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -7,6 +7,7 @@ from .functions import ( render_statement, sort_query, table_name, + with_backrefs ) from .listeners import coercion_listener from .merge import merge, Merger @@ -49,6 +50,7 @@ __all__ = ( render_statement, sort_query, table_name, + with_backrefs, ArrowType, ColorType, EmailType, diff --git a/sqlalchemy_utils/functions/__init__.py b/sqlalchemy_utils/functions/__init__.py index 9f1429b..1bee38e 100644 --- a/sqlalchemy_utils/functions/__init__.py +++ b/sqlalchemy_utils/functions/__init__.py @@ -6,13 +6,14 @@ from sqlalchemy.orm import defer from sqlalchemy.orm.properties import ColumnProperty from sqlalchemy.orm.query import Query from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint -from .batch_fetch import batch_fetch +from .batch_fetch import batch_fetch, with_backrefs from .sort_query import sort_query __all__ = ( batch_fetch, - sort_query + sort_query, + with_backrefs ) diff --git a/sqlalchemy_utils/functions/batch_fetch.py b/sqlalchemy_utils/functions/batch_fetch.py index 41a84e9..2e99d62 100644 --- a/sqlalchemy_utils/functions/batch_fetch.py +++ b/sqlalchemy_utils/functions/batch_fetch.py @@ -4,6 +4,11 @@ from sqlalchemy.orm.attributes import set_committed_value from sqlalchemy.orm.session import object_session +class with_backrefs(object): + def __init__(self, attr_path): + self.attr_path = attr_path + + def batch_fetch(entities, *attr_paths): """ Batch fetch given relationship attribute for collection of entities. @@ -50,121 +55,163 @@ def batch_fetch(entities, *attr_paths): You can also force populate backrefs: :: + from sqlalchemy_utils import with_backrefs + + clubs = session.query(Club).limit(20).all() batch_fetch( clubs, 'teams', 'teams.players', - 'teams.players.user_groups -pb' + with_backrefs('teams.players.user_groups') ) """ if entities: - first = entities[0] - parent_ids = [entity.id for entity in entities] - + fetcher = BatchFetcher(entities) for attr_path in attr_paths: - parent_dict = dict((entity.id, []) for entity in entities) - populate_backrefs = False + fetcher(attr_path) - if isinstance(attr_path, six.string_types): - attrs = attr_path.split('.') - if len(attrs) > 1: - related_entities = [] - for entity in entities: - related_entities.extend(getattr(entity, attrs[0])) +class BatchFetcher(object): + def __init__(self, entities): + self.entities = entities + self.first = entities[0] + self.parent_ids = [entity.id for entity in entities] + self.session = object_session(self.first) - batch_fetch( - related_entities, - '.'.join(attrs[1:]) - ) - continue - else: - args = attrs[-1].split(' ') - if '-pb' in args: - populate_backrefs = True + def populate_backrefs(self, related_entities): + """ + Populates backrefs for given related entities. + """ - attr = getattr( - first.__class__, args[0] - ) + backref_dict = dict( + (entity.id, []) for entity, parent_id in related_entities + ) + for entity, parent_id in related_entities: + backref_dict[entity.id].append( + self.session.query(self.first.__class__).get(parent_id) + ) + for entity, parent_id in related_entities: + set_committed_value( + entity, self.prop.back_populates, backref_dict[entity.id] + ) + + def populate_entities(self): + """ + Populate batch fetched entities to parent objects. + """ + for entity in self.entities: + set_committed_value( + entity, + self.prop.key, + self.parent_dict[entity.id] + ) + + if self.should_populate_backrefs: + self.populate_backrefs(self.related_entities) + + def parse_attr_path(self, attr_path, should_populate_backrefs): + if isinstance(attr_path, six.string_types): + attrs = attr_path.split('.') + + if len(attrs) > 1: + related_entities = [] + for entity in self.entities: + related_entities.extend(getattr(entity, attrs[0])) + + subpath = '.'.join(attrs[1:]) + + if should_populate_backrefs: + subpath = with_backrefs(subpath) + + batch_fetch( + related_entities, + subpath + ) + return else: - attr = attr_path - - prop = attr.property - if not isinstance(prop, RelationshipProperty): - raise Exception( - 'Given attribute is not a relationship property.' + return getattr( + self.first.__class__, attrs[0] ) + else: + return attr_path - model = prop.mapper.class_ + def fetch_relation_entities(self): + if len(self.prop.remote_side) > 1: + raise Exception( + 'Only relationships with single remote side columns ' + 'are supported.' + ) - session = object_session(first) + column_name = list(self.prop.remote_side)[0].name - if prop.secondary is None: - if len(prop.remote_side) > 1: - raise Exception( - 'Only relationships with single remote side columns ' - 'are supported.' - ) + self.related_entities = ( + self.session.query(self.model) + .filter( + getattr(self.model, column_name).in_(self.parent_ids) + ) + ) - column_name = list(prop.remote_side)[0].name + for entity in self.related_entities: + self.parent_dict[getattr(entity, column_name)].append( + entity + ) - related_entities = ( - session.query(model) - .filter( - getattr(model, column_name).in_(parent_ids) - ) + def fetch_association_entities(self): + column_name = None + for column in self.prop.remote_side: + for fk in column.foreign_keys: + # TODO: make this support inherited tables + if fk.column.table == self.first.__class__.__table__: + column_name = fk.parent.name + break + if column_name: + break + + self.related_entities = ( + self.session + .query(self.model, getattr(self.prop.secondary.c, column_name)) + .join( + self.prop.secondary, self.prop.secondaryjoin + ) + .filter( + getattr(self.prop.secondary.c, column_name).in_( + self.parent_ids ) + ) + ) + for entity, parent_id in self.related_entities: + self.parent_dict[parent_id].append( + entity + ) - for entity in related_entities: - parent_dict[getattr(entity, column_name)].append( - entity - ) + def __call__(self, attr_path): + self.parent_dict = dict( + (entity.id, []) for entity in self.entities + ) + if isinstance(attr_path, with_backrefs): + self.should_populate_backrefs = True + attr_path = attr_path.attr_path + else: + self.should_populate_backrefs = False - else: - column_name = None - for column in prop.remote_side: - for fk in column.foreign_keys: - # TODO: make this support inherited tables - if fk.column.table == first.__class__.__table__: - column_name = fk.parent.name - break - if column_name: - break + attr = self.parse_attr_path(attr_path, self.should_populate_backrefs) + if not attr: + return - related_entities = ( - session - .query(model, getattr(prop.secondary.c, column_name)) - .join( - prop.secondary, prop.secondaryjoin - ) - .filter( - getattr(prop.secondary.c, column_name).in_( - parent_ids - ) - ) - ) - for entity, parent_id in related_entities: - parent_dict[parent_id].append( - entity - ) + self.prop = attr.property + if not isinstance(self.prop, RelationshipProperty): + raise Exception( + 'Given attribute is not a relationship property.' + ) - for entity in entities: - set_committed_value( - entity, prop.key, parent_dict[entity.id] - ) - if populate_backrefs: - backref_dict = dict( - (entity.id, []) for entity, parent_id in related_entities - ) - for entity, parent_id in related_entities: - backref_dict[entity.id].append( - session.query(first.__class__).get(parent_id) - ) - for entity, parent_id in related_entities: - set_committed_value( - entity, prop.back_populates, backref_dict[entity.id] - ) + self.model = self.prop.mapper.class_ + + if self.prop.secondary is None: + self.fetch_relation_entities() + else: + self.fetch_association_entities() + self.populate_entities() diff --git a/tests/batch_fetch/test_deep_relationships.py b/tests/batch_fetch/test_deep_relationships.py index e4fa619..b837c29 100644 --- a/tests/batch_fetch/test_deep_relationships.py +++ b/tests/batch_fetch/test_deep_relationships.py @@ -1,5 +1,5 @@ import sqlalchemy as sa -from sqlalchemy_utils import batch_fetch +from sqlalchemy_utils import batch_fetch, with_backrefs from tests import TestCase @@ -107,7 +107,7 @@ class TestBatchFetch(TestCase): batch_fetch( categories, 'articles', - 'articles.tags -pb', + with_backrefs('articles.tags'), ) query_count = self.connection.query_count tags = categories[0].articles[0].tags