Refactor batch fetch module
This commit is contained in:
@@ -11,7 +11,10 @@ from sqlalchemy_utils.generic import (
|
||||
GenericRelationshipProperty, class_from_table_name
|
||||
)
|
||||
from sqlalchemy_utils.functions.orm import (
|
||||
local_column_names, remote_column_names
|
||||
local_values,
|
||||
local_column_names,
|
||||
remote_column_names,
|
||||
remote_values
|
||||
)
|
||||
|
||||
|
||||
@@ -225,10 +228,7 @@ class CompositeFetcher(object):
|
||||
def fetch(self):
|
||||
for entity in self.related_entities:
|
||||
for fetcher in self.fetchers:
|
||||
if any(
|
||||
getattr(entity, name)
|
||||
for name in remote_column_names(fetcher.prop)
|
||||
):
|
||||
if any(remote_values(entity, fetcher.prop)):
|
||||
fetcher.append_entity(entity)
|
||||
|
||||
def populate(self):
|
||||
@@ -240,16 +240,10 @@ class AbstractFetcher(object):
|
||||
@property
|
||||
def local_values_list(self):
|
||||
return [
|
||||
self.local_values(entity)
|
||||
local_values(entity, self.prop)
|
||||
for entity in self.path.entities
|
||||
]
|
||||
|
||||
def local_values(self, entity):
|
||||
return tuple(
|
||||
getattr(entity, name)
|
||||
for name in local_column_names(self.prop)
|
||||
)
|
||||
|
||||
|
||||
class Fetcher(AbstractFetcher):
|
||||
def __init__(self, path):
|
||||
@@ -260,12 +254,6 @@ class Fetcher(AbstractFetcher):
|
||||
else:
|
||||
self.parent_dict = defaultdict(lambda: None)
|
||||
|
||||
def parent_key(self, entity):
|
||||
return tuple(
|
||||
getattr(entity, name)
|
||||
for name in remote_column_names(self.prop)
|
||||
)
|
||||
|
||||
@property
|
||||
def relation_query_base(self):
|
||||
return self.path.session.query(self.path.model)
|
||||
@@ -279,11 +267,11 @@ class Fetcher(AbstractFetcher):
|
||||
Populates backrefs for given related entities.
|
||||
"""
|
||||
backref_dict = dict(
|
||||
(self.local_values(value[0]), [])
|
||||
(local_values(value[0], self.prop), [])
|
||||
for value in related_entities
|
||||
)
|
||||
for value in related_entities:
|
||||
backref_dict[self.local_values(value[0])].append(
|
||||
backref_dict[local_values(value[0], self.prop)].append(
|
||||
self.path.session.query(self.path.parent_model).get(
|
||||
tuple(value[1:])
|
||||
)
|
||||
@@ -292,7 +280,7 @@ class Fetcher(AbstractFetcher):
|
||||
set_committed_value(
|
||||
value[0],
|
||||
self.prop.back_populates,
|
||||
backref_dict[self.local_values(value[0])]
|
||||
backref_dict[local_values(value[0], self.prop)]
|
||||
)
|
||||
|
||||
def populate(self):
|
||||
@@ -303,7 +291,7 @@ class Fetcher(AbstractFetcher):
|
||||
set_committed_value(
|
||||
entity,
|
||||
self.prop.key,
|
||||
self.parent_dict[self.local_values(entity)]
|
||||
self.parent_dict[local_values(entity, self.prop)]
|
||||
)
|
||||
|
||||
if self.path.populate_backrefs:
|
||||
@@ -359,7 +347,7 @@ class GenericRelationshipFetcher(AbstractFetcher):
|
||||
return (entity.__tablename__, getattr(entity, 'id'))
|
||||
|
||||
def append_entity(self, entity):
|
||||
self.parent_dict[self.parent_key(entity)] = entity
|
||||
self.parent_dict[remote_values(entity, self.prop)] = entity
|
||||
|
||||
def populate(self):
|
||||
"""
|
||||
@@ -369,7 +357,7 @@ class GenericRelationshipFetcher(AbstractFetcher):
|
||||
set_committed_value(
|
||||
entity,
|
||||
self.prop.key,
|
||||
self.parent_dict[self.local_values(entity)]
|
||||
self.parent_dict[local_values(entity, self.prop)]
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -424,11 +412,11 @@ class ManyToManyFetcher(Fetcher):
|
||||
|
||||
class ManyToOneFetcher(Fetcher):
|
||||
def append_entity(self, entity):
|
||||
self.parent_dict[self.parent_key(entity)] = entity
|
||||
self.parent_dict[remote_values(entity, self.prop)] = entity
|
||||
|
||||
|
||||
class OneToManyFetcher(Fetcher):
|
||||
def append_entity(self, entity):
|
||||
self.parent_dict[self.parent_key(entity)].append(
|
||||
self.parent_dict[remote_values(entity, self.prop)].append(
|
||||
entity
|
||||
)
|
||||
|
@@ -1,3 +1,4 @@
|
||||
from functools import partial
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
@@ -46,6 +47,18 @@ def table_name(obj):
|
||||
pass
|
||||
|
||||
|
||||
def getattrs(obj, attrs):
|
||||
return map(partial(getattr, obj), attrs)
|
||||
|
||||
|
||||
def local_values(entity, prop):
|
||||
return tuple(getattrs(entity, local_column_names(prop)))
|
||||
|
||||
|
||||
def remote_values(entity, prop):
|
||||
return tuple(getattrs(entity, remote_column_names(prop)))
|
||||
|
||||
|
||||
def local_column_names(prop):
|
||||
if not hasattr(prop, 'secondary'):
|
||||
yield prop._discriminator_col.key
|
||||
@@ -62,7 +75,10 @@ def local_column_names(prop):
|
||||
|
||||
|
||||
def remote_column_names(prop):
|
||||
if not hasattr(prop, 'secondary') or prop.secondary is None:
|
||||
if not hasattr(prop, 'secondary'):
|
||||
yield '__tablename__'
|
||||
yield 'id'
|
||||
elif prop.secondary is None:
|
||||
for _, remote in prop.local_remote_pairs:
|
||||
yield remote.name
|
||||
else:
|
||||
|
Reference in New Issue
Block a user