cqle: Protect ks/cf identifiers that use keywords
PYTHON-244
This commit is contained in:
@@ -81,7 +81,7 @@ def create_keyspace(name, strategy_class, replication_factor, durable_writes=Tru
|
|||||||
query = """
|
query = """
|
||||||
CREATE KEYSPACE {}
|
CREATE KEYSPACE {}
|
||||||
WITH REPLICATION = {}
|
WITH REPLICATION = {}
|
||||||
""".format(name, json.dumps(replication_map).replace('"', "'"))
|
""".format(metadata.protect_name(name), json.dumps(replication_map).replace('"', "'"))
|
||||||
|
|
||||||
if strategy_class != 'SimpleStrategy':
|
if strategy_class != 'SimpleStrategy':
|
||||||
query += " AND DURABLE_WRITES = {}".format('true' if durable_writes else 'false')
|
query += " AND DURABLE_WRITES = {}".format('true' if durable_writes else 'false')
|
||||||
@@ -163,7 +163,7 @@ def drop_keyspace(name):
|
|||||||
|
|
||||||
cluster = get_cluster()
|
cluster = get_cluster()
|
||||||
if name in cluster.metadata.keyspaces:
|
if name in cluster.metadata.keyspaces:
|
||||||
execute("DROP KEYSPACE {}".format(name))
|
execute("DROP KEYSPACE {}".format(metadata.protect_name(name)))
|
||||||
|
|
||||||
|
|
||||||
def sync_table(model):
|
def sync_table(model):
|
||||||
@@ -191,9 +191,8 @@ def sync_table(model):
|
|||||||
if model.__abstract__:
|
if model.__abstract__:
|
||||||
raise CQLEngineException("cannot create table from abstract model")
|
raise CQLEngineException("cannot create table from abstract model")
|
||||||
|
|
||||||
# construct query string
|
|
||||||
cf_name = model.column_family_name()
|
cf_name = model.column_family_name()
|
||||||
raw_cf_name = model.column_family_name(include_keyspace=False)
|
raw_cf_name = model._raw_column_family_name()
|
||||||
|
|
||||||
ks_name = model._get_keyspace()
|
ks_name = model._get_keyspace()
|
||||||
|
|
||||||
@@ -433,7 +432,7 @@ def get_compaction_options(model):
|
|||||||
def get_fields(model):
|
def get_fields(model):
|
||||||
# returns all fields that aren't part of the PK
|
# returns all fields that aren't part of the PK
|
||||||
ks_name = model._get_keyspace()
|
ks_name = model._get_keyspace()
|
||||||
col_family = model.column_family_name(include_keyspace=False)
|
col_family = model._raw_column_family_name()
|
||||||
field_types = ['regular', 'static']
|
field_types = ['regular', 'static']
|
||||||
query = "select * from system.schema_columns where keyspace_name = %s and columnfamily_name = %s"
|
query = "select * from system.schema_columns where keyspace_name = %s and columnfamily_name = %s"
|
||||||
tmp = execute(query, [ks_name, col_family])
|
tmp = execute(query, [ks_name, col_family])
|
||||||
@@ -452,7 +451,7 @@ def get_table_settings(model):
|
|||||||
# returns the table as provided by the native driver for a given model
|
# returns the table as provided by the native driver for a given model
|
||||||
cluster = get_cluster()
|
cluster = get_cluster()
|
||||||
ks = model._get_keyspace()
|
ks = model._get_keyspace()
|
||||||
table = model.column_family_name(include_keyspace=False)
|
table = model._raw_column_family_name()
|
||||||
table = cluster.metadata.keyspaces[ks].tables[table]
|
table = cluster.metadata.keyspaces[ks].tables[table]
|
||||||
return table
|
return table
|
||||||
|
|
||||||
@@ -520,11 +519,11 @@ def drop_table(model):
|
|||||||
meta = get_cluster().metadata
|
meta = get_cluster().metadata
|
||||||
|
|
||||||
ks_name = model._get_keyspace()
|
ks_name = model._get_keyspace()
|
||||||
raw_cf_name = model.column_family_name(include_keyspace=False)
|
raw_cf_name = model._raw_column_family_name()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
meta.keyspaces[ks_name].tables[raw_cf_name]
|
meta.keyspaces[ks_name].tables[raw_cf_name]
|
||||||
execute('drop table {};'.format(model.column_family_name(include_keyspace=True)))
|
execute('drop table {};'.format(model.column_family_name()))
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from cassandra.cqlengine import connection
|
|||||||
from cassandra.cqlengine import query
|
from cassandra.cqlengine import query
|
||||||
from cassandra.cqlengine.query import DoesNotExist as _DoesNotExist
|
from cassandra.cqlengine.query import DoesNotExist as _DoesNotExist
|
||||||
from cassandra.cqlengine.query import MultipleObjectsReturned as _MultipleObjectsReturned
|
from cassandra.cqlengine.query import MultipleObjectsReturned as _MultipleObjectsReturned
|
||||||
|
from cassandra.metadata import protect_name
|
||||||
from cassandra.util import OrderedDict
|
from cassandra.util import OrderedDict
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
@@ -353,6 +354,8 @@ class BaseModel(object):
|
|||||||
|
|
||||||
_if_not_exists = False # optional if_not_exists flag to check existence before insertion
|
_if_not_exists = False # optional if_not_exists flag to check existence before insertion
|
||||||
|
|
||||||
|
_table_name = None # used internally to cache a derived table name
|
||||||
|
|
||||||
def __init__(self, **values):
|
def __init__(self, **values):
|
||||||
self._values = {}
|
self._values = {}
|
||||||
self._ttl = self.__default_ttl__
|
self._ttl = self.__default_ttl__
|
||||||
@@ -504,27 +507,33 @@ class BaseModel(object):
|
|||||||
Returns the column family name if it's been defined
|
Returns the column family name if it's been defined
|
||||||
otherwise, it creates it from the module and class name
|
otherwise, it creates it from the module and class name
|
||||||
"""
|
"""
|
||||||
cf_name = ''
|
cf_name = protect_name(cls._raw_column_family_name())
|
||||||
if cls.__table_name__:
|
if include_keyspace:
|
||||||
cf_name = cls.__table_name__.lower()
|
return '{}.{}'.format(protect_name(cls._get_keyspace()), cf_name)
|
||||||
else:
|
|
||||||
# get polymorphic base table names if model is polymorphic
|
|
||||||
if cls._is_polymorphic and not cls._is_polymorphic_base:
|
|
||||||
return cls._polymorphic_base.column_family_name(include_keyspace=include_keyspace)
|
|
||||||
|
|
||||||
|
return cf_name
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _raw_column_family_name(cls):
|
||||||
|
if not cls._table_name:
|
||||||
|
if cls.__table_name__:
|
||||||
|
cls._table_name = cls.__table_name__.lower()
|
||||||
|
else:
|
||||||
|
if cls._is_polymorphic and not cls._is_polymorphic_base:
|
||||||
|
cls._table_name = cls._polymorphic_base._raw_column_family_name()
|
||||||
|
else:
|
||||||
camelcase = re.compile(r'([a-z])([A-Z])')
|
camelcase = re.compile(r'([a-z])([A-Z])')
|
||||||
ccase = lambda s: camelcase.sub(lambda v: '{}_{}'.format(v.group(1), v.group(2).lower()), s)
|
ccase = lambda s: camelcase.sub(lambda v: '{}_{}'.format(v.group(1), v.group(2).lower()), s)
|
||||||
|
|
||||||
cf_name += ccase(cls.__name__)
|
cf_name = ccase(cls.__name__)
|
||||||
# trim to less than 48 characters or cassandra will complain
|
# trim to less than 48 characters or cassandra will complain
|
||||||
cf_name = cf_name[-48:]
|
cf_name = cf_name[-48:]
|
||||||
cf_name = cf_name.lower()
|
cf_name = cf_name.lower()
|
||||||
cf_name = re.sub(r'^_+', '', cf_name)
|
cf_name = re.sub(r'^_+', '', cf_name)
|
||||||
|
cls._table_name = cf_name
|
||||||
|
|
||||||
if not include_keyspace:
|
return cls._table_name
|
||||||
return cf_name
|
|
||||||
|
|
||||||
return '{}.{}'.format(cls._get_keyspace(), cf_name)
|
|
||||||
|
|
||||||
def validate(self):
|
def validate(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -14,8 +14,9 @@
|
|||||||
|
|
||||||
from unittest import TestCase
|
from unittest import TestCase
|
||||||
|
|
||||||
from cassandra.cqlengine.models import Model, ModelDefinitionException
|
|
||||||
from cassandra.cqlengine import columns
|
from cassandra.cqlengine import columns
|
||||||
|
from cassandra.cqlengine.management import sync_table, drop_table, create_keyspace_simple, drop_keyspace
|
||||||
|
from cassandra.cqlengine.models import Model, ModelDefinitionException
|
||||||
|
|
||||||
|
|
||||||
class TestModel(TestCase):
|
class TestModel(TestCase):
|
||||||
@@ -49,6 +50,40 @@ class TestModel(TestCase):
|
|||||||
self.assertEqual(m0, m0)
|
self.assertEqual(m0, m0)
|
||||||
self.assertNotEqual(m0, m1)
|
self.assertNotEqual(m0, m1)
|
||||||
|
|
||||||
|
def test_keywords_as_names(self):
|
||||||
|
create_keyspace_simple('keyspace', 1)
|
||||||
|
|
||||||
|
class table(Model):
|
||||||
|
__keyspace__ = 'keyspace'
|
||||||
|
select = columns.Integer(primary_key=True)
|
||||||
|
table = columns.Text()
|
||||||
|
|
||||||
|
# create should work
|
||||||
|
drop_table(table)
|
||||||
|
sync_table(table)
|
||||||
|
|
||||||
|
created = table.create(select=0, table='table')
|
||||||
|
selected = table.objects(select=0)[0]
|
||||||
|
self.assertEqual(created.select, selected.select)
|
||||||
|
self.assertEqual(created.table, selected.table)
|
||||||
|
|
||||||
|
# alter should work
|
||||||
|
class table(Model):
|
||||||
|
__keyspace__ = 'keyspace'
|
||||||
|
select = columns.Integer(primary_key=True)
|
||||||
|
table = columns.Text()
|
||||||
|
where = columns.Text()
|
||||||
|
|
||||||
|
sync_table(table)
|
||||||
|
|
||||||
|
created = table.create(select=1, table='table')
|
||||||
|
selected = table.objects(select=1)[0]
|
||||||
|
self.assertEqual(created.select, selected.select)
|
||||||
|
self.assertEqual(created.table, selected.table)
|
||||||
|
self.assertEqual(created.where, selected.where)
|
||||||
|
|
||||||
|
drop_keyspace('keyspace')
|
||||||
|
|
||||||
|
|
||||||
class BuiltInAttributeConflictTest(TestCase):
|
class BuiltInAttributeConflictTest(TestCase):
|
||||||
"""tests Model definitions that conflict with built-in attributes/methods"""
|
"""tests Model definitions that conflict with built-in attributes/methods"""
|
||||||
|
|||||||
Reference in New Issue
Block a user