Added experimental merge function
This commit is contained in:
@@ -3,3 +3,4 @@ pytest==2.2.3
|
|||||||
Pygments==1.2
|
Pygments==1.2
|
||||||
Jinja2==2.3
|
Jinja2==2.3
|
||||||
docutils>=0.10
|
docutils>=0.10
|
||||||
|
phonenumbers>=5.4b1
|
||||||
|
@@ -1,6 +1,8 @@
|
|||||||
import phonenumbers
|
import phonenumbers
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from sqlalchemy.orm import defer
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.engine import reflection
|
||||||
|
from sqlalchemy.orm import defer, object_session, mapperlib
|
||||||
from sqlalchemy.orm.collections import InstrumentedList as _InstrumentedList
|
from sqlalchemy.orm.collections import InstrumentedList as _InstrumentedList
|
||||||
from sqlalchemy.orm.mapper import Mapper
|
from sqlalchemy.orm.mapper import Mapper
|
||||||
from sqlalchemy.orm.query import _ColumnEntity
|
from sqlalchemy.orm.query import _ColumnEntity
|
||||||
@@ -210,3 +212,123 @@ def escape_like(string, escape_char='*'):
|
|||||||
.replace('%', escape_char + '%')
|
.replace('%', escape_char + '%')
|
||||||
.replace('_', escape_char + '_')
|
.replace('_', escape_char + '_')
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def dependent_foreign_keys(model_class):
|
||||||
|
"""
|
||||||
|
Returns dependent foreign keys as dicts for given model class.
|
||||||
|
|
||||||
|
** Experimental function **
|
||||||
|
"""
|
||||||
|
session = object_session(model_class)
|
||||||
|
|
||||||
|
engine = session.bind
|
||||||
|
inspector = reflection.Inspector.from_engine(engine)
|
||||||
|
table_names = inspector.get_table_names()
|
||||||
|
|
||||||
|
dependent_foreign_keys = {}
|
||||||
|
|
||||||
|
for table_name in table_names:
|
||||||
|
fks = inspector.get_foreign_keys(table_name)
|
||||||
|
if fks:
|
||||||
|
dependent_foreign_keys[table_name] = []
|
||||||
|
for fk in fks:
|
||||||
|
if fk['referred_table'] == model_class.__tablename__:
|
||||||
|
dependent_foreign_keys[table_name].append(fk)
|
||||||
|
return dependent_foreign_keys
|
||||||
|
|
||||||
|
|
||||||
|
class Merger(object):
|
||||||
|
def memory_merge(self, session, table_name, old_values, new_values):
|
||||||
|
# try to fetch mappers for given table and update in memory objects as
|
||||||
|
# well as database table
|
||||||
|
found = False
|
||||||
|
for mapper in mapperlib._mapper_registry:
|
||||||
|
class_ = mapper.class_
|
||||||
|
if table_name == class_.__table__.name:
|
||||||
|
try:
|
||||||
|
(
|
||||||
|
session.query(mapper.class_)
|
||||||
|
.filter_by(**old_values)
|
||||||
|
.update(
|
||||||
|
new_values,
|
||||||
|
'fetch'
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except sa.exc.IntegrityError:
|
||||||
|
pass
|
||||||
|
found = True
|
||||||
|
return found
|
||||||
|
|
||||||
|
def raw_merge(self, session, table, old_values, new_values):
|
||||||
|
conditions = []
|
||||||
|
for key, value in old_values.items():
|
||||||
|
conditions.append(getattr(table.c, key) == value)
|
||||||
|
sql = (
|
||||||
|
table
|
||||||
|
.update()
|
||||||
|
.where(sa.and_(
|
||||||
|
*conditions
|
||||||
|
))
|
||||||
|
.values(
|
||||||
|
new_values
|
||||||
|
)
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
session.execute(sql)
|
||||||
|
except sa.exc.IntegrityError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def merge_update(self, table_name, from_, to, foreign_key):
|
||||||
|
session = object_session(from_)
|
||||||
|
constrained_columns = foreign_key['constrained_columns']
|
||||||
|
referred_columns = foreign_key['referred_columns']
|
||||||
|
metadata = from_.metadata
|
||||||
|
table = metadata.tables[table_name]
|
||||||
|
|
||||||
|
new_values = {}
|
||||||
|
for index, column in enumerate(constrained_columns):
|
||||||
|
new_values[column] = getattr(
|
||||||
|
to, referred_columns[index]
|
||||||
|
)
|
||||||
|
|
||||||
|
old_values = {}
|
||||||
|
for index, column in enumerate(constrained_columns):
|
||||||
|
old_values[column] = getattr(
|
||||||
|
from_, referred_columns[index]
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.memory_merge(session, table_name, old_values, new_values):
|
||||||
|
self.raw_merge(session, table, old_values, new_values)
|
||||||
|
|
||||||
|
def __call__(self, from_, to):
|
||||||
|
"""
|
||||||
|
Merges entity into another entity. After merging deletes the from_
|
||||||
|
argument entity.
|
||||||
|
"""
|
||||||
|
if from_.__tablename__ != to.__tablename__:
|
||||||
|
raise Exception()
|
||||||
|
|
||||||
|
session = object_session(from_)
|
||||||
|
foreign_keys = dependent_foreign_keys(from_)
|
||||||
|
|
||||||
|
for table_name in foreign_keys:
|
||||||
|
for foreign_key in foreign_keys[table_name]:
|
||||||
|
self.merge_update(table_name, from_, to, foreign_key)
|
||||||
|
|
||||||
|
session.delete(from_)
|
||||||
|
|
||||||
|
|
||||||
|
def merge(from_, to, merger=Merger):
|
||||||
|
"""
|
||||||
|
Merges entity into another entity. After merging deletes the from_ argument
|
||||||
|
entity.
|
||||||
|
|
||||||
|
After merging the from_ entity is deleted from database.
|
||||||
|
|
||||||
|
:param from_: an entity to merge into another entity
|
||||||
|
:param to: an entity to merge another entity into
|
||||||
|
:param merger: Merger class, by default this is sqlalchemy_utils.Merger
|
||||||
|
class
|
||||||
|
"""
|
||||||
|
return Merger()(from_, to)
|
||||||
|
217
tests.py
217
tests.py
@@ -9,7 +9,8 @@ from sqlalchemy_utils import (
|
|||||||
escape_like,
|
escape_like,
|
||||||
sort_query,
|
sort_query,
|
||||||
InstrumentedList,
|
InstrumentedList,
|
||||||
PhoneNumberType
|
PhoneNumberType,
|
||||||
|
merge
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -142,3 +143,217 @@ class TestPhoneNumberType(TestCase):
|
|||||||
{'param': self.user.id}
|
{'param': self.user.id}
|
||||||
)
|
)
|
||||||
assert result.first()[0] == u'+358401234567'
|
assert result.first()[0] == u'+358401234567'
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseTestCase(object):
|
||||||
|
def create_models(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def setup_method(self, method):
|
||||||
|
self.engine = create_engine(
|
||||||
|
'sqlite:///'
|
||||||
|
)
|
||||||
|
#self.engine.echo = True
|
||||||
|
self.Base = declarative_base()
|
||||||
|
|
||||||
|
self.create_models()
|
||||||
|
self.Base.metadata.create_all(self.engine)
|
||||||
|
Session = sessionmaker(bind=self.engine)
|
||||||
|
self.session = Session()
|
||||||
|
|
||||||
|
def teardown_method(self, method):
|
||||||
|
self.engine.dispose()
|
||||||
|
self.Base.metadata.drop_all(self.engine)
|
||||||
|
self.session.expunge_all()
|
||||||
|
|
||||||
|
|
||||||
|
class TestMerge(DatabaseTestCase):
|
||||||
|
def create_models(self):
|
||||||
|
class User(self.Base):
|
||||||
|
__tablename__ = 'user'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return 'User(%r)' % self.name
|
||||||
|
|
||||||
|
class BlogPost(self.Base):
|
||||||
|
__tablename__ = 'blog_post'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
title = sa.Column(sa.Unicode(255))
|
||||||
|
content = sa.Column(sa.UnicodeText)
|
||||||
|
author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id'))
|
||||||
|
|
||||||
|
author = sa.orm.relationship(User)
|
||||||
|
|
||||||
|
self.User = User
|
||||||
|
self.BlogPost = BlogPost
|
||||||
|
|
||||||
|
def test_updates_foreign_keys(self):
|
||||||
|
john = self.User(name=u'John')
|
||||||
|
jack = self.User(name=u'Jack')
|
||||||
|
post = self.BlogPost(title=u'Some title', author=john)
|
||||||
|
post2 = self.BlogPost(title=u'Other title', author=jack)
|
||||||
|
self.session.add(john)
|
||||||
|
self.session.add(jack)
|
||||||
|
self.session.add(post)
|
||||||
|
self.session.add(post2)
|
||||||
|
self.session.commit()
|
||||||
|
merge(john, jack)
|
||||||
|
assert post.author == jack
|
||||||
|
assert post2.author == jack
|
||||||
|
|
||||||
|
def test_deletes_from_entity(self):
|
||||||
|
john = self.User(name=u'John')
|
||||||
|
jack = self.User(name=u'Jack')
|
||||||
|
self.session.add(john)
|
||||||
|
self.session.add(jack)
|
||||||
|
self.session.commit()
|
||||||
|
merge(john, jack)
|
||||||
|
assert john in self.session.deleted
|
||||||
|
|
||||||
|
|
||||||
|
class TestMergeManyToManyAssociations(DatabaseTestCase):
|
||||||
|
def create_models(self):
|
||||||
|
class User(self.Base):
|
||||||
|
__tablename__ = 'user'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return 'User(%r)' % self.name
|
||||||
|
|
||||||
|
team_member = sa.Table(
|
||||||
|
'team_member', self.Base.metadata,
|
||||||
|
sa.Column(
|
||||||
|
'user_id', sa.Integer,
|
||||||
|
sa.ForeignKey('user.id', ondelete='CASCADE'),
|
||||||
|
primary_key=True
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
'team_id', sa.Integer,
|
||||||
|
sa.ForeignKey('team.id', ondelete='CASCADE'),
|
||||||
|
primary_key=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
class Team(self.Base):
|
||||||
|
__tablename__ = 'team'
|
||||||
|
id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
members = sa.orm.relationship(
|
||||||
|
User,
|
||||||
|
secondary=team_member,
|
||||||
|
backref='teams'
|
||||||
|
)
|
||||||
|
|
||||||
|
self.User = User
|
||||||
|
self.Team = Team
|
||||||
|
|
||||||
|
def test_when_association_only_exists_in_from_entity(self):
|
||||||
|
john = self.User(name=u'John')
|
||||||
|
jack = self.User(name=u'Jack')
|
||||||
|
team = self.Team(name=u'Team')
|
||||||
|
team.members.append(john)
|
||||||
|
self.session.add(john)
|
||||||
|
self.session.add(jack)
|
||||||
|
self.session.commit()
|
||||||
|
merge(john, jack)
|
||||||
|
assert john not in team.members
|
||||||
|
assert jack in team.members
|
||||||
|
|
||||||
|
def test_when_association_exists_in_both(self):
|
||||||
|
john = self.User(name=u'John')
|
||||||
|
jack = self.User(name=u'Jack')
|
||||||
|
team = self.Team(name=u'Team')
|
||||||
|
team.members.append(john)
|
||||||
|
team.members.append(jack)
|
||||||
|
self.session.add(john)
|
||||||
|
self.session.add(jack)
|
||||||
|
self.session.commit()
|
||||||
|
merge(john, jack)
|
||||||
|
assert john not in team.members
|
||||||
|
assert jack in team.members
|
||||||
|
count = self.session.execute(
|
||||||
|
'SELECT COUNT(1) FROM team_member'
|
||||||
|
).fetchone()[0]
|
||||||
|
assert count == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestMergeManyToManyAssociationObjects(DatabaseTestCase):
|
||||||
|
def create_models(self):
|
||||||
|
class Team(self.Base):
|
||||||
|
__tablename__ = 'team'
|
||||||
|
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
class User(self.Base):
|
||||||
|
__tablename__ = 'user'
|
||||||
|
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
|
||||||
|
name = sa.Column(sa.Unicode(255))
|
||||||
|
|
||||||
|
class TeamMember(self.Base):
|
||||||
|
__tablename__ = 'team_member'
|
||||||
|
user_id = sa.Column(
|
||||||
|
sa.Integer,
|
||||||
|
sa.ForeignKey(User.id, ondelete='CASCADE'),
|
||||||
|
primary_key=True
|
||||||
|
)
|
||||||
|
team_id = sa.Column(
|
||||||
|
sa.Integer,
|
||||||
|
sa.ForeignKey(Team.id, ondelete='CASCADE'),
|
||||||
|
primary_key=True
|
||||||
|
)
|
||||||
|
role = sa.Column(sa.Unicode(255))
|
||||||
|
team = sa.orm.relationship(
|
||||||
|
Team,
|
||||||
|
backref=sa.orm.backref(
|
||||||
|
'members',
|
||||||
|
cascade='all, delete-orphan'
|
||||||
|
),
|
||||||
|
primaryjoin=team_id == Team.id,
|
||||||
|
)
|
||||||
|
user = sa.orm.relationship(
|
||||||
|
User,
|
||||||
|
backref=sa.orm.backref(
|
||||||
|
'memberships',
|
||||||
|
cascade='all, delete-orphan'
|
||||||
|
),
|
||||||
|
primaryjoin=user_id == User.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.User = User
|
||||||
|
self.TeamMember = TeamMember
|
||||||
|
self.Team = Team
|
||||||
|
|
||||||
|
def test_when_association_only_exists_in_from_entity(self):
|
||||||
|
john = self.User(name=u'John')
|
||||||
|
jack = self.User(name=u'Jack')
|
||||||
|
team = self.Team(name=u'Team')
|
||||||
|
team.members.append(self.TeamMember(user=john))
|
||||||
|
self.session.add(john)
|
||||||
|
self.session.add(jack)
|
||||||
|
self.session.add(team)
|
||||||
|
self.session.commit()
|
||||||
|
merge(john, jack)
|
||||||
|
self.session.commit()
|
||||||
|
users = [member.user for member in team.members]
|
||||||
|
assert john not in users
|
||||||
|
assert jack in users
|
||||||
|
|
||||||
|
def test_when_association_exists_in_both(self):
|
||||||
|
john = self.User(name=u'John')
|
||||||
|
jack = self.User(name=u'Jack')
|
||||||
|
team = self.Team(name=u'Team')
|
||||||
|
team.members.append(self.TeamMember(user=john))
|
||||||
|
team.members.append(self.TeamMember(user=jack))
|
||||||
|
self.session.add(john)
|
||||||
|
self.session.add(jack)
|
||||||
|
self.session.add(team)
|
||||||
|
self.session.commit()
|
||||||
|
merge(john, jack)
|
||||||
|
users = [member.user for member in team.members]
|
||||||
|
assert john not in users
|
||||||
|
assert jack in users
|
||||||
|
assert self.session.query(self.TeamMember).count() == 1
|
||||||
|
Reference in New Issue
Block a user