Refactor batch fetch module

This commit is contained in:
Konsta Vesterinen
2013-12-26 00:15:17 +02:00
parent f2cd9ffdf5
commit 09cb504946
2 changed files with 31 additions and 27 deletions

View File

@@ -11,7 +11,10 @@ from sqlalchemy_utils.generic import (
GenericRelationshipProperty, class_from_table_name
)
from sqlalchemy_utils.functions.orm import (
local_column_names, remote_column_names
local_values,
local_column_names,
remote_column_names,
remote_values
)
@@ -225,10 +228,7 @@ class CompositeFetcher(object):
def fetch(self):
for entity in self.related_entities:
for fetcher in self.fetchers:
if any(
getattr(entity, name)
for name in remote_column_names(fetcher.prop)
):
if any(remote_values(entity, fetcher.prop)):
fetcher.append_entity(entity)
def populate(self):
@@ -240,16 +240,10 @@ class AbstractFetcher(object):
@property
def local_values_list(self):
return [
self.local_values(entity)
local_values(entity, self.prop)
for entity in self.path.entities
]
def local_values(self, entity):
return tuple(
getattr(entity, name)
for name in local_column_names(self.prop)
)
class Fetcher(AbstractFetcher):
def __init__(self, path):
@@ -260,12 +254,6 @@ class Fetcher(AbstractFetcher):
else:
self.parent_dict = defaultdict(lambda: None)
def parent_key(self, entity):
return tuple(
getattr(entity, name)
for name in remote_column_names(self.prop)
)
@property
def relation_query_base(self):
return self.path.session.query(self.path.model)
@@ -279,11 +267,11 @@ class Fetcher(AbstractFetcher):
Populates backrefs for given related entities.
"""
backref_dict = dict(
(self.local_values(value[0]), [])
(local_values(value[0], self.prop), [])
for value in related_entities
)
for value in related_entities:
backref_dict[self.local_values(value[0])].append(
backref_dict[local_values(value[0], self.prop)].append(
self.path.session.query(self.path.parent_model).get(
tuple(value[1:])
)
@@ -292,7 +280,7 @@ class Fetcher(AbstractFetcher):
set_committed_value(
value[0],
self.prop.back_populates,
backref_dict[self.local_values(value[0])]
backref_dict[local_values(value[0], self.prop)]
)
def populate(self):
@@ -303,7 +291,7 @@ class Fetcher(AbstractFetcher):
set_committed_value(
entity,
self.prop.key,
self.parent_dict[self.local_values(entity)]
self.parent_dict[local_values(entity, self.prop)]
)
if self.path.populate_backrefs:
@@ -359,7 +347,7 @@ class GenericRelationshipFetcher(AbstractFetcher):
return (entity.__tablename__, getattr(entity, 'id'))
def append_entity(self, entity):
self.parent_dict[self.parent_key(entity)] = entity
self.parent_dict[remote_values(entity, self.prop)] = entity
def populate(self):
"""
@@ -369,7 +357,7 @@ class GenericRelationshipFetcher(AbstractFetcher):
set_committed_value(
entity,
self.prop.key,
self.parent_dict[self.local_values(entity)]
self.parent_dict[local_values(entity, self.prop)]
)
@property
@@ -424,11 +412,11 @@ class ManyToManyFetcher(Fetcher):
class ManyToOneFetcher(Fetcher):
def append_entity(self, entity):
self.parent_dict[self.parent_key(entity)] = entity
self.parent_dict[remote_values(entity, self.prop)] = entity
class OneToManyFetcher(Fetcher):
def append_entity(self, entity):
self.parent_dict[self.parent_key(entity)].append(
self.parent_dict[remote_values(entity, self.prop)].append(
entity
)

View File

@@ -1,3 +1,4 @@
from functools import partial
import sqlalchemy as sa
@@ -46,6 +47,18 @@ def table_name(obj):
pass
def getattrs(obj, attrs):
return map(partial(getattr, obj), attrs)
def local_values(entity, prop):
return tuple(getattrs(entity, local_column_names(prop)))
def remote_values(entity, prop):
return tuple(getattrs(entity, remote_column_names(prop)))
def local_column_names(prop):
if not hasattr(prop, 'secondary'):
yield prop._discriminator_col.key
@@ -62,7 +75,10 @@ def local_column_names(prop):
def remote_column_names(prop):
if not hasattr(prop, 'secondary') or prop.secondary is None:
if not hasattr(prop, 'secondary'):
yield '__tablename__'
yield 'id'
elif prop.secondary is None:
for _, remote in prop.local_remote_pairs:
yield remote.name
else: