From 3a5b80ea5e32d6146acf10824b6a8b99a226b63f Mon Sep 17 00:00:00 2001 From: Konsta Vesterinen Date: Wed, 20 Mar 2013 13:45:12 +0200 Subject: [PATCH] Added experimental merge function --- requirements-dev.txt | 1 + sqlalchemy_utils/__init__.py | 124 +++++++++++++++++++- tests.py | 217 ++++++++++++++++++++++++++++++++++- 3 files changed, 340 insertions(+), 2 deletions(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index e2929c0..23b7fac 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,3 +3,4 @@ pytest==2.2.3 Pygments==1.2 Jinja2==2.3 docutils>=0.10 +phonenumbers>=5.4b1 diff --git a/sqlalchemy_utils/__init__.py b/sqlalchemy_utils/__init__.py index 2028ab0..91be58f 100644 --- a/sqlalchemy_utils/__init__.py +++ b/sqlalchemy_utils/__init__.py @@ -1,6 +1,8 @@ import phonenumbers from functools import wraps -from sqlalchemy.orm import defer +import sqlalchemy as sa +from sqlalchemy.engine import reflection +from sqlalchemy.orm import defer, object_session, mapperlib from sqlalchemy.orm.collections import InstrumentedList as _InstrumentedList from sqlalchemy.orm.mapper import Mapper from sqlalchemy.orm.query import _ColumnEntity @@ -210,3 +212,123 @@ def escape_like(string, escape_char='*'): .replace('%', escape_char + '%') .replace('_', escape_char + '_') ) + + +def dependent_foreign_keys(model_class): + """ + Returns dependent foreign keys as dicts for given model class. + + ** Experimental function ** + """ + session = object_session(model_class) + + engine = session.bind + inspector = reflection.Inspector.from_engine(engine) + table_names = inspector.get_table_names() + + dependent_foreign_keys = {} + + for table_name in table_names: + fks = inspector.get_foreign_keys(table_name) + if fks: + dependent_foreign_keys[table_name] = [] + for fk in fks: + if fk['referred_table'] == model_class.__tablename__: + dependent_foreign_keys[table_name].append(fk) + return dependent_foreign_keys + + +class Merger(object): + def memory_merge(self, session, table_name, old_values, new_values): + # try to fetch mappers for given table and update in memory objects as + # well as database table + found = False + for mapper in mapperlib._mapper_registry: + class_ = mapper.class_ + if table_name == class_.__table__.name: + try: + ( + session.query(mapper.class_) + .filter_by(**old_values) + .update( + new_values, + 'fetch' + ) + ) + except sa.exc.IntegrityError: + pass + found = True + return found + + def raw_merge(self, session, table, old_values, new_values): + conditions = [] + for key, value in old_values.items(): + conditions.append(getattr(table.c, key) == value) + sql = ( + table + .update() + .where(sa.and_( + *conditions + )) + .values( + new_values + ) + ) + try: + session.execute(sql) + except sa.exc.IntegrityError: + pass + + def merge_update(self, table_name, from_, to, foreign_key): + session = object_session(from_) + constrained_columns = foreign_key['constrained_columns'] + referred_columns = foreign_key['referred_columns'] + metadata = from_.metadata + table = metadata.tables[table_name] + + new_values = {} + for index, column in enumerate(constrained_columns): + new_values[column] = getattr( + to, referred_columns[index] + ) + + old_values = {} + for index, column in enumerate(constrained_columns): + old_values[column] = getattr( + from_, referred_columns[index] + ) + + if not self.memory_merge(session, table_name, old_values, new_values): + self.raw_merge(session, table, old_values, new_values) + + def __call__(self, from_, to): + """ + Merges entity into another entity. After merging deletes the from_ + argument entity. + """ + if from_.__tablename__ != to.__tablename__: + raise Exception() + + session = object_session(from_) + foreign_keys = dependent_foreign_keys(from_) + + for table_name in foreign_keys: + for foreign_key in foreign_keys[table_name]: + self.merge_update(table_name, from_, to, foreign_key) + + session.delete(from_) + + +def merge(from_, to, merger=Merger): + """ + Merges entity into another entity. After merging deletes the from_ argument + entity. + + After merging the from_ entity is deleted from database. + + :param from_: an entity to merge into another entity + :param to: an entity to merge another entity into + :param merger: Merger class, by default this is sqlalchemy_utils.Merger + class + """ + return Merger()(from_, to) diff --git a/tests.py b/tests.py index d0cbfc6..36d0070 100644 --- a/tests.py +++ b/tests.py @@ -9,7 +9,8 @@ from sqlalchemy_utils import ( escape_like, sort_query, InstrumentedList, - PhoneNumberType + PhoneNumberType, + merge ) @@ -142,3 +143,217 @@ class TestPhoneNumberType(TestCase): {'param': self.user.id} ) assert result.first()[0] == u'+358401234567' + + +class DatabaseTestCase(object): + def create_models(self): + pass + + def setup_method(self, method): + self.engine = create_engine( + 'sqlite:///' + ) + #self.engine.echo = True + self.Base = declarative_base() + + self.create_models() + self.Base.metadata.create_all(self.engine) + Session = sessionmaker(bind=self.engine) + self.session = Session() + + def teardown_method(self, method): + self.engine.dispose() + self.Base.metadata.drop_all(self.engine) + self.session.expunge_all() + + +class TestMerge(DatabaseTestCase): + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + def __repr__(self): + return 'User(%r)' % self.name + + class BlogPost(self.Base): + __tablename__ = 'blog_post' + id = sa.Column(sa.Integer, primary_key=True) + title = sa.Column(sa.Unicode(255)) + content = sa.Column(sa.UnicodeText) + author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id')) + + author = sa.orm.relationship(User) + + self.User = User + self.BlogPost = BlogPost + + def test_updates_foreign_keys(self): + john = self.User(name=u'John') + jack = self.User(name=u'Jack') + post = self.BlogPost(title=u'Some title', author=john) + post2 = self.BlogPost(title=u'Other title', author=jack) + self.session.add(john) + self.session.add(jack) + self.session.add(post) + self.session.add(post2) + self.session.commit() + merge(john, jack) + assert post.author == jack + assert post2.author == jack + + def test_deletes_from_entity(self): + john = self.User(name=u'John') + jack = self.User(name=u'Jack') + self.session.add(john) + self.session.add(jack) + self.session.commit() + merge(john, jack) + assert john in self.session.deleted + + +class TestMergeManyToManyAssociations(DatabaseTestCase): + def create_models(self): + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + def __repr__(self): + return 'User(%r)' % self.name + + team_member = sa.Table( + 'team_member', self.Base.metadata, + sa.Column( + 'user_id', sa.Integer, + sa.ForeignKey('user.id', ondelete='CASCADE'), + primary_key=True + ), + sa.Column( + 'team_id', sa.Integer, + sa.ForeignKey('team.id', ondelete='CASCADE'), + primary_key=True + ) + ) + + class Team(self.Base): + __tablename__ = 'team' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + members = sa.orm.relationship( + User, + secondary=team_member, + backref='teams' + ) + + self.User = User + self.Team = Team + + def test_when_association_only_exists_in_from_entity(self): + john = self.User(name=u'John') + jack = self.User(name=u'Jack') + team = self.Team(name=u'Team') + team.members.append(john) + self.session.add(john) + self.session.add(jack) + self.session.commit() + merge(john, jack) + assert john not in team.members + assert jack in team.members + + def test_when_association_exists_in_both(self): + john = self.User(name=u'John') + jack = self.User(name=u'Jack') + team = self.Team(name=u'Team') + team.members.append(john) + team.members.append(jack) + self.session.add(john) + self.session.add(jack) + self.session.commit() + merge(john, jack) + assert john not in team.members + assert jack in team.members + count = self.session.execute( + 'SELECT COUNT(1) FROM team_member' + ).fetchone()[0] + assert count == 1 + + +class TestMergeManyToManyAssociationObjects(DatabaseTestCase): + def create_models(self): + class Team(self.Base): + __tablename__ = 'team' + id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + class TeamMember(self.Base): + __tablename__ = 'team_member' + user_id = sa.Column( + sa.Integer, + sa.ForeignKey(User.id, ondelete='CASCADE'), + primary_key=True + ) + team_id = sa.Column( + sa.Integer, + sa.ForeignKey(Team.id, ondelete='CASCADE'), + primary_key=True + ) + role = sa.Column(sa.Unicode(255)) + team = sa.orm.relationship( + Team, + backref=sa.orm.backref( + 'members', + cascade='all, delete-orphan' + ), + primaryjoin=team_id == Team.id, + ) + user = sa.orm.relationship( + User, + backref=sa.orm.backref( + 'memberships', + cascade='all, delete-orphan' + ), + primaryjoin=user_id == User.id, + ) + + self.User = User + self.TeamMember = TeamMember + self.Team = Team + + def test_when_association_only_exists_in_from_entity(self): + john = self.User(name=u'John') + jack = self.User(name=u'Jack') + team = self.Team(name=u'Team') + team.members.append(self.TeamMember(user=john)) + self.session.add(john) + self.session.add(jack) + self.session.add(team) + self.session.commit() + merge(john, jack) + self.session.commit() + users = [member.user for member in team.members] + assert john not in users + assert jack in users + + def test_when_association_exists_in_both(self): + john = self.User(name=u'John') + jack = self.User(name=u'Jack') + team = self.Team(name=u'Team') + team.members.append(self.TeamMember(user=john)) + team.members.append(self.TeamMember(user=jack)) + self.session.add(john) + self.session.add(jack) + self.session.add(team) + self.session.commit() + merge(john, jack) + users = [member.user for member in team.members] + assert john not in users + assert jack in users + assert self.session.query(self.TeamMember).count() == 1