Re-factored central and rpc decorators

- Moved central and rpc decorators to common location.
- Cleaned up decorator code.

Change-Id: I79d21df7d17a2f706b8747e600e79a1ef1762e2b
This commit is contained in:
Erik Olof Gunnar Andersson 2022-07-25 23:23:46 -07:00
parent d58514ec84
commit 857b4c4e63
12 changed files with 419 additions and 308 deletions

View File

@ -14,16 +14,12 @@
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import collections
import copy import copy
import functools
import itertools
import random import random
from random import SystemRandom from random import SystemRandom
import re import re
import signal import signal
import string import string
import threading
import time import time
from dns import exception as dnsexception from dns import exception as dnsexception
@ -33,16 +29,16 @@ from oslo_log import log as logging
import oslo_messaging as messaging import oslo_messaging as messaging
from designate.common import constants from designate.common import constants
from designate import context as dcontext from designate.common.decorators import lock
from designate.common.decorators import notification
from designate.common.decorators import rpc
from designate import coordination from designate import coordination
from designate import dnsutils from designate import dnsutils
from designate import exceptions from designate import exceptions
from designate import network_api from designate import network_api
from designate import notifications
from designate import objects from designate import objects
from designate import policy from designate import policy
from designate import quota from designate import quota
from designate import rpc
from designate import scheduler from designate import scheduler
from designate import service from designate import service
from designate import storage from designate import storage
@ -51,135 +47,7 @@ from designate.storage import transaction_shallow_copy
from designate import utils from designate import utils
from designate.worker import rpcapi as worker_rpcapi from designate.worker import rpcapi as worker_rpcapi
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
ZONE_LOCKS = threading.local()
NOTIFICATION_BUFFER = threading.local()
def synchronized_zone(zone_arg=1, new_zone=False):
"""Ensures only a single operation is in progress for each zone
A Decorator which ensures only a single operation can be happening
on a single zone at once, within the current designate-central instance
"""
def outer(f):
@functools.wraps(f)
def sync_wrapper(self, *args, **kwargs):
if not hasattr(ZONE_LOCKS, 'held'):
# Create the held set if necessary
ZONE_LOCKS.held = set()
zone_id = None
if 'zone_id' in kwargs:
zone_id = kwargs['zone_id']
elif 'zone' in kwargs:
zone_id = kwargs['zone'].id
elif 'recordset' in kwargs:
zone_id = kwargs['recordset'].zone_id
elif 'record' in kwargs:
zone_id = kwargs['record'].zone_id
# The various objects won't always have an ID set, we should
# attempt to locate an Object containing the ID.
if zone_id is None:
for arg in itertools.chain(kwargs.values(), args):
if isinstance(arg, objects.Zone):
zone_id = arg.id
if zone_id:
break
elif (isinstance(arg, objects.RecordSet) or
isinstance(arg, objects.Record) or
isinstance(arg, objects.ZoneTransferRequest) or
isinstance(arg, objects.ZoneTransferAccept)):
zone_id = arg.zone_id
if zone_id:
break
# If we still don't have an ID, find the Nth argument as
# defined by the zone_arg decorator option.
if not zone_id and len(args) > zone_arg:
zone_id = args[zone_arg]
if isinstance(zone_id, objects.Zone):
# If the value is a Zone object, extract it's ID.
zone_id = zone_id.id
if new_zone and not zone_id:
lock_name = 'create-new-zone'
elif not new_zone and zone_id:
lock_name = 'zone-%s' % zone_id
else:
raise Exception('Failed to determine zone id for '
'synchronized operation')
if zone_id in ZONE_LOCKS.held:
return f(self, *args, **kwargs)
with self.coordination.get_lock(lock_name):
try:
ZONE_LOCKS.held.add(zone_id)
return f(self, *args, **kwargs)
finally:
ZONE_LOCKS.held.remove(zone_id)
sync_wrapper.__wrapped_function = f
sync_wrapper.__wrapper_name = 'synchronized_zone'
return sync_wrapper
return outer
def notification(notification_type):
def outer(f):
@functools.wraps(f)
def notification_wrapper(self, *args, **kwargs):
if not hasattr(NOTIFICATION_BUFFER, 'queue'):
# Create the notifications queue if necessary
NOTIFICATION_BUFFER.stack = 0
NOTIFICATION_BUFFER.queue = collections.deque()
NOTIFICATION_BUFFER.stack += 1
try:
# Find the context argument
context = dcontext.DesignateContext.\
get_context_from_function_and_args(f, args, kwargs)
# Call the wrapped function
result = f(self, *args, **kwargs)
# Feed the args/result to a notification plugin
# to determine what is emitted
payloads = notifications.get_plugin().emit(
notification_type, context, result, args, kwargs)
# Enqueue the notification
for payload in payloads:
LOG.debug('Queueing notification for %(type)s ',
{'type': notification_type})
NOTIFICATION_BUFFER.queue.appendleft(
(context, notification_type, payload,))
return result
finally:
NOTIFICATION_BUFFER.stack -= 1
if NOTIFICATION_BUFFER.stack == 0:
LOG.debug('Emitting %(count)d notifications',
{'count': len(NOTIFICATION_BUFFER.queue)})
# Send the queued notifications, in order.
for value in NOTIFICATION_BUFFER.queue:
LOG.debug('Emitting %(type)s notification',
{'type': value[1]})
self.notifier.info(value[0], value[1], value[2])
# Reset the queue
NOTIFICATION_BUFFER.queue.clear()
return notification_wrapper
return outer
class Service(service.RPCService): class Service(service.RPCService):
@ -188,6 +56,9 @@ class Service(service.RPCService):
target = messaging.Target(version=RPC_API_VERSION) target = messaging.Target(version=RPC_API_VERSION)
def __init__(self): def __init__(self):
self.zone_lock_local = lock.ZoneLockLocal()
self.notification_thread_local = notification.NotificationThreadLocal()
self._scheduler = None self._scheduler = None
self._storage = None self._storage = None
self._quota = None self._quota = None
@ -196,11 +67,9 @@ class Service(service.RPCService):
self.service_name, cfg.CONF['service:central'].topic, self.service_name, cfg.CONF['service:central'].topic,
threads=cfg.CONF['service:central'].threads, threads=cfg.CONF['service:central'].threads,
) )
self.coordination = coordination.Coordination( self.coordination = coordination.Coordination(
self.service_name, self.tg, grouping_enabled=False self.service_name, self.tg, grouping_enabled=False
) )
self.network_api = network_api.get_network_api(cfg.CONF.network_api) self.network_api = network_api.get_network_api(cfg.CONF.network_api)
@property @property
@ -713,7 +582,7 @@ class Service(service.RPCService):
# TLD Methods # TLD Methods
@rpc.expected_exceptions() @rpc.expected_exceptions()
@notification('dns.tld.create') @notification.notify_type('dns.tld.create')
@transaction @transaction
def create_tld(self, context, tld): def create_tld(self, context, tld):
policy.check('create_tld', context) policy.check('create_tld', context)
@ -738,7 +607,7 @@ class Service(service.RPCService):
return self.storage.get_tld(context, tld_id) return self.storage.get_tld(context, tld_id)
@rpc.expected_exceptions() @rpc.expected_exceptions()
@notification('dns.tld.update') @notification.notify_type('dns.tld.update')
@transaction @transaction
def update_tld(self, context, tld): def update_tld(self, context, tld):
target = { target = {
@ -751,7 +620,7 @@ class Service(service.RPCService):
return tld return tld
@rpc.expected_exceptions() @rpc.expected_exceptions()
@notification('dns.tld.delete') @notification.notify_type('dns.tld.delete')
@transaction @transaction
def delete_tld(self, context, tld_id): def delete_tld(self, context, tld_id):
policy.check('delete_tld', context, {'tld_id': tld_id}) policy.check('delete_tld', context, {'tld_id': tld_id})
@ -762,7 +631,7 @@ class Service(service.RPCService):
# TSIG Key Methods # TSIG Key Methods
@rpc.expected_exceptions() @rpc.expected_exceptions()
@notification('dns.tsigkey.create') @notification.notify_type('dns.tsigkey.create')
@transaction @transaction
def create_tsigkey(self, context, tsigkey): def create_tsigkey(self, context, tsigkey):
policy.check('create_tsigkey', context) policy.check('create_tsigkey', context)
@ -788,7 +657,7 @@ class Service(service.RPCService):
return self.storage.get_tsigkey(context, tsigkey_id) return self.storage.get_tsigkey(context, tsigkey_id)
@rpc.expected_exceptions() @rpc.expected_exceptions()
@notification('dns.tsigkey.update') @notification.notify_type('dns.tsigkey.update')
@transaction @transaction
def update_tsigkey(self, context, tsigkey): def update_tsigkey(self, context, tsigkey):
target = { target = {
@ -803,7 +672,7 @@ class Service(service.RPCService):
return tsigkey return tsigkey
@rpc.expected_exceptions() @rpc.expected_exceptions()
@notification('dns.tsigkey.delete') @notification.notify_type('dns.tsigkey.delete')
@transaction @transaction
def delete_tsigkey(self, context, tsigkey_id): def delete_tsigkey(self, context, tsigkey_id):
policy.check('delete_tsigkey', context, {'tsigkey_id': tsigkey_id}) policy.check('delete_tsigkey', context, {'tsigkey_id': tsigkey_id})
@ -862,9 +731,9 @@ class Service(service.RPCService):
return pool.ns_records return pool.ns_records
@rpc.expected_exceptions() @rpc.expected_exceptions()
@notification('dns.domain.create') @notification.notify_type('dns.domain.create')
@notification('dns.zone.create') @notification.notify_type('dns.zone.create')
@synchronized_zone(new_zone=True) @lock.synchronized_zone(new_zone=True)
def create_zone(self, context, zone): def create_zone(self, context, zone):
"""Create zone: perform checks and then call _create_zone() """Create zone: perform checks and then call _create_zone()
""" """
@ -1060,9 +929,9 @@ class Service(service.RPCService):
sort_key, sort_dir) sort_key, sort_dir)
@rpc.expected_exceptions() @rpc.expected_exceptions()
@notification('dns.domain.update') @notification.notify_type('dns.domain.update')
@notification('dns.zone.update') @notification.notify_type('dns.zone.update')
@synchronized_zone() @lock.synchronized_zone()
def update_zone(self, context, zone, increment_serial=True): def update_zone(self, context, zone, increment_serial=True):
"""Update zone. Perform checks and then call _update_zone() """Update zone. Perform checks and then call _update_zone()
@ -1134,9 +1003,9 @@ class Service(service.RPCService):
return zone return zone
@rpc.expected_exceptions() @rpc.expected_exceptions()
@notification('dns.domain.delete') @notification.notify_type('dns.domain.delete')
@notification('dns.zone.delete') @notification.notify_type('dns.zone.delete')
@synchronized_zone() @lock.synchronized_zone()
def delete_zone(self, context, zone_id): def delete_zone(self, context, zone_id):
"""Delete or abandon a zone """Delete or abandon a zone
On abandon, delete the zone from the DB immediately. On abandon, delete the zone from the DB immediately.
@ -1294,8 +1163,8 @@ class Service(service.RPCService):
# RecordSet Methods # RecordSet Methods
@rpc.expected_exceptions() @rpc.expected_exceptions()
@notification('dns.recordset.create') @notification.notify_type('dns.recordset.create')
@synchronized_zone() @lock.synchronized_zone()
def create_recordset(self, context, zone_id, recordset, def create_recordset(self, context, zone_id, recordset,
increment_serial=True): increment_serial=True):
zone = self.storage.get_zone(context, zone_id) zone = self.storage.get_zone(context, zone_id)
@ -1467,8 +1336,8 @@ class Service(service.RPCService):
recordsets=recordsets) recordsets=recordsets)
@rpc.expected_exceptions() @rpc.expected_exceptions()
@notification('dns.recordset.update') @notification.notify_type('dns.recordset.update')
@synchronized_zone() @lock.synchronized_zone()
def update_recordset(self, context, recordset, increment_serial=True): def update_recordset(self, context, recordset, increment_serial=True):
zone_id = recordset.obj_get_original_value('zone_id') zone_id = recordset.obj_get_original_value('zone_id')
zone = self.storage.get_zone(context, zone_id) zone = self.storage.get_zone(context, zone_id)
@ -1550,8 +1419,8 @@ class Service(service.RPCService):
return recordset, zone return recordset, zone
@rpc.expected_exceptions() @rpc.expected_exceptions()
@notification('dns.recordset.delete') @notification.notify_type('dns.recordset.delete')
@synchronized_zone() @lock.synchronized_zone()
def delete_recordset(self, context, zone_id, recordset_id, def delete_recordset(self, context, zone_id, recordset_id,
increment_serial=True): increment_serial=True):
zone = self.storage.get_zone(context, zone_id) zone = self.storage.get_zone(context, zone_id)
@ -2049,7 +1918,7 @@ class Service(service.RPCService):
# Blacklisted zones # Blacklisted zones
@rpc.expected_exceptions() @rpc.expected_exceptions()
@notification('dns.blacklist.create') @notification.notify_type('dns.blacklist.create')
@transaction @transaction
def create_blacklist(self, context, blacklist): def create_blacklist(self, context, blacklist):
policy.check('create_blacklist', context) policy.check('create_blacklist', context)
@ -2078,7 +1947,7 @@ class Service(service.RPCService):
return blacklists return blacklists
@rpc.expected_exceptions() @rpc.expected_exceptions()
@notification('dns.blacklist.update') @notification.notify_type('dns.blacklist.update')
@transaction @transaction
def update_blacklist(self, context, blacklist): def update_blacklist(self, context, blacklist):
target = { target = {
@ -2091,7 +1960,7 @@ class Service(service.RPCService):
return blacklist return blacklist
@rpc.expected_exceptions() @rpc.expected_exceptions()
@notification('dns.blacklist.delete') @notification.notify_type('dns.blacklist.delete')
@transaction @transaction
def delete_blacklist(self, context, blacklist_id): def delete_blacklist(self, context, blacklist_id):
policy.check('delete_blacklist', context) policy.check('delete_blacklist', context)
@ -2102,7 +1971,7 @@ class Service(service.RPCService):
# Server Pools # Server Pools
@rpc.expected_exceptions() @rpc.expected_exceptions()
@notification('dns.pool.create') @notification.notify_type('dns.pool.create')
@transaction @transaction
def create_pool(self, context, pool): def create_pool(self, context, pool):
# Verify that there is a tenant_id # Verify that there is a tenant_id
@ -2141,7 +2010,7 @@ class Service(service.RPCService):
return self.storage.get_pool(context, pool_id) return self.storage.get_pool(context, pool_id)
@rpc.expected_exceptions() @rpc.expected_exceptions()
@notification('dns.pool.update') @notification.notify_type('dns.pool.update')
@transaction @transaction
def update_pool(self, context, pool): def update_pool(self, context, pool):
policy.check('update_pool', context) policy.check('update_pool', context)
@ -2202,7 +2071,7 @@ class Service(service.RPCService):
return updated_pool return updated_pool
@rpc.expected_exceptions() @rpc.expected_exceptions()
@notification('dns.pool.delete') @notification.notify_type('dns.pool.delete')
@transaction @transaction
def delete_pool(self, context, pool_id): def delete_pool(self, context, pool_id):
@ -2225,10 +2094,10 @@ class Service(service.RPCService):
# Pool Manager Integration # Pool Manager Integration
@rpc.expected_exceptions() @rpc.expected_exceptions()
@notification('dns.domain.update') @notification.notify_type('dns.domain.update')
@notification('dns.zone.update') @notification.notify_type('dns.zone.update')
@transaction @transaction
@synchronized_zone() @lock.synchronized_zone()
def update_status(self, context, zone_id, status, serial, action=None): def update_status(self, context, zone_id, status, serial, action=None):
""" """
:param context: Security context information. :param context: Security context information.
@ -2356,7 +2225,7 @@ class Service(service.RPCService):
return ''.join(sysrand.choice(chars) for _ in range(size)) return ''.join(sysrand.choice(chars) for _ in range(size))
@rpc.expected_exceptions() @rpc.expected_exceptions()
@notification('dns.zone_transfer_request.create') @notification.notify_type('dns.zone_transfer_request.create')
@transaction @transaction
def create_zone_transfer_request(self, context, zone_transfer_request): def create_zone_transfer_request(self, context, zone_transfer_request):
@ -2427,7 +2296,7 @@ class Service(service.RPCService):
return requests return requests
@rpc.expected_exceptions() @rpc.expected_exceptions()
@notification('dns.zone_transfer_request.update') @notification.notify_type('dns.zone_transfer_request.update')
@transaction @transaction
def update_zone_transfer_request(self, context, zone_transfer_request): def update_zone_transfer_request(self, context, zone_transfer_request):
@ -2449,7 +2318,7 @@ class Service(service.RPCService):
return request return request
@rpc.expected_exceptions() @rpc.expected_exceptions()
@notification('dns.zone_transfer_request.delete') @notification.notify_type('dns.zone_transfer_request.delete')
@transaction @transaction
def delete_zone_transfer_request(self, context, zone_transfer_request_id): def delete_zone_transfer_request(self, context, zone_transfer_request_id):
# Get zone transfer request # Get zone transfer request
@ -2469,7 +2338,7 @@ class Service(service.RPCService):
zone_transfer_request_id) zone_transfer_request_id)
@rpc.expected_exceptions() @rpc.expected_exceptions()
@notification('dns.zone_transfer_accept.create') @notification.notify_type('dns.zone_transfer_accept.create')
@transaction @transaction
def create_zone_transfer_accept(self, context, zone_transfer_accept): def create_zone_transfer_accept(self, context, zone_transfer_accept):
elevated_context = context.elevated(all_tenants=True) elevated_context = context.elevated(all_tenants=True)
@ -2571,7 +2440,7 @@ class Service(service.RPCService):
# Zone Import Methods # Zone Import Methods
@rpc.expected_exceptions() @rpc.expected_exceptions()
@notification('dns.zone_import.create') @notification.notify_type('dns.zone_import.create')
def create_zone_import(self, context, request_body): def create_zone_import(self, context, request_body):
if policy.enforce_new_defaults(): if policy.enforce_new_defaults():
target = {constants.RBAC_PROJECT_ID: context.project_id} target = {constants.RBAC_PROJECT_ID: context.project_id}
@ -2667,7 +2536,7 @@ class Service(service.RPCService):
self.update_zone_import(context, zone_import) self.update_zone_import(context, zone_import)
@rpc.expected_exceptions() @rpc.expected_exceptions()
@notification('dns.zone_import.update') @notification.notify_type('dns.zone_import.update')
def update_zone_import(self, context, zone_import): def update_zone_import(self, context, zone_import):
if policy.enforce_new_defaults(): if policy.enforce_new_defaults():
target = {constants.RBAC_PROJECT_ID: zone_import.tenant_id} target = {constants.RBAC_PROJECT_ID: zone_import.tenant_id}
@ -2710,7 +2579,7 @@ class Service(service.RPCService):
return self.storage.get_zone_import(context, zone_import_id) return self.storage.get_zone_import(context, zone_import_id)
@rpc.expected_exceptions() @rpc.expected_exceptions()
@notification('dns.zone_import.delete') @notification.notify_type('dns.zone_import.delete')
@transaction @transaction
def delete_zone_import(self, context, zone_import_id): def delete_zone_import(self, context, zone_import_id):
@ -2733,7 +2602,7 @@ class Service(service.RPCService):
# Zone Export Methods # Zone Export Methods
@rpc.expected_exceptions() @rpc.expected_exceptions()
@notification('dns.zone_export.create') @notification.notify_type('dns.zone_export.create')
def create_zone_export(self, context, zone_id): def create_zone_export(self, context, zone_id):
# Try getting the zone to ensure it exists # Try getting the zone to ensure it exists
zone = self.storage.get_zone(context, zone_id) zone = self.storage.get_zone(context, zone_id)
@ -2797,7 +2666,7 @@ class Service(service.RPCService):
return self.storage.get_zone_export(context, zone_export_id) return self.storage.get_zone_export(context, zone_export_id)
@rpc.expected_exceptions() @rpc.expected_exceptions()
@notification('dns.zone_export.update') @notification.notify_type('dns.zone_export.update')
def update_zone_export(self, context, zone_export): def update_zone_export(self, context, zone_export):
if policy.enforce_new_defaults(): if policy.enforce_new_defaults():
@ -2810,7 +2679,7 @@ class Service(service.RPCService):
return self.storage.update_zone_export(context, zone_export) return self.storage.update_zone_export(context, zone_export)
@rpc.expected_exceptions() @rpc.expected_exceptions()
@notification('dns.zone_export.delete') @notification.notify_type('dns.zone_export.delete')
@transaction @transaction
def delete_zone_export(self, context, zone_export_id): def delete_zone_export(self, context, zone_export_id):

View File

View File

@ -0,0 +1,107 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import functools
import itertools
import threading
from oslo_log import log as logging
from designate import objects
LOG = logging.getLogger(__name__)
class ZoneLockLocal(threading.local):
def __init__(self):
super(ZoneLockLocal, self).__init__()
self._held = set()
def hold(self, name):
self._held.add(name)
def release(self, name):
self._held.remove(name)
def has_lock(self, name):
return name in self._held
def extract_zone_id(args, kwargs):
zone_id = None
if 'zone_id' in kwargs:
zone_id = kwargs['zone_id']
elif 'zone' in kwargs:
zone_id = kwargs['zone'].id
elif 'recordset' in kwargs:
zone_id = kwargs['recordset'].zone_id
elif 'record' in kwargs:
zone_id = kwargs['record'].zone_id
if not zone_id:
for arg in itertools.chain(args, kwargs.values()):
if not isinstance(arg, objects.DesignateObject):
continue
if isinstance(arg, objects.Zone):
zone_id = arg.id
if zone_id:
break
elif isinstance(arg, (objects.RecordSet,
objects.Record,
objects.ZoneTransferRequest,
objects.ZoneTransferAccept)):
zone_id = arg.zone_id
if zone_id:
break
if not zone_id and len(args) > 1:
arg = args[1]
if isinstance(arg, str):
zone_id = arg
elif isinstance(zone_id, objects.Zone):
zone_id = arg.id
return zone_id
def synchronized_zone(new_zone=False):
"""Ensures only a single operation is in progress for each zone
A Decorator which ensures only a single operation can be happening
on a single zone at once, within the current designate-central instance
"""
def outer(f):
@functools.wraps(f)
def sync_wrapper(cls, *args, **kwargs):
if new_zone is True:
lock_name = 'create-new-zone'
else:
zone_id = extract_zone_id(args, kwargs)
if zone_id:
lock_name = 'zone-%s' % zone_id
else:
raise Exception('Failed to determine zone id for '
'synchronized operation')
if cls.zone_lock_local.has_lock(lock_name):
return f(cls, *args, **kwargs)
with cls.coordination.get_lock(lock_name):
try:
cls.zone_lock_local.hold(lock_name)
return f(cls, *args, **kwargs)
finally:
cls.zone_lock_local.release(lock_name)
return sync_wrapper
return outer

View File

@ -0,0 +1,90 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import collections
import functools
import itertools
import threading
from oslo_log import log as logging
from designate import context as designate_context
from designate import notifications
LOG = logging.getLogger(__name__)
class NotificationThreadLocal(threading.local):
def __init__(self):
super(NotificationThreadLocal, self).__init__()
self.stack = 0
self.queue = collections.deque()
def reset_queue(self):
self.queue.clear()
def notify_type(notification_type):
def outer(f):
@functools.wraps(f)
def notification_wrapper(cls, *args, **kwargs):
cls.notification_thread_local.stack += 1
context = None
for arg in itertools.chain(args, kwargs.values()):
if isinstance(arg, designate_context.DesignateContext):
context = arg
break
try:
result = f(cls, *args, **kwargs)
payloads = notifications.get_plugin().emit(
notification_type, context, result, args, kwargs
)
for payload in payloads:
LOG.debug(
'Queueing notification for %(type)s',
{
'type': notification_type
}
)
cls.notification_thread_local.queue.appendleft(
(context, notification_type, payload,)
)
return result
finally:
cls.notification_thread_local.stack -= 1
if cls.notification_thread_local.stack == 0:
LOG.debug(
'Emitting %(count)d notifications',
{
'count': len(cls.notification_thread_local.queue)
}
)
for message in cls.notification_thread_local.queue:
LOG.debug(
'Emitting %(type)s notification',
{
'type': message[1]
}
)
cls.notifier.info(message[0], message[1], message[2])
cls.notification_thread_local.reset_queue()
return notification_wrapper
return outer

View File

@ -0,0 +1,49 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import functools
import threading
from oslo_messaging.rpc import dispatcher as rpc_dispatcher
import designate.exceptions
class ExceptionThreadLocal(threading.local):
def __init__(self):
super(ExceptionThreadLocal, self).__init__()
self.depth = 0
def reset_depth(self):
self.depth = 0
def expected_exceptions():
def outer(f):
@functools.wraps(f)
def exception_wrapper(cls, *args, **kwargs):
cls.exception_thread_local.depth += 1
# We only want to wrap the first function wrapped.
if cls.exception_thread_local.depth > 1:
return f(cls, *args, **kwargs)
try:
return f(cls, *args, **kwargs)
except designate.exceptions.DesignateException as e:
if e.expected:
raise rpc_dispatcher.ExpectedException()
raise
finally:
cls.exception_thread_local.reset_depth()
return exception_wrapper
return outer

View File

@ -14,7 +14,6 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import copy import copy
import itertools
from keystoneauth1.access import service_catalog as ksa_service_catalog from keystoneauth1.access import service_catalog as ksa_service_catalog
from keystoneauth1 import plugin from keystoneauth1 import plugin
@ -145,21 +144,6 @@ class DesignateContext(context.RequestContext):
return cls(None, **kwargs) return cls(None, **kwargs)
@classmethod
def get_context_from_function_and_args(cls, function, args, kwargs):
"""
Find an arg of type DesignateContext and return it.
This is useful in a couple of decorators where we don't
know much about the function we're wrapping.
"""
for arg in itertools.chain(kwargs.values(), args):
if isinstance(arg, cls):
return arg
return None
@property @property
def all_tenants(self): def all_tenants(self):
return self._all_tenants return self._all_tenants

View File

@ -11,8 +11,6 @@
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import functools
import threading
from oslo_config import cfg from oslo_config import cfg
import oslo_messaging as messaging import oslo_messaging as messaging
@ -40,7 +38,6 @@ __all__ = [
] ]
CONF = cfg.CONF CONF = cfg.CONF
EXPECTED_EXCEPTION = threading.local()
NOTIFICATION_TRANSPORT = None NOTIFICATION_TRANSPORT = None
NOTIFIER = None NOTIFIER = None
TRANSPORT = None TRANSPORT = None
@ -237,27 +234,3 @@ def create_transport(url):
return messaging.get_rpc_transport(CONF, return messaging.get_rpc_transport(CONF,
url=url, url=url,
allowed_remote_exmods=exmods) allowed_remote_exmods=exmods)
def expected_exceptions():
def outer(f):
@functools.wraps(f)
def exception_wrapper(self, *args, **kwargs):
if not hasattr(EXPECTED_EXCEPTION, 'depth'):
EXPECTED_EXCEPTION.depth = 0
EXPECTED_EXCEPTION.depth += 1
# We only want to wrap the first function wrapped.
if EXPECTED_EXCEPTION.depth > 1:
return f(self, *args, **kwargs)
try:
return f(self, *args, **kwargs)
except designate.exceptions.DesignateException as e:
if e.expected:
raise rpc_dispatcher.ExpectedException()
raise
finally:
EXPECTED_EXCEPTION.depth = 0
return exception_wrapper
return outer

View File

@ -30,6 +30,7 @@ from oslo_service import sslutils
from oslo_service import wsgi from oslo_service import wsgi
from oslo_utils import netutils from oslo_utils import netutils
from designate.common.decorators import rpc as rpc_decorator
from designate.common import profiler from designate.common import profiler
import designate.conf import designate.conf
from designate.i18n import _ from designate.i18n import _
@ -77,6 +78,7 @@ class RPCService(Service):
rpc_topic, self.name) rpc_topic, self.name)
self.endpoints = [self] self.endpoints = [self]
self.exception_thread_local = rpc_decorator.ExceptionThreadLocal()
self.notifier = None self.notifier = None
self.rpc_server = None self.rpc_server = None
self.rpc_topic = rpc_topic self.rpc_topic = rpc_topic

View File

@ -1,75 +0,0 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from unittest import mock
from oslo_concurrency import lockutils
from oslo_log import log as logging
from designate.central import service
from designate import exceptions
from designate.objects import record
from designate.objects import zone
from designate.tests.test_central import CentralTestCase
from designate import utils
LOG = logging.getLogger(__name__)
class FakeCoordination(object):
def get_lock(self, name):
return lockutils.lock(name)
class CentralDecoratorTests(CentralTestCase):
def test_synchronized_zone_exception_raised(self):
@service.synchronized_zone()
def mock_get_zone(cls, index, zone):
self.assertEqual(service.ZONE_LOCKS.held, {zone.id})
if index % 3 == 0:
raise exceptions.ZoneNotFound()
for index in range(9):
try:
mock_get_zone(mock.Mock(coordination=FakeCoordination()),
index,
zone.Zone(id=utils.generate_uuid()))
except exceptions.ZoneNotFound:
pass
def test_synchronized_zone_recursive_decorator_call(self):
@service.synchronized_zone()
def mock_create_record(cls, context, record):
self.assertEqual(service.ZONE_LOCKS.held, {record.zone_id})
mock_get_zone(cls, context, zone.Zone(id=record.zone_id))
@service.synchronized_zone()
def mock_get_zone(cls, context, zone):
self.assertEqual(service.ZONE_LOCKS.held, {zone.id})
mock_create_record(mock.Mock(coordination=FakeCoordination()),
self.get_context(),
record=record.Record(zone_id=utils.generate_uuid()))
mock_get_zone(mock.Mock(coordination=FakeCoordination()),
self.get_context(),
zone=zone.Zone(id=utils.generate_uuid()))
def test_synchronized_zone_raises_exception_when_no_zone_provided(self):
@service.synchronized_zone(new_zone=False)
def mock_not_creating_new_zone(cls, context, record):
pass
self.assertRaisesRegex(
Exception,
'Failed to determine zone id for '
'synchronized operation',
mock_not_creating_new_zone, self.get_context(), None
)

View File

@ -392,7 +392,7 @@ class CentralServiceTestCase(CentralBasic):
def test_create_recordset_in_storage(self): def test_create_recordset_in_storage(self):
self.service._enforce_recordset_quota = mock.Mock() self.service._enforce_recordset_quota = mock.Mock()
self.service._validate_recordset = mock.Mock() self.service._validate_recordset = mock.Mock(spec=objects.RecordSet)
self.service.storage.create_recordset = mock.Mock(return_value='rs') self.service.storage.create_recordset = mock.Mock(return_value='rs')
self.service._update_zone_in_storage = mock.Mock() self.service._update_zone_in_storage = mock.Mock()
@ -416,7 +416,7 @@ class CentralServiceTestCase(CentralBasic):
central_service.storage.create_recordset = mock.Mock(return_value='rs') central_service.storage.create_recordset = mock.Mock(return_value='rs')
central_service._update_zone_in_storage = mock.Mock() central_service._update_zone_in_storage = mock.Mock()
recordset = mock.Mock() recordset = mock.Mock(spec=objects.RecordSet)
recordset.obj_attr_is_set.return_value = True recordset.obj_attr_is_set.return_value = True
recordset.records = [MockRecord()] recordset.records = [MockRecord()]
@ -441,7 +441,7 @@ class CentralServiceTestCase(CentralBasic):
# NOTE(thirose): Since this is a race condition we assume that # NOTE(thirose): Since this is a race condition we assume that
# we will hit it if we try to do the operations in a loop 100 times. # we will hit it if we try to do the operations in a loop 100 times.
for num in range(100): for num in range(100):
recordset = mock.Mock() recordset = mock.Mock(spec=objects.RecordSet)
recordset.name = "b{}".format(num) recordset.name = "b{}".format(num)
recordset.obj_attr_is_set.return_value = True recordset.obj_attr_is_set.return_value = True
recordset.records = [MockRecord()] recordset.records = [MockRecord()]
@ -1148,7 +1148,7 @@ class CentralZoneTestCase(CentralBasic):
def test_update_recordset_fail_on_changes(self): def test_update_recordset_fail_on_changes(self):
self.service.storage.get_zone.return_value = RoObject() self.service.storage.get_zone.return_value = RoObject()
recordset = mock.Mock() recordset = mock.Mock(spec=objects.RecordSet)
recordset.obj_get_original_value.return_value = '1' recordset.obj_get_original_value.return_value = '1'
recordset.obj_get_changes.return_value = ['tenant_id', 'foo'] recordset.obj_get_changes.return_value = ['tenant_id', 'foo']
@ -1179,7 +1179,7 @@ class CentralZoneTestCase(CentralBasic):
self.service.storage.get_zone.return_value = RoObject( self.service.storage.get_zone.return_value = RoObject(
action='DELETE', action='DELETE',
) )
recordset = mock.Mock() recordset = mock.Mock(spec=objects.RecordSet)
recordset.obj_get_changes.return_value = ['foo'] recordset.obj_get_changes.return_value = ['foo']
exc = self.assertRaises(rpc_dispatcher.ExpectedException, exc = self.assertRaises(rpc_dispatcher.ExpectedException,
@ -1196,7 +1196,7 @@ class CentralZoneTestCase(CentralBasic):
tenant_id='2', tenant_id='2',
action='bogus', action='bogus',
) )
recordset = mock.Mock() recordset = mock.Mock(spec=objects.RecordSet)
recordset.obj_get_changes.return_value = ['foo'] recordset.obj_get_changes.return_value = ['foo']
recordset.managed = True recordset.managed = True
self.context = mock.Mock() self.context = mock.Mock()
@ -1216,10 +1216,11 @@ class CentralZoneTestCase(CentralBasic):
tenant_id='2', tenant_id='2',
action='bogus', action='bogus',
) )
recordset = mock.Mock() recordset = mock.Mock(spec=objects.RecordSet)
recordset.obj_get_changes.return_value = ['foo'] recordset.obj_get_changes.return_value = ['foo']
recordset.obj_get_original_value.return_value =\ recordset.obj_get_original_value.return_value = (
'9c85d9b0-1e9d-4e99-aede-a06664f1af2e' '9c85d9b0-1e9d-4e99-aede-a06664f1af2e'
)
recordset.managed = False recordset.managed = False
self.service._update_recordset_in_storage = mock.Mock( self.service._update_recordset_in_storage = mock.Mock(
return_value=('x', 'y') return_value=('x', 'y')
@ -1239,7 +1240,7 @@ class CentralZoneTestCase(CentralBasic):
'recordset_id': '9c85d9b0-1e9d-4e99-aede-a06664f1af2e', 'recordset_id': '9c85d9b0-1e9d-4e99-aede-a06664f1af2e',
'project_id': '2'}, target) 'project_id': '2'}, target)
def test__update_recordset_in_storage(self): def test_update_recordset_in_storage(self):
recordset = mock.Mock() recordset = mock.Mock()
recordset.name = 'n' recordset.name = 'n'
recordset.type = 't' recordset.type = 't'
@ -1426,7 +1427,7 @@ class CentralZoneTestCase(CentralBasic):
self.assertTrue( self.assertTrue(
self.service._delete_recordset_in_storage.called) self.service._delete_recordset_in_storage.called)
def test__delete_recordset_in_storage(self): def test_delete_recordset_in_storage(self):
def mock_uds(c, zone, inc): def mock_uds(c, zone, inc):
return zone return zone
self.service._update_zone_in_storage = mock_uds self.service._update_zone_in_storage = mock_uds
@ -1730,7 +1731,7 @@ class CentralQuotaTest(unittest.TestCase):
service = Service() service = Service()
service.storage.count_records.return_value = 10 service.storage.count_records.return_value = 10
recordset = mock.Mock() recordset = mock.Mock(spec=objects.RecordSet)
recordset.managed = False recordset.managed = False
recordset.records = ['1.1.1.%i' % (i + 1) for i in range(5)] recordset.records = ['1.1.1.%i' % (i + 1) for i in range(5)]
@ -1801,7 +1802,7 @@ class CentralQuotaTest(unittest.TestCase):
1, 1, 1, 1,
] ]
managed_recordset = mock.Mock() managed_recordset = mock.Mock(spec=objects.RecordSet)
managed_recordset.managed = True managed_recordset.managed = True
recordset_one_record = mock.Mock() recordset_one_record = mock.Mock()

View File

@ -0,0 +1,111 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from unittest import mock
from oslo_concurrency import lockutils
from oslo_log import log as logging
import oslotest.base
from designate.common.decorators import lock
from designate import exceptions
from designate.objects import record
from designate.objects import zone
from designate import utils
LOG = logging.getLogger(__name__)
class FakeCoordination:
def get_lock(self, name):
return lockutils.lock(name)
class FakeService:
def __init__(self):
self.zone_lock_local = lock.ZoneLockLocal()
self.coordination = FakeCoordination()
class CentralDecoratorTests(oslotest.base.BaseTestCase):
def setUp(self):
super().setUp()
self.context = mock.Mock()
self.service = FakeService()
def test_synchronized_zone_exception_raised(self):
@lock.synchronized_zone()
def mock_get_zone(cls, current_index, zone_obj):
self.assertEqual(
{'zone-%s' % zone_obj.id}, cls.zone_lock_local._held
)
if current_index % 3 == 0:
raise exceptions.ZoneNotFound()
for index in range(9):
try:
mock_get_zone(
self.service, index, zone.Zone(id=utils.generate_uuid())
)
except exceptions.ZoneNotFound:
pass
def test_synchronized_new_zone_with_recursion(self):
@lock.synchronized_zone(new_zone=True)
def mock_create_zone(cls, context):
self.assertEqual({'create-new-zone'}, cls.zone_lock_local._held)
mock_create_record(
cls, context, zone.Zone(id=utils.generate_uuid())
)
@lock.synchronized_zone()
def mock_create_record(cls, context, zone_obj):
self.assertIn('zone-%s' % zone_obj.id, cls.zone_lock_local._held)
self.assertIn('create-new-zone', cls.zone_lock_local._held)
mock_create_zone(
self.service, self.context
)
def test_synchronized_zone_recursive_decorator_call(self):
@lock.synchronized_zone()
def mock_create_record(cls, context, record_obj):
self.assertEqual(
{'zone-%s' % record_obj.zone_id}, cls.zone_lock_local._held
)
mock_get_zone(cls, context, zone.Zone(id=record_obj.zone_id))
@lock.synchronized_zone()
def mock_get_zone(cls, context, zone_obj):
self.assertEqual(
{'zone-%s' % zone_obj.id}, cls.zone_lock_local._held
)
mock_create_record(
self.service, self.context,
record_obj=record.Record(zone_id=utils.generate_uuid())
)
mock_get_zone(
self.service, self.context,
zone_obj=zone.Zone(id=utils.generate_uuid())
)
def test_synchronized_zone_raises_exception_when_no_zone_provided(self):
@lock.synchronized_zone(new_zone=False)
def mock_not_creating_new_zone(cls, context, record_obj):
pass
self.assertRaisesRegex(
Exception,
'Failed to determine zone id for synchronized operation',
mock_not_creating_new_zone, self.service, mock.Mock(), None
)

View File

@ -21,9 +21,9 @@ import oslo_messaging as messaging
from designate import backend from designate import backend
from designate.central import rpcapi as central_api from designate.central import rpcapi as central_api
from designate.common.decorators import rpc
from designate.context import DesignateContext from designate.context import DesignateContext
from designate import exceptions from designate import exceptions
from designate import rpc
from designate import service from designate import service
from designate import storage from designate import storage
from designate.worker import processing from designate.worker import processing