diff --git a/cassandra/metadata.py b/cassandra/metadata.py index 28ed1747..c893857b 100644 --- a/cassandra/metadata.py +++ b/cassandra/metadata.py @@ -1311,7 +1311,7 @@ class TableMetadata(object): if len(self.partition_key) > 1: ret += "(%s)" % ", ".join(protect_name(col.name) for col in self.partition_key) else: - ret += self.partition_key[0].name + ret += protect_name(self.partition_key[0].name) if self.clustering_key: ret += ", %s" % ", ".join(protect_name(col.name) for col in self.clustering_key) diff --git a/tests/integration/standard/test_metadata.py b/tests/integration/standard/test_metadata.py index 1d7f68df..5dc3181e 100644 --- a/tests/integration/standard/test_metadata.py +++ b/tests/integration/standard/test_metadata.py @@ -28,7 +28,7 @@ from cassandra.cluster import Cluster from cassandra.cqltypes import DoubleType, Int32Type, ListType, UTF8Type, MapType from cassandra.encoder import Encoder from cassandra.metadata import (Metadata, KeyspaceMetadata, TableMetadata, IndexMetadata, - Token, MD5Token, TokenMap, murmur3, Function, Aggregate) + Token, MD5Token, TokenMap, murmur3, Function, Aggregate, protect_name, protect_names) from cassandra.policies import SimpleConvictionPolicy from cassandra.pool import Host @@ -65,24 +65,24 @@ class SchemaMetadataTests(unittest.TestCase): statement = "CREATE TABLE %s.%s (" % (self.ksname, self.cfname) if len(partition_cols) == 1 and not clustering_cols: - statement += "%s text PRIMARY KEY, " % partition_cols[0] + statement += "%s text PRIMARY KEY, " % protect_name(partition_cols[0]) else: - statement += ", ".join("%s text" % col for col in partition_cols) + statement += ", ".join("%s text" % protect_name(col) for col in partition_cols) statement += ", " - statement += ", ".join("%s text" % col for col in clustering_cols + other_cols) + statement += ", ".join("%s text" % protect_name(col) for col in clustering_cols + other_cols) if len(partition_cols) != 1 or clustering_cols: statement += ", PRIMARY KEY (" if len(partition_cols) > 1: - statement += "(" + ", ".join(partition_cols) + ")" + statement += "(" + ", ".join(protect_names(partition_cols)) + ")" else: - statement += partition_cols[0] + statement += protect_name(partition_cols[0]) if clustering_cols: statement += ", " - statement += ", ".join(clustering_cols) + statement += ", ".join(protect_names(clustering_cols)) statement += ")" @@ -149,6 +149,18 @@ class SchemaMetadataTests(unittest.TestCase): self.check_create_statement(tablemeta, create_statement) + def test_compound_primary_keys_protected(self): + create_statement = self.make_create_statement(["Aa"], ["Bb"], ["Cc"]) + create_statement += ' WITH CLUSTERING ORDER BY ("Bb" ASC)' + self.session.execute(create_statement) + tablemeta = self.get_table_metadata() + + self.assertEqual([u'Aa'], [c.name for c in tablemeta.partition_key]) + self.assertEqual([u'Bb'], [c.name for c in tablemeta.clustering_key]) + self.assertEqual([u'Aa', u'Bb', u'Cc'], sorted(tablemeta.columns.keys())) + + self.check_create_statement(tablemeta, create_statement) + def test_compound_primary_keys_more_columns(self): create_statement = self.make_create_statement(["a"], ["b", "c"], ["d", "e", "f"]) create_statement += " WITH CLUSTERING ORDER BY (b ASC, c ASC)"