Refactor batch fetch module
This commit is contained in:
@@ -11,7 +11,10 @@ from sqlalchemy_utils.generic import (
|
|||||||
GenericRelationshipProperty, class_from_table_name
|
GenericRelationshipProperty, class_from_table_name
|
||||||
)
|
)
|
||||||
from sqlalchemy_utils.functions.orm import (
|
from sqlalchemy_utils.functions.orm import (
|
||||||
local_column_names, remote_column_names
|
local_values,
|
||||||
|
local_column_names,
|
||||||
|
remote_column_names,
|
||||||
|
remote_values
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -225,10 +228,7 @@ class CompositeFetcher(object):
|
|||||||
def fetch(self):
|
def fetch(self):
|
||||||
for entity in self.related_entities:
|
for entity in self.related_entities:
|
||||||
for fetcher in self.fetchers:
|
for fetcher in self.fetchers:
|
||||||
if any(
|
if any(remote_values(entity, fetcher.prop)):
|
||||||
getattr(entity, name)
|
|
||||||
for name in remote_column_names(fetcher.prop)
|
|
||||||
):
|
|
||||||
fetcher.append_entity(entity)
|
fetcher.append_entity(entity)
|
||||||
|
|
||||||
def populate(self):
|
def populate(self):
|
||||||
@@ -240,16 +240,10 @@ class AbstractFetcher(object):
|
|||||||
@property
|
@property
|
||||||
def local_values_list(self):
|
def local_values_list(self):
|
||||||
return [
|
return [
|
||||||
self.local_values(entity)
|
local_values(entity, self.prop)
|
||||||
for entity in self.path.entities
|
for entity in self.path.entities
|
||||||
]
|
]
|
||||||
|
|
||||||
def local_values(self, entity):
|
|
||||||
return tuple(
|
|
||||||
getattr(entity, name)
|
|
||||||
for name in local_column_names(self.prop)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Fetcher(AbstractFetcher):
|
class Fetcher(AbstractFetcher):
|
||||||
def __init__(self, path):
|
def __init__(self, path):
|
||||||
@@ -260,12 +254,6 @@ class Fetcher(AbstractFetcher):
|
|||||||
else:
|
else:
|
||||||
self.parent_dict = defaultdict(lambda: None)
|
self.parent_dict = defaultdict(lambda: None)
|
||||||
|
|
||||||
def parent_key(self, entity):
|
|
||||||
return tuple(
|
|
||||||
getattr(entity, name)
|
|
||||||
for name in remote_column_names(self.prop)
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def relation_query_base(self):
|
def relation_query_base(self):
|
||||||
return self.path.session.query(self.path.model)
|
return self.path.session.query(self.path.model)
|
||||||
@@ -279,11 +267,11 @@ class Fetcher(AbstractFetcher):
|
|||||||
Populates backrefs for given related entities.
|
Populates backrefs for given related entities.
|
||||||
"""
|
"""
|
||||||
backref_dict = dict(
|
backref_dict = dict(
|
||||||
(self.local_values(value[0]), [])
|
(local_values(value[0], self.prop), [])
|
||||||
for value in related_entities
|
for value in related_entities
|
||||||
)
|
)
|
||||||
for value in related_entities:
|
for value in related_entities:
|
||||||
backref_dict[self.local_values(value[0])].append(
|
backref_dict[local_values(value[0], self.prop)].append(
|
||||||
self.path.session.query(self.path.parent_model).get(
|
self.path.session.query(self.path.parent_model).get(
|
||||||
tuple(value[1:])
|
tuple(value[1:])
|
||||||
)
|
)
|
||||||
@@ -292,7 +280,7 @@ class Fetcher(AbstractFetcher):
|
|||||||
set_committed_value(
|
set_committed_value(
|
||||||
value[0],
|
value[0],
|
||||||
self.prop.back_populates,
|
self.prop.back_populates,
|
||||||
backref_dict[self.local_values(value[0])]
|
backref_dict[local_values(value[0], self.prop)]
|
||||||
)
|
)
|
||||||
|
|
||||||
def populate(self):
|
def populate(self):
|
||||||
@@ -303,7 +291,7 @@ class Fetcher(AbstractFetcher):
|
|||||||
set_committed_value(
|
set_committed_value(
|
||||||
entity,
|
entity,
|
||||||
self.prop.key,
|
self.prop.key,
|
||||||
self.parent_dict[self.local_values(entity)]
|
self.parent_dict[local_values(entity, self.prop)]
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.path.populate_backrefs:
|
if self.path.populate_backrefs:
|
||||||
@@ -359,7 +347,7 @@ class GenericRelationshipFetcher(AbstractFetcher):
|
|||||||
return (entity.__tablename__, getattr(entity, 'id'))
|
return (entity.__tablename__, getattr(entity, 'id'))
|
||||||
|
|
||||||
def append_entity(self, entity):
|
def append_entity(self, entity):
|
||||||
self.parent_dict[self.parent_key(entity)] = entity
|
self.parent_dict[remote_values(entity, self.prop)] = entity
|
||||||
|
|
||||||
def populate(self):
|
def populate(self):
|
||||||
"""
|
"""
|
||||||
@@ -369,7 +357,7 @@ class GenericRelationshipFetcher(AbstractFetcher):
|
|||||||
set_committed_value(
|
set_committed_value(
|
||||||
entity,
|
entity,
|
||||||
self.prop.key,
|
self.prop.key,
|
||||||
self.parent_dict[self.local_values(entity)]
|
self.parent_dict[local_values(entity, self.prop)]
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -424,11 +412,11 @@ class ManyToManyFetcher(Fetcher):
|
|||||||
|
|
||||||
class ManyToOneFetcher(Fetcher):
|
class ManyToOneFetcher(Fetcher):
|
||||||
def append_entity(self, entity):
|
def append_entity(self, entity):
|
||||||
self.parent_dict[self.parent_key(entity)] = entity
|
self.parent_dict[remote_values(entity, self.prop)] = entity
|
||||||
|
|
||||||
|
|
||||||
class OneToManyFetcher(Fetcher):
|
class OneToManyFetcher(Fetcher):
|
||||||
def append_entity(self, entity):
|
def append_entity(self, entity):
|
||||||
self.parent_dict[self.parent_key(entity)].append(
|
self.parent_dict[remote_values(entity, self.prop)].append(
|
||||||
entity
|
entity
|
||||||
)
|
)
|
||||||
|
@@ -1,3 +1,4 @@
|
|||||||
|
from functools import partial
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
@@ -46,6 +47,18 @@ def table_name(obj):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def getattrs(obj, attrs):
|
||||||
|
return map(partial(getattr, obj), attrs)
|
||||||
|
|
||||||
|
|
||||||
|
def local_values(entity, prop):
|
||||||
|
return tuple(getattrs(entity, local_column_names(prop)))
|
||||||
|
|
||||||
|
|
||||||
|
def remote_values(entity, prop):
|
||||||
|
return tuple(getattrs(entity, remote_column_names(prop)))
|
||||||
|
|
||||||
|
|
||||||
def local_column_names(prop):
|
def local_column_names(prop):
|
||||||
if not hasattr(prop, 'secondary'):
|
if not hasattr(prop, 'secondary'):
|
||||||
yield prop._discriminator_col.key
|
yield prop._discriminator_col.key
|
||||||
@@ -62,7 +75,10 @@ def local_column_names(prop):
|
|||||||
|
|
||||||
|
|
||||||
def remote_column_names(prop):
|
def remote_column_names(prop):
|
||||||
if not hasattr(prop, 'secondary') or prop.secondary is None:
|
if not hasattr(prop, 'secondary'):
|
||||||
|
yield '__tablename__'
|
||||||
|
yield 'id'
|
||||||
|
elif prop.secondary is None:
|
||||||
for _, remote in prop.local_remote_pairs:
|
for _, remote in prop.local_remote_pairs:
|
||||||
yield remote.name
|
yield remote.name
|
||||||
else:
|
else:
|
||||||
|
Reference in New Issue
Block a user