Protect singular partition key name in composite key CQL output

PYTHON-375
This commit is contained in:
Adam Holmberg
2015-08-06 09:07:51 -05:00
parent cc5b650f6a
commit 88c380a7b3
2 changed files with 20 additions and 8 deletions

View File

@@ -1311,7 +1311,7 @@ class TableMetadata(object):
if len(self.partition_key) > 1:
ret += "(%s)" % ", ".join(protect_name(col.name) for col in self.partition_key)
else:
ret += self.partition_key[0].name
ret += protect_name(self.partition_key[0].name)
if self.clustering_key:
ret += ", %s" % ", ".join(protect_name(col.name) for col in self.clustering_key)

View File

@@ -28,7 +28,7 @@ from cassandra.cluster import Cluster
from cassandra.cqltypes import DoubleType, Int32Type, ListType, UTF8Type, MapType
from cassandra.encoder import Encoder
from cassandra.metadata import (Metadata, KeyspaceMetadata, TableMetadata, IndexMetadata,
Token, MD5Token, TokenMap, murmur3, Function, Aggregate)
Token, MD5Token, TokenMap, murmur3, Function, Aggregate, protect_name, protect_names)
from cassandra.policies import SimpleConvictionPolicy
from cassandra.pool import Host
@@ -65,24 +65,24 @@ class SchemaMetadataTests(unittest.TestCase):
statement = "CREATE TABLE %s.%s (" % (self.ksname, self.cfname)
if len(partition_cols) == 1 and not clustering_cols:
statement += "%s text PRIMARY KEY, " % partition_cols[0]
statement += "%s text PRIMARY KEY, " % protect_name(partition_cols[0])
else:
statement += ", ".join("%s text" % col for col in partition_cols)
statement += ", ".join("%s text" % protect_name(col) for col in partition_cols)
statement += ", "
statement += ", ".join("%s text" % col for col in clustering_cols + other_cols)
statement += ", ".join("%s text" % protect_name(col) for col in clustering_cols + other_cols)
if len(partition_cols) != 1 or clustering_cols:
statement += ", PRIMARY KEY ("
if len(partition_cols) > 1:
statement += "(" + ", ".join(partition_cols) + ")"
statement += "(" + ", ".join(protect_names(partition_cols)) + ")"
else:
statement += partition_cols[0]
statement += protect_name(partition_cols[0])
if clustering_cols:
statement += ", "
statement += ", ".join(clustering_cols)
statement += ", ".join(protect_names(clustering_cols))
statement += ")"
@@ -149,6 +149,18 @@ class SchemaMetadataTests(unittest.TestCase):
self.check_create_statement(tablemeta, create_statement)
def test_compound_primary_keys_protected(self):
create_statement = self.make_create_statement(["Aa"], ["Bb"], ["Cc"])
create_statement += ' WITH CLUSTERING ORDER BY ("Bb" ASC)'
self.session.execute(create_statement)
tablemeta = self.get_table_metadata()
self.assertEqual([u'Aa'], [c.name for c in tablemeta.partition_key])
self.assertEqual([u'Bb'], [c.name for c in tablemeta.clustering_key])
self.assertEqual([u'Aa', u'Bb', u'Cc'], sorted(tablemeta.columns.keys()))
self.check_create_statement(tablemeta, create_statement)
def test_compound_primary_keys_more_columns(self):
create_statement = self.make_create_statement(["a"], ["b", "c"], ["d", "e", "f"])
create_statement += " WITH CLUSTERING ORDER BY (b ASC, c ASC)"