Add batch fetch support for generic relationships

This commit is contained in:
Konsta Vesterinen
2013-12-04 15:53:37 +02:00
parent 48b357cbb8
commit 22fbcfac3d
6 changed files with 143 additions and 686 deletions

View File

@@ -1,9 +1,9 @@
from .aggregates import aggregated
from .batch_fetch import batch_fetch, with_backrefs
from .decorators import generates
from .eav import MetaValue, MetaType
from .exceptions import ImproperlyConfigured
from .functions import (
batch_fetch,
defer_except,
escape_like,
identity,
@@ -15,7 +15,6 @@ from .functions import (
mock_engine,
sort_query,
table_name,
with_backrefs,
database_exists,
create_database,
drop_database

View File

@@ -1,398 +0,0 @@
from collections import defaultdict
import six
import sqlalchemy as sa
from sqlalchemy.orm import RelationshipProperty
from sqlalchemy.orm.attributes import (
set_committed_value, InstrumentedAttribute
)
from sqlalchemy.orm.session import object_session
class PathException(Exception):
pass
class with_backrefs(object):
"""
Marks given attribute path so that whenever its fetched with batch_fetch
the backref relations are force set too. Very useful when dealing with
certain many-to-many relationship scenarios.
"""
def __init__(self, path):
self.path = path
class Path(object):
"""
A class that represents an attribute path.
"""
def __init__(self, entities, prop, populate_backrefs=False):
self.property = prop
self.entities = entities
self.populate_backrefs = populate_backrefs
if not isinstance(self.property, RelationshipProperty):
raise PathException(
'Given attribute is not a relationship property.'
)
self.fetcher = self.fetcher_class(self)
@property
def session(self):
return object_session(self.entities[0])
@property
def parent_model(self):
return self.entities[0].__class__
@property
def model(self):
return self.property.mapper.class_
@classmethod
def parse(cls, entities, path, populate_backrefs=False):
if isinstance(path, six.string_types):
attrs = path.split('.')
if len(attrs) > 1:
related_entities = []
for entity in entities:
related_entities.extend(getattr(entity, attrs[0]))
if not related_entities:
return
subpath = '.'.join(attrs[1:])
return Path.parse(related_entities, subpath, populate_backrefs)
else:
attr = getattr(
entities[0].__class__, attrs[0]
)
elif isinstance(path, InstrumentedAttribute):
attr = path
else:
raise PathException('Unknown path type.')
return Path(entities, attr.property, populate_backrefs)
@property
def fetcher_class(self):
if self.property.secondary is not None:
return ManyToManyFetcher
else:
if self.property.direction.name == 'MANYTOONE':
return ManyToOneFetcher
else:
return OneToManyFetcher
class CompositePath(object):
def __init__(self, *paths):
self.paths = paths
def batch_fetch(entities, *attr_paths):
"""
Batch fetch given relationship attribute for collection of entities.
This function is in many cases a valid alternative for SQLAlchemy's
subqueryload and performs lot better.
:param entities: list of entities of the same type
:param attr_paths:
List of either InstrumentedAttribute objects or a strings representing
the name of the instrumented attribute
Example::
from sqlalchemy_utils import batch_fetch
users = session.query(User).limit(20).all()
batch_fetch(users, User.phonenumbers)
Function also accepts strings as attribute names: ::
users = session.query(User).limit(20).all()
batch_fetch(users, 'phonenumbers')
Multiple attributes may be provided: ::
clubs = session.query(Club).limit(20).all()
batch_fetch(
clubs,
'teams',
'teams.players',
'teams.players.user_groups'
)
You can also force populate backrefs: ::
from sqlalchemy_utils import with_backrefs
clubs = session.query(Club).limit(20).all()
batch_fetch(
clubs,
'teams',
'teams.players',
with_backrefs('teams.players.user_groups')
)
"""
if entities:
for path in attr_paths:
fetcher = fetcher_factory(entities, path)
if fetcher:
fetcher.fetch()
fetcher.populate()
def fetcher_factory(entities, path):
populate_backrefs = False
if isinstance(path, with_backrefs):
path = path.path
populate_backrefs = True
if isinstance(path, CompositePath):
fetchers = []
for path in path.paths:
path = Path.parse(entities, path, populate_backrefs)
if path:
fetchers.append(
path.fetcher
)
return CompositeFetcher(*fetchers)
else:
path = Path.parse(entities, path, populate_backrefs)
if path:
return path.fetcher
class CompositeFetcher(object):
def __init__(self, *fetchers):
if not all(
fetchers[0].path.model == fetcher.path.model
for fetcher in fetchers
):
raise PathException(
'Each relationship property must have the same class when '
'using CompositeFetcher.'
)
self.fetchers = fetchers
@property
def session(self):
return self.fetchers[0].path.session
@property
def model(self):
return self.fetchers[0].path.model
@property
def condition(self):
return sa.or_(
*[fetcher.condition for fetcher in self.fetchers]
)
@property
def related_entities(self):
return self.session.query(self.model).filter(self.condition)
def fetch(self):
for entity in self.related_entities:
for fetcher in self.fetchers:
if any(
getattr(entity, name)
for name in fetcher.remote_column_names
):
fetcher.append_entity(entity)
def populate(self):
for fetcher in self.fetchers:
fetcher.populate()
class Fetcher(object):
def __init__(self, path):
self.path = path
self.prop = self.path.property
if self.prop.uselist:
self.parent_dict = defaultdict(list)
else:
self.parent_dict = defaultdict(lambda: None)
@property
def local_values_list(self):
return [
self.local_values(entity)
for entity in self.path.entities
]
@property
def relation_query_base(self):
return self.path.session.query(self.path.model)
@property
def related_entities(self):
return self.relation_query_base.filter(self.condition)
@property
def local_column_names(self):
return [local.name for local, remote in self.prop.local_remote_pairs]
def parent_key(self, entity):
return tuple(
getattr(entity, name)
for name in self.remote_column_names
)
def local_values(self, entity):
return tuple(
getattr(entity, name)
for name in self.local_column_names
)
def populate_backrefs(self, related_entities):
"""
Populates backrefs for given related entities.
"""
backref_dict = dict(
(self.local_values(value[0]), [])
for value in related_entities
)
for value in related_entities:
backref_dict[self.local_values(value[0])].append(
self.path.session.query(self.path.parent_model).get(
tuple(value[1:])
)
)
for value in related_entities:
set_committed_value(
value[0],
self.prop.back_populates,
backref_dict[self.local_values(value[0])]
)
def populate(self):
"""
Populate batch fetched entities to parent objects.
"""
for entity in self.path.entities:
set_committed_value(
entity,
self.prop.key,
self.parent_dict[self.local_values(entity)]
)
if self.path.populate_backrefs:
self.populate_backrefs(self.related_entities)
@property
def remote(self):
return self.path.model
@property
def condition(self):
names = self.remote_column_names
if len(names) == 1:
return getattr(self.remote, names[0]).in_(
value[0] for value in self.local_values_list
)
elif len(names) > 1:
conditions = []
for entity in self.path.entities:
conditions.append(
sa.and_(
*[
getattr(self.remote, remote.name)
==
getattr(entity, local.name)
for local, remote in self.prop.local_remote_pairs
if remote in self.remote_column_names
]
)
)
return sa.or_(*conditions)
else:
raise PathException(
'Could not obtain remote column names.'
)
def fetch(self):
for entity in self.related_entities:
self.append_entity(entity)
@property
def remote_column_names(self):
return [remote.name for local, remote in self.prop.local_remote_pairs]
class ManyToManyFetcher(Fetcher):
@property
def remote(self):
return self.prop.secondary.c
@property
def local_column_names(self):
names = []
for local, remote in self.prop.local_remote_pairs:
for fk in remote.foreign_keys:
if fk.column.table in self.prop.parent.tables:
names.append(local.name)
return names
@property
def remote_column_names(self):
names = []
for local, remote in self.prop.local_remote_pairs:
for fk in remote.foreign_keys:
if fk.column.table in self.prop.parent.tables:
names.append(remote.name)
return names
@property
def relation_query_base(self):
return (
self.path.session
.query(
self.path.model,
*[
getattr(self.prop.secondary.c, name)
for name in self.remote_column_names
]
)
.join(
self.prop.secondary, self.prop.secondaryjoin
)
)
def fetch(self):
for value in self.related_entities:
self.parent_dict[tuple(value[1:])].append(
value[0]
)
class ManyToOneFetcher(Fetcher):
def append_entity(self, entity):
#print 'appending entity ', entity, ' to key ', self.parent_key(entity)
self.parent_dict[self.parent_key(entity)] = entity
class OneToManyFetcher(Fetcher):
def append_entity(self, entity):
#print 'appending entity ', entity, ' to key ', self.parent_key(entity)
self.parent_dict[self.parent_key(entity)].append(
entity
)

View File

@@ -1,292 +1,48 @@
from collections import defaultdict
import sqlalchemy as sa
from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint
from .batch_fetch import batch_fetch, with_backrefs, CompositePath
from .defer_except import defer_except
from .mock import create_mock_engine, mock_engine
from .render import render_expression, render_statement
from .sort_query import sort_query, QuerySorterException
from .database import database_exists, create_database, drop_database
from .database import (
database_exists,
create_database,
drop_database,
escape_like,
is_auto_assigned_date_column,
is_indexed_foreign_key,
non_indexed_foreign_keys,
)
from .orm import (
primary_keys,
table_name,
declarative_base,
has_changes,
identity,
naturally_equivalent,
remove_property
)
__all__ = (
batch_fetch,
create_mock_engine,
defer_except,
mock_engine,
sort_query,
render_expression,
render_statement,
with_backrefs,
CompositePath,
QuerySorterException,
database_exists,
create_database,
drop_database
drop_database,
escape_like,
is_auto_assigned_date_column,
is_indexed_foreign_key,
non_indexed_foreign_keys,
remove_property,
primary_keys,
table_name,
declarative_base,
has_changes,
identity,
naturally_equivalent,
)
def escape_like(string, escape_char='*'):
"""
Escapes the string paremeter used in SQL LIKE expressions
>>> from sqlalchemy_utils import escape_like
>>> query = session.query(User).filter(
... User.name.ilike(escape_like('John'))
... )
:param string: a string to escape
:param escape_char: escape character
"""
return (
string
.replace(escape_char, escape_char * 2)
.replace('%', escape_char + '%')
.replace('_', escape_char + '_')
)
def remove_property(class_, name):
"""
**Experimental function**
Remove property from declarative class
"""
mapper = class_.mapper
table = class_.__table__
columns = class_.mapper.c
column = columns[name]
del columns._data[name]
del mapper.columns[name]
columns._all_cols.remove(column)
mapper._cols_by_table[table].remove(column)
mapper.class_manager.uninstrument_attribute(name)
del mapper._props[name]
def primary_keys(class_):
"""
Returns all primary keys for given declarative class.
"""
for column in class_.__table__.c:
if column.primary_key:
yield column
def table_name(obj):
"""
Return table name of given target, declarative class or the
table name where the declarative attribute is bound to.
"""
class_ = getattr(obj, 'class_', obj)
try:
return class_.__tablename__
except AttributeError:
pass
try:
return class_.__table__.name
except AttributeError:
pass
def non_indexed_foreign_keys(metadata, engine=None):
"""
Finds all non indexed foreign keys from all tables of given MetaData.
Very useful for optimizing postgresql database and finding out which
foreign keys need indexes.
:param metadata: MetaData object to inspect tables from
"""
reflected_metadata = MetaData()
if metadata.bind is None and engine is None:
raise Exception(
'Either pass a metadata object with bind or '
'pass engine as a second parameter'
)
constraints = defaultdict(list)
for table_name in metadata.tables.keys():
table = Table(
table_name,
reflected_metadata,
autoload=True,
autoload_with=metadata.bind or engine
)
for constraint in table.constraints:
if not isinstance(constraint, ForeignKeyConstraint):
continue
if not is_indexed_foreign_key(constraint):
constraints[table.name].append(constraint)
return dict(constraints)
def is_indexed_foreign_key(constraint):
"""
Whether or not given foreign key constraint's columns have been indexed.
:param constraint: ForeignKeyConstraint object to check the indexes
"""
for index in constraint.table.indexes:
index_column_names = set([
column.name for column in index.columns
])
if index_column_names == set(constraint.columns):
return True
return False
def declarative_base(model):
"""
Returns the declarative base for given model class.
:param model: SQLAlchemy declarative model
"""
for parent in model.__bases__:
try:
parent.metadata
return declarative_base(parent)
except AttributeError:
pass
return model
def is_auto_assigned_date_column(column):
"""
Returns whether or not given SQLAlchemy Column object's is auto assigned
DateTime or Date.
:param column: SQLAlchemy Column object
"""
return (
(
isinstance(column.type, sa.DateTime) or
isinstance(column.type, sa.Date)
)
and
(
column.default or
column.server_default or
column.onupdate or
column.server_onupdate
)
)
def has_changes(obj, attr):
"""
Simple shortcut function for checking if given attribute of given
declarative model object has changed during the transaction.
::
from sqlalchemy_utils import has_changes
user = User()
has_changes(user, 'name') # False
user.name = u'someone'
has_changes(user, 'name') # True
:param obj: SQLAlchemy declarative model object
:param attr: Name of the attribute
"""
return (
sa.inspect(obj)
.attrs
.get(attr)
.history
.has_changes()
)
def identity(obj):
"""
Return the identity of given sqlalchemy declarative model instance as a
tuple. This differs from obj._sa_instance_state.identity in a way that it
always returns the identity even if object is still in transient state (
new object that is not yet persisted into database).
::
from sqlalchemy import inspect
from sqlalchemy_utils import identity
user = User(name=u'John Matrix')
session.add(user)
identity(user) # None
inspect(user).identity # None
session.flush() # User now has id but is still in transient state
identity(user) # (1,)
inspect(user).identity # None
session.commit()
identity(user) # (1,)
inspect(user).identity # (1, )
.. versionadded: 0.21.0
:param obj: SQLAlchemy declarative model object
"""
id_ = []
for column in sa.inspect(obj.__class__).columns:
if column.primary_key:
id_.append(getattr(obj, column.name))
if all(value is None for value in id_):
return None
else:
return tuple(id_)
def naturally_equivalent(obj, obj2):
"""
Returns whether or not two given SQLAlchemy declarative instances are
naturally equivalent (all their non primary key properties are equivalent).
::
from sqlalchemy_utils import naturally_equivalent
user = User(name=u'someone')
user2 = User(name=u'someone')
user == user2 # False
naturally_equivalent(user, user2) # True
:param obj: SQLAlchemy declarative model object
:param obj2: SQLAlchemy declarative model object to compare with `obj`
"""
for prop in sa.inspect(obj.__class__).iterate_properties:
if not isinstance(prop, sa.orm.ColumnProperty):
continue
if prop.columns[0].primary_key:
continue
if not (getattr(obj, prop.key) == getattr(obj2, prop.key)):
return False
return True

View File

@@ -1,10 +1,55 @@
from collections import defaultdict
from sqlalchemy.engine.url import make_url
import sqlalchemy as sa
from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint
from sqlalchemy.exc import ProgrammingError, OperationalError
import os
from copy import copy
def escape_like(string, escape_char='*'):
"""
Escapes the string paremeter used in SQL LIKE expressions
>>> from sqlalchemy_utils import escape_like
>>> query = session.query(User).filter(
... User.name.ilike(escape_like('John'))
... )
:param string: a string to escape
:param escape_char: escape character
"""
return (
string
.replace(escape_char, escape_char * 2)
.replace('%', escape_char + '%')
.replace('_', escape_char + '_')
)
def is_auto_assigned_date_column(column):
"""
Returns whether or not given SQLAlchemy Column object's is auto assigned
DateTime or Date.
:param column: SQLAlchemy Column object
"""
return (
(
isinstance(column.type, sa.DateTime) or
isinstance(column.type, sa.Date)
)
and
(
column.default or
column.server_default or
column.onupdate or
column.server_onupdate
)
)
def database_exists(url):
"""Check if a database exists.
@@ -137,3 +182,55 @@ def drop_database(url):
else:
text = "DROP DATABASE %s" % database
engine.execute(text)
def non_indexed_foreign_keys(metadata, engine=None):
"""
Finds all non indexed foreign keys from all tables of given MetaData.
Very useful for optimizing postgresql database and finding out which
foreign keys need indexes.
:param metadata: MetaData object to inspect tables from
"""
reflected_metadata = MetaData()
if metadata.bind is None and engine is None:
raise Exception(
'Either pass a metadata object with bind or '
'pass engine as a second parameter'
)
constraints = defaultdict(list)
for table_name in metadata.tables.keys():
table = Table(
table_name,
reflected_metadata,
autoload=True,
autoload_with=metadata.bind or engine
)
for constraint in table.constraints:
if not isinstance(constraint, ForeignKeyConstraint):
continue
if not is_indexed_foreign_key(constraint):
constraints[table.name].append(constraint)
return dict(constraints)
def is_indexed_foreign_key(constraint):
"""
Whether or not given foreign key constraint's columns have been indexed.
:param constraint: ForeignKeyConstraint object to check the indexes
"""
for index in constraint.table.indexes:
index_column_names = set([
column.name for column in index.columns
])
if index_column_names == set(constraint.columns):
return True
return False

View File

@@ -3,11 +3,18 @@ from sqlalchemy.orm.session import _state_session
from sqlalchemy.orm import attributes, class_mapper
from sqlalchemy.util import set_creation_order
from sqlalchemy import exc as sa_exc
from .functions import table_name
from sqlalchemy_utils.functions import table_name
def class_from_table_name(state, table):
for class_ in state.class_._decl_class_registry.values():
name = table_name(class_)
if name and name == table:
return class_
return None
class GenericAttributeImpl(attributes.ScalarAttributeImpl):
def get(self, state, dict_, passive=attributes.PASSIVE_OFF):
if self.key in dict_:
return dict_[self.key]
@@ -22,11 +29,7 @@ class GenericAttributeImpl(attributes.ScalarAttributeImpl):
# Find class for discriminator.
# TODO: Perhaps optimize with some sort of lookup?
discriminator = state.attrs[self.parent_token.discriminator.key].value
target_class = None
for class_ in state.class_._decl_class_registry.values():
name = table_name(class_)
if name and name == discriminator:
target_class = class_
target_class = class_from_table_name(state, discriminator)
if target_class is None:
# Unknown discriminator; return nothing.
@@ -96,13 +99,13 @@ class GenericRelationshipProperty(MapperProperty):
class Comparator(PropComparator):
def __init__(self, prop, parentmapper):
self.prop = prop
self.property = prop
self._parentmapper = parentmapper
def __eq__(self, other):
discriminator = table_name(other)
q = self.prop._discriminator_col == discriminator
q &= self.prop._id_col == other.id
q = self.property._discriminator_col == discriminator
q &= self.property._id_col == other.id
return q
def __ne__(self, other):
@@ -110,7 +113,7 @@ class GenericRelationshipProperty(MapperProperty):
def is_type(self, other):
discriminator = table_name(other)
return self.prop._discriminator_col == discriminator
return self.property._discriminator_col == discriminator
def instrument_class(self, mapper):
attributes.register_attribute(

View File

@@ -1,6 +1,6 @@
import sqlalchemy as sa
from sqlalchemy_utils import batch_fetch
from sqlalchemy_utils.functions import CompositePath
from sqlalchemy_utils.batch import CompositePath
from tests import TestCase