Add PKI API for compute nodes certificates
Treat the control node as a CA for certificates at compute nodes. Upon joining a cluster, the compute node will request a certificate to be created by generating a CSR and asking the control node to sign the certificate. This adds new config options for the compute private keys and certificate locations in use. Change-Id: I8e8b1a86cf7df752b6cb34cfdf65a87a72934ec5
This commit is contained in:
		@@ -83,6 +83,10 @@ def _get_default_config():
 | 
			
		||||
        f'{snap_common}/etc/ssl/certs/cert.pem',
 | 
			
		||||
        'config.tls.key-path':
 | 
			
		||||
        f'{snap_common}/etc/ssl/private/key.pem',
 | 
			
		||||
        'config.tls.compute.cert-path':
 | 
			
		||||
        f'{snap_common}/etc/ssl/certs/compute-cert.pem',
 | 
			
		||||
        'config.tls.compute.key-path':
 | 
			
		||||
        f'{snap_common}/etc/ssl/private/compute-key.pem',
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -136,6 +136,8 @@ setup:
 | 
			
		||||
    tls_cacert_path: 'config.tls.cacert-path'
 | 
			
		||||
    tls_cert_path: 'config.tls.cert-path'
 | 
			
		||||
    tls_key_path: 'config.tls.key-path'
 | 
			
		||||
    tls_compute_cert_path: 'config.tls.compute.cert-path'
 | 
			
		||||
    tls_compute_key_path: 'config.tls.compute.key-path'
 | 
			
		||||
entry_points:
 | 
			
		||||
  keystone-manage:
 | 
			
		||||
    binary: "{snap}/bin/keystone-manage"
 | 
			
		||||
 
 | 
			
		||||
@@ -4,15 +4,24 @@ import sys
 | 
			
		||||
import urllib3
 | 
			
		||||
import json
 | 
			
		||||
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
from cluster import shell
 | 
			
		||||
from init import tls
 | 
			
		||||
 | 
			
		||||
CLUSTER_SERVICE_PORT = 10002
 | 
			
		||||
 | 
			
		||||
CLIENT_API_VERSION = '2.0.0'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class UnauthorizedRequestError(Exception):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class UnsupportedAPIError(Exception):
 | 
			
		||||
    pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def join():
 | 
			
		||||
    """Join an existing cluster as a compute node."""
 | 
			
		||||
 | 
			
		||||
@@ -44,7 +53,7 @@ def join():
 | 
			
		||||
        resp = conn_pool.urlopen(
 | 
			
		||||
            'POST', '/join', retries=0, preload_content=True,
 | 
			
		||||
            headers={
 | 
			
		||||
                'API-VERSION': '1.0.0',
 | 
			
		||||
                'API-VERSION': CLIENT_API_VERSION,
 | 
			
		||||
                'Content-Type': 'application/json',
 | 
			
		||||
            }, body=request_body)
 | 
			
		||||
    except urllib3.exceptions.MaxRetryError as e:
 | 
			
		||||
@@ -59,22 +68,41 @@ def join():
 | 
			
		||||
        raise Exception('Could not retrieve a response from the clustering'
 | 
			
		||||
                        ' service.') from e
 | 
			
		||||
 | 
			
		||||
    if resp.status == 401:
 | 
			
		||||
        response_data = resp.data.decode('utf-8')
 | 
			
		||||
        # TODO: this should be more bulletproof in case a proxy server
 | 
			
		||||
        # returns this response - it will not have the expected format.
 | 
			
		||||
        print('An authorization failure has occurred while joining the'
 | 
			
		||||
              ' the cluster: please make sure the connection string'
 | 
			
		||||
              ' was entered as returned by the "add-compute" command'
 | 
			
		||||
              ' and that it was used before its expiration time.',
 | 
			
		||||
              file=sys.stderr)
 | 
			
		||||
        if response_data:
 | 
			
		||||
            message = json.loads(response_data)['message']
 | 
			
		||||
            raise UnauthorizedRequestError(message)
 | 
			
		||||
        raise UnauthorizedRequestError()
 | 
			
		||||
    if resp.status != 200:
 | 
			
		||||
        raise Exception('Unexpected response status received from the'
 | 
			
		||||
                        f' clustering service: {resp.status}')
 | 
			
		||||
        response_data = resp.data.decode('utf-8')
 | 
			
		||||
        message = ''
 | 
			
		||||
        if response_data:
 | 
			
		||||
            try:
 | 
			
		||||
                message = json.loads(response_data)['message']
 | 
			
		||||
            except json.JSONDecodeError:
 | 
			
		||||
                message = resp.data
 | 
			
		||||
            raise UnauthorizedRequestError(message)
 | 
			
		||||
 | 
			
		||||
        if resp.status == 401:
 | 
			
		||||
            print('An authorization failure has occurred while joining the'
 | 
			
		||||
                  ' the cluster: please make sure the connection string'
 | 
			
		||||
                  ' was entered as returned by the "add-compute" command'
 | 
			
		||||
                  ' and that it was used before its expiration time.',
 | 
			
		||||
                  file=sys.stderr)
 | 
			
		||||
            raise UnauthorizedRequestError(message)
 | 
			
		||||
        elif resp.status == 410:
 | 
			
		||||
            print('The control node no longer supports API version '
 | 
			
		||||
                  f'{CLIENT_API_VERSION}. Please update the local compute node'
 | 
			
		||||
                  ' snap version to match the version running on the control '
 | 
			
		||||
                  f'node {control_hostname}.', file=sys.stderr)
 | 
			
		||||
            raise UnsupportedAPIError(message)
 | 
			
		||||
        elif resp.status == 501:
 | 
			
		||||
            print('The control node does not support API version '
 | 
			
		||||
                  f'{CLIENT_API_VERSION}. Please update the microstack snap '
 | 
			
		||||
                  f'on control node {control_hostname} to a later version and '
 | 
			
		||||
                  'try again', file=sys.stderr)
 | 
			
		||||
            raise UnsupportedAPIError(message)
 | 
			
		||||
        else:
 | 
			
		||||
            msg = ('Unexpected response status received from the clustering '
 | 
			
		||||
                   f'service on {control_hostname}: {resp.status}')
 | 
			
		||||
            if message:
 | 
			
		||||
                msg = f'{msg} message: {message}'
 | 
			
		||||
            raise Exception(msg)
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        response_data = resp.data.decode('utf-8')
 | 
			
		||||
@@ -111,6 +139,66 @@ def join():
 | 
			
		||||
            f.write(response_dict[tls_file])
 | 
			
		||||
    shell.config_set(**{'config.tls.generate-self-signed': False})
 | 
			
		||||
 | 
			
		||||
    compute_key_path = shell.config_get('config.tls.compute.key-path')
 | 
			
		||||
    compute_key_path = Path(compute_key_path)
 | 
			
		||||
    tls.create_or_get_private_key(compute_key_path)
 | 
			
		||||
    csr = tls.create_csr(compute_key_path)
 | 
			
		||||
 | 
			
		||||
    # Request CSRs for PKI settings
 | 
			
		||||
    request_body = json.dumps({
 | 
			
		||||
        'credential-id': credential_id,
 | 
			
		||||
        'credential-secret': credential_secret,
 | 
			
		||||
        'csr': csr.decode('utf-8')
 | 
			
		||||
    })
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        resp = conn_pool.urlopen(
 | 
			
		||||
            'POST', '/pki/sign', retries=0, preload_content=True,
 | 
			
		||||
            headers={
 | 
			
		||||
                'API-VERSION': CLIENT_API_VERSION,
 | 
			
		||||
                'Content-Type': 'application/json',
 | 
			
		||||
            }, body=request_body)
 | 
			
		||||
    except urllib3.exceptions.MaxRetryError as e:
 | 
			
		||||
        if isinstance(e.reason, urllib3.exceptions.SSLError):
 | 
			
		||||
            raise Exception(
 | 
			
		||||
                'The actual clustering service certificate fingerprint'
 | 
			
		||||
                ' did not match the expected one, please make sure that: '
 | 
			
		||||
                '(1) that a correct token was specified during initialization;'
 | 
			
		||||
                ' (2) a MITM attacks are not performed against HTTPS requests'
 | 
			
		||||
                ' (including transparent proxies).'
 | 
			
		||||
            ) from e.reason
 | 
			
		||||
        raise Exception('Could not retrieve a response from the clustering'
 | 
			
		||||
                        ' service.') from e
 | 
			
		||||
 | 
			
		||||
    if resp.status == 401:
 | 
			
		||||
        response_data = resp.data.decode('utf-8')
 | 
			
		||||
        # TODO: this should be more bulletproof in case a proxy server
 | 
			
		||||
        # returns this response - it will not have the expected format.
 | 
			
		||||
        print('An authorization failure has occurred while joining the'
 | 
			
		||||
              ' the cluster: please make sure the connection string'
 | 
			
		||||
              ' was entered as returned by the "add-compute" command'
 | 
			
		||||
              ' and that it was used before its expiration time.',
 | 
			
		||||
              file=sys.stderr)
 | 
			
		||||
        if response_data:
 | 
			
		||||
            message = json.loads(response_data)['message']
 | 
			
		||||
            raise UnauthorizedRequestError(message)
 | 
			
		||||
        raise UnauthorizedRequestError()
 | 
			
		||||
    if resp.status != 200:
 | 
			
		||||
        raise Exception('Unexpected response status received from the'
 | 
			
		||||
                        f' clustering service: {resp.status}')
 | 
			
		||||
 | 
			
		||||
    # The API was introduced so there's a possibility that we get a 404
 | 
			
		||||
    if resp.status == 404:
 | 
			
		||||
        raise Exception('The control node does not support CSR.')
 | 
			
		||||
 | 
			
		||||
    if not resp.data:
 | 
			
		||||
        raise Exception('The clustering service did not return a certificate, '
 | 
			
		||||
                        'which is unexpected. Check its status and try again.')
 | 
			
		||||
 | 
			
		||||
    compute_cert_path = shell.config_get('config.tls.compute.cert-path')
 | 
			
		||||
    with open(compute_cert_path, 'wb+') as f:
 | 
			
		||||
        f.write(resp.data)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    join()
 | 
			
		||||
 
 | 
			
		||||
@@ -6,12 +6,15 @@ import semantic_version
 | 
			
		||||
import keystoneclient.exceptions as kc_exceptions
 | 
			
		||||
 | 
			
		||||
from flask import Flask, request, jsonify
 | 
			
		||||
from functools import wraps
 | 
			
		||||
from werkzeug.exceptions import BadRequest
 | 
			
		||||
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
from cluster.shell import check_output
 | 
			
		||||
from cluster.shell import config_get
 | 
			
		||||
 | 
			
		||||
from init.tls import generate_cert_from_csr
 | 
			
		||||
 | 
			
		||||
from keystoneauth1.identity import v3
 | 
			
		||||
from keystoneauth1 import session
 | 
			
		||||
from keystoneclient.v3 import client as v3client
 | 
			
		||||
@@ -22,7 +25,7 @@ logger = logging.getLogger(__name__)
 | 
			
		||||
app = Flask(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
API_VERSION = semantic_version.Version('1.0.0')
 | 
			
		||||
API_VERSION = semantic_version.Version('2.0.0')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Unauthorized(Exception):
 | 
			
		||||
@@ -68,6 +71,11 @@ class IncorrectContentType(APIException):
 | 
			
		||||
               'application/json.')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MissingParameterError(APIException):
 | 
			
		||||
    status_code = 400
 | 
			
		||||
    message = 'The request is missing necessary data.'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MissingAuthDataInRequest(APIException):
 | 
			
		||||
    status_code = 400
 | 
			
		||||
    message = 'The request does not have the required authentication data.'
 | 
			
		||||
@@ -146,60 +154,26 @@ def handle_authorization_failed(error):
 | 
			
		||||
    return _handle_api_version_exception(error)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.errorhandler(MissingParameterError)
 | 
			
		||||
def handle_missing_parameter_error(error):
 | 
			
		||||
    return _handle_api_version_exception(error)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.errorhandler(UnexpectedError)
 | 
			
		||||
def handle_unexpected_error(error):
 | 
			
		||||
    return _handle_api_version_exception(error)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def join_info():
 | 
			
		||||
    """Generate the configuration information to return to a client."""
 | 
			
		||||
    # TODO: be selective about what we return. For now, we just get everything.
 | 
			
		||||
    info = {}
 | 
			
		||||
    config = json.loads(check_output('snapctl', 'get', 'config'))
 | 
			
		||||
    info['config'] = config
 | 
			
		||||
def require_auth(f):
 | 
			
		||||
    """Validates that proper authentication credentials are provided in the
 | 
			
		||||
    request.
 | 
			
		||||
 | 
			
		||||
    # Add the controller's TLS certificate data
 | 
			
		||||
    tls_path_map = {
 | 
			
		||||
        'cacert-path': 'tls_cacert',
 | 
			
		||||
        'cert-path': 'tls_cert',
 | 
			
		||||
        'key-path': 'tls_key',
 | 
			
		||||
    }
 | 
			
		||||
    for tls_config, tls_file in tls_path_map.items():
 | 
			
		||||
        with open(config_get('config.tls.{}'.format(tls_config)), "r") as f:
 | 
			
		||||
            info[tls_file] = f.read()
 | 
			
		||||
 | 
			
		||||
    return info
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.route('/join', methods=['POST'])
 | 
			
		||||
def join():
 | 
			
		||||
    """Authorize a client node and return relevant config."""
 | 
			
		||||
 | 
			
		||||
    # Retrieve an API version from the request - it is a mandatory
 | 
			
		||||
    # header for this API.
 | 
			
		||||
    request_version = request.headers.get('API-Version')
 | 
			
		||||
    if request_version is None:
 | 
			
		||||
        logger.debug('The client has not specified the API-version header.')
 | 
			
		||||
        raise APIVersionMissing()
 | 
			
		||||
    else:
 | 
			
		||||
        try:
 | 
			
		||||
            api_version = semantic_version.Version(request_version)
 | 
			
		||||
        except ValueError:
 | 
			
		||||
            logger.debug('The client has specified an invalid API version.'
 | 
			
		||||
                         f': {request_version}')
 | 
			
		||||
            raise APIVersionInvalid()
 | 
			
		||||
 | 
			
		||||
    # Compare the API version used by the clustering service with the
 | 
			
		||||
    # one specified in the request and return an appropriate response.
 | 
			
		||||
    if api_version.major > API_VERSION.major:
 | 
			
		||||
        logger.debug('The client requested a version that is not'
 | 
			
		||||
                     f' supported yet: {api_version}.')
 | 
			
		||||
        raise APIVersionNotImplemented()
 | 
			
		||||
    elif api_version.major < API_VERSION.major:
 | 
			
		||||
        logger.debug('The client request version is no longer supported'
 | 
			
		||||
                     f': {api_version}.')
 | 
			
		||||
        raise APIVersionDropped()
 | 
			
		||||
    else:
 | 
			
		||||
    Checks the provided request.json data to retrieve the credential_id and the
 | 
			
		||||
    credential_secret and then validates the provided credentials and tokens
 | 
			
		||||
    are valid tokens.
 | 
			
		||||
    """
 | 
			
		||||
    @wraps(f)
 | 
			
		||||
    def decorated_function(*args, **kwargs):
 | 
			
		||||
        # Flask raises a BadRequest if the JSON content is invalid and
 | 
			
		||||
        # returns None if the Content-Type header is missing or not set
 | 
			
		||||
        # to application/json.
 | 
			
		||||
@@ -293,6 +267,61 @@ def join():
 | 
			
		||||
                             ' connecting to Keystone')
 | 
			
		||||
            raise UnexpectedError()
 | 
			
		||||
 | 
			
		||||
        return f(*args, **kwargs)
 | 
			
		||||
 | 
			
		||||
    return decorated_function
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def join_info():
 | 
			
		||||
    """Generate the configuration information to return to a client."""
 | 
			
		||||
    # TODO: be selective about what we return. For now, we just get everything.
 | 
			
		||||
    info = {}
 | 
			
		||||
    config = json.loads(check_output('snapctl', 'get', 'config'))
 | 
			
		||||
    info['config'] = config
 | 
			
		||||
 | 
			
		||||
    # Add the controller's TLS certificate data
 | 
			
		||||
    tls_path_map = {
 | 
			
		||||
        'cacert-path': 'tls_cacert',
 | 
			
		||||
        'cert-path': 'tls_cert',
 | 
			
		||||
        'key-path': 'tls_key',
 | 
			
		||||
    }
 | 
			
		||||
    for tls_config, tls_file in tls_path_map.items():
 | 
			
		||||
        with open(config_get('config.tls.{}'.format(tls_config)), "r") as f:
 | 
			
		||||
            info[tls_file] = f.read()
 | 
			
		||||
 | 
			
		||||
    return info
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.route('/join', methods=['POST'])
 | 
			
		||||
@require_auth
 | 
			
		||||
def join():
 | 
			
		||||
    """Authorize a client node and return relevant config."""
 | 
			
		||||
 | 
			
		||||
    # Retrieve an API version from the request - it is a mandatory
 | 
			
		||||
    # header for this API.
 | 
			
		||||
    request_version = request.headers.get('API-Version')
 | 
			
		||||
    if request_version is None:
 | 
			
		||||
        logger.debug('The client has not specified the API-version header.')
 | 
			
		||||
        raise APIVersionMissing()
 | 
			
		||||
    else:
 | 
			
		||||
        try:
 | 
			
		||||
            api_version = semantic_version.Version(request_version)
 | 
			
		||||
        except ValueError:
 | 
			
		||||
            logger.debug('The client has specified an invalid API version.'
 | 
			
		||||
                         f': {request_version}')
 | 
			
		||||
            raise APIVersionInvalid()
 | 
			
		||||
 | 
			
		||||
    # Compare the API version used by the clustering service with the
 | 
			
		||||
    # one specified in the request and return an appropriate response.
 | 
			
		||||
    if api_version.major > API_VERSION.major:
 | 
			
		||||
        logger.debug('The client requested a version that is not'
 | 
			
		||||
                     f' supported yet: {api_version}.')
 | 
			
		||||
        raise APIVersionNotImplemented()
 | 
			
		||||
    elif api_version.major < API_VERSION.major:
 | 
			
		||||
        logger.debug('The client request version is no longer supported'
 | 
			
		||||
                     f': {api_version}. Client needs to be updated.')
 | 
			
		||||
        raise APIVersionDropped()
 | 
			
		||||
    else:
 | 
			
		||||
        # We were able to authenticate against Keystone using the
 | 
			
		||||
        # application credential and verify that it has not expired
 | 
			
		||||
        # so the information for a compute node to join the cluster can
 | 
			
		||||
@@ -300,6 +329,24 @@ def join():
 | 
			
		||||
        return json.dumps(join_info())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.route('/pki/sign', methods=['POST'])
 | 
			
		||||
@require_auth
 | 
			
		||||
def sign_csr():
 | 
			
		||||
    """Signs the provided CSR."""
 | 
			
		||||
    json_data = request.json
 | 
			
		||||
    csr = json_data.get('csr')
 | 
			
		||||
    if not csr:
 | 
			
		||||
        logger.debug('Client requested CSR but did not provide CSR')
 | 
			
		||||
        raise MissingParameterError()
 | 
			
		||||
 | 
			
		||||
    csr = bytes(csr, 'utf-8')
 | 
			
		||||
    key_path = Path(config_get('config.tls.key-path'))
 | 
			
		||||
    ca_path = Path(config_get('config.tls.cacert-path'))
 | 
			
		||||
    cert = generate_cert_from_csr(ca_path, key_path, csr)
 | 
			
		||||
 | 
			
		||||
    return cert
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.route('/')
 | 
			
		||||
def home():
 | 
			
		||||
    status = {
 | 
			
		||||
 
 | 
			
		||||
@@ -92,6 +92,12 @@ class Clustering(Question):
 | 
			
		||||
 | 
			
		||||
        role = shell.config_get('config.cluster.role')
 | 
			
		||||
 | 
			
		||||
        # Generate the compute private key for all nodes since they all act as
 | 
			
		||||
        # compute nodes.
 | 
			
		||||
        compute_key_path = Path(
 | 
			
		||||
            shell.config_get('config.tls.compute.key-path'))
 | 
			
		||||
        tls.create_or_get_private_key(compute_key_path)
 | 
			
		||||
 | 
			
		||||
        if role == 'compute':
 | 
			
		||||
            log.info('Setting up as a compute node.')
 | 
			
		||||
            # Gets config info and sets local env vals.
 | 
			
		||||
 
 | 
			
		||||
@@ -5,6 +5,7 @@ from init.shell import check
 | 
			
		||||
 | 
			
		||||
from datetime import datetime
 | 
			
		||||
from dateutil.relativedelta import relativedelta
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
import ipaddress
 | 
			
		||||
import socket
 | 
			
		||||
 | 
			
		||||
@@ -18,6 +19,39 @@ from cryptography.x509.oid import NameOID
 | 
			
		||||
from init import shell
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def create_or_get_private_key(key_path: Path) -> rsa.RSAPrivateKey:
 | 
			
		||||
    """Generate a local private key file.
 | 
			
		||||
 | 
			
		||||
    :param key_path: path of the key
 | 
			
		||||
    :type key_path: Path
 | 
			
		||||
    :return: private key
 | 
			
		||||
    :rtype: rs.RSAPrivateKey
 | 
			
		||||
    """
 | 
			
		||||
    # If the key path exists, then attempt to load it in order to make sure
 | 
			
		||||
    # it is a valid private key.
 | 
			
		||||
    if key_path.exists():
 | 
			
		||||
        with open(key_path, 'rb') as f:
 | 
			
		||||
            key = serialization.load_pem_private_key(f.read(), None,
 | 
			
		||||
                                                     default_backend())
 | 
			
		||||
            if not isinstance(key, rsa.RSAPrivateKey):
 | 
			
		||||
                raise TypeError('Private key already exists but is not an '
 | 
			
		||||
                                'RSA key')
 | 
			
		||||
            return key
 | 
			
		||||
    key = rsa.generate_private_key(
 | 
			
		||||
        public_exponent=65537,
 | 
			
		||||
        key_size=2048,
 | 
			
		||||
        backend=default_backend(),
 | 
			
		||||
    )
 | 
			
		||||
    serialized_key = key.private_bytes(
 | 
			
		||||
        encoding=serialization.Encoding.PEM,
 | 
			
		||||
        format=serialization.PrivateFormat.PKCS8,
 | 
			
		||||
        encryption_algorithm=serialization.NoEncryption(),
 | 
			
		||||
    )
 | 
			
		||||
    key_path.write_bytes(serialized_key)
 | 
			
		||||
    check('chmod', '600', str(key_path))
 | 
			
		||||
    return key
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def generate_self_signed(cert_path, key_path, ip=None,
 | 
			
		||||
                         fingerprint_config=None):
 | 
			
		||||
    """Generate a self-signed certificate with associated keys.
 | 
			
		||||
@@ -34,15 +68,11 @@ def generate_self_signed(cert_path, key_path, ip=None,
 | 
			
		||||
    """
 | 
			
		||||
    # Do not generate a new certificate and key if there is already an existing
 | 
			
		||||
    # pair. TODO: improve this check and allow renewal.
 | 
			
		||||
    if cert_path.exists() and key_path.exists():
 | 
			
		||||
    if cert_path.exists():
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    key = rsa.generate_private_key(
 | 
			
		||||
        public_exponent=65537,
 | 
			
		||||
        key_size=2048,
 | 
			
		||||
        backend=default_backend(),
 | 
			
		||||
    )
 | 
			
		||||
    cn = socket.gethostname()
 | 
			
		||||
    key = create_or_get_private_key(key_path=key_path)
 | 
			
		||||
    cn = socket.getfqdn()
 | 
			
		||||
    common_name = x509.Name([
 | 
			
		||||
        x509.NameAttribute(NameOID.COMMON_NAME, cn)
 | 
			
		||||
    ])
 | 
			
		||||
@@ -74,12 +104,85 @@ def generate_self_signed(cert_path, key_path, ip=None,
 | 
			
		||||
        shell.config_set(**{fingerprint_config: cert_fprint})
 | 
			
		||||
 | 
			
		||||
    serialized_cert = cert.public_bytes(encoding=serialization.Encoding.PEM)
 | 
			
		||||
    serialized_key = key.private_bytes(
 | 
			
		||||
        encoding=serialization.Encoding.PEM,
 | 
			
		||||
        format=serialization.PrivateFormat.PKCS8,
 | 
			
		||||
        encryption_algorithm=serialization.NoEncryption(),
 | 
			
		||||
    )
 | 
			
		||||
    cert_path.write_bytes(serialized_cert)
 | 
			
		||||
    key_path.write_bytes(serialized_key)
 | 
			
		||||
    check('chmod', '644', str(cert_path))
 | 
			
		||||
    check('chmod', '600', str(key_path))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def create_csr(key_path: Path, ip: str = None) -> \
 | 
			
		||||
        x509.CertificateSigningRequest:
 | 
			
		||||
    """Creates a Certificate Signing Request (CSR) for the local node.
 | 
			
		||||
 | 
			
		||||
    A CSR is created for the local node. The resulting CSR can be provided to
 | 
			
		||||
    generate a Certificate in a PKI infrastructure. The CSR will be generated
 | 
			
		||||
    using the local nodes hostname as the CN and SAN in the request. The CSR
 | 
			
		||||
    generated will not request certificate authority.
 | 
			
		||||
 | 
			
		||||
    :param key_path: the path to the local private key file
 | 
			
		||||
    :type key_path: Path
 | 
			
		||||
    :param ip: the ip address of the local node
 | 
			
		||||
    :type str: the ip address of the local node
 | 
			
		||||
    :returns: x509.CertificateSigningRequest object for the local node
 | 
			
		||||
    :rtype: x509.CertificateSigningRequest
 | 
			
		||||
    """
 | 
			
		||||
    with open(key_path, 'rb+') as f:
 | 
			
		||||
        key = serialization.load_pem_private_key(f.read(), None,
 | 
			
		||||
                                                 default_backend())
 | 
			
		||||
 | 
			
		||||
    hostname = socket.getfqdn()
 | 
			
		||||
    cn = x509.NameAttribute(NameOID.COMMON_NAME, hostname)
 | 
			
		||||
    if ip:
 | 
			
		||||
        san = x509.SubjectAlternativeName(
 | 
			
		||||
            [x509.DNSName(cn), x509.IPAddress(ipaddress.ip_address(ip))]
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        san = x509.SubjectAlternativeName([x509.DNSName(hostname)])
 | 
			
		||||
    not_ca = x509.BasicConstraints(ca=False, path_length=None)
 | 
			
		||||
 | 
			
		||||
    builder = x509.CertificateSigningRequestBuilder()
 | 
			
		||||
    builder = builder.subject_name(x509.Name([cn]))
 | 
			
		||||
    builder = builder.add_extension(san, critical=False)
 | 
			
		||||
    builder = builder.add_extension(not_ca, critical=True)
 | 
			
		||||
 | 
			
		||||
    request = builder.sign(key, hashes.SHA256(), backend=default_backend())
 | 
			
		||||
    return request.public_bytes(serialization.Encoding.PEM)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def generate_cert_from_csr(ca_path, key_path, client_csr):
 | 
			
		||||
    """Generates a certificate from a Certificate Signing Request (CSR).
 | 
			
		||||
 | 
			
		||||
    :param ca_path: the path to the ca cert
 | 
			
		||||
    :type ca_path: str or Path
 | 
			
		||||
    :param key_path: the path to the ca cert key file
 | 
			
		||||
    :type key_path: str or Path
 | 
			
		||||
    :param client_csr: the certificate signing request from a client
 | 
			
		||||
    :return: PEM encoded certificate
 | 
			
		||||
    :rtype: bytes
 | 
			
		||||
    """
 | 
			
		||||
    with open(ca_path, 'rb') as f:
 | 
			
		||||
        cacert = x509.load_pem_x509_certificate(f.read(), default_backend())
 | 
			
		||||
 | 
			
		||||
    with open(key_path, 'rb') as f:
 | 
			
		||||
        key = serialization.load_pem_private_key(f.read(), None,
 | 
			
		||||
                                                 default_backend())
 | 
			
		||||
 | 
			
		||||
    csr = x509.load_pem_x509_csr(client_csr, default_backend())
 | 
			
		||||
 | 
			
		||||
    builder = (
 | 
			
		||||
        x509.CertificateBuilder()
 | 
			
		||||
        .subject_name(csr.subject)
 | 
			
		||||
        .issuer_name(cacert.subject)
 | 
			
		||||
        .public_key(csr.public_key())
 | 
			
		||||
        .serial_number(x509.random_serial_number())
 | 
			
		||||
        .not_valid_before(datetime.utcnow())
 | 
			
		||||
        .not_valid_after(
 | 
			
		||||
            # Set it to expire 2 days before our cacert does
 | 
			
		||||
            cacert.not_valid_after - relativedelta(days=2)
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Add requested extensions
 | 
			
		||||
    for extension in csr.extensions:
 | 
			
		||||
        builder.add_extension(extension.value, extension.critical)
 | 
			
		||||
 | 
			
		||||
    cert = builder.sign(key, hashes.SHA256(), default_backend())
 | 
			
		||||
    return cert.public_bytes(encoding=serialization.Encoding.PEM)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user