Files
deb-python-sqlalchemy-utils/sqlalchemy_utils/merge.py
2013-06-21 00:08:51 -07:00

125 lines
3.9 KiB
Python

import six
import sqlalchemy as sa
from sqlalchemy.engine import reflection
from sqlalchemy.orm import object_session, mapperlib
def dependent_foreign_keys(model_class):
"""
Returns dependent foreign keys as dicts for given model class.
** Experimental function **
"""
session = object_session(model_class)
engine = session.bind
inspector = reflection.Inspector.from_engine(engine)
table_names = inspector.get_table_names()
dependent_foreign_keys = {}
for table_name in table_names:
fks = inspector.get_foreign_keys(table_name)
if fks:
dependent_foreign_keys[table_name] = []
for fk in fks:
if fk['referred_table'] == model_class.__tablename__:
dependent_foreign_keys[table_name].append(fk)
return dependent_foreign_keys
class Merger(object):
def memory_merge(self, session, table_name, old_values, new_values):
# try to fetch mappers for given table and update in memory objects as
# well as database table
found = False
for mapper in mapperlib._mapper_registry:
class_ = mapper.class_
if table_name == class_.__table__.name:
try:
(
session.query(mapper.class_)
.filter_by(**old_values)
.update(
new_values,
'fetch'
)
)
except sa.exc.IntegrityError:
pass
found = True
return found
def raw_merge(self, session, table, old_values, new_values):
conditions = []
for key, value in six.iteritems(old_values):
conditions.append(getattr(table.c, key) == value)
sql = (
table
.update()
.where(sa.and_(
*conditions
))
.values(
new_values
)
)
try:
session.execute(sql)
except sa.exc.IntegrityError:
pass
def merge_update(self, table_name, from_, to, foreign_key):
session = object_session(from_)
constrained_columns = foreign_key['constrained_columns']
referred_columns = foreign_key['referred_columns']
metadata = from_.metadata
table = metadata.tables[table_name]
new_values = {}
for index, column in enumerate(constrained_columns):
new_values[column] = getattr(
to, referred_columns[index]
)
old_values = {}
for index, column in enumerate(constrained_columns):
old_values[column] = getattr(
from_, referred_columns[index]
)
if not self.memory_merge(session, table_name, old_values, new_values):
self.raw_merge(session, table, old_values, new_values)
def __call__(self, from_, to):
"""
Merges entity into another entity. After merging deletes the from_
argument entity.
"""
if from_.__tablename__ != to.__tablename__:
raise Exception()
session = object_session(from_)
foreign_keys = dependent_foreign_keys(from_)
for table_name in foreign_keys:
for foreign_key in foreign_keys[table_name]:
self.merge_update(table_name, from_, to, foreign_key)
session.delete(from_)
def merge(from_, to, merger=Merger):
"""
Merges entity into another entity. After merging deletes the from_ argument
entity.
After merging the from_ entity is deleted from database.
:param from_: an entity to merge into another entity
:param to: an entity to merge another entity into
:param merger: Merger class, by default this is sqlalchemy_utils.Merger
class
"""
return Merger()(from_, to)