Merge "Accommodate immutable URL api"

This commit is contained in:
Zuul 2021-04-28 11:12:59 +00:00 committed by Gerrit Code Review
commit 432ee2f34e
4 changed files with 46 additions and 9 deletions

View File

@ -105,6 +105,14 @@ def _setup_logging(connection_debug=0):
def _extend_url_parameters(url, connection_parameters): def _extend_url_parameters(url, connection_parameters):
# TODO(zzzeek): remove hasattr() conditional when SQLAlchemy 1.4 is the
# minimum version in requirements; call update_query_string()
# unconditionally
if hasattr(url, "update_query_string"):
return url.update_query_string(connection_parameters, append=True)
# TODO(zzzeek): remove the remainder of this method when SQLAlchemy 1.4
# is the minimum version in requirements
for key, value in parse.parse_qs( for key, value in parse.parse_qs(
connection_parameters).items(): connection_parameters).items():
if key in url.query: if key in url.query:
@ -118,6 +126,8 @@ def _extend_url_parameters(url, connection_parameters):
if len(value) == 1: if len(value) == 1:
url.query[key] = value[0] url.query[key] = value[0]
return url
def _vet_url(url): def _vet_url(url):
if "+" not in url.drivername and not url.drivername.startswith("sqlite"): if "+" not in url.drivername and not url.drivername.startswith("sqlite"):
@ -153,7 +163,7 @@ def create_engine(sql_connection, sqlite_fk=False, mysql_sql_mode=None,
url = sqlalchemy.engine.url.make_url(sql_connection) url = sqlalchemy.engine.url.make_url(sql_connection)
if connection_parameters: if connection_parameters:
_extend_url_parameters(url, connection_parameters) url = _extend_url_parameters(url, connection_parameters)
_vet_url(url) _vet_url(url)

View File

@ -495,7 +495,16 @@ class BackendImpl(object, metaclass=abc.ABCMeta):
""" """
url = sa_url.make_url(str(base_url)) url = sa_url.make_url(str(base_url))
url.database = ident
# TODO(zzzeek): remove hasattr() conditional in favor of "url.set()"
# when SQLAlchemy 1.4 is the minimum version in requirements
if hasattr(url, "set"):
url = url.set(database=ident)
else:
# TODO(zzzeek): remove when SQLAlchemy 1.4
# is the minimum version in requirements
url.database = ident
return url return url

View File

@ -386,8 +386,16 @@ class TestNonExistentDatabase(
super(TestNonExistentDatabase, self).setUp() super(TestNonExistentDatabase, self).setUp()
url = sqla_url.make_url(str(self.engine.url)) url = sqla_url.make_url(str(self.engine.url))
url.database = 'non_existent_database'
self.url = url # TODO(zzzeek): remove hasattr() conditional in favor of "url.set()"
# when SQLAlchemy 1.4 is the minimum version in requirements
if hasattr(url, "set"):
self.url = url.set(database="non_existent_database")
else:
# TODO(zzzeek): remove when SQLAlchemy 1.4
# is the minimum version in requirements
url.database = 'non_existent_database'
self.url = url
def test_raise(self): def test_raise(self):
matched = self.assertRaises( matched = self.assertRaises(

View File

@ -229,6 +229,16 @@ class QueryParamTest(test_base.DbTestCase):
"oslo_db.sqlalchemy.engines.sqlalchemy.create_engine", "oslo_db.sqlalchemy.engines.sqlalchemy.create_engine",
side_effect=_mock_create_engine) side_effect=_mock_create_engine)
def _normalize_query_dict(self, qdict):
# SQLAlchemy 1.4 returns url.query as:
# immutabledict({k1: v1, k2: (v2a, v2b, ...), ...})
# that is with tuples not lists for multiparams
return {
k: list(v) if isinstance(v, tuple) else v
for k, v in qdict.items()
}
def test_add_assorted_params(self): def test_add_assorted_params(self):
with self._fixture() as ce: with self._fixture() as ce:
engines.create_engine( engines.create_engine(
@ -236,7 +246,7 @@ class QueryParamTest(test_base.DbTestCase):
connection_parameters="foo=bar&bat=hoho&bat=param2") connection_parameters="foo=bar&bat=hoho&bat=param2")
self.assertEqual( self.assertEqual(
ce.mock_calls[0][1][0].query, self._normalize_query_dict(ce.mock_calls[0][1][0].query),
{'bat': ['hoho', 'param2'], 'foo': 'bar'} {'bat': ['hoho', 'param2'], 'foo': 'bar'}
) )
@ -247,7 +257,7 @@ class QueryParamTest(test_base.DbTestCase):
self.assertEqual( self.assertEqual(
ce.mock_calls[0][1][0].query, ce.mock_calls[0][1][0].query,
{} self._normalize_query_dict({})
) )
def test_combine_params(self): def test_combine_params(self):
@ -260,7 +270,7 @@ class QueryParamTest(test_base.DbTestCase):
"bind_host=192.168.1.5") "bind_host=192.168.1.5")
self.assertEqual( self.assertEqual(
ce.mock_calls[0][1][0].query, self._normalize_query_dict(ce.mock_calls[0][1][0].query),
{ {
'bind_host': '192.168.1.5', 'bind_host': '192.168.1.5',
'charset': 'utf8', 'charset': 'utf8',
@ -280,7 +290,7 @@ class QueryParamTest(test_base.DbTestCase):
"bind_host=192.168.1.5") "bind_host=192.168.1.5")
self.assertEqual( self.assertEqual(
ce.mock_calls[0][1][0].query, self._normalize_query_dict(ce.mock_calls[0][1][0].query),
{ {
'bind_host': '192.168.1.5', 'bind_host': '192.168.1.5',
'charset': 'utf8', 'charset': 'utf8',
@ -751,7 +761,7 @@ class CreateEngineTest(oslo_test.BaseTestCase):
def warn_interpolate(msg, args): def warn_interpolate(msg, args):
# test the interpolation itself to ensure the password # test the interpolation itself to ensure the password
# is concealed # is concealed
warnings.warning(msg % args) warnings.warning(msg % (args, ))
with mock.patch( with mock.patch(
"oslo_db.sqlalchemy.engines.LOG.warning", "oslo_db.sqlalchemy.engines.LOG.warning",