Refactor session handling and align test contexts

This change refactors database session management and aligns unit
tests with the production DB context handling.

- Replace _session_for_read() and _session_for_write() helpers with
  @main_context_manager.reader / @main_context_manager.writer decorators
  to simplify and standardize DB access.
- Update DB API methods to use context.session directly and rely on
  centralized transaction management through enginefacade.
- Switch tests to use cyborg.context.RequestContext instead of
  oslo_context.RequestContext for consistent context propagation.

Closes-Bug: #2061130
Change-Id: Idf7714ec9fa57b4885bd5679f431cdeac2ad1497
Signed-off-by: Sooyoung Kim <sykim.etri@gmail.com>
This commit is contained in:
Sooyoung Kim
2025-10-28 20:13:25 +09:00
parent e408b4c1b7
commit 75467de9f7
6 changed files with 335 additions and 312 deletions

View File

@@ -16,7 +16,6 @@
"""SQLAlchemy storage backend.""" """SQLAlchemy storage backend."""
import copy import copy
import threading
import uuid import uuid
from oslo_db import api as oslo_db_api from oslo_db import api as oslo_db_api
@@ -35,7 +34,6 @@ from cyborg.common.i18n import _
from cyborg.db import api from cyborg.db import api
from cyborg.db.sqlalchemy import models from cyborg.db.sqlalchemy import models
_CONTEXT = threading.local()
LOG = log.getLogger(__name__) LOG = log.getLogger(__name__)
main_context_manager = enginefacade.transaction_context() main_context_manager = enginefacade.transaction_context()
@@ -46,14 +44,6 @@ def get_backend():
return Connection() return Connection()
def _session_for_read():
return enginefacade.reader.using(_CONTEXT)
def _session_for_write():
return enginefacade.writer.using(_CONTEXT)
def get_session(use_slave=False, **kwargs): def get_session(use_slave=False, **kwargs):
return main_context_manager._factory.get_legacy_facade().get_session( return main_context_manager._factory.get_legacy_facade().get_session(
use_slave=use_slave, **kwargs) use_slave=use_slave, **kwargs)
@@ -78,10 +68,10 @@ def model_query(context, model, *args, **kwargs):
if kwargs.pop("project_only", False): if kwargs.pop("project_only", False):
kwargs["project_id"] = context.project_id kwargs["project_id"] = context.project_id
with _session_for_read() as session: query = sqlalchemyutils.model_query(
query = sqlalchemyutils.model_query( model, context.session, args, **kwargs)
model, session, args, **kwargs)
return query return query
def add_identity_filter(query, value): def add_identity_filter(query, value):
@@ -124,6 +114,7 @@ class Connection(api.Connection):
def __init__(self): def __init__(self):
pass pass
@main_context_manager.writer
def attach_handle_create(self, context, values): def attach_handle_create(self, context, values):
if not values.get('uuid'): if not values.get('uuid'):
values['uuid'] = uuidutils.generate_uuid() values['uuid'] = uuidutils.generate_uuid()
@@ -131,14 +122,14 @@ class Connection(api.Connection):
attach_handle = models.AttachHandle() attach_handle = models.AttachHandle()
attach_handle.update(values) attach_handle.update(values)
with _session_for_write() as session: try:
try: context.session.add(attach_handle)
session.add(attach_handle) context.session.flush()
session.flush() except db_exc.DBDuplicateEntry:
except db_exc.DBDuplicateEntry: raise exception.AttachHandleAlreadyExists(uuid=values['uuid'])
raise exception.AttachHandleAlreadyExists(uuid=values['uuid']) return attach_handle
return attach_handle
@main_context_manager.reader
def attach_handle_get_by_uuid(self, context, uuid): def attach_handle_get_by_uuid(self, context, uuid):
query = model_query( query = model_query(
context, context,
@@ -150,6 +141,7 @@ class Connection(api.Connection):
resource='AttachHandle', resource='AttachHandle',
msg='with uuid=%s' % uuid) msg='with uuid=%s' % uuid)
@main_context_manager.reader
def attach_handle_get_by_id(self, context, id): def attach_handle_get_by_id(self, context, id):
query = model_query( query = model_query(
context, context,
@@ -161,6 +153,7 @@ class Connection(api.Connection):
resource='AttachHandle', resource='AttachHandle',
msg='with id=%s' % id) msg='with id=%s' % id)
@main_context_manager.reader
def attach_handle_list_by_type(self, context, attach_type='PCI'): def attach_handle_list_by_type(self, context, attach_type='PCI'):
query = model_query(context, models.AttachHandle). \ query = model_query(context, models.AttachHandle). \
filter_by(attach_type=attach_type) filter_by(attach_type=attach_type)
@@ -171,6 +164,7 @@ class Connection(api.Connection):
resource='AttachHandle', resource='AttachHandle',
msg='with type=%s' % attach_type) msg='with type=%s' % attach_type)
@main_context_manager.reader
def attach_handle_get_by_filters(self, context, def attach_handle_get_by_filters(self, context,
filters, sort_key='created_at', filters, sort_key='created_at',
sort_dir='desc', limit=None, sort_dir='desc', limit=None,
@@ -238,6 +232,7 @@ class Connection(api.Connection):
for k, v in filter_dict.items()]) for k, v in filter_dict.items()])
return query return query
@main_context_manager.reader
def attach_handle_list(self, context): def attach_handle_list(self, context):
query = model_query(context, models.AttachHandle) query = model_query(context, models.AttachHandle)
return _paginate_query(context, models.AttachHandle, query=query) return _paginate_query(context, models.AttachHandle, query=query)
@@ -249,36 +244,36 @@ class Connection(api.Connection):
return self._do_update_attach_handle(context, uuid, values) return self._do_update_attach_handle(context, uuid, values)
@oslo_db_api.retry_on_deadlock @oslo_db_api.retry_on_deadlock
@main_context_manager.writer
def _do_update_attach_handle(self, context, uuid, values): def _do_update_attach_handle(self, context, uuid, values):
with _session_for_write(): query = model_query(context, models.AttachHandle)
query = model_query(context, models.AttachHandle) query = add_identity_filter(query, uuid)
query = add_identity_filter(query, uuid) try:
try: ref = query.with_for_update().one()
ref = query.with_for_update().one() except NoResultFound:
except NoResultFound: raise exception.ResourceNotFound(
raise exception.ResourceNotFound( resource='AttachHandle',
resource='AttachHandle', msg='with uuid=%s' % uuid)
msg='with uuid=%s' % uuid) ref.update(values)
ref.update(values)
return ref return ref
@oslo_db_api.retry_on_deadlock @oslo_db_api.retry_on_deadlock
@main_context_manager.writer
def _do_allocate_attach_handle(self, context, deployable_id): def _do_allocate_attach_handle(self, context, deployable_id):
"""Atomically get a set of attach handles that match the query """Atomically get a set of attach handles that match the query
and mark one of those as in_use. and mark one of those as in_use.
""" """
with _session_for_write() as session: query = model_query(context, models.AttachHandle). \
query = model_query(context, models.AttachHandle). \ filter_by(deployable_id=deployable_id,
filter_by(deployable_id=deployable_id, in_use=False)
in_use=False) values = {"in_use": True}
values = {"in_use": True} ref = query.with_for_update().first()
ref = query.with_for_update().first() if not ref:
if not ref: msg = 'Matching deployable_id {0}'.format(deployable_id)
msg = 'Matching deployable_id {0}'.format(deployable_id) raise exception.ResourceNotFound(
raise exception.ResourceNotFound( resource='AttachHandle', msg=msg)
resource='AttachHandle', msg=msg) ref.update(values)
ref.update(values) context.session.flush()
session.flush()
return ref return ref
def attach_handle_allocate(self, context, deployable_id): def attach_handle_allocate(self, context, deployable_id):
@@ -298,16 +293,17 @@ class Connection(api.Connection):
# NOTE: For deallocate, we use attach_handle_update() # NOTE: For deallocate, we use attach_handle_update()
@oslo_db_api.retry_on_deadlock @oslo_db_api.retry_on_deadlock
@main_context_manager.writer
def attach_handle_delete(self, context, uuid): def attach_handle_delete(self, context, uuid):
with _session_for_write(): query = model_query(context, models.AttachHandle)
query = model_query(context, models.AttachHandle) query = add_identity_filter(query, uuid)
query = add_identity_filter(query, uuid) count = query.delete()
count = query.delete() if count != 1:
if count != 1: raise exception.ResourceNotFound(
raise exception.ResourceNotFound( resource='AttachHandle',
resource='AttachHandle', msg='with uuid=%s' % uuid)
msg='with uuid=%s' % uuid)
@main_context_manager.writer
def control_path_create(self, context, values): def control_path_create(self, context, values):
if not values.get('uuid'): if not values.get('uuid'):
values['uuid'] = uuidutils.generate_uuid() values['uuid'] = uuidutils.generate_uuid()
@@ -315,14 +311,14 @@ class Connection(api.Connection):
control_path_id = models.ControlpathID() control_path_id = models.ControlpathID()
control_path_id.update(values) control_path_id.update(values)
with _session_for_write() as session: try:
try: context.session.add(control_path_id)
session.add(control_path_id) context.session.flush()
session.flush() except db_exc.DBDuplicateEntry:
except db_exc.DBDuplicateEntry: raise exception.ControlpathIDAlreadyExists(uuid=values['uuid'])
raise exception.ControlpathIDAlreadyExists(uuid=values['uuid']) return control_path_id
return control_path_id
@main_context_manager.reader
def control_path_get_by_uuid(self, context, uuid): def control_path_get_by_uuid(self, context, uuid):
query = model_query( query = model_query(
context, context,
@@ -334,6 +330,7 @@ class Connection(api.Connection):
resource='ControlpathID', resource='ControlpathID',
msg='with uuid=%s' % uuid) msg='with uuid=%s' % uuid)
@main_context_manager.reader
def control_path_get_by_filters(self, context, def control_path_get_by_filters(self, context,
filters, sort_key='created_at', filters, sort_key='created_at',
sort_dir='desc', limit=None, sort_dir='desc', limit=None,
@@ -360,6 +357,7 @@ class Connection(api.Connection):
return _paginate_query(context, models.ControlpathID, query_prefix, return _paginate_query(context, models.ControlpathID, query_prefix,
limit, marker, sort_key, sort_dir) limit, marker, sort_key, sort_dir)
@main_context_manager.reader
def control_path_list(self, context): def control_path_list(self, context):
query = model_query(context, models.ControlpathID) query = model_query(context, models.ControlpathID)
return _paginate_query(context, models.ControlpathID, query=query) return _paginate_query(context, models.ControlpathID, query=query)
@@ -371,28 +369,29 @@ class Connection(api.Connection):
return self._do_update_control_path(context, uuid, values) return self._do_update_control_path(context, uuid, values)
@oslo_db_api.retry_on_deadlock @oslo_db_api.retry_on_deadlock
@main_context_manager.writer
def _do_update_control_path(self, context, uuid, values): def _do_update_control_path(self, context, uuid, values):
with _session_for_write(): query = model_query(context, models.ControlpathID)
query = model_query(context, models.ControlpathID) query = add_identity_filter(query, uuid)
query = add_identity_filter(query, uuid) try:
try: ref = query.with_for_update().one()
ref = query.with_for_update().one() except NoResultFound:
except NoResultFound: raise exception.ResourceNotFound(
raise exception.ResourceNotFound( resource='ControlpathID',
resource='ControlpathID', msg='with uuid=%s' % uuid)
msg='with uuid=%s' % uuid) ref.update(values)
ref.update(values)
return ref return ref
@oslo_db_api.retry_on_deadlock @oslo_db_api.retry_on_deadlock
@main_context_manager.writer
def control_path_delete(self, context, uuid): def control_path_delete(self, context, uuid):
with _session_for_write(): query = model_query(context, models.ControlpathID)
query = model_query(context, models.ControlpathID) query = add_identity_filter(query, uuid)
query = add_identity_filter(query, uuid) count = query.delete()
count = query.delete() if count != 1:
if count != 1: raise exception.ControlpathNotFound(uuid=uuid)
raise exception.ControlpathNotFound(uuid=uuid)
@main_context_manager.writer
def device_create(self, context, values): def device_create(self, context, values):
if not values.get('uuid'): if not values.get('uuid'):
values['uuid'] = uuidutils.generate_uuid() values['uuid'] = uuidutils.generate_uuid()
@@ -400,14 +399,14 @@ class Connection(api.Connection):
device = models.Device() device = models.Device()
device.update(values) device.update(values)
with _session_for_write() as session: try:
try: context.session.add(device)
session.add(device) context.session.flush()
session.flush() except db_exc.DBDuplicateEntry:
except db_exc.DBDuplicateEntry: raise exception.DeviceAlreadyExists(uuid=values['uuid'])
raise exception.DeviceAlreadyExists(uuid=values['uuid']) return device
return device
@main_context_manager.reader
def device_get(self, context, uuid): def device_get(self, context, uuid):
query = model_query( query = model_query(
context, context,
@@ -419,6 +418,7 @@ class Connection(api.Connection):
resource='Device', resource='Device',
msg='with uuid=%s' % uuid) msg='with uuid=%s' % uuid)
@main_context_manager.reader
def device_get_by_id(self, context, id): def device_get_by_id(self, context, id):
query = model_query( query = model_query(
context, context,
@@ -430,6 +430,7 @@ class Connection(api.Connection):
resource='Device', resource='Device',
msg='with id=%s' % id) msg='with id=%s' % id)
@main_context_manager.reader
def device_list_by_filters(self, context, def device_list_by_filters(self, context,
filters, sort_key='created_at', filters, sort_key='created_at',
sort_dir='desc', limit=None, sort_dir='desc', limit=None,
@@ -453,6 +454,7 @@ class Connection(api.Connection):
return _paginate_query(context, models.Device, query_prefix, return _paginate_query(context, models.Device, query_prefix,
limit, marker, sort_key, sort_dir) limit, marker, sort_key, sort_dir)
@main_context_manager.reader
def device_list(self, context, limit=None, marker=None, sort_key=None, def device_list(self, context, limit=None, marker=None, sort_key=None,
sort_dir=None): sort_dir=None):
query = model_query(context, models.Device) query = model_query(context, models.Device)
@@ -471,31 +473,32 @@ class Connection(api.Connection):
raise exception.DuplicateDeviceName(name=values['name']) raise exception.DuplicateDeviceName(name=values['name'])
@oslo_db_api.retry_on_deadlock @oslo_db_api.retry_on_deadlock
@main_context_manager.writer
def _do_update_device(self, context, uuid, values): def _do_update_device(self, context, uuid, values):
with _session_for_write(): query = model_query(context, models.Device)
query = model_query(context, models.Device) query = add_identity_filter(query, uuid)
query = add_identity_filter(query, uuid) try:
try: ref = query.with_for_update().one()
ref = query.with_for_update().one() except NoResultFound:
except NoResultFound: raise exception.ResourceNotFound(
raise exception.ResourceNotFound( resource='Device',
resource='Device', msg='with uuid=%s' % uuid)
msg='with uuid=%s' % uuid)
ref.update(values) ref.update(values)
return ref return ref
@oslo_db_api.retry_on_deadlock @oslo_db_api.retry_on_deadlock
@main_context_manager.writer
def device_delete(self, context, uuid): def device_delete(self, context, uuid):
with _session_for_write(): query = model_query(context, models.Device)
query = model_query(context, models.Device) query = add_identity_filter(query, uuid)
query = add_identity_filter(query, uuid) count = query.delete()
count = query.delete() if count != 1:
if count != 1: raise exception.ResourceNotFound(
raise exception.ResourceNotFound( resource='Device',
resource='Device', msg='with uuid=%s' % uuid)
msg='with uuid=%s' % uuid)
@main_context_manager.writer
def device_profile_create(self, context, values): def device_profile_create(self, context, values):
if not values.get('uuid'): if not values.get('uuid'):
values['uuid'] = uuidutils.generate_uuid() values['uuid'] = uuidutils.generate_uuid()
@@ -503,24 +506,24 @@ class Connection(api.Connection):
device_profile = models.DeviceProfile() device_profile = models.DeviceProfile()
device_profile.update(values) device_profile.update(values)
with _session_for_write() as session: try:
try: context.session.add(device_profile)
session.add(device_profile) context.session.flush()
session.flush() except db_exc.DBDuplicateEntry as e:
except db_exc.DBDuplicateEntry as e: # mysql duplicate key error changed as reference link below:
# mysql duplicate key error changed as reference link below: # https://review.opendev.org/c/openstack/oslo.db/+/792124
# https://review.opendev.org/c/openstack/oslo.db/+/792124 LOG.info('Duplicate columns are: ', e.columns)
LOG.info('Duplicate columns are: ', e.columns) columns = [column.split('0')[1] if 'uniq_' in column else
columns = [column.split('0')[1] if 'uniq_' in column else column for column in e.columns]
column for column in e.columns] if 'name' in columns:
if 'name' in columns: raise exception.DuplicateDeviceProfileName(
raise exception.DuplicateDeviceProfileName( name=values['name'])
name=values['name']) else:
else: raise exception.DeviceProfileAlreadyExists(
raise exception.DeviceProfileAlreadyExists( uuid=values['uuid'])
uuid=values['uuid']) return device_profile
return device_profile
@main_context_manager.reader
def device_profile_get_by_uuid(self, context, uuid): def device_profile_get_by_uuid(self, context, uuid):
query = model_query( query = model_query(
context, context,
@@ -532,6 +535,7 @@ class Connection(api.Connection):
resource='Device Profile', resource='Device Profile',
msg='with uuid=%s' % uuid) msg='with uuid=%s' % uuid)
@main_context_manager.reader
def device_profile_get_by_id(self, context, id): def device_profile_get_by_id(self, context, id):
query = model_query( query = model_query(
context, context,
@@ -543,6 +547,7 @@ class Connection(api.Connection):
resource='Device Profile', resource='Device Profile',
msg='with id=%s' % id) msg='with id=%s' % id)
@main_context_manager.reader
def device_profile_get(self, context, name): def device_profile_get(self, context, name):
query = model_query( query = model_query(
context, models.DeviceProfile).filter_by(name=name) context, models.DeviceProfile).filter_by(name=name)
@@ -553,6 +558,7 @@ class Connection(api.Connection):
resource='Device Profile', resource='Device Profile',
msg='with name=%s' % name) msg='with name=%s' % name)
@main_context_manager.reader
def device_profile_list_by_filters( def device_profile_list_by_filters(
self, context, filters, sort_key='created_at', sort_dir='desc', self, context, filters, sort_key='created_at', sort_dir='desc',
limit=None, marker=None, join_columns=None): limit=None, marker=None, join_columns=None):
@@ -572,6 +578,7 @@ class Connection(api.Connection):
return _paginate_query(context, models.DeviceProfile, query_prefix, return _paginate_query(context, models.DeviceProfile, query_prefix,
limit, marker, sort_key, sort_dir) limit, marker, sort_key, sort_dir)
@main_context_manager.reader
def device_profile_list(self, context): def device_profile_list(self, context):
query = model_query(context, models.DeviceProfile) query = model_query(context, models.DeviceProfile)
return _paginate_query(context, models.DeviceProfile, query=query) return _paginate_query(context, models.DeviceProfile, query=query)
@@ -588,31 +595,32 @@ class Connection(api.Connection):
raise exception.DuplicateDeviceProfileName(name=values['name']) raise exception.DuplicateDeviceProfileName(name=values['name'])
@oslo_db_api.retry_on_deadlock @oslo_db_api.retry_on_deadlock
@main_context_manager.writer
def _do_update_device_profile(self, context, uuid, values): def _do_update_device_profile(self, context, uuid, values):
with _session_for_write(): query = model_query(context, models.DeviceProfile)
query = model_query(context, models.DeviceProfile) query = add_identity_filter(query, uuid)
query = add_identity_filter(query, uuid) try:
try: ref = query.with_for_update().one()
ref = query.with_for_update().one() except NoResultFound:
except NoResultFound: raise exception.ResourceNotFound(
raise exception.ResourceNotFound( resource='Device Profile',
resource='Device Profile', msg='with uuid=%s' % uuid)
msg='with uuid=%s' % uuid)
ref.update(values) ref.update(values)
return ref return ref
@oslo_db_api.retry_on_deadlock @oslo_db_api.retry_on_deadlock
@main_context_manager.writer
def device_profile_delete(self, context, uuid): def device_profile_delete(self, context, uuid):
with _session_for_write(): query = model_query(context, models.DeviceProfile)
query = model_query(context, models.DeviceProfile) query = add_identity_filter(query, uuid)
query = add_identity_filter(query, uuid) count = query.delete()
count = query.delete() if count != 1:
if count != 1: raise exception.ResourceNotFound(
raise exception.ResourceNotFound( resource='Device Profile',
resource='Device Profile', msg='with uuid=%s' % uuid)
msg='with uuid=%s' % uuid)
@main_context_manager.writer
def deployable_create(self, context, values): def deployable_create(self, context, values):
if not values.get('uuid'): if not values.get('uuid'):
values['uuid'] = uuidutils.generate_uuid() values['uuid'] = uuidutils.generate_uuid()
@@ -621,14 +629,14 @@ class Connection(api.Connection):
deployable = models.Deployable() deployable = models.Deployable()
deployable.update(values) deployable.update(values)
with _session_for_write() as session: try:
try: context.session.add(deployable)
session.add(deployable) context.session.flush()
session.flush() except db_exc.DBDuplicateEntry:
except db_exc.DBDuplicateEntry: raise exception.DeployableAlreadyExists(uuid=values['uuid'])
raise exception.DeployableAlreadyExists(uuid=values['uuid']) return deployable
return deployable
@main_context_manager.reader
def deployable_get(self, context, uuid): def deployable_get(self, context, uuid):
query = model_query( query = model_query(
context, context,
@@ -640,6 +648,7 @@ class Connection(api.Connection):
resource='Deployable', resource='Deployable',
msg='with uuid=%s' % uuid) msg='with uuid=%s' % uuid)
@main_context_manager.reader
def deployable_get_by_rp_uuid(self, context, rp_uuid): def deployable_get_by_rp_uuid(self, context, rp_uuid):
"""Get a deployable by resource provider UUID.""" """Get a deployable by resource provider UUID."""
query = model_query( query = model_query(
@@ -652,6 +661,7 @@ class Connection(api.Connection):
resource='Deployable', resource='Deployable',
msg='with resource provider uuid=%s' % rp_uuid) msg='with resource provider uuid=%s' % rp_uuid)
@main_context_manager.reader
def deployable_list(self, context): def deployable_list(self, context):
query = model_query(context, models.Deployable) query = model_query(context, models.Deployable)
return query.all() return query.all()
@@ -668,32 +678,32 @@ class Connection(api.Connection):
raise exception.DuplicateDeployableName(name=values['name']) raise exception.DuplicateDeployableName(name=values['name'])
@oslo_db_api.retry_on_deadlock @oslo_db_api.retry_on_deadlock
@main_context_manager.writer
def _do_update_deployable(self, context, uuid, values): def _do_update_deployable(self, context, uuid, values):
with _session_for_write(): query = model_query(context, models.Deployable)
query = model_query(context, models.Deployable) # query = add_identity_filter(query, uuid)
# query = add_identity_filter(query, uuid) query = query.filter_by(uuid=uuid)
query = query.filter_by(uuid=uuid) try:
try: ref = query.with_for_update().one()
ref = query.with_for_update().one() except NoResultFound:
except NoResultFound: raise exception.ResourceNotFound(
raise exception.ResourceNotFound( resource='Deployable',
resource='Deployable', msg='with uuid=%s' % uuid)
msg='with uuid=%s' % uuid)
ref.update(values) ref.update(values)
return ref return ref
@oslo_db_api.retry_on_deadlock @oslo_db_api.retry_on_deadlock
@main_context_manager.writer
def deployable_delete(self, context, uuid): def deployable_delete(self, context, uuid):
with _session_for_write(): query = model_query(context, models.Deployable)
query = model_query(context, models.Deployable) query = add_identity_filter(query, uuid)
query = add_identity_filter(query, uuid) query.update({'root_id': None})
query.update({'root_id': None}) count = query.delete()
count = query.delete() if count != 1:
if count != 1: raise exception.ResourceNotFound(
raise exception.ResourceNotFound( resource='Deployable',
resource='Deployable', msg='with uuid=%s' % uuid)
msg='with uuid=%s' % uuid)
def deployable_get_by_filters(self, context, def deployable_get_by_filters(self, context,
filters, sort_key='created_at', filters, sort_key='created_at',
@@ -709,6 +719,7 @@ class Connection(api.Connection):
sort_key=sort_key, sort_key=sort_key,
sort_dir=sort_dir) sort_dir=sort_dir)
@main_context_manager.reader
def deployable_get_by_filters_sort(self, context, filters, limit=None, def deployable_get_by_filters_sort(self, context, filters, limit=None,
marker=None, join_columns=None, marker=None, join_columns=None,
sort_key=None, sort_dir=None): sort_key=None, sort_dir=None):
@@ -736,6 +747,7 @@ class Connection(api.Connection):
return _paginate_query(context, models.Deployable, query_prefix, return _paginate_query(context, models.Deployable, query_prefix,
limit, marker, sort_key, sort_dir) limit, marker, sort_key, sort_dir)
@main_context_manager.writer
def attribute_create(self, context, values): def attribute_create(self, context, values):
if not values.get('uuid'): if not values.get('uuid'):
values['uuid'] = uuidutils.generate_uuid() values['uuid'] = uuidutils.generate_uuid()
@@ -744,15 +756,15 @@ class Connection(api.Connection):
attribute = models.Attribute() attribute = models.Attribute()
attribute.update(values) attribute.update(values)
with _session_for_write() as session: try:
try: context.session.add(attribute)
session.add(attribute) context.session.flush()
session.flush() except db_exc.DBDuplicateEntry:
except db_exc.DBDuplicateEntry: raise exception.AttributeAlreadyExists(
raise exception.AttributeAlreadyExists( uuid=values['uuid'])
uuid=values['uuid']) return attribute
return attribute
@main_context_manager.reader
def attribute_get(self, context, uuid): def attribute_get(self, context, uuid):
query = model_query( query = model_query(
context, context,
@@ -764,12 +776,14 @@ class Connection(api.Connection):
resource='Attribute', resource='Attribute',
msg='with uuid=%s' % uuid) msg='with uuid=%s' % uuid)
@main_context_manager.reader
def attribute_get_by_deployable_id(self, context, deployable_id): def attribute_get_by_deployable_id(self, context, deployable_id):
query = model_query( query = model_query(
context, context,
models.Attribute).filter_by(deployable_id=deployable_id) models.Attribute).filter_by(deployable_id=deployable_id)
return query.all() return query.all()
@main_context_manager.reader
def attribute_get_by_filter(self, context, filters): def attribute_get_by_filter(self, context, filters):
"""Return attributes that matches the filters """Return attributes that matches the filters
""" """
@@ -802,31 +816,32 @@ class Connection(api.Connection):
return self._do_update_attribute(context, uuid, key, value) return self._do_update_attribute(context, uuid, key, value)
@oslo_db_api.retry_on_deadlock @oslo_db_api.retry_on_deadlock
@main_context_manager.writer
def _do_update_attribute(self, context, uuid, key, value): def _do_update_attribute(self, context, uuid, key, value):
update_fields = {'key': key, 'value': value} update_fields = {'key': key, 'value': value}
with _session_for_write(): query = model_query(context, models.Attribute)
query = model_query(context, models.Attribute) query = add_identity_filter(query, uuid)
query = add_identity_filter(query, uuid) try:
try: ref = query.with_for_update().one()
ref = query.with_for_update().one() except NoResultFound:
except NoResultFound: raise exception.ResourceNotFound(
raise exception.ResourceNotFound( resource='Attribute',
resource='Attribute', msg='with uuid=%s' % uuid)
msg='with uuid=%s' % uuid)
ref.update(update_fields) ref.update(update_fields)
return ref return ref
@main_context_manager.writer
def attribute_delete(self, context, uuid): def attribute_delete(self, context, uuid):
with _session_for_write(): query = model_query(context, models.Attribute)
query = model_query(context, models.Attribute) query = add_identity_filter(query, uuid)
query = add_identity_filter(query, uuid) count = query.delete()
count = query.delete() if count != 1:
if count != 1: raise exception.ResourceNotFound(
raise exception.ResourceNotFound( resource='Attribute',
resource='Attribute', msg='with uuid=%s' % uuid)
msg='with uuid=%s' % uuid)
@main_context_manager.writer
def extarq_create(self, context, values): def extarq_create(self, context, values):
if not values.get('uuid'): if not values.get('uuid'):
values['uuid'] = uuidutils.generate_uuid() values['uuid'] = uuidutils.generate_uuid()
@@ -845,24 +860,23 @@ class Connection(api.Connection):
extarq = models.ExtArq() extarq = models.ExtArq()
extarq.update(values) extarq.update(values)
with _session_for_write() as session: try:
try: context.session.add(extarq)
session.add(extarq) context.session.flush()
session.flush() except db_exc.DBDuplicateEntry:
except db_exc.DBDuplicateEntry: raise exception.ExtArqAlreadyExists(uuid=values['uuid'])
raise exception.ExtArqAlreadyExists(uuid=values['uuid']) return extarq
return extarq
@oslo_db_api.retry_on_deadlock @oslo_db_api.retry_on_deadlock
@main_context_manager.writer
def extarq_delete(self, context, uuid): def extarq_delete(self, context, uuid):
with _session_for_write(): query = model_query(context, models.ExtArq)
query = model_query(context, models.ExtArq) query = add_identity_filter(query, uuid)
query = add_identity_filter(query, uuid) count = query.delete()
count = query.delete() if count != 1:
if count != 1: raise exception.ResourceNotFound(
raise exception.ResourceNotFound( resource='ExtArq',
resource='ExtArq', msg='with uuid=%s' % uuid)
msg='with uuid=%s' % uuid)
def extarq_update(self, context, uuid, values, state_scope=None): def extarq_update(self, context, uuid, values, state_scope=None):
if 'uuid' in values and values['uuid'] != uuid: if 'uuid' in values and values['uuid'] != uuid:
@@ -871,24 +885,25 @@ class Connection(api.Connection):
return self._do_update_extarq(context, uuid, values, state_scope) return self._do_update_extarq(context, uuid, values, state_scope)
@oslo_db_api.retry_on_deadlock @oslo_db_api.retry_on_deadlock
@main_context_manager.writer
def _do_update_extarq(self, context, uuid, values, state_scope=None): def _do_update_extarq(self, context, uuid, values, state_scope=None):
with _session_for_write(): query = model_query(context, models.ExtArq)
query = model_query(context, models.ExtArq) query = query_update = query.filter_by(
query = query_update = query.filter_by( uuid=uuid).with_for_update()
uuid=uuid).with_for_update() if type(state_scope) is list:
if type(state_scope) is list: query_update = query_update.filter(
query_update = query_update.filter( models.ExtArq.state.in_(state_scope))
models.ExtArq.state.in_(state_scope)) try:
try: query_update.update(
query_update.update( values, synchronize_session="fetch")
values, synchronize_session="fetch") except NoResultFound:
except NoResultFound: raise exception.ResourceNotFound(
raise exception.ResourceNotFound( resource='ExtArq',
resource='ExtArq', msg='with uuid=%s' % uuid)
msg='with uuid=%s' % uuid) ref = query.first()
ref = query.first()
return ref return ref
@main_context_manager.reader
def extarq_list(self, context, uuid_range=None): def extarq_list(self, context, uuid_range=None):
query = model_query(context, models.ExtArq) query = model_query(context, models.ExtArq)
if type(uuid_range) is list: if type(uuid_range) is list:
@@ -897,6 +912,7 @@ class Connection(api.Connection):
return _paginate_query(context, models.ExtArq, query) return _paginate_query(context, models.ExtArq, query)
@oslo_db_api.retry_on_deadlock @oslo_db_api.retry_on_deadlock
@main_context_manager.writer
def extarq_get(self, context, uuid, lock=False): def extarq_get(self, context, uuid, lock=False):
query = model_query( query = model_query(
context, context,
@@ -911,6 +927,7 @@ class Connection(api.Connection):
resource='ExtArq', resource='ExtArq',
msg='with uuid=%s' % uuid) msg='with uuid=%s' % uuid)
@main_context_manager.writer
def _get_quota_usages(self, context, project_id, resources=None): def _get_quota_usages(self, context, project_id, resources=None):
# Broken out for testability # Broken out for testability
query = model_query(context, models.QuotaUsage,).filter_by( query = model_query(context, models.QuotaUsage,).filter_by(
@@ -966,86 +983,86 @@ class Connection(api.Connection):
with_for_update(). \ with_for_update(). \
all() all()
@main_context_manager.writer
def quota_reserve(self, context, resources, deltas, expire, def quota_reserve(self, context, resources, deltas, expire,
until_refresh, max_age, project_id=None, until_refresh, max_age, project_id=None,
is_allocated_reserve=False): is_allocated_reserve=False):
"""Create reservation record in DB according to params""" """Create reservation record in DB according to params"""
with _session_for_write() as session: if project_id is None:
if project_id is None: project_id = context.project_id
project_id = context.project_id usages = self._get_quota_usages(context, project_id,
usages = self._get_quota_usages(context, project_id, resources=deltas.keys())
resources=deltas.keys()) work = set(deltas.keys())
work = set(deltas.keys()) while work:
while work: resource = work.pop()
resource = work.pop()
# Do we need to refresh the usage? # Do we need to refresh the usage?
refresh = False refresh = False
# create quota usage in DB if there is no record of this type # create quota usage in DB if there is no record of this type
# of resource # of resource
if resource not in usages: if resource not in usages:
usages[resource] = self._quota_usage_create( usages[resource] = self._quota_usage_create(
project_id, resource, until_refresh or None, project_id, resource, until_refresh or None,
in_use=0, reserved=0, session=session) in_use=0, reserved=0, session=context.session)
refresh = True refresh = True
elif usages[resource].in_use < 0: elif usages[resource].in_use < 0:
# Negative in_use count indicates a desync, so try to # Negative in_use count indicates a desync, so try to
# heal from that... # heal from that...
refresh = True refresh = True
elif usages[resource].until_refresh is not None: elif usages[resource].until_refresh is not None:
usages[resource].until_refresh -= 1 usages[resource].until_refresh -= 1
if usages[resource].until_refresh <= 0: if usages[resource].until_refresh <= 0:
refresh = True
elif max_age and usages[resource].updated_at is not None and (
(timeutils.utcnow() -
usages[resource].updated_at).total_seconds() >=
max_age):
refresh = True refresh = True
elif max_age and usages[resource].updated_at is not None and (
(timeutils.utcnow() -
usages[resource].updated_at).total_seconds() >=
max_age):
refresh = True
# refresh the usage # refresh the usage
if refresh: if refresh:
# Grab the sync routine # Grab the sync routine
updates = self._sync_acc_res(context, updates = self._sync_acc_res(context,
resource, project_id) resource, project_id)
for res, in_use in updates.items(): for res, in_use in updates.items():
# Make sure we have a destination for the usage! # Make sure we have a destination for the usage!
if res not in usages: if res not in usages:
usages[res] = self._quota_usage_create( usages[res] = self._quota_usage_create(
project_id, project_id,
res, res,
until_refresh or None, until_refresh or None,
in_use=0, in_use=0,
reserved=0, reserved=0,
session=session session=context.session
) )
# Update the usage # Update the usage
usages[res].in_use = in_use usages[res].in_use = in_use
usages[res].until_refresh = until_refresh or None usages[res].until_refresh = until_refresh or None
# Because more than one resource may be refreshed # Because more than one resource may be refreshed
# by the call to the sync routine, and we don't # by the call to the sync routine, and we don't
# want to double-sync, we make sure all refreshed # want to double-sync, we make sure all refreshed
# resources are dropped from the work set. # resources are dropped from the work set.
work.discard(res) work.discard(res)
# NOTE(Vek): We make the assumption that the sync # NOTE(Vek): We make the assumption that the sync
# routine actually refreshes the # routine actually refreshes the
# resources that it is the sync routine # resources that it is the sync routine
# for. We don't check, because this is # for. We don't check, because this is
# a best-effort mechanism. # a best-effort mechanism.
unders = [r for r, delta in deltas.items() unders = [r for r, delta in deltas.items()
if delta < 0 and delta + usages[r].in_use < 0] if delta < 0 and delta + usages[r].in_use < 0]
reservations = [] reservations = []
for resource, delta in deltas.items(): for resource, delta in deltas.items():
usage = usages[resource] usage = usages[resource]
reservation = self._reservation_create( reservation = self._reservation_create(
str(uuid.uuid4()), usage, project_id, resource, str(uuid.uuid4()), usage, project_id, resource,
delta, expire, session=session) delta, expire, session=context.session)
reservations.append(reservation.uuid) reservations.append(reservation.uuid)
usages[resource].reserved += delta usages[resource].reserved += delta
session.flush() context.session.flush()
if unders: if unders:
LOG.warning("Change will make usage less than 0 for the " LOG.warning("Change will make usage less than 0 for the "
"following resources: %s", unders) "following resources: %s", unders)
@@ -1057,6 +1074,7 @@ class Connection(api.Connection):
project_id) project_id)
return {resource: res_in_use} return {resource: res_in_use}
@main_context_manager.reader
def _device_data_get_for_project(self, context, resource, project_id): def _device_data_get_for_project(self, context, resource, project_id):
"""Return the number of resource which is being used by a project""" """Return the number of resource which is being used by a project"""
query = model_query(context, models.Device).filter_by(type=resource) query = model_query(context, models.Device).filter_by(type=resource)
@@ -1066,24 +1084,24 @@ class Connection(api.Connection):
def _dict_with_usage_id(self, usages): def _dict_with_usage_id(self, usages):
return {row.id: row for row in usages.values()} return {row.id: row for row in usages.values()}
@main_context_manager.writer
def reservation_commit(self, context, reservations, project_id=None): def reservation_commit(self, context, reservations, project_id=None):
"""Commit quota reservation to quota usage table""" """Commit quota reservation to quota usage table"""
with _session_for_write() as session: quota_usage = self._get_quota_usages(
quota_usage = self._get_quota_usages( context, project_id,
context, project_id, resources=self._get_reservation_resources(context,
resources=self._get_reservation_resources(context, reservations))
reservations)) usages = self._dict_with_usage_id(quota_usage)
usages = self._dict_with_usage_id(quota_usage)
for reservation in self._quota_reservations(session, context, for reservation in self._quota_reservations(context.session, context,
reservations): reservations):
usage = usages[reservation.usage_id] usage = usages[reservation.usage_id]
if reservation.delta >= 0: if reservation.delta >= 0:
usage.reserved -= reservation.delta usage.reserved -= reservation.delta
usage.in_use += reservation.delta usage.in_use += reservation.delta
session.flush() context.session.flush()
reservation.delete(session=session) reservation.delete(session=context.session)
def process_sort_params(self, sort_keys, sort_dirs, def process_sort_params(self, sort_keys, sort_dirs,
default_keys=['created_at', 'id'], default_keys=['created_at', 'id'],

View File

@@ -18,7 +18,6 @@ from unittest import mock
from oslo_config import cfg from oslo_config import cfg
from oslo_config import fixture as config_fixture from oslo_config import fixture as config_fixture
from oslo_context import context
from oslo_db import options from oslo_db import options
from oslo_log import log from oslo_log import log
from oslo_utils import excutils from oslo_utils import excutils
@@ -29,6 +28,7 @@ import eventlet
import testtools import testtools
from cyborg.common import config as cyborg_config from cyborg.common import config as cyborg_config
from cyborg import context as cyborg_context
from cyborg.tests import post_mortem_debug from cyborg.tests import post_mortem_debug
from cyborg.tests.unit import policy_fixture from cyborg.tests.unit import policy_fixture
@@ -46,7 +46,7 @@ class TestCase(base.BaseTestCase):
def setUp(self): def setUp(self):
super(TestCase, self).setUp() super(TestCase, self).setUp()
self.context = context.get_admin_context() self.context = cyborg_context.get_admin_context()
self._set_config() self._set_config()
self.policy = self.useFixture(policy_fixture.PolicyFixture()) self.policy = self.useFixture(policy_fixture.PolicyFixture())
@@ -100,7 +100,7 @@ class DietTestCase(base.BaseTestCase):
def setUp(self): def setUp(self):
super(DietTestCase, self).setUp() super(DietTestCase, self).setUp()
self.context = context.get_admin_context() self.context = cyborg_context.get_admin_context()
options.set_defaults(cfg.CONF, connection='sqlite://') options.set_defaults(cfg.CONF, connection='sqlite://')

View File

@@ -17,9 +17,9 @@ import oslo_messaging as messaging
from cyborg.agent.rpcapi import AgentAPI from cyborg.agent.rpcapi import AgentAPI
from cyborg.common import constants from cyborg.common import constants
from cyborg.common import rpc from cyborg.common import rpc
from cyborg import context as cyborg_context
from cyborg.objects import base as objects_base from cyborg.objects import base as objects_base
from cyborg.tests import base from cyborg.tests import base
from oslo_context import context as oslo_context
from unittest import mock from unittest import mock
@@ -40,8 +40,8 @@ class TestRPCAPI(base.TestCase):
serializer=self.serializer) serializer=self.serializer)
def _test_rpc_call(self, method): def _test_rpc_call(self, method):
ctxt = oslo_context.RequestContext(user_id='fake_user', ctxt = cyborg_context.RequestContext(user_id='fake_user',
project_id='fake_project') project_id='fake_project')
expect_val = True expect_val = True
with mock.patch.object(self.agent_rpcapi, with mock.patch.object(self.agent_rpcapi,
'fpga_program') as mock_program: 'fpga_program') as mock_program:

View File

@@ -16,10 +16,10 @@
"""Base classes for API tests.""" """Base classes for API tests."""
from oslo_config import cfg from oslo_config import cfg
from oslo_context import context
import pecan import pecan
import pecan.testing import pecan.testing
from cyborg import context as cyborg_context
from cyborg.tests.unit.db import base from cyborg.tests.unit.db import base
cfg.CONF.import_group('keystone_authtoken', 'keystonemiddleware.auth_token') cfg.CONF.import_group('keystone_authtoken', 'keystonemiddleware.auth_token')
@@ -114,7 +114,7 @@ class BaseApiTest(base.DbTestCase):
status=status, method="post") status=status, method="post")
def gen_context(self, value, **kwargs): def gen_context(self, value, **kwargs):
ct = context.RequestContext.from_dict(value, **kwargs) ct = cyborg_context.RequestContext.from_dict(value, **kwargs)
return ct return ct
def gen_headers(self, context, **kw): def gen_headers(self, context, **kw):

View File

@@ -17,9 +17,9 @@
import fixtures import fixtures
from oslo_config import cfg from oslo_config import cfg
from oslo_db.sqlalchemy import enginefacade
from cyborg.db import api as dbapi from cyborg.db import api as dbapi
from cyborg.db.sqlalchemy import api as sqlalchemy_api
from cyborg.db.sqlalchemy import migration from cyborg.db.sqlalchemy import migration
from cyborg.db.sqlalchemy import models from cyborg.db.sqlalchemy import models
from cyborg.tests import base from cyborg.tests import base
@@ -65,7 +65,11 @@ class DbTestCase(base.TestCase):
global _DB_CACHE global _DB_CACHE
if not _DB_CACHE: if not _DB_CACHE:
engine = enginefacade.get_legacy_facade().get_engine() engine = (
sqlalchemy_api.main_context_manager
.get_legacy_facade()
.get_engine()
)
_DB_CACHE = Database(engine, migration, _DB_CACHE = Database(engine, migration,
sql_connection=CONF.database.connection) sql_connection=CONF.database.connection)
self.useFixture(_DB_CACHE) self.useFixture(_DB_CACHE)

View File

@@ -16,8 +16,8 @@ import datetime
from oslo_log import log from oslo_log import log
from oslo_context import context
from cyborg import context as cyborg_context
from cyborg.objects import base from cyborg.objects import base
from cyborg.objects import fields from cyborg.objects import fields
from cyborg import tests as test from cyborg import tests as test
@@ -192,7 +192,8 @@ class _BaseTestCase(test.base.TestCase):
super(_BaseTestCase, self).setUp() super(_BaseTestCase, self).setUp()
self.user_id = 'fake-user' self.user_id = 'fake-user'
self.project_id = 'fake-project' self.project_id = 'fake-project'
self.context = context.RequestContext(self.user_id, self.project_id) self.context = cyborg_context.RequestContext(self.user_id,
self.project_id)
base.CyborgObjectRegistry.register(MyObj) base.CyborgObjectRegistry.register(MyObj)
base.CyborgObjectRegistry.register(MyOwnedObject) base.CyborgObjectRegistry.register(MyOwnedObject)