Add batch fetch support for generic relationships

This commit is contained in:
Konsta Vesterinen
2013-12-04 15:53:37 +02:00
parent 48b357cbb8
commit 22fbcfac3d
6 changed files with 143 additions and 686 deletions

View File

@@ -1,9 +1,9 @@
from .aggregates import aggregated from .aggregates import aggregated
from .batch_fetch import batch_fetch, with_backrefs
from .decorators import generates from .decorators import generates
from .eav import MetaValue, MetaType from .eav import MetaValue, MetaType
from .exceptions import ImproperlyConfigured from .exceptions import ImproperlyConfigured
from .functions import ( from .functions import (
batch_fetch,
defer_except, defer_except,
escape_like, escape_like,
identity, identity,
@@ -15,7 +15,6 @@ from .functions import (
mock_engine, mock_engine,
sort_query, sort_query,
table_name, table_name,
with_backrefs,
database_exists, database_exists,
create_database, create_database,
drop_database drop_database

View File

@@ -1,398 +0,0 @@
from collections import defaultdict
import six
import sqlalchemy as sa
from sqlalchemy.orm import RelationshipProperty
from sqlalchemy.orm.attributes import (
set_committed_value, InstrumentedAttribute
)
from sqlalchemy.orm.session import object_session
class PathException(Exception):
pass
class with_backrefs(object):
"""
Marks given attribute path so that whenever its fetched with batch_fetch
the backref relations are force set too. Very useful when dealing with
certain many-to-many relationship scenarios.
"""
def __init__(self, path):
self.path = path
class Path(object):
"""
A class that represents an attribute path.
"""
def __init__(self, entities, prop, populate_backrefs=False):
self.property = prop
self.entities = entities
self.populate_backrefs = populate_backrefs
if not isinstance(self.property, RelationshipProperty):
raise PathException(
'Given attribute is not a relationship property.'
)
self.fetcher = self.fetcher_class(self)
@property
def session(self):
return object_session(self.entities[0])
@property
def parent_model(self):
return self.entities[0].__class__
@property
def model(self):
return self.property.mapper.class_
@classmethod
def parse(cls, entities, path, populate_backrefs=False):
if isinstance(path, six.string_types):
attrs = path.split('.')
if len(attrs) > 1:
related_entities = []
for entity in entities:
related_entities.extend(getattr(entity, attrs[0]))
if not related_entities:
return
subpath = '.'.join(attrs[1:])
return Path.parse(related_entities, subpath, populate_backrefs)
else:
attr = getattr(
entities[0].__class__, attrs[0]
)
elif isinstance(path, InstrumentedAttribute):
attr = path
else:
raise PathException('Unknown path type.')
return Path(entities, attr.property, populate_backrefs)
@property
def fetcher_class(self):
if self.property.secondary is not None:
return ManyToManyFetcher
else:
if self.property.direction.name == 'MANYTOONE':
return ManyToOneFetcher
else:
return OneToManyFetcher
class CompositePath(object):
def __init__(self, *paths):
self.paths = paths
def batch_fetch(entities, *attr_paths):
"""
Batch fetch given relationship attribute for collection of entities.
This function is in many cases a valid alternative for SQLAlchemy's
subqueryload and performs lot better.
:param entities: list of entities of the same type
:param attr_paths:
List of either InstrumentedAttribute objects or a strings representing
the name of the instrumented attribute
Example::
from sqlalchemy_utils import batch_fetch
users = session.query(User).limit(20).all()
batch_fetch(users, User.phonenumbers)
Function also accepts strings as attribute names: ::
users = session.query(User).limit(20).all()
batch_fetch(users, 'phonenumbers')
Multiple attributes may be provided: ::
clubs = session.query(Club).limit(20).all()
batch_fetch(
clubs,
'teams',
'teams.players',
'teams.players.user_groups'
)
You can also force populate backrefs: ::
from sqlalchemy_utils import with_backrefs
clubs = session.query(Club).limit(20).all()
batch_fetch(
clubs,
'teams',
'teams.players',
with_backrefs('teams.players.user_groups')
)
"""
if entities:
for path in attr_paths:
fetcher = fetcher_factory(entities, path)
if fetcher:
fetcher.fetch()
fetcher.populate()
def fetcher_factory(entities, path):
populate_backrefs = False
if isinstance(path, with_backrefs):
path = path.path
populate_backrefs = True
if isinstance(path, CompositePath):
fetchers = []
for path in path.paths:
path = Path.parse(entities, path, populate_backrefs)
if path:
fetchers.append(
path.fetcher
)
return CompositeFetcher(*fetchers)
else:
path = Path.parse(entities, path, populate_backrefs)
if path:
return path.fetcher
class CompositeFetcher(object):
def __init__(self, *fetchers):
if not all(
fetchers[0].path.model == fetcher.path.model
for fetcher in fetchers
):
raise PathException(
'Each relationship property must have the same class when '
'using CompositeFetcher.'
)
self.fetchers = fetchers
@property
def session(self):
return self.fetchers[0].path.session
@property
def model(self):
return self.fetchers[0].path.model
@property
def condition(self):
return sa.or_(
*[fetcher.condition for fetcher in self.fetchers]
)
@property
def related_entities(self):
return self.session.query(self.model).filter(self.condition)
def fetch(self):
for entity in self.related_entities:
for fetcher in self.fetchers:
if any(
getattr(entity, name)
for name in fetcher.remote_column_names
):
fetcher.append_entity(entity)
def populate(self):
for fetcher in self.fetchers:
fetcher.populate()
class Fetcher(object):
def __init__(self, path):
self.path = path
self.prop = self.path.property
if self.prop.uselist:
self.parent_dict = defaultdict(list)
else:
self.parent_dict = defaultdict(lambda: None)
@property
def local_values_list(self):
return [
self.local_values(entity)
for entity in self.path.entities
]
@property
def relation_query_base(self):
return self.path.session.query(self.path.model)
@property
def related_entities(self):
return self.relation_query_base.filter(self.condition)
@property
def local_column_names(self):
return [local.name for local, remote in self.prop.local_remote_pairs]
def parent_key(self, entity):
return tuple(
getattr(entity, name)
for name in self.remote_column_names
)
def local_values(self, entity):
return tuple(
getattr(entity, name)
for name in self.local_column_names
)
def populate_backrefs(self, related_entities):
"""
Populates backrefs for given related entities.
"""
backref_dict = dict(
(self.local_values(value[0]), [])
for value in related_entities
)
for value in related_entities:
backref_dict[self.local_values(value[0])].append(
self.path.session.query(self.path.parent_model).get(
tuple(value[1:])
)
)
for value in related_entities:
set_committed_value(
value[0],
self.prop.back_populates,
backref_dict[self.local_values(value[0])]
)
def populate(self):
"""
Populate batch fetched entities to parent objects.
"""
for entity in self.path.entities:
set_committed_value(
entity,
self.prop.key,
self.parent_dict[self.local_values(entity)]
)
if self.path.populate_backrefs:
self.populate_backrefs(self.related_entities)
@property
def remote(self):
return self.path.model
@property
def condition(self):
names = self.remote_column_names
if len(names) == 1:
return getattr(self.remote, names[0]).in_(
value[0] for value in self.local_values_list
)
elif len(names) > 1:
conditions = []
for entity in self.path.entities:
conditions.append(
sa.and_(
*[
getattr(self.remote, remote.name)
==
getattr(entity, local.name)
for local, remote in self.prop.local_remote_pairs
if remote in self.remote_column_names
]
)
)
return sa.or_(*conditions)
else:
raise PathException(
'Could not obtain remote column names.'
)
def fetch(self):
for entity in self.related_entities:
self.append_entity(entity)
@property
def remote_column_names(self):
return [remote.name for local, remote in self.prop.local_remote_pairs]
class ManyToManyFetcher(Fetcher):
@property
def remote(self):
return self.prop.secondary.c
@property
def local_column_names(self):
names = []
for local, remote in self.prop.local_remote_pairs:
for fk in remote.foreign_keys:
if fk.column.table in self.prop.parent.tables:
names.append(local.name)
return names
@property
def remote_column_names(self):
names = []
for local, remote in self.prop.local_remote_pairs:
for fk in remote.foreign_keys:
if fk.column.table in self.prop.parent.tables:
names.append(remote.name)
return names
@property
def relation_query_base(self):
return (
self.path.session
.query(
self.path.model,
*[
getattr(self.prop.secondary.c, name)
for name in self.remote_column_names
]
)
.join(
self.prop.secondary, self.prop.secondaryjoin
)
)
def fetch(self):
for value in self.related_entities:
self.parent_dict[tuple(value[1:])].append(
value[0]
)
class ManyToOneFetcher(Fetcher):
def append_entity(self, entity):
#print 'appending entity ', entity, ' to key ', self.parent_key(entity)
self.parent_dict[self.parent_key(entity)] = entity
class OneToManyFetcher(Fetcher):
def append_entity(self, entity):
#print 'appending entity ', entity, ' to key ', self.parent_key(entity)
self.parent_dict[self.parent_key(entity)].append(
entity
)

View File

@@ -1,292 +1,48 @@
from collections import defaultdict
import sqlalchemy as sa
from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint
from .batch_fetch import batch_fetch, with_backrefs, CompositePath
from .defer_except import defer_except from .defer_except import defer_except
from .mock import create_mock_engine, mock_engine from .mock import create_mock_engine, mock_engine
from .render import render_expression, render_statement from .render import render_expression, render_statement
from .sort_query import sort_query, QuerySorterException from .sort_query import sort_query, QuerySorterException
from .database import database_exists, create_database, drop_database from .database import (
database_exists,
create_database,
drop_database,
escape_like,
is_auto_assigned_date_column,
is_indexed_foreign_key,
non_indexed_foreign_keys,
)
from .orm import (
primary_keys,
table_name,
declarative_base,
has_changes,
identity,
naturally_equivalent,
remove_property
)
__all__ = ( __all__ = (
batch_fetch,
create_mock_engine, create_mock_engine,
defer_except, defer_except,
mock_engine, mock_engine,
sort_query, sort_query,
render_expression, render_expression,
render_statement, render_statement,
with_backrefs,
CompositePath,
QuerySorterException, QuerySorterException,
database_exists, database_exists,
create_database, create_database,
drop_database drop_database,
escape_like,
is_auto_assigned_date_column,
is_indexed_foreign_key,
non_indexed_foreign_keys,
remove_property,
primary_keys,
table_name,
declarative_base,
has_changes,
identity,
naturally_equivalent,
) )
def escape_like(string, escape_char='*'):
"""
Escapes the string paremeter used in SQL LIKE expressions
>>> from sqlalchemy_utils import escape_like
>>> query = session.query(User).filter(
... User.name.ilike(escape_like('John'))
... )
:param string: a string to escape
:param escape_char: escape character
"""
return (
string
.replace(escape_char, escape_char * 2)
.replace('%', escape_char + '%')
.replace('_', escape_char + '_')
)
def remove_property(class_, name):
"""
**Experimental function**
Remove property from declarative class
"""
mapper = class_.mapper
table = class_.__table__
columns = class_.mapper.c
column = columns[name]
del columns._data[name]
del mapper.columns[name]
columns._all_cols.remove(column)
mapper._cols_by_table[table].remove(column)
mapper.class_manager.uninstrument_attribute(name)
del mapper._props[name]
def primary_keys(class_):
"""
Returns all primary keys for given declarative class.
"""
for column in class_.__table__.c:
if column.primary_key:
yield column
def table_name(obj):
"""
Return table name of given target, declarative class or the
table name where the declarative attribute is bound to.
"""
class_ = getattr(obj, 'class_', obj)
try:
return class_.__tablename__
except AttributeError:
pass
try:
return class_.__table__.name
except AttributeError:
pass
def non_indexed_foreign_keys(metadata, engine=None):
"""
Finds all non indexed foreign keys from all tables of given MetaData.
Very useful for optimizing postgresql database and finding out which
foreign keys need indexes.
:param metadata: MetaData object to inspect tables from
"""
reflected_metadata = MetaData()
if metadata.bind is None and engine is None:
raise Exception(
'Either pass a metadata object with bind or '
'pass engine as a second parameter'
)
constraints = defaultdict(list)
for table_name in metadata.tables.keys():
table = Table(
table_name,
reflected_metadata,
autoload=True,
autoload_with=metadata.bind or engine
)
for constraint in table.constraints:
if not isinstance(constraint, ForeignKeyConstraint):
continue
if not is_indexed_foreign_key(constraint):
constraints[table.name].append(constraint)
return dict(constraints)
def is_indexed_foreign_key(constraint):
"""
Whether or not given foreign key constraint's columns have been indexed.
:param constraint: ForeignKeyConstraint object to check the indexes
"""
for index in constraint.table.indexes:
index_column_names = set([
column.name for column in index.columns
])
if index_column_names == set(constraint.columns):
return True
return False
def declarative_base(model):
"""
Returns the declarative base for given model class.
:param model: SQLAlchemy declarative model
"""
for parent in model.__bases__:
try:
parent.metadata
return declarative_base(parent)
except AttributeError:
pass
return model
def is_auto_assigned_date_column(column):
"""
Returns whether or not given SQLAlchemy Column object's is auto assigned
DateTime or Date.
:param column: SQLAlchemy Column object
"""
return (
(
isinstance(column.type, sa.DateTime) or
isinstance(column.type, sa.Date)
)
and
(
column.default or
column.server_default or
column.onupdate or
column.server_onupdate
)
)
def has_changes(obj, attr):
"""
Simple shortcut function for checking if given attribute of given
declarative model object has changed during the transaction.
::
from sqlalchemy_utils import has_changes
user = User()
has_changes(user, 'name') # False
user.name = u'someone'
has_changes(user, 'name') # True
:param obj: SQLAlchemy declarative model object
:param attr: Name of the attribute
"""
return (
sa.inspect(obj)
.attrs
.get(attr)
.history
.has_changes()
)
def identity(obj):
"""
Return the identity of given sqlalchemy declarative model instance as a
tuple. This differs from obj._sa_instance_state.identity in a way that it
always returns the identity even if object is still in transient state (
new object that is not yet persisted into database).
::
from sqlalchemy import inspect
from sqlalchemy_utils import identity
user = User(name=u'John Matrix')
session.add(user)
identity(user) # None
inspect(user).identity # None
session.flush() # User now has id but is still in transient state
identity(user) # (1,)
inspect(user).identity # None
session.commit()
identity(user) # (1,)
inspect(user).identity # (1, )
.. versionadded: 0.21.0
:param obj: SQLAlchemy declarative model object
"""
id_ = []
for column in sa.inspect(obj.__class__).columns:
if column.primary_key:
id_.append(getattr(obj, column.name))
if all(value is None for value in id_):
return None
else:
return tuple(id_)
def naturally_equivalent(obj, obj2):
"""
Returns whether or not two given SQLAlchemy declarative instances are
naturally equivalent (all their non primary key properties are equivalent).
::
from sqlalchemy_utils import naturally_equivalent
user = User(name=u'someone')
user2 = User(name=u'someone')
user == user2 # False
naturally_equivalent(user, user2) # True
:param obj: SQLAlchemy declarative model object
:param obj2: SQLAlchemy declarative model object to compare with `obj`
"""
for prop in sa.inspect(obj.__class__).iterate_properties:
if not isinstance(prop, sa.orm.ColumnProperty):
continue
if prop.columns[0].primary_key:
continue
if not (getattr(obj, prop.key) == getattr(obj2, prop.key)):
return False
return True

View File

@@ -1,10 +1,55 @@
from collections import defaultdict
from sqlalchemy.engine.url import make_url from sqlalchemy.engine.url import make_url
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint
from sqlalchemy.exc import ProgrammingError, OperationalError from sqlalchemy.exc import ProgrammingError, OperationalError
import os import os
from copy import copy from copy import copy
def escape_like(string, escape_char='*'):
"""
Escapes the string paremeter used in SQL LIKE expressions
>>> from sqlalchemy_utils import escape_like
>>> query = session.query(User).filter(
... User.name.ilike(escape_like('John'))
... )
:param string: a string to escape
:param escape_char: escape character
"""
return (
string
.replace(escape_char, escape_char * 2)
.replace('%', escape_char + '%')
.replace('_', escape_char + '_')
)
def is_auto_assigned_date_column(column):
"""
Returns whether or not given SQLAlchemy Column object's is auto assigned
DateTime or Date.
:param column: SQLAlchemy Column object
"""
return (
(
isinstance(column.type, sa.DateTime) or
isinstance(column.type, sa.Date)
)
and
(
column.default or
column.server_default or
column.onupdate or
column.server_onupdate
)
)
def database_exists(url): def database_exists(url):
"""Check if a database exists. """Check if a database exists.
@@ -137,3 +182,55 @@ def drop_database(url):
else: else:
text = "DROP DATABASE %s" % database text = "DROP DATABASE %s" % database
engine.execute(text) engine.execute(text)
def non_indexed_foreign_keys(metadata, engine=None):
"""
Finds all non indexed foreign keys from all tables of given MetaData.
Very useful for optimizing postgresql database and finding out which
foreign keys need indexes.
:param metadata: MetaData object to inspect tables from
"""
reflected_metadata = MetaData()
if metadata.bind is None and engine is None:
raise Exception(
'Either pass a metadata object with bind or '
'pass engine as a second parameter'
)
constraints = defaultdict(list)
for table_name in metadata.tables.keys():
table = Table(
table_name,
reflected_metadata,
autoload=True,
autoload_with=metadata.bind or engine
)
for constraint in table.constraints:
if not isinstance(constraint, ForeignKeyConstraint):
continue
if not is_indexed_foreign_key(constraint):
constraints[table.name].append(constraint)
return dict(constraints)
def is_indexed_foreign_key(constraint):
"""
Whether or not given foreign key constraint's columns have been indexed.
:param constraint: ForeignKeyConstraint object to check the indexes
"""
for index in constraint.table.indexes:
index_column_names = set([
column.name for column in index.columns
])
if index_column_names == set(constraint.columns):
return True
return False

View File

@@ -3,11 +3,18 @@ from sqlalchemy.orm.session import _state_session
from sqlalchemy.orm import attributes, class_mapper from sqlalchemy.orm import attributes, class_mapper
from sqlalchemy.util import set_creation_order from sqlalchemy.util import set_creation_order
from sqlalchemy import exc as sa_exc from sqlalchemy import exc as sa_exc
from .functions import table_name from sqlalchemy_utils.functions import table_name
def class_from_table_name(state, table):
for class_ in state.class_._decl_class_registry.values():
name = table_name(class_)
if name and name == table:
return class_
return None
class GenericAttributeImpl(attributes.ScalarAttributeImpl): class GenericAttributeImpl(attributes.ScalarAttributeImpl):
def get(self, state, dict_, passive=attributes.PASSIVE_OFF): def get(self, state, dict_, passive=attributes.PASSIVE_OFF):
if self.key in dict_: if self.key in dict_:
return dict_[self.key] return dict_[self.key]
@@ -22,11 +29,7 @@ class GenericAttributeImpl(attributes.ScalarAttributeImpl):
# Find class for discriminator. # Find class for discriminator.
# TODO: Perhaps optimize with some sort of lookup? # TODO: Perhaps optimize with some sort of lookup?
discriminator = state.attrs[self.parent_token.discriminator.key].value discriminator = state.attrs[self.parent_token.discriminator.key].value
target_class = None target_class = class_from_table_name(state, discriminator)
for class_ in state.class_._decl_class_registry.values():
name = table_name(class_)
if name and name == discriminator:
target_class = class_
if target_class is None: if target_class is None:
# Unknown discriminator; return nothing. # Unknown discriminator; return nothing.
@@ -96,13 +99,13 @@ class GenericRelationshipProperty(MapperProperty):
class Comparator(PropComparator): class Comparator(PropComparator):
def __init__(self, prop, parentmapper): def __init__(self, prop, parentmapper):
self.prop = prop self.property = prop
self._parentmapper = parentmapper self._parentmapper = parentmapper
def __eq__(self, other): def __eq__(self, other):
discriminator = table_name(other) discriminator = table_name(other)
q = self.prop._discriminator_col == discriminator q = self.property._discriminator_col == discriminator
q &= self.prop._id_col == other.id q &= self.property._id_col == other.id
return q return q
def __ne__(self, other): def __ne__(self, other):
@@ -110,7 +113,7 @@ class GenericRelationshipProperty(MapperProperty):
def is_type(self, other): def is_type(self, other):
discriminator = table_name(other) discriminator = table_name(other)
return self.prop._discriminator_col == discriminator return self.property._discriminator_col == discriminator
def instrument_class(self, mapper): def instrument_class(self, mapper):
attributes.register_attribute( attributes.register_attribute(

View File

@@ -1,6 +1,6 @@
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy_utils import batch_fetch from sqlalchemy_utils import batch_fetch
from sqlalchemy_utils.functions import CompositePath from sqlalchemy_utils.batch import CompositePath
from tests import TestCase from tests import TestCase