Add batch fetch support for generic relationships
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
from .aggregates import aggregated
|
||||
from .batch_fetch import batch_fetch, with_backrefs
|
||||
from .decorators import generates
|
||||
from .eav import MetaValue, MetaType
|
||||
from .exceptions import ImproperlyConfigured
|
||||
from .functions import (
|
||||
batch_fetch,
|
||||
defer_except,
|
||||
escape_like,
|
||||
identity,
|
||||
@@ -15,7 +15,6 @@ from .functions import (
|
||||
mock_engine,
|
||||
sort_query,
|
||||
table_name,
|
||||
with_backrefs,
|
||||
database_exists,
|
||||
create_database,
|
||||
drop_database
|
||||
|
@@ -1,398 +0,0 @@
|
||||
from collections import defaultdict
|
||||
import six
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.orm import RelationshipProperty
|
||||
from sqlalchemy.orm.attributes import (
|
||||
set_committed_value, InstrumentedAttribute
|
||||
)
|
||||
from sqlalchemy.orm.session import object_session
|
||||
|
||||
|
||||
class PathException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class with_backrefs(object):
|
||||
"""
|
||||
Marks given attribute path so that whenever its fetched with batch_fetch
|
||||
the backref relations are force set too. Very useful when dealing with
|
||||
certain many-to-many relationship scenarios.
|
||||
"""
|
||||
def __init__(self, path):
|
||||
self.path = path
|
||||
|
||||
|
||||
class Path(object):
|
||||
"""
|
||||
A class that represents an attribute path.
|
||||
"""
|
||||
def __init__(self, entities, prop, populate_backrefs=False):
|
||||
self.property = prop
|
||||
self.entities = entities
|
||||
self.populate_backrefs = populate_backrefs
|
||||
if not isinstance(self.property, RelationshipProperty):
|
||||
raise PathException(
|
||||
'Given attribute is not a relationship property.'
|
||||
)
|
||||
self.fetcher = self.fetcher_class(self)
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
return object_session(self.entities[0])
|
||||
|
||||
@property
|
||||
def parent_model(self):
|
||||
return self.entities[0].__class__
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
return self.property.mapper.class_
|
||||
|
||||
@classmethod
|
||||
def parse(cls, entities, path, populate_backrefs=False):
|
||||
if isinstance(path, six.string_types):
|
||||
attrs = path.split('.')
|
||||
|
||||
if len(attrs) > 1:
|
||||
related_entities = []
|
||||
for entity in entities:
|
||||
related_entities.extend(getattr(entity, attrs[0]))
|
||||
|
||||
if not related_entities:
|
||||
return
|
||||
subpath = '.'.join(attrs[1:])
|
||||
return Path.parse(related_entities, subpath, populate_backrefs)
|
||||
else:
|
||||
attr = getattr(
|
||||
entities[0].__class__, attrs[0]
|
||||
)
|
||||
elif isinstance(path, InstrumentedAttribute):
|
||||
attr = path
|
||||
else:
|
||||
raise PathException('Unknown path type.')
|
||||
|
||||
return Path(entities, attr.property, populate_backrefs)
|
||||
|
||||
@property
|
||||
def fetcher_class(self):
|
||||
if self.property.secondary is not None:
|
||||
return ManyToManyFetcher
|
||||
else:
|
||||
if self.property.direction.name == 'MANYTOONE':
|
||||
return ManyToOneFetcher
|
||||
else:
|
||||
return OneToManyFetcher
|
||||
|
||||
|
||||
class CompositePath(object):
|
||||
def __init__(self, *paths):
|
||||
self.paths = paths
|
||||
|
||||
|
||||
def batch_fetch(entities, *attr_paths):
|
||||
"""
|
||||
Batch fetch given relationship attribute for collection of entities.
|
||||
|
||||
This function is in many cases a valid alternative for SQLAlchemy's
|
||||
subqueryload and performs lot better.
|
||||
|
||||
:param entities: list of entities of the same type
|
||||
:param attr_paths:
|
||||
List of either InstrumentedAttribute objects or a strings representing
|
||||
the name of the instrumented attribute
|
||||
|
||||
Example::
|
||||
|
||||
|
||||
from sqlalchemy_utils import batch_fetch
|
||||
|
||||
|
||||
users = session.query(User).limit(20).all()
|
||||
|
||||
batch_fetch(users, User.phonenumbers)
|
||||
|
||||
|
||||
Function also accepts strings as attribute names: ::
|
||||
|
||||
|
||||
users = session.query(User).limit(20).all()
|
||||
|
||||
batch_fetch(users, 'phonenumbers')
|
||||
|
||||
|
||||
Multiple attributes may be provided: ::
|
||||
|
||||
|
||||
clubs = session.query(Club).limit(20).all()
|
||||
|
||||
batch_fetch(
|
||||
clubs,
|
||||
'teams',
|
||||
'teams.players',
|
||||
'teams.players.user_groups'
|
||||
)
|
||||
|
||||
You can also force populate backrefs: ::
|
||||
|
||||
|
||||
from sqlalchemy_utils import with_backrefs
|
||||
|
||||
|
||||
clubs = session.query(Club).limit(20).all()
|
||||
|
||||
batch_fetch(
|
||||
clubs,
|
||||
'teams',
|
||||
'teams.players',
|
||||
with_backrefs('teams.players.user_groups')
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
if entities:
|
||||
for path in attr_paths:
|
||||
fetcher = fetcher_factory(entities, path)
|
||||
if fetcher:
|
||||
fetcher.fetch()
|
||||
fetcher.populate()
|
||||
|
||||
|
||||
def fetcher_factory(entities, path):
|
||||
populate_backrefs = False
|
||||
if isinstance(path, with_backrefs):
|
||||
path = path.path
|
||||
populate_backrefs = True
|
||||
|
||||
if isinstance(path, CompositePath):
|
||||
fetchers = []
|
||||
for path in path.paths:
|
||||
path = Path.parse(entities, path, populate_backrefs)
|
||||
if path:
|
||||
fetchers.append(
|
||||
path.fetcher
|
||||
)
|
||||
|
||||
return CompositeFetcher(*fetchers)
|
||||
else:
|
||||
path = Path.parse(entities, path, populate_backrefs)
|
||||
if path:
|
||||
return path.fetcher
|
||||
|
||||
|
||||
class CompositeFetcher(object):
|
||||
def __init__(self, *fetchers):
|
||||
if not all(
|
||||
fetchers[0].path.model == fetcher.path.model
|
||||
for fetcher in fetchers
|
||||
):
|
||||
raise PathException(
|
||||
'Each relationship property must have the same class when '
|
||||
'using CompositeFetcher.'
|
||||
)
|
||||
self.fetchers = fetchers
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
return self.fetchers[0].path.session
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
return self.fetchers[0].path.model
|
||||
|
||||
@property
|
||||
def condition(self):
|
||||
return sa.or_(
|
||||
*[fetcher.condition for fetcher in self.fetchers]
|
||||
)
|
||||
|
||||
@property
|
||||
def related_entities(self):
|
||||
return self.session.query(self.model).filter(self.condition)
|
||||
|
||||
def fetch(self):
|
||||
for entity in self.related_entities:
|
||||
for fetcher in self.fetchers:
|
||||
if any(
|
||||
getattr(entity, name)
|
||||
for name in fetcher.remote_column_names
|
||||
):
|
||||
fetcher.append_entity(entity)
|
||||
|
||||
def populate(self):
|
||||
for fetcher in self.fetchers:
|
||||
fetcher.populate()
|
||||
|
||||
|
||||
class Fetcher(object):
|
||||
def __init__(self, path):
|
||||
self.path = path
|
||||
self.prop = self.path.property
|
||||
if self.prop.uselist:
|
||||
self.parent_dict = defaultdict(list)
|
||||
else:
|
||||
self.parent_dict = defaultdict(lambda: None)
|
||||
|
||||
@property
|
||||
def local_values_list(self):
|
||||
return [
|
||||
self.local_values(entity)
|
||||
for entity in self.path.entities
|
||||
]
|
||||
|
||||
@property
|
||||
def relation_query_base(self):
|
||||
return self.path.session.query(self.path.model)
|
||||
|
||||
@property
|
||||
def related_entities(self):
|
||||
return self.relation_query_base.filter(self.condition)
|
||||
|
||||
@property
|
||||
def local_column_names(self):
|
||||
return [local.name for local, remote in self.prop.local_remote_pairs]
|
||||
|
||||
def parent_key(self, entity):
|
||||
return tuple(
|
||||
getattr(entity, name)
|
||||
for name in self.remote_column_names
|
||||
)
|
||||
|
||||
def local_values(self, entity):
|
||||
return tuple(
|
||||
getattr(entity, name)
|
||||
for name in self.local_column_names
|
||||
)
|
||||
|
||||
def populate_backrefs(self, related_entities):
|
||||
"""
|
||||
Populates backrefs for given related entities.
|
||||
"""
|
||||
backref_dict = dict(
|
||||
(self.local_values(value[0]), [])
|
||||
for value in related_entities
|
||||
)
|
||||
for value in related_entities:
|
||||
backref_dict[self.local_values(value[0])].append(
|
||||
self.path.session.query(self.path.parent_model).get(
|
||||
tuple(value[1:])
|
||||
)
|
||||
)
|
||||
for value in related_entities:
|
||||
set_committed_value(
|
||||
value[0],
|
||||
self.prop.back_populates,
|
||||
backref_dict[self.local_values(value[0])]
|
||||
)
|
||||
|
||||
def populate(self):
|
||||
"""
|
||||
Populate batch fetched entities to parent objects.
|
||||
"""
|
||||
for entity in self.path.entities:
|
||||
set_committed_value(
|
||||
entity,
|
||||
self.prop.key,
|
||||
self.parent_dict[self.local_values(entity)]
|
||||
)
|
||||
|
||||
if self.path.populate_backrefs:
|
||||
self.populate_backrefs(self.related_entities)
|
||||
|
||||
@property
|
||||
def remote(self):
|
||||
return self.path.model
|
||||
|
||||
@property
|
||||
def condition(self):
|
||||
names = self.remote_column_names
|
||||
if len(names) == 1:
|
||||
return getattr(self.remote, names[0]).in_(
|
||||
value[0] for value in self.local_values_list
|
||||
)
|
||||
elif len(names) > 1:
|
||||
conditions = []
|
||||
for entity in self.path.entities:
|
||||
conditions.append(
|
||||
sa.and_(
|
||||
*[
|
||||
getattr(self.remote, remote.name)
|
||||
==
|
||||
getattr(entity, local.name)
|
||||
for local, remote in self.prop.local_remote_pairs
|
||||
if remote in self.remote_column_names
|
||||
]
|
||||
)
|
||||
)
|
||||
return sa.or_(*conditions)
|
||||
else:
|
||||
raise PathException(
|
||||
'Could not obtain remote column names.'
|
||||
)
|
||||
|
||||
def fetch(self):
|
||||
for entity in self.related_entities:
|
||||
self.append_entity(entity)
|
||||
|
||||
@property
|
||||
def remote_column_names(self):
|
||||
return [remote.name for local, remote in self.prop.local_remote_pairs]
|
||||
|
||||
|
||||
class ManyToManyFetcher(Fetcher):
|
||||
@property
|
||||
def remote(self):
|
||||
return self.prop.secondary.c
|
||||
|
||||
@property
|
||||
def local_column_names(self):
|
||||
names = []
|
||||
for local, remote in self.prop.local_remote_pairs:
|
||||
for fk in remote.foreign_keys:
|
||||
if fk.column.table in self.prop.parent.tables:
|
||||
names.append(local.name)
|
||||
return names
|
||||
|
||||
@property
|
||||
def remote_column_names(self):
|
||||
names = []
|
||||
for local, remote in self.prop.local_remote_pairs:
|
||||
for fk in remote.foreign_keys:
|
||||
if fk.column.table in self.prop.parent.tables:
|
||||
names.append(remote.name)
|
||||
return names
|
||||
|
||||
@property
|
||||
def relation_query_base(self):
|
||||
return (
|
||||
self.path.session
|
||||
.query(
|
||||
self.path.model,
|
||||
*[
|
||||
getattr(self.prop.secondary.c, name)
|
||||
for name in self.remote_column_names
|
||||
]
|
||||
)
|
||||
.join(
|
||||
self.prop.secondary, self.prop.secondaryjoin
|
||||
)
|
||||
)
|
||||
|
||||
def fetch(self):
|
||||
for value in self.related_entities:
|
||||
self.parent_dict[tuple(value[1:])].append(
|
||||
value[0]
|
||||
)
|
||||
|
||||
|
||||
class ManyToOneFetcher(Fetcher):
|
||||
def append_entity(self, entity):
|
||||
#print 'appending entity ', entity, ' to key ', self.parent_key(entity)
|
||||
self.parent_dict[self.parent_key(entity)] = entity
|
||||
|
||||
|
||||
class OneToManyFetcher(Fetcher):
|
||||
def append_entity(self, entity):
|
||||
#print 'appending entity ', entity, ' to key ', self.parent_key(entity)
|
||||
self.parent_dict[self.parent_key(entity)].append(
|
||||
entity
|
||||
)
|
@@ -1,292 +1,48 @@
|
||||
from collections import defaultdict
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint
|
||||
from .batch_fetch import batch_fetch, with_backrefs, CompositePath
|
||||
from .defer_except import defer_except
|
||||
from .mock import create_mock_engine, mock_engine
|
||||
from .render import render_expression, render_statement
|
||||
from .sort_query import sort_query, QuerySorterException
|
||||
from .database import database_exists, create_database, drop_database
|
||||
|
||||
from .database import (
|
||||
database_exists,
|
||||
create_database,
|
||||
drop_database,
|
||||
escape_like,
|
||||
is_auto_assigned_date_column,
|
||||
is_indexed_foreign_key,
|
||||
non_indexed_foreign_keys,
|
||||
)
|
||||
from .orm import (
|
||||
primary_keys,
|
||||
table_name,
|
||||
declarative_base,
|
||||
has_changes,
|
||||
identity,
|
||||
naturally_equivalent,
|
||||
remove_property
|
||||
)
|
||||
|
||||
__all__ = (
|
||||
batch_fetch,
|
||||
create_mock_engine,
|
||||
defer_except,
|
||||
mock_engine,
|
||||
sort_query,
|
||||
render_expression,
|
||||
render_statement,
|
||||
with_backrefs,
|
||||
CompositePath,
|
||||
QuerySorterException,
|
||||
database_exists,
|
||||
create_database,
|
||||
drop_database
|
||||
)
|
||||
|
||||
|
||||
def escape_like(string, escape_char='*'):
|
||||
"""
|
||||
Escapes the string paremeter used in SQL LIKE expressions
|
||||
|
||||
>>> from sqlalchemy_utils import escape_like
|
||||
>>> query = session.query(User).filter(
|
||||
... User.name.ilike(escape_like('John'))
|
||||
... )
|
||||
|
||||
|
||||
:param string: a string to escape
|
||||
:param escape_char: escape character
|
||||
"""
|
||||
return (
|
||||
string
|
||||
.replace(escape_char, escape_char * 2)
|
||||
.replace('%', escape_char + '%')
|
||||
.replace('_', escape_char + '_')
|
||||
)
|
||||
|
||||
|
||||
def remove_property(class_, name):
|
||||
"""
|
||||
**Experimental function**
|
||||
|
||||
Remove property from declarative class
|
||||
"""
|
||||
mapper = class_.mapper
|
||||
table = class_.__table__
|
||||
columns = class_.mapper.c
|
||||
column = columns[name]
|
||||
del columns._data[name]
|
||||
del mapper.columns[name]
|
||||
columns._all_cols.remove(column)
|
||||
mapper._cols_by_table[table].remove(column)
|
||||
mapper.class_manager.uninstrument_attribute(name)
|
||||
del mapper._props[name]
|
||||
|
||||
|
||||
def primary_keys(class_):
|
||||
"""
|
||||
Returns all primary keys for given declarative class.
|
||||
"""
|
||||
for column in class_.__table__.c:
|
||||
if column.primary_key:
|
||||
yield column
|
||||
|
||||
|
||||
def table_name(obj):
|
||||
"""
|
||||
Return table name of given target, declarative class or the
|
||||
table name where the declarative attribute is bound to.
|
||||
"""
|
||||
class_ = getattr(obj, 'class_', obj)
|
||||
|
||||
try:
|
||||
return class_.__tablename__
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
try:
|
||||
return class_.__table__.name
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
def non_indexed_foreign_keys(metadata, engine=None):
|
||||
"""
|
||||
Finds all non indexed foreign keys from all tables of given MetaData.
|
||||
|
||||
Very useful for optimizing postgresql database and finding out which
|
||||
foreign keys need indexes.
|
||||
|
||||
:param metadata: MetaData object to inspect tables from
|
||||
"""
|
||||
reflected_metadata = MetaData()
|
||||
|
||||
if metadata.bind is None and engine is None:
|
||||
raise Exception(
|
||||
'Either pass a metadata object with bind or '
|
||||
'pass engine as a second parameter'
|
||||
)
|
||||
|
||||
constraints = defaultdict(list)
|
||||
|
||||
for table_name in metadata.tables.keys():
|
||||
table = Table(
|
||||
drop_database,
|
||||
escape_like,
|
||||
is_auto_assigned_date_column,
|
||||
is_indexed_foreign_key,
|
||||
non_indexed_foreign_keys,
|
||||
remove_property,
|
||||
primary_keys,
|
||||
table_name,
|
||||
reflected_metadata,
|
||||
autoload=True,
|
||||
autoload_with=metadata.bind or engine
|
||||
)
|
||||
|
||||
for constraint in table.constraints:
|
||||
if not isinstance(constraint, ForeignKeyConstraint):
|
||||
continue
|
||||
|
||||
if not is_indexed_foreign_key(constraint):
|
||||
constraints[table.name].append(constraint)
|
||||
|
||||
return dict(constraints)
|
||||
|
||||
|
||||
def is_indexed_foreign_key(constraint):
|
||||
"""
|
||||
Whether or not given foreign key constraint's columns have been indexed.
|
||||
|
||||
:param constraint: ForeignKeyConstraint object to check the indexes
|
||||
"""
|
||||
for index in constraint.table.indexes:
|
||||
index_column_names = set([
|
||||
column.name for column in index.columns
|
||||
])
|
||||
if index_column_names == set(constraint.columns):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def declarative_base(model):
|
||||
"""
|
||||
Returns the declarative base for given model class.
|
||||
|
||||
:param model: SQLAlchemy declarative model
|
||||
"""
|
||||
for parent in model.__bases__:
|
||||
try:
|
||||
parent.metadata
|
||||
return declarative_base(parent)
|
||||
except AttributeError:
|
||||
pass
|
||||
return model
|
||||
|
||||
|
||||
def is_auto_assigned_date_column(column):
|
||||
"""
|
||||
Returns whether or not given SQLAlchemy Column object's is auto assigned
|
||||
DateTime or Date.
|
||||
|
||||
:param column: SQLAlchemy Column object
|
||||
"""
|
||||
return (
|
||||
(
|
||||
isinstance(column.type, sa.DateTime) or
|
||||
isinstance(column.type, sa.Date)
|
||||
)
|
||||
and
|
||||
(
|
||||
column.default or
|
||||
column.server_default or
|
||||
column.onupdate or
|
||||
column.server_onupdate
|
||||
)
|
||||
declarative_base,
|
||||
has_changes,
|
||||
identity,
|
||||
naturally_equivalent,
|
||||
)
|
||||
|
||||
|
||||
def has_changes(obj, attr):
|
||||
"""
|
||||
Simple shortcut function for checking if given attribute of given
|
||||
declarative model object has changed during the transaction.
|
||||
|
||||
|
||||
::
|
||||
|
||||
|
||||
from sqlalchemy_utils import has_changes
|
||||
|
||||
|
||||
user = User()
|
||||
|
||||
has_changes(user, 'name') # False
|
||||
|
||||
user.name = u'someone'
|
||||
|
||||
has_changes(user, 'name') # True
|
||||
|
||||
|
||||
:param obj: SQLAlchemy declarative model object
|
||||
:param attr: Name of the attribute
|
||||
"""
|
||||
return (
|
||||
sa.inspect(obj)
|
||||
.attrs
|
||||
.get(attr)
|
||||
.history
|
||||
.has_changes()
|
||||
)
|
||||
|
||||
|
||||
def identity(obj):
|
||||
"""
|
||||
Return the identity of given sqlalchemy declarative model instance as a
|
||||
tuple. This differs from obj._sa_instance_state.identity in a way that it
|
||||
always returns the identity even if object is still in transient state (
|
||||
new object that is not yet persisted into database).
|
||||
|
||||
::
|
||||
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy_utils import identity
|
||||
|
||||
|
||||
user = User(name=u'John Matrix')
|
||||
session.add(user)
|
||||
identity(user) # None
|
||||
inspect(user).identity # None
|
||||
|
||||
session.flush() # User now has id but is still in transient state
|
||||
|
||||
identity(user) # (1,)
|
||||
inspect(user).identity # None
|
||||
|
||||
session.commit()
|
||||
|
||||
identity(user) # (1,)
|
||||
inspect(user).identity # (1, )
|
||||
|
||||
|
||||
.. versionadded: 0.21.0
|
||||
|
||||
:param obj: SQLAlchemy declarative model object
|
||||
"""
|
||||
id_ = []
|
||||
for column in sa.inspect(obj.__class__).columns:
|
||||
if column.primary_key:
|
||||
id_.append(getattr(obj, column.name))
|
||||
|
||||
if all(value is None for value in id_):
|
||||
return None
|
||||
else:
|
||||
return tuple(id_)
|
||||
|
||||
|
||||
def naturally_equivalent(obj, obj2):
|
||||
"""
|
||||
Returns whether or not two given SQLAlchemy declarative instances are
|
||||
naturally equivalent (all their non primary key properties are equivalent).
|
||||
|
||||
|
||||
::
|
||||
|
||||
from sqlalchemy_utils import naturally_equivalent
|
||||
|
||||
|
||||
user = User(name=u'someone')
|
||||
user2 = User(name=u'someone')
|
||||
|
||||
user == user2 # False
|
||||
|
||||
naturally_equivalent(user, user2) # True
|
||||
|
||||
|
||||
:param obj: SQLAlchemy declarative model object
|
||||
:param obj2: SQLAlchemy declarative model object to compare with `obj`
|
||||
"""
|
||||
for prop in sa.inspect(obj.__class__).iterate_properties:
|
||||
if not isinstance(prop, sa.orm.ColumnProperty):
|
||||
continue
|
||||
|
||||
if prop.columns[0].primary_key:
|
||||
continue
|
||||
|
||||
if not (getattr(obj, prop.key) == getattr(obj2, prop.key)):
|
||||
return False
|
||||
return True
|
||||
|
@@ -1,10 +1,55 @@
|
||||
from collections import defaultdict
|
||||
from sqlalchemy.engine.url import make_url
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint
|
||||
from sqlalchemy.exc import ProgrammingError, OperationalError
|
||||
import os
|
||||
from copy import copy
|
||||
|
||||
|
||||
def escape_like(string, escape_char='*'):
|
||||
"""
|
||||
Escapes the string paremeter used in SQL LIKE expressions
|
||||
|
||||
>>> from sqlalchemy_utils import escape_like
|
||||
>>> query = session.query(User).filter(
|
||||
... User.name.ilike(escape_like('John'))
|
||||
... )
|
||||
|
||||
|
||||
:param string: a string to escape
|
||||
:param escape_char: escape character
|
||||
"""
|
||||
return (
|
||||
string
|
||||
.replace(escape_char, escape_char * 2)
|
||||
.replace('%', escape_char + '%')
|
||||
.replace('_', escape_char + '_')
|
||||
)
|
||||
|
||||
|
||||
def is_auto_assigned_date_column(column):
|
||||
"""
|
||||
Returns whether or not given SQLAlchemy Column object's is auto assigned
|
||||
DateTime or Date.
|
||||
|
||||
:param column: SQLAlchemy Column object
|
||||
"""
|
||||
return (
|
||||
(
|
||||
isinstance(column.type, sa.DateTime) or
|
||||
isinstance(column.type, sa.Date)
|
||||
)
|
||||
and
|
||||
(
|
||||
column.default or
|
||||
column.server_default or
|
||||
column.onupdate or
|
||||
column.server_onupdate
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def database_exists(url):
|
||||
"""Check if a database exists.
|
||||
|
||||
@@ -137,3 +182,55 @@ def drop_database(url):
|
||||
else:
|
||||
text = "DROP DATABASE %s" % database
|
||||
engine.execute(text)
|
||||
|
||||
|
||||
def non_indexed_foreign_keys(metadata, engine=None):
|
||||
"""
|
||||
Finds all non indexed foreign keys from all tables of given MetaData.
|
||||
|
||||
Very useful for optimizing postgresql database and finding out which
|
||||
foreign keys need indexes.
|
||||
|
||||
:param metadata: MetaData object to inspect tables from
|
||||
"""
|
||||
reflected_metadata = MetaData()
|
||||
|
||||
if metadata.bind is None and engine is None:
|
||||
raise Exception(
|
||||
'Either pass a metadata object with bind or '
|
||||
'pass engine as a second parameter'
|
||||
)
|
||||
|
||||
constraints = defaultdict(list)
|
||||
|
||||
for table_name in metadata.tables.keys():
|
||||
table = Table(
|
||||
table_name,
|
||||
reflected_metadata,
|
||||
autoload=True,
|
||||
autoload_with=metadata.bind or engine
|
||||
)
|
||||
|
||||
for constraint in table.constraints:
|
||||
if not isinstance(constraint, ForeignKeyConstraint):
|
||||
continue
|
||||
|
||||
if not is_indexed_foreign_key(constraint):
|
||||
constraints[table.name].append(constraint)
|
||||
|
||||
return dict(constraints)
|
||||
|
||||
|
||||
def is_indexed_foreign_key(constraint):
|
||||
"""
|
||||
Whether or not given foreign key constraint's columns have been indexed.
|
||||
|
||||
:param constraint: ForeignKeyConstraint object to check the indexes
|
||||
"""
|
||||
for index in constraint.table.indexes:
|
||||
index_column_names = set([
|
||||
column.name for column in index.columns
|
||||
])
|
||||
if index_column_names == set(constraint.columns):
|
||||
return True
|
||||
return False
|
||||
|
@@ -3,11 +3,18 @@ from sqlalchemy.orm.session import _state_session
|
||||
from sqlalchemy.orm import attributes, class_mapper
|
||||
from sqlalchemy.util import set_creation_order
|
||||
from sqlalchemy import exc as sa_exc
|
||||
from .functions import table_name
|
||||
from sqlalchemy_utils.functions import table_name
|
||||
|
||||
|
||||
def class_from_table_name(state, table):
|
||||
for class_ in state.class_._decl_class_registry.values():
|
||||
name = table_name(class_)
|
||||
if name and name == table:
|
||||
return class_
|
||||
return None
|
||||
|
||||
|
||||
class GenericAttributeImpl(attributes.ScalarAttributeImpl):
|
||||
|
||||
def get(self, state, dict_, passive=attributes.PASSIVE_OFF):
|
||||
if self.key in dict_:
|
||||
return dict_[self.key]
|
||||
@@ -22,11 +29,7 @@ class GenericAttributeImpl(attributes.ScalarAttributeImpl):
|
||||
# Find class for discriminator.
|
||||
# TODO: Perhaps optimize with some sort of lookup?
|
||||
discriminator = state.attrs[self.parent_token.discriminator.key].value
|
||||
target_class = None
|
||||
for class_ in state.class_._decl_class_registry.values():
|
||||
name = table_name(class_)
|
||||
if name and name == discriminator:
|
||||
target_class = class_
|
||||
target_class = class_from_table_name(state, discriminator)
|
||||
|
||||
if target_class is None:
|
||||
# Unknown discriminator; return nothing.
|
||||
@@ -96,13 +99,13 @@ class GenericRelationshipProperty(MapperProperty):
|
||||
class Comparator(PropComparator):
|
||||
|
||||
def __init__(self, prop, parentmapper):
|
||||
self.prop = prop
|
||||
self.property = prop
|
||||
self._parentmapper = parentmapper
|
||||
|
||||
def __eq__(self, other):
|
||||
discriminator = table_name(other)
|
||||
q = self.prop._discriminator_col == discriminator
|
||||
q &= self.prop._id_col == other.id
|
||||
q = self.property._discriminator_col == discriminator
|
||||
q &= self.property._id_col == other.id
|
||||
return q
|
||||
|
||||
def __ne__(self, other):
|
||||
@@ -110,7 +113,7 @@ class GenericRelationshipProperty(MapperProperty):
|
||||
|
||||
def is_type(self, other):
|
||||
discriminator = table_name(other)
|
||||
return self.prop._discriminator_col == discriminator
|
||||
return self.property._discriminator_col == discriminator
|
||||
|
||||
def instrument_class(self, mapper):
|
||||
attributes.register_attribute(
|
||||
|
@@ -1,6 +1,6 @@
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy_utils import batch_fetch
|
||||
from sqlalchemy_utils.functions import CompositePath
|
||||
from sqlalchemy_utils.batch import CompositePath
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user