118 lines
4.0 KiB
Python
118 lines
4.0 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
|
|
import shutil
|
|
from StringIO import StringIO
|
|
|
|
import migrate
|
|
from migrate.versioning import exceptions, genmodel, schemadiff
|
|
from migrate.versioning.base import operations
|
|
from migrate.versioning.template import template
|
|
from migrate.versioning.script import base
|
|
from migrate.versioning.util import import_path, loadModel, construct_engine
|
|
|
|
class PythonScript(base.BaseScript):
|
|
|
|
@classmethod
|
|
def create(cls, path, **opts):
|
|
"""Create an empty migration script"""
|
|
cls.require_notfound(path)
|
|
|
|
# TODO: Use the default script template (defined in the template
|
|
# module) for now, but we might want to allow people to specify a
|
|
# different one later.
|
|
template_file = None
|
|
src = template.get_script(template_file)
|
|
shutil.copy(src, path)
|
|
|
|
@classmethod
|
|
def make_update_script_for_model(cls, engine, oldmodel,
|
|
model, repository, **opts):
|
|
"""Create a migration script"""
|
|
|
|
# Compute differences.
|
|
if isinstance(repository, basestring):
|
|
# oh dear, an import cycle!
|
|
from migrate.versioning.repository import Repository
|
|
repository = Repository(repository)
|
|
oldmodel = loadModel(oldmodel)
|
|
model = loadModel(model)
|
|
diff = schemadiff.getDiffOfModelAgainstModel(
|
|
oldmodel,
|
|
model,
|
|
engine,
|
|
excludeTables=[repository.version_table])
|
|
decls, upgradeCommands, downgradeCommands = \
|
|
genmodel.ModelGenerator(diff).toUpgradeDowngradePython()
|
|
|
|
# Store differences into file.
|
|
template_file = None
|
|
src = template.get_script(template_file)
|
|
contents = open(src).read()
|
|
search = 'def upgrade():'
|
|
contents = contents.replace(search, '\n\n'.join((decls, search)), 1)
|
|
if upgradeCommands:
|
|
contents = contents.replace(' pass', upgradeCommands, 1)
|
|
if downgradeCommands:
|
|
contents = contents.replace(' pass', downgradeCommands, 1)
|
|
return contents
|
|
|
|
@classmethod
|
|
def verify_module(cls,path):
|
|
"""Ensure this is a valid script, or raise InvalidScriptError"""
|
|
# Try to import and get the upgrade() func
|
|
try:
|
|
module=import_path(path)
|
|
except:
|
|
# If the script itself has errors, that's not our problem
|
|
raise
|
|
try:
|
|
assert callable(module.upgrade)
|
|
except Exception, e:
|
|
raise exceptions.InvalidScriptError(path + ': %s' % str(e))
|
|
return module
|
|
|
|
def preview_sql(self, url, step, **args):
|
|
"""Mock engine to store all executable calls in a string \
|
|
and execute the step"""
|
|
buf = StringIO()
|
|
args['engine_arg_strategy'] = 'mock'
|
|
args['engine_arg_executor'] = lambda s, p='': buf.write(s + p)
|
|
engine = construct_engine(url, **args)
|
|
|
|
self.run(engine, step)
|
|
|
|
return buf.getvalue()
|
|
|
|
def run(self, engine, step):
|
|
"""Core method of Script file. \
|
|
Exectues update() or downgrade() function"""
|
|
if step > 0:
|
|
op = 'upgrade'
|
|
elif step < 0:
|
|
op = 'downgrade'
|
|
else:
|
|
raise exceptions.ScriptError("%d is not a valid step" % step)
|
|
funcname = base.operations[op]
|
|
|
|
migrate.migrate_engine = engine
|
|
#migrate.run.migrate_engine = migrate.migrate_engine = engine
|
|
func = self._func(funcname)
|
|
func()
|
|
migrate.migrate_engine = None
|
|
#migrate.run.migrate_engine = migrate.migrate_engine = None
|
|
|
|
def _get_module(self):
|
|
if not hasattr(self,'_module'):
|
|
self._module = self.verify_module(self.path)
|
|
return self._module
|
|
module = property(_get_module)
|
|
|
|
|
|
def _func(self, funcname):
|
|
fn = getattr(self.module, funcname, None)
|
|
if not fn:
|
|
msg = "The function %s is not defined in this script"
|
|
raise exceptions.ScriptError(msg%funcname)
|
|
return fn
|