diff --git a/cyborg/accelerator/drivers/driver.py b/cyborg/accelerator/drivers/driver.py new file mode 100644 index 00000000..a4b5b1a5 --- /dev/null +++ b/cyborg/accelerator/drivers/driver.py @@ -0,0 +1,38 @@ +import abc +import six + + +@six.add_metaclass(abc.ABCMeta) +class GenericDriver(object): + + @abc.abstractmethod + def discover(self): + """Discover a specified accelerator. + + :return: the list of driver device objs + """ + pass + + @abc.abstractmethod + def update(self, control_path, image_path): + """Update the device firmware with specific image. + + :param control_path: the image update control path of device. + :param image_path: The image path of the firmware binary. + + :return: True if update successfully otherwise False + """ + pass + + @abc.abstractmethod + def get_stats(self): + """Collects device stats. + + It is used to collect information from the device about the device + capabilities. Such as performance info like temprature, power, volt, + packet_count info. + + :return: The stats info of the device. The format should follow the + current Cyborg device-deploy-accelerator model + """ + pass diff --git a/cyborg/tests/unit/accelerator/drivers/test_driver.py b/cyborg/tests/unit/accelerator/drivers/test_driver.py new file mode 100644 index 00000000..92069778 --- /dev/null +++ b/cyborg/tests/unit/accelerator/drivers/test_driver.py @@ -0,0 +1,31 @@ +import six + +from cyborg.tests import base +from cyborg.accelerator.drivers.driver import GenericDriver + + +class WellDoneDriver(GenericDriver): + def discover(self): + pass + + def update(self, control_path, image_path): + pass + + def get_stats(self): + pass + + +class NotCompleteDriver(GenericDriver): + def discover(self): + pass + + +class TestGenericDriver(base.TestCase): + + def test_generic_driver(self): + driver = WellDoneDriver() + # Can't instantiate abstract class NotCompleteDriver with + # abstract methods get_stats, update + result = self.assertRaises(TypeError, NotCompleteDriver) + self.assertIn("Can't instantiate abstract class", + six.text_type(result))