Browse Source

Add tensorflow driver implementation

Change-Id: I951cea9325d2ea4a843ea55d1731c481df899474
bharath 6 months ago
parent
commit
26e2299908

+ 1
- 0
devstack/lib/gyan View File

@@ -305,6 +305,7 @@ function start_gyan_compute {
305 305
 function start_gyan {
306 306
 
307 307
     # ``run_process`` checks ``is_service_enabled``, it is not needed here
308
+    mkdir -p /opt/stack/data/gyan
308 309
     start_gyan_api
309 310
     start_gyan_compute
310 311
 }

+ 5
- 5
gyan/api/controllers/v1/__init__.py View File

@@ -82,10 +82,10 @@ class V1(controllers_base.APIBase):
82 82
                                         'hosts', '',
83 83
                                         bookmark=True)]
84 84
         v1.ml_models = [link.make_link('self', pecan.request.host_url,
85
-                                    'ml_models', ''),
85
+                                    'ml-models', ''),
86 86
                      link.make_link('bookmark',
87 87
                                     pecan.request.host_url,
88
-                                    'ml_models', '',
88
+                                    'ml-models', '',
89 89
                                     bookmark=True)]
90 90
         return v1
91 91
 
@@ -147,9 +147,9 @@ class Controller(controllers_base.Controller):
147 147
                    {'url': pecan.request.url,
148 148
                     'method': pecan.request.method,
149 149
                     'body': pecan.request.body})
150
-            LOG.debug(msg)
151
-
150
+            # LOG.debug(msg)
151
+        LOG.debug(args)
152 152
         return super(Controller, self)._route(args)
153 153
 
154 154
 
155
-__all__ = ('Controller',)
155
+__all__ = ('Controller',)

+ 67
- 54
gyan/api/controllers/v1/ml_models.py View File

@@ -10,6 +10,7 @@
10 10
 #    License for the specific language governing permissions and limitations
11 11
 #    under the License.
12 12
 
13
+import base64
13 14
 import shlex
14 15
 
15 16
 from oslo_log import log as logging
@@ -74,12 +75,13 @@ class MLModelController(base.Controller):
74 75
     """Controller for MLModels."""
75 76
 
76 77
     _custom_actions = {
77
-        'train': ['POST'],
78
+        'upload_trained_model': ['POST'],
78 79
         'deploy': ['GET'],
79
-        'undeploy': ['GET']
80
+        'undeploy': ['GET'],
81
+        'predict': ['POST']
80 82
     }
81 83
 
82
-    
84
+
83 85
     @pecan.expose('json')
84 86
     @exception.wrap_pecan_controller_exception
85 87
     def get_all(self, **kwargs):
@@ -149,33 +151,55 @@ class MLModelController(base.Controller):
149 151
             context.all_projects = True
150 152
         ml_model = utils.get_ml_model(ml_model_ident)
151 153
         check_policy_on_ml_model(ml_model.as_dict(), "ml_model:get_one")
152
-        if ml_model.node:
153
-            compute_api = pecan.request.compute_api
154
-            try:
155
-                ml_model = compute_api.ml_model_show(context, ml_model)
156
-            except exception.MLModelHostNotUp:
157
-                raise exception.ServerNotUsable
158
-
159 154
         return view.format_ml_model(context, pecan.request.host_url,
160 155
                                      ml_model.as_dict())
161 156
 
157
+    @base.Controller.api_version("1.0")
158
+    @pecan.expose('json')
159
+    @exception.wrap_pecan_controller_exception
160
+    def upload_trained_model(self, ml_model_ident, **kwargs):
161
+        context = pecan.request.context
162
+        LOG.debug(ml_model_ident)
163
+        ml_model = utils.get_ml_model(ml_model_ident)
164
+        LOG.debug(ml_model)
165
+        ml_model.ml_data = pecan.request.body
166
+        ml_model.save(context)
167
+        pecan.response.status = 200
168
+        compute_api = pecan.request.compute_api
169
+        new_model = view.format_ml_model(context, pecan.request.host_url,
170
+                                         ml_model.as_dict())
171
+        compute_api.ml_model_create(context, new_model)
172
+        return new_model
173
+    
174
+    @base.Controller.api_version("1.0")
175
+    @pecan.expose('json')
176
+    @exception.wrap_pecan_controller_exception
177
+    def predict(self, ml_model_ident, **kwargs):
178
+        context = pecan.request.context
179
+        LOG.debug(ml_model_ident)
180
+        ml_model = utils.get_ml_model(ml_model_ident)
181
+        pecan.response.status = 200
182
+        compute_api = pecan.request.compute_api
183
+        predict_dict = {
184
+            "data": base64.b64encode(pecan.request.POST['file'].file.read())
185
+        }
186
+        prediction = compute_api.ml_model_predict(context, ml_model_ident, **predict_dict)
187
+        return prediction
188
+
162 189
     @base.Controller.api_version("1.0")
163 190
     @pecan.expose('json')
164 191
     @api_utils.enforce_content_types(['application/json'])
165 192
     @exception.wrap_pecan_controller_exception
166
-    @validation.validate_query_param(pecan.request, schema.query_param_create)
167 193
     @validation.validated(schema.ml_model_create)
168 194
     def post(self, **ml_model_dict):
169 195
         return self._do_post(**ml_model_dict)
170 196
 
171
-
172 197
     def _do_post(self, **ml_model_dict):
173 198
         """Create or run a new ml model.
174 199
 
175 200
         :param ml_model_dict: a ml_model within the request body.
176 201
         """
177 202
         context = pecan.request.context
178
-        compute_api = pecan.request.compute_api
179 203
         policy.enforce(context, "ml_model:create",
180 204
                        action="ml_model:create")
181 205
 
@@ -183,22 +207,24 @@ class MLModelController(base.Controller):
183 207
         ml_model_dict['user_id'] = context.user_id
184 208
         name = ml_model_dict.get('name')
185 209
         ml_model_dict['name'] = name
186
-        
187
-        ml_model_dict['status'] = consts.CREATING
210
+
211
+        ml_model_dict['status'] = consts.CREATED
212
+        ml_model_dict['ml_type'] = ml_model_dict['type']
188 213
         extra_spec = {}
189 214
         extra_spec['hints'] = ml_model_dict.get('hints', None)
215
+        #ml_model_dict["model_data"] = open("/home/bharath/model.zip", "rb").read()
190 216
         new_ml_model = objects.ML_Model(context, **ml_model_dict)
191
-        new_ml_model.create(context)
192
-
193
-        compute_api.ml_model_create(context, new_ml_model, **kwargs)
217
+        ml_model = new_ml_model.create(context)
218
+        LOG.debug(new_ml_model)
219
+        #compute_api.ml_model_create(context, new_ml_model)
194 220
         # Set the HTTP Location Header
195 221
         pecan.response.location = link.build_url('ml_models',
196
-                                                 new_ml_model.uuid)
197
-        pecan.response.status = 202
198
-        return view.format_ml_model(context, pecan.request.node_url,
199
-                                     new_ml_model.as_dict())
222
+                                                 ml_model.id)
223
+        pecan.response.status = 201
224
+        return view.format_ml_model(context, pecan.request.host_url,
225
+                                     ml_model.as_dict())
226
+
200 227
 
201
-    
202 228
     @pecan.expose('json')
203 229
     @exception.wrap_pecan_controller_exception
204 230
     @validation.validated(schema.ml_model_update)
@@ -217,11 +243,11 @@ class MLModelController(base.Controller):
217 243
         return view.format_ml_model(context, pecan.request.node_url,
218 244
                                      ml_model.as_dict())
219 245
 
220
-    
246
+
221 247
     @pecan.expose('json')
222 248
     @exception.wrap_pecan_controller_exception
223 249
     @validation.validate_query_param(pecan.request, schema.query_param_delete)
224
-    def delete(self, ml_model_ident, force=False, **kwargs):
250
+    def delete(self, ml_model_ident, **kwargs):
225 251
         """Delete a ML Model.
226 252
 
227 253
         :param ml_model_ident: UUID or Name of a ML Model.
@@ -230,27 +256,7 @@ class MLModelController(base.Controller):
230 256
         context = pecan.request.context
231 257
         ml_model = utils.get_ml_model(ml_model_ident)
232 258
         check_policy_on_ml_model(ml_model.as_dict(), "ml_model:delete")
233
-        try:
234
-            force = strutils.bool_from_string(force, strict=True)
235
-        except ValueError:
236
-            bools = ', '.join(strutils.TRUE_STRINGS + strutils.FALSE_STRINGS)
237
-            raise exception.InvalidValue(_('Valid force values are: %s')
238
-                                         % bools)
239
-        stop = kwargs.pop('stop', False)
240
-        try:
241
-            stop = strutils.bool_from_string(stop, strict=True)
242
-        except ValueError:
243
-            bools = ', '.join(strutils.TRUE_STRINGS + strutils.FALSE_STRINGS)
244
-            raise exception.InvalidValue(_('Valid stop values are: %s')
245
-                                         % bools)
246
-        compute_api = pecan.request.compute_api
247
-        if not force:
248
-            utils.validate_ml_model_state(ml_model, 'delete')
249
-        ml_model.status = consts.DELETING
250
-        if ml_model.node:
251
-            compute_api.ml_model_delete(context, ml_model, force)
252
-        else:
253
-            ml_model.destroy(context)
259
+        ml_model.destroy(context)
254 260
         pecan.response.status = 204
255 261
 
256 262
 
@@ -261,15 +267,19 @@ class MLModelController(base.Controller):
261 267
 
262 268
         :param ml_model_ident: UUID or Name of a ML Model.
263 269
         """
270
+        context = pecan.request.context
264 271
         ml_model = utils.get_ml_model(ml_model_ident)
265 272
         check_policy_on_ml_model(ml_model.as_dict(), "ml_model:deploy")
266 273
         utils.validate_ml_model_state(ml_model, 'deploy')
267 274
         LOG.debug('Calling compute.ml_model_deploy with %s',
268
-                  ml_model.uuid)
269
-        context = pecan.request.context
270
-        compute_api = pecan.request.compute_api
271
-        compute_api.ml_model_deploy(context, ml_model)
275
+                  ml_model.id)
276
+        ml_model.status =  consts.DEPLOYED
277
+        url = pecan.request.url.replace("deploy", "predict")
278
+        ml_model.url = url
279
+        ml_model.save(context)
272 280
         pecan.response.status = 202
281
+        return view.format_ml_model(context, pecan.request.host_url,
282
+                                     ml_model.as_dict())
273 283
 
274 284
     @pecan.expose('json')
275 285
     @exception.wrap_pecan_controller_exception
@@ -278,12 +288,15 @@ class MLModelController(base.Controller):
278 288
 
279 289
         :param ml_model_ident: UUID or Name of a ML Model.
280 290
         """
291
+        context = pecan.request.context
281 292
         ml_model = utils.get_ml_model(ml_model_ident)
282 293
         check_policy_on_ml_model(ml_model.as_dict(), "ml_model:deploy")
283 294
         utils.validate_ml_model_state(ml_model, 'undeploy')
284 295
         LOG.debug('Calling compute.ml_model_deploy with %s',
285
-                  ml_model.uuid)
286
-        context = pecan.request.context
287
-        compute_api = pecan.request.compute_api
288
-        compute_api.ml_model_undeploy(context, ml_model)
296
+                  ml_model.id)
297
+        ml_model.status = consts.SCHEDULED
298
+        ml_model.url = None
299
+        ml_model.save(context)
289 300
         pecan.response.status = 202
301
+        return view.format_ml_model(context, pecan.request.host_url,
302
+                                     ml_model.as_dict())

+ 6
- 3
gyan/api/controllers/v1/schemas/ml_models.py View File

@@ -18,8 +18,11 @@ _ml_model_properties = {}
18 18
 
19 19
 ml_model_create = {
20 20
     'type': 'object',
21
-    'properties': _ml_model_properties,
22
-    'required': ['name'],
21
+    'properties': {
22
+        "name": parameter_types.ml_model_name,
23
+        "type": parameter_types.ml_model_type
24
+    },
25
+    'required': ['name', 'type'],
23 26
     'additionalProperties': False
24 27
 }
25 28
 
@@ -46,4 +49,4 @@ query_param_delete = {
46 49
         'stop': parameter_types.boolean_extended
47 50
     },
48 51
     'additionalProperties': False
49
-}
52
+}

+ 14
- 0
gyan/api/controllers/v1/schemas/parameter_types.py View File

@@ -95,3 +95,17 @@ hostname = {
95 95
     # real systems.
96 96
     'pattern': '^[a-zA-Z0-9-._]*$',
97 97
 }
98
+
99
+ml_model_name = {
100
+    'type': 'string',
101
+    'minLength': 1,
102
+    'maxLength': 255,
103
+    'pattern': '^[a-zA-Z0-9-._]*$'
104
+}
105
+
106
+ml_model_type = {
107
+    'type': 'string',
108
+    'minLength': 1,
109
+    'maxLength': 255,
110
+    'pattern': '^[a-zA-Z0-9-._]*$'
111
+}

+ 19
- 14
gyan/api/controllers/v1/views/ml_models_view.py View File

@@ -13,41 +13,46 @@
13 13
 
14 14
 import itertools
15 15
 
16
+from oslo_log import log as logging
17
+
16 18
 from gyan.api.controllers import link
17 19
 from gyan.common.policies import ml_model as policies
18 20
 
19 21
 _basic_keys = (
20
-    'uuid',
22
+    'id',
21 23
     'user_id',
22 24
     'project_id',
23 25
     'name',
24 26
     'url',
25 27
     'status',
26 28
     'status_reason',
27
-    'task_state',
28
-    'labels',
29
-    'host',
30
-    'status_detail'
29
+    'host_id',
30
+    'deployed',
31
+    'ml_type'
31 32
 )
32 33
 
34
+LOG = logging.getLogger(__name__)
35
+
33 36
 
34 37
 def format_ml_model(context, url, ml_model):
35 38
     def transform(key, value):
39
+        LOG.debug(key)
40
+        LOG.debug(value)
36 41
         if key not in _basic_keys:
37 42
             return
38 43
         # strip the key if it is not allowed by policy
39 44
         policy_action = policies.ML_MODEL % ('get_one:%s' % key)
40 45
         if not context.can(policy_action, fatal=False, might_not_exist=True):
41 46
             return
42
-        if key == 'uuid':
43
-            yield ('uuid', value)
44
-            if url:
45
-                yield ('links', [link.make_link(
46
-                    'self', url, 'ml_models', value),
47
-                    link.make_link(
48
-                        'bookmark', url,
49
-                        'ml_models', value,
50
-                        bookmark=True)])
47
+        if key == 'id':
48
+            yield ('id', value)
49
+            # if url:
50
+            #     yield ('links', [link.make_link(
51
+            #         'self', url, 'ml_models', value),
52
+            #         link.make_link(
53
+            #             'bookmark', url,
54
+            #             'ml_models', value,
55
+            #             bookmark=True)])
51 56
         else:
52 57
             yield (key, value)
53 58
 

+ 0
- 2
gyan/api/middleware/parsable_error.py View File

@@ -1,5 +1,3 @@
1
-# Copyright ? 2012 New Dream Network, LLC (DreamHost)
2
-#
3 1
 # Licensed under the Apache License, Version 2.0 (the "License"); you may
4 2
 # not use this file except in compliance with the License. You may obtain
5 3
 # a copy of the License at

+ 1
- 1
gyan/api/utils.py View File

@@ -113,4 +113,4 @@ def version_check(action, version):
113 113
     if req_version < min_version:
114 114
         raise exception.InvalidParamInVersion(param=action,
115 115
                                               req_version=req_version,
116
-                                              min_version=min_version)
116
+                                              min_version=min_version)

+ 4
- 1
gyan/common/consts.py View File

@@ -14,4 +14,7 @@
14 14
 ALLOCATED = 'allocated'
15 15
 CREATED = 'created'
16 16
 UNDEPLOYED = 'undeployed'
17
-DEPLOYED = 'deployed'
17
+DEPLOYED = 'deployed'
18
+CREATING = 'CREATING'
19
+CREATED = 'CREATED'
20
+SCHEDULED = 'SCHEDULED'

+ 13
- 2
gyan/common/policies/ml_model.py View File

@@ -106,16 +106,27 @@ rules = [
106 106
         ]
107 107
     ),
108 108
     policy.DocumentedRuleDefault(
109
-        name=ML_MODEL % 'upload',
109
+        name=ML_MODEL % 'upload_trained_model',
110 110
         check_str=base.RULE_ADMIN_OR_OWNER,
111 111
         description='Upload the trained ML Model',
112 112
         operations=[
113 113
             {
114
-                'path': '/v1/ml_models/{ml_model_ident}/upload',
114
+                'path': '/v1/ml_models/{ml_model_ident}/upload_trained_model',
115 115
                 'method': 'POST'
116 116
             }
117 117
         ]
118 118
     ),
119
+    policy.DocumentedRuleDefault(
120
+        name=ML_MODEL % 'deploy',
121
+        check_str=base.RULE_ADMIN_OR_OWNER,
122
+        description='Upload the trained ML Model',
123
+        operations=[
124
+            {
125
+                'path': '/v1/ml_models/{ml_model_ident}/deploy',
126
+                'method': 'GET'
127
+            }
128
+        ]
129
+    ),
119 130
 ]
120 131
 
121 132
 

+ 2
- 1
gyan/common/service.py View File

@@ -27,7 +27,8 @@ CONF = gyan.conf.CONF
27 27
 
28 28
 def prepare_service(argv=None):
29 29
     if argv is None:
30
-        argv = []
30
+        argv = ['/usr/local/bin/gyan-api', '--config-file', '/etc/gyan/gyan.conf']
31
+    argv = ['/usr/local/bin/gyan-api', '--config-file', '/etc/gyan/gyan.conf']
31 32
     log.register_options(CONF)
32 33
     config.parse_args(argv)
33 34
     config.set_config_defaults()

+ 12
- 2
gyan/common/utils.py View File

@@ -23,6 +23,7 @@ import functools
23 23
 import inspect
24 24
 import json
25 25
 import mimetypes
26
+import os
26 27
 
27 28
 from oslo_concurrency import processutils
28 29
 from oslo_context import context as common_context
@@ -44,7 +45,7 @@ CONF = gyan.conf.CONF
44 45
 LOG = logging.getLogger(__name__)
45 46
 
46 47
 VALID_STATES = {
47
-    'deploy': [consts.CREATED, consts.UNDEPLOYED],
48
+    'deploy': [consts.CREATED, consts.UNDEPLOYED, consts.SCHEDULED],
48 49
     'undeploy': [consts.DEPLOYED]
49 50
 }
50 51
 def safe_rstrip(value, chars=None):
@@ -162,7 +163,7 @@ def get_ml_model(ml_model_ident):
162 163
 def validate_ml_model_state(ml_model, action):
163 164
     if ml_model.status not in VALID_STATES[action]:
164 165
         raise exception.InvalidStateException(
165
-            id=ml_model.uuid,
166
+            id=ml_model.id,
166 167
             action=action,
167 168
             actual_state=ml_model.status)
168 169
 
@@ -253,3 +254,12 @@ def decode_file_data(data):
253 254
         return base64.b64decode(data)
254 255
     except (TypeError, binascii.Error):
255 256
         raise exception.Base64Exception()
257
+
258
+
259
+def save_model(path, model):
260
+    file_path = os.path.join(path, model.id)
261
+    with open(file_path+'.zip', 'wb') as f:
262
+        f.write(model.ml_data)
263
+    zip_ref = zipfile.ZipFile(file_path+'.zip', 'r')
264
+    zip_ref.extractall(file_path)
265
+    zip_ref.close()

+ 11
- 7
gyan/compute/api.py View File

@@ -28,7 +28,6 @@ CONF = gyan.conf.CONF
28 28
 LOG = logging.getLogger(__name__)
29 29
 
30 30
 
31
-@profiler.trace_cls("rpc")
32 31
 class API(object):
33 32
     """API for interacting with the compute manager."""
34 33
 
@@ -36,10 +35,11 @@ class API(object):
36 35
         self.rpcapi = rpcapi.API(context=context)
37 36
         super(API, self).__init__()
38 37
 
39
-    def ml_model_create(self, context, new_ml_model, extra_spec):
38
+    def ml_model_create(self, context, new_ml_model, **extra_spec):
40 39
         try:
41
-            host_state = self._schedule_ml_model(context, ml_model,
42
-                                                  extra_spec)
40
+            host_state = {
41
+                "host": "localhost"
42
+            } #self._schedule_ml_model(context, ml_model, extra_spec)
43 43
         except exception.NoValidHost:
44 44
             new_ml_model.status = consts.ERROR
45 45
             new_ml_model.status_reason = _(
@@ -51,13 +51,17 @@ class API(object):
51 51
             new_ml_model.status_reason = _("Unexpected exception occurred.")
52 52
             new_ml_model.save(context)
53 53
             raise
54
-
55
-        self.rpcapi.ml_model_create(context, host_state['host'],
54
+        LOG.debug(host_state)
55
+        return self.rpcapi.ml_model_create(context, host_state['host'],
56 56
                                      new_ml_model)
57
+    
58
+    def ml_model_predict(self, context, ml_model_id, **kwargs):
59
+        return self.rpcapi.ml_model_predict(context, ml_model_id,
60
+                                     **kwargs)
57 61
 
58 62
     def ml_model_delete(self, context, ml_model, *args):
59 63
         self._record_action_start(context, ml_model, ml_model_actions.DELETE)
60 64
         return self.rpcapi.ml_model_delete(context, ml_model, *args)
61 65
 
62 66
     def ml_model_show(self, context, ml_model):
63
-        return self.rpcapi.ml_model_show(context, ml_model)
67
+        return self.rpcapi.ml_model_show(context, ml_model)

+ 14
- 13
gyan/compute/manager.py View File

@@ -1,5 +1,3 @@
1
-#    Copyright 2016 IBM Corp.
2
-#
3 1
 #    Licensed under the Apache License, Version 2.0 (the "License"); you may
4 2
 #    not use this file except in compliance with the License. You may obtain
5 3
 #    a copy of the License at
@@ -12,7 +10,9 @@
12 10
 #    License for the specific language governing permissions and limitations
13 11
 #    under the License.
14 12
 
13
+import base64
15 14
 import itertools
15
+import os
16 16
 
17 17
 import six
18 18
 import time
@@ -49,17 +49,18 @@ class Manager(periodic_task.PeriodicTasks):
49 49
         self.host = CONF.compute.host
50 50
         self._resource_tracker = None
51 51
 
52
-    def ml_model_create(self, context, limits, requested_networks,
53
-                         requested_volumes, ml_model, run, pci_requests=None):
54
-        @utils.synchronized(ml_model.uuid)
55
-        def do_ml_model_create():
56
-            created_ml_model = self._do_ml_model_create(
57
-                context, ml_model, requested_networks, requested_volumes,
58
-                pci_requests, limits)
59
-            if run:
60
-                self._do_ml_model_start(context, created_ml_model)
52
+    def ml_model_create(self, context, ml_model):
53
+        db_ml_model = objects.ML_Model.get_by_uuid_db(context, ml_model["id"])
54
+        utils.save_model(CONF.state_path, db_ml_model)
55
+        obj_ml_model = objects.ML_Model.get_by_uuid(context, ml_model["id"])
56
+        obj_ml_model.status = consts.SCHEDULED
57
+        obj_ml_model.status_reason = "The ML Model is scheduled and saved to the host %s" % self.host
58
+        obj_ml_model.save(context)
61 59
 
62
-        utils.spawn_n(do_ml_model_create)
60
+    def ml_model_predict(self, context, ml_model_id, kwargs):
61
+        #open("/home/bharath/Documents/0.png", "wb").write(base64.b64decode(kwargs["data"]))
62
+        model_path = os.path.join(CONF.state_path, ml_model_id)
63
+        return self.driver.predict(context, model_path, base64.b64decode(kwargs["data"]))
63 64
 
64 65
     @wrap_ml_model_event(prefix='compute')
65 66
     def _do_ml_model_create(self, context, ml_model, requested_networks,
@@ -118,4 +119,4 @@ class Manager(periodic_task.PeriodicTasks):
118 119
             rt = compute_host_tracker.ComputeHostTracker(self.host,
119 120
                                                          self.driver)
120 121
             self._resource_tracker = rt
121
-        return self._resource_tracker
122
+        return self._resource_tracker

+ 4
- 3
gyan/compute/rpcapi.py View File

@@ -1,5 +1,3 @@
1
-#    Copyright 2016 IBM Corp.
2
-#
3 1
 #    Licensed under the Apache License, Version 2.0 (the "License"); you may
4 2
 #    not use this file except in compliance with the License. You may obtain
5 3
 #    a copy of the License at
@@ -30,7 +28,6 @@ def check_ml_model_host(func):
30 28
     return wrap
31 29
 
32 30
 
33
-@profiler.trace_cls("rpc")
34 31
 class API(rpc_service.API):
35 32
     """Client side of the ml_model compute rpc API.
36 33
 
@@ -51,6 +48,10 @@ class API(rpc_service.API):
51 48
         self._cast(host, 'ml_model_create', 
52 49
                    ml_model=ml_model)
53 50
 
51
+    def ml_model_predict(self, context, ml_model_id, **kwargs):
52
+        return self._call("localhost", 'ml_model_predict', 
53
+                   ml_model_id=ml_model_id, kwargs=kwargs)
54
+
54 55
     @check_ml_model_host
55 56
     def ml_model_delete(self, context, ml_model, force):
56 57
         return self._cast(ml_model.host, 'ml_model_delete',

+ 0
- 3
gyan/conf/scheduler.py View File

@@ -1,6 +1,3 @@
1
-# Copyright 2015 OpenStack Foundation
2
-# All Rights Reserved.
3
-#
4 1
 #    Licensed under the Apache License, Version 2.0 (the "License"); you may
5 2
 #    not use this file except in compliance with the License. You may obtain
6 3
 #    a copy of the License at

+ 39
- 0
gyan/db/sqlalchemy/alembic/versions/f3bf9414f399_add_ml_type_and_ml_data_to_ml_model_.py View File

@@ -0,0 +1,39 @@
1
+"""Add ml_type and ml_data to ml_model table
2
+
3
+Revision ID: f3bf9414f399
4
+Revises: cebd81b206ca
5
+Create Date: 2018-10-13 09:48:36.783322
6
+
7
+"""
8
+
9
+# revision identifiers, used by Alembic.
10
+revision = 'f3bf9414f399'
11
+down_revision = 'cebd81b206ca'
12
+branch_labels = None
13
+depends_on = None
14
+
15
+from alembic import op
16
+import sqlalchemy as sa
17
+from sqlalchemy.dialects import mysql
18
+
19
+def upgrade():
20
+    # ### commands auto generated by Alembic - please adjust! ###
21
+    with op.batch_alter_table('compute_host', schema=None) as batch_op:
22
+        batch_op.alter_column('hostname',
23
+               existing_type=mysql.VARCHAR(length=255),
24
+               nullable=False)
25
+        batch_op.alter_column('status',
26
+               existing_type=mysql.VARCHAR(length=255),
27
+               nullable=False)
28
+        batch_op.alter_column('type',
29
+               existing_type=mysql.VARCHAR(length=255),
30
+               nullable=False)
31
+
32
+    with op.batch_alter_table('ml_model', schema=None) as batch_op:
33
+        batch_op.add_column(sa.Column('ml_data', sa.LargeBinary(length=(2**32)-1), nullable=True))
34
+        batch_op.add_column(sa.Column('ml_type', sa.String(length=255), nullable=True))
35
+        batch_op.add_column(sa.Column('started_at', sa.DateTime(), nullable=True))
36
+        batch_op.create_unique_constraint('uniq_mlmodel0uuid', ['id'])
37
+        batch_op.drop_constraint(u'ml_model_ibfk_1', type_='foreignkey')
38
+        
39
+    # ### end Alembic commands ###

+ 10
- 9
gyan/db/sqlalchemy/api.py View File

@@ -1,5 +1,3 @@
1
-# Copyright 2013 Hewlett-Packard Development Company, L.P.
2
-#
3 1
 # Licensed under the Apache License, Version 2.0 (the "License"); you may
4 2
 # not use this file except in compliance with the License. You may obtain
5 3
 # a copy of the License at
@@ -13,6 +11,7 @@
13 11
 # under the License.
14 12
 
15 13
 """SQLAlchemy storage backend."""
14
+from oslo_log import log as logging
16 15
 
17 16
 from oslo_db import exception as db_exc
18 17
 from oslo_db.sqlalchemy import session as db_session
@@ -39,6 +38,7 @@ profiler_sqlalchemy = importutils.try_import('osprofiler.sqlalchemy')
39 38
 CONF = gyan.conf.CONF
40 39
 
41 40
 _FACADE = None
41
+LOG = logging.getLogger(__name__)
42 42
 
43 43
 
44 44
 def _create_facade_lazily():
@@ -90,7 +90,7 @@ def add_identity_filter(query, value):
90 90
     if strutils.is_int_like(value):
91 91
         return query.filter_by(id=value)
92 92
     elif uuidutils.is_uuid_like(value):
93
-        return query.filter_by(uuid=value)
93
+        return query.filter_by(id=value)
94 94
     else:
95 95
         raise exception.InvalidIdentity(identity=value)
96 96
 
@@ -230,16 +230,17 @@ class Connection(object):
230 230
 
231 231
     def list_ml_models(self, context, filters=None, limit=None,
232 232
                       marker=None, sort_key=None, sort_dir=None):
233
-        query = model_query(models.Capsule)
233
+        query = model_query(models.ML_Model)
234 234
         query = self._add_project_filters(context, query)
235 235
         query = self._add_ml_models_filters(query, filters)
236
-        return _paginate_query(models.Capsule, limit, marker,
236
+        LOG.debug(filters)
237
+        return _paginate_query(models.ML_Model, limit, marker,
237 238
                                sort_key, sort_dir, query)
238 239
 
239 240
     def create_ml_model(self, context, values):
240 241
         # ensure defaults are present for new ml_models
241
-        if not values.get('uuid'):
242
-            values['uuid'] = uuidutils.generate_uuid()
242
+        if not values.get('id'):
243
+            values['id'] = uuidutils.generate_uuid()
243 244
         ml_model = models.ML_Model()
244 245
         ml_model.update(values)
245 246
         try:
@@ -252,7 +253,7 @@ class Connection(object):
252 253
     def get_ml_model_by_uuid(self, context, ml_model_uuid):
253 254
         query = model_query(models.ML_Model)
254 255
         query = self._add_project_filters(context, query)
255
-        query = query.filter_by(uuid=ml_model_uuid)
256
+        query = query.filter_by(id=ml_model_uuid)
256 257
         try:
257 258
             return query.one()
258 259
         except NoResultFound:
@@ -261,7 +262,7 @@ class Connection(object):
261 262
     def get_ml_model_by_name(self, context, ml_model_name):
262 263
         query = model_query(models.ML_Model)
263 264
         query = self._add_project_filters(context, query)
264
-        query = query.filter_by(meta_name=ml_model_name)
265
+        query = query.filter_by(name=ml_model_name)
265 266
         try:
266 267
             return query.one()
267 268
         except NoResultFound:

+ 7
- 5
gyan/db/sqlalchemy/models.py View File

@@ -31,6 +31,7 @@ from sqlalchemy import orm
31 31
 from sqlalchemy import schema
32 32
 from sqlalchemy import sql
33 33
 from sqlalchemy import String
34
+from sqlalchemy import LargeBinary
34 35
 from sqlalchemy import Text
35 36
 from sqlalchemy.types import TypeDecorator, TEXT
36 37
 
@@ -120,11 +121,12 @@ class ML_Model(Base):
120 121
     name = Column(String(255))
121 122
     status = Column(String(20))
122 123
     status_reason = Column(Text, nullable=True)
123
-    task_state = Column(String(20))
124
-    host_id = Column(String(255))
125
-    status_detail = Column(String(50))
126
-    deployed = Column(String(50))
124
+    host_id = Column(String(255), nullable=True)
127 125
     deployed = Column(Text, nullable=True)
126
+    url = Column(Text, nullable=True)
127
+    hints = Column(Text, nullable=True)
128
+    ml_type = Column(String(255), nullable=True)
129
+    ml_data = Column(LargeBinary(length=(2**32)-1), nullable=True)
128 130
     started_at = Column(DateTime)
129 131
 
130 132
 
@@ -138,4 +140,4 @@ class ComputeHost(Base):
138 140
     id = Column(String(36), primary_key=True, nullable=False)
139 141
     hostname = Column(String(255), nullable=False)
140 142
     status = Column(String(255), nullable=False)
141
-    type = Column(String(255), nullable=False)
143
+    type = Column(String(255), nullable=False)

+ 24
- 1
gyan/ml_model/tensorflow/driver.py View File

@@ -15,8 +15,13 @@ import datetime
15 15
 import eventlet
16 16
 import functools
17 17
 import types
18
+import png
19
+import os
20
+import tempfile
21
+import numpy as np
22
+
23
+import tensorflow as tf
18 24
 
19
-from docker import errors
20 25
 from oslo_log import log as logging
21 26
 from oslo_utils import timeutils
22 27
 from oslo_utils import uuidutils
@@ -47,6 +52,24 @@ class TensorflowDriver(driver.MLModelDriver):
47 52
         return ml_model
48 53
         pass
49 54
 
55
+    def _load(self, session, path):
56
+        saver = tf.train.import_meta_graph(path + '/model.meta')
57
+        saver.restore(session, tf.train.latest_checkpoint(path))
58
+        return tf.get_default_graph()
59
+
60
+    def predict(self, context, ml_model_path, data):
61
+        session = tf.Session()
62
+        graph = self._load(session, ml_model_path)
63
+        img_file, img_path = tempfile.mkstemp()
64
+        with os.fdopen(img_file, 'wb') as f:
65
+            f.write(data)
66
+        png_data = png.Reader(img_path)
67
+        img = np.array(list(png_data.read()[2]))
68
+        img = img.reshape(1, 784)
69
+        tensor = graph.get_tensor_by_name('x:0')
70
+        prediction = graph.get_tensor_by_name('classification:0')
71
+        return {"data": session.run(prediction, feed_dict={tensor:img})[0]}
72
+
50 73
 
51 74
     def delete(self, context, ml_model, force):
52 75
         pass

+ 15
- 0
gyan/objects/fields.py View File

@@ -43,3 +43,18 @@ class Json(fields.FieldType):
43 43
 
44 44
 class JsonField(fields.AutoTypedField):
45 45
     AUTO_TYPE = Json()
46
+
47
+
48
+class ModelFieldType(fields.FieldType):
49
+    def coerce(self, obj, attr, value):
50
+        return value
51
+
52
+    def from_primitive(self, obj, attr, value):
53
+        return self.coerce(obj, attr, value)
54
+
55
+    def to_primitive(self, obj, attr, value):
56
+        return value
57
+
58
+
59
+class ModelField(fields.AutoTypedField):
60
+    AUTO_TYPE = ModelFieldType()

+ 38
- 4
gyan/objects/ml_model.py View File

@@ -22,6 +22,7 @@ from gyan.objects import fields as z_fields
22 22
 
23 23
 LOG = logging.getLogger(__name__)
24 24
 
25
+
25 26
 @base.GyanObjectRegistry.register
26 27
 class ML_Model(base.GyanPersistentObject, base.GyanObject):
27 28
     VERSION = '1'
@@ -35,16 +36,19 @@ class ML_Model(base.GyanPersistentObject, base.GyanObject):
35 36
         'status_reason': fields.StringField(nullable=True),
36 37
         'url': fields.StringField(nullable=True),
37 38
         'deployed': fields.BooleanField(nullable=True),
38
-        'node': fields.UUIDField(nullable=True),
39 39
         'hints': fields.StringField(nullable=True),
40 40
         'created_at': fields.DateTimeField(tzinfo_aware=False, nullable=True),
41
-        'updated_at': fields.DateTimeField(tzinfo_aware=False, nullable=True)
41
+        'updated_at': fields.DateTimeField(tzinfo_aware=False, nullable=True),
42
+        'ml_data': z_fields.ModelField(nullable=True),
43
+        'ml_type': fields.StringField(nullable=True)
42 44
     }
43 45
 
44 46
     @staticmethod
45 47
     def _from_db_object(ml_model, db_ml_model):
46 48
         """Converts a database entity to a formal object."""
47 49
         for field in ml_model.fields:
50
+            if 'field' == 'ml_data':
51
+                continue
48 52
             setattr(ml_model, field, db_ml_model[field])
49 53
 
50 54
         ml_model.obj_reset_changes()
@@ -67,6 +71,17 @@ class ML_Model(base.GyanPersistentObject, base.GyanObject):
67 71
         db_ml_model = dbapi.get_ml_model_by_uuid(context, uuid)
68 72
         ml_model = ML_Model._from_db_object(cls(context), db_ml_model)
69 73
         return ml_model
74
+    
75
+    @base.remotable_classmethod
76
+    def get_by_uuid_db(cls, context, uuid):
77
+        """Find a ml model based on uuid and return a :class:`ML_Model` object.
78
+
79
+        :param uuid: the uuid of a ml model.
80
+        :param context: Security context
81
+        :returns: a :class:`ML_Model` object.
82
+        """
83
+        db_ml_model = dbapi.get_ml_model_by_uuid(context, uuid)
84
+        return db_ml_model
70 85
 
71 86
     @base.remotable_classmethod
72 87
     def get_by_name(cls, context, name):
@@ -125,7 +140,7 @@ class ML_Model(base.GyanPersistentObject, base.GyanObject):
125 140
         """
126 141
         values = self.obj_get_changes()
127 142
         db_ml_model = dbapi.create_ml_model(context, values)
128
-        self._from_db_object(self, db_ml_model)
143
+        return self._from_db_object(self, db_ml_model)
129 144
 
130 145
     @base.remotable
131 146
     def destroy(self, context=None):
@@ -138,7 +153,26 @@ class ML_Model(base.GyanPersistentObject, base.GyanObject):
138 153
                         A context should be set when instantiating the
139 154
                         object, e.g.: ML Model(context)
140 155
         """
141
-        dbapi.destroy_ml_model(context, self.uuid)
156
+        dbapi.destroy_ml_model(context, self.id)
157
+        self.obj_reset_changes()
158
+
159
+    @base.remotable
160
+    def save(self, context=None):
161
+        """Save updates to this ML Model.
162
+
163
+        Updates will be made column by column based on the result
164
+        of self.what_changed().
165
+
166
+        :param context: Security context. NOTE: This should only
167
+                        be used internally by the indirection_api.
168
+                        Unfortunately, RPC requires context as the first
169
+                        argument, even though we don't use it.
170
+                        A context should be set when instantiating the
171
+                        object, e.g.: ML Model(context)
172
+        """
173
+        updates = self.obj_get_changes()
174
+        dbapi.update_ml_model(context, self.id, updates)
175
+
142 176
         self.obj_reset_changes()
143 177
 
144 178
     def obj_load_attr(self, attrname):

+ 0
- 5
gyan/tests/base.py View File

@@ -1,8 +1,3 @@
1
-# -*- coding: utf-8 -*-
2
-
3
-# Copyright 2010-2011 OpenStack Foundation
4
-# Copyright (c) 2013 Hewlett-Packard Development Company, L.P.
5
-#
6 1
 # Licensed under the Apache License, Version 2.0 (the "License"); you may
7 2
 # not use this file except in compliance with the License. You may obtain
8 3
 # a copy of the License at

+ 26
- 1
requirements.txt View File

@@ -2,4 +2,29 @@
2 2
 # of appearance. Changing the order has an impact on the overall integration
3 3
 # process, which may cause wedges in the gate later.
4 4
 
5
-pbr>=2.0 # Apache-2.0
5
+PyYAML>=3.12 # MIT
6
+eventlet!=0.18.3,!=0.20.1,>=0.18.2 # MIT
7
+keystonemiddleware>=4.17.0 # Apache-2.0
8
+pecan!=1.0.2,!=1.0.3,!=1.0.4,!=1.2,>=1.0.0 # BSD
9
+oslo.i18n>=3.15.3 # Apache-2.0
10
+oslo.log>=3.36.0 # Apache-2.0
11
+oslo.concurrency>=3.25.0 # Apache-2.0
12
+oslo.config>=5.2.0 # Apache-2.0
13
+oslo.messaging>=5.29.0 # Apache-2.0
14
+oslo.middleware>=3.31.0 # Apache-2.0
15
+oslo.policy>=1.30.0 # Apache-2.0
16
+oslo.privsep>=1.23.0 # Apache-2.0
17
+oslo.serialization!=2.19.1,>=2.18.0 # Apache-2.0
18
+oslo.service!=1.28.1,>=1.24.0 # Apache-2.0
19
+oslo.versionedobjects>=1.31.2 # Apache-2.0
20
+oslo.context>=2.19.2 # Apache-2.0
21
+oslo.utils>=3.33.0 # Apache-2.0
22
+oslo.db>=4.27.0 # Apache-2.0
23
+os-brick>=2.2.0 # Apache-2.0
24
+six>=1.10.0 # MIT
25
+SQLAlchemy!=1.1.5,!=1.1.6,!=1.1.7,!=1.1.8,>=1.0.10 # MIT
26
+stevedore>=1.20.0 # Apache-2.0
27
+pypng
28
+numpy
29
+tensorflow
30
+idx2numpy

+ 0
- 2
setup.py View File

@@ -1,5 +1,3 @@
1
-# Copyright (c) 2013 Hewlett-Packard Development Company, L.P.
2
-#
3 1
 # Licensed under the Apache License, Version 2.0 (the "License");
4 2
 # you may not use this file except in compliance with the License.
5 3
 # You may obtain a copy of the License at

Loading…
Cancel
Save