Refactored batch fetch, rewrote backref population syntax
This commit is contained in:
@@ -7,6 +7,7 @@ from .functions import (
|
|||||||
render_statement,
|
render_statement,
|
||||||
sort_query,
|
sort_query,
|
||||||
table_name,
|
table_name,
|
||||||
|
with_backrefs
|
||||||
)
|
)
|
||||||
from .listeners import coercion_listener
|
from .listeners import coercion_listener
|
||||||
from .merge import merge, Merger
|
from .merge import merge, Merger
|
||||||
@@ -49,6 +50,7 @@ __all__ = (
|
|||||||
render_statement,
|
render_statement,
|
||||||
sort_query,
|
sort_query,
|
||||||
table_name,
|
table_name,
|
||||||
|
with_backrefs,
|
||||||
ArrowType,
|
ArrowType,
|
||||||
ColorType,
|
ColorType,
|
||||||
EmailType,
|
EmailType,
|
||||||
|
@@ -6,13 +6,14 @@ from sqlalchemy.orm import defer
|
|||||||
from sqlalchemy.orm.properties import ColumnProperty
|
from sqlalchemy.orm.properties import ColumnProperty
|
||||||
from sqlalchemy.orm.query import Query
|
from sqlalchemy.orm.query import Query
|
||||||
from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint
|
from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint
|
||||||
from .batch_fetch import batch_fetch
|
from .batch_fetch import batch_fetch, with_backrefs
|
||||||
from .sort_query import sort_query
|
from .sort_query import sort_query
|
||||||
|
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
batch_fetch,
|
batch_fetch,
|
||||||
sort_query
|
sort_query,
|
||||||
|
with_backrefs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@@ -4,6 +4,11 @@ from sqlalchemy.orm.attributes import set_committed_value
|
|||||||
from sqlalchemy.orm.session import object_session
|
from sqlalchemy.orm.session import object_session
|
||||||
|
|
||||||
|
|
||||||
|
class with_backrefs(object):
|
||||||
|
def __init__(self, attr_path):
|
||||||
|
self.attr_path = attr_path
|
||||||
|
|
||||||
|
|
||||||
def batch_fetch(entities, *attr_paths):
|
def batch_fetch(entities, *attr_paths):
|
||||||
"""
|
"""
|
||||||
Batch fetch given relationship attribute for collection of entities.
|
Batch fetch given relationship attribute for collection of entities.
|
||||||
@@ -50,121 +55,163 @@ def batch_fetch(entities, *attr_paths):
|
|||||||
You can also force populate backrefs: ::
|
You can also force populate backrefs: ::
|
||||||
|
|
||||||
|
|
||||||
|
from sqlalchemy_utils import with_backrefs
|
||||||
|
|
||||||
|
|
||||||
clubs = session.query(Club).limit(20).all()
|
clubs = session.query(Club).limit(20).all()
|
||||||
|
|
||||||
batch_fetch(
|
batch_fetch(
|
||||||
clubs,
|
clubs,
|
||||||
'teams',
|
'teams',
|
||||||
'teams.players',
|
'teams.players',
|
||||||
'teams.players.user_groups -pb'
|
with_backrefs('teams.players.user_groups')
|
||||||
)
|
)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if entities:
|
if entities:
|
||||||
first = entities[0]
|
fetcher = BatchFetcher(entities)
|
||||||
parent_ids = [entity.id for entity in entities]
|
|
||||||
|
|
||||||
for attr_path in attr_paths:
|
for attr_path in attr_paths:
|
||||||
parent_dict = dict((entity.id, []) for entity in entities)
|
fetcher(attr_path)
|
||||||
populate_backrefs = False
|
|
||||||
|
|
||||||
if isinstance(attr_path, six.string_types):
|
|
||||||
attrs = attr_path.split('.')
|
|
||||||
|
|
||||||
if len(attrs) > 1:
|
class BatchFetcher(object):
|
||||||
related_entities = []
|
def __init__(self, entities):
|
||||||
for entity in entities:
|
self.entities = entities
|
||||||
related_entities.extend(getattr(entity, attrs[0]))
|
self.first = entities[0]
|
||||||
|
self.parent_ids = [entity.id for entity in entities]
|
||||||
|
self.session = object_session(self.first)
|
||||||
|
|
||||||
batch_fetch(
|
def populate_backrefs(self, related_entities):
|
||||||
related_entities,
|
"""
|
||||||
'.'.join(attrs[1:])
|
Populates backrefs for given related entities.
|
||||||
)
|
"""
|
||||||
continue
|
|
||||||
else:
|
|
||||||
args = attrs[-1].split(' ')
|
|
||||||
if '-pb' in args:
|
|
||||||
populate_backrefs = True
|
|
||||||
|
|
||||||
attr = getattr(
|
backref_dict = dict(
|
||||||
first.__class__, args[0]
|
(entity.id, []) for entity, parent_id in related_entities
|
||||||
)
|
)
|
||||||
|
for entity, parent_id in related_entities:
|
||||||
|
backref_dict[entity.id].append(
|
||||||
|
self.session.query(self.first.__class__).get(parent_id)
|
||||||
|
)
|
||||||
|
for entity, parent_id in related_entities:
|
||||||
|
set_committed_value(
|
||||||
|
entity, self.prop.back_populates, backref_dict[entity.id]
|
||||||
|
)
|
||||||
|
|
||||||
|
def populate_entities(self):
|
||||||
|
"""
|
||||||
|
Populate batch fetched entities to parent objects.
|
||||||
|
"""
|
||||||
|
for entity in self.entities:
|
||||||
|
set_committed_value(
|
||||||
|
entity,
|
||||||
|
self.prop.key,
|
||||||
|
self.parent_dict[entity.id]
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.should_populate_backrefs:
|
||||||
|
self.populate_backrefs(self.related_entities)
|
||||||
|
|
||||||
|
def parse_attr_path(self, attr_path, should_populate_backrefs):
|
||||||
|
if isinstance(attr_path, six.string_types):
|
||||||
|
attrs = attr_path.split('.')
|
||||||
|
|
||||||
|
if len(attrs) > 1:
|
||||||
|
related_entities = []
|
||||||
|
for entity in self.entities:
|
||||||
|
related_entities.extend(getattr(entity, attrs[0]))
|
||||||
|
|
||||||
|
subpath = '.'.join(attrs[1:])
|
||||||
|
|
||||||
|
if should_populate_backrefs:
|
||||||
|
subpath = with_backrefs(subpath)
|
||||||
|
|
||||||
|
batch_fetch(
|
||||||
|
related_entities,
|
||||||
|
subpath
|
||||||
|
)
|
||||||
|
return
|
||||||
else:
|
else:
|
||||||
attr = attr_path
|
return getattr(
|
||||||
|
self.first.__class__, attrs[0]
|
||||||
prop = attr.property
|
|
||||||
if not isinstance(prop, RelationshipProperty):
|
|
||||||
raise Exception(
|
|
||||||
'Given attribute is not a relationship property.'
|
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
return attr_path
|
||||||
|
|
||||||
model = prop.mapper.class_
|
def fetch_relation_entities(self):
|
||||||
|
if len(self.prop.remote_side) > 1:
|
||||||
|
raise Exception(
|
||||||
|
'Only relationships with single remote side columns '
|
||||||
|
'are supported.'
|
||||||
|
)
|
||||||
|
|
||||||
session = object_session(first)
|
column_name = list(self.prop.remote_side)[0].name
|
||||||
|
|
||||||
if prop.secondary is None:
|
self.related_entities = (
|
||||||
if len(prop.remote_side) > 1:
|
self.session.query(self.model)
|
||||||
raise Exception(
|
.filter(
|
||||||
'Only relationships with single remote side columns '
|
getattr(self.model, column_name).in_(self.parent_ids)
|
||||||
'are supported.'
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
column_name = list(prop.remote_side)[0].name
|
for entity in self.related_entities:
|
||||||
|
self.parent_dict[getattr(entity, column_name)].append(
|
||||||
|
entity
|
||||||
|
)
|
||||||
|
|
||||||
related_entities = (
|
def fetch_association_entities(self):
|
||||||
session.query(model)
|
column_name = None
|
||||||
.filter(
|
for column in self.prop.remote_side:
|
||||||
getattr(model, column_name).in_(parent_ids)
|
for fk in column.foreign_keys:
|
||||||
)
|
# TODO: make this support inherited tables
|
||||||
|
if fk.column.table == self.first.__class__.__table__:
|
||||||
|
column_name = fk.parent.name
|
||||||
|
break
|
||||||
|
if column_name:
|
||||||
|
break
|
||||||
|
|
||||||
|
self.related_entities = (
|
||||||
|
self.session
|
||||||
|
.query(self.model, getattr(self.prop.secondary.c, column_name))
|
||||||
|
.join(
|
||||||
|
self.prop.secondary, self.prop.secondaryjoin
|
||||||
|
)
|
||||||
|
.filter(
|
||||||
|
getattr(self.prop.secondary.c, column_name).in_(
|
||||||
|
self.parent_ids
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for entity, parent_id in self.related_entities:
|
||||||
|
self.parent_dict[parent_id].append(
|
||||||
|
entity
|
||||||
|
)
|
||||||
|
|
||||||
for entity in related_entities:
|
def __call__(self, attr_path):
|
||||||
parent_dict[getattr(entity, column_name)].append(
|
self.parent_dict = dict(
|
||||||
entity
|
(entity.id, []) for entity in self.entities
|
||||||
)
|
)
|
||||||
|
if isinstance(attr_path, with_backrefs):
|
||||||
|
self.should_populate_backrefs = True
|
||||||
|
attr_path = attr_path.attr_path
|
||||||
|
else:
|
||||||
|
self.should_populate_backrefs = False
|
||||||
|
|
||||||
else:
|
attr = self.parse_attr_path(attr_path, self.should_populate_backrefs)
|
||||||
column_name = None
|
if not attr:
|
||||||
for column in prop.remote_side:
|
return
|
||||||
for fk in column.foreign_keys:
|
|
||||||
# TODO: make this support inherited tables
|
|
||||||
if fk.column.table == first.__class__.__table__:
|
|
||||||
column_name = fk.parent.name
|
|
||||||
break
|
|
||||||
if column_name:
|
|
||||||
break
|
|
||||||
|
|
||||||
related_entities = (
|
self.prop = attr.property
|
||||||
session
|
if not isinstance(self.prop, RelationshipProperty):
|
||||||
.query(model, getattr(prop.secondary.c, column_name))
|
raise Exception(
|
||||||
.join(
|
'Given attribute is not a relationship property.'
|
||||||
prop.secondary, prop.secondaryjoin
|
)
|
||||||
)
|
|
||||||
.filter(
|
|
||||||
getattr(prop.secondary.c, column_name).in_(
|
|
||||||
parent_ids
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
for entity, parent_id in related_entities:
|
|
||||||
parent_dict[parent_id].append(
|
|
||||||
entity
|
|
||||||
)
|
|
||||||
|
|
||||||
for entity in entities:
|
self.model = self.prop.mapper.class_
|
||||||
set_committed_value(
|
|
||||||
entity, prop.key, parent_dict[entity.id]
|
if self.prop.secondary is None:
|
||||||
)
|
self.fetch_relation_entities()
|
||||||
if populate_backrefs:
|
else:
|
||||||
backref_dict = dict(
|
self.fetch_association_entities()
|
||||||
(entity.id, []) for entity, parent_id in related_entities
|
self.populate_entities()
|
||||||
)
|
|
||||||
for entity, parent_id in related_entities:
|
|
||||||
backref_dict[entity.id].append(
|
|
||||||
session.query(first.__class__).get(parent_id)
|
|
||||||
)
|
|
||||||
for entity, parent_id in related_entities:
|
|
||||||
set_committed_value(
|
|
||||||
entity, prop.back_populates, backref_dict[entity.id]
|
|
||||||
)
|
|
||||||
|
@@ -1,5 +1,5 @@
|
|||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from sqlalchemy_utils import batch_fetch
|
from sqlalchemy_utils import batch_fetch, with_backrefs
|
||||||
from tests import TestCase
|
from tests import TestCase
|
||||||
|
|
||||||
|
|
||||||
@@ -107,7 +107,7 @@ class TestBatchFetch(TestCase):
|
|||||||
batch_fetch(
|
batch_fetch(
|
||||||
categories,
|
categories,
|
||||||
'articles',
|
'articles',
|
||||||
'articles.tags -pb',
|
with_backrefs('articles.tags'),
|
||||||
)
|
)
|
||||||
query_count = self.connection.query_count
|
query_count = self.connection.query_count
|
||||||
tags = categories[0].articles[0].tags
|
tags = categories[0].articles[0].tags
|
||||||
|
Reference in New Issue
Block a user