Refactored batch fetch, rewrote backref population syntax

This commit is contained in:
Konsta Vesterinen
2013-08-13 10:17:47 +03:00
parent 46e0b41d55
commit 39feea5a6f
4 changed files with 144 additions and 94 deletions

View File

@@ -7,6 +7,7 @@ from .functions import (
render_statement, render_statement,
sort_query, sort_query,
table_name, table_name,
with_backrefs
) )
from .listeners import coercion_listener from .listeners import coercion_listener
from .merge import merge, Merger from .merge import merge, Merger
@@ -49,6 +50,7 @@ __all__ = (
render_statement, render_statement,
sort_query, sort_query,
table_name, table_name,
with_backrefs,
ArrowType, ArrowType,
ColorType, ColorType,
EmailType, EmailType,

View File

@@ -6,13 +6,14 @@ from sqlalchemy.orm import defer
from sqlalchemy.orm.properties import ColumnProperty from sqlalchemy.orm.properties import ColumnProperty
from sqlalchemy.orm.query import Query from sqlalchemy.orm.query import Query
from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint
from .batch_fetch import batch_fetch from .batch_fetch import batch_fetch, with_backrefs
from .sort_query import sort_query from .sort_query import sort_query
__all__ = ( __all__ = (
batch_fetch, batch_fetch,
sort_query sort_query,
with_backrefs
) )

View File

@@ -4,6 +4,11 @@ from sqlalchemy.orm.attributes import set_committed_value
from sqlalchemy.orm.session import object_session from sqlalchemy.orm.session import object_session
class with_backrefs(object):
def __init__(self, attr_path):
self.attr_path = attr_path
def batch_fetch(entities, *attr_paths): def batch_fetch(entities, *attr_paths):
""" """
Batch fetch given relationship attribute for collection of entities. Batch fetch given relationship attribute for collection of entities.
@@ -50,121 +55,163 @@ def batch_fetch(entities, *attr_paths):
You can also force populate backrefs: :: You can also force populate backrefs: ::
from sqlalchemy_utils import with_backrefs
clubs = session.query(Club).limit(20).all() clubs = session.query(Club).limit(20).all()
batch_fetch( batch_fetch(
clubs, clubs,
'teams', 'teams',
'teams.players', 'teams.players',
'teams.players.user_groups -pb' with_backrefs('teams.players.user_groups')
) )
""" """
if entities: if entities:
first = entities[0] fetcher = BatchFetcher(entities)
parent_ids = [entity.id for entity in entities]
for attr_path in attr_paths: for attr_path in attr_paths:
parent_dict = dict((entity.id, []) for entity in entities) fetcher(attr_path)
populate_backrefs = False
if isinstance(attr_path, six.string_types):
attrs = attr_path.split('.')
if len(attrs) > 1: class BatchFetcher(object):
related_entities = [] def __init__(self, entities):
for entity in entities: self.entities = entities
related_entities.extend(getattr(entity, attrs[0])) self.first = entities[0]
self.parent_ids = [entity.id for entity in entities]
self.session = object_session(self.first)
batch_fetch( def populate_backrefs(self, related_entities):
related_entities, """
'.'.join(attrs[1:]) Populates backrefs for given related entities.
) """
continue
else:
args = attrs[-1].split(' ')
if '-pb' in args:
populate_backrefs = True
attr = getattr( backref_dict = dict(
first.__class__, args[0] (entity.id, []) for entity, parent_id in related_entities
) )
for entity, parent_id in related_entities:
backref_dict[entity.id].append(
self.session.query(self.first.__class__).get(parent_id)
)
for entity, parent_id in related_entities:
set_committed_value(
entity, self.prop.back_populates, backref_dict[entity.id]
)
def populate_entities(self):
"""
Populate batch fetched entities to parent objects.
"""
for entity in self.entities:
set_committed_value(
entity,
self.prop.key,
self.parent_dict[entity.id]
)
if self.should_populate_backrefs:
self.populate_backrefs(self.related_entities)
def parse_attr_path(self, attr_path, should_populate_backrefs):
if isinstance(attr_path, six.string_types):
attrs = attr_path.split('.')
if len(attrs) > 1:
related_entities = []
for entity in self.entities:
related_entities.extend(getattr(entity, attrs[0]))
subpath = '.'.join(attrs[1:])
if should_populate_backrefs:
subpath = with_backrefs(subpath)
batch_fetch(
related_entities,
subpath
)
return
else: else:
attr = attr_path return getattr(
self.first.__class__, attrs[0]
prop = attr.property
if not isinstance(prop, RelationshipProperty):
raise Exception(
'Given attribute is not a relationship property.'
) )
else:
return attr_path
model = prop.mapper.class_ def fetch_relation_entities(self):
if len(self.prop.remote_side) > 1:
raise Exception(
'Only relationships with single remote side columns '
'are supported.'
)
session = object_session(first) column_name = list(self.prop.remote_side)[0].name
if prop.secondary is None: self.related_entities = (
if len(prop.remote_side) > 1: self.session.query(self.model)
raise Exception( .filter(
'Only relationships with single remote side columns ' getattr(self.model, column_name).in_(self.parent_ids)
'are supported.' )
) )
column_name = list(prop.remote_side)[0].name for entity in self.related_entities:
self.parent_dict[getattr(entity, column_name)].append(
entity
)
related_entities = ( def fetch_association_entities(self):
session.query(model) column_name = None
.filter( for column in self.prop.remote_side:
getattr(model, column_name).in_(parent_ids) for fk in column.foreign_keys:
) # TODO: make this support inherited tables
if fk.column.table == self.first.__class__.__table__:
column_name = fk.parent.name
break
if column_name:
break
self.related_entities = (
self.session
.query(self.model, getattr(self.prop.secondary.c, column_name))
.join(
self.prop.secondary, self.prop.secondaryjoin
)
.filter(
getattr(self.prop.secondary.c, column_name).in_(
self.parent_ids
) )
)
)
for entity, parent_id in self.related_entities:
self.parent_dict[parent_id].append(
entity
)
for entity in related_entities: def __call__(self, attr_path):
parent_dict[getattr(entity, column_name)].append( self.parent_dict = dict(
entity (entity.id, []) for entity in self.entities
) )
if isinstance(attr_path, with_backrefs):
self.should_populate_backrefs = True
attr_path = attr_path.attr_path
else:
self.should_populate_backrefs = False
else: attr = self.parse_attr_path(attr_path, self.should_populate_backrefs)
column_name = None if not attr:
for column in prop.remote_side: return
for fk in column.foreign_keys:
# TODO: make this support inherited tables
if fk.column.table == first.__class__.__table__:
column_name = fk.parent.name
break
if column_name:
break
related_entities = ( self.prop = attr.property
session if not isinstance(self.prop, RelationshipProperty):
.query(model, getattr(prop.secondary.c, column_name)) raise Exception(
.join( 'Given attribute is not a relationship property.'
prop.secondary, prop.secondaryjoin )
)
.filter(
getattr(prop.secondary.c, column_name).in_(
parent_ids
)
)
)
for entity, parent_id in related_entities:
parent_dict[parent_id].append(
entity
)
for entity in entities: self.model = self.prop.mapper.class_
set_committed_value(
entity, prop.key, parent_dict[entity.id] if self.prop.secondary is None:
) self.fetch_relation_entities()
if populate_backrefs: else:
backref_dict = dict( self.fetch_association_entities()
(entity.id, []) for entity, parent_id in related_entities self.populate_entities()
)
for entity, parent_id in related_entities:
backref_dict[entity.id].append(
session.query(first.__class__).get(parent_id)
)
for entity, parent_id in related_entities:
set_committed_value(
entity, prop.back_populates, backref_dict[entity.id]
)

View File

@@ -1,5 +1,5 @@
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils import batch_fetch from sqlalchemy_utils import batch_fetch, with_backrefs
from tests import TestCase from tests import TestCase
@@ -107,7 +107,7 @@ class TestBatchFetch(TestCase):
batch_fetch( batch_fetch(
categories, categories,
'articles', 'articles',
'articles.tags -pb', with_backrefs('articles.tags'),
) )
query_count = self.connection.query_count query_count = self.connection.query_count
tags = categories[0].articles[0].tags tags = categories[0].articles[0].tags