Enable all code style tests.

Fixes: bug #1187664

Change-Id: I214c94c4fd8277132c115e84d63d01d3753def4d
This commit is contained in:
Sergey Lukjanov 2013-06-05 11:43:27 +04:00
parent b6ca5d6720
commit 269dc98bf7
26 changed files with 220 additions and 205 deletions

View File

@ -16,7 +16,7 @@
from savanna.openstack.common import log as logging
from savanna.service import api
import savanna.utils.api as u
from savanna.utils.openstack.nova import novaclient
from savanna.utils.openstack import nova
LOG = logging.getLogger(__name__)
@ -135,34 +135,34 @@ def plugins_get_version(plugin_name, version):
@rest.get('/images')
def images_list():
return u.render(
images=[i.dict for i in novaclient().images.list_registered()])
images=[i.dict for i in nova.client().images.list_registered()])
@rest.get('/images/<image_id>')
def images_get(image_id):
return u.render(novaclient().images.get(image_id).dict)
return u.render(nova.client().images.get(image_id).dict)
def _render_image(image_id, nova):
return u.render(nova.images.get(image_id).wrapped_dict)
def _render_image(image_id, novaclient):
return u.render(novaclient.images.get(image_id).wrapped_dict)
@rest.post('/images/<image_id>')
def images_set(image_id, data):
nova = novaclient()
nova.images.set_description(image_id, **data)
return _render_image(image_id, nova)
novaclient = nova.client()
novaclient.images.set_description(image_id, **data)
return _render_image(image_id, novaclient)
@rest.post('/images/<image_id>/tag')
def image_tags_add(image_id, data):
nova = novaclient()
nova.images.tag(image_id, **data)
return _render_image(image_id, nova)
novaclient = nova.client()
novaclient.images.tag(image_id, **data)
return _render_image(image_id, novaclient)
@rest.post('/images/<image_id>/untag')
def image_tags_delete(image_id, data):
nova = novaclient()
nova.images.untag(image_id, **data)
return _render_image(image_id, nova)
novaclient = nova.client()
novaclient.images.untag(image_id, **data)
return _render_image(image_id, novaclient)

View File

@ -51,6 +51,15 @@ def ctx():
return _CTXS._curr_ctx
def current():
return ctx()
def session(context=None):
context = context or ctx()
return context.session
def set_ctx(new_ctx):
if not new_ctx and hasattr(_CTXS, '_curr_ctx'):
del _CTXS._curr_ctx

View File

@ -14,7 +14,7 @@
# limitations under the License.
from alembic import context
from logging.config import fileConfig
from logging import config as logging_config
from savanna.openstack.common import importutils
from sqlalchemy import create_engine, pool
@ -26,7 +26,7 @@ importutils.import_module('savanna.db.models')
config = context.config
savanna_config = config.savanna_config
fileConfig(config.config_file_name)
logging_config.fileConfig(config.config_file_name)
# set the target for 'autogenerate' support
target_metadata = model_base.SavannaBase.metadata

View File

@ -28,9 +28,9 @@ down_revision = None
from alembic import op
import sqlalchemy as sa
from savanna.utils.sqlatypes import JSONEncoded
from savanna.utils import sqlatypes as st
sa.JSONEncoded = JSONEncoded
sa.JSONEncoded = st.JSONEncoded
def upgrade():

View File

@ -12,6 +12,7 @@
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import sqlalchemy as sa
@ -20,11 +21,11 @@ from sqlalchemy import orm
from savanna.openstack.common import timeutils
from savanna.openstack.common import uuidutils
from savanna.utils.resources import BaseResource
from savanna.utils.sqlatypes import JsonDictType
from savanna.utils import resources
from savanna.utils import sqlatypes as st
class _SavannaBase(BaseResource):
class _SavannaBase(resources.BaseResource):
"""Base class for all Savanna Models."""
created = sa.Column(sa.DateTime, default=timeutils.utcnow,
@ -140,4 +141,4 @@ class ExtraMixin(object):
__filter_cols__ = ['extra']
extra = sa.Column(JsonDictType())
extra = sa.Column(st.JsonDictType())

View File

@ -19,10 +19,9 @@ from sqlalchemy.orm import relationship
from savanna.db import model_base as mb
from savanna.utils import crypto
from savanna.utils.openstack.nova import novaclient
from savanna.utils.openstack import nova
from savanna.utils import remote
from savanna.utils.sqlatypes import JsonDictType
from savanna.utils.sqlatypes import JsonListType
from savanna.utils import sqlatypes as st
CLUSTER_STATUSES = ['Starting', 'Active', 'Stopping', 'Error']
@ -39,7 +38,7 @@ class Cluster(mb.SavannaBase, mb.IdMixin, mb.TenantMixin,
name = sa.Column(sa.String(80), nullable=False)
default_image_id = sa.Column(sa.String(36))
cluster_configs = sa.Column(JsonDictType())
cluster_configs = sa.Column(st.JsonDictType())
node_groups = relationship('NodeGroup', cascade="all,delete",
backref='cluster')
# todo replace String type with sa.Enum(*CLUSTER_STATUSES)
@ -76,7 +75,7 @@ class Cluster(mb.SavannaBase, mb.IdMixin, mb.TenantMixin,
It contains 'public_key' and 'fingerprint' fields.
"""
if not hasattr(self, '_user_kp'):
self._user_kp = novaclient().keypairs.get(self.user_keypair_id)
self._user_kp = nova.client().keypairs.get(self.user_keypair_id)
return self._user_kp
@ -92,8 +91,8 @@ class NodeGroup(mb.SavannaBase, mb.IdMixin, mb.ExtraMixin):
name = sa.Column(sa.String(80), nullable=False)
flavor_id = sa.Column(sa.String(36), nullable=False)
image_id = sa.Column(sa.String(36), nullable=False)
node_processes = sa.Column(JsonListType())
node_configs = sa.Column(JsonDictType())
node_processes = sa.Column(st.JsonListType())
node_configs = sa.Column(st.JsonDictType())
anti_affinity_group = sa.Column(sa.String(36))
count = sa.Column(sa.Integer, nullable=False)
instances = relationship('Instance', cascade="all,delete",
@ -119,7 +118,7 @@ class NodeGroup(mb.SavannaBase, mb.IdMixin, mb.ExtraMixin):
@property
def username(self):
if not hasattr(self, '_username'):
self._username = novaclient().images.get(self.image_id).username
self._username = nova.client().images.get(self.image_id).username
return self._username
@property
@ -166,7 +165,7 @@ class Instance(mb.SavannaBase, mb.ExtraMixin):
@property
def nova_info(self):
"""Returns info from nova about instance."""
return novaclient().servers.get(self.instance_id)
return nova.client().servers.get(self.instance_id)
@property
def username(self):
@ -191,7 +190,7 @@ class ClusterTemplate(mb.SavannaBase, mb.IdMixin, mb.TenantMixin,
name = sa.Column(sa.String(80), nullable=False)
description = sa.Column(sa.String(200))
cluster_configs = sa.Column(JsonDictType())
cluster_configs = sa.Column(st.JsonDictType())
# todo add node_groups_suggestion helper
@ -228,8 +227,8 @@ class NodeGroupTemplate(mb.SavannaBase, mb.IdMixin, mb.TenantMixin,
name = sa.Column(sa.String(80), nullable=False)
description = sa.Column(sa.String(200))
flavor_id = sa.Column(sa.String(36), nullable=False)
node_processes = sa.Column(JsonListType())
node_configs = sa.Column(JsonDictType())
node_processes = sa.Column(st.JsonListType())
node_configs = sa.Column(st.JsonDictType())
def __init__(self, name, tenant_id, flavor_id, plugin_name,
hadoop_version, node_processes, node_configs=None,

View File

@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from savanna.context import ctx
from savanna.context import model_query
from savanna import context as ctx
import savanna.db.models as m
@ -22,17 +21,17 @@ import savanna.db.models as m
# todo check tenant_id and etc.
def get_clusters(**args):
return model_query(m.Cluster).filter_by(**args).all()
return ctx.model_query(m.Cluster).filter_by(**args).all()
def get_cluster(**args):
return model_query(m.Cluster).filter_by(**args).first()
return ctx.model_query(m.Cluster).filter_by(**args).first()
def create_cluster(values):
session = ctx().session
session = ctx.current().session
with session.begin():
values['tenant_id'] = ctx().tenant_id
values['tenant_id'] = ctx.current().tenant_id
ngs_vals = values.pop('node_groups', [])
cluster = m.Cluster(**values)
for ng in ngs_vals:
@ -45,24 +44,24 @@ def create_cluster(values):
def terminate_cluster(cluster):
with ctx().session.begin():
ctx().session.delete()
with ctx.current().session.begin():
ctx.current().session.delete(cluster)
## ClusterTemplate ops
def get_cluster_templates(**args):
return model_query(m.ClusterTemplate).filter_by(**args).all()
return ctx.model_query(m.ClusterTemplate).filter_by(**args).all()
def get_cluster_template(**args):
return model_query(m.ClusterTemplate).filter_by(**args).first()
return ctx.model_query(m.ClusterTemplate).filter_by(**args).first()
def create_cluster_template(values):
session = ctx().session
session = ctx.current().session
with session.begin():
values['tenant_id'] = ctx().tenant_id
values['tenant_id'] = ctx.current().tenant_id
ngts_vals = values.pop('node_group_templates', [])
cluster_template = m.ClusterTemplate(**values)
for ngt in ngts_vals:
@ -76,29 +75,29 @@ def create_cluster_template(values):
def terminate_cluster_template(**args):
with ctx().session.begin():
ctx().session.delete(get_cluster_template(**args))
with ctx.current().session.begin():
ctx.current().session.delete(get_cluster_template(**args))
## NodeGroupTemplate ops
def get_node_group_templates(**args):
return model_query(m.NodeGroupTemplate).filter_by(**args).all()
return ctx.model_query(m.NodeGroupTemplate).filter_by(**args).all()
def get_node_group_template(**args):
return model_query(m.NodeGroupTemplate).filter_by(**args).first()
return ctx.model_query(m.NodeGroupTemplate).filter_by(**args).first()
def create_node_group_template(values):
session = ctx().session
session = ctx.current().session
with session.begin():
values['tenant_id'] = ctx().tenant_id
values['tenant_id'] = ctx.current().tenant_id
node_group_template = m.NodeGroupTemplate(**values)
session.add(node_group_template)
return node_group_template
def terminate_node_group_template(**args):
with ctx().session.begin():
ctx().session.delete(get_node_group_template(**args))
with ctx.current().session.begin():
ctx.current().session.delete(get_node_group_template(**args))

View File

@ -13,26 +13,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from eventlet import monkey_patch
from flask import Flask
from keystoneclient.middleware.auth_token import filter_factory as auth_token
import eventlet
import flask
from keystoneclient.middleware import auth_token
from oslo.config import cfg
from savanna.context import ctx
from savanna.plugins.base import setup_plugins
from werkzeug.exceptions import default_exceptions
from werkzeug.exceptions import HTTPException
from werkzeug import exceptions as werkzeug_exceptions
from savanna.api import v10 as api_v10
from savanna import context
from savanna.db import api as db_api
from savanna.middleware.auth_valid import filter_factory as auth_valid
from savanna.utils.api import render
from savanna.utils.scheduler import setup_scheduler
from savanna.middleware import auth_valid
from savanna.plugins import base as plugins_base
from savanna.utils import api as api_utils
from savanna.utils import scheduler
from savanna.openstack.common import log
LOG = log.getLogger(__name__)
monkey_patch(os=True, select=True, socket=True, thread=True, time=True)
eventlet.monkey_patch(
os=True, select=True, socket=True, thread=True, time=True)
opts = [
cfg.StrOpt('os_auth_protocol',
@ -67,11 +67,11 @@ def make_app():
Entry point for Savanna REST API server
"""
app = Flask('savanna.api')
app = flask.Flask('savanna.api')
@app.route('/', methods=['GET'])
def version_list():
return render({
return api_utils.render({
"versions": [
{"id": "v1.0", "status": "CURRENT"}
]
@ -80,32 +80,33 @@ def make_app():
@app.teardown_request
def teardown_request(_ex=None):
# todo how it'll work in case of exception?
session = ctx().session
session = context.session()
if session.transaction:
session.transaction.commit()
app.register_blueprint(api_v10.rest, url_prefix='/v1.0')
db_api.configure_db()
setup_scheduler(app)
setup_plugins()
scheduler.setup_scheduler(app)
plugins_base.setup_plugins()
def make_json_error(ex):
status_code = (ex.code
if isinstance(ex, HTTPException)
if isinstance(ex, werkzeug_exceptions.HTTPException)
else 500)
description = (ex.description
if isinstance(ex, HTTPException)
if isinstance(ex, werkzeug_exceptions.HTTPException)
else str(ex))
return render({'error': status_code, 'error_message': description},
status=status_code)
return api_utils.render({'error': status_code,
'error_message': description},
status=status_code)
for code in default_exceptions.iterkeys():
for code in werkzeug_exceptions.default_exceptions.iterkeys():
app.error_handler_spec[None][code] = make_json_error
app.wsgi_app = auth_valid(app.config)(app.wsgi_app)
app.wsgi_app = auth_valid.filter_factory(app.config)(app.wsgi_app)
app.wsgi_app = auth_token(
app.wsgi_app = auth_token.filter_factory(
app.config,
auth_host=CONF.os_auth_host,
auth_port=CONF.os_auth_port,

View File

@ -13,14 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABCMeta
from abc import abstractmethod
import abc
import inspect
from oslo.config import cfg
from savanna.config import parse_configs
from savanna import config
from savanna.openstack.common import importutils
from savanna.openstack.common import log as logging
from savanna.utils.resources import BaseResource
from savanna.utils import resources
LOG = logging.getLogger(__name__)
@ -34,8 +35,8 @@ CONF = cfg.CONF
CONF.register_opts(opts)
class PluginInterface(BaseResource):
__metaclass__ = ABCMeta
class PluginInterface(resources.BaseResource):
__metaclass__ = abc.ABCMeta
name = 'plugin_interface'
@ -59,7 +60,7 @@ class PluginInterface(BaseResource):
"""
pass
@abstractmethod
@abc.abstractmethod
def get_title(self):
"""Plugin title
@ -102,13 +103,13 @@ class PluginManager(object):
]
CONF.register_opts(opts, group='plugin:%s' % plugin)
parse_configs()
config.parse_configs()
# register plugin-specific configs
for plugin_name in CONF.plugins:
self.plugins[plugin_name] = self._get_plugin_instance(plugin_name)
parse_configs()
config.parse_configs()
titles = []
for plugin_name in CONF.plugins:

View File

@ -13,32 +13,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import abstractmethod
import abc
import functools
from savanna.plugins.base import PluginInterface
import savanna.utils.openstack.nova as nova
from savanna.utils.resources import BaseResource
from savanna.plugins import base as plugins_base
from savanna.utils.openstack import nova
from savanna.utils import resources
class ProvisioningPluginContext(object):
def __init__(self, headers):
self.headers = headers
self.nova = self._autoheaders(nova.novaclient)
self.nova = self._autoheaders(nova.client)
def _autoheaders(self, func):
return functools.partial(func, headers=self.headers)
class ProvisioningPluginBase(PluginInterface):
@abstractmethod
class ProvisioningPluginBase(plugins_base.PluginInterface):
@abc.abstractmethod
def get_versions(self):
pass
@abstractmethod
@abc.abstractmethod
def get_configs(self, hadoop_version):
pass
@abstractmethod
@abc.abstractmethod
def get_node_processes(self, hadoop_version):
pass
@ -48,11 +49,11 @@ class ProvisioningPluginBase(PluginInterface):
def update_infra(self, cluster):
pass
@abstractmethod
@abc.abstractmethod
def configure_cluster(self, cluster):
pass
@abstractmethod
@abc.abstractmethod
def start_cluster(self, cluster):
pass
@ -68,7 +69,7 @@ class ProvisioningPluginBase(PluginInterface):
return res
class Config(BaseResource):
class Config(resources.BaseResource):
"""Describes a single config parameter.
For example:

View File

@ -13,11 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from savanna.plugins.provisioning import Config
from savanna.plugins.provisioning import ProvisioningPluginBase
from savanna.plugins import provisioning as p
class VanillaProvider(ProvisioningPluginBase):
class VanillaProvider(p.ProvisioningPluginBase):
def get_plugin_opts(self):
return []
@ -37,7 +36,7 @@ class VanillaProvider(ProvisioningPluginBase):
def get_configs(self, hadoop_version):
return [
Config('heap_size', 'tasktracker', default_value='1024M')
p.Config('heap_size', 'tasktracker', default_value='1024M')
]
def get_node_processes(self, hadoop_version):

View File

@ -14,10 +14,10 @@
# limitations under the License.
from savanna import context
import savanna.db.storage as s
from savanna.db import storage as s
from savanna.openstack.common import log as logging
import savanna.plugins.base as plugin_base
from savanna.plugins.provisioning import ProvisioningPluginBase
from savanna.plugins import base as plugin_base
from savanna.plugins import provisioning
from savanna.service import instances as i
LOG = logging.getLogger(__name__)
@ -98,7 +98,8 @@ terminate_node_group_template = s.terminate_node_group_template
## Plugins ops
def get_plugins():
return plugin_base.PLUGINS.get_plugins(base=ProvisioningPluginBase)
return plugin_base.PLUGINS.get_plugins(
base=provisioning.ProvisioningPluginBase)
def get_plugin(plugin_name, version=None):

View File

@ -18,8 +18,8 @@ import time
from savanna import context
from savanna.db import models as m
from savanna.openstack.common import log as logging
from savanna.utils.crypto import private_key_to_public_key
import savanna.utils.openstack.nova as nova
from savanna.utils import crypto
from savanna.utils.openstack import nova
LOG = logging.getLogger(__name__)
@ -57,7 +57,7 @@ def _create_instances(cluster):
ids = aa_groups[aa_group]
hints = {'different_host': list(ids)} if ids else None
nova_instance = nova.novaclient().servers.create(
nova_instance = nova.client().servers.create(
name, node_group.image_id, node_group.flavor_id,
scheduler_hints=hints, files=files)
@ -80,7 +80,7 @@ def _generate_instance_files(node_group):
path_to_root = "/home/" + node_group.username
authorized_keys = user_key.public_key + '\n'
authorized_keys += private_key_to_public_key(cluster.private_key)
authorized_keys += crypto.private_key_to_public_key(cluster.private_key)
return {
path_to_root + "/.ssh/authorized_keys": authorized_keys,
@ -173,7 +173,7 @@ def _shutdown_instances(cluster, quiet=False):
"""Shutdown all instances related to the specified cluster."""
for node_group in cluster.node_groups:
for instance in node_group.instances:
nova.novaclient().servers.delete(instance.instance_id)
nova.client().servers.delete(instance.instance_id)
def shutdown_cluster(cluster):

View File

@ -18,10 +18,8 @@ import os
import tempfile
import unittest2
from savanna.context import Context
from savanna.context import set_ctx
from savanna.db.api import clear_db
from savanna.db.api import configure_db
from savanna import context
from savanna.db import api as db_api
from savanna.openstack.common.db.sqlalchemy import session
from savanna.openstack.common import timeutils
from savanna.openstack.common import uuidutils
@ -29,16 +27,17 @@ from savanna.openstack.common import uuidutils
class ModelTestCase(unittest2.TestCase):
def setUp(self):
set_ctx(Context('test_user', 'test_tenant', 'test_auth_token', {}))
context.set_ctx(
context.Context('test_user', 'test_tenant', 'test_auth_token', {}))
self.db_fd, self.db_path = tempfile.mkstemp()
session.set_defaults('sqlite:///' + self.db_path, self.db_path)
configure_db()
db_api.configure_db()
def tearDown(self):
clear_db()
db_api.clear_db()
os.close(self.db_fd)
os.unlink(self.db_path)
set_ctx(None)
context.set_ctx(None)
def assertIsValidModelObject(self, res):
self.assertIsNotNone(res)

View File

@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from savanna.context import ctx
import savanna.db.models as m
from savanna.tests.unit.db.models.base import ModelTestCase
from savanna import context as ctx
from savanna.db import models as m
from savanna.tests.unit.db.models import base as models_test_base
class ClusterModelTest(ModelTestCase):
class ClusterModelTest(models_test_base.ModelTestCase):
def testCreateCluster(self):
session = ctx().session
session = ctx.current().session
with session.begin():
c = m.Cluster('c-1', 't-1', 'p-1', 'hv-1')
session.add(c)

View File

@ -12,13 +12,14 @@
# implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from savanna.context import ctx
import savanna.db.models as m
from savanna.tests.unit.db.models.base import ModelTestCase
import unittest2
from savanna import context as ctx
import savanna.db.models as m
from savanna.tests.unit.db.models import base as models_test_base
SAMPLE_CONFIGS = {
'a': 'av',
'b': 123,
@ -26,9 +27,9 @@ SAMPLE_CONFIGS = {
}
class TemplatesModelTest(ModelTestCase):
class TemplatesModelTest(models_test_base.ModelTestCase):
def testCreateNodeGroupTemplate(self):
session = ctx().session
session = ctx.current().session
with session.begin():
ngt = m.NodeGroupTemplate('ngt-1', 't-1', 'f-1', 'p-1', 'hv-1',
['np-1', 'np-2'], SAMPLE_CONFIGS, "d")
@ -53,7 +54,7 @@ class TemplatesModelTest(ModelTestCase):
})
def testCreateClusterTemplate(self):
session = ctx().session
session = ctx.current().session
with session.begin():
c = m.ClusterTemplate('c-1', 't-1', 'p-1', 'hv-1', SAMPLE_CONFIGS,
"d")
@ -75,7 +76,7 @@ class TemplatesModelTest(ModelTestCase):
})
def testCreateClusterTemplateWithNodeGroupTemplates(self):
session = ctx().session
session = ctx.current().session
with session.begin():
ct = m.ClusterTemplate('ct', 't-1', 'p-1', 'hv-1')
session.add(ct)

View File

@ -14,15 +14,16 @@
# limitations under the License.
import mock
import savanna.context as ctx
from savanna import context as ctx
import savanna.db.models as m
from savanna.service.instances import _create_instances
from savanna.tests.unit.db.models.base import ModelTestCase
from savanna.service import instances
from savanna.tests.unit.db.models import base as models_test_base
import savanna.utils.crypto as c
class NodePlacementTest(ModelTestCase):
@mock.patch('savanna.utils.openstack.nova.novaclient')
class NodePlacementTest(models_test_base.ModelTestCase):
@mock.patch('savanna.utils.openstack.nova.client')
def test_one_node_groups_and_one_affinity_group(self, novaclient):
node_groups = [m.NodeGroup("test_group",
"test_flavor",
@ -36,7 +37,7 @@ class NodePlacementTest(ModelTestCase):
nova = _create_nova_mock(novaclient)
_create_instances(cluster)
instances._create_instances(cluster)
files = _generate_files(cluster)
nova.servers.create.assert_has_calls(
@ -56,7 +57,7 @@ class NodePlacementTest(ModelTestCase):
with session.begin():
self.assertEqual(session.query(m.Instance).count(), 2)
@mock.patch('savanna.utils.openstack.nova.novaclient')
@mock.patch('savanna.utils.openstack.nova.client')
def test_one_node_groups_and_no_affinity_group(self, novaclient):
node_groups = [m.NodeGroup("test_group",
"test_flavor",
@ -69,7 +70,7 @@ class NodePlacementTest(ModelTestCase):
nova = _create_nova_mock(novaclient)
_create_instances(cluster)
instances._create_instances(cluster)
files = _generate_files(cluster)
nova.servers.create.assert_has_calls(
@ -89,7 +90,7 @@ class NodePlacementTest(ModelTestCase):
with session.begin():
self.assertEqual(session.query(m.Instance).count(), 2)
@mock.patch('savanna.utils.openstack.nova.novaclient')
@mock.patch('savanna.utils.openstack.nova.client')
def test_two_node_groups_and_one_affinity_group(self, novaclient):
node_groups = [m.NodeGroup("test_group_1",
"test_flavor",
@ -110,7 +111,7 @@ class NodePlacementTest(ModelTestCase):
cluster = _create_cluster_mock(node_groups)
nova = _create_nova_mock(novaclient)
_create_instances(cluster)
instances._create_instances(cluster)
files = _generate_files(cluster)
nova.servers.create.assert_has_calls(

View File

@ -13,14 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from savanna.utils.patches import patch_minidom_writexml
import unittest
import unittest2
import xml.dom.minidom as xml
from savanna.utils import patches
class MinidomPatchesTest(unittest.TestCase):
class MinidomPatchesTest(unittest2.TestCase):
def setUp(self):
patch_minidom_writexml()
patches.patch_minidom_writexml()
def _generate_n_prettify_xml(self):
doc = xml.Document()

View File

@ -13,21 +13,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import flask as f
import inspect
import mimetypes
import traceback
from werkzeug.datastructures import MIMEAccept
from savanna.context import Context
from savanna.context import set_ctx
import flask
from werkzeug import datastructures
from savanna import context
from savanna.openstack.common import log as logging
from savanna.openstack.common import wsgi
LOG = logging.getLogger(__name__)
class Rest(f.Blueprint):
class Rest(flask.Blueprint):
def get(self, rule, status_code=200):
return self._mroute('GET', rule, status_code)
@ -53,33 +53,36 @@ class Rest(f.Blueprint):
def handler(**kwargs):
# extract response content type
resp_type = f.request.accept_mimetypes
resp_type = flask.request.accept_mimetypes
type_suffix = kwargs.pop('resp_type', None)
if type_suffix:
suffix_mime = mimetypes.guess_type("res." + type_suffix)[0]
if suffix_mime:
resp_type = MIMEAccept([(suffix_mime, 1)])
f.request.resp_type = resp_type
resp_type = datastructures.MIMEAccept(
[(suffix_mime, 1)])
flask.request.resp_type = resp_type
# update status code
if status:
f.request.status_code = status
flask.request.status_code = status
kwargs.pop("tenant_id")
context = Context(f.request.headers['X-User-Id'],
f.request.headers['X-Tenant-Id'],
f.request.headers['X-Auth-Token'],
f.request.headers)
set_ctx(context)
ctx = context.Context(
flask.request.headers['X-User-Id'],
flask.request.headers['X-Tenant-Id'],
flask.request.headers[
'X-Auth-Token'],
flask.request.headers)
context.set_ctx(ctx)
# set func implicit args
args = inspect.getargspec(func).args
if 'ctx' in args:
kwargs['ctx'] = context
kwargs['ctx'] = ctx
if f.request.method in ['POST', 'PUT'] and 'data' in args:
if flask.request.method in ['POST', 'PUT'] and 'data' in args:
kwargs['data'] = request_data()
return func(**kwargs)
@ -97,8 +100,8 @@ class Rest(f.Blueprint):
return decorator
RT_JSON = MIMEAccept([("application/json", 1)])
RT_XML = MIMEAccept([("application/xml", 1)])
RT_JSON = datastructures.MIMEAccept([("application/json", 1)])
RT_XML = datastructures.MIMEAccept([("application/xml", 1)])
def _clean_nones(obj):
@ -136,14 +139,14 @@ def render(res=None, resp_type=None, status=None, **kwargs):
res = _clean_nones(res)
status_code = getattr(f.request, 'status_code', None)
status_code = getattr(flask.request, 'status_code', None)
if status:
status_code = status
if not status_code:
status_code = 200
if not resp_type:
resp_type = getattr(f.request, 'resp_type', RT_JSON)
resp_type = getattr(flask.request, 'resp_type', RT_JSON)
if not resp_type:
resp_type = RT_JSON
@ -161,19 +164,20 @@ def render(res=None, resp_type=None, status=None, **kwargs):
body = serializer.serialize(res)
resp_type = str(resp_type)
return f.Response(response=body, status=status_code, mimetype=resp_type)
return flask.Response(response=body, status=status_code,
mimetype=resp_type)
def request_data():
if hasattr(f.request, 'parsed_data'):
return f.request.parsed_data
if hasattr(flask.request, 'parsed_data'):
return flask.request.parsed_data
if not f.request.content_length > 0:
if not flask.request.content_length > 0:
LOG.debug("Empty body provided in request")
return dict()
deserializer = None
content_type = f.request.mimetype
content_type = flask.request.mimetype
if not content_type or content_type in RT_JSON:
deserializer = wsgi.JSONDeserializer()
elif content_type in RT_XML:
@ -183,9 +187,10 @@ def request_data():
abort_and_log(400, "Content type '%s' isn't supported" % content_type)
# parsed request data to avoid unwanted re-parsings
f.request.parsed_data = deserializer.deserialize(f.request.data)['body']
parsed_data = deserializer.deserialize(flask.request.data)['body']
flask.request.parsed_data = parsed_data
return f.request.parsed_data
return flask.request.parsed_data
def abort_and_log(status_code, descr, exc=None):
@ -195,7 +200,7 @@ def abort_and_log(status_code, descr, exc=None):
if exc is not None:
LOG.error(traceback.format_exc())
f.abort(status_code, description=descr)
flask.abort(status_code, description=descr)
def render_error_message(error_code, error_message, error_name):

View File

@ -16,7 +16,7 @@
from Crypto.PublicKey import RSA
from Crypto import Random
import paramiko
from six import StringIO
import six
def generate_private_key(length=2048):
@ -27,7 +27,7 @@ def generate_private_key(length=2048):
def to_paramiko_private_key(pkey):
"""Convert private key (str) to paramiko-specific RSAKey object."""
return paramiko.RSAKey(file_obj=StringIO(pkey))
return paramiko.RSAKey(file_obj=six.StringIO(pkey))
def private_key_to_public_key(key):

View File

@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from novaclient.v1_1.images import Image
from novaclient.v1_1.images import ImageManager
from novaclient.v1_1 import images
PROP_DESCR = '_savanna_description'
@ -32,7 +31,7 @@ def _ensure_tags(tags):
return [tags] if type(tags) in [str, unicode] else tags
class SavannaImage(Image):
class SavannaImage(images.Image):
def __init__(self, manager, info, loaded=False):
info['description'] = info.get('metadata', {}).get(PROP_DESCR)
info['username'] = info.get('metadata', {}).get(PROP_USERNAME)
@ -60,7 +59,7 @@ class SavannaImage(Image):
return self._info.copy()
class SavannaImageManager(ImageManager):
class SavannaImageManager(images.ImageManager):
"""Manage :class:`SavannaImage` resources.
This is an extended version of nova client's ImageManager with support of

View File

@ -14,10 +14,10 @@
# limitations under the License.
from novaclient import base
from novaclient.v1_1.keypairs import KeypairManager
from novaclient.v1_1 import keypairs
class SavannaKeypairManager(KeypairManager):
class SavannaKeypairManager(keypairs.KeypairManager):
def get(self, keypair):
"""Get a keypair.

View File

@ -16,15 +16,15 @@
import logging
from novaclient.v1_1 import client as nova_client
from savanna.context import ctx
from savanna import context
import savanna.utils.openstack.base as base
from savanna.utils.openstack.images import SavannaImageManager
from savanna.utils.openstack.keypairs import SavannaKeypairManager
from savanna.utils.openstack import images
from savanna.utils.openstack import keypairs
def novaclient():
headers = ctx().headers
def client():
headers = context.current().headers
username = headers['X-User-Name']
token = headers['X-Auth-Token']
tenant = headers['X-Tenant-Id']
@ -39,25 +39,25 @@ def novaclient():
nova.client.auth_token = token
nova.client.management_url = compute_url
nova.images = SavannaImageManager(nova)
nova.images = images.SavannaImageManager(nova)
if not hasattr(nova.keypairs, 'get'):
nova.keypairs = SavannaKeypairManager(nova)
nova.keypairs = keypairs.SavannaKeypairManager(nova)
return nova
def get_flavors():
return [flavor.name for flavor in novaclient().flavors.list()]
return [flavor.name for flavor in client().flavors.list()]
def get_flavor(**kwargs):
return novaclient().flavors.find(**kwargs)
return client().flavors.find(**kwargs)
def get_images():
return [image.id for image in novaclient().images.list()]
return [image.id for image in client().images.list()]
def get_limits():
limits = novaclient().limits.get().absolute
limits = client().limits.get().absolute
return dict((l.name, l.value) for l in limits)

View File

@ -15,13 +15,13 @@
import paramiko
from savanna.utils.crypto import to_paramiko_private_key
from savanna.utils import crypto
def setup_ssh_connection(host, username, private_key):
"""Setup SSH connection to the host using username and private key."""
if type(private_key) in [str, unicode]:
private_key = to_paramiko_private_key(private_key)
private_key = crypto.to_paramiko_private_key(private_key)
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
ssh.connect(host, username=username, pkey=private_key)

View File

@ -13,16 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from sqlalchemy.ext.mutable import Mutable
from sqlalchemy.types import TypeDecorator, VARCHAR
from sqlalchemy.ext import mutable
from sqlalchemy import types as st
from savanna.openstack.common import jsonutils
class JSONEncoded(TypeDecorator):
class JSONEncoded(st.TypeDecorator):
"""Represents an immutable structure as a json-encoded string."""
impl = VARCHAR
impl = st.VARCHAR
def process_bind_param(self, value, dialect):
if value is not None:
@ -36,7 +36,7 @@ class JSONEncoded(TypeDecorator):
# todo verify this implementation
class MutableDict(Mutable, dict):
class MutableDict(mutable.Mutable, dict):
@classmethod
def coerce(cls, key, value):
"""Convert plain dictionaries to MutableDict."""
@ -45,7 +45,7 @@ class MutableDict(Mutable, dict):
return MutableDict(value)
# this call will raise ValueError
return Mutable.coerce(key, value)
return mutable.Mutable.coerce(key, value)
else:
return value
@ -66,7 +66,7 @@ class MutableDict(Mutable, dict):
# todo verify this implementation
class MutableList(Mutable, list):
class MutableList(mutable.Mutable, list):
@classmethod
def coerce(cls, key, value):
"""Convert plain lists to MutableList."""
@ -75,7 +75,7 @@ class MutableList(Mutable, list):
return MutableList(value)
# this call will raise ValueError
return Mutable.coerce(key, value)
return mutable.Mutable.coerce(key, value)
else:
return value

View File

@ -45,8 +45,6 @@ commands =
pylint --output-format=parseable --rcfile=.pylintrc bin/savanna-api bin/savanna-manage savanna | tee pylint-report.txt
[flake8]
# H302 import only modules
ignore = H302
show-source = true
builtins = _
exclude=.venv,.git,.tox,dist,doc,*openstack/common*,*lib/python*,*egg,tools