Refactor batch fetch

This commit is contained in:
Konsta Vesterinen
2013-12-26 02:26:03 +02:00
parent 485c810fb3
commit c911f50581
2 changed files with 64 additions and 52 deletions

View File

@@ -11,10 +11,15 @@ from sqlalchemy_utils.generic import (
GenericRelationshipProperty, class_from_table_name
)
from sqlalchemy_utils.functions.orm import (
list_local_values,
list_local_remote_exprs,
local_values,
local_column_names,
local_remote_expr,
mapfirst,
remote_column_names,
remote_values
remote_values,
remote
)
@@ -228,7 +233,7 @@ class CompositeFetcher(object):
def fetch(self):
for entity in self.related_entities:
for fetcher in self.fetchers:
if any(remote_values(entity, fetcher.prop)):
if any(remote_values(fetcher.prop, entity)):
fetcher.append_entity(entity)
def populate(self):
@@ -236,23 +241,13 @@ class CompositeFetcher(object):
fetcher.populate()
class AbstractFetcher(object):
@property
def local_values_list(self):
return [
local_values(entity, self.prop)
for entity in self.path.entities
]
class Fetcher(AbstractFetcher):
class Fetcher(object):
def __init__(self, path):
self.path = path
self.prop = self.path.property
if self.prop.uselist:
self.parent_dict = defaultdict(list)
else:
self.parent_dict = defaultdict(lambda: None)
default = list if self.prop.uselist else lambda: None
self.parent_dict = defaultdict(default)
@property
def relation_query_base(self):
@@ -267,11 +262,11 @@ class Fetcher(AbstractFetcher):
Populates backrefs for given related entities.
"""
backref_dict = dict(
(local_values(value[0], self.prop), [])
(local_values(self.prop, value[0]), [])
for value in related_entities
)
for value in related_entities:
backref_dict[local_values(value[0], self.prop)].append(
backref_dict[local_values(self.prop, value[0])].append(
self.path.session.query(self.path.parent_model).get(
tuple(value[1:])
)
@@ -280,7 +275,7 @@ class Fetcher(AbstractFetcher):
set_committed_value(
value[0],
self.prop.back_populates,
backref_dict[local_values(value[0], self.prop)]
backref_dict[local_values(self.prop, value[0])]
)
def populate(self):
@@ -291,38 +286,24 @@ class Fetcher(AbstractFetcher):
set_committed_value(
entity,
self.prop.key,
self.parent_dict[local_values(entity, self.prop)]
self.parent_dict[local_values(self.prop, entity)]
)
if self.path.populate_backrefs:
self.populate_backrefs(self.related_entities)
@property
def remote(self):
return self.path.model
@property
def condition(self):
names = list(remote_column_names(self.prop))
if len(names) == 1:
return getattr(self.remote, names[0]).in_(
value[0] for value in self.local_values_list
attr = getattr(remote(self.prop), names[0])
return attr.in_(
mapfirst(list_local_values(self.prop, self.path.entities))
)
elif len(names) > 1:
conditions = []
for entity in self.path.entities:
conditions.append(
sa.and_(
*[
getattr(self.remote, remote.name)
==
getattr(entity, local.name)
for local, remote in self.prop.local_remote_pairs
if remote in names
]
)
)
return sa.or_(*conditions)
return sa.or_(
*list_local_remote_exprs(self.prop, self.path.entities)
)
else:
raise PathException(
'Could not obtain remote column names.'
@@ -333,7 +314,7 @@ class Fetcher(AbstractFetcher):
self.append_entity(entity)
class GenericRelationshipFetcher(AbstractFetcher):
class GenericRelationshipFetcher(object):
def __init__(self, path):
self.path = path
self.prop = self.path.property
@@ -344,7 +325,7 @@ class GenericRelationshipFetcher(AbstractFetcher):
self.append_entity(entity)
def append_entity(self, entity):
self.parent_dict[remote_values(entity, self.prop)] = entity
self.parent_dict[remote_values(self.prop, entity)] = entity
def populate(self):
"""
@@ -354,7 +335,7 @@ class GenericRelationshipFetcher(AbstractFetcher):
set_committed_value(
entity,
self.prop.key,
self.parent_dict[local_values(entity, self.prop)]
self.parent_dict[local_values(self.prop, entity)]
)
@property
@@ -380,10 +361,6 @@ class GenericRelationshipFetcher(AbstractFetcher):
class ManyToManyFetcher(Fetcher):
@property
def remote(self):
return self.prop.secondary.c
@property
def relation_query_base(self):
return (
@@ -391,7 +368,7 @@ class ManyToManyFetcher(Fetcher):
.query(
self.path.model,
*[
getattr(self.prop.secondary.c, name)
getattr(remote(self.prop), name)
for name in remote_column_names(self.prop)
]
)
@@ -409,11 +386,11 @@ class ManyToManyFetcher(Fetcher):
class ManyToOneFetcher(Fetcher):
def append_entity(self, entity):
self.parent_dict[remote_values(entity, self.prop)] = entity
self.parent_dict[remote_values(self.prop, entity)] = entity
class OneToManyFetcher(Fetcher):
def append_entity(self, entity):
self.parent_dict[remote_values(entity, self.prop)].append(
self.parent_dict[remote_values(self.prop, entity)].append(
entity
)

View File

@@ -1,4 +1,6 @@
from functools import partial
from funcy import first
from toolz import curry
import sqlalchemy as sa
@@ -51,14 +53,47 @@ def getattrs(obj, attrs):
return map(partial(getattr, obj), attrs)
def local_values(entity, prop):
def mapfirst(iterable):
return map(first, iterable)
@curry
def local_values(prop, entity):
return tuple(getattrs(entity, local_column_names(prop)))
def remote_values(entity, prop):
def list_local_values(prop, entities):
return map(local_values(prop), entities)
def remote_values(prop, entity):
return tuple(getattrs(entity, remote_column_names(prop)))
@curry
def local_remote_expr(prop, entity):
return sa.and_(
*[
getattr(remote(prop), r.name)
==
getattr(entity, l.name)
for l, r in prop.local_remote_pairs
if r in remote_column_names(prop)
]
)
def list_local_remote_exprs(prop, entities):
return map(local_remote_expr(prop), entities)
def remote(prop):
try:
return prop.secondary.c
except AttributeError:
return prop.mapper.class_
def local_column_names(prop):
if not hasattr(prop, 'secondary'):
yield prop._discriminator_col.key