add support for SA 0.6 by Michael Bayer
This commit is contained in:
		@@ -1,6 +1,7 @@
 | 
				
			|||||||
0.5.5
 | 
					0.5.5
 | 
				
			||||||
-----
 | 
					-----
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					- added support for SQLAlchemy 0.6 (missing oracle and firebird) by Michael Bayer
 | 
				
			||||||
- alter, create, drop column / rename table / rename index constructs now accept `alter_metadata` parameter. If True, it will modify Column/Table objects according to changes. Otherwise, everything will be untouched.
 | 
					- alter, create, drop column / rename table / rename index constructs now accept `alter_metadata` parameter. If True, it will modify Column/Table objects according to changes. Otherwise, everything will be untouched.
 | 
				
			||||||
- complete refactoring of :class:`~migrate.changeset.schema.ColumnDelta` (fixes issue 23)
 | 
					- complete refactoring of :class:`~migrate.changeset.schema.ColumnDelta` (fixes issue 23)
 | 
				
			||||||
- added support for :ref:`firebird <firebird-d>`
 | 
					- added support for :ref:`firebird <firebird-d>`
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -6,9 +6,21 @@
 | 
				
			|||||||
"""
 | 
					"""
 | 
				
			||||||
import sqlalchemy
 | 
					import sqlalchemy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from sqlalchemy import __version__ as _sa_version
 | 
				
			||||||
 | 
					import re
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					_sa_version = tuple(int(re.match("\d+", x).group(0)) for x in _sa_version.split("."))
 | 
				
			||||||
 | 
					SQLA_06 = _sa_version >= (0, 6)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					del re
 | 
				
			||||||
 | 
					del _sa_version
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from migrate.changeset.schema import *
 | 
					from migrate.changeset.schema import *
 | 
				
			||||||
from migrate.changeset.constraint import *
 | 
					from migrate.changeset.constraint import *
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
sqlalchemy.schema.Table.__bases__ += (ChangesetTable, )
 | 
					sqlalchemy.schema.Table.__bases__ += (ChangesetTable, )
 | 
				
			||||||
sqlalchemy.schema.Column.__bases__ += (ChangesetColumn, )
 | 
					sqlalchemy.schema.Column.__bases__ += (ChangesetColumn, )
 | 
				
			||||||
sqlalchemy.schema.Index.__bases__ += (ChangesetIndex, )
 | 
					sqlalchemy.schema.Index.__bases__ += (ChangesetIndex, )
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -5,23 +5,53 @@
 | 
				
			|||||||
   things that just happen to work with multiple databases.
 | 
					   things that just happen to work with multiple databases.
 | 
				
			||||||
"""
 | 
					"""
 | 
				
			||||||
import sqlalchemy as sa
 | 
					import sqlalchemy as sa
 | 
				
			||||||
 | 
					from sqlalchemy.schema import SchemaVisitor
 | 
				
			||||||
from sqlalchemy.engine.default import DefaultDialect
 | 
					from sqlalchemy.engine.default import DefaultDialect
 | 
				
			||||||
from sqlalchemy.schema import (ForeignKeyConstraint,
 | 
					from sqlalchemy.schema import (ForeignKeyConstraint,
 | 
				
			||||||
                               PrimaryKeyConstraint,
 | 
					                               PrimaryKeyConstraint,
 | 
				
			||||||
                               CheckConstraint,
 | 
					                               CheckConstraint,
 | 
				
			||||||
                               UniqueConstraint,
 | 
					                               UniqueConstraint,
 | 
				
			||||||
                               Index)
 | 
					                               Index)
 | 
				
			||||||
from sqlalchemy.sql.compiler import SchemaGenerator, SchemaDropper
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
from migrate.changeset import exceptions, constraint
 | 
					from migrate.changeset import exceptions, constraint, SQLA_06
 | 
				
			||||||
 | 
					import StringIO
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if not SQLA_06:
 | 
				
			||||||
 | 
					    from sqlalchemy.sql.compiler import SchemaGenerator, SchemaDropper
 | 
				
			||||||
 | 
					else:
 | 
				
			||||||
 | 
					    from sqlalchemy.schema import AddConstraint, DropConstraint
 | 
				
			||||||
 | 
					    from sqlalchemy.sql.compiler import DDLCompiler
 | 
				
			||||||
 | 
					    SchemaGenerator = SchemaDropper = DDLCompiler
 | 
				
			||||||
 | 
					
 | 
				
			||||||
SchemaIterator = sa.engine.SchemaIterator
 | 
					class AlterTableVisitor(SchemaVisitor):
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
class AlterTableVisitor(SchemaIterator):
 | 
					 | 
				
			||||||
    """Common operations for ``ALTER TABLE`` statements."""
 | 
					    """Common operations for ``ALTER TABLE`` statements."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def append(self, s):
 | 
				
			||||||
 | 
					        """Append content to the SchemaIterator's query buffer."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.buffer.write(s)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def execute(self):
 | 
				
			||||||
 | 
					        """Execute the contents of the SchemaIterator's buffer."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            return self.connection.execute(self.buffer.getvalue())
 | 
				
			||||||
 | 
					        finally:
 | 
				
			||||||
 | 
					            self.buffer.truncate(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, dialect, connection, **kw):
 | 
				
			||||||
 | 
					        self.connection = connection
 | 
				
			||||||
 | 
					        self.buffer = StringIO.StringIO()
 | 
				
			||||||
 | 
					        self.preparer = dialect.identifier_preparer
 | 
				
			||||||
 | 
					        self.dialect = dialect
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def traverse_single(self, elem):
 | 
				
			||||||
 | 
					        ret = super(AlterTableVisitor, self).traverse_single(elem)
 | 
				
			||||||
 | 
					        if ret:
 | 
				
			||||||
 | 
					            # adapt to 0.6 which uses a string-returning
 | 
				
			||||||
 | 
					            # object
 | 
				
			||||||
 | 
					            self.append(ret)
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
    def _to_table(self, param):
 | 
					    def _to_table(self, param):
 | 
				
			||||||
        """Returns the table object for the given param object."""
 | 
					        """Returns the table object for the given param object."""
 | 
				
			||||||
        if isinstance(param, (sa.Column, sa.Index, sa.schema.Constraint)):
 | 
					        if isinstance(param, (sa.Column, sa.Index, sa.schema.Constraint)):
 | 
				
			||||||
@@ -88,6 +118,9 @@ class ANSIColumnGenerator(AlterTableVisitor, SchemaGenerator):
 | 
				
			|||||||
                                                   name=column.primary_key_name)
 | 
					                                                   name=column.primary_key_name)
 | 
				
			||||||
            cons.create()
 | 
					            cons.create()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if SQLA_06:
 | 
				
			||||||
 | 
					        def add_foreignkey(self, fk):
 | 
				
			||||||
 | 
					            self.connection.execute(AddConstraint(fk))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ANSIColumnDropper(AlterTableVisitor, SchemaDropper):
 | 
					class ANSIColumnDropper(AlterTableVisitor, SchemaDropper):
 | 
				
			||||||
    """Extends ANSI SQL dropper for column dropping (``ALTER TABLE
 | 
					    """Extends ANSI SQL dropper for column dropping (``ALTER TABLE
 | 
				
			||||||
@@ -181,7 +214,10 @@ class ANSISchemaChanger(AlterTableVisitor, SchemaGenerator):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    def _visit_column_type(self, table, column, delta):
 | 
					    def _visit_column_type(self, table, column, delta):
 | 
				
			||||||
        type_ = delta['type']
 | 
					        type_ = delta['type']
 | 
				
			||||||
        type_text = type_.dialect_impl(self.dialect).get_col_spec()
 | 
					        if SQLA_06:
 | 
				
			||||||
 | 
					            type_text = str(type_.compile(dialect=self.dialect))
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            type_text = type_.dialect_impl(self.dialect).get_col_spec()
 | 
				
			||||||
        self.append("TYPE %s" % type_text)
 | 
					        self.append("TYPE %s" % type_text)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _visit_column_name(self, table, column, delta):
 | 
					    def _visit_column_name(self, table, column, delta):
 | 
				
			||||||
@@ -225,60 +261,75 @@ class ANSIConstraintCommon(AlterTableVisitor):
 | 
				
			|||||||
    def visit_migrate_unique_constraint(self, *p, **k):
 | 
					    def visit_migrate_unique_constraint(self, *p, **k):
 | 
				
			||||||
        self._visit_constraint(*p, **k)
 | 
					        self._visit_constraint(*p, **k)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if SQLA_06:
 | 
				
			||||||
 | 
					    class ANSIConstraintGenerator(ANSIConstraintCommon, SchemaGenerator):
 | 
				
			||||||
 | 
					        def _visit_constraint(self, constraint):
 | 
				
			||||||
 | 
					            constraint.name = self.get_constraint_name(constraint)
 | 
				
			||||||
 | 
					            self.append(self.process(AddConstraint(constraint)))
 | 
				
			||||||
 | 
					            self.execute()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ANSIConstraintGenerator(ANSIConstraintCommon, SchemaGenerator):
 | 
					    class ANSIConstraintDropper(ANSIConstraintCommon, SchemaDropper):
 | 
				
			||||||
 | 
					        def _visit_constraint(self, constraint):
 | 
				
			||||||
 | 
					            constraint.name = self.get_constraint_name(constraint)
 | 
				
			||||||
 | 
					            self.append(self.process(DropConstraint(constraint, cascade=constraint.cascade)))
 | 
				
			||||||
 | 
					            self.execute()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_constraint_specification(self, cons, **kwargs):
 | 
					else:
 | 
				
			||||||
        """Constaint SQL generators.
 | 
					    class ANSIConstraintGenerator(ANSIConstraintCommon, SchemaGenerator):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        def get_constraint_specification(self, cons, **kwargs):
 | 
				
			||||||
 | 
					            """Constaint SQL generators.
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        We cannot use SA visitors because they append comma.
 | 
					            We cannot use SA visitors because they append comma.
 | 
				
			||||||
        """
 | 
					            """
 | 
				
			||||||
        if isinstance(cons, PrimaryKeyConstraint):
 | 
					        
 | 
				
			||||||
            if cons.name is not None:
 | 
					            if isinstance(cons, PrimaryKeyConstraint):
 | 
				
			||||||
                self.append("CONSTRAINT %s " % self.preparer.format_constraint(cons))
 | 
					                if cons.name is not None:
 | 
				
			||||||
            self.append("PRIMARY KEY ")
 | 
					                    self.append("CONSTRAINT %s " % self.preparer.format_constraint(cons))
 | 
				
			||||||
            self.append("(%s)" % ', '.join(self.preparer.quote(c.name, c.quote)
 | 
					                self.append("PRIMARY KEY ")
 | 
				
			||||||
                                           for c in cons))
 | 
					                self.append("(%s)" % ', '.join(self.preparer.quote(c.name, c.quote)
 | 
				
			||||||
            self.define_constraint_deferrability(cons)
 | 
					                                               for c in cons))
 | 
				
			||||||
        elif isinstance(cons, ForeignKeyConstraint):
 | 
					                self.define_constraint_deferrability(cons)
 | 
				
			||||||
            self.define_foreign_key(cons)
 | 
					            elif isinstance(cons, ForeignKeyConstraint):
 | 
				
			||||||
        elif isinstance(cons, CheckConstraint):
 | 
					                self.define_foreign_key(cons)
 | 
				
			||||||
            if cons.name is not None:
 | 
					            elif isinstance(cons, CheckConstraint):
 | 
				
			||||||
                self.append("CONSTRAINT %s " %
 | 
					                if cons.name is not None:
 | 
				
			||||||
                            self.preparer.format_constraint(cons))
 | 
					                    self.append("CONSTRAINT %s " %
 | 
				
			||||||
            self.append("CHECK (%s)" % cons.sqltext)
 | 
					                                self.preparer.format_constraint(cons))
 | 
				
			||||||
            self.define_constraint_deferrability(cons)
 | 
					                self.append("CHECK (%s)" % cons.sqltext)
 | 
				
			||||||
        elif isinstance(cons, UniqueConstraint):
 | 
					                self.define_constraint_deferrability(cons)
 | 
				
			||||||
            if cons.name is not None:
 | 
					            elif isinstance(cons, UniqueConstraint):
 | 
				
			||||||
                self.append("CONSTRAINT %s " %
 | 
					                if cons.name is not None:
 | 
				
			||||||
                            self.preparer.format_constraint(cons))
 | 
					                    self.append("CONSTRAINT %s " %
 | 
				
			||||||
            self.append("UNIQUE (%s)" % \
 | 
					                                self.preparer.format_constraint(cons))
 | 
				
			||||||
                (', '.join(self.preparer.quote(c.name, c.quote) for c in cons)))
 | 
					                self.append("UNIQUE (%s)" % \
 | 
				
			||||||
            self.define_constraint_deferrability(cons)
 | 
					                    (', '.join(self.preparer.quote(c.name, c.quote) for c in cons)))
 | 
				
			||||||
        else:
 | 
					                self.define_constraint_deferrability(cons)
 | 
				
			||||||
            raise exceptions.InvalidConstraintError(cons)
 | 
					            else:
 | 
				
			||||||
 | 
					                raise exceptions.InvalidConstraintError(cons)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _visit_constraint(self, constraint):
 | 
					        def _visit_constraint(self, constraint):
 | 
				
			||||||
        table = self.start_alter_table(constraint)
 | 
					        
 | 
				
			||||||
        constraint.name = self.get_constraint_name(constraint)
 | 
					            table = self.start_alter_table(constraint)
 | 
				
			||||||
        self.append("ADD ")
 | 
					            constraint.name = self.get_constraint_name(constraint)
 | 
				
			||||||
        self.get_constraint_specification(constraint)
 | 
					            self.append("ADD ")
 | 
				
			||||||
        self.execute()
 | 
					            self.get_constraint_specification(constraint)
 | 
				
			||||||
 | 
					            self.execute()
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    class ANSIConstraintDropper(ANSIConstraintCommon, SchemaDropper):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ANSIConstraintDropper(ANSIConstraintCommon, SchemaDropper):
 | 
					        def _visit_constraint(self, constraint):
 | 
				
			||||||
 | 
					            self.start_alter_table(constraint)
 | 
				
			||||||
 | 
					            self.append("DROP CONSTRAINT ")
 | 
				
			||||||
 | 
					            constraint.name = self.get_constraint_name(constraint)
 | 
				
			||||||
 | 
					            self.append(self.preparer.format_constraint(constraint))
 | 
				
			||||||
 | 
					            if constraint.cascade:
 | 
				
			||||||
 | 
					                self.cascade_constraint(constraint)
 | 
				
			||||||
 | 
					            self.execute()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _visit_constraint(self, constraint):
 | 
					        def cascade_constraint(self, constraint):
 | 
				
			||||||
        self.start_alter_table(constraint)
 | 
					            self.append(" CASCADE")
 | 
				
			||||||
        self.append("DROP CONSTRAINT ")
 | 
					 | 
				
			||||||
        constraint.name = self.get_constraint_name(constraint)
 | 
					 | 
				
			||||||
        self.append(self.preparer.format_constraint(constraint))
 | 
					 | 
				
			||||||
        if constraint.cascade:
 | 
					 | 
				
			||||||
            self.cascade_constraint(constraint)
 | 
					 | 
				
			||||||
        self.execute()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def cascade_constraint(self, constraint):
 | 
					 | 
				
			||||||
        self.append(" CASCADE")
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ANSIDialect(DefaultDialect):
 | 
					class ANSIDialect(DefaultDialect):
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -4,7 +4,7 @@
 | 
				
			|||||||
from sqlalchemy import schema
 | 
					from sqlalchemy import schema
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from migrate.changeset.exceptions import *
 | 
					from migrate.changeset.exceptions import *
 | 
				
			||||||
 | 
					from migrate.changeset import SQLA_06
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ConstraintChangeset(object):
 | 
					class ConstraintChangeset(object):
 | 
				
			||||||
    """Base class for Constraint classes."""
 | 
					    """Base class for Constraint classes."""
 | 
				
			||||||
@@ -54,7 +54,10 @@ class ConstraintChangeset(object):
 | 
				
			|||||||
        """
 | 
					        """
 | 
				
			||||||
        self.cascade = kw.pop('cascade', False)
 | 
					        self.cascade = kw.pop('cascade', False)
 | 
				
			||||||
        self.__do_imports('constraintdropper', *a, **kw)
 | 
					        self.__do_imports('constraintdropper', *a, **kw)
 | 
				
			||||||
        self.columns.clear()
 | 
					        # the spirit of Constraint objects is that they
 | 
				
			||||||
 | 
					        # are immutable (just like in a DB.  they're only ADDed
 | 
				
			||||||
 | 
					        # or DROPped).
 | 
				
			||||||
 | 
					        #self.columns.clear()
 | 
				
			||||||
        return self
 | 
					        return self
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -69,7 +72,7 @@ class PrimaryKeyConstraint(ConstraintChangeset, schema.PrimaryKeyConstraint):
 | 
				
			|||||||
    :type cols: strings or Column instances
 | 
					    :type cols: strings or Column instances
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    __visit_name__ = 'migrate_primary_key_constraint'
 | 
					    __migrate_visit_name__ = 'migrate_primary_key_constraint'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, *cols, **kwargs):
 | 
					    def __init__(self, *cols, **kwargs):
 | 
				
			||||||
        colnames, table = self._normalize_columns(cols)
 | 
					        colnames, table = self._normalize_columns(cols)
 | 
				
			||||||
@@ -97,7 +100,7 @@ class ForeignKeyConstraint(ConstraintChangeset, schema.ForeignKeyConstraint):
 | 
				
			|||||||
    :type refcolumns: list of strings or Column instances
 | 
					    :type refcolumns: list of strings or Column instances
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    __visit_name__ = 'migrate_foreign_key_constraint'
 | 
					    __migrate_visit_name__ = 'migrate_foreign_key_constraint'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, columns, refcolumns, *args, **kwargs):
 | 
					    def __init__(self, columns, refcolumns, *args, **kwargs):
 | 
				
			||||||
        colnames, table = self._normalize_columns(columns)
 | 
					        colnames, table = self._normalize_columns(columns)
 | 
				
			||||||
@@ -139,7 +142,7 @@ class CheckConstraint(ConstraintChangeset, schema.CheckConstraint):
 | 
				
			|||||||
    :type sqltext: string
 | 
					    :type sqltext: string
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    __visit_name__ = 'migrate_check_constraint'
 | 
					    __migrate_visit_name__ = 'migrate_check_constraint'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, sqltext, *args, **kwargs):
 | 
					    def __init__(self, sqltext, *args, **kwargs):
 | 
				
			||||||
        cols = kwargs.pop('columns', [])
 | 
					        cols = kwargs.pop('columns', [])
 | 
				
			||||||
@@ -150,7 +153,8 @@ class CheckConstraint(ConstraintChangeset, schema.CheckConstraint):
 | 
				
			|||||||
        table = kwargs.pop('table', table)
 | 
					        table = kwargs.pop('table', table)
 | 
				
			||||||
        schema.CheckConstraint.__init__(self, sqltext, *args, **kwargs)
 | 
					        schema.CheckConstraint.__init__(self, sqltext, *args, **kwargs)
 | 
				
			||||||
        if table is not None:
 | 
					        if table is not None:
 | 
				
			||||||
            self.table = table
 | 
					            if not SQLA_06:
 | 
				
			||||||
 | 
					                self.table = table
 | 
				
			||||||
            self._set_parent(table)
 | 
					            self._set_parent(table)
 | 
				
			||||||
        self.colnames = colnames
 | 
					        self.colnames = colnames
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -172,7 +176,7 @@ class UniqueConstraint(ConstraintChangeset, schema.UniqueConstraint):
 | 
				
			|||||||
    .. versionadded:: 0.5.5
 | 
					    .. versionadded:: 0.5.5
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    __visit_name__ = 'migrate_unique_constraint'
 | 
					    __migrate_visit_name__ = 'migrate_unique_constraint'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __init__(self, *cols, **kwargs):
 | 
					    def __init__(self, *cols, **kwargs):
 | 
				
			||||||
        self.colnames, table = self._normalize_columns(cols)
 | 
					        self.colnames, table = self._normalize_columns(cols)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,14 +1,13 @@
 | 
				
			|||||||
"""
 | 
					"""
 | 
				
			||||||
   Firebird database specific implementations of changeset classes.
 | 
					   Firebird database specific implementations of changeset classes.
 | 
				
			||||||
"""
 | 
					"""
 | 
				
			||||||
from sqlalchemy.databases import firebird as sa_base
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
from migrate.changeset import ansisql, exceptions
 | 
					from migrate.changeset import ansisql, exceptions
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# TODO: SQLA 0.6 has not migrated the FB dialect over yet
 | 
				
			||||||
 | 
					from sqlalchemy.databases import firebird as sa_base
 | 
				
			||||||
FBSchemaGenerator = sa_base.FBSchemaGenerator
 | 
					FBSchemaGenerator = sa_base.FBSchemaGenerator
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
class FBColumnGenerator(FBSchemaGenerator, ansisql.ANSIColumnGenerator):
 | 
					class FBColumnGenerator(FBSchemaGenerator, ansisql.ANSIColumnGenerator):
 | 
				
			||||||
    """Firebird column generator implementation."""
 | 
					    """Firebird column generator implementation."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,13 +2,13 @@
 | 
				
			|||||||
   MySQL database specific implementations of changeset classes.
 | 
					   MySQL database specific implementations of changeset classes.
 | 
				
			||||||
"""
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from migrate.changeset import ansisql, exceptions, SQLA_06
 | 
				
			||||||
from sqlalchemy.databases import mysql as sa_base
 | 
					from sqlalchemy.databases import mysql as sa_base
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from migrate.changeset import ansisql, exceptions
 | 
					if not SQLA_06:
 | 
				
			||||||
 | 
					    MySQLSchemaGenerator = sa_base.MySQLSchemaGenerator
 | 
				
			||||||
 | 
					else:
 | 
				
			||||||
MySQLSchemaGenerator = sa_base.MySQLSchemaGenerator
 | 
					    MySQLSchemaGenerator = sa_base.MySQLDDLCompiler
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
class MySQLColumnGenerator(MySQLSchemaGenerator, ansisql.ANSIColumnGenerator):
 | 
					class MySQLColumnGenerator(MySQLSchemaGenerator, ansisql.ANSIColumnGenerator):
 | 
				
			||||||
    pass
 | 
					    pass
 | 
				
			||||||
@@ -39,31 +39,37 @@ class MySQLSchemaChanger(MySQLSchemaGenerator, ansisql.ANSISchemaChanger):
 | 
				
			|||||||
class MySQLConstraintGenerator(ansisql.ANSIConstraintGenerator):
 | 
					class MySQLConstraintGenerator(ansisql.ANSIConstraintGenerator):
 | 
				
			||||||
    pass
 | 
					    pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if SQLA_06:
 | 
				
			||||||
 | 
					    class MySQLConstraintDropper(MySQLSchemaGenerator, ansisql.ANSIConstraintDropper):
 | 
				
			||||||
 | 
					        def visit_migrate_check_constraint(self, *p, **k):
 | 
				
			||||||
 | 
					            raise exceptions.NotSupportedError("MySQL does not support CHECK"
 | 
				
			||||||
 | 
					                " constraints, use triggers instead.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class MySQLConstraintDropper(ansisql.ANSIConstraintDropper):
 | 
					else:
 | 
				
			||||||
 | 
					    class MySQLConstraintDropper(ansisql.ANSIConstraintDropper):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def visit_migrate_primary_key_constraint(self, constraint):
 | 
					        def visit_migrate_primary_key_constraint(self, constraint):
 | 
				
			||||||
        self.start_alter_table(constraint)
 | 
					            self.start_alter_table(constraint)
 | 
				
			||||||
        self.append("DROP PRIMARY KEY")
 | 
					            self.append("DROP PRIMARY KEY")
 | 
				
			||||||
        self.execute()
 | 
					            self.execute()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def visit_migrate_foreign_key_constraint(self, constraint):
 | 
					        def visit_migrate_foreign_key_constraint(self, constraint):
 | 
				
			||||||
        self.start_alter_table(constraint)
 | 
					            self.start_alter_table(constraint)
 | 
				
			||||||
        self.append("DROP FOREIGN KEY ")
 | 
					            self.append("DROP FOREIGN KEY ")
 | 
				
			||||||
        constraint.name = self.get_constraint_name(constraint)
 | 
					            constraint.name = self.get_constraint_name(constraint)
 | 
				
			||||||
        self.append(self.preparer.format_constraint(constraint))
 | 
					            self.append(self.preparer.format_constraint(constraint))
 | 
				
			||||||
        self.execute()
 | 
					            self.execute()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def visit_migrate_check_constraint(self, *p, **k):
 | 
					        def visit_migrate_check_constraint(self, *p, **k):
 | 
				
			||||||
        raise exceptions.NotSupportedError("MySQL does not support CHECK"
 | 
					            raise exceptions.NotSupportedError("MySQL does not support CHECK"
 | 
				
			||||||
            " constraints, use triggers instead.")
 | 
					                " constraints, use triggers instead.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def visit_migrate_unique_constraint(self, constraint, *p, **k):
 | 
					        def visit_migrate_unique_constraint(self, constraint, *p, **k):
 | 
				
			||||||
        self.start_alter_table(constraint)
 | 
					            self.start_alter_table(constraint)
 | 
				
			||||||
        self.append('DROP INDEX ')
 | 
					            self.append('DROP INDEX ')
 | 
				
			||||||
        constraint.name = self.get_constraint_name(constraint)
 | 
					            constraint.name = self.get_constraint_name(constraint)
 | 
				
			||||||
        self.append(self.preparer.format_constraint(constraint))
 | 
					            self.append(self.preparer.format_constraint(constraint))
 | 
				
			||||||
        self.execute()
 | 
					            self.execute()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class MySQLDialect(ansisql.ANSIDialect):
 | 
					class MySQLDialect(ansisql.ANSIDialect):
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,12 +2,17 @@
 | 
				
			|||||||
   Oracle database specific implementations of changeset classes.
 | 
					   Oracle database specific implementations of changeset classes.
 | 
				
			||||||
"""
 | 
					"""
 | 
				
			||||||
import sqlalchemy as sa
 | 
					import sqlalchemy as sa
 | 
				
			||||||
from sqlalchemy.databases import oracle as sa_base
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
from migrate.changeset import ansisql, exceptions
 | 
					from migrate.changeset import ansisql, exceptions
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from sqlalchemy.databases import oracle as sa_base
 | 
				
			||||||
 | 
					
 | 
				
			||||||
OracleSchemaGenerator = sa_base.OracleSchemaGenerator
 | 
					from migrate.changeset import ansisql, exceptions, SQLA_06
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if not SQLA_06:
 | 
				
			||||||
 | 
					    OracleSchemaGenerator = sa_base.OracleSchemaGenerator
 | 
				
			||||||
 | 
					else:
 | 
				
			||||||
 | 
					    OracleSchemaGenerator = sa_base.OracleDDLCompiler
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class OracleColumnGenerator(OracleSchemaGenerator, ansisql.ANSIColumnGenerator):
 | 
					class OracleColumnGenerator(OracleSchemaGenerator, ansisql.ANSIColumnGenerator):
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -3,12 +3,13 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
   .. _`PostgreSQL`: http://www.postgresql.org/
 | 
					   .. _`PostgreSQL`: http://www.postgresql.org/
 | 
				
			||||||
"""
 | 
					"""
 | 
				
			||||||
from migrate.changeset import ansisql
 | 
					from migrate.changeset import ansisql, SQLA_06
 | 
				
			||||||
from sqlalchemy.databases import postgres as sa_base
 | 
					from sqlalchemy.databases import postgres as sa_base
 | 
				
			||||||
#import sqlalchemy as sa
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if not SQLA_06:
 | 
				
			||||||
PGSchemaGenerator = sa_base.PGSchemaGenerator
 | 
					    PGSchemaGenerator = sa_base.PGSchemaGenerator
 | 
				
			||||||
 | 
					else:
 | 
				
			||||||
 | 
					    PGSchemaGenerator = sa_base.PGDDLCompiler
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class PGColumnGenerator(PGSchemaGenerator, ansisql.ANSIColumnGenerator):
 | 
					class PGColumnGenerator(PGSchemaGenerator, ansisql.ANSIColumnGenerator):
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -8,10 +8,12 @@ from copy import copy
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from sqlalchemy.databases import sqlite as sa_base
 | 
					from sqlalchemy.databases import sqlite as sa_base
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from migrate.changeset import ansisql, exceptions
 | 
					from migrate.changeset import ansisql, exceptions, SQLA_06
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					if not SQLA_06:
 | 
				
			||||||
SQLiteSchemaGenerator = sa_base.SQLiteSchemaGenerator
 | 
					    SQLiteSchemaGenerator = sa_base.SQLiteSchemaGenerator
 | 
				
			||||||
 | 
					else:
 | 
				
			||||||
 | 
					    SQLiteSchemaGenerator = sa_base.SQLiteDDLCompiler
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class SQLiteCommon(object):
 | 
					class SQLiteCommon(object):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -52,8 +54,7 @@ class SQLiteHelper(SQLiteCommon):
 | 
				
			|||||||
        table.indexes = ixbackup
 | 
					        table.indexes = ixbackup
 | 
				
			||||||
        table.constraints = consbackup
 | 
					        table.constraints = consbackup
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class SQLiteColumnGenerator(SQLiteSchemaGenerator, SQLiteCommon, 
 | 
				
			||||||
class SQLiteColumnGenerator(SQLiteSchemaGenerator, SQLiteCommon,
 | 
					 | 
				
			||||||
                            ansisql.ANSIColumnGenerator):
 | 
					                            ansisql.ANSIColumnGenerator):
 | 
				
			||||||
    """SQLite ColumnGenerator"""
 | 
					    """SQLite ColumnGenerator"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -13,12 +13,12 @@ from migrate.changeset.databases import (sqlite,
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
# Map SA dialects to the corresponding Migrate extensions
 | 
					# Map SA dialects to the corresponding Migrate extensions
 | 
				
			||||||
DIALECTS = {
 | 
					DIALECTS = {
 | 
				
			||||||
    sa.engine.default.DefaultDialect: ansisql.ANSIDialect,
 | 
					    "default": ansisql.ANSIDialect,
 | 
				
			||||||
    sa.databases.sqlite.SQLiteDialect: sqlite.SQLiteDialect,
 | 
					    "sqlite": sqlite.SQLiteDialect,
 | 
				
			||||||
    sa.databases.postgres.PGDialect: postgres.PGDialect,
 | 
					    "postgres": postgres.PGDialect,
 | 
				
			||||||
    sa.databases.mysql.MySQLDialect: mysql.MySQLDialect,
 | 
					    "mysql": mysql.MySQLDialect,
 | 
				
			||||||
    sa.databases.oracle.OracleDialect: oracle.OracleDialect,
 | 
					    "oracle": oracle.OracleDialect,
 | 
				
			||||||
    sa.databases.firebird.FBDialect: firebird.FBDialect,
 | 
					    "firebird": firebird.FBDialect,
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -47,8 +47,8 @@ def get_dialect_visitor(sa_dialect, name):
 | 
				
			|||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # map sa dialect to migrate dialect and return visitor
 | 
					    # map sa dialect to migrate dialect and return visitor
 | 
				
			||||||
    sa_dialect_cls = sa_dialect.__class__
 | 
					    sa_dialect_name = getattr(sa_dialect, 'name', 'default')
 | 
				
			||||||
    migrate_dialect_cls = DIALECTS[sa_dialect_cls]
 | 
					    migrate_dialect_cls = DIALECTS[sa_dialect_name]
 | 
				
			||||||
    visitor = getattr(migrate_dialect_cls, name)
 | 
					    visitor = getattr(migrate_dialect_cls, name)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # bind preparer
 | 
					    # bind preparer
 | 
				
			||||||
@@ -61,6 +61,10 @@ def run_single_visitor(engine, visitorcallable, element, **kwargs):
 | 
				
			|||||||
    conn = engine.contextual_connect(close_with_result=False)
 | 
					    conn = engine.contextual_connect(close_with_result=False)
 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
        visitor = visitorcallable(engine.dialect, conn)
 | 
					        visitor = visitorcallable(engine.dialect, conn)
 | 
				
			||||||
        getattr(visitor, 'visit_' + element.__visit_name__)(element, **kwargs)
 | 
					        if hasattr(element, '__migrate_visit_name__'):
 | 
				
			||||||
 | 
					            fn = getattr(visitor, 'visit_' + element.__migrate_visit_name__)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            fn = getattr(visitor, 'visit_' + element.__visit_name__)
 | 
				
			||||||
 | 
					        fn(element, **kwargs)
 | 
				
			||||||
    finally:
 | 
					    finally:
 | 
				
			||||||
        conn.close()
 | 
					        conn.close()
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -4,6 +4,7 @@
 | 
				
			|||||||
from UserDict import DictMixin
 | 
					from UserDict import DictMixin
 | 
				
			||||||
import sqlalchemy
 | 
					import sqlalchemy
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from migrate.changeset import SQLA_06
 | 
				
			||||||
from migrate.changeset.exceptions import *
 | 
					from migrate.changeset.exceptions import *
 | 
				
			||||||
from migrate.changeset.databases.visitor import (get_engine_visitor,
 | 
					from migrate.changeset.databases.visitor import (get_engine_visitor,
 | 
				
			||||||
                                                 run_single_visitor)
 | 
					                                                 run_single_visitor)
 | 
				
			||||||
@@ -310,7 +311,7 @@ class ColumnDelta(DictMixin, sqlalchemy.schema.SchemaItem):
 | 
				
			|||||||
    def process_column(self, column):
 | 
					    def process_column(self, column):
 | 
				
			||||||
        """Processes default values for column"""
 | 
					        """Processes default values for column"""
 | 
				
			||||||
        # XXX: this is a snippet from SA processing of positional parameters
 | 
					        # XXX: this is a snippet from SA processing of positional parameters
 | 
				
			||||||
        if column.args:
 | 
					        if not SQLA_06 and column.args:
 | 
				
			||||||
            toinit = list(column.args)
 | 
					            toinit = list(column.args)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            toinit = list()
 | 
					            toinit = list()
 | 
				
			||||||
@@ -328,7 +329,9 @@ class ColumnDelta(DictMixin, sqlalchemy.schema.SchemaItem):
 | 
				
			|||||||
                                            for_update=True))
 | 
					                                            for_update=True))
 | 
				
			||||||
        if toinit:
 | 
					        if toinit:
 | 
				
			||||||
            column._init_items(*toinit)
 | 
					            column._init_items(*toinit)
 | 
				
			||||||
        column.args = []
 | 
					            
 | 
				
			||||||
 | 
					        if not SQLA_06:
 | 
				
			||||||
 | 
					            column.args = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _get_table(self):
 | 
					    def _get_table(self):
 | 
				
			||||||
        return getattr(self, '_table', None)
 | 
					        return getattr(self, '_table', None)
 | 
				
			||||||
@@ -365,9 +368,6 @@ class ColumnDelta(DictMixin, sqlalchemy.schema.SchemaItem):
 | 
				
			|||||||
            self.current_name = column.name
 | 
					            self.current_name = column.name
 | 
				
			||||||
        if self.alter_metadata:
 | 
					        if self.alter_metadata:
 | 
				
			||||||
            self._result_column = column
 | 
					            self._result_column = column
 | 
				
			||||||
            # remove column from table, nothing has changed yet
 | 
					 | 
				
			||||||
            if self.table:
 | 
					 | 
				
			||||||
                column.remove_from_table(self.table)
 | 
					 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            self._result_column = column.copy_fixed()
 | 
					            self._result_column = column.copy_fixed()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -34,10 +34,7 @@ class ModelGenerator(object):
 | 
				
			|||||||
    def __init__(self, diff, declarative=False):
 | 
					    def __init__(self, diff, declarative=False):
 | 
				
			||||||
        self.diff = diff
 | 
					        self.diff = diff
 | 
				
			||||||
        self.declarative = declarative
 | 
					        self.declarative = declarative
 | 
				
			||||||
        # is there an easier way to get this?
 | 
					
 | 
				
			||||||
        dialectModule = sys.modules[self.diff.conn.dialect.__module__]
 | 
					 | 
				
			||||||
        self.colTypeMappings = dict((v, k) for k, v in \
 | 
					 | 
				
			||||||
                                        dialectModule.colspecs.items())
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def column_repr(self, col):
 | 
					    def column_repr(self, col):
 | 
				
			||||||
        kwarg = []
 | 
					        kwarg = []
 | 
				
			||||||
@@ -63,18 +60,18 @@ class ModelGenerator(object):
 | 
				
			|||||||
        # crs: not sure if this is good idea, but it gets rid of extra
 | 
					        # crs: not sure if this is good idea, but it gets rid of extra
 | 
				
			||||||
        # u''
 | 
					        # u''
 | 
				
			||||||
        name = col.name.encode('utf8')
 | 
					        name = col.name.encode('utf8')
 | 
				
			||||||
        type = self.colTypeMappings.get(col.type.__class__, None)
 | 
					
 | 
				
			||||||
        if type:
 | 
					        type_ = col.type
 | 
				
			||||||
            # Make the column type be an instance of this type.
 | 
					        for cls in col.type.__class__.__mro__:
 | 
				
			||||||
            type = type()
 | 
					            if cls.__module__ == 'sqlalchemy.types' and \
 | 
				
			||||||
        else:
 | 
					                not cls.__name__.isupper():
 | 
				
			||||||
            # We must already be a model type, no need to map from the
 | 
					                if cls is not type_.__class__:
 | 
				
			||||||
            # database-specific types.
 | 
					                    type_ = cls()
 | 
				
			||||||
            type = col.type
 | 
					                break
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        data = {
 | 
					        data = {
 | 
				
			||||||
            'name': name,
 | 
					            'name': name,
 | 
				
			||||||
            'type': type,
 | 
					            'type': type_,
 | 
				
			||||||
            'constraints': ', '.join([repr(cn) for cn in col.constraints]),
 | 
					            'constraints': ', '.join([repr(cn) for cn in col.constraints]),
 | 
				
			||||||
            'args': ks and ks or ''}
 | 
					            'args': ks and ks or ''}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,7 +2,7 @@
 | 
				
			|||||||
   Schema differencing support.
 | 
					   Schema differencing support.
 | 
				
			||||||
"""
 | 
					"""
 | 
				
			||||||
import sqlalchemy
 | 
					import sqlalchemy
 | 
				
			||||||
 | 
					from migrate.changeset import SQLA_06
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def getDiffOfModelAgainstDatabase(model, conn, excludeTables=None):
 | 
					def getDiffOfModelAgainstDatabase(model, conn, excludeTables=None):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
@@ -55,9 +55,25 @@ class SchemaDiff(object):
 | 
				
			|||||||
        """
 | 
					        """
 | 
				
			||||||
        # Setup common variables.
 | 
					        # Setup common variables.
 | 
				
			||||||
        cc = self.conn.contextual_connect()
 | 
					        cc = self.conn.contextual_connect()
 | 
				
			||||||
        schemagenerator = self.conn.dialect.schemagenerator(
 | 
					        if SQLA_06:
 | 
				
			||||||
            self.conn.dialect, cc)
 | 
					            from sqlalchemy.ext import compiler
 | 
				
			||||||
 | 
					            from sqlalchemy.schema import DDLElement
 | 
				
			||||||
 | 
					            class DefineColumn(DDLElement):
 | 
				
			||||||
 | 
					                def __init__(self, col):
 | 
				
			||||||
 | 
					                    self.col = col
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            @compiler.compiles(DefineColumn)
 | 
				
			||||||
 | 
					            def compile(elem, compiler, **kw):
 | 
				
			||||||
 | 
					                return compiler.get_column_specification(elem.col)
 | 
				
			||||||
 | 
					            
 | 
				
			||||||
 | 
					            def get_column_specification(col):
 | 
				
			||||||
 | 
					                return str(DefineColumn(col).compile(dialect=self.conn.dialect))
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            schemagenerator = self.conn.dialect.schemagenerator(
 | 
				
			||||||
 | 
					                self.conn.dialect, cc)
 | 
				
			||||||
 | 
					            def get_column_specification(col):
 | 
				
			||||||
 | 
					                return schemagenerator.get_column_specification(col)
 | 
				
			||||||
 | 
					                
 | 
				
			||||||
        # For each in model, find missing in database.
 | 
					        # For each in model, find missing in database.
 | 
				
			||||||
        for modelName, modelTable in self.model.tables.items():
 | 
					        for modelName, modelTable in self.model.tables.items():
 | 
				
			||||||
            if modelName in self.excludeTables:
 | 
					            if modelName in self.excludeTables:
 | 
				
			||||||
@@ -89,15 +105,16 @@ class SchemaDiff(object):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
                # Find missing columns in model.
 | 
					                # Find missing columns in model.
 | 
				
			||||||
                for databaseCol in reflectedTable.columns:
 | 
					                for databaseCol in reflectedTable.columns:
 | 
				
			||||||
 | 
					                    
 | 
				
			||||||
 | 
					                    # TODO: no test coverage here?   (mrb)
 | 
				
			||||||
 | 
					                    
 | 
				
			||||||
                    modelCol = modelTable.columns.get(databaseCol.name, None)
 | 
					                    modelCol = modelTable.columns.get(databaseCol.name, None)
 | 
				
			||||||
                    if modelCol:
 | 
					                    if modelCol:
 | 
				
			||||||
                        # Compare attributes of column.
 | 
					                        # Compare attributes of column.
 | 
				
			||||||
                        modelDecl = \
 | 
					                        modelDecl = \
 | 
				
			||||||
                            schemagenerator.get_column_specification(
 | 
					                            get_column_specification(modelCol)
 | 
				
			||||||
                            modelCol)
 | 
					 | 
				
			||||||
                        databaseDecl = \
 | 
					                        databaseDecl = \
 | 
				
			||||||
                            schemagenerator.get_column_specification(
 | 
					                            get_column_specification(databaseCol)
 | 
				
			||||||
                            databaseCol)
 | 
					 | 
				
			||||||
                        if modelDecl != databaseDecl:
 | 
					                        if modelDecl != databaseDecl:
 | 
				
			||||||
                            # Unfortunately, sometimes the database
 | 
					                            # Unfortunately, sometimes the database
 | 
				
			||||||
                            # decl won't quite match the model, even
 | 
					                            # decl won't quite match the model, even
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -112,7 +112,7 @@ class PythonScript(base.BaseScript):
 | 
				
			|||||||
        """
 | 
					        """
 | 
				
			||||||
        buf = StringIO()
 | 
					        buf = StringIO()
 | 
				
			||||||
        args['engine_arg_strategy'] = 'mock'
 | 
					        args['engine_arg_strategy'] = 'mock'
 | 
				
			||||||
        args['engine_arg_executor'] = lambda s, p = '': buf.write(s + p)
 | 
					        args['engine_arg_executor'] = lambda s, p = '': buf.write(str(s) + p)
 | 
				
			||||||
        engine = construct_engine(url, **args)
 | 
					        engine = construct_engine(url, **args)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.run(engine, step)
 | 
					        self.run(engine, step)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -18,7 +18,7 @@ from test import fixture
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
class Shell(fixture.Shell):
 | 
					class Shell(fixture.Shell):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    _cmd = os.path.join('python migrate', 'versioning', 'shell.py')
 | 
					    _cmd = os.path.join(sys.executable + ' migrate', 'versioning', 'shell.py')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @classmethod
 | 
					    @classmethod
 | 
				
			||||||
    def cmd(cls, *args):
 | 
					    def cmd(cls, *args):
 | 
				
			||||||
@@ -509,21 +509,21 @@ class TestShellDatabase(Shell, fixture.DB):
 | 
				
			|||||||
        open(model_path, 'w').write(script_preamble + script_text)
 | 
					        open(model_path, 'w').write(script_preamble + script_text)
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        # Model is defined but database is empty.
 | 
					        # Model is defined but database is empty.
 | 
				
			||||||
        output, exitcode = self.output_and_exitcode('python %s compare_model_to_db' % script_path)
 | 
					        output, exitcode = self.output_and_exitcode('%s %s compare_model_to_db' % (sys.executable, script_path))
 | 
				
			||||||
        assert "tables missing in database: tmp_account_rundiffs" in output, output
 | 
					        assert "tables missing in database: tmp_account_rundiffs" in output, output
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Test Deprecation
 | 
					        # Test Deprecation
 | 
				
			||||||
        output, exitcode = self.output_and_exitcode('python %s compare_model_to_db --model=testmodel.meta' % script_path)
 | 
					        output, exitcode = self.output_and_exitcode('%s %s compare_model_to_db --model=testmodel.meta' % (sys.executable, script_path))
 | 
				
			||||||
        assert "tables missing in database: tmp_account_rundiffs" in output, output
 | 
					        assert "tables missing in database: tmp_account_rundiffs" in output, output
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        # Update db to latest model.
 | 
					        # Update db to latest model.
 | 
				
			||||||
        output, exitcode = self.output_and_exitcode('python %s update_db_from_model' % script_path)
 | 
					        output, exitcode = self.output_and_exitcode('%s %s update_db_from_model' % (sys.executable, script_path))
 | 
				
			||||||
        self.assertEquals(exitcode, None)
 | 
					        self.assertEquals(exitcode, None)
 | 
				
			||||||
        self.assertEquals(self.cmd_version(repos_path),0)
 | 
					        self.assertEquals(self.cmd_version(repos_path),0)
 | 
				
			||||||
        self.assertEquals(self.cmd_db_version(self.url,repos_path),0)  # version did not get bumped yet because new version not yet created
 | 
					        self.assertEquals(self.cmd_db_version(self.url,repos_path),0)  # version did not get bumped yet because new version not yet created
 | 
				
			||||||
        output, exitcode = self.output_and_exitcode('python %s compare_model_to_db' % script_path)
 | 
					        output, exitcode = self.output_and_exitcode('%s %s compare_model_to_db' % (sys.executable, script_path))
 | 
				
			||||||
        assert "No schema diffs" in output, output
 | 
					        assert "No schema diffs" in output, output
 | 
				
			||||||
        output, exitcode = self.output_and_exitcode('python %s create_model' % script_path)
 | 
					        output, exitcode = self.output_and_exitcode('%s %s create_model' % (sys.executable, script_path))
 | 
				
			||||||
        output = output.replace(genmodel.HEADER.strip(), '')  # need strip b/c output_and_exitcode called strip
 | 
					        output = output.replace(genmodel.HEADER.strip(), '')  # need strip b/c output_and_exitcode called strip
 | 
				
			||||||
        assert """tmp_account_rundiffs = Table('tmp_account_rundiffs', meta,
 | 
					        assert """tmp_account_rundiffs = Table('tmp_account_rundiffs', meta,
 | 
				
			||||||
  Column('id', Integer(),  primary_key=True, nullable=False),
 | 
					  Column('id', Integer(),  primary_key=True, nullable=False),
 | 
				
			||||||
@@ -531,9 +531,9 @@ class TestShellDatabase(Shell, fixture.DB):
 | 
				
			|||||||
  Column('passwd', String(length=None, convert_unicode=False, assert_unicode=None)),""" in output.strip(), output
 | 
					  Column('passwd', String(length=None, convert_unicode=False, assert_unicode=None)),""" in output.strip(), output
 | 
				
			||||||
        
 | 
					        
 | 
				
			||||||
        # We're happy with db changes, make first db upgrade script to go from version 0 -> 1.
 | 
					        # We're happy with db changes, make first db upgrade script to go from version 0 -> 1.
 | 
				
			||||||
        output, exitcode = self.output_and_exitcode('python %s make_update_script_for_model' % script_path)  # intentionally omit a parameter
 | 
					        output, exitcode = self.output_and_exitcode('%s %s make_update_script_for_model' % (sys.executable, script_path))  # intentionally omit a parameter
 | 
				
			||||||
        self.assertEquals('Not enough arguments' in output, True)
 | 
					        self.assertEquals('Not enough arguments' in output, True)
 | 
				
			||||||
        output, exitcode = self.output_and_exitcode('python %s make_update_script_for_model --oldmodel=oldtestmodel:meta' % script_path)
 | 
					        output, exitcode = self.output_and_exitcode('%s %s make_update_script_for_model --oldmodel=oldtestmodel:meta' % (sys.executable, script_path))
 | 
				
			||||||
        self.assertEqualsIgnoreWhitespace(output,
 | 
					        self.assertEqualsIgnoreWhitespace(output,
 | 
				
			||||||
        """from sqlalchemy import *
 | 
					        """from sqlalchemy import *
 | 
				
			||||||
from migrate import *
 | 
					from migrate import *
 | 
				
			||||||
@@ -560,9 +560,9 @@ def downgrade(migrate_engine):
 | 
				
			|||||||
        self.assertSuccess(self.cmd('script', '--repository=%s' % repos_path, 'Desc'))
 | 
					        self.assertSuccess(self.cmd('script', '--repository=%s' % repos_path, 'Desc'))
 | 
				
			||||||
        upgrade_script_path = '%s/versions/001_Desc.py' % repos_path
 | 
					        upgrade_script_path = '%s/versions/001_Desc.py' % repos_path
 | 
				
			||||||
        open(upgrade_script_path, 'w').write(output)
 | 
					        open(upgrade_script_path, 'w').write(output)
 | 
				
			||||||
        #output, exitcode = self.output_and_exitcode('python %s test %s' % (script_path, upgrade_script_path))  # no, we already upgraded the db above
 | 
					        #output, exitcode = self.output_and_exitcode('%s %s test %s' % (sys.executable, script_path, upgrade_script_path))  # no, we already upgraded the db above
 | 
				
			||||||
        #self.assertEquals(output, "")
 | 
					        #self.assertEquals(output, "")
 | 
				
			||||||
        output, exitcode = self.output_and_exitcode('python %s update_db_from_model' % script_path)  # bump the db_version
 | 
					        output, exitcode = self.output_and_exitcode('%s %s update_db_from_model' % (sys.executable, script_path))  # bump the db_version
 | 
				
			||||||
        self.assertEquals(exitcode, None)
 | 
					        self.assertEquals(exitcode, None)
 | 
				
			||||||
        self.assertEquals(self.cmd_version(repos_path),1)
 | 
					        self.assertEquals(self.cmd_version(repos_path),1)
 | 
				
			||||||
        self.assertEquals(self.cmd_db_version(self.url,repos_path),1)
 | 
					        self.assertEquals(self.cmd_db_version(self.url,repos_path),1)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user