Properly test access to tls_refs in the API layer

Change-Id: I264c525c36f301d23378a6d72aded741fcb9f4f6
Story: 2001640
Task: 6658
This commit is contained in:
Adam Harwell 2018-03-06 17:25:34 -08:00
parent 649b33d247
commit a1e443ccea
8 changed files with 157 additions and 23 deletions

View File

@ -19,6 +19,7 @@ from oslo_config import cfg
from oslo_db import exception as odb_exceptions
from oslo_utils import excutils
import pecan
from stevedore import driver as stevedore_driver
from wsme import types as wtypes
from wsmeext import pecan as wsme_pecan
@ -43,6 +44,11 @@ class ListenersController(base.BaseController):
def __init__(self):
super(ListenersController, self).__init__()
self.handler = self.handler.listener
self.cert_manager = stevedore_driver.DriverManager(
namespace='octavia.cert_manager',
name=CONF.certificates.cert_manager,
invoke_on_load=True,
).driver
def _get_db_listener(self, session, id):
"""Gets a listener object from the database."""
@ -125,7 +131,19 @@ class ListenersController(base.BaseController):
session, lb_id,
provisioning_status=constants.ACTIVE)
def _validate_create_listener(self, lock_session, lb_id, listener_dict):
def _validate_tls_refs(self, tls_refs):
context = pecan.request.context.get('octavia_context')
bad_refs = []
for ref in tls_refs:
try:
self.cert_manager.get_cert(context, ref, check_only=True)
except Exception:
bad_refs.append(ref)
if bad_refs:
raise exceptions.CertificateRetrievalException(ref=bad_refs)
def _validate_create_listener(self, lock_session, listener_dict):
"""Validate listener for wrong protocol or duplicate listeners
Update the load balancer db when provisioning status changes.
@ -140,6 +158,10 @@ class ListenersController(base.BaseController):
try:
sni_containers = listener_dict.pop('sni_containers', [])
tls_refs = [sni['tls_container_id'] for sni in sni_containers]
if listener_dict.get('tls_certificate_id'):
tls_refs.append(listener_dict.get('tls_certificate_id'))
self._validate_tls_refs(tls_refs)
db_listener = self.repositories.listener.create(
lock_session, **listener_dict)
if sni_containers:
@ -222,7 +244,7 @@ class ListenersController(base.BaseController):
lock_session, lb_id=load_balancer_id)
db_listener = self._validate_create_listener(
lock_session, load_balancer_id, listener_dict)
lock_session, listener_dict)
lock_session.commit()
except Exception:
with excutils.save_and_reraise_exception():
@ -240,7 +262,7 @@ class ListenersController(base.BaseController):
self._validate_pool(lock_session, load_balancer_id,
listener_dict['default_pool_id'])
db_listener = self._validate_create_listener(
lock_session, load_balancer_id, listener_dict)
lock_session, listener_dict)
# Now create l7policies
new_l7ps = []
@ -284,6 +306,12 @@ class ListenersController(base.BaseController):
self._test_lb_and_listener_statuses(context.session, load_balancer_id,
id=id)
sni_containers = listener.sni_container_refs or []
tls_refs = [sni for sni in sni_containers]
if listener.default_tls_container_ref:
tls_refs.append(listener.default_tls_container_ref)
self._validate_tls_refs(tls_refs)
try:
LOG.info("Sending Update of Listener %s to handler", id)
self.handler.update(db_listener, listener)

View File

@ -23,6 +23,7 @@ from stevedore import driver as stevedore_driver
from octavia.certificates.common import barbican as barbican_common
from octavia.certificates.manager import cert_mgr
from octavia.common.tls_utils import cert_parser
LOG = logging.getLogger(__name__)
@ -142,10 +143,19 @@ class BarbicanCertManager(cert_mgr.CertManager):
name=service_name,
url=resource_ref
)
return barbican_common.BarbicanCert(cert_container)
barbican_cert = barbican_common.BarbicanCert(cert_container)
LOG.debug('Validating certificate data for %s.', cert_ref)
cert_parser.validate_cert(
barbican_cert.get_certificate(),
private_key=barbican_cert.get_private_key(),
private_key_passphrase=(
barbican_cert.get_private_key_passphrase()),
intermediates=barbican_cert.get_intermediates())
LOG.debug('Certificate data validated for %s.', cert_ref)
return barbican_cert
except Exception as e:
with excutils.save_and_reraise_exception():
LOG.error('Error getting %s: %s', cert_ref, e)
LOG.error('Error getting cert %s: %s', cert_ref, str(e))
def delete_cert(self, context, cert_ref, resource_ref, service_name=None):
"""Deregister as a consumer for the specified cert.

View File

@ -132,6 +132,11 @@ class MisMatchedKey(OctaviaException):
message = _("Key and x509 certificate do not match")
class CertificateRetrievalException(APIException):
msg = _('Could not retrieve certificate: %(ref)s')
code = 400
class CertificateStorageException(OctaviaException):
message = _('Could not store certificate: %(msg)s')

View File

@ -50,10 +50,10 @@ def validate_cert(certificate, private_key=None,
:returns: boolean
"""
cert = _get_x509_from_pem_bytes(certificate)
if intermediates:
for imd in get_intermediates_pems(intermediates):
# Loading the certificates validates them
pass
if intermediates and not isinstance(intermediates, list):
# If the intermediates are in a list, then they are already loaded.
# Load the certificates to validate them, if they weren't already.
list(get_intermediates_pems(intermediates))
if private_key:
pkey = _read_private_key(private_key,
passphrase=private_key_passphrase)
@ -71,7 +71,7 @@ def _read_private_key(private_key_pem, passphrase=None):
:param passphrase: Optional passphrase needed to decrypt the private key
:returns: a RSAPrivatekey object
"""
if passphrase:
if passphrase and type(passphrase) == six.text_type:
passphrase = passphrase.encode("utf-8")
if type(private_key_pem) == six.text_type:
private_key_pem = private_key_pem.encode('utf-8')

View File

@ -90,6 +90,9 @@ class BaseAPITest(base_db_test.OctaviaDBTestBase):
patcher = mock.patch('octavia.api.handlers.controller_simulator.'
'handler.SimulatedControllerHandler')
self.handler_mock = patcher.start()
patcher2 = mock.patch('octavia.certificates.manager.barbican.'
'BarbicanCertManager')
self.cert_manager_mock = patcher2.start()
self.app = self._make_app()
self.project_id = uuidutils.generate_uuid()

View File

@ -467,6 +467,25 @@ class TestListener(base.BaseAPITest):
self.create_listener(constants.PROTOCOL_HTTP, 80, self.lb_id,
status=409)
def test_create_bad_tls_ref(self):
sni1 = uuidutils.generate_uuid()
sni2 = uuidutils.generate_uuid()
tls_ref = uuidutils.generate_uuid()
lb_listener = {'name': 'listener1', 'default_pool_id': None,
'protocol': constants.PROTOCOL_TERMINATED_HTTPS,
'protocol_port': 80,
'sni_container_refs': [sni1, sni2],
'default_tls_container_ref': tls_ref,
'loadbalancer_id': self.lb_id}
body = self._build_body(lb_listener)
self.cert_manager_mock().get_cert.side_effect = [
Exception("bad cert"), None, Exception("bad_cert")]
response = self.post(self.LISTENERS_PATH, body, status=400).json
self.assertIn(sni1, response['faultstring'])
self.assertNotIn(sni2, response['faultstring'])
self.assertIn(tls_ref, response['faultstring'])
def test_create_with_default_pool_id(self):
lb_listener = {'name': 'listener1',
'default_pool_id': self.pool_id,
@ -983,6 +1002,37 @@ class TestListener(base.BaseAPITest):
listener_id=listener['listener'].get('id'))
self.put(listener_path, {}, status=400)
def test_update_bad_tls_ref(self):
sni1 = uuidutils.generate_uuid()
sni2 = uuidutils.generate_uuid()
tls_ref = uuidutils.generate_uuid()
tls_ref2 = uuidutils.generate_uuid()
lb_listener = {'name': 'listener1', 'default_pool_id': None,
'protocol': constants.PROTOCOL_TERMINATED_HTTPS,
'protocol_port': 80,
'sni_container_refs': [sni1, sni2],
'default_tls_container_ref': tls_ref,
'loadbalancer_id': self.lb_id}
body = self._build_body(lb_listener)
api_listener = self.post(
self.LISTENERS_PATH, body).json['listener']
self.set_lb_status(self.lb_id)
lb_listener_put = {
'default_tls_container_ref': tls_ref2,
'sni_container_refs': [sni1, sni2]
}
body = self._build_body(lb_listener_put)
listener_path = self.LISTENER_PATH.format(
listener_id=api_listener['id'])
self.cert_manager_mock().get_cert.side_effect = [
Exception("bad cert"), None, Exception("bad cert")]
response = self.put(listener_path, body, status=400).json
self.assertIn(tls_ref2, response['faultstring'])
self.assertIn(sni1, response['faultstring'])
self.assertNotIn(sni2, response['faultstring'])
self.assertNotIn(tls_ref, response['faultstring'])
def test_update_pending_update(self):
lb = self.create_load_balancer(uuidutils.generate_uuid())
optionals = {'name': 'lb1', 'description': 'desc1',

View File

@ -11,11 +11,11 @@
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import uuid
from barbicanclient.v1 import containers
from barbicanclient.v1 import secrets
import mock
from oslo_utils import uuidutils
import six
import octavia.certificates.common.barbican as barbican_common
@ -33,28 +33,38 @@ class TestBarbicanManager(base.TestCase):
def setUp(self):
# Make a fake Container and contents
self.barbican_endpoint = 'http://localhost:9311/v1'
self.container_uuid = uuid.uuid4()
self.container_uuid = uuidutils.generate_uuid()
self.certificate_uuid = uuidutils.generate_uuid()
self.intermediates_uuid = uuidutils.generate_uuid()
self.private_key_uuid = uuidutils.generate_uuid()
self.private_key_passphrase_uuid = uuidutils.generate_uuid()
self.container_ref = '{0}/containers/{1}'.format(
self.barbican_endpoint, self.container_uuid
)
self.barbican_api = mock.MagicMock()
self.name = 'My Fancy Cert'
self.certificate = secrets.Secret(
api=mock.MagicMock(),
payload=sample.X509_CERT
api=self.barbican_api,
payload=sample.X509_CERT,
secret_ref=self.certificate_uuid
)
self.intermediates = secrets.Secret(
api=mock.MagicMock(),
payload=sample.X509_IMDS
api=self.barbican_api,
payload=sample.X509_IMDS,
secret_ref=self.intermediates_uuid
)
self.private_key = secrets.Secret(
api=mock.MagicMock(),
payload=sample.X509_CERT_KEY_ENCRYPTED
api=self.barbican_api,
payload=sample.X509_CERT_KEY_ENCRYPTED,
secret_ref=self.private_key_uuid
)
self.private_key_passphrase = secrets.Secret(
api=mock.MagicMock(),
payload=sample.X509_CERT_KEY_PASSPHRASE
api=self.barbican_api,
payload=sample.X509_CERT_KEY_PASSPHRASE,
secret_ref=self.private_key_passphrase_uuid
)
container = mock.Mock(spec=containers.CertificateContainer)
@ -225,6 +235,23 @@ class TestBarbicanManager(base.TestCase):
self.assertEqual(data.get_private_key_passphrase(),
six.b(self.private_key_passphrase.payload))
def test_get_cert_no_registration_raise_on_secret_access_failure(self):
self.bc.containers.get.return_value = self.container
type(self.certificate).payload = mock.PropertyMock(
side_effect=ValueError)
# Get the container data
self.assertRaises(
ValueError, self.cert_manager.get_cert,
context=self.context,
cert_ref=self.container_ref, check_only=True
)
# 'get' should be called once with the container_ref
self.bc.containers.get.assert_called_once_with(
container_ref=self.container_ref
)
def test_delete_cert(self):
# Attempt to deregister as a consumer
self.cert_manager.delete_cert(

View File

@ -78,14 +78,25 @@ class TestTLSParseUtils(base.TestCase):
self.assertTrue(
cert_parser.validate_cert(
sample_certs.X509_CERT,
private_key=sample_certs.X509_CERT_KEY,
intermediates=(sample_certs.TEST_X509_IMDS +
b"\nParser should ignore junk\n")))
private_key=sample_certs.X509_CERT_KEY.decode('utf-8')))
self.assertRaises(exceptions.MisMatchedKey,
cert_parser.validate_cert,
sample_certs.X509_CERT,
private_key=sample_certs.X509_CERT_KEY_2)
def test_validate_cert_handles_intermediates(self):
self.assertTrue(
cert_parser.validate_cert(
sample_certs.X509_CERT,
private_key=sample_certs.X509_CERT_KEY,
intermediates=(sample_certs.X509_IMDS +
b"\nParser should ignore junk\n")))
self.assertTrue(
cert_parser.validate_cert(
sample_certs.X509_CERT,
private_key=sample_certs.X509_CERT_KEY,
intermediates=sample_certs.X509_IMDS_LIST))
def test_split_x509s(self):
imds = []
for x509Pem in cert_parser._split_x509s(sample_certs.TEST_X509_IMDS):