Refactor docs, add merge_references
This commit is contained in:
@@ -4,12 +4,13 @@ Changelog
|
||||
Here you can see the full list of changes between each SQLAlchemy-Utils release.
|
||||
|
||||
|
||||
0.26.1 (2014-05-xx)
|
||||
0.26.1 (2014-05-14)
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
- Added get_bind
|
||||
- Added group_foreign_keys
|
||||
- Added get_mapper
|
||||
- Added merge_references
|
||||
|
||||
|
||||
0.26.0 (2014-05-07)
|
||||
|
@@ -5,18 +5,6 @@ Database helpers
|
||||
.. module:: sqlalchemy_utils.functions
|
||||
|
||||
|
||||
is_indexed_foreign_key
|
||||
^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: is_indexed_foreign_key
|
||||
|
||||
|
||||
non_indexed_foreign_keys
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: non_indexed_foreign_keys
|
||||
|
||||
|
||||
database_exists
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
|
40
docs/foreign_key_helpers.rst
Normal file
40
docs/foreign_key_helpers.rst
Normal file
@@ -0,0 +1,40 @@
|
||||
Foreign key helpers
|
||||
===================
|
||||
|
||||
.. module:: sqlalchemy_utils.functions
|
||||
|
||||
|
||||
dependent_objects
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: dependent_objects
|
||||
|
||||
|
||||
get_referencing_foreign_keys
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: get_referencing_foreign_keys
|
||||
|
||||
|
||||
group_foreign_keys
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: group_foreign_keys
|
||||
|
||||
|
||||
is_indexed_foreign_key
|
||||
^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: is_indexed_foreign_key
|
||||
|
||||
|
||||
merge_references
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: merge_references
|
||||
|
||||
|
||||
non_indexed_foreign_keys
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: non_indexed_foreign_keys
|
@@ -1,5 +1,5 @@
|
||||
Generic relationship
|
||||
====================
|
||||
Generic relationships
|
||||
=====================
|
||||
|
||||
Generic relationship is a form of relationship that supports creating a 1 to many relationship to any target model.
|
||||
|
||||
|
@@ -16,6 +16,7 @@ SQLAlchemy-Utils provides custom data types and various utility functions for SQ
|
||||
decorators
|
||||
generic_relationship
|
||||
database_helpers
|
||||
model_helpers
|
||||
foreign_key_helpers
|
||||
orm_helpers
|
||||
utility_classes
|
||||
license
|
||||
|
@@ -1,15 +1,9 @@
|
||||
Model helpers
|
||||
=============
|
||||
ORM helpers
|
||||
===========
|
||||
|
||||
.. module:: sqlalchemy_utils.functions
|
||||
|
||||
|
||||
dependent_objects
|
||||
^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: dependent_objects
|
||||
|
||||
|
||||
escape_like
|
||||
^^^^^^^^^^^
|
||||
|
||||
@@ -46,24 +40,12 @@ get_primary_keys
|
||||
.. autofunction:: get_primary_keys
|
||||
|
||||
|
||||
get_referencing_foreign_keys
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: get_referencing_foreign_keys
|
||||
|
||||
|
||||
get_tables
|
||||
^^^^^^^^^^
|
||||
|
||||
.. autofunction:: get_tables
|
||||
|
||||
|
||||
group_foreign_keys
|
||||
^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autofunction:: group_foreign_keys
|
||||
|
||||
|
||||
query_entities
|
||||
^^^^^^^^^^^^^^
|
||||
|
2
setup.py
2
setup.py
@@ -44,7 +44,7 @@ for name, requirements in extras_require.items():
|
||||
|
||||
setup(
|
||||
name='SQLAlchemy-Utils',
|
||||
version='0.26.0',
|
||||
version='0.26.1',
|
||||
url='https://github.com/kvesteri/sqlalchemy-utils',
|
||||
license='BSD',
|
||||
author='Konsta Vesterinen, Ryan Leckey, Janne Vanhala, Vesa Uimonen',
|
||||
|
@@ -20,6 +20,7 @@ from .functions import (
|
||||
get_tables,
|
||||
group_foreign_keys,
|
||||
identity,
|
||||
merge_references,
|
||||
mock_engine,
|
||||
naturally_equivalent,
|
||||
render_expression,
|
||||
@@ -32,7 +33,6 @@ from .listeners import (
|
||||
force_auto_coercion,
|
||||
force_instant_defaults
|
||||
)
|
||||
from .merge import merge, Merger
|
||||
from .generic import generic_relationship
|
||||
from .proxy_dict import ProxyDict, proxy_dict
|
||||
from .query_chain import QueryChain
|
||||
@@ -67,7 +67,7 @@ from .types import (
|
||||
)
|
||||
|
||||
|
||||
__version__ = '0.26.0'
|
||||
__version__ = '0.26.1'
|
||||
|
||||
|
||||
__all__ = (
|
||||
@@ -95,7 +95,7 @@ __all__ = (
|
||||
group_foreign_keys,
|
||||
identity,
|
||||
instrumented_list,
|
||||
merge,
|
||||
merge_references,
|
||||
mock_engine,
|
||||
naturally_equivalent,
|
||||
proxy_dict,
|
||||
@@ -120,7 +120,6 @@ __all__ = (
|
||||
IPAddressType,
|
||||
JSONType,
|
||||
LocaleType,
|
||||
Merger,
|
||||
NumericRangeType,
|
||||
Password,
|
||||
PasswordType,
|
||||
|
@@ -8,20 +8,23 @@ from .database import (
|
||||
drop_database,
|
||||
escape_like,
|
||||
is_auto_assigned_date_column,
|
||||
)
|
||||
from .foreign_keys import (
|
||||
dependent_objects,
|
||||
get_referencing_foreign_keys,
|
||||
group_foreign_keys,
|
||||
is_indexed_foreign_key,
|
||||
merge_references,
|
||||
non_indexed_foreign_keys,
|
||||
)
|
||||
from .orm import (
|
||||
dependent_objects,
|
||||
get_bind,
|
||||
get_columns,
|
||||
get_declarative_base,
|
||||
get_mapper,
|
||||
get_primary_keys,
|
||||
get_referencing_foreign_keys,
|
||||
get_tables,
|
||||
getdotattr,
|
||||
group_foreign_keys,
|
||||
has_changes,
|
||||
identity,
|
||||
naturally_equivalent,
|
||||
|
@@ -1,4 +1,3 @@
|
||||
from collections import defaultdict
|
||||
from sqlalchemy.engine.url import make_url
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint
|
||||
@@ -186,55 +185,3 @@ def drop_database(url):
|
||||
else:
|
||||
text = "DROP DATABASE %s" % database
|
||||
engine.execute(text)
|
||||
|
||||
|
||||
def non_indexed_foreign_keys(metadata, engine=None):
|
||||
"""
|
||||
Finds all non indexed foreign keys from all tables of given MetaData.
|
||||
|
||||
Very useful for optimizing postgresql database and finding out which
|
||||
foreign keys need indexes.
|
||||
|
||||
:param metadata: MetaData object to inspect tables from
|
||||
"""
|
||||
reflected_metadata = MetaData()
|
||||
|
||||
if metadata.bind is None and engine is None:
|
||||
raise Exception(
|
||||
'Either pass a metadata object with bind or '
|
||||
'pass engine as a second parameter'
|
||||
)
|
||||
|
||||
constraints = defaultdict(list)
|
||||
|
||||
for table_name in metadata.tables.keys():
|
||||
table = Table(
|
||||
table_name,
|
||||
reflected_metadata,
|
||||
autoload=True,
|
||||
autoload_with=metadata.bind or engine
|
||||
)
|
||||
|
||||
for constraint in table.constraints:
|
||||
if not isinstance(constraint, ForeignKeyConstraint):
|
||||
continue
|
||||
|
||||
if not is_indexed_foreign_key(constraint):
|
||||
constraints[table.name].append(constraint)
|
||||
|
||||
return dict(constraints)
|
||||
|
||||
|
||||
def is_indexed_foreign_key(constraint):
|
||||
"""
|
||||
Whether or not given foreign key constraint's columns have been indexed.
|
||||
|
||||
:param constraint: ForeignKeyConstraint object to check the indexes
|
||||
"""
|
||||
for index in constraint.table.indexes:
|
||||
index_column_names = set(
|
||||
column.name for column in index.columns
|
||||
)
|
||||
if index_column_names == set(constraint.columns):
|
||||
return True
|
||||
return False
|
||||
|
350
sqlalchemy_utils/functions/foreign_keys.py
Normal file
350
sqlalchemy_utils/functions/foreign_keys.py
Normal file
@@ -0,0 +1,350 @@
|
||||
from collections import defaultdict
|
||||
from itertools import groupby
|
||||
|
||||
import six
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.engine import reflection
|
||||
from sqlalchemy.orm import object_session, mapperlib
|
||||
from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint
|
||||
|
||||
from .orm import get_mapper, get_tables
|
||||
from ..query_chain import QueryChain
|
||||
|
||||
|
||||
def get_foreign_key_values(fk, obj):
|
||||
return {
|
||||
fk.constraint.columns[index].key:
|
||||
getattr(obj, element.column.key)
|
||||
for
|
||||
index, element
|
||||
in
|
||||
enumerate(fk.constraint.elements)
|
||||
}
|
||||
|
||||
|
||||
def group_foreign_keys(foreign_keys):
|
||||
"""
|
||||
Return a groupby iterator that groups given foreign keys by table.
|
||||
|
||||
:param foreign_keys: a sequence of foreign keys
|
||||
|
||||
|
||||
::
|
||||
|
||||
foreign_keys = get_referencing_foreign_keys(User)
|
||||
|
||||
for table, fks in group_foreign_keys(foreign_keys):
|
||||
# do something
|
||||
pass
|
||||
|
||||
|
||||
.. seealso:: :func:`get_referencing_foreign_keys`
|
||||
|
||||
.. versionadded: 0.26.1
|
||||
"""
|
||||
foreign_keys = sorted(
|
||||
foreign_keys, key=lambda key: key.constraint.table.name
|
||||
)
|
||||
return groupby(foreign_keys, lambda key: key.constraint.table)
|
||||
|
||||
|
||||
def get_referencing_foreign_keys(mixed):
|
||||
"""
|
||||
Returns referencing foreign keys for given Table object or declarative
|
||||
class.
|
||||
|
||||
:param mixed:
|
||||
SA Table object or SA declarative class
|
||||
|
||||
::
|
||||
|
||||
get_referencing_foreign_keys(User) # set([ForeignKey('user.id')])
|
||||
|
||||
get_referencing_foreign_keys(User.__table__)
|
||||
|
||||
|
||||
This function also understands inheritance. This means it returns
|
||||
all foreign keys that reference any table in the class inheritance tree.
|
||||
|
||||
Let's say you have three classes which use joined table inheritance,
|
||||
namely TextItem, Article and BlogPost with Article and BlogPost inheriting
|
||||
TextItem.
|
||||
|
||||
::
|
||||
|
||||
# This will check all foreign keys that reference either article table
|
||||
# or textitem table.
|
||||
get_referencing_foreign_keys(Article)
|
||||
|
||||
.. seealso:: :func:`get_tables`
|
||||
"""
|
||||
if isinstance(mixed, sa.Table):
|
||||
tables = [mixed]
|
||||
else:
|
||||
tables = get_tables(mixed)
|
||||
|
||||
referencing_foreign_keys = set()
|
||||
|
||||
for table in mixed.metadata.tables.values():
|
||||
if table not in tables:
|
||||
for constraint in table.constraints:
|
||||
if isinstance(constraint, sa.sql.schema.ForeignKeyConstraint):
|
||||
for fk in constraint.elements:
|
||||
if any(fk.references(t) for t in tables):
|
||||
referencing_foreign_keys.add(fk)
|
||||
return referencing_foreign_keys
|
||||
|
||||
|
||||
def merge_references(from_, to, foreign_keys=None):
|
||||
"""
|
||||
Merge the references of an entity into another entity.
|
||||
|
||||
Consider the following models::
|
||||
|
||||
class User(self.Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
def __repr__(self):
|
||||
return 'User(name=%r)' % self.name
|
||||
|
||||
class BlogPost(self.Base):
|
||||
__tablename__ = 'blog_post'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
title = sa.Column(sa.Unicode(255))
|
||||
author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id'))
|
||||
|
||||
author = sa.orm.relationship(User)
|
||||
|
||||
|
||||
Now lets add some data::
|
||||
|
||||
john = self.User(name=u'John')
|
||||
jack = self.User(name=u'Jack')
|
||||
post = self.BlogPost(title=u'Some title', author=john)
|
||||
post2 = self.BlogPost(title=u'Other title', author=jack)
|
||||
self.session.add_all([
|
||||
john,
|
||||
jack,
|
||||
post,
|
||||
post2
|
||||
])
|
||||
self.session.commit()
|
||||
|
||||
|
||||
If we wanted to merge all John's references to Jack it would be as easy as
|
||||
::
|
||||
|
||||
merge_references(john, jack)
|
||||
self.session.commit()
|
||||
|
||||
post.author # User(name='Jack')
|
||||
post2.author # User(name='Jack')
|
||||
|
||||
|
||||
|
||||
:param from_: an entity to merge into another entity
|
||||
:param to: an entity to merge another entity into
|
||||
:param foreign_keys: A sequence of foreign keys. By default this is None
|
||||
indicating all referencing foreign keys should be used.
|
||||
|
||||
.. seealso: :func:`dependent_objects`
|
||||
|
||||
.. versionadded: 0.26.1
|
||||
"""
|
||||
if from_.__tablename__ != to.__tablename__:
|
||||
raise TypeError('The tables of given arguments do not match.')
|
||||
|
||||
session = object_session(from_)
|
||||
foreign_keys = get_referencing_foreign_keys(from_)
|
||||
|
||||
for fk in foreign_keys:
|
||||
old_values = get_foreign_key_values(fk, from_)
|
||||
new_values = get_foreign_key_values(fk, to)
|
||||
criteria = (
|
||||
getattr(fk.constraint.table.c, key) == value
|
||||
for key, value in six.iteritems(old_values)
|
||||
)
|
||||
try:
|
||||
mapper = get_mapper(fk.constraint.table)
|
||||
except ValueError:
|
||||
query = (
|
||||
fk.constraint.table
|
||||
.update()
|
||||
.where(sa.and_(*criteria))
|
||||
.values(new_values)
|
||||
)
|
||||
session.execute(query)
|
||||
else:
|
||||
print old_values, new_values
|
||||
(
|
||||
session.query(mapper.class_)
|
||||
.filter_by(**old_values)
|
||||
.update(
|
||||
new_values,
|
||||
'evaluate'
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def dependent_objects(obj, foreign_keys=None):
|
||||
"""
|
||||
Return a :class:`~sqlalchemy_utils.query_chain.QueryChain` that iterates
|
||||
through all dependent objects for given SQLAlchemy object.
|
||||
|
||||
Consider a User object is referenced in various articles and also in
|
||||
various orders. Getting all these dependent objects is as easy as:
|
||||
|
||||
::
|
||||
|
||||
from sqlalchemy_utils import dependent_objects
|
||||
|
||||
|
||||
dependent_objects(user)
|
||||
|
||||
|
||||
If you expect an object to have lots of dependent_objects it might be good
|
||||
to limit the results::
|
||||
|
||||
|
||||
dependent_objects(user).limit(5)
|
||||
|
||||
|
||||
|
||||
The common use case is checking for all restrict dependent objects before
|
||||
deleting parent object and inform the user if there are dependent objects
|
||||
with ondelete='RESTRICT' foreign keys. If this kind of checking is not used
|
||||
it will lead to nasty IntegrityErrors being raised.
|
||||
|
||||
In the following example we delete given user if it doesn't have any
|
||||
foreign key restricted dependent objects.
|
||||
|
||||
::
|
||||
|
||||
|
||||
from sqlalchemy_utils import get_referencing_foreign_keys
|
||||
|
||||
|
||||
user = session.query(User).get(some_user_id)
|
||||
|
||||
|
||||
deps = list(
|
||||
dependent_objects(
|
||||
user,
|
||||
(
|
||||
fk for fk in get_referencing_foreign_keys(User)
|
||||
# On most databases RESTRICT is the default mode hence we
|
||||
# check for None values also
|
||||
if fk.ondelete == 'RESTRICT' or fk.ondelete is None
|
||||
)
|
||||
).limit(5)
|
||||
)
|
||||
|
||||
if deps:
|
||||
# Do something to inform the user
|
||||
pass
|
||||
else:
|
||||
session.delete(user)
|
||||
|
||||
|
||||
:param obj: SQLAlchemy declarative model object
|
||||
:param foreign_keys:
|
||||
A sequence of foreign keys to use for searching the dependent_objects
|
||||
for given object. By default this is None, indicating that all foreign
|
||||
keys referencing the object will be used.
|
||||
|
||||
.. note::
|
||||
This function does not support exotic mappers that use multiple tables
|
||||
|
||||
.. seealso:: :func:`get_referencing_foreign_keys`
|
||||
.. seealso:: :func:`merge_references`
|
||||
|
||||
.. versionadded: 0.26.0
|
||||
"""
|
||||
if foreign_keys is None:
|
||||
foreign_keys = get_referencing_foreign_keys(obj)
|
||||
|
||||
session = object_session(obj)
|
||||
|
||||
chain = QueryChain([])
|
||||
classes = obj.__class__._decl_class_registry
|
||||
|
||||
for table, keys in group_foreign_keys(foreign_keys):
|
||||
for class_ in classes.values():
|
||||
if hasattr(class_, '__table__') and class_.__table__ == table:
|
||||
criteria = []
|
||||
visited_constraints = []
|
||||
for key in keys:
|
||||
if key.constraint not in visited_constraints:
|
||||
visited_constraints.append(key.constraint)
|
||||
subcriteria = [
|
||||
getattr(class_, column.key) ==
|
||||
getattr(
|
||||
obj,
|
||||
key.constraint.elements[index].column.key
|
||||
)
|
||||
for index, column
|
||||
in enumerate(key.constraint.columns)
|
||||
]
|
||||
criteria.append(sa.and_(*subcriteria))
|
||||
|
||||
query = session.query(class_).filter(
|
||||
sa.or_(
|
||||
*criteria
|
||||
)
|
||||
)
|
||||
chain.queries.append(query)
|
||||
return chain
|
||||
|
||||
|
||||
def non_indexed_foreign_keys(metadata, engine=None):
|
||||
"""
|
||||
Finds all non indexed foreign keys from all tables of given MetaData.
|
||||
|
||||
Very useful for optimizing postgresql database and finding out which
|
||||
foreign keys need indexes.
|
||||
|
||||
:param metadata: MetaData object to inspect tables from
|
||||
"""
|
||||
reflected_metadata = MetaData()
|
||||
|
||||
if metadata.bind is None and engine is None:
|
||||
raise Exception(
|
||||
'Either pass a metadata object with bind or '
|
||||
'pass engine as a second parameter'
|
||||
)
|
||||
|
||||
constraints = defaultdict(list)
|
||||
|
||||
for table_name in metadata.tables.keys():
|
||||
table = Table(
|
||||
table_name,
|
||||
reflected_metadata,
|
||||
autoload=True,
|
||||
autoload_with=metadata.bind or engine
|
||||
)
|
||||
|
||||
for constraint in table.constraints:
|
||||
if not isinstance(constraint, ForeignKeyConstraint):
|
||||
continue
|
||||
|
||||
if not is_indexed_foreign_key(constraint):
|
||||
constraints[table.name].append(constraint)
|
||||
|
||||
return dict(constraints)
|
||||
|
||||
|
||||
def is_indexed_foreign_key(constraint):
|
||||
"""
|
||||
Whether or not given foreign key constraint's columns have been indexed.
|
||||
|
||||
:param constraint: ForeignKeyConstraint object to check the indexes
|
||||
"""
|
||||
for index in constraint.table.indexes:
|
||||
index_column_names = set(
|
||||
column.name for column in index.columns
|
||||
)
|
||||
if index_column_names == set(constraint.columns):
|
||||
return True
|
||||
return False
|
@@ -61,8 +61,12 @@ def get_mapper(mixed):
|
||||
]
|
||||
if len(mappers) > 1:
|
||||
raise ValueError(
|
||||
"Could not get mapper for '%r'. Multiple mappers found."
|
||||
% mixed
|
||||
"Multiple mappers found for table '%s'."
|
||||
% mixed.name
|
||||
)
|
||||
elif not mappers:
|
||||
raise ValueError(
|
||||
"Could not get mapper for table '%s'."
|
||||
)
|
||||
else:
|
||||
return mappers[0]
|
||||
@@ -104,188 +108,6 @@ def get_bind(obj):
|
||||
return conn
|
||||
|
||||
|
||||
def dependent_objects(obj, foreign_keys=None):
|
||||
"""
|
||||
Return a :class:`~sqlalchemy_utils.query_chain.QueryChain` that iterates
|
||||
through all dependent objects for given SQLAlchemy object.
|
||||
|
||||
Consider a User object is referenced in various articles and also in
|
||||
various orders. Getting all these dependent objects is as easy as:
|
||||
|
||||
::
|
||||
|
||||
from sqlalchemy_utils import dependent_objects
|
||||
|
||||
|
||||
dependent_objects(user)
|
||||
|
||||
|
||||
If you expect an object to have lots of dependent_objects it might be good
|
||||
to limit the results::
|
||||
|
||||
|
||||
dependent_objects(user).limit(5)
|
||||
|
||||
|
||||
|
||||
The common use case is checking for all restrict dependent objects before
|
||||
deleting parent object and inform the user if there are dependent objects
|
||||
with ondelete='RESTRICT' foreign keys. If this kind of checking is not used
|
||||
it will lead to nasty IntegrityErrors being raised.
|
||||
|
||||
In the following example we delete given user if it doesn't have any
|
||||
foreign key restricted dependent objects.
|
||||
|
||||
::
|
||||
|
||||
|
||||
from sqlalchemy_utils import get_referencing_foreign_keys
|
||||
|
||||
|
||||
user = session.query(User).get(some_user_id)
|
||||
|
||||
|
||||
deps = list(
|
||||
dependent_objects(
|
||||
user,
|
||||
(
|
||||
fk for fk in get_referencing_foreign_keys(User)
|
||||
# On most databases RESTRICT is the default mode hence we
|
||||
# check for None values also
|
||||
if fk.ondelete == 'RESTRICT' or fk.ondelete is None
|
||||
)
|
||||
).limit(5)
|
||||
)
|
||||
|
||||
if deps:
|
||||
# Do something to inform the user
|
||||
pass
|
||||
else:
|
||||
session.delete(user)
|
||||
|
||||
|
||||
:param obj: SQLAlchemy declarative model object
|
||||
:param foreign_keys:
|
||||
A sequence of foreign keys to use for searching the dependent_objects
|
||||
for given object. By default this is None, indicating that all foreign
|
||||
keys referencing the object will be used.
|
||||
|
||||
.. note::
|
||||
This function does not support exotic mappers that use multiple tables
|
||||
|
||||
.. seealso:: :func:`get_referencing_foreign_keys`
|
||||
|
||||
.. versionadded: 0.26.0
|
||||
"""
|
||||
if foreign_keys is None:
|
||||
foreign_keys = get_referencing_foreign_keys(obj)
|
||||
|
||||
session = object_session(obj)
|
||||
|
||||
chain = QueryChain([])
|
||||
classes = obj.__class__._decl_class_registry
|
||||
|
||||
for table, keys in group_foreign_keys(foreign_keys):
|
||||
for class_ in classes.values():
|
||||
if hasattr(class_, '__table__') and class_.__table__ == table:
|
||||
criteria = []
|
||||
visited_constraints = []
|
||||
for key in keys:
|
||||
if key.constraint not in visited_constraints:
|
||||
visited_constraints.append(key.constraint)
|
||||
subcriteria = [
|
||||
getattr(class_, column.key) ==
|
||||
getattr(
|
||||
obj,
|
||||
key.constraint.elements[index].column.key
|
||||
)
|
||||
for index, column
|
||||
in enumerate(key.constraint.columns)
|
||||
]
|
||||
criteria.append(sa.and_(*subcriteria))
|
||||
|
||||
query = session.query(class_).filter(
|
||||
sa.or_(
|
||||
*criteria
|
||||
)
|
||||
)
|
||||
chain.queries.append(query)
|
||||
return chain
|
||||
|
||||
|
||||
def group_foreign_keys(foreign_keys):
|
||||
"""
|
||||
Return a groupby iterator that groups given foreign keys by table.
|
||||
|
||||
:param foreign_keys: a sequence of foreign keys
|
||||
|
||||
|
||||
::
|
||||
|
||||
foreign_keys = get_referencing_foreign_keys(User)
|
||||
|
||||
for table, fks in group_foreign_keys(foreign_keys):
|
||||
# do something
|
||||
pass
|
||||
|
||||
|
||||
.. also:: :func:`get_referencing_foreign_keys`
|
||||
|
||||
.. versionadded: 0.26.1
|
||||
"""
|
||||
foreign_keys = sorted(
|
||||
foreign_keys, key=lambda key: key.constraint.table.name
|
||||
)
|
||||
return groupby(foreign_keys, lambda key: key.constraint.table)
|
||||
|
||||
|
||||
def get_referencing_foreign_keys(mixed):
|
||||
"""
|
||||
Returns referencing foreign keys for given Table object or declarative
|
||||
class.
|
||||
|
||||
:param mixed:
|
||||
SA Table object or SA declarative class
|
||||
|
||||
::
|
||||
|
||||
get_referencing_foreign_keys(User) # set([ForeignKey('user.id')])
|
||||
|
||||
get_referencing_foreign_keys(User.__table__)
|
||||
|
||||
|
||||
This function also understands inheritance. This means it returns
|
||||
all foreign keys that reference any table in the class inheritance tree.
|
||||
|
||||
Let's say you have three classes which use joined table inheritance,
|
||||
namely TextItem, Article and BlogPost with Article and BlogPost inheriting
|
||||
TextItem.
|
||||
|
||||
::
|
||||
|
||||
# This will check all foreign keys that reference either article table
|
||||
# or textitem table.
|
||||
get_referencing_foreign_keys(Article)
|
||||
|
||||
.. seealso:: :func:`get_tables`
|
||||
"""
|
||||
if isinstance(mixed, sa.Table):
|
||||
tables = [mixed]
|
||||
else:
|
||||
tables = get_tables(mixed)
|
||||
|
||||
referencing_foreign_keys = set()
|
||||
|
||||
for table in mixed.metadata.tables.values():
|
||||
if table not in tables:
|
||||
for constraint in table.constraints:
|
||||
if isinstance(constraint, sa.sql.schema.ForeignKeyConstraint):
|
||||
for fk in constraint.elements:
|
||||
if any(fk.references(t) for t in tables):
|
||||
referencing_foreign_keys.add(fk)
|
||||
return referencing_foreign_keys
|
||||
|
||||
|
||||
def get_primary_keys(mixed):
|
||||
"""
|
||||
Return an OrderedDict of all primary keys for given Table object,
|
||||
|
@@ -1,124 +0,0 @@
|
||||
import six
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.engine import reflection
|
||||
from sqlalchemy.orm import object_session, mapperlib
|
||||
|
||||
|
||||
def dependent_foreign_keys(model_class):
|
||||
"""
|
||||
Returns dependent foreign keys as dicts for given model class.
|
||||
|
||||
** Experimental function **
|
||||
"""
|
||||
session = object_session(model_class)
|
||||
|
||||
engine = session.bind
|
||||
inspector = reflection.Inspector.from_engine(engine)
|
||||
table_names = inspector.get_table_names()
|
||||
|
||||
dependent_foreign_keys = {}
|
||||
|
||||
for table_name in table_names:
|
||||
fks = inspector.get_foreign_keys(table_name)
|
||||
if fks:
|
||||
dependent_foreign_keys[table_name] = []
|
||||
for fk in fks:
|
||||
if fk['referred_table'] == model_class.__tablename__:
|
||||
dependent_foreign_keys[table_name].append(fk)
|
||||
return dependent_foreign_keys
|
||||
|
||||
|
||||
class Merger(object):
|
||||
def memory_merge(self, session, table_name, old_values, new_values):
|
||||
# try to fetch mappers for given table and update in memory objects as
|
||||
# well as database table
|
||||
found = False
|
||||
for mapper in mapperlib._mapper_registry:
|
||||
class_ = mapper.class_
|
||||
if table_name == class_.__table__.name:
|
||||
try:
|
||||
(
|
||||
session.query(mapper.class_)
|
||||
.filter_by(**old_values)
|
||||
.update(
|
||||
new_values,
|
||||
'fetch'
|
||||
)
|
||||
)
|
||||
except sa.exc.IntegrityError:
|
||||
pass
|
||||
found = True
|
||||
return found
|
||||
|
||||
def raw_merge(self, session, table, old_values, new_values):
|
||||
conditions = []
|
||||
for key, value in six.iteritems(old_values):
|
||||
conditions.append(getattr(table.c, key) == value)
|
||||
sql = (
|
||||
table
|
||||
.update()
|
||||
.where(sa.and_(
|
||||
*conditions
|
||||
))
|
||||
.values(
|
||||
new_values
|
||||
)
|
||||
)
|
||||
try:
|
||||
session.execute(sql)
|
||||
except sa.exc.IntegrityError:
|
||||
pass
|
||||
|
||||
def merge_update(self, table_name, from_, to, foreign_key):
|
||||
session = object_session(from_)
|
||||
constrained_columns = foreign_key['constrained_columns']
|
||||
referred_columns = foreign_key['referred_columns']
|
||||
metadata = from_.metadata
|
||||
table = metadata.tables[table_name]
|
||||
|
||||
new_values = {}
|
||||
for index, column in enumerate(constrained_columns):
|
||||
new_values[column] = getattr(
|
||||
to, referred_columns[index]
|
||||
)
|
||||
|
||||
old_values = {}
|
||||
for index, column in enumerate(constrained_columns):
|
||||
old_values[column] = getattr(
|
||||
from_, referred_columns[index]
|
||||
)
|
||||
|
||||
if not self.memory_merge(session, table_name, old_values, new_values):
|
||||
self.raw_merge(session, table, old_values, new_values)
|
||||
|
||||
def __call__(self, from_, to):
|
||||
"""
|
||||
Merges entity into another entity. After merging deletes the from_
|
||||
argument entity.
|
||||
"""
|
||||
if from_.__tablename__ != to.__tablename__:
|
||||
raise Exception()
|
||||
|
||||
session = object_session(from_)
|
||||
foreign_keys = dependent_foreign_keys(from_)
|
||||
|
||||
for table_name in foreign_keys:
|
||||
for foreign_key in foreign_keys[table_name]:
|
||||
self.merge_update(table_name, from_, to, foreign_key)
|
||||
|
||||
session.delete(from_)
|
||||
|
||||
|
||||
def merge(from_, to, merger=Merger):
|
||||
"""
|
||||
Merges entity into another entity. After merging deletes the from_ argument
|
||||
entity.
|
||||
|
||||
After merging the from_ entity is deleted from database.
|
||||
|
||||
:param from_: an entity to merge into another entity
|
||||
:param to: an entity to merge another entity into
|
||||
:param merger: Merger class, by default this is sqlalchemy_utils.Merger
|
||||
class
|
||||
"""
|
||||
return Merger()(from_, to)
|
@@ -71,3 +71,18 @@ class TestGetMapperWithMultipleMappersFound(object):
|
||||
alias = sa.orm.aliased(self.Building.__table__)
|
||||
with raises(ValueError):
|
||||
get_mapper(alias)
|
||||
|
||||
|
||||
class TestGetMapperForTableWithoutMapper(object):
|
||||
def setup_method(self, method):
|
||||
metadata = sa.MetaData()
|
||||
self.building = sa.Table('building', metadata)
|
||||
|
||||
def test_table(self):
|
||||
with raises(ValueError):
|
||||
get_mapper(self.building)
|
||||
|
||||
def test_table_alias(self):
|
||||
alias = sa.orm.aliased(self.building)
|
||||
with raises(ValueError):
|
||||
get_mapper(alias)
|
||||
|
@@ -1,10 +1,10 @@
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy_utils import merge
|
||||
from sqlalchemy_utils import merge_references
|
||||
|
||||
from tests import TestCase
|
||||
|
||||
|
||||
class TestMerge(TestCase):
|
||||
class TestMergeReferences(TestCase):
|
||||
def create_models(self):
|
||||
class User(self.Base):
|
||||
__tablename__ = 'user'
|
||||
@@ -36,21 +36,29 @@ class TestMerge(TestCase):
|
||||
self.session.add(post)
|
||||
self.session.add(post2)
|
||||
self.session.commit()
|
||||
merge(john, jack)
|
||||
merge_references(john, jack)
|
||||
self.session.commit()
|
||||
assert post.author == jack
|
||||
assert post2.author == jack
|
||||
|
||||
def test_deletes_from_entity(self):
|
||||
def test_object_merging_whenever_possible(self):
|
||||
john = self.User(name=u'John')
|
||||
jack = self.User(name=u'Jack')
|
||||
post = self.BlogPost(title=u'Some title', author=john)
|
||||
post2 = self.BlogPost(title=u'Other title', author=jack)
|
||||
self.session.add(john)
|
||||
self.session.add(jack)
|
||||
self.session.add(post)
|
||||
self.session.add(post2)
|
||||
self.session.commit()
|
||||
merge(john, jack)
|
||||
assert john in self.session.deleted
|
||||
# Load the author for post
|
||||
assert post.author_id == john.id
|
||||
merge_references(john, jack)
|
||||
assert post.author_id == jack.id
|
||||
assert post2.author_id == jack.id
|
||||
|
||||
|
||||
class TestMergeManyToManyAssociations(TestCase):
|
||||
class TestMergeReferencesWithManyToManyAssociations(TestCase):
|
||||
def create_models(self):
|
||||
class User(self.Base):
|
||||
__tablename__ = 'user'
|
||||
@@ -88,7 +96,7 @@ class TestMergeManyToManyAssociations(TestCase):
|
||||
self.User = User
|
||||
self.Team = Team
|
||||
|
||||
def test_when_association_only_exists_in_from_entity(self):
|
||||
def test_supports_associations(self):
|
||||
john = self.User(name=u'John')
|
||||
jack = self.User(name=u'Jack')
|
||||
team = self.Team(name=u'Team')
|
||||
@@ -96,29 +104,12 @@ class TestMergeManyToManyAssociations(TestCase):
|
||||
self.session.add(john)
|
||||
self.session.add(jack)
|
||||
self.session.commit()
|
||||
merge(john, jack)
|
||||
merge_references(john, jack)
|
||||
assert john not in team.members
|
||||
assert jack in team.members
|
||||
|
||||
# def test_when_association_exists_in_both(self):
|
||||
# john = self.User(name=u'John')
|
||||
# jack = self.User(name=u'Jack')
|
||||
# team = self.Team(name=u'Team')
|
||||
# team.members.append(john)
|
||||
# team.members.append(jack)
|
||||
# self.session.add(john)
|
||||
# self.session.add(jack)
|
||||
# self.session.commit()
|
||||
# merge(john, jack)
|
||||
# assert john not in team.members
|
||||
# assert jack in team.members
|
||||
# count = self.session.execute(
|
||||
# 'SELECT COUNT(1) FROM team_member'
|
||||
# ).fetchone()[0]
|
||||
# assert count == 1
|
||||
|
||||
|
||||
class TestMergeManyToManyAssociationObjects(TestCase):
|
||||
class TestMergeReferencesWithManyToManyAssociationObjects(TestCase):
|
||||
def create_models(self):
|
||||
class Team(self.Base):
|
||||
__tablename__ = 'team'
|
||||
@@ -164,7 +155,7 @@ class TestMergeManyToManyAssociationObjects(TestCase):
|
||||
self.TeamMember = TeamMember
|
||||
self.Team = Team
|
||||
|
||||
def test_when_association_only_exists_in_from_entity(self):
|
||||
def test_supports_associations(self):
|
||||
john = self.User(name=u'John')
|
||||
jack = self.User(name=u'Jack')
|
||||
team = self.Team(name=u'Team')
|
||||
@@ -173,24 +164,8 @@ class TestMergeManyToManyAssociationObjects(TestCase):
|
||||
self.session.add(jack)
|
||||
self.session.add(team)
|
||||
self.session.commit()
|
||||
merge(john, jack)
|
||||
merge_references(john, jack)
|
||||
self.session.commit()
|
||||
users = [member.user for member in team.members]
|
||||
assert john not in users
|
||||
assert jack in users
|
||||
|
||||
# def test_when_association_exists_in_both(self):
|
||||
# john = self.User(name=u'John')
|
||||
# jack = self.User(name=u'Jack')
|
||||
# team = self.Team(name=u'Team')
|
||||
# team.members.append(self.TeamMember(user=john))
|
||||
# team.members.append(self.TeamMember(user=jack))
|
||||
# self.session.add(john)
|
||||
# self.session.add(jack)
|
||||
# self.session.add(team)
|
||||
# self.session.commit()
|
||||
# merge(john, jack)
|
||||
# users = [member.user for member in team.members]
|
||||
# assert john not in users
|
||||
# assert jack in users
|
||||
# assert self.session.query(self.TeamMember).count() == 1
|
Reference in New Issue
Block a user