Added SELECT FOR UPDATE to prevent deadlocks in tests

Closes-Bug: #1308115

Change-Id: Ib72bb7a054a13c6fadfa548bfc4640b198c47cfe
This commit is contained in:
Nikolay Markov 2014-05-14 15:05:57 +04:00
parent f93e782951
commit 518ccf2a36
6 changed files with 90 additions and 26 deletions

View File

@ -27,6 +27,7 @@ default_messages = {
"CannotCreate": "Can't create object",
"NotAllowed": "Action is not allowed",
"InvalidField": "Invalid field specified for object",
"ObjectNotFound": "Object not found in DB",
# node discovering errors
"InvalidInterfacesInfo": "Invalid interfaces info",

View File

@ -63,13 +63,26 @@ class NailgunObject(object):
)
@classmethod
def get_by_uid(cls, uid):
def get_by_uid(cls, uid, fail_if_not_found=False, lock_for_update=False):
"""Get instance by it's uid (PK in case of SQLAlchemy)
:param uid: uid of object
:param fail_if_not_found: raise an exception if object is not found
:param lock_for_update: lock returned object for update (DB mutex)
:returns: instance of an object (model)
"""
return db().query(cls.model).get(uid)
q = db().query(cls.model)
if lock_for_update:
q = q.with_lockmode('update')
res = q.get(uid)
if not res and fail_if_not_found:
raise errors.ObjectNotFound(
"Object '{0}' with UID={1} is not found in DB".format(
cls.__name__,
uid
)
)
return res
@classmethod
def create(cls, data):
@ -199,6 +212,25 @@ class NailgunCollection(object):
else:
raise TypeError("First argument should be iterable")
@classmethod
def lock_for_update(cls, iterable, yield_per=100):
"""Use SELECT FOR UPDATE on a given iterable (query).
In case if iterable=None returns all object instances
:param iterable: iterable (SQLAlchemy query)
:param yield_per: SQLAlchemy's yield_per() clause
:returns: filtered iterable (SQLAlchemy query)
"""
use_iterable = iterable or cls.all(yield_per=yield_per)
if cls._is_query(use_iterable):
return use_iterable.with_lockmode('update')
elif cls._is_iterable(use_iterable):
# we can't lock abstract iterable, so returning as is
# for compatibility
return use_iterable
else:
raise TypeError("First argument should be iterable")
@classmethod
def get_by_id_list(cls, iterable, uid_list, yield_per=100):
"""Filter given iterable by list of uids.

View File

@ -71,9 +71,18 @@ class Task(NailgunObject):
})
@classmethod
def get_by_uuid(cls, uuid):
def get_by_uuid(cls, uuid, fail_if_not_found=False, lock_for_update=False):
# maybe consider using uuid as pk?
return db().query(cls.model).filter_by(uuid=uuid).first()
q = db().query(cls.model).filter_by(uuid=uuid)
if lock_for_update:
q = q.with_lockmode('update')
res = q.first()
if not res and fail_if_not_found:
raise errors.ObjectNotFound(
"Task with UUID={0} is not found in DB".format(uuid)
)
return res
class TaskCollection(NailgunCollection):

View File

@ -125,10 +125,10 @@ class NailgunReceiver(object):
IPAddr.network.in_([n.id for n in cluster.network_groups])
)
map(db().delete, ips)
db().commit()
db().flush()
db().delete(cluster)
db().commit()
db().flush()
notifier.notify(
"done",
@ -140,7 +140,7 @@ class NailgunReceiver(object):
elif task.status in ('error',):
cluster.status = 'error'
db().add(cluster)
db().commit()
db().flush()
if not task.message:
task.message = "Failed to delete nodes:\n{0}".format(
cls._generate_error_message(
@ -166,21 +166,35 @@ class NailgunReceiver(object):
status = kwargs.get('status')
progress = kwargs.get('progress')
task = TaskHelper.get_task_by_uuid(task_uuid)
if not task:
# No task found - nothing to do here, returning
logger.warning(
u"No task with uuid '{0}'' found - nothing changed".format(
task_uuid
task = objects.Task.get_by_uuid(
task_uuid,
fail_if_not_found=True,
lock_for_update=True
)
# lock cluster for updating so it can't be deleted
objects.Cluster.get_by_uid(
task.cluster_id,
fail_if_not_found=True,
lock_for_update=True
)
return
if not status:
status = task.status
# lock nodes for updating so they can't be deleted
list(
objects.NodeCollection.lock_for_update(
objects.NodeCollection.get_by_id_list(
None,
[n['uid'] for n in nodes]
)
)
)
# First of all, let's update nodes in database
for node in nodes:
node_db = db().query(Node).get(node['uid'])
node_db = objects.Node.get_by_uid(node['uid'])
if not node_db:
logger.warning(
@ -230,10 +244,7 @@ class NailgunReceiver(object):
)
db().add(node_db)
db().commit()
# We should calculate task progress by nodes info
task = TaskHelper.get_task_by_uuid(task_uuid)
db().flush()
if nodes and not progress:
progress = TaskHelper.recalculate_deployment_task_progress(task)

View File

@ -230,6 +230,7 @@ class FakeAmpqThread(FakeThread):
resp_method = getattr(receiver, self.respond_to)
for msg in self.message_gen():
resp_method(**msg)
db().commit()
class FakeDeploymentThread(FakeAmpqThread):

View File

@ -672,12 +672,19 @@ class TestConsumer(BaseIntegrationTest):
self.receiver = rcvr.NailgunReceiver()
def test_node_deploy_resp(self):
node = self.env.create_node(api=False)
node2 = self.env.create_node(api=False)
self.env.create(
cluster_kwargs={},
nodes_kwargs=[
{"api": False},
{"api": False}]
)
node, node2 = self.env.nodes
task = Task(
uuid=str(uuid.uuid4()),
name="deploy"
name="deploy",
cluster_id=self.env.clusters[0].id
)
self.db.add(task)
self.db.commit()
@ -725,11 +732,13 @@ class TestConsumer(BaseIntegrationTest):
self.assertEqual(task.progress, 50)
def test_task_progress(self):
self.env.create_cluster()
task = Task(
uuid=str(uuid.uuid4()),
name="super",
status="running"
status="running",
cluster_id=self.env.clusters[0].id
)
self.db.add(task)
self.db.commit()
@ -846,7 +855,8 @@ class TestConsumer(BaseIntegrationTest):
task = Task(
uuid=str(uuid.uuid4()),
name="super",
status="running"
status="running",
cluster_id=self.env.clusters[0].id
)
self.db.add(task)
self.db.commit()