cqle: Protect ks/cf identifiers that use keywords

PYTHON-244
This commit is contained in:
Adam Holmberg
2015-05-12 15:04:43 -05:00
parent bc3ea9b1bc
commit 90e2ace36f
3 changed files with 69 additions and 26 deletions

View File

@@ -81,7 +81,7 @@ def create_keyspace(name, strategy_class, replication_factor, durable_writes=Tru
query = """ query = """
CREATE KEYSPACE {} CREATE KEYSPACE {}
WITH REPLICATION = {} WITH REPLICATION = {}
""".format(name, json.dumps(replication_map).replace('"', "'")) """.format(metadata.protect_name(name), json.dumps(replication_map).replace('"', "'"))
if strategy_class != 'SimpleStrategy': if strategy_class != 'SimpleStrategy':
query += " AND DURABLE_WRITES = {}".format('true' if durable_writes else 'false') query += " AND DURABLE_WRITES = {}".format('true' if durable_writes else 'false')
@@ -163,7 +163,7 @@ def drop_keyspace(name):
cluster = get_cluster() cluster = get_cluster()
if name in cluster.metadata.keyspaces: if name in cluster.metadata.keyspaces:
execute("DROP KEYSPACE {}".format(name)) execute("DROP KEYSPACE {}".format(metadata.protect_name(name)))
def sync_table(model): def sync_table(model):
@@ -191,9 +191,8 @@ def sync_table(model):
if model.__abstract__: if model.__abstract__:
raise CQLEngineException("cannot create table from abstract model") raise CQLEngineException("cannot create table from abstract model")
# construct query string
cf_name = model.column_family_name() cf_name = model.column_family_name()
raw_cf_name = model.column_family_name(include_keyspace=False) raw_cf_name = model._raw_column_family_name()
ks_name = model._get_keyspace() ks_name = model._get_keyspace()
@@ -433,7 +432,7 @@ def get_compaction_options(model):
def get_fields(model): def get_fields(model):
# returns all fields that aren't part of the PK # returns all fields that aren't part of the PK
ks_name = model._get_keyspace() ks_name = model._get_keyspace()
col_family = model.column_family_name(include_keyspace=False) col_family = model._raw_column_family_name()
field_types = ['regular', 'static'] field_types = ['regular', 'static']
query = "select * from system.schema_columns where keyspace_name = %s and columnfamily_name = %s" query = "select * from system.schema_columns where keyspace_name = %s and columnfamily_name = %s"
tmp = execute(query, [ks_name, col_family]) tmp = execute(query, [ks_name, col_family])
@@ -452,7 +451,7 @@ def get_table_settings(model):
# returns the table as provided by the native driver for a given model # returns the table as provided by the native driver for a given model
cluster = get_cluster() cluster = get_cluster()
ks = model._get_keyspace() ks = model._get_keyspace()
table = model.column_family_name(include_keyspace=False) table = model._raw_column_family_name()
table = cluster.metadata.keyspaces[ks].tables[table] table = cluster.metadata.keyspaces[ks].tables[table]
return table return table
@@ -520,11 +519,11 @@ def drop_table(model):
meta = get_cluster().metadata meta = get_cluster().metadata
ks_name = model._get_keyspace() ks_name = model._get_keyspace()
raw_cf_name = model.column_family_name(include_keyspace=False) raw_cf_name = model._raw_column_family_name()
try: try:
meta.keyspaces[ks_name].tables[raw_cf_name] meta.keyspaces[ks_name].tables[raw_cf_name]
execute('drop table {};'.format(model.column_family_name(include_keyspace=True))) execute('drop table {};'.format(model.column_family_name()))
except KeyError: except KeyError:
pass pass

View File

@@ -23,6 +23,7 @@ from cassandra.cqlengine import connection
from cassandra.cqlengine import query from cassandra.cqlengine import query
from cassandra.cqlengine.query import DoesNotExist as _DoesNotExist from cassandra.cqlengine.query import DoesNotExist as _DoesNotExist
from cassandra.cqlengine.query import MultipleObjectsReturned as _MultipleObjectsReturned from cassandra.cqlengine.query import MultipleObjectsReturned as _MultipleObjectsReturned
from cassandra.metadata import protect_name
from cassandra.util import OrderedDict from cassandra.util import OrderedDict
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -353,6 +354,8 @@ class BaseModel(object):
_if_not_exists = False # optional if_not_exists flag to check existence before insertion _if_not_exists = False # optional if_not_exists flag to check existence before insertion
_table_name = None # used internally to cache a derived table name
def __init__(self, **values): def __init__(self, **values):
self._values = {} self._values = {}
self._ttl = self.__default_ttl__ self._ttl = self.__default_ttl__
@@ -504,27 +507,33 @@ class BaseModel(object):
Returns the column family name if it's been defined Returns the column family name if it's been defined
otherwise, it creates it from the module and class name otherwise, it creates it from the module and class name
""" """
cf_name = '' cf_name = protect_name(cls._raw_column_family_name())
if cls.__table_name__: if include_keyspace:
cf_name = cls.__table_name__.lower() return '{}.{}'.format(protect_name(cls._get_keyspace()), cf_name)
else:
# get polymorphic base table names if model is polymorphic
if cls._is_polymorphic and not cls._is_polymorphic_base:
return cls._polymorphic_base.column_family_name(include_keyspace=include_keyspace)
return cf_name
@classmethod
def _raw_column_family_name(cls):
if not cls._table_name:
if cls.__table_name__:
cls._table_name = cls.__table_name__.lower()
else:
if cls._is_polymorphic and not cls._is_polymorphic_base:
cls._table_name = cls._polymorphic_base._raw_column_family_name()
else:
camelcase = re.compile(r'([a-z])([A-Z])') camelcase = re.compile(r'([a-z])([A-Z])')
ccase = lambda s: camelcase.sub(lambda v: '{}_{}'.format(v.group(1), v.group(2).lower()), s) ccase = lambda s: camelcase.sub(lambda v: '{}_{}'.format(v.group(1), v.group(2).lower()), s)
cf_name += ccase(cls.__name__) cf_name = ccase(cls.__name__)
# trim to less than 48 characters or cassandra will complain # trim to less than 48 characters or cassandra will complain
cf_name = cf_name[-48:] cf_name = cf_name[-48:]
cf_name = cf_name.lower() cf_name = cf_name.lower()
cf_name = re.sub(r'^_+', '', cf_name) cf_name = re.sub(r'^_+', '', cf_name)
cls._table_name = cf_name
if not include_keyspace: return cls._table_name
return cf_name
return '{}.{}'.format(cls._get_keyspace(), cf_name)
def validate(self): def validate(self):
""" """

View File

@@ -14,8 +14,9 @@
from unittest import TestCase from unittest import TestCase
from cassandra.cqlengine.models import Model, ModelDefinitionException
from cassandra.cqlengine import columns from cassandra.cqlengine import columns
from cassandra.cqlengine.management import sync_table, drop_table, create_keyspace_simple, drop_keyspace
from cassandra.cqlengine.models import Model, ModelDefinitionException
class TestModel(TestCase): class TestModel(TestCase):
@@ -49,6 +50,40 @@ class TestModel(TestCase):
self.assertEqual(m0, m0) self.assertEqual(m0, m0)
self.assertNotEqual(m0, m1) self.assertNotEqual(m0, m1)
def test_keywords_as_names(self):
create_keyspace_simple('keyspace', 1)
class table(Model):
__keyspace__ = 'keyspace'
select = columns.Integer(primary_key=True)
table = columns.Text()
# create should work
drop_table(table)
sync_table(table)
created = table.create(select=0, table='table')
selected = table.objects(select=0)[0]
self.assertEqual(created.select, selected.select)
self.assertEqual(created.table, selected.table)
# alter should work
class table(Model):
__keyspace__ = 'keyspace'
select = columns.Integer(primary_key=True)
table = columns.Text()
where = columns.Text()
sync_table(table)
created = table.create(select=1, table='table')
selected = table.objects(select=1)[0]
self.assertEqual(created.select, selected.select)
self.assertEqual(created.table, selected.table)
self.assertEqual(created.where, selected.where)
drop_keyspace('keyspace')
class BuiltInAttributeConflictTest(TestCase): class BuiltInAttributeConflictTest(TestCase):
"""tests Model definitions that conflict with built-in attributes/methods""" """tests Model definitions that conflict with built-in attributes/methods"""