Add connection selection support to mangement commands

This commit is contained in:
Alan Boudreault 2016-08-23 15:42:42 -04:00
parent f4a82ee073
commit 0b1d012524
4 changed files with 146 additions and 94 deletions

View File

@ -29,6 +29,7 @@ log = logging.getLogger(__name__)
NOT_SET = _NOT_SET # required for passing timeout to Session.execute
# connections registry
DEFAULT_CONNECTION = '_default_'
_connections = {}
# Because type models may be registered before a connection is present,
@ -124,18 +125,19 @@ def register_connection(name, hosts, consistency=None, lazy_connect=False,
_connections[name] = conn
if default:
_connections['_default_'] = conn
_connections[DEFAULT_CONNECTION] = conn
conn.setup()
return conn
def get_connection(name=None):
if not name:
name = '_default_'
name = DEFAULT_CONNECTION
if name not in _connections:
raise ValueError("Connection name '{0}' doesn't exist in the registry.".format(name))
raise CQLEngineException("Connection name '{0}' doesn't exist in the registry.".format(name))
conn = _connections[name]
conn.handle_lazy_connect()
@ -211,14 +213,13 @@ def setup(
from cassandra.cqlengine import models
models.DEFAULT_KEYSPACE = default_keyspace
conn = register_connection('default', hosts=hosts, consistency=consistency, lazy_connect=lazy_connect,
retry_connect=retry_connect, cluster_options=kwargs, default=True)
conn.setup()
register_connection('default', hosts=hosts, consistency=consistency, lazy_connect=lazy_connect,
retry_connect=retry_connect, cluster_options=kwargs, default=True)
def execute(query, params=None, consistency_level=None, timeout=NOT_SET):
def execute(query, params=None, consistency_level=None, timeout=NOT_SET, connection=None):
conn = get_connection()
conn = get_connection(connection)
if not conn.session:
raise CQLEngineException("It is required to setup() cqlengine before executing queries")
@ -238,22 +239,22 @@ def execute(query, params=None, consistency_level=None, timeout=NOT_SET):
return result
def get_session():
conn = get_connection()
def get_session(connection=None):
conn = get_connection(connection)
return conn.session
def get_cluster():
conn = get_connection()
def get_cluster(connection=None):
conn = get_connection(connection)
if not conn.cluster:
raise CQLEngineException("%s.cluster is not configured. Call one of the setup or default functions first." % __name__)
return conn.cluster
def register_udt(keyspace, type_name, klass):
def register_udt(keyspace, type_name, klass, connection=None):
udt_by_keyspace[keyspace][type_name] = klass
cluster = get_cluster()
cluster = get_cluster(connection)
if cluster:
try:
cluster.register_user_type(keyspace, type_name, klass)

View File

@ -18,11 +18,12 @@ import logging
import os
import six
import warnings
from itertools import product
from cassandra import metadata
from cassandra.cqlengine import CQLEngineException
from cassandra.cqlengine import columns, query
from cassandra.cqlengine.connection import execute, get_cluster
from cassandra.cqlengine.connection import execute, get_cluster, DEFAULT_CONNECTION
from cassandra.cqlengine.models import Model
from cassandra.cqlengine.named import NamedTable
from cassandra.cqlengine.usertype import UserType
@ -37,7 +38,34 @@ log = logging.getLogger(__name__)
schema_columnfamilies = NamedTable('system', 'schema_columnfamilies')
def create_keyspace_simple(name, replication_factor, durable_writes=True):
def get_context(keyspaces, connections):
"""Return the execution context"""
if keyspaces:
if not isinstance(keyspaces, (list, tuple)):
raise ValueError('keyspaces must be a list or a tuple.')
if connections:
if not isinstance(connections, (list, tuple)):
raise ValueError('connections must be a list or a tuple.')
keyspaces = keyspaces if keyspaces else [None]
connections = connections if connections else [None]
return product(connections, keyspaces)
def log_msg(msg, connection=None, keyspace=None):
"""Format log message to add keyspace and connection context"""
connection_info = connection if connection else DEFAULT_CONNECTION
if keyspace:
msg = '[Connection: {0}, Keyspace: {1}] {2}'.format(connection_info, keyspace, msg)
else:
msg = '[Connection: {0}] {1}'.format(connection_info, msg)
return msg
def create_keyspace_simple(name, replication_factor, durable_writes=True, connections=None):
"""
Creates a keyspace with SimpleStrategy for replica placement
@ -51,12 +79,13 @@ def create_keyspace_simple(name, replication_factor, durable_writes=True):
:param str name: name of keyspace to create
:param int replication_factor: keyspace replication factor, used with :attr:`~.SimpleStrategy`
:param bool durable_writes: Write log is bypassed if set to False
:param str connections: List of connection names
"""
_create_keyspace(name, durable_writes, 'SimpleStrategy',
{'replication_factor': replication_factor})
{'replication_factor': replication_factor}, connections=connections)
def create_keyspace_network_topology(name, dc_replication_map, durable_writes=True):
def create_keyspace_network_topology(name, dc_replication_map, durable_writes=True, connections=None):
"""
Creates a keyspace with NetworkTopologyStrategy for replica placement
@ -70,25 +99,37 @@ def create_keyspace_network_topology(name, dc_replication_map, durable_writes=Tr
:param str name: name of keyspace to create
:param dict dc_replication_map: map of dc_names: replication_factor
:param bool durable_writes: Write log is bypassed if set to False
:param str connections: List of connection names
"""
_create_keyspace(name, durable_writes, 'NetworkTopologyStrategy', dc_replication_map)
_create_keyspace(name, durable_writes, 'NetworkTopologyStrategy', dc_replication_map, connections=connections)
def _create_keyspace(name, durable_writes, strategy_class, strategy_options):
def _create_keyspace(name, durable_writes, strategy_class, strategy_options, connections=None):
if not _allow_schema_modification():
return
cluster = get_cluster()
if connections:
if not isinstance(connections, (list, tuple)):
raise ValueError('Connections must be a list or a tuple.')
if name not in cluster.metadata.keyspaces:
log.info("Creating keyspace %s ", name)
ks_meta = metadata.KeyspaceMetadata(name, durable_writes, strategy_class, strategy_options)
execute(ks_meta.as_cql_query())
def __create_keyspace(name, durable_writes, strategy_class, strategy_options, connection=None):
cluster = get_cluster(connection)
if name not in cluster.metadata.keyspaces:
log.info(log_msg("Creating keyspace %s", connection=connection), name)
ks_meta = metadata.KeyspaceMetadata(name, durable_writes, strategy_class, strategy_options)
execute(ks_meta.as_cql_query(), connection=connection)
else:
log.info(log_msg("Not creating keyspace %s because it already exists", connection=connection), name)
if connections:
for connection in connections:
__create_keyspace(name, durable_writes, strategy_class, strategy_options, connection=connection)
else:
log.info("Not creating keyspace %s because it already exists", name)
__create_keyspace(name, durable_writes, strategy_class, strategy_options)
def drop_keyspace(name):
def drop_keyspace(name, connections=None):
"""
Drops a keyspace, if it exists.
@ -98,14 +139,25 @@ def drop_keyspace(name):
Take care to execute schema modifications in a single context (i.e. not concurrently with other clients).**
:param str name: name of keyspace to drop
:param str connections: List of connection names
"""
if not _allow_schema_modification():
return
cluster = get_cluster()
if name in cluster.metadata.keyspaces:
execute("DROP KEYSPACE {0}".format(metadata.protect_name(name)))
if connections:
if not isinstance(connections, (list, tuple)):
raise ValueError('Connections must be a list or a tuple.')
def _drop_keyspace(name, connection=None):
cluster = get_cluster(connection)
if name in cluster.metadata.keyspaces:
execute("DROP KEYSPACE {0}".format(metadata.protect_name(name)), connection=connection)
if connections:
for connection in connections:
_drop_keyspace(name, connection)
else:
_drop_keyspace(name)
def _get_index_name_by_column(table, column_name):
"""
@ -119,7 +171,7 @@ def _get_index_name_by_column(table, column_name):
return index_metadata.name
def sync_table(model, keyspaces=None):
def sync_table(model, keyspaces=None, connections=None):
"""
Inspects the model and creates / updates the corresponding table and columns.
@ -138,19 +190,13 @@ def sync_table(model, keyspaces=None):
*There are plans to guard schema-modifying functions with an environment-driven conditional.*
"""
if keyspaces:
if not isinstance(keyspaces, (list, tuple)):
raise ValueError('keyspaces must be a list or a tuple.')
for keyspace in keyspaces:
with query.ContextQuery(model, keyspace=keyspace) as m:
_sync_table(m)
else:
_sync_table(model)
context = get_context(keyspaces, connections)
for connection, keyspace in context:
with query.ContextQuery(model, keyspace=keyspace) as m:
_sync_table(m, connection=connection)
def _sync_table(model):
def _sync_table(model, connection=None):
if not _allow_schema_modification():
return
@ -165,12 +211,13 @@ def _sync_table(model):
ks_name = model._get_keyspace()
cluster = get_cluster()
cluster = get_cluster(connection)
try:
keyspace = cluster.metadata.keyspaces[ks_name]
except KeyError:
raise CQLEngineException("Keyspace '{0}' for model {1} does not exist.".format(ks_name, model))
msg = log_msg("Keyspace '{0}' for model {1} does not exist.", connection=connection)
raise CQLEngineException(msg.format(ks_name, model))
tables = keyspace.tables
@ -179,21 +226,21 @@ def _sync_table(model):
udts = []
columns.resolve_udts(col, udts)
for udt in [u for u in udts if u not in syncd_types]:
_sync_type(ks_name, udt, syncd_types)
_sync_type(ks_name, udt, syncd_types, connection=connection)
if raw_cf_name not in tables:
log.debug("sync_table creating new table %s", cf_name)
log.debug(log_msg("sync_table creating new table %s", keyspace=ks_name, connection=connection), cf_name)
qs = _get_create_table(model)
try:
execute(qs)
execute(qs, connection=connection)
except CQLEngineException as ex:
# 1.2 doesn't return cf names, so we have to examine the exception
# and ignore if it says the column family already exists
if "Cannot add already existing column family" not in unicode(ex):
raise
else:
log.debug("sync_table checking existing table %s", cf_name)
log.debug(log_msg("sync_table checking existing table %s", keyspace=ks_name, connection=connection), cf_name)
table_meta = tables[raw_cf_name]
_validate_pk(model, table_meta)
@ -207,24 +254,27 @@ def _sync_table(model):
if db_name in table_columns:
col_meta = table_columns[db_name]
if col_meta.cql_type != col.db_type:
msg = 'Existing table {0} has column "{1}" with a type ({2}) differing from the model type ({3}).' \
' Model should be updated.'.format(cf_name, db_name, col_meta.cql_type, col.db_type)
msg = log_msg('Existing table {0} has column "{1}" with a type ({2}) differing from the model type ({3}).'
' Model should be updated.', keyspace=ks_name, connection=connection)
msg = msg.format(cf_name, db_name, col_meta.cql_type, col.db_type)
warnings.warn(msg)
log.warning(msg)
continue
if col.primary_key or col.primary_key:
raise CQLEngineException("Cannot add primary key '{0}' (with db_field '{1}') to existing table {2}".format(model_name, db_name, cf_name))
msg = log_msg("Cannot add primary key '{0}' (with db_field '{1}') to existing table {2}", keyspace=ks_name, connection=connection)
raise CQLEngineException(msg.format(model_name, db_name, cf_name))
query = "ALTER TABLE {0} add {1}".format(cf_name, col.get_column_def())
execute(query)
execute(query, connection=connection)
db_fields_not_in_model = model_fields.symmetric_difference(table_columns)
if db_fields_not_in_model:
log.info("Table {0} has fields not referenced by model: {1}".format(cf_name, db_fields_not_in_model))
msg = log_msg("Table {0} has fields not referenced by model: {1}", keyspace=ks_name, connection=connection)
log.info(msg.format(cf_name, db_fields_not_in_model))
_update_options(model)
_update_options(model, connection=connection)
table = cluster.metadata.keyspaces[ks_name].tables[raw_cf_name]
@ -240,7 +290,7 @@ def _sync_table(model):
qs += ['ON {0}'.format(cf_name)]
qs += ['("{0}")'.format(column.db_field_name)]
qs = ' '.join(qs)
execute(qs)
execute(qs, connection=connection)
def _validate_pk(model, table_meta):
@ -259,7 +309,7 @@ def _validate_pk(model, table_meta):
_pk_string(meta_partition, meta_clustering)))
def sync_type(ks_name, type_model):
def sync_type(ks_name, type_model, connection=None):
"""
Inspects the type_model and creates / updates the corresponding type.
@ -277,33 +327,33 @@ def sync_type(ks_name, type_model):
if not issubclass(type_model, UserType):
raise CQLEngineException("Types must be derived from base UserType.")
_sync_type(ks_name, type_model)
_sync_type(ks_name, type_model, connection=connection)
def _sync_type(ks_name, type_model, omit_subtypes=None):
def _sync_type(ks_name, type_model, omit_subtypes=None, connection=None):
syncd_sub_types = omit_subtypes or set()
for field in type_model._fields.values():
udts = []
columns.resolve_udts(field, udts)
for udt in [u for u in udts if u not in syncd_sub_types]:
_sync_type(ks_name, udt, syncd_sub_types)
_sync_type(ks_name, udt, syncd_sub_types, connection=connection)
syncd_sub_types.add(udt)
type_name = type_model.type_name()
type_name_qualified = "%s.%s" % (ks_name, type_name)
cluster = get_cluster()
cluster = get_cluster(connection)
keyspace = cluster.metadata.keyspaces[ks_name]
defined_types = keyspace.user_types
if type_name not in defined_types:
log.debug("sync_type creating new type %s", type_name_qualified)
log.debug(log_msg("sync_type creating new type %s", keyspace=ks_name, connection=connection), type_name_qualified)
cql = get_create_type(type_model, ks_name)
execute(cql)
execute(cql, connection=connection)
cluster.refresh_user_type_metadata(ks_name, type_name)
type_model.register_for_keyspace(ks_name)
type_model.register_for_keyspace(ks_name, connection=connection)
else:
type_meta = defined_types[type_name]
defined_fields = type_meta.field_names
@ -311,24 +361,26 @@ def _sync_type(ks_name, type_model, omit_subtypes=None):
for field in type_model._fields.values():
model_fields.add(field.db_field_name)
if field.db_field_name not in defined_fields:
execute("ALTER TYPE {0} ADD {1}".format(type_name_qualified, field.get_column_def()))
execute("ALTER TYPE {0} ADD {1}".format(type_name_qualified, field.get_column_def()), connection=connection)
else:
field_type = type_meta.field_types[defined_fields.index(field.db_field_name)]
if field_type != field.db_type:
msg = 'Existing user type {0} has field "{1}" with a type ({2}) differing from the model user type ({3}).' \
' UserType should be updated.'.format(type_name_qualified, field.db_field_name, field_type, field.db_type)
msg = log_msg('Existing user type {0} has field "{1}" with a type ({2}) differing from the model user type ({3}).'
' UserType should be updated.', keyspace=ks_name, connection=connection)
msg = msg.format(type_name_qualified, field.db_field_name, field_type, field.db_type)
warnings.warn(msg)
log.warning(msg)
type_model.register_for_keyspace(ks_name)
type_model.register_for_keyspace(ks_name, connection=connection)
if len(defined_fields) == len(model_fields):
log.info("Type %s did not require synchronization", type_name_qualified)
log.info(log_msg("Type %s did not require synchronization", keyspace=ks_name, connection=connection), type_name_qualified)
return
db_fields_not_in_model = model_fields.symmetric_difference(defined_fields)
if db_fields_not_in_model:
log.info("Type %s has fields not referenced by model: %s", type_name_qualified, db_fields_not_in_model)
msg = log_msg("Type %s has fields not referenced by model: %s", keyspace=ks_name, connection=connection)
log.info(msg, type_name_qualified, db_fields_not_in_model)
def get_create_type(type_model, keyspace):
@ -377,9 +429,9 @@ def _get_create_table(model):
return ' '.join(query_strings)
def _get_table_metadata(model):
def _get_table_metadata(model, connection=None):
# returns the table as provided by the native driver for a given model
cluster = get_cluster()
cluster = get_cluster(connection)
ks = model._get_keyspace()
table = model._raw_column_family_name()
table = cluster.metadata.keyspaces[ks].tables[table]
@ -401,19 +453,22 @@ def _options_map_from_strings(option_strings):
return options
def _update_options(model):
def _update_options(model, connection=None):
"""Updates the table options for the given model if necessary.
:param model: The model to update.
:param connection: Name of the connection to use
:return: `True`, if the options were modified in Cassandra,
`False` otherwise.
:rtype: bool
"""
log.debug("Checking %s for option differences", model)
ks_name = model._get_keyspace()
msg = log_msg("Checking %s for option differences", keyspace=ks_name, connection=connection)
log.debug(msg, model)
model_options = model.__options__ or {}
table_meta = _get_table_metadata(model)
table_meta = _get_table_metadata(model, connection=connection)
# go to CQL string first to normalize meta from different versions
existing_option_strings = set(table_meta._make_option_strings(table_meta.options))
existing_options = _options_map_from_strings(existing_option_strings)
@ -425,7 +480,8 @@ def _update_options(model):
try:
existing_value = existing_options[name]
except KeyError:
raise KeyError("Invalid table option: '%s'; known options: %s" % (name, existing_options.keys()))
msg = log_msg("Invalid table option: '%s'; known options: %s", keyspace=ks_name, connection=connection)
raise KeyError(msg % (name, existing_options.keys()))
if isinstance(existing_value, six.string_types):
if value != existing_value:
update_options[name] = value
@ -441,13 +497,13 @@ def _update_options(model):
if update_options:
options = ' AND '.join(metadata.TableMetadataV3._make_option_strings(update_options))
query = "ALTER TABLE {0} WITH {1}".format(model.column_family_name(), options)
execute(query)
execute(query, connection=connection)
return True
return False
def drop_table(model, keyspaces=None):
def drop_table(model, keyspaces=None, connections=None):
"""
Drops the table indicated by the model, if it exists.
@ -459,29 +515,24 @@ def drop_table(model, keyspaces=None):
*There are plans to guard schema-modifying functions with an environment-driven conditional.*
"""
if keyspaces:
if not isinstance(keyspaces, (list, tuple)):
raise ValueError('keyspaces must be a list or a tuple.')
context = get_context(keyspaces, connections)
for connection, keyspace in context:
with query.ContextQuery(model, keyspace=keyspace) as m:
_drop_table(m, connection=connection)
for keyspace in keyspaces:
with query.ContextQuery(model, keyspace=keyspace) as m:
_drop_table(m)
else:
_drop_table(model)
def _drop_table(model):
def _drop_table(model, connection=None):
if not _allow_schema_modification():
return
# don't try to delete non existant tables
meta = get_cluster().metadata
meta = get_cluster(connection).metadata
ks_name = model._get_keyspace()
raw_cf_name = model._raw_column_family_name()
try:
meta.keyspaces[ks_name].tables[raw_cf_name]
execute('DROP TABLE {0};'.format(model.column_family_name()))
execute('DROP TABLE {0};'.format(model.column_family_name()), connection=connection)
except KeyError:
pass

View File

@ -290,7 +290,7 @@ class ContextQuery(object):
raise CQLEngineException("Models must be derived from base Model.")
ks = keyspace if keyspace else model.__keyspace__
new_type = type(model.__name__, (model,), {'__keyspace__': ks})
new_type = type(model.__name__, (model,), {'__keyspace__': ks, '__abstract__': model.__abstract__})
self.model = new_type
@ -972,7 +972,7 @@ class AbstractQuerySet(object):
clone = copy.deepcopy(self)
if keyspace:
new_type = type(self.model.__name__, (self.model,), {'__keyspace__': keyspace})
new_type = type(self.model.__name__, (self.model,), {'__keyspace__': keyspace, '__abstract__': self.model.__abstract__})
clone.model = new_type
return clone

View File

@ -4,7 +4,7 @@ import six
from cassandra.util import OrderedDict
from cassandra.cqlengine import CQLEngineException
from cassandra.cqlengine import columns
from cassandra.cqlengine import connection
from cassandra.cqlengine import connection as conn
from cassandra.cqlengine import models
@ -112,8 +112,8 @@ class BaseUserType(object):
return [(k, self[k]) for k in self]
@classmethod
def register_for_keyspace(cls, keyspace):
connection.register_udt(keyspace, cls.type_name(), cls)
def register_for_keyspace(cls, keyspace, connection=None):
conn.register_udt(keyspace, cls.type_name(), cls, connection=connection)
@classmethod
def type_name(cls):