309 lines
8.6 KiB
Python
309 lines
8.6 KiB
Python
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
|
# not use this file except in compliance with the License. You may obtain
|
|
# a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
# License for the specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
"""Implementation of SQLAlchemy backend."""
|
|
|
|
|
|
import functools
|
|
import sys
|
|
import threading
|
|
import time
|
|
|
|
from oslo_config import cfg
|
|
from oslo_db import exception as db_exc
|
|
from oslo_db import options
|
|
from oslo_db.sqlalchemy import session as db_session
|
|
from oslo_log import log as logging
|
|
from oslo_utils import timeutils
|
|
from sqlalchemy.sql.expression import literal_column
|
|
from sqlalchemy.sql import func
|
|
|
|
from smaug.db.sqlalchemy import models
|
|
from smaug import exception
|
|
from smaug.i18n import _, _LW
|
|
|
|
|
|
CONF = cfg.CONF
|
|
LOG = logging.getLogger(__name__)
|
|
|
|
options.set_defaults(CONF, connection='sqlite:///$state_path/smaug.sqlite')
|
|
|
|
_LOCK = threading.Lock()
|
|
_FACADE = None
|
|
|
|
|
|
def _create_facade_lazily():
|
|
global _LOCK
|
|
with _LOCK:
|
|
global _FACADE
|
|
if _FACADE is None:
|
|
_FACADE = db_session.EngineFacade(
|
|
CONF.database.connection,
|
|
**dict(CONF.database)
|
|
)
|
|
|
|
return _FACADE
|
|
|
|
|
|
def get_engine():
|
|
facade = _create_facade_lazily()
|
|
return facade.get_engine()
|
|
|
|
|
|
def get_session(**kwargs):
|
|
facade = _create_facade_lazily()
|
|
return facade.get_session(**kwargs)
|
|
|
|
|
|
def dispose_engine():
|
|
get_engine().dispose()
|
|
|
|
_DEFAULT_QUOTA_NAME = 'default'
|
|
|
|
|
|
def get_backend():
|
|
"""The backend is this module itself."""
|
|
|
|
return sys.modules[__name__]
|
|
|
|
|
|
def is_admin_context(context):
|
|
"""Indicates if the request context is an administrator."""
|
|
if not context:
|
|
LOG.warning(_LW('Use of empty request context is deprecated'),
|
|
DeprecationWarning)
|
|
raise Exception('die')
|
|
return context.is_admin
|
|
|
|
|
|
def is_user_context(context):
|
|
"""Indicates if the request context is a normal user."""
|
|
if not context:
|
|
return False
|
|
if context.is_admin:
|
|
return False
|
|
if not context.user_id or not context.project_id:
|
|
return False
|
|
return True
|
|
|
|
|
|
def authorize_project_context(context, project_id):
|
|
"""Ensures a request has permission to access the given project."""
|
|
if is_user_context(context):
|
|
if not context.project_id:
|
|
raise exception.NotAuthorized()
|
|
elif context.project_id != project_id:
|
|
raise exception.NotAuthorized()
|
|
|
|
|
|
def authorize_user_context(context, user_id):
|
|
"""Ensures a request has permission to access the given user."""
|
|
if is_user_context(context):
|
|
if not context.user_id:
|
|
raise exception.NotAuthorized()
|
|
elif context.user_id != user_id:
|
|
raise exception.NotAuthorized()
|
|
|
|
|
|
def require_admin_context(f):
|
|
"""Decorator to require admin request context.
|
|
|
|
The first argument to the wrapped function must be the context.
|
|
|
|
"""
|
|
|
|
def wrapper(*args, **kwargs):
|
|
if not is_admin_context(args[0]):
|
|
raise exception.AdminRequired()
|
|
return f(*args, **kwargs)
|
|
return wrapper
|
|
|
|
|
|
def require_context(f):
|
|
"""Decorator to require *any* user or admin context.
|
|
|
|
This does no authorization for user or project access matching, see
|
|
:py:func:`authorize_project_context` and
|
|
:py:func:`authorize_user_context`.
|
|
|
|
The first argument to the wrapped function must be the context.
|
|
|
|
"""
|
|
|
|
def wrapper(*args, **kwargs):
|
|
if not is_admin_context(args[0]) and not is_user_context(args[0]):
|
|
raise exception.NotAuthorized()
|
|
return f(*args, **kwargs)
|
|
return wrapper
|
|
|
|
|
|
def _retry_on_deadlock(f):
|
|
"""Decorator to retry a DB API call if Deadlock was received."""
|
|
@functools.wraps(f)
|
|
def wrapped(*args, **kwargs):
|
|
while True:
|
|
try:
|
|
return f(*args, **kwargs)
|
|
except db_exc.DBDeadlock:
|
|
LOG.warning(_LW("Deadlock detected when running "
|
|
"'%(func_name)s': Retrying..."),
|
|
dict(func_name=f.__name__))
|
|
# Retry!
|
|
time.sleep(0.5)
|
|
continue
|
|
functools.update_wrapper(wrapped, f)
|
|
return wrapped
|
|
|
|
|
|
def model_query(context, *args, **kwargs):
|
|
"""Query helper that accounts for context's `read_deleted` field.
|
|
|
|
:param context: context to query under
|
|
:param session: if present, the session to use
|
|
:param read_deleted: if present, overrides context's read_deleted field.
|
|
:param project_only: if present and context is user-type, then restrict
|
|
query to match the context's project_id.
|
|
"""
|
|
session = kwargs.get('session') or get_session()
|
|
read_deleted = kwargs.get('read_deleted') or context.read_deleted
|
|
project_only = kwargs.get('project_only')
|
|
|
|
query = session.query(*args)
|
|
if read_deleted == 'no':
|
|
query = query.filter_by(deleted=False)
|
|
elif read_deleted == 'yes':
|
|
pass # omit the filter to include deleted and active
|
|
elif read_deleted == 'only':
|
|
query = query.filter_by(deleted=True)
|
|
else:
|
|
raise Exception(
|
|
_("Unrecognized read_deleted value '%s'") % read_deleted)
|
|
|
|
if project_only and is_user_context(context):
|
|
query = query.filter_by(project_id=context.project_id)
|
|
|
|
return query
|
|
|
|
|
|
@require_admin_context
|
|
def service_destroy(context, service_id):
|
|
session = get_session()
|
|
with session.begin():
|
|
service_ref = _service_get(context, service_id, session=session)
|
|
service_ref.delete(session=session)
|
|
|
|
|
|
@require_admin_context
|
|
def _service_get(context, service_id, session=None):
|
|
result = model_query(
|
|
context,
|
|
models.Service,
|
|
session=session).\
|
|
filter_by(id=service_id).\
|
|
first()
|
|
if not result:
|
|
raise exception.ServiceNotFound(service_id=service_id)
|
|
|
|
return result
|
|
|
|
|
|
@require_admin_context
|
|
def service_get(context, service_id):
|
|
return _service_get(context, service_id)
|
|
|
|
|
|
@require_admin_context
|
|
def service_get_all(context, disabled=None):
|
|
query = model_query(context, models.Service)
|
|
|
|
if disabled is not None:
|
|
query = query.filter_by(disabled=disabled)
|
|
|
|
return query.all()
|
|
|
|
|
|
@require_admin_context
|
|
def service_get_all_by_topic(context, topic, disabled=None):
|
|
query = model_query(
|
|
context, models.Service, read_deleted="no").\
|
|
filter_by(topic=topic)
|
|
|
|
if disabled is not None:
|
|
query = query.filter_by(disabled=disabled)
|
|
|
|
return query.all()
|
|
|
|
|
|
@require_admin_context
|
|
def service_get_by_host_and_topic(context, host, topic):
|
|
result = model_query(
|
|
context, models.Service, read_deleted="no").\
|
|
filter_by(disabled=False).\
|
|
filter_by(host=host).\
|
|
filter_by(topic=topic).\
|
|
first()
|
|
if not result:
|
|
raise exception.ServiceNotFound(service_id=None)
|
|
return result
|
|
|
|
|
|
@require_admin_context
|
|
def _service_get_all_topic_subquery(context, session, topic, subq, label):
|
|
sort_value = getattr(subq.c, label)
|
|
return model_query(context, models.Service,
|
|
func.coalesce(sort_value, 0),
|
|
session=session, read_deleted="no").\
|
|
filter_by(topic=topic).\
|
|
filter_by(disabled=False).\
|
|
outerjoin((subq, models.Service.host == subq.c.host)).\
|
|
order_by(sort_value).\
|
|
all()
|
|
|
|
|
|
@require_admin_context
|
|
def service_get_by_args(context, host, binary):
|
|
results = model_query(context, models.Service).\
|
|
filter_by(host=host).\
|
|
filter_by(binary=binary).\
|
|
all()
|
|
|
|
for result in results:
|
|
if host == result['host']:
|
|
return result
|
|
|
|
raise exception.HostBinaryNotFound(host=host, binary=binary)
|
|
|
|
|
|
@require_admin_context
|
|
def service_create(context, values):
|
|
service_ref = models.Service()
|
|
service_ref.update(values)
|
|
if not CONF.enable_new_services:
|
|
service_ref.disabled = True
|
|
|
|
session = get_session()
|
|
with session.begin():
|
|
service_ref.save(session)
|
|
return service_ref
|
|
|
|
|
|
@require_admin_context
|
|
def service_update(context, service_id, values):
|
|
session = get_session()
|
|
with session.begin():
|
|
service_ref = _service_get(context, service_id, session=session)
|
|
if ('disabled' in values):
|
|
service_ref['modified_at'] = timeutils.utcnow()
|
|
service_ref['updated_at'] = literal_column('updated_at')
|
|
service_ref.update(values)
|
|
return service_ref
|