gyan/gyan/db/sqlalchemy/api.py

386 lines
14 KiB
Python

# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""SQLAlchemy storage backend."""
from oslo_log import log as logging
from oslo_db import exception as db_exc
from oslo_db.sqlalchemy import session as db_session
from oslo_db.sqlalchemy import utils as db_utils
from oslo_utils import importutils
from oslo_utils import strutils
from oslo_utils import timeutils
from oslo_utils import uuidutils
import sqlalchemy as sa
from sqlalchemy.orm import contains_eager
from sqlalchemy.orm.exc import MultipleResultsFound
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.sql.expression import desc
from sqlalchemy.sql import func
from gyan.common import consts
from gyan.common import exception
from gyan.common.i18n import _
import gyan.conf
from gyan.db.sqlalchemy import models
profiler_sqlalchemy = importutils.try_import('osprofiler.sqlalchemy')
CONF = gyan.conf.CONF
_FACADE = None
LOG = logging.getLogger(__name__)
def _create_facade_lazily():
global _FACADE
if _FACADE is None:
_FACADE = db_session.enginefacade.get_legacy_facade()
if profiler_sqlalchemy:
if CONF.profiler.enabled and CONF.profiler.trace_sqlalchemy:
profiler_sqlalchemy.add_tracing(sa, _FACADE.get_engine(), "db")
return _FACADE
def get_engine():
facade = _create_facade_lazily()
return facade.get_engine()
def get_session(**kwargs):
facade = _create_facade_lazily()
return facade.get_session(**kwargs)
def get_backend():
"""The backend is this module itself."""
return Connection()
def model_query(model, *args, **kwargs):
"""Query helper for simpler session usage.
:param session: if present, the session to use
"""
session = kwargs.get('session') or get_session()
query = session.query(model, *args)
return query
def add_identity_filter(query, value):
"""Adds an identity filter to a query.
Filters results by ID, if supplied value is a valid integer.
Otherwise attempts to filter results by UUID.
:param query: Initial query to add filter to.
:param value: Value for filtering results by.
:return: Modified query.
"""
if strutils.is_int_like(value):
return query.filter_by(id=value)
elif uuidutils.is_uuid_like(value):
return query.filter_by(id=value)
elif "gyan-" in value:
return query.filter_by(id=value)
else:
raise exception.InvalidIdentity(identity=value)
def _paginate_query(model, limit=None, marker=None, sort_key=None,
sort_dir=None, query=None, default_sort_key='id'):
if not query:
query = model_query(model)
sort_keys = [default_sort_key]
if sort_key and sort_key not in sort_keys:
sort_keys.insert(0, sort_key)
try:
query = db_utils.paginate_query(query, model, limit, sort_keys,
marker=marker, sort_dir=sort_dir)
except db_exc.InvalidSortKey:
raise exception.InvalidParameterValue(
_('The sort_key value "%(key)s" is an invalid field for sorting')
% {'key': sort_key})
return query.all()
class Connection(object):
"""SqlAlchemy connection."""
def __init__(self):
pass
def _add_project_filters(self, context, query):
if context.is_admin and context.all_projects:
return query
if context.project_id:
query = query.filter_by(project_id=context.project_id)
else:
query = query.filter_by(user_id=context.user_id)
return query
def _add_filters(self, query, model, filters=None, filter_names=None):
"""Generic way to add filters to a Gyan model"""
if not filters:
return query
if not filter_names:
filter_names = []
for name in filter_names:
if name in filters:
value = filters[name]
if isinstance(value, list):
column = getattr(model, name)
query = query.filter(column.in_(value))
else:
column = getattr(model, name)
query = query.filter(column == value)
return query
def _add_compute_hosts_filters(self, query, filters):
filter_names = None
return self._add_filters(query, models.ComputeHost, filters=filters,
filter_names=filter_names)
def list_compute_hosts(self, context, filters=None, limit=None,
marker=None, sort_key=None, sort_dir=None):
query = model_query(models.ComputeHost)
query = self._add_compute_hosts_filters(query, filters)
return _paginate_query(models.ComputeHost, limit, marker,
sort_key, sort_dir, query,
default_sort_key='id')
def create_compute_host(self, context, values):
# ensure defaults are present for new compute hosts
if not values.get('id'):
values['id'] = uuidutils.generate_uuid()
compute_host = models.ComputeHost()
compute_host.update(values)
try:
compute_host.save()
except db_exc.DBDuplicateEntry:
raise exception.ComputeHostAlreadyExists(
field='UUID', value=values['uuid'])
return compute_host
def get_compute_host(self, context, host_uuid):
query = model_query(models.ComputeHost)
query = query.filter_by(id=host_uuid)
try:
return query.one()
except NoResultFound:
raise exception.ComputeHostNotFound(
compute_host=host_uuid)
def get_compute_host_by_hostname(self, context, hostname):
query = model_query(models.ComputeHost)
query = query.filter_by(hostname=hostname)
try:
return query.one()
except NoResultFound:
raise exception.ComputeHostNotFound(
compute_host=hostname)
except MultipleResultsFound:
raise exception.Conflict('Multiple compute hosts exist with same '
'hostname. Please use the uuid instead.')
def destroy_compute_host(self, context, host_uuid):
session = get_session()
with session.begin():
query = model_query(models.ComputeHost, session=session)
query = query.filter_by(uuid=host_uuid)
count = query.delete()
if count != 1:
raise exception.ComputeHostNotFound(
compute_host=host_uuid)
def update_compute_host(self, context, host_uuid, values):
if 'uuid' in values:
msg = _("Cannot overwrite UUID for an existing ComputeHost.")
raise exception.InvalidParameterValue(err=msg)
return self._do_update_compute_host(host_uuid, values)
def _do_update_compute_host(self, host_uuid, values):
session = get_session()
with session.begin():
query = model_query(models.ComputeHost, session=session)
query = query.filter_by(uuid=host_uuid)
try:
ref = query.with_lockmode('update').one()
except NoResultFound:
raise exception.ComputeHostNotFound(
compute_host=host_uuid)
ref.update(values)
return ref
def list_ml_models(self, context, filters=None, limit=None,
marker=None, sort_key=None, sort_dir=None):
query = model_query(models.ML_Model)
query = self._add_project_filters(context, query)
query = self._add_ml_models_filters(query, filters)
LOG.debug(filters)
return _paginate_query(models.ML_Model, limit, marker,
sort_key, sort_dir, query)
def create_ml_model(self, context, values):
# ensure defaults are present for new ml_models
if not values.get('id'):
values['id'] = 'gyan-' + uuidutils.generate_uuid()
ml_model = models.ML_Model()
ml_model.update(values)
try:
ml_model.save()
except db_exc.DBDuplicateEntry:
raise exception.MLModelAlreadyExists(field='UUID',
value=values['uuid'])
return ml_model
def get_ml_model_by_uuid(self, context, ml_model_uuid):
query = model_query(models.ML_Model)
query = self._add_project_filters(context, query)
query = query.filter_by(id=ml_model_uuid)
try:
return query.one()
except NoResultFound:
raise exception.MLModelNotFound(ml_model=ml_model_uuid)
def get_ml_model_by_name(self, context, ml_model_name):
query = model_query(models.ML_Model)
query = self._add_project_filters(context, query)
query = query.filter_by(name=ml_model_name)
try:
return query.one()
except NoResultFound:
raise exception.MLModelNotFound(ml_model=ml_model_name)
except MultipleResultsFound:
raise exception.Conflict('Multiple ml_models exist with same '
'name. Please use the ml_model uuid '
'instead.')
def destroy_ml_model(self, context, ml_model_id):
session = get_session()
with session.begin():
query = model_query(models.ML_Model, session=session)
query = add_identity_filter(query, ml_model_id)
count = query.delete()
if count != 1:
raise exception.MLModelNotFound(ml_model_id)
def update_ml_model(self, context, ml_model_id, values):
if 'uuid' in values:
msg = _("Cannot overwrite UUID for an existing ML Model.")
raise exception.InvalidParameterValue(err=msg)
return self._do_update_ml_model_id(ml_model_id, values)
def _do_update_ml_model_id(self, ml_model_id, values):
session = get_session()
with session.begin():
query = model_query(models.ML_Model, session=session)
query = add_identity_filter(query, ml_model_id)
try:
ref = query.with_lockmode('update').one()
except NoResultFound:
raise exception.MLModelNotFound(ml_model=ml_model_id)
ref.update(values)
return ref
def _add_ml_models_filters(self, query, filters):
filter_names = ['uuid', 'project_id', 'user_id']
return self._add_filters(query, models.ML_Model, filters=filters,
filter_names=filter_names)
def list_flavors(self, context, filters=None, limit=None,
marker=None, sort_key=None, sort_dir=None):
query = model_query(models.Flavor)
query = self._add_flavors_filters(query, filters)
LOG.debug(filters)
return _paginate_query(models.Flavor, limit, marker,
sort_key, sort_dir, query)
def create_flavor(self, context, values):
# ensure defaults are present for new flavors
if not values.get('id'):
values['id'] = uuidutils.generate_uuid()
flavor = models.Flavor()
flavor.update(values)
try:
flavor.save()
except db_exc.DBDuplicateEntry:
raise exception.FlavorAlreadyExists(field='UUID',
value=values['uuid'])
return flavor
def get_flavor_by_uuid(self, context, flavor_uuid):
query = model_query(models.Flavor)
query = self._add_project_filters(context, query)
query = query.filter_by(id=flavor_uuid)
try:
return query.one()
except NoResultFound:
raise exception.FlavorNotFound(flavor=flavor_uuid)
def get_flavor_by_name(self, context, flavor_name):
query = model_query(models.Flavor)
query = self._add_project_filters(context, query)
query = query.filter_by(name=flavor_name)
try:
return query.one()
except NoResultFound:
raise exception.FlavorNotFound(flavor=flavor_name)
except MultipleResultsFound:
raise exception.Conflict('Multiple flavors exist with same '
'name. Please use the flavor uuid '
'instead.')
def destroy_flavor(self, context, flavor_id):
session = get_session()
with session.begin():
query = model_query(models.Flavor, session=session)
query = add_identity_filter(query, flavor_id)
count = query.delete()
if count != 1:
raise exception.FlavorNotFound(flavor_id)
def update_flavor(self, context, flavor_id, values):
if 'id' in values:
msg = _("Cannot overwrite UUID for an existing ML Model.")
raise exception.InvalidParameterValue(err=msg)
return self._do_update_flavor_id(flavor_id, values)
def _do_update_flavor_id(self, flavor_id, values):
session = get_session()
with session.begin():
query = model_query(models.Flavor, session=session)
query = add_identity_filter(query, flavor_id)
try:
ref = query.with_lockmode('update').one()
except NoResultFound:
raise exception.FlavorNotFound(flavor=flavor_id)
ref.update(values)
return ref
def _add_flavors_filters(self, query, filters):
filter_names = ['id']
return self._add_filters(query, models.Flavor, filters=filters,
filter_names=filter_names)