From b04e3ab3dc2e475db3b2674987021424a343563a Mon Sep 17 00:00:00 2001 From: Peter Hamilton Date: Thu, 27 Aug 2015 08:47:07 -0400 Subject: [PATCH] Updating support for the LongInteger primitive This change updates the LongInteger primitive, improving class documentation, adding standard Python operators, cleaning up the original implementation, and expanding the corresponding unit test suite to address the modifications. --- kmip/core/exceptions.py | 8 + kmip/core/primitives.py | 112 +++-- .../unit/core/primitives/test_long_integer.py | 439 ++++++++++++------ 3 files changed, 388 insertions(+), 171 deletions(-) diff --git a/kmip/core/exceptions.py b/kmip/core/exceptions.py index c238171..fcab164 100644 --- a/kmip/core/exceptions.py +++ b/kmip/core/exceptions.py @@ -19,3 +19,11 @@ class InvalidKmipEncoding(Exception): An exception raised when processing invalid KMIP message encodings. """ pass + + +class InvalidPrimitiveLength(Exception): + """ + An exception raised for errors when processing primitives with invalid + lengths. + """ + pass diff --git a/kmip/core/primitives.py b/kmip/core/primitives.py index adb4d3a..c8003e3 100644 --- a/kmip/core/primitives.py +++ b/kmip/core/primitives.py @@ -26,6 +26,7 @@ from kmip.core.enums import Tags from kmip.core.errors import ErrorStrings from kmip.core import errors +from kmip.core import exceptions from kmip.core import utils @@ -249,51 +250,110 @@ class Integer(Base): class LongInteger(Base): + """ + An encodeable object representing a long integer value. + + A LongInteger is one of the KMIP primitive object types. It is encoded as + a signed, big-endian, 64-bit integer. For more information, see Section + 9.1 of the KMIP 1.1 specification. + """ + LENGTH = 8 - def __init__(self, value=None, tag=Tags.DEFAULT): + # Bounds for signed 64-bit integers + MIN = -9223372036854775808 + MAX = 9223372036854775807 + + def __init__(self, value=0, tag=Tags.DEFAULT): + """ + Create a LongInteger. + + Args: + value (int): The value of the LongInteger. Optional, defaults to 0. + tag (Tags): An enumeration defining the tag of the LongInteger. + Optional, defaults to Tags.DEFAULT. + """ super(LongInteger, self).__init__(tag, type=Types.LONG_INTEGER) self.value = value - self.length = self.LENGTH + self.length = LongInteger.LENGTH self.validate() - def read_value(self, istream): - if self.length is not self.LENGTH: - raise errors.ReadValueError(LongInteger.__name__, 'length', - self.LENGTH, self.length) + def read(self, istream): + """ + Read the encoding of the LongInteger from the input stream. + + Args: + istream (stream): A buffer containing the encoded bytes of a + LongInteger. Usually a BytearrayStream object. Required. + + Raises: + InvalidPrimitiveLength: if the long integer encoding read in has + an invalid encoded length. + """ + super(LongInteger, self).read(istream) + + if self.length is not LongInteger.LENGTH: + raise exceptions.InvalidPrimitiveLength( + "invalid long integer length read; " + "expected: {0}, observed: {1}".format( + LongInteger.LENGTH, self.length)) self.value = unpack('!q', istream.read(self.length))[0] self.validate() - def read(self, istream): - super(LongInteger, self).read(istream) - self.read_value(istream) + def write(self, ostream): + """ + Write the encoding of the LongInteger to the output stream. - def write_value(self, ostream): + Args: + ostream (stream): A buffer to contain the encoded bytes of a + LongInteger. Usually a BytearrayStream object. Required. + """ + super(LongInteger, self).write(ostream) ostream.write(pack('!q', self.value)) - def write(self, ostream): - super(LongInteger, self).write(ostream) - self.write_value(ostream) - def validate(self): - self.__validate() + """ + Verify that the value of the LongInteger is valid. - def __validate(self): + Raises: + TypeError: if the value is not of type int or long + ValueError: if the value cannot be represented by a signed 64-bit + integer + """ if self.value is not None: - data_type = type(self.value) - if data_type not in six.integer_types: - raise errors.StateTypeError( - LongInteger.__name__, "{0}".format(six.integer_types), - data_type) - num_bytes = utils.count_bytes(self.value) - if num_bytes > self.length: - raise errors.StateOverflowError( - LongInteger.__name__, 'value', self.length, num_bytes) + if not isinstance(self.value, six.integer_types): + raise TypeError('expected (one of): {0}, observed: {1}'.format( + six.integer_types, type(self.value))) + else: + if self.value > LongInteger.MAX: + raise ValueError( + 'long integer value greater than accepted max') + elif self.value < LongInteger.MIN: + raise ValueError( + 'long integer value less than accepted min') def __repr__(self): - return '' % (self.value) + return "LongInteger(value={0}, tag={1})".format(self.value, self.tag) + + def __str__(self): + return str(self.value) + + def __eq__(self, other): + if isinstance(other, LongInteger): + if self.value == other.value: + return True + else: + return False + else: + return NotImplemented + + def __ne__(self, other): + if isinstance(other, LongInteger): + return not self.__eq__(other) + else: + return NotImplemented class BigInteger(Base): diff --git a/kmip/tests/unit/core/primitives/test_long_integer.py b/kmip/tests/unit/core/primitives/test_long_integer.py index 420d7bf..95e8785 100644 --- a/kmip/tests/unit/core/primitives/test_long_integer.py +++ b/kmip/tests/unit/core/primitives/test_long_integer.py @@ -15,195 +15,344 @@ import testtools -from kmip.core import errors +from kmip.core import exceptions from kmip.core import primitives from kmip.core import utils class TestLongInteger(testtools.TestCase): + """ + Test suite for the LongInteger primitive. + """ def setUp(self): super(TestLongInteger, self).setUp() - self.stream = utils.BytearrayStream() - self.max_byte_long = 18446744073709551615 - self.max_long = 9223372036854775807 - self.bad_value = ( - 'Bad primitives.LongInteger.{0} after init: expected {1}, ' - 'received {2}') - self.bad_write = ( - 'Bad primitives.LongInteger write: expected {0} bytes, ' - 'received {1} bytes') - self.bad_encoding = ( - 'Bad primitives.LongInteger write: encoding mismatch') - self.bad_read = ( - 'Bad primitives.LongInteger.value read: expected {0}, ' - 'received {1}') def tearDown(self): super(TestLongInteger, self).tearDown() def test_init(self): - i = primitives.LongInteger(0) - - self.assertEqual(0, i.value, - self.bad_value.format('value', 0, i.value)) - self.assertEqual(i.LENGTH, i.length, - self.bad_value.format('length', i.LENGTH, i.length)) + """ + Test that a LongInteger can be instantiated. + """ + long_int = primitives.LongInteger(1) + self.assertEqual(1, long_int.value) def test_init_unset(self): - i = primitives.LongInteger() + """ + Test that a LongInteger can be instantiated with no input. + """ + long_int = primitives.LongInteger() + self.assertEqual(0, long_int.value) - self.assertEqual(None, i.value, - self.bad_value.format('value', None, i.value)) - self.assertEqual(i.LENGTH, i.length, - self.bad_value.format('length', i.LENGTH, i.length)) + def test_init_on_max(self): + """ + Test that a LongInteger can be instantiated with the maximum possible + signed 64-bit value. + """ + primitives.LongInteger(primitives.LongInteger.MAX) + + def test_init_on_min(self): + """ + Test that a LongInteger can be instantiated with the minimum possible + signed 64-bit value. + """ + primitives.LongInteger(primitives.LongInteger.MIN) def test_validate_on_valid(self): - i = primitives.LongInteger() - i.value = 0 - - # Check no exception thrown - i.validate() - - def test_validate_on_valid_long(self): - i = primitives.LongInteger() - i.value = self.max_long + 1 - - # Check no exception thrown - i.validate() + """ + Test that a LongInteger can be validated on good input. + """ + long_int = primitives.LongInteger(1) + long_int.validate() def test_validate_on_valid_unset(self): - i = primitives.LongInteger() - - # Check no exception thrown - i.validate() + """ + Test that a LongInteger with no preset value can be validated. + """ + long_int = primitives.LongInteger() + long_int.validate() def test_validate_on_invalid_type(self): - i = primitives.LongInteger() - i.value = 'test' + """ + Test that a TypeError is thrown on input of invalid type (e.g., str). + """ + self.assertRaises(TypeError, primitives.LongInteger, 'invalid') - self.assertRaises(errors.StateTypeError, i.validate) + def test_validate_on_invalid_value_too_big(self): + """ + Test that a ValueError is thrown on input that is too large. + """ + self.assertRaises( + ValueError, primitives.LongInteger, primitives.LongInteger.MAX + 1) - def test_validate_on_invalid_value(self): - self.assertRaises(errors.StateOverflowError, primitives.LongInteger, - self.max_byte_long + 1) + def test_validate_on_invalid_value_too_small(self): + """ + Test that a ValueError is thrown on input that is too small. + """ + self.assertRaises( + ValueError, primitives.LongInteger, primitives.LongInteger.MIN - 1) - def test_read_value(self): - encoding = (b'\x00\x00\x00\x00\x00\x00\x00\x01') - self.stream = utils.BytearrayStream(encoding) - i = primitives.LongInteger() - i.read_value(self.stream) + def test_read_zero(self): + """ + Test that a LongInteger representing the value 0 can be read from a + byte stream. + """ + encoding = ( + b'\x42\x00\x00\x03\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x00' + b'\x00') + stream = utils.BytearrayStream(encoding) + long_int = primitives.LongInteger() + long_int.read(stream) + self.assertEqual(0, long_int.value) - self.assertEqual(1, i.value, self.bad_read.format(1, i.value)) + def test_read_max_max(self): + """ + Test that a LongInteger representing the maximum positive value can be + read from a byte stream. + """ + encoding = ( + b'\x42\x00\x00\x03\x00\x00\x00\x08\x7f\xff\xff\xff\xff\xff\xff' + b'\xff') + stream = utils.BytearrayStream(encoding) + long_int = primitives.LongInteger() + long_int.read(stream) + self.assertEqual(primitives.LongInteger.MAX, long_int.value) - def test_read_value_zero(self): - encoding = (b'\x00\x00\x00\x00\x00\x00\x00\x00') - self.stream = utils.BytearrayStream(encoding) - i = primitives.LongInteger() - i.read_value(self.stream) - - self.assertEqual(0, i.value, self.bad_read.format(0, i.value)) - - def test_read_value_max_positive(self): - encoding = (b'\x7f\xff\xff\xff\xff\xff\xff\xff') - self.stream = utils.BytearrayStream(encoding) - i = primitives.LongInteger() - i.read_value(self.stream) - - self.assertEqual(self.max_long, i.value, - self.bad_read.format(1, i.value)) - - def test_read_value_min_negative(self): - encoding = (b'\xff\xff\xff\xff\xff\xff\xff\xff') - self.stream = utils.BytearrayStream(encoding) - i = primitives.LongInteger() - i.read_value(self.stream) - - self.assertEqual(-1, i.value, - self.bad_read.format(1, i.value)) - - def test_read(self): + def test_read_min_max(self): + """ + Test that a LongInteger representing the minimum positive value can be + read from a byte stream. + """ encoding = ( b'\x42\x00\x00\x03\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x00' b'\x01') - self.stream = utils.BytearrayStream(encoding) - i = primitives.LongInteger() - i.read(self.stream) + stream = utils.BytearrayStream(encoding) + long_int = primitives.LongInteger() + long_int.read(stream) + self.assertEqual(1, long_int.value) - self.assertEqual(1, i.value, self.bad_read.format(1, i.value)) + def test_read_max_min(self): + """ + Test that a LongInteger representing the maximum negative value can be + read from a byte stream. + """ + encoding = ( + b'\x42\x00\x00\x03\x00\x00\x00\x08\xff\xff\xff\xff\xff\xff\xff' + b'\xff') + stream = utils.BytearrayStream(encoding) + long_int = primitives.LongInteger() + long_int.read(stream) + self.assertEqual(-1, long_int.value) + + def test_read_min_min(self): + """ + Test that a LongInteger representing the minimum negative value can be + read from a byte stream. + """ + encoding = ( + b'\x42\x00\x00\x03\x00\x00\x00\x08\x80\x00\x00\x00\x00\x00\x00' + b'\x00') + stream = utils.BytearrayStream(encoding) + long_int = primitives.LongInteger(primitives.LongInteger.MIN) + long_int.read(stream) + self.assertEqual(primitives.LongInteger.MIN, long_int.value) def test_read_on_invalid_length(self): + """ + Test that an InvalidPrimitiveLength exception is thrown when attempting + to decode a LongInteger with an invalid length. + """ encoding = ( b'\x42\x00\x00\x03\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00' b'\x00') - self.stream = utils.BytearrayStream(encoding) - i = primitives.LongInteger() + stream = utils.BytearrayStream(encoding) + long_int = primitives.LongInteger() - self.assertRaises(errors.ReadValueError, i.read, self.stream) + self.assertRaises( + exceptions.InvalidPrimitiveLength, long_int.read, stream) - def test_write_value(self): - encoding = (b'\x00\x00\x00\x00\x00\x00\x00\x01') - i = primitives.LongInteger(1) - i.write_value(self.stream) + def test_write_zero(self): + """ + Test that a LongInteger representing the value 0 can be written to a + byte stream. + """ + encoding = ( + b'\x42\x00\x00\x03\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x00' + b'\x00') + stream = utils.BytearrayStream() + long_int = primitives.LongInteger(0) + long_int.write(stream) - result = self.stream.read() - len_exp = len(encoding) - len_rcv = len(result) + result = stream.read() + self.assertEqual(len(encoding), len(result)) + self.assertEqual(encoding, result) - self.assertEqual(len_exp, len_rcv, self.bad_write.format(len_exp, - len_rcv)) - self.assertEqual(encoding, result, self.bad_encoding) + def test_write_max_max(self): + """ + Test that a LongInteger representing the maximum positive value can be + written to a byte stream. + """ + encoding = ( + b'\x42\x00\x00\x03\x00\x00\x00\x08\x7f\xff\xff\xff\xff\xff\xff' + b'\xff') + stream = utils.BytearrayStream() + long_int = primitives.LongInteger(primitives.LongInteger.MAX) + long_int.write(stream) - def test_write_value_zero(self): - encoding = (b'\x00\x00\x00\x00\x00\x00\x00\x00') - i = primitives.LongInteger(0) - i.write_value(self.stream) + result = stream.read() + self.assertEqual(len(encoding), len(result)) + self.assertEqual(encoding, result) - result = self.stream.read() - len_exp = len(encoding) - len_rcv = len(result) - - self.assertEqual(len_exp, len_rcv, self.bad_write.format(len_exp, - len_rcv)) - self.assertEqual(encoding, result, self.bad_encoding) - - def test_write_value_max_positive(self): - encoding = (b'\x7f\xff\xff\xff\xff\xff\xff\xff') - i = primitives.LongInteger(self.max_long) - i.write_value(self.stream) - - result = self.stream.read() - len_exp = len(encoding) - len_rcv = len(result) - - self.assertEqual(len_exp, len_rcv, self.bad_write.format(len_exp, - len_rcv)) - self.assertEqual(encoding, result, self.bad_encoding) - - def test_write_value_min_negative(self): - encoding = (b'\xff\xff\xff\xff\xff\xff\xff\xff') - i = primitives.LongInteger(-1) - i.write_value(self.stream) - - result = self.stream.read() - len_exp = len(encoding) - len_rcv = len(result) - - self.assertEqual(len_exp, len_rcv, self.bad_write.format(len_exp, - len_rcv)) - self.assertEqual(encoding, result, self.bad_encoding) - - def test_write(self): + def test_write_min_max(self): + """ + Test that a LongInteger representing the minimum positive value can be + written to a byte stream. + """ encoding = ( b'\x42\x00\x00\x03\x00\x00\x00\x08\x00\x00\x00\x00\x00\x00\x00' b'\x01') - i = primitives.LongInteger(1) - i.write(self.stream) + stream = utils.BytearrayStream() + long_int = primitives.LongInteger(1) + long_int.write(stream) - result = self.stream.read() - len_exp = len(encoding) - len_rcv = len(result) + result = stream.read() + self.assertEqual(len(encoding), len(result)) + self.assertEqual(encoding, result) - self.assertEqual(len_exp, len_rcv, self.bad_write.format(len_exp, - len_rcv)) - self.assertEqual(encoding, result, self.bad_encoding) + def test_write_max_min(self): + """ + Test that a LongInteger representing the maximum negative value can be + written to a byte stream. + """ + encoding = ( + b'\x42\x00\x00\x03\x00\x00\x00\x08\xff\xff\xff\xff\xff\xff\xff' + b'\xff') + stream = utils.BytearrayStream() + long_int = primitives.LongInteger(-1) + long_int.write(stream) + + result = stream.read() + self.assertEqual(len(encoding), len(result)) + self.assertEqual(encoding, result) + + def test_write_min_min(self): + """ + Test that a LongInteger representing the minimum negative value can be + written to a byte stream. + """ + encoding = ( + b'\x42\x00\x00\x03\x00\x00\x00\x08\x80\x00\x00\x00\x00\x00\x00' + b'\x00') + stream = utils.BytearrayStream() + long_int = primitives.LongInteger(primitives.LongInteger.MIN) + long_int.write(stream) + + result = stream.read() + self.assertEqual(len(encoding), len(result)) + self.assertEqual(encoding, result) + + def test_repr(self): + """ + Test that the representation of a LongInteger is formatted properly. + """ + long_int = primitives.LongInteger() + value = "value={0}".format(long_int.value) + tag = "tag={0}".format(long_int.tag) + self.assertEqual( + "LongInteger({0}, {1})".format(value, tag), repr(long_int)) + + def test_str(self): + """ + Test that the string representation of a LongInteger is formatted + properly. + """ + self.assertEqual("0", str(primitives.LongInteger())) + + def test_equal_on_equal(self): + """ + Test that the equality operator returns True when comparing two + LongIntegers. + """ + a = primitives.LongInteger(1) + b = primitives.LongInteger(1) + + self.assertTrue(a == b) + self.assertTrue(b == a) + + def test_equal_on_equal_and_empty(self): + """ + Test that the equality operator returns True when comparing two + LongIntegers. + """ + a = primitives.LongInteger() + b = primitives.LongInteger() + + self.assertTrue(a == b) + self.assertTrue(b == a) + + def test_equal_on_not_equal(self): + """ + Test that the equality operator returns False when comparing two + LongIntegers with different values. + """ + a = primitives.LongInteger(1) + b = primitives.LongInteger(2) + + self.assertFalse(a == b) + self.assertFalse(b == a) + + def test_equal_on_type_mismatch(self): + """ + Test that the equality operator returns False when comparing a + LongInteger to a non-LongInteger object. + """ + a = primitives.LongInteger() + b = 'invalid' + + self.assertFalse(a == b) + self.assertFalse(b == a) + + def test_not_equal_on_equal(self): + """ + Test that the inequality operator returns False when comparing + two LongIntegers with the same values. + """ + a = primitives.LongInteger(1) + b = primitives.LongInteger(1) + + self.assertFalse(a != b) + self.assertFalse(b != a) + + def test_not_equal_on_equal_and_empty(self): + """ + Test that the inequality operator returns False when comparing + two LongIntegers. + """ + a = primitives.LongInteger() + b = primitives.LongInteger() + + self.assertFalse(a != b) + self.assertFalse(b != a) + + def test_not_equal_on_not_equal(self): + """ + Test that the inequality operator returns True when comparing two + LongIntegers with different values. + """ + a = primitives.LongInteger(1) + b = primitives.LongInteger(2) + + self.assertTrue(a != b) + self.assertTrue(b != a) + + def test_not_equal_on_type_mismatch(self): + """ + Test that the inequality operator returns True when comparing a + LongInteger to a non-LongInteger object. + """ + a = primitives.LongInteger() + b = 'invalid' + + self.assertTrue(a != b) + self.assertTrue(b != a)