Switch to using enginefacade

Closes-Bug: #2067345
Change-Id: If9a2c96628cfcb819fee5e19f872ea015979b30f
(cherry picked from commit 0ce2c41404)
This commit is contained in:
Mohammed Naser 2024-08-28 19:16:07 -04:00 committed by Michal Nasiadka
parent b82bd6ae51
commit d5f6dcf9ff
6 changed files with 437 additions and 377 deletions

View File

@ -12,6 +12,7 @@
from eventlet.green import threading from eventlet.green import threading
from oslo_context import context from oslo_context import context
from oslo_db.sqlalchemy import enginefacade
from magnum.common import policy from magnum.common import policy
@ -20,6 +21,7 @@ import magnum.conf
CONF = magnum.conf.CONF CONF = magnum.conf.CONF
@enginefacade.transaction_context_provider
class RequestContext(context.RequestContext): class RequestContext(context.RequestContext):
"""Extends security contexts from the OpenStack common library.""" """Extends security contexts from the OpenStack common library."""

View File

@ -13,8 +13,8 @@
from logging import config as log_config from logging import config as log_config
from alembic import context from alembic import context
from oslo_db.sqlalchemy import enginefacade
from magnum.db.sqlalchemy import api as sqla_api
from magnum.db.sqlalchemy import models from magnum.db.sqlalchemy import models
# this is the Alembic Config object, which provides # this is the Alembic Config object, which provides
@ -43,7 +43,7 @@ def run_migrations_online():
and associate a connection with the context. and associate a connection with the context.
""" """
engine = sqla_api.get_engine() engine = enginefacade.writer.get_engine()
with engine.connect() as connection: with engine.connect() as connection:
context.configure(connection=connection, context.configure(connection=connection,
target_metadata=target_metadata) target_metadata=target_metadata)

View File

@ -14,8 +14,11 @@
"""SQLAlchemy storage backend.""" """SQLAlchemy storage backend."""
import threading
from oslo_db import api as oslo_db_api
from oslo_db import exception as db_exc from oslo_db import exception as db_exc
from oslo_db.sqlalchemy import session as db_session from oslo_db.sqlalchemy import enginefacade
from oslo_db.sqlalchemy import utils as db_utils from oslo_db.sqlalchemy import utils as db_utils
from oslo_log import log from oslo_log import log
from oslo_utils import importutils from oslo_utils import importutils
@ -35,34 +38,13 @@ from magnum.db import api
from magnum.db.sqlalchemy import models from magnum.db.sqlalchemy import models
from magnum.i18n import _ from magnum.i18n import _
profiler_sqlalchemy = importutils.try_import('osprofiler.sqlalchemy') profiler_sqlalchemy = importutils.try_import("osprofiler.sqlalchemy")
CONF = magnum.conf.CONF CONF = magnum.conf.CONF
LOG = log.getLogger(__name__) LOG = log.getLogger(__name__)
_FACADE = None _CONTEXT = threading.local()
def _create_facade_lazily():
global _FACADE
if _FACADE is None:
_FACADE = db_session.EngineFacade.from_config(CONF)
if profiler_sqlalchemy:
if CONF.profiler.enabled and CONF.profiler.trace_sqlalchemy:
profiler_sqlalchemy.add_tracing(sa, _FACADE.get_engine(), "db")
return _FACADE
def get_engine():
facade = _create_facade_lazily()
return facade.get_engine()
def get_session(**kwargs):
facade = _create_facade_lazily()
return facade.get_session(**kwargs)
def get_backend(): def get_backend():
@ -70,15 +52,21 @@ def get_backend():
return Connection() return Connection()
def model_query(model, *args, **kwargs): def _session_for_read():
"""Query helper for simpler session usage. return _wrap_session(enginefacade.reader.using(_CONTEXT))
:param session: if present, the session to use
"""
session = kwargs.get('session') or get_session() # NOTE(tylerchristie) Please add @oslo_db_api.retry_on_deadlock decorator to
query = session.query(model, *args) # any new methods using _session_for_write (as deadlocks happen on write), so
return query # that oslo_db is able to retry in case of deadlocks.
def _session_for_write():
return _wrap_session(enginefacade.writer.using(_CONTEXT))
def _wrap_session(session):
if CONF.profiler.enabled and CONF.profiler.trace_sqlalchemy:
session = profiler_sqlalchemy.wrap_session(sa, session)
return session
def add_identity_filter(query, value): def add_identity_filter(query, value):
@ -101,8 +89,6 @@ def add_identity_filter(query, value):
def _paginate_query(model, limit=None, marker=None, sort_key=None, def _paginate_query(model, limit=None, marker=None, sort_key=None,
sort_dir=None, query=None): sort_dir=None, query=None):
if not query:
query = model_query(model)
sort_keys = ['id'] sort_keys = ['id']
if sort_key and sort_key not in sort_keys: if sort_key and sort_key not in sort_keys:
sort_keys.insert(0, sort_key) sort_keys.insert(0, sort_key)
@ -166,7 +152,8 @@ class Connection(api.Connection):
# Helper to filter based on node_count field from nodegroups # Helper to filter based on node_count field from nodegroups
def filter_node_count(query, node_count, is_master=False): def filter_node_count(query, node_count, is_master=False):
nfunc = func.sum(models.NodeGroup.node_count) nfunc = func.sum(models.NodeGroup.node_count)
nquery = model_query(models.NodeGroup) with _session_for_read() as session:
nquery = session.query(models.NodeGroup)
if is_master: if is_master:
nquery = nquery.filter(models.NodeGroup.role == 'master') nquery = nquery.filter(models.NodeGroup.role == 'master')
else: else:
@ -187,12 +174,14 @@ class Connection(api.Connection):
def get_cluster_list(self, context, filters=None, limit=None, marker=None, def get_cluster_list(self, context, filters=None, limit=None, marker=None,
sort_key=None, sort_dir=None): sort_key=None, sort_dir=None):
query = model_query(models.Cluster) with _session_for_read() as session:
query = session.query(models.Cluster)
query = self._add_tenant_filters(context, query) query = self._add_tenant_filters(context, query)
query = self._add_clusters_filters(query, filters) query = self._add_clusters_filters(query, filters)
return _paginate_query(models.Cluster, limit, marker, return _paginate_query(
sort_key, sort_dir, query) models.Cluster, limit, marker, sort_key, sort_dir, query)
@oslo_db_api.retry_on_deadlock
def create_cluster(self, values): def create_cluster(self, values):
# ensure defaults are present for new clusters # ensure defaults are present for new clusters
if not values.get('uuid'): if not values.get('uuid'):
@ -200,14 +189,18 @@ class Connection(api.Connection):
cluster = models.Cluster() cluster = models.Cluster()
cluster.update(values) cluster.update(values)
with _session_for_write() as session:
try: try:
cluster.save() session.add(cluster)
session.flush()
except db_exc.DBDuplicateEntry: except db_exc.DBDuplicateEntry:
raise exception.ClusterAlreadyExists(uuid=values['uuid']) raise exception.ClusterAlreadyExists(uuid=values['uuid'])
return cluster return cluster
def get_cluster_by_id(self, context, cluster_id): def get_cluster_by_id(self, context, cluster_id):
query = model_query(models.Cluster) with _session_for_read() as session:
query = session.query(models.Cluster)
query = self._add_tenant_filters(context, query) query = self._add_tenant_filters(context, query)
query = query.filter_by(id=cluster_id) query = query.filter_by(id=cluster_id)
try: try:
@ -216,19 +209,22 @@ class Connection(api.Connection):
raise exception.ClusterNotFound(cluster=cluster_id) raise exception.ClusterNotFound(cluster=cluster_id)
def get_cluster_by_name(self, context, cluster_name): def get_cluster_by_name(self, context, cluster_name):
query = model_query(models.Cluster) with _session_for_read() as session:
query = session.query(models.Cluster)
query = self._add_tenant_filters(context, query) query = self._add_tenant_filters(context, query)
query = query.filter_by(name=cluster_name) query = query.filter_by(name=cluster_name)
try: try:
return query.one() return query.one()
except MultipleResultsFound: except MultipleResultsFound:
raise exception.Conflict('Multiple clusters exist with same name.' raise exception.Conflict(
'Multiple clusters exist with same name.'
' Please use the cluster uuid instead.') ' Please use the cluster uuid instead.')
except NoResultFound: except NoResultFound:
raise exception.ClusterNotFound(cluster=cluster_name) raise exception.ClusterNotFound(cluster=cluster_name)
def get_cluster_by_uuid(self, context, cluster_uuid): def get_cluster_by_uuid(self, context, cluster_uuid):
query = model_query(models.Cluster) with _session_for_read() as session:
query = session.query(models.Cluster)
query = self._add_tenant_filters(context, query) query = self._add_tenant_filters(context, query)
query = query.filter_by(uuid=cluster_uuid) query = query.filter_by(uuid=cluster_uuid)
try: try:
@ -237,7 +233,8 @@ class Connection(api.Connection):
raise exception.ClusterNotFound(cluster=cluster_uuid) raise exception.ClusterNotFound(cluster=cluster_uuid)
def get_cluster_stats(self, context, project_id=None): def get_cluster_stats(self, context, project_id=None):
query = model_query(models.Cluster) with _session_for_read() as session:
query = session.query(models.Cluster)
node_count_col = models.NodeGroup.node_count node_count_col = models.NodeGroup.node_count
ncfunc = func.sum(node_count_col) ncfunc = func.sum(node_count_col)
@ -253,15 +250,16 @@ class Connection(api.Connection):
return clusters, nodes return clusters, nodes
def get_cluster_count_all(self, context, filters=None): def get_cluster_count_all(self, context, filters=None):
query = model_query(models.Cluster) with _session_for_read() as session:
query = session.query(models.Cluster)
query = self._add_tenant_filters(context, query) query = self._add_tenant_filters(context, query)
query = self._add_clusters_filters(query, filters) query = self._add_clusters_filters(query, filters)
return query.count() return query.count()
@oslo_db_api.retry_on_deadlock
def destroy_cluster(self, cluster_id): def destroy_cluster(self, cluster_id):
session = get_session() with _session_for_write() as session:
with session.begin(): query = session.query(models.Cluster)
query = model_query(models.Cluster, session=session)
query = add_identity_filter(query, cluster_id) query = add_identity_filter(query, cluster_id)
try: try:
@ -279,10 +277,10 @@ class Connection(api.Connection):
return self._do_update_cluster(cluster_id, values) return self._do_update_cluster(cluster_id, values)
@oslo_db_api.retry_on_deadlock
def _do_update_cluster(self, cluster_id, values): def _do_update_cluster(self, cluster_id, values):
session = get_session() with _session_for_write() as session:
with session.begin(): query = session.query(models.Cluster)
query = model_query(models.Cluster, session=session)
query = add_identity_filter(query, cluster_id) query = add_identity_filter(query, cluster_id)
try: try:
ref = query.with_for_update().one() ref = query.with_for_update().one()
@ -309,22 +307,24 @@ class Connection(api.Connection):
def get_cluster_template_list(self, context, filters=None, limit=None, def get_cluster_template_list(self, context, filters=None, limit=None,
marker=None, sort_key=None, sort_dir=None): marker=None, sort_key=None, sort_dir=None):
query = model_query(models.ClusterTemplate) with _session_for_read() as session:
query = session.query(models.ClusterTemplate)
query = self._add_tenant_filters(context, query) query = self._add_tenant_filters(context, query)
query = self._add_cluster_template_filters(query, filters) query = self._add_cluster_template_filters(query, filters)
# include public (and not hidden) ClusterTemplates # include public (and not hidden) ClusterTemplates
public_q = model_query(models.ClusterTemplate).filter_by( public_q = session.query(models.ClusterTemplate).filter_by(
public=True, hidden=False) public=True, hidden=False)
query = query.union(public_q) query = query.union(public_q)
# include hidden and public ClusterTemplate if admin # include hidden and public ClusterTemplate if admin
if context.is_admin: if context.is_admin:
hidden_q = model_query(models.ClusterTemplate).filter_by( hidden_q = session.query(models.ClusterTemplate).filter_by(
public=True, hidden=True) public=True, hidden=True)
query = query.union(hidden_q) query = query.union(hidden_q)
return _paginate_query(models.ClusterTemplate, limit, marker, return _paginate_query(models.ClusterTemplate, limit, marker,
sort_key, sort_dir, query) sort_key, sort_dir, query)
@oslo_db_api.retry_on_deadlock
def create_cluster_template(self, values): def create_cluster_template(self, values):
# ensure defaults are present for new ClusterTemplates # ensure defaults are present for new ClusterTemplates
if not values.get('uuid'): if not values.get('uuid'):
@ -332,18 +332,25 @@ class Connection(api.Connection):
cluster_template = models.ClusterTemplate() cluster_template = models.ClusterTemplate()
cluster_template.update(values) cluster_template.update(values)
with _session_for_write() as session:
try: try:
cluster_template.save() session.add(cluster_template)
session.flush()
except db_exc.DBDuplicateEntry: except db_exc.DBDuplicateEntry:
raise exception.ClusterTemplateAlreadyExists(uuid=values['uuid']) raise exception.ClusterTemplateAlreadyExists(
uuid=values['uuid'])
return cluster_template return cluster_template
def get_cluster_template_by_id(self, context, cluster_template_id): def get_cluster_template_by_id(self, context, cluster_template_id):
query = model_query(models.ClusterTemplate) with _session_for_read() as session:
query = session.query(models.ClusterTemplate)
query = self._add_tenant_filters(context, query) query = self._add_tenant_filters(context, query)
public_q = model_query(models.ClusterTemplate).filter_by(public=True) public_q = session.query(models.ClusterTemplate).filter_by(
public=True)
query = query.union(public_q) query = query.union(public_q)
query = query.filter(models.ClusterTemplate.id == cluster_template_id) query = query.filter(
models.ClusterTemplate.id == cluster_template_id)
try: try:
return query.one() return query.one()
except NoResultFound: except NoResultFound:
@ -351,9 +358,11 @@ class Connection(api.Connection):
clustertemplate=cluster_template_id) clustertemplate=cluster_template_id)
def get_cluster_template_by_uuid(self, context, cluster_template_uuid): def get_cluster_template_by_uuid(self, context, cluster_template_uuid):
query = model_query(models.ClusterTemplate) with _session_for_read() as session:
query = session.query(models.ClusterTemplate)
query = self._add_tenant_filters(context, query) query = self._add_tenant_filters(context, query)
public_q = model_query(models.ClusterTemplate).filter_by(public=True) public_q = session.query(models.ClusterTemplate).filter_by(
public=True)
query = query.union(public_q) query = query.union(public_q)
query = query.filter( query = query.filter(
models.ClusterTemplate.uuid == cluster_template_uuid) models.ClusterTemplate.uuid == cluster_template_uuid)
@ -364,16 +373,19 @@ class Connection(api.Connection):
clustertemplate=cluster_template_uuid) clustertemplate=cluster_template_uuid)
def get_cluster_template_by_name(self, context, cluster_template_name): def get_cluster_template_by_name(self, context, cluster_template_name):
query = model_query(models.ClusterTemplate) with _session_for_read() as session:
query = session.query(models.ClusterTemplate)
query = self._add_tenant_filters(context, query) query = self._add_tenant_filters(context, query)
public_q = model_query(models.ClusterTemplate).filter_by(public=True) public_q = session.query(models.ClusterTemplate).filter_by(
public=True)
query = query.union(public_q) query = query.union(public_q)
query = query.filter( query = query.filter(
models.ClusterTemplate.name == cluster_template_name) models.ClusterTemplate.name == cluster_template_name)
try: try:
return query.one() return query.one()
except MultipleResultsFound: except MultipleResultsFound:
raise exception.Conflict('Multiple ClusterTemplates exist with' raise exception.Conflict(
'Multiple ClusterTemplates exist with'
' same name. Please use the ' ' same name. Please use the '
'ClusterTemplate uuid instead.') 'ClusterTemplate uuid instead.')
except NoResultFound: except NoResultFound:
@ -382,9 +394,9 @@ class Connection(api.Connection):
def _is_cluster_template_referenced(self, session, cluster_template_uuid): def _is_cluster_template_referenced(self, session, cluster_template_uuid):
"""Checks whether the ClusterTemplate is referenced by cluster(s).""" """Checks whether the ClusterTemplate is referenced by cluster(s)."""
query = model_query(models.Cluster, session=session) query = session.query(models.Cluster)
query = self._add_clusters_filters(query, {'cluster_template_id': query = self._add_clusters_filters(
cluster_template_uuid}) query, {'cluster_template_id': cluster_template_uuid})
return query.count() != 0 return query.count() != 0
def _is_publishing_cluster_template(self, values): def _is_publishing_cluster_template(self, values):
@ -395,10 +407,10 @@ class Connection(api.Connection):
return True return True
return False return False
@oslo_db_api.retry_on_deadlock
def destroy_cluster_template(self, cluster_template_id): def destroy_cluster_template(self, cluster_template_id):
session = get_session() with _session_for_write() as session:
with session.begin(): query = session.query(models.ClusterTemplate)
query = model_query(models.ClusterTemplate, session=session)
query = add_identity_filter(query, cluster_template_id) query = add_identity_filter(query, cluster_template_id)
try: try:
@ -422,10 +434,10 @@ class Connection(api.Connection):
return self._do_update_cluster_template(cluster_template_id, values) return self._do_update_cluster_template(cluster_template_id, values)
@oslo_db_api.retry_on_deadlock
def _do_update_cluster_template(self, cluster_template_id, values): def _do_update_cluster_template(self, cluster_template_id, values):
session = get_session() with _session_for_write() as session:
with session.begin(): query = session.query(models.ClusterTemplate)
query = model_query(models.ClusterTemplate, session=session)
query = add_identity_filter(query, cluster_template_id) query = add_identity_filter(query, cluster_template_id)
try: try:
ref = query.with_for_update().one() ref = query.with_for_update().one()
@ -444,6 +456,7 @@ class Connection(api.Connection):
ref.update(values) ref.update(values)
return ref return ref
@oslo_db_api.retry_on_deadlock
def create_x509keypair(self, values): def create_x509keypair(self, values):
# ensure defaults are present for new x509keypairs # ensure defaults are present for new x509keypairs
if not values.get('uuid'): if not values.get('uuid'):
@ -451,14 +464,18 @@ class Connection(api.Connection):
x509keypair = models.X509KeyPair() x509keypair = models.X509KeyPair()
x509keypair.update(values) x509keypair.update(values)
with _session_for_write() as session:
try: try:
x509keypair.save() session.add(x509keypair)
session.flush()
except db_exc.DBDuplicateEntry: except db_exc.DBDuplicateEntry:
raise exception.X509KeyPairAlreadyExists(uuid=values['uuid']) raise exception.X509KeyPairAlreadyExists(uuid=values['uuid'])
return x509keypair return x509keypair
def get_x509keypair_by_id(self, context, x509keypair_id): def get_x509keypair_by_id(self, context, x509keypair_id):
query = model_query(models.X509KeyPair) with _session_for_read() as session:
query = session.query(models.X509KeyPair)
query = self._add_tenant_filters(context, query) query = self._add_tenant_filters(context, query)
query = query.filter_by(id=x509keypair_id) query = query.filter_by(id=x509keypair_id)
try: try:
@ -467,18 +484,20 @@ class Connection(api.Connection):
raise exception.X509KeyPairNotFound(x509keypair=x509keypair_id) raise exception.X509KeyPairNotFound(x509keypair=x509keypair_id)
def get_x509keypair_by_uuid(self, context, x509keypair_uuid): def get_x509keypair_by_uuid(self, context, x509keypair_uuid):
query = model_query(models.X509KeyPair) with _session_for_read() as session:
query = session.query(models.X509KeyPair)
query = self._add_tenant_filters(context, query) query = self._add_tenant_filters(context, query)
query = query.filter_by(uuid=x509keypair_uuid) query = query.filter_by(uuid=x509keypair_uuid)
try: try:
return query.one() return query.one()
except NoResultFound: except NoResultFound:
raise exception.X509KeyPairNotFound(x509keypair=x509keypair_uuid) raise exception.X509KeyPairNotFound(
x509keypair=x509keypair_uuid)
@oslo_db_api.retry_on_deadlock
def destroy_x509keypair(self, x509keypair_id): def destroy_x509keypair(self, x509keypair_id):
session = get_session() with _session_for_write() as session:
with session.begin(): query = session.query(models.X509KeyPair)
query = model_query(models.X509KeyPair, session=session)
query = add_identity_filter(query, x509keypair_id) query = add_identity_filter(query, x509keypair_id)
count = query.delete() count = query.delete()
if count != 1: if count != 1:
@ -492,10 +511,10 @@ class Connection(api.Connection):
return self._do_update_x509keypair(x509keypair_id, values) return self._do_update_x509keypair(x509keypair_id, values)
@oslo_db_api.retry_on_deadlock
def _do_update_x509keypair(self, x509keypair_id, values): def _do_update_x509keypair(self, x509keypair_id, values):
session = get_session() with _session_for_write() as session:
with session.begin(): query = session.query(models.X509KeyPair)
query = model_query(models.X509KeyPair, session=session)
query = add_identity_filter(query, x509keypair_id) query = add_identity_filter(query, x509keypair_id)
try: try:
ref = query.with_for_update().one() ref = query.with_for_update().one()
@ -518,26 +537,27 @@ class Connection(api.Connection):
def get_x509keypair_list(self, context, filters=None, limit=None, def get_x509keypair_list(self, context, filters=None, limit=None,
marker=None, sort_key=None, sort_dir=None): marker=None, sort_key=None, sort_dir=None):
query = model_query(models.X509KeyPair) with _session_for_read() as session:
query = session.query(models.X509KeyPair)
query = self._add_tenant_filters(context, query) query = self._add_tenant_filters(context, query)
query = self._add_x509keypairs_filters(query, filters) query = self._add_x509keypairs_filters(query, filters)
return _paginate_query(models.X509KeyPair, limit, marker, return _paginate_query(
sort_key, sort_dir, query) models.X509KeyPair, limit, marker, sort_key, sort_dir, query)
@oslo_db_api.retry_on_deadlock
def destroy_magnum_service(self, magnum_service_id): def destroy_magnum_service(self, magnum_service_id):
session = get_session() with _session_for_write() as session:
with session.begin(): query = session.query(models.MagnumService)
query = model_query(models.MagnumService, session=session)
query = add_identity_filter(query, magnum_service_id) query = add_identity_filter(query, magnum_service_id)
count = query.delete() count = query.delete()
if count != 1: if count != 1:
raise exception.MagnumServiceNotFound( raise exception.MagnumServiceNotFound(
magnum_service_id=magnum_service_id) magnum_service_id=magnum_service_id)
@oslo_db_api.retry_on_deadlock
def update_magnum_service(self, magnum_service_id, values): def update_magnum_service(self, magnum_service_id, values):
session = get_session() with _session_for_write() as session:
with session.begin(): query = session.query(models.MagnumService)
query = model_query(models.MagnumService, session=session)
query = add_identity_filter(query, magnum_service_id) query = add_identity_filter(query, magnum_service_id)
try: try:
ref = query.with_for_update().one() ref = query.with_for_update().one()
@ -553,25 +573,32 @@ class Connection(api.Connection):
return ref return ref
def get_magnum_service_by_host_and_binary(self, host, binary): def get_magnum_service_by_host_and_binary(self, host, binary):
query = model_query(models.MagnumService) with _session_for_read() as session:
query = session.query(models.MagnumService)
query = query.filter_by(host=host, binary=binary) query = query.filter_by(host=host, binary=binary)
try: try:
return query.one() return query.one()
except NoResultFound: except NoResultFound:
return None return None
@oslo_db_api.retry_on_deadlock
def create_magnum_service(self, values): def create_magnum_service(self, values):
magnum_service = models.MagnumService() magnum_service = models.MagnumService()
magnum_service.update(values) magnum_service.update(values)
with _session_for_write() as session:
try: try:
magnum_service.save() session.add(magnum_service)
session.flush()
except db_exc.DBDuplicateEntry: except db_exc.DBDuplicateEntry:
host = values["host"] host = values["host"]
binary = values["binary"] binary = values["binary"]
LOG.warning("Magnum service with same host:%(host)s and" LOG.warning(
"Magnum service with same host:%(host)s and"
" binary:%(binary)s had been saved into DB", " binary:%(binary)s had been saved into DB",
{'host': host, 'binary': binary}) {'host': host, 'binary': binary})
query = model_query(models.MagnumService) with _session_for_read() as read_session:
query = read_session.query(models.MagnumService)
query = query.filter_by(host=host, binary=binary) query = query.filter_by(host=host, binary=binary)
return query.one() return query.one()
return magnum_service return magnum_service
@ -579,20 +606,26 @@ class Connection(api.Connection):
def get_magnum_service_list(self, disabled=None, limit=None, def get_magnum_service_list(self, disabled=None, limit=None,
marker=None, sort_key=None, sort_dir=None marker=None, sort_key=None, sort_dir=None
): ):
query = model_query(models.MagnumService) with _session_for_read() as session:
query = session.query(models.MagnumService)
if disabled: if disabled:
query = query.filter_by(disabled=disabled) query = query.filter_by(disabled=disabled)
return _paginate_query(models.MagnumService, limit, marker, return _paginate_query(
sort_key, sort_dir, query) models.MagnumService, limit, marker, sort_key, sort_dir, query)
@oslo_db_api.retry_on_deadlock
def create_quota(self, values): def create_quota(self, values):
quotas = models.Quota() quotas = models.Quota()
quotas.update(values) quotas.update(values)
with _session_for_write() as session:
try: try:
quotas.save() session.add(quotas)
session.flush()
except db_exc.DBDuplicateEntry: except db_exc.DBDuplicateEntry:
raise exception.QuotaAlreadyExists(project_id=values['project_id'], raise exception.QuotaAlreadyExists(
project_id=values['project_id'],
resource=values['resource']) resource=values['resource'])
return quotas return quotas
@ -611,15 +644,16 @@ class Connection(api.Connection):
def get_quota_list(self, context, filters=None, limit=None, marker=None, def get_quota_list(self, context, filters=None, limit=None, marker=None,
sort_key=None, sort_dir=None): sort_key=None, sort_dir=None):
query = model_query(models.Quota) with _session_for_read() as session:
query = session.query(models.Quota)
query = self._add_quota_filters(query, filters) query = self._add_quota_filters(query, filters)
return _paginate_query(models.Quota, limit, marker, return _paginate_query(
sort_key, sort_dir, query) models.Quota, limit, marker, sort_key, sort_dir, query)
@oslo_db_api.retry_on_deadlock
def update_quota(self, project_id, values): def update_quota(self, project_id, values):
session = get_session() with _session_for_write() as session:
with session.begin(): query = session.query(models.Quota)
query = model_query(models.Quota, session=session)
resource = values['resource'] resource = values['resource']
try: try:
query = query.filter_by(project_id=project_id).filter_by( query = query.filter_by(project_id=project_id).filter_by(
@ -633,12 +667,13 @@ class Connection(api.Connection):
ref.update(values) ref.update(values)
return ref return ref
@oslo_db_api.retry_on_deadlock
def delete_quota(self, project_id, resource): def delete_quota(self, project_id, resource):
session = get_session() with _session_for_write() as session:
with session.begin(): query = (
query = model_query(models.Quota, session=session) \ session.query(models.Quota)
.filter_by(project_id=project_id) \ .filter_by(project_id=project_id)
.filter_by(resource=resource) .filter_by(resource=resource))
try: try:
query.one() query.one()
@ -650,7 +685,8 @@ class Connection(api.Connection):
query.delete() query.delete()
def get_quota_by_id(self, context, quota_id): def get_quota_by_id(self, context, quota_id):
query = model_query(models.Quota) with _session_for_read() as session:
query = session.query(models.Quota)
query = query.filter_by(id=quota_id) query = query.filter_by(id=quota_id)
try: try:
return query.one() return query.one()
@ -659,13 +695,15 @@ class Connection(api.Connection):
raise exception.QuotaNotFound(msg=msg) raise exception.QuotaNotFound(msg=msg)
def quota_get_all_by_project_id(self, project_id): def quota_get_all_by_project_id(self, project_id):
query = model_query(models.Quota) with _session_for_read() as session:
query = session.query(models.Quota)
result = query.filter_by(project_id=project_id).all() result = query.filter_by(project_id=project_id).all()
return result return result
def get_quota_by_project_id_resource(self, project_id, resource): def get_quota_by_project_id_resource(self, project_id, resource):
query = model_query(models.Quota) with _session_for_read() as session:
query = session.query(models.Quota)
query = query.filter_by(project_id=project_id).filter_by( query = query.filter_by(project_id=project_id).filter_by(
resource=resource) resource=resource)
@ -701,7 +739,8 @@ class Connection(api.Connection):
return query return query
def get_federation_by_id(self, context, federation_id): def get_federation_by_id(self, context, federation_id):
query = model_query(models.Federation) with _session_for_read() as session:
query = session.query(models.Federation)
query = self._add_tenant_filters(context, query) query = self._add_tenant_filters(context, query)
query = query.filter_by(id=federation_id) query = query.filter_by(id=federation_id)
try: try:
@ -710,7 +749,8 @@ class Connection(api.Connection):
raise exception.FederationNotFound(federation=federation_id) raise exception.FederationNotFound(federation=federation_id)
def get_federation_by_uuid(self, context, federation_uuid): def get_federation_by_uuid(self, context, federation_uuid):
query = model_query(models.Federation) with _session_for_read() as session:
query = session.query(models.Federation)
query = self._add_tenant_filters(context, query) query = self._add_tenant_filters(context, query)
query = query.filter_by(uuid=federation_uuid) query = query.filter_by(uuid=federation_uuid)
try: try:
@ -719,13 +759,15 @@ class Connection(api.Connection):
raise exception.FederationNotFound(federation=federation_uuid) raise exception.FederationNotFound(federation=federation_uuid)
def get_federation_by_name(self, context, federation_name): def get_federation_by_name(self, context, federation_name):
query = model_query(models.Federation) with _session_for_read() as session:
query = session.query(models.Federation)
query = self._add_tenant_filters(context, query) query = self._add_tenant_filters(context, query)
query = query.filter_by(name=federation_name) query = query.filter_by(name=federation_name)
try: try:
return query.one() return query.one()
except MultipleResultsFound: except MultipleResultsFound:
raise exception.Conflict('Multiple federations exist with same ' raise exception.Conflict(
'Multiple federations exist with same '
'name. Please use the federation uuid ' 'name. Please use the federation uuid '
'instead.') 'instead.')
except NoResultFound: except NoResultFound:
@ -733,28 +775,33 @@ class Connection(api.Connection):
def get_federation_list(self, context, limit=None, marker=None, def get_federation_list(self, context, limit=None, marker=None,
sort_key=None, sort_dir=None, filters=None): sort_key=None, sort_dir=None, filters=None):
query = model_query(models.Federation) with _session_for_read() as session:
query = session.query(models.Federation)
query = self._add_tenant_filters(context, query) query = self._add_tenant_filters(context, query)
query = self._add_federation_filters(query, filters) query = self._add_federation_filters(query, filters)
return _paginate_query(models.Federation, limit, marker, return _paginate_query(
sort_key, sort_dir, query) models.Federation, limit, marker, sort_key, sort_dir, query)
@oslo_db_api.retry_on_deadlock
def create_federation(self, values): def create_federation(self, values):
if not values.get('uuid'): if not values.get('uuid'):
values['uuid'] = uuidutils.generate_uuid() values['uuid'] = uuidutils.generate_uuid()
federation = models.Federation() federation = models.Federation()
federation.update(values) federation.update(values)
with _session_for_write() as session:
try: try:
federation.save() session.add(federation)
session.flush()
except db_exc.DBDuplicateEntry: except db_exc.DBDuplicateEntry:
raise exception.FederationAlreadyExists(uuid=values['uuid']) raise exception.FederationAlreadyExists(uuid=values['uuid'])
return federation return federation
@oslo_db_api.retry_on_deadlock
def destroy_federation(self, federation_id): def destroy_federation(self, federation_id):
session = get_session() with _session_for_write() as session:
with session.begin(): query = session.query(models.Federation)
query = model_query(models.Federation, session=session)
query = add_identity_filter(query, federation_id) query = add_identity_filter(query, federation_id)
try: try:
@ -771,10 +818,10 @@ class Connection(api.Connection):
return self._do_update_federation(federation_id, values) return self._do_update_federation(federation_id, values)
@oslo_db_api.retry_on_deadlock
def _do_update_federation(self, federation_id, values): def _do_update_federation(self, federation_id, values):
session = get_session() with _session_for_write() as session:
with session.begin(): query = session.query(models.Federation)
query = model_query(models.Federation, session=session)
query = add_identity_filter(query, federation_id) query = add_identity_filter(query, federation_id)
try: try:
ref = query.with_for_update().one() ref = query.with_for_update().one()
@ -804,23 +851,27 @@ class Connection(api.Connection):
return query return query
@oslo_db_api.retry_on_deadlock
def create_nodegroup(self, values): def create_nodegroup(self, values):
if not values.get('uuid'): if not values.get('uuid'):
values['uuid'] = uuidutils.generate_uuid() values['uuid'] = uuidutils.generate_uuid()
nodegroup = models.NodeGroup() nodegroup = models.NodeGroup()
nodegroup.update(values) nodegroup.update(values)
with _session_for_write() as session:
try: try:
nodegroup.save() session.add(nodegroup)
session.flush()
except db_exc.DBDuplicateEntry: except db_exc.DBDuplicateEntry:
raise exception.NodeGroupAlreadyExists( raise exception.NodeGroupAlreadyExists(
cluster_id=values['cluster_id'], name=values['name']) cluster_id=values['cluster_id'], name=values['name'])
return nodegroup return nodegroup
@oslo_db_api.retry_on_deadlock
def destroy_nodegroup(self, cluster_id, nodegroup_id): def destroy_nodegroup(self, cluster_id, nodegroup_id):
session = get_session() with _session_for_write() as session:
with session.begin(): query = session.query(models.NodeGroup)
query = model_query(models.NodeGroup, session=session)
query = add_identity_filter(query, nodegroup_id) query = add_identity_filter(query, nodegroup_id)
query = query.filter_by(cluster_id=cluster_id) query = query.filter_by(cluster_id=cluster_id)
try: try:
@ -832,10 +883,10 @@ class Connection(api.Connection):
def update_nodegroup(self, cluster_id, nodegroup_id, values): def update_nodegroup(self, cluster_id, nodegroup_id, values):
return self._do_update_nodegroup(cluster_id, nodegroup_id, values) return self._do_update_nodegroup(cluster_id, nodegroup_id, values)
@oslo_db_api.retry_on_deadlock
def _do_update_nodegroup(self, cluster_id, nodegroup_id, values): def _do_update_nodegroup(self, cluster_id, nodegroup_id, values):
session = get_session() with _session_for_write() as session:
with session.begin(): query = session.query(models.NodeGroup)
query = model_query(models.NodeGroup, session=session)
query = add_identity_filter(query, nodegroup_id) query = add_identity_filter(query, nodegroup_id)
query = query.filter_by(cluster_id=cluster_id) query = query.filter_by(cluster_id=cluster_id)
try: try:
@ -847,7 +898,8 @@ class Connection(api.Connection):
return ref return ref
def get_nodegroup_by_id(self, context, cluster_id, nodegroup_id): def get_nodegroup_by_id(self, context, cluster_id, nodegroup_id):
query = model_query(models.NodeGroup) with _session_for_read() as session:
query = session.query(models.NodeGroup)
if not context.is_admin: if not context.is_admin:
query = query.filter_by(project_id=context.project_id) query = query.filter_by(project_id=context.project_id)
query = query.filter_by(cluster_id=cluster_id) query = query.filter_by(cluster_id=cluster_id)
@ -858,7 +910,8 @@ class Connection(api.Connection):
raise exception.NodeGroupNotFound(nodegroup=nodegroup_id) raise exception.NodeGroupNotFound(nodegroup=nodegroup_id)
def get_nodegroup_by_uuid(self, context, cluster_id, nodegroup_uuid): def get_nodegroup_by_uuid(self, context, cluster_id, nodegroup_uuid):
query = model_query(models.NodeGroup) with _session_for_read() as session:
query = session.query(models.NodeGroup)
if not context.is_admin: if not context.is_admin:
query = query.filter_by(project_id=context.project_id) query = query.filter_by(project_id=context.project_id)
query = query.filter_by(cluster_id=cluster_id) query = query.filter_by(cluster_id=cluster_id)
@ -869,7 +922,8 @@ class Connection(api.Connection):
raise exception.NodeGroupNotFound(nodegroup=nodegroup_uuid) raise exception.NodeGroupNotFound(nodegroup=nodegroup_uuid)
def get_nodegroup_by_name(self, context, cluster_id, nodegroup_name): def get_nodegroup_by_name(self, context, cluster_id, nodegroup_name):
query = model_query(models.NodeGroup) with _session_for_read() as session:
query = session.query(models.NodeGroup)
if not context.is_admin: if not context.is_admin:
query = query.filter_by(project_id=context.project_id) query = query.filter_by(project_id=context.project_id)
query = query.filter_by(cluster_id=cluster_id) query = query.filter_by(cluster_id=cluster_id)
@ -877,7 +931,8 @@ class Connection(api.Connection):
try: try:
return query.one() return query.one()
except MultipleResultsFound: except MultipleResultsFound:
raise exception.Conflict('Multiple nodegroups exist with same ' raise exception.Conflict(
'Multiple nodegroups exist with same '
'name. Please use the nodegroup uuid ' 'name. Please use the nodegroup uuid '
'instead.') 'instead.')
except NoResultFound: except NoResultFound:
@ -886,16 +941,18 @@ class Connection(api.Connection):
def list_cluster_nodegroups(self, context, cluster_id, filters=None, def list_cluster_nodegroups(self, context, cluster_id, filters=None,
limit=None, marker=None, sort_key=None, limit=None, marker=None, sort_key=None,
sort_dir=None): sort_dir=None):
query = model_query(models.NodeGroup) with _session_for_read() as session:
query = session.query(models.NodeGroup)
if not context.is_admin: if not context.is_admin:
query = query.filter_by(project_id=context.project_id) query = query.filter_by(project_id=context.project_id)
query = query.filter_by(cluster_id=cluster_id) query = query.filter_by(cluster_id=cluster_id)
query = self._add_nodegoup_filters(query, filters) query = self._add_nodegoup_filters(query, filters)
return _paginate_query(models.NodeGroup, limit, marker, return _paginate_query(
sort_key, sort_dir, query) models.NodeGroup, limit, marker, sort_key, sort_dir, query)
def get_cluster_nodegroup_count(self, context, cluster_id): def get_cluster_nodegroup_count(self, context, cluster_id):
query = model_query(models.NodeGroup) with _session_for_read() as session:
query = session.query(models.NodeGroup)
if not context.is_admin: if not context.is_admin:
query = query.filter_by(project_id=context.project_id) query = query.filter_by(project_id=context.project_id)
query = query.filter_by(cluster_id=cluster_id) query = query.filter_by(cluster_id=cluster_id)

View File

@ -87,15 +87,6 @@ class MagnumBase(models.TimestampMixin,
d[c.name] = self[c.name] d[c.name] = self[c.name]
return d return d
def save(self, session=None):
import magnum.db.sqlalchemy.api as db_api
if session is None:
session = db_api.get_session()
with session.begin():
super(MagnumBase, self).save(session)
Base = declarative_base(cls=MagnumBase) Base = declarative_base(cls=MagnumBase)

View File

@ -16,10 +16,10 @@
"""Magnum DB test base class.""" """Magnum DB test base class."""
import fixtures import fixtures
from oslo_db.sqlalchemy import enginefacade
import magnum.conf import magnum.conf
from magnum.db import api as dbapi from magnum.db import api as dbapi
from magnum.db.sqlalchemy import api as sqla_api
from magnum.db.sqlalchemy import migration from magnum.db.sqlalchemy import migration
from magnum.db.sqlalchemy import models from magnum.db.sqlalchemy import models
from magnum.tests import base from magnum.tests import base
@ -32,15 +32,14 @@ _DB_CACHE = None
class Database(fixtures.Fixture): class Database(fixtures.Fixture):
def __init__(self, db_api, db_migrate, sql_connection): def __init__(self, engine, db_migrate, sql_connection):
self.sql_connection = sql_connection self.sql_connection = sql_connection
self.engine = db_api.get_engine() self.engine = engine
self.engine.dispose() self.engine.dispose()
conn = self.engine.connect() with self.engine.connect() as conn:
self.setup_sqlite(db_migrate) self.setup_sqlite(db_migrate)
self.post_migrations() self.post_migrations()
self._DB = "".join(line for line in conn.connection.iterdump()) self._DB = "".join(line for line in conn.connection.iterdump())
self.engine.dispose() self.engine.dispose()
@ -50,8 +49,9 @@ class Database(fixtures.Fixture):
models.Base.metadata.create_all(self.engine) models.Base.metadata.create_all(self.engine)
db_migrate.stamp('head') db_migrate.stamp('head')
def _setUp(self): def setUp(self):
conn = self.engine.connect() super(Database, self).setUp()
with self.engine.connect() as conn:
conn.connection.executescript(self._DB) conn.connection.executescript(self._DB)
self.addCleanup(self.engine.dispose) self.addCleanup(self.engine.dispose)
@ -68,6 +68,8 @@ class DbTestCase(base.TestCase):
global _DB_CACHE global _DB_CACHE
if not _DB_CACHE: if not _DB_CACHE:
_DB_CACHE = Database(sqla_api, migration, engine = enginefacade.writer.get_engine()
_DB_CACHE = Database(engine, migration,
sql_connection=CONF.database.connection) sql_connection=CONF.database.connection)
engine.dispose()
self.useFixture(_DB_CACHE) self.useFixture(_DB_CACHE)

View File

@ -26,16 +26,22 @@ class SqlAlchemyCustomTypesTestCase(base.DbTestCase):
# Create ClusterTemplate w/o labels # Create ClusterTemplate w/o labels
cluster_template1_id = uuidutils.generate_uuid() cluster_template1_id = uuidutils.generate_uuid()
self.dbapi.create_cluster_template({'uuid': cluster_template1_id}) self.dbapi.create_cluster_template({'uuid': cluster_template1_id})
cluster_template1 = sa_api.model_query( with sa_api._session_for_read() as session:
models.ClusterTemplate).filter_by(uuid=cluster_template1_id).one() cluster_template1 = (session.query(
models.ClusterTemplate)
.filter_by(uuid=cluster_template1_id)
.one())
self.assertEqual({}, cluster_template1.labels) self.assertEqual({}, cluster_template1.labels)
# Create ClusterTemplate with labels # Create ClusterTemplate with labels
cluster_template2_id = uuidutils.generate_uuid() cluster_template2_id = uuidutils.generate_uuid()
self.dbapi.create_cluster_template( self.dbapi.create_cluster_template(
{'uuid': cluster_template2_id, 'labels': {'bar': 'foo'}}) {'uuid': cluster_template2_id, 'labels': {'bar': 'foo'}})
cluster_template2 = sa_api.model_query( with sa_api._session_for_read() as session:
models.ClusterTemplate).filter_by(uuid=cluster_template2_id).one() cluster_template2 = (session.query(
models.ClusterTemplate)
.filter_by(uuid=cluster_template2_id)
.one())
self.assertEqual('foo', cluster_template2.labels['bar']) self.assertEqual('foo', cluster_template2.labels['bar'])
def test_JSONEncodedDict_type_check(self): def test_JSONEncodedDict_type_check(self):
@ -48,7 +54,8 @@ class SqlAlchemyCustomTypesTestCase(base.DbTestCase):
# Create nodegroup w/o node_addresses # Create nodegroup w/o node_addresses
nodegroup1_id = uuidutils.generate_uuid() nodegroup1_id = uuidutils.generate_uuid()
self.dbapi.create_nodegroup({'uuid': nodegroup1_id}) self.dbapi.create_nodegroup({'uuid': nodegroup1_id})
nodegroup1 = sa_api.model_query( with sa_api._session_for_read() as session:
nodegroup1 = session.query(
models.NodeGroup).filter_by(uuid=nodegroup1_id).one() models.NodeGroup).filter_by(uuid=nodegroup1_id).one()
self.assertEqual([], nodegroup1.node_addresses) self.assertEqual([], nodegroup1.node_addresses)
@ -59,7 +66,8 @@ class SqlAlchemyCustomTypesTestCase(base.DbTestCase):
'node_addresses': ['mynode_address1', 'node_addresses': ['mynode_address1',
'mynode_address2'] 'mynode_address2']
}) })
nodegroup2 = sa_api.model_query( with sa_api._session_for_read() as session:
nodegroup2 = session.query(
models.NodeGroup).filter_by(uuid=nodegroup2_id).one() models.NodeGroup).filter_by(uuid=nodegroup2_id).one()
self.assertEqual(['mynode_address1', 'mynode_address2'], self.assertEqual(['mynode_address1', 'mynode_address2'],
nodegroup2.node_addresses) nodegroup2.node_addresses)