[sqlalchemy2] subtransactions & autocommit removal

subtransactions are being removed from SQLAlchemy, as well as autocommit
sessions [0] [1]

[0] https://docs.sqlalchemy.org/en/14/changelog/migration_20.html#\
    session-subtransaction-behavior-removed
[1] https://docs.sqlalchemy.org/en/14/changelog/migration_20.html#\
    library-level-but-not-driver-level-autocommit-removed-from-both-core-and-orm

Change-Id: Ie32a17c9bf91b05752bc0457151b6b0aa4151264
This commit is contained in:
Gregory Thiemonge 2022-10-08 13:06:11 +02:00
parent 12d8e0de5d
commit b15a6e8c7d
54 changed files with 3113 additions and 2194 deletions

View File

@ -37,7 +37,7 @@ import octavia.common.jinja.haproxy.combined_listeners.jinja_cfg as jinja_combo
from octavia.common.jinja.lvs import jinja_cfg as jinja_udp_cfg
from octavia.common.tls_utils import cert_parser
from octavia.common import utils
from octavia.db import api as db_apis
from octavia.db import api as db_api
from octavia.db import models as db_models
from octavia.db import repositories as repo
@ -182,9 +182,10 @@ class HaproxyAmphoraLoadBalancerDriver(
'"%s". Skipping this listener.',
listener.id, str(e))
listener_repo = repo.ListenerRepository()
listener_repo.update(db_apis.get_session(), listener.id,
provisioning_status=consts.ERROR,
operating_status=consts.ERROR)
with db_api.session().begin() as session:
listener_repo.update(session, listener.id,
provisioning_status=consts.ERROR,
operating_status=consts.ERROR)
if has_tcp:
if listeners_to_update:

View File

@ -331,8 +331,9 @@ class UpdateHealthDb:
session = db_api.get_session()
# We need to see if all of the listeners are reporting in
db_lb = self.amphora_repo.get_lb_for_health_update(session,
health['id'])
with session.begin():
db_lb = self.amphora_repo.get_lb_for_health_update(session,
health['id'])
ignore_listener_count = False
if db_lb:
@ -353,11 +354,13 @@ class UpdateHealthDb:
l for k, l in db_lb.get('listeners', {}).items()
if l['protocol'] == constants.PROTOCOL_UDP]
if udp_listeners:
expected_listener_count = (
self._update_listener_count_for_UDP(
session, db_lb, expected_listener_count))
with session.begin():
expected_listener_count = (
self._update_listener_count_for_UDP(
session, db_lb, expected_listener_count))
else:
amp = self.amphora_repo.get(session, id=health['id'])
with session.begin():
amp = self.amphora_repo.get(session, id=health['id'])
# This is debug and not warning because this can happen under
# normal deleting operations.
LOG.debug('Received a health heartbeat from amphora %s with '
@ -392,8 +395,6 @@ class UpdateHealthDb:
# does not match the expected listener count
if len(listeners) == expected_listener_count or ignore_listener_count:
lock_session = db_api.get_session(autocommit=False)
# if we're running too far behind, warn and bail
proc_delay = time.time() - health['recv_time']
hb_interval = CONF.health_manager.heartbeat_interval
@ -409,6 +410,9 @@ class UpdateHealthDb:
{'id': health['id'], 'delay': proc_delay})
return
lock_session = db_api.get_session()
lock_session.begin()
# if the input amphora is healthy, we update its db info
try:
self.amphora_health_repo.replace(
@ -472,9 +476,10 @@ class UpdateHealthDb:
try:
if (listener_status is not None and
listener_status != db_op_status):
self._update_status(
session, self.listener_repo, constants.LISTENER,
listener_id, listener_status, db_op_status)
with session.begin():
self._update_status(
session, self.listener_repo, constants.LISTENER,
listener_id, listener_status, db_op_status)
except sqlalchemy.orm.exc.NoResultFound:
LOG.error("Listener %s is not in DB", listener_id)
@ -496,9 +501,11 @@ class UpdateHealthDb:
if db_pool_id in processed_pools:
continue
db_pool_dict = db_lb['pools'][db_pool_id]
lb_status = self._process_pool_status(
session, db_pool_id, db_pool_dict, pools,
lb_status, processed_pools, potential_offline_pools)
with session.begin():
lb_status = self._process_pool_status(
session, db_pool_id, db_pool_dict, pools,
lb_status, processed_pools,
potential_offline_pools)
if health_msg_version >= 2:
raw_pools = health['pools']
@ -514,9 +521,10 @@ class UpdateHealthDb:
if db_pool_id in processed_pools:
continue
db_pool_dict = db_lb['pools'][db_pool_id]
lb_status = self._process_pool_status(
session, db_pool_id, db_pool_dict, pools,
lb_status, processed_pools, potential_offline_pools)
with session.begin():
lb_status = self._process_pool_status(
session, db_pool_id, db_pool_dict, pools,
lb_status, processed_pools, potential_offline_pools)
for pool_id, pool in potential_offline_pools.items():
# Skip if we eventually found a status for this pool
@ -525,19 +533,21 @@ class UpdateHealthDb:
try:
# If the database doesn't already show the pool offline, update
if pool != constants.OFFLINE:
self._update_status(
session, self.pool_repo, constants.POOL,
pool_id, constants.OFFLINE, pool)
with session.begin():
self._update_status(
session, self.pool_repo, constants.POOL,
pool_id, constants.OFFLINE, pool)
except sqlalchemy.orm.exc.NoResultFound:
LOG.error("Pool %s is not in DB", pool_id)
# Update the load balancer status last
try:
if lb_status != db_lb['operating_status']:
self._update_status(
session, self.loadbalancer_repo,
constants.LOADBALANCER, db_lb['id'], lb_status,
db_lb[constants.OPERATING_STATUS])
with session.begin():
self._update_status(
session, self.loadbalancer_repo,
constants.LOADBALANCER, db_lb['id'], lb_status,
db_lb[constants.OPERATING_STATUS])
except sqlalchemy.orm.exc.NoResultFound:
LOG.error("Load balancer %s is not in DB", db_lb.id)

View File

@ -273,8 +273,10 @@ class AmphoraProviderDriver(driver_base.ProviderDriver):
# Member
def member_create(self, member):
pool_id = member.pool_id
db_pool = self.repositories.pool.get(db_apis.get_session(),
id=pool_id)
session = db_apis.get_session()
with session.begin():
db_pool = self.repositories.pool.get(session,
id=pool_id)
self._validate_members(db_pool, [member])
payload = {consts.MEMBER: member.to_dict()}
@ -296,7 +298,9 @@ class AmphoraProviderDriver(driver_base.ProviderDriver):
def member_batch_update(self, pool_id, members):
# The DB should not have updated yet, so we can still use the pool
db_pool = self.repositories.pool.get(db_apis.get_session(), id=pool_id)
session = db_apis.get_session()
with session.begin():
db_pool = self.repositories.pool.get(session, id=pool_id)
self._validate_members(db_pool, members)
@ -385,8 +389,10 @@ class AmphoraProviderDriver(driver_base.ProviderDriver):
# L7 Policy
def l7policy_create(self, l7policy):
db_listener = self.repositories.listener.get(db_apis.get_session(),
id=l7policy.listener_id)
session = db_apis.get_session()
with session.begin():
db_listener = self.repositories.listener.get(
session, id=l7policy.listener_id)
if db_listener.protocol not in VALID_L7POLICY_LISTENER_PROTOCOLS:
msg = ('%s protocol listeners do not support L7 policies' % (
db_listener.protocol))

View File

@ -25,56 +25,64 @@ def process_get(get_data):
if get_data[constants.OBJECT] == lib_consts.LOADBALANCERS:
lb_repo = repositories.LoadBalancerRepository()
db_lb = lb_repo.get(session, id=get_data[lib_consts.ID],
show_deleted=False)
with session.begin():
db_lb = lb_repo.get(session, id=get_data[lib_consts.ID],
show_deleted=False)
if db_lb:
provider_lb = (
driver_utils.db_loadbalancer_to_provider_loadbalancer(db_lb))
return provider_lb.to_dict(recurse=True, render_unsets=True)
elif get_data[constants.OBJECT] == lib_consts.LISTENERS:
listener_repo = repositories.ListenerRepository()
db_listener = listener_repo.get(
session, id=get_data[lib_consts.ID], show_deleted=False)
with session.begin():
db_listener = listener_repo.get(
session, id=get_data[lib_consts.ID], show_deleted=False)
if db_listener:
provider_listener = (
driver_utils.db_listener_to_provider_listener(db_listener))
return provider_listener.to_dict(recurse=True, render_unsets=True)
elif get_data[constants.OBJECT] == lib_consts.POOLS:
pool_repo = repositories.PoolRepository()
db_pool = pool_repo.get(session, id=get_data[lib_consts.ID],
show_deleted=False)
with session.begin():
db_pool = pool_repo.get(session, id=get_data[lib_consts.ID],
show_deleted=False)
if db_pool:
provider_pool = (
driver_utils.db_pool_to_provider_pool(db_pool))
return provider_pool.to_dict(recurse=True, render_unsets=True)
elif get_data[constants.OBJECT] == lib_consts.MEMBERS:
member_repo = repositories.MemberRepository()
db_member = member_repo.get(session, id=get_data[lib_consts.ID],
show_deleted=False)
with session.begin():
db_member = member_repo.get(session, id=get_data[lib_consts.ID],
show_deleted=False)
if db_member:
provider_member = (
driver_utils.db_member_to_provider_member(db_member))
return provider_member.to_dict(recurse=True, render_unsets=True)
elif get_data[constants.OBJECT] == lib_consts.HEALTHMONITORS:
hm_repo = repositories.HealthMonitorRepository()
db_hm = hm_repo.get(session, id=get_data[lib_consts.ID],
show_deleted=False)
with session.begin():
db_hm = hm_repo.get(session, id=get_data[lib_consts.ID],
show_deleted=False)
if db_hm:
provider_hm = (
driver_utils.db_HM_to_provider_HM(db_hm))
return provider_hm.to_dict(recurse=True, render_unsets=True)
elif get_data[constants.OBJECT] == lib_consts.L7POLICIES:
l7policy_repo = repositories.L7PolicyRepository()
db_l7policy = l7policy_repo.get(session, id=get_data[lib_consts.ID],
show_deleted=False)
with session.begin():
db_l7policy = l7policy_repo.get(session,
id=get_data[lib_consts.ID],
show_deleted=False)
if db_l7policy:
provider_l7policy = (
driver_utils.db_l7policy_to_provider_l7policy(db_l7policy))
return provider_l7policy.to_dict(recurse=True, render_unsets=True)
elif get_data[constants.OBJECT] == lib_consts.L7RULES:
l7rule_repo = repositories.L7RuleRepository()
db_l7rule = l7rule_repo.get(session, id=get_data[lib_consts.ID],
show_deleted=False)
with session.begin():
db_l7rule = l7rule_repo.get(session, id=get_data[lib_consts.ID],
show_deleted=False)
if db_l7rule:
provider_l7rule = (
driver_utils.db_l7rule_to_provider_l7rule(db_l7rule))

View File

@ -46,7 +46,8 @@ class DriverUpdater(object):
super().__init__(**kwargs)
def _check_for_lb_vip_deallocate(self, repo, lb_id):
lb = repo.get(self.db_session, id=lb_id)
with self.db_session.begin():
lb = repo.get(self.db_session, id=lb_id)
if lb.vip.octavia_owned:
vip = lb.vip
# We need a backreference
@ -56,7 +57,8 @@ class DriverUpdater(object):
network_driver.deallocate_vip(vip)
def _decrement_quota(self, repo, object_name, record_id):
lock_session = db_apis.get_session(autocommit=False)
lock_session = self.db_session
lock_session.begin()
db_object = repo.get(lock_session, id=record_id)
if db_object is None:
lock_session.rollback()
@ -106,7 +108,8 @@ class DriverUpdater(object):
return
if delete_record and object_name != consts.LOADBALANCERS:
repo.delete(self.db_session, id=record_id)
with self.db_session.begin():
repo.delete(self.db_session, id=record_id)
return
record_kwargs[consts.PROVISIONING_STATUS] = prov_status
@ -114,7 +117,8 @@ class DriverUpdater(object):
if op_status:
record_kwargs[consts.OPERATING_STATUS] = op_status
if prov_status or op_status:
repo.update(self.db_session, record_id, **record_kwargs)
with self.db_session.begin():
repo.update(self.db_session, record_id, **record_kwargs)
except Exception as e:
# We need to raise a failure here to notify the driver it is
# sending bad status data.

View File

@ -132,8 +132,10 @@ def lb_dict_to_provider_dict(lb_dict, vip=None, add_vips=None, db_pools=None,
new_lb_dict['vip_qos_policy_id'] = vip.qos_policy_id
if 'flavor_id' in lb_dict and lb_dict['flavor_id']:
flavor_repo = repositories.FlavorRepository()
new_lb_dict['flavor'] = flavor_repo.get_flavor_metadata_dict(
db_api.get_session(), lb_dict['flavor_id'])
session = db_api.get_session()
with session.begin():
new_lb_dict['flavor'] = flavor_repo.get_flavor_metadata_dict(
session, lb_dict['flavor_id'])
if add_vips:
new_lb_dict['additional_vips'] = db_additional_vips_to_provider_vips(
add_vips)

View File

@ -42,8 +42,10 @@ class OctaviaDBHealthcheck(pluginbase.HealthcheckBaseExtension):
result = self.last_result
message = self.last_message
else:
result, message = healthcheck.check_database_connection(
db_apis.get_session())
session = db_apis.get_session()
with session.begin():
result, message = healthcheck.check_database_connection(
session)
self.last_check = datetime.datetime.now()
self.last_result = result
self.last_message = message

View File

@ -28,7 +28,6 @@ from octavia.api.v2.types import amphora as amp_types
from octavia.common import constants
from octavia.common import exceptions
from octavia.common import rpc
from octavia.db import api as db_api
CONF = cfg.CONF
LOG = logging.getLogger(__name__)
@ -50,7 +49,8 @@ class AmphoraController(base.BaseController):
def get_one(self, id, fields=None):
"""Gets a single amphora's details."""
context = pecan_request.context.get('octavia_context')
db_amp = self._get_db_amp(context.session, id, show_deleted=False)
with context.session.begin():
db_amp = self._get_db_amp(context.session, id, show_deleted=False)
self._auth_validate_action(context, context.project_id,
constants.RBAC_GET_ONE)
@ -71,9 +71,10 @@ class AmphoraController(base.BaseController):
self._auth_validate_action(context, context.project_id,
constants.RBAC_GET_ALL)
db_amp, links = self.repositories.amphora.get_all_API_list(
context.session, show_deleted=False,
pagination_helper=pcontext.get(constants.PAGINATION_HELPER))
with context.session.begin():
db_amp, links = self.repositories.amphora.get_all_API_list(
context.session, show_deleted=False,
pagination_helper=pcontext.get(constants.PAGINATION_HELPER))
result = self._convert_db_to_type(
db_amp, [amp_types.AmphoraResponse])
if fields is not None:
@ -89,10 +90,10 @@ class AmphoraController(base.BaseController):
self._auth_validate_action(context, context.project_id,
constants.RBAC_DELETE)
with db_api.get_lock_session() as lock_session:
with context.session.begin():
try:
self.repositories.amphora.test_and_set_status_for_delete(
lock_session, id)
context.session, id)
except sa_exception.NoResultFound as e:
raise exceptions.NotFound(resource='Amphora', id=id) from e
@ -137,27 +138,29 @@ class FailoverController(base.BaseController):
"""Fails over an amphora"""
pcontext = pecan_request.context
context = pcontext.get('octavia_context')
db_amp = self._get_db_amp(context.session, self.amp_id,
show_deleted=False)
with context.session.begin():
db_amp = self._get_db_amp(context.session, self.amp_id,
show_deleted=False)
self._auth_validate_action(
context, db_amp.load_balancer.project_id,
constants.RBAC_PUT_FAILOVER)
self.repositories.load_balancer.test_and_set_provisioning_status(
context.session, db_amp.load_balancer_id,
status=constants.PENDING_UPDATE, raise_exception=True)
with context.session.begin():
self.repositories.load_balancer.test_and_set_provisioning_status(
context.session, db_amp.load_balancer_id,
status=constants.PENDING_UPDATE, raise_exception=True)
try:
LOG.info("Sending failover request for amphora %s to the queue",
self.amp_id)
payload = {constants.AMPHORA_ID: db_amp.id}
self.client.cast({}, 'failover_amphora', **payload)
except Exception:
with excutils.save_and_reraise_exception(reraise=False):
self.repositories.load_balancer.update(
context.session, db_amp.load_balancer.id,
provisioning_status=constants.ERROR)
try:
LOG.info("Sending failover request for amphora %s to the "
"queue", self.amp_id)
payload = {constants.AMPHORA_ID: db_amp.id}
self.client.cast({}, 'failover_amphora', **payload)
except Exception:
with excutils.save_and_reraise_exception(reraise=False):
self.repositories.load_balancer.update(
context.session, db_amp.load_balancer.id,
provisioning_status=constants.ERROR)
class AmphoraUpdateController(base.BaseController):
@ -179,8 +182,9 @@ class AmphoraUpdateController(base.BaseController):
"""Update amphora agent configuration"""
pcontext = pecan_request.context
context = pcontext.get('octavia_context')
db_amp = self._get_db_amp(context.session, self.amp_id,
show_deleted=False)
with context.session.begin():
db_amp = self._get_db_amp(context.session, self.amp_id,
show_deleted=False)
self._auth_validate_action(
context, db_amp.load_balancer.project_id,
@ -212,8 +216,9 @@ class AmphoraStatsController(base.BaseController):
self._auth_validate_action(context, context.project_id,
constants.RBAC_GET_STATS)
stats = self.repositories.get_amphora_stats(context.session,
self.amp_id)
with context.session.begin():
stats = self.repositories.get_amphora_stats(context.session,
self.amp_id)
if not stats:
raise exceptions.NotFound(resource='Amphora stats for',
id=self.amp_id)

View File

@ -28,7 +28,6 @@ from octavia.api.v2.controllers import base
from octavia.api.v2.types import availability_zone_profile as profile_types
from octavia.common import constants
from octavia.common import exceptions
from octavia.db import api as db_api
LOG = logging.getLogger(__name__)
@ -49,8 +48,9 @@ class AvailabilityZoneProfileController(base.BaseController):
if id == constants.NIL_UUID:
raise exceptions.NotFound(resource='Availability Zone Profile',
id=constants.NIL_UUID)
db_availability_zone_profile = self._get_db_availability_zone_profile(
context.session, id)
with context.session.begin():
db_availability_zone_profile = (
self._get_db_availability_zone_profile(context.session, id))
result = self._convert_db_to_type(
db_availability_zone_profile,
profile_types.AvailabilityZoneProfileResponse)
@ -67,10 +67,12 @@ class AvailabilityZoneProfileController(base.BaseController):
context = pcontext.get('octavia_context')
self._auth_validate_action(context, context.project_id,
constants.RBAC_GET_ALL)
db_availability_zone_profiles, links = (
self.repositories.availability_zone_profile.get_all(
context.session,
pagination_helper=pcontext.get(constants.PAGINATION_HELPER)))
with context.session.begin():
db_availability_zone_profiles, links = (
self.repositories.availability_zone_profile.get_all(
context.session,
pagination_helper=pcontext.get(
constants.PAGINATION_HELPER)))
result = self._convert_db_to_type(
db_availability_zone_profiles,
[profile_types.AvailabilityZoneProfileResponse])
@ -106,21 +108,21 @@ class AvailabilityZoneProfileController(base.BaseController):
driver.name, driver.validate_availability_zone,
availability_zone_data_dict)
lock_session = db_api.get_session(autocommit=False)
context.session.begin()
try:
availability_zone_profile_dict = availability_zone_profile.to_dict(
render_unsets=True)
availability_zone_profile_dict['id'] = uuidutils.generate_uuid()
db_availability_zone_profile = (
self.repositories.availability_zone_profile.create(
lock_session, **availability_zone_profile_dict))
lock_session.commit()
context.session, **availability_zone_profile_dict))
context.session.commit()
except odb_exceptions.DBDuplicateEntry as e:
lock_session.rollback()
context.session.rollback()
raise exceptions.IDAlreadyExists() from e
except Exception:
with excutils.save_and_reraise_exception():
lock_session.rollback()
context.session.rollback()
result = self._convert_db_to_type(
db_availability_zone_profile,
profile_types.AvailabilityZoneProfileResponse)
@ -159,7 +161,8 @@ class AvailabilityZoneProfileController(base.BaseController):
self._auth_validate_action(context, context.project_id,
constants.RBAC_PUT)
self._validate_update_azp(context, id, availability_zone_profile)
with context.session.begin():
self._validate_update_azp(context, id, availability_zone_profile)
if id == constants.NIL_UUID:
raise exceptions.NotFound(resource='Availability Zone Profile',
id=constants.NIL_UUID)
@ -177,9 +180,10 @@ class AvailabilityZoneProfileController(base.BaseController):
if isinstance(availability_zone_profile.provider_name,
wtypes.UnsetType):
db_availability_zone_profile = (
self._get_db_availability_zone_profile(
context.session, id))
with context.session.begin():
db_availability_zone_profile = (
self._get_db_availability_zone_profile(
context.session, id))
provider_driver = db_availability_zone_profile.provider_name
else:
provider_driver = availability_zone_profile.provider_name
@ -190,23 +194,25 @@ class AvailabilityZoneProfileController(base.BaseController):
driver.name, driver.validate_availability_zone,
availability_zone_data_dict)
lock_session = db_api.get_session(autocommit=False)
context.session.begin()
try:
availability_zone_profile_dict = availability_zone_profile.to_dict(
render_unsets=False)
if availability_zone_profile_dict:
self.repositories.availability_zone_profile.update(
lock_session, id, **availability_zone_profile_dict)
lock_session.commit()
context.session, id, **availability_zone_profile_dict)
context.session.commit()
except Exception:
with excutils.save_and_reraise_exception():
lock_session.rollback()
context.session.rollback()
# Force SQL alchemy to query the DB, otherwise we get inconsistent
# results
context.session.expire_all()
db_availability_zone_profile = self._get_db_availability_zone_profile(
context.session, id)
with context.session.begin():
db_availability_zone_profile = (
self._get_db_availability_zone_profile(context.session,
id))
result = self._convert_db_to_type(
db_availability_zone_profile,
profile_types.AvailabilityZoneProfileResponse)
@ -224,14 +230,18 @@ class AvailabilityZoneProfileController(base.BaseController):
raise exceptions.NotFound(resource='Availability Zone Profile',
id=constants.NIL_UUID)
# Don't allow it to be deleted if it is in use by an availability zone
if self.repositories.availability_zone.count(
context.session,
availability_zone_profile_id=availability_zone_profile_id) > 0:
raise exceptions.ObjectInUse(object='Availability Zone Profile',
id=availability_zone_profile_id)
try:
self.repositories.availability_zone_profile.delete(
context.session, id=availability_zone_profile_id)
except sa_exception.NoResultFound as e:
raise exceptions.NotFound(resource='Availability Zone Profile',
id=availability_zone_profile_id) from e
with context.session.begin():
if self.repositories.availability_zone.count(
context.session,
availability_zone_profile_id=availability_zone_profile_id
) > 0:
raise exceptions.ObjectInUse(
object='Availability Zone Profile',
id=availability_zone_profile_id)
try:
self.repositories.availability_zone_profile.delete(
context.session, id=availability_zone_profile_id)
except sa_exception.NoResultFound as e:
raise exceptions.NotFound(
resource='Availability Zone Profile',
id=availability_zone_profile_id) from e

View File

@ -46,8 +46,9 @@ class AvailabilityZonesController(base.BaseController):
if name == constants.NIL_UUID:
raise exceptions.NotFound(resource='Availability Zone',
id=constants.NIL_UUID)
db_availability_zone = self._get_db_availability_zone(
context.session, name)
with context.session.begin():
db_availability_zone = self._get_db_availability_zone(
context.session, name)
result = self._convert_db_to_type(
db_availability_zone,
availability_zone_types.AvailabilityZoneResponse)
@ -64,10 +65,12 @@ class AvailabilityZonesController(base.BaseController):
context = pcontext.get('octavia_context')
self._auth_validate_action(context, context.project_id,
constants.RBAC_GET_ALL)
db_availability_zones, links = (
self.repositories.availability_zone.get_all(
context.session,
pagination_helper=pcontext.get(constants.PAGINATION_HELPER)))
with context.session.begin():
db_availability_zones, links = (
self.repositories.availability_zone.get_all(
context.session,
pagination_helper=pcontext.get(
constants.PAGINATION_HELPER)))
result = self._convert_db_to_type(
db_availability_zones,
[availability_zone_types.AvailabilityZoneResponse])
@ -86,20 +89,20 @@ class AvailabilityZonesController(base.BaseController):
self._auth_validate_action(context, context.project_id,
constants.RBAC_POST)
lock_session = db_api.get_session(autocommit=False)
context.session.begin()
try:
availability_zone_dict = availability_zone.to_dict(
render_unsets=True)
db_availability_zone = self.repositories.availability_zone.create(
lock_session, **availability_zone_dict)
lock_session.commit()
context.session, **availability_zone_dict)
context.session.commit()
except odb_exceptions.DBDuplicateEntry as e:
lock_session.rollback()
context.session.rollback()
raise exceptions.RecordAlreadyExists(
field='availability zone', name=availability_zone.name) from e
except Exception:
with excutils.save_and_reraise_exception():
lock_session.rollback()
context.session.rollback()
result = self._convert_db_to_type(
db_availability_zone,
availability_zone_types.AvailabilityZoneResponse)
@ -117,23 +120,24 @@ class AvailabilityZonesController(base.BaseController):
if name == constants.NIL_UUID:
raise exceptions.NotFound(resource='Availability Zone',
id=constants.NIL_UUID)
lock_session = db_api.get_session(autocommit=False)
context.session.begin()
try:
availability_zone_dict = availability_zone.to_dict(
render_unsets=False)
if availability_zone_dict:
self.repositories.availability_zone.update(
lock_session, name, **availability_zone_dict)
lock_session.commit()
context.session, name, **availability_zone_dict)
context.session.commit()
except Exception:
with excutils.save_and_reraise_exception():
lock_session.rollback()
context.session.rollback()
# Force SQL alchemy to query the DB, otherwise we get inconsistent
# results
context.session.expire_all()
db_availability_zone = self._get_db_availability_zone(
context.session, name)
with context.session.begin():
db_availability_zone = self._get_db_availability_zone(
context.session, name)
result = self._convert_db_to_type(
db_availability_zone,
availability_zone_types.AvailabilityZoneResponse)
@ -151,7 +155,7 @@ class AvailabilityZonesController(base.BaseController):
if availability_zone_name == constants.NIL_UUID:
raise exceptions.NotFound(resource='Availability Zone',
id=constants.NIL_UUID)
serial_session = db_api.get_session(autocommit=False)
serial_session = db_api.get_session()
serial_session.connection(
execution_options={'isolation_level': 'SERIALIZABLE'})
try:

View File

@ -29,7 +29,6 @@ from octavia.api.v2.controllers import base
from octavia.api.v2.types import flavor_profile as profile_types
from octavia.common import constants
from octavia.common import exceptions
from octavia.db import api as db_api
LOG = logging.getLogger(__name__)
@ -50,7 +49,9 @@ class FlavorProfileController(base.BaseController):
if id == constants.NIL_UUID:
raise exceptions.NotFound(resource='Flavor profile',
id=constants.NIL_UUID)
db_flavor_profile = self._get_db_flavor_profile(context.session, id)
with context.session.begin():
db_flavor_profile = self._get_db_flavor_profile(context.session,
id)
result = self._convert_db_to_type(db_flavor_profile,
profile_types.FlavorProfileResponse)
if fields is not None:
@ -65,9 +66,12 @@ class FlavorProfileController(base.BaseController):
context = pcontext.get('octavia_context')
self._auth_validate_action(context, context.project_id,
constants.RBAC_GET_ALL)
db_flavor_profiles, links = self.repositories.flavor_profile.get_all(
context.session,
pagination_helper=pcontext.get(constants.PAGINATION_HELPER))
with context.session.begin():
db_flavor_profiles, links = (
self.repositories.flavor_profile.get_all(
context.session,
pagination_helper=pcontext.get(
constants.PAGINATION_HELPER)))
result = self._convert_db_to_type(
db_flavor_profiles, [profile_types.FlavorProfileResponse])
if fields is not None:
@ -97,19 +101,19 @@ class FlavorProfileController(base.BaseController):
driver_utils.call_provider(driver.name, driver.validate_flavor,
flavor_data_dict)
lock_session = db_api.get_session(autocommit=False)
context.session.begin()
try:
flavorprofile_dict = flavorprofile.to_dict(render_unsets=True)
flavorprofile_dict['id'] = uuidutils.generate_uuid()
db_flavor_profile = self.repositories.flavor_profile.create(
lock_session, **flavorprofile_dict)
lock_session.commit()
context.session, **flavorprofile_dict)
context.session.commit()
except odb_exceptions.DBDuplicateEntry as e:
lock_session.rollback()
context.session.rollback()
raise exceptions.IDAlreadyExists() from e
except Exception:
with excutils.save_and_reraise_exception():
lock_session.rollback()
context.session.rollback()
result = self._convert_db_to_type(
db_flavor_profile, profile_types.FlavorProfileResponse)
return profile_types.FlavorProfileRootResponse(flavorprofile=result)
@ -142,7 +146,8 @@ class FlavorProfileController(base.BaseController):
self._auth_validate_action(context, context.project_id,
constants.RBAC_PUT)
self._validate_update_fp(context, id, flavorprofile)
with context.session.begin():
self._validate_update_fp(context, id, flavorprofile)
if id == constants.NIL_UUID:
raise exceptions.NotFound(resource='Flavor profile',
id=constants.NIL_UUID)
@ -157,8 +162,9 @@ class FlavorProfileController(base.BaseController):
option=constants.FLAVOR_DATA) from e
if isinstance(flavorprofile.provider_name, wtypes.UnsetType):
db_flavor_profile = self._get_db_flavor_profile(
context.session, id)
with context.session.begin():
db_flavor_profile = self._get_db_flavor_profile(
context.session, id)
provider_driver = db_flavor_profile.provider_name
else:
provider_driver = flavorprofile.provider_name
@ -168,21 +174,23 @@ class FlavorProfileController(base.BaseController):
driver_utils.call_provider(driver.name, driver.validate_flavor,
flavor_data_dict)
lock_session = db_api.get_session(autocommit=False)
context.session.begin()
try:
flavorprofile_dict = flavorprofile.to_dict(render_unsets=False)
if flavorprofile_dict:
self.repositories.flavor_profile.update(lock_session, id,
self.repositories.flavor_profile.update(context.session, id,
**flavorprofile_dict)
lock_session.commit()
context.session.commit()
except Exception:
with excutils.save_and_reraise_exception():
lock_session.rollback()
context.session.rollback()
# Force SQL alchemy to query the DB, otherwise we get inconsistent
# results
context.session.expire_all()
db_flavor_profile = self._get_db_flavor_profile(context.session, id)
with context.session.begin():
db_flavor_profile = self._get_db_flavor_profile(context.session,
id)
result = self._convert_db_to_type(
db_flavor_profile, profile_types.FlavorProfileResponse)
return profile_types.FlavorProfileRootResponse(flavorprofile=result)
@ -200,13 +208,14 @@ class FlavorProfileController(base.BaseController):
id=constants.NIL_UUID)
# Don't allow it to be deleted if it is in use by a flavor
if self.repositories.flavor.count(
context.session, flavor_profile_id=flavor_profile_id) > 0:
raise exceptions.ObjectInUse(object='Flavor profile',
id=flavor_profile_id)
try:
self.repositories.flavor_profile.delete(context.session,
id=flavor_profile_id)
except sa_exception.NoResultFound as e:
raise exceptions.NotFound(
resource='Flavor profile', id=flavor_profile_id) from e
with context.session.begin():
if self.repositories.flavor.count(
context.session, flavor_profile_id=flavor_profile_id) > 0:
raise exceptions.ObjectInUse(object='Flavor profile',
id=flavor_profile_id)
try:
self.repositories.flavor_profile.delete(context.session,
id=flavor_profile_id)
except sa_exception.NoResultFound as e:
raise exceptions.NotFound(
resource='Flavor profile', id=flavor_profile_id) from e

View File

@ -47,7 +47,8 @@ class FlavorsController(base.BaseController):
constants.RBAC_GET_ONE)
if id == constants.NIL_UUID:
raise exceptions.NotFound(resource='Flavor', id=constants.NIL_UUID)
db_flavor = self._get_db_flavor(context.session, id)
with context.session.begin():
db_flavor = self._get_db_flavor(context.session, id)
result = self._convert_db_to_type(db_flavor,
flavor_types.FlavorResponse)
if fields is not None:
@ -62,9 +63,10 @@ class FlavorsController(base.BaseController):
context = pcontext.get('octavia_context')
self._auth_validate_action(context, context.project_id,
constants.RBAC_GET_ALL)
db_flavors, links = self.repositories.flavor.get_all(
context.session,
pagination_helper=pcontext.get(constants.PAGINATION_HELPER))
with context.session.begin():
db_flavors, links = self.repositories.flavor.get_all(
context.session,
pagination_helper=pcontext.get(constants.PAGINATION_HELPER))
result = self._convert_db_to_type(
db_flavors, [flavor_types.FlavorResponse])
if fields is not None:
@ -83,20 +85,20 @@ class FlavorsController(base.BaseController):
# TODO(johnsom) Validate the flavor profile ID
lock_session = db_api.get_session(autocommit=False)
context.session.begin()
try:
flavor_dict = flavor.to_dict(render_unsets=True)
flavor_dict['id'] = uuidutils.generate_uuid()
db_flavor = self.repositories.flavor.create(lock_session,
db_flavor = self.repositories.flavor.create(context.session,
**flavor_dict)
lock_session.commit()
context.session.commit()
except odb_exceptions.DBDuplicateEntry as e:
lock_session.rollback()
context.session.rollback()
raise exceptions.RecordAlreadyExists(field='flavor',
name=flavor.name) from e
except Exception:
with excutils.save_and_reraise_exception():
lock_session.rollback()
context.session.rollback()
result = self._convert_db_to_type(db_flavor,
flavor_types.FlavorResponse)
return flavor_types.FlavorRootResponse(flavor=result)
@ -111,21 +113,22 @@ class FlavorsController(base.BaseController):
constants.RBAC_PUT)
if id == constants.NIL_UUID:
raise exceptions.NotFound(resource='Flavor', id=constants.NIL_UUID)
lock_session = db_api.get_session(autocommit=False)
context.session.begin()
try:
flavor_dict = flavor.to_dict(render_unsets=False)
if flavor_dict:
self.repositories.flavor.update(lock_session, id,
self.repositories.flavor.update(context.session, id,
**flavor_dict)
lock_session.commit()
context.session.commit()
except Exception:
with excutils.save_and_reraise_exception():
lock_session.rollback()
context.session.rollback()
# Force SQL alchemy to query the DB, otherwise we get inconsistent
# results
context.session.expire_all()
db_flavor = self._get_db_flavor(context.session, id)
with context.session.begin():
db_flavor = self._get_db_flavor(context.session, id)
result = self._convert_db_to_type(db_flavor,
flavor_types.FlavorResponse)
return flavor_types.FlavorRootResponse(flavor=result)
@ -140,7 +143,8 @@ class FlavorsController(base.BaseController):
constants.RBAC_DELETE)
if flavor_id == constants.NIL_UUID:
raise exceptions.NotFound(resource='Flavor', id=constants.NIL_UUID)
serial_session = db_api.get_session(autocommit=False)
serial_session = db_api.get_session()
serial_session.begin()
serial_session.connection(
execution_options={'isolation_level': 'SERIALIZABLE'})
try:

View File

@ -30,7 +30,6 @@ from octavia.api.v2.types import health_monitor as hm_types
from octavia.common import constants as consts
from octavia.common import data_models
from octavia.common import exceptions
from octavia.db import api as db_api
from octavia.db import prepare as db_prepare
from octavia.i18n import _
@ -50,7 +49,8 @@ class HealthMonitorController(base.BaseController):
def get_one(self, id, fields=None):
"""Gets a single healthmonitor's details."""
context = pecan_request.context.get('octavia_context')
db_hm = self._get_db_hm(context.session, id, show_deleted=False)
with context.session.begin():
db_hm = self._get_db_hm(context.session, id, show_deleted=False)
self._auth_validate_action(context, db_hm.project_id,
consts.RBAC_GET_ONE)
@ -70,10 +70,11 @@ class HealthMonitorController(base.BaseController):
query_filter = self._auth_get_all(context, project_id)
db_hm, links = self.repositories.health_monitor.get_all_API_list(
context.session, show_deleted=False,
pagination_helper=pcontext.get(consts.PAGINATION_HELPER),
**query_filter)
with context.session.begin():
db_hm, links = self.repositories.health_monitor.get_all_API_list(
context.session, show_deleted=False,
pagination_helper=pcontext.get(consts.PAGINATION_HELPER),
**query_filter)
result = self._convert_db_to_type(
db_hm, [hm_types.HealthMonitorResponse])
if fields is not None:
@ -155,8 +156,10 @@ class HealthMonitorController(base.BaseController):
option='health monitors HTTP 1.1 domain name health check')
try:
return self.repositories.health_monitor.create(
ret = self.repositories.health_monitor.create(
lock_session, **hm_dict)
lock_session.flush()
return ret
except odb_exceptions.DBDuplicateEntry as e:
raise exceptions.DuplicateHealthMonitor() from e
except odb_exceptions.DBReferenceError as e:
@ -200,10 +203,12 @@ class HealthMonitorController(base.BaseController):
context = pecan_request.context.get('octavia_context')
health_monitor = health_monitor_.healthmonitor
pool = self._get_db_pool(context.session, health_monitor.pool_id)
with context.session.begin():
pool = self._get_db_pool(context.session, health_monitor.pool_id)
health_monitor.project_id, provider = self._get_lb_project_id_provider(
context.session, pool.load_balancer_id)
health_monitor.project_id, provider = (
self._get_lb_project_id_provider(context.session,
pool.load_balancer_id))
self._auth_validate_action(context, health_monitor.project_id,
consts.RBAC_POST)
@ -230,11 +235,11 @@ class HealthMonitorController(base.BaseController):
# Load the driver early as it also provides validation
driver = driver_factory.get_driver(provider)
lock_session = db_api.get_session(autocommit=False)
context.session.begin()
try:
if self.repositories.check_quota_met(
context.session,
lock_session,
context.session,
data_models.HealthMonitor,
health_monitor.project_id):
raise exceptions.QuotaException(
@ -244,8 +249,8 @@ class HealthMonitorController(base.BaseController):
health_monitor.to_dict(render_unsets=True))
self._test_lb_and_listener_and_pool_statuses(
lock_session, health_monitor)
db_hm = self._validate_create_hm(lock_session, hm_dict)
context.session, health_monitor)
db_hm = self._validate_create_hm(context.session, hm_dict)
# Prepare the data for the driver data model
provider_healthmon = driver_utils.db_HM_to_provider_HM(db_hm)
@ -256,16 +261,17 @@ class HealthMonitorController(base.BaseController):
driver_utils.call_provider(
driver.name, driver.health_monitor_create, provider_healthmon)
lock_session.commit()
context.session.commit()
except odb_exceptions.DBError as e:
lock_session.rollback()
context.session.rollback()
raise exceptions.InvalidOption(
value=hm_dict.get('type'), option='type') from e
except Exception:
with excutils.save_and_reraise_exception():
lock_session.rollback()
context.session.rollback()
db_hm = self._get_db_hm(context.session, db_hm.id)
with context.session.begin():
db_hm = self._get_db_hm(context.session, db_hm.id)
result = self._convert_db_to_type(
db_hm, hm_types.HealthMonitorResponse)
return hm_types.HealthMonitorRootResponse(healthmonitor=result)
@ -342,11 +348,12 @@ class HealthMonitorController(base.BaseController):
"""Updates a health monitor."""
context = pecan_request.context.get('octavia_context')
health_monitor = health_monitor_.healthmonitor
db_hm = self._get_db_hm(context.session, id, show_deleted=False)
with context.session.begin():
db_hm = self._get_db_hm(context.session, id, show_deleted=False)
pool = self._get_db_pool(context.session, db_hm.pool_id)
project_id, provider = self._get_lb_project_id_provider(
context.session, pool.load_balancer_id)
pool = self._get_db_pool(context.session, db_hm.pool_id)
project_id, provider = self._get_lb_project_id_provider(
context.session, pool.load_balancer_id)
self._auth_validate_action(context, project_id, consts.RBAC_PUT)
@ -363,9 +370,10 @@ class HealthMonitorController(base.BaseController):
# Load the driver early as it also provides validation
driver = driver_factory.get_driver(provider)
with db_api.get_lock_session() as lock_session:
with context.session.begin():
self._test_lb_and_listener_and_pool_statuses(lock_session, db_hm)
self._test_lb_and_listener_and_pool_statuses(context.session,
db_hm)
# Prepare the data for the driver data model
healthmon_dict = health_monitor.to_dict(render_unsets=False)
@ -387,13 +395,14 @@ class HealthMonitorController(base.BaseController):
# Update the database to reflect what the driver just accepted
health_monitor.provisioning_status = consts.PENDING_UPDATE
db_hm_dict = health_monitor.to_dict(render_unsets=False)
self.repositories.health_monitor.update(lock_session, id,
self.repositories.health_monitor.update(context.session, id,
**db_hm_dict)
# Force SQL alchemy to query the DB, otherwise we get inconsistent
# results
context.session.expire_all()
db_hm = self._get_db_hm(context.session, id)
with context.session.begin():
db_hm = self._get_db_hm(context.session, id)
result = self._convert_db_to_type(
db_hm, hm_types.HealthMonitorResponse)
return hm_types.HealthMonitorRootResponse(healthmonitor=result)
@ -402,11 +411,12 @@ class HealthMonitorController(base.BaseController):
def delete(self, id):
"""Deletes a health monitor."""
context = pecan_request.context.get('octavia_context')
db_hm = self._get_db_hm(context.session, id, show_deleted=False)
with context.session.begin():
db_hm = self._get_db_hm(context.session, id, show_deleted=False)
pool = self._get_db_pool(context.session, db_hm.pool_id)
project_id, provider = self._get_lb_project_id_provider(
context.session, pool.load_balancer_id)
pool = self._get_db_pool(context.session, db_hm.pool_id)
project_id, provider = self._get_lb_project_id_provider(
context.session, pool.load_balancer_id)
self._auth_validate_action(context, project_id, consts.RBAC_DELETE)
@ -416,12 +426,13 @@ class HealthMonitorController(base.BaseController):
# Load the driver early as it also provides validation
driver = driver_factory.get_driver(provider)
with db_api.get_lock_session() as lock_session:
with context.session.begin():
self._test_lb_and_listener_and_pool_statuses(lock_session, db_hm)
self._test_lb_and_listener_and_pool_statuses(context.session,
db_hm)
self.repositories.health_monitor.update(
lock_session, db_hm.id,
context.session, db_hm.id,
provisioning_status=consts.PENDING_DELETE)
LOG.info("Sending delete Health Monitor %s to provider %s",

View File

@ -32,7 +32,6 @@ from octavia.common import constants
from octavia.common import data_models
from octavia.common import exceptions
from octavia.common import validate
from octavia.db import api as db_api
from octavia.db import prepare as db_prepare
@ -51,8 +50,9 @@ class L7PolicyController(base.BaseController):
def get(self, id, fields=None):
"""Gets a single l7policy's details."""
context = pecan_request.context.get('octavia_context')
db_l7policy = self._get_db_l7policy(context.session, id,
show_deleted=False)
with context.session.begin():
db_l7policy = self._get_db_l7policy(context.session, id,
show_deleted=False)
self._auth_validate_action(context, db_l7policy.project_id,
constants.RBAC_GET_ONE)
@ -72,10 +72,11 @@ class L7PolicyController(base.BaseController):
query_filter = self._auth_get_all(context, project_id)
db_l7policies, links = self.repositories.l7policy.get_all_API_list(
context.session, show_deleted=False,
pagination_helper=pcontext.get(constants.PAGINATION_HELPER),
**query_filter)
with context.session.begin():
db_l7policies, links = self.repositories.l7policy.get_all_API_list(
context.session, show_deleted=False,
pagination_helper=pcontext.get(constants.PAGINATION_HELPER),
**query_filter)
result = self._convert_db_to_type(
db_l7policies, [l7policy_types.L7PolicyResponse])
if fields is not None:
@ -121,11 +122,12 @@ class L7PolicyController(base.BaseController):
# Verify the parent listener exists
listener_id = l7policy.listener_id
listener = self._get_db_listener(
context.session, listener_id)
load_balancer_id = listener.load_balancer_id
l7policy.project_id, provider = self._get_lb_project_id_provider(
context.session, load_balancer_id)
with context.session.begin():
listener = self._get_db_listener(
context.session, listener_id)
load_balancer_id = listener.load_balancer_id
l7policy.project_id, provider = self._get_lb_project_id_provider(
context.session, load_balancer_id)
self._auth_validate_action(context, l7policy.project_id,
constants.RBAC_POST)
@ -137,14 +139,16 @@ class L7PolicyController(base.BaseController):
# Make sure any pool specified by redirect_pool_id exists
if l7policy.redirect_pool_id:
db_pool = self._get_db_pool(
context.session, l7policy.redirect_pool_id)
with context.session.begin():
db_pool = self._get_db_pool(
context.session, l7policy.redirect_pool_id)
self._validate_protocol(listener.protocol, db_pool.protocol)
# Load the driver early as it also provides validation
driver = driver_factory.get_driver(provider)
lock_session = db_api.get_session(autocommit=False)
lock_session = context.session
lock_session.begin()
try:
if self.repositories.check_quota_met(
context.session,
@ -179,7 +183,9 @@ class L7PolicyController(base.BaseController):
with excutils.save_and_reraise_exception():
lock_session.rollback()
db_l7policy = self._get_db_l7policy(context.session, db_l7policy.id)
with context.session.begin():
db_l7policy = self._get_db_l7policy(context.session,
db_l7policy.id)
result = self._convert_db_to_type(db_l7policy,
l7policy_types.L7PolicyResponse)
return l7policy_types.L7PolicyRootResponse(l7policy=result)
@ -209,12 +215,13 @@ class L7PolicyController(base.BaseController):
"""Updates a l7policy."""
l7policy = l7policy_.l7policy
context = pecan_request.context.get('octavia_context')
db_l7policy = self._get_db_l7policy(context.session, id,
show_deleted=False)
load_balancer_id, listener_id = self._get_listener_and_loadbalancer_id(
db_l7policy)
project_id, provider = self._get_lb_project_id_provider(
context.session, load_balancer_id)
with context.session.begin():
db_l7policy = self._get_db_l7policy(context.session, id,
show_deleted=False)
load_balancer_id, listener_id = (
self._get_listener_and_loadbalancer_id(db_l7policy))
project_id, provider = self._get_lb_project_id_provider(
context.session, load_balancer_id)
self._auth_validate_action(context, project_id, constants.RBAC_PUT)
@ -226,18 +233,21 @@ class L7PolicyController(base.BaseController):
l7policy_dict[attr] = l7policy_dict.pop(val)
sanitized_l7policy = l7policy_types.L7PolicyPUT(**l7policy_dict)
listener = self._get_db_listener(
context.session, db_l7policy.listener_id)
with context.session.begin():
listener = self._get_db_listener(
context.session, db_l7policy.listener_id)
# Make sure any specified redirect_pool_id exists
if l7policy_dict.get('redirect_pool_id'):
db_pool = self._get_db_pool(
context.session, l7policy_dict['redirect_pool_id'])
with context.session.begin():
db_pool = self._get_db_pool(
context.session, l7policy_dict['redirect_pool_id'])
self._validate_protocol(listener.protocol, db_pool.protocol)
# Load the driver early as it also provides validation
driver = driver_factory.get_driver(provider)
with db_api.get_lock_session() as lock_session:
with context.session.begin():
lock_session = context.session
self._test_lb_and_listener_statuses(lock_session,
lb_id=load_balancer_id,
@ -270,7 +280,8 @@ class L7PolicyController(base.BaseController):
# Force SQL alchemy to query the DB, otherwise we get inconsistent
# results
context.session.expire_all()
db_l7policy = self._get_db_l7policy(context.session, id)
with context.session.begin():
db_l7policy = self._get_db_l7policy(context.session, id)
result = self._convert_db_to_type(db_l7policy,
l7policy_types.L7PolicyResponse)
return l7policy_types.L7PolicyRootResponse(l7policy=result)
@ -279,12 +290,13 @@ class L7PolicyController(base.BaseController):
def delete(self, id):
"""Deletes a l7policy."""
context = pecan_request.context.get('octavia_context')
db_l7policy = self._get_db_l7policy(context.session, id,
show_deleted=False)
load_balancer_id, listener_id = self._get_listener_and_loadbalancer_id(
db_l7policy)
project_id, provider = self._get_lb_project_id_provider(
context.session, load_balancer_id)
with context.session.begin():
db_l7policy = self._get_db_l7policy(context.session, id,
show_deleted=False)
load_balancer_id, listener_id = (
self._get_listener_and_loadbalancer_id(db_l7policy))
project_id, provider = self._get_lb_project_id_provider(
context.session, load_balancer_id)
self._auth_validate_action(context, project_id, constants.RBAC_DELETE)
@ -294,13 +306,12 @@ class L7PolicyController(base.BaseController):
# Load the driver early as it also provides validation
driver = driver_factory.get_driver(provider)
with db_api.get_lock_session() as lock_session:
self._test_lb_and_listener_statuses(lock_session,
with context.session.begin():
self._test_lb_and_listener_statuses(context.session,
lb_id=load_balancer_id,
listener_ids=[listener_id])
self.repositories.l7policy.update(
lock_session, db_l7policy.id,
context.session, db_l7policy.id,
provisioning_status=constants.PENDING_DELETE)
LOG.info("Sending delete L7 Policy %s to provider %s",
@ -320,8 +331,9 @@ class L7PolicyController(base.BaseController):
context = pecan_request.context.get('octavia_context')
if l7policy_id and remainder and remainder[0] == 'rules':
remainder = remainder[1:]
db_l7policy = self.repositories.l7policy.get(
context.session, id=l7policy_id)
with context.session.begin():
db_l7policy = self.repositories.l7policy.get(
context.session, id=l7policy_id)
if not db_l7policy:
LOG.info("L7Policy %s not found.", l7policy_id)
raise exceptions.NotFound(

View File

@ -28,7 +28,6 @@ from octavia.common import constants
from octavia.common import data_models
from octavia.common import exceptions
from octavia.common import validate
from octavia.db import api as db_api
from octavia.db import prepare as db_prepare
@ -47,8 +46,9 @@ class L7RuleController(base.BaseController):
def get(self, id, fields=None):
"""Gets a single l7rule's details."""
context = pecan_request.context.get('octavia_context')
db_l7rule = self._get_db_l7rule(context.session, id,
show_deleted=False)
with context.session.begin():
db_l7rule = self._get_db_l7rule(context.session, id,
show_deleted=False)
self._auth_validate_action(context, db_l7rule.project_id,
constants.RBAC_GET_ONE)
@ -66,15 +66,17 @@ class L7RuleController(base.BaseController):
pcontext = pecan_request.context
context = pcontext.get('octavia_context')
l7policy = self._get_db_l7policy(context.session, self.l7policy_id,
show_deleted=False)
with context.session.begin():
l7policy = self._get_db_l7policy(context.session, self.l7policy_id,
show_deleted=False)
self._auth_validate_action(context, l7policy.project_id,
constants.RBAC_GET_ALL)
self._auth_validate_action(context, l7policy.project_id,
constants.RBAC_GET_ALL)
db_l7rules, links = self.repositories.l7rule.get_all_API_list(
context.session, show_deleted=False, l7policy_id=self.l7policy_id,
pagination_helper=pcontext.get(constants.PAGINATION_HELPER))
db_l7rules, links = self.repositories.l7rule.get_all_API_list(
context.session, show_deleted=False,
l7policy_id=self.l7policy_id,
pagination_helper=pcontext.get(constants.PAGINATION_HELPER))
result = self._convert_db_to_type(
db_l7rules, [l7rule_types.L7RuleResponse])
if fields is not None:
@ -109,7 +111,9 @@ class L7RuleController(base.BaseController):
def _validate_create_l7rule(self, lock_session, l7rule_dict):
try:
return self.repositories.l7rule.create(lock_session, **l7rule_dict)
ret = self.repositories.l7rule.create(lock_session, **l7rule_dict)
lock_session.flush()
return ret
except odb_exceptions.DBDuplicateEntry as e:
raise exceptions.IDAlreadyExists() from e
except odb_exceptions.DBReferenceError as e:
@ -125,30 +129,33 @@ class L7RuleController(base.BaseController):
l7rule = rule_.rule
context = pecan_request.context.get('octavia_context')
db_l7policy = self._get_db_l7policy(context.session, self.l7policy_id,
show_deleted=False)
load_balancer_id, listener_id = self._get_listener_and_loadbalancer_id(
db_l7policy)
l7rule.project_id, provider = self._get_lb_project_id_provider(
context.session, load_balancer_id)
self._auth_validate_action(context, l7rule.project_id,
constants.RBAC_POST)
with context.session.begin():
db_l7policy = self._get_db_l7policy(context.session,
self.l7policy_id,
show_deleted=False)
load_balancer_id, listener_id = (
self._get_listener_and_loadbalancer_id(db_l7policy))
l7rule.project_id, provider = self._get_lb_project_id_provider(
context.session, load_balancer_id)
try:
validate.l7rule_data(l7rule)
except Exception as e:
raise exceptions.L7RuleValidation(error=e)
self._auth_validate_action(context, l7rule.project_id,
constants.RBAC_POST)
self._check_l7policy_max_rules(context.session)
try:
validate.l7rule_data(l7rule)
except Exception as e:
raise exceptions.L7RuleValidation(error=e)
self._check_l7policy_max_rules(context.session)
# Load the driver early as it also provides validation
driver = driver_factory.get_driver(provider)
lock_session = db_api.get_session(autocommit=False)
context.session.begin()
try:
if self.repositories.check_quota_met(
context.session,
lock_session,
context.session,
data_models.L7Rule,
l7rule.project_id):
raise exceptions.QuotaException(
@ -157,9 +164,10 @@ class L7RuleController(base.BaseController):
l7rule_dict = db_prepare.create_l7rule(
l7rule.to_dict(render_unsets=True), self.l7policy_id)
self._test_lb_listener_policy_statuses(lock_session)
self._test_lb_listener_policy_statuses(context.session)
db_l7rule = self._validate_create_l7rule(lock_session, l7rule_dict)
db_l7rule = self._validate_create_l7rule(context.session,
l7rule_dict)
# Prepare the data for the driver data model
provider_l7rule = (
@ -171,12 +179,13 @@ class L7RuleController(base.BaseController):
driver_utils.call_provider(
driver.name, driver.l7rule_create, provider_l7rule)
lock_session.commit()
context.session.commit()
except Exception:
with excutils.save_and_reraise_exception():
lock_session.rollback()
context.session.rollback()
db_l7rule = self._get_db_l7rule(context.session, db_l7rule.id)
with context.session.begin():
db_l7rule = self._get_db_l7rule(context.session, db_l7rule.id)
result = self._convert_db_to_type(db_l7rule,
l7rule_types.L7RuleResponse)
return l7rule_types.L7RuleRootResponse(rule=result)
@ -198,14 +207,16 @@ class L7RuleController(base.BaseController):
"""Updates a l7rule."""
l7rule = l7rule_.rule
context = pecan_request.context.get('octavia_context')
db_l7rule = self._get_db_l7rule(context.session, id,
show_deleted=False)
db_l7policy = self._get_db_l7policy(context.session, self.l7policy_id,
with context.session.begin():
db_l7rule = self._get_db_l7rule(context.session, id,
show_deleted=False)
load_balancer_id, listener_id = self._get_listener_and_loadbalancer_id(
db_l7policy)
project_id, provider = self._get_lb_project_id_provider(
context.session, load_balancer_id)
db_l7policy = self._get_db_l7policy(context.session,
self.l7policy_id,
show_deleted=False)
load_balancer_id, listener_id = (
self._get_listener_and_loadbalancer_id(db_l7policy))
project_id, provider = self._get_lb_project_id_provider(
context.session, load_balancer_id)
self._auth_validate_action(context, project_id, constants.RBAC_PUT)
@ -225,9 +236,8 @@ class L7RuleController(base.BaseController):
# Load the driver early as it also provides validation
driver = driver_factory.get_driver(provider)
with db_api.get_lock_session() as lock_session:
self._test_lb_listener_policy_statuses(lock_session)
with context.session.begin():
self._test_lb_listener_policy_statuses(context.session)
# Prepare the data for the driver data model
l7rule_dict = l7rule.to_dict(render_unsets=False)
@ -250,12 +260,14 @@ class L7RuleController(base.BaseController):
# Update the database to reflect what the driver just accepted
l7rule.provisioning_status = constants.PENDING_UPDATE
db_l7rule_dict = l7rule.to_dict(render_unsets=False)
self.repositories.l7rule.update(lock_session, id, **db_l7rule_dict)
self.repositories.l7rule.update(context.session, id,
**db_l7rule_dict)
# Force SQL alchemy to query the DB, otherwise we get inconsistent
# results
context.session.expire_all()
db_l7rule = self._get_db_l7rule(context.session, id)
with context.session.begin():
db_l7rule = self._get_db_l7rule(context.session, id)
result = self._convert_db_to_type(db_l7rule,
l7rule_types.L7RuleResponse)
return l7rule_types.L7RuleRootResponse(rule=result)
@ -264,15 +276,17 @@ class L7RuleController(base.BaseController):
def delete(self, id):
"""Deletes a l7rule."""
context = pecan_request.context.get('octavia_context')
db_l7rule = self._get_db_l7rule(context.session, id,
show_deleted=False)
db_l7policy = self._get_db_l7policy(context.session, self.l7policy_id,
with context.session.begin():
db_l7rule = self._get_db_l7rule(context.session, id,
show_deleted=False)
load_balancer_id, listener_id = self._get_listener_and_loadbalancer_id(
db_l7policy)
project_id, provider = self._get_lb_project_id_provider(
context.session, load_balancer_id)
db_l7policy = self._get_db_l7policy(context.session,
self.l7policy_id,
show_deleted=False)
load_balancer_id, listener_id = (
self._get_listener_and_loadbalancer_id(db_l7policy))
project_id, provider = self._get_lb_project_id_provider(
context.session, load_balancer_id)
self._auth_validate_action(context, project_id, constants.RBAC_DELETE)
@ -282,12 +296,11 @@ class L7RuleController(base.BaseController):
# Load the driver early as it also provides validation
driver = driver_factory.get_driver(provider)
with db_api.get_lock_session() as lock_session:
self._test_lb_listener_policy_statuses(lock_session)
with context.session.begin():
self._test_lb_listener_policy_statuses(context.session)
self.repositories.l7rule.update(
lock_session, db_l7rule.id,
context.session, db_l7rule.id,
provisioning_status=constants.PENDING_DELETE)
LOG.info("Sending delete L7 Rule %s to provider %s", id,

View File

@ -35,7 +35,6 @@ from octavia.common import exceptions
from octavia.common import stats
from octavia.common import utils as common_utils
from octavia.common import validate
from octavia.db import api as db_api
from octavia.db import prepare as db_prepare
from octavia.i18n import _
@ -55,8 +54,9 @@ class ListenersController(base.BaseController):
def get_one(self, id, fields=None):
"""Gets a single listener's details."""
context = pecan_request.context.get('octavia_context')
db_listener = self._get_db_listener(context.session, id,
show_deleted=False)
with context.session.begin():
db_listener = self._get_db_listener(context.session, id,
show_deleted=False)
if not db_listener:
raise exceptions.NotFound(resource=data_models.Listener._name(),
@ -80,10 +80,11 @@ class ListenersController(base.BaseController):
query_filter = self._auth_get_all(context, project_id)
db_listeners, links = self.repositories.listener.get_all_API_list(
context.session, show_deleted=False,
pagination_helper=pcontext.get(constants.PAGINATION_HELPER),
**query_filter)
with context.session.begin():
db_listeners, links = self.repositories.listener.get_all_API_list(
context.session, show_deleted=False,
pagination_helper=pcontext.get(constants.PAGINATION_HELPER),
**query_filter)
result = self._convert_db_to_type(
db_listeners, [listener_types.ListenerResponse])
if fields is not None:
@ -327,12 +328,14 @@ class ListenersController(base.BaseController):
try:
db_listener = self.repositories.listener.create(
lock_session, **listener_dict)
lock_session.flush()
if sni_containers:
for container in sni_containers:
sni_dict = {'listener_id': db_listener.id,
'tls_container_id': container.get(
'tls_container_id')}
self.repositories.sni.create(lock_session, **sni_dict)
lock_session.flush()
# DB listener needs to be refreshed
db_listener = self.repositories.listener.get(
lock_session, id=db_listener.id)
@ -355,8 +358,9 @@ class ListenersController(base.BaseController):
context = pecan_request.context.get('octavia_context')
load_balancer_id = listener.loadbalancer_id
listener.project_id, provider = self._get_lb_project_id_provider(
context.session, load_balancer_id)
with context.session.begin():
listener.project_id, provider = self._get_lb_project_id_provider(
context.session, load_balancer_id)
self._auth_validate_action(context, listener.project_id,
constants.RBAC_POST)
@ -364,11 +368,11 @@ class ListenersController(base.BaseController):
# Load the driver early as it also provides validation
driver = driver_factory.get_driver(provider)
lock_session = db_api.get_session(autocommit=False)
context.session.begin()
try:
if self.repositories.check_quota_met(
context.session,
lock_session,
context.session,
data_models.Listener,
listener.project_id):
raise exceptions.QuotaException(
@ -387,10 +391,10 @@ class ListenersController(base.BaseController):
listener.protocol)
self._test_lb_and_listener_statuses(
lock_session, lb_id=load_balancer_id)
context.session, lb_id=load_balancer_id)
db_listener = self._validate_create_listener(
lock_session, listener_dict)
context.session, listener_dict)
# Prepare the data for the driver data model
provider_listener = (
@ -408,12 +412,14 @@ class ListenersController(base.BaseController):
driver_utils.call_provider(
driver.name, driver.listener_create, provider_listener)
lock_session.commit()
context.session.commit()
except Exception:
with excutils.save_and_reraise_exception():
lock_session.rollback()
context.session.rollback()
db_listener = self._get_db_listener(context.session, db_listener.id)
with context.session.begin():
db_listener = self._get_db_listener(context.session,
db_listener.id)
result = self._convert_db_to_type(db_listener,
listener_types.ListenerResponse)
return listener_types.ListenerRootResponse(listener=result)
@ -599,12 +605,13 @@ class ListenersController(base.BaseController):
"""Updates a listener on a load balancer."""
listener = listener_.listener
context = pecan_request.context.get('octavia_context')
db_listener = self._get_db_listener(context.session, id,
show_deleted=False)
load_balancer_id = db_listener.load_balancer_id
with context.session.begin():
db_listener = self._get_db_listener(context.session, id,
show_deleted=False)
load_balancer_id = db_listener.load_balancer_id
project_id, provider = self._get_lb_project_id_provider(
context.session, load_balancer_id)
project_id, provider = self._get_lb_project_id_provider(
context.session, load_balancer_id)
self._auth_validate_action(context, project_id, constants.RBAC_PUT)
@ -616,14 +623,16 @@ class ListenersController(base.BaseController):
if db_listener.protocol == lib_consts.PROTOCOL_PROMETHEUS:
raise exceptions.ListenerNoChildren(
protocol=lib_consts.PROTOCOL_PROMETHEUS)
self._validate_pool(context.session, load_balancer_id,
listener.default_pool_id, db_listener.protocol)
with context.session.begin():
self._validate_pool(context.session, load_balancer_id,
listener.default_pool_id,
db_listener.protocol)
# Load the driver early as it also provides validation
driver = driver_factory.get_driver(provider)
with db_api.get_lock_session() as lock_session:
self._test_lb_and_listener_statuses(lock_session,
with context.session.begin():
self._test_lb_and_listener_statuses(context.session,
load_balancer_id, id=id)
# Prepare the data for the driver data model
@ -648,12 +657,13 @@ class ListenersController(base.BaseController):
# Update the database to reflect what the driver just accepted
self.repositories.listener.update(
lock_session, id, **listener.to_dict(render_unsets=False))
context.session, id, **listener.to_dict(render_unsets=False))
# Force SQL alchemy to query the DB, otherwise we get inconsistent
# results
context.session.expire_all()
db_listener = self._get_db_listener(context.session, id)
with context.session.begin():
db_listener = self._get_db_listener(context.session, id)
result = self._convert_db_to_type(db_listener,
listener_types.ListenerResponse)
return listener_types.ListenerRootResponse(listener=result)
@ -662,22 +672,22 @@ class ListenersController(base.BaseController):
def delete(self, id):
"""Deletes a listener from a load balancer."""
context = pecan_request.context.get('octavia_context')
db_listener = self._get_db_listener(context.session, id,
show_deleted=False)
load_balancer_id = db_listener.load_balancer_id
with context.session.begin():
db_listener = self._get_db_listener(context.session, id,
show_deleted=False)
load_balancer_id = db_listener.load_balancer_id
project_id, provider = self._get_lb_project_id_provider(
context.session, load_balancer_id)
project_id, provider = self._get_lb_project_id_provider(
context.session, load_balancer_id)
self._auth_validate_action(context, project_id, constants.RBAC_DELETE)
# Load the driver early as it also provides validation
driver = driver_factory.get_driver(provider)
with db_api.get_lock_session() as lock_session:
with context.session.begin():
self._test_lb_and_listener_statuses(
lock_session, load_balancer_id,
context.session, load_balancer_id,
id=id, listener_status=constants.PENDING_DELETE)
LOG.info("Sending delete Listener %s to provider %s", id,
@ -711,18 +721,19 @@ class StatisticsController(base.BaseController, stats.StatsMixin):
status_code=200)
def get(self):
context = pecan_request.context.get('octavia_context')
db_listener = self._get_db_listener(context.session, self.id,
show_deleted=False)
if not db_listener:
LOG.info("Listener %s not found.", id)
raise exceptions.NotFound(
resource=data_models.Listener._name(),
id=id)
with context.session.begin():
db_listener = self._get_db_listener(context.session, self.id,
show_deleted=False)
if not db_listener:
LOG.info("Listener %s not found.", id)
raise exceptions.NotFound(
resource=data_models.Listener._name(),
id=id)
self._auth_validate_action(context, db_listener.project_id,
constants.RBAC_GET_STATS)
self._auth_validate_action(context, db_listener.project_id,
constants.RBAC_GET_STATS)
listener_stats = self.get_listener_stats(context.session, self.id)
listener_stats = self.get_listener_stats(context.session, self.id)
result = self._convert_db_to_type(
listener_stats, listener_types.ListenerStatisticsResponse)

View File

@ -38,7 +38,6 @@ from octavia.common import exceptions
from octavia.common import stats
from octavia.common import utils
from octavia.common import validate
from octavia.db import api as db_api
from octavia.db import prepare as db_prepare
from octavia.i18n import _
from octavia.network import base as network_base
@ -59,8 +58,9 @@ class LoadBalancersController(base.BaseController):
def get_one(self, id, fields=None):
"""Gets a single load balancer's details."""
context = pecan_request.context.get('octavia_context')
load_balancer = self._get_db_lb(context.session, id,
show_deleted=False)
with context.session.begin():
load_balancer = self._get_db_lb(context.session, id,
show_deleted=False)
if not load_balancer:
raise exceptions.NotFound(
@ -85,11 +85,13 @@ class LoadBalancersController(base.BaseController):
query_filter = self._auth_get_all(context, project_id)
load_balancers, links = (
self.repositories.load_balancer.get_all_API_list(
context.session, show_deleted=False,
pagination_helper=pcontext.get(constants.PAGINATION_HELPER),
**query_filter))
with context.session.begin():
load_balancers, links = (
self.repositories.load_balancer.get_all_API_list(
context.session, show_deleted=False,
pagination_helper=pcontext.get(
constants.PAGINATION_HELPER),
**query_filter))
result = self._convert_db_to_type(
load_balancers, [lb_types.LoadBalancerResponse])
if fields is not None:
@ -327,8 +329,9 @@ class LoadBalancersController(base.BaseController):
provider = None
if not isinstance(load_balancer.flavor_id, wtypes.UnsetType):
try:
provider = self.repositories.flavor.get_flavor_provider(
session, load_balancer.flavor_id)
with session.begin():
provider = self.repositories.flavor.get_flavor_provider(
session, load_balancer.flavor_id)
except sa_exception.NoResultFound as e:
raise exceptions.ValidationException(
detail=_("Invalid flavor_id.")) from e
@ -377,8 +380,9 @@ class LoadBalancersController(base.BaseController):
def _validate_flavor(self, session, load_balancer):
if not isinstance(load_balancer.flavor_id, wtypes.UnsetType):
flavor = self.repositories.flavor.get(session,
id=load_balancer.flavor_id)
with session.begin():
flavor = self.repositories.flavor.get(
session, id=load_balancer.flavor_id)
if not flavor:
raise exceptions.ValidationException(
detail=_("Invalid flavor_id."))
@ -416,8 +420,9 @@ class LoadBalancersController(base.BaseController):
def _validate_availability_zone(self, session, load_balancer):
if not isinstance(load_balancer.availability_zone, wtypes.UnsetType):
az = self.repositories.availability_zone.get(
session, name=load_balancer.availability_zone)
with session.begin():
az = self.repositories.availability_zone.get(
session, name=load_balancer.availability_zone)
if not az:
raise exceptions.ValidationException(
detail=_("Invalid availability zone."))
@ -456,7 +461,8 @@ class LoadBalancersController(base.BaseController):
# Load the driver early as it also provides validation
driver = driver_factory.get_driver(provider)
lock_session = db_api.get_session(autocommit=False)
lock_session = context.session
lock_session.begin()
try:
if self.repositories.check_quota_met(
context.session,
@ -553,6 +559,8 @@ class LoadBalancersController(base.BaseController):
driver_lb_dict = driver_utils.lb_dict_to_provider_dict(
lb_dict, vip, add_vips, db_pools, db_lists)
lock_session.flush()
# Dispatch to the driver
LOG.info("Sending create Load Balancer %s to provider %s",
db_lb.id, driver.name)
@ -568,7 +576,8 @@ class LoadBalancersController(base.BaseController):
with excutils.save_and_reraise_exception():
lock_session.rollback()
db_lb = self._get_db_lb(context.session, db_lb.id)
with context.session.begin():
db_lb = self._get_db_lb(context.session, db_lb.id)
result = self._convert_db_to_type(
db_lb, lb_types.LoadBalancerFullResponse)
@ -689,7 +698,8 @@ class LoadBalancersController(base.BaseController):
"""Updates a load balancer."""
load_balancer = load_balancer.loadbalancer
context = pecan_request.context.get('octavia_context')
db_lb = self._get_db_lb(context.session, id, show_deleted=False)
with context.session.begin():
db_lb = self._get_db_lb(context.session, id, show_deleted=False)
self._auth_validate_action(context, db_lb.project_id,
constants.RBAC_PUT)
@ -704,8 +714,8 @@ class LoadBalancersController(base.BaseController):
# Load the driver early as it also provides validation
driver = driver_factory.get_driver(db_lb.provider)
with db_api.get_lock_session() as lock_session:
self._test_lb_status(lock_session, id)
with context.session.begin():
self._test_lb_status(context.session, id)
# Prepare the data for the driver data model
lb_dict = load_balancer.to_dict(render_unsets=False)
@ -731,15 +741,17 @@ class LoadBalancersController(base.BaseController):
db_lb_dict = load_balancer.to_dict(render_unsets=False)
if 'vip' in db_lb_dict:
db_vip_dict = db_lb_dict.pop('vip')
self.repositories.vip.update(lock_session, id, **db_vip_dict)
self.repositories.vip.update(context.session, id,
**db_vip_dict)
if db_lb_dict:
self.repositories.load_balancer.update(lock_session, id,
self.repositories.load_balancer.update(context.session, id,
**db_lb_dict)
# Force SQL alchemy to query the DB, otherwise we get inconsistent
# results
context.session.expire_all()
db_lb = self._get_db_lb(context.session, id)
with context.session.begin():
db_lb = self._get_db_lb(context.session, id)
result = self._convert_db_to_type(db_lb, lb_types.LoadBalancerResponse)
return lb_types.LoadBalancerRootResponse(loadbalancer=result)
@ -748,7 +760,8 @@ class LoadBalancersController(base.BaseController):
"""Deletes a load balancer."""
context = pecan_request.context.get('octavia_context')
cascade = strutils.bool_from_string(cascade)
db_lb = self._get_db_lb(context.session, id, show_deleted=False)
with context.session.begin():
db_lb = self._get_db_lb(context.session, id, show_deleted=False)
self._auth_validate_action(context, db_lb.project_id,
constants.RBAC_DELETE)
@ -756,13 +769,13 @@ class LoadBalancersController(base.BaseController):
# Load the driver early as it also provides validation
driver = driver_factory.get_driver(db_lb.provider)
with db_api.get_lock_session() as lock_session:
with context.session.begin():
if (db_lb.listeners or db_lb.pools) and not cascade:
msg = _("Cannot delete Load Balancer %s - "
"it has children") % id
LOG.warning(msg)
raise exceptions.ValidationException(detail=msg)
self._test_lb_status(lock_session, id,
self._test_lb_status(context.session, id,
lb_status=constants.PENDING_DELETE)
LOG.info("Sending delete Load Balancer %s to provider %s",
@ -813,8 +826,9 @@ class StatusController(base.BaseController):
status_code=200)
def get(self):
context = pecan_request.context.get('octavia_context')
load_balancer = self._get_db_lb(context.session, self.id,
show_deleted=False)
with context.session.begin():
load_balancer = self._get_db_lb(context.session, self.id,
show_deleted=False)
if not load_balancer:
LOG.info("Load balancer %s not found.", id)
raise exceptions.NotFound(
@ -841,8 +855,9 @@ class StatisticsController(base.BaseController, stats.StatsMixin):
status_code=200)
def get(self):
context = pecan_request.context.get('octavia_context')
load_balancer = self._get_db_lb(context.session, self.id,
show_deleted=False)
with context.session.begin():
load_balancer = self._get_db_lb(context.session, self.id,
show_deleted=False)
if not load_balancer:
LOG.info("Load balancer %s not found.", id)
raise exceptions.NotFound(
@ -852,7 +867,8 @@ class StatisticsController(base.BaseController, stats.StatsMixin):
self._auth_validate_action(context, load_balancer.project_id,
constants.RBAC_GET_STATS)
lb_stats = self.get_loadbalancer_stats(context.session, self.id)
with context.session.begin():
lb_stats = self.get_loadbalancer_stats(context.session, self.id)
result = self._convert_db_to_type(
lb_stats, lb_types.LoadBalancerStatisticsResponse)
@ -869,8 +885,9 @@ class FailoverController(LoadBalancersController):
def put(self, **kwargs):
"""Fails over a loadbalancer"""
context = pecan_request.context.get('octavia_context')
db_lb = self._get_db_lb(context.session, self.lb_id,
show_deleted=False)
with context.session.begin():
db_lb = self._get_db_lb(context.session, self.lb_id,
show_deleted=False)
self._auth_validate_action(context, db_lb.project_id,
constants.RBAC_PUT_FAILOVER)
@ -878,8 +895,9 @@ class FailoverController(LoadBalancersController):
# Load the driver early as it also provides validation
driver = driver_factory.get_driver(db_lb.provider)
with db_api.get_lock_session() as lock_session:
self._test_and_set_failover_prov_status(lock_session, self.lb_id)
with context.session.begin():
self._test_and_set_failover_prov_status(context.session,
self.lb_id)
LOG.info("Sending failover request for load balancer %s to the "
"provider %s", self.lb_id, driver.name)
driver_utils.call_provider(

View File

@ -30,7 +30,6 @@ from octavia.common import constants
from octavia.common import data_models
from octavia.common import exceptions
from octavia.common import validate
from octavia.db import api as db_api
from octavia.db import prepare as db_prepare
@ -49,8 +48,9 @@ class MemberController(base.BaseController):
def get(self, id, fields=None):
"""Gets a single pool member's details."""
context = pecan_request.context.get('octavia_context')
db_member = self._get_db_member(context.session, id,
show_deleted=False)
with context.session.begin():
db_member = self._get_db_member(context.session, id,
show_deleted=False)
self._auth_validate_action(context, db_member.project_id,
constants.RBAC_GET_ONE)
@ -70,16 +70,17 @@ class MemberController(base.BaseController):
pcontext = pecan_request.context
context = pcontext.get('octavia_context')
pool = self._get_db_pool(context.session, self.pool_id,
show_deleted=False)
with context.session.begin():
pool = self._get_db_pool(context.session, self.pool_id,
show_deleted=False)
self._auth_validate_action(context, pool.project_id,
constants.RBAC_GET_ALL)
self._auth_validate_action(context, pool.project_id,
constants.RBAC_GET_ALL)
db_members, links = self.repositories.member.get_all_API_list(
context.session, show_deleted=False,
pool_id=self.pool_id,
pagination_helper=pcontext.get(constants.PAGINATION_HELPER))
db_members, links = self.repositories.member.get_all_API_list(
context.session, show_deleted=False,
pool_id=self.pool_id,
pagination_helper=pcontext.get(constants.PAGINATION_HELPER))
result = self._convert_db_to_type(
db_members, [member_types.MemberResponse])
if fields is not None:
@ -118,7 +119,9 @@ class MemberController(base.BaseController):
def _validate_create_member(self, lock_session, member_dict):
"""Validate creating member on pool."""
try:
return self.repositories.member.create(lock_session, **member_dict)
ret = self.repositories.member.create(lock_session, **member_dict)
lock_session.flush()
return ret
except odb_exceptions.DBDuplicateEntry as e:
raise exceptions.DuplicateMemberEntry(
ip_address=member_dict.get('ip_address'),
@ -141,9 +144,10 @@ class MemberController(base.BaseController):
member = member_.member
context = pecan_request.context.get('octavia_context')
pool = self.repositories.pool.get(context.session, id=self.pool_id)
member.project_id, provider = self._get_lb_project_id_provider(
context.session, pool.load_balancer_id)
with context.session.begin():
pool = self.repositories.pool.get(context.session, id=self.pool_id)
member.project_id, provider = self._get_lb_project_id_provider(
context.session, pool.load_balancer_id)
self._auth_validate_action(context, member.project_id,
constants.RBAC_POST)
@ -158,11 +162,11 @@ class MemberController(base.BaseController):
# Load the driver early as it also provides validation
driver = driver_factory.get_driver(provider)
lock_session = db_api.get_session(autocommit=False)
context.session.begin()
try:
if self.repositories.check_quota_met(
context.session,
lock_session,
context.session,
data_models.Member,
member.project_id):
raise exceptions.QuotaException(
@ -171,9 +175,10 @@ class MemberController(base.BaseController):
member_dict = db_prepare.create_member(member.to_dict(
render_unsets=True), self.pool_id, bool(pool.health_monitor))
self._test_lb_and_listener_and_pool_statuses(lock_session)
self._test_lb_and_listener_and_pool_statuses(context.session)
db_member = self._validate_create_member(lock_session, member_dict)
db_member = self._validate_create_member(context.session,
member_dict)
# Prepare the data for the driver data model
provider_member = (
@ -185,12 +190,13 @@ class MemberController(base.BaseController):
driver_utils.call_provider(
driver.name, driver.member_create, provider_member)
lock_session.commit()
context.session.commit()
except Exception:
with excutils.save_and_reraise_exception():
lock_session.rollback()
context.session.rollback()
db_member = self._get_db_member(context.session, db_member.id)
with context.session.begin():
db_member = self._get_db_member(context.session, db_member.id)
result = self._convert_db_to_type(db_member,
member_types.MemberResponse)
return member_types.MemberRootResponse(member=result)
@ -226,12 +232,13 @@ class MemberController(base.BaseController):
"""Updates a pool member."""
member = member_.member
context = pecan_request.context.get('octavia_context')
db_member = self._get_db_member(context.session, id,
show_deleted=False)
pool = self.repositories.pool.get(context.session,
id=db_member.pool_id)
project_id, provider = self._get_lb_project_id_provider(
context.session, pool.load_balancer_id)
with context.session.begin():
db_member = self._get_db_member(context.session, id,
show_deleted=False)
pool = self.repositories.pool.get(context.session,
id=db_member.pool_id)
project_id, provider = self._get_lb_project_id_provider(
context.session, pool.load_balancer_id)
self._auth_validate_action(context, project_id, constants.RBAC_PUT)
@ -242,8 +249,8 @@ class MemberController(base.BaseController):
# Load the driver early as it also provides validation
driver = driver_factory.get_driver(provider)
with db_api.get_lock_session() as lock_session:
self._test_lb_and_listener_and_pool_statuses(lock_session,
with context.session.begin():
self._test_lb_and_listener_and_pool_statuses(context.session,
member=db_member)
# Prepare the data for the driver data model
@ -267,12 +274,14 @@ class MemberController(base.BaseController):
# Update the database to reflect what the driver just accepted
member.provisioning_status = constants.PENDING_UPDATE
db_member_dict = member.to_dict(render_unsets=False)
self.repositories.member.update(lock_session, id, **db_member_dict)
self.repositories.member.update(context.session, id,
**db_member_dict)
# Force SQL alchemy to query the DB, otherwise we get inconsistent
# results
context.session.expire_all()
db_member = self._get_db_member(context.session, id)
with context.session.begin():
db_member = self._get_db_member(context.session, id)
result = self._convert_db_to_type(db_member,
member_types.MemberResponse)
return member_types.MemberRootResponse(member=result)
@ -281,13 +290,14 @@ class MemberController(base.BaseController):
def delete(self, id):
"""Deletes a pool member."""
context = pecan_request.context.get('octavia_context')
db_member = self._get_db_member(context.session, id,
show_deleted=False)
with context.session.begin():
db_member = self._get_db_member(context.session, id,
show_deleted=False)
pool = self.repositories.pool.get(context.session,
id=db_member.pool_id)
project_id, provider = self._get_lb_project_id_provider(
context.session, pool.load_balancer_id)
pool = self.repositories.pool.get(context.session,
id=db_member.pool_id)
project_id, provider = self._get_lb_project_id_provider(
context.session, pool.load_balancer_id)
self._auth_validate_action(context, project_id, constants.RBAC_DELETE)
@ -296,11 +306,11 @@ class MemberController(base.BaseController):
# Load the driver early as it also provides validation
driver = driver_factory.get_driver(provider)
with db_api.get_lock_session() as lock_session:
self._test_lb_and_listener_and_pool_statuses(lock_session,
with context.session.begin():
self._test_lb_and_listener_and_pool_statuses(context.session,
member=db_member)
self.repositories.member.update(
lock_session, db_member.id,
context.session, db_member.id,
provisioning_status=constants.PENDING_DELETE)
LOG.info("Sending delete Member %s to provider %s", id,
@ -324,11 +334,12 @@ class MembersController(MemberController):
additive_only = strutils.bool_from_string(additive_only)
context = pecan_request.context.get('octavia_context')
db_pool = self._get_db_pool(context.session, self.pool_id)
old_members = db_pool.members
with context.session.begin():
db_pool = self._get_db_pool(context.session, self.pool_id)
old_members = db_pool.members
project_id, provider = self._get_lb_project_id_provider(
context.session, db_pool.load_balancer_id)
project_id, provider = self._get_lb_project_id_provider(
context.session, db_pool.load_balancer_id)
# Check POST+PUT+DELETE since this operation is all of 'CUD'
self._auth_validate_action(context, project_id, constants.RBAC_POST)
@ -340,8 +351,8 @@ class MembersController(MemberController):
# Load the driver early as it also provides validation
driver = driver_factory.get_driver(provider)
with db_api.get_lock_session() as lock_session:
self._test_lb_and_listener_and_pool_statuses(lock_session)
with context.session.begin():
self._test_lb_and_listener_and_pool_statuses(context.session)
old_member_uniques = {
(m.ip_address, m.protocol_port): m.id for m in old_members}
@ -368,7 +379,7 @@ class MembersController(MemberController):
if not (deleted_members or new_members or updated_members):
LOG.info("Member batch update is a noop, rolling back and "
"returning early.")
lock_session.rollback()
context.session.rollback()
return
if additive_only:
@ -376,7 +387,7 @@ class MembersController(MemberController):
else:
member_count_diff = len(new_members) - len(deleted_members)
if member_count_diff > 0 and self.repositories.check_quota_met(
context.session, lock_session, data_models.Member,
context.session, context.session, data_models.Member,
db_pool.project_id, count=member_count_diff):
raise exceptions.QuotaException(
resource=data_models.Member._name())
@ -405,7 +416,7 @@ class MembersController(MemberController):
m = m.to_dict(render_unsets=False)
m['project_id'] = db_pool.project_id
created_member = self._graph_create(lock_session, m)
created_member = self._graph_create(context.session, m)
provider_member = driver_utils.db_member_to_provider_member(
created_member)
provider_members.append(provider_member)
@ -416,7 +427,7 @@ class MembersController(MemberController):
db_member_dict = m.to_dict(render_unsets=False)
db_member_dict.pop('id')
self.repositories.member.update(
lock_session, m.id, **db_member_dict)
context.session, m.id, **db_member_dict)
m.pool_id = self.pool_id
provider_members.append(
@ -434,7 +445,7 @@ class MembersController(MemberController):
else:
# Members are changed to PENDING_DELETE and not passed.
self.repositories.member.update(
lock_session, m.id,
context.session, m.id,
provisioning_status=constants.PENDING_DELETE)
# Dispatch to the driver

View File

@ -34,7 +34,6 @@ from octavia.common import constants
from octavia.common import data_models
from octavia.common import exceptions
from octavia.common import validate
from octavia.db import api as db_api
from octavia.db import prepare as db_prepare
from octavia.i18n import _
@ -54,7 +53,9 @@ class PoolsController(base.BaseController):
def get(self, id, fields=None):
"""Gets a pool's details."""
context = pecan_request.context.get('octavia_context')
db_pool = self._get_db_pool(context.session, id, show_deleted=False)
with context.session.begin():
db_pool = self._get_db_pool(context.session, id,
show_deleted=False)
self._auth_validate_action(context, db_pool.project_id,
constants.RBAC_GET_ONE)
@ -73,10 +74,11 @@ class PoolsController(base.BaseController):
query_filter = self._auth_get_all(context, project_id)
db_pools, links = self.repositories.pool.get_all_API_list(
context.session, show_deleted=False,
pagination_helper=pcontext.get(constants.PAGINATION_HELPER),
**query_filter)
with context.session.begin():
db_pools, links = self.repositories.pool.get_all_API_list(
context.session, show_deleted=False,
pagination_helper=pcontext.get(constants.PAGINATION_HELPER),
**query_filter)
result = self._convert_db_to_type(db_pools, [pool_types.PoolResponse])
if fields is not None:
result = self._filter_fields(result, fields)
@ -141,9 +143,11 @@ class PoolsController(base.BaseController):
validate.check_alpn_protocols(pool_dict['alpn_protocols'])
try:
return self.repositories.create_pool_on_load_balancer(
ret = self.repositories.create_pool_on_load_balancer(
lock_session, pool_dict,
listener_id=listener_id)
lock_session.flush()
return ret
except odb_exceptions.DBDuplicateEntry as e:
raise exceptions.IDAlreadyExists() from e
except odb_exceptions.DBReferenceError as e:
@ -211,19 +215,20 @@ class PoolsController(base.BaseController):
pool = pool_.pool
context = pecan_request.context.get('octavia_context')
listener = None
if pool.loadbalancer_id:
pool.project_id, provider = self._get_lb_project_id_provider(
context.session, pool.loadbalancer_id)
elif pool.listener_id:
listener = self.repositories.listener.get(
context.session, id=pool.listener_id)
pool.loadbalancer_id = listener.load_balancer_id
pool.project_id, provider = self._get_lb_project_id_provider(
context.session, pool.loadbalancer_id)
else:
msg = _("Must provide at least one of: "
"loadbalancer_id, listener_id")
raise exceptions.ValidationException(detail=msg)
with context.session.begin():
if pool.loadbalancer_id:
pool.project_id, provider = self._get_lb_project_id_provider(
context.session, pool.loadbalancer_id)
elif pool.listener_id:
listener = self.repositories.listener.get(
context.session, id=pool.listener_id)
pool.loadbalancer_id = listener.load_balancer_id
pool.project_id, provider = self._get_lb_project_id_provider(
context.session, pool.loadbalancer_id)
else:
msg = _("Must provide at least one of: "
"loadbalancer_id, listener_id")
raise exceptions.ValidationException(detail=msg)
self._auth_validate_action(context, pool.project_id,
constants.RBAC_POST)
@ -252,11 +257,11 @@ class PoolsController(base.BaseController):
# Load the driver early as it also provides validation
driver = driver_factory.get_driver(provider)
lock_session = db_api.get_session(autocommit=False)
context.session.begin()
try:
if self.repositories.check_quota_met(
context.session,
lock_session,
context.session,
data_models.Pool,
pool.project_id):
raise exceptions.QuotaException(
@ -268,16 +273,16 @@ class PoolsController(base.BaseController):
listener_id = pool_dict.pop('listener_id', None)
if listener_id:
if listener_repo.has_default_pool(lock_session,
if listener_repo.has_default_pool(context.session,
listener_id):
raise exceptions.DuplicatePoolEntry()
self._test_lb_and_listener_statuses(
lock_session, lb_id=pool_dict['load_balancer_id'],
context.session, lb_id=pool_dict['load_balancer_id'],
listener_ids=[listener_id] if listener_id else [])
db_pool = self._validate_create_pool(
lock_session, pool_dict, listener_id)
context.session, pool_dict, listener_id)
# Prepare the data for the driver data model
provider_pool = (
@ -289,12 +294,13 @@ class PoolsController(base.BaseController):
driver_utils.call_provider(
driver.name, driver.pool_create, provider_pool)
lock_session.commit()
context.session.commit()
except Exception:
with excutils.save_and_reraise_exception():
lock_session.rollback()
context.session.rollback()
db_pool = self._get_db_pool(context.session, db_pool.id)
with context.session.begin():
db_pool = self._get_db_pool(context.session, db_pool.id)
result = self._convert_db_to_type(db_pool, pool_types.PoolResponse)
return pool_types.PoolRootResponse(pool=result)
@ -435,10 +441,12 @@ class PoolsController(base.BaseController):
"""Updates a pool on a load balancer."""
pool = pool_.pool
context = pecan_request.context.get('octavia_context')
db_pool = self._get_db_pool(context.session, id, show_deleted=False)
with context.session.begin():
db_pool = self._get_db_pool(context.session, id,
show_deleted=False)
project_id, provider = self._get_lb_project_id_provider(
context.session, db_pool.load_balancer_id)
project_id, provider = self._get_lb_project_id_provider(
context.session, db_pool.load_balancer_id)
self._auth_validate_action(context, project_id, constants.RBAC_PUT)
@ -458,9 +466,9 @@ class PoolsController(base.BaseController):
# Load the driver early as it also provides validation
driver = driver_factory.get_driver(provider)
with db_api.get_lock_session() as lock_session:
with context.session.begin():
self._test_lb_and_listener_statuses(
lock_session, lb_id=db_pool.load_balancer_id,
context.session, lb_id=db_pool.load_balancer_id,
listener_ids=self._get_affected_listener_ids(db_pool))
# Prepare the data for the driver data model
@ -483,13 +491,14 @@ class PoolsController(base.BaseController):
# Update the database to reflect what the driver just accepted
pool.provisioning_status = constants.PENDING_UPDATE
db_pool_dict = pool.to_dict(render_unsets=False)
self.repositories.update_pool_and_sp(lock_session, id,
self.repositories.update_pool_and_sp(context.session, id,
db_pool_dict)
# Force SQL alchemy to query the DB, otherwise we get inconsistent
# results
context.session.expire_all()
db_pool = self._get_db_pool(context.session, id)
with context.session.begin():
db_pool = self._get_db_pool(context.session, id)
result = self._convert_db_to_type(db_pool, pool_types.PoolResponse)
return pool_types.PoolRootResponse(pool=result)
@ -497,10 +506,12 @@ class PoolsController(base.BaseController):
def delete(self, id):
"""Deletes a pool from a load balancer."""
context = pecan_request.context.get('octavia_context')
db_pool = self._get_db_pool(context.session, id, show_deleted=False)
with context.session.begin():
db_pool = self._get_db_pool(context.session, id,
show_deleted=False)
project_id, provider = self._get_lb_project_id_provider(
context.session, db_pool.load_balancer_id)
project_id, provider = self._get_lb_project_id_provider(
context.session, db_pool.load_balancer_id)
self._auth_validate_action(context, project_id, constants.RBAC_DELETE)
@ -511,12 +522,12 @@ class PoolsController(base.BaseController):
# Load the driver early as it also provides validation
driver = driver_factory.get_driver(provider)
with db_api.get_lock_session() as lock_session:
with context.session.begin():
self._test_lb_and_listener_statuses(
lock_session, lb_id=db_pool.load_balancer_id,
context.session, lb_id=db_pool.load_balancer_id,
listener_ids=self._get_affected_listener_ids(db_pool))
self.repositories.pool.update(
lock_session, db_pool.id,
context.session, db_pool.id,
provisioning_status=constants.PENDING_DELETE)
LOG.info("Sending delete Pool %s to provider %s", id, driver.name)
@ -536,7 +547,9 @@ class PoolsController(base.BaseController):
context = pecan_request.context.get('octavia_context')
if pool_id and remainder and remainder[0] == 'members':
remainder = remainder[1:]
db_pool = self.repositories.pool.get(context.session, id=pool_id)
with context.session.begin():
db_pool = self.repositories.pool.get(context.session,
id=pool_id)
if not db_pool:
LOG.info("Pool %s not found.", pool_id)
raise exceptions.NotFound(resource=data_models.Pool._name(),

View File

@ -72,8 +72,9 @@ class QuotasController(base.BaseController):
self._auth_validate_action(context, project_id, constants.RBAC_PUT)
quotas_dict = quotas.to_dict()
self.repositories.quotas.update(context.session, project_id,
**quotas_dict)
with context.session.begin():
self.repositories.quotas.update(context.session, project_id,
**quotas_dict)
db_quotas = self._get_db_quotas(context.session, project_id)
return self._convert_db_to_type(db_quotas, quota_types.QuotaResponse)
@ -87,7 +88,8 @@ class QuotasController(base.BaseController):
self._auth_validate_action(context, project_id, constants.RBAC_DELETE)
self.repositories.quotas.delete(context.session, project_id)
with context.session.begin():
self.repositories.quotas.delete(context.session, project_id)
db_quotas = self._get_db_quotas(context.session, project_id)
return self._convert_db_to_type(db_quotas, quota_types.QuotaResponse)

View File

@ -87,7 +87,8 @@ class HealthManager(object):
amp_health = None
lock_session = None
try:
lock_session = db_api.get_session(autocommit=False)
lock_session = db_api.get_session()
lock_session.begin()
amp_health = self.amp_health_repo.get_stale_amphora(
lock_session)
if amp_health:

View File

@ -39,24 +39,26 @@ class DatabaseCleanup(object):
seconds=CONF.house_keeping.amphora_expiry_age)
session = db_api.get_session()
amp_ids = self.amp_repo.get_all_deleted_expiring(session,
exp_age=exp_age)
with session.begin():
amp_ids = self.amp_repo.get_all_deleted_expiring(session,
exp_age=exp_age)
for amp_id in amp_ids:
# If we're here, we already think the amp is expiring according to
# the amphora table. Now check it is expired in the health table.
# In this way, we ensure that amps aren't deleted unless they are
# both expired AND no longer receiving zombie heartbeats.
if self.amp_health_repo.check_amphora_health_expired(
session, amp_id, exp_age):
LOG.debug('Attempting to purge db record for Amphora ID: %s',
amp_id)
self.amp_repo.delete(session, id=amp_id)
try:
self.amp_health_repo.delete(session, amphora_id=amp_id)
except sqlalchemy_exceptions.NoResultFound:
pass # Best effort delete, this record might not exist
LOG.info('Purged db record for Amphora ID: %s', amp_id)
for amp_id in amp_ids:
# If we're here, we already think the amp is expiring according
# to the amphora table. Now check it is expired in the health
# table.
# In this way, we ensure that amps aren't deleted unless they
# are both expired AND no longer receiving zombie heartbeats.
if self.amp_health_repo.check_amphora_health_expired(
session, amp_id, exp_age):
LOG.debug('Attempting to purge db record for Amphora ID: '
'%s', amp_id)
self.amp_repo.delete(session, id=amp_id)
try:
self.amp_health_repo.delete(session, amphora_id=amp_id)
except sqlalchemy_exceptions.NoResultFound:
pass # Best effort delete, this record might not exist
LOG.info('Purged db record for Amphora ID: %s', amp_id)
def cleanup_load_balancers(self):
"""Checks the DB for old load balancers and triggers their removal."""
@ -64,13 +66,14 @@ class DatabaseCleanup(object):
seconds=CONF.house_keeping.load_balancer_expiry_age)
session = db_api.get_session()
lb_ids = self.lb_repo.get_all_deleted_expiring(session,
exp_age=exp_age)
with session.begin():
lb_ids = self.lb_repo.get_all_deleted_expiring(session,
exp_age=exp_age)
for lb_id in lb_ids:
LOG.info('Attempting to delete load balancer id : %s', lb_id)
self.lb_repo.delete(session, id=lb_id)
LOG.info('Deleted load balancer id : %s', lb_id)
for lb_id in lb_ids:
LOG.info('Attempting to delete load balancer id : %s', lb_id)
self.lb_repo.delete(session, id=lb_id)
LOG.info('Deleted load balancer id : %s', lb_id)
class CertRotation(object):
@ -83,14 +86,15 @@ class CertRotation(object):
amp_repo = repo.AmphoraRepository()
with futures.ThreadPoolExecutor(max_workers=self.threads) as executor:
session = db_api.get_session()
rotation_count = 0
while True:
amp = amp_repo.get_cert_expiring_amphora(session)
if not amp:
break
rotation_count += 1
LOG.debug("Cert expired amphora's id is: %s", amp.id)
executor.submit(self.cw.amphora_cert_rotation, amp.id)
session = db_api.get_session()
with session.begin():
amp = amp_repo.get_cert_expiring_amphora(session)
if not amp:
break
rotation_count += 1
LOG.debug("Cert expired amphora's id is: %s", amp.id)
executor.submit(self.cw.amphora_cert_rotation, amp.id)
if rotation_count > 0:
LOG.info("Rotated certificates for %s amphora", rotation_count)

View File

@ -35,10 +35,11 @@ class AmphoraBuildRateLimit(object):
self.amp_build_req_repo = repo.AmphoraBuildReqRepository()
def add_to_build_request_queue(self, amphora_id, build_priority):
self.amp_build_req_repo.add_to_build_queue(
db_apis.get_session(),
amphora_id=amphora_id,
priority=build_priority)
with db_apis.session().begin() as session:
self.amp_build_req_repo.add_to_build_queue(
session,
amphora_id=amphora_id,
priority=build_priority)
LOG.debug("Added build request for amphora %s to the queue",
amphora_id)
self.wait_for_build_slot(amphora_id)

View File

@ -48,9 +48,10 @@ class TaskUtils(object):
LOG.debug('Unmarking health monitoring busy on amphora: %s',
amphora_id)
try:
self.amp_health_repo.update(db_apis.get_session(),
amphora_id=amphora_id,
busy=False)
with db_apis.session().begin() as session:
self.amp_health_repo.update(session,
amphora_id=amphora_id,
busy=False)
except Exception as e:
LOG.debug('Failed to update amphora health record %(amp)s '
'due to: %(except)s',
@ -64,9 +65,10 @@ class TaskUtils(object):
:param amphora_id: Amphora ID to set the status to ERROR
"""
try:
self.amphora_repo.update(db_apis.get_session(),
id=amphora_id,
status=constants.ERROR)
with db_apis.session().begin() as session:
self.amphora_repo.update(session,
id=amphora_id,
status=constants.ERROR)
except Exception as e:
LOG.error("Failed to update amphora %(amp)s "
"status to ERROR due to: "
@ -80,9 +82,10 @@ class TaskUtils(object):
:param health_mon_id: Health Monitor ID to set prov status to ERROR
"""
try:
self.health_mon_repo.update(db_apis.get_session(),
id=health_mon_id,
provisioning_status=constants.ERROR)
with db_apis.session().begin() as session:
self.health_mon_repo.update(
session, id=health_mon_id,
provisioning_status=constants.ERROR)
except Exception as e:
LOG.error("Failed to update health monitor %(health)s "
"provisioning status to ERROR due to: "
@ -97,9 +100,10 @@ class TaskUtils(object):
:param l7policy_id: L7 Policy ID to set provisioning status to ACTIVE
"""
try:
self.l7policy_repo.update(db_apis.get_session(),
id=l7policy_id,
provisioning_status=constants.ACTIVE)
with db_apis.session().begin() as session:
self.l7policy_repo.update(session,
id=l7policy_id,
provisioning_status=constants.ACTIVE)
except Exception as e:
LOG.error("Failed to update l7policy %(l7p)s "
"provisioning status to ACTIVE due to: "
@ -113,9 +117,10 @@ class TaskUtils(object):
:param l7policy_id: L7 Policy ID to set provisioning status to ERROR
"""
try:
self.l7policy_repo.update(db_apis.get_session(),
id=l7policy_id,
provisioning_status=constants.ERROR)
with db_apis.session().begin() as session:
self.l7policy_repo.update(session,
id=l7policy_id,
provisioning_status=constants.ERROR)
except Exception as e:
LOG.error("Failed to update l7policy %(l7p)s "
"provisioning status to ERROR due to: "
@ -129,9 +134,10 @@ class TaskUtils(object):
:param l7rule_id: L7 Rule ID to set provisioning status to ERROR
"""
try:
self.l7rule_repo.update(db_apis.get_session(),
id=l7rule_id,
provisioning_status=constants.ERROR)
with db_apis.session().begin() as session:
self.l7rule_repo.update(session,
id=l7rule_id,
provisioning_status=constants.ERROR)
except Exception as e:
LOG.error("Failed to update l7rule %(l7r)s "
"provisioning status to ERROR due to: "
@ -145,9 +151,10 @@ class TaskUtils(object):
:param listener_id: Listener ID to set provisioning status to ERROR
"""
try:
self.listener_repo.update(db_apis.get_session(),
id=listener_id,
provisioning_status=constants.ERROR)
with db_apis.session().begin() as session:
self.listener_repo.update(session,
id=listener_id,
provisioning_status=constants.ERROR)
except Exception as e:
LOG.error("Failed to update listener %(list)s "
"provisioning status to ERROR due to: "
@ -162,9 +169,11 @@ class TaskUtils(object):
status to ERROR
"""
try:
self.loadbalancer_repo.update(db_apis.get_session(),
id=loadbalancer_id,
provisioning_status=constants.ERROR)
with db_apis.session().begin() as session:
self.loadbalancer_repo.update(
session,
id=loadbalancer_id,
provisioning_status=constants.ERROR)
except Exception as e:
LOG.error("Failed to update load balancer %(lb)s "
"provisioning status to ERROR due to: "
@ -179,9 +188,10 @@ class TaskUtils(object):
status to ACTIVE
"""
try:
self.listener_repo.update(db_apis.get_session(),
id=listener_id,
provisioning_status=constants.ACTIVE)
with db_apis.session().begin() as session:
self.listener_repo.update(session,
id=listener_id,
provisioning_status=constants.ACTIVE)
except Exception as e:
LOG.error("Failed to update listener %(list)s "
"provisioning status to ACTIVE due to: "
@ -195,9 +205,10 @@ class TaskUtils(object):
:param pool_id: Pool ID to set provisioning status to ACTIVE
"""
try:
self.pool_repo.update(db_apis.get_session(),
id=pool_id,
provisioning_status=constants.ACTIVE)
with db_apis.session().begin() as session:
self.pool_repo.update(session,
id=pool_id,
provisioning_status=constants.ACTIVE)
except Exception as e:
LOG.error("Failed to update pool %(pool)s provisioning status "
"to ACTIVE due to: %(except)s", {'pool': pool_id,
@ -212,9 +223,11 @@ class TaskUtils(object):
status to ACTIVE
"""
try:
self.loadbalancer_repo.update(db_apis.get_session(),
id=loadbalancer_id,
provisioning_status=constants.ACTIVE)
with db_apis.session().begin() as session:
self.loadbalancer_repo.update(
session,
id=loadbalancer_id,
provisioning_status=constants.ACTIVE)
except Exception as e:
LOG.error("Failed to update load balancer %(lb)s "
"provisioning status to ACTIVE due to: "
@ -228,9 +241,10 @@ class TaskUtils(object):
:param member_id: Member ID to set provisioning status to ERROR
"""
try:
self.member_repo.update(db_apis.get_session(),
id=member_id,
provisioning_status=constants.ERROR)
with db_apis.session().begin() as session:
self.member_repo.update(session,
id=member_id,
provisioning_status=constants.ERROR)
except Exception as e:
LOG.error("Failed to update member %(member)s "
"provisioning status to ERROR due to: "
@ -244,9 +258,10 @@ class TaskUtils(object):
:param pool_id: Pool ID to set provisioning status to ERROR
"""
try:
self.pool_repo.update(db_apis.get_session(),
id=pool_id,
provisioning_status=constants.ERROR)
with db_apis.session().begin() as session:
self.pool_repo.update(session,
id=pool_id,
provisioning_status=constants.ERROR)
except Exception as e:
LOG.error("Failed to update pool %(pool)s "
"provisioning status to ERROR due to: "
@ -258,8 +273,9 @@ class TaskUtils(object):
:param: loadbalancer_id: Load balancer ID which to get from db
"""
try:
return self.loadbalancer_repo.get(db_apis.get_session(),
id=loadbalancer_id)
with db_apis.session().begin() as session:
return self.loadbalancer_repo.get(session,
id=loadbalancer_id)
except Exception as e:
LOG.error("Failed to get loadbalancer %(loadbalancer)s "
"due to: %(except)s",

View File

@ -94,7 +94,9 @@ class ControllerWorker(object):
CONF.haproxy_amphora.api_db_commit_retry_attempts))
def _get_db_obj_until_pending_update(self, repo, id):
return repo.get(db_apis.get_session(), id=id)
session = db_apis.get_session()
with session.begin():
return repo.get(session, id=id)
@property
def services_controller(self):
@ -118,8 +120,10 @@ class ControllerWorker(object):
:raises AmphoraNotFound: The referenced Amphora was not found
"""
try:
amphora = self._amphora_repo.get(db_apis.get_session(),
id=amphora_id)
session = db_apis.get_session()
with session.begin():
amphora = self._amphora_repo.get(session,
id=amphora_id)
store = {constants.AMPHORA: amphora.to_dict()}
self.run_flow(
flow_utils.get_delete_amphora_flow,
@ -145,9 +149,11 @@ class ControllerWorker(object):
:returns: None
:raises NoResultFound: Unable to find the object
"""
db_health_monitor = self._health_mon_repo.get(
db_apis.get_session(),
id=health_monitor[constants.HEALTHMONITOR_ID])
session = db_apis.get_session()
with session.begin():
db_health_monitor = self._health_mon_repo.get(
session,
id=health_monitor[constants.HEALTHMONITOR_ID])
if not db_health_monitor:
LOG.warning('Failed to fetch %s %s from DB. Retrying for up to '
@ -178,9 +184,11 @@ class ControllerWorker(object):
:returns: None
:raises HMNotFound: The referenced health monitor was not found
"""
db_health_monitor = self._health_mon_repo.get(
db_apis.get_session(),
id=health_monitor[constants.HEALTHMONITOR_ID])
session = db_apis.get_session()
with session.begin():
db_health_monitor = self._health_mon_repo.get(
session,
id=health_monitor[constants.HEALTHMONITOR_ID])
pool = db_health_monitor.pool
load_balancer = pool.load_balancer
@ -251,8 +259,10 @@ class ControllerWorker(object):
:returns: None
:raises NoResultFound: Unable to find the object
"""
db_listener = self._listener_repo.get(
db_apis.get_session(), id=listener[constants.LISTENER_ID])
session = db_apis.get_session()
with session.begin():
db_listener = self._listener_repo.get(
session, id=listener[constants.LISTENER_ID])
if not db_listener:
LOG.warning('Failed to fetch %s %s from DB. Retrying for up to '
'60 seconds.', 'listener',
@ -333,8 +343,10 @@ class ControllerWorker(object):
:returns: None
:raises NoResultFound: Unable to find the object
"""
lb = self._lb_repo.get(db_apis.get_session(),
id=loadbalancer[constants.LOADBALANCER_ID])
session = db_apis.get_session()
with session.begin():
lb = self._lb_repo.get(session,
id=loadbalancer[constants.LOADBALANCER_ID])
if not lb:
LOG.warning('Failed to fetch %s %s from DB. Retrying for up to '
'60 seconds.', 'load_balancer',
@ -374,7 +386,9 @@ class ControllerWorker(object):
:raises LBNotFound: The referenced load balancer was not found
"""
loadbalancer_id = load_balancer[constants.LOADBALANCER_ID]
db_lb = self._lb_repo.get(db_apis.get_session(), id=loadbalancer_id)
session = db_apis.get_session()
with session.begin():
db_lb = self._lb_repo.get(session, id=loadbalancer_id)
store = {constants.LOADBALANCER: load_balancer,
constants.LOADBALANCER_ID: loadbalancer_id,
constants.SERVER_GROUP_ID: db_lb.server_group_id,
@ -436,8 +450,10 @@ class ControllerWorker(object):
:returns: None
:raises NoSuitablePool: Unable to find the node pool
"""
db_member = self._member_repo.get(db_apis.get_session(),
id=member[constants.MEMBER_ID])
session = db_apis.get_session()
with session.begin():
db_member = self._member_repo.get(session,
id=member[constants.MEMBER_ID])
if not db_member:
LOG.warning('Failed to fetch %s %s from DB. Retrying for up to '
'60 seconds.', 'l7member',
@ -457,9 +473,10 @@ class ControllerWorker(object):
constants.LOADBALANCER: provider_lb,
constants.POOL_ID: pool.id}
if load_balancer.availability_zone:
store[constants.AVAILABILITY_ZONE] = (
self._az_repo.get_availability_zone_metadata_dict(
db_apis.get_session(), load_balancer.availability_zone))
with session.begin():
store[constants.AVAILABILITY_ZONE] = (
self._az_repo.get_availability_zone_metadata_dict(
session, load_balancer.availability_zone))
else:
store[constants.AVAILABILITY_ZONE] = {}
@ -474,8 +491,10 @@ class ControllerWorker(object):
:returns: None
:raises MemberNotFound: The referenced member was not found
"""
pool = self._pool_repo.get(db_apis.get_session(),
id=member[constants.POOL_ID])
session = db_apis.get_session()
with session.begin():
pool = self._pool_repo.get(session,
id=member[constants.POOL_ID])
load_balancer = pool.load_balancer
provider_lb = provider_utils.db_loadbalancer_to_provider_loadbalancer(
@ -490,9 +509,10 @@ class ControllerWorker(object):
constants.POOL_ID: pool.id,
constants.PROJECT_ID: load_balancer.project_id}
if load_balancer.availability_zone:
store[constants.AVAILABILITY_ZONE] = (
self._az_repo.get_availability_zone_metadata_dict(
db_apis.get_session(), load_balancer.availability_zone))
with session.begin():
store[constants.AVAILABILITY_ZONE] = (
self._az_repo.get_availability_zone_metadata_dict(
session, load_balancer.availability_zone))
else:
store[constants.AVAILABILITY_ZONE] = {}
@ -510,9 +530,12 @@ class ControllerWorker(object):
CONF.haproxy_amphora.api_db_commit_retry_attempts))
def batch_update_members(self, old_members, new_members,
updated_members):
db_new_members = [self._member_repo.get(db_apis.get_session(),
id=member[constants.MEMBER_ID])
for member in new_members]
session = db_apis.get_session()
with session.begin():
db_new_members = [
self._member_repo.get(
session, id=member[constants.MEMBER_ID])
for member in new_members]
# The API may not have commited all of the new member records yet.
# Make sure we retry looking them up.
if None in db_new_members or len(db_new_members) != len(new_members):
@ -520,27 +543,28 @@ class ControllerWorker(object):
'Retrying for up to 60 seconds.')
raise db_exceptions.NoResultFound
updated_members = [
(provider_utils.db_member_to_provider_member(
self._member_repo.get(db_apis.get_session(),
id=m.get(constants.ID))).to_dict(),
m)
for m in updated_members]
provider_old_members = [
provider_utils.db_member_to_provider_member(
self._member_repo.get(db_apis.get_session(),
id=m.get(constants.ID))).to_dict()
for m in old_members]
if old_members:
pool = self._pool_repo.get(db_apis.get_session(),
id=old_members[0][constants.POOL_ID])
elif new_members:
pool = self._pool_repo.get(db_apis.get_session(),
id=new_members[0][constants.POOL_ID])
else:
pool = self._pool_repo.get(
db_apis.get_session(),
id=updated_members[0][0][constants.POOL_ID])
with session.begin():
updated_members = [
(provider_utils.db_member_to_provider_member(
self._member_repo.get(session,
id=m.get(constants.ID))).to_dict(),
m)
for m in updated_members]
provider_old_members = [
provider_utils.db_member_to_provider_member(
self._member_repo.get(session,
id=m.get(constants.ID))).to_dict()
for m in old_members]
if old_members:
pool = self._pool_repo.get(
session, id=old_members[0][constants.POOL_ID])
elif new_members:
pool = self._pool_repo.get(
session, id=new_members[0][constants.POOL_ID])
else:
pool = self._pool_repo.get(
session,
id=updated_members[0][0][constants.POOL_ID])
load_balancer = pool.load_balancer
provider_lb = provider_utils.db_loadbalancer_to_provider_loadbalancer(
@ -554,9 +578,10 @@ class ControllerWorker(object):
constants.POOL_ID: pool.id,
constants.PROJECT_ID: load_balancer.project_id}
if load_balancer.availability_zone:
store[constants.AVAILABILITY_ZONE] = (
self._az_repo.get_availability_zone_metadata_dict(
db_apis.get_session(), load_balancer.availability_zone))
with session.begin():
store[constants.AVAILABILITY_ZONE] = (
self._az_repo.get_availability_zone_metadata_dict(
session, load_balancer.availability_zone))
else:
store[constants.AVAILABILITY_ZONE] = {}
@ -598,9 +623,11 @@ class ControllerWorker(object):
constants.POOL_ID: pool.id,
constants.UPDATE_DICT: member_updates}
if load_balancer.availability_zone:
store[constants.AVAILABILITY_ZONE] = (
self._az_repo.get_availability_zone_metadata_dict(
db_apis.get_session(), load_balancer.availability_zone))
session = db_apis.get_session()
with session.begin():
store[constants.AVAILABILITY_ZONE] = (
self._az_repo.get_availability_zone_metadata_dict(
session, load_balancer.availability_zone))
else:
store[constants.AVAILABILITY_ZONE] = {}
@ -626,8 +653,10 @@ class ControllerWorker(object):
# TODO(ataraday) It seems we need to get db pool here anyway to get
# proper listeners
db_pool = self._pool_repo.get(db_apis.get_session(),
id=pool[constants.POOL_ID])
session = db_apis.get_session()
with session.begin():
db_pool = self._pool_repo.get(session,
id=pool[constants.POOL_ID])
if not db_pool:
LOG.warning('Failed to fetch %s %s from DB. Retrying for up to '
'60 seconds.', 'pool', pool[constants.POOL_ID])
@ -653,8 +682,10 @@ class ControllerWorker(object):
:returns: None
:raises PoolNotFound: The referenced pool was not found
"""
db_pool = self._pool_repo.get(db_apis.get_session(),
id=pool[constants.POOL_ID])
session = db_apis.get_session()
with session.begin():
db_pool = self._pool_repo.get(session,
id=pool[constants.POOL_ID])
load_balancer = db_pool.load_balancer
provider_lb = provider_utils.db_loadbalancer_to_provider_loadbalancer(
@ -718,8 +749,10 @@ class ControllerWorker(object):
:returns: None
:raises NoResultFound: Unable to find the object
"""
db_l7policy = self._l7policy_repo.get(
db_apis.get_session(), id=l7policy[constants.L7POLICY_ID])
session = db_apis.get_session()
with session.begin():
db_l7policy = self._l7policy_repo.get(
session, id=l7policy[constants.L7POLICY_ID])
if not db_l7policy:
LOG.warning('Failed to fetch %s %s from DB. Retrying for up to '
'60 seconds.', 'l7policy',
@ -747,8 +780,10 @@ class ControllerWorker(object):
:returns: None
:raises L7PolicyNotFound: The referenced l7policy was not found
"""
db_listener = self._listener_repo.get(
db_apis.get_session(), id=l7policy[constants.LISTENER_ID])
session = db_apis.get_session()
with session.begin():
db_listener = self._listener_repo.get(
session, id=l7policy[constants.LISTENER_ID])
listeners_dicts = (
provider_utils.db_listeners_to_provider_dicts_list_of_dicts(
[db_listener]))
@ -809,8 +844,10 @@ class ControllerWorker(object):
:returns: None
:raises NoResultFound: Unable to find the object
"""
db_l7rule = self._l7rule_repo.get(db_apis.get_session(),
id=l7rule[constants.L7RULE_ID])
session = db_apis.get_session()
with session.begin():
db_l7rule = self._l7rule_repo.get(session,
id=l7rule[constants.L7RULE_ID])
if not db_l7rule:
LOG.warning('Failed to fetch %s %s from DB. Retrying for up to '
'60 seconds.', 'l7rule',
@ -844,8 +881,10 @@ class ControllerWorker(object):
:returns: None
:raises L7RuleNotFound: The referenced l7rule was not found
"""
db_l7policy = self._l7policy_repo.get(db_apis.get_session(),
id=l7rule[constants.L7POLICY_ID])
session = db_apis.get_session()
with session.begin():
db_l7policy = self._l7policy_repo.get(
session, id=l7rule[constants.L7POLICY_ID])
l7policy = provider_utils.db_l7policy_to_provider_l7policy(db_l7policy)
load_balancer = db_l7policy.listener.load_balancer
@ -914,8 +953,10 @@ class ControllerWorker(object):
"""
amphora = None
try:
amphora = self._amphora_repo.get(db_apis.get_session(),
id=amphora_id)
session = db_apis.get_session()
with session.begin():
amphora = self._amphora_repo.get(session,
id=amphora_id)
if amphora is None:
LOG.error('Amphora failover for amphora %s failed because '
'there is no record of this amphora in the '
@ -930,14 +971,16 @@ class ControllerWorker(object):
'was submitted for failover. Deleting it from the '
'amphora health table to exclude it from health '
'checks and skipping the failover.', amphora.id)
self._amphora_health_repo.delete(db_apis.get_session(),
amphora_id=amphora.id)
with session.begin():
self._amphora_health_repo.delete(session,
amphora_id=amphora.id)
return
loadbalancer = None
if amphora.load_balancer_id:
loadbalancer = self._lb_repo.get(db_apis.get_session(),
id=amphora.load_balancer_id)
with session.begin():
loadbalancer = self._lb_repo.get(
session, id=amphora.load_balancer_id)
lb_amp_count = None
if loadbalancer:
if loadbalancer.topology == constants.TOPOLOGY_ACTIVE_STANDBY:
@ -956,18 +999,21 @@ class ControllerWorker(object):
# Even if the LB doesn't have a flavor, create one and
# pass through the topology.
if loadbalancer.flavor_id:
flavor_dict = self._flavor_repo.get_flavor_metadata_dict(
db_apis.get_session(), loadbalancer.flavor_id)
with session.begin():
flavor_dict = (
self._flavor_repo.get_flavor_metadata_dict(
session, loadbalancer.flavor_id))
flavor_dict[constants.LOADBALANCER_TOPOLOGY] = (
loadbalancer.topology)
else:
flavor_dict = {constants.LOADBALANCER_TOPOLOGY:
loadbalancer.topology}
if loadbalancer.availability_zone:
az_metadata = (
self._az_repo.get_availability_zone_metadata_dict(
db_apis.get_session(),
loadbalancer.availability_zone))
with session.begin():
az_metadata = (
self._az_repo.get_availability_zone_metadata_dict(
session,
loadbalancer.availability_zone))
vip_dict = loadbalancer.vip.to_dict()
additional_vip_dicts = [
av.to_dict()
@ -1003,12 +1049,14 @@ class ControllerWorker(object):
with excutils.save_and_reraise_exception(reraise=reraise):
LOG.exception("Amphora %s failover exception: %s",
amphora_id, str(e))
self._amphora_repo.update(db_apis.get_session(),
amphora_id, status=constants.ERROR)
if amphora and amphora.load_balancer_id:
self._lb_repo.update(
db_apis.get_session(), amphora.load_balancer_id,
provisioning_status=constants.ERROR)
with session.begin():
self._amphora_repo.update(session,
amphora_id,
status=constants.ERROR)
if amphora and amphora.load_balancer_id:
self._lb_repo.update(
session, amphora.load_balancer_id,
provisioning_status=constants.ERROR)
@staticmethod
def _get_amphorae_for_failover(load_balancer):
@ -1084,8 +1132,10 @@ class ControllerWorker(object):
found.
"""
try:
lb = self._lb_repo.get(db_apis.get_session(),
id=load_balancer_id)
session = db_apis.get_session()
with session.begin():
lb = self._lb_repo.get(session,
id=load_balancer_id)
if lb is None:
raise exceptions.NotFound(resource=constants.LOADBALANCER,
id=load_balancer_id)
@ -1113,8 +1163,9 @@ class ControllerWorker(object):
# here for the amphora to be created with the correct
# configuration.
if lb.flavor_id:
flavor = self._flavor_repo.get_flavor_metadata_dict(
db_apis.get_session(), lb.flavor_id)
with session.begin():
flavor = self._flavor_repo.get_flavor_metadata_dict(
session, lb.flavor_id)
flavor[constants.LOADBALANCER_TOPOLOGY] = lb.topology
else:
flavor = {constants.LOADBALANCER_TOPOLOGY: lb.topology}
@ -1136,9 +1187,10 @@ class ControllerWorker(object):
constants.FLAVOR: flavor}
if lb.availability_zone:
stored_params[constants.AVAILABILITY_ZONE] = (
self._az_repo.get_availability_zone_metadata_dict(
db_apis.get_session(), lb.availability_zone))
with session.begin():
stored_params[constants.AVAILABILITY_ZONE] = (
self._az_repo.get_availability_zone_metadata_dict(
session, lb.availability_zone))
else:
stored_params[constants.AVAILABILITY_ZONE] = {}
@ -1153,9 +1205,10 @@ class ControllerWorker(object):
with excutils.save_and_reraise_exception(reraise=False):
LOG.exception("LB %(lbid)s failover exception: %(exc)s",
{'lbid': load_balancer_id, 'exc': str(e)})
self._lb_repo.update(
db_apis.get_session(), load_balancer_id,
provisioning_status=constants.ERROR)
with session.begin():
self._lb_repo.update(
session, load_balancer_id,
provisioning_status=constants.ERROR)
def amphora_cert_rotation(self, amphora_id):
"""Perform cert rotation for an amphora.
@ -1165,8 +1218,10 @@ class ControllerWorker(object):
:raises AmphoraNotFound: The referenced amphora was not found
"""
amp = self._amphora_repo.get(db_apis.get_session(),
id=amphora_id)
session = db_apis.get_session()
with session.begin():
amp = self._amphora_repo.get(session,
id=amphora_id)
LOG.info("Start amphora cert rotation, amphora's id is: %s",
amphora_id)
@ -1191,13 +1246,15 @@ class ControllerWorker(object):
"""
LOG.info("Start amphora agent configuration update, amphora's id "
"is: %s", amphora_id)
amp = self._amphora_repo.get(db_apis.get_session(), id=amphora_id)
lb = self._amphora_repo.get_lb_for_amphora(db_apis.get_session(),
amphora_id)
flavor = {}
if lb.flavor_id:
flavor = self._flavor_repo.get_flavor_metadata_dict(
db_apis.get_session(), lb.flavor_id)
session = db_apis.get_session()
with session.begin():
amp = self._amphora_repo.get(session, id=amphora_id)
lb = self._amphora_repo.get_lb_for_amphora(session,
amphora_id)
flavor = {}
if lb.flavor_id:
flavor = self._flavor_repo.get_flavor_metadata_dict(
session, lb.flavor_id)
store = {constants.AMPHORA: amp.to_dict(),
constants.FLAVOR: flavor}

View File

@ -80,20 +80,23 @@ class AmpListenersUpdate(BaseAmphoraTask):
# health manager fix it.
# TODO(johnsom) Optimize this to use the dicts and not need the
# DB lookups
db_amp = self.amphora_repo.get(db_apis.get_session(),
id=amphora[constants.ID])
try:
db_lb = self.loadbalancer_repo.get(
db_apis.get_session(),
id=loadbalancer[constants.LOADBALANCER_ID])
session = db_apis.get_session()
with session.begin():
db_amp = self.amphora_repo.get(session,
id=amphora[constants.ID])
db_lb = self.loadbalancer_repo.get(
session,
id=loadbalancer[constants.LOADBALANCER_ID])
self.amphora_driver.update_amphora_listeners(
db_lb, db_amp, timeout_dict)
except Exception as e:
LOG.error('Failed to update listeners on amphora %s. Skipping '
'this amphora as it is failing to update due to: %s',
db_amp.id, str(e))
self.amphora_repo.update(db_apis.get_session(), db_amp.id,
status=constants.ERROR)
with session.begin():
self.amphora_repo.update(session, db_amp.id,
status=constants.ERROR)
class AmphoraIndexListenerUpdate(BaseAmphoraTask):
@ -106,12 +109,14 @@ class AmphoraIndexListenerUpdate(BaseAmphoraTask):
try:
# TODO(johnsom) Optimize this to use the dicts and not need the
# DB lookups
db_amp = self.amphora_repo.get(
db_apis.get_session(),
id=amphorae[amphora_index][constants.ID])
db_lb = self.loadbalancer_repo.get(
db_apis.get_session(),
id=loadbalancer[constants.LOADBALANCER_ID])
session = db_apis.get_session()
with session.begin():
db_amp = self.amphora_repo.get(
session,
id=amphorae[amphora_index][constants.ID])
db_lb = self.loadbalancer_repo.get(
session,
id=loadbalancer[constants.LOADBALANCER_ID])
self.amphora_driver.update_amphora_listeners(
db_lb, db_amp, timeout_dict)
except Exception as e:
@ -119,8 +124,10 @@ class AmphoraIndexListenerUpdate(BaseAmphoraTask):
LOG.error('Failed to update listeners on amphora %s. Skipping '
'this amphora as it is failing to update due to: %s',
amphora_id, str(e))
self.amphora_repo.update(db_apis.get_session(), amphora_id,
status=constants.ERROR)
session = db_apis.get_session()
with session.begin():
self.amphora_repo.update(session, amphora_id,
status=constants.ERROR)
class ListenersUpdate(BaseAmphoraTask):
@ -128,8 +135,10 @@ class ListenersUpdate(BaseAmphoraTask):
def execute(self, loadbalancer_id):
"""Execute updates per listener for an amphora."""
loadbalancer = self.loadbalancer_repo.get(db_apis.get_session(),
id=loadbalancer_id)
session = db_apis.get_session()
with session.begin():
loadbalancer = self.loadbalancer_repo.get(session,
id=loadbalancer_id)
if loadbalancer:
self.amphora_driver.update(loadbalancer)
else:
@ -140,8 +149,10 @@ class ListenersUpdate(BaseAmphoraTask):
"""Handle failed listeners updates."""
LOG.warning("Reverting listeners updates.")
loadbalancer = self.loadbalancer_repo.get(db_apis.get_session(),
id=loadbalancer_id)
session = db_apis.get_session()
with session.begin():
loadbalancer = self.loadbalancer_repo.get(session,
id=loadbalancer_id)
for listener in loadbalancer.listeners:
self.task_utils.mark_listener_prov_status_error(
listener.id)
@ -152,12 +163,15 @@ class ListenersStart(BaseAmphoraTask):
def execute(self, loadbalancer, amphora=None):
"""Execute listener start routines for listeners on an amphora."""
db_lb = self.loadbalancer_repo.get(
db_apis.get_session(), id=loadbalancer[constants.LOADBALANCER_ID])
session = db_apis.get_session()
with session.begin():
db_lb = self.loadbalancer_repo.get(
session, id=loadbalancer[constants.LOADBALANCER_ID])
if db_lb.listeners:
if amphora is not None:
db_amp = self.amphora_repo.get(db_apis.get_session(),
id=amphora[constants.ID])
with session.begin():
db_amp = self.amphora_repo.get(session,
id=amphora[constants.ID])
else:
db_amp = amphora
self.amphora_driver.start(db_lb, db_amp)
@ -167,8 +181,10 @@ class ListenersStart(BaseAmphoraTask):
"""Handle failed listeners starts."""
LOG.warning("Reverting listeners starts.")
db_lb = self.loadbalancer_repo.get(
db_apis.get_session(), id=loadbalancer[constants.LOADBALANCER_ID])
session = db_apis.get_session()
with session.begin():
db_lb = self.loadbalancer_repo.get(
session, id=loadbalancer[constants.LOADBALANCER_ID])
for listener in db_lb.listeners:
self.task_utils.mark_listener_prov_status_error(listener.id)
@ -183,11 +199,13 @@ class AmphoraIndexListenersReload(BaseAmphoraTask):
return
# TODO(johnsom) Optimize this to use the dicts and not need the
# DB lookups
db_amp = self.amphora_repo.get(
db_apis.get_session(), id=amphorae[amphora_index][constants.ID])
db_lb = self.loadbalancer_repo.get(
db_apis.get_session(),
id=loadbalancer[constants.LOADBALANCER_ID])
session = db_apis.get_session()
with session.begin():
db_amp = self.amphora_repo.get(
session, id=amphorae[amphora_index][constants.ID])
db_lb = self.loadbalancer_repo.get(
session,
id=loadbalancer[constants.LOADBALANCER_ID])
if db_lb.listeners:
try:
self.amphora_driver.reload(db_lb, db_amp, timeout_dict)
@ -196,8 +214,9 @@ class AmphoraIndexListenersReload(BaseAmphoraTask):
LOG.warning('Failed to reload listeners on amphora %s. '
'Skipping this amphora as it is failing to '
'reload due to: %s', amphora_id, str(e))
self.amphora_repo.update(db_apis.get_session(), amphora_id,
status=constants.ERROR)
with session.begin():
self.amphora_repo.update(session, amphora_id,
status=constants.ERROR)
class ListenerDelete(BaseAmphoraTask):
@ -205,8 +224,10 @@ class ListenerDelete(BaseAmphoraTask):
def execute(self, listener):
"""Execute listener delete routines for an amphora."""
db_listener = self.listener_repo.get(
db_apis.get_session(), id=listener[constants.LISTENER_ID])
session = db_apis.get_session()
with session.begin():
db_listener = self.listener_repo.get(
session, id=listener[constants.LISTENER_ID])
self.amphora_driver.delete(db_listener)
LOG.debug("Deleted the listener on the vip")
@ -224,8 +245,10 @@ class AmphoraGetInfo(BaseAmphoraTask):
def execute(self, amphora):
"""Execute get_info routine for an amphora."""
db_amp = self.amphora_repo.get(db_apis.get_session(),
id=amphora[constants.ID])
session = db_apis.get_session()
with session.begin():
db_amp = self.amphora_repo.get(session,
id=amphora[constants.ID])
self.amphora_driver.get_info(db_amp)
@ -242,8 +265,10 @@ class AmphoraFinalize(BaseAmphoraTask):
def execute(self, amphora):
"""Execute finalize_amphora routine."""
db_amp = self.amphora_repo.get(db_apis.get_session(),
id=amphora.get(constants.ID))
session = db_apis.get_session()
with session.begin():
db_amp = self.amphora_repo.get(session,
id=amphora.get(constants.ID))
self.amphora_driver.finalize_amphora(db_amp)
LOG.debug("Finalized the amphora.")
@ -261,8 +286,10 @@ class AmphoraPostNetworkPlug(BaseAmphoraTask):
def execute(self, amphora, ports, amphora_network_config):
"""Execute post_network_plug routine."""
db_amp = self.amphora_repo.get(db_apis.get_session(),
id=amphora[constants.ID])
session = db_apis.get_session()
with session.begin():
db_amp = self.amphora_repo.get(session,
id=amphora[constants.ID])
for port in ports:
net = data_models.Network(**port.pop(constants.NETWORK))
@ -302,8 +329,10 @@ class AmphoraePostNetworkPlug(BaseAmphoraTask):
def execute(self, loadbalancer, updated_ports, amphorae_network_config):
"""Execute post_network_plug routine."""
amp_post_plug = AmphoraPostNetworkPlug()
db_lb = self.loadbalancer_repo.get(
db_apis.get_session(), id=loadbalancer[constants.LOADBALANCER_ID])
session = db_apis.get_session()
with session.begin():
db_lb = self.loadbalancer_repo.get(
session, id=loadbalancer[constants.LOADBALANCER_ID])
for amphora in db_lb.amphorae:
if amphora.id in updated_ports:
amp_post_plug.execute(amphora.to_dict(),
@ -314,8 +343,10 @@ class AmphoraePostNetworkPlug(BaseAmphoraTask):
"""Handle a failed post network plug."""
if isinstance(result, failure.Failure):
return
db_lb = self.loadbalancer_repo.get(
db_apis.get_session(), id=loadbalancer[constants.LOADBALANCER_ID])
session = db_apis.get_session()
with session.begin():
db_lb = self.loadbalancer_repo.get(
session, id=loadbalancer[constants.LOADBALANCER_ID])
LOG.warning("Reverting post network plug.")
for amphora in filter(
lambda amp: amp.status == constants.AMPHORA_ALLOCATED,
@ -329,10 +360,12 @@ class AmphoraPostVIPPlug(BaseAmphoraTask):
def execute(self, amphora, loadbalancer, amphorae_network_config):
"""Execute post_vip_routine."""
db_amp = self.amphora_repo.get(db_apis.get_session(),
id=amphora.get(constants.ID))
db_lb = self.loadbalancer_repo.get(
db_apis.get_session(), id=loadbalancer[constants.LOADBALANCER_ID])
session = db_apis.get_session()
with session.begin():
db_amp = self.amphora_repo.get(session,
id=amphora.get(constants.ID))
db_lb = self.loadbalancer_repo.get(
session, id=loadbalancer[constants.LOADBALANCER_ID])
vrrp_port = data_models.Port(
**amphorae_network_config[
amphora.get(constants.ID)][constants.VRRP_PORT])
@ -385,8 +418,10 @@ class AmphoraePostVIPPlug(BaseAmphoraTask):
def execute(self, loadbalancer, amphorae_network_config):
"""Execute post_vip_plug across the amphorae."""
amp_post_vip_plug = AmphoraPostVIPPlug()
db_lb = self.loadbalancer_repo.get(
db_apis.get_session(), id=loadbalancer[constants.LOADBALANCER_ID])
session = db_apis.get_session()
with session.begin():
db_lb = self.loadbalancer_repo.get(
session, id=loadbalancer[constants.LOADBALANCER_ID])
for amphora in db_lb.amphorae:
amp_post_vip_plug.execute(amphora.to_dict(),
loadbalancer,
@ -401,8 +436,10 @@ class AmphoraCertUpload(BaseAmphoraTask):
LOG.debug("Upload cert in amphora REST driver")
key = utils.get_compatible_server_certs_key_passphrase()
fer = fernet.Fernet(key)
db_amp = self.amphora_repo.get(db_apis.get_session(),
id=amphora.get(constants.ID))
session = db_apis.get_session()
with session.begin():
db_amp = self.amphora_repo.get(session,
id=amphora.get(constants.ID))
self.amphora_driver.upload_cert_amp(
db_amp, fer.decrypt(server_pem.encode('utf-8')))
@ -415,8 +452,10 @@ class AmphoraUpdateVRRPInterface(BaseAmphoraTask):
try:
# TODO(johnsom) Optimize this to use the dicts and not need the
# DB lookups
db_amp = self.amphora_repo.get(db_apis.get_session(),
id=amphora[constants.ID])
session = db_apis.get_session()
with session.begin():
db_amp = self.amphora_repo.get(session,
id=amphora[constants.ID])
interface = self.amphora_driver.get_interface_from_ip(
db_amp, db_amp.vrrp_ip, timeout_dict=timeout_dict)
except Exception as e:
@ -424,13 +463,15 @@ class AmphoraUpdateVRRPInterface(BaseAmphoraTask):
LOG.error('Failed to get amphora VRRP interface on amphora '
'%s. Skipping this amphora as it is failing due to: '
'%s', amphora.get(constants.ID), str(e))
self.amphora_repo.update(db_apis.get_session(),
amphora.get(constants.ID),
status=constants.ERROR)
with session.begin():
self.amphora_repo.update(session,
amphora.get(constants.ID),
status=constants.ERROR)
return None
self.amphora_repo.update(db_apis.get_session(), amphora[constants.ID],
vrrp_interface=interface)
with session.begin():
self.amphora_repo.update(session, amphora[constants.ID],
vrrp_interface=interface)
return interface
@ -442,8 +483,10 @@ class AmphoraIndexUpdateVRRPInterface(BaseAmphoraTask):
try:
# TODO(johnsom) Optimize this to use the dicts and not need the
# DB lookups
db_amp = self.amphora_repo.get(db_apis.get_session(),
id=amphora_id)
session = db_apis.get_session()
with session.begin():
db_amp = self.amphora_repo.get(session,
id=amphora_id)
interface = self.amphora_driver.get_interface_from_ip(
db_amp, db_amp.vrrp_ip, timeout_dict=timeout_dict)
except Exception as e:
@ -451,12 +494,14 @@ class AmphoraIndexUpdateVRRPInterface(BaseAmphoraTask):
LOG.error('Failed to get amphora VRRP interface on amphora '
'%s. Skipping this amphora as it is failing due to: '
'%s', amphora_id, str(e))
self.amphora_repo.update(db_apis.get_session(), amphora_id,
status=constants.ERROR)
with session.begin():
self.amphora_repo.update(session, amphora_id,
status=constants.ERROR)
return None
self.amphora_repo.update(db_apis.get_session(), amphora_id,
vrrp_interface=interface)
with session.begin():
self.amphora_repo.update(session, amphora_id,
vrrp_interface=interface)
return interface
@ -473,10 +518,12 @@ class AmphoraVRRPUpdate(BaseAmphoraTask):
try:
# TODO(johnsom) Optimize this to use the dicts and not need the
# DB lookups
db_amp = self.amphora_repo.get(db_apis.get_session(),
id=amphora_id)
loadbalancer = self.loadbalancer_repo.get(db_apis.get_session(),
id=loadbalancer_id)
session = db_apis.get_session()
with session.begin():
db_amp = self.amphora_repo.get(session,
id=amphora_id)
loadbalancer = self.loadbalancer_repo.get(session,
id=loadbalancer_id)
db_amp.vrrp_interface = amp_vrrp_int
self.amphora_driver.update_vrrp_conf(
loadbalancer, amphorae_network_config, db_amp, timeout_dict)
@ -484,8 +531,9 @@ class AmphoraVRRPUpdate(BaseAmphoraTask):
LOG.error('Failed to update VRRP configuration amphora %s. '
'Skipping this amphora as it is failing to update due '
'to: %s', amphora_id, str(e))
self.amphora_repo.update(db_apis.get_session(), amphora_id,
status=constants.ERROR)
with session.begin():
self.amphora_repo.update(session, amphora_id,
status=constants.ERROR)
LOG.debug("Uploaded VRRP configuration of amphora %s.", amphora_id)
@ -503,10 +551,12 @@ class AmphoraIndexVRRPUpdate(BaseAmphoraTask):
try:
# TODO(johnsom) Optimize this to use the dicts and not need the
# DB lookups
db_amp = self.amphora_repo.get(db_apis.get_session(),
id=amphora_id)
loadbalancer = self.loadbalancer_repo.get(db_apis.get_session(),
id=loadbalancer_id)
session = db_apis.get_session()
with session.begin():
db_amp = self.amphora_repo.get(session,
id=amphora_id)
loadbalancer = self.loadbalancer_repo.get(session,
id=loadbalancer_id)
db_amp.vrrp_interface = amp_vrrp_int
self.amphora_driver.update_vrrp_conf(
loadbalancer, amphorae_network_config, db_amp, timeout_dict)
@ -514,8 +564,9 @@ class AmphoraIndexVRRPUpdate(BaseAmphoraTask):
LOG.error('Failed to update VRRP configuration amphora %s. '
'Skipping this amphora as it is failing to update due '
'to: %s', amphora_id, str(e))
self.amphora_repo.update(db_apis.get_session(), amphora_id,
status=constants.ERROR)
with session.begin():
self.amphora_repo.update(session, amphora_id,
status=constants.ERROR)
return
LOG.debug("Uploaded VRRP configuration of amphora %s.", amphora_id)
@ -529,8 +580,10 @@ class AmphoraVRRPStart(BaseAmphoraTask):
def execute(self, amphora, timeout_dict=None):
# TODO(johnsom) Optimize this to use the dicts and not need the
# DB lookups
db_amp = self.amphora_repo.get(
db_apis.get_session(), id=amphora[constants.ID])
session = db_apis.get_session()
with session.begin():
db_amp = self.amphora_repo.get(
session, id=amphora[constants.ID])
self.amphora_driver.start_vrrp_service(db_amp, timeout_dict)
LOG.debug("Started VRRP on amphora %s.", amphora[constants.ID])
@ -545,15 +598,18 @@ class AmphoraIndexVRRPStart(BaseAmphoraTask):
# TODO(johnsom) Optimize this to use the dicts and not need the
# DB lookups
amphora_id = amphorae[amphora_index][constants.ID]
db_amp = self.amphora_repo.get(db_apis.get_session(), id=amphora_id)
session = db_apis.get_session()
with session.begin():
db_amp = self.amphora_repo.get(session, id=amphora_id)
try:
self.amphora_driver.start_vrrp_service(db_amp, timeout_dict)
except Exception as e:
LOG.error('Failed to start VRRP on amphora %s. '
'Skipping this amphora as it is failing to start due '
'to: %s', amphora_id, str(e))
self.amphora_repo.update(db_apis.get_session(), amphora_id,
status=constants.ERROR)
with session.begin():
self.amphora_repo.update(session, amphora_id,
status=constants.ERROR)
return
LOG.debug("Started VRRP on amphora %s.",
amphorae[amphora_index][constants.ID])
@ -565,8 +621,10 @@ class AmphoraComputeConnectivityWait(BaseAmphoraTask):
def execute(self, amphora, raise_retry_exception=False):
"""Execute get_info routine for an amphora until it responds."""
try:
db_amphora = self.amphora_repo.get(
db_apis.get_session(), id=amphora.get(constants.ID))
session = db_apis.get_session()
with session.begin():
db_amphora = self.amphora_repo.get(
session, id=amphora.get(constants.ID))
amp_info = self.amphora_driver.get_info(
db_amphora, raise_retry_exception=raise_retry_exception)
LOG.debug('Successfuly connected to amphora %s: %s',
@ -576,9 +634,10 @@ class AmphoraComputeConnectivityWait(BaseAmphoraTask):
"This either means the compute driver failed to fully "
"boot the instance inside the timeout interval or the "
"instance is not reachable via the lb-mgmt-net.")
self.amphora_repo.update(db_apis.get_session(),
amphora.get(constants.ID),
status=constants.ERROR)
with session.begin():
self.amphora_repo.update(session,
amphora.get(constants.ID),
status=constants.ERROR)
raise
@ -597,8 +656,10 @@ class AmphoraConfigUpdate(BaseAmphoraTask):
agent_cfg_tmpl = agent_jinja_cfg.AgentJinjaTemplater()
agent_config = agent_cfg_tmpl.build_agent_config(
amphora.get(constants.ID), topology)
db_amp = self.amphora_repo.get(db_apis.get_session(),
id=amphora[constants.ID])
session = db_apis.get_session()
with session.begin():
db_amp = self.amphora_repo.get(session,
id=amphora[constants.ID])
# Push the new configuration to the amphora
try:
self.amphora_driver.update_amphora_agent_config(db_amp,

View File

@ -210,8 +210,10 @@ class DeleteAmphoraeOnLoadBalancer(BaseComputeTask):
"""
def execute(self, loadbalancer):
db_lb = self.loadbalancer_repo.get(
db_apis.get_session(), id=loadbalancer[constants.LOADBALANCER_ID])
session = db_apis.get_session()
with session.begin():
db_lb = self.loadbalancer_repo.get(
session, id=loadbalancer[constants.LOADBALANCER_ID])
for amp in db_lb.amphorae:
# The compute driver will already handle NotFound
try:

File diff suppressed because it is too large Load Diff

View File

@ -74,8 +74,10 @@ class CalculateAmphoraDelta(BaseNetworkTask):
else:
management_nets = CONF.controller_worker.amp_boot_network_list
db_lb = self.loadbalancer_repo.get(
db_apis.get_session(), id=loadbalancer[constants.LOADBALANCER_ID])
session = db_apis.get_session()
with session.begin():
db_lb = self.loadbalancer_repo.get(
session, id=loadbalancer[constants.LOADBALANCER_ID])
desired_subnet_to_net_map = {}
for mgmt_net_id in management_nets:
@ -183,8 +185,10 @@ class CalculateDelta(BaseNetworkTask):
calculate_amp = CalculateAmphoraDelta()
deltas = {}
db_lb = self.loadbalancer_repo.get(
db_apis.get_session(), id=loadbalancer[constants.LOADBALANCER_ID])
session = db_apis.get_session()
with session.begin():
db_lb = self.loadbalancer_repo.get(
session, id=loadbalancer[constants.LOADBALANCER_ID])
for amphora in filter(
lambda amp: amp.status == constants.AMPHORA_ALLOCATED,
db_lb.amphorae):
@ -317,8 +321,10 @@ class HandleNetworkDelta(BaseNetworkTask):
def execute(self, amphora, delta):
"""Handle network plugging based off deltas."""
db_amp = self.amphora_repo.get(db_apis.get_session(),
id=amphora.get(constants.ID))
session = db_apis.get_session()
with session.begin():
db_amp = self.amphora_repo.get(session,
id=amphora.get(constants.ID))
updated_ports = {}
for nic in delta[constants.ADD_NICS]:
subnet_id = nic[constants.FIXED_IPS][0][constants.SUBNET_ID]
@ -432,8 +438,10 @@ class HandleNetworkDeltas(BaseNetworkTask):
def execute(self, deltas, loadbalancer):
"""Handle network plugging based off deltas."""
db_lb = self.loadbalancer_repo.get(
db_apis.get_session(), id=loadbalancer[constants.LOADBALANCER_ID])
session = db_apis.get_session()
with session.begin():
db_lb = self.loadbalancer_repo.get(
session, id=loadbalancer[constants.LOADBALANCER_ID])
amphorae = {amp.id: amp for amp in db_lb.amphorae}
updated_ports = {}
@ -481,9 +489,11 @@ class PlugVIP(BaseNetworkTask):
LOG.debug("Plumbing VIP for loadbalancer id: %s",
loadbalancer[constants.LOADBALANCER_ID])
db_lb = self.loadbalancer_repo.get(
db_apis.get_session(),
id=loadbalancer[constants.LOADBALANCER_ID])
session = db_apis.get_session()
with session.begin():
db_lb = self.loadbalancer_repo.get(
session,
id=loadbalancer[constants.LOADBALANCER_ID])
amps_data = self.network_driver.plug_vip(db_lb,
db_lb.vip)
return [amp.to_dict() for amp in amps_data]
@ -496,9 +506,11 @@ class PlugVIP(BaseNetworkTask):
LOG.warning("Unable to plug VIP for loadbalancer id %s",
loadbalancer[constants.LOADBALANCER_ID])
db_lb = self.loadbalancer_repo.get(
db_apis.get_session(),
id=loadbalancer[constants.LOADBALANCER_ID])
session = db_apis.get_session()
with session.begin():
db_lb = self.loadbalancer_repo.get(
session,
id=loadbalancer[constants.LOADBALANCER_ID])
try:
# Make sure we have the current port IDs for cleanup
for amp_data in result:
@ -524,8 +536,10 @@ class UpdateVIPSecurityGroup(BaseNetworkTask):
LOG.debug("Setting up VIP SG for load balancer id: %s",
loadbalancer_id)
db_lb = self.loadbalancer_repo.get(
db_apis.get_session(), id=loadbalancer_id)
session = db_apis.get_session()
with session.begin():
db_lb = self.loadbalancer_repo.get(
session, id=loadbalancer_id)
sg_id = self.network_driver.update_vip_sg(db_lb, db_lb.vip)
LOG.info("Set up VIP SG %s for load balancer %s complete",
@ -557,11 +571,13 @@ class PlugVIPAmphora(BaseNetworkTask):
LOG.debug("Plumbing VIP for amphora id: %s",
amphora.get(constants.ID))
db_amp = self.amphora_repo.get(db_apis.get_session(),
id=amphora.get(constants.ID))
db_subnet = self.network_driver.get_subnet(subnet[constants.ID])
db_lb = self.loadbalancer_repo.get(
db_apis.get_session(), id=loadbalancer[constants.LOADBALANCER_ID])
session = db_apis.get_session()
with session.begin():
db_amp = self.amphora_repo.get(session,
id=amphora.get(constants.ID))
db_subnet = self.network_driver.get_subnet(subnet[constants.ID])
db_lb = self.loadbalancer_repo.get(
session, id=loadbalancer[constants.LOADBALANCER_ID])
amp_data = self.network_driver.plug_aap_port(
db_lb, db_lb.vip, db_amp, db_subnet)
return amp_data.to_dict()
@ -576,14 +592,17 @@ class PlugVIPAmphora(BaseNetworkTask):
loadbalancer[constants.LOADBALANCER_ID])
try:
db_amp = self.amphora_repo.get(db_apis.get_session(),
id=amphora.get(constants.ID))
db_amp.vrrp_port_id = result[constants.VRRP_PORT_ID]
db_amp.ha_port_id = result[constants.HA_PORT_ID]
db_subnet = self.network_driver.get_subnet(subnet[constants.ID])
db_lb = self.loadbalancer_repo.get(
db_apis.get_session(),
id=loadbalancer[constants.LOADBALANCER_ID])
session = db_apis.get_session()
with session.begin():
db_amp = self.amphora_repo.get(session,
id=amphora.get(constants.ID))
db_amp.vrrp_port_id = result[constants.VRRP_PORT_ID]
db_amp.ha_port_id = result[constants.HA_PORT_ID]
db_subnet = self.network_driver.get_subnet(
subnet[constants.ID])
db_lb = self.loadbalancer_repo.get(
session,
id=loadbalancer[constants.LOADBALANCER_ID])
self.network_driver.unplug_aap_port(db_lb.vip,
db_amp, db_subnet)
@ -600,9 +619,11 @@ class UnplugVIP(BaseNetworkTask):
LOG.debug("Unplug vip on amphora")
try:
db_lb = self.loadbalancer_repo.get(
db_apis.get_session(),
id=loadbalancer[constants.LOADBALANCER_ID])
session = db_apis.get_session()
with session.begin():
db_lb = self.loadbalancer_repo.get(
session,
id=loadbalancer[constants.LOADBALANCER_ID])
self.network_driver.unplug_vip(db_lb, db_lb.vip)
except Exception:
LOG.exception("Unable to unplug vip from load balancer %s",
@ -621,8 +642,10 @@ class AllocateVIP(BaseNetworkTask):
loadbalancer[constants.VIP_SUBNET_ID],
loadbalancer[constants.VIP_ADDRESS],
loadbalancer[constants.LOADBALANCER_ID])
db_lb = self.loadbalancer_repo.get(
db_apis.get_session(), id=loadbalancer[constants.LOADBALANCER_ID])
session = db_apis.get_session()
with session.begin():
db_lb = self.loadbalancer_repo.get(
session, id=loadbalancer[constants.LOADBALANCER_ID])
vip, additional_vips = self.network_driver.allocate_vip(db_lb)
LOG.info("Allocated vip with port id %s, subnet id %s, ip address %s "
"for load balancer %s",
@ -682,10 +705,12 @@ class DeallocateVIP(BaseNetworkTask):
# will need access to the load balancer that the vip is/was attached
# to. However the data model serialization for the vip does not give a
# backref to the loadbalancer if accessed through the loadbalancer.
db_lb = self.loadbalancer_repo.get(
db_apis.get_session(), id=loadbalancer[constants.LOADBALANCER_ID])
vip = db_lb.vip
vip.load_balancer = db_lb
session = db_apis.get_session()
with session.begin():
db_lb = self.loadbalancer_repo.get(
session, id=loadbalancer[constants.LOADBALANCER_ID])
vip = db_lb.vip
vip.load_balancer = db_lb
self.network_driver.deallocate_vip(vip)
@ -693,8 +718,10 @@ class UpdateVIP(BaseNetworkTask):
"""Task to update a VIP."""
def execute(self, listeners):
loadbalancer = self.loadbalancer_repo.get(
db_apis.get_session(), id=listeners[0][constants.LOADBALANCER_ID])
session = db_apis.get_session()
with session.begin():
loadbalancer = self.loadbalancer_repo.get(
session, id=listeners[0][constants.LOADBALANCER_ID])
LOG.debug("Updating VIP of load_balancer %s.", loadbalancer.id)
@ -705,8 +732,10 @@ class UpdateVIPForDelete(BaseNetworkTask):
"""Task to update a VIP for listener delete flows."""
def execute(self, loadbalancer_id):
loadbalancer = self.loadbalancer_repo.get(
db_apis.get_session(), id=loadbalancer_id)
session = db_apis.get_session()
with session.begin():
loadbalancer = self.loadbalancer_repo.get(
session, id=loadbalancer_id)
LOG.debug("Updating VIP for listener delete on load_balancer %s.",
loadbalancer.id)
self.network_driver.update_vip(loadbalancer, for_delete=True)
@ -717,10 +746,12 @@ class GetAmphoraNetworkConfigs(BaseNetworkTask):
def execute(self, loadbalancer, amphora=None):
LOG.debug("Retrieving vip network details.")
db_amp = self.amphora_repo.get(db_apis.get_session(),
id=amphora.get(constants.ID))
db_lb = self.loadbalancer_repo.get(
db_apis.get_session(), id=loadbalancer[constants.LOADBALANCER_ID])
session = db_apis.get_session()
with session.begin():
db_amp = self.amphora_repo.get(session,
id=amphora.get(constants.ID))
db_lb = self.loadbalancer_repo.get(
session, id=loadbalancer[constants.LOADBALANCER_ID])
db_configs = self.network_driver.get_network_configs(
db_lb, amphora=db_amp)
provider_dict = {}
@ -734,9 +765,11 @@ class GetAmphoraNetworkConfigsByID(BaseNetworkTask):
def execute(self, loadbalancer_id, amphora_id=None):
LOG.debug("Retrieving vip network details.")
loadbalancer = self.loadbalancer_repo.get(db_apis.get_session(),
id=loadbalancer_id)
amphora = self.amphora_repo.get(db_apis.get_session(), id=amphora_id)
session = db_apis.get_session()
with session.begin():
loadbalancer = self.loadbalancer_repo.get(session,
id=loadbalancer_id)
amphora = self.amphora_repo.get(session, id=amphora_id)
db_configs = self.network_driver.get_network_configs(loadbalancer,
amphora=amphora)
provider_dict = {}
@ -750,8 +783,10 @@ class GetAmphoraeNetworkConfigs(BaseNetworkTask):
def execute(self, loadbalancer_id):
LOG.debug("Retrieving vip network details.")
db_lb = self.loadbalancer_repo.get(
db_apis.get_session(), id=loadbalancer_id)
session = db_apis.get_session()
with session.begin():
db_lb = self.loadbalancer_repo.get(
session, id=loadbalancer_id)
db_configs = self.network_driver.get_network_configs(db_lb)
provider_dict = {}
for amp_id, amp_conf in db_configs.items():
@ -763,8 +798,10 @@ class FailoverPreparationForAmphora(BaseNetworkTask):
"""Task to prepare an amphora for failover."""
def execute(self, amphora):
db_amp = self.amphora_repo.get(db_apis.get_session(),
id=amphora[constants.ID])
session = db_apis.get_session()
with session.begin():
db_amp = self.amphora_repo.get(session,
id=amphora[constants.ID])
LOG.debug("Prepare amphora %s for failover.", amphora[constants.ID])
self.network_driver.failover_preparation(db_amp)
@ -799,8 +836,10 @@ class PlugPorts(BaseNetworkTask):
"""Task to plug neutron ports into a compute instance."""
def execute(self, amphora, ports):
db_amp = self.amphora_repo.get(db_apis.get_session(),
id=amphora[constants.ID])
session = db_apis.get_session()
with session.begin():
db_amp = self.amphora_repo.get(session,
id=amphora[constants.ID])
for port in ports:
LOG.debug('Plugging port ID: %(port_id)s into compute instance: '
'%(compute_id)s.',
@ -816,15 +855,17 @@ class ApplyQos(BaseNetworkTask):
is_revert=False, request_qos_id=None):
"""Call network driver to apply QoS Policy on the vrrp ports."""
if not amps_data:
db_lb = self.loadbalancer_repo.get(
db_apis.get_session(),
id=loadbalancer[constants.LOADBALANCER_ID])
amps_data = db_lb.amphorae
session = db_apis.get_session()
with session.begin():
if not amps_data:
db_lb = self.loadbalancer_repo.get(
session,
id=loadbalancer[constants.LOADBALANCER_ID])
amps_data = db_lb.amphorae
amps_data = [amp
for amp in amps_data
if amp.status == constants.AMPHORA_ALLOCATED]
amps_data = [amp
for amp in amps_data
if amp.status == constants.AMPHORA_ALLOCATED]
apply_qos = ApplyQosAmphora()
for amp_data in amps_data:
@ -833,9 +874,11 @@ class ApplyQos(BaseNetworkTask):
def execute(self, loadbalancer, amps_data=None, update_dict=None):
"""Apply qos policy on the vrrp ports which are related with vip."""
db_lb = self.loadbalancer_repo.get(
db_apis.get_session(),
id=loadbalancer[constants.LOADBALANCER_ID])
session = db_apis.get_session()
with session.begin():
db_lb = self.loadbalancer_repo.get(
session,
id=loadbalancer[constants.LOADBALANCER_ID])
qos_policy_id = db_lb.vip.qos_policy_id
if not qos_policy_id and (

View File

@ -12,15 +12,13 @@
# License for the specific language governing permissions and limitations
# under the License.
import contextlib
import time
from sqlalchemy.sql.expression import select
from oslo_config import cfg
from oslo_db.sqlalchemy import session as db_session
from oslo_db.sqlalchemy import enginefacade
from oslo_log import log as logging
from oslo_utils import excutils
LOG = logging.getLogger(__name__)
_FACADE = None
@ -29,32 +27,37 @@ _FACADE = None
def _create_facade_lazily():
global _FACADE
if _FACADE is None:
_FACADE = db_session.EngineFacade.from_config(cfg.CONF, sqlite_fk=True)
return _FACADE
_FACADE = True
enginefacade.configure(sqlite_fk=True, expire_on_commit=True)
def _get_transaction_context(reader=False):
_create_facade_lazily()
# TODO(gthiemonge) Create and use new functions to get read-only sessions
if reader:
context = enginefacade.reader
else:
context = enginefacade.writer
return context
def _get_sessionmaker(reader=False):
context = _get_transaction_context(reader)
return context.get_sessionmaker()
def get_engine():
facade = _create_facade_lazily()
return facade.get_engine()
context = _get_transaction_context()
return context.get_engine()
def get_session(expire_on_commit=True, autocommit=True):
def get_session():
"""Helper method to grab session."""
facade = _create_facade_lazily()
return facade.get_session(expire_on_commit=expire_on_commit,
autocommit=autocommit)
return _get_sessionmaker()()
@contextlib.contextmanager
def get_lock_session():
"""Context manager for using a locking (not auto-commit) session."""
lock_session = get_session(autocommit=False)
try:
yield lock_session
lock_session.commit()
except Exception:
with excutils.save_and_reraise_exception():
lock_session.rollback()
def session():
return _get_sessionmaker()
def wait_for_connection(exit_event):

View File

@ -169,6 +169,8 @@ def create_pool(pool_dict, lb_id=None):
def create_member(member_dict, pool_id, has_health_monitor=False):
if not member_dict.get('id'):
member_dict['id'] = uuidutils.generate_uuid()
member_dict['pool_id'] = pool_id
member_dict[constants.PROVISIONING_STATUS] = constants.PENDING_CREATE
if has_health_monitor:

File diff suppressed because it is too large Load Diff

View File

@ -31,13 +31,13 @@ class StatsUpdateDb(stats_base.StatsDriverMixin):
def update_stats(self, listener_stats, deltas=False):
"""This function is to update the db with listener stats"""
session = db_api.get_session()
for stats_object in listener_stats:
LOG.debug("Updating listener stats in db for listener `%s` / "
"amphora `%s`: %s",
stats_object.listener_id, stats_object.amphora_id,
stats_object.get_stats())
if deltas:
self.listener_stats_repo.increment(session, stats_object)
else:
self.listener_stats_repo.replace(session, stats_object)
with db_api.session().begin() as session:
for stats_object in listener_stats:
LOG.debug("Updating listener stats in db for listener `%s` / "
"amphora `%s`: %s",
stats_object.listener_id, stats_object.amphora_id,
stats_object.get_stats())
if deltas:
self.listener_stats_repo.increment(session, stats_object)
else:
self.listener_stats_repo.replace(session, stats_object)

View File

@ -29,21 +29,6 @@ class WarningsFixture(fixtures.Fixture):
'error',
category=sqla_exc.SADeprecationWarning)
warnings.filterwarnings(
'ignore',
message='The Session.begin.subtransactions flag is deprecated ',
category=sqla_exc.SADeprecationWarning)
warnings.filterwarnings(
'ignore',
message='The Session.autocommit parameter is deprecated ',
category=sqla_exc.SADeprecationWarning)
warnings.filterwarnings(
'ignore',
message='The current statement is being autocommitted ',
category=sqla_exc.SADeprecationWarning)
warnings.filterwarnings(
'ignore',
message='The create_engine.convert_unicode parameter and ',

View File

@ -31,6 +31,9 @@ from octavia.tests.common import sample_certs
from octavia.tests.common import sample_data_models
from octavia.tests.functional.db import base
from oslo_log import log as logging
LOG = logging.getLogger(__name__)
CONF = cfg.CONF
@ -267,6 +270,8 @@ class DriverAgentTest(base.OctaviaDBTestBase):
l7rule2_dict = copy.deepcopy(self.sample_data.test_l7rule2_dict)
self.repos.l7rule.create(self.session, **l7rule2_dict)
self.session.commit()
self.provider_lb_dict = copy.deepcopy(
self.sample_data.provider_loadbalancer_tree_dict)
self.provider_lb_dict[lib_consts.POOLS] = [self.provider_pool_dict]
@ -594,6 +599,7 @@ class DriverAgentTest(base.OctaviaDBTestBase):
# Add a new member
member_dict = copy.deepcopy(self.sample_data.test_member2_dict)
self.repos.member.create(self.session, **member_dict)
self.session.commit()
result = self.driver_lib.get_member(member_dict[lib_consts.ID])
self.assertEqual(self.sample_data.provider_member2_dict,

View File

@ -269,7 +269,7 @@ class BaseAPITest(base_db_test.OctaviaDBTestBase):
def create_listener_stats(self, listener_id, amphora_id):
db_ls = self.listener_stats_repo.create(
db_api.get_session(), listener_id=listener_id,
self.session, listener_id=listener_id,
amphora_id=amphora_id, bytes_in=0,
bytes_out=0, active_connections=0, total_connections=0,
request_errors=0)
@ -280,7 +280,7 @@ class BaseAPITest(base_db_test.OctaviaDBTestBase):
active_connections=0,
total_connections=0, request_errors=0):
db_ls = self.listener_stats_repo.create(
db_api.get_session(), listener_id=listener_id,
self.session, listener_id=listener_id,
amphora_id=amphora_id, bytes_in=bytes_in,
bytes_out=bytes_out, active_connections=active_connections,
total_connections=total_connections, request_errors=request_errors)
@ -478,9 +478,11 @@ class BaseAPITest(base_db_test.OctaviaDBTestBase):
@staticmethod
def set_object_status(repo, id_, provisioning_status=constants.ACTIVE,
operating_status=constants.ONLINE):
repo.update(db_api.get_session(), id_,
session = db_api.get_session()
repo.update(session, id_,
provisioning_status=provisioning_status,
operating_status=operating_status)
session.commit()
def assert_final_listener_statuses(self, lb_id, listener_id, delete=False):
expected_prov_status = constants.ACTIVE

View File

@ -40,6 +40,7 @@ class TestAmphora(base.BaseAPITest):
self.project_id = self.lb.get('project_id')
self.set_lb_status(self.lb_id)
self.amp_args = {
'id': uuidutils.generate_uuid(),
'load_balancer_id': self.lb_id,
'compute_id': uuidutils.generate_uuid(),
'lb_network_ip': '192.168.1.2',
@ -91,9 +92,11 @@ class TestAmphora(base.BaseAPITest):
'total_connections': 9}
self.ref_amp_stats = [self.listener1_amp_stats,
self.listener2_amp_stats]
self.session.commit()
def _create_additional_amp(self):
amp_args = {
'id': uuidutils.generate_uuid(),
'load_balancer_id': None,
'compute_id': uuidutils.generate_uuid(),
'lb_network_ip': '192.168.1.2',
@ -109,7 +112,8 @@ class TestAmphora(base.BaseAPITest):
'vrrp_id': 1,
'vrrp_priority': 100,
}
return self.amphora_repo.create(self.session, **amp_args)
with self.session.begin():
return self.amphora_repo.create(self.session, **amp_args)
def _assert_amp_equal(self, source, response):
self.assertEqual(source.pop('load_balancer_id'),
@ -130,9 +134,11 @@ class TestAmphora(base.BaseAPITest):
@mock.patch('oslo_messaging.RPCClient.cast')
def test_delete(self, mock_cast):
self.amp_args = {
'id': uuidutils.generate_uuid(),
'status': constants.AMPHORA_READY,
}
amp = self.amphora_repo.create(self.session, **self.amp_args)
with self.session.begin():
amp = self.amphora_repo.create(self.session, **self.amp_args)
self.delete(self.AMPHORA_PATH.format(
amphora_id=amp.id), status=204)
@ -154,9 +160,11 @@ class TestAmphora(base.BaseAPITest):
@mock.patch('oslo_messaging.RPCClient.cast')
def test_delete_immutable(self, mock_cast):
self.amp_args = {
'id': uuidutils.generate_uuid(),
'status': constants.AMPHORA_ALLOCATED,
}
amp = self.amphora_repo.create(self.session, **self.amp_args)
with self.session.begin():
amp = self.amphora_repo.create(self.session, **self.amp_args)
self.delete(self.AMPHORA_PATH.format(
amphora_id=amp.id), status=409)
@ -166,9 +174,11 @@ class TestAmphora(base.BaseAPITest):
@mock.patch('oslo_messaging.RPCClient.cast')
def test_delete_authorized(self, mock_cast):
self.amp_args = {
'id': uuidutils.generate_uuid(),
'status': constants.AMPHORA_READY,
}
amp = self.amphora_repo.create(self.session, **self.amp_args)
with self.session.begin():
amp = self.amphora_repo.create(self.session, **self.amp_args)
self.conf = self.useFixture(oslo_fixture.Config(cfg.CONF))
auth_strategy = self.conf.conf.api_settings.get('auth_strategy')
@ -208,9 +218,11 @@ class TestAmphora(base.BaseAPITest):
@mock.patch('oslo_messaging.RPCClient.cast')
def test_delete_not_authorized(self, mock_cast):
self.amp_args = {
'id': uuidutils.generate_uuid(),
'status': constants.AMPHORA_READY,
}
amp = self.amphora_repo.create(self.session, **self.amp_args)
with self.session.begin():
amp = self.amphora_repo.create(self.session, **self.amp_args)
self.conf = self.useFixture(oslo_fixture.Config(cfg.CONF))
auth_strategy = self.conf.conf.api_settings.get('auth_strategy')
@ -567,6 +579,7 @@ class TestAmphora(base.BaseAPITest):
self.lb2_id = self.lb2.get('id')
self.set_lb_status(self.lb2_id)
self.amp2_args = {
'id': uuidutils.generate_uuid(),
'load_balancer_id': self.lb2_id,
'compute_id': uuidutils.generate_uuid(),
'lb_network_ip': '192.168.1.20',

View File

@ -531,6 +531,7 @@ class TestAvailabilityZoneProfiles(base.BaseAPITest):
def test_delete_authorized(self):
azp = self.create_availability_zone_profile(
'test1', 'noop_driver', '{"compute_zone": "my_az_1"}')
self.session.commit()
self.assertTrue(uuidutils.is_uuid_like(azp.get('id')))
self.conf = self.useFixture(oslo_fixture.Config(cfg.CONF))
auth_strategy = self.conf.conf.api_settings.get('auth_strategy')

View File

@ -2832,6 +2832,7 @@ class TestListener(base.BaseAPITest):
bytes_out=random.randint(1, 9),
total_connections=random.randint(1, 9),
request_errors=random.randint(1, 9))
self.session.commit()
response = self._getStats(li['id'])
self.assertEqual(ls['bytes_in'], response['bytes_in'])
@ -2862,6 +2863,7 @@ class TestListener(base.BaseAPITest):
self.conf = self.useFixture(oslo_fixture.Config(cfg.CONF))
auth_strategy = self.conf.conf.api_settings.get('auth_strategy')
self.conf.config(group='api_settings', auth_strategy=constants.TESTING)
self.session.commit()
with mock.patch.object(octavia.common.context.RequestContext,
'project_id',

View File

@ -2657,6 +2657,10 @@ class TestLoadBalancer(base.BaseAPITest):
prov='bad_driver', user_msg='broken')
self.delete(self.LB_PATH.format(lb_id=api_lb.get('id')), status=501)
response = self.get(self.LB_PATH.format(
lb_id=api_lb.get('id'))).json.get(self.root_tag)
self.assertEqual(constants.ACTIVE, response['provisioning_status'])
@mock.patch('octavia.api.drivers.utils.call_provider')
def test_create_with_provider_unsupport_option(self, mock_provider):
mock_provider.side_effect = exceptions.ProviderUnsupportedOptionError(
@ -4045,6 +4049,7 @@ class TestLoadBalancerGraph(base.BaseAPITest):
bytes_out=random.randint(1, 9),
total_connections=random.randint(1, 9),
request_errors=random.randint(1, 9))
self.session.commit()
response = self._getStats(lb['id'])
self.assertEqual(ls['bytes_in'], response['bytes_in'])
@ -4072,6 +4077,7 @@ class TestLoadBalancerGraph(base.BaseAPITest):
bytes_out=random.randint(1, 9),
total_connections=random.randint(1, 9),
request_errors=random.randint(1, 9))
self.session.commit()
auth_strategy = self.conf.conf.api_settings.get('auth_strategy')
self.conf.config(group='api_settings', auth_strategy=constants.TESTING)

View File

@ -33,6 +33,7 @@ class OctaviaDBTestBase(test_base.BaseTestCase):
def setUp(self, connection_string='sqlite://'):
super().setUp()
self.connection_string = connection_string
self.warning_fixture = self.useFixture(oc_fixtures.WarningsFixture())
# NOTE(blogan): doing this for now because using the engine and
@ -42,21 +43,12 @@ class OctaviaDBTestBase(test_base.BaseTestCase):
conf = self.useFixture(oslo_fixture.Config(config.cfg.CONF))
conf.config(group="database", connection=connection_string)
# We need to get our own Facade so that the file backed sqlite tests
# don't use the _FACADE singleton. Some tests will use in-memory
# sqlite, some will use a file backed sqlite.
if 'sqlite:///' in connection_string:
facade = db_session.EngineFacade.from_config(cfg.CONF,
sqlite_fk=True)
engine = facade.get_engine()
self.session = facade.get_session(expire_on_commit=True,
autocommit=True)
else:
engine = db_api.get_engine()
self.session = db_api.get_session()
engine, self.session = self._get_db_engine_session()
base_models.BASE.metadata.create_all(engine)
self._seed_lookup_tables(self.session)
with self.session.begin():
self._seed_lookup_tables(self.session)
def clear_tables():
"""Unregister all data models."""
@ -67,6 +59,20 @@ class OctaviaDBTestBase(test_base.BaseTestCase):
self.addCleanup(clear_tables)
def _get_db_engine_session(self):
# We need to get our own Facade so that the file backed sqlite tests
# don't use the _FACADE singleton. Some tests will use in-memory
# sqlite, some will use a file backed sqlite.
if 'sqlite:///' in self.connection_string:
facade = db_session.EngineFacade.from_config(cfg.CONF,
sqlite_fk=True)
engine = facade.get_engine()
session = facade.get_session(expire_on_commit=True)
else:
engine = db_api.get_engine()
session = db_api.get_session()
return engine, session
def _seed_lookup_tables(self, session):
self._seed_lookup_table(
session, constants.SUPPORTED_PROVISIONING_STATUSES,
@ -128,6 +134,5 @@ class OctaviaDBTestBase(test_base.BaseTestCase):
def _seed_lookup_table(self, session, name_list, model_cls):
for name in name_list:
with session.begin():
model = model_cls(name=name)
session.add(model)
model = model_cls(name=name)
session.add(model)

View File

@ -32,9 +32,9 @@ class ModelTestMixin(object):
FAKE_AZ = 'zone1'
def _insert(self, session, model_cls, model_kwargs):
with session.begin():
model = model_cls(**model_kwargs)
session.add(model)
model = model_cls(**model_kwargs)
session.add(model)
session.commit()
return model
def create_flavor_profile(self, session, **overrides):
@ -230,9 +230,8 @@ class PoolModelTest(base.OctaviaDBTestBase, ModelTestMixin):
def test_delete(self):
pool = self.create_pool(self.session)
id = pool.id
with self.session.begin():
self.session.delete(pool)
self.session.flush()
self.session.delete(pool)
self.session.commit()
new_pool = self.session.query(
models.Pool).filter_by(id=id).first()
self.assertIsNone(new_pool)
@ -302,6 +301,7 @@ class MemberModelTest(base.OctaviaDBTestBase, ModelTestMixin):
member_id = member.id
member.enabled = False
self.session.commit()
new_member = self.session.query(
models.Member).filter_by(id=member_id).first()
@ -311,9 +311,11 @@ class MemberModelTest(base.OctaviaDBTestBase, ModelTestMixin):
def test_delete(self):
member = self.create_member(self.session, self.pool.id)
member_id = member.id
with self.session.begin():
self.session.delete(member)
self.session.flush()
self.session.commit()
self.session.delete(member)
self.session.commit()
new_member = self.session.query(
models.Member).filter_by(id=member_id).first()
self.assertIsNone(new_member)
@ -424,9 +426,8 @@ class ListenerModelTest(base.OctaviaDBTestBase, ModelTestMixin):
def test_delete(self):
listener = self.create_listener(self.session)
listener_id = listener.id
with self.session.begin():
self.session.delete(listener)
self.session.flush()
self.session.delete(listener)
self.session.commit()
new_listener = self.session.query(
models.Listener).filter_by(id=listener_id).first()
self.assertIsNone(new_listener)
@ -578,9 +579,8 @@ class LoadBalancerModelTest(base.OctaviaDBTestBase, ModelTestMixin):
def test_delete(self):
load_balancer = self.create_load_balancer(self.session)
lb_id = load_balancer.id
with self.session.begin():
self.session.delete(load_balancer)
self.session.flush()
self.session.delete(load_balancer)
self.session.commit()
new_load_balancer = self.session.query(
models.LoadBalancer).filter_by(id=lb_id).first()
self.assertIsNone(new_load_balancer)
@ -745,7 +745,7 @@ class AmphoraHealthModelTest(base.OctaviaDBTestBase, ModelTestMixin):
new_amphora_health = self.session.query(
models.AmphoraHealth).filter_by(
amphora_id=amphora_health.amphora_id).first()
self.assertEqual(newdate, new_amphora_health.last_update.date())
self.assertEqual(newdate, new_amphora_health.last_update)
def test_delete(self):
amphora_health = self.create_amphora_health(
@ -787,9 +787,8 @@ class L7PolicyModelTest(base.OctaviaDBTestBase, ModelTestMixin):
def test_delete(self):
l7policy = self.create_l7policy(self.session, self.listener.id)
l7policy_id = l7policy.id
with self.session.begin():
self.session.delete(l7policy)
self.session.flush()
self.session.delete(l7policy)
self.session.commit()
new_l7policy = self.session.query(
models.L7Policy).filter_by(id=l7policy_id).first()
self.assertIsNone(new_l7policy)
@ -908,9 +907,8 @@ class L7RuleModelTest(base.OctaviaDBTestBase, ModelTestMixin):
def test_delete(self):
l7rule = self.create_l7rule(self.session, self.l7policy.id)
l7rule_id = l7rule.id
with self.session.begin():
self.session.delete(l7rule)
self.session.flush()
self.session.delete(l7rule)
self.session.commit()
new_l7rule = self.session.query(
models.L7Rule).filter_by(id=l7rule_id).first()
self.assertIsNone(new_l7rule)
@ -1825,10 +1823,11 @@ class FlavorModelTest(base.OctaviaDBTestBase, ModelTestMixin):
flavor = self.create_flavor(self.session, self.profile.id)
self.assertIsNotNone(flavor.id)
id = flavor.id
self.session.commit()
self.session.delete(flavor)
self.session.commit()
with self.session.begin():
self.session.delete(flavor)
self.session.flush()
new_flavor = self.session.query(
models.Flavor).filter_by(id=id).first()
self.assertIsNone(new_flavor)
@ -1847,10 +1846,10 @@ class FlavorProfileModelTest(base.OctaviaDBTestBase, ModelTestMixin):
fp = self.create_flavor_profile(self.session)
self.assertIsNotNone(fp.id)
id = fp.id
self.session.commit()
with self.session.begin():
self.session.delete(fp)
self.session.flush()
self.session.delete(fp)
self.session.commit()
new_fp = self.session.query(
models.FlavorProfile).filter_by(id=id).first()
self.assertIsNone(new_fp)

File diff suppressed because it is too large Load Diff

View File

@ -167,7 +167,7 @@ class TestHaproxyAmphoraLoadBalancerDriverTest(base.TestCase):
self.driver.clients[API_VERSION].upload_config.assert_not_called()
self.driver.clients[API_VERSION].reload_listener.assert_not_called()
@mock.patch('octavia.db.api.get_session')
@mock.patch('octavia.db.api.session')
@mock.patch('octavia.db.repositories.ListenerRepository.update')
@mock.patch('octavia.common.tls_utils.cert_parser.load_certificates_data')
def test_update_amphora_listeners_bad_cert(
@ -176,12 +176,13 @@ class TestHaproxyAmphoraLoadBalancerDriverTest(base.TestCase):
mock_amphora.id = 'mock_amphora_id'
mock_amphora.api_version = API_VERSION
mock_get_session.return_value = 'fake_session'
mock_session = mock_get_session().begin().__enter__()
mock_load_cert.side_effect = [Exception]
self.driver.update_amphora_listeners(self.lb,
mock_amphora, self.timeout_dict)
mock_list_update.assert_called_once_with(
'fake_session', self.lb.listeners[0].id,
mock_session, self.lb.listeners[0].id,
provisioning_status=constants.ERROR,
operating_status=constants.ERROR)
self.driver.jinja_combo.build_config.assert_not_called()

View File

@ -27,7 +27,7 @@ class TestDriverGet(base.TestCase):
@mock.patch('octavia.db.api.get_session')
def _test_process_get_object(self, object_name, mock_object_repo,
mock_object_to_provider, mock_get_session):
mock_get_session.return_value = 'bogus_session'
mock_get_session.return_value = mock.MagicMock()
object_repo_mock = mock.MagicMock()
mock_object_repo.return_value = object_repo_mock
db_object_mock = mock.MagicMock()
@ -47,7 +47,7 @@ class TestDriverGet(base.TestCase):
mock_object_repo.assert_called_once_with()
object_repo_mock.get.assert_called_once_with(
'bogus_session', id=object_id, show_deleted=False)
mock_get_session(), id=object_id, show_deleted=False)
mock_object_to_provider.assert_called_once_with(db_object_mock)
self.assertEqual(ref_prov_dict, result)
@ -61,7 +61,7 @@ class TestDriverGet(base.TestCase):
mock_object_repo.assert_called_once_with()
object_repo_mock.get.assert_called_once_with(
'bogus_session', id=object_id, show_deleted=False)
mock_get_session(), id=object_id, show_deleted=False)
mock_object_to_provider.assert_not_called()
self.assertEqual({}, result)

View File

@ -39,7 +39,7 @@ class TestDriverUpdater(base.TestCase):
mock_pool_repo, mock_l7r_repo, mock_l7p_repo, mock_list_repo,
mock_lb_repo):
super().setUp()
self.mock_session = "FAKE_DB_SESSION"
self.mock_session = mock.MagicMock()
mock_get_session.return_value = self.mock_session
member_mock = mock.MagicMock()
@ -98,34 +98,32 @@ class TestDriverUpdater(base.TestCase):
@mock.patch('octavia.db.repositories.Repositories.decrement_quota')
@mock.patch('octavia.db.api.get_session')
def test_decrement_quota(self, mock_get_session, mock_dec_quota):
mock_session = mock.MagicMock()
mock_get_session.return_value = mock_session
mock_dec_quota.side_effect = [mock.DEFAULT,
exceptions.OctaviaException('Boom')]
self.driver_updater._decrement_quota(self.mock_lb_repo,
'FakeName', self.lb_id)
mock_dec_quota.assert_called_once_with(
mock_session, self.mock_lb_repo.model_class.__data_model__,
self.mock_session, self.mock_lb_repo.model_class.__data_model__,
self.lb_project_id)
mock_session.commit.assert_called_once()
mock_session.rollback.assert_not_called()
self.mock_session.commit.assert_called_once()
self.mock_session.rollback.assert_not_called()
# Test exception path
mock_dec_quota.reset_mock()
mock_session.reset_mock()
self.mock_session.reset_mock()
self.assertRaises(exceptions.OctaviaException,
self.driver_updater._decrement_quota,
self.mock_lb_repo, 'FakeName', self.lb_id)
mock_dec_quota.assert_called_once_with(
mock_session, self.mock_lb_repo.model_class.__data_model__,
self.mock_session, self.mock_lb_repo.model_class.__data_model__,
self.lb_project_id)
mock_session.commit.assert_not_called()
mock_session.rollback.assert_called_once()
self.mock_session.commit.assert_not_called()
self.mock_session.rollback.assert_called_once()
# Test already deleted path
mock_dec_quota.reset_mock()
mock_session.reset_mock()
self.mock_session.reset_mock()
# Create a local mock LB and LB_repo for this test
mock_lb = mock.MagicMock()
mock_lb.id = self.lb_id
@ -136,8 +134,8 @@ class TestDriverUpdater(base.TestCase):
self.driver_updater._decrement_quota(mock_lb_repo,
'FakeName', self.lb_id)
mock_dec_quota.assert_not_called()
mock_session.commit.assert_not_called()
mock_session.rollback.assert_called_once()
self.mock_session.commit.assert_not_called()
self.mock_session.rollback.assert_called_once()
@mock.patch('octavia.api.drivers.driver_agent.driver_updater.'
'DriverUpdater._decrement_quota')

View File

@ -99,7 +99,7 @@ class TestHealthManager(base.TestCase):
hm = healthmanager.HealthManager(exit_event)
hm.health_check()
session_mock.assert_called_once_with(autocommit=False)
session_mock.assert_called_once_with()
self.assertFalse(failover_mock.called)
@mock.patch('octavia.controller.worker.v2.controller_worker.'
@ -115,7 +115,7 @@ class TestHealthManager(base.TestCase):
hm = healthmanager.HealthManager(exit_event)
hm.health_check()
session_mock.assert_called_once_with(autocommit=False)
session_mock.assert_called_once_with()
self.assertFalse(failover_mock.called)
@mock.patch('octavia.controller.worker.v2.controller_worker.'

View File

@ -36,7 +36,7 @@ class TestAmphoraBuildRateLimit(base.TestCase):
self.amp_build_req_repo = mock.MagicMock()
self.conf.config(group='haproxy_amphora', build_rate_limit=1)
@mock.patch('octavia.db.api.get_session', mock.MagicMock())
@mock.patch('octavia.db.api.session', mock.MagicMock())
@mock.patch('octavia.controller.worker.amphora_rate_limit'
'.AmphoraBuildRateLimit.wait_for_build_slot')
@mock.patch('octavia.db.repositories.AmphoraBuildReqRepository'

View File

@ -19,8 +19,6 @@ from octavia.common import constants
from octavia.controller.worker import task_utils as task_utilities
import octavia.tests.unit.base as base
TEST_SESSION = 'TEST_SESSION'
class TestTaskUtils(base.TestCase):
@ -39,7 +37,7 @@ class TestTaskUtils(base.TestCase):
super().setUp()
@mock.patch('octavia.db.api.get_session', return_value=TEST_SESSION)
@mock.patch('octavia.db.api.session')
@mock.patch('octavia.db.repositories.AmphoraRepository.update')
def test_mark_amphora_status_error(self,
mock_amphora_repo_update,
@ -48,8 +46,10 @@ class TestTaskUtils(base.TestCase):
# Happy path
self.task_utils.mark_amphora_status_error(self.AMPHORA_ID)
mock_session = mock_get_session().begin().__enter__()
mock_amphora_repo_update.assert_called_once_with(
TEST_SESSION,
mock_session,
id=self.AMPHORA_ID,
status=constants.ERROR)
@ -61,7 +61,7 @@ class TestTaskUtils(base.TestCase):
self.assertFalse(mock_amphora_repo_update.called)
@mock.patch('octavia.db.api.get_session', return_value=TEST_SESSION)
@mock.patch('octavia.db.api.session')
@mock.patch('octavia.db.repositories.HealthMonitorRepository.update')
def test_mark_health_mon_prov_status_error(self,
mock_health_mon_repo_update,
@ -70,8 +70,10 @@ class TestTaskUtils(base.TestCase):
# Happy path
self.task_utils.mark_health_mon_prov_status_error(self.HEALTH_MON_ID)
mock_session = mock_get_session().begin().__enter__()
mock_health_mon_repo_update.assert_called_once_with(
TEST_SESSION,
mock_session,
id=self.HEALTH_MON_ID,
provisioning_status=constants.ERROR)
@ -83,7 +85,7 @@ class TestTaskUtils(base.TestCase):
self.assertFalse(mock_health_mon_repo_update.called)
@mock.patch('octavia.db.api.get_session', return_value=TEST_SESSION)
@mock.patch('octavia.db.api.session')
@mock.patch('octavia.db.repositories.L7PolicyRepository.update')
def test_mark_l7policy_prov_status_error(self,
mock_l7policy_repo_update,
@ -92,8 +94,10 @@ class TestTaskUtils(base.TestCase):
# Happy path
self.task_utils.mark_l7policy_prov_status_error(self.L7POLICY_ID)
mock_session = mock_get_session().begin().__enter__()
mock_l7policy_repo_update.assert_called_once_with(
TEST_SESSION,
mock_session,
id=self.L7POLICY_ID,
provisioning_status=constants.ERROR)
@ -105,7 +109,7 @@ class TestTaskUtils(base.TestCase):
self.assertFalse(mock_l7policy_repo_update.called)
@mock.patch('octavia.db.api.get_session', return_value=TEST_SESSION)
@mock.patch('octavia.db.api.session')
@mock.patch('octavia.db.repositories.L7RuleRepository.update')
def test_mark_l7rule_prov_status_error(self,
mock_l7rule_repo_update,
@ -114,8 +118,10 @@ class TestTaskUtils(base.TestCase):
# Happy path
self.task_utils.mark_l7rule_prov_status_error(self.L7RULE_ID)
mock_session = mock_get_session().begin().__enter__()
mock_l7rule_repo_update.assert_called_once_with(
TEST_SESSION,
mock_session,
id=self.L7RULE_ID,
provisioning_status=constants.ERROR)
@ -127,7 +133,7 @@ class TestTaskUtils(base.TestCase):
self.assertFalse(mock_l7rule_repo_update.called)
@mock.patch('octavia.db.api.get_session', return_value=TEST_SESSION)
@mock.patch('octavia.db.api.session')
@mock.patch('octavia.db.repositories.ListenerRepository.update')
def test_mark_listener_prov_status_active(self,
mock_listener_repo_update,
@ -136,8 +142,10 @@ class TestTaskUtils(base.TestCase):
# Happy path
self.task_utils.mark_listener_prov_status_active(self.LISTENER_ID)
mock_session = mock_get_session().begin().__enter__()
mock_listener_repo_update.assert_called_once_with(
TEST_SESSION,
mock_session,
id=self.LISTENER_ID,
provisioning_status=constants.ACTIVE)
@ -149,7 +157,7 @@ class TestTaskUtils(base.TestCase):
self.assertFalse(mock_listener_repo_update.called)
@mock.patch('octavia.db.api.get_session', return_value=TEST_SESSION)
@mock.patch('octavia.db.api.session')
@mock.patch('octavia.db.repositories.ListenerRepository.update')
def test_mark_listener_prov_status_error(self,
mock_listener_repo_update,
@ -158,8 +166,10 @@ class TestTaskUtils(base.TestCase):
# Happy path
self.task_utils.mark_listener_prov_status_error(self.LISTENER_ID)
mock_session = mock_get_session().begin().__enter__()
mock_listener_repo_update.assert_called_once_with(
TEST_SESSION,
mock_session,
id=self.LISTENER_ID,
provisioning_status=constants.ERROR)
@ -171,7 +181,7 @@ class TestTaskUtils(base.TestCase):
self.assertFalse(mock_listener_repo_update.called)
@mock.patch('octavia.db.api.get_session', return_value=TEST_SESSION)
@mock.patch('octavia.db.api.session')
@mock.patch('octavia.db.repositories.LoadBalancerRepository.update')
def test_mark_loadbalancer_prov_status_active(self,
mock_lb_repo_update,
@ -181,8 +191,10 @@ class TestTaskUtils(base.TestCase):
self.task_utils.mark_loadbalancer_prov_status_active(
self.LOADBALANCER_ID)
mock_session = mock_get_session().begin().__enter__()
mock_lb_repo_update.assert_called_once_with(
TEST_SESSION,
mock_session,
id=self.LOADBALANCER_ID,
provisioning_status=constants.ACTIVE)
@ -195,7 +207,7 @@ class TestTaskUtils(base.TestCase):
self.assertFalse(mock_lb_repo_update.called)
@mock.patch('octavia.db.api.get_session', return_value=TEST_SESSION)
@mock.patch('octavia.db.api.session')
@mock.patch('octavia.db.repositories.LoadBalancerRepository.update')
def test_mark_loadbalancer_prov_status_error(self,
mock_lb_repo_update,
@ -205,8 +217,10 @@ class TestTaskUtils(base.TestCase):
self.task_utils.mark_loadbalancer_prov_status_error(
self.LOADBALANCER_ID)
mock_session = mock_get_session().begin().__enter__()
mock_lb_repo_update.assert_called_once_with(
TEST_SESSION,
mock_session,
id=self.LOADBALANCER_ID,
provisioning_status=constants.ERROR)
@ -219,7 +233,7 @@ class TestTaskUtils(base.TestCase):
self.assertFalse(mock_lb_repo_update.called)
@mock.patch('octavia.db.api.get_session', return_value=TEST_SESSION)
@mock.patch('octavia.db.api.session')
@mock.patch('octavia.db.repositories.MemberRepository.update')
def test_mark_member_prov_status_error(self,
mock_member_repo_update,
@ -228,8 +242,10 @@ class TestTaskUtils(base.TestCase):
# Happy path
self.task_utils.mark_member_prov_status_error(self.MEMBER_ID)
mock_session = mock_get_session().begin().__enter__()
mock_member_repo_update.assert_called_once_with(
TEST_SESSION,
mock_session,
id=self.MEMBER_ID,
provisioning_status=constants.ERROR)
@ -241,7 +257,7 @@ class TestTaskUtils(base.TestCase):
self.assertFalse(mock_member_repo_update.called)
@mock.patch('octavia.db.api.get_session', return_value=TEST_SESSION)
@mock.patch('octavia.db.api.session')
@mock.patch('octavia.db.repositories.PoolRepository.update')
def test_mark_pool_prov_status_error(self,
mock_pool_repo_update,
@ -250,8 +266,10 @@ class TestTaskUtils(base.TestCase):
# Happy path
self.task_utils.mark_pool_prov_status_error(self.POOL_ID)
mock_session = mock_get_session().begin().__enter__()
mock_pool_repo_update.assert_called_once_with(
TEST_SESSION,
mock_session,
id=self.POOL_ID,
provisioning_status=constants.ERROR)
@ -263,15 +281,17 @@ class TestTaskUtils(base.TestCase):
self.assertFalse(mock_pool_repo_update.called)
@mock.patch('octavia.db.api.get_session', return_value=TEST_SESSION)
@mock.patch('octavia.db.api.session')
@mock.patch('octavia.db.repositories.LoadBalancerRepository.get')
def test_get_current_loadbalancer_from_db(self, mock_lb_repo_get,
mock_get_session):
# Happy path
self.task_utils.get_current_loadbalancer_from_db(self.LOADBALANCER_ID)
mock_session = mock_get_session().begin().__enter__()
mock_lb_repo_get.assert_called_once_with(
TEST_SESSION,
mock_session,
id=self.LOADBALANCER_ID)
# Exception path

View File

@ -148,7 +148,9 @@ class TestAmphoraDriverTasks(base.TestCase):
_session_mock, AMP_ID, status=constants.ERROR)
@mock.patch('octavia.db.repositories.LoadBalancerRepository.get')
@mock.patch('octavia.db.api.session')
def test_listeners_update(self,
mock_get_session_ctx,
mock_lb_get,
mock_driver,
mock_generate_uuid,
@ -174,12 +176,14 @@ class TestAmphoraDriverTasks(base.TestCase):
listeners_update_obj.execute(None)
mock_driver.update.assert_not_called()
mock_session = mock_get_session_ctx().begin().__enter__()
# Test the revert
amp = listeners_update_obj.revert(_LB_mock)
expected_db_calls = [mock.call(_session_mock,
expected_db_calls = [mock.call(mock_session,
id=listeners[0].id,
provisioning_status=constants.ERROR),
mock.call(_session_mock,
mock.call(mock_session,
id=listeners[1].id,
provisioning_status=constants.ERROR)]
repo.ListenerRepository.update.assert_has_calls(expected_db_calls)
@ -262,7 +266,9 @@ class TestAmphoraDriverTasks(base.TestCase):
listeners_start_obj.revert(_LB_mock)
mock_prov_status_error.assert_called_once_with('12345')
@mock.patch('octavia.db.api.session')
def test_listener_delete(self,
mock_get_session_ctx,
mock_driver,
mock_generate_uuid,
mock_log,
@ -278,10 +284,12 @@ class TestAmphoraDriverTasks(base.TestCase):
mock_driver.delete.assert_called_once_with(_listener_mock)
mock_session = mock_get_session_ctx().begin().__enter__()
# Test the revert
amp = listener_delete_obj.revert(listener_dict)
repo.ListenerRepository.update.assert_called_once_with(
_session_mock,
mock_session,
id=LISTENER_ID,
provisioning_status=constants.ERROR)
self.assertIsNone(amp)
@ -291,7 +299,7 @@ class TestAmphoraDriverTasks(base.TestCase):
mock_listener_repo_update.side_effect = Exception('fail')
amp = listener_delete_obj.revert(listener_dict)
repo.ListenerRepository.update.assert_called_once_with(
_session_mock,
mock_session,
id=LISTENER_ID,
provisioning_status=constants.ERROR)
self.assertIsNone(amp)
@ -330,7 +338,9 @@ class TestAmphoraDriverTasks(base.TestCase):
mock_driver.get_diagnostics.assert_called_once_with(
_amphora_mock)
@mock.patch('octavia.db.api.session')
def test_amphora_finalize(self,
mock_get_session_ctx,
mock_driver,
mock_generate_uuid,
mock_log,
@ -347,10 +357,12 @@ class TestAmphoraDriverTasks(base.TestCase):
mock_driver.finalize_amphora.assert_called_once_with(
_db_amphora_mock)
mock_session = mock_get_session_ctx().begin().__enter__()
# Test revert
amp = amphora_finalize_obj.revert(None, _amphora_mock)
repo.AmphoraRepository.update.assert_called_once_with(
_session_mock,
mock_session,
id=AMP_ID,
status=constants.ERROR)
self.assertIsNone(amp)
@ -360,7 +372,7 @@ class TestAmphoraDriverTasks(base.TestCase):
mock_amphora_repo_update.side_effect = Exception('fail')
amp = amphora_finalize_obj.revert(None, _amphora_mock)
repo.AmphoraRepository.update.assert_called_once_with(
_session_mock,
mock_session,
id=AMP_ID,
status=constants.ERROR)
self.assertIsNone(amp)
@ -371,7 +383,9 @@ class TestAmphoraDriverTasks(base.TestCase):
failure.Failure.from_exception(Exception('boom')), _amphora_mock)
repo.AmphoraRepository.update.assert_not_called()
@mock.patch('octavia.db.api.session')
def test_amphora_post_network_plug(self,
mock_get_session_ctx,
mock_driver,
mock_generate_uuid,
mock_log,
@ -396,10 +410,12 @@ class TestAmphoraDriverTasks(base.TestCase):
network_data_models.Port(**port_mock),
_amphora_network_config_mock)
mock_session = mock_get_session_ctx().begin().__enter__()
# Test revert
amp = amphora_post_network_plug_obj.revert(None, _amphora_mock)
repo.AmphoraRepository.update.assert_called_once_with(
_session_mock,
mock_session,
id=AMP_ID,
status=constants.ERROR)
@ -410,7 +426,7 @@ class TestAmphoraDriverTasks(base.TestCase):
mock_amphora_repo_update.side_effect = Exception('fail')
amp = amphora_post_network_plug_obj.revert(None, _amphora_mock)
repo.AmphoraRepository.update.assert_called_once_with(
_session_mock,
mock_session,
id=AMP_ID,
status=constants.ERROR)
@ -456,7 +472,10 @@ class TestAmphoraDriverTasks(base.TestCase):
self.assertEqual(hr1['nexthop'], hr2.nexthop)
@mock.patch('octavia.db.repositories.LoadBalancerRepository.get')
def test_amphorae_post_network_plug(self, mock_lb_get,
@mock.patch('octavia.db.api.session')
def test_amphorae_post_network_plug(self,
mock_get_session_ctx,
mock_lb_get,
mock_driver,
mock_generate_uuid,
mock_log,
@ -496,11 +515,13 @@ class TestAmphoraDriverTasks(base.TestCase):
_amphora_network_config_mock)
mock_driver.post_network_plug.assert_not_called()
mock_session = mock_get_session_ctx().begin().__enter__()
# Test revert
amp = amphora_post_network_plug_obj.revert(None, _LB_mock,
_deltas_mock)
repo.AmphoraRepository.update.assert_called_once_with(
_session_mock,
mock_session,
id=AMP_ID,
status=constants.ERROR)
@ -512,7 +533,7 @@ class TestAmphoraDriverTasks(base.TestCase):
amp = amphora_post_network_plug_obj.revert(None, _LB_mock,
_deltas_mock)
repo.AmphoraRepository.update.assert_called_once_with(
_session_mock,
mock_session,
id=AMP_ID,
status=constants.ERROR)
@ -527,7 +548,10 @@ class TestAmphoraDriverTasks(base.TestCase):
@mock.patch('octavia.db.repositories.LoadBalancerRepository.update')
@mock.patch('octavia.db.repositories.LoadBalancerRepository.get')
def test_amphora_post_vip_plug(self, mock_lb_get,
@mock.patch('octavia.db.api.session')
def test_amphora_post_vip_plug(self,
mock_get_session_ctx,
mock_lb_get,
mock_loadbalancer_repo_update,
mock_driver,
mock_generate_uuid,
@ -562,10 +586,12 @@ class TestAmphoraDriverTasks(base.TestCase):
_db_amphora_mock, _db_load_balancer_mock, amphorae_net_config_mock,
vrrp_port, vip_subnet, additional_vip_data=[])
mock_session = mock_get_session_ctx().begin().__enter__()
# Test revert
amp = amphora_post_vip_plug_obj.revert(None, _amphora_mock, _LB_mock)
repo.AmphoraRepository.update.assert_called_once_with(
_session_mock,
mock_session,
id=AMP_ID,
status=constants.ERROR)
repo.LoadBalancerRepository.update.assert_not_called()
@ -579,7 +605,7 @@ class TestAmphoraDriverTasks(base.TestCase):
mock_loadbalancer_repo_update.side_effect = Exception('fail')
amp = amphora_post_vip_plug_obj.revert(None, _amphora_mock, _LB_mock)
repo.AmphoraRepository.update.assert_called_once_with(
_session_mock,
mock_session,
id=AMP_ID,
status=constants.ERROR)
repo.LoadBalancerRepository.update.assert_not_called()

View File

@ -123,10 +123,8 @@ class TestDatabaseTasksQuota(base.TestCase):
with mock.patch('octavia.db.api.'
'get_session') as mock_get_session_local:
mock_session = mock.MagicMock()
mock_lock_session = mock.MagicMock()
mock_get_session_local.side_effect = [mock_session,
mock_lock_session]
mock_get_session_local.return_value = mock_session
if data_model == data_models.Pool:
task.revert(test_object, self.zero_pool_child_count, None)
@ -136,19 +134,17 @@ class TestDatabaseTasksQuota(base.TestCase):
task.revert(test_object, None)
mock_check_quota_met.assert_called_once_with(
mock_session, mock_lock_session, data_model,
mock_session, mock_session, data_model,
project_id)
mock_lock_session.commit.assert_called_once_with()
mock_session.commit.assert_called_once_with()
# revert with rollback
with mock.patch('octavia.db.api.'
'get_session') as mock_get_session_local:
mock_session = mock.MagicMock()
mock_lock_session = mock.MagicMock()
mock_get_session_local.side_effect = [mock_session,
mock_lock_session]
mock_get_session_local.return_value = mock_session
mock_check_quota_met.side_effect = (
exceptions.OctaviaException('fail'))
@ -160,7 +156,7 @@ class TestDatabaseTasksQuota(base.TestCase):
else:
task.revert(test_object, None)
mock_lock_session.rollback.assert_called_once_with()
mock_session.rollback.assert_called_once_with()
# revert with db exception
mock_check_quota_met.reset_mock()
@ -238,27 +234,22 @@ class TestDatabaseTasksQuota(base.TestCase):
mock_session.reset_mock()
with mock.patch('octavia.db.api.'
'get_session') as mock_get_session_local:
mock_lock_session = mock.MagicMock()
mock_get_session_local.side_effect = [mock_session,
mock_lock_session,
mock_lock_session,
mock_lock_session,
mock_lock_session]
mock_get_session_local.return_value = mock_session
task.revert(project_id, pool_child_count, None)
calls = [mock.call(mock_session, mock_lock_session,
calls = [mock.call(mock_session, mock_session,
data_models.Pool, project_id),
mock.call(mock_session, mock_lock_session,
mock.call(mock_session, mock_session,
data_models.HealthMonitor, project_id),
mock.call(mock_session, mock_lock_session,
mock.call(mock_session, mock_session,
data_models.Member, project_id),
mock.call(mock_session, mock_lock_session,
mock.call(mock_session, mock_session,
data_models.Member, project_id)]
mock_check_quota_met.assert_has_calls(calls)
self.assertEqual(4, mock_lock_session.commit.call_count)
self.assertEqual(4, mock_session.commit.call_count)
# revert with health monitor quota exception
mock_session.reset_mock()
@ -266,28 +257,23 @@ class TestDatabaseTasksQuota(base.TestCase):
None]
with mock.patch('octavia.db.api.'
'get_session') as mock_get_session_local:
mock_lock_session = mock.MagicMock()
mock_get_session_local.side_effect = [mock_session,
mock_lock_session,
mock_lock_session,
mock_lock_session,
mock_lock_session]
mock_get_session_local.return_value = mock_session
task.revert(project_id, pool_child_count, None)
calls = [mock.call(mock_session, mock_lock_session,
calls = [mock.call(mock_session, mock_session,
data_models.Pool, project_id),
mock.call(mock_session, mock_lock_session,
mock.call(mock_session, mock_session,
data_models.HealthMonitor, project_id),
mock.call(mock_session, mock_lock_session,
mock.call(mock_session, mock_session,
data_models.Member, project_id),
mock.call(mock_session, mock_lock_session,
mock.call(mock_session, mock_session,
data_models.Member, project_id)]
mock_check_quota_met.assert_has_calls(calls)
self.assertEqual(3, mock_lock_session.commit.call_count)
self.assertEqual(1, mock_lock_session.rollback.call_count)
self.assertEqual(3, mock_session.commit.call_count)
self.assertEqual(1, mock_session.rollback.call_count)
# revert with member quota exception
mock_session.reset_mock()
@ -295,30 +281,25 @@ class TestDatabaseTasksQuota(base.TestCase):
Exception('fail')]
with mock.patch('octavia.db.api.'
'get_session') as mock_get_session_local:
mock_lock_session = mock.MagicMock()
mock_get_session_local.side_effect = [mock_session,
mock_lock_session,
mock_lock_session,
mock_lock_session,
mock_lock_session]
mock_get_session_local.return_value = mock_session
task.revert(project_id, pool_child_count, None)
calls = [mock.call(mock_session, mock_lock_session,
calls = [mock.call(mock_session, mock_session,
data_models.Pool, project_id),
mock.call(mock_session, mock_lock_session,
mock.call(mock_session, mock_session,
data_models.HealthMonitor, project_id),
mock.call(mock_session, mock_lock_session,
mock.call(mock_session, mock_session,
data_models.Member, project_id),
mock.call(mock_session, mock_lock_session,
mock.call(mock_session, mock_session,
data_models.Member, project_id)]
mock_check_quota_met.assert_has_calls(calls)
self.assertEqual(3, mock_lock_session.commit.call_count)
self.assertEqual(1, mock_lock_session.rollback.call_count)
self.assertEqual(3, mock_session.commit.call_count)
self.assertEqual(1, mock_session.rollback.call_count)
@mock.patch('octavia.db.api.get_session')
@mock.patch('octavia.db.api.session')
@mock.patch('octavia.db.repositories.PoolRepository.get_children_count')
def test_count_pool_children_for_quota(self, repo_mock, session_mock):
project_id = uuidutils.generate_uuid()
@ -385,24 +366,20 @@ class TestDatabaseTasksQuota(base.TestCase):
mock_session.reset_mock()
with mock.patch('octavia.db.api.'
'get_session') as mock_get_session_local:
mock_lock_session = mock.MagicMock()
mock_get_session_local.side_effect = [mock_session,
mock_lock_session,
mock_lock_session,
mock_lock_session]
mock_get_session_local.return_value = mock_session
task.revert(test_object, None)
calls = [mock.call(mock_session, mock_lock_session,
calls = [mock.call(mock_session, mock_session,
data_models.L7Policy, project_id),
mock.call(mock_session, mock_lock_session,
mock.call(mock_session, mock_session,
data_models.L7Rule, project_id),
mock.call(mock_session, mock_lock_session,
mock.call(mock_session, mock_session,
data_models.L7Rule, project_id)]
mock_check_quota_met.assert_has_calls(calls)
self.assertEqual(3, mock_lock_session.commit.call_count)
self.assertEqual(3, mock_session.commit.call_count)
# revert with l7rule quota exception
mock_session.reset_mock()
@ -410,25 +387,21 @@ class TestDatabaseTasksQuota(base.TestCase):
Exception('fail')]
with mock.patch('octavia.db.api.'
'get_session') as mock_get_session_local:
mock_lock_session = mock.MagicMock()
mock_get_session_local.side_effect = [mock_session,
mock_lock_session,
mock_lock_session,
mock_lock_session]
mock_get_session_local.return_value = mock_session
task.revert(test_object, None)
calls = [mock.call(mock_session, mock_lock_session,
calls = [mock.call(mock_session, mock_session,
data_models.L7Policy, project_id),
mock.call(mock_session, mock_lock_session,
mock.call(mock_session, mock_session,
data_models.L7Rule, project_id),
mock.call(mock_session, mock_lock_session,
mock.call(mock_session, mock_session,
data_models.L7Rule, project_id)]
mock_check_quota_met.assert_has_calls(calls)
self.assertEqual(2, mock_lock_session.commit.call_count)
self.assertEqual(1, mock_lock_session.rollback.call_count)
self.assertEqual(2, mock_session.commit.call_count)
self.assertEqual(1, mock_session.rollback.call_count)
def test_decrement_l7rule_quota(self):
project_id = uuidutils.generate_uuid()

View File

@ -1411,7 +1411,7 @@ class TestNetworkTasks(base.TestCase):
net_task.execute(listener)
mock_driver.update_vip.assert_called_once_with(lb, for_delete=True)
@mock.patch('octavia.db.api.get_session', return_value='TEST')
@mock.patch('octavia.db.api.get_session')
@mock.patch('octavia.db.repositories.AmphoraRepository.get')
@mock.patch('octavia.db.repositories.LoadBalancerRepository.get')
def test_get_amphora_network_configs_by_id(
@ -1430,8 +1430,8 @@ class TestNetworkTasks(base.TestCase):
mock_driver.get_network_configs.assert_called_once_with(
'mock load balancer', amphora='mock amphora')
mock_amp_get.assert_called_once_with('TEST', id=AMP_ID)
mock_lb_get.assert_called_once_with('TEST', id=LB_ID)
mock_amp_get.assert_called_once_with(mock_get_session(), id=AMP_ID)
mock_lb_get.assert_called_once_with(mock_get_session(), id=LB_ID)
@mock.patch('octavia.db.repositories.LoadBalancerRepository.get')
@mock.patch('octavia.db.api.get_session', return_value=_session_mock)

View File

@ -30,7 +30,7 @@ class TestStatsUpdateDb(base.TestCase):
self.listener_id = uuidutils.generate_uuid()
@mock.patch('octavia.db.repositories.ListenerStatisticsRepository')
@mock.patch('octavia.db.api.get_session')
@mock.patch('octavia.db.api.session')
def test_update_stats(self, mock_get_session, mock_listener_stats_repo):
bytes_in1 = random.randrange(1000000000)
bytes_out1 = random.randrange(1000000000)
@ -61,18 +61,20 @@ class TestStatsUpdateDb(base.TestCase):
request_errors=request_errors2
)
mock_session = mock_get_session().begin().__enter__()
update_db.StatsUpdateDb().update_stats(
[stats_1, stats_2], deltas=False)
mock_listener_stats_repo().replace.assert_has_calls([
mock.call(mock_get_session(), stats_1),
mock.call(mock_get_session(), stats_2)
mock.call(mock_session, stats_1),
mock.call(mock_session, stats_2)
])
update_db.StatsUpdateDb().update_stats(
[stats_1, stats_2], deltas=True)
mock_listener_stats_repo().increment.assert_has_calls([
mock.call(mock_get_session(), stats_1),
mock.call(mock_get_session(), stats_2)
mock.call(mock_session, stats_1),
mock.call(mock_session, stats_2)
])