Added SELECT FOR UPDATE to prevent deadlocks in tests
Closes-Bug: #1308115 Change-Id: Ib72bb7a054a13c6fadfa548bfc4640b198c47cfe
This commit is contained in:
parent
f93e782951
commit
518ccf2a36
@ -27,6 +27,7 @@ default_messages = {
|
||||
"CannotCreate": "Can't create object",
|
||||
"NotAllowed": "Action is not allowed",
|
||||
"InvalidField": "Invalid field specified for object",
|
||||
"ObjectNotFound": "Object not found in DB",
|
||||
|
||||
# node discovering errors
|
||||
"InvalidInterfacesInfo": "Invalid interfaces info",
|
||||
|
@ -63,13 +63,26 @@ class NailgunObject(object):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_by_uid(cls, uid):
|
||||
def get_by_uid(cls, uid, fail_if_not_found=False, lock_for_update=False):
|
||||
"""Get instance by it's uid (PK in case of SQLAlchemy)
|
||||
|
||||
:param uid: uid of object
|
||||
:param fail_if_not_found: raise an exception if object is not found
|
||||
:param lock_for_update: lock returned object for update (DB mutex)
|
||||
:returns: instance of an object (model)
|
||||
"""
|
||||
return db().query(cls.model).get(uid)
|
||||
q = db().query(cls.model)
|
||||
if lock_for_update:
|
||||
q = q.with_lockmode('update')
|
||||
res = q.get(uid)
|
||||
if not res and fail_if_not_found:
|
||||
raise errors.ObjectNotFound(
|
||||
"Object '{0}' with UID={1} is not found in DB".format(
|
||||
cls.__name__,
|
||||
uid
|
||||
)
|
||||
)
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def create(cls, data):
|
||||
@ -199,6 +212,25 @@ class NailgunCollection(object):
|
||||
else:
|
||||
raise TypeError("First argument should be iterable")
|
||||
|
||||
@classmethod
|
||||
def lock_for_update(cls, iterable, yield_per=100):
|
||||
"""Use SELECT FOR UPDATE on a given iterable (query).
|
||||
In case if iterable=None returns all object instances
|
||||
|
||||
:param iterable: iterable (SQLAlchemy query)
|
||||
:param yield_per: SQLAlchemy's yield_per() clause
|
||||
:returns: filtered iterable (SQLAlchemy query)
|
||||
"""
|
||||
use_iterable = iterable or cls.all(yield_per=yield_per)
|
||||
if cls._is_query(use_iterable):
|
||||
return use_iterable.with_lockmode('update')
|
||||
elif cls._is_iterable(use_iterable):
|
||||
# we can't lock abstract iterable, so returning as is
|
||||
# for compatibility
|
||||
return use_iterable
|
||||
else:
|
||||
raise TypeError("First argument should be iterable")
|
||||
|
||||
@classmethod
|
||||
def get_by_id_list(cls, iterable, uid_list, yield_per=100):
|
||||
"""Filter given iterable by list of uids.
|
||||
|
@ -71,9 +71,18 @@ class Task(NailgunObject):
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def get_by_uuid(cls, uuid):
|
||||
def get_by_uuid(cls, uuid, fail_if_not_found=False, lock_for_update=False):
|
||||
# maybe consider using uuid as pk?
|
||||
return db().query(cls.model).filter_by(uuid=uuid).first()
|
||||
q = db().query(cls.model).filter_by(uuid=uuid)
|
||||
if lock_for_update:
|
||||
q = q.with_lockmode('update')
|
||||
res = q.first()
|
||||
|
||||
if not res and fail_if_not_found:
|
||||
raise errors.ObjectNotFound(
|
||||
"Task with UUID={0} is not found in DB".format(uuid)
|
||||
)
|
||||
return res
|
||||
|
||||
|
||||
class TaskCollection(NailgunCollection):
|
||||
|
@ -125,10 +125,10 @@ class NailgunReceiver(object):
|
||||
IPAddr.network.in_([n.id for n in cluster.network_groups])
|
||||
)
|
||||
map(db().delete, ips)
|
||||
db().commit()
|
||||
db().flush()
|
||||
|
||||
db().delete(cluster)
|
||||
db().commit()
|
||||
db().flush()
|
||||
|
||||
notifier.notify(
|
||||
"done",
|
||||
@ -140,7 +140,7 @@ class NailgunReceiver(object):
|
||||
elif task.status in ('error',):
|
||||
cluster.status = 'error'
|
||||
db().add(cluster)
|
||||
db().commit()
|
||||
db().flush()
|
||||
if not task.message:
|
||||
task.message = "Failed to delete nodes:\n{0}".format(
|
||||
cls._generate_error_message(
|
||||
@ -166,21 +166,35 @@ class NailgunReceiver(object):
|
||||
status = kwargs.get('status')
|
||||
progress = kwargs.get('progress')
|
||||
|
||||
task = TaskHelper.get_task_by_uuid(task_uuid)
|
||||
if not task:
|
||||
# No task found - nothing to do here, returning
|
||||
logger.warning(
|
||||
u"No task with uuid '{0}'' found - nothing changed".format(
|
||||
task_uuid
|
||||
)
|
||||
)
|
||||
return
|
||||
task = objects.Task.get_by_uuid(
|
||||
task_uuid,
|
||||
fail_if_not_found=True,
|
||||
lock_for_update=True
|
||||
)
|
||||
|
||||
# lock cluster for updating so it can't be deleted
|
||||
objects.Cluster.get_by_uid(
|
||||
task.cluster_id,
|
||||
fail_if_not_found=True,
|
||||
lock_for_update=True
|
||||
)
|
||||
|
||||
if not status:
|
||||
status = task.status
|
||||
|
||||
# lock nodes for updating so they can't be deleted
|
||||
list(
|
||||
objects.NodeCollection.lock_for_update(
|
||||
objects.NodeCollection.get_by_id_list(
|
||||
None,
|
||||
[n['uid'] for n in nodes]
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# First of all, let's update nodes in database
|
||||
for node in nodes:
|
||||
node_db = db().query(Node).get(node['uid'])
|
||||
node_db = objects.Node.get_by_uid(node['uid'])
|
||||
|
||||
if not node_db:
|
||||
logger.warning(
|
||||
@ -230,10 +244,7 @@ class NailgunReceiver(object):
|
||||
)
|
||||
|
||||
db().add(node_db)
|
||||
db().commit()
|
||||
|
||||
# We should calculate task progress by nodes info
|
||||
task = TaskHelper.get_task_by_uuid(task_uuid)
|
||||
db().flush()
|
||||
|
||||
if nodes and not progress:
|
||||
progress = TaskHelper.recalculate_deployment_task_progress(task)
|
||||
|
@ -230,6 +230,7 @@ class FakeAmpqThread(FakeThread):
|
||||
resp_method = getattr(receiver, self.respond_to)
|
||||
for msg in self.message_gen():
|
||||
resp_method(**msg)
|
||||
db().commit()
|
||||
|
||||
|
||||
class FakeDeploymentThread(FakeAmpqThread):
|
||||
|
@ -672,12 +672,19 @@ class TestConsumer(BaseIntegrationTest):
|
||||
self.receiver = rcvr.NailgunReceiver()
|
||||
|
||||
def test_node_deploy_resp(self):
|
||||
node = self.env.create_node(api=False)
|
||||
node2 = self.env.create_node(api=False)
|
||||
self.env.create(
|
||||
cluster_kwargs={},
|
||||
nodes_kwargs=[
|
||||
{"api": False},
|
||||
{"api": False}]
|
||||
)
|
||||
|
||||
node, node2 = self.env.nodes
|
||||
|
||||
task = Task(
|
||||
uuid=str(uuid.uuid4()),
|
||||
name="deploy"
|
||||
name="deploy",
|
||||
cluster_id=self.env.clusters[0].id
|
||||
)
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
@ -725,11 +732,13 @@ class TestConsumer(BaseIntegrationTest):
|
||||
self.assertEqual(task.progress, 50)
|
||||
|
||||
def test_task_progress(self):
|
||||
self.env.create_cluster()
|
||||
|
||||
task = Task(
|
||||
uuid=str(uuid.uuid4()),
|
||||
name="super",
|
||||
status="running"
|
||||
status="running",
|
||||
cluster_id=self.env.clusters[0].id
|
||||
)
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
@ -846,7 +855,8 @@ class TestConsumer(BaseIntegrationTest):
|
||||
task = Task(
|
||||
uuid=str(uuid.uuid4()),
|
||||
name="super",
|
||||
status="running"
|
||||
status="running",
|
||||
cluster_id=self.env.clusters[0].id
|
||||
)
|
||||
self.db.add(task)
|
||||
self.db.commit()
|
||||
|
Loading…
Reference in New Issue
Block a user