Updating support for the LongInteger primitive

This change updates the LongInteger primitive, improving class
documentation, adding standard Python operators, cleaning up the
original implementation, and expanding the corresponding unit test suite
to address the modifications.
This commit is contained in:
Peter Hamilton 2015-08-27 08:47:07 -04:00
parent f3c202cf3c
commit b04e3ab3dc
3 changed files with 388 additions and 171 deletions

View File

@ -19,3 +19,11 @@ class InvalidKmipEncoding(Exception):
An exception raised when processing invalid KMIP message encodings.
"""
pass
class InvalidPrimitiveLength(Exception):
"""
An exception raised for errors when processing primitives with invalid
lengths.
"""
pass

View File

@ -26,6 +26,7 @@ from kmip.core.enums import Tags
from kmip.core.errors import ErrorStrings
from kmip.core import errors
from kmip.core import exceptions
from kmip.core import utils
@ -249,51 +250,110 @@ class Integer(Base):
class LongInteger(Base):
"""
An encodeable object representing a long integer value.
A LongInteger is one of the KMIP primitive object types. It is encoded as
a signed, big-endian, 64-bit integer. For more information, see Section
9.1 of the KMIP 1.1 specification.
"""
LENGTH = 8
def __init__(self, value=None, tag=Tags.DEFAULT):
# Bounds for signed 64-bit integers
MIN = -9223372036854775808
MAX = 9223372036854775807
def __init__(self, value=0, tag=Tags.DEFAULT):
"""
Create a LongInteger.
Args:
value (int): The value of the LongInteger. Optional, defaults to 0.
tag (Tags): An enumeration defining the tag of the LongInteger.
Optional, defaults to Tags.DEFAULT.
"""
super(LongInteger, self).__init__(tag, type=Types.LONG_INTEGER)
self.value = value
self.length = self.LENGTH
self.length = LongInteger.LENGTH
self.validate()
def read_value(self, istream):
if self.length is not self.LENGTH:
raise errors.ReadValueError(LongInteger.__name__, 'length',
self.LENGTH, self.length)
def read(self, istream):
"""
Read the encoding of the LongInteger from the input stream.
Args:
istream (stream): A buffer containing the encoded bytes of a
LongInteger. Usually a BytearrayStream object. Required.
Raises:
InvalidPrimitiveLength: if the long integer encoding read in has
an invalid encoded length.
"""
super(LongInteger, self).read(istream)
if self.length is not LongInteger.LENGTH:
raise exceptions.InvalidPrimitiveLength(
"invalid long integer length read; "
"expected: {0}, observed: {1}".format(
LongInteger.LENGTH, self.length))
self.value = unpack('!q', istream.read(self.length))[0]
self.validate()
def read(self, istream):
super(LongInteger, self).read(istream)
self.read_value(istream)
def write(self, ostream):
"""
Write the encoding of the LongInteger to the output stream.
def write_value(self, ostream):
Args:
ostream (stream): A buffer to contain the encoded bytes of a
LongInteger. Usually a BytearrayStream object. Required.
"""
super(LongInteger, self).write(ostream)
ostream.write(pack('!q', self.value))
def write(self, ostream):
super(LongInteger, self).write(ostream)
self.write_value(ostream)
def validate(self):
self.__validate()
"""
Verify that the value of the LongInteger is valid.
def __validate(self):
Raises:
TypeError: if the value is not of type int or long
ValueError: if the value cannot be represented by a signed 64-bit
integer
"""
if self.value is not None:
data_type = type(self.value)
if data_type not in six.integer_types:
raise errors.StateTypeError(
LongInteger.__name__, "{0}".format(six.integer_types),
data_type)
num_bytes = utils.count_bytes(self.value)
if num_bytes > self.length:
raise errors.StateOverflowError(
LongInteger.__name__, 'value', self.length, num_bytes)
if not isinstance(self.value, six.integer_types):
raise TypeError('expected (one of): {0}, observed: {1}'.format(
six.integer_types, type(self.value)))
else:
if self.value > LongInteger.MAX:
raise ValueError(
'long integer value greater than accepted max')
elif self.value < LongInteger.MIN:
raise ValueError(
'long integer value less than accepted min')
def __repr__(self):
return '<Long Integer, %d>' % (self.value)
return "LongInteger(value={0}, tag={1})".format(self.value, self.tag)
def __str__(self):
return str(self.value)
def __eq__(self, other):
if isinstance(other, LongInteger):
if self.value == other.value:
return True
else:
return False
else:
return NotImplemented
def __ne__(self, other):
if isinstance(other, LongInteger):
return not self.__eq__(other)
else:
return NotImplemented
class BigInteger(Base):

View File

@ -15,195 +15,344 @@
import testtools
from kmip.core import errors
from kmip.core import exceptions
from kmip.core import primitives
from kmip.core import utils
class TestLongInteger(testtools.TestCase):
"""
Test suite for the LongInteger primitive.
"""
def setUp(self):
super(TestLongInteger, self).setUp()
self.stream = utils.BytearrayStream()
self.max_byte_long = 18446744073709551615
self.max_long = 9223372036854775807
self.bad_value = (
'Bad primitives.LongInteger.{0} after init: expected {1}, '
'received {2}')
self.bad_write = (
'Bad primitives.LongInteger write: expected {0} bytes, '
'received {1} bytes')
self.bad_encoding = (
'Bad primitives.LongInteger write: encoding mismatch')
self.bad_read = (
'Bad primitives.LongInteger.value read: expected {0}, '
'received {1}')
def tearDown(self):
super(TestLongInteger, self).tearDown()
def test_init(self):
i = primitives.LongInteger(0)
self.assertEqual(0, i.value,
self.bad_value.format('value', 0, i.value))
self.assertEqual(i.LENGTH, i.length,
self.bad_value.format('length', i.LENGTH, i.length))
"""
Test that a LongInteger can be instantiated.
"""
long_int = primitives.LongInteger(1)
self.assertEqual(1, long_int.value)
def test_init_unset(self):
i = primitives.LongInteger()
"""
Test that a LongInteger can be instantiated with no input.
"""
long_int = primitives.LongInteger()
self.assertEqual(0, long_int.value)
self.assertEqual(None, i.value,
self.bad_value.format('value', None, i.value))
self.assertEqual(i.LENGTH, i.length,
self.bad_value.format('length', i.LENGTH, i.length))
def test_init_on_max(self):
"""
Test that a LongInteger can be instantiated with the maximum possible
signed 64-bit value.
"""
primitives.LongInteger(primitives.LongInteger.MAX)
def test_init_on_min(self):
"""
Test that a LongInteger can be instantiated with the minimum possible
signed 64-bit value.
"""
primitives.LongInteger(primitives.LongInteger.MIN)
def test_validate_on_valid(self):
i = primitives.LongInteger()
i.value = 0
# Check no exception thrown
i.validate()
def test_validate_on_valid_long(self):
i = primitives.LongInteger()
i.value = self.max_long + 1
# Check no exception thrown
i.validate()
"""
Test that a LongInteger can be validated on good input.
"""
long_int = primitives.LongInteger(1)
long_int.validate()
def test_validate_on_valid_unset(self):
i = primitives.LongInteger()
# Check no exception thrown
i.validate()
"""
Test that a LongInteger with no preset value can be validated.
"""
long_int = primitives.LongInteger()
long_int.validate()
def test_validate_on_invalid_type(self):
i = primitives.LongInteger()
i.value = 'test'
"""
Test that a TypeError is thrown on input of invalid type (e.g., str).
"""
self.assertRaises(TypeError, primitives.LongInteger, 'invalid')
self.assertRaises(errors.StateTypeError, i.validate)
def test_validate_on_invalid_value_too_big(self):
"""
Test that a ValueError is thrown on input that is too large.
"""
self.assertRaises(
ValueError, primitives.LongInteger, primitives.LongInteger.MAX + 1)
def test_validate_on_invalid_value(self):
self.assertRaises(errors.StateOverflowError, primitives.LongInteger,
self.max_byte_long + 1)
def test_validate_on_invalid_value_too_small(self):
"""
Test that a ValueError is thrown on input that is too small.
"""
self.assertRaises(
ValueError, primitives.LongInteger, primitives.LongInteger.MIN - 1)
def test_read_value(self):
encoding = (b'\x00\x00\x00\x00\x00\x00\x00\x01')
self.stream = utils.BytearrayStream(encoding)
i = primitives.LongInteger()
i.read_value(self.stream)
def test_read_zero(self):
"""
Test that a LongInteger representing the value 0 can be read from a
byte stream.
"""
encoding = (
b'\x42\x00\x00\x03\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x00'
b'\x00')
stream = utils.BytearrayStream(encoding)
long_int = primitives.LongInteger()
long_int.read(stream)
self.assertEqual(0, long_int.value)
self.assertEqual(1, i.value, self.bad_read.format(1, i.value))
def test_read_max_max(self):
"""
Test that a LongInteger representing the maximum positive value can be
read from a byte stream.
"""
encoding = (
b'\x42\x00\x00\x03\x00\x00\x00\x08\x7f\xff\xff\xff\xff\xff\xff'
b'\xff')
stream = utils.BytearrayStream(encoding)
long_int = primitives.LongInteger()
long_int.read(stream)
self.assertEqual(primitives.LongInteger.MAX, long_int.value)
def test_read_value_zero(self):
encoding = (b'\x00\x00\x00\x00\x00\x00\x00\x00')
self.stream = utils.BytearrayStream(encoding)
i = primitives.LongInteger()
i.read_value(self.stream)
self.assertEqual(0, i.value, self.bad_read.format(0, i.value))
def test_read_value_max_positive(self):
encoding = (b'\x7f\xff\xff\xff\xff\xff\xff\xff')
self.stream = utils.BytearrayStream(encoding)
i = primitives.LongInteger()
i.read_value(self.stream)
self.assertEqual(self.max_long, i.value,
self.bad_read.format(1, i.value))
def test_read_value_min_negative(self):
encoding = (b'\xff\xff\xff\xff\xff\xff\xff\xff')
self.stream = utils.BytearrayStream(encoding)
i = primitives.LongInteger()
i.read_value(self.stream)
self.assertEqual(-1, i.value,
self.bad_read.format(1, i.value))
def test_read(self):
def test_read_min_max(self):
"""
Test that a LongInteger representing the minimum positive value can be
read from a byte stream.
"""
encoding = (
b'\x42\x00\x00\x03\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x00'
b'\x01')
self.stream = utils.BytearrayStream(encoding)
i = primitives.LongInteger()
i.read(self.stream)
stream = utils.BytearrayStream(encoding)
long_int = primitives.LongInteger()
long_int.read(stream)
self.assertEqual(1, long_int.value)
self.assertEqual(1, i.value, self.bad_read.format(1, i.value))
def test_read_max_min(self):
"""
Test that a LongInteger representing the maximum negative value can be
read from a byte stream.
"""
encoding = (
b'\x42\x00\x00\x03\x00\x00\x00\x08\xff\xff\xff\xff\xff\xff\xff'
b'\xff')
stream = utils.BytearrayStream(encoding)
long_int = primitives.LongInteger()
long_int.read(stream)
self.assertEqual(-1, long_int.value)
def test_read_min_min(self):
"""
Test that a LongInteger representing the minimum negative value can be
read from a byte stream.
"""
encoding = (
b'\x42\x00\x00\x03\x00\x00\x00\x08\x80\x00\x00\x00\x00\x00\x00'
b'\x00')
stream = utils.BytearrayStream(encoding)
long_int = primitives.LongInteger(primitives.LongInteger.MIN)
long_int.read(stream)
self.assertEqual(primitives.LongInteger.MIN, long_int.value)
def test_read_on_invalid_length(self):
"""
Test that an InvalidPrimitiveLength exception is thrown when attempting
to decode a LongInteger with an invalid length.
"""
encoding = (
b'\x42\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00'
b'\x00')
self.stream = utils.BytearrayStream(encoding)
i = primitives.LongInteger()
stream = utils.BytearrayStream(encoding)
long_int = primitives.LongInteger()
self.assertRaises(errors.ReadValueError, i.read, self.stream)
self.assertRaises(
exceptions.InvalidPrimitiveLength, long_int.read, stream)
def test_write_value(self):
encoding = (b'\x00\x00\x00\x00\x00\x00\x00\x01')
i = primitives.LongInteger(1)
i.write_value(self.stream)
def test_write_zero(self):
"""
Test that a LongInteger representing the value 0 can be written to a
byte stream.
"""
encoding = (
b'\x42\x00\x00\x03\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x00'
b'\x00')
stream = utils.BytearrayStream()
long_int = primitives.LongInteger(0)
long_int.write(stream)
result = self.stream.read()
len_exp = len(encoding)
len_rcv = len(result)
result = stream.read()
self.assertEqual(len(encoding), len(result))
self.assertEqual(encoding, result)
self.assertEqual(len_exp, len_rcv, self.bad_write.format(len_exp,
len_rcv))
self.assertEqual(encoding, result, self.bad_encoding)
def test_write_max_max(self):
"""
Test that a LongInteger representing the maximum positive value can be
written to a byte stream.
"""
encoding = (
b'\x42\x00\x00\x03\x00\x00\x00\x08\x7f\xff\xff\xff\xff\xff\xff'
b'\xff')
stream = utils.BytearrayStream()
long_int = primitives.LongInteger(primitives.LongInteger.MAX)
long_int.write(stream)
def test_write_value_zero(self):
encoding = (b'\x00\x00\x00\x00\x00\x00\x00\x00')
i = primitives.LongInteger(0)
i.write_value(self.stream)
result = stream.read()
self.assertEqual(len(encoding), len(result))
self.assertEqual(encoding, result)
result = self.stream.read()
len_exp = len(encoding)
len_rcv = len(result)
self.assertEqual(len_exp, len_rcv, self.bad_write.format(len_exp,
len_rcv))
self.assertEqual(encoding, result, self.bad_encoding)
def test_write_value_max_positive(self):
encoding = (b'\x7f\xff\xff\xff\xff\xff\xff\xff')
i = primitives.LongInteger(self.max_long)
i.write_value(self.stream)
result = self.stream.read()
len_exp = len(encoding)
len_rcv = len(result)
self.assertEqual(len_exp, len_rcv, self.bad_write.format(len_exp,
len_rcv))
self.assertEqual(encoding, result, self.bad_encoding)
def test_write_value_min_negative(self):
encoding = (b'\xff\xff\xff\xff\xff\xff\xff\xff')
i = primitives.LongInteger(-1)
i.write_value(self.stream)
result = self.stream.read()
len_exp = len(encoding)
len_rcv = len(result)
self.assertEqual(len_exp, len_rcv, self.bad_write.format(len_exp,
len_rcv))
self.assertEqual(encoding, result, self.bad_encoding)
def test_write(self):
def test_write_min_max(self):
"""
Test that a LongInteger representing the minimum positive value can be
written to a byte stream.
"""
encoding = (
b'\x42\x00\x00\x03\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x00'
b'\x01')
i = primitives.LongInteger(1)
i.write(self.stream)
stream = utils.BytearrayStream()
long_int = primitives.LongInteger(1)
long_int.write(stream)
result = self.stream.read()
len_exp = len(encoding)
len_rcv = len(result)
result = stream.read()
self.assertEqual(len(encoding), len(result))
self.assertEqual(encoding, result)
self.assertEqual(len_exp, len_rcv, self.bad_write.format(len_exp,
len_rcv))
self.assertEqual(encoding, result, self.bad_encoding)
def test_write_max_min(self):
"""
Test that a LongInteger representing the maximum negative value can be
written to a byte stream.
"""
encoding = (
b'\x42\x00\x00\x03\x00\x00\x00\x08\xff\xff\xff\xff\xff\xff\xff'
b'\xff')
stream = utils.BytearrayStream()
long_int = primitives.LongInteger(-1)
long_int.write(stream)
result = stream.read()
self.assertEqual(len(encoding), len(result))
self.assertEqual(encoding, result)
def test_write_min_min(self):
"""
Test that a LongInteger representing the minimum negative value can be
written to a byte stream.
"""
encoding = (
b'\x42\x00\x00\x03\x00\x00\x00\x08\x80\x00\x00\x00\x00\x00\x00'
b'\x00')
stream = utils.BytearrayStream()
long_int = primitives.LongInteger(primitives.LongInteger.MIN)
long_int.write(stream)
result = stream.read()
self.assertEqual(len(encoding), len(result))
self.assertEqual(encoding, result)
def test_repr(self):
"""
Test that the representation of a LongInteger is formatted properly.
"""
long_int = primitives.LongInteger()
value = "value={0}".format(long_int.value)
tag = "tag={0}".format(long_int.tag)
self.assertEqual(
"LongInteger({0}, {1})".format(value, tag), repr(long_int))
def test_str(self):
"""
Test that the string representation of a LongInteger is formatted
properly.
"""
self.assertEqual("0", str(primitives.LongInteger()))
def test_equal_on_equal(self):
"""
Test that the equality operator returns True when comparing two
LongIntegers.
"""
a = primitives.LongInteger(1)
b = primitives.LongInteger(1)
self.assertTrue(a == b)
self.assertTrue(b == a)
def test_equal_on_equal_and_empty(self):
"""
Test that the equality operator returns True when comparing two
LongIntegers.
"""
a = primitives.LongInteger()
b = primitives.LongInteger()
self.assertTrue(a == b)
self.assertTrue(b == a)
def test_equal_on_not_equal(self):
"""
Test that the equality operator returns False when comparing two
LongIntegers with different values.
"""
a = primitives.LongInteger(1)
b = primitives.LongInteger(2)
self.assertFalse(a == b)
self.assertFalse(b == a)
def test_equal_on_type_mismatch(self):
"""
Test that the equality operator returns False when comparing a
LongInteger to a non-LongInteger object.
"""
a = primitives.LongInteger()
b = 'invalid'
self.assertFalse(a == b)
self.assertFalse(b == a)
def test_not_equal_on_equal(self):
"""
Test that the inequality operator returns False when comparing
two LongIntegers with the same values.
"""
a = primitives.LongInteger(1)
b = primitives.LongInteger(1)
self.assertFalse(a != b)
self.assertFalse(b != a)
def test_not_equal_on_equal_and_empty(self):
"""
Test that the inequality operator returns False when comparing
two LongIntegers.
"""
a = primitives.LongInteger()
b = primitives.LongInteger()
self.assertFalse(a != b)
self.assertFalse(b != a)
def test_not_equal_on_not_equal(self):
"""
Test that the inequality operator returns True when comparing two
LongIntegers with different values.
"""
a = primitives.LongInteger(1)
b = primitives.LongInteger(2)
self.assertTrue(a != b)
self.assertTrue(b != a)
def test_not_equal_on_type_mismatch(self):
"""
Test that the inequality operator returns True when comparing a
LongInteger to a non-LongInteger object.
"""
a = primitives.LongInteger()
b = 'invalid'
self.assertTrue(a != b)
self.assertTrue(b != a)