cqle: Protect ks/cf identifiers that use keywords

PYTHON-244
This commit is contained in:
Adam Holmberg
2015-05-12 15:04:43 -05:00
parent bc3ea9b1bc
commit 90e2ace36f
3 changed files with 69 additions and 26 deletions

View File

@@ -81,7 +81,7 @@ def create_keyspace(name, strategy_class, replication_factor, durable_writes=Tru
query = """
CREATE KEYSPACE {}
WITH REPLICATION = {}
""".format(name, json.dumps(replication_map).replace('"', "'"))
""".format(metadata.protect_name(name), json.dumps(replication_map).replace('"', "'"))
if strategy_class != 'SimpleStrategy':
query += " AND DURABLE_WRITES = {}".format('true' if durable_writes else 'false')
@@ -163,7 +163,7 @@ def drop_keyspace(name):
cluster = get_cluster()
if name in cluster.metadata.keyspaces:
execute("DROP KEYSPACE {}".format(name))
execute("DROP KEYSPACE {}".format(metadata.protect_name(name)))
def sync_table(model):
@@ -191,9 +191,8 @@ def sync_table(model):
if model.__abstract__:
raise CQLEngineException("cannot create table from abstract model")
# construct query string
cf_name = model.column_family_name()
raw_cf_name = model.column_family_name(include_keyspace=False)
raw_cf_name = model._raw_column_family_name()
ks_name = model._get_keyspace()
@@ -433,7 +432,7 @@ def get_compaction_options(model):
def get_fields(model):
# returns all fields that aren't part of the PK
ks_name = model._get_keyspace()
col_family = model.column_family_name(include_keyspace=False)
col_family = model._raw_column_family_name()
field_types = ['regular', 'static']
query = "select * from system.schema_columns where keyspace_name = %s and columnfamily_name = %s"
tmp = execute(query, [ks_name, col_family])
@@ -452,7 +451,7 @@ def get_table_settings(model):
# returns the table as provided by the native driver for a given model
cluster = get_cluster()
ks = model._get_keyspace()
table = model.column_family_name(include_keyspace=False)
table = model._raw_column_family_name()
table = cluster.metadata.keyspaces[ks].tables[table]
return table
@@ -520,11 +519,11 @@ def drop_table(model):
meta = get_cluster().metadata
ks_name = model._get_keyspace()
raw_cf_name = model.column_family_name(include_keyspace=False)
raw_cf_name = model._raw_column_family_name()
try:
meta.keyspaces[ks_name].tables[raw_cf_name]
execute('drop table {};'.format(model.column_family_name(include_keyspace=True)))
execute('drop table {};'.format(model.column_family_name()))
except KeyError:
pass

View File

@@ -23,6 +23,7 @@ from cassandra.cqlengine import connection
from cassandra.cqlengine import query
from cassandra.cqlengine.query import DoesNotExist as _DoesNotExist
from cassandra.cqlengine.query import MultipleObjectsReturned as _MultipleObjectsReturned
from cassandra.metadata import protect_name
from cassandra.util import OrderedDict
log = logging.getLogger(__name__)
@@ -353,6 +354,8 @@ class BaseModel(object):
_if_not_exists = False # optional if_not_exists flag to check existence before insertion
_table_name = None # used internally to cache a derived table name
def __init__(self, **values):
self._values = {}
self._ttl = self.__default_ttl__
@@ -504,27 +507,33 @@ class BaseModel(object):
Returns the column family name if it's been defined
otherwise, it creates it from the module and class name
"""
cf_name = ''
if cls.__table_name__:
cf_name = cls.__table_name__.lower()
else:
# get polymorphic base table names if model is polymorphic
if cls._is_polymorphic and not cls._is_polymorphic_base:
return cls._polymorphic_base.column_family_name(include_keyspace=include_keyspace)
cf_name = protect_name(cls._raw_column_family_name())
if include_keyspace:
return '{}.{}'.format(protect_name(cls._get_keyspace()), cf_name)
camelcase = re.compile(r'([a-z])([A-Z])')
ccase = lambda s: camelcase.sub(lambda v: '{}_{}'.format(v.group(1), v.group(2).lower()), s)
return cf_name
cf_name += ccase(cls.__name__)
# trim to less than 48 characters or cassandra will complain
cf_name = cf_name[-48:]
cf_name = cf_name.lower()
cf_name = re.sub(r'^_+', '', cf_name)
if not include_keyspace:
return cf_name
@classmethod
def _raw_column_family_name(cls):
if not cls._table_name:
if cls.__table_name__:
cls._table_name = cls.__table_name__.lower()
else:
if cls._is_polymorphic and not cls._is_polymorphic_base:
cls._table_name = cls._polymorphic_base._raw_column_family_name()
else:
camelcase = re.compile(r'([a-z])([A-Z])')
ccase = lambda s: camelcase.sub(lambda v: '{}_{}'.format(v.group(1), v.group(2).lower()), s)
return '{}.{}'.format(cls._get_keyspace(), cf_name)
cf_name = ccase(cls.__name__)
# trim to less than 48 characters or cassandra will complain
cf_name = cf_name[-48:]
cf_name = cf_name.lower()
cf_name = re.sub(r'^_+', '', cf_name)
cls._table_name = cf_name
return cls._table_name
def validate(self):
"""

View File

@@ -14,8 +14,9 @@
from unittest import TestCase
from cassandra.cqlengine.models import Model, ModelDefinitionException
from cassandra.cqlengine import columns
from cassandra.cqlengine.management import sync_table, drop_table, create_keyspace_simple, drop_keyspace
from cassandra.cqlengine.models import Model, ModelDefinitionException
class TestModel(TestCase):
@@ -49,6 +50,40 @@ class TestModel(TestCase):
self.assertEqual(m0, m0)
self.assertNotEqual(m0, m1)
def test_keywords_as_names(self):
create_keyspace_simple('keyspace', 1)
class table(Model):
__keyspace__ = 'keyspace'
select = columns.Integer(primary_key=True)
table = columns.Text()
# create should work
drop_table(table)
sync_table(table)
created = table.create(select=0, table='table')
selected = table.objects(select=0)[0]
self.assertEqual(created.select, selected.select)
self.assertEqual(created.table, selected.table)
# alter should work
class table(Model):
__keyspace__ = 'keyspace'
select = columns.Integer(primary_key=True)
table = columns.Text()
where = columns.Text()
sync_table(table)
created = table.create(select=1, table='table')
selected = table.objects(select=1)[0]
self.assertEqual(created.select, selected.select)
self.assertEqual(created.table, selected.table)
self.assertEqual(created.where, selected.where)
drop_keyspace('keyspace')
class BuiltInAttributeConflictTest(TestCase):
"""tests Model definitions that conflict with built-in attributes/methods"""