Implement custom constraints

This adds a new registry of constraints and a CustomConstraint class to
be able to validate against external services.

blueprint param-constraints

Co-Authored-By: cedric.soulas@cloudwatt.com
Change-Id: Iffaf2c1179eaf5d88cdea02b63782cea9f3fc616
This commit is contained in:
Thomas Herve 2014-01-31 11:26:11 +01:00
parent 764f2706d8
commit 10417621df
5 changed files with 196 additions and 5 deletions

View File

@ -17,6 +17,8 @@ import collections
import numbers
import re
from heat.engine import resources
class InvalidSchemaError(Exception):
pass
@ -439,3 +441,55 @@ class AllowedPattern(Constraint):
def _constraint(self):
return self.pattern
class CustomConstraint(Constraint):
"""
A constraint delegating validation to an external class.
"""
valid_types = (Schema.STRING_TYPE, Schema.INTEGER_TYPE, Schema.NUMBER_TYPE,
Schema.BOOLEAN_TYPE, Schema.LIST_TYPE)
def __init__(self, name, description=None, environment=None):
super(CustomConstraint, self).__init__(description)
self.name = name
self._environment = environment
self._custom_constraint = None
def _constraint(self):
return self.name
@property
def custom_constraint(self):
if self._custom_constraint is None:
if self._environment is None:
self._environment = resources.global_env()
constraint_class = self._environment.get_constraint(self.name)
if constraint_class:
self._custom_constraint = constraint_class()
return self._custom_constraint
def _str(self):
message = getattr(self.custom_constraint, "message", None)
if not message:
message = _('Value must be of type %s') % self.name
return message
def _err_msg(self, value):
constraint = self.custom_constraint
if constraint is None:
return _('"%(value)s" does not validate %(name)s '
'(constraint not found)') % {
"value": value, "name": self.name}
error = getattr(constraint, "error", None)
if error:
return error(value)
return _('"%(value)s" does not validate %(name)s') % {
"value": value, "name": self.name}
def _is_valid(self, value, context):
constraint = self.custom_constraint
if not constraint:
return False
return constraint.validate(value, context)

View File

@ -347,6 +347,7 @@ class Environment(object):
else:
self.params = dict((k, v) for (k, v) in env.iteritems()
if k != RESOURCE_REGISTRY)
self.constraints = {}
def load(self, env_snippet):
self.registry.load(env_snippet.get(RESOURCE_REGISTRY, {}))
@ -360,6 +361,9 @@ class Environment(object):
def register_class(self, resource_type, resource_class):
self.registry.register_class(resource_type, resource_class)
def register_constraint(self, constraint_name, constraint):
self.constraints[constraint_name] = constraint
def get_class(self, resource_type, resource_name=None):
return self.registry.get_class(resource_type, resource_name)
@ -370,3 +374,6 @@ class Environment(object):
registry_type=None):
return self.registry.get_resource_info(resource_type, resource_name,
registry_type)
def get_constraint(self, name):
return self.constraints.get(name)

View File

@ -14,6 +14,7 @@
# under the License.
import glob
import itertools
import os
import os.path
@ -29,11 +30,15 @@ LOG = log.getLogger(__name__)
def _register_resources(env, type_pairs):
for res_name, res_class in type_pairs:
env.register_class(res_name, res_class)
def _register_constraints(env, type_pairs):
for constraint_name, constraint in type_pairs:
env.register_constraint(constraint_name, constraint)
def _get_module_resources(module):
if callable(getattr(module, 'resource_mapping', None)):
try:
@ -44,11 +49,22 @@ def _get_module_resources(module):
return []
def _register_modules(env, modules):
import itertools
def _get_module_constraints(module):
if callable(getattr(module, 'constraint_mapping', None)):
return module.constraint_mapping().iteritems()
else:
return []
resource_lists = (_get_module_resources(m) for m in modules)
_register_resources(env, itertools.chain.from_iterable(resource_lists))
def _register_modules(env, modules):
data_lists = [(_get_module_resources(m), _get_module_constraints(m))
for m in modules]
if data_lists:
resource_lists, constraint_lists = zip(*data_lists)
_register_resources(env, itertools.chain.from_iterable(resource_lists))
_register_constraints(
env, itertools.chain.from_iterable(constraint_lists))
_environment = None

View File

@ -16,6 +16,7 @@
import testtools
from heat.engine import constraints
from heat.engine import environment
class SchemaTest(testtools.TestCase):
@ -218,3 +219,66 @@ class SchemaTest(testtools.TestCase):
constraints.Length, '1', 10)
self.assertRaises(constraints.InvalidSchemaError,
constraints.Length, 1, '10')
class CustomConstraintTest(testtools.TestCase):
def setUp(self):
super(CustomConstraintTest, self).setUp()
self.env = environment.Environment({})
def test_validation(self):
class ZeroConstraint(object):
def validate(self, value, context):
return value == 0
self.env.register_constraint("zero", ZeroConstraint)
constraint = constraints.CustomConstraint("zero", environment=self.env)
self.assertEqual("Value must be of type zero", str(constraint))
self.assertIsNone(constraint.validate(0))
error = self.assertRaises(ValueError, constraint.validate, 1)
self.assertEqual('"1" does not validate zero', str(error))
def test_custom_error(self):
class ZeroConstraint(object):
def error(self, value):
return "%s is not 0" % value
def validate(self, value, context):
return value == 0
self.env.register_constraint("zero", ZeroConstraint)
constraint = constraints.CustomConstraint("zero", environment=self.env)
error = self.assertRaises(ValueError, constraint.validate, 1)
self.assertEqual("1 is not 0", str(error))
def test_custom_message(self):
class ZeroConstraint(object):
message = "Only zero!"
def validate(self, value, context):
return value == 0
self.env.register_constraint("zero", ZeroConstraint)
constraint = constraints.CustomConstraint("zero", environment=self.env)
self.assertEqual("Only zero!", str(constraint))
def test_unknown_constraint(self):
constraint = constraints.CustomConstraint("zero", environment=self.env)
error = self.assertRaises(ValueError, constraint.validate, 1)
self.assertEqual('"1" does not validate zero (constraint not found)',
str(error))
def test_constraints(self):
class ZeroConstraint(object):
def validate(self, value, context):
return value == 0
self.env.register_constraint("zero", ZeroConstraint)
constraint = constraints.CustomConstraint("zero", environment=self.env)
self.assertEqual("zero", constraint["custom_constraint"])

View File

@ -15,6 +15,7 @@
import fixtures
import mock
import os.path
import sys
from oslo.config import cfg
@ -105,6 +106,55 @@ class EnvironmentTest(common.HeatTestCase):
env.get_resource_info('OS::Networking::FloatingIP',
'my_fip').value)
def test_constraints(self):
env = environment.Environment({})
first_constraint = object()
second_constraint = object()
env.register_constraint("constraint1", first_constraint)
env.register_constraint("constraint2", second_constraint)
self.assertIs(first_constraint, env.get_constraint("constraint1"))
self.assertIs(second_constraint, env.get_constraint("constraint2"))
self.assertIs(None, env.get_constraint("no_constraint"))
def test_constraints_registry(self):
constraint_content = '''
class MyConstraint(object):
pass
def constraint_mapping():
return {"constraint1": MyConstraint}
'''
plugin_dir = self.useFixture(fixtures.TempDir())
plugin_file = os.path.join(plugin_dir.path, 'test.py')
with open(plugin_file, 'w+') as ef:
ef.write(constraint_content)
self.addCleanup(sys.modules.pop, "heat.engine.plugins.test")
cfg.CONF.set_override('plugin_dirs', plugin_dir.path)
env = environment.Environment({})
resources._load_all(env)
self.assertEqual("MyConstraint",
env.get_constraint("constraint1").__name__)
self.assertIs(None, env.get_constraint("no_constraint"))
def test_constraints_registry_error(self):
constraint_content = '''
def constraint_mapping():
raise ValueError("oops")
'''
plugin_dir = self.useFixture(fixtures.TempDir())
plugin_file = os.path.join(plugin_dir.path, 'test.py')
with open(plugin_file, 'w+') as ef:
ef.write(constraint_content)
self.addCleanup(sys.modules.pop, "heat.engine.plugins.test")
cfg.CONF.set_override('plugin_dirs', plugin_dir.path)
env = environment.Environment({})
error = self.assertRaises(ValueError, resources._load_all, env)
self.assertEqual("oops", str(error))
class EnvironmentDuplicateTest(common.HeatTestCase):