add tests for plain API, fixed some small bugs

This commit is contained in:
iElectric
2009-07-08 22:03:00 +02:00
parent 286a912e34
commit 67af81806d
18 changed files with 318 additions and 109 deletions

View File

@@ -23,6 +23,7 @@
**Backward incompatible changes**: **Backward incompatible changes**:
- :func:`api.test` and schema comparison functions now all accept `url` as first parameter and `repository` as second.
- python upgrade/downgrade scripts do not import `migrate_engine` magically, but recieve engine as the only parameter to function (eg. ``def upgrade(migrate_engine):``) - python upgrade/downgrade scripts do not import `migrate_engine` magically, but recieve engine as the only parameter to function (eg. ``def upgrade(migrate_engine):``)
- :meth:`Column.alter <migrate.changeset.schema.ChangesetColumn.alter>` does not accept `current_name` anymore, it extracts name from the old column. - :meth:`Column.alter <migrate.changeset.schema.ChangesetColumn.alter>` does not accept `current_name` anymore, it extracts name from the old column.

View File

@@ -1,12 +1,17 @@
""" """
This module provides an external API to the versioning system. This module provides an external API to the versioning system.
.. versionchanged:: 0.4.5 .. versionchanged:: 0.6.0
:func:`migrate.versioning.api.test` and schema diff functions \
changed order of positional arguments so all accept `url` and `repository`\
as first arguments.
.. versionchanged:: 0.5.4
``--preview_sql`` displays source file when using SQL scripts. ``--preview_sql`` displays source file when using SQL scripts.
If Python script is used, it runs the action with mocked engine and If Python script is used, it runs the action with mocked engine and
returns captured SQL statements. returns captured SQL statements.
.. versionchanged:: 0.4.5 .. versionchanged:: 0.5.4
Deprecated ``--echo`` parameter in favour of new Deprecated ``--echo`` parameter in favour of new
:func:`migrate.versioning.util.construct_engine` behavior. :func:`migrate.versioning.util.construct_engine` behavior.
""" """
@@ -74,6 +79,7 @@ def help(cmd=None, **opts):
ret = ret.replace('%prog', sys.argv[0]) ret = ret.replace('%prog', sys.argv[0])
return ret return ret
@catch_known_errors @catch_known_errors
def create(repository, name, **opts): def create(repository, name, **opts):
"""%prog create REPOSITORY_PATH NAME [--table=TABLE] """%prog create REPOSITORY_PATH NAME [--table=TABLE]
@@ -84,7 +90,7 @@ def create(repository, name, **opts):
'migrate_version'. This table is created in all version-controlled 'migrate_version'. This table is created in all version-controlled
databases. databases.
""" """
repo_path = Repository.create(repository, name, **opts) Repository.create(repository, name, **opts)
@catch_known_errors @catch_known_errors
@@ -192,8 +198,8 @@ def downgrade(url, repository, version, **opts):
"Try 'upgrade' instead." "Try 'upgrade' instead."
return _migrate(url, repository, version, upgrade=False, err=err, **opts) return _migrate(url, repository, version, upgrade=False, err=err, **opts)
def test(repository, url, **opts): def test(url, repository, **opts):
"""%prog test REPOSITORY_PATH URL [VERSION] """%prog test URL REPOSITORY_PATH [VERSION]
Performs the upgrade and downgrade option on the given Performs the upgrade and downgrade option on the given
database. This is not a real test and may leave the database in a database. This is not a real test and may leave the database in a
@@ -267,8 +273,8 @@ def manage(file, **opts):
return Repository.create_manage_file(file, **opts) return Repository.create_manage_file(file, **opts)
def compare_model_to_db(url, model, repository, **opts): def compare_model_to_db(url, repository, model, **opts):
"""%prog compare_model_to_db URL MODEL REPOSITORY_PATH """%prog compare_model_to_db URL REPOSITORY_PATH MODEL
Compare the current model (assumed to be a module level variable Compare the current model (assumed to be a module level variable
of type sqlalchemy.MetaData) against the current database. of type sqlalchemy.MetaData) against the current database.
@@ -276,7 +282,7 @@ def compare_model_to_db(url, model, repository, **opts):
NOTE: This is EXPERIMENTAL. NOTE: This is EXPERIMENTAL.
""" # TODO: get rid of EXPERIMENTAL label """ # TODO: get rid of EXPERIMENTAL label
engine = construct_engine(url, **opts) engine = construct_engine(url, **opts)
print ControlledSchema.compare_model_to_db(engine, model, repository) return ControlledSchema.compare_model_to_db(engine, model, repository)
def create_model(url, repository, **opts): def create_model(url, repository, **opts):
@@ -288,13 +294,12 @@ def create_model(url, repository, **opts):
""" # TODO: get rid of EXPERIMENTAL label """ # TODO: get rid of EXPERIMENTAL label
engine = construct_engine(url, **opts) engine = construct_engine(url, **opts)
declarative = opts.get('declarative', False) declarative = opts.get('declarative', False)
print ControlledSchema.create_model(engine, repository, declarative) return ControlledSchema.create_model(engine, repository, declarative)
# TODO: get rid of this? if we don't add back path param
@catch_known_errors @catch_known_errors
def make_update_script_for_model(url, oldmodel, model, repository, **opts): def make_update_script_for_model(url, repository, oldmodel, model, **opts):
"""%prog make_update_script_for_model URL OLDMODEL MODEL REPOSITORY_PATH """%prog make_update_script_for_model URL REPOSITORY_PATH OLDMODEL MODEL
Create a script changing the old Python model to the new (current) Create a script changing the old Python model to the new (current)
Python model, sending to stdout. Python model, sending to stdout.
@@ -302,12 +307,12 @@ def make_update_script_for_model(url, oldmodel, model, repository, **opts):
NOTE: This is EXPERIMENTAL. NOTE: This is EXPERIMENTAL.
""" # TODO: get rid of EXPERIMENTAL label """ # TODO: get rid of EXPERIMENTAL label
engine = construct_engine(url, **opts) engine = construct_engine(url, **opts)
print PythonScript.make_update_script_for_model( return PythonScript.make_update_script_for_model(
engine, oldmodel, model, repository, **opts) engine, oldmodel, model, repository, **opts)
def update_db_from_model(url, model, repository, **opts): def update_db_from_model(url, repository, model, **opts):
"""%prog update_db_from_model URL MODEL REPOSITORY_PATH """%prog update_db_from_model URL REPOSITORY_PATH MODEL
Modify the database to match the structure of the current Python Modify the database to match the structure of the current Python
model. This also sets the db_version number to the latest in the model. This also sets the db_version number to the latest in the
@@ -337,15 +342,14 @@ def _migrate(url, repository, version, upgrade, err, **opts):
print change.source() print change.source()
elif opts.get('preview_py'): elif opts.get('preview_py'):
if not isinstance(change, PythonScript):
raise exceptions.UsageError("Python source can be only displayed"
" for python migration files")
source_ver = max(ver, nextver) source_ver = max(ver, nextver)
module = schema.repository.version(source_ver).script().module module = schema.repository.version(source_ver).script().module
funcname = upgrade and "upgrade" or "downgrade" funcname = upgrade and "upgrade" or "downgrade"
func = getattr(module, funcname) func = getattr(module, funcname)
if isinstance(change, PythonScript):
print inspect.getsource(func) print inspect.getsource(func)
else:
raise UsageError("Python source can be only displayed"
" for python migration files")
else: else:
schema.runchange(ver, change, changeset.step) schema.runchange(ver, change, changeset.step)
print 'done' print 'done'

View File

@@ -1,6 +1,8 @@
""" """
Database schema version management. Database schema version management.
""" """
import sys
from sqlalchemy import (Table, Column, MetaData, String, Text, Integer, from sqlalchemy import (Table, Column, MetaData, String, Text, Integer,
create_engine) create_engine)
from sqlalchemy.sql import and_ from sqlalchemy.sql import and_
@@ -32,22 +34,17 @@ class ControlledSchema(object):
def load(self): def load(self):
"""Load controlled schema version info from DB""" """Load controlled schema version info from DB"""
tname = self.repository.version_table tname = self.repository.version_table
if not hasattr(self, 'table') or self.table is None:
try: try:
if not hasattr(self, 'table') or self.table is None:
self.table = Table(tname, self.meta, autoload=True) self.table = Table(tname, self.meta, autoload=True)
except (sa_exceptions.NoSuchTableError,
AssertionError):
# assertionerror is raised if no table is found in oracle db
raise exceptions.DatabaseNotControlledError(tname)
# TODO?: verify that the table is correct (# cols, etc.)
result = self.engine.execute(self.table.select( result = self.engine.execute(self.table.select(
self.table.c.repository_id == str(self.repository.id))) self.table.c.repository_id == str(self.repository.id)))
try:
data = list(result)[0] data = list(result)[0]
except IndexError: except Exception:
raise exceptions.DatabaseNotControlledError(tname) cls, exc, tb = sys.exc_info()
raise exceptions.DatabaseNotControlledError, exc.message, tb
self.version = data['version'] self.version = data['version']
return data return data

View File

@@ -1,6 +1,7 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import warnings
import shutil import shutil
from StringIO import StringIO from StringIO import StringIO
@@ -11,6 +12,7 @@ from migrate.versioning.template import template
from migrate.versioning.script import base from migrate.versioning.script import base
from migrate.versioning.util import import_path, load_model, construct_engine from migrate.versioning.util import import_path, load_model, construct_engine
__all__ = ['PythonScript']
class PythonScript(base.BaseScript): class PythonScript(base.BaseScript):
"""Base for Python scripts""" """Base for Python scripts"""
@@ -88,16 +90,11 @@ class PythonScript(base.BaseScript):
:param path: Script location :param path: Script location
:type path: string :type path: string
:raises: :exc:`InvalidScriptError <migrate.versioning.exceptions.InvalidScriptError>` :raises: :exc:`InvalidScriptError <migrate.versioning.exceptions.InvalidScriptError>`
:returns: Python module :returns: Python module
""" """
# Try to import and get the upgrade() func # Try to import and get the upgrade() func
try:
module = import_path(path) module = import_path(path)
except:
# If the script itself has errors, that's not our problem
raise
try: try:
assert callable(module.upgrade) assert callable(module.upgrade)
except Exception, e: except Exception, e:
@@ -134,13 +131,15 @@ class PythonScript(base.BaseScript):
op = 'downgrade' op = 'downgrade'
else: else:
raise exceptions.ScriptError("%d is not a valid step" % step) raise exceptions.ScriptError("%d is not a valid step" % step)
funcname = base.operations[op]
func = self._func(funcname) funcname = base.operations[op]
script_func = self._func(funcname)
try: try:
func(engine) script_func(engine)
except TypeError: except TypeError:
print "upgrade/downgrade functions must accept engine parameter (since ver 0.5.5)" warnings.warn("upgrade/downgrade functions must accept engine"
" parameter (since version > 0.5.4)")
raise raise
@property @property
@@ -148,7 +147,7 @@ class PythonScript(base.BaseScript):
"""Calls :meth:`migrate.versioning.script.py.verify_module` """Calls :meth:`migrate.versioning.script.py.verify_module`
and returns it. and returns it.
""" """
if not hasattr(self, '_module'): if not getattr(self, '_module', None):
self._module = self.verify_module(self.path) self._module = self.verify_module(self.path)
return self._module return self._module

View File

@@ -7,8 +7,16 @@ from migrate.versioning.script import base
class SqlScript(base.BaseScript): class SqlScript(base.BaseScript):
"""A file containing plain SQL statements.""" """A file containing plain SQL statements."""
@classmethod
def create(cls, path, **opts):
"""Create an empty migration script at specified path
:returns: :class:`SqlScript instance <migrate.versioning.script.sql.SqlScript>`"""
cls.require_notfound(path)
open(path, "w").close()
# TODO: why is step parameter even here? # TODO: why is step parameter even here?
def run(self, engine, step=None): def run(self, engine, step=None, executemany=True):
"""Runs SQL script through raw dbapi execute call""" """Runs SQL script through raw dbapi execute call"""
text = self.source() text = self.source()
# Don't rely on SA's autocommit here # Don't rely on SA's autocommit here
@@ -21,7 +29,7 @@ class SqlScript(base.BaseScript):
# HACK: SQLite doesn't allow multiple statements through # HACK: SQLite doesn't allow multiple statements through
# its execute() method, but it provides executescript() instead # its execute() method, but it provides executescript() instead
dbapi = conn.engine.raw_connection() dbapi = conn.engine.raw_connection()
if getattr(dbapi, 'executescript', None): if executemany and getattr(dbapi, 'executescript', None):
dbapi.executescript(text) dbapi.executescript(text)
else: else:
conn.execute(text) conn.execute(text)

View File

@@ -69,7 +69,6 @@ def main(argv=None, **kwargs):
parser = PassiveOptionParser(usage=usage) parser = PassiveOptionParser(usage=usage)
parser.add_option("-v", "--verbose", action="store_true", dest="verbose") parser.add_option("-v", "--verbose", action="store_true", dest="verbose")
parser.add_option("-d", "--debug", action="store_true", dest="debug") parser.add_option("-d", "--debug", action="store_true", dest="debug")
parser.add_option("-f", "--force", action="store_true", dest="force")
help_commands = ['help', '-h', '--help'] help_commands = ['help', '-h', '--help']
HELP = False HELP = False
@@ -156,8 +155,6 @@ def main(argv=None, **kwargs):
if ret is not None: if ret is not None:
print ret print ret
except (exceptions.UsageError, exceptions.KnownError), e: except (exceptions.UsageError, exceptions.KnownError), e:
if e.args[0] is None:
parser.print_help()
parser.error(e.args[0]) parser.error(e.args[0])
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -81,7 +81,7 @@ def catch_known_errors(f, *a, **kw):
""" """
try: try:
f(*a, **kw) return f(*a, **kw)
except exceptions.PathFoundError, e: except exceptions.PathFoundError, e:
raise exceptions.KnownError("The path %s already exists" % e.args[0]) raise exceptions.KnownError("The path %s already exists" % e.args[0])

View File

@@ -98,11 +98,7 @@ class Collection(pathed.Pathed):
filename = '%03d%s.py' % (ver, extra) filename = '%03d%s.py' % (ver, extra)
filepath = self._version_path(filename) filepath = self._version_path(filename)
if os.path.exists(filepath):
raise Exception('Script already exists: %s' % filepath)
else:
script.PythonScript.create(filepath) script.PythonScript.create(filepath)
self.versions[ver] = Version(ver, self.path, [filename]) self.versions[ver] = Version(ver, self.path, [filename])
def create_new_sql_version(self, database, **k): def create_new_sql_version(self, database, **k):
@@ -114,10 +110,7 @@ class Collection(pathed.Pathed):
for op in ('upgrade', 'downgrade'): for op in ('upgrade', 'downgrade'):
filename = '%03d_%s_%s.sql' % (ver, database, op) filename = '%03d_%s_%s.sql' % (ver, database, op)
filepath = self._version_path(filename) filepath = self._version_path(filename)
if os.path.exists(filepath): script.SqlScript.create(filepath)
raise Exception('Script already exists: %s' % filepath)
else:
open(filepath, "w").close()
self.versions[ver].add_script(filepath) self.versions[ver].add_script(filepath)
def version(self, vernum=None): def version(self, vernum=None):
@@ -137,7 +130,14 @@ class Collection(pathed.Pathed):
class Version(object): class Version(object):
"""A single version in a collection """ """A single version in a collection
:param vernum: Version Number
:param path: Path to script files
:param filelist: List of scripts
:type vernum: int, VerNum
:type path: string
:type filelist: list
"""
def __init__(self, vernum, path, filelist): def __init__(self, vernum, path, filelist):
self.version = VerNum(vernum) self.version = VerNum(vernum)
@@ -165,22 +165,6 @@ class Version(object):
"There is no script for %d version" % self.version "There is no script for %d version" % self.version
return ret return ret
# deprecated?
@classmethod
def create(cls, path):
os.mkdir(path)
# create the version as a proper Python package
initfile = os.path.join(path, "__init__.py")
if not os.path.exists(initfile):
# just touch the file
open(initfile, "w").close()
try:
ret = cls(path)
except:
os.rmdir(path)
raise
return ret
def add_script(self, path): def add_script(self, path):
"""Add script to Collection/Version""" """Add script to Collection/Version"""
if path.endswith(Extensions.py): if path.endswith(Extensions.py):
@@ -203,10 +187,11 @@ class Version(object):
def _add_script_py(self, path): def _add_script_py(self, path):
if self.python is not None: if self.python is not None:
raise Exception('You can only have one Python script per version,' raise exceptions.ScriptError('You can only have one Python script '
' but you have: %s and %s' % (self.python, path)) 'per version, but you have: %s and %s' % (self.python, path))
self.python = script.PythonScript(path) self.python = script.PythonScript(path)
class Extensions: class Extensions:
"""A namespace for file extensions""" """A namespace for file extensions"""
py = 'py' py = 'py'

View File

@@ -7,8 +7,8 @@ tag_svn_revision = 1
tag_build = .dev tag_build = .dev
[nosetests] [nosetests]
pdb = true #pdb = true
pdb-failures = true #pdb-failures = true
#stop = true #stop = true
[aliases] [aliases]

View File

@@ -3,9 +3,12 @@ from sqlalchemy import *
# test rundiffs in shell # test rundiffs in shell
meta_old_rundiffs = MetaData() meta_old_rundiffs = MetaData()
meta_rundiffs = MetaData() meta_rundiffs = MetaData()
meta = MetaData()
tmp_account_rundiffs = Table('tmp_account_rundiffs', meta_rundiffs, tmp_account_rundiffs = Table('tmp_account_rundiffs', meta_rundiffs,
Column('id', Integer, primary_key=True), Column('id', Integer, primary_key=True),
Column('login', String(40)), Column('login', String(40)),
Column('passwd', String(40)), Column('passwd', String(40)),
) )
tmp_sql_table = Table('tmp_sql_table', meta, Column('id', Integer))

View File

@@ -5,6 +5,8 @@ from migrate.versioning import api
from migrate.versioning.exceptions import * from migrate.versioning.exceptions import *
from test.fixture.pathed import * from test.fixture.pathed import *
from test.fixture import models
from test import fixture
class TestAPI(Pathed): class TestAPI(Pathed):
@@ -15,14 +17,104 @@ class TestAPI(Pathed):
self.assertRaises(UsageError, api.help, 'foobar') self.assertRaises(UsageError, api.help, 'foobar')
self.assert_(isinstance(api.help('create'), str)) self.assert_(isinstance(api.help('create'), str))
def test_help_commands(self): # test that all commands return some text
pass for cmd in api.__all__:
content = api.help(cmd)
self.assertTrue(content)
def test_create(self): def test_create(self):
pass tmprepo = self.tmp_repos()
api.create(tmprepo, 'temp')
# repository already exists
self.assertRaises(KnownError, api.create, tmprepo, 'temp')
def test_script(self): def test_script(self):
pass repo = self.tmp_repos()
api.create(repo, 'temp')
api.script('first version', repo)
def test_script_sql(self): def test_script_sql(self):
pass repo = self.tmp_repos()
api.create(repo, 'temp')
api.script_sql('postgres', repo)
def test_version(self):
repo = self.tmp_repos()
api.create(repo, 'temp')
api.version(repo)
def test_source(self):
repo = self.tmp_repos()
api.create(repo, 'temp')
api.script('first version', repo)
api.script_sql('default', repo)
# no repository
self.assertRaises(UsageError, api.source, 1)
# stdout
out = api.source(1, dest=None, repository=repo)
self.assertTrue(out)
# file
out = api.source(1, dest=self.tmp_repos(), repository=repo)
self.assertFalse(out)
def test_manage(self):
output = api.manage(os.path.join(self.temp_usable_dir, 'manage.py'))
class TestSchemaAPI(fixture.DB, Pathed):
def _setup(self, url):
super(TestSchemaAPI, self)._setup(url)
self.repo = self.tmp_repos()
api.create(self.repo, 'temp')
self.schema = api.version_control(url, self.repo)
def _teardown(self):
self.schema = api.drop_version_control(self.url, self.repo)
super(TestSchemaAPI, self)._teardown()
@fixture.usedb()
def test_workflow(self):
self.assertEqual(api.db_version(self.url, self.repo), 0)
api.script('First Version', self.repo)
self.assertEqual(api.db_version(self.url, self.repo), 0)
api.upgrade(self.url, self.repo, 1)
self.assertEqual(api.db_version(self.url, self.repo), 1)
api.downgrade(self.url, self.repo, 0)
self.assertEqual(api.db_version(self.url, self.repo), 0)
api.test(self.url, self.repo)
self.assertEqual(api.db_version(self.url, self.repo), 0)
# preview
# TODO: test output
out = api.upgrade(self.url, self.repo, preview_py=True)
out = api.upgrade(self.url, self.repo, preview_sql=True)
api.upgrade(self.url, self.repo, 1)
api.script_sql('default', self.repo)
self.assertRaises(UsageError, api.upgrade, self.url, self.repo, 2, preview_py=True)
out = api.upgrade(self.url, self.repo, 2, preview_sql=True)
# cant upgrade to version 1, already at version 1
self.assertEqual(api.db_version(self.url, self.repo), 1)
self.assertRaises(KnownError, api.upgrade, self.url, self.repo, 0)
@fixture.usedb()
def test_compare_model_to_db(self):
diff = api.compare_model_to_db(self.url, self.repo, models.meta)
@fixture.usedb()
def test_create_model(self):
model = api.create_model(self.url, self.repo)
@fixture.usedb()
def test_make_update_script_for_model(self):
model = api.make_update_script_for_model(self.url, self.repo, models.meta_old_rundiffs, models.meta_rundiffs)
@fixture.usedb()
def test_update_db_from_model(self):
model = api.update_db_from_model(self.url, self.repo, models.meta_rundiffs)

View File

@@ -1,3 +1,6 @@
#!/usr/bin/python
# -*- coding: utf-8 -*-
from test import fixture from test import fixture
from migrate.versioning.util.keyedinstance import * from migrate.versioning.util.keyedinstance import *
@@ -38,3 +41,5 @@ class TestKeydInstance(fixture.Base):
Uniq1.clear() Uniq1.clear()
a12 = Uniq1('a') a12 = Uniq1('a')
self.assert_(a10 is not a12) self.assert_(a10 is not a12)
self.assertRaises(NotImplementedError, KeyedInstance._key)

View File

@@ -101,13 +101,15 @@ class TestVersionedRepository(fixture.Pathed):
# Load repository and commit script # Load repository and commit script
repo = Repository(self.path_repos) repo = Repository(self.path_repos)
repo.create_script('') repo.create_script('')
repo.create_script_sql('postgres')
# Get script object
source = repo.version(1).script().source()
# Source is valid: script must have an upgrade function # Source is valid: script must have an upgrade function
# (not a very thorough test, but should be plenty) # (not a very thorough test, but should be plenty)
self.assert_(source.find('def upgrade') >= 0) source = repo.version(1).script().source()
self.assertTrue(source.find('def upgrade') >= 0)
source = repo.version(2).script('postgres', 'upgrade').source()
self.assertEqual(source.strip(), '')
def test_latestversion(self): def test_latestversion(self):
"""Repository.version() (no params) returns the latest version""" """Repository.version() (no params) returns the latest version"""

View File

@@ -16,11 +16,10 @@ class TestControlledSchema(fixture.Pathed, fixture.DB):
# Transactions break postgres in this test; we'll clean up after ourselves # Transactions break postgres in this test; we'll clean up after ourselves
level = fixture.DB.CONNECT level = fixture.DB.CONNECT
def setUp(self): def setUp(self):
super(TestControlledSchema, self).setUp() super(TestControlledSchema, self).setUp()
path_repos = self.temp_usable_dir + '/repo/' self.path_repos = self.temp_usable_dir + '/repo/'
self.repos = Repository.create(path_repos, 'repo_name') self.repos = Repository.create(self.path_repos, 'repo_name')
def _setup(self, url): def _setup(self, url):
self.setUp() self.setUp()
@@ -44,6 +43,19 @@ class TestControlledSchema(fixture.Pathed, fixture.DB):
self.cleanup() self.cleanup()
super(TestControlledSchema, self).tearDown() super(TestControlledSchema, self).tearDown()
@fixture.usedb()
def test_schema_table_fail(self):
"""Test scenarios when loading schema should fail"""
dbcontrol = ControlledSchema.create(self.engine, self.path_repos)
dbcontrol.table.drop()
try:
dbcontrol.load()
except exceptions.DatabaseNotControlledError:
pass
else:
self.fail()
@fixture.usedb() @fixture.usedb()
def test_version_control(self): def test_version_control(self):
"""Establish version control on a particular database""" """Establish version control on a particular database"""
@@ -116,7 +128,7 @@ class TestControlledSchema(fixture.Pathed, fixture.DB):
#self.assertRaises(ControlledSchema.InvalidVersionError, #self.assertRaises(ControlledSchema.InvalidVersionError,
# Can't have custom errors with assertRaises... # Can't have custom errors with assertRaises...
try: try:
ControlledSchema.create(self.engine,self.repos,version) ControlledSchema.create(self.engine, self.repos,version)
self.assert_(False, repr(version)) self.assert_(False, repr(version))
except exceptions.InvalidVersionError: except exceptions.InvalidVersionError:
pass pass

View File

@@ -10,6 +10,7 @@ from migrate.versioning.script import *
from migrate.versioning.util import * from migrate.versioning.util import *
from test import fixture from test import fixture
from test.fixture.models import tmp_sql_table
class TestBaseScript(fixture.Pathed): class TestBaseScript(fixture.Pathed):
@@ -48,6 +49,25 @@ class TestPyScript(fixture.Pathed, fixture.DB):
self.assertRaises(exceptions.ScriptError, pyscript.run, self.engine, 0) self.assertRaises(exceptions.ScriptError, pyscript.run, self.engine, 0)
self.assertRaises(exceptions.ScriptError, pyscript._func, 'foobar') self.assertRaises(exceptions.ScriptError, pyscript._func, 'foobar')
# clean pyc file
os.remove(script_path + 'c')
# test deprecated upgrade/downgrade with no arguments
contents = open(script_path, 'r').read()
f = open(script_path, 'w')
f.write(contents.replace("upgrade(migrate_engine)", "upgrade()"))
f.close()
pyscript = PythonScript(script_path)
pyscript._module = None
try:
pyscript.run(self.engine, 1)
pyscript.run(self.engine, -1)
except TypeError:
pass
else:
self.fail()
def test_verify_notfound(self): def test_verify_notfound(self):
"""Correctly verify a python migration script: nonexistant file""" """Correctly verify a python migration script: nonexistant file"""
path = self.tmp_py() path = self.tmp_py()
@@ -86,7 +106,7 @@ class TestPyScript(fixture.Pathed, fixture.DB):
path = self.tmp_py() path = self.tmp_py()
f = open(path, 'w') f = open(path, 'w')
content = """ content = '''
from migrate import * from migrate import *
from sqlalchemy import * from sqlalchemy import *
@@ -99,7 +119,7 @@ UserGroup = Table('Link', metadata,
def upgrade(migrate_engine): def upgrade(migrate_engine):
metadata.create_all(migrate_engine) metadata.create_all(migrate_engine)
""" '''
f.write(content) f.write(content)
f.close() f.close()
@@ -130,7 +150,6 @@ def upgrade(migrate_engine):
self.write_file(self.first_model_path, self.base_source) self.write_file(self.first_model_path, self.base_source)
self.write_file(self.second_model_path, self.base_source + self.model_source) self.write_file(self.second_model_path, self.base_source + self.model_source)
source_script = self.pyscript.make_update_script_for_model( source_script = self.pyscript.make_update_script_for_model(
engine=self.engine, engine=self.engine,
oldmodel=load_model('testmodel_first:meta'), oldmodel=load_model('testmodel_first:meta'),
@@ -195,3 +214,31 @@ class TestSqlScript(fixture.Pathed, fixture.DB):
sqls = SqlScript(src) sqls = SqlScript(src)
self.assertRaises(Exception, sqls.run, self.engine) self.assertRaises(Exception, sqls.run, self.engine)
@fixture.usedb()
def test_success(self):
"""Test sucessful SQL execution"""
# cleanup and prepare python script
tmp_sql_table.metadata.drop_all(self.engine, checkfirst=True)
script_path = self.tmp_py()
pyscript = PythonScript.create(script_path)
# populate python script
contents = open(script_path, 'r').read()
contents = contents.replace("pass", "tmp_sql_table.create(migrate_engine)")
contents = 'from test.fixture.models import tmp_sql_table\n' + contents
f = open(script_path, 'w')
f.write(contents)
f.close()
# write SQL script from python script preview
pyscript = PythonScript(script_path)
src = self.tmp()
f = open(src, 'w')
f.write(pyscript.preview_sql(self.url, 1))
f.close()
# run the change
sqls = SqlScript(src)
sqls.run(self.engine, executemany=False)
tmp_sql_table.metadata.drop_all(self.engine, checkfirst=True)

View File

@@ -3,13 +3,14 @@
import os import os
import tempfile import tempfile
from runpy import run_module
from sqlalchemy import MetaData, Table from sqlalchemy import MetaData, Table
from migrate.versioning import genmodel, shell, api
from migrate.versioning.repository import Repository from migrate.versioning.repository import Repository
from migrate.versioning.exceptions import * from migrate.versioning.exceptions import *
from test.fixture import * from test.fixture import *
from migrate.versioning import genmodel, shell, api
class TestShellCommands(Shell): class TestShellCommands(Shell):
@@ -30,6 +31,36 @@ class TestShellCommands(Shell):
self.assertTrue(result.stdout) self.assertTrue(result.stdout)
self.assertFalse(result.stderr) self.assertFalse(result.stderr)
def test_main(self):
"""Test main() function"""
# TODO: test output?
try:
run_module('migrate.versioning.shell', run_name='__main__')
except:
pass
repos = self.tmp_repos()
shell.main(['help'])
shell.main(['help', 'create'])
shell.main(['create', 'repo_name', '--preview_sql'], repository=repos)
shell.main(['version', '--', '--repository=%s' % repos])
shell.main(['version', '-d', '--repository=%s' % repos, '--version=2'])
try:
shell.main(['foobar'])
except SystemExit, e:
pass
try:
shell.main(['create', 'f', 'o', 'o'])
except SystemExit, e:
pass
try:
shell.main(['create'])
except SystemExit, e:
pass
try:
shell.main(['create', 'repo_name'], repository=repos)
except SystemExit, e:
pass
def test_create(self): def test_create(self):
"""Repositories are created successfully""" """Repositories are created successfully"""
repos = self.tmp_repos() repos = self.tmp_repos()
@@ -333,7 +364,7 @@ class TestShellDatabase(Shell, DB):
# Empty script should succeed # Empty script should succeed
result = self.env.run('migrate script Desc %s' % repos_path) result = self.env.run('migrate script Desc %s' % repos_path)
result = self.env.run('migrate test %s %s' % (repos_path, self.url)) result = self.env.run('migrate test %s %s' % (self.url, repos_path))
self.assertEquals(self.run_version(repos_path), 1) self.assertEquals(self.run_version(repos_path), 1)
self.assertEquals(self.run_db_version(self.url, repos_path), 0) self.assertEquals(self.run_db_version(self.url, repos_path), 0)
@@ -355,7 +386,7 @@ class TestShellDatabase(Shell, DB):
file.write(script_text) file.write(script_text)
file.close() file.close()
result = self.env.run('migrate test %s %s bla' % (repos_path, self.url), expect_error=True) result = self.env.run('migrate test %s %s bla' % (self.url, repos_path), expect_error=True)
self.assertEqual(result.returncode, 2) self.assertEqual(result.returncode, 2)
self.assertEquals(self.run_version(repos_path), 1) self.assertEquals(self.run_version(repos_path), 1)
self.assertEquals(self.run_db_version(self.url, repos_path), 0) self.assertEquals(self.run_db_version(self.url, repos_path), 0)
@@ -384,7 +415,7 @@ class TestShellDatabase(Shell, DB):
file = open(script_path, 'w') file = open(script_path, 'w')
file.write(script_text) file.write(script_text)
file.close() file.close()
result = self.env.run('migrate test %s %s' % (repos_path, self.url)) result = self.env.run('migrate test %s %s' % (self.url, repos_path))
self.assertEquals(self.run_version(repos_path), 1) self.assertEquals(self.run_version(repos_path), 1)
self.assertEquals(self.run_db_version(self.url, repos_path), 0) self.assertEquals(self.run_db_version(self.url, repos_path), 0)
@@ -426,12 +457,12 @@ class TestShellDatabase(Shell, DB):
# Update db to latest model. # Update db to latest model.
result = self.env.run('migrate update_db_from_model %s %s %s'\ result = self.env.run('migrate update_db_from_model %s %s %s'\
% (self.url, model_module, repos_path)) % (self.url, repos_path, model_module))
self.assertEquals(self.run_version(repos_path), 0) self.assertEquals(self.run_version(repos_path), 0)
self.assertEquals(self.run_db_version(self.url, repos_path), 0) # version did not get bumped yet because new version not yet created self.assertEquals(self.run_db_version(self.url, repos_path), 0) # version did not get bumped yet because new version not yet created
result = self.env.run('migrate compare_model_to_db %s %s %s'\ result = self.env.run('migrate compare_model_to_db %s %s %s'\
% (self.url, model_module, repos_path)) % (self.url, repos_path, model_module))
self.assert_("No schema diffs" in result.stdout) self.assert_("No schema diffs" in result.stdout)
result = self.env.run('migrate drop_version_control %s %s' % (self.url, repos_path), expect_error=True) result = self.env.run('migrate drop_version_control %s %s' % (self.url, repos_path), expect_error=True)
@@ -448,9 +479,9 @@ class TestShellDatabase(Shell, DB):
self.assertTrue('Not enough arguments' in result.stderr) self.assertTrue('Not enough arguments' in result.stderr)
result_script = self.env.run('migrate make_update_script_for_model %s %s %s %s'\ result_script = self.env.run('migrate make_update_script_for_model %s %s %s %s'\
% (self.url, old_model_module, model_module, repos_path)) % (self.url, repos_path, old_model_module, model_module))
self.assertEqualsIgnoreWhitespace(result_script.stdout, self.assertEqualsIgnoreWhitespace(result_script.stdout,
"""from sqlalchemy import * '''from sqlalchemy import *
from migrate import * from migrate import *
meta = MetaData() meta = MetaData()
@@ -469,7 +500,7 @@ class TestShellDatabase(Shell, DB):
def downgrade(migrate_engine): def downgrade(migrate_engine):
# Operations to reverse the above upgrade go here. # Operations to reverse the above upgrade go here.
meta.bind = migrate_engine meta.bind = migrate_engine
tmp_account_rundiffs.drop()""") tmp_account_rundiffs.drop()''')
# Save the upgrade script. # Save the upgrade script.
result = self.env.run('migrate script Desc %s' % repos_path) result = self.env.run('migrate script Desc %s' % repos_path)
@@ -477,7 +508,7 @@ class TestShellDatabase(Shell, DB):
open(upgrade_script_path, 'w').write(result_script.stdout) open(upgrade_script_path, 'w').write(result_script.stdout)
result = self.env.run('migrate compare_model_to_db %s %s %s'\ result = self.env.run('migrate compare_model_to_db %s %s %s'\
% (self.url, model_module, repos_path)) % (self.url, repos_path, model_module))
self.assert_("No schema diffs" in result.stdout) self.assert_("No schema diffs" in result.stdout)
self.meta.drop_all() # in case junk tables are lying around in the test database self.meta.drop_all() # in case junk tables are lying around in the test database

View File

@@ -36,10 +36,13 @@ class TestUtil(fixture.Pathed):
engine_arg_assert_unicode=True) engine_arg_assert_unicode=True)
self.assertTrue(engine.dialect.assert_unicode) self.assertTrue(engine.dialect.assert_unicode)
# deprecated echo= parameter # deprecated echo=True parameter
engine = construct_engine(url, echo='True') engine = construct_engine(url, echo='True')
self.assertTrue(engine.echo) self.assertTrue(engine.echo)
# unsupported argument
self.assertRaises(ValueError, construct_engine, 1)
def test_asbool(self): def test_asbool(self):
"""test asbool parsing""" """test asbool parsing"""
result = asbool(True) result = asbool(True)

View File

@@ -3,6 +3,7 @@
from test import fixture from test import fixture
from migrate.versioning.version import * from migrate.versioning.version import *
from migrate.versioning.exceptions import *
class TestVerNum(fixture.Base): class TestVerNum(fixture.Base):
@@ -12,6 +13,11 @@ class TestVerNum(fixture.Base):
for version in versions: for version in versions:
self.assertRaises(ValueError, VerNum, version) self.assertRaises(ValueError, VerNum, version)
def test_str(self):
"""Test str and repr version numbers"""
self.assertEqual(str(VerNum(2)), '2')
self.assertEqual(repr(VerNum(2)), '<VerNum(2)>')
def test_is(self): def test_is(self):
"""Two version with the same number should be equal""" """Two version with the same number should be equal"""
a = VerNum(1) a = VerNum(1)
@@ -62,12 +68,14 @@ class TestVerNum(fixture.Base):
self.assert_(VerNum(2) >= 1) self.assert_(VerNum(2) >= 1)
self.assertFalse(VerNum(1) >= 2) self.assertFalse(VerNum(1) >= 2)
class TestVersion(fixture.Pathed): class TestVersion(fixture.Pathed):
def setUp(self): def setUp(self):
super(TestVersion, self).setUp() super(TestVersion, self).setUp()
def test_str_to_filename(self): def test_str_to_filename(self):
self.assertEquals(str_to_filename(''), '')
self.assertEquals(str_to_filename(''), '') self.assertEquals(str_to_filename(''), '')
self.assertEquals(str_to_filename('__'), '_') self.assertEquals(str_to_filename('__'), '_')
self.assertEquals(str_to_filename('a'), 'a') self.assertEquals(str_to_filename('a'), 'a')
@@ -91,12 +99,18 @@ class TestVersion(fixture.Pathed):
coll2 = Collection(self.temp_usable_dir) coll2 = Collection(self.temp_usable_dir)
self.assertEqual(coll.versions, coll2.versions) self.assertEqual(coll.versions, coll2.versions)
Collection.clear()
def test_old_repository(self):
open(os.path.join(self.temp_usable_dir, '1'), 'w')
self.assertRaises(Exception, Collection, self.temp_usable_dir)
#def test_collection_unicode(self): #def test_collection_unicode(self):
# pass # pass
def test_create_new_python_version(self): def test_create_new_python_version(self):
coll = Collection(self.temp_usable_dir) coll = Collection(self.temp_usable_dir)
coll.create_new_python_version("foo bar") coll.create_new_python_version("'")
ver = coll.version() ver = coll.version()
self.assert_(ver.script().source()) self.assert_(ver.script().source())
@@ -140,3 +154,12 @@ class TestVersion(fixture.Pathed):
ver = Version(1, path, [sqlite_upgrade_file, python_file]) ver = Version(1, path, [sqlite_upgrade_file, python_file])
self.assertEquals(os.path.basename(ver.script('postgres', 'upgrade').path), python_file) self.assertEquals(os.path.basename(ver.script('postgres', 'upgrade').path), python_file)
def test_bad_version(self):
ver = Version(1, self.temp_usable_dir, [])
self.assertRaises(ScriptError, ver.add_script, '123.sql')
pyscript = os.path.join(self.temp_usable_dir, 'bla.py')
open(pyscript, 'w')
ver.add_script(pyscript)
self.assertRaises(ScriptError, ver.add_script, 'bla.py')