Sanitize column names for named_tuple_factory
Fixes PYTHON-31
This commit is contained in:
@@ -1,31 +1,17 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from binascii import hexlify
|
||||
from collections import namedtuple
|
||||
import datetime
|
||||
import logging
|
||||
import re
|
||||
import socket
|
||||
import types
|
||||
from uuid import UUID
|
||||
|
||||
try:
|
||||
from collections import OrderedDict
|
||||
except ImportError: # Python <2.7
|
||||
from cassandra.util import OrderedDict # NOQA
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import socket
|
||||
import types
|
||||
from uuid import UUID
|
||||
try:
|
||||
from cStringIO import StringIO
|
||||
except ImportError:
|
||||
@@ -60,12 +46,20 @@ HEADER_DIRECTION_TO_CLIENT = 0x80
|
||||
HEADER_DIRECTION_MASK = 0x80
|
||||
|
||||
|
||||
NON_ALPHA_REGEX = re.compile('\W')
|
||||
END_UNDERSCORE_REGEX = re.compile('^_*(\w*[a-zA-Z0-9])_*$')
|
||||
|
||||
|
||||
def _clean_column_name(name):
|
||||
return END_UNDERSCORE_REGEX.sub("\g<1>", NON_ALPHA_REGEX.sub("_", name))
|
||||
|
||||
|
||||
def tuple_factory(colnames, rows):
|
||||
return rows
|
||||
|
||||
|
||||
def named_tuple_factory(colnames, rows):
|
||||
Row = namedtuple('Row', colnames)
|
||||
Row = namedtuple('Row', map(_clean_column_name, colnames))
|
||||
return [Row(*row) for row in rows]
|
||||
|
||||
|
||||
|
@@ -1,10 +1,9 @@
|
||||
import unittest
|
||||
import datetime
|
||||
import cassandra
|
||||
from cassandra.cqltypes import CassandraType, BooleanType, lookup_casstype_simple, lookup_casstype, \
|
||||
AsciiType, LongType, DecimalType, DoubleType, FloatType, Int32Type, UTF8Type, IntegerType, SetType, cql_typename
|
||||
|
||||
from cassandra.cluster import Cluster
|
||||
from cassandra.cqltypes import (BooleanType, lookup_casstype_simple, lookup_casstype,
|
||||
LongType, DecimalType, SetType, cql_typename)
|
||||
from cassandra.decoder import named_tuple_factory
|
||||
|
||||
|
||||
class TypeTests(unittest.TestCase):
|
||||
@@ -105,3 +104,12 @@ class TypeTests(unittest.TestCase):
|
||||
|
||||
self.assertEqual(cql_typename('DateType'), 'timestamp')
|
||||
self.assertEqual(cql_typename('org.apache.cassandra.db.marshal.ListType(IntegerType)'), 'list<varint>')
|
||||
|
||||
def test_named_tuple_colname_substitution(self):
|
||||
colnames = ("func(abc)", "[applied]", "func(func(abc))", "foo_bar")
|
||||
rows = [(1, 2, 3, 4)]
|
||||
result = named_tuple_factory(colnames, rows)[0]
|
||||
self.assertEqual(result[0], result.func_abc)
|
||||
self.assertEqual(result[1], result.applied)
|
||||
self.assertEqual(result[2], result.func_func_abc)
|
||||
self.assertEqual(result[3], result.foo_bar)
|
||||
|
Reference in New Issue
Block a user