Refactored batch fetch, rewrote backref population syntax

This commit is contained in:
Konsta Vesterinen
2013-08-13 10:17:47 +03:00
parent 46e0b41d55
commit 39feea5a6f
4 changed files with 144 additions and 94 deletions

View File

@@ -7,6 +7,7 @@ from .functions import (
render_statement,
sort_query,
table_name,
with_backrefs
)
from .listeners import coercion_listener
from .merge import merge, Merger
@@ -49,6 +50,7 @@ __all__ = (
render_statement,
sort_query,
table_name,
with_backrefs,
ArrowType,
ColorType,
EmailType,

View File

@@ -6,13 +6,14 @@ from sqlalchemy.orm import defer
from sqlalchemy.orm.properties import ColumnProperty
from sqlalchemy.orm.query import Query
from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint
from .batch_fetch import batch_fetch
from .batch_fetch import batch_fetch, with_backrefs
from .sort_query import sort_query
__all__ = (
batch_fetch,
sort_query
sort_query,
with_backrefs
)

View File

@@ -4,6 +4,11 @@ from sqlalchemy.orm.attributes import set_committed_value
from sqlalchemy.orm.session import object_session
class with_backrefs(object):
def __init__(self, attr_path):
self.attr_path = attr_path
def batch_fetch(entities, *attr_paths):
"""
Batch fetch given relationship attribute for collection of entities.
@@ -50,121 +55,163 @@ def batch_fetch(entities, *attr_paths):
You can also force populate backrefs: ::
from sqlalchemy_utils import with_backrefs
clubs = session.query(Club).limit(20).all()
batch_fetch(
clubs,
'teams',
'teams.players',
'teams.players.user_groups -pb'
with_backrefs('teams.players.user_groups')
)
"""
if entities:
first = entities[0]
parent_ids = [entity.id for entity in entities]
fetcher = BatchFetcher(entities)
for attr_path in attr_paths:
parent_dict = dict((entity.id, []) for entity in entities)
populate_backrefs = False
fetcher(attr_path)
if isinstance(attr_path, six.string_types):
attrs = attr_path.split('.')
if len(attrs) > 1:
related_entities = []
for entity in entities:
related_entities.extend(getattr(entity, attrs[0]))
class BatchFetcher(object):
def __init__(self, entities):
self.entities = entities
self.first = entities[0]
self.parent_ids = [entity.id for entity in entities]
self.session = object_session(self.first)
batch_fetch(
related_entities,
'.'.join(attrs[1:])
)
continue
else:
args = attrs[-1].split(' ')
if '-pb' in args:
populate_backrefs = True
def populate_backrefs(self, related_entities):
"""
Populates backrefs for given related entities.
"""
attr = getattr(
first.__class__, args[0]
)
backref_dict = dict(
(entity.id, []) for entity, parent_id in related_entities
)
for entity, parent_id in related_entities:
backref_dict[entity.id].append(
self.session.query(self.first.__class__).get(parent_id)
)
for entity, parent_id in related_entities:
set_committed_value(
entity, self.prop.back_populates, backref_dict[entity.id]
)
def populate_entities(self):
"""
Populate batch fetched entities to parent objects.
"""
for entity in self.entities:
set_committed_value(
entity,
self.prop.key,
self.parent_dict[entity.id]
)
if self.should_populate_backrefs:
self.populate_backrefs(self.related_entities)
def parse_attr_path(self, attr_path, should_populate_backrefs):
if isinstance(attr_path, six.string_types):
attrs = attr_path.split('.')
if len(attrs) > 1:
related_entities = []
for entity in self.entities:
related_entities.extend(getattr(entity, attrs[0]))
subpath = '.'.join(attrs[1:])
if should_populate_backrefs:
subpath = with_backrefs(subpath)
batch_fetch(
related_entities,
subpath
)
return
else:
attr = attr_path
prop = attr.property
if not isinstance(prop, RelationshipProperty):
raise Exception(
'Given attribute is not a relationship property.'
return getattr(
self.first.__class__, attrs[0]
)
else:
return attr_path
model = prop.mapper.class_
def fetch_relation_entities(self):
if len(self.prop.remote_side) > 1:
raise Exception(
'Only relationships with single remote side columns '
'are supported.'
)
session = object_session(first)
column_name = list(self.prop.remote_side)[0].name
if prop.secondary is None:
if len(prop.remote_side) > 1:
raise Exception(
'Only relationships with single remote side columns '
'are supported.'
)
self.related_entities = (
self.session.query(self.model)
.filter(
getattr(self.model, column_name).in_(self.parent_ids)
)
)
column_name = list(prop.remote_side)[0].name
for entity in self.related_entities:
self.parent_dict[getattr(entity, column_name)].append(
entity
)
related_entities = (
session.query(model)
.filter(
getattr(model, column_name).in_(parent_ids)
)
def fetch_association_entities(self):
column_name = None
for column in self.prop.remote_side:
for fk in column.foreign_keys:
# TODO: make this support inherited tables
if fk.column.table == self.first.__class__.__table__:
column_name = fk.parent.name
break
if column_name:
break
self.related_entities = (
self.session
.query(self.model, getattr(self.prop.secondary.c, column_name))
.join(
self.prop.secondary, self.prop.secondaryjoin
)
.filter(
getattr(self.prop.secondary.c, column_name).in_(
self.parent_ids
)
)
)
for entity, parent_id in self.related_entities:
self.parent_dict[parent_id].append(
entity
)
for entity in related_entities:
parent_dict[getattr(entity, column_name)].append(
entity
)
def __call__(self, attr_path):
self.parent_dict = dict(
(entity.id, []) for entity in self.entities
)
if isinstance(attr_path, with_backrefs):
self.should_populate_backrefs = True
attr_path = attr_path.attr_path
else:
self.should_populate_backrefs = False
else:
column_name = None
for column in prop.remote_side:
for fk in column.foreign_keys:
# TODO: make this support inherited tables
if fk.column.table == first.__class__.__table__:
column_name = fk.parent.name
break
if column_name:
break
attr = self.parse_attr_path(attr_path, self.should_populate_backrefs)
if not attr:
return
related_entities = (
session
.query(model, getattr(prop.secondary.c, column_name))
.join(
prop.secondary, prop.secondaryjoin
)
.filter(
getattr(prop.secondary.c, column_name).in_(
parent_ids
)
)
)
for entity, parent_id in related_entities:
parent_dict[parent_id].append(
entity
)
self.prop = attr.property
if not isinstance(self.prop, RelationshipProperty):
raise Exception(
'Given attribute is not a relationship property.'
)
for entity in entities:
set_committed_value(
entity, prop.key, parent_dict[entity.id]
)
if populate_backrefs:
backref_dict = dict(
(entity.id, []) for entity, parent_id in related_entities
)
for entity, parent_id in related_entities:
backref_dict[entity.id].append(
session.query(first.__class__).get(parent_id)
)
for entity, parent_id in related_entities:
set_committed_value(
entity, prop.back_populates, backref_dict[entity.id]
)
self.model = self.prop.mapper.class_
if self.prop.secondary is None:
self.fetch_relation_entities()
else:
self.fetch_association_entities()
self.populate_entities()

View File

@@ -1,5 +1,5 @@
import sqlalchemy as sa
from sqlalchemy_utils import batch_fetch
from sqlalchemy_utils import batch_fetch, with_backrefs
from tests import TestCase
@@ -107,7 +107,7 @@ class TestBatchFetch(TestCase):
batch_fetch(
categories,
'articles',
'articles.tags -pb',
with_backrefs('articles.tags'),
)
query_count = self.connection.query_count
tags = categories[0].articles[0].tags