Refactor batch fetch
This commit is contained in:
@@ -11,10 +11,15 @@ from sqlalchemy_utils.generic import (
|
|||||||
GenericRelationshipProperty, class_from_table_name
|
GenericRelationshipProperty, class_from_table_name
|
||||||
)
|
)
|
||||||
from sqlalchemy_utils.functions.orm import (
|
from sqlalchemy_utils.functions.orm import (
|
||||||
|
list_local_values,
|
||||||
|
list_local_remote_exprs,
|
||||||
local_values,
|
local_values,
|
||||||
local_column_names,
|
local_column_names,
|
||||||
|
local_remote_expr,
|
||||||
|
mapfirst,
|
||||||
remote_column_names,
|
remote_column_names,
|
||||||
remote_values
|
remote_values,
|
||||||
|
remote
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -228,7 +233,7 @@ class CompositeFetcher(object):
|
|||||||
def fetch(self):
|
def fetch(self):
|
||||||
for entity in self.related_entities:
|
for entity in self.related_entities:
|
||||||
for fetcher in self.fetchers:
|
for fetcher in self.fetchers:
|
||||||
if any(remote_values(entity, fetcher.prop)):
|
if any(remote_values(fetcher.prop, entity)):
|
||||||
fetcher.append_entity(entity)
|
fetcher.append_entity(entity)
|
||||||
|
|
||||||
def populate(self):
|
def populate(self):
|
||||||
@@ -236,23 +241,13 @@ class CompositeFetcher(object):
|
|||||||
fetcher.populate()
|
fetcher.populate()
|
||||||
|
|
||||||
|
|
||||||
class AbstractFetcher(object):
|
class Fetcher(object):
|
||||||
@property
|
|
||||||
def local_values_list(self):
|
|
||||||
return [
|
|
||||||
local_values(entity, self.prop)
|
|
||||||
for entity in self.path.entities
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class Fetcher(AbstractFetcher):
|
|
||||||
def __init__(self, path):
|
def __init__(self, path):
|
||||||
self.path = path
|
self.path = path
|
||||||
self.prop = self.path.property
|
self.prop = self.path.property
|
||||||
if self.prop.uselist:
|
|
||||||
self.parent_dict = defaultdict(list)
|
default = list if self.prop.uselist else lambda: None
|
||||||
else:
|
self.parent_dict = defaultdict(default)
|
||||||
self.parent_dict = defaultdict(lambda: None)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def relation_query_base(self):
|
def relation_query_base(self):
|
||||||
@@ -267,11 +262,11 @@ class Fetcher(AbstractFetcher):
|
|||||||
Populates backrefs for given related entities.
|
Populates backrefs for given related entities.
|
||||||
"""
|
"""
|
||||||
backref_dict = dict(
|
backref_dict = dict(
|
||||||
(local_values(value[0], self.prop), [])
|
(local_values(self.prop, value[0]), [])
|
||||||
for value in related_entities
|
for value in related_entities
|
||||||
)
|
)
|
||||||
for value in related_entities:
|
for value in related_entities:
|
||||||
backref_dict[local_values(value[0], self.prop)].append(
|
backref_dict[local_values(self.prop, value[0])].append(
|
||||||
self.path.session.query(self.path.parent_model).get(
|
self.path.session.query(self.path.parent_model).get(
|
||||||
tuple(value[1:])
|
tuple(value[1:])
|
||||||
)
|
)
|
||||||
@@ -280,7 +275,7 @@ class Fetcher(AbstractFetcher):
|
|||||||
set_committed_value(
|
set_committed_value(
|
||||||
value[0],
|
value[0],
|
||||||
self.prop.back_populates,
|
self.prop.back_populates,
|
||||||
backref_dict[local_values(value[0], self.prop)]
|
backref_dict[local_values(self.prop, value[0])]
|
||||||
)
|
)
|
||||||
|
|
||||||
def populate(self):
|
def populate(self):
|
||||||
@@ -291,38 +286,24 @@ class Fetcher(AbstractFetcher):
|
|||||||
set_committed_value(
|
set_committed_value(
|
||||||
entity,
|
entity,
|
||||||
self.prop.key,
|
self.prop.key,
|
||||||
self.parent_dict[local_values(entity, self.prop)]
|
self.parent_dict[local_values(self.prop, entity)]
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.path.populate_backrefs:
|
if self.path.populate_backrefs:
|
||||||
self.populate_backrefs(self.related_entities)
|
self.populate_backrefs(self.related_entities)
|
||||||
|
|
||||||
@property
|
|
||||||
def remote(self):
|
|
||||||
return self.path.model
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def condition(self):
|
def condition(self):
|
||||||
names = list(remote_column_names(self.prop))
|
names = list(remote_column_names(self.prop))
|
||||||
if len(names) == 1:
|
if len(names) == 1:
|
||||||
return getattr(self.remote, names[0]).in_(
|
attr = getattr(remote(self.prop), names[0])
|
||||||
value[0] for value in self.local_values_list
|
return attr.in_(
|
||||||
|
mapfirst(list_local_values(self.prop, self.path.entities))
|
||||||
)
|
)
|
||||||
elif len(names) > 1:
|
elif len(names) > 1:
|
||||||
conditions = []
|
return sa.or_(
|
||||||
for entity in self.path.entities:
|
*list_local_remote_exprs(self.prop, self.path.entities)
|
||||||
conditions.append(
|
)
|
||||||
sa.and_(
|
|
||||||
*[
|
|
||||||
getattr(self.remote, remote.name)
|
|
||||||
==
|
|
||||||
getattr(entity, local.name)
|
|
||||||
for local, remote in self.prop.local_remote_pairs
|
|
||||||
if remote in names
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return sa.or_(*conditions)
|
|
||||||
else:
|
else:
|
||||||
raise PathException(
|
raise PathException(
|
||||||
'Could not obtain remote column names.'
|
'Could not obtain remote column names.'
|
||||||
@@ -333,7 +314,7 @@ class Fetcher(AbstractFetcher):
|
|||||||
self.append_entity(entity)
|
self.append_entity(entity)
|
||||||
|
|
||||||
|
|
||||||
class GenericRelationshipFetcher(AbstractFetcher):
|
class GenericRelationshipFetcher(object):
|
||||||
def __init__(self, path):
|
def __init__(self, path):
|
||||||
self.path = path
|
self.path = path
|
||||||
self.prop = self.path.property
|
self.prop = self.path.property
|
||||||
@@ -344,7 +325,7 @@ class GenericRelationshipFetcher(AbstractFetcher):
|
|||||||
self.append_entity(entity)
|
self.append_entity(entity)
|
||||||
|
|
||||||
def append_entity(self, entity):
|
def append_entity(self, entity):
|
||||||
self.parent_dict[remote_values(entity, self.prop)] = entity
|
self.parent_dict[remote_values(self.prop, entity)] = entity
|
||||||
|
|
||||||
def populate(self):
|
def populate(self):
|
||||||
"""
|
"""
|
||||||
@@ -354,7 +335,7 @@ class GenericRelationshipFetcher(AbstractFetcher):
|
|||||||
set_committed_value(
|
set_committed_value(
|
||||||
entity,
|
entity,
|
||||||
self.prop.key,
|
self.prop.key,
|
||||||
self.parent_dict[local_values(entity, self.prop)]
|
self.parent_dict[local_values(self.prop, entity)]
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -380,10 +361,6 @@ class GenericRelationshipFetcher(AbstractFetcher):
|
|||||||
|
|
||||||
|
|
||||||
class ManyToManyFetcher(Fetcher):
|
class ManyToManyFetcher(Fetcher):
|
||||||
@property
|
|
||||||
def remote(self):
|
|
||||||
return self.prop.secondary.c
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def relation_query_base(self):
|
def relation_query_base(self):
|
||||||
return (
|
return (
|
||||||
@@ -391,7 +368,7 @@ class ManyToManyFetcher(Fetcher):
|
|||||||
.query(
|
.query(
|
||||||
self.path.model,
|
self.path.model,
|
||||||
*[
|
*[
|
||||||
getattr(self.prop.secondary.c, name)
|
getattr(remote(self.prop), name)
|
||||||
for name in remote_column_names(self.prop)
|
for name in remote_column_names(self.prop)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -409,11 +386,11 @@ class ManyToManyFetcher(Fetcher):
|
|||||||
|
|
||||||
class ManyToOneFetcher(Fetcher):
|
class ManyToOneFetcher(Fetcher):
|
||||||
def append_entity(self, entity):
|
def append_entity(self, entity):
|
||||||
self.parent_dict[remote_values(entity, self.prop)] = entity
|
self.parent_dict[remote_values(self.prop, entity)] = entity
|
||||||
|
|
||||||
|
|
||||||
class OneToManyFetcher(Fetcher):
|
class OneToManyFetcher(Fetcher):
|
||||||
def append_entity(self, entity):
|
def append_entity(self, entity):
|
||||||
self.parent_dict[remote_values(entity, self.prop)].append(
|
self.parent_dict[remote_values(self.prop, entity)].append(
|
||||||
entity
|
entity
|
||||||
)
|
)
|
||||||
|
@@ -1,4 +1,6 @@
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
|
from funcy import first
|
||||||
|
from toolz import curry
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
@@ -51,14 +53,47 @@ def getattrs(obj, attrs):
|
|||||||
return map(partial(getattr, obj), attrs)
|
return map(partial(getattr, obj), attrs)
|
||||||
|
|
||||||
|
|
||||||
def local_values(entity, prop):
|
def mapfirst(iterable):
|
||||||
|
return map(first, iterable)
|
||||||
|
|
||||||
|
|
||||||
|
@curry
|
||||||
|
def local_values(prop, entity):
|
||||||
return tuple(getattrs(entity, local_column_names(prop)))
|
return tuple(getattrs(entity, local_column_names(prop)))
|
||||||
|
|
||||||
|
|
||||||
def remote_values(entity, prop):
|
def list_local_values(prop, entities):
|
||||||
|
return map(local_values(prop), entities)
|
||||||
|
|
||||||
|
|
||||||
|
def remote_values(prop, entity):
|
||||||
return tuple(getattrs(entity, remote_column_names(prop)))
|
return tuple(getattrs(entity, remote_column_names(prop)))
|
||||||
|
|
||||||
|
|
||||||
|
@curry
|
||||||
|
def local_remote_expr(prop, entity):
|
||||||
|
return sa.and_(
|
||||||
|
*[
|
||||||
|
getattr(remote(prop), r.name)
|
||||||
|
==
|
||||||
|
getattr(entity, l.name)
|
||||||
|
for l, r in prop.local_remote_pairs
|
||||||
|
if r in remote_column_names(prop)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def list_local_remote_exprs(prop, entities):
|
||||||
|
return map(local_remote_expr(prop), entities)
|
||||||
|
|
||||||
|
|
||||||
|
def remote(prop):
|
||||||
|
try:
|
||||||
|
return prop.secondary.c
|
||||||
|
except AttributeError:
|
||||||
|
return prop.mapper.class_
|
||||||
|
|
||||||
|
|
||||||
def local_column_names(prop):
|
def local_column_names(prop):
|
||||||
if not hasattr(prop, 'secondary'):
|
if not hasattr(prop, 'secondary'):
|
||||||
yield prop._discriminator_col.key
|
yield prop._discriminator_col.key
|
||||||
|
Reference in New Issue
Block a user