diff --git a/.travis.yml b/.travis.yml index 9725881..cfa3845 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,6 +8,6 @@ python: - 2.7 - 3.3 install: - - pip install -q -e ".[test]" --use-mirrors + - pip install -e ".[test]" script: - python setup.py test diff --git a/setup.py b/setup.py index 8b8b622..4190304 100644 --- a/setup.py +++ b/setup.py @@ -33,7 +33,7 @@ extras_require = { 'Jinja2>=2.3', 'docutils>=0.10', 'flexmock>=0.9.7', - 'psycopg2>=2.4.6' + 'psycopg2>=2.4.6', ], 'arrow': ['arrow>=0.3.4'], 'phone': [ diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index d116613..1be7838 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -1,7 +1,12 @@ from .exceptions import ImproperlyConfigured from .functions import ( - sort_query, defer_except, escape_like, primary_keys, table_name, - render_statement + batch_fetch, + defer_except, + escape_like, + primary_keys, + render_statement, + sort_query, + table_name, ) from .listeners import coercion_listener from .merge import merge, Merger @@ -33,9 +38,8 @@ __version__ = '0.16.2' __all__ = ( - ImproperlyConfigured, + batch_fetch, coercion_listener, - sort_query, defer_except, escape_like, instrumented_list, @@ -43,10 +47,12 @@ __all__ = ( primary_keys, proxy_dict, render_statement, + sort_query, table_name, ArrowType, ColorType, EmailType, + ImproperlyConfigured, InstrumentedList, IPAddressType, Merger, @@ -59,8 +65,8 @@ __all__ = ( PhoneNumber, PhoneNumberType, ProxyDict, - ScalarListType, ScalarListException, + ScalarListType, TimezoneType, TSVectorType, UUIDType, diff --git a/sqlalchemy_utils/functions.py b/sqlalchemy_utils/functions.py index 61a35f4..1b6b497 100644 --- a/sqlalchemy_utils/functions.py +++ b/sqlalchemy_utils/functions.py @@ -386,3 +386,49 @@ def render_statement(statement, bind=None): return super(Compiler, self).render_literal_value(value, type_) return Compiler(bind.dialect, statement).process(statement) + + +from sqlalchemy.orm.session import object_session +from sqlalchemy.orm import RelationshipProperty +from sqlalchemy.orm.attributes import set_committed_value + + +def batch_fetch(entities, attr): + if entities: + first = entities[0] + if isinstance(attr, six.string_types): + attr = getattr( + first.__class__, attr + ) + + prop = attr.property + if not isinstance(prop, RelationshipProperty): + raise Exception( + 'Given attribute is not a relationship property.' + ) + + model = prop.mapper.class_ + session = object_session(first) + + if len(prop.remote_side) > 1: + raise Exception( + 'Only relationships with single remote side columns are ' + 'supported.' + ) + + column_name = list(prop.remote_side)[0].name + parent_ids = [entity.id for entity in entities] + + related_entities = ( + session.query(model) + .filter( + getattr(model, column_name).in_(parent_ids) + ) + ) + + parent_dict = dict((entity.id, []) for entity in entities) + for entity in related_entities: + parent_dict[getattr(entity, column_name)].append(entity) + + for entity in entities: + set_committed_value(entity, prop.key, parent_dict[entity.id])