From d5f6dcf9ffd71366b1c3baaccacbf4cbdf05c94b Mon Sep 17 00:00:00 2001 From: Mohammed Naser Date: Wed, 28 Aug 2024 19:16:07 -0400 Subject: [PATCH] Switch to using enginefacade Closes-Bug: #2067345 Change-Id: If9a2c96628cfcb819fee5e19f872ea015979b30f (cherry picked from commit 0ce2c41404f1f8dcd1bcd19d36a885edc34926a2) --- magnum/common/context.py | 2 + magnum/db/sqlalchemy/alembic/env.py | 4 +- magnum/db/sqlalchemy/api.py | 749 ++++++++++-------- magnum/db/sqlalchemy/models.py | 9 - magnum/tests/unit/db/base.py | 26 +- magnum/tests/unit/db/sqlalchemy/test_types.py | 24 +- 6 files changed, 437 insertions(+), 377 deletions(-) diff --git a/magnum/common/context.py b/magnum/common/context.py index 7d6c4011a8..a225c2194a 100644 --- a/magnum/common/context.py +++ b/magnum/common/context.py @@ -12,6 +12,7 @@ from eventlet.green import threading from oslo_context import context +from oslo_db.sqlalchemy import enginefacade from magnum.common import policy @@ -20,6 +21,7 @@ import magnum.conf CONF = magnum.conf.CONF +@enginefacade.transaction_context_provider class RequestContext(context.RequestContext): """Extends security contexts from the OpenStack common library.""" diff --git a/magnum/db/sqlalchemy/alembic/env.py b/magnum/db/sqlalchemy/alembic/env.py index ff264b7652..e7690eee4e 100644 --- a/magnum/db/sqlalchemy/alembic/env.py +++ b/magnum/db/sqlalchemy/alembic/env.py @@ -13,8 +13,8 @@ from logging import config as log_config from alembic import context +from oslo_db.sqlalchemy import enginefacade -from magnum.db.sqlalchemy import api as sqla_api from magnum.db.sqlalchemy import models # this is the Alembic Config object, which provides @@ -43,7 +43,7 @@ def run_migrations_online(): and associate a connection with the context. """ - engine = sqla_api.get_engine() + engine = enginefacade.writer.get_engine() with engine.connect() as connection: context.configure(connection=connection, target_metadata=target_metadata) diff --git a/magnum/db/sqlalchemy/api.py b/magnum/db/sqlalchemy/api.py index 0ec438063d..5b5f0ac451 100644 --- a/magnum/db/sqlalchemy/api.py +++ b/magnum/db/sqlalchemy/api.py @@ -14,8 +14,11 @@ """SQLAlchemy storage backend.""" +import threading + +from oslo_db import api as oslo_db_api from oslo_db import exception as db_exc -from oslo_db.sqlalchemy import session as db_session +from oslo_db.sqlalchemy import enginefacade from oslo_db.sqlalchemy import utils as db_utils from oslo_log import log from oslo_utils import importutils @@ -35,34 +38,13 @@ from magnum.db import api from magnum.db.sqlalchemy import models from magnum.i18n import _ -profiler_sqlalchemy = importutils.try_import('osprofiler.sqlalchemy') +profiler_sqlalchemy = importutils.try_import("osprofiler.sqlalchemy") CONF = magnum.conf.CONF LOG = log.getLogger(__name__) -_FACADE = None - - -def _create_facade_lazily(): - global _FACADE - if _FACADE is None: - _FACADE = db_session.EngineFacade.from_config(CONF) - if profiler_sqlalchemy: - if CONF.profiler.enabled and CONF.profiler.trace_sqlalchemy: - profiler_sqlalchemy.add_tracing(sa, _FACADE.get_engine(), "db") - - return _FACADE - - -def get_engine(): - facade = _create_facade_lazily() - return facade.get_engine() - - -def get_session(**kwargs): - facade = _create_facade_lazily() - return facade.get_session(**kwargs) +_CONTEXT = threading.local() def get_backend(): @@ -70,15 +52,21 @@ def get_backend(): return Connection() -def model_query(model, *args, **kwargs): - """Query helper for simpler session usage. +def _session_for_read(): + return _wrap_session(enginefacade.reader.using(_CONTEXT)) - :param session: if present, the session to use - """ - session = kwargs.get('session') or get_session() - query = session.query(model, *args) - return query +# NOTE(tylerchristie) Please add @oslo_db_api.retry_on_deadlock decorator to +# any new methods using _session_for_write (as deadlocks happen on write), so +# that oslo_db is able to retry in case of deadlocks. +def _session_for_write(): + return _wrap_session(enginefacade.writer.using(_CONTEXT)) + + +def _wrap_session(session): + if CONF.profiler.enabled and CONF.profiler.trace_sqlalchemy: + session = profiler_sqlalchemy.wrap_session(sa, session) + return session def add_identity_filter(query, value): @@ -101,8 +89,6 @@ def add_identity_filter(query, value): def _paginate_query(model, limit=None, marker=None, sort_key=None, sort_dir=None, query=None): - if not query: - query = model_query(model) sort_keys = ['id'] if sort_key and sort_key not in sort_keys: sort_keys.insert(0, sort_key) @@ -166,15 +152,16 @@ class Connection(api.Connection): # Helper to filter based on node_count field from nodegroups def filter_node_count(query, node_count, is_master=False): nfunc = func.sum(models.NodeGroup.node_count) - nquery = model_query(models.NodeGroup) - if is_master: - nquery = nquery.filter(models.NodeGroup.role == 'master') - else: - nquery = nquery.filter(models.NodeGroup.role != 'master') - nquery = nquery.group_by(models.NodeGroup.cluster_id) - nquery = nquery.having(nfunc == node_count) - uuids = [ng.cluster_id for ng in nquery.all()] - return query.filter(models.Cluster.uuid.in_(uuids)) + with _session_for_read() as session: + nquery = session.query(models.NodeGroup) + if is_master: + nquery = nquery.filter(models.NodeGroup.role == 'master') + else: + nquery = nquery.filter(models.NodeGroup.role != 'master') + nquery = nquery.group_by(models.NodeGroup.cluster_id) + nquery = nquery.having(nfunc == node_count) + uuids = [ng.cluster_id for ng in nquery.all()] + return query.filter(models.Cluster.uuid.in_(uuids)) if 'node_count' in filters: query = filter_node_count( @@ -187,12 +174,14 @@ class Connection(api.Connection): def get_cluster_list(self, context, filters=None, limit=None, marker=None, sort_key=None, sort_dir=None): - query = model_query(models.Cluster) - query = self._add_tenant_filters(context, query) - query = self._add_clusters_filters(query, filters) - return _paginate_query(models.Cluster, limit, marker, - sort_key, sort_dir, query) + with _session_for_read() as session: + query = session.query(models.Cluster) + query = self._add_tenant_filters(context, query) + query = self._add_clusters_filters(query, filters) + return _paginate_query( + models.Cluster, limit, marker, sort_key, sort_dir, query) + @oslo_db_api.retry_on_deadlock def create_cluster(self, values): # ensure defaults are present for new clusters if not values.get('uuid'): @@ -200,68 +189,77 @@ class Connection(api.Connection): cluster = models.Cluster() cluster.update(values) - try: - cluster.save() - except db_exc.DBDuplicateEntry: - raise exception.ClusterAlreadyExists(uuid=values['uuid']) - return cluster + + with _session_for_write() as session: + try: + session.add(cluster) + session.flush() + except db_exc.DBDuplicateEntry: + raise exception.ClusterAlreadyExists(uuid=values['uuid']) + return cluster def get_cluster_by_id(self, context, cluster_id): - query = model_query(models.Cluster) - query = self._add_tenant_filters(context, query) - query = query.filter_by(id=cluster_id) - try: - return query.one() - except NoResultFound: - raise exception.ClusterNotFound(cluster=cluster_id) + with _session_for_read() as session: + query = session.query(models.Cluster) + query = self._add_tenant_filters(context, query) + query = query.filter_by(id=cluster_id) + try: + return query.one() + except NoResultFound: + raise exception.ClusterNotFound(cluster=cluster_id) def get_cluster_by_name(self, context, cluster_name): - query = model_query(models.Cluster) - query = self._add_tenant_filters(context, query) - query = query.filter_by(name=cluster_name) - try: - return query.one() - except MultipleResultsFound: - raise exception.Conflict('Multiple clusters exist with same name.' - ' Please use the cluster uuid instead.') - except NoResultFound: - raise exception.ClusterNotFound(cluster=cluster_name) + with _session_for_read() as session: + query = session.query(models.Cluster) + query = self._add_tenant_filters(context, query) + query = query.filter_by(name=cluster_name) + try: + return query.one() + except MultipleResultsFound: + raise exception.Conflict( + 'Multiple clusters exist with same name.' + ' Please use the cluster uuid instead.') + except NoResultFound: + raise exception.ClusterNotFound(cluster=cluster_name) def get_cluster_by_uuid(self, context, cluster_uuid): - query = model_query(models.Cluster) - query = self._add_tenant_filters(context, query) - query = query.filter_by(uuid=cluster_uuid) - try: - return query.one() - except NoResultFound: - raise exception.ClusterNotFound(cluster=cluster_uuid) + with _session_for_read() as session: + query = session.query(models.Cluster) + query = self._add_tenant_filters(context, query) + query = query.filter_by(uuid=cluster_uuid) + try: + return query.one() + except NoResultFound: + raise exception.ClusterNotFound(cluster=cluster_uuid) def get_cluster_stats(self, context, project_id=None): - query = model_query(models.Cluster) - node_count_col = models.NodeGroup.node_count - ncfunc = func.sum(node_count_col) + with _session_for_read() as session: + query = session.query(models.Cluster) + node_count_col = models.NodeGroup.node_count + ncfunc = func.sum(node_count_col) - if project_id: - query = query.filter_by(project_id=project_id) - nquery = query.session.query(ncfunc.label("nodes")).filter_by( - project_id=project_id) - else: - nquery = query.session.query(ncfunc.label("nodes")) + if project_id: + query = query.filter_by(project_id=project_id) + nquery = query.session.query(ncfunc.label("nodes")).filter_by( + project_id=project_id) + else: + nquery = query.session.query(ncfunc.label("nodes")) clusters = query.count() nodes = int(nquery.one()[0]) if nquery.one()[0] else 0 return clusters, nodes def get_cluster_count_all(self, context, filters=None): - query = model_query(models.Cluster) - query = self._add_tenant_filters(context, query) - query = self._add_clusters_filters(query, filters) - return query.count() + with _session_for_read() as session: + query = session.query(models.Cluster) + query = self._add_tenant_filters(context, query) + query = self._add_clusters_filters(query, filters) + return query.count() + @oslo_db_api.retry_on_deadlock def destroy_cluster(self, cluster_id): - session = get_session() - with session.begin(): - query = model_query(models.Cluster, session=session) + with _session_for_write() as session: + query = session.query(models.Cluster) query = add_identity_filter(query, cluster_id) try: @@ -279,10 +277,10 @@ class Connection(api.Connection): return self._do_update_cluster(cluster_id, values) + @oslo_db_api.retry_on_deadlock def _do_update_cluster(self, cluster_id, values): - session = get_session() - with session.begin(): - query = model_query(models.Cluster, session=session) + with _session_for_write() as session: + query = session.query(models.Cluster) query = add_identity_filter(query, cluster_id) try: ref = query.with_for_update().one() @@ -309,22 +307,24 @@ class Connection(api.Connection): def get_cluster_template_list(self, context, filters=None, limit=None, marker=None, sort_key=None, sort_dir=None): - query = model_query(models.ClusterTemplate) - query = self._add_tenant_filters(context, query) - query = self._add_cluster_template_filters(query, filters) - # include public (and not hidden) ClusterTemplates - public_q = model_query(models.ClusterTemplate).filter_by( - public=True, hidden=False) - query = query.union(public_q) - # include hidden and public ClusterTemplate if admin - if context.is_admin: - hidden_q = model_query(models.ClusterTemplate).filter_by( - public=True, hidden=True) - query = query.union(hidden_q) + with _session_for_read() as session: + query = session.query(models.ClusterTemplate) + query = self._add_tenant_filters(context, query) + query = self._add_cluster_template_filters(query, filters) + # include public (and not hidden) ClusterTemplates + public_q = session.query(models.ClusterTemplate).filter_by( + public=True, hidden=False) + query = query.union(public_q) + # include hidden and public ClusterTemplate if admin + if context.is_admin: + hidden_q = session.query(models.ClusterTemplate).filter_by( + public=True, hidden=True) + query = query.union(hidden_q) return _paginate_query(models.ClusterTemplate, limit, marker, sort_key, sort_dir, query) + @oslo_db_api.retry_on_deadlock def create_cluster_template(self, values): # ensure defaults are present for new ClusterTemplates if not values.get('uuid'): @@ -332,59 +332,71 @@ class Connection(api.Connection): cluster_template = models.ClusterTemplate() cluster_template.update(values) - try: - cluster_template.save() - except db_exc.DBDuplicateEntry: - raise exception.ClusterTemplateAlreadyExists(uuid=values['uuid']) - return cluster_template + + with _session_for_write() as session: + try: + session.add(cluster_template) + session.flush() + except db_exc.DBDuplicateEntry: + raise exception.ClusterTemplateAlreadyExists( + uuid=values['uuid']) + return cluster_template def get_cluster_template_by_id(self, context, cluster_template_id): - query = model_query(models.ClusterTemplate) - query = self._add_tenant_filters(context, query) - public_q = model_query(models.ClusterTemplate).filter_by(public=True) - query = query.union(public_q) - query = query.filter(models.ClusterTemplate.id == cluster_template_id) - try: - return query.one() - except NoResultFound: - raise exception.ClusterTemplateNotFound( - clustertemplate=cluster_template_id) + with _session_for_read() as session: + query = session.query(models.ClusterTemplate) + query = self._add_tenant_filters(context, query) + public_q = session.query(models.ClusterTemplate).filter_by( + public=True) + query = query.union(public_q) + query = query.filter( + models.ClusterTemplate.id == cluster_template_id) + try: + return query.one() + except NoResultFound: + raise exception.ClusterTemplateNotFound( + clustertemplate=cluster_template_id) def get_cluster_template_by_uuid(self, context, cluster_template_uuid): - query = model_query(models.ClusterTemplate) - query = self._add_tenant_filters(context, query) - public_q = model_query(models.ClusterTemplate).filter_by(public=True) - query = query.union(public_q) - query = query.filter( + with _session_for_read() as session: + query = session.query(models.ClusterTemplate) + query = self._add_tenant_filters(context, query) + public_q = session.query(models.ClusterTemplate).filter_by( + public=True) + query = query.union(public_q) + query = query.filter( models.ClusterTemplate.uuid == cluster_template_uuid) - try: - return query.one() - except NoResultFound: - raise exception.ClusterTemplateNotFound( - clustertemplate=cluster_template_uuid) + try: + return query.one() + except NoResultFound: + raise exception.ClusterTemplateNotFound( + clustertemplate=cluster_template_uuid) def get_cluster_template_by_name(self, context, cluster_template_name): - query = model_query(models.ClusterTemplate) - query = self._add_tenant_filters(context, query) - public_q = model_query(models.ClusterTemplate).filter_by(public=True) - query = query.union(public_q) - query = query.filter( + with _session_for_read() as session: + query = session.query(models.ClusterTemplate) + query = self._add_tenant_filters(context, query) + public_q = session.query(models.ClusterTemplate).filter_by( + public=True) + query = query.union(public_q) + query = query.filter( models.ClusterTemplate.name == cluster_template_name) - try: - return query.one() - except MultipleResultsFound: - raise exception.Conflict('Multiple ClusterTemplates exist with' - ' same name. Please use the ' - 'ClusterTemplate uuid instead.') - except NoResultFound: - raise exception.ClusterTemplateNotFound( - clustertemplate=cluster_template_name) + try: + return query.one() + except MultipleResultsFound: + raise exception.Conflict( + 'Multiple ClusterTemplates exist with' + ' same name. Please use the ' + 'ClusterTemplate uuid instead.') + except NoResultFound: + raise exception.ClusterTemplateNotFound( + clustertemplate=cluster_template_name) def _is_cluster_template_referenced(self, session, cluster_template_uuid): """Checks whether the ClusterTemplate is referenced by cluster(s).""" - query = model_query(models.Cluster, session=session) - query = self._add_clusters_filters(query, {'cluster_template_id': - cluster_template_uuid}) + query = session.query(models.Cluster) + query = self._add_clusters_filters( + query, {'cluster_template_id': cluster_template_uuid}) return query.count() != 0 def _is_publishing_cluster_template(self, values): @@ -395,10 +407,10 @@ class Connection(api.Connection): return True return False + @oslo_db_api.retry_on_deadlock def destroy_cluster_template(self, cluster_template_id): - session = get_session() - with session.begin(): - query = model_query(models.ClusterTemplate, session=session) + with _session_for_write() as session: + query = session.query(models.ClusterTemplate) query = add_identity_filter(query, cluster_template_id) try: @@ -422,10 +434,10 @@ class Connection(api.Connection): return self._do_update_cluster_template(cluster_template_id, values) + @oslo_db_api.retry_on_deadlock def _do_update_cluster_template(self, cluster_template_id, values): - session = get_session() - with session.begin(): - query = model_query(models.ClusterTemplate, session=session) + with _session_for_write() as session: + query = session.query(models.ClusterTemplate) query = add_identity_filter(query, cluster_template_id) try: ref = query.with_for_update().one() @@ -444,6 +456,7 @@ class Connection(api.Connection): ref.update(values) return ref + @oslo_db_api.retry_on_deadlock def create_x509keypair(self, values): # ensure defaults are present for new x509keypairs if not values.get('uuid'): @@ -451,34 +464,40 @@ class Connection(api.Connection): x509keypair = models.X509KeyPair() x509keypair.update(values) - try: - x509keypair.save() - except db_exc.DBDuplicateEntry: - raise exception.X509KeyPairAlreadyExists(uuid=values['uuid']) - return x509keypair + + with _session_for_write() as session: + try: + session.add(x509keypair) + session.flush() + except db_exc.DBDuplicateEntry: + raise exception.X509KeyPairAlreadyExists(uuid=values['uuid']) + return x509keypair def get_x509keypair_by_id(self, context, x509keypair_id): - query = model_query(models.X509KeyPair) - query = self._add_tenant_filters(context, query) - query = query.filter_by(id=x509keypair_id) - try: - return query.one() - except NoResultFound: - raise exception.X509KeyPairNotFound(x509keypair=x509keypair_id) + with _session_for_read() as session: + query = session.query(models.X509KeyPair) + query = self._add_tenant_filters(context, query) + query = query.filter_by(id=x509keypair_id) + try: + return query.one() + except NoResultFound: + raise exception.X509KeyPairNotFound(x509keypair=x509keypair_id) def get_x509keypair_by_uuid(self, context, x509keypair_uuid): - query = model_query(models.X509KeyPair) - query = self._add_tenant_filters(context, query) - query = query.filter_by(uuid=x509keypair_uuid) - try: - return query.one() - except NoResultFound: - raise exception.X509KeyPairNotFound(x509keypair=x509keypair_uuid) + with _session_for_read() as session: + query = session.query(models.X509KeyPair) + query = self._add_tenant_filters(context, query) + query = query.filter_by(uuid=x509keypair_uuid) + try: + return query.one() + except NoResultFound: + raise exception.X509KeyPairNotFound( + x509keypair=x509keypair_uuid) + @oslo_db_api.retry_on_deadlock def destroy_x509keypair(self, x509keypair_id): - session = get_session() - with session.begin(): - query = model_query(models.X509KeyPair, session=session) + with _session_for_write() as session: + query = session.query(models.X509KeyPair) query = add_identity_filter(query, x509keypair_id) count = query.delete() if count != 1: @@ -492,10 +511,10 @@ class Connection(api.Connection): return self._do_update_x509keypair(x509keypair_id, values) + @oslo_db_api.retry_on_deadlock def _do_update_x509keypair(self, x509keypair_id, values): - session = get_session() - with session.begin(): - query = model_query(models.X509KeyPair, session=session) + with _session_for_write() as session: + query = session.query(models.X509KeyPair) query = add_identity_filter(query, x509keypair_id) try: ref = query.with_for_update().one() @@ -518,26 +537,27 @@ class Connection(api.Connection): def get_x509keypair_list(self, context, filters=None, limit=None, marker=None, sort_key=None, sort_dir=None): - query = model_query(models.X509KeyPair) - query = self._add_tenant_filters(context, query) - query = self._add_x509keypairs_filters(query, filters) - return _paginate_query(models.X509KeyPair, limit, marker, - sort_key, sort_dir, query) + with _session_for_read() as session: + query = session.query(models.X509KeyPair) + query = self._add_tenant_filters(context, query) + query = self._add_x509keypairs_filters(query, filters) + return _paginate_query( + models.X509KeyPair, limit, marker, sort_key, sort_dir, query) + @oslo_db_api.retry_on_deadlock def destroy_magnum_service(self, magnum_service_id): - session = get_session() - with session.begin(): - query = model_query(models.MagnumService, session=session) + with _session_for_write() as session: + query = session.query(models.MagnumService) query = add_identity_filter(query, magnum_service_id) count = query.delete() if count != 1: raise exception.MagnumServiceNotFound( magnum_service_id=magnum_service_id) + @oslo_db_api.retry_on_deadlock def update_magnum_service(self, magnum_service_id, values): - session = get_session() - with session.begin(): - query = model_query(models.MagnumService, session=session) + with _session_for_write() as session: + query = session.query(models.MagnumService) query = add_identity_filter(query, magnum_service_id) try: ref = query.with_for_update().one() @@ -553,48 +573,61 @@ class Connection(api.Connection): return ref def get_magnum_service_by_host_and_binary(self, host, binary): - query = model_query(models.MagnumService) - query = query.filter_by(host=host, binary=binary) - try: - return query.one() - except NoResultFound: - return None + with _session_for_read() as session: + query = session.query(models.MagnumService) + query = query.filter_by(host=host, binary=binary) + try: + return query.one() + except NoResultFound: + return None + @oslo_db_api.retry_on_deadlock def create_magnum_service(self, values): magnum_service = models.MagnumService() magnum_service.update(values) - try: - magnum_service.save() - except db_exc.DBDuplicateEntry: - host = values["host"] - binary = values["binary"] - LOG.warning("Magnum service with same host:%(host)s and" - " binary:%(binary)s had been saved into DB", - {'host': host, 'binary': binary}) - query = model_query(models.MagnumService) - query = query.filter_by(host=host, binary=binary) - return query.one() - return magnum_service + + with _session_for_write() as session: + try: + session.add(magnum_service) + session.flush() + except db_exc.DBDuplicateEntry: + host = values["host"] + binary = values["binary"] + LOG.warning( + "Magnum service with same host:%(host)s and" + " binary:%(binary)s had been saved into DB", + {'host': host, 'binary': binary}) + with _session_for_read() as read_session: + query = read_session.query(models.MagnumService) + query = query.filter_by(host=host, binary=binary) + return query.one() + return magnum_service def get_magnum_service_list(self, disabled=None, limit=None, marker=None, sort_key=None, sort_dir=None ): - query = model_query(models.MagnumService) - if disabled: - query = query.filter_by(disabled=disabled) + with _session_for_read() as session: + query = session.query(models.MagnumService) + if disabled: + query = query.filter_by(disabled=disabled) - return _paginate_query(models.MagnumService, limit, marker, - sort_key, sort_dir, query) + return _paginate_query( + models.MagnumService, limit, marker, sort_key, sort_dir, query) + @oslo_db_api.retry_on_deadlock def create_quota(self, values): quotas = models.Quota() quotas.update(values) - try: - quotas.save() - except db_exc.DBDuplicateEntry: - raise exception.QuotaAlreadyExists(project_id=values['project_id'], - resource=values['resource']) - return quotas + + with _session_for_write() as session: + try: + session.add(quotas) + session.flush() + except db_exc.DBDuplicateEntry: + raise exception.QuotaAlreadyExists( + project_id=values['project_id'], + resource=values['resource']) + return quotas def _add_quota_filters(self, query, filters): if filters is None: @@ -611,15 +644,16 @@ class Connection(api.Connection): def get_quota_list(self, context, filters=None, limit=None, marker=None, sort_key=None, sort_dir=None): - query = model_query(models.Quota) - query = self._add_quota_filters(query, filters) - return _paginate_query(models.Quota, limit, marker, - sort_key, sort_dir, query) + with _session_for_read() as session: + query = session.query(models.Quota) + query = self._add_quota_filters(query, filters) + return _paginate_query( + models.Quota, limit, marker, sort_key, sort_dir, query) + @oslo_db_api.retry_on_deadlock def update_quota(self, project_id, values): - session = get_session() - with session.begin(): - query = model_query(models.Quota, session=session) + with _session_for_write() as session: + query = session.query(models.Quota) resource = values['resource'] try: query = query.filter_by(project_id=project_id).filter_by( @@ -633,12 +667,13 @@ class Connection(api.Connection): ref.update(values) return ref + @oslo_db_api.retry_on_deadlock def delete_quota(self, project_id, resource): - session = get_session() - with session.begin(): - query = model_query(models.Quota, session=session) \ - .filter_by(project_id=project_id) \ - .filter_by(resource=resource) + with _session_for_write() as session: + query = ( + session.query(models.Quota) + .filter_by(project_id=project_id) + .filter_by(resource=resource)) try: query.one() @@ -650,31 +685,34 @@ class Connection(api.Connection): query.delete() def get_quota_by_id(self, context, quota_id): - query = model_query(models.Quota) - query = query.filter_by(id=quota_id) - try: - return query.one() - except NoResultFound: - msg = _('quota id %s .') % quota_id - raise exception.QuotaNotFound(msg=msg) + with _session_for_read() as session: + query = session.query(models.Quota) + query = query.filter_by(id=quota_id) + try: + return query.one() + except NoResultFound: + msg = _('quota id %s .') % quota_id + raise exception.QuotaNotFound(msg=msg) def quota_get_all_by_project_id(self, project_id): - query = model_query(models.Quota) - result = query.filter_by(project_id=project_id).all() + with _session_for_read() as session: + query = session.query(models.Quota) + result = query.filter_by(project_id=project_id).all() return result def get_quota_by_project_id_resource(self, project_id, resource): - query = model_query(models.Quota) - query = query.filter_by(project_id=project_id).filter_by( - resource=resource) + with _session_for_read() as session: + query = session.query(models.Quota) + query = query.filter_by(project_id=project_id).filter_by( + resource=resource) - try: - return query.one() - except NoResultFound: - msg = (_('project_id %(project_id)s resource %(resource)s.') % - {'project_id': project_id, 'resource': resource}) - raise exception.QuotaNotFound(msg=msg) + try: + return query.one() + except NoResultFound: + msg = (_('project_id %(project_id)s resource %(resource)s.') % + {'project_id': project_id, 'resource': resource}) + raise exception.QuotaNotFound(msg=msg) def _add_federation_filters(self, query, filters): if filters is None: @@ -701,60 +739,69 @@ class Connection(api.Connection): return query def get_federation_by_id(self, context, federation_id): - query = model_query(models.Federation) - query = self._add_tenant_filters(context, query) - query = query.filter_by(id=federation_id) - try: - return query.one() - except NoResultFound: - raise exception.FederationNotFound(federation=federation_id) + with _session_for_read() as session: + query = session.query(models.Federation) + query = self._add_tenant_filters(context, query) + query = query.filter_by(id=federation_id) + try: + return query.one() + except NoResultFound: + raise exception.FederationNotFound(federation=federation_id) def get_federation_by_uuid(self, context, federation_uuid): - query = model_query(models.Federation) - query = self._add_tenant_filters(context, query) - query = query.filter_by(uuid=federation_uuid) - try: - return query.one() - except NoResultFound: - raise exception.FederationNotFound(federation=federation_uuid) + with _session_for_read() as session: + query = session.query(models.Federation) + query = self._add_tenant_filters(context, query) + query = query.filter_by(uuid=federation_uuid) + try: + return query.one() + except NoResultFound: + raise exception.FederationNotFound(federation=federation_uuid) def get_federation_by_name(self, context, federation_name): - query = model_query(models.Federation) - query = self._add_tenant_filters(context, query) - query = query.filter_by(name=federation_name) - try: - return query.one() - except MultipleResultsFound: - raise exception.Conflict('Multiple federations exist with same ' - 'name. Please use the federation uuid ' - 'instead.') - except NoResultFound: - raise exception.FederationNotFound(federation=federation_name) + with _session_for_read() as session: + query = session.query(models.Federation) + query = self._add_tenant_filters(context, query) + query = query.filter_by(name=federation_name) + try: + return query.one() + except MultipleResultsFound: + raise exception.Conflict( + 'Multiple federations exist with same ' + 'name. Please use the federation uuid ' + 'instead.') + except NoResultFound: + raise exception.FederationNotFound(federation=federation_name) def get_federation_list(self, context, limit=None, marker=None, sort_key=None, sort_dir=None, filters=None): - query = model_query(models.Federation) - query = self._add_tenant_filters(context, query) - query = self._add_federation_filters(query, filters) - return _paginate_query(models.Federation, limit, marker, - sort_key, sort_dir, query) + with _session_for_read() as session: + query = session.query(models.Federation) + query = self._add_tenant_filters(context, query) + query = self._add_federation_filters(query, filters) + return _paginate_query( + models.Federation, limit, marker, sort_key, sort_dir, query) + @oslo_db_api.retry_on_deadlock def create_federation(self, values): if not values.get('uuid'): values['uuid'] = uuidutils.generate_uuid() federation = models.Federation() federation.update(values) - try: - federation.save() - except db_exc.DBDuplicateEntry: - raise exception.FederationAlreadyExists(uuid=values['uuid']) - return federation + with _session_for_write() as session: + try: + session.add(federation) + session.flush() + except db_exc.DBDuplicateEntry: + raise exception.FederationAlreadyExists(uuid=values['uuid']) + return federation + + @oslo_db_api.retry_on_deadlock def destroy_federation(self, federation_id): - session = get_session() - with session.begin(): - query = model_query(models.Federation, session=session) + with _session_for_write() as session: + query = session.query(models.Federation) query = add_identity_filter(query, federation_id) try: @@ -771,10 +818,10 @@ class Connection(api.Connection): return self._do_update_federation(federation_id, values) + @oslo_db_api.retry_on_deadlock def _do_update_federation(self, federation_id, values): - session = get_session() - with session.begin(): - query = model_query(models.Federation, session=session) + with _session_for_write() as session: + query = session.query(models.Federation) query = add_identity_filter(query, federation_id) try: ref = query.with_for_update().one() @@ -804,23 +851,27 @@ class Connection(api.Connection): return query + @oslo_db_api.retry_on_deadlock def create_nodegroup(self, values): if not values.get('uuid'): values['uuid'] = uuidutils.generate_uuid() nodegroup = models.NodeGroup() nodegroup.update(values) - try: - nodegroup.save() - except db_exc.DBDuplicateEntry: - raise exception.NodeGroupAlreadyExists( - cluster_id=values['cluster_id'], name=values['name']) - return nodegroup + with _session_for_write() as session: + try: + session.add(nodegroup) + session.flush() + except db_exc.DBDuplicateEntry: + raise exception.NodeGroupAlreadyExists( + cluster_id=values['cluster_id'], name=values['name']) + return nodegroup + + @oslo_db_api.retry_on_deadlock def destroy_nodegroup(self, cluster_id, nodegroup_id): - session = get_session() - with session.begin(): - query = model_query(models.NodeGroup, session=session) + with _session_for_write() as session: + query = session.query(models.NodeGroup) query = add_identity_filter(query, nodegroup_id) query = query.filter_by(cluster_id=cluster_id) try: @@ -832,10 +883,10 @@ class Connection(api.Connection): def update_nodegroup(self, cluster_id, nodegroup_id, values): return self._do_update_nodegroup(cluster_id, nodegroup_id, values) + @oslo_db_api.retry_on_deadlock def _do_update_nodegroup(self, cluster_id, nodegroup_id, values): - session = get_session() - with session.begin(): - query = model_query(models.NodeGroup, session=session) + with _session_for_write() as session: + query = session.query(models.NodeGroup) query = add_identity_filter(query, nodegroup_id) query = query.filter_by(cluster_id=cluster_id) try: @@ -847,56 +898,62 @@ class Connection(api.Connection): return ref def get_nodegroup_by_id(self, context, cluster_id, nodegroup_id): - query = model_query(models.NodeGroup) - if not context.is_admin: - query = query.filter_by(project_id=context.project_id) - query = query.filter_by(cluster_id=cluster_id) - query = query.filter_by(id=nodegroup_id) - try: - return query.one() - except NoResultFound: - raise exception.NodeGroupNotFound(nodegroup=nodegroup_id) + with _session_for_read() as session: + query = session.query(models.NodeGroup) + if not context.is_admin: + query = query.filter_by(project_id=context.project_id) + query = query.filter_by(cluster_id=cluster_id) + query = query.filter_by(id=nodegroup_id) + try: + return query.one() + except NoResultFound: + raise exception.NodeGroupNotFound(nodegroup=nodegroup_id) def get_nodegroup_by_uuid(self, context, cluster_id, nodegroup_uuid): - query = model_query(models.NodeGroup) - if not context.is_admin: - query = query.filter_by(project_id=context.project_id) - query = query.filter_by(cluster_id=cluster_id) - query = query.filter_by(uuid=nodegroup_uuid) - try: - return query.one() - except NoResultFound: - raise exception.NodeGroupNotFound(nodegroup=nodegroup_uuid) + with _session_for_read() as session: + query = session.query(models.NodeGroup) + if not context.is_admin: + query = query.filter_by(project_id=context.project_id) + query = query.filter_by(cluster_id=cluster_id) + query = query.filter_by(uuid=nodegroup_uuid) + try: + return query.one() + except NoResultFound: + raise exception.NodeGroupNotFound(nodegroup=nodegroup_uuid) def get_nodegroup_by_name(self, context, cluster_id, nodegroup_name): - query = model_query(models.NodeGroup) - if not context.is_admin: - query = query.filter_by(project_id=context.project_id) - query = query.filter_by(cluster_id=cluster_id) - query = query.filter_by(name=nodegroup_name) - try: - return query.one() - except MultipleResultsFound: - raise exception.Conflict('Multiple nodegroups exist with same ' - 'name. Please use the nodegroup uuid ' - 'instead.') - except NoResultFound: - raise exception.NodeGroupNotFound(nodegroup=nodegroup_name) + with _session_for_read() as session: + query = session.query(models.NodeGroup) + if not context.is_admin: + query = query.filter_by(project_id=context.project_id) + query = query.filter_by(cluster_id=cluster_id) + query = query.filter_by(name=nodegroup_name) + try: + return query.one() + except MultipleResultsFound: + raise exception.Conflict( + 'Multiple nodegroups exist with same ' + 'name. Please use the nodegroup uuid ' + 'instead.') + except NoResultFound: + raise exception.NodeGroupNotFound(nodegroup=nodegroup_name) def list_cluster_nodegroups(self, context, cluster_id, filters=None, limit=None, marker=None, sort_key=None, sort_dir=None): - query = model_query(models.NodeGroup) - if not context.is_admin: - query = query.filter_by(project_id=context.project_id) - query = query.filter_by(cluster_id=cluster_id) - query = self._add_nodegoup_filters(query, filters) - return _paginate_query(models.NodeGroup, limit, marker, - sort_key, sort_dir, query) + with _session_for_read() as session: + query = session.query(models.NodeGroup) + if not context.is_admin: + query = query.filter_by(project_id=context.project_id) + query = query.filter_by(cluster_id=cluster_id) + query = self._add_nodegoup_filters(query, filters) + return _paginate_query( + models.NodeGroup, limit, marker, sort_key, sort_dir, query) def get_cluster_nodegroup_count(self, context, cluster_id): - query = model_query(models.NodeGroup) - if not context.is_admin: - query = query.filter_by(project_id=context.project_id) - query = query.filter_by(cluster_id=cluster_id) - return query.count() + with _session_for_read() as session: + query = session.query(models.NodeGroup) + if not context.is_admin: + query = query.filter_by(project_id=context.project_id) + query = query.filter_by(cluster_id=cluster_id) + return query.count() diff --git a/magnum/db/sqlalchemy/models.py b/magnum/db/sqlalchemy/models.py index 92b474da37..2d83093010 100644 --- a/magnum/db/sqlalchemy/models.py +++ b/magnum/db/sqlalchemy/models.py @@ -87,15 +87,6 @@ class MagnumBase(models.TimestampMixin, d[c.name] = self[c.name] return d - def save(self, session=None): - import magnum.db.sqlalchemy.api as db_api - - if session is None: - session = db_api.get_session() - - with session.begin(): - super(MagnumBase, self).save(session) - Base = declarative_base(cls=MagnumBase) diff --git a/magnum/tests/unit/db/base.py b/magnum/tests/unit/db/base.py index 711d30caeb..d78d8fa378 100644 --- a/magnum/tests/unit/db/base.py +++ b/magnum/tests/unit/db/base.py @@ -16,10 +16,10 @@ """Magnum DB test base class.""" import fixtures +from oslo_db.sqlalchemy import enginefacade import magnum.conf from magnum.db import api as dbapi -from magnum.db.sqlalchemy import api as sqla_api from magnum.db.sqlalchemy import migration from magnum.db.sqlalchemy import models from magnum.tests import base @@ -32,16 +32,15 @@ _DB_CACHE = None class Database(fixtures.Fixture): - def __init__(self, db_api, db_migrate, sql_connection): + def __init__(self, engine, db_migrate, sql_connection): self.sql_connection = sql_connection - self.engine = db_api.get_engine() + self.engine = engine self.engine.dispose() - conn = self.engine.connect() - self.setup_sqlite(db_migrate) - self.post_migrations() - - self._DB = "".join(line for line in conn.connection.iterdump()) + with self.engine.connect() as conn: + self.setup_sqlite(db_migrate) + self.post_migrations() + self._DB = "".join(line for line in conn.connection.iterdump()) self.engine.dispose() def setup_sqlite(self, db_migrate): @@ -50,9 +49,10 @@ class Database(fixtures.Fixture): models.Base.metadata.create_all(self.engine) db_migrate.stamp('head') - def _setUp(self): - conn = self.engine.connect() - conn.connection.executescript(self._DB) + def setUp(self): + super(Database, self).setUp() + with self.engine.connect() as conn: + conn.connection.executescript(self._DB) self.addCleanup(self.engine.dispose) def post_migrations(self): @@ -68,6 +68,8 @@ class DbTestCase(base.TestCase): global _DB_CACHE if not _DB_CACHE: - _DB_CACHE = Database(sqla_api, migration, + engine = enginefacade.writer.get_engine() + _DB_CACHE = Database(engine, migration, sql_connection=CONF.database.connection) + engine.dispose() self.useFixture(_DB_CACHE) diff --git a/magnum/tests/unit/db/sqlalchemy/test_types.py b/magnum/tests/unit/db/sqlalchemy/test_types.py index b9a2c1103a..d89bea97ca 100644 --- a/magnum/tests/unit/db/sqlalchemy/test_types.py +++ b/magnum/tests/unit/db/sqlalchemy/test_types.py @@ -26,16 +26,22 @@ class SqlAlchemyCustomTypesTestCase(base.DbTestCase): # Create ClusterTemplate w/o labels cluster_template1_id = uuidutils.generate_uuid() self.dbapi.create_cluster_template({'uuid': cluster_template1_id}) - cluster_template1 = sa_api.model_query( - models.ClusterTemplate).filter_by(uuid=cluster_template1_id).one() + with sa_api._session_for_read() as session: + cluster_template1 = (session.query( + models.ClusterTemplate) + .filter_by(uuid=cluster_template1_id) + .one()) self.assertEqual({}, cluster_template1.labels) # Create ClusterTemplate with labels cluster_template2_id = uuidutils.generate_uuid() self.dbapi.create_cluster_template( {'uuid': cluster_template2_id, 'labels': {'bar': 'foo'}}) - cluster_template2 = sa_api.model_query( - models.ClusterTemplate).filter_by(uuid=cluster_template2_id).one() + with sa_api._session_for_read() as session: + cluster_template2 = (session.query( + models.ClusterTemplate) + .filter_by(uuid=cluster_template2_id) + .one()) self.assertEqual('foo', cluster_template2.labels['bar']) def test_JSONEncodedDict_type_check(self): @@ -48,8 +54,9 @@ class SqlAlchemyCustomTypesTestCase(base.DbTestCase): # Create nodegroup w/o node_addresses nodegroup1_id = uuidutils.generate_uuid() self.dbapi.create_nodegroup({'uuid': nodegroup1_id}) - nodegroup1 = sa_api.model_query( - models.NodeGroup).filter_by(uuid=nodegroup1_id).one() + with sa_api._session_for_read() as session: + nodegroup1 = session.query( + models.NodeGroup).filter_by(uuid=nodegroup1_id).one() self.assertEqual([], nodegroup1.node_addresses) # Create nodegroup with node_addresses @@ -59,8 +66,9 @@ class SqlAlchemyCustomTypesTestCase(base.DbTestCase): 'node_addresses': ['mynode_address1', 'mynode_address2'] }) - nodegroup2 = sa_api.model_query( - models.NodeGroup).filter_by(uuid=nodegroup2_id).one() + with sa_api._session_for_read() as session: + nodegroup2 = session.query( + models.NodeGroup).filter_by(uuid=nodegroup2_id).one() self.assertEqual(['mynode_address1', 'mynode_address2'], nodegroup2.node_addresses)