Merge "Move to a batch model for incrementing serial"

This commit is contained in:
Zuul 2023-04-06 22:51:09 +00:00 committed by Gerrit Code Review
commit f4ce71c8f8
21 changed files with 518 additions and 190 deletions
designate
central
conf
objects
producer
sqlalchemy
storage
tests
test_api/test_v2
test_central
test_producer
test_storage
unit
utils.py
worker/tasks
releasenotes/notes
setup.cfg

View File

@ -68,8 +68,9 @@ class CentralAPI(object):
6.4 - Removed unused record and diagnostic methods 6.4 - Removed unused record and diagnostic methods
6.5 - Removed additional unused methods 6.5 - Removed additional unused methods
6.6 - Add methods for shared zones 6.6 - Add methods for shared zones
6.7 - Add increment_zone_serial
""" """
RPC_API_VERSION = '6.6' RPC_API_VERSION = '6.7'
# This allows us to mark some methods as not logged. # This allows us to mark some methods as not logged.
# This can be for a few reasons - some methods my not actually call over # This can be for a few reasons - some methods my not actually call over
@ -82,7 +83,7 @@ class CentralAPI(object):
target = messaging.Target(topic=self.topic, target = messaging.Target(topic=self.topic,
version=self.RPC_API_VERSION) version=self.RPC_API_VERSION)
self.client = rpc.get_client(target, version_cap='6.6') self.client = rpc.get_client(target, version_cap='6.7')
@classmethod @classmethod
def get_instance(cls): def get_instance(cls):
@ -141,6 +142,9 @@ class CentralAPI(object):
return self.client.call(context, 'get_tenant', tenant_id=tenant_id) return self.client.call(context, 'get_tenant', tenant_id=tenant_id)
# Zone Methods # Zone Methods
def increment_zone_serial(self, context, zone):
return self.client.call(context, 'increment_zone_serial', zone=zone)
def create_zone(self, context, zone): def create_zone(self, context, zone):
return self.client.call(context, 'create_zone', zone=zone) return self.client.call(context, 'create_zone', zone=zone)

View File

@ -27,6 +27,7 @@ from dns import zone as dnszone
from oslo_config import cfg from oslo_config import cfg
from oslo_log import log as logging from oslo_log import log as logging
import oslo_messaging as messaging import oslo_messaging as messaging
from oslo_utils import timeutils
from designate.common import constants from designate.common import constants
from designate.common.decorators import lock from designate.common.decorators import lock
@ -51,7 +52,7 @@ LOG = logging.getLogger(__name__)
class Service(service.RPCService): class Service(service.RPCService):
RPC_API_VERSION = '6.6' RPC_API_VERSION = '6.7'
target = messaging.Target(version=RPC_API_VERSION) target = messaging.Target(version=RPC_API_VERSION)
@ -349,48 +350,35 @@ class Service(service.RPCService):
"A project ID must be specified when not using a project " "A project ID must be specified when not using a project "
"scoped token.") "scoped token.")
def _increment_zone_serial(self, context, zone, set_delayed_notify=False):
"""Update the zone serial and the SOA record
Optionally set delayed_notify to have PM issue delayed notify
"""
# Increment the serial number
zone.serial = utils.increment_serial(zone.serial)
if set_delayed_notify:
zone.delayed_notify = True
zone = self.storage.update_zone(context, zone)
# Update SOA record
self._update_soa(context, zone)
return zone
# SOA Recordset Methods # SOA Recordset Methods
def _build_soa_record(self, zone, ns_records): @staticmethod
return "%s %s. %d %d %d %d %d" % (ns_records[0]['hostname'], def _build_soa_record(zone, ns_records):
zone['email'].replace("@", "."), return '%s %s. %d %d %d %d %d' % (
zone['serial'], ns_records[0]['hostname'],
zone['refresh'], zone['email'].replace('@', '.'),
zone['retry'], zone['serial'],
zone['expire'], zone['refresh'],
zone['minimum']) zone['retry'],
zone['expire'],
zone['minimum']
)
def _create_soa(self, context, zone): def _create_soa(self, context, zone):
pool_ns_records = self._get_pool_ns_records(context, zone.pool_id) pool_ns_records = self._get_pool_ns_records(context, zone.pool_id)
records = objects.RecordList(objects=[
soa_values = [self._build_soa_record(zone, pool_ns_records)] objects.Record(
recordlist = objects.RecordList(objects=[ data=self._build_soa_record(zone, pool_ns_records),
objects.Record(data=r, managed=True) for r in soa_values]) managed=True
values = { )
'name': zone['name'], ])
'type': "SOA", return self._create_recordset_in_storage(
'records': recordlist context, zone,
} objects.RecordSet(
soa, zone = self._create_recordset_in_storage( name=zone['name'],
context, zone, objects.RecordSet(**values), type='SOA',
increment_serial=False) records=records
return soa ), increment_serial=False
)[0]
def _update_soa(self, context, zone): def _update_soa(self, context, zone):
# NOTE: We should not be updating SOA records when a zone is SECONDARY. # NOTE: We should not be updating SOA records when a zone is SECONDARY.
@ -400,14 +388,18 @@ class Service(service.RPCService):
# Get the pool for it's list of ns_records # Get the pool for it's list of ns_records
pool_ns_records = self._get_pool_ns_records(context, zone.pool_id) pool_ns_records = self._get_pool_ns_records(context, zone.pool_id)
soa = self.find_recordset(context, soa = self.find_recordset(
criterion={'zone_id': zone['id'], context, criterion={
'type': "SOA"}) 'zone_id': zone['id'],
'type': 'SOA'
}
)
soa.records[0].data = self._build_soa_record(zone, pool_ns_records) soa.records[0].data = self._build_soa_record(zone, pool_ns_records)
self._update_recordset_in_storage(context, zone, soa, self._update_recordset_in_storage(
increment_serial=False) context, zone, soa, increment_serial=False
)
# NS Recordset Methods # NS Recordset Methods
def _create_ns(self, context, zone, ns_records): def _create_ns(self, context, zone, ns_records):
@ -730,6 +722,14 @@ class Service(service.RPCService):
pool = self.storage.get_pool(elevated_context, pool_id) pool = self.storage.get_pool(elevated_context, pool_id)
return pool.ns_records return pool.ns_records
@rpc.expected_exceptions()
@transaction
@lock.synchronized_zone()
def increment_zone_serial(self, context, zone):
zone.serial = self.storage.increment_serial(context, zone.id)
self._update_soa(context, zone)
return zone.serial
@rpc.expected_exceptions() @rpc.expected_exceptions()
@notification.notify_type('dns.domain.create') @notification.notify_type('dns.domain.create')
@notification.notify_type('dns.zone.create') @notification.notify_type('dns.zone.create')
@ -853,7 +853,8 @@ class Service(service.RPCService):
# can be very long-lived. # can be very long-lived.
time.sleep(0) time.sleep(0)
self._create_recordset_in_storage( self._create_recordset_in_storage(
context, zone, rrset, increment_serial=False) context, zone, rrset, increment_serial=False
)
return zone return zone
@ -992,29 +993,28 @@ class Service(service.RPCService):
"""Update zone """Update zone
""" """
zone = self._update_zone_in_storage( zone = self._update_zone_in_storage(
context, zone, increment_serial=increment_serial) context, zone, increment_serial=increment_serial
)
# Fire off a XFR # Fire off a XFR
if 'masters' in changes: if 'masters' in changes:
self.worker_api.perform_zone_xfr(context, zone) self.worker_api.perform_zone_xfr(context, zone)
self.worker_api.update_zone(context, zone)
return zone return zone
@transaction @transaction
def _update_zone_in_storage(self, context, zone, def _update_zone_in_storage(self, context, zone,
increment_serial=True, set_delayed_notify=False): increment_serial=True,
set_delayed_notify=False):
zone.action = 'UPDATE' zone.action = 'UPDATE'
zone.status = 'PENDING' zone.status = 'PENDING'
if increment_serial: if increment_serial:
# _increment_zone_serial increments and updates the zone zone.increment_serial = True
zone = self._increment_zone_serial( if set_delayed_notify:
context, zone, set_delayed_notify=set_delayed_notify) zone.delayed_notify = True
else:
zone = self.storage.update_zone(context, zone) zone = self.storage.update_zone(context, zone)
return zone return zone
@ -1333,7 +1333,6 @@ class Service(service.RPCService):
# RecordSet Methods # RecordSet Methods
@rpc.expected_exceptions() @rpc.expected_exceptions()
@notification.notify_type('dns.recordset.create') @notification.notify_type('dns.recordset.create')
@lock.synchronized_zone()
def create_recordset(self, context, zone_id, recordset, def create_recordset(self, context, zone_id, recordset,
increment_serial=True): increment_serial=True):
zone = self.storage.get_zone(context, zone_id, zone = self.storage.get_zone(context, zone_id,
@ -1376,9 +1375,8 @@ class Service(service.RPCService):
context = context.elevated(all_tenants=True) context = context.elevated(all_tenants=True)
recordset, zone = self._create_recordset_in_storage( recordset, zone = self._create_recordset_in_storage(
context, zone, recordset, increment_serial=increment_serial) context, zone, recordset, increment_serial=increment_serial
)
self.worker_api.update_zone(context, zone)
recordset.zone_name = zone.name recordset.zone_name = zone.name
recordset.obj_reset_changes(['zone_name']) recordset.obj_reset_changes(['zone_name'])
@ -1416,33 +1414,33 @@ class Service(service.RPCService):
@transaction_shallow_copy @transaction_shallow_copy
def _create_recordset_in_storage(self, context, zone, recordset, def _create_recordset_in_storage(self, context, zone, recordset,
increment_serial=True): increment_serial=True):
# Ensure the tenant has enough quota to continue # Ensure the tenant has enough quota to continue
self._enforce_recordset_quota(context, zone) self._enforce_recordset_quota(context, zone)
self._validate_recordset(context, zone, recordset) self._validate_recordset(context, zone, recordset)
if recordset.obj_attr_is_set('records') and len(recordset.records) > 0: if recordset.obj_attr_is_set('records') and recordset.records:
# Ensure the tenant has enough zone record quotas to # Ensure the tenant has enough zone record quotas to
# create new records # create new records
self._enforce_record_quota(context, zone, recordset) self._enforce_record_quota(context, zone, recordset)
if increment_serial:
# update the zone's status and increment the serial
zone = self._update_zone_in_storage(
context, zone, increment_serial)
for record in recordset.records: for record in recordset.records:
record.action = 'CREATE' record.action = 'CREATE'
record.status = 'PENDING' record.status = 'PENDING'
record.serial = zone.serial if not increment_serial:
record.serial = zone.serial
else:
record.serial = timeutils.utcnow_ts()
recordset = self.storage.create_recordset(context, zone.id, new_recordset = self.storage.create_recordset(context, zone.id,
recordset) recordset)
if recordset.records and increment_serial:
# update the zone's status and increment the serial
zone = self._update_zone_in_storage(
context, zone, increment_serial
)
# Return the zone too in case it was updated # Return the zone too in case it was updated
return (recordset, zone) return new_recordset, zone
@rpc.expected_exceptions() @rpc.expected_exceptions()
def get_recordset(self, context, zone_id, recordset_id): def get_recordset(self, context, zone_id, recordset_id):
@ -1557,7 +1555,6 @@ class Service(service.RPCService):
@rpc.expected_exceptions() @rpc.expected_exceptions()
@notification.notify_type('dns.recordset.update') @notification.notify_type('dns.recordset.update')
@lock.synchronized_zone()
def update_recordset(self, context, recordset, increment_serial=True): def update_recordset(self, context, recordset, increment_serial=True):
zone_id = recordset.obj_get_original_value('zone_id') zone_id = recordset.obj_get_original_value('zone_id')
changes = recordset.obj_get_changes() changes = recordset.obj_get_changes()
@ -1626,41 +1623,44 @@ class Service(service.RPCService):
recordset, zone = self._update_recordset_in_storage( recordset, zone = self._update_recordset_in_storage(
context, zone, recordset, increment_serial=increment_serial) context, zone, recordset, increment_serial=increment_serial)
self.worker_api.update_zone(context, zone)
return recordset return recordset
@transaction @transaction
def _update_recordset_in_storage(self, context, zone, recordset, def _update_recordset_in_storage(self, context, zone, recordset,
increment_serial=True, set_delayed_notify=False): increment_serial=True,
set_delayed_notify=False):
self._validate_recordset(context, zone, recordset) self._validate_recordset(context, zone, recordset)
if increment_serial:
# update the zone's status and increment the serial
zone = self._update_zone_in_storage(
context, zone, increment_serial,
set_delayed_notify=set_delayed_notify)
if recordset.records: if recordset.records:
for record in recordset.records: for record in recordset.records:
if record.action != 'DELETE': if record.action == 'DELETE':
record.action = 'UPDATE' continue
record.status = 'PENDING' record.action = 'UPDATE'
record.status = 'PENDING'
if not increment_serial:
record.serial = zone.serial record.serial = zone.serial
else:
record.serial = timeutils.utcnow_ts()
# Ensure the tenant has enough zone record quotas to # Ensure the tenant has enough zone record quotas to
# create new records # create new records
self._enforce_record_quota(context, zone, recordset) self._enforce_record_quota(context, zone, recordset)
# Update the recordset # Update the recordset
recordset = self.storage.update_recordset(context, recordset) new_recordset = self.storage.update_recordset(context, recordset)
return recordset, zone if increment_serial:
# update the zone's status and increment the serial
zone = self._update_zone_in_storage(
context, zone,
increment_serial=increment_serial,
set_delayed_notify=set_delayed_notify)
return new_recordset, zone
@rpc.expected_exceptions() @rpc.expected_exceptions()
@notification.notify_type('dns.recordset.delete') @notification.notify_type('dns.recordset.delete')
@lock.synchronized_zone()
def delete_recordset(self, context, zone_id, recordset_id, def delete_recordset(self, context, zone_id, recordset_id,
increment_serial=True): increment_serial=True):
# apply_tenant_criteria=False here as we will gate this delete # apply_tenant_criteria=False here as we will gate this delete
@ -1712,8 +1712,6 @@ class Service(service.RPCService):
recordset, zone = self._delete_recordset_in_storage( recordset, zone = self._delete_recordset_in_storage(
context, zone, recordset, increment_serial=increment_serial) context, zone, recordset, increment_serial=increment_serial)
self.worker_api.update_zone(context, zone)
recordset.zone_name = zone.name recordset.zone_name = zone.name
recordset.obj_reset_changes(['zone_name']) recordset.obj_reset_changes(['zone_name'])
@ -1722,23 +1720,26 @@ class Service(service.RPCService):
@transaction @transaction
def _delete_recordset_in_storage(self, context, zone, recordset, def _delete_recordset_in_storage(self, context, zone, recordset,
increment_serial=True): increment_serial=True):
if recordset.records:
for record in recordset.records:
record.action = 'DELETE'
record.status = 'PENDING'
if not increment_serial:
record.serial = zone.serial
else:
record.serial = timeutils.utcnow_ts()
# Update the recordset's action/status and then delete it
self.storage.update_recordset(context, recordset)
if increment_serial: if increment_serial:
# update the zone's status and increment the serial # update the zone's status and increment the serial
zone = self._update_zone_in_storage( zone = self._update_zone_in_storage(
context, zone, increment_serial) context, zone, increment_serial)
if recordset.records: new_recordset = self.storage.delete_recordset(context, recordset.id)
for record in recordset.records:
record.action = 'DELETE'
record.status = 'PENDING'
record.serial = zone.serial
# Update the recordset's action/status and then delete it return new_recordset, zone
self.storage.update_recordset(context, recordset)
recordset = self.storage.delete_recordset(context, recordset.id)
return (recordset, zone)
@rpc.expected_exceptions() @rpc.expected_exceptions()
def count_recordsets(self, context, criterion=None): def count_recordsets(self, context, criterion=None):

View File

@ -20,6 +20,11 @@ PRODUCER_GROUP = cfg.OptGroup(
title='Configuration for Producer Service' title='Configuration for Producer Service'
) )
PRODUCER_TASK_INCREMENT_SERIAL_GROUP = cfg.OptGroup(
name='producer_task:increment_serial',
title='Configuration for Producer Task: Increment Serial'
)
PRODUCER_TASK_DELAYED_NOTIFY_GROUP = cfg.OptGroup( PRODUCER_TASK_DELAYED_NOTIFY_GROUP = cfg.OptGroup(
name='producer_task:delayed_notify', name='producer_task:delayed_notify',
title='Configuration for Producer Task: Delayed Notify' title='Configuration for Producer Task: Delayed Notify'
@ -62,6 +67,15 @@ PRODUCER_OPTS = [
help='RPC topic name for producer'), help='RPC topic name for producer'),
] ]
PRODUCER_TASK_INCREMENT_SERIAL_OPTS = [
cfg.IntOpt('interval', default=5,
help='Run interval in seconds'),
cfg.IntOpt('per_page', default=100,
help='Default amount of results returned per page'),
cfg.IntOpt('batch_size', default=100,
help='How many zones to increment serial for on each run'),
]
PRODUCER_TASK_DELAYED_NOTIFY_OPTS = [ PRODUCER_TASK_DELAYED_NOTIFY_OPTS = [
cfg.IntOpt('interval', default=5, cfg.IntOpt('interval', default=5,
help='Run interval in seconds'), help='Run interval in seconds'),
@ -111,6 +125,9 @@ def register_opts(conf):
conf.register_group(PRODUCER_TASK_DELAYED_NOTIFY_GROUP) conf.register_group(PRODUCER_TASK_DELAYED_NOTIFY_GROUP)
conf.register_opts(PRODUCER_TASK_DELAYED_NOTIFY_OPTS, conf.register_opts(PRODUCER_TASK_DELAYED_NOTIFY_OPTS,
group=PRODUCER_TASK_DELAYED_NOTIFY_GROUP) group=PRODUCER_TASK_DELAYED_NOTIFY_GROUP)
conf.register_group(PRODUCER_TASK_INCREMENT_SERIAL_GROUP)
conf.register_opts(PRODUCER_TASK_INCREMENT_SERIAL_OPTS,
group=PRODUCER_TASK_INCREMENT_SERIAL_GROUP)
conf.register_group(PRODUCER_TASK_PERIODIC_EXISTS_GROUP) conf.register_group(PRODUCER_TASK_PERIODIC_EXISTS_GROUP)
conf.register_opts(PRODUCER_TASK_PERIODIC_EXISTS_OPTS, conf.register_opts(PRODUCER_TASK_PERIODIC_EXISTS_OPTS,
group=PRODUCER_TASK_PERIODIC_EXISTS_GROUP) group=PRODUCER_TASK_PERIODIC_EXISTS_GROUP)

View File

@ -66,6 +66,7 @@ class Zone(base.DesignateObject, base.DictObjectMixin,
), ),
'transferred_at': fields.DateTimeField(nullable=True, read_only=False), 'transferred_at': fields.DateTimeField(nullable=True, read_only=False),
'delayed_notify': fields.BooleanField(nullable=True), 'delayed_notify': fields.BooleanField(nullable=True),
'increment_serial': fields.BooleanField(nullable=True),
} }
STRING_KEYS = [ STRING_KEYS = [

View File

@ -227,8 +227,6 @@ class PeriodicGenerateDelayedNotifyTask(PeriodicTask):
Call Worker to emit NOTIFY transactions, Call Worker to emit NOTIFY transactions,
Reset the flag. Reset the flag.
""" """
pstart, pend = self._my_range()
ctxt = context.DesignateContext.get_admin_context() ctxt = context.DesignateContext.get_admin_context()
ctxt.all_tenants = True ctxt.all_tenants = True
@ -237,6 +235,7 @@ class PeriodicGenerateDelayedNotifyTask(PeriodicTask):
# There's an index on delayed_notify. # There's an index on delayed_notify.
criterion = self._filter_between('shard') criterion = self._filter_between('shard')
criterion['delayed_notify'] = True criterion['delayed_notify'] = True
criterion['increment_serial'] = False
zones = self.central_api.find_zones( zones = self.central_api.find_zones(
ctxt, ctxt,
criterion, criterion,
@ -246,6 +245,17 @@ class PeriodicGenerateDelayedNotifyTask(PeriodicTask):
) )
for zone in zones: for zone in zones:
if zone.action == 'NONE':
zone.action = 'UPDATE'
zone.status = 'PENDING'
elif zone.action == 'DELETE':
LOG.debug(
'Skipping delayed NOTIFY for %(id)s being DELETED',
{
'id': zone.id,
}
)
continue
self.worker_api.update_zone(ctxt, zone) self.worker_api.update_zone(ctxt, zone)
zone.delayed_notify = False zone.delayed_notify = False
self.central_api.update_zone(ctxt, zone) self.central_api.update_zone(ctxt, zone)
@ -257,6 +267,54 @@ class PeriodicGenerateDelayedNotifyTask(PeriodicTask):
) )
class PeriodicIncrementSerialTask(PeriodicTask):
__plugin_name__ = 'increment_serial'
def __init__(self):
super(PeriodicIncrementSerialTask, self).__init__()
def __call__(self):
ctxt = context.DesignateContext.get_admin_context()
ctxt.all_tenants = True
# Select zones where "increment_serial" is set and starting from the
# oldest "updated_at".
# There's an index on increment_serial.
criterion = self._filter_between('shard')
criterion['increment_serial'] = True
zones = self.central_api.find_zones(
ctxt,
criterion,
limit=CONF[self.name].batch_size,
sort_key='updated_at',
sort_dir='asc',
)
for zone in zones:
if zone.action == 'DELETE':
LOG.debug(
'Skipping increment serial for %(id)s being DELETED',
{
'id': zone.id,
}
)
continue
serial = self.central_api.increment_zone_serial(ctxt, zone)
LOG.debug(
'Incremented serial for %(id)s to %(serial)d',
{
'id': zone.id,
'serial': serial,
}
)
if not zone.delayed_notify:
# Notify the backend.
if zone.action == 'NONE':
zone.action = 'UPDATE'
zone.status = 'PENDING'
self.worker_api.update_zone(ctxt, zone)
class WorkerPeriodicRecovery(PeriodicTask): class WorkerPeriodicRecovery(PeriodicTask):
__plugin_name__ = 'worker_periodic_recovery' __plugin_name__ = 'worker_periodic_recovery'

View File

@ -152,7 +152,7 @@ class SQLAlchemy(object, metaclass=abc.ABCMeta):
# Ensure the Object is valid # Ensure the Object is valid
# obj.validate() # obj.validate()
values = obj.obj_get_changes() values = dict(obj)
if skip_values is not None: if skip_values is not None:
for skip_value in skip_values: for skip_value in skip_values:
@ -166,7 +166,7 @@ class SQLAlchemy(object, metaclass=abc.ABCMeta):
with sql.get_write_session() as session: with sql.get_write_session() as session:
try: try:
resultproxy = session.execute(query, [dict(values)]) resultproxy = session.execute(query, [values])
except oslo_db_exception.DBDuplicateEntry: except oslo_db_exception.DBDuplicateEntry:
raise exc_dup("Duplicate %s" % obj.obj_name()) raise exc_dup("Duplicate %s" % obj.obj_name())

View File

@ -744,6 +744,15 @@ class Storage(DriverPlugin, metaclass=abc.ABCMeta):
:param zone_import: Zone Import to update. :param zone_import: Zone Import to update.
""" """
@abc.abstractmethod
def increment_serial(self, context, zone_id):
"""
Increment serial of a Zone
:param context: RPC Context.
:param zone_id: ID of the Zone.
"""
@abc.abstractmethod @abc.abstractmethod
def delete_zone_import(self, context, zone_import_id): def delete_zone_import(self, context, zone_import_id):
""" """

View File

@ -15,6 +15,7 @@
# under the License. # under the License.
from oslo_log import log as logging from oslo_log import log as logging
from oslo_utils.secretutils import md5 from oslo_utils.secretutils import md5
from oslo_utils import timeutils
from sqlalchemy import case, select, distinct, func from sqlalchemy import case, select, distinct, func
from sqlalchemy.sql.expression import or_, literal_column from sqlalchemy.sql.expression import or_, literal_column
@ -435,6 +436,19 @@ class SQLAlchemyStorage(sqlalchemy_base.SQLAlchemy, storage_base.Storage):
return updated_zone return updated_zone
def increment_serial(self, context, zone_id):
"""Increment the zone's serial number.
"""
new_serial = timeutils.utcnow_ts()
query = tables.zones.update().where(
tables.zones.c.id == zone_id).values(
{'serial': new_serial, 'increment_serial': False}
)
with sql.get_write_session() as session:
session.execute(query)
LOG.debug('Incremented zone serial for %s to %d', zone_id, new_serial)
return new_serial
def delete_zone(self, context, zone_id): def delete_zone(self, context, zone_id):
""" """
""" """

View File

@ -0,0 +1,38 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Add increment serial
Revision ID: a005af3aa38e
Revises: b20189fd288e
Create Date: 2023-01-21 17:39:00.822775
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'a005af3aa38e'
down_revision = 'b20189fd288e'
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
'zones',
sa.Column('increment_serial', sa.Boolean, default=False)
)
op.create_index(
'increment_serial', 'zones', ['increment_serial']
)

View File

@ -139,6 +139,7 @@ zones = Table('zones', metadata,
Column('pool_id', UUID, default=None, nullable=True), Column('pool_id', UUID, default=None, nullable=True),
Column('reverse_name', String(255), nullable=False), Column('reverse_name', String(255), nullable=False),
Column('delayed_notify', Boolean, default=False), Column('delayed_notify', Boolean, default=False),
Column('increment_serial', Boolean, default=False),
UniqueConstraint('name', 'deleted', 'pool_id', name='unique_zone_name'), UniqueConstraint('name', 'deleted', 'pool_id', name='unique_zone_name'),
ForeignKeyConstraint(['parent_zone_id'], ForeignKeyConstraint(['parent_zone_id'],

View File

@ -17,6 +17,7 @@ from unittest.mock import patch
from oslo_log import log as logging from oslo_log import log as logging
import oslo_messaging as messaging import oslo_messaging as messaging
from oslo_utils import timeutils
from designate.central import service as central_service from designate.central import service as central_service
from designate import exceptions from designate import exceptions
@ -438,10 +439,9 @@ class ApiV2RecordSetsTest(ApiV2TestCase):
self.client.delete(url, status=202, headers={'X-Test-Role': 'member'}) self.client.delete(url, status=202, headers={'X-Test-Role': 'member'})
# Simulate the zone having been deleted on the backend # Simulate the zone having been deleted on the backend
zone_serial = self.central_service.get_zone(
self.admin_context, zone['id']).serial
self.central_service.update_status( self.central_service.update_status(
self.admin_context, zone['id'], 'SUCCESS', zone_serial, 'UPDATE' self.admin_context, zone['id'], 'SUCCESS', timeutils.utcnow_ts(),
'DELETE'
) )
# Try to get the record and ensure that we get a # Try to get the record and ensure that we get a

View File

@ -19,6 +19,7 @@ from collections import namedtuple
from concurrent import futures from concurrent import futures
import copy import copy
import datetime import datetime
import futurist import futurist
import random import random
import unittest import unittest
@ -33,7 +34,6 @@ from oslo_messaging.rpc import dispatcher as rpc_dispatcher
from oslo_utils import timeutils from oslo_utils import timeutils
from oslo_versionedobjects import exception as ovo_exc from oslo_versionedobjects import exception as ovo_exc
import testtools import testtools
from testtools.matchers import GreaterThan
from designate import exceptions from designate import exceptions
from designate import objects from designate import objects
@ -890,7 +890,7 @@ class CentralServiceTest(CentralTestCase):
def test_update_zone(self, mock_notifier): def test_update_zone(self, mock_notifier):
# Create a zone # Create a zone
zone = self.create_zone(email='info@example.org') zone = self.create_zone(email='info@example.org')
original_serial = zone.serial self.assertFalse(zone.increment_serial)
# Update the object # Update the object
zone.email = 'info@example.net' zone.email = 'info@example.net'
@ -906,7 +906,7 @@ class CentralServiceTest(CentralTestCase):
self.admin_context, zone.id) self.admin_context, zone.id)
# Ensure the zone was updated correctly # Ensure the zone was updated correctly
self.assertGreater(zone.serial, original_serial) self.assertTrue(zone.increment_serial)
self.assertEqual('info@example.net', zone.email) self.assertEqual('info@example.net', zone.email)
self.assertEqual(2, mock_notifier.call_count) self.assertEqual(2, mock_notifier.call_count)
@ -931,7 +931,7 @@ class CentralServiceTest(CentralTestCase):
def test_update_zone_without_incrementing_serial(self): def test_update_zone_without_incrementing_serial(self):
# Create a zone # Create a zone
zone = self.create_zone(email='info@example.org') zone = self.create_zone(email='info@example.org')
original_serial = zone.serial self.assertFalse(zone.increment_serial)
# Update the object # Update the object
zone.email = 'info@example.net' zone.email = 'info@example.net'
@ -944,7 +944,7 @@ class CentralServiceTest(CentralTestCase):
zone = self.central_service.get_zone(self.admin_context, zone.id) zone = self.central_service.get_zone(self.admin_context, zone.id)
# Ensure the zone was updated correctly # Ensure the zone was updated correctly
self.assertEqual(original_serial, zone.serial) self.assertFalse(zone.increment_serial)
self.assertEqual('info@example.net', zone.email) self.assertEqual('info@example.net', zone.email)
def test_update_zone_name_fail(self): def test_update_zone_name_fail(self):
@ -965,7 +965,7 @@ class CentralServiceTest(CentralTestCase):
def test_update_zone_deadlock_retry(self): def test_update_zone_deadlock_retry(self):
# Create a zone # Create a zone
zone = self.create_zone(name='example.org.') zone = self.create_zone(name='example.org.')
original_serial = zone.serial self.assertFalse(zone.increment_serial)
# Update the Object # Update the Object
zone.email = 'info@example.net' zone.email = 'info@example.net'
@ -992,7 +992,7 @@ class CentralServiceTest(CentralTestCase):
self.assertTrue(i[0]) self.assertTrue(i[0])
# Ensure the zone was updated correctly # Ensure the zone was updated correctly
self.assertGreater(zone.serial, original_serial) self.assertTrue(zone.increment_serial)
self.assertEqual('info@example.net', zone.email) self.assertEqual('info@example.net', zone.email)
@mock.patch.object(notifier.Notifier, "info") @mock.patch.object(notifier.Notifier, "info")
@ -1457,7 +1457,7 @@ class CentralServiceTest(CentralTestCase):
# RecordSet Tests # RecordSet Tests
def test_create_recordset(self): def test_create_recordset(self):
zone = self.create_zone() zone = self.create_zone()
original_serial = zone.serial self.assertFalse(zone.increment_serial)
# Create the Object # Create the Object
recordset = objects.RecordSet(name='www.%s' % zone.name, type='A') recordset = objects.RecordSet(name='www.%s' % zone.name, type='A')
@ -1469,7 +1469,6 @@ class CentralServiceTest(CentralTestCase):
# Get the zone again to check if serial increased # Get the zone again to check if serial increased
updated_zone = self.central_service.get_zone(self.admin_context, updated_zone = self.central_service.get_zone(self.admin_context,
zone.id) zone.id)
new_serial = updated_zone.serial
# Ensure all values have been set correctly # Ensure all values have been set correctly
self.assertIsNotNone(recordset.id) self.assertIsNotNone(recordset.id)
@ -1479,7 +1478,7 @@ class CentralServiceTest(CentralTestCase):
self.assertIsNotNone(recordset.records) self.assertIsNotNone(recordset.records)
# The serial number does not get updated is there are no records # The serial number does not get updated is there are no records
# in the recordset # in the recordset
self.assertEqual(original_serial, new_serial) self.assertFalse(updated_zone.increment_serial)
def test_create_recordset_shared_zone(self): def test_create_recordset_shared_zone(self):
zone = self.create_zone() zone = self.create_zone()
@ -1546,15 +1545,15 @@ class CentralServiceTest(CentralTestCase):
def test_create_recordset_with_records(self): def test_create_recordset_with_records(self):
zone = self.create_zone() zone = self.create_zone()
original_serial = zone.serial self.assertFalse(zone.increment_serial)
# Create the Object # Create the Object
recordset = objects.RecordSet( recordset = objects.RecordSet(
name='www.%s' % zone.name, name='www.%s' % zone.name,
type='A', type='A',
records=objects.RecordList(objects=[ records=objects.RecordList(objects=[
objects.Record(data='192.3.3.15'), objects.Record(data='192.0.2.15'),
objects.Record(data='192.3.3.16'), objects.Record(data='192.0.2.16'),
]) ])
) )
@ -1565,14 +1564,13 @@ class CentralServiceTest(CentralTestCase):
# Get updated serial number # Get updated serial number
updated_zone = self.central_service.get_zone(self.admin_context, updated_zone = self.central_service.get_zone(self.admin_context,
zone.id) zone.id)
new_serial = updated_zone.serial
# Ensure all values have been set correctly # Ensure all values have been set correctly
self.assertIsNotNone(recordset.records) self.assertIsNotNone(recordset.records)
self.assertEqual(2, len(recordset.records)) self.assertEqual(2, len(recordset.records))
self.assertIsNotNone(recordset.records[0].id) self.assertIsNotNone(recordset.records[0].id)
self.assertIsNotNone(recordset.records[1].id) self.assertIsNotNone(recordset.records[1].id)
self.assertThat(new_serial, GreaterThan(original_serial)) self.assertTrue(updated_zone.increment_serial)
def test_create_recordset_over_quota(self): def test_create_recordset_over_quota(self):
# SOA, NS recordsets exist by default. # SOA, NS recordsets exist by default.
@ -1851,7 +1849,7 @@ class CentralServiceTest(CentralTestCase):
def test_update_recordset(self): def test_update_recordset(self):
# Create a zone # Create a zone
zone = self.create_zone() zone = self.create_zone()
original_serial = zone.serial self.assertFalse(zone.increment_serial)
# Create a recordset # Create a recordset
recordset = self.create_recordset(zone) recordset = self.create_recordset(zone)
@ -1865,7 +1863,7 @@ class CentralServiceTest(CentralTestCase):
# Get zone again to verify that serial number was updated # Get zone again to verify that serial number was updated
updated_zone = self.central_service.get_zone(self.admin_context, updated_zone = self.central_service.get_zone(self.admin_context,
zone.id) zone.id)
new_serial = updated_zone.serial self.assertTrue(updated_zone.increment_serial)
# Fetch the resource again # Fetch the resource again
recordset = self.central_service.get_recordset( recordset = self.central_service.get_recordset(
@ -1873,7 +1871,6 @@ class CentralServiceTest(CentralTestCase):
# Ensure the new value took # Ensure the new value took
self.assertEqual(1800, recordset.ttl) self.assertEqual(1800, recordset.ttl)
self.assertThat(new_serial, GreaterThan(original_serial))
@unittest.expectedFailure # FIXME @unittest.expectedFailure # FIXME
def test_update_recordset_deadlock_retry(self): def test_update_recordset_deadlock_retry(self):
@ -1936,7 +1933,7 @@ class CentralServiceTest(CentralTestCase):
def test_update_recordset_with_record_delete(self): def test_update_recordset_with_record_delete(self):
# Create a zone # Create a zone
zone = self.create_zone() zone = self.create_zone()
original_serial = zone.serial self.assertFalse(zone.increment_serial)
# Create a recordset and two records # Create a recordset and two records
records = [ records = [
@ -1960,12 +1957,11 @@ class CentralServiceTest(CentralTestCase):
# Fetch the Zone again # Fetch the Zone again
updated_zone = self.central_service.get_zone(self.admin_context, updated_zone = self.central_service.get_zone(self.admin_context,
zone.id) zone.id)
new_serial = updated_zone.serial
# Ensure two Records are attached to the RecordSet correctly # Ensure two Records are attached to the RecordSet correctly
self.assertEqual(1, len(recordset.records)) self.assertEqual(1, len(recordset.records))
self.assertIsNotNone(recordset.records[0].id) self.assertIsNotNone(recordset.records[0].id)
self.assertThat(new_serial, GreaterThan(original_serial)) self.assertTrue(updated_zone.increment_serial)
def test_update_recordset_with_record_update(self): def test_update_recordset_with_record_update(self):
# Create a zone # Create a zone
@ -2066,7 +2062,7 @@ class CentralServiceTest(CentralTestCase):
def test_update_recordset_shared_zone(self): def test_update_recordset_shared_zone(self):
# Create a zone # Create a zone
zone = self.create_zone() zone = self.create_zone()
original_serial = zone.serial self.assertFalse(zone.increment_serial)
context = self.get_context(project_id='1', roles=['member', 'reader']) context = self.get_context(project_id='1', roles=['member', 'reader'])
self.share_zone(context=self.admin_context, zone_id=zone.id, self.share_zone(context=self.admin_context, zone_id=zone.id,
@ -2084,7 +2080,9 @@ class CentralServiceTest(CentralTestCase):
# Get zone again to verify that serial number was updated # Get zone again to verify that serial number was updated
updated_zone = self.central_service.get_zone(self.admin_context, updated_zone = self.central_service.get_zone(self.admin_context,
zone.id) zone.id)
new_serial = updated_zone.serial
# Ensure that we are incrementing the zone serial
self.assertTrue(updated_zone.increment_serial)
# Fetch the resource again # Fetch the resource again
recordset = self.central_service.get_recordset( recordset = self.central_service.get_recordset(
@ -2092,11 +2090,10 @@ class CentralServiceTest(CentralTestCase):
# Ensure the new value took # Ensure the new value took
self.assertEqual(1800, recordset.ttl) self.assertEqual(1800, recordset.ttl)
self.assertThat(new_serial, GreaterThan(original_serial))
def test_delete_recordset(self): def test_delete_recordset(self):
zone = self.create_zone() zone = self.create_zone()
original_serial = zone.serial self.assertFalse(zone.increment_serial)
# Create a recordset # Create a recordset
recordset = self.create_recordset(zone) recordset = self.create_recordset(zone)
@ -2116,8 +2113,7 @@ class CentralServiceTest(CentralTestCase):
# Fetch the zone again to verify serial number increased # Fetch the zone again to verify serial number increased
updated_zone = self.central_service.get_zone(self.admin_context, updated_zone = self.central_service.get_zone(self.admin_context,
zone.id) zone.id)
new_serial = updated_zone.serial self.assertTrue(updated_zone.increment_serial)
self.assertThat(new_serial, GreaterThan(original_serial))
def test_delete_recordset_without_incrementing_serial(self): def test_delete_recordset_without_incrementing_serial(self):
zone = self.create_zone() zone = self.create_zone()
@ -2219,8 +2215,8 @@ class CentralServiceTest(CentralTestCase):
name='www.%s' % zone.name, name='www.%s' % zone.name,
type='A', type='A',
records=objects.RecordList(objects=[ records=objects.RecordList(objects=[
objects.Record(data='192.3.3.15'), objects.Record(data='203.0.113.15'),
objects.Record(data='192.3.3.16'), objects.Record(data='203.0.113.16'),
]) ])
) )
@ -2401,7 +2397,7 @@ class CentralServiceTest(CentralTestCase):
# Ensure that the record is still in DB (No invalidation) # Ensure that the record is still in DB (No invalidation)
self.central_service.find_records(elevated_a, criterion) self.central_service.find_records(elevated_a, criterion)
# Now give the fip id to tenant 'b' and see that it get's deleted # Now give the fip id to tenant 'b' and see that it gets deleted
self.network_api.fake.allocate_floatingip( self.network_api.fake.allocate_floatingip(
context_b.project_id, fip['id']) context_b.project_id, fip['id'])
@ -2411,8 +2407,10 @@ class CentralServiceTest(CentralTestCase):
self.assertIsNone(fip_ptr['ptrdname']) self.assertIsNone(fip_ptr['ptrdname'])
# Simulate the invalidation on the backend # Simulate the invalidation on the backend
zone_serial = self.central_service.get_zone( zone = self.central_service.get_zone(
elevated_a, zone_id).serial elevated_a, zone_id)
zone_serial = self.central_service.increment_zone_serial(
elevated_a, zone)
self.central_service.update_status( self.central_service.update_status(
elevated_a, zone_id, 'SUCCESS', zone_serial, 'UPDATE') elevated_a, zone_id, 'SUCCESS', zone_serial, 'UPDATE')
@ -2482,10 +2480,8 @@ class CentralServiceTest(CentralTestCase):
elevated_a, criterion)[0].zone_id elevated_a, criterion)[0].zone_id
# Simulate the update on the backend # Simulate the update on the backend
zone_serial = self.central_service.get_zone(
elevated_a, zone_id).serial
self.central_service.update_status( self.central_service.update_status(
elevated_a, zone_id, 'SUCCESS', zone_serial, 'UPDATE') elevated_a, zone_id, 'SUCCESS', timeutils.utcnow_ts(), 'UPDATE')
self.network_api.fake.deallocate_floatingip(fip['id']) self.network_api.fake.deallocate_floatingip(fip['id'])
@ -2495,7 +2491,7 @@ class CentralServiceTest(CentralTestCase):
# Ensure that the record is still in DB (No invalidation) # Ensure that the record is still in DB (No invalidation)
self.central_service.find_records(elevated_a, criterion) self.central_service.find_records(elevated_a, criterion)
# Now give the fip id to tenant 'b' and see that it get's deleted # Now give the fip id to tenant 'b' and see that it gets deleted
self.network_api.fake.allocate_floatingip( self.network_api.fake.allocate_floatingip(
context_b.project_id, fip['id']) context_b.project_id, fip['id'])
@ -2505,10 +2501,8 @@ class CentralServiceTest(CentralTestCase):
self.assertIsNone(fips[0]['ptrdname']) self.assertIsNone(fips[0]['ptrdname'])
# Simulate the invalidation on the backend # Simulate the invalidation on the backend
zone_serial = self.central_service.get_zone(
elevated_a, zone_id).serial
self.central_service.update_status( self.central_service.update_status(
elevated_a, zone_id, 'SUCCESS', zone_serial, 'UPDATE') elevated_a, zone_id, 'SUCCESS', timeutils.utcnow_ts(), 'UPDATE')
record = self.central_service.find_records(elevated_a, criterion)[0] record = self.central_service.find_records(elevated_a, criterion)[0]
self.assertEqual('NONE', record.action) self.assertEqual('NONE', record.action)
@ -3970,3 +3964,91 @@ class CentralServiceTest(CentralTestCase):
retrived_shared_zone.target_project_id) retrived_shared_zone.target_project_id)
self.assertEqual(shared_zone.project_id, self.assertEqual(shared_zone.project_id,
retrived_shared_zone.project_id) retrived_shared_zone.project_id)
def test_batch_increment_serial(self):
zone = self.create_zone()
zone_serial = zone.serial
self.assertFalse(zone.increment_serial)
for index in range(10):
recordset = objects.RecordSet(
name='www.%d.%s' % (index, zone.name),
type='A',
records=objects.RecordList(objects=[
objects.Record(data='192.0.2.%d' % index),
objects.Record(data='198.51.100.%d' % index),
])
)
self.central_service.create_recordset(
self.admin_context, zone.id, recordset=recordset
)
updated_zone = self.central_service.get_zone(
self.admin_context, zone.id
)
recordsets = self.central_service.find_recordsets(
self.admin_context,
criterion={'zone_id': zone.id, 'type': 'A'}
)
# Increment serial hasn't been triggered yet.
self.assertEqual(zone_serial, updated_zone.serial)
self.assertTrue(updated_zone.increment_serial)
self.assertEqual('PENDING', updated_zone.status)
self.assertEqual(10, len(recordsets))
for recordset in recordsets:
self.assertEqual('PENDING', recordset.status)
self.assertEqual(2, len(recordset.records))
for record in recordset.records:
self.assertEqual('PENDING', record.status)
# Increment serial (Producer -> Central) for zone.
with mock.patch.object(timeutils, 'utcnow_ts',
return_value=zone_serial + 5):
self.central_service.increment_zone_serial(
self.admin_context, zone
)
updated_zone = self.central_service.get_zone(
self.admin_context, zone.id
)
recordsets = self.central_service.find_recordsets(
self.admin_context,
criterion={'zone_id': zone.id, 'type': 'A'}
)
# Ensure that serial is now correct.
self.assertEqual(zone_serial + 5, updated_zone.serial)
self.assertFalse(updated_zone.increment_serial)
# But the zone is still in pending status as we haven't notified
# the upstream dns servers yet.
self.assertEqual('PENDING', updated_zone.status)
for recordset in recordsets:
self.assertEqual('PENDING', recordset.status)
for record in recordset.records:
self.assertEqual('PENDING', record.status)
# Trigger update_status (Producer -> Worker -> Central).
# This happens after the upstream DNS servers have been notified
# and updated.
self.central_service.update_status(
self.admin_context, zone.id, 'SUCCESS', updated_zone.serial
)
updated_zone = self.central_service.get_zone(
self.admin_context, zone.id
)
recordsets = self.central_service.find_recordsets(
self.admin_context,
criterion={'zone_id': zone.id, 'type': 'A'}
)
# Validate that the status is now ACTIVE.
self.assertEqual('ACTIVE', updated_zone.status)
self.assertEqual(zone_serial + 5, updated_zone.serial)
for recordset in recordsets:
self.assertEqual('ACTIVE', recordset.status)
for record in recordset.records:
self.assertEqual('ACTIVE', record.status)

View File

@ -15,6 +15,7 @@
# under the License. # under the License.
import datetime import datetime
from unittest import mock
from oslo_log import log as logging from oslo_log import log as logging
from oslo_utils import timeutils from oslo_utils import timeutils
@ -24,6 +25,7 @@ from designate.storage.impl_sqlalchemy import tables
from designate.storage import sql from designate.storage import sql
from designate.tests import fixtures from designate.tests import fixtures
from designate.tests import TestCase from designate.tests import TestCase
from designate.worker import rpcapi as worker_api
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -39,7 +41,7 @@ class DeletedZonePurgeTest(TestCase):
self.config( self.config(
time_threshold=self.time_threshold, time_threshold=self.time_threshold,
batch_size=self.batch_size, batch_size=self.batch_size,
group="producer_task:zone_purge" group='producer_task:zone_purge'
) )
self.purge_task_fixture = self.useFixture( self.purge_task_fixture = self.useFixture(
fixtures.ZoneManagerTaskFixture(tasks.DeletedZonePurgeTask) fixtures.ZoneManagerTaskFixture(tasks.DeletedZonePurgeTask)
@ -77,7 +79,7 @@ class DeletedZonePurgeTest(TestCase):
age = index * (self.time_threshold // self.number_of_zones * 2) - 1 age = index * (self.time_threshold // self.number_of_zones * 2) - 1
delta = datetime.timedelta(seconds=age) delta = datetime.timedelta(seconds=age)
deletion_time = now - delta deletion_time = now - delta
name = "example%d.org." % index name = 'example%d.org.' % index
self._create_deleted_zone(name, deletion_time) self._create_deleted_zone(name, deletion_time)
def test_purge_zones(self): def test_purge_zones(self):
@ -101,9 +103,8 @@ class PeriodicGenerateDelayedNotifyTaskTest(TestCase):
super(PeriodicGenerateDelayedNotifyTaskTest, self).setUp() super(PeriodicGenerateDelayedNotifyTaskTest, self).setUp()
self.config(quota_zones=self.number_of_zones) self.config(quota_zones=self.number_of_zones)
self.config( self.config(
interval=1,
batch_size=self.batch_size, batch_size=self.batch_size,
group="producer_task:delayed_notify" group='producer_task:delayed_notify'
) )
self.generate_delayed_notify_task_fixture = self.useFixture( self.generate_delayed_notify_task_fixture = self.useFixture(
fixtures.ZoneManagerTaskFixture( fixtures.ZoneManagerTaskFixture(
@ -123,7 +124,7 @@ class PeriodicGenerateDelayedNotifyTaskTest(TestCase):
def _create_zones(self): def _create_zones(self):
# Create a number of zones; half of them with delayed_notify set. # Create a number of zones; half of them with delayed_notify set.
for index in range(self.number_of_zones): for index in range(self.number_of_zones):
name = "example%d.org." % index name = 'example%d.org.' % index
delayed_notify = (index % 2 == 0) delayed_notify = (index % 2 == 0)
self.create_zone( self.create_zone(
name=name, name=name,
@ -149,3 +150,43 @@ class PeriodicGenerateDelayedNotifyTaskTest(TestCase):
remaining, len(zones), remaining, len(zones),
message='Remaining zones: %s' % zones message='Remaining zones: %s' % zones
) )
class PeriodicIncrementSerialTaskTest(TestCase):
number_of_zones = 20
batch_size = 20
def setUp(self):
super(PeriodicIncrementSerialTaskTest, self).setUp()
self.worker_api = mock.Mock()
mock.patch.object(worker_api.WorkerAPI, 'get_instance',
return_value=self.worker_api).start()
self.config(quota_zones=self.number_of_zones)
self.config(
batch_size=self.batch_size,
group='producer_task:increment_serial'
)
self.increment_serial_task_fixture = self.useFixture(
fixtures.ZoneManagerTaskFixture(
tasks.PeriodicIncrementSerialTask
)
)
def _create_zones(self):
for index in range(self.number_of_zones):
name = 'example%d.org.' % index
increment_serial = (index % 2 == 0)
delayed_notify = (index % 4 == 0)
self.create_zone(
name=name,
increment_serial=increment_serial,
delayed_notify=delayed_notify,
)
def test_increment_serial(self):
self._create_zones()
self.increment_serial_task_fixture.task()
self.worker_api.update_zone.assert_called()
self.assertEqual(5, self.worker_api.update_zone.call_count)

View File

@ -113,6 +113,7 @@ class SqlalchemyStorageTest(StorageTestCase, TestCase):
}, },
"zones": { "zones": {
"delayed_notify": "CREATE INDEX delayed_notify ON zones (delayed_notify)", # noqa "delayed_notify": "CREATE INDEX delayed_notify ON zones (delayed_notify)", # noqa
"increment_serial": "CREATE INDEX increment_serial ON zones (increment_serial)", # noqa
"reverse_name_deleted": "CREATE INDEX reverse_name_deleted ON zones (reverse_name, deleted)", # noqa "reverse_name_deleted": "CREATE INDEX reverse_name_deleted ON zones (reverse_name, deleted)", # noqa
"zone_created_at": "CREATE INDEX zone_created_at ON zones (created_at)", # noqa "zone_created_at": "CREATE INDEX zone_created_at ON zones (created_at)", # noqa
"zone_deleted": "CREATE INDEX zone_deleted ON zones (deleted)", "zone_deleted": "CREATE INDEX zone_deleted ON zones (deleted)",

View File

@ -31,7 +31,9 @@ from designate import context
from designate.producer import tasks from designate.producer import tasks
from designate import rpc from designate import rpc
from designate.tests.unit import RoObject from designate.tests.unit import RoObject
from designate.tests.unit import RwObject
from designate.utils import generate_uuid from designate.utils import generate_uuid
from designate.worker import rpcapi as worker_api
DUMMY_TASK_GROUP = cfg.OptGroup( DUMMY_TASK_GROUP = cfg.OptGroup(
name='producer_task:dummy', name='producer_task:dummy',
@ -244,3 +246,82 @@ class PeriodicSecondaryRefreshTest(oslotest.base.BaseTestCase):
self.task() self.task()
self.assertFalse(self.central.xfr_zone.called) self.assertFalse(self.central.xfr_zone.called)
class PeriodicIncrementSerialTest(oslotest.base.BaseTestCase):
def setUp(self):
super(PeriodicIncrementSerialTest, self).setUp()
self.useFixture(cfg_fixture.Config(CONF))
self.central_api = mock.Mock()
self.context = mock.Mock()
self.worker_api = mock.Mock()
mock.patch.object(worker_api.WorkerAPI, 'get_instance',
return_value=self.worker_api).start()
mock.patch.object(central_api.CentralAPI, 'get_instance',
return_value=self.central_api).start()
mock.patch.object(context.DesignateContext, 'get_admin_context',
return_value=self.context).start()
self.central_api.increment_zone_serial.return_value = 123
self.task = tasks.PeriodicIncrementSerialTask()
self.task.my_partitions = 0, 9
def test_increment_zone(self):
zone = RoObject(
id=generate_uuid(),
action='CREATE',
increment_serial=True,
delayed_notify=False,
)
self.central_api.find_zones.return_value = [zone]
self.task()
self.central_api.increment_zone_serial.assert_called()
self.worker_api.update_zone.assert_called()
def test_increment_zone_with_action_none(self):
zone = RwObject(
id=generate_uuid(),
action='NONE',
status='ACTIVE',
increment_serial=True,
delayed_notify=False,
)
self.central_api.find_zones.return_value = [zone]
self.task()
self.central_api.increment_zone_serial.assert_called()
self.worker_api.update_zone.assert_called()
self.assertEqual('UPDATE', zone.action)
self.assertEqual('PENDING', zone.status)
def test_increment_zone_with_delayed_notify(self):
zone = RoObject(
id=generate_uuid(),
action='CREATE',
increment_serial=True,
delayed_notify=True,
)
self.central_api.find_zones.return_value = [zone]
self.task()
self.central_api.increment_zone_serial.assert_called()
self.worker_api.update_zone.assert_not_called()
def test_increment_zone_skip_deleted(self):
zone = RoObject(
id=generate_uuid(),
action='DELETE',
increment_serial=True,
delayed_notify=False,
)
self.central_api.find_zones.return_value = [zone]
self.task()
self.central_api.increment_zone_serial.assert_not_called()
self.worker_api.update_zone.assert_not_called()

View File

@ -179,6 +179,7 @@ class MockRecordSet(object):
ttl = 1 ttl = 1
type = "PRIMARY" type = "PRIMARY"
serial = 123 serial = 123
records = []
def obj_attr_is_set(self, n): def obj_attr_is_set(self, n):
if n == 'records': if n == 'records':
@ -418,8 +419,9 @@ class CentralServiceTestCase(CentralBasic):
central_service._is_valid_ttl = mock.Mock() central_service._is_valid_ttl = mock.Mock()
central_service.storage.create_recordset = mock.Mock(return_value='rs') central_service.storage.create_recordset = mock.Mock(return_value='rs')
central_service._update_zone_in_storage = mock.Mock() central_service._update_zone_in_storage = mock.Mock(
return_value=Mockzone()
)
recordset = mock.Mock(spec=objects.RecordSet) recordset = mock.Mock(spec=objects.RecordSet)
recordset.obj_attr_is_set.return_value = True recordset.obj_attr_is_set.return_value = True
recordset.records = [MockRecord()] recordset.records = [MockRecord()]
@ -440,7 +442,9 @@ class CentralServiceTestCase(CentralBasic):
self.service._is_valid_ttl = mock.Mock() self.service._is_valid_ttl = mock.Mock()
self.service.storage.create_recordset = mock.Mock(return_value='rs') self.service.storage.create_recordset = mock.Mock(return_value='rs')
self.service._update_zone_in_storage = mock.Mock() self.service._update_zone_in_storage = mock.Mock(
return_value=Mockzone()
)
# NOTE(thirose): Since this is a race condition we assume that # NOTE(thirose): Since this is a race condition we assume that
# we will hit it if we try to do the operations in a loop 100 times. # we will hit it if we try to do the operations in a loop 100 times.
@ -1506,8 +1510,6 @@ class CentralZoneTestCase(CentralBasic):
self.service.delete_recordset(self.context, self.service.delete_recordset(self.context,
CentralZoneTestCase.zone__id_2, CentralZoneTestCase.zone__id_2,
CentralZoneTestCase.recordset__id) CentralZoneTestCase.recordset__id)
self.assertTrue(
self.service.worker_api.update_zone.called)
self.assertTrue( self.assertTrue(
self.service._delete_recordset_in_storage.called) self.service._delete_recordset_in_storage.called)
@ -1524,6 +1526,7 @@ class CentralZoneTestCase(CentralBasic):
action='', action='',
status='', status='',
serial=0, serial=0,
increment_serial=False,
) )
]) ])
) )
@ -1533,7 +1536,7 @@ class CentralZoneTestCase(CentralBasic):
self.assertEqual(1, len(rs.records)) self.assertEqual(1, len(rs.records))
self.assertEqual('DELETE', rs.records[0].action) self.assertEqual('DELETE', rs.records[0].action)
self.assertEqual('PENDING', rs.records[0].status) self.assertEqual('PENDING', rs.records[0].status)
self.assertEqual(1, rs.records[0].serial) self.assertTrue(rs.records[0].serial, 1)
def test_delete_recordset_in_storage_no_increment_serial(self): def test_delete_recordset_in_storage_no_increment_serial(self):
self.service._update_zone_in_storage = mock.Mock() self.service._update_zone_in_storage = mock.Mock()

View File

@ -16,7 +16,6 @@ import jinja2
from oslo_concurrency import processutils from oslo_concurrency import processutils
from oslo_config import cfg from oslo_config import cfg
from oslo_config import fixture as cfg_fixture from oslo_config import fixture as cfg_fixture
from oslo_utils import timeutils
import oslotest.base import oslotest.base
from designate import exceptions from designate import exceptions
@ -213,22 +212,6 @@ class TestUtils(oslotest.base.BaseTestCase):
self.assertEqual('Hello World', result) self.assertEqual('Hello World', result)
@mock.patch.object(timeutils, 'utcnow_ts')
def test_increment_serial_lower_than_ts(self, mock_utcnow_ts):
mock_utcnow_ts.return_value = 1561698354
ret_serial = utils.increment_serial(serial=1)
self.assertEqual(1561698354, ret_serial)
@mock.patch.object(timeutils, 'utcnow_ts')
def test_increment_serial_higher_than_ts(self, mock_utcnow_ts):
mock_utcnow_ts.return_value = 1561698354
ret_serial = utils.increment_serial(serial=1561698354 * 2)
self.assertEqual(1561698354 * 2 + 1, ret_serial)
def test_is_uuid_like(self): def test_is_uuid_like(self):
self.assertTrue( self.assertTrue(
utils.is_uuid_like('ce9fcd6b-d546-4397-8a49-8ceaec37cb64') utils.is_uuid_like('ce9fcd6b-d546-4397-8a49-8ceaec37cb64')

View File

@ -26,7 +26,6 @@ from oslo_config import cfg
from oslo_log import log as logging from oslo_log import log as logging
from oslo_serialization import jsonutils from oslo_serialization import jsonutils
from oslo_utils.netutils import is_valid_ipv6 from oslo_utils.netutils import is_valid_ipv6
from oslo_utils import timeutils
from oslo_utils import uuidutils from oslo_utils import uuidutils
import pkg_resources import pkg_resources
@ -119,16 +118,6 @@ def execute(*cmd, **kw):
root_helper=root_helper, **kw) root_helper=root_helper, **kw)
def increment_serial(serial=0):
# This provides for *roughly* unix timestamp based serial numbers
new_serial = timeutils.utcnow_ts()
if new_serial <= serial:
new_serial = serial + 1
return new_serial
def deep_dict_merge(a, b): def deep_dict_merge(a, b):
if not isinstance(b, dict): if not isinstance(b, dict):
return b return b

View File

@ -26,7 +26,6 @@ from oslo_utils import timeutils
from designate import dnsutils from designate import dnsutils
from designate import exceptions from designate import exceptions
from designate import objects from designate import objects
from designate import utils
from designate.worker.tasks import base from designate.worker.tasks import base
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@ -795,11 +794,10 @@ class RecoverShard(base.Task):
# Include things that have been hanging out in PENDING # Include things that have been hanging out in PENDING
# status for longer than they should # status for longer than they should
# Generate the current serial, will provide a UTC Unix TS. # Generate the current serial, will provide a UTC Unix TS.
current = utils.increment_serial()
stale_criterion = { stale_criterion = {
'shard': "BETWEEN %s,%s" % (self.begin_shard, self.end_shard), 'shard': "BETWEEN %s,%s" % (self.begin_shard, self.end_shard),
'status': 'PENDING', 'status': 'PENDING',
'serial': "<%s" % (current - self.max_prop_time) 'serial': "<%s" % (timeutils.utcnow_ts() - self.max_prop_time)
} }
stale_zones = self.storage.find_zones(self.context, stale_criterion) stale_zones = self.storage.find_zones(self.context, stale_criterion)

View File

@ -0,0 +1,6 @@
---
features:
- |
Moved zone serial updates to a `designate-producer` task called
`increment_serial` to fix race conditions and to reduce the number of
updates to the upstream DNS servers when performing multiple DNS updates.

View File

@ -118,6 +118,7 @@ designate.producer_tasks =
periodic_exists = designate.producer.tasks:PeriodicExistsTask periodic_exists = designate.producer.tasks:PeriodicExistsTask
periodic_secondary_refresh = designate.producer.tasks:PeriodicSecondaryRefreshTask periodic_secondary_refresh = designate.producer.tasks:PeriodicSecondaryRefreshTask
delayed_notify = designate.producer.tasks:PeriodicGenerateDelayedNotifyTask delayed_notify = designate.producer.tasks:PeriodicGenerateDelayedNotifyTask
increment_serial = designate.producer.tasks:PeriodicIncrementSerialTask
worker_periodic_recovery = designate.producer.tasks:WorkerPeriodicRecovery worker_periodic_recovery = designate.producer.tasks:WorkerPeriodicRecovery
designate.heartbeat_emitter = designate.heartbeat_emitter =