diff --git a/savanna/context.py b/savanna/context.py index ad100b81..e24155bb 100644 --- a/savanna/context.py +++ b/savanna/context.py @@ -94,6 +94,23 @@ def model_save(model, context=None): return model +def model_update(model, context=None, **kwargs): + if not hasattr(model, '__table__'): + # TODO(slikjanov): replace with specific exception + raise RuntimeError("Specified object isn't model, class: %s" + % model.__class__.__name__) + columns = model.__table__.columns + for prop in kwargs: + if prop not in columns: + # TODO(slukjanov): replace with specific exception + raise RuntimeError( + "Model class '%s' doesn't contains specified property '%s'" + % (model.__class__.__name__, prop)) + setattr(model, prop, kwargs[prop]) + + return model_save(model, context) + + def spawn(func, *args, **kwargs): ctx = current().clone() diff --git a/savanna/service/api.py b/savanna/service/api.py index 88bb5a2b..323b079e 100644 --- a/savanna/service/api.py +++ b/savanna/service/api.py @@ -48,14 +48,12 @@ def scale_cluster(cluster_id, data): additional = construct_ngs_for_scaling(additional_node_groups) try: - cluster.status = 'Validating' - context.model_save(cluster) + context.model_update(cluster, status='Validating') _validate_cluster(cluster, plugin, additional) plugin.validate_scaling(cluster, to_be_enlarged, additional) except Exception: with excutils.save_and_reraise_exception(): - cluster.status = 'Active' - context.model_save(cluster) + context.model_update(cluster, status='Active') # If we are here validation is successful. # So let's update bd and to_be_enlarged map: @@ -77,13 +75,11 @@ def create_cluster(values): # validating cluster try: - cluster.status = 'Validating' - context.model_save(cluster) + context.model_update(cluster, status='Validating') plugin.validate(cluster) except Exception: with excutils.save_and_reraise_exception(): - cluster.status = 'Error' - context.model_save(cluster) + context.model_update(cluster, status='Error') context.spawn(_provision_cluster, cluster.id) @@ -95,18 +91,15 @@ def _provision_nodes(cluster_id, node_group_names_map): cluster = get_cluster(id=cluster_id) plugin = plugin_base.PLUGINS.get_plugin(cluster.plugin_name) - cluster.status = 'Scaling' - context.model_save(cluster) + context.model_update(cluster, status='Scaling') instances = i.scale_cluster(cluster, node_group_names_map) if instances: - cluster.status = 'Configuring' - context.model_save(cluster) + context.model_update(cluster, status='Configuring') plugin.scale_cluster(cluster, instances) # cluster is now up and ready - cluster.status = 'Active' - context.model_save(cluster) + context.model_update(cluster, status='Active') def _provision_cluster(cluster_id): @@ -114,34 +107,29 @@ def _provision_cluster(cluster_id): plugin = plugin_base.PLUGINS.get_plugin(cluster.plugin_name) # updating cluster infra - cluster.status = 'InfraUpdating' - context.model_save(cluster) + context.model_update(cluster, status='InfraUpdating') plugin.update_infra(cluster) # creating instances and configuring them i.create_cluster(cluster) # configure cluster - cluster.status = 'Configuring' - context.model_save(cluster) + context.model_update(cluster, status='Configuring') plugin.configure_cluster(cluster) # starting prepared and configured cluster - cluster.status = 'Starting' - context.model_save(cluster) + context.model_update(cluster, status='Starting') plugin.start_cluster(cluster) # cluster is now up and ready - cluster.status = 'Active' - context.model_save(cluster) + context.model_update(cluster, status='Active') return cluster def terminate_cluster(**args): cluster = get_cluster(**args) - cluster.status = 'Deleting' - context.model_save(cluster) + context.model_update(cluster, status='Deleting') plugin = plugin_base.PLUGINS.get_plugin(cluster.plugin_name) plugin.on_terminate_cluster(cluster) diff --git a/savanna/service/instances.py b/savanna/service/instances.py index 7b04c632..8396156c 100644 --- a/savanna/service/instances.py +++ b/savanna/service/instances.py @@ -28,27 +28,23 @@ LOG = logging.getLogger(__name__) def create_cluster(cluster): try: # create all instances - cluster.status = 'Spawning' - context.model_save(cluster) + context.model_update(cluster, status='Spawning') _create_instances(cluster) # wait for all instances are up and accessible - cluster.status = 'Waiting' - context.model_save(cluster) + context.model_update(cluster, status='Waiting') _await_instances(cluster) # attach volumes volumes.attach(cluster) # prepare all instances - cluster.status = 'Preparing' - context.model_save(cluster) + context.model_update(cluster, status='Preparing') _configure_instances(cluster) except Exception as ex: LOG.warn("Can't start cluster '%s' (reason: %s)", cluster.name, ex) with excutils.save_and_reraise_exception(): - cluster.status = 'Error' - context.model_save(cluster) + context.model_update(cluster, status='Error') _rollback_cluster_creation(cluster, ex) diff --git a/savanna/tests/unit/test_context.py b/savanna/tests/unit/test_context.py new file mode 100644 index 00000000..02f3a2cb --- /dev/null +++ b/savanna/tests/unit/test_context.py @@ -0,0 +1,59 @@ +# Copyright (c) 2013 Mirantis Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import sqlalchemy as sa + +from savanna import context +from savanna.db import model_base as mb +from savanna.tests.unit import base + + +class TestModel(mb.SavannaBase, mb.IdMixin): + test_field = sa.Column(sa.String(80)) + + +def _insert_test_object(): + t = TestModel() + t.test_field = 123 + context.model_save(t) + + return t + + +class ModelHelpersTest(base.DbTestCase): + def test_model_save(self): + self.assertEqual(0, len(context.model_query(TestModel).all())) + + t = _insert_test_object() + + self.assertEqual(1, len(context.model_query(TestModel).all())) + + db_t = context.model_query(TestModel).first() + + self.assertEqual(t.id, db_t.id) + self.assertEqual(t.test_field, db_t.test_field) + + def test_model_update(self): + _insert_test_object() + + t = context.current().session.query(TestModel).first() + + context.model_update(t, test_field=42) + + db_t = context.model_query(TestModel).first() + + self.assertEqual(t.id, db_t.id) + self.assertEqual(42, db_t.test_field)