Refactor docs, add merge_references
This commit is contained in:
		| @@ -4,12 +4,13 @@ Changelog | ||||
| Here you can see the full list of changes between each SQLAlchemy-Utils release. | ||||
|  | ||||
|  | ||||
| 0.26.1 (2014-05-xx) | ||||
| 0.26.1 (2014-05-14) | ||||
| ^^^^^^^^^^^^^^^^^^^ | ||||
|  | ||||
| - Added get_bind | ||||
| - Added group_foreign_keys | ||||
| - Added get_mapper | ||||
| - Added merge_references | ||||
|  | ||||
|  | ||||
| 0.26.0 (2014-05-07) | ||||
|   | ||||
| @@ -5,18 +5,6 @@ Database helpers | ||||
| .. module:: sqlalchemy_utils.functions | ||||
|  | ||||
|  | ||||
| is_indexed_foreign_key | ||||
| ^^^^^^^^^^^^^^^^^^^^^^ | ||||
|  | ||||
| .. autofunction:: is_indexed_foreign_key | ||||
|  | ||||
|  | ||||
| non_indexed_foreign_keys | ||||
| ^^^^^^^^^^^^^^^^^^^^^^^^ | ||||
|  | ||||
| .. autofunction:: non_indexed_foreign_keys | ||||
|  | ||||
|  | ||||
| database_exists | ||||
| ^^^^^^^^^^^^^^^ | ||||
|  | ||||
|   | ||||
							
								
								
									
										40
									
								
								docs/foreign_key_helpers.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										40
									
								
								docs/foreign_key_helpers.rst
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,40 @@ | ||||
| Foreign key helpers | ||||
| =================== | ||||
|  | ||||
| .. module:: sqlalchemy_utils.functions | ||||
|  | ||||
|  | ||||
| dependent_objects | ||||
| ^^^^^^^^^^^^^^^^^ | ||||
|  | ||||
| .. autofunction:: dependent_objects | ||||
|  | ||||
|  | ||||
| get_referencing_foreign_keys | ||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||||
|  | ||||
| .. autofunction:: get_referencing_foreign_keys | ||||
|  | ||||
|  | ||||
| group_foreign_keys | ||||
| ^^^^^^^^^^^^^^^^^^ | ||||
|  | ||||
| .. autofunction:: group_foreign_keys | ||||
|  | ||||
|  | ||||
| is_indexed_foreign_key | ||||
| ^^^^^^^^^^^^^^^^^^^^^^ | ||||
|  | ||||
| .. autofunction:: is_indexed_foreign_key | ||||
|  | ||||
|  | ||||
| merge_references | ||||
| ^^^^^^^^^^^^^^^^ | ||||
|  | ||||
| .. autofunction:: merge_references | ||||
|  | ||||
|  | ||||
| non_indexed_foreign_keys | ||||
| ^^^^^^^^^^^^^^^^^^^^^^^^ | ||||
|  | ||||
| .. autofunction:: non_indexed_foreign_keys | ||||
| @@ -1,5 +1,5 @@ | ||||
| Generic relationship | ||||
| ==================== | ||||
| Generic relationships | ||||
| ===================== | ||||
|  | ||||
| Generic relationship is a form of relationship that supports creating a 1 to many relationship to any target model. | ||||
|  | ||||
|   | ||||
| @@ -16,6 +16,7 @@ SQLAlchemy-Utils provides custom data types and various utility functions for SQ | ||||
|    decorators | ||||
|    generic_relationship | ||||
|    database_helpers | ||||
|    model_helpers | ||||
|    foreign_key_helpers | ||||
|    orm_helpers | ||||
|    utility_classes | ||||
|    license | ||||
|   | ||||
| @@ -1,15 +1,9 @@ | ||||
| Model helpers | ||||
| ============= | ||||
| ORM helpers | ||||
| =========== | ||||
| 
 | ||||
| .. module:: sqlalchemy_utils.functions | ||||
| 
 | ||||
| 
 | ||||
| dependent_objects | ||||
| ^^^^^^^^^^^^^^^^^ | ||||
| 
 | ||||
| .. autofunction:: dependent_objects | ||||
| 
 | ||||
| 
 | ||||
| escape_like | ||||
| ^^^^^^^^^^^ | ||||
| 
 | ||||
| @@ -46,24 +40,12 @@ get_primary_keys | ||||
| .. autofunction:: get_primary_keys | ||||
| 
 | ||||
| 
 | ||||
| get_referencing_foreign_keys | ||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||||
| 
 | ||||
| .. autofunction:: get_referencing_foreign_keys | ||||
| 
 | ||||
| 
 | ||||
| get_tables | ||||
| ^^^^^^^^^^ | ||||
| 
 | ||||
| .. autofunction:: get_tables | ||||
| 
 | ||||
| 
 | ||||
| group_foreign_keys | ||||
| ^^^^^^^^^^^^^^^^^^ | ||||
| 
 | ||||
| .. autofunction:: group_foreign_keys | ||||
| 
 | ||||
| 
 | ||||
| query_entities | ||||
| ^^^^^^^^^^^^^^ | ||||
| 
 | ||||
							
								
								
									
										2
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								setup.py
									
									
									
									
									
								
							| @@ -44,7 +44,7 @@ for name, requirements in extras_require.items(): | ||||
|  | ||||
| setup( | ||||
|     name='SQLAlchemy-Utils', | ||||
|     version='0.26.0', | ||||
|     version='0.26.1', | ||||
|     url='https://github.com/kvesteri/sqlalchemy-utils', | ||||
|     license='BSD', | ||||
|     author='Konsta Vesterinen, Ryan Leckey, Janne Vanhala, Vesa Uimonen', | ||||
|   | ||||
| @@ -20,6 +20,7 @@ from .functions import ( | ||||
|     get_tables, | ||||
|     group_foreign_keys, | ||||
|     identity, | ||||
|     merge_references, | ||||
|     mock_engine, | ||||
|     naturally_equivalent, | ||||
|     render_expression, | ||||
| @@ -32,7 +33,6 @@ from .listeners import ( | ||||
|     force_auto_coercion, | ||||
|     force_instant_defaults | ||||
| ) | ||||
| from .merge import merge, Merger | ||||
| from .generic import generic_relationship | ||||
| from .proxy_dict import ProxyDict, proxy_dict | ||||
| from .query_chain import QueryChain | ||||
| @@ -67,7 +67,7 @@ from .types import ( | ||||
| ) | ||||
|  | ||||
|  | ||||
| __version__ = '0.26.0' | ||||
| __version__ = '0.26.1' | ||||
|  | ||||
|  | ||||
| __all__ = ( | ||||
| @@ -95,7 +95,7 @@ __all__ = ( | ||||
|     group_foreign_keys, | ||||
|     identity, | ||||
|     instrumented_list, | ||||
|     merge, | ||||
|     merge_references, | ||||
|     mock_engine, | ||||
|     naturally_equivalent, | ||||
|     proxy_dict, | ||||
| @@ -120,7 +120,6 @@ __all__ = ( | ||||
|     IPAddressType, | ||||
|     JSONType, | ||||
|     LocaleType, | ||||
|     Merger, | ||||
|     NumericRangeType, | ||||
|     Password, | ||||
|     PasswordType, | ||||
|   | ||||
| @@ -8,20 +8,23 @@ from .database import ( | ||||
|     drop_database, | ||||
|     escape_like, | ||||
|     is_auto_assigned_date_column, | ||||
| ) | ||||
| from .foreign_keys import ( | ||||
|     dependent_objects, | ||||
|     get_referencing_foreign_keys, | ||||
|     group_foreign_keys, | ||||
|     is_indexed_foreign_key, | ||||
|     merge_references, | ||||
|     non_indexed_foreign_keys, | ||||
| ) | ||||
| from .orm import ( | ||||
|     dependent_objects, | ||||
|     get_bind, | ||||
|     get_columns, | ||||
|     get_declarative_base, | ||||
|     get_mapper, | ||||
|     get_primary_keys, | ||||
|     get_referencing_foreign_keys, | ||||
|     get_tables, | ||||
|     getdotattr, | ||||
|     group_foreign_keys, | ||||
|     has_changes, | ||||
|     identity, | ||||
|     naturally_equivalent, | ||||
|   | ||||
| @@ -1,4 +1,3 @@ | ||||
| from collections import defaultdict | ||||
| from sqlalchemy.engine.url import make_url | ||||
| import sqlalchemy as sa | ||||
| from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint | ||||
| @@ -186,55 +185,3 @@ def drop_database(url): | ||||
|     else: | ||||
|         text = "DROP DATABASE %s" % database | ||||
|         engine.execute(text) | ||||
|  | ||||
|  | ||||
| def non_indexed_foreign_keys(metadata, engine=None): | ||||
|     """ | ||||
|     Finds all non indexed foreign keys from all tables of given MetaData. | ||||
|  | ||||
|     Very useful for optimizing postgresql database and finding out which | ||||
|     foreign keys need indexes. | ||||
|  | ||||
|     :param metadata: MetaData object to inspect tables from | ||||
|     """ | ||||
|     reflected_metadata = MetaData() | ||||
|  | ||||
|     if metadata.bind is None and engine is None: | ||||
|         raise Exception( | ||||
|             'Either pass a metadata object with bind or ' | ||||
|             'pass engine as a second parameter' | ||||
|         ) | ||||
|  | ||||
|     constraints = defaultdict(list) | ||||
|  | ||||
|     for table_name in metadata.tables.keys(): | ||||
|         table = Table( | ||||
|             table_name, | ||||
|             reflected_metadata, | ||||
|             autoload=True, | ||||
|             autoload_with=metadata.bind or engine | ||||
|         ) | ||||
|  | ||||
|         for constraint in table.constraints: | ||||
|             if not isinstance(constraint, ForeignKeyConstraint): | ||||
|                 continue | ||||
|  | ||||
|             if not is_indexed_foreign_key(constraint): | ||||
|                 constraints[table.name].append(constraint) | ||||
|  | ||||
|     return dict(constraints) | ||||
|  | ||||
|  | ||||
| def is_indexed_foreign_key(constraint): | ||||
|     """ | ||||
|     Whether or not given foreign key constraint's columns have been indexed. | ||||
|  | ||||
|     :param constraint: ForeignKeyConstraint object to check the indexes | ||||
|     """ | ||||
|     for index in constraint.table.indexes: | ||||
|         index_column_names = set( | ||||
|             column.name for column in index.columns | ||||
|         ) | ||||
|         if index_column_names == set(constraint.columns): | ||||
|             return True | ||||
|     return False | ||||
|   | ||||
							
								
								
									
										350
									
								
								sqlalchemy_utils/functions/foreign_keys.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										350
									
								
								sqlalchemy_utils/functions/foreign_keys.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,350 @@ | ||||
| from collections import defaultdict | ||||
| from itertools import groupby | ||||
|  | ||||
| import six | ||||
| import sqlalchemy as sa | ||||
| from sqlalchemy.engine import reflection | ||||
| from sqlalchemy.orm import object_session, mapperlib | ||||
| from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint | ||||
|  | ||||
| from .orm import get_mapper, get_tables | ||||
| from ..query_chain import QueryChain | ||||
|  | ||||
|  | ||||
| def get_foreign_key_values(fk, obj): | ||||
|     return { | ||||
|         fk.constraint.columns[index].key: | ||||
|         getattr(obj, element.column.key) | ||||
|         for | ||||
|         index, element | ||||
|         in | ||||
|         enumerate(fk.constraint.elements) | ||||
|     } | ||||
|  | ||||
|  | ||||
| def group_foreign_keys(foreign_keys): | ||||
|     """ | ||||
|     Return a groupby iterator that groups given foreign keys by table. | ||||
|  | ||||
|     :param foreign_keys: a sequence of foreign keys | ||||
|  | ||||
|  | ||||
|     :: | ||||
|  | ||||
|         foreign_keys = get_referencing_foreign_keys(User) | ||||
|  | ||||
|         for table, fks in group_foreign_keys(foreign_keys): | ||||
|             # do something | ||||
|             pass | ||||
|  | ||||
|  | ||||
|     .. seealso:: :func:`get_referencing_foreign_keys` | ||||
|  | ||||
|     .. versionadded: 0.26.1 | ||||
|     """ | ||||
|     foreign_keys = sorted( | ||||
|         foreign_keys, key=lambda key: key.constraint.table.name | ||||
|     ) | ||||
|     return groupby(foreign_keys, lambda key: key.constraint.table) | ||||
|  | ||||
|  | ||||
| def get_referencing_foreign_keys(mixed): | ||||
|     """ | ||||
|     Returns referencing foreign keys for given Table object or declarative | ||||
|     class. | ||||
|  | ||||
|     :param mixed: | ||||
|         SA Table object or SA declarative class | ||||
|  | ||||
|     :: | ||||
|  | ||||
|         get_referencing_foreign_keys(User)  # set([ForeignKey('user.id')]) | ||||
|  | ||||
|         get_referencing_foreign_keys(User.__table__) | ||||
|  | ||||
|  | ||||
|     This function also understands inheritance. This means it returns | ||||
|     all foreign keys that reference any table in the class inheritance tree. | ||||
|  | ||||
|     Let's say you have three classes which use joined table inheritance, | ||||
|     namely TextItem, Article and BlogPost with Article and BlogPost inheriting | ||||
|     TextItem. | ||||
|  | ||||
|     :: | ||||
|  | ||||
|         # This will check all foreign keys that reference either article table | ||||
|         # or textitem table. | ||||
|         get_referencing_foreign_keys(Article) | ||||
|  | ||||
|     .. seealso:: :func:`get_tables` | ||||
|     """ | ||||
|     if isinstance(mixed, sa.Table): | ||||
|         tables = [mixed] | ||||
|     else: | ||||
|         tables = get_tables(mixed) | ||||
|  | ||||
|     referencing_foreign_keys = set() | ||||
|  | ||||
|     for table in mixed.metadata.tables.values(): | ||||
|         if table not in tables: | ||||
|             for constraint in table.constraints: | ||||
|                 if isinstance(constraint, sa.sql.schema.ForeignKeyConstraint): | ||||
|                     for fk in constraint.elements: | ||||
|                         if any(fk.references(t) for t in tables): | ||||
|                             referencing_foreign_keys.add(fk) | ||||
|     return referencing_foreign_keys | ||||
|  | ||||
|  | ||||
| def merge_references(from_, to, foreign_keys=None): | ||||
|     """ | ||||
|     Merge the references of an entity into another entity. | ||||
|  | ||||
|     Consider the following models:: | ||||
|  | ||||
|         class User(self.Base): | ||||
|             __tablename__ = 'user' | ||||
|             id = sa.Column(sa.Integer, primary_key=True) | ||||
|             name = sa.Column(sa.Unicode(255)) | ||||
|  | ||||
|             def __repr__(self): | ||||
|                 return 'User(name=%r)' % self.name | ||||
|  | ||||
|         class BlogPost(self.Base): | ||||
|             __tablename__ = 'blog_post' | ||||
|             id = sa.Column(sa.Integer, primary_key=True) | ||||
|             title = sa.Column(sa.Unicode(255)) | ||||
|             author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id')) | ||||
|  | ||||
|             author = sa.orm.relationship(User) | ||||
|  | ||||
|  | ||||
|     Now lets add some data:: | ||||
|  | ||||
|         john = self.User(name=u'John') | ||||
|         jack = self.User(name=u'Jack') | ||||
|         post = self.BlogPost(title=u'Some title', author=john) | ||||
|         post2 = self.BlogPost(title=u'Other title', author=jack) | ||||
|         self.session.add_all([ | ||||
|             john, | ||||
|             jack, | ||||
|             post, | ||||
|             post2 | ||||
|         ]) | ||||
|         self.session.commit() | ||||
|  | ||||
|  | ||||
|     If we wanted to merge all John's references to Jack it would be as easy as | ||||
|     :: | ||||
|  | ||||
|         merge_references(john, jack) | ||||
|         self.session.commit() | ||||
|  | ||||
|         post.author     # User(name='Jack') | ||||
|         post2.author    # User(name='Jack') | ||||
|  | ||||
|  | ||||
|  | ||||
|     :param from_: an entity to merge into another entity | ||||
|     :param to: an entity to merge another entity into | ||||
|     :param foreign_keys: A sequence of foreign keys. By default this is None | ||||
|         indicating all referencing foreign keys should be used. | ||||
|  | ||||
|     .. seealso: :func:`dependent_objects` | ||||
|  | ||||
|     .. versionadded: 0.26.1 | ||||
|     """ | ||||
|     if from_.__tablename__ != to.__tablename__: | ||||
|         raise TypeError('The tables of given arguments do not match.') | ||||
|  | ||||
|     session = object_session(from_) | ||||
|     foreign_keys = get_referencing_foreign_keys(from_) | ||||
|  | ||||
|     for fk in foreign_keys: | ||||
|         old_values = get_foreign_key_values(fk, from_) | ||||
|         new_values = get_foreign_key_values(fk, to) | ||||
|         criteria = ( | ||||
|             getattr(fk.constraint.table.c, key) == value | ||||
|             for key, value in six.iteritems(old_values) | ||||
|         ) | ||||
|         try: | ||||
|             mapper = get_mapper(fk.constraint.table) | ||||
|         except ValueError: | ||||
|             query = ( | ||||
|                 fk.constraint.table | ||||
|                 .update() | ||||
|                 .where(sa.and_(*criteria)) | ||||
|                 .values(new_values) | ||||
|             ) | ||||
|             session.execute(query) | ||||
|         else: | ||||
|             print old_values, new_values | ||||
|             ( | ||||
|                 session.query(mapper.class_) | ||||
|                 .filter_by(**old_values) | ||||
|                 .update( | ||||
|                     new_values, | ||||
|                     'evaluate' | ||||
|                 ) | ||||
|             ) | ||||
|  | ||||
|  | ||||
| def dependent_objects(obj, foreign_keys=None): | ||||
|     """ | ||||
|     Return a :class:`~sqlalchemy_utils.query_chain.QueryChain` that iterates | ||||
|     through all dependent objects for given SQLAlchemy object. | ||||
|  | ||||
|     Consider a User object is referenced in various articles and also in | ||||
|     various orders. Getting all these dependent objects is as easy as: | ||||
|  | ||||
|     :: | ||||
|  | ||||
|         from sqlalchemy_utils import dependent_objects | ||||
|  | ||||
|  | ||||
|         dependent_objects(user) | ||||
|  | ||||
|  | ||||
|     If you expect an object to have lots of dependent_objects it might be good | ||||
|     to limit the results:: | ||||
|  | ||||
|  | ||||
|         dependent_objects(user).limit(5) | ||||
|  | ||||
|  | ||||
|  | ||||
|     The common use case is checking for all restrict dependent objects before | ||||
|     deleting parent object and inform the user if there are dependent objects | ||||
|     with ondelete='RESTRICT' foreign keys. If this kind of checking is not used | ||||
|     it will lead to nasty IntegrityErrors being raised. | ||||
|  | ||||
|     In the following example we delete given user if it doesn't have any | ||||
|     foreign key restricted dependent objects. | ||||
|  | ||||
|     :: | ||||
|  | ||||
|  | ||||
|         from sqlalchemy_utils import get_referencing_foreign_keys | ||||
|  | ||||
|  | ||||
|         user = session.query(User).get(some_user_id) | ||||
|  | ||||
|  | ||||
|         deps = list( | ||||
|             dependent_objects( | ||||
|                 user, | ||||
|                 ( | ||||
|                     fk for fk in get_referencing_foreign_keys(User) | ||||
|                     # On most databases RESTRICT is the default mode hence we | ||||
|                     # check for None values also | ||||
|                     if fk.ondelete == 'RESTRICT' or fk.ondelete is None | ||||
|                 ) | ||||
|             ).limit(5) | ||||
|         ) | ||||
|  | ||||
|         if deps: | ||||
|             # Do something to inform the user | ||||
|             pass | ||||
|         else: | ||||
|             session.delete(user) | ||||
|  | ||||
|  | ||||
|     :param obj: SQLAlchemy declarative model object | ||||
|     :param foreign_keys: | ||||
|         A sequence of foreign keys to use for searching the dependent_objects | ||||
|         for given object. By default this is None, indicating that all foreign | ||||
|         keys referencing the object will be used. | ||||
|  | ||||
|     .. note:: | ||||
|         This function does not support exotic mappers that use multiple tables | ||||
|  | ||||
|     .. seealso:: :func:`get_referencing_foreign_keys` | ||||
|     .. seealso:: :func:`merge_references` | ||||
|  | ||||
|     .. versionadded: 0.26.0 | ||||
|     """ | ||||
|     if foreign_keys is None: | ||||
|         foreign_keys = get_referencing_foreign_keys(obj) | ||||
|  | ||||
|     session = object_session(obj) | ||||
|  | ||||
|     chain = QueryChain([]) | ||||
|     classes = obj.__class__._decl_class_registry | ||||
|  | ||||
|     for table, keys in group_foreign_keys(foreign_keys): | ||||
|         for class_ in classes.values(): | ||||
|             if hasattr(class_, '__table__') and class_.__table__ == table: | ||||
|                 criteria = [] | ||||
|                 visited_constraints = [] | ||||
|                 for key in keys: | ||||
|                     if key.constraint not in visited_constraints: | ||||
|                         visited_constraints.append(key.constraint) | ||||
|                         subcriteria = [ | ||||
|                             getattr(class_, column.key) == | ||||
|                             getattr( | ||||
|                                 obj, | ||||
|                                 key.constraint.elements[index].column.key | ||||
|                             ) | ||||
|                             for index, column | ||||
|                             in enumerate(key.constraint.columns) | ||||
|                         ] | ||||
|                         criteria.append(sa.and_(*subcriteria)) | ||||
|  | ||||
|                 query = session.query(class_).filter( | ||||
|                     sa.or_( | ||||
|                         *criteria | ||||
|                     ) | ||||
|                 ) | ||||
|                 chain.queries.append(query) | ||||
|         return chain | ||||
|  | ||||
|  | ||||
| def non_indexed_foreign_keys(metadata, engine=None): | ||||
|     """ | ||||
|     Finds all non indexed foreign keys from all tables of given MetaData. | ||||
|  | ||||
|     Very useful for optimizing postgresql database and finding out which | ||||
|     foreign keys need indexes. | ||||
|  | ||||
|     :param metadata: MetaData object to inspect tables from | ||||
|     """ | ||||
|     reflected_metadata = MetaData() | ||||
|  | ||||
|     if metadata.bind is None and engine is None: | ||||
|         raise Exception( | ||||
|             'Either pass a metadata object with bind or ' | ||||
|             'pass engine as a second parameter' | ||||
|         ) | ||||
|  | ||||
|     constraints = defaultdict(list) | ||||
|  | ||||
|     for table_name in metadata.tables.keys(): | ||||
|         table = Table( | ||||
|             table_name, | ||||
|             reflected_metadata, | ||||
|             autoload=True, | ||||
|             autoload_with=metadata.bind or engine | ||||
|         ) | ||||
|  | ||||
|         for constraint in table.constraints: | ||||
|             if not isinstance(constraint, ForeignKeyConstraint): | ||||
|                 continue | ||||
|  | ||||
|             if not is_indexed_foreign_key(constraint): | ||||
|                 constraints[table.name].append(constraint) | ||||
|  | ||||
|     return dict(constraints) | ||||
|  | ||||
|  | ||||
| def is_indexed_foreign_key(constraint): | ||||
|     """ | ||||
|     Whether or not given foreign key constraint's columns have been indexed. | ||||
|  | ||||
|     :param constraint: ForeignKeyConstraint object to check the indexes | ||||
|     """ | ||||
|     for index in constraint.table.indexes: | ||||
|         index_column_names = set( | ||||
|             column.name for column in index.columns | ||||
|         ) | ||||
|         if index_column_names == set(constraint.columns): | ||||
|             return True | ||||
|     return False | ||||
| @@ -61,8 +61,12 @@ def get_mapper(mixed): | ||||
|         ] | ||||
|         if len(mappers) > 1: | ||||
|             raise ValueError( | ||||
|                 "Could not get mapper for '%r'. Multiple mappers found." | ||||
|                 % mixed | ||||
|                 "Multiple mappers found for table '%s'." | ||||
|                 % mixed.name | ||||
|             ) | ||||
|         elif not mappers: | ||||
|             raise ValueError( | ||||
|                 "Could not get mapper for table '%s'." | ||||
|             ) | ||||
|         else: | ||||
|             return mappers[0] | ||||
| @@ -104,188 +108,6 @@ def get_bind(obj): | ||||
|     return conn | ||||
|  | ||||
|  | ||||
| def dependent_objects(obj, foreign_keys=None): | ||||
|     """ | ||||
|     Return a :class:`~sqlalchemy_utils.query_chain.QueryChain` that iterates | ||||
|     through all dependent objects for given SQLAlchemy object. | ||||
|  | ||||
|     Consider a User object is referenced in various articles and also in | ||||
|     various orders. Getting all these dependent objects is as easy as: | ||||
|  | ||||
|     :: | ||||
|  | ||||
|         from sqlalchemy_utils import dependent_objects | ||||
|  | ||||
|  | ||||
|         dependent_objects(user) | ||||
|  | ||||
|  | ||||
|     If you expect an object to have lots of dependent_objects it might be good | ||||
|     to limit the results:: | ||||
|  | ||||
|  | ||||
|         dependent_objects(user).limit(5) | ||||
|  | ||||
|  | ||||
|  | ||||
|     The common use case is checking for all restrict dependent objects before | ||||
|     deleting parent object and inform the user if there are dependent objects | ||||
|     with ondelete='RESTRICT' foreign keys. If this kind of checking is not used | ||||
|     it will lead to nasty IntegrityErrors being raised. | ||||
|  | ||||
|     In the following example we delete given user if it doesn't have any | ||||
|     foreign key restricted dependent objects. | ||||
|  | ||||
|     :: | ||||
|  | ||||
|  | ||||
|         from sqlalchemy_utils import get_referencing_foreign_keys | ||||
|  | ||||
|  | ||||
|         user = session.query(User).get(some_user_id) | ||||
|  | ||||
|  | ||||
|         deps = list( | ||||
|             dependent_objects( | ||||
|                 user, | ||||
|                 ( | ||||
|                     fk for fk in get_referencing_foreign_keys(User) | ||||
|                     # On most databases RESTRICT is the default mode hence we | ||||
|                     # check for None values also | ||||
|                     if fk.ondelete == 'RESTRICT' or fk.ondelete is None | ||||
|                 ) | ||||
|             ).limit(5) | ||||
|         ) | ||||
|  | ||||
|         if deps: | ||||
|             # Do something to inform the user | ||||
|             pass | ||||
|         else: | ||||
|             session.delete(user) | ||||
|  | ||||
|  | ||||
|     :param obj: SQLAlchemy declarative model object | ||||
|     :param foreign_keys: | ||||
|         A sequence of foreign keys to use for searching the dependent_objects | ||||
|         for given object. By default this is None, indicating that all foreign | ||||
|         keys referencing the object will be used. | ||||
|  | ||||
|     .. note:: | ||||
|         This function does not support exotic mappers that use multiple tables | ||||
|  | ||||
|     .. seealso:: :func:`get_referencing_foreign_keys` | ||||
|  | ||||
|     .. versionadded: 0.26.0 | ||||
|     """ | ||||
|     if foreign_keys is None: | ||||
|         foreign_keys = get_referencing_foreign_keys(obj) | ||||
|  | ||||
|     session = object_session(obj) | ||||
|  | ||||
|     chain = QueryChain([]) | ||||
|     classes = obj.__class__._decl_class_registry | ||||
|  | ||||
|     for table, keys in group_foreign_keys(foreign_keys): | ||||
|         for class_ in classes.values(): | ||||
|             if hasattr(class_, '__table__') and class_.__table__ == table: | ||||
|                 criteria = [] | ||||
|                 visited_constraints = [] | ||||
|                 for key in keys: | ||||
|                     if key.constraint not in visited_constraints: | ||||
|                         visited_constraints.append(key.constraint) | ||||
|                         subcriteria = [ | ||||
|                             getattr(class_, column.key) == | ||||
|                             getattr( | ||||
|                                 obj, | ||||
|                                 key.constraint.elements[index].column.key | ||||
|                             ) | ||||
|                             for index, column | ||||
|                             in enumerate(key.constraint.columns) | ||||
|                         ] | ||||
|                         criteria.append(sa.and_(*subcriteria)) | ||||
|  | ||||
|                 query = session.query(class_).filter( | ||||
|                     sa.or_( | ||||
|                         *criteria | ||||
|                     ) | ||||
|                 ) | ||||
|                 chain.queries.append(query) | ||||
|         return chain | ||||
|  | ||||
|  | ||||
| def group_foreign_keys(foreign_keys): | ||||
|     """ | ||||
|     Return a groupby iterator that groups given foreign keys by table. | ||||
|  | ||||
|     :param foreign_keys: a sequence of foreign keys | ||||
|  | ||||
|  | ||||
|     :: | ||||
|  | ||||
|         foreign_keys = get_referencing_foreign_keys(User) | ||||
|  | ||||
|         for table, fks in group_foreign_keys(foreign_keys): | ||||
|             # do something | ||||
|             pass | ||||
|  | ||||
|  | ||||
|     .. also:: :func:`get_referencing_foreign_keys` | ||||
|  | ||||
|     .. versionadded: 0.26.1 | ||||
|     """ | ||||
|     foreign_keys = sorted( | ||||
|         foreign_keys, key=lambda key: key.constraint.table.name | ||||
|     ) | ||||
|     return groupby(foreign_keys, lambda key: key.constraint.table) | ||||
|  | ||||
|  | ||||
| def get_referencing_foreign_keys(mixed): | ||||
|     """ | ||||
|     Returns referencing foreign keys for given Table object or declarative | ||||
|     class. | ||||
|  | ||||
|     :param mixed: | ||||
|         SA Table object or SA declarative class | ||||
|  | ||||
|     :: | ||||
|  | ||||
|         get_referencing_foreign_keys(User)  # set([ForeignKey('user.id')]) | ||||
|  | ||||
|         get_referencing_foreign_keys(User.__table__) | ||||
|  | ||||
|  | ||||
|     This function also understands inheritance. This means it returns | ||||
|     all foreign keys that reference any table in the class inheritance tree. | ||||
|  | ||||
|     Let's say you have three classes which use joined table inheritance, | ||||
|     namely TextItem, Article and BlogPost with Article and BlogPost inheriting | ||||
|     TextItem. | ||||
|  | ||||
|     :: | ||||
|  | ||||
|         # This will check all foreign keys that reference either article table | ||||
|         # or textitem table. | ||||
|         get_referencing_foreign_keys(Article) | ||||
|  | ||||
|     .. seealso:: :func:`get_tables` | ||||
|     """ | ||||
|     if isinstance(mixed, sa.Table): | ||||
|         tables = [mixed] | ||||
|     else: | ||||
|         tables = get_tables(mixed) | ||||
|  | ||||
|     referencing_foreign_keys = set() | ||||
|  | ||||
|     for table in mixed.metadata.tables.values(): | ||||
|         if table not in tables: | ||||
|             for constraint in table.constraints: | ||||
|                 if isinstance(constraint, sa.sql.schema.ForeignKeyConstraint): | ||||
|                     for fk in constraint.elements: | ||||
|                         if any(fk.references(t) for t in tables): | ||||
|                             referencing_foreign_keys.add(fk) | ||||
|     return referencing_foreign_keys | ||||
|  | ||||
|  | ||||
| def get_primary_keys(mixed): | ||||
|     """ | ||||
|     Return an OrderedDict of all primary keys for given Table object, | ||||
|   | ||||
| @@ -1,124 +0,0 @@ | ||||
| import six | ||||
| import sqlalchemy as sa | ||||
| from sqlalchemy.engine import reflection | ||||
| from sqlalchemy.orm import object_session, mapperlib | ||||
|  | ||||
|  | ||||
| def dependent_foreign_keys(model_class): | ||||
|     """ | ||||
|     Returns dependent foreign keys as dicts for given model class. | ||||
|  | ||||
|     ** Experimental function ** | ||||
|     """ | ||||
|     session = object_session(model_class) | ||||
|  | ||||
|     engine = session.bind | ||||
|     inspector = reflection.Inspector.from_engine(engine) | ||||
|     table_names = inspector.get_table_names() | ||||
|  | ||||
|     dependent_foreign_keys = {} | ||||
|  | ||||
|     for table_name in table_names: | ||||
|         fks = inspector.get_foreign_keys(table_name) | ||||
|         if fks: | ||||
|             dependent_foreign_keys[table_name] = [] | ||||
|             for fk in fks: | ||||
|                 if fk['referred_table'] == model_class.__tablename__: | ||||
|                     dependent_foreign_keys[table_name].append(fk) | ||||
|     return dependent_foreign_keys | ||||
|  | ||||
|  | ||||
| class Merger(object): | ||||
|     def memory_merge(self, session, table_name, old_values, new_values): | ||||
|         # try to fetch mappers for given table and update in memory objects as | ||||
|         # well as database table | ||||
|         found = False | ||||
|         for mapper in mapperlib._mapper_registry: | ||||
|             class_ = mapper.class_ | ||||
|             if table_name == class_.__table__.name: | ||||
|                 try: | ||||
|                     ( | ||||
|                         session.query(mapper.class_) | ||||
|                         .filter_by(**old_values) | ||||
|                         .update( | ||||
|                             new_values, | ||||
|                             'fetch' | ||||
|                         ) | ||||
|                     ) | ||||
|                 except sa.exc.IntegrityError: | ||||
|                     pass | ||||
|                 found = True | ||||
|         return found | ||||
|  | ||||
|     def raw_merge(self, session, table, old_values, new_values): | ||||
|         conditions = [] | ||||
|         for key, value in six.iteritems(old_values): | ||||
|             conditions.append(getattr(table.c, key) == value) | ||||
|         sql = ( | ||||
|             table | ||||
|             .update() | ||||
|             .where(sa.and_( | ||||
|                 *conditions | ||||
|             )) | ||||
|             .values( | ||||
|                 new_values | ||||
|             ) | ||||
|         ) | ||||
|         try: | ||||
|             session.execute(sql) | ||||
|         except sa.exc.IntegrityError: | ||||
|             pass | ||||
|  | ||||
|     def merge_update(self, table_name, from_, to, foreign_key): | ||||
|         session = object_session(from_) | ||||
|         constrained_columns = foreign_key['constrained_columns'] | ||||
|         referred_columns = foreign_key['referred_columns'] | ||||
|         metadata = from_.metadata | ||||
|         table = metadata.tables[table_name] | ||||
|  | ||||
|         new_values = {} | ||||
|         for index, column in enumerate(constrained_columns): | ||||
|             new_values[column] = getattr( | ||||
|                 to, referred_columns[index] | ||||
|             ) | ||||
|  | ||||
|         old_values = {} | ||||
|         for index, column in enumerate(constrained_columns): | ||||
|             old_values[column] = getattr( | ||||
|                 from_, referred_columns[index] | ||||
|             ) | ||||
|  | ||||
|         if not self.memory_merge(session, table_name, old_values, new_values): | ||||
|             self.raw_merge(session, table, old_values, new_values) | ||||
|  | ||||
|     def __call__(self, from_, to): | ||||
|         """ | ||||
|         Merges entity into another entity. After merging deletes the from_ | ||||
|         argument entity. | ||||
|         """ | ||||
|         if from_.__tablename__ != to.__tablename__: | ||||
|             raise Exception() | ||||
|  | ||||
|         session = object_session(from_) | ||||
|         foreign_keys = dependent_foreign_keys(from_) | ||||
|  | ||||
|         for table_name in foreign_keys: | ||||
|             for foreign_key in foreign_keys[table_name]: | ||||
|                 self.merge_update(table_name, from_, to, foreign_key) | ||||
|  | ||||
|         session.delete(from_) | ||||
|  | ||||
|  | ||||
| def merge(from_, to, merger=Merger): | ||||
|     """ | ||||
|     Merges entity into another entity. After merging deletes the from_ argument | ||||
|     entity. | ||||
|  | ||||
|     After merging the from_ entity is deleted from database. | ||||
|  | ||||
|     :param from_: an entity to merge into another entity | ||||
|     :param to: an entity to merge another entity into | ||||
|     :param merger: Merger class, by default this is sqlalchemy_utils.Merger | ||||
|         class | ||||
|     """ | ||||
|     return Merger()(from_, to) | ||||
| @@ -71,3 +71,18 @@ class TestGetMapperWithMultipleMappersFound(object): | ||||
|         alias = sa.orm.aliased(self.Building.__table__) | ||||
|         with raises(ValueError): | ||||
|             get_mapper(alias) | ||||
|  | ||||
|  | ||||
| class TestGetMapperForTableWithoutMapper(object): | ||||
|     def setup_method(self, method): | ||||
|         metadata = sa.MetaData() | ||||
|         self.building = sa.Table('building', metadata) | ||||
|  | ||||
|     def test_table(self): | ||||
|         with raises(ValueError): | ||||
|             get_mapper(self.building) | ||||
|  | ||||
|     def test_table_alias(self): | ||||
|         alias = sa.orm.aliased(self.building) | ||||
|         with raises(ValueError): | ||||
|             get_mapper(alias) | ||||
|   | ||||
| @@ -1,10 +1,10 @@ | ||||
| import sqlalchemy as sa | ||||
| from sqlalchemy_utils import merge | ||||
| from sqlalchemy_utils import merge_references | ||||
| 
 | ||||
| from tests import TestCase | ||||
| 
 | ||||
| 
 | ||||
| class TestMerge(TestCase): | ||||
| class TestMergeReferences(TestCase): | ||||
|     def create_models(self): | ||||
|         class User(self.Base): | ||||
|             __tablename__ = 'user' | ||||
| @@ -36,21 +36,29 @@ class TestMerge(TestCase): | ||||
|         self.session.add(post) | ||||
|         self.session.add(post2) | ||||
|         self.session.commit() | ||||
|         merge(john, jack) | ||||
|         merge_references(john, jack) | ||||
|         self.session.commit() | ||||
|         assert post.author == jack | ||||
|         assert post2.author == jack | ||||
| 
 | ||||
|     def test_deletes_from_entity(self): | ||||
|     def test_object_merging_whenever_possible(self): | ||||
|         john = self.User(name=u'John') | ||||
|         jack = self.User(name=u'Jack') | ||||
|         post = self.BlogPost(title=u'Some title', author=john) | ||||
|         post2 = self.BlogPost(title=u'Other title', author=jack) | ||||
|         self.session.add(john) | ||||
|         self.session.add(jack) | ||||
|         self.session.add(post) | ||||
|         self.session.add(post2) | ||||
|         self.session.commit() | ||||
|         merge(john, jack) | ||||
|         assert john in self.session.deleted | ||||
|         # Load the author for post | ||||
|         assert post.author_id == john.id | ||||
|         merge_references(john, jack) | ||||
|         assert post.author_id == jack.id | ||||
|         assert post2.author_id == jack.id | ||||
| 
 | ||||
| 
 | ||||
| class TestMergeManyToManyAssociations(TestCase): | ||||
| class TestMergeReferencesWithManyToManyAssociations(TestCase): | ||||
|     def create_models(self): | ||||
|         class User(self.Base): | ||||
|             __tablename__ = 'user' | ||||
| @@ -88,7 +96,7 @@ class TestMergeManyToManyAssociations(TestCase): | ||||
|         self.User = User | ||||
|         self.Team = Team | ||||
| 
 | ||||
|     def test_when_association_only_exists_in_from_entity(self): | ||||
|     def test_supports_associations(self): | ||||
|         john = self.User(name=u'John') | ||||
|         jack = self.User(name=u'Jack') | ||||
|         team = self.Team(name=u'Team') | ||||
| @@ -96,29 +104,12 @@ class TestMergeManyToManyAssociations(TestCase): | ||||
|         self.session.add(john) | ||||
|         self.session.add(jack) | ||||
|         self.session.commit() | ||||
|         merge(john, jack) | ||||
|         merge_references(john, jack) | ||||
|         assert john not in team.members | ||||
|         assert jack in team.members | ||||
| 
 | ||||
|     # def test_when_association_exists_in_both(self): | ||||
|     #     john = self.User(name=u'John') | ||||
|     #     jack = self.User(name=u'Jack') | ||||
|     #     team = self.Team(name=u'Team') | ||||
|     #     team.members.append(john) | ||||
|     #     team.members.append(jack) | ||||
|     #     self.session.add(john) | ||||
|     #     self.session.add(jack) | ||||
|     #     self.session.commit() | ||||
|     #     merge(john, jack) | ||||
|     #     assert john not in team.members | ||||
|     #     assert jack in team.members | ||||
|     #     count = self.session.execute( | ||||
|     #         'SELECT COUNT(1) FROM team_member' | ||||
|     #     ).fetchone()[0] | ||||
|     #     assert count == 1 | ||||
| 
 | ||||
| 
 | ||||
| class TestMergeManyToManyAssociationObjects(TestCase): | ||||
| class TestMergeReferencesWithManyToManyAssociationObjects(TestCase): | ||||
|     def create_models(self): | ||||
|         class Team(self.Base): | ||||
|             __tablename__ = 'team' | ||||
| @@ -164,7 +155,7 @@ class TestMergeManyToManyAssociationObjects(TestCase): | ||||
|         self.TeamMember = TeamMember | ||||
|         self.Team = Team | ||||
| 
 | ||||
|     def test_when_association_only_exists_in_from_entity(self): | ||||
|     def test_supports_associations(self): | ||||
|         john = self.User(name=u'John') | ||||
|         jack = self.User(name=u'Jack') | ||||
|         team = self.Team(name=u'Team') | ||||
| @@ -173,24 +164,8 @@ class TestMergeManyToManyAssociationObjects(TestCase): | ||||
|         self.session.add(jack) | ||||
|         self.session.add(team) | ||||
|         self.session.commit() | ||||
|         merge(john, jack) | ||||
|         merge_references(john, jack) | ||||
|         self.session.commit() | ||||
|         users = [member.user for member in team.members] | ||||
|         assert john not in users | ||||
|         assert jack in users | ||||
| 
 | ||||
|     # def test_when_association_exists_in_both(self): | ||||
|     #     john = self.User(name=u'John') | ||||
|     #     jack = self.User(name=u'Jack') | ||||
|     #     team = self.Team(name=u'Team') | ||||
|     #     team.members.append(self.TeamMember(user=john)) | ||||
|     #     team.members.append(self.TeamMember(user=jack)) | ||||
|     #     self.session.add(john) | ||||
|     #     self.session.add(jack) | ||||
|     #     self.session.add(team) | ||||
|     #     self.session.commit() | ||||
|     #     merge(john, jack) | ||||
|     #     users = [member.user for member in team.members] | ||||
|     #     assert john not in users | ||||
|     #     assert jack in users | ||||
|     #     assert self.session.query(self.TeamMember).count() == 1 | ||||
		Reference in New Issue
	
	Block a user
	 Konsta Vesterinen
					Konsta Vesterinen