Add tensorflow driver implementation
Change-Id: I951cea9325d2ea4a843ea55d1731c481df899474
This commit is contained in:
parent
a40f556027
commit
26e2299908
@ -305,6 +305,7 @@ function start_gyan_compute {
|
||||
function start_gyan {
|
||||
|
||||
# ``run_process`` checks ``is_service_enabled``, it is not needed here
|
||||
mkdir -p /opt/stack/data/gyan
|
||||
start_gyan_api
|
||||
start_gyan_compute
|
||||
}
|
||||
|
@ -82,10 +82,10 @@ class V1(controllers_base.APIBase):
|
||||
'hosts', '',
|
||||
bookmark=True)]
|
||||
v1.ml_models = [link.make_link('self', pecan.request.host_url,
|
||||
'ml_models', ''),
|
||||
'ml-models', ''),
|
||||
link.make_link('bookmark',
|
||||
pecan.request.host_url,
|
||||
'ml_models', '',
|
||||
'ml-models', '',
|
||||
bookmark=True)]
|
||||
return v1
|
||||
|
||||
@ -147,9 +147,9 @@ class Controller(controllers_base.Controller):
|
||||
{'url': pecan.request.url,
|
||||
'method': pecan.request.method,
|
||||
'body': pecan.request.body})
|
||||
LOG.debug(msg)
|
||||
|
||||
# LOG.debug(msg)
|
||||
LOG.debug(args)
|
||||
return super(Controller, self)._route(args)
|
||||
|
||||
|
||||
__all__ = ('Controller',)
|
||||
__all__ = ('Controller',)
|
||||
|
@ -10,6 +10,7 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import base64
|
||||
import shlex
|
||||
|
||||
from oslo_log import log as logging
|
||||
@ -74,12 +75,13 @@ class MLModelController(base.Controller):
|
||||
"""Controller for MLModels."""
|
||||
|
||||
_custom_actions = {
|
||||
'train': ['POST'],
|
||||
'upload_trained_model': ['POST'],
|
||||
'deploy': ['GET'],
|
||||
'undeploy': ['GET']
|
||||
'undeploy': ['GET'],
|
||||
'predict': ['POST']
|
||||
}
|
||||
|
||||
|
||||
|
||||
@pecan.expose('json')
|
||||
@exception.wrap_pecan_controller_exception
|
||||
def get_all(self, **kwargs):
|
||||
@ -149,33 +151,55 @@ class MLModelController(base.Controller):
|
||||
context.all_projects = True
|
||||
ml_model = utils.get_ml_model(ml_model_ident)
|
||||
check_policy_on_ml_model(ml_model.as_dict(), "ml_model:get_one")
|
||||
if ml_model.node:
|
||||
compute_api = pecan.request.compute_api
|
||||
try:
|
||||
ml_model = compute_api.ml_model_show(context, ml_model)
|
||||
except exception.MLModelHostNotUp:
|
||||
raise exception.ServerNotUsable
|
||||
|
||||
return view.format_ml_model(context, pecan.request.host_url,
|
||||
ml_model.as_dict())
|
||||
|
||||
@base.Controller.api_version("1.0")
|
||||
@pecan.expose('json')
|
||||
@exception.wrap_pecan_controller_exception
|
||||
def upload_trained_model(self, ml_model_ident, **kwargs):
|
||||
context = pecan.request.context
|
||||
LOG.debug(ml_model_ident)
|
||||
ml_model = utils.get_ml_model(ml_model_ident)
|
||||
LOG.debug(ml_model)
|
||||
ml_model.ml_data = pecan.request.body
|
||||
ml_model.save(context)
|
||||
pecan.response.status = 200
|
||||
compute_api = pecan.request.compute_api
|
||||
new_model = view.format_ml_model(context, pecan.request.host_url,
|
||||
ml_model.as_dict())
|
||||
compute_api.ml_model_create(context, new_model)
|
||||
return new_model
|
||||
|
||||
@base.Controller.api_version("1.0")
|
||||
@pecan.expose('json')
|
||||
@exception.wrap_pecan_controller_exception
|
||||
def predict(self, ml_model_ident, **kwargs):
|
||||
context = pecan.request.context
|
||||
LOG.debug(ml_model_ident)
|
||||
ml_model = utils.get_ml_model(ml_model_ident)
|
||||
pecan.response.status = 200
|
||||
compute_api = pecan.request.compute_api
|
||||
predict_dict = {
|
||||
"data": base64.b64encode(pecan.request.POST['file'].file.read())
|
||||
}
|
||||
prediction = compute_api.ml_model_predict(context, ml_model_ident, **predict_dict)
|
||||
return prediction
|
||||
|
||||
@base.Controller.api_version("1.0")
|
||||
@pecan.expose('json')
|
||||
@api_utils.enforce_content_types(['application/json'])
|
||||
@exception.wrap_pecan_controller_exception
|
||||
@validation.validate_query_param(pecan.request, schema.query_param_create)
|
||||
@validation.validated(schema.ml_model_create)
|
||||
def post(self, **ml_model_dict):
|
||||
return self._do_post(**ml_model_dict)
|
||||
|
||||
|
||||
def _do_post(self, **ml_model_dict):
|
||||
"""Create or run a new ml model.
|
||||
|
||||
:param ml_model_dict: a ml_model within the request body.
|
||||
"""
|
||||
context = pecan.request.context
|
||||
compute_api = pecan.request.compute_api
|
||||
policy.enforce(context, "ml_model:create",
|
||||
action="ml_model:create")
|
||||
|
||||
@ -183,22 +207,24 @@ class MLModelController(base.Controller):
|
||||
ml_model_dict['user_id'] = context.user_id
|
||||
name = ml_model_dict.get('name')
|
||||
ml_model_dict['name'] = name
|
||||
|
||||
ml_model_dict['status'] = consts.CREATING
|
||||
|
||||
ml_model_dict['status'] = consts.CREATED
|
||||
ml_model_dict['ml_type'] = ml_model_dict['type']
|
||||
extra_spec = {}
|
||||
extra_spec['hints'] = ml_model_dict.get('hints', None)
|
||||
#ml_model_dict["model_data"] = open("/home/bharath/model.zip", "rb").read()
|
||||
new_ml_model = objects.ML_Model(context, **ml_model_dict)
|
||||
new_ml_model.create(context)
|
||||
|
||||
compute_api.ml_model_create(context, new_ml_model, **kwargs)
|
||||
ml_model = new_ml_model.create(context)
|
||||
LOG.debug(new_ml_model)
|
||||
#compute_api.ml_model_create(context, new_ml_model)
|
||||
# Set the HTTP Location Header
|
||||
pecan.response.location = link.build_url('ml_models',
|
||||
new_ml_model.uuid)
|
||||
pecan.response.status = 202
|
||||
return view.format_ml_model(context, pecan.request.node_url,
|
||||
new_ml_model.as_dict())
|
||||
ml_model.id)
|
||||
pecan.response.status = 201
|
||||
return view.format_ml_model(context, pecan.request.host_url,
|
||||
ml_model.as_dict())
|
||||
|
||||
|
||||
|
||||
@pecan.expose('json')
|
||||
@exception.wrap_pecan_controller_exception
|
||||
@validation.validated(schema.ml_model_update)
|
||||
@ -217,11 +243,11 @@ class MLModelController(base.Controller):
|
||||
return view.format_ml_model(context, pecan.request.node_url,
|
||||
ml_model.as_dict())
|
||||
|
||||
|
||||
|
||||
@pecan.expose('json')
|
||||
@exception.wrap_pecan_controller_exception
|
||||
@validation.validate_query_param(pecan.request, schema.query_param_delete)
|
||||
def delete(self, ml_model_ident, force=False, **kwargs):
|
||||
def delete(self, ml_model_ident, **kwargs):
|
||||
"""Delete a ML Model.
|
||||
|
||||
:param ml_model_ident: UUID or Name of a ML Model.
|
||||
@ -230,27 +256,7 @@ class MLModelController(base.Controller):
|
||||
context = pecan.request.context
|
||||
ml_model = utils.get_ml_model(ml_model_ident)
|
||||
check_policy_on_ml_model(ml_model.as_dict(), "ml_model:delete")
|
||||
try:
|
||||
force = strutils.bool_from_string(force, strict=True)
|
||||
except ValueError:
|
||||
bools = ', '.join(strutils.TRUE_STRINGS + strutils.FALSE_STRINGS)
|
||||
raise exception.InvalidValue(_('Valid force values are: %s')
|
||||
% bools)
|
||||
stop = kwargs.pop('stop', False)
|
||||
try:
|
||||
stop = strutils.bool_from_string(stop, strict=True)
|
||||
except ValueError:
|
||||
bools = ', '.join(strutils.TRUE_STRINGS + strutils.FALSE_STRINGS)
|
||||
raise exception.InvalidValue(_('Valid stop values are: %s')
|
||||
% bools)
|
||||
compute_api = pecan.request.compute_api
|
||||
if not force:
|
||||
utils.validate_ml_model_state(ml_model, 'delete')
|
||||
ml_model.status = consts.DELETING
|
||||
if ml_model.node:
|
||||
compute_api.ml_model_delete(context, ml_model, force)
|
||||
else:
|
||||
ml_model.destroy(context)
|
||||
ml_model.destroy(context)
|
||||
pecan.response.status = 204
|
||||
|
||||
|
||||
@ -261,15 +267,19 @@ class MLModelController(base.Controller):
|
||||
|
||||
:param ml_model_ident: UUID or Name of a ML Model.
|
||||
"""
|
||||
context = pecan.request.context
|
||||
ml_model = utils.get_ml_model(ml_model_ident)
|
||||
check_policy_on_ml_model(ml_model.as_dict(), "ml_model:deploy")
|
||||
utils.validate_ml_model_state(ml_model, 'deploy')
|
||||
LOG.debug('Calling compute.ml_model_deploy with %s',
|
||||
ml_model.uuid)
|
||||
context = pecan.request.context
|
||||
compute_api = pecan.request.compute_api
|
||||
compute_api.ml_model_deploy(context, ml_model)
|
||||
ml_model.id)
|
||||
ml_model.status = consts.DEPLOYED
|
||||
url = pecan.request.url.replace("deploy", "predict")
|
||||
ml_model.url = url
|
||||
ml_model.save(context)
|
||||
pecan.response.status = 202
|
||||
return view.format_ml_model(context, pecan.request.host_url,
|
||||
ml_model.as_dict())
|
||||
|
||||
@pecan.expose('json')
|
||||
@exception.wrap_pecan_controller_exception
|
||||
@ -278,12 +288,15 @@ class MLModelController(base.Controller):
|
||||
|
||||
:param ml_model_ident: UUID or Name of a ML Model.
|
||||
"""
|
||||
context = pecan.request.context
|
||||
ml_model = utils.get_ml_model(ml_model_ident)
|
||||
check_policy_on_ml_model(ml_model.as_dict(), "ml_model:deploy")
|
||||
utils.validate_ml_model_state(ml_model, 'undeploy')
|
||||
LOG.debug('Calling compute.ml_model_deploy with %s',
|
||||
ml_model.uuid)
|
||||
context = pecan.request.context
|
||||
compute_api = pecan.request.compute_api
|
||||
compute_api.ml_model_undeploy(context, ml_model)
|
||||
ml_model.id)
|
||||
ml_model.status = consts.SCHEDULED
|
||||
ml_model.url = None
|
||||
ml_model.save(context)
|
||||
pecan.response.status = 202
|
||||
return view.format_ml_model(context, pecan.request.host_url,
|
||||
ml_model.as_dict())
|
||||
|
@ -18,8 +18,11 @@ _ml_model_properties = {}
|
||||
|
||||
ml_model_create = {
|
||||
'type': 'object',
|
||||
'properties': _ml_model_properties,
|
||||
'required': ['name'],
|
||||
'properties': {
|
||||
"name": parameter_types.ml_model_name,
|
||||
"type": parameter_types.ml_model_type
|
||||
},
|
||||
'required': ['name', 'type'],
|
||||
'additionalProperties': False
|
||||
}
|
||||
|
||||
@ -46,4 +49,4 @@ query_param_delete = {
|
||||
'stop': parameter_types.boolean_extended
|
||||
},
|
||||
'additionalProperties': False
|
||||
}
|
||||
}
|
||||
|
@ -95,3 +95,17 @@ hostname = {
|
||||
# real systems.
|
||||
'pattern': '^[a-zA-Z0-9-._]*$',
|
||||
}
|
||||
|
||||
ml_model_name = {
|
||||
'type': 'string',
|
||||
'minLength': 1,
|
||||
'maxLength': 255,
|
||||
'pattern': '^[a-zA-Z0-9-._]*$'
|
||||
}
|
||||
|
||||
ml_model_type = {
|
||||
'type': 'string',
|
||||
'minLength': 1,
|
||||
'maxLength': 255,
|
||||
'pattern': '^[a-zA-Z0-9-._]*$'
|
||||
}
|
||||
|
@ -13,41 +13,46 @@
|
||||
|
||||
import itertools
|
||||
|
||||
from oslo_log import log as logging
|
||||
|
||||
from gyan.api.controllers import link
|
||||
from gyan.common.policies import ml_model as policies
|
||||
|
||||
_basic_keys = (
|
||||
'uuid',
|
||||
'id',
|
||||
'user_id',
|
||||
'project_id',
|
||||
'name',
|
||||
'url',
|
||||
'status',
|
||||
'status_reason',
|
||||
'task_state',
|
||||
'labels',
|
||||
'host',
|
||||
'status_detail'
|
||||
'host_id',
|
||||
'deployed',
|
||||
'ml_type'
|
||||
)
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def format_ml_model(context, url, ml_model):
|
||||
def transform(key, value):
|
||||
LOG.debug(key)
|
||||
LOG.debug(value)
|
||||
if key not in _basic_keys:
|
||||
return
|
||||
# strip the key if it is not allowed by policy
|
||||
policy_action = policies.ML_MODEL % ('get_one:%s' % key)
|
||||
if not context.can(policy_action, fatal=False, might_not_exist=True):
|
||||
return
|
||||
if key == 'uuid':
|
||||
yield ('uuid', value)
|
||||
if url:
|
||||
yield ('links', [link.make_link(
|
||||
'self', url, 'ml_models', value),
|
||||
link.make_link(
|
||||
'bookmark', url,
|
||||
'ml_models', value,
|
||||
bookmark=True)])
|
||||
if key == 'id':
|
||||
yield ('id', value)
|
||||
# if url:
|
||||
# yield ('links', [link.make_link(
|
||||
# 'self', url, 'ml_models', value),
|
||||
# link.make_link(
|
||||
# 'bookmark', url,
|
||||
# 'ml_models', value,
|
||||
# bookmark=True)])
|
||||
else:
|
||||
yield (key, value)
|
||||
|
||||
|
@ -1,5 +1,3 @@
|
||||
# Copyright ? 2012 New Dream Network, LLC (DreamHost)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
|
@ -113,4 +113,4 @@ def version_check(action, version):
|
||||
if req_version < min_version:
|
||||
raise exception.InvalidParamInVersion(param=action,
|
||||
req_version=req_version,
|
||||
min_version=min_version)
|
||||
min_version=min_version)
|
||||
|
@ -14,4 +14,7 @@
|
||||
ALLOCATED = 'allocated'
|
||||
CREATED = 'created'
|
||||
UNDEPLOYED = 'undeployed'
|
||||
DEPLOYED = 'deployed'
|
||||
DEPLOYED = 'deployed'
|
||||
CREATING = 'CREATING'
|
||||
CREATED = 'CREATED'
|
||||
SCHEDULED = 'SCHEDULED'
|
@ -106,16 +106,27 @@ rules = [
|
||||
]
|
||||
),
|
||||
policy.DocumentedRuleDefault(
|
||||
name=ML_MODEL % 'upload',
|
||||
name=ML_MODEL % 'upload_trained_model',
|
||||
check_str=base.RULE_ADMIN_OR_OWNER,
|
||||
description='Upload the trained ML Model',
|
||||
operations=[
|
||||
{
|
||||
'path': '/v1/ml_models/{ml_model_ident}/upload',
|
||||
'path': '/v1/ml_models/{ml_model_ident}/upload_trained_model',
|
||||
'method': 'POST'
|
||||
}
|
||||
]
|
||||
),
|
||||
policy.DocumentedRuleDefault(
|
||||
name=ML_MODEL % 'deploy',
|
||||
check_str=base.RULE_ADMIN_OR_OWNER,
|
||||
description='Upload the trained ML Model',
|
||||
operations=[
|
||||
{
|
||||
'path': '/v1/ml_models/{ml_model_ident}/deploy',
|
||||
'method': 'GET'
|
||||
}
|
||||
]
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
@ -27,7 +27,8 @@ CONF = gyan.conf.CONF
|
||||
|
||||
def prepare_service(argv=None):
|
||||
if argv is None:
|
||||
argv = []
|
||||
argv = ['/usr/local/bin/gyan-api', '--config-file', '/etc/gyan/gyan.conf']
|
||||
argv = ['/usr/local/bin/gyan-api', '--config-file', '/etc/gyan/gyan.conf']
|
||||
log.register_options(CONF)
|
||||
config.parse_args(argv)
|
||||
config.set_config_defaults()
|
||||
|
@ -23,6 +23,7 @@ import functools
|
||||
import inspect
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
|
||||
from oslo_concurrency import processutils
|
||||
from oslo_context import context as common_context
|
||||
@ -44,7 +45,7 @@ CONF = gyan.conf.CONF
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
VALID_STATES = {
|
||||
'deploy': [consts.CREATED, consts.UNDEPLOYED],
|
||||
'deploy': [consts.CREATED, consts.UNDEPLOYED, consts.SCHEDULED],
|
||||
'undeploy': [consts.DEPLOYED]
|
||||
}
|
||||
def safe_rstrip(value, chars=None):
|
||||
@ -162,7 +163,7 @@ def get_ml_model(ml_model_ident):
|
||||
def validate_ml_model_state(ml_model, action):
|
||||
if ml_model.status not in VALID_STATES[action]:
|
||||
raise exception.InvalidStateException(
|
||||
id=ml_model.uuid,
|
||||
id=ml_model.id,
|
||||
action=action,
|
||||
actual_state=ml_model.status)
|
||||
|
||||
@ -253,3 +254,12 @@ def decode_file_data(data):
|
||||
return base64.b64decode(data)
|
||||
except (TypeError, binascii.Error):
|
||||
raise exception.Base64Exception()
|
||||
|
||||
|
||||
def save_model(path, model):
|
||||
file_path = os.path.join(path, model.id)
|
||||
with open(file_path+'.zip', 'wb') as f:
|
||||
f.write(model.ml_data)
|
||||
zip_ref = zipfile.ZipFile(file_path+'.zip', 'r')
|
||||
zip_ref.extractall(file_path)
|
||||
zip_ref.close()
|
@ -28,7 +28,6 @@ CONF = gyan.conf.CONF
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@profiler.trace_cls("rpc")
|
||||
class API(object):
|
||||
"""API for interacting with the compute manager."""
|
||||
|
||||
@ -36,10 +35,11 @@ class API(object):
|
||||
self.rpcapi = rpcapi.API(context=context)
|
||||
super(API, self).__init__()
|
||||
|
||||
def ml_model_create(self, context, new_ml_model, extra_spec):
|
||||
def ml_model_create(self, context, new_ml_model, **extra_spec):
|
||||
try:
|
||||
host_state = self._schedule_ml_model(context, ml_model,
|
||||
extra_spec)
|
||||
host_state = {
|
||||
"host": "localhost"
|
||||
} #self._schedule_ml_model(context, ml_model, extra_spec)
|
||||
except exception.NoValidHost:
|
||||
new_ml_model.status = consts.ERROR
|
||||
new_ml_model.status_reason = _(
|
||||
@ -51,13 +51,17 @@ class API(object):
|
||||
new_ml_model.status_reason = _("Unexpected exception occurred.")
|
||||
new_ml_model.save(context)
|
||||
raise
|
||||
|
||||
self.rpcapi.ml_model_create(context, host_state['host'],
|
||||
LOG.debug(host_state)
|
||||
return self.rpcapi.ml_model_create(context, host_state['host'],
|
||||
new_ml_model)
|
||||
|
||||
def ml_model_predict(self, context, ml_model_id, **kwargs):
|
||||
return self.rpcapi.ml_model_predict(context, ml_model_id,
|
||||
**kwargs)
|
||||
|
||||
def ml_model_delete(self, context, ml_model, *args):
|
||||
self._record_action_start(context, ml_model, ml_model_actions.DELETE)
|
||||
return self.rpcapi.ml_model_delete(context, ml_model, *args)
|
||||
|
||||
def ml_model_show(self, context, ml_model):
|
||||
return self.rpcapi.ml_model_show(context, ml_model)
|
||||
return self.rpcapi.ml_model_show(context, ml_model)
|
||||
|
@ -1,5 +1,3 @@
|
||||
# Copyright 2016 IBM Corp.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
@ -12,7 +10,9 @@
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import base64
|
||||
import itertools
|
||||
import os
|
||||
|
||||
import six
|
||||
import time
|
||||
@ -49,17 +49,18 @@ class Manager(periodic_task.PeriodicTasks):
|
||||
self.host = CONF.compute.host
|
||||
self._resource_tracker = None
|
||||
|
||||
def ml_model_create(self, context, limits, requested_networks,
|
||||
requested_volumes, ml_model, run, pci_requests=None):
|
||||
@utils.synchronized(ml_model.uuid)
|
||||
def do_ml_model_create():
|
||||
created_ml_model = self._do_ml_model_create(
|
||||
context, ml_model, requested_networks, requested_volumes,
|
||||
pci_requests, limits)
|
||||
if run:
|
||||
self._do_ml_model_start(context, created_ml_model)
|
||||
def ml_model_create(self, context, ml_model):
|
||||
db_ml_model = objects.ML_Model.get_by_uuid_db(context, ml_model["id"])
|
||||
utils.save_model(CONF.state_path, db_ml_model)
|
||||
obj_ml_model = objects.ML_Model.get_by_uuid(context, ml_model["id"])
|
||||
obj_ml_model.status = consts.SCHEDULED
|
||||
obj_ml_model.status_reason = "The ML Model is scheduled and saved to the host %s" % self.host
|
||||
obj_ml_model.save(context)
|
||||
|
||||
utils.spawn_n(do_ml_model_create)
|
||||
def ml_model_predict(self, context, ml_model_id, kwargs):
|
||||
#open("/home/bharath/Documents/0.png", "wb").write(base64.b64decode(kwargs["data"]))
|
||||
model_path = os.path.join(CONF.state_path, ml_model_id)
|
||||
return self.driver.predict(context, model_path, base64.b64decode(kwargs["data"]))
|
||||
|
||||
@wrap_ml_model_event(prefix='compute')
|
||||
def _do_ml_model_create(self, context, ml_model, requested_networks,
|
||||
@ -118,4 +119,4 @@ class Manager(periodic_task.PeriodicTasks):
|
||||
rt = compute_host_tracker.ComputeHostTracker(self.host,
|
||||
self.driver)
|
||||
self._resource_tracker = rt
|
||||
return self._resource_tracker
|
||||
return self._resource_tracker
|
||||
|
@ -1,5 +1,3 @@
|
||||
# Copyright 2016 IBM Corp.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
@ -30,7 +28,6 @@ def check_ml_model_host(func):
|
||||
return wrap
|
||||
|
||||
|
||||
@profiler.trace_cls("rpc")
|
||||
class API(rpc_service.API):
|
||||
"""Client side of the ml_model compute rpc API.
|
||||
|
||||
@ -51,6 +48,10 @@ class API(rpc_service.API):
|
||||
self._cast(host, 'ml_model_create',
|
||||
ml_model=ml_model)
|
||||
|
||||
def ml_model_predict(self, context, ml_model_id, **kwargs):
|
||||
return self._call("localhost", 'ml_model_predict',
|
||||
ml_model_id=ml_model_id, kwargs=kwargs)
|
||||
|
||||
@check_ml_model_host
|
||||
def ml_model_delete(self, context, ml_model, force):
|
||||
return self._cast(ml_model.host, 'ml_model_delete',
|
||||
|
@ -1,6 +1,3 @@
|
||||
# Copyright 2015 OpenStack Foundation
|
||||
# All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
|
@ -0,0 +1,39 @@
|
||||
"""Add ml_type and ml_data to ml_model table
|
||||
|
||||
Revision ID: f3bf9414f399
|
||||
Revises: cebd81b206ca
|
||||
Create Date: 2018-10-13 09:48:36.783322
|
||||
|
||||
"""
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'f3bf9414f399'
|
||||
down_revision = 'cebd81b206ca'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import mysql
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('compute_host', schema=None) as batch_op:
|
||||
batch_op.alter_column('hostname',
|
||||
existing_type=mysql.VARCHAR(length=255),
|
||||
nullable=False)
|
||||
batch_op.alter_column('status',
|
||||
existing_type=mysql.VARCHAR(length=255),
|
||||
nullable=False)
|
||||
batch_op.alter_column('type',
|
||||
existing_type=mysql.VARCHAR(length=255),
|
||||
nullable=False)
|
||||
|
||||
with op.batch_alter_table('ml_model', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('ml_data', sa.LargeBinary(length=(2**32)-1), nullable=True))
|
||||
batch_op.add_column(sa.Column('ml_type', sa.String(length=255), nullable=True))
|
||||
batch_op.add_column(sa.Column('started_at', sa.DateTime(), nullable=True))
|
||||
batch_op.create_unique_constraint('uniq_mlmodel0uuid', ['id'])
|
||||
batch_op.drop_constraint(u'ml_model_ibfk_1', type_='foreignkey')
|
||||
|
||||
# ### end Alembic commands ###
|
@ -1,5 +1,3 @@
|
||||
# Copyright 2013 Hewlett-Packard Development Company, L.P.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
@ -13,6 +11,7 @@
|
||||
# under the License.
|
||||
|
||||
"""SQLAlchemy storage backend."""
|
||||
from oslo_log import log as logging
|
||||
|
||||
from oslo_db import exception as db_exc
|
||||
from oslo_db.sqlalchemy import session as db_session
|
||||
@ -39,6 +38,7 @@ profiler_sqlalchemy = importutils.try_import('osprofiler.sqlalchemy')
|
||||
CONF = gyan.conf.CONF
|
||||
|
||||
_FACADE = None
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _create_facade_lazily():
|
||||
@ -90,7 +90,7 @@ def add_identity_filter(query, value):
|
||||
if strutils.is_int_like(value):
|
||||
return query.filter_by(id=value)
|
||||
elif uuidutils.is_uuid_like(value):
|
||||
return query.filter_by(uuid=value)
|
||||
return query.filter_by(id=value)
|
||||
else:
|
||||
raise exception.InvalidIdentity(identity=value)
|
||||
|
||||
@ -230,16 +230,17 @@ class Connection(object):
|
||||
|
||||
def list_ml_models(self, context, filters=None, limit=None,
|
||||
marker=None, sort_key=None, sort_dir=None):
|
||||
query = model_query(models.Capsule)
|
||||
query = model_query(models.ML_Model)
|
||||
query = self._add_project_filters(context, query)
|
||||
query = self._add_ml_models_filters(query, filters)
|
||||
return _paginate_query(models.Capsule, limit, marker,
|
||||
LOG.debug(filters)
|
||||
return _paginate_query(models.ML_Model, limit, marker,
|
||||
sort_key, sort_dir, query)
|
||||
|
||||
def create_ml_model(self, context, values):
|
||||
# ensure defaults are present for new ml_models
|
||||
if not values.get('uuid'):
|
||||
values['uuid'] = uuidutils.generate_uuid()
|
||||
if not values.get('id'):
|
||||
values['id'] = uuidutils.generate_uuid()
|
||||
ml_model = models.ML_Model()
|
||||
ml_model.update(values)
|
||||
try:
|
||||
@ -252,7 +253,7 @@ class Connection(object):
|
||||
def get_ml_model_by_uuid(self, context, ml_model_uuid):
|
||||
query = model_query(models.ML_Model)
|
||||
query = self._add_project_filters(context, query)
|
||||
query = query.filter_by(uuid=ml_model_uuid)
|
||||
query = query.filter_by(id=ml_model_uuid)
|
||||
try:
|
||||
return query.one()
|
||||
except NoResultFound:
|
||||
@ -261,7 +262,7 @@ class Connection(object):
|
||||
def get_ml_model_by_name(self, context, ml_model_name):
|
||||
query = model_query(models.ML_Model)
|
||||
query = self._add_project_filters(context, query)
|
||||
query = query.filter_by(meta_name=ml_model_name)
|
||||
query = query.filter_by(name=ml_model_name)
|
||||
try:
|
||||
return query.one()
|
||||
except NoResultFound:
|
||||
|
@ -31,6 +31,7 @@ from sqlalchemy import orm
|
||||
from sqlalchemy import schema
|
||||
from sqlalchemy import sql
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy import LargeBinary
|
||||
from sqlalchemy import Text
|
||||
from sqlalchemy.types import TypeDecorator, TEXT
|
||||
|
||||
@ -120,11 +121,12 @@ class ML_Model(Base):
|
||||
name = Column(String(255))
|
||||
status = Column(String(20))
|
||||
status_reason = Column(Text, nullable=True)
|
||||
task_state = Column(String(20))
|
||||
host_id = Column(String(255))
|
||||
status_detail = Column(String(50))
|
||||
deployed = Column(String(50))
|
||||
host_id = Column(String(255), nullable=True)
|
||||
deployed = Column(Text, nullable=True)
|
||||
url = Column(Text, nullable=True)
|
||||
hints = Column(Text, nullable=True)
|
||||
ml_type = Column(String(255), nullable=True)
|
||||
ml_data = Column(LargeBinary(length=(2**32)-1), nullable=True)
|
||||
started_at = Column(DateTime)
|
||||
|
||||
|
||||
@ -138,4 +140,4 @@ class ComputeHost(Base):
|
||||
id = Column(String(36), primary_key=True, nullable=False)
|
||||
hostname = Column(String(255), nullable=False)
|
||||
status = Column(String(255), nullable=False)
|
||||
type = Column(String(255), nullable=False)
|
||||
type = Column(String(255), nullable=False)
|
||||
|
@ -15,8 +15,13 @@ import datetime
|
||||
import eventlet
|
||||
import functools
|
||||
import types
|
||||
import png
|
||||
import os
|
||||
import tempfile
|
||||
import numpy as np
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from docker import errors
|
||||
from oslo_log import log as logging
|
||||
from oslo_utils import timeutils
|
||||
from oslo_utils import uuidutils
|
||||
@ -47,6 +52,24 @@ class TensorflowDriver(driver.MLModelDriver):
|
||||
return ml_model
|
||||
pass
|
||||
|
||||
def _load(self, session, path):
|
||||
saver = tf.train.import_meta_graph(path + '/model.meta')
|
||||
saver.restore(session, tf.train.latest_checkpoint(path))
|
||||
return tf.get_default_graph()
|
||||
|
||||
def predict(self, context, ml_model_path, data):
|
||||
session = tf.Session()
|
||||
graph = self._load(session, ml_model_path)
|
||||
img_file, img_path = tempfile.mkstemp()
|
||||
with os.fdopen(img_file, 'wb') as f:
|
||||
f.write(data)
|
||||
png_data = png.Reader(img_path)
|
||||
img = np.array(list(png_data.read()[2]))
|
||||
img = img.reshape(1, 784)
|
||||
tensor = graph.get_tensor_by_name('x:0')
|
||||
prediction = graph.get_tensor_by_name('classification:0')
|
||||
return {"data": session.run(prediction, feed_dict={tensor:img})[0]}
|
||||
|
||||
|
||||
def delete(self, context, ml_model, force):
|
||||
pass
|
||||
|
@ -43,3 +43,18 @@ class Json(fields.FieldType):
|
||||
|
||||
class JsonField(fields.AutoTypedField):
|
||||
AUTO_TYPE = Json()
|
||||
|
||||
|
||||
class ModelFieldType(fields.FieldType):
|
||||
def coerce(self, obj, attr, value):
|
||||
return value
|
||||
|
||||
def from_primitive(self, obj, attr, value):
|
||||
return self.coerce(obj, attr, value)
|
||||
|
||||
def to_primitive(self, obj, attr, value):
|
||||
return value
|
||||
|
||||
|
||||
class ModelField(fields.AutoTypedField):
|
||||
AUTO_TYPE = ModelFieldType()
|
@ -22,6 +22,7 @@ from gyan.objects import fields as z_fields
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@base.GyanObjectRegistry.register
|
||||
class ML_Model(base.GyanPersistentObject, base.GyanObject):
|
||||
VERSION = '1'
|
||||
@ -35,16 +36,19 @@ class ML_Model(base.GyanPersistentObject, base.GyanObject):
|
||||
'status_reason': fields.StringField(nullable=True),
|
||||
'url': fields.StringField(nullable=True),
|
||||
'deployed': fields.BooleanField(nullable=True),
|
||||
'node': fields.UUIDField(nullable=True),
|
||||
'hints': fields.StringField(nullable=True),
|
||||
'created_at': fields.DateTimeField(tzinfo_aware=False, nullable=True),
|
||||
'updated_at': fields.DateTimeField(tzinfo_aware=False, nullable=True)
|
||||
'updated_at': fields.DateTimeField(tzinfo_aware=False, nullable=True),
|
||||
'ml_data': z_fields.ModelField(nullable=True),
|
||||
'ml_type': fields.StringField(nullable=True)
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _from_db_object(ml_model, db_ml_model):
|
||||
"""Converts a database entity to a formal object."""
|
||||
for field in ml_model.fields:
|
||||
if 'field' == 'ml_data':
|
||||
continue
|
||||
setattr(ml_model, field, db_ml_model[field])
|
||||
|
||||
ml_model.obj_reset_changes()
|
||||
@ -67,6 +71,17 @@ class ML_Model(base.GyanPersistentObject, base.GyanObject):
|
||||
db_ml_model = dbapi.get_ml_model_by_uuid(context, uuid)
|
||||
ml_model = ML_Model._from_db_object(cls(context), db_ml_model)
|
||||
return ml_model
|
||||
|
||||
@base.remotable_classmethod
|
||||
def get_by_uuid_db(cls, context, uuid):
|
||||
"""Find a ml model based on uuid and return a :class:`ML_Model` object.
|
||||
|
||||
:param uuid: the uuid of a ml model.
|
||||
:param context: Security context
|
||||
:returns: a :class:`ML_Model` object.
|
||||
"""
|
||||
db_ml_model = dbapi.get_ml_model_by_uuid(context, uuid)
|
||||
return db_ml_model
|
||||
|
||||
@base.remotable_classmethod
|
||||
def get_by_name(cls, context, name):
|
||||
@ -125,7 +140,7 @@ class ML_Model(base.GyanPersistentObject, base.GyanObject):
|
||||
"""
|
||||
values = self.obj_get_changes()
|
||||
db_ml_model = dbapi.create_ml_model(context, values)
|
||||
self._from_db_object(self, db_ml_model)
|
||||
return self._from_db_object(self, db_ml_model)
|
||||
|
||||
@base.remotable
|
||||
def destroy(self, context=None):
|
||||
@ -138,7 +153,26 @@ class ML_Model(base.GyanPersistentObject, base.GyanObject):
|
||||
A context should be set when instantiating the
|
||||
object, e.g.: ML Model(context)
|
||||
"""
|
||||
dbapi.destroy_ml_model(context, self.uuid)
|
||||
dbapi.destroy_ml_model(context, self.id)
|
||||
self.obj_reset_changes()
|
||||
|
||||
@base.remotable
|
||||
def save(self, context=None):
|
||||
"""Save updates to this ML Model.
|
||||
|
||||
Updates will be made column by column based on the result
|
||||
of self.what_changed().
|
||||
|
||||
:param context: Security context. NOTE: This should only
|
||||
be used internally by the indirection_api.
|
||||
Unfortunately, RPC requires context as the first
|
||||
argument, even though we don't use it.
|
||||
A context should be set when instantiating the
|
||||
object, e.g.: ML Model(context)
|
||||
"""
|
||||
updates = self.obj_get_changes()
|
||||
dbapi.update_ml_model(context, self.id, updates)
|
||||
|
||||
self.obj_reset_changes()
|
||||
|
||||
def obj_load_attr(self, attrname):
|
||||
|
@ -1,8 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2010-2011 OpenStack Foundation
|
||||
# Copyright (c) 2013 Hewlett-Packard Development Company, L.P.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
|
@ -2,4 +2,29 @@
|
||||
# of appearance. Changing the order has an impact on the overall integration
|
||||
# process, which may cause wedges in the gate later.
|
||||
|
||||
pbr>=2.0 # Apache-2.0
|
||||
PyYAML>=3.12 # MIT
|
||||
eventlet!=0.18.3,!=0.20.1,>=0.18.2 # MIT
|
||||
keystonemiddleware>=4.17.0 # Apache-2.0
|
||||
pecan!=1.0.2,!=1.0.3,!=1.0.4,!=1.2,>=1.0.0 # BSD
|
||||
oslo.i18n>=3.15.3 # Apache-2.0
|
||||
oslo.log>=3.36.0 # Apache-2.0
|
||||
oslo.concurrency>=3.25.0 # Apache-2.0
|
||||
oslo.config>=5.2.0 # Apache-2.0
|
||||
oslo.messaging>=5.29.0 # Apache-2.0
|
||||
oslo.middleware>=3.31.0 # Apache-2.0
|
||||
oslo.policy>=1.30.0 # Apache-2.0
|
||||
oslo.privsep>=1.23.0 # Apache-2.0
|
||||
oslo.serialization!=2.19.1,>=2.18.0 # Apache-2.0
|
||||
oslo.service!=1.28.1,>=1.24.0 # Apache-2.0
|
||||
oslo.versionedobjects>=1.31.2 # Apache-2.0
|
||||
oslo.context>=2.19.2 # Apache-2.0
|
||||
oslo.utils>=3.33.0 # Apache-2.0
|
||||
oslo.db>=4.27.0 # Apache-2.0
|
||||
os-brick>=2.2.0 # Apache-2.0
|
||||
six>=1.10.0 # MIT
|
||||
SQLAlchemy!=1.1.5,!=1.1.6,!=1.1.7,!=1.1.8,>=1.0.10 # MIT
|
||||
stevedore>=1.20.0 # Apache-2.0
|
||||
pypng
|
||||
numpy
|
||||
tensorflow
|
||||
idx2numpy
|
||||
|
Loading…
Reference in New Issue
Block a user