From 09cb5049467b91fd68b5e3600c29b94bb3656231 Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Thu, 26 Dec 2013 00:15:17 +0200 Subject: [PATCH] Refactor batch fetch module --- sqlalchemy_utils/batch.py | 40 +++++++++++-------------------- sqlalchemy_utils/functions/orm.py | 18 +++++++++++++- 2 files changed, 31 insertions(+), 27 deletions(-) diff --git a/sqlalchemy_utils/batch.py b/sqlalchemy_utils/batch.py index c58d272..b83a967 100644 --- a/sqlalchemy_utils/batch.py +++ b/sqlalchemy_utils/batch.py @@ -11,7 +11,10 @@ from sqlalchemy_utils.generic import ( GenericRelationshipProperty, class_from_table_name ) from sqlalchemy_utils.functions.orm import ( - local_column_names, remote_column_names + local_values, + local_column_names, + remote_column_names, + remote_values ) @@ -225,10 +228,7 @@ class CompositeFetcher(object): def fetch(self): for entity in self.related_entities: for fetcher in self.fetchers: - if any( - getattr(entity, name) - for name in remote_column_names(fetcher.prop) - ): + if any(remote_values(entity, fetcher.prop)): fetcher.append_entity(entity) def populate(self): @@ -240,16 +240,10 @@ class AbstractFetcher(object): @property def local_values_list(self): return [ - self.local_values(entity) + local_values(entity, self.prop) for entity in self.path.entities ] - def local_values(self, entity): - return tuple( - getattr(entity, name) - for name in local_column_names(self.prop) - ) - class Fetcher(AbstractFetcher): def __init__(self, path): @@ -260,12 +254,6 @@ class Fetcher(AbstractFetcher): else: self.parent_dict = defaultdict(lambda: None) - def parent_key(self, entity): - return tuple( - getattr(entity, name) - for name in remote_column_names(self.prop) - ) - @property def relation_query_base(self): return self.path.session.query(self.path.model) @@ -279,11 +267,11 @@ class Fetcher(AbstractFetcher): Populates backrefs for given related entities. """ backref_dict = dict( - (self.local_values(value[0]), []) + (local_values(value[0], self.prop), []) for value in related_entities ) for value in related_entities: - backref_dict[self.local_values(value[0])].append( + backref_dict[local_values(value[0], self.prop)].append( self.path.session.query(self.path.parent_model).get( tuple(value[1:]) ) @@ -292,7 +280,7 @@ class Fetcher(AbstractFetcher): set_committed_value( value[0], self.prop.back_populates, - backref_dict[self.local_values(value[0])] + backref_dict[local_values(value[0], self.prop)] ) def populate(self): @@ -303,7 +291,7 @@ class Fetcher(AbstractFetcher): set_committed_value( entity, self.prop.key, - self.parent_dict[self.local_values(entity)] + self.parent_dict[local_values(entity, self.prop)] ) if self.path.populate_backrefs: @@ -359,7 +347,7 @@ class GenericRelationshipFetcher(AbstractFetcher): return (entity.__tablename__, getattr(entity, 'id')) def append_entity(self, entity): - self.parent_dict[self.parent_key(entity)] = entity + self.parent_dict[remote_values(entity, self.prop)] = entity def populate(self): """ @@ -369,7 +357,7 @@ class GenericRelationshipFetcher(AbstractFetcher): set_committed_value( entity, self.prop.key, - self.parent_dict[self.local_values(entity)] + self.parent_dict[local_values(entity, self.prop)] ) @property @@ -424,11 +412,11 @@ class ManyToManyFetcher(Fetcher): class ManyToOneFetcher(Fetcher): def append_entity(self, entity): - self.parent_dict[self.parent_key(entity)] = entity + self.parent_dict[remote_values(entity, self.prop)] = entity class OneToManyFetcher(Fetcher): def append_entity(self, entity): - self.parent_dict[self.parent_key(entity)].append( + self.parent_dict[remote_values(entity, self.prop)].append( entity ) diff --git a/sqlalchemy_utils/functions/orm.py b/sqlalchemy_utils/functions/orm.py index df11677..1886a73 100644 --- a/sqlalchemy_utils/functions/orm.py +++ b/sqlalchemy_utils/functions/orm.py @@ -1,3 +1,4 @@ +from functools import partial import sqlalchemy as sa @@ -46,6 +47,18 @@ def table_name(obj): pass +def getattrs(obj, attrs): + return map(partial(getattr, obj), attrs) + + +def local_values(entity, prop): + return tuple(getattrs(entity, local_column_names(prop))) + + +def remote_values(entity, prop): + return tuple(getattrs(entity, remote_column_names(prop))) + + def local_column_names(prop): if not hasattr(prop, 'secondary'): yield prop._discriminator_col.key @@ -62,7 +75,10 @@ def local_column_names(prop): def remote_column_names(prop): - if not hasattr(prop, 'secondary') or prop.secondary is None: + if not hasattr(prop, 'secondary'): + yield '__tablename__' + yield 'id' + elif prop.secondary is None: for _, remote in prop.local_remote_pairs: yield remote.name else: