diff --git a/oslo/db/sqlalchemy/session.py b/oslo/db/sqlalchemy/session.py index 57b5fb8..b24e06b 100644 --- a/oslo/db/sqlalchemy/session.py +++ b/oslo/db/sqlalchemy/session.py @@ -376,6 +376,7 @@ def create_engine(sql_connection, sqlite_fk=False, mysql_sql_mode=None, engine_args = { "pool_recycle": idle_timeout, 'convert_unicode': True, + 'connect_args': {}, } _setup_logging(connection_debug) @@ -433,7 +434,22 @@ def _init_connection_args(url, engine_args, **kw): # replace it with StaticPool. if issubclass(pool_class, pool.SingletonThreadPool): engine_args["poolclass"] = pool.StaticPool - engine_args["connect_args"] = {'check_same_thread': False} + engine_args['connect_args']['check_same_thread'] = False + + +@_init_connection_args.dispatch_for("postgresql") +def _init_connection_args(url, engine_args, **kw): + if 'client_encoding' not in url.query: + # Set encoding using engine_args instead of connect_args since + # it's supported for PostgreSQL 8.*. More details at: + # http://docs.sqlalchemy.org/en/rel_0_9/dialects/postgresql.html + engine_args['client_encoding'] = 'utf8' + + +@_init_connection_args.dispatch_for("mysql") +def _init_connection_args(url, engine_args, **kw): + if 'charset' not in url.query: + engine_args['connect_args']['charset'] = 'utf8' @_init_connection_args.dispatch_for("mysql+mysqlconnector") @@ -441,8 +457,18 @@ def _init_connection_args(url, engine_args, **kw): # mysqlconnector engine (<1.0) incorrectly defaults to # raise_on_warnings=True # https://bitbucket.org/zzzeek/sqlalchemy/issue/2515 - if "raise_on_warnings" not in url.query: - engine_args['connect_args'] = {'raise_on_warnings': False} + if 'raise_on_warnings' not in url.query: + engine_args['connect_args']['raise_on_warnings'] = False + + +@_init_connection_args.dispatch_for("mysql+mysqldb") +@_init_connection_args.dispatch_for("mysql+oursql") +def _init_connection_args(url, engine_args, **kw): + # Those drivers require use_unicode=0 to avoid performance drop due + # to internal usage of Python unicode objects in the driver + # http://docs.sqlalchemy.org/en/rel_0_9/dialects/mysql.html + if 'use_unicode' not in url.query: + engine_args['connect_args']['use_unicode'] = 0 @utils.dispatch_for_dialect('*', multiple=True) diff --git a/tests/sqlalchemy/test_sqlalchemy.py b/tests/sqlalchemy/test_sqlalchemy.py index 0aeeb05..1c46399 100644 --- a/tests/sqlalchemy/test_sqlalchemy.py +++ b/tests/sqlalchemy/test_sqlalchemy.py @@ -536,84 +536,90 @@ class CreateEngineTest(oslo_test.BaseTestCase): """ + def setUp(self): + super(CreateEngineTest, self).setUp() + self.args = {'connect_args': {}} + def test_queuepool_args(self): - args = {} session._init_connection_args( - url.make_url("mysql://u:p@host/test"), args, + url.make_url("mysql://u:p@host/test"), self.args, max_pool_size=10, max_overflow=10) - self.assertEqual(args['pool_size'], 10) - self.assertEqual(args['max_overflow'], 10) + self.assertEqual(self.args['pool_size'], 10) + self.assertEqual(self.args['max_overflow'], 10) def test_sqlite_memory_pool_args(self): for _url in ("sqlite://", "sqlite:///:memory:"): - args = {} session._init_connection_args( - url.make_url(_url), args, + url.make_url(_url), self.args, max_pool_size=10, max_overflow=10) # queuepool arguments are not peresnet self.assertTrue( - 'pool_size' not in args) + 'pool_size' not in self.args) self.assertTrue( - 'max_overflow' not in args) + 'max_overflow' not in self.args) - self.assertEqual( - args['connect_args'], - {'check_same_thread': False} - ) + self.assertEqual(self.args['connect_args']['check_same_thread'], + False) # due to memory connection - self.assertTrue('poolclass' in args) + self.assertTrue('poolclass' in self.args) def test_sqlite_file_pool_args(self): - args = {} session._init_connection_args( - url.make_url("sqlite:///somefile.db"), args, + url.make_url("sqlite:///somefile.db"), self.args, max_pool_size=10, max_overflow=10) # queuepool arguments are not peresnet - self.assertTrue('pool_size' not in args) + self.assertTrue('pool_size' not in self.args) self.assertTrue( - 'max_overflow' not in args) + 'max_overflow' not in self.args) - self.assertTrue( - 'connect_args' not in args - ) + self.assertFalse(self.args['connect_args']) # NullPool is the default for file based connections, # no need to specify this - self.assertTrue('poolclass' not in args) + self.assertTrue('poolclass' not in self.args) def test_mysql_connect_args_default(self): - args = {} session._init_connection_args( - url.make_url("mysql+mysqldb://u:p@host/test"), args) - self.assertTrue( - 'connect_args' not in args - ) + url.make_url("mysql://u:p@host/test"), self.args) + self.assertEqual(self.args['connect_args'], + {'charset': 'utf8', 'use_unicode': 0}) + + def test_mysql_oursql_connect_args_default(self): + session._init_connection_args( + url.make_url("mysql+oursql://u:p@host/test"), self.args) + self.assertEqual(self.args['connect_args'], + {'charset': 'utf8', 'use_unicode': 0}) + + def test_mysql_mysqldb_connect_args_default(self): + session._init_connection_args( + url.make_url("mysql+mysqldb://u:p@host/test"), self.args) + self.assertEqual(self.args['connect_args'], + {'charset': 'utf8', 'use_unicode': 0}) + + def test_postgresql_connect_args_default(self): + session._init_connection_args( + url.make_url("postgresql://u:p@host/test"), self.args) + self.assertEqual(self.args['client_encoding'], 'utf8') + self.assertFalse(self.args['connect_args']) def test_mysqlconnector_raise_on_warnings_default(self): - args = {} session._init_connection_args( url.make_url("mysql+mysqlconnector://u:p@host/test"), - args) - self.assertEqual( - args, - {'connect_args': {'raise_on_warnings': False}} - ) + self.args) + self.assertEqual(self.args['connect_args']['raise_on_warnings'], False) def test_mysqlconnector_raise_on_warnings_override(self): - args = {} session._init_connection_args( url.make_url( "mysql+mysqlconnector://u:p@host/test" "?raise_on_warnings=true"), - args + self.args ) - self.assertTrue( - 'connect_args' not in args - ) + self.assertFalse('raise_on_warnings' in self.args['connect_args']) def test_thread_checkin(self): with mock.patch("sqlalchemy.event.listens_for"):