Refactor batch fetch module

This commit is contained in:
Konsta Vesterinen
2013-12-26 00:15:17 +02:00
parent f2cd9ffdf5
commit 09cb504946
2 changed files with 31 additions and 27 deletions

View File

@@ -11,7 +11,10 @@ from sqlalchemy_utils.generic import (
GenericRelationshipProperty, class_from_table_name GenericRelationshipProperty, class_from_table_name
) )
from sqlalchemy_utils.functions.orm import ( from sqlalchemy_utils.functions.orm import (
local_column_names, remote_column_names local_values,
local_column_names,
remote_column_names,
remote_values
) )
@@ -225,10 +228,7 @@ class CompositeFetcher(object):
def fetch(self): def fetch(self):
for entity in self.related_entities: for entity in self.related_entities:
for fetcher in self.fetchers: for fetcher in self.fetchers:
if any( if any(remote_values(entity, fetcher.prop)):
getattr(entity, name)
for name in remote_column_names(fetcher.prop)
):
fetcher.append_entity(entity) fetcher.append_entity(entity)
def populate(self): def populate(self):
@@ -240,16 +240,10 @@ class AbstractFetcher(object):
@property @property
def local_values_list(self): def local_values_list(self):
return [ return [
self.local_values(entity) local_values(entity, self.prop)
for entity in self.path.entities for entity in self.path.entities
] ]
def local_values(self, entity):
return tuple(
getattr(entity, name)
for name in local_column_names(self.prop)
)
class Fetcher(AbstractFetcher): class Fetcher(AbstractFetcher):
def __init__(self, path): def __init__(self, path):
@@ -260,12 +254,6 @@ class Fetcher(AbstractFetcher):
else: else:
self.parent_dict = defaultdict(lambda: None) self.parent_dict = defaultdict(lambda: None)
def parent_key(self, entity):
return tuple(
getattr(entity, name)
for name in remote_column_names(self.prop)
)
@property @property
def relation_query_base(self): def relation_query_base(self):
return self.path.session.query(self.path.model) return self.path.session.query(self.path.model)
@@ -279,11 +267,11 @@ class Fetcher(AbstractFetcher):
Populates backrefs for given related entities. Populates backrefs for given related entities.
""" """
backref_dict = dict( backref_dict = dict(
(self.local_values(value[0]), []) (local_values(value[0], self.prop), [])
for value in related_entities for value in related_entities
) )
for value in related_entities: for value in related_entities:
backref_dict[self.local_values(value[0])].append( backref_dict[local_values(value[0], self.prop)].append(
self.path.session.query(self.path.parent_model).get( self.path.session.query(self.path.parent_model).get(
tuple(value[1:]) tuple(value[1:])
) )
@@ -292,7 +280,7 @@ class Fetcher(AbstractFetcher):
set_committed_value( set_committed_value(
value[0], value[0],
self.prop.back_populates, self.prop.back_populates,
backref_dict[self.local_values(value[0])] backref_dict[local_values(value[0], self.prop)]
) )
def populate(self): def populate(self):
@@ -303,7 +291,7 @@ class Fetcher(AbstractFetcher):
set_committed_value( set_committed_value(
entity, entity,
self.prop.key, self.prop.key,
self.parent_dict[self.local_values(entity)] self.parent_dict[local_values(entity, self.prop)]
) )
if self.path.populate_backrefs: if self.path.populate_backrefs:
@@ -359,7 +347,7 @@ class GenericRelationshipFetcher(AbstractFetcher):
return (entity.__tablename__, getattr(entity, 'id')) return (entity.__tablename__, getattr(entity, 'id'))
def append_entity(self, entity): def append_entity(self, entity):
self.parent_dict[self.parent_key(entity)] = entity self.parent_dict[remote_values(entity, self.prop)] = entity
def populate(self): def populate(self):
""" """
@@ -369,7 +357,7 @@ class GenericRelationshipFetcher(AbstractFetcher):
set_committed_value( set_committed_value(
entity, entity,
self.prop.key, self.prop.key,
self.parent_dict[self.local_values(entity)] self.parent_dict[local_values(entity, self.prop)]
) )
@property @property
@@ -424,11 +412,11 @@ class ManyToManyFetcher(Fetcher):
class ManyToOneFetcher(Fetcher): class ManyToOneFetcher(Fetcher):
def append_entity(self, entity): def append_entity(self, entity):
self.parent_dict[self.parent_key(entity)] = entity self.parent_dict[remote_values(entity, self.prop)] = entity
class OneToManyFetcher(Fetcher): class OneToManyFetcher(Fetcher):
def append_entity(self, entity): def append_entity(self, entity):
self.parent_dict[self.parent_key(entity)].append( self.parent_dict[remote_values(entity, self.prop)].append(
entity entity
) )

View File

@@ -1,3 +1,4 @@
from functools import partial
import sqlalchemy as sa import sqlalchemy as sa
@@ -46,6 +47,18 @@ def table_name(obj):
pass pass
def getattrs(obj, attrs):
return map(partial(getattr, obj), attrs)
def local_values(entity, prop):
return tuple(getattrs(entity, local_column_names(prop)))
def remote_values(entity, prop):
return tuple(getattrs(entity, remote_column_names(prop)))
def local_column_names(prop): def local_column_names(prop):
if not hasattr(prop, 'secondary'): if not hasattr(prop, 'secondary'):
yield prop._discriminator_col.key yield prop._discriminator_col.key
@@ -62,7 +75,10 @@ def local_column_names(prop):
def remote_column_names(prop): def remote_column_names(prop):
if not hasattr(prop, 'secondary') or prop.secondary is None: if not hasattr(prop, 'secondary'):
yield '__tablename__'
yield 'id'
elif prop.secondary is None:
for _, remote in prop.local_remote_pairs: for _, remote in prop.local_remote_pairs:
yield remote.name yield remote.name
else: else: