Refactor batch fetch

This commit is contained in:
Konsta Vesterinen
2013-12-26 02:26:03 +02:00
parent 485c810fb3
commit c911f50581
2 changed files with 64 additions and 52 deletions

View File

@@ -11,10 +11,15 @@ from sqlalchemy_utils.generic import (
GenericRelationshipProperty, class_from_table_name GenericRelationshipProperty, class_from_table_name
) )
from sqlalchemy_utils.functions.orm import ( from sqlalchemy_utils.functions.orm import (
list_local_values,
list_local_remote_exprs,
local_values, local_values,
local_column_names, local_column_names,
local_remote_expr,
mapfirst,
remote_column_names, remote_column_names,
remote_values remote_values,
remote
) )
@@ -228,7 +233,7 @@ class CompositeFetcher(object):
def fetch(self): def fetch(self):
for entity in self.related_entities: for entity in self.related_entities:
for fetcher in self.fetchers: for fetcher in self.fetchers:
if any(remote_values(entity, fetcher.prop)): if any(remote_values(fetcher.prop, entity)):
fetcher.append_entity(entity) fetcher.append_entity(entity)
def populate(self): def populate(self):
@@ -236,23 +241,13 @@ class CompositeFetcher(object):
fetcher.populate() fetcher.populate()
class AbstractFetcher(object): class Fetcher(object):
@property
def local_values_list(self):
return [
local_values(entity, self.prop)
for entity in self.path.entities
]
class Fetcher(AbstractFetcher):
def __init__(self, path): def __init__(self, path):
self.path = path self.path = path
self.prop = self.path.property self.prop = self.path.property
if self.prop.uselist:
self.parent_dict = defaultdict(list) default = list if self.prop.uselist else lambda: None
else: self.parent_dict = defaultdict(default)
self.parent_dict = defaultdict(lambda: None)
@property @property
def relation_query_base(self): def relation_query_base(self):
@@ -267,11 +262,11 @@ class Fetcher(AbstractFetcher):
Populates backrefs for given related entities. Populates backrefs for given related entities.
""" """
backref_dict = dict( backref_dict = dict(
(local_values(value[0], self.prop), []) (local_values(self.prop, value[0]), [])
for value in related_entities for value in related_entities
) )
for value in related_entities: for value in related_entities:
backref_dict[local_values(value[0], self.prop)].append( backref_dict[local_values(self.prop, value[0])].append(
self.path.session.query(self.path.parent_model).get( self.path.session.query(self.path.parent_model).get(
tuple(value[1:]) tuple(value[1:])
) )
@@ -280,7 +275,7 @@ class Fetcher(AbstractFetcher):
set_committed_value( set_committed_value(
value[0], value[0],
self.prop.back_populates, self.prop.back_populates,
backref_dict[local_values(value[0], self.prop)] backref_dict[local_values(self.prop, value[0])]
) )
def populate(self): def populate(self):
@@ -291,38 +286,24 @@ class Fetcher(AbstractFetcher):
set_committed_value( set_committed_value(
entity, entity,
self.prop.key, self.prop.key,
self.parent_dict[local_values(entity, self.prop)] self.parent_dict[local_values(self.prop, entity)]
) )
if self.path.populate_backrefs: if self.path.populate_backrefs:
self.populate_backrefs(self.related_entities) self.populate_backrefs(self.related_entities)
@property
def remote(self):
return self.path.model
@property @property
def condition(self): def condition(self):
names = list(remote_column_names(self.prop)) names = list(remote_column_names(self.prop))
if len(names) == 1: if len(names) == 1:
return getattr(self.remote, names[0]).in_( attr = getattr(remote(self.prop), names[0])
value[0] for value in self.local_values_list return attr.in_(
mapfirst(list_local_values(self.prop, self.path.entities))
) )
elif len(names) > 1: elif len(names) > 1:
conditions = [] return sa.or_(
for entity in self.path.entities: *list_local_remote_exprs(self.prop, self.path.entities)
conditions.append( )
sa.and_(
*[
getattr(self.remote, remote.name)
==
getattr(entity, local.name)
for local, remote in self.prop.local_remote_pairs
if remote in names
]
)
)
return sa.or_(*conditions)
else: else:
raise PathException( raise PathException(
'Could not obtain remote column names.' 'Could not obtain remote column names.'
@@ -333,7 +314,7 @@ class Fetcher(AbstractFetcher):
self.append_entity(entity) self.append_entity(entity)
class GenericRelationshipFetcher(AbstractFetcher): class GenericRelationshipFetcher(object):
def __init__(self, path): def __init__(self, path):
self.path = path self.path = path
self.prop = self.path.property self.prop = self.path.property
@@ -344,7 +325,7 @@ class GenericRelationshipFetcher(AbstractFetcher):
self.append_entity(entity) self.append_entity(entity)
def append_entity(self, entity): def append_entity(self, entity):
self.parent_dict[remote_values(entity, self.prop)] = entity self.parent_dict[remote_values(self.prop, entity)] = entity
def populate(self): def populate(self):
""" """
@@ -354,7 +335,7 @@ class GenericRelationshipFetcher(AbstractFetcher):
set_committed_value( set_committed_value(
entity, entity,
self.prop.key, self.prop.key,
self.parent_dict[local_values(entity, self.prop)] self.parent_dict[local_values(self.prop, entity)]
) )
@property @property
@@ -380,10 +361,6 @@ class GenericRelationshipFetcher(AbstractFetcher):
class ManyToManyFetcher(Fetcher): class ManyToManyFetcher(Fetcher):
@property
def remote(self):
return self.prop.secondary.c
@property @property
def relation_query_base(self): def relation_query_base(self):
return ( return (
@@ -391,7 +368,7 @@ class ManyToManyFetcher(Fetcher):
.query( .query(
self.path.model, self.path.model,
*[ *[
getattr(self.prop.secondary.c, name) getattr(remote(self.prop), name)
for name in remote_column_names(self.prop) for name in remote_column_names(self.prop)
] ]
) )
@@ -409,11 +386,11 @@ class ManyToManyFetcher(Fetcher):
class ManyToOneFetcher(Fetcher): class ManyToOneFetcher(Fetcher):
def append_entity(self, entity): def append_entity(self, entity):
self.parent_dict[remote_values(entity, self.prop)] = entity self.parent_dict[remote_values(self.prop, entity)] = entity
class OneToManyFetcher(Fetcher): class OneToManyFetcher(Fetcher):
def append_entity(self, entity): def append_entity(self, entity):
self.parent_dict[remote_values(entity, self.prop)].append( self.parent_dict[remote_values(self.prop, entity)].append(
entity entity
) )

View File

@@ -1,4 +1,6 @@
from functools import partial from functools import partial
from funcy import first
from toolz import curry
import sqlalchemy as sa import sqlalchemy as sa
@@ -51,14 +53,47 @@ def getattrs(obj, attrs):
return map(partial(getattr, obj), attrs) return map(partial(getattr, obj), attrs)
def local_values(entity, prop): def mapfirst(iterable):
return map(first, iterable)
@curry
def local_values(prop, entity):
return tuple(getattrs(entity, local_column_names(prop))) return tuple(getattrs(entity, local_column_names(prop)))
def remote_values(entity, prop): def list_local_values(prop, entities):
return map(local_values(prop), entities)
def remote_values(prop, entity):
return tuple(getattrs(entity, remote_column_names(prop))) return tuple(getattrs(entity, remote_column_names(prop)))
@curry
def local_remote_expr(prop, entity):
return sa.and_(
*[
getattr(remote(prop), r.name)
==
getattr(entity, l.name)
for l, r in prop.local_remote_pairs
if r in remote_column_names(prop)
]
)
def list_local_remote_exprs(prop, entities):
return map(local_remote_expr(prop), entities)
def remote(prop):
try:
return prop.secondary.c
except AttributeError:
return prop.mapper.class_
def local_column_names(prop): def local_column_names(prop):
if not hasattr(prop, 'secondary'): if not hasattr(prop, 'secondary'):
yield prop._discriminator_col.key yield prop._discriminator_col.key