diff --git a/designate/central/service.py b/designate/central/service.py index d45e2f38a..9faf16edb 100644 --- a/designate/central/service.py +++ b/designate/central/service.py @@ -16,17 +16,20 @@ # under the License. import re import collections +import copy import functools import threading import itertools import string import random +import time from oslo.config import cfg from oslo import messaging from oslo_log import log as logging from oslo_utils import excutils from oslo_concurrency import lockutils +from oslo_db import exception as db_exception from designate.i18n import _LI from designate.i18n import _LC @@ -46,21 +49,78 @@ from designate.pool_manager import rpcapi as pool_manager_rpcapi LOG = logging.getLogger(__name__) DOMAIN_LOCKS = threading.local() NOTIFICATION_BUFFER = threading.local() +RETRY_STATE = threading.local() +def _retry_on_deadlock(exc): + """Filter to trigger retry a when a Deadlock is received.""" + # TODO(kiall): This is a total leak of the SQLA Driver, we'll need a better + # way to handle this. + if isinstance(exc, db_exception.DBDeadlock): + LOG.warn(_LW("Deadlock detected. Retrying...")) + return True + return False + + +def retry(cb=None, retries=50, delay=150): + """A retry decorator that ignores attempts at creating nested retries""" + def outer(f): + @functools.wraps(f) + def wrapper(self, *args, **kwargs): + if not hasattr(RETRY_STATE, 'held'): + # Create the state vars if necessary + RETRY_STATE.held = False + RETRY_STATE.retries = 0 + + if not RETRY_STATE.held: + # We're the outermost retry decorator + RETRY_STATE.held = True + + try: + while True: + try: + result = f(self, *copy.deepcopy(args), + **copy.deepcopy(kwargs)) + break + except Exception as exc: + RETRY_STATE.retries += 1 + if RETRY_STATE.retries >= retries: + # Exceeded retry attempts, raise. + raise + elif cb is not None and cb(exc) is False: + # We're not setup to retry on this exception. + raise + else: + # Retry, with a delay. + time.sleep(delay / float(1000)) + + finally: + RETRY_STATE.held = False + RETRY_STATE.retries = 0 + + else: + # We're an inner retry decorator, just pass on through. + result = f(self, *copy.deepcopy(args), **copy.deepcopy(kwargs)) + + return result + return wrapper + return outer + + +# TODO(kiall): Get this a better home :) def transaction(f): - # TODO(kiall): Get this a better home :) + @retry(cb=_retry_on_deadlock) @functools.wraps(f) def wrapper(self, *args, **kwargs): self.storage.begin() try: result = f(self, *args, **kwargs) + self.storage.commit() + return result except Exception: with excutils.save_and_reraise_exception(): self.storage.rollback() - else: - self.storage.commit() - return result + return wrapper diff --git a/designate/tests/test_central/test_service.py b/designate/tests/test_central/test_service.py index 52d10b2c7..0804aee2d 100644 --- a/designate/tests/test_central/test_service.py +++ b/designate/tests/test_central/test_service.py @@ -17,10 +17,11 @@ import copy import random -from mock import patch -from oslo_log import log as logging import testtools from testtools.matchers import GreaterThan +from mock import patch +from oslo_log import log as logging +from oslo_db import exception as db_exception from designate import exceptions from designate import objects @@ -865,6 +866,39 @@ class CentralServiceTest(CentralTestCase): with testtools.ExpectedException(exceptions.BadRequest): self.central_service.update_domain(self.admin_context, domain) + def test_update_domain_deadlock_retry(self): + # Create a domain + domain = self.create_domain(name='example.org.') + original_serial = domain.serial + + # Update the Object + domain.email = 'info@example.net' + + # Due to Python's scoping of i - we need to make it a mutable type + # for the counter to work.. In Py3, we can use the nonlocal keyword. + i = [False] + + def fail_once_then_pass(): + if i[0] is True: + return self.central_service.storage.session.commit() + else: + i[0] = True + raise db_exception.DBDeadlock() + + with patch.object(self.central_service.storage, 'commit', + side_effect=fail_once_then_pass): + # Perform the update + domain = self.central_service.update_domain( + self.admin_context, domain) + + # Ensure i[0] is True, indicating the side_effect code above was + # triggered + self.assertTrue(i[0]) + + # Ensure the domain was updated correctly + self.assertTrue(domain.serial > original_serial) + self.assertEqual('info@example.net', domain.email) + def test_delete_domain(self): # Create a domain domain = self.create_domain() @@ -1242,6 +1276,40 @@ class CentralServiceTest(CentralTestCase): self.assertEqual(recordset.ttl, 1800) self.assertThat(new_serial, GreaterThan(original_serial)) + def test_update_recordset_deadlock_retry(self): + # Create a domain + domain = self.create_domain() + + # Create a recordset + recordset = self.create_recordset(domain) + + # Update the recordset + recordset.ttl = 1800 + + # Due to Python's scoping of i - we need to make it a mutable type + # for the counter to work.. In Py3, we can use the nonlocal keyword. + i = [False] + + def fail_once_then_pass(): + if i[0] is True: + return self.central_service.storage.session.commit() + else: + i[0] = True + raise db_exception.DBDeadlock() + + with patch.object(self.central_service.storage, 'commit', + side_effect=fail_once_then_pass): + # Perform the update + recordset = self.central_service.update_recordset( + self.admin_context, recordset) + + # Ensure i[0] is True, indicating the side_effect code above was + # triggered + self.assertTrue(i[0]) + + # Ensure the recordset was updated correctly + self.assertEqual(1800, recordset.ttl) + def test_update_recordset_with_record_create(self): # Create a domain domain = self.create_domain() @@ -2400,16 +2468,20 @@ class CentralServiceTest(CentralTestCase): # Compare the actual values of attributes and nameservers for k in range(0, len(values['attributes'])): - self.assertEqual( - pool['attributes'][k].to_primitive()['designate_object.data'], - values['attributes'][k].to_primitive()['designate_object.data'] + self.assertDictContainsSubset( + values['attributes'][k].to_primitive() + ['designate_object.data'], + pool['attributes'][k].to_primitive() + ['designate_object.data'] ) for k in range(0, len(values['nameservers'])): - self.assertEqual( - pool['nameservers'][k].to_primitive()['designate_object.data'], + self.assertDictContainsSubset( values['nameservers'][k].to_primitive() - ['designate_object.data']) + ['designate_object.data'], + pool['nameservers'][k].to_primitive() + ['designate_object.data'] + ) def test_get_pool(self): # Create a server pool @@ -2509,7 +2581,7 @@ class CentralServiceTest(CentralTestCase): for r in nameserver_values]) # Update pool - self.central_service.update_pool(self.admin_context, pool) + pool = self.central_service.update_pool(self.admin_context, pool) # GET the pool pool = self.central_service.get_pool(self.admin_context, pool.id) @@ -2523,14 +2595,16 @@ class CentralServiceTest(CentralTestCase): pool['attributes'][0].to_primitive()['designate_object.data'] expected_attributes = \ pool_attributes[0].to_primitive()['designate_object.data'] - self.assertEqual(actual_attributes, expected_attributes) + self.assertDictContainsSubset( + expected_attributes, actual_attributes) for k in range(0, len(pool_nameservers)): actual_nameservers = \ pool['nameservers'][k].to_primitive()['designate_object.data'] expected_nameservers = \ pool_nameservers[k].to_primitive()['designate_object.data'] - self.assertEqual(actual_nameservers, expected_nameservers) + self.assertDictContainsSubset( + expected_nameservers, actual_nameservers) def test_delete_pool(self): # Create a server pool