cqle: Protect ks/cf identifiers that use keywords
PYTHON-244
This commit is contained in:
@@ -81,7 +81,7 @@ def create_keyspace(name, strategy_class, replication_factor, durable_writes=Tru
|
||||
query = """
|
||||
CREATE KEYSPACE {}
|
||||
WITH REPLICATION = {}
|
||||
""".format(name, json.dumps(replication_map).replace('"', "'"))
|
||||
""".format(metadata.protect_name(name), json.dumps(replication_map).replace('"', "'"))
|
||||
|
||||
if strategy_class != 'SimpleStrategy':
|
||||
query += " AND DURABLE_WRITES = {}".format('true' if durable_writes else 'false')
|
||||
@@ -163,7 +163,7 @@ def drop_keyspace(name):
|
||||
|
||||
cluster = get_cluster()
|
||||
if name in cluster.metadata.keyspaces:
|
||||
execute("DROP KEYSPACE {}".format(name))
|
||||
execute("DROP KEYSPACE {}".format(metadata.protect_name(name)))
|
||||
|
||||
|
||||
def sync_table(model):
|
||||
@@ -191,9 +191,8 @@ def sync_table(model):
|
||||
if model.__abstract__:
|
||||
raise CQLEngineException("cannot create table from abstract model")
|
||||
|
||||
# construct query string
|
||||
cf_name = model.column_family_name()
|
||||
raw_cf_name = model.column_family_name(include_keyspace=False)
|
||||
raw_cf_name = model._raw_column_family_name()
|
||||
|
||||
ks_name = model._get_keyspace()
|
||||
|
||||
@@ -433,7 +432,7 @@ def get_compaction_options(model):
|
||||
def get_fields(model):
|
||||
# returns all fields that aren't part of the PK
|
||||
ks_name = model._get_keyspace()
|
||||
col_family = model.column_family_name(include_keyspace=False)
|
||||
col_family = model._raw_column_family_name()
|
||||
field_types = ['regular', 'static']
|
||||
query = "select * from system.schema_columns where keyspace_name = %s and columnfamily_name = %s"
|
||||
tmp = execute(query, [ks_name, col_family])
|
||||
@@ -452,7 +451,7 @@ def get_table_settings(model):
|
||||
# returns the table as provided by the native driver for a given model
|
||||
cluster = get_cluster()
|
||||
ks = model._get_keyspace()
|
||||
table = model.column_family_name(include_keyspace=False)
|
||||
table = model._raw_column_family_name()
|
||||
table = cluster.metadata.keyspaces[ks].tables[table]
|
||||
return table
|
||||
|
||||
@@ -520,11 +519,11 @@ def drop_table(model):
|
||||
meta = get_cluster().metadata
|
||||
|
||||
ks_name = model._get_keyspace()
|
||||
raw_cf_name = model.column_family_name(include_keyspace=False)
|
||||
raw_cf_name = model._raw_column_family_name()
|
||||
|
||||
try:
|
||||
meta.keyspaces[ks_name].tables[raw_cf_name]
|
||||
execute('drop table {};'.format(model.column_family_name(include_keyspace=True)))
|
||||
execute('drop table {};'.format(model.column_family_name()))
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
|
@@ -23,6 +23,7 @@ from cassandra.cqlengine import connection
|
||||
from cassandra.cqlengine import query
|
||||
from cassandra.cqlengine.query import DoesNotExist as _DoesNotExist
|
||||
from cassandra.cqlengine.query import MultipleObjectsReturned as _MultipleObjectsReturned
|
||||
from cassandra.metadata import protect_name
|
||||
from cassandra.util import OrderedDict
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@@ -353,6 +354,8 @@ class BaseModel(object):
|
||||
|
||||
_if_not_exists = False # optional if_not_exists flag to check existence before insertion
|
||||
|
||||
_table_name = None # used internally to cache a derived table name
|
||||
|
||||
def __init__(self, **values):
|
||||
self._values = {}
|
||||
self._ttl = self.__default_ttl__
|
||||
@@ -504,27 +507,33 @@ class BaseModel(object):
|
||||
Returns the column family name if it's been defined
|
||||
otherwise, it creates it from the module and class name
|
||||
"""
|
||||
cf_name = ''
|
||||
if cls.__table_name__:
|
||||
cf_name = cls.__table_name__.lower()
|
||||
else:
|
||||
# get polymorphic base table names if model is polymorphic
|
||||
if cls._is_polymorphic and not cls._is_polymorphic_base:
|
||||
return cls._polymorphic_base.column_family_name(include_keyspace=include_keyspace)
|
||||
cf_name = protect_name(cls._raw_column_family_name())
|
||||
if include_keyspace:
|
||||
return '{}.{}'.format(protect_name(cls._get_keyspace()), cf_name)
|
||||
|
||||
camelcase = re.compile(r'([a-z])([A-Z])')
|
||||
ccase = lambda s: camelcase.sub(lambda v: '{}_{}'.format(v.group(1), v.group(2).lower()), s)
|
||||
return cf_name
|
||||
|
||||
cf_name += ccase(cls.__name__)
|
||||
# trim to less than 48 characters or cassandra will complain
|
||||
cf_name = cf_name[-48:]
|
||||
cf_name = cf_name.lower()
|
||||
cf_name = re.sub(r'^_+', '', cf_name)
|
||||
|
||||
if not include_keyspace:
|
||||
return cf_name
|
||||
@classmethod
|
||||
def _raw_column_family_name(cls):
|
||||
if not cls._table_name:
|
||||
if cls.__table_name__:
|
||||
cls._table_name = cls.__table_name__.lower()
|
||||
else:
|
||||
if cls._is_polymorphic and not cls._is_polymorphic_base:
|
||||
cls._table_name = cls._polymorphic_base._raw_column_family_name()
|
||||
else:
|
||||
camelcase = re.compile(r'([a-z])([A-Z])')
|
||||
ccase = lambda s: camelcase.sub(lambda v: '{}_{}'.format(v.group(1), v.group(2).lower()), s)
|
||||
|
||||
return '{}.{}'.format(cls._get_keyspace(), cf_name)
|
||||
cf_name = ccase(cls.__name__)
|
||||
# trim to less than 48 characters or cassandra will complain
|
||||
cf_name = cf_name[-48:]
|
||||
cf_name = cf_name.lower()
|
||||
cf_name = re.sub(r'^_+', '', cf_name)
|
||||
cls._table_name = cf_name
|
||||
|
||||
return cls._table_name
|
||||
|
||||
def validate(self):
|
||||
"""
|
||||
|
@@ -14,8 +14,9 @@
|
||||
|
||||
from unittest import TestCase
|
||||
|
||||
from cassandra.cqlengine.models import Model, ModelDefinitionException
|
||||
from cassandra.cqlengine import columns
|
||||
from cassandra.cqlengine.management import sync_table, drop_table, create_keyspace_simple, drop_keyspace
|
||||
from cassandra.cqlengine.models import Model, ModelDefinitionException
|
||||
|
||||
|
||||
class TestModel(TestCase):
|
||||
@@ -49,6 +50,40 @@ class TestModel(TestCase):
|
||||
self.assertEqual(m0, m0)
|
||||
self.assertNotEqual(m0, m1)
|
||||
|
||||
def test_keywords_as_names(self):
|
||||
create_keyspace_simple('keyspace', 1)
|
||||
|
||||
class table(Model):
|
||||
__keyspace__ = 'keyspace'
|
||||
select = columns.Integer(primary_key=True)
|
||||
table = columns.Text()
|
||||
|
||||
# create should work
|
||||
drop_table(table)
|
||||
sync_table(table)
|
||||
|
||||
created = table.create(select=0, table='table')
|
||||
selected = table.objects(select=0)[0]
|
||||
self.assertEqual(created.select, selected.select)
|
||||
self.assertEqual(created.table, selected.table)
|
||||
|
||||
# alter should work
|
||||
class table(Model):
|
||||
__keyspace__ = 'keyspace'
|
||||
select = columns.Integer(primary_key=True)
|
||||
table = columns.Text()
|
||||
where = columns.Text()
|
||||
|
||||
sync_table(table)
|
||||
|
||||
created = table.create(select=1, table='table')
|
||||
selected = table.objects(select=1)[0]
|
||||
self.assertEqual(created.select, selected.select)
|
||||
self.assertEqual(created.table, selected.table)
|
||||
self.assertEqual(created.where, selected.where)
|
||||
|
||||
drop_keyspace('keyspace')
|
||||
|
||||
|
||||
class BuiltInAttributeConflictTest(TestCase):
|
||||
"""tests Model definitions that conflict with built-in attributes/methods"""
|
||||
|
Reference in New Issue
Block a user