Refactoring API - Cue Objects - DB modules

* ensure autonomous DB transaction when creating a new cluster.
* move DB read/write access calls from DB models to DB API.
* add unit tests for DB models
* add unit tests for DB API
* update REST api to secod approach, documented here:
  https://wiki.openstack.org/wiki/Cue/api_2
* add unit tests for REST API
* add unit tests for Cue Objects
* add integration tests for Cue Objects

Change-Id: Ibbeaada7c24aa3e1e9c4cf1abaf60b3521a683f1
This commit is contained in:
dagnello 2014-12-17 08:42:10 -08:00
parent 6dd5bcab06
commit 0887839f44
19 changed files with 1192 additions and 438 deletions

View File

@ -19,6 +19,8 @@
"""Version 1 of the Cue API """Version 1 of the Cue API
""" """
from cue.api.controllers import base from cue.api.controllers import base
from cue.common import exception
from cue.common.i18n import _ # noqa
from cue import objects from cue import objects
import uuid import uuid
@ -49,35 +51,11 @@ class EndPoint(base.APIBase):
"URL to endpoint" "URL to endpoint"
class Node(base.APIBase):
"""Representation of a Node."""
def __init__(self, **kwargs):
self.fields = []
node_object_fields = list(objects.Node.fields)
# Adding endpoints since it is an api-only attribute.
self.fields.append('end_points')
for k in node_object_fields:
# only add fields we expose in the api
if hasattr(self, k):
self.fields.append(k)
setattr(self, k, kwargs.get(k, wtypes.Unset))
id = wtypes.text
"UUID of node"
flavor = wsme.wsattr(wtypes.text, mandatory=True)
"Flavor of cluster"
status = wtypes.text
"Current status of node"
end_points = wtypes.wsattr([EndPoint], default=[])
"List of endpoints on accessing node"
class Cluster(base.APIBase): class Cluster(base.APIBase):
"""Representation of a cluster.""" """Representation of a cluster."""
# todo(dagnello): WSME attribute verification sometimes triggers 500 server
# error when user input was actually invalid (400). Example: if 'size' was
# provided as a string/char, e.g. 'a', api returns 500 server error.
def __init__(self, **kwargs): def __init__(self, **kwargs):
self.fields = [] self.fields = []
@ -93,42 +71,44 @@ class Cluster(base.APIBase):
id = wtypes.text id = wtypes.text
"UUID of cluster" "UUID of cluster"
nic = wtypes.wsattr(wtypes.text, mandatory=True) network_id = wtypes.wsattr(wtypes.text, mandatory=True)
"NIC of Neutron network" "NIC of Neutron network"
nodes = wtypes.wsattr([Node], default=[])
"List of nodes of cluster"
name = wsme.wsattr(wtypes.text, mandatory=True) name = wsme.wsattr(wtypes.text, mandatory=True)
"Name of cluster" "Name of cluster"
status = wtypes.text status = wtypes.text
"Current status of cluster" "Current status of cluster"
flavor = wsme.wsattr(wtypes.text, mandatory=True)
"Flavor of cluster"
size = wtypes.IntegerType()
"Number of nodes in cluster"
volume_size = wtypes.IntegerType() volume_size = wtypes.IntegerType()
"Volume size for nodes in cluster" "Volume size for nodes in cluster"
end_points = wtypes.wsattr([EndPoint], default=[])
"List of endpoints on accessing node"
def get_complete_cluster(cluster_id): def get_complete_cluster(cluster_id):
"""Helper to retrieve the api-compatible full structure of a cluster.""" """Helper to retrieve the api-compatible full structure of a cluster."""
cluster_obj = objects.Cluster.get_cluster(cluster_id) cluster_obj = objects.Cluster.get_cluster_by_id(cluster_id)
# construct api cluster object # construct api cluster object
cluster = Cluster(**cluster_obj.as_dict()) cluster = Cluster(**cluster_obj.as_dict())
cluster_nodes = objects.Node.get_nodes(cluster_id) cluster_nodes = objects.Node.get_nodes_by_cluster_id(cluster_id)
# construct api node objects for node in cluster_nodes:
cluster.nodes = [Node(**obj_node.as_dict()) for obj_node in
cluster_nodes]
for node in cluster.nodes:
# extract endpoints from node # extract endpoints from node
node_endpoints = objects.Endpoint.get_endpoints(node.id) node_endpoints = objects.Endpoint.get_endpoints_by_node_id(node.id)
# construct api endpoint objects # construct api endpoint objects
node.end_points = [EndPoint(**obj_endpoint.as_dict()) for cluster.end_points = [EndPoint(**obj_endpoint.as_dict()) for
obj_endpoint in node_endpoints] obj_endpoint in node_endpoints]
return cluster return cluster
@ -151,7 +131,7 @@ class ClusterController(rest.RestController):
@wsme_pecan.wsexpose(None, status_code=202) @wsme_pecan.wsexpose(None, status_code=202)
def delete(self): def delete(self):
"""Delete this Cluster.""" """Delete this Cluster."""
objects.Cluster.mark_as_delete_cluster(self.id) objects.Cluster.update_cluster_deleting(self.id)
class ClustersController(rest.RestController): class ClustersController(rest.RestController):
@ -174,21 +154,17 @@ class ClustersController(rest.RestController):
:param data: cluster parameters within the request body. :param data: cluster parameters within the request body.
""" """
# validate user parameters if data.size <= 0:
cluster_flavor = data.nodes[0].flavor raise exception.Invalid(_("Invalid cluster size provided"))
for node in data.nodes:
if cluster_flavor != node.flavor:
pecan.abort(400)
# create new cluster object with required data from user # create new cluster object with required data from user
new_cluster = objects.Cluster(**data.as_dict()) new_cluster = objects.Cluster(**data.as_dict())
# TODO(dagnello): project_id will have to be extracted from HTTP header # TODO(dagnello): project_id will have to be extracted from HTTP header
project_id = unicode(uuid.uuid1()) project_id = unicode(uuid.uuid1())
number_of_nodes = len(data.nodes)
# create new cluster with node related data from user # create new cluster with node related data from user
new_cluster.create_cluster(project_id, cluster_flavor, number_of_nodes) new_cluster.create(project_id)
cluster = get_complete_cluster(new_cluster.id) cluster = get_complete_cluster(new_cluster.id)

View File

@ -53,7 +53,7 @@ class Connection(object):
""" """
@abc.abstractmethod @abc.abstractmethod
def create_cluster(self, cluster_values, flavor, number_of_nodes): def create_cluster(self, cluster_values):
"""Creates a new cluster. """Creates a new cluster.
:param cluster_values: Dictionary of several required items. :param cluster_values: Dictionary of several required items.
@ -61,18 +61,18 @@ class Connection(object):
:: ::
{ {
'network_id': obj_utils.str_or_none,
'project_id': obj_utils.str_or_none, 'project_id': obj_utils.str_or_none,
'name': obj_utils.str_or_none, 'name': obj_utils.str_or_none,
'nic': obj_utils.str_or_none, 'flavor': obj_utils.str_or_none,
'size': obj_utils.int_or_none,
'volume_size': obj_utils.int_or_none, 'volume_size': obj_utils.int_or_none,
} }
:param flavor: The required flavor for nodes in this cluster.
:param number_of_nodes: The number of nodes in this cluster.
""" """
@abc.abstractmethod @abc.abstractmethod
def get_cluster(self, cluster_id): def get_cluster_by_id(self, cluster_id):
"""Returns a Cluster objects for specified cluster_id. """Returns a Cluster objects for specified cluster_id.
:param cluster_id: UUID of a cluster. :param cluster_id: UUID of a cluster.
@ -81,7 +81,7 @@ class Connection(object):
""" """
@abc.abstractmethod @abc.abstractmethod
def get_nodes(self, cluster_id): def get_nodes_in_cluster(self, cluster_id):
"""Returns a list of Node objects for specified cluster. """Returns a list of Node objects for specified cluster.
:param cluster_id: UUID of the cluster. :param cluster_id: UUID of the cluster.
@ -90,7 +90,16 @@ class Connection(object):
""" """
@abc.abstractmethod @abc.abstractmethod
def get_endpoints(self, node_id): def get_node_by_id(self, node_id):
"""Returns a node for the specified node_id.
:param node_id: UUID of the node.
:returns: a :class:'Node' object.
"""
@abc.abstractmethod
def get_endpoints_in_node(self, node_id):
"""Returns a list of Endpoint objects for specified node. """Returns a list of Endpoint objects for specified node.
:param node_id: UUID of the node. :param node_id: UUID of the node.
@ -99,7 +108,16 @@ class Connection(object):
""" """
@abc.abstractmethod @abc.abstractmethod
def mark_as_delete_cluster(self, cluster_id): def get_endpoint_by_id(self, endpoint_id):
"""Returns an endpoint for the specified endpoint_id.
:param endpoint_id: UUID of the endpoint.
:returns: a :class:'Endpoint' object.
"""
@abc.abstractmethod
def update_cluster_deleting(self, cluster_id):
"""Marks specified cluster to indicate deletion. """Marks specified cluster to indicate deletion.
:param cluster_id: UUID of a cluster. :param cluster_id: UUID of a cluster.

View File

@ -36,28 +36,29 @@ def upgrade():
### commands auto generated by Alembic - please adjust! ### ### commands auto generated by Alembic - please adjust! ###
op.create_table('clusters', op.create_table('clusters',
sa.Column('id', types.UUID(), nullable=False), sa.Column('id', types.UUID(), nullable=False),
sa.Column('project_id', sa.String(length=36), nullable=False),
sa.Column('network_id', sa.String(length=36), nullable=False),
sa.Column('name', sa.String(length=255), nullable=False),
sa.Column('status', sa.String(length=50), nullable=False),
sa.Column('flavor', sa.String(length=50), nullable=False),
sa.Column('size', sa.Integer(), nullable=False),
sa.Column('volume_size', sa.Integer(), nullable=True),
sa.Column('deleted', sa.Boolean(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False), sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False), sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('deleted_at', sa.DateTime(), nullable=True), sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.Column('project_id', sa.String(length=36), nullable=False),
sa.Column('nic', sa.String(length=36), nullable=False),
sa.Column('name', sa.String(length=255), nullable=False),
sa.Column('status', sa.String(length=50), nullable=False),
sa.Column('volume_size', sa.Integer(), nullable=False),
sa.Column('deleted', sa.Boolean(), nullable=False),
sa.PrimaryKeyConstraint('id') sa.PrimaryKeyConstraint('id')
) )
op.create_table('nodes', op.create_table('nodes',
sa.Column('id', types.UUID(), nullable=False), sa.Column('id', types.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.Column('cluster_id', types.UUID(), nullable=True), sa.Column('cluster_id', types.UUID(), nullable=True),
sa.Column('flavor', sa.String(length=36), nullable=False), sa.Column('flavor', sa.String(length=36), nullable=False),
sa.Column('instance_id', sa.String(length=36), nullable=True), sa.Column('instance_id', sa.String(length=36), nullable=True),
sa.Column('status', sa.String(length=50), nullable=False), sa.Column('status', sa.String(length=50), nullable=False),
sa.Column('volume_size', sa.Integer(), nullable=False),
sa.Column('deleted', sa.Boolean(), nullable=False), sa.Column('deleted', sa.Boolean(), nullable=False),
sa.Column('created_at', sa.DateTime(), nullable=False),
sa.Column('updated_at', sa.DateTime(), nullable=False),
sa.Column('deleted_at', sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(['cluster_id'], ['clusters.id'], ), sa.ForeignKeyConstraint(['cluster_id'], ['clusters.id'], ),
sa.PrimaryKeyConstraint('id') sa.PrimaryKeyConstraint('id')
) )

View File

@ -14,12 +14,19 @@
# under the License. # under the License.
# #
# Copied from Neutron # Copied from Neutron
import uuid
from cue.common import exception
from cue.common.i18n import _ # noqa
from cue.db import api from cue.db import api
from cue.db.sqlalchemy import models from cue.db.sqlalchemy import models
from oslo.config import cfg from oslo.config import cfg
from oslo.db import exception as db_exception
from oslo.db import options as db_options from oslo.db import options as db_options
from oslo.db.sqlalchemy import session from oslo.db.sqlalchemy import session
from oslo.utils import timeutils
from sqlalchemy.orm import exc as sql_exception
CONF = cfg.CONF CONF = cfg.CONF
@ -58,6 +65,17 @@ def get_backend():
return Connection() return Connection()
def model_query(model, *args, **kwargs):
"""Query helper for simpler session usage.
:param session: if present, the session to use
"""
session = kwargs.get('session') or get_session()
query = session.query(model, *args)
return query
class Connection(api.Connection): class Connection(api.Connection):
"""SqlAlchemy connection implementation.""" """SqlAlchemy connection implementation."""
@ -65,51 +83,91 @@ class Connection(api.Connection):
pass pass
def get_clusters(self, project_id): def get_clusters(self, project_id):
db_session = get_session() query = model_query(models.Cluster).filter_by(deleted=False)
db_filter = { #project_id=project_id)
# TODO(dagnello): update project_id accordingly when enabled return query.all()
#'project_id': project_id,
'deleted': False, def create_cluster(self, cluster_values):
if not cluster_values.get('id'):
cluster_values['id'] = str(uuid.uuid4())
cluster_values['status'] = models.Status.BUILDING
cluster = models.Cluster()
cluster.update(cluster_values)
node_values = {
'cluster_id': cluster_values['id'],
'flavor': cluster_values['flavor'],
'status': models.Status.BUILDING,
} }
return models.Cluster.get_all(db_session, **db_filter)
def create_cluster(self, cluster_values, flavor, number_of_nodes): db_session = get_session()
with db_session.begin():
cluster.save(db_session)
db_session.flush()
for i in range(cluster_values['size']):
node = models.Node()
node_id = str(uuid.uuid4())
node_values['id'] = node_id
node.update(node_values)
node.save(db_session)
return cluster
def get_cluster_by_id(self, cluster_id):
query = model_query(models.Cluster).filter_by(id=cluster_id)
try:
cluster = query.one()
except db_exception.DBError:
# Todo(dagnello): User input will be validated from REST API and
# not from DB transactions.
raise exception.Invalid(_("badly formed cluster_id UUID string"))
except sql_exception.NoResultFound:
raise exception.NotFound(_("Cluster was not found"))
return cluster
def get_nodes_in_cluster(self, cluster_id):
query = model_query(models.Node).filter_by(cluster_id=cluster_id)
# No need to catch user-derived exceptions around not found or badly
# formed UUIDs if these happen, they should be classified as internal
# server errors since the user is not able to access nodes directly.
return query.all()
def get_node_by_id(self, node_id):
query = model_query(models.Node).filter_by(id=node_id)
return query.one()
def get_endpoints_in_node(self, node_id):
query = model_query(models.Endpoint).filter_by(node_id=node_id)
# No need to catch user-derived exceptions for same reason as above
return query.all()
def get_endpoint_by_id(self, endpoint_id):
query = model_query(models.Endpoint).filter_by(id=endpoint_id)
return query.one()
def update_cluster_deleting(self, cluster_id):
values = {'status': models.Status.DELETING,
'updated_at': timeutils.utcnow()}
cluster_query = model_query(models.Cluster).filter_by(id=cluster_id)
try:
cluster_query.one()
except db_exception.DBError:
# Todo(dagnello): User input will be validated from REST API and
# not from DB transactions.
raise exception.Invalid(_("badly formed cluster_id UUID string"))
except sql_exception.NoResultFound:
raise exception.NotFound(_("Cluster was not found"))
db_session = get_session() db_session = get_session()
cluster_ref = models.Cluster.add(db_session, cluster_values.project_id, with db_session.begin():
cluster_values.name, cluster_query.update(values)
cluster_values.nic, nodes_query = model_query(models.Node).filter_by(
cluster_values.volume_size)
for i in range(number_of_nodes):
models.Node.add(db_session, cluster_ref.id, flavor,
cluster_values['volume_size'])
return cluster_ref
def get_cluster(self, cluster_id):
db_session = get_session()
return models.Cluster.get(db_session, id=cluster_id)
def get_nodes(self, cluster_id):
db_session = get_session()
return models.Node.get_all(db_session, cluster_id=cluster_id)
def get_endpoints(self, node_id):
db_session = get_session()
node_endpoint_ref = models.Endpoint.get_all(db_session,
node_id=node_id)
return node_endpoint_ref
def mark_as_delete_cluster(self, cluster_id):
db_session = get_session()
cluster_node_ref = models.Node.get_all(db_session,
cluster_id=cluster_id) cluster_id=cluster_id)
models.Cluster.delete(db_session, cluster_id) nodes_query.update(values)
for node in cluster_node_ref:
models.Node.delete(db_session, node.id)

View File

@ -13,8 +13,6 @@
# under the License. # under the License.
# #
# Copied from Octavia # Copied from Octavia
from cue.common import exception
from cue.db.sqlalchemy import types from cue.db.sqlalchemy import types
import uuid import uuid
@ -26,44 +24,6 @@ from sqlalchemy.ext import declarative
class CueBase(models.ModelBase): class CueBase(models.ModelBase):
@classmethod
def create(cls, session, **kwargs):
with session.begin():
instance = cls(**kwargs)
session.add(instance)
return instance
@classmethod
def delete(cls, session, id):
model = session.query(cls).filter_by(id=id).first()
if not model:
raise exception.NotFound
with session.begin():
session.delete(model)
session.flush()
@classmethod
def delete_batch(self, session, ids=None):
[self.delete(session, id) for id in ids]
@classmethod
def update(cls, session, id, **kwargs):
with session.begin():
kwargs.update(updated_at=timeutils.utcnow())
session.query(cls).filter_by(id=id).update(kwargs)
@classmethod
def get(cls, session, **filters):
instance = session.query(cls).filter_by(**filters).first()
if not instance:
raise exception.NotFound
return instance
@classmethod
def get_all(cls, session, **filters):
data = session.query(cls).filter_by(**filters).all()
return data
def as_dict(self): def as_dict(self):
d = {} d = {}
for c in self.__table__.columns: for c in self.__table__.columns:

View File

@ -35,21 +35,10 @@ class Endpoint(base.BASE, base.IdMixin):
primary_key=True) primary_key=True)
uri = sa.Column(sa.String(255), nullable=False) uri = sa.Column(sa.String(255), nullable=False)
type = sa.Column(sa.String(length=255), nullable=False) type = sa.Column(sa.String(length=255), nullable=False)
deleted = sa.Column(sa.Boolean(), nullable=False) deleted = sa.Column(sa.Boolean(), default=False, nullable=False)
sa.Index("endpoints_id_idx", "id", unique=True) sa.Index("endpoints_id_idx", "id", unique=True)
sa.Index("endpoints_nodes_id_idx", "node_id", unique=False) sa.Index("endpoints_nodes_id_idx", "node_id", unique=False)
@classmethod
def add(cls, session, node_id, endpoint_type, uri):
endpoint = {
"node_id": node_id,
"uri": uri,
"type": endpoint_type,
"deleted": False
}
return super(Endpoint, cls).create(session, **endpoint)
class Node(base.BASE, base.IdMixin, base.TimeMixin): class Node(base.BASE, base.IdMixin, base.TimeMixin):
__tablename__ = 'nodes' __tablename__ = 'nodes'
@ -60,53 +49,20 @@ class Node(base.BASE, base.IdMixin, base.TimeMixin):
flavor = sa.Column(sa.String(36), nullable=False) flavor = sa.Column(sa.String(36), nullable=False)
instance_id = sa.Column(sa.String(36), nullable=True) instance_id = sa.Column(sa.String(36), nullable=True)
status = sa.Column(sa.String(50), nullable=False) status = sa.Column(sa.String(50), nullable=False)
volume_size = sa.Column(sa.Integer(), nullable=False) deleted = sa.Column(sa.Boolean(), default=False, nullable=False)
deleted = sa.Column(sa.Boolean(), nullable=False)
sa.Index("nodes_id_idx", "id", unique=True) sa.Index("nodes_id_idx", "id", unique=True)
sa.Index("nodes_cluster_id_idx", "cluster_id", unique=False) sa.Index("nodes_cluster_id_idx", "cluster_id", unique=False)
@classmethod
def add(cls, session, cluster_id, flavor, vol_size):
node = {
"cluster_id": cluster_id,
"flavor": flavor,
"volume_size": vol_size,
"deleted": False,
"status": Status.BUILDING
}
return super(Node, cls).create(session, **node)
@classmethod
def delete(cls, session, node_id):
super(Node, cls).update(session, id=node_id, status=Status.DELETING)
class Cluster(base.BASE, base.IdMixin, base.TimeMixin): class Cluster(base.BASE, base.IdMixin, base.TimeMixin):
__tablename__ = 'clusters' __tablename__ = 'clusters'
project_id = sa.Column(sa.String(36), nullable=False) project_id = sa.Column(sa.String(36), nullable=False)
nic = sa.Column(sa.String(36), nullable=False) network_id = sa.Column(sa.String(36), nullable=False)
name = sa.Column(sa.String(255), nullable=False) name = sa.Column(sa.String(255), nullable=False)
status = sa.Column(sa.String(50), nullable=False) status = sa.Column(sa.String(50), nullable=False)
volume_size = sa.Column(sa.Integer(), nullable=False) flavor = sa.Column(sa.String(50), nullable=False)
deleted = sa.Column(sa.Boolean(), nullable=False) size = sa.Column(sa.Integer(), default=1, nullable=False)
volume_size = sa.Column(sa.Integer(), nullable=True)
deleted = sa.Column(sa.Boolean(), default=False, nullable=False)
sa.Index("clusters_cluster_id_idx", "cluster_id", unique=True) sa.Index("clusters_cluster_id_idx", "cluster_id", unique=True)
@classmethod
def add(cls, session, project_id, name, nic, vol_size):
cluster = {
"project_id": project_id,
"name": name,
"nic": nic,
"volume_size": vol_size,
"deleted": False,
"status": Status.BUILDING
}
return super(Cluster, cls).create(session, **cluster)
@classmethod
def delete(cls, session, cluster_id):
super(Cluster, cls).update(session, id=cluster_id,
status=Status.DELETING)

View File

@ -121,3 +121,12 @@ class CueObject(object):
return dict((k, getattr(self, k)) return dict((k, getattr(self, k))
for k in self.fields for k in self.fields
if hasattr(self, k)) if hasattr(self, k))
def obj_get_changes(self):
"""Returns dict of changed fields and their new values."""
changes = {}
for key in self._changed_fields:
changes[key] = self[key]
return changes

View File

@ -27,11 +27,14 @@ class Cluster(base.CueObject):
fields = { fields = {
'id': obj_utils.str_or_none, 'id': obj_utils.str_or_none,
'network_id': obj_utils.str_or_none,
'project_id': obj_utils.str_or_none, 'project_id': obj_utils.str_or_none,
'name': obj_utils.str_or_none, 'name': obj_utils.str_or_none,
'nic': obj_utils.str_or_none,
'volume_size': obj_utils.int_or_none,
'status': obj_utils.str_or_none, 'status': obj_utils.str_or_none,
'flavor': obj_utils.str_or_none,
'size': obj_utils.int_or_none,
'volume_size': obj_utils.int_or_none,
'deleted': obj_utils.bool_or_none,
'created_at': obj_utils.datetime_or_str_or_none, 'created_at': obj_utils.datetime_or_str_or_none,
'updated_at': obj_utils.datetime_or_str_or_none, 'updated_at': obj_utils.datetime_or_str_or_none,
'deleted_at': obj_utils.datetime_or_str_or_none, 'deleted_at': obj_utils.datetime_or_str_or_none,
@ -44,17 +47,16 @@ class Cluster(base.CueObject):
cluster[field] = db_cluster[field] cluster[field] = db_cluster[field]
return cluster return cluster
def create_cluster(self, project_id, flavor, number_of_nodes): def create(self, project_id):
"""Creates a new cluster. """Creates a new cluster.
:param project_id: The project id the cluster resides in. :param project_id: The project id the cluster resides in.
:param flavor: The required flavor for nodes in this cluster.
:param number_of_nodes: The number of nodes in this cluster.
""" """
self['project_id'] = project_id self['project_id'] = project_id
db_cluster = self.dbapi.create_cluster(self, flavor, cluster_changes = self.obj_get_changes()
number_of_nodes)
db_cluster = self.dbapi.create_cluster(cluster_changes)
self._from_db_object(self, db_cluster) self._from_db_object(self, db_cluster)
@ -70,22 +72,22 @@ class Cluster(base.CueObject):
return [Cluster._from_db_object(Cluster(), obj) for obj in db_clusters] return [Cluster._from_db_object(Cluster(), obj) for obj in db_clusters]
@classmethod @classmethod
def get_cluster(cls, cluster_id): def get_cluster_by_id(cls, cluster_id):
"""Returns a Cluster objects for specified cluster_id. """Returns a Cluster objects for specified cluster_id.
:param cluster_id: UUID of a cluster. :param cluster_id: UUID of a cluster.
:returns: a :class:'Cluster' object. :returns: a :class:'Cluster' object.
""" """
db_cluster = cls.dbapi.get_cluster(cluster_id) db_cluster = cls.dbapi.get_cluster_by_id(cluster_id)
cluster = Cluster._from_db_object(Cluster(), db_cluster) cluster = Cluster._from_db_object(Cluster(), db_cluster)
return cluster return cluster
@classmethod @classmethod
def mark_as_delete_cluster(cls, cluster_id): def update_cluster_deleting(cls, cluster_id):
"""Marks specified cluster to indicate deletion. """Marks specified cluster to indicate deletion.
:param cluster_id: UUID of a cluster. :param cluster_id: UUID of a cluster.
""" """
cls.dbapi.mark_as_delete_cluster(cluster_id) cls.dbapi.update_cluster_deleting(cluster_id)

View File

@ -40,13 +40,13 @@ class Endpoint(base.CueObject):
return cluster return cluster
@classmethod @classmethod
def get_endpoints(cls, node_id): def get_endpoints_by_node_id(cls, node_id):
"""Returns a list of Endpoint objects for specified node. """Returns a list of Endpoint objects for specified node.
:param node_id: UUID of the node. :param node_id: UUID of the node.
:returns: a list of :class:'Endpoint' object. :returns: a list of :class:'Endpoint' object.
""" """
db_endpoints = cls.dbapi.get_endpoints(node_id) db_endpoints = cls.dbapi.get_endpoints_in_node(node_id)
return [Endpoint._from_db_object(Endpoint(), obj) for obj in db_endpoints] return [Endpoint._from_db_object(Endpoint(), obj) for obj in db_endpoints]

View File

@ -44,13 +44,13 @@ class Node(base.CueObject):
return node return node
@classmethod @classmethod
def get_nodes(cls, cluster_id): def get_nodes_by_cluster_id(cls, cluster_id):
"""Returns a list of Node objects for specified cluster. """Returns a list of Node objects for specified cluster.
:param cluster_id: UUID of the cluster. :param cluster_id: UUID of the cluster.
:returns: a list of :class:'Node' object. :returns: a list of :class:'Node' object.
""" """
db_nodes = cls.dbapi.get_nodes(cluster_id) db_nodes = cls.dbapi.get_nodes_in_cluster(cluster_id)
return [Node._from_db_object(Node(), obj) for obj in db_nodes] return [Node._from_db_object(Node(), obj) for obj in db_nodes]

View File

@ -0,0 +1,74 @@
# Copyright 2015 Hewlett-Packard Development Company, L.P.
#
# Authors: Davide Agnello <davide.agnello@hp.com>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright [2014] Hewlett-Packard Development Company, L.P.
# limitations under the License.
"""
Common API base class to controller test classes.
"""
import cue.tests.api as api_base
class ApiCommon(api_base.FunctionalTest):
cluster_name = "test-cluster"
def validate_cluster_values(self, cluster_ref, cluster_cmp):
self.assertEqual(cluster_ref.id if hasattr(cluster_ref, "id") else
cluster_ref["id"],
cluster_cmp.id if hasattr(cluster_cmp, "id") else
cluster_cmp["id"],
"Invalid cluster id value")
self.assertEqual(cluster_ref.network_id if hasattr(cluster_ref,
"network_id")
else cluster_ref["network_id"],
cluster_cmp.network_id if hasattr(cluster_cmp,
"network_id")
else cluster_cmp["network_id"],
"Invalid cluster network_id value")
self.assertEqual(cluster_ref.name if hasattr(cluster_ref, "name")
else cluster_ref["name"],
cluster_cmp.name if hasattr(cluster_cmp, "name")
else cluster_cmp["name"],
"Invalid cluster name value")
self.assertEqual(cluster_ref.status if hasattr(cluster_ref, "status")
else cluster_ref["status"],
cluster_cmp.status if hasattr(cluster_cmp, "status")
else cluster_cmp["status"],
"Invalid cluster status value")
self.assertEqual(cluster_ref.flavor if hasattr(cluster_ref, "flavor")
else cluster_ref["flavor"],
cluster_cmp.flavor if hasattr(cluster_cmp, "flavor")
else cluster_cmp["flavor"],
"Invalid cluster flavor value")
self.assertEqual(cluster_ref.size if hasattr(cluster_ref, "size")
else cluster_ref["size"],
cluster_cmp.size if hasattr(cluster_cmp, "size")
else cluster_cmp["size"],
"Invalid cluster size value")
self.assertEqual(cluster_ref.volume_size if hasattr(cluster_ref,
"volume_size")
else cluster_ref["volume_size"],
cluster_cmp.volume_size if hasattr(cluster_cmp,
"volume_size")
else cluster_cmp["volume_size"],
"Invalid cluster volume_size value")
self.assertEqual(unicode(cluster_ref.created_at.isoformat()),
cluster_cmp["created_at"],
"Invalid cluster created_at value")
self.assertEqual(unicode(cluster_ref.updated_at.isoformat()),
cluster_cmp["updated_at"],
"Invalid cluster updated_at value")

View File

@ -0,0 +1,134 @@
# Copyright 2015 Hewlett-Packard Development Company, L.P.
#
# Authors: Davide Agnello <davide.agnello@hp.com>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright [2014] Hewlett-Packard Development Company, L.P.
# limitations under the License.
"""
Tests for the API /cluster/ controller methods.
"""
import uuid
from cue.db.sqlalchemy import models
from cue import objects
from cue.tests.api import api_common
from cue.tests import utils as test_utils
class TestGetCluster(api_common.ApiCommon):
def setUp(self):
super(TestGetCluster, self).setUp()
def test_get_cluster_not_found(self):
"""test get non-existing cluster."""
data = self.get_json('/clusters/' + str(uuid.uuid4()),
expect_errors=True)
self.assertEqual(404, data.status_code,
'Invalid status code value received.')
self.assertEqual('404 Not Found', data.status,
'Invalid status value received.')
self.assertIn('Cluster was not found',
data.namespace["faultstring"],
'Invalid faultstring received.')
def test_get_cluster_invalid_uuid_format(self):
"""test get cluster with invalid id uuid format."""
invalid_uuid = u"25c06c22.fadd.4c83-a515-974a29668ba9"
data = self.get_json('/clusters/' + invalid_uuid, expect_errors=True)
self.assertEqual(400, data.status_code,
'Invalid status code value received.')
self.assertEqual('400 Bad Request', data.status,
'Invalid status value received.')
self.assertIn('badly formed cluster_id UUID string',
data.namespace["faultstring"],
'Invalid faultstring received.')
def test_get_cluster_invalid_uri(self):
"""test get cluster with invalid URI string.
Example: get /clusters/<cluster_id>/invalid_resource
"""
def test_get_cluster_valid_uri(self):
"""test get cluster with valid URI strings.
Examples (with and without end forward slash):
get /clusters/<cluster_id>
get /clusters/<cluster_id>/
"""
def test_get_cluster(self):
"""test get cluster on valid existing cluster."""
cluster = test_utils.create_db_test_cluster_from_objects_api(
name=self.cluster_name)
data = self.get_json('/clusters/' + cluster.id)
self.validate_cluster_values(cluster, data)
class TestDeleteCluster(api_common.ApiCommon):
def setUp(self):
super(TestDeleteCluster, self).setUp()
def test_delete_cluster_not_found(self):
"""test delete non-existing cluster."""
data = self.delete('/clusters/' + str(uuid.uuid4()),
expect_errors=True)
self.assertEqual(404, data.status_code,
'Invalid status code value received.')
self.assertEqual('404 Not Found', data.status,
'Invalid status value received.')
self.assertIn('Cluster was not found',
data.namespace["faultstring"],
'Invalid faultstring received.')
def test_delete_cluster_invalid_uuid_format(self):
"""test delete cluster with invalid uuid format."""
invalid_uuid = u"25c06c22.fadd.4c83-a515-974a29668ba9"
data = self.delete('/clusters/' + invalid_uuid, expect_errors=True)
self.assertEqual(400, data.status_code,
'Invalid status code value received.')
self.assertEqual('400 Bad Request', data.status,
'Invalid status value received.')
self.assertIn('badly formed cluster_id UUID string',
data.namespace["faultstring"],
'Invalid faultstring received.')
def test_deleted_cluster_already_deleted(self):
"""test delete cluster that has already been deleted."""
def test_delete_pending_cluster(self):
"""test delete cluster that is pending deletion."""
def test_delete_cluster(self):
"""test delete cluster on valid existing cluster."""
cluster = test_utils.create_db_test_cluster_from_objects_api(
name=self.cluster_name)
cluster_in_db = objects.Cluster.get_cluster_by_id(cluster.id)
self.assertEqual(models.Status.BUILDING, cluster_in_db.status,
"Invalid cluster status value")
self.delete('/clusters/' + cluster.id)
cluster_in_db = objects.Cluster.get_cluster_by_id(cluster.id)
cluster.status = models.Status.DELETING
cluster.updated_at = cluster_in_db.created_at
cluster.updated_at = cluster_in_db.updated_at
data = self.get_json('/clusters/' + cluster.id)
self.validate_cluster_values(cluster, data)

View File

@ -12,38 +12,149 @@
# License for the specific language governing permissions and limitations # License for the specific language governing permissions and limitations
# under the License. # under the License.
""" """
Tests for the API /clusters/ methods. Tests for the API /clusters/ controller methods.
""" """
from cue.db.sqlalchemy import models
import cue.tests.api as api_base from cue import objects
from cue.tests.db import utils as dbutils from cue.tests.api import api_common
from cue.tests import utils as test_utils
# class TestClusterObject(base.TestCase): class TestListClusters(api_common.ApiCommon):
#
# def test_cluster_init(self):
# # port_dict = apiutils.port_post_data(node_id=None)
# # del port_dict['extra']
# # port = api_port.Port(**port_dict)
# # self.assertEqual(wtypes.Unset, port.extra)
class TestListClusters(api_base.FunctionalTest):
cluster_name = "test-cluster"
def setUp(self): def setUp(self):
super(TestListClusters, self).setUp() super(TestListClusters, self).setUp()
#self.cluster = dbutils.create_test_cluster()
def test_empty(self): def test_empty(self):
data = self.get_json('/clusters') data = self.get_json('/clusters')
# TODO(vipul): This should probably return a empty 'clusters'
self.assertEqual([], data) self.assertEqual([], data)
def test_one(self): def test_one(self):
cluster = dbutils.create_test_cluster(name=self.cluster_name) cluster = test_utils.create_db_test_cluster_from_objects_api(
name=self.cluster_name)
data = self.get_json('/clusters') data = self.get_json('/clusters')
self.assertEqual(cluster.id, data[0]["id"]) self.assertEqual(len(data), 1, "Invalid number of clusters returned")
self.assertEqual(self.cluster_name, data[0]["name"])
self.validate_cluster_values(cluster, data[0])
def test_multiple(self):
cluster_0 = test_utils.create_db_test_cluster_from_objects_api(
name=self.cluster_name + '_0')
cluster_1 = test_utils.create_db_test_cluster_from_objects_api(
name=self.cluster_name + '_1')
cluster_2 = test_utils.create_db_test_cluster_from_objects_api(
name=self.cluster_name + '_2')
cluster_3 = test_utils.create_db_test_cluster_from_objects_api(
name=self.cluster_name + '_3')
cluster_4 = test_utils.create_db_test_cluster_from_objects_api(
name=self.cluster_name + '_4')
data = self.get_json('/clusters')
self.assertEqual(len(data), 5, "Invalid number of clusters returned")
self.validate_cluster_values(cluster_0, data[0])
self.validate_cluster_values(cluster_1, data[1])
self.validate_cluster_values(cluster_2, data[2])
self.validate_cluster_values(cluster_3, data[3])
self.validate_cluster_values(cluster_4, data[4])
class TestCreateCluster(api_common.ApiCommon):
def setUp(self):
super(TestCreateCluster, self).setUp()
def test_create_empty_body(self):
cluster_params = {}
header = {'Content-Type': 'application/json'}
data = self.post_json('/clusters', params=cluster_params,
headers=header, expect_errors=True)
self.assertEqual(400, data.status_code,
'Invalid status code value received.')
self.assertEqual('400 Bad Request', data.status,
'Invalid status value received.')
self.assertIn('Invalid input for field/attribute',
data.namespace["faultstring"],
'Invalid faultstring received.')
def test_create_size_zero(self):
"""test create an empty cluster."""
api_cluster = test_utils.create_api_test_cluster(size=0)
header = {'Content-Type': 'application/json'}
data = self.post_json('/clusters', params=api_cluster.as_dict(),
headers=header, expect_errors=True)
self.assertEqual(400, data.status_code,
'Invalid status code value received.')
self.assertEqual('400 Bad Request', data.status,
'Invalid status value received.')
self.assertIn('Invalid cluster size provided',
data.namespace["faultstring"],
'Invalid faultstring received.')
def test_create_size_one(self):
"""test create a cluster with one node.
Will verify cluster create from DB record then verifies cluster get
returns the same cluster from the API.
"""
api_cluster = test_utils.create_api_test_cluster(size=1)
header = {'Content-Type': 'application/json'}
data = self.post_json('/clusters', params=api_cluster.as_dict(),
headers=header, status=202)
cluster = objects.Cluster.get_cluster_by_id(data.json["id"])
self.validate_cluster_values(cluster, data.json)
self.assertEqual(models.Status.BUILDING, data.json['status'])
data_api = self.get_json('/clusters/' + cluster.id)
self.validate_cluster_values(cluster, data_api)
self.assertEqual(models.Status.BUILDING, data_api['status'])
def test_create_size_three(self):
"""test create a cluster with three nodes.
Will verify cluster create from DB record then verifies cluster get
returns the same cluster from the API.
"""
api_cluster = test_utils.create_api_test_cluster(size=3)
header = {'Content-Type': 'application/json'}
data = self.post_json('/clusters', params=api_cluster.as_dict(),
headers=header, status=202)
cluster = objects.Cluster.get_cluster_by_id(data.json["id"])
self.validate_cluster_values(cluster, data.json)
self.assertEqual(models.Status.BUILDING, data.json['status'])
data_api = self.get_json('/clusters/' + cluster.id)
self.validate_cluster_values(cluster, data_api)
self.assertEqual(models.Status.BUILDING, data_api['status'])
def test_create_invalid_size_format(self):
"""test with invalid formatted size parameter."""
api_cluster = test_utils.create_api_test_cluster(size="a")
header = {'Content-Type': 'application/json'}
data = self.post_json('/clusters', params=api_cluster.as_dict(),
headers=header, expect_errors=True)
self.assertEqual(500, data.status_code,
'Invalid status code value received.')
self.assertEqual('500 Internal Server Error', data.status,
'Invalid status value received.')
self.assertIn('invalid literal for int() with base 10:',
data.namespace["faultstring"],
'Invalid faultstring received.')
def test_create_invalid_volume_size(self):
"""test with invalid volume_size parameter."""
def test_create_invalid_parameter_set_id(self):
"""test with invalid parameter set: id."""
def test_create_invalid_parameter_set_status(self):
"""test with invalid parameter set: status."""
def test_create_invalid_parameter_set_created_at(self):
"""test with invalid parameter set: created_at."""
def test_create_invalid_parameter_set_updated_at(self):
"""test with invalid parameter set: updated_at."""
def test_create_invalid_parameter_set_deleted_at(self):
"""test with invalid parameter set: deleted_at."""

90
cue/tests/db/test_api.py Normal file
View File

@ -0,0 +1,90 @@
# Copyright 2014 Hewlett-Packard Development Company, L.P.
#
# Authors: Davide Agnello <davide.agnello@hp.com>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright [2014] Hewlett-Packard Development Company, L.P.
# limitations under the License.
import uuid
from cue.db import api as db_api
from cue.tests import base
UUID1 = str(uuid.uuid4())
UUID2 = str(uuid.uuid4())
class ApiTests(base.TestCase):
dbapi = db_api.get_instance()
def test_get_clusters(self):
"""Verifies get clusters DB API."""
def test_create_clusters(self):
"""Verifies create cluster DB API."""
cluster_values = {
"project_id": UUID1,
"name": "Rabbit Cluster",
"network_id": UUID2,
"flavor": "medium",
"size": 5,
"volume_size": 250,
}
db_cluster = self.dbapi.create_cluster(cluster_values)
self.assertEqual(cluster_values["name"], db_cluster.name,
"invalid name value")
self.assertEqual(cluster_values["network_id"], db_cluster.network_id,
"invalid network_id value")
self.assertEqual(cluster_values["flavor"], db_cluster.flavor,
"invalid flavor value")
self.assertEqual(cluster_values["size"], db_cluster.size,
"invalid size value")
self.assertEqual(cluster_values["volume_size"], db_cluster.volume_size,
"invalid volume_size value")
self.assertEqual(False, db_cluster.deleted, "invalid deleted value")
def test_get_cluster_by_id(self):
"""Verifies create cluster DB API."""
def test_get_nodes_in_cluster(self):
"""Verifies create cluster DB API."""
cluster_values = {
"project_id": UUID1,
"name": "Rabbit Cluster",
"network_id": UUID2,
"flavor": "medium",
"size": 5,
"volume_size": 250,
}
db_cluster = self.dbapi.create_cluster(cluster_values)
db_nodes = self.dbapi.get_nodes_in_cluster(db_cluster.id)
for node in db_nodes:
self.assertEqual(db_cluster.id, node.cluster_id,
"invalid flavor value")
self.assertEqual(cluster_values["flavor"], node.flavor,
"invalid flavor value")
self.assertEqual(False, node.deleted,
"invalid deleted value")
def test_get_endpoints_in_node(self):
"""Verifies create cluster DB API."""
def test_update_cluster_deleting(self):
"""Verifies create cluster DB API."""

View File

@ -15,164 +15,227 @@
# Copied from Octavia # Copied from Octavia
import uuid import uuid
from cue.common import exception from cue.db import api as db_api
from cue.db.sqlalchemy import api from cue.db.sqlalchemy import api as sql_api
from cue.db.sqlalchemy import models from cue.db.sqlalchemy import models
from cue.tests import base from cue.tests import base
from oslo.utils import timeutils
UUID1 = str(uuid.uuid4()) UUID1 = str(uuid.uuid4())
UUID2 = str(uuid.uuid4()) UUID2 = str(uuid.uuid4())
UUID3 = str(uuid.uuid4())
class ClusterTests(base.TestCase): class ModelsTests(base.TestCase):
def test_create(self):
def test_create_cluster_model(self):
"""Verifies a new cluster record is created in DB.""" """Verifies a new cluster record is created in DB."""
data = { cluster_values = {
"project_id": UUID1, "id": UUID1,
"name": "test", "network_id": UUID3,
"project_id": UUID2,
"name": "Cluster test",
"status": models.Status.BUILDING, "status": models.Status.BUILDING,
"nic": UUID2, "flavor": "medium",
"volume_size": 0, "size": 3,
"deleted": False "volume_size": 250,
"deleted": False,
"created_at": timeutils.utcnow(),
"updated_at": timeutils.utcnow(),
"deleted_at": timeutils.utcnow(),
} }
ref = models.Cluster.create(self.session, **data)
self.assertIsInstance(ref, models.Cluster)
self.session = api.get_session() cluster = models.Cluster()
get_ref = models.Cluster.get(self.session, id=ref.id) cluster.update(cluster_values)
self.assertEqual(ref.id, get_ref.id, "Database object does not match "
"submitted object")
def test_update(self): self.assertEqual(cluster_values["id"], cluster.id, "Invalid ID value")
"""Verifies update record function.""" self.assertEqual(cluster_values["project_id"], cluster.project_id,
"Invalid project_id value")
self.assertEqual(cluster_values["network_id"], cluster.network_id,
"Invalid network_id value")
self.assertEqual(cluster_values["name"], cluster.name, "Invalid name"
"value")
self.assertEqual(cluster_values["status"], cluster.status, "Invalid "
"status"
"value")
self.assertEqual(cluster_values["flavor"], cluster.flavor,
"Invalid flavor value")
self.assertEqual(cluster_values["size"], cluster.size,
"Invalid size value")
self.assertEqual(cluster_values["volume_size"], cluster.volume_size,
"Invalid volume_size value")
self.assertEqual(cluster_values["deleted"], cluster.deleted,
"Invalid deleted value")
self.assertEqual(cluster_values["created_at"], cluster.created_at,
"Invalid created_at value")
self.assertEqual(cluster_values["updated_at"], cluster.updated_at,
"Invalid updated_at value")
self.assertEqual(cluster_values["deleted_at"], cluster.deleted_at,
"Invalid deleted_at value")
data = { db_session = sql_api.get_session()
"project_id": UUID1, cluster.save(db_session)
"name": "test",
dbapi = db_api.get_instance()
cluster_db = dbapi.get_cluster_by_id(cluster_values["id"])
self.assertEqual(cluster_values["id"], cluster_db.id, "Invalid ID "
"value")
self.assertEqual(cluster_values["project_id"], cluster_db.project_id,
"Invalid project_id value")
self.assertEqual(cluster_values["network_id"], cluster_db.network_id,
"Invalid network_id value")
self.assertEqual(cluster_values["name"], cluster_db.name, "Invalid "
"name value")
self.assertEqual(cluster_values["status"], cluster_db.status,
"Invalid status value")
self.assertEqual(cluster_values["flavor"], cluster_db.flavor,
"Invalid flavor value")
self.assertEqual(cluster_values["size"], cluster_db.size,
"Invalid size value")
self.assertEqual(cluster_values["volume_size"], cluster_db.volume_size,
"Invalid volume_size value")
self.assertEqual(cluster_values["deleted"], cluster_db.deleted,
"Invalid deleted value")
self.assertEqual(cluster_values["created_at"], cluster_db.created_at,
"Invalid created_at value")
self.assertEqual(cluster_values["updated_at"], cluster_db.updated_at,
"Invalid updated_at value")
self.assertEqual(cluster_values["deleted_at"], cluster_db.deleted_at,
"Invalid deleted_at value")
def test_create_node_model(self):
"""Verifies a new cluster record is created in DB."""
dbapi = db_api.get_instance()
cluster_values = {
"network_id": UUID3,
"project_id": UUID2,
"name": "Cluster test",
"flavor": "medium",
"size": 3,
"volume_size": 250,
}
db_cluster = dbapi.create_cluster(cluster_values)
node_values = {
"id": UUID1,
"cluster_id": db_cluster.id,
"instance_id": "NovaInstanceId",
"flavor": "Large",
"status": models.Status.BUILDING, "status": models.Status.BUILDING,
"nic": UUID2, "deleted": False,
"volume_size": 0, "created_at": timeutils.utcnow(),
"deleted": False "updated_at": timeutils.utcnow(),
"deleted_at": timeutils.utcnow(),
} }
ref = models.Cluster.create(self.session, **data)
self.assertIsInstance(ref, models.Cluster)
self.session = api.get_session() node = models.Node()
data2 = { node.update(node_values)
"name": "NewName",
"status": "ACTIVE" self.assertEqual(node_values["id"], node.id, "Invalid ID value")
self.assertEqual(node_values["cluster_id"], node.cluster_id,
"Invalid cluster_id value")
self.assertEqual(node_values["instance_id"], node.instance_id,
"Invalid instance_id value")
self.assertEqual(node_values["status"], node.status, "Invalid status "
"value")
self.assertEqual(node_values["flavor"], node.flavor, "Invalid flavor "
"value")
self.assertEqual(node_values["deleted"], node.deleted,
"Invalid deleted value")
self.assertEqual(node_values["created_at"], node.created_at,
"Invalid created_at value")
self.assertEqual(node_values["updated_at"], node.updated_at,
"Invalid updated_at value")
self.assertEqual(node_values["deleted_at"], node.deleted_at,
"Invalid deleted_at value")
db_session = sql_api.get_session()
node.save(db_session)
node_db = dbapi.get_node_by_id(node_values["id"])
self.assertEqual(node_values["id"], node_db.id, "Invalid ID value")
self.assertEqual(node_values["cluster_id"], node_db.cluster_id,
"Invalid cluster_id value")
self.assertEqual(node_values["instance_id"], node_db.instance_id,
"Invalid instance_id value")
self.assertEqual(node_values["status"], node_db.status, "Invalid "
"status value")
self.assertEqual(node_values["flavor"], node_db.flavor, "Invalid "
"flavor value")
self.assertEqual(node_values["deleted"], node_db.deleted,
"Invalid deleted value")
self.assertEqual(node_values["created_at"], node_db.created_at,
"Invalid created_at value")
self.assertEqual(node_values["updated_at"], node_db.updated_at,
"Invalid updated_at value")
self.assertEqual(node_values["deleted_at"], node_db.deleted_at,
"Invalid deleted_at value")
def test_create_endpoint_model(self):
"""Verifies a new cluster record is created in DB."""
dbapi = db_api.get_instance()
cluster_values = {
"network_id": UUID3,
"project_id": UUID2,
"name": "Cluster test",
"flavor": "medium",
"size": 3,
"volume_size": 250,
} }
models.Cluster.update(self.session, ref.id, **data2) db_cluster = dbapi.create_cluster(cluster_values)
self.session = api.get_session() node_values = {
get_ref = models.Cluster.get(self.session, id=ref.id) "cluster_id": db_cluster.id,
self.assertEqual(str(ref.name), 'test', "Original cluster name was" "flavor": "Large",
"unexpectedly changed")
self.assertEqual(str(get_ref.name), 'NewName', "Cluster name was not "
"updated")
self.assertEqual(str(get_ref.status), 'ACTIVE', "Cluster status was "
"not updated")
self.assertGreater(get_ref.updated_at, ref.updated_at, "Updated "
"datetime was "
"not updated.")
def test_delete(self):
"""Verifies deleting existing record is removed from DB."""
data = {
"project_id": UUID1,
"name": "test",
"nic": UUID2,
"volume_size": 0,
"status": models.Status.BUILDING, "status": models.Status.BUILDING,
"deleted": False
}
ref = models.Cluster.create(self.session, **data)
self.assertIsInstance(ref, models.Cluster)
models.Cluster.delete(self.session, ref.id)
self.session = api.get_session()
try:
get_ref = models.Cluster.get(self.session, id=ref.id)
except exception.NotFound:
self.fail('Record was deleted entirely')
else:
if get_ref.status != models.Status.DELETING:
self.fail('Record status was not update to DELETING')
def test_get_delete_batch(self):
"""Verifies delete batch records from DB."""
data = {
"project_id": UUID1,
"name": "test1",
"status": models.Status.BUILDING,
"nic": UUID2,
"volume_size": 0,
"deleted": False
}
ref1 = models.Cluster.create(self.session, **data)
self.assertIsInstance(ref1, models.Cluster)
data.update(name="test2")
ref2 = models.Cluster.create(self.session, **data)
self.assertIsInstance(ref2, models.Cluster)
data.update(name="test3")
ref3 = models.Cluster.create(self.session, **data)
self.assertIsInstance(ref3, models.Cluster)
data.update(name="test4")
ref4 = models.Cluster.create(self.session, **data)
self.assertIsInstance(ref4, models.Cluster)
self.session = api.get_session()
clusters_before = models.Cluster.get_all(self.session,
status=models.Status.BUILDING)
ids = [ref1.id, ref2.id, ref3.id, ref4.id]
self.assertEqual(len(ids), len(clusters_before), "Not able to get "
"all created "
"clusters")
models.Cluster.delete_batch(self.session, ids)
clusters_after = models.Cluster.get_all(self.session,
status=models.Status.DELETING)
self.assertEqual(len(ids), len(clusters_after), "Not all cluster "
"record's statuses "
"were marked as "
"deleted")
class NodeRepositoryTests(base.TestCase):
def test_create(self):
"""Verifies a new cluster record and a new node record pointed to
correct cluster are created.
"""
cluster = {
"project_id": UUID1,
"name": "test",
"nic": UUID2,
"volume_size": 0,
"status": "BUILDING",
"deleted": False
} }
cluster_ref = models.Cluster.create(self.session, **cluster) node = models.Node()
node.update(node_values)
db_session = sql_api.get_session()
node.save(db_session)
node = { endpoint_values = {
"flavor": 'foo', "id": UUID1,
"instance_id": 'bar', "node_id": node.id,
"cluster_id": cluster_ref.id, "uri": "amqp://10.20.30.40:10000",
"volume_size": 0, "type": "AMQP",
"status": "BUILDING", "deleted": False,
"deleted": False
} }
node_ref = models.Node.create(self.session, **node) endpoint = models.Endpoint()
self.assertIsInstance(node_ref, models.Node) endpoint.update(endpoint_values)
self.assertEqual(endpoint_values["id"], endpoint.id, "Invalid ID "
"value")
self.assertEqual(endpoint_values["node_id"], endpoint.node_id,
"Invalid node_id value")
self.assertEqual(endpoint_values["uri"], endpoint.uri, "Invalid uri"
"value")
self.assertEqual(endpoint_values["type"], endpoint.type, "Invalid "
"type"
"value")
self.assertEqual(endpoint_values["deleted"], endpoint.deleted,
"Invalid deleted value")
endpoint.save(db_session)
endpoint_db = dbapi.get_endpoint_by_id(endpoint_values["id"])
self.assertEqual(endpoint_values["id"], endpoint_db.id, "Invalid ID "
"value")
self.assertEqual(endpoint_values["node_id"], endpoint_db.node_id,
"Invalid node_id value")
self.assertEqual(endpoint_values["uri"], endpoint_db.uri, "Invalid uri"
"value")
self.assertEqual(endpoint_values["type"], endpoint_db.type, "Invalid "
"type"
"value")
self.assertEqual(endpoint_values["deleted"], endpoint_db.deleted,
"Invalid deleted value")

View File

@ -1,59 +0,0 @@
# Copyright 2013 Hewlett-Packard Development Company, L.P.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Cue test utilities."""
from oslo.utils import timeutils
from cue import objects
def get_test_cluster(**kw):
return {
'id': kw.get('id', '1be26c0b-03f2-4d2e-ae87-c02d7f33c781'),
'project_id': kw.get('project_id', '1234567890'),
'name': kw.get('name', 'sample_cluster'),
'nic': kw.get('nic', '3dc26c0b-03f2-4d2e-ae87-c02d7f33c788'),
'status': kw.get('status', 'BUILDING'),
'volume_size': kw.get('volume_size', 10),
'created_at': kw.get('created_at', timeutils.utcnow()),
'updated_at': kw.get('updated_at', timeutils.utcnow()),
}
def create_test_cluster(**kw):
"""Create test Cluster entry in DB and return Cluster DB object.
Function to be used to create test Cluster objects in the database.
:param kw: kwargs with overriding values for cluster's attributes.
:returns: Test Cluster DB object.
"""
cluster = get_test_cluster(**kw)
cluster_parameters = {
'name': cluster['name'],
'nic': cluster['nic'],
'volume_size': cluster['volume_size'],
}
new_cluster = objects.Cluster(**cluster_parameters)
project_id = cluster['project_id']
number_of_nodes = 1
new_cluster.create_cluster(project_id, "flavor1", number_of_nodes)
return new_cluster

View File

View File

@ -0,0 +1,209 @@
# Copyright 2015 Hewlett-Packard Development Company, L.P.
#
# Authors: Davide Agnello <davide.agnello@hp.com>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright [2014] Hewlett-Packard Development Company, L.P.
# limitations under the License.
"""
Tests for cue objects classes.
"""
import iso8601
from oslo.utils import timeutils
from cue.api.controllers import v1
from cue.db import api as db_api
from cue.db.sqlalchemy import models
from cue import objects
from cue.tests import base
from cue.tests import utils as test_utils
class ClusterObjectsTests(base.TestCase):
dbapi = db_api.get_instance()
def compare_dates(self, datetime1, datetime2):
if datetime1.utcoffset() is None:
datetime1 = datetime1.replace(tzinfo=iso8601.iso8601.Utc())
if datetime2.utcoffset() is None:
datetime2 = datetime2.replace(tzinfo=iso8601.iso8601.Utc())
self.assertEqual(datetime1 == datetime2, True,
"Invalid datetime value")
def validate_cluster_values(self, cluster_ref, cluster_cmp):
self.assertEqual(cluster_ref.id if hasattr(cluster_ref, "id") else
cluster_ref["id"],
cluster_cmp.id if hasattr(cluster_cmp, "id") else
cluster_cmp["id"],
"Invalid cluster id value")
self.assertEqual(cluster_ref.network_id if hasattr(cluster_ref,
"network_id")
else cluster_ref["network_id"],
cluster_cmp.network_id if hasattr(cluster_cmp,
"network_id")
else cluster_cmp["network_id"],
"Invalid cluster network_id value")
self.assertEqual(cluster_ref.name if hasattr(cluster_ref, "name")
else cluster_ref["name"],
cluster_cmp.name if hasattr(cluster_cmp, "name")
else cluster_cmp["name"],
"Invalid cluster name value")
self.assertEqual(cluster_ref.status if hasattr(cluster_ref, "status")
else cluster_ref["status"],
cluster_cmp.status if hasattr(cluster_cmp, "status")
else cluster_cmp["status"],
"Invalid cluster status value")
self.assertEqual(cluster_ref.flavor if hasattr(cluster_ref, "flavor")
else cluster_ref["flavor"],
cluster_cmp.flavor if hasattr(cluster_cmp, "flavor")
else cluster_cmp["flavor"],
"Invalid cluster flavor value")
self.assertEqual(cluster_ref.size if hasattr(cluster_ref, "size")
else cluster_ref["size"],
cluster_cmp.size if hasattr(cluster_cmp, "size")
else cluster_cmp["size"],
"Invalid cluster size value")
self.assertEqual(cluster_ref.volume_size if hasattr(cluster_ref,
"volume_size")
else cluster_ref["volume_size"],
cluster_cmp.volume_size if hasattr(cluster_cmp,
"volume_size")
else cluster_cmp["volume_size"],
"Invalid cluster volume_size value")
self.compare_dates(cluster_ref.created_at if hasattr(cluster_ref,
"created_at")
else cluster_ref["created_at"],
cluster_cmp.created_at if hasattr(cluster_cmp,
"created_at")
else cluster_cmp["created_at"])
self.compare_dates(cluster_ref.updated_at if hasattr(cluster_ref,
"updated_at")
else cluster_ref["updated_at"],
cluster_cmp.updated_at if hasattr(cluster_cmp,
"updated_at")
else cluster_cmp["updated_at"])
def test_cluster_object_generation(self):
"""Test Cluster Object generation from a cluster dictionary object."""
cluster_dict = test_utils.get_test_cluster()
cluster_object = objects.Cluster(**cluster_dict)
self.validate_cluster_values(cluster_object, cluster_dict)
def test_node_object_generation(self):
"""Test Node Object generation from a cluster dictionary object."""
def test_endpoint_object_generation(self):
"""Test Endpoint Object generation from a cluster dictionary object."""
def test_cluster_api_to_object_to_api(self):
"""Tests Cluster api object conversion to Cluster object and back
to api object.
"""
api_cluster = test_utils.create_api_test_cluster_all()
object_cluster = objects.Cluster(**api_cluster.as_dict())
self.validate_cluster_values(api_cluster, object_cluster)
api_cluster_2 = v1.Cluster(**object_cluster.as_dict())
self.validate_cluster_values(api_cluster, api_cluster_2)
def test_cluster_db_to_object_to_db(self):
"""Tests Cluster db object conversion to Cluster object and back
to db object.
"""
db_cluster_object = test_utils.create_db_test_cluster_model_object(
deleted_at=timeutils.utcnow(), deleted=True)
object_cluster = objects.Cluster._from_db_object(objects.Cluster(),
db_cluster_object)
self.validate_cluster_values(db_cluster_object, object_cluster)
self.assertEqual(db_cluster_object.deleted,
object_cluster.deleted,
"Invalid cluster deleted_at value")
self.compare_dates(db_cluster_object.deleted_at
if hasattr(db_cluster_object, "deleted_at")
else db_cluster_object["deleted_at"],
object_cluster.deleted_at
if hasattr(object_cluster, "deleted_at")
else object_cluster["deleted_at"])
cluster_changes = object_cluster.obj_get_changes()
db_cluster_object_2 = models.Cluster()
db_cluster_object_2.update(cluster_changes)
self.validate_cluster_values(db_cluster_object, db_cluster_object_2)
self.assertEqual(db_cluster_object.deleted,
db_cluster_object_2.deleted,
"Invalid cluster deleted_at value")
self.compare_dates(db_cluster_object.deleted_at
if hasattr(db_cluster_object, "deleted_at")
else db_cluster_object["deleted_at"],
db_cluster_object_2.deleted_at
if hasattr(db_cluster_object_2, "deleted_at")
else db_cluster_object_2["deleted_at"])
class NodeObjectsTests(base.TestCase):
dbapi = db_api.get_instance()
def test_node_api_to_object_to_api(self):
"""Tests Node api object conversion to Node object and back
to api object.
"""
def test_node_db_to_object_to_db(self):
"""Tests Node db object conversion to Node object and back
to db object.
"""
class EndpointObjectsTests(base.TestCase):
dbapi = db_api.get_instance()
def test_endpoint_api_to_object_to_api(self):
"""Tests Endpoint api object conversion to Endpoint object and back
to api object.
"""
def test_endpoint_db_to_object_to_db(self):
"""Tests Endpoint db object conversion to Endpoint object and back
to db object.
"""
class ClusterObjectsApiTests(base.TestCase):
def test_create_cluster(self):
"""Tests creating a Cluster from Cluster objects API."""
def test_get_clusters(self):
"""Tests getting all Clusters from Cluster objects API."""
def test_get_clusters_by_id(self):
"""Tests get Cluster by id from Cluster objects API."""
def test_mark_cluster_as_delete(self):
"""Tests marking clusters for delete from Cluster objects API."""
class NodeObjectsApiTests(base.TestCase):
def test_get_nodes_by_cluster_id(self):
"""Tests get nodes by cluster id from Nodes objects API."""
class EndpointObjectsApiTests(base.TestCase):
def test_get_endpoints_by_node_id(self):
"""Tests get endpoint objects by node id from Endpoint objects API."""

152
cue/tests/utils.py Normal file
View File

@ -0,0 +1,152 @@
# Copyright 2013 Hewlett-Packard Development Company, L.P.
# All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Cue test utilities."""
from oslo.utils import timeutils
from cue.api.controllers import v1
from cue.db.sqlalchemy import models
from cue import objects
def get_test_cluster(**kw):
return {
'id': kw.get('id', '1be26c0b-03f2-4d2e-ae87-c02d7f33c781'),
'project_id': kw.get('project_id', '1234567890'),
'name': kw.get('name', 'sample_cluster'),
'network_id': kw.get('network_id',
'3dc26c0b-03f2-4d2e-ae87-c02d7f33c788'),
'status': kw.get('status', 'BUILDING'),
'flavor': kw.get('flavor', 'flavor1'),
'size': kw.get('size', 1),
'volume_size': kw.get('volume_size', 10),
'deleted': kw.get('deleted', False),
'created_at': kw.get('created_at', timeutils.utcnow()),
'updated_at': kw.get('updated_at', timeutils.utcnow()),
'deleted_at': kw.get('deleted_at', None),
}
def create_api_test_cluster(**kw):
"""Create test Cluster api object and return this object.
Function to be used to acquire an API Cluster object set with only required
fields. This would mimic a cluster object values received from REST API.
:param kw: kwargs with overriding values for cluster's attributes.
:returns: Test Cluster API object.
"""
cluster = get_test_cluster(**kw)
cluster_parameters = {
'name': cluster['name'],
'network_id': cluster['network_id'],
'flavor': cluster['flavor'],
'size': str(cluster['size']),
'volume_size': str(cluster['volume_size']),
}
new_cluster = v1.Cluster(**cluster_parameters)
return new_cluster
def create_api_test_cluster_all(**kw):
"""Create fully-populated test Cluster api object and return this object.
Function to be used to acquire an API Cluster object with all fields set.
:param kw: kwargs with overriding values for cluster's attributes.
:returns: Test Cluster API object.
"""
cluster = get_test_cluster(**kw)
cluster_parameters = {
'name': cluster['name'],
'network_id': cluster['network_id'],
'flavor': cluster['flavor'],
'size': cluster['size'],
'volume_size': cluster['volume_size'],
'id': cluster['id'],
'project_id': cluster['project_id'],
'status': cluster['status'],
'created_at': cluster['created_at'],
'updated_at': cluster['updated_at'],
}
new_cluster = v1.Cluster(**cluster_parameters)
return new_cluster
def create_db_test_cluster_from_objects_api(**kw):
"""Create test Cluster entry in DB from objects API and return Cluster
DB object. Function to be used to create test Cluster objects in the
database.
:param kw: kwargs with overriding values for cluster's attributes.
:returns: Test Cluster DB object.
"""
cluster = get_test_cluster(**kw)
cluster_parameters = {
'name': cluster['name'],
'network_id': cluster['network_id'],
'flavor': cluster['flavor'],
'size': cluster['size'],
'volume_size': cluster['volume_size'],
}
new_cluster = objects.Cluster(**cluster_parameters)
project_id = cluster['project_id']
new_cluster.create(project_id)
return new_cluster
def create_db_test_cluster_model_object(**kw):
"""Create test Cluster DB model object.
:param kw: kwargs with overriding values for cluster's attributes.
:returns: Test Cluster DB model object.
"""
cluster = get_test_cluster(**kw)
cluster_parameters = {
'name': cluster['name'],
'network_id': cluster['network_id'],
'flavor': cluster['flavor'],
'size': cluster['size'],
'volume_size': cluster['volume_size'],
'id': cluster['id'],
'project_id': cluster['project_id'],
'status': cluster['status'],
'deleted': cluster['deleted'],
'created_at': cluster['created_at'],
'updated_at': cluster['updated_at'],
'deleted_at': cluster['deleted_at'],
}
new_cluster = models.Cluster()
new_cluster.update(cluster_parameters)
return new_cluster