Merge "Re-factored central and rpc decorators"

This commit is contained in:
Zuul 2022-08-04 00:08:54 +00:00 committed by Gerrit Code Review
commit ef187ddec8
12 changed files with 419 additions and 308 deletions

View File

@ -14,16 +14,12 @@
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import collections
import copy
import functools
import itertools
import random
from random import SystemRandom
import re
import signal
import string
import threading
import time
from dns import exception as dnsexception
@ -33,16 +29,16 @@ from oslo_log import log as logging
import oslo_messaging as messaging
from designate.common import constants
from designate import context as dcontext
from designate.common.decorators import lock
from designate.common.decorators import notification
from designate.common.decorators import rpc
from designate import coordination
from designate import dnsutils
from designate import exceptions
from designate import network_api
from designate import notifications
from designate import objects
from designate import policy
from designate import quota
from designate import rpc
from designate import scheduler
from designate import service
from designate import storage
@ -51,135 +47,7 @@ from designate.storage import transaction_shallow_copy
from designate import utils
from designate.worker import rpcapi as worker_rpcapi
LOG = logging.getLogger(__name__)
ZONE_LOCKS = threading.local()
NOTIFICATION_BUFFER = threading.local()
def synchronized_zone(zone_arg=1, new_zone=False):
"""Ensures only a single operation is in progress for each zone
A Decorator which ensures only a single operation can be happening
on a single zone at once, within the current designate-central instance
"""
def outer(f):
@functools.wraps(f)
def sync_wrapper(self, *args, **kwargs):
if not hasattr(ZONE_LOCKS, 'held'):
# Create the held set if necessary
ZONE_LOCKS.held = set()
zone_id = None
if 'zone_id' in kwargs:
zone_id = kwargs['zone_id']
elif 'zone' in kwargs:
zone_id = kwargs['zone'].id
elif 'recordset' in kwargs:
zone_id = kwargs['recordset'].zone_id
elif 'record' in kwargs:
zone_id = kwargs['record'].zone_id
# The various objects won't always have an ID set, we should
# attempt to locate an Object containing the ID.
if zone_id is None:
for arg in itertools.chain(kwargs.values(), args):
if isinstance(arg, objects.Zone):
zone_id = arg.id
if zone_id:
break
elif (isinstance(arg, objects.RecordSet) or
isinstance(arg, objects.Record) or
isinstance(arg, objects.ZoneTransferRequest) or
isinstance(arg, objects.ZoneTransferAccept)):
zone_id = arg.zone_id
if zone_id:
break
# If we still don't have an ID, find the Nth argument as
# defined by the zone_arg decorator option.
if not zone_id and len(args) > zone_arg:
zone_id = args[zone_arg]
if isinstance(zone_id, objects.Zone):
# If the value is a Zone object, extract it's ID.
zone_id = zone_id.id
if new_zone and not zone_id:
lock_name = 'create-new-zone'
elif not new_zone and zone_id:
lock_name = 'zone-%s' % zone_id
else:
raise Exception('Failed to determine zone id for '
'synchronized operation')
if zone_id in ZONE_LOCKS.held:
return f(self, *args, **kwargs)
with self.coordination.get_lock(lock_name):
try:
ZONE_LOCKS.held.add(zone_id)
return f(self, *args, **kwargs)
finally:
ZONE_LOCKS.held.remove(zone_id)
sync_wrapper.__wrapped_function = f
sync_wrapper.__wrapper_name = 'synchronized_zone'
return sync_wrapper
return outer
def notification(notification_type):
def outer(f):
@functools.wraps(f)
def notification_wrapper(self, *args, **kwargs):
if not hasattr(NOTIFICATION_BUFFER, 'queue'):
# Create the notifications queue if necessary
NOTIFICATION_BUFFER.stack = 0
NOTIFICATION_BUFFER.queue = collections.deque()
NOTIFICATION_BUFFER.stack += 1
try:
# Find the context argument
context = dcontext.DesignateContext.\
get_context_from_function_and_args(f, args, kwargs)
# Call the wrapped function
result = f(self, *args, **kwargs)
# Feed the args/result to a notification plugin
# to determine what is emitted
payloads = notifications.get_plugin().emit(
notification_type, context, result, args, kwargs)
# Enqueue the notification
for payload in payloads:
LOG.debug('Queueing notification for %(type)s ',
{'type': notification_type})
NOTIFICATION_BUFFER.queue.appendleft(
(context, notification_type, payload,))
return result
finally:
NOTIFICATION_BUFFER.stack -= 1
if NOTIFICATION_BUFFER.stack == 0:
LOG.debug('Emitting %(count)d notifications',
{'count': len(NOTIFICATION_BUFFER.queue)})
# Send the queued notifications, in order.
for value in NOTIFICATION_BUFFER.queue:
LOG.debug('Emitting %(type)s notification',
{'type': value[1]})
self.notifier.info(value[0], value[1], value[2])
# Reset the queue
NOTIFICATION_BUFFER.queue.clear()
return notification_wrapper
return outer
class Service(service.RPCService):
@ -188,6 +56,9 @@ class Service(service.RPCService):
target = messaging.Target(version=RPC_API_VERSION)
def __init__(self):
self.zone_lock_local = lock.ZoneLockLocal()
self.notification_thread_local = notification.NotificationThreadLocal()
self._scheduler = None
self._storage = None
self._quota = None
@ -196,11 +67,9 @@ class Service(service.RPCService):
self.service_name, cfg.CONF['service:central'].topic,
threads=cfg.CONF['service:central'].threads,
)
self.coordination = coordination.Coordination(
self.service_name, self.tg, grouping_enabled=False
)
self.network_api = network_api.get_network_api(cfg.CONF.network_api)
@property
@ -713,7 +582,7 @@ class Service(service.RPCService):
# TLD Methods
@rpc.expected_exceptions()
@notification('dns.tld.create')
@notification.notify_type('dns.tld.create')
@transaction
def create_tld(self, context, tld):
policy.check('create_tld', context)
@ -738,7 +607,7 @@ class Service(service.RPCService):
return self.storage.get_tld(context, tld_id)
@rpc.expected_exceptions()
@notification('dns.tld.update')
@notification.notify_type('dns.tld.update')
@transaction
def update_tld(self, context, tld):
target = {
@ -751,7 +620,7 @@ class Service(service.RPCService):
return tld
@rpc.expected_exceptions()
@notification('dns.tld.delete')
@notification.notify_type('dns.tld.delete')
@transaction
def delete_tld(self, context, tld_id):
policy.check('delete_tld', context, {'tld_id': tld_id})
@ -762,7 +631,7 @@ class Service(service.RPCService):
# TSIG Key Methods
@rpc.expected_exceptions()
@notification('dns.tsigkey.create')
@notification.notify_type('dns.tsigkey.create')
@transaction
def create_tsigkey(self, context, tsigkey):
policy.check('create_tsigkey', context)
@ -788,7 +657,7 @@ class Service(service.RPCService):
return self.storage.get_tsigkey(context, tsigkey_id)
@rpc.expected_exceptions()
@notification('dns.tsigkey.update')
@notification.notify_type('dns.tsigkey.update')
@transaction
def update_tsigkey(self, context, tsigkey):
target = {
@ -803,7 +672,7 @@ class Service(service.RPCService):
return tsigkey
@rpc.expected_exceptions()
@notification('dns.tsigkey.delete')
@notification.notify_type('dns.tsigkey.delete')
@transaction
def delete_tsigkey(self, context, tsigkey_id):
policy.check('delete_tsigkey', context, {'tsigkey_id': tsigkey_id})
@ -862,9 +731,9 @@ class Service(service.RPCService):
return pool.ns_records
@rpc.expected_exceptions()
@notification('dns.domain.create')
@notification('dns.zone.create')
@synchronized_zone(new_zone=True)
@notification.notify_type('dns.domain.create')
@notification.notify_type('dns.zone.create')
@lock.synchronized_zone(new_zone=True)
def create_zone(self, context, zone):
"""Create zone: perform checks and then call _create_zone()
"""
@ -1060,9 +929,9 @@ class Service(service.RPCService):
sort_key, sort_dir)
@rpc.expected_exceptions()
@notification('dns.domain.update')
@notification('dns.zone.update')
@synchronized_zone()
@notification.notify_type('dns.domain.update')
@notification.notify_type('dns.zone.update')
@lock.synchronized_zone()
def update_zone(self, context, zone, increment_serial=True):
"""Update zone. Perform checks and then call _update_zone()
@ -1134,9 +1003,9 @@ class Service(service.RPCService):
return zone
@rpc.expected_exceptions()
@notification('dns.domain.delete')
@notification('dns.zone.delete')
@synchronized_zone()
@notification.notify_type('dns.domain.delete')
@notification.notify_type('dns.zone.delete')
@lock.synchronized_zone()
def delete_zone(self, context, zone_id):
"""Delete or abandon a zone
On abandon, delete the zone from the DB immediately.
@ -1294,8 +1163,8 @@ class Service(service.RPCService):
# RecordSet Methods
@rpc.expected_exceptions()
@notification('dns.recordset.create')
@synchronized_zone()
@notification.notify_type('dns.recordset.create')
@lock.synchronized_zone()
def create_recordset(self, context, zone_id, recordset,
increment_serial=True):
zone = self.storage.get_zone(context, zone_id)
@ -1467,8 +1336,8 @@ class Service(service.RPCService):
recordsets=recordsets)
@rpc.expected_exceptions()
@notification('dns.recordset.update')
@synchronized_zone()
@notification.notify_type('dns.recordset.update')
@lock.synchronized_zone()
def update_recordset(self, context, recordset, increment_serial=True):
zone_id = recordset.obj_get_original_value('zone_id')
zone = self.storage.get_zone(context, zone_id)
@ -1550,8 +1419,8 @@ class Service(service.RPCService):
return recordset, zone
@rpc.expected_exceptions()
@notification('dns.recordset.delete')
@synchronized_zone()
@notification.notify_type('dns.recordset.delete')
@lock.synchronized_zone()
def delete_recordset(self, context, zone_id, recordset_id,
increment_serial=True):
zone = self.storage.get_zone(context, zone_id)
@ -2049,7 +1918,7 @@ class Service(service.RPCService):
# Blacklisted zones
@rpc.expected_exceptions()
@notification('dns.blacklist.create')
@notification.notify_type('dns.blacklist.create')
@transaction
def create_blacklist(self, context, blacklist):
policy.check('create_blacklist', context)
@ -2078,7 +1947,7 @@ class Service(service.RPCService):
return blacklists
@rpc.expected_exceptions()
@notification('dns.blacklist.update')
@notification.notify_type('dns.blacklist.update')
@transaction
def update_blacklist(self, context, blacklist):
target = {
@ -2091,7 +1960,7 @@ class Service(service.RPCService):
return blacklist
@rpc.expected_exceptions()
@notification('dns.blacklist.delete')
@notification.notify_type('dns.blacklist.delete')
@transaction
def delete_blacklist(self, context, blacklist_id):
policy.check('delete_blacklist', context)
@ -2102,7 +1971,7 @@ class Service(service.RPCService):
# Server Pools
@rpc.expected_exceptions()
@notification('dns.pool.create')
@notification.notify_type('dns.pool.create')
@transaction
def create_pool(self, context, pool):
# Verify that there is a tenant_id
@ -2141,7 +2010,7 @@ class Service(service.RPCService):
return self.storage.get_pool(context, pool_id)
@rpc.expected_exceptions()
@notification('dns.pool.update')
@notification.notify_type('dns.pool.update')
@transaction
def update_pool(self, context, pool):
policy.check('update_pool', context)
@ -2202,7 +2071,7 @@ class Service(service.RPCService):
return updated_pool
@rpc.expected_exceptions()
@notification('dns.pool.delete')
@notification.notify_type('dns.pool.delete')
@transaction
def delete_pool(self, context, pool_id):
@ -2225,10 +2094,10 @@ class Service(service.RPCService):
# Pool Manager Integration
@rpc.expected_exceptions()
@notification('dns.domain.update')
@notification('dns.zone.update')
@notification.notify_type('dns.domain.update')
@notification.notify_type('dns.zone.update')
@transaction
@synchronized_zone()
@lock.synchronized_zone()
def update_status(self, context, zone_id, status, serial, action=None):
"""
:param context: Security context information.
@ -2356,7 +2225,7 @@ class Service(service.RPCService):
return ''.join(sysrand.choice(chars) for _ in range(size))
@rpc.expected_exceptions()
@notification('dns.zone_transfer_request.create')
@notification.notify_type('dns.zone_transfer_request.create')
@transaction
def create_zone_transfer_request(self, context, zone_transfer_request):
@ -2427,7 +2296,7 @@ class Service(service.RPCService):
return requests
@rpc.expected_exceptions()
@notification('dns.zone_transfer_request.update')
@notification.notify_type('dns.zone_transfer_request.update')
@transaction
def update_zone_transfer_request(self, context, zone_transfer_request):
@ -2449,7 +2318,7 @@ class Service(service.RPCService):
return request
@rpc.expected_exceptions()
@notification('dns.zone_transfer_request.delete')
@notification.notify_type('dns.zone_transfer_request.delete')
@transaction
def delete_zone_transfer_request(self, context, zone_transfer_request_id):
# Get zone transfer request
@ -2469,7 +2338,7 @@ class Service(service.RPCService):
zone_transfer_request_id)
@rpc.expected_exceptions()
@notification('dns.zone_transfer_accept.create')
@notification.notify_type('dns.zone_transfer_accept.create')
@transaction
def create_zone_transfer_accept(self, context, zone_transfer_accept):
elevated_context = context.elevated(all_tenants=True)
@ -2571,7 +2440,7 @@ class Service(service.RPCService):
# Zone Import Methods
@rpc.expected_exceptions()
@notification('dns.zone_import.create')
@notification.notify_type('dns.zone_import.create')
def create_zone_import(self, context, request_body):
if policy.enforce_new_defaults():
target = {constants.RBAC_PROJECT_ID: context.project_id}
@ -2667,7 +2536,7 @@ class Service(service.RPCService):
self.update_zone_import(context, zone_import)
@rpc.expected_exceptions()
@notification('dns.zone_import.update')
@notification.notify_type('dns.zone_import.update')
def update_zone_import(self, context, zone_import):
if policy.enforce_new_defaults():
target = {constants.RBAC_PROJECT_ID: zone_import.tenant_id}
@ -2710,7 +2579,7 @@ class Service(service.RPCService):
return self.storage.get_zone_import(context, zone_import_id)
@rpc.expected_exceptions()
@notification('dns.zone_import.delete')
@notification.notify_type('dns.zone_import.delete')
@transaction
def delete_zone_import(self, context, zone_import_id):
@ -2733,7 +2602,7 @@ class Service(service.RPCService):
# Zone Export Methods
@rpc.expected_exceptions()
@notification('dns.zone_export.create')
@notification.notify_type('dns.zone_export.create')
def create_zone_export(self, context, zone_id):
# Try getting the zone to ensure it exists
zone = self.storage.get_zone(context, zone_id)
@ -2797,7 +2666,7 @@ class Service(service.RPCService):
return self.storage.get_zone_export(context, zone_export_id)
@rpc.expected_exceptions()
@notification('dns.zone_export.update')
@notification.notify_type('dns.zone_export.update')
def update_zone_export(self, context, zone_export):
if policy.enforce_new_defaults():
@ -2810,7 +2679,7 @@ class Service(service.RPCService):
return self.storage.update_zone_export(context, zone_export)
@rpc.expected_exceptions()
@notification('dns.zone_export.delete')
@notification.notify_type('dns.zone_export.delete')
@transaction
def delete_zone_export(self, context, zone_export_id):

View File

View File

@ -0,0 +1,107 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import functools
import itertools
import threading
from oslo_log import log as logging
from designate import objects
LOG = logging.getLogger(__name__)
class ZoneLockLocal(threading.local):
def __init__(self):
super(ZoneLockLocal, self).__init__()
self._held = set()
def hold(self, name):
self._held.add(name)
def release(self, name):
self._held.remove(name)
def has_lock(self, name):
return name in self._held
def extract_zone_id(args, kwargs):
zone_id = None
if 'zone_id' in kwargs:
zone_id = kwargs['zone_id']
elif 'zone' in kwargs:
zone_id = kwargs['zone'].id
elif 'recordset' in kwargs:
zone_id = kwargs['recordset'].zone_id
elif 'record' in kwargs:
zone_id = kwargs['record'].zone_id
if not zone_id:
for arg in itertools.chain(args, kwargs.values()):
if not isinstance(arg, objects.DesignateObject):
continue
if isinstance(arg, objects.Zone):
zone_id = arg.id
if zone_id:
break
elif isinstance(arg, (objects.RecordSet,
objects.Record,
objects.ZoneTransferRequest,
objects.ZoneTransferAccept)):
zone_id = arg.zone_id
if zone_id:
break
if not zone_id and len(args) > 1:
arg = args[1]
if isinstance(arg, str):
zone_id = arg
elif isinstance(zone_id, objects.Zone):
zone_id = arg.id
return zone_id
def synchronized_zone(new_zone=False):
"""Ensures only a single operation is in progress for each zone
A Decorator which ensures only a single operation can be happening
on a single zone at once, within the current designate-central instance
"""
def outer(f):
@functools.wraps(f)
def sync_wrapper(cls, *args, **kwargs):
if new_zone is True:
lock_name = 'create-new-zone'
else:
zone_id = extract_zone_id(args, kwargs)
if zone_id:
lock_name = 'zone-%s' % zone_id
else:
raise Exception('Failed to determine zone id for '
'synchronized operation')
if cls.zone_lock_local.has_lock(lock_name):
return f(cls, *args, **kwargs)
with cls.coordination.get_lock(lock_name):
try:
cls.zone_lock_local.hold(lock_name)
return f(cls, *args, **kwargs)
finally:
cls.zone_lock_local.release(lock_name)
return sync_wrapper
return outer

View File

@ -0,0 +1,90 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import collections
import functools
import itertools
import threading
from oslo_log import log as logging
from designate import context as designate_context
from designate import notifications
LOG = logging.getLogger(__name__)
class NotificationThreadLocal(threading.local):
def __init__(self):
super(NotificationThreadLocal, self).__init__()
self.stack = 0
self.queue = collections.deque()
def reset_queue(self):
self.queue.clear()
def notify_type(notification_type):
def outer(f):
@functools.wraps(f)
def notification_wrapper(cls, *args, **kwargs):
cls.notification_thread_local.stack += 1
context = None
for arg in itertools.chain(args, kwargs.values()):
if isinstance(arg, designate_context.DesignateContext):
context = arg
break
try:
result = f(cls, *args, **kwargs)
payloads = notifications.get_plugin().emit(
notification_type, context, result, args, kwargs
)
for payload in payloads:
LOG.debug(
'Queueing notification for %(type)s',
{
'type': notification_type
}
)
cls.notification_thread_local.queue.appendleft(
(context, notification_type, payload,)
)
return result
finally:
cls.notification_thread_local.stack -= 1
if cls.notification_thread_local.stack == 0:
LOG.debug(
'Emitting %(count)d notifications',
{
'count': len(cls.notification_thread_local.queue)
}
)
for message in cls.notification_thread_local.queue:
LOG.debug(
'Emitting %(type)s notification',
{
'type': message[1]
}
)
cls.notifier.info(message[0], message[1], message[2])
cls.notification_thread_local.reset_queue()
return notification_wrapper
return outer

View File

@ -0,0 +1,49 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import functools
import threading
from oslo_messaging.rpc import dispatcher as rpc_dispatcher
import designate.exceptions
class ExceptionThreadLocal(threading.local):
def __init__(self):
super(ExceptionThreadLocal, self).__init__()
self.depth = 0
def reset_depth(self):
self.depth = 0
def expected_exceptions():
def outer(f):
@functools.wraps(f)
def exception_wrapper(cls, *args, **kwargs):
cls.exception_thread_local.depth += 1
# We only want to wrap the first function wrapped.
if cls.exception_thread_local.depth > 1:
return f(cls, *args, **kwargs)
try:
return f(cls, *args, **kwargs)
except designate.exceptions.DesignateException as e:
if e.expected:
raise rpc_dispatcher.ExpectedException()
raise
finally:
cls.exception_thread_local.reset_depth()
return exception_wrapper
return outer

View File

@ -14,7 +14,6 @@
# License for the specific language governing permissions and limitations
# under the License.
import copy
import itertools
from keystoneauth1.access import service_catalog as ksa_service_catalog
from keystoneauth1 import plugin
@ -145,21 +144,6 @@ class DesignateContext(context.RequestContext):
return cls(None, **kwargs)
@classmethod
def get_context_from_function_and_args(cls, function, args, kwargs):
"""
Find an arg of type DesignateContext and return it.
This is useful in a couple of decorators where we don't
know much about the function we're wrapping.
"""
for arg in itertools.chain(kwargs.values(), args):
if isinstance(arg, cls):
return arg
return None
@property
def all_tenants(self):
return self._all_tenants

View File

@ -11,8 +11,6 @@
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import functools
import threading
from oslo_config import cfg
import oslo_messaging as messaging
@ -40,7 +38,6 @@ __all__ = [
]
CONF = cfg.CONF
EXPECTED_EXCEPTION = threading.local()
NOTIFICATION_TRANSPORT = None
NOTIFIER = None
TRANSPORT = None
@ -237,27 +234,3 @@ def create_transport(url):
return messaging.get_rpc_transport(CONF,
url=url,
allowed_remote_exmods=exmods)
def expected_exceptions():
def outer(f):
@functools.wraps(f)
def exception_wrapper(self, *args, **kwargs):
if not hasattr(EXPECTED_EXCEPTION, 'depth'):
EXPECTED_EXCEPTION.depth = 0
EXPECTED_EXCEPTION.depth += 1
# We only want to wrap the first function wrapped.
if EXPECTED_EXCEPTION.depth > 1:
return f(self, *args, **kwargs)
try:
return f(self, *args, **kwargs)
except designate.exceptions.DesignateException as e:
if e.expected:
raise rpc_dispatcher.ExpectedException()
raise
finally:
EXPECTED_EXCEPTION.depth = 0
return exception_wrapper
return outer

View File

@ -30,6 +30,7 @@ from oslo_service import sslutils
from oslo_service import wsgi
from oslo_utils import netutils
from designate.common.decorators import rpc as rpc_decorator
from designate.common import profiler
import designate.conf
from designate.i18n import _
@ -77,6 +78,7 @@ class RPCService(Service):
rpc_topic, self.name)
self.endpoints = [self]
self.exception_thread_local = rpc_decorator.ExceptionThreadLocal()
self.notifier = None
self.rpc_server = None
self.rpc_topic = rpc_topic

View File

@ -1,75 +0,0 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from unittest import mock
from oslo_concurrency import lockutils
from oslo_log import log as logging
from designate.central import service
from designate import exceptions
from designate.objects import record
from designate.objects import zone
from designate.tests.test_central import CentralTestCase
from designate import utils
LOG = logging.getLogger(__name__)
class FakeCoordination(object):
def get_lock(self, name):
return lockutils.lock(name)
class CentralDecoratorTests(CentralTestCase):
def test_synchronized_zone_exception_raised(self):
@service.synchronized_zone()
def mock_get_zone(cls, index, zone):
self.assertEqual(service.ZONE_LOCKS.held, {zone.id})
if index % 3 == 0:
raise exceptions.ZoneNotFound()
for index in range(9):
try:
mock_get_zone(mock.Mock(coordination=FakeCoordination()),
index,
zone.Zone(id=utils.generate_uuid()))
except exceptions.ZoneNotFound:
pass
def test_synchronized_zone_recursive_decorator_call(self):
@service.synchronized_zone()
def mock_create_record(cls, context, record):
self.assertEqual(service.ZONE_LOCKS.held, {record.zone_id})
mock_get_zone(cls, context, zone.Zone(id=record.zone_id))
@service.synchronized_zone()
def mock_get_zone(cls, context, zone):
self.assertEqual(service.ZONE_LOCKS.held, {zone.id})
mock_create_record(mock.Mock(coordination=FakeCoordination()),
self.get_context(),
record=record.Record(zone_id=utils.generate_uuid()))
mock_get_zone(mock.Mock(coordination=FakeCoordination()),
self.get_context(),
zone=zone.Zone(id=utils.generate_uuid()))
def test_synchronized_zone_raises_exception_when_no_zone_provided(self):
@service.synchronized_zone(new_zone=False)
def mock_not_creating_new_zone(cls, context, record):
pass
self.assertRaisesRegex(
Exception,
'Failed to determine zone id for '
'synchronized operation',
mock_not_creating_new_zone, self.get_context(), None
)

View File

@ -392,7 +392,7 @@ class CentralServiceTestCase(CentralBasic):
def test_create_recordset_in_storage(self):
self.service._enforce_recordset_quota = mock.Mock()
self.service._validate_recordset = mock.Mock()
self.service._validate_recordset = mock.Mock(spec=objects.RecordSet)
self.service.storage.create_recordset = mock.Mock(return_value='rs')
self.service._update_zone_in_storage = mock.Mock()
@ -416,7 +416,7 @@ class CentralServiceTestCase(CentralBasic):
central_service.storage.create_recordset = mock.Mock(return_value='rs')
central_service._update_zone_in_storage = mock.Mock()
recordset = mock.Mock()
recordset = mock.Mock(spec=objects.RecordSet)
recordset.obj_attr_is_set.return_value = True
recordset.records = [MockRecord()]
@ -441,7 +441,7 @@ class CentralServiceTestCase(CentralBasic):
# NOTE(thirose): Since this is a race condition we assume that
# we will hit it if we try to do the operations in a loop 100 times.
for num in range(100):
recordset = mock.Mock()
recordset = mock.Mock(spec=objects.RecordSet)
recordset.name = "b{}".format(num)
recordset.obj_attr_is_set.return_value = True
recordset.records = [MockRecord()]
@ -1148,7 +1148,7 @@ class CentralZoneTestCase(CentralBasic):
def test_update_recordset_fail_on_changes(self):
self.service.storage.get_zone.return_value = RoObject()
recordset = mock.Mock()
recordset = mock.Mock(spec=objects.RecordSet)
recordset.obj_get_original_value.return_value = '1'
recordset.obj_get_changes.return_value = ['tenant_id', 'foo']
@ -1179,7 +1179,7 @@ class CentralZoneTestCase(CentralBasic):
self.service.storage.get_zone.return_value = RoObject(
action='DELETE',
)
recordset = mock.Mock()
recordset = mock.Mock(spec=objects.RecordSet)
recordset.obj_get_changes.return_value = ['foo']
exc = self.assertRaises(rpc_dispatcher.ExpectedException,
@ -1196,7 +1196,7 @@ class CentralZoneTestCase(CentralBasic):
tenant_id='2',
action='bogus',
)
recordset = mock.Mock()
recordset = mock.Mock(spec=objects.RecordSet)
recordset.obj_get_changes.return_value = ['foo']
recordset.managed = True
self.context = mock.Mock()
@ -1216,10 +1216,11 @@ class CentralZoneTestCase(CentralBasic):
tenant_id='2',
action='bogus',
)
recordset = mock.Mock()
recordset = mock.Mock(spec=objects.RecordSet)
recordset.obj_get_changes.return_value = ['foo']
recordset.obj_get_original_value.return_value =\
recordset.obj_get_original_value.return_value = (
'9c85d9b0-1e9d-4e99-aede-a06664f1af2e'
)
recordset.managed = False
self.service._update_recordset_in_storage = mock.Mock(
return_value=('x', 'y')
@ -1239,7 +1240,7 @@ class CentralZoneTestCase(CentralBasic):
'recordset_id': '9c85d9b0-1e9d-4e99-aede-a06664f1af2e',
'project_id': '2'}, target)
def test__update_recordset_in_storage(self):
def test_update_recordset_in_storage(self):
recordset = mock.Mock()
recordset.name = 'n'
recordset.type = 't'
@ -1426,7 +1427,7 @@ class CentralZoneTestCase(CentralBasic):
self.assertTrue(
self.service._delete_recordset_in_storage.called)
def test__delete_recordset_in_storage(self):
def test_delete_recordset_in_storage(self):
def mock_uds(c, zone, inc):
return zone
self.service._update_zone_in_storage = mock_uds
@ -1730,7 +1731,7 @@ class CentralQuotaTest(unittest.TestCase):
service = Service()
service.storage.count_records.return_value = 10
recordset = mock.Mock()
recordset = mock.Mock(spec=objects.RecordSet)
recordset.managed = False
recordset.records = ['1.1.1.%i' % (i + 1) for i in range(5)]
@ -1801,7 +1802,7 @@ class CentralQuotaTest(unittest.TestCase):
1, 1,
]
managed_recordset = mock.Mock()
managed_recordset = mock.Mock(spec=objects.RecordSet)
managed_recordset.managed = True
recordset_one_record = mock.Mock()

View File

@ -0,0 +1,111 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from unittest import mock
from oslo_concurrency import lockutils
from oslo_log import log as logging
import oslotest.base
from designate.common.decorators import lock
from designate import exceptions
from designate.objects import record
from designate.objects import zone
from designate import utils
LOG = logging.getLogger(__name__)
class FakeCoordination:
def get_lock(self, name):
return lockutils.lock(name)
class FakeService:
def __init__(self):
self.zone_lock_local = lock.ZoneLockLocal()
self.coordination = FakeCoordination()
class CentralDecoratorTests(oslotest.base.BaseTestCase):
def setUp(self):
super().setUp()
self.context = mock.Mock()
self.service = FakeService()
def test_synchronized_zone_exception_raised(self):
@lock.synchronized_zone()
def mock_get_zone(cls, current_index, zone_obj):
self.assertEqual(
{'zone-%s' % zone_obj.id}, cls.zone_lock_local._held
)
if current_index % 3 == 0:
raise exceptions.ZoneNotFound()
for index in range(9):
try:
mock_get_zone(
self.service, index, zone.Zone(id=utils.generate_uuid())
)
except exceptions.ZoneNotFound:
pass
def test_synchronized_new_zone_with_recursion(self):
@lock.synchronized_zone(new_zone=True)
def mock_create_zone(cls, context):
self.assertEqual({'create-new-zone'}, cls.zone_lock_local._held)
mock_create_record(
cls, context, zone.Zone(id=utils.generate_uuid())
)
@lock.synchronized_zone()
def mock_create_record(cls, context, zone_obj):
self.assertIn('zone-%s' % zone_obj.id, cls.zone_lock_local._held)
self.assertIn('create-new-zone', cls.zone_lock_local._held)
mock_create_zone(
self.service, self.context
)
def test_synchronized_zone_recursive_decorator_call(self):
@lock.synchronized_zone()
def mock_create_record(cls, context, record_obj):
self.assertEqual(
{'zone-%s' % record_obj.zone_id}, cls.zone_lock_local._held
)
mock_get_zone(cls, context, zone.Zone(id=record_obj.zone_id))
@lock.synchronized_zone()
def mock_get_zone(cls, context, zone_obj):
self.assertEqual(
{'zone-%s' % zone_obj.id}, cls.zone_lock_local._held
)
mock_create_record(
self.service, self.context,
record_obj=record.Record(zone_id=utils.generate_uuid())
)
mock_get_zone(
self.service, self.context,
zone_obj=zone.Zone(id=utils.generate_uuid())
)
def test_synchronized_zone_raises_exception_when_no_zone_provided(self):
@lock.synchronized_zone(new_zone=False)
def mock_not_creating_new_zone(cls, context, record_obj):
pass
self.assertRaisesRegex(
Exception,
'Failed to determine zone id for synchronized operation',
mock_not_creating_new_zone, self.service, mock.Mock(), None
)

View File

@ -21,9 +21,9 @@ import oslo_messaging as messaging
from designate import backend
from designate.central import rpcapi as central_api
from designate.common.decorators import rpc
from designate.context import DesignateContext
from designate import exceptions
from designate import rpc
from designate import service
from designate import storage
from designate.worker import processing