# coding=utf-8

# Copyright (c) 2012 Rackspace Hosting
# All Rights Reserved.
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.

"""Unit tests for SQLAlchemy specific code."""

import logging

import fixtures
import mock
from oslo_config import cfg
from oslotest import base as oslo_test
import six
import sqlalchemy
from sqlalchemy import Column, MetaData, Table
from sqlalchemy.engine import url
from sqlalchemy import Integer, String
from sqlalchemy.ext.declarative import declarative_base

from oslo_db import exception
from oslo_db import options as db_options
from oslo_db.sqlalchemy import engines
from oslo_db.sqlalchemy import models
from oslo_db.sqlalchemy import session
from oslo_db.sqlalchemy import test_base


BASE = declarative_base()
_TABLE_NAME = '__tmp__test__tmp__'

_REGEXP_TABLE_NAME = _TABLE_NAME + "regexp"


class RegexpTable(BASE, models.ModelBase):
    __tablename__ = _REGEXP_TABLE_NAME
    id = Column(Integer, primary_key=True)
    bar = Column(String(255))


class RegexpFilterTestCase(test_base.DbTestCase):

    def setUp(self):
        super(RegexpFilterTestCase, self).setUp()
        meta = MetaData()
        meta.bind = self.engine
        test_table = Table(_REGEXP_TABLE_NAME, meta,
                           Column('id', Integer, primary_key=True,
                                  nullable=False),
                           Column('bar', String(255)))
        test_table.create()
        self.addCleanup(test_table.drop)

    def _test_regexp_filter(self, regexp, expected):
        _session = self.sessionmaker()
        with _session.begin():
            for i in ['10', '20', u'♥']:
                tbl = RegexpTable()
                tbl.update({'bar': i})
                tbl.save(session=_session)

        regexp_op = RegexpTable.bar.op('REGEXP')(regexp)
        result = _session.query(RegexpTable).filter(regexp_op).all()
        self.assertEqual([r.bar for r in result], expected)

    def test_regexp_filter(self):
        self._test_regexp_filter('10', ['10'])

    def test_regexp_filter_nomatch(self):
        self._test_regexp_filter('11', [])

    def test_regexp_filter_unicode(self):
        self._test_regexp_filter(u'♥', [u'♥'])

    def test_regexp_filter_unicode_nomatch(self):
        self._test_regexp_filter(u'♦', [])


class SQLiteSavepointTest(test_base.DbTestCase):
    def setUp(self):
        super(SQLiteSavepointTest, self).setUp()
        meta = MetaData()
        self.test_table = Table(
            "test_table", meta,
            Column('id', Integer, primary_key=True),
            Column('data', String(10)))
        self.test_table.create(self.engine)
        self.addCleanup(self.test_table.drop, self.engine)

    def test_plain_transaction(self):
        conn = self.engine.connect()
        trans = conn.begin()
        conn.execute(
            self.test_table.insert(),
            {'data': 'data 1'}
        )
        self.assertEqual(
            [(1, 'data 1')],
            self.engine.execute(
                self.test_table.select().
                order_by(self.test_table.c.id)
            ).fetchall()
        )
        trans.rollback()
        self.assertEqual(
            0,
            self.engine.scalar(self.test_table.count())
        )

    def test_savepoint_middle(self):
        with self.engine.begin() as conn:
            conn.execute(
                self.test_table.insert(),
                {'data': 'data 1'}
            )

            savepoint = conn.begin_nested()
            conn.execute(
                self.test_table.insert(),
                {'data': 'data 2'}
            )
            savepoint.rollback()

            conn.execute(
                self.test_table.insert(),
                {'data': 'data 3'}
            )

        self.assertEqual(
            [(1, 'data 1'), (2, 'data 3')],
            self.engine.execute(
                self.test_table.select().
                order_by(self.test_table.c.id)
            ).fetchall()
        )

    def test_savepoint_beginning(self):
        with self.engine.begin() as conn:
            savepoint = conn.begin_nested()
            conn.execute(
                self.test_table.insert(),
                {'data': 'data 1'}
            )
            savepoint.rollback()

            conn.execute(
                self.test_table.insert(),
                {'data': 'data 2'}
            )

        self.assertEqual(
            [(1, 'data 2')],
            self.engine.execute(
                self.test_table.select().
                order_by(self.test_table.c.id)
            ).fetchall()
        )


class FakeDBAPIConnection(object):
    def cursor(self):
        return FakeCursor()


class FakeCursor(object):
    def execute(self, sql):
        pass


class FakeConnectionProxy(object):
    pass


class FakeConnectionRec(object):
    pass


class OperationalError(Exception):
    pass


class ProgrammingError(Exception):
    pass


class FakeDB2Engine(object):

    class Dialect(object):

        def is_disconnect(self, e, *args):
            expected_error = ('SQL30081N: DB2 Server connection is no longer '
                              'active')
            return (str(e) == expected_error)

    dialect = Dialect()
    name = 'ibm_db_sa'

    def dispose(self):
        pass


class MySQLDefaultModeTestCase(test_base.MySQLOpportunisticTestCase):
    def test_default_is_traditional(self):
        with self.engine.connect() as conn:
            sql_mode = conn.execute(
                "SHOW VARIABLES LIKE 'sql_mode'"
            ).first()[1]

        self.assertTrue("TRADITIONAL" in sql_mode)


class MySQLModeTestCase(test_base.MySQLOpportunisticTestCase):

    def __init__(self, *args, **kwargs):
        super(MySQLModeTestCase, self).__init__(*args, **kwargs)
        # By default, run in empty SQL mode.
        # Subclasses override this with specific modes.
        self.mysql_mode = ''

    def setUp(self):
        super(MySQLModeTestCase, self).setUp()
        mode_engine = session.create_engine(
            self.engine.url,
            mysql_sql_mode=self.mysql_mode)
        self.connection = mode_engine.connect()

        meta = MetaData()
        self.test_table = Table(_TABLE_NAME + "mode", meta,
                                Column('id', Integer, primary_key=True),
                                Column('bar', String(255)))
        self.test_table.create(self.connection)

        def cleanup():
            self.test_table.drop(self.connection)
            self.connection.close()
            mode_engine.dispose()
        self.addCleanup(cleanup)

    def _test_string_too_long(self, value):
        with self.connection.begin():
            self.connection.execute(self.test_table.insert(),
                                    bar=value)
            result = self.connection.execute(self.test_table.select())
            return result.fetchone()['bar']

    def test_string_too_long(self):
        value = 'a' * 512
        # String is too long.
        # With no SQL mode set, this gets truncated.
        self.assertNotEqual(value,
                            self._test_string_too_long(value))


class MySQLStrictAllTablesModeTestCase(MySQLModeTestCase):
    "Test data integrity enforcement in MySQL STRICT_ALL_TABLES mode."

    def __init__(self, *args, **kwargs):
        super(MySQLStrictAllTablesModeTestCase, self).__init__(*args, **kwargs)
        self.mysql_mode = 'STRICT_ALL_TABLES'

    def test_string_too_long(self):
        value = 'a' * 512
        # String is too long.
        # With STRICT_ALL_TABLES or TRADITIONAL mode set, this is an error.
        self.assertRaises(exception.DBError,
                          self._test_string_too_long, value)


class MySQLTraditionalModeTestCase(MySQLStrictAllTablesModeTestCase):
    """Test data integrity enforcement in MySQL TRADITIONAL mode.

    Since TRADITIONAL includes STRICT_ALL_TABLES, this inherits all
    STRICT_ALL_TABLES mode tests.
    """

    def __init__(self, *args, **kwargs):
        super(MySQLTraditionalModeTestCase, self).__init__(*args, **kwargs)
        self.mysql_mode = 'TRADITIONAL'


class EngineFacadeTestCase(oslo_test.BaseTestCase):
    def setUp(self):
        super(EngineFacadeTestCase, self).setUp()

        self.facade = session.EngineFacade('sqlite://')

    def test_get_engine(self):
        eng1 = self.facade.get_engine()
        eng2 = self.facade.get_engine()

        self.assertIs(eng1, eng2)

    def test_get_session(self):
        ses1 = self.facade.get_session()
        ses2 = self.facade.get_session()

        self.assertIsNot(ses1, ses2)

    def test_get_session_arguments_override_default_settings(self):
        ses = self.facade.get_session(autocommit=False, expire_on_commit=True)

        self.assertFalse(ses.autocommit)
        self.assertTrue(ses.expire_on_commit)

    @mock.patch('oslo_db.sqlalchemy.orm.get_maker')
    @mock.patch('oslo_db.sqlalchemy.engines.create_engine')
    def test_creation_from_config(self, create_engine, get_maker):
        conf = cfg.ConfigOpts()
        conf.register_opts(db_options.database_opts, group='database')

        overrides = {
            'connection': 'sqlite:///:memory:',
            'slave_connection': None,
            'connection_debug': 100,
            'max_pool_size': 10,
            'mysql_sql_mode': 'TRADITIONAL',
        }
        for optname, optvalue in overrides.items():
            conf.set_override(optname, optvalue, group='database')

        session.EngineFacade.from_config(conf,
                                         autocommit=False,
                                         expire_on_commit=True)

        create_engine.assert_called_once_with(
            sql_connection='sqlite:///:memory:',
            connection_debug=100,
            max_pool_size=10,
            mysql_sql_mode='TRADITIONAL',
            sqlite_fk=False,
            idle_timeout=mock.ANY,
            retry_interval=mock.ANY,
            max_retries=mock.ANY,
            max_overflow=mock.ANY,
            connection_trace=mock.ANY,
            sqlite_synchronous=mock.ANY,
            pool_timeout=mock.ANY,
            thread_checkin=mock.ANY,
        )
        get_maker.assert_called_once_with(engine=create_engine(),
                                          autocommit=False,
                                          expire_on_commit=True)

    def test_slave_connection(self):
        paths = self.create_tempfiles([('db.master', ''), ('db.slave', '')],
                                      ext='')
        master_path = 'sqlite:///' + paths[0]
        slave_path = 'sqlite:///' + paths[1]

        facade = session.EngineFacade(
            sql_connection=master_path,
            slave_connection=slave_path
        )

        master = facade.get_engine()
        self.assertEqual(master_path, str(master.url))
        slave = facade.get_engine(use_slave=True)
        self.assertEqual(slave_path, str(slave.url))

        master_session = facade.get_session()
        self.assertEqual(master_path, str(master_session.bind.url))
        slave_session = facade.get_session(use_slave=True)
        self.assertEqual(slave_path, str(slave_session.bind.url))

    def test_slave_connection_string_not_provided(self):
        master_path = 'sqlite:///' + self.create_tempfiles(
            [('db.master', '')], ext='')[0]

        facade = session.EngineFacade(sql_connection=master_path)

        master = facade.get_engine()
        slave = facade.get_engine(use_slave=True)
        self.assertIs(master, slave)
        self.assertEqual(master_path, str(master.url))

        master_session = facade.get_session()
        self.assertEqual(master_path, str(master_session.bind.url))
        slave_session = facade.get_session(use_slave=True)
        self.assertEqual(master_path, str(slave_session.bind.url))


class SQLiteConnectTest(oslo_test.BaseTestCase):

    def _fixture(self, **kw):
        return session.create_engine("sqlite://", **kw)

    def test_sqlite_fk_listener(self):
        engine = self._fixture(sqlite_fk=True)
        self.assertEqual(
            engine.scalar("pragma foreign_keys"),
            1
        )

        engine = self._fixture(sqlite_fk=False)

        self.assertEqual(
            engine.scalar("pragma foreign_keys"),
            0
        )

    def test_sqlite_synchronous_listener(self):
        engine = self._fixture()

        # "The default setting is synchronous=FULL." (e.g. 2)
        # http://www.sqlite.org/pragma.html#pragma_synchronous
        self.assertEqual(
            engine.scalar("pragma synchronous"),
            2
        )

        engine = self._fixture(sqlite_synchronous=False)

        self.assertEqual(
            engine.scalar("pragma synchronous"),
            0
        )


class MysqlConnectTest(test_base.MySQLOpportunisticTestCase):

    def _fixture(self, sql_mode):
        return session.create_engine(self.engine.url, mysql_sql_mode=sql_mode)

    def _assert_sql_mode(self, engine, sql_mode_present, sql_mode_non_present):
        mode = engine.execute("SHOW VARIABLES LIKE 'sql_mode'").fetchone()[1]
        self.assertTrue(
            sql_mode_present in mode
        )
        if sql_mode_non_present:
            self.assertTrue(
                sql_mode_non_present not in mode
            )

    def test_set_mode_traditional(self):
        engine = self._fixture(sql_mode='TRADITIONAL')
        self._assert_sql_mode(engine, "TRADITIONAL", "ANSI")

    def test_set_mode_ansi(self):
        engine = self._fixture(sql_mode='ANSI')
        self._assert_sql_mode(engine, "ANSI", "TRADITIONAL")

    def test_set_mode_no_mode(self):
        # If _mysql_set_mode_callback is called with sql_mode=None, then
        # the SQL mode is NOT set on the connection.

        # get the GLOBAL sql_mode, not the @@SESSION, so that
        # we get what is configured for the MySQL database, as opposed
        # to what our own session.create_engine() has set it to.
        expected = self.engine.execute(
            "SELECT @@GLOBAL.sql_mode").scalar()

        engine = self._fixture(sql_mode=None)
        self._assert_sql_mode(engine, expected, None)

    def test_fail_detect_mode(self):
        # If "SHOW VARIABLES LIKE 'sql_mode'" results in no row, then
        # we get a log indicating can't detect the mode.

        log = self.useFixture(fixtures.FakeLogger(level=logging.WARN))

        mysql_conn = self.engine.raw_connection()
        self.addCleanup(mysql_conn.close)
        mysql_conn.detach()
        mysql_cursor = mysql_conn.cursor()

        def execute(statement, parameters=()):
            if "SHOW VARIABLES LIKE 'sql_mode'" in statement:
                statement = "SHOW VARIABLES LIKE 'i_dont_exist'"
            return mysql_cursor.execute(statement, parameters)

        test_engine = sqlalchemy.create_engine(self.engine.url,
                                               _initialize=False)

        with mock.patch.object(
            test_engine.pool, '_creator',
            mock.Mock(
                return_value=mock.Mock(
                    cursor=mock.Mock(
                        return_value=mock.Mock(
                            execute=execute,
                            fetchone=mysql_cursor.fetchone,
                            fetchall=mysql_cursor.fetchall
                        )
                    )
                )
            )
        ):
            engines._init_events.dispatch_on_drivername("mysql")(test_engine)

            test_engine.raw_connection()
        self.assertIn('Unable to detect effective SQL mode',
                      log.output)

    def test_logs_real_mode(self):
        # If "SHOW VARIABLES LIKE 'sql_mode'" results in a value, then
        # we get a log with the value.

        log = self.useFixture(fixtures.FakeLogger(level=logging.DEBUG))

        engine = self._fixture(sql_mode='TRADITIONAL')

        actual_mode = engine.execute(
            "SHOW VARIABLES LIKE 'sql_mode'").fetchone()[1]

        self.assertIn('MySQL server mode set to %s' % actual_mode,
                      log.output)

    def test_warning_when_not_traditional(self):
        # If "SHOW VARIABLES LIKE 'sql_mode'" results in a value that doesn't
        # include 'TRADITIONAL', then a warning is logged.

        log = self.useFixture(fixtures.FakeLogger(level=logging.WARN))
        self._fixture(sql_mode='ANSI')

        self.assertIn("consider enabling TRADITIONAL or STRICT_ALL_TABLES",
                      log.output)

    def test_no_warning_when_traditional(self):
        # If "SHOW VARIABLES LIKE 'sql_mode'" results in a value that includes
        # 'TRADITIONAL', then no warning is logged.

        log = self.useFixture(fixtures.FakeLogger(level=logging.WARN))
        self._fixture(sql_mode='TRADITIONAL')

        self.assertNotIn("consider enabling TRADITIONAL or STRICT_ALL_TABLES",
                         log.output)

    def test_no_warning_when_strict_all_tables(self):
        # If "SHOW VARIABLES LIKE 'sql_mode'" results in a value that includes
        # 'STRICT_ALL_TABLES', then no warning is logged.

        log = self.useFixture(fixtures.FakeLogger(level=logging.WARN))
        self._fixture(sql_mode='TRADITIONAL')

        self.assertNotIn("consider enabling TRADITIONAL or STRICT_ALL_TABLES",
                         log.output)


class CreateEngineTest(oslo_test.BaseTestCase):
    """Test that dialect-specific arguments/ listeners are set up correctly.

    """

    def setUp(self):
        super(CreateEngineTest, self).setUp()
        self.args = {'connect_args': {}}

    def test_queuepool_args(self):
        engines._init_connection_args(
            url.make_url("mysql+pymysql://u:p@host/test"), self.args,
            max_pool_size=10, max_overflow=10)
        self.assertEqual(self.args['pool_size'], 10)
        self.assertEqual(self.args['max_overflow'], 10)

    def test_sqlite_memory_pool_args(self):
        for _url in ("sqlite://", "sqlite:///:memory:"):
            engines._init_connection_args(
                url.make_url(_url), self.args,
                max_pool_size=10, max_overflow=10)

            # queuepool arguments are not peresnet
            self.assertTrue(
                'pool_size' not in self.args)
            self.assertTrue(
                'max_overflow' not in self.args)

            self.assertEqual(self.args['connect_args']['check_same_thread'],
                             False)

            # due to memory connection
            self.assertTrue('poolclass' in self.args)

    def test_sqlite_file_pool_args(self):
        engines._init_connection_args(
            url.make_url("sqlite:///somefile.db"), self.args,
            max_pool_size=10, max_overflow=10)

        # queuepool arguments are not peresnet
        self.assertTrue('pool_size' not in self.args)
        self.assertTrue(
            'max_overflow' not in self.args)

        self.assertFalse(self.args['connect_args'])

        # NullPool is the default for file based connections,
        # no need to specify this
        self.assertTrue('poolclass' not in self.args)

    def _test_mysql_connect_args_default(self, connect_args):
        if six.PY3:
            self.assertEqual(connect_args,
                             {'charset': 'utf8', 'use_unicode': 1})
        else:
            self.assertEqual(connect_args,
                             {'charset': 'utf8', 'use_unicode': 0})

    def test_mysql_connect_args_default(self):
        engines._init_connection_args(
            url.make_url("mysql://u:p@host/test"), self.args)
        self._test_mysql_connect_args_default(self.args['connect_args'])

    def test_mysql_oursql_connect_args_default(self):
        engines._init_connection_args(
            url.make_url("mysql+oursql://u:p@host/test"), self.args)
        self._test_mysql_connect_args_default(self.args['connect_args'])

    def test_mysql_pymysql_connect_args_default(self):
        engines._init_connection_args(
            url.make_url("mysql+pymysql://u:p@host/test"), self.args)
        self.assertEqual(self.args['connect_args'],
                         {'charset': 'utf8'})

    def test_mysql_mysqldb_connect_args_default(self):
        engines._init_connection_args(
            url.make_url("mysql+mysqldb://u:p@host/test"), self.args)
        self._test_mysql_connect_args_default(self.args['connect_args'])

    def test_postgresql_connect_args_default(self):
        engines._init_connection_args(
            url.make_url("postgresql://u:p@host/test"), self.args)
        self.assertEqual(self.args['client_encoding'], 'utf8')
        self.assertFalse(self.args['connect_args'])

    def test_mysqlconnector_raise_on_warnings_default(self):
        engines._init_connection_args(
            url.make_url("mysql+mysqlconnector://u:p@host/test"),
            self.args)
        self.assertEqual(self.args['connect_args']['raise_on_warnings'], False)

    def test_mysqlconnector_raise_on_warnings_override(self):
        engines._init_connection_args(
            url.make_url(
                "mysql+mysqlconnector://u:p@host/test"
                "?raise_on_warnings=true"),
            self.args
        )

        self.assertFalse('raise_on_warnings' in self.args['connect_args'])

    def test_thread_checkin(self):
        with mock.patch("sqlalchemy.event.listens_for"):
            with mock.patch("sqlalchemy.event.listen") as listen_evt:
                engines._init_events.dispatch_on_drivername(
                    "sqlite")(mock.Mock())

        self.assertEqual(
            listen_evt.mock_calls[0][1][-1],
            engines._thread_yield
        )


class ProcessGuardTest(test_base.DbTestCase):
    def test_process_guard(self):
        self.engine.dispose()

        def get_parent_pid():
            return 4

        def get_child_pid():
            return 5

        with mock.patch("os.getpid", get_parent_pid):
            with self.engine.connect() as conn:
                dbapi_id = id(conn.connection.connection)

        with mock.patch("os.getpid", get_child_pid):
            with self.engine.connect() as conn:
                new_dbapi_id = id(conn.connection.connection)

        self.assertNotEqual(dbapi_id, new_dbapi_id)

        # ensure it doesn't trip again
        with mock.patch("os.getpid", get_child_pid):
            with self.engine.connect() as conn:
                newer_dbapi_id = id(conn.connection.connection)

        self.assertEqual(new_dbapi_id, newer_dbapi_id)


class PatchStacktraceTest(test_base.DbTestCase):

    def test_trace(self):
        engine = self.engine

        # NOTE(viktors): The code in oslo_db.sqlalchemy.session filters out
        #                lines from modules under oslo_db, so we should remove
        #                 "oslo_db/" from file path in traceback.
        import traceback
        orig_extract_stack = traceback.extract_stack

        def extract_stack():
            return [(row[0].replace("oslo_db/", ""), row[1], row[2], row[3])
                    for row in orig_extract_stack()]

        with mock.patch("traceback.extract_stack", side_effect=extract_stack):

            engines._add_trace_comments(engine)
            conn = engine.connect()
            orig_do_exec = engine.dialect.do_execute
            with mock.patch.object(engine.dialect, "do_execute") as mock_exec:

                mock_exec.side_effect = orig_do_exec
                conn.execute("select 1;")

            call = mock_exec.mock_calls[0]

            # we're the caller, see that we're in there
            self.assertIn("tests/sqlalchemy/test_sqlalchemy.py", call[1][1])