diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6716afa9..6de285b2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -35,17 +35,12 @@ repos: hooks: - id: ruff-check args: ['--fix', '--unsafe-fixes'] - - repo: https://github.com/hhatto/autopep8 - rev: v2.3.2 - hooks: - - id: autopep8 - files: '^.*\.py$' + - id: ruff-format - repo: https://opendev.org/openstack/hacking - rev: 7.0.0 + rev: 8.0.0 hooks: - id: hacking additional_dependencies: [] - exclude: '^(doc|releasenotes|tools)/.*$' - repo: https://github.com/codespell-project/codespell rev: v2.4.1 hooks: diff --git a/.zuul.yaml b/.zuul.yaml index 419195b8..0806c4bb 100644 --- a/.zuul.yaml +++ b/.zuul.yaml @@ -1,26 +1,3 @@ -- job: - name: cyborg-tox-bandit - parent: openstack-tox - timeout: 2400 - vars: - tox_envlist: bandit - required-projects: - - openstack/requirements - irrelevant-files: &gate-irrelevant-files - - ^(test-|)requirements.txt$ - - ^.*\.rst$ - - ^api-ref/.*$ - - ^cyborg/cmd/status\.py$ - - ^cyborg/hacking/.*$ - - ^cyborg/tests/functional.*$ - - ^cyborg/tests/unit.*$ - - ^doc/.*$ - - ^etc/.*$ - - ^releasenotes/.*$ - - ^setup.cfg$ - - ^tools/.*$ - - ^tox.ini$ - - project: templates: - openstack-cover-jobs @@ -32,7 +9,6 @@ jobs: - cyborg-tempest - cyborg-tempest-ipv6-only - - cyborg-tox-bandit gate: jobs: - cyborg-tempest diff --git a/HACKING.rst b/HACKING.rst index 7836e452..2b898fd4 100644 --- a/HACKING.rst +++ b/HACKING.rst @@ -7,8 +7,7 @@ Before you commit your code run tox against your patch using the command. tox . -If any of the tests fail correct the error and try again. If your code is valid -Python but not valid pep8 you may find autopep8 from pip useful. +If any of the tests fail correct the error and try again. Once you submit a patch integration tests will run and those may fail, -1'ing your patch you can make a gerrit comment 'recheck ci' if you have diff --git a/api-ref/source/conf.py b/api-ref/source/conf.py index 3d275fc8..abc9430a 100644 --- a/api-ref/source/conf.py +++ b/api-ref/source/conf.py @@ -68,6 +68,11 @@ html_theme_options = { # (source start file, target name, title, author, documentclass # [howto/manual]). latex_documents = [ - ('index', 'Cyborg.tex', 'OpenStack Acceleration API Documentation', - 'OpenStack Foundation', 'manual'), + ( + 'index', + 'Cyborg.tex', + 'OpenStack Acceleration API Documentation', + 'OpenStack Foundation', + 'manual', + ), ] diff --git a/cyborg/__init__.py b/cyborg/__init__.py index a6a23527..3c22f50d 100644 --- a/cyborg/__init__.py +++ b/cyborg/__init__.py @@ -1,4 +1,3 @@ - # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain # a copy of the License at @@ -14,5 +13,4 @@ import pbr.version -__version__ = pbr.version.VersionInfo( - 'openstack-cyborg').version_string() +__version__ = pbr.version.VersionInfo('openstack-cyborg').version_string() diff --git a/cyborg/accelerator/__init__.py b/cyborg/accelerator/__init__.py index a6a23527..3c22f50d 100644 --- a/cyborg/accelerator/__init__.py +++ b/cyborg/accelerator/__init__.py @@ -1,4 +1,3 @@ - # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain # a copy of the License at @@ -14,5 +13,4 @@ import pbr.version -__version__ = pbr.version.VersionInfo( - 'openstack-cyborg').version_string() +__version__ = pbr.version.VersionInfo('openstack-cyborg').version_string() diff --git a/cyborg/accelerator/accelerator.py b/cyborg/accelerator/accelerator.py index 99ca37cf..45dff5d1 100644 --- a/cyborg/accelerator/accelerator.py +++ b/cyborg/accelerator/accelerator.py @@ -1,4 +1,3 @@ - # Copyright 2016-2017 OpenStack Foundation # # Licensed under the Apache License, Version 2.0 (the "License"); you may @@ -15,6 +14,7 @@ from sqlalchemy import Column, Integer, String from sqlalchemy.ext.declarative import declarative_base + Base = declarative_base() diff --git a/cyborg/accelerator/common/exception.py b/cyborg/accelerator/common/exception.py index 2de52698..3cbd78a8 100644 --- a/cyborg/accelerator/common/exception.py +++ b/cyborg/accelerator/common/exception.py @@ -12,7 +12,7 @@ # License for the specific language governing permissions and limitations # under the License. -"""Accelerator base exception handling. """ +"""Accelerator base exception handling.""" import collections from http import HTTPStatus @@ -38,8 +38,10 @@ def _ensure_exception_kwargs_serializable(exc_class_name, kwargs): constructor. :returns: a dictionary of serializable keyword arguments. """ - serializers = [(jsonutils.dumps, _('when converting to JSON')), - (str, _('when converting to string'))] + serializers = [ + (jsonutils.dumps, _('when converting to JSON')), + (str, _('when converting to string')), + ] exceptions = collections.defaultdict(list) serializable_kwargs = {} for k, v in kwargs.items(): @@ -50,20 +52,29 @@ def _ensure_exception_kwargs_serializable(exc_class_name, kwargs): break except Exception as e: exceptions[k].append( - '(%(serializer_type)s) %(e_type)s: %(e_contents)s' % - {'serializer_type': msg, 'e_contents': e, - 'e_type': e.__class__.__name__}) + '(%(serializer_type)s) %(e_type)s: %(e_contents)s' + % { + 'serializer_type': msg, + 'e_contents': e, + 'e_type': e.__class__.__name__, + } + ) if exceptions: - LOG.error("One or more arguments passed to the %(exc_class)s " - "constructor as kwargs can not be serialized. The " - "serialized arguments: %(serialized)s. These " - "unserialized kwargs were dropped because of the " - "exceptions encountered during their " - "serialization:\n%(errors)s", - dict(errors=';\n'.join("%s: %s" % (k, '; '.join(v)) - for k, v in exceptions.items()), - exc_class=exc_class_name, - serialized=serializable_kwargs)) + LOG.error( + "One or more arguments passed to the %(exc_class)s " + "constructor as kwargs can not be serialized. The " + "serialized arguments: %(serialized)s. These " + "unserialized kwargs were dropped because of the " + "exceptions encountered during their " + "serialization:\n%(errors)s", + dict( + errors=';\n'.join( + "%s: %s" % (k, '; '.join(v)) for k, v in exceptions.items() + ), + exc_class=exc_class_name, + serialized=serializable_kwargs, + ), + ) # We might be able to actually put the following keys' values into # format string, but there is no guarantee, drop it just in case. for k in exceptions: @@ -81,15 +92,16 @@ class AcceleratorException(Exception): If you need to access the message from an exception you should use str(exc). """ + _msg_fmt = _("An unknown exception occurred.") code = HTTPStatus.INTERNAL_SERVER_ERROR headers = {} safe = False def __init__(self, message=None, **kwargs): - self.kwargs = _ensure_exception_kwargs_serializable( - self.__class__.__name__, kwargs) + self.__class__.__name__, kwargs + ) if 'code' not in self.kwargs: try: diff --git a/cyborg/accelerator/common/utils.py b/cyborg/accelerator/common/utils.py index 26cf00f7..7222b0ba 100644 --- a/cyborg/accelerator/common/utils.py +++ b/cyborg/accelerator/common/utils.py @@ -21,9 +21,9 @@ from oslo_serialization import jsonutils from cyborg.common import exception -_PCI_ADDRESS_PATTERN = ("^(hex{4}):(hex{2}):(hex{2}).(oct{1})$". - replace("hex", r"[\da-fA-F]"). - replace("oct", "[0-7]")) +_PCI_ADDRESS_PATTERN = "^(hex{4}):(hex{2}):(hex{2}).(oct{1})$".replace( + "hex", r"[\da-fA-F]" +).replace("oct", "[0-7]") _PCI_ADDRESS_REGEX = re.compile(_PCI_ADDRESS_PATTERN) @@ -93,11 +93,18 @@ def parse_mappings(mapping_list): raise ValueError(("Missing key in mapping: '%s'") % dev_mapping) if physnet_or_function in mapping: raise ValueError( - ("Key %(physnet_or_function)s in mapping: %(mapping)s " - "not unique") % {'physnet_or_function': physnet_or_function, - 'mapping': dev_mapping}) - mapping[physnet_or_function] = set(dev.strip() for dev in - devices.split("|") if dev.strip()) + ( + "Key %(physnet_or_function)s in mapping: %(mapping)s " + "not unique" + ) + % { + 'physnet_or_function': physnet_or_function, + 'mapping': dev_mapping, + } + ) + mapping[physnet_or_function] = set( + dev.strip() for dev in devices.split("|") if dev.strip() + ) return mapping @@ -106,13 +113,14 @@ def get_vendor_maps(): :return: vendor maps dict """ - return {"10de": "nvidia", - "102b": "matrox", - "1bd4": "inspur", - "8086": "intel", - "1099": "samsung", - "1cf2": "zte" - } + return { + "10de": "nvidia", + "102b": "matrox", + "1bd4": "inspur", + "8086": "intel", + "1099": "samsung", + "1cf2": "zte", + } def mdev_str_to_json(pci_address, asked_type, vgpu_mark): diff --git a/cyborg/accelerator/configuration.py b/cyborg/accelerator/configuration.py index cd4a3c5b..da6281b5 100644 --- a/cyborg/accelerator/configuration.py +++ b/cyborg/accelerator/configuration.py @@ -139,8 +139,9 @@ class Configuration: """ self.config_group = config_group if config_group: - self.conf = BackendGroupConfiguration(accelerator_opts, - config_group) + self.conf = BackendGroupConfiguration( + accelerator_opts, config_group + ) else: self.conf = DefaultGroupConfiguration() diff --git a/cyborg/accelerator/drivers/aichip/huawei/ascend.py b/cyborg/accelerator/drivers/aichip/huawei/ascend.py index c938cd6e..9f05a0d2 100644 --- a/cyborg/accelerator/drivers/aichip/huawei/ascend.py +++ b/cyborg/accelerator/drivers/aichip/huawei/ascend.py @@ -24,12 +24,14 @@ from cyborg.objects.driver_objects import driver_deployable from cyborg.objects.driver_objects import driver_device import cyborg.privsep -PCI_INFO_PATTERN = re.compile(r"(?P[0-9a-f]{4}:[0-9a-f]{2}:" - r"[0-9a-f]{2}\.[0-9a-f]) " - r"(?P.*) [\[].*]: (?P.*) .*" - r"[\[](?P[0-9a-fA-F]" - r"{4}):(?P[0-9a-fA-F]{4})].*" - r"[(rev ](?P[0-9a-f]{2})") +PCI_INFO_PATTERN = re.compile( + r"(?P[0-9a-f]{4}:[0-9a-f]{2}:" + r"[0-9a-f]{2}\.[0-9a-f]) " + r"(?P.*) [\[].*]: (?P.*) .*" + r"[\[](?P[0-9a-fA-F]" + r"{4}):(?P[0-9a-fA-F]{4})].*" + r"[(rev ](?P[0-9a-f]{2})" +) @cyborg.privsep.sys_admin_pctxt.entrypoint @@ -41,8 +43,9 @@ def lspci_privileged(): class AscendDriver(GenericDriver): """The class for Ascend AI Chip drivers. - This is the Huawei Ascend AI Chip drivers. + This is the Huawei Ascend AI Chip drivers. """ + VENDOR = "huawei" # TODO(yikun): can be extracted into PCIDeviceDriver @@ -106,8 +109,10 @@ class AscendDriver(GenericDriver): device.stub = False device.vendor = pci_dict["vendor_id"] device.model = pci_dict.get('model', '') - std_board_info = {'device_id': pci_dict.get('device_id', None), - 'class': pci_dict.get('class', None)} + std_board_info = { + 'device_id': pci_dict.get('device_id', None), + 'class': pci_dict.get('class', None), + } device.std_board_info = jsonutils.dumps(std_board_info) device.vendor_board_info = '' device.type = constants.DEVICE_AICHIP diff --git a/cyborg/accelerator/drivers/driver.py b/cyborg/accelerator/drivers/driver.py index b6038161..ed1083a6 100644 --- a/cyborg/accelerator/drivers/driver.py +++ b/cyborg/accelerator/drivers/driver.py @@ -15,7 +15,6 @@ import abc class GenericDriver(metaclass=abc.ABCMeta): - @abc.abstractmethod def discover(self): """Discover a specified accelerator. diff --git a/cyborg/accelerator/drivers/fake.py b/cyborg/accelerator/drivers/fake.py index 08798efd..5506fbcc 100644 --- a/cyborg/accelerator/drivers/fake.py +++ b/cyborg/accelerator/drivers/fake.py @@ -29,8 +29,9 @@ from cyborg.objects.driver_objects import driver_device class FakeDriver(GenericDriver): """Base class for Fake drivers. - This is just a Fake drivers interface. + This is just a Fake drivers interface. """ + VENDOR = "fake" NUM_ACCELERATORS = 16 @@ -76,7 +77,8 @@ class FakeDriver(GenericDriver): def _generate_dep_list(self, pci): driver_dep = driver_deployable.DriverDeployable() driver_dep.attach_handle_list = self._generate_attach_handles( - pci, self.NUM_ACCELERATORS) + pci, self.NUM_ACCELERATORS + ) # NOTE(sean-k-mooney): we need to prepend the host name to the # device name as this is used to generate the RP name and uuid in # the cyborg conductor when updating placement. As such this needs @@ -95,17 +97,19 @@ class FakeDriver(GenericDriver): fpga_list = [] pci_addr = '{"domain":"0000","bus":"0c","device":"00","function":"0"}' pci_dict = { - 'slot': pci_addr, # PCI slot address - 'device': 'FakeDevice', # Name of the device - 'vendor_id': '0xABCD', # ID of the vendor - 'class': 'Fake class', # Name of the class - 'device_id': '0xabcd' # ID of the device + 'slot': pci_addr, # PCI slot address + 'device': 'FakeDevice', # Name of the device + 'vendor_id': '0xABCD', # ID of the vendor + 'class': 'Fake class', # Name of the class + 'device_id': '0xabcd', # ID of the device } device = driver_device.DriverDevice() device.vendor = pci_dict["vendor_id"] device.model = pci_dict.get('model', 'miss model info') - std_board_info = {'device_id': pci_dict.get('device_id'), - 'class': pci_dict.get('class')} + std_board_info = { + 'device_id': pci_dict.get('device_id'), + 'class': pci_dict.get('class'), + } device.std_board_info = jsonutils.dumps(std_board_info) device.vendor_board_info = 'fake_vendor_info' device.type = constants.DEVICE_FPGA diff --git a/cyborg/accelerator/drivers/fpga/base.py b/cyborg/accelerator/drivers/fpga/base.py index 8794a5d4..9a6a7846 100644 --- a/cyborg/accelerator/drivers/fpga/base.py +++ b/cyborg/accelerator/drivers/fpga/base.py @@ -20,15 +20,14 @@ Cyborg FPGA driver implementation. from cyborg.accelerator.drivers.fpga import utils -VENDOR_MAPS = {"0x8086": "intel", - "1bd4": 'inspur'} +VENDOR_MAPS = {"0x8086": "intel", "1bd4": 'inspur'} class FPGADriver: """Base class for FPGA drivers. - This is just a virtual FPGA drivers interface. - Vendor should implement their specific drivers. + This is just a virtual FPGA drivers interface. + Vendor should implement their specific drivers. """ @classmethod diff --git a/cyborg/accelerator/drivers/fpga/inspur/driver.py b/cyborg/accelerator/drivers/fpga/inspur/driver.py index 8227ac33..fd63711c 100644 --- a/cyborg/accelerator/drivers/fpga/inspur/driver.py +++ b/cyborg/accelerator/drivers/fpga/inspur/driver.py @@ -27,8 +27,9 @@ LOG = logging.getLogger(__name__) class InspurFPGADriver(FPGADriver): """Class for Inspur FPGA drivers. - Vendor should implement their specific drivers in this class. + Vendor should implement their specific drivers in this class. """ + VENDOR = "inspur" def __init__(self, *args, **kwargs): diff --git a/cyborg/accelerator/drivers/fpga/inspur/sysinfo.py b/cyborg/accelerator/drivers/fpga/inspur/sysinfo.py index e44bf7fb..9f12b25b 100644 --- a/cyborg/accelerator/drivers/fpga/inspur/sysinfo.py +++ b/cyborg/accelerator/drivers/fpga/inspur/sysinfo.py @@ -32,14 +32,17 @@ from cyborg.objects.driver_objects import driver_deployable from cyborg.objects.driver_objects import driver_device import cyborg.privsep -INSPUR_FPGA_FLAGS = ["Inspur Electronic Information Industry Co., Ltd.", - "Processing accelerators"] +INSPUR_FPGA_FLAGS = [ + "Inspur Electronic Information Industry Co., Ltd.", + "Processing accelerators", +] INSPUR_FPGA_INFO_PATTERN = re.compile( r"(?P[0-9a-fA-F]{4}:[0-9a-fA-F]{2}:" r"[0-9a-fA-F]{2}\.[0-9a-fA-F]) " r"(?P.*) [\[].*]: (?P.*) .*" r"[\[](?P[0-9a-fA-F]" - r"{4}):(?P[0-9a-fA-F]{4})].*") + r"{4}):(?P[0-9a-fA-F]{4})].*" +) VENDOR_ID = "1bd4" VENDOR_MAPS = {"1bd4": "inspur"} @@ -88,7 +91,8 @@ def fpga_tree(): fpga_dict = m.groupdict() # generate traits info traits = get_traits( - fpga_dict["vendor_id"], fpga_dict["product_id"]) + fpga_dict["vendor_id"], fpga_dict["product_id"] + ) fpga_dict["rc"] = constants.RESOURCES["FPGA"] fpga_dict.update(traits) fpga_list.append(_generate_driver_device(fpga_dict)) @@ -99,10 +103,13 @@ def _generate_driver_device(fpga): driver_device_obj = driver_device.DriverDevice() driver_device_obj.vendor = fpga["vendor_id"] driver_device_obj.model = fpga.get('model', 'miss model info') - std_board_info = {'product_id': fpga.get('product_id'), - 'controller': fpga.get('controller')} + std_board_info = { + 'product_id': fpga.get('product_id'), + 'controller': fpga.get('controller'), + } vendor_board_info = { - 'vendor_info': fpga.get('vendor_info', 'fpga_vb_info')} + 'vendor_info': fpga.get('vendor_info', 'fpga_vb_info') + } driver_device_obj.std_board_info = jsonutils.dumps(std_board_info) driver_device_obj.vendor_board_info = jsonutils.dumps(vendor_board_info) driver_device_obj.type = constants.DEVICE_FPGA @@ -143,7 +150,8 @@ def _generate_attribute_list(fpga): if k == "traits": for index, val in enumerate(v): driver_attr = driver_attribute.DriverAttribute( - key="trait" + str(index), value=val) + key="trait" + str(index), value=val + ) attr_list.append(driver_attr) return attr_list diff --git a/cyborg/accelerator/drivers/fpga/intel/driver.py b/cyborg/accelerator/drivers/fpga/intel/driver.py index 4bf2f0b3..db71651e 100644 --- a/cyborg/accelerator/drivers/fpga/intel/driver.py +++ b/cyborg/accelerator/drivers/fpga/intel/driver.py @@ -45,8 +45,9 @@ def _fpga_program_privileged(cmd_args): class IntelFPGADriver(FPGADriver): """Class for Intel FPGA drivers. - Vendor should implement their specific drivers in this class. + Vendor should implement their specific drivers in this class. """ + VENDOR = "intel" def __init__(self, *args, **kwargs): @@ -58,23 +59,27 @@ class IntelFPGADriver(FPGADriver): def program(self, controlpath_id, image_file_path): """Program the FPGA with the provided bitstream image. - TODO(Sundar): Need to handle retries. + TODO(Sundar): Need to handle retries. - :param: controlpath_id - Controlpath_id OVO - :param: image_file_path - String with the file path - :returns: True on success, False on failure + :param: controlpath_id + Controlpath_id OVO + :param: image_file_path + String with the file path + :returns: True on success, False on failure """ if controlpath_id['cpid_type'] != "PCI": - raise exception.InvalidType(obj='controlpath_id', - type=controlpath_id['cpid_type'], - expected='PCI') + raise exception.InvalidType( + obj='controlpath_id', + type=controlpath_id['cpid_type'], + expected='PCI', + ) cmd_args = [] bdf_dict = controlpath_id['cpid_info'] # fitting format to the OPAE command. - bdf = ['0x' + s for s in map(lambda x: bdf_dict[x], - ["bus", "device", "function"])] + bdf = [ + '0x' + s + for s in map(lambda x: bdf_dict[x], ["bus", "device", "function"]) + ] for i in zip(["--bus", "--device", "--function"], bdf): cmd_args.extend(i) cmd_args.append(image_file_path) diff --git a/cyborg/accelerator/drivers/fpga/intel/sysinfo.py b/cyborg/accelerator/drivers/fpga/intel/sysinfo.py index 3c46c9e4..e598acf2 100644 --- a/cyborg/accelerator/drivers/fpga/intel/sysinfo.py +++ b/cyborg/accelerator/drivers/fpga/intel/sysinfo.py @@ -17,7 +17,6 @@ Cyborg Intel FPGA driver implementation. """ - import glob import os import re @@ -44,10 +43,10 @@ DEVICE = "device" PF = "physfn" VF = "virtfn*" BDF_PATTERN = re.compile( - r"^[a-fA-F\d]{4}:[a-fA-F\d]{2}:[a-fA-F\d]{2}\.[a-fA-F\d]$") + r"^[a-fA-F\d]{4}:[a-fA-F\d]{2}:[a-fA-F\d]{2}\.[a-fA-F\d]$" +) -DEVICE_FILE_MAP = {"vendor": "vendor", - "device": "product_id"} +DEVICE_FILE_MAP = {"vendor": "vendor", "device": "product_id"} DEVICE_FILE_HANDLER = {} DEVICE_EXPOSED = ["vendor", "device"] @@ -62,15 +61,16 @@ def read_line(filename): def is_fpga(p): - infos = (read_line(os.path.join(p, "vendor")), - read_line(os.path.join(p, "device"))) + infos = ( + read_line(os.path.join(p, "vendor")), + read_line(os.path.join(p, "device")), + ) if infos in KNOWN_FPGAS: return os.path.realpath(p) def link_real_path(p): - return os.path.realpath( - os.path.join(os.path.dirname(p), os.readlink(p))) + return os.path.realpath(os.path.join(os.path.dirname(p), os.readlink(p))) # TODO(s_shogo) This function name should be reconsidered in py3 @@ -79,34 +79,43 @@ def find_fpgas_by_know_list(): return filter( lambda p: ( read_line(os.path.join(p, "vendor")), - read_line(os.path.join(p, "device")) - ) in KNOWN_FPGAS, - glob.glob(PCI_DEVICES_PATH_PATTERN)) + read_line(os.path.join(p, "device")), + ) + in KNOWN_FPGAS, + glob.glob(PCI_DEVICES_PATH_PATTERN), + ) def get_link_targets(links): return map( - lambda p: - os.path.realpath( - os.path.join(os.path.dirname(p), os.readlink(p))), - links) + lambda p: os.path.realpath( + os.path.join(os.path.dirname(p), os.readlink(p)) + ), + links, + ) def all_fpgas(): # glob.glob("/sys/class/fpga", "*") return set(get_link_targets(find_fpgas_by_know_list())) | set( - map(lambda p: p.rsplit("/", 2)[0], - get_link_targets(glob.glob(os.path.join(SYS_FPGA, "*"))))) + map( + lambda p: p.rsplit("/", 2)[0], + get_link_targets(glob.glob(os.path.join(SYS_FPGA, "*"))), + ) + ) def all_vf_fpgas(): - return [dev.rsplit("/", 2)[0] for dev in - glob.glob(os.path.join(SYS_FPGA, "*/device/physfn"))] + return [ + dev.rsplit("/", 2)[0] + for dev in glob.glob(os.path.join(SYS_FPGA, "*/device/physfn")) + ] def all_pfs_have_vf(): - return list(filter(lambda p: glob.glob(os.path.join(p, "virtfn0")), - all_fpgas())) + return list( + filter(lambda p: glob.glob(os.path.join(p, "virtfn0")), all_fpgas()) + ) def target_symbolic_map(): @@ -121,13 +130,13 @@ def bdf_path_map(): def all_vfs_in_pf_fpgas(pf_path): - return get_link_targets( - glob.glob(os.path.join(pf_path, "virtfn*"))) + return get_link_targets(glob.glob(os.path.join(pf_path, "virtfn*"))) def all_pf_fpgas(): - return filter(lambda p: glob.glob(os.path.join(p, "sriov_totalvfs")), - all_fpgas()) + return filter( + lambda p: glob.glob(os.path.join(p, "sriov_totalvfs")), all_fpgas() + ) def is_bdf(bdf): @@ -146,9 +155,13 @@ def get_afu_ids(device_name): read_line, glob.glob( os.path.join( - PCI_DEVICES_PATH_PATTERN, "fpga", - device_name, "intel-fpga-port.*", "afu_id") - ) + PCI_DEVICES_PATH_PATTERN, + "fpga", + device_name, + "intel-fpga-port.*", + "afu_id", + ) + ), ) @@ -157,9 +170,9 @@ def get_region_ids(device_name): read_line, glob.glob( os.path.join( - SYS_FPGA, device_name, - "intel-fpga-fme.*", "pr/interface_id") - ) + SYS_FPGA, device_name, "intel-fpga-fme.*", "pr/interface_id" + ) + ), ) @@ -187,14 +200,16 @@ def fpga_device(path): infos = {} # NOTE "In 3.x, os.path.walk is removed in favor of os.walk." - for (dirpath, dirnames, filenames) in os.walk(path): + for dirpath, dirnames, filenames in os.walk(path): for filename in filenames: if filename in DEVICE_EXPOSED: key = DEVICE_FILE_MAP.get(filename) or filename if key in DEVICE_FILE_HANDLER and callable( - DEVICE_FILE_HANDLER(key)): + DEVICE_FILE_HANDLER(key) + ): infos[key] = DEVICE_FILE_HANDLER(key)( - os.path.join(dirpath, filename)) + os.path.join(dirpath, filename) + ) else: infos[key] = read_line(os.path.join(dirpath, filename)) return infos @@ -204,9 +219,12 @@ def fpga_tree(): def gen_fpga_infos(path, vf=True): bdf = get_bdf_by_path(path) names = glob.glob1(os.path.join(path, "fpga"), "*") - fpga = {"type": constants.DEVICE_FPGA, - "devices": bdf, "stub": True, - "name": "_".join((socket.gethostname(), bdf))} + fpga = { + "type": constants.DEVICE_FPGA, + "devices": bdf, + "stub": True, + "name": "_".join((socket.gethostname(), bdf)), + } d_info = fpga_device(path) fpga.update(d_info) if names: @@ -238,8 +256,9 @@ def _generate_driver_device(fpga, pf_has_vf): driver_device_obj.vendor = fpga["vendor"] driver_device_obj.stub = fpga["stub"] driver_device_obj.model = fpga.get('model', "miss_model_info") - driver_device_obj.vendor_board_info = fpga.get('vendor_board_info', - "miss_vb_info") + driver_device_obj.vendor_board_info = fpga.get( + 'vendor_board_info', "miss_vb_info" + ) std_board_info = {'product_id': fpga.get('product_id')} driver_device_obj.std_board_info = jsonutils.dumps(std_board_info) driver_device_obj.type = fpga["type"] @@ -262,8 +281,7 @@ def _generate_dep_list(fpga, pf_has_vf): # pf without sriov enabled. if not pf_has_vf: driver_dep.num_accelerators = 1 - driver_dep.attach_handle_list = \ - [_generate_attach_handle(fpga)] + driver_dep.attach_handle_list = [_generate_attach_handle(fpga)] driver_dep.name = fpga["name"] driver_dep.driver_name = DRIVER_NAME # pf with sriov enabled, may have several regions and several vfs. @@ -272,8 +290,7 @@ def _generate_dep_list(fpga, pf_has_vf): driver_dep.num_accelerators = len(fpga["regions"]) for vf in fpga["regions"]: # Only vfs in regions can be attach, no pf. - driver_dep.attach_handle_list.append( - _generate_attach_handle(vf)) + driver_dep.attach_handle_list.append(_generate_attach_handle(vf)) driver_dep.name = vf["name"] driver_dep.driver_name = DRIVER_NAME return [driver_dep] diff --git a/cyborg/accelerator/drivers/fpga/xilinx/driver.py b/cyborg/accelerator/drivers/fpga/xilinx/driver.py index 0dc1265b..1c642789 100644 --- a/cyborg/accelerator/drivers/fpga/xilinx/driver.py +++ b/cyborg/accelerator/drivers/fpga/xilinx/driver.py @@ -16,6 +16,7 @@ """ Cyborg Xilinx FPGA driver implementation. """ + from oslo_concurrency import processutils from cyborg.accelerator.drivers.fpga.base import FPGADriver @@ -35,8 +36,9 @@ def _fpga_program_privileged(cmd_args): class XilinxFPGADriver(FPGADriver): """Class for Xilinx FPGA drivers. - Vendor should implement their specific drivers in this class. + Vendor should implement their specific drivers in this class. """ + VENDOR = "xilinx" def __init__(self, *args, **kwargs): @@ -53,15 +55,24 @@ class XilinxFPGADriver(FPGADriver): :returns: True on success, False on failure """ if controlpath_id['cpid_type'] != "PCI": - raise exception.InvalidType(obj='controlpath_id', - type=controlpath_id['cpid_type'], - expected='PCI') + raise exception.InvalidType( + obj='controlpath_id', + type=controlpath_id['cpid_type'], + expected='PCI', + ) cmd_args = ['program'] cmd_args.append('--device') bdf_dict = controlpath_id['cpid_info'] # BDF format: domain:bus:device:function - bdf = ':'.join([s for s in map(lambda x: bdf_dict[x], - ['domain', 'bus', 'device', 'function'])]) + bdf = ':'.join( + [ + s + for s in map( + lambda x: bdf_dict[x], + ['domain', 'bus', 'device', 'function'], + ) + ] + ) cmd_args.append(bdf) cmd_args.append('--base') cmd_args.append('--image') diff --git a/cyborg/accelerator/drivers/fpga/xilinx/sysinfo.py b/cyborg/accelerator/drivers/fpga/xilinx/sysinfo.py index c954ddad..cc61459d 100644 --- a/cyborg/accelerator/drivers/fpga/xilinx/sysinfo.py +++ b/cyborg/accelerator/drivers/fpga/xilinx/sysinfo.py @@ -34,15 +34,15 @@ from cyborg.privsep import sys_admin_pctxt LOG = logging.getLogger(__name__) -XILINX_FPGA_FLAGS = ["Xilinx Corporation Device", - "Processing accelerators"] +XILINX_FPGA_FLAGS = ["Xilinx Corporation Device", "Processing accelerators"] XILINX_FPGA_INFO_PATTERN = re.compile( r"(?P[0-9a-fA-F]{4}:[0-9a-fA-F]{2}:" r"[0-9a-fA-F]{2}\.[0-9a-fA-F]) " r"(?P.*) [\[].*]: (?P.*) .*" r"[\[](?P[0-9a-fA-F]" - r"{4}):(?P[0-9a-fA-F]{4})].*") + r"{4}):(?P[0-9a-fA-F]{4})].*" +) XILINX_PF_MAPS = {"mgmt": "xclmgmt", "user": "xocl"} @@ -101,13 +101,16 @@ def _combine_device_by_pci_func(pci_devices): for fpga in fpga_devices: existed_addr = fpga.get('pci_addr')[0] # compare domain:bus:slot - if existed_addr and \ - new_addr.split('.')[0] == existed_addr.split('.')[0]: + if ( + existed_addr + and new_addr.split('.')[0] == existed_addr.split('.')[0] + ): fpga.update({'pci_addr': [existed_addr, new_addr]}) is_existed = True if not is_existed: - traits = _generate_traits(pci_dict["vendor_id"], - pci_dict["product_id"]) + traits = _generate_traits( + pci_dict["vendor_id"], pci_dict["product_id"] + ) pci_dict["rc"] = constants.RESOURCES["FPGA"] pci_dict.update(traits) pci_dict.update({'pci_addr': [new_addr]}) @@ -144,7 +147,8 @@ def _generate_attribute_list(fpga): values = fpga.get(k, []) for index, val in enumerate(values): driver_attr = driver_attribute.DriverAttribute( - key="trait" + str(index), value=val) + key="trait" + str(index), value=val + ) attr_list.append(driver_attr) return attr_list @@ -171,13 +175,17 @@ def fpga_tree(): driver_device_obj = driver_device.DriverDevice() driver_device_obj.vendor = fpga["vendor_id"] driver_device_obj.model = fpga.get('model', 'miss model info') - std_board_info = {'product_id': fpga.get('product_id'), - 'controller': fpga.get('controller')} + std_board_info = { + 'product_id': fpga.get('product_id'), + 'controller': fpga.get('controller'), + } vendor_board_info = { - 'vendor_info': fpga.get('vendor_info', 'fpga_vb_info')} + 'vendor_info': fpga.get('vendor_info', 'fpga_vb_info') + } driver_device_obj.std_board_info = jsonutils.dumps(std_board_info) - driver_device_obj.vendor_board_info = \ - jsonutils.dumps(vendor_board_info) + driver_device_obj.vendor_board_info = jsonutils.dumps( + vendor_board_info + ) driver_device_obj.type = constants.DEVICE_FPGA driver_device_obj.stub = fpga.get('stub', False) driver_device_obj.controlpath_id = _generate_controlpath_id(fpga) diff --git a/cyborg/accelerator/drivers/gpu/base.py b/cyborg/accelerator/drivers/gpu/base.py index ceb26ea5..85406b2a 100644 --- a/cyborg/accelerator/drivers/gpu/base.py +++ b/cyborg/accelerator/drivers/gpu/base.py @@ -16,6 +16,7 @@ """ Cyborg GPU driver implementation. """ + from oslo_log import log as logging from cyborg.accelerator.drivers.gpu import utils @@ -29,8 +30,8 @@ VENDOR_MAPS = {"10de": "nvidia", "102b": "matrox"} class GPUDriver: """Base class for GPU drivers. - This is just a virtual GPU drivers interface. - Vendor should implement their specific drivers. + This is just a virtual GPU drivers interface. + Vendor should implement their specific drivers. """ @classmethod diff --git a/cyborg/accelerator/drivers/gpu/nvidia/driver.py b/cyborg/accelerator/drivers/gpu/nvidia/driver.py index 5ea46dfc..f687cf24 100644 --- a/cyborg/accelerator/drivers/gpu/nvidia/driver.py +++ b/cyborg/accelerator/drivers/gpu/nvidia/driver.py @@ -24,8 +24,9 @@ from cyborg.accelerator.drivers.gpu.nvidia import sysinfo class NVIDIAGPUDriver(GPUDriver): """Class for Nvidia GPU drivers. - Vendor should implement their specific drivers in this class. + Vendor should implement their specific drivers in this class. """ + VENDOR = "nvidia" VENDOR_ID = "10de" diff --git a/cyborg/accelerator/drivers/gpu/nvidia/sysinfo.py b/cyborg/accelerator/drivers/gpu/nvidia/sysinfo.py index bc76c8f9..98c325a6 100644 --- a/cyborg/accelerator/drivers/gpu/nvidia/sysinfo.py +++ b/cyborg/accelerator/drivers/gpu/nvidia/sysinfo.py @@ -17,6 +17,7 @@ """ Cyborg NVIDIA GPU driver implementation. """ + from oslo_log import log as logging from oslo_serialization import jsonutils @@ -52,8 +53,12 @@ def _get_traits(vendor_id, product_id, vgpu_type_name=None): traits = ["OWNER_CYBORG"] # PGPU trait gpu_trait = "_".join( - ('CUSTOM', gpu_utils.VENDOR_MAPS.get(vendor_id, "").upper(), - product_id.upper())) + ( + 'CUSTOM', + gpu_utils.VENDOR_MAPS.get(vendor_id, "").upper(), + product_id.upper(), + ) + ) # VGPU trait if vgpu_type_name: gpu_trait = "_".join((gpu_trait, vgpu_type_name.upper())) @@ -73,7 +78,8 @@ def _generate_attribute_list(gpu): values = gpu.get(k, []) for val in values: driver_attr = driver_attribute.DriverAttribute( - key="trait" + str(index), value=val) + key="trait" + str(index), value=val + ) index = index + 1 attr_list.append(driver_attr) return attr_list @@ -89,7 +95,8 @@ def _generate_attach_handle(gpu, num=None): vgpu_mark = gpu["vGPU_type"] + '_' + str(num) driver_ah.attach_type = constants.AH_TYPE_MDEV driver_ah.attach_info = utils.mdev_str_to_json( - gpu["devices"], gpu["vGPU_type"], vgpu_mark) + gpu["devices"], gpu["vGPU_type"], vgpu_mark + ) return driver_ah @@ -103,19 +110,21 @@ def _generate_dep_list(gpu): # NOTE(yumeng) Since Wallaby release, the deplpyable_name is named as # _ driver_dep.name = gpu.get('hostname', '') + '_' + gpu["devices"] - driver_dep.driver_name = \ - gpu_utils.VENDOR_MAPS.get(gpu["vendor_id"], '').upper() + driver_dep.driver_name = gpu_utils.VENDOR_MAPS.get( + gpu["vendor_id"], '' + ).upper() # if is pGPU, num_accelerators = 1 if gpu["rc"] == "PGPU": driver_dep.num_accelerators = 1 - driver_dep.attach_handle_list = \ - [_generate_attach_handle(gpu)] + driver_dep.attach_handle_list = [_generate_attach_handle(gpu)] else: # if is vGPU, num_accelerators is the total vGPU capability of # the asked vGPU type vGPU_path = os.path.expandvars( - '/sys/bus/pci/devices/{}/mdev_supported_types/{}/' - .format(gpu["devices"], gpu["vGPU_type"])) + '/sys/bus/pci/devices/{}/mdev_supported_types/{}/'.format( + gpu["devices"], gpu["vGPU_type"] + ) + ) num_available = 0 with open(vGPU_path + 'available_instances') as f: num_available = int(f.read().strip()) @@ -128,7 +137,8 @@ def _generate_dep_list(gpu): # example: echo "attach_handle_uuid" > nvidia-223/create for num in range(driver_dep.num_accelerators): driver_dep.attach_handle_list.append( - _generate_attach_handle(gpu, num)) + _generate_attach_handle(gpu, num) + ) return [driver_dep] @@ -143,13 +153,13 @@ def _generate_driver_device(gpu): driver_device_obj = driver_device.DriverDevice() driver_device_obj.vendor = gpu['vendor_id'] driver_device_obj.model = gpu.get('model', 'miss model info') - std_board_info = {'product_id': gpu.get('product_id'), - 'controller': gpu.get('controller'), } - vendor_board_info = {'vendor_info': gpu.get('vendor_info', - 'gpu_vb_info')} + std_board_info = { + 'product_id': gpu.get('product_id'), + 'controller': gpu.get('controller'), + } + vendor_board_info = {'vendor_info': gpu.get('vendor_info', 'gpu_vb_info')} driver_device_obj.std_board_info = jsonutils.dumps(std_board_info) - driver_device_obj.vendor_board_info = jsonutils.dumps( - vendor_board_info) + driver_device_obj.vendor_board_info = jsonutils.dumps(vendor_board_info) driver_device_obj.type = constants.DEVICE_GPU driver_device_obj.stub = gpu.get('stub', False) driver_device_obj.controlpath_id = _generate_controlpath_id(gpu) @@ -204,8 +214,9 @@ def _get_supported_vgpu_types(): return CONF.gpu_devices.enabled_vgpu_types, pgpu_type_mapping -def _get_vgpu_type_per_pgpu(device_address, supported_vgpu_types, - pgpu_type_mapping): +def _get_vgpu_type_per_pgpu( + device_address, supported_vgpu_types, pgpu_type_mapping +): """Provides the vGPU type the pGPU supports. :param device_address: the PCI device address in config, @@ -214,9 +225,11 @@ def _get_vgpu_type_per_pgpu(device_address, supported_vgpu_types, supported_vgpu_types, pgpu_type_mapping = _get_supported_vgpu_types() # Bail out quickly if we don't support vGPUs if not supported_vgpu_types: - LOG.warning('Unable to load vGPU_type from [gpu_devices] ' - 'Ensure "enabled_vgpu_types" is set if the gpu' - 'is virtualized.') + LOG.warning( + 'Unable to load vGPU_type from [gpu_devices] ' + 'Ensure "enabled_vgpu_types" is set if the gpu' + 'is virtualized.' + ) return try: @@ -224,8 +237,10 @@ def _get_vgpu_type_per_pgpu(device_address, supported_vgpu_types, utils.parse_address(device_address) except (exception.PciDeviceWrongAddressFormat, IndexError): # this is not a valid PCI address - LOG.warning("The PCI address %s was invalid for getting the" - "related vGPU type", device_address) + LOG.warning( + "The PCI address %s was invalid for getting therelated vGPU type", + device_address, + ) return return pgpu_type_mapping.get(device_address) @@ -240,14 +255,17 @@ def _is_vf(pci_address): try: return os.path.exists(physfn_path) except OSError: - LOG.warning('Failed to check VF status for device %s via %s, ' - 'assuming it is not a VF.', pci_address, physfn_path) + LOG.warning( + 'Failed to check VF status for device %s via %s, ' + 'assuming it is not a VF.', + pci_address, + physfn_path, + ) return False def _discover_gpus(vendor_id): - """param: vendor_id=VENDOR_ID means only discover Nvidia GPU on the host - """ + """param: vendor_id=VENDOR_ID means only discover Nvidia GPU on the host""" # init vGPU conf cyborg.conf.devices.register_dynamic_opts(CONF) supported_vgpu_types, pgpu_type_mapping = _get_supported_vgpu_types() @@ -280,46 +298,56 @@ def _discover_gpus(vendor_id): LOG.warning( 'Unable to determine VF status for ' 'device %s, assuming it is not a VF.', - gpu_dict["devices"]) + gpu_dict["devices"], + ) is_vf = False if is_vf: LOG.info( 'Skipping VF device %s, only PFs and' ' mediated devices are reported.', - gpu_dict["devices"]) + gpu_dict["devices"], + ) continue # get hostname for deployable_name usage gpu_dict['hostname'] = CONF.host # get vgpu_type from cyborg.conf, otherwise vgpu_type=None vgpu_type = _get_vgpu_type_per_pgpu( - gpu_dict["devices"], supported_vgpu_types, pgpu_type_mapping) + gpu_dict["devices"], supported_vgpu_types, pgpu_type_mapping + ) # generate rc and trait for pGPU if not vgpu_type: gpu_dict["rc"] = constants.RESOURCES["PGPU"] - traits = _get_traits(gpu_dict["vendor_id"], - gpu_dict["product_id"]) + traits = _get_traits( + gpu_dict["vendor_id"], gpu_dict["product_id"] + ) # generate rc and trait for vGPU else: # get rc gpu_dict["rc"] = constants.RESOURCES["VGPU"] mdev_path = os.path.expandvars( - '/sys/bus/pci/devices/{}/mdev_supported_types'. - format(gpu_dict["devices"])) + '/sys/bus/pci/devices/{}/mdev_supported_types'.format( + gpu_dict["devices"] + ) + ) valid_types = os.listdir(mdev_path) if vgpu_type not in valid_types: raise exception.InvalidVGPUType(name=vgpu_type) gpu_dict["vGPU_type"] = vgpu_type vGPU_path = os.path.expandvars( - '/sys/bus/pci/devices/{}/mdev_supported_types/{}/' - .format(gpu_dict["devices"], gpu_dict["vGPU_type"])) + '/sys/bus/pci/devices/{}/mdev_supported_types/{}/'.format( + gpu_dict["devices"], gpu_dict["vGPU_type"] + ) + ) # transfer vgpu_type to vgpu_type_name. # eg. transfer 'nvidia-223' to 'T4_1B' with open(vGPU_path + 'name') as f: name = f.read().strip() vgpu_type_name = name.split(' ')[1].replace('-', '_') - traits = _get_traits(gpu_dict["vendor_id"], - gpu_dict["product_id"], - vgpu_type_name) + traits = _get_traits( + gpu_dict["vendor_id"], + gpu_dict["product_id"], + vgpu_type_name, + ) gpu_dict.update(traits) gpu_list.append(_generate_driver_device(gpu_dict)) return gpu_list diff --git a/cyborg/accelerator/drivers/gpu/utils.py b/cyborg/accelerator/drivers/gpu/utils.py index 1b9963fc..26d71669 100644 --- a/cyborg/accelerator/drivers/gpu/utils.py +++ b/cyborg/accelerator/drivers/gpu/utils.py @@ -14,6 +14,7 @@ """ Utils for GPU driver. """ + from oslo_concurrency import processutils from oslo_log import log as logging @@ -26,11 +27,13 @@ import cyborg.privsep LOG = logging.getLogger(__name__) GPU_FLAGS = ["VGA compatible controller", "3D controller"] -GPU_INFO_PATTERN = re.compile(r"(?P[0-9a-fA-F]{4}:[0-9a-fA-F]{2}:" - r"[0-9a-fA-F]{2}\.[0-9a-fA-F]) " - r"(?P.*) [\[].*]: (?P.*) .*" - r"[\[](?P[0-9a-fA-F]" - r"{4}):(?P[0-9a-fA-F]{4})].*") +GPU_INFO_PATTERN = re.compile( + r"(?P[0-9a-fA-F]{4}:[0-9a-fA-F]{2}:" + r"[0-9a-fA-F]{2}\.[0-9a-fA-F]) " + r"(?P.*) [\[].*]: (?P.*) .*" + r"[\[](?P[0-9a-fA-F]" + r"{4}):(?P[0-9a-fA-F]{4})].*" +) VENDOR_MAPS = {"10de": "nvidia", "102b": "matrox"} PRODUCT_ID_MAPS = {"1eb8": "T4", "15f7": "P100_PCIE_12GB"} @@ -56,8 +59,9 @@ def create_mdev_privileged(pci_addr, mdev_type, ah_uuid): @cyborg.privsep.sys_admin_pctxt.entrypoint def remove_mdev_privileged(physical_device, mdev_type, medv_uuid): - fpath = ('/sys/class/mdev_bus/{0}/mdev_supported_types/' - '{1}/devices/{2}/remove') + fpath = ( + '/sys/class/mdev_bus/{0}/mdev_supported_types/{1}/devices/{2}/remove' + ) fpath = fpath.format(physical_device, mdev_type, medv_uuid) with open(fpath, 'w') as f: f.write("1") diff --git a/cyborg/accelerator/drivers/modules/generic.py b/cyborg/accelerator/drivers/modules/generic.py index 73363209..e1d5f71d 100644 --- a/cyborg/accelerator/drivers/modules/generic.py +++ b/cyborg/accelerator/drivers/modules/generic.py @@ -37,20 +37,22 @@ def _check_for_missing_params(info_dict, error_msg, param_prefix=''): if missing_info: exc_msg = _("%(error_msg)s. Missing are: %(missing_info)s") raise exception.MissingParameterValue( - exc_msg % {'error_msg': error_msg, 'missing_info': missing_info}) + exc_msg % {'error_msg': error_msg, 'missing_info': missing_info} + ) def _parse_driver_info(driver): info = driver.driver_info d_info = {k: info.get(k) for k in COMMON_PROPERTIES} - error_msg = _("Cannot validate Generic Driver. Some parameters were" - " missing in the configuration file.") + error_msg = _( + "Cannot validate Generic Driver. Some parameters were" + " missing in the configuration file." + ) _check_for_missing_params(d_info, error_msg) return d_info class GENERICDRIVER: - def get_properties(self): """Return the properties of the generic driver. @@ -59,12 +61,10 @@ class GENERICDRIVER: return COMMON_PROPERTIES def attach(self, accelerator, instance): - def install(self, accelerator): pass def detach(self, accelerator, instance): - def uninstall(self, accelerator): pass diff --git a/cyborg/accelerator/drivers/nic/base.py b/cyborg/accelerator/drivers/nic/base.py index 073a3787..4c242b60 100644 --- a/cyborg/accelerator/drivers/nic/base.py +++ b/cyborg/accelerator/drivers/nic/base.py @@ -23,8 +23,8 @@ VENDOR_MAPS = {"0x8086": "intel"} class NICDriver: """Base class for Nic drivers. - This is just a virtual NIC drivers interface. - Vendor should implement their specific drivers. + This is just a virtual NIC drivers interface. + Vendor should implement their specific drivers. """ @classmethod diff --git a/cyborg/accelerator/drivers/nic/intel/driver.py b/cyborg/accelerator/drivers/nic/intel/driver.py index 65752e43..c743ae5d 100644 --- a/cyborg/accelerator/drivers/nic/intel/driver.py +++ b/cyborg/accelerator/drivers/nic/intel/driver.py @@ -23,8 +23,9 @@ from cyborg.accelerator.drivers.nic.intel import sysinfo class IntelNICDriver(NICDriver): """Class for Intel NIC drivers. - Vendor should implement their specific drivers in this class. + Vendor should implement their specific drivers in this class. """ + VENDOR = "intel" def __init__(self, *args, **kwargs): diff --git a/cyborg/accelerator/drivers/nic/intel/sysinfo.py b/cyborg/accelerator/drivers/nic/intel/sysinfo.py index 0de8fcaf..0de037c3 100644 --- a/cyborg/accelerator/drivers/nic/intel/sysinfo.py +++ b/cyborg/accelerator/drivers/nic/intel/sysinfo.py @@ -17,7 +17,6 @@ Cyborg Intel NIC driver implementation. """ - import glob import os import socket @@ -58,8 +57,9 @@ def _parse_config(): return pdm, fdm -def get_physical_network_and_traits(pci_info, physnet_device_mappings, - function_device_mappings, pf_nic=None): +def get_physical_network_and_traits( + pci_info, physnet_device_mappings, function_device_mappings, pf_nic=None +): traits = [] physnet = None func_name = None @@ -99,20 +99,21 @@ def read_line(filename): def find_nics_by_know_list(): - return set(filter( - lambda p: ( - read_line(os.path.join(p, "vendor")), - read_line(os.path.join(p, "device")) - ) in KNOWN_NICS, - glob.glob(PCI_DEVICES_PATH_PATTERN))) + return set( + filter( + lambda p: ( + read_line(os.path.join(p, "vendor")), + read_line(os.path.join(p, "device")), + ) + in KNOWN_NICS, + glob.glob(PCI_DEVICES_PATH_PATTERN), + ) + ) def pci_attributes(path): with open(os.path.join(path, "uevent")) as f: - attributes = dict(map( - lambda p: p.strip().split("="), - f.readlines() - )) + attributes = dict(map(lambda p: p.strip().split("="), f.readlines())) with open(os.path.join(path, "vendor")) as f: attributes["VENDOR"] = f.readline().strip() @@ -123,8 +124,12 @@ def pci_attributes(path): return attributes -def nic_gen(path, physnet_device_mappings=None, function_device_mappings=None, - pf_nic=None): +def nic_gen( + path, + physnet_device_mappings=None, + function_device_mappings=None, + pf_nic=None, +): pci_info = pci_attributes(path) nic = { "name": "_".join((socket.gethostname(), pci_info["PCI_SLOT_NAME"])), @@ -134,30 +139,32 @@ def nic_gen(path, physnet_device_mappings=None, function_device_mappings=None, "product_id": pci_info["PRODUCT_ID"], "rc": "CUSTOM_NIC", "stub": False, - } + } # TODO(Xinran): need check device id and call get_traits differently. - updates = get_physical_network_and_traits(pci_info, - physnet_device_mappings, - function_device_mappings, - pf_nic) + updates = get_physical_network_and_traits( + pci_info, physnet_device_mappings, function_device_mappings, pf_nic + ) nic.update(updates) return nic def all_pfs_with_vf(): - return set(filter( - lambda p: glob.glob(os.path.join(p, VF)), - find_nics_by_know_list())) + return set( + filter( + lambda p: glob.glob(os.path.join(p, VF)), find_nics_by_know_list() + ) + ) def all_vfs_in_pf(pf_path): return map( - lambda p: - os.path.join( + lambda p: os.path.join( os.path.dirname(os.path.dirname(p)), - os.path.basename(os.readlink(p))), - glob.glob(os.path.join(pf_path, VF))) + os.path.basename(os.readlink(p)), + ), + glob.glob(os.path.join(pf_path, VF)), + ) def nic_tree(): @@ -169,8 +176,9 @@ def nic_tree(): if n in pfs_has_vf: vfs = [] for vf in all_vfs_in_pf(n): - vf_nic = nic_gen(vf, physnet_device_mappings, - function_device_mappings, nic) + vf_nic = nic_gen( + vf, physnet_device_mappings, function_device_mappings, nic + ) vfs.append(vf_nic) nic["vfs"] = vfs nics.append(_generate_driver_device(nic)) @@ -183,8 +191,8 @@ def _generate_driver_device(nic): driver_device_obj.stub = nic["stub"] driver_device_obj.model = nic.get("model", "miss_model_info") driver_device_obj.vendor_board_info = nic.get( - "vendor_board_info", - "miss_vb_info") + "vendor_board_info", "miss_vb_info" + ) std_board_info = {"product_id": nic.get("product_id")} driver_device_obj.std_board_info = jsonutils.dumps(std_board_info) driver_device_obj.type = nic["type"] @@ -206,8 +214,7 @@ def _generate_dep_list(nic): if "vfs" not in nic: driver_dep = driver_deployable.DriverDeployable() driver_dep.num_accelerators = 1 - driver_dep.attach_handle_list = [ - _generate_attach_handle(nic)] + driver_dep.attach_handle_list = [_generate_attach_handle(nic)] driver_dep.name = nic["name"] driver_dep.driver_name = DRIVER_NAME driver_dep.attribute_list = _generate_attribute_list(nic) @@ -217,8 +224,7 @@ def _generate_dep_list(nic): for vf in nic["vfs"]: driver_dep = driver_deployable.DriverDeployable() driver_dep.num_accelerators = 1 - driver_dep.attach_handle_list = [ - _generate_attach_handle(vf)] + driver_dep.attach_handle_list = [_generate_attach_handle(vf)] driver_dep.name = vf["name"] driver_dep.driver_name = DRIVER_NAME driver_dep.attribute_list = _generate_attribute_list(vf) @@ -229,8 +235,9 @@ def _generate_dep_list(nic): def _generate_attach_handle(nic): driver_ah = driver_attach_handle.DriverAttachHandle() driver_ah.attach_type = "PCI" - driver_ah.attach_info = utils.pci_str_to_json(nic["device"], - nic["physical_network"]) + driver_ah.attach_info = utils.pci_str_to_json( + nic["device"], nic["physical_network"] + ) driver_ah.in_use = False return driver_ah diff --git a/cyborg/accelerator/drivers/pci/base.py b/cyborg/accelerator/drivers/pci/base.py index 871ffbaa..5230bf33 100644 --- a/cyborg/accelerator/drivers/pci/base.py +++ b/cyborg/accelerator/drivers/pci/base.py @@ -19,8 +19,7 @@ Cyborg Pci driver implementation. class PciDriver: - """Base class for Pci drivers. - """ + """Base class for Pci drivers.""" def __init__(self, *args, **kwargs): pass diff --git a/cyborg/accelerator/drivers/pci/devspec.py b/cyborg/accelerator/drivers/pci/devspec.py index f0818474..29f6b41e 100644 --- a/cyborg/accelerator/drivers/pci/devspec.py +++ b/cyborg/accelerator/drivers/pci/devspec.py @@ -41,11 +41,14 @@ class PciAddressSpec(metaclass=abc.ABCMeta): pass def is_single_address(self): - return all([ - all(c in string.hexdigits for c in self.domain), - all(c in string.hexdigits for c in self.bus), - all(c in string.hexdigits for c in self.slot), - all(c in string.hexdigits for c in self.func)]) + return all( + [ + all(c in string.hexdigits for c in self.domain), + all(c in string.hexdigits for c in self.bus), + all(c in string.hexdigits for c in self.slot), + all(c in string.hexdigits for c in self.func), + ] + ) def _set_pci_dev_info(self, prop, maxval, hex_value): a = getattr(self, prop) @@ -55,13 +58,20 @@ class PciAddressSpec(metaclass=abc.ABCMeta): v = int(a, 16) except ValueError: raise exception.PciConfigInvalidWhitelist( - reason=_("property %(property)s ('%(attr)s') does not parse " - "as a hex number.") % {'property': prop, 'attr': a}) + reason=_( + "property %(property)s ('%(attr)s') does not parse " + "as a hex number." + ) + % {'property': prop, 'attr': a} + ) if v > maxval: raise exception.PciConfigInvalidWhitelist( - reason=_("property %(property)s (%(attr)s) is greater than " - "the maximum allowable value (%(max)X).") % { - 'property': prop, 'attr': a, 'max': maxval}) + reason=_( + "property %(property)s (%(attr)s) is greater than " + "the maximum allowable value (%(max)X)." + ) + % {'property': prop, 'attr': a, 'max': maxval} + ) setattr(self, prop, hex_value % v) @@ -81,7 +91,8 @@ class PhysicalPciAddress(PciAddressSpec): self.func = pci_addr['function'] else: self.domain, self.bus, self.slot, self.func = ( - utils.get_pci_address_fields(pci_addr)) + utils.get_pci_address_fields(pci_addr) + ) self._set_pci_dev_info('func', MAX_FUNC, '%1x') self._set_pci_dev_info('domain', MAX_DOMAIN, '%04x') self._set_pci_dev_info('bus', MAX_BUS, '%02x') @@ -95,7 +106,7 @@ class PhysicalPciAddress(PciAddressSpec): self.bus == phys_pci_addr.bus, self.slot == phys_pci_addr.slot, self.func == phys_pci_addr.func, - ] + ] return all(conditions) @@ -136,8 +147,8 @@ class PciAddressGlobSpec(PciAddressSpec): self.domain in (ANY, phys_pci_addr.domain), self.bus in (ANY, phys_pci_addr.bus), self.slot in (ANY, phys_pci_addr.slot), - self.func in (ANY, phys_pci_addr.func) - ] + self.func in (ANY, phys_pci_addr.func), + ] return all(conditions) @@ -167,8 +178,8 @@ class PciAddressRegexSpec(PciAddressSpec): bool(self.domain_regex.match(phys_pci_addr.domain)), bool(self.bus_regex.match(phys_pci_addr.bus)), bool(self.slot_regex.match(phys_pci_addr.slot)), - bool(self.func_regex.match(phys_pci_addr.func)) - ] + bool(self.func_regex.match(phys_pci_addr.func)), + ] return all(conditions) @@ -197,12 +208,12 @@ class WhitelistPciAddress: def _check_physical_function(self): if self.pci_address_spec.is_single_address(): - self.is_physical_function = ( - utils.is_physical_function( - self.pci_address_spec.domain, - self.pci_address_spec.bus, - self.pci_address_spec.slot, - self.pci_address_spec.func)) + self.is_physical_function = utils.is_physical_function( + self.pci_address_spec.domain, + self.pci_address_spec.bus, + self.pci_address_spec.slot, + self.pci_address_spec.func, + ) def _init_address_fields(self, pci_addr): if not self.is_physical_function: @@ -266,8 +277,7 @@ class PciDeviceSpec(PciAddressSpec): def match(self, dev_dict): if self.dev_name: - address_str, pf = utils.get_function_by_ifname( - self.dev_name) + address_str, pf = utils.get_function_by_ifname(self.dev_name) if not address_str: return False # Note(moshele): In this case we always passing a string @@ -275,17 +285,25 @@ class PciDeviceSpec(PciAddressSpec): address_obj = WhitelistPciAddress(address_str, pf) elif self.address: address_obj = self.address - return all([ - self.vendor_id in (ANY, dev_dict['vendor_id']), - self.product_id in (ANY, dev_dict['product_id']), - address_obj.match(dev_dict['address'], - dev_dict.get('parent_addr'))]) + return all( + [ + self.vendor_id in (ANY, dev_dict['vendor_id']), + self.product_id in (ANY, dev_dict['product_id']), + address_obj.match( + dev_dict['address'], dev_dict.get('parent_addr') + ), + ] + ) def match_pci_obj(self, pci_obj): - return self.match({'vendor_id': pci_obj.vendor_id, - 'product_id': pci_obj.product_id, - 'address': pci_obj.address, - 'parent_addr': pci_obj.parent_addr}) + return self.match( + { + 'vendor_id': pci_obj.vendor_id, + 'product_id': pci_obj.product_id, + 'address': pci_obj.address, + 'parent_addr': pci_obj.parent_addr, + } + ) def get_tags(self): return self.tags diff --git a/cyborg/accelerator/drivers/pci/pci/driver.py b/cyborg/accelerator/drivers/pci/pci/driver.py index c77e1196..cd1167ba 100644 --- a/cyborg/accelerator/drivers/pci/pci/driver.py +++ b/cyborg/accelerator/drivers/pci/pci/driver.py @@ -23,7 +23,7 @@ from cyborg.accelerator.drivers.pci.pci import sysinfo class PCIDriver(PciDriver): """Class for Pci drivers. - Vendor should implement their specific drivers in this class. + Vendor should implement their specific drivers in this class. """ def discover(self): diff --git a/cyborg/accelerator/drivers/pci/pci/sysinfo.py b/cyborg/accelerator/drivers/pci/pci/sysinfo.py index 768518e8..42ac9ccf 100644 --- a/cyborg/accelerator/drivers/pci/pci/sysinfo.py +++ b/cyborg/accelerator/drivers/pci/pci/sysinfo.py @@ -16,6 +16,7 @@ """ Cyborg PCI driver implementation. """ + import re from oslo_log import log as logging @@ -73,7 +74,8 @@ def _generate_attribute_list(pci): values = pci.get(k, []) for val in values: driver_attr = driver_attribute.DriverAttribute( - key="trait" + str(index), value=val) + key="trait" + str(index), value=val + ) index = index + 1 attr_list.append(driver_attr) return attr_list @@ -98,8 +100,7 @@ def _generate_dep_list(pci): # NOTE(yumeng) Since Wallaby release, the deplpyable_name is named as # _ driver_dep.name = pci.get('hostname', '') + '_' + pci["devices"] - vendor_name = pci_utils.VENDOR_MAPS.get( - pci["vendor_id"], pci["vendor_id"]) + vendor_name = pci_utils.VENDOR_MAPS.get(pci["vendor_id"], pci["vendor_id"]) driver_dep.driver_name = vendor_name.upper() driver_dep.num_accelerators = 1 driver_dep.attach_handle_list = [_generate_attach_handle(pci)] @@ -118,16 +119,18 @@ def _generate_driver_device(pci): driver_device_obj = driver_device.DriverDevice() driver_device_obj.vendor = pci['vendor_id'] driver_device_obj.model = pci['product_id'] - std_board_info = {'product_id': pci.get('product_id'), - 'controller': pci.get('controller'), - } + std_board_info = { + 'product_id': pci.get('product_id'), + 'controller': pci.get('controller'), + } driver_device_obj.std_board_info = jsonutils.dumps(std_board_info) driver_device_obj.type = constants.DEVICE_GPU driver_device_obj.stub = pci.get('stub', False) driver_device_obj.controlpath_id = _generate_controlpath_id(pci) driver_device_obj.deployable_list, ais = _generate_dep_list(pci) - driver_device_obj.vendor_board_info = pci.get('vendor_board_info', - "miss_vb_info") + driver_device_obj.vendor_board_info = pci.get( + 'vendor_board_info', "miss_vb_info" + ) return driver_device_obj @@ -149,14 +152,13 @@ def _discover_pcis(): 'vendor_id': pci_dict['vendor_id'], 'product_id': pci_dict['product_id'], 'address': pci_dict['devices'], - 'parent_addr': None + 'parent_addr': None, } if dev_filter.device_assignable(dev_info): # get hostname for deployable_name usage pci_dict['hostname'] = CONF.host pci_dict["rc"] = constants.RESOURCES["PCI"] - traits = _get_traits(pci_dict["vendor_id"], - pci_dict["product_id"]) + traits = _get_traits(pci_dict["vendor_id"], pci_dict["product_id"]) pci_dict.update(traits) pci_list.append(_generate_driver_device(pci_dict)) LOG.info('pci_list: %s', pci_list) diff --git a/cyborg/accelerator/drivers/pci/utils.py b/cyborg/accelerator/drivers/pci/utils.py index f4b7c81c..aec5516d 100644 --- a/cyborg/accelerator/drivers/pci/utils.py +++ b/cyborg/accelerator/drivers/pci/utils.py @@ -25,9 +25,9 @@ import cyborg.privsep LOG = logging.getLogger(__name__) PCI_VENDOR_PATTERN = "^(hex{4})$".replace("hex", r"[\da-fA-F]") -_PCI_ADDRESS_PATTERN = ("^(hex{4}):(hex{2}):(hex{2}).(oct{1})$". - replace("hex", r"[\da-fA-F]"). - replace("oct", "[0-7]")) +_PCI_ADDRESS_PATTERN = "^(hex{4}):(hex{2}):(hex{2}).(oct{1})$".replace( + "hex", r"[\da-fA-F]" +).replace("oct", "[0-7]") _PCI_ADDRESS_REGEX = re.compile(_PCI_ADDRESS_PATTERN) _SRIOV_TOTALVFS = "sriov_totalvfs" @@ -61,6 +61,7 @@ def pci_device_prop_match(pci_dev, specs): "capabilities_network": ["rx", "tx", "tso", "gso"]}] """ + def _matching_devices(spec): for k, v in spec.items(): pci_dev_v = pci_dev.get(k) @@ -127,8 +128,7 @@ def get_function_by_ifname(ifname): # sriov_totalvfs contains the maximum possible VFs for this PF with open(os.path.join(dev_path, _SRIOV_TOTALVFS)) as fd: sriov_totalvfs = int(fd.read()) - return (os.readlink(dev_path).strip("./"), - sriov_totalvfs > 0) + return (os.readlink(dev_path).strip("./"), sriov_totalvfs > 0) except (OSError, ValueError): return os.readlink(dev_path).strip("./"), False return None, False @@ -136,7 +136,11 @@ def get_function_by_ifname(ifname): def is_physical_function(domain, bus, slot, function): dev_path = "/sys/bus/pci/devices/%(d)s:%(b)s:%(s)s.%(f)s/" % { - "d": domain, "b": bus, "s": slot, "f": function} + "d": domain, + "b": bus, + "s": slot, + "f": function, + } if os.path.isdir(dev_path): try: with open(dev_path + _SRIOV_TOTALVFS) as fd: @@ -185,10 +189,12 @@ def get_mac_by_pci_address(pci_addr, pf_interface=False): mac = next(f).strip() return mac except (OSError, StopIteration) as e: - LOG.warning("Could not find the expected sysfs file for " - "determining the MAC address of the PCI device " - "%(addr)s. May not be a NIC. Error: %(e)s", - {'addr': pci_addr, 'e': e}) + LOG.warning( + "Could not find the expected sysfs file for " + "determining the MAC address of the PCI device " + "%(addr)s. May not be a NIC. Error: %(e)s", + {'addr': pci_addr, 'e': e}, + ) raise exception.PciDeviceNotFoundById(id=pci_addr) @@ -230,9 +236,13 @@ def get_net_name_by_vf_pci_address(vfaddress): try: mac = get_mac_by_pci_address(vfaddress).split(':') ifname = get_ifname_by_pci_address(vfaddress) - return ("net_%(ifname)s_%(mac)s" % - {'ifname': ifname, 'mac': '_'.join(mac)}) + return "net_%(ifname)s_%(mac)s" % { + 'ifname': ifname, + 'mac': '_'.join(mac), + } except Exception: - LOG.warning("No net device was found for VF %(vfaddress)s", - {'vfaddress': vfaddress}) + LOG.warning( + "No net device was found for VF %(vfaddress)s", + {'vfaddress': vfaddress}, + ) return diff --git a/cyborg/accelerator/drivers/pci/whitelist.py b/cyborg/accelerator/drivers/pci/whitelist.py index 81b2379c..a8eb6707 100644 --- a/cyborg/accelerator/drivers/pci/whitelist.py +++ b/cyborg/accelerator/drivers/pci/whitelist.py @@ -56,19 +56,21 @@ class Whitelist: dev_spec = jsonutils.loads(jsonspec) except ValueError: raise exception.PciConfigInvalidWhitelist( - reason=_("Invalid entry: '%s'") % jsonspec) + reason=_("Invalid entry: '%s'") % jsonspec + ) if isinstance(dev_spec, dict): dev_spec = [dev_spec] elif not isinstance(dev_spec, list): raise exception.PciConfigInvalidWhitelist( - reason=_("Invalid entry: '%s'; " - "Expecting list or dict") % jsonspec) + reason=_("Invalid entry: '%s'; Expecting list or dict") + % jsonspec + ) for ds in dev_spec: if not isinstance(ds, dict): raise exception.PciConfigInvalidWhitelist( - reason=_("Invalid entry: '%s'; " - "Expecting dict") % ds) + reason=_("Invalid entry: '%s'; Expecting dict") % ds + ) spec = devspec.PciDeviceSpec(ds) specs.append(spec) diff --git a/cyborg/accelerator/drivers/qat/base.py b/cyborg/accelerator/drivers/qat/base.py index d60e0e5a..3d4febdb 100644 --- a/cyborg/accelerator/drivers/qat/base.py +++ b/cyborg/accelerator/drivers/qat/base.py @@ -23,8 +23,8 @@ VENDOR_MAPS = {"0x8086": "intel"} class QATDriver: """Base class for QAT drivers. - This is just a virtual QAT drivers interface. - Vendor should implement their specific drivers. + This is just a virtual QAT drivers interface. + Vendor should implement their specific drivers. """ @classmethod diff --git a/cyborg/accelerator/drivers/qat/intel/driver.py b/cyborg/accelerator/drivers/qat/intel/driver.py index 6f910752..c3fc02ce 100644 --- a/cyborg/accelerator/drivers/qat/intel/driver.py +++ b/cyborg/accelerator/drivers/qat/intel/driver.py @@ -23,8 +23,9 @@ from cyborg.accelerator.drivers.qat.intel import sysinfo class IntelQATDriver(QATDriver): """Class for Intel QAT drivers. - Vendor should implement their specific drivers in this class. + Vendor should implement their specific drivers in this class. """ + VENDOR = "intel" def __init__(self, *args, **kwargs): diff --git a/cyborg/accelerator/drivers/qat/intel/sysinfo.py b/cyborg/accelerator/drivers/qat/intel/sysinfo.py index 3f63b82d..9e615946 100644 --- a/cyborg/accelerator/drivers/qat/intel/sysinfo.py +++ b/cyborg/accelerator/drivers/qat/intel/sysinfo.py @@ -17,7 +17,6 @@ Cyborg Intel QAT driver implementation. """ - import glob import os import socket @@ -41,17 +40,12 @@ INTEL_QAT_DEV_PREFIX = "intel-qat-dev" RC_QAT = constants.RESOURCES["QAT"] DRIVER_NAME = "intel" -RESOURCES = { - "qat": RC_QAT -} +RESOURCES = {"qat": RC_QAT} def pci_attributes(path): with open(os.path.join(path, "uevent")) as f: - attributes = dict(map( - lambda p: p.strip().split("="), - f.readlines() - )) + attributes = dict(map(lambda p: p.strip().split("="), f.readlines())) with open(os.path.join(path, "vendor")) as f: attributes["VENDOR"] = f.readline().strip() @@ -64,47 +58,49 @@ def pci_attributes(path): def get_link_targets(links): return map( - lambda p: - os.path.realpath( - os.path.join(os.path.dirname(p), os.readlink(p))), - links) + lambda p: os.path.realpath( + os.path.join(os.path.dirname(p), os.readlink(p)) + ), + links, + ) def all_qats(): - return set(filter( - lambda p: ( - pci_attributes(p)["VENDOR"], - pci_attributes(p)["PRODUCT_ID"] - ) in KNOW_QATS, - glob.glob(os.path.join(PCI_DEVICES_PATH, "*")))) + return set( + filter( + lambda p: ( + pci_attributes(p)["VENDOR"], + pci_attributes(p)["PRODUCT_ID"], + ) + in KNOW_QATS, + glob.glob(os.path.join(PCI_DEVICES_PATH, "*")), + ) + ) def all_pfs_with_vf(): - return set(filter( - lambda p: glob.glob(os.path.join(p, VF)), - all_qats())) + return set(filter(lambda p: glob.glob(os.path.join(p, VF)), all_qats())) def all_vfs_in_pf(pf_path): return map( - lambda p: - os.path.join( + lambda p: os.path.join( os.path.dirname(os.path.dirname(p)), - os.path.basename(os.readlink(p))), - glob.glob(os.path.join(pf_path, VF))) + os.path.basename(os.readlink(p)), + ), + glob.glob(os.path.join(pf_path, VF)), + ) def find_pf_by_vf(vf_path): return os.path.join( os.path.dirname(vf_path), - os.path.basename(os.readlink( - os.path.join(vf_path, PF)))) + os.path.basename(os.readlink(os.path.join(vf_path, PF))), + ) def all_vfs(): - return map( - lambda p: all_vfs_in_pf(p), all_pfs_with_vf() - ) + return map(lambda p: all_vfs_in_pf(p), all_pfs_with_vf()) def qat_gen(path): @@ -142,8 +138,8 @@ def _generate_driver_device(qat): driver_device_obj.stub = qat["stub"] driver_device_obj.model = qat.get("model", "miss_model_info") driver_device_obj.vendor_board_info = qat.get( - "vendor_board_info", - "miss_vb_info") + "vendor_board_info", "miss_vb_info" + ) std_board_info = {"product_id": qat.get("product_id")} driver_device_obj.std_board_info = jsonutils.dumps(std_board_info) driver_device_obj.type = qat["type"] @@ -165,8 +161,7 @@ def _generate_dep_list(qat): if "vfs" not in qat: driver_dep = driver_deployable.DriverDeployable() driver_dep.num_accelerators = 1 - driver_dep.attach_handle_list = [ - _generate_attach_handle(qat)] + driver_dep.attach_handle_list = [_generate_attach_handle(qat)] driver_dep.name = qat["name"] driver_dep.driver_name = DRIVER_NAME driver_dep.attribute_list = _generate_attribute_list(qat) @@ -176,8 +171,7 @@ def _generate_dep_list(qat): for vf in qat["vfs"]: driver_dep = driver_deployable.DriverDeployable() driver_dep.num_accelerators = 1 - driver_dep.attach_handle_list = [ - _generate_attach_handle(vf)] + driver_dep.attach_handle_list = [_generate_attach_handle(vf)] driver_dep.name = vf["name"] driver_dep.driver_name = DRIVER_NAME driver_dep.attribute_list = _generate_attribute_list(qat) diff --git a/cyborg/accelerator/drivers/spdk/nvmf/nvmf.py b/cyborg/accelerator/drivers/spdk/nvmf/nvmf.py index f51ec02a..919e52a8 100644 --- a/cyborg/accelerator/drivers/spdk/nvmf/nvmf.py +++ b/cyborg/accelerator/drivers/spdk/nvmf/nvmf.py @@ -52,7 +52,7 @@ class NVMFDRIVER(SPDKDRIVER): accelerator_obj = { 'server': self.SERVER, 'bdevs': bdevs, - 'subsystems': subsystems + 'subsystems': subsystems, } return accelerator_obj @@ -95,13 +95,9 @@ class NVMFDRIVER(SPDKDRIVER): else: raise exception.Invalid('Delete nvmf subsystem failed.') - def construct_subsystem(self, - nqn, - listen, - hosts, - serial_number, - namespaces - ): + def construct_subsystem( + self, nqn, listen, hosts, serial_number, namespaces + ): """Add a nvmf subsystem :param nqn: Target nqn(ASCII). @@ -117,14 +113,13 @@ class NVMFDRIVER(SPDKDRIVER): :param namespaces: Whitespace-separated list of namespaces. :raise exception: Invalid """ - if ((namespaces != '' and listen != '') and - (hosts != '' and serial_number != '')) and nqn != '': + if ( + (namespaces != '' and listen != '') + and (hosts != '' and serial_number != '') + ) and nqn != '': acc_client = NvmfTgt(self.py) - acc_client.construct_nvmf_subsystem(nqn, - listen, - hosts, - serial_number, - namespaces - ) + acc_client.construct_nvmf_subsystem( + nqn, listen, hosts, serial_number, namespaces + ) else: raise exception.Invalid('Construct nvmf subsystem failed.') diff --git a/cyborg/accelerator/drivers/spdk/spdk.py b/cyborg/accelerator/drivers/spdk/spdk.py index d2098999..70ecd573 100644 --- a/cyborg/accelerator/drivers/spdk/spdk.py +++ b/cyborg/accelerator/drivers/spdk/spdk.py @@ -16,15 +16,17 @@ Cyborg SPDK driver modules implementation. """ from oslo_log import log as logging + LOG = logging.getLogger(__name__) class SPDKDRIVER: """SPDKDRIVER - This is just a virtual SPDK drivers interface. - SPDK-based app server should implement their specific drivers. + This is just a virtual SPDK drivers interface. + SPDK-based app server should implement their specific drivers. """ + @classmethod def create(cls, server, *args, **kwargs): for subclass in cls.__subclasses__(): diff --git a/cyborg/accelerator/drivers/spdk/util/common_fun.py b/cyborg/accelerator/drivers/spdk/util/common_fun.py index 64fcf31d..6e3a2705 100644 --- a/cyborg/accelerator/drivers/spdk/util/common_fun.py +++ b/cyborg/accelerator/drivers/spdk/util/common_fun.py @@ -32,25 +32,29 @@ from cyborg.common.i18n import _ LOG = logging.getLogger(__name__) accelerator_opts = [ - cfg.StrOpt('spdk_conf_file', - default='/etc/cyborg/spdk.conf', - help=_('SPDK conf file to be used for the SPDK driver')), - - cfg.StrOpt('accelerator_servers', - default=['vhost', 'nvmf', 'iscsi'], - help=_('A list of accelerator servers to enable by default')), - - cfg.StrOpt('spdk_dir', - default='/home/wewe/spdk', - help=_('The SPDK directory is /home/{user_name}/spdk')), - - cfg.StrOpt('device_type', - default='NVMe', - help=_('Backend device type is NVMe by default')), - - cfg.BoolOpt('remoteable', - default=False, - help=_('Remoteable is false by default')) + cfg.StrOpt( + 'spdk_conf_file', + default='/etc/cyborg/spdk.conf', + help=_('SPDK conf file to be used for the SPDK driver'), + ), + cfg.StrOpt( + 'accelerator_servers', + default=['vhost', 'nvmf', 'iscsi'], + help=_('A list of accelerator servers to enable by default'), + ), + cfg.StrOpt( + 'spdk_dir', + default='/home/wewe/spdk', + help=_('The SPDK directory is /home/{user_name}/spdk'), + ), + cfg.StrOpt( + 'device_type', + default='NVMe', + help=_('Backend device type is NVMe by default'), + ), + cfg.BoolOpt( + 'remoteable', default=False, help=_('Remoteable is false by default') + ), ] CONF = cfg.CONF @@ -124,15 +128,9 @@ def construct_error_bdev(py, accelerator, basename): acc_client.construct_error_bdev(basename) -def construct_nvme_bdev(py, - accelerator, - name, - trtype, - traddr, - adrfam, - trsvcid, - subnqn - ): +def construct_nvme_bdev( + py, accelerator, name, trtype, traddr, adrfam, trsvcid, subnqn +): """Add a bdev with nvme backend :param py: py_client. @@ -148,22 +146,13 @@ def construct_nvme_bdev(py, :return: name. """ acc_client = get_accelerator_client(py, accelerator) - acc_client.construct_nvme_bdev(name, - trtype, - traddr, - adrfam, - trsvcid, - subnqn - ) + acc_client.construct_nvme_bdev( + name, trtype, traddr, adrfam, trsvcid, subnqn + ) return name -def construct_null_bdev(py, - accelerator, - name, - total_size, - block_size - ): +def construct_null_bdev(py, accelerator, name, total_size, block_size): """Add a bdev with null backend :param py: py_client. @@ -189,7 +178,7 @@ def get_py_client(server): py = PySPDK(server) return py else: - msg = (_("Could not find %s accelerator") % server) + msg = _("Could not find %s accelerator") % server raise exception.InvalidAccelerator(msg) @@ -204,7 +193,7 @@ def check_for_setup_error(py, server): if py.is_alive(): return True else: - msg = (_("%s accelerator is down") % server) + msg = _("%s accelerator is down") % server raise exception.AcceleratorException(msg) @@ -224,6 +213,5 @@ def get_accelerator_client(py, accelerator): acc_client = NvmfTgt(py) return acc_client else: - exc_msg = (_("accelerator_client %(acc_client) is missing") - % acc_client) + exc_msg = _("accelerator_client %(acc_client) is missing") % acc_client raise exception.InvalidAccelerator(exc_msg) diff --git a/cyborg/accelerator/drivers/spdk/util/pyspdk/nvmf_client.py b/cyborg/accelerator/drivers/spdk/util/pyspdk/nvmf_client.py index aee2c22a..32f1ecb5 100644 --- a/cyborg/accelerator/drivers/spdk/util/pyspdk/nvmf_client.py +++ b/cyborg/accelerator/drivers/spdk/util/pyspdk/nvmf_client.py @@ -18,19 +18,16 @@ LOG = logging.getLogger(__name__) class NvmfTgt: - def __init__(self, py): super().__init__() self.py = py def get_rpc_methods(self): - rpc_methods = self._get_json_objs( - 'get_rpc_methods', '10.0.2.15') + rpc_methods = self._get_json_objs('get_rpc_methods', '10.0.2.15') return rpc_methods def get_bdevs(self): - block_devices = self._get_json_objs( - 'get_bdevs', '10.0.2.15') + block_devices = self._get_json_objs('get_bdevs', '10.0.2.15') return block_devices def delete_bdev(self, name): @@ -46,27 +43,20 @@ class NvmfTgt: def construct_aio_bdev(self, filename, name, block_size): sub_args = [filename, name, str(block_size)] res = self.py.exec_rpc( - 'construct_aio_bdev', - '10.0.2.15', - sub_args=sub_args) + 'construct_aio_bdev', '10.0.2.15', sub_args=sub_args + ) LOG.info(res) def construct_error_bdev(self, basename): sub_args = [basename] res = self.py.exec_rpc( - 'construct_error_bdev', - '10.0.2.15', - sub_args=sub_args) + 'construct_error_bdev', '10.0.2.15', sub_args=sub_args + ) LOG.info(res) def construct_nvme_bdev( - self, - name, - trtype, - traddr, - adrfam=None, - trsvcid=None, - subnqn=None): + self, name, trtype, traddr, adrfam=None, trsvcid=None, subnqn=None + ): sub_args = ["-b", "-t", "-a"] sub_args.insert(1, name) sub_args.insert(2, trtype) @@ -81,52 +71,42 @@ class NvmfTgt: sub_args.append("-n") sub_args.append(subnqn) res = self.py.exec_rpc( - 'construct_nvme_bdev', - '10.0.2.15', - sub_args=sub_args) + 'construct_nvme_bdev', '10.0.2.15', sub_args=sub_args + ) return res def construct_null_bdev(self, name, total_size, block_size): sub_args = [name, str(total_size), str(block_size)] res = self.py.exec_rpc( - 'construct_null_bdev', - '10.0.2.15', - sub_args=sub_args) + 'construct_null_bdev', '10.0.2.15', sub_args=sub_args + ) return res def construct_malloc_bdev(self, total_size, block_size): sub_args = [str(total_size), str(block_size)] res = self.py.exec_rpc( - 'construct_malloc_bdev', - '10.0.2.15', - sub_args=sub_args) + 'construct_malloc_bdev', '10.0.2.15', sub_args=sub_args + ) LOG.info(res) def delete_nvmf_subsystem(self, nqn): sub_args = [nqn] res = self.py.exec_rpc( - 'delete_nvmf_subsystem', - '10.0.2.15', - sub_args=sub_args) + 'delete_nvmf_subsystem', '10.0.2.15', sub_args=sub_args + ) LOG.info(res) def construct_nvmf_subsystem( - self, - nqn, - listen, - hosts, - serial_number, - namespaces): + self, nqn, listen, hosts, serial_number, namespaces + ): sub_args = [nqn, listen, hosts, serial_number, namespaces] res = self.py.exec_rpc( - 'construct_nvmf_subsystem', - '10.0.2.15', - sub_args=sub_args) + 'construct_nvmf_subsystem', '10.0.2.15', sub_args=sub_args + ) LOG.info(res) def get_nvmf_subsystems(self): - subsystems = self._get_json_objs( - 'get_nvmf_subsystems', '10.0.2.15') + subsystems = self._get_json_objs('get_nvmf_subsystems', '10.0.2.15') return subsystems def _get_json_objs(self, method, server_ip): diff --git a/cyborg/accelerator/drivers/spdk/util/pyspdk/py_spdk.py b/cyborg/accelerator/drivers/spdk/util/pyspdk/py_spdk.py index d48afedc..417d3ad4 100644 --- a/cyborg/accelerator/drivers/spdk/util/pyspdk/py_spdk.py +++ b/cyborg/accelerator/drivers/spdk/util/pyspdk/py_spdk.py @@ -24,7 +24,6 @@ LOG = logging.getLogger(__name__) class PySPDK: - def __init__(self, pname): super().__init__() self.pid = None diff --git a/cyborg/accelerator/drivers/spdk/util/pyspdk/vhost_client.py b/cyborg/accelerator/drivers/spdk/util/pyspdk/vhost_client.py index 15bf13f6..ade49d84 100644 --- a/cyborg/accelerator/drivers/spdk/util/pyspdk/vhost_client.py +++ b/cyborg/accelerator/drivers/spdk/util/pyspdk/vhost_client.py @@ -18,7 +18,6 @@ LOG = logging.getLogger(__name__) class VhostTgt: - def __init__(self, py): super().__init__() self.py = py @@ -28,8 +27,7 @@ class VhostTgt: return rpc_methods def get_scsi_devices(self): - scsi_devices = self._get_json_objs( - 'get_scsi_devices', '127.0.0.1') + scsi_devices = self._get_json_objs('get_scsi_devices', '127.0.0.1') return scsi_devices def get_luns(self): @@ -37,29 +35,25 @@ class VhostTgt: return luns def get_interfaces(self): - interfaces = self._get_json_objs( - 'get_interfaces', '127.0.0.1') + interfaces = self._get_json_objs('get_interfaces', '127.0.0.1') return interfaces def add_ip_address(self, ifc_index, ip_addr): sub_args = [ifc_index, ip_addr] res = self.py.exec_rpc( - 'add_ip_address', - '127.0.0.1', - sub_args=sub_args) + 'add_ip_address', '127.0.0.1', sub_args=sub_args + ) return res def delete_ip_address(self, ifc_index, ip_addr): sub_args = [ifc_index, ip_addr] res = self.py.exec_rpc( - 'delete_ip_address', - '127.0.0.1', - sub_args=sub_args) + 'delete_ip_address', '127.0.0.1', sub_args=sub_args + ) return res def get_bdevs(self): - block_devices = self._get_json_objs( - 'get_bdevs', '127.0.0.1') + block_devices = self._get_json_objs('get_bdevs', '127.0.0.1') return block_devices def delete_bdev(self, name): @@ -75,27 +69,20 @@ class VhostTgt: def construct_aio_bdev(self, filename, name, block_size): sub_args = [filename, name, str(block_size)] res = self.py.exec_rpc( - 'construct_aio_bdev', - '127.0.0.1', - sub_args=sub_args) + 'construct_aio_bdev', '127.0.0.1', sub_args=sub_args + ) LOG.info(res) def construct_error_bdev(self, basename): sub_args = [basename] res = self.py.exec_rpc( - 'construct_error_bdev', - '127.0.0.1', - sub_args=sub_args) + 'construct_error_bdev', '127.0.0.1', sub_args=sub_args + ) LOG.info(res) def construct_nvme_bdev( - self, - name, - trtype, - traddr, - adrfam=None, - trsvcid=None, - subnqn=None): + self, name, trtype, traddr, adrfam=None, trsvcid=None, subnqn=None + ): sub_args = ["-b", "-t", "-a"] sub_args.insert(1, name) sub_args.insert(2, trtype) @@ -110,25 +97,22 @@ class VhostTgt: sub_args.append("-n") sub_args.append(subnqn) res = self.py.exec_rpc( - 'construct_nvme_bdev', - '127.0.0.1', - sub_args=sub_args) + 'construct_nvme_bdev', '127.0.0.1', sub_args=sub_args + ) return res def construct_null_bdev(self, name, total_size, block_size): sub_args = [name, str(total_size), str(block_size)] res = self.py.exec_rpc( - 'construct_null_bdev', - '127.0.0.1', - sub_args=sub_args) + 'construct_null_bdev', '127.0.0.1', sub_args=sub_args + ) return res def construct_malloc_bdev(self, total_size, block_size): sub_args = [str(total_size), str(block_size)] res = self.py.exec_rpc( - 'construct_malloc_bdev', - '10.0.2.15', - sub_args=sub_args) + 'construct_malloc_bdev', '10.0.2.15', sub_args=sub_args + ) LOG.info(res) def _get_json_objs(self, method, server_ip): diff --git a/cyborg/accelerator/drivers/spdk/vhost/vhost.py b/cyborg/accelerator/drivers/spdk/vhost/vhost.py index 095bf768..3cda9056 100644 --- a/cyborg/accelerator/drivers/spdk/vhost/vhost.py +++ b/cyborg/accelerator/drivers/spdk/vhost/vhost.py @@ -28,7 +28,7 @@ LOG = logging.getLogger(__name__) class VHOSTDRIVER(SPDKDRIVER): """VHOSTDRIVER class. - vhost server app should be able to implement this driver. + vhost server app should be able to implement this driver. """ SERVER = 'vhost' @@ -57,7 +57,7 @@ class VHOSTDRIVER(SPDKDRIVER): 'bdevs': bdevs, 'scsi_devices': scsi_devices, 'luns': luns, - 'interfaces': interfaces + 'interfaces': interfaces, } return accelerator_obj diff --git a/cyborg/accelerator/drivers/ssd/base.py b/cyborg/accelerator/drivers/ssd/base.py index d2c64a05..ccb20537 100644 --- a/cyborg/accelerator/drivers/ssd/base.py +++ b/cyborg/accelerator/drivers/ssd/base.py @@ -16,6 +16,7 @@ """ Cyborg Generic SSD driver implementation. """ + from oslo_log import log as logging from cyborg.accelerator.common import utils as pci_utils @@ -28,8 +29,8 @@ VENDOR_MAPS = pci_utils.get_vendor_maps() class SSDDriver: """Generic class for SSD drivers. - This is just a virtual SSD drivers interface. - Vendor should implement their specific drivers. + This is just a virtual SSD drivers interface. + Vendor should implement their specific drivers. """ @classmethod diff --git a/cyborg/accelerator/drivers/ssd/utils.py b/cyborg/accelerator/drivers/ssd/utils.py index 9c1edb99..ae2242ec 100644 --- a/cyborg/accelerator/drivers/ssd/utils.py +++ b/cyborg/accelerator/drivers/ssd/utils.py @@ -36,11 +36,13 @@ import cyborg.privsep LOG = logging.getLogger(__name__) SSD_FLAGS = ["Non-Volatile memory controller"] -SSD_INFO_PATTERN = re.compile(r"(?P[0-9a-fA-F]{4}:[0-9a-fA-F]{2}:" - r"[0-9a-fA-F]{2}\.[0-9a-fA-F]) " - r"(?P.*) [\[].*]: (?P.*) .*" - r"[\[](?P[0-9a-fA-F]" - r"{4}):(?P[0-9a-fA-F]{4})].*") +SSD_INFO_PATTERN = re.compile( + r"(?P[0-9a-fA-F]{4}:[0-9a-fA-F]{2}:" + r"[0-9a-fA-F]{2}\.[0-9a-fA-F]) " + r"(?P.*) [\[].*]: (?P.*) .*" + r"[\[](?P[0-9a-fA-F]" + r"{4}):(?P[0-9a-fA-F]{4})].*" +) VENDOR_MAPS = utils.get_vendor_maps() @@ -107,8 +109,10 @@ def _generate_driver_device(ssd): driver_device_obj = driver_device.DriverDevice() driver_device_obj.vendor = ssd["vendor_id"] driver_device_obj.model = ssd.get('model', 'miss model info') - std_board_info = {'product_id': ssd.get('product_id'), - 'controller': ssd.get('controller')} + std_board_info = { + 'product_id': ssd.get('product_id'), + 'controller': ssd.get('controller'), + } vendor_board_info = {'vendor_info': ssd.get('vendor_info', 'ssd_vb_info')} driver_device_obj.std_board_info = jsonutils.dumps(std_board_info) driver_device_obj.vendor_board_info = jsonutils.dumps(vendor_board_info) @@ -165,6 +169,7 @@ def _generate_attribute_list(ssd): values = ssd.get(k, []) for index, val in enumerate(values): driver_attr = driver_attribute.DriverAttribute( - key="trait" + str(index), value=val) + key="trait" + str(index), value=val + ) attr_list.append(driver_attr) return attr_list diff --git a/cyborg/agent/manager.py b/cyborg/agent/manager.py index 0629ae6b..82c228bc 100644 --- a/cyborg/agent/manager.py +++ b/cyborg/agent/manager.py @@ -70,18 +70,22 @@ class AgentManager(periodic_task.PeriodicTasks): for attempt in range(retries + 1): try: self.resource_provider_name = ( - self._get_resource_provider_name()) + self._get_resource_provider_name() + ) break except exception.PlacementResourceProviderNotFound: if attempt < retries: - wait = 2 ** attempt # 1, 2, 4, 8, ... + wait = 2**attempt # 1, 2, 4, 8, ... LOG.warning( "Resource provider not found in Placement, " "retrying in %(wait)ds (attempt %(attempt)d/" "%(total)d)", - {'wait': wait, - 'attempt': attempt + 1, - 'total': retries + 1}) + { + 'wait': wait, + 'attempt': attempt + 1, + 'total': retries + 1, + }, + ) time.sleep(wait) else: raise @@ -114,7 +118,8 @@ class AgentManager(periodic_task.PeriodicTasks): LOG.warning( "Resource provider not found with name '%(primary)s', " "using fallback '%(fallback)s'", - {'primary': primary, 'fallback': candidate}) + {'primary': primary, 'fallback': candidate}, + ) LOG.info("Using resource provider name: %s", candidate) return candidate @@ -122,9 +127,12 @@ class AgentManager(periodic_task.PeriodicTasks): "Could not find resource provider in Placement. Tried: %s. " "Ensure nova-compute is running and has registered with " "Placement, or set [agent] resource_provider_name to match " - "the compute node's hypervisor_hostname.", candidates) + "the compute node's hypervisor_hostname.", + candidates, + ) raise exception.PlacementResourceProviderNotFound( - resource_provider=primary) + resource_provider=primary + ) def _check_resource_provider_exists(self, hostname): """Check if a resource provider exists in Placement. @@ -134,36 +142,38 @@ class AgentManager(periodic_task.PeriodicTasks): """ try: resp = self.placement_client.get( - "/resource_providers?name=" - + urllib.parse.quote(hostname)) + "/resource_providers?name=" + urllib.parse.quote(hostname) + ) providers = resp.json().get("resource_providers", []) return len(providers) > 0 except (ValueError, AttributeError) as e: LOG.warning( - "Failed to parse Placement response for " - "'%(name)s': %(err)s", - {'name': hostname, 'err': e}) + "Failed to parse Placement response for '%(name)s': %(err)s", + {'name': hostname, 'err': e}, + ) return False - except (ks_exc.ClientException, - exception.PlacementServerError) as e: + except (ks_exc.ClientException, exception.PlacementServerError) as e: LOG.warning( "Failed to check resource provider '%(name)s': %(err)s", - {'name': hostname, 'err': e}) + {'name': hostname, 'err': e}, + ) return False def periodic_tasks(self, context, raise_on_error=False): return self.run_periodic_tasks(context, raise_on_error=raise_on_error) - def fpga_program(self, context, controlpath_id, - bitstream_uuid, driver_name): + def fpga_program( + self, context, controlpath_id, bitstream_uuid, driver_name + ): bitstream_uuid = str(bitstream_uuid) if not uuidutils.is_uuid_like(bitstream_uuid): raise exception.InvalidUUID(uuid=bitstream_uuid) - download_path = tempfile.NamedTemporaryFile(suffix=".gbs", - prefix=bitstream_uuid) - self.image_api.download(context, - bitstream_uuid, - dest_path=download_path.name) + download_path = tempfile.NamedTemporaryFile( + suffix=".gbs", prefix=bitstream_uuid + ) + self.image_api.download( + context, bitstream_uuid, dest_path=download_path.name + ) try: driver = self.fpga_driver.create(driver_name) ret = driver.program(controlpath_id, download_path.name) diff --git a/cyborg/agent/resource_tracker.py b/cyborg/agent/resource_tracker.py index 851c1628..1ce99cf5 100644 --- a/cyborg/agent/resource_tracker.py +++ b/cyborg/agent/resource_tracker.py @@ -53,20 +53,22 @@ class ResourceTracker: if not enabled_drivers: enabled_drivers = CONF.agent.enabled_drivers valid_drivers = ExtensionManager( - namespace='cyborg.accelerator.driver').names() + namespace='cyborg.accelerator.driver' + ).names() for d in enabled_drivers: if d not in valid_drivers: raise exception.InvalidDriver(name=d) acc_driver = driver.DriverManager( - namespace='cyborg.accelerator.driver', name=d, - invoke_on_load=True).driver + namespace='cyborg.accelerator.driver', + name=d, + invoke_on_load=True, + ).driver acc_drivers.append(acc_driver) self.acc_drivers = acc_drivers @utils.synchronized(AGENT_RESOURCE_SEMAPHORE) def update_usage(self, context): - """Update the resource usage periodically. - """ + """Update the resource usage periodically.""" acc_list = [] for acc_driver in self.acc_drivers: acc_list.extend(acc_driver.discover()) diff --git a/cyborg/agent/rpcapi.py b/cyborg/agent/rpcapi.py index 1845b4ec..2e282812 100644 --- a/cyborg/agent/rpcapi.py +++ b/cyborg/agent/rpcapi.py @@ -1,4 +1,3 @@ - # # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain @@ -43,42 +42,64 @@ class AgentAPI: def __init__(self, topic=None): super().__init__() self.topic = topic or constants.AGENT_TOPIC - target = messaging.Target(topic=self.topic, - version='1.0') + target = messaging.Target(topic=self.topic, version='1.0') serializer = objects_base.CyborgObjectSerializer() - self.client = rpc.get_client(target, - version_cap=self.RPC_API_VERSION, - serializer=serializer) + self.client = rpc.get_client( + target, version_cap=self.RPC_API_VERSION, serializer=serializer + ) - def fpga_program(self, context, hostname, controlpath_id, - bitstream_uuid, driver_name): - LOG.info('Agent fpga_program: hostname: (%s) ' + - 'bitstream_id: (%s)', hostname, bitstream_uuid) + def fpga_program( + self, context, hostname, controlpath_id, bitstream_uuid, driver_name + ): + LOG.info( + 'Agent fpga_program: hostname: (%s) bitstream_id: (%s)', + hostname, + bitstream_uuid, + ) version = '1.0' cctxt = self.client.prepare(server=hostname, version=version) - return cctxt.call(context, 'fpga_program', - controlpath_id=controlpath_id, - bitstream_uuid=bitstream_uuid, - driver_name=driver_name) + return cctxt.call( + context, + 'fpga_program', + controlpath_id=controlpath_id, + bitstream_uuid=bitstream_uuid, + driver_name=driver_name, + ) - def create_vgpu_mdev(self, context, hostname, pci_addr, - asked_type, ah_uuid): - LOG.debug('Agent create_vgpu_mdev: hostname: (%s) , pci_address: (%s)' - 'gpu_id: (%s)', hostname, pci_addr, ah_uuid) + def create_vgpu_mdev( + self, context, hostname, pci_addr, asked_type, ah_uuid + ): + LOG.debug( + 'Agent create_vgpu_mdev: hostname: (%s) , pci_address: (%s)' + 'gpu_id: (%s)', + hostname, + pci_addr, + ah_uuid, + ) version = '1.0' cctxt = self.client.prepare(server=hostname, version=version) - return cctxt.call(context, 'create_vgpu_mdev', - pci_addr=pci_addr, - asked_type=asked_type, - ah_uuid=ah_uuid) + return cctxt.call( + context, + 'create_vgpu_mdev', + pci_addr=pci_addr, + asked_type=asked_type, + ah_uuid=ah_uuid, + ) - def remove_vgpu_mdev(self, context, hostname, pci_addr, - asked_type, ah_uuid): - LOG.debug('Agent remove_vgpu_mdev: hostname: (%s) ' - 'gpu_id: (%s)', hostname, ah_uuid) + def remove_vgpu_mdev( + self, context, hostname, pci_addr, asked_type, ah_uuid + ): + LOG.debug( + 'Agent remove_vgpu_mdev: hostname: (%s) gpu_id: (%s)', + hostname, + ah_uuid, + ) version = '1.0' cctxt = self.client.prepare(server=hostname, version=version) - return cctxt.call(context, 'remove_vgpu_mdev', - pci_addr=pci_addr, - asked_type=asked_type, - ah_uuid=ah_uuid) + return cctxt.call( + context, + 'remove_vgpu_mdev', + pci_addr=pci_addr, + asked_type=asked_type, + ah_uuid=ah_uuid, + ) diff --git a/cyborg/api/app.py b/cyborg/api/app.py index a677ba74..6c17b422 100644 --- a/cyborg/api/app.py +++ b/cyborg/api/app.py @@ -40,10 +40,12 @@ def setup_app(pecan_config=None, extra_hooks=None): if not pecan_config: pecan_config = get_pecan_config() - app_hooks = [hooks.ConfigHook(), - hooks.ConductorAPIHook(), - hooks.ContextHook(pecan_config.app.acl_public_routes), - hooks.PublicUrlHook()] + app_hooks = [ + hooks.ConfigHook(), + hooks.ConductorAPIHook(), + hooks.ContextHook(pecan_config.app.acl_public_routes), + hooks.PublicUrlHook(), + ] if extra_hooks: app_hooks.extend(extra_hooks) @@ -53,7 +55,7 @@ def setup_app(pecan_config=None, extra_hooks=None): force_canonical=getattr(pecan_config.app, 'force_canonical', True), hooks=app_hooks, wrap_app=middleware.ParsableErrorMiddleware, - **app_conf + **app_conf, ) return app diff --git a/cyborg/api/config.py b/cyborg/api/config.py index e5a50ee8..40db1de1 100644 --- a/cyborg/api/config.py +++ b/cyborg/api/config.py @@ -15,10 +15,7 @@ # Server Specific Configurations # See https://pecan.readthedocs.org/en/latest/configuration.html#server-configuration # noqa -server = { - 'port': '6666', - 'host': '127.0.0.1' -} +server = {'port': '6666', 'host': '127.0.0.1'} # Pecan Application Configurations # See https://pecan.readthedocs.org/en/latest/configuration.html#application-configuration # noqa @@ -27,14 +24,9 @@ app = { 'modules': ['cyborg.api'], 'static_root': '%(confdir)s/public', 'debug': False, - 'acl_public_routes': [ - '/', - '/v2' - ] + 'acl_public_routes': ['/', '/v2'], } # WSME Configurations # See https://wsme.readthedocs.org/en/latest/integrate.html#configuration -wsme = { - 'debug': False -} +wsme = {'debug': False} diff --git a/cyborg/api/controllers/base.py b/cyborg/api/controllers/base.py index c869cd96..68649226 100644 --- a/cyborg/api/controllers/base.py +++ b/cyborg/api/controllers/base.py @@ -37,12 +37,14 @@ class APIBase(wtypes.Base): def as_dict(self): """Render this object as a dict of its fields.""" - return {k: getattr(self, k) for k in self.fields - if hasattr(self, k) and getattr(self, k) != wsme.Unset} + return { + k: getattr(self, k) + for k in self.fields + if hasattr(self, k) and getattr(self, k) != wsme.Unset + } class CyborgController(rest.RestController): - def _handle_patch(self, method, remainder, request=None): """Routes ``PATCH`` _custom_actions.""" # route to a patch_all or get if no additional parts are available @@ -91,7 +93,8 @@ class Version: """ (self.major, self.minor) = Version.parse_headers( - headers, default_version, latest_version) + headers, default_version, latest_version + ) def __repr__(self): return '%s.%s' % (self.major, self.minor) @@ -127,7 +130,8 @@ class Version: if len(version) != 2: raise exc.HTTPNotAcceptable( - "Invalid value for %s header" % Version.current_api_version) + "Invalid value for %s header" % Version.current_api_version + ) return version def __gt__(self, other): diff --git a/cyborg/api/controllers/link.py b/cyborg/api/controllers/link.py index 35c02bd1..905cc962 100644 --- a/cyborg/api/controllers/link.py +++ b/cyborg/api/controllers/link.py @@ -24,11 +24,15 @@ def build_url(resource, resource_args, bookmark=False, base_url=None): base_url = pecan.request.public_url # TODO(Sundar) Return version etc. similar to other projects. - template = '%(url)s/accelerator/%(res)s' \ - if bookmark else '%(url)s/accelerator/' + base.API_V2 + '/%(res)s' + template = ( + '%(url)s/accelerator/%(res)s' + if bookmark + else '%(url)s/accelerator/' + base.API_V2 + '/%(res)s' + ) if resource_args: - template += ('%(args)s' if resource_args.startswith('?') - else '/%(args)s') + template += ( + '%(args)s' if resource_args.startswith('?') else '/%(args)s' + ) return template % {'url': base_url, 'res': resource, 'args': resource_args} @@ -45,10 +49,17 @@ class Link(base.APIBase): """Indicates the type of document/link.""" @staticmethod - def make_link(rel_name, url, resource, resource_args, - bookmark=False, type=wtypes.Unset): - href = build_url(resource, resource_args, - bookmark=bookmark, base_url=url) + def make_link( + rel_name, + url, + resource, + resource_args, + bookmark=False, + type=wtypes.Unset, + ): + href = build_url( + resource, resource_args, bookmark=bookmark, base_url=url + ) return Link(href=href, rel=rel_name, type=type) @staticmethod diff --git a/cyborg/api/controllers/root.py b/cyborg/api/controllers/root.py index 719578de..c4bf15ea 100644 --- a/cyborg/api/controllers/root.py +++ b/cyborg/api/controllers/root.py @@ -57,13 +57,17 @@ class Version(base.APIBase): version.min_version = None else: v = importlib.import_module( - 'cyborg.api.controllers.%s.versions' % id) + 'cyborg.api.controllers.%s.versions' % id + ) version.max_version = v.max_version_string() version.min_version = v.min_version_string() version.id = id version.status = status - version.links = [link.Link.make_link('self', pecan.request.host_url, - id, '', bookmark=True)] + version.links = [ + link.Link.make_link( + 'self', pecan.request.host_url, id, '', bookmark=True + ) + ] return version @@ -87,7 +91,8 @@ class Root(base.APIBase): root.description = ( "Cyborg is the OpenStack project for lifecycle " "management of hardware accelerators, such as GPUs," - "FPGAs, AI chips, security accelerators, etc.") + "FPGAs, AI chips, security accelerators, etc." + ) root.versions = [Version.convert('v2')] root.default_version = Version.convert('v2') return root diff --git a/cyborg/api/controllers/types.py b/cyborg/api/controllers/types.py index 24b325fb..6af2a8f6 100644 --- a/cyborg/api/controllers/types.py +++ b/cyborg/api/controllers/types.py @@ -27,26 +27,35 @@ from cyborg.common.i18n import _ class FilterType(wtypes.UserType): """Query filter.""" + name = 'filtertype' basetype = wtypes.text - _supported_fields = wtypes.Enum(wtypes.text, 'parent_uuid', 'root_uuid', - 'board', 'availability', 'interface_type', - 'instance_uuid', 'limit', 'marker', - 'sort_key', 'sort_dir', 'name') + _supported_fields = wtypes.Enum( + wtypes.text, + 'parent_uuid', + 'root_uuid', + 'board', + 'availability', + 'interface_type', + 'instance_uuid', + 'limit', + 'marker', + 'sort_key', + 'sort_dir', + 'name', + ) field = wsme.wsattr(_supported_fields, mandatory=True) value = wsme.wsattr(wtypes.text, mandatory=True) def __repr__(self): # for logging calls - return '' % (self.field, - self.value) + return '' % (self.field, self.value) @classmethod def sample(cls): - return cls(field='interface_type', - value='pci') + return cls(field='interface_type', value='pci') def as_dict(self): d = dict() @@ -131,10 +140,12 @@ integer = wtypes.IntegerType() class JsonPatchType(wtypes.Base): """A complex type that represents a single json-patch operation.""" - path = wtypes.wsattr(wtypes.StringType(pattern=r'^(/[\w-]+)+$'), - mandatory=True) - op = wtypes.wsattr(wtypes.Enum(str, 'add', 'replace', 'remove'), - mandatory=True) + path = wtypes.wsattr( + wtypes.StringType(pattern=r'^(/[\w-]+)+$'), mandatory=True + ) + op = wtypes.wsattr( + wtypes.Enum(str, 'add', 'replace', 'remove'), mandatory=True + ) value = wtypes.wsattr(jsontype, default=wtypes.Unset) # The class of the objects being patched. Override this in subclasses. @@ -170,8 +181,9 @@ class JsonPatchType(wtypes.Base): if cls._non_removable_attrs is None: cls._non_removable_attrs = cls._extra_non_removable_attrs.copy() if cls._api_base: - fields = inspect.getmembers(cls._api_base, - lambda a: not inspect.isroutine(a)) + fields = inspect.getmembers( + cls._api_base, lambda a: not inspect.isroutine(a) + ) for name, field in fields: if getattr(field, 'mandatory', False): cls._non_removable_attrs.add('/%s' % name) diff --git a/cyborg/api/controllers/utils.py b/cyborg/api/controllers/utils.py index 6ddca588..19bf0ef8 100644 --- a/cyborg/api/controllers/utils.py +++ b/cyborg/api/controllers/utils.py @@ -20,16 +20,20 @@ import wsme from cyborg.common.i18n import _ -JSONPATCH_EXCEPTIONS = (jsonpatch.JsonPatchException, - jsonpatch.JsonPointerException, - KeyError) +JSONPATCH_EXCEPTIONS = ( + jsonpatch.JsonPatchException, + jsonpatch.JsonPointerException, + KeyError, +) def apply_jsonpatch(doc, patch): for p in patch: if p['op'] == 'add' and p['path'].count('/') == 1: if p['path'].lstrip('/') not in doc: - msg = _('Adding a new attribute (%s) to the root of ' - ' the resource is not allowed') + msg = _( + 'Adding a new attribute (%s) to the root of ' + ' the resource is not allowed' + ) raise wsme.exc.ClientSideError(msg % p['path']) return jsonpatch.apply_patch(doc, jsonpatch.JsonPatch(patch)) diff --git a/cyborg/api/controllers/v2/__init__.py b/cyborg/api/controllers/v2/__init__.py index a5d221c2..7959a4af 100644 --- a/cyborg/api/controllers/v2/__init__.py +++ b/cyborg/api/controllers/v2/__init__.py @@ -36,13 +36,17 @@ from cyborg.api.controllers.v2 import versions def min_version(): return base.Version( {base.Version.current_api_version: versions.min_version_string()}, - versions.min_version_string(), versions.max_version_string()) + versions.min_version_string(), + versions.max_version_string(), + ) def max_version(): return base.Version( {base.Version.current_api_version: versions.max_version_string()}, - versions.min_version_string(), versions.max_version_string()) + versions.min_version_string(), + versions.max_version_string(), + ) class V2(base.APIBase): @@ -71,9 +75,8 @@ class V2(base.APIBase): v2.min_version = str(min_version()) v2.status = 'CURRENT' v2.links = [ - link.Link.make_link('self', pecan.request.public_url, - '', ''), - ] + link.Link.make_link('self', pecan.request.public_url, '', ''), + ] return v2 @@ -98,24 +101,35 @@ class Controller(rest.RestController): raise exc.HTTPNotAcceptable( "Mutually exclusive versions requested. Version %(ver)s " "requested but not supported by this service. The supported " - "version range is: [%(min)s, %(max)s]." % - {'ver': version, 'min': versions.min_version_string(), - 'max': versions.max_version_string()}, - headers=headers) + "version range is: [%(min)s, %(max)s]." + % { + 'ver': version, + 'min': versions.min_version_string(), + 'max': versions.max_version_string(), + }, + headers=headers, + ) # ensure the minor version is within the supported range if version < min_version() or version > max_version(): raise exc.HTTPNotAcceptable( "Version %(ver)s was requested but the minor version is not " "supported by this service. The supported version range is: " - "[%(min)s, %(max)s]." % - {'ver': version, 'min': versions.min_version_string(), - 'max': versions.max_version_string()}, - headers=headers) + "[%(min)s, %(max)s]." + % { + 'ver': version, + 'min': versions.min_version_string(), + 'max': versions.max_version_string(), + }, + headers=headers, + ) @pecan.expose() def _route(self, args, request=None): - v = base.Version(pecan.request.headers, versions.min_version_string(), - versions.max_version_string()) + v = base.Version( + pecan.request.headers, + versions.min_version_string(), + versions.max_version_string(), + ) # The Vary header is used as a hint to caching proxies and user agents # that the response is also dependent on the OpenStack-API-Version and @@ -124,9 +138,11 @@ class Controller(rest.RestController): # Always set the min and max headers pecan.response.headers[base.Version.min_api_version] = ( - versions.min_version_string()) + versions.min_version_string() + ) pecan.response.headers[base.Version.max_api_version] = ( - versions.max_version_string()) + versions.max_version_string() + ) # assert that requested version is supported self._check_version(v, pecan.response.headers) diff --git a/cyborg/api/controllers/v2/arqs.py b/cyborg/api/controllers/v2/arqs.py index 946e42ca..f28a8fe1 100644 --- a/cyborg/api/controllers/v2/arqs.py +++ b/cyborg/api/controllers/v2/arqs.py @@ -41,6 +41,7 @@ class ARQ(base.APIBase): This class enforces type checking and value constraints, and converts between the internal object model and the API representation. """ + uuid = types.uuid """The UUID of the ARQ""" @@ -77,9 +78,13 @@ class ARQ(base.APIBase): def convert_with_links(cls, obj_arq): api_arq = cls(**obj_arq.as_dict()) api_arq.links = [ - link.Link.make_link('self', pecan.request.public_url, - 'accelerator_requests', api_arq.uuid) - ] + link.Link.make_link( + 'self', + pecan.request.public_url, + 'accelerator_requests', + api_arq.uuid, + ) + ] return api_arq @@ -92,31 +97,33 @@ class ARQCollection(base.APIBase): @classmethod def convert_with_links(cls, obj_arqs): collection = cls() - collection.arqs = [ARQ.convert_with_links(obj_arq) - for obj_arq in obj_arqs] + collection.arqs = [ + ARQ.convert_with_links(obj_arq) for obj_arq in obj_arqs + ] return collection class ARQsController(base.CyborgController): """REST controller for ARQs. - For the relationship between ARQs and device profiles, see - nova/nova/accelerator/cyborg.py. + For the relationship between ARQs and device profiles, see + nova/nova/accelerator/cyborg.py. """ @authorize_wsgi.authorize_wsgi("cyborg:arq", "create", False) - @expose.expose(ARQCollection, body=types.jsontype, - status_code=HTTPStatus.CREATED) + @expose.expose( + ARQCollection, body=types.jsontype, status_code=HTTPStatus.CREATED + ) def post(self, req): """Create one or more ARQs for a single device profile. - Request body: - { 'device_profile_name': } - Future: - { 'device_profile_name': # required - 'device_profile_group_id': , # opt, default=0 - 'image_uuid': , #optional, for future - } - :param req: request body. + Request body: + { 'device_profile_name': } + Future: + { 'device_profile_name': # required + 'device_profile_group_id': , # opt, default=0 + 'image_uuid': , #optional, for future + } + :param req: request body. """ LOG.info("[arq] post req = (%s)", req) context = pecan.request.context @@ -126,8 +133,8 @@ class ARQsController(base.CyborgController): devprof = objects.DeviceProfile.get_by_name(context, dp_name) except exception.ResourceNotFound: raise exception.ResourceNotFound( - resource='Device Profile', - msg='with name=%s' % dp_name) + resource='Device Profile', msg='with name=%s' % dp_name + ) except Exception as e: raise e else: @@ -145,8 +152,10 @@ class ARQsController(base.CyborgController): accel_resources = [int(group.get("resources:FPGA"))] * 2 else: accel_resources = [ - int(val) for key, val in group.items() - if key.startswith('resources')] + int(val) + for key, val in group.items() + if key.startswith('resources') + ] # If/when we introduce non-accelerator resources, like # device-local memory, the key search above needs to be @@ -161,11 +170,13 @@ class ARQsController(base.CyborgController): extarq_fields = {'arq': obj_arq} obj_extarq = objects.ExtARQ(context, **extarq_fields) new_extarq = pecan.request.conductor_api.arq_create( - context, obj_extarq, devprof.id) + context, obj_extarq, devprof.id + ) extarq_list.append(new_extarq) ret = ARQCollection.convert_with_links( - [extarq.arq for extarq in extarq_list]) + [extarq.arq for extarq in extarq_list] + ) LOG.info('[arqs] post returned: %s', ret) return ret @@ -182,8 +193,11 @@ class ARQsController(base.CyborgController): def get_all(self, bind_state=None, instance=None): """Retrieve a list of arqs.""" # TODO(Sundar) Need to implement 'arq=uuid1,...' query parameter - LOG.info('[arqs] get_all. bind_state:(%s), instance:(%s)', - bind_state or '', instance or '') + LOG.info( + '[arqs] get_all. bind_state:(%s), instance:(%s)', + bind_state or '', + instance or '', + ) context = pecan.request.context extarqs = objects.ExtARQ.list(context) state_map = constants.ARQ_BIND_STATES_STATUS_MAP @@ -193,32 +207,36 @@ class ARQsController(base.CyborgController): # Apply instance filter before state filter. if bind_state and bind_state != 'resolved': raise exception.ARQBadState( - state=bind_state, uuid=None, expected=['resolved']) + state=bind_state, uuid=None, expected=['resolved'] + ) if instance: - new_arqs = [arq for arq in arqs - if arq['instance_uuid'] == instance] + new_arqs = [ + arq for arq in arqs if arq['instance_uuid'] == instance + ] arqs = new_arqs if bind_state: for arq in new_arqs: if arq['state'] not in valid_bind_states: # NOTE(Sundar) This should return HTTP code 423 # if any ARQ for this instance is not resolved. - LOG.warning('Some of ARQs for instance %s is not ' - 'resolved', instance) + LOG.warning( + 'Some of ARQs for instance %s is not resolved', + instance, + ) return wsme.api.Response( - None, - status_code=HTTPStatus.LOCKED) + None, status_code=HTTPStatus.LOCKED + ) elif bind_state: - arqs = [arq for arq in arqs - if arq['state'] in valid_bind_states] + arqs = [arq for arq in arqs if arq['state'] in valid_bind_states] ret = ARQCollection.convert_with_links(arqs) LOG.info('[arqs:get_all] Returned: %s', ret) return ret @authorize_wsgi.authorize_wsgi("cyborg:arq", "delete", False) - @expose.expose(None, wtypes.text, types.uuid, - status_code=HTTPStatus.NO_CONTENT) + @expose.expose( + None, wtypes.text, types.uuid, status_code=HTTPStatus.NO_CONTENT + ) def delete(self, arqs=None, instance=None): """Delete one or more ARQS. @@ -241,14 +259,16 @@ class ARQsController(base.CyborgController): if (arqs and instance) or (not arqs and not instance): raise exception.ObjectActionError( action='delete', - reason='Provide either an ARQ uuid list or an instance UUID') + reason='Provide either an ARQ uuid list or an instance UUID', + ) elif arqs: LOG.info("[arqs] delete. arqs=(%s)", arqs) pecan.request.conductor_api.arq_delete_by_uuid(context, arqs) else: # instance is not None LOG.info("[arqs] delete. instance=(%s)", instance) pecan.request.conductor_api.arq_delete_by_instance_uuid( - context, instance) + context, instance + ) def _validate_arq_patch(self, patch): """Validate a single patch for an ARQ. @@ -258,31 +278,39 @@ class ARQsController(base.CyborgController): value field of arq_uuid in patch() method below. :returns: dict of valid fields """ - valid_fields = {'hostname': None, - 'device_rp_uuid': None, - 'instance_uuid': None} + valid_fields = { + 'hostname': None, + 'device_rp_uuid': None, + 'instance_uuid': None, + } if utils.allow_project_id(): valid_fields['project_id'] = None - if ((not all(p['op'] == 'add' for p in patch)) and - (not all(p['op'] == 'remove' for p in patch))): - raise exception.PatchError( - reason='Every op must be add or remove') + if (not all(p['op'] == 'add' for p in patch)) and ( + not all(p['op'] == 'remove' for p in patch) + ): + raise exception.PatchError(reason='Every op must be add or remove') for p in patch: path = p['path'].lstrip('/') if path == 'project_id' and not utils.allow_project_id(): - raise exception.NotAcceptable(_( - "Request not acceptable. The minimal required API " - "version should be %(base)s.%(opr)s") % - {'base': versions.BASE_VERSION, - 'opr': versions.MINOR_1_PROJECT_ID}) + raise exception.NotAcceptable( + _( + "Request not acceptable. The minimal required API " + "version should be %(base)s.%(opr)s" + ) + % { + 'base': versions.BASE_VERSION, + 'opr': versions.MINOR_1_PROJECT_ID, + } + ) if path not in valid_fields.keys(): reason = 'Invalid path in patch {}'.format(p['path']) raise exception.PatchError(reason=reason) if p['op'] == 'add': valid_fields[path] = p['value'] - not_found = [field for field, value in valid_fields.items() - if value is None] + not_found = [ + field for field, value in valid_fields.items() if value is None + ] if patch[0]['op'] == 'add' and len(not_found) > 0: msg = ','.join(not_found) reason = _('Fields absent in patch {}').format(msg) @@ -296,17 +324,20 @@ class ARQsController(base.CyborgController): instance_uuid = patch_fields['instance_uuid'] extarqs = objects.ExtARQ.list(context) extarqs_for_instance = [ - extarq for extarq in extarqs - if extarq.arq['instance_uuid'] == instance_uuid] + extarq + for extarq in extarqs + if extarq.arq['instance_uuid'] == instance_uuid + ] if extarqs_for_instance: # duplicate binding request - msg = _('Instance {} already has accelerator requests. ' - 'Cannot bind additional ARQs.') + msg = _( + 'Instance {} already has accelerator requests. ' + 'Cannot bind additional ARQs.' + ) reason = msg.format(instance_uuid) raise exception.PatchError(reason=reason) @authorize_wsgi.authorize_wsgi("cyborg:arq", "update", False) - @expose.expose(None, body=types.jsontype, - status_code=HTTPStatus.ACCEPTED) + @expose.expose(None, body=types.jsontype, status_code=HTTPStatus.ACCEPTED) def patch(self, patch_list): """Bind/Unbind one or more ARQs. @@ -345,4 +376,5 @@ class ARQsController(base.CyborgController): self._check_if_already_bound(context, valid_fields) pecan.request.conductor_api.arq_apply_patch( - context, patch_list, valid_fields) + context, patch_list, valid_fields + ) diff --git a/cyborg/api/controllers/v2/attributes.py b/cyborg/api/controllers/v2/attributes.py index 1cd58800..44255962 100644 --- a/cyborg/api/controllers/v2/attributes.py +++ b/cyborg/api/controllers/v2/attributes.py @@ -26,6 +26,7 @@ from cyborg.api.controllers import types from cyborg.api import expose from cyborg.common import authorize_wsgi from cyborg import objects + LOG = log.getLogger(__name__) @@ -66,9 +67,13 @@ class Attribute(base.APIBase): def convert_with_links(cls, obj_attribute): api_attribute = cls(**obj_attribute.as_dict()) api_attribute.links = [ - link.Link.make_link('self', pecan.request.public_url, - 'attributes', api_attribute.uuid) - ] + link.Link.make_link( + 'self', + pecan.request.public_url, + 'attributes', + api_attribute.uuid, + ) + ] return api_attribute def get_attribute(self, obj_attribute): @@ -79,7 +84,7 @@ class Attribute(base.APIBase): api_obj[field] = str(obj_attribute[field]) api_obj['links'] = [ link.Link.make_link_dict('attributes', api_obj['uuid']) - ] + ] return api_obj @@ -94,26 +99,30 @@ class AttributeCollection(Attribute): collection = cls() collection.attributes = [ Attribute.convert_with_links(obj_attribute) - for obj_attribute in obj_attributes] + for obj_attribute in obj_attributes + ] return collection def get_attributes(self, obj_attributes): api_obj_attributes = [ self.get_attribute(obj_attribute) - for obj_attribute in obj_attributes] + for obj_attribute in obj_attributes + ] return api_obj_attributes -class AttributesController(base.CyborgController, - AttributeCollection): +class AttributesController(base.CyborgController, AttributeCollection): """REST controller for Attributes.""" @authorize_wsgi.authorize_wsgi("cyborg:attribute", "get_all", False) @expose.expose(AttributeCollection, wtypes.IntegerType(), wtypes.text) def get_all(self, deployable_id=None, key=None): """Retrieve a list of attributes.""" - LOG.info('[attributes] get_all by deployable_id:(%s) and key:(%s).', - deployable_id, key) + LOG.info( + '[attributes] get_all by deployable_id:(%s) and key:(%s).', + deployable_id, + key, + ) search_opts = {} if deployable_id: search_opts['deployable_id'] = deployable_id @@ -137,8 +146,9 @@ class AttributesController(base.CyborgController, return ret @authorize_wsgi.authorize_wsgi("cyborg:attribute", "create", False) - @expose.expose(Attribute, body=types.jsontype, - status_code=HTTPStatus.CREATED) + @expose.expose( + Attribute, body=types.jsontype, status_code=HTTPStatus.CREATED + ) def post(self, req_attr): """Create one attribute. :param req_attr: attribute value. @@ -159,7 +169,7 @@ class AttributesController(base.CyborgController, @expose.expose(None, wtypes.text, status_code=HTTPStatus.NO_CONTENT) def delete(self, uuid): """Delete one attribute. - - UUID of a attribute. + - UUID of a attribute. """ LOG.info('[attributes] delete by uuid: %s.', uuid) context = pecan.request.context diff --git a/cyborg/api/controllers/v2/deployables.py b/cyborg/api/controllers/v2/deployables.py index 600bf854..dac7d448 100644 --- a/cyborg/api/controllers/v2/deployables.py +++ b/cyborg/api/controllers/v2/deployables.py @@ -82,12 +82,14 @@ class Deployable(base.APIBase): url = pecan.request.public_url api_dep.links = [ link.Link.make_link('self', url, 'deployables', api_dep.uuid), - link.Link.make_link('bookmark', url, 'deployables', api_dep.uuid, - bookmark=True) - ] + link.Link.make_link( + 'bookmark', url, 'deployables', api_dep.uuid, bookmark=True + ), + ] query = {"deployable_id": obj_dep.id} - attr_get_list = objects.Attribute.get_by_filter(pecan.request.context, - query) + attr_get_list = objects.Attribute.get_by_filter( + pecan.request.context, query + ) attributes_list = [] for exist_attr in attr_get_list: attributes_list.append({exist_attr.key: exist_attr.value}) @@ -104,12 +106,12 @@ class DeployableCollection(Deployable): def convert_with_links(self, obj_deps): collection = DeployableCollection() collection.deployables = [ - self.convert_with_link(obj_dep) for obj_dep in obj_deps] + self.convert_with_link(obj_dep) for obj_dep in obj_deps + ] return collection class DeployablePatchType(types.JsonPatchType): - _api_base = Deployable @staticmethod @@ -118,8 +120,7 @@ class DeployablePatchType(types.JsonPatchType): return defaults + ['/name', '/num_accelerators'] -class DeployablesController(base.CyborgController, - DeployableCollection): +class DeployablesController(base.CyborgController, DeployableCollection): """REST controller for Deployables.""" _custom_actions = {'program': ['PATCH']} @@ -148,15 +149,15 @@ class DeployablesController(base.CyborgController, obj_dep = objects.Deployable.get(pecan.request.context, uuid) obj_dev = objects.Device.get_by_device_id( - pecan.request.context, - obj_dep.device_id + pecan.request.context, obj_dep.device_id ) hostname = obj_dev.hostname driver_name = obj_dep.driver_name cpid_list = obj_dep.get_cpid_list(pecan.request.context) controlpath_id = cpid_list[0] controlpath_id['cpid_info'] = jsonutils.loads( - cpid_list[0]['cpid_info']) + cpid_list[0]['cpid_info'] + ) self.agent_rpcapi = AgentAPI() ret = self.agent_rpcapi.fpga_program( pecan.request.context, @@ -164,7 +165,7 @@ class DeployablesController(base.CyborgController, controlpath_id, image_uuid, driver_name, - ) + ) if ret: return self.convert_with_link(obj_dep) else: diff --git a/cyborg/api/controllers/v2/device_profiles.py b/cyborg/api/controllers/v2/device_profiles.py index 186236fe..9cbda195 100644 --- a/cyborg/api/controllers/v2/device_profiles.py +++ b/cyborg/api/controllers/v2/device_profiles.py @@ -51,6 +51,7 @@ from cyborg.common import constants from cyborg.common import exception from cyborg.common.i18n import _ from cyborg import objects + LOG = log.getLogger(__name__) @@ -91,9 +92,13 @@ class DeviceProfile(base.APIBase): def convert_with_links(cls, obj_devprof): api_devprof = cls(**obj_devprof.as_dict()) api_devprof.links = [ - link.Link.make_link('self', pecan.request.public_url, - 'device_profiles', api_devprof.uuid) - ] + link.Link.make_link( + 'self', + pecan.request.public_url, + 'device_profiles', + api_devprof.uuid, + ) + ] return api_devprof def get_device_profile(self, obj_devprof): @@ -104,7 +109,7 @@ class DeviceProfile(base.APIBase): api_obj[field] = str(obj_devprof[field]) api_obj['links'] = [ link.Link.make_link_dict('device_profiles', api_obj['uuid']) - ] + ] return api_obj @@ -119,23 +124,25 @@ class DeviceProfileCollection(DeviceProfile): collection = cls() collection.device_profiles = [ DeviceProfile.convert_with_links(obj_devprof) - for obj_devprof in obj_devprofs] + for obj_devprof in obj_devprofs + ] return collection def get_device_profiles(self, obj_devprofs): api_obj_devprofs = [ self.get_device_profile(obj_devprof) - for obj_devprof in obj_devprofs] + for obj_devprof in obj_devprofs + ] return api_obj_devprofs -class DeviceProfilesController(base.CyborgController, - DeviceProfileCollection): +class DeviceProfilesController(base.CyborgController, DeviceProfileCollection): """REST controller for Device Profiles.""" @authorize_wsgi.authorize_wsgi("cyborg:device_profile", "create", False) - @expose.expose(DeviceProfile, body=types.jsontype, - status_code=HTTPStatus.CREATED) + @expose.expose( + DeviceProfile, body=types.jsontype, status_code=HTTPStatus.CREATED + ) def post(self, req_devprof_list): """Create one or more device_profiles. @@ -154,8 +161,8 @@ class DeviceProfilesController(base.CyborgController, LOG.info("[device_profiles] POST request = (%s)", req_devprof_list) if len(req_devprof_list) != 1: raise exception.InvalidParameterValue( - err="Only one device profile allowed " - "per POST request for now.") + err="Only one device profile allowed per POST request for now." + ) req_devprof = req_devprof_list[0] self._validate_post_request(req_devprof) @@ -163,7 +170,8 @@ class DeviceProfilesController(base.CyborgController, obj_devprof = objects.DeviceProfile(context, **req_devprof) new_devprof = pecan.request.conductor_api.device_profile_create( - context, obj_devprof) + context, obj_devprof + ) return DeviceProfile.convert_with_links(new_devprof) def _validate_post_request(self, req_devprof): @@ -177,7 +185,8 @@ class DeviceProfilesController(base.CyborgController, raise exception.DeviceProfileNameNeeded() elif not re.match(NAME, name): raise exception.InvalidParameterValue( - err="Device profile name must be of the form %s" % NAME) + err="Device profile name must be of the form %s" % NAME + ) groups = req_devprof.get("groups") if not groups: @@ -190,7 +199,8 @@ class DeviceProfilesController(base.CyborgController, if not re.match(GROUP_KEYS, key): raise exception.InvalidParameterValue( err="Device profile group keys must be of" - " the form %s" % GROUP_KEYS) + " the form %s" % GROUP_KEYS + ) # check trait name and it's value if key.startswith("trait:"): inner_origin_trait = ":".join(key.split(":")[1:]) @@ -198,11 +208,13 @@ class DeviceProfilesController(base.CyborgController, if not inner_trait.startswith('CUSTOM_'): raise exception.InvalidParameterValue( err="Unsupported trait name format %s, should " - "start with CUSTOM_" % inner_trait) + "start with CUSTOM_" % inner_trait + ) if value not in TRAIT_VALUES: raise exception.InvalidParameterValue( err="Unsupported trait value %s, the value must" - " be one among %s" % (value, TRAIT_VALUES)) + " be one among %s" % (value, TRAIT_VALUES) + ) # strip " " and update old group key. if inner_origin_trait != inner_trait: del group[key] @@ -212,15 +224,19 @@ class DeviceProfilesController(base.CyborgController, if key.startswith("resources:"): inner_origin_rc = ":".join(key.split(":")[1:]) inner_rc = inner_origin_rc.strip(" ") - if inner_rc not in constants.SUPPORT_RESOURCES and \ - not inner_rc.startswith('CUSTOM_'): + if ( + inner_rc not in constants.SUPPORT_RESOURCES + and not inner_rc.startswith('CUSTOM_') + ): raise exception.InvalidParameterValue( - err="Unsupported resource class %s" % inner_rc) + err="Unsupported resource class %s" % inner_rc + ) try: int(value) except ValueError: raise exception.InvalidParameterValue( - err="Resources number %s is invalid" % value) + err="Resources number %s is invalid" % value + ) # strip " " and update old group key. if inner_origin_rc != inner_rc: del group[key] @@ -233,12 +249,14 @@ class DeviceProfilesController(base.CyborgController, context = pecan.request.context obj_devprofs = objects.DeviceProfile.list(context) if names: - new_obj_devprofs = [devprof for devprof in obj_devprofs - if devprof['name'] in names] + new_obj_devprofs = [ + devprof for devprof in obj_devprofs if devprof['name'] in names + ] obj_devprofs = new_obj_devprofs elif uuid is not None: - new_obj_devprofs = [devprof for devprof in obj_devprofs - if devprof['uuid'] == uuid] + new_obj_devprofs = [ + devprof for devprof in obj_devprofs if devprof['uuid'] == uuid + ] obj_devprofs = new_obj_devprofs return obj_devprofs @@ -265,32 +283,39 @@ class DeviceProfilesController(base.CyborgController, context = pecan.request.context if uuidutils.is_uuid_like(dp_uuid_or_name): LOG.info('[device_profiles] get_one. uuid=%s', dp_uuid_or_name) - obj_devprof = objects.DeviceProfile.get_by_uuid(context, - dp_uuid_or_name) + obj_devprof = objects.DeviceProfile.get_by_uuid( + context, dp_uuid_or_name + ) else: if api.request.version.minor >= versions.MINOR_2_DP_BY_NAME: LOG.info('[device_profiles] get_one. name=%s', dp_uuid_or_name) - obj_devprof = \ - objects.DeviceProfile.get_by_name(context, - dp_uuid_or_name) + obj_devprof = objects.DeviceProfile.get_by_name( + context, dp_uuid_or_name + ) else: - raise exception.NotAcceptable(_( - "Request not acceptable. The minimal required API " - "version should be %(base)s.%(opr)s") % - {'base': versions.BASE_VERSION, - 'opr': versions.MINOR_2_DP_BY_NAME}) + raise exception.NotAcceptable( + _( + "Request not acceptable. The minimal required API " + "version should be %(base)s.%(opr)s" + ) + % { + 'base': versions.BASE_VERSION, + 'opr': versions.MINOR_2_DP_BY_NAME, + } + ) if not obj_devprof: LOG.warning("Device profile with %s not found!", dp_uuid_or_name) raise exception.ResourceNotFound( - resource='Device profile', - msg='with %s' % dp_uuid_or_name) + resource='Device profile', msg='with %s' % dp_uuid_or_name + ) api_obj_devprof = self.get_device_profile(obj_devprof) ret = {"device_profile": api_obj_devprof} LOG.info('[device_profiles] get_one returned: %s', ret) # TODO(Sundar) Replace this with convert_with_links() - return wsme.api.Response(ret, status_code=HTTPStatus.OK, - return_type=wsme.types.DictType) + return wsme.api.Response( + ret, status_code=HTTPStatus.OK, return_type=wsme.types.DictType + ) @authorize_wsgi.authorize_wsgi("cyborg:device_profile", "delete") @expose.expose(None, wtypes.text, status_code=HTTPStatus.NO_CONTENT) @@ -309,11 +334,13 @@ class DeviceProfilesController(base.CyborgController, LOG.info('[device_profiles] delete uuid=%s', uuid) obj_devprof = objects.DeviceProfile.get_by_uuid(context, uuid) pecan.request.conductor_api.device_profile_delete( - context, obj_devprof) + context, obj_devprof + ) else: names = value.split(",") LOG.info('[device_profiles] delete names=(%s)', names) for name in names: obj_devprof = objects.DeviceProfile.get_by_name(context, name) pecan.request.conductor_api.device_profile_delete( - context, obj_devprof) + context, obj_devprof + ) diff --git a/cyborg/api/controllers/v2/devices.py b/cyborg/api/controllers/v2/devices.py index a4317044..295623c5 100644 --- a/cyborg/api/controllers/v2/devices.py +++ b/cyborg/api/controllers/v2/devices.py @@ -41,6 +41,7 @@ class Device(base.APIBase): This class enforces type checking and value constraints, and converts between the internal object model and the API representation. """ + uuid = types.uuid """The UUID of the device""" @@ -82,9 +83,10 @@ class Device(base.APIBase): if not api.request.version.minor >= versions.MINOR_3_DEVICE_STATUS: delattr(api_device, 'status') api_device.links = [ - link.Link.make_link('self', pecan.request.public_url, - 'devices', api_device.uuid) - ] + link.Link.make_link( + 'self', pecan.request.public_url, 'devices', api_device.uuid + ) + ] return api_device @@ -97,8 +99,9 @@ class DeviceCollection(base.APIBase): @classmethod def convert_with_links(cls, devices): collection = cls() - collection.devices = [Device.convert_with_links(device) - for device in devices] + collection.devices = [ + Device.convert_with_links(device) for device in devices + ] return collection @@ -118,8 +121,13 @@ class DevicesController(base.CyborgController): return Device.convert_with_links(device) @authorize_wsgi.authorize_wsgi("cyborg:device", "get_all", False) - @expose.expose(DeviceCollection, wtypes.text, wtypes.text, wtypes.text, - wtypes.ArrayType(types.FilterType)) + @expose.expose( + DeviceCollection, + wtypes.text, + wtypes.text, + wtypes.text, + wtypes.ArrayType(types.FilterType), + ) def get_all(self, type=None, vendor=None, hostname=None, filters=None): """Retrieve a list of devices. :param type: type of a device. @@ -144,8 +152,7 @@ class DevicesController(base.CyborgController): return DeviceCollection.convert_with_links(obj_devices) @authorize_wsgi.authorize_wsgi("cyborg:device", "disable") - @expose.expose(None, wtypes.text, types.uuid, - status_code=HTTPStatus.OK) + @expose.expose(None, wtypes.text, types.uuid, status_code=HTTPStatus.OK) def disable(self, uuid): context = pecan.request.context device = objects.Device.get(context, uuid) @@ -161,15 +168,17 @@ class DevicesController(base.CyborgController): else: raise exception.ResourceNotFound( resource='Attribute', - msg='with deployable_id=%s,key=%s' % (deployable.id, 'rc')) + msg='with deployable_id=%s,key=%s' % (deployable.id, 'rc'), + ) client.update_rp_inventory_reserved( - deployable.rp_uuid, att_type, + deployable.rp_uuid, + att_type, deployable.num_accelerators, - deployable.num_accelerators) + deployable.num_accelerators, + ) @authorize_wsgi.authorize_wsgi("cyborg:device", "enable") - @expose.expose(None, wtypes.text, types.uuid, - status_code=HTTPStatus.OK) + @expose.expose(None, wtypes.text, types.uuid, status_code=HTTPStatus.OK) def enable(self, uuid): context = pecan.request.context device = objects.Device.get(context, uuid) @@ -185,8 +194,8 @@ class DevicesController(base.CyborgController): else: raise exception.ResourceNotFound( resource='Attribute', - msg='with deployable_id=%s,key=%s' % (deployable.id, 'rc')) + msg='with deployable_id=%s,key=%s' % (deployable.id, 'rc'), + ) client.update_rp_inventory_reserved( - deployable.rp_uuid, att_type, - deployable.num_accelerators, - 0) + deployable.rp_uuid, att_type, deployable.num_accelerators, 0 + ) diff --git a/cyborg/api/hooks.py b/cyborg/api/hooks.py index 2369b329..f4ba50f7 100644 --- a/cyborg/api/hooks.py +++ b/cyborg/api/hooks.py @@ -38,7 +38,8 @@ class PublicUrlHook(hooks.PecanHook): def before(self, state): state.request.public_url = ( - cfg.CONF.api.public_endpoint or state.request.host_url) + cfg.CONF.api.public_endpoint or state.request.host_url + ) class ConductorAPIHook(hooks.PecanHook): @@ -92,10 +93,11 @@ class ContextHook(hooks.PecanHook): user_auth_plugin = req.environ.get('keystone.token_auth') roles = req.headers.get('X-Roles', '').split(',') - is_admin = ('admin' in roles or 'administrator' in roles) + is_admin = 'admin' in roles or 'administrator' in roles state.request.context = context.RequestContext.from_environ( req.environ, user_auth_plugin=user_auth_plugin, is_admin=is_admin, - service_catalog=service_catalog) + service_catalog=service_catalog, + ) diff --git a/cyborg/api/middleware/__init__.py b/cyborg/api/middleware/__init__.py index 95cc7401..bfbf9a42 100644 --- a/cyborg/api/middleware/__init__.py +++ b/cyborg/api/middleware/__init__.py @@ -20,5 +20,4 @@ from cyborg.api.middleware import parsable_error ParsableErrorMiddleware = parsable_error.ParsableErrorMiddleware AuthTokenMiddleware = auth_token.AuthTokenMiddleware -__all__ = ('ParsableErrorMiddleware', - 'AuthTokenMiddleware') +__all__ = ('ParsableErrorMiddleware', 'AuthTokenMiddleware') diff --git a/cyborg/api/middleware/auth_token.py b/cyborg/api/middleware/auth_token.py index 7e6186b3..45999cfb 100644 --- a/cyborg/api/middleware/auth_token.py +++ b/cyborg/api/middleware/auth_token.py @@ -40,8 +40,10 @@ class AuthTokenMiddleware(auth_token.AuthProtocol): route_pattern_tpl = r'%s(\.json)?$' try: - self.public_api_routes = [re.compile(route_pattern_tpl % route_tpl) - for route_tpl in public_api_routes] + self.public_api_routes = [ + re.compile(route_pattern_tpl % route_tpl) + for route_tpl in public_api_routes + ] except re.error as e: msg = _('Cannot compile public API routes: %s') % e @@ -56,8 +58,11 @@ class AuthTokenMiddleware(auth_token.AuthProtocol): # The information whether the API call is being performed against the # public API is required for some other components. Saving it to the # WSGI environment is reasonable thereby. - env['is_public_api'] = any(map(lambda pattern: re.match(pattern, path), - self.public_api_routes)) + env['is_public_api'] = any( + map( + lambda pattern: re.match(pattern, path), self.public_api_routes + ) + ) if env['is_public_api']: return self.app(env, start_response) diff --git a/cyborg/api/middleware/parsable_error.py b/cyborg/api/middleware/parsable_error.py index 1aa11fad..954c06cd 100644 --- a/cyborg/api/middleware/parsable_error.py +++ b/cyborg/api/middleware/parsable_error.py @@ -42,15 +42,18 @@ class ParsableErrorMiddleware: except (ValueError, TypeError): # pragma: nocover raise Exception( 'ParsableErrorMiddleware received an invalid ' - 'status %s' % status) + 'status %s' % status + ) if (state['status_code'] // 100) not in (2, 3): # Remove some headers so we can replace them later # when we have the full error message and can # compute the length. headers = [ - (h, v) for (h, v) in headers - if h not in ('Content-Length', 'Content-Type')] + (h, v) + for (h, v) in headers + if h not in ('Content-Length', 'Content-Type') + ] # Save the headers in case we need to modify them. state['headers'] = headers diff --git a/cyborg/cmd/agent.py b/cyborg/cmd/agent.py index 2358afe5..3ea3eccd 100644 --- a/cyborg/cmd/agent.py +++ b/cyborg/cmd/agent.py @@ -31,9 +31,9 @@ def main(): cyborg_service.prepare_service(sys.argv) priv_context.init(root_helper=shlex.split('sudo')) - mgr = cyborg_service.RPCService('cyborg.agent.manager', - 'AgentManager', - constants.AGENT_TOPIC) + mgr = cyborg_service.RPCService( + 'cyborg.agent.manager', 'AgentManager', constants.AGENT_TOPIC + ) launcher = service.launch(CONF, mgr, restart_method='mutate') launcher.wait() diff --git a/cyborg/cmd/conductor.py b/cyborg/cmd/conductor.py index 98aaff57..7b6b02d9 100644 --- a/cyborg/cmd/conductor.py +++ b/cyborg/cmd/conductor.py @@ -31,9 +31,11 @@ def main(): # Parse config file and command line options, then start logging cyborg_service.prepare_service(sys.argv) - mgr = cyborg_service.RPCService('cyborg.conductor.manager', - 'ConductorManager', - constants.CONDUCTOR_TOPIC) + mgr = cyborg_service.RPCService( + 'cyborg.conductor.manager', + 'ConductorManager', + constants.CONDUCTOR_TOPIC, + ) launcher = service.launch(CONF, mgr, restart_method='mutate') launcher.wait() diff --git a/cyborg/cmd/dbsync.py b/cyborg/cmd/dbsync.py index 0bf7831c..7b6b3f79 100644 --- a/cyborg/cmd/dbsync.py +++ b/cyborg/cmd/dbsync.py @@ -28,7 +28,6 @@ from cyborg.db import migration class DBCommand: - def upgrade(self): migration.upgrade(CONF.command.revision) @@ -50,16 +49,22 @@ def add_command_parsers(subparsers): parser = subparsers.add_parser( 'upgrade', - help=_("Upgrade the database schema to the latest version. " - "Optionally, use --revision to specify an alembic revision " - "string to upgrade to.")) + help=_( + "Upgrade the database schema to the latest version. " + "Optionally, use --revision to specify an alembic revision " + "string to upgrade to." + ), + ) parser.set_defaults(func=command_object.upgrade) parser.add_argument('--revision', nargs='?') parser = subparsers.add_parser( 'revision', - help=_("Create a new alembic revision. " - "Use --message to set the message string.")) + help=_( + "Create a new alembic revision. " + "Use --message to set the message string." + ), + ) parser.set_defaults(func=command_object.revision) parser.add_argument('-m', '--message') parser.add_argument('--autogenerate', action='store_true') @@ -69,21 +74,23 @@ def add_command_parsers(subparsers): parser.add_argument('--revision', nargs='?') parser = subparsers.add_parser( - 'version', - help=_("Print the current version information and exit.")) + 'version', help=_("Print the current version information and exit.") + ) parser.set_defaults(func=command_object.version) parser = subparsers.add_parser( - 'create_schema', - help=_("Create the database schema.")) + 'create_schema', help=_("Create the database schema.") + ) parser.set_defaults(func=command_object.create_schema) def main(): - command_opt = cfg.SubCommandOpt('command', - title='Command', - help=_('Available commands'), - handler=add_command_parsers) + command_opt = cfg.SubCommandOpt( + 'command', + title='Command', + help=_('Available commands'), + handler=add_command_parsers, + ) CONF.register_cli_opt(command_opt) diff --git a/cyborg/cmd/status.py b/cyborg/cmd/status.py index 63925b7a..1a0ba993 100644 --- a/cyborg/cmd/status.py +++ b/cyborg/cmd/status.py @@ -25,21 +25,22 @@ CONF = cfg.CONF class Checks(upgradecheck.UpgradeCommands): - """Various upgrade checks should be added as separate methods in this class and added to _upgrade_checks tuple. """ def _check_policy_json(self): "Checks to see if policy file is JSON-formatted policy file." - msg = _("Your policy file is JSON-formatted which is " - "deprecated since Victoria release (Cyborg 5.0.0). " - "You need to switch to YAML-formatted file. You can use the " - "``oslopolicy-convert-json-to-yaml`` tool to convert existing " - "JSON-formatted files to YAML-formatted files in a " - "backwards-compatible manner: " - "https://docs.openstack.org/oslo.policy/" - "latest/cli/oslopolicy-convert-json-to-yaml.html.") + msg = _( + "Your policy file is JSON-formatted which is " + "deprecated since Victoria release (Cyborg 5.0.0). " + "You need to switch to YAML-formatted file. You can use the " + "``oslopolicy-convert-json-to-yaml`` tool to convert existing " + "JSON-formatted files to YAML-formatted files in a " + "backwards-compatible manner: " + "https://docs.openstack.org/oslo.policy/" + "latest/cli/oslopolicy-convert-json-to-yaml.html." + ) status = upgradecheck.Result(upgradecheck.Code.SUCCESS) # NOTE(gmann): Check if policy file exist and is in # JSON format by actually loading the file not just @@ -64,7 +65,8 @@ class Checks(upgradecheck.UpgradeCommands): def main(): return upgradecheck.main( - cfg.CONF, project='cyborg', upgrade_command=Checks()) + cfg.CONF, project='cyborg', upgrade_command=Checks() + ) if __name__ == '__main__': diff --git a/cyborg/common/authorize_wsgi.py b/cyborg/common/authorize_wsgi.py index e6b5ee51..95aa20a0 100644 --- a/cyborg/common/authorize_wsgi.py +++ b/cyborg/common/authorize_wsgi.py @@ -10,6 +10,7 @@ # License for the specific language governing permissions and limitations # under the License. """Policy Authorize Engine For Cyborg.""" + import functools import sys @@ -31,9 +32,13 @@ LOG = log.getLogger(__name__) @lockutils.synchronized('policy_enforcer', 'cyborg-') -def init_enforcer(policy_file=None, rules=None, - default_rule=None, use_conf=True, - suppress_deprecation_warnings=False): +def init_enforcer( + policy_file=None, + rules=None, + default_rule=None, + use_conf=True, + suppress_deprecation_warnings=False, +): """Synchronously initializes the policy enforcer :param policy_file: Custom policy file to use, if none is specified, `CONF.oslo_policy.policy_file` will be used. @@ -60,7 +65,8 @@ def init_enforcer(policy_file=None, rules=None, policy_file=policy_file, rules=rules, default_rule=default_rule, - use_conf=use_conf) + use_conf=use_conf, + ) if suppress_deprecation_warnings: _ENFORCER.suppress_deprecation_warnings = True _ENFORCER.register_defaults(policies.list_policies()) @@ -87,8 +93,9 @@ def authorize(rule, target, creds, do_raise=False, *args, **kwargs): """ enforcer = get_enforcer() try: - return enforcer.authorize(rule, target, creds, do_raise=do_raise, - *args, **kwargs) + return enforcer.authorize( + rule, target, creds, do_raise=do_raise, *args, **kwargs + ) except policy.PolicyNotAuthorized: raise exception.HTTPForbidden(resource=rule) @@ -112,6 +119,7 @@ def authorize_wsgi(api_name, act=None, need_target=True): def post(self, values): ... """ + def wrapper(fn): action = '%s:%s' % (api_name, act or fn.__name__) @@ -124,8 +132,7 @@ def authorize_wsgi(api_name, act=None, need_target=True): orig_code = getattr(orig_exception, 'code', None) pecan.response.status = orig_code or resp_status data = wsme.api.format_exception( - exception_info, - pecan.conf.get('wsme', {}).get('debug', False) + exception_info, pecan.conf.get('wsme', {}).get('debug', False) ) del exception_info return data @@ -145,10 +152,13 @@ def authorize_wsgi(api_name, act=None, need_target=True): # just support object, other type will just keep target as # empty, then follow authorize method will fail and throw # an exception - if isinstance(resource, - object_base.VersionedObjectDictCompat): - target = {'project_id': resource.project_id, - 'user_id': resource.user_id} + if isinstance( + resource, object_base.VersionedObjectDictCompat + ): + target = { + 'project_id': resource.project_id, + 'user_id': resource.user_id, + } except Exception: return return_error(500) elif need_target: @@ -159,8 +169,10 @@ def authorize_wsgi(api_name, act=None, need_target=True): else: # for create method, before resource exsites, we can check the # the credentials with itself. - target = {'project_id': context.project_id, - 'user_id': context.user_id} + target = { + 'project_id': context.project_id, + 'user_id': context.user_id, + } try: authorize(action, target, credentials, do_raise=True) diff --git a/cyborg/common/config.py b/cyborg/common/config.py index 954e36d9..9663da3d 100644 --- a/cyborg/common/config.py +++ b/cyborg/common/config.py @@ -22,8 +22,10 @@ from cyborg import version def parse_args(argv, default_config_files=None): rpc.set_defaults(control_exchange='cyborg') version_string = version.version_info.release_string() - cfg.CONF(argv[1:], - project='cyborg', - version=version_string, - default_config_files=default_config_files) + cfg.CONF( + argv[1:], + project='cyborg', + version=version_string, + default_config_files=default_config_files, + ) rpc.init(cfg.CONF) diff --git a/cyborg/common/constants.py b/cyborg/common/constants.py index 3c5cd879..60074903 100644 --- a/cyborg/common/constants.py +++ b/cyborg/common/constants.py @@ -24,27 +24,42 @@ DEVICE_NIC = 'NIC' DEVICE_SSD = 'SSD' -ARQ_STATES = (ARQ_INITIAL, ARQ_BIND_STARTED, ARQ_BOUND, ARQ_UNBOUND, - ARQ_BIND_FAILED, ARQ_UNBIND_FAILED, ARQ_DELETING) = ( - 'Initial', 'BindStarted', 'Bound', 'Unbound', 'BindFailed', 'UnbindFailed', - 'Deleting') +ARQ_STATES = ( + ARQ_INITIAL, + ARQ_BIND_STARTED, + ARQ_BOUND, + ARQ_UNBOUND, + ARQ_BIND_FAILED, + ARQ_UNBIND_FAILED, + ARQ_DELETING, +) = ( + 'Initial', + 'BindStarted', + 'Bound', + 'Unbound', + 'BindFailed', + 'UnbindFailed', + 'Deleting', +) -ARQ_BIND_STAGE = (ARQ_PRE_BIND, ARQ_FINISH_BIND, - ARQ_OUFOF_BIND_FLOW) = ( +ARQ_BIND_STAGE = (ARQ_PRE_BIND, ARQ_FINISH_BIND, ARQ_OUFOF_BIND_FLOW) = ( [ARQ_INITIAL, ARQ_BIND_STARTED], [ARQ_BOUND, ARQ_BIND_FAILED], - [ARQ_UNBOUND, ARQ_DELETING]) + [ARQ_UNBOUND, ARQ_DELETING], +) ARQ_BIND_STATUS = (ARQ_BIND_STATUS_FINISH, ARQ_BIND_STATUS_FAILED) = ( - "completed", "failed") + "completed", + "failed", +) ARQ_BIND_STATES_STATUS_MAP = { ARQ_BOUND: ARQ_BIND_STATUS_FINISH, ARQ_BIND_FAILED: ARQ_BIND_STATUS_FAILED, - ARQ_DELETING: ARQ_BIND_STATUS_FAILED + ARQ_DELETING: ARQ_BIND_STATUS_FAILED, } # TODO(Shaohe): maybe we can use oslo automaton lib @@ -56,14 +71,25 @@ ARQ_STATES_TRANSFORM_MATRIX = { ARQ_BOUND: [ARQ_BIND_STARTED], ARQ_UNBOUND: [ARQ_INITIAL, ARQ_BIND_STARTED, ARQ_BOUND, ARQ_BIND_FAILED], ARQ_BIND_FAILED: [ARQ_BIND_STARTED, ARQ_BOUND], - ARQ_DELETING: [ARQ_INITIAL, ARQ_BIND_STARTED, ARQ_BOUND, - ARQ_UNBOUND, ARQ_BIND_FAILED] + ARQ_DELETING: [ + ARQ_INITIAL, + ARQ_BIND_STARTED, + ARQ_BOUND, + ARQ_UNBOUND, + ARQ_BIND_FAILED, + ], } # Device type -DEVICE_TYPE = (DEVICE_GPU, DEVICE_FPGA, DEVICE_AICHIP, DEVICE_QAT, DEVICE_NIC, - DEVICE_SSD) +DEVICE_TYPE = ( + DEVICE_GPU, + DEVICE_FPGA, + DEVICE_AICHIP, + DEVICE_QAT, + DEVICE_NIC, + DEVICE_SSD, +) # Device type @@ -73,11 +99,14 @@ DEVICE_STATUS = ("enabled", "maintaining") # Attach handle type # 'TEST_PCI': used by fake driver, ignored by Nova virt driver. ATTACH_HANDLE_TYPES = (AH_TYPE_PCI, AH_TYPE_MDEV, AH_TYPE_TEST_PCI) = ( - "PCI", "MDEV", "TEST_PCI") + "PCI", + "MDEV", + "TEST_PCI", +) # Control Path ID type -CPID_TYPE = (CPID_TYPE_PCI) = ("PCI") +CPID_TYPE = CPID_TYPE_PCI = "PCI" # Resource Class @@ -92,26 +121,24 @@ RESOURCES = { } -ACCEL_SPECS = ( - ACCEL_BITSTREAM_ID, - ACCEL_FUNCTION_ID -) = ( +ACCEL_SPECS = (ACCEL_BITSTREAM_ID, ACCEL_FUNCTION_ID) = ( "accel:bitstream_id", - "accel:function_id" + "accel:function_id", ) -SUPPORT_RESOURCES = ( - FPGA, GPU, VGPU, PGPU, QAT, NIC, SSD) = ( - "FPGA", "GPU", "VGPU", "PGPU", "CUSTOM_QAT", "CUSTOM_NIC", "CUSTOM_SSD" +SUPPORT_RESOURCES = (FPGA, GPU, VGPU, PGPU, QAT, NIC, SSD) = ( + "FPGA", + "GPU", + "VGPU", + "PGPU", + "CUSTOM_QAT", + "CUSTOM_NIC", + "CUSTOM_SSD", ) -FPGA_TRAITS = ( - FPGA_FUNCTION_ID, -) = ( - "CUSTOM_FPGA_FUNCTION_ID", -) +FPGA_TRAITS = (FPGA_FUNCTION_ID,) = ("CUSTOM_FPGA_FUNCTION_ID",) RESOURCES_PREFIX = "resources:" diff --git a/cyborg/common/exception.py b/cyborg/common/exception.py index 6358ba92..46795156 100644 --- a/cyborg/common/exception.py +++ b/cyborg/common/exception.py @@ -39,6 +39,7 @@ class CyborgException(Exception): If you need to access the message from an exception you should use str(exc). """ + _msg_fmt = _("An unknown exception occurred.") code = HTTPStatus.INTERNAL_SERVER_ERROR headers = {} @@ -61,8 +62,9 @@ class CyborgException(Exception): # log the issue and the kwargs LOG.exception('Exception in string format operation') for name, value in kwargs.items(): - LOG.error("%(name)s: %(value)s", - {"name": name, "value": value}) + LOG.error( + "%(name)s: %(value)s", {"name": name, "value": value} + ) if CONF.fatal_exception_format_errors: raise @@ -84,8 +86,10 @@ class Forbidden(CyborgException): class ARQBadState(CyborgException): - _msg_fmt = _('Bad state: %(state)s for ARQ: %(uuid)s. ' - 'Expected state(s): %(expected)s') + _msg_fmt = _( + 'Bad state: %(state)s for ARQ: %(uuid)s. ' + 'Expected state(s): %(expected)s' + ) class AttachHandleAlreadyExists(CyborgException): @@ -129,8 +133,7 @@ class ExtArqAlreadyExists(CyborgException): class ExpectedOneObject(CyborgException): - _msg_fmt = _("Expected one object of type %(obj)s " - "but got %(count)s.") + _msg_fmt = _("Expected one object of type %(obj)s but got %(count)s.") class InUse(CyborgException): @@ -155,14 +158,17 @@ class InvalidJsonType(Invalid): class InvalidAPIVersionString(Invalid): - _msg_fmt = _("API Version String %(version)s is of invalid format. Must " - "be of format MajorNum.MinorNum.") + _msg_fmt = _( + "API Version String %(version)s is of invalid format. Must " + "be of format MajorNum.MinorNum." + ) # TODO(All): Consider whether Placement/Image exceptions can be included here. class InvalidAPIResponse(Invalid): - _msg_fmt = _('Bad API response from %(service)s for %(api)s API. ' - 'Details: %(msg)s') + _msg_fmt = _( + 'Bad API response from %(service)s for %(api)s API. Details: %(msg)s' + ) # Cannot be templated as the error syntax varies. @@ -203,8 +209,9 @@ class ServiceNotFound(NotFound): class ConfGroupForServiceTypeNotFound(ServiceNotFound): - _msg_fmt = _("No conf group name could be found for service type " - "%(stype)s.") + _msg_fmt = _( + "No conf group name could be found for service type %(stype)s." + ) class InvalidDeployType(CyborgException): @@ -233,18 +240,23 @@ class PlacementEndpointNotFound(NotFound): class PlacementResourceProviderNotFound(NotFound): - _msg_fmt = _("Placement resource provider not found: " - "%(resource_provider)s.") + _msg_fmt = _( + "Placement resource provider not found: %(resource_provider)s." + ) class PlacementInventoryNotFound(NotFound): - _msg_fmt = _("Placement inventory not found for resource provider " - "%(resource_provider)s, resource class %(resource_class)s.") + _msg_fmt = _( + "Placement inventory not found for resource provider " + "%(resource_provider)s, resource class %(resource_class)s." + ) class PlacementInventoryUpdateConflict(Conflict): - _msg_fmt = _("Placement inventory update conflict for resource provider " - "%(resource_provider)s, resource class %(resource_class)s.") + _msg_fmt = _( + "Placement inventory update conflict for resource provider " + "%(resource_provider)s, resource class %(resource_class)s." + ) class ObjectActionError(CyborgException): @@ -274,13 +286,15 @@ class ResourceProviderRetrievalFailed(CyborgException): class ResourceProviderAggregateRetrievalFailed(CyborgException): - _msg_fmt = _("Failed to get aggregates for resource provider with UUID" - " %(uuid)s") + _msg_fmt = _( + "Failed to get aggregates for resource provider with UUID %(uuid)s" + ) class ResourceProviderTraitRetrievalFailed(CyborgException): - _msg_fmt = _("Failed to get traits for resource provider with UUID" - " %(uuid)s") + _msg_fmt = _( + "Failed to get traits for resource provider with UUID %(uuid)s" + ) class ResourceProviderCreationFailed(CyborgException): @@ -292,13 +306,16 @@ class ResourceProviderDeletionFailed(CyborgException): class ResourceProviderUpdateFailed(CyborgException): - _msg_fmt = _("Failed to update resource provider via URL %(url)s: " - "%(error)s") + _msg_fmt = _( + "Failed to update resource provider via URL %(url)s: %(error)s" + ) class ResourceProviderSyncFailed(CyborgException): - _msg_fmt = _("Failed to synchronize the placement service with resource " - "provider information supplied by the compute host.") + _msg_fmt = _( + "Failed to synchronize the placement service with resource " + "provider information supplied by the compute host." + ) class PlacementAPIConnectFailure(CyborgException): @@ -313,16 +330,22 @@ class PlacementAPIConflict(CyborgException): """Any 409 error from placement APIs should use (a subclass of) this exception. """ - _msg_fmt = _("A conflict was encountered attempting to invoke the " - "placement API at URL %(url)s: %(error)s") + + _msg_fmt = _( + "A conflict was encountered attempting to invoke the " + "placement API at URL %(url)s: %(error)s" + ) class ResourceProviderUpdateConflict(PlacementAPIConflict): """A 409 caused by generation mismatch from attempting to update an existing provider record or its associated data (aggregates, traits, etc.). """ - _msg_fmt = _("A conflict was encountered attempting to update resource " - "provider %(uuid)s (generation %(generation)d): %(error)s") + + _msg_fmt = _( + "A conflict was encountered attempting to update resource " + "provider %(uuid)s (generation %(generation)d): %(error)s" + ) class TraitCreationFailed(CyborgException): @@ -342,8 +365,10 @@ class InvalidResourceAmount(Invalid): class InvalidInventory(Invalid): - _msg_fmt = _("Inventory for '%(resource_class)s' on " - "resource provider '%(resource_provider)s' invalid.") + _msg_fmt = _( + "Inventory for '%(resource_class)s' on " + "resource provider '%(resource_provider)s' invalid." + ) # An exception with this name is used on both sides of the placement/ @@ -351,8 +376,10 @@ class InvalidInventory(Invalid): class InventoryInUse(InvalidInventory): # NOTE(mriedem): This message cannot change without impacting the # cyborg.services.client.report._RE_INV_IN_USE regex. - _msg_fmt = _("Inventory for '%(resource_classes)s' on " - "resource provider '%(resource_provider)s' in use.") + _msg_fmt = _( + "Inventory for '%(resource_classes)s' on " + "resource provider '%(resource_provider)s' in use." + ) class QuotaNotFound(NotFound): @@ -372,8 +399,7 @@ class InvalidReservationExpiration(Invalid): class GlanceConnectionFailed(CyborgException): - _msg_fmt = _("Connection to glance host %(server)s failed: " - "%(reason)s") + _msg_fmt = _("Connection to glance host %(server)s failed: %(reason)s") class ImageUnacceptable(Invalid): @@ -385,8 +411,9 @@ class ImageNotAuthorized(CyborgException): class ImageBadRequest(Invalid): - _msg_fmt = _("Request of image %(image_id)s got BadRequest response: " - "%(response)s") + _msg_fmt = _( + "Request of image %(image_id)s got BadRequest response: %(response)s" + ) class InvalidDriver(Invalid): @@ -406,8 +433,7 @@ class PciDeviceWrongAddressFormat(Invalid): class InvalidType(Invalid): - _msg_fmt = _("Invalid type for %(obj)s: %(type)s." - "Expected: %(expected)s") + _msg_fmt = _("Invalid type for %(obj)s: %(type)s.Expected: %(expected)s") class ResourceNotFound(NotFound): @@ -432,5 +458,7 @@ class PciConfigInvalidWhitelist(Invalid): class PciDeviceInvalidDeviceName(CyborgException): - _msg_fmt = _("Invalid PCI whitelist: The PCI whitelist can specify " - "devname or address, but not both.") + _msg_fmt = _( + "Invalid PCI whitelist: The PCI whitelist can specify " + "devname or address, but not both." + ) diff --git a/cyborg/common/nova_client.py b/cyborg/common/nova_client.py index 381b83cd..a1d76238 100644 --- a/cyborg/common/nova_client.py +++ b/cyborg/common/nova_client.py @@ -26,11 +26,15 @@ class NovaAPI: self.nova_client.default_microversion = '2.82' def _get_acc_changed_events(self, instance_uuid, arq_bind_statuses): - return [{'name': 'accelerator-request-bound', - 'server_uuid': instance_uuid, - 'tag': arq_uuid, - 'status': arq_bind_status, - } for (arq_uuid, arq_bind_status) in arq_bind_statuses] + return [ + { + 'name': 'accelerator-request-bound', + 'server_uuid': instance_uuid, + 'tag': arq_uuid, + 'status': arq_bind_status, + } + for (arq_uuid, arq_bind_status) in arq_bind_statuses + ] def _send_events(self, events): """Send events to Nova external events API. @@ -44,8 +48,10 @@ class NovaAPI: # NOTE(Sundar): Response status should always be 200/207. See # https://review.opendev.org/#/c/698037/ if response.status_code == 200: - LOG.info("Successfully sent events to Nova, events: %(events)s", - {"events": events}) + LOG.info( + "Successfully sent events to Nova, events: %(events)s", + {"events": events}, + ) elif response.status_code == 207: # NOTE(Sundar): If Nova returns per-event code of 422, that # is due to a race condition where Nova has not associated @@ -55,30 +61,42 @@ class NovaAPI: event_codes = {ev['code'] for ev in events} if len(event_codes) == 1: # all events have same event code if event_codes == {422}: - LOG.info('Ignoring Nova notification error that the ' - 'instance %s is not yet associated with a host.', - events[0]['server_uuid']) + LOG.info( + 'Ignoring Nova notification error that the ' + 'instance %s is not yet associated with a host.', + events[0]['server_uuid'], + ) else: - msg = _('Unexpected event code %(code)s ' - 'for instance %(inst)s') - msg = msg % {'code': event_codes.pop(), - 'inst': events[0]["server_uuid"]} + msg = _( + 'Unexpected event code %(code)s for instance %(inst)s' + ) + msg = msg % { + 'code': event_codes.pop(), + 'inst': events[0]["server_uuid"], + } raise exception.InvalidAPIResponse( - service='Nova', api=url[1:], msg=msg) + service='Nova', api=url[1:], msg=msg + ) else: - msg = _('All event responses are expected to ' - 'have the same event code. Instance: %(inst)s') + msg = _( + 'All event responses are expected to ' + 'have the same event code. Instance: %(inst)s' + ) msg = msg % {'inst': events[0]['server_uuid']} raise exception.InvalidAPIResponse( - service='Nova', api=url[1:], msg=msg) + service='Nova', api=url[1:], msg=msg + ) else: # Unexpected return code from Nova msg = _('Failed to send events %(ev)s: HTTP %(code)s: %(txt)s') - msg = msg % {'ev': events, - 'code': response.status_code, - 'txt': response.text} + msg = msg % { + 'ev': events, + 'code': response.status_code, + 'txt': response.text, + } raise exception.InvalidAPIResponse( - service='Nova', api=url[1:], msg=msg) + service='Nova', api=url[1:], msg=msg + ) def notify_binding(self, instance_uuid, arq_bind_statuses): """Notify Nova that ARQ bindings are resolved for a given instance. diff --git a/cyborg/common/placement_client.py b/cyborg/common/placement_client.py index e7b40f5b..8245f6ce 100644 --- a/cyborg/common/placement_client.py +++ b/cyborg/common/placement_client.py @@ -34,20 +34,27 @@ class PlacementClient: self._client = utils.get_sdk_adapter('placement') def get(self, url, version=None, global_request_id=None): - res = self._client.get(url, microversion=version, - global_request_id=global_request_id) + res = self._client.get( + url, microversion=version, global_request_id=global_request_id + ) if res and res.status_code >= 500: raise exception.PlacementServerError( - "Placement Server has some error at this time.") + "Placement Server has some error at this time." + ) LOG.debug('Successfully get resources from placement: %s', url) return res def post(self, url, data, version=None, global_request_id=None): - res = self._client.post(url, json=data, microversion=version, - global_request_id=global_request_id) + res = self._client.post( + url, + json=data, + microversion=version, + global_request_id=global_request_id, + ) if res.status_code >= 500: raise exception.PlacementServerError( - "Placement Server has some error at this time.") + "Placement Server has some error at this time." + ) LOG.debug('Successfully create resources from placement: %s', url) return res @@ -55,31 +62,37 @@ class PlacementClient: kwargs = {} if data is not None: kwargs['json'] = data - res = self._client.put(url, microversion=version, - global_request_id=global_request_id, - **kwargs) + res = self._client.put( + url, + microversion=version, + global_request_id=global_request_id, + **kwargs, + ) if res.status_code >= 500: raise exception.PlacementServerError( - "Placement Server has some error at this time.") + "Placement Server has some error at this time." + ) LOG.debug('Successfully update resources from placement: %s', url) return res def delete(self, url, version=None, global_request_id=None): - res = self._client.delete(url, microversion=version, - global_request_id=global_request_id) + res = self._client.delete( + url, microversion=version, global_request_id=global_request_id + ) if res.status_code >= 500: raise exception.PlacementServerError( - "Placement Server has some error at this time.") + "Placement Server has some error at this time." + ) LOG.debug('Successfully delete resources from placement: %s', url) return res def _get_rp_traits(self, rp_uuid): - resp = self.get(f"/resource_providers/{rp_uuid}/traits", - version='1.6') + resp = self.get(f"/resource_providers/{rp_uuid}/traits", version='1.6') if resp.status_code != 200: raise Exception( f"Failed to get traits for rp {rp_uuid}:" - f" HTTP {resp.status_code}: {resp.text}") + f" HTTP {resp.status_code}: {resp.text}" + ) return resp.json() def _ensure_traits(self, trait_names): @@ -88,8 +101,9 @@ class PlacementClient: for trait_name in trait_names: trait = self.get(f"/traits/{trait_name}", version='1.6') if trait: - LOG.info("Trait %(trait)s already existed", - {"trait": trait_name}) + LOG.info( + "Trait %(trait)s already existed", {"trait": trait_name} + ) continue resp = self.put(f"/traits/{trait_name}", None, version='1.6') if resp.status_code == 201: @@ -97,21 +111,25 @@ class PlacementClient: else: raise Exception( f"Failed to create trait {trait_name}:" - f" HTTP {resp.status_code}: {resp.text}") + f" HTTP {resp.status_code}: {resp.text}" + ) def _put_rp_traits(self, rp_uuid, traits_json): generation = self.get_resource_provider( - resource_provider_uuid=rp_uuid)['generation'] + resource_provider_uuid=rp_uuid + )['generation'] payload = { 'resource_provider_generation': generation, 'traits': traits_json["traits"], } resp = self.put( - f"/resource_providers/{rp_uuid}/traits", payload, version='1.6') + f"/resource_providers/{rp_uuid}/traits", payload, version='1.6' + ) if resp.status_code != 200: raise Exception( f"Failed to set traits to {traits_json} for rp {rp_uuid}:" - f" HTTP {resp.status_code}: {resp.text}") + f" HTTP {resp.status_code}: {resp.text}" + ) def add_traits_to_rp(self, rp_uuid, trait_names): self._ensure_traits(trait_names) @@ -123,9 +141,8 @@ class PlacementClient: def delete_trait_by_name(self, context, rp_uuid, trait_name): traits_json = self._get_rp_traits(rp_uuid) traits = [ - trait for trait in traits_json['traits'] - if trait != trait_name - ] + trait for trait in traits_json['traits'] if trait != trait_name + ] traits_json['traits'] = traits self._put_rp_traits(rp_uuid, traits_json) self._delete_trait(context, trait_name) @@ -133,9 +150,10 @@ class PlacementClient: def delete_traits_with_prefixes(self, context, rp_uuid, trait_prefixes): traits_json = self._get_rp_traits(rp_uuid) traits = [ - trait for trait in traits_json['traits'] - if not any(trait.startswith(prefix) - for prefix in trait_prefixes)] + trait + for trait in traits_json['traits'] + if not any(trait.startswith(prefix) for prefix in trait_prefixes) + ] delete_traits = set(traits_json['traits']) - set(traits) traits_json['traits'] = traits self._put_rp_traits(rp_uuid, traits_json) @@ -147,21 +165,27 @@ class PlacementClient: return response.headers.get(request_id.HTTP_RESP_HEADER_REQUEST_ID) def update_inventory( - self, resource_provider_uuid, inventories, - resource_provider_generation=None, version=None): + self, + resource_provider_uuid, + inventories, + resource_provider_generation=None, + version=None, + ): if resource_provider_generation is None: resource_provider_generation = self.get_resource_provider( - resource_provider_uuid=resource_provider_uuid)['generation'] + resource_provider_uuid=resource_provider_uuid + )['generation'] url = f'/resource_providers/{resource_provider_uuid}/inventories' body = { 'resource_provider_generation': resource_provider_generation, - 'inventories': inventories + 'inventories': inventories, } try: return self.put(url, body, version=version).json() except ks_exc.NotFound: raise exception.PlacementResourceProviderNotFound( - resource_provider=resource_provider_uuid) + resource_provider=resource_provider_uuid + ) def get_resource_provider(self, resource_provider_uuid): """Get resource provider by UUID. @@ -175,10 +199,12 @@ class PlacementClient: return self.get(url).json() except ks_exc.NotFound: raise exception.PlacementResourceProviderNotFound( - resource_provider=resource_provider_uuid) + resource_provider=resource_provider_uuid + ) - def _create_resource_provider(self, context, uuid, name, - parent_provider_uuid=None): + def _create_resource_provider( + self, context, uuid, name, parent_provider_uuid=None + ): """Calls the placement API to create a new resource provider record. :param context: The security context @@ -200,16 +226,21 @@ class PlacementClient: # Bug #1746075: First try the microversion that returns the new # provider's payload. - resp = self.post(url, payload, - version=POST_RPS_RETURNS_PAYLOAD_API_VERSION, - global_request_id=context.global_id) + resp = self.post( + url, + payload, + version=POST_RPS_RETURNS_PAYLOAD_API_VERSION, + global_request_id=context.global_id, + ) placement_req_id = self.get_placement_request_id(resp) if resp: - msg = ("[%(placement_req_id)s] Created resource provider record " - "via placement API for resource provider with UUID " - "%(uuid)s and name %(name)s.") + msg = ( + "[%(placement_req_id)s] Created resource provider record " + "via placement API for resource provider with UUID " + "%(uuid)s and name %(name)s." + ) args = { 'uuid': uuid, 'name': name, @@ -218,21 +249,27 @@ class PlacementClient: LOG.info(msg, args) return resp.json() - def ensure_resource_provider(self, context, uuid, name=None, - parent_provider_uuid=None): + def ensure_resource_provider( + self, context, uuid, name=None, parent_provider_uuid=None + ): resp = self.get(f"/resource_providers/{uuid}", version='1.6') if resp.status_code == 200: - LOG.info("Resource Provider %(uuid)s already exists", - {"uuid": uuid}) + LOG.info( + "Resource Provider %(uuid)s already exists", {"uuid": uuid} + ) else: - LOG.info("Creating resource provider %(provider)s", - {"provider": name or uuid}) + LOG.info( + "Creating resource provider %(provider)s", + {"provider": name or uuid}, + ) try: - resp = self._create_resource_provider(context, uuid, name, - parent_provider_uuid) + resp = self._create_resource_provider( + context, uuid, name, parent_provider_uuid + ) except Exception: raise exception.ResourceProviderCreationFailed( - name=name or uuid) + name=name or uuid + ) return uuid def ensure_resource_classes(self, context, names): @@ -245,12 +282,17 @@ class PlacementClient: if name in orc.STANDARDS: return resp = self.put( - f"/resource_classes/{name}", None, version=version, - global_request_id=context.global_id) + f"/resource_classes/{name}", + None, + version=version, + global_request_id=context.global_id, + ) if not resp: - msg = ("Failed to ensure resource class record with placement " - "API for resource class %(rc_name)s. Got " - "%(status_code)d: %(err_text)s.") + msg = ( + "Failed to ensure resource class record with placement " + "API for resource class %(rc_name)s. Got " + "%(status_code)d: %(err_text)s." + ) args = { 'rc_name': name, 'status_code': resp.status_code, @@ -259,11 +301,15 @@ class PlacementClient: LOG.error(msg, args) raise exception.InvalidResourceClass(resource_class=name) elif resp.status_code == 204: - LOG.info("Resource class %(rc_name)s already exists", - {"rc_name": name}) + LOG.info( + "Resource class %(rc_name)s already exists", + {"rc_name": name}, + ) elif resp.status_code == 201: - LOG.info("Successfully created resource class %(rc_name).", { - "rc_name", name}) + LOG.info( + "Successfully created resource class %(rc_name).", + {"rc_name", name}, + ) def get_providers_in_tree(self, context, uuid): """Queries the placement API for a list of the resource providers in @@ -275,18 +321,22 @@ class PlacementClient: empty if no provider exists with the specified UUID. :raise: ResourceProviderRetrievalFailed on error. """ - resp = self.get(f"/resource_providers?in_tree={uuid}", - version=NESTED_PROVIDER_API_VERSION, - global_request_id=context.global_id) + resp = self.get( + f"/resource_providers?in_tree={uuid}", + version=NESTED_PROVIDER_API_VERSION, + global_request_id=context.global_id, + ) if resp.status_code == 200: return resp.json()['resource_providers'] # Some unexpected error placement_req_id = self.get_placement_request_id(resp) - msg = ("[%(placement_req_id)s] Failed to retrieve resource provider " - "tree from placement API for UUID %(uuid)s. Got " - "%(status_code)d: %(err_text)s.") + msg = ( + "[%(placement_req_id)s] Failed to retrieve resource provider " + "tree from placement API for UUID %(uuid)s. Got " + "%(status_code)d: %(err_text)s." + ) args = { 'uuid': uuid, 'status_code': resp.status_code, @@ -297,22 +347,26 @@ class PlacementClient: raise exception.ResourceProviderRetrievalFailed(uuid=uuid) def delete_provider(self, rp_uuid, global_request_id=None): - resp = self.delete(f'/resource_providers/{rp_uuid}', - global_request_id=global_request_id) + resp = self.delete( + f'/resource_providers/{rp_uuid}', + global_request_id=global_request_id, + ) # Check for 404 since we don't need to warn/raise if we tried to delete # something which doesn"t actually exist. if resp.ok: LOG.info("Deleted resource provider %s", rp_uuid) return - msg = ("[%(placement_req_id)s] Failed to delete resource provider " - "with UUID %(uuid)s from the placement API. Got " - "%(status_code)d: %(err_text)s.") + msg = ( + "[%(placement_req_id)s] Failed to delete resource provider " + "with UUID %(uuid)s from the placement API. Got " + "%(status_code)d: %(err_text)s." + ) args = { 'placement_req_id': self.get_placement_request_id(resp), 'uuid': rp_uuid, 'status_code': resp.status_code, - 'err_text': resp.text + 'err_text': resp.text, } LOG.error(msg, args) # On conflict, the caller may wish to delete allocations and @@ -325,11 +379,14 @@ class PlacementClient: def delete_rc_by_name(self, context, name): """Delete resource class from placement by name.""" resp = self.delete( - f"/resource_classes/{name}", global_request_id=context.global_id) + f"/resource_classes/{name}", global_request_id=context.global_id + ) if not resp: - msg = ("Failed to delete resource class record with placement " - "API for resource class %(rc_name)s. Got " - "%(status_code)d: %(err_text)s.") + msg = ( + "Failed to delete resource class record with placement " + "API for resource class %(rc_name)s. Got " + "%(status_code)d: %(err_text)s." + ) args = { 'rc_name': name, 'status_code': resp.status_code, @@ -337,18 +394,25 @@ class PlacementClient: } LOG.error(msg, args) elif resp.status_code == 204: - LOG.info("Successfully delete resource class %(rc_name).", { - "rc_name", name}) + LOG.info( + "Successfully delete resource class %(rc_name).", + {"rc_name", name}, + ) def _delete_trait(self, context, name): """Delete trait from placement by name.""" version = '1.6' - resp = self.delete(f"/traits/{name}", version=version, - global_request_id=context.global_id) + resp = self.delete( + f"/traits/{name}", + version=version, + global_request_id=context.global_id, + ) if not resp: - msg = ("Failed to delete trait record with placement " - "API for trait %(trait_name)s. Got " - "%(status_code)d: %(err_text)s.") + msg = ( + "Failed to delete trait record with placement " + "API for trait %(trait_name)s. Got " + "%(status_code)d: %(err_text)s." + ) args = { 'trait_name': name, 'status_code': resp.status_code, @@ -356,8 +420,10 @@ class PlacementClient: } LOG.error(msg, args) elif resp.status_code == 204: - LOG.info("Successfully delete trait %(trait_name).", { - "trait_name", name}) + LOG.info( + "Successfully delete trait %(trait_name).", + {"trait_name", name}, + ) def update_rp_inventory_reserved(self, rp_uuid, resource, total, reserved): update_inventory = {resource: {"total": total, "reserved": reserved}} diff --git a/cyborg/common/policy.py b/cyborg/common/policy.py index 746bb115..a27c56d3 100644 --- a/cyborg/common/policy.py +++ b/cyborg/common/policy.py @@ -15,8 +15,9 @@ """legacy old_policies, the following old_policies will be removed once - new policies are implemented. +new policies are implemented. """ + from oslo_policy import policy # NOTE: to follow policy-in-code spec, we define defaults for # the granular policies in code, rather than in policy.yaml. @@ -24,73 +25,103 @@ from oslo_policy import policy # depend on their existence throughout the code. accelerator_request_policies = [ - policy.RuleDefault('cyborg:arq:get_all', - 'rule:default', - description='Retrieve accelerator request records.'), - policy.RuleDefault('cyborg:arq:get_one', - 'rule:default', - description='Get an accelerator request record.'), - policy.RuleDefault('cyborg:arq:create', - 'rule:allow', - description='Create accelerator request records.'), - policy.RuleDefault('cyborg:arq:delete', - 'rule:default', - description='Delete accelerator request records.'), - policy.RuleDefault('cyborg:arq:update', - 'rule:default', - description='Update accelerator request records.'), + policy.RuleDefault( + 'cyborg:arq:get_all', + 'rule:default', + description='Retrieve accelerator request records.', + ), + policy.RuleDefault( + 'cyborg:arq:get_one', + 'rule:default', + description='Get an accelerator request record.', + ), + policy.RuleDefault( + 'cyborg:arq:create', + 'rule:allow', + description='Create accelerator request records.', + ), + policy.RuleDefault( + 'cyborg:arq:delete', + 'rule:default', + description='Delete accelerator request records.', + ), + policy.RuleDefault( + 'cyborg:arq:update', + 'rule:default', + description='Update accelerator request records.', + ), ] device_policies = [ - policy.RuleDefault('cyborg:device:get_one', - 'rule:allow', - description='Show device detail'), - policy.RuleDefault('cyborg:device:get_all', - 'rule:allow', - description='Retrieve all device records'), - policy.RuleDefault('cyborg:device:disable', - 'rule:admin_api', - description='Disable a device'), - policy.RuleDefault('cyborg:device:enable', - 'rule:admin_api', - description='Enable a device'), + policy.RuleDefault( + 'cyborg:device:get_one', 'rule:allow', description='Show device detail' + ), + policy.RuleDefault( + 'cyborg:device:get_all', + 'rule:allow', + description='Retrieve all device records', + ), + policy.RuleDefault( + 'cyborg:device:disable', + 'rule:admin_api', + description='Disable a device', + ), + policy.RuleDefault( + 'cyborg:device:enable', 'rule:admin_api', description='Enable a device' + ), ] deployable_policies = [ - policy.RuleDefault('cyborg:deployable:get_one', - 'rule:allow', - description='Show deployable detail'), - policy.RuleDefault('cyborg:deployable:get_all', - 'rule:allow', - description='Retrieve all deployable records'), - policy.RuleDefault('cyborg:deployable:program', - 'rule:allow', - description='FPGA programming.'), + policy.RuleDefault( + 'cyborg:deployable:get_one', + 'rule:allow', + description='Show deployable detail', + ), + policy.RuleDefault( + 'cyborg:deployable:get_all', + 'rule:allow', + description='Retrieve all deployable records', + ), + policy.RuleDefault( + 'cyborg:deployable:program', + 'rule:allow', + description='FPGA programming.', + ), ] attribute_policies = [ - policy.RuleDefault('cyborg:attribute:get_one', - 'rule:allow', - description='Show attribute detail'), - policy.RuleDefault('cyborg:attribute:get_all', - 'rule:allow', - description='Retrieve all attribute records'), - policy.RuleDefault('cyborg:attribute:create', - 'rule:allow', - description='Create an attribute record'), - policy.RuleDefault('cyborg:attribute:delete', - 'rule:allow', - description='Delete attribute records.'), + policy.RuleDefault( + 'cyborg:attribute:get_one', + 'rule:allow', + description='Show attribute detail', + ), + policy.RuleDefault( + 'cyborg:attribute:get_all', + 'rule:allow', + description='Retrieve all attribute records', + ), + policy.RuleDefault( + 'cyborg:attribute:create', + 'rule:allow', + description='Create an attribute record', + ), + policy.RuleDefault( + 'cyborg:attribute:delete', + 'rule:allow', + description='Delete attribute records.', + ), ] fpga_policies = [ - policy.RuleDefault('cyborg:fpga:get_one', - 'rule:allow', - description='Show fpga detail'), - policy.RuleDefault('cyborg:fpga:get_all', - 'rule:allow', - description='Retrieve all fpga records'), - policy.RuleDefault('cyborg:fpga:update', - 'rule:allow', - description='Update fpga records'), + policy.RuleDefault( + 'cyborg:fpga:get_one', 'rule:allow', description='Show fpga detail' + ), + policy.RuleDefault( + 'cyborg:fpga:get_all', + 'rule:allow', + description='Retrieve all fpga records', + ), + policy.RuleDefault( + 'cyborg:fpga:update', 'rule:allow', description='Update fpga records' + ), ] diff --git a/cyborg/common/rpc.py b/cyborg/common/rpc.py index 61fd6be9..8c5d4e01 100644 --- a/cyborg/common/rpc.py +++ b/cyborg/common/rpc.py @@ -35,15 +35,14 @@ EXTRA_EXMODS = [] def init(conf): global TRANSPORT, NOTIFICATION_TRANSPORT, NOTIFIER exmods = get_allowed_exmods() - TRANSPORT = messaging.get_rpc_transport(conf, - allowed_remote_exmods=exmods) + TRANSPORT = messaging.get_rpc_transport(conf, allowed_remote_exmods=exmods) NOTIFICATION_TRANSPORT = messaging.get_notification_transport( - conf, - allowed_remote_exmods=exmods) + conf, allowed_remote_exmods=exmods + ) serializer = RequestContextSerializer(messaging.JsonPayloadSerializer()) - NOTIFIER = messaging.Notifier(NOTIFICATION_TRANSPORT, - serializer=serializer, - topics=['notifications']) + NOTIFIER = messaging.Notifier( + NOTIFICATION_TRANSPORT, serializer=serializer, topics=['notifications'] + ) def cleanup(): @@ -97,8 +96,8 @@ def get_client(target, version_cap=None, serializer=None): assert TRANSPORT is not None serializer = RequestContextSerializer(serializer) return messaging.get_rpc_client( - TRANSPORT, target, version_cap=version_cap, - serializer=serializer) + TRANSPORT, target, version_cap=version_cap, serializer=serializer + ) def get_server(target, endpoints, serializer=None): @@ -106,8 +105,12 @@ def get_server(target, endpoints, serializer=None): access_policy = dispatcher.DefaultRPCAccessPolicy serializer = RequestContextSerializer(serializer) return messaging.get_rpc_server( - TRANSPORT, target, endpoints, - serializer=serializer, access_policy=access_policy) + TRANSPORT, + target, + endpoints, + serializer=serializer, + access_policy=access_policy, + ) def get_notifier(service=None, host=None, publisher_id=None): diff --git a/cyborg/common/service.py b/cyborg/common/service.py index 85bee9a4..30bc693f 100644 --- a/cyborg/common/service.py +++ b/cyborg/common/service.py @@ -58,24 +58,30 @@ class RPCService(service.Service): self.tg.add_dynamic_timer_args( self.manager.periodic_tasks, kwargs={"context": admin_context}, - periodic_interval_max=CONF.periodic_interval) + periodic_interval_max=CONF.periodic_interval, + ) - LOG.info('Created RPC server for service %(service)s on host ' - '%(host)s.', - {'service': self.topic, 'host': self.host}) + LOG.info( + 'Created RPC server for service %(service)s on host %(host)s.', + {'service': self.topic, 'host': self.host}, + ) def stop(self, graceful=True): try: self.rpcserver.stop() self.rpcserver.wait() except Exception as e: - LOG.exception('Service error occurred when stopping the ' - 'RPC server. Error: %s', e) + LOG.exception( + 'Service error occurred when stopping the ' + 'RPC server. Error: %s', + e, + ) super().stop(graceful=graceful) - LOG.info('Stopped RPC server for service %(service)s on host ' - '%(host)s.', - {'service': self.topic, 'host': self.host}) + LOG.info( + 'Stopped RPC server for service %(service)s on host %(host)s.', + {'service': self.topic, 'host': self.host}, + ) def prepare_service(argv=None): @@ -105,17 +111,24 @@ class WSGIService(service.ServiceBase): """ self.name = name self.app = app.load_app() - self.workers = (CONF.api.api_workers or - processutils.get_worker_count()) + self.workers = CONF.api.api_workers or processutils.get_worker_count() if self.workers and self.workers < 1: raise exception.ConfigInvalid( - _("api_workers value of %d is invalid, " - "must be greater than 0.") % self.workers) + _( + "api_workers value of %d is invalid, " + "must be greater than 0." + ) + % self.workers + ) - self.server = wsgi.Server(CONF, self.name, self.app, - host=CONF.api.host_ip, - port=CONF.api.port, - use_ssl=use_ssl) + self.server = wsgi.Server( + CONF, + self.name, + self.app, + host=CONF.api.host_ip, + port=CONF.api.port, + use_ssl=use_ssl, + ) def start(self): """Start serving this service using loaded configuration. diff --git a/cyborg/common/utils.py b/cyborg/common/utils.py index 39613cf6..ffa7ba26 100644 --- a/cyborg/common/utils.py +++ b/cyborg/common/utils.py @@ -49,16 +49,24 @@ def safe_rstrip(value, chars=None): """ if not isinstance(value, str): - LOG.warning("Failed to remove trailing character. Returning " - "original object. Supplied object is not a string: " - "%s,", value) + LOG.warning( + "Failed to remove trailing character. Returning " + "original object. Supplied object is not a string: " + "%s,", + value, + ) return value return value.rstrip(chars) or value -def get_ksa_adapter(service_type, ksa_auth=None, ksa_session=None, - min_version=None, max_version=None): +def get_ksa_adapter( + service_type, + ksa_auth=None, + ksa_session=None, + min_version=None, + max_version=None, +): """Construct a keystoneauth1 Adapter for a given service type. We expect to find a conf group whose name corresponds to the service_type's @@ -109,11 +117,17 @@ def get_ksa_adapter(service_type, ksa_auth=None, ksa_session=None, if not ksa_session: ksa_session = ks_loading.load_session_from_conf_options( - CONF, confgrp, auth=ksa_auth) + CONF, confgrp, auth=ksa_auth + ) return ks_loading.load_adapter_from_conf_options( - CONF, confgrp, session=ksa_session, auth=ksa_auth, - min_version=min_version, max_version=max_version) + CONF, + confgrp, + session=ksa_session, + auth=ksa_auth, + min_version=min_version, + max_version=max_version, + ) def _get_conf_group(service_type): @@ -127,7 +141,8 @@ def _get_conf_group(service_type): def _get_auth_and_session(confgrp): ksa_auth = ks_loading.load_auth_from_conf_options(CONF, confgrp) return ks_loading.load_session_from_conf_options( - CONF, confgrp, auth=ksa_auth) + CONF, confgrp, auth=ksa_auth + ) def get_sdk_adapter(service_type, check_service=False): @@ -148,12 +163,16 @@ def get_sdk_adapter(service_type, check_service=False): sess = _get_auth_and_session(confgrp) try: conn = connection.Connection( - session=sess, oslo_conf=CONF, service_types={service_type}, - strict_proxies=check_service) + session=sess, + oslo_conf=CONF, + service_types={service_type}, + strict_proxies=check_service, + ) except sdk_exc.ServiceDiscoveryException as e: raise exception.ServiceUnavailable( - _("The %(service_type)s service is unavailable: %(error)s") % - {'service_type': service_type, 'error': str(e)}) + _("The %(service_type)s service is unavailable: %(error)s") + % {'service_type': service_type, 'error': str(e)} + ) return getattr(conn, service_type) @@ -195,7 +214,8 @@ def get_endpoint(ksa_adapter): pass raise ks_exc.EndpointNotFound( "Could not find requested endpoint for any of the following " - "interfaces: %s" % interfaces) + "interfaces: %s" % interfaces + ) class _Singleton(type): @@ -205,8 +225,9 @@ class _Singleton(type): def __call__(cls, *args, **kwargs): ins = cls._instances.get(cls) - if not ins or (hasattr(ins, "_reset") - and isinstance(ins, cls) and ins._reset()): + if not ins or ( + hasattr(ins, "_reset") and isinstance(ins, cls) and ins._reset() + ): cls._instances[cls] = super().__call__(*args, **kwargs) return cls._instances[cls] @@ -233,8 +254,13 @@ class ThreadWorks(Singleton): def spawn(self, func, *args, **kwargs): """Put a job in thread pool.""" - LOG.debug("Add an async jobs. func: %s is with parameters args: %s, " - "kwargs: %s", func, args, kwargs) + LOG.debug( + "Add an async jobs. func: %s is with parameters args: %s, " + "kwargs: %s", + func, + args, + kwargs, + ) future = self.executor.submit(func, *args, **kwargs) return future @@ -243,8 +269,13 @@ class ThreadWorks(Singleton): executor = futures.ThreadPoolExecutor() # TODO(Shaohe) every submit func should be wrapped with exception catch job = executor.submit(func, *args, **kwargs) - LOG.debug("Spawn master job. func: %s is with parameters args: %s, " - "kwargs: %s", func, args, kwargs) + LOG.debug( + "Spawn master job. func: %s is with parameters args: %s, " + "kwargs: %s", + func, + args, + kwargs, + ) # NOTE(Shaohe) shutdown should be after job submit executor.shutdown(wait=False) # TODO(Shaohe) we need to consider resource collection such as the @@ -302,17 +333,28 @@ class ThreadWorks(Singleton): yield f.result(), f.exception(), f._state, None else: f = fs.pop() - yield (f.result(end_time - time.time()), - f.exception(), f._state, None) + yield ( + f.result(end_time - time.time()), + f.exception(), + f._state, + None, + ) except Exception as e: err = traceback.format_exc() - LOG.error("Error during check the worker status. Exception " - "info: %s", err) + LOG.error( + "Error during check the worker status. Exception info: %s", + err, + ) if f: - LOG.error("Error during check the worker status. " - "Exception info: %s, result: %s, state: %s. " - "Reason %s", f.exception(), f._result, - f._state, str(e)) + LOG.error( + "Error during check the worker status. " + "Exception info: %s, result: %s, state: %s. " + "Reason %s", + f.exception(), + f._result, + f._state, + str(e), + ) yield f._result, f.exception(), f._state, err finally: # Do best to cancel remain jobs. @@ -351,6 +393,7 @@ def format_tb(tb, limit=None): def wrap_job_tb(msg="Reason: %s"): """Wrap a function with a is_job tag added, and catch Exception.""" + def _wrap_job_tb(method): @wraps(method) def _impl(self, *args, **kwargs): @@ -361,13 +404,16 @@ def wrap_job_tb(msg="Reason: %s"): LOG.error(traceback.format_exc()) raise return output + setattr(_impl, "is_job", True) return _impl + return _wrap_job_tb def factory_register(SuperClass, ClassName): """Register an concrete class to a factory Class.""" + def decorator(Class): # return Class if not hasattr(SuperClass, "_factory"): @@ -375,6 +421,7 @@ def factory_register(SuperClass, ClassName): SuperClass._factory[ClassName] = Class setattr(Class, "_factory_type", ClassName) return Class + return decorator @@ -387,16 +434,23 @@ class FactoryMixin: f = getattr(cls, "_factory", {}) sclass = f.get(typ, None) if sclass: - LOG.info("Find %s of concrete %s by %s.", - sclass.__name__, cls.__name__, typ) + LOG.info( + "Find %s of concrete %s by %s.", + sclass.__name__, + cls.__name__, + typ, + ) return sclass for sclass in cls.__subclasses__(): if typ == getattr(cls, "_factory_type", None): return sclass else: return cls - LOG.info("Use default %s, do not find concrete class" - "by %s.", cls.__name__, typ) + LOG.info( + "Use default %s, do not find concrete classby %s.", + cls.__name__, + typ, + ) def strtime(at): diff --git a/cyborg/conductor/handlers.py b/cyborg/conductor/handlers.py index 31f806f1..55e8657d 100644 --- a/cyborg/conductor/handlers.py +++ b/cyborg/conductor/handlers.py @@ -1,4 +1,3 @@ - # # Licensed under the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. You may obtain diff --git a/cyborg/conductor/manager.py b/cyborg/conductor/manager.py index cf83667e..c43f5d9e 100644 --- a/cyborg/conductor/manager.py +++ b/cyborg/conductor/manager.py @@ -117,24 +117,32 @@ class ConductorManager: old_driver_device_list = DriverDevice.list(context, hostname) # TODO(wangzhh): Remove invalid driver_devices without controlpath_id. # Then diff two driver device list. - self.drv_device_make_diff(context, hostname, - old_driver_device_list, driver_device_list) + self.drv_device_make_diff( + context, hostname, old_driver_device_list, driver_device_list + ) - def drv_device_make_diff(self, context, host, old_driver_device_list, - new_driver_device_list): + def drv_device_make_diff( + self, context, host, old_driver_device_list, new_driver_device_list + ): """Compare new driver-side device object list with the old one in one host. """ LOG.info("Start differing devices.") # TODO(): The placement report will be implemented here. # Use cpid.cpid_info to identify whether the device is the same. - stub_cpid_list = [driver_dev_obj.controlpath_id.cpid_info for - driver_dev_obj in new_driver_device_list - if driver_dev_obj.stub] - new_cpid_list = [driver_dev_obj.controlpath_id.cpid_info for - driver_dev_obj in new_driver_device_list] - old_cpid_list = [driver_dev_obj.controlpath_id.cpid_info for - driver_dev_obj in old_driver_device_list] + stub_cpid_list = [ + driver_dev_obj.controlpath_id.cpid_info + for driver_dev_obj in new_driver_device_list + if driver_dev_obj.stub + ] + new_cpid_list = [ + driver_dev_obj.controlpath_id.cpid_info + for driver_dev_obj in new_driver_device_list + ] + old_cpid_list = [ + driver_dev_obj.controlpath_id.cpid_info + for driver_dev_obj in old_driver_device_list + ] same = set(new_cpid_list) & set(old_cpid_list) - set(stub_cpid_list) added = set(new_cpid_list) - same - set(stub_cpid_list) deleted = set(old_cpid_list) - same - set(stub_cpid_list) @@ -152,10 +160,10 @@ class ConductorManager: try: new_driver_dev_obj.create(context, host) except Exception as exc: - LOG.exception("Failed to add device %(device)s. " - "Reason: %(reason)s", - {'device': new_driver_dev_obj, - 'reason': exc}) + LOG.exception( + "Failed to add device %(device)s. Reason: %(reason)s", + {'device': new_driver_dev_obj, 'reason': exc}, + ) new_driver_dev_obj.destroy(context, host) # TODO(All): If report device data to Placement raise exception, # we should revert driver device created in Cyborg and resources @@ -164,14 +172,14 @@ class ConductorManager: cleanup_inconsistency_resources = False for driver_dep_obj in new_driver_dev_obj.deployable_list: try: - self.get_placement_needed_info_and_report(context, - driver_dep_obj, - host_rp) + self.get_placement_needed_info_and_report( + context, driver_dep_obj, host_rp + ) except Exception as exc: - LOG.info("Failed to add device %(device)s. " - "Reason: %(reason)s", - {'device': new_driver_dev_obj, - 'reason': exc}) + LOG.info( + "Failed to add device %(device)s. Reason: %(reason)s", + {'device': new_driver_dev_obj, 'reason': exc}, + ) cleanup_inconsistency_resources = True break if cleanup_inconsistency_resources: @@ -190,36 +198,55 @@ class ConductorManager: for dev_obj in device_obj_list: # get cpid_obj, could be empty or only one value. cpid_obj = ControlpathID.get_by_device_id_cpidinfo( - context, dev_obj.id, cpid_info) + context, dev_obj.id, cpid_info + ) # find the one cpid_obj with cpid_info if cpid_obj is not None: break - changed_key = ['std_board_info', 'vendor', 'vendor_board_info', - 'model', 'type'] + changed_key = [ + 'std_board_info', + 'vendor', + 'vendor_board_info', + 'model', + 'type', + ] for c_k in changed_key: if getattr(new_driver_dev_obj, c_k) != getattr( - old_driver_dev_obj, c_k): + old_driver_dev_obj, c_k + ): setattr(dev_obj, c_k, getattr(new_driver_dev_obj, c_k)) dev_obj.save(context) # diff the internal layer: driver_deployable - self.drv_deployable_make_diff(context, dev_obj.id, cpid_obj.id, - old_driver_dev_obj.deployable_list, - new_driver_dev_obj.deployable_list, - host_rp) + self.drv_deployable_make_diff( + context, + dev_obj.id, + cpid_obj.id, + old_driver_dev_obj.deployable_list, + new_driver_dev_obj.deployable_list, + host_rp, + ) - def drv_deployable_make_diff(self, context, device_id, cpid_id, - old_driver_dep_list, new_driver_dep_list, - host_rp): + def drv_deployable_make_diff( + self, + context, + device_id, + cpid_id, + old_driver_dep_list, + new_driver_dep_list, + host_rp, + ): """Compare new driver-side deployable object list with the old one in one host. """ # use name to identify whether the deployable is the same. LOG.info("Start differing deploybles.") - new_name_list = [driver_dep_obj.name for driver_dep_obj in - new_driver_dep_list] - old_name_list = [driver_dep_obj.name for driver_dep_obj in - old_driver_dep_list] + new_name_list = [ + driver_dep_obj.name for driver_dep_obj in new_driver_dep_list + ] + old_name_list = [ + driver_dep_obj.name for driver_dep_obj in old_driver_dep_list + ] same = set(new_name_list) & set(old_name_list) added = set(new_name_list) - same deleted = set(old_name_list) - same @@ -234,14 +261,15 @@ class ConductorManager: new_driver_dep_obj = new_driver_dep_list[new_name_list.index(a)] new_driver_dep_obj.create(context, device_id, cpid_id) try: - self.get_placement_needed_info_and_report(context, - new_driver_dep_obj, - host_rp) + self.get_placement_needed_info_and_report( + context, new_driver_dep_obj, host_rp + ) except Exception as exc: - LOG.info("Failed to add deployable %(deployable)s. " - "Reason: %(reason)s", - {'deployable': new_driver_dep_obj, - 'reason': exc}) + LOG.info( + "Failed to add deployable %(deployable)s. " + "Reason: %(reason)s", + {'deployable': new_driver_dep_obj, 'reason': exc}, + ) new_driver_dep_obj.destroy(context, device_id) rp_uuid = self.get_rp_uuid_from_obj(new_driver_dep_obj) # TODO(All): If report deployable data to Placement raise @@ -263,38 +291,50 @@ class ConductorManager: attrs = new_driver_dep_obj.attribute_list resource_class = [i.value for i in attrs if i.key == 'rc'][0] inv_data = _gen_resource_inventory( - resource_class, dep_obj.num_accelerators) + resource_class, dep_obj.num_accelerators + ) self.placement_client.update_inventory(rp_uuid, inv_data) # diff the internal layer: driver_attribute_list new_attribute_list = [] if hasattr(new_driver_dep_obj, 'attribute_list'): new_attribute_list = new_driver_dep_obj.attribute_list - self.drv_attr_make_diff(context, dep_obj.id, - old_driver_dep_obj.attribute_list, - new_attribute_list) + self.drv_attr_make_diff( + context, + dep_obj.id, + old_driver_dep_obj.attribute_list, + new_attribute_list, + ) # diff the internal layer: driver_attach_hanle_list - self.drv_ah_make_diff(context, dep_obj.id, cpid_id, - old_driver_dep_obj.attach_handle_list, - new_driver_dep_obj.attach_handle_list) + self.drv_ah_make_diff( + context, + dep_obj.id, + cpid_id, + old_driver_dep_obj.attach_handle_list, + new_driver_dep_obj.attach_handle_list, + ) - def drv_attr_make_diff(self, context, dep_id, old_driver_attr_list, - new_driver_attr_list): + def drv_attr_make_diff( + self, context, dep_id, old_driver_attr_list, new_driver_attr_list + ): """Diff new driver-side Attribute Object lists with the old one.""" LOG.info("Start differing attributes.") dep_obj = Deployable.get_by_id(context, dep_id) driver_dep = DriverDeployable.get_by_name(context, dep_obj.name) rp_uuid = self.get_rp_uuid_from_obj(driver_dep) - new_key_list = [driver_attr_obj.key for driver_attr_obj in - new_driver_attr_list] - old_key_list = [driver_attr_obj.key for driver_attr_obj in - old_driver_attr_list] + new_key_list = [ + driver_attr_obj.key for driver_attr_obj in new_driver_attr_list + ] + old_key_list = [ + driver_attr_obj.key for driver_attr_obj in old_driver_attr_list + ] same = set(new_key_list) & set(old_key_list) # key is deleted. deleted = set(old_key_list) - same for d in deleted: old_driver_attr_obj = old_driver_attr_list[old_key_list.index(d)] self.placement_client.delete_trait_by_name( - context, rp_uuid, old_driver_attr_obj.value) + context, rp_uuid, old_driver_attr_obj.value + ) old_driver_attr_obj.delete_by_key(context, dep_id, d) # key is added. added = set(new_key_list) - same @@ -302,7 +342,8 @@ class ConductorManager: new_driver_attr_obj = new_driver_attr_list[new_key_list.index(a)] new_driver_attr_obj.create(context, dep_id) self.placement_client.add_traits_to_rp( - rp_uuid, [new_driver_attr_obj.value]) + rp_uuid, [new_driver_attr_obj.value] + ) # key is same, diff the value. for s in same: # value is not same, update @@ -315,28 +356,36 @@ class ConductorManager: # Update traits here. if new_driver_attr_obj.key.startswith("trait"): self.placement_client.delete_trait_by_name( - context, rp_uuid, old_driver_attr_obj.value) + context, rp_uuid, old_driver_attr_obj.value + ) self.placement_client.add_traits_to_rp( - rp_uuid, [new_driver_attr_obj.value]) + rp_uuid, [new_driver_attr_obj.value] + ) # Update resource classes here. if new_driver_attr_obj.key.startswith("rc"): self.placement_client.ensure_resource_classes( - context, [new_driver_attr_obj.value]) + context, [new_driver_attr_obj.value] + ) inv_data = _gen_resource_inventory( - new_driver_attr_obj.value, dep_obj.num_accelerators) + new_driver_attr_obj.value, dep_obj.num_accelerators + ) self.placement_client.update_inventory(rp_uuid, inv_data) self.placement_client.delete_rc_by_name( - context, old_driver_attr_obj.value) + context, old_driver_attr_obj.value + ) @classmethod - def drv_ah_make_diff(cls, context, dep_id, cpid_id, old_driver_ah_list, - new_driver_ah_list): + def drv_ah_make_diff( + cls, context, dep_id, cpid_id, old_driver_ah_list, new_driver_ah_list + ): """Diff new driver-side AttachHandle Object lists with the old one.""" LOG.info("Start differing attach_handles.") - new_info_list = [driver_ah_obj.attach_info for driver_ah_obj in - new_driver_ah_list] - old_info_list = [driver_ah_obj.attach_info for driver_ah_obj in - old_driver_ah_list] + new_info_list = [ + driver_ah_obj.attach_info for driver_ah_obj in new_driver_ah_list + ] + old_info_list = [ + driver_ah_obj.attach_info for driver_ah_obj in old_driver_ah_list + ] same = set(new_info_list) & set(old_info_list) LOG.info('new info list %s', new_info_list) LOG.info('old info list %s', old_info_list) @@ -356,41 +405,45 @@ class ConductorManager: new_driver_ah_obj = new_driver_ah_list[new_info_list.index(s)] old_driver_ah_obj = old_driver_ah_list[old_info_list.index(s)] changed_key = ['attach_type'] - ah_obj = AttachHandle.get_ah_by_depid_attachinfo(context, - dep_id, s) + ah_obj = AttachHandle.get_ah_by_depid_attachinfo( + context, dep_id, s + ) for c_k in changed_key: if getattr(new_driver_ah_obj, c_k) != getattr( - old_driver_ah_obj, c_k): + old_driver_ah_obj, c_k + ): setattr(ah_obj, c_k, getattr(new_driver_ah_obj, c_k)) ah_obj.save(context) def _get_root_provider(self, context, hostname): try: provider = self.placement_client.get( - "resource_providers?name=" + hostname).json() + "resource_providers?name=" + hostname + ).json() pr_uuid = provider["resource_providers"][0]["uuid"] return pr_uuid except (IndexError, KeyError): raise exception.PlacementResourceProviderNotFound( - resource_provider=hostname) + resource_provider=hostname + ) def _get_sub_provider(self, context, parent, name): - old_sub_pr_uuid = str(uuid.uuid3(uuid.NAMESPACE_DNS, - str(name))) + old_sub_pr_uuid = str(uuid.uuid3(uuid.NAMESPACE_DNS, str(name))) new_sub_pr_uuid = self.placement_client.ensure_resource_provider( - context, old_sub_pr_uuid, - name=name, parent_provider_uuid=parent) + context, old_sub_pr_uuid, name=name, parent_provider_uuid=parent + ) if old_sub_pr_uuid == new_sub_pr_uuid: return new_sub_pr_uuid else: raise exception.Conflict() - def provider_report(self, context, name, resource_class, traits, total, - parent): + def provider_report( + self, context, name, resource_class, traits, total, parent + ): self.placement_client.ensure_resource_classes( - context, [resource_class]) - sub_pr_uuid = self._get_sub_provider( - context, parent, name) + context, [resource_class] + ) + sub_pr_uuid = self._get_sub_provider(context, parent, name) result = _gen_resource_inventory(resource_class, total) self.placement_client.update_inventory(sub_pr_uuid, result) # traits = ["CUSTOM_FPGA_INTEL", "CUSTOM_FPGA_INTEL_ARRIA10", @@ -401,16 +454,17 @@ class ConductorManager: self.placement_client.add_traits_to_rp(sub_pr_uuid, traits) return sub_pr_uuid - def get_placement_needed_info_and_report(self, context, obj, - parent_uuid=None): + def get_placement_needed_info_and_report( + self, context, obj, parent_uuid=None + ): pr_name = obj.name attrs = obj.attribute_list resource_class = [i.value for i in attrs if i.key == 'rc'][0] - traits = [i.value for i in attrs - if str(i.key).startswith("trait")] + traits = [i.value for i in attrs if str(i.key).startswith("trait")] total = obj.num_accelerators - rp_uuid = self.provider_report(context, pr_name, resource_class, - traits, total, parent_uuid) + rp_uuid = self.provider_report( + context, pr_name, resource_class, traits, total, parent_uuid + ) dep_obj = Deployable.get_by_name(context, pr_name) dep_obj["rp_uuid"] = rp_uuid dep_obj.save(context) @@ -419,13 +473,16 @@ class ConductorManager: return str(uuid.uuid3(uuid.NAMESPACE_DNS, str(obj.name))) def _delete_provider_and_sub_providers(self, context, rp_uuid): - rp_in_tree = self.placement_client.get_providers_in_tree(context, - rp_uuid) + rp_in_tree = self.placement_client.get_providers_in_tree( + context, rp_uuid + ) for rp in rp_in_tree[::-1]: if rp["parent_provider_uuid"] == rp_uuid or rp["uuid"] == rp_uuid: self.placement_client.delete_provider(rp["uuid"]) - LOG.info("Successfully delete resource provider %(rp_uuid)s", - {"rp_uuid": rp["uuid"]}) + LOG.info( + "Successfully delete resource provider %(rp_uuid)s", + {"rp_uuid": rp["uuid"]}, + ) if rp["uuid"] == rp_uuid: break diff --git a/cyborg/conductor/rpcapi.py b/cyborg/conductor/rpcapi.py index cdc5f685..4216bb78 100644 --- a/cyborg/conductor/rpcapi.py +++ b/cyborg/conductor/rpcapi.py @@ -40,20 +40,23 @@ class ConductorAPI: def __init__(self, topic=None): super().__init__() self.topic = topic or constants.CONDUCTOR_TOPIC - target = messaging.Target(topic=self.topic, - version='1.0') + target = messaging.Target(topic=self.topic, version='1.0') serializer = objects_base.CyborgObjectSerializer() - self.client = rpc.get_client(target, - version_cap=self.RPC_API_VERSION, - serializer=serializer) + self.client = rpc.get_client( + target, version_cap=self.RPC_API_VERSION, serializer=serializer + ) def report_data(self, context, hostname, driver_device_list): """Signal to conductor service to update the cyborg DB :param context: request context. """ cctxt = self.client.prepare(topic=self.topic) - cctxt.call(context, 'report_data', hostname=hostname, - driver_device_list=driver_device_list) + cctxt.call( + context, + 'report_data', + hostname=hostname, + driver_device_list=driver_device_list, + ) def device_profile_create(self, context, obj_devprof): """Signal to conductor service to create a device_profile. @@ -63,8 +66,9 @@ class ConductorAPI: :returns: created device_profile object. """ cctxt = self.client.prepare(topic=self.topic) - return cctxt.call(context, 'device_profile_create', - obj_devprof=obj_devprof) + return cctxt.call( + context, 'device_profile_create', obj_devprof=obj_devprof + ) def device_profile_delete(self, context, obj_devprof): """Signal to conductor service to delete a device_profile. @@ -72,8 +76,7 @@ class ConductorAPI: :param obj_devprof: a device_profile object to delete. """ cctxt = self.client.prepare(topic=self.topic) - cctxt.call(context, 'device_profile_delete', - obj_devprof=obj_devprof) + cctxt.call(context, 'device_profile_delete', obj_devprof=obj_devprof) def arq_create(self, context, obj_extarq, devprof_id): """Signal to conductor service to create an accelerator requests. @@ -85,8 +88,9 @@ class ConductorAPI: :returns: saved accelerator_requests object. """ cctxt = self.client.prepare(topic=self.topic) - return cctxt.call(context, 'arq_create', obj_extarq=obj_extarq, - devprof_id=devprof_id) + return cctxt.call( + context, 'arq_create', obj_extarq=obj_extarq, devprof_id=devprof_id + ) def arq_delete_by_uuid(self, context, arqs): """Signal to conductor service to delete accelerator requests by @@ -116,5 +120,9 @@ class ConductorAPI: :param valid_fields: Dict of valid fields """ cctxt = self.client.prepare(topic=self.topic) - return cctxt.call(context, 'arq_apply_patch', patch_list=patch_list, - valid_fields=valid_fields) + return cctxt.call( + context, + 'arq_apply_patch', + patch_list=patch_list, + valid_fields=valid_fields, + ) diff --git a/cyborg/conf/agent.py b/cyborg/conf/agent.py index 43fbb213..ca96cff9 100644 --- a/cyborg/conf/agent.py +++ b/cyborg/conf/agent.py @@ -20,36 +20,49 @@ from cyborg.common.i18n import _ opts = [ - cfg.ListOpt('enabled_drivers', - default=['fake_driver'], - help=_('The accelerator drivers enabled on this agent. Such ' - 'as intel_fpga_driver, inspur_fpga_driver,' - 'nvidia_gpu_driver, intel_qat_driver,' - 'inspur_nvme_ssd_driver, xilinx_fpga_driver, etc.')), - cfg.IntOpt('resource_provider_startup_retries', - default=3, - min=0, - help=_('Number of times to retry looking up the resource ' - 'provider in Placement during agent startup. Uses ' - 'exponential backoff (1s, 2s, 4s, ...) between ' - 'attempts. Set to 0 to fail immediately without ' - 'retrying.')), - cfg.StrOpt('resource_provider_name', - default=socket.getfqdn(), - sample_default='compute.fully.qualified.name', - help=_('Name of the compute resource provider in Placement. ' - 'This should match the hypervisor_hostname used by Nova ' - 'for this compute host. Defaults to socket.getfqdn() ' - 'which typically matches libvirt behavior. If resource ' - 'provider lookup fails with this name, Cyborg will fall ' - 'back to using CONF.host.')), + cfg.ListOpt( + 'enabled_drivers', + default=['fake_driver'], + help=_( + 'The accelerator drivers enabled on this agent. Such ' + 'as intel_fpga_driver, inspur_fpga_driver,' + 'nvidia_gpu_driver, intel_qat_driver,' + 'inspur_nvme_ssd_driver, xilinx_fpga_driver, etc.' + ), + ), + cfg.IntOpt( + 'resource_provider_startup_retries', + default=3, + min=0, + help=_( + 'Number of times to retry looking up the resource ' + 'provider in Placement during agent startup. Uses ' + 'exponential backoff (1s, 2s, 4s, ...) between ' + 'attempts. Set to 0 to fail immediately without ' + 'retrying.' + ), + ), + cfg.StrOpt( + 'resource_provider_name', + default=socket.getfqdn(), + sample_default='compute.fully.qualified.name', + help=_( + 'Name of the compute resource provider in Placement. ' + 'This should match the hypervisor_hostname used by Nova ' + 'for this compute host. Defaults to socket.getfqdn() ' + 'which typically matches libvirt behavior. If resource ' + 'provider lookup fails with this name, Cyborg will fall ' + 'back to using CONF.host.' + ), + ), ] -opt_group = cfg.OptGroup(name='agent', - title='Options for the cyborg-agent service') +opt_group = cfg.OptGroup( + name='agent', title='Options for the cyborg-agent service' +) -AGENT_OPTS = (opts) +AGENT_OPTS = opts def register_opts(conf): @@ -58,6 +71,4 @@ def register_opts(conf): def list_opts(): - return { - opt_group: AGENT_OPTS - } + return {opt_group: AGENT_OPTS} diff --git a/cyborg/conf/api.py b/cyborg/conf/api.py index 59ef7d9f..b4a3d01c 100644 --- a/cyborg/conf/api.py +++ b/cyborg/conf/api.py @@ -19,42 +19,61 @@ from cyborg.common.i18n import _ opts = [ - cfg.HostAddressOpt('host_ip', - default='127.0.0.1', - help=_('The IP address on which cyborg-api listens.')), - cfg.PortOpt('port', - default=6666, - help=_('The TCP port on which cyborg-api listens.')), - cfg.IntOpt('api_workers', - help=_('Number of workers for OpenStack Cyborg API service. ' - 'The default is equal to the number of CPUs available ' - 'if that can be determined, else a default worker ' - 'count of 1 is returned.')), - cfg.BoolOpt('enable_ssl_api', - default=False, - help=_("Enable the integrated stand-alone API to service " - "requests via HTTPS instead of HTTP. If there is a " - "front-end service performing HTTPS offloading from " - "the service, this option should be False; note, you " - "will want to change public API endpoint to represent " - "SSL termination URL with 'public_endpoint' option.")), - cfg.StrOpt('public_endpoint', - help=_("Public URL to use when building the links to the API " - "resources (for example, \"https://cyborg.rocks:6666\")." - " If None the links will be built using the request's " - "host URL. If the API is operating behind a proxy, you " - "will want to change this to represent the proxy's URL. " - "Defaults to None.")), - cfg.StrOpt('api_paste_config', - default="api-paste.ini", - help="Configuration file for WSGI definition of API."), + cfg.HostAddressOpt( + 'host_ip', + default='127.0.0.1', + help=_('The IP address on which cyborg-api listens.'), + ), + cfg.PortOpt( + 'port', + default=6666, + help=_('The TCP port on which cyborg-api listens.'), + ), + cfg.IntOpt( + 'api_workers', + help=_( + 'Number of workers for OpenStack Cyborg API service. ' + 'The default is equal to the number of CPUs available ' + 'if that can be determined, else a default worker ' + 'count of 1 is returned.' + ), + ), + cfg.BoolOpt( + 'enable_ssl_api', + default=False, + help=_( + "Enable the integrated stand-alone API to service " + "requests via HTTPS instead of HTTP. If there is a " + "front-end service performing HTTPS offloading from " + "the service, this option should be False; note, you " + "will want to change public API endpoint to represent " + "SSL termination URL with 'public_endpoint' option." + ), + ), + cfg.StrOpt( + 'public_endpoint', + help=_( + "Public URL to use when building the links to the API " + "resources (for example, \"https://cyborg.rocks:6666\")." + " If None the links will be built using the request's " + "host URL. If the API is operating behind a proxy, you " + "will want to change this to represent the proxy's URL. " + "Defaults to None." + ), + ), + cfg.StrOpt( + 'api_paste_config', + default="api-paste.ini", + help="Configuration file for WSGI definition of API.", + ), ] -opt_group = cfg.OptGroup(name='api', - title='Options for the cyborg-api service') +opt_group = cfg.OptGroup( + name='api', title='Options for the cyborg-api service' +) -API_OPTS = (opts) +API_OPTS = opts def register_opts(conf): @@ -63,6 +82,4 @@ def register_opts(conf): def list_opts(): - return { - opt_group: API_OPTS - } + return {opt_group: API_OPTS} diff --git a/cyborg/conf/database.py b/cyborg/conf/database.py index 49ddecf3..4f4a860e 100644 --- a/cyborg/conf/database.py +++ b/cyborg/conf/database.py @@ -19,23 +19,22 @@ from cyborg.common.i18n import _ opts = [ - cfg.StrOpt('mysql_engine', - default='InnoDB', - help=_('MySQL engine to use.')) + cfg.StrOpt( + 'mysql_engine', default='InnoDB', help=_('MySQL engine to use.') + ) ] -opt_group = cfg.OptGroup(name='database', - title='Options for the database service') +opt_group = cfg.OptGroup( + name='database', title='Options for the database service' +) def register_opts(conf): conf.register_opts(opts, group=opt_group) -DB_OPTS = (opts) +DB_OPTS = opts def list_opts(): - return { - opt_group: DB_OPTS - } + return {opt_group: DB_OPTS} diff --git a/cyborg/conf/default.py b/cyborg/conf/default.py index 269342d5..a94e5803 100644 --- a/cyborg/conf/default.py +++ b/cyborg/conf/default.py @@ -24,53 +24,70 @@ from cyborg.common.i18n import _ exc_log_opts = [ - cfg.BoolOpt('fatal_exception_format_errors', - default=False, - help=_('Used if there is a formatting error when generating ' - 'an exception message (a programming error). If True, ' - 'raise an exception; if False, use the unformatted ' - 'message.')), + cfg.BoolOpt( + 'fatal_exception_format_errors', + default=False, + help=_( + 'Used if there is a formatting error when generating ' + 'an exception message (a programming error). If True, ' + 'raise an exception; if False, use the unformatted ' + 'message.' + ), + ), ] service_opts = [ - cfg.HostAddressOpt('host', - default=socket.gethostname(), - sample_default='localhost', - help=_('Name of this node. This can be an opaque ' - 'identifier. It is not necessarily a hostname, ' - 'FQDN, or IP address. However, the node name ' - 'must be valid within an AMQP key.') - ), - cfg.IntOpt('periodic_interval', - default=60, - help=_('Default interval (in seconds) for running periodic ' - 'tasks.')), + cfg.HostAddressOpt( + 'host', + default=socket.gethostname(), + sample_default='localhost', + help=_( + 'Name of this node. This can be an opaque ' + 'identifier. It is not necessarily a hostname, ' + 'FQDN, or IP address. However, the node name ' + 'must be valid within an AMQP key.' + ), + ), + cfg.IntOpt( + 'periodic_interval', + default=60, + help=_('Default interval (in seconds) for running periodic tasks.'), + ), cfg.IntOpt( 'thread_pool_size', default=10, - help=_('This option specifies the size of the pool of threads used ' - 'by API to do async jobs.It is possible to limit the number ' - 'of concurrent connections using this option.')), + help=_( + 'This option specifies the size of the pool of threads used ' + 'by API to do async jobs.It is possible to limit the number ' + 'of concurrent connections using this option.' + ), + ), cfg.IntOpt( 'bind_timeout', default=60, - help=_('This option specifies the timeout of async job for ARQ ' - 'bind.')), + help=_('This option specifies the timeout of async job for ARQ bind.'), + ), ] path_opts = [ - cfg.StrOpt('pybasedir', - default=os.path.abspath( - os.path.join(os.path.dirname(__file__), '../')), - sample_default='/usr/lib/python/site-packages/cyborg/cyborg', - help=_('Directory where the cyborg python module is ' - 'installed.')), - cfg.StrOpt('bindir', - default='$pybasedir/bin', - help=_('Directory where cyborg binaries are installed.')), - cfg.StrOpt('state_path', - default='$pybasedir', - help=_("Top-level directory for maintaining cyborg's state.")), + cfg.StrOpt( + 'pybasedir', + default=os.path.abspath( + os.path.join(os.path.dirname(__file__), '../') + ), + sample_default='/usr/lib/python/site-packages/cyborg/cyborg', + help=_('Directory where the cyborg python module is installed.'), + ), + cfg.StrOpt( + 'bindir', + default='$pybasedir/bin', + help=_('Directory where cyborg binaries are installed.'), + ), + cfg.StrOpt( + 'state_path', + default='$pybasedir', + help=_("Top-level directory for maintaining cyborg's state."), + ), ] @@ -80,10 +97,8 @@ def register_opts(conf): conf.register_opts(path_opts) -DEFAULT_OPTS = (exc_log_opts + service_opts + path_opts) +DEFAULT_OPTS = exc_log_opts + service_opts + path_opts def list_opts(): - return { - 'DEFAULT': DEFAULT_OPTS - } + return {'DEFAULT': DEFAULT_OPTS} diff --git a/cyborg/conf/devices.py b/cyborg/conf/devices.py index 13473023..77ec9377 100644 --- a/cyborg/conf/devices.py +++ b/cyborg/conf/devices.py @@ -14,42 +14,35 @@ from oslo_config import cfg -pci_group = cfg.OptGroup( - name='pci', - title='PCI passthrough options') +pci_group = cfg.OptGroup(name='pci', title='PCI passthrough options') -pci_opts = [ - cfg.MultiStrOpt('passthrough_whitelist', - default=[], - help=" ") -] +pci_opts = [cfg.MultiStrOpt('passthrough_whitelist', default=[], help=" ")] nic_group = cfg.OptGroup( name='nic_devices', title='nic device ID options', help="""This is used to config specific nic devices. - """) + """, +) -nic_opts = [ - cfg.ListOpt('enabled_nic_types', - default=[], - help=" ") -] +nic_opts = [cfg.ListOpt('enabled_nic_types', default=[], help=" ")] gpu_group = cfg.OptGroup( name='gpu_devices', title='virtual gpu options', help="""This is used to config vGPU types for nvidia GPU devices. - """) + """, +) vgpu_opts = [ # TODO(bogdando): After Cyborg ensures the safe removal of Placement # resource providers and deployables during upgrades and that can be tested # similar to Nova's test_pci_in_placement backed by the Placement # sqlite fixture, change this option's default to True. - cfg.BoolOpt('filter_sriov_vfs', - default=False, - help=""" + cfg.BoolOpt( + 'filter_sriov_vfs', + default=False, + help=""" Filter out SR-IOV Virtual Function (VF) devices from GPU discovery. When enabled, the NVIDIA GPU driver will skip PCI VF devices and only @@ -63,10 +56,12 @@ up first. Operators should ensure no instances hold VF allocations before enabling this option, as Cyborg does not yet have upgrade-safe protection equivalent to Nova's PCI tracker (which defers removal of allocated devices until the owning instance is deleted). -"""), - cfg.ListOpt('enabled_vgpu_types', - default=[], - help=""" +""", + ), + cfg.ListOpt( + 'enabled_vgpu_types', + default=[], + help=""" The vGPU types enabled in the compute node. Cyborg supports multiple vGPU types in one host. Usually, a single physical @@ -93,7 +88,8 @@ An example is as the following:: [vgpu_nvidia-36] device_addresses = 0000:86:00.0 -""") +""", + ), ] @@ -113,10 +109,16 @@ def register_dynamic_opts(conf): the initial configuration has been loaded. """ opts = [ - cfg.ListOpt('physical_device_mappings', default=[], - item_type=cfg.types.String()), - cfg.ListOpt('function_device_mappings', default=[], - item_type=cfg.types.String()), + cfg.ListOpt( + 'physical_device_mappings', + default=[], + item_type=cfg.types.String(), + ), + cfg.ListOpt( + 'function_device_mappings', + default=[], + item_type=cfg.types.String(), + ), ] # Register the '[nic_type]/physical_device_mappings' and @@ -126,8 +128,9 @@ def register_dynamic_opts(conf): conf.register_opts(opts, group=nic_type) # Register the '[vgpu_$(VGPU_TYPE)]/device_addresses' opts, implicitly # registering the '[vgpu_$(VGPU_TYPE)]' groups in the process - opt = cfg.ListOpt('device_addresses', default=[], - item_type=cfg.types.String()) + opt = cfg.ListOpt( + 'device_addresses', default=[], item_type=cfg.types.String() + ) for vgpu_type in conf.gpu_devices.enabled_vgpu_types: conf.register_opt(opt, group='vgpu_%s' % vgpu_type) diff --git a/cyborg/conf/glance.py b/cyborg/conf/glance.py index 0a9b9299..08085da8 100644 --- a/cyborg/conf/glance.py +++ b/cyborg/conf/glance.py @@ -23,21 +23,25 @@ DEFAULT_SERVICE_TYPE = 'image' glance_group = cfg.OptGroup( 'glance', title='Glance Options', - help='Configuration options for the Image service') + help='Configuration options for the Image service', +) glance_opts = [ - cfg.IntOpt('num_retries', - default=0, - min=0, - help=""" + cfg.IntOpt( + 'num_retries', + default=0, + min=0, + help=""" Enable glance operation retries. Specifies the number of retries when uploading / downloading an image to / from glance. 0 means no retries. -"""), - cfg.BoolOpt('verify_glance_signatures', - default=False, - help=""" +""", + ), + cfg.BoolOpt( + 'verify_glance_signatures', + default=False, + help=""" Enable image signature verification. cyborg uses the image signature metadata from glance and verifies the signature @@ -53,17 +57,19 @@ Related options: for the signature validation. * Both enable_certificate_validation and default_trusted_certificate_ids below depend on this option being enabled. -"""), - cfg.BoolOpt('enable_certificate_validation', - default=False, - deprecated_for_removal=True, - deprecated_since='16.0.0', - deprecated_reason=""" +""", + ), + cfg.BoolOpt( + 'enable_certificate_validation', + default=False, + deprecated_for_removal=True, + deprecated_since='16.0.0', + deprecated_reason=""" This option is intended to ease the transition for deployments leveraging image signature verification. The intended state long-term is for signature verification and certificate validation to always happen together. """, - help=""" + help=""" Enable certificate validation for image signature verification. During image signature verification cyborg will first verify the validity of @@ -81,10 +87,12 @@ Related options: * This option only takes effect if verify_glance_signatures is enabled. * The value of default_trusted_certificate_ids may be used when this option is enabled. -"""), - cfg.ListOpt('default_trusted_certificate_ids', - default=[], - help=""" +""", + ), + cfg.ListOpt( + 'default_trusted_certificate_ids', + default=[], + help=""" List of certificate IDs for certificates that should be trusted. May be used as a default list of trusted certificate IDs for certificate @@ -100,10 +108,13 @@ Related options: * The value of this option may be used if both verify_glance_signatures and enable_certificate_validation are enabled. -"""), - cfg.BoolOpt('debug', - default=False, - help='Enable or disable debug logging with glanceclient.') +""", + ), + cfg.BoolOpt( + 'debug', + default=False, + help='Enable or disable debug logging with glanceclient.', + ), ] deprecated_ksa_opts = { @@ -119,13 +130,21 @@ def register_opts(conf): conf.register_opts(glance_opts, group=glance_group) confutils.register_ksa_opts( - conf, glance_group, DEFAULT_SERVICE_TYPE, include_auth=False, - deprecated_opts=deprecated_ksa_opts) + conf, + glance_group, + DEFAULT_SERVICE_TYPE, + include_auth=False, + deprecated_opts=deprecated_ksa_opts, + ) def list_opts(): - return {glance_group: ( - glance_opts + - ks_loading.get_session_conf_options() + - confutils.get_ksa_adapter_opts(DEFAULT_SERVICE_TYPE, - deprecated_opts=deprecated_ksa_opts))} + return { + glance_group: ( + glance_opts + + ks_loading.get_session_conf_options() + + confutils.get_ksa_adapter_opts( + DEFAULT_SERVICE_TYPE, deprecated_opts=deprecated_ksa_opts + ) + ) + } diff --git a/cyborg/conf/keystone.py b/cyborg/conf/keystone.py index a75beb6e..bf04203d 100644 --- a/cyborg/conf/keystone.py +++ b/cyborg/conf/keystone.py @@ -22,19 +22,21 @@ DEFAULT_SERVICE_TYPE = 'identity' keystone_group = cfg.OptGroup( 'keystone', title='Keystone Options', - help='Configuration options for the identity service') + help='Configuration options for the identity service', +) def register_opts(conf): conf.register_group(keystone_group) - confutils.register_ksa_opts(conf, keystone_group.name, - DEFAULT_SERVICE_TYPE, include_auth=False) + confutils.register_ksa_opts( + conf, keystone_group.name, DEFAULT_SERVICE_TYPE, include_auth=False + ) def list_opts(): return { keystone_group: ( - ks_loading.get_session_conf_options() + - confutils.get_ksa_adapter_opts(DEFAULT_SERVICE_TYPE) + ks_loading.get_session_conf_options() + + confutils.get_ksa_adapter_opts(DEFAULT_SERVICE_TYPE) ) } diff --git a/cyborg/conf/nova.py b/cyborg/conf/nova.py index edd5dab4..2b94892a 100644 --- a/cyborg/conf/nova.py +++ b/cyborg/conf/nova.py @@ -22,7 +22,8 @@ DEFAULT_SERVICE_TYPE = 'compute' nova_group = cfg.OptGroup( 'nova', title='Nova Service Options', - help="Configuration options for connecting to the Nova API service") + help="Configuration options for connecting to the Nova API service", +) def register_opts(conf): @@ -33,10 +34,11 @@ def register_opts(conf): def list_opts(): return { nova_group.name: ( - ks_loading.get_session_conf_options() + - ks_loading.get_auth_common_conf_options() + - ks_loading.get_auth_plugin_conf_options('password') + - ks_loading.get_auth_plugin_conf_options('v2password') + - ks_loading.get_auth_plugin_conf_options('v3password') + - confutils.get_ksa_adapter_opts(DEFAULT_SERVICE_TYPE)) + ks_loading.get_session_conf_options() + + ks_loading.get_auth_common_conf_options() + + ks_loading.get_auth_plugin_conf_options('password') + + ks_loading.get_auth_plugin_conf_options('v2password') + + ks_loading.get_auth_plugin_conf_options('v3password') + + confutils.get_ksa_adapter_opts(DEFAULT_SERVICE_TYPE) + ) } diff --git a/cyborg/conf/opts.py b/cyborg/conf/opts.py index c1eb3d6d..6cdc525e 100644 --- a/cyborg/conf/opts.py +++ b/cyborg/conf/opts.py @@ -51,9 +51,11 @@ def _import_modules(module_names): for modname in module_names: mod = importlib.import_module("cyborg.conf." + modname) if not hasattr(mod, LIST_OPTS_FUNC_NAME): - msg = "The module 'zun.conf.%s' should have a '%s' "\ - "function which returns the config options." % \ - (modname, LIST_OPTS_FUNC_NAME) + msg = ( + "The module 'zun.conf.%s' should have a '%s' " + "function which returns the config options." + % (modname, LIST_OPTS_FUNC_NAME) + ) raise AttributeError(msg) else: imported_modules.append(mod) diff --git a/cyborg/conf/placement.py b/cyborg/conf/placement.py index cf63314f..b6dfcf6d 100644 --- a/cyborg/conf/placement.py +++ b/cyborg/conf/placement.py @@ -22,7 +22,8 @@ DEFAULT_SERVICE_TYPE = 'placement' placement_group = cfg.OptGroup( PLACEMENT_CONF_SECTION, title='Placement Service Options', - help="Configuration options for connecting to the placement API service") + help="Configuration options for connecting to the placement API service", +) def register_opts(conf): @@ -33,10 +34,11 @@ def register_opts(conf): def list_opts(): return { PLACEMENT_CONF_SECTION: ( - ks_loading.get_session_conf_options() + - ks_loading.get_auth_common_conf_options() + - ks_loading.get_auth_plugin_conf_options('password') + - ks_loading.get_auth_plugin_conf_options('v2password') + - ks_loading.get_auth_plugin_conf_options('v3password') + - confutils.get_ksa_adapter_opts(DEFAULT_SERVICE_TYPE)) + ks_loading.get_session_conf_options() + + ks_loading.get_auth_common_conf_options() + + ks_loading.get_auth_plugin_conf_options('password') + + ks_loading.get_auth_plugin_conf_options('v2password') + + ks_loading.get_auth_plugin_conf_options('v3password') + + confutils.get_ksa_adapter_opts(DEFAULT_SERVICE_TYPE) + ) } diff --git a/cyborg/conf/service_token.py b/cyborg/conf/service_token.py index 5fb3e959..5e4379df 100644 --- a/cyborg/conf/service_token.py +++ b/cyborg/conf/service_token.py @@ -22,15 +22,17 @@ service_user = cfg.OptGroup( Configuration options for service to service authentication using a service token. These options allow sending a service token along with the user's token when contacting external REST APIs. -""" +""", ) service_user_opts = [ - cfg.BoolOpt('send_service_user_token', - default=False, - help=""" + cfg.BoolOpt( + 'send_service_user_token', + default=False, + help=""" When True, if sending a user token to a REST API, also send a service token. -"""), +""", + ), ] @@ -45,10 +47,11 @@ def register_opts(conf): def list_opts(): return { service_user: ( - service_user_opts + - ks_loading.get_session_conf_options() + - ks_loading.get_auth_common_conf_options() + - ks_loading.get_auth_plugin_conf_options('password') + - ks_loading.get_auth_plugin_conf_options('v2password') + - ks_loading.get_auth_plugin_conf_options('v3password')) + service_user_opts + + ks_loading.get_session_conf_options() + + ks_loading.get_auth_common_conf_options() + + ks_loading.get_auth_plugin_conf_options('password') + + ks_loading.get_auth_plugin_conf_options('v2password') + + ks_loading.get_auth_plugin_conf_options('v3password') + ) } diff --git a/cyborg/conf/utils.py b/cyborg/conf/utils.py index da964440..cd065fbb 100644 --- a/cyborg/conf/utils.py +++ b/cyborg/conf/utils.py @@ -15,6 +15,7 @@ This module does not provide any actual conf options. """ + from keystoneauth1 import loading as ks_loading from oslo_config import cfg @@ -33,8 +34,9 @@ def get_ksa_adapter_opts(default_service_type, deprecated_opts=None): keystoneauth1.loading.session.Session.register_conf_options :return: List of cfg.Opts. """ - opts = ks_loading.get_adapter_conf_options(include_deprecated=False, - deprecated_opts=deprecated_opts) + opts = ks_loading.get_adapter_conf_options( + include_deprecated=False, deprecated_opts=deprecated_opts + ) for opt in opts[:]: # Remove version-related opts. Required/supported versions are @@ -43,9 +45,11 @@ def get_ksa_adapter_opts(default_service_type, deprecated_opts=None): opts.remove(opt) # Override defaults that make sense for nova - cfg.set_defaults(opts, - valid_interfaces=['internal', 'public'], - service_type=default_service_type) + cfg.set_defaults( + opts, + valid_interfaces=['internal', 'public'], + service_type=default_service_type, + ) return opts @@ -55,8 +59,9 @@ def _dummy_opt(name): return cfg.Opt(name, type=lambda x: None) -def register_ksa_opts(conf, group, default_service_type, include_auth=True, - deprecated_opts=None): +def register_ksa_opts( + conf, group, default_service_type, include_auth=True, deprecated_opts=None +): """Register keystoneauth auth, Session, and Adapter opts. :param conf: oslo_config.cfg.CONF in which to register the options @@ -76,11 +81,16 @@ def register_ksa_opts(conf, group, default_service_type, include_auth=True, # ksa register methods need the group name as a string. oslo doesn't care. group = getattr(group, 'name', group) ks_loading.register_session_conf_options( - conf, group, deprecated_opts=deprecated_opts) + conf, group, deprecated_opts=deprecated_opts + ) if include_auth: ks_loading.register_auth_conf_options(conf, group) - conf.register_opts(get_ksa_adapter_opts( - default_service_type, deprecated_opts=deprecated_opts), group=group) + conf.register_opts( + get_ksa_adapter_opts( + default_service_type, deprecated_opts=deprecated_opts + ), + group=group, + ) # Have to register dummies for the version-related opts we removed for name in _ADAPTER_VERSION_OPTS: conf.register_opt(_dummy_opt(name), group=group) diff --git a/cyborg/context.py b/cyborg/context.py index 9a174b03..bfbbbdef 100644 --- a/cyborg/context.py +++ b/cyborg/context.py @@ -42,12 +42,21 @@ class _ContextAuthPlugin(plugin.BaseAuthPlugin): def get_token(self, *args, **kwargs): return self.auth_token - def get_endpoint(self, session, service_type=None, interface=None, - region_name=None, service_name=None, **kwargs): - return self.service_catalog.url_for(service_type=service_type, - service_name=service_name, - interface=interface, - region_name=region_name) + def get_endpoint( + self, + session, + service_type=None, + interface=None, + region_name=None, + service_name=None, + **kwargs, + ): + return self.service_catalog.url_for( + service_type=service_type, + service_name=service_name, + interface=interface, + region_name=region_name, + ) @enginefacade.transaction_context_provider @@ -58,22 +67,31 @@ class RequestContext(context.RequestContext): """ - def __init__(self, user_id=None, project_id=None, is_admin=None, - read_deleted="no", remote_address=None, timestamp=None, - quota_class=None, service_catalog=None, - user_auth_plugin=None, **kwargs): + def __init__( + self, + user_id=None, + project_id=None, + is_admin=None, + read_deleted="no", + remote_address=None, + timestamp=None, + quota_class=None, + service_catalog=None, + user_auth_plugin=None, + **kwargs, + ): """:param read_deleted: 'no' indicates deleted records are hidden, - 'yes' indicates deleted records are visible, - 'only' indicates that *only* deleted records are visible. + 'yes' indicates deleted records are visible, + 'only' indicates that *only* deleted records are visible. - :param overwrite: Set to False to ensure that the thread-local - copy of the index is not overwritten. + :param overwrite: Set to False to ensure that the thread-local + copy of the index is not overwritten. - :param instance_lock_checked: This is not used and will be removed - in a future release. + :param instance_lock_checked: This is not used and will be removed + in a future release. - :param user_auth_plugin: The auth plugin for the current request's - authentication data. + :param user_auth_plugin: The auth plugin for the current request's + authentication data. """ if user_id: kwargs['user_id'] = user_id @@ -92,8 +110,9 @@ class RequestContext(context.RequestContext): if service_catalog: # Only include required parts of service_catalog - self.service_catalog = [s for s in service_catalog - if s.get('type') in ('image')] + self.service_catalog = [ + s for s in service_catalog if s.get('type') in ('image') + ] else: # if list is empty or none self.service_catalog = [] @@ -113,27 +132,32 @@ class RequestContext(context.RequestContext): # FIXME(dims): defensive hasattr() checks need to be # removed once we figure out why we are seeing stack # traces - values.update({ - 'user_id': getattr(self, 'user_id', None), - 'project_id': getattr(self, 'project_id', None), - 'is_admin': getattr(self, 'is_admin', None), - 'read_deleted': getattr(self, 'read_deleted', 'no'), - 'remote_address': getattr(self, 'remote_address', None), - 'timestamp': utils.strtime(self.timestamp) if hasattr( - self, 'timestamp') else None, - 'request_id': getattr(self, 'request_id', None), - 'quota_class': getattr(self, 'quota_class', None), - 'user_name': getattr(self, 'user_name', None), - 'service_catalog': getattr(self, 'service_catalog', None), - 'project_name': getattr(self, 'project_name', None), - }) + values.update( + { + 'user_id': getattr(self, 'user_id', None), + 'project_id': getattr(self, 'project_id', None), + 'is_admin': getattr(self, 'is_admin', None), + 'read_deleted': getattr(self, 'read_deleted', 'no'), + 'remote_address': getattr(self, 'remote_address', None), + 'timestamp': utils.strtime(self.timestamp) + if hasattr(self, 'timestamp') + else None, + 'request_id': getattr(self, 'request_id', None), + 'quota_class': getattr(self, 'quota_class', None), + 'user_name': getattr(self, 'user_name', None), + 'service_catalog': getattr(self, 'service_catalog', None), + 'project_name': getattr(self, 'project_name', None), + } + ) # NOTE(tonyb): This can be removed once we're certain to have a # RequestContext contains 'is_admin_project', We can only get away with # this because we "know" the default value of 'is_admin_project' which # is very fragile. - values.update({ - 'is_admin_project': getattr(self, 'is_admin_project', True), - }) + values.update( + { + 'is_admin_project': getattr(self, 'is_admin_project', True), + } + ) return values @classmethod @@ -159,10 +183,9 @@ def get_context(): Note that overwrite is False here so this context will not update the thread-local stored context that is used when logging. """ - return RequestContext(user_id=None, - project_id=None, - is_admin=False, - overwrite=False) + return RequestContext( + user_id=None, project_id=None, is_admin=False, overwrite=False + ) def get_admin_context(read_deleted="no"): @@ -172,11 +195,13 @@ def get_admin_context(read_deleted="no"): # use context.elevated() where necessary. Some periodic tasks may use # get_admin_context so that their database calls are not filtered on # project_id. - return RequestContext(user_id=None, - project_id=None, - is_admin=True, - read_deleted=read_deleted, - overwrite=False) + return RequestContext( + user_id=None, + project_id=None, + is_admin=True, + read_deleted=read_deleted, + overwrite=False, + ) def is_user_context(context): diff --git a/cyborg/db/api.py b/cyborg/db/api.py index 0db1309e..239eb0b8 100644 --- a/cyborg/db/api.py +++ b/cyborg/db/api.py @@ -22,9 +22,9 @@ from oslo_db import api as db_api _BACKEND_MAPPING = {'sqlalchemy': 'cyborg.db.sqlalchemy.api'} -IMPL = db_api.DBAPI.from_config(cfg.CONF, - backend_mapping=_BACKEND_MAPPING, - lazy=True) +IMPL = db_api.DBAPI.from_config( + cfg.CONF, backend_mapping=_BACKEND_MAPPING, lazy=True +) def get_instance(): @@ -49,15 +49,22 @@ class Connection(metaclass=abc.ABCMeta): """Get requested device.""" @abc.abstractmethod - def device_list(self, context, limit=None, marker=None, - sort_key=None, sort_dir=None): + def device_list( + self, context, limit=None, marker=None, sort_key=None, sort_dir=None + ): """Get requested list of devices.""" @abc.abstractmethod - def device_list_by_filters(self, context, - filters, sort_key='created_at', - sort_dir='desc', limit=None, - marker=None, columns_to_join=None): + def device_list_by_filters( + self, + context, + filters, + sort_key='created_at', + sort_dir='desc', + limit=None, + marker=None, + columns_to_join=None, + ): """Get requested devices by filters.""" @abc.abstractmethod @@ -90,10 +97,16 @@ class Connection(metaclass=abc.ABCMeta): """Get requested list of device_profiles.""" @abc.abstractmethod - def device_profile_list_by_filters(self, context, - filters, sort_key='created_at', - sort_dir='desc', limit=None, - marker=None, columns_to_join=None): + def device_profile_list_by_filters( + self, + context, + filters, + sort_key='created_at', + sort_dir='desc', + limit=None, + marker=None, + columns_to_join=None, + ): """Get requested list of device_profiles by filters.""" @abc.abstractmethod @@ -126,10 +139,16 @@ class Connection(metaclass=abc.ABCMeta): """Delete a deployable.""" @abc.abstractmethod - def deployable_get_by_filters(self, context, - filters, sort_key='created_at', - sort_dir='desc', limit=None, - marker=None, columns_to_join=None): + def deployable_get_by_filters( + self, + context, + filters, + sort_key='created_at', + sort_dir='desc', + limit=None, + marker=None, + columns_to_join=None, + ): """Get requested deployable by filters.""" @abc.abstractmethod @@ -163,9 +182,17 @@ class Connection(metaclass=abc.ABCMeta): # quota @abc.abstractmethod - def quota_reserve(self, context, resources, deltas, expire, - until_refresh, max_age, project_id=None, - is_allocated_reserve=False): + def quota_reserve( + self, + context, + resources, + deltas, + expire, + until_refresh, + max_age, + project_id=None, + is_allocated_reserve=False, + ): """Check quotas and create appropriate reservations.""" @abc.abstractmethod @@ -207,10 +234,16 @@ class Connection(metaclass=abc.ABCMeta): """Get requested attach_handle""" @abc.abstractmethod - def attach_handle_get_by_filters(self, context, - filters, sort_key='created_at', - sort_dir='desc', limit=None, - marker=None, columns_to_join=None): + def attach_handle_get_by_filters( + self, + context, + filters, + sort_key='created_at', + sort_dir='desc', + limit=None, + marker=None, + columns_to_join=None, + ): """Get requested deployable by filters.""" @abc.abstractmethod @@ -235,10 +268,16 @@ class Connection(metaclass=abc.ABCMeta): """Get requested control path id""" @abc.abstractmethod - def control_path_get_by_filters(self, context, - filters, sort_key='created_at', - sort_dir='desc', limit=None, - marker=None, columns_to_join=None): + def control_path_get_by_filters( + self, + context, + filters, + sort_key='created_at', + sort_dir='desc', + limit=None, + marker=None, + columns_to_join=None, + ): """Get requested deployable by filters.""" @abc.abstractmethod diff --git a/cyborg/db/migration.py b/cyborg/db/migration.py index 5c7f580d..d140afd3 100644 --- a/cyborg/db/migration.py +++ b/cyborg/db/migration.py @@ -26,8 +26,9 @@ def get_backend(): global _IMPL if not _IMPL: cfg.CONF.import_opt('backend', 'oslo_db.options', group='database') - _IMPL = driver.DriverManager("cyborg.database.migration_backend", - cfg.CONF.database.backend).driver + _IMPL = driver.DriverManager( + "cyborg.database.migration_backend", cfg.CONF.database.backend + ).driver return _IMPL diff --git a/cyborg/db/sqlalchemy/alembic/env.py b/cyborg/db/sqlalchemy/alembic/env.py index 4c6a9c48..5f2a1178 100644 --- a/cyborg/db/sqlalchemy/alembic/env.py +++ b/cyborg/db/sqlalchemy/alembic/env.py @@ -52,8 +52,9 @@ def run_migrations_online(): """ engine = enginefacade.writer.get_engine() with engine.connect() as connection: - context.configure(connection=connection, - target_metadata=target_metadata) + context.configure( + connection=connection, target_metadata=target_metadata + ) with context.begin_transaction(): context.run_migrations() diff --git a/cyborg/db/sqlalchemy/alembic/versions/22fb1af2d51e_placeholder.py b/cyborg/db/sqlalchemy/alembic/versions/22fb1af2d51e_placeholder.py index c3f4fb73..9ce2a736 100644 --- a/cyborg/db/sqlalchemy/alembic/versions/22fb1af2d51e_placeholder.py +++ b/cyborg/db/sqlalchemy/alembic/versions/22fb1af2d51e_placeholder.py @@ -11,7 +11,6 @@ revision = '22fb1af2d51e' down_revision = '57539722e5cf' - def upgrade(): # ### commands auto generated by Alembic - please adjust! ### pass diff --git a/cyborg/db/sqlalchemy/alembic/versions/4cc1d79978fc_add_ssd_type.py b/cyborg/db/sqlalchemy/alembic/versions/4cc1d79978fc_add_ssd_type.py index 7cf789b9..d2660b6f 100644 --- a/cyborg/db/sqlalchemy/alembic/versions/4cc1d79978fc_add_ssd_type.py +++ b/cyborg/db/sqlalchemy/alembic/versions/4cc1d79978fc_add_ssd_type.py @@ -15,8 +15,9 @@ down_revision = '899cead40bc9' def upgrade(): - new_device_type = sa.Enum('GPU', 'FPGA', 'AICHIP', 'QAT', 'NIC', 'SSD', - name='device_type') - op.alter_column('devices', 'type', - existing_type=new_device_type, - nullable=False) + new_device_type = sa.Enum( + 'GPU', 'FPGA', 'AICHIP', 'QAT', 'NIC', 'SSD', name='device_type' + ) + op.alter_column( + 'devices', 'type', existing_type=new_device_type, nullable=False + ) diff --git a/cyborg/db/sqlalchemy/alembic/versions/57539722e5cf_placeholder.py b/cyborg/db/sqlalchemy/alembic/versions/57539722e5cf_placeholder.py index bef4840e..c1c062c5 100644 --- a/cyborg/db/sqlalchemy/alembic/versions/57539722e5cf_placeholder.py +++ b/cyborg/db/sqlalchemy/alembic/versions/57539722e5cf_placeholder.py @@ -11,7 +11,6 @@ revision = '57539722e5cf' down_revision = 'c1b5abada09c' - def upgrade(): # ### commands auto generated by Alembic - please adjust! ### pass diff --git a/cyborg/db/sqlalchemy/alembic/versions/589ff20545b7_add_aichip_type.py b/cyborg/db/sqlalchemy/alembic/versions/589ff20545b7_add_aichip_type.py index 30e3ccc9..d9880d4e 100644 --- a/cyborg/db/sqlalchemy/alembic/versions/589ff20545b7_add_aichip_type.py +++ b/cyborg/db/sqlalchemy/alembic/versions/589ff20545b7_add_aichip_type.py @@ -15,8 +15,7 @@ down_revision = 'ede4e3f1a232' def upgrade(): - new_device_type = sa.Enum('GPU', 'FPGA', 'AICHIP', - name='device_type') - op.alter_column('devices', 'type', - existing_type=new_device_type, - nullable=False) + new_device_type = sa.Enum('GPU', 'FPGA', 'AICHIP', name='device_type') + op.alter_column( + 'devices', 'type', existing_type=new_device_type, nullable=False + ) diff --git a/cyborg/db/sqlalchemy/alembic/versions/60d8ac91fd20_add_description_field_to_dps.py b/cyborg/db/sqlalchemy/alembic/versions/60d8ac91fd20_add_description_field_to_dps.py index f5d4611a..be0d96f9 100644 --- a/cyborg/db/sqlalchemy/alembic/versions/60d8ac91fd20_add_description_field_to_dps.py +++ b/cyborg/db/sqlalchemy/alembic/versions/60d8ac91fd20_add_description_field_to_dps.py @@ -15,5 +15,6 @@ down_revision = '7a4fd0fc3f8c' def upgrade(): - op.add_column('device_profiles', sa.Column('description', - sa.Text(), nullable=True)) + op.add_column( + 'device_profiles', sa.Column('description', sa.Text(), nullable=True) + ) diff --git a/cyborg/db/sqlalchemy/alembic/versions/62bcf2610c5d_placeholder.py b/cyborg/db/sqlalchemy/alembic/versions/62bcf2610c5d_placeholder.py index ea3100bc..d8318b3c 100644 --- a/cyborg/db/sqlalchemy/alembic/versions/62bcf2610c5d_placeholder.py +++ b/cyborg/db/sqlalchemy/alembic/versions/62bcf2610c5d_placeholder.py @@ -11,7 +11,6 @@ revision = '62bcf2610c5d' down_revision = '7b696fd94949' - def upgrade(): # ### commands auto generated by Alembic - please adjust! ### pass diff --git a/cyborg/db/sqlalchemy/alembic/versions/6c77bd6afea5_add_device_status.py b/cyborg/db/sqlalchemy/alembic/versions/6c77bd6afea5_add_device_status.py index 6267c07d..f5163e8b 100644 --- a/cyborg/db/sqlalchemy/alembic/versions/6c77bd6afea5_add_device_status.py +++ b/cyborg/db/sqlalchemy/alembic/versions/6c77bd6afea5_add_device_status.py @@ -14,8 +14,11 @@ revision = '6c77bd6afea5' down_revision = '4cc1d79978fc' - def upgrade(): - new_column = sa.Column('status', sa.Enum('enabled', 'maintaining'), - nullable=False, default='enabled') + new_column = sa.Column( + 'status', + sa.Enum('enabled', 'maintaining'), + nullable=False, + default='enabled', + ) op.add_column('devices', new_column) diff --git a/cyborg/db/sqlalchemy/alembic/versions/7a4fd0fc3f8c_placeholder.py b/cyborg/db/sqlalchemy/alembic/versions/7a4fd0fc3f8c_placeholder.py index 22e55d73..a825920a 100644 --- a/cyborg/db/sqlalchemy/alembic/versions/7a4fd0fc3f8c_placeholder.py +++ b/cyborg/db/sqlalchemy/alembic/versions/7a4fd0fc3f8c_placeholder.py @@ -11,7 +11,6 @@ revision = '7a4fd0fc3f8c' down_revision = '62bcf2610c5d' - def upgrade(): # ### commands auto generated by Alembic - please adjust! ### pass diff --git a/cyborg/db/sqlalchemy/alembic/versions/7b696fd94949_placeholder.py b/cyborg/db/sqlalchemy/alembic/versions/7b696fd94949_placeholder.py index d3199052..dc1ac65d 100644 --- a/cyborg/db/sqlalchemy/alembic/versions/7b696fd94949_placeholder.py +++ b/cyborg/db/sqlalchemy/alembic/versions/7b696fd94949_placeholder.py @@ -11,7 +11,6 @@ revision = '7b696fd94949' down_revision = '22fb1af2d51e' - def upgrade(): # ### commands auto generated by Alembic - please adjust! ### pass diff --git a/cyborg/db/sqlalchemy/alembic/versions/7e6f1f107f2b_add_qat_type.py b/cyborg/db/sqlalchemy/alembic/versions/7e6f1f107f2b_add_qat_type.py index 6ed686ba..5ed309ad 100644 --- a/cyborg/db/sqlalchemy/alembic/versions/7e6f1f107f2b_add_qat_type.py +++ b/cyborg/db/sqlalchemy/alembic/versions/7e6f1f107f2b_add_qat_type.py @@ -15,8 +15,9 @@ down_revision = '60d8ac91fd20' def upgrade(): - new_device_type = sa.Enum('GPU', 'FPGA', 'AICHIP', 'QAT', - name='device_type') - op.alter_column('devices', 'type', - existing_type=new_device_type, - nullable=False) + new_device_type = sa.Enum( + 'GPU', 'FPGA', 'AICHIP', 'QAT', name='device_type' + ) + op.alter_column( + 'devices', 'type', existing_type=new_device_type, nullable=False + ) diff --git a/cyborg/db/sqlalchemy/alembic/versions/899cead40bc9_add_nic_type.py b/cyborg/db/sqlalchemy/alembic/versions/899cead40bc9_add_nic_type.py index 929b7066..f2dd19cb 100644 --- a/cyborg/db/sqlalchemy/alembic/versions/899cead40bc9_add_nic_type.py +++ b/cyborg/db/sqlalchemy/alembic/versions/899cead40bc9_add_nic_type.py @@ -15,8 +15,9 @@ down_revision = '7e6f1f107f2b' def upgrade(): - new_device_type = sa.Enum('GPU', 'FPGA', 'AICHIP', 'QAT', 'NIC', - name='device_type') - op.alter_column('devices', 'type', - existing_type=new_device_type, - nullable=False) + new_device_type = sa.Enum( + 'GPU', 'FPGA', 'AICHIP', 'QAT', 'NIC', name='device_type' + ) + op.alter_column( + 'devices', 'type', existing_type=new_device_type, nullable=False + ) diff --git a/cyborg/db/sqlalchemy/alembic/versions/c1b5abada09c_update_for_nova_integ.py b/cyborg/db/sqlalchemy/alembic/versions/c1b5abada09c_update_for_nova_integ.py index 5066a546..fc1c63c9 100644 --- a/cyborg/db/sqlalchemy/alembic/versions/c1b5abada09c_update_for_nova_integ.py +++ b/cyborg/db/sqlalchemy/alembic/versions/c1b5abada09c_update_for_nova_integ.py @@ -33,55 +33,76 @@ def upgrade(): # Update Deployables op.add_column( 'deployables', - sa.Column('rp_uuid', sa.String(length=36), nullable=True)) + sa.Column('rp_uuid', sa.String(length=36), nullable=True), + ) op.add_column( 'deployables', - sa.Column('driver_name', sa.String(length=100), nullable=True)) + sa.Column('driver_name', sa.String(length=100), nullable=True), + ) op.add_column( 'deployables', - sa.Column('bitstream_id', sa.String(length=36), nullable=True)) + sa.Column('bitstream_id', sa.String(length=36), nullable=True), + ) # Update ExtARQ table op.add_column( 'extended_accelerator_requests', - sa.Column('device_profile_group_id', sa.Integer(), nullable=False)) + sa.Column('device_profile_group_id', sa.Integer(), nullable=False), + ) op.add_column( 'extended_accelerator_requests', - sa.Column('instance_uuid', sa.String(length=36), - nullable=True)) - op.create_index('extArqs_instance_uuid_idx', # index name - 'extended_accelerator_requests', # table name - ['instance_uuid'] # columns on which index is defined - ) - op.drop_index('extArqs_device_instance_uuid_idx', # index name - 'extended_accelerator_requests', # table name - ) + sa.Column('instance_uuid', sa.String(length=36), nullable=True), + ) + op.create_index( + 'extArqs_instance_uuid_idx', # index name + 'extended_accelerator_requests', # table name + ['instance_uuid'], # columns on which index is defined + ) + op.drop_index( + 'extArqs_device_instance_uuid_idx', # index name + 'extended_accelerator_requests', # table name + ) op.drop_column('extended_accelerator_requests', 'device_instance_uuid') # Add more valid states for 'state' field - ns = sa.Enum(constants.ARQ_INITIAL, - constants.ARQ_BIND_STARTED, - constants.ARQ_BOUND, - constants.ARQ_UNBOUND, - constants.ARQ_BIND_FAILED, - constants.ARQ_DELETING, name='state') + ns = sa.Enum( + constants.ARQ_INITIAL, + constants.ARQ_BIND_STARTED, + constants.ARQ_BOUND, + constants.ARQ_UNBOUND, + constants.ARQ_BIND_FAILED, + constants.ARQ_DELETING, + name='state', + ) op.alter_column( - 'extended_accelerator_requests', 'state', - existing_type=ns, nullable=False, default=constants.ARQ_INITIAL) + 'extended_accelerator_requests', + 'state', + existing_type=ns, + nullable=False, + default=constants.ARQ_INITIAL, + ) # update attach type fields - new_attach_type = sa.Enum(constants.AH_TYPE_PCI, - constants.AH_TYPE_MDEV, - constants.AH_TYPE_TEST_PCI, - name='attach_type') - op.alter_column('attach_handles', 'attach_type', - existing_type=new_attach_type, - nullable=False) + new_attach_type = sa.Enum( + constants.AH_TYPE_PCI, + constants.AH_TYPE_MDEV, + constants.AH_TYPE_TEST_PCI, + name='attach_type', + ) + op.alter_column( + 'attach_handles', + 'attach_type', + existing_type=new_attach_type, + nullable=False, + ) # Update device_profiles table to make name and uuid unique separately. # Previous schema made the pair unique. - op.create_unique_constraint('uniq_device_profiles0uuid', - 'device_profiles', ['uuid']) - op.create_unique_constraint('uniq_device_profiles0name', - 'device_profiles', ['name']) - op.drop_constraint('uniq_device_profiles0uuid0name', - 'device_profiles', type_='unique') + op.create_unique_constraint( + 'uniq_device_profiles0uuid', 'device_profiles', ['uuid'] + ) + op.create_unique_constraint( + 'uniq_device_profiles0name', 'device_profiles', ['name'] + ) + op.drop_constraint( + 'uniq_device_profiles0uuid0name', 'device_profiles', type_='unique' + ) diff --git a/cyborg/db/sqlalchemy/alembic/versions/d6f033d8fa5b_add_quota_related_tables.py b/cyborg/db/sqlalchemy/alembic/versions/d6f033d8fa5b_add_quota_related_tables.py index 0301fa53..97155e65 100644 --- a/cyborg/db/sqlalchemy/alembic/versions/d6f033d8fa5b_add_quota_related_tables.py +++ b/cyborg/db/sqlalchemy/alembic/versions/d6f033d8fa5b_add_quota_related_tables.py @@ -39,12 +39,17 @@ def upgrade(): sa.Column('in_use', sa.Integer(), nullable=False), sa.Column('reserved', sa.Integer(), nullable=False), sa.Column('until_refresh', sa.Integer(), nullable=True), - sa.PrimaryKeyConstraint('id') + sa.PrimaryKeyConstraint('id'), + ) + op.create_index( + 'ix_quota_usages_project_id', + 'quota_usages', + ['project_id'], + unique=False, + ) + op.create_index( + 'ix_quota_usages_user_id', 'quota_usages', ['user_id'], unique=False ) - op.create_index('ix_quota_usages_project_id', 'quota_usages', - ['project_id'], unique=False) - op.create_index('ix_quota_usages_user_id', 'quota_usages', ['user_id'], - unique=False) op.create_table( 'reservations', @@ -58,13 +63,22 @@ def upgrade(): sa.Column('resource', sa.String(length=255), nullable=True), sa.Column('delta', sa.Integer(), nullable=False), sa.Column('expire', sa.DateTime(), nullable=True), - sa.ForeignKeyConstraint(['usage_id'], ['quota_usages.id'], ), - sa.PrimaryKeyConstraint('id') + sa.ForeignKeyConstraint( + ['usage_id'], + ['quota_usages.id'], + ), + sa.PrimaryKeyConstraint('id'), + ) + op.create_index( + 'ix_reservations_project_id', + 'reservations', + ['project_id'], + unique=False, + ) + op.create_index( + 'ix_reservations_user_id', 'reservations', ['user_id'], unique=False + ) + op.create_index( + 'reservations_uuid_idx', 'reservations', ['uuid'], unique=False ) - op.create_index('ix_reservations_project_id', 'reservations', - ['project_id'], unique=False) - op.create_index('ix_reservations_user_id', 'reservations', ['user_id'], - unique=False) - op.create_index('reservations_uuid_idx', 'reservations', ['uuid'], - unique=False) # ### end Alembic commands ### diff --git a/cyborg/db/sqlalchemy/alembic/versions/ede4e3f1a232_new_db_schema.py b/cyborg/db/sqlalchemy/alembic/versions/ede4e3f1a232_new_db_schema.py index 5827f73d..1c0407b3 100644 --- a/cyborg/db/sqlalchemy/alembic/versions/ede4e3f1a232_new_db_schema.py +++ b/cyborg/db/sqlalchemy/alembic/versions/ede4e3f1a232_new_db_schema.py @@ -42,7 +42,7 @@ def upgrade(): sa.Column('hostname', sa.String(length=255), nullable=False), sa.PrimaryKeyConstraint('id'), mysql_ENGINE='InnoDB', - mysql_DEFAULT_CHARSET='UTF8' + mysql_DEFAULT_CHARSET='UTF8', ) op.create_table( @@ -51,23 +51,32 @@ def upgrade(): sa.Column('updated_at', sa.DateTime(), nullable=True), sa.Column('id', sa.Integer(), nullable=False), sa.Column('uuid', sa.String(length=36), nullable=False, unique=True), - sa.Column('parent_id', sa.Integer(), - sa.ForeignKey('deployables.id', ondelete='CASCADE'), - nullable=True), - sa.Column('root_id', sa.Integer(), - sa.ForeignKey('deployables.id', ondelete='CASCADE'), - nullable=True), + sa.Column( + 'parent_id', + sa.Integer(), + sa.ForeignKey('deployables.id', ondelete='CASCADE'), + nullable=True, + ), + sa.Column( + 'root_id', + sa.Integer(), + sa.ForeignKey('deployables.id', ondelete='CASCADE'), + nullable=True, + ), sa.Column('name', sa.String(length=255), nullable=False), sa.Column('num_accelerators', sa.Integer(), nullable=False), - sa.Column('device_id', sa.Integer(), - sa.ForeignKey('devices.id', ondelete="RESTRICT"), - nullable=False), + sa.Column( + 'device_id', + sa.Integer(), + sa.ForeignKey('devices.id', ondelete="RESTRICT"), + nullable=False, + ), sa.PrimaryKeyConstraint('id'), sa.Index('deployables_parent_id_idx', 'parent_id'), sa.Index('deployables_root_id_idx', 'root_id'), sa.Index('deployables_device_id_idx', 'device_id'), mysql_ENGINE='InnoDB', - mysql_DEFAULT_CHARSET='UTF8' + mysql_DEFAULT_CHARSET='UTF8', ) op.create_table( @@ -76,14 +85,18 @@ def upgrade(): sa.Column('updated_at', sa.DateTime(), nullable=True), sa.Column('id', sa.Integer(), nullable=False), sa.Column('uuid', sa.String(length=36), nullable=False, unique=True), - sa.Column('deployable_id', sa.Integer(), - sa.ForeignKey('deployables.id', ondelete="RESTRICT"), - nullable=False, index=True), + sa.Column( + 'deployable_id', + sa.Integer(), + sa.ForeignKey('deployables.id', ondelete="RESTRICT"), + nullable=False, + index=True, + ), sa.Column('key', sa.Text(), nullable=False), sa.Column('value', sa.Text(), nullable=False), sa.PrimaryKeyConstraint('id'), mysql_ENGINE='InnoDB', - mysql_DEFAULT_CHARSET='UTF8' + mysql_DEFAULT_CHARSET='UTF8', ) op.create_table( @@ -92,14 +105,18 @@ def upgrade(): sa.Column('updated_at', sa.DateTime(), nullable=True), sa.Column('id', sa.Integer(), nullable=False), sa.Column('uuid', sa.String(length=36), nullable=False, unique=True), - sa.Column('device_id', sa.Integer(), - sa.ForeignKey('devices.id', ondelete="RESTRICT"), - nullable=False, index=True), + sa.Column( + 'device_id', + sa.Integer(), + sa.ForeignKey('devices.id', ondelete="RESTRICT"), + nullable=False, + index=True, + ), sa.Column('cpid_type', cpid_type, nullable=False), sa.Column('cpid_info', sa.String(length=255), nullable=False), sa.PrimaryKeyConstraint('id'), mysql_ENGINE='InnoDB', - mysql_DEFAULT_CHARSET='UTF8' + mysql_DEFAULT_CHARSET='UTF8', ) op.create_table( @@ -108,12 +125,18 @@ def upgrade(): sa.Column('updated_at', sa.DateTime(), nullable=True), sa.Column('id', sa.Integer(), nullable=False), sa.Column('uuid', sa.String(length=36), nullable=False, unique=True), - sa.Column('deployable_id', sa.Integer(), - sa.ForeignKey('deployables.id', ondelete="RESTRICT"), - nullable=False), - sa.Column('cpid_id', sa.Integer(), - sa.ForeignKey('controlpath_ids.id', ondelete="RESTRICT"), - nullable=False), + sa.Column( + 'deployable_id', + sa.Integer(), + sa.ForeignKey('deployables.id', ondelete="RESTRICT"), + nullable=False, + ), + sa.Column( + 'cpid_id', + sa.Integer(), + sa.ForeignKey('controlpath_ids.id', ondelete="RESTRICT"), + nullable=False, + ), sa.Column('in_use', sa.Boolean(), default=False), sa.Column('attach_type', attach_type, nullable=False), sa.Column('attach_info', sa.String(length=255), nullable=False), @@ -121,7 +144,7 @@ def upgrade(): sa.Index('attach_handles_deployable_id_idx', 'deployable_id'), sa.Index('attach_handles_cpid_id_idx', 'cpid_id'), mysql_ENGINE='InnoDB', - mysql_DEFAULT_CHARSET='UTF8' + mysql_DEFAULT_CHARSET='UTF8', ) op.create_table( @@ -133,10 +156,11 @@ def upgrade(): sa.Column('name', sa.String(length=255), nullable=False), sa.Column('profile_json', sa.Text(), nullable=False), sa.PrimaryKeyConstraint('id'), - sa.UniqueConstraint('uuid', 'name', - name='uniq_device_profiles0uuid0name'), + sa.UniqueConstraint( + 'uuid', 'name', name='uniq_device_profiles0uuid0name' + ), mysql_ENGINE='InnoDB', - mysql_DEFAULT_CHARSET='UTF8' + mysql_DEFAULT_CHARSET='UTF8', ) op.create_table( @@ -149,21 +173,29 @@ def upgrade(): # set nullable=True but keep this field for further expansion. sa.Column('project_id', sa.String(length=255), nullable=True), sa.Column('state', state, nullable=False, default='Initial'), - sa.Column('device_profile_id', sa.Integer(), - sa.ForeignKey('device_profiles.id', ondelete="RESTRICT"), - nullable=False), + sa.Column( + 'device_profile_id', + sa.Integer(), + sa.ForeignKey('device_profiles.id', ondelete="RESTRICT"), + nullable=False, + ), sa.Column('hostname', sa.String(length=255), nullable=True), sa.Column('device_rp_uuid', sa.String(length=36), nullable=True), - sa.Column('device_instance_uuid', sa.String(length=36), - nullable=True), - sa.Column('attach_handle_id', sa.Integer(), - sa.ForeignKey('attach_handles.id', ondelete="RESTRICT"), - nullable=True), + sa.Column('device_instance_uuid', sa.String(length=36), nullable=True), + sa.Column( + 'attach_handle_id', + sa.Integer(), + sa.ForeignKey('attach_handles.id', ondelete="RESTRICT"), + nullable=True, + ), # Cyborg Private Fields begin here. sa.Column('substate', substate, nullable=False, default='Initial'), - sa.Column('deployable_id', sa.Integer(), - sa.ForeignKey('deployables.id', ondelete="RESTRICT"), - nullable=True), + sa.Column( + 'deployable_id', + sa.Integer(), + sa.ForeignKey('deployables.id', ondelete="RESTRICT"), + nullable=True, + ), sa.PrimaryKeyConstraint('id'), sa.Index('extArqs_project_id_idx', 'project_id'), sa.Index('extArqs_device_profile_id_idx', 'device_profile_id'), @@ -172,5 +204,5 @@ def upgrade(): sa.Index('extArqs_attach_handle_id_idx', 'attach_handle_id'), sa.Index('extArqs_deployable_id_idx', 'deployable_id'), mysql_ENGINE='InnoDB', - mysql_DEFAULT_CHARSET='UTF8' + mysql_DEFAULT_CHARSET='UTF8', ) diff --git a/cyborg/db/sqlalchemy/alembic/versions/f50980397351_initial_migration.py b/cyborg/db/sqlalchemy/alembic/versions/f50980397351_initial_migration.py index 9f30be74..0c785228 100644 --- a/cyborg/db/sqlalchemy/alembic/versions/f50980397351_initial_migration.py +++ b/cyborg/db/sqlalchemy/alembic/versions/f50980397351_initial_migration.py @@ -47,7 +47,7 @@ def upgrade(): sa.PrimaryKeyConstraint('id'), sa.UniqueConstraint('uuid', name='uniq_accelerators0uuid'), mysql_ENGINE='InnoDB', - mysql_DEFAULT_CHARSET='UTF8' + mysql_DEFAULT_CHARSET='UTF8', ) op.create_table( @@ -57,10 +57,18 @@ def upgrade(): sa.Column('id', sa.Integer(), nullable=False), sa.Column('uuid', sa.String(length=36), nullable=False), sa.Column('name', sa.String(length=255), nullable=False), - sa.Column('parent_uuid', sa.String(length=36), - sa.ForeignKey('deployables.uuid'), nullable=True), - sa.Column('root_uuid', sa.String(length=36), - sa.ForeignKey('deployables.uuid'), nullable=True), + sa.Column( + 'parent_uuid', + sa.String(length=36), + sa.ForeignKey('deployables.uuid'), + nullable=True, + ), + sa.Column( + 'root_uuid', + sa.String(length=36), + sa.ForeignKey('deployables.uuid'), + nullable=True, + ), sa.Column('address', sa.Text(), nullable=False), sa.Column('host', sa.Text(), nullable=False), sa.Column('board', sa.Text(), nullable=False), @@ -71,15 +79,18 @@ def upgrade(): sa.Column('assignable', sa.Boolean(), nullable=False), sa.Column('instance_uuid', sa.String(length=36), nullable=True), sa.Column('availability', sa.Text(), nullable=False), - sa.Column('accelerator_id', sa.Integer(), - sa.ForeignKey('accelerators.id', ondelete="CASCADE"), - nullable=False), + sa.Column( + 'accelerator_id', + sa.Integer(), + sa.ForeignKey('accelerators.id', ondelete="CASCADE"), + nullable=False, + ), sa.PrimaryKeyConstraint('id'), sa.UniqueConstraint('uuid', name='uniq_deployables0uuid'), sa.Index('deployables_parent_uuid_idx', 'parent_uuid'), sa.Index('deployables_root_uuid_idx', 'root_uuid'), mysql_ENGINE='InnoDB', - mysql_DEFAULT_CHARSET='UTF8' + mysql_DEFAULT_CHARSET='UTF8', ) op.create_table( @@ -88,14 +99,17 @@ def upgrade(): sa.Column('updated_at', sa.DateTime(), nullable=True), sa.Column('id', sa.Integer(), nullable=False), sa.Column('uuid', sa.String(length=36), nullable=False), - sa.Column('deployable_id', sa.Integer(), - sa.ForeignKey('deployables.id', ondelete="CASCADE"), - nullable=False), + sa.Column( + 'deployable_id', + sa.Integer(), + sa.ForeignKey('deployables.id', ondelete="CASCADE"), + nullable=False, + ), sa.Column('key', sa.Text(), nullable=False), sa.Column('value', sa.Text(), nullable=False), sa.PrimaryKeyConstraint('id'), sa.UniqueConstraint('uuid', name='uniq_attributes0uuid'), sa.Index('attributes_deployable_id_idx', 'deployable_id'), mysql_ENGINE='InnoDB', - mysql_DEFAULT_CHARSET='UTF8' + mysql_DEFAULT_CHARSET='UTF8', ) diff --git a/cyborg/db/sqlalchemy/api.py b/cyborg/db/sqlalchemy/api.py index 9deccb5d..284b12bf 100644 --- a/cyborg/db/sqlalchemy/api.py +++ b/cyborg/db/sqlalchemy/api.py @@ -63,8 +63,7 @@ def model_query(context, model, *args, **kwargs): if kwargs.pop("project_only", False): kwargs["project_id"] = context.project_id - query = sqlalchemyutils.model_query( - model, context.session, args, **kwargs) + query = sqlalchemyutils.model_query(model, context.session, args, **kwargs) return query @@ -87,19 +86,27 @@ def add_identity_filter(query, value): raise exception.InvalidIdentity(identity=value) -def _paginate_query(context, model, query, limit=None, marker=None, - sort_key=None, sort_dir=None): +def _paginate_query( + context, + model, + query, + limit=None, + marker=None, + sort_key=None, + sort_dir=None, +): sort_keys = ['id'] if sort_key and sort_key not in sort_keys: sort_keys.insert(0, sort_key) try: - query = sqlalchemyutils.paginate_query(query, model, limit, sort_keys, - marker=marker, - sort_dir=sort_dir) + query = sqlalchemyutils.paginate_query( + query, model, limit, sort_keys, marker=marker, sort_dir=sort_dir + ) except db_exc.InvalidSortKey: raise exception.InvalidParameterValue( _('The sort_key value "%(key)s" is an invalid field for sorting') - % {'key': sort_key}) + % {'key': sort_key} + ) return query.all() @@ -126,44 +133,47 @@ class Connection(api.Connection): @main_context_manager.reader def attach_handle_get_by_uuid(self, context, uuid): - query = model_query( - context, - models.AttachHandle).filter_by(uuid=uuid) + query = model_query(context, models.AttachHandle).filter_by(uuid=uuid) try: return query.one() except NoResultFound: raise exception.ResourceNotFound( - resource='AttachHandle', - msg='with uuid=%s' % uuid) + resource='AttachHandle', msg='with uuid=%s' % uuid + ) @main_context_manager.reader def attach_handle_get_by_id(self, context, id): - query = model_query( - context, - models.AttachHandle).filter_by(id=id) + query = model_query(context, models.AttachHandle).filter_by(id=id) try: return query.one() except NoResultFound: raise exception.ResourceNotFound( - resource='AttachHandle', - msg='with id=%s' % id) + resource='AttachHandle', msg='with id=%s' % id + ) @main_context_manager.reader def attach_handle_list_by_type(self, context, attach_type='PCI'): - query = model_query(context, models.AttachHandle). \ - filter_by(attach_type=attach_type) + query = model_query(context, models.AttachHandle).filter_by( + attach_type=attach_type + ) try: return query.all() except NoResultFound: raise exception.ResourceNotFound( - resource='AttachHandle', - msg='with type=%s' % attach_type) + resource='AttachHandle', msg='with type=%s' % attach_type + ) @main_context_manager.reader - def attach_handle_get_by_filters(self, context, - filters, sort_key='created_at', - sort_dir='desc', limit=None, - marker=None, join_columns=None): + def attach_handle_get_by_filters( + self, + context, + filters, + sort_key='created_at', + sort_dir='desc', + limit=None, + marker=None, + join_columns=None, + ): """Return attach_handle that match all filters sorted by the given keys. Deleted attach_handle will be returned by default, unless there's a filter that says otherwise. @@ -178,12 +188,23 @@ class Connection(api.Connection): exact_match_filter_names = ['uuid', 'id', 'deployable_id', 'cpid_id'] # Filter the query - query_prefix = self._exact_filter(models.AttachHandle, query_prefix, - filters, exact_match_filter_names) + query_prefix = self._exact_filter( + models.AttachHandle, + query_prefix, + filters, + exact_match_filter_names, + ) if query_prefix is None: return [] - return _paginate_query(context, models.AttachHandle, query_prefix, - limit, marker, sort_key, sort_dir) + return _paginate_query( + context, + models.AttachHandle, + query_prefix, + limit, + marker, + sort_key, + sort_dir, + ) def _exact_filter(self, model, query, filters, legal_keys=None): """Applies exact match filtering to a deployable query. @@ -223,8 +244,9 @@ class Connection(api.Connection): filter_dict[key] = value # Apply simple exact matches if filter_dict: - query = query.filter(*[getattr(model, k) == v - for k, v in filter_dict.items()]) + query = query.filter( + *[getattr(model, k) == v for k, v in filter_dict.items()] + ) return query @main_context_manager.reader @@ -247,8 +269,8 @@ class Connection(api.Connection): ref = query.with_for_update().one() except NoResultFound: raise exception.ResourceNotFound( - resource='AttachHandle', - msg='with uuid=%s' % uuid) + resource='AttachHandle', msg='with uuid=%s' % uuid + ) ref.update(values) return ref @@ -256,17 +278,16 @@ class Connection(api.Connection): @main_context_manager.writer def _do_allocate_attach_handle(self, context, deployable_id): """Atomically get a set of attach handles that match the query - and mark one of those as in_use. + and mark one of those as in_use. """ - query = model_query(context, models.AttachHandle). \ - filter_by(deployable_id=deployable_id, - in_use=False) + query = model_query(context, models.AttachHandle).filter_by( + deployable_id=deployable_id, in_use=False + ) values = {"in_use": True} ref = query.with_for_update().first() if not ref: msg = 'Matching deployable_id {}'.format(deployable_id) - raise exception.ResourceNotFound( - resource='AttachHandle', msg=msg) + raise exception.ResourceNotFound(resource='AttachHandle', msg=msg) ref.update(values) context.session.flush() return ref @@ -274,15 +295,13 @@ class Connection(api.Connection): def attach_handle_allocate(self, context, deployable_id): """Allocate an attach handle with given deployable. - To allocate is to get an unused resource and mark it as in_use. + To allocate is to get an unused resource and mark it as in_use. """ try: - ah = self._do_allocate_attach_handle( - context, deployable_id) + ah = self._do_allocate_attach_handle(context, deployable_id) except NoResultFound: msg = 'Matching deployable_id {}'.format(deployable_id) - raise exception.ResourceNotFound( - resource='AttachHandle', msg=msg) + raise exception.ResourceNotFound(resource='AttachHandle', msg=msg) return ah # NOTE: For deallocate, we use attach_handle_update() @@ -295,8 +314,8 @@ class Connection(api.Connection): count = query.delete() if count != 1: raise exception.ResourceNotFound( - resource='AttachHandle', - msg='with uuid=%s' % uuid) + resource='AttachHandle', msg='with uuid=%s' % uuid + ) @main_context_manager.writer def control_path_create(self, context, values): @@ -315,21 +334,25 @@ class Connection(api.Connection): @main_context_manager.reader def control_path_get_by_uuid(self, context, uuid): - query = model_query( - context, - models.ControlpathID).filter_by(uuid=uuid) + query = model_query(context, models.ControlpathID).filter_by(uuid=uuid) try: return query.one() except NoResultFound: raise exception.ResourceNotFound( - resource='ControlpathID', - msg='with uuid=%s' % uuid) + resource='ControlpathID', msg='with uuid=%s' % uuid + ) @main_context_manager.reader - def control_path_get_by_filters(self, context, - filters, sort_key='created_at', - sort_dir='desc', limit=None, - marker=None, join_columns=None): + def control_path_get_by_filters( + self, + context, + filters, + sort_key='created_at', + sort_dir='desc', + limit=None, + marker=None, + join_columns=None, + ): """Return attach_handle that match all filters sorted by the given keys. Deleted attach_handle will be returned by default, unless there's a filter that says otherwise. @@ -341,16 +364,32 @@ class Connection(api.Connection): query_prefix = model_query(context, models.ControlpathID) filters = copy.deepcopy(filters) - exact_match_filter_names = ['uuid', 'id', 'device_id', 'cpid_info', - 'cpid_type'] + exact_match_filter_names = [ + 'uuid', + 'id', + 'device_id', + 'cpid_info', + 'cpid_type', + ] # Filter the query - query_prefix = self._exact_filter(models.ControlpathID, query_prefix, - filters, exact_match_filter_names) + query_prefix = self._exact_filter( + models.ControlpathID, + query_prefix, + filters, + exact_match_filter_names, + ) if query_prefix is None: return [] - return _paginate_query(context, models.ControlpathID, query_prefix, - limit, marker, sort_key, sort_dir) + return _paginate_query( + context, + models.ControlpathID, + query_prefix, + limit, + marker, + sort_key, + sort_dir, + ) @main_context_manager.reader def control_path_list(self, context): @@ -372,8 +411,8 @@ class Connection(api.Connection): ref = query.with_for_update().one() except NoResultFound: raise exception.ResourceNotFound( - resource='ControlpathID', - msg='with uuid=%s' % uuid) + resource='ControlpathID', msg='with uuid=%s' % uuid + ) ref.update(values) return ref @@ -403,33 +442,35 @@ class Connection(api.Connection): @main_context_manager.reader def device_get(self, context, uuid): - query = model_query( - context, - models.Device).filter_by(uuid=uuid) + query = model_query(context, models.Device).filter_by(uuid=uuid) try: return query.one() except NoResultFound: raise exception.ResourceNotFound( - resource='Device', - msg='with uuid=%s' % uuid) + resource='Device', msg='with uuid=%s' % uuid + ) @main_context_manager.reader def device_get_by_id(self, context, id): - query = model_query( - context, - models.Device).filter_by(id=id) + query = model_query(context, models.Device).filter_by(id=id) try: return query.one() except NoResultFound: raise exception.ResourceNotFound( - resource='Device', - msg='with id=%s' % id) + resource='Device', msg='with id=%s' % id + ) @main_context_manager.reader - def device_list_by_filters(self, context, - filters, sort_key='created_at', - sort_dir='desc', limit=None, - marker=None, join_columns=None): + def device_list_by_filters( + self, + context, + filters, + sort_key='created_at', + sort_dir='desc', + limit=None, + marker=None, + join_columns=None, + ): """Return devices that match all filters sorted by the given keys.""" if limit == 0: @@ -438,23 +479,39 @@ class Connection(api.Connection): query_prefix = model_query(context, models.Device) filters = copy.deepcopy(filters) - exact_match_filter_names = ['uuid', 'id', 'type', - 'vendor', 'model', 'hostname'] + exact_match_filter_names = [ + 'uuid', + 'id', + 'type', + 'vendor', + 'model', + 'hostname', + ] # Filter the query - query_prefix = self._exact_filter(models.Device, query_prefix, - filters, exact_match_filter_names) + query_prefix = self._exact_filter( + models.Device, query_prefix, filters, exact_match_filter_names + ) if query_prefix is None: return [] - return _paginate_query(context, models.Device, query_prefix, - limit, marker, sort_key, sort_dir) + return _paginate_query( + context, + models.Device, + query_prefix, + limit, + marker, + sort_key, + sort_dir, + ) @main_context_manager.reader - def device_list(self, context, limit=None, marker=None, sort_key=None, - sort_dir=None): + def device_list( + self, context, limit=None, marker=None, sort_key=None, sort_dir=None + ): query = model_query(context, models.Device) - return _paginate_query(context, models.Device, query, - limit, marker, sort_key, sort_dir) + return _paginate_query( + context, models.Device, query, limit, marker, sort_key, sort_dir + ) def device_update(self, context, uuid, values): if 'uuid' in values: @@ -476,8 +533,8 @@ class Connection(api.Connection): ref = query.with_for_update().one() except NoResultFound: raise exception.ResourceNotFound( - resource='Device', - msg='with uuid=%s' % uuid) + resource='Device', msg='with uuid=%s' % uuid + ) ref.update(values) return ref @@ -490,8 +547,8 @@ class Connection(api.Connection): count = query.delete() if count != 1: raise exception.ResourceNotFound( - resource='Device', - msg='with uuid=%s' % uuid) + resource='Device', msg='with uuid=%s' % uuid + ) @main_context_manager.writer def device_profile_create(self, context, values): @@ -508,55 +565,57 @@ class Connection(api.Connection): # mysql duplicate key error changed as reference link below: # https://review.opendev.org/c/openstack/oslo.db/+/792124 LOG.info('Duplicate columns are: ', e.columns) - columns = [column.split('0')[1] if 'uniq_' in column else - column for column in e.columns] + columns = [ + column.split('0')[1] if 'uniq_' in column else column + for column in e.columns + ] if 'name' in columns: - raise exception.DuplicateDeviceProfileName( - name=values['name']) + raise exception.DuplicateDeviceProfileName(name=values['name']) else: - raise exception.DeviceProfileAlreadyExists( - uuid=values['uuid']) + raise exception.DeviceProfileAlreadyExists(uuid=values['uuid']) return device_profile @main_context_manager.reader def device_profile_get_by_uuid(self, context, uuid): - query = model_query( - context, - models.DeviceProfile).filter_by(uuid=uuid) + query = model_query(context, models.DeviceProfile).filter_by(uuid=uuid) try: return query.one() except NoResultFound: raise exception.ResourceNotFound( - resource='Device Profile', - msg='with uuid=%s' % uuid) + resource='Device Profile', msg='with uuid=%s' % uuid + ) @main_context_manager.reader def device_profile_get_by_id(self, context, id): - query = model_query( - context, - models.DeviceProfile).filter_by(id=id) + query = model_query(context, models.DeviceProfile).filter_by(id=id) try: return query.one() except NoResultFound: raise exception.ResourceNotFound( - resource='Device Profile', - msg='with id=%s' % id) + resource='Device Profile', msg='with id=%s' % id + ) @main_context_manager.reader def device_profile_get(self, context, name): - query = model_query( - context, models.DeviceProfile).filter_by(name=name) + query = model_query(context, models.DeviceProfile).filter_by(name=name) try: return query.one() except NoResultFound: raise exception.ResourceNotFound( - resource='Device Profile', - msg='with name=%s' % name) + resource='Device Profile', msg='with name=%s' % name + ) @main_context_manager.reader def device_profile_list_by_filters( - self, context, filters, sort_key='created_at', sort_dir='desc', - limit=None, marker=None, join_columns=None): + self, + context, + filters, + sort_key='created_at', + sort_dir='desc', + limit=None, + marker=None, + join_columns=None, + ): if limit == 0: return [] @@ -566,12 +625,23 @@ class Connection(api.Connection): exact_match_filter_names = ['uuid', 'id', 'name'] # Filter the query - query_prefix = self._exact_filter(models.DeviceProfile, query_prefix, - filters, exact_match_filter_names) + query_prefix = self._exact_filter( + models.DeviceProfile, + query_prefix, + filters, + exact_match_filter_names, + ) if query_prefix is None: return [] - return _paginate_query(context, models.DeviceProfile, query_prefix, - limit, marker, sort_key, sort_dir) + return _paginate_query( + context, + models.DeviceProfile, + query_prefix, + limit, + marker, + sort_key, + sort_dir, + ) @main_context_manager.reader def device_profile_list(self, context): @@ -598,8 +668,8 @@ class Connection(api.Connection): ref = query.with_for_update().one() except NoResultFound: raise exception.ResourceNotFound( - resource='Device Profile', - msg='with uuid=%s' % uuid) + resource='Device Profile', msg='with uuid=%s' % uuid + ) ref.update(values) return ref @@ -612,8 +682,8 @@ class Connection(api.Connection): count = query.delete() if count != 1: raise exception.ResourceNotFound( - resource='Device Profile', - msg='with uuid=%s' % uuid) + resource='Device Profile', msg='with uuid=%s' % uuid + ) @main_context_manager.writer def deployable_create(self, context, values): @@ -633,28 +703,27 @@ class Connection(api.Connection): @main_context_manager.reader def deployable_get(self, context, uuid): - query = model_query( - context, - models.Deployable).filter_by(uuid=uuid) + query = model_query(context, models.Deployable).filter_by(uuid=uuid) try: return query.one() except NoResultFound: raise exception.ResourceNotFound( - resource='Deployable', - msg='with uuid=%s' % uuid) + resource='Deployable', msg='with uuid=%s' % uuid + ) @main_context_manager.reader def deployable_get_by_rp_uuid(self, context, rp_uuid): """Get a deployable by resource provider UUID.""" - query = model_query( - context, - models.Deployable).filter_by(rp_uuid=rp_uuid) + query = model_query(context, models.Deployable).filter_by( + rp_uuid=rp_uuid + ) try: return query.one() except NoResultFound: raise exception.ResourceNotFound( resource='Deployable', - msg='with resource provider uuid=%s' % rp_uuid) + msg='with resource provider uuid=%s' % rp_uuid, + ) @main_context_manager.reader def deployable_list(self, context): @@ -682,8 +751,8 @@ class Connection(api.Connection): ref = query.with_for_update().one() except NoResultFound: raise exception.ResourceNotFound( - resource='Deployable', - msg='with uuid=%s' % uuid) + resource='Deployable', msg='with uuid=%s' % uuid + ) ref.update(values) return ref @@ -697,27 +766,44 @@ class Connection(api.Connection): count = query.delete() if count != 1: raise exception.ResourceNotFound( - resource='Deployable', - msg='with uuid=%s' % uuid) + resource='Deployable', msg='with uuid=%s' % uuid + ) - def deployable_get_by_filters(self, context, - filters, sort_key='created_at', - sort_dir='desc', limit=None, - marker=None, join_columns=None): + def deployable_get_by_filters( + self, + context, + filters, + sort_key='created_at', + sort_dir='desc', + limit=None, + marker=None, + join_columns=None, + ): """Return list of deployables matching all filters sorted by the sort_key. See deployable_get_by_filters_sort for more information. """ - return self.deployable_get_by_filters_sort(context, filters, - limit=limit, marker=marker, - join_columns=join_columns, - sort_key=sort_key, - sort_dir=sort_dir) + return self.deployable_get_by_filters_sort( + context, + filters, + limit=limit, + marker=marker, + join_columns=join_columns, + sort_key=sort_key, + sort_dir=sort_dir, + ) @main_context_manager.reader - def deployable_get_by_filters_sort(self, context, filters, limit=None, - marker=None, join_columns=None, - sort_key=None, sort_dir=None): + def deployable_get_by_filters_sort( + self, + context, + filters, + limit=None, + marker=None, + join_columns=None, + sort_key=None, + sort_dir=None, + ): """Return deployables that match all filters sorted by the given keys. Deleted deployables will be returned by default, unless there's a filter that says otherwise. @@ -728,19 +814,34 @@ class Connection(api.Connection): query_prefix = model_query(context, models.Deployable) filters = copy.deepcopy(filters) - exact_match_filter_names = ['id', 'uuid', 'name', - 'parent_id', 'root_id', - 'num_accelerators', 'device_id', - 'driver_name', 'rp_uuid', 'bitstream_id'] + exact_match_filter_names = [ + 'id', + 'uuid', + 'name', + 'parent_id', + 'root_id', + 'num_accelerators', + 'device_id', + 'driver_name', + 'rp_uuid', + 'bitstream_id', + ] # Filter the query - query_prefix = self._exact_filter(models.Deployable, query_prefix, - filters, - exact_match_filter_names) + query_prefix = self._exact_filter( + models.Deployable, query_prefix, filters, exact_match_filter_names + ) if query_prefix is None: return [] - return _paginate_query(context, models.Deployable, query_prefix, - limit, marker, sort_key, sort_dir) + return _paginate_query( + context, + models.Deployable, + query_prefix, + limit, + marker, + sort_key, + sort_dir, + ) @main_context_manager.writer def attribute_create(self, context, values): @@ -755,39 +856,36 @@ class Connection(api.Connection): context.session.add(attribute) context.session.flush() except db_exc.DBDuplicateEntry: - raise exception.AttributeAlreadyExists( - uuid=values['uuid']) + raise exception.AttributeAlreadyExists(uuid=values['uuid']) return attribute @main_context_manager.reader def attribute_get(self, context, uuid): - query = model_query( - context, - models.Attribute).filter_by(uuid=uuid) + query = model_query(context, models.Attribute).filter_by(uuid=uuid) try: return query.one() except NoResultFound: raise exception.ResourceNotFound( - resource='Attribute', - msg='with uuid=%s' % uuid) + resource='Attribute', msg='with uuid=%s' % uuid + ) @main_context_manager.reader def attribute_get_by_deployable_id(self, context, deployable_id): - query = model_query( - context, - models.Attribute).filter_by(deployable_id=deployable_id) + query = model_query(context, models.Attribute).filter_by( + deployable_id=deployable_id + ) return query.all() @main_context_manager.reader def attribute_get_by_filter(self, context, filters): - """Return attributes that matches the filters - """ + """Return attributes that matches the filters""" query_prefix = model_query(context, models.Attribute) exact_match_filter_names = ['deployable_id', 'key'] # Filter the query - query_prefix = self._exact_filter(models.Attribute, query_prefix, - filters, exact_match_filter_names) + query_prefix = self._exact_filter( + models.Attribute, query_prefix, filters, exact_match_filter_names + ) if query_prefix is None: return [] @@ -820,8 +918,8 @@ class Connection(api.Connection): ref = query.with_for_update().one() except NoResultFound: raise exception.ResourceNotFound( - resource='Attribute', - msg='with uuid=%s' % uuid) + resource='Attribute', msg='with uuid=%s' % uuid + ) ref.update(update_fields) return ref @@ -833,8 +931,8 @@ class Connection(api.Connection): count = query.delete() if count != 1: raise exception.ResourceNotFound( - resource='Attribute', - msg='with uuid=%s' % uuid) + resource='Attribute', msg='with uuid=%s' % uuid + ) @main_context_manager.writer def extarq_create(self, context, values): @@ -846,8 +944,9 @@ class Connection(api.Connection): if values.get('device_profile_id'): pass # Already have the devprof id, so nothing to do elif values.get('device_profile_name'): - devprof = self.device_profile_get(context, - values['device_profile_name']) + devprof = self.device_profile_get( + context, values['device_profile_name'] + ) values['device_profile_id'] = devprof['id'] else: raise exception.DeviceProfileNameNeeded() @@ -870,8 +969,8 @@ class Connection(api.Connection): count = query.delete() if count != 1: raise exception.ResourceNotFound( - resource='ExtArq', - msg='with uuid=%s' % uuid) + resource='ExtArq', msg='with uuid=%s' % uuid + ) def extarq_update(self, context, uuid, values, state_scope=None): if 'uuid' in values and values['uuid'] != uuid: @@ -883,18 +982,17 @@ class Connection(api.Connection): @main_context_manager.writer def _do_update_extarq(self, context, uuid, values, state_scope=None): query = model_query(context, models.ExtArq) - query = query_update = query.filter_by( - uuid=uuid).with_for_update() + query = query_update = query.filter_by(uuid=uuid).with_for_update() if type(state_scope) is list: query_update = query_update.filter( - models.ExtArq.state.in_(state_scope)) + models.ExtArq.state.in_(state_scope) + ) try: - query_update.update( - values, synchronize_session="fetch") + query_update.update(values, synchronize_session="fetch") except NoResultFound: raise exception.ResourceNotFound( - resource='ExtArq', - msg='with uuid=%s' % uuid) + resource='ExtArq', msg='with uuid=%s' % uuid + ) ref = query.first() return ref @@ -902,16 +1000,13 @@ class Connection(api.Connection): def extarq_list(self, context, uuid_range=None): query = model_query(context, models.ExtArq) if type(uuid_range) is list: - query = query.filter( - models.ExtArq.uuid.in_(uuid_range)) + query = query.filter(models.ExtArq.uuid.in_(uuid_range)) return _paginate_query(context, models.ExtArq, query) @oslo_db_api.retry_on_deadlock @main_context_manager.writer def extarq_get(self, context, uuid, lock=False): - query = model_query( - context, - models.ExtArq).filter_by(uuid=uuid) + query = model_query(context, models.ExtArq).filter_by(uuid=uuid) # NOTE we will support aync bind, so get query by lock if lock: query = query.with_for_update() @@ -919,24 +1014,34 @@ class Connection(api.Connection): return query.one() except NoResultFound: raise exception.ResourceNotFound( - resource='ExtArq', - msg='with uuid=%s' % uuid) + resource='ExtArq', msg='with uuid=%s' % uuid + ) @main_context_manager.writer def _get_quota_usages(self, context, project_id, resources=None): # Broken out for testability - query = model_query(context, models.QuotaUsage,).filter_by( - project_id=project_id) + query = model_query( + context, + models.QuotaUsage, + ).filter_by(project_id=project_id) if resources: - query = query.filter(models.QuotaUsage.resource.in_( - list(resources))) - rows = query.order_by(models.QuotaUsage.id.asc()). \ - with_for_update().all() + query = query.filter( + models.QuotaUsage.resource.in_(list(resources)) + ) + rows = ( + query.order_by(models.QuotaUsage.id.asc()).with_for_update().all() + ) return {row.resource: row for row in rows} - def _quota_usage_create(self, project_id, resource, until_refresh, - in_use, reserved, session=None): - + def _quota_usage_create( + self, + project_id, + resource, + until_refresh, + in_use, + reserved, + session=None, + ): quota_usage_ref = models.QuotaUsage() quota_usage_ref.project_id = project_id quota_usage_ref.resource = resource @@ -947,8 +1052,9 @@ class Connection(api.Connection): return quota_usage_ref - def _reservation_create(self, uuid, usage, project_id, resource, delta, - expire, session=None): + def _reservation_create( + self, uuid, usage, project_id, resource, delta, expire, session=None + ): usage_id = usage['id'] if usage else None reservation_ref = models.Reservation() reservation_ref.uuid = uuid @@ -963,30 +1069,43 @@ class Connection(api.Connection): def _get_reservation_resources(self, context, reservation_ids): """Return the relevant resources by reservations.""" - reservations = model_query(context, models.Reservation). \ - options(load_only('resource')). \ - filter(models.Reservation.uuid.in_(reservation_ids)). \ - all() + reservations = ( + model_query(context, models.Reservation) + .options(load_only('resource')) + .filter(models.Reservation.uuid.in_(reservation_ids)) + .all() + ) return {r.resource for r in reservations} def _quota_reservations(self, session, context, reservations): """Return the relevant reservations.""" # Get the listed reservations - return model_query(context, models.Reservation). \ - filter(models.Reservation.uuid.in_(reservations)). \ - with_for_update(). \ - all() + return ( + model_query(context, models.Reservation) + .filter(models.Reservation.uuid.in_(reservations)) + .with_for_update() + .all() + ) @main_context_manager.writer - def quota_reserve(self, context, resources, deltas, expire, - until_refresh, max_age, project_id=None, - is_allocated_reserve=False): + def quota_reserve( + self, + context, + resources, + deltas, + expire, + until_refresh, + max_age, + project_id=None, + is_allocated_reserve=False, + ): """Create reservation record in DB according to params""" if project_id is None: project_id = context.project_id - usages = self._get_quota_usages(context, project_id, - resources=deltas.keys()) + usages = self._get_quota_usages( + context, project_id, resources=deltas.keys() + ) work = set(deltas.keys()) while work: resource = work.pop() @@ -997,8 +1116,13 @@ class Connection(api.Connection): # of resource if resource not in usages: usages[resource] = self._quota_usage_create( - project_id, resource, until_refresh or None, - in_use=0, reserved=0, session=context.session) + project_id, + resource, + until_refresh or None, + in_use=0, + reserved=0, + session=context.session, + ) refresh = True elif usages[resource].in_use < 0: # Negative in_use count indicates a desync, so try to @@ -1008,17 +1132,22 @@ class Connection(api.Connection): usages[resource].until_refresh -= 1 if usages[resource].until_refresh <= 0: refresh = True - elif max_age and usages[resource].updated_at is not None and ( - (timeutils.utcnow() - - usages[resource].updated_at).total_seconds() >= - max_age): + elif ( + max_age + and usages[resource].updated_at is not None + and ( + ( + timeutils.utcnow() - usages[resource].updated_at + ).total_seconds() + >= max_age + ) + ): refresh = True # refresh the usage if refresh: # Grab the sync routine - updates = self._sync_acc_res(context, - resource, project_id) + updates = self._sync_acc_res(context, resource, project_id) for res, in_use in updates.items(): # Make sure we have a destination for the usage! if res not in usages: @@ -1028,7 +1157,7 @@ class Connection(api.Connection): until_refresh or None, in_use=0, reserved=0, - session=context.session + session=context.session, ) # Update the usage @@ -1047,26 +1176,39 @@ class Connection(api.Connection): # for. We don't check, because this is # a best-effort mechanism. - unders = [r for r, delta in deltas.items() - if delta < 0 and delta + usages[r].in_use < 0] + unders = [ + r + for r, delta in deltas.items() + if delta < 0 and delta + usages[r].in_use < 0 + ] reservations = [] for resource, delta in deltas.items(): usage = usages[resource] reservation = self._reservation_create( - str(uuid.uuid4()), usage, project_id, resource, - delta, expire, session=context.session) + str(uuid.uuid4()), + usage, + project_id, + resource, + delta, + expire, + session=context.session, + ) reservations.append(reservation.uuid) usages[resource].reserved += delta context.session.flush() if unders: - LOG.warning("Change will make usage less than 0 for the " - "following resources: %s", unders) + LOG.warning( + "Change will make usage less than 0 for the " + "following resources: %s", + unders, + ) return reservations def _sync_acc_res(self, context, resource, project_id): """Quota sync function""" - res_in_use = self._device_data_get_for_project(context, resource, - project_id) + res_in_use = self._device_data_get_for_project( + context, resource, project_id + ) return {resource: res_in_use} @main_context_manager.reader @@ -1083,14 +1225,15 @@ class Connection(api.Connection): def reservation_commit(self, context, reservations, project_id=None): """Commit quota reservation to quota usage table""" quota_usage = self._get_quota_usages( - context, project_id, - resources=self._get_reservation_resources(context, - reservations)) + context, + project_id, + resources=self._get_reservation_resources(context, reservations), + ) usages = self._dict_with_usage_id(quota_usage) - for reservation in self._quota_reservations(context.session, context, - reservations): - + for reservation in self._quota_reservations( + context.session, context, reservations + ): usage = usages[reservation.usage_id] if reservation.delta >= 0: usage.reserved -= reservation.delta @@ -1098,10 +1241,13 @@ class Connection(api.Connection): context.session.flush() reservation.delete(session=context.session) - def process_sort_params(self, sort_keys, sort_dirs, - default_keys=['created_at', 'id'], - default_dir='asc'): - + def process_sort_params( + self, + sort_keys, + sort_dirs, + default_keys=['created_at', 'id'], + default_dir='asc', + ): # Determine direction to use for when adding default keys if sort_dirs and len(sort_dirs) != 0: default_dir_value = sort_dirs[0] diff --git a/cyborg/db/sqlalchemy/migration.py b/cyborg/db/sqlalchemy/migration.py index 6823c48f..38b3c28c 100644 --- a/cyborg/db/sqlalchemy/migration.py +++ b/cyborg/db/sqlalchemy/migration.py @@ -63,8 +63,9 @@ def create_schema(config=None, engine=None): engine = enginefacade.writer.get_engine() if version(engine=engine) is not None: - raise db_exc.DBMigrationError("DB schema is already under version" - " control. Use upgrade() instead") + raise db_exc.DBMigrationError( + "DB schema is already under version control. Use upgrade() instead" + ) models.Base.metadata.create_all(engine) stamp('head', config=config) @@ -104,5 +105,6 @@ def revision(message=None, autogenerate=False, config=None): :type autogenerate: bool """ config = config or _alembic_config() - return alembic.command.revision(config, message=message, - autogenerate=autogenerate) + return alembic.command.revision( + config, message=message, autogenerate=autogenerate + ) diff --git a/cyborg/db/sqlalchemy/models.py b/cyborg/db/sqlalchemy/models.py index 7ddbe059..6b451a19 100644 --- a/cyborg/db/sqlalchemy/models.py +++ b/cyborg/db/sqlalchemy/models.py @@ -44,8 +44,10 @@ db_options.set_defaults(CONF, connection=_DEFAULT_SQL_CONNECTION) def table_args(): engine_name = urlparse.urlparse(CONF.database.connection).scheme if engine_name == 'mysql': - return {'mysql_engine': CONF.database.mysql_engine, - 'mysql_charset': "utf8"} + return { + 'mysql_engine': CONF.database.mysql_engine, + 'mysql_charset': "utf8", + } return None @@ -60,8 +62,7 @@ class CyborgBase(models.TimestampMixin, models.ModelBase): @staticmethod def delete_values(): - return {'deleted': True, - 'deleted_at': timeutils.utcnow()} + return {'deleted': True, 'deleted_at': timeutils.utcnow()} def delete(self, session): """Delete this object.""" @@ -81,15 +82,20 @@ class Device(Base): id = Column(Integer, primary_key=True) uuid = Column(String(36), nullable=False, unique=True) - type = Column(Enum('GPU', 'FPGA', 'AICHIP', 'QAT', 'NIC', 'SSD', - name='device_type'), nullable=False) + type = Column( + Enum('GPU', 'FPGA', 'AICHIP', 'QAT', 'NIC', 'SSD', name='device_type'), + nullable=False, + ) vendor = Column(String(255), nullable=False) model = Column(String(255), nullable=False) std_board_info = Column(Text, nullable=True) vendor_board_info = Column(Text, nullable=True) hostname = Column(String(255), nullable=False) - status = Column(Enum("enabled", "maintaining", name='device_status'), - default='enabled', nullable=False) + status = Column( + Enum("enabled", "maintaining", name='device_status'), + default='enabled', + nullable=False, + ) class Deployable(Base): @@ -100,7 +106,7 @@ class Deployable(Base): Index('deployables_parent_id_idx', 'parent_id'), Index('deployables_root_id_idx', 'root_id'), Index('deployables_device_id_idx', 'device_id'), - table_args() + table_args(), ) id = Column(Integer, primary_key=True) @@ -109,8 +115,9 @@ class Deployable(Base): root_id = Column(Integer, ForeignKey('deployables.id'), nullable=True) name = Column(String(255), nullable=False) num_accelerators = Column(Integer, nullable=False) - device_id = Column(Integer, ForeignKey('devices.id', ondelete="RESTRICT"), - nullable=False) + device_id = Column( + Integer, ForeignKey('devices.id', ondelete="RESTRICT"), nullable=False + ) # The resource provider UUID is nullable for 2 reasons: # A. on creation, till Placement is populated, this will be null. # B. Sub-deployables (such as in networked FPGA cards) will have @@ -129,9 +136,12 @@ class Attribute(Base): id = Column(Integer, primary_key=True) uuid = Column(String(36), nullable=False, unique=True) - deployable_id = Column(Integer, - ForeignKey('deployables.id', ondelete="RESTRICT"), - nullable=False, index=True) + deployable_id = Column( + Integer, + ForeignKey('deployables.id', ondelete="RESTRICT"), + nullable=False, + index=True, + ) key = Column(Text, nullable=False) value = Column(Text, nullable=False) @@ -145,9 +155,12 @@ class ControlpathID(Base): id = Column(Integer, primary_key=True) uuid = Column(String(36), nullable=False, unique=True) - device_id = Column(Integer, - ForeignKey('devices.id', ondelete="RESTRICT"), - nullable=False, index=True) + device_id = Column( + Integer, + ForeignKey('devices.id', ondelete="RESTRICT"), + nullable=False, + index=True, + ) cpid_type = Column(Enum('PCI', name='cpid_type'), nullable=False) cpid_info = Column(String(255), nullable=False) @@ -159,23 +172,31 @@ class AttachHandle(Base): __table_args__ = ( Index('attach_handles_cpid_id_idx', 'cpid_id'), Index('attach_handles_deployable_id_idx', 'deployable_id'), - table_args() + table_args(), ) id = Column(Integer, primary_key=True) uuid = Column(String(36), nullable=False, unique=True) - deployable_id = Column(Integer, - ForeignKey('deployables.id', ondelete="RESTRICT"), - nullable=False) - cpid_id = Column(Integer, - ForeignKey('controlpath_ids.id', ondelete="RESTRICT"), - nullable=False) + deployable_id = Column( + Integer, + ForeignKey('deployables.id', ondelete="RESTRICT"), + nullable=False, + ) + cpid_id = Column( + Integer, + ForeignKey('controlpath_ids.id', ondelete="RESTRICT"), + nullable=False, + ) in_use = Column(Boolean, default=False) - attach_type = Column(Enum(constants.AH_TYPE_PCI, - constants.AH_TYPE_MDEV, - constants.AH_TYPE_TEST_PCI, - name='attach_type'), - nullable=False) + attach_type = Column( + Enum( + constants.AH_TYPE_PCI, + constants.AH_TYPE_MDEV, + constants.AH_TYPE_TEST_PCI, + name='attach_type', + ), + nullable=False, + ) attach_info = Column(String(255), nullable=False) @@ -184,9 +205,10 @@ class DeviceProfile(Base): __tablename__ = 'device_profiles' __table_args__ = ( - schema.UniqueConstraint('uuid', 'name', - name='uniq_device_profiles0uuid0name'), - table_args() + schema.UniqueConstraint( + 'uuid', 'name', name='uniq_device_profiles0uuid0name' + ), + table_args(), ) id = Column(Integer, primary_key=True) @@ -209,35 +231,46 @@ class ExtArq(Base): Index('extArqs_instance_uuid_idx', 'instance_uuid'), Index('extArqs_attach_handle_id_idx', 'attach_handle_id'), Index('extArqs_deployable_id_idx', 'deployable_id'), - table_args() + table_args(), ) id = Column(Integer, primary_key=True) uuid = Column(String(36), nullable=False, unique=True) project_id = Column(String(255), nullable=True) - state = Column(Enum(constants.ARQ_INITIAL, - constants.ARQ_BIND_STARTED, - constants.ARQ_BOUND, - constants.ARQ_BIND_FAILED, - constants.ARQ_UNBOUND, - constants.ARQ_DELETING), - nullable=False) - device_profile_id = Column(Integer, ForeignKey('device_profiles.id', - ondelete="RESTRICT"), - nullable=False) + state = Column( + Enum( + constants.ARQ_INITIAL, + constants.ARQ_BIND_STARTED, + constants.ARQ_BOUND, + constants.ARQ_BIND_FAILED, + constants.ARQ_UNBOUND, + constants.ARQ_DELETING, + ), + nullable=False, + ) + device_profile_id = Column( + Integer, + ForeignKey('device_profiles.id', ondelete="RESTRICT"), + nullable=False, + ) device_profile_group_id = Column(Integer, nullable=False, default=0) hostname = Column(String(255), nullable=True) device_rp_uuid = Column(String(36), nullable=True) instance_uuid = Column(String(36), nullable=True) - attach_handle_id = Column(Integer, ForeignKey('attach_handles.id', - ondelete="RESTRICT"), - nullable=True) + attach_handle_id = Column( + Integer, + ForeignKey('attach_handles.id', ondelete="RESTRICT"), + nullable=True, + ) # Cyborg Private Fields - substate = Column(Enum('Initial', name='substate'), nullable=False, - default='Initial') - deployable_id = Column(Integer, - ForeignKey('deployables.id', ondelete="RESTRICT"), - nullable=True) + substate = Column( + Enum('Initial', name='substate'), nullable=False, default='Initial' + ) + deployable_id = Column( + Integer, + ForeignKey('deployables.id', ondelete="RESTRICT"), + nullable=True, + ) class QuotaUsage(Base): @@ -288,4 +321,5 @@ class Reservation(Base): usage = orm.relationship( "QuotaUsage", foreign_keys=usage_id, - primaryjoin=usage_id == QuotaUsage.id) + primaryjoin=usage_id == QuotaUsage.id, + ) diff --git a/cyborg/hacking/__init__.py b/cyborg/hacking/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/cyborg/hacking/checks.py b/cyborg/hacking/checks.py deleted file mode 100644 index ec338c77..00000000 --- a/cyborg/hacking/checks.py +++ /dev/null @@ -1,103 +0,0 @@ -# -# All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -""" -Guidelines for writing new hacking checks - - - Use only for Magnum specific tests. OpenStack general tests - should be submitted to the common 'hacking' module. - - Pick numbers in the range M3xx. Find the current test with - the highest allocated number and then pick the next value. - If nova has an N3xx code for that test, use the same number. - - Keep the test method code in the source file ordered based - on the M3xx value. - - List the new rule in the top level HACKING.rst file - - Add test cases for each new rule to magnum/tests/unit/test_hacking.py - -""" - -import re - -from hacking import core -UNDERSCORE_IMPORT_FILES = [] - -mutable_default_args = re.compile(r"^\s*def .+\((.+=\{\}|.+=\[\])") -dict_constructor_with_list_copy_re = re.compile(r".*\bdict\((\[)?(\(|\[)") -log_translation = re.compile( - r"(.)*LOG\.(audit|error|critical)\(\s*('|\")") -log_translation_info = re.compile( - r"(.)*LOG\.(info)\(\s*(_\(|'|\")") -log_translation_exception = re.compile( - r"(.)*LOG\.(exception)\(\s*(_\(|'|\")") -log_translation_LW = re.compile( - r"(.)*LOG\.(warning|warn)\(\s*(_\(|'|\")") -custom_underscore_check = re.compile(r"(.)*_\s*=\s*(.)*") -underscore_import_check = re.compile(r"(.)*import _(.)*") -translated_log = re.compile( - r"(.)*LOG\.(audit|error|info|critical|exception)" - r"\(\s*_\(\s*('|\")") -string_translation = re.compile(r"[^_]*_\(\s*('|\")") - - -@core.flake8ext -def no_mutable_default_args(logical_line): - msg = "M322: Method's default argument shouldn't be mutable!" - if mutable_default_args.match(logical_line): - yield (0, msg) - - -@core.flake8ext -def use_timeutils_utcnow(logical_line, filename): - # tools are OK to use the standard datetime module - if "/tools/" in filename: - return - - msg = "M310: timeutils.utcnow() must be used instead of datetime.%s()" - datetime_funcs = ['now', 'utcnow'] - for f in datetime_funcs: - pos = logical_line.find(f'datetime.{f}') - if pos != -1: - yield (pos, msg % f) - - -@core.flake8ext -def dict_constructor_with_list_copy(logical_line): - msg = ("M336: Must use a dict comprehension instead of a dict constructor" - " with a sequence of key-value pairs." - ) - if dict_constructor_with_list_copy_re.match(logical_line): - yield (0, msg) - - -@core.flake8ext -def check_explicit_underscore_import(logical_line, filename): - """Check for explicit import of the _ function - - We need to ensure that any files that are using the _() function - to translate logs are explicitly importing the _ function. We - can't trust unit test to catch whether the import has been - added so we need to check for it here. - """ - - # Build a list of the files that have _ imported. No further - # checking needed once it is found. - if filename in UNDERSCORE_IMPORT_FILES: - pass - elif (underscore_import_check.match(logical_line) or - custom_underscore_check.match(logical_line)): - UNDERSCORE_IMPORT_FILES.append(filename) - elif (translated_log.match(logical_line) or - string_translation.match(logical_line)): - yield (0, "M340: Found use of _() without explicit import of _ !") diff --git a/cyborg/image/api.py b/cyborg/image/api.py index 096c36ff..988e4689 100644 --- a/cyborg/image/api.py +++ b/cyborg/image/api.py @@ -22,7 +22,6 @@ LOG = log.getLogger(__name__) class API: - """Responsible for exposing a relatively stable internal API for other modules in Cyborg to retrieve information about accelerator images. """ @@ -71,5 +70,6 @@ class API: """ session, image_id = self._get_session_and_image_id(context, id_or_uri) - return session.download(context, image_id, data=data, - dst_path=dest_path) + return session.download( + context, image_id, data=data, dst_path=dest_path + ) diff --git a/cyborg/image/glance.py b/cyborg/image/glance.py index 1000003d..3a15b84e 100644 --- a/cyborg/image/glance.py +++ b/cyborg/image/glance.py @@ -54,7 +54,8 @@ def _session_and_auth(context): if not _SESSION: _SESSION = ks_loading.load_session_from_conf_options( - CONF, cyborg.conf.glance.glance_group.name) + CONF, cyborg.conf.glance.glance_group.name + ) auth = service_auth.get_auth_plugin(context) @@ -64,9 +65,13 @@ def _session_and_auth(context): def _glanceclient_from_endpoint(context, endpoint, version): sess, auth = _session_and_auth(context) - return glanceclient.Client(version, session=sess, auth=auth, - endpoint_override=endpoint, - global_request_id=context.global_id) + return glanceclient.Client( + version, + session=sess, + auth=auth, + endpoint_override=endpoint, + global_request_id=context.global_id, + ) def generate_glance_url(context): @@ -103,8 +108,11 @@ def get_api_server(context): sess, auth = _session_and_auth(context) ksa_adap = utils.get_ksa_adapter( cyborg.conf.glance.DEFAULT_SERVICE_TYPE, - ksa_auth=auth, ksa_session=sess, - min_version='2.0', max_version='2.latest') + ksa_auth=auth, + ksa_session=sess, + min_version='2.0', + max_version='2.latest', + ) endpoint = utils.get_endpoint(ksa_adap) if endpoint: # NOTE(mriedem): Due to python-glanceclient bug 1707995 we have @@ -122,9 +130,9 @@ class GlanceClientWrapper: def __init__(self, context=None, endpoint=None): version = 2 if endpoint is not None: - self.client = self._create_static_client(context, - endpoint, - version) + self.client = self._create_static_client( + context, endpoint, version + ) else: self.client = None self.api_server = None @@ -144,17 +152,21 @@ class GlanceClientWrapper: """Call a glance client method. If we get a connection error, retry the request according to CONF.glance.num_retries. """ - retry_excs = (glanceclient.exc.HTTPServiceUnavailable, - glanceclient.exc.InvalidEndpoint, - glanceclient.exc.CommunicationError) + retry_excs = ( + glanceclient.exc.HTTPServiceUnavailable, + glanceclient.exc.InvalidEndpoint, + glanceclient.exc.CommunicationError, + ) num_attempts = 1 + CONF.glance.num_retries for attempt in range(1, num_attempts + 1): - client = self.client or self._create_onetime_client(context, - version) + client = self.client or self._create_onetime_client( + context, version + ) try: - controller = getattr(client, - kwargs.pop('controller', 'images')) + controller = getattr( + client, kwargs.pop('controller', 'images') + ) result = getattr(controller, method)(*args, **kwargs) if inspect.isgenerator(result): # Convert generator results to a list, so that we can @@ -167,14 +179,20 @@ class GlanceClientWrapper: else: extra = 'done trying' - LOG.exception("Error contacting glance server " - "'%(server)s' for '%(method)s', " - "%(extra)s.", - {'server': self.api_server, - 'method': method, 'extra': extra}) + LOG.exception( + "Error contacting glance server " + "'%(server)s' for '%(method)s', " + "%(extra)s.", + { + 'server': self.api_server, + 'method': method, + 'extra': extra, + }, + ) if attempt == num_attempts: raise exception.GlanceConnectionFailed( - server=str(self.api_server), reason=str(e)) + server=str(self.api_server), reason=str(e) + ) time.sleep(1) @@ -210,15 +228,18 @@ class GlanceImageServiceV2: if image_chunks.wrapped is None: # None is a valid return value, but there's nothing we can do with # a image with no associated data - raise exception.ImageUnacceptable(image_id=image_id, - reason='Image has no \ - associated data') + raise exception.ImageUnacceptable( + image_id=image_id, + reason='Image has no \ + associated data', + ) # Retrieve properties for verification of Glance image signature verifier = None if CONF.glance.verify_glance_signatures: - image_meta_dict = self.show(context, image_id, - include_locations=False) + image_meta_dict = self.show( + context, image_id, include_locations=False + ) image_meta = objects.ImageMeta.from_dict(image_meta_dict) img_signature = image_meta.properties.get('img_signature') img_sig_hash_method = image_meta.properties.get( @@ -240,8 +261,10 @@ class GlanceImageServiceV2: ) except cursive_exception.SignatureVerificationError: with excutils.save_and_reraise_exception(): - LOG.error('Image signature verification failed ' - 'for image: %s', image_id) + LOG.error( + 'Image signature verification failed for image: %s', + image_id, + ) close_file = False if data is None and dst_path: @@ -249,7 +272,6 @@ class GlanceImageServiceV2: close_file = True if data is None: - # Perform image signature verification if verifier: try: @@ -257,13 +279,18 @@ class GlanceImageServiceV2: verifier.update(chunk) verifier.verify() - LOG.info('Image signature verification succeeded ' - 'for image: %s', image_id) + LOG.info( + 'Image signature verification succeeded for image: %s', + image_id, + ) except cryptography.exceptions.InvalidSignature: with excutils.save_and_reraise_exception(): - LOG.error('Image signature verification failed ' - 'for image: %s', image_id) + LOG.error( + 'Image signature verification failed ' + 'for image: %s', + image_id, + ) return image_chunks else: try: @@ -273,17 +300,23 @@ class GlanceImageServiceV2: data.write(chunk) if verifier: verifier.verify() - LOG.info('Image signature verification succeeded ' - 'for image %s', image_id) + LOG.info( + 'Image signature verification succeeded for image %s', + image_id, + ) except cryptography.exceptions.InvalidSignature: data.truncate(0) with excutils.save_and_reraise_exception(): - LOG.error('Image signature verification failed ' - 'for image: %s', image_id) + LOG.error( + 'Image signature verification failed for image: %s', + image_id, + ) except Exception as ex: with excutils.save_and_reraise_exception(): - LOG.error("Error writing to %(path)s: %(exception)s", - {'path': dst_path, 'exception': ex}) + LOG.error( + "Error writing to %(path)s: %(exception)s", + {'path': dst_path, 'exception': ex}, + ) finally: if close_file: # Ensure that the data is pushed all the way down to @@ -297,8 +330,14 @@ class GlanceImageServiceV2: def _extract_query_params(params): _params = {} - accepted_params = ('filters', 'marker', 'limit', - 'page_size', 'sort_key', 'sort_dir') + accepted_params = ( + 'filters', + 'marker', + 'limit', + 'page_size', + 'sort_key', + 'sort_dir', + ) for param in accepted_params: if params.get(param): _params[param] = params.get(param) @@ -313,8 +352,14 @@ def _extract_query_params(params): def _extract_query_params_v2(params): _params = {} - accepted_params = ('filters', 'marker', 'limit', - 'page_size', 'sort_key', 'sort_dir') + accepted_params = ( + 'filters', + 'marker', + 'limit', + 'page_size', + 'sort_key', + 'sort_dir', + ) for param in accepted_params: if params.get(param): _params[param] = params.get(param) @@ -408,15 +453,16 @@ def _convert_to_v2(image_meta): # if allow_additional_image_properties is disabled we can't # define kernel_id and ramdisk_id as None, so we have to omit # these properties if they are not set. - if prop_name in ('kernel_id', 'ramdisk_id') and \ - prop_value is not None and \ - prop_value.strip().lower() in ('none', ''): + if ( + prop_name in ('kernel_id', 'ramdisk_id') + and prop_value is not None + and prop_value.strip().lower() in ('none', '') + ): continue # in glance only string and None property values are allowed, # v1 client accepts any values and converts them to string, # v2 doesn't - so we have to take care of it. - elif prop_value is None or isinstance( - prop_value, str): + elif prop_value is None or isinstance(prop_value, str): output[prop_name] = prop_value else: output[prop_name] = str(prop_value) @@ -434,7 +480,8 @@ def _convert_to_v2(image_meta): def _translate_from_glance(image, include_locations=False): image_meta = _extract_attributes_v2( - image, include_locations=include_locations) + image, include_locations=include_locations + ) image_meta = _convert_timestamps_to_datetimes(image_meta) image_meta = _convert_from_string(image_meta) @@ -492,12 +539,25 @@ def _extract_attributes(image, include_locations=False): # therefore sorted, with dependent attributes as the end # 'deleted_at' depends on 'deleted' # 'checksum' depends on 'status' == 'active' - IMAGE_ATTRIBUTES = ['size', 'disk_format', 'owner', - 'container_format', 'status', 'id', - 'name', 'created_at', 'updated_at', - 'deleted', 'deleted_at', 'checksum', - 'min_disk', 'min_ram', 'is_public', - 'direct_url', 'locations'] + IMAGE_ATTRIBUTES = [ + 'size', + 'disk_format', + 'owner', + 'container_format', + 'status', + 'id', + 'name', + 'created_at', + 'updated_at', + 'deleted', + 'deleted_at', + 'checksum', + 'min_disk', + 'min_ram', + 'is_public', + 'direct_url', + 'locations', + ] queued = getattr(image, 'status') == 'queued' queued_exclude_attrs = ['disk_format', 'container_format'] @@ -538,16 +598,31 @@ def _extract_attributes(image, include_locations=False): def _extract_attributes_v2(image, include_locations=False): include_locations_attrs = ['direct_url', 'locations'] - omit_attrs = ['self', 'schema', 'protected', 'virtual_size', 'file', - 'tags'] + omit_attrs = [ + 'self', + 'schema', + 'protected', + 'virtual_size', + 'file', + 'tags', + ] raw_schema = image.schema schema = schemas.Schema(raw_schema) - output = {'properties': {}, 'deleted': False, 'deleted_at': None, - 'disk_format': None, 'container_format': None, 'name': None, - 'checksum': None} + output = { + 'properties': {}, + 'deleted': False, + 'deleted_at': None, + 'disk_format': None, + 'container_format': None, + 'name': None, + 'checksum': None, + } for name, value in image.items(): - if (name in omit_attrs - or name in include_locations_attrs and not include_locations): + if ( + name in omit_attrs + or name in include_locations_attrs + and not include_locations + ): continue elif name == 'visibility': output['is_public'] = value == 'public' @@ -585,20 +660,27 @@ def _reraise_translated_exception(): def _translate_image_exception(image_id, exc_value): - if isinstance(exc_value, glanceclient.exc.HTTPForbidden | glanceclient.exc.HTTPUnauthorized): # noqa: E501 + if isinstance( + exc_value, + glanceclient.exc.HTTPForbidden | glanceclient.exc.HTTPUnauthorized, + ): # noqa: E501 return exception.ImageNotAuthorized(image_id=image_id) if isinstance(exc_value, glanceclient.exc.HTTPNotFound): return exception.ResourceNotFound( - resource='Image', - msg='with uuid=%s' % image_id) + resource='Image', msg='with uuid=%s' % image_id + ) if isinstance(exc_value, glanceclient.exc.HTTPBadRequest): - return exception.ImageBadRequest(image_id=image_id, - response=str(exc_value)) + return exception.ImageBadRequest( + image_id=image_id, response=str(exc_value) + ) return exc_value def _translate_plain_exception(exc_value): - if isinstance(exc_value, glanceclient.exc.HTTPForbidden | glanceclient.exc.HTTPUnauthorized): # noqa: E501 + if isinstance( + exc_value, + glanceclient.exc.HTTPForbidden | glanceclient.exc.HTTPUnauthorized, + ): # noqa: E501 return exception.HTTPForbidden(str(exc_value)) if isinstance(exc_value, glanceclient.exc.HTTPNotFound): return exception.HTTPNotFound(str(exc_value)) @@ -627,8 +709,7 @@ def get_remote_image_service(context, image_href): try: (image_id, endpoint) = _endpoint_from_image_ref(image_href) - glance_client = GlanceClientWrapper(context=context, - endpoint=endpoint) + glance_client = GlanceClientWrapper(context=context, endpoint=endpoint) except ValueError: raise exception.InvalidImageRef(image_href=image_href) @@ -648,7 +729,13 @@ class UpdateGlanceImage: self.image_stream = stream def start(self): - image_service, image_id = ( - get_remote_image_service(self.context, self.image_id)) - image_service.update(self.context, image_id, self.metadata, - self.image_stream, purge_props=False) + image_service, image_id = get_remote_image_service( + self.context, self.image_id + ) + image_service.update( + self.context, + image_id, + self.metadata, + self.image_stream, + purge_props=False, + ) diff --git a/cyborg/objects/arq.py b/cyborg/objects/arq.py index 17e81a7c..a376a715 100644 --- a/cyborg/objects/arq.py +++ b/cyborg/objects/arq.py @@ -37,15 +37,12 @@ class ARQ(base.CyborgObject, object_base.VersionedObjectDictCompat): 'uuid': object_fields.UUIDField(nullable=False), 'state': object_fields.ARQStateField(nullable=False), 'device_profile_name': object_fields.StringField(nullable=False), - 'device_profile_group_id': - object_fields.IntegerField(nullable=False), - + 'device_profile_group_id': object_fields.IntegerField(nullable=False), # Fields populated by Nova after scheduling for binding 'hostname': object_fields.StringField(nullable=True), 'device_rp_uuid': object_fields.StringField(nullable=True), 'instance_uuid': object_fields.StringField(nullable=True), 'project_id': object_fields.StringField(nullable=True), - # Fields populated by Cyborg after binding 'attach_handle_type': object_fields.StringField(nullable=True), 'attach_handle_uuid': object_fields.StringField(nullable=True), diff --git a/cyborg/objects/attach_handle.py b/cyborg/objects/attach_handle.py index f3029f96..9dd60ad9 100644 --- a/cyborg/objects/attach_handle.py +++ b/cyborg/objects/attach_handle.py @@ -37,11 +37,11 @@ class AttachHandle(base.CyborgObject, object_base.VersionedObjectDictCompat): 'deployable_id': object_fields.IntegerField(nullable=False), 'cpid_id': object_fields.IntegerField(nullable=False), 'attach_type': object_fields.EnumField( - valid_values=constants.ATTACH_HANDLE_TYPES, - nullable=False), + valid_values=constants.ATTACH_HANDLE_TYPES, nullable=False + ), # attach_info should be JSON here. 'attach_info': object_fields.StringField(nullable=False), - 'in_use': object_fields.BooleanField(nullable=False, default=False) + 'in_use': object_fields.BooleanField(nullable=False, default=False), } def create(self, context): @@ -72,11 +72,14 @@ class AttachHandle(base.CyborgObject, object_base.VersionedObjectDictCompat): sort_key = filters.pop('sort_key', 'created_at') limit = filters.pop('limit', None) marker = filters.pop('marker_obj', None) - db_ahs = cls.dbapi.attach_handle_get_by_filters(context, filters, - sort_dir=sort_dir, - sort_key=sort_key, - limit=limit, - marker=marker) + db_ahs = cls.dbapi.attach_handle_get_by_filters( + context, + filters, + sort_dir=sort_dir, + sort_key=sort_key, + limit=limit, + marker=marker, + ) else: db_ahs = cls.dbapi.attach_handle_list(context) obj_ah_list = cls._from_db_object_list(db_ahs, context) @@ -101,8 +104,10 @@ class AttachHandle(base.CyborgObject, object_base.VersionedObjectDictCompat): @classmethod def get_ah_by_depid_attachinfo(cls, context, deployable_id, attach_info): - ah_filter = {'deployable_id': deployable_id, - 'attach_info': attach_info} + ah_filter = { + 'deployable_id': deployable_id, + 'attach_info': attach_info, + } ah_obj_list = AttachHandle.list(context, ah_filter) if len(ah_obj_list) != 0: return ah_obj_list[0] diff --git a/cyborg/objects/attribute.py b/cyborg/objects/attribute.py index f0816cdf..8413fdc9 100644 --- a/cyborg/objects/attribute.py +++ b/cyborg/objects/attribute.py @@ -37,7 +37,7 @@ class Attribute(base.CyborgObject, object_base.VersionedObjectDictCompat): 'uuid': object_fields.UUIDField(nullable=False), 'deployable_id': object_fields.IntegerField(nullable=False), 'key': object_fields.StringField(nullable=False), - 'value': object_fields.StringField(nullable=False) + 'value': object_fields.StringField(nullable=False), } def create(self, context): @@ -46,8 +46,7 @@ class Attribute(base.CyborgObject, object_base.VersionedObjectDictCompat): raise exception.AttributeInvalid() values = self.obj_get_changes() - db_attr = self.dbapi.attribute_create(context, - values) + db_attr = self.dbapi.attribute_create(context, values) return self._from_db_object(self, db_attr) @classmethod @@ -60,8 +59,9 @@ class Attribute(base.CyborgObject, object_base.VersionedObjectDictCompat): @classmethod def get_by_deployable_id(cls, context, deployable_id): """Get an attribute by deployable_id""" - db_attr = cls.dbapi.attribute_get_by_deployable_id(context, - deployable_id) + db_attr = cls.dbapi.attribute_get_by_deployable_id( + context, deployable_id + ) return cls._from_db_object_list(db_attr, context) @classmethod @@ -72,10 +72,9 @@ class Attribute(base.CyborgObject, object_base.VersionedObjectDictCompat): def save(self, context): """Update an attribute record in the DB.""" - db_attr = self.dbapi.attribute_update(context, - self.uuid, - self.key, - self.value) + db_attr = self.dbapi.attribute_update( + context, self.uuid, self.key, self.value + ) self._from_db_object(self, db_attr) def destroy(self, context): diff --git a/cyborg/objects/base.py b/cyborg/objects/base.py index c9d8b12b..9aae16c3 100644 --- a/cyborg/objects/base.py +++ b/cyborg/objects/base.py @@ -38,7 +38,8 @@ class CyborgObjectRegistry(object_base.VersionedObjectRegistry): setattr(objects, cls.obj_name(), cls) else: cur_version = versionutils.convert_version_to_tuple( - getattr(objects, cls.obj_name()).VERSION) + getattr(objects, cls.obj_name()).VERSION + ) if version >= cur_version: setattr(objects, cls.obj_name(), cls) @@ -73,8 +74,9 @@ class CyborgObject(object_base.VersionedObject): attr = attr.as_dict() return attr - return {k: _attr_as_dict(k) - for k in self.fields if self.obj_attr_is_set(k)} + return { + k: _attr_as_dict(k) for k in self.fields if self.obj_attr_is_set(k) + } @staticmethod def _from_db_object(obj, db_obj): @@ -112,8 +114,7 @@ class CyborgObject(object_base.VersionedObject): of the object. """ _log_backport(self, target_version) - super().obj_make_compatible(primitive, - target_version) + super().obj_make_compatible(primitive, target_version) class CyborgObjectSerializer(object_base.VersionedObjectSerializer): @@ -129,23 +130,24 @@ class CyborgPersistentObject: This adds the fields that we use in common for most persistent objects. """ + fields = { 'created_at': object_fields.DateTimeField(nullable=True), 'updated_at': object_fields.DateTimeField(nullable=True), 'deleted_at': object_fields.DateTimeField(nullable=True), 'deleted': object_fields.BooleanField(default=False), - } + } class ObjectListBase(object_base.ObjectListBase): - @classmethod def _obj_primitive_key(cls, field): return 'cyborg_object.%s' % field @classmethod - def _obj_primitive_field(cls, primitive, field, - default=object_fields.UnspecifiedDefault): + def _obj_primitive_field( + cls, primitive, field, default=object_fields.UnspecifiedDefault + ): key = cls._obj_primitive_key(field) if default == object_fields.UnspecifiedDefault: return primitive[key] @@ -224,8 +226,12 @@ class DriverObjectBase(CyborgObject): def _log_backport(ovo, target_version): """Log backported versioned objects.""" if target_version and target_version != ovo.VERSION: - LOG.debug('Backporting %(obj_name)s from version %(src_vers)s ' - 'to version %(dst_vers)s', - {'obj_name': ovo.obj_name(), - 'src_vers': ovo.VERSION, - 'dst_vers': target_version}) + LOG.debug( + 'Backporting %(obj_name)s from version %(src_vers)s ' + 'to version %(dst_vers)s', + { + 'obj_name': ovo.obj_name(), + 'src_vers': ovo.VERSION, + 'dst_vers': target_version, + }, + ) diff --git a/cyborg/objects/control_path.py b/cyborg/objects/control_path.py index 123e3f87..6d09e30b 100644 --- a/cyborg/objects/control_path.py +++ b/cyborg/objects/control_path.py @@ -38,9 +38,9 @@ class ControlpathID(base.CyborgObject, object_base.VersionedObjectDictCompat): 'uuid': object_fields.UUIDField(nullable=False), 'device_id': object_fields.IntegerField(nullable=False), 'cpid_type': object_fields.EnumField( - valid_values=constants.CPID_TYPE, - nullable=False), - 'cpid_info': object_fields.StringField(nullable=False) + valid_values=constants.CPID_TYPE, nullable=False + ), + 'cpid_info': object_fields.StringField(nullable=False), } @property @@ -72,11 +72,14 @@ class ControlpathID(base.CyborgObject, object_base.VersionedObjectDictCompat): sort_key = filters.pop('sort_key', 'created_at') limit = filters.pop('limit', None) marker = filters.pop('marker_obj', None) - db_cps = cls.dbapi.control_path_get_by_filters(context, filters, - sort_dir=sort_dir, - sort_key=sort_key, - limit=limit, - marker=marker) + db_cps = cls.dbapi.control_path_get_by_filters( + context, + filters, + sort_dir=sort_dir, + sort_key=sort_key, + limit=limit, + marker=marker, + ) else: db_cps = cls.dbapi.control_path_list(context) obj_cp_list = cls._from_db_object_list(db_cps, context) @@ -105,8 +108,7 @@ class ControlpathID(base.CyborgObject, object_base.VersionedObjectDictCompat): @classmethod def get_by_device_id_cpidinfo(cls, context, device_id, cpid_info): - cpid_filter = {'device_id': device_id, - 'cpid_info': cpid_info} + cpid_filter = {'device_id': device_id, 'cpid_info': cpid_info} # the list could have one value or is empty. cpid_obj_list = ControlpathID.list(context, cpid_filter) if len(cpid_obj_list) != 0: diff --git a/cyborg/objects/deployable.py b/cyborg/objects/deployable.py index 477ccbbd..276cc6f5 100644 --- a/cyborg/objects/deployable.py +++ b/cyborg/objects/deployable.py @@ -96,11 +96,14 @@ class Deployable(base.CyborgObject, object_base.VersionedObjectDictCompat): sort_key = filters.pop('sort_key', 'created_at') limit = filters.pop('limit', None) marker = filters.pop('marker_obj', None) - db_deps = cls.dbapi.deployable_get_by_filters(context, filters, - sort_dir=sort_dir, - sort_key=sort_key, - limit=limit, - marker=marker) + db_deps = cls.dbapi.deployable_get_by_filters( + context, + filters, + sort_dir=sort_dir, + sort_key=sort_key, + limit=limit, + marker=marker, + ) else: db_deps = cls.dbapi.deployable_list(context) obj_dpl_list = cls._from_db_object_list(db_deps, context) @@ -112,8 +115,10 @@ class Deployable(base.CyborgObject, object_base.VersionedObjectDictCompat): # TODO(Xinran): Will remove this if find some better way. updates.pop("uuid", None) updates.pop("created_at", None) - if "updated_at" in updates.keys() and \ - updates["updated_at"] is not None: + if ( + "updated_at" in updates.keys() + and updates["updated_at"] is not None + ): updates["updated_at"] = updates["updated_at"].replace(tzinfo=None) db_dep = self.dbapi.deployable_update(context, self.uuid, updates) self.obj_reset_changes() @@ -121,8 +126,7 @@ class Deployable(base.CyborgObject, object_base.VersionedObjectDictCompat): def update(self, context, updates): """Update provided key, value pairs""" - self.dbapi.deployable_update(context, self.uuid, - updates) + self.dbapi.deployable_update(context, self.uuid, updates) def destroy(self, context): """Delete a Deployable from the DB.""" @@ -130,12 +134,9 @@ class Deployable(base.CyborgObject, object_base.VersionedObjectDictCompat): self.obj_reset_changes() @classmethod - def get_by_filter(cls, context, - filters): + def get_by_filter(cls, context, filters): obj_dpl_list = [] - db_dpl_list = cls.dbapi.deployable_get_by_filters( - context, - filters) + db_dpl_list = cls.dbapi.deployable_get_by_filters(context, filters) if db_dpl_list: for db_dpl in db_dpl_list: @@ -173,5 +174,6 @@ class Deployable(base.CyborgObject, object_base.VersionedObjectDictCompat): # TODO(Sundar) We should probably get cpid from objects layer, # not db layer cpid_list = self.dbapi.control_path_get_by_filters( - context, query_filter) + context, query_filter + ) return cpid_list diff --git a/cyborg/objects/device.py b/cyborg/objects/device.py index ac8aac59..9327cbd2 100644 --- a/cyborg/objects/device.py +++ b/cyborg/objects/device.py @@ -36,15 +36,19 @@ class Device(base.CyborgObject, object_base.VersionedObjectDictCompat): fields = { 'id': object_fields.IntegerField(nullable=False), 'uuid': object_fields.UUIDField(nullable=False), - 'type': object_fields.EnumField(valid_values=constants.DEVICE_TYPE, - nullable=False), + 'type': object_fields.EnumField( + valid_values=constants.DEVICE_TYPE, nullable=False + ), 'vendor': object_fields.StringField(nullable=False), 'model': object_fields.StringField(nullable=False), 'std_board_info': object_fields.StringField(nullable=True), 'vendor_board_info': object_fields.StringField(nullable=True), 'hostname': object_fields.StringField(nullable=False), - 'status': object_fields.EnumField(valid_values=constants.DEVICE_STATUS, - nullable=False, default="enabled"), + 'status': object_fields.EnumField( + valid_values=constants.DEVICE_STATUS, + nullable=False, + default="enabled", + ), } def create(self, context): @@ -69,8 +73,13 @@ class Device(base.CyborgObject, object_base.VersionedObjectDictCompat): limit = filters.pop('limit', None) marker = filters.pop('marker_obj', None) db_devices = cls.dbapi.device_list_by_filters( - context, filters, sort_dir=sort_dir, sort_key=sort_key, - limit=limit, marker=marker) + context, + filters, + sort_dir=sort_dir, + sort_key=sort_key, + limit=limit, + marker=marker, + ) else: db_devices = cls.dbapi.device_list(context) return cls._from_db_object_list(db_devices, context) diff --git a/cyborg/objects/device_profile.py b/cyborg/objects/device_profile.py index 3ef239fe..5bc01dad 100644 --- a/cyborg/objects/device_profile.py +++ b/cyborg/objects/device_profile.py @@ -43,8 +43,7 @@ class DeviceProfile(base.CyborgObject, object_base.VersionedObjectDictCompat): } def obj_make_compatible(self, primitive, target_version): - super().obj_make_compatible( - primitive, target_version) + super().obj_make_compatible(primitive, target_version) target_version = versionutils.convert_version_to_tuple(target_version) if target_version < (1, 1) and 'description' in primitive: del primitive['description'] @@ -62,8 +61,9 @@ class DeviceProfile(base.CyborgObject, object_base.VersionedObjectDictCompat): """Create a Device Profile record in the DB.""" # TODO() validate with a JSON schema if 'name' not in self: - raise exception.ObjectActionError(action='create', - reason='name is required') + raise exception.ObjectActionError( + action='create', reason='name is required' + ) values = self.obj_get_changes() self._to_profile_json(values) @@ -105,8 +105,9 @@ class DeviceProfile(base.CyborgObject, object_base.VersionedObjectDictCompat): updates = self.obj_get_changes() self._to_profile_json(updates) - db_devprof = self.dbapi.device_profile_update(context, - self.name, updates) + db_devprof = self.dbapi.device_profile_update( + context, self.name, updates + ) self._from_db_object(self, db_devprof) def destroy(self, context): diff --git a/cyborg/objects/driver_objects/driver_attach_handle.py b/cyborg/objects/driver_objects/driver_attach_handle.py index 91e4e7b6..2fa4ef99 100644 --- a/cyborg/objects/driver_objects/driver_attach_handle.py +++ b/cyborg/objects/driver_objects/driver_attach_handle.py @@ -19,12 +19,14 @@ from oslo_versionedobjects import base as object_base from cyborg.objects.attach_handle import AttachHandle from cyborg.objects import base from cyborg.objects import fields as object_fields + LOG = logging.getLogger(__name__) @base.CyborgObjectRegistry.register -class DriverAttachHandle(base.DriverObjectBase, - object_base.VersionedObjectDictCompat): +class DriverAttachHandle( + base.DriverObjectBase, object_base.VersionedObjectDictCompat +): # Version 1.0: Initial version VERSION = '1.0' @@ -33,26 +35,27 @@ class DriverAttachHandle(base.DriverObjectBase, # PCI BDF or mediated device ID... 'attach_info': object_fields.StringField(nullable=False), # The status of attach_handle, is in use or not. - 'in_use': object_fields.BooleanField(nullable=False, default=False) + 'in_use': object_fields.BooleanField(nullable=False, default=False), } def create(self, context, deployable_id, cpid_id): """Create a driver-side AttachHandle object, call AttachHandle Object to store in DB. """ - attach_handle_obj = AttachHandle(context=context, - deployable_id=deployable_id, - cpid_id=cpid_id, - attach_type=self.attach_type, - attach_info=self.attach_info, - in_use=self.in_use - ) + attach_handle_obj = AttachHandle( + context=context, + deployable_id=deployable_id, + cpid_id=cpid_id, + attach_type=self.attach_type, + attach_info=self.attach_info, + in_use=self.in_use, + ) attach_handle_obj.create(context) def destroy(self, context, deployable_id): - ah_obj = AttachHandle.get_ah_by_depid_attachinfo(context, - deployable_id, - self.attach_info) + ah_obj = AttachHandle.get_ah_by_depid_attachinfo( + context, deployable_id, self.attach_info + ) if ah_obj is not None: ah_obj.destroy(context) @@ -60,12 +63,15 @@ class DriverAttachHandle(base.DriverObjectBase, def list(cls, context, deployable_id): """Form a driver-side attach_handle list for one deployable.""" ah_obj_list = AttachHandle.get_ah_list_by_deployable_id( - context, deployable_id) + context, deployable_id + ) driver_ah_obj_list = [] for ah_obj in ah_obj_list: - driver_ah_obj = cls(context=context, - attach_type=ah_obj.attach_type, - attach_info=ah_obj.attach_info, - in_use=ah_obj.in_use) + driver_ah_obj = cls( + context=context, + attach_type=ah_obj.attach_type, + attach_info=ah_obj.attach_info, + in_use=ah_obj.in_use, + ) driver_ah_obj_list.append(driver_ah_obj) return driver_ah_obj_list diff --git a/cyborg/objects/driver_objects/driver_attribute.py b/cyborg/objects/driver_objects/driver_attribute.py index d32d79b5..96d532c8 100644 --- a/cyborg/objects/driver_objects/driver_attribute.py +++ b/cyborg/objects/driver_objects/driver_attribute.py @@ -21,14 +21,15 @@ from cyborg.objects import fields as object_fields @base.CyborgObjectRegistry.register -class DriverAttribute(base.DriverObjectBase, - object_base.VersionedObjectDictCompat): +class DriverAttribute( + base.DriverObjectBase, object_base.VersionedObjectDictCompat +): # Version 1.0: Initial version VERSION = '1.0' fields = { 'key': object_fields.StringField(nullable=False), - 'value': object_fields.StringField(nullable=False) + 'value': object_fields.StringField(nullable=False), } def create(self, context, deployable_id): @@ -61,8 +62,8 @@ class DriverAttribute(base.DriverObjectBase, attr_obj_list = Attribute.get_by_deployable_id(context, deployable_id) driver_attr_obj_list = [] for attr_obj in attr_obj_list: - driver_attr_obj = cls(context=context, - key=attr_obj.key, - value=attr_obj.value) + driver_attr_obj = cls( + context=context, key=attr_obj.key, value=attr_obj.value + ) driver_attr_obj_list.append(driver_attr_obj) return driver_attr_obj_list diff --git a/cyborg/objects/driver_objects/driver_controlpath_id.py b/cyborg/objects/driver_objects/driver_controlpath_id.py index ecba9d00..d6bc35d6 100644 --- a/cyborg/objects/driver_objects/driver_controlpath_id.py +++ b/cyborg/objects/driver_objects/driver_controlpath_id.py @@ -21,32 +21,35 @@ from cyborg.objects import fields as object_fields @base.CyborgObjectRegistry.register -class DriverControlPathID(base.DriverObjectBase, - object_base.VersionedObjectDictCompat): +class DriverControlPathID( + base.DriverObjectBase, object_base.VersionedObjectDictCompat +): # Version 1.0: Initial version VERSION = '1.0' fields = { 'cpid_type': object_fields.StringField(nullable=False), # PCI BDF, PowerVM device, etc. - 'cpid_info': object_fields.StringField(nullable=False) + 'cpid_info': object_fields.StringField(nullable=False), } def create(self, context, device_id): """Create a driver-side ControlPathID for drivers. Call ControlpathID object to store in DB. """ - cpid_obj = ControlpathID(context=context, - device_id=device_id, - cpid_type=self.cpid_type, - cpid_info=self.cpid_info) + cpid_obj = ControlpathID( + context=context, + device_id=device_id, + cpid_type=self.cpid_type, + cpid_info=self.cpid_info, + ) cpid_obj.create(context) return cpid_obj def destroy(self, context, device_id): - cpid_obj = ControlpathID.get_by_device_id_cpidinfo(context, - device_id, - self.cpid_info) + cpid_obj = ControlpathID.get_by_device_id_cpidinfo( + context, device_id, self.cpid_info + ) if cpid_obj is not None: cpid_obj.destroy(context) @@ -56,7 +59,9 @@ class DriverControlPathID(base.DriverObjectBase, cpid_obj = ControlpathID.get_by_device_id(context, device_id) driver_cpid_obj = None if cpid_obj is not None: - driver_cpid_obj = cls(context=context, - cpid_type=cpid_obj.cpid_type, - cpid_info=cpid_obj.cpid_info) + driver_cpid_obj = cls( + context=context, + cpid_type=cpid_obj.cpid_type, + cpid_info=cpid_obj.cpid_info, + ) return driver_cpid_obj diff --git a/cyborg/objects/driver_objects/driver_deployable.py b/cyborg/objects/driver_objects/driver_deployable.py index 6029bc6c..54c48756 100644 --- a/cyborg/objects/driver_objects/driver_deployable.py +++ b/cyborg/objects/driver_objects/driver_deployable.py @@ -17,15 +17,17 @@ from oslo_versionedobjects import base as object_base from cyborg.objects import base from cyborg.objects.deployable import Deployable -from cyborg.objects.driver_objects.driver_attach_handle import \ - DriverAttachHandle +from cyborg.objects.driver_objects.driver_attach_handle import ( + DriverAttachHandle, +) from cyborg.objects.driver_objects.driver_attribute import DriverAttribute from cyborg.objects import fields as object_fields @base.CyborgObjectRegistry.register -class DriverDeployable(base.DriverObjectBase, - object_base.VersionedObjectDictCompat): +class DriverDeployable( + base.DriverObjectBase, object_base.VersionedObjectDictCompat +): # Version 1.0: Initial version VERSION = '1.0' @@ -33,12 +35,14 @@ class DriverDeployable(base.DriverObjectBase, 'name': object_fields.StringField(nullable=False), 'num_accelerators': object_fields.IntegerField(nullable=False), 'attribute_list': object_fields.ListOfObjectsField( - 'DriverAttribute', default=[], nullable=True), + 'DriverAttribute', default=[], nullable=True + ), # TODO() add field related to local_memory or just store in the # attribute list? 'attach_handle_list': object_fields.ListOfObjectsField( - 'DriverAttachHandle', default=[], nullable=True), - 'driver_name': object_fields.StringField(nullable=True) + 'DriverAttachHandle', default=[], nullable=True + ), + 'driver_name': object_fields.StringField(nullable=True), } def create(self, context, device_id, cpid_id): @@ -48,12 +52,13 @@ class DriverDeployable(base.DriverObjectBase, """ # first store in deployable table through Deployable Object. - deployable_obj = Deployable(context=context, - name=self.name, - num_accelerators=self.num_accelerators, - device_id=device_id, - driver_name=self.driver_name - ) + deployable_obj = Deployable( + context=context, + name=self.name, + num_accelerators=self.num_accelerators, + device_id=device_id, + driver_name=self.driver_name, + ) deployable_obj.create(context) # create attribute_list for this deployable if hasattr(self, 'attribute_list'): @@ -63,8 +68,9 @@ class DriverDeployable(base.DriverObjectBase, # create attach_handle_list for this deployable if hasattr(self, 'attach_handle_list'): for driver_attach_handle in self.attach_handle_list: - driver_attach_handle.create(context, deployable_obj.id, - cpid_id) + driver_attach_handle.create( + context, deployable_obj.id, cpid_id + ) def destroy(self, context, device_id): """delete one driver-side deployable by calling existing Deployable @@ -73,8 +79,9 @@ class DriverDeployable(base.DriverObjectBase, """ # get deployable_id by name, get only one value. - dep_obj = Deployable.get_by_name_deviceid(context, self.name, - device_id) + dep_obj = Deployable.get_by_name_deviceid( + context, self.name, device_id + ) # delete attach_handle if hasattr(self, 'attach_handle_list'): for driver_ah_obj in self.attach_handle_list: @@ -98,11 +105,13 @@ class DriverDeployable(base.DriverObjectBase, driver_ah_obj_list = DriverAttachHandle.list(context, dep_obj.id) # get driver_attr_obj_list for this dep_obj driver_attr_obj_list = DriverAttribute.list(context, dep_obj.id) - driver_dep_obj = cls(context=context, - name=dep_obj.name, - num_accelerators=dep_obj.num_accelerators, - attribute_list=driver_attr_obj_list, - attach_handle_list=driver_ah_obj_list) + driver_dep_obj = cls( + context=context, + name=dep_obj.name, + num_accelerators=dep_obj.num_accelerators, + attribute_list=driver_attr_obj_list, + attach_handle_list=driver_ah_obj_list, + ) driver_dep_obj_list.append(driver_dep_obj) return driver_dep_obj_list @@ -114,8 +123,11 @@ class DriverDeployable(base.DriverObjectBase, driver_ah_obj_list = DriverAttachHandle.list(context, dep_obj.id) # get driver_attr_obj_list for this dep_obj driver_attr_obj_list = DriverAttribute.list(context, dep_obj.id) - driver_dep_obj = cls(context=context, name=dep_obj.name, - num_accelerators=dep_obj.num_accelerators, - attribute_list=driver_attr_obj_list, - attach_handle_list=driver_ah_obj_list) + driver_dep_obj = cls( + context=context, + name=dep_obj.name, + num_accelerators=dep_obj.num_accelerators, + attribute_list=driver_attr_obj_list, + attach_handle_list=driver_ah_obj_list, + ) return driver_dep_obj diff --git a/cyborg/objects/driver_objects/driver_device.py b/cyborg/objects/driver_objects/driver_device.py index 264c3e26..d607f062 100644 --- a/cyborg/objects/driver_objects/driver_device.py +++ b/cyborg/objects/driver_objects/driver_device.py @@ -18,15 +18,17 @@ from oslo_versionedobjects import base as object_base from cyborg.objects import base from cyborg.objects.control_path import ControlpathID from cyborg.objects.device import Device -from cyborg.objects.driver_objects.driver_controlpath_id import \ - DriverControlPathID +from cyborg.objects.driver_objects.driver_controlpath_id import ( + DriverControlPathID, +) from cyborg.objects.driver_objects.driver_deployable import DriverDeployable from cyborg.objects import fields as object_fields @base.CyborgObjectRegistry.register -class DriverDevice(base.DriverObjectBase, - object_base.VersionedObjectDictCompat): +class DriverDevice( + base.DriverObjectBase, object_base.VersionedObjectDictCompat +): # Version 1.0: Initial version VERSION = '1.0' @@ -41,12 +43,13 @@ class DriverDevice(base.DriverObjectBase, # hostname will be set by the agent, so driver don't need to report. # Each controlpath_id corresponds to a different PF. For now # we are sticking with a single cpid. - 'controlpath_id': object_fields.ObjectField('DriverControlPathID', - nullable=False), - 'deployable_list': object_fields.ListOfObjectsField('DriverDeployable', - default=[], - nullable=False), - 'stub': object_fields.BooleanField(nullable=False, default=False) + 'controlpath_id': object_fields.ObjectField( + 'DriverControlPathID', nullable=False + ), + 'deployable_list': object_fields.ListOfObjectsField( + 'DriverDeployable', default=[], nullable=False + ), + 'stub': object_fields.BooleanField(nullable=False, default=False), } def create(self, context, host): @@ -56,12 +59,13 @@ class DriverDevice(base.DriverObjectBase, """ # first store in device table through Device Object. - device_obj = Device(context=context, - type=self.type, - vendor=self.vendor, - model=self.model, - hostname=host - ) + device_obj = Device( + context=context, + type=self.type, + vendor=self.vendor, + model=self.model, + hostname=host, + ) if hasattr(self, 'std_board_info'): device_obj.std_board_info = self.std_board_info if hasattr(self, 'vendor_board_info'): @@ -86,7 +90,8 @@ class DriverDevice(base.DriverObjectBase, driver_deployable.destroy(context, device_obj.id) if hasattr(self.controlpath_id, 'cpid_info'): cpid_obj = ControlpathID.get_by_device_id_cpidinfo( - context, device_obj.id, self.controlpath_id.cpid_info) + context, device_obj.id, self.controlpath_id.cpid_info + ) # delete controlpath_id cpid_obj.destroy(context) # delete the device @@ -106,7 +111,8 @@ class DriverDevice(base.DriverObjectBase, for device_obj in device_obj_list: # get cpid_obj, could be empty or only one value. cpid_obj = ControlpathID.get_by_device_id_cpidinfo( - context, device_obj.id, self.controlpath_id.cpid_info) + context, device_obj.id, self.controlpath_id.cpid_info + ) # find the one cpid_obj with cpid_info if cpid_obj is not None: return device_obj @@ -125,15 +131,16 @@ class DriverDevice(base.DriverObjectBase, cpid = DriverControlPathID.get(context, dev_obj.id) # NOTE: will not return device without controlpath_id. if cpid is not None: - driver_dev_obj = \ - cls(context=context, vendor=dev_obj.vendor, - model=dev_obj.model, type=dev_obj.type, - std_board_info=dev_obj.std_board_info, - vendor_board_info=dev_obj.vendor_board_info, - controlpath_id=cpid, - deployable_list=DriverDeployable.list(context, - dev_obj.id) - ) + driver_dev_obj = cls( + context=context, + vendor=dev_obj.vendor, + model=dev_obj.model, + type=dev_obj.type, + std_board_info=dev_obj.std_board_info, + vendor_board_info=dev_obj.vendor_board_info, + controlpath_id=cpid, + deployable_list=DriverDeployable.list(context, dev_obj.id), + ) driver_dev_obj_list.append(driver_dev_obj) return driver_dev_obj_list @@ -150,6 +157,7 @@ class DriverDevice(base.DriverObjectBase, # use controlpath_id.cpid_info to identify one Device. # get cpid_obj, could be empty or only one value. ControlpathID.get_by_device_id_cpidinfo( - context, device_obj.id, self.controlpath_id.cpid_info) + context, device_obj.id, self.controlpath_id.cpid_info + ) # find the one cpid_obj with cpid_info return device_obj diff --git a/cyborg/objects/ext_arq.py b/cyborg/objects/ext_arq.py index 8ce8ade4..9a7bef2f 100644 --- a/cyborg/objects/ext_arq.py +++ b/cyborg/objects/ext_arq.py @@ -39,15 +39,20 @@ LOG = logging.getLogger(__name__) @base.CyborgObjectRegistry.register -class ExtARQ(base.CyborgObject, object_base.VersionedObjectDictCompat, - utils.FactoryMixin, ExtARQJobMixin): +class ExtARQ( + base.CyborgObject, + object_base.VersionedObjectDictCompat, + utils.FactoryMixin, + ExtARQJobMixin, +): """ExtARQ is a wrapper around ARQ with Cyborg-private fields. - Each ExtARQ object contains exactly one ARQ object as a field. - But, in the db layer, ExtARQ and ARQ are represented together - as a row in a single table. Both share a single UUID. - ExtARQ version is bumped up either if any of its fields change - or if the ARQ version changes. + Each ExtARQ object contains exactly one ARQ object as a field. + But, in the db layer, ExtARQ and ARQ are represented together + as a row in a single table. Both share a single UUID. + ExtARQ version is bumped up either if any of its fields change + or if the ARQ version changes. """ + # Version 1.0: Initial version # 1.1: v2 API and Nova integration # 1.2: Fill the value of deployable_id @@ -62,11 +67,11 @@ class ExtARQ(base.CyborgObject, object_base.VersionedObjectDictCompat, # later. 'substate': object_fields.StringField(), 'deployable_uuid': object_fields.UUIDField(nullable=True), - # The dp group is copied in to the extarq, so that any changes or # deletions to the device profile do not affect running VMs. 'device_profile_group': object_fields.DictOfStringsField( - nullable=True), + nullable=True + ), # For bound ARQs, we keep the attach handle ID and deployable ID here # so that it is easy to deallocate on unbind or delete. 'attach_handle_id': object_fields.IntegerField(nullable=True), @@ -74,8 +79,7 @@ class ExtARQ(base.CyborgObject, object_base.VersionedObjectDictCompat, } def obj_make_compatible(self, primitive, target_version): - super().obj_make_compatible( - primitive, target_version) + super().obj_make_compatible(primitive, target_version) target_version = versionutils.convert_version_to_tuple(target_version) # TODO(eric): need to handle v1.1 changes if target_version < (1, 2) and 'deployable_id' in primitive: @@ -90,7 +94,8 @@ class ExtARQ(base.CyborgObject, object_base.VersionedObjectDictCompat, if 'device_profile_name' not in self.arq and not device_profile_id: raise exception.ObjectActionError( action='create', - reason='Device profile name is required in ARQ') + reason='Device profile name is required in ARQ', + ) self.arq.state = constants.ARQ_INITIAL self.substate = constants.ARQ_INITIAL values = self.obj_get_changes() @@ -120,8 +125,7 @@ class ExtARQ(base.CyborgObject, object_base.VersionedObjectDictCompat, def list(cls, context, uuid_range=None): """Return a list of ExtARQ objects.""" db_extarqs = cls.dbapi.extarq_list(context, uuid_range) - obj_extarq_list = cls._from_db_object_list( - db_extarqs, context) + obj_extarq_list = cls._from_db_object_list(db_extarqs, context) return obj_extarq_list def save(self, context): @@ -135,13 +139,17 @@ class ExtARQ(base.CyborgObject, object_base.VersionedObjectDictCompat, updates = self.obj_get_changes() updates["state"] = state db_extarq = self.dbapi.extarq_update( - context, self.arq.uuid, updates, scope) + context, self.arq.uuid, updates, scope + ) self._from_db_object(self, db_extarq, context) def update_check_state(self, context, state, scope=None): if self.arq.state == state: - LOG.info("ExtARQ(%s) state is %s, no need to update", - self.arq.uuid, state) + LOG.info( + "ExtARQ(%s) state is %s, no need to update", + self.arq.uuid, + state, + ) return False old = self.arq.state scope = scope or ARQ_STATES_TRANSFORM_MATRIX[state] @@ -150,14 +158,18 @@ class ExtARQ(base.CyborgObject, object_base.VersionedObjectDictCompat, if not ea: raise exception.ResourceNotFound( resources='ExtARQ', - msg="Can not find ExtARQ(%s)" % self.arq.uuid) + msg="Can not find ExtARQ(%s)" % self.arq.uuid, + ) current = ea.arq.state if state != current: - msg = ("Failed to change ARQ state from %s to %s, the current " - "state is %s" % (old, state, current)) + msg = ( + "Failed to change ARQ state from %s to %s, the current " + "state is %s" % (old, state, current) + ) LOG.error(msg) raise exception.ARQBadState( - state=current, uuid=self.arq.uuid, expected=list(state)) + state=current, uuid=self.arq.uuid, expected=list(state) + ) return True def destroy(self, context): @@ -188,8 +200,8 @@ class ExtARQ(base.CyborgObject, object_base.VersionedObjectDictCompat, if unexisted: LOG.warning('There are unexisted arqs: %s', unexisted) raise exception.ResourceNotFound( - resource='ARQ', - msg='with uuids %s' % unexisted) + resource='ARQ', msg='with uuids %s' % unexisted + ) @classmethod def delete_by_instance(cls, context, instance_uuid): @@ -200,11 +212,17 @@ class ExtARQ(base.CyborgObject, object_base.VersionedObjectDictCompat, not raise an error on the second and later attempts even if the first one has deleted the ARQs. """ - obj_extarqs = [extarq for extarq in objects.ExtARQ.list(context) - if extarq.arq['instance_uuid'] == instance_uuid] + obj_extarqs = [ + extarq + for extarq in objects.ExtARQ.list(context) + if extarq.arq['instance_uuid'] == instance_uuid + ] for obj_extarq in obj_extarqs: - LOG.info('Deleting obj_extarq uuid %s for instance %s', - obj_extarq.arq['uuid'], obj_extarq.arq['instance_uuid']) + LOG.info( + 'Deleting obj_extarq uuid %s for instance %s', + obj_extarq.arq['uuid'], + obj_extarq.arq['instance_uuid'], + ) obj_extarq.unbind(context) obj_extarq.destroy(context) @@ -224,23 +242,33 @@ class ExtARQ(base.CyborgObject, object_base.VersionedObjectDictCompat, if ah.attach_type == 'MDEV': attach_info = json.loads(ah.attach_info) pci_addr = "{}:{}:{}.{}".format( - attach_info['domain'], attach_info['bus'], - attach_info['device'], attach_info['function']) + attach_info['domain'], + attach_info['bus'], + attach_info['device'], + attach_info['function'], + ) hostname = self.arq.hostname asked_type = attach_info['asked_type'] self.agent.create_vgpu_mdev( - context, hostname, pci_addr, asked_type, ah.uuid) + context, hostname, pci_addr, asked_type, ah.uuid + ) except Exception as e: - LOG.error("Failed to allocate attach handle for ARQ %s" - "from deployable %s. Reason: %s", - self.arq.uuid, deployable.uuid, str(e)) + LOG.error( + "Failed to allocate attach handle for ARQ %s" + "from deployable %s. Reason: %s", + self.arq.uuid, + deployable.uuid, + str(e), + ) # TODO(Shaohe) Rollback? We have _update_placement, # should cancel it. - self.update_check_state( - context, constants.ARQ_BIND_FAILED) + self.update_check_state(context, constants.ARQ_BIND_FAILED) raise - LOG.info('Attach handle(%s) allocate for ARQ(%s) successfully.', - ah.uuid, self.arq.uuid) + LOG.info( + 'Attach handle(%s) allocate for ARQ(%s) successfully.', + ah.uuid, + self.arq.uuid, + ) def bind(self, context, deployable): self._allocate_attach_handle(context, deployable) @@ -248,8 +276,7 @@ class ExtARQ(base.CyborgObject, object_base.VersionedObjectDictCompat, self.save(context) # ARQ state changes get committed here self.update_check_state(context, constants.ARQ_BOUND) - LOG.info('Update ARQ %s state to "Bound" successfully.', - self.arq.uuid) + LOG.info('Update ARQ %s state to "Bound" successfully.', self.arq.uuid) # TODO(Shaohe) rollback self._unbind and self._delete # if (self.arq.state == constants.ARQ_DELETING # or self.arq.state == ARQ_UNBOUND): @@ -260,20 +287,33 @@ class ExtARQ(base.CyborgObject, object_base.VersionedObjectDictCompat, if attach_handle.attach_type == 'MDEV': attach_info = json.loads(attach_handle.attach_info) pci_addr = "{}:{}:{}.{}".format( - attach_info['domain'], attach_info['bus'], - attach_info['device'], attach_info['function']) + attach_info['domain'], + attach_info['bus'], + attach_info['device'], + attach_info['function'], + ) self.agent.remove_vgpu_mdev( - context, hostname, pci_addr, - attach_info['asked_type'], attach_handle.uuid) + context, + hostname, + pci_addr, + attach_info['asked_type'], + attach_handle.uuid, + ) attach_handle.deallocate(context) except Exception as e: - LOG.error("Failed to deallocate attach handle %s for ARQ %s." - "Reason: %s", ah_id, self.arq.uuid, str(e)) - self.update_check_state( - context, constants.ARQ_UNBIND_FAILED) + LOG.error( + "Failed to deallocate attach handle %s for ARQ %s.Reason: %s", + ah_id, + self.arq.uuid, + str(e), + ) + self.update_check_state(context, constants.ARQ_UNBIND_FAILED) raise - LOG.info('Attach handle(%s) deallocate for ARQ(%s) successfully.', - ah_id, self.arq.uuid) + LOG.info( + 'Attach handle(%s) deallocate for ARQ(%s) successfully.', + ah_id, + self.arq.uuid, + ) def unbind(self, context): arq = self.arq @@ -294,7 +334,7 @@ class ExtARQ(base.CyborgObject, object_base.VersionedObjectDictCompat, @classmethod def _fill_obj_extarq_fields(cls, context, db_extarq): """ExtARQ object has some fields that are not present - in db_extarq. We fill them out here. + in db_extarq. We fill them out here. """ # From the 2 fields in the ExtARQ, we obtain other fields. devprof_id = db_extarq['device_profile_id'] @@ -307,7 +347,8 @@ class ExtARQ(base.CyborgObject, object_base.VersionedObjectDictCompat, db_extarq['attach_handle_info'] = '' if db_extarq['state'] == 'Bound': # TODO() Do proper bind db_ah = cls.dbapi.attach_handle_get_by_id( - context, db_extarq['attach_handle_id']) + context, db_extarq['attach_handle_id'] + ) if db_ah is not None: db_extarq['attach_handle_type'] = db_ah['attach_type'] db_extarq['attach_handle_info'] = db_ah['attach_info'] @@ -315,17 +356,22 @@ class ExtARQ(base.CyborgObject, object_base.VersionedObjectDictCompat, else: raise exception.ResourceNotFound( resource='Attach Handle', - msg='with uuid=%s' % db_extarq['attach_handle_id']) + msg='with uuid=%s' % db_extarq['attach_handle_id'], + ) if db_extarq['deployable_id']: - dep = objects.Deployable.get_by_id(context, - db_extarq['deployable_id']) + dep = objects.Deployable.get_by_id( + context, db_extarq['deployable_id'] + ) db_extarq['deployable_uuid'] = dep.uuid else: - LOG.debug('Setting deployable UUID to zeroes for db_extarq %s', - db_extarq['uuid']) + LOG.debug( + 'Setting deployable UUID to zeroes for db_extarq %s', + db_extarq['uuid'], + ) db_extarq['deployable_uuid'] = ( - '00000000-0000-0000-0000-000000000000') + '00000000-0000-0000-0000-000000000000' + ) groups = devprof['groups'] db_extarq['device_profile_group'] = groups[devprof_group_id] diff --git a/cyborg/objects/extarq/ext_arq_job.py b/cyborg/objects/extarq/ext_arq_job.py index 9b429ed9..1c9d94d7 100644 --- a/cyborg/objects/extarq/ext_arq_job.py +++ b/cyborg/objects/extarq/ext_arq_job.py @@ -61,18 +61,25 @@ class ExtARQJobMixin: # Check whether ARQ can be bound. if self.arq.state not in expected: raise exception.ARQBadState( - state=self.arq.state, uuid=self.arq.uuid, expected=expected) + state=self.arq.state, uuid=self.arq.uuid, expected=expected + ) hostname = valid_fields[self.arq.uuid]['hostname'] devrp_uuid = valid_fields[self.arq.uuid]['device_rp_uuid'] instance_uuid = valid_fields[self.arq.uuid]['instance_uuid'] project_id = valid_fields[self.arq.uuid].get('project_id') - LOG.info('[arqs:objs] bind. hostname: %(hostname)s, ' - 'devrp_uuid: %(devrp_uuid)s, ' - 'instance_uuid: %(instance_uuid)s, ' - 'project_id: %(project_id)s', - {'hostname': hostname, 'devrp_uuid': devrp_uuid, - 'instance_uuid': instance_uuid, 'project_id': project_id}) + LOG.info( + '[arqs:objs] bind. hostname: %(hostname)s, ' + 'devrp_uuid: %(devrp_uuid)s, ' + 'instance_uuid: %(instance_uuid)s, ' + 'project_id: %(project_id)s', + { + 'hostname': hostname, + 'devrp_uuid': devrp_uuid, + 'instance_uuid': instance_uuid, + 'project_id': project_id, + }, + ) self.arq.hostname = hostname self.arq.device_rp_uuid = devrp_uuid @@ -100,11 +107,13 @@ class ExtARQJobMixin: return th_workers = utils.ThreadWorks() works_generator = th_workers.get_workers_result( - jobs.values(), timeout=CONF.bind_timeout) + jobs.values(), timeout=CONF.bind_timeout + ) # arq_binds, timeout=1) LOG.info("Check ARQ(%s) bind jobs status.", arq_uuids) th_workers.spawn_master( - cls.job_monitor, context, works_generator, arq_binds.keys()) + cls.job_monitor, context, works_generator, arq_binds.keys() + ) @classmethod def get_arq_bind_statuses(cls, arq_list): @@ -124,7 +133,8 @@ class ExtARQJobMixin: for arq in arq_list: if arq.state not in good_states: raise exception.ARQBadState( - state=arq.state, uuid=arq.uuid, expected=good_states) + state=arq.state, uuid=arq.uuid, expected=good_states + ) arq_bind_status = (arq.uuid, state_map[arq.state]) arq_bind_statuses.append(arq_bind_status) return arq_bind_statuses @@ -158,10 +168,13 @@ class ExtARQJobMixin: # (not deleted) ARQs among the specified ones. extarqs = cls.list(context, arq_uuids) if len(extarqs) < len(arq_uuids): - LOG.error("ARQs(%s) bind status sync error, status is %s. " - "For some ARQs %s are deleted.", - arq_uuids, constants.ARQ_BIND_STATUS_FAILED, - set(arq_uuids) - set([ea.arq.uuid for ea in extarqs])) + LOG.error( + "ARQs(%s) bind status sync error, status is %s. " + "For some ARQs %s are deleted.", + arq_uuids, + constants.ARQ_BIND_STATUS_FAILED, + set(arq_uuids) - set([ea.arq.uuid for ea in extarqs]), + ) cls.bind_notify(instance_uuid, cls.get_arq_bind_statuses(arq_list)) return @@ -172,19 +185,29 @@ class ExtARQJobMixin: if state in constants.ARQ_PRE_BIND: # OPEN ignore ARQ_OUFOF_BIND_FLOW? status = constants.ARQ_BIND_STATUS_FAILED - LOG.error("ARQs(%s) bind has not finished, status is %s.", - uuid, status) + LOG.error( + "ARQs(%s) bind has not finished, status is %s.", + uuid, + status, + ) break elif state in constants.ARQ_OUFOF_BIND_FLOW + [ - constants.ARQ_BIND_FAILED]: + constants.ARQ_BIND_FAILED + ]: # OPEN ignore ARQ_OUFOF_BIND_FLOW? status = constants.ARQ_BIND_STATUS_FAILED - LOG.error("ARQs(%s) bind status sync error, status is %s.", - uuid, status) + LOG.error( + "ARQs(%s) bind status sync error, status is %s.", + uuid, + status, + ) break elif state == constants.ARQ_BOUND: - LOG.info("ARQs(%s) bind status sync finish, status is %s.", - uuid, status) + LOG.info( + "ARQs(%s) bind status sync finish, status is %s.", + uuid, + status, + ) if status == constants.ARQ_BIND_STATUS_FINISH: LOG.info('All ARQs %s async bind jobs has finished.', arq_uuids) cls.bind_notify(instance_uuid, cls.get_arq_bind_statuses(arq_list)) @@ -219,11 +242,14 @@ class ExtARQJobMixin: group = self.device_profile_group # example: {"resources:CUSTOM_ACCELERATOR_FPGA": "1"} resources = [ - (k.lstrip(constants.RESOURCES_PREFIX), v) for k, v in group.items() - if k.startswith(constants.RESOURCES_PREFIX)] + (k.lstrip(constants.RESOURCES_PREFIX), v) + for k, v in group.items() + if k.startswith(constants.RESOURCES_PREFIX) + ] if not resources: raise exception.InvalidParameterValue( - 'No resources in device_profile_group: %s' % group) + 'No resources in device_profile_group: %s' % group + ) res_type = resources[0][0] return res_type diff --git a/cyborg/objects/extarq/fpga_ext_arq.py b/cyborg/objects/extarq/fpga_ext_arq.py index 371025d4..f892fdcd 100644 --- a/cyborg/objects/extarq/fpga_ext_arq.py +++ b/cyborg/objects/extarq/fpga_ext_arq.py @@ -17,7 +17,6 @@ Different accelerator handlers for conductor/agent/api/object to call. """ - from openstack import connection from oslo_log import log as logging from oslo_serialization import jsonutils @@ -42,12 +41,14 @@ class FPGAExtARQ(ExtARQ): def _get_bitstream_id(self): bitstream_id = self.device_profile_group.get( - constants.ACCEL_BITSTREAM_ID) + constants.ACCEL_BITSTREAM_ID + ) return bitstream_id def _get_function_id(self): function_id = self.device_profile_group.get( - constants.ACCEL_FUNCTION_ID) + constants.ACCEL_FUNCTION_ID + ) return function_id def _get_bitstream_md_from_bitstream_id(self, bitstream_id): @@ -57,8 +58,7 @@ class FPGAExtARQ(ExtARQ): if resp: return resp.json() else: - LOG.warning('Failed to get image for bitstream (%s)', - bitstream_id) + LOG.warning('Failed to get image for bitstream (%s)', bitstream_id) return None # TODO(Shaohe) should move to spec handler. @@ -72,52 +72,62 @@ class FPGAExtARQ(ExtARQ): image_list = resp.json()['images'] if not isinstance(image_list, list): raise exception.InvalidType( - obj='image', type=type(image_list), - expected='list') + obj='image', type=type(image_list), expected='list' + ) if len(image_list) != 1: - raise exception.ExpectedOneObject(obj='image', - count=len(image_list)) - LOG.info('[arqs:objs] For function id (%s), got ' - 'bitstream id (%s)', function_id, - image_list[0]['id']) + raise exception.ExpectedOneObject( + obj='image', count=len(image_list) + ) + LOG.info( + '[arqs:objs] For function id (%s), got bitstream id (%s)', + function_id, + image_list[0]['id'], + ) return image_list[0] else: - LOG.warning('Failed to get image for function (%s)', - function_id) + LOG.warning('Failed to get image for function (%s)', function_id) return None def _needs_programming(self, context, deployable): bs_id = self._get_bitstream_id() fun_id = self._get_function_id() if all([bs_id, fun_id]): - self.update_check_state( - context, constants.ARQ_BIND_FAILED) + self.update_check_state(context, constants.ARQ_BIND_FAILED) raise exception.InvalidParameterValue( 'In device profile {0}, only one among bitstream_id ' - 'and function_id must be set, but both are set') + 'and function_id must be set, but both are set' + ) # TODO(Shaohe) Optimize: check if deployable already has # bitstream/function if any([bs_id, fun_id]): - LOG.info('[arqs:objs] bind. Programming needed. ' - 'bitstream: (%s) function: (%s) Deployable UUID: (%s)', - bs_id or '', fun_id or '', deployable.uuid) + LOG.info( + '[arqs:objs] bind. Programming needed. ' + 'bitstream: (%s) function: (%s) Deployable UUID: (%s)', + bs_id or '', + fun_id or '', + deployable.uuid, + ) else: # One situation is that fun_id is zero and device_profile # has't bitstream. We should return False. LOG.info('No programming is required. ') return False if bs_id and deployable.bitstream_id == bs_id: - LOG.info('Deployable %s already has the needed ' - 'bitstream %s. Skipping programming.', - deployable.uuid, bs_id) + LOG.info( + 'Deployable %s already has the needed ' + 'bitstream %s. Skipping programming.', + deployable.uuid, + bs_id, + ) return False return True def get_bitstream_md(self, context, deployable, function_id, bitstream_id): """Get bitstream metadata from FPGA image.""" - LOG.info("Get bitstream metadata for deployable(uuid:%s).", - deployable.uuid) + LOG.info( + "Get bitstream metadata for deployable(uuid:%s).", deployable.uuid + ) # TODO(Shaohe) Check that deployable.device.hostname matches param # hostname out of here if not self._needs_programming(context, deployable): @@ -126,16 +136,22 @@ class FPGAExtARQ(ExtARQ): # FPGA aaS or accelerated Function aaS bitstream_md = ( self._get_bitstream_md_from_bitstream_id(bitstream_id) - if bitstream_id else - self._get_bitstream_md_from_function_id(function_id)) + if bitstream_id + else self._get_bitstream_md_from_function_id(function_id) + ) if bitstream_md: - LOG.info('ARQ %s get bitstream metadata:%s from image registry.', - self.arq.uuid, bitstream_md) + LOG.info( + 'ARQ %s get bitstream metadata:%s from image registry.', + self.arq.uuid, + bitstream_md, + ) else: - self.update_check_state( - context, constants.ARQ_BIND_FAILED) - LOG.error('Can not get bitstream metadata from image registry ' - 'for ARQ %s', self.arq.uuid) + self.update_check_state(context, constants.ARQ_BIND_FAILED) + LOG.error( + 'Can not get bitstream metadata from image registry ' + 'for ARQ %s', + self.arq.uuid, + ) return bitstream_md def _need_extra_bind_job(self, context, deployable): @@ -143,8 +159,11 @@ class FPGAExtARQ(ExtARQ): @utils.wrap_job_tb("Error during ARQ bind job. Reason: %s") def bind(self, context, deployable): - LOG.info('Start bind jobs for ARQ(%s) with deployable(%s)', - self.arq.uuid, deployable.uuid) + LOG.info( + 'Start bind jobs for ARQ(%s) with deployable(%s)', + self.arq.uuid, + deployable.uuid, + ) bs_id = self._get_bitstream_id() fun_id = self._get_function_id() bs_md = self.get_bitstream_md(context, deployable, fun_id, bs_id) @@ -172,42 +191,52 @@ class FPGAExtARQ(ExtARQ): """update resources provider after program.""" # TODO(Sundar) Don't apply function trait if bitstream is private if not function_id: - LOG.info("Not get function id for resources provider %s.", - self.arq.device_rp_uuid) + LOG.info( + "Not get function id for resources provider %s.", + self.arq.device_rp_uuid, + ) return placement = placement_client.PlacementClient() try: placement.delete_traits_with_prefixes( - context, self.arq.device_rp_uuid, - [constants.FPGA_FUNCTION_ID]) + context, self.arq.device_rp_uuid, [constants.FPGA_FUNCTION_ID] + ) except Exception as e: - LOG.error("Failed to delete traits(%s) from resources provider %s." - "Reason: %s", constants.FPGA_FUNCTION_ID, - self.arq.device_rp_uuid, e.message) - self.update_check_state( - context, constants.ARQ_BIND_FAILED) + LOG.error( + "Failed to delete traits(%s) from resources provider %s." + "Reason: %s", + constants.FPGA_FUNCTION_ID, + self.arq.device_rp_uuid, + e.message, + ) + self.update_check_state(context, constants.ARQ_BIND_FAILED) raise function_id = function_id.upper().replace('-', '_-') # TODO(Sundar) Validate this is a valid trait name vendor = driver_name.upper() - trait_names = ["_".join(( - constants.FPGA_FUNCTION_ID, vendor, function_id))] + trait_names = [ + "_".join((constants.FPGA_FUNCTION_ID, vendor, function_id)) + ] try: - placement.add_traits_to_rp( - self.arq.device_rp_uuid, trait_names) + placement.add_traits_to_rp(self.arq.device_rp_uuid, trait_names) except Exception as e: - LOG.error("Failed to add traits(%s) to resources provider %s." - "Reason: %s", trait_names, - self.arq.device_rp_uuid, e.message) + LOG.error( + "Failed to add traits(%s) to resources provider %s.Reason: %s", + trait_names, + self.arq.device_rp_uuid, + e.message, + ) # TODO(Shaohe) Rollback? We have _update_placement, # should cancel it. - self.update_check_state( - context, constants.ARQ_BIND_FAILED) + self.update_check_state(context, constants.ARQ_BIND_FAILED) raise - LOG.info("Add traits(%s) to resources provider %s.", - trait_names, self.arq.device_rp_uuid) + LOG.info( + "Add traits(%s) to resources provider %s.", + trait_names, + self.arq.device_rp_uuid, + ) def _do_programming(self, context, deployable, bitstream_id): """FPGA program.""" @@ -220,31 +249,42 @@ class FPGAExtARQ(ExtARQ): cpid_list = deployable.get_cpid_list(context) count = len(cpid_list) if count != 1: - self.update_check_state( - context, constants.ARQ_BIND_FAILED) - raise exception.ExpectedOneObject(obj='controlpath_id', - count=count) + self.update_check_state(context, constants.ARQ_BIND_FAILED) + raise exception.ExpectedOneObject( + obj='controlpath_id', count=count + ) controlpath_id = cpid_list[0] controlpath_id['cpid_info'] = jsonutils.loads( - controlpath_id['cpid_info']) + controlpath_id['cpid_info'] + ) LOG.info('Found control path id: %s', controlpath_id) - LOG.info('Starting programming for host: (%s) deployable (%s) ' - 'bitstream_id (%s)', hostname, - deployable.uuid, bitstream_id) + LOG.info( + 'Starting programming for host: (%s) deployable (%s) ' + 'bitstream_id (%s)', + hostname, + deployable.uuid, + bitstream_id, + ) # TODO(Shaohe) do this asynchronously, do this in conductor or agent? try: agent = AgentAPI() - agent.fpga_program(context, hostname, - controlpath_id, bitstream_id, - driver_name) + agent.fpga_program( + context, hostname, controlpath_id, bitstream_id, driver_name + ) except Exception as e: - self.update_check_state( - context, constants.ARQ_BIND_FAILED) - LOG.error('Failed programming for host: (%s) deployable (%s). ' - 'Error: %s', hostname, deployable.uuid, e.message) + self.update_check_state(context, constants.ARQ_BIND_FAILED) + LOG.error( + 'Failed programming for host: (%s) deployable (%s). Error: %s', + hostname, + deployable.uuid, + e.message, + ) raise - LOG.info('Finished programming for host: (%s) deployable (%s)', - hostname, deployable.uuid) + LOG.info( + 'Finished programming for host: (%s) deployable (%s)', + hostname, + deployable.uuid, + ) # TODO(Shaohe) propagate agent errors to caller return True diff --git a/cyborg/objects/fields.py b/cyborg/objects/fields.py index 3ced4995..64b142fd 100644 --- a/cyborg/objects/fields.py +++ b/cyborg/objects/fields.py @@ -32,7 +32,8 @@ IPAddressField = object_fields.IPAddressField IPNetworkField = object_fields.IPNetworkField UnspecifiedDefault = object_fields.UnspecifiedDefault ListOfDictOfNullableStringsField = ( - object_fields.ListOfDictOfNullableStringsField) + object_fields.ListOfDictOfNullableStringsField +) class ARQState(object_fields.Enum): diff --git a/cyborg/policies/base.py b/cyborg/policies/base.py index 7a438010..3949ffdc 100644 --- a/cyborg/policies/base.py +++ b/cyborg/policies/base.py @@ -36,7 +36,7 @@ DEPRECATED_ADMIN_OR_OWNER = policy.DeprecatedRule( name=ADMIN_OR_OWNER, check_str='is_admin:True or project_id:%(project_id)s', deprecated_reason=DEPRECATED_REASON, - deprecated_since=versionutils.deprecated.WALLABY + deprecated_since=versionutils.deprecated.WALLABY, ) deprecated_default_policies = [ @@ -47,7 +47,8 @@ deprecated_default_policies = [ description='legacy rule of Internal flag for public API routes', deprecated_for_removal=True, deprecated_reason=DEPRECATED_REASON, - deprecated_since=versionutils.deprecated.WALLABY), + deprecated_since=versionutils.deprecated.WALLABY, + ), # The policy check "@" will always accept an access. The empty list # (``[]``) or the empty string (``""``) is equivalent to the "@" policy.RuleDefault( @@ -56,7 +57,8 @@ deprecated_default_policies = [ description='legacy rule: any access will be passed', deprecated_for_removal=True, deprecated_reason=DEPRECATED_REASON, - deprecated_since=versionutils.deprecated.WALLABY), + deprecated_since=versionutils.deprecated.WALLABY, + ), # the policy check "!" will always reject an access. policy.RuleDefault( name='deny', @@ -64,35 +66,40 @@ deprecated_default_policies = [ description='legacy rule: all access will be forbidden', deprecated_for_removal=True, deprecated_reason=DEPRECATED_REASON, - deprecated_since=versionutils.deprecated.WALLABY), + deprecated_since=versionutils.deprecated.WALLABY, + ), policy.RuleDefault( name='default', check_str='rule:admin_or_owner', description='Legacy rule for default rule', deprecated_for_removal=True, deprecated_reason=DEPRECATED_REASON, - deprecated_since=versionutils.deprecated.WALLABY), + deprecated_since=versionutils.deprecated.WALLABY, + ), policy.RuleDefault( name='is_admin', check_str='rule:admin_api', description='Full read/write API access', deprecated_for_removal=True, deprecated_reason=DEPRECATED_REASON, - deprecated_since=versionutils.deprecated.WALLABY), + deprecated_since=versionutils.deprecated.WALLABY, + ), policy.RuleDefault( name='admin_or_owner', check_str='is_admin:True or project_id:%(project_id)s', description='Admin or owner API access', deprecated_for_removal=True, deprecated_reason=DEPRECATED_REASON, - deprecated_since=versionutils.deprecated.WALLABY), + deprecated_since=versionutils.deprecated.WALLABY, + ), policy.RuleDefault( name='admin_or_user', check_str='is_admin:True or user_id:%(user_id)s', description='Admin or user API access', deprecated_for_removal=True, deprecated_reason=DEPRECATED_REASON, - deprecated_since=versionutils.deprecated.WALLABY), + deprecated_since=versionutils.deprecated.WALLABY, + ), ] ADMIN = 'rule:admin_api' @@ -120,32 +127,37 @@ default_policies = [ policy.RuleDefault( name='admin_api', check_str='role:admin or role:administrator', - description='Legacy rule for cloud admin access'), + description='Legacy rule for cloud admin access', + ), policy.RuleDefault( name="project_admin_api", check_str="role:admin and project_id:%(project_id)s", - description="Default rule for Project level admin APIs."), + description="Default rule for Project level admin APIs.", + ), policy.RuleDefault( name="project_member_api", check_str="role:member and project_id:%(project_id)s", - description="Default rule for Project level non admin APIs."), + description="Default rule for Project level non admin APIs.", + ), policy.RuleDefault( name="project_reader_api", check_str="role:reader and project_id:%(project_id)s", - description="Default rule for Project level read only APIs."), + description="Default rule for Project level read only APIs.", + ), policy.RuleDefault( "project_member_or_admin", "rule:project_member_api or rule:admin_api", "Default rule for Project Member or admin APIs.", - deprecated_rule=DEPRECATED_ADMIN_OR_OWNER), + deprecated_rule=DEPRECATED_ADMIN_OR_OWNER, + ), policy.RuleDefault( "project_reader_or_admin", "rule:project_reader_api or rule:admin_api", "Default rule for Project reader or admin APIs.", - deprecated_rule=DEPRECATED_ADMIN_OR_OWNER) + deprecated_rule=DEPRECATED_ADMIN_OR_OWNER, + ), ] def list_policies(): - return default_policies \ - + deprecated_default_policies + return default_policies + deprecated_default_policies diff --git a/cyborg/policies/device_profiles.py b/cyborg/policies/device_profiles.py index 9332cd07..580444fd 100644 --- a/cyborg/policies/device_profiles.py +++ b/cyborg/policies/device_profiles.py @@ -33,27 +33,38 @@ from cyborg.policies import base deprecated_get_all = policy.DeprecatedRule( name='cyborg:device_profile:get_all', check_str=base.deprecated_default, - deprecated_reason=('request admin_or_owmer rule is too strict for ' - 'listing device_profile'), - deprecated_since=versionutils.deprecated.WALLABY) + deprecated_reason=( + 'request admin_or_owmer rule is too strict for listing device_profile' + ), + deprecated_since=versionutils.deprecated.WALLABY, +) deprecated_get_one = policy.DeprecatedRule( name='cyborg:device_profile:get_one', check_str=base.deprecated_default, - deprecated_reason=('request admin_or_owmer rule is too strict for ' - 'retrieving a device_profile'), - deprecated_since=versionutils.deprecated.WALLABY) + deprecated_reason=( + 'request admin_or_owmer rule is too strict for ' + 'retrieving a device_profile' + ), + deprecated_since=versionutils.deprecated.WALLABY, +) deprecated_create = policy.DeprecatedRule( name='cyborg:device_profile:create', check_str=base.deprecated_is_admin, - deprecated_reason=('project_admin_or_owner is too permissive, ' - 'introduce admin for creation'), - deprecated_since=versionutils.deprecated.WALLABY) + deprecated_reason=( + 'project_admin_or_owner is too permissive, ' + 'introduce admin for creation' + ), + deprecated_since=versionutils.deprecated.WALLABY, +) deprecated_delete = policy.DeprecatedRule( name='cyborg:device_profile:delete', check_str=base.deprecated_default, - deprecated_reason=('project_admin_or_owner is too permissive, ' - 'introduce admin for deletion'), - deprecated_since=versionutils.deprecated.WALLABY) + deprecated_reason=( + 'project_admin_or_owner is too permissive, ' + 'introduce admin for deletion' + ), + deprecated_since=versionutils.deprecated.WALLABY, +) # new device_profile policies device_profile_policies = [ @@ -61,13 +72,10 @@ device_profile_policies = [ name='cyborg:device_profile:get_all', check_str=base.PROJECT_READER_OR_ADMIN, description='Retrieve all device_profiles', - operations=[ - { - 'path': '/v2/device_profiles', - 'method': 'GET' - }], + operations=[{'path': '/v2/device_profiles', 'method': 'GET'}], scope_types=['project'], - deprecated_rule=deprecated_get_all), + deprecated_rule=deprecated_get_all, + ), policy.DocumentedRuleDefault( name='cyborg:device_profile:get_one', check_str=base.PROJECT_READER_OR_ADMIN, @@ -75,21 +83,20 @@ device_profile_policies = [ operations=[ { 'path': '/v2/device_profiles/{device_profiles_uuid}', - 'method': 'GET' - }], + 'method': 'GET', + } + ], scope_types=['project'], - deprecated_rule=deprecated_get_one), + deprecated_rule=deprecated_get_one, + ), policy.DocumentedRuleDefault( name='cyborg:device_profile:create', check_str=base.ADMIN, description='Create a device_profile', - operations=[ - { - 'path': '/v2/device_profiles', - 'method': 'POST' - }], + operations=[{'path': '/v2/device_profiles', 'method': 'POST'}], scope_types=['project'], - deprecated_rule=deprecated_create), + deprecated_rule=deprecated_create, + ), policy.DocumentedRuleDefault( name='cyborg:device_profile:delete', check_str=base.ADMIN, @@ -97,13 +104,16 @@ device_profile_policies = [ operations=[ { 'path': '/v2/device_profiles/{device_profiles_uuid}', - 'method': 'DELETE'}, + 'method': 'DELETE', + }, { 'path': '/v2/device_profiles?value={device_profile_name1}', - 'method': 'DELETE'}, - ], + 'method': 'DELETE', + }, + ], scope_types=['project'], - deprecated_rule=deprecated_delete), + deprecated_rule=deprecated_delete, + ), ] diff --git a/cyborg/privsep/__init__.py b/cyborg/privsep/__init__.py index 7020a8eb..e6ff8bec 100644 --- a/cyborg/privsep/__init__.py +++ b/cyborg/privsep/__init__.py @@ -23,9 +23,11 @@ sys_admin_pctxt = priv_context.PrivContext( # TODO(yumeng): # CAP_SYS_ADMIN has a lot of scary powers, so # consider breaking this out into a separate minimal context. - capabilities=[capabilities.CAP_CHOWN, - capabilities.CAP_DAC_OVERRIDE, - capabilities.CAP_DAC_READ_SEARCH, - capabilities.CAP_FOWNER, - capabilities.CAP_SYS_ADMIN], + capabilities=[ + capabilities.CAP_CHOWN, + capabilities.CAP_DAC_OVERRIDE, + capabilities.CAP_DAC_READ_SEARCH, + capabilities.CAP_FOWNER, + capabilities.CAP_SYS_ADMIN, + ], ) diff --git a/cyborg/quota.py b/cyborg/quota.py index c0f6939e..6f48be54 100644 --- a/cyborg/quota.py +++ b/cyborg/quota.py @@ -23,25 +23,37 @@ from cyborg import db as db_api LOG = logging.getLogger(__name__) quota_opts = [ - cfg.IntOpt('reservation_expire', - default=86400, - help='Number of seconds until a reservation expires'), - cfg.IntOpt('until_refresh', - default=0, - help='Count of reservations until usage is refreshed'), - cfg.StrOpt('quota_driver', - default="cyborg.quota.DbQuotaDriver", - help='Default driver to use for quota checks'), - cfg.IntOpt('quota_fpgas', - default=10, - help='Total amount of fpga allowed per project'), - cfg.IntOpt('quota_gpus', - default=10, - help='Total amount of storage allowed per project'), - cfg.IntOpt('max_age', - default=0, - help='Number of seconds between subsequent usage refreshes') - ] + cfg.IntOpt( + 'reservation_expire', + default=86400, + help='Number of seconds until a reservation expires', + ), + cfg.IntOpt( + 'until_refresh', + default=0, + help='Count of reservations until usage is refreshed', + ), + cfg.StrOpt( + 'quota_driver', + default="cyborg.quota.DbQuotaDriver", + help='Default driver to use for quota checks', + ), + cfg.IntOpt( + 'quota_fpgas', + default=10, + help='Total amount of fpga allowed per project', + ), + cfg.IntOpt( + 'quota_gpus', + default=10, + help='Total amount of storage allowed per project', + ), + cfg.IntOpt( + 'max_age', + default=0, + help='Number of seconds between subsequent usage refreshes', + ), +] CONF = cfg.CONF CONF.register_opts(quota_opts) @@ -101,9 +113,13 @@ class QuotaEngine: """ if not project_id: project_id = context.project_id - reservations = self._driver.reserve(context, self._resources, deltas, - expire=expire, - project_id=project_id) + reservations = self._driver.reserve( + context, + self._resources, + deltas, + expire=expire, + project_id=project_id, + ) LOG.debug("Created reservations %s", reservations) @@ -139,10 +155,12 @@ class DbQuotaDriver: Also allows to obtain quota information. The default driver utilizes the local database. """ + dbapi = db_api.get_instance() - def reserve(self, context, resources, deltas, expire=None, - project_id=None): + def reserve( + self, context, resources, deltas, expire=None, project_id=None + ): # Set up the reservation expiration if expire is None: expire = CONF.reservation_expire @@ -157,13 +175,18 @@ class DbQuotaDriver: if project_id is None: project_id = context.project_id - return self._reserve(context, resources, deltas, expire, - project_id) + return self._reserve(context, resources, deltas, expire, project_id) def _reserve(self, context, resources, deltas, expire, project_id): - return self.dbapi.quota_reserve(context, resources, deltas, expire, - CONF.until_refresh, CONF.max_age, - project_id=project_id) + return self.dbapi.quota_reserve( + context, + resources, + deltas, + expire, + CONF.until_refresh, + CONF.max_age, + project_id=project_id, + ) def commit(self, context, reservations, project_id=None): """Commit reservations. @@ -177,8 +200,9 @@ class DbQuotaDriver: """ try: - self.dbapi.reservation_commit(context, reservations, - project_id=project_id) + self.dbapi.reservation_commit( + context, reservations, project_id=project_id + ) except Exception: # NOTE(Vek): Ignoring exceptions here is safe, because the # usage resynchronization and the reservation expiration diff --git a/cyborg/service_auth.py b/cyborg/service_auth.py index 5eb71583..e7d05398 100644 --- a/cyborg/service_auth.py +++ b/cyborg/service_auth.py @@ -36,19 +36,19 @@ def get_auth_plugin(context): if CONF.service_user.send_service_user_token: global _SERVICE_AUTH if not _SERVICE_AUTH: - _SERVICE_AUTH = ks_loading.\ - load_auth_from_conf_options(CONF, - group=cyborg. - conf.service_token. - SERVICE_USER_GROUP) + _SERVICE_AUTH = ks_loading.load_auth_from_conf_options( + CONF, group=cyborg.conf.service_token.SERVICE_USER_GROUP + ) if _SERVICE_AUTH is None: # This indicates a misconfiguration so log a warning and # return the user_auth. - LOG.warning('Unable to load auth from [service_user] ' - 'configuration. Ensure "auth_type" is set.') + LOG.warning( + 'Unable to load auth from [service_user] ' + 'configuration. Ensure "auth_type" is set.' + ) return user_auth - return service_token.\ - ServiceTokenAuthWrapper(user_auth=user_auth, - service_auth=_SERVICE_AUTH) + return service_token.ServiceTokenAuthWrapper( + user_auth=user_auth, service_auth=_SERVICE_AUTH + ) return user_auth diff --git a/cyborg/tests/base.py b/cyborg/tests/base.py index c2fdd348..98a6cc27 100644 --- a/cyborg/tests/base.py +++ b/cyborg/tests/base.py @@ -50,13 +50,11 @@ class TestCase(base.BaseTestCase): def _set_config(self): self.cfg_fixture = self.useFixture(config_fixture.Config(cfg.CONF)) - self.config(use_stderr=False, - fatal_exception_format_errors=True) - self.set_defaults(host='fake-mini', - debug=True) - self.set_defaults(connection="sqlite://", - sqlite_synchronous=False, - group='database') + self.config(use_stderr=False, fatal_exception_format_errors=True) + self.set_defaults(host='fake-mini', debug=True) + self.set_defaults( + connection="sqlite://", sqlite_synchronous=False, group='database' + ) cyborg_config.parse_args([], default_config_files=[]) def config(self, **kw): @@ -77,7 +75,7 @@ class TestCase(base.BaseTestCase): """ root = os.path.abspath( os.path.join(os.path.dirname(__file__), '..', '..') - ) + ) if project_file: return os.path.join(root, project_file) else: @@ -100,8 +98,9 @@ class DietTestCase(base.BaseTestCase): debugger = os.environ.get('OS_POST_MORTEM_DEBUGGER') if debugger: - self.addOnException(post_mortem_debug.get_exception_handler( - debugger)) + self.addOnException( + post_mortem_debug.get_exception_handler(debugger) + ) self.addCleanup(mock.patch.stopall) @@ -109,15 +108,17 @@ class DietTestCase(base.BaseTestCase): self.orig_pid = os.getpid() def addOnException(self, handler): - def safe_handler(*args, **kwargs): try: return handler(*args, **kwargs) except Exception: with excutils.save_and_reraise_exception(reraise=False) as ctx: - self.addDetail('Failure in exception handler %s' % handler, - testtools.content.TracebackContent( - (ctx.type_, ctx.value, ctx.tb), self)) + self.addDetail( + 'Failure in exception handler %s' % handler, + testtools.content.TracebackContent( + (ctx.type_, ctx.value, ctx.tb), self + ), + ) return super().addOnException(safe_handler) @@ -151,13 +152,20 @@ class DietTestCase(base.BaseTestCase): be reported upon failure. """ if not isinstance(expected_subset, dict): - self.fail("expected_subset (%s) is not an instance of dict" % - type(expected_subset)) + self.fail( + "expected_subset (%s) is not an instance of dict" + % type(expected_subset) + ) if not isinstance(actual_superset, dict): - self.fail("actual_superset (%s) is not an instance of dict" % - type(actual_superset)) + self.fail( + "actual_superset (%s) is not an instance of dict" + % type(actual_superset) + ) for k, v in expected_subset.items(): self.assertIn(k, actual_superset) - self.assertEqual(v, actual_superset[k], - "Key %(key)s expected: %(exp)r, actual %(act)r" % - {'key': k, 'exp': v, 'act': actual_superset[k]}) + self.assertEqual( + v, + actual_superset[k], + "Key %(key)s expected: %(exp)r, actual %(act)r" + % {'key': k, 'exp': v, 'act': actual_superset[k]}, + ) diff --git a/cyborg/tests/post_mortem_debug.py b/cyborg/tests/post_mortem_debug.py index 8424a594..068913d1 100644 --- a/cyborg/tests/post_mortem_debug.py +++ b/cyborg/tests/post_mortem_debug.py @@ -26,13 +26,15 @@ def _get_debugger(debugger_name): try: debugger = __import__(debugger_name) except ImportError: - raise ValueError("can't import %s module as a post mortem debugger" % - debugger_name) + raise ValueError( + "can't import %s module as a post mortem debugger" % debugger_name + ) if 'post_mortem' in dir(debugger): return debugger else: - raise ValueError("%s is not a supported post mortem debugger" % - debugger_name) + raise ValueError( + "%s is not a supported post mortem debugger" % debugger_name + ) def _exception_handler(debugger, exc_info): diff --git a/cyborg/tests/unit/accelerator/common/test_utils.py b/cyborg/tests/unit/accelerator/common/test_utils.py index 1f781f0e..9a7e220a 100644 --- a/cyborg/tests/unit/accelerator/common/test_utils.py +++ b/cyborg/tests/unit/accelerator/common/test_utils.py @@ -25,25 +25,32 @@ class TestUtils(unittest.TestCase): def test_pci_str_to_json(self): pci_address = '0000:0b:00.0' - json_str = '{"bus": "0b", "device": "00", "domain": "0000", ' \ - '"function": "0"}' + json_str = ( + '{"bus": "0b", "device": "00", "domain": "0000", "function": "0"}' + ) result = self.utils.pci_str_to_json(pci_address) self.assertEqual(result, json_str) pci_address = '0000:0b:00.1' - json_str = '{"bus": "0b", "device": "00", "domain": "0000", ' \ - '"function": "1", "physical_network": "physnet"}' + json_str = ( + '{"bus": "0b", "device": "00", "domain": "0000", ' + '"function": "1", "physical_network": "physnet"}' + ) result = self.utils.pci_str_to_json(pci_address, 'physnet') self.assertEqual(result, json_str) def test_mdev_str_to_json(self): - json_str = '{"asked_type": "type", "bus": "0b", "device": "00", ' \ - '"domain": "0000", "function": "1", "vgpu_mark": "mask"}' + json_str = ( + '{"asked_type": "type", "bus": "0b", "device": "00", ' + '"domain": "0000", "function": "1", "vgpu_mark": "mask"}' + ) result = self.utils.mdev_str_to_json('0000:0b:00.1', 'type', 'mask') self.assertEqual(result, json_str) - json_str = '{"asked_type": null, "bus": "0b", "device": "00", ' \ - '"domain": "0000", "function": "1", "vgpu_mark": null}' + json_str = ( + '{"asked_type": null, "bus": "0b", "device": "00", ' + '"domain": "0000", "function": "1", "vgpu_mark": null}' + ) result = self.utils.mdev_str_to_json('0000:0b:00.1', None, None) self.assertEqual(result, json_str) diff --git a/cyborg/tests/unit/accelerator/drivers/aichip/huawei/test_ascend.py b/cyborg/tests/unit/accelerator/drivers/aichip/huawei/test_ascend.py index cac8723b..2f8f2f0a 100644 --- a/cyborg/tests/unit/accelerator/drivers/aichip/huawei/test_ascend.py +++ b/cyborg/tests/unit/accelerator/drivers/aichip/huawei/test_ascend.py @@ -20,14 +20,15 @@ d100_pci_res = ( '0000:00:0c.0 Processing accelerators [1200]:' ' Device [19e5:d100] (rev 20)\n' '0000:00:0d.0 Processing accelerators [1200]:' - ' Device [19e5:d100] (rev 20)\n',) + ' Device [19e5:d100] (rev 20)\n', +) class TestAscendDriver(base.TestCase): - - @mock.patch('cyborg.accelerator.drivers.aichip.' - 'huawei.ascend.lspci_privileged', - return_value=d100_pci_res) + @mock.patch( + 'cyborg.accelerator.drivers.aichip.huawei.ascend.lspci_privileged', + return_value=d100_pci_res, + ) def test_discover(self, mock_pci): ascend_driver = AscendDriver() npu_list = ascend_driver.discover() @@ -35,19 +36,26 @@ class TestAscendDriver(base.TestCase): for ascend in npu_list: self.assertEqual('AICHIP', ascend.type) self.assertEqual('PCI', ascend.controlpath_id.cpid_type) - self.assertEqual(json.loads( - '{"class": "Processing accelerators", "device_id": "d100"}'), - json.loads(ascend.std_board_info)) + self.assertEqual( + json.loads( + '{"class": "Processing accelerators", "device_id": "d100"}' + ), + json.loads(ascend.std_board_info), + ) self.assertEqual('19e5', ascend.vendor) self.assertEqual( {"device": "0c", "bus": "00", "domain": "0000", "function": "0"}, - json.loads(npu_list[0].controlpath_id.cpid_info)) + json.loads(npu_list[0].controlpath_id.cpid_info), + ) self.assertEqual( {"device": "0d", "bus": "00", "domain": "0000", "function": "0"}, - json.loads(npu_list[1].controlpath_id.cpid_info)) + json.loads(npu_list[1].controlpath_id.cpid_info), + ) - self.assertEqual('Device_0000_00_0c_0', - npu_list[0].deployable_list[0].name) - self.assertEqual('Device_0000_00_0d_0', - npu_list[1].deployable_list[0].name) + self.assertEqual( + 'Device_0000_00_0c_0', npu_list[0].deployable_list[0].name + ) + self.assertEqual( + 'Device_0000_00_0d_0', npu_list[1].deployable_list[0].name + ) diff --git a/cyborg/tests/unit/accelerator/drivers/fpga/inspur/test_driver.py b/cyborg/tests/unit/accelerator/drivers/fpga/inspur/test_driver.py index 8bca5905..18705304 100644 --- a/cyborg/tests/unit/accelerator/drivers/fpga/inspur/test_driver.py +++ b/cyborg/tests/unit/accelerator/drivers/fpga/inspur/test_driver.py @@ -19,9 +19,11 @@ from cyborg.accelerator.drivers.fpga.inspur.driver import InspurFPGADriver from cyborg.accelerator.drivers.fpga.inspur import sysinfo from cyborg.tests import base -INSPUR_FPGA_INFO = ("0000:86:00.0 Processing accelerators [1200]: " - "Inspur Electronic Information Industry Co., Ltd. " - "Device [1bd4:a115] (rev 04)") +INSPUR_FPGA_INFO = ( + "0000:86:00.0 Processing accelerators [1200]: " + "Inspur Electronic Information Industry Co., Ltd. " + "Device [1bd4:a115] (rev 04)" +) class stdout: @@ -38,25 +40,27 @@ class p: class TestInspurFPGADriver(base.TestCase): - def setUp(self): super().setUp() self.p = p() - @mock.patch('cyborg.accelerator.drivers.fpga.' - 'inspur.sysinfo.lspci_privileged') + @mock.patch( + 'cyborg.accelerator.drivers.fpga.inspur.sysinfo.lspci_privileged' + ) def test_discover(self, mock_devices_for_vendor): mock_devices_for_vendor.return_value = self.p.stdout.readlines() self.set_defaults(host='host-192-168-32-195', debug=True) fpga_list = InspurFPGADriver().discover() self.assertEqual(1, len(fpga_list)) attach_handle_list = [ - {'attach_type': 'PCI', - 'attach_info': '{"bus": "86", ' - '"device": "00", ' - '"domain": "0000", ' - '"function": "0"}', - 'in_use': False} + { + 'attach_type': 'PCI', + 'attach_info': '{"bus": "86", ' + '"device": "00", ' + '"domain": "0000", ' + '"function": "0"}', + 'in_use': False, + } ] attribute_list = [ {'key': 'rc', 'value': 'FPGA'}, @@ -66,65 +70,90 @@ class TestInspurFPGADriver(base.TestCase): expected = { 'vendor': '1bd4', 'type': 'FPGA', - 'std_board_info': {"controller": "Processing accelerators", - "product_id": "a115"}, + 'std_board_info': { + "controller": "Processing accelerators", + "product_id": "a115", + }, 'vendor_board_info': {"vendor_info": "fpga_vb_info"}, - 'deployable_list': - [ - { - 'num_accelerators': 1, - 'driver_name': 'INSPUR', - 'name': 'host-192-168-32-195_0000:86:00.0', - 'attach_handle_list': attach_handle_list, - 'attribute_list': attribute_list - }, - ], - 'controlpath_id': {'cpid_info': '{"bus": "86", ' - '"device": "00", ' - '"domain": "0000", ' - '"function": "0"}', - 'cpid_type': 'PCI'} + 'deployable_list': [ + { + 'num_accelerators': 1, + 'driver_name': 'INSPUR', + 'name': 'host-192-168-32-195_0000:86:00.0', + 'attach_handle_list': attach_handle_list, + 'attribute_list': attribute_list, + }, + ], + 'controlpath_id': { + 'cpid_info': '{"bus": "86", ' + '"device": "00", ' + '"domain": "0000", ' + '"function": "0"}', + 'cpid_type': 'PCI', + }, } fpga_obj = fpga_list[0] fpga_dict = fpga_obj.as_dict() fpga_dep_list = fpga_dict['deployable_list'] - fpga_attach_handle_list = ( - fpga_dep_list[0].as_dict()['attach_handle_list']) + fpga_attach_handle_list = fpga_dep_list[0].as_dict()[ + 'attach_handle_list' + ] fpga_attribute_list = fpga_dep_list[0].as_dict()['attribute_list'] attri_obj_data = [] [attri_obj_data.append(attr.as_dict()) for attr in fpga_attribute_list] attribute_actual_data = sorted(attri_obj_data, key=lambda i: i['key']) self.assertEqual(expected['vendor'], fpga_dict['vendor']) - self.assertEqual(expected['controlpath_id'], - fpga_dict['controlpath_id']) - self.assertEqual(expected['std_board_info'], - jsonutils.loads(fpga_dict['std_board_info'])) - self.assertEqual(expected['vendor_board_info'], - jsonutils.loads(fpga_dict['vendor_board_info'])) - self.assertEqual(expected['deployable_list'][0]['num_accelerators'], - fpga_dep_list[0].as_dict()['num_accelerators']) - self.assertEqual(expected['deployable_list'][0]['name'], - fpga_dep_list[0].as_dict()['name']) - self.assertEqual(expected['deployable_list'][0]['driver_name'], - fpga_dep_list[0].as_dict()['driver_name']) - self.assertEqual(attach_handle_list[0], - fpga_attach_handle_list[0].as_dict()) + self.assertEqual( + expected['controlpath_id'], fpga_dict['controlpath_id'] + ) + self.assertEqual( + expected['std_board_info'], + jsonutils.loads(fpga_dict['std_board_info']), + ) + self.assertEqual( + expected['vendor_board_info'], + jsonutils.loads(fpga_dict['vendor_board_info']), + ) + self.assertEqual( + expected['deployable_list'][0]['num_accelerators'], + fpga_dep_list[0].as_dict()['num_accelerators'], + ) + self.assertEqual( + expected['deployable_list'][0]['name'], + fpga_dep_list[0].as_dict()['name'], + ) + self.assertEqual( + expected['deployable_list'][0]['driver_name'], + fpga_dep_list[0].as_dict()['driver_name'], + ) + self.assertEqual( + attach_handle_list[0], fpga_attach_handle_list[0].as_dict() + ) self.assertEqual(attribute_list, attribute_actual_data) - @mock.patch('cyborg.accelerator.drivers.fpga.' - 'inspur.sysinfo.lspci_privileged') + @mock.patch( + 'cyborg.accelerator.drivers.fpga.inspur.sysinfo.lspci_privileged' + ) def test_get_pci_devices_by_inspur_vendor(self, mock_devices_for_vendor): - - fake_pci_output = [("0000:86:00.0 Processing accelerators [1200]: " - "Inspur Electronic Information Industry Co., Ltd. " - "Device [1bd4:a115] (rev 04)\n" - "0000:86:00.0 Processing accelerators [1200]: " - "Xilinx Corporation Device [10ee:5000]")] + fake_pci_output = [ + ( + "0000:86:00.0 Processing accelerators [1200]: " + "Inspur Electronic Information Industry Co., Ltd. " + "Device [1bd4:a115] (rev 04)\n" + "0000:86:00.0 Processing accelerators [1200]: " + "Xilinx Corporation Device [10ee:5000]" + ) + ] mock_devices_for_vendor.return_value = fake_pci_output - pci_devices = sysinfo.get_pci_devices(sysinfo.INSPUR_FPGA_FLAGS, - vendor_id=sysinfo.VENDOR_ID) - expected = [("0000:86:00.0 Processing accelerators [1200]: " - "Inspur Electronic Information Industry Co., Ltd. " - "Device [1bd4:a115] (rev 04)")] + pci_devices = sysinfo.get_pci_devices( + sysinfo.INSPUR_FPGA_FLAGS, vendor_id=sysinfo.VENDOR_ID + ) + expected = [ + ( + "0000:86:00.0 Processing accelerators [1200]: " + "Inspur Electronic Information Industry Co., Ltd. " + "Device [1bd4:a115] (rev 04)" + ) + ] self.assertEqual(len(pci_devices), 1) self.assertEqual(pci_devices, expected) diff --git a/cyborg/tests/unit/accelerator/drivers/fpga/intel/prepare_test_data.py b/cyborg/tests/unit/accelerator/drivers/fpga/intel/prepare_test_data.py index e1093850..4ba6f13e 100644 --- a/cyborg/tests/unit/accelerator/drivers/fpga/intel/prepare_test_data.py +++ b/cyborg/tests/unit/accelerator/drivers/fpga/intel/prepare_test_data.py @@ -22,9 +22,9 @@ PF0_ADDR = "0000:5e:00.0" PF1_ADDR = "0000:be:00.0" VF0_ADDR = "0000:5e:00.1" FPGA_TREE = { - "dev.0": {"bdf": PF0_ADDR, - "regions": {"dev.2": {"bdf": VF0_ADDR}}}, - "dev.1": {"bdf": PF1_ADDR}} + "dev.0": {"bdf": PF0_ADDR, "regions": {"dev.2": {"bdf": VF0_ADDR}}}, + "dev.1": {"bdf": PF1_ADDR}, +} SYS_DEVICES = "sys/devices" PCI_DEVICES_PATH = "sys/bus/pci/devices" @@ -47,8 +47,8 @@ PGFA_DEVICE_COMMON_CONTENT = { "irq": "16", "local_cpulist": "0-111", "local_cpus": "00000000,00000000,00000000,00000000,00000000," - "00000000,00000000,00000000,00000000,00000000," - "0000ffff,ffffffff,ffffffff,ffffffff", + "00000000,00000000,00000000,00000000,00000000," + "0000ffff,ffffffff,ffffffff,ffffffff", "modalias": "pci:v00008086d0000BCC0sv00000000sd00000000bc12sc00i00", "msi_bus": "", "numa_node": "-1", @@ -65,7 +65,8 @@ PGFA_DEVICE_COMMON_CONTENT = { "0x0000000000000000 0x0000000000000000 0x0000000000000000", "0x0000000000000000 0x0000000000000000 0x0000000000000000", "0x0000000000000000 0x0000000000000000 0x0000000000000000", - "0x0000000000000000 0x0000000000000000 0x0000000000000000"], + "0x0000000000000000 0x0000000000000000 0x0000000000000000", + ], "resource0": "", "resource0_wc": "", "subsystem_device": "0x0000", @@ -76,8 +77,10 @@ PGFA_DEVICE_COMMON_CONTENT = { "PCI_ID=8086:BCC0", "PCI_SUBSYS_ID=0000:0000", "PCI_SLOT_NAME=0000:5e:00.0", - "MODALIAS=pci:v00008086d0000BCC0sv00000000sd00000000bc12sc00i00"], - "vendor": "0x8086"} + "MODALIAS=pci:v00008086d0000BCC0sv00000000sd00000000bc12sc00i00", + ], + "vendor": "0x8086", +} PGFA_DEVICES_SPECIAL_COMMON_CONTENT = { "dev.0": { @@ -100,7 +103,8 @@ PGFA_DEVICES_SPECIAL_COMMON_CONTENT = { "0x0000000000000000 0x0000000000000000 0x0000000000000000", "0x0000000000000000 0x0000000000000000 0x0000000000000000", "0x0000000000000000 0x0000000000000000 0x0000000000000000", - "0x0000000000000000 0x0000000000000000 0x0000000000000000"], + "0x0000000000000000 0x0000000000000000 0x0000000000000000", + ], "resource2": "", "resource2_wc": "", "sriov_numvfs": "0", @@ -111,7 +115,8 @@ PGFA_DEVICES_SPECIAL_COMMON_CONTENT = { "PCI_ID=8086:BCC0", "PCI_SUBSYS_ID=0000:0000", "PCI_SLOT_NAME=0000:be:00.0", - "MODALIAS=pci:v00008086d0000BCC0sv00000000sd00000000bc12sc00i00"], + "MODALIAS=pci:v00008086d0000BCC0sv00000000sd00000000bc12sc00i00", + ], }, "dev.2": { "d3cold_allowed": "0", @@ -131,22 +136,24 @@ PGFA_DEVICES_SPECIAL_COMMON_CONTENT = { "0x0000000000000000 0x0000000000000000 0x0000000000000000", "0x0000000000000000 0x0000000000000000 0x0000000000000000", "0x0000000000000000 0x0000000000000000 0x0000000000000000", - "0x0000000000000000 0x0000000000000000 0x0000000000000000"], + "0x0000000000000000 0x0000000000000000 0x0000000000000000", + ], "uevent": [ "DRIVER=intel-fpga-pci", "PCI_CLASS=120000", "PCI_ID=8086:BCC1", "PCI_SUBSYS_ID=0000:0000", "PCI_SLOT_NAME=0000:5e:00.1", - "MODALIAS=pci:v00008086d0000BCC1sv00000000sd00000000bc12sc00i00"], - } + "MODALIAS=pci:v00008086d0000BCC1sv00000000sd00000000bc12sc00i00", + ], + }, } PGFA_DEVICE_COMMON_SOFT_LINK = { "driver": "../../../bus/pci/drivers/intel-fpga-pci", "iommu": "../../virtual/iommu/dmar8", "iommu_group": "../../../kernel/iommu_groups/75", - "subsystem": "../../../bus/pci" + "subsystem": "../../../bus/pci", } PGFA_DEVICES_SPECIAL_SOFT_LINK = { @@ -161,7 +168,7 @@ PGFA_DEVICES_SPECIAL_SOFT_LINK = { "dev.2": { "iommu": "../../virtual/iommu/dmar9", "iommu_group": "../../../kernel/iommu_groups/81", - } + }, } PGFA_DEVICES_SPECIAL_SOFT_LINK = { "dev.0": { @@ -175,17 +182,17 @@ PGFA_DEVICES_SPECIAL_SOFT_LINK = { "dev.2": { "iommu": "../../virtual/iommu/dmar9", "iommu_group": "../../../kernel/iommu_groups/81", - } + }, } PGFA_DEVICE_PF_SOFT_LINK = { - "virtfn": lambda k, v: (k + str(int(v.rsplit(".", 1)[-1]) - 1), - "/".join(["..", v])) + "virtfn": lambda k, v: ( + k + str(int(v.rsplit(".", 1)[-1]) - 1), + "/".join(["..", v]), + ) } -PGFA_DEVICE_VF_SOFT_LINK = { - "physfn": lambda k, v: (k, "/".join(["..", v])) -} +PGFA_DEVICE_VF_SOFT_LINK = {"physfn": lambda k, v: (k, "/".join(["..", v]))} def gen_fpga_content(path, dev): @@ -228,8 +235,9 @@ def gen_fpga_vf_soft_link(path, bdf): os.symlink(v, os.path.join(path, k)) -def create_devices_path_and_files(tree, device_path, class_fpga_path, - vf=False, pfinfo=None): +def create_devices_path_and_files( + tree, device_path, class_fpga_path, vf=False, pfinfo=None +): for k, v in tree.items(): bdf = v["bdf"] pci_path = "pci" + bdf.rsplit(":", 1)[0] @@ -245,16 +253,19 @@ def create_devices_path_and_files(tree, device_path, class_fpga_path, pfinfo = {"path": bdf_path, "bdf": bdf} if "regions" in v: create_devices_path_and_files( - v["regions"], device_path, class_fpga_path, True, pfinfo) + v["regions"], device_path, class_fpga_path, True, pfinfo + ) source = dev_path.split("sys")[-1] os.symlink("../.." + source, os.path.join(class_fpga_path, ln)) os.symlink("../../../" + bdf, os.path.join(dev_path, "device")) - pci_dev = os.path.join(device_path.split(SYS_DEVICES)[0], - PCI_DEVICES_PATH) + pci_dev = os.path.join( + device_path.split(SYS_DEVICES)[0], PCI_DEVICES_PATH + ) if not os.path.exists(pci_dev): os.makedirs(pci_dev) - os.symlink("../../.." + bdf_path.split("sys")[-1], - os.path.join(pci_dev, bdf)) + os.symlink( + "../../.." + bdf_path.split("sys")[-1], os.path.join(pci_dev, bdf) + ) def create_devices_soft_link(class_fpga_path): @@ -263,7 +274,8 @@ def create_devices_soft_link(class_fpga_path): path = os.path.realpath("%s/%s/device" % (class_fpga_path, dev)) softlinks = copy.copy(PGFA_DEVICE_COMMON_SOFT_LINK) softlinks.update( - PGFA_DEVICES_SPECIAL_SOFT_LINK[dev.rsplit("-", 1)[-1]]) + PGFA_DEVICES_SPECIAL_SOFT_LINK[dev.rsplit("-", 1)[-1]] + ) for k, v in softlinks.items(): source = os.path.normpath(os.path.join(path, v)) if not os.path.exists(source): diff --git a/cyborg/tests/unit/accelerator/drivers/fpga/intel/test_driver.py b/cyborg/tests/unit/accelerator/drivers/fpga/intel/test_driver.py index ded71bb5..baadf07d 100644 --- a/cyborg/tests/unit/accelerator/drivers/fpga/intel/test_driver.py +++ b/cyborg/tests/unit/accelerator/drivers/fpga/intel/test_driver.py @@ -24,7 +24,6 @@ from cyborg.tests.unit.accelerator.drivers.fpga.intel import prepare_test_data class TestIntelFPGADriver(base.TestCase): - def setUp(self): super().setUp() self.syspath = sysinfo.SYS_FPGA @@ -34,11 +33,14 @@ class TestIntelFPGADriver(base.TestCase): prepare_test_data.create_fake_sysfs(tmp_sys_dir.path) tmp_path = tmp_sys_dir.path sysinfo.SYS_FPGA = os.path.join( - tmp_path, sysinfo.SYS_FPGA.split("/", 1)[-1]) + tmp_path, sysinfo.SYS_FPGA.split("/", 1)[-1] + ) sysinfo.PCI_DEVICES_PATH = os.path.join( - tmp_path, sysinfo.PCI_DEVICES_PATH.split("/", 1)[-1]) + tmp_path, sysinfo.PCI_DEVICES_PATH.split("/", 1)[-1] + ) sysinfo.PCI_DEVICES_PATH_PATTERN = os.path.join( - tmp_path, sysinfo.PCI_DEVICES_PATH_PATTERN.split("/", 1)[-1]) + tmp_path, sysinfo.PCI_DEVICES_PATH_PATTERN.split("/", 1)[-1] + ) def tearDown(self): super().tearDown() @@ -49,59 +51,66 @@ class TestIntelFPGADriver(base.TestCase): def test_discover(self): attach_handle_list = [ [ - {'attach_type': 'PCI', - 'attach_info': '{"bus": "5e", ' - '"device": "00", ' - '"domain": "0000", ' - '"function": "1"}', - 'in_use': False} + { + 'attach_type': 'PCI', + 'attach_info': '{"bus": "5e", ' + '"device": "00", ' + '"domain": "0000", ' + '"function": "1"}', + 'in_use': False, + } ], [ - {'attach_type': 'PCI', - 'attach_info': '{"bus": "be", ' - '"device": "00", ' - '"domain": "0000", ' - '"function": "0"}', - 'in_use': False} - ] + { + 'attach_type': 'PCI', + 'attach_info': '{"bus": "be", ' + '"device": "00", ' + '"domain": "0000", ' + '"function": "0"}', + 'in_use': False, + } + ], + ] + expected = [ + { + 'vendor': '0x8086', + 'type': 'FPGA', + 'model': '0xbcc0', + 'deployable_list': [ + { + 'num_accelerators': 1, + 'name': 'intel-fpga-dev_0000:5e:00.1', + 'attach_handle_list': attach_handle_list[0], + }, + ], + 'controlpath_id': { + 'cpid_info': '{"bus": "5e", ' + '"device": "00", ' + '"domain": "0000", ' + '"function": "0"}', + 'cpid_type': 'PCI', + }, + }, + { + 'vendor': '0x8086', + 'type': 'FPGA', + 'model': '0xbcc0', + 'deployable_list': [ + { + 'num_accelerators': 1, + 'name': 'intel-fpga-dev_0000:be:00.0', + 'attach_handle_list': attach_handle_list[1], + }, + ], + 'controlpath_id': { + 'cpid_info': '{"bus": "be", ' + '"device": "00", ' + '"domain": "0000", ' + '"function": "0"}', + 'cpid_type': 'PCI', + }, + }, ] - expected = [{'vendor': '0x8086', - 'type': 'FPGA', - 'model': '0xbcc0', - 'deployable_list': - [ - {'num_accelerators': 1, - 'name': 'intel-fpga-dev_0000:5e:00.1', - 'attach_handle_list': attach_handle_list[0] - }, - ], - 'controlpath_id': - { - 'cpid_info': '{"bus": "5e", ' - '"device": "00", ' - '"domain": "0000", ' - '"function": "0"}', - 'cpid_type': 'PCI'} - }, - {'vendor': '0x8086', - 'type': 'FPGA', - 'model': '0xbcc0', - 'deployable_list': - [ - {'num_accelerators': 1, - 'name': 'intel-fpga-dev_0000:be:00.0', - 'attach_handle_list': attach_handle_list[1] - }, - ], - 'controlpath_id': - { - 'cpid_info': '{"bus": "be", ' - '"device": "00", ' - '"domain": "0000", ' - '"function": "0"}', - 'cpid_type': 'PCI'} - } - ] intel = IntelFPGADriver() fpgas = intel.discover() list.sort(fpgas, key=lambda x: x._obj_deployable_list[0].name) @@ -109,29 +118,45 @@ class TestIntelFPGADriver(base.TestCase): for i in range(len(fpgas)): fpga_dict = fpgas[i].as_dict() fpga_dep_list = fpga_dict['deployable_list'] - fpga_attach_handle_list = \ - fpga_dep_list[0].as_dict()['attach_handle_list'] + fpga_attach_handle_list = fpga_dep_list[0].as_dict()[ + 'attach_handle_list' + ] self.assertEqual(expected[i]['vendor'], fpga_dict['vendor']) - self.assertEqual(expected[i]['controlpath_id'], - fpga_dict['controlpath_id']) - self.assertEqual(expected[i]['deployable_list'][0] - ['num_accelerators'], - fpga_dep_list[0].as_dict()['num_accelerators']) - self.assertEqual(attach_handle_list[i][0], - fpga_attach_handle_list[0].as_dict()) + self.assertEqual( + expected[i]['controlpath_id'], fpga_dict['controlpath_id'] + ) + self.assertEqual( + expected[i]['deployable_list'][0]['num_accelerators'], + fpga_dep_list[0].as_dict()['num_accelerators'], + ) + self.assertEqual( + attach_handle_list[i][0], fpga_attach_handle_list[0].as_dict() + ) - @mock.patch('cyborg.accelerator.drivers.fpga.intel.driver.' - '_fpga_program_privileged') + @mock.patch( + 'cyborg.accelerator.drivers.fpga.intel.driver._fpga_program_privileged' + ) def test_intel_program(self, mock_prog): b = "0x5e" d = "0x00" f = "0x0" - expect_cmd = ['--bus', b, '--device', d, '--function', f, - '/path/image'] + expect_cmd = [ + '--bus', + b, + '--device', + d, + '--function', + f, + '/path/image', + ] intel = IntelFPGADriver() - cpid_info = {"domain": "0000", "bus": "5e", - "device": "00", "function": "0"} + cpid_info = { + "domain": "0000", + "bus": "5e", + "device": "00", + "function": "0", + } cpid = {'cpid_type': 'PCI', 'cpid_info': cpid_info} # program PF diff --git a/cyborg/tests/unit/accelerator/drivers/fpga/xilinx/test_driver.py b/cyborg/tests/unit/accelerator/drivers/fpga/xilinx/test_driver.py index 753af0ad..820c207f 100644 --- a/cyborg/tests/unit/accelerator/drivers/fpga/xilinx/test_driver.py +++ b/cyborg/tests/unit/accelerator/drivers/fpga/xilinx/test_driver.py @@ -18,10 +18,12 @@ from unittest import mock from cyborg.accelerator.drivers.fpga.xilinx.driver import XilinxFPGADriver from cyborg.tests import base -XILINX_FPGA_INFO = ["0000:3b:00.0 Processing accelerators [1200]: " - "Xilinx Corporation Device [10ee:5000]\n" - "0000:3b:00.1 Processing accelerators [1200]: " - "Xilinx Corporation Device [10ee:5001]"] +XILINX_FPGA_INFO = [ + "0000:3b:00.0 Processing accelerators [1200]: " + "Xilinx Corporation Device [10ee:5000]\n" + "0000:3b:00.1 Processing accelerators [1200]: " + "Xilinx Corporation Device [10ee:5001]" +] def fake_output(arg): @@ -30,30 +32,34 @@ def fake_output(arg): class TestXilinxFPGADriver(base.TestCase): - def setUp(self): super().setUp() - @mock.patch('cyborg.accelerator.drivers.fpga.' - 'xilinx.sysinfo.lspci_privileged') + @mock.patch( + 'cyborg.accelerator.drivers.fpga.xilinx.sysinfo.lspci_privileged' + ) def test_discover(self, mock_devices_for_vendor): mock_devices_for_vendor.side_effect = fake_output self.set_defaults(host='fake-host', debug=True) fpga_list = XilinxFPGADriver().discover() self.assertEqual(1, len(fpga_list)) attach_handle_list = [ - {'attach_type': 'PCI', - 'attach_info': '{"bus": "3b", ' - '"device": "00", ' - '"domain": "0000", ' - '"function": "0"}', - 'in_use': False}, - {'attach_type': 'PCI', - 'attach_info': '{"bus": "3b", ' - '"device": "00", ' - '"domain": "0000", ' - '"function": "1"}', - 'in_use': False} + { + 'attach_type': 'PCI', + 'attach_info': '{"bus": "3b", ' + '"device": "00", ' + '"domain": "0000", ' + '"function": "0"}', + 'in_use': False, + }, + { + 'attach_type': 'PCI', + 'attach_info': '{"bus": "3b", ' + '"device": "00", ' + '"domain": "0000", ' + '"function": "1"}', + 'in_use': False, + }, ] attribute_list = [ {'key': 'rc', 'value': 'FPGA'}, @@ -63,64 +69,93 @@ class TestXilinxFPGADriver(base.TestCase): expected = { 'vendor': '10ee', 'type': 'FPGA', - 'std_board_info': {"controller": "Processing accelerators", - "product_id": "5000"}, + 'std_board_info': { + "controller": "Processing accelerators", + "product_id": "5000", + }, 'vendor_board_info': {"vendor_info": "fpga_vb_info"}, - 'deployable_list': - [ - { - 'num_accelerators': 1, - 'driver_name': 'XILINX', - 'name': 'fake-host_0000:3b:00.0', - 'attach_handle_list': attach_handle_list, - 'attribute_list': attribute_list - }, - ], - 'controlpath_id': {'cpid_info': '{"bus": "3b", ' - '"device": "00", ' - '"domain": "0000", ' - '"function": "0"}', - 'cpid_type': 'PCI'} + 'deployable_list': [ + { + 'num_accelerators': 1, + 'driver_name': 'XILINX', + 'name': 'fake-host_0000:3b:00.0', + 'attach_handle_list': attach_handle_list, + 'attribute_list': attribute_list, + }, + ], + 'controlpath_id': { + 'cpid_info': '{"bus": "3b", ' + '"device": "00", ' + '"domain": "0000", ' + '"function": "0"}', + 'cpid_type': 'PCI', + }, } fpga_obj = fpga_list[0] fpga_dict = fpga_obj.as_dict() fpga_dep_list = fpga_dict['deployable_list'] - fpga_attach_handle_list = ( - fpga_dep_list[0].as_dict()['attach_handle_list']) + fpga_attach_handle_list = fpga_dep_list[0].as_dict()[ + 'attach_handle_list' + ] fpga_attribute_list = fpga_dep_list[0].as_dict()['attribute_list'] attri_obj_data = [] [attri_obj_data.append(attr.as_dict()) for attr in fpga_attribute_list] attribute_actual_data = sorted(attri_obj_data, key=lambda i: i['key']) self.assertEqual(expected['vendor'], fpga_dict['vendor']) - self.assertEqual(expected['controlpath_id'], - fpga_dict['controlpath_id']) - self.assertEqual(expected['std_board_info'], - jsonutils.loads(fpga_dict['std_board_info'])) - self.assertEqual(expected['vendor_board_info'], - jsonutils.loads(fpga_dict['vendor_board_info'])) - self.assertEqual(expected['deployable_list'][0]['num_accelerators'], - fpga_dep_list[0].as_dict()['num_accelerators']) - self.assertEqual(expected['deployable_list'][0]['name'], - fpga_dep_list[0].as_dict()['name']) - self.assertEqual(expected['deployable_list'][0]['driver_name'], - fpga_dep_list[0].as_dict()['driver_name']) + self.assertEqual( + expected['controlpath_id'], fpga_dict['controlpath_id'] + ) + self.assertEqual( + expected['std_board_info'], + jsonutils.loads(fpga_dict['std_board_info']), + ) + self.assertEqual( + expected['vendor_board_info'], + jsonutils.loads(fpga_dict['vendor_board_info']), + ) + self.assertEqual( + expected['deployable_list'][0]['num_accelerators'], + fpga_dep_list[0].as_dict()['num_accelerators'], + ) + self.assertEqual( + expected['deployable_list'][0]['name'], + fpga_dep_list[0].as_dict()['name'], + ) + self.assertEqual( + expected['deployable_list'][0]['driver_name'], + fpga_dep_list[0].as_dict()['driver_name'], + ) self.assertEqual(2, len(fpga_attach_handle_list)) - self.assertEqual(attach_handle_list[0], - fpga_attach_handle_list[0].as_dict()) - self.assertEqual(attach_handle_list[1], - fpga_attach_handle_list[1].as_dict()) + self.assertEqual( + attach_handle_list[0], fpga_attach_handle_list[0].as_dict() + ) + self.assertEqual( + attach_handle_list[1], fpga_attach_handle_list[1].as_dict() + ) self.assertEqual(attribute_list, attribute_actual_data) - @mock.patch('cyborg.accelerator.drivers.fpga.xilinx.driver.' - '_fpga_program_privileged') + @mock.patch( + 'cyborg.accelerator.drivers.fpga.xilinx.driver.' + '_fpga_program_privileged' + ) def test_program(self, mock_prog): bdf = '0000:3b:00:0' - expect_cmd_args = ['program', '--device', bdf, '--base', - '--image', '/path/image'] + expect_cmd_args = [ + 'program', + '--device', + bdf, + '--base', + '--image', + '/path/image', + ] xilinx_driver = XilinxFPGADriver() - cpid_info = {"domain": "0000", "bus": "3b", - "device": "00", "function": "0"} + cpid_info = { + "domain": "0000", + "bus": "3b", + "device": "00", + "function": "0", + } cpid = {'cpid_type': 'PCI', 'cpid_info': cpid_info} # program PF diff --git a/cyborg/tests/unit/accelerator/drivers/gpu/nvidia/test_driver.py b/cyborg/tests/unit/accelerator/drivers/gpu/nvidia/test_driver.py index b38ebb1d..8a210e36 100644 --- a/cyborg/tests/unit/accelerator/drivers/gpu/nvidia/test_driver.py +++ b/cyborg/tests/unit/accelerator/drivers/gpu/nvidia/test_driver.py @@ -14,4 +14,4 @@ from cyborg.tests import base class TestNvidiaGPUDriver(base.TestCase): - """""""" + """""" "" diff --git a/cyborg/tests/unit/accelerator/drivers/gpu/test_utils.py b/cyborg/tests/unit/accelerator/drivers/gpu/test_utils.py index f87c00cd..81c26377 100644 --- a/cyborg/tests/unit/accelerator/drivers/gpu/test_utils.py +++ b/cyborg/tests/unit/accelerator/drivers/gpu/test_utils.py @@ -26,24 +26,45 @@ from cyborg.tests import base CONF = cyborg.conf.CONF -NVIDIA_GPU_INFO = "0000:00:06.0 3D controller [0302]: NVIDIA Corporation " \ - "GP100GL [Tesla P100 PCIe 12GB] [10de:15f7] (rev a1)" +NVIDIA_GPU_INFO = ( + "0000:00:06.0 3D controller [0302]: NVIDIA Corporation " + "GP100GL [Tesla P100 PCIe 12GB] [10de:15f7] (rev a1)" +) -NVIDIA_T4_GPU_INFO = "0000:af:00.0 3D controller [0302]: NVIDIA Corporation "\ - "TU104GL [Tesla T4] [10de:1eb8] (rev a1)" +NVIDIA_T4_GPU_INFO = ( + "0000:af:00.0 3D controller [0302]: NVIDIA Corporation " + "TU104GL [Tesla T4] [10de:1eb8] (rev a1)" +) -NVIDIA_A100_PF_INFO = "0000:3b:00.0 3D controller [0302]: NVIDIA Corporation "\ - "GA100 [A100 PCIe 40GB] [10de:20f1] (rev a1)" +NVIDIA_A100_PF_INFO = ( + "0000:3b:00.0 3D controller [0302]: NVIDIA Corporation " + "GA100 [A100 PCIe 40GB] [10de:20f1] (rev a1)" +) -NVIDIA_A100_VF_INFO = "0000:3b:00.4 3D controller [0302]: NVIDIA Corporation "\ - "GA100 [A100 PCIe 40GB] [10de:20f1] (rev a1)" +NVIDIA_A100_VF_INFO = ( + "0000:3b:00.4 3D controller [0302]: NVIDIA Corporation " + "GA100 [A100 PCIe 40GB] [10de:20f1] (rev a1)" +) -NVIDIA_T4_SUPPORTED_MDEV_TYPES = ['nvidia-222', 'nvidia-223', 'nvidia-224', - 'nvidia-225', 'nvidia-226', 'nvidia-227', - 'nvidia-228', 'nvidia-229', 'nvidia-230', - 'nvidia-231', 'nvidia-232', 'nvidia-233', - 'nvidia-234', 'nvidia-252', 'nvidia-319', - 'nvidia-320', 'nvidia-321'] +NVIDIA_T4_SUPPORTED_MDEV_TYPES = [ + 'nvidia-222', + 'nvidia-223', + 'nvidia-224', + 'nvidia-225', + 'nvidia-226', + 'nvidia-227', + 'nvidia-228', + 'nvidia-229', + 'nvidia-230', + 'nvidia-231', + 'nvidia-232', + 'nvidia-233', + 'nvidia-234', + 'nvidia-252', + 'nvidia-319', + 'nvidia-320', + 'nvidia-321', +] BUILTIN = '__builtin__' if (sys.version_info[0] < 3) else '__builtins__' @@ -68,7 +89,6 @@ class p: class TestGPUDriverUtils(base.TestCase): - def setUp(self): super().setUp() self.p = p() @@ -90,12 +110,14 @@ class TestGPUDriverUtils(base.TestCase): self.assertEqual(1, len(gpu_list)) attach_handle_list = [ - {'attach_type': 'PCI', - 'attach_info': '{"bus": "00", ' - '"device": "06", ' - '"domain": "0000", ' - '"function": "0"}', - 'in_use': False} + { + 'attach_type': 'PCI', + 'attach_info': '{"bus": "00", ' + '"device": "06", ' + '"domain": "0000", ' + '"function": "0"}', + 'in_use': False, + } ] attribute_list = [ {'key': 'rc', 'value': 'PGPU'}, @@ -106,64 +128,85 @@ class TestGPUDriverUtils(base.TestCase): 'vendor': '10de', 'type': 'GPU', 'model': 'NVIDIA Corporation GP100GL [Tesla P100 PCIe 12GB]', - 'std_board_info': - {"controller": "3D controller", "product_id": "15f7"}, + 'std_board_info': { + "controller": "3D controller", + "product_id": "15f7", + }, 'vendor_board_info': {"vendor_info": "gpu_vb_info"}, - 'deployable_list': - [ - { - 'num_accelerators': 1, - 'driver_name': 'NVIDIA', - 'name': 'host-192-168-32-195_0000:00:06.0', - 'attach_handle_list': attach_handle_list, - 'attribute_list': attribute_list - }, - ], - 'controlpath_id': {'cpid_info': '{"bus": "00", ' - '"device": "06", ' - '"domain": "0000", ' - '"function": "0"}', - 'cpid_type': 'PCI'} - } + 'deployable_list': [ + { + 'num_accelerators': 1, + 'driver_name': 'NVIDIA', + 'name': 'host-192-168-32-195_0000:00:06.0', + 'attach_handle_list': attach_handle_list, + 'attribute_list': attribute_list, + }, + ], + 'controlpath_id': { + 'cpid_info': '{"bus": "00", ' + '"device": "06", ' + '"domain": "0000", ' + '"function": "0"}', + 'cpid_type': 'PCI', + }, + } gpu_obj = gpu_list[0] gpu_dict = gpu_obj.as_dict() gpu_dep_list = gpu_dict['deployable_list'] - gpu_attach_handle_list = \ - gpu_dep_list[0].as_dict()['attach_handle_list'] - gpu_attribute_list = \ - gpu_dep_list[0].as_dict()['attribute_list'] + gpu_attach_handle_list = gpu_dep_list[0].as_dict()[ + 'attach_handle_list' + ] + gpu_attribute_list = gpu_dep_list[0].as_dict()['attribute_list'] attri_obj_data = [] [attri_obj_data.append(attr.as_dict()) for attr in gpu_attribute_list] attribute_actual_data = sorted(attri_obj_data, key=lambda i: i['key']) self.assertEqual(expected['vendor'], gpu_dict['vendor']) self.assertEqual(expected['model'], gpu_dict['model']) - self.assertEqual(expected['controlpath_id'], - gpu_dict['controlpath_id']) - self.assertEqual(expected['std_board_info'], - jsonutils.loads(gpu_dict['std_board_info'])) - self.assertEqual(expected['vendor_board_info'], - jsonutils.loads(gpu_dict['vendor_board_info'])) - self.assertEqual(expected['deployable_list'][0]['num_accelerators'], - gpu_dep_list[0].as_dict()['num_accelerators']) - self.assertEqual(expected['deployable_list'][0]['name'], - gpu_dep_list[0].as_dict()['name']) - self.assertEqual(expected['deployable_list'][0]['driver_name'], - gpu_dep_list[0].as_dict()['driver_name']) - self.assertEqual(attach_handle_list[0], - gpu_attach_handle_list[0].as_dict()) + self.assertEqual( + expected['controlpath_id'], gpu_dict['controlpath_id'] + ) + self.assertEqual( + expected['std_board_info'], + jsonutils.loads(gpu_dict['std_board_info']), + ) + self.assertEqual( + expected['vendor_board_info'], + jsonutils.loads(gpu_dict['vendor_board_info']), + ) + self.assertEqual( + expected['deployable_list'][0]['num_accelerators'], + gpu_dep_list[0].as_dict()['num_accelerators'], + ) + self.assertEqual( + expected['deployable_list'][0]['name'], + gpu_dep_list[0].as_dict()['name'], + ) + self.assertEqual( + expected['deployable_list'][0]['driver_name'], + gpu_dep_list[0].as_dict()['driver_name'], + ) + self.assertEqual( + attach_handle_list[0], gpu_attach_handle_list[0].as_dict() + ) self.assertEqual(attribute_list, attribute_actual_data) - @mock.patch('cyborg.accelerator.drivers.gpu.nvidia.sysinfo._is_vf', - return_value=False, autospec=True) + @mock.patch( + 'cyborg.accelerator.drivers.gpu.nvidia.sysinfo._is_vf', + return_value=False, + autospec=True, + ) @mock.patch('builtins.open') @mock.patch('os.listdir') @mock.patch('os.path.exists') @mock.patch('cyborg.accelerator.drivers.gpu.utils.lspci_privileged') - def test_discover_gpus_report_vGPU(self, mock_devices_for_vendor, - mock_path_exists, - mock_supported_mdev_types, - mock_open, - mock_is_vf): + def test_discover_gpus_report_vGPU( + self, + mock_devices_for_vendor, + mock_path_exists, + mock_supported_mdev_types, + mock_open, + mock_is_vf, + ): """test nvidia vGPU discover""" mock_devices_for_vendor.return_value = self.p.stdout.readlines_T4() mock_path_exists.return_value = True @@ -174,20 +217,23 @@ class TestGPUDriverUtils(base.TestCase): self.set_defaults(enabled_vgpu_types='nvidia-223', group='gpu_devices') cyborg.conf.devices.register_dynamic_opts(CONF) self.set_defaults( - device_addresses=['0000:af:00.0'], group='vgpu_nvidia-223') + device_addresses=['0000:af:00.0'], group='vgpu_nvidia-223' + ) nvidia = NVIDIAGPUDriver() gpu_list = nvidia.discover() self.assertEqual(1, len(gpu_list)) attach_handle_list = [ - {'attach_type': 'MDEV', - 'attach_info': '{"asked_type": "nvidia-223", ' - '"bus": "af", ' - '"device": "00", ' - '"domain": "0000", ' - '"function": "0", ' - '"vgpu_mark": "nvidia-223_0"}', - 'in_use': False} + { + 'attach_type': 'MDEV', + 'attach_info': '{"asked_type": "nvidia-223", ' + '"bus": "af", ' + '"device": "00", ' + '"domain": "0000", ' + '"function": "0", ' + '"vgpu_mark": "nvidia-223_0"}', + 'in_use': False, + } ] * 8 attribute_list = [ {'key': 'rc', 'value': 'VGPU'}, @@ -197,58 +243,74 @@ class TestGPUDriverUtils(base.TestCase): expected = { 'vendor': '10de', 'type': 'GPU', - 'std_board_info': - {"controller": "3D controller", "product_id": "1eb8"}, + 'std_board_info': { + "controller": "3D controller", + "product_id": "1eb8", + }, 'vendor_board_info': {"vendor_info": "gpu_vb_info"}, - 'deployable_list': - [ - { - 'num_accelerators': 18, - 'driver_name': 'NVIDIA', - 'name': 'host-192-168-32-195_0000:af:00.0', - 'attach_handle_list': attach_handle_list, - 'attribute_list': attribute_list - }, - ], - 'controlpath_id': {'cpid_info': '{"bus": "af", ' - '"device": "00", ' - '"domain": "0000", ' - '"function": "0"}', - 'cpid_type': 'PCI'} - } + 'deployable_list': [ + { + 'num_accelerators': 18, + 'driver_name': 'NVIDIA', + 'name': 'host-192-168-32-195_0000:af:00.0', + 'attach_handle_list': attach_handle_list, + 'attribute_list': attribute_list, + }, + ], + 'controlpath_id': { + 'cpid_info': '{"bus": "af", ' + '"device": "00", ' + '"domain": "0000", ' + '"function": "0"}', + 'cpid_type': 'PCI', + }, + } gpu_obj = gpu_list[0] gpu_dict = gpu_obj.as_dict() gpu_dep_list = gpu_dict['deployable_list'] gpu_attach_handle_list = gpu_dep_list[0].as_dict()[ - 'attach_handle_list'] + 'attach_handle_list' + ] gpu_attribute_list = gpu_dep_list[0].as_dict()['attribute_list'] attri_obj_data = [] [attri_obj_data.append(attr.as_dict()) for attr in gpu_attribute_list] attribute_actual_data = sorted(attri_obj_data, key=lambda i: i['key']) self.assertEqual(expected['vendor'], gpu_dict['vendor']) - self.assertEqual(expected['controlpath_id'], - gpu_dict['controlpath_id']) - self.assertEqual(expected['std_board_info'], - jsonutils.loads(gpu_dict['std_board_info'])) - self.assertEqual(expected['vendor_board_info'], - jsonutils.loads(gpu_dict['vendor_board_info'])) - self.assertEqual(expected['deployable_list'][0]['num_accelerators'], - gpu_dep_list[0].as_dict()['num_accelerators']) - self.assertEqual(expected['deployable_list'][0]['name'], - gpu_dep_list[0].as_dict()['name']) - self.assertEqual(expected['deployable_list'][0]['driver_name'], - gpu_dep_list[0].as_dict()['driver_name']) - self.assertEqual(attach_handle_list[0], - gpu_attach_handle_list[0].as_dict()) + self.assertEqual( + expected['controlpath_id'], gpu_dict['controlpath_id'] + ) + self.assertEqual( + expected['std_board_info'], + jsonutils.loads(gpu_dict['std_board_info']), + ) + self.assertEqual( + expected['vendor_board_info'], + jsonutils.loads(gpu_dict['vendor_board_info']), + ) + self.assertEqual( + expected['deployable_list'][0]['num_accelerators'], + gpu_dep_list[0].as_dict()['num_accelerators'], + ) + self.assertEqual( + expected['deployable_list'][0]['name'], + gpu_dep_list[0].as_dict()['name'], + ) + self.assertEqual( + expected['deployable_list'][0]['driver_name'], + gpu_dep_list[0].as_dict()['driver_name'], + ) + self.assertEqual( + attach_handle_list[0], gpu_attach_handle_list[0].as_dict() + ) self.assertEqual(attribute_list, attribute_actual_data) @mock.patch('cyborg.accelerator.drivers.gpu.nvidia.sysinfo._is_vf') @mock.patch('cyborg.accelerator.drivers.gpu.utils.lspci_privileged') - def test_discover_gpus_filters_vf_devices(self, mock_devices_for_vendor, - mock_is_vf): + def test_discover_gpus_filters_vf_devices( + self, mock_devices_for_vendor, mock_is_vf + ): """Test that VF devices are filtered when filter_sriov_vfs=True.""" - mock_devices_for_vendor.return_value = ( - self.p.stdout.readlines_A100()) + mock_devices_for_vendor.return_value = self.p.stdout.readlines_A100() mock_is_vf.side_effect = lambda addr: addr == '0000:3b:00.4' self.set_defaults(host='host-192-168-32-195', debug=True) self.set_defaults(filter_sriov_vfs=True, group='gpu_devices') @@ -264,10 +326,10 @@ class TestGPUDriverUtils(base.TestCase): @mock.patch('cyborg.accelerator.drivers.gpu.nvidia.sysinfo._is_vf') @mock.patch('cyborg.accelerator.drivers.gpu.utils.lspci_privileged') def test_discover_gpus_vf_not_filtered_by_default( - self, mock_devices_for_vendor, mock_is_vf): + self, mock_devices_for_vendor, mock_is_vf + ): """Test that VFs are reported when filter_sriov_vfs=False (default).""" - mock_devices_for_vendor.return_value = ( - self.p.stdout.readlines_A100()) + mock_devices_for_vendor.return_value = self.p.stdout.readlines_A100() mock_is_vf.side_effect = lambda addr: addr == '0000:3b:00.4' self.set_defaults(host='host-192-168-32-195', debug=True) @@ -280,14 +342,13 @@ class TestGPUDriverUtils(base.TestCase): @mock.patch('cyborg.accelerator.drivers.gpu.nvidia.sysinfo._is_vf') @mock.patch('cyborg.accelerator.drivers.gpu.utils.lspci_privileged') def test_discover_gpus_continues_on_is_vf_oserror( - self, mock_devices_for_vendor, mock_is_vf): + self, mock_devices_for_vendor, mock_is_vf + ): """Test discovery continues when _is_vf raises OSError.""" - mock_devices_for_vendor.return_value = ( - self.p.stdout.readlines_A100()) + mock_devices_for_vendor.return_value = self.p.stdout.readlines_A100() mock_is_vf.side_effect = OSError('device busy') self.set_defaults(host='host-192-168-32-195', debug=True) - self.set_defaults( - filter_sriov_vfs=True, group='gpu_devices') + self.set_defaults(filter_sriov_vfs=True, group='gpu_devices') nvidia = NVIDIAGPUDriver() gpu_list = nvidia.discover() @@ -297,24 +358,26 @@ class TestGPUDriverUtils(base.TestCase): class TestIsVf(base.TestCase): - @mock.patch('os.path.exists', return_value=True) def test_is_vf_returns_true_when_physfn_exists(self, mock_exists): self.assertTrue(sysinfo._is_vf('0000:3b:00.4')) mock_exists.assert_called_once_with( - '/sys/bus/pci/devices/0000:3b:00.4/physfn') + '/sys/bus/pci/devices/0000:3b:00.4/physfn' + ) @mock.patch('os.path.exists', return_value=False) def test_is_vf_returns_false_for_pf(self, mock_exists): self.assertFalse(sysinfo._is_vf('0000:3b:00.0')) mock_exists.assert_called_once_with( - '/sys/bus/pci/devices/0000:3b:00.0/physfn') + '/sys/bus/pci/devices/0000:3b:00.0/physfn' + ) @mock.patch('os.path.exists', side_effect=OSError('device busy')) def test_is_vf_returns_false_on_oserror(self, mock_exists): self.assertFalse(sysinfo._is_vf('0000:3b:00.4')) mock_exists.assert_called_once_with( - '/sys/bus/pci/devices/0000:3b:00.4/physfn') + '/sys/bus/pci/devices/0000:3b:00.4/physfn' + ) def multi_mock_open(*file_contents): @@ -328,8 +391,9 @@ def multi_mock_open(*file_contents): """ mock_files = [ - mock.mock_open(read_data=content).return_value for content in - file_contents] + mock.mock_open(read_data=content).return_value + for content in file_contents + ] mock_opener = mock.mock_open() mock_opener.side_effect = mock_files return mock_opener diff --git a/cyborg/tests/unit/accelerator/drivers/nic/intel/prepare_test_data.py b/cyborg/tests/unit/accelerator/drivers/nic/intel/prepare_test_data.py index 9ef91e93..3008a5d7 100644 --- a/cyborg/tests/unit/accelerator/drivers/nic/intel/prepare_test_data.py +++ b/cyborg/tests/unit/accelerator/drivers/nic/intel/prepare_test_data.py @@ -21,9 +21,9 @@ PF0_ADDR = "0000:05:00.0" PF1_ADDR = "0000:06:00.0" VF0_ADDR = "0000:05:01.0" NIC_TREE = { - "dev.0": {"bdf": PF0_ADDR, - "vfs": {"dev.2": {"bdf": VF0_ADDR}}}, - "dev.1": {"bdf": PF1_ADDR}} + "dev.0": {"bdf": PF0_ADDR, "vfs": {"dev.2": {"bdf": VF0_ADDR}}}, + "dev.1": {"bdf": PF1_ADDR}, +} SYS_DEVICES = "sys/devices" PCI_DEVICES_PATH = "sys/bus/pci/devices" @@ -53,7 +53,8 @@ NIC_DEVICE_COMMON_CONTENT = { "resource0": "", "subsystem_device": "0x0002", "subsystem_vendor": "0x8086", - "vendor": "0x8086"} + "vendor": "0x8086", +} NIC_DEVICES_SPECIAL_COMMON_CONTENT = { "dev.0": { @@ -70,7 +71,8 @@ NIC_DEVICES_SPECIAL_COMMON_CONTENT = { "0x00000000d1390000 0x00000000d139ffff 0x0000000000140204", "0x0000000000000000 0x0000000000000000 0x0000000000000000", "0x0000000000000000 0x0000000000000000 0x0000000000000000", - "0x0000000000000000 0x0000000000000000 0x0000000000000000"], + "0x0000000000000000 0x0000000000000000 0x0000000000000000", + ], "resource2": "", "resource4": "", "sriov_numvfs": "1", @@ -81,7 +83,8 @@ NIC_DEVICES_SPECIAL_COMMON_CONTENT = { "PCI_ID=8086:37C8", "PCI_SUBSYS_ID=8086:0002", "PCI_SLOT_NAME=0000:05:00.0", - "MODALIAS=pci:v00008086d000037C8sv00008086sd00000002bc0Bsc40i00"], + "MODALIAS=pci:v00008086d000037C8sv00008086sd00000002bc0Bsc40i00", + ], }, "dev.1": { "resource": [ @@ -97,7 +100,8 @@ NIC_DEVICES_SPECIAL_COMMON_CONTENT = { "0x0000000000000000 0x0000000000000000 0x0000000000000000", "0x0000000000000000 0x0000000000000000 0x0000000000000000", "0x0000000000000000 0x0000000000000000 0x0000000000000000", - "0x0000000000000000 0x0000000000000000 0x0000000000000000"], + "0x0000000000000000 0x0000000000000000 0x0000000000000000", + ], "resource2": "", "sriov_numvfs": "0", "sriov_totalvfs": "0", @@ -107,7 +111,8 @@ NIC_DEVICES_SPECIAL_COMMON_CONTENT = { "PCI_ID=8086:37C8", "PCI_SUBSYS_ID=8086:0002", "PCI_SLOT_NAME=0000:06:00.0", - "MODALIAS=pci:v00008086d000037C8sv00008086sd00000002bc0Bsc40i00"], + "MODALIAS=pci:v00008086d000037C8sv00008086sd00000002bc0Bsc40i00", + ], }, "dev.2": { "d3cold_allowed": "0", @@ -127,21 +132,23 @@ NIC_DEVICES_SPECIAL_COMMON_CONTENT = { "0x0000000000000000 0x0000000000000000 0x0000000000000000", "0x0000000000000000 0x0000000000000000 0x0000000000000000", "0x0000000000000000 0x0000000000000000 0x0000000000000000", - "0x0000000000000000 0x0000000000000000 0x0000000000000000"], + "0x0000000000000000 0x0000000000000000 0x0000000000000000", + ], "uevent": [ "DRIVER=c6xx", "PCI_CLASS=B4000", "PCI_ID=8086:37C8", "PCI_SUBSYS_ID=8086:0002", "PCI_SLOT_NAME=0000:05:01.0", - "MODALIAS=pci:v00008086d000037C8sv00008086sd00000002bc0Bsc40i00"], - } + "MODALIAS=pci:v00008086d000037C8sv00008086sd00000002bc0Bsc40i00", + ], + }, } NIC_DEVICE_COMMON_SOFT_LINK = { "driver": "../../../../../../bus/pci/drivers/c6xx", "iommu": "../../../../../virtual/iommu/dmar1", - "subsystem": "../../../../../../bus/pci" + "subsystem": "../../../../../../bus/pci", } NIC_DEVICES_SPECIAL_SOFT_LINK = { @@ -157,17 +164,17 @@ NIC_DEVICES_SPECIAL_SOFT_LINK = { "dev.2": { "iommu_group": "../../../../../../kernel/iommu_groups/67", # "physfn": "../0000:05:00.0/", - } + }, } NIC_DEVICE_PF_SOFT_LINK = { - "virtfn": lambda k, v: (k + str(int(v.rsplit(".", 1)[-1])), - "/".join(["..", v])) + "virtfn": lambda k, v: ( + k + str(int(v.rsplit(".", 1)[-1])), + "/".join(["..", v]), + ) } -NIC_DEVICE_VF_SOFT_LINK = { - "physfn": lambda k, v: (k, "/".join(["..", v])) -} +NIC_DEVICE_VF_SOFT_LINK = {"physfn": lambda k, v: (k, "/".join(["..", v]))} def gen_nic_content(path, dev): @@ -225,15 +232,16 @@ def create_devices_path_and_files(tree, device_path, vf=False, pfinfo=None): gen_nic_vf_soft_link(bdf_path, pfinfo["bdf"]) pfinfo = {"path": bdf_path, "bdf": bdf} if "vfs" in v: - create_devices_path_and_files( - v["vfs"], device_path, True, pfinfo) + create_devices_path_and_files(v["vfs"], device_path, True, pfinfo) os.symlink("../../../" + bdf, os.path.join(dev_path, "device")) - pci_dev = os.path.join(device_path.split(SYS_DEVICES)[0], - PCI_DEVICES_PATH) + pci_dev = os.path.join( + device_path.split(SYS_DEVICES)[0], PCI_DEVICES_PATH + ) if not os.path.exists(pci_dev): os.makedirs(pci_dev) - os.symlink("../../.." + bdf_path.split("sys")[-1], - os.path.join(pci_dev, bdf)) + os.symlink( + "../../.." + bdf_path.split("sys")[-1], os.path.join(pci_dev, bdf) + ) def create_fake_sysfs(prefix=""): diff --git a/cyborg/tests/unit/accelerator/drivers/nic/intel/test_driver.py b/cyborg/tests/unit/accelerator/drivers/nic/intel/test_driver.py index 96dc9bd7..856a5927 100644 --- a/cyborg/tests/unit/accelerator/drivers/nic/intel/test_driver.py +++ b/cyborg/tests/unit/accelerator/drivers/nic/intel/test_driver.py @@ -30,7 +30,8 @@ class TestIntelNICDriver(base.TestCase): prepare_test_data.create_fake_sysfs(tmp_sys_dir.path) tmp_path = tmp_sys_dir.path sysinfo.PCI_DEVICES_PATH_PATTERN = os.path.join( - tmp_path, sysinfo.PCI_DEVICES_PATH_PATTERN.split("/", 1)[-1]) + tmp_path, sysinfo.PCI_DEVICES_PATH_PATTERN.split("/", 1)[-1] + ) def tearDown(self): super().tearDown() @@ -41,81 +42,76 @@ class TestIntelNICDriver(base.TestCase): mock_device_ifname.return_value = "ethx" attach_handle_list = [ [ - {'attach_type': 'PCI', - 'attach_info': '{"bus": "05", ' - '"device": "01", ' - '"domain": "0000", ' - '"function": "0"}', - 'in_use': False} + { + 'attach_type': 'PCI', + 'attach_info': '{"bus": "05", ' + '"device": "01", ' + '"domain": "0000", ' + '"function": "0"}', + 'in_use': False, + } ], [ - {'attach_type': 'PCI', - 'attach_info': '{"bus": "06", ' - '"device": "00", ' - '"domain": "0000", ' - '"function": "0"}', - 'in_use': False} - ] + { + 'attach_type': 'PCI', + 'attach_info': '{"bus": "06", ' + '"device": "00", ' + '"domain": "0000", ' + '"function": "0"}', + 'in_use': False, + } + ], ] attribute_list = [ [ - { - "key": "rc", - "value": "CUSTOM_NIC" - }, - { - "key": "trait0", - "value": "CUSTOM_VF" - } + {"key": "rc", "value": "CUSTOM_NIC"}, + {"key": "trait0", "value": "CUSTOM_VF"}, ], [ - { - "key": "rc", - "value": "CUSTOM_NIC" - }, - { - "key": "trait0", - "value": "CUSTOM_PF" - } + {"key": "rc", "value": "CUSTOM_NIC"}, + {"key": "trait0", "value": "CUSTOM_PF"}, ], ] - expected = [{'vendor': '0x8086', - 'type': 'NIC', - 'deployable_list': - [ - {'num_accelerators': 1, - 'name': '0000:05:01.0', - 'attach_handle_list': attach_handle_list[0], - 'attribute_list': attribute_list[0] - }, - ], - 'controlpath_id': - { - 'cpid_info': '{"bus": "05", ' - '"device": "00", ' - '"domain": "0000", ' - '"function": "0"}', - 'cpid_type': 'PCI'} - }, - {'vendor': '0x8086', - 'type': 'NIC', - 'deployable_list': - [ - {'num_accelerators': 1, - 'name': '0000:06:00.0', - 'attach_handle_list': attach_handle_list[1], - 'attribute_list': attribute_list[1] - }, - ], - 'controlpath_id': - { - 'cpid_info': '{"bus": "06", ' - '"device": "00", ' - '"domain": "0000", ' - '"function": "0"}', - 'cpid_type': 'PCI'} - } - ] + expected = [ + { + 'vendor': '0x8086', + 'type': 'NIC', + 'deployable_list': [ + { + 'num_accelerators': 1, + 'name': '0000:05:01.0', + 'attach_handle_list': attach_handle_list[0], + 'attribute_list': attribute_list[0], + }, + ], + 'controlpath_id': { + 'cpid_info': '{"bus": "05", ' + '"device": "00", ' + '"domain": "0000", ' + '"function": "0"}', + 'cpid_type': 'PCI', + }, + }, + { + 'vendor': '0x8086', + 'type': 'NIC', + 'deployable_list': [ + { + 'num_accelerators': 1, + 'name': '0000:06:00.0', + 'attach_handle_list': attach_handle_list[1], + 'attribute_list': attribute_list[1], + }, + ], + 'controlpath_id': { + 'cpid_info': '{"bus": "06", ' + '"device": "00", ' + '"domain": "0000", ' + '"function": "0"}', + 'cpid_type': 'PCI', + }, + }, + ] intel = IntelNICDriver() nics = intel.discover() list.sort(nics, key=lambda x: x._obj_deployable_list[0].name) @@ -123,18 +119,22 @@ class TestIntelNICDriver(base.TestCase): for i in range(len(nics)): nic_dict = nics[i].as_dict() nic_dep_list = nic_dict['deployable_list'] - nic_attach_handle_list = \ - nic_dep_list[0].as_dict()['attach_handle_list'] - nic_attribute_list = \ - nic_dep_list[0].as_dict()['attribute_list'] + nic_attach_handle_list = nic_dep_list[0].as_dict()[ + 'attach_handle_list' + ] + nic_attribute_list = nic_dep_list[0].as_dict()['attribute_list'] self.assertEqual(expected[i]['vendor'], nic_dict['vendor']) - self.assertEqual(expected[i]['controlpath_id'], - nic_dict['controlpath_id']) - self.assertEqual(expected[i]['deployable_list'][0] - ['num_accelerators'], - nic_dep_list[0].as_dict()['num_accelerators']) + self.assertEqual( + expected[i]['controlpath_id'], nic_dict['controlpath_id'] + ) + self.assertEqual( + expected[i]['deployable_list'][0]['num_accelerators'], + nic_dep_list[0].as_dict()['num_accelerators'], + ) self.assertEqual(1, len(nic_attach_handle_list)) - self.assertEqual(attach_handle_list[i][0], - nic_attach_handle_list[0].as_dict()) - self.assertEqual(attribute_list[i][0], - nic_attribute_list[0].as_dict()) + self.assertEqual( + attach_handle_list[i][0], nic_attach_handle_list[0].as_dict() + ) + self.assertEqual( + attribute_list[i][0], nic_attribute_list[0].as_dict() + ) diff --git a/cyborg/tests/unit/accelerator/drivers/pci/pci/test_sysinfo.py b/cyborg/tests/unit/accelerator/drivers/pci/pci/test_sysinfo.py index 8c6359cd..ddf147c0 100644 --- a/cyborg/tests/unit/accelerator/drivers/pci/pci/test_sysinfo.py +++ b/cyborg/tests/unit/accelerator/drivers/pci/pci/test_sysinfo.py @@ -22,34 +22,38 @@ CONF = cyborg.conf.CONF NVIDIA_PCI_LINE = ( "0000:3b:00.0 3D controller [0302]: " - "NVIDIA Corporation Device [10de:1eb8] (rev a1)") + "NVIDIA Corporation Device [10de:1eb8] (rev a1)" +) INTEL_PCI_LINE = ( "0000:00:00.0 Host bridge [0600]: " - "Intel Corporation Device [8086:2020] (rev 07)") + "Intel Corporation Device [8086:2020] (rev 07)" +) UNKNOWN_VENDOR_LINE = ( "0000:05:00.0 Network controller [0280]: " - "Acme Corp Device [dead:beef] (rev 01)") + "Acme Corp Device [dead:beef] (rev 01)" +) MALFORMED_LINE = "garbage data that does not match" -_MOCK_GET_PCI = ( - 'cyborg.accelerator.drivers.pci.utils.get_pci_devices') +_MOCK_GET_PCI = 'cyborg.accelerator.drivers.pci.utils.get_pci_devices' class TestSysinfo(base.DietTestCase): """Unit tests for sysinfo module helpers.""" - def _make_pci_dict(self, vendor_id='10de', - product_id='1eb8', - devices='0000:3b:00.0', - hostname='testhost'): + def _make_pci_dict( + self, + vendor_id='10de', + product_id='1eb8', + devices='0000:3b:00.0', + hostname='testhost', + ): return { 'vendor_id': vendor_id, 'product_id': product_id, 'devices': devices, 'hostname': hostname, 'rc': constants.RESOURCES['PCI'], - 'traits': ['CUSTOM_PCI_NVIDIA', - 'CUSTOM_PCI_PRODUCT_ID_1EB8'], + 'traits': ['CUSTOM_PCI_NVIDIA', 'CUSTOM_PCI_PRODUCT_ID_1EB8'], } def test_match_standard_device(self): @@ -60,9 +64,7 @@ class TestSysinfo(base.DietTestCase): self.assertEqual('1eb8', m.group('product_id')) def test_match_uppercase_hex(self): - line = ( - "0000:3B:00.0 3D controller [0302]: " - "Device [10DE:1EB8] (rev a1)") + line = "0000:3B:00.0 3D controller [0302]: Device [10DE:1EB8] (rev a1)" m = sysinfo.LSPCI_PATTERN.match(line) self.assertIsNotNone(m) self.assertEqual('10DE', m.group('vendor_id')) @@ -71,7 +73,8 @@ class TestSysinfo(base.DietTestCase): def test_match_without_revision(self): line = ( "0000:3b:00.0 3D controller [0302]: " - "NVIDIA Corporation Device [10de:1eb8]") + "NVIDIA Corporation Device [10de:1eb8]" + ) m = sysinfo.LSPCI_PATTERN.match(line) self.assertIsNotNone(m) self.assertEqual('10de', m.group('vendor_id')) @@ -96,8 +99,8 @@ class TestSysinfo(base.DietTestCase): self.assertIn('CUSTOM_PCI_PRODUCT_ID_BEEF', traits) # No vendor trait should be present vendor_traits = [ - t for t in traits if not t.startswith( - 'CUSTOM_PCI_PRODUCT_ID')] + t for t in traits if not t.startswith('CUSTOM_PCI_PRODUCT_ID') + ] self.assertEqual([], vendor_traits) def test_product_id_uppercased(self): @@ -117,11 +120,9 @@ class TestSysinfo(base.DietTestCase): self.assertEqual('DEAD', dep_list[0].driver_name) def test_deployable_name_format(self): - pci = self._make_pci_dict( - hostname='myhost', devices='0000:3b:00.0') + pci = self._make_pci_dict(hostname='myhost', devices='0000:3b:00.0') dep_list, _num = sysinfo._generate_dep_list(pci) - self.assertEqual( - 'myhost_0000:3b:00.0', dep_list[0].name) + self.assertEqual('myhost_0000:3b:00.0', dep_list[0].name) class TestDiscoverPcis(base.TestCase): @@ -137,9 +138,7 @@ class TestDiscoverPcis(base.TestCase): :param specs: list of JSON-encoded whitelist entries, e.g. ['{"vendor_id":"10de"}'] """ - self.config( - passthrough_whitelist=specs, - group='pci') + self.config(passthrough_whitelist=specs, group='pci') @mock.patch(_MOCK_GET_PCI, autospec=True) def test_discover_single_known_device(self, mock_pci): @@ -158,22 +157,21 @@ class TestDiscoverPcis(base.TestCase): # Check deployable dep = dev_dict['deployable_list'][0] dep_dict = dep.as_dict() - self.assertEqual( - 'fake-pci-host_0000:3b:00.0', dep_dict['name']) + self.assertEqual('fake-pci-host_0000:3b:00.0', dep_dict['name']) self.assertEqual('NVIDIA', dep_dict['driver_name']) self.assertEqual(1, dep_dict['num_accelerators']) # Check attach handle ah = dep_dict['attach_handle_list'][0].as_dict() - self.assertEqual(constants.AH_TYPE_PCI, - ah['attach_type']) + self.assertEqual(constants.AH_TYPE_PCI, ah['attach_type']) self.assertFalse(ah['in_use']) # Check attributes contain traits attrs = dep_dict['attribute_list'] attr_data = [a.as_dict() for a in attrs] attr_keys = [a['key'] for a in attr_data] self.assertIn('rc', attr_keys) - trait_vals = [a['value'] for a in attr_data - if a['key'].startswith('trait')] + trait_vals = [ + a['value'] for a in attr_data if a['key'].startswith('trait') + ] self.assertIn('CUSTOM_PCI_NVIDIA', trait_vals) self.assertIn('CUSTOM_PCI_PRODUCT_ID_1EB8', trait_vals) @@ -190,25 +188,26 @@ class TestDiscoverPcis(base.TestCase): # No vendor trait, only product trait attrs = dep_dict['attribute_list'] attr_data = [a.as_dict() for a in attrs] - trait_vals = [a['value'] for a in attr_data - if a['key'].startswith('trait')] + trait_vals = [ + a['value'] for a in attr_data if a['key'].startswith('trait') + ] self.assertIn('CUSTOM_PCI_PRODUCT_ID_BEEF', trait_vals) # Ensure no vendor-specific trait is present vendor_traits = [ - t for t in trait_vals - if not t.startswith('CUSTOM_PCI_PRODUCT_ID')] + t for t in trait_vals if not t.startswith('CUSTOM_PCI_PRODUCT_ID') + ] self.assertEqual([], vendor_traits) @mock.patch(_MOCK_GET_PCI, autospec=True) def test_discover_multiple_devices_filtered(self, mock_pci): - stdout = '\n'.join([ - NVIDIA_PCI_LINE, INTEL_PCI_LINE, UNKNOWN_VENDOR_LINE]) + stdout = '\n'.join( + [NVIDIA_PCI_LINE, INTEL_PCI_LINE, UNKNOWN_VENDOR_LINE] + ) mock_pci.return_value = (stdout, '') self._set_whitelist(['{"vendor_id":"10de"}']) devs = sysinfo._discover_pcis() self.assertEqual(1, len(devs)) - self.assertEqual( - '10de', devs[0].as_dict()['vendor']) + self.assertEqual('10de', devs[0].as_dict()['vendor']) @mock.patch(_MOCK_GET_PCI, autospec=True) def test_discover_empty_output(self, mock_pci): @@ -226,14 +225,14 @@ class TestDiscoverPcis(base.TestCase): @mock.patch(_MOCK_GET_PCI, autospec=True) def test_discover_malformed_lines_skipped(self, mock_pci): - stdout = '\n'.join([ - MALFORMED_LINE, '', NVIDIA_PCI_LINE, MALFORMED_LINE]) + stdout = '\n'.join( + [MALFORMED_LINE, '', NVIDIA_PCI_LINE, MALFORMED_LINE] + ) mock_pci.return_value = (stdout, '') self._set_whitelist(['{"vendor_id":"10de"}']) devs = sysinfo._discover_pcis() self.assertEqual(1, len(devs)) - self.assertEqual( - '10de', devs[0].as_dict()['vendor']) + self.assertEqual('10de', devs[0].as_dict()['vendor']) @mock.patch(_MOCK_GET_PCI, autospec=True) def test_discover_device_not_in_whitelist(self, mock_pci): @@ -249,5 +248,4 @@ class TestDiscoverPcis(base.TestCase): driver = PCIDriver() devs = driver.discover() self.assertEqual(1, len(devs)) - self.assertEqual( - '10de', devs[0].as_dict()['vendor']) + self.assertEqual('10de', devs[0].as_dict()['vendor']) diff --git a/cyborg/tests/unit/accelerator/drivers/qat/intel/prepare_test_data.py b/cyborg/tests/unit/accelerator/drivers/qat/intel/prepare_test_data.py index 00bf2fad..73f76582 100755 --- a/cyborg/tests/unit/accelerator/drivers/qat/intel/prepare_test_data.py +++ b/cyborg/tests/unit/accelerator/drivers/qat/intel/prepare_test_data.py @@ -23,9 +23,9 @@ PF0_ADDR = "0000:05:00.0" PF1_ADDR = "0000:06:00.0" VF0_ADDR = "0000:05:01.0" QAT_TREE = { - "dev.0": {"bdf": PF0_ADDR, - "vfs": {"dev.2": {"bdf": VF0_ADDR}}}, - "dev.1": {"bdf": PF1_ADDR}} + "dev.0": {"bdf": PF0_ADDR, "vfs": {"dev.2": {"bdf": VF0_ADDR}}}, + "dev.1": {"bdf": PF1_ADDR}, +} SYS_DEVICES = "sys/devices" PCI_DEVICES_PATH = "sys/bus/pci/devices" @@ -55,7 +55,8 @@ QAT_DEVICE_COMMON_CONTENT = { "resource0": "", "subsystem_device": "0x0002", "subsystem_vendor": "0x8086", - "vendor": "0x8086"} + "vendor": "0x8086", +} QAT_DEVICES_SPECIAL_COMMON_CONTENT = { "dev.0": { @@ -72,7 +73,8 @@ QAT_DEVICES_SPECIAL_COMMON_CONTENT = { "0x00000000d1390000 0x00000000d139ffff 0x0000000000140204", "0x0000000000000000 0x0000000000000000 0x0000000000000000", "0x0000000000000000 0x0000000000000000 0x0000000000000000", - "0x0000000000000000 0x0000000000000000 0x0000000000000000"], + "0x0000000000000000 0x0000000000000000 0x0000000000000000", + ], "resource2": "", "resource4": "", "sriov_numvfs": "1", @@ -83,7 +85,8 @@ QAT_DEVICES_SPECIAL_COMMON_CONTENT = { "PCI_ID=8086:37C8", "PCI_SUBSYS_ID=8086:0002", "PCI_SLOT_NAME=0000:05:00.0", - "MODALIAS=pci:v00008086d000037C8sv00008086sd00000002bc0Bsc40i00"], + "MODALIAS=pci:v00008086d000037C8sv00008086sd00000002bc0Bsc40i00", + ], }, "dev.1": { "resource": [ @@ -99,7 +102,8 @@ QAT_DEVICES_SPECIAL_COMMON_CONTENT = { "0x0000000000000000 0x0000000000000000 0x0000000000000000", "0x0000000000000000 0x0000000000000000 0x0000000000000000", "0x0000000000000000 0x0000000000000000 0x0000000000000000", - "0x0000000000000000 0x0000000000000000 0x0000000000000000"], + "0x0000000000000000 0x0000000000000000 0x0000000000000000", + ], "resource2": "", "sriov_numvfs": "0", "sriov_totalvfs": "0", @@ -109,7 +113,8 @@ QAT_DEVICES_SPECIAL_COMMON_CONTENT = { "PCI_ID=8086:37C8", "PCI_SUBSYS_ID=8086:0002", "PCI_SLOT_NAME=0000:06:00.0", - "MODALIAS=pci:v00008086d000037C8sv00008086sd00000002bc0Bsc40i00"], + "MODALIAS=pci:v00008086d000037C8sv00008086sd00000002bc0Bsc40i00", + ], }, "dev.2": { "d3cold_allowed": "0", @@ -129,21 +134,23 @@ QAT_DEVICES_SPECIAL_COMMON_CONTENT = { "0x0000000000000000 0x0000000000000000 0x0000000000000000", "0x0000000000000000 0x0000000000000000 0x0000000000000000", "0x0000000000000000 0x0000000000000000 0x0000000000000000", - "0x0000000000000000 0x0000000000000000 0x0000000000000000"], + "0x0000000000000000 0x0000000000000000 0x0000000000000000", + ], "uevent": [ "DRIVER=c6xx", "PCI_CLASS=B4000", "PCI_ID=8086:37C8", "PCI_SUBSYS_ID=8086:0002", "PCI_SLOT_NAME=0000:05:01.0", - "MODALIAS=pci:v00008086d000037C8sv00008086sd00000002bc0Bsc40i00"], - } + "MODALIAS=pci:v00008086d000037C8sv00008086sd00000002bc0Bsc40i00", + ], + }, } QAT_DEVICE_COMMON_SOFT_LINK = { "driver": "../../../../../../bus/pci/drivers/c6xx", "iommu": "../../../../../virtual/iommu/dmar1", - "subsystem": "../../../../../../bus/pci" + "subsystem": "../../../../../../bus/pci", } QAT_DEVICES_SPECIAL_SOFT_LINK = { @@ -159,17 +166,17 @@ QAT_DEVICES_SPECIAL_SOFT_LINK = { "dev.2": { "iommu_group": "../../../../../../kernel/iommu_groups/67", # "physfn": "../0000:05:00.0/", - } + }, } QAT_DEVICE_PF_SOFT_LINK = { - "virtfn": lambda k, v: (k + str(int(v.rsplit(".", 1)[-1])), - "/".join(["..", v])) + "virtfn": lambda k, v: ( + k + str(int(v.rsplit(".", 1)[-1])), + "/".join(["..", v]), + ) } -QAT_DEVICE_VF_SOFT_LINK = { - "physfn": lambda k, v: (k, "/".join(["..", v])) -} +QAT_DEVICE_VF_SOFT_LINK = {"physfn": lambda k, v: (k, "/".join(["..", v]))} def gen_qat_content(path, dev): @@ -227,15 +234,16 @@ def create_devices_path_and_files(tree, device_path, vf=False, pfinfo=None): gen_qat_vf_soft_link(bdf_path, pfinfo["bdf"]) pfinfo = {"path": bdf_path, "bdf": bdf} if "vfs" in v: - create_devices_path_and_files( - v["vfs"], device_path, True, pfinfo) + create_devices_path_and_files(v["vfs"], device_path, True, pfinfo) os.symlink("../../../" + bdf, os.path.join(dev_path, "device")) - pci_dev = os.path.join(device_path.split(SYS_DEVICES)[0], - PCI_DEVICES_PATH) + pci_dev = os.path.join( + device_path.split(SYS_DEVICES)[0], PCI_DEVICES_PATH + ) if not os.path.exists(pci_dev): os.makedirs(pci_dev) - os.symlink("../../.." + bdf_path.split("sys")[-1], - os.path.join(pci_dev, bdf)) + os.symlink( + "../../.." + bdf_path.split("sys")[-1], os.path.join(pci_dev, bdf) + ) def create_fake_sysfs(prefix=""): @@ -252,14 +260,19 @@ def main(): if __name__ == "__main__": parser = argparse.ArgumentParser( - description="Generate a fake sysfs for intel QAT.") + description="Generate a fake sysfs for intel QAT." + ) group = parser.add_mutually_exclusive_group() group.add_argument("-v", "--verbose", action="store_true") group.add_argument("-q", "--quiet", action="store_true") - parser.add_argument("-p", "--prefix", type=str, - default="/tmp", dest="p", - help='Set the prefix path of the fake sysfs. ' - 'default "/tmp"') + parser.add_argument( + "-p", + "--prefix", + type=str, + default="/tmp", + dest="p", + help='Set the prefix path of the fake sysfs. default "/tmp"', + ) args = parser.parse_args() create_fake_sysfs(args.p) diff --git a/cyborg/tests/unit/accelerator/drivers/qat/intel/test_driver.py b/cyborg/tests/unit/accelerator/drivers/qat/intel/test_driver.py index c88f467f..2bb1e40d 100644 --- a/cyborg/tests/unit/accelerator/drivers/qat/intel/test_driver.py +++ b/cyborg/tests/unit/accelerator/drivers/qat/intel/test_driver.py @@ -29,7 +29,8 @@ class TestIntelQATDriver(base.TestCase): prepare_test_data.create_fake_sysfs(tmp_sys_dir.path) tmp_path = tmp_sys_dir.path sysinfo.PCI_DEVICES_PATH = os.path.join( - tmp_path, sysinfo.PCI_DEVICES_PATH.split("/", 1)[-1]) + tmp_path, sysinfo.PCI_DEVICES_PATH.split("/", 1)[-1] + ) def tearDown(self): super().tearDown() @@ -38,57 +39,64 @@ class TestIntelQATDriver(base.TestCase): def test_discover(self): attach_handle_list = [ [ - {'attach_type': 'PCI', - 'attach_info': '{"bus": "05", ' - '"device": "01", ' - '"domain": "0000", ' - '"function": "0"}', - 'in_use': False} + { + 'attach_type': 'PCI', + 'attach_info': '{"bus": "05", ' + '"device": "01", ' + '"domain": "0000", ' + '"function": "0"}', + 'in_use': False, + } ], [ - {'attach_type': 'PCI', - 'attach_info': '{"bus": "06", ' - '"device": "00", ' - '"domain": "0000", ' - '"function": "0"}', - 'in_use': False} - ] + { + 'attach_type': 'PCI', + 'attach_info': '{"bus": "06", ' + '"device": "00", ' + '"domain": "0000", ' + '"function": "0"}', + 'in_use': False, + } + ], + ] + expected = [ + { + 'vendor': '0x8086', + 'type': 'QAT', + 'deployable_list': [ + { + 'num_accelerators': 1, + 'name': '0000:05:01.0', + 'attach_handle_list': attach_handle_list[0], + }, + ], + 'controlpath_id': { + 'cpid_info': '{"bus": "05", ' + '"device": "00", ' + '"domain": "0000", ' + '"function": "0"}', + 'cpid_type': 'PCI', + }, + }, + { + 'vendor': '0x8086', + 'type': 'QAT', + 'deployable_list': [ + { + 'num_accelerators': 1, + 'name': '0000:06:00.0', + 'attach_handle_list': attach_handle_list[1], + }, + ], + 'controlpath_id': { + 'cpid_info': '{"bus": "06", ' + '"device": "00", ' + '"domain": "0000", ' + '"function": "0"}', + 'cpid_type': 'PCI', + }, + }, ] - expected = [{'vendor': '0x8086', - 'type': 'QAT', - 'deployable_list': - [ - {'num_accelerators': 1, - 'name': '0000:05:01.0', - 'attach_handle_list': attach_handle_list[0] - }, - ], - 'controlpath_id': - { - 'cpid_info': '{"bus": "05", ' - '"device": "00", ' - '"domain": "0000", ' - '"function": "0"}', - 'cpid_type': 'PCI'} - }, - {'vendor': '0x8086', - 'type': 'QAT', - 'deployable_list': - [ - {'num_accelerators': 1, - 'name': '0000:06:00.0', - 'attach_handle_list': attach_handle_list[1] - }, - ], - 'controlpath_id': - { - 'cpid_info': '{"bus": "06", ' - '"device": "00", ' - '"domain": "0000", ' - '"function": "0"}', - 'cpid_type': 'PCI'} - } - ] intel = IntelQATDriver() qats = intel.discover() list.sort(qats, key=lambda x: x._obj_deployable_list[0].name) @@ -96,14 +104,18 @@ class TestIntelQATDriver(base.TestCase): for i in range(len(qats)): qat_dict = qats[i].as_dict() qat_dep_list = qat_dict['deployable_list'] - qat_attach_handle_list = \ - qat_dep_list[0].as_dict()['attach_handle_list'] + qat_attach_handle_list = qat_dep_list[0].as_dict()[ + 'attach_handle_list' + ] self.assertEqual(expected[i]['vendor'], qat_dict['vendor']) - self.assertEqual(expected[i]['controlpath_id'], - qat_dict['controlpath_id']) - self.assertEqual(expected[i]['deployable_list'][0] - ['num_accelerators'], - qat_dep_list[0].as_dict()['num_accelerators']) + self.assertEqual( + expected[i]['controlpath_id'], qat_dict['controlpath_id'] + ) + self.assertEqual( + expected[i]['deployable_list'][0]['num_accelerators'], + qat_dep_list[0].as_dict()['num_accelerators'], + ) self.assertEqual(1, len(qat_attach_handle_list)) - self.assertEqual(attach_handle_list[i][0], - qat_attach_handle_list[0].as_dict()) + self.assertEqual( + attach_handle_list[i][0], qat_attach_handle_list[0].as_dict() + ) diff --git a/cyborg/tests/unit/accelerator/drivers/spdk/nvmf/test_nvmf.py b/cyborg/tests/unit/accelerator/drivers/spdk/nvmf/test_nvmf.py index 60ff451b..1a50f9a6 100644 --- a/cyborg/tests/unit/accelerator/drivers/spdk/nvmf/test_nvmf.py +++ b/cyborg/tests/unit/accelerator/drivers/spdk/nvmf/test_nvmf.py @@ -22,8 +22,9 @@ from cyborg.tests import base class TestNVMFDRIVER(base.TestCase): - - def setUp(self,): + def setUp( + self, + ): super().setUp() self.nvmf_driver = NVMFDRIVER() @@ -35,14 +36,16 @@ class TestNVMFDRIVER(base.TestCase): def test_discover_accelerator(self, mock_get_one_accelerator): expect_accelerator = { 'server': 'nvmf', - 'bdevs': [{"num_blocks": 131072, - "name": "nvme1", - "block_size": 512 - }], - 'subsystems': [{"core": 0, - "nqn": "nqn.2018-01.org.nvmexpress.discovery", - "hosts": [] - }] + 'bdevs': [ + {"num_blocks": 131072, "name": "nvme1", "block_size": 512} + ], + 'subsystems': [ + { + "core": 0, + "nqn": "nqn.2018-01.org.nvmexpress.discovery", + "hosts": [], + } + ], } alive = mock.Mock(return_value=False) self.nvmf_driver.py.is_alive = alive @@ -50,29 +53,31 @@ class TestNVMFDRIVER(base.TestCase): common_fun.check_for_setup_error = check_error self.assertFalse( mock_get_one_accelerator.called, - "Failed to discover_accelerator if py not alive." + "Failed to discover_accelerator if py not alive.", ) alive = mock.Mock(return_value=True) self.nvmf_driver.py.is_alive = alive check_error = mock.Mock(return_value=True) common_fun.check_for_setup_error = check_error acce_client = NvmfTgt(self.nvmf_driver.py) - bdevs_fake = [{"num_blocks": 131072, - "name": "nvme1", - "block_size": 512 - }] + bdevs_fake = [ + {"num_blocks": 131072, "name": "nvme1", "block_size": 512} + ] bdev_list = mock.Mock(return_value=bdevs_fake) acce_client.get_bdevs = bdev_list - subsystems_fake = [{"core": 0, - "nqn": "nqn.2018-01.org.nvmexpress.discovery", - "hosts": [] - }] + subsystems_fake = [ + { + "core": 0, + "nqn": "nqn.2018-01.org.nvmexpress.discovery", + "hosts": [], + } + ] subsystem_list = mock.Mock(return_value=subsystems_fake) acce_client.get_nvmf_subsystems = subsystem_list accelerator_fake = { 'server': self.nvmf_driver.SERVER, 'bdevs': acce_client.get_bdevs(), - 'subsystems': acce_client.get_nvmf_subsystems() + 'subsystems': acce_client.get_nvmf_subsystems(), } success_send = mock.Mock(return_value=accelerator_fake) self.nvmf_driver.get_one_accelerator = success_send @@ -80,35 +85,39 @@ class TestNVMFDRIVER(base.TestCase): self.assertEqual(accelerator, expect_accelerator) def test_accelerator_list(self): - expect_accelerators = [{ - 'server': 'nvmf', - 'bdevs': [{"num_blocks": 131072, - "name": "nvme1", - "block_size": 512 - }], - 'subsystems': - [{"core": 0, - "nqn": "nqn.2018-01.org.nvmexpress.discovery", - "hosts": [] - }] - }, + expect_accelerators = [ + { + 'server': 'nvmf', + 'bdevs': [ + {"num_blocks": 131072, "name": "nvme1", "block_size": 512} + ], + 'subsystems': [ + { + "core": 0, + "nqn": "nqn.2018-01.org.nvmexpress.discovery", + "hosts": [], + } + ], + }, { 'server': 'nvnf_tgt', - 'bdevs': [{"num_blocks": 131072, - "name": "nvme1", - "block_size": 512 - }], - 'subsystems': - [{"core": 0, - "nqn": "nqn.2018-01.org.nvmexpress.discovery", - "hosts": [] - }] - } + 'bdevs': [ + {"num_blocks": 131072, "name": "nvme1", "block_size": 512} + ], + 'subsystems': [ + { + "core": 0, + "nqn": "nqn.2018-01.org.nvmexpress.discovery", + "hosts": [], + } + ], + }, ] success_send = mock.Mock(return_value=expect_accelerators) self.nvmf_driver.get_all_accelerators = success_send - self.assertEqual(self.nvmf_driver.accelerator_list(), - expect_accelerators) + self.assertEqual( + self.nvmf_driver.accelerator_list(), expect_accelerators + ) def test_install_accelerator(self): pass diff --git a/cyborg/tests/unit/accelerator/drivers/spdk/vhost/test_vhost.py b/cyborg/tests/unit/accelerator/drivers/spdk/vhost/test_vhost.py index c2ec09b8..0700454e 100644 --- a/cyborg/tests/unit/accelerator/drivers/spdk/vhost/test_vhost.py +++ b/cyborg/tests/unit/accelerator/drivers/spdk/vhost/test_vhost.py @@ -22,7 +22,6 @@ from cyborg.tests import base class TestVHOSTDRIVER(base.TestCase): - def setUp(self): super().setUp() self.vhost_driver = VHOSTDRIVER() @@ -35,17 +34,18 @@ class TestVHOSTDRIVER(base.TestCase): def test_discover_accelerator(self, mock_get_one_accelerator): expect_accelerator = { 'server': 'vhost', - 'bdevs': [{"num_blocks": 131072, - "name": "nvme1", - "block_size": 512 - }], + 'bdevs': [ + {"num_blocks": 131072, "name": "nvme1", "block_size": 512} + ], 'scsi_devices': [], - 'luns': [{"claimed": True, - "name": "Malloc0"}], - 'interfaces': [{"core": 0, - "nqn": "nqn.2018-01.org.nvmexpress.discovery", - "hosts": [] - }] + 'luns': [{"claimed": True, "name": "Malloc0"}], + 'interfaces': [ + { + "core": 0, + "nqn": "nqn.2018-01.org.nvmexpress.discovery", + "hosts": [], + } + ], } alive = mock.Mock(return_value=True) self.vhost_driver.py.is_alive = alive @@ -53,27 +53,27 @@ class TestVHOSTDRIVER(base.TestCase): common_fun.check_for_setup_error = check_error self.assertFalse( mock_get_one_accelerator.called, - "Failed to discover_accelerator if py not alive." + "Failed to discover_accelerator if py not alive.", ) acce_client = VhostTgt(self.vhost_driver.py) - bdevs_fake = [{"num_blocks": 131072, - "name": "nvme1", - "block_size": 512 - }] + bdevs_fake = [ + {"num_blocks": 131072, "name": "nvme1", "block_size": 512} + ] bdev_list = mock.Mock(return_value=bdevs_fake) acce_client.get_bdevs = bdev_list scsi_devices_fake = [] scsi_device_list = mock.Mock(return_value=scsi_devices_fake) acce_client.get_scsi_devices = scsi_device_list - luns_fake = [{"claimed": True, - "name": "Malloc0"}] + luns_fake = [{"claimed": True, "name": "Malloc0"}] lun_list = mock.Mock(return_value=luns_fake) acce_client.get_luns = lun_list - interfaces_fake = \ - [{"core": 0, - "nqn": "nqn.2018-01.org.nvmexpress.discovery", - "hosts": [] - }] + interfaces_fake = [ + { + "core": 0, + "nqn": "nqn.2018-01.org.nvmexpress.discovery", + "hosts": [], + } + ] interface_list = mock.Mock(return_value=interfaces_fake) acce_client.get_interfaces = interface_list accelerator_fake = { @@ -81,7 +81,7 @@ class TestVHOSTDRIVER(base.TestCase): 'bdevs': acce_client.get_bdevs(), 'scsi_devices': acce_client.get_scsi_devices(), 'luns': acce_client.get_luns(), - 'interfaces': acce_client.get_interfaces() + 'interfaces': acce_client.get_interfaces(), } success_send = mock.Mock(return_value=accelerator_fake) self.vhost_driver.get_one_accelerator = success_send @@ -89,39 +89,43 @@ class TestVHOSTDRIVER(base.TestCase): self.assertEqual(accelerator, expect_accelerator) def test_accelerator_list(self): - expect_accelerators = [{ - 'server': 'vhost', - 'bdevs': [{"num_blocks": 131072, - "name": "nvme1", - "block_size": 512 - }], - 'scsi_devices': [], - 'luns': [{"claimed": True, - "name": "Malloc0"}], - 'interfaces': [{"core": 0, - "nqn": "nqn.2018-01.org.nvmexpress.discovery", - "hosts": [] - }] - }, + expect_accelerators = [ + { + 'server': 'vhost', + 'bdevs': [ + {"num_blocks": 131072, "name": "nvme1", "block_size": 512} + ], + 'scsi_devices': [], + 'luns': [{"claimed": True, "name": "Malloc0"}], + 'interfaces': [ + { + "core": 0, + "nqn": "nqn.2018-01.org.nvmexpress.discovery", + "hosts": [], + } + ], + }, { 'server': 'vhost_tgt', - 'bdevs': [{"num_blocks": 131072, - "name": "nvme1", - "block_size": 512 - }], + 'bdevs': [ + {"num_blocks": 131072, "name": "nvme1", "block_size": 512} + ], 'scsi_devices': [], - 'luns': [{"claimed": True, - "name": "Malloc0"}], - 'interfaces': [{"core": 0, - "nqn": "nqn.2018-01.org.nvmexpress.discovery", - "hosts": [] - }] - } + 'luns': [{"claimed": True, "name": "Malloc0"}], + 'interfaces': [ + { + "core": 0, + "nqn": "nqn.2018-01.org.nvmexpress.discovery", + "hosts": [], + } + ], + }, ] success_send = mock.Mock(return_value=expect_accelerators) self.vhost_driver.get_all_accelerators = success_send - self.assertEqual(self.vhost_driver.accelerator_list(), - expect_accelerators) + self.assertEqual( + self.vhost_driver.accelerator_list(), expect_accelerators + ) def test_install_accelerator(self): pass diff --git a/cyborg/tests/unit/accelerator/drivers/ssd/test_utils.py b/cyborg/tests/unit/accelerator/drivers/ssd/test_utils.py index 88942db3..6321221e 100644 --- a/cyborg/tests/unit/accelerator/drivers/ssd/test_utils.py +++ b/cyborg/tests/unit/accelerator/drivers/ssd/test_utils.py @@ -21,12 +21,13 @@ from cyborg.accelerator.drivers.ssd.inspur.driver import InspurNVMeSSDDriver from cyborg.accelerator.drivers.ssd import utils from cyborg.tests import base -NVME_SSD_INFO = \ - "0000:db:00.0 Non-Volatile memory controller [0108]: Inspur " \ - "Electronic Information Industry Co., Ltd. Device [1bd4:1001]" \ - " (rev 02)\n0000:db:01.0 Non-Volatile memory controller " \ - "[0108]: Inspur Electronic Information Industry Co., Ltd. " \ +NVME_SSD_INFO = ( + "0000:db:00.0 Non-Volatile memory controller [0108]: Inspur " + "Electronic Information Industry Co., Ltd. Device [1bd4:1001]" + " (rev 02)\n0000:db:01.0 Non-Volatile memory controller " + "[0108]: Inspur Electronic Information Industry Co., Ltd. " "Device [1bd4:1001] (rev 02)" +) class stdout: @@ -43,7 +44,6 @@ class p: class TestSSDDriverUtils(base.TestCase): - def setUp(self): super().setUp() self.p = p() @@ -62,12 +62,14 @@ class TestSSDDriverUtils(base.TestCase): ssd_list = InspurNVMeSSDDriver.discover(vendor_id) self.assertEqual(2, len(ssd_list)) attach_handle_list = [ - {'attach_type': 'PCI', - 'attach_info': '{"bus": "db", ' - '"device": "00", ' - '"domain": "0000", ' - '"function": "0"}', - 'in_use': False} + { + 'attach_type': 'PCI', + 'attach_info': '{"bus": "db", ' + '"device": "00", ' + '"domain": "0000", ' + '"function": "0"}', + 'in_use': False, + } ] attribute_list = [ {'key': 'rc', 'value': 'CUSTOM_SSD'}, @@ -77,51 +79,65 @@ class TestSSDDriverUtils(base.TestCase): expected = { 'vendor': '1bd4', 'type': 'SSD', - 'std_board_info': - {"controller": "Non-Volatile memory controller", - "product_id": "1001"}, + 'std_board_info': { + "controller": "Non-Volatile memory controller", + "product_id": "1001", + }, 'vendor_board_info': {"vendor_info": "ssd_vb_info"}, - 'deployable_list': - [ - { - 'num_accelerators': 1, - 'driver_name': 'INSPUR', - 'name': 'host-192-168-32-195_0000:db:00.0', - 'attach_handle_list': attach_handle_list, - 'attribute_list': attribute_list - }, - ], - 'controlpath_id': {'cpid_info': '{"bus": "db", ' - '"device": "00", ' - '"domain": "0000", ' - '"function": "0"}', - 'cpid_type': 'PCI'} + 'deployable_list': [ + { + 'num_accelerators': 1, + 'driver_name': 'INSPUR', + 'name': 'host-192-168-32-195_0000:db:00.0', + 'attach_handle_list': attach_handle_list, + 'attribute_list': attribute_list, + }, + ], + 'controlpath_id': { + 'cpid_info': '{"bus": "db", ' + '"device": "00", ' + '"domain": "0000", ' + '"function": "0"}', + 'cpid_type': 'PCI', + }, } ssd_obj = ssd_list[0] ssd_dict = ssd_obj.as_dict() ssd_dep_list = ssd_dict['deployable_list'] - ssd_attach_handle_list = \ - ssd_dep_list[0].as_dict()['attach_handle_list'] - ssd_attribute_list = \ - ssd_dep_list[0].as_dict()['attribute_list'] + ssd_attach_handle_list = ssd_dep_list[0].as_dict()[ + 'attach_handle_list' + ] + ssd_attribute_list = ssd_dep_list[0].as_dict()['attribute_list'] attri_obj_data = [] [attri_obj_data.append(attr.as_dict()) for attr in ssd_attribute_list] attribute_actual_data = sorted(attri_obj_data, key=lambda i: i['key']) self.assertEqual(expected['vendor'], ssd_dict['vendor']) - self.assertEqual(expected['controlpath_id'], - ssd_dict['controlpath_id']) - self.assertEqual(expected['std_board_info'], - jsonutils.loads(ssd_dict['std_board_info'])) - self.assertEqual(expected['vendor_board_info'], - jsonutils.loads(ssd_dict['vendor_board_info'])) - self.assertEqual(expected['deployable_list'][0]['num_accelerators'], - ssd_dep_list[0].as_dict()['num_accelerators']) - self.assertEqual(expected['deployable_list'][0]['name'], - ssd_dep_list[0].as_dict()['name']) - self.assertEqual(expected['deployable_list'][0]['driver_name'], - ssd_dep_list[0].as_dict()['driver_name']) - self.assertEqual(attach_handle_list[0], - ssd_attach_handle_list[0].as_dict()) + self.assertEqual( + expected['controlpath_id'], ssd_dict['controlpath_id'] + ) + self.assertEqual( + expected['std_board_info'], + jsonutils.loads(ssd_dict['std_board_info']), + ) + self.assertEqual( + expected['vendor_board_info'], + jsonutils.loads(ssd_dict['vendor_board_info']), + ) + self.assertEqual( + expected['deployable_list'][0]['num_accelerators'], + ssd_dep_list[0].as_dict()['num_accelerators'], + ) + self.assertEqual( + expected['deployable_list'][0]['name'], + ssd_dep_list[0].as_dict()['name'], + ) + self.assertEqual( + expected['deployable_list'][0]['driver_name'], + ssd_dep_list[0].as_dict()['driver_name'], + ) + self.assertEqual( + attach_handle_list[0], ssd_attach_handle_list[0].as_dict() + ) self.assertEqual(attribute_list, attribute_actual_data) @mock.patch('cyborg.accelerator.drivers.ssd.utils.lspci_privileged') @@ -130,8 +146,12 @@ class TestSSDDriverUtils(base.TestCase): with self.assertLogs(None, level='INFO') as cm: d = SSDDriver.create() ssd_list = d.discover() - self.assertEqual(cm.output, - ['INFO:cyborg.accelerator.drivers.ssd.base:The ' - 'method "discover" is called in generic.SSDDriver']) + self.assertEqual( + cm.output, + [ + 'INFO:cyborg.accelerator.drivers.ssd.base:The ' + 'method "discover" is called in generic.SSDDriver' + ], + ) self.assertEqual(2, len(ssd_list)) self.assertEqual("1bd4", ssd_list[1].as_dict()['vendor']) diff --git a/cyborg/tests/unit/accelerator/drivers/test_driver.py b/cyborg/tests/unit/accelerator/drivers/test_driver.py index 7e246930..6492468d 100644 --- a/cyborg/tests/unit/accelerator/drivers/test_driver.py +++ b/cyborg/tests/unit/accelerator/drivers/test_driver.py @@ -32,10 +32,8 @@ class NotCompleteDriver(GenericDriver): class TestGenericDriver(base.TestCase): - def test_generic_driver(self): # Can't instantiate abstract class NotCompleteDriver with # abstract methods get_stats, update result = self.assertRaises(TypeError, NotCompleteDriver) - self.assertIn("Can't instantiate abstract class", - str(result)) + self.assertIn("Can't instantiate abstract class", str(result)) diff --git a/cyborg/tests/unit/accelerator/drivers/test_fake_driver.py b/cyborg/tests/unit/accelerator/drivers/test_fake_driver.py index 9fa05a24..ae4f6d1c 100644 --- a/cyborg/tests/unit/accelerator/drivers/test_fake_driver.py +++ b/cyborg/tests/unit/accelerator/drivers/test_fake_driver.py @@ -22,7 +22,6 @@ from cyborg.tests import base class TestFakeDriver(base.TestCase): - def setUp(self): super().setUp() @@ -34,7 +33,10 @@ class TestFakeDriver(base.TestCase): for attach handles. """ pci_bdf_template = { - 'domain': '0000', 'bus': '0c', 'device': '00', } + 'domain': '0000', + 'bus': '0c', + 'device': '00', + } pci_bdf_list = [] for fn in range(num_pci_bdfs): pci_bdf = copy.deepcopy(pci_bdf_template) @@ -61,52 +63,60 @@ class TestFakeDriver(base.TestCase): return cpid, [ah1, ah2] def _validate_attach_handle_list(self, ah_list): - self.assertTrue(all( - ah['attach_type'] == 'TEST_PCI' for ah in ah_list)) + self.assertTrue(all(ah['attach_type'] == 'TEST_PCI' for ah in ah_list)) self.assertTrue(all(ah['in_use'] is False for ah in ah_list)) ah_bdfs = [jsonutils.loads(ah['attach_info']) for ah in ah_list] - bdf_list = [(bdf['domain'], bdf['bus'], - bdf['device'], bdf['function']) for bdf in ah_bdfs] + bdf_list = [ + (bdf['domain'], bdf['bus'], bdf['device'], bdf['function']) + for bdf in ah_bdfs + ] domains = [bdf[0] for bdf in bdf_list] buses = [bdf[1] for bdf in bdf_list] pci_devs = [bdf[2] for bdf in bdf_list] pci_fns = [bdf[3] for bdf in bdf_list] # Ensure BDFs are unique - self.assertEqual(len(set(bdf_list)), len(bdf_list), - "BDFs must be unique.") self.assertEqual( - len(set(domains)), 1, - "PCI domains are expected to be same for #accelerators < 256.") + len(set(bdf_list)), len(bdf_list), "BDFs must be unique." + ) self.assertEqual( - len(set(buses)), 1, - "PCI bus #s are expected to be same for #accelerators < 256.") + len(set(domains)), + 1, + "PCI domains are expected to be same for #accelerators < 256.", + ) + self.assertEqual( + len(set(buses)), + 1, + "PCI bus #s are expected to be same for #accelerators < 256.", + ) self.assertTrue( all([0 <= int(pci_dev) < 32 for pci_dev in pci_devs]), - "PCI device numbers must be between 0 and 31." - ) + "PCI device numbers must be between 0 and 31.", + ) self.assertTrue( all([0 <= int(pci_fn) < 8 for pci_fn in pci_fns]), - "PCI function numbers must be between 0 and 7." - ) + "PCI function numbers must be between 0 and 7.", + ) def test_discover(self): cpid, ah_list = self._get_cpid_attach_handles() # This is an example device from the fake driver. # We set num_accelerators = 2 for simplicity. - expected = {'vendor': '0xABCD', - 'type': 'FPGA', - 'model': 'miss model info', - 'deployable_list': - [{'num_accelerators': 2, - 'name': 'FakeDevice', - 'attach_handle_list': ah_list, - }, - ], - 'controlpath_id': cpid, - } + expected = { + 'vendor': '0xABCD', + 'type': 'FPGA', + 'model': 'miss model info', + 'deployable_list': [ + { + 'num_accelerators': 2, + 'name': 'FakeDevice', + 'attach_handle_list': ah_list, + }, + ], + 'controlpath_id': cpid, + } fake_drv = FakeDriver() devices = fake_drv.discover() @@ -119,7 +129,8 @@ class TestFakeDriver(base.TestCase): self.assertEqual(cpid['cpid_type'], expected_cpid['cpid_type']) self.assertEqual( jsonutils.loads(cpid['cpid_info']), - jsonutils.loads(expected_cpid['cpid_info'])) + jsonutils.loads(expected_cpid['cpid_info']), + ) deployables = devices[0].deployable_list self.assertEqual(1, len(deployables)) diff --git a/cyborg/tests/unit/agent/test_manager.py b/cyborg/tests/unit/agent/test_manager.py index aed0d55b..491bf947 100644 --- a/cyborg/tests/unit/agent/test_manager.py +++ b/cyborg/tests/unit/agent/test_manager.py @@ -27,17 +27,21 @@ class TestAgentManager(base.TestCase): def setUp(self): super().setUp() - self.placement_mock = self.useFixture(fixtures.MockPatch( - 'cyborg.agent.manager.placement.PlacementClient') + self.placement_mock = self.useFixture( + fixtures.MockPatch( + 'cyborg.agent.manager.placement.PlacementClient' + ) ).mock.return_value def _create_manager_with_mocks(self): """Create an AgentManager with all dependencies mocked.""" - with mock.patch('cyborg.agent.manager.FPGADriver'), \ - mock.patch('cyborg.agent.manager.cond_api.ConductorAPI'), \ - mock.patch('cyborg.agent.manager.AgentAPI'), \ - mock.patch('cyborg.agent.manager.ImageAPI'), \ - mock.patch('cyborg.agent.manager.ResourceTracker'): + with ( + mock.patch('cyborg.agent.manager.FPGADriver'), + mock.patch('cyborg.agent.manager.cond_api.ConductorAPI'), + mock.patch('cyborg.agent.manager.AgentAPI'), + mock.patch('cyborg.agent.manager.ImageAPI'), + mock.patch('cyborg.agent.manager.ResourceTracker'), + ): return manager.AgentManager('cyborg-agent-topic') @mock.patch('cyborg.agent.manager.CONF') @@ -99,7 +103,8 @@ class TestAgentManager(base.TestCase): self.assertRaises( exception.PlacementResourceProviderNotFound, - self._create_manager_with_mocks) + self._create_manager_with_mocks, + ) @mock.patch('cyborg.agent.manager.CONF') def test_get_resource_provider_name_config_override(self, mock_conf): @@ -135,7 +140,8 @@ class TestAgentManager(base.TestCase): self.assertRaises( exception.PlacementResourceProviderNotFound, - self._create_manager_with_mocks) + self._create_manager_with_mocks, + ) # Should only try once since hostname == CONF.host self.assertEqual(1, self.placement_mock.get.call_count) @@ -143,7 +149,8 @@ class TestAgentManager(base.TestCase): @mock.patch('cyborg.agent.manager.time') @mock.patch('cyborg.agent.manager.CONF') def test_get_resource_provider_name_retry_succeeds( - self, mock_conf, mock_time): + self, mock_conf, mock_time + ): """Test retry succeeds on second attempt.""" mock_conf.agent.resource_provider_name = 'compute-0.example.com' mock_conf.host = 'compute-0' @@ -164,7 +171,8 @@ class TestAgentManager(base.TestCase): @mock.patch('cyborg.agent.manager.time') @mock.patch('cyborg.agent.manager.CONF') def test_get_resource_provider_name_retry_exhausted( - self, mock_conf, mock_time): + self, mock_conf, mock_time + ): """Test all retry attempts exhausted raises exception.""" mock_conf.agent.resource_provider_name = 'compute-0.example.com' mock_conf.host = 'compute-0' @@ -177,17 +185,19 @@ class TestAgentManager(base.TestCase): self.assertRaises( exception.PlacementResourceProviderNotFound, - self._create_manager_with_mocks) + self._create_manager_with_mocks, + ) # 3 total attempts (initial + 2 retries), sleep between each self.assertEqual( - [mock.call(1), mock.call(2)], - mock_time.sleep.call_args_list) + [mock.call(1), mock.call(2)], mock_time.sleep.call_args_list + ) @mock.patch('cyborg.agent.manager.time') @mock.patch('cyborg.agent.manager.CONF') def test_get_resource_provider_name_no_retry_when_zero( - self, mock_conf, mock_time): + self, mock_conf, mock_time + ): """Test immediate failure when retries set to 0.""" mock_conf.agent.resource_provider_name = 'compute-0.example.com' mock_conf.host = 'compute-0' @@ -199,7 +209,8 @@ class TestAgentManager(base.TestCase): self.assertRaises( exception.PlacementResourceProviderNotFound, - self._create_manager_with_mocks) + self._create_manager_with_mocks, + ) mock_time.sleep.assert_not_called() @@ -212,8 +223,10 @@ class TestAgentManager(base.TestCase): # Placement raises an exception for all requests self.placement_mock.get.side_effect = ks_exc.ConnectFailure( - 'Placement unavailable') + 'Placement unavailable' + ) self.assertRaises( exception.PlacementResourceProviderNotFound, - self._create_manager_with_mocks) + self._create_manager_with_mocks, + ) diff --git a/cyborg/tests/unit/agent/test_resource_tracker.py b/cyborg/tests/unit/agent/test_resource_tracker.py index 4b496358..2a2c150c 100644 --- a/cyborg/tests/unit/agent/test_resource_tracker.py +++ b/cyborg/tests/unit/agent/test_resource_tracker.py @@ -25,7 +25,7 @@ from cyborg.tests import base class TestResourceTracker(base.TestCase): - """Test Agent ResourceTracker """ + """Test Agent ResourceTracker""" def setUp(self): super().setUp() @@ -49,15 +49,20 @@ class TestResourceTracker(base.TestCase): def test_initialize_invalid_driver(self): enabled_drivers = ['invalid_driver'] - self.assertRaises(exception.InvalidDriver, self.rt._initialize_drivers, - enabled_drivers) + self.assertRaises( + exception.InvalidDriver, + self.rt._initialize_drivers, + enabled_drivers, + ) @mock.patch('cyborg.agent.resource_tracker.LOG') def test_update_usage_failed_parent_provider(self, mock_log): with mock.patch.object(self.rt.conductor_api, 'report_data') as m: m.side_effect = exception.PlacementResourceProviderNotFound( - resource_provider='foo') + resource_provider='foo' + ) self.rt.update_usage(None) m.assert_called_once_with(None, 'fake-mini', mock.ANY) - mock_log.error.assert_called_once_with('Unable to report usage: %s', - m.side_effect) + mock_log.error.assert_called_once_with( + 'Unable to report usage: %s', m.side_effect + ) diff --git a/cyborg/tests/unit/agent/test_rpcapi.py b/cyborg/tests/unit/agent/test_rpcapi.py index 09cd38b8..538bc0ca 100644 --- a/cyborg/tests/unit/agent/test_rpcapi.py +++ b/cyborg/tests/unit/agent/test_rpcapi.py @@ -31,20 +31,25 @@ class TestRPCAPI(base.TestCase): def setUp(self, topic=None): super().setUp() self.topic = topic or constants.AGENT_TOPIC - target = messaging.Target(topic=self.topic, - version=self.RPC_API_VERSION) + target = messaging.Target( + topic=self.topic, version=self.RPC_API_VERSION + ) self.agent_rpcapi = AgentAPI() self.serializer = objects_base.CyborgObjectSerializer() - self.client = rpc.get_client(target, - version_cap=self.RPC_API_VERSION, - serializer=self.serializer) + self.client = rpc.get_client( + target, + version_cap=self.RPC_API_VERSION, + serializer=self.serializer, + ) def _test_rpc_call(self, method): - ctxt = cyborg_context.RequestContext(user_id='fake_user', - project_id='fake_project') + ctxt = cyborg_context.RequestContext( + user_id='fake_user', project_id='fake_project' + ) expect_val = True - with mock.patch.object(self.agent_rpcapi, - 'fpga_program') as mock_program: + with mock.patch.object( + self.agent_rpcapi, 'fpga_program' + ) as mock_program: func_obj = getattr(self.agent_rpcapi, method) mock_program.return_value = expect_val actual_val = func_obj(ctxt, 'fake_dep_uuid') diff --git a/cyborg/tests/unit/api/base.py b/cyborg/tests/unit/api/base.py index 7b0f6744..7171d206 100644 --- a/cyborg/tests/unit/api/base.py +++ b/cyborg/tests/unit/api/base.py @@ -37,10 +37,10 @@ class BaseApiTest(base.DbTestCase): def setUp(self): super().setUp() - cfg.CONF.set_override("auth_version", "v3", - group='keystone_authtoken') - cfg.CONF.set_override("admin_user", "admin", - group='keystone_authtoken') + cfg.CONF.set_override("auth_version", "v3", group='keystone_authtoken') + cfg.CONF.set_override( + "admin_user", "admin", group='keystone_authtoken' + ) self.app = self._make_app() def reset_pecan(): @@ -69,8 +69,16 @@ class BaseApiTest(base.DbTestCase): } return pecan.testing.load_test_app(self.app_config) - def _request_json(self, path, params, expect_errors=False, headers=None, - method="post", extra_environ=None, status=None): + def _request_json( + self, + path, + params, + expect_errors=False, + headers=None, + method="post", + extra_environ=None, + status=None, + ): """Sends simulated HTTP request to Pecan test app. :param path: url path of target service @@ -90,12 +98,19 @@ class BaseApiTest(base.DbTestCase): headers=headers, status=status, extra_environ=extra_environ, - expect_errors=expect_errors + expect_errors=expect_errors, ) return response - def post_json(self, path, params, expect_errors=False, headers=None, - extra_environ=None, status=None): + def post_json( + self, + path, + params, + expect_errors=False, + headers=None, + extra_environ=None, + status=None, + ): """Sends simulated HTTP POST request to Pecan test app. :param path: url path of target service @@ -108,10 +123,15 @@ class BaseApiTest(base.DbTestCase): :param status: expected status code of response """ full_path = self.PATH_PREFIX + path - return self._request_json(path=full_path, params=params, - expect_errors=expect_errors, - headers=headers, extra_environ=extra_environ, - status=status, method="post") + return self._request_json( + path=full_path, + params=params, + expect_errors=expect_errors, + headers=headers, + extra_environ=extra_environ, + status=status, + method="post", + ) def gen_context(self, value, **kwargs): ct = cyborg_context.RequestContext.from_dict(value, **kwargs) @@ -135,24 +155,32 @@ class BaseApiTest(base.DbTestCase): role = "user" headers = { 'X-User-Name': ct.get("user_name") or "user", - 'X-User-Id': - ct.get("user_id") or "1d6d686bc2c949ddb685ffb4682e0047", + 'X-User-Id': ct.get("user_id") + or "1d6d686bc2c949ddb685ffb4682e0047", 'X-Project-Name': ct.get("project_name") or "no_project_name", - 'X-Project-Id': - ct.get("project_id") or "86f64f561b6d4f479655384572727f70", - 'X-User-Domain-Id': - ct.get("domain_id") or "bd5eeb7d0fb046daaf694b36f4df5518", + 'X-Project-Id': ct.get("project_id") + or "86f64f561b6d4f479655384572727f70", + 'X-User-Domain-Id': ct.get("domain_id") + or "bd5eeb7d0fb046daaf694b36f4df5518", 'X-User-Domain-Name': ct.get("domain_name") or "no_domain", - 'X-Auth-Token': - ct.get("auth_token") or "b9764005b8c145bf972634fb16a826e8", + 'X-Auth-Token': ct.get("auth_token") + or "b9764005b8c145bf972634fb16a826e8", 'X-Roles': ct.get("roles") or role, } if ct.get('system_scope') == 'all': headers.update({'Openstack-System-Scope': 'all'}) return headers - def get_json(self, path, expect_errors=False, headers=None, - extra_environ=None, q=None, return_json=True, **params): + def get_json( + self, + path, + expect_errors=False, + headers=None, + extra_environ=None, + q=None, + return_json=True, + **params, + ): """Sends simulated HTTP GET request to Pecan test app. :param path: url path of target service @@ -172,7 +200,7 @@ class BaseApiTest(base.DbTestCase): 'q.field': [], 'q.value': [], 'q.op': [], - } + } for query in q: for name in ['field', 'op', 'value']: query_params['q.%s' % name].append(query.get(name, '')) @@ -180,17 +208,26 @@ class BaseApiTest(base.DbTestCase): all_params.update(params) if q: all_params.update(query_params) - response = self.app.get(full_path, - params=all_params, - headers=headers, - extra_environ=extra_environ, - expect_errors=expect_errors) + response = self.app.get( + full_path, + params=all_params, + headers=headers, + extra_environ=extra_environ, + expect_errors=expect_errors, + ) if return_json and not expect_errors: response = response.json return response - def patch_json(self, path, params, expect_errors=False, headers=None, - extra_environ=None, status=None): + def patch_json( + self, + path, + params, + expect_errors=False, + headers=None, + extra_environ=None, + status=None, + ): """Sends simulated HTTP PATCH request to Pecan test app. :param path: url path of target service @@ -203,13 +240,24 @@ class BaseApiTest(base.DbTestCase): :param status: expected status code of response """ full_path = self.PATH_PREFIX + path - return self._request_json(path=full_path, params=params, - expect_errors=expect_errors, - headers=headers, extra_environ=extra_environ, - status=status, method="patch") + return self._request_json( + path=full_path, + params=params, + expect_errors=expect_errors, + headers=headers, + extra_environ=extra_environ, + status=status, + method="patch", + ) - def delete(self, path, expect_errors=False, headers=None, - extra_environ=None, status=None): + def delete( + self, + path, + expect_errors=False, + headers=None, + extra_environ=None, + status=None, + ): """Sends simulated HTTP DELETE request to Pecan test app. :param path: url path of target service @@ -221,9 +269,11 @@ class BaseApiTest(base.DbTestCase): :param status: expected status code of response """ full_path = self.PATH_PREFIX + path - response = self.app.delete(full_path, - headers=headers, - status=status, - extra_environ=extra_environ, - expect_errors=expect_errors) + response = self.app.delete( + full_path, + headers=headers, + status=status, + extra_environ=extra_environ, + expect_errors=expect_errors, + ) return response diff --git a/cyborg/tests/unit/api/controllers/v2/base.py b/cyborg/tests/unit/api/controllers/v2/base.py index 889bd3a4..791cfa8d 100644 --- a/cyborg/tests/unit/api/controllers/v2/base.py +++ b/cyborg/tests/unit/api/controllers/v2/base.py @@ -17,5 +17,4 @@ from cyborg.tests.unit.api import base class APITestV2(base.BaseApiTest): - PATH_PREFIX = '/v2' diff --git a/cyborg/tests/unit/api/controllers/v2/test_api.py b/cyborg/tests/unit/api/controllers/v2/test_api.py index e93877ef..3c6767f2 100644 --- a/cyborg/tests/unit/api/controllers/v2/test_api.py +++ b/cyborg/tests/unit/api/controllers/v2/test_api.py @@ -18,7 +18,6 @@ from cyborg.tests.unit.api.controllers.v2 import base as v2_test class TestAPI(v2_test.APITestV2): - def setUp(self): super().setUp() self.headers = self.gen_headers(self.context) diff --git a/cyborg/tests/unit/api/controllers/v2/test_arqs.py b/cyborg/tests/unit/api/controllers/v2/test_arqs.py index cd64f1c7..033c8d76 100644 --- a/cyborg/tests/unit/api/controllers/v2/test_arqs.py +++ b/cyborg/tests/unit/api/controllers/v2/test_arqs.py @@ -28,7 +28,6 @@ from cyborg.tests.unit import fake_extarq class TestARQsController(v2_test.APITestV2): - ARQ_URL = '/accelerator_requests' def setUp(self): @@ -37,7 +36,8 @@ class TestARQsController(v2_test.APITestV2): self.fake_extarqs = fake_extarq.get_fake_extarq_objs() self.fake_bind_extarqs = fake_extarq.get_fake_extarq_bind_objs() self.fake_resolved_extarqs = ( - fake_extarq.get_fake_extarq_resolved_objs()) + fake_extarq.get_fake_extarq_resolved_objs() + ) self.arqs_controller = arqs.ARQsController() def _validate_links(self, links, arq_uuid): @@ -109,8 +109,9 @@ class TestARQsController(v2_test.APITestV2): result = isinstance(out_arqs, list) self.assertTrue(result) self.assertEqual(len(out_arqs), len(self.fake_resolved_extarqs[1:])) - for in_extarq, out_arq in zip(self.fake_resolved_extarqs[1:], - out_arqs): + for in_extarq, out_arq in zip( + self.fake_resolved_extarqs[1:], out_arqs + ): self._validate_arq(in_extarq.arq, out_arq) @mock.patch('cyborg.objects.ExtARQ.list') @@ -119,7 +120,9 @@ class TestARQsController(v2_test.APITestV2): mock_extarqs.return_value = self.fake_bind_extarqs[:3] instance_uuid = self.fake_bind_extarqs[0].arq.instance_uuid url = '%s?instance=%s&bind_state=resolved' % ( - self.ARQ_URL, instance_uuid) + self.ARQ_URL, + instance_uuid, + ) data = self.get_json(url, headers=self.headers) out_arqs = data['arqs'] @@ -135,7 +138,9 @@ class TestARQsController(v2_test.APITestV2): mock_extarqs.return_value = self.fake_bind_extarqs instance_uuid = self.fake_bind_extarqs[0].arq.instance_uuid url = '%s?instance=%s&bind_state=resolved' % ( - self.ARQ_URL, instance_uuid) + self.ARQ_URL, + instance_uuid, + ) try: self.get_json(url, headers=self.headers) except Exception as e: @@ -148,7 +153,9 @@ class TestARQsController(v2_test.APITestV2): mock_extarqs.return_value = self.fake_extarqs instance_uuid = self.fake_extarqs[0].arq.instance_uuid url = '%s?instance=%s&bind_state=started' % ( - self.ARQ_URL, instance_uuid) + self.ARQ_URL, + instance_uuid, + ) exc = None try: self.get_json(url, headers=self.headers) @@ -158,7 +165,9 @@ class TestARQsController(v2_test.APITestV2): # use assertIn here, improve this case with assertRaises later. self.assertIn( "Bad state: started for ARQ: None. Expected state(s): " - "[\\\'resolved\\\']", exc.args[0]) + "[\\'resolved\\']", + exc.args[0], + ) url = '%s?bind_state=started' % (self.ARQ_URL) exc = None @@ -170,7 +179,9 @@ class TestARQsController(v2_test.APITestV2): # use assertIn here, improve this case with assertRaises later. self.assertIn( "Bad state: started for ARQ: None. Expected state(s): " - "[\\\'resolved\\\']", exc.args[0]) + "[\\'resolved\\']", + exc.args[0], + ) @mock.patch('cyborg.objects.ExtARQ.list') def test_get_all_with_invalid_arq_state(self, mock_extarqs): @@ -180,7 +191,9 @@ class TestARQsController(v2_test.APITestV2): mock_extarqs.return_value = self.fake_extarqs instance_uuid = self.fake_extarqs[0].arq.instance_uuid url = '%s?instance=%s&bind_state=resolved' % ( - self.ARQ_URL, instance_uuid) + self.ARQ_URL, + instance_uuid, + ) response = self.get_json(url, headers=self.headers, expect_errors=True) self.assertEqual(HTTPStatus.LOCKED, response.status_int) @@ -229,7 +242,8 @@ class TestARQsController(v2_test.APITestV2): params = {'device_profile_name': 'wrong_device_profile_name'} mock_obj_dp.side_effect = exception.ResourceNotFound( resource='Device Profile', - msg='with name=%s' % params.get('device_profile_name')) + msg='with name=%s' % params.get('device_profile_name'), + ) mock_obj_extarq.side_effect = self.fake_extarqs exc = None try: @@ -237,12 +251,14 @@ class TestARQsController(v2_test.APITestV2): except Exception as e: exc = e self.assertIn( - "Device Profile not found with " - "name=wrong_device_profile_name", exc.args[0]) + "Device Profile not found with name=wrong_device_profile_name", + exc.args[0], + ) @mock.patch('cyborg.conductor.rpcapi.ConductorAPI.arq_delete_by_uuid') - @mock.patch('cyborg.conductor.rpcapi.ConductorAPI.' - 'arq_delete_by_instance_uuid') + @mock.patch( + 'cyborg.conductor.rpcapi.ConductorAPI.arq_delete_by_instance_uuid' + ) def test_delete(self, mock_by_inst, mock_by_arq): url = self.ARQ_URL arq = self.fake_extarqs[0].arq @@ -286,48 +302,69 @@ class TestARQsController(v2_test.APITestV2): arq_uuid: { 'hostname': obj_extarq.arq.hostname, 'device_rp_uuid': device_rp_uuid, - 'instance_uuid': obj_extarq.arq.instance_uuid} - for arq_uuid in arq_uuids} + 'instance_uuid': obj_extarq.arq.instance_uuid, + } + for arq_uuid in arq_uuids + } - self.patch_json(self.ARQ_URL, params=patch_list, - headers=self.headers) + self.patch_json(self.ARQ_URL, params=patch_list, headers=self.headers) - mock_apply_patch.assert_called_once_with(mock.ANY, patch_list, - valid_fields) + mock_apply_patch.assert_called_once_with( + mock.ANY, patch_list, valid_fields + ) mock_check_if_bound.assert_called_once_with(mock.ANY, valid_fields) @mock.patch.object(arqs.ARQsController, '_check_if_already_bound') @mock.patch('cyborg.conductor.rpcapi.ConductorAPI.arq_apply_patch') def test_apply_patch_allow_project_id( - self, mock_apply_patch, mock_check_if_bound): + self, mock_apply_patch, mock_check_if_bound + ): patch_list, _ = fake_extarq.get_patch_list() for arq_uuid, patch in patch_list.items(): - patch.append({'path': '/project_id', 'op': 'add', - 'value': 'b1c76756ac2e482789a8e1c5f4bf065e'}) + patch.append( + { + 'path': '/project_id', + 'op': 'add', + 'value': 'b1c76756ac2e482789a8e1c5f4bf065e', + } + ) arq_uuids = list(patch_list.keys()) valid_fields = { arq_uuid: { 'hostname': 'myhost', 'device_rp_uuid': 'fb16c293-5739-4c84-8590-926f9ab16669', 'instance_uuid': '5922a70f-1e06-4cfd-88dd-a332120d7144', - 'project_id': 'b1c76756ac2e482789a8e1c5f4bf065e'} - for arq_uuid in arq_uuids} + 'project_id': 'b1c76756ac2e482789a8e1c5f4bf065e', + } + for arq_uuid in arq_uuids + } - self.patch_json(self.ARQ_URL, params=patch_list, - headers={base.Version.current_api_version: - '2.1'}) - mock_apply_patch.assert_called_once_with(mock.ANY, patch_list, - valid_fields) + self.patch_json( + self.ARQ_URL, + params=patch_list, + headers={base.Version.current_api_version: '2.1'}, + ) + mock_apply_patch.assert_called_once_with( + mock.ANY, patch_list, valid_fields + ) mock_check_if_bound.assert_called_once_with(mock.ANY, valid_fields) def test_apply_patch_not_allow_project_id(self): patch_list, _ = fake_extarq.get_patch_list() for arq_uuid, patch in patch_list.items(): - patch.append({'path': '/project_id', 'op': 'add', - 'value': 'b1c76756ac2e482789a8e1c5f4bf065e'}) - response = self.patch_json(self.ARQ_URL, params=patch_list, - headers=self.headers, - expect_errors=True) + patch.append( + { + 'path': '/project_id', + 'op': 'add', + 'value': 'b1c76756ac2e482789a8e1c5f4bf065e', + } + ) + response = self.patch_json( + self.ARQ_URL, + params=patch_list, + headers=self.headers, + expect_errors=True, + ) self.assertEqual(HTTPStatus.NOT_ACCEPTABLE, response.status_code) self.assertTrue(response.json['error_message']) @@ -345,17 +382,20 @@ class TestARQsController(v2_test.APITestV2): extarq.arq['uuid']: { 'hostname': 'myhost', 'device_rp_uuid': 'fb16c293-5739-4c84-8590-926f9ab16669', - 'instance_uuid': instance_uuid} - for extarq in extarqs} + 'instance_uuid': instance_uuid, + } + for extarq in extarqs + } self.arqs_controller._check_if_already_bound( - self.context, valid_fields) + self.context, valid_fields + ) mock_extarq_list.assert_called_once_with(self.context) @mock.patch('cyborg.objects.ExtARQ.list') def test_check_if_bound_exception(self, mock_extarq_list): """Test that an exception is raised if binding request specifies - an instance that already has ARQs. + an instance that already has ARQs. """ extarqs = fake_extarq.get_fake_extarq_objs() mock_extarq_list.return_value = extarqs @@ -366,13 +406,20 @@ class TestARQsController(v2_test.APITestV2): extarq.arq['uuid']: { 'hostname': 'myhost', 'device_rp_uuid': 'fb16c293-5739-4c84-8590-926f9ab16669', - 'instance_uuid': instance_uuid} - for extarq in extarqs} + 'instance_uuid': instance_uuid, + } + for extarq in extarqs + } - expected_err = ('Instance %s already has accelerator requests. ' - 'Cannot bind additional ARQs.') % instance_uuid + expected_err = ( + 'Instance %s already has accelerator requests. ' + 'Cannot bind additional ARQs.' + ) % instance_uuid self.assertRaisesRegex( - exception.PatchError, expected_err, + exception.PatchError, + expected_err, self.arqs_controller._check_if_already_bound, - self.context, valid_fields) + self.context, + valid_fields, + ) diff --git a/cyborg/tests/unit/api/controllers/v2/test_attributes.py b/cyborg/tests/unit/api/controllers/v2/test_attributes.py index 969e3421..24e6fd6f 100644 --- a/cyborg/tests/unit/api/controllers/v2/test_attributes.py +++ b/cyborg/tests/unit/api/controllers/v2/test_attributes.py @@ -29,8 +29,9 @@ class TestAttributes(v2_test.APITestV2): super().setUp() self.headers = self.gen_headers(self.context) self.fake_attributes = fake_attribute.fake_db_attribute() - self.fake_attribute_objs = \ - [fake_attribute.fake_attribute_obj(self.context)] + self.fake_attribute_objs = [ + fake_attribute.fake_attribute_obj(self.context) + ] def _validate_links(self, links, attribute_uuid): has_self_link = False @@ -46,8 +47,9 @@ class TestAttributes(v2_test.APITestV2): self.assertEqual(in_attributes['uuid'], out_attributes['uuid']) self.assertEqual(in_attributes['key'], out_attributes['key']) self.assertEqual(in_attributes['value'], out_attributes['value']) - self.assertEqual(in_attributes['deployable_id'], - out_attributes['deployable_id']) + self.assertEqual( + in_attributes['deployable_id'], out_attributes['deployable_id'] + ) # Check that the link is properly set up self._validate_links(out_attributes['links'], in_attributes['uuid']) @@ -57,8 +59,9 @@ class TestAttributes(v2_test.APITestV2): attribute = self.fake_attribute_objs[0] mock_attributes_uuid.return_value = attribute url = self.ATTRIBUTE_URL + '/%s' - out_attribute = self.get_json(url % attribute['uuid'], - headers=self.headers) + out_attribute = self.get_json( + url % attribute['uuid'], headers=self.headers + ) mock_attributes_uuid.assert_called_once() self._validate_attributes(attribute, out_attribute) @@ -79,8 +82,9 @@ class TestAttributes(v2_test.APITestV2): out_attributes = data['attributes'] self.assertIsInstance(out_attributes, list) self.assertEqual(len(out_attributes), len(self.fake_attribute_objs)) - for in_attribute, out_attribute in zip(self.fake_attribute_objs, - out_attributes): + for in_attribute, out_attribute in zip( + self.fake_attribute_objs, out_attributes + ): self._validate_attributes(in_attribute, out_attribute) @mock.patch('cyborg.objects.Attribute.get_by_filter') @@ -95,8 +99,9 @@ class TestAttributes(v2_test.APITestV2): @mock.patch('cyborg.objects.Attribute.create') def test_create(self, mock_cond_attribute): mock_cond_attribute.return_value = self.fake_attribute_objs[0] - response = self.post_json(self.ATTRIBUTE_URL, self.fake_attributes, - headers=self.headers) + response = self.post_json( + self.ATTRIBUTE_URL, self.fake_attributes, headers=self.headers + ) out_attribute = jsonutils.loads(response.controller_output) self.assertEqual(HTTPStatus.CREATED, response.status_int) self._validate_attributes(self.fake_attribute_objs[0], out_attribute) diff --git a/cyborg/tests/unit/api/controllers/v2/test_deployables.py b/cyborg/tests/unit/api/controllers/v2/test_deployables.py index 2f1bb53b..8bef6832 100644 --- a/cyborg/tests/unit/api/controllers/v2/test_deployables.py +++ b/cyborg/tests/unit/api/controllers/v2/test_deployables.py @@ -20,14 +20,14 @@ from cyborg.tests.unit import fake_deployable class TestDeployablesController(v2_test.APITestV2): - DEPLOYABLE_URL = '/deployables' def setUp(self): super().setUp() self.headers = self.gen_headers(self.context) self.fake_deployable = fake_deployable.fake_deployable_obj( - self.context) + self.context + ) def _validate_links(self, links, deployable_uuid): has_self_link = False @@ -76,10 +76,12 @@ class TestDeployablesController(v2_test.APITestV2): # order to list the deployables with limited number which is 1. data = self.get_json( self.DEPLOYABLE_URL + "?filters.field=limit&filters.value=1", - headers=self.headers) + headers=self.headers, + ) out_deployable = data['deployables'] - mock_deployables.assert_called_once_with(mock.ANY, - filters={"limit": "1"}) + mock_deployables.assert_called_once_with( + mock.ANY, filters={"limit": "1"} + ) self._validate_deployable(self.fake_deployable, out_deployable[0]) @mock.patch('cyborg.objects.Deployable.list') @@ -88,10 +90,12 @@ class TestDeployablesController(v2_test.APITestV2): # is "dp_name". mock_deployables.return_value = [] data = self.get_json( - self.DEPLOYABLE_URL + - "?filters.field=name&filters.value=wrongname", - headers=self.headers) + self.DEPLOYABLE_URL + + "?filters.field=name&filters.value=wrongname", + headers=self.headers, + ) out_deployable = data['deployables'] - mock_deployables.assert_called_once_with(mock.ANY, - filters={"name": "wrongname"}) + mock_deployables.assert_called_once_with( + mock.ANY, filters={"name": "wrongname"} + ) self.assertEqual(len(out_deployable), 0) diff --git a/cyborg/tests/unit/api/controllers/v2/test_device_profiles.py b/cyborg/tests/unit/api/controllers/v2/test_device_profiles.py index 85f3bd35..3c4812be 100644 --- a/cyborg/tests/unit/api/controllers/v2/test_device_profiles.py +++ b/cyborg/tests/unit/api/controllers/v2/test_device_profiles.py @@ -74,7 +74,8 @@ class TestDeviceProfileController(v2_test.APITestV2): "Request not acceptable.*", self.get_json, url % dp['name'], - headers=headers) + headers=headers, + ) @mock.patch('cyborg.objects.DeviceProfile.get_by_name') def test_get_one_by_name(self, mock_dp_name): @@ -83,8 +84,7 @@ class TestDeviceProfileController(v2_test.APITestV2): url = self.DP_URL + '/%s' headers = self.headers headers[base.Version.current_api_version] = '2.2' - data = self.get_json(url % dp['name'], - headers=headers) + data = self.get_json(url % dp['name'], headers=headers) mock_dp_name.assert_called_once() out_dp = data['device_profile'] self._validate_dp(dp, out_dp) @@ -105,8 +105,9 @@ class TestDeviceProfileController(v2_test.APITestV2): def test_get_all_by_name(self, mock_dp): mock_dp.return_value = self.fake_dp_objs name = 'dp_example_1' - data = self.get_json(self.DP_URL + '?name=' + name, - headers=self.headers) + data = self.get_json( + self.DP_URL + '?name=' + name, headers=self.headers + ) out_dps = data['device_profiles'] expected_dps = [dp for dp in self.fake_dp_objs if dp.name in [name]] @@ -133,14 +134,16 @@ class TestDeviceProfileController(v2_test.APITestV2): # delete dp name for test del test_unsupported_dp['name'] test_unsupported_dp['created_at'] = str( - test_unsupported_dp['created_at']) + test_unsupported_dp['created_at'] + ) self.assertRaisesRegex( webtest.app.AppError, "DeviceProfile name needed.", self.post_json, self.DP_URL, [test_unsupported_dp], - headers=self.headers) + headers=self.headers, + ) def test_create_with_unsupported_name(self): test_unsupported_dp = self.fake_dps[0] @@ -148,14 +151,16 @@ class TestDeviceProfileController(v2_test.APITestV2): # generate special dp name for test test_unsupported_dp['name'] = '!' test_unsupported_dp['created_at'] = str( - test_unsupported_dp['created_at']) + test_unsupported_dp['created_at'] + ) self.assertRaisesRegex( webtest.app.AppError, ".*Device profile name must be of the form *", self.post_json, self.DP_URL, [test_unsupported_dp], - headers=self.headers) + headers=self.headers, + ) def test_create_with_no_groups(self): test_unsupported_dp = self.fake_dps[0] @@ -163,14 +168,16 @@ class TestDeviceProfileController(v2_test.APITestV2): # delete dp groups for test del test_unsupported_dp['groups'] test_unsupported_dp['created_at'] = str( - test_unsupported_dp['created_at']) + test_unsupported_dp['created_at'] + ) self.assertRaisesRegex( webtest.app.AppError, "DeviceProfile needs groups field.", self.post_json, self.DP_URL, [test_unsupported_dp], - headers=self.headers) + headers=self.headers, + ) def test_create_with_unsupported_group_key(self): test_unsupported_dp = self.fake_dps[0] @@ -179,47 +186,55 @@ class TestDeviceProfileController(v2_test.APITestV2): del test_unsupported_dp['groups'][0]['resources:FPGA'] test_unsupported_dp['groups'][0]['fake:FPGA'] = 'required' test_unsupported_dp['created_at'] = str( - test_unsupported_dp['created_at']) + test_unsupported_dp['created_at'] + ) self.assertRaisesRegex( webtest.app.AppError, ".*Device profile group keys must be of the form *", self.post_json, self.DP_URL, [test_unsupported_dp], - headers=self.headers) + headers=self.headers, + ) def test_create_with_unsupported_trait_value(self): test_unsupported_dp = self.fake_dps[0] # generate special dp trait value for test test_unsupported_dp['groups'][0][ - 'trait:CUSTOM_FPGA_INTEL_PAC_ARRIA10'] = 'fake' + 'trait:CUSTOM_FPGA_INTEL_PAC_ARRIA10' + ] = 'fake' test_unsupported_dp['created_at'] = str( - test_unsupported_dp['created_at']) + test_unsupported_dp['created_at'] + ) self.assertRaisesRegex( webtest.app.AppError, ".*Unsupported trait value fake *", self.post_json, self.DP_URL, [test_unsupported_dp], - headers=self.headers) + headers=self.headers, + ) def test_create_with_unsupported_trait_name(self): test_unsupported_dp = self.fake_dps[0] # generate special trait for test del test_unsupported_dp['groups'][0][ - 'trait:CUSTOM_FPGA_INTEL_PAC_ARRIA10'] + 'trait:CUSTOM_FPGA_INTEL_PAC_ARRIA10' + ] test_unsupported_dp['groups'][0]['trait:FAKE_TRAIT'] = 'required' test_unsupported_dp['created_at'] = str( - test_unsupported_dp['created_at']) + test_unsupported_dp['created_at'] + ) self.assertRaisesRegex( webtest.app.AppError, ".*Unsupported trait name format FAKE_TRAIT.*", self.post_json, self.DP_URL, [test_unsupported_dp], - headers=self.headers) + headers=self.headers, + ) @mock.patch('cyborg.conductor.rpcapi.ConductorAPI.device_profile_create') def test_create_with_extra_space_in_trait(self, mock_cond_dp): @@ -227,16 +242,20 @@ class TestDeviceProfileController(v2_test.APITestV2): # generate a requested dp which has extra space in trait del test_unsupported_dp['groups'][0][ - 'trait:CUSTOM_FPGA_INTEL_PAC_ARRIA10'] + 'trait:CUSTOM_FPGA_INTEL_PAC_ARRIA10' + ] test_unsupported_dp['groups'][0][ - 'trait: CUSTOM_FPGA_INTEL_PAC_ARRIA10'] = 'required' + 'trait: CUSTOM_FPGA_INTEL_PAC_ARRIA10' + ] = 'required' mock_cond_dp.return_value = self.fake_dp_objs[0] test_unsupported_dp['created_at'] = str( - test_unsupported_dp['created_at']) + test_unsupported_dp['created_at'] + ) response = self.post_json( - self.DP_URL, [test_unsupported_dp], headers=self.headers) + self.DP_URL, [test_unsupported_dp], headers=self.headers + ) out_dp = jsonutils.loads(response.controller_output) # check that the extra space in trait: @@ -256,10 +275,12 @@ class TestDeviceProfileController(v2_test.APITestV2): mock_cond_dp.return_value = self.fake_dp_objs[0] test_unsupported_dp['created_at'] = str( - test_unsupported_dp['created_at']) + test_unsupported_dp['created_at'] + ) response = self.post_json( - self.DP_URL, [test_unsupported_dp], headers=self.headers) + self.DP_URL, [test_unsupported_dp], headers=self.headers + ) out_dp = jsonutils.loads(response.controller_output) # check that the extra space in rc:{'resources: FPGA ': '1'} is @@ -275,28 +296,32 @@ class TestDeviceProfileController(v2_test.APITestV2): test_unsupported_dp['groups'][0]["resources:FAKE_RC"] = '1' test_unsupported_dp['created_at'] = str( - test_unsupported_dp['created_at']) + test_unsupported_dp['created_at'] + ) self.assertRaisesRegex( webtest.app.AppError, ".*Unsupported resource class FAKE_RC.*", self.post_json, self.DP_URL, [test_unsupported_dp], - headers=self.headers) + headers=self.headers, + ) def test_create_with_invalid_resource_value(self): test_unsupported_dp = self.fake_dps[0] del test_unsupported_dp['groups'][0]['resources:FPGA'] test_unsupported_dp['groups'][0]["resources:CUSTOM_FAKE_RC"] = 'fake' test_unsupported_dp['created_at'] = str( - test_unsupported_dp['created_at']) + test_unsupported_dp['created_at'] + ) self.assertRaisesRegex( webtest.app.AppError, ".*Resources number fake is invalid.*", self.post_json, self.DP_URL, [test_unsupported_dp], - headers=self.headers) + headers=self.headers, + ) @mock.patch('cyborg.conductor.rpcapi.ConductorAPI.device_profile_delete') @mock.patch('cyborg.objects.DeviceProfile.get_by_name') diff --git a/cyborg/tests/unit/api/controllers/v2/test_devices.py b/cyborg/tests/unit/api/controllers/v2/test_devices.py index f45d77f0..01fa7b90 100644 --- a/cyborg/tests/unit/api/controllers/v2/test_devices.py +++ b/cyborg/tests/unit/api/controllers/v2/test_devices.py @@ -20,7 +20,6 @@ from cyborg.tests.unit import fake_device class TestDevicesController(v2_test.APITestV2): - DEVICE_URL = '/devices' def setUp(self): @@ -74,7 +73,8 @@ class TestDevicesController(v2_test.APITestV2): mock_devices.return_value = in_devices[:1] data = self.get_json( self.DEVICE_URL + "?filters.field=limit&filters.value=1", - headers=self.headers) + headers=self.headers, + ) out_devices = data['devices'] mock_devices.assert_called_once_with(mock.ANY, filters={"limit": "1"}) for in_device, out_device in zip(self.fake_devices, out_devices): @@ -85,11 +85,12 @@ class TestDevicesController(v2_test.APITestV2): in_devices = self.fake_devices mock_devices.return_value = [in_devices[0]] data = self.get_json( - self.DEVICE_URL + "?type=FPGA", - headers=self.headers) + self.DEVICE_URL + "?type=FPGA", headers=self.headers + ) out_devices = data['devices'] - mock_devices.assert_called_once_with(mock.ANY, - filters={"type": "FPGA"}) + mock_devices.assert_called_once_with( + mock.ANY, filters={"type": "FPGA"} + ) for in_device, out_device in zip(self.fake_devices, out_devices): self._validate_device(in_device, out_device) @@ -98,11 +99,12 @@ class TestDevicesController(v2_test.APITestV2): in_devices = self.fake_devices mock_devices.return_value = [in_devices[0]] data = self.get_json( - self.DEVICE_URL + "?vendor=0xABCD", - headers=self.headers) + self.DEVICE_URL + "?vendor=0xABCD", headers=self.headers + ) out_devices = data['devices'] - mock_devices.assert_called_once_with(mock.ANY, - filters={"vendor": "0xABCD"}) + mock_devices.assert_called_once_with( + mock.ANY, filters={"vendor": "0xABCD"} + ) for in_device, out_device in zip(self.fake_devices, out_devices): self._validate_device(in_device, out_device) @@ -111,10 +113,11 @@ class TestDevicesController(v2_test.APITestV2): in_devices = self.fake_devices mock_devices.return_value = [in_devices[0]] data = self.get_json( - self.DEVICE_URL + "?hostname=test-node-1", - headers=self.headers) + self.DEVICE_URL + "?hostname=test-node-1", headers=self.headers + ) out_devices = data['devices'] mock_devices.assert_called_once_with( - mock.ANY, filters={"hostname": "test-node-1"}) + mock.ANY, filters={"hostname": "test-node-1"} + ) for in_device, out_device in zip(self.fake_devices, out_devices): self._validate_device(in_device, out_device) diff --git a/cyborg/tests/unit/api/controllers/v2/test_fpga_program.py b/cyborg/tests/unit/api/controllers/v2/test_fpga_program.py index b500b884..ebe3a7c3 100644 --- a/cyborg/tests/unit/api/controllers/v2/test_fpga_program.py +++ b/cyborg/tests/unit/api/controllers/v2/test_fpga_program.py @@ -21,7 +21,6 @@ from cyborg.tests.unit import fake_device class TestFPGAProgramController(v2_test.APITestV2): - def setUp(self): super().setUp() self.headers = self.gen_headers(self.context) @@ -30,8 +29,9 @@ class TestFPGAProgramController(v2_test.APITestV2): self.nonexistent_image_uuid = "1234abcd-1234-1234-1234-abcde1234567" self.invalid_image_uuid = "abcd1234" dep_uuid = self.deployable_uuids[0] - self.dep = fake_deployable.fake_deployable_obj(self.context, - uuid=dep_uuid) + self.dep = fake_deployable.fake_deployable_obj( + self.context, uuid=dep_uuid + ) self.dev = fake_device.get_fake_devices_objs()[0] bdf = {"domain": "0000", "bus": "00", "device": "01", "function": "1"} self.cpid = { @@ -39,15 +39,20 @@ class TestFPGAProgramController(v2_test.APITestV2): "uuid": "e4a66b0d-b377-40d6-9cdc-6bf7e720e596", "device_id": "1", "cpid_type": "PCI", - "cpid_info": jsonutils.dumps(bdf).encode('utf-8') + "cpid_info": jsonutils.dumps(bdf).encode('utf-8'), } @mock.patch('cyborg.objects.Device.get_by_device_id') @mock.patch('cyborg.objects.Deployable.get_cpid_list') @mock.patch('cyborg.objects.Deployable.get') @mock.patch('cyborg.agent.rpcapi.AgentAPI.fpga_program') - def test_program_success(self, mock_program, mock_get_dep, - mock_get_cpid_list, mock_get_by_device_id): + def test_program_success( + self, + mock_program, + mock_get_dep, + mock_get_cpid_list, + mock_get_by_device_id, + ): self.headers['X-Roles'] = 'admin' self.headers['Content-Type'] = 'application/json' dep_uuid = self.deployable_uuids[0] @@ -56,9 +61,11 @@ class TestFPGAProgramController(v2_test.APITestV2): mock_get_cpid_list.return_value = [self.cpid] mock_program.return_value = True body = [{"image_uuid": self.existent_image_uuid}] - response = self.patch_json('/deployables/%s/program' % dep_uuid, - [{'path': '/bitstream_id', 'value': body, - 'op': 'replace'}], headers=self.headers) + response = self.patch_json( + '/deployables/%s/program' % dep_uuid, + [{'path': '/bitstream_id', 'value': body, 'op': 'replace'}], + headers=self.headers, + ) self.assertEqual(HTTPStatus.OK, response.status_code) data = response.json_body self.assertEqual(dep_uuid, data['uuid']) @@ -67,8 +74,13 @@ class TestFPGAProgramController(v2_test.APITestV2): @mock.patch('cyborg.objects.Deployable.get_cpid_list') @mock.patch('cyborg.objects.Deployable.get') @mock.patch('cyborg.agent.rpcapi.AgentAPI.fpga_program') - def test_program_failed(self, mock_program, mock_get_dep, - mock_get_cpid_list, mock_get_by_device_id): + def test_program_failed( + self, + mock_program, + mock_get_dep, + mock_get_cpid_list, + mock_get_by_device_id, + ): self.headers['X-Roles'] = 'admin' self.headers['Content-Type'] = 'application/json' dep_uuid = self.deployable_uuids[0] @@ -78,22 +90,29 @@ class TestFPGAProgramController(v2_test.APITestV2): mock_program.return_value = False body = [{"image_uuid": self.existent_image_uuid}] try: - self.patch_json('/deployables/%s/program' % dep_uuid, - [{'path': '/bitstream_id', 'value': body, - 'op': 'replace'}], headers=self.headers) + self.patch_json( + '/deployables/%s/program' % dep_uuid, + [{'path': '/bitstream_id', 'value': body, 'op': 'replace'}], + headers=self.headers, + ) except Exception as e: exc = e - self.assertIn(exception.FPGAProgramError( - ret=mock_program.return_value).args[0], - exc.args[0] - ) + self.assertIn( + exception.FPGAProgramError(ret=mock_program.return_value).args[0], + exc.args[0], + ) @mock.patch('cyborg.objects.Device.get_by_device_id') @mock.patch('cyborg.objects.Deployable.get_cpid_list') @mock.patch('cyborg.objects.Deployable.get') @mock.patch('cyborg.agent.rpcapi.AgentAPI.fpga_program') - def test_program_invalid_uuid(self, mock_program, mock_get_dep, - mock_get_cpid_list, mock_get_by_device_id): + def test_program_invalid_uuid( + self, + mock_program, + mock_get_dep, + mock_get_cpid_list, + mock_get_by_device_id, + ): self.headers['X-Roles'] = 'admin' self.headers['Content-Type'] = 'application/json' dep_uuid = self.deployable_uuids[0] @@ -103,24 +122,28 @@ class TestFPGAProgramController(v2_test.APITestV2): mock_program.return_value = False body = [{"image_uuid": self.invalid_image_uuid}] try: - self.patch_json('/deployables/%s/program' % dep_uuid, - [{'path': '/bitstream_id', - 'value': body, - 'op': 'replace'}], - headers=self.headers) + self.patch_json( + '/deployables/%s/program' % dep_uuid, + [{'path': '/bitstream_id', 'value': body, 'op': 'replace'}], + headers=self.headers, + ) except Exception as e: exc = e - self.assertIn(exception.InvalidUUID(self.invalid_image_uuid).args[0], - exc.args[0]) + self.assertIn( + exception.InvalidUUID(self.invalid_image_uuid).args[0], exc.args[0] + ) @mock.patch('cyborg.objects.Device.get_by_device_id') @mock.patch('cyborg.objects.Deployable.get_cpid_list') @mock.patch('cyborg.objects.Deployable.get') @mock.patch('cyborg.agent.rpcapi.AgentAPI.fpga_program') - def test_program_wrong_image_uuid(self, mock_program, - mock_get_dep, - mock_get_cpid_list, - mock_get_by_device_id): + def test_program_wrong_image_uuid( + self, + mock_program, + mock_get_dep, + mock_get_cpid_list, + mock_get_by_device_id, + ): self.headers['X-Roles'] = 'admin' self.headers['Content-Type'] = 'application/json' dep_uuid = self.deployable_uuids[0] @@ -130,14 +153,14 @@ class TestFPGAProgramController(v2_test.APITestV2): mock_program.return_value = False body = [{"image_uuid": self.nonexistent_image_uuid}] try: - self.patch_json('/deployables/%s/program' % dep_uuid, - [{'path': '/bitstream_id', - 'value': body, - 'op': 'replace'}], - headers=self.headers) + self.patch_json( + '/deployables/%s/program' % dep_uuid, + [{'path': '/bitstream_id', 'value': body, 'op': 'replace'}], + headers=self.headers, + ) except Exception as e: exc = e - self.assertIn(exception.FPGAProgramError( - ret=mock_program.return_value).args[0], - exc.args[0] - ) + self.assertIn( + exception.FPGAProgramError(ret=mock_program.return_value).args[0], + exc.args[0], + ) diff --git a/cyborg/tests/unit/api/controllers/v2/test_microversion.py b/cyborg/tests/unit/api/controllers/v2/test_microversion.py index 231d6b67..1bc10aa2 100644 --- a/cyborg/tests/unit/api/controllers/v2/test_microversion.py +++ b/cyborg/tests/unit/api/controllers/v2/test_microversion.py @@ -23,9 +23,13 @@ MAX_VER = versions.max_version_string() class TestMicroversions(api_base.BaseApiTest): - controller_list_response = [ - 'id', 'links', 'max_version', 'min_version', 'status'] + 'id', + 'links', + 'max_version', + 'min_version', + 'status', + ] def setUp(self): super().setUp() @@ -34,11 +38,12 @@ class TestMicroversions(api_base.BaseApiTest): response = self.get_json( '/v2', headers={'OpenStack-API-Version': '10'}, - expect_errors=True, return_json=False) + expect_errors=True, + return_json=False, + ) self.assertEqual('application/json', response.content_type) self.assertEqual(406, response.status_int) - expected_error_msg = ('Invalid value for' - ' OpenStack-API-Version header') + expected_error_msg = 'Invalid value for OpenStack-API-Version header' self.assertTrue(response.json['error_message']) self.assertIn(expected_error_msg, response.json['error_message']) @@ -50,43 +55,57 @@ class TestMicroversions(api_base.BaseApiTest): self.assertEqual(response.headers[H_MIN_VER], MIN_VER) self.assertEqual(response.headers[H_MAX_VER], MAX_VER) self.assertEqual(response.headers[H_RESP_VER], MIN_VER) - self.assertTrue(all(x in response.json.keys() for x in - self.controller_list_response)) + self.assertTrue( + all( + x in response.json.keys() + for x in self.controller_list_response + ) + ) def test_new_client_new_api(self): response = self.get_json( - '/v2', - headers={'OpenStack-API-Version': '2.0'}, - return_json=False) + '/v2', headers={'OpenStack-API-Version': '2.0'}, return_json=False + ) self.assertEqual(response.headers[H_MIN_VER], MIN_VER) self.assertEqual(response.headers[H_MAX_VER], MAX_VER) self.assertEqual(response.headers[H_RESP_VER], '2.0') - self.assertTrue(all(x in response.json.keys() for x in - self.controller_list_response)) + self.assertTrue( + all( + x in response.json.keys() + for x in self.controller_list_response + ) + ) def test_latest_microversion(self): response = self.get_json( '/v2', headers={'OpenStack-API-Version': 'latest'}, - return_json=False) + return_json=False, + ) self.assertEqual(response.headers[H_MIN_VER], MIN_VER) self.assertEqual(response.headers[H_MAX_VER], MAX_VER) self.assertEqual(response.headers[H_RESP_VER], MAX_VER) - self.assertTrue(all(x in response.json.keys() for x in - self.controller_list_response)) + self.assertTrue( + all( + x in response.json.keys() + for x in self.controller_list_response + ) + ) def test_unsupported_version(self): unsupported_version = str(float(MAX_VER) + 0.1) response = self.get_json( '/v2', headers={'OpenStack-API-Version': unsupported_version}, - expect_errors=True) + expect_errors=True, + ) self.assertEqual(406, response.status_int) self.assertEqual(response.headers[H_MIN_VER], MIN_VER) self.assertEqual(response.headers[H_MAX_VER], MAX_VER) - expected_error_msg = ('Version %s was requested but the minor ' - 'version is not supported by this service. ' - 'The supported version range is' % - unsupported_version) + expected_error_msg = ( + 'Version %s was requested but the minor ' + 'version is not supported by this service. ' + 'The supported version range is' % unsupported_version + ) self.assertTrue(response.json['error_message']) self.assertIn(expected_error_msg, response.json['error_message']) diff --git a/cyborg/tests/unit/cmd/test_status.py b/cyborg/tests/unit/cmd/test_status.py index 03d2052e..5daf6c65 100644 --- a/cyborg/tests/unit/cmd/test_status.py +++ b/cyborg/tests/unit/cmd/test_status.py @@ -27,15 +27,11 @@ from cyborg.tests import base class TestUpgradeCheckPolicyJSON(base.TestCase): - def setUp(self): super().setUp() self.cmd = status.UpgradeCommands() authorize_wsgi.CONF.clear_override('policy_file', group='oslo_policy') - self.data = { - 'rule_admin': 'True', - 'rule_admin2': 'is_admin:True' - } + self.data = {'rule_admin': 'True', 'rule_admin2': 'is_admin:True'} self.temp_dir = self.useFixture(fixtures.TempDir()) fd, self.json_file = tempfile.mkstemp(dir=self.temp_dir.path) fd, self.yaml_file = tempfile.mkstemp(dir=self.temp_dir.path) @@ -51,38 +47,44 @@ class TestUpgradeCheckPolicyJSON(base.TestCase): dirs.append(self.temp_dir.path) return original_search_dirs(dirs, name) - self.mock_search = self.useFixture(fixtures.MockPatch( - 'oslo_config.cfg._search_dirs')).mock + self.mock_search = self.useFixture( + fixtures.MockPatch('oslo_config.cfg._search_dirs') + ).mock self.mock_search.side_effect = fake_search_dirs def test_policy_json_file_fail_upgrade(self): # Test with policy json file full path set in config. self.flags(policy_file=self.json_file, group="oslo_policy") - self.assertEqual(upgradecheck.Code.FAILURE, - self.cmd._check_policy_json().code) + self.assertEqual( + upgradecheck.Code.FAILURE, self.cmd._check_policy_json().code + ) def test_policy_yaml_file_pass_upgrade(self): # Test with full policy yaml file path set in config. self.flags(policy_file=self.yaml_file, group="oslo_policy") - self.assertEqual(upgradecheck.Code.SUCCESS, - self.cmd._check_policy_json().code) + self.assertEqual( + upgradecheck.Code.SUCCESS, self.cmd._check_policy_json().code + ) def test_no_policy_file_pass_upgrade(self): # Test with no policy file exist. - self.assertEqual(upgradecheck.Code.SUCCESS, - self.cmd._check_policy_json().code) + self.assertEqual( + upgradecheck.Code.SUCCESS, self.cmd._check_policy_json().code + ) def test_default_policy_yaml_file_pass_upgrade(self): tmpfilename = os.path.join(self.temp_dir.path, 'policy.yaml') with open(tmpfilename, 'w') as fh: yaml.dump(self.data, fh) - self.assertEqual(upgradecheck.Code.SUCCESS, - self.cmd._check_policy_json().code) + self.assertEqual( + upgradecheck.Code.SUCCESS, self.cmd._check_policy_json().code + ) def test_old_default_policy_json_file_fail_upgrade(self): self.flags(policy_file='policy.json', group="oslo_policy") tmpfilename = os.path.join(self.temp_dir.path, 'policy.json') with open(tmpfilename, 'w') as fh: jsonutils.dump(self.data, fh) - self.assertEqual(upgradecheck.Code.FAILURE, - self.cmd._check_policy_json().code) + self.assertEqual( + upgradecheck.Code.FAILURE, self.cmd._check_policy_json().code + ) diff --git a/cyborg/tests/unit/common/test_nova_client.py b/cyborg/tests/unit/common/test_nova_client.py index 8d138364..345ea76b 100644 --- a/cyborg/tests/unit/common/test_nova_client.py +++ b/cyborg/tests/unit/common/test_nova_client.py @@ -27,18 +27,24 @@ class NovaAPITest(base.TestCase): def setUp(self): super().setUp() self.instance_uuid = '00000000-0000-0000-0000-000000000001' - template = {'name': 'accelerator-request-bound', - 'server_uuid': self.instance_uuid, - 'code': 200, - 'status': 'completed'} - tags = ['00000000-0000-0000-0000-000000000002', - '00000000-0000-0000-0000-000000000003'] + template = { + 'name': 'accelerator-request-bound', + 'server_uuid': self.instance_uuid, + 'code': 200, + 'status': 'completed', + } + tags = [ + '00000000-0000-0000-0000-000000000002', + '00000000-0000-0000-0000-000000000003', + ] self.events = [dict(template, tag=tag) for tag in tags] - self.mock_sdk = self.useFixture(fixtures.MockPatch( - 'cyborg.common.utils.get_sdk_adapter')).mock.return_value - self.mock_log_info = self.useFixture(fixtures.MockPatch( - 'cyborg.common.nova_client.LOG.info')).mock + self.mock_sdk = self.useFixture( + fixtures.MockPatch('cyborg.common.utils.get_sdk_adapter') + ).mock.return_value + self.mock_log_info = self.useFixture( + fixtures.MockPatch('cyborg.common.nova_client.LOG.info') + ).mock def test_send_events(self): self.mock_sdk.post.return_value = mock.Mock(status_code=200) @@ -48,7 +54,8 @@ class NovaAPITest(base.TestCase): msg = 'Successfully sent events to Nova, events: %(events)s' self.mock_log_info.assert_called_once_with( - msg, {'events': self.events}) + msg, {'events': self.events} + ) def test_send_events_422(self): # If Nova returns HTTP 207 with event code 422 for all events, @@ -64,8 +71,10 @@ class NovaAPITest(base.TestCase): nova = nova_client.NovaAPI() nova._send_events(self.events) - msg = ('Ignoring Nova notification error that the instance %s is not ' - 'yet associated with a host.') + msg = ( + 'Ignoring Nova notification error that the instance %s is not ' + 'yet associated with a host.' + ) self.mock_log_info.assert_called_once_with(msg, self.instance_uuid) def test_send_events_with_event_code_422_exception(self): @@ -80,8 +89,9 @@ class NovaAPITest(base.TestCase): self.mock_sdk.post.return_value = mock_ret nova = nova_client.NovaAPI() - self.assertRaises(exception.InvalidAPIResponse, - nova._send_events, self.events) + self.assertRaises( + exception.InvalidAPIResponse, nova._send_events, self.events + ) def test_send_events_with_event_code_400_exception(self): # If Nova returns HTTP 207 with event code 400 for some events, @@ -94,8 +104,9 @@ class NovaAPITest(base.TestCase): self.mock_sdk.post.return_value = mock_ret nova = nova_client.NovaAPI() - self.assertRaises(exception.InvalidAPIResponse, - nova._send_events, self.events) + self.assertRaises( + exception.InvalidAPIResponse, nova._send_events, self.events + ) def test_send_events_with_all_event_code_400_exception(self): # If Nova returns HTTP 207 with event code 400 for all events, @@ -109,8 +120,9 @@ class NovaAPITest(base.TestCase): self.mock_sdk.post.return_value = mock_ret nova = nova_client.NovaAPI() - self.assertRaises(exception.InvalidAPIResponse, - nova._send_events, self.events) + self.assertRaises( + exception.InvalidAPIResponse, nova._send_events, self.events + ) def test_send_events_failure(self): # Nova is expected to return 200/207 but this is future-proofing. @@ -119,5 +131,6 @@ class NovaAPITest(base.TestCase): self.mock_sdk.post.return_value = mock_ret nova = nova_client.NovaAPI() - self.assertRaises(exception.InvalidAPIResponse, - nova._send_events, self.events) + self.assertRaises( + exception.InvalidAPIResponse, nova._send_events, self.events + ) diff --git a/cyborg/tests/unit/common/test_placement_client.py b/cyborg/tests/unit/common/test_placement_client.py index e4052f02..7e63cf2e 100644 --- a/cyborg/tests/unit/common/test_placement_client.py +++ b/cyborg/tests/unit/common/test_placement_client.py @@ -21,17 +21,19 @@ from cyborg.tests import base class PlacementAPITest(base.TestCase): - def setUp(self): super().setUp() self.instance_uuid = '00000000-0000-0000-0000-000000000001' - self.mock_sdk = self.useFixture(fixtures.MockPatch( - 'cyborg.common.utils.get_sdk_adapter')).mock.return_value - self.mock_log_info = self.useFixture(fixtures.MockPatch( - 'cyborg.common.placement_client.LOG.info')).mock - self.mock_log_debug = self.useFixture(fixtures.MockPatch( - 'cyborg.common.placement_client.LOG.debug')).mock + self.mock_sdk = self.useFixture( + fixtures.MockPatch('cyborg.common.utils.get_sdk_adapter') + ).mock.return_value + self.mock_log_info = self.useFixture( + fixtures.MockPatch('cyborg.common.placement_client.LOG.info') + ).mock + self.mock_log_debug = self.useFixture( + fixtures.MockPatch('cyborg.common.placement_client.LOG.debug') + ).mock def test_get(self): self.mock_sdk.get.return_value = mock.Mock(status_code=200) @@ -44,8 +46,9 @@ class PlacementAPITest(base.TestCase): placement = placement_client.PlacementClient() mock_ret = mock.Mock(status_code=500) self.mock_sdk.get.return_value = mock_ret - self.assertRaises(exception.PlacementServerError, - placement.get, mock.Mock()) + self.assertRaises( + exception.PlacementServerError, placement.get, mock.Mock() + ) def test_post(self): self.mock_sdk.post.return_value = mock.Mock(status_code=200) @@ -58,8 +61,12 @@ class PlacementAPITest(base.TestCase): placement = placement_client.PlacementClient() mock_ret = mock.Mock(status_code=500) self.mock_sdk.post.return_value = mock_ret - self.assertRaises(exception.PlacementServerError, - placement.post, mock.Mock(), mock.ANY) + self.assertRaises( + exception.PlacementServerError, + placement.post, + mock.Mock(), + mock.ANY, + ) def test_put(self): self.mock_sdk.put.return_value = mock.Mock(status_code=200) @@ -72,8 +79,12 @@ class PlacementAPITest(base.TestCase): placement = placement_client.PlacementClient() mock_ret = mock.Mock(status_code=500) self.mock_sdk.put.return_value = mock_ret - self.assertRaises(exception.PlacementServerError, - placement.put, mock.Mock(), mock.ANY) + self.assertRaises( + exception.PlacementServerError, + placement.put, + mock.Mock(), + mock.ANY, + ) def test_delete(self): self.mock_sdk.delete.return_value = mock.Mock(status_code=200) @@ -86,8 +97,12 @@ class PlacementAPITest(base.TestCase): placement = placement_client.PlacementClient() mock_ret = mock.Mock(status_code=500) self.mock_sdk.delete.return_value = mock_ret - self.assertRaises(exception.PlacementServerError, - placement.delete, mock.Mock(), mock.ANY) + self.assertRaises( + exception.PlacementServerError, + placement.delete, + mock.Mock(), + mock.ANY, + ) def test_get_rp_traits(self): self.mock_sdk.get.return_value = mock.Mock(status_code=200) @@ -100,8 +115,9 @@ class PlacementAPITest(base.TestCase): placement = placement_client.PlacementClient() mock_ret = mock.Mock(status_code=500) self.mock_sdk.get.return_value = mock_ret - self.assertRaises(exception.PlacementServerError, - placement._get_rp_traits, mock.ANY) + self.assertRaises( + exception.PlacementServerError, placement._get_rp_traits, mock.ANY + ) def test_ensure_traits(self): self.mock_sdk.put.return_value = mock.Mock(status_code=201) @@ -118,11 +134,15 @@ class PlacementAPITest(base.TestCase): mock_ret = mock.Mock(status_code=500) self.mock_sdk.get.return_value = None self.mock_sdk.put.return_value = mock_ret - self.assertRaises(exception.PlacementServerError, - placement._ensure_traits, [mock.ANY]) + self.assertRaises( + exception.PlacementServerError, + placement._ensure_traits, + [mock.ANY], + ) - @mock.patch('cyborg.common.placement_client.' - 'PlacementClient.get_resource_provider') + @mock.patch( + 'cyborg.common.placement_client.PlacementClient.get_resource_provider' + ) def test_put_rp_traits(self, rp): self.mock_sdk.put.return_value = mock.Mock(status_code=200) placement = placement_client.PlacementClient() @@ -131,13 +151,17 @@ class PlacementAPITest(base.TestCase): msg = 'Successfully update resources from placement: %s' self.mock_log_debug.assert_called_once_with(msg, mock.ANY) - @mock.patch('cyborg.common.placement_client.' - 'PlacementClient.get_resource_provider') + @mock.patch( + 'cyborg.common.placement_client.PlacementClient.get_resource_provider' + ) def test_put_rp_traits_exception(self, rp): placement = placement_client.PlacementClient() mock_ret = mock.Mock(status_code=500) rp.return_value = {'status_code': 200, 'generation': 0} self.mock_sdk.put.return_value = mock_ret - self.assertRaises(exception.PlacementServerError, - placement._put_rp_traits, - mock.ANY, {'traits': 'fake_trait'}) + self.assertRaises( + exception.PlacementServerError, + placement._put_rp_traits, + mock.ANY, + {'traits': 'fake_trait'}, + ) diff --git a/cyborg/tests/unit/common/test_service.py b/cyborg/tests/unit/common/test_service.py index 45f7ffee..aa919c9c 100644 --- a/cyborg/tests/unit/common/test_service.py +++ b/cyborg/tests/unit/common/test_service.py @@ -19,7 +19,6 @@ from cyborg.tests import base class TestRPCService(base.TestCase): - def setUp(self): super().setUp() self.topic = 'cyborg-conductor' @@ -27,35 +26,41 @@ class TestRPCService(base.TestCase): self.manager_module = 'cyborg.conductor.manager' self.manager_class = 'ConductorManager' - self.mock_try_import = self.useFixture(fixtures.MockPatch( - 'cyborg.common.service.importutils.try_import', - autospec=True)).mock + self.mock_try_import = self.useFixture( + fixtures.MockPatch( + 'cyborg.common.service.importutils.try_import', autospec=True + ) + ).mock mock_module = mock.MagicMock() mock_manager_cls = mock.MagicMock() self.mock_manager = mock_manager_cls.return_value setattr(mock_module, self.manager_class, mock_manager_cls) self.mock_try_import.return_value = mock_module - self.mock_get_server = self.useFixture(fixtures.MockPatch( - 'cyborg.common.service.rpc.get_server', - autospec=True)).mock + self.mock_get_server = self.useFixture( + fixtures.MockPatch( + 'cyborg.common.service.rpc.get_server', autospec=True + ) + ).mock self.mock_rpcserver = self.mock_get_server.return_value - self.mock_get_admin_context = self.useFixture(fixtures.MockPatch( - 'cyborg.common.service.context.get_admin_context', - autospec=True)).mock - self.mock_admin_context = ( - self.mock_get_admin_context.return_value) + self.mock_get_admin_context = self.useFixture( + fixtures.MockPatch( + 'cyborg.common.service.context.get_admin_context', + autospec=True, + ) + ).mock + self.mock_admin_context = self.mock_get_admin_context.return_value - self.mock_log_info = self.useFixture(fixtures.MockPatch( - 'cyborg.common.service.LOG.info')).mock + self.mock_log_info = self.useFixture( + fixtures.MockPatch('cyborg.common.service.LOG.info') + ).mock - @mock.patch.object( - cyborg_service.service.Service, 'start', autospec=True) + @mock.patch.object(cyborg_service.service.Service, 'start', autospec=True) def test_start(self, mock_super_start): svc = cyborg_service.RPCService( - self.manager_module, self.manager_class, - self.topic, host=self.host) + self.manager_module, self.manager_class, self.topic, host=self.host + ) svc.tg = mock.MagicMock() svc.start() @@ -67,9 +72,10 @@ class TestRPCService(base.TestCase): svc.tg.add_dynamic_timer_args.assert_called_once_with( self.mock_manager.periodic_tasks, kwargs={'context': self.mock_admin_context}, - periodic_interval_max=CONF.periodic_interval) + periodic_interval_max=CONF.periodic_interval, + ) self.mock_log_info.assert_called_once_with( - 'Created RPC server for service %(service)s on host ' - '%(host)s.', - {'service': self.topic, 'host': self.host}) + 'Created RPC server for service %(service)s on host %(host)s.', + {'service': self.topic, 'host': self.host}, + ) diff --git a/cyborg/tests/unit/conductor/test_manager.py b/cyborg/tests/unit/conductor/test_manager.py index 52a896de..7efbf417 100644 --- a/cyborg/tests/unit/conductor/test_manager.py +++ b/cyborg/tests/unit/conductor/test_manager.py @@ -23,15 +23,20 @@ from cyborg.tests.unit import fake_driver_device class ConductorManagerTest(base.TestCase): def setUp(self): super().setUp() - self.placement_mock = self.useFixture(fixtures.MockPatch( - 'cyborg.common.placement_client.PlacementClient') + self.placement_mock = self.useFixture( + fixtures.MockPatch( + 'cyborg.common.placement_client.PlacementClient' + ) ).mock.return_value self.cm = manager.ConductorManager( - mock.sentinel.topic, mock.sentinel.host) - self.fake_driver_devices = (fake_driver_device. - get_fake_driver_devices_objs()) - self.fake_driver_depolyables = (fake_driver_device. - get_fake_driver_deployable_objs()) + mock.sentinel.topic, mock.sentinel.host + ) + self.fake_driver_devices = ( + fake_driver_device.get_fake_driver_devices_objs() + ) + self.fake_driver_depolyables = ( + fake_driver_device.get_fake_driver_deployable_objs() + ) def test__gen_resource_inventory(self): expected = { @@ -46,12 +51,14 @@ class ConductorManagerTest(base.TestCase): @mock.patch('cyborg.conductor.manager.ConductorManager._get_sub_provider') def test_provider_report(self, mock_get_sub): rc = 'CUSTOM_ACCELERATOR' - traits = ["CUSTOM_FPGA_INTEL", - "CUSTOM_FPGA_INTEL_ARRIA10", - "CUSTOM_FPGA_INTEL_REGION_UUID", - "CUSTOM_FPGA_FUNCTION_ID_INTEL_UUID", - "CUSTOM_PROGRAMMABLE", - "CUSTOM_FPGA_NETWORK"] + traits = [ + "CUSTOM_FPGA_INTEL", + "CUSTOM_FPGA_INTEL_ARRIA10", + "CUSTOM_FPGA_INTEL_REGION_UUID", + "CUSTOM_FPGA_FUNCTION_ID_INTEL_UUID", + "CUSTOM_PROGRAMMABLE", + "CUSTOM_FPGA_NETWORK", + ] total = 42 expected_inv = { rc: { @@ -61,18 +68,27 @@ class ConductorManagerTest(base.TestCase): } self.placement_mock.ensure_resource_classes.return_value = None actual = self.cm.provider_report( - mock.sentinel.context, mock.sentinel.name, rc, traits, total, - mock.sentinel.parent) + mock.sentinel.context, + mock.sentinel.name, + rc, + traits, + total, + mock.sentinel.parent, + ) self.placement_mock.ensure_resource_classes.assert_called_once_with( - mock.sentinel.context, [rc]) + mock.sentinel.context, [rc] + ) mock_get_sub.assert_called_once_with( - mock.sentinel.context, mock.sentinel.parent, mock.sentinel.name) + mock.sentinel.context, mock.sentinel.parent, mock.sentinel.name + ) sub_pr_uuid = mock_get_sub.return_value self.placement_mock.update_inventory.assert_called_once_with( - sub_pr_uuid, expected_inv) + sub_pr_uuid, expected_inv + ) self.placement_mock.add_traits_to_rp.assert_called_once_with( - sub_pr_uuid, traits) + sub_pr_uuid, traits + ) self.assertEqual(sub_pr_uuid, actual) def test_get_root_provider(self): @@ -86,29 +102,45 @@ class ConductorManagerTest(base.TestCase): self.placement_mock.get.return_value.json.return_value = { 'resource_providers': [], } - self.assertRaises(exception.PlacementResourceProviderNotFound, - self.cm._get_root_provider, - mock.sentinel.context, 'foo') + self.assertRaises( + exception.PlacementResourceProviderNotFound, + self.cm._get_root_provider, + mock.sentinel.context, + 'foo', + ) def test_get_root_provider_unavailable(self): self.placement_mock.get.side_effect = exception.PlacementServerError( - "Placement Server has some error at this time.") - self.assertRaises(exception.PlacementServerError, - self.cm._get_root_provider, - mock.sentinel.context, 'foo') + "Placement Server has some error at this time." + ) + self.assertRaises( + exception.PlacementServerError, + self.cm._get_root_provider, + mock.sentinel.context, + 'foo', + ) - @mock.patch('cyborg.conductor.manager.ConductorManager.' - '_delete_provider_and_sub_providers') - @mock.patch('cyborg.conductor.manager.ConductorManager.' - 'get_placement_needed_info_and_report') - @mock.patch('cyborg.objects.driver_objects.driver_device.' - 'DriverDevice.destroy') - @mock.patch('cyborg.objects.driver_objects.driver_device.' - 'DriverDevice.create') - def test_drv_device_make_diff(self, mock_create_driver_device, - mock_destroy_driver_device, - mock_placement_report, - mock_placement_delete): + @mock.patch( + 'cyborg.conductor.manager.ConductorManager.' + '_delete_provider_and_sub_providers' + ) + @mock.patch( + 'cyborg.conductor.manager.ConductorManager.' + 'get_placement_needed_info_and_report' + ) + @mock.patch( + 'cyborg.objects.driver_objects.driver_device.DriverDevice.destroy' + ) + @mock.patch( + 'cyborg.objects.driver_objects.driver_device.DriverDevice.create' + ) + def test_drv_device_make_diff( + self, + mock_create_driver_device, + mock_destroy_driver_device, + mock_placement_report, + mock_placement_delete, + ): old_driver_attr_list = [] new_driver_attr_list = self.fake_driver_devices[:1] self.placement_mock.get.return_value.json.return_value = { @@ -116,28 +148,42 @@ class ConductorManagerTest(base.TestCase): } mock_placement_report.side_effect = ( - exception.ResourceProviderCreationFailed( - name=uuids.compute_node)) + exception.ResourceProviderCreationFailed(name=uuids.compute_node) + ) self.cm.drv_device_make_diff( - mock.sentinel.context, 'foo', - old_driver_attr_list, new_driver_attr_list) + mock.sentinel.context, + 'foo', + old_driver_attr_list, + new_driver_attr_list, + ) mock_destroy_driver_device.assert_called_once() mock_placement_delete.assert_called_once() - @mock.patch('cyborg.conductor.manager.ConductorManager.' - '_delete_provider_and_sub_providers') - @mock.patch('cyborg.conductor.manager.ConductorManager.' - 'get_placement_needed_info_and_report') - @mock.patch('cyborg.objects.driver_objects.driver_deployable.' - 'DriverDeployable.destroy') - @mock.patch('cyborg.objects.driver_objects.driver_deployable.' - 'DriverDeployable.create') - def test_drv_deployable_make_diff(self, mock_create_driver_deployable, - mock_destroy_driver_deployable, - mock_placement_report, - mock_placement_delete): + @mock.patch( + 'cyborg.conductor.manager.ConductorManager.' + '_delete_provider_and_sub_providers' + ) + @mock.patch( + 'cyborg.conductor.manager.ConductorManager.' + 'get_placement_needed_info_and_report' + ) + @mock.patch( + 'cyborg.objects.driver_objects.driver_deployable.' + 'DriverDeployable.destroy' + ) + @mock.patch( + 'cyborg.objects.driver_objects.driver_deployable.' + 'DriverDeployable.create' + ) + def test_drv_deployable_make_diff( + self, + mock_create_driver_deployable, + mock_destroy_driver_deployable, + mock_placement_report, + mock_placement_delete, + ): old_driver_dep_list = [] new_driver_dep_list = self.fake_driver_depolyables[:1] self.placement_mock.get.return_value.json.return_value = { @@ -145,12 +191,17 @@ class ConductorManagerTest(base.TestCase): } mock_placement_report.side_effect = ( - exception.ResourceProviderCreationFailed( - name=uuids.compute_node)) + exception.ResourceProviderCreationFailed(name=uuids.compute_node) + ) self.cm.drv_deployable_make_diff( - mock.sentinel.context, '1', '2', - old_driver_dep_list, new_driver_dep_list, 'foo') + mock.sentinel.context, + '1', + '2', + old_driver_dep_list, + new_driver_dep_list, + 'foo', + ) mock_destroy_driver_deployable.assert_called_once() mock_placement_delete.assert_called_once() diff --git a/cyborg/tests/unit/db/base.py b/cyborg/tests/unit/db/base.py index 6f37153f..914ee544 100644 --- a/cyborg/tests/unit/db/base.py +++ b/cyborg/tests/unit/db/base.py @@ -38,7 +38,6 @@ CONF = cfg.CONF class DbTestCase(base.TestCase): - def setUp(self): super().setUp() @@ -47,12 +46,11 @@ class DbTestCase(base.TestCase): self.useFixture(fixtures.NestedTempfile()) # File-backed SQLite so each thread gets its own connection. - fd, dbfile_path = tempfile.mkstemp( - prefix="cyborg_test_", suffix=".db") + fd, dbfile_path = tempfile.mkstemp(prefix="cyborg_test_", suffix=".db") os.close(fd) CONF.set_override( - "connection", "sqlite:///%s" % dbfile_path, - group="database") + "connection", "sqlite:///%s" % dbfile_path, group="database" + ) # WAL mode: readers don't block writers, writer doesn't # block readers. @@ -63,11 +61,13 @@ class DbTestCase(base.TestCase): local_enginefacade = enginefacade.transaction_context() local_enginefacade.configure( connection=CONF.database.connection, - sqlite_synchronous=CONF.database.sqlite_synchronous) + sqlite_synchronous=CONF.database.sqlite_synchronous, + ) self.useFixture( test_fixtures.ReplaceEngineFacadeFixture( - sqlalchemy_api.main_context_manager, - local_enginefacade)) + sqlalchemy_api.main_context_manager, local_enginefacade + ) + ) # Build schema from models directly, bypassing Alembic's env.py # which would create its own engine via the global enginefacade. diff --git a/cyborg/tests/unit/db/test_db_api.py b/cyborg/tests/unit/db/test_db_api.py index 86877ae4..d0fb0ff7 100644 --- a/cyborg/tests/unit/db/test_db_api.py +++ b/cyborg/tests/unit/db/test_db_api.py @@ -38,51 +38,59 @@ def _quota_reserve(context, project_id): for i, resource in enumerate(('fpga', 'gpu')): deltas[resource] = i + 1 return sqlalchemy_api.quota_reserve( - context, resources, deltas, - timeutils.utcnow(), timeutils.utcnow(), - datetime.timedelta(days=1), project_id + context, + resources, + deltas, + timeutils.utcnow(), + timeutils.utcnow(), + datetime.timedelta(days=1), + project_id, ) class DBAPIQuotaUsageTestCase(base.DbTestCase): - """Tests for db.api.quota_usage_* methods.""" def _test_quota_reserve(self): sqlalchemy_api = sqlalchemyapi.get_backend() reservations = _quota_reserve(self.context, 'project1') self.assertEqual(2, len(reservations)) - quota_usages = sqlalchemy_api._get_quota_usages(self.context, - 'project1') + quota_usages = sqlalchemy_api._get_quota_usages( + self.context, 'project1' + ) result = {'project_id': "project1"} for k, v in quota_usages.items(): result[v.resource] = dict(in_use=v.in_use, reserved=v.reserved) - self.assertEqual({'project_id': 'project1', - 'gpu': {'reserved': 2, 'in_use': 0}, - 'fpga': {'reserved': 1, 'in_use': 0}}, - result) + self.assertEqual( + { + 'project_id': 'project1', + 'gpu': {'reserved': 2, 'in_use': 0}, + 'fpga': {'reserved': 1, 'in_use': 0}, + }, + result, + ) def _test__get_quota_usages(self): _quota_reserve(self.context, 'project1') sqlalchemy_api = sqlalchemyapi.get_backend() - quota_usages = sqlalchemy_api._get_quota_usages(self.context, - 'project1') + quota_usages = sqlalchemy_api._get_quota_usages( + self.context, 'project1' + ) - self.assertEqual(['fpga', 'gpu'], - sorted(quota_usages.keys())) + self.assertEqual(['fpga', 'gpu'], sorted(quota_usages.keys())) def _test__get_quota_usages_with_resources(self): _quota_reserve(self.context, 'project1') sqlalchemy_api = sqlalchemyapi.get_backend() quota_usage = sqlalchemy_api._get_quota_usages( - self.context, 'project1', resources=['gpu']) + self.context, 'project1', resources=['gpu'] + ) self.assertEqual(['gpu'], list(quota_usage.keys())) class DBAPIReservationTestCase(base.DbTestCase): - """Tests for db.api.reservation_* methods.""" def setUp(self): @@ -92,9 +100,8 @@ class DBAPIReservationTestCase(base.DbTestCase): 'project_id': 'project1', 'resource': 'resource', 'delta': 42, - 'expire': (timeutils.utcnow() + - datetime.timedelta(days=1)), - 'usage': {'id': 1} + 'expire': (timeutils.utcnow() + datetime.timedelta(days=1)), + 'usage': {'id': 1}, } def _test__get_reservation_resources(self): @@ -102,16 +109,18 @@ class DBAPIReservationTestCase(base.DbTestCase): reservations = _quota_reserve(self.context, 'project1') expected = ['fpga', 'gpu'] resources = sqlalchemy_api._get_reservation_resources( - self.context, reservations) + self.context, reservations + ) self.assertEqual(expected, sorted(resources)) def _test_reservation_commit(self): db_api = dbapi.get_instance() reservations = _quota_reserve(self.context, 'project1') - expected = {'project_id': 'project1', - 'fpga': {'reserved': 1, 'in_use': 0}, - 'gpu': {'reserved': 2, 'in_use': 0}, - } + expected = { + 'project_id': 'project1', + 'fpga': {'reserved': 1, 'in_use': 0}, + 'gpu': {'reserved': 2, 'in_use': 0}, + } quota_usages = db_api._get_quota_usages(self.context, 'project1') result = {'project_id': "project1"} for k, v in quota_usages.items(): @@ -120,13 +129,13 @@ class DBAPIReservationTestCase(base.DbTestCase): self.assertEqual(expected, result) db_api.reservation_commit(self.context, reservations, 'project1') - expected = {'project_id': 'project1', - 'fpga': {'reserved': 0, 'in_use': 1}, - 'gpu': {'reserved': 0, 'in_use': 2}, - } + expected = { + 'project_id': 'project1', + 'fpga': {'reserved': 0, 'in_use': 1}, + 'gpu': {'reserved': 0, 'in_use': 2}, + } quota_usages = db_api._get_quota_usages(self.context, 'project1') result = {'project_id': "project1"} for k, v in quota_usages.items(): - result[v.resource] = dict(in_use=v.in_use, - reserved=v.reserved) + result[v.resource] = dict(in_use=v.in_use, reserved=v.reserved) self.assertEqual(expected, result) diff --git a/cyborg/tests/unit/db/test_db_attach_handle.py b/cyborg/tests/unit/db/test_db_attach_handle.py index de0cae3e..6c1352dc 100644 --- a/cyborg/tests/unit/db/test_db_attach_handle.py +++ b/cyborg/tests/unit/db/test_db_attach_handle.py @@ -23,7 +23,6 @@ from cyborg.tests.unit.db import utils class TestDbAttachHandle(base.DbTestCase): - def test_create(self): random_uuid = uuidutils.generate_uuid() kw = {'uuid': random_uuid} @@ -33,28 +32,30 @@ class TestDbAttachHandle(base.DbTestCase): def test_get_by_uuid(self): created_ah = utils.create_test_attach_handle(self.context) queried_ah = self.dbapi.attach_handle_get_by_uuid( - self.context, created_ah['uuid']) + self.context, created_ah['uuid'] + ) self.assertEqual(created_ah['uuid'], queried_ah['uuid']) def test_get_by_id(self): created_ah = utils.create_test_attach_handle(self.context) queried_ah = self.dbapi.attach_handle_get_by_id( - self.context, created_ah['id']) + self.context, created_ah['id'] + ) self.assertEqual(created_ah['id'], queried_ah['id']) def test_update(self): created_ah = utils.create_test_attach_handle(self.context) queried_ah = self.dbapi.attach_handle_update( - self.context, created_ah['uuid'], {'attach_type': 'TEST_PCI'}) + self.context, created_ah['uuid'], {'attach_type': 'TEST_PCI'} + ) self.assertEqual('TEST_PCI', queried_ah['attach_type']) def test_list(self): uuids = [] for i in range(1, 4): ah = utils.create_test_attach_handle( - self.context, - id=i, - uuid=uuidutils.generate_uuid()) + self.context, id=i, uuid=uuidutils.generate_uuid() + ) uuids.append(ah['uuid']) ahs = self.dbapi.attach_handle_list(self.context) ah_uuids = [item.uuid for item in ahs] @@ -65,48 +66,47 @@ class TestDbAttachHandle(base.DbTestCase): self.context, id=1, uuid=uuidutils.generate_uuid(), - attach_type='PCI') + attach_type='PCI', + ) utils.create_test_attach_handle( self.context, id=2, uuid=uuidutils.generate_uuid(), - attach_type='TEST_PCI') + attach_type='TEST_PCI', + ) res = self.dbapi.attach_handle_list_by_type( - self.context, attach_type='PCI') + self.context, attach_type='PCI' + ) self.assertEqual(1, len(res)) self.assertEqual(ah1['uuid'], res[0]['uuid']) def test_get_by_filters(self): ah1 = utils.create_test_attach_handle( - self.context, - id=1, - uuid=uuidutils.generate_uuid(), - deployable_id=1) + self.context, id=1, uuid=uuidutils.generate_uuid(), deployable_id=1 + ) utils.create_test_attach_handle( - self.context, - id=2, - uuid=uuidutils.generate_uuid(), - deployable_id=2) + self.context, id=2, uuid=uuidutils.generate_uuid(), deployable_id=2 + ) res = self.dbapi.attach_handle_get_by_filters( - self.context, filters={"deployable_id": 1}) + self.context, filters={"deployable_id": 1} + ) self.assertEqual(1, len(res)) self.assertEqual(ah1['uuid'], res[0]['uuid']) def test_allocate(self): utils.create_test_attach_handle( - self.context, - id=1, - uuid=uuidutils.generate_uuid(), - deployable_id=1) + self.context, id=1, uuid=uuidutils.generate_uuid(), deployable_id=1 + ) allocate_ah = self.dbapi.attach_handle_allocate( - self.context, deployable_id=1) + self.context, deployable_id=1 + ) self.assertTrue(allocate_ah['in_use']) def test_delete(self): created_ah = utils.create_test_attach_handle(self.context) return_value = self.dbapi.attach_handle_delete( - self.context, - created_ah['uuid']) + self.context, created_ah['uuid'] + ) self.assertIsNone(return_value) def test_list_filter_is_none(self): @@ -115,28 +115,37 @@ class TestDbAttachHandle(base.DbTestCase): handle same as the List Attach Handle API response. """ ah1 = utils.create_test_attach_handle( - self.context, - id=1, - uuid=uuidutils.generate_uuid()) + self.context, id=1, uuid=uuidutils.generate_uuid() + ) res = self.dbapi.attach_handle_get_by_filters( - self.context, filters=None) + self.context, filters=None + ) self.assertEqual(1, len(res)) self.assertEqual(ah1['uuid'], res[0]['uuid']) def test_get_by_uuid_not_exist(self): random_uuid = uuidutils.generate_uuid() - self.assertRaises(exception.ResourceNotFound, - self.dbapi.attach_handle_get_by_uuid, - self.context, random_uuid) + self.assertRaises( + exception.ResourceNotFound, + self.dbapi.attach_handle_get_by_uuid, + self.context, + random_uuid, + ) def test_delete_by_uuid_not_exist(self): random_uuid = uuidutils.generate_uuid() - self.assertRaises(exception.ResourceNotFound, - self.dbapi.attach_handle_delete, - self.context, random_uuid) + self.assertRaises( + exception.ResourceNotFound, + self.dbapi.attach_handle_delete, + self.context, + random_uuid, + ) def test_do_allocate_attach_handle(self): dep_id = 100 - self.assertRaises(exception.ResourceNotFound, - self.dbapi._do_allocate_attach_handle, - self.context, dep_id) + self.assertRaises( + exception.ResourceNotFound, + self.dbapi._do_allocate_attach_handle, + self.context, + dep_id, + ) diff --git a/cyborg/tests/unit/db/test_db_attribute.py b/cyborg/tests/unit/db/test_db_attribute.py index 5a4eeaf5..fd6990a1 100644 --- a/cyborg/tests/unit/db/test_db_attribute.py +++ b/cyborg/tests/unit/db/test_db_attribute.py @@ -22,15 +22,20 @@ from cyborg.tests.unit.db import base class TestDbAttributeTestCase(base.DbTestCase): - def test_get_by_uuid_not_exist(self): random_uuid = uuidutils.generate_uuid() - self.assertRaises(exception.ResourceNotFound, - self.dbapi.attribute_get, - self.context, random_uuid) + self.assertRaises( + exception.ResourceNotFound, + self.dbapi.attribute_get, + self.context, + random_uuid, + ) def test_delete_by_uuid_not_exist(self): random_uuid = uuidutils.generate_uuid() - self.assertRaises(exception.ResourceNotFound, - self.dbapi.attribute_delete, - self.context, random_uuid) + self.assertRaises( + exception.ResourceNotFound, + self.dbapi.attribute_delete, + self.context, + random_uuid, + ) diff --git a/cyborg/tests/unit/db/test_db_control_path.py b/cyborg/tests/unit/db/test_db_control_path.py index 490da043..022d4b14 100644 --- a/cyborg/tests/unit/db/test_db_control_path.py +++ b/cyborg/tests/unit/db/test_db_control_path.py @@ -22,9 +22,11 @@ from cyborg.tests.unit.db import base class TestDbControlPath(base.DbTestCase): - def test_get_by_uuid_not_exist(self): random_uuid = uuidutils.generate_uuid() - self.assertRaises(exception.ResourceNotFound, - self.dbapi.control_path_get_by_uuid, - self.context, random_uuid) + self.assertRaises( + exception.ResourceNotFound, + self.dbapi.control_path_get_by_uuid, + self.context, + random_uuid, + ) diff --git a/cyborg/tests/unit/db/test_db_deployable.py b/cyborg/tests/unit/db/test_db_deployable.py index 1d4badc0..baa073ba 100644 --- a/cyborg/tests/unit/db/test_db_deployable.py +++ b/cyborg/tests/unit/db/test_db_deployable.py @@ -23,7 +23,6 @@ from cyborg.tests.unit.db import utils class TestDbDeployable(base.DbTestCase): - def test_create(self): kw = {'name': 'test_create_dep'} created_dep = utils.create_test_deployable(self.context, **kw) @@ -32,30 +31,31 @@ class TestDbDeployable(base.DbTestCase): def test_get_by_uuid(self): created_dep = utils.create_test_deployable(self.context) queried_dep = self.dbapi.deployable_get( - self.context, created_dep['uuid']) + self.context, created_dep['uuid'] + ) self.assertEqual(created_dep['uuid'], queried_dep['uuid']) def test_get_by_rp_uuid(self): created_dep = utils.create_test_deployable(self.context) queried_dep = self.dbapi.deployable_get_by_rp_uuid( - self.context, created_dep['rp_uuid']) + self.context, created_dep['rp_uuid'] + ) self.assertEqual(created_dep['uuid'], queried_dep['uuid']) def test_update(self): created_dep = utils.create_test_deployable(self.context) bit_stream_id = '10efe63d-dfea-4a37-ad94-4116fba5011' queried_dep = self.dbapi.deployable_update( - self.context, created_dep['uuid'], - {'bit_stream_id': bit_stream_id}) + self.context, created_dep['uuid'], {'bit_stream_id': bit_stream_id} + ) self.assertEqual(bit_stream_id, queried_dep['bit_stream_id']) def test_list(self): uuids = [] for i in range(1, 4): dep = utils.create_test_deployable( - self.context, - id=i, - uuid=uuidutils.generate_uuid()) + self.context, id=i, uuid=uuidutils.generate_uuid() + ) uuids.append(dep['uuid']) deps = self.dbapi.deployable_list(self.context) dep_uuids = [item.uuid for item in deps] @@ -64,23 +64,20 @@ class TestDbDeployable(base.DbTestCase): def test_delete(self): created_dep = utils.create_test_deployable(self.context) return_value = self.dbapi.deployable_delete( - self.context, - created_dep['uuid']) + self.context, created_dep['uuid'] + ) self.assertIsNone(return_value) def test_list_by_filters(self): dep1 = utils.create_test_deployable( - self.context, - id=1, - uuid=uuidutils.generate_uuid(), - name='mydep1') + self.context, id=1, uuid=uuidutils.generate_uuid(), name='mydep1' + ) utils.create_test_deployable( - self.context, - id=2, - uuid=uuidutils.generate_uuid(), - name='mydep2') + self.context, id=2, uuid=uuidutils.generate_uuid(), name='mydep2' + ) res = self.dbapi.deployable_get_by_filters( - self.context, filters={"name": "mydep1"}) + self.context, filters={"name": "mydep1"} + ) self.assertEqual(1, len(res)) self.assertEqual(dep1['name'], res[0]['name']) @@ -90,28 +87,35 @@ class TestDbDeployable(base.DbTestCase): same as the List Deployable API response. """ dep1 = utils.create_test_deployable( - self.context, - id=1, - uuid=uuidutils.generate_uuid()) - res = self.dbapi.deployable_get_by_filters( - self.context, filters=None) + self.context, id=1, uuid=uuidutils.generate_uuid() + ) + res = self.dbapi.deployable_get_by_filters(self.context, filters=None) self.assertEqual(1, len(res)) self.assertEqual(dep1['uuid'], res[0]['uuid']) def test_get_by_uuid_not_exist(self): random_uuid = uuidutils.generate_uuid() - self.assertRaises(exception.ResourceNotFound, - self.dbapi.deployable_get, - self.context, random_uuid) + self.assertRaises( + exception.ResourceNotFound, + self.dbapi.deployable_get, + self.context, + random_uuid, + ) def test_get_by_rp_uuid_not_exist(self): random_uuid = uuidutils.generate_uuid() - self.assertRaises(exception.ResourceNotFound, - self.dbapi.deployable_get_by_rp_uuid, - self.context, random_uuid) + self.assertRaises( + exception.ResourceNotFound, + self.dbapi.deployable_get_by_rp_uuid, + self.context, + random_uuid, + ) def test_delete_by_uuid_not_exist(self): random_uuid = uuidutils.generate_uuid() - self.assertRaises(exception.ResourceNotFound, - self.dbapi.deployable_delete, - self.context, random_uuid) + self.assertRaises( + exception.ResourceNotFound, + self.dbapi.deployable_delete, + self.context, + random_uuid, + ) diff --git a/cyborg/tests/unit/db/test_db_device.py b/cyborg/tests/unit/db/test_db_device.py index 8af829c6..3b3d3eb0 100644 --- a/cyborg/tests/unit/db/test_db_device.py +++ b/cyborg/tests/unit/db/test_db_device.py @@ -25,7 +25,6 @@ from cyborg.tests.unit.db import utils class TestDbDevice(base.DbTestCase): - def test_create(self): random_uuid = uuidutils.generate_uuid() kw = {'uuid': random_uuid} @@ -34,29 +33,29 @@ class TestDbDevice(base.DbTestCase): def test_get_by_uuid(self): created_dev = utils.create_test_device(self.context) - queried_dev = self.dbapi.device_get( - self.context, created_dev['uuid']) + queried_dev = self.dbapi.device_get(self.context, created_dev['uuid']) self.assertEqual(created_dev['uuid'], queried_dev['uuid']) def test_get_by_id(self): created_dev = utils.create_test_device(self.context) queried_dev = self.dbapi.device_get_by_id( - self.context, created_dev['id']) + self.context, created_dev['id'] + ) self.assertEqual(created_dev['id'], queried_dev['id']) def test_update(self): created_dev = utils.create_test_device(self.context) queried_dev = self.dbapi.device_update( - self.context, created_dev['uuid'], {'hostname': 'myhost'}) + self.context, created_dev['uuid'], {'hostname': 'myhost'} + ) self.assertEqual('myhost', queried_dev['hostname']) def test_list(self): uuids = [] for i in range(1, 4): dev = utils.create_test_device( - self.context, - id=i, - uuid=uuidutils.generate_uuid()) + self.context, id=i, uuid=uuidutils.generate_uuid() + ) uuids.append(dev['uuid']) devs = self.dbapi.device_list(self.context) dev_uuids = [item.uuid for item in devs] @@ -67,22 +66,25 @@ class TestDbDevice(base.DbTestCase): self.context, id=1, uuid=uuidutils.generate_uuid(), - hostname='myhost1') + hostname='myhost1', + ) utils.create_test_device( self.context, id=2, uuid=uuidutils.generate_uuid(), - hostname='myhost2') + hostname='myhost2', + ) res = self.dbapi.device_list_by_filters( - self.context, filters={"hostname": "myhost1"}) + self.context, filters={"hostname": "myhost1"} + ) self.assertEqual(1, len(res)) self.assertEqual(dev1['hostname'], res[0]['hostname']) def test_delete(self): created_dev = utils.create_test_device(self.context) return_value = self.dbapi.device_delete( - self.context, - created_dev['uuid']) + self.context, created_dev['uuid'] + ) self.assertIsNone(return_value) def test_list_filter_is_none(self): @@ -91,28 +93,35 @@ class TestDbDevice(base.DbTestCase): same as the List Device API response. """ dev1 = utils.create_test_device( - self.context, - id=1, - uuid=uuidutils.generate_uuid()) - res = self.dbapi.device_list_by_filters( - self.context, filters=None) + self.context, id=1, uuid=uuidutils.generate_uuid() + ) + res = self.dbapi.device_list_by_filters(self.context, filters=None) self.assertEqual(1, len(res)) self.assertEqual(dev1['uuid'], res[0]['uuid']) def test_get_by_uuid_not_exist(self): random_uuid = uuidutils.generate_uuid() - self.assertRaises(exception.ResourceNotFound, - self.dbapi.device_get, - self.context, random_uuid) + self.assertRaises( + exception.ResourceNotFound, + self.dbapi.device_get, + self.context, + random_uuid, + ) def test_get_by_id_not_exist(self): fake_id = sys.maxsize - self.assertRaises(exception.ResourceNotFound, - self.dbapi.device_get_by_id, - self.context, fake_id) + self.assertRaises( + exception.ResourceNotFound, + self.dbapi.device_get_by_id, + self.context, + fake_id, + ) def test_delete_by_uuid_not_exist(self): random_uuid = uuidutils.generate_uuid() - self.assertRaises(exception.ResourceNotFound, - self.dbapi.device_delete, - self.context, random_uuid) + self.assertRaises( + exception.ResourceNotFound, + self.dbapi.device_delete, + self.context, + random_uuid, + ) diff --git a/cyborg/tests/unit/db/test_db_device_profile.py b/cyborg/tests/unit/db/test_db_device_profile.py index 10b7de6d..745d9852 100644 --- a/cyborg/tests/unit/db/test_db_device_profile.py +++ b/cyborg/tests/unit/db/test_db_device_profile.py @@ -26,53 +26,64 @@ from cyborg.tests.unit.db import utils class TestDbDeviceProfile(base.DbTestCase): - def test_create_dp(self): created_dp = utils.create_test_device_profile(self.context) expected_dp = utils.get_test_device_profile() - self.assertEqual(json.loads(created_dp.profile_json)['groups'], - json.loads(expected_dp['profile_json'])['groups']) + self.assertEqual( + json.loads(created_dp.profile_json)['groups'], + json.loads(expected_dp['profile_json'])['groups'], + ) def test_create_dp_with_duplicate_name(self): utils.create_test_device_profile(self.context) duplicate_dp = utils.get_test_device_profile() duplicate_dp['id'] = 2 duplicate_dp['uuid'] = uuidutils.generate_uuid() - self.assertRaises(exception.DuplicateDeviceProfileName, - self.dbapi.device_profile_create, - self.context, duplicate_dp) + self.assertRaises( + exception.DuplicateDeviceProfileName, + self.dbapi.device_profile_create, + self.context, + duplicate_dp, + ) def test_create_dp_with_duplicate_uuid(self): utils.create_test_device_profile(self.context) duplicate_dp = utils.get_test_device_profile() - self.assertRaises(exception.DeviceProfileAlreadyExists, - self.dbapi.device_profile_create, - self.context, duplicate_dp) + self.assertRaises( + exception.DeviceProfileAlreadyExists, + self.dbapi.device_profile_create, + self.context, + duplicate_dp, + ) def test_get_by_uuid(self): created_dp = utils.create_test_device_profile(self.context) queried_dp = self.dbapi.device_profile_get_by_uuid( - self.context, created_dp['uuid']) + self.context, created_dp['uuid'] + ) self.assertEqual(created_dp['uuid'], queried_dp['uuid']) self.assertIn('description', queried_dp) def test_get_by_id(self): created_dp = utils.create_test_device_profile(self.context) queried_dp = self.dbapi.device_profile_get_by_id( - self.context, created_dp['id']) + self.context, created_dp['id'] + ) self.assertEqual(created_dp['id'], queried_dp['id']) self.assertIn('description', queried_dp) def test_update_with_name(self): created_dp = utils.create_test_device_profile(self.context) queried_dp = self.dbapi.device_profile_update( - self.context, created_dp['uuid'], {'name': 'updated_name'}) + self.context, created_dp['uuid'], {'name': 'updated_name'} + ) self.assertEqual('updated_name', queried_dp['name']) def test_update_with_description(self): created_dp = utils.create_test_device_profile(self.context) queried_dp = self.dbapi.device_profile_update( - self.context, created_dp['uuid'], {'description': 'fake-desc'}) + self.context, created_dp['uuid'], {'description': 'fake-desc'} + ) self.assertEqual('fake-desc', queried_dp['description']) def test_list(self): @@ -82,7 +93,8 @@ class TestDbDeviceProfile(base.DbTestCase): self.context, id=i, uuid=uuidutils.generate_uuid(), - name="device_profile_name_%s" % i) + name="device_profile_name_%s" % i, + ) uuids.append(dp['uuid']) dps = self.dbapi.device_profile_list(self.context) dp_uuids = [item.uuid for item in dps] @@ -90,32 +102,32 @@ class TestDbDeviceProfile(base.DbTestCase): def test_list_filter_by_name(self): utils.create_test_device_profile( - self.context, - id=1, - uuid=uuidutils.generate_uuid(), - name="name_1") + self.context, id=1, uuid=uuidutils.generate_uuid(), name="name_1" + ) utils.create_test_device_profile( - self.context, - id=2, - uuid=uuidutils.generate_uuid(), - name="name_2") + self.context, id=2, uuid=uuidutils.generate_uuid(), name="name_2" + ) res = self.dbapi.device_profile_list_by_filters( - self.context, filters={"name": "name_1"}) + self.context, filters={"name": "name_1"} + ) self.assertEqual(1, len(res)) self.assertEqual('name_1', res[0]['name']) def test_delete(self): created_dp = utils.create_test_device_profile(self.context) return_value = self.dbapi.device_profile_delete( - self.context, - created_dp['uuid']) + self.context, created_dp['uuid'] + ) self.assertIsNone(return_value) def test_get_by_uuid_not_exist(self): random_uuid = uuidutils.generate_uuid() - self.assertRaises(exception.ResourceNotFound, - self.dbapi.device_profile_get_by_uuid, - self.context, random_uuid) + self.assertRaises( + exception.ResourceNotFound, + self.dbapi.device_profile_get_by_uuid, + self.context, + random_uuid, + ) def test_list_filter_is_none(self): """The main test is filters=None. If filters=None, @@ -123,38 +135,48 @@ class TestDbDeviceProfile(base.DbTestCase): profiles same as the List Device Profiles API response. """ utils.create_test_device_profile( - self.context, - id=1, - uuid=uuidutils.generate_uuid(), - name="foo_dp") + self.context, id=1, uuid=uuidutils.generate_uuid(), name="foo_dp" + ) res = self.dbapi.device_profile_list_by_filters( - self.context, filters=None) + self.context, filters=None + ) self.assertEqual(1, len(res)) self.assertEqual('foo_dp', res[0]['name']) def test_update_with_uuid_not_exist(self): utils.create_test_device_profile(self.context) random_uuid = uuidutils.generate_uuid() - self.assertRaises(exception.ResourceNotFound, - self.dbapi.device_profile_update, - self.context, - random_uuid, - {'name': 'updated_name'}) + self.assertRaises( + exception.ResourceNotFound, + self.dbapi.device_profile_update, + self.context, + random_uuid, + {'name': 'updated_name'}, + ) def test_get_by_id_not_exist(self): fake_id = sys.maxsize - self.assertRaises(exception.ResourceNotFound, - self.dbapi.device_profile_get_by_id, - self.context, fake_id) + self.assertRaises( + exception.ResourceNotFound, + self.dbapi.device_profile_get_by_id, + self.context, + fake_id, + ) def test_get_by_name_not_exist(self): random_name = 'fake' + uuidutils.generate_uuid() - self.assertRaises(exception.ResourceNotFound, - self.dbapi.device_profile_get, - self.context, random_name) + self.assertRaises( + exception.ResourceNotFound, + self.dbapi.device_profile_get, + self.context, + random_name, + ) def test_delete_by_uuid_not_exist(self): random_uuid = uuidutils.generate_uuid() - self.assertRaises(exception.ResourceNotFound, - self.dbapi.device_profile_delete, - self.context, random_uuid) + self.assertRaises( + exception.ResourceNotFound, + self.dbapi.device_profile_delete, + self.context, + random_uuid, + ) diff --git a/cyborg/tests/unit/db/test_db_extarq.py b/cyborg/tests/unit/db/test_db_extarq.py index 9d523392..8614551f 100644 --- a/cyborg/tests/unit/db/test_db_extarq.py +++ b/cyborg/tests/unit/db/test_db_extarq.py @@ -23,7 +23,6 @@ from cyborg.tests.unit.db import utils class TestDbExtArq(base.DbTestCase): - def test_create(self): random_uuid = uuidutils.generate_uuid() kw = {'uuid': random_uuid} @@ -33,22 +32,23 @@ class TestDbExtArq(base.DbTestCase): def test_get_by_uuid(self): created_extarq = utils.create_test_extarq(self.context) queried_extarq = self.dbapi.extarq_get( - self.context, created_extarq['uuid']) + self.context, created_extarq['uuid'] + ) self.assertEqual(created_extarq['uuid'], queried_extarq['uuid']) def test_update(self): created_extarq = utils.create_test_extarq(self.context) queried_extarq = self.dbapi.extarq_update( - self.context, created_extarq['uuid'], {'state': 'Initial'}) + self.context, created_extarq['uuid'], {'state': 'Initial'} + ) self.assertEqual('Initial', queried_extarq['state']) def test_list(self): uuids = [] for i in range(1, 4): extarq = utils.create_test_extarq( - self.context, - id=i, - uuid=uuidutils.generate_uuid()) + self.context, id=i, uuid=uuidutils.generate_uuid() + ) uuids.append(extarq['uuid']) extarqs = self.dbapi.extarq_list(self.context) extarq_uuids = [item.uuid for item in extarqs] @@ -57,18 +57,24 @@ class TestDbExtArq(base.DbTestCase): def test_delete(self): created_extarq = utils.create_test_extarq(self.context) return_value = self.dbapi.extarq_delete( - self.context, - created_extarq['uuid']) + self.context, created_extarq['uuid'] + ) self.assertIsNone(return_value) def test_get_by_uuid_not_exist(self): random_uuid = uuidutils.generate_uuid() - self.assertRaises(exception.ResourceNotFound, - self.dbapi.extarq_get, - self.context, random_uuid) + self.assertRaises( + exception.ResourceNotFound, + self.dbapi.extarq_get, + self.context, + random_uuid, + ) def test_delete_by_uuid_not_exist(self): random_uuid = uuidutils.generate_uuid() - self.assertRaises(exception.ResourceNotFound, - self.dbapi.extarq_delete, - self.context, random_uuid) + self.assertRaises( + exception.ResourceNotFound, + self.dbapi.extarq_delete, + self.context, + random_uuid, + ) diff --git a/cyborg/tests/unit/db/test_migrations.py b/cyborg/tests/unit/db/test_migrations.py index 1fc3f351..fd1dd7b4 100644 --- a/cyborg/tests/unit/db/test_migrations.py +++ b/cyborg/tests/unit/db/test_migrations.py @@ -35,8 +35,7 @@ MIGRATIONS_TIMEOUT = 300 @contextlib.contextmanager def patch_with_engine(engine): - with mock.patch.object(enginefacade.writer, - 'get_engine') as patch_engine: + with mock.patch.object(enginefacade.writer, 'get_engine') as patch_engine: patch_engine.return_value = engine yield @@ -57,8 +56,9 @@ class WalkVersionsMixin: versions = [ver for ver in script_directory.walk_revisions()] for version in reversed(versions): - self._migrate_up(engine, alembic_cfg, - version.revision, with_data=True) + self._migrate_up( + engine, alembic_cfg, version.revision, with_data=True + ) def _skippable_migrations(self): # Some db scripts are not necessary to check @@ -80,7 +80,8 @@ class WalkVersionsMixin: self.assertIsNotNone( check, f'DB Migration {version} does not have ' - 'a test. Please add one!') + 'a test. Please add one!', + ) class TestWalkVersions(base.TestCase, WalkVersionsMixin): @@ -94,20 +95,19 @@ class TestWalkVersions(base.TestCase, WalkVersionsMixin): def test_migrate_up(self): self.migration_api.version.return_value = '6a7f90fc3s8c' self._migrate_up(self.engine, self.config, '6a7f90fc3s8c') - self.migration_api.upgrade.assert_called_with('6a7f90fc3s8c', - config=self.config) + self.migration_api.upgrade.assert_called_with( + '6a7f90fc3s8c', config=self.config + ) self.migration_api.version.assert_called_with(self.config) class CyborgMigrationsCheckers: - def setUp(self): super().setUp() self.engine = enginefacade.writer.get_engine() self.config = migration._alembic_config() self.migration_api = migration - self.useFixture(fixtures.Timeout(MIGRATIONS_TIMEOUT, - gentle=True)), + (self.useFixture(fixtures.Timeout(MIGRATIONS_TIMEOUT, gentle=True)),) def test_walk_versions(self): self._walk_versions(self.engine, self.config) @@ -116,8 +116,7 @@ class CyborgMigrationsCheckers: devices = db_utils.get_table(engine, 'devices') col_names = [column.name for column in devices.c] self.assertIn('type', col_names) - self.assertIsInstance(devices.c.type.type, - sqlalchemy.types.Enum) + self.assertIsInstance(devices.c.type.type, sqlalchemy.types.Enum) def test_upgrade_and_version(self): with patch_with_engine(self.engine): @@ -132,8 +131,9 @@ class CyborgMigrationsCheckers: def test_upgrade_and_create_schema(self): with patch_with_engine(self.engine): self.migration_api.upgrade('ede4e3f1a232') - self.assertRaises(db_exc.DBMigrationError, - self.migration_api.create_schema) + self.assertRaises( + db_exc.DBMigrationError, self.migration_api.create_schema + ) def test_upgrade_twice(self): with patch_with_engine(self.engine): @@ -144,15 +144,19 @@ class CyborgMigrationsCheckers: self.assertNotEqual(v1, v2) -class TestCyborgMigrationsMySQL(CyborgMigrationsCheckers, - WalkVersionsMixin, - test_fixtures.OpportunisticDBTestMixin, - test_base.BaseTestCase): +class TestCyborgMigrationsMySQL( + CyborgMigrationsCheckers, + WalkVersionsMixin, + test_fixtures.OpportunisticDBTestMixin, + test_base.BaseTestCase, +): FIXTURE = test_fixtures.MySQLOpportunisticFixture -class TestMigrationsPostgreSQL(CyborgMigrationsCheckers, - WalkVersionsMixin, - test_fixtures.OpportunisticDBTestMixin, - test_base.BaseTestCase): +class TestMigrationsPostgreSQL( + CyborgMigrationsCheckers, + WalkVersionsMixin, + test_fixtures.OpportunisticDBTestMixin, + test_base.BaseTestCase, +): FIXTURE = test_fixtures.PostgresqlOpportunisticFixture diff --git a/cyborg/tests/unit/db/utils.py b/cyborg/tests/unit/db/utils.py index 5c7fd19f..b2f2da38 100644 --- a/cyborg/tests/unit/db/utils.py +++ b/cyborg/tests/unit/db/utils.py @@ -31,7 +31,7 @@ def get_test_deployable(**kw): 'rp_uuid': kw.get('rp_uuid', '1c559644-2b56-3470-8427-d9d71f0f8621'), 'bitstream_id': kw.get('bitstream_id', None), 'created_at': kw.get('created_at', None), - 'updated_at': kw.get('updated_at', None) + 'updated_at': kw.get('updated_at', None), } @@ -49,10 +49,7 @@ def create_test_deployable(context, **kwargs): def get_test_device(**kw): - std_board_info = { - "class": "Fake class", - "device_id": "0xabcd" - } + std_board_info = {"class": "Fake class", "device_id": "0xabcd"} return { 'id': kw.get('id', 1), 'uuid': kw.get('uuid', '20efe63d-dfea-4a37-ad94-4116fba50122'), @@ -87,14 +84,15 @@ def get_test_extarq(**kwargs): 'state': kwargs.get('state', 'Bound'), 'device_profile_id': kwargs.get('id', 1), 'hostname': kwargs.get('hostname', 'testnode1'), - 'device_rp_uuid': kwargs.get('device_rp_uuid', - 'f2b96c5f-242a-41a0-a736-b6e1fada071b'), - 'device_instance_uuid': - kwargs.get('device_rp_uuid', - '6219e0fb-2935-4db2-a3c7-86a2ac3ac84e'), + 'device_rp_uuid': kwargs.get( + 'device_rp_uuid', 'f2b96c5f-242a-41a0-a736-b6e1fada071b' + ), + 'device_instance_uuid': kwargs.get( + 'device_rp_uuid', '6219e0fb-2935-4db2-a3c7-86a2ac3ac84e' + ), 'attach_handle_id': kwargs.get('id', 1), 'created_at': kwargs.get('created_at', None), - 'updated_at': kwargs.get('updated_at', None) + 'updated_at': kwargs.get('updated_at', None), } @@ -118,11 +116,12 @@ def get_test_arq(**kwargs): 'state': kwargs.get('state', 'Initial'), 'device_profile': kwargs.get('device_profile', None), 'hostname': kwargs.get('hostname', 'testnode1'), - 'device_rp_uuid': kwargs.get('device_rp_uuid', - 'f2b96c5f-242a-41a0-a736-b6e1fada071b'), - 'device_instance_uuid': - kwargs.get('device_rp_uuid', - '6219e0fb-2935-4db2-a3c7-86a2ac3ac84e'), + 'device_rp_uuid': kwargs.get( + 'device_rp_uuid', 'f2b96c5f-242a-41a0-a736-b6e1fada071b' + ), + 'device_instance_uuid': kwargs.get( + 'device_rp_uuid', '6219e0fb-2935-4db2-a3c7-86a2ac3ac84e' + ), 'attach_handle': kwargs.get('attach_handle', None), 'created_at': kwargs.get('created_at', None), 'updated_at': kwargs.get('updated_at', None), @@ -164,9 +163,10 @@ def get_test_control_path(**kw): 'id': kw.get('id', 1), 'device_id': kw.get('device_id', 1), 'cpid_type': kw.get('cpid_type', "PCI"), - 'cpid_info': kw.get('cpid_info', - '{"device": "2", "bus": "00", "function": "01", ' - '"domain": "0001"}'), + 'cpid_info': kw.get( + 'cpid_info', + '{"device": "2", "bus": "00", "function": "01", "domain": "0001"}', + ), 'created_at': kw.get('create_at', None), 'updated_at': kw.get('updated_at', None), } @@ -183,7 +183,8 @@ def get_test_device_profile(**kw): '{"version": "1.0", \ "groups": [{"resources:CUSTOM_ACCELERATOR_FPGA": "1"}, \ {"trait:CUSTOM_FPGA_INTEL_PAC_ARRIA10": "required"}, \ - {"trait:CUSTOM_FUNCTION_ID_3AFB": "required"}]}'), + {"trait:CUSTOM_FUNCTION_ID_3AFB": "required"}]}', + ), 'created_at': kw.get('create_at', None), 'updated_at': kw.get('updated_at', None), } diff --git a/cyborg/tests/unit/db_lock_fixture.py b/cyborg/tests/unit/db_lock_fixture.py index d02abc7c..091552c2 100644 --- a/cyborg/tests/unit/db_lock_fixture.py +++ b/cyborg/tests/unit/db_lock_fixture.py @@ -36,8 +36,7 @@ class DatabaseWriteLock(fixtures.Fixture): """ def _setUp(self): - original = ( - enginefacade._TransactionContextManager._transaction_scope) + original = enginefacade._TransactionContextManager._transaction_scope @contextlib.contextmanager def _locked_scope(tcm_self, context): @@ -52,7 +51,10 @@ class DatabaseWriteLock(fixtures.Fixture): with original(tcm_self, context) as resource: yield resource - self.useFixture(fixtures.MockPatchObject( - enginefacade._TransactionContextManager, - '_transaction_scope', - _locked_scope)) + self.useFixture( + fixtures.MockPatchObject( + enginefacade._TransactionContextManager, + '_transaction_scope', + _locked_scope, + ) + ) diff --git a/cyborg/tests/unit/fake_attach_handle.py b/cyborg/tests/unit/fake_attach_handle.py index e203ff4a..6c39ba71 100644 --- a/cyborg/tests/unit/fake_attach_handle.py +++ b/cyborg/tests/unit/fake_attach_handle.py @@ -28,8 +28,8 @@ def get_fake_attach_handle_as_dict(): 'deployable_id': 1, 'attach_type': "PCI", 'attach_info': '{"domain": "0000", "bus": "0c",' - '"device": "0", "function": "1"}', - } + '"device": "0", "function": "1"}', + } attach_handle2 = { 'id': 2, @@ -39,8 +39,8 @@ def get_fake_attach_handle_as_dict(): 'deployable_id': 2, 'attach_type': "PCI", 'attach_info': '{"domain": "0000", "bus": "0c",' - '"device": "0", "function": "1"}', - } + '"device": "0", "function": "1"}', + } return [attach_handle1, attach_handle2] diff --git a/cyborg/tests/unit/fake_attribute.py b/cyborg/tests/unit/fake_attribute.py index cf1a4eba..c1aeadc1 100644 --- a/cyborg/tests/unit/fake_attribute.py +++ b/cyborg/tests/unit/fake_attribute.py @@ -25,8 +25,8 @@ def fake_db_attribute(**updates): 'uuid': attr_uuid, 'deployable_id': 1, 'key': 'rc', - 'value': 'FPGA' - } + 'value': 'FPGA', + } for name, field in objects.Attribute.fields.items(): if name in db_attribute: @@ -48,7 +48,7 @@ def fake_attribute_obj(context, obj_attr_class=None, **updates): if obj_attr_class is None: obj_attr_class = objects.Attribute attribute = obj_attr_class._from_db_object( - obj_attr_class(), - fake_db_attribute(**updates)) + obj_attr_class(), fake_db_attribute(**updates) + ) attribute.obj_reset_changes() return attribute diff --git a/cyborg/tests/unit/fake_deployable.py b/cyborg/tests/unit/fake_deployable.py index c545aa45..a225bafd 100644 --- a/cyborg/tests/unit/fake_deployable.py +++ b/cyborg/tests/unit/fake_deployable.py @@ -31,7 +31,7 @@ def fake_db_deployable(**updates): 'driver_name': "fake-driver-name", 'rp_uuid': None, 'bitstream_id': None, - } + } for name, field in objects.Deployable.fields.items(): if name in db_deployable: @@ -52,7 +52,8 @@ def fake_db_deployable(**updates): def fake_deployable_obj(context, obj_dpl_class=None, **updates): if obj_dpl_class is None: obj_dpl_class = objects.Deployable - deploy = obj_dpl_class._from_db_object(obj_dpl_class(), - fake_db_deployable(**updates)) + deploy = obj_dpl_class._from_db_object( + obj_dpl_class(), fake_db_deployable(**updates) + ) deploy.obj_reset_changes() return deploy diff --git a/cyborg/tests/unit/fake_device.py b/cyborg/tests/unit/fake_device.py index 629c5103..1099be9e 100644 --- a/cyborg/tests/unit/fake_device.py +++ b/cyborg/tests/unit/fake_device.py @@ -26,8 +26,8 @@ def get_fake_devices_as_dict(): "vendor_board_info": "fake_vendor_info", "model": "miss model info", "type": "FPGA", - "std_board_info": "{'class': 'Fake class', 'device_id': '0xabcd'}" - } + "std_board_info": "{'class': 'Fake class', 'device_id': '0xabcd'}", + } device2 = { "id": 2, "vendor": "0xDCBA", @@ -36,8 +36,8 @@ def get_fake_devices_as_dict(): "vendor_board_info": "fake_vendor_info", "model": "miss model info", "type": "GPU", - "std_board_info": "{'class': 'Fake class', 'device_id': '0xdcba'}" - } + "std_board_info": "{'class': 'Fake class', 'device_id': '0xdcba'}", + } return [device1, device2] diff --git a/cyborg/tests/unit/fake_device_profile.py b/cyborg/tests/unit/fake_device_profile.py index 3d87dc72..ca6aecaa 100644 --- a/cyborg/tests/unit/fake_device_profile.py +++ b/cyborg/tests/unit/fake_device_profile.py @@ -13,10 +13,10 @@ # under the License. """ - See note at the start of cyborg/api/controllers/v2/device_profiles.py. - Device profiles have an API format (which is provided to POST to - create one) and an object format. The code in this file can provide - fake device profiles in either format. +See note at the start of cyborg/api/controllers/v2/device_profiles.py. +Device profiles have an API format (which is provided to POST to +create one) and an object format. The code in this file can provide +fake device profiles in either format. """ import datetime @@ -28,11 +28,11 @@ from cyborg.objects import device_profile def _get_device_profiles_as_dict(): date1 = datetime.datetime( - 2019, 10, 9, 6, 31, 59, - tzinfo=datetime.timezone.utc) + 2019, 10, 9, 6, 31, 59, tzinfo=datetime.timezone.utc + ) date2 = datetime.datetime( - 2019, 11, 8, 5, 30, 49, - tzinfo=datetime.timezone.utc) + 2019, 11, 8, 5, 30, 49, tzinfo=datetime.timezone.utc + ) dp1 = { "id": 1, "uuid": "a95e10ae-b3e3-4eab-a513-1afae6f17c51", @@ -41,14 +41,16 @@ def _get_device_profiles_as_dict(): "created_at": date1, "updated_at": None, "groups": [ - {"resources:FPGA": "1", - "trait:CUSTOM_FPGA_INTEL_PAC_ARRIA10": "required", - "trait:CUSTOM_FUNCTION_ID_3AFB": "required", - }, - {"resources:CUSTOM_ACCELERATOR_FOO": "2", - "trait:CUSTOM_TRAIT_ALWAYS": "required", - } - ] + { + "resources:FPGA": "1", + "trait:CUSTOM_FPGA_INTEL_PAC_ARRIA10": "required", + "trait:CUSTOM_FUNCTION_ID_3AFB": "required", + }, + { + "resources:CUSTOM_ACCELERATOR_FOO": "2", + "trait:CUSTOM_TRAIT_ALWAYS": "required", + }, + ], } dp2 = { "id": 2, @@ -58,11 +60,12 @@ def _get_device_profiles_as_dict(): "updated_at": None, "description": "fake-dp_example_2-desc", "groups": [ - {"resources:FPGA": "1", - "trait:CUSTOM_REGION_ID_3ACD": "required", - "accel:bitstream_id": "ea0d149c-8555-495b-bc79-608d7bab1260" - } - ] + { + "resources:FPGA": "1", + "trait:CUSTOM_REGION_ID_3ACD": "required", + "accel:bitstream_id": "ea0d149c-8555-495b-bc79-608d7bab1260", + } + ], } return [dp1, dp2] @@ -115,14 +118,15 @@ def get_xilinx_fpga_devprof(): "name": 'fake_xilinx_fpga_dp', "description": "fake_xilinx_fpga_dp-desc", "created_at": datetime.datetime( - 2022, 1, 22, 10, 40, 56, - tzinfo=datetime.timezone.utc), + 2022, 1, 22, 10, 40, 56, tzinfo=datetime.timezone.utc + ), "updated_at": None, "groups": [ - {"resources:FPGA": "1", - "trait:CUSTOM_FPGA_XILINX": "required", - "trait:CUSTOM_FPGA_PRODUCT_ID_5000": "required", - }, - ] + { + "resources:FPGA": "1", + "trait:CUSTOM_FPGA_XILINX": "required", + "trait:CUSTOM_FPGA_PRODUCT_ID_5000": "required", + }, + ], } return _convert_to_obj(xilinx_fpga_dp) diff --git a/cyborg/tests/unit/fake_driver_device.py b/cyborg/tests/unit/fake_driver_device.py index 9d078b7f..07e70450 100644 --- a/cyborg/tests/unit/fake_driver_device.py +++ b/cyborg/tests/unit/fake_driver_device.py @@ -28,7 +28,7 @@ def get_fake_driver_devices_as_dict(): "std_board_info": "{'class': 'Fake class', 'device_id': '0xabcd'}", "stub": False, "controlpath_id": get_fake_driver_controlpath_objs()[0], - "deployable_list": get_fake_driver_deployable_objs()[:1] + "deployable_list": get_fake_driver_deployable_objs()[:1], } driver_device2 = { "vendor": "0xDCBA", @@ -38,7 +38,7 @@ def get_fake_driver_devices_as_dict(): "std_board_info": "{'class': 'Fake class', 'device_id': '0xdcba'}", "stub": False, "controlpath_id": get_fake_driver_controlpath_objs()[1], - "deployable_list": get_fake_driver_deployable_objs()[1:] + "deployable_list": get_fake_driver_deployable_objs()[1:], } return [driver_device1, driver_device2] @@ -57,11 +57,11 @@ def get_fake_driver_devices_objs(): def get_fake_driver_controlpath_as_dict(): driver_controlpath1 = { "cpid_info": '{"bus": "af", "device":00, "domain":0000, "function":0}', - "cpid_type": "PCI" + "cpid_type": "PCI", } driver_controlpath2 = { "cpid_info": '{"bus": "db", "device":00, "domain":0000, "function":0}', - "cpid_type": "PCI" + "cpid_type": "PCI", } return [driver_controlpath1, driver_controlpath2] @@ -108,16 +108,14 @@ def get_fake_driver_deployable_objs(): def get_fake_driver_attach_handles_as_dict(): driver_attach_handle1 = { - "attach_info": - '{"bus": "af", "device":00, "domain":0000, "function":0}', + "attach_info": '{"bus": "af", "device":00, "domain":0000, "function":0}', "attach_type": "PCI", - "in_use": 0 + "in_use": 0, } driver_attach_handle2 = { - "attach_info": - '{"bus": "db", "device":00, "domain":0000, "function":0}', + "attach_info": '{"bus": "db", "device":00, "domain":0000, "function":0}', "attach_type": "PCI", - "in_use": 0 + "in_use": 0, } return [driver_attach_handle1, driver_attach_handle2] @@ -134,18 +132,12 @@ def get_fake_driver_attach_handle_objs(): def get_fake_driver_attributes_as_dict(): - driver_attribute1 = { - "key": "trait0", - "value": "CUSTOM_GPU_NVIDIA" - } + driver_attribute1 = {"key": "trait0", "value": "CUSTOM_GPU_NVIDIA"} driver_attribute2 = { "key": "trait1", - "value": "CUSTOM_GPU_PRODUCT_ID_1DB6" - } - driver_attribute3 = { - "key": "rc", - "value": "PGPU" + "value": "CUSTOM_GPU_PRODUCT_ID_1DB6", } + driver_attribute3 = {"key": "rc", "value": "PGPU"} return [driver_attribute1, driver_attribute2, driver_attribute3] diff --git a/cyborg/tests/unit/fake_extarq.py b/cyborg/tests/unit/fake_extarq.py index d191b37d..c540c4ec 100644 --- a/cyborg/tests/unit/fake_extarq.py +++ b/cyborg/tests/unit/fake_extarq.py @@ -29,53 +29,71 @@ def _get_arqs_as_dict(): "bus": "1", "device": "0", "domain": "0", - "function": "0" + "function": "0", }, "device_profile_group": { "trait:CUSTOM_FPGA_INTEL": "required", "resources:FPGA": "1", - "accel:bitstream_id": "b069d97a-010a-4057-b70d-eca2b337fc9c"} + "accel:bitstream_id": "b069d97a-010a-4057-b70d-eca2b337fc9c", + }, } dp_groups = [ {"device_profile_group": {"resources:GPU": "1"}}, - {"device_profile_group": { - "trait:CUSTOM_FPGA_INTEL": "required", - "resources:FPGA": "1"}}, - {"device_profile_group": { - "trait:CUSTOM_FPGA_INTEL": "required", - "resources:FPGA": "1", - "accel:bitstream_id": "b069d97a-010a-4057-b70d-eca2b337fc9c"}}, - {"device_profile_group": { - "trait:CUSTOM_FPGA_INTEL": "required", - "resources:FPGA": "1", - "accel:function_id": "25453786-03e0-4ee7-a640-969eb5a5aa44"}}, - {"device_profile_group": { - "trait:CUSTOM_FPGA_INTEL": "required", - "resources:FPGA": "1", - "accel:bitstream_id": "b069d97a-010a-4057-b70d-eca2b337fc9c", - "accel:function_id": "25453786-03e0-4ee7-a640-969eb5a5aa44"}}, + { + "device_profile_group": { + "trait:CUSTOM_FPGA_INTEL": "required", + "resources:FPGA": "1", + } + }, + { + "device_profile_group": { + "trait:CUSTOM_FPGA_INTEL": "required", + "resources:FPGA": "1", + "accel:bitstream_id": "b069d97a-010a-4057-b70d-eca2b337fc9c", + } + }, + { + "device_profile_group": { + "trait:CUSTOM_FPGA_INTEL": "required", + "resources:FPGA": "1", + "accel:function_id": "25453786-03e0-4ee7-a640-969eb5a5aa44", + } + }, + { + "device_profile_group": { + "trait:CUSTOM_FPGA_INTEL": "required", + "resources:FPGA": "1", + "accel:bitstream_id": "b069d97a-010a-4057-b70d-eca2b337fc9c", + "accel:function_id": "25453786-03e0-4ee7-a640-969eb5a5aa44", + } + }, ] arqs = [ # Corresponds to 1st device profile in fake_device)profile.py - {"uuid": "a097fefa-da62-4630-8e8b-424c0e3426dc", - "device_profile_group_id": 0, - "device_rp_uuid": "8787595e-9954-49f8-b5c1-cdb55b59062f", - }, - {"uuid": "aa140114-4869-45ea-8213-45f530804b0f", - "device_profile_group_id": 1, - "device_rp_uuid": "a1ec17f2-0051-4737-bac4-f074d8a01a9c", - }, - {"uuid": "292b2fa2-0831-484c-aeac-09c794428a5d", - "device_profile_group_id": 2, - "device_rp_uuid": "a1ec17f2-0051-4737-bac4-f074d8a01a9c", - }, - {"uuid": "3049ad04-a2b1-40a3-b9c8-480a5e661645", - "device_profile_group_id": 3, - "device_rp_uuid": "57455a49-bde4-490e-9179-9aa84a3870bb", - }, - {"uuid": "3a9a07e7-d126-47a5-bf11-dcc04f9e60ff", - "device_profile_group_id": 4, - "device_rp_uuid": "fbd485e1-40b1-4a7e-84b9-f6b6959114a4", - }, + { + "uuid": "a097fefa-da62-4630-8e8b-424c0e3426dc", + "device_profile_group_id": 0, + "device_rp_uuid": "8787595e-9954-49f8-b5c1-cdb55b59062f", + }, + { + "uuid": "aa140114-4869-45ea-8213-45f530804b0f", + "device_profile_group_id": 1, + "device_rp_uuid": "a1ec17f2-0051-4737-bac4-f074d8a01a9c", + }, + { + "uuid": "292b2fa2-0831-484c-aeac-09c794428a5d", + "device_profile_group_id": 2, + "device_rp_uuid": "a1ec17f2-0051-4737-bac4-f074d8a01a9c", + }, + { + "uuid": "3049ad04-a2b1-40a3-b9c8-480a5e661645", + "device_profile_group_id": 3, + "device_rp_uuid": "57455a49-bde4-490e-9179-9aa84a3870bb", + }, + { + "uuid": "3a9a07e7-d126-47a5-bf11-dcc04f9e60ff", + "device_profile_group_id": 4, + "device_rp_uuid": "fbd485e1-40b1-4a7e-84b9-f6b6959114a4", + }, ] new_arqs = [] for idx, new_arq in enumerate(arqs): @@ -88,38 +106,41 @@ def _get_arqs_as_dict(): def _get_arqs_resloved_as_dict(): arqs = [ # Corresponds to 1st device profile in fake_device)profile.py - {"uuid": 'a097fefa-da62-4630-8e8b-424c0e3426dd', - "device_profile_group_id": 0, - "state": "Initial", - "device_profile_name": "dp_example_1", - "device_rp_uuid": None, - "hostname": None, - "instance_uuid": None, - "attach_handle_type": None, - # attach_handle info should vary across ARQs but ignored for testing - "attach_handle_info": {}, - "device_profile_group": {} - }, - {"uuid": 'aa140114-4869-45ea-8213-45f530804b0e', - "device_profile_group_id": 1, - "device_rp_uuid": "fbd485e1-40b1-4a7e-84b9-f6b6959114a5", - "state": "Bound", - "device_profile_name": "dp_example_1", - "hostname": "myhost", - "instance_uuid": "5922a70f-1e06-4cfd-88dd-a332120d7144", - "attach_handle_type": "PCI", - # attach_handle info should vary across ARQs but ignored for testing - "attach_handle_info": { - "bus": "1", - "device": "0", - "domain": "0", - "function": "0" - }, - "device_profile_group": { - "trait:CUSTOM_FPGA_INTEL": "required", - "resources:FPGA": "1", - "accel:bitstream_id": "b069d97a-010a-4057-b70d-eca2b337fc9c"} - }, + { + "uuid": 'a097fefa-da62-4630-8e8b-424c0e3426dd', + "device_profile_group_id": 0, + "state": "Initial", + "device_profile_name": "dp_example_1", + "device_rp_uuid": None, + "hostname": None, + "instance_uuid": None, + "attach_handle_type": None, + # attach_handle info should vary across ARQs but ignored for testing + "attach_handle_info": {}, + "device_profile_group": {}, + }, + { + "uuid": 'aa140114-4869-45ea-8213-45f530804b0e', + "device_profile_group_id": 1, + "device_rp_uuid": "fbd485e1-40b1-4a7e-84b9-f6b6959114a5", + "state": "Bound", + "device_profile_name": "dp_example_1", + "hostname": "myhost", + "instance_uuid": "5922a70f-1e06-4cfd-88dd-a332120d7144", + "attach_handle_type": "PCI", + # attach_handle info should vary across ARQs but ignored for testing + "attach_handle_info": { + "bus": "1", + "device": "0", + "domain": "0", + "function": "0", + }, + "device_profile_group": { + "trait:CUSTOM_FPGA_INTEL": "required", + "resources:FPGA": "1", + "accel:bitstream_id": "b069d97a-010a-4057-b70d-eca2b337fc9c", + }, + }, ] new_arqs = [] for idx, new_arq in enumerate(arqs): @@ -142,10 +163,11 @@ def _get_arqs_bind_as_dict(): "bus": "1", "device": "0", "domain": "0", - "function": "0" + "function": "0", }, - "device_profile_group": {"resources:GPU": "1"} - }, { + "device_profile_group": {"resources:GPU": "1"}, + }, + { "state": "Deleting", "device_profile_name": "dp_example_2", "hostname": "myhost1", @@ -157,12 +179,14 @@ def _get_arqs_bind_as_dict(): "bus": "2", "device": "0", "domain": "0", - "function": "0" + "function": "0", }, "device_profile_group": { "trait:CUSTOM_FPGA_INTEL": "required", - "resources:FPGA": "1"} - }, { + "resources:FPGA": "1", + }, + }, + { "state": "Deleting", "device_profile_name": "dp_example_3", "hostname": "myhost3", @@ -174,13 +198,15 @@ def _get_arqs_bind_as_dict(): "bus": "3", "device": "0", "domain": "0", - "function": "0" + "function": "0", }, "device_profile_group": { "trait:CUSTOM_FPGA_INTEL": "required", "resources:FPGA": "1", - "accel:bitstream_id": "b069d97a-010a-4057-b70d-eca2b337fc9e"} - }, { + "accel:bitstream_id": "b069d97a-010a-4057-b70d-eca2b337fc9e", + }, + }, + { "state": "Unbound", "device_profile_name": "dp_example_2", "hostname": "myhost1", @@ -192,30 +218,35 @@ def _get_arqs_bind_as_dict(): "bus": "2", "device": "0", "domain": "0", - "function": "0" + "function": "0", }, "device_profile_group": { "trait:CUSTOM_FPGA_INTEL": "required", - "resources:FPGA": "1"} + "resources:FPGA": "1", + }, }, ] arqs = [ # Corresponds to 1st device profile in fake_device)profile.py - {"uuid": "a097fefa-da62-4630-8e8b-424c0e3426de", - "device_profile_group_id": 0, - "device_rp_uuid": "8787595e-9954-49f8-b5c1-cdb55b59062e", - }, - {"uuid": "aa140114-4869-45ea-8213-45f530804b0d", - "device_profile_group_id": 0, - "device_rp_uuid": "a1ec17f2-0051-4737-bac4-f074d8a01a9d", - }, - {"uuid": "292b2fa2-0831-484c-aeac-09c794428a5e", - "device_profile_group_id": 0, - "device_rp_uuid": "57455a49-bde4-490e-9179-9aa84a3870bb", - }, - {"uuid": "292b2fa2-0831-484c-aeac-09c794428a5d", - "device_profile_group_id": 0, - "device_rp_uuid": "57455a49-bde4-490e-9179-9aa84a3870bc", - } + { + "uuid": "a097fefa-da62-4630-8e8b-424c0e3426de", + "device_profile_group_id": 0, + "device_rp_uuid": "8787595e-9954-49f8-b5c1-cdb55b59062e", + }, + { + "uuid": "aa140114-4869-45ea-8213-45f530804b0d", + "device_profile_group_id": 0, + "device_rp_uuid": "a1ec17f2-0051-4737-bac4-f074d8a01a9d", + }, + { + "uuid": "292b2fa2-0831-484c-aeac-09c794428a5e", + "device_profile_group_id": 0, + "device_rp_uuid": "57455a49-bde4-490e-9179-9aa84a3870bb", + }, + { + "uuid": "292b2fa2-0831-484c-aeac-09c794428a5d", + "device_profile_group_id": 0, + "device_rp_uuid": "57455a49-bde4-490e-9179-9aa84a3870bc", + }, ] new_arqs = [] for idx, new_arq in enumerate(arqs): @@ -288,59 +319,72 @@ def get_patch_list(same_device=True): must be for the same device. """ arqs = _get_arqs_as_dict() - host_binding = {'path': '/hostname', 'op': 'add', - 'value': arqs[0]['hostname']} - inst_binding = {'path': '/instance_uuid', 'op': 'add', - 'value': arqs[0]['instance_uuid']} + host_binding = { + 'path': '/hostname', + 'op': 'add', + 'value': arqs[0]['hostname'], + } + inst_binding = { + 'path': '/instance_uuid', + 'op': 'add', + 'value': arqs[0]['instance_uuid'], + } device_rp_uuid = 'fb16c293-5739-4c84-8590-926f9ab16669' patch_list = {} for newarq in arqs: dev_uuid = device_rp_uuid if same_device else newarq['device_rp_uuid'] - dev_binding = {'path': '/device_rp_uuid', 'op': 'add', - 'value': dev_uuid} + dev_binding = { + 'path': '/device_rp_uuid', + 'op': 'add', + 'value': dev_uuid, + } patch_list[newarq['uuid']] = [host_binding, inst_binding, dev_binding] return patch_list, device_rp_uuid def get_fake_xilinx_fpga_extarq_objs(): arqs = [ - {"uuid": 'b8c19eb2-e03c-47b4-b7cf-ced6086b2d11', - "device_profile_group_id": 0, - "state": "Initial", - "device_profile_name": "fake_xilinx_fpga_dp", - "hostname": "myhost", - "instance_uuid": "5922a70f-1e06-4cfd-88dd-a332120d7144", - "attach_handle_type": "PCI", - # attach_handle info should vary across ARQs but ignored for testing - "attach_handle_info": { - "bus": "3b", - "device": "00", - "domain": "0000", - "function": "0" - }, - "device_profile_group": { - "trait:CUSTOM_FPGA_XILINX": "required", - "resources:FPGA": "1", - "trait:CUSTOM_FPGA_PRODUCT_ID_5000": "required"} - }, - {"uuid": '012955c7-90f9-45a9-bb7d-7c2907d8997f', - "device_profile_group_id": 1, - "state": "Initial", - "device_profile_name": "fake_xilinx_fpga_dp", - "hostname": "myhost", - "instance_uuid": "5922a70f-1e06-4cfd-88dd-a332120d7144", - "attach_handle_type": "PCI", - # attach_handle info should vary across ARQs but ignored for testing - "attach_handle_info": { - "bus": "3b", - "device": "00", - "domain": "0000", - "function": "1" - }, - "device_profile_group": { - "trait:CUSTOM_FPGA_XILINX": "required", - "resources:FPGA": "1", - "trait:CUSTOM_FPGA_PRODUCT_ID_5000": "required"} - }, + { + "uuid": 'b8c19eb2-e03c-47b4-b7cf-ced6086b2d11', + "device_profile_group_id": 0, + "state": "Initial", + "device_profile_name": "fake_xilinx_fpga_dp", + "hostname": "myhost", + "instance_uuid": "5922a70f-1e06-4cfd-88dd-a332120d7144", + "attach_handle_type": "PCI", + # attach_handle info should vary across ARQs but ignored for testing + "attach_handle_info": { + "bus": "3b", + "device": "00", + "domain": "0000", + "function": "0", + }, + "device_profile_group": { + "trait:CUSTOM_FPGA_XILINX": "required", + "resources:FPGA": "1", + "trait:CUSTOM_FPGA_PRODUCT_ID_5000": "required", + }, + }, + { + "uuid": '012955c7-90f9-45a9-bb7d-7c2907d8997f', + "device_profile_group_id": 1, + "state": "Initial", + "device_profile_name": "fake_xilinx_fpga_dp", + "hostname": "myhost", + "instance_uuid": "5922a70f-1e06-4cfd-88dd-a332120d7144", + "attach_handle_type": "PCI", + # attach_handle info should vary across ARQs but ignored for testing + "attach_handle_info": { + "bus": "3b", + "device": "00", + "domain": "0000", + "function": "1", + }, + "device_profile_group": { + "trait:CUSTOM_FPGA_XILINX": "required", + "resources:FPGA": "1", + "trait:CUSTOM_FPGA_PRODUCT_ID_5000": "required", + }, + }, ] return list(map(_convert_from_dict_to_obj, arqs)) diff --git a/cyborg/tests/unit/fake_physical_function.py b/cyborg/tests/unit/fake_physical_function.py index 11d81deb..5eca8d87 100644 --- a/cyborg/tests/unit/fake_physical_function.py +++ b/cyborg/tests/unit/fake_physical_function.py @@ -38,8 +38,8 @@ def fake_db_physical_function(**updates): 'assignable': True, 'instance_uuid': None, 'availability': 'Available', - 'accelerator_id': 1 - } + 'accelerator_id': 1, + } for name, field in physical_function.PhysicalFunction.fields.items(): if name in db_physical_function: @@ -49,8 +49,9 @@ def fake_db_physical_function(**updates): elif field.default != fields.UnspecifiedDefault: db_physical_function[name] = field.default else: - raise Exception('fake_db_physical_function needs help with %s' - % name) + raise Exception( + 'fake_db_physical_function needs help with %s' % name + ) if updates: db_physical_function.update(updates) @@ -62,9 +63,11 @@ def fake_physical_function_obj(context, obj_pf_class=None, **updates): if obj_pf_class is None: obj_pf_class = objects.VirtualFunction expected_attrs = updates.pop('expected_attrs', None) - pf = obj_pf_class._from_db_object(context, - obj_pf_class(), - fake_db_physical_function(**updates), - expected_attrs=expected_attrs) + pf = obj_pf_class._from_db_object( + context, + obj_pf_class(), + fake_db_physical_function(**updates), + expected_attrs=expected_attrs, + ) pf.obj_reset_changes() return pf diff --git a/cyborg/tests/unit/fake_virtual_function.py b/cyborg/tests/unit/fake_virtual_function.py index 6bf06b7c..40c5728e 100644 --- a/cyborg/tests/unit/fake_virtual_function.py +++ b/cyborg/tests/unit/fake_virtual_function.py @@ -38,8 +38,8 @@ def fake_db_virtual_function(**updates): 'assignable': True, 'instance_uuid': None, 'availability': 'Available', - 'accelerator_id': 1 - } + 'accelerator_id': 1, + } for name, field in virtual_function.VirtualFunction.fields.items(): if name in db_virtual_function: @@ -49,8 +49,9 @@ def fake_db_virtual_function(**updates): elif field.default != fields.UnspecifiedDefault: db_virtual_function[name] = field.default else: - raise Exception('fake_db_virtual_function needs help with %s' - % name) + raise Exception( + 'fake_db_virtual_function needs help with %s' % name + ) if updates: db_virtual_function.update(updates) @@ -62,9 +63,11 @@ def fake_virtual_function_obj(context, obj_vf_class=None, **updates): if obj_vf_class is None: obj_vf_class = objects.VirtualFunction expected_attrs = updates.pop('expected_attrs', None) - vf = obj_vf_class._from_db_object(context, - obj_vf_class(), - fake_db_virtual_function(**updates), - expected_attrs=expected_attrs) + vf = obj_vf_class._from_db_object( + context, + obj_vf_class(), + fake_db_virtual_function(**updates), + expected_attrs=expected_attrs, + ) vf.obj_reset_changes() return vf diff --git a/cyborg/tests/unit/image/test_glance.py b/cyborg/tests/unit/image/test_glance.py index 706fcceb..73652cbe 100644 --- a/cyborg/tests/unit/image/test_glance.py +++ b/cyborg/tests/unit/image/test_glance.py @@ -21,7 +21,6 @@ from cyborg.tests.unit.db import base class TestExceptionTranslations(base.DbTestCase): - def test_client_notfound_converts_to_imagenotfound(self): in_exc = glanceclient.exc.HTTPNotFound('123') out_exc = glance._translate_image_exception('123', in_exc) diff --git a/cyborg/tests/unit/objects/test_attach_handle.py b/cyborg/tests/unit/objects/test_attach_handle.py index e0d7dd2b..682facf9 100644 --- a/cyborg/tests/unit/objects/test_attach_handle.py +++ b/cyborg/tests/unit/objects/test_attach_handle.py @@ -20,15 +20,15 @@ from cyborg.tests.unit.db import utils class TestAttachHandleObject(base.DbTestCase): - def setUp(self): super().setUp() self.fake_attach_handle = utils.get_test_attach_handle() def test_get(self): uuid = self.fake_attach_handle['uuid'] - with mock.patch.object(self.dbapi, 'attach_handle_get_by_uuid', - autospec=True) as mock_attach_handle_get: + with mock.patch.object( + self.dbapi, 'attach_handle_get_by_uuid', autospec=True + ) as mock_attach_handle_get: mock_attach_handle_get.return_value = self.fake_attach_handle attach_handle = objects.AttachHandle.get(self.context, uuid) mock_attach_handle_get.assert_called_once_with(self.context, uuid) @@ -36,8 +36,9 @@ class TestAttachHandleObject(base.DbTestCase): def test_get_by_id(self): id = self.fake_attach_handle['id'] - with mock.patch.object(self.dbapi, 'attach_handle_get_by_id', - autospec=True) as mock_attach_handle_get: + with mock.patch.object( + self.dbapi, 'attach_handle_get_by_id', autospec=True + ) as mock_attach_handle_get: mock_attach_handle_get.return_value = self.fake_attach_handle attach_handle = objects.AttachHandle.get_by_id(self.context, id) mock_attach_handle_get.assert_called_once_with(self.context, id) @@ -49,7 +50,8 @@ class TestAttachHandleObject(base.DbTestCase): deployable_id = self.fake_attach_handle['deployable_id'] ah_filter = {'deployable_id': deployable_id} attach_handles = objects.AttachHandle.get_ah_list_by_deployable_id( - self.context, deployable_id) + self.context, deployable_id + ) mock_list.assert_called_once_with(self.context, ah_filter) self.assertEqual(deployable_id, attach_handles[0]['deployable_id']) @@ -58,22 +60,27 @@ class TestAttachHandleObject(base.DbTestCase): mock_list.return_value = [self.fake_attach_handle] deployable_id = self.fake_attach_handle['deployable_id'] attach_info = self.fake_attach_handle['attach_info'] - ah_filter = {'deployable_id': deployable_id, - 'attach_info': attach_info} + ah_filter = { + 'deployable_id': deployable_id, + 'attach_info': attach_info, + } attach_handles = objects.AttachHandle.get_ah_by_depid_attachinfo( - self.context, deployable_id, attach_info) + self.context, deployable_id, attach_info + ) mock_list.assert_called_once_with(self.context, ah_filter) self.assertEqual(attach_info, attach_handles['attach_info']) # test objects.AttachHandle.list() return [] mock_list.return_value = [] attach_handle = objects.AttachHandle.get_ah_by_depid_attachinfo( - self.context, deployable_id, attach_info) + self.context, deployable_id, attach_info + ) self.assertIsNone(attach_handle) def test_list(self): - with mock.patch.object(self.dbapi, 'attach_handle_list', - autospec=True) as mock_attach_handle_list: + with mock.patch.object( + self.dbapi, 'attach_handle_list', autospec=True + ) as mock_attach_handle_list: mock_attach_handle_list.return_value = [self.fake_attach_handle] attach_handles = objects.AttachHandle.list(self.context) self.assertEqual(1, mock_attach_handle_list.call_count) @@ -82,8 +89,9 @@ class TestAttachHandleObject(base.DbTestCase): self.assertEqual(self.context, attach_handles[0]._context) def test_list_with_filter(self): - with mock.patch.object(self.dbapi, 'attach_handle_get_by_filters', - autospec=True) as mock_ah_with_filter_list: + with mock.patch.object( + self.dbapi, 'attach_handle_get_by_filters', autospec=True + ) as mock_ah_with_filter_list: mock_ah_with_filter_list.return_value = [self.fake_attach_handle] filters = {'limit': 1} attach_handles = objects.AttachHandle.list(self.context, filters) @@ -97,70 +105,80 @@ class TestAttachHandleObject(base.DbTestCase): sort_key='created_at', limit=1, marker=None, - ) + ) @mock.patch.object(objects.base.CyborgObject, "_from_db_object") def test_attach_handle_allocate(self, mock_from_db_obj): deployable_id = self.fake_attach_handle['deployable_id'] - with mock.patch.object(self.dbapi, 'attach_handle_allocate', - autospec=True) as mock_ah_allocate: + with mock.patch.object( + self.dbapi, 'attach_handle_allocate', autospec=True + ) as mock_ah_allocate: mock_ah_allocate.return_value = self.fake_attach_handle objects.AttachHandle.allocate(self.context, deployable_id) mock_from_db_obj.assert_called_once_with( - mock.ANY, - self.fake_attach_handle) + mock.ANY, self.fake_attach_handle + ) def test_attach_handle_deallocate(self): attach_handle_uuid = self.fake_attach_handle['uuid'] - with mock.patch.object(self.dbapi, 'attach_handle_update', - autospec=True) as mock_ah_update: + with mock.patch.object( + self.dbapi, 'attach_handle_update', autospec=True + ) as mock_ah_update: ah_obj = objects.AttachHandle(**self.fake_attach_handle) ah_obj.deallocate(self.context) mock_ah_update.assert_called_once_with( - self.context, - attach_handle_uuid, - {"in_use": False}) + self.context, attach_handle_uuid, {"in_use": False} + ) def test_create(self): - with mock.patch.object(self.dbapi, 'attach_handle_create', - autospec=True) as mock_attach_handle_create: + with mock.patch.object( + self.dbapi, 'attach_handle_create', autospec=True + ) as mock_attach_handle_create: mock_attach_handle_create.return_value = self.fake_attach_handle - attach_handle = objects.AttachHandle(self.context, - **self.fake_attach_handle) + attach_handle = objects.AttachHandle( + self.context, **self.fake_attach_handle + ) attach_handle.create(self.context) mock_attach_handle_create.assert_called_once_with( - self.context, self.fake_attach_handle) + self.context, self.fake_attach_handle + ) self.assertEqual(self.context, attach_handle._context) def test_destroy(self): uuid = self.fake_attach_handle['uuid'] - with mock.patch.object(self.dbapi, 'attach_handle_get_by_uuid', - autospec=True) as mock_attach_handle_get: + with mock.patch.object( + self.dbapi, 'attach_handle_get_by_uuid', autospec=True + ) as mock_attach_handle_get: mock_attach_handle_get.return_value = self.fake_attach_handle - with mock.patch.object(self.dbapi, 'attach_handle_delete', - autospec=True) as mock_attach_handle_delete: + with mock.patch.object( + self.dbapi, 'attach_handle_delete', autospec=True + ) as mock_attach_handle_delete: attach_handle = objects.AttachHandle.get(self.context, uuid) attach_handle.destroy(self.context) - mock_attach_handle_delete.assert_called_once_with(self.context, - uuid) + mock_attach_handle_delete.assert_called_once_with( + self.context, uuid + ) self.assertEqual(self.context, attach_handle._context) def test_update(self): uuid = self.fake_attach_handle['uuid'] - with mock.patch.object(self.dbapi, 'attach_handle_get_by_uuid', - autospec=True) as mock_attach_handle_get: + with mock.patch.object( + self.dbapi, 'attach_handle_get_by_uuid', autospec=True + ) as mock_attach_handle_get: mock_attach_handle_get.return_value = self.fake_attach_handle - with mock.patch.object(self.dbapi, 'attach_handle_update', - autospec=True) as mock_attach_handle_update: + with mock.patch.object( + self.dbapi, 'attach_handle_update', autospec=True + ) as mock_attach_handle_update: fake = self.fake_attach_handle fake["attach_info"] = "new_attach_info" mock_attach_handle_update.return_value = fake attach_handle = objects.AttachHandle.get(self.context, uuid) attach_handle.attach_info = 'new_attach_info' attach_handle.save(self.context) - mock_attach_handle_get.assert_called_once_with(self.context, - uuid) + mock_attach_handle_get.assert_called_once_with( + self.context, uuid + ) mock_attach_handle_update.assert_called_once_with( - self.context, uuid, - {'attach_info': 'new_attach_info'}) + self.context, uuid, {'attach_info': 'new_attach_info'} + ) self.assertEqual(self.context, attach_handle._context) diff --git a/cyborg/tests/unit/objects/test_control_path.py b/cyborg/tests/unit/objects/test_control_path.py index 30e4541a..7fb3edf9 100644 --- a/cyborg/tests/unit/objects/test_control_path.py +++ b/cyborg/tests/unit/objects/test_control_path.py @@ -22,15 +22,15 @@ from oslo_serialization import jsonutils class TestControlpathIDObject(base.DbTestCase): - def setUp(self): super().setUp() self.fake_control_path = utils.get_test_control_path() def test_get(self): uuid = self.fake_control_path['uuid'] - with mock.patch.object(self.dbapi, 'control_path_get_by_uuid', - autospec=True) as mock_control_path_get: + with mock.patch.object( + self.dbapi, 'control_path_get_by_uuid', autospec=True + ) as mock_control_path_get: mock_control_path_get.return_value = self.fake_control_path control_path = objects.ControlpathID.get(self.context, uuid) mock_control_path_get.assert_called_once_with(self.context, uuid) @@ -38,21 +38,27 @@ class TestControlpathIDObject(base.DbTestCase): def test_get_set_cpid_info_using_obj(self): uuid = self.fake_control_path['uuid'] - with mock.patch.object(self.dbapi, 'control_path_get_by_uuid', - autospec=True) as mock_control_path_get: + with mock.patch.object( + self.dbapi, 'control_path_get_by_uuid', autospec=True + ) as mock_control_path_get: mock_control_path_get.return_value = self.fake_control_path # test cpid_info_obj loader control_path = objects.ControlpathID.get(self.context, uuid) - self.assertEqual(jsonutils.loads(control_path.cpid_info), - control_path.cpid_info_obj) + self.assertEqual( + jsonutils.loads(control_path.cpid_info), + control_path.cpid_info_obj, + ) # test cpid_info_obj setter control_path.cpid_info_obj = {'bus': "fake"} - self.assertEqual(control_path.cpid_info, - jsonutils.dumps(control_path.cpid_info_obj)) + self.assertEqual( + control_path.cpid_info, + jsonutils.dumps(control_path.cpid_info_obj), + ) def test_list(self): - with mock.patch.object(self.dbapi, 'control_path_list', - autospec=True) as mock_control_path_list: + with mock.patch.object( + self.dbapi, 'control_path_list', autospec=True + ) as mock_control_path_list: mock_control_path_list.return_value = [self.fake_control_path] control_paths = objects.ControlpathID.list(self.context) self.assertEqual(1, mock_control_path_list.call_count) @@ -61,45 +67,54 @@ class TestControlpathIDObject(base.DbTestCase): self.assertEqual(self.context, control_paths[0]._context) def test_create(self): - with mock.patch.object(self.dbapi, 'control_path_create', - autospec=True) as mock_control_path_create: + with mock.patch.object( + self.dbapi, 'control_path_create', autospec=True + ) as mock_control_path_create: mock_control_path_create.return_value = self.fake_control_path - control_path = objects.ControlpathID(self.context, - **self.fake_control_path) + control_path = objects.ControlpathID( + self.context, **self.fake_control_path + ) control_path.create(self.context) mock_control_path_create.assert_called_once_with( - self.context, self.fake_control_path) + self.context, self.fake_control_path + ) self.assertEqual(self.context, control_path._context) def test_destroy(self): uuid = self.fake_control_path['uuid'] - with mock.patch.object(self.dbapi, 'control_path_get_by_uuid', - autospec=True) as mock_control_path_get: + with mock.patch.object( + self.dbapi, 'control_path_get_by_uuid', autospec=True + ) as mock_control_path_get: mock_control_path_get.return_value = self.fake_control_path - with mock.patch.object(self.dbapi, 'control_path_delete', - autospec=True) as mock_control_path_delete: + with mock.patch.object( + self.dbapi, 'control_path_delete', autospec=True + ) as mock_control_path_delete: control_path = objects.ControlpathID.get(self.context, uuid) control_path.destroy(self.context) - mock_control_path_delete.assert_called_once_with(self.context, - uuid) + mock_control_path_delete.assert_called_once_with( + self.context, uuid + ) self.assertEqual(self.context, control_path._context) def test_update(self): uuid = self.fake_control_path['uuid'] - with mock.patch.object(self.dbapi, 'control_path_get_by_uuid', - autospec=True) as mock_control_path_get: + with mock.patch.object( + self.dbapi, 'control_path_get_by_uuid', autospec=True + ) as mock_control_path_get: mock_control_path_get.return_value = self.fake_control_path - with mock.patch.object(self.dbapi, 'control_path_update', - autospec=True) as mock_control_path_update: + with mock.patch.object( + self.dbapi, 'control_path_update', autospec=True + ) as mock_control_path_update: fake = self.fake_control_path fake["cpid_info"] = "new_cpid_info" mock_control_path_update.return_value = fake control_path = objects.ControlpathID.get(self.context, uuid) control_path.cpid_info = 'new_cpid_info' control_path.save(self.context) - mock_control_path_get.assert_called_once_with(self.context, - uuid) + mock_control_path_get.assert_called_once_with( + self.context, uuid + ) mock_control_path_update.assert_called_once_with( - self.context, uuid, - {'cpid_info': 'new_cpid_info'}) + self.context, uuid, {'cpid_info': 'new_cpid_info'} + ) self.assertEqual(self.context, control_path._context) diff --git a/cyborg/tests/unit/objects/test_deployable.py b/cyborg/tests/unit/objects/test_deployable.py index 2b314823..3292e6d0 100644 --- a/cyborg/tests/unit/objects/test_deployable.py +++ b/cyborg/tests/unit/objects/test_deployable.py @@ -26,7 +26,6 @@ from cyborg.tests.unit.objects import test_objects class TestDeployableObject(DbTestCase): - @property def fake_device(self): db_device = fake_device.get_fake_devices_as_dict()[0] @@ -44,13 +43,11 @@ class TestDeployableObject(DbTestCase): def test_create(self): db_device = self.fake_device - device = objects.Device(context=self.context, - **db_device) + device = objects.Device(context=self.context, **db_device) device.create(self.context) device_get = objects.Device.get(self.context, device.uuid) db_dpl = self.fake_deployable - dpl = objects.Deployable(context=self.context, - **db_dpl) + dpl = objects.Deployable(context=self.context, **db_dpl) dpl.device_id = device_get.id dpl.create(self.context) @@ -59,13 +56,11 @@ class TestDeployableObject(DbTestCase): def test_get(self): db_device = self.fake_device - device = objects.Device(context=self.context, - **db_device) + device = objects.Device(context=self.context, **db_device) device.create(self.context) device_get = objects.Device.get(self.context, device.uuid) db_dpl = self.fake_deployable - dpl = objects.Deployable(context=self.context, - **db_dpl) + dpl = objects.Deployable(context=self.context, **db_dpl) dpl.device_id = device_get.id dpl.create(self.context) @@ -74,13 +69,11 @@ class TestDeployableObject(DbTestCase): def test_get_by_filter(self): db_device = self.fake_device - device = objects.Device(context=self.context, - **db_device) + device = objects.Device(context=self.context, **db_device) device.create(self.context) device_get = objects.Device.get(self.context, device.uuid) db_dpl = self.fake_deployable - dpl = objects.Deployable(context=self.context, - **db_dpl) + dpl = objects.Deployable(context=self.context, **db_dpl) dpl.device_id = device_get.id dpl.create(self.context) @@ -91,13 +84,11 @@ class TestDeployableObject(DbTestCase): def test_save(self): db_device = self.fake_device - device = objects.Device(context=self.context, - **db_device) + device = objects.Device(context=self.context, **db_device) device.create(self.context) device_get = objects.Device.get(self.context, device.uuid) db_dpl = self.fake_deployable - dpl = objects.Deployable(context=self.context, - **db_dpl) + dpl = objects.Deployable(context=self.context, **db_dpl) dpl.device_id = device_get.id dpl.create(self.context) @@ -108,35 +99,40 @@ class TestDeployableObject(DbTestCase): def test_destroy(self): db_device = self.fake_device - device = objects.Device(context=self.context, - **db_device) + device = objects.Device(context=self.context, **db_device) device.create(self.context) device_get = objects.Device.get(self.context, device.uuid) db_dpl = self.fake_deployable - dpl = objects.Deployable(context=self.context, - **db_dpl) + dpl = objects.Deployable(context=self.context, **db_dpl) dpl.device_id = device_get.id dpl.create(self.context) self.assertEqual(db_dpl['uuid'], dpl.uuid) dpl.destroy(self.context) - self.assertRaises(exception.ResourceNotFound, - objects.Deployable.get, self.context, - dpl.uuid) + self.assertRaises( + exception.ResourceNotFound, + objects.Deployable.get, + self.context, + dpl.uuid, + ) -class TestDeployableObject(test_objects._LocalTest, - TestDeployableObject): - def _test_save_objectfield_fk_constraint_fails(self, foreign_key, - expected_exception): - - error = db_exc.DBReferenceError('table', 'constraint', foreign_key, - 'key_table') +class TestDeployableObject(test_objects._LocalTest, TestDeployableObject): + def _test_save_objectfield_fk_constraint_fails( + self, foreign_key, expected_exception + ): + error = db_exc.DBReferenceError( + 'table', 'constraint', foreign_key, 'key_table' + ) # Prevent lazy-loading any fields, results in InstanceNotFound deployable = fake_deployable.fake_deployable_obj(self.context) - fields_with_save_methods = [field for field in deployable.fields - if hasattr(deployable, '_save_%s' % field)] + fields_with_save_methods = [ + field + for field in deployable.fields + if hasattr(deployable, '_save_%s' % field) + ] for field in fields_with_save_methods: + @mock.patch.object(deployable, '_save_%s' % field) @mock.patch.object(deployable, 'obj_attr_is_set') def _test(mock_is_set, mock_save_field): @@ -146,4 +142,5 @@ class TestDeployableObject(test_objects._LocalTest, deployable._changed_fields.add(field) self.assertRaises(expected_exception, deployable.save) deployable.obj_reset_changes(fields=[field]) + _test() diff --git a/cyborg/tests/unit/objects/test_device.py b/cyborg/tests/unit/objects/test_device.py index 30055f5a..a72bb6f8 100644 --- a/cyborg/tests/unit/objects/test_device.py +++ b/cyborg/tests/unit/objects/test_device.py @@ -21,15 +21,15 @@ from cyborg.tests.unit import fake_device class TestDeviceObject(base.DbTestCase): - def setUp(self): super().setUp() self.fake_device = fake_device.get_db_devices()[0] def test_get(self): uuid = self.fake_device['uuid'] - with mock.patch.object(self.dbapi, 'device_get', - autospec=True) as mock_device_get: + with mock.patch.object( + self.dbapi, 'device_get', autospec=True + ) as mock_device_get: mock_device_get.return_value = self.fake_device device = objects.Device.get(self.context, uuid) mock_device_get.assert_called_once_with(self.context, uuid) @@ -37,24 +37,30 @@ class TestDeviceObject(base.DbTestCase): def test_get_by_id(self): device_id = self.fake_device['id'] - with mock.patch.object(self.dbapi, 'device_get_by_id', - autospec=True) as mock_device_get_by_id: + with mock.patch.object( + self.dbapi, 'device_get_by_id', autospec=True + ) as mock_device_get_by_id: mock_device_get_by_id.return_value = self.fake_device device = objects.Device.get_by_device_id(self.context, device_id) mock_device_get_by_id.assert_called_once_with( - self.context, device_id) + self.context, device_id + ) self.assertEqual(self.context, device._context) def test_get_by_non_existed_id(self): device_id = self.fake_device['id'] - with mock.patch.object(self.dbapi, 'device_get_by_id', - autospec=True) as mock_device_get_by_id: + with mock.patch.object( + self.dbapi, 'device_get_by_id', autospec=True + ) as mock_device_get_by_id: mock_device_get_by_id.side_effect = exception.ResourceNotFound( - resource='Device', msg='with uuid=%s' % device_id) - self.assertRaises(exception.ResourceNotFound, - objects.Device.get_by_device_id, - self.context, - device_id) + resource='Device', msg='with uuid=%s' % device_id + ) + self.assertRaises( + exception.ResourceNotFound, + objects.Device.get_by_device_id, + self.context, + device_id, + ) @mock.patch.object(objects.Device, 'list') def test_get_by_hostname(self, mock_list): @@ -74,8 +80,9 @@ class TestDeviceObject(base.DbTestCase): self.assertEqual([], devices) def test_list(self): - with mock.patch.object(self.dbapi, 'device_list', - autospec=True) as mock_device_list: + with mock.patch.object( + self.dbapi, 'device_list', autospec=True + ) as mock_device_list: mock_device_list.return_value = [self.fake_device] devices = objects.Device.list(self.context) self.assertEqual(1, mock_device_list.call_count) @@ -84,8 +91,9 @@ class TestDeviceObject(base.DbTestCase): self.assertEqual(self.context, devices[0]._context) def test_list_with_filter(self): - with mock.patch.object(self.dbapi, 'device_list_by_filters', - autospec=True) as mock_device_with_filter_list: + with mock.patch.object( + self.dbapi, 'device_list_by_filters', autospec=True + ) as mock_device_with_filter_list: mock_device_with_filter_list.return_value = [self.fake_device] filters = {'limit': 1} devices = objects.Device.list(self.context, filters) @@ -99,50 +107,55 @@ class TestDeviceObject(base.DbTestCase): sort_key='created_at', limit=1, marker=None, - ) + ) def test_create(self): - with mock.patch.object(self.dbapi, 'device_create', - autospec=True) as mock_device_create: + with mock.patch.object( + self.dbapi, 'device_create', autospec=True + ) as mock_device_create: mock_device_create.return_value = self.fake_device - device = objects.Device(self.context, - **self.fake_device) + device = objects.Device(self.context, **self.fake_device) device.create(self.context) mock_device_create.assert_called_once_with( - self.context, self.fake_device) + self.context, self.fake_device + ) self.assertEqual(self.context, device._context) def test_destroy(self): uuid = self.fake_device['uuid'] - with mock.patch.object(self.dbapi, 'device_get', - autospec=True) as mock_device_get: + with mock.patch.object( + self.dbapi, 'device_get', autospec=True + ) as mock_device_get: mock_device_get.return_value = self.fake_device - with mock.patch.object(self.dbapi, 'device_delete', - autospec=True) as mock_device_delete: + with mock.patch.object( + self.dbapi, 'device_delete', autospec=True + ) as mock_device_delete: device = objects.Device.get(self.context, uuid) device.destroy(self.context) - mock_device_delete.assert_called_once_with(self.context, - uuid) + mock_device_delete.assert_called_once_with(self.context, uuid) self.assertEqual(self.context, device._context) def test_update(self): uuid = self.fake_device['uuid'] - with mock.patch.object(self.dbapi, 'device_get', - autospec=True) as mock_device_get: + with mock.patch.object( + self.dbapi, 'device_get', autospec=True + ) as mock_device_get: mock_device_get.return_value = self.fake_device - with mock.patch.object(self.dbapi, 'device_update', - autospec=True) as mock_device_update: + with mock.patch.object( + self.dbapi, 'device_update', autospec=True + ) as mock_device_update: fake = self.fake_device fake["vendor_board_info"] = "new_vendor_board_info" mock_device_update.return_value = fake device = objects.Device.get(self.context, uuid) device.vendor_board_info = 'new_vendor_board_info' device.save(self.context) - mock_device_get.assert_called_once_with(self.context, - uuid) + mock_device_get.assert_called_once_with(self.context, uuid) mock_device_update.assert_called_once_with( - self.context, uuid, - {'vendor_board_info': 'new_vendor_board_info'}) + self.context, + uuid, + {'vendor_board_info': 'new_vendor_board_info'}, + ) self.assertEqual(self.context, device._context) def test_device_type(self): @@ -150,5 +163,6 @@ class TestDeviceObject(base.DbTestCase): device = objects.Device(self.context, type=t) self.assertEqual(self.context, device._context) # Invalid type will raise ValueError - self.assertRaises(ValueError, objects.Device, - self.context, type='OTHER_TYPE') + self.assertRaises( + ValueError, objects.Device, self.context, type='OTHER_TYPE' + ) diff --git a/cyborg/tests/unit/objects/test_device_profile.py b/cyborg/tests/unit/objects/test_device_profile.py index 6bcbe075..a20e43c5 100644 --- a/cyborg/tests/unit/objects/test_device_profile.py +++ b/cyborg/tests/unit/objects/test_device_profile.py @@ -21,15 +21,15 @@ from cyborg.tests.unit import fake_device_profile class TestDeviceProfileObject(base.DbTestCase): - def setUp(self): super().setUp() self.fake_device_profile = utils.get_test_device_profile() def test_get_by_name(self): name = self.fake_device_profile['name'] - with mock.patch.object(self.dbapi, 'device_profile_get', - autospec=True) as mock_db_devprof_get: + with mock.patch.object( + self.dbapi, 'device_profile_get', autospec=True + ) as mock_db_devprof_get: mock_db_devprof_get.return_value = self.fake_device_profile obj_devprof = objects.DeviceProfile.get_by_name(self.context, name) mock_db_devprof_get.assert_called_once_with(self.context, name) @@ -39,8 +39,9 @@ class TestDeviceProfileObject(base.DbTestCase): def test_get_by_uuid(self): uuid = self.fake_device_profile['uuid'] - with mock.patch.object(self.dbapi, 'device_profile_get_by_uuid', - autospec=True) as mock_db_devprof_get: + with mock.patch.object( + self.dbapi, 'device_profile_get_by_uuid', autospec=True + ) as mock_db_devprof_get: mock_db_devprof_get.return_value = self.fake_device_profile obj_devprof = objects.DeviceProfile.get_by_uuid(self.context, uuid) mock_db_devprof_get.assert_called_once_with(self.context, uuid) @@ -49,42 +50,51 @@ class TestDeviceProfileObject(base.DbTestCase): self.assertIn('description', obj_devprof) def test_list(self): - with mock.patch.object(self.dbapi, 'device_profile_list', - autospec=True) as mock_db_devprof_list: + with mock.patch.object( + self.dbapi, 'device_profile_list', autospec=True + ) as mock_db_devprof_list: mock_db_devprof_list.return_value = [self.fake_device_profile] obj_devprofs = objects.DeviceProfile.list(self.context) self.assertEqual(1, mock_db_devprof_list.call_count) self.assertEqual(1, len(obj_devprofs)) self.assertIsInstance(obj_devprofs[0], objects.DeviceProfile) self.assertEqual(self.context, obj_devprofs[0]._context) - self.assertEqual(self.fake_device_profile['name'], - obj_devprofs[0].name) - self.assertEqual(self.fake_device_profile['description'], - obj_devprofs[0].description) + self.assertEqual( + self.fake_device_profile['name'], obj_devprofs[0].name + ) + self.assertEqual( + self.fake_device_profile['description'], + obj_devprofs[0].description, + ) def test_create(self): api_devprofs = fake_device_profile.get_api_devprofs() api_devprof = api_devprofs[0] db_devprofs = fake_device_profile.get_db_devprofs() db_devprof = db_devprofs[0] - with mock.patch.object(self.dbapi, 'device_profile_create', - autospec=True) as mock_db_devprof_create: + with mock.patch.object( + self.dbapi, 'device_profile_create', autospec=True + ) as mock_db_devprof_create: mock_db_devprof_create.return_value = self.fake_device_profile obj_devprof = objects.DeviceProfile(**api_devprof) obj_devprof.create(self.context) mock_db_devprof_create.assert_called_once_with( - self.context, db_devprof) + self.context, db_devprof + ) def test_destroy(self): uuid = self.fake_device_profile['uuid'] - with mock.patch.object(self.dbapi, 'device_profile_get_by_uuid', - autospec=True) as mock_dp_get: + with mock.patch.object( + self.dbapi, 'device_profile_get_by_uuid', autospec=True + ) as mock_dp_get: mock_dp_get.return_value = self.fake_device_profile - with mock.patch.object(self.dbapi, 'device_profile_delete', - autospec=True) as m_dp_delete: + with mock.patch.object( + self.dbapi, 'device_profile_delete', autospec=True + ) as m_dp_delete: m_dp_delete.return_value = None obj_devprof = objects.DeviceProfile.get_by_uuid( - self.context, uuid) + self.context, uuid + ) obj_devprof.destroy(self.context) m_dp_delete.assert_called_once_with(self.context, uuid) self.assertEqual(self.context, obj_devprof._context) @@ -95,8 +105,9 @@ class TestDeviceProfileObject(base.DbTestCase): db_devprof = fake_db_devprofs[0] db_devprof['created_at'] = None db_devprof['updated_at'] = None - with mock.patch.object(self.dbapi, 'device_profile_get_by_uuid', - autospec=True) as mock_dp_get: + with mock.patch.object( + self.dbapi, 'device_profile_get_by_uuid', autospec=True + ) as mock_dp_get: mock_dp_get.return_value = db_devprof uuid = fake_db_devprofs[0]['uuid'] # Start with db_devprofs[0], corr. to fake_obj_devprofs[0] @@ -104,8 +115,9 @@ class TestDeviceProfileObject(base.DbTestCase): # Change contents to fake_obj_devprofs[1] except uuid obj_devprof = fake_obj_devprofs[1] obj_devprof['uuid'] = uuid - with mock.patch.object(self.dbapi, 'device_profile_update', - autospec=True) as mock_dp_update: + with mock.patch.object( + self.dbapi, 'device_profile_update', autospec=True + ) as mock_dp_update: mock_dp_update.return_value = db_devprof obj_devprof.save(self.context) mock_dp_get.assert_called_once_with(self.context, uuid) diff --git a/cyborg/tests/unit/objects/test_ext_arq_job.py b/cyborg/tests/unit/objects/test_ext_arq_job.py index a461a5e8..35d0a8c0 100644 --- a/cyborg/tests/unit/objects/test_ext_arq_job.py +++ b/cyborg/tests/unit/objects/test_ext_arq_job.py @@ -30,39 +30,49 @@ from cyborg.tests.unit import fake_extarq class TestExtARQJobMixin(base.DbTestCase): - def setUp(self): super().setUp() self.fake_db_extarqs = fake_extarq.get_fake_db_extarqs() self.fake_obj_extarqs = fake_extarq.get_fake_extarq_objs() self.fake_obj_fpga_extarqs = fake_extarq.get_fake_fpga_extarq_objs() self.deployable_uuids = ['0acbf8d6-e02a-4394-aae3-57557d209498'] - self.classes = ["gpu", "no_program", "bitstream_program", - "function_program", "bad_program"] - self.class_objects = dict( - zip(self.classes, self.fake_obj_extarqs)) - self.class_dbs = dict( - zip(self.classes, self.fake_db_extarqs)) - self.fpga_classes = ["no_program", "bitstream_program", - "function_program", "bad_program"] + self.classes = [ + "gpu", + "no_program", + "bitstream_program", + "function_program", + "bad_program", + ] + self.class_objects = dict(zip(self.classes, self.fake_obj_extarqs)) + self.class_dbs = dict(zip(self.classes, self.fake_db_extarqs)) + self.fpga_classes = [ + "no_program", + "bitstream_program", + "function_program", + "bad_program", + ] self.fpga_class_objects = dict( - zip(self.fpga_classes, self.fake_obj_fpga_extarqs)) + zip(self.fpga_classes, self.fake_obj_fpga_extarqs) + ) self.bitstream_id = self.class_objects["bitstream_program"][ - "device_profile_group"][constants.ACCEL_BITSTREAM_ID] + "device_profile_group" + ][constants.ACCEL_BITSTREAM_ID] self.function_id = self.class_objects["function_program"][ - "device_profile_group"][constants.ACCEL_FUNCTION_ID] + "device_profile_group" + ][constants.ACCEL_FUNCTION_ID] def test_get_resources_from_device_profile_group(self): expect = [("GPU")] + [("FPGA")] * 4 - actual = [v.get_resources_from_device_profile_group() - for v in self.class_objects.values()] + actual = [ + v.get_resources_from_device_profile_group() + for v in self.class_objects.values() + ] self.assertEqual(expect, actual) def test_get_suitable_ext_arq(self): expect_type = [objects.ExtARQ] + [objects.FPGAExtARQ] * 4 uuid = uuidutils.generate_uuid() - groups = [v['device_profile_group'] - for v in self.fake_db_extarqs] + groups = [v['device_profile_group'] for v in self.fake_db_extarqs] dp_db = { "id": self.fake_db_extarqs[0]['device_profile_id'], "uuid": uuid, @@ -73,23 +83,26 @@ class TestExtARQJobMixin(base.DbTestCase): "updated_at": timeutils.utcnow().isoformat(), } for i, v in enumerate(self.classes): - with mock.patch.object(self.dbapi, 'extarq_get') \ - as mock_extarq_get, \ - mock.patch.object(self.dbapi, - "device_profile_get_by_id") \ - as mock_dp_get, \ - mock.patch.object(self.dbapi, - "attach_handle_get_by_id") \ - as mock_ah_get: + with ( + mock.patch.object(self.dbapi, 'extarq_get') as mock_extarq_get, + mock.patch.object( + self.dbapi, "device_profile_get_by_id" + ) as mock_dp_get, + mock.patch.object( + self.dbapi, "attach_handle_get_by_id" + ) as mock_ah_get, + ): mock_ah_get.return_value = None mock_dp_get.return_value = dp_db - self.fake_db_extarqs[i].update({ - "state": constants.ARQ_BIND_STARTED, - "substate": constants.ARQ_BIND_STARTED, - "attach_handle_id": None, - "created_at": timeutils.utcnow().isoformat(), - "updated_at": timeutils.utcnow().isoformat() - }) + self.fake_db_extarqs[i].update( + { + "state": constants.ARQ_BIND_STARTED, + "substate": constants.ARQ_BIND_STARTED, + "attach_handle_id": None, + "created_at": timeutils.utcnow().isoformat(), + "updated_at": timeutils.utcnow().isoformat(), + } + ) self.fake_db_extarqs[i]["deployable_id"] = None mock_extarq_get.return_value = self.fake_db_extarqs[i] obj = self.class_objects[v] @@ -102,14 +115,19 @@ class TestExtARQJobMixin(base.DbTestCase): instance_uuid = obj_extarq.arq.instance_uuid uuid = obj_extarq.arq.uuid valid_fields = { - uuid: {'hostname': obj_extarq.arq.hostname, - 'device_rp_uuid': obj_extarq.arq.device_rp_uuid, - 'instance_uuid': instance_uuid} + uuid: { + 'hostname': obj_extarq.arq.hostname, + 'device_rp_uuid': obj_extarq.arq.device_rp_uuid, + 'instance_uuid': instance_uuid, } + } self.assertRaises( - exception.ARQBadState, obj_extarq.start_bind_job, - self.context, valid_fields) + exception.ARQBadState, + obj_extarq.start_bind_job, + self.context, + valid_fields, + ) @mock.patch('cyborg.objects.extarq.ext_arq_job.ExtARQJobMixin._bind_job') @mock.patch('cyborg.objects.deployable.Deployable.get_by_device_rp_uuid') @@ -120,18 +138,22 @@ class TestExtARQJobMixin(base.DbTestCase): instance_uuid = obj_extarq.arq.instance_uuid uuid = obj_extarq.arq.uuid dep_uuid = self.deployable_uuids[0] - fake_dep = fake_deployable.fake_deployable_obj(self.context, - uuid=dep_uuid) + fake_dep = fake_deployable.fake_deployable_obj( + self.context, uuid=dep_uuid + ) mock_get_dep.return_value = fake_dep valid_fields = { - uuid: {'hostname': obj_extarq.arq.hostname, - 'device_rp_uuid': obj_extarq.arq.device_rp_uuid, - 'instance_uuid': instance_uuid} + uuid: { + 'hostname': obj_extarq.arq.hostname, + 'device_rp_uuid': obj_extarq.arq.device_rp_uuid, + 'instance_uuid': instance_uuid, } + } obj_extarq.start_bind_job(self.context, valid_fields) mock_state.assert_called_once_with( - self.context, constants.ARQ_BIND_STARTED) + self.context, constants.ARQ_BIND_STARTED + ) mock_job.assert_called_once_with(self.context, fake_dep) @mock.patch('cyborg.objects.FPGAExtARQ.bind') @@ -140,8 +162,9 @@ class TestExtARQJobMixin(base.DbTestCase): obj_extarq = self.class_objects["gpu"] obj_extarq.arq.state = constants.ARQ_UNBOUND dep_uuid = self.deployable_uuids[0] - fake_dep = fake_deployable.fake_deployable_obj(self.context, - uuid=dep_uuid) + fake_dep = fake_deployable.fake_deployable_obj( + self.context, uuid=dep_uuid + ) is_job = getattr(obj_extarq.bind, "is_job", False) with mock.patch.object(obj_extarq, 'bind') as mock_bind: mock_bind.is_job = is_job @@ -155,8 +178,9 @@ class TestExtARQJobMixin(base.DbTestCase): obj_extarq = self.fpga_class_objects["no_program"] obj_extarq.arq.state = constants.ARQ_UNBOUND dep_uuid = self.deployable_uuids[0] - fake_dep = fake_deployable.fake_deployable_obj(self.context, - uuid=dep_uuid) + fake_dep = fake_deployable.fake_deployable_obj( + self.context, uuid=dep_uuid + ) need_bind = getattr(obj_extarq.bind, "is_job", False) with mock.patch.object(obj_extarq, 'bind') as mock_aysnc_bind: mock_aysnc_bind.is_job = need_bind @@ -171,14 +195,16 @@ class TestExtARQJobMixin(base.DbTestCase): obj_extarq = self.fpga_class_objects["bitstream_program"] obj_extarq.arq.state = constants.ARQ_UNBOUND dep_uuid = self.deployable_uuids[0] - fake_dep = fake_deployable.fake_deployable_obj(self.context, - uuid=dep_uuid) + fake_dep = fake_deployable.fake_deployable_obj( + self.context, uuid=dep_uuid + ) need_bind = getattr(obj_extarq.bind, "is_job", False) with mock.patch.object(obj_extarq, 'bind') as mock_aysnc_bind: mock_aysnc_bind.is_job = need_bind obj_extarq._bind_job(self.context, fake_dep) mock_spawn.assert_called_once_with( - mock_aysnc_bind, self.context, fake_dep) + mock_aysnc_bind, self.context, fake_dep + ) mock_bind.assert_not_called() @mock.patch('cyborg.common.utils.ThreadWorks.get_workers_result') @@ -190,13 +216,16 @@ class TestExtARQJobMixin(base.DbTestCase): works = utils.ThreadWorks() dep_uuid = self.deployable_uuids[0] - fake_dep = fake_deployable.fake_deployable_obj(self.context, - uuid=dep_uuid) + fake_dep = fake_deployable.fake_deployable_obj( + self.context, uuid=dep_uuid + ) arq_job_binds = { - self.class_objects["bitstream_program"]: - works.spawn(job, self.context, fake_dep), - self.class_objects["function_program"]: - works.spawn(job, self.context, fake_dep) + self.class_objects["bitstream_program"]: works.spawn( + job, self.context, fake_dep + ), + self.class_objects["function_program"]: works.spawn( + job, self.context, fake_dep + ), } arq_binds = { self.class_objects["gpu"]: None, @@ -221,30 +250,38 @@ class TestExtARQJobMixin(base.DbTestCase): @mock.patch('cyborg.objects.ExtARQ.bind_notify') @mock.patch('cyborg.objects.ExtARQ.list') def test_check_bindings_result_with_arq_deleted( - self, mock_list, mock_notify): + self, mock_list, mock_notify + ): bind_status = [ (e.arq.uuid, constants.ARQ_BIND_STATES_STATUS_MAP[e.arq.state]) - for e in self.fake_obj_extarqs] + for e in self.fake_obj_extarqs + ] extarqs = self.fake_obj_extarqs mock_list.return_value = self.fake_obj_fpga_extarqs instance_uuid = extarqs[0].arq.instance_uuid - objects.ExtARQ.check_bindings_result( - self.context, extarqs) + objects.ExtARQ.check_bindings_result(self.context, extarqs) mock_notify.assert_called_once_with(instance_uuid, bind_status) @mock.patch('cyborg.objects.ExtARQ.bind_notify') @mock.patch('cyborg.objects.ExtARQ.list') def test_check_bindings_result_with_non_bound_arq( - self, mock_list, mock_notify): + self, mock_list, mock_notify + ): bind_status = [ (e.arq.uuid, constants.ARQ_BIND_STATES_STATUS_MAP[e.arq.state]) - for e in self.fake_obj_extarqs] + for e in self.fake_obj_extarqs + ] extarqs = self.fake_obj_extarqs mock_list.return_value = extarqs instance_uuid = extarqs[0].arq.instance_uuid - states = [constants.ARQ_BIND_STARTED, constants.ARQ_BIND_STARTED, - constants.ARQ_UNBOUND, constants.ARQ_BOUND, - constants.ARQ_BIND_FAILED, constants.ARQ_DELETING] + states = [ + constants.ARQ_BIND_STARTED, + constants.ARQ_BIND_STARTED, + constants.ARQ_UNBOUND, + constants.ARQ_BOUND, + constants.ARQ_BIND_FAILED, + constants.ARQ_DELETING, + ] # Completed objects.ExtARQ.check_bindings_result(self.context, extarqs) mock_notify.assert_called_once_with(instance_uuid, bind_status) @@ -252,8 +289,11 @@ class TestExtARQJobMixin(base.DbTestCase): for s in states[:2]: extarqs[0].arq.state = s self.assertRaises( - exception.ARQBadState, objects.ExtARQ.check_bindings_result, - self.context, extarqs) + exception.ARQBadState, + objects.ExtARQ.check_bindings_result, + self.context, + extarqs, + ) # Failed for s in states[4:]: extarqs[0].arq.state = s @@ -269,15 +309,16 @@ class TestExtARQJobMixin(base.DbTestCase): @mock.patch('cyborg.objects.ExtARQ.bind_notify') @mock.patch('cyborg.objects.ExtARQ.list') def test_check_bindings_result_with_arq_bound( - self, mock_list, mock_notify): + self, mock_list, mock_notify + ): bind_status = [ (e.arq.uuid, constants.ARQ_BIND_STATES_STATUS_MAP[e.arq.state]) - for e in self.fake_obj_extarqs] + for e in self.fake_obj_extarqs + ] extarqs = self.fake_obj_extarqs mock_list.return_value = extarqs instance_uuid = extarqs[0].arq.instance_uuid - objects.ExtARQ.check_bindings_result( - self.context, extarqs) + objects.ExtARQ.check_bindings_result(self.context, extarqs) mock_notify.assert_called_once_with(instance_uuid, bind_status) @mock.patch('cyborg.objects.ext_arq.ExtARQJobMixin.check_bindings_result') @@ -286,13 +327,17 @@ class TestExtARQJobMixin(base.DbTestCase): err_job = works.spawn(lambda x: x / 0, 1) good_job = works.spawn(lambda x: x, 1) works_generator = works.get_workers_result( - [err_job, good_job], timeout=CONF.bind_timeout) + [err_job, good_job], timeout=CONF.bind_timeout + ) - extarqs = [self.class_objects["bitstream_program"], - self.class_objects["function_program"]] + extarqs = [ + self.class_objects["bitstream_program"], + self.class_objects["function_program"], + ] objects.ext_arq.ExtARQ.job_monitor( - self.context, works_generator, extarqs) + self.context, works_generator, extarqs + ) mock_result.assert_called_once_with(self.context, extarqs) @mock.patch('cyborg.objects.ext_arq.ExtARQJobMixin.check_bindings_result') @@ -301,26 +346,34 @@ class TestExtARQJobMixin(base.DbTestCase): job1 = works.spawn(lambda x: x, 1) job2 = works.spawn(lambda x: x, 2) works_generator = works.get_workers_result( - [job1, job2], timeout=CONF.bind_timeout) + [job1, job2], timeout=CONF.bind_timeout + ) - extarqs = [self.class_objects["bitstream_program"], - self.class_objects["function_program"]] + extarqs = [ + self.class_objects["bitstream_program"], + self.class_objects["function_program"], + ] objects.ext_arq.ExtARQ.job_monitor( - self.context, works_generator, extarqs) + self.context, works_generator, extarqs + ) mock_result.assert_called_once_with(self.context, extarqs) @mock.patch('cyborg.objects.ext_arq.ExtARQJobMixin.check_bindings_result') def test_job_monitor_without_jobs(self, mock_result): works = utils.ThreadWorks() works_generator = works.get_workers_result( - [], timeout=CONF.bind_timeout) + [], timeout=CONF.bind_timeout + ) - extarqs = [self.class_objects["bitstream_program"], - self.class_objects["function_program"]] + extarqs = [ + self.class_objects["bitstream_program"], + self.class_objects["function_program"], + ] objects.ext_arq.ExtARQ.job_monitor( - self.context, works_generator, extarqs) + self.context, works_generator, extarqs + ) mock_result.assert_called_once_with(self.context, extarqs) @mock.patch('cyborg.objects.ext_arq.ExtARQJobMixin.check_bindings_result') @@ -328,44 +381,53 @@ class TestExtARQJobMixin(base.DbTestCase): works = utils.ThreadWorks() good_job = works.spawn(lambda x: x, 1) works_generator = works.get_workers_result( - [good_job], timeout=CONF.bind_timeout) + [good_job], timeout=CONF.bind_timeout + ) extarqs = [] objects.ext_arq.ExtARQ.job_monitor( - self.context, works_generator, extarqs) + self.context, works_generator, extarqs + ) mock_result.assert_not_called() @mock.patch('cyborg.common.nova_client.NovaAPI.notify_binding') @mock.patch('cyborg.common.nova_client.NovaAPI') def test_bind_notify(self, mock_api, mock_notify): mock_api.return_value = type( - "NovaAPI", (object,), {"notify_binding": mock_notify}) + "NovaAPI", (object,), {"notify_binding": mock_notify} + ) objects.ext_arq.ExtARQJobMixin.bind_notify( '5922a70f-1e06-4cfd-88dd-a332120d7144', - [('a097fefa-da62-4630-8e8b-424c0e3426dc', 'completed')]) + [('a097fefa-da62-4630-8e8b-424c0e3426dc', 'completed')], + ) mock_api.assert_called_once_with() mock_notify.assert_called_once_with( '5922a70f-1e06-4cfd-88dd-a332120d7144', - [('a097fefa-da62-4630-8e8b-424c0e3426dc', 'completed')]) + [('a097fefa-da62-4630-8e8b-424c0e3426dc', 'completed')], + ) @mock.patch('cyborg.objects.ExtARQ.unbind') @mock.patch('cyborg.objects.ext_arq.ExtARQJobMixin.get_suitable_ext_arq') def test_apply_patch_with_op_remove(self, mock_get, mock_unbind): patch_list = {} - host_binding = {'path': '/hostname', 'op': 'remove', - 'value': 'myhost'} - inst_binding = {'path': '/instance_uuid', 'op': 'remove', - 'value': '5922a70f-1e06-4cfd-88dd-a332120d7144'} + host_binding = {'path': '/hostname', 'op': 'remove', 'value': 'myhost'} + inst_binding = { + 'path': '/instance_uuid', + 'op': 'remove', + 'value': '5922a70f-1e06-4cfd-88dd-a332120d7144', + } device_rp_uuid = 'fb16c293-5739-4c84-8590-926f9ab16669' arp_uuid = 'a097fefa-da62-4630-8e8b-424c0e3426dc' patch_list[arp_uuid] = [host_binding, inst_binding, device_rp_uuid] mock_get.return_value = self.fake_obj_extarqs[0] valid_fields = { - arp_uuid: {'hostname': 'myhost', - 'device_rp_uuid': device_rp_uuid, - 'instance_uuid': '5922a70f-1e06-4cfd-88dd-a332120d7144'} + arp_uuid: { + 'hostname': 'myhost', + 'device_rp_uuid': device_rp_uuid, + 'instance_uuid': '5922a70f-1e06-4cfd-88dd-a332120d7144', + } } - objects.extarq.ext_arq_job.ExtARQJobMixin.apply_patch(self.context, - patch_list, - valid_fields) + objects.extarq.ext_arq_job.ExtARQJobMixin.apply_patch( + self.context, patch_list, valid_fields + ) mock_unbind.assert_called_with(self.context) diff --git a/cyborg/tests/unit/objects/test_extarq.py b/cyborg/tests/unit/objects/test_extarq.py index 0faaa764..5e01e4ce 100644 --- a/cyborg/tests/unit/objects/test_extarq.py +++ b/cyborg/tests/unit/objects/test_extarq.py @@ -28,7 +28,6 @@ from cyborg.tests.unit import fake_extarq class TestExtARQObject(base.DbTestCase): - def setUp(self): super().setUp() self.fake_db_extarqs = fake_extarq.get_fake_db_extarqs() @@ -42,8 +41,9 @@ class TestExtARQObject(base.DbTestCase): db_extarq = self.fake_db_extarqs[0] uuid = db_extarq['uuid'] mock_from_db_obj.return_value = self.fake_obj_extarqs[0] - with mock.patch.object(self.dbapi, 'extarq_get', - autospec=True) as mock_extarq_get: + with mock.patch.object( + self.dbapi, 'extarq_get', autospec=True + ) as mock_extarq_get: mock_extarq_get.return_value = db_extarq obj_extarq = objects.ExtARQ.get(self.context, uuid) mock_extarq_get.assert_called_once_with(self.context, uuid) @@ -53,8 +53,9 @@ class TestExtARQObject(base.DbTestCase): def test_list(self, mock_from_db_obj): db_extarq = self.fake_db_extarqs[0] mock_from_db_obj.return_value = self.fake_obj_extarqs[0] - with mock.patch.object(self.dbapi, 'extarq_list', - autospec=True) as mock_get_list: + with mock.patch.object( + self.dbapi, 'extarq_list', autospec=True + ) as mock_get_list: mock_get_list.return_value = [db_extarq] obj_extarqs = objects.ExtARQ.list(self.context) self.assertEqual(1, mock_get_list.call_count) @@ -67,8 +68,9 @@ class TestExtARQObject(base.DbTestCase): def test_create(self, mock_from_db_obj): db_extarq = self.fake_db_extarqs[0] mock_from_db_obj.return_value = self.fake_obj_extarqs[0] - with mock.patch.object(self.dbapi, 'extarq_create', - autospec=True) as mock_extarq_create: + with mock.patch.object( + self.dbapi, 'extarq_create', autospec=True + ) as mock_extarq_create: mock_extarq_create.return_value = db_extarq extarq = objects.ExtARQ(self.context, **db_extarq) extarq.arq = objects.ARQ(self.context, **db_extarq) @@ -80,25 +82,38 @@ class TestExtARQObject(base.DbTestCase): @mock.patch('cyborg.objects.ExtARQ.bind') @mock.patch('cyborg.objects.ExtARQ.get') def test_apply_patch_to_bad_arq_state( - self, mock_get, mock_bind, mock_notify_bind, mock_conn): + self, mock_get, mock_bind, mock_notify_bind, mock_conn + ): good_states = constants.ARQ_STATES_TRANSFORM_MATRIX[ - constants.ARQ_BIND_STARTED] + constants.ARQ_BIND_STARTED + ] mock_get.return_value = obj_extarq = self.fake_obj_extarqs[0] uuid = obj_extarq.arq.uuid instance_uuid = obj_extarq.arq.instance_uuid valid_fields = { - uuid: {'hostname': obj_extarq.arq.hostname, - 'device_rp_uuid': obj_extarq.arq.device_rp_uuid, - 'instance_uuid': instance_uuid} + uuid: { + 'hostname': obj_extarq.arq.hostname, + 'device_rp_uuid': obj_extarq.arq.device_rp_uuid, + 'instance_uuid': instance_uuid, } + } patch_list = { str(uuid): [ - {"path": "/hostname", "op": "add", - "value": obj_extarq.arq.hostname}, - {"path": "/device_rp_uuid", "op": "add", - "value": obj_extarq.arq.device_rp_uuid}, - {"path": "/instance_uuid", "op": "add", - "value": instance_uuid} + { + "path": "/hostname", + "op": "add", + "value": obj_extarq.arq.hostname, + }, + { + "path": "/device_rp_uuid", + "op": "add", + "value": obj_extarq.arq.device_rp_uuid, + }, + { + "path": "/instance_uuid", + "op": "add", + "value": instance_uuid, + }, ] } @@ -106,14 +121,20 @@ class TestExtARQObject(base.DbTestCase): obj_extarq.arq.state = state mock_get.return_value = obj_extarq self.assertRaises( - exception.ARQBadState, objects.ExtARQ.apply_patch, - self.context, patch_list, valid_fields) + exception.ARQBadState, + objects.ExtARQ.apply_patch, + self.context, + patch_list, + valid_fields, + ) mock_notify_bind.assert_not_called() @mock.patch('cyborg.objects.ExtARQ.save') - @mock.patch('cyborg.objects.extarq.ext_arq_job.ExtARQJobMixin.' - 'get_arq_bind_statuses') + @mock.patch( + 'cyborg.objects.extarq.ext_arq_job.ExtARQJobMixin.' + 'get_arq_bind_statuses' + ) @mock.patch('openstack.connection.Connection') @mock.patch('cyborg.common.nova_client.NovaAPI.notify_binding') @mock.patch('cyborg.objects.ExtARQ._allocate_attach_handle') @@ -122,12 +143,20 @@ class TestExtARQObject(base.DbTestCase): @mock.patch('cyborg.objects.ExtARQ.update_check_state') @mock.patch('cyborg.objects.deployable.Deployable.get_by_device_rp_uuid') def test_apply_patch_for_common_extarq( - self, mock_get_dep, mock_check_state, mock_list, mock_get, - mock_attach_handle, mock_notify_bind, mock_conn, mock_get_bind_st, - mock_save): - + self, + mock_get_dep, + mock_check_state, + mock_list, + mock_get, + mock_attach_handle, + mock_notify_bind, + mock_conn, + mock_get_bind_st, + mock_save, + ): good_states = constants.ARQ_STATES_TRANSFORM_MATRIX[ - constants.ARQ_BIND_STARTED] + constants.ARQ_BIND_STARTED + ] obj_extarq = self.fake_obj_extarqs[0] obj_extarq.arq.state = good_states[0] @@ -135,7 +164,8 @@ class TestExtARQObject(base.DbTestCase): # remains as 'Initial'. So, we mock get_arq_bind_statuses to # prevent that from raising exception. mock_get_bind_st.return_value = [ - (obj_extarq.arq.uuid, constants.ARQ_BIND_STATUS_FINISH)] + (obj_extarq.arq.uuid, constants.ARQ_BIND_STATUS_FINISH) + ] # TODO(Shaohe) we should control the state of arq to make # better testcase. @@ -149,22 +179,34 @@ class TestExtARQObject(base.DbTestCase): instance_uuid = obj_extarq.arq.instance_uuid dep_uuid = self.deployable_uuids[0] - fake_dep = fake_deployable.fake_deployable_obj(self.context, - uuid=dep_uuid) + fake_dep = fake_deployable.fake_deployable_obj( + self.context, uuid=dep_uuid + ) mock_get_dep.return_value = fake_dep valid_fields = { - uuid: {'hostname': obj_extarq.arq.hostname, - 'device_rp_uuid': obj_extarq.arq.device_rp_uuid, - 'instance_uuid': instance_uuid} + uuid: { + 'hostname': obj_extarq.arq.hostname, + 'device_rp_uuid': obj_extarq.arq.device_rp_uuid, + 'instance_uuid': instance_uuid, } + } patch_list = { str(uuid): [ - {"path": "/hostname", "op": "add", - "value": obj_extarq.arq.hostname}, - {"path": "/device_rp_uuid", "op": "add", - "value": obj_extarq.arq.device_rp_uuid}, - {"path": "/instance_uuid", "op": "add", - "value": instance_uuid} + { + "path": "/hostname", + "op": "add", + "value": obj_extarq.arq.hostname, + }, + { + "path": "/device_rp_uuid", + "op": "add", + "value": obj_extarq.arq.device_rp_uuid, + }, + { + "path": "/instance_uuid", + "op": "add", + "value": instance_uuid, + }, ] } objects.ExtARQ.apply_patch(self.context, patch_list, valid_fields) @@ -174,13 +216,16 @@ class TestExtARQObject(base.DbTestCase): self.assertEqual(obj_extarq.arq.state, 'Initial') mock_notify_bind.assert_called_once_with( instance_uuid, - [(obj_extarq.arq.uuid, constants.ARQ_BIND_STATUS_FINISH)]) + [(obj_extarq.arq.uuid, constants.ARQ_BIND_STATUS_FINISH)], + ) self.assertEqual(obj_extarq.deployable_id, fake_dep.id) mock_save.assert_called_once() - @mock.patch('cyborg.objects.extarq.ext_arq_job.ExtARQJobMixin.' - 'get_arq_bind_statuses') + @mock.patch( + 'cyborg.objects.extarq.ext_arq_job.ExtARQJobMixin.' + 'get_arq_bind_statuses' + ) @mock.patch('openstack.connection.Connection') @mock.patch('cyborg.common.nova_client.NovaAPI.notify_binding') @mock.patch('cyborg.objects.ExtARQ._allocate_attach_handle') @@ -190,10 +235,20 @@ class TestExtARQObject(base.DbTestCase): @mock.patch('cyborg.objects.deployable.Deployable.get_by_device_rp_uuid') @mock.patch('cyborg.common.utils.ThreadWorks.spawn') def test_apply_patch_start_fpga_arq_job( - self, mock_spawn, mock_get_dep, mock_check_state, mock_list, mock_get, - mock_attach_handle, mock_notify_bind, mock_conn, mock_get_bind_st): + self, + mock_spawn, + mock_get_dep, + mock_check_state, + mock_list, + mock_get, + mock_attach_handle, + mock_notify_bind, + mock_conn, + mock_get_bind_st, + ): good_states = constants.ARQ_STATES_TRANSFORM_MATRIX[ - constants.ARQ_BIND_STARTED] + constants.ARQ_BIND_STARTED + ] obj_extarq = self.fake_obj_extarqs[2] obj_fpga_extarq = self.fake_obj_fpga_extarqs[1] obj_fpga_extarq.state = self.fake_obj_fpga_extarqs[1] @@ -211,25 +266,38 @@ class TestExtARQObject(base.DbTestCase): instance_uuid = obj_extarq.arq.instance_uuid # mock_job_get_ext_arq.side_effect = obj_extarq dep_uuid = self.deployable_uuids[0] - fake_dep = fake_deployable.fake_deployable_obj(self.context, - uuid=dep_uuid) + fake_dep = fake_deployable.fake_deployable_obj( + self.context, uuid=dep_uuid + ) mock_get_bind_st.return_value = [ - (obj_extarq.arq.uuid, constants.ARQ_BIND_STATUS_FINISH)] + (obj_extarq.arq.uuid, constants.ARQ_BIND_STATUS_FINISH) + ] mock_get_dep.return_value = fake_dep mock_spawn.return_value = None valid_fields = { - uuid: {'hostname': obj_extarq.arq.hostname, - 'device_rp_uuid': obj_extarq.arq.device_rp_uuid, - 'instance_uuid': instance_uuid} + uuid: { + 'hostname': obj_extarq.arq.hostname, + 'device_rp_uuid': obj_extarq.arq.device_rp_uuid, + 'instance_uuid': instance_uuid, } + } patch_list = { str(uuid): [ - {"path": "/hostname", "op": "add", - "value": obj_extarq.arq.hostname}, - {"path": "/device_rp_uuid", "op": "add", - "value": obj_extarq.arq.device_rp_uuid}, - {"path": "/instance_uuid", "op": "add", - "value": instance_uuid} + { + "path": "/hostname", + "op": "add", + "value": obj_extarq.arq.hostname, + }, + { + "path": "/device_rp_uuid", + "op": "add", + "value": obj_extarq.arq.device_rp_uuid, + }, + { + "path": "/instance_uuid", + "op": "add", + "value": instance_uuid, + }, ] } objects.ExtARQ.apply_patch(self.context, patch_list, valid_fields) @@ -239,10 +307,12 @@ class TestExtARQObject(base.DbTestCase): self.assertEqual(obj_extarq.arq.state, 'Initial') mock_notify_bind.assert_called_once_with( instance_uuid, - [(obj_extarq.arq.uuid, constants.ARQ_BIND_STATUS_FINISH)]) + [(obj_extarq.arq.uuid, constants.ARQ_BIND_STATUS_FINISH)], + ) # NOTE(Shaohe) check it spawn to start a job. mock_spawn.assert_called_once_with( - obj_fpga_extarq.bind, self.context, fake_dep) + obj_fpga_extarq.bind, self.context, fake_dep + ) @mock.patch('openstack.connection.Connection') @mock.patch('cyborg.common.nova_client.NovaAPI.notify_binding') @@ -253,11 +323,19 @@ class TestExtARQObject(base.DbTestCase): @mock.patch('cyborg.objects.deployable.Deployable.get_by_device_rp_uuid') @mock.patch('cyborg.common.utils.ThreadWorks.spawn_master') def test_apply_patch_fpga_arq_monitor_job( - self, mock_master, mock_get_dep, mock_check_state, mock_list, - mock_get, mock_attach_handle, mock_notify_bind, mock_conn): - + self, + mock_master, + mock_get_dep, + mock_check_state, + mock_list, + mock_get, + mock_attach_handle, + mock_notify_bind, + mock_conn, + ): good_states = constants.ARQ_STATES_TRANSFORM_MATRIX[ - constants.ARQ_BIND_STARTED] + constants.ARQ_BIND_STARTED + ] obj_extarq = self.fake_obj_extarqs[2] obj_fpga_extarq = self.fake_obj_fpga_extarqs[1] obj_fpga_extarq.state = self.fake_obj_fpga_extarqs[1] @@ -274,22 +352,34 @@ class TestExtARQObject(base.DbTestCase): uuid = obj_extarq.arq.uuid instance_uuid = obj_extarq.arq.instance_uuid dep_uuid = self.deployable_uuids[0] - fake_dep = fake_deployable.fake_deployable_obj(self.context, - uuid=dep_uuid) + fake_dep = fake_deployable.fake_deployable_obj( + self.context, uuid=dep_uuid + ) mock_get_dep.return_value = fake_dep valid_fields = { - uuid: {'hostname': obj_extarq.arq.hostname, - 'device_rp_uuid': obj_extarq.arq.device_rp_uuid, - 'instance_uuid': instance_uuid} + uuid: { + 'hostname': obj_extarq.arq.hostname, + 'device_rp_uuid': obj_extarq.arq.device_rp_uuid, + 'instance_uuid': instance_uuid, } + } patch_list = { str(uuid): [ - {"path": "/hostname", "op": "add", - "value": obj_extarq.arq.hostname}, - {"path": "/device_rp_uuid", "op": "add", - "value": obj_extarq.arq.device_rp_uuid}, - {"path": "/instance_uuid", "op": "add", - "value": instance_uuid} + { + "path": "/hostname", + "op": "add", + "value": obj_extarq.arq.hostname, + }, + { + "path": "/device_rp_uuid", + "op": "add", + "value": obj_extarq.arq.device_rp_uuid, + }, + { + "path": "/instance_uuid", + "op": "add", + "value": instance_uuid, + }, ] } objects.ExtARQ.apply_patch(self.context, patch_list, valid_fields) @@ -302,11 +392,13 @@ class TestExtARQObject(base.DbTestCase): uuid = db_extarq['uuid'] mock_from_db_obj.return_value = db_extarq mock_obj_extarq.return_value = self.fake_obj_extarqs[0] - with mock.patch.object(self.dbapi, 'extarq_get', - autospec=True) as mock_extarq_get: + with mock.patch.object( + self.dbapi, 'extarq_get', autospec=True + ) as mock_extarq_get: mock_extarq_get.return_value = db_extarq - with mock.patch.object(self.dbapi, 'extarq_delete', - autospec=True) as mock_extarq_delete: + with mock.patch.object( + self.dbapi, 'extarq_delete', autospec=True + ) as mock_extarq_delete: extarq = objects.ExtARQ.get(self.context, uuid) extarq.destroy(self.context) mock_extarq_delete.assert_called_once_with(self.context, uuid) @@ -315,45 +407,62 @@ class TestExtARQObject(base.DbTestCase): def test_allocate_attach_handle(self, mock_check_state): obj_extarq = self.fake_obj_extarqs[0] dep_uuid = self.deployable_uuids[0] - fake_dep = fake_deployable.fake_deployable_obj(self.context, - uuid=dep_uuid) + fake_dep = fake_deployable.fake_deployable_obj( + self.context, uuid=dep_uuid + ) self.assertRaises( exception.ResourceNotFound, - obj_extarq._allocate_attach_handle, self.context, fake_dep) + obj_extarq._allocate_attach_handle, + self.context, + fake_dep, + ) mock_check_state.assert_called_once_with( - self.context, constants.ARQ_BIND_FAILED) + self.context, constants.ARQ_BIND_FAILED + ) @mock.patch('logging.LoggerAdapter.error') @mock.patch('cyborg.objects.attach_handle.AttachHandle.allocate') @mock.patch('cyborg.objects.ExtARQ.update_check_state') def test_allocate_attach_handle_with_error_log( - self, mock_check_state, mock_allocate, mock_log): + self, mock_check_state, mock_allocate, mock_log + ): obj_extarq = self.fake_obj_extarqs[0] dep_uuid = self.deployable_uuids[0] e = exception.ResourceNotFound( - resource='AttachHandle', msg="Just for Test") - msg = ("Failed to allocate attach handle for ARQ %s" - "from deployable %s. Reason: %s") + resource='AttachHandle', msg="Just for Test" + ) + msg = ( + "Failed to allocate attach handle for ARQ %s" + "from deployable %s. Reason: %s" + ) mock_allocate.side_effect = e - fake_dep = fake_deployable.fake_deployable_obj(self.context, - uuid=dep_uuid) + fake_dep = fake_deployable.fake_deployable_obj( + self.context, uuid=dep_uuid + ) self.assertRaises( exception.ResourceNotFound, - obj_extarq._allocate_attach_handle, self.context, fake_dep) + obj_extarq._allocate_attach_handle, + self.context, + fake_dep, + ) mock_log.assert_called_once_with( - msg, obj_extarq.arq.uuid, fake_dep.uuid, str(e)) + msg, obj_extarq.arq.uuid, fake_dep.uuid, str(e) + ) mock_check_state.assert_called_once_with( - self.context, constants.ARQ_BIND_FAILED) + self.context, constants.ARQ_BIND_FAILED + ) @mock.patch('cyborg.objects.ExtARQ.update_check_state') @mock.patch('cyborg.objects.attach_handle.AttachHandle.get_by_id') @mock.patch('cyborg.objects.attach_handle.AttachHandle.deallocate') def test_deallocate_attach_handle( - self, mock_deallocate, mock_ah, mock_check_state): + self, mock_deallocate, mock_ah, mock_check_state + ): obj_extarq = self.fake_obj_extarqs[0] mock_ah.return_value = self.fake_obj_ahs[0] obj_extarq._deallocate_attach_handle( - self.context, mock_ah.id, obj_extarq.arq.hostname) + self.context, mock_ah.id, obj_extarq.arq.hostname + ) mock_check_state.assert_not_called() @mock.patch('logging.LoggerAdapter.error') @@ -361,22 +470,28 @@ class TestExtARQObject(base.DbTestCase): @mock.patch('cyborg.objects.attach_handle.AttachHandle.get_by_id') @mock.patch('cyborg.objects.attach_handle.AttachHandle.deallocate') def test_deallocate_attach_handle_with_error_log( - self, mock_ah, mock_deallocate, mock_check_state, mock_log): + self, mock_ah, mock_deallocate, mock_check_state, mock_log + ): obj_extarq = self.fake_obj_extarqs[0] mock_ah.return_value = self.fake_obj_ahs[0] e = exception.ResourceNotFound( - resource='AttachHandle', msg="Just for Test") - msg = ("Failed to deallocate attach handle %s for ARQ %s." - "Reason: %s") + resource='AttachHandle', msg="Just for Test" + ) + msg = "Failed to deallocate attach handle %s for ARQ %s.Reason: %s" mock_deallocate.side_effect = e self.assertRaises( exception.ResourceNotFound, - obj_extarq._deallocate_attach_handle, self.context, mock_ah.id, - obj_extarq.arq.hostname) + obj_extarq._deallocate_attach_handle, + self.context, + mock_ah.id, + obj_extarq.arq.hostname, + ) mock_log.assert_called_once_with( - msg, mock_ah.id, obj_extarq.arq.uuid, str(e)) + msg, mock_ah.id, obj_extarq.arq.uuid, str(e) + ) mock_check_state.assert_called_once_with( - self.context, constants.ARQ_UNBIND_FAILED) + self.context, constants.ARQ_UNBIND_FAILED + ) @mock.patch('cyborg.objects.ExtARQ.get') @mock.patch('cyborg.objects.ExtARQ._from_db_object') @@ -385,8 +500,9 @@ class TestExtARQObject(base.DbTestCase): uuid = db_extarq['uuid'] mock_from_db_obj.return_value = db_extarq mock_obj_extarq.return_value = self.fake_obj_extarqs[0] - with mock.patch.object(self.dbapi, 'extarq_update', - autospec=True) as mock_extarq_update: + with mock.patch.object( + self.dbapi, 'extarq_update', autospec=True + ) as mock_extarq_update: obj_extarq = objects.ExtARQ.get(self.context, uuid) obj_extarq.arq.hostname = 'newtestnode1' fake_arq_updated = db_extarq @@ -398,8 +514,7 @@ class TestExtARQObject(base.DbTestCase): def test_get_arq_bind_statuses(self): # ARQ state is 'Bound' by default in the fake extarqs arq_list = [extarq.arq for extarq in self.fake_obj_extarqs] - bind_status = constants.ARQ_BIND_STATES_STATUS_MAP[ - constants.ARQ_BOUND] + bind_status = constants.ARQ_BIND_STATES_STATUS_MAP[constants.ARQ_BOUND] expected = [(arq.uuid, bind_status) for arq in arq_list] result = objects.ExtARQ.get_arq_bind_statuses(arq_list) @@ -412,17 +527,19 @@ class TestExtARQObject(base.DbTestCase): for arq in arq_list: arq['state'] = constants.ARQ_INITIAL - self.assertRaises(exception.ARQBadState, - objects.ExtARQ.get_arq_bind_statuses, arq_list) + self.assertRaises( + exception.ARQBadState, + objects.ExtARQ.get_arq_bind_statuses, + arq_list, + ) @mock.patch('cyborg.objects.device_profile.DeviceProfile.get_by_name') @mock.patch('cyborg.objects.deployable.Deployable.get_by_id') - @mock.patch('cyborg.db.sqlalchemy.api.Connection.' - 'attach_handle_get_by_id') - @mock.patch('cyborg.db.sqlalchemy.api.Connection.' - 'device_profile_get_by_id') - def test_fill_obj_extarq_fields(self, mock_get_devprof, mock_get_ah, - mock_get_obj_dep, mock_obj_devprof): + @mock.patch('cyborg.db.sqlalchemy.api.Connection.attach_handle_get_by_id') + @mock.patch('cyborg.db.sqlalchemy.api.Connection.device_profile_get_by_id') + def test_fill_obj_extarq_fields( + self, mock_get_devprof, mock_get_ah, mock_get_obj_dep, mock_obj_devprof + ): in_db_extarq = self.fake_db_extarqs[0] # Since state is not 'Bound', attach_handle_get_by_id is not called. in_db_extarq['state'] = 'Initial' @@ -434,28 +551,31 @@ class TestExtARQObject(base.DbTestCase): mock_obj_devprof.return_value = obj_devprof out_db_extarq = objects.ExtARQ._fill_obj_extarq_fields( - self.context, in_db_extarq) + self.context, in_db_extarq + ) - self.assertEqual(out_db_extarq['device_profile_name'], - db_devprof['name']) + self.assertEqual( + out_db_extarq['device_profile_name'], db_devprof['name'] + ) self.assertEqual(out_db_extarq['attach_handle_type'], '') self.assertEqual(out_db_extarq['attach_handle_info'], '') self.assertEqual( out_db_extarq['deployable_uuid'], - '00000000-0000-0000-0000-000000000000') + '00000000-0000-0000-0000-000000000000', + ) devprof_group_id = out_db_extarq['device_profile_group_id'] - self.assertEqual(out_db_extarq['device_profile_group'], - obj_devprof['groups'][devprof_group_id]) + self.assertEqual( + out_db_extarq['device_profile_group'], + obj_devprof['groups'][devprof_group_id], + ) @mock.patch('cyborg.objects.device_profile.DeviceProfile.get_by_name') @mock.patch('cyborg.objects.deployable.Deployable.get_by_id') - @mock.patch('cyborg.db.sqlalchemy.api.Connection.' - 'attach_handle_get_by_id') - @mock.patch('cyborg.db.sqlalchemy.api.Connection.' - 'device_profile_get_by_id') + @mock.patch('cyborg.db.sqlalchemy.api.Connection.attach_handle_get_by_id') + @mock.patch('cyborg.db.sqlalchemy.api.Connection.device_profile_get_by_id') def test_fill_obj_extarq_fields_with_dep_id( - self, mock_get_devprof, mock_get_ah, mock_get_obj_dep, - mock_obj_devprof): + self, mock_get_devprof, mock_get_ah, mock_get_obj_dep, mock_obj_devprof + ): in_db_extarq = self.fake_db_extarqs[0] # Since state is not 'Bound', attach_handle_get_by_id is not called. in_db_extarq['state'] = 'Initial' @@ -465,20 +585,24 @@ class TestExtARQObject(base.DbTestCase): mock_get_devprof.return_value = db_devprof mock_obj_devprof.return_value = obj_devprof - mock_get_obj_dep.return_value = \ - fake_deployable.fake_deployable_obj(self.context, - uuid=self.deployable_uuids[0]) + mock_get_obj_dep.return_value = fake_deployable.fake_deployable_obj( + self.context, uuid=self.deployable_uuids[0] + ) out_db_extarq = objects.ExtARQ._fill_obj_extarq_fields( - self.context, in_db_extarq) + self.context, in_db_extarq + ) - self.assertEqual(out_db_extarq['device_profile_name'], - db_devprof['name']) + self.assertEqual( + out_db_extarq['device_profile_name'], db_devprof['name'] + ) self.assertEqual(out_db_extarq['attach_handle_type'], '') self.assertEqual(out_db_extarq['attach_handle_info'], '') devprof_group_id = out_db_extarq['device_profile_group_id'] - self.assertEqual(out_db_extarq['device_profile_group'], - obj_devprof['groups'][devprof_group_id]) + self.assertEqual( + out_db_extarq['device_profile_group'], + obj_devprof['groups'][devprof_group_id], + ) def test_obj_make_compatible(self): arq_obj = objects.ExtARQ(deployable_id=1) diff --git a/cyborg/tests/unit/objects/test_fpga_ext_arq.py b/cyborg/tests/unit/objects/test_fpga_ext_arq.py index 7c995621..5da83d15 100644 --- a/cyborg/tests/unit/objects/test_fpga_ext_arq.py +++ b/cyborg/tests/unit/objects/test_fpga_ext_arq.py @@ -31,42 +31,60 @@ from cyborg.tests.unit import fake_extarq class TestFPGAExtARQObject(base.DbTestCase): - def setUp(self): super().setUp() self.fake_fpga_db_extarqs = fake_extarq.get_fake_fpga_db_extarqs() self.fake_obj_fpga_extarqs = fake_extarq.get_fake_fpga_extarq_objs() - classes = ["no_program", "bitstream_program", - "function_program", "bad_program"] + classes = [ + "no_program", + "bitstream_program", + "function_program", + "bad_program", + ] self.class_fgpa_objects = dict( - zip(classes, self.fake_obj_fpga_extarqs)) - self.class_fgpa_dbs = dict( - zip(classes, self.fake_fpga_db_extarqs)) + zip(classes, self.fake_obj_fpga_extarqs) + ) + self.class_fgpa_dbs = dict(zip(classes, self.fake_fpga_db_extarqs)) self.bitstream_id = self.class_fgpa_objects["bitstream_program"][ - "device_profile_group"][constants.ACCEL_BITSTREAM_ID] + "device_profile_group" + ][constants.ACCEL_BITSTREAM_ID] self.function_id = self.class_fgpa_objects["function_program"][ - "device_profile_group"][constants.ACCEL_FUNCTION_ID] + "device_profile_group" + ][constants.ACCEL_FUNCTION_ID] self.images_md = { - "/images": - [ - {"id": self.bitstream_id, - "tags": ["trait:CUSTOM_FPGA_INTEL"], - constants.ACCEL_FUNCTION_ID: self.function_id}, + "/images": [ + { + "id": self.bitstream_id, + "tags": ["trait:CUSTOM_FPGA_INTEL"], + constants.ACCEL_FUNCTION_ID: self.function_id, + }, ] } self.deployable_uuids = ['0acbf8d6-e02a-4394-aae3-57557d209498'] - self.bdf = {"domain": "0000", "bus": "00", - "device": "01", "function": "1"} + self.bdf = { + "domain": "0000", + "bus": "00", + "device": "01", + "function": "1", + } self.cpid = { "id": 0, "uuid": "e4a66b0d-b377-40d6-9cdc-6bf7e720e596", "device_id": "1", "cpid_type": "PCI", - "cpid_info": json.dumps(self.bdf).encode('utf-8') + "cpid_info": json.dumps(self.bdf).encode('utf-8'), } - def response(self, status_code=200, content='', headers=None, - reason=None, elapsed=0, request=None, stream=False): + def response( + self, + status_code=200, + content='', + headers=None, + reason=None, + elapsed=0, + request=None, + stream=False, + ): res = requests.Response() res.status_code = status_code if isinstance(content, dict | list): @@ -113,8 +131,7 @@ class TestFPGAExtARQObject(base.DbTestCase): db_extarq = self.fake_fpga_db_extarqs[0] uuid = db_extarq['uuid'] mock_from_db_obj.return_value = self.fake_obj_fpga_extarqs[0] - with mock.patch.object( - self.dbapi, 'extarq_get') as mock_extarq_get: + with mock.patch.object(self.dbapi, 'extarq_get') as mock_extarq_get: mock_extarq_get.return_value = db_extarq obj_extarq = objects.FPGAExtARQ.get(self.context, uuid) mock_extarq_get.assert_called_once_with(self.context, uuid) @@ -125,11 +142,11 @@ class TestFPGAExtARQObject(base.DbTestCase): def test_get_bitstream_id_return_UUID(self, mock_from_db_obj): db_extarq = self.fake_fpga_db_extarqs[1] bit_id = db_extarq['device_profile_group'][ - constants.ACCEL_BITSTREAM_ID] + constants.ACCEL_BITSTREAM_ID + ] uuid = db_extarq['uuid'] mock_from_db_obj.return_value = self.fake_obj_fpga_extarqs[1] - with mock.patch.object( - self.dbapi, 'extarq_get') as mock_extarq_get: + with mock.patch.object(self.dbapi, 'extarq_get') as mock_extarq_get: mock_extarq_get.return_value = db_extarq obj_extarq = objects.FPGAExtARQ.get(self.context, uuid) mock_extarq_get.assert_called_once_with(self.context, uuid) @@ -142,8 +159,7 @@ class TestFPGAExtARQObject(base.DbTestCase): fun_id = db_extarq['device_profile_group'][constants.ACCEL_FUNCTION_ID] uuid = db_extarq['uuid'] mock_from_db_obj.return_value = self.fake_obj_fpga_extarqs[2] - with mock.patch.object( - self.dbapi, 'extarq_get') as mock_extarq_get: + with mock.patch.object(self.dbapi, 'extarq_get') as mock_extarq_get: mock_extarq_get.return_value = db_extarq obj_extarq = objects.FPGAExtARQ.get(self.context, uuid) mock_extarq_get.assert_called_once_with(self.context, uuid) @@ -154,8 +170,9 @@ class TestFPGAExtARQObject(base.DbTestCase): def test_list(self, mock_from_db_obj_list): db_extarqs = self.fake_fpga_db_extarqs mock_from_db_obj_list.return_value = self.fake_obj_fpga_extarqs - with mock.patch.object(self.dbapi, 'extarq_list', - autospec=True) as mock_get_list: + with mock.patch.object( + self.dbapi, 'extarq_list', autospec=True + ) as mock_get_list: mock_get_list.return_value = db_extarqs obj_extarqs = objects.FPGAExtARQ.list(self.context) self.assertEqual(1, mock_get_list.call_count) @@ -167,8 +184,10 @@ class TestFPGAExtARQObject(base.DbTestCase): @mock.patch('openstack.connection.Connection') def test_get_bitstream_md_from_bitstream_id(self, mock_conn): mock_conn.return_value = type( - "Connection", (object,), - {"image": type("image", (object,), {"get": self.images_get})}) + "Connection", + (object,), + {"image": type("image", (object,), {"get": self.images_get})}, + ) obj_extarq = self.class_fgpa_objects["bitstream_program"] md = obj_extarq._get_bitstream_md_from_bitstream_id(self.bitstream_id) self.assertDictEqual(self.images_md["/images"][0], md) @@ -176,8 +195,10 @@ class TestFPGAExtARQObject(base.DbTestCase): @mock.patch('openstack.connection.Connection') def test_get_bitstream_md_from_function_id(self, mock_conn): mock_conn.return_value = type( - "Connection", (object,), - {"image": type("image", (object,), {"get": self.images_get})}) + "Connection", + (object,), + {"image": type("image", (object,), {"get": self.images_get})}, + ) obj_extarq = self.class_fgpa_objects["function_program"] md = obj_extarq._get_bitstream_md_from_function_id(self.function_id) self.assertDictEqual(self.images_md["/images"][0], md) @@ -186,11 +207,14 @@ class TestFPGAExtARQObject(base.DbTestCase): @mock.patch('cyborg.objects.ExtARQ.update_check_state') def test_needs_programming(self, mock_check_state, mock_conn): mock_conn.return_value = type( - "Connection", (object,), - {"image": type("image", (object,), {"get": self.images_get})}) + "Connection", + (object,), + {"image": type("image", (object,), {"get": self.images_get})}, + ) dep_uuid = self.deployable_uuids[0] fake_dep = fake_deployable.fake_deployable_obj( - self.context, uuid=dep_uuid) + self.context, uuid=dep_uuid + ) # function_id is require obj_extarq = self.class_fgpa_objects["function_program"] @@ -205,8 +229,11 @@ class TestFPGAExtARQObject(base.DbTestCase): # Both bitstream_id and function_id are require obj_extarq = self.class_fgpa_objects["bad_program"] self.assertRaises( - exception.InvalidParameterValue, obj_extarq._needs_programming, - self.context, fake_dep) + exception.InvalidParameterValue, + obj_extarq._needs_programming, + self.context, + fake_dep, + ) # None of bitstream_id or function_id is require obj_extarq = self.class_fgpa_objects["no_program"] @@ -217,26 +244,27 @@ class TestFPGAExtARQObject(base.DbTestCase): @mock.patch('cyborg.objects.ExtARQ.update_check_state') def test_get_bitstream_md(self, mock_check_state, mock_conn): mock_conn.return_value = type( - "Connection", (object,), - {"image": type("image", (object,), {"get": self.images_get})}) + "Connection", + (object,), + {"image": type("image", (object,), {"get": self.images_get})}, + ) dep_uuid = self.deployable_uuids[0] fake_dep = fake_deployable.fake_deployable_obj( - self.context, uuid=dep_uuid) + self.context, uuid=dep_uuid + ) # function_id is require obj_extarq = self.class_fgpa_objects["function_program"] bs_id = obj_extarq._get_bitstream_id() fun_id = obj_extarq._get_function_id() - md = obj_extarq.get_bitstream_md( - self.context, fake_dep, fun_id, bs_id) + md = obj_extarq.get_bitstream_md(self.context, fake_dep, fun_id, bs_id) self.assertDictEqual(self.images_md["/images"][0], md) # bitstream_id is require obj_extarq = self.class_fgpa_objects["bitstream_program"] bs_id = obj_extarq._get_bitstream_id() fun_id = obj_extarq._get_function_id() - md = obj_extarq.get_bitstream_md( - self.context, fake_dep, fun_id, bs_id) + md = obj_extarq.get_bitstream_md(self.context, fake_dep, fun_id, bs_id) self.assertDictEqual(self.images_md["/images"][0], md) # Both bitstream_id and function_id are require @@ -244,26 +272,33 @@ class TestFPGAExtARQObject(base.DbTestCase): bs_id = obj_extarq._get_bitstream_id() fun_id = obj_extarq._get_function_id() self.assertRaises( - exception.InvalidParameterValue, obj_extarq.get_bitstream_md, - self.context, fake_dep, fun_id, bs_id) + exception.InvalidParameterValue, + obj_extarq.get_bitstream_md, + self.context, + fake_dep, + fun_id, + bs_id, + ) # None of bitstream_id or function_id is require obj_extarq = self.class_fgpa_objects["no_program"] bs_id = obj_extarq._get_bitstream_id() fun_id = obj_extarq._get_function_id() - md = obj_extarq.get_bitstream_md( - self.context, fake_dep, fun_id, bs_id) + md = obj_extarq.get_bitstream_md(self.context, fake_dep, fun_id, bs_id) self.assertIsNone(md) @mock.patch('openstack.connection.Connection') @mock.patch('cyborg.objects.ExtARQ.update_check_state') def test_need_extra_bind_job(self, mock_check_state, mock_conn): mock_conn.return_value = type( - "Connection", (object,), - {"image": type("image", (object,), {"get": self.images_get})}) + "Connection", + (object,), + {"image": type("image", (object,), {"get": self.images_get})}, + ) dep_uuid = self.deployable_uuids[0] fake_dep = fake_deployable.fake_deployable_obj( - self.context, uuid=dep_uuid) + self.context, uuid=dep_uuid + ) # function_id is require obj_extarq = self.class_fgpa_objects["function_program"] @@ -278,8 +313,11 @@ class TestFPGAExtARQObject(base.DbTestCase): # Both bitstream_id and function_id are require obj_extarq = self.class_fgpa_objects["bad_program"] self.assertRaises( - exception.InvalidParameterValue, obj_extarq._need_extra_bind_job, - self.context, fake_dep) + exception.InvalidParameterValue, + obj_extarq._need_extra_bind_job, + self.context, + fake_dep, + ) # None of bitstream_id or function_id is require obj_extarq = self.class_fgpa_objects["no_program"] @@ -290,11 +328,13 @@ class TestFPGAExtARQObject(base.DbTestCase): @mock.patch('openstack.connection.Connection') @mock.patch('cyborg.objects.ExtARQ.update_check_state') @mock.patch('cyborg.objects.Deployable.get_cpid_list') - def test_do_programming(self, mock_cpid_list, mock_check_state, mock_conn, - mock_program): + def test_do_programming( + self, mock_cpid_list, mock_check_state, mock_conn, mock_program + ): dep_uuid = self.deployable_uuids[0] fake_dep = fake_deployable.fake_deployable_obj( - self.context, uuid=dep_uuid) + self.context, uuid=dep_uuid + ) # function_id is require obj_extarq = self.class_fgpa_objects["function_program"] @@ -305,39 +345,48 @@ class TestFPGAExtARQObject(base.DbTestCase): bs_id = obj_extarq._get_bitstream_id() obj_extarq._do_programming(self.context, fake_dep, bs_id) mock_program.assert_called_once_with( - self.context, 'newtestnode1', self.cpid, bs_id, "intel_fpga") + self.context, 'newtestnode1', self.cpid, bs_id, "intel_fpga" + ) @mock.patch('cyborg.objects.ExtARQ.update_check_state') @mock.patch('cyborg.objects.Deployable.get_cpid_list') - def test_do_programming_with_not_one_cp(self, mock_cpid_list, - mock_check_state): + def test_do_programming_with_not_one_cp( + self, mock_cpid_list, mock_check_state + ): dep_uuid = self.deployable_uuids[0] fake_dep = fake_deployable.fake_deployable_obj( - self.context, uuid=dep_uuid) + self.context, uuid=dep_uuid + ) mock_cpid_list.return_value = [] obj_extarq = self.class_fgpa_objects["function_program"] obj_extarq.arq.hostname = 'newtestnode1' fake_dep.driver_name = "intel_fpga" bs_id = obj_extarq._get_bitstream_id() - self.assertRaises(exception.ExpectedOneObject, - obj_extarq._do_programming, - self.context, - fake_dep, - bs_id) + self.assertRaises( + exception.ExpectedOneObject, + obj_extarq._do_programming, + self.context, + fake_dep, + bs_id, + ) - @mock.patch('cyborg.common.placement_client.PlacementClient.' - '__init__') - @mock.patch('cyborg.common.placement_client.PlacementClient.' - 'add_traits_to_rp') - @mock.patch('cyborg.common.placement_client.PlacementClient.' - 'delete_traits_with_prefixes') + @mock.patch('cyborg.common.placement_client.PlacementClient.__init__') + @mock.patch( + 'cyborg.common.placement_client.PlacementClient.add_traits_to_rp' + ) + @mock.patch( + 'cyborg.common.placement_client.PlacementClient.' + 'delete_traits_with_prefixes' + ) def test_update_placement( - self, mock_delete_traits, mock_add_traits, mock_placement_init): + self, mock_delete_traits, mock_add_traits, mock_placement_init + ): mock_placement_init.return_value = None dep_uuid = self.deployable_uuids[0] fake_dep = fake_deployable.fake_deployable_obj( - self.context, uuid=dep_uuid) + self.context, uuid=dep_uuid + ) # function_id is require obj_extarq = self.class_fgpa_objects["function_program"] @@ -346,41 +395,56 @@ class TestFPGAExtARQObject(base.DbTestCase): fun_id = obj_extarq._get_function_id() rp_uuid = obj_extarq.arq.device_rp_uuid obj_extarq._update_placement( - self.context, fun_id, fake_dep.driver_name) + self.context, fun_id, fake_dep.driver_name + ) function_id = fun_id.upper().replace('-', '_-') vendor = fake_dep.driver_name.upper() - trait_names = ["_".join(( - constants.FPGA_FUNCTION_ID, vendor, function_id))] + trait_names = [ + "_".join((constants.FPGA_FUNCTION_ID, vendor, function_id)) + ] mock_add_traits.assert_called_once_with(rp_uuid, trait_names) mock_delete_traits.assert_called_once_with( - self.context, rp_uuid, [constants.FPGA_FUNCTION_ID]) + self.context, rp_uuid, [constants.FPGA_FUNCTION_ID] + ) @mock.patch('cyborg.agent.rpcapi.AgentAPI.fpga_program') @mock.patch('cyborg.objects.Deployable.get_cpid_list') @mock.patch('cyborg.objects.Deployable.update') - @mock.patch('cyborg.common.placement_client.PlacementClient.' - '__init__') - @mock.patch('cyborg.common.placement_client.PlacementClient.' - 'add_traits_to_rp') - @mock.patch('cyborg.common.placement_client.PlacementClient.' - 'delete_traits_with_prefixes') + @mock.patch('cyborg.common.placement_client.PlacementClient.__init__') + @mock.patch( + 'cyborg.common.placement_client.PlacementClient.add_traits_to_rp' + ) + @mock.patch( + 'cyborg.common.placement_client.PlacementClient.' + 'delete_traits_with_prefixes' + ) @mock.patch('openstack.connection.Connection') @mock.patch('cyborg.objects.ExtARQ.update_check_state') @mock.patch('cyborg.objects.ExtARQ.bind') def test_bind( - self, mock_bind, mock_check_state, mock_conn, mock_delete_traits, - mock_add_traits, mock_placement_init, mock_dp_update, - mock_cpid_list, mock_program): - + self, + mock_bind, + mock_check_state, + mock_conn, + mock_delete_traits, + mock_add_traits, + mock_placement_init, + mock_dp_update, + mock_cpid_list, + mock_program, + ): mock_placement_init.return_value = None mock_conn.return_value = type( - "Connection", (object,), - {"image": type("image", (object,), {"get": self.images_get})}) + "Connection", + (object,), + {"image": type("image", (object,), {"get": self.images_get})}, + ) dep_uuid = self.deployable_uuids[0] fake_dep = fake_deployable.fake_deployable_obj( - self.context, uuid=dep_uuid) + self.context, uuid=dep_uuid + ) mock_cpid_list.return_value = [self.cpid] diff --git a/cyborg/tests/unit/objects/test_objects.py b/cyborg/tests/unit/objects/test_objects.py index 5e0a1f4b..16ae09d0 100644 --- a/cyborg/tests/unit/objects/test_objects.py +++ b/cyborg/tests/unit/objects/test_objects.py @@ -31,18 +31,21 @@ class MyOwnedObject(base.CyborgPersistentObject, base.CyborgObject): fields = {'baz': fields.IntegerField()} -class MyObj(base.CyborgPersistentObject, base.CyborgObject, - base.CyborgObjectDictCompat): +class MyObj( + base.CyborgPersistentObject, base.CyborgObject, base.CyborgObjectDictCompat +): VERSION = '1.6' - fields = {'foo': fields.IntegerField(default=1), - 'bar': fields.StringField(), - 'missing': fields.StringField(), - 'readonly': fields.IntegerField(read_only=True), - 'rel_object': fields.ObjectField('MyOwnedObject', nullable=True), - 'rel_objects': fields.ListOfObjectsField('MyOwnedObject', - nullable=True), - 'mutable_default': fields.ListOfStringsField(default=[]), - } + fields = { + 'foo': fields.IntegerField(default=1), + 'bar': fields.StringField(), + 'missing': fields.StringField(), + 'readonly': fields.IntegerField(read_only=True), + 'rel_object': fields.ObjectField('MyOwnedObject', nullable=True), + 'rel_objects': fields.ListOfObjectsField( + 'MyOwnedObject', nullable=True + ), + 'mutable_default': fields.ListOfStringsField(default=[]), + } @staticmethod def _from_db_object(context, obj, db_obj): @@ -92,6 +95,7 @@ class MyObj(base.CyborgPersistentObject, base.CyborgObject, class RandomMixInWithNoFields: """Used to test object inheritance using a mixin that has no fields.""" + pass @@ -101,7 +105,6 @@ class TestSubclassedObject(RandomMixInWithNoFields, MyObj): class TestObjToPrimitive(test.base.TestCase): - def test_obj_to_primitive_list(self): @base.CyborgObjectRegistry.register_if(False) class MyObjElement(base.CyborgObject): @@ -117,14 +120,16 @@ class TestObjToPrimitive(test.base.TestCase): mylist = MyList() mylist.objects = [MyObjElement(1), MyObjElement(2), MyObjElement(3)] - self.assertEqual([1, 2, 3], - [x['foo'] for x in base.obj_to_primitive(mylist)]) + self.assertEqual( + [1, 2, 3], [x['foo'] for x in base.obj_to_primitive(mylist)] + ) def test_obj_to_primitive_dict(self): base.CyborgObjectRegistry.register(MyObj) myobj = MyObj(foo=1, bar='foo') - self.assertEqual({'foo': 1, 'bar': 'foo'}, - base.obj_to_primitive(myobj)) + self.assertEqual( + {'foo': 1, 'bar': 'foo'}, base.obj_to_primitive(myobj) + ) def test_obj_to_primitive_recursive(self): base.CyborgObjectRegistry.register(MyObj) @@ -135,22 +140,28 @@ class TestObjToPrimitive(test.base.TestCase): mylist = MyList(objects=[MyObj(), MyObj()]) for i, value in enumerate(mylist): value.foo = i - self.assertEqual([{'foo': 0}, {'foo': 1}], - base.obj_to_primitive(mylist)) + self.assertEqual( + [{'foo': 0}, {'foo': 1}], base.obj_to_primitive(mylist) + ) def test_obj_to_primitive_with_ip_addr(self): @base.CyborgObjectRegistry.register_if(False) class TestObject(base.CyborgObject): - fields = {'addr': fields.IPAddressField(), - 'cidr': fields.IPNetworkField()} + fields = { + 'addr': fields.IPAddressField(), + 'cidr': fields.IPNetworkField(), + } obj = TestObject(addr='1.2.3.4', cidr='1.1.1.1/16') - self.assertEqual({'addr': '1.2.3.4', 'cidr': '1.1.1.1/16'}, - base.obj_to_primitive(obj)) + self.assertEqual( + {'addr': '1.2.3.4', 'cidr': '1.1.1.1/16'}, + base.obj_to_primitive(obj), + ) -def compare_obj(test, obj, db_obj, subs=None, allow_missing=None, - comparators=None): +def compare_obj( + test, obj, db_obj, subs=None, allow_missing=None, comparators=None +): """Compare a CyborgObject and a dict-like database object. This automatically converts TZ-aware datetimes and iterates over @@ -192,16 +203,24 @@ class _BaseTestCase(test.base.TestCase): super().setUp() self.user_id = 'fake-user' self.project_id = 'fake-project' - self.context = cyborg_context.RequestContext(self.user_id, - self.project_id) + self.context = cyborg_context.RequestContext( + self.user_id, self.project_id + ) base.CyborgObjectRegistry.register(MyObj) base.CyborgObjectRegistry.register(MyOwnedObject) - def compare_obj(self, obj, db_obj, subs=None, allow_missing=None, - comparators=None): - compare_obj(self, obj, db_obj, subs=subs, allow_missing=allow_missing, - comparators=comparators) + def compare_obj( + self, obj, db_obj, subs=None, allow_missing=None, comparators=None + ): + compare_obj( + self, + obj, + db_obj, + subs=subs, + allow_missing=allow_missing, + comparators=comparators, + ) def str_comparator(self, expected, obj_val): """Compare an object field to a string in the db by performing diff --git a/cyborg/tests/unit/policies/base.py b/cyborg/tests/unit/policies/base.py index 383ad602..9ff95eec 100644 --- a/cyborg/tests/unit/policies/base.py +++ b/cyborg/tests/unit/policies/base.py @@ -25,7 +25,6 @@ LOG = logging.getLogger(__name__) class BasePolicyTest(v2_test.APITestV2): - def setUp(self): super().setUp() self.policy = self.useFixture(policy_fixture.PolicyFixture()) @@ -37,54 +36,70 @@ class BasePolicyTest(v2_test.APITestV2): # legacy default role: "default:admin_or_owner" self.legacy_admin_context = cyborg_context.RequestContext( - user_id="legacy_admin", project_id=self.admin_project_id, - roles='admin') + user_id="legacy_admin", + project_id=self.admin_project_id, + roles='admin', + ) self.legacy_owner_context = cyborg_context.RequestContext( - user_id="legacy_owner", project_id=self.admin_project_id, - roles='member') + user_id="legacy_owner", + project_id=self.admin_project_id, + roles='member', + ) # system scoped users self.system_admin_context = cyborg_context.RequestContext( - user_id="sys_admin", - roles='admin', system_scope='all') + user_id="sys_admin", roles='admin', system_scope='all' + ) self.system_member_context = cyborg_context.RequestContext( - user_id="sys_member", - roles='member', system_scope='all') + user_id="sys_member", roles='member', system_scope='all' + ) self.system_reader_context = cyborg_context.RequestContext( - user_id="sys_reader", roles='reader', system_scope='all') + user_id="sys_reader", roles='reader', system_scope='all' + ) self.system_foo_context = cyborg_context.RequestContext( - user_id="sys_foo", roles='foo', system_scope='all') + user_id="sys_foo", roles='foo', system_scope='all' + ) # project scoped users self.project_admin_context = cyborg_context.RequestContext( - user_id="project_admin", project_id=self.project_id, - roles='admin') + user_id="project_admin", project_id=self.project_id, roles='admin' + ) self.project_member_context = cyborg_context.RequestContext( - user_id="project_member", project_id=self.project_id, - roles='member') + user_id="project_member", + project_id=self.project_id, + roles='member', + ) self.project_reader_context = cyborg_context.RequestContext( - user_id="project_reader", project_id=self.project_id, - roles='reader') + user_id="project_reader", + project_id=self.project_id, + roles='reader', + ) self.project_foo_context = cyborg_context.RequestContext( - user_id="project_foo", project_id=self.project_id, - roles='foo') + user_id="project_foo", project_id=self.project_id, roles='foo' + ) self.other_project_member_context = cyborg_context.RequestContext( user_id="other_project_member", project_id=self.project_id_other, - roles='member') + roles='member', + ) self.all_contexts = [ - self.legacy_admin_context, self.legacy_owner_context, - self.system_admin_context, self.system_member_context, - self.system_reader_context, self.system_foo_context, - self.project_admin_context, self.project_member_context, - self.project_reader_context, self.other_project_member_context, + self.legacy_admin_context, + self.legacy_owner_context, + self.system_admin_context, + self.system_member_context, + self.system_reader_context, + self.system_foo_context, + self.project_admin_context, + self.project_member_context, + self.project_reader_context, + self.other_project_member_context, self.project_foo_context, ] diff --git a/cyborg/tests/unit/policies/test_device_profiles.py b/cyborg/tests/unit/policies/test_device_profiles.py index 033e02c5..e541104c 100644 --- a/cyborg/tests/unit/policies/test_device_profiles.py +++ b/cyborg/tests/unit/policies/test_device_profiles.py @@ -48,10 +48,11 @@ class DeviceProfilePolicyTest(base.BasePolicyTest): self.create_authorized_contexts = [ self.legacy_admin_context, # legacy: admin self.system_admin_context, # new policy: system_admin - self.project_admin_context + self.project_admin_context, ] self.create_unauthorized_contexts = list( - set(self.all_contexts) - set(self.create_authorized_contexts)) + set(self.all_contexts) - set(self.create_authorized_contexts) + ) # check both legacy and new policies for delete APIs self.delete_authorized_contexts = [ @@ -63,10 +64,11 @@ class DeviceProfilePolicyTest(base.BasePolicyTest): # If later we need support owner policy, we should recheck here. # self.legacy_owner_context, self.system_admin_context, # new policy: system_admin - self.project_admin_context + self.project_admin_context, ] self.delete_unauthorized_contexts = list( - set(self.all_contexts) - set(self.delete_authorized_contexts)) + set(self.all_contexts) - set(self.delete_authorized_contexts) + ) def _validate_links(self, links, dp_uuid): has_self_link = False @@ -114,8 +116,9 @@ class DeviceProfilePolicyTest(base.BasePolicyTest): @mock.patch('cyborg.conductor.rpcapi.ConductorAPI.device_profile_delete') @mock.patch('cyborg.objects.DeviceProfile.get_by_name') @mock.patch('cyborg.objects.DeviceProfile.get_by_uuid') - def test_delete_device_profile_success(self, mock_dp_uuid, - mock_dp_name, mock_cond_del): + def test_delete_device_profile_success( + self, mock_dp_uuid, mock_dp_name, mock_cond_del + ): for context in self.delete_authorized_contexts: headers = self.gen_headers(context) # Delete by UUID @@ -156,15 +159,19 @@ class DeviceProfileScopeTypePolicyTest(DeviceProfilePolicyTest): # check that admin is able to do create and delete operations. self.create_authorized_contexts = [ self.legacy_admin_context, - self.project_admin_context] + self.project_admin_context, + ] self.delete_authorized_contexts = self.create_authorized_contexts # Check that system or non-admin is not able to perform the system # level actions on device_profiles. self.create_unauthorized_contexts = [ - self.system_admin_context, self.system_member_context, - self.system_reader_context, self.system_foo_context, + self.system_admin_context, + self.system_member_context, + self.system_reader_context, + self.system_foo_context, self.project_member_context, self.other_project_member_context, - self.project_foo_context, self.project_reader_context + self.project_foo_context, + self.project_reader_context, ] self.delete_unauthorized_contexts = self.create_unauthorized_contexts diff --git a/cyborg/tests/unit/policy_fixture.py b/cyborg/tests/unit/policy_fixture.py index 594fcecf..bf1064e5 100644 --- a/cyborg/tests/unit/policy_fixture.py +++ b/cyborg/tests/unit/policy_fixture.py @@ -35,8 +35,9 @@ class PolicyFixture(fixtures.Fixture): def setUp(self): super().setUp() self.policy_dir = self.useFixture(fixtures.TempDir()) - self.policy_file_name = os.path.join(self.policy_dir.path, - 'policy.yaml') + self.policy_file_name = os.path.join( + self.policy_dir.path, 'policy.yaml' + ) with open(self.policy_file_name, 'w') as policy_file: policy_file.write(policy_data) policy_opts.set_defaults(CONF) diff --git a/cyborg/tests/unit/services/_test_placement_client.py b/cyborg/tests/unit/services/_test_placement_client.py index 43ac23e8..0e866e3b 100644 --- a/cyborg/tests/unit/services/_test_placement_client.py +++ b/cyborg/tests/unit/services/_test_placement_client.py @@ -29,25 +29,28 @@ class PlacementAPIClientTestCase(base.DietTestCase): def setUp(self): super().setUp() self.mock_load_auth_p = mock.patch( - 'keystoneauth1.loading.load_auth_from_conf_options') + 'keystoneauth1.loading.load_auth_from_conf_options' + ) self.mock_load_auth = self.mock_load_auth_p.start() self.mock_request_p = mock.patch( - 'keystoneauth1.session.Session.request') + 'keystoneauth1.session.Session.request' + ) self.mock_request = self.mock_request_p.start() self.client = placement_client.PlacementClient() @mock.patch('keystoneauth1.session.Session') @mock.patch('keystoneauth1.loading.load_auth_from_conf_options') def test_constructor(self, load_auth_mock, ks_sess_mock): - placement_client.PlacementClient() load_auth_mock.assert_called_once_with(cfg.CONF, 'placement') - ks_sess_mock.assert_called_once_with(auth=load_auth_mock.return_value, - cert=None, - collect_timing=False, - split_loggers=False, - timeout=None, - verify=True) + ks_sess_mock.assert_called_once_with( + auth=load_auth_mock.return_value, + cert=None, + collect_timing=False, + split_loggers=False, + timeout=None, + verify=True, + ) @mock.patch('cyborg.common.placement_client.PlacementClient.post') def test_create_resource_provider(self, mock_post): @@ -55,8 +58,11 @@ class PlacementAPIClientTestCase(base.DietTestCase): self.client._create_resource_provider(self.context, rp_uuid, 'test') expected_url = '/resource_providers' mock_post.assert_called_once_with( - expected_url, {'uuid': rp_uuid, 'name': 'test'}, - version='1.20', global_request_id=mock.ANY) + expected_url, + {'uuid': rp_uuid, 'name': 'test'}, + version='1.20', + global_request_id=mock.ANY, + ) @mock.patch('cyborg.common.placement_client.PlacementClient.delete') def test_delete_resource_provider(self, mock_delete): @@ -64,7 +70,8 @@ class PlacementAPIClientTestCase(base.DietTestCase): self.client.delete_provider(rp_uuid) expected_url = '/resource_providers/' + rp_uuid mock_delete.assert_called_once_with( - expected_url, global_request_id=mock.ANY) + expected_url, global_request_id=mock.ANY + ) def test_create_inventory(self): expected_payload = 'fake_inventory' @@ -72,9 +79,12 @@ class PlacementAPIClientTestCase(base.DietTestCase): e_filter = {'region_name': mock.ANY, 'service_type': 'placement'} self.client.create_inventory(rp_uuid, expected_payload) expected_url = '/resource_providers/%s/inventories' % rp_uuid - self.mock_request.assert_called_once_with(expected_url, 'POST', - endpoint_filter=e_filter, - json=expected_payload) + self.mock_request.assert_called_once_with( + expected_url, + 'POST', + endpoint_filter=e_filter, + json=expected_payload, + ) def test_get_inventory(self): rp_uuid = uuidutils.generate_uuid() @@ -82,25 +92,34 @@ class PlacementAPIClientTestCase(base.DietTestCase): resource_class = 'fake_resource_class' self.client.get_inventory(rp_uuid, resource_class) expected_url = '/resource_providers/%s/inventories/%s' % ( - rp_uuid, resource_class) - self.mock_request.assert_called_once_with(expected_url, 'GET', - endpoint_filter=e_filter) + rp_uuid, + resource_class, + ) + self.mock_request.assert_called_once_with( + expected_url, 'GET', endpoint_filter=e_filter + ) def _test_get_inventory_not_found(self, details, expected_exception): rp_uuid = uuidutils.generate_uuid() resource_class = 'fake_resource_class' self.mock_request.side_effect = ks_exc.NotFound(details=details) - self.assertRaises(expected_exception, self.client.get_inventory, - rp_uuid, resource_class) + self.assertRaises( + expected_exception, + self.client.get_inventory, + rp_uuid, + resource_class, + ) def test_get_inventory_not_found_no_resource_provider(self): self._test_get_inventory_not_found( "No resource provider with uuid", - c_exc.PlacementResourceProviderNotFound) + c_exc.PlacementResourceProviderNotFound, + ) def test_get_inventory_not_found_no_inventory(self): self._test_get_inventory_not_found( - "No inventory of class", c_exc.PlacementInventoryNotFound) + "No inventory of class", c_exc.PlacementInventoryNotFound + ) def test_get_inventory_not_found_unknown_cause(self): self._test_get_inventory_not_found("Unknown cause", ks_exc.NotFound) @@ -112,16 +131,25 @@ class PlacementAPIClientTestCase(base.DietTestCase): resource_class = 'fake_resource_class' self.client.update_inventory(rp_uuid, expected_payload, resource_class) expected_url = '/resource_providers/%s/inventories/%s' % ( - rp_uuid, resource_class) - self.mock_request.assert_called_once_with(expected_url, 'PUT', - endpoint_filter=e_filter, - json=expected_payload) + rp_uuid, + resource_class, + ) + self.mock_request.assert_called_once_with( + expected_url, + 'PUT', + endpoint_filter=e_filter, + json=expected_payload, + ) def test_update_inventory_conflict(self): rp_uuid = uuidutils.generate_uuid() expected_payload = 'fake_inventory' resource_class = 'fake_resource_class' self.mock_request.side_effect = ks_exc.Conflict - self.assertRaises(c_exc.PlacementInventoryUpdateConflict, - self.client.update_inventory, rp_uuid, - expected_payload, resource_class) + self.assertRaises( + c_exc.PlacementInventoryUpdateConflict, + self.client.update_inventory, + rp_uuid, + expected_payload, + resource_class, + ) diff --git a/cyborg/tests/unit/test_exception.py b/cyborg/tests/unit/test_exception.py index 344ee87d..6d01145b 100644 --- a/cyborg/tests/unit/test_exception.py +++ b/cyborg/tests/unit/test_exception.py @@ -41,6 +41,9 @@ class TestException(base.TestCase): if parent._msg_fmt == cls._msg_fmt: bad_classes.append(cls.__name__) - self.assertEqual([], bad_classes, - 'Exception classes %s do not ' - 'set _msg_fmt' % ', '.join(bad_classes)) + self.assertEqual( + [], + bad_classes, + 'Exception classes %s do not ' + 'set _msg_fmt' % ', '.join(bad_classes), + ) diff --git a/cyborg/tests/unit/test_hacking.py b/cyborg/tests/unit/test_hacking.py deleted file mode 100644 index db8fbb93..00000000 --- a/cyborg/tests/unit/test_hacking.py +++ /dev/null @@ -1,164 +0,0 @@ -# Copyright 2015 Intel, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. - -import textwrap -from unittest import mock - -import pycodestyle - -from cyborg.hacking import checks -from cyborg.tests import base - - -class HackingTestCase(base.TestCase): - """Hacking test class. - - This class tests the hacking checks in magnum.hacking.checks by passing - strings to the check methods like the pycodestyle/flake8 parser would. - The parser loops over each line in the file and then passes the parameters - to the check method. The parameter names in the check method dictate what - type of object is passed to the check method. The parameter types are:: - - logical_line: A processed line with the following modifications: - - Multi-line statements converted to a single line. - - Stripped left and right. - - Contents of strings replaced with "xxx" of same length. - - Comments removed. - physical_line: Raw line of text from the input file. - lines: a list of the raw lines from the input file - tokens: the tokens that contribute to this logical line - line_number: line number in the input file - total_lines: number of lines in the input file - blank_lines: blank lines before this one - indent_char: indentation character in this file (" " or "\t") - indent_level: indentation (with tabs expanded to multiples of 8) - previous_indent_level: indentation on previous line - previous_logical: previous logical line - filename: Path of the file being run through pycodestyle - - When running a test on a check method the return will be False/None if - there is no violation in the sample input. If there is an error a tuple is - returned with a position in the line, and a message. So to check the result - just assertTrue if the check is expected to fail and assertFalse if it - should pass. - """ - # We are patching pycodestyle so that only the check under test is - # actually installed. - - @mock.patch('pycodestyle._checks', - {'physical_line': {}, 'logical_line': {}, 'tree': {}}) - def _run_check(self, code, checker, filename=None): - pycodestyle.register_check(checker) - - lines = textwrap.dedent(code).strip().splitlines(True) - - checker = pycodestyle.Checker(filename=filename, lines=lines) - checker.check_all() - checker.report._deferred_print.sort() - return checker.report._deferred_print - - def _assert_has_errors(self, code, checker, expected_errors=None, - filename=None): - actual_errors = [e[:3] for e in - self._run_check(code, checker, filename)] - self.assertEqual(expected_errors or [], actual_errors) - - def _assert_has_no_errors(self, code, checker, filename=None): - self._assert_has_errors(code, checker, filename=filename) - - def test_no_mutable_default_args(self): - errors = [(1, 0, "M322")] - check = checks.no_mutable_default_args - - code = "def get_info_from_bdm(virt_type, bdm, mapping=[])" - self._assert_has_errors(code, check, errors) - - code = "defined = []" - self._assert_has_no_errors(code, check) - - code = "defined, undefined = [], {}" - self._assert_has_no_errors(code, check) - - def test_use_timeunitls_utcow(self): - errors = [(1, 0, "M310")] - check = checks.use_timeutils_utcnow - - code = "datetime.now" - self._assert_has_errors(code, check, errors) - - code = "datetime.utcnow" - self._assert_has_errors(code, check, errors) - - code = "datetime.aa" - self._assert_has_no_errors(code, check) - - code = "aaa" - self._assert_has_no_errors(code, check) - - def test_dict_constructor_with_list_copy(self): - self.assertEqual(1, len(list(checks.dict_constructor_with_list_copy( - " dict([(i, connect_info[i])")))) - - self.assertEqual(1, len(list(checks.dict_constructor_with_list_copy( - " attrs = dict([(k, _from_json(v))")))) - - self.assertEqual(1, len(list(checks.dict_constructor_with_list_copy( - " type_names = dict((value, key) for key, value in")))) - - self.assertEqual(1, len(list(checks.dict_constructor_with_list_copy( - " dict((value, key) for key, value in")))) - - self.assertEqual(1, len(list(checks.dict_constructor_with_list_copy( - "foo(param=dict((k, v) for k, v in bar.items()))")))) - - self.assertEqual(1, len(list(checks.dict_constructor_with_list_copy( - " dict([[i,i] for i in range(3)])")))) - - self.assertEqual(1, len(list(checks.dict_constructor_with_list_copy( - " dd = dict([i,i] for i in range(3))")))) - - self.assertEqual(0, len(list(checks.dict_constructor_with_list_copy( - " create_kwargs = dict(snapshot=snapshot,")))) - - self.assertEqual(0, len(list(checks.dict_constructor_with_list_copy( - " self._render_dict(xml, data_el, data.__dict__)")))) - - def test_check_explicit_underscore_import(self): - self.assertEqual(len(list(checks.check_explicit_underscore_import( - "LOG.info(_('My info message'))", - "magnum/tests/other_files.py"))), 1) - self.assertEqual(len(list(checks.check_explicit_underscore_import( - "msg = _('My message')", - "magnum/tests/other_files.py"))), 1) - self.assertEqual(len(list(checks.check_explicit_underscore_import( - "from magnum.i18n import _", - "magnum/tests/other_files.py"))), 0) - self.assertEqual(len(list(checks.check_explicit_underscore_import( - "LOG.info(_('My info message'))", - "magnum/tests/other_files.py"))), 0) - self.assertEqual(len(list(checks.check_explicit_underscore_import( - "msg = _('My message')", - "magnum/tests/other_files.py"))), 0) - self.assertEqual(len(list(checks.check_explicit_underscore_import( - "from magnum.i18n import _, _LW", - "magnum/tests/other_files2.py"))), 0) - self.assertEqual(len(list(checks.check_explicit_underscore_import( - "msg = _('My message')", - "magnum/tests/other_files2.py"))), 0) - self.assertEqual(len(list(checks.check_explicit_underscore_import( - "_ = translations.ugettext", - "magnum/tests/other_files3.py"))), 0) - self.assertEqual(len(list(checks.check_explicit_underscore_import( - "msg = _('My message')", - "magnum/tests/other_files3.py"))), 0) diff --git a/pyproject.toml b/pyproject.toml index 95873bf9..25d3b8ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,7 +81,8 @@ line-length = 79 target-version = "py310" [tool.ruff.lint] -select = ["E4", "E7", "E9", "F", "S", "UP", "W", "C90"] +select = ["E4", "E7", "E9", "F", "G", "LOG", "S", "UP", "W", "C90"] +external = ["H"] ignore = [ # asserts used for type narrowing only "S101", @@ -103,3 +104,7 @@ ignore = [ [tool.ruff.lint.mccabe] max-complexity = 20 + +[tool.ruff.format] +quote-style = "preserve" +docstring-code-format = true diff --git a/releasenotes/source/conf.py b/releasenotes/source/conf.py index ed79c55f..8d27f5e1 100644 --- a/releasenotes/source/conf.py +++ b/releasenotes/source/conf.py @@ -122,9 +122,13 @@ htmlhelp_basename = 'CyborgReleaseNotesdoc' # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - (master_doc, 'CyborgReleaseNotes.tex', - 'Cyborg Release Notes Documentation', - 'Cyborg developers', 'manual'), + ( + master_doc, + 'CyborgReleaseNotes.tex', + 'Cyborg Release Notes Documentation', + 'Cyborg developers', + 'manual', + ), ] @@ -133,8 +137,13 @@ latex_documents = [ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ - (master_doc, 'CyborgReleaseNotes', 'Cyborg Release Notes Documentation', - [author], 1) + ( + master_doc, + 'CyborgReleaseNotes', + 'Cyborg Release Notes Documentation', + [author], + 1, + ) ] @@ -144,9 +153,15 @@ man_pages = [ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'CyborgReleaseNotes', 'Cyborg Release Notes Documentation', - author, 'CyborgReleaseNotes', 'One line description of project.', - 'Miscellaneous'), + ( + master_doc, + 'CyborgReleaseNotes', + 'Cyborg Release Notes Documentation', + author, + 'CyborgReleaseNotes', + 'One line description of project.', + 'Miscellaneous', + ), ] # -- Options for Internationalization output ------------------------------ diff --git a/setup.py b/setup.py index cd35c3c3..481505b0 100644 --- a/setup.py +++ b/setup.py @@ -15,6 +15,4 @@ import setuptools -setuptools.setup( - setup_requires=['pbr>=2.0.0'], - pbr=True) +setuptools.setup(setup_requires=['pbr>=2.0.0'], pbr=True) diff --git a/test-requirements.txt b/test-requirements.txt index 522333d5..a3cc1f49 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1,6 +1,3 @@ -hacking>=8.0.0,<8.1.0 # Apache-2.0 - -bandit>=1.6.0 # Apache-2.0 coverage>=3.6,!=4.4 # Apache-2.0 fixtures>=3.0.0 # Apache-2.0/BSD oslotest>=3.2.0 # Apache-2.0 diff --git a/tools/flake8wrap.sh b/tools/flake8wrap.sh deleted file mode 100755 index c7478ac4..00000000 --- a/tools/flake8wrap.sh +++ /dev/null @@ -1,27 +0,0 @@ -#!/bin/sh -# -# A simple wrapper around flake8 which makes it possible -# to ask it to only verify files changed in the current -# git HEAD patch. -# -# Intended to be invoked via tox: -# -# tox -epep8 -- -HEAD -# - -if test "x$1" = "x-HEAD" ; then - shift - files=$(git diff --name-only HEAD~1 | tr '\n' ' ') - echo "Running flake8 on ${files}" - echo "" - echo "Consider using the 'pre-commit' tool instead." - echo "" - echo " pip install --user pre-commit" - echo " pre-commit install --allow-missing-config" - echo "" - diff -u --from-file /dev/null ${files} | flake8 --diff "$@" -else - echo "Running flake8 on all files" - echo "" - exec flake8 "$@" -fi diff --git a/tox.ini b/tox.ini index 95b789f8..d54ee191 100644 --- a/tox.ini +++ b/tox.ini @@ -5,8 +5,6 @@ envlist = py3,pep8 [testenv] usedevelop = True allowlist_externals = - bash - find rm env make @@ -43,20 +41,9 @@ usedevelop = False [testenv:pep8] description = Run style checks. -commands = - bash tools/flake8wrap.sh {posargs} - # Check that all JSON files don't have \r\n in line. - bash -c "! find doc/ -type f -name *.json | xargs grep -U -n $'\r'" - # Check that all included JSON files are valid JSON - bash -c '! find doc/ -type f -name *.json | grep -v 'curl' | xargs -t -n1 python -m json.tool 2>&1 > /dev/null | grep -B1 -v ^python' - bash tools/check-cherry-picks.sh - doc8 doc/source/ CONTRIBUTING.rst HACKING.rst README.rst - -[testenv:fast8] -description = - Run style checks on the changes made since HEAD~. For a full run including docs, use 'pep8' -commands = - bash tools/flake8wrap.sh -HEAD +deps = pre-commit +skip_install = true +commands = pre-commit run -a --show-diff-on-failure [testenv:venv] commands = {posargs} @@ -119,20 +106,11 @@ commands = allowlist_externals = rm [flake8] -filename = *.py,app.wsgi +# We only enable the hacking (H) checks; ruff handles everything else +select = H +# H301 Ruff will put commas after imports that can't fit on one line +# H404 Docstrings don't always start with a newline +# H405 Multiline docstrings are okay +ignore = H301,H404,H405 show-source = True -ignore = E123,E125,H405,W503,W504 -builtins = _ -enable-extensions = H106,H203,H904 -exclude=.venv,.git,.tox,dist,doc,*lib/python*,*egg,build,*sqlalchemy/alembic/versions/*,demo/,releasenotes - -[testenv:bandit] -commands = bandit -r cyborg -x cyborg/tests/* -n 5 -ll - -[flake8:local-plugins] -extension = - M310 = checks:use_timeutils_utcnow - M322 = checks:no_mutable_default_args - M336 = checks:dict_constructor_with_list_copy - M340 = checks:check_explicit_underscore_import -paths = ./cyborg/hacking +exclude=.venv,.git,.tox,dist,doc,*lib/python*,*egg,build,releasenotes,*sqlalchemy/alembic/versions/*