diff --git a/cassandra/cqltypes.py b/cassandra/cqltypes.py index 82cf09e7..6fd98ae4 100644 --- a/cassandra/cqltypes.py +++ b/cassandra/cqltypes.py @@ -369,7 +369,7 @@ class DecimalType(_CassandraType): try: sign, digits, exponent = dec.as_tuple() except AttributeError: - raise TypeError("Non-Decimal type received for Decimal value") + sign, digits, exponent = Decimal(dec).as_tuple() unscaled = int(''.join([str(digit) for digit in digits])) if sign: unscaled *= -1 diff --git a/tests/unit/test_marshalling.py b/tests/unit/test_marshalling.py index 3b98280c..f097a7cb 100644 --- a/tests/unit/test_marshalling.py +++ b/tests/unit/test_marshalling.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from cassandra.marshal import bitlength +from cassandra.protocol import MAX_SUPPORTED_VERSION try: import unittest2 as unittest @@ -139,3 +140,11 @@ class UnmarshalTest(unittest.TestCase): def test_date(self): # separate test because it will deserialize as datetime self.assertEqual(DateType.from_binary(DateType.to_binary(date(2015, 11, 2), 1), 1), datetime(2015, 11, 2)) + + def test_decimal(self): + # testing implicit numeric conversion + # int, float, tuple(sign, digits, exp) + for proto_ver in range(1, MAX_SUPPORTED_VERSION + 1): + for n in (10001, 100.1, (0, (1, 0, 0, 0, 0, 1), -3)): + expected = Decimal(n) + self.assertEqual(DecimalType.from_binary(DecimalType.to_binary(n, proto_ver), proto_ver), expected)