Switch to using enginefacade

Closes-Bug: #2067345
Change-Id: If9a2c96628cfcb819fee5e19f872ea015979b30f
(cherry picked from commit 0ce2c41404f1f8dcd1bcd19d36a885edc34926a2)
This commit is contained in:
Mohammed Naser 2024-08-28 19:16:07 -04:00 committed by Michal Nasiadka
parent b82bd6ae51
commit d5f6dcf9ff
6 changed files with 437 additions and 377 deletions

View File

@ -12,6 +12,7 @@
from eventlet.green import threading
from oslo_context import context
from oslo_db.sqlalchemy import enginefacade
from magnum.common import policy
@ -20,6 +21,7 @@ import magnum.conf
CONF = magnum.conf.CONF
@enginefacade.transaction_context_provider
class RequestContext(context.RequestContext):
"""Extends security contexts from the OpenStack common library."""

View File

@ -13,8 +13,8 @@
from logging import config as log_config
from alembic import context
from oslo_db.sqlalchemy import enginefacade
from magnum.db.sqlalchemy import api as sqla_api
from magnum.db.sqlalchemy import models
# this is the Alembic Config object, which provides
@ -43,7 +43,7 @@ def run_migrations_online():
and associate a connection with the context.
"""
engine = sqla_api.get_engine()
engine = enginefacade.writer.get_engine()
with engine.connect() as connection:
context.configure(connection=connection,
target_metadata=target_metadata)

View File

@ -14,8 +14,11 @@
"""SQLAlchemy storage backend."""
import threading
from oslo_db import api as oslo_db_api
from oslo_db import exception as db_exc
from oslo_db.sqlalchemy import session as db_session
from oslo_db.sqlalchemy import enginefacade
from oslo_db.sqlalchemy import utils as db_utils
from oslo_log import log
from oslo_utils import importutils
@ -35,34 +38,13 @@ from magnum.db import api
from magnum.db.sqlalchemy import models
from magnum.i18n import _
profiler_sqlalchemy = importutils.try_import('osprofiler.sqlalchemy')
profiler_sqlalchemy = importutils.try_import("osprofiler.sqlalchemy")
CONF = magnum.conf.CONF
LOG = log.getLogger(__name__)
_FACADE = None
def _create_facade_lazily():
global _FACADE
if _FACADE is None:
_FACADE = db_session.EngineFacade.from_config(CONF)
if profiler_sqlalchemy:
if CONF.profiler.enabled and CONF.profiler.trace_sqlalchemy:
profiler_sqlalchemy.add_tracing(sa, _FACADE.get_engine(), "db")
return _FACADE
def get_engine():
facade = _create_facade_lazily()
return facade.get_engine()
def get_session(**kwargs):
facade = _create_facade_lazily()
return facade.get_session(**kwargs)
_CONTEXT = threading.local()
def get_backend():
@ -70,15 +52,21 @@ def get_backend():
return Connection()
def model_query(model, *args, **kwargs):
"""Query helper for simpler session usage.
def _session_for_read():
return _wrap_session(enginefacade.reader.using(_CONTEXT))
:param session: if present, the session to use
"""
session = kwargs.get('session') or get_session()
query = session.query(model, *args)
return query
# NOTE(tylerchristie) Please add @oslo_db_api.retry_on_deadlock decorator to
# any new methods using _session_for_write (as deadlocks happen on write), so
# that oslo_db is able to retry in case of deadlocks.
def _session_for_write():
return _wrap_session(enginefacade.writer.using(_CONTEXT))
def _wrap_session(session):
if CONF.profiler.enabled and CONF.profiler.trace_sqlalchemy:
session = profiler_sqlalchemy.wrap_session(sa, session)
return session
def add_identity_filter(query, value):
@ -101,8 +89,6 @@ def add_identity_filter(query, value):
def _paginate_query(model, limit=None, marker=None, sort_key=None,
sort_dir=None, query=None):
if not query:
query = model_query(model)
sort_keys = ['id']
if sort_key and sort_key not in sort_keys:
sort_keys.insert(0, sort_key)
@ -166,7 +152,8 @@ class Connection(api.Connection):
# Helper to filter based on node_count field from nodegroups
def filter_node_count(query, node_count, is_master=False):
nfunc = func.sum(models.NodeGroup.node_count)
nquery = model_query(models.NodeGroup)
with _session_for_read() as session:
nquery = session.query(models.NodeGroup)
if is_master:
nquery = nquery.filter(models.NodeGroup.role == 'master')
else:
@ -187,12 +174,14 @@ class Connection(api.Connection):
def get_cluster_list(self, context, filters=None, limit=None, marker=None,
sort_key=None, sort_dir=None):
query = model_query(models.Cluster)
with _session_for_read() as session:
query = session.query(models.Cluster)
query = self._add_tenant_filters(context, query)
query = self._add_clusters_filters(query, filters)
return _paginate_query(models.Cluster, limit, marker,
sort_key, sort_dir, query)
return _paginate_query(
models.Cluster, limit, marker, sort_key, sort_dir, query)
@oslo_db_api.retry_on_deadlock
def create_cluster(self, values):
# ensure defaults are present for new clusters
if not values.get('uuid'):
@ -200,14 +189,18 @@ class Connection(api.Connection):
cluster = models.Cluster()
cluster.update(values)
with _session_for_write() as session:
try:
cluster.save()
session.add(cluster)
session.flush()
except db_exc.DBDuplicateEntry:
raise exception.ClusterAlreadyExists(uuid=values['uuid'])
return cluster
def get_cluster_by_id(self, context, cluster_id):
query = model_query(models.Cluster)
with _session_for_read() as session:
query = session.query(models.Cluster)
query = self._add_tenant_filters(context, query)
query = query.filter_by(id=cluster_id)
try:
@ -216,19 +209,22 @@ class Connection(api.Connection):
raise exception.ClusterNotFound(cluster=cluster_id)
def get_cluster_by_name(self, context, cluster_name):
query = model_query(models.Cluster)
with _session_for_read() as session:
query = session.query(models.Cluster)
query = self._add_tenant_filters(context, query)
query = query.filter_by(name=cluster_name)
try:
return query.one()
except MultipleResultsFound:
raise exception.Conflict('Multiple clusters exist with same name.'
raise exception.Conflict(
'Multiple clusters exist with same name.'
' Please use the cluster uuid instead.')
except NoResultFound:
raise exception.ClusterNotFound(cluster=cluster_name)
def get_cluster_by_uuid(self, context, cluster_uuid):
query = model_query(models.Cluster)
with _session_for_read() as session:
query = session.query(models.Cluster)
query = self._add_tenant_filters(context, query)
query = query.filter_by(uuid=cluster_uuid)
try:
@ -237,7 +233,8 @@ class Connection(api.Connection):
raise exception.ClusterNotFound(cluster=cluster_uuid)
def get_cluster_stats(self, context, project_id=None):
query = model_query(models.Cluster)
with _session_for_read() as session:
query = session.query(models.Cluster)
node_count_col = models.NodeGroup.node_count
ncfunc = func.sum(node_count_col)
@ -253,15 +250,16 @@ class Connection(api.Connection):
return clusters, nodes
def get_cluster_count_all(self, context, filters=None):
query = model_query(models.Cluster)
with _session_for_read() as session:
query = session.query(models.Cluster)
query = self._add_tenant_filters(context, query)
query = self._add_clusters_filters(query, filters)
return query.count()
@oslo_db_api.retry_on_deadlock
def destroy_cluster(self, cluster_id):
session = get_session()
with session.begin():
query = model_query(models.Cluster, session=session)
with _session_for_write() as session:
query = session.query(models.Cluster)
query = add_identity_filter(query, cluster_id)
try:
@ -279,10 +277,10 @@ class Connection(api.Connection):
return self._do_update_cluster(cluster_id, values)
@oslo_db_api.retry_on_deadlock
def _do_update_cluster(self, cluster_id, values):
session = get_session()
with session.begin():
query = model_query(models.Cluster, session=session)
with _session_for_write() as session:
query = session.query(models.Cluster)
query = add_identity_filter(query, cluster_id)
try:
ref = query.with_for_update().one()
@ -309,22 +307,24 @@ class Connection(api.Connection):
def get_cluster_template_list(self, context, filters=None, limit=None,
marker=None, sort_key=None, sort_dir=None):
query = model_query(models.ClusterTemplate)
with _session_for_read() as session:
query = session.query(models.ClusterTemplate)
query = self._add_tenant_filters(context, query)
query = self._add_cluster_template_filters(query, filters)
# include public (and not hidden) ClusterTemplates
public_q = model_query(models.ClusterTemplate).filter_by(
public_q = session.query(models.ClusterTemplate).filter_by(
public=True, hidden=False)
query = query.union(public_q)
# include hidden and public ClusterTemplate if admin
if context.is_admin:
hidden_q = model_query(models.ClusterTemplate).filter_by(
hidden_q = session.query(models.ClusterTemplate).filter_by(
public=True, hidden=True)
query = query.union(hidden_q)
return _paginate_query(models.ClusterTemplate, limit, marker,
sort_key, sort_dir, query)
@oslo_db_api.retry_on_deadlock
def create_cluster_template(self, values):
# ensure defaults are present for new ClusterTemplates
if not values.get('uuid'):
@ -332,18 +332,25 @@ class Connection(api.Connection):
cluster_template = models.ClusterTemplate()
cluster_template.update(values)
with _session_for_write() as session:
try:
cluster_template.save()
session.add(cluster_template)
session.flush()
except db_exc.DBDuplicateEntry:
raise exception.ClusterTemplateAlreadyExists(uuid=values['uuid'])
raise exception.ClusterTemplateAlreadyExists(
uuid=values['uuid'])
return cluster_template
def get_cluster_template_by_id(self, context, cluster_template_id):
query = model_query(models.ClusterTemplate)
with _session_for_read() as session:
query = session.query(models.ClusterTemplate)
query = self._add_tenant_filters(context, query)
public_q = model_query(models.ClusterTemplate).filter_by(public=True)
public_q = session.query(models.ClusterTemplate).filter_by(
public=True)
query = query.union(public_q)
query = query.filter(models.ClusterTemplate.id == cluster_template_id)
query = query.filter(
models.ClusterTemplate.id == cluster_template_id)
try:
return query.one()
except NoResultFound:
@ -351,9 +358,11 @@ class Connection(api.Connection):
clustertemplate=cluster_template_id)
def get_cluster_template_by_uuid(self, context, cluster_template_uuid):
query = model_query(models.ClusterTemplate)
with _session_for_read() as session:
query = session.query(models.ClusterTemplate)
query = self._add_tenant_filters(context, query)
public_q = model_query(models.ClusterTemplate).filter_by(public=True)
public_q = session.query(models.ClusterTemplate).filter_by(
public=True)
query = query.union(public_q)
query = query.filter(
models.ClusterTemplate.uuid == cluster_template_uuid)
@ -364,16 +373,19 @@ class Connection(api.Connection):
clustertemplate=cluster_template_uuid)
def get_cluster_template_by_name(self, context, cluster_template_name):
query = model_query(models.ClusterTemplate)
with _session_for_read() as session:
query = session.query(models.ClusterTemplate)
query = self._add_tenant_filters(context, query)
public_q = model_query(models.ClusterTemplate).filter_by(public=True)
public_q = session.query(models.ClusterTemplate).filter_by(
public=True)
query = query.union(public_q)
query = query.filter(
models.ClusterTemplate.name == cluster_template_name)
try:
return query.one()
except MultipleResultsFound:
raise exception.Conflict('Multiple ClusterTemplates exist with'
raise exception.Conflict(
'Multiple ClusterTemplates exist with'
' same name. Please use the '
'ClusterTemplate uuid instead.')
except NoResultFound:
@ -382,9 +394,9 @@ class Connection(api.Connection):
def _is_cluster_template_referenced(self, session, cluster_template_uuid):
"""Checks whether the ClusterTemplate is referenced by cluster(s)."""
query = model_query(models.Cluster, session=session)
query = self._add_clusters_filters(query, {'cluster_template_id':
cluster_template_uuid})
query = session.query(models.Cluster)
query = self._add_clusters_filters(
query, {'cluster_template_id': cluster_template_uuid})
return query.count() != 0
def _is_publishing_cluster_template(self, values):
@ -395,10 +407,10 @@ class Connection(api.Connection):
return True
return False
@oslo_db_api.retry_on_deadlock
def destroy_cluster_template(self, cluster_template_id):
session = get_session()
with session.begin():
query = model_query(models.ClusterTemplate, session=session)
with _session_for_write() as session:
query = session.query(models.ClusterTemplate)
query = add_identity_filter(query, cluster_template_id)
try:
@ -422,10 +434,10 @@ class Connection(api.Connection):
return self._do_update_cluster_template(cluster_template_id, values)
@oslo_db_api.retry_on_deadlock
def _do_update_cluster_template(self, cluster_template_id, values):
session = get_session()
with session.begin():
query = model_query(models.ClusterTemplate, session=session)
with _session_for_write() as session:
query = session.query(models.ClusterTemplate)
query = add_identity_filter(query, cluster_template_id)
try:
ref = query.with_for_update().one()
@ -444,6 +456,7 @@ class Connection(api.Connection):
ref.update(values)
return ref
@oslo_db_api.retry_on_deadlock
def create_x509keypair(self, values):
# ensure defaults are present for new x509keypairs
if not values.get('uuid'):
@ -451,14 +464,18 @@ class Connection(api.Connection):
x509keypair = models.X509KeyPair()
x509keypair.update(values)
with _session_for_write() as session:
try:
x509keypair.save()
session.add(x509keypair)
session.flush()
except db_exc.DBDuplicateEntry:
raise exception.X509KeyPairAlreadyExists(uuid=values['uuid'])
return x509keypair
def get_x509keypair_by_id(self, context, x509keypair_id):
query = model_query(models.X509KeyPair)
with _session_for_read() as session:
query = session.query(models.X509KeyPair)
query = self._add_tenant_filters(context, query)
query = query.filter_by(id=x509keypair_id)
try:
@ -467,18 +484,20 @@ class Connection(api.Connection):
raise exception.X509KeyPairNotFound(x509keypair=x509keypair_id)
def get_x509keypair_by_uuid(self, context, x509keypair_uuid):
query = model_query(models.X509KeyPair)
with _session_for_read() as session:
query = session.query(models.X509KeyPair)
query = self._add_tenant_filters(context, query)
query = query.filter_by(uuid=x509keypair_uuid)
try:
return query.one()
except NoResultFound:
raise exception.X509KeyPairNotFound(x509keypair=x509keypair_uuid)
raise exception.X509KeyPairNotFound(
x509keypair=x509keypair_uuid)
@oslo_db_api.retry_on_deadlock
def destroy_x509keypair(self, x509keypair_id):
session = get_session()
with session.begin():
query = model_query(models.X509KeyPair, session=session)
with _session_for_write() as session:
query = session.query(models.X509KeyPair)
query = add_identity_filter(query, x509keypair_id)
count = query.delete()
if count != 1:
@ -492,10 +511,10 @@ class Connection(api.Connection):
return self._do_update_x509keypair(x509keypair_id, values)
@oslo_db_api.retry_on_deadlock
def _do_update_x509keypair(self, x509keypair_id, values):
session = get_session()
with session.begin():
query = model_query(models.X509KeyPair, session=session)
with _session_for_write() as session:
query = session.query(models.X509KeyPair)
query = add_identity_filter(query, x509keypair_id)
try:
ref = query.with_for_update().one()
@ -518,26 +537,27 @@ class Connection(api.Connection):
def get_x509keypair_list(self, context, filters=None, limit=None,
marker=None, sort_key=None, sort_dir=None):
query = model_query(models.X509KeyPair)
with _session_for_read() as session:
query = session.query(models.X509KeyPair)
query = self._add_tenant_filters(context, query)
query = self._add_x509keypairs_filters(query, filters)
return _paginate_query(models.X509KeyPair, limit, marker,
sort_key, sort_dir, query)
return _paginate_query(
models.X509KeyPair, limit, marker, sort_key, sort_dir, query)
@oslo_db_api.retry_on_deadlock
def destroy_magnum_service(self, magnum_service_id):
session = get_session()
with session.begin():
query = model_query(models.MagnumService, session=session)
with _session_for_write() as session:
query = session.query(models.MagnumService)
query = add_identity_filter(query, magnum_service_id)
count = query.delete()
if count != 1:
raise exception.MagnumServiceNotFound(
magnum_service_id=magnum_service_id)
@oslo_db_api.retry_on_deadlock
def update_magnum_service(self, magnum_service_id, values):
session = get_session()
with session.begin():
query = model_query(models.MagnumService, session=session)
with _session_for_write() as session:
query = session.query(models.MagnumService)
query = add_identity_filter(query, magnum_service_id)
try:
ref = query.with_for_update().one()
@ -553,25 +573,32 @@ class Connection(api.Connection):
return ref
def get_magnum_service_by_host_and_binary(self, host, binary):
query = model_query(models.MagnumService)
with _session_for_read() as session:
query = session.query(models.MagnumService)
query = query.filter_by(host=host, binary=binary)
try:
return query.one()
except NoResultFound:
return None
@oslo_db_api.retry_on_deadlock
def create_magnum_service(self, values):
magnum_service = models.MagnumService()
magnum_service.update(values)
with _session_for_write() as session:
try:
magnum_service.save()
session.add(magnum_service)
session.flush()
except db_exc.DBDuplicateEntry:
host = values["host"]
binary = values["binary"]
LOG.warning("Magnum service with same host:%(host)s and"
LOG.warning(
"Magnum service with same host:%(host)s and"
" binary:%(binary)s had been saved into DB",
{'host': host, 'binary': binary})
query = model_query(models.MagnumService)
with _session_for_read() as read_session:
query = read_session.query(models.MagnumService)
query = query.filter_by(host=host, binary=binary)
return query.one()
return magnum_service
@ -579,20 +606,26 @@ class Connection(api.Connection):
def get_magnum_service_list(self, disabled=None, limit=None,
marker=None, sort_key=None, sort_dir=None
):
query = model_query(models.MagnumService)
with _session_for_read() as session:
query = session.query(models.MagnumService)
if disabled:
query = query.filter_by(disabled=disabled)
return _paginate_query(models.MagnumService, limit, marker,
sort_key, sort_dir, query)
return _paginate_query(
models.MagnumService, limit, marker, sort_key, sort_dir, query)
@oslo_db_api.retry_on_deadlock
def create_quota(self, values):
quotas = models.Quota()
quotas.update(values)
with _session_for_write() as session:
try:
quotas.save()
session.add(quotas)
session.flush()
except db_exc.DBDuplicateEntry:
raise exception.QuotaAlreadyExists(project_id=values['project_id'],
raise exception.QuotaAlreadyExists(
project_id=values['project_id'],
resource=values['resource'])
return quotas
@ -611,15 +644,16 @@ class Connection(api.Connection):
def get_quota_list(self, context, filters=None, limit=None, marker=None,
sort_key=None, sort_dir=None):
query = model_query(models.Quota)
with _session_for_read() as session:
query = session.query(models.Quota)
query = self._add_quota_filters(query, filters)
return _paginate_query(models.Quota, limit, marker,
sort_key, sort_dir, query)
return _paginate_query(
models.Quota, limit, marker, sort_key, sort_dir, query)
@oslo_db_api.retry_on_deadlock
def update_quota(self, project_id, values):
session = get_session()
with session.begin():
query = model_query(models.Quota, session=session)
with _session_for_write() as session:
query = session.query(models.Quota)
resource = values['resource']
try:
query = query.filter_by(project_id=project_id).filter_by(
@ -633,12 +667,13 @@ class Connection(api.Connection):
ref.update(values)
return ref
@oslo_db_api.retry_on_deadlock
def delete_quota(self, project_id, resource):
session = get_session()
with session.begin():
query = model_query(models.Quota, session=session) \
.filter_by(project_id=project_id) \
.filter_by(resource=resource)
with _session_for_write() as session:
query = (
session.query(models.Quota)
.filter_by(project_id=project_id)
.filter_by(resource=resource))
try:
query.one()
@ -650,7 +685,8 @@ class Connection(api.Connection):
query.delete()
def get_quota_by_id(self, context, quota_id):
query = model_query(models.Quota)
with _session_for_read() as session:
query = session.query(models.Quota)
query = query.filter_by(id=quota_id)
try:
return query.one()
@ -659,13 +695,15 @@ class Connection(api.Connection):
raise exception.QuotaNotFound(msg=msg)
def quota_get_all_by_project_id(self, project_id):
query = model_query(models.Quota)
with _session_for_read() as session:
query = session.query(models.Quota)
result = query.filter_by(project_id=project_id).all()
return result
def get_quota_by_project_id_resource(self, project_id, resource):
query = model_query(models.Quota)
with _session_for_read() as session:
query = session.query(models.Quota)
query = query.filter_by(project_id=project_id).filter_by(
resource=resource)
@ -701,7 +739,8 @@ class Connection(api.Connection):
return query
def get_federation_by_id(self, context, federation_id):
query = model_query(models.Federation)
with _session_for_read() as session:
query = session.query(models.Federation)
query = self._add_tenant_filters(context, query)
query = query.filter_by(id=federation_id)
try:
@ -710,7 +749,8 @@ class Connection(api.Connection):
raise exception.FederationNotFound(federation=federation_id)
def get_federation_by_uuid(self, context, federation_uuid):
query = model_query(models.Federation)
with _session_for_read() as session:
query = session.query(models.Federation)
query = self._add_tenant_filters(context, query)
query = query.filter_by(uuid=federation_uuid)
try:
@ -719,13 +759,15 @@ class Connection(api.Connection):
raise exception.FederationNotFound(federation=federation_uuid)
def get_federation_by_name(self, context, federation_name):
query = model_query(models.Federation)
with _session_for_read() as session:
query = session.query(models.Federation)
query = self._add_tenant_filters(context, query)
query = query.filter_by(name=federation_name)
try:
return query.one()
except MultipleResultsFound:
raise exception.Conflict('Multiple federations exist with same '
raise exception.Conflict(
'Multiple federations exist with same '
'name. Please use the federation uuid '
'instead.')
except NoResultFound:
@ -733,28 +775,33 @@ class Connection(api.Connection):
def get_federation_list(self, context, limit=None, marker=None,
sort_key=None, sort_dir=None, filters=None):
query = model_query(models.Federation)
with _session_for_read() as session:
query = session.query(models.Federation)
query = self._add_tenant_filters(context, query)
query = self._add_federation_filters(query, filters)
return _paginate_query(models.Federation, limit, marker,
sort_key, sort_dir, query)
return _paginate_query(
models.Federation, limit, marker, sort_key, sort_dir, query)
@oslo_db_api.retry_on_deadlock
def create_federation(self, values):
if not values.get('uuid'):
values['uuid'] = uuidutils.generate_uuid()
federation = models.Federation()
federation.update(values)
with _session_for_write() as session:
try:
federation.save()
session.add(federation)
session.flush()
except db_exc.DBDuplicateEntry:
raise exception.FederationAlreadyExists(uuid=values['uuid'])
return federation
@oslo_db_api.retry_on_deadlock
def destroy_federation(self, federation_id):
session = get_session()
with session.begin():
query = model_query(models.Federation, session=session)
with _session_for_write() as session:
query = session.query(models.Federation)
query = add_identity_filter(query, federation_id)
try:
@ -771,10 +818,10 @@ class Connection(api.Connection):
return self._do_update_federation(federation_id, values)
@oslo_db_api.retry_on_deadlock
def _do_update_federation(self, federation_id, values):
session = get_session()
with session.begin():
query = model_query(models.Federation, session=session)
with _session_for_write() as session:
query = session.query(models.Federation)
query = add_identity_filter(query, federation_id)
try:
ref = query.with_for_update().one()
@ -804,23 +851,27 @@ class Connection(api.Connection):
return query
@oslo_db_api.retry_on_deadlock
def create_nodegroup(self, values):
if not values.get('uuid'):
values['uuid'] = uuidutils.generate_uuid()
nodegroup = models.NodeGroup()
nodegroup.update(values)
with _session_for_write() as session:
try:
nodegroup.save()
session.add(nodegroup)
session.flush()
except db_exc.DBDuplicateEntry:
raise exception.NodeGroupAlreadyExists(
cluster_id=values['cluster_id'], name=values['name'])
return nodegroup
@oslo_db_api.retry_on_deadlock
def destroy_nodegroup(self, cluster_id, nodegroup_id):
session = get_session()
with session.begin():
query = model_query(models.NodeGroup, session=session)
with _session_for_write() as session:
query = session.query(models.NodeGroup)
query = add_identity_filter(query, nodegroup_id)
query = query.filter_by(cluster_id=cluster_id)
try:
@ -832,10 +883,10 @@ class Connection(api.Connection):
def update_nodegroup(self, cluster_id, nodegroup_id, values):
return self._do_update_nodegroup(cluster_id, nodegroup_id, values)
@oslo_db_api.retry_on_deadlock
def _do_update_nodegroup(self, cluster_id, nodegroup_id, values):
session = get_session()
with session.begin():
query = model_query(models.NodeGroup, session=session)
with _session_for_write() as session:
query = session.query(models.NodeGroup)
query = add_identity_filter(query, nodegroup_id)
query = query.filter_by(cluster_id=cluster_id)
try:
@ -847,7 +898,8 @@ class Connection(api.Connection):
return ref
def get_nodegroup_by_id(self, context, cluster_id, nodegroup_id):
query = model_query(models.NodeGroup)
with _session_for_read() as session:
query = session.query(models.NodeGroup)
if not context.is_admin:
query = query.filter_by(project_id=context.project_id)
query = query.filter_by(cluster_id=cluster_id)
@ -858,7 +910,8 @@ class Connection(api.Connection):
raise exception.NodeGroupNotFound(nodegroup=nodegroup_id)
def get_nodegroup_by_uuid(self, context, cluster_id, nodegroup_uuid):
query = model_query(models.NodeGroup)
with _session_for_read() as session:
query = session.query(models.NodeGroup)
if not context.is_admin:
query = query.filter_by(project_id=context.project_id)
query = query.filter_by(cluster_id=cluster_id)
@ -869,7 +922,8 @@ class Connection(api.Connection):
raise exception.NodeGroupNotFound(nodegroup=nodegroup_uuid)
def get_nodegroup_by_name(self, context, cluster_id, nodegroup_name):
query = model_query(models.NodeGroup)
with _session_for_read() as session:
query = session.query(models.NodeGroup)
if not context.is_admin:
query = query.filter_by(project_id=context.project_id)
query = query.filter_by(cluster_id=cluster_id)
@ -877,7 +931,8 @@ class Connection(api.Connection):
try:
return query.one()
except MultipleResultsFound:
raise exception.Conflict('Multiple nodegroups exist with same '
raise exception.Conflict(
'Multiple nodegroups exist with same '
'name. Please use the nodegroup uuid '
'instead.')
except NoResultFound:
@ -886,16 +941,18 @@ class Connection(api.Connection):
def list_cluster_nodegroups(self, context, cluster_id, filters=None,
limit=None, marker=None, sort_key=None,
sort_dir=None):
query = model_query(models.NodeGroup)
with _session_for_read() as session:
query = session.query(models.NodeGroup)
if not context.is_admin:
query = query.filter_by(project_id=context.project_id)
query = query.filter_by(cluster_id=cluster_id)
query = self._add_nodegoup_filters(query, filters)
return _paginate_query(models.NodeGroup, limit, marker,
sort_key, sort_dir, query)
return _paginate_query(
models.NodeGroup, limit, marker, sort_key, sort_dir, query)
def get_cluster_nodegroup_count(self, context, cluster_id):
query = model_query(models.NodeGroup)
with _session_for_read() as session:
query = session.query(models.NodeGroup)
if not context.is_admin:
query = query.filter_by(project_id=context.project_id)
query = query.filter_by(cluster_id=cluster_id)

View File

@ -87,15 +87,6 @@ class MagnumBase(models.TimestampMixin,
d[c.name] = self[c.name]
return d
def save(self, session=None):
import magnum.db.sqlalchemy.api as db_api
if session is None:
session = db_api.get_session()
with session.begin():
super(MagnumBase, self).save(session)
Base = declarative_base(cls=MagnumBase)

View File

@ -16,10 +16,10 @@
"""Magnum DB test base class."""
import fixtures
from oslo_db.sqlalchemy import enginefacade
import magnum.conf
from magnum.db import api as dbapi
from magnum.db.sqlalchemy import api as sqla_api
from magnum.db.sqlalchemy import migration
from magnum.db.sqlalchemy import models
from magnum.tests import base
@ -32,15 +32,14 @@ _DB_CACHE = None
class Database(fixtures.Fixture):
def __init__(self, db_api, db_migrate, sql_connection):
def __init__(self, engine, db_migrate, sql_connection):
self.sql_connection = sql_connection
self.engine = db_api.get_engine()
self.engine = engine
self.engine.dispose()
conn = self.engine.connect()
with self.engine.connect() as conn:
self.setup_sqlite(db_migrate)
self.post_migrations()
self._DB = "".join(line for line in conn.connection.iterdump())
self.engine.dispose()
@ -50,8 +49,9 @@ class Database(fixtures.Fixture):
models.Base.metadata.create_all(self.engine)
db_migrate.stamp('head')
def _setUp(self):
conn = self.engine.connect()
def setUp(self):
super(Database, self).setUp()
with self.engine.connect() as conn:
conn.connection.executescript(self._DB)
self.addCleanup(self.engine.dispose)
@ -68,6 +68,8 @@ class DbTestCase(base.TestCase):
global _DB_CACHE
if not _DB_CACHE:
_DB_CACHE = Database(sqla_api, migration,
engine = enginefacade.writer.get_engine()
_DB_CACHE = Database(engine, migration,
sql_connection=CONF.database.connection)
engine.dispose()
self.useFixture(_DB_CACHE)

View File

@ -26,16 +26,22 @@ class SqlAlchemyCustomTypesTestCase(base.DbTestCase):
# Create ClusterTemplate w/o labels
cluster_template1_id = uuidutils.generate_uuid()
self.dbapi.create_cluster_template({'uuid': cluster_template1_id})
cluster_template1 = sa_api.model_query(
models.ClusterTemplate).filter_by(uuid=cluster_template1_id).one()
with sa_api._session_for_read() as session:
cluster_template1 = (session.query(
models.ClusterTemplate)
.filter_by(uuid=cluster_template1_id)
.one())
self.assertEqual({}, cluster_template1.labels)
# Create ClusterTemplate with labels
cluster_template2_id = uuidutils.generate_uuid()
self.dbapi.create_cluster_template(
{'uuid': cluster_template2_id, 'labels': {'bar': 'foo'}})
cluster_template2 = sa_api.model_query(
models.ClusterTemplate).filter_by(uuid=cluster_template2_id).one()
with sa_api._session_for_read() as session:
cluster_template2 = (session.query(
models.ClusterTemplate)
.filter_by(uuid=cluster_template2_id)
.one())
self.assertEqual('foo', cluster_template2.labels['bar'])
def test_JSONEncodedDict_type_check(self):
@ -48,7 +54,8 @@ class SqlAlchemyCustomTypesTestCase(base.DbTestCase):
# Create nodegroup w/o node_addresses
nodegroup1_id = uuidutils.generate_uuid()
self.dbapi.create_nodegroup({'uuid': nodegroup1_id})
nodegroup1 = sa_api.model_query(
with sa_api._session_for_read() as session:
nodegroup1 = session.query(
models.NodeGroup).filter_by(uuid=nodegroup1_id).one()
self.assertEqual([], nodegroup1.node_addresses)
@ -59,7 +66,8 @@ class SqlAlchemyCustomTypesTestCase(base.DbTestCase):
'node_addresses': ['mynode_address1',
'mynode_address2']
})
nodegroup2 = sa_api.model_query(
with sa_api._session_for_read() as session:
nodegroup2 = session.query(
models.NodeGroup).filter_by(uuid=nodegroup2_id).one()
self.assertEqual(['mynode_address1', 'mynode_address2'],
nodegroup2.node_addresses)