make table extensions for for views, too

PYTHON-682
This commit is contained in:
Adam Holmberg
2017-01-13 16:44:34 -06:00
parent ab9bafa5c0
commit f3612d41aa
2 changed files with 76 additions and 29 deletions

View File

@@ -1081,15 +1081,6 @@ class TableMetadata(object):
Metadata describing configuration for table extensions Metadata describing configuration for table extensions
""" """
_extension_registry = {}
class _RegisteredExtensionType(type):
def __new__(mcs, name, bases, dct):
cls = super(TableMetadata._RegisteredExtensionType, mcs).__new__(mcs, name, bases, dct)
if name != 'RegisteredTableExtension':
TableMetadata._extension_registry[cls.name] = cls
return cls
def __init__(self, keyspace_name, name, partition_key=None, clustering_key=None, columns=None, triggers=None, options=None): def __init__(self, keyspace_name, name, partition_key=None, clustering_key=None, columns=None, triggers=None, options=None):
self.keyspace_name = keyspace_name self.keyspace_name = keyspace_name
self.name = name self.name = name
@@ -1138,9 +1129,10 @@ class TableMetadata(object):
for view_meta in self.views.values(): for view_meta in self.views.values():
ret += "\n\n%s;" % (view_meta.as_cql_query(formatted=True),) ret += "\n\n%s;" % (view_meta.as_cql_query(formatted=True),)
if self.extensions: # None if self.extensions:
for k in six.viewkeys(self._extension_registry) & self.extensions: # no viewkeys on OrderedMapSerializeKey registry = _RegisteredExtensionType._extension_registry
ext = self._extension_registry[k] for k in six.viewkeys(registry) & self.extensions: # no viewkeys on OrderedMapSerializeKey
ext = registry[k]
cql = ext.after_table_cql(self, k, self.extensions[k]) cql = ext.after_table_cql(self, k, self.extensions[k])
if cql: if cql:
ret += "\n\n%s" % (cql,) ret += "\n\n%s" % (cql,)
@@ -1260,7 +1252,18 @@ class TableExtensionInterface(object):
pass pass
@six.add_metaclass(TableMetadata._RegisteredExtensionType) class _RegisteredExtensionType(type):
_extension_registry = {}
def __new__(mcs, name, bases, dct):
cls = super(_RegisteredExtensionType, mcs).__new__(mcs, name, bases, dct)
if name != 'RegisteredTableExtension':
mcs._extension_registry[cls.name] = cls
return cls
@six.add_metaclass(_RegisteredExtensionType)
class RegisteredTableExtension(TableExtensionInterface): class RegisteredTableExtension(TableExtensionInterface):
""" """
Extending this class registers it by name (associated by key in the `system_schema.tables.extensions` map). Extending this class registers it by name (associated by key in the `system_schema.tables.extensions` map).
@@ -2327,6 +2330,7 @@ class SchemaParserV3(SchemaParserV22):
view_meta = MaterializedViewMetadata(keyspace_name, view_name, base_table_name, view_meta = MaterializedViewMetadata(keyspace_name, view_name, base_table_name,
include_all_columns, where_clause, self._build_table_options(row)) include_all_columns, where_clause, self._build_table_options(row))
self._build_table_columns(view_meta, col_rows) self._build_table_columns(view_meta, col_rows)
view_meta.extensions = row.get('extensions', {})
return view_meta return view_meta
@@ -2487,6 +2491,11 @@ class MaterializedViewMetadata(object):
view. view.
""" """
extensions = None
"""
Metadata describing configuration for table extensions
"""
def __init__(self, keyspace_name, view_name, base_table_name, include_all_columns, where_clause, options): def __init__(self, keyspace_name, view_name, base_table_name, include_all_columns, where_clause, options):
self.keyspace_name = keyspace_name self.keyspace_name = keyspace_name
self.name = view_name self.name = view_name
@@ -2523,13 +2532,22 @@ class MaterializedViewMetadata(object):
properties = TableMetadataV3._property_string(formatted, self.clustering_key, self.options) properties = TableMetadataV3._property_string(formatted, self.clustering_key, self.options)
return "CREATE MATERIALIZED VIEW %(keyspace)s.%(name)s AS%(sep)s" \ ret = "CREATE MATERIALIZED VIEW %(keyspace)s.%(name)s AS%(sep)s" \
"SELECT %(selected_cols)s%(sep)s" \ "SELECT %(selected_cols)s%(sep)s" \
"FROM %(keyspace)s.%(base_table)s%(sep)s" \ "FROM %(keyspace)s.%(base_table)s%(sep)s" \
"WHERE %(where_clause)s%(sep)s" \ "WHERE %(where_clause)s%(sep)s" \
"PRIMARY KEY %(pk)s%(sep)s" \ "PRIMARY KEY %(pk)s%(sep)s" \
"WITH %(properties)s" % locals() "WITH %(properties)s" % locals()
if self.extensions:
registry = _RegisteredExtensionType._extension_registry
for k in six.viewkeys(registry) & self.extensions: # no viewkeys on OrderedMapSerializeKey
ext = registry[k]
cql = ext.after_table_cql(self, k, self.extensions[k])
if cql:
ret += "\n\n%s" % (cql,)
return ret
def export_as_string(self): def export_as_string(self):
return self.as_cql_query(formatted=True) + ";" return self.as_cql_query(formatted=True) + ";"

View File

@@ -29,13 +29,15 @@ from cassandra.cluster import Cluster
from cassandra.encoder import Encoder from cassandra.encoder import Encoder
from cassandra.metadata import (Metadata, KeyspaceMetadata, IndexMetadata, from cassandra.metadata import (Metadata, KeyspaceMetadata, IndexMetadata,
Token, MD5Token, TokenMap, murmur3, Function, Aggregate, protect_name, protect_names, Token, MD5Token, TokenMap, murmur3, Function, Aggregate, protect_name, protect_names,
get_schema_parser, RegisteredTableExtension) get_schema_parser, RegisteredTableExtension, _RegisteredExtensionType)
from cassandra.policies import SimpleConvictionPolicy from cassandra.policies import SimpleConvictionPolicy
from cassandra.pool import Host from cassandra.pool import Host
from tests.integration import get_cluster, use_singledc, PROTOCOL_VERSION, get_server_versions, execute_until_pass, \ from tests.integration import (get_cluster, use_singledc, PROTOCOL_VERSION, get_server_versions, execute_until_pass,
BasicSegregatedKeyspaceUnitTestCase, BasicSharedKeyspaceUnitTestCase, BasicExistingKeyspaceUnitTestCase, drop_keyspace_shutdown_cluster, CASSANDRA_VERSION, \ BasicSegregatedKeyspaceUnitTestCase, BasicSharedKeyspaceUnitTestCase,
BasicExistingSegregatedKeyspaceUnitTestCase, dseonly, DSE_VERSION, get_supported_protocol_versions, greaterthanorequalcass30 BasicExistingKeyspaceUnitTestCase, drop_keyspace_shutdown_cluster, CASSANDRA_VERSION,
BasicExistingSegregatedKeyspaceUnitTestCase, dseonly, DSE_VERSION,
get_supported_protocol_versions, greaterthanorequalcass30)
def setup_module(): def setup_module():
@@ -864,15 +866,19 @@ class SchemaMetadataTests(BasicSegregatedKeyspaceUnitTestCase):
ks = self.keyspace_name ks = self.keyspace_name
ks_meta = s.cluster.metadata.keyspaces[ks] ks_meta = s.cluster.metadata.keyspaces[ks]
t = self.function_table_name t = self.function_table_name
v = t + 'view'
s.execute("CREATE TABLE %s.%s (k text PRIMARY KEY, v int)" % (ks, t)) s.execute("CREATE TABLE %s.%s (k text PRIMARY KEY, v int)" % (ks, t))
s.execute("CREATE MATERIALIZED VIEW %s.%s AS SELECT * FROM %s.%s WHERE v IS NOT NULL PRIMARY KEY (v, k)" % (ks, v, ks, t))
table_meta = ks_meta.tables[t] table_meta = ks_meta.tables[t]
view_meta = table_meta.views[v]
self.assertFalse(table_meta.extensions) self.assertFalse(table_meta.extensions)
self.assertNotIn(t, table_meta._extension_registry) self.assertFalse(view_meta.extensions)
original_cql = table_meta.export_as_string() original_table_cql = table_meta.export_as_string()
original_view_cql = view_meta.export_as_string()
# extensions registered, not present # extensions registered, not present
# -------------------------------------- # --------------------------------------
@@ -887,44 +893,67 @@ class SchemaMetadataTests(BasicSegregatedKeyspaceUnitTestCase):
name = t + '##' name = t + '##'
self.assertFalse(table_meta.extensions) self.assertFalse(table_meta.extensions)
self.assertIn(Ext0.name, table_meta._extension_registry) self.assertFalse(view_meta.extensions)
self.assertIn(Ext1.name, table_meta._extension_registry) self.assertIn(Ext0.name, _RegisteredExtensionType._extension_registry)
self.assertEqual(len(table_meta._extension_registry), 2) self.assertIn(Ext1.name, _RegisteredExtensionType._extension_registry)
self.assertEqual(len(_RegisteredExtensionType._extension_registry), 2)
self.cluster.refresh_table_metadata(ks, t) self.cluster.refresh_table_metadata(ks, t)
table_meta = ks_meta.tables[t] table_meta = ks_meta.tables[t]
view_meta = table_meta.views[v]
self.assertEqual(table_meta.export_as_string(), original_cql) self.assertEqual(table_meta.export_as_string(), original_table_cql)
self.assertEqual(view_meta.export_as_string(), original_view_cql)
p = s.prepare('UPDATE system_schema.tables SET extensions=? WHERE keyspace_name=? AND table_name=?') # for blob type coercing update_t = s.prepare('UPDATE system_schema.tables SET extensions=? WHERE keyspace_name=? AND table_name=?') # for blob type coercing
update_v = s.prepare('UPDATE system_schema.views SET extensions=? WHERE keyspace_name=? AND view_name=?')
# extensions registered, one present # extensions registered, one present
# -------------------------------------- # --------------------------------------
ext_map = {Ext0.name: six.b("THA VALUE")} ext_map = {Ext0.name: six.b("THA VALUE")}
[s.execute(p, (ext_map, ks, t)) for _ in self.cluster.metadata.all_hosts()] # we're manipulating metadata - do it on all hosts [(s.execute(update_t, (ext_map, ks, t)), s.execute(update_v, (ext_map, ks, v)))
for _ in self.cluster.metadata.all_hosts()] # we're manipulating metadata - do it on all hosts
self.cluster.refresh_table_metadata(ks, t) self.cluster.refresh_table_metadata(ks, t)
self.cluster.refresh_materialized_view_metadata(ks, v)
table_meta = ks_meta.tables[t] table_meta = ks_meta.tables[t]
view_meta = table_meta.views[v]
self.assertIn(Ext0.name, table_meta.extensions) self.assertIn(Ext0.name, table_meta.extensions)
new_cql = table_meta.export_as_string() new_cql = table_meta.export_as_string()
self.assertNotEqual(new_cql, original_cql) self.assertNotEqual(new_cql, original_table_cql)
self.assertIn(Ext0.after_table_cql(table_meta, Ext0.name, ext_map[Ext0.name]), new_cql) self.assertIn(Ext0.after_table_cql(table_meta, Ext0.name, ext_map[Ext0.name]), new_cql)
self.assertNotIn(Ext1.name, new_cql) self.assertNotIn(Ext1.name, new_cql)
self.assertIn(Ext0.name, view_meta.extensions)
new_cql = view_meta.export_as_string()
self.assertNotEqual(new_cql, original_view_cql)
self.assertIn(Ext0.after_table_cql(view_meta, Ext0.name, ext_map[Ext0.name]), new_cql)
self.assertNotIn(Ext1.name, new_cql)
# extensions registered, one present # extensions registered, one present
# -------------------------------------- # --------------------------------------
ext_map = {Ext0.name: six.b("THA VALUE"), ext_map = {Ext0.name: six.b("THA VALUE"),
Ext1.name: six.b("OTHA VALUE")} Ext1.name: six.b("OTHA VALUE")}
[s.execute(p, (ext_map, ks, t)) for _ in self.cluster.metadata.all_hosts()] # we're manipulating metadata - do it on all hosts [(s.execute(update_t, (ext_map, ks, t)), s.execute(update_v, (ext_map, ks, v)))
for _ in self.cluster.metadata.all_hosts()] # we're manipulating metadata - do it on all hosts
self.cluster.refresh_table_metadata(ks, t) self.cluster.refresh_table_metadata(ks, t)
self.cluster.refresh_materialized_view_metadata(ks, v)
table_meta = ks_meta.tables[t] table_meta = ks_meta.tables[t]
view_meta = table_meta.views[v]
self.assertIn(Ext0.name, table_meta.extensions) self.assertIn(Ext0.name, table_meta.extensions)
self.assertIn(Ext1.name, table_meta.extensions) self.assertIn(Ext1.name, table_meta.extensions)
new_cql = table_meta.export_as_string() new_cql = table_meta.export_as_string()
self.assertNotEqual(new_cql, original_cql) self.assertNotEqual(new_cql, original_table_cql)
self.assertIn(Ext0.after_table_cql(table_meta, Ext0.name, ext_map[Ext0.name]), new_cql) self.assertIn(Ext0.after_table_cql(table_meta, Ext0.name, ext_map[Ext0.name]), new_cql)
self.assertIn(Ext1.after_table_cql(table_meta, Ext1.name, ext_map[Ext1.name]), new_cql) self.assertIn(Ext1.after_table_cql(table_meta, Ext1.name, ext_map[Ext1.name]), new_cql)
self.assertIn(Ext0.name, view_meta.extensions)
self.assertIn(Ext1.name, view_meta.extensions)
new_cql = view_meta.export_as_string()
self.assertNotEqual(new_cql, original_view_cql)
self.assertIn(Ext0.after_table_cql(view_meta, Ext0.name, ext_map[Ext0.name]), new_cql)
self.assertIn(Ext1.after_table_cql(view_meta, Ext1.name, ext_map[Ext1.name]), new_cql)
class TestCodeCoverage(unittest.TestCase): class TestCodeCoverage(unittest.TestCase):