diff --git a/sqlalchemy_utils/batch.py b/sqlalchemy_utils/batch.py index afdd6ba..c42bec7 100644 --- a/sqlalchemy_utils/batch.py +++ b/sqlalchemy_utils/batch.py @@ -14,8 +14,6 @@ from sqlalchemy_utils.functions.orm import ( list_local_values, list_local_remote_exprs, local_values, - local_column_names, - local_remote_expr, remote_column_names, remote_values, remote @@ -339,9 +337,10 @@ class GenericRelationshipFetcher(object): id_dict = defaultdict(list) for entity in self.path.entities: discriminator = getattr(entity, self.prop._discriminator_col.key) - id_dict[discriminator].append( - getattr(entity, self.prop._id_col.key) - ) + for id_col in self.prop._id_cols: + id_dict[discriminator].append( + getattr(entity, id_col.key) + ) return chain(*self._queries(sa.inspect(entity), id_dict)) def _queries(self, state, id_dict): diff --git a/sqlalchemy_utils/functions/orm.py b/sqlalchemy_utils/functions/orm.py index c94a868..e7f4edd 100644 --- a/sqlalchemy_utils/functions/orm.py +++ b/sqlalchemy_utils/functions/orm.py @@ -77,7 +77,8 @@ def remote(prop): def local_column_names(prop): if not hasattr(prop, 'secondary'): yield prop._discriminator_col.key - yield prop._id_col.key + for id_col in prop._id_cols: + yield id_col.key elif prop.secondary is None: for local, _ in prop.local_remote_pairs: yield local.name