diff --git a/docs/model_helpers.rst b/docs/model_helpers.rst index 0bfee53..55d8205 100644 --- a/docs/model_helpers.rst +++ b/docs/model_helpers.rst @@ -52,6 +52,12 @@ get_tables .. autofunction:: get_tables +group_foreign_keys +^^^^^^^^^^^^^^^^^^ + +.. autofunction:: group_foreign_keys + + query_entities ^^^^^^^^^^^^^^ diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index d5ed830..9ef5214 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -17,6 +17,7 @@ from .functions import ( get_primary_keys, get_referencing_foreign_keys, get_tables, + group_foreign_keys, identity, mock_engine, naturally_equivalent, @@ -89,6 +90,7 @@ __all__ = ( get_primary_keys, get_referencing_foreign_keys, get_tables, + group_foreign_keys, identity, instrumented_list, merge, diff --git a/sqlalchemy_utils/functions/__init__.py b/sqlalchemy_utils/functions/__init__.py index 7b56fb3..95891ba 100644 --- a/sqlalchemy_utils/functions/__init__.py +++ b/sqlalchemy_utils/functions/__init__.py @@ -20,6 +20,7 @@ from .orm import ( get_referencing_foreign_keys, get_tables, getdotattr, + group_foreign_keys, has_changes, identity, naturally_equivalent, @@ -42,6 +43,7 @@ __all__ = ( 'get_referencing_foreign_keys', 'get_tables', 'getdotattr', + 'group_foreign_keys', 'has_changes', 'identity', 'is_auto_assigned_date_column', diff --git a/sqlalchemy_utils/functions/orm.py b/sqlalchemy_utils/functions/orm.py index df58bd7..0b4fd38 100644 --- a/sqlalchemy_utils/functions/orm.py +++ b/sqlalchemy_utils/functions/orm.py @@ -129,13 +129,10 @@ def dependent_objects(obj, foreign_keys=None): session = object_session(obj) - foreign_keys = sorted( - foreign_keys, key=lambda key: key.constraint.table.name - ) chain = QueryChain([]) classes = obj.__class__._decl_class_registry - for table, keys in groupby(foreign_keys, lambda key: key.constraint.table): + for table, keys in group_foreign_keys(foreign_keys): for class_ in classes.values(): if hasattr(class_, '__table__') and class_.__table__ == table: criteria = [] @@ -163,6 +160,32 @@ def dependent_objects(obj, foreign_keys=None): return chain +def group_foreign_keys(foreign_keys): + """ + Return a groupby iterator that groups given foreign keys by table. + + :param foreign_keys: a sequence of foreign keys + + + :: + + foreign_keys = get_referencing_foreign_keys(User) + + for table, fks in group_foreign_keys(foreign_keys): + # do something + pass + + + .. also:: :func:`get_referencing_foreign_keys` + + .. versionadded: 0.26.1 + """ + foreign_keys = sorted( + foreign_keys, key=lambda key: key.constraint.table.name + ) + return groupby(foreign_keys, lambda key: key.constraint.table) + + def get_referencing_foreign_keys(mixed): """ Returns referencing foreign keys for given Table object or declarative