oslo.db/oslo_db/tests/sqlalchemy/test_utils.py
Jeremy Stanley 9b552046f5 Switch from MySQL-python to PyMySQL
As discussed in the Liberty Design Summit "Moving apps to Python 3"
cross-project workshop, the way forward in the near future is to
switch to the pure-python PyMySQL library as a default.

Added a special test environment to keep MySQL-python support.
Documentation modified.

https://etherpad.openstack.org/p/liberty-cross-project-python3

Change-Id: I12b32dc097a121bd43991bc38dd4d289b65e86c1
2015-06-18 15:42:23 +03:00

1255 lines
48 KiB
Python

# Copyright (c) 2013 Boris Pavlovic (boris@pavlovic.me).
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import uuid
import fixtures
import mock
from oslotest import base as test_base
from oslotest import moxstubout
import six
from six.moves.urllib import parse
import sqlalchemy
from sqlalchemy.dialects import mysql
from sqlalchemy import Boolean, Index, Integer, DateTime, String, SmallInteger
from sqlalchemy import MetaData, Table, Column, ForeignKey
from sqlalchemy.engine import reflection
from sqlalchemy.engine import url as sa_url
from sqlalchemy.exc import OperationalError
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import Session
from sqlalchemy.sql import select
from sqlalchemy.types import UserDefinedType, NullType
from oslo_db import exception
from oslo_db.sqlalchemy.compat import utils as compat_utils
from oslo_db.sqlalchemy import models
from oslo_db.sqlalchemy import provision
from oslo_db.sqlalchemy import session
from oslo_db.sqlalchemy import test_base as db_test_base
from oslo_db.sqlalchemy import utils
from oslo_db.tests import utils as test_utils
Base = declarative_base()
SA_VERSION = compat_utils.SQLA_VERSION
class TestSanitizeDbUrl(test_base.BaseTestCase):
def test_url_with_cred(self):
db_url = 'myproto://johndoe:secret@localhost/myschema'
expected = 'myproto://****:****@localhost/myschema'
actual = utils.sanitize_db_url(db_url)
self.assertEqual(expected, actual)
def test_url_with_no_cred(self):
db_url = 'sqlite:///mysqlitefile'
actual = utils.sanitize_db_url(db_url)
self.assertEqual(db_url, actual)
class CustomType(UserDefinedType):
"""Dummy column type for testing unsupported types."""
def get_col_spec(self):
return "CustomType"
class FakeTable(Base):
__tablename__ = 'fake_table'
user_id = Column(String(50), primary_key=True)
project_id = Column(String(50))
snapshot_id = Column(String(50))
# mox is comparing in some awkward way that
# in this case requires the same identity of object
_expr_to_appease_mox = project_id + snapshot_id
@hybrid_property
def some_hybrid(self):
raise NotImplementedError()
@some_hybrid.expression
def some_hybrid(cls):
return cls._expr_to_appease_mox
def foo(self):
pass
class FakeModel(object):
def __init__(self, values):
self.values = values
def __getattr__(self, name):
try:
value = self.values[name]
except KeyError:
raise AttributeError(name)
return value
def __getitem__(self, key):
if key in self.values:
return self.values[key]
else:
raise NotImplementedError()
def __repr__(self):
return '<FakeModel: %s>' % self.values
class TestPaginateQuery(test_base.BaseTestCase):
def setUp(self):
super(TestPaginateQuery, self).setUp()
mox_fixture = self.useFixture(moxstubout.MoxStubout())
self.mox = mox_fixture.mox
self.query = self.mox.CreateMockAnything()
self.mox.StubOutWithMock(sqlalchemy, 'asc')
self.mox.StubOutWithMock(sqlalchemy, 'desc')
self.marker = FakeTable(user_id='user',
project_id='p',
snapshot_id='s')
self.model = FakeTable
def test_paginate_query_no_pagination_no_sort_dirs(self):
sqlalchemy.asc(self.model.user_id).AndReturn('asc_3')
self.query.order_by('asc_3').AndReturn(self.query)
sqlalchemy.asc(self.model.project_id).AndReturn('asc_2')
self.query.order_by('asc_2').AndReturn(self.query)
sqlalchemy.asc(self.model.snapshot_id).AndReturn('asc_1')
self.query.order_by('asc_1').AndReturn(self.query)
self.query.limit(5).AndReturn(self.query)
self.mox.ReplayAll()
utils.paginate_query(self.query, self.model, 5,
['user_id', 'project_id', 'snapshot_id'])
def test_paginate_query_no_pagination(self):
sqlalchemy.asc(self.model.user_id).AndReturn('asc')
self.query.order_by('asc').AndReturn(self.query)
sqlalchemy.desc(self.model.project_id).AndReturn('desc')
self.query.order_by('desc').AndReturn(self.query)
self.query.limit(5).AndReturn(self.query)
self.mox.ReplayAll()
utils.paginate_query(self.query, self.model, 5,
['user_id', 'project_id'],
sort_dirs=['asc', 'desc'])
def test_paginate_query_attribute_error(self):
sqlalchemy.asc(self.model.user_id).AndReturn('asc')
self.query.order_by('asc').AndReturn(self.query)
self.mox.ReplayAll()
self.assertRaises(exception.InvalidSortKey,
utils.paginate_query, self.query,
self.model, 5, ['user_id', 'non-existent key'])
def test_paginate_query_attribute_error_invalid_sortkey(self):
self.assertRaises(exception.InvalidSortKey,
utils.paginate_query, self.query,
self.model, 5, ['bad_user_id'])
def test_paginate_query_attribute_error_invalid_sortkey_2(self):
self.assertRaises(exception.InvalidSortKey,
utils.paginate_query, self.query,
self.model, 5, ['foo'])
def test_paginate_query_assertion_error(self):
self.mox.ReplayAll()
self.assertRaises(AssertionError,
utils.paginate_query, self.query,
self.model, 5, ['user_id'],
marker=self.marker,
sort_dir='asc', sort_dirs=['asc'])
def test_paginate_query_assertion_error_2(self):
self.mox.ReplayAll()
self.assertRaises(AssertionError,
utils.paginate_query, self.query,
self.model, 5, ['user_id'],
marker=self.marker,
sort_dir=None, sort_dirs=['asc', 'desk'])
def test_paginate_query(self):
sqlalchemy.asc(self.model.user_id).AndReturn('asc_1')
self.query.order_by('asc_1').AndReturn(self.query)
sqlalchemy.desc(self.model.project_id).AndReturn('desc_1')
self.query.order_by('desc_1').AndReturn(self.query)
self.mox.StubOutWithMock(sqlalchemy.sql, 'and_')
sqlalchemy.sql.and_(mock.ANY).AndReturn('some_crit')
sqlalchemy.sql.and_(mock.ANY, mock.ANY).AndReturn('another_crit')
self.mox.StubOutWithMock(sqlalchemy.sql, 'or_')
sqlalchemy.sql.or_('some_crit', 'another_crit').AndReturn('some_f')
self.query.filter('some_f').AndReturn(self.query)
self.query.limit(5).AndReturn(self.query)
self.mox.ReplayAll()
utils.paginate_query(self.query, self.model, 5,
['user_id', 'project_id'],
marker=self.marker,
sort_dirs=['asc', 'desc'])
def test_paginate_query_value_error(self):
sqlalchemy.asc(self.model.user_id).AndReturn('asc_1')
self.query.order_by('asc_1').AndReturn(self.query)
self.mox.ReplayAll()
self.assertRaises(ValueError, utils.paginate_query,
self.query, self.model, 5, ['user_id', 'project_id'],
marker=self.marker, sort_dirs=['asc', 'mixed'])
def test_paginate_on_hybrid(self):
sqlalchemy.asc(self.model.user_id).AndReturn('asc_1')
self.query.order_by('asc_1').AndReturn(self.query)
sqlalchemy.desc(self.model.some_hybrid).AndReturn('desc_1')
self.query.order_by('desc_1').AndReturn(self.query)
self.query.limit(5).AndReturn(self.query)
self.mox.ReplayAll()
utils.paginate_query(self.query, self.model, 5,
['user_id', 'some_hybrid'],
sort_dirs=['asc', 'desc'])
class TestPaginateQueryActualSQL(test_base.BaseTestCase):
def test_paginate_on_hybrid_assert_stmt(self):
s = Session()
q = s.query(FakeTable)
q = utils.paginate_query(
q, FakeTable, 5,
['user_id', 'some_hybrid'],
sort_dirs=['asc', 'desc'])
expected_core_sql = (
select([FakeTable]).
order_by(sqlalchemy.asc(FakeTable.user_id)).
order_by(sqlalchemy.desc(FakeTable.some_hybrid)).
limit(5)
)
self.assertEqual(
str(q.statement.compile()),
str(expected_core_sql.compile())
)
class TestMigrationUtils(db_test_base.DbTestCase):
"""Class for testing utils that are used in db migrations."""
def setUp(self):
super(TestMigrationUtils, self).setUp()
self.meta = MetaData(bind=self.engine)
self.conn = self.engine.connect()
self.addCleanup(self.meta.drop_all)
self.addCleanup(self.conn.close)
def _populate_db_for_drop_duplicate_entries(self, engine, meta,
table_name):
values = [
{'id': 11, 'a': 3, 'b': 10, 'c': 'abcdef'},
{'id': 12, 'a': 5, 'b': 10, 'c': 'abcdef'},
{'id': 13, 'a': 6, 'b': 10, 'c': 'abcdef'},
{'id': 14, 'a': 7, 'b': 10, 'c': 'abcdef'},
{'id': 21, 'a': 1, 'b': 20, 'c': 'aa'},
{'id': 31, 'a': 1, 'b': 20, 'c': 'bb'},
{'id': 41, 'a': 1, 'b': 30, 'c': 'aef'},
{'id': 42, 'a': 2, 'b': 30, 'c': 'aef'},
{'id': 43, 'a': 3, 'b': 30, 'c': 'aef'}
]
test_table = Table(table_name, meta,
Column('id', Integer, primary_key=True,
nullable=False),
Column('a', Integer),
Column('b', Integer),
Column('c', String(255)),
Column('deleted', Integer, default=0),
Column('deleted_at', DateTime),
Column('updated_at', DateTime))
test_table.create()
engine.execute(test_table.insert(), values)
return test_table, values
def test_drop_old_duplicate_entries_from_table(self):
table_name = "__test_tmp_table__"
test_table, values = self._populate_db_for_drop_duplicate_entries(
self.engine, self.meta, table_name)
utils.drop_old_duplicate_entries_from_table(
self.engine, table_name, False, 'b', 'c')
uniq_values = set()
expected_ids = []
for value in sorted(values, key=lambda x: x['id'], reverse=True):
uniq_value = (('b', value['b']), ('c', value['c']))
if uniq_value in uniq_values:
continue
uniq_values.add(uniq_value)
expected_ids.append(value['id'])
real_ids = [row[0] for row in
self.engine.execute(select([test_table.c.id])).fetchall()]
self.assertEqual(len(real_ids), len(expected_ids))
for id_ in expected_ids:
self.assertTrue(id_ in real_ids)
def test_drop_dup_entries_in_file_conn(self):
table_name = "__test_tmp_table__"
tmp_db_file = self.create_tempfiles([['name', '']], ext='.sql')[0]
in_file_engine = session.EngineFacade(
'sqlite:///%s' % tmp_db_file).get_engine()
meta = MetaData()
meta.bind = in_file_engine
test_table, values = self._populate_db_for_drop_duplicate_entries(
in_file_engine, meta, table_name)
utils.drop_old_duplicate_entries_from_table(
in_file_engine, table_name, False, 'b', 'c')
def test_drop_old_duplicate_entries_from_table_soft_delete(self):
table_name = "__test_tmp_table__"
table, values = self._populate_db_for_drop_duplicate_entries(
self.engine, self.meta, table_name)
utils.drop_old_duplicate_entries_from_table(self.engine, table_name,
True, 'b', 'c')
uniq_values = set()
expected_values = []
soft_deleted_values = []
for value in sorted(values, key=lambda x: x['id'], reverse=True):
uniq_value = (('b', value['b']), ('c', value['c']))
if uniq_value in uniq_values:
soft_deleted_values.append(value)
continue
uniq_values.add(uniq_value)
expected_values.append(value)
base_select = table.select()
rows_select = base_select.where(table.c.deleted != table.c.id)
row_ids = [row['id'] for row in
self.engine.execute(rows_select).fetchall()]
self.assertEqual(len(row_ids), len(expected_values))
for value in expected_values:
self.assertTrue(value['id'] in row_ids)
deleted_rows_select = base_select.where(
table.c.deleted == table.c.id)
deleted_rows_ids = [row['id'] for row in
self.engine.execute(
deleted_rows_select).fetchall()]
self.assertEqual(len(deleted_rows_ids),
len(values) - len(row_ids))
for value in soft_deleted_values:
self.assertTrue(value['id'] in deleted_rows_ids)
def test_change_deleted_column_type_does_not_drop_index(self):
table_name = 'abc'
indexes = {
'idx_a_deleted': ['a', 'deleted'],
'idx_b_deleted': ['b', 'deleted'],
'idx_a': ['a']
}
index_instances = [Index(name, *columns)
for name, columns in six.iteritems(indexes)]
table = Table(table_name, self.meta,
Column('id', Integer, primary_key=True),
Column('a', String(255)),
Column('b', String(255)),
Column('deleted', Boolean),
*index_instances)
table.create()
utils.change_deleted_column_type_to_id_type(self.engine, table_name)
utils.change_deleted_column_type_to_boolean(self.engine, table_name)
insp = reflection.Inspector.from_engine(self.engine)
real_indexes = insp.get_indexes(table_name)
self.assertEqual(len(real_indexes), 3)
for index in real_indexes:
name = index['name']
self.assertIn(name, indexes)
self.assertEqual(set(index['column_names']),
set(indexes[name]))
def test_change_deleted_column_type_to_id_type_integer(self):
table_name = 'abc'
table = Table(table_name, self.meta,
Column('id', Integer, primary_key=True),
Column('deleted', Boolean))
table.create()
utils.change_deleted_column_type_to_id_type(self.engine, table_name)
table = utils.get_table(self.engine, table_name)
self.assertTrue(isinstance(table.c.deleted.type, Integer))
def test_change_deleted_column_type_to_id_type_string(self):
table_name = 'abc'
table = Table(table_name, self.meta,
Column('id', String(255), primary_key=True),
Column('deleted', Boolean))
table.create()
utils.change_deleted_column_type_to_id_type(self.engine, table_name)
table = utils.get_table(self.engine, table_name)
self.assertTrue(isinstance(table.c.deleted.type, String))
@db_test_base.backend_specific('sqlite')
def test_change_deleted_column_type_to_id_type_custom(self):
table_name = 'abc'
table = Table(table_name, self.meta,
Column('id', Integer, primary_key=True),
Column('foo', CustomType),
Column('deleted', Boolean))
table.create()
# reflection of custom types has been fixed upstream
if SA_VERSION < (0, 9, 0):
self.assertRaises(exception.ColumnError,
utils.change_deleted_column_type_to_id_type,
self.engine, table_name)
fooColumn = Column('foo', CustomType())
utils.change_deleted_column_type_to_id_type(self.engine, table_name,
foo=fooColumn)
table = utils.get_table(self.engine, table_name)
# NOTE(boris-42): There is no way to check has foo type CustomType.
# but sqlalchemy will set it to NullType. This has
# been fixed upstream in recent SA versions
if SA_VERSION < (0, 9, 0):
self.assertTrue(isinstance(table.c.foo.type, NullType))
self.assertTrue(isinstance(table.c.deleted.type, Integer))
def test_change_deleted_column_type_to_boolean(self):
expected_types = {'mysql': mysql.TINYINT,
'ibm_db_sa': SmallInteger}
table_name = 'abc'
table = Table(table_name, self.meta,
Column('id', Integer, primary_key=True),
Column('deleted', Integer))
table.create()
utils.change_deleted_column_type_to_boolean(self.engine, table_name)
table = utils.get_table(self.engine, table_name)
self.assertIsInstance(table.c.deleted.type,
expected_types.get(self.engine.name, Boolean))
def test_change_deleted_column_type_to_boolean_with_fc(self):
expected_types = {'mysql': mysql.TINYINT,
'ibm_db_sa': SmallInteger}
table_name_1 = 'abc'
table_name_2 = 'bcd'
table_1 = Table(table_name_1, self.meta,
Column('id', Integer, primary_key=True),
Column('deleted', Integer))
table_1.create()
table_2 = Table(table_name_2, self.meta,
Column('id', Integer, primary_key=True),
Column('foreign_id', Integer,
ForeignKey('%s.id' % table_name_1)),
Column('deleted', Integer))
table_2.create()
utils.change_deleted_column_type_to_boolean(self.engine, table_name_2)
table = utils.get_table(self.engine, table_name_2)
self.assertIsInstance(table.c.deleted.type,
expected_types.get(self.engine.name, Boolean))
@db_test_base.backend_specific('sqlite')
def test_change_deleted_column_type_to_boolean_type_custom(self):
table_name = 'abc'
table = Table(table_name, self.meta,
Column('id', Integer, primary_key=True),
Column('foo', CustomType),
Column('deleted', Integer))
table.create()
# reflection of custom types has been fixed upstream
if SA_VERSION < (0, 9, 0):
self.assertRaises(exception.ColumnError,
utils.change_deleted_column_type_to_boolean,
self.engine, table_name)
fooColumn = Column('foo', CustomType())
utils.change_deleted_column_type_to_boolean(self.engine, table_name,
foo=fooColumn)
table = utils.get_table(self.engine, table_name)
# NOTE(boris-42): There is no way to check has foo type CustomType.
# but sqlalchemy will set it to NullType. This has
# been fixed upstream in recent SA versions
if SA_VERSION < (0, 9, 0):
self.assertTrue(isinstance(table.c.foo.type, NullType))
self.assertTrue(isinstance(table.c.deleted.type, Boolean))
@db_test_base.backend_specific('sqlite')
def test_change_deleted_column_type_sqlite_drops_check_constraint(self):
table_name = 'abc'
table = Table(table_name, self.meta,
Column('id', Integer, primary_key=True),
Column('deleted', Boolean))
table.create()
utils._change_deleted_column_type_to_id_type_sqlite(self.engine,
table_name)
table = Table(table_name, self.meta, autoload=True)
# NOTE(I159): if the CHECK constraint has been dropped (expected
# behavior), any integer value can be inserted, otherwise only 1 or 0.
self.engine.execute(table.insert({'deleted': 10}))
def test_insert_from_select(self):
insert_table_name = "__test_insert_to_table__"
select_table_name = "__test_select_from_table__"
uuidstrs = []
for unused in range(10):
uuidstrs.append(uuid.uuid4().hex)
insert_table = Table(
insert_table_name, self.meta,
Column('id', Integer, primary_key=True,
nullable=False, autoincrement=True),
Column('uuid', String(36), nullable=False))
select_table = Table(
select_table_name, self.meta,
Column('id', Integer, primary_key=True,
nullable=False, autoincrement=True),
Column('uuid', String(36), nullable=False))
insert_table.create()
select_table.create()
# Add 10 rows to select_table
for uuidstr in uuidstrs:
ins_stmt = select_table.insert().values(uuid=uuidstr)
self.conn.execute(ins_stmt)
# Select 4 rows in one chunk from select_table
column = select_table.c.id
query_insert = select([select_table],
select_table.c.id < 5).order_by(column)
insert_statement = utils.InsertFromSelect(insert_table,
query_insert)
result_insert = self.conn.execute(insert_statement)
# Verify we insert 4 rows
self.assertEqual(result_insert.rowcount, 4)
query_all = select([insert_table]).where(
insert_table.c.uuid.in_(uuidstrs))
rows = self.conn.execute(query_all).fetchall()
# Verify we really have 4 rows in insert_table
self.assertEqual(len(rows), 4)
def test_insert_from_select_with_specified_columns(self):
insert_table_name = "__test_insert_to_table__"
select_table_name = "__test_select_from_table__"
uuidstrs = []
for unused in range(10):
uuidstrs.append(uuid.uuid4().hex)
insert_table = Table(
insert_table_name, self.meta,
Column('id', Integer, primary_key=True,
nullable=False, autoincrement=True),
Column('uuid', String(36), nullable=False))
select_table = Table(
select_table_name, self.meta,
Column('id', Integer, primary_key=True,
nullable=False, autoincrement=True),
Column('uuid', String(36), nullable=False))
insert_table.create()
select_table.create()
# Add 10 rows to select_table
for uuidstr in uuidstrs:
ins_stmt = select_table.insert().values(uuid=uuidstr)
self.conn.execute(ins_stmt)
# Select 4 rows in one chunk from select_table
column = select_table.c.id
query_insert = select([select_table],
select_table.c.id < 5).order_by(column)
insert_statement = utils.InsertFromSelect(insert_table,
query_insert, ['id', 'uuid'])
result_insert = self.conn.execute(insert_statement)
# Verify we insert 4 rows
self.assertEqual(result_insert.rowcount, 4)
query_all = select([insert_table]).where(
insert_table.c.uuid.in_(uuidstrs))
rows = self.conn.execute(query_all).fetchall()
# Verify we really have 4 rows in insert_table
self.assertEqual(len(rows), 4)
def test_insert_from_select_with_specified_columns_negative(self):
insert_table_name = "__test_insert_to_table__"
select_table_name = "__test_select_from_table__"
uuidstrs = []
for unused in range(10):
uuidstrs.append(uuid.uuid4().hex)
insert_table = Table(
insert_table_name, self.meta,
Column('id', Integer, primary_key=True,
nullable=False, autoincrement=True),
Column('uuid', String(36), nullable=False))
select_table = Table(
select_table_name, self.meta,
Column('id', Integer, primary_key=True,
nullable=False, autoincrement=True),
Column('uuid', String(36), nullable=False))
insert_table.create()
select_table.create()
# Add 10 rows to select_table
for uuidstr in uuidstrs:
ins_stmt = select_table.insert().values(uuid=uuidstr)
self.conn.execute(ins_stmt)
# Select 4 rows in one chunk from select_table
column = select_table.c.id
query_insert = select([select_table],
select_table.c.id < 5).order_by(column)
insert_statement = utils.InsertFromSelect(insert_table,
query_insert, ['uuid', 'id'])
self.assertRaises(exception.DBError, self.conn.execute,
insert_statement)
class PostgesqlTestMigrations(TestMigrationUtils,
db_test_base.PostgreSQLOpportunisticTestCase):
"""Test migrations on PostgreSQL."""
pass
class MySQLTestMigrations(TestMigrationUtils,
db_test_base.MySQLOpportunisticTestCase):
"""Test migrations on MySQL."""
pass
class TestConnectionUtils(test_utils.BaseTestCase):
def setUp(self):
super(TestConnectionUtils, self).setUp()
self.full_credentials = {'backend': 'postgresql',
'database': 'test',
'user': 'dude',
'passwd': 'pass'}
self.connect_string = 'postgresql://dude:pass@localhost/test'
def test_connect_string(self):
connect_string = utils.get_connect_string(**self.full_credentials)
self.assertEqual(connect_string, self.connect_string)
def test_connect_string_sqlite(self):
sqlite_credentials = {'backend': 'sqlite', 'database': 'test.db'}
connect_string = utils.get_connect_string(**sqlite_credentials)
self.assertEqual(connect_string, 'sqlite:///test.db')
def test_is_backend_avail(self):
self.mox.StubOutWithMock(sqlalchemy.engine.base.Engine, 'connect')
fake_connection = self.mox.CreateMockAnything()
fake_connection.close()
sqlalchemy.engine.base.Engine.connect().AndReturn(fake_connection)
self.mox.ReplayAll()
self.assertTrue(utils.is_backend_avail(**self.full_credentials))
def test_is_backend_unavail(self):
log = self.useFixture(fixtures.FakeLogger())
err = OperationalError("Can't connect to database", None, None)
error_msg = "The postgresql backend is unavailable: %s\n" % err
self.mox.StubOutWithMock(sqlalchemy.engine.base.Engine, 'connect')
sqlalchemy.engine.base.Engine.connect().AndRaise(err)
self.mox.ReplayAll()
self.assertFalse(utils.is_backend_avail(**self.full_credentials))
self.assertEqual(error_msg, log.output)
def test_ensure_backend_available(self):
self.mox.StubOutWithMock(sqlalchemy.engine.base.Engine, 'connect')
fake_connection = self.mox.CreateMockAnything()
fake_connection.close()
sqlalchemy.engine.base.Engine.connect().AndReturn(fake_connection)
self.mox.ReplayAll()
eng = provision.Backend._ensure_backend_available(self.connect_string)
self.assertIsInstance(eng, sqlalchemy.engine.base.Engine)
self.assertEqual(self.connect_string, str(eng.url))
def test_ensure_backend_available_no_connection_raises(self):
log = self.useFixture(fixtures.FakeLogger())
err = OperationalError("Can't connect to database", None, None)
self.mox.StubOutWithMock(sqlalchemy.engine.base.Engine, 'connect')
sqlalchemy.engine.base.Engine.connect().AndRaise(err)
self.mox.ReplayAll()
exc = self.assertRaises(
exception.BackendNotAvailable,
provision.Backend._ensure_backend_available, self.connect_string
)
self.assertEqual("Could not connect", str(exc))
self.assertEqual(
"The postgresql backend is unavailable: %s" % err,
log.output.strip())
def test_ensure_backend_available_no_dbapi_raises(self):
log = self.useFixture(fixtures.FakeLogger())
self.mox.StubOutWithMock(sqlalchemy, 'create_engine')
sqlalchemy.create_engine(
sa_url.make_url(self.connect_string)).AndRaise(
ImportError("Can't import DBAPI module foobar"))
self.mox.ReplayAll()
exc = self.assertRaises(
exception.BackendNotAvailable,
provision.Backend._ensure_backend_available, self.connect_string
)
self.assertEqual("No DBAPI installed", str(exc))
self.assertEqual(
"The postgresql backend is unavailable: Can't import "
"DBAPI module foobar", log.output.strip())
def test_get_db_connection_info(self):
conn_pieces = parse.urlparse(self.connect_string)
self.assertEqual(utils.get_db_connection_info(conn_pieces),
('dude', 'pass', 'test', 'localhost'))
def test_connect_string_host(self):
self.full_credentials['host'] = 'myhost'
connect_string = utils.get_connect_string(**self.full_credentials)
self.assertEqual(connect_string, 'postgresql://dude:pass@myhost/test')
class MyModelSoftDeletedProjectId(declarative_base(), models.ModelBase,
models.SoftDeleteMixin):
__tablename__ = 'soft_deleted_project_id_test_model'
id = Column(Integer, primary_key=True)
project_id = Column(Integer)
class MyModel(declarative_base(), models.ModelBase):
__tablename__ = 'test_model'
id = Column(Integer, primary_key=True)
class MyModelSoftDeleted(declarative_base(), models.ModelBase,
models.SoftDeleteMixin):
__tablename__ = 'soft_deleted_test_model'
id = Column(Integer, primary_key=True)
class TestModelQuery(test_base.BaseTestCase):
def setUp(self):
super(TestModelQuery, self).setUp()
self.session = mock.MagicMock()
self.session.query.return_value = self.session.query
self.session.query.filter.return_value = self.session.query
def test_wrong_model(self):
self.assertRaises(TypeError, utils.model_query,
FakeModel, session=self.session)
def test_no_soft_deleted(self):
self.assertRaises(ValueError, utils.model_query,
MyModel, session=self.session, deleted=True)
def test_deleted_false(self):
mock_query = utils.model_query(
MyModelSoftDeleted, session=self.session, deleted=False)
deleted_filter = mock_query.filter.call_args[0][0]
self.assertEqual(str(deleted_filter),
'soft_deleted_test_model.deleted = :deleted_1')
self.assertEqual(deleted_filter.right.value,
MyModelSoftDeleted.__mapper__.c.deleted.default.arg)
def test_deleted_true(self):
mock_query = utils.model_query(
MyModelSoftDeleted, session=self.session, deleted=True)
deleted_filter = mock_query.filter.call_args[0][0]
self.assertEqual(str(deleted_filter),
'soft_deleted_test_model.deleted != :deleted_1')
self.assertEqual(deleted_filter.right.value,
MyModelSoftDeleted.__mapper__.c.deleted.default.arg)
@mock.patch.object(utils, "_read_deleted_filter")
def test_no_deleted_value(self, _read_deleted_filter):
utils.model_query(MyModelSoftDeleted, session=self.session)
self.assertEqual(_read_deleted_filter.call_count, 0)
def test_project_filter(self):
project_id = 10
mock_query = utils.model_query(
MyModelSoftDeletedProjectId, session=self.session,
project_only=True, project_id=project_id)
deleted_filter = mock_query.filter.call_args[0][0]
self.assertEqual(
str(deleted_filter),
'soft_deleted_project_id_test_model.project_id = :project_id_1')
self.assertEqual(deleted_filter.right.value, project_id)
def test_project_filter_wrong_model(self):
self.assertRaises(ValueError, utils.model_query,
MyModelSoftDeleted, session=self.session,
project_id=10)
def test_project_filter_allow_none(self):
mock_query = utils.model_query(
MyModelSoftDeletedProjectId,
session=self.session, project_id=(10, None))
self.assertEqual(
str(mock_query.filter.call_args[0][0]),
'soft_deleted_project_id_test_model.project_id'
' IN (:project_id_1, NULL)'
)
def test_model_query_common(self):
utils.model_query(MyModel, args=(MyModel.id,), session=self.session)
self.session.query.assert_called_with(MyModel.id)
class TestUtils(db_test_base.DbTestCase):
def setUp(self):
super(TestUtils, self).setUp()
meta = MetaData(bind=self.engine)
self.test_table = Table(
'test_table',
meta,
Column('a', Integer),
Column('b', Integer)
)
self.test_table.create()
self.addCleanup(meta.drop_all)
def test_index_exists(self):
self.assertFalse(utils.index_exists(self.engine, 'test_table',
'new_index'))
Index('new_index', self.test_table.c.a).create(self.engine)
self.assertTrue(utils.index_exists(self.engine, 'test_table',
'new_index'))
def test_add_index(self):
self.assertFalse(utils.index_exists(self.engine, 'test_table',
'new_index'))
utils.add_index(self.engine, 'test_table', 'new_index', ('a',))
self.assertTrue(utils.index_exists(self.engine, 'test_table',
'new_index'))
def test_add_existing_index(self):
Index('new_index', self.test_table.c.a).create(self.engine)
self.assertRaises(ValueError, utils.add_index, self.engine,
'test_table', 'new_index', ('a',))
def test_drop_index(self):
Index('new_index', self.test_table.c.a).create(self.engine)
utils.drop_index(self.engine, 'test_table', 'new_index')
self.assertFalse(utils.index_exists(self.engine, 'test_table',
'new_index'))
def test_drop_unexisting_index(self):
self.assertRaises(ValueError, utils.drop_index, self.engine,
'test_table', 'new_index')
@mock.patch('oslo_db.sqlalchemy.utils.drop_index')
@mock.patch('oslo_db.sqlalchemy.utils.add_index')
def test_change_index_columns(self, add_index, drop_index):
utils.change_index_columns(self.engine, 'test_table', 'a_index',
('a',))
utils.drop_index.assert_called_once_with(self.engine, 'test_table',
'a_index')
utils.add_index.assert_called_once_with(self.engine, 'test_table',
'a_index', ('a',))
def test_column_exists(self):
for col in ['a', 'b']:
self.assertTrue(utils.column_exists(self.engine, 'test_table',
col))
self.assertFalse(utils.column_exists(self.engine, 'test_table',
'fake_column'))
class TestUtilsMysqlOpportunistically(
TestUtils, db_test_base.MySQLOpportunisticTestCase):
pass
class TestUtilsPostgresqlOpportunistically(
TestUtils, db_test_base.PostgreSQLOpportunisticTestCase):
pass
class TestDialectFunctionDispatcher(test_base.BaseTestCase):
def _single_fixture(self):
callable_fn = mock.Mock()
dispatcher = orig = utils.dispatch_for_dialect("*")(
callable_fn.default)
dispatcher = dispatcher.dispatch_for("sqlite")(callable_fn.sqlite)
dispatcher = dispatcher.dispatch_for("mysql+pymysql")(
callable_fn.mysql_pymysql)
dispatcher = dispatcher.dispatch_for("mysql")(
callable_fn.mysql)
dispatcher = dispatcher.dispatch_for("postgresql")(
callable_fn.postgresql)
self.assertTrue(dispatcher is orig)
return dispatcher, callable_fn
def _multiple_fixture(self):
callable_fn = mock.Mock()
for targ in [
callable_fn.default,
callable_fn.sqlite,
callable_fn.mysql,
callable_fn.mysql_pymysql,
callable_fn.postgresql,
callable_fn.postgresql_psycopg2,
callable_fn.pyodbc
]:
targ.return_value = None
dispatcher = orig = utils.dispatch_for_dialect("*", multiple=True)(
callable_fn.default)
dispatcher = dispatcher.dispatch_for("sqlite")(callable_fn.sqlite)
dispatcher = dispatcher.dispatch_for("mysql+pymysql")(
callable_fn.mysql_pymysql)
dispatcher = dispatcher.dispatch_for("mysql")(
callable_fn.mysql)
dispatcher = dispatcher.dispatch_for("postgresql+*")(
callable_fn.postgresql)
dispatcher = dispatcher.dispatch_for("postgresql+psycopg2")(
callable_fn.postgresql_psycopg2)
dispatcher = dispatcher.dispatch_for("*+pyodbc")(
callable_fn.pyodbc)
self.assertTrue(dispatcher is orig)
return dispatcher, callable_fn
def test_single(self):
dispatcher, callable_fn = self._single_fixture()
dispatcher("sqlite://", 1)
dispatcher("postgresql+psycopg2://u:p@h/t", 2)
dispatcher("mysql+pymysql://u:p@h/t", 3)
dispatcher("mysql://u:p@h/t", 4)
dispatcher("mysql+mysqlconnector://u:p@h/t", 5)
self.assertEqual(
[
mock.call.sqlite('sqlite://', 1),
mock.call.postgresql("postgresql+psycopg2://u:p@h/t", 2),
mock.call.mysql_pymysql("mysql+pymysql://u:p@h/t", 3),
mock.call.mysql("mysql://u:p@h/t", 4),
mock.call.mysql("mysql+mysqlconnector://u:p@h/t", 5),
],
callable_fn.mock_calls)
def test_single_kwarg(self):
dispatcher, callable_fn = self._single_fixture()
dispatcher("sqlite://", foo='bar')
dispatcher("postgresql+psycopg2://u:p@h/t", 1, x='y')
self.assertEqual(
[
mock.call.sqlite('sqlite://', foo='bar'),
mock.call.postgresql(
"postgresql+psycopg2://u:p@h/t",
1, x='y'),
],
callable_fn.mock_calls)
def test_dispatch_on_target(self):
callable_fn = mock.Mock()
@utils.dispatch_for_dialect("*")
def default_fn(url, x, y):
callable_fn.default(url, x, y)
@default_fn.dispatch_for("sqlite")
def sqlite_fn(url, x, y):
callable_fn.sqlite(url, x, y)
default_fn.dispatch_on_drivername("*")(url, x, y)
default_fn("sqlite://", 4, 5)
self.assertEqual(
[
mock.call.sqlite("sqlite://", 4, 5),
mock.call.default("sqlite://", 4, 5)
],
callable_fn.mock_calls
)
def test_single_no_dispatcher(self):
callable_fn = mock.Mock()
dispatcher = utils.dispatch_for_dialect("sqlite")(callable_fn.sqlite)
dispatcher = dispatcher.dispatch_for("mysql")(callable_fn.mysql)
exc = self.assertRaises(
ValueError,
dispatcher, "postgresql://s:t@localhost/test"
)
self.assertEqual(
"No default function found for driver: 'postgresql+psycopg2'",
str(exc)
)
def test_multiple_no_dispatcher(self):
callable_fn = mock.Mock()
dispatcher = utils.dispatch_for_dialect("sqlite", multiple=True)(
callable_fn.sqlite)
dispatcher = dispatcher.dispatch_for("mysql")(callable_fn.mysql)
dispatcher("postgresql://s:t@localhost/test")
self.assertEqual(
[], callable_fn.mock_calls
)
def test_multiple_no_driver(self):
callable_fn = mock.Mock(
default=mock.Mock(return_value=None),
sqlite=mock.Mock(return_value=None)
)
dispatcher = utils.dispatch_for_dialect("*", multiple=True)(
callable_fn.default)
dispatcher = dispatcher.dispatch_for("sqlite")(
callable_fn.sqlite)
dispatcher.dispatch_on_drivername("sqlite")("foo")
self.assertEqual(
[mock.call.sqlite("foo"), mock.call.default("foo")],
callable_fn.mock_calls
)
def test_multiple_nesting(self):
callable_fn = mock.Mock(
default=mock.Mock(return_value=None),
mysql=mock.Mock(return_value=None)
)
dispatcher = utils.dispatch_for_dialect("*", multiple=True)(
callable_fn.default)
dispatcher = dispatcher.dispatch_for("mysql+mysqlconnector")(
dispatcher.dispatch_for("mysql+mysqldb")(
callable_fn.mysql
)
)
mysqldb_url = sqlalchemy.engine.url.make_url("mysql+mysqldb://")
mysqlconnector_url = sqlalchemy.engine.url.make_url(
"mysql+mysqlconnector://")
sqlite_url = sqlalchemy.engine.url.make_url("sqlite://")
dispatcher(mysqldb_url, 1)
dispatcher(mysqlconnector_url, 2)
dispatcher(sqlite_url, 3)
self.assertEqual(
[
mock.call.mysql(mysqldb_url, 1),
mock.call.default(mysqldb_url, 1),
mock.call.mysql(mysqlconnector_url, 2),
mock.call.default(mysqlconnector_url, 2),
mock.call.default(sqlite_url, 3)
],
callable_fn.mock_calls
)
def test_single_retval(self):
dispatcher, callable_fn = self._single_fixture()
callable_fn.mysql_pymysql.return_value = 5
self.assertEqual(
dispatcher("mysql+pymysql://u:p@h/t", 3), 5
)
def test_engine(self):
eng = sqlalchemy.create_engine("sqlite:///path/to/my/db.db")
dispatcher, callable_fn = self._single_fixture()
dispatcher(eng)
self.assertEqual(
[mock.call.sqlite(eng)],
callable_fn.mock_calls
)
def test_url_pymysql(self):
url = sqlalchemy.engine.url.make_url(
"mysql+pymysql://scott:tiger@localhost/test")
dispatcher, callable_fn = self._single_fixture()
dispatcher(url, 15)
self.assertEqual(
[mock.call.mysql_pymysql(url, 15)],
callable_fn.mock_calls
)
def test_url_mysql_generic(self):
url = sqlalchemy.engine.url.make_url(
"mysql://scott:tiger@localhost/test")
dispatcher, callable_fn = self._single_fixture()
dispatcher(url, 15)
self.assertEqual(
[mock.call.mysql(url, 15)],
callable_fn.mock_calls
)
def test_invalid_target(self):
dispatcher, callable_fn = self._single_fixture()
exc = self.assertRaises(
ValueError,
dispatcher, 20
)
self.assertEqual("Invalid target type: 20", str(exc))
def test_invalid_dispatch(self):
callable_fn = mock.Mock()
dispatcher = utils.dispatch_for_dialect("*")(callable_fn.default)
exc = self.assertRaises(
ValueError,
dispatcher.dispatch_for("+pyodbc"), callable_fn.pyodbc
)
self.assertEqual(
"Couldn't parse database[+driver]: '+pyodbc'",
str(exc)
)
def test_single_only_one_target(self):
callable_fn = mock.Mock()
dispatcher = utils.dispatch_for_dialect("*")(callable_fn.default)
dispatcher = dispatcher.dispatch_for("sqlite")(callable_fn.sqlite)
exc = self.assertRaises(
TypeError,
dispatcher.dispatch_for("sqlite"), callable_fn.sqlite2
)
self.assertEqual(
"Multiple functions for expression 'sqlite'", str(exc)
)
def test_multiple(self):
dispatcher, callable_fn = self._multiple_fixture()
dispatcher("postgresql+pyodbc://", 1)
dispatcher("mysql+pymysql://", 2)
dispatcher("ibm_db_sa+db2://", 3)
dispatcher("postgresql+psycopg2://", 4)
dispatcher("postgresql://", 5)
# TODO(zzzeek): there is a deterministic order here, but we might
# want to tweak it, or maybe provide options. default first?
# most specific first? is *+pyodbc or postgresql+* more specific?
self.assertEqual(
[
mock.call.postgresql('postgresql+pyodbc://', 1),
mock.call.pyodbc('postgresql+pyodbc://', 1),
mock.call.default('postgresql+pyodbc://', 1),
mock.call.mysql_pymysql('mysql+pymysql://', 2),
mock.call.mysql('mysql+pymysql://', 2),
mock.call.default('mysql+pymysql://', 2),
mock.call.default('ibm_db_sa+db2://', 3),
mock.call.postgresql_psycopg2('postgresql+psycopg2://', 4),
mock.call.postgresql('postgresql+psycopg2://', 4),
mock.call.default('postgresql+psycopg2://', 4),
# note this is called because we resolve the default
# DBAPI for the url
mock.call.postgresql_psycopg2('postgresql://', 5),
mock.call.postgresql('postgresql://', 5),
mock.call.default('postgresql://', 5),
],
callable_fn.mock_calls
)
def test_multiple_no_return_value(self):
dispatcher, callable_fn = self._multiple_fixture()
callable_fn.sqlite.return_value = 5
exc = self.assertRaises(
TypeError,
dispatcher, "sqlite://"
)
self.assertEqual(
"Return value not allowed for multiple filtered function",
str(exc)
)
class TestGetInnoDBTables(db_test_base.MySQLOpportunisticTestCase):
def test_all_tables_use_innodb(self):
self.engine.execute("CREATE TABLE customers "
"(a INT, b CHAR (20), INDEX (a)) ENGINE=InnoDB")
self.assertEqual([], utils.get_non_innodb_tables(self.engine))
def test_all_tables_use_innodb_false(self):
self.engine.execute("CREATE TABLE employee "
"(i INT) ENGINE=MEMORY")
self.assertEqual(['employee'],
utils.get_non_innodb_tables(self.engine))
def test_skip_tables_use_default_value(self):
self.engine.execute("CREATE TABLE migrate_version "
"(i INT) ENGINE=MEMORY")
self.assertEqual([],
utils.get_non_innodb_tables(self.engine))
def test_skip_tables_use_passed_value(self):
self.engine.execute("CREATE TABLE some_table "
"(i INT) ENGINE=MEMORY")
self.assertEqual([],
utils.get_non_innodb_tables(
self.engine, skip_tables=('some_table',)))
def test_skip_tables_use_empty_list(self):
self.engine.execute("CREATE TABLE some_table_3 "
"(i INT) ENGINE=MEMORY")
self.assertEqual(['some_table_3'],
utils.get_non_innodb_tables(
self.engine, skip_tables=()))
def test_skip_tables_use_several_values(self):
self.engine.execute("CREATE TABLE some_table_1 "
"(i INT) ENGINE=MEMORY")
self.engine.execute("CREATE TABLE some_table_2 "
"(i INT) ENGINE=MEMORY")
self.assertEqual([],
utils.get_non_innodb_tables(
self.engine,
skip_tables=('some_table_1', 'some_table_2')))