diff --git a/cassandra/cqlengine/management.py b/cassandra/cqlengine/management.py index 344b262c..19f8e9fa 100644 --- a/cassandra/cqlengine/management.py +++ b/cassandra/cqlengine/management.py @@ -151,7 +151,11 @@ def sync_table(model): cluster = get_cluster() - keyspace = cluster.metadata.keyspaces[ks_name] + try: + keyspace = cluster.metadata.keyspaces[ks_name] + except KeyError: + raise CQLEngineException("Keyspace '{0}' for model {1} does not exist.".format(ks_name, model)) + tables = keyspace.tables syncd_types = set() @@ -161,7 +165,6 @@ def sync_table(model): for udt in [u for u in udts if u not in syncd_types]: _sync_type(ks_name, udt, syncd_types) - # check for an existing column family if raw_cf_name not in tables: log.debug("sync_table creating new table %s", cf_name) qs = _get_create_table(model) @@ -175,32 +178,35 @@ def sync_table(model): raise else: log.debug("sync_table checking existing table %s", cf_name) - # see if we're missing any columns table_meta = tables[raw_cf_name] - field_names = _get_non_pk_field_names(tables[raw_cf_name]) - model_fields = set() - # # TODO: does this work with db_name?? - for name, col in model._columns.items(): - if col.primary_key or col.partition_key: - continue # we can't mess with the PK - model_fields.add(name) - if col.db_field_name in field_names: - col_meta = table_meta.columns[col.db_field_name] + _validate_pk(model, table_meta) + + table_columns = table_meta.columns + model_fields = set() + + for model_name, col in model._columns.items(): + db_name = col.db_field_name + model_fields.add(db_name) + if db_name in table_columns: + col_meta = table_columns[db_name] if col_meta.cql_type != col.db_type: msg = 'Existing table {0} has column "{1}" with a type ({2}) differing from the model type ({3}).' \ - ' Model should be updated.'.format(cf_name, col.db_field_name, col_meta.cql_type, col.db_type) + ' Model should be updated.'.format(cf_name, db_name, col_meta.cql_type, col.db_type) warnings.warn(msg) log.warning(msg) - continue # skip columns already defined - # add missing column using the column def + continue + + if col.primary_key or col.primary_key: + raise CQLEngineException("Cannot add primary key '{0}' (with db_field '{1}') to existing table {2}".format(model_name, db_name, cf_name)) + query = "ALTER TABLE {0} add {1}".format(cf_name, col.get_column_def()) execute(query) - db_fields_not_in_model = model_fields.symmetric_difference(field_names) + db_fields_not_in_model = model_fields.symmetric_difference(table_columns) if db_fields_not_in_model: - log.info("Table %s has fields not referenced by model: %s", cf_name, db_fields_not_in_model) + log.info("Table {0} has fields not referenced by model: {1}".format(cf_name, db_fields_not_in_model)) _update_options(model) @@ -221,6 +227,22 @@ def sync_table(model): execute(qs) +def _validate_pk(model, table_meta): + model_partition = [c.db_field_name for c in model._partition_keys.values()] + meta_partition = [c.name for c in table_meta.partition_key] + model_clustering = [c.db_field_name for c in model._clustering_keys.values()] + meta_clustering = [c.name for c in table_meta.clustering_key] + + if model_partition != meta_partition or model_clustering != meta_clustering: + def _pk_string(partition, clustering): + return "PRIMARY KEY (({0}){1})".format(', '.join(partition), ', ' + ', '.join(clustering) if clustering else '') + raise CQLEngineException("Model {0} PRIMARY KEY composition does not match existing table {1}. " + "Model: {2}; Table: {3}. " + "Update model or drop the table.".format(model, model.column_family_name(), + _pk_string(model_partition, model_clustering), + _pk_string(meta_partition, meta_clustering))) + + def sync_type(ks_name, type_model): """ Inspects the type_model and creates / updates the corresponding type. @@ -339,13 +361,6 @@ def _get_create_table(model): return ' '.join(query_strings) -def _get_non_pk_field_names(table_meta): - # returns all fields that aren't part of the PK - pk_names = set(col.name for col in table_meta.primary_key) - field_names = [name for name in table_meta.columns.keys() if name not in pk_names] - return field_names - - def _get_table_metadata(model): # returns the table as provided by the native driver for a given model cluster = get_cluster() diff --git a/tests/integration/cqlengine/management/test_management.py b/tests/integration/cqlengine/management/test_management.py index 109ab22f..e4b35e21 100644 --- a/tests/integration/cqlengine/management/test_management.py +++ b/tests/integration/cqlengine/management/test_management.py @@ -21,7 +21,7 @@ import logging from cassandra.cqlengine.connection import get_session, get_cluster from cassandra.cqlengine import CQLEngineException from cassandra.cqlengine import management -from cassandra.cqlengine.management import _get_non_pk_field_names, _get_table_metadata, sync_table, drop_table, sync_type +from cassandra.cqlengine.management import _get_table_metadata, sync_table, drop_table, sync_type from cassandra.cqlengine.models import Model from cassandra.cqlengine import columns @@ -128,7 +128,7 @@ class FourthModel(Model): first_key = columns.UUID(primary_key=True) second_key = columns.UUID() third_key = columns.Text() - # removed fourth key, but it should stay in the DB + # renamed model field, but map to existing column renamed = columns.Map(columns.Text, columns.Text, db_field='blah') @@ -138,23 +138,31 @@ class AddColumnTest(BaseCassEngTestCase): def test_add_column(self): sync_table(FirstModel) - fields = _get_non_pk_field_names(_get_table_metadata(FirstModel)) + meta_columns = _get_table_metadata(FirstModel).columns + self.assertEqual(set(meta_columns), set(FirstModel._columns)) - # this should contain the second key - self.assertEqual(len(fields), 2) - # get schema sync_table(SecondModel) - - fields = _get_non_pk_field_names(_get_table_metadata(FirstModel)) - self.assertEqual(len(fields), 3) + meta_columns = _get_table_metadata(FirstModel).columns + self.assertEqual(set(meta_columns), set(SecondModel._columns)) sync_table(ThirdModel) - fields = _get_non_pk_field_names(_get_table_metadata(FirstModel)) - self.assertEqual(len(fields), 4) + meta_columns = _get_table_metadata(FirstModel).columns + self.assertEqual(len(meta_columns), 5) + self.assertEqual(len(ThirdModel._columns), 4) + self.assertIn('fourth_key', meta_columns) + self.assertNotIn('fourth_key', ThirdModel._columns) + self.assertIn('blah', ThirdModel._columns) + self.assertIn('blah', meta_columns) sync_table(FourthModel) - fields = _get_non_pk_field_names(_get_table_metadata(FirstModel)) - self.assertEqual(len(fields), 4) + meta_columns = _get_table_metadata(FirstModel).columns + self.assertEqual(len(meta_columns), 5) + self.assertEqual(len(ThirdModel._columns), 4) + self.assertIn('fourth_key', meta_columns) + self.assertNotIn('fourth_key', FourthModel._columns) + self.assertIn('renamed', FourthModel._columns) + self.assertNotIn('renamed', meta_columns) + self.assertIn('blah', meta_columns) class ModelWithTableProperties(Model):