diff --git a/sqlalchemy_utils/batch.py b/sqlalchemy_utils/batch.py index cea052d..620fa95 100644 --- a/sqlalchemy_utils/batch.py +++ b/sqlalchemy_utils/batch.py @@ -11,10 +11,15 @@ from sqlalchemy_utils.generic import ( GenericRelationshipProperty, class_from_table_name ) from sqlalchemy_utils.functions.orm import ( + list_local_values, + list_local_remote_exprs, local_values, local_column_names, + local_remote_expr, + mapfirst, remote_column_names, - remote_values + remote_values, + remote ) @@ -228,7 +233,7 @@ class CompositeFetcher(object): def fetch(self): for entity in self.related_entities: for fetcher in self.fetchers: - if any(remote_values(entity, fetcher.prop)): + if any(remote_values(fetcher.prop, entity)): fetcher.append_entity(entity) def populate(self): @@ -236,23 +241,13 @@ class CompositeFetcher(object): fetcher.populate() -class AbstractFetcher(object): - @property - def local_values_list(self): - return [ - local_values(entity, self.prop) - for entity in self.path.entities - ] - - -class Fetcher(AbstractFetcher): +class Fetcher(object): def __init__(self, path): self.path = path self.prop = self.path.property - if self.prop.uselist: - self.parent_dict = defaultdict(list) - else: - self.parent_dict = defaultdict(lambda: None) + + default = list if self.prop.uselist else lambda: None + self.parent_dict = defaultdict(default) @property def relation_query_base(self): @@ -267,11 +262,11 @@ class Fetcher(AbstractFetcher): Populates backrefs for given related entities. """ backref_dict = dict( - (local_values(value[0], self.prop), []) + (local_values(self.prop, value[0]), []) for value in related_entities ) for value in related_entities: - backref_dict[local_values(value[0], self.prop)].append( + backref_dict[local_values(self.prop, value[0])].append( self.path.session.query(self.path.parent_model).get( tuple(value[1:]) ) @@ -280,7 +275,7 @@ class Fetcher(AbstractFetcher): set_committed_value( value[0], self.prop.back_populates, - backref_dict[local_values(value[0], self.prop)] + backref_dict[local_values(self.prop, value[0])] ) def populate(self): @@ -291,38 +286,24 @@ class Fetcher(AbstractFetcher): set_committed_value( entity, self.prop.key, - self.parent_dict[local_values(entity, self.prop)] + self.parent_dict[local_values(self.prop, entity)] ) if self.path.populate_backrefs: self.populate_backrefs(self.related_entities) - @property - def remote(self): - return self.path.model - @property def condition(self): names = list(remote_column_names(self.prop)) if len(names) == 1: - return getattr(self.remote, names[0]).in_( - value[0] for value in self.local_values_list + attr = getattr(remote(self.prop), names[0]) + return attr.in_( + mapfirst(list_local_values(self.prop, self.path.entities)) ) elif len(names) > 1: - conditions = [] - for entity in self.path.entities: - conditions.append( - sa.and_( - *[ - getattr(self.remote, remote.name) - == - getattr(entity, local.name) - for local, remote in self.prop.local_remote_pairs - if remote in names - ] - ) - ) - return sa.or_(*conditions) + return sa.or_( + *list_local_remote_exprs(self.prop, self.path.entities) + ) else: raise PathException( 'Could not obtain remote column names.' @@ -333,7 +314,7 @@ class Fetcher(AbstractFetcher): self.append_entity(entity) -class GenericRelationshipFetcher(AbstractFetcher): +class GenericRelationshipFetcher(object): def __init__(self, path): self.path = path self.prop = self.path.property @@ -344,7 +325,7 @@ class GenericRelationshipFetcher(AbstractFetcher): self.append_entity(entity) def append_entity(self, entity): - self.parent_dict[remote_values(entity, self.prop)] = entity + self.parent_dict[remote_values(self.prop, entity)] = entity def populate(self): """ @@ -354,7 +335,7 @@ class GenericRelationshipFetcher(AbstractFetcher): set_committed_value( entity, self.prop.key, - self.parent_dict[local_values(entity, self.prop)] + self.parent_dict[local_values(self.prop, entity)] ) @property @@ -380,10 +361,6 @@ class GenericRelationshipFetcher(AbstractFetcher): class ManyToManyFetcher(Fetcher): - @property - def remote(self): - return self.prop.secondary.c - @property def relation_query_base(self): return ( @@ -391,7 +368,7 @@ class ManyToManyFetcher(Fetcher): .query( self.path.model, *[ - getattr(self.prop.secondary.c, name) + getattr(remote(self.prop), name) for name in remote_column_names(self.prop) ] ) @@ -409,11 +386,11 @@ class ManyToManyFetcher(Fetcher): class ManyToOneFetcher(Fetcher): def append_entity(self, entity): - self.parent_dict[remote_values(entity, self.prop)] = entity + self.parent_dict[remote_values(self.prop, entity)] = entity class OneToManyFetcher(Fetcher): def append_entity(self, entity): - self.parent_dict[remote_values(entity, self.prop)].append( + self.parent_dict[remote_values(self.prop, entity)].append( entity ) diff --git a/sqlalchemy_utils/functions/orm.py b/sqlalchemy_utils/functions/orm.py index 1886a73..14f4e09 100644 --- a/sqlalchemy_utils/functions/orm.py +++ b/sqlalchemy_utils/functions/orm.py @@ -1,4 +1,6 @@ from functools import partial +from funcy import first +from toolz import curry import sqlalchemy as sa @@ -51,14 +53,47 @@ def getattrs(obj, attrs): return map(partial(getattr, obj), attrs) -def local_values(entity, prop): +def mapfirst(iterable): + return map(first, iterable) + + +@curry +def local_values(prop, entity): return tuple(getattrs(entity, local_column_names(prop))) -def remote_values(entity, prop): +def list_local_values(prop, entities): + return map(local_values(prop), entities) + + +def remote_values(prop, entity): return tuple(getattrs(entity, remote_column_names(prop))) +@curry +def local_remote_expr(prop, entity): + return sa.and_( + *[ + getattr(remote(prop), r.name) + == + getattr(entity, l.name) + for l, r in prop.local_remote_pairs + if r in remote_column_names(prop) + ] + ) + + +def list_local_remote_exprs(prop, entities): + return map(local_remote_expr(prop), entities) + + +def remote(prop): + try: + return prop.secondary.c + except AttributeError: + return prop.mapper.class_ + + def local_column_names(prop): if not hasattr(prop, 'secondary'): yield prop._discriminator_col.key