From 42650d9f32ab94dd9edeea3d758b65f3ea8a08f4 Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Wed, 14 May 2014 17:20:34 +0300 Subject: [PATCH] Refactor docs, add merge_references --- CHANGES.rst | 3 +- docs/database_helpers.rst | 12 - docs/foreign_key_helpers.rst | 40 ++ docs/generic_relationship.rst | 4 +- docs/index.rst | 3 +- docs/{model_helpers.rst => orm_helpers.rst} | 22 +- setup.py | 2 +- sqlalchemy_utils/__init__.py | 7 +- sqlalchemy_utils/functions/__init__.py | 9 +- sqlalchemy_utils/functions/database.py | 53 --- sqlalchemy_utils/functions/foreign_keys.py | 350 ++++++++++++++++++ sqlalchemy_utils/functions/orm.py | 190 +--------- sqlalchemy_utils/merge.py | 124 ------- tests/functions/test_get_mapper.py | 15 + .../test_merge_references.py} | 65 +--- 15 files changed, 449 insertions(+), 450 deletions(-) create mode 100644 docs/foreign_key_helpers.rst rename docs/{model_helpers.rst => orm_helpers.rst} (72%) create mode 100644 sqlalchemy_utils/functions/foreign_keys.py delete mode 100644 sqlalchemy_utils/merge.py rename tests/{test_merge.py => functions/test_merge_references.py} (73%) diff --git a/CHANGES.rst b/CHANGES.rst index 761db32..93570ce 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,12 +4,13 @@ Changelog Here you can see the full list of changes between each SQLAlchemy-Utils release. -0.26.1 (2014-05-xx) +0.26.1 (2014-05-14) ^^^^^^^^^^^^^^^^^^^ - Added get_bind - Added group_foreign_keys - Added get_mapper +- Added merge_references 0.26.0 (2014-05-07) diff --git a/docs/database_helpers.rst b/docs/database_helpers.rst index 7a67515..6d5387c 100644 --- a/docs/database_helpers.rst +++ b/docs/database_helpers.rst @@ -5,18 +5,6 @@ Database helpers .. module:: sqlalchemy_utils.functions -is_indexed_foreign_key -^^^^^^^^^^^^^^^^^^^^^^ - -.. autofunction:: is_indexed_foreign_key - - -non_indexed_foreign_keys -^^^^^^^^^^^^^^^^^^^^^^^^ - -.. autofunction:: non_indexed_foreign_keys - - database_exists ^^^^^^^^^^^^^^^ diff --git a/docs/foreign_key_helpers.rst b/docs/foreign_key_helpers.rst new file mode 100644 index 0000000..3afa61e --- /dev/null +++ b/docs/foreign_key_helpers.rst @@ -0,0 +1,40 @@ +Foreign key helpers +=================== + +.. module:: sqlalchemy_utils.functions + + +dependent_objects +^^^^^^^^^^^^^^^^^ + +.. autofunction:: dependent_objects + + +get_referencing_foreign_keys +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: get_referencing_foreign_keys + + +group_foreign_keys +^^^^^^^^^^^^^^^^^^ + +.. autofunction:: group_foreign_keys + + +is_indexed_foreign_key +^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: is_indexed_foreign_key + + +merge_references +^^^^^^^^^^^^^^^^ + +.. autofunction:: merge_references + + +non_indexed_foreign_keys +^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: non_indexed_foreign_keys diff --git a/docs/generic_relationship.rst b/docs/generic_relationship.rst index aacb09e..6f0d425 100644 --- a/docs/generic_relationship.rst +++ b/docs/generic_relationship.rst @@ -1,5 +1,5 @@ -Generic relationship -==================== +Generic relationships +===================== Generic relationship is a form of relationship that supports creating a 1 to many relationship to any target model. diff --git a/docs/index.rst b/docs/index.rst index 33463bd..02a672f 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -16,6 +16,7 @@ SQLAlchemy-Utils provides custom data types and various utility functions for SQ decorators generic_relationship database_helpers - model_helpers + foreign_key_helpers + orm_helpers utility_classes license diff --git a/docs/model_helpers.rst b/docs/orm_helpers.rst similarity index 72% rename from docs/model_helpers.rst rename to docs/orm_helpers.rst index 55cd91c..ad413ee 100644 --- a/docs/model_helpers.rst +++ b/docs/orm_helpers.rst @@ -1,15 +1,9 @@ -Model helpers -============= +ORM helpers +=========== .. module:: sqlalchemy_utils.functions -dependent_objects -^^^^^^^^^^^^^^^^^ - -.. autofunction:: dependent_objects - - escape_like ^^^^^^^^^^^ @@ -46,24 +40,12 @@ get_primary_keys .. autofunction:: get_primary_keys -get_referencing_foreign_keys -^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. autofunction:: get_referencing_foreign_keys - - get_tables ^^^^^^^^^^ .. autofunction:: get_tables -group_foreign_keys -^^^^^^^^^^^^^^^^^^ - -.. autofunction:: group_foreign_keys - - query_entities ^^^^^^^^^^^^^^ diff --git a/setup.py b/setup.py index ed9f665..509f36c 100644 --- a/setup.py +++ b/setup.py @@ -44,7 +44,7 @@ for name, requirements in extras_require.items(): setup( name='SQLAlchemy-Utils', - version='0.26.0', + version='0.26.1', url='https://github.com/kvesteri/sqlalchemy-utils', license='BSD', author='Konsta Vesterinen, Ryan Leckey, Janne Vanhala, Vesa Uimonen', diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index 196f030..63b20ee 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -20,6 +20,7 @@ from .functions import ( get_tables, group_foreign_keys, identity, + merge_references, mock_engine, naturally_equivalent, render_expression, @@ -32,7 +33,6 @@ from .listeners import ( force_auto_coercion, force_instant_defaults ) -from .merge import merge, Merger from .generic import generic_relationship from .proxy_dict import ProxyDict, proxy_dict from .query_chain import QueryChain @@ -67,7 +67,7 @@ from .types import ( ) -__version__ = '0.26.0' +__version__ = '0.26.1' __all__ = ( @@ -95,7 +95,7 @@ __all__ = ( group_foreign_keys, identity, instrumented_list, - merge, + merge_references, mock_engine, naturally_equivalent, proxy_dict, @@ -120,7 +120,6 @@ __all__ = ( IPAddressType, JSONType, LocaleType, - Merger, NumericRangeType, Password, PasswordType, diff --git a/sqlalchemy_utils/functions/__init__.py b/sqlalchemy_utils/functions/__init__.py index 39e5b7e..d8e2ef9 100644 --- a/sqlalchemy_utils/functions/__init__.py +++ b/sqlalchemy_utils/functions/__init__.py @@ -8,20 +8,23 @@ from .database import ( drop_database, escape_like, is_auto_assigned_date_column, +) +from .foreign_keys import ( + dependent_objects, + get_referencing_foreign_keys, + group_foreign_keys, is_indexed_foreign_key, + merge_references, non_indexed_foreign_keys, ) from .orm import ( - dependent_objects, get_bind, get_columns, get_declarative_base, get_mapper, get_primary_keys, - get_referencing_foreign_keys, get_tables, getdotattr, - group_foreign_keys, has_changes, identity, naturally_equivalent, diff --git a/sqlalchemy_utils/functions/database.py b/sqlalchemy_utils/functions/database.py index 0888502..5986cc3 100644 --- a/sqlalchemy_utils/functions/database.py +++ b/sqlalchemy_utils/functions/database.py @@ -1,4 +1,3 @@ -from collections import defaultdict from sqlalchemy.engine.url import make_url import sqlalchemy as sa from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint @@ -186,55 +185,3 @@ def drop_database(url): else: text = "DROP DATABASE %s" % database engine.execute(text) - - -def non_indexed_foreign_keys(metadata, engine=None): - """ - Finds all non indexed foreign keys from all tables of given MetaData. - - Very useful for optimizing postgresql database and finding out which - foreign keys need indexes. - - :param metadata: MetaData object to inspect tables from - """ - reflected_metadata = MetaData() - - if metadata.bind is None and engine is None: - raise Exception( - 'Either pass a metadata object with bind or ' - 'pass engine as a second parameter' - ) - - constraints = defaultdict(list) - - for table_name in metadata.tables.keys(): - table = Table( - table_name, - reflected_metadata, - autoload=True, - autoload_with=metadata.bind or engine - ) - - for constraint in table.constraints: - if not isinstance(constraint, ForeignKeyConstraint): - continue - - if not is_indexed_foreign_key(constraint): - constraints[table.name].append(constraint) - - return dict(constraints) - - -def is_indexed_foreign_key(constraint): - """ - Whether or not given foreign key constraint's columns have been indexed. - - :param constraint: ForeignKeyConstraint object to check the indexes - """ - for index in constraint.table.indexes: - index_column_names = set( - column.name for column in index.columns - ) - if index_column_names == set(constraint.columns): - return True - return False diff --git a/sqlalchemy_utils/functions/foreign_keys.py b/sqlalchemy_utils/functions/foreign_keys.py new file mode 100644 index 0000000..dfb465c --- /dev/null +++ b/sqlalchemy_utils/functions/foreign_keys.py @@ -0,0 +1,350 @@ +from collections import defaultdict +from itertools import groupby + +import six +import sqlalchemy as sa +from sqlalchemy.engine import reflection +from sqlalchemy.orm import object_session, mapperlib +from sqlalchemy.schema import MetaData, Table, ForeignKeyConstraint + +from .orm import get_mapper, get_tables +from ..query_chain import QueryChain + + +def get_foreign_key_values(fk, obj): + return { + fk.constraint.columns[index].key: + getattr(obj, element.column.key) + for + index, element + in + enumerate(fk.constraint.elements) + } + + +def group_foreign_keys(foreign_keys): + """ + Return a groupby iterator that groups given foreign keys by table. + + :param foreign_keys: a sequence of foreign keys + + + :: + + foreign_keys = get_referencing_foreign_keys(User) + + for table, fks in group_foreign_keys(foreign_keys): + # do something + pass + + + .. seealso:: :func:`get_referencing_foreign_keys` + + .. versionadded: 0.26.1 + """ + foreign_keys = sorted( + foreign_keys, key=lambda key: key.constraint.table.name + ) + return groupby(foreign_keys, lambda key: key.constraint.table) + + +def get_referencing_foreign_keys(mixed): + """ + Returns referencing foreign keys for given Table object or declarative + class. + + :param mixed: + SA Table object or SA declarative class + + :: + + get_referencing_foreign_keys(User) # set([ForeignKey('user.id')]) + + get_referencing_foreign_keys(User.__table__) + + + This function also understands inheritance. This means it returns + all foreign keys that reference any table in the class inheritance tree. + + Let's say you have three classes which use joined table inheritance, + namely TextItem, Article and BlogPost with Article and BlogPost inheriting + TextItem. + + :: + + # This will check all foreign keys that reference either article table + # or textitem table. + get_referencing_foreign_keys(Article) + + .. seealso:: :func:`get_tables` + """ + if isinstance(mixed, sa.Table): + tables = [mixed] + else: + tables = get_tables(mixed) + + referencing_foreign_keys = set() + + for table in mixed.metadata.tables.values(): + if table not in tables: + for constraint in table.constraints: + if isinstance(constraint, sa.sql.schema.ForeignKeyConstraint): + for fk in constraint.elements: + if any(fk.references(t) for t in tables): + referencing_foreign_keys.add(fk) + return referencing_foreign_keys + + +def merge_references(from_, to, foreign_keys=None): + """ + Merge the references of an entity into another entity. + + Consider the following models:: + + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + def __repr__(self): + return 'User(name=%r)' % self.name + + class BlogPost(self.Base): + __tablename__ = 'blog_post' + id = sa.Column(sa.Integer, primary_key=True) + title = sa.Column(sa.Unicode(255)) + author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id')) + + author = sa.orm.relationship(User) + + + Now lets add some data:: + + john = self.User(name=u'John') + jack = self.User(name=u'Jack') + post = self.BlogPost(title=u'Some title', author=john) + post2 = self.BlogPost(title=u'Other title', author=jack) + self.session.add_all([ + john, + jack, + post, + post2 + ]) + self.session.commit() + + + If we wanted to merge all John's references to Jack it would be as easy as + :: + + merge_references(john, jack) + self.session.commit() + + post.author # User(name='Jack') + post2.author # User(name='Jack') + + + + :param from_: an entity to merge into another entity + :param to: an entity to merge another entity into + :param foreign_keys: A sequence of foreign keys. By default this is None + indicating all referencing foreign keys should be used. + + .. seealso: :func:`dependent_objects` + + .. versionadded: 0.26.1 + """ + if from_.__tablename__ != to.__tablename__: + raise TypeError('The tables of given arguments do not match.') + + session = object_session(from_) + foreign_keys = get_referencing_foreign_keys(from_) + + for fk in foreign_keys: + old_values = get_foreign_key_values(fk, from_) + new_values = get_foreign_key_values(fk, to) + criteria = ( + getattr(fk.constraint.table.c, key) == value + for key, value in six.iteritems(old_values) + ) + try: + mapper = get_mapper(fk.constraint.table) + except ValueError: + query = ( + fk.constraint.table + .update() + .where(sa.and_(*criteria)) + .values(new_values) + ) + session.execute(query) + else: + print old_values, new_values + ( + session.query(mapper.class_) + .filter_by(**old_values) + .update( + new_values, + 'evaluate' + ) + ) + + +def dependent_objects(obj, foreign_keys=None): + """ + Return a :class:`~sqlalchemy_utils.query_chain.QueryChain` that iterates + through all dependent objects for given SQLAlchemy object. + + Consider a User object is referenced in various articles and also in + various orders. Getting all these dependent objects is as easy as: + + :: + + from sqlalchemy_utils import dependent_objects + + + dependent_objects(user) + + + If you expect an object to have lots of dependent_objects it might be good + to limit the results:: + + + dependent_objects(user).limit(5) + + + + The common use case is checking for all restrict dependent objects before + deleting parent object and inform the user if there are dependent objects + with ondelete='RESTRICT' foreign keys. If this kind of checking is not used + it will lead to nasty IntegrityErrors being raised. + + In the following example we delete given user if it doesn't have any + foreign key restricted dependent objects. + + :: + + + from sqlalchemy_utils import get_referencing_foreign_keys + + + user = session.query(User).get(some_user_id) + + + deps = list( + dependent_objects( + user, + ( + fk for fk in get_referencing_foreign_keys(User) + # On most databases RESTRICT is the default mode hence we + # check for None values also + if fk.ondelete == 'RESTRICT' or fk.ondelete is None + ) + ).limit(5) + ) + + if deps: + # Do something to inform the user + pass + else: + session.delete(user) + + + :param obj: SQLAlchemy declarative model object + :param foreign_keys: + A sequence of foreign keys to use for searching the dependent_objects + for given object. By default this is None, indicating that all foreign + keys referencing the object will be used. + + .. note:: + This function does not support exotic mappers that use multiple tables + + .. seealso:: :func:`get_referencing_foreign_keys` + .. seealso:: :func:`merge_references` + + .. versionadded: 0.26.0 + """ + if foreign_keys is None: + foreign_keys = get_referencing_foreign_keys(obj) + + session = object_session(obj) + + chain = QueryChain([]) + classes = obj.__class__._decl_class_registry + + for table, keys in group_foreign_keys(foreign_keys): + for class_ in classes.values(): + if hasattr(class_, '__table__') and class_.__table__ == table: + criteria = [] + visited_constraints = [] + for key in keys: + if key.constraint not in visited_constraints: + visited_constraints.append(key.constraint) + subcriteria = [ + getattr(class_, column.key) == + getattr( + obj, + key.constraint.elements[index].column.key + ) + for index, column + in enumerate(key.constraint.columns) + ] + criteria.append(sa.and_(*subcriteria)) + + query = session.query(class_).filter( + sa.or_( + *criteria + ) + ) + chain.queries.append(query) + return chain + + +def non_indexed_foreign_keys(metadata, engine=None): + """ + Finds all non indexed foreign keys from all tables of given MetaData. + + Very useful for optimizing postgresql database and finding out which + foreign keys need indexes. + + :param metadata: MetaData object to inspect tables from + """ + reflected_metadata = MetaData() + + if metadata.bind is None and engine is None: + raise Exception( + 'Either pass a metadata object with bind or ' + 'pass engine as a second parameter' + ) + + constraints = defaultdict(list) + + for table_name in metadata.tables.keys(): + table = Table( + table_name, + reflected_metadata, + autoload=True, + autoload_with=metadata.bind or engine + ) + + for constraint in table.constraints: + if not isinstance(constraint, ForeignKeyConstraint): + continue + + if not is_indexed_foreign_key(constraint): + constraints[table.name].append(constraint) + + return dict(constraints) + + +def is_indexed_foreign_key(constraint): + """ + Whether or not given foreign key constraint's columns have been indexed. + + :param constraint: ForeignKeyConstraint object to check the indexes + """ + for index in constraint.table.indexes: + index_column_names = set( + column.name for column in index.columns + ) + if index_column_names == set(constraint.columns): + return True + return False diff --git a/sqlalchemy_utils/functions/orm.py b/sqlalchemy_utils/functions/orm.py index 594e483..d16e008 100644 --- a/sqlalchemy_utils/functions/orm.py +++ b/sqlalchemy_utils/functions/orm.py @@ -61,8 +61,12 @@ def get_mapper(mixed): ] if len(mappers) > 1: raise ValueError( - "Could not get mapper for '%r'. Multiple mappers found." - % mixed + "Multiple mappers found for table '%s'." + % mixed.name + ) + elif not mappers: + raise ValueError( + "Could not get mapper for table '%s'." ) else: return mappers[0] @@ -104,188 +108,6 @@ def get_bind(obj): return conn -def dependent_objects(obj, foreign_keys=None): - """ - Return a :class:`~sqlalchemy_utils.query_chain.QueryChain` that iterates - through all dependent objects for given SQLAlchemy object. - - Consider a User object is referenced in various articles and also in - various orders. Getting all these dependent objects is as easy as: - - :: - - from sqlalchemy_utils import dependent_objects - - - dependent_objects(user) - - - If you expect an object to have lots of dependent_objects it might be good - to limit the results:: - - - dependent_objects(user).limit(5) - - - - The common use case is checking for all restrict dependent objects before - deleting parent object and inform the user if there are dependent objects - with ondelete='RESTRICT' foreign keys. If this kind of checking is not used - it will lead to nasty IntegrityErrors being raised. - - In the following example we delete given user if it doesn't have any - foreign key restricted dependent objects. - - :: - - - from sqlalchemy_utils import get_referencing_foreign_keys - - - user = session.query(User).get(some_user_id) - - - deps = list( - dependent_objects( - user, - ( - fk for fk in get_referencing_foreign_keys(User) - # On most databases RESTRICT is the default mode hence we - # check for None values also - if fk.ondelete == 'RESTRICT' or fk.ondelete is None - ) - ).limit(5) - ) - - if deps: - # Do something to inform the user - pass - else: - session.delete(user) - - - :param obj: SQLAlchemy declarative model object - :param foreign_keys: - A sequence of foreign keys to use for searching the dependent_objects - for given object. By default this is None, indicating that all foreign - keys referencing the object will be used. - - .. note:: - This function does not support exotic mappers that use multiple tables - - .. seealso:: :func:`get_referencing_foreign_keys` - - .. versionadded: 0.26.0 - """ - if foreign_keys is None: - foreign_keys = get_referencing_foreign_keys(obj) - - session = object_session(obj) - - chain = QueryChain([]) - classes = obj.__class__._decl_class_registry - - for table, keys in group_foreign_keys(foreign_keys): - for class_ in classes.values(): - if hasattr(class_, '__table__') and class_.__table__ == table: - criteria = [] - visited_constraints = [] - for key in keys: - if key.constraint not in visited_constraints: - visited_constraints.append(key.constraint) - subcriteria = [ - getattr(class_, column.key) == - getattr( - obj, - key.constraint.elements[index].column.key - ) - for index, column - in enumerate(key.constraint.columns) - ] - criteria.append(sa.and_(*subcriteria)) - - query = session.query(class_).filter( - sa.or_( - *criteria - ) - ) - chain.queries.append(query) - return chain - - -def group_foreign_keys(foreign_keys): - """ - Return a groupby iterator that groups given foreign keys by table. - - :param foreign_keys: a sequence of foreign keys - - - :: - - foreign_keys = get_referencing_foreign_keys(User) - - for table, fks in group_foreign_keys(foreign_keys): - # do something - pass - - - .. also:: :func:`get_referencing_foreign_keys` - - .. versionadded: 0.26.1 - """ - foreign_keys = sorted( - foreign_keys, key=lambda key: key.constraint.table.name - ) - return groupby(foreign_keys, lambda key: key.constraint.table) - - -def get_referencing_foreign_keys(mixed): - """ - Returns referencing foreign keys for given Table object or declarative - class. - - :param mixed: - SA Table object or SA declarative class - - :: - - get_referencing_foreign_keys(User) # set([ForeignKey('user.id')]) - - get_referencing_foreign_keys(User.__table__) - - - This function also understands inheritance. This means it returns - all foreign keys that reference any table in the class inheritance tree. - - Let's say you have three classes which use joined table inheritance, - namely TextItem, Article and BlogPost with Article and BlogPost inheriting - TextItem. - - :: - - # This will check all foreign keys that reference either article table - # or textitem table. - get_referencing_foreign_keys(Article) - - .. seealso:: :func:`get_tables` - """ - if isinstance(mixed, sa.Table): - tables = [mixed] - else: - tables = get_tables(mixed) - - referencing_foreign_keys = set() - - for table in mixed.metadata.tables.values(): - if table not in tables: - for constraint in table.constraints: - if isinstance(constraint, sa.sql.schema.ForeignKeyConstraint): - for fk in constraint.elements: - if any(fk.references(t) for t in tables): - referencing_foreign_keys.add(fk) - return referencing_foreign_keys - - def get_primary_keys(mixed): """ Return an OrderedDict of all primary keys for given Table object, diff --git a/sqlalchemy_utils/merge.py b/sqlalchemy_utils/merge.py deleted file mode 100644 index 5e7f2c8..0000000 --- a/sqlalchemy_utils/merge.py +++ /dev/null @@ -1,124 +0,0 @@ -import six -import sqlalchemy as sa -from sqlalchemy.engine import reflection -from sqlalchemy.orm import object_session, mapperlib - - -def dependent_foreign_keys(model_class): - """ - Returns dependent foreign keys as dicts for given model class. - - ** Experimental function ** - """ - session = object_session(model_class) - - engine = session.bind - inspector = reflection.Inspector.from_engine(engine) - table_names = inspector.get_table_names() - - dependent_foreign_keys = {} - - for table_name in table_names: - fks = inspector.get_foreign_keys(table_name) - if fks: - dependent_foreign_keys[table_name] = [] - for fk in fks: - if fk['referred_table'] == model_class.__tablename__: - dependent_foreign_keys[table_name].append(fk) - return dependent_foreign_keys - - -class Merger(object): - def memory_merge(self, session, table_name, old_values, new_values): - # try to fetch mappers for given table and update in memory objects as - # well as database table - found = False - for mapper in mapperlib._mapper_registry: - class_ = mapper.class_ - if table_name == class_.__table__.name: - try: - ( - session.query(mapper.class_) - .filter_by(**old_values) - .update( - new_values, - 'fetch' - ) - ) - except sa.exc.IntegrityError: - pass - found = True - return found - - def raw_merge(self, session, table, old_values, new_values): - conditions = [] - for key, value in six.iteritems(old_values): - conditions.append(getattr(table.c, key) == value) - sql = ( - table - .update() - .where(sa.and_( - *conditions - )) - .values( - new_values - ) - ) - try: - session.execute(sql) - except sa.exc.IntegrityError: - pass - - def merge_update(self, table_name, from_, to, foreign_key): - session = object_session(from_) - constrained_columns = foreign_key['constrained_columns'] - referred_columns = foreign_key['referred_columns'] - metadata = from_.metadata - table = metadata.tables[table_name] - - new_values = {} - for index, column in enumerate(constrained_columns): - new_values[column] = getattr( - to, referred_columns[index] - ) - - old_values = {} - for index, column in enumerate(constrained_columns): - old_values[column] = getattr( - from_, referred_columns[index] - ) - - if not self.memory_merge(session, table_name, old_values, new_values): - self.raw_merge(session, table, old_values, new_values) - - def __call__(self, from_, to): - """ - Merges entity into another entity. After merging deletes the from_ - argument entity. - """ - if from_.__tablename__ != to.__tablename__: - raise Exception() - - session = object_session(from_) - foreign_keys = dependent_foreign_keys(from_) - - for table_name in foreign_keys: - for foreign_key in foreign_keys[table_name]: - self.merge_update(table_name, from_, to, foreign_key) - - session.delete(from_) - - -def merge(from_, to, merger=Merger): - """ - Merges entity into another entity. After merging deletes the from_ argument - entity. - - After merging the from_ entity is deleted from database. - - :param from_: an entity to merge into another entity - :param to: an entity to merge another entity into - :param merger: Merger class, by default this is sqlalchemy_utils.Merger - class - """ - return Merger()(from_, to) diff --git a/tests/functions/test_get_mapper.py b/tests/functions/test_get_mapper.py index 3d8bee4..17646a0 100644 --- a/tests/functions/test_get_mapper.py +++ b/tests/functions/test_get_mapper.py @@ -71,3 +71,18 @@ class TestGetMapperWithMultipleMappersFound(object): alias = sa.orm.aliased(self.Building.__table__) with raises(ValueError): get_mapper(alias) + + +class TestGetMapperForTableWithoutMapper(object): + def setup_method(self, method): + metadata = sa.MetaData() + self.building = sa.Table('building', metadata) + + def test_table(self): + with raises(ValueError): + get_mapper(self.building) + + def test_table_alias(self): + alias = sa.orm.aliased(self.building) + with raises(ValueError): + get_mapper(alias) diff --git a/tests/test_merge.py b/tests/functions/test_merge_references.py similarity index 73% rename from tests/test_merge.py rename to tests/functions/test_merge_references.py index ec9ec78..e982bd8 100644 --- a/tests/test_merge.py +++ b/tests/functions/test_merge_references.py @@ -1,10 +1,10 @@ import sqlalchemy as sa -from sqlalchemy_utils import merge +from sqlalchemy_utils import merge_references from tests import TestCase -class TestMerge(TestCase): +class TestMergeReferences(TestCase): def create_models(self): class User(self.Base): __tablename__ = 'user' @@ -36,21 +36,29 @@ class TestMerge(TestCase): self.session.add(post) self.session.add(post2) self.session.commit() - merge(john, jack) + merge_references(john, jack) + self.session.commit() assert post.author == jack assert post2.author == jack - def test_deletes_from_entity(self): + def test_object_merging_whenever_possible(self): john = self.User(name=u'John') jack = self.User(name=u'Jack') + post = self.BlogPost(title=u'Some title', author=john) + post2 = self.BlogPost(title=u'Other title', author=jack) self.session.add(john) self.session.add(jack) + self.session.add(post) + self.session.add(post2) self.session.commit() - merge(john, jack) - assert john in self.session.deleted + # Load the author for post + assert post.author_id == john.id + merge_references(john, jack) + assert post.author_id == jack.id + assert post2.author_id == jack.id -class TestMergeManyToManyAssociations(TestCase): +class TestMergeReferencesWithManyToManyAssociations(TestCase): def create_models(self): class User(self.Base): __tablename__ = 'user' @@ -88,7 +96,7 @@ class TestMergeManyToManyAssociations(TestCase): self.User = User self.Team = Team - def test_when_association_only_exists_in_from_entity(self): + def test_supports_associations(self): john = self.User(name=u'John') jack = self.User(name=u'Jack') team = self.Team(name=u'Team') @@ -96,29 +104,12 @@ class TestMergeManyToManyAssociations(TestCase): self.session.add(john) self.session.add(jack) self.session.commit() - merge(john, jack) + merge_references(john, jack) assert john not in team.members assert jack in team.members - # def test_when_association_exists_in_both(self): - # john = self.User(name=u'John') - # jack = self.User(name=u'Jack') - # team = self.Team(name=u'Team') - # team.members.append(john) - # team.members.append(jack) - # self.session.add(john) - # self.session.add(jack) - # self.session.commit() - # merge(john, jack) - # assert john not in team.members - # assert jack in team.members - # count = self.session.execute( - # 'SELECT COUNT(1) FROM team_member' - # ).fetchone()[0] - # assert count == 1 - -class TestMergeManyToManyAssociationObjects(TestCase): +class TestMergeReferencesWithManyToManyAssociationObjects(TestCase): def create_models(self): class Team(self.Base): __tablename__ = 'team' @@ -164,7 +155,7 @@ class TestMergeManyToManyAssociationObjects(TestCase): self.TeamMember = TeamMember self.Team = Team - def test_when_association_only_exists_in_from_entity(self): + def test_supports_associations(self): john = self.User(name=u'John') jack = self.User(name=u'Jack') team = self.Team(name=u'Team') @@ -173,24 +164,8 @@ class TestMergeManyToManyAssociationObjects(TestCase): self.session.add(jack) self.session.add(team) self.session.commit() - merge(john, jack) + merge_references(john, jack) self.session.commit() users = [member.user for member in team.members] assert john not in users assert jack in users - - # def test_when_association_exists_in_both(self): - # john = self.User(name=u'John') - # jack = self.User(name=u'Jack') - # team = self.Team(name=u'Team') - # team.members.append(self.TeamMember(user=john)) - # team.members.append(self.TeamMember(user=jack)) - # self.session.add(john) - # self.session.add(jack) - # self.session.add(team) - # self.session.commit() - # merge(john, jack) - # users = [member.user for member in team.members] - # assert john not in users - # assert jack in users - # assert self.session.query(self.TeamMember).count() == 1