Refactor batch fetch
This commit is contained in:
@@ -11,10 +11,15 @@ from sqlalchemy_utils.generic import (
|
||||
GenericRelationshipProperty, class_from_table_name
|
||||
)
|
||||
from sqlalchemy_utils.functions.orm import (
|
||||
list_local_values,
|
||||
list_local_remote_exprs,
|
||||
local_values,
|
||||
local_column_names,
|
||||
local_remote_expr,
|
||||
mapfirst,
|
||||
remote_column_names,
|
||||
remote_values
|
||||
remote_values,
|
||||
remote
|
||||
)
|
||||
|
||||
|
||||
@@ -228,7 +233,7 @@ class CompositeFetcher(object):
|
||||
def fetch(self):
|
||||
for entity in self.related_entities:
|
||||
for fetcher in self.fetchers:
|
||||
if any(remote_values(entity, fetcher.prop)):
|
||||
if any(remote_values(fetcher.prop, entity)):
|
||||
fetcher.append_entity(entity)
|
||||
|
||||
def populate(self):
|
||||
@@ -236,23 +241,13 @@ class CompositeFetcher(object):
|
||||
fetcher.populate()
|
||||
|
||||
|
||||
class AbstractFetcher(object):
|
||||
@property
|
||||
def local_values_list(self):
|
||||
return [
|
||||
local_values(entity, self.prop)
|
||||
for entity in self.path.entities
|
||||
]
|
||||
|
||||
|
||||
class Fetcher(AbstractFetcher):
|
||||
class Fetcher(object):
|
||||
def __init__(self, path):
|
||||
self.path = path
|
||||
self.prop = self.path.property
|
||||
if self.prop.uselist:
|
||||
self.parent_dict = defaultdict(list)
|
||||
else:
|
||||
self.parent_dict = defaultdict(lambda: None)
|
||||
|
||||
default = list if self.prop.uselist else lambda: None
|
||||
self.parent_dict = defaultdict(default)
|
||||
|
||||
@property
|
||||
def relation_query_base(self):
|
||||
@@ -267,11 +262,11 @@ class Fetcher(AbstractFetcher):
|
||||
Populates backrefs for given related entities.
|
||||
"""
|
||||
backref_dict = dict(
|
||||
(local_values(value[0], self.prop), [])
|
||||
(local_values(self.prop, value[0]), [])
|
||||
for value in related_entities
|
||||
)
|
||||
for value in related_entities:
|
||||
backref_dict[local_values(value[0], self.prop)].append(
|
||||
backref_dict[local_values(self.prop, value[0])].append(
|
||||
self.path.session.query(self.path.parent_model).get(
|
||||
tuple(value[1:])
|
||||
)
|
||||
@@ -280,7 +275,7 @@ class Fetcher(AbstractFetcher):
|
||||
set_committed_value(
|
||||
value[0],
|
||||
self.prop.back_populates,
|
||||
backref_dict[local_values(value[0], self.prop)]
|
||||
backref_dict[local_values(self.prop, value[0])]
|
||||
)
|
||||
|
||||
def populate(self):
|
||||
@@ -291,38 +286,24 @@ class Fetcher(AbstractFetcher):
|
||||
set_committed_value(
|
||||
entity,
|
||||
self.prop.key,
|
||||
self.parent_dict[local_values(entity, self.prop)]
|
||||
self.parent_dict[local_values(self.prop, entity)]
|
||||
)
|
||||
|
||||
if self.path.populate_backrefs:
|
||||
self.populate_backrefs(self.related_entities)
|
||||
|
||||
@property
|
||||
def remote(self):
|
||||
return self.path.model
|
||||
|
||||
@property
|
||||
def condition(self):
|
||||
names = list(remote_column_names(self.prop))
|
||||
if len(names) == 1:
|
||||
return getattr(self.remote, names[0]).in_(
|
||||
value[0] for value in self.local_values_list
|
||||
attr = getattr(remote(self.prop), names[0])
|
||||
return attr.in_(
|
||||
mapfirst(list_local_values(self.prop, self.path.entities))
|
||||
)
|
||||
elif len(names) > 1:
|
||||
conditions = []
|
||||
for entity in self.path.entities:
|
||||
conditions.append(
|
||||
sa.and_(
|
||||
*[
|
||||
getattr(self.remote, remote.name)
|
||||
==
|
||||
getattr(entity, local.name)
|
||||
for local, remote in self.prop.local_remote_pairs
|
||||
if remote in names
|
||||
]
|
||||
)
|
||||
)
|
||||
return sa.or_(*conditions)
|
||||
return sa.or_(
|
||||
*list_local_remote_exprs(self.prop, self.path.entities)
|
||||
)
|
||||
else:
|
||||
raise PathException(
|
||||
'Could not obtain remote column names.'
|
||||
@@ -333,7 +314,7 @@ class Fetcher(AbstractFetcher):
|
||||
self.append_entity(entity)
|
||||
|
||||
|
||||
class GenericRelationshipFetcher(AbstractFetcher):
|
||||
class GenericRelationshipFetcher(object):
|
||||
def __init__(self, path):
|
||||
self.path = path
|
||||
self.prop = self.path.property
|
||||
@@ -344,7 +325,7 @@ class GenericRelationshipFetcher(AbstractFetcher):
|
||||
self.append_entity(entity)
|
||||
|
||||
def append_entity(self, entity):
|
||||
self.parent_dict[remote_values(entity, self.prop)] = entity
|
||||
self.parent_dict[remote_values(self.prop, entity)] = entity
|
||||
|
||||
def populate(self):
|
||||
"""
|
||||
@@ -354,7 +335,7 @@ class GenericRelationshipFetcher(AbstractFetcher):
|
||||
set_committed_value(
|
||||
entity,
|
||||
self.prop.key,
|
||||
self.parent_dict[local_values(entity, self.prop)]
|
||||
self.parent_dict[local_values(self.prop, entity)]
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -380,10 +361,6 @@ class GenericRelationshipFetcher(AbstractFetcher):
|
||||
|
||||
|
||||
class ManyToManyFetcher(Fetcher):
|
||||
@property
|
||||
def remote(self):
|
||||
return self.prop.secondary.c
|
||||
|
||||
@property
|
||||
def relation_query_base(self):
|
||||
return (
|
||||
@@ -391,7 +368,7 @@ class ManyToManyFetcher(Fetcher):
|
||||
.query(
|
||||
self.path.model,
|
||||
*[
|
||||
getattr(self.prop.secondary.c, name)
|
||||
getattr(remote(self.prop), name)
|
||||
for name in remote_column_names(self.prop)
|
||||
]
|
||||
)
|
||||
@@ -409,11 +386,11 @@ class ManyToManyFetcher(Fetcher):
|
||||
|
||||
class ManyToOneFetcher(Fetcher):
|
||||
def append_entity(self, entity):
|
||||
self.parent_dict[remote_values(entity, self.prop)] = entity
|
||||
self.parent_dict[remote_values(self.prop, entity)] = entity
|
||||
|
||||
|
||||
class OneToManyFetcher(Fetcher):
|
||||
def append_entity(self, entity):
|
||||
self.parent_dict[remote_values(entity, self.prop)].append(
|
||||
self.parent_dict[remote_values(self.prop, entity)].append(
|
||||
entity
|
||||
)
|
||||
|
@@ -1,4 +1,6 @@
|
||||
from functools import partial
|
||||
from funcy import first
|
||||
from toolz import curry
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
@@ -51,14 +53,47 @@ def getattrs(obj, attrs):
|
||||
return map(partial(getattr, obj), attrs)
|
||||
|
||||
|
||||
def local_values(entity, prop):
|
||||
def mapfirst(iterable):
|
||||
return map(first, iterable)
|
||||
|
||||
|
||||
@curry
|
||||
def local_values(prop, entity):
|
||||
return tuple(getattrs(entity, local_column_names(prop)))
|
||||
|
||||
|
||||
def remote_values(entity, prop):
|
||||
def list_local_values(prop, entities):
|
||||
return map(local_values(prop), entities)
|
||||
|
||||
|
||||
def remote_values(prop, entity):
|
||||
return tuple(getattrs(entity, remote_column_names(prop)))
|
||||
|
||||
|
||||
@curry
|
||||
def local_remote_expr(prop, entity):
|
||||
return sa.and_(
|
||||
*[
|
||||
getattr(remote(prop), r.name)
|
||||
==
|
||||
getattr(entity, l.name)
|
||||
for l, r in prop.local_remote_pairs
|
||||
if r in remote_column_names(prop)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def list_local_remote_exprs(prop, entities):
|
||||
return map(local_remote_expr(prop), entities)
|
||||
|
||||
|
||||
def remote(prop):
|
||||
try:
|
||||
return prop.secondary.c
|
||||
except AttributeError:
|
||||
return prop.mapper.class_
|
||||
|
||||
|
||||
def local_column_names(prop):
|
||||
if not hasattr(prop, 'secondary'):
|
||||
yield prop._discriminator_col.key
|
||||
|
Reference in New Issue
Block a user