diff --git a/cyborg/common/constants.py b/cyborg/common/constants.py index dae77990..3bcac66c 100644 --- a/cyborg/common/constants.py +++ b/cyborg/common/constants.py @@ -25,15 +25,72 @@ ARQ_STATES = (ARQ_INITIAL, ARQ_BIND_STARTED, ARQ_BOUND, ARQ_UNBOUND, ARQ_BIND_FAILED, ARQ_DELETING) = ( 'Initial', 'BindStarted', 'Bound', 'Unbound', 'BindFailed', 'Deleting') + +ARQ_BIND_STAGE = (ARQ_PRE_BIND, ARQ_FINISH_BIND, + ARQ_OUFOF_BIND_FLOW) = ( + [ARQ_INITIAL, ARQ_BIND_STARTED], [ARQ_BOUND, ARQ_BIND_FAILED], + [ARQ_UNBOUND, ARQ_DELETING]) + + +ARQ_BIND_STATUS = (ARQ_BIND_STATUS_FINISH, ARQ_BIND_STATUS_FAILED) = ( + "completed", "failed") + + +ARQ_BIND_STATES_STATUS_MAP = { + ARQ_BOUND: ARQ_BIND_STATUS_FINISH, + ARQ_BIND_FAILED: ARQ_BIND_STATUS_FAILED +} + +# TODO(Shaohe): maybe we can use oslo automaton lib +# ref: https://docs.openstack.org/automaton/latest/user/examples.html +# The states in value list can transfrom to the key state +ARQ_STATES_TRANSFORM_MATRIX = { + ARQ_INITIAL: [], + ARQ_BIND_STARTED: [ARQ_INITIAL, ARQ_UNBOUND], + ARQ_BOUND: [ARQ_BIND_STARTED], + ARQ_UNBOUND: [ARQ_INITIAL, ARQ_BIND_STARTED, ARQ_BOUND, ARQ_BIND_FAILED], + ARQ_BIND_FAILED: [ARQ_BIND_STARTED, ARQ_BOUND], + ARQ_DELETING: [ARQ_INITIAL, ARQ_BIND_STARTED, ARQ_BOUND, + ARQ_UNBOUND, ARQ_BIND_FAILED] +} + + # Device type DEVICE_TYPE = (DEVICE_GPU, DEVICE_FPGA, DEVICE_AICHIP) + # Attach handle type # 'TEST_PCI': used by fake driver, ignored by Nova virt driver. ATTACH_HANDLE_TYPES = (AH_TYPE_PCI, AH_TYPE_MDEV, AH_TYPE_TEST_PCI) = ( "PCI", "MDEV", "TEST_PCI") + # Resource Class RESOURCES = { "FPGA": orc.FPGA } + + +ACCEL_SPECS = ( + ACCEL_BITSTREAM_ID, + ACCEL_FUNCTION_ID +) = ( + "accel:bitstream_id", + "accel:function_id" +) + + +SUPPORT_RESOURCES = ( + FPGA) = ( + "FPGA" +) + + +FPGA_TRAITS = ( + FPGA_FUNCTION_ID, +) = ( + "CUSTOM_FPGA_FUNCTION_ID", +) + + +RESOURCES_PREFIX = "resources:" diff --git a/cyborg/common/exception.py b/cyborg/common/exception.py index 7c05a3f3..99a8f8aa 100644 --- a/cyborg/common/exception.py +++ b/cyborg/common/exception.py @@ -225,6 +225,11 @@ class DeployableNotFound(NotFound): _msg_fmt = _("Deployable %(uuid)s could not be found.") +class DeployableNotFoundByRP(NotFound): + _msg_fmt = _("Deployable could not be found by resource provider " + "%(uuid)s.") + + class ExtArqNotFound(NotFound): _msg_fmt = _("ExtArq %(uuid)s could not be found.") diff --git a/cyborg/common/utils.py b/cyborg/common/utils.py index c60e3cfb..d200c7cc 100644 --- a/cyborg/common/utils.py +++ b/cyborg/common/utils.py @@ -15,7 +15,11 @@ """Utilities and helper functions.""" +from concurrent.futures import ThreadPoolExecutor +from functools import wraps import six +import time +import traceback from keystoneauth1 import exceptions as ks_exc from keystoneauth1 import loading as ks_loading @@ -163,3 +167,262 @@ def get_endpoint(ksa_adapter): raise ks_exc.EndpointNotFound( "Could not find requested endpoint for any of the following " "interfaces: %s" % interfaces) + + +class _Singleton(type): + """A metaclass that creates a Singleton base class when called.""" + + _instances = {} + + def __call__(cls, *args, **kwargs): + ins = cls._instances.get(cls) + if not ins or ( + hasattr(ins, "_reset") and isinstance(ins, cls) and ins._reset()): + cls._instances[cls] = super( + _Singleton, cls).__call__(*args, **kwargs) + + return cls._instances[cls] + + +class Singleton(_Singleton('SingletonMeta', (object,), {})): + """A class for Singleton pattern.""" + + pass + + +class ThreadWorks(Singleton): + """Passthrough method for ThreadPoolExecutor. + + It will also grab the context from the threadlocal store and add it to + the store on the new thread. This allows for continuity in logging the + context when using this method to spawn a new thread. + """ + + def __init__(self, pool_size=CONF.thread_pool_size): + """Singleton ThreadWorks init.""" + # Ref: https://pythonhosted.org/futures/ + # NOTE(Shaohe) We can let eventlet greening ThreadPoolExecutor + # eventlet.patcher.monkey_patch(os=False, socket=True, + # select=True, thread=True) + # futures = eventlet.import_patched('concurrent.futures') + # ThreadPoolExecutor = futures.ThreadPoolExecutor + self.executor = ThreadPoolExecutor(max_workers=pool_size) + self.masters = {} + + def spawn(self, func, *args, **kwargs): + """Put a job in thread pool.""" + LOG.debug("Add an async jobs. func: %s is with parameters args: %s, " + "kwargs: %s", func, args, kwargs) + future = self.executor.submit(func, *args, **kwargs) + return future + + def spawn_master(self, func, *args, **kwargs): + """Start a new thread for a job.""" + executor = ThreadPoolExecutor() + # TODO(Shaohe) every submit func should be wrapped with exception catch + job = executor.submit(func, *args, **kwargs) + # NOTE(Shaohe) shutdown should be after job submit + executor.shutdown(wait=False) + # TODO(Shaohe) we need to consider resouce collection such as the + # follow code to recoder them with timestemp? + # master = {tag: { + # "executor": executor, + # "job": f, + # "timestemp": time.time(), + # "timeout": timeout}} + # self.masters.update(master) + return job + + def _reset(self): + return self.executor._shutdown + + def map(self, func, *iterables, **kwargs): + """Batch for job function.""" + return self.executor.map(func, *iterables, **kwargs) + + @classmethod + def get_workers_result(cls, fs=(), **kwargs): + """get a jobs worker result. + + Waits workers util it finish or raise any Exception. + It will cancel the rest if one job worker fails. + If the future is cancelled before completing then CancelledError + will be raised. + + Parameters: + fs: the workers list spawn return. + timeout: Wait workers timeout, it can be an int or float. + If the worker hasn't yet completed then this method + will wait up to timeout seconds. If the worker hasn't + completed in timeout seconds, then a + concurrent.futures.TimeoutError will be raised. + If timeout is not specified or None, there is no limit + to the wait time. + return a generator which include: + result: the value returned by the job workers. + exception_info: the exception details raised from workers. + state: The work state. + """ + timeout = kwargs.get('timeout') + if timeout is not None: + end_time = timeout + time.time() + LOG.info("job timeout set as %s", timeout) + + # Yield must be hidden in closure so that the futures are submitted + # before the first iterator value is required. + def future_iterator(): + try: + # reverse to keep finishing order + fs.reverse() + while fs: + # Careful not to keep a reference to the popped future + if timeout is None: + f = fs.pop() + yield f.result(), f.exception_info(), f._state, None + else: + f = fs.pop() + yield (f.result(end_time - time.time()), + f.exception_info(), f._state, None) + except Exception as e: + err = traceback.format_exc() + LOG.error("Error during check the worker status. Exception " + "info: %s, result: %s, state: %s. Reason %s", + f.exception_info(), f._result, f._state, e.message) + if f: + yield f._result, f.exception_info(), f._state, err + finally: + # Do best to cancel remain jobs. + if fs: + LOG.info("Cancel the remained pending jobs") + for future in fs: + future.cancel() + return future_iterator() + + @classmethod + def check_workers_exception(cls, fs=(), **kwargs): + """check whether a jobs worker raise exception. + + Waits workers util it finish or raise any Exception. + It will not cancel the rest if one job worker fails. As we discussed, + if the job has already started flashing the card, we shouldn't cancel + it then. + So in FPGA scenarios, that means we will let the remained FPGA program + go on, even one jobs failed. + + Parameters: + fs: the workers list spawn return. + timeout: Wait workers timeout, it can be an int or float. + If the worker hasn't yet completed then this method + will wait up to timeout seconds. If the worker hasn't + completed in timeout seconds, then a + concurrent.futures.TimeoutError will be raised. + If timeout is not specified or None, there is no limit + + return a generator which include: + exception: Return the exception raised by the workers. + exception_info: the exception details raised from workers. + result: the value returned by the job workers. + state: The work state. + to the wait time. + usage: + """ + timeout = kwargs.get('timeout') + if timeout is not None: + LOG.info("job timeout set as %s", timeout) + end_time = timeout + time.time() + + # Yield must be hidden in closure so that the futures are submitted + # before the first iterator value is required. + def exception_iterator(): + try: + # reverse to keep finishing order + fs.reverse() + while fs: + # Careful not to keep a reference to the popped future + if timeout is None: + f = fs.pop() + yield (f.exception(), f.exception_info(), + f._result, f._state) + else: + f = fs.pop() + yield (f.exception(end_time - time.time()), + f.exception_info(), f._result, f._state) + except Exception as e: + LOG.error("Error during check the worker status. Exception " + "info: %s, result: %s, state: %s. Reason %s", + f.exception_info(), f._result, f._state, e.message) + finally: + if fs: + LOG.info("Cancel the remained pending jobs") + for future in fs: + future.cancel() + return exception_iterator() + + +# info https://www.oreilly.com/library/view/python-cookbook/ +# 0596001673/ch14s05.html +def format_tb(tb, limit=None): + """Fromat traceback to a string list. + + Print the usual traceback information, followed by a listing of all the + local variables in each frame. + """ + if not tb: + return [] + tbs = ['Traceback (most recent call last):\n'] + while 1: + tbs = tbs + traceback.format_tb(tb, limit) + if not tb.tb_next: + break + tb = tb.tb_next + return tbs + + +def wrap_job_tb(msg="Reason: %s"): + """Wrap a function with a is_job tag added, and catch Excetpion.""" + def _wrap_job_tb(method): + @wraps(method) + def _impl(self, *args, **kwargs): + try: + output = method(self, *args, **kwargs) + except Exception as e: + LOG.error(msg, e.message) + LOG.error(traceback.format_exc()) + raise + return output + setattr(_impl, "is_job", True) + return _impl + return _wrap_job_tb + + +def factory_register(SuperClass, ClassName): + """Register an concrete class to a factory Class.""" + def decorator(Class): + # return Class + if not hasattr(SuperClass, "_factory"): + setattr(SuperClass, "_factory", {}) + SuperClass._factory[ClassName] = Class + setattr(Class, "_factory_type", ClassName) + return Class + return decorator + + +class FactoryMixin(object): + """A factory Mixin to create an concrete class.""" + + @classmethod + def factory(cls, typ, *args, **kwargs): + """factory to create an concrete class.""" + f = getattr(cls, "_factory", {}) + sclass = f.get(typ, None) + if sclass: + LOG.info("Find %s of concrete %s by %s.", + sclass.__name__, cls.__name__, typ) + return sclass + for sclass in cls.__subclasses__(): + if typ == getattr(cls, "_factory_type", None): + return sclass + else: + return cls + LOG.info("Use default %s, do not find concrete class" + "by %s.", cls.__name__, typ) diff --git a/cyborg/conf/default.py b/cyborg/conf/default.py index 433d52ef..a408d269 100644 --- a/cyborg/conf/default.py +++ b/cyborg/conf/default.py @@ -48,6 +48,18 @@ service_opts = [ default=60, help=_('Default interval (in seconds) for running periodic ' 'tasks.')), + cfg.IntOpt( + 'thread_pool_size', + default=10, + help=_(""" + This option specifies the size of the pool of threads used by API + to do async jobs.It is possible to limit the number of concurrent + connections using this option.""")), + cfg.IntOpt( + 'bind_timeout', + default=60, + help=_(""" + This option specifies the timeout of async job for ARQ bind.""")), ] path_opts = [ diff --git a/cyborg/db/api.py b/cyborg/db/api.py index d07e8091..f31ea8d4 100644 --- a/cyborg/db/api.py +++ b/cyborg/db/api.py @@ -189,15 +189,15 @@ class Connection(object): """Delete an extarq.""" @abc.abstractmethod - def extarq_update(self, context, uuid, values): + def extarq_update(self, context, uuid, values, state_scope=None): """Update an extarq.""" @abc.abstractmethod - def extarq_list(self, context): + def extarq_list(self, context, uuid_range=None): """Get requested list of extarqs.""" @abc.abstractmethod - def extarq_get(self, context, uuid): + def extarq_get(self, context, uuid, lock=False): """Get requested extarq.""" # attach_handle diff --git a/cyborg/db/sqlalchemy/api.py b/cyborg/db/sqlalchemy/api.py index 54b63715..0b99d6f4 100644 --- a/cyborg/db/sqlalchemy/api.py +++ b/cyborg/db/sqlalchemy/api.py @@ -597,7 +597,10 @@ class Connection(api.Connection): query = model_query( context, models.Deployable).filter_by(rp_uuid=rp_uuid) - return query.one() + try: + return query.one() + except NoResultFound: + raise exception.DeployableNotFoundByRP(uuid=rp_uuid) def deployable_list(self, context): query = model_query(context, models.Deployable) @@ -871,32 +874,44 @@ class Connection(api.Connection): if count != 1: raise exception.ExtArqNotFound(uuid=uuid) - def extarq_update(self, context, uuid, values): + def extarq_update(self, context, uuid, values, state_scope=None): if 'uuid' in values and values['uuid'] != uuid: msg = _("Cannot overwrite UUID for an existing ExtArq.") raise exception.InvalidParameterValue(err=msg) - return self._do_update_extarq(context, uuid, values) + return self._do_update_extarq(context, uuid, values, state_scope) @oslo_db_api.retry_on_deadlock - def _do_update_extarq(self, context, uuid, values): + def _do_update_extarq(self, context, uuid, values, state_scope=None): with _session_for_write(): query = model_query(context, models.ExtArq) - query = query.filter_by(uuid=uuid) + query = query_update = query.filter_by( + uuid=uuid).with_lockmode('update') + if type(state_scope) is list: + query_update = query_update.filter( + models.ExtArq.state.in_(state_scope)) try: - ref = query.with_lockmode('update').one() + query_update.update( + values, synchronize_session="fetch") except NoResultFound: raise exception.ExtArqNotFound(uuid=uuid) - ref.update(values) + ref = query.first() return ref - def extarq_list(self, context): + def extarq_list(self, context, uuid_range=None): query = model_query(context, models.ExtArq) + if type(uuid_range) is list: + query = query.filter( + models.ExtArq.uuid.in_(uuid_range)) return _paginate_query(context, models.ExtArq, query) - def extarq_get(self, context, uuid): + @oslo_db_api.retry_on_deadlock + def extarq_get(self, context, uuid, lock=False): query = model_query( context, models.ExtArq).filter_by(uuid=uuid) + # NOTE we will support aync bind, so get query by lock + if lock: + query = query.with_for_update() try: return query.one() except NoResultFound: diff --git a/cyborg/objects/__init__.py b/cyborg/objects/__init__.py index ef7f1dd0..bafc3157 100644 --- a/cyborg/objects/__init__.py +++ b/cyborg/objects/__init__.py @@ -29,6 +29,7 @@ def register_all(): __import__('cyborg.objects.attribute') __import__('cyborg.objects.arq') __import__('cyborg.objects.ext_arq') + __import__('cyborg.objects.extarq.fpga_ext_arq') __import__('cyborg.objects.attach_handle') __import__('cyborg.objects.control_path') __import__('cyborg.objects.device') diff --git a/cyborg/objects/deployable.py b/cyborg/objects/deployable.py index 427511a7..46a5fffb 100644 --- a/cyborg/objects/deployable.py +++ b/cyborg/objects/deployable.py @@ -242,3 +242,11 @@ class Deployable(base.CyborgObject, object_base.VersionedObjectDictCompat): return dep_obj_list[0] else: return None + + def get_cpid_list(self, context): + query_filter = {"device_id": self.device_id} + # TODO(Sundar) We should probably get cpid from objects layer, + # not db layer + cpid_list = self.dbapi.control_path_get_by_filters( + context, query_filter) + return cpid_list diff --git a/cyborg/objects/ext_arq.py b/cyborg/objects/ext_arq.py index 3c819c33..d97cee94 100644 --- a/cyborg/objects/ext_arq.py +++ b/cyborg/objects/ext_arq.py @@ -15,28 +15,28 @@ from openstack import connection from oslo_log import log as logging -from oslo_serialization import jsonutils from oslo_versionedobjects import base as object_base -from cyborg.agent.rpcapi import AgentAPI from cyborg.common import constants +from cyborg.common.constants import ARQ_STATES_TRANSFORM_MATRIX from cyborg.common import exception -from cyborg.common import nova_client -from cyborg.common import placement_client +from cyborg.common import utils from cyborg.conf import CONF from cyborg.db import api as dbapi from cyborg import objects from cyborg.objects.attach_handle import AttachHandle from cyborg.objects import base -from cyborg.objects.deployable import Deployable from cyborg.objects.device_profile import DeviceProfile +from cyborg.objects.extarq.ext_arq_job import ExtARQJobMixin from cyborg.objects import fields as object_fields + LOG = logging.getLogger(__name__) @base.CyborgObjectRegistry.register -class ExtARQ(base.CyborgObject, object_base.VersionedObjectDictCompat): +class ExtARQ(base.CyborgObject, object_base.VersionedObjectDictCompat, + utils.FactoryMixin, ExtARQJobMixin): """ExtARQ is a wrapper around ARQ with Cyborg-private fields. Each ExtARQ object contains exactly one ARQ object as a field. But, in the db layer, ExtARQ and ARQ are represented together @@ -89,21 +89,22 @@ class ExtARQ(base.CyborgObject, object_base.VersionedObjectDictCompat): return self @classmethod - def get(cls, context, uuid): + def get(cls, context, uuid, lock=False): """Find a DB ExtARQ and return an Obj ExtARQ.""" # TODO() Fix warnings that '' is not an UUID db_extarq = cls.dbapi.extarq_get(context, uuid) obj_arq = objects.ARQ(context) - obj_extarq = ExtARQ(context) + obj_extarq = cls(context) obj_extarq['arq'] = obj_arq obj_extarq = cls._from_db_object(obj_extarq, db_extarq, context) return obj_extarq @classmethod - def list(cls, context): + def list(cls, context, uuid_range=None): """Return a list of ExtARQ objects.""" - db_extarqs = cls.dbapi.extarq_list(context) - obj_extarq_list = cls._from_db_object_list(db_extarqs, context) + db_extarqs = cls.dbapi.extarq_list(context, uuid_range) + obj_extarq_list = cls._from_db_object_list( + db_extarqs, context) return obj_extarq_list def save(self, context): @@ -112,6 +113,34 @@ class ExtARQ(base.CyborgObject, object_base.VersionedObjectDictCompat): db_extarq = self.dbapi.extarq_update(context, self.arq.uuid, updates) self._from_db_object(self, db_extarq, context) + def update_state(self, context, state, scope=None): + """Update an ExtARQ state record in the DB.""" + updates = self.obj_get_changes() + updates["state"] = state + db_extarq = self.dbapi.extarq_update( + context, self.arq.uuid, updates, scope) + self._from_db_object(self, db_extarq, context) + + def update_check_state(self, context, state, scope=None): + if self.arq.state == state: + LOG.info("ExtARQ(%s) state is %s, no need to update", + self.arq.uuid, state) + return False + old = self.arq.state + scope = scope or ARQ_STATES_TRANSFORM_MATRIX[state] + self.update_state(context, state, scope) + ea = ExtARQ.get(context, self.arq.uuid, lock=True) + if not ea: + raise exception.ResourceNotFound( + "Can not find ExtARQ(%s)" % self.arq.uuid) + current = ea.arq.state + if state != current: + msg = ("Failed to change ARQ state from %s to %s, the current " + "state is %s" % (old, state, current)) + LOG.error(msg) + raise exception.ARQInvalidState(msg) + return True + def destroy(self, context): """Delete an ExtARQ from the DB.""" self.dbapi.extarq_delete(context, self.arq.uuid) @@ -145,203 +174,31 @@ class ExtARQ(base.CyborgObject, object_base.VersionedObjectDictCompat): auth_user = default_user return connection.Connection(cloud=auth_user) - def _get_bitstream_md_from_function_id(self, function_id): - """Get bitstream metadata given a function id.""" - conn = self._get_glance_connection() - properties = {'accel:function_id': function_id} - resp = conn.image.get('/images', params=properties) - if resp: - image_list = resp.json()['images'] - if type(image_list) != list: - raise exception.InvalidType( - obj='image', type=type(image_list), - expected='list') - if len(image_list) != 1: - raise exception.ExpectedOneObject(obj='image', - count=len(image_list)) - return image_list[0] - else: - LOG.warning('Failed to get image for function (%s)', - function_id) - return None + def _allocate_attach_handle(self, context, deployable): + try: + ah = AttachHandle.allocate(context, deployable.id) + self.attach_handle_id = ah.id + except Exception as e: + LOG.error("Failed to allocate attach handle for ARQ %s" + "from deployable %s. Reason: %s", + self.arq.uuid, deployable.uuid, e.message) + # TODO(Shaohe) Rollback? We have _update_placement, + # should cancel it. + self.update_check_state( + context, constants.ARQ_BIND_FAILED) + raise + LOG.info('Attach handle(%s) for ARQ(%s) successfully.', + ah.uuid, self.arq.uuid) - def _get_bitstream_md_from_bitstream_id(self, bitstream_id): - """Get bitstream metadata given a bitstream id.""" - conn = self._get_glance_connection() - resp = conn.image.get('/images/' + bitstream_id) - if resp: - return resp.json() - else: - LOG.warning('Failed to get image for bitstream (%s)', - bitstream_id) - return None - - def _do_programming(self, context, hostname, - deployable, bitstream_id): - driver_name = deployable.driver_name - - query_filter = {"device_id": deployable.device_id} - # TODO() We should probably get cpid from objects layer, not db layer - cpid_list = self.dbapi.control_path_get_by_filters( - context, query_filter) - count = len(cpid_list) - if count != 1: - raise exception.ExpectedOneObject(type='controlpath_id', - count=count) - controlpath_id = cpid_list[0] - controlpath_id['cpid_info'] = jsonutils.loads( - controlpath_id['cpid_info']) - LOG.info('Found control path id: %s', controlpath_id.__dict__) - - LOG.info('Starting programming for host: (%s) deployable (%s) ' - 'bitstream_id (%s)', hostname, - deployable.uuid, bitstream_id) - agent = AgentAPI() - # TODO() do this asynchronously - # TODO() do this in the conductor - agent.fpga_program_v2(context, hostname, - controlpath_id, bitstream_id, - driver_name) - LOG.info('Finished programming for host: (%s) deployable (%s)', - hostname, deployable.uuid) - # TODO() propagate agent errors to caller - return True - - def _update_placement(self, devrp_uuid, function_id, - bitstream_md, driver_name): - placement = placement_client.PlacementClient() - placement.delete_traits_with_prefixes( - devrp_uuid, ['CUSTOM_FPGA_FUNCTION_ID']) - # TODO(Sundar) Don't apply function trait if bitstream is private - if not function_id: - function_id = bitstream_md.get('accel:function_id') - if function_id: - function_id = function_id.upper().replace('-', '_-') - # TODO(Sundar) Validate this is a valid trait name - # Assume driver name == vendor name for FPGA driver. - vendor = driver_name.upper() - trait_names = ['CUSTOM_FPGA_FUNCTION_ID_' + vendor + function_id] - placement.add_traits_to_rp(devrp_uuid, trait_names) - - def bind(self, context, hostname, devrp_uuid, instance_uuid): - """Given a device rp UUID, get the deployable UUID and - an attach handle. - """ - LOG.info('[arqs:objs] bind. hostname: %s, devrp_uuid: %s' - 'instance: %s', hostname, devrp_uuid, instance_uuid) - - bitstream_id = self.device_profile_group.get('accel:bitstream_id') - function_id = self.device_profile_group.get('accel:function_id') - programming_needed = (bitstream_id is not None or - function_id is not None) - if (programming_needed and - bitstream_id is not None and function_id is not None): - raise exception.InvalidParameterValue( - 'In device profile {0}, only one among bitstream_id ' - 'and function_id must be set, but both are set') - - deployable = Deployable.get_by_device_rp_uuid(context, devrp_uuid) - - # TODO() Check that deployable.device.hostname matches param hostname - - # Note(Sundar): We associate the ARQ with instance UUID before the - # programming starts. So, if programming fails and then Nova calls - # to delete all ARQs for a given instance, we can still pick all - # the relevant ARQs. - arq = self.arq - arq.hostname = hostname - arq.device_rp_uuid = devrp_uuid - arq.instance_uuid = instance_uuid - # If prog fails, we'll change the state - arq.state = constants.ARQ_BIND_STARTED - self.save(context) # ARQ changes get committed here - - if programming_needed: - LOG.info('[arqs:objs] bind. Programming needed. ' - 'bitstream: (%s) function: (%s) Deployable UUID: (%s)', - bitstream_id or '', function_id or '', - deployable.uuid) - if bitstream_id is not None: # FPGA aaS - bitstream_md = self._get_bitstream_md_from_bitstream_id( - bitstream_id) - else: # Accelerated Function aaS - bitstream_md = self._get_bitstream_md_from_function_id( - function_id) - LOG.info('[arqs:objs] For function id (%s), got ' - 'bitstream id (%s)', function_id, - bitstream_md['id']) - bitstream_id = bitstream_md['id'] - - if deployable.bitstream_id == bitstream_id: - LOG.info('Deployable %(uuid)s already has the needed ' - 'bitstream %(stream_id)s. Skipping programming.', - {"uuid": deployable.uuid, "stream_id": bitstream_id}) - else: - ok = self._do_programming(context, hostname, - deployable, bitstream_id) - if ok: - self._update_placement(devrp_uuid, function_id, - bitstream_md, - deployable.driver_name) - deployable.update(context, {'bitstream_id': bitstream_id}) - arq.state = constants.ARQ_BOUND - else: - arq.state = constants.ARQ_BIND_FAILED - - # If programming was done, arq.state already got updated. - # If no programming was needed, transition to BOUND state. - if arq.state == constants.ARQ_BIND_STARTED: - arq.state = constants.ARQ_BOUND - - # We allocate attach handle after programming because, if latter - # fails, we need to deallocate the AH - if arq.state == constants.ARQ_BOUND: # still on happy path - try: - ah = AttachHandle.allocate(context, deployable.id) - self.attach_handle_id = ah.id - except Exception: - LOG.error("Failed to allocate attach handle for ARQ " - "%(arq_uuid)s from deployable %(deployable_uuid)s", - {"arq_uuid": arq.uuid, - "deployable_uuid": deployable.uuid}) - arq.state = constants.ARQ_BIND_FAILED - - self.arq = arq - self.save(context) # ARQ state changes get committed here - - @classmethod - def apply_patch(cls, context, patch_list, valid_fields): - """Apply JSON patch. See api/controllers/v1/arqs.py. """ - device_profile_name = None - instance_uuid = None - bind_action = False - status = "completed" - for arq_uuid, patch in patch_list.items(): - extarq = ExtARQ.get(context, arq_uuid) - if not device_profile_name: - device_profile_name = extarq.arq.device_profile_name - if not instance_uuid: - instance_uuid = valid_fields[arq_uuid]['instance_uuid'] - if patch[0]['op'] == 'add': # All ops are 'add' - # True if do binding, False if do unbinding. - bind_action = True - extarq.bind(context, - valid_fields[arq_uuid]['hostname'], - valid_fields[arq_uuid]['device_rp_uuid'], - valid_fields[arq_uuid]['instance_uuid']) - if extarq.arq.state == constants.ARQ_BIND_FAILED: - status = "failed" - elif extarq.arq.state == constants.ARQ_BOUND: - continue - else: - raise exception.ARQInvalidState(state=extarq.arq.state) - else: - bind_action = False - extarq.unbind(context) - if bind_action: - nova_api = nova_client.NovaAPI() - nova_api.notify_binding(instance_uuid, - device_profile_name, status) + def bind(self, context, deployable): + self._allocate_attach_handle(context, deployable) + # ARQ state changes get committed here + self.update_check_state(context, constants.ARQ_BOUND) + LOG.info('Update ARQ %s state to "Bound" successfully.', + self.arq.uuid) + # TODO(Shaohe) rollback self._unbind and self._delete + # if (self.arq.state == constants.ARQ_DELETING + # or self.arq.state == ARQ_UNBOUND): def unbind(self, context): arq = self.arq diff --git a/cyborg/objects/extarq/__init__.py b/cyborg/objects/extarq/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/cyborg/objects/extarq/ext_arq_job.py b/cyborg/objects/extarq/ext_arq_job.py new file mode 100644 index 00000000..7b07a129 --- /dev/null +++ b/cyborg/objects/extarq/ext_arq_job.py @@ -0,0 +1,205 @@ +# Copyright 2019 Intel Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from oslo_log import log as logging + +from cyborg.common import constants +from cyborg.common.constants import ARQ_STATES_TRANSFORM_MATRIX +from cyborg.common import exception +from cyborg.common import nova_client +from cyborg.common import utils +from cyborg.conf import CONF +from cyborg import objects + + +LOG = logging.getLogger(__name__) + + +class ExtARQJobMixin(object): + """Mixin Class for ExtARQ async job management.""" + + def _bind_job(self, context, deployable): + """The bind process of an acclerator.""" + check_extra_job = getattr(self, "_need_extra_bind_job", None) + need_job = None + if check_extra_job: + need_job = check_extra_job(context, deployable) + if getattr(self.bind, "is_job", False) and need_job is not False: + LOG.info("Start job for ARQ(%s) bind.", self.arq.uuid) + works = utils.ThreadWorks() + job = works.spawn(self.bind, context, deployable) + return job + else: + LOG.info("ARQ(%s) bind process is instant.", self.arq.uuid) + self.bind(context, deployable) + + @classmethod + def get_suitable_ext_arq(cls, context, uuid): + """From the inherit subclass find the suitable ExtARQ.""" + extarq = cls.get(context, uuid) + typ, _ = extarq.get_resources_from_device_profile_group() + factory = cls.factory(typ) + if factory != cls: + return factory.get(context, uuid) + return extarq + + def start_bind_job(self, context, valid_fields): + """Check and start bind jobs for ARQ.""" + # Check can ARC be bound. + if (self.arq.state not in + ARQ_STATES_TRANSFORM_MATRIX[constants.ARQ_BIND_STARTED]): + raise exception.ARQInvalidState(state=self.arq.state) + + hostname = valid_fields[self.arq.uuid]['hostname'] + devrp_uuid = valid_fields[self.arq.uuid]['device_rp_uuid'] + instance_uuid = valid_fields[self.arq.uuid]['instance_uuid'] + LOG.info('[arqs:objs] bind. hostname: %s, devrp_uuid: %s' + 'instance: %s', hostname, devrp_uuid, instance_uuid) + + self.arq.hostname = hostname + self.arq.device_rp_uuid = devrp_uuid + self.arq.instance_uuid = instance_uuid + + # If prog fails, we'll change this ARQ state changes get committed here + self.update_check_state(context, constants.ARQ_BIND_STARTED) + + dep = objects.Deployable.get_by_device_rp_uuid(context, devrp_uuid) + return self._bind_job(context, dep) + + @classmethod + def master(cls, context, arq_binds): + """Start a master thread to monitor job workers.""" + jobs = {} + instant = {} + arq_uuids = [ea.arq.uuid for ea in arq_binds.keys()] + for arq_uuid, job in arq_binds.items(): + kv = {arq_uuid: job} + jobs.update(kv) if job else instant.update(kv) + if not jobs: + LOG.info("All ARQ(%s) bind process are instant.", arq_uuids) + cls.check_bindings_result(context, arq_binds.keys()) + return + th_workers = utils.ThreadWorks() + works_generator = th_workers.get_workers_result( + jobs.values(), timeout=CONF.bind_timeout) + # arq_binds, timeout=1) + LOG.info("Check ARQ(%s) bind jobs status.", arq_uuids) + th_workers.spawn_master( + cls.job_monitor, context, works_generator, arq_binds.keys()) + + @classmethod + def check_bindings_result(cls, context, extarqs): + """Check the ARQ bind status result.""" + # Batch get or get one by one? Maybe delete a ARQ + arq_uuids = [ea.arq.uuid for ea in extarqs] + + extarqs = list(extarqs) + device_profile_name = extarqs[0].arq.device_profile_name + instance_uuid = extarqs[0].arq.instance_uuid + + extarqs = cls.list(context, arq_uuids) + if len(extarqs) < len(arq_uuids): + LOG.error("ARQs(%s) bind status sync error, status is %s. " + "For some ARQs %s are deleted.", + arq_uuids, constants.ARQ_BIND_STATUS_FAILED, + set(arq_uuids) - set([[ea.arq.uuid for ea in extarqs]])) + cls.bind_notify(device_profile_name, instance_uuid, + constants.ARQ_BIND_STATUS_FAILED) + + status = constants.ARQ_BIND_STATUS_FINISH + for extarq in extarqs: + state = extarq.arq.state + uuid = extarq.arq.uuid + if state in constants.ARQ_PRE_BIND: + # OPEN ignore ARQ_OUFOF_BIND_FLOW? + status = constants.ARQ_BIND_STATUS_FAILED + LOG.error("ARQs(%s) bind has not finished, status is %s.", + uuid, status) + break + elif state in constants.ARQ_OUFOF_BIND_FLOW + [ + constants.ARQ_BIND_STATUS_FAILED]: + # OPEN ignore ARQ_OUFOF_BIND_FLOW? + status = constants.ARQ_BIND_STATUS_FAILED + LOG.error("ARQs(%s) bind status sync error, status is %s.", + uuid, status) + break + elif state == constants.ARQ_BOUND: + LOG.info("ARQs(%s) bind status sync finish, status is %s.", + uuid, status) + cls.bind_notify(device_profile_name, instance_uuid, status) + + @classmethod + @utils.wrap_job_tb("Error in ARQ bind async job_monitor. Reason: %s") + def job_monitor(cls, context, works_generator, extarqs): + """monitor every deployable bind jobs.""" + # result: f.result(), f.exception_info(), f._state + msg = None + arq_uuids = [ea.arq.uuid for ea in extarqs] + LOG.info('Monitor master check ARQ %s async bind job.', arq_uuids) + for _, (exc, tb), _, err in works_generator: + msg = "".join(utils.format_tb(tb)) + str(exc) if exc else err + if msg: + LOG.error(msg) + # TODO(Shaohe) Rollback? Such as We have _update_placement, + # should cancel it. + LOG.info('All ARQs %s async bind jobs has finished.', arq_uuids) + if not arq_uuids: + return + cls.check_bindings_result(context, extarqs) + + @classmethod + def bind_notify(cls, device_profile_name, instance_uuid, status): + """Notify the bind status to nova.""" + nova_api = nova_client.NovaAPI() + nova_api.notify_binding(instance_uuid, + device_profile_name, status) + + def get_resources_from_device_profile_group(self): + """parser device profile group.""" + group = self.device_profile_group + # example: {"resources:CUSTOM_ACCELERATOR_FPGA": "1"} + resources = [ + (k.lstrip(constants.RESOURCES_PREFIX), v) for k, v in group.items() + if k.startswith(constants.RESOURCES_PREFIX)] + if not resources: + raise exception.InvalidParameterValue( + 'No resources in device_profile_group: %s' % group) + res_type, res_num = resources[0] + if res_type not in constants.SUPPORT_RESOURCES: + raise exception.InvalidParameterValue( + 'Unsupport resources %s from device_profile_group: %s' % + (res_type, group)) + try: + res_num = int(res_num) + except ValueError: + raise exception.InvalidParameterValue( + 'Resources nummber is a invalid in' + 'device_profile_group: %s' % group) + return res_type, res_num + + @classmethod + def apply_patch(cls, context, patch_list, valid_fields): + """Apply JSON patch. See api/controllers/v1/arqs.py.""" + arq_binds = {} + for arq_uuid, patch in patch_list.items(): + extarq = cls.get_suitable_ext_arq(context, arq_uuid) + if patch[0]['op'] == 'add': # All ops are 'add' + # arq_notify_list.append(arq_uuid) + job = extarq.start_bind_job(context, valid_fields) + arq_binds[extarq] = job + else: + extarq.unbind(context, extarq) + if arq_binds: + cls.master(context, arq_binds) diff --git a/cyborg/objects/extarq/fpga_ext_arq.py b/cyborg/objects/extarq/fpga_ext_arq.py new file mode 100644 index 00000000..4c36bc37 --- /dev/null +++ b/cyborg/objects/extarq/fpga_ext_arq.py @@ -0,0 +1,244 @@ +# Copyright 2019 Intel Ltd. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +""" +Different accelerator handlers for conductor/agent/api/object to call. +""" + + +from openstack import connection +from oslo_log import log as logging +from oslo_serialization import jsonutils + + +from cyborg.agent.rpcapi import AgentAPI +from cyborg.common import constants +from cyborg.common import exception +from cyborg.common import placement_client +from cyborg.common import utils +from cyborg.objects import base +from cyborg.objects.ext_arq import ExtARQ + + +LOG = logging.getLogger(__name__) + + +@utils.factory_register(ExtARQ, constants.FPGA) +@base.CyborgObjectRegistry.register +class FPGAExtARQ(ExtARQ): + """FPGA Extra ARQ.""" + + def _get_bitstream_id(self): + bitstream_id = self.device_profile_group.get( + constants.ACCEL_BITSTREAM_ID) + return bitstream_id + + def _get_function_id(self): + function_id = self.device_profile_group.get( + constants.ACCEL_FUNCTION_ID) + return function_id + + def _get_bitstream_md_from_bitstream_id(self, bitstream_id): + """Get bitstream metadata given a bitstream id.""" + conn = connection.Connection(cloud='devstack-admin') + resp = conn.image.get('/images/' + bitstream_id) + if resp: + return resp.json() + else: + LOG.warning('Failed to get image for bitstream (%s)', + bitstream_id) + return None + + # TODO(Shaohe) should move to spec handler. + def _get_bitstream_md_from_function_id(self, function_id): + """Get bitstream metadata given a function id.""" + # TODO(Shaohe) parametrize this role in config file. + conn = connection.Connection(cloud='devstack-admin') + properties = {constants.ACCEL_FUNCTION_ID: function_id} + resp = conn.image.get('/images', params=properties) + if resp: + image_list = resp.json()['images'] + if type(image_list) != list: + raise exception.InvalidType( + obj='image', type=type(image_list), + expected='list') + if len(image_list) != 1: + raise exception.ExpectedOneObject(obj='image', + count=len(image_list)) + LOG.info('[arqs:objs] For function id (%s), got ' + 'bitstream id (%s)', function_id, + image_list[0]['id']) + return image_list[0] + else: + LOG.warning('Failed to get image for function (%s)', + function_id) + return None + + def _needs_programming(self, context, deployable): + bs_id = self._get_bitstream_id() + fun_id = self._get_function_id() + if all([bs_id, fun_id]): + self.update_check_state( + context, constants.ARQ_BIND_FAILED) + raise exception.InvalidParameterValue( + 'In device profile {0}, only one among bitstream_id ' + 'and function_id must be set, but both are set') + # TODO(Shaohe) Optimize: check if deployable already has + # bitstream/function + if any([bs_id, fun_id]): + LOG.info('[arqs:objs] bind. Programming needed. ' + 'bitstream: (%s) function: (%s) Deployable UUID: (%s)', + bs_id or '', fun_id or '', deployable.uuid) + if deployable.bitstream_id == bs_id: + LOG.info('Deployable %s already has the needed ' + 'bitstream %s. Skipping programming.', + deployable.uuid, bs_id) + return False + + return True + + def get_bitstream_md(self, context, deployable, function_id, bitstream_id): + """Get bitstream metadate from FPGA image.""" + LOG.info("Get bitstream metadata for deployable(uuid:%s).", + deployable.uuid) + # TODO(Shaohe) Check that deployable.device.hostname matches param + # hostname out of here + if not self._needs_programming(context, deployable): + return + + # FPGA aaS or ccelerated Function aaS + bitstream_md = ( + self._get_bitstream_md_from_bitstream_id(bitstream_id) + if bitstream_id else + self._get_bitstream_md_from_function_id(function_id)) + if bitstream_md: + LOG.info('ARQ %s get bitstream metadata:%s from image registry.', + self.arq.uuid, bitstream_md) + else: + self.update_check_state( + context, constants.ARQ_BIND_FAILED) + LOG.error('Can not get bitstream metadata from image registry ' + 'for ARQ %s', self.arq.uuid) + return bitstream_md + + def _need_extra_bind_job(self, context, deployable): + return self._needs_programming(context, deployable) + + @utils.wrap_job_tb("Error during ARQ bind job. Reason: %s") + def bind(self, context, deployable): + LOG.info('Start bind jobs for ARQ(%s) with deployable(%s)', + self.arq.uuid, deployable.uuid) + bs_id = self._get_bitstream_id() + fun_id = self._get_function_id() + bs_md = self.get_bitstream_md(context, deployable, fun_id, bs_id) + ok = False + if bs_md: + ok = self._do_programming(context, deployable, bs_md['id']) + if ok: + fun_id = fun_id or bs_md[constants.ACCEL_FUNCTION_ID] + self._update_placement(context, fun_id, deployable.driver_name) + deployable.update(context, {'bitstream_id': bs_md['id']}) + + super(FPGAExtARQ, self).bind(context, deployable) + + return True + + def _unbind(self): + # TODO(Shaohe) add cancel _update_placement, unbind operation. + pass + + def _delete(self): + # TODO(Shaohe) add cancel _update_placement, delete operation. + pass + + def _update_placement(self, context, function_id, driver_name): + """update resources provider after program.""" + # TODO(Sundar) Don't apply function trait if bitstream is private + if not function_id: + LOG.info("Not get function id for resources provider %s.", + self.arq.device_rp_uuid) + return + + placement = placement_client.PlacementClient() + try: + placement.delete_traits_with_prefixes( + self.arq.device_rp_uuid, [constants.FPGA_FUNCTION_ID]) + except Exception as e: + LOG.error("Failed to delete traits(%s) from resources provider %s." + "Reason: %s", constants.FPGA_FUNCTION_ID, + self.arq.device_rp_uuid, e.message) + self.update_check_state( + context, constants.ARQ_BIND_FAILED) + raise + + function_id = function_id.upper().replace('-', '_-') + # TODO(Sundar) Validate this is a valid trait name + vendor = driver_name.upper() + trait_names = ["_".join(( + constants.FPGA_FUNCTION_ID, vendor, function_id))] + try: + placement.add_traits_to_rp( + self.arq.device_rp_uuid, trait_names) + except Exception as e: + LOG.error("Failed to add traits(%s) to resources provider %s." + "Reason: %s", trait_names, + self.arq.device_rp_uuid, e.message) + # TODO(Shaohe) Rollback? We have _update_placement, + # should cancel it. + self.update_check_state( + context, constants.ARQ_BIND_FAILED) + raise + LOG.info("Add traits(%s) to resources provider %s.", + trait_names, self.arq.device_rp_uuid) + + def _do_programming(self, context, deployable, bitstream_id): + """FPGA program.""" + hostname = self.arq.hostname + driver_name = deployable.driver_name + + # query_filter = {"device_id": deployable.device_id} + # TODO(Shaohe) We should probably get cpid from objects layer, + # not db layer + cpid_list = deployable.get_cpid_list(context) + count = len(cpid_list) + if count != 1: + self.update_check_state( + context, constants.ARQ_BIND_FAILED) + raise exception.ExpectedOneObject(type='controlpath_id', + count=count) + controlpath_id = cpid_list[0] + controlpath_id['cpid_info'] = jsonutils.loads( + controlpath_id['cpid_info']) + LOG.info('Found control path id: %s', controlpath_id.__dict__) + + LOG.info('Starting programming for host: (%s) deployable (%s) ' + 'bitstream_id (%s)', hostname, + deployable.uuid, bitstream_id) + # TODO(Shaohe) do this asynchronously, do this in conductor or agent? + try: + agent = AgentAPI() + agent.fpga_program_v2(context, hostname, + controlpath_id, bitstream_id, + driver_name) + except Exception as e: + self.update_check_state( + context, constants.ARQ_BIND_FAILED) + LOG.error('Failed programming for host: (%s) deployable (%s). ' + 'Error: %s', hostname, deployable.uuid, e.message) + raise + LOG.info('Finished programming for host: (%s) deployable (%s)', + hostname, deployable.uuid) + # TODO(Shaohe) propagate agent errors to caller + return True diff --git a/cyborg/tests/unit/fake_extarq.py b/cyborg/tests/unit/fake_extarq.py index a2d18f43..c924eb52 100644 --- a/cyborg/tests/unit/fake_extarq.py +++ b/cyborg/tests/unit/fake_extarq.py @@ -14,6 +14,7 @@ from cyborg.objects import arq from cyborg.objects import ext_arq +from cyborg.objects.extarq import fpga_ext_arq def _get_arqs_as_dict(): @@ -30,6 +31,10 @@ def _get_arqs_as_dict(): "domain": "0", "function": "0" }, + "device_profile_group": { + "trait:CUSTOM_FPGA_INTEL": "required", + "resources:FPGA": "1", + "accel:bitstream_id": "b069d97a-010a-4057-b70d-eca2b337fc9c"} } arqs = [ # Corresponds to 1st device profile in fake_device)profile.py {"uuid": "a097fefa-da62-4630-8e8b-424c0e3426dc", @@ -59,6 +64,7 @@ def _convert_from_dict_to_obj(arq_dict): obj_arq[field] = arq_dict[field] obj_extarq = ext_arq.ExtARQ() obj_extarq.arq = obj_arq + obj_extarq.device_profile_group = arq_dict["device_profile_group"] return obj_extarq @@ -68,6 +74,22 @@ def get_fake_extarq_objs(): return obj_extarqs +def _convert_from_dict_to_fpga_obj(arq_dict): + obj_arq = arq.ARQ() + for field in arq_dict.keys(): + obj_arq[field] = arq_dict[field] + obj_extarq = fpga_ext_arq.FPGAExtARQ() + obj_extarq.arq = obj_arq + obj_extarq.device_profile_group = arq_dict["device_profile_group"] + return obj_extarq + + +def get_fake_fpga_extarq_objs(): + arq_list = _get_arqs_as_dict() + obj_extarqs = list(map(_convert_from_dict_to_fpga_obj, arq_list)) + return obj_extarqs + + def get_fake_db_extarqs(): db_extarqs = [] for db_extarq in _get_arqs_as_dict(): diff --git a/cyborg/tests/unit/objects/test_extarq.py b/cyborg/tests/unit/objects/test_extarq.py index 313d5627..e70f0f32 100644 --- a/cyborg/tests/unit/objects/test_extarq.py +++ b/cyborg/tests/unit/objects/test_extarq.py @@ -17,8 +17,11 @@ import mock from testtools.matchers import HasLength +from cyborg.common import constants +from cyborg.common import exception from cyborg import objects from cyborg.tests.unit.db import base +from cyborg.tests.unit import fake_deployable from cyborg.tests.unit import fake_extarq @@ -28,6 +31,8 @@ class TestExtARQObject(base.DbTestCase): super(TestExtARQObject, self).setUp() self.fake_db_extarqs = fake_extarq.get_fake_db_extarqs() self.fake_obj_extarqs = fake_extarq.get_fake_extarq_objs() + self.fake_obj_fpga_extarqs = fake_extarq.get_fake_fpga_extarq_objs() + self.deployable_uuids = ['0acbf8d6-e02a-4394-aae3-57557d209498'] @mock.patch('cyborg.objects.ExtARQ._from_db_object') def test_get(self, mock_from_db_obj): @@ -71,8 +76,10 @@ class TestExtARQObject(base.DbTestCase): @mock.patch('cyborg.common.nova_client.NovaAPI.notify_binding') @mock.patch('cyborg.objects.ExtARQ.bind') @mock.patch('cyborg.objects.ExtARQ.get') - def test_apply_patch(self, mock_get, mock_bind, mock_notify_bind, - mock_conn): + def test_apply_patch_to_bad_arq_state( + self, mock_get, mock_bind, mock_notify_bind, mock_conn): + good_states = constants.ARQ_STATES_TRANSFORM_MATRIX[ + constants.ARQ_BIND_STARTED] mock_get.return_value = obj_extarq = self.fake_obj_extarqs[0] uuid = obj_extarq.arq.uuid instance_uuid = obj_extarq.arq.instance_uuid @@ -91,11 +98,178 @@ class TestExtARQObject(base.DbTestCase): "value": instance_uuid} ] } + + for state in set(constants.ARQ_STATES) - set(good_states): + obj_extarq.arq.state = state + mock_get.return_value = obj_extarq + self.assertRaises( + exception.ARQInvalidState, objects.ExtARQ.apply_patch, + self.context, patch_list, valid_fields) + + @mock.patch('openstack.connection.Connection') + @mock.patch('cyborg.common.nova_client.NovaAPI.notify_binding') + @mock.patch('cyborg.objects.ExtARQ._allocate_attach_handle') + @mock.patch('cyborg.objects.ExtARQ.get') + @mock.patch('cyborg.objects.ExtARQ.list') + @mock.patch('cyborg.objects.ExtARQ.update_check_state') + @mock.patch('cyborg.objects.deployable.Deployable.get_by_device_rp_uuid') + def test_apply_patch_for_common_extarq( + self, mock_get_dep, mock_check_state, mock_list, mock_get, + mock_attach_handle, mock_notify_bind, mock_conn): + + good_states = constants.ARQ_STATES_TRANSFORM_MATRIX[ + constants.ARQ_BIND_STARTED] + obj_extarq = self.fake_obj_extarqs[0] + obj_extarq.arq.state = good_states[0] + + # TODO(Shaohe) we should control the state of arq to make + # better testcase. + # bound_extarq = copy.deepcopy(obj_extarq) + # bound_extarq.arq.state = constants.ARQ_BOUND + # mock_get.side_effect = [obj_extarq, bound_extarq] + mock_get.side_effect = [obj_extarq] * 2 + mock_list.return_value = [obj_extarq] + uuid = obj_extarq.arq.uuid + instance_uuid = obj_extarq.arq.instance_uuid + + dep_uuid = self.deployable_uuids[0] + fake_dep = fake_deployable.fake_deployable_obj(self.context, + uuid=dep_uuid) + mock_get_dep.return_value = fake_dep + valid_fields = { + uuid: {'hostname': obj_extarq.arq.hostname, + 'device_rp_uuid': obj_extarq.arq.device_rp_uuid, + 'instance_uuid': instance_uuid} + } + patch_list = { + str(uuid): [ + {"path": "/hostname", "op": "add", + "value": obj_extarq.arq.hostname}, + {"path": "/device_rp_uuid", "op": "add", + "value": obj_extarq.arq.device_rp_uuid}, + {"path": "/instance_uuid", "op": "add", + "value": instance_uuid} + ] + } objects.ExtARQ.apply_patch(self.context, patch_list, valid_fields) - status = 'completed' + # NOTE(Shaohe) we set the fake_obj_extarqs state is ARQ_INITIAL + # TODO(Shaohe) we should control the state of arq to make + # complete status testcase. + status = 'failed' mock_notify_bind.assert_called_once_with( instance_uuid, obj_extarq.arq.device_profile_name, status) + @mock.patch('openstack.connection.Connection') + @mock.patch('cyborg.common.nova_client.NovaAPI.notify_binding') + @mock.patch('cyborg.objects.ExtARQ._allocate_attach_handle') + @mock.patch('cyborg.objects.ExtARQ.get') + @mock.patch('cyborg.objects.ExtARQ.list') + @mock.patch('cyborg.objects.ExtARQ.update_check_state') + @mock.patch('cyborg.objects.deployable.Deployable.get_by_device_rp_uuid') + @mock.patch('cyborg.common.utils.ThreadWorks.spawn') + def test_apply_patch_start_fpga_arq_job( + self, mock_spawn, mock_get_dep, mock_check_state, mock_list, mock_get, + mock_attach_handle, mock_notify_bind, mock_conn): + good_states = constants.ARQ_STATES_TRANSFORM_MATRIX[ + constants.ARQ_BIND_STARTED] + obj_extarq = self.fake_obj_extarqs[0] + obj_fpga_extarq = self.fake_obj_fpga_extarqs[0] + obj_fpga_extarq.state = self.fake_obj_fpga_extarqs[0] + obj_extarq.arq.state = good_states[0] + obj_fpga_extarq.arq.state = good_states[0] + + # TODO(Shaohe) we should control the state of arq to make + # better testcase. + # bound_extarq = copy.deepcopy(obj_extarq) + # bound_extarq.arq.state = constants.ARQ_BOUND + # mock_get.side_effect = [obj_extarq, bound_extarq] + mock_get.side_effect = [obj_extarq, obj_fpga_extarq] + mock_list.return_value = [obj_extarq] + uuid = obj_extarq.arq.uuid + instance_uuid = obj_extarq.arq.instance_uuid + # mock_job_get_ext_arq.side_effect = obj_extarq + dep_uuid = self.deployable_uuids[0] + fake_dep = fake_deployable.fake_deployable_obj(self.context, + uuid=dep_uuid) + mock_get_dep.return_value = fake_dep + mock_spawn.return_value = None + valid_fields = { + uuid: {'hostname': obj_extarq.arq.hostname, + 'device_rp_uuid': obj_extarq.arq.device_rp_uuid, + 'instance_uuid': instance_uuid} + } + patch_list = { + str(uuid): [ + {"path": "/hostname", "op": "add", + "value": obj_extarq.arq.hostname}, + {"path": "/device_rp_uuid", "op": "add", + "value": obj_extarq.arq.device_rp_uuid}, + {"path": "/instance_uuid", "op": "add", + "value": instance_uuid} + ] + } + objects.ExtARQ.apply_patch(self.context, patch_list, valid_fields) + # NOTE(Shaohe) we set the fake_obj_extarqs state is ARQ_INITIAL + # TODO(Shaohe) we should control the state of arq to make + # better testcase. + status = 'failed' + mock_notify_bind.assert_called_once_with( + instance_uuid, obj_extarq.arq.device_profile_name, status) + # NOTE(Shaohe) check it spawn to start a job. + mock_spawn.assert_called_once_with( + obj_fpga_extarq.bind, self.context, fake_dep) + + @mock.patch('openstack.connection.Connection') + @mock.patch('cyborg.common.nova_client.NovaAPI.notify_binding') + @mock.patch('cyborg.objects.ExtARQ._allocate_attach_handle') + @mock.patch('cyborg.objects.ExtARQ.get') + @mock.patch('cyborg.objects.ExtARQ.list') + @mock.patch('cyborg.objects.ExtARQ.update_check_state') + @mock.patch('cyborg.objects.deployable.Deployable.get_by_device_rp_uuid') + @mock.patch('cyborg.common.utils.ThreadWorks.spawn_master') + def test_apply_patch_fpga_arq_monitor_job( + self, mock_master, mock_get_dep, mock_check_state, mock_list, + mock_get, mock_attach_handle, mock_notify_bind, mock_conn): + + good_states = constants.ARQ_STATES_TRANSFORM_MATRIX[ + constants.ARQ_BIND_STARTED] + obj_extarq = self.fake_obj_extarqs[0] + obj_fpga_extarq = self.fake_obj_fpga_extarqs[0] + obj_fpga_extarq.state = self.fake_obj_fpga_extarqs[0] + obj_extarq.arq.state = good_states[0] + obj_fpga_extarq.arq.state = good_states[0] + + # TODO(Shaohe) we should control the state of arq to make + # better testcase. + # bound_extarq = copy.deepcopy(obj_extarq) + # bound_extarq.arq.state = constants.ARQ_BOUND + # mock_get.side_effect = [obj_extarq, bound_extarq] + mock_get.side_effect = [obj_extarq, obj_fpga_extarq] + mock_list.return_value = [obj_extarq] + uuid = obj_extarq.arq.uuid + instance_uuid = obj_extarq.arq.instance_uuid + dep_uuid = self.deployable_uuids[0] + fake_dep = fake_deployable.fake_deployable_obj(self.context, + uuid=dep_uuid) + mock_get_dep.return_value = fake_dep + valid_fields = { + uuid: {'hostname': obj_extarq.arq.hostname, + 'device_rp_uuid': obj_extarq.arq.device_rp_uuid, + 'instance_uuid': instance_uuid} + } + patch_list = { + str(uuid): [ + {"path": "/hostname", "op": "add", + "value": obj_extarq.arq.hostname}, + {"path": "/device_rp_uuid", "op": "add", + "value": obj_extarq.arq.device_rp_uuid}, + {"path": "/instance_uuid", "op": "add", + "value": instance_uuid} + ] + } + objects.ExtARQ.apply_patch(self.context, patch_list, valid_fields) + mock_master.assert_called_once() + @mock.patch('cyborg.objects.ExtARQ.get') @mock.patch('cyborg.objects.ExtARQ._from_db_object') def test_destroy(self, mock_from_db_obj, mock_obj_extarq):