Add tensorflow driver implementation

Change-Id: I951cea9325d2ea4a843ea55d1731c481df899474
This commit is contained in:
bharath 2018-10-14 00:22:56 +05:30
parent a40f556027
commit 26e2299908
25 changed files with 332 additions and 138 deletions

View File

@ -305,6 +305,7 @@ function start_gyan_compute {
function start_gyan {
# ``run_process`` checks ``is_service_enabled``, it is not needed here
mkdir -p /opt/stack/data/gyan
start_gyan_api
start_gyan_compute
}

View File

@ -82,10 +82,10 @@ class V1(controllers_base.APIBase):
'hosts', '',
bookmark=True)]
v1.ml_models = [link.make_link('self', pecan.request.host_url,
'ml_models', ''),
'ml-models', ''),
link.make_link('bookmark',
pecan.request.host_url,
'ml_models', '',
'ml-models', '',
bookmark=True)]
return v1
@ -147,9 +147,9 @@ class Controller(controllers_base.Controller):
{'url': pecan.request.url,
'method': pecan.request.method,
'body': pecan.request.body})
LOG.debug(msg)
# LOG.debug(msg)
LOG.debug(args)
return super(Controller, self)._route(args)
__all__ = ('Controller',)
__all__ = ('Controller',)

View File

@ -10,6 +10,7 @@
# License for the specific language governing permissions and limitations
# under the License.
import base64
import shlex
from oslo_log import log as logging
@ -74,12 +75,13 @@ class MLModelController(base.Controller):
"""Controller for MLModels."""
_custom_actions = {
'train': ['POST'],
'upload_trained_model': ['POST'],
'deploy': ['GET'],
'undeploy': ['GET']
'undeploy': ['GET'],
'predict': ['POST']
}
@pecan.expose('json')
@exception.wrap_pecan_controller_exception
def get_all(self, **kwargs):
@ -149,33 +151,55 @@ class MLModelController(base.Controller):
context.all_projects = True
ml_model = utils.get_ml_model(ml_model_ident)
check_policy_on_ml_model(ml_model.as_dict(), "ml_model:get_one")
if ml_model.node:
compute_api = pecan.request.compute_api
try:
ml_model = compute_api.ml_model_show(context, ml_model)
except exception.MLModelHostNotUp:
raise exception.ServerNotUsable
return view.format_ml_model(context, pecan.request.host_url,
ml_model.as_dict())
@base.Controller.api_version("1.0")
@pecan.expose('json')
@exception.wrap_pecan_controller_exception
def upload_trained_model(self, ml_model_ident, **kwargs):
context = pecan.request.context
LOG.debug(ml_model_ident)
ml_model = utils.get_ml_model(ml_model_ident)
LOG.debug(ml_model)
ml_model.ml_data = pecan.request.body
ml_model.save(context)
pecan.response.status = 200
compute_api = pecan.request.compute_api
new_model = view.format_ml_model(context, pecan.request.host_url,
ml_model.as_dict())
compute_api.ml_model_create(context, new_model)
return new_model
@base.Controller.api_version("1.0")
@pecan.expose('json')
@exception.wrap_pecan_controller_exception
def predict(self, ml_model_ident, **kwargs):
context = pecan.request.context
LOG.debug(ml_model_ident)
ml_model = utils.get_ml_model(ml_model_ident)
pecan.response.status = 200
compute_api = pecan.request.compute_api
predict_dict = {
"data": base64.b64encode(pecan.request.POST['file'].file.read())
}
prediction = compute_api.ml_model_predict(context, ml_model_ident, **predict_dict)
return prediction
@base.Controller.api_version("1.0")
@pecan.expose('json')
@api_utils.enforce_content_types(['application/json'])
@exception.wrap_pecan_controller_exception
@validation.validate_query_param(pecan.request, schema.query_param_create)
@validation.validated(schema.ml_model_create)
def post(self, **ml_model_dict):
return self._do_post(**ml_model_dict)
def _do_post(self, **ml_model_dict):
"""Create or run a new ml model.
:param ml_model_dict: a ml_model within the request body.
"""
context = pecan.request.context
compute_api = pecan.request.compute_api
policy.enforce(context, "ml_model:create",
action="ml_model:create")
@ -183,22 +207,24 @@ class MLModelController(base.Controller):
ml_model_dict['user_id'] = context.user_id
name = ml_model_dict.get('name')
ml_model_dict['name'] = name
ml_model_dict['status'] = consts.CREATING
ml_model_dict['status'] = consts.CREATED
ml_model_dict['ml_type'] = ml_model_dict['type']
extra_spec = {}
extra_spec['hints'] = ml_model_dict.get('hints', None)
#ml_model_dict["model_data"] = open("/home/bharath/model.zip", "rb").read()
new_ml_model = objects.ML_Model(context, **ml_model_dict)
new_ml_model.create(context)
compute_api.ml_model_create(context, new_ml_model, **kwargs)
ml_model = new_ml_model.create(context)
LOG.debug(new_ml_model)
#compute_api.ml_model_create(context, new_ml_model)
# Set the HTTP Location Header
pecan.response.location = link.build_url('ml_models',
new_ml_model.uuid)
pecan.response.status = 202
return view.format_ml_model(context, pecan.request.node_url,
new_ml_model.as_dict())
ml_model.id)
pecan.response.status = 201
return view.format_ml_model(context, pecan.request.host_url,
ml_model.as_dict())
@pecan.expose('json')
@exception.wrap_pecan_controller_exception
@validation.validated(schema.ml_model_update)
@ -217,11 +243,11 @@ class MLModelController(base.Controller):
return view.format_ml_model(context, pecan.request.node_url,
ml_model.as_dict())
@pecan.expose('json')
@exception.wrap_pecan_controller_exception
@validation.validate_query_param(pecan.request, schema.query_param_delete)
def delete(self, ml_model_ident, force=False, **kwargs):
def delete(self, ml_model_ident, **kwargs):
"""Delete a ML Model.
:param ml_model_ident: UUID or Name of a ML Model.
@ -230,27 +256,7 @@ class MLModelController(base.Controller):
context = pecan.request.context
ml_model = utils.get_ml_model(ml_model_ident)
check_policy_on_ml_model(ml_model.as_dict(), "ml_model:delete")
try:
force = strutils.bool_from_string(force, strict=True)
except ValueError:
bools = ', '.join(strutils.TRUE_STRINGS + strutils.FALSE_STRINGS)
raise exception.InvalidValue(_('Valid force values are: %s')
% bools)
stop = kwargs.pop('stop', False)
try:
stop = strutils.bool_from_string(stop, strict=True)
except ValueError:
bools = ', '.join(strutils.TRUE_STRINGS + strutils.FALSE_STRINGS)
raise exception.InvalidValue(_('Valid stop values are: %s')
% bools)
compute_api = pecan.request.compute_api
if not force:
utils.validate_ml_model_state(ml_model, 'delete')
ml_model.status = consts.DELETING
if ml_model.node:
compute_api.ml_model_delete(context, ml_model, force)
else:
ml_model.destroy(context)
ml_model.destroy(context)
pecan.response.status = 204
@ -261,15 +267,19 @@ class MLModelController(base.Controller):
:param ml_model_ident: UUID or Name of a ML Model.
"""
context = pecan.request.context
ml_model = utils.get_ml_model(ml_model_ident)
check_policy_on_ml_model(ml_model.as_dict(), "ml_model:deploy")
utils.validate_ml_model_state(ml_model, 'deploy')
LOG.debug('Calling compute.ml_model_deploy with %s',
ml_model.uuid)
context = pecan.request.context
compute_api = pecan.request.compute_api
compute_api.ml_model_deploy(context, ml_model)
ml_model.id)
ml_model.status = consts.DEPLOYED
url = pecan.request.url.replace("deploy", "predict")
ml_model.url = url
ml_model.save(context)
pecan.response.status = 202
return view.format_ml_model(context, pecan.request.host_url,
ml_model.as_dict())
@pecan.expose('json')
@exception.wrap_pecan_controller_exception
@ -278,12 +288,15 @@ class MLModelController(base.Controller):
:param ml_model_ident: UUID or Name of a ML Model.
"""
context = pecan.request.context
ml_model = utils.get_ml_model(ml_model_ident)
check_policy_on_ml_model(ml_model.as_dict(), "ml_model:deploy")
utils.validate_ml_model_state(ml_model, 'undeploy')
LOG.debug('Calling compute.ml_model_deploy with %s',
ml_model.uuid)
context = pecan.request.context
compute_api = pecan.request.compute_api
compute_api.ml_model_undeploy(context, ml_model)
ml_model.id)
ml_model.status = consts.SCHEDULED
ml_model.url = None
ml_model.save(context)
pecan.response.status = 202
return view.format_ml_model(context, pecan.request.host_url,
ml_model.as_dict())

View File

@ -18,8 +18,11 @@ _ml_model_properties = {}
ml_model_create = {
'type': 'object',
'properties': _ml_model_properties,
'required': ['name'],
'properties': {
"name": parameter_types.ml_model_name,
"type": parameter_types.ml_model_type
},
'required': ['name', 'type'],
'additionalProperties': False
}
@ -46,4 +49,4 @@ query_param_delete = {
'stop': parameter_types.boolean_extended
},
'additionalProperties': False
}
}

View File

@ -95,3 +95,17 @@ hostname = {
# real systems.
'pattern': '^[a-zA-Z0-9-._]*$',
}
ml_model_name = {
'type': 'string',
'minLength': 1,
'maxLength': 255,
'pattern': '^[a-zA-Z0-9-._]*$'
}
ml_model_type = {
'type': 'string',
'minLength': 1,
'maxLength': 255,
'pattern': '^[a-zA-Z0-9-._]*$'
}

View File

@ -13,41 +13,46 @@
import itertools
from oslo_log import log as logging
from gyan.api.controllers import link
from gyan.common.policies import ml_model as policies
_basic_keys = (
'uuid',
'id',
'user_id',
'project_id',
'name',
'url',
'status',
'status_reason',
'task_state',
'labels',
'host',
'status_detail'
'host_id',
'deployed',
'ml_type'
)
LOG = logging.getLogger(__name__)
def format_ml_model(context, url, ml_model):
def transform(key, value):
LOG.debug(key)
LOG.debug(value)
if key not in _basic_keys:
return
# strip the key if it is not allowed by policy
policy_action = policies.ML_MODEL % ('get_one:%s' % key)
if not context.can(policy_action, fatal=False, might_not_exist=True):
return
if key == 'uuid':
yield ('uuid', value)
if url:
yield ('links', [link.make_link(
'self', url, 'ml_models', value),
link.make_link(
'bookmark', url,
'ml_models', value,
bookmark=True)])
if key == 'id':
yield ('id', value)
# if url:
# yield ('links', [link.make_link(
# 'self', url, 'ml_models', value),
# link.make_link(
# 'bookmark', url,
# 'ml_models', value,
# bookmark=True)])
else:
yield (key, value)

View File

@ -1,5 +1,3 @@
# Copyright ? 2012 New Dream Network, LLC (DreamHost)
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at

View File

@ -113,4 +113,4 @@ def version_check(action, version):
if req_version < min_version:
raise exception.InvalidParamInVersion(param=action,
req_version=req_version,
min_version=min_version)
min_version=min_version)

View File

@ -14,4 +14,7 @@
ALLOCATED = 'allocated'
CREATED = 'created'
UNDEPLOYED = 'undeployed'
DEPLOYED = 'deployed'
DEPLOYED = 'deployed'
CREATING = 'CREATING'
CREATED = 'CREATED'
SCHEDULED = 'SCHEDULED'

View File

@ -106,16 +106,27 @@ rules = [
]
),
policy.DocumentedRuleDefault(
name=ML_MODEL % 'upload',
name=ML_MODEL % 'upload_trained_model',
check_str=base.RULE_ADMIN_OR_OWNER,
description='Upload the trained ML Model',
operations=[
{
'path': '/v1/ml_models/{ml_model_ident}/upload',
'path': '/v1/ml_models/{ml_model_ident}/upload_trained_model',
'method': 'POST'
}
]
),
policy.DocumentedRuleDefault(
name=ML_MODEL % 'deploy',
check_str=base.RULE_ADMIN_OR_OWNER,
description='Upload the trained ML Model',
operations=[
{
'path': '/v1/ml_models/{ml_model_ident}/deploy',
'method': 'GET'
}
]
),
]

View File

@ -27,7 +27,8 @@ CONF = gyan.conf.CONF
def prepare_service(argv=None):
if argv is None:
argv = []
argv = ['/usr/local/bin/gyan-api', '--config-file', '/etc/gyan/gyan.conf']
argv = ['/usr/local/bin/gyan-api', '--config-file', '/etc/gyan/gyan.conf']
log.register_options(CONF)
config.parse_args(argv)
config.set_config_defaults()

View File

@ -23,6 +23,7 @@ import functools
import inspect
import json
import mimetypes
import os
from oslo_concurrency import processutils
from oslo_context import context as common_context
@ -44,7 +45,7 @@ CONF = gyan.conf.CONF
LOG = logging.getLogger(__name__)
VALID_STATES = {
'deploy': [consts.CREATED, consts.UNDEPLOYED],
'deploy': [consts.CREATED, consts.UNDEPLOYED, consts.SCHEDULED],
'undeploy': [consts.DEPLOYED]
}
def safe_rstrip(value, chars=None):
@ -162,7 +163,7 @@ def get_ml_model(ml_model_ident):
def validate_ml_model_state(ml_model, action):
if ml_model.status not in VALID_STATES[action]:
raise exception.InvalidStateException(
id=ml_model.uuid,
id=ml_model.id,
action=action,
actual_state=ml_model.status)
@ -253,3 +254,12 @@ def decode_file_data(data):
return base64.b64decode(data)
except (TypeError, binascii.Error):
raise exception.Base64Exception()
def save_model(path, model):
file_path = os.path.join(path, model.id)
with open(file_path+'.zip', 'wb') as f:
f.write(model.ml_data)
zip_ref = zipfile.ZipFile(file_path+'.zip', 'r')
zip_ref.extractall(file_path)
zip_ref.close()

View File

@ -28,7 +28,6 @@ CONF = gyan.conf.CONF
LOG = logging.getLogger(__name__)
@profiler.trace_cls("rpc")
class API(object):
"""API for interacting with the compute manager."""
@ -36,10 +35,11 @@ class API(object):
self.rpcapi = rpcapi.API(context=context)
super(API, self).__init__()
def ml_model_create(self, context, new_ml_model, extra_spec):
def ml_model_create(self, context, new_ml_model, **extra_spec):
try:
host_state = self._schedule_ml_model(context, ml_model,
extra_spec)
host_state = {
"host": "localhost"
} #self._schedule_ml_model(context, ml_model, extra_spec)
except exception.NoValidHost:
new_ml_model.status = consts.ERROR
new_ml_model.status_reason = _(
@ -51,13 +51,17 @@ class API(object):
new_ml_model.status_reason = _("Unexpected exception occurred.")
new_ml_model.save(context)
raise
self.rpcapi.ml_model_create(context, host_state['host'],
LOG.debug(host_state)
return self.rpcapi.ml_model_create(context, host_state['host'],
new_ml_model)
def ml_model_predict(self, context, ml_model_id, **kwargs):
return self.rpcapi.ml_model_predict(context, ml_model_id,
**kwargs)
def ml_model_delete(self, context, ml_model, *args):
self._record_action_start(context, ml_model, ml_model_actions.DELETE)
return self.rpcapi.ml_model_delete(context, ml_model, *args)
def ml_model_show(self, context, ml_model):
return self.rpcapi.ml_model_show(context, ml_model)
return self.rpcapi.ml_model_show(context, ml_model)

View File

@ -1,5 +1,3 @@
# Copyright 2016 IBM Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
@ -12,7 +10,9 @@
# License for the specific language governing permissions and limitations
# under the License.
import base64
import itertools
import os
import six
import time
@ -49,17 +49,18 @@ class Manager(periodic_task.PeriodicTasks):
self.host = CONF.compute.host
self._resource_tracker = None
def ml_model_create(self, context, limits, requested_networks,
requested_volumes, ml_model, run, pci_requests=None):
@utils.synchronized(ml_model.uuid)
def do_ml_model_create():
created_ml_model = self._do_ml_model_create(
context, ml_model, requested_networks, requested_volumes,
pci_requests, limits)
if run:
self._do_ml_model_start(context, created_ml_model)
def ml_model_create(self, context, ml_model):
db_ml_model = objects.ML_Model.get_by_uuid_db(context, ml_model["id"])
utils.save_model(CONF.state_path, db_ml_model)
obj_ml_model = objects.ML_Model.get_by_uuid(context, ml_model["id"])
obj_ml_model.status = consts.SCHEDULED
obj_ml_model.status_reason = "The ML Model is scheduled and saved to the host %s" % self.host
obj_ml_model.save(context)
utils.spawn_n(do_ml_model_create)
def ml_model_predict(self, context, ml_model_id, kwargs):
#open("/home/bharath/Documents/0.png", "wb").write(base64.b64decode(kwargs["data"]))
model_path = os.path.join(CONF.state_path, ml_model_id)
return self.driver.predict(context, model_path, base64.b64decode(kwargs["data"]))
@wrap_ml_model_event(prefix='compute')
def _do_ml_model_create(self, context, ml_model, requested_networks,
@ -118,4 +119,4 @@ class Manager(periodic_task.PeriodicTasks):
rt = compute_host_tracker.ComputeHostTracker(self.host,
self.driver)
self._resource_tracker = rt
return self._resource_tracker
return self._resource_tracker

View File

@ -1,5 +1,3 @@
# Copyright 2016 IBM Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
@ -30,7 +28,6 @@ def check_ml_model_host(func):
return wrap
@profiler.trace_cls("rpc")
class API(rpc_service.API):
"""Client side of the ml_model compute rpc API.
@ -51,6 +48,10 @@ class API(rpc_service.API):
self._cast(host, 'ml_model_create',
ml_model=ml_model)
def ml_model_predict(self, context, ml_model_id, **kwargs):
return self._call("localhost", 'ml_model_predict',
ml_model_id=ml_model_id, kwargs=kwargs)
@check_ml_model_host
def ml_model_delete(self, context, ml_model, force):
return self._cast(ml_model.host, 'ml_model_delete',

View File

@ -1,6 +1,3 @@
# Copyright 2015 OpenStack Foundation
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at

View File

@ -0,0 +1,39 @@
"""Add ml_type and ml_data to ml_model table
Revision ID: f3bf9414f399
Revises: cebd81b206ca
Create Date: 2018-10-13 09:48:36.783322
"""
# revision identifiers, used by Alembic.
revision = 'f3bf9414f399'
down_revision = 'cebd81b206ca'
branch_labels = None
depends_on = None
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import mysql
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('compute_host', schema=None) as batch_op:
batch_op.alter_column('hostname',
existing_type=mysql.VARCHAR(length=255),
nullable=False)
batch_op.alter_column('status',
existing_type=mysql.VARCHAR(length=255),
nullable=False)
batch_op.alter_column('type',
existing_type=mysql.VARCHAR(length=255),
nullable=False)
with op.batch_alter_table('ml_model', schema=None) as batch_op:
batch_op.add_column(sa.Column('ml_data', sa.LargeBinary(length=(2**32)-1), nullable=True))
batch_op.add_column(sa.Column('ml_type', sa.String(length=255), nullable=True))
batch_op.add_column(sa.Column('started_at', sa.DateTime(), nullable=True))
batch_op.create_unique_constraint('uniq_mlmodel0uuid', ['id'])
batch_op.drop_constraint(u'ml_model_ibfk_1', type_='foreignkey')
# ### end Alembic commands ###

View File

@ -1,5 +1,3 @@
# Copyright 2013 Hewlett-Packard Development Company, L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
@ -13,6 +11,7 @@
# under the License.
"""SQLAlchemy storage backend."""
from oslo_log import log as logging
from oslo_db import exception as db_exc
from oslo_db.sqlalchemy import session as db_session
@ -39,6 +38,7 @@ profiler_sqlalchemy = importutils.try_import('osprofiler.sqlalchemy')
CONF = gyan.conf.CONF
_FACADE = None
LOG = logging.getLogger(__name__)
def _create_facade_lazily():
@ -90,7 +90,7 @@ def add_identity_filter(query, value):
if strutils.is_int_like(value):
return query.filter_by(id=value)
elif uuidutils.is_uuid_like(value):
return query.filter_by(uuid=value)
return query.filter_by(id=value)
else:
raise exception.InvalidIdentity(identity=value)
@ -230,16 +230,17 @@ class Connection(object):
def list_ml_models(self, context, filters=None, limit=None,
marker=None, sort_key=None, sort_dir=None):
query = model_query(models.Capsule)
query = model_query(models.ML_Model)
query = self._add_project_filters(context, query)
query = self._add_ml_models_filters(query, filters)
return _paginate_query(models.Capsule, limit, marker,
LOG.debug(filters)
return _paginate_query(models.ML_Model, limit, marker,
sort_key, sort_dir, query)
def create_ml_model(self, context, values):
# ensure defaults are present for new ml_models
if not values.get('uuid'):
values['uuid'] = uuidutils.generate_uuid()
if not values.get('id'):
values['id'] = uuidutils.generate_uuid()
ml_model = models.ML_Model()
ml_model.update(values)
try:
@ -252,7 +253,7 @@ class Connection(object):
def get_ml_model_by_uuid(self, context, ml_model_uuid):
query = model_query(models.ML_Model)
query = self._add_project_filters(context, query)
query = query.filter_by(uuid=ml_model_uuid)
query = query.filter_by(id=ml_model_uuid)
try:
return query.one()
except NoResultFound:
@ -261,7 +262,7 @@ class Connection(object):
def get_ml_model_by_name(self, context, ml_model_name):
query = model_query(models.ML_Model)
query = self._add_project_filters(context, query)
query = query.filter_by(meta_name=ml_model_name)
query = query.filter_by(name=ml_model_name)
try:
return query.one()
except NoResultFound:

View File

@ -31,6 +31,7 @@ from sqlalchemy import orm
from sqlalchemy import schema
from sqlalchemy import sql
from sqlalchemy import String
from sqlalchemy import LargeBinary
from sqlalchemy import Text
from sqlalchemy.types import TypeDecorator, TEXT
@ -120,11 +121,12 @@ class ML_Model(Base):
name = Column(String(255))
status = Column(String(20))
status_reason = Column(Text, nullable=True)
task_state = Column(String(20))
host_id = Column(String(255))
status_detail = Column(String(50))
deployed = Column(String(50))
host_id = Column(String(255), nullable=True)
deployed = Column(Text, nullable=True)
url = Column(Text, nullable=True)
hints = Column(Text, nullable=True)
ml_type = Column(String(255), nullable=True)
ml_data = Column(LargeBinary(length=(2**32)-1), nullable=True)
started_at = Column(DateTime)
@ -138,4 +140,4 @@ class ComputeHost(Base):
id = Column(String(36), primary_key=True, nullable=False)
hostname = Column(String(255), nullable=False)
status = Column(String(255), nullable=False)
type = Column(String(255), nullable=False)
type = Column(String(255), nullable=False)

View File

@ -15,8 +15,13 @@ import datetime
import eventlet
import functools
import types
import png
import os
import tempfile
import numpy as np
import tensorflow as tf
from docker import errors
from oslo_log import log as logging
from oslo_utils import timeutils
from oslo_utils import uuidutils
@ -47,6 +52,24 @@ class TensorflowDriver(driver.MLModelDriver):
return ml_model
pass
def _load(self, session, path):
saver = tf.train.import_meta_graph(path + '/model.meta')
saver.restore(session, tf.train.latest_checkpoint(path))
return tf.get_default_graph()
def predict(self, context, ml_model_path, data):
session = tf.Session()
graph = self._load(session, ml_model_path)
img_file, img_path = tempfile.mkstemp()
with os.fdopen(img_file, 'wb') as f:
f.write(data)
png_data = png.Reader(img_path)
img = np.array(list(png_data.read()[2]))
img = img.reshape(1, 784)
tensor = graph.get_tensor_by_name('x:0')
prediction = graph.get_tensor_by_name('classification:0')
return {"data": session.run(prediction, feed_dict={tensor:img})[0]}
def delete(self, context, ml_model, force):
pass

View File

@ -43,3 +43,18 @@ class Json(fields.FieldType):
class JsonField(fields.AutoTypedField):
AUTO_TYPE = Json()
class ModelFieldType(fields.FieldType):
def coerce(self, obj, attr, value):
return value
def from_primitive(self, obj, attr, value):
return self.coerce(obj, attr, value)
def to_primitive(self, obj, attr, value):
return value
class ModelField(fields.AutoTypedField):
AUTO_TYPE = ModelFieldType()

View File

@ -22,6 +22,7 @@ from gyan.objects import fields as z_fields
LOG = logging.getLogger(__name__)
@base.GyanObjectRegistry.register
class ML_Model(base.GyanPersistentObject, base.GyanObject):
VERSION = '1'
@ -35,16 +36,19 @@ class ML_Model(base.GyanPersistentObject, base.GyanObject):
'status_reason': fields.StringField(nullable=True),
'url': fields.StringField(nullable=True),
'deployed': fields.BooleanField(nullable=True),
'node': fields.UUIDField(nullable=True),
'hints': fields.StringField(nullable=True),
'created_at': fields.DateTimeField(tzinfo_aware=False, nullable=True),
'updated_at': fields.DateTimeField(tzinfo_aware=False, nullable=True)
'updated_at': fields.DateTimeField(tzinfo_aware=False, nullable=True),
'ml_data': z_fields.ModelField(nullable=True),
'ml_type': fields.StringField(nullable=True)
}
@staticmethod
def _from_db_object(ml_model, db_ml_model):
"""Converts a database entity to a formal object."""
for field in ml_model.fields:
if 'field' == 'ml_data':
continue
setattr(ml_model, field, db_ml_model[field])
ml_model.obj_reset_changes()
@ -67,6 +71,17 @@ class ML_Model(base.GyanPersistentObject, base.GyanObject):
db_ml_model = dbapi.get_ml_model_by_uuid(context, uuid)
ml_model = ML_Model._from_db_object(cls(context), db_ml_model)
return ml_model
@base.remotable_classmethod
def get_by_uuid_db(cls, context, uuid):
"""Find a ml model based on uuid and return a :class:`ML_Model` object.
:param uuid: the uuid of a ml model.
:param context: Security context
:returns: a :class:`ML_Model` object.
"""
db_ml_model = dbapi.get_ml_model_by_uuid(context, uuid)
return db_ml_model
@base.remotable_classmethod
def get_by_name(cls, context, name):
@ -125,7 +140,7 @@ class ML_Model(base.GyanPersistentObject, base.GyanObject):
"""
values = self.obj_get_changes()
db_ml_model = dbapi.create_ml_model(context, values)
self._from_db_object(self, db_ml_model)
return self._from_db_object(self, db_ml_model)
@base.remotable
def destroy(self, context=None):
@ -138,7 +153,26 @@ class ML_Model(base.GyanPersistentObject, base.GyanObject):
A context should be set when instantiating the
object, e.g.: ML Model(context)
"""
dbapi.destroy_ml_model(context, self.uuid)
dbapi.destroy_ml_model(context, self.id)
self.obj_reset_changes()
@base.remotable
def save(self, context=None):
"""Save updates to this ML Model.
Updates will be made column by column based on the result
of self.what_changed().
:param context: Security context. NOTE: This should only
be used internally by the indirection_api.
Unfortunately, RPC requires context as the first
argument, even though we don't use it.
A context should be set when instantiating the
object, e.g.: ML Model(context)
"""
updates = self.obj_get_changes()
dbapi.update_ml_model(context, self.id, updates)
self.obj_reset_changes()
def obj_load_attr(self, attrname):

View File

@ -1,8 +1,3 @@
# -*- coding: utf-8 -*-
# Copyright 2010-2011 OpenStack Foundation
# Copyright (c) 2013 Hewlett-Packard Development Company, L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at

View File

@ -2,4 +2,29 @@
# of appearance. Changing the order has an impact on the overall integration
# process, which may cause wedges in the gate later.
pbr>=2.0 # Apache-2.0
PyYAML>=3.12 # MIT
eventlet!=0.18.3,!=0.20.1,>=0.18.2 # MIT
keystonemiddleware>=4.17.0 # Apache-2.0
pecan!=1.0.2,!=1.0.3,!=1.0.4,!=1.2,>=1.0.0 # BSD
oslo.i18n>=3.15.3 # Apache-2.0
oslo.log>=3.36.0 # Apache-2.0
oslo.concurrency>=3.25.0 # Apache-2.0
oslo.config>=5.2.0 # Apache-2.0
oslo.messaging>=5.29.0 # Apache-2.0
oslo.middleware>=3.31.0 # Apache-2.0
oslo.policy>=1.30.0 # Apache-2.0
oslo.privsep>=1.23.0 # Apache-2.0
oslo.serialization!=2.19.1,>=2.18.0 # Apache-2.0
oslo.service!=1.28.1,>=1.24.0 # Apache-2.0
oslo.versionedobjects>=1.31.2 # Apache-2.0
oslo.context>=2.19.2 # Apache-2.0
oslo.utils>=3.33.0 # Apache-2.0
oslo.db>=4.27.0 # Apache-2.0
os-brick>=2.2.0 # Apache-2.0
six>=1.10.0 # MIT
SQLAlchemy!=1.1.5,!=1.1.6,!=1.1.7,!=1.1.8,>=1.0.10 # MIT
stevedore>=1.20.0 # Apache-2.0
pypng
numpy
tensorflow
idx2numpy

View File

@ -1,5 +1,3 @@
# Copyright (c) 2013 Hewlett-Packard Development Company, L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at