Add tensorflow driver implementation

Change-Id: I951cea9325d2ea4a843ea55d1731c481df899474
This commit is contained in:
bharath 2018-10-14 00:22:56 +05:30
parent a40f556027
commit 26e2299908
25 changed files with 332 additions and 138 deletions

View File

@ -305,6 +305,7 @@ function start_gyan_compute {
function start_gyan { function start_gyan {
# ``run_process`` checks ``is_service_enabled``, it is not needed here # ``run_process`` checks ``is_service_enabled``, it is not needed here
mkdir -p /opt/stack/data/gyan
start_gyan_api start_gyan_api
start_gyan_compute start_gyan_compute
} }

View File

@ -82,10 +82,10 @@ class V1(controllers_base.APIBase):
'hosts', '', 'hosts', '',
bookmark=True)] bookmark=True)]
v1.ml_models = [link.make_link('self', pecan.request.host_url, v1.ml_models = [link.make_link('self', pecan.request.host_url,
'ml_models', ''), 'ml-models', ''),
link.make_link('bookmark', link.make_link('bookmark',
pecan.request.host_url, pecan.request.host_url,
'ml_models', '', 'ml-models', '',
bookmark=True)] bookmark=True)]
return v1 return v1
@ -147,9 +147,9 @@ class Controller(controllers_base.Controller):
{'url': pecan.request.url, {'url': pecan.request.url,
'method': pecan.request.method, 'method': pecan.request.method,
'body': pecan.request.body}) 'body': pecan.request.body})
LOG.debug(msg) # LOG.debug(msg)
LOG.debug(args)
return super(Controller, self)._route(args) return super(Controller, self)._route(args)
__all__ = ('Controller',) __all__ = ('Controller',)

View File

@ -10,6 +10,7 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import base64
import shlex import shlex
from oslo_log import log as logging from oslo_log import log as logging
@ -74,12 +75,13 @@ class MLModelController(base.Controller):
"""Controller for MLModels.""" """Controller for MLModels."""
_custom_actions = { _custom_actions = {
'train': ['POST'], 'upload_trained_model': ['POST'],
'deploy': ['GET'], 'deploy': ['GET'],
'undeploy': ['GET'] 'undeploy': ['GET'],
'predict': ['POST']
} }
@pecan.expose('json') @pecan.expose('json')
@exception.wrap_pecan_controller_exception @exception.wrap_pecan_controller_exception
def get_all(self, **kwargs): def get_all(self, **kwargs):
@ -149,33 +151,55 @@ class MLModelController(base.Controller):
context.all_projects = True context.all_projects = True
ml_model = utils.get_ml_model(ml_model_ident) ml_model = utils.get_ml_model(ml_model_ident)
check_policy_on_ml_model(ml_model.as_dict(), "ml_model:get_one") check_policy_on_ml_model(ml_model.as_dict(), "ml_model:get_one")
if ml_model.node:
compute_api = pecan.request.compute_api
try:
ml_model = compute_api.ml_model_show(context, ml_model)
except exception.MLModelHostNotUp:
raise exception.ServerNotUsable
return view.format_ml_model(context, pecan.request.host_url, return view.format_ml_model(context, pecan.request.host_url,
ml_model.as_dict()) ml_model.as_dict())
@base.Controller.api_version("1.0")
@pecan.expose('json')
@exception.wrap_pecan_controller_exception
def upload_trained_model(self, ml_model_ident, **kwargs):
context = pecan.request.context
LOG.debug(ml_model_ident)
ml_model = utils.get_ml_model(ml_model_ident)
LOG.debug(ml_model)
ml_model.ml_data = pecan.request.body
ml_model.save(context)
pecan.response.status = 200
compute_api = pecan.request.compute_api
new_model = view.format_ml_model(context, pecan.request.host_url,
ml_model.as_dict())
compute_api.ml_model_create(context, new_model)
return new_model
@base.Controller.api_version("1.0")
@pecan.expose('json')
@exception.wrap_pecan_controller_exception
def predict(self, ml_model_ident, **kwargs):
context = pecan.request.context
LOG.debug(ml_model_ident)
ml_model = utils.get_ml_model(ml_model_ident)
pecan.response.status = 200
compute_api = pecan.request.compute_api
predict_dict = {
"data": base64.b64encode(pecan.request.POST['file'].file.read())
}
prediction = compute_api.ml_model_predict(context, ml_model_ident, **predict_dict)
return prediction
@base.Controller.api_version("1.0") @base.Controller.api_version("1.0")
@pecan.expose('json') @pecan.expose('json')
@api_utils.enforce_content_types(['application/json']) @api_utils.enforce_content_types(['application/json'])
@exception.wrap_pecan_controller_exception @exception.wrap_pecan_controller_exception
@validation.validate_query_param(pecan.request, schema.query_param_create)
@validation.validated(schema.ml_model_create) @validation.validated(schema.ml_model_create)
def post(self, **ml_model_dict): def post(self, **ml_model_dict):
return self._do_post(**ml_model_dict) return self._do_post(**ml_model_dict)
def _do_post(self, **ml_model_dict): def _do_post(self, **ml_model_dict):
"""Create or run a new ml model. """Create or run a new ml model.
:param ml_model_dict: a ml_model within the request body. :param ml_model_dict: a ml_model within the request body.
""" """
context = pecan.request.context context = pecan.request.context
compute_api = pecan.request.compute_api
policy.enforce(context, "ml_model:create", policy.enforce(context, "ml_model:create",
action="ml_model:create") action="ml_model:create")
@ -183,22 +207,24 @@ class MLModelController(base.Controller):
ml_model_dict['user_id'] = context.user_id ml_model_dict['user_id'] = context.user_id
name = ml_model_dict.get('name') name = ml_model_dict.get('name')
ml_model_dict['name'] = name ml_model_dict['name'] = name
ml_model_dict['status'] = consts.CREATING ml_model_dict['status'] = consts.CREATED
ml_model_dict['ml_type'] = ml_model_dict['type']
extra_spec = {} extra_spec = {}
extra_spec['hints'] = ml_model_dict.get('hints', None) extra_spec['hints'] = ml_model_dict.get('hints', None)
#ml_model_dict["model_data"] = open("/home/bharath/model.zip", "rb").read()
new_ml_model = objects.ML_Model(context, **ml_model_dict) new_ml_model = objects.ML_Model(context, **ml_model_dict)
new_ml_model.create(context) ml_model = new_ml_model.create(context)
LOG.debug(new_ml_model)
compute_api.ml_model_create(context, new_ml_model, **kwargs) #compute_api.ml_model_create(context, new_ml_model)
# Set the HTTP Location Header # Set the HTTP Location Header
pecan.response.location = link.build_url('ml_models', pecan.response.location = link.build_url('ml_models',
new_ml_model.uuid) ml_model.id)
pecan.response.status = 202 pecan.response.status = 201
return view.format_ml_model(context, pecan.request.node_url, return view.format_ml_model(context, pecan.request.host_url,
new_ml_model.as_dict()) ml_model.as_dict())
@pecan.expose('json') @pecan.expose('json')
@exception.wrap_pecan_controller_exception @exception.wrap_pecan_controller_exception
@validation.validated(schema.ml_model_update) @validation.validated(schema.ml_model_update)
@ -217,11 +243,11 @@ class MLModelController(base.Controller):
return view.format_ml_model(context, pecan.request.node_url, return view.format_ml_model(context, pecan.request.node_url,
ml_model.as_dict()) ml_model.as_dict())
@pecan.expose('json') @pecan.expose('json')
@exception.wrap_pecan_controller_exception @exception.wrap_pecan_controller_exception
@validation.validate_query_param(pecan.request, schema.query_param_delete) @validation.validate_query_param(pecan.request, schema.query_param_delete)
def delete(self, ml_model_ident, force=False, **kwargs): def delete(self, ml_model_ident, **kwargs):
"""Delete a ML Model. """Delete a ML Model.
:param ml_model_ident: UUID or Name of a ML Model. :param ml_model_ident: UUID or Name of a ML Model.
@ -230,27 +256,7 @@ class MLModelController(base.Controller):
context = pecan.request.context context = pecan.request.context
ml_model = utils.get_ml_model(ml_model_ident) ml_model = utils.get_ml_model(ml_model_ident)
check_policy_on_ml_model(ml_model.as_dict(), "ml_model:delete") check_policy_on_ml_model(ml_model.as_dict(), "ml_model:delete")
try: ml_model.destroy(context)
force = strutils.bool_from_string(force, strict=True)
except ValueError:
bools = ', '.join(strutils.TRUE_STRINGS + strutils.FALSE_STRINGS)
raise exception.InvalidValue(_('Valid force values are: %s')
% bools)
stop = kwargs.pop('stop', False)
try:
stop = strutils.bool_from_string(stop, strict=True)
except ValueError:
bools = ', '.join(strutils.TRUE_STRINGS + strutils.FALSE_STRINGS)
raise exception.InvalidValue(_('Valid stop values are: %s')
% bools)
compute_api = pecan.request.compute_api
if not force:
utils.validate_ml_model_state(ml_model, 'delete')
ml_model.status = consts.DELETING
if ml_model.node:
compute_api.ml_model_delete(context, ml_model, force)
else:
ml_model.destroy(context)
pecan.response.status = 204 pecan.response.status = 204
@ -261,15 +267,19 @@ class MLModelController(base.Controller):
:param ml_model_ident: UUID or Name of a ML Model. :param ml_model_ident: UUID or Name of a ML Model.
""" """
context = pecan.request.context
ml_model = utils.get_ml_model(ml_model_ident) ml_model = utils.get_ml_model(ml_model_ident)
check_policy_on_ml_model(ml_model.as_dict(), "ml_model:deploy") check_policy_on_ml_model(ml_model.as_dict(), "ml_model:deploy")
utils.validate_ml_model_state(ml_model, 'deploy') utils.validate_ml_model_state(ml_model, 'deploy')
LOG.debug('Calling compute.ml_model_deploy with %s', LOG.debug('Calling compute.ml_model_deploy with %s',
ml_model.uuid) ml_model.id)
context = pecan.request.context ml_model.status = consts.DEPLOYED
compute_api = pecan.request.compute_api url = pecan.request.url.replace("deploy", "predict")
compute_api.ml_model_deploy(context, ml_model) ml_model.url = url
ml_model.save(context)
pecan.response.status = 202 pecan.response.status = 202
return view.format_ml_model(context, pecan.request.host_url,
ml_model.as_dict())
@pecan.expose('json') @pecan.expose('json')
@exception.wrap_pecan_controller_exception @exception.wrap_pecan_controller_exception
@ -278,12 +288,15 @@ class MLModelController(base.Controller):
:param ml_model_ident: UUID or Name of a ML Model. :param ml_model_ident: UUID or Name of a ML Model.
""" """
context = pecan.request.context
ml_model = utils.get_ml_model(ml_model_ident) ml_model = utils.get_ml_model(ml_model_ident)
check_policy_on_ml_model(ml_model.as_dict(), "ml_model:deploy") check_policy_on_ml_model(ml_model.as_dict(), "ml_model:deploy")
utils.validate_ml_model_state(ml_model, 'undeploy') utils.validate_ml_model_state(ml_model, 'undeploy')
LOG.debug('Calling compute.ml_model_deploy with %s', LOG.debug('Calling compute.ml_model_deploy with %s',
ml_model.uuid) ml_model.id)
context = pecan.request.context ml_model.status = consts.SCHEDULED
compute_api = pecan.request.compute_api ml_model.url = None
compute_api.ml_model_undeploy(context, ml_model) ml_model.save(context)
pecan.response.status = 202 pecan.response.status = 202
return view.format_ml_model(context, pecan.request.host_url,
ml_model.as_dict())

View File

@ -18,8 +18,11 @@ _ml_model_properties = {}
ml_model_create = { ml_model_create = {
'type': 'object', 'type': 'object',
'properties': _ml_model_properties, 'properties': {
'required': ['name'], "name": parameter_types.ml_model_name,
"type": parameter_types.ml_model_type
},
'required': ['name', 'type'],
'additionalProperties': False 'additionalProperties': False
} }
@ -46,4 +49,4 @@ query_param_delete = {
'stop': parameter_types.boolean_extended 'stop': parameter_types.boolean_extended
}, },
'additionalProperties': False 'additionalProperties': False
} }

View File

@ -95,3 +95,17 @@ hostname = {
# real systems. # real systems.
'pattern': '^[a-zA-Z0-9-._]*$', 'pattern': '^[a-zA-Z0-9-._]*$',
} }
ml_model_name = {
'type': 'string',
'minLength': 1,
'maxLength': 255,
'pattern': '^[a-zA-Z0-9-._]*$'
}
ml_model_type = {
'type': 'string',
'minLength': 1,
'maxLength': 255,
'pattern': '^[a-zA-Z0-9-._]*$'
}

View File

@ -13,41 +13,46 @@
import itertools import itertools
from oslo_log import log as logging
from gyan.api.controllers import link from gyan.api.controllers import link
from gyan.common.policies import ml_model as policies from gyan.common.policies import ml_model as policies
_basic_keys = ( _basic_keys = (
'uuid', 'id',
'user_id', 'user_id',
'project_id', 'project_id',
'name', 'name',
'url', 'url',
'status', 'status',
'status_reason', 'status_reason',
'task_state', 'host_id',
'labels', 'deployed',
'host', 'ml_type'
'status_detail'
) )
LOG = logging.getLogger(__name__)
def format_ml_model(context, url, ml_model): def format_ml_model(context, url, ml_model):
def transform(key, value): def transform(key, value):
LOG.debug(key)
LOG.debug(value)
if key not in _basic_keys: if key not in _basic_keys:
return return
# strip the key if it is not allowed by policy # strip the key if it is not allowed by policy
policy_action = policies.ML_MODEL % ('get_one:%s' % key) policy_action = policies.ML_MODEL % ('get_one:%s' % key)
if not context.can(policy_action, fatal=False, might_not_exist=True): if not context.can(policy_action, fatal=False, might_not_exist=True):
return return
if key == 'uuid': if key == 'id':
yield ('uuid', value) yield ('id', value)
if url: # if url:
yield ('links', [link.make_link( # yield ('links', [link.make_link(
'self', url, 'ml_models', value), # 'self', url, 'ml_models', value),
link.make_link( # link.make_link(
'bookmark', url, # 'bookmark', url,
'ml_models', value, # 'ml_models', value,
bookmark=True)]) # bookmark=True)])
else: else:
yield (key, value) yield (key, value)

View File

@ -1,5 +1,3 @@
# Copyright ? 2012 New Dream Network, LLC (DreamHost)
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may # Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain # not use this file except in compliance with the License. You may obtain
# a copy of the License at # a copy of the License at

View File

@ -113,4 +113,4 @@ def version_check(action, version):
if req_version < min_version: if req_version < min_version:
raise exception.InvalidParamInVersion(param=action, raise exception.InvalidParamInVersion(param=action,
req_version=req_version, req_version=req_version,
min_version=min_version) min_version=min_version)

View File

@ -14,4 +14,7 @@
ALLOCATED = 'allocated' ALLOCATED = 'allocated'
CREATED = 'created' CREATED = 'created'
UNDEPLOYED = 'undeployed' UNDEPLOYED = 'undeployed'
DEPLOYED = 'deployed' DEPLOYED = 'deployed'
CREATING = 'CREATING'
CREATED = 'CREATED'
SCHEDULED = 'SCHEDULED'

View File

@ -106,16 +106,27 @@ rules = [
] ]
), ),
policy.DocumentedRuleDefault( policy.DocumentedRuleDefault(
name=ML_MODEL % 'upload', name=ML_MODEL % 'upload_trained_model',
check_str=base.RULE_ADMIN_OR_OWNER, check_str=base.RULE_ADMIN_OR_OWNER,
description='Upload the trained ML Model', description='Upload the trained ML Model',
operations=[ operations=[
{ {
'path': '/v1/ml_models/{ml_model_ident}/upload', 'path': '/v1/ml_models/{ml_model_ident}/upload_trained_model',
'method': 'POST' 'method': 'POST'
} }
] ]
), ),
policy.DocumentedRuleDefault(
name=ML_MODEL % 'deploy',
check_str=base.RULE_ADMIN_OR_OWNER,
description='Upload the trained ML Model',
operations=[
{
'path': '/v1/ml_models/{ml_model_ident}/deploy',
'method': 'GET'
}
]
),
] ]

View File

@ -27,7 +27,8 @@ CONF = gyan.conf.CONF
def prepare_service(argv=None): def prepare_service(argv=None):
if argv is None: if argv is None:
argv = [] argv = ['/usr/local/bin/gyan-api', '--config-file', '/etc/gyan/gyan.conf']
argv = ['/usr/local/bin/gyan-api', '--config-file', '/etc/gyan/gyan.conf']
log.register_options(CONF) log.register_options(CONF)
config.parse_args(argv) config.parse_args(argv)
config.set_config_defaults() config.set_config_defaults()

View File

@ -23,6 +23,7 @@ import functools
import inspect import inspect
import json import json
import mimetypes import mimetypes
import os
from oslo_concurrency import processutils from oslo_concurrency import processutils
from oslo_context import context as common_context from oslo_context import context as common_context
@ -44,7 +45,7 @@ CONF = gyan.conf.CONF
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
VALID_STATES = { VALID_STATES = {
'deploy': [consts.CREATED, consts.UNDEPLOYED], 'deploy': [consts.CREATED, consts.UNDEPLOYED, consts.SCHEDULED],
'undeploy': [consts.DEPLOYED] 'undeploy': [consts.DEPLOYED]
} }
def safe_rstrip(value, chars=None): def safe_rstrip(value, chars=None):
@ -162,7 +163,7 @@ def get_ml_model(ml_model_ident):
def validate_ml_model_state(ml_model, action): def validate_ml_model_state(ml_model, action):
if ml_model.status not in VALID_STATES[action]: if ml_model.status not in VALID_STATES[action]:
raise exception.InvalidStateException( raise exception.InvalidStateException(
id=ml_model.uuid, id=ml_model.id,
action=action, action=action,
actual_state=ml_model.status) actual_state=ml_model.status)
@ -253,3 +254,12 @@ def decode_file_data(data):
return base64.b64decode(data) return base64.b64decode(data)
except (TypeError, binascii.Error): except (TypeError, binascii.Error):
raise exception.Base64Exception() raise exception.Base64Exception()
def save_model(path, model):
file_path = os.path.join(path, model.id)
with open(file_path+'.zip', 'wb') as f:
f.write(model.ml_data)
zip_ref = zipfile.ZipFile(file_path+'.zip', 'r')
zip_ref.extractall(file_path)
zip_ref.close()

View File

@ -28,7 +28,6 @@ CONF = gyan.conf.CONF
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@profiler.trace_cls("rpc")
class API(object): class API(object):
"""API for interacting with the compute manager.""" """API for interacting with the compute manager."""
@ -36,10 +35,11 @@ class API(object):
self.rpcapi = rpcapi.API(context=context) self.rpcapi = rpcapi.API(context=context)
super(API, self).__init__() super(API, self).__init__()
def ml_model_create(self, context, new_ml_model, extra_spec): def ml_model_create(self, context, new_ml_model, **extra_spec):
try: try:
host_state = self._schedule_ml_model(context, ml_model, host_state = {
extra_spec) "host": "localhost"
} #self._schedule_ml_model(context, ml_model, extra_spec)
except exception.NoValidHost: except exception.NoValidHost:
new_ml_model.status = consts.ERROR new_ml_model.status = consts.ERROR
new_ml_model.status_reason = _( new_ml_model.status_reason = _(
@ -51,13 +51,17 @@ class API(object):
new_ml_model.status_reason = _("Unexpected exception occurred.") new_ml_model.status_reason = _("Unexpected exception occurred.")
new_ml_model.save(context) new_ml_model.save(context)
raise raise
LOG.debug(host_state)
self.rpcapi.ml_model_create(context, host_state['host'], return self.rpcapi.ml_model_create(context, host_state['host'],
new_ml_model) new_ml_model)
def ml_model_predict(self, context, ml_model_id, **kwargs):
return self.rpcapi.ml_model_predict(context, ml_model_id,
**kwargs)
def ml_model_delete(self, context, ml_model, *args): def ml_model_delete(self, context, ml_model, *args):
self._record_action_start(context, ml_model, ml_model_actions.DELETE) self._record_action_start(context, ml_model, ml_model_actions.DELETE)
return self.rpcapi.ml_model_delete(context, ml_model, *args) return self.rpcapi.ml_model_delete(context, ml_model, *args)
def ml_model_show(self, context, ml_model): def ml_model_show(self, context, ml_model):
return self.rpcapi.ml_model_show(context, ml_model) return self.rpcapi.ml_model_show(context, ml_model)

View File

@ -1,5 +1,3 @@
# Copyright 2016 IBM Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may # Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain # not use this file except in compliance with the License. You may obtain
# a copy of the License at # a copy of the License at
@ -12,7 +10,9 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
import base64
import itertools import itertools
import os
import six import six
import time import time
@ -49,17 +49,18 @@ class Manager(periodic_task.PeriodicTasks):
self.host = CONF.compute.host self.host = CONF.compute.host
self._resource_tracker = None self._resource_tracker = None
def ml_model_create(self, context, limits, requested_networks, def ml_model_create(self, context, ml_model):
requested_volumes, ml_model, run, pci_requests=None): db_ml_model = objects.ML_Model.get_by_uuid_db(context, ml_model["id"])
@utils.synchronized(ml_model.uuid) utils.save_model(CONF.state_path, db_ml_model)
def do_ml_model_create(): obj_ml_model = objects.ML_Model.get_by_uuid(context, ml_model["id"])
created_ml_model = self._do_ml_model_create( obj_ml_model.status = consts.SCHEDULED
context, ml_model, requested_networks, requested_volumes, obj_ml_model.status_reason = "The ML Model is scheduled and saved to the host %s" % self.host
pci_requests, limits) obj_ml_model.save(context)
if run:
self._do_ml_model_start(context, created_ml_model)
utils.spawn_n(do_ml_model_create) def ml_model_predict(self, context, ml_model_id, kwargs):
#open("/home/bharath/Documents/0.png", "wb").write(base64.b64decode(kwargs["data"]))
model_path = os.path.join(CONF.state_path, ml_model_id)
return self.driver.predict(context, model_path, base64.b64decode(kwargs["data"]))
@wrap_ml_model_event(prefix='compute') @wrap_ml_model_event(prefix='compute')
def _do_ml_model_create(self, context, ml_model, requested_networks, def _do_ml_model_create(self, context, ml_model, requested_networks,
@ -118,4 +119,4 @@ class Manager(periodic_task.PeriodicTasks):
rt = compute_host_tracker.ComputeHostTracker(self.host, rt = compute_host_tracker.ComputeHostTracker(self.host,
self.driver) self.driver)
self._resource_tracker = rt self._resource_tracker = rt
return self._resource_tracker return self._resource_tracker

View File

@ -1,5 +1,3 @@
# Copyright 2016 IBM Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may # Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain # not use this file except in compliance with the License. You may obtain
# a copy of the License at # a copy of the License at
@ -30,7 +28,6 @@ def check_ml_model_host(func):
return wrap return wrap
@profiler.trace_cls("rpc")
class API(rpc_service.API): class API(rpc_service.API):
"""Client side of the ml_model compute rpc API. """Client side of the ml_model compute rpc API.
@ -51,6 +48,10 @@ class API(rpc_service.API):
self._cast(host, 'ml_model_create', self._cast(host, 'ml_model_create',
ml_model=ml_model) ml_model=ml_model)
def ml_model_predict(self, context, ml_model_id, **kwargs):
return self._call("localhost", 'ml_model_predict',
ml_model_id=ml_model_id, kwargs=kwargs)
@check_ml_model_host @check_ml_model_host
def ml_model_delete(self, context, ml_model, force): def ml_model_delete(self, context, ml_model, force):
return self._cast(ml_model.host, 'ml_model_delete', return self._cast(ml_model.host, 'ml_model_delete',

View File

@ -1,6 +1,3 @@
# Copyright 2015 OpenStack Foundation
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may # Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain # not use this file except in compliance with the License. You may obtain
# a copy of the License at # a copy of the License at

View File

@ -0,0 +1,39 @@
"""Add ml_type and ml_data to ml_model table
Revision ID: f3bf9414f399
Revises: cebd81b206ca
Create Date: 2018-10-13 09:48:36.783322
"""
# revision identifiers, used by Alembic.
revision = 'f3bf9414f399'
down_revision = 'cebd81b206ca'
branch_labels = None
depends_on = None
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('compute_host', schema=None) as batch_op:
batch_op.alter_column('hostname',
existing_type=mysql.VARCHAR(length=255),
nullable=False)
batch_op.alter_column('status',
existing_type=mysql.VARCHAR(length=255),
nullable=False)
batch_op.alter_column('type',
existing_type=mysql.VARCHAR(length=255),
nullable=False)
with op.batch_alter_table('ml_model', schema=None) as batch_op:
batch_op.add_column(sa.Column('ml_data', sa.LargeBinary(length=(2**32)-1), nullable=True))
batch_op.add_column(sa.Column('ml_type', sa.String(length=255), nullable=True))
batch_op.add_column(sa.Column('started_at', sa.DateTime(), nullable=True))
batch_op.create_unique_constraint('uniq_mlmodel0uuid', ['id'])
batch_op.drop_constraint(u'ml_model_ibfk_1', type_='foreignkey')
# ### end Alembic commands ###

View File

@ -1,5 +1,3 @@
# Copyright 2013 Hewlett-Packard Development Company, L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may # Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain # not use this file except in compliance with the License. You may obtain
# a copy of the License at # a copy of the License at
@ -13,6 +11,7 @@
# under the License. # under the License.
"""SQLAlchemy storage backend.""" """SQLAlchemy storage backend."""
from oslo_log import log as logging
from oslo_db import exception as db_exc from oslo_db import exception as db_exc
from oslo_db.sqlalchemy import session as db_session from oslo_db.sqlalchemy import session as db_session
@ -39,6 +38,7 @@ profiler_sqlalchemy = importutils.try_import('osprofiler.sqlalchemy')
CONF = gyan.conf.CONF CONF = gyan.conf.CONF
_FACADE = None _FACADE = None
LOG = logging.getLogger(__name__)
def _create_facade_lazily(): def _create_facade_lazily():
@ -90,7 +90,7 @@ def add_identity_filter(query, value):
if strutils.is_int_like(value): if strutils.is_int_like(value):
return query.filter_by(id=value) return query.filter_by(id=value)
elif uuidutils.is_uuid_like(value): elif uuidutils.is_uuid_like(value):
return query.filter_by(uuid=value) return query.filter_by(id=value)
else: else:
raise exception.InvalidIdentity(identity=value) raise exception.InvalidIdentity(identity=value)
@ -230,16 +230,17 @@ class Connection(object):
def list_ml_models(self, context, filters=None, limit=None, def list_ml_models(self, context, filters=None, limit=None,
marker=None, sort_key=None, sort_dir=None): marker=None, sort_key=None, sort_dir=None):
query = model_query(models.Capsule) query = model_query(models.ML_Model)
query = self._add_project_filters(context, query) query = self._add_project_filters(context, query)
query = self._add_ml_models_filters(query, filters) query = self._add_ml_models_filters(query, filters)
return _paginate_query(models.Capsule, limit, marker, LOG.debug(filters)
return _paginate_query(models.ML_Model, limit, marker,
sort_key, sort_dir, query) sort_key, sort_dir, query)
def create_ml_model(self, context, values): def create_ml_model(self, context, values):
# ensure defaults are present for new ml_models # ensure defaults are present for new ml_models
if not values.get('uuid'): if not values.get('id'):
values['uuid'] = uuidutils.generate_uuid() values['id'] = uuidutils.generate_uuid()
ml_model = models.ML_Model() ml_model = models.ML_Model()
ml_model.update(values) ml_model.update(values)
try: try:
@ -252,7 +253,7 @@ class Connection(object):
def get_ml_model_by_uuid(self, context, ml_model_uuid): def get_ml_model_by_uuid(self, context, ml_model_uuid):
query = model_query(models.ML_Model) query = model_query(models.ML_Model)
query = self._add_project_filters(context, query) query = self._add_project_filters(context, query)
query = query.filter_by(uuid=ml_model_uuid) query = query.filter_by(id=ml_model_uuid)
try: try:
return query.one() return query.one()
except NoResultFound: except NoResultFound:
@ -261,7 +262,7 @@ class Connection(object):
def get_ml_model_by_name(self, context, ml_model_name): def get_ml_model_by_name(self, context, ml_model_name):
query = model_query(models.ML_Model) query = model_query(models.ML_Model)
query = self._add_project_filters(context, query) query = self._add_project_filters(context, query)
query = query.filter_by(meta_name=ml_model_name) query = query.filter_by(name=ml_model_name)
try: try:
return query.one() return query.one()
except NoResultFound: except NoResultFound:

View File

@ -31,6 +31,7 @@ from sqlalchemy import orm
from sqlalchemy import schema from sqlalchemy import schema
from sqlalchemy import sql from sqlalchemy import sql
from sqlalchemy import String from sqlalchemy import String
from sqlalchemy import LargeBinary
from sqlalchemy import Text from sqlalchemy import Text
from sqlalchemy.types import TypeDecorator, TEXT from sqlalchemy.types import TypeDecorator, TEXT
@ -120,11 +121,12 @@ class ML_Model(Base):
name = Column(String(255)) name = Column(String(255))
status = Column(String(20)) status = Column(String(20))
status_reason = Column(Text, nullable=True) status_reason = Column(Text, nullable=True)
task_state = Column(String(20)) host_id = Column(String(255), nullable=True)
host_id = Column(String(255))
status_detail = Column(String(50))
deployed = Column(String(50))
deployed = Column(Text, nullable=True) deployed = Column(Text, nullable=True)
url = Column(Text, nullable=True)
hints = Column(Text, nullable=True)
ml_type = Column(String(255), nullable=True)
ml_data = Column(LargeBinary(length=(2**32)-1), nullable=True)
started_at = Column(DateTime) started_at = Column(DateTime)
@ -138,4 +140,4 @@ class ComputeHost(Base):
id = Column(String(36), primary_key=True, nullable=False) id = Column(String(36), primary_key=True, nullable=False)
hostname = Column(String(255), nullable=False) hostname = Column(String(255), nullable=False)
status = Column(String(255), nullable=False) status = Column(String(255), nullable=False)
type = Column(String(255), nullable=False) type = Column(String(255), nullable=False)

View File

@ -15,8 +15,13 @@ import datetime
import eventlet import eventlet
import functools import functools
import types import types
import png
import os
import tempfile
import numpy as np
import tensorflow as tf
from docker import errors
from oslo_log import log as logging from oslo_log import log as logging
from oslo_utils import timeutils from oslo_utils import timeutils
from oslo_utils import uuidutils from oslo_utils import uuidutils
@ -47,6 +52,24 @@ class TensorflowDriver(driver.MLModelDriver):
return ml_model return ml_model
pass pass
def _load(self, session, path):
saver = tf.train.import_meta_graph(path + '/model.meta')
saver.restore(session, tf.train.latest_checkpoint(path))
return tf.get_default_graph()
def predict(self, context, ml_model_path, data):
session = tf.Session()
graph = self._load(session, ml_model_path)
img_file, img_path = tempfile.mkstemp()
with os.fdopen(img_file, 'wb') as f:
f.write(data)
png_data = png.Reader(img_path)
img = np.array(list(png_data.read()[2]))
img = img.reshape(1, 784)
tensor = graph.get_tensor_by_name('x:0')
prediction = graph.get_tensor_by_name('classification:0')
return {"data": session.run(prediction, feed_dict={tensor:img})[0]}
def delete(self, context, ml_model, force): def delete(self, context, ml_model, force):
pass pass

View File

@ -43,3 +43,18 @@ class Json(fields.FieldType):
class JsonField(fields.AutoTypedField): class JsonField(fields.AutoTypedField):
AUTO_TYPE = Json() AUTO_TYPE = Json()
class ModelFieldType(fields.FieldType):
def coerce(self, obj, attr, value):
return value
def from_primitive(self, obj, attr, value):
return self.coerce(obj, attr, value)
def to_primitive(self, obj, attr, value):
return value
class ModelField(fields.AutoTypedField):
AUTO_TYPE = ModelFieldType()

View File

@ -22,6 +22,7 @@ from gyan.objects import fields as z_fields
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@base.GyanObjectRegistry.register @base.GyanObjectRegistry.register
class ML_Model(base.GyanPersistentObject, base.GyanObject): class ML_Model(base.GyanPersistentObject, base.GyanObject):
VERSION = '1' VERSION = '1'
@ -35,16 +36,19 @@ class ML_Model(base.GyanPersistentObject, base.GyanObject):
'status_reason': fields.StringField(nullable=True), 'status_reason': fields.StringField(nullable=True),
'url': fields.StringField(nullable=True), 'url': fields.StringField(nullable=True),
'deployed': fields.BooleanField(nullable=True), 'deployed': fields.BooleanField(nullable=True),
'node': fields.UUIDField(nullable=True),
'hints': fields.StringField(nullable=True), 'hints': fields.StringField(nullable=True),
'created_at': fields.DateTimeField(tzinfo_aware=False, nullable=True), 'created_at': fields.DateTimeField(tzinfo_aware=False, nullable=True),
'updated_at': fields.DateTimeField(tzinfo_aware=False, nullable=True) 'updated_at': fields.DateTimeField(tzinfo_aware=False, nullable=True),
'ml_data': z_fields.ModelField(nullable=True),
'ml_type': fields.StringField(nullable=True)
} }
@staticmethod @staticmethod
def _from_db_object(ml_model, db_ml_model): def _from_db_object(ml_model, db_ml_model):
"""Converts a database entity to a formal object.""" """Converts a database entity to a formal object."""
for field in ml_model.fields: for field in ml_model.fields:
if 'field' == 'ml_data':
continue
setattr(ml_model, field, db_ml_model[field]) setattr(ml_model, field, db_ml_model[field])
ml_model.obj_reset_changes() ml_model.obj_reset_changes()
@ -67,6 +71,17 @@ class ML_Model(base.GyanPersistentObject, base.GyanObject):
db_ml_model = dbapi.get_ml_model_by_uuid(context, uuid) db_ml_model = dbapi.get_ml_model_by_uuid(context, uuid)
ml_model = ML_Model._from_db_object(cls(context), db_ml_model) ml_model = ML_Model._from_db_object(cls(context), db_ml_model)
return ml_model return ml_model
@base.remotable_classmethod
def get_by_uuid_db(cls, context, uuid):
"""Find a ml model based on uuid and return a :class:`ML_Model` object.
:param uuid: the uuid of a ml model.
:param context: Security context
:returns: a :class:`ML_Model` object.
"""
db_ml_model = dbapi.get_ml_model_by_uuid(context, uuid)
return db_ml_model
@base.remotable_classmethod @base.remotable_classmethod
def get_by_name(cls, context, name): def get_by_name(cls, context, name):
@ -125,7 +140,7 @@ class ML_Model(base.GyanPersistentObject, base.GyanObject):
""" """
values = self.obj_get_changes() values = self.obj_get_changes()
db_ml_model = dbapi.create_ml_model(context, values) db_ml_model = dbapi.create_ml_model(context, values)
self._from_db_object(self, db_ml_model) return self._from_db_object(self, db_ml_model)
@base.remotable @base.remotable
def destroy(self, context=None): def destroy(self, context=None):
@ -138,7 +153,26 @@ class ML_Model(base.GyanPersistentObject, base.GyanObject):
A context should be set when instantiating the A context should be set when instantiating the
object, e.g.: ML Model(context) object, e.g.: ML Model(context)
""" """
dbapi.destroy_ml_model(context, self.uuid) dbapi.destroy_ml_model(context, self.id)
self.obj_reset_changes()
@base.remotable
def save(self, context=None):
"""Save updates to this ML Model.
Updates will be made column by column based on the result
of self.what_changed().
:param context: Security context. NOTE: This should only
be used internally by the indirection_api.
Unfortunately, RPC requires context as the first
argument, even though we don't use it.
A context should be set when instantiating the
object, e.g.: ML Model(context)
"""
updates = self.obj_get_changes()
dbapi.update_ml_model(context, self.id, updates)
self.obj_reset_changes() self.obj_reset_changes()
def obj_load_attr(self, attrname): def obj_load_attr(self, attrname):

View File

@ -1,8 +1,3 @@
# -*- coding: utf-8 -*-
# Copyright 2010-2011 OpenStack Foundation
# Copyright (c) 2013 Hewlett-Packard Development Company, L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may # Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain # not use this file except in compliance with the License. You may obtain
# a copy of the License at # a copy of the License at

View File

@ -2,4 +2,29 @@
# of appearance. Changing the order has an impact on the overall integration # of appearance. Changing the order has an impact on the overall integration
# process, which may cause wedges in the gate later. # process, which may cause wedges in the gate later.
pbr>=2.0 # Apache-2.0 PyYAML>=3.12 # MIT
eventlet!=0.18.3,!=0.20.1,>=0.18.2 # MIT
keystonemiddleware>=4.17.0 # Apache-2.0
pecan!=1.0.2,!=1.0.3,!=1.0.4,!=1.2,>=1.0.0 # BSD
oslo.i18n>=3.15.3 # Apache-2.0
oslo.log>=3.36.0 # Apache-2.0
oslo.concurrency>=3.25.0 # Apache-2.0
oslo.config>=5.2.0 # Apache-2.0
oslo.messaging>=5.29.0 # Apache-2.0
oslo.middleware>=3.31.0 # Apache-2.0
oslo.policy>=1.30.0 # Apache-2.0
oslo.privsep>=1.23.0 # Apache-2.0
oslo.serialization!=2.19.1,>=2.18.0 # Apache-2.0
oslo.service!=1.28.1,>=1.24.0 # Apache-2.0
oslo.versionedobjects>=1.31.2 # Apache-2.0
oslo.context>=2.19.2 # Apache-2.0
oslo.utils>=3.33.0 # Apache-2.0
oslo.db>=4.27.0 # Apache-2.0
os-brick>=2.2.0 # Apache-2.0
six>=1.10.0 # MIT
SQLAlchemy!=1.1.5,!=1.1.6,!=1.1.7,!=1.1.8,>=1.0.10 # MIT
stevedore>=1.20.0 # Apache-2.0
pypng
numpy
tensorflow
idx2numpy

View File

@ -1,5 +1,3 @@
# Copyright (c) 2013 Hewlett-Packard Development Company, L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at