more tests and now its sort of working

This commit is contained in:
Mike Bayer
2011-11-27 17:43:31 -05:00
parent 47e6fcedb4
commit 1cea038080
11 changed files with 205 additions and 69 deletions

View File

@@ -5,6 +5,7 @@ from alembic.context import _context_opts, get_bind
from alembic import util
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy import types as sqltypes, schema
import re
###################################################
# top level
@@ -19,8 +20,8 @@ def produce_migration_diffs(template_args):
connection = get_bind()
diffs = []
_produce_net_changes(connection, metadata, diffs)
_set_upgrade(template_args, _produce_upgrade_commands(diffs))
_set_downgrade(template_args, _produce_downgrade_commands(diffs))
_set_upgrade(template_args, _indent(_produce_upgrade_commands(diffs)))
_set_downgrade(template_args, _indent(_produce_downgrade_commands(diffs)))
def _set_upgrade(template_args, text):
template_args[_context_opts['upgrade_token']] = text
@@ -28,6 +29,12 @@ def _set_upgrade(template_args, text):
def _set_downgrade(template_args, text):
template_args[_context_opts['downgrade_token']] = text
def _indent(text):
text = "### commands auto generated by Alembic - please adjust! ###\n" + text
text += "\n### end Alembic commands ###"
text = re.compile(r'^', re.M).sub(" ", text).strip()
return text
###################################################
# walk structures
@@ -85,7 +92,12 @@ def _compare_columns(tname, conn_table, metadata_table, diffs):
for cname in metadata_col_names.difference(conn_col_names)
)
diffs.extend(
("remove_column", tname, cname)
("remove_column", tname, schema.Column(
cname,
conn_table[cname]['type'],
nullable=conn_table[cname]['nullable'],
server_default=conn_table[cname]['default']
))
for cname in conn_col_names.difference(metadata_col_names)
)
@@ -145,28 +157,49 @@ _type_comparators = {
}
###################################################
# render python
# produce command structure
def _produce_upgrade_commands(diffs):
buf = []
for diff in diffs:
cmd = _commands[diff[0]]
buf.append(cmd(*diff[1:]))
buf.append(_invoke_command("upgrade", diff))
return "\n".join(buf)
def _produce_downgrade_commands(diffs):
buf = []
for diff in diffs:
cmd = _commands[diff[0]]
buf.append(cmd(*diff[1:]))
buf.append(_invoke_command("downgrade", diff))
return "\n".join(buf)
def _invoke_command(updown, args):
cmd_type = args[0]
adddrop, cmd_type = cmd_type.split("_")
cmd_args = args[1:]
cmd_callables = _commands[cmd_type]
if len(cmd_callables) == 2:
if (
updown == "upgrade" and adddrop == "add"
) or (
updown == "downgrade" and adddrop == "remove"
):
return cmd_callables[1](*cmd_args)
else:
return cmd_callables[0](*cmd_args)
else:
if updown == "upgrade":
return cmd_callables[0](
cmd_args[0], cmd_args[1], cmd_args[3])
else:
return cmd_callables[0](
cmd_args[0], cmd_args[1], cmd_args[2])
###################################################
# render python
def _add_table(table):
return \
"""create_table(%(tablename)r,
%(args)s
)
""" % {
return "create_table(%(tablename)r,\n%(args)s\n)" % {
'tablename':table.name,
'args':',\n'.join(
[_render_column(col) for col in table.c] +
@@ -178,16 +211,16 @@ def _add_table(table):
),
}
def _drop_table(tname):
return "drop_table(%r)" % tname
def _drop_table(table):
return "drop_table(%r)" % table.name
def _add_column(tname, column):
return "add_column(%r, %s)" % (
tname,
_render_column(column))
def _drop_column(tname, cname):
return "drop_column(%r, %r)" % (tname, cname)
def _drop_column(tname, column):
return "drop_column(%r, %r)" % (tname, column.name)
def _modify_type(tname, cname, type_):
return "alter_column(%r, %r, type=%r)" % (
@@ -200,22 +233,37 @@ def _modify_nullable(tname, cname, nullable):
)
_commands = {
"table":(_drop_table, _add_table),
"column":(_drop_column, _add_column),
"type":(_modify_type,),
"nullable":(_modify_nullable,),
}
def _autogenerate_prefix():
return _context_opts['autogenerate_sqlalchemy_prefix']
def _render_column(column):
opts = []
if column.server_default:
opts.append(("server_default", column.server_default))
opts.append(("server_default", _render_server_default(column.server_default)))
if column.nullable is not None:
opts.append(("nullable", column.nullable))
# TODO: for non-ascii colname, assign a "key"
return "Column(%(name)r, %(type)r, %(kw)s)" % {
return "%(prefix)sColumn(%(name)r, %(prefix)s%(type)r, %(kw)s)" % {
'prefix':_autogenerate_prefix(),
'name':column.name,
'type':column.type,
'kw':", ".join(["%s=%s" % (kwname, val) for kwname, val in opts])
}
def _render_server_default(default):
assert isinstance(default, schema.DefaultClause)
return "%(prefix)sDefaultClause(%(arg)r)" % {
'prefix':_autogenerate_prefix(),
'arg':str(default.arg)
}
def _render_constraint(constraint):
renderer = _constraint_renderers.get(type(constraint), None)
if renderer:
@@ -226,10 +274,11 @@ def _render_constraint(constraint):
def _render_primary_key(constraint):
opts = []
if constraint.name:
opts.append(("name", constraint.name))
return "PrimaryKeyConstraint(%(args)s)" % {
opts.append(("name", repr(constraint.name)))
return "%(prefix)sPrimaryKeyConstraint(%(args)s)" % {
"prefix":_autogenerate_prefix(),
"args":", ".join(
[c.key for c in constraint.columns] +
[repr(c.key) for c in constraint.columns] +
["%s=%s" % (kwname, val) for kwname, val in opts]
),
}
@@ -237,9 +286,10 @@ def _render_primary_key(constraint):
def _render_foreign_key(constraint):
opts = []
if constraint.name:
opts.append(("name", constraint.name))
opts.append(("name", repr(constraint.name)))
# TODO: deferrable, initially, etc.
return "ForeignKeyConstraint([%(cols)s], [%(refcols)s], %(args)s)" % {
return "%(prefix)sForeignKeyConstraint([%(cols)s], [%(refcols)s], %(args)s)" % {
"prefix":_autogenerate_prefix(),
"cols":", ".join(f.parent.key for f in constraint.elements),
"refcols":", ".join(repr(f._get_colspec()) for f in constraint.elements),
"args":", ".join(
@@ -250,8 +300,10 @@ def _render_foreign_key(constraint):
def _render_check_constraint(constraint):
opts = []
if constraint.name:
opts.append(("name", constraint.name))
return "CheckConstraint('TODO')"
opts.append(("name", repr(constraint.name)))
return "%(prefix)sCheckConstraint('TODO')" % {
"prefix":_autogenerate_prefix()
}
_constraint_renderers = {
schema.PrimaryKeyConstraint:_render_primary_key,

View File

@@ -149,7 +149,7 @@ _context = None
_script = None
def _opts(cfg, script, **kw):
"""Set up options that will be used by the :func:`.configure_connection`
"""Set up options that will be used by the :func:`.configure`
function.
This basically sets some global variables.
@@ -263,7 +263,8 @@ def configure(
tag=None,
autogenerate_metadata=None,
upgrade_token="upgrades",
downgrade_token="downgrades"
downgrade_token="downgrades",
autogenerate_sqlalchemy_prefix="sa.",
):
"""Configure the migration environment.
@@ -311,6 +312,10 @@ def configure(
:param downgrade_token: when running "alembic revision" with the ``--autogenerate``
option, the text of the candidate downgrade operations will be present in this
template variable when script.py.mako is rendered.
:param autogenerate_sqlalchemy_prefix: When autogenerate refers to SQLAlchemy
:class:`~sqlalchemy.schema.Column` or type classes, this prefix will be used
(i.e. ``sa.Column("somename", sa.Integer)``)
"""
if connection:
@@ -339,6 +344,7 @@ def configure(
opts['autogenerate_metadata'] = autogenerate_metadata
opts['upgrade_token'] = upgrade_token
opts['downgrade_token'] = downgrade_token
opts['autogenerate_sqlalchemy_prefix'] = autogenerate_sqlalchemy_prefix
_context = Context(
dialect, _script, connection,
opts['fn'],

View File

@@ -180,7 +180,7 @@ class ScriptDirectory(object):
shutil.copy,
src, dest)
def generate_rev(self, revid, message, **kw):
def generate_rev(self, revid, message, refresh=False, **kw):
current_head = self._current_head()
path = self._rev_path(revid)
self.generate_template(
@@ -192,12 +192,16 @@ class ScriptDirectory(object):
message=message if message is not None else ("empty message"),
**kw
)
script = Script.from_path(path)
self._revision_map[script.revision] = script
if script.down_revision:
self._revision_map[script.down_revision].\
add_nextrev(script.revision)
return script
if refresh:
script = Script.from_path(path)
self._revision_map[script.revision] = script
if script.down_revision:
self._revision_map[script.down_revision].\
add_nextrev(script.revision)
return script
else:
return revid
class Script(object):
"""Represent a single revision file in a ``versions/`` directory."""

View File

@@ -10,6 +10,7 @@ Create Date: ${create_date}
down_revision = ${repr(down_revision)}
from alembic.op import *
import sqlalchemy as sa
def upgrade():
${upgrades if upgrades else "pass"}

View File

@@ -10,6 +10,7 @@ Create Date: ${create_date}
down_revision = ${repr(down_revision)}
from alembic.op import *
import sqlalchemy as sa
def upgrade(engine):
eval("upgrade_%s" % engine.name)()

View File

@@ -10,6 +10,7 @@ Create Date: ${create_date}
down_revision = ${repr(down_revision)}
from alembic.op import *
import sqlalchemy as sa
def upgrade():
${upgrades if upgrades else "pass"}

View File

@@ -2,7 +2,7 @@ from sqlalchemy.engine import url, default
import shutil
import os
import itertools
from sqlalchemy import create_engine, text
from sqlalchemy import create_engine, text, MetaData
from alembic import context, util
import re
from alembic.script import ScriptDirectory
@@ -218,7 +218,9 @@ def staging_env(create=True, template="generic"):
cfg = _testing_config()
if create:
command.init(cfg, os.path.join(staging_directory, 'scripts'))
return script.ScriptDirectory.from_config(cfg)
sc = script.ScriptDirectory.from_config(cfg)
context._opts(cfg,sc, fn=lambda:None)
return sc
def clear_staging_env():
shutil.rmtree(staging_directory, True)
@@ -230,7 +232,7 @@ def three_rev_fixture(cfg):
c = util.rev_id()
script = ScriptDirectory.from_config(cfg)
script.generate_rev(a, None)
script.generate_rev(a, None, refresh=True)
script.write(a, """
down_revision = None
@@ -244,7 +246,7 @@ def downgrade():
""")
script.generate_rev(b, None)
script.generate_rev(b, None, refresh=True)
script.write(b, """
down_revision = '%s'
@@ -258,7 +260,7 @@ def downgrade():
""" % a)
script.generate_rev(c, None)
script.generate_rev(c, None, refresh=True)
script.write(c, """
down_revision = '%s'

View File

@@ -1,8 +1,8 @@
from sqlalchemy import MetaData, Column, Table, Integer, String, Text, \
Numeric, CHAR, NUMERIC, ForeignKey, DATETIME
from alembic import autogenerate
from alembic import autogenerate, context
from unittest import TestCase
from tests import staging_env, sqlite_db, clear_staging_env, eq_, eq_ignore_whitespace
from tests import staging_env, sqlite_db, clear_staging_env, eq_, eq_ignore_whitespace, capture_context_buffer, _no_sql_testing_config, _testing_config
def _model_one():
m = MetaData()
@@ -76,14 +76,21 @@ class AutogenerateDiffTest(TestCase):
extra = diffs[1][1]
eq_(extra.name, "extra")
del diffs[1]
eq_(repr(diffs[3][3]), "NUMERIC(precision=8, scale=2)")
eq_(repr(diffs[3][4]), "Numeric(precision=10, scale=2)")
del diffs[3]
dropcol = diffs[1][2]
del diffs[1]
eq_(dropcol.name, "pw")
eq_(dropcol.nullable, True)
eq_(dropcol.type._type_affinity, String)
eq_(dropcol.type.length, 50)
eq_(repr(diffs[2][3]), "NUMERIC(precision=8, scale=2)")
eq_(repr(diffs[2][4]), "Numeric(precision=10, scale=2)")
del diffs[2]
eq_(
diffs,
[
('add_table', metadata.tables['item']),
('remove_column', 'user', u'pw'),
('modify_nullable', 'user', 'name', True, False),
('modify_nullable', 'order', u'amount', False, True),
('add_column', 'address',
@@ -91,7 +98,47 @@ class AutogenerateDiffTest(TestCase):
]
)
def test_render_diffs(self):
metadata = _model_two()
connection = self.bind.connect()
template_args = {}
context.configure(
connection=self.bind.connect(),
autogenerate_metadata=metadata)
autogenerate.produce_migration_diffs(template_args)
eq_(template_args['upgrades'],
"""### commands auto generated by Alembic - please adjust! ###
create_table('item',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('description', sa.String(length=100), nullable=True),
sa.PrimaryKeyConstraint('id')
)
drop_table(u'extra')
drop_column('user', u'pw')
alter_column('user', 'name', nullable=False)
alter_column('order', u'amount', type=Numeric(precision=10, scale=2))
alter_column('order', u'amount', nullable=True)
add_column('address', sa.Column('street', sa.String(length=50), nullable=True))
### end Alembic commands ###""")
eq_(template_args['downgrades'],
"""### commands auto generated by Alembic - please adjust! ###
drop_table('item')
create_table(u'extra',
sa.Column(u'x', sa.CHAR(), nullable=True),
sa.PrimaryKeyConstraint()
)
add_column('user', sa.Column(u'pw', sa.VARCHAR(length=50), nullable=True))
alter_column('user', 'name', nullable=True)
alter_column('order', u'amount', type=NUMERIC(precision=8, scale=2))
alter_column('order', u'amount', nullable=False)
drop_column('address', 'street')
### end Alembic commands ###""")
class AutogenRenderTest(TestCase):
@classmethod
def setup_class(cls):
context._context_opts['autogenerate_sqlalchemy_prefix'] = 'sa.'
def test_render_table_upgrade(self):
m = MetaData()
t = Table('test', m,
@@ -102,32 +149,47 @@ class AutogenRenderTest(TestCase):
)
eq_ignore_whitespace(
autogenerate._add_table(t),
"create_table('test', "
"Column('id', Integer(), nullable=False),"
"Column('address_id', Integer(), nullable=True),"
"Column('timestamp', DATETIME(), "
"server_default=DefaultClause('NOW()', for_update=False), "
"create_table('test',"
"sa.Column('id', sa.Integer(), nullable=False),"
"sa.Column('address_id', sa.Integer(), nullable=True),"
"sa.Column('timestamp', sa.DATETIME(), "
"server_default=sa.DefaultClause('NOW()'), "
"nullable=True),"
"Column('amount', Numeric(precision=5, scale=2), nullable=True),"
"ForeignKeyConstraint([address_id], ['address.id'], ),"
"PrimaryKeyConstraint(id)"
" )"
"sa.Column('amount', sa.Numeric(precision=5, scale=2), nullable=True),"
"sa.ForeignKeyConstraint([address_id], ['address.id'], ),"
"sa.PrimaryKeyConstraint('id')"
")"
)
def test_render_table_downgrade(self):
def test_render_drop_table(self):
eq_(
autogenerate._drop_table("sometable"),
autogenerate._drop_table(Table("sometable", MetaData())),
"drop_table('sometable')"
)
def test_render_type_upgrade(self):
def test_render_add_column(self):
eq_(
autogenerate._add_column(
"foo", Column("x", Integer, server_default="5")),
"add_column('foo', sa.Column('x', sa.Integer(), "
"server_default=sa.DefaultClause('5'), nullable=True))"
)
def test_render_drop_column(self):
eq_(
autogenerate._drop_column(
"foo", Column("x", Integer, server_default="5")),
"drop_column('foo', 'x')"
)
def test_render_modify_type(self):
eq_(
autogenerate._modify_type(
"sometable", "somecolumn", CHAR(10)),
"alter_column('sometable', 'somecolumn', type=CHAR(length=10))"
)
def test_render_nullable_upgrade(self):
def test_render_modify_nullable(self):
eq_(
autogenerate._modify_nullable(
"sometable", "somecolumn", True),

View File

@@ -19,7 +19,7 @@ def test_003_heads():
eq_(env._get_heads(), [])
def test_004_rev():
script = env.generate_rev(abc, "this is a message")
script = env.generate_rev(abc, "this is a message", refresh=True)
eq_(script.doc, "this is a message")
eq_(script.revision, abc)
eq_(script.down_revision, None)
@@ -29,7 +29,7 @@ def test_004_rev():
eq_(env._get_heads(), [abc])
def test_005_nextrev():
script = env.generate_rev(def_, "this is the next rev")
script = env.generate_rev(def_, "this is the next rev", refresh=True)
eq_(script.revision, def_)
eq_(script.down_revision, abc)
eq_(env._revision_map[abc].nextrev, set([def_]))
@@ -50,6 +50,13 @@ def test_006_from_clean_env():
eq_(def_rev.down_revision, abc)
eq_(env._get_heads(), [def_])
def test_007_no_refresh():
script = env.generate_rev(util.rev_id(), "dont' refresh")
ne_(script, env._as_rev_number("head"))
env2 = staging_env(create=False)
eq_(script, env2._as_rev_number("head"))
def setup():
global env
env = staging_env()

View File

@@ -6,11 +6,11 @@ def setup():
global env
env = staging_env()
global a, b, c, d, e
a = env.generate_rev(util.rev_id(), None)
b = env.generate_rev(util.rev_id(), None)
c = env.generate_rev(util.rev_id(), None)
d = env.generate_rev(util.rev_id(), None)
e = env.generate_rev(util.rev_id(), None)
a = env.generate_rev(util.rev_id(), None, refresh=True)
b = env.generate_rev(util.rev_id(), None, refresh=True)
c = env.generate_rev(util.rev_id(), None, refresh=True)
d = env.generate_rev(util.rev_id(), None, refresh=True)
e = env.generate_rev(util.rev_id(), None, refresh=True)
def teardown():
clear_staging_env()

View File

@@ -10,7 +10,7 @@ def test_001_revisions():
c = util.rev_id()
script = ScriptDirectory.from_config(cfg)
script.generate_rev(a, None)
script.generate_rev(a, None, refresh=True)
script.write(a, """
down_revision = None
@@ -24,7 +24,7 @@ def downgrade():
""")
script.generate_rev(b, None)
script.generate_rev(b, None, refresh=True)
script.write(b, """
down_revision = '%s'
@@ -38,7 +38,7 @@ def downgrade():
""" % a)
script.generate_rev(c, None)
script.generate_rev(c, None, refresh=True)
script.write(c, """
down_revision = '%s'