anchor/tests/X509/test_extension.py

271 lines
9.4 KiB
Python

# -*- coding:utf-8 -*-
#
# Copyright 2014 Hewlett-Packard Development Company, L.P.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import unittest
import netaddr
from pyasn1.codec.der import encoder
from pyasn1.type import univ
from anchor.asn1 import rfc5280
from anchor.X509 import errors
from anchor.X509 import extension
class TestExtensionBase(unittest.TestCase):
def test_no_spec(self):
with self.assertRaises(errors.X509Error):
extension.X509Extension()
def test_invalid_asn(self):
with self.assertRaises(errors.X509Error):
extension.X509Extension("foobar")
def test_unknown_extension_str(self):
asn1 = rfc5280.Extension()
asn1['extnID'] = univ.ObjectIdentifier('1.2.3.4')
asn1['critical'] = False
asn1['extnValue'] = "foobar"
ext = extension.X509Extension(asn1)
self.assertEqual("1.2.3.4: <unknown>", str(ext))
def test_construct(self):
asn1 = rfc5280.Extension()
asn1['extnID'] = univ.ObjectIdentifier('1.2.3.4')
asn1['critical'] = False
asn1['extnValue'] = "foobar"
ext = extension.construct_extension(asn1)
self.assertIsInstance(ext, extension.X509Extension)
def test_construct_invalid_type(self):
with self.assertRaises(errors.X509Error):
extension.construct_extension("foobar")
def test_critical(self):
asn1 = rfc5280.Extension()
asn1['extnID'] = univ.ObjectIdentifier('1.2.3.4')
asn1['critical'] = False
asn1['extnValue'] = "foobar"
ext = extension.construct_extension(asn1)
self.assertFalse(ext.get_critical())
ext.set_critical(True)
self.assertTrue(ext.get_critical())
def test_serialise(self):
asn1 = rfc5280.Extension()
asn1['extnID'] = univ.ObjectIdentifier('1.2.3.4')
asn1['critical'] = False
asn1['extnValue'] = "foobar"
ext = extension.construct_extension(asn1)
self.assertEqual(ext.as_der(), encoder.encode(asn1))
def test_broken_set_value(self):
class SomeExt(extension.X509Extension):
spec = rfc5280.Extension
_oid = univ.ObjectIdentifier('1.2.3.4')
@classmethod
def _get_default_value(cls):
return 1234
with self.assertRaisesRegexp(errors.X509Error, 'incorrect type'):
SomeExt()
class TestBasicConstraints(unittest.TestCase):
def setUp(self):
self.ext = extension.X509ExtensionBasicConstraints()
def test_str(self):
self.assertEqual(str(self.ext),
"basicConstraints: CA: FALSE, pathLen: None")
def test_ca(self):
self.ext.set_ca(True)
self.assertTrue(self.ext.get_ca())
self.ext.set_ca(False)
self.assertFalse(self.ext.get_ca())
def test_pathlen(self):
self.ext.set_path_len_constraint(1)
self.assertEqual(1, self.ext.get_path_len_constraint())
class TestKeyUsage(unittest.TestCase):
def setUp(self):
self.ext = extension.X509ExtensionKeyUsage()
def test_usage_set(self):
self.ext.set_usage('digitalSignature', True)
self.ext.set_usage('keyAgreement', False)
self.assertTrue(self.ext.get_usage('digitalSignature'))
self.assertFalse(self.ext.get_usage('keyAgreement'))
def test_usage_reset(self):
self.ext.set_usage('digitalSignature', True)
self.ext.set_usage('digitalSignature', False)
self.assertFalse(self.ext.get_usage('digitalSignature'))
def test_usage_unset(self):
self.assertFalse(self.ext.get_usage('keyAgreement'))
def test_get_all_usage(self):
self.ext.set_usage('digitalSignature', True)
self.ext.set_usage('keyAgreement', False)
self.ext.set_usage('keyEncipherment', True)
self.assertEqual(set(['digitalSignature', 'keyEncipherment']),
set(self.ext.get_all_usages()))
def test_str(self):
self.ext.set_usage('digitalSignature', True)
self.assertEqual("keyUsage: digitalSignature", str(self.ext))
class TestSubjectAltName(unittest.TestCase):
def setUp(self):
self.ext = extension.X509ExtensionSubjectAltName()
self.domain = 'example.com'
self.ip = netaddr.IPAddress('1.2.3.4')
self.ip6 = netaddr.IPAddress('::1')
def test_dns_ids(self):
self.ext.add_dns_id(self.domain)
self.ext.add_ip(self.ip)
self.assertEqual([self.domain], self.ext.get_dns_ids())
def test_ips(self):
self.ext.add_dns_id(self.domain)
self.ext.add_ip(self.ip)
self.assertEqual([self.ip], self.ext.get_ips())
def test_ipv6(self):
self.ext.add_ip(self.ip6)
self.assertEqual([self.ip6], self.ext.get_ips())
def test_add_ip_invalid(self):
with self.assertRaises(errors.X509Error):
self.ext.add_ip("abcdef")
def test_str(self):
self.ext.add_dns_id(self.domain)
self.ext.add_ip(self.ip)
self.assertEqual("subjectAltName: DNS:example.com, IP:1.2.3.4",
str(self.ext))
class TestNameConstraints(unittest.TestCase):
def setUp(self):
self.ext = extension.X509ExtensionNameConstraints()
def test_length(self):
self.assertEqual(0, self.ext.get_permitted_length())
self.assertEqual(0, self.ext.get_excluded_length())
def test_add(self):
test_name = 'example.com'
test_type = 'dNSName'
self.assertEqual(0, self.ext.get_permitted_length())
self.assertEqual(0, self.ext.get_excluded_length())
self.ext.add_permitted(test_type, test_name)
self.assertEqual(1, self.ext.get_permitted_length())
self.assertEqual(0, self.ext.get_excluded_length())
self.ext.add_excluded(test_type, test_name)
self.assertEqual(1, self.ext.get_permitted_length())
self.assertEqual(1, self.ext.get_excluded_length())
def test_excluded(self):
self.ext.add_excluded('dNSName', 'example.com')
self.assertEqual(self.ext.get_excluded_range(0), (0, None))
self.assertEqual(self.ext.get_excluded_name(0),
('dNSName', b'example.com'))
def test_permitted(self):
self.ext.add_permitted('dNSName', 'example.com')
self.assertEqual(self.ext.get_permitted_range(0), (0, None))
self.assertEqual(self.ext.get_permitted_name(0),
('dNSName', b'example.com'))
class TestExtendedKeyUsage(unittest.TestCase):
def setUp(self):
self.ext = extension.X509ExtensionExtendedKeyUsage()
def test_get_all(self):
self.ext.set_usage(rfc5280.id_kp_clientAuth, True)
self.ext.set_usage(rfc5280.id_kp_codeSigning, True)
usages = self.ext.get_all_usages()
self.assertEqual(2, len(usages))
self.assertIn(rfc5280.id_kp_clientAuth, usages)
def test_get_one(self):
self.assertFalse(self.ext.get_usage(rfc5280.id_kp_clientAuth))
self.ext.set_usage(rfc5280.id_kp_clientAuth, True)
self.assertTrue(self.ext.get_usage(rfc5280.id_kp_clientAuth))
def test_set(self):
self.assertEqual(0, len(self.ext.get_all_usages()))
self.ext.set_usage(rfc5280.id_kp_clientAuth, True)
self.assertEqual(1, len(self.ext.get_all_usages()))
self.ext.set_usage(rfc5280.id_kp_clientAuth, True)
self.assertEqual(1, len(self.ext.get_all_usages()))
self.ext.set_usage(rfc5280.id_kp_codeSigning, True)
self.assertEqual(2, len(self.ext.get_all_usages()))
def test_unset(self):
self.ext.set_usage(rfc5280.id_kp_clientAuth, True)
self.ext.set_usage(rfc5280.id_kp_clientAuth, False)
self.assertEqual(0, len(self.ext.get_all_usages()))
self.ext.set_usage(rfc5280.id_kp_clientAuth, False)
self.assertEqual(0, len(self.ext.get_all_usages()))
def test_str(self):
self.ext.set_usage(rfc5280.id_kp_clientAuth, True)
self.ext.set_usage(rfc5280.id_kp_codeSigning, True)
self.assertEqual(
"extKeyUsage: TLS Web Client Authentication, Code Signing",
str(self.ext))
def test_invalid_usage(self):
self.assertRaises(ValueError, self.ext.get_usage,
univ.ObjectIdentifier('1.2.3.4'))
self.assertRaises(ValueError, self.ext.set_usage, True,
univ.ObjectIdentifier('1.2.3.4'))
class TestAuthorityKeyId(unittest.TestCase):
def setUp(self):
self.ext = extension.X509ExtensionAuthorityKeyId()
def test_key_id(self):
key_id = b"12345678"
self.ext.set_key_id(key_id)
self.assertEqual(key_id, self.ext.get_key_id())
def test_name_serial(self):
s = 12345678
self.ext.set_serial(s)
self.assertEqual(s, self.ext.get_serial())
class TestSubjectKeyId(unittest.TestCase):
def setUp(self):
self.ext = extension.X509ExtensionSubjectKeyId()
def test_key_id(self):
key_id = b"12345678"
self.ext.set_key_id(key_id)
self.assertEqual(key_id, self.ext.get_key_id())