Merge remote-tracking branch 'cqlengine/master'

Conflicts:
	cqlengine/columns.py
	cqlengine/connection.py
	cqlengine/models.py
	cqlengine/query.py
This commit is contained in:
Greg Doermann
2014-07-30 18:07:21 -06:00
75 changed files with 8010 additions and 1397 deletions

5
.gitignore vendored
View File

@@ -38,3 +38,8 @@ html/
#Mr Developer
.mr.developer.cfg
.noseids
/commitlog
/data
docs/_build

View File

@@ -8,7 +8,9 @@ before_install:
- sudo service cassandra start
install:
- "pip install -r requirements.txt --use-mirrors"
- "pip install pytest --use-mirrors"
#- "pip install pytest --use-mirrors"
script:
- while [ ! -f /var/run/cassandra.pid ] ; do sleep 1 ; done # wait until cassandra is ready
- "py.test cqlengine/tests/"
- "nosetests --no-skip"
# - "py.test cqlengine/tests/"

13
AUTHORS
View File

@@ -1,9 +1,16 @@
PRIMARY AUTHORS
Blake Eggleston <bdeggleston@gmail.com>
Jon Haddad <jon@jonhaddad.com>
CONTRIBUTORS
CONTRIBUTORS (let us know if we missed you)
Eric Scrivner - test environment, connection pooling
Jon Haddad - helped hash out some of the architecture
Kevin Deldycke
Roey Berman
Danny Cosson
Michael Hall
Netanel Cohen-Tzemach
Mariusz Kryński
Greg Doermann
@pandu-rao

12
CONTRIBUTING.md Normal file
View File

@@ -0,0 +1,12 @@
### Contributing code to cqlengine
Before submitting a pull request, please make sure that it follows these guidelines:
* Limit yourself to one feature or bug fix per pull request.
* Include unittests that thoroughly test the feature/bug fix
* Write clear, descriptive commit messages.
* Many granular commits are preferred over large monolithic commits
* If you're adding or modifying features, please update the documentation
If you're working on a big ticket item, please check in on [cqlengine-users](https://groups.google.com/forum/?fromgroups#!forum/cqlengine-users).
We'd hate to have to steer you in a different direction after you've already put in a lot of hard work.

14
Makefile Normal file
View File

@@ -0,0 +1,14 @@
clean:
find . -name *.pyc -delete
rm -rf cqlengine/__pycache__
build: clean
python setup.py build
release: clean
python setup.py sdist upload
.PHONY: build

View File

@@ -1,7 +1,9 @@
cqlengine
===============
cqlengine is a Cassandra CQL 3 Object Mapper for Python with an interface similar to the Django orm and mongoengine
cqlengine is a Cassandra CQL 3 Object Mapper for Python
**Users of versions < 0.16, the default keyspace 'cqlengine' has been removed. Please read this before upgrading:** [Breaking Changes](https://cqlengine.readthedocs.org/en/latest/topics/models.html#keyspace-change)
[Documentation](https://cqlengine.readthedocs.org/en/latest/)
@@ -9,8 +11,6 @@ cqlengine is a Cassandra CQL 3 Object Mapper for Python with an interface simila
[Users Mailing List](https://groups.google.com/forum/?fromgroups#!forum/cqlengine-users)
[Dev Mailing List](https://groups.google.com/forum/?fromgroups#!forum/cqlengine-dev)
## Installation
```
pip install cqlengine
@@ -20,23 +20,27 @@ pip install cqlengine
```python
#first, define a model
import uuid
from cqlengine import columns
from cqlengine.models import Model
class ExampleModel(Model):
read_repair_chance = 0.05 # optional - defaults to 0.1
example_id = columns.UUID(primary_key=True)
example_id = columns.UUID(primary_key=True, default=uuid.uuid4)
example_type = columns.Integer(index=True)
created_at = columns.DateTime()
description = columns.Text(required=False)
#next, setup the connection to your cassandra server(s)...
>>> from cqlengine import connection
>>> connection.setup(['127.0.0.1:9160'])
>>> connection.setup(['127.0.0.1'])
# or if you're still on cassandra 1.2
>>> connection.setup(['127.0.0.1'], protocol_version=1)
#...and create your CQL table
>>> from cqlengine.management import create_table
>>> create_table(ExampleModel)
>>> from cqlengine.management import sync_table
>>> sync_table(ExampleModel)
#now we can create some rows:
>>> em1 = ExampleModel.create(example_type=0, description="example1", created_at=datetime.now())
@@ -47,8 +51,6 @@ class ExampleModel(Model):
>>> em6 = ExampleModel.create(example_type=1, description="example6", created_at=datetime.now())
>>> em7 = ExampleModel.create(example_type=1, description="example7", created_at=datetime.now())
>>> em8 = ExampleModel.create(example_type=1, description="example8", created_at=datetime.now())
# Note: the UUID and DateTime columns will create uuid4 and datetime.now
# values automatically if we don't specify them when creating new rows
#and now we can run some queries against our table
>>> ExampleModel.objects.count()
@@ -57,7 +59,7 @@ class ExampleModel(Model):
>>> q.count()
4
>>> for instance in q:
>>> print q.description
>>> print instance.description
example5
example6
example7
@@ -71,6 +73,10 @@ example8
>>> q2.count()
1
>>> for instance in q2:
>>> print q.description
>>> print instance.description
example5
```
## Contributing
If you'd like to contribute to cqlengine, please read the [contributor guidelines](https://github.com/bdeggleston/cqlengine/blob/master/CONTRIBUTING.md)

7
RELEASE.txt Normal file
View File

@@ -0,0 +1,7 @@
Check changelog
Ensure docs are updated
Tests pass
Update VERSION
Push tag to github
Push release to pypi

172
changelog
View File

@@ -1,6 +1,176 @@
CHANGELOG
0.2.1 (in progress)
0.16.0
225: No handling of PagedResult from execute
222: figure out how to make travis not totally fail when a test is skipped
220: delayed connect. use setup(delayed_connect=True)
218: throw exception on create_table and delete_table
212: Unchanged primary key trigger error on update
206: FAQ - why we dont' do #189
191: Add support for simple table properties.
172: raise exception when None is passed in as query param
170: trying to perform queries before connection is established should raise useful exception
162: Not Possible to Make Non-Equality Filtering Queries
161: Filtering on DateTime column
154: Blob(bytes) column type issue
128: remove default read_repair_chance & ensure changes are saved when using sync_table
106: specify caching on model
99: specify caching options table management
94: type checking on sync_table (currently allows anything then fails miserably)
73: remove default 'cqlengine' keyspace table management
71: add named table and query expression usage to docs
0.15.0
* native driver integration
0.14.0
* fix for setting map to empty (Lifto)
* report when creating models with attributes that conflict with cqlengine (maedhroz)
* use stable version of sure package (maedhroz)
* performance improvements
0.13.0
* adding support for post batch callbacks (thanks Daniel Dotsenko github.com/dvdotsenko)
* fixing sync table for tables with multiple keys (thanks Daniel Dotsenko github.com/dvdotsenko)
* fixing bug in Bytes column (thanks Netanel Cohen-Tzemach github.com/natict)
* fixing bug with timestamps and DST (thanks Netanel Cohen-Tzemach github.com/natict)
0.12.0
* Normalize and unquote boolean values. (Thanks Kevin Deldycke github.com/kdeldycke)
* Fix race condition in connection manager (Thanks Roey Berman github.com/bergundy)
* allow access to instance columns as if it is a dict (Thanks Kevin Deldycke github.com/kdeldycke)
* Added support for blind partial updates to queryset (Thanks Danny Cosson github.com/dcosson)
* Model instance equality check bugfix (Thanks to Michael Hall, github.com/mahall)
* Fixed bug syncing tables with camel cased names (Thanks to Netanel Cohen-Tzemach, github.com/natict)
* Fixed bug dealing with eggs (Thanks Kevin Deldycke github.com/kdeldycke)
0.11.0
* support for USING TIMESTAMP <microseconds from epoch> via a .timestamp(timedelta(seconds=30)) syntax
- allows for long, timedelta, and datetime
* fixed use of USING TIMESTAMP in batches
* clear TTL and timestamp off models after persisting to DB
* allows UUID without dashes - (Thanks to Michael Hall, github.com/mahall)
* fixes regarding syncing schema settings (thanks Kai Lautaportti github.com/dokai)
0.10.0
* support for execute_on_exception within batches
0.9.2
* fixing create keyspace with network topology strategy
* fixing regression with query expressions
0.9
* adding update method
* adding support for ttls
* adding support for per-query consistency
* adding BigInt column (thanks @Lifto)
* adding support for timezone aware time uuid functions (thanks @dokai)
* only saving collection fields on insert if they've been modified
* adding model method that returns a list of modified columns
0.8.5
* adding support for timeouts
0.8.4
* changing value manager previous value copying to deepcopy
0.8.3
* better logging for operational errors
0.8.2
* fix for connection failover
0.8.1
* fix for models not exactly matching schema
0.8.0
* support for table polymorphism
* var int type
0.7.1
* refactoring query class to make defining custom model instantiation logic easier
0.7.0
* added counter columns
* added support for compaction settings at the model level
* deprecated delete_table in favor of drop_table
* deprecated create_table in favor of sync_table
* added support for custom QuerySets
0.6.0
* added table sync
0.5.2
* adding hex conversion to Bytes column
0.5.1
* improving connection pooling
* fixing bug with clustering order columns not being quoted
0.5
* eagerly loading results into the query result cache, the cql driver does this anyway,
and pulling them from the cursor was causing some problems with gevented queries,
this will cause some breaking changes for users calling execute directly
0.4.10
* changing query parameter placeholders from uuid1 to uuid4
0.4.7
* adding support for passing None into query batch methods to clear any batch objects
0.4.6
* fixing the way models are compared
0.4.5
* fixed bug where container columns would not call their child to_python method, this only really affected columns with special to_python logic
0.4.4
* adding query logging back
* fixed bug updating an empty list column
0.4.3
* fixed bug with Text column validation
0.4.2
* added support for instantiating container columns with column instances
0.4.1
* fixed bug in TimeUUID from datetime method
0.4.0
* removed default values from all column types
* explicit primary key is required (automatic id removed)
* added validation on keyname types on .create()
* changed table_name to __table_name__, read_repair_chance to __read_repair_chance__, keyspace to __keyspace__
* modified table name auto generator to ignore module name
* changed internal implementation of model value get/set
* added TimeUUID.from_datetime(), used for generating UUID1's for a specific
time
0.3.3
* added abstract base class models
0.3.2
* comprehesive rewrite of connection management (thanks @rustyrazorblade)
0.3
* added support for Token function (thanks @mrk-its)
* added support for compound partition key (thanks @mrk-its)s
* added support for defining clustering key ordering (thanks @mrk-its)
* added values_list to Query class, bypassing object creation if desired (thanks @mrk-its)
* fixed bug with Model.objects caching values (thanks @mrk-its)
* fixed Cassandra 1.2.5 compatibility bug
* updated model exception inheritance
0.2.1
* adding support for datetimes with tzinfo (thanks @gdoermann)
* fixing bug in saving map updates (thanks @pandu-rao)

1
cqlengine/VERSION Normal file
View File

@@ -0,0 +1 @@
0.16.1

View File

@@ -1,7 +1,32 @@
import pkg_resources
from cassandra import ConsistencyLevel
from cqlengine.columns import *
from cqlengine.functions import *
from cqlengine.models import Model, CounterModel
from cqlengine.query import BatchQuery
__version__ = '0.2'
__cqlengine_version_path__ = pkg_resources.resource_filename('cqlengine',
'VERSION')
__version__ = open(__cqlengine_version_path__, 'r').readline().strip()
# compaction
SizeTieredCompactionStrategy = "SizeTieredCompactionStrategy"
LeveledCompactionStrategy = "LeveledCompactionStrategy"
# Caching constants.
CACHING_ALL = "ALL"
CACHING_KEYS_ONLY = "KEYS_ONLY"
CACHING_ROWS_ONLY = "ROWS_ONLY"
CACHING_NONE = "NONE"
ANY = ConsistencyLevel.ANY
ONE = ConsistencyLevel.ONE
TWO = ConsistencyLevel.TWO
THREE = ConsistencyLevel.THREE
QUORUM = ConsistencyLevel.QUORUM
LOCAL_QUORUM = ConsistencyLevel.LOCAL_QUORUM
EACH_QUORUM = ConsistencyLevel.EACH_QUORUM
ALL = ConsistencyLevel.ALL

View File

@@ -1,19 +1,20 @@
#column field types
from copy import copy
from copy import deepcopy, copy
from datetime import datetime
from datetime import date
import re
from uuid import uuid1, uuid4
from cql.query import cql_quote
from cassandra.cqltypes import DateType
from cqlengine.exceptions import ValidationError
from cassandra.encoder import cql_quote
class BaseValueManager(object):
def __init__(self, instance, column, value):
self.instance = instance
self.column = column
self.previous_value = copy(value)
self.previous_value = deepcopy(value)
self.value = value
@property
@@ -52,6 +53,26 @@ class BaseValueManager(object):
else:
return property(_get, _set)
class ValueQuoter(object):
"""
contains a single value, which will quote itself for CQL insertion statements
"""
def __init__(self, value):
self.value = value
def __str__(self):
raise NotImplementedError
def __repr__(self):
return self.__str__()
def __eq__(self, other):
if isinstance(other, self.__class__):
return self.value == other.value
return False
class Column(object):
#the cassandra type this column maps to
@@ -60,24 +81,38 @@ class Column(object):
instance_counter = 0
def __init__(self, primary_key=False, index=False, db_field=None, default=None, required=True):
def __init__(self,
primary_key=False,
partition_key=False,
index=False,
db_field=None,
default=None,
required=False,
clustering_order=None,
polymorphic_key=False):
"""
:param primary_key: bool flag, indicates this column is a primary key. The first primary key defined
on a model is the partition key, all others are cluster keys
on a model is the partition key (unless partition keys are set), all others are cluster keys
:param partition_key: indicates that this column should be the partition key, defining
more than one partition key column creates a compound partition key
:param index: bool flag, indicates an index should be created for this column
:param db_field: the fieldname this field will map to in the database
:param default: the default value, can be a value or a callable (no args)
:param required: boolean, is the field required?
:param required: boolean, is the field required? Model validation will raise and
exception if required is set to True and there is a None value assigned
:param clustering_order: only applicable on clustering keys (primary keys that are not partition keys)
determines the order that the clustering keys are sorted on disk
:param polymorphic_key: boolean, if set to True, this column will be used for saving and loading instances
of polymorphic tables
"""
self.primary_key = primary_key
self.partition_key = partition_key
self.primary_key = partition_key or primary_key
self.index = index
self.db_field = db_field
self.default = default
self.required = required
#only the model meta class should touch this
self._partition_key = False
self.clustering_order = clustering_order
self.polymorphic_key = polymorphic_key
#the column name in the model definition
self.column_name = None
@@ -137,7 +172,7 @@ class Column(object):
"""
Returns a column definition for CQL table definition
"""
return '"{}" {}'.format(self.db_field_name, self.db_type)
return '{} {}'.format(self.cql, self.db_type)
def set_column_name(self, name):
"""
@@ -156,17 +191,45 @@ class Column(object):
""" Returns the name of the cql index """
return 'index_{}'.format(self.db_field_name)
@property
def cql(self):
return self.get_cql()
def get_cql(self):
return '"{}"'.format(self.db_field_name)
def _val_is_null(self, val):
""" determines if the given value equates to a null value for the given column type """
return val is None
class Bytes(Column):
db_type = 'blob'
class Quoter(ValueQuoter):
def __str__(self):
return '0x' + self.value.encode('hex')
def to_database(self, value):
val = super(Bytes, self).to_database(value)
if val is None: return
return self.Quoter(val)
def to_python(self, value):
#return value[2:].decode('hex')
return value
class Ascii(Column):
db_type = 'ascii'
class Text(Column):
db_type = 'text'
def __init__(self, *args, **kwargs):
self.min_length = kwargs.pop('min_length', 1 if kwargs.get('required', True) else None)
self.min_length = kwargs.pop('min_length', 1 if kwargs.get('required', False) else None)
self.max_length = kwargs.pop('max_length', None)
super(Text, self).__init__(*args, **kwargs)
@@ -183,6 +246,7 @@ class Text(Column):
raise ValidationError('{} is shorter than {} characters'.format(self.column_name, self.min_length))
return value
class Integer(Column):
db_type = 'int'
@@ -200,44 +264,102 @@ class Integer(Column):
def to_database(self, value):
return self.validate(value)
class DateTime(Column):
db_type = 'timestamp'
def __init__(self, **kwargs):
super(DateTime, self).__init__(**kwargs)
class BigInt(Integer):
db_type = 'bigint'
class VarInt(Column):
db_type = 'varint'
def validate(self, value):
val = super(VarInt, self).validate(value)
if val is None:
return
try:
return long(val)
except (TypeError, ValueError):
raise ValidationError(
"{} can't be converted to integral value".format(value))
def to_python(self, value):
return self.validate(value)
def to_database(self, value):
return self.validate(value)
class CounterValueManager(BaseValueManager):
def __init__(self, instance, column, value):
super(CounterValueManager, self).__init__(instance, column, value)
self.value = self.value or 0
self.previous_value = self.previous_value or 0
class Counter(Integer):
db_type = 'counter'
value_manager = CounterValueManager
def __init__(self,
index=False,
db_field=None,
required=False):
super(Counter, self).__init__(
primary_key=False,
partition_key=False,
index=index,
db_field=db_field,
default=0,
required=required,
)
class DateTime(Column):
db_type = 'timestamp'
def to_python(self, value):
if value is None: return
if isinstance(value, datetime):
return value
return datetime.utcfromtimestamp(value)
elif isinstance(value, date):
return datetime(*(value.timetuple()[:6]))
try:
return datetime.utcfromtimestamp(value)
except TypeError:
return datetime.utcfromtimestamp(DateType.deserialize(value))
def to_database(self, value):
value = super(DateTime, self).to_database(value)
if value is None: return
if not isinstance(value, datetime):
raise ValidationError("'{}' is not a datetime object".format(value))
if isinstance(value, date):
value = datetime(value.year, value.month, value.day)
else:
raise ValidationError("'{}' is not a datetime object".format(value))
epoch = datetime(1970, 1, 1, tzinfo=value.tzinfo)
offset = 0
if epoch.tzinfo:
offset_delta = epoch.tzinfo.utcoffset(epoch)
offset = offset_delta.days*24*3600 + offset_delta.seconds
return long(((value - epoch).total_seconds() - offset) * 1000)
offset = epoch.tzinfo.utcoffset(epoch).total_seconds() if epoch.tzinfo else 0
return long(((value - epoch).total_seconds() - offset) * 1000)
class Date(Column):
db_type = 'timestamp'
def __init__(self, **kwargs):
super(Date, self).__init__(**kwargs)
def to_python(self, value):
if value is None: return
if isinstance(value, datetime):
return value.date()
elif isinstance(value, date):
return value
return datetime.utcfromtimestamp(value).date()
try:
return datetime.utcfromtimestamp(value).date()
except TypeError:
return datetime.utcfromtimestamp(DateType.deserialize(value)).date()
def to_database(self, value):
value = super(Date, self).to_database(value)
if value is None: return
if isinstance(value, datetime):
value = value.date()
if not isinstance(value, date):
@@ -252,19 +374,25 @@ class UUID(Column):
"""
db_type = 'uuid'
re_uuid = re.compile(r'[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}')
def __init__(self, default=lambda:uuid4(), **kwargs):
super(UUID, self).__init__(default=default, **kwargs)
re_uuid = re.compile(r'[0-9a-f]{8}-?[0-9a-f]{4}-?[0-9a-f]{4}-?[0-9a-f]{4}-?[0-9a-f]{12}')
def validate(self, value):
val = super(UUID, self).validate(value)
if val is None: return
from uuid import UUID as _UUID
if isinstance(val, _UUID): return val
if not self.re_uuid.match(val):
raise ValidationError("{} is not a valid uuid".format(value))
return _UUID(val)
if isinstance(val, basestring) and self.re_uuid.match(val):
return _UUID(val)
raise ValidationError("{} is not a valid uuid".format(value))
def to_python(self, value):
return self.validate(value)
def to_database(self, value):
return self.validate(value)
from uuid import UUID as pyUUID, getnode
class TimeUUID(UUID):
"""
@@ -273,18 +401,52 @@ class TimeUUID(UUID):
db_type = 'timeuuid'
def __init__(self, **kwargs):
kwargs.setdefault('default', lambda: uuid1())
super(TimeUUID, self).__init__(**kwargs)
@classmethod
def from_datetime(self, dt):
"""
generates a UUID for a given datetime
:param dt: datetime
:type dt: datetime
:return:
"""
global _last_timestamp
epoch = datetime(1970, 1, 1, tzinfo=dt.tzinfo)
offset = epoch.tzinfo.utcoffset(epoch).total_seconds() if epoch.tzinfo else 0
timestamp = (dt - epoch).total_seconds() - offset
node = None
clock_seq = None
nanoseconds = int(timestamp * 1e9)
timestamp = int(nanoseconds // 100) + 0x01b21dd213814000L
if clock_seq is None:
import random
clock_seq = random.randrange(1 << 14L) # instead of stable storage
time_low = timestamp & 0xffffffffL
time_mid = (timestamp >> 32L) & 0xffffL
time_hi_version = (timestamp >> 48L) & 0x0fffL
clock_seq_low = clock_seq & 0xffL
clock_seq_hi_variant = (clock_seq >> 8L) & 0x3fL
if node is None:
node = getnode()
return pyUUID(fields=(time_low, time_mid, time_hi_version,
clock_seq_hi_variant, clock_seq_low, node), version=1)
class Boolean(Column):
db_type = 'boolean'
def to_python(self, value):
def validate(self, value):
""" Always returns a Python boolean. """
value = super(Boolean, self).validate(value)
return bool(value)
def to_database(self, value):
return bool(value)
def to_python(self, value):
return self.validate(value)
class Float(Column):
db_type = 'double'
@@ -307,59 +469,76 @@ class Float(Column):
def to_database(self, value):
return self.validate(value)
class Decimal(Column):
db_type = 'decimal'
class Counter(Integer):
""" Validates like an integer, goes into the database as a counter
"""
db_type = 'counter'
class Counter(Column):
#TODO: counter field
def __init__(self, **kwargs):
super(Counter, self).__init__(**kwargs)
class ContainerQuoter(object):
"""
contains a single value, which will quote itself for CQL insertion statements
"""
def __init__(self, value):
self.value = value
def __str__(self):
raise NotImplementedError
def to_python(self, value):
return self.validate(value)
def to_database(self, value):
return self.validate(value)
class BaseContainerColumn(Column):
"""
Base Container type
Base Container type for collection-like columns.
https://cassandra.apache.org/doc/cql3/CQL.html#collections
"""
def __init__(self, value_type, **kwargs):
"""
:param value_type: a column class indicating the types of the value
"""
if not issubclass(value_type, Column):
inheritance_comparator = issubclass if isinstance(value_type, type) else isinstance
if not inheritance_comparator(value_type, Column):
raise ValidationError('value_type must be a column class')
if issubclass(value_type, BaseContainerColumn):
if inheritance_comparator(value_type, BaseContainerColumn):
raise ValidationError('container types cannot be nested')
if value_type.db_type is None:
raise ValidationError('value_type cannot be an abstract column type')
self.value_type = value_type
self.value_col = self.value_type()
if isinstance(value_type, type):
self.value_type = value_type
self.value_col = self.value_type()
else:
self.value_col = value_type
self.value_type = self.value_col.__class__
super(BaseContainerColumn, self).__init__(**kwargs)
def validate(self, value):
value = super(BaseContainerColumn, self).validate(value)
# It is dangerous to let collections have more than 65535.
# See: https://issues.apache.org/jira/browse/CASSANDRA-5428
if value is not None and len(value) > 65535:
raise ValidationError("Collection can't have more than 65535 elements.")
return value
def get_column_def(self):
"""
Returns a column definition for CQL table definition
"""
db_type = self.db_type.format(self.value_type.db_type)
return '{} {}'.format(self.db_field_name, db_type)
return '{} {}'.format(self.cql, db_type)
def _val_is_null(self, val):
return not val
class BaseContainerQuoter(ValueQuoter):
def __nonzero__(self):
return bool(self.value)
def get_update_statement(self, val, prev, ctx):
"""
Used to add partial update statements
"""
raise NotImplementedError
class Set(BaseContainerColumn):
"""
@@ -369,20 +548,21 @@ class Set(BaseContainerColumn):
"""
db_type = 'set<{}>'
class Quoter(ContainerQuoter):
class Quoter(BaseContainerQuoter):
def __str__(self):
cq = cql_quote
return '{' + ', '.join([cq(v) for v in self.value]) + '}'
def __init__(self, value_type, strict=True, **kwargs):
def __init__(self, value_type, strict=True, default=set, **kwargs):
"""
:param value_type: a column class indicating the types of the value
:param strict: sets whether non set values will be coerced to set
type on validation, or raise a validation error, defaults to True
"""
self.strict = strict
super(Set, self).__init__(value_type, **kwargs)
super(Set, self).__init__(value_type, default=default, **kwargs)
def validate(self, value):
val = super(Set, self).validate(value)
@@ -394,55 +574,23 @@ class Set(BaseContainerColumn):
else:
raise ValidationError('{} cannot be coerced to a set object'.format(val))
if None in val:
raise ValidationError("None not allowed in a set")
return {self.value_col.validate(v) for v in val}
def to_python(self, value):
if value is None: return set()
return {self.value_col.to_python(v) for v in value}
def to_database(self, value):
if value is None: return None
if isinstance(value, self.Quoter): return value
return self.Quoter({self.value_col.to_database(v) for v in value})
def get_update_statement(self, val, prev, ctx):
"""
Returns statements that will be added to an object's update statement
also updates the query context
:param val: the current column value
:param prev: the previous column value
:param ctx: the values that will be passed to the query
:rtype: list
"""
# remove from Quoter containers, if applicable
val = self.to_database(val)
prev = self.to_database(prev)
if isinstance(val, self.Quoter): val = val.value
if isinstance(prev, self.Quoter): prev = prev.value
if val is None or val == prev:
# don't return anything if the new value is the same as
# the old one, or if the new value is none
return []
elif prev is None or not any({v in prev for v in val}):
field = uuid1().hex
ctx[field] = self.Quoter(val)
return ['"{}" = :{}'.format(self.db_field_name, field)]
else:
# partial update time
to_create = val - prev
to_delete = prev - val
statements = []
if to_create:
field_id = uuid1().hex
ctx[field_id] = self.Quoter(to_create)
statements += ['"{0}" = "{0}" + :{1}'.format(self.db_field_name, field_id)]
if to_delete:
field_id = uuid1().hex
ctx[field_id] = self.Quoter(to_delete)
statements += ['"{0}" = "{0}" - :{1}'.format(self.db_field_name, field_id)]
return statements
class List(BaseContainerColumn):
"""
@@ -452,93 +600,36 @@ class List(BaseContainerColumn):
"""
db_type = 'list<{}>'
class Quoter(ContainerQuoter):
class Quoter(BaseContainerQuoter):
def __str__(self):
cq = cql_quote
return '[' + ', '.join([cq(v) for v in self.value]) + ']'
def __nonzero__(self):
return bool(self.value)
def __init__(self, value_type, default=list, **kwargs):
return super(List, self).__init__(value_type=value_type, default=default, **kwargs)
def validate(self, value):
val = super(List, self).validate(value)
if val is None: return
if not isinstance(val, (set, list, tuple)):
raise ValidationError('{} is not a list object'.format(val))
if None in val:
raise ValidationError("None is not allowed in a list")
return [self.value_col.validate(v) for v in val]
def to_python(self, value):
if value is None: return None
return list(value)
if value is None: return []
return [self.value_col.to_python(v) for v in value]
def to_database(self, value):
if value is None: return None
if isinstance(value, self.Quoter): return value
return self.Quoter([self.value_col.to_database(v) for v in value])
def get_update_statement(self, val, prev, values):
"""
Returns statements that will be added to an object's update statement
also updates the query context
"""
# remove from Quoter containers, if applicable
val = self.to_database(val)
prev = self.to_database(prev)
if isinstance(val, self.Quoter): val = val.value
if isinstance(prev, self.Quoter): prev = prev.value
def _insert():
field_id = uuid1().hex
values[field_id] = self.Quoter(val)
return ['"{}" = :{}'.format(self.db_field_name, field_id)]
if val is None or val == prev:
return []
elif prev is None:
return _insert()
elif len(val) < len(prev):
return _insert()
else:
# the prepend and append lists,
# if both of these are still None after looking
# at both lists, an insert statement will be returned
prepend = None
append = None
# the max start idx we want to compare
search_space = len(val) - max(0, len(prev)-1)
# the size of the sub lists we want to look at
search_size = len(prev)
for i in range(search_space):
#slice boundary
j = i + search_size
sub = val[i:j]
idx_cmp = lambda idx: prev[idx] == sub[idx]
if idx_cmp(0) and idx_cmp(-1) and prev == sub:
prepend = val[:i]
append = val[j:]
break
# create update statements
if prepend is append is None:
return _insert()
statements = []
if prepend:
field_id = uuid1().hex
# CQL seems to prepend element at a time, starting
# with the element at idx 0, we can either reverse
# it here, or have it inserted in reverse
prepend.reverse()
values[field_id] = self.Quoter(prepend)
statements += ['"{0}" = :{1} + "{0}"'.format(self.db_field_name, field_id)]
if append:
field_id = uuid1().hex
values[field_id] = self.Quoter(append)
statements += ['"{0}" = "{0}" + :{1}'.format(self.db_field_name, field_id)]
return statements
class Map(BaseContainerColumn):
"""
@@ -549,27 +640,41 @@ class Map(BaseContainerColumn):
db_type = 'map<{}, {}>'
class Quoter(ContainerQuoter):
class Quoter(BaseContainerQuoter):
def __str__(self):
cq = cql_quote
return '{' + ', '.join([cq(k) + ':' + cq(v) for k,v in self.value.items()]) + '}'
def __init__(self, key_type, value_type, **kwargs):
def get(self, key):
return self.value.get(key)
def keys(self):
return self.value.keys()
def items(self):
return self.value.items()
def __init__(self, key_type, value_type, default=dict, **kwargs):
"""
:param key_type: a column class indicating the types of the key
:param value_type: a column class indicating the types of the value
"""
if not issubclass(value_type, Column):
inheritance_comparator = issubclass if isinstance(key_type, type) else isinstance
if not inheritance_comparator(key_type, Column):
raise ValidationError('key_type must be a column class')
if issubclass(value_type, BaseContainerColumn):
if inheritance_comparator(key_type, BaseContainerColumn):
raise ValidationError('container types cannot be nested')
if key_type.db_type is None:
raise ValidationError('key_type cannot be an abstract column type')
self.key_type = key_type
self.key_col = self.key_type()
super(Map, self).__init__(value_type, **kwargs)
if isinstance(key_type, type):
self.key_type = key_type
self.key_col = self.key_type()
else:
self.key_col = key_type
self.key_type = self.key_col.__class__
super(Map, self).__init__(value_type, default=default, **kwargs)
def get_column_def(self):
"""
@@ -579,7 +684,7 @@ class Map(BaseContainerColumn):
self.key_type.db_type,
self.value_type.db_type
)
return '{} {}'.format(self.db_field_name, db_type)
return '{} {}'.format(self.cql, db_type)
def validate(self, value):
val = super(Map, self).validate(value)
@@ -589,62 +694,38 @@ class Map(BaseContainerColumn):
return {self.key_col.validate(k):self.value_col.validate(v) for k,v in val.items()}
def to_python(self, value):
if value is None:
return {}
if value is not None:
return {self.key_col.to_python(k):self.value_col.to_python(v) for k,v in value.items()}
return {self.key_col.to_python(k): self.value_col.to_python(v) for k,v in value.items()}
def to_database(self, value):
if value is None: return None
if isinstance(value, self.Quoter): return value
return self.Quoter({self.key_col.to_database(k):self.value_col.to_database(v) for k,v in value.items()})
def get_update_statement(self, val, prev, ctx):
"""
http://www.datastax.com/docs/1.2/cql_cli/using/collections_map#deletion
"""
# remove from Quoter containers, if applicable
val = self.to_database(val)
prev = self.to_database(prev)
if isinstance(val, self.Quoter): val = val.value
if isinstance(prev, self.Quoter): prev = prev.value
val = val or {}
prev = prev or {}
#get the updated map
update = {k:v for k,v in val.items() if v != prev.get(k)}
statements = []
for k,v in update.items():
key_id = uuid1().hex
val_id = uuid1().hex
ctx[key_id] = k
ctx[val_id] = v
statements += ['"{}"[:{}] = :{}'.format(self.db_field_name, key_id, val_id)]
return statements
def get_delete_statement(self, val, prev, ctx):
"""
Returns statements that will be added to an object's delete statement
also updates the query context, used for removing keys from a map
"""
if val is prev is None:
return []
val = self.to_database(val)
prev = self.to_database(prev)
if isinstance(val, self.Quoter): val = val.value
if isinstance(prev, self.Quoter): prev = prev.value
old_keys = set(prev.keys()) if prev else set()
new_keys = set(val.keys()) if val else set()
del_keys = old_keys - new_keys
del_statements = []
for key in del_keys:
field_id = uuid1().hex
ctx[field_id] = key
del_statements += ['"{}"[:{}]'.format(self.db_field_name, field_id)]
return del_statements
class _PartitionKeysToken(Column):
"""
virtual column representing token of partition columns.
Used by filter(pk__token=Token(...)) filters
"""
def __init__(self, model):
self.partition_columns = model._partition_keys.values()
super(_PartitionKeysToken, self).__init__(partition_key=True)
@property
def db_field_name(self):
return 'token({})'.format(', '.join(['"{}"'.format(c.db_field_name) for c in self.partition_columns]))
def to_database(self, value):
from cqlengine.functions import Token
assert isinstance(value, Token)
value.set_columns(self.partition_columns)
return value
def get_cql(self):
return "token({})".format(", ".join(c.cql for c in self.partition_columns))

View File

@@ -3,184 +3,113 @@
#http://cassandra.apache.org/doc/cql/CQL.html
from collections import namedtuple
import Queue
import random
from cassandra.cluster import Cluster
from cassandra.query import SimpleStatement, Statement
import cql
try:
import Queue as queue
except ImportError:
# python 3
import queue
from cqlengine.exceptions import CQLEngineException
import logging
from thrift.transport.TTransport import TTransportException
from cqlengine.exceptions import CQLEngineException, UndefinedKeyspaceException
from cassandra import ConsistencyLevel
from cqlengine.statements import BaseCQLStatement
from cassandra.query import dict_factory
LOG = logging.getLogger('cqlengine.cql')
class CQLConnectionError(CQLEngineException): pass
Host = namedtuple('Host', ['name', 'port'])
_hosts = []
_host_idx = 0
_conn= None
_username = None
_password = None
_max_connections = 10
def setup(hosts, username=None, password=None, max_connections=10, default_keyspace=None, lazy=False):
cluster = None
session = None
lazy_connect_args = None
default_consistency_level = None
def setup(
hosts,
default_keyspace,
consistency=ConsistencyLevel.ONE,
lazy_connect=False,
**kwargs):
"""
Records the hosts and connects to one of them
:param hosts: list of hosts, strings in the <hostname>:<port>, or just <hostname>
:param hosts: list of hosts, see http://datastax.github.io/python-driver/api/cassandra/cluster.html
:type hosts: list
:param default_keyspace: The default keyspace to use
:type default_keyspace: str
:param consistency: The global consistency level
:type consistency: int
:param lazy_connect: True if should not connect until first use
:type lazy_connect: bool
"""
global _hosts
global _username
global _password
global _max_connections
_username = username
_password = password
_max_connections = max_connections
global cluster, session, default_consistency_level, lazy_connect_args
if default_keyspace:
from cqlengine import models
models.DEFAULT_KEYSPACE = default_keyspace
if 'username' in kwargs or 'password' in kwargs:
raise CQLEngineException("Username & Password are now handled by using the native driver's auth_provider")
for host in hosts:
host = host.strip()
host = host.split(':')
if len(host) == 1:
_hosts.append(Host(host[0], 9160))
elif len(host) == 2:
_hosts.append(Host(*host))
else:
raise CQLConnectionError("Can't parse {}".format(''.join(host)))
if not default_keyspace:
raise UndefinedKeyspaceException()
if not _hosts:
raise CQLConnectionError("At least one host required")
from cqlengine import models
models.DEFAULT_KEYSPACE = default_keyspace
random.shuffle(_hosts)
default_consistency_level = consistency
if lazy_connect:
lazy_connect_args = (hosts, default_keyspace, consistency, kwargs)
return
if not lazy:
con = ConnectionPool.get()
ConnectionPool.put(con)
cluster = Cluster(hosts, **kwargs)
session = cluster.connect()
session.row_factory = dict_factory
def execute(query, params=None, consistency_level=None):
handle_lazy_connect()
if not session:
raise CQLEngineException("It is required to setup() cqlengine before executing queries")
class ConnectionPool(object):
"""Handles pooling of database connections."""
if consistency_level is None:
consistency_level = default_consistency_level
# Connection pool queue
_queue = None
if isinstance(query, Statement):
pass
@classmethod
def clear(cls):
"""
Force the connection pool to be cleared. Will close all internal
connections.
"""
try:
while not cls._queue.empty():
cls._queue.get().close()
except:
pass
elif isinstance(query, BaseCQLStatement):
params = query.get_context()
query = str(query)
query = SimpleStatement(query, consistency_level=consistency_level)
@classmethod
def get(cls):
"""
Returns a usable database connection. Uses the internal queue to
determine whether to return an existing connection or to create
a new one.
"""
try:
if cls._queue.empty():
return cls._create_connection()
return cls._queue.get()
except CQLConnectionError as cqle:
raise cqle
except:
if not cls._queue:
cls._queue = Queue.Queue(maxsize=_max_connections)
return cls._create_connection()
@classmethod
def put(cls, conn):
"""
Returns a connection to the queue freeing it up for other queries to
use.
:param conn: The connection to be released
:type conn: connection
"""
try:
if cls._queue.full():
conn.close()
else:
cls._queue.put(conn)
except:
if not cls._queue:
cls._queue = Queue.Queue(maxsize=_max_connections)
cls._queue.put(conn)
@classmethod
def _create_connection(cls):
"""
Creates a new connection for the connection pool.
"""
global _hosts
global _username
global _password
if not _hosts:
raise CQLConnectionError("At least one host required")
host = _hosts[_host_idx]
new_conn = cql.connect(host.name, host.port, user=_username, password=_password)
new_conn.set_cql_version('3.0.0')
return new_conn
elif isinstance(query, basestring):
query = SimpleStatement(query, consistency_level=consistency_level)
class connection_manager(object):
"""
Connection failure tolerant connection manager. Written to be used in a 'with' block for connection pooling
"""
def __init__(self):
if not _hosts:
raise CQLConnectionError("No connections have been configured, call cqlengine.connection.setup")
self.keyspace = None
self.con = ConnectionPool.get()
self.cur = None
LOG.info(query.query_string)
def close(self):
if self.cur: self.cur.close()
ConnectionPool.put(self.con)
params = params or {}
result = session.execute(query, params)
def __enter__(self):
return self
return result
def __exit__(self, type, value, traceback):
self.close()
def execute(self, query, params=None):
"""
Gets a connection from the pool and executes the given query, returns the cursor
def get_session():
handle_lazy_connect()
return session
if there's a connection problem, this will silently create a new connection pool
from the available hosts, and remove the problematic host from the host list
"""
if params is None:
params = {}
global _host_idx
for i in range(len(_hosts)):
try:
self.cur = self.con.cursor()
self.cur.execute(query, params)
return self.cur
except cql.ProgrammingError as ex:
raise CQLEngineException(unicode(ex))
except TTransportException:
#TODO: check for other errors raised in the event of a connection / server problem
#move to the next connection and set the connection pool
_host_idx += 1
_host_idx %= len(_hosts)
self.con.close()
self.con = ConnectionPool._create_connection()
raise CQLConnectionError("couldn't reach a Cassandra server")
def get_cluster():
handle_lazy_connect()
return cluster
def handle_lazy_connect():
global lazy_connect_args
if lazy_connect_args:
hosts, default_keyspace, consistency, kwargs = lazy_connect_args
lazy_connect_args = None
setup(hosts, default_keyspace, consistency, **kwargs)

View File

@@ -3,3 +3,4 @@ class CQLEngineException(Exception): pass
class ModelException(CQLEngineException): pass
class ValidationError(CQLEngineException): pass
class UndefinedKeyspaceException(CQLEngineException): pass

View File

@@ -1,31 +1,48 @@
from datetime import datetime
from uuid import uuid1
from cqlengine.exceptions import ValidationError
class BaseQueryFunction(object):
class QueryValue(object):
"""
Base class for query filter values. Subclasses of these classes can
be passed into .filter() keyword args
"""
format_string = '%({})s'
def __init__(self, value):
self.value = value
self.context_id = None
def __unicode__(self):
return self.format_string.format(self.context_id)
def set_context_id(self, ctx_id):
self.context_id = ctx_id
def get_context_size(self):
return 1
def update_context(self, ctx):
ctx[str(self.context_id)] = self.value
class BaseQueryFunction(QueryValue):
"""
Base class for filtering functions. Subclasses of these classes can
be passed into .filter() and will be translated into CQL functions in
the resulting query
"""
_cql_string = None
def __init__(self, value):
self.value = value
def to_cql(self, value_id):
"""
Returns a function for cql with the value id as it's argument
"""
return self._cql_string.format(value_id)
def get_value(self):
raise NotImplementedError
class MinTimeUUID(BaseQueryFunction):
"""
return a fake timeuuid corresponding to the smallest possible timeuuid for the given timestamp
_cql_string = 'MinTimeUUID(:{})'
http://cassandra.apache.org/doc/cql3/CQL.html#timeuuidFun
"""
format_string = 'MinTimeUUID(%({})s)'
def __init__(self, value):
"""
@@ -36,13 +53,23 @@ class MinTimeUUID(BaseQueryFunction):
raise ValidationError('datetime instance is required')
super(MinTimeUUID, self).__init__(value)
def get_value(self):
epoch = datetime(1970, 1, 1)
return long((self.value - epoch).total_seconds() * 1000)
def to_database(self, val):
epoch = datetime(1970, 1, 1, tzinfo=val.tzinfo)
offset = epoch.tzinfo.utcoffset(epoch).total_seconds() if epoch.tzinfo else 0
return long(((val - epoch).total_seconds() - offset) * 1000)
def update_context(self, ctx):
ctx[str(self.context_id)] = self.to_database(self.value)
class MaxTimeUUID(BaseQueryFunction):
"""
return a fake timeuuid corresponding to the largest possible timeuuid for the given timestamp
_cql_string = 'MaxTimeUUID(:{})'
http://cassandra.apache.org/doc/cql3/CQL.html#timeuuidFun
"""
format_string = 'MaxTimeUUID(%({})s)'
def __init__(self, value):
"""
@@ -53,9 +80,41 @@ class MaxTimeUUID(BaseQueryFunction):
raise ValidationError('datetime instance is required')
super(MaxTimeUUID, self).__init__(value)
def get_value(self):
epoch = datetime(1970, 1, 1)
return long((self.value - epoch).total_seconds() * 1000)
def to_database(self, val):
epoch = datetime(1970, 1, 1, tzinfo=val.tzinfo)
offset = epoch.tzinfo.utcoffset(epoch).total_seconds() if epoch.tzinfo else 0
return long(((val - epoch).total_seconds() - offset) * 1000)
def update_context(self, ctx):
ctx[str(self.context_id)] = self.to_database(self.value)
class Token(BaseQueryFunction):
"""
compute the token for a given partition key
http://cassandra.apache.org/doc/cql3/CQL.html#tokenFun
"""
def __init__(self, *values):
if len(values) == 1 and isinstance(values[0], (list, tuple)):
values = values[0]
super(Token, self).__init__(values)
self._columns = None
def set_columns(self, columns):
self._columns = columns
def get_context_size(self):
return len(self.value)
def __unicode__(self):
token_args = ', '.join('%({})s'.format(self.context_id + i) for i in range(self.get_context_size()))
return "token({})".format(token_args)
def update_context(self, ctx):
for i, (col, val) in enumerate(zip(self._columns, self.value)):
ctx[str(self.context_id + i)] = col.to_database(val)
class NotSet(object):

View File

@@ -1,8 +1,22 @@
import json
import warnings
from cqlengine import SizeTieredCompactionStrategy, LeveledCompactionStrategy
from cqlengine import ONE
from cqlengine.named import NamedTable
from cqlengine.connection import connection_manager
from cqlengine.connection import execute, get_cluster
from cqlengine.exceptions import CQLEngineException
import logging
from collections import namedtuple
Field = namedtuple('Field', ['name', 'type'])
logger = logging.getLogger(__name__)
from cqlengine.models import Model
# system keyspaces
schema_columnfamilies = NamedTable('system', 'schema_columnfamilies')
def create_keyspace(name, strategy_class='SimpleStrategy', replication_factor=3, durable_writes=True, **replication_values):
"""
creates a keyspace
@@ -13,109 +27,302 @@ def create_keyspace(name, strategy_class='SimpleStrategy', replication_factor=3,
:param durable_writes: 1.2 only, write log is bypassed if set to False
:param **replication_values: 1.2 only, additional values to ad to the replication data map
"""
with connection_manager() as con:
if not any([name == k.name for k in con.con.client.describe_keyspaces()]):
# if name not in [k.name for k in con.con.client.describe_keyspaces()]:
try:
#Try the 1.1 method
con.execute("""CREATE KEYSPACE {}
WITH strategy_class = '{}'
AND strategy_options:replication_factor={};""".format(name, strategy_class, replication_factor))
except CQLEngineException:
#try the 1.2 method
replication_map = {
'class': strategy_class,
'replication_factor':replication_factor
}
replication_map.update(replication_values)
cluster = get_cluster()
query = """
CREATE KEYSPACE {}
WITH REPLICATION = {}
""".format(name, json.dumps(replication_map).replace('"', "'"))
if name not in cluster.metadata.keyspaces:
#try the 1.2 method
replication_map = {
'class': strategy_class,
'replication_factor':replication_factor
}
replication_map.update(replication_values)
if strategy_class.lower() != 'simplestrategy':
# Although the Cassandra documentation states for `replication_factor`
# that it is "Required if class is SimpleStrategy; otherwise,
# not used." we get an error if it is present.
replication_map.pop('replication_factor', None)
if strategy_class != 'SimpleStrategy':
query += " AND DURABLE_WRITES = {}".format('true' if durable_writes else 'false')
query = """
CREATE KEYSPACE {}
WITH REPLICATION = {}
""".format(name, json.dumps(replication_map).replace('"', "'"))
if strategy_class != 'SimpleStrategy':
query += " AND DURABLE_WRITES = {}".format('true' if durable_writes else 'false')
execute(query)
con.execute(query)
def delete_keyspace(name):
with connection_manager() as con:
if name in [k.name for k in con.con.client.describe_keyspaces()]:
con.execute("DROP KEYSPACE {}".format(name))
cluster = get_cluster()
if name in cluster.metadata.keyspaces:
execute("DROP KEYSPACE {}".format(name))
def create_table(model, create_missing_keyspace=True):
raise CQLEngineException("create_table is deprecated, please use sync_table")
def sync_table(model, create_missing_keyspace=True):
"""
Inspects the model and creates / updates the corresponding table and columns.
Note that the attributes removed from the model are not deleted on the database.
They become effectively ignored by (will not show up on) the model.
:param create_missing_keyspace: (Defaults to True) Flags to us that we need to create missing keyspace
mentioned in the model automatically.
:type create_missing_keyspace: bool
"""
if not issubclass(model, Model):
raise CQLEngineException("Models must be derived from base Model.")
if model.__abstract__:
raise CQLEngineException("cannot create table from abstract model")
#construct query string
cf_name = model.column_family_name()
raw_cf_name = model.column_family_name(include_keyspace=False)
ks_name = model._get_keyspace()
#create missing keyspace
if create_missing_keyspace:
create_keyspace(model._get_keyspace())
create_keyspace(ks_name)
with connection_manager() as con:
#check for an existing column family
ks_info = con.con.client.describe_keyspace(model._get_keyspace())
if not any([raw_cf_name == cf.name for cf in ks_info.cf_defs]):
qs = ['CREATE TABLE {}'.format(cf_name)]
cluster = get_cluster()
#add column types
pkeys = []
qtypes = []
def add_column(col):
s = col.get_column_def()
if col.primary_key: pkeys.append('"{}"'.format(col.db_field_name))
qtypes.append(s)
for name, col in model._columns.items():
add_column(col)
keyspace = cluster.metadata.keyspaces[ks_name]
tables = keyspace.tables
qtypes.append('PRIMARY KEY ({})'.format(', '.join(pkeys)))
qs += ['({})'.format(', '.join(qtypes))]
# add read_repair_chance
qs += ['WITH read_repair_chance = {}'.format(model.read_repair_chance)]
qs = ' '.join(qs)
#check for an existing column family
if raw_cf_name not in tables:
qs = get_create_table(model)
try:
con.execute(qs)
except CQLEngineException as ex:
# 1.2 doesn't return cf names, so we have to examine the exception
# and ignore if it says the column family already exists
if "Cannot add already existing column family" not in unicode(ex):
raise
try:
execute(qs)
except CQLEngineException as ex:
# 1.2 doesn't return cf names, so we have to examine the exception
# and ignore if it says the column family already exists
if "Cannot add already existing column family" not in unicode(ex):
raise
else:
# see if we're missing any columns
fields = get_fields(model)
field_names = [x.name for x in fields]
for name, col in model._columns.items():
if col.primary_key or col.partition_key: continue # we can't mess with the PK
if col.db_field_name in field_names: continue # skip columns already defined
#get existing index names, skip ones that already exist
ks_info = con.con.client.describe_keyspace(model._get_keyspace())
cf_defs = [cf for cf in ks_info.cf_defs if cf.name == raw_cf_name]
idx_names = [i.index_name for i in cf_defs[0].column_metadata] if cf_defs else []
idx_names = filter(None, idx_names)
# add missing column using the column def
query = "ALTER TABLE {} add {}".format(cf_name, col.get_column_def())
logger.debug(query)
execute(query)
indexes = [c for n,c in model._columns.items() if c.index]
if indexes:
for column in indexes:
if column.db_index_name in idx_names: continue
qs = ['CREATE INDEX index_{}_{}'.format(raw_cf_name, column.db_field_name)]
qs += ['ON {}'.format(cf_name)]
qs += ['("{}")'.format(column.db_field_name)]
qs = ' '.join(qs)
update_compaction(model)
try:
con.execute(qs)
except CQLEngineException as ex:
# 1.2 doesn't return cf names, so we have to examine the exception
# and ignore if it says the index already exists
if "Index already exists" not in unicode(ex):
raise
table = cluster.metadata.keyspaces[ks_name].tables[raw_cf_name]
indexes = [c for n,c in model._columns.items() if c.index]
for column in indexes:
if table.columns[column.db_field_name].index:
continue
qs = ['CREATE INDEX index_{}_{}'.format(raw_cf_name, column.db_field_name)]
qs += ['ON {}'.format(cf_name)]
qs += ['("{}")'.format(column.db_field_name)]
qs = ' '.join(qs)
execute(qs)
def get_create_table(model):
cf_name = model.column_family_name()
qs = ['CREATE TABLE {}'.format(cf_name)]
#add column types
pkeys = [] # primary keys
ckeys = [] # clustering keys
qtypes = [] # field types
def add_column(col):
s = col.get_column_def()
if col.primary_key:
keys = (pkeys if col.partition_key else ckeys)
keys.append('"{}"'.format(col.db_field_name))
qtypes.append(s)
for name, col in model._columns.items():
add_column(col)
qtypes.append('PRIMARY KEY (({}){})'.format(', '.join(pkeys), ckeys and ', ' + ', '.join(ckeys) or ''))
qs += ['({})'.format(', '.join(qtypes))]
with_qs = []
table_properties = ['bloom_filter_fp_chance', 'caching', 'comment',
'dclocal_read_repair_chance', 'default_time_to_live', 'gc_grace_seconds',
'index_interval', 'memtable_flush_period_in_ms', 'populate_io_cache_on_flush',
'read_repair_chance', 'replicate_on_write']
for prop_name in table_properties:
prop_value = getattr(model, '__{}__'.format(prop_name), None)
if prop_value is not None:
# Strings needs to be single quoted
if isinstance(prop_value, basestring):
prop_value = "'{}'".format(prop_value)
with_qs.append("{} = {}".format(prop_name, prop_value))
_order = ['"{}" {}'.format(c.db_field_name, c.clustering_order or 'ASC') for c in model._clustering_keys.values()]
if _order:
with_qs.append('clustering order by ({})'.format(', '.join(_order)))
compaction_options = get_compaction_options(model)
if compaction_options:
compaction_options = json.dumps(compaction_options).replace('"', "'")
with_qs.append("compaction = {}".format(compaction_options))
# Add table properties.
if with_qs:
qs += ['WITH {}'.format(' AND '.join(with_qs))]
qs = ' '.join(qs)
return qs
def get_compaction_options(model):
"""
Generates dictionary (later converted to a string) for creating and altering
tables with compaction strategy
:param model:
:return:
"""
if not model.__compaction__:
return {}
result = {'class':model.__compaction__}
def setter(key, limited_to_strategy = None):
"""
sets key in result, checking if the key is limited to either SizeTiered or Leveled
:param key: one of the compaction options, like "bucket_high"
:param limited_to_strategy: SizeTieredCompactionStrategy, LeveledCompactionStrategy
:return:
"""
mkey = "__compaction_{}__".format(key)
tmp = getattr(model, mkey)
if tmp and limited_to_strategy and limited_to_strategy != model.__compaction__:
raise CQLEngineException("{} is limited to {}".format(key, limited_to_strategy))
if tmp:
# Explicitly cast the values to strings to be able to compare the
# values against introspected values from Cassandra.
result[key] = str(tmp)
setter('tombstone_compaction_interval')
setter('tombstone_threshold')
setter('bucket_high', SizeTieredCompactionStrategy)
setter('bucket_low', SizeTieredCompactionStrategy)
setter('max_threshold', SizeTieredCompactionStrategy)
setter('min_threshold', SizeTieredCompactionStrategy)
setter('min_sstable_size', SizeTieredCompactionStrategy)
setter('sstable_size_in_mb', LeveledCompactionStrategy)
return result
def get_fields(model):
# returns all fields that aren't part of the PK
ks_name = model._get_keyspace()
col_family = model.column_family_name(include_keyspace=False)
query = "select * from system.schema_columns where keyspace_name = %s and columnfamily_name = %s"
tmp = execute(query, [ks_name, col_family])
# Tables containing only primary keys do not appear to create
# any entries in system.schema_columns, as only non-primary-key attributes
# appear to be inserted into the schema_columns table
try:
return [Field(x['column_name'], x['validator']) for x in tmp if x['type'] == 'regular']
except KeyError:
return [Field(x['column_name'], x['validator']) for x in tmp]
# convert to Field named tuples
def get_table_settings(model):
# returns the table as provided by the native driver for a given model
cluster = get_cluster()
ks = model._get_keyspace()
table = model.column_family_name(include_keyspace=False)
table = cluster.metadata.keyspaces[ks].tables[table]
return table
def update_compaction(model):
"""Updates the compaction options for the given model if necessary.
:param model: The model to update.
:return: `True`, if the compaction options were modified in Cassandra,
`False` otherwise.
:rtype: bool
"""
logger.debug("Checking %s for compaction differences", model)
table = get_table_settings(model)
existing_options = table.options.copy()
existing_compaction_strategy = existing_options['compaction_strategy_class']
existing_options = json.loads(existing_options['compaction_strategy_options'])
desired_options = get_compaction_options(model)
desired_compact_strategy = desired_options.get('class', SizeTieredCompactionStrategy)
desired_options.pop('class', None)
do_update = False
if desired_compact_strategy not in existing_compaction_strategy:
do_update = True
for k, v in desired_options.items():
val = existing_options.pop(k, None)
if val != v:
do_update = True
# check compaction_strategy_options
if do_update:
options = get_compaction_options(model)
# jsonify
options = json.dumps(options).replace('"', "'")
cf_name = model.column_family_name()
query = "ALTER TABLE {} with compaction = {}".format(cf_name, options)
logger.debug(query)
execute(query)
return True
return False
def delete_table(model):
cf_name = model.column_family_name()
with connection_manager() as con:
try:
con.execute('drop table {};'.format(cf_name))
except CQLEngineException as ex:
#don't freak out if the table doesn't exist
if 'Cannot drop non existing column family' not in unicode(ex):
raise
raise CQLEngineException("delete_table has been deprecated in favor of drop_table()")
def drop_table(model):
# don't try to delete non existant tables
meta = get_cluster().metadata
ks_name = model._get_keyspace()
raw_cf_name = model.column_family_name(include_keyspace=False)
try:
table = meta.keyspaces[ks_name].tables[raw_cf_name]
execute('drop table {};'.format(model.column_family_name(include_keyspace=True)))
except KeyError:
pass

View File

@@ -1,15 +1,24 @@
from collections import OrderedDict
import re
import warnings
from cqlengine import columns
from cqlengine.exceptions import ModelException
from cqlengine.functions import BaseQueryFunction, NotSet
from cqlengine.query import QuerySet, QueryException, DMLQuery
from cqlengine.exceptions import ModelException, CQLEngineException, ValidationError
from cqlengine.query import ModelQuerySet, DMLQuery, AbstractQueryableColumn
from cqlengine.query import DoesNotExist as _DoesNotExist
from cqlengine.query import MultipleObjectsReturned as _MultipleObjectsReturned
class ModelDefinitionException(ModelException): pass
DEFAULT_KEYSPACE = 'cqlengine'
class PolyMorphicModelException(ModelException): pass
DEFAULT_KEYSPACE = None
class UndefinedKeyspaceWarning(Warning):
pass
class hybrid_classmethod(object):
"""
@@ -27,35 +36,253 @@ class hybrid_classmethod(object):
else:
return self.instmethod.__get__(instance, owner)
def __call__(self, *args, **kwargs):
"""
Just a hint to IDEs that it's ok to call this
"""
raise NotImplementedError
class QuerySetDescriptor(object):
"""
returns a fresh queryset for the given model
it's declared on everytime it's accessed
"""
def __get__(self, obj, model):
""" :rtype: ModelQuerySet """
if model.__abstract__:
raise CQLEngineException('cannot execute queries against abstract models')
queryset = model.__queryset__(model)
# if this is a concrete polymorphic model, and the polymorphic
# key is an indexed column, add a filter clause to only return
# logical rows of the proper type
if model._is_polymorphic and not model._is_polymorphic_base:
name, column = model._polymorphic_column_name, model._polymorphic_column
if column.partition_key or column.index:
# look for existing poly types
return queryset.filter(**{name: model.__polymorphic_key__})
return queryset
def __call__(self, *args, **kwargs):
"""
Just a hint to IDEs that it's ok to call this
:rtype: ModelQuerySet
"""
raise NotImplementedError
class TTLDescriptor(object):
"""
returns a query set descriptor
"""
def __get__(self, instance, model):
if instance:
#instance = copy.deepcopy(instance)
# instance method
def ttl_setter(ts):
instance._ttl = ts
return instance
return ttl_setter
qs = model.__queryset__(model)
def ttl_setter(ts):
qs._ttl = ts
return qs
return ttl_setter
def __call__(self, *args, **kwargs):
raise NotImplementedError
class TimestampDescriptor(object):
"""
returns a query set descriptor with a timestamp specified
"""
def __get__(self, instance, model):
if instance:
# instance method
def timestamp_setter(ts):
instance._timestamp = ts
return instance
return timestamp_setter
return model.objects.timestamp
def __call__(self, *args, **kwargs):
raise NotImplementedError
class ConsistencyDescriptor(object):
"""
returns a query set descriptor if called on Class, instance if it was an instance call
"""
def __get__(self, instance, model):
if instance:
#instance = copy.deepcopy(instance)
def consistency_setter(consistency):
instance.__consistency__ = consistency
return instance
return consistency_setter
qs = model.__queryset__(model)
def consistency_setter(consistency):
qs._consistency = consistency
return qs
return consistency_setter
def __call__(self, *args, **kwargs):
raise NotImplementedError
class ColumnQueryEvaluator(AbstractQueryableColumn):
"""
Wraps a column and allows it to be used in comparator
expressions, returning query operators
ie:
Model.column == 5
"""
def __init__(self, column):
self.column = column
def __unicode__(self):
return self.column.db_field_name
def _get_column(self):
""" :rtype: ColumnQueryEvaluator """
return self.column
class ColumnDescriptor(object):
"""
Handles the reading and writing of column values to and from
a model instance's value manager, as well as creating
comparator queries
"""
def __init__(self, column):
"""
:param column:
:type column: columns.Column
:return:
"""
self.column = column
self.query_evaluator = ColumnQueryEvaluator(self.column)
def __get__(self, instance, owner):
"""
Returns either the value or column, depending
on if an instance is provided or not
:param instance: the model instance
:type instance: Model
"""
try:
return instance._values[self.column.column_name].getval()
except AttributeError as e:
return self.query_evaluator
def __set__(self, instance, value):
"""
Sets the value on an instance, raises an exception with classes
TODO: use None instance to create update statements
"""
if instance:
return instance._values[self.column.column_name].setval(value)
else:
raise AttributeError('cannot reassign column values')
def __delete__(self, instance):
"""
Sets the column value to None, if possible
"""
if instance:
if self.column.can_delete:
instance._values[self.column.column_name].delval()
else:
raise AttributeError('cannot delete {} columns'.format(self.column.column_name))
class BaseModel(object):
"""
The base model class, don't inherit from this, inherit from Model, defined below
"""
class DoesNotExist(QueryException): pass
class MultipleObjectsReturned(QueryException): pass
#table names will be generated automatically from it's model and package name
#however, you can also define them manually here
table_name = None
class DoesNotExist(_DoesNotExist): pass
#DEFAULT_TTL must be an integer seconds for the default time to live on any insert on the table
#this can be overridden on any given query, but you can set a default on the model
DEFAULT_TTL = None
class MultipleObjectsReturned(_MultipleObjectsReturned): pass
#the keyspace for this model
keyspace = None
read_repair_chance = 0.1
objects = QuerySetDescriptor()
ttl = TTLDescriptor()
consistency = ConsistencyDescriptor()
def __init__(self, ttl=NotSet, **values):
# custom timestamps, see USING TIMESTAMP X
timestamp = TimestampDescriptor()
# _len is lazily created by __len__
# table names will be generated automatically from it's model
# however, you can also define them manually here
__table_name__ = None
# the keyspace for this model
__keyspace__ = None
# polymorphism options
__polymorphic_key__ = None
# compaction options
__compaction__ = None
__compaction_tombstone_compaction_interval__ = None
__compaction_tombstone_threshold__ = None
# compaction - size tiered options
__compaction_bucket_high__ = None
__compaction_bucket_low__ = None
__compaction_max_threshold__ = None
__compaction_min_threshold__ = None
__compaction_min_sstable_size__ = None
# compaction - leveled options
__compaction_sstable_size_in_mb__ = None
# end compaction
# the queryset class used for this class
__queryset__ = ModelQuerySet
__dmlquery__ = DMLQuery
#__ttl__ = None # this doesn't seem to be used
__consistency__ = None # can be set per query
# Additional table properties
__bloom_filter_fp_chance__ = None
__caching__ = None
__comment__ = None
__dclocal_read_repair_chance__ = None
__default_time_to_live__ = None
__gc_grace_seconds__ = None
__index_interval__ = None
__memtable_flush_period_in_ms__ = None
__populate_io_cache_on_flush__ = None
__read_repair_chance__ = None
__replicate_on_write__ = None
_timestamp = None # optional timestamp to include with the operation (USING TIMESTAMP)
def __init__(self, **values):
self._values = {}
if ttl == NotSet:
self.ttl = self.DEFAULT_TTL
else:
self.ttl = ttl
for name, column in self._columns.items():
value = values.get(name, None)
if value is not None: value = column.to_python(value)
if value is not None or isinstance(column, columns.BaseContainerColumn):
value = column.to_python(value)
value_mngr = column.value_manager(self, column, value)
self._values[name] = value_mngr
@@ -64,6 +291,77 @@ class BaseModel(object):
self._is_persisted = False
self._batch = None
def __repr__(self):
"""
Pretty printing of models by their primary key
"""
return '{} <{}>'.format(self.__class__.__name__,
', '.join(('{}={}'.format(k, getattr(self, k)) for k,v in self._primary_keys.iteritems()))
)
@classmethod
def _discover_polymorphic_submodels(cls):
if not cls._is_polymorphic_base:
raise ModelException('_discover_polymorphic_submodels can only be called on polymorphic base classes')
def _discover(klass):
if not klass._is_polymorphic_base and klass.__polymorphic_key__ is not None:
cls._polymorphic_map[klass.__polymorphic_key__] = klass
for subklass in klass.__subclasses__():
_discover(subklass)
_discover(cls)
@classmethod
def _get_model_by_polymorphic_key(cls, key):
if not cls._is_polymorphic_base:
raise ModelException('_get_model_by_polymorphic_key can only be called on polymorphic base classes')
return cls._polymorphic_map.get(key)
@classmethod
def _construct_instance(cls, values):
"""
method used to construct instances from query results
this is where polymorphic deserialization occurs
"""
# we're going to take the values, which is from the DB as a dict
# and translate that into our local fields
# the db_map is a db_field -> model field map
items = values.items()
field_dict = dict([(cls._db_map.get(k, k),v) for k,v in items])
if cls._is_polymorphic:
poly_key = field_dict.get(cls._polymorphic_column_name)
if poly_key is None:
raise PolyMorphicModelException('polymorphic key was not found in values')
poly_base = cls if cls._is_polymorphic_base else cls._polymorphic_base
klass = poly_base._get_model_by_polymorphic_key(poly_key)
if klass is None:
poly_base._discover_polymorphic_submodels()
klass = poly_base._get_model_by_polymorphic_key(poly_key)
if klass is None:
raise PolyMorphicModelException(
'unrecognized polymorphic key {} for class {}'.format(poly_key, poly_base.__name__)
)
if not issubclass(klass, cls):
raise PolyMorphicModelException(
'{} is not a subclass of {}'.format(klass.__name__, cls.__name__)
)
field_dict = {k: v for k, v in field_dict.items() if k in klass._columns.keys()}
else:
klass = cls
instance = klass(**field_dict)
instance._is_persisted = True
return instance
def _can_update(self):
"""
Called by the save function to check if this should be
@@ -78,10 +376,35 @@ class BaseModel(object):
@classmethod
def _get_keyspace(cls):
""" Returns the manual keyspace, if set, otherwise the default keyspace """
return cls.keyspace or DEFAULT_KEYSPACE
return cls.__keyspace__ or DEFAULT_KEYSPACE
@classmethod
def _get_column(cls, name):
"""
Returns the column matching the given name, raising a key error if
it doesn't exist
:param name: the name of the column to return
:rtype: Column
"""
return cls._columns[name]
def __eq__(self, other):
return self.as_dict() == other.as_dict()
if self.__class__ != other.__class__:
return False
# check attribute keys
keys = set(self._columns.keys())
other_keys = set(other._columns.keys())
if keys != other_keys:
return False
# check that all of the attributes match
for key in other_keys:
if getattr(self, key, None) != getattr(other, key, None):
return False
return True
def __ne__(self, other):
return not self.__eq__(other)
@@ -93,16 +416,16 @@ class BaseModel(object):
otherwise, it creates it from the module and class name
"""
cf_name = ''
if cls.table_name:
cf_name = cls.table_name.lower()
if cls.__table_name__:
cf_name = cls.__table_name__.lower()
else:
# get polymorphic base table names if model is polymorphic
if cls._is_polymorphic and not cls._is_polymorphic_base:
return cls._polymorphic_base.column_family_name(include_keyspace=include_keyspace)
camelcase = re.compile(r'([a-z])([A-Z])')
ccase = lambda s: camelcase.sub(lambda v: '{}_{}'.format(v.group(1), v.group(2).lower()), s)
module = cls.__module__.split('.')
if module:
cf_name = ccase(module[-1]) + '_'
cf_name += ccase(cls.__name__)
#trim to less than 48 characters or cassandra will complain
cf_name = cf_name[-48:]
@@ -111,18 +434,56 @@ class BaseModel(object):
if not include_keyspace: return cf_name
return '{}.{}'.format(cls._get_keyspace(), cf_name)
@property
def pk(self):
""" Returns the object's primary key """
return getattr(self, self._pk_name)
def validate(self):
""" Cleans and validates the field values """
for name, col in self._columns.items():
val = col.validate(getattr(self, name))
setattr(self, name, val)
def as_dict(self):
### Let an instance be used like a dict of its columns keys/values
def __iter__(self):
""" Iterate over column ids. """
for column_id in self._columns.keys():
yield column_id
def __getitem__(self, key):
""" Returns column's value. """
if not isinstance(key, basestring):
raise TypeError
if key not in self._columns.keys():
raise KeyError
return getattr(self, key)
def __setitem__(self, key, val):
""" Sets a column's value. """
if not isinstance(key, basestring):
raise TypeError
if key not in self._columns.keys():
raise KeyError
return setattr(self, key, val)
def __len__(self):
""" Returns the number of columns defined on that model. """
try:
return self._len
except:
self._len = len(self._columns.keys())
return self._len
def keys(self):
""" Returns list of column's IDs. """
return [k for k in self]
def values(self):
""" Returns list of column's values. """
return [self[k] for k in self]
def items(self):
""" Returns a list of columns's IDs/values. """
return [(k, self[k]) for k in self]
def _as_dict(self):
""" Returns a map of column names to cleaned values """
values = self._dynamic_columns or {}
for name, col in self._columns.items():
@@ -131,37 +492,97 @@ class BaseModel(object):
@classmethod
def create(cls, **kwargs):
extra_columns = set(kwargs.keys()) - set(cls._columns.keys())
if extra_columns:
raise ValidationError("Incorrect columns passed: {}".format(extra_columns))
return cls.objects.create(**kwargs)
@classmethod
def all(cls):
return cls.objects.all()
@classmethod
def filter(cls, **kwargs):
return cls.objects.filter(**kwargs)
@classmethod
def get(cls, **kwargs):
return cls.objects.get(**kwargs)
def save(self, ttl=NotSet, timestamp=None):
if ttl == NotSet:
ttl = self.ttl
@classmethod
def filter(cls, *args, **kwargs):
# if kwargs.values().count(None):
# raise CQLEngineException("Cannot pass None as a filter")
return cls.objects.filter(*args, **kwargs)
@classmethod
def get(cls, *args, **kwargs):
return cls.objects.get(*args, **kwargs)
def save(self):
# handle polymorphic models
if self._is_polymorphic:
if self._is_polymorphic_base:
raise PolyMorphicModelException('cannot save polymorphic base model')
else:
setattr(self, self._polymorphic_column_name, self.__polymorphic_key__)
is_new = self.pk is None
self.validate()
DMLQuery(self.__class__, self, batch=self._batch).save(ttl, timestamp)
self.__dmlquery__(self.__class__, self,
batch=self._batch,
ttl=self._ttl,
timestamp=self._timestamp,
consistency=self.__consistency__).save()
#reset the value managers
for v in self._values.values():
v.reset_previous_value()
self._is_persisted = True
self._ttl = None
self._timestamp = None
return self
def update(self, **values):
for k, v in values.items():
col = self._columns.get(k)
# check for nonexistant columns
if col is None:
raise ValidationError("{}.{} has no column named: {}".format(self.__module__, self.__class__.__name__, k))
# check for primary key update attempts
if col.is_primary_key:
raise ValidationError("Cannot apply update to primary key '{}' for {}.{}".format(k, self.__module__, self.__class__.__name__))
setattr(self, k, v)
# handle polymorphic models
if self._is_polymorphic:
if self._is_polymorphic_base:
raise PolyMorphicModelException('cannot update polymorphic base model')
else:
setattr(self, self._polymorphic_column_name, self.__polymorphic_key__)
self.validate()
self.__dmlquery__(self.__class__, self,
batch=self._batch,
ttl=self._ttl,
timestamp=self._timestamp,
consistency=self.__consistency__).update()
#reset the value managers
for v in self._values.values():
v.reset_previous_value()
self._is_persisted = True
self._ttl = None
self._timestamp = None
return self
def delete(self):
""" Deletes this instance """
DMLQuery(self.__class__, self, batch=self._batch).delete()
self.__dmlquery__(self.__class__, self, batch=self._batch, timestamp=self._timestamp, consistency=self.__consistency__).delete()
def get_changed_columns(self):
""" returns a list of the columns that have been updated since instantiation or save """
return [k for k,v in self._values.items() if v.changed]
@classmethod
def _class_batch(cls, batch):
@@ -171,6 +592,7 @@ class BaseModel(object):
self._batch = batch
return self
batch = hybrid_classmethod(_class_batch, _inst_batch)
@@ -185,7 +607,6 @@ class ModelMetaClass(type):
column_dict = OrderedDict()
primary_keys = OrderedDict()
pk_name = None
primary_key = None
#get inherited properties
inherited_columns = OrderedDict()
@@ -193,52 +614,105 @@ class ModelMetaClass(type):
for k,v in getattr(base, '_defined_columns', {}).items():
inherited_columns.setdefault(k,v)
#short circuit __abstract__ inheritance
is_abstract = attrs['__abstract__'] = attrs.get('__abstract__', False)
#short circuit __polymorphic_key__ inheritance
attrs['__polymorphic_key__'] = attrs.get('__polymorphic_key__', None)
def _transform_column(col_name, col_obj):
column_dict[col_name] = col_obj
if col_obj.primary_key:
primary_keys[col_name] = col_obj
col_obj.set_column_name(col_name)
#set properties
_get = lambda self: self._values[col_name].getval()
_set = lambda self, val: self._values[col_name].setval(val)
_del = lambda self: self._values[col_name].delval()
if col_obj.can_delete:
attrs[col_name] = property(_get, _set)
else:
attrs[col_name] = property(_get, _set, _del)
attrs[col_name] = ColumnDescriptor(col_obj)
column_definitions = [(k,v) for k,v in attrs.items() if isinstance(v, columns.Column)]
column_definitions = sorted(column_definitions, lambda x,y: cmp(x[1].position, y[1].position))
is_polymorphic_base = any([c[1].polymorphic_key for c in column_definitions])
column_definitions = inherited_columns.items() + column_definitions
#columns defined on model, excludes automatically
#defined columns
polymorphic_columns = [c for c in column_definitions if c[1].polymorphic_key]
is_polymorphic = len(polymorphic_columns) > 0
if len(polymorphic_columns) > 1:
raise ModelDefinitionException('only one polymorphic_key can be defined in a model, {} found'.format(len(polymorphic_columns)))
polymorphic_column_name, polymorphic_column = polymorphic_columns[0] if polymorphic_columns else (None, None)
if isinstance(polymorphic_column, (columns.BaseContainerColumn, columns.Counter)):
raise ModelDefinitionException('counter and container columns cannot be used for polymorphic keys')
# find polymorphic base class
polymorphic_base = None
if is_polymorphic and not is_polymorphic_base:
def _get_polymorphic_base(bases):
for base in bases:
if getattr(base, '_is_polymorphic_base', False):
return base
klass = _get_polymorphic_base(base.__bases__)
if klass:
return klass
polymorphic_base = _get_polymorphic_base(bases)
defined_columns = OrderedDict(column_definitions)
#prepend primary key if one hasn't been defined
if not any([v.primary_key for k,v in column_definitions]):
k,v = 'id', columns.UUID(primary_key=True)
column_definitions = [(k,v)] + column_definitions
# check for primary key
if not is_abstract and not any([v.primary_key for k,v in column_definitions]):
raise ModelDefinitionException("At least 1 primary key is required.")
counter_columns = [c for c in defined_columns.values() if isinstance(c, columns.Counter)]
data_columns = [c for c in defined_columns.values() if not c.primary_key and not isinstance(c, columns.Counter)]
if counter_columns and data_columns:
raise ModelDefinitionException('counter models may not have data columns')
has_partition_keys = any(v.partition_key for (k, v) in column_definitions)
#TODO: check that the defined columns don't conflict with any of the Model API's existing attributes/methods
#transform column definitions
for k,v in column_definitions:
if pk_name is None and v.primary_key:
pk_name = k
primary_key = v
v._partition_key = True
_transform_column(k,v)
#setup primary key shortcut
if pk_name != 'pk':
attrs['pk'] = attrs[pk_name]
for k, v in column_definitions:
# don't allow a column with the same name as a built-in attribute or method
if k in BaseModel.__dict__:
raise ModelDefinitionException("column '{}' conflicts with built-in attribute/method".format(k))
#check for duplicate column names
# counter column primary keys are not allowed
if (v.primary_key or v.partition_key) and isinstance(v, (columns.Counter, columns.BaseContainerColumn)):
raise ModelDefinitionException('counter columns and container columns cannot be used as primary keys')
# this will mark the first primary key column as a partition
# key, if one hasn't been set already
if not has_partition_keys and v.primary_key:
v.partition_key = True
has_partition_keys = True
_transform_column(k, v)
partition_keys = OrderedDict(k for k in primary_keys.items() if k[1].partition_key)
clustering_keys = OrderedDict(k for k in primary_keys.items() if not k[1].partition_key)
#setup partition key shortcut
if len(partition_keys) == 0:
if not is_abstract:
raise ModelException("at least one partition key must be defined")
if len(partition_keys) == 1:
pk_name = partition_keys.keys()[0]
attrs['pk'] = attrs[pk_name]
else:
# composite partition key case, get/set a tuple of values
_get = lambda self: tuple(self._values[c].getval() for c in partition_keys.keys())
_set = lambda self, val: tuple(self._values[c].setval(v) for (c, v) in zip(partition_keys.keys(), val))
attrs['pk'] = property(_get, _set)
# some validation
col_names = set()
for v in column_dict.values():
# check for duplicate column names
if v.db_field_name in col_names:
raise ModelException("{} defines the column {} more than once".format(name, v.db_field_name))
if v.clustering_order and not (v.primary_key and not v.partition_key):
raise ModelException("clustering_order may be specified only for clustering primary keys")
if v.clustering_order and v.clustering_order.lower() not in ('asc', 'desc'):
raise ModelException("invalid clustering order {} for column {}".format(repr(v.clustering_order), v.db_field_name))
col_names.add(v.db_field_name)
#create db_name -> model name map for loading
@@ -246,21 +720,46 @@ class ModelMetaClass(type):
for field_name, col in column_dict.items():
db_map[col.db_field_name] = field_name
#short circuit table_name inheritance
attrs['table_name'] = attrs.get('table_name')
#add management members to the class
attrs['_columns'] = column_dict
attrs['_primary_keys'] = primary_keys
attrs['_defined_columns'] = defined_columns
# maps the database field to the models key
attrs['_db_map'] = db_map
attrs['_pk_name'] = pk_name
attrs['_primary_key'] = primary_key
attrs['_dynamic_columns'] = {}
attrs['_partition_keys'] = partition_keys
attrs['_clustering_keys'] = clustering_keys
attrs['_has_counter'] = len(counter_columns) > 0
# add polymorphic management attributes
attrs['_is_polymorphic_base'] = is_polymorphic_base
attrs['_is_polymorphic'] = is_polymorphic
attrs['_polymorphic_base'] = polymorphic_base
attrs['_polymorphic_column'] = polymorphic_column
attrs['_polymorphic_column_name'] = polymorphic_column_name
attrs['_polymorphic_map'] = {} if is_polymorphic_base else None
#setup class exceptions
DoesNotExistBase = None
for base in bases:
DoesNotExistBase = getattr(base, 'DoesNotExist', None)
if DoesNotExistBase is not None: break
DoesNotExistBase = DoesNotExistBase or attrs.pop('DoesNotExist', BaseModel.DoesNotExist)
attrs['DoesNotExist'] = type('DoesNotExist', (DoesNotExistBase,), {})
MultipleObjectsReturnedBase = None
for base in bases:
MultipleObjectsReturnedBase = getattr(base, 'MultipleObjectsReturned', None)
if MultipleObjectsReturnedBase is not None: break
MultipleObjectsReturnedBase = DoesNotExistBase or attrs.pop('MultipleObjectsReturned', BaseModel.MultipleObjectsReturned)
attrs['MultipleObjectsReturned'] = type('MultipleObjectsReturned', (MultipleObjectsReturnedBase,), {})
#create the class and add a QuerySet to it
klass = super(ModelMetaClass, cls).__new__(cls, name, bases, attrs)
klass.objects = QuerySet(klass)
return klass
@@ -269,6 +768,7 @@ class Model(BaseModel):
the db name for the column family can be set as the attribute db_name, or
it will be genertaed from the class name
"""
__abstract__ = True
__metaclass__ = ModelMetaClass

122
cqlengine/named.py Normal file
View File

@@ -0,0 +1,122 @@
from cqlengine.exceptions import CQLEngineException
from cqlengine.query import AbstractQueryableColumn, SimpleQuerySet
from cqlengine.query import DoesNotExist as _DoesNotExist
from cqlengine.query import MultipleObjectsReturned as _MultipleObjectsReturned
class QuerySetDescriptor(object):
"""
returns a fresh queryset for the given model
it's declared on everytime it's accessed
"""
def __get__(self, obj, model):
""" :rtype: ModelQuerySet """
if model.__abstract__:
raise CQLEngineException('cannot execute queries against abstract models')
return SimpleQuerySet(obj)
def __call__(self, *args, **kwargs):
"""
Just a hint to IDEs that it's ok to call this
:rtype: ModelQuerySet
"""
raise NotImplementedError
class NamedColumn(AbstractQueryableColumn):
"""
A column that is not coupled to a model class, or type
"""
def __init__(self, name):
self.name = name
def __unicode__(self):
return self.name
def _get_column(self):
""" :rtype: NamedColumn """
return self
@property
def db_field_name(self):
return self.name
@property
def cql(self):
return self.get_cql()
def get_cql(self):
return '"{}"'.format(self.name)
def to_database(self, val):
return val
class NamedTable(object):
"""
A Table that is not coupled to a model class
"""
__abstract__ = False
objects = QuerySetDescriptor()
class DoesNotExist(_DoesNotExist): pass
class MultipleObjectsReturned(_MultipleObjectsReturned): pass
def __init__(self, keyspace, name):
self.keyspace = keyspace
self.name = name
def column(self, name):
return NamedColumn(name)
def column_family_name(self, include_keyspace=True):
"""
Returns the column family name if it's been defined
otherwise, it creates it from the module and class name
"""
if include_keyspace:
return '{}.{}'.format(self.keyspace, self.name)
else:
return self.name
def _get_column(self, name):
"""
Returns the column matching the given name
:rtype: Column
"""
return self.column(name)
# def create(self, **kwargs):
# return self.objects.create(**kwargs)
def all(self):
return self.objects.all()
def filter(self, *args, **kwargs):
return self.objects.filter(*args, **kwargs)
def get(self, *args, **kwargs):
return self.objects.get(*args, **kwargs)
class NamedKeyspace(object):
"""
A keyspace
"""
def __init__(self, name):
self.name = name
def table(self, name):
"""
returns a table descriptor with the given
name that belongs to this keyspace
"""
return NamedTable(self.name, name)

82
cqlengine/operators.py Normal file
View File

@@ -0,0 +1,82 @@
class QueryOperatorException(Exception): pass
class BaseQueryOperator(object):
# The symbol that identifies this operator in kwargs
# ie: colname__<symbol>
symbol = None
# The comparator symbol this operator uses in cql
cql_symbol = None
def __unicode__(self):
if self.cql_symbol is None:
raise QueryOperatorException("cql symbol is None")
return self.cql_symbol
def __str__(self):
return unicode(self).encode('utf-8')
@classmethod
def get_operator(cls, symbol):
if cls == BaseQueryOperator:
raise QueryOperatorException("get_operator can only be called from a BaseQueryOperator subclass")
if not hasattr(cls, 'opmap'):
cls.opmap = {}
def _recurse(klass):
if klass.symbol:
cls.opmap[klass.symbol.upper()] = klass
for subklass in klass.__subclasses__():
_recurse(subklass)
pass
_recurse(cls)
try:
return cls.opmap[symbol.upper()]
except KeyError:
raise QueryOperatorException("{} doesn't map to a QueryOperator".format(symbol))
class BaseWhereOperator(BaseQueryOperator):
""" base operator used for where clauses """
class EqualsOperator(BaseWhereOperator):
symbol = 'EQ'
cql_symbol = '='
class InOperator(EqualsOperator):
symbol = 'IN'
cql_symbol = 'IN'
class GreaterThanOperator(BaseWhereOperator):
symbol = "GT"
cql_symbol = '>'
class GreaterThanOrEqualOperator(BaseWhereOperator):
symbol = "GTE"
cql_symbol = '>='
class LessThanOperator(BaseWhereOperator):
symbol = "LT"
cql_symbol = '<'
class LessThanOrEqualOperator(BaseWhereOperator):
symbol = "LTE"
cql_symbol = '<='
class BaseAssignmentOperator(BaseQueryOperator):
""" base operator used for insert and delete statements """
class AssignmentOperator(BaseAssignmentOperator):
cql_symbol = "="
class AddSymbol(BaseAssignmentOperator):
cql_symbol = "+"

File diff suppressed because it is too large Load Diff

724
cqlengine/statements.py Normal file
View File

@@ -0,0 +1,724 @@
import time
from datetime import datetime, timedelta
from cqlengine.functions import QueryValue
from cqlengine.operators import BaseWhereOperator, InOperator
class StatementException(Exception): pass
class ValueQuoter(object):
def __init__(self, value):
self.value = value
def __unicode__(self):
from cassandra.encoder import cql_quote
if isinstance(self.value, bool):
return 'true' if self.value else 'false'
elif isinstance(self.value, (list, tuple)):
return '[' + ', '.join([cql_quote(v) for v in self.value]) + ']'
elif isinstance(self.value, dict):
return '{' + ', '.join([cql_quote(k) + ':' + cql_quote(v) for k,v in self.value.items()]) + '}'
elif isinstance(self.value, set):
return '{' + ', '.join([cql_quote(v) for v in self.value]) + '}'
return cql_quote(self.value)
def __eq__(self, other):
if isinstance(other, self.__class__):
return self.value == other.value
return False
def __str__(self):
return unicode(self).encode('utf-8')
class InQuoter(ValueQuoter):
def __unicode__(self):
from cassandra.encoder import cql_quote
return '(' + ', '.join([cql_quote(v) for v in self.value]) + ')'
class BaseClause(object):
def __init__(self, field, value):
self.field = field
self.value = value
self.context_id = None
def __unicode__(self):
raise NotImplementedError
def __str__(self):
return unicode(self).encode('utf-8')
def __hash__(self):
return hash(self.field) ^ hash(self.value)
def __eq__(self, other):
if isinstance(other, self.__class__):
return self.field == other.field and self.value == other.value
return False
def __ne__(self, other):
return not self.__eq__(other)
def get_context_size(self):
""" returns the number of entries this clause will add to the query context """
return 1
def set_context_id(self, i):
""" sets the value placeholder that will be used in the query """
self.context_id = i
def update_context(self, ctx):
""" updates the query context with this clauses values """
assert isinstance(ctx, dict)
ctx[str(self.context_id)] = self.value
class WhereClause(BaseClause):
""" a single where statement used in queries """
def __init__(self, field, operator, value, quote_field=True):
"""
:param field:
:param operator:
:param value:
:param quote_field: hack to get the token function rendering properly
:return:
"""
if not isinstance(operator, BaseWhereOperator):
raise StatementException(
"operator must be of type {}, got {}".format(BaseWhereOperator, type(operator))
)
super(WhereClause, self).__init__(field, value)
self.operator = operator
self.query_value = self.value if isinstance(self.value, QueryValue) else QueryValue(self.value)
self.quote_field = quote_field
def __unicode__(self):
field = ('"{}"' if self.quote_field else '{}').format(self.field)
return u'{} {} {}'.format(field, self.operator, unicode(self.query_value))
def __hash__(self):
return super(WhereClause, self).__hash__() ^ hash(self.operator)
def __eq__(self, other):
if super(WhereClause, self).__eq__(other):
return self.operator.__class__ == other.operator.__class__
return False
def get_context_size(self):
return self.query_value.get_context_size()
def set_context_id(self, i):
super(WhereClause, self).set_context_id(i)
self.query_value.set_context_id(i)
def update_context(self, ctx):
if isinstance(self.operator, InOperator):
ctx[str(self.context_id)] = InQuoter(self.value)
else:
self.query_value.update_context(ctx)
class AssignmentClause(BaseClause):
""" a single variable st statement """
def __unicode__(self):
return u'"{}" = %({})s'.format(self.field, self.context_id)
def insert_tuple(self):
return self.field, self.context_id
class ContainerUpdateClause(AssignmentClause):
def __init__(self, field, value, operation=None, previous=None, column=None):
super(ContainerUpdateClause, self).__init__(field, value)
self.previous = previous
self._assignments = None
self._operation = operation
self._analyzed = False
self._column = column
def _to_database(self, val):
return self._column.to_database(val) if self._column else val
def _analyze(self):
raise NotImplementedError
def get_context_size(self):
raise NotImplementedError
def update_context(self, ctx):
raise NotImplementedError
class SetUpdateClause(ContainerUpdateClause):
""" updates a set collection """
def __init__(self, field, value, operation=None, previous=None, column=None):
super(SetUpdateClause, self).__init__(field, value, operation, previous, column=column)
self._additions = None
self._removals = None
def __unicode__(self):
qs = []
ctx_id = self.context_id
if self.previous is None and not (self._assignments or self._additions or self._removals):
qs += ['"{}" = %({})s'.format(self.field, ctx_id)]
if self._assignments:
qs += ['"{}" = %({})s'.format(self.field, ctx_id)]
ctx_id += 1
if self._additions:
qs += ['"{0}" = "{0}" + %({1})s'.format(self.field, ctx_id)]
ctx_id += 1
if self._removals:
qs += ['"{0}" = "{0}" - %({1})s'.format(self.field, ctx_id)]
return ', '.join(qs)
def _analyze(self):
""" works out the updates to be performed """
if self.value is None or self.value == self.previous:
pass
elif self._operation == "add":
self._additions = self.value
elif self._operation == "remove":
self._removals = self.value
elif self.previous is None:
self._assignments = self.value
else:
# partial update time
self._additions = (self.value - self.previous) or None
self._removals = (self.previous - self.value) or None
self._analyzed = True
def get_context_size(self):
if not self._analyzed: self._analyze()
if self.previous is None and not (self._assignments or self._additions or self._removals):
return 1
return int(bool(self._assignments)) + int(bool(self._additions)) + int(bool(self._removals))
def update_context(self, ctx):
if not self._analyzed: self._analyze()
ctx_id = self.context_id
if self.previous is None and not (self._assignments or self._additions or self._removals):
ctx[str(ctx_id)] = self._to_database({})
if self._assignments:
ctx[str(ctx_id)] = self._to_database(self._assignments)
ctx_id += 1
if self._additions:
ctx[str(ctx_id)] = self._to_database(self._additions)
ctx_id += 1
if self._removals:
ctx[str(ctx_id)] = self._to_database(self._removals)
class ListUpdateClause(ContainerUpdateClause):
""" updates a list collection """
def __init__(self, field, value, operation=None, previous=None, column=None):
super(ListUpdateClause, self).__init__(field, value, operation, previous, column=column)
self._append = None
self._prepend = None
def __unicode__(self):
if not self._analyzed: self._analyze()
qs = []
ctx_id = self.context_id
if self._assignments is not None:
qs += ['"{}" = %({})s'.format(self.field, ctx_id)]
ctx_id += 1
if self._prepend:
qs += ['"{0}" = %({1})s + "{0}"'.format(self.field, ctx_id)]
ctx_id += 1
if self._append:
qs += ['"{0}" = "{0}" + %({1})s'.format(self.field, ctx_id)]
return ', '.join(qs)
def get_context_size(self):
if not self._analyzed: self._analyze()
return int(self._assignments is not None) + int(bool(self._append)) + int(bool(self._prepend))
def update_context(self, ctx):
if not self._analyzed: self._analyze()
ctx_id = self.context_id
if self._assignments is not None:
ctx[str(ctx_id)] = self._to_database(self._assignments)
ctx_id += 1
if self._prepend:
# CQL seems to prepend element at a time, starting
# with the element at idx 0, we can either reverse
# it here, or have it inserted in reverse
ctx[str(ctx_id)] = self._to_database(list(reversed(self._prepend)))
ctx_id += 1
if self._append:
ctx[str(ctx_id)] = self._to_database(self._append)
def _analyze(self):
""" works out the updates to be performed """
if self.value is None or self.value == self.previous:
pass
elif self._operation == "append":
self._append = self.value
elif self._operation == "prepend":
# self.value is a Quoter but we reverse self._prepend later as if
# it's a list, so we have to set it to the underlying list
self._prepend = self.value.value
elif self.previous is None:
self._assignments = self.value
elif len(self.value) < len(self.previous):
# if elements have been removed,
# rewrite the whole list
self._assignments = self.value
elif len(self.previous) == 0:
# if we're updating from an empty
# list, do a complete insert
self._assignments = self.value
else:
# the max start idx we want to compare
search_space = len(self.value) - max(0, len(self.previous)-1)
# the size of the sub lists we want to look at
search_size = len(self.previous)
for i in range(search_space):
#slice boundary
j = i + search_size
sub = self.value[i:j]
idx_cmp = lambda idx: self.previous[idx] == sub[idx]
if idx_cmp(0) and idx_cmp(-1) and self.previous == sub:
self._prepend = self.value[:i] or None
self._append = self.value[j:] or None
break
# if both append and prepend are still None after looking
# at both lists, an insert statement will be created
if self._prepend is self._append is None:
self._assignments = self.value
self._analyzed = True
class MapUpdateClause(ContainerUpdateClause):
""" updates a map collection """
def __init__(self, field, value, operation=None, previous=None, column=None):
super(MapUpdateClause, self).__init__(field, value, operation, previous, column=column)
self._updates = None
def _analyze(self):
if self._operation == "update":
self._updates = self.value.keys()
else:
if self.previous is None:
self._updates = sorted([k for k, v in self.value.items()])
else:
self._updates = sorted([k for k, v in self.value.items() if v != self.previous.get(k)]) or None
self._analyzed = True
def get_context_size(self):
if not self._analyzed: self._analyze()
if self.previous is None and not self._updates:
return 1
return len(self._updates or []) * 2
def update_context(self, ctx):
if not self._analyzed: self._analyze()
ctx_id = self.context_id
if self.previous is None and not self._updates:
ctx[str(ctx_id)] = {}
else:
for key in self._updates or []:
val = self.value.get(key)
ctx[str(ctx_id)] = self._column.key_col.to_database(key) if self._column else key
ctx[str(ctx_id + 1)] = self._column.value_col.to_database(val) if self._column else val
ctx_id += 2
def __unicode__(self):
if not self._analyzed: self._analyze()
qs = []
ctx_id = self.context_id
if self.previous is None and not self._updates:
qs += ['"int_map" = %({})s'.format(ctx_id)]
else:
for _ in self._updates or []:
qs += ['"{}"[%({})s] = %({})s'.format(self.field, ctx_id, ctx_id + 1)]
ctx_id += 2
return ', '.join(qs)
class CounterUpdateClause(ContainerUpdateClause):
def __init__(self, field, value, previous=None, column=None):
super(CounterUpdateClause, self).__init__(field, value, previous=previous, column=column)
self.previous = self.previous or 0
def get_context_size(self):
return 1
def update_context(self, ctx):
ctx[str(self.context_id)] = self._to_database(abs(self.value - self.previous))
def __unicode__(self):
delta = self.value - self.previous
sign = '-' if delta < 0 else '+'
return '"{0}" = "{0}" {1} %({2})s'.format(self.field, sign, self.context_id)
class BaseDeleteClause(BaseClause):
pass
class FieldDeleteClause(BaseDeleteClause):
""" deletes a field from a row """
def __init__(self, field):
super(FieldDeleteClause, self).__init__(field, None)
def __unicode__(self):
return '"{}"'.format(self.field)
def update_context(self, ctx):
pass
def get_context_size(self):
return 0
class MapDeleteClause(BaseDeleteClause):
""" removes keys from a map """
def __init__(self, field, value, previous=None):
super(MapDeleteClause, self).__init__(field, value)
self.value = self.value or {}
self.previous = previous or {}
self._analyzed = False
self._removals = None
def _analyze(self):
self._removals = sorted([k for k in self.previous if k not in self.value])
self._analyzed = True
def update_context(self, ctx):
if not self._analyzed: self._analyze()
for idx, key in enumerate(self._removals):
ctx[str(self.context_id + idx)] = key
def get_context_size(self):
if not self._analyzed: self._analyze()
return len(self._removals)
def __unicode__(self):
if not self._analyzed: self._analyze()
return ', '.join(['"{}"[%({})s]'.format(self.field, self.context_id + i) for i in range(len(self._removals))])
class BaseCQLStatement(object):
""" The base cql statement class """
def __init__(self, table, consistency=None, timestamp=None, where=None):
super(BaseCQLStatement, self).__init__()
self.table = table
self.consistency = consistency
self.context_id = 0
self.context_counter = self.context_id
self.timestamp = timestamp
self.where_clauses = []
for clause in where or []:
self.add_where_clause(clause)
def add_where_clause(self, clause):
"""
adds a where clause to this statement
:param clause: the clause to add
:type clause: WhereClause
"""
if not isinstance(clause, WhereClause):
raise StatementException("only instances of WhereClause can be added to statements")
clause.set_context_id(self.context_counter)
self.context_counter += clause.get_context_size()
self.where_clauses.append(clause)
def get_context(self):
"""
returns the context dict for this statement
:rtype: dict
"""
ctx = {}
for clause in self.where_clauses or []:
clause.update_context(ctx)
return ctx
def get_context_size(self):
return len(self.get_context())
def update_context_id(self, i):
self.context_id = i
self.context_counter = self.context_id
for clause in self.where_clauses:
clause.set_context_id(self.context_counter)
self.context_counter += clause.get_context_size()
@property
def timestamp_normalized(self):
"""
we're expecting self.timestamp to be either a long, int, a datetime, or a timedelta
:return:
"""
if not self.timestamp:
return None
if isinstance(self.timestamp, (int, long)):
return self.timestamp
if isinstance(self.timestamp, timedelta):
tmp = datetime.now() + self.timestamp
else:
tmp = self.timestamp
return long(time.mktime(tmp.timetuple()) * 1e+6 + tmp.microsecond)
def __unicode__(self):
raise NotImplementedError
def __str__(self):
return unicode(self).encode('utf-8')
def __repr__(self):
return self.__unicode__()
@property
def _where(self):
return 'WHERE {}'.format(' AND '.join([unicode(c) for c in self.where_clauses]))
class SelectStatement(BaseCQLStatement):
""" a cql select statement """
def __init__(self,
table,
fields=None,
count=False,
consistency=None,
where=None,
order_by=None,
limit=None,
allow_filtering=False):
"""
:param where
:type where list of cqlengine.statements.WhereClause
"""
super(SelectStatement, self).__init__(
table,
consistency=consistency,
where=where
)
self.fields = [fields] if isinstance(fields, basestring) else (fields or [])
self.count = count
self.order_by = [order_by] if isinstance(order_by, basestring) else order_by
self.limit = limit
self.allow_filtering = allow_filtering
def __unicode__(self):
qs = ['SELECT']
if self.count:
qs += ['COUNT(*)']
else:
qs += [', '.join(['"{}"'.format(f) for f in self.fields]) if self.fields else '*']
qs += ['FROM', self.table]
if self.where_clauses:
qs += [self._where]
if self.order_by and not self.count:
qs += ['ORDER BY {}'.format(', '.join(unicode(o) for o in self.order_by))]
if self.limit:
qs += ['LIMIT {}'.format(self.limit)]
if self.allow_filtering:
qs += ['ALLOW FILTERING']
return ' '.join(qs)
class AssignmentStatement(BaseCQLStatement):
""" value assignment statements """
def __init__(self,
table,
assignments=None,
consistency=None,
where=None,
ttl=None,
timestamp=None):
super(AssignmentStatement, self).__init__(
table,
consistency=consistency,
where=where,
)
self.ttl = ttl
self.timestamp = timestamp
# add assignments
self.assignments = []
for assignment in assignments or []:
self.add_assignment_clause(assignment)
def update_context_id(self, i):
super(AssignmentStatement, self).update_context_id(i)
for assignment in self.assignments:
assignment.set_context_id(self.context_counter)
self.context_counter += assignment.get_context_size()
def add_assignment_clause(self, clause):
"""
adds an assignment clause to this statement
:param clause: the clause to add
:type clause: AssignmentClause
"""
if not isinstance(clause, AssignmentClause):
raise StatementException("only instances of AssignmentClause can be added to statements")
clause.set_context_id(self.context_counter)
self.context_counter += clause.get_context_size()
self.assignments.append(clause)
@property
def is_empty(self):
return len(self.assignments) == 0
def get_context(self):
ctx = super(AssignmentStatement, self).get_context()
for clause in self.assignments:
clause.update_context(ctx)
return ctx
class InsertStatement(AssignmentStatement):
""" an cql insert select statement """
def add_where_clause(self, clause):
raise StatementException("Cannot add where clauses to insert statements")
def __unicode__(self):
qs = ['INSERT INTO {}'.format(self.table)]
# get column names and context placeholders
fields = [a.insert_tuple() for a in self.assignments]
columns, values = zip(*fields)
qs += ["({})".format(', '.join(['"{}"'.format(c) for c in columns]))]
qs += ['VALUES']
qs += ["({})".format(', '.join(['%({})s'.format(v) for v in values]))]
if self.ttl:
qs += ["USING TTL {}".format(self.ttl)]
if self.timestamp:
qs += ["USING TIMESTAMP {}".format(self.timestamp_normalized)]
return ' '.join(qs)
class UpdateStatement(AssignmentStatement):
""" an cql update select statement """
def __unicode__(self):
qs = ['UPDATE', self.table]
using_options = []
if self.ttl:
using_options += ["TTL {}".format(self.ttl)]
if self.timestamp:
using_options += ["TIMESTAMP {}".format(self.timestamp_normalized)]
if using_options:
qs += ["USING {}".format(" AND ".join(using_options))]
qs += ['SET']
qs += [', '.join([unicode(c) for c in self.assignments])]
if self.where_clauses:
qs += [self._where]
return ' '.join(qs)
class DeleteStatement(BaseCQLStatement):
""" a cql delete statement """
def __init__(self, table, fields=None, consistency=None, where=None, timestamp=None):
super(DeleteStatement, self).__init__(
table,
consistency=consistency,
where=where,
timestamp=timestamp
)
self.fields = []
if isinstance(fields, basestring):
fields = [fields]
for field in fields or []:
self.add_field(field)
def update_context_id(self, i):
super(DeleteStatement, self).update_context_id(i)
for field in self.fields:
field.set_context_id(self.context_counter)
self.context_counter += field.get_context_size()
def get_context(self):
ctx = super(DeleteStatement, self).get_context()
for field in self.fields:
field.update_context(ctx)
return ctx
def add_field(self, field):
if isinstance(field, basestring):
field = FieldDeleteClause(field)
if not isinstance(field, BaseClause):
raise StatementException("only instances of AssignmentClause can be added to statements")
field.set_context_id(self.context_counter)
self.context_counter += field.get_context_size()
self.fields.append(field)
def __unicode__(self):
qs = ['DELETE']
if self.fields:
qs += [', '.join(['{}'.format(f) for f in self.fields])]
qs += ['FROM', self.table]
delete_option = []
if self.timestamp:
delete_option += ["TIMESTAMP {}".format(self.timestamp_normalized)]
if delete_option:
qs += [" USING {} ".format(" AND ".join(delete_option))]
if self.where_clauses:
qs += [self._where]
return ' '.join(qs)

View File

@@ -1,18 +1,33 @@
from unittest import TestCase
from cqlengine import connection
import os
from cqlengine.connection import get_session
if os.environ.get('CASSANDRA_TEST_HOST'):
CASSANDRA_TEST_HOST = os.environ['CASSANDRA_TEST_HOST']
else:
CASSANDRA_TEST_HOST = 'localhost'
protocol_version = int(os.environ.get("CASSANDRA_PROTOCOL_VERSION", 2))
connection.setup([CASSANDRA_TEST_HOST], protocol_version=protocol_version, default_keyspace='cqlengine_test')
class BaseCassEngTestCase(TestCase):
@classmethod
def setUpClass(cls):
super(BaseCassEngTestCase, cls).setUpClass()
if not connection._hosts:
connection.setup(['localhost:9160'], default_keyspace='cqlengine_test')
# @classmethod
# def setUpClass(cls):
# super(BaseCassEngTestCase, cls).setUpClass()
session = None
def setUp(self):
self.session = get_session()
super(BaseCassEngTestCase, self).setUp()
def assertHasAttr(self, obj, attr):
self.assertTrue(hasattr(obj, attr),
self.assertTrue(hasattr(obj, attr),
"{} doesn't have attribute: {}".format(obj, attr))
def assertNotHasAttr(self, obj, attr):
self.assertFalse(hasattr(obj, attr),
self.assertFalse(hasattr(obj, attr),
"{} shouldn't have the attribute: {}".format(obj, attr))

View File

@@ -1,32 +1,91 @@
from datetime import datetime, timedelta
import json
from uuid import uuid4
from cqlengine import Model, ValidationError
from cqlengine import columns
from cqlengine.management import create_table, delete_table
from cqlengine.management import sync_table, drop_table
from cqlengine.tests.base import BaseCassEngTestCase
class TestSetModel(Model):
partition = columns.UUID(primary_key=True, default=uuid4)
int_set = columns.Set(columns.Integer, required=False)
text_set = columns.Set(columns.Text, required=False)
__keyspace__ = 'test'
partition = columns.UUID(primary_key=True, default=uuid4)
int_set = columns.Set(columns.Integer, required=False)
text_set = columns.Set(columns.Text, required=False)
class JsonTestColumn(columns.Column):
db_type = 'text'
def to_python(self, value):
if value is None: return
if isinstance(value, basestring):
return json.loads(value)
else:
return value
def to_database(self, value):
if value is None: return
return json.dumps(value)
class TestSetColumn(BaseCassEngTestCase):
@classmethod
def setUpClass(cls):
super(TestSetColumn, cls).setUpClass()
delete_table(TestSetModel)
create_table(TestSetModel)
drop_table(TestSetModel)
sync_table(TestSetModel)
@classmethod
def tearDownClass(cls):
super(TestSetColumn, cls).tearDownClass()
delete_table(TestSetModel)
drop_table(TestSetModel)
def test_add_none_fails(self):
with self.assertRaises(ValidationError):
m = TestSetModel.create(int_set=set([None]))
def test_empty_set_initial(self):
"""
tests that sets are set() by default, should never be none
:return:
"""
m = TestSetModel.create()
m.int_set.add(5)
m.save()
def test_deleting_last_item_should_succeed(self):
m = TestSetModel.create()
m.int_set.add(5)
m.save()
m.int_set.remove(5)
m.save()
m = TestSetModel.get(partition=m.partition)
self.assertNotIn(5, m.int_set)
def test_blind_deleting_last_item_should_succeed(self):
m = TestSetModel.create()
m.int_set.add(5)
m.save()
TestSetModel.objects(partition=m.partition).update(int_set=set())
m = TestSetModel.get(partition=m.partition)
self.assertNotIn(5, m.int_set)
def test_empty_set_retrieval(self):
m = TestSetModel.create()
m2 = TestSetModel.get(partition=m.partition)
m2.int_set.add(3)
def test_io_success(self):
""" Tests that a basic usage works as expected """
m1 = TestSetModel.create(int_set={1,2}, text_set={'kai', 'andreas'})
m1 = TestSetModel.create(int_set={1, 2}, text_set={'kai', 'andreas'})
m2 = TestSetModel.get(partition=m1.partition)
assert isinstance(m2.int_set, set)
@@ -45,54 +104,96 @@ class TestSetColumn(BaseCassEngTestCase):
with self.assertRaises(ValidationError):
TestSetModel.create(int_set={'string', True}, text_set={1, 3.0})
def test_element_count_validation(self):
"""
Tests that big collections are detected and raise an exception.
"""
TestSetModel.create(text_set={str(uuid4()) for i in range(65535)})
with self.assertRaises(ValidationError):
TestSetModel.create(text_set={str(uuid4()) for i in range(65536)})
def test_partial_updates(self):
""" Tests that partial udpates work as expected """
m1 = TestSetModel.create(int_set={1,2,3,4})
m1 = TestSetModel.create(int_set={1, 2, 3, 4})
m1.int_set.add(5)
m1.int_set.remove(1)
assert m1.int_set == {2,3,4,5}
assert m1.int_set == {2, 3, 4, 5}
m1.save()
m2 = TestSetModel.get(partition=m1.partition)
assert m2.int_set == {2,3,4,5}
m2 = TestSetModel.get(partition=m1.partition)
assert m2.int_set == {2, 3, 4, 5}
def test_partial_update_creation(self):
def test_instantiation_with_column_class(self):
"""
Tests that proper update statements are created for a partial set update
:return:
Tests that columns instantiated with a column class work properly
and that the class is instantiated in the constructor
"""
ctx = {}
col = columns.Set(columns.Integer, db_field="TEST")
statements = col.get_update_statement({1,2,3,4}, {2,3,4,5}, ctx)
column = columns.Set(columns.Text)
assert isinstance(column.value_col, columns.Text)
def test_instantiation_with_column_instance(self):
"""
Tests that columns instantiated with a column instance work properly
"""
column = columns.Set(columns.Text(min_length=100))
assert isinstance(column.value_col, columns.Text)
def test_to_python(self):
""" Tests that to_python of value column is called """
column = columns.Set(JsonTestColumn)
val = {1, 2, 3}
db_val = column.to_database(val)
assert db_val.value == {json.dumps(v) for v in val}
py_val = column.to_python(db_val.value)
assert py_val == val
def test_default_empty_container_saving(self):
""" tests that the default empty container is not saved if it hasn't been updated """
pkey = uuid4()
# create a row with set data
TestSetModel.create(partition=pkey, int_set={3, 4})
# create another with no set data
TestSetModel.create(partition=pkey)
m = TestSetModel.get(partition=pkey)
self.assertEqual(m.int_set, {3, 4})
assert len([v for v in ctx.values() if {1} == v.value]) == 1
assert len([v for v in ctx.values() if {5} == v.value]) == 1
assert len([s for s in statements if '"TEST" = "TEST" -' in s]) == 1
assert len([s for s in statements if '"TEST" = "TEST" +' in s]) == 1
class TestListModel(Model):
partition = columns.UUID(primary_key=True, default=uuid4)
int_list = columns.List(columns.Integer, required=False)
text_list = columns.List(columns.Text, required=False)
__keyspace__ = 'test'
partition = columns.UUID(primary_key=True, default=uuid4)
int_list = columns.List(columns.Integer, required=False)
text_list = columns.List(columns.Text, required=False)
class TestListColumn(BaseCassEngTestCase):
@classmethod
def setUpClass(cls):
super(TestListColumn, cls).setUpClass()
delete_table(TestListModel)
create_table(TestListModel)
drop_table(TestListModel)
sync_table(TestListModel)
@classmethod
def tearDownClass(cls):
super(TestListColumn, cls).tearDownClass()
delete_table(TestListModel)
drop_table(TestListModel)
def test_initial(self):
tmp = TestListModel.create()
tmp.int_list.append(1)
def test_initial(self):
tmp = TestListModel.create()
tmp2 = TestListModel.get(partition=tmp.partition)
tmp2.int_list.append(1)
def test_io_success(self):
""" Tests that a basic usage works as expected """
m1 = TestListModel.create(int_list=[1,2], text_list=['kai', 'andreas'])
m1 = TestListModel.create(int_list=[1, 2], text_list=['kai', 'andreas'])
m2 = TestListModel.get(partition=m1.partition)
assert isinstance(m2.int_list, list)
@@ -114,6 +215,14 @@ class TestListColumn(BaseCassEngTestCase):
with self.assertRaises(ValidationError):
TestListModel.create(int_list=['string', True], text_list=[1, 3.0])
def test_element_count_validation(self):
"""
Tests that big collections are detected and raise an exception.
"""
TestListModel.create(text_list=[str(uuid4()) for i in range(65535)])
with self.assertRaises(ValidationError):
TestListModel.create(text_list=[str(uuid4()) for i in range(65536)])
def test_partial_updates(self):
""" Tests that partial udpates work as expected """
final = range(10)
@@ -123,43 +232,128 @@ class TestListColumn(BaseCassEngTestCase):
m1.int_list = final
m1.save()
m2 = TestListModel.get(partition=m1.partition)
m2 = TestListModel.get(partition=m1.partition)
assert list(m2.int_list) == final
def test_partial_update_creation(self):
def test_instantiation_with_column_class(self):
"""
Tests that proper update statements are created for a partial list update
:return:
Tests that columns instantiated with a column class work properly
and that the class is instantiated in the constructor
"""
final = range(10)
initial = final[3:7]
column = columns.List(columns.Text)
assert isinstance(column.value_col, columns.Text)
ctx = {}
col = columns.List(columns.Integer, db_field="TEST")
statements = col.get_update_statement(final, initial, ctx)
def test_instantiation_with_column_instance(self):
"""
Tests that columns instantiated with a column instance work properly
"""
column = columns.List(columns.Text(min_length=100))
assert isinstance(column.value_col, columns.Text)
assert len([v for v in ctx.values() if [2,1,0] == v.value]) == 1
assert len([v for v in ctx.values() if [7,8,9] == v.value]) == 1
assert len([s for s in statements if '"TEST" = "TEST" +' in s]) == 1
assert len([s for s in statements if '+ "TEST"' in s]) == 1
def test_to_python(self):
""" Tests that to_python of value column is called """
column = columns.List(JsonTestColumn)
val = [1, 2, 3]
db_val = column.to_database(val)
assert db_val.value == [json.dumps(v) for v in val]
py_val = column.to_python(db_val.value)
assert py_val == val
def test_default_empty_container_saving(self):
""" tests that the default empty container is not saved if it hasn't been updated """
pkey = uuid4()
# create a row with list data
TestListModel.create(partition=pkey, int_list=[1,2,3,4])
# create another with no list data
TestListModel.create(partition=pkey)
m = TestListModel.get(partition=pkey)
self.assertEqual(m.int_list, [1,2,3,4])
def test_remove_entry_works(self):
pkey = uuid4()
tmp = TestListModel.create(partition=pkey, int_list=[1,2])
tmp.int_list.pop()
tmp.update()
tmp = TestListModel.get(partition=pkey)
self.assertEqual(tmp.int_list, [1])
def test_update_from_non_empty_to_empty(self):
pkey = uuid4()
tmp = TestListModel.create(partition=pkey, int_list=[1,2])
tmp.int_list = []
tmp.update()
tmp = TestListModel.get(partition=pkey)
self.assertEqual(tmp.int_list, [])
def test_insert_none(self):
pkey = uuid4()
with self.assertRaises(ValidationError):
TestListModel.create(partition=pkey, int_list=[None])
def test_blind_list_updates_from_none(self):
""" Tests that updates from None work as expected """
m = TestListModel.create(int_list=None)
expected = [1, 2]
m.int_list = expected
m.save()
m2 = TestListModel.get(partition=m.partition)
assert m2.int_list == expected
TestListModel.objects(partition=m.partition).update(int_list=[])
m3 = TestListModel.get(partition=m.partition)
assert m3.int_list == []
class TestMapModel(Model):
partition = columns.UUID(primary_key=True, default=uuid4)
int_map = columns.Map(columns.Integer, columns.UUID, required=False)
text_map = columns.Map(columns.Text, columns.DateTime, required=False)
__keyspace__ = 'test'
partition = columns.UUID(primary_key=True, default=uuid4)
int_map = columns.Map(columns.Integer, columns.UUID, required=False)
text_map = columns.Map(columns.Text, columns.DateTime, required=False)
class TestMapColumn(BaseCassEngTestCase):
@classmethod
def setUpClass(cls):
super(TestMapColumn, cls).setUpClass()
delete_table(TestMapModel)
create_table(TestMapModel)
drop_table(TestMapModel)
sync_table(TestMapModel)
@classmethod
def tearDownClass(cls):
super(TestMapColumn, cls).tearDownClass()
delete_table(TestMapModel)
drop_table(TestMapModel)
def test_empty_default(self):
tmp = TestMapModel.create()
tmp.int_map['blah'] = 1
def test_add_none_as_map_key(self):
with self.assertRaises(ValidationError):
TestMapModel.create(int_map={None:1})
def test_add_none_as_map_value(self):
with self.assertRaises(ValidationError):
TestMapModel.create(int_map={None:1})
def test_empty_retrieve(self):
tmp = TestMapModel.create()
tmp2 = TestMapModel.get(partition=tmp.partition)
tmp2.int_map['blah'] = 1
def test_remove_last_entry_works(self):
tmp = TestMapModel.create()
tmp.text_map["blah"] = datetime.now()
tmp.save()
del tmp.text_map["blah"]
tmp.save()
tmp = TestMapModel.get(partition=tmp.partition)
self.assertNotIn("blah", tmp.int_map)
def test_io_success(self):
""" Tests that a basic usage works as expected """
@@ -167,7 +361,7 @@ class TestMapColumn(BaseCassEngTestCase):
k2 = uuid4()
now = datetime.now()
then = now + timedelta(days=1)
m1 = TestMapModel.create(int_map={1:k1,2:k2}, text_map={'now':now, 'then':then})
m1 = TestMapModel.create(int_map={1: k1, 2: k2}, text_map={'now': now, 'then': then})
m2 = TestMapModel.get(partition=m1.partition)
assert isinstance(m2.int_map, dict)
@@ -188,7 +382,15 @@ class TestMapColumn(BaseCassEngTestCase):
Tests that attempting to use the wrong types will raise an exception
"""
with self.assertRaises(ValidationError):
TestMapModel.create(int_map={'key':2,uuid4():'val'}, text_map={2:5})
TestMapModel.create(int_map={'key': 2, uuid4(): 'val'}, text_map={2: 5})
def test_element_count_validation(self):
"""
Tests that big collections are detected and raise an exception.
"""
TestMapModel.create(text_map={str(uuid4()): i for i in range(65535)})
with self.assertRaises(ValidationError):
TestMapModel.create(text_map={str(uuid4()): i for i in range(65536)})
def test_partial_updates(self):
""" Tests that partial udpates work as expected """
@@ -199,36 +401,93 @@ class TestMapColumn(BaseCassEngTestCase):
earlier = early - timedelta(minutes=30)
later = now + timedelta(minutes=30)
initial = {'now':now, 'early':earlier}
final = {'later':later, 'early':early}
initial = {'now': now, 'early': earlier}
final = {'later': later, 'early': early}
m1 = TestMapModel.create(text_map=initial)
m1.text_map = final
m1.save()
m2 = TestMapModel.get(partition=m1.partition)
m2 = TestMapModel.get(partition=m1.partition)
assert m2.text_map == final
def test_updates_from_none(self):
""" Tests that updates from None work as expected """
m = TestMapModel.create(int_map=None)
expected = {1:uuid4()}
expected = {1: uuid4()}
m.int_map = expected
m.save()
m2 = TestMapModel.get(partition=m.partition)
assert m2.int_map == expected
m2.int_map = None
m2.save()
m3 = TestMapModel.get(partition=m.partition)
assert m3.int_map != expected
def test_blind_updates_from_none(self):
""" Tests that updates from None work as expected """
m = TestMapModel.create(int_map=None)
expected = {1: uuid4()}
m.int_map = expected
m.save()
m2 = TestMapModel.get(partition=m.partition)
assert m2.int_map == expected
TestMapModel.objects(partition=m.partition).update(int_map={})
m3 = TestMapModel.get(partition=m.partition)
assert m3.int_map != expected
def test_updates_to_none(self):
""" Tests that setting the field to None works as expected """
m = TestMapModel.create(int_map={1:uuid4()})
m = TestMapModel.create(int_map={1: uuid4()})
m.int_map = None
m.save()
m2 = TestMapModel.get(partition=m.partition)
assert m2.int_map is None
assert m2.int_map == {}
def test_instantiation_with_column_class(self):
"""
Tests that columns instantiated with a column class work properly
and that the class is instantiated in the constructor
"""
column = columns.Map(columns.Text, columns.Integer)
assert isinstance(column.key_col, columns.Text)
assert isinstance(column.value_col, columns.Integer)
def test_instantiation_with_column_instance(self):
"""
Tests that columns instantiated with a column instance work properly
"""
column = columns.Map(columns.Text(min_length=100), columns.Integer())
assert isinstance(column.key_col, columns.Text)
assert isinstance(column.value_col, columns.Integer)
def test_to_python(self):
""" Tests that to_python of value column is called """
column = columns.Map(JsonTestColumn, JsonTestColumn)
val = {1: 2, 3: 4, 5: 6}
db_val = column.to_database(val)
assert db_val.value == {json.dumps(k):json.dumps(v) for k,v in val.items()}
py_val = column.to_python(db_val.value)
assert py_val == val
def test_default_empty_container_saving(self):
""" tests that the default empty container is not saved if it hasn't been updated """
pkey = uuid4()
tmap = {1: uuid4(), 2: uuid4()}
# create a row with set data
TestMapModel.create(partition=pkey, int_map=tmap)
# create another with no set data
TestMapModel.create(partition=pkey)
m = TestMapModel.get(partition=pkey)
self.assertEqual(m.int_map, tmap)
# def test_partial_update_creation(self):
# """
@@ -246,3 +505,27 @@ class TestMapColumn(BaseCassEngTestCase):
# assert len([v for v in ctx.values() if [7,8,9] == v.value]) == 1
# assert len([s for s in statements if '"TEST" = "TEST" +' in s]) == 1
# assert len([s for s in statements if '+ "TEST"' in s]) == 1
class TestCamelMapModel(Model):
__keyspace__ = 'test'
partition = columns.UUID(primary_key=True, default=uuid4)
camelMap = columns.Map(columns.Text, columns.Integer, required=False)
class TestCamelMapColumn(BaseCassEngTestCase):
@classmethod
def setUpClass(cls):
super(TestCamelMapColumn, cls).setUpClass()
drop_table(TestCamelMapModel)
sync_table(TestCamelMapModel)
@classmethod
def tearDownClass(cls):
super(TestCamelMapColumn, cls).tearDownClass()
drop_table(TestCamelMapModel)
def test_camelcase_column(self):
TestCamelMapModel.create(partition=None, camelMap={'blah': 1})

View File

@@ -0,0 +1,94 @@
from uuid import uuid4
from cqlengine import Model
from cqlengine import columns
from cqlengine.management import sync_table, drop_table
from cqlengine.models import ModelDefinitionException
from cqlengine.tests.base import BaseCassEngTestCase
class TestCounterModel(Model):
__keyspace__ = 'test'
partition = columns.UUID(primary_key=True, default=uuid4)
cluster = columns.UUID(primary_key=True, default=uuid4)
counter = columns.Counter()
class TestClassConstruction(BaseCassEngTestCase):
def test_defining_a_non_counter_column_fails(self):
""" Tests that defining a non counter column field in a model with a counter column fails """
with self.assertRaises(ModelDefinitionException):
class model(Model):
partition = columns.UUID(primary_key=True, default=uuid4)
counter = columns.Counter()
text = columns.Text()
def test_defining_a_primary_key_counter_column_fails(self):
""" Tests that defining primary keys on counter columns fails """
with self.assertRaises(TypeError):
class model(Model):
partition = columns.UUID(primary_key=True, default=uuid4)
cluster = columns.Counter(primary_ley=True)
counter = columns.Counter()
# force it
with self.assertRaises(ModelDefinitionException):
class model(Model):
partition = columns.UUID(primary_key=True, default=uuid4)
cluster = columns.Counter()
cluster.primary_key = True
counter = columns.Counter()
class TestCounterColumn(BaseCassEngTestCase):
@classmethod
def setUpClass(cls):
super(TestCounterColumn, cls).setUpClass()
drop_table(TestCounterModel)
sync_table(TestCounterModel)
@classmethod
def tearDownClass(cls):
super(TestCounterColumn, cls).tearDownClass()
drop_table(TestCounterModel)
def test_updates(self):
""" Tests that counter updates work as intended """
instance = TestCounterModel.create()
instance.counter += 5
instance.save()
actual = TestCounterModel.get(partition=instance.partition)
assert actual.counter == 5
def test_concurrent_updates(self):
""" Tests updates from multiple queries reaches the correct value """
instance = TestCounterModel.create()
new1 = TestCounterModel.get(partition=instance.partition)
new2 = TestCounterModel.get(partition=instance.partition)
new1.counter += 5
new1.save()
new2.counter += 5
new2.save()
actual = TestCounterModel.get(partition=instance.partition)
assert actual.counter == 10
def test_update_from_none(self):
""" Tests that updating from None uses a create statement """
instance = TestCounterModel()
instance.counter += 1
instance.save()
new = TestCounterModel.get(partition=instance.partition)
assert new.counter == 1
def test_new_instance_defaults_to_zero(self):
""" Tests that instantiating a new model instance will set the counter column to zero """
instance = TestCounterModel()
assert instance.counter == 0

View File

@@ -3,7 +3,10 @@ from datetime import datetime, timedelta
from datetime import date
from datetime import tzinfo
from decimal import Decimal as D
from unittest import TestCase
from uuid import uuid4, uuid1
from cqlengine import ValidationError
from cqlengine.connection import execute
from cqlengine.tests.base import BaseCassEngTestCase
@@ -12,6 +15,8 @@ from cqlengine.columns import Bytes
from cqlengine.columns import Ascii
from cqlengine.columns import Text
from cqlengine.columns import Integer
from cqlengine.columns import BigInt
from cqlengine.columns import VarInt
from cqlengine.columns import DateTime
from cqlengine.columns import Date
from cqlengine.columns import UUID
@@ -19,23 +24,27 @@ from cqlengine.columns import Boolean
from cqlengine.columns import Float
from cqlengine.columns import Decimal
from cqlengine.management import create_table, delete_table
from cqlengine.management import sync_table, drop_table
from cqlengine.models import Model
import sys
class TestDatetime(BaseCassEngTestCase):
class DatetimeTest(Model):
__keyspace__ = 'test'
test_id = Integer(primary_key=True)
created_at = DateTime()
@classmethod
def setUpClass(cls):
super(TestDatetime, cls).setUpClass()
create_table(cls.DatetimeTest)
sync_table(cls.DatetimeTest)
@classmethod
def tearDownClass(cls):
super(TestDatetime, cls).tearDownClass()
delete_table(cls.DatetimeTest)
drop_table(cls.DatetimeTest)
def test_datetime_io(self):
now = datetime.now()
@@ -55,21 +64,78 @@ class TestDatetime(BaseCassEngTestCase):
dt2 = self.DatetimeTest.objects(test_id=0).first()
assert dt2.created_at.timetuple()[:6] == (now + timedelta(hours=1)).timetuple()[:6]
def test_datetime_date_support(self):
today = date.today()
self.DatetimeTest.objects.create(test_id=0, created_at=today)
dt2 = self.DatetimeTest.objects(test_id=0).first()
assert dt2.created_at.isoformat() == datetime(today.year, today.month, today.day).isoformat()
def test_datetime_none(self):
dt = self.DatetimeTest.objects.create(test_id=1, created_at=None)
dt2 = self.DatetimeTest.objects(test_id=1).first()
assert dt2.created_at is None
dts = self.DatetimeTest.objects.filter(test_id=1).values_list('created_at')
assert dts[0][0] is None
class TestBoolDefault(BaseCassEngTestCase):
class BoolDefaultValueTest(Model):
__keyspace__ = 'test'
test_id = Integer(primary_key=True)
stuff = Boolean(default=True)
@classmethod
def setUpClass(cls):
super(TestBoolDefault, cls).setUpClass()
sync_table(cls.BoolDefaultValueTest)
def test_default_is_set(self):
tmp = self.BoolDefaultValueTest.create(test_id=1)
self.assertEqual(True, tmp.stuff)
tmp2 = self.BoolDefaultValueTest.get(test_id=1)
self.assertEqual(True, tmp2.stuff)
class TestVarInt(BaseCassEngTestCase):
class VarIntTest(Model):
__keyspace__ = 'test'
test_id = Integer(primary_key=True)
bignum = VarInt(primary_key=True)
@classmethod
def setUpClass(cls):
super(TestVarInt, cls).setUpClass()
sync_table(cls.VarIntTest)
@classmethod
def tearDownClass(cls):
super(TestVarInt, cls).tearDownClass()
sync_table(cls.VarIntTest)
def test_varint_io(self):
long_int = sys.maxint + 1
int1 = self.VarIntTest.objects.create(test_id=0, bignum=long_int)
int2 = self.VarIntTest.objects(test_id=0).first()
assert int1.bignum == int2.bignum
class TestDate(BaseCassEngTestCase):
class DateTest(Model):
__keyspace__ = 'test'
test_id = Integer(primary_key=True)
created_at = Date()
@classmethod
def setUpClass(cls):
super(TestDate, cls).setUpClass()
create_table(cls.DateTest)
sync_table(cls.DateTest)
@classmethod
def tearDownClass(cls):
super(TestDate, cls).tearDownClass()
delete_table(cls.DateTest)
drop_table(cls.DateTest)
def test_date_io(self):
today = date.today()
@@ -85,23 +151,32 @@ class TestDate(BaseCassEngTestCase):
assert isinstance(dt2.created_at, date)
assert dt2.created_at.isoformat() == now.date().isoformat()
def test_date_none(self):
self.DateTest.objects.create(test_id=1, created_at=None)
dt2 = self.DateTest.objects(test_id=1).first()
assert dt2.created_at is None
dts = self.DateTest.objects(test_id=1).values_list('created_at')
assert dts[0][0] is None
class TestDecimal(BaseCassEngTestCase):
class DecimalTest(Model):
__keyspace__ = 'test'
test_id = Integer(primary_key=True)
dec_val = Decimal()
@classmethod
def setUpClass(cls):
super(TestDecimal, cls).setUpClass()
create_table(cls.DecimalTest)
sync_table(cls.DecimalTest)
@classmethod
def tearDownClass(cls):
super(TestDecimal, cls).tearDownClass()
delete_table(cls.DecimalTest)
drop_table(cls.DecimalTest)
def test_datetime_io(self):
def test_decimal_io(self):
dt = self.DecimalTest.objects.create(test_id=0, dec_val=D('0.00'))
dt2 = self.DecimalTest.objects(test_id=0).first()
assert dt2.dec_val == dt.dec_val
@@ -110,22 +185,55 @@ class TestDecimal(BaseCassEngTestCase):
dt2 = self.DecimalTest.objects(test_id=0).first()
assert dt2.dec_val == D('5')
class TestUUID(BaseCassEngTestCase):
class UUIDTest(Model):
__keyspace__ = 'test'
test_id = Integer(primary_key=True)
a_uuid = UUID(default=uuid4())
@classmethod
def setUpClass(cls):
super(TestUUID, cls).setUpClass()
sync_table(cls.UUIDTest)
@classmethod
def tearDownClass(cls):
super(TestUUID, cls).tearDownClass()
drop_table(cls.UUIDTest)
def test_uuid_str_with_dashes(self):
a_uuid = uuid4()
t0 = self.UUIDTest.create(test_id=0, a_uuid=str(a_uuid))
t1 = self.UUIDTest.get(test_id=0)
assert a_uuid == t1.a_uuid
def test_uuid_str_no_dashes(self):
a_uuid = uuid4()
t0 = self.UUIDTest.create(test_id=1, a_uuid=a_uuid.hex)
t1 = self.UUIDTest.get(test_id=1)
assert a_uuid == t1.a_uuid
class TestTimeUUID(BaseCassEngTestCase):
class TimeUUIDTest(Model):
__keyspace__ = 'test'
test_id = Integer(primary_key=True)
timeuuid = TimeUUID()
timeuuid = TimeUUID(default=uuid1())
@classmethod
def setUpClass(cls):
super(TestTimeUUID, cls).setUpClass()
create_table(cls.TimeUUIDTest)
sync_table(cls.TimeUUIDTest)
@classmethod
def tearDownClass(cls):
super(TestTimeUUID, cls).tearDownClass()
delete_table(cls.TimeUUIDTest)
drop_table(cls.TimeUUIDTest)
def test_timeuuid_io(self):
"""
ensures that
:return:
"""
t0 = self.TimeUUIDTest.create(test_id=0)
t1 = self.TimeUUIDTest.get(test_id=0)
@@ -133,22 +241,32 @@ class TestTimeUUID(BaseCassEngTestCase):
class TestInteger(BaseCassEngTestCase):
class IntegerTest(Model):
test_id = UUID(primary_key=True)
value = Integer(default=0)
__keyspace__ = 'test'
test_id = UUID(primary_key=True, default=lambda:uuid4())
value = Integer(default=0, required=True)
def test_default_zero_fields_validate(self):
""" Tests that integer columns with a default value of 0 validate """
it = self.IntegerTest()
it.validate()
class TestBigInt(BaseCassEngTestCase):
class BigIntTest(Model):
__keyspace__ = 'test'
test_id = UUID(primary_key=True, default=lambda:uuid4())
value = BigInt(default=0, required=True)
def test_default_zero_fields_validate(self):
""" Tests that bigint columns with a default value of 0 validate """
it = self.BigIntTest()
it.validate()
class TestText(BaseCassEngTestCase):
def test_min_length(self):
#min len defaults to 1
col = Text()
with self.assertRaises(ValidationError):
col.validate('')
col.validate('')
col.validate('b')
@@ -174,7 +292,7 @@ class TestText(BaseCassEngTestCase):
Text().validate(bytearray('bytearray'))
with self.assertRaises(ValidationError):
Text().validate(None)
Text(required=True).validate(None)
with self.assertRaises(ValidationError):
Text().validate(5)
@@ -182,15 +300,48 @@ class TestText(BaseCassEngTestCase):
with self.assertRaises(ValidationError):
Text().validate(True)
def test_non_required_validation(self):
""" Tests that validation is ok on none and blank values if required is False """
Text().validate('')
Text().validate(None)
class TestExtraFieldsRaiseException(BaseCassEngTestCase):
class TestModel(Model):
__keyspace__ = 'test'
id = UUID(primary_key=True, default=uuid4)
def test_extra_field(self):
with self.assertRaises(ValidationError):
self.TestModel.create(bacon=5000)
class TestPythonDoesntDieWhenExtraFieldIsInCassandra(BaseCassEngTestCase):
class TestModel(Model):
__keyspace__ = 'test'
__table_name__ = 'alter_doesnt_break_running_app'
id = UUID(primary_key=True, default=uuid4)
def test_extra_field(self):
drop_table(self.TestModel)
sync_table(self.TestModel)
self.TestModel.create()
execute("ALTER TABLE {} add blah int".format(self.TestModel.column_family_name(include_keyspace=True)))
self.TestModel.objects().all()
class TestTimeUUIDFromDatetime(TestCase):
def test_conversion_specific_date(self):
dt = datetime(1981, 7, 11, microsecond=555000)
uuid = TimeUUID.from_datetime(dt)
from uuid import UUID
assert isinstance(uuid, UUID)
ts = (uuid.time - 0x01b21dd213814000) / 1e7 # back to a timestamp
new_dt = datetime.utcfromtimestamp(ts)
# checks that we created a UUID1 with the proper timestamp
assert new_dt == dt

View File

@@ -0,0 +1,171 @@
from datetime import datetime, timedelta
from decimal import Decimal
from uuid import uuid1, uuid4, UUID
from cqlengine.tests.base import BaseCassEngTestCase
from cqlengine.management import sync_table
from cqlengine.management import drop_table
from cqlengine.models import Model
from cqlengine.columns import ValueQuoter
from cqlengine import columns
import unittest
class BaseColumnIOTest(BaseCassEngTestCase):
"""
Tests that values are come out of cassandra in the format we expect
To test a column type, subclass this test, define the column, and the primary key
and data values you want to test
"""
# The generated test model is assigned here
_generated_model = None
# the column we want to test
column = None
# the values we want to test against, you can
# use a single value, or multiple comma separated values
pkey_val = None
data_val = None
@classmethod
def setUpClass(cls):
super(BaseColumnIOTest, cls).setUpClass()
#if the test column hasn't been defined, bail out
if not cls.column: return
# create a table with the given column
class IOTestModel(Model):
__keyspace__ = 'test'
table_name = cls.column.db_type + "_io_test_model_{}".format(uuid4().hex[:8])
pkey = cls.column(primary_key=True)
data = cls.column()
cls._generated_model = IOTestModel
sync_table(cls._generated_model)
#tupleify the tested values
if not isinstance(cls.pkey_val, tuple):
cls.pkey_val = cls.pkey_val,
if not isinstance(cls.data_val, tuple):
cls.data_val = cls.data_val,
@classmethod
def tearDownClass(cls):
super(BaseColumnIOTest, cls).tearDownClass()
if not cls.column: return
drop_table(cls._generated_model)
def comparator_converter(self, val):
""" If you want to convert the original value used to compare the model vales """
return val
def test_column_io(self):
""" Tests the given models class creates and retrieves values as expected """
if not self.column: return
for pkey, data in zip(self.pkey_val, self.data_val):
#create
m1 = self._generated_model.create(pkey=pkey, data=data)
#get
m2 = self._generated_model.get(pkey=pkey)
assert m1.pkey == m2.pkey == self.comparator_converter(pkey), self.column
assert m1.data == m2.data == self.comparator_converter(data), self.column
#delete
self._generated_model.filter(pkey=pkey).delete()
class TestBlobIO(BaseColumnIOTest):
column = columns.Bytes
pkey_val = 'blake', uuid4().bytes
data_val = 'eggleston', uuid4().bytes
class TestTextIO(BaseColumnIOTest):
column = columns.Text
pkey_val = 'bacon'
data_val = 'monkey'
class TestNonBinaryTextIO(BaseColumnIOTest):
column = columns.Text
pkey_val = 'bacon'
data_val = '0xmonkey'
class TestInteger(BaseColumnIOTest):
column = columns.Integer
pkey_val = 5
data_val = 6
class TestBigInt(BaseColumnIOTest):
column = columns.BigInt
pkey_val = 6
data_val = pow(2, 63) - 1
class TestDateTime(BaseColumnIOTest):
column = columns.DateTime
now = datetime(*datetime.now().timetuple()[:6])
pkey_val = now
data_val = now + timedelta(days=1)
class TestDate(BaseColumnIOTest):
column = columns.Date
now = datetime.now().date()
pkey_val = now
data_val = now + timedelta(days=1)
class TestUUID(BaseColumnIOTest):
column = columns.UUID
pkey_val = str(uuid4()), uuid4()
data_val = str(uuid4()), uuid4()
def comparator_converter(self, val):
return val if isinstance(val, UUID) else UUID(val)
class TestTimeUUID(BaseColumnIOTest):
column = columns.TimeUUID
pkey_val = str(uuid1()), uuid1()
data_val = str(uuid1()), uuid1()
def comparator_converter(self, val):
return val if isinstance(val, UUID) else UUID(val)
class TestFloatIO(BaseColumnIOTest):
column = columns.Float
pkey_val = 3.14
data_val = -1982.11
class TestDecimalIO(BaseColumnIOTest):
column = columns.Decimal
pkey_val = Decimal('1.35'), 5, '2.4'
data_val = Decimal('0.005'), 3.5, '8'
def comparator_converter(self, val):
return Decimal(val)
class TestQuoter(unittest.TestCase):
def test_equals(self):
assert ValueQuoter(False) == ValueQuoter(False)
assert ValueQuoter(1) == ValueQuoter(1)
assert ValueQuoter("foo") == ValueQuoter("foo")
assert ValueQuoter(1.55) == ValueQuoter(1.55)

View File

View File

@@ -0,0 +1,237 @@
import copy
import json
from time import sleep
from mock import patch, MagicMock
from cqlengine import Model, columns, SizeTieredCompactionStrategy, LeveledCompactionStrategy
from cqlengine.exceptions import CQLEngineException
from cqlengine.management import get_compaction_options, drop_table, sync_table, get_table_settings
from cqlengine.tests.base import BaseCassEngTestCase
class CompactionModel(Model):
__keyspace__ = 'test'
__compaction__ = None
cid = columns.UUID(primary_key=True)
name = columns.Text()
class BaseCompactionTest(BaseCassEngTestCase):
def assert_option_fails(self, key):
# key is a normal_key, converted to
# __compaction_key__
key = "__compaction_{}__".format(key)
with patch.object(self.model, key, 10), \
self.assertRaises(CQLEngineException):
get_compaction_options(self.model)
class SizeTieredCompactionTest(BaseCompactionTest):
def setUp(self):
self.model = copy.deepcopy(CompactionModel)
self.model.__compaction__ = SizeTieredCompactionStrategy
def test_size_tiered(self):
result = get_compaction_options(self.model)
assert result['class'] == SizeTieredCompactionStrategy
def test_min_threshold(self):
self.model.__compaction_min_threshold__ = 2
result = get_compaction_options(self.model)
assert result['min_threshold'] == '2'
class LeveledCompactionTest(BaseCompactionTest):
def setUp(self):
self.model = copy.deepcopy(CompactionLeveledStrategyModel)
def test_simple_leveled(self):
result = get_compaction_options(self.model)
assert result['class'] == LeveledCompactionStrategy
def test_bucket_high_fails(self):
self.assert_option_fails('bucket_high')
def test_bucket_low_fails(self):
self.assert_option_fails('bucket_low')
def test_max_threshold_fails(self):
self.assert_option_fails('max_threshold')
def test_min_threshold_fails(self):
self.assert_option_fails('min_threshold')
def test_min_sstable_size_fails(self):
self.assert_option_fails('min_sstable_size')
def test_sstable_size_in_mb(self):
with patch.object(self.model, '__compaction_sstable_size_in_mb__', 32):
result = get_compaction_options(self.model)
assert result['sstable_size_in_mb'] == '32'
class LeveledcompactionTestTable(Model):
__keyspace__ = 'test'
__compaction__ = LeveledCompactionStrategy
__compaction_sstable_size_in_mb__ = 64
user_id = columns.UUID(primary_key=True)
name = columns.Text()
from cqlengine.management import schema_columnfamilies
class AlterTableTest(BaseCassEngTestCase):
def test_alter_is_called_table(self):
drop_table(LeveledcompactionTestTable)
sync_table(LeveledcompactionTestTable)
with patch('cqlengine.management.update_compaction') as mock:
sync_table(LeveledcompactionTestTable)
assert mock.called == 1
def test_compaction_not_altered_without_changes_leveled(self):
from cqlengine.management import update_compaction
class LeveledCompactionChangesDetectionTest(Model):
__keyspace__ = 'test'
__compaction__ = LeveledCompactionStrategy
__compaction_sstable_size_in_mb__ = 160
__compaction_tombstone_threshold__ = 0.125
__compaction_tombstone_compaction_interval__ = 3600
pk = columns.Integer(primary_key=True)
drop_table(LeveledCompactionChangesDetectionTest)
sync_table(LeveledCompactionChangesDetectionTest)
assert not update_compaction(LeveledCompactionChangesDetectionTest)
def test_compaction_not_altered_without_changes_sizetiered(self):
from cqlengine.management import update_compaction
class SizeTieredCompactionChangesDetectionTest(Model):
__keyspace__ = 'test'
__compaction__ = SizeTieredCompactionStrategy
__compaction_bucket_high__ = 20
__compaction_bucket_low__ = 10
__compaction_max_threshold__ = 200
__compaction_min_threshold__ = 100
__compaction_min_sstable_size__ = 1000
__compaction_tombstone_threshold__ = 0.125
__compaction_tombstone_compaction_interval__ = 3600
pk = columns.Integer(primary_key=True)
drop_table(SizeTieredCompactionChangesDetectionTest)
sync_table(SizeTieredCompactionChangesDetectionTest)
assert not update_compaction(SizeTieredCompactionChangesDetectionTest)
def test_alter_actually_alters(self):
tmp = copy.deepcopy(LeveledcompactionTestTable)
drop_table(tmp)
sync_table(tmp)
tmp.__compaction__ = SizeTieredCompactionStrategy
tmp.__compaction_sstable_size_in_mb__ = None
sync_table(tmp)
table_settings = get_table_settings(tmp)
self.assertRegexpMatches(table_settings.options['compaction_strategy_class'], '.*SizeTieredCompactionStrategy$')
def test_alter_options(self):
class AlterTable(Model):
__keyspace__ = 'test'
__compaction__ = LeveledCompactionStrategy
__compaction_sstable_size_in_mb__ = 64
user_id = columns.UUID(primary_key=True)
name = columns.Text()
drop_table(AlterTable)
sync_table(AlterTable)
AlterTable.__compaction_sstable_size_in_mb__ = 128
sync_table(AlterTable)
class EmptyCompactionTest(BaseCassEngTestCase):
def test_empty_compaction(self):
class EmptyCompactionModel(Model):
__keyspace__ = 'test'
__compaction__ = None
cid = columns.UUID(primary_key=True)
name = columns.Text()
result = get_compaction_options(EmptyCompactionModel)
self.assertEqual({}, result)
class CompactionLeveledStrategyModel(Model):
__keyspace__ = 'test'
__compaction__ = LeveledCompactionStrategy
cid = columns.UUID(primary_key=True)
name = columns.Text()
class CompactionSizeTieredModel(Model):
__keyspace__ = 'test'
__compaction__ = SizeTieredCompactionStrategy
cid = columns.UUID(primary_key=True)
name = columns.Text()
class OptionsTest(BaseCassEngTestCase):
def test_all_size_tiered_options(self):
class AllSizeTieredOptionsModel(Model):
__keyspace__ = 'test'
__compaction__ = SizeTieredCompactionStrategy
__compaction_bucket_low__ = .3
__compaction_bucket_high__ = 2
__compaction_min_threshold__ = 2
__compaction_max_threshold__ = 64
__compaction_tombstone_compaction_interval__ = 86400
cid = columns.UUID(primary_key=True)
name = columns.Text()
drop_table(AllSizeTieredOptionsModel)
sync_table(AllSizeTieredOptionsModel)
options = get_table_settings(AllSizeTieredOptionsModel).options['compaction_strategy_options']
options = json.loads(options)
expected = {u'min_threshold': u'2',
u'bucket_low': u'0.3',
u'tombstone_compaction_interval': u'86400',
u'bucket_high': u'2',
u'max_threshold': u'64'}
self.assertDictEqual(options, expected)
def test_all_leveled_options(self):
class AllLeveledOptionsModel(Model):
__keyspace__ = 'test'
__compaction__ = LeveledCompactionStrategy
__compaction_sstable_size_in_mb__ = 64
cid = columns.UUID(primary_key=True)
name = columns.Text()
drop_table(AllLeveledOptionsModel)
sync_table(AllLeveledOptionsModel)
settings = get_table_settings(AllLeveledOptionsModel).options
options = json.loads(settings['compaction_strategy_options'])
self.assertDictEqual(options, {u'sstable_size_in_mb': u'64'})

View File

@@ -1,44 +1,14 @@
from cqlengine.management import create_table, delete_table
from cqlengine import ALL, CACHING_ALL, CACHING_NONE
from cqlengine.exceptions import CQLEngineException
from cqlengine.management import get_fields, sync_table, drop_table
from cqlengine.tests.base import BaseCassEngTestCase
from cqlengine.connection import ConnectionPool
from mock import Mock
from cqlengine import management
from cqlengine.tests.query.test_queryset import TestModel
from cqlengine.models import Model
from cqlengine import columns, SizeTieredCompactionStrategy, LeveledCompactionStrategy
class ConnectionPoolTestCase(BaseCassEngTestCase):
"""Test cassandra connection pooling."""
def setUp(self):
ConnectionPool.clear()
def test_should_create_single_connection_on_request(self):
"""Should create a single connection on first request"""
result = ConnectionPool.get()
self.assertIsNotNone(result)
self.assertEquals(0, ConnectionPool._queue.qsize())
ConnectionPool._queue.put(result)
self.assertEquals(1, ConnectionPool._queue.qsize())
def test_should_close_connection_if_queue_is_full(self):
"""Should close additional connections if queue is full"""
connections = [ConnectionPool.get() for x in range(10)]
for conn in connections:
ConnectionPool.put(conn)
fake_conn = Mock()
ConnectionPool.put(fake_conn)
fake_conn.close.assert_called_once_with()
def test_should_pop_connections_from_queue(self):
"""Should pull existing connections off of the queue"""
conn = ConnectionPool.get()
ConnectionPool.put(conn)
self.assertEquals(1, ConnectionPool._queue.qsize())
self.assertEquals(conn, ConnectionPool.get())
self.assertEquals(0, ConnectionPool._queue.qsize())
class CreateKeyspaceTest(BaseCassEngTestCase):
def test_create_succeeeds(self):
@@ -51,7 +21,219 @@ class DeleteTableTest(BaseCassEngTestCase):
"""
"""
create_table(TestModel)
sync_table(TestModel)
delete_table(TestModel)
delete_table(TestModel)
drop_table(TestModel)
drop_table(TestModel)
class LowercaseKeyModel(Model):
__keyspace__ = 'test'
first_key = columns.Integer(primary_key=True)
second_key = columns.Integer(primary_key=True)
some_data = columns.Text()
class CapitalizedKeyModel(Model):
__keyspace__ = 'test'
firstKey = columns.Integer(primary_key=True)
secondKey = columns.Integer(primary_key=True)
someData = columns.Text()
class PrimaryKeysOnlyModel(Model):
__keyspace__ = 'test'
__compaction__ = LeveledCompactionStrategy
first_ey = columns.Integer(primary_key=True)
second_key = columns.Integer(primary_key=True)
class CapitalizedKeyTest(BaseCassEngTestCase):
def test_table_definition(self):
""" Tests that creating a table with capitalized column names succeedso """
sync_table(LowercaseKeyModel)
sync_table(CapitalizedKeyModel)
drop_table(LowercaseKeyModel)
drop_table(CapitalizedKeyModel)
class FirstModel(Model):
__keyspace__ = 'test'
__table_name__ = 'first_model'
first_key = columns.UUID(primary_key=True)
second_key = columns.UUID()
third_key = columns.Text()
class SecondModel(Model):
__keyspace__ = 'test'
__table_name__ = 'first_model'
first_key = columns.UUID(primary_key=True)
second_key = columns.UUID()
third_key = columns.Text()
fourth_key = columns.Text()
class ThirdModel(Model):
__keyspace__ = 'test'
__table_name__ = 'first_model'
first_key = columns.UUID(primary_key=True)
second_key = columns.UUID()
third_key = columns.Text()
# removed fourth key, but it should stay in the DB
blah = columns.Map(columns.Text, columns.Text)
class FourthModel(Model):
__keyspace__ = 'test'
__table_name__ = 'first_model'
first_key = columns.UUID(primary_key=True)
second_key = columns.UUID()
third_key = columns.Text()
# removed fourth key, but it should stay in the DB
renamed = columns.Map(columns.Text, columns.Text, db_field='blah')
class AddColumnTest(BaseCassEngTestCase):
def setUp(self):
drop_table(FirstModel)
def test_add_column(self):
sync_table(FirstModel)
fields = get_fields(FirstModel)
# this should contain the second key
self.assertEqual(len(fields), 2)
# get schema
sync_table(SecondModel)
fields = get_fields(FirstModel)
self.assertEqual(len(fields), 3)
sync_table(ThirdModel)
fields = get_fields(FirstModel)
self.assertEqual(len(fields), 4)
sync_table(FourthModel)
fields = get_fields(FirstModel)
self.assertEqual(len(fields), 4)
class ModelWithTableProperties(Model):
__keyspace__ = 'test'
# Set random table properties
__bloom_filter_fp_chance__ = 0.76328
__caching__ = CACHING_ALL
__comment__ = 'TxfguvBdzwROQALmQBOziRMbkqVGFjqcJfVhwGR'
__default_time_to_live__ = 4756
__gc_grace_seconds__ = 2063
__index_interval__ = 98706
__memtable_flush_period_in_ms__ = 43681
__populate_io_cache_on_flush__ = True
__read_repair_chance__ = 0.17985
__replicate_on_write__ = False
__dclocal_read_repair_chance__ = 0.50811
key = columns.UUID(primary_key=True)
class TablePropertiesTests(BaseCassEngTestCase):
def setUp(self):
drop_table(ModelWithTableProperties)
def test_set_table_properties(self):
sync_table(ModelWithTableProperties)
self.assertDictContainsSubset({
'bloom_filter_fp_chance': 0.76328,
'caching': CACHING_ALL,
'comment': 'TxfguvBdzwROQALmQBOziRMbkqVGFjqcJfVhwGR',
'default_time_to_live': 4756,
'gc_grace_seconds': 2063,
'index_interval': 98706,
'memtable_flush_period_in_ms': 43681,
'populate_io_cache_on_flush': True,
'read_repair_chance': 0.17985,
'replicate_on_write': False,
# For some reason 'dclocal_read_repair_chance' in CQL is called
# just 'local_read_repair_chance' in the schema table.
# Source: https://issues.apache.org/jira/browse/CASSANDRA-6717
# TODO: due to a bug in the native driver i'm not seeing the local read repair chance show up
#'local_read_repair_chance': 0.50811,
}, management.get_table_settings(ModelWithTableProperties).options)
def test_table_property_update(self):
ModelWithTableProperties.__bloom_filter_fp_chance__ = 0.66778
ModelWithTableProperties.__caching__ = CACHING_NONE
ModelWithTableProperties.__comment__ = 'xirAkRWZVVvsmzRvXamiEcQkshkUIDINVJZgLYSdnGHweiBrAiJdLJkVohdRy'
ModelWithTableProperties.__default_time_to_live__ = 65178
ModelWithTableProperties.__gc_grace_seconds__ = 96362
ModelWithTableProperties.__index_interval__ = 94207
ModelWithTableProperties.__memtable_flush_period_in_ms__ = 60210
ModelWithTableProperties.__populate_io_cache_on_flush__ = False
ModelWithTableProperties.__read_repair_chance__ = 0.2989
ModelWithTableProperties.__replicate_on_write__ = True
ModelWithTableProperties.__dclocal_read_repair_chance__ = 0.12732
sync_table(ModelWithTableProperties)
table_settings = management.get_table_settings(ModelWithTableProperties).options
self.assertDictContainsSubset({
'bloom_filter_fp_chance': 0.66778,
'caching': CACHING_NONE,
'comment': 'xirAkRWZVVvsmzRvXamiEcQkshkUIDINVJZgLYSdnGHweiBrAiJdLJkVohdRy',
'default_time_to_live': 65178,
'gc_grace_seconds': 96362,
'index_interval': 94207,
'memtable_flush_period_in_ms': 60210,
'populate_io_cache_on_flush': False,
'read_repair_chance': 0.2989,
'replicate_on_write': True,
# TODO see above comment re: native driver missing local read repair chance
# 'local_read_repair_chance': 0.12732,
}, table_settings)
class SyncTableTests(BaseCassEngTestCase):
def setUp(self):
drop_table(PrimaryKeysOnlyModel)
def test_sync_table_works_with_primary_keys_only_tables(self):
# This is "create table":
sync_table(PrimaryKeysOnlyModel)
# let's make sure settings persisted correctly:
assert PrimaryKeysOnlyModel.__compaction__ == LeveledCompactionStrategy
# blows up with DoesNotExist if table does not exist
table_settings = management.get_table_settings(PrimaryKeysOnlyModel)
# let make sure the flag we care about
assert LeveledCompactionStrategy in table_settings.options['compaction_strategy_class']
# Now we are "updating" the table:
# setting up something to change
PrimaryKeysOnlyModel.__compaction__ = SizeTieredCompactionStrategy
# primary-keys-only tables do not create entries in system.schema_columns
# table. Only non-primary keys are added to that table.
# Our code must deal with that eventuality properly (not crash)
# on subsequent runs of sync_table (which runs get_fields internally)
get_fields(PrimaryKeysOnlyModel)
sync_table(PrimaryKeysOnlyModel)
table_settings = management.get_table_settings(PrimaryKeysOnlyModel)
assert SizeTieredCompactionStrategy in table_settings.options['compaction_strategy_class']
class NonModelFailureTest(BaseCassEngTestCase):
class FakeModel(object):
pass
def test_failure(self):
with self.assertRaises(CQLEngineException):
sync_table(self.FakeModel)

View File

@@ -1,7 +1,10 @@
from uuid import uuid4
import warnings
from cqlengine.query import QueryException, ModelQuerySet, DMLQuery
from cqlengine.tests.base import BaseCassEngTestCase
from cqlengine.exceptions import ModelException
from cqlengine.models import Model
from cqlengine.exceptions import ModelException, CQLEngineException
from cqlengine.models import Model, ModelDefinitionException, ColumnQueryEvaluator, UndefinedKeyspaceWarning
from cqlengine import columns
import cqlengine
@@ -17,6 +20,8 @@ class TestModelClassFunction(BaseCassEngTestCase):
"""
class TestModel(Model):
__keyspace__ = 'test'
id = columns.UUID(primary_key=True, default=lambda:uuid4())
text = columns.Text()
#check class attibutes
@@ -37,6 +42,8 @@ class TestModelClassFunction(BaseCassEngTestCase):
-the db_map allows columns
"""
class WildDBNames(Model):
__keyspace__ = 'test'
id = columns.UUID(primary_key=True, default=lambda:uuid4())
content = columns.Text(db_field='words_and_whatnot')
numbers = columns.Integer(db_field='integers_etc')
@@ -60,17 +67,29 @@ class TestModelClassFunction(BaseCassEngTestCase):
"""
class Stuff(Model):
__keyspace__ = 'test'
id = columns.UUID(primary_key=True, default=lambda:uuid4())
words = columns.Text()
content = columns.Text()
numbers = columns.Integer()
self.assertEquals(Stuff._columns.keys(), ['id', 'words', 'content', 'numbers'])
def test_exception_raised_when_creating_class_without_pk(self):
with self.assertRaises(ModelDefinitionException):
class TestModel(Model):
__keyspace__ = 'test'
count = columns.Integer()
text = columns.Text(required=False)
def test_value_managers_are_keeping_model_instances_isolated(self):
"""
Tests that instance value managers are isolated from other instances
"""
class Stuff(Model):
__keyspace__ = 'test'
id = columns.UUID(primary_key=True, default=lambda:uuid4())
num = columns.Integer()
inst1 = Stuff(num=5)
@@ -85,6 +104,8 @@ class TestModelClassFunction(BaseCassEngTestCase):
Tests that fields defined on the super class are inherited properly
"""
class TestModel(Model):
__keyspace__ = 'test'
id = columns.UUID(primary_key=True, default=lambda:uuid4())
text = columns.Text()
class InheritedModel(TestModel):
@@ -93,6 +114,15 @@ class TestModelClassFunction(BaseCassEngTestCase):
assert 'text' in InheritedModel._columns
assert 'numbers' in InheritedModel._columns
def test_column_family_name_generation(self):
""" Tests that auto column family name generation works as expected """
class TestModel(Model):
__keyspace__ = 'test'
id = columns.UUID(primary_key=True, default=lambda:uuid4())
text = columns.Text()
assert TestModel.column_family_name(include_keyspace=False) == 'test_model'
def test_normal_fields_can_be_defined_between_primary_keys(self):
"""
Tests tha non primary key fields can be defined between primary key fields
@@ -117,24 +147,224 @@ class TestModelClassFunction(BaseCassEngTestCase):
"""
Test that metadata defined in one class, is not inherited by subclasses
"""
def test_partition_keys(self):
"""
Test compound partition key definition
"""
class ModelWithPartitionKeys(cqlengine.Model):
__keyspace__ = 'test'
id = columns.UUID(primary_key=True, default=lambda:uuid4())
c1 = cqlengine.Text(primary_key=True)
p1 = cqlengine.Text(partition_key=True)
p2 = cqlengine.Text(partition_key=True)
cols = ModelWithPartitionKeys._columns
self.assertTrue(cols['c1'].primary_key)
self.assertFalse(cols['c1'].partition_key)
self.assertTrue(cols['p1'].primary_key)
self.assertTrue(cols['p1'].partition_key)
self.assertTrue(cols['p2'].primary_key)
self.assertTrue(cols['p2'].partition_key)
obj = ModelWithPartitionKeys(p1='a', p2='b')
self.assertEquals(obj.pk, ('a', 'b'))
def test_del_attribute_is_assigned_properly(self):
""" Tests that columns that can be deleted have the del attribute """
class DelModel(Model):
__keyspace__ = 'test'
id = columns.UUID(primary_key=True, default=lambda:uuid4())
key = columns.Integer(primary_key=True)
data = columns.Integer(required=False)
model = DelModel(key=4, data=5)
del model.data
with self.assertRaises(AttributeError):
del model.key
def test_does_not_exist_exceptions_are_not_shared_between_model(self):
""" Tests that DoesNotExist exceptions are not the same exception between models """
class Model1(Model):
__keyspace__ = 'test'
id = columns.UUID(primary_key=True, default=lambda:uuid4())
class Model2(Model):
__keyspace__ = 'test'
id = columns.UUID(primary_key=True, default=lambda:uuid4())
try:
raise Model1.DoesNotExist
except Model2.DoesNotExist:
assert False, "Model1 exception should not be caught by Model2"
except Model1.DoesNotExist:
#expected
pass
def test_does_not_exist_inherits_from_superclass(self):
""" Tests that a DoesNotExist exception can be caught by it's parent class DoesNotExist """
class Model1(Model):
__keyspace__ = 'test'
id = columns.UUID(primary_key=True, default=lambda:uuid4())
class Model2(Model1):
pass
try:
raise Model2.DoesNotExist
except Model1.DoesNotExist:
#expected
pass
except Exception:
assert False, "Model2 exception should not be caught by Model1"
def test_abstract_model_keyspace_warning_is_skipped(self):
with warnings.catch_warnings(record=True) as warn:
class NoKeyspace(Model):
__abstract__ = True
key = columns.UUID(primary_key=True)
self.assertEqual(len(warn), 0)
class TestManualTableNaming(BaseCassEngTestCase):
class RenamedTest(cqlengine.Model):
keyspace = 'whatever'
table_name = 'manual_name'
__keyspace__ = 'whatever'
__table_name__ = 'manual_name'
id = cqlengine.UUID(primary_key=True)
data = cqlengine.Text()
def test_proper_table_naming(self):
assert self.RenamedTest.column_family_name(include_keyspace=False) == 'manual_name'
assert self.RenamedTest.column_family_name(include_keyspace=True) == 'whatever.manual_name'
def test_manual_table_name_is_not_inherited(self):
class InheritedTest(self.RenamedTest): pass
assert InheritedTest.table_name is None
class AbstractModel(Model):
__abstract__ = True
__keyspace__ = 'test'
class ConcreteModel(AbstractModel):
pkey = columns.Integer(primary_key=True)
data = columns.Integer()
class AbstractModelWithCol(Model):
__keyspace__ = 'test'
__abstract__ = True
pkey = columns.Integer(primary_key=True)
class ConcreteModelWithCol(AbstractModelWithCol):
data = columns.Integer()
class AbstractModelWithFullCols(Model):
__abstract__ = True
__keyspace__ = 'test'
pkey = columns.Integer(primary_key=True)
data = columns.Integer()
class TestAbstractModelClasses(BaseCassEngTestCase):
def test_id_field_is_not_created(self):
""" Tests that an id field is not automatically generated on abstract classes """
assert not hasattr(AbstractModel, 'id')
assert not hasattr(AbstractModelWithCol, 'id')
def test_id_field_is_not_created_on_subclass(self):
assert not hasattr(ConcreteModel, 'id')
def test_abstract_attribute_is_not_inherited(self):
""" Tests that __abstract__ attribute is not inherited """
assert not ConcreteModel.__abstract__
assert not ConcreteModelWithCol.__abstract__
def test_attempting_to_save_abstract_model_fails(self):
""" Attempting to save a model from an abstract model should fail """
with self.assertRaises(CQLEngineException):
AbstractModelWithFullCols.create(pkey=1, data=2)
def test_attempting_to_create_abstract_table_fails(self):
""" Attempting to create a table from an abstract model should fail """
from cqlengine.management import sync_table
with self.assertRaises(CQLEngineException):
sync_table(AbstractModelWithFullCols)
def test_attempting_query_on_abstract_model_fails(self):
""" Tests attempting to execute query with an abstract model fails """
with self.assertRaises(CQLEngineException):
iter(AbstractModelWithFullCols.objects(pkey=5)).next()
def test_abstract_columns_are_inherited(self):
""" Tests that columns defined in the abstract class are inherited into the concrete class """
assert hasattr(ConcreteModelWithCol, 'pkey')
assert isinstance(ConcreteModelWithCol.pkey, ColumnQueryEvaluator)
assert isinstance(ConcreteModelWithCol._columns['pkey'], columns.Column)
def test_concrete_class_table_creation_cycle(self):
""" Tests that models with inherited abstract classes can be created, and have io performed """
from cqlengine.management import sync_table, drop_table
sync_table(ConcreteModelWithCol)
w1 = ConcreteModelWithCol.create(pkey=5, data=6)
w2 = ConcreteModelWithCol.create(pkey=6, data=7)
r1 = ConcreteModelWithCol.get(pkey=5)
r2 = ConcreteModelWithCol.get(pkey=6)
assert w1.pkey == r1.pkey
assert w1.data == r1.data
assert w2.pkey == r2.pkey
assert w2.data == r2.data
drop_table(ConcreteModelWithCol)
class TestCustomQuerySet(BaseCassEngTestCase):
""" Tests overriding the default queryset class """
class TestException(Exception): pass
def test_overriding_queryset(self):
class QSet(ModelQuerySet):
def create(iself, **kwargs):
raise self.TestException
class CQModel(Model):
__queryset__ = QSet
__keyspace__ = 'test'
part = columns.UUID(primary_key=True)
data = columns.Text()
with self.assertRaises(self.TestException):
CQModel.create(part=uuid4(), data='s')
def test_overriding_dmlqueryset(self):
class DMLQ(DMLQuery):
def save(iself):
raise self.TestException
class CDQModel(Model):
__keyspace__ = 'test'
__dmlquery__ = DMLQ
part = columns.UUID(primary_key=True)
data = columns.Text()
with self.assertRaises(self.TestException):
CDQModel().save()
class TestCachedLengthIsNotCarriedToSubclasses(BaseCassEngTestCase):
def test_subclassing(self):
length = len(ConcreteModelWithCol())
class AlreadyLoadedTest(ConcreteModelWithCol):
new_field = columns.Integer()
self.assertGreater(len(AlreadyLoadedTest()), length)

View File

@@ -1,12 +1,15 @@
from unittest import skip
from uuid import uuid4
from cqlengine.tests.base import BaseCassEngTestCase
from cqlengine.management import create_table
from cqlengine.management import delete_table
from cqlengine.management import sync_table
from cqlengine.management import drop_table
from cqlengine.models import Model
from cqlengine import columns
class TestModel(Model):
__keyspace__ = 'test'
id = columns.UUID(primary_key=True, default=lambda:uuid4())
count = columns.Integer()
text = columns.Text(required=False)
@@ -15,7 +18,7 @@ class TestEqualityOperators(BaseCassEngTestCase):
@classmethod
def setUpClass(cls):
super(TestEqualityOperators, cls).setUpClass()
create_table(TestModel)
sync_table(TestModel)
def setUp(self):
super(TestEqualityOperators, self).setUp()
@@ -25,7 +28,7 @@ class TestEqualityOperators(BaseCassEngTestCase):
@classmethod
def tearDownClass(cls):
super(TestEqualityOperators, cls).tearDownClass()
delete_table(TestModel)
drop_table(TestModel)
def test_an_instance_evaluates_as_equal_to_itself(self):
"""

View File

@@ -0,0 +1,56 @@
from unittest import TestCase
from cqlengine.models import Model, ModelDefinitionException
from cqlengine import columns
class TestModel(TestCase):
""" Tests the non-io functionality of models """
def test_instance_equality(self):
""" tests the model equality functionality """
class EqualityModel(Model):
__keyspace__ = 'test'
pk = columns.Integer(primary_key=True)
m0 = EqualityModel(pk=0)
m1 = EqualityModel(pk=1)
self.assertEqual(m0, m0)
self.assertNotEqual(m0, m1)
def test_model_equality(self):
""" tests the model equality functionality """
class EqualityModel0(Model):
__keyspace__ = 'test'
pk = columns.Integer(primary_key=True)
class EqualityModel1(Model):
__keyspace__ = 'test'
kk = columns.Integer(primary_key=True)
m0 = EqualityModel0(pk=0)
m1 = EqualityModel1(kk=1)
self.assertEqual(m0, m0)
self.assertNotEqual(m0, m1)
class BuiltInAttributeConflictTest(TestCase):
"""tests Model definitions that conflict with built-in attributes/methods"""
def test_model_with_attribute_name_conflict(self):
"""should raise exception when model defines column that conflicts with built-in attribute"""
with self.assertRaises(ModelDefinitionException):
class IllegalTimestampColumnModel(Model):
__keyspace__ = 'test'
my_primary_key = columns.Integer(primary_key=True)
timestamp = columns.BigInt()
def test_model_with_method_name_conflict(self):
"""should raise exception when model defines column that conflicts with built-in method"""
with self.assertRaises(ModelDefinitionException):
class IllegalFilterColumnModel(Model):
__keyspace__ = 'test'
my_primary_key = columns.Integer(primary_key=True)
filter = columns.Text()

View File

@@ -1,39 +1,78 @@
from uuid import uuid4
import random
from datetime import date
from operator import itemgetter
from cqlengine.exceptions import CQLEngineException
from cqlengine.tests.base import BaseCassEngTestCase
from cqlengine.management import create_table
from cqlengine.management import delete_table
from cqlengine.management import sync_table
from cqlengine.management import drop_table
from cqlengine.models import Model
from cqlengine import columns
class TestModel(Model):
__keyspace__ = 'test'
id = columns.UUID(primary_key=True, default=lambda:uuid4())
count = columns.Integer()
text = columns.Text(required=False)
a_bool = columns.Boolean(default=False)
class TestModel(Model):
__keyspace__ = 'test'
id = columns.UUID(primary_key=True, default=lambda:uuid4())
count = columns.Integer()
text = columns.Text(required=False)
a_bool = columns.Boolean(default=False)
class TestModelIO(BaseCassEngTestCase):
@classmethod
def setUpClass(cls):
super(TestModelIO, cls).setUpClass()
create_table(TestModel)
sync_table(TestModel)
@classmethod
def tearDownClass(cls):
super(TestModelIO, cls).tearDownClass()
delete_table(TestModel)
drop_table(TestModel)
def test_model_save_and_load(self):
"""
Tests that models can be saved and retrieved
"""
tm = TestModel.create(count=8, text='123456789')
self.assertIsInstance(tm, TestModel)
tm2 = TestModel.objects(id=tm.pk).first()
self.assertIsInstance(tm2, TestModel)
for cname in tm._columns.keys():
self.assertEquals(getattr(tm, cname), getattr(tm2, cname))
def test_model_read_as_dict(self):
"""
Tests that columns of an instance can be read as a dict.
"""
tm = TestModel.create(count=8, text='123456789', a_bool=True)
column_dict = {
'id': tm.id,
'count': tm.count,
'text': tm.text,
'a_bool': tm.a_bool,
}
self.assertEquals(sorted(tm.keys()), sorted(column_dict.keys()))
self.assertEquals(sorted(tm.values()), sorted(column_dict.values()))
self.assertEquals(
sorted(tm.items(), key=itemgetter(0)),
sorted(column_dict.items(), key=itemgetter(0)))
self.assertEquals(len(tm), len(column_dict))
for column_id in column_dict.keys():
self.assertEqual(tm[column_id], column_dict[column_id])
tm['count'] = 6
self.assertEqual(tm.count, 6)
def test_model_updating_works_properly(self):
"""
Tests that subsequent saves after initial model creation work
@@ -65,35 +104,38 @@ class TestModelIO(BaseCassEngTestCase):
tm.save()
tm2 = TestModel.objects(id=tm.pk).first()
self.assertIsInstance(tm2, TestModel)
assert tm2.text is None
assert tm2._values['text'].previous_value is None
def test_a_sensical_error_is_raised_if_you_try_to_create_a_table_twice(self):
"""
"""
create_table(TestModel)
create_table(TestModel)
sync_table(TestModel)
sync_table(TestModel)
class TestMultiKeyModel(Model):
__keyspace__ = 'test'
partition = columns.Integer(primary_key=True)
cluster = columns.Integer(primary_key=True)
count = columns.Integer(required=False)
text = columns.Text(required=False)
class TestDeleting(BaseCassEngTestCase):
@classmethod
def setUpClass(cls):
super(TestDeleting, cls).setUpClass()
delete_table(TestMultiKeyModel)
create_table(TestMultiKeyModel)
drop_table(TestMultiKeyModel)
sync_table(TestMultiKeyModel)
@classmethod
def tearDownClass(cls):
super(TestDeleting, cls).tearDownClass()
delete_table(TestMultiKeyModel)
drop_table(TestMultiKeyModel)
def test_deleting_only_deletes_one_object(self):
partition = random.randint(0,1000)
@@ -114,13 +156,13 @@ class TestUpdating(BaseCassEngTestCase):
@classmethod
def setUpClass(cls):
super(TestUpdating, cls).setUpClass()
delete_table(TestMultiKeyModel)
create_table(TestMultiKeyModel)
drop_table(TestMultiKeyModel)
sync_table(TestMultiKeyModel)
@classmethod
def tearDownClass(cls):
super(TestUpdating, cls).tearDownClass()
delete_table(TestMultiKeyModel)
drop_table(TestMultiKeyModel)
def setUp(self):
super(TestUpdating, self).setUp()
@@ -148,6 +190,14 @@ class TestUpdating(BaseCassEngTestCase):
assert check.count is None
assert check.text is None
def test_get_changed_columns(self):
assert self.instance.get_changed_columns() == []
self.instance.count = 1
changes = self.instance.get_changed_columns()
assert len(changes) == 1
assert changes == ['count']
self.instance.save()
assert self.instance.get_changed_columns() == []
class TestCanUpdate(BaseCassEngTestCase):
@@ -155,13 +205,13 @@ class TestCanUpdate(BaseCassEngTestCase):
@classmethod
def setUpClass(cls):
super(TestCanUpdate, cls).setUpClass()
delete_table(TestModel)
create_table(TestModel)
drop_table(TestModel)
sync_table(TestModel)
@classmethod
def tearDownClass(cls):
super(TestCanUpdate, cls).tearDownClass()
delete_table(TestModel)
drop_table(TestModel)
def test_success_case(self):
tm = TestModel(count=8, text='123456789')
@@ -193,16 +243,18 @@ class TestCanUpdate(BaseCassEngTestCase):
class IndexDefinitionModel(Model):
__keyspace__ = 'test'
key = columns.UUID(primary_key=True)
val = columns.Text(index=True)
class TestIndexedColumnDefinition(BaseCassEngTestCase):
def test_exception_isnt_raised_if_an_index_is_defined_more_than_once(self):
create_table(IndexDefinitionModel)
create_table(IndexDefinitionModel)
sync_table(IndexDefinitionModel)
sync_table(IndexDefinitionModel)
class ReservedWordModel(Model):
__keyspace__ = 'test'
token = columns.Text(primary_key=True)
insert = columns.Integer(index=True)
@@ -211,7 +263,7 @@ class TestQueryQuoting(BaseCassEngTestCase):
def test_reserved_cql_words_can_be_used_as_column_names(self):
"""
"""
create_table(ReservedWordModel)
sync_table(ReservedWordModel)
model1 = ReservedWordModel.create(token='1', insert=5)
@@ -222,3 +274,52 @@ class TestQueryQuoting(BaseCassEngTestCase):
assert model1.insert == model2[0].insert
class TestQueryModel(Model):
__keyspace__ = 'test'
test_id = columns.UUID(primary_key=True, default=uuid4)
date = columns.Date(primary_key=True)
description = columns.Text()
class TestQuerying(BaseCassEngTestCase):
@classmethod
def setUpClass(cls):
super(TestQuerying, cls).setUpClass()
drop_table(TestQueryModel)
sync_table(TestQueryModel)
@classmethod
def tearDownClass(cls):
super(TestQuerying, cls).tearDownClass()
drop_table(TestQueryModel)
def test_query_with_date(self):
uid = uuid4()
day = date(2013, 11, 26)
obj = TestQueryModel.create(test_id=uid, date=day, description=u'foo')
self.assertEqual(obj.description, u'foo')
inst = TestQueryModel.filter(
TestQueryModel.test_id == uid,
TestQueryModel.date == day).limit(1).first()
assert inst.test_id == uid
assert inst.date == day
def test_none_filter_fails():
class NoneFilterModel(Model):
__keyspace__ = 'test'
pk = columns.Integer(primary_key=True)
v = columns.Integer()
sync_table(NoneFilterModel)
try:
NoneFilterModel.objects(pk=None)
raise Exception("fail")
except CQLEngineException as e:
pass

View File

@@ -0,0 +1,241 @@
import uuid
import mock
from cqlengine import columns
from cqlengine import models
from cqlengine.connection import get_session
from cqlengine.tests.base import BaseCassEngTestCase
from cqlengine import management
class TestPolymorphicClassConstruction(BaseCassEngTestCase):
def test_multiple_polymorphic_key_failure(self):
""" Tests that defining a model with more than one polymorphic key fails """
with self.assertRaises(models.ModelDefinitionException):
class M(models.Model):
__keyspace__ = 'test'
partition = columns.Integer(primary_key=True)
type1 = columns.Integer(polymorphic_key=True)
type2 = columns.Integer(polymorphic_key=True)
def test_polymorphic_key_inheritance(self):
""" Tests that polymorphic_key attribute is not inherited """
class Base(models.Model):
__keyspace__ = 'test'
partition = columns.Integer(primary_key=True)
type1 = columns.Integer(polymorphic_key=True)
class M1(Base):
__polymorphic_key__ = 1
class M2(M1):
pass
assert M2.__polymorphic_key__ is None
def test_polymorphic_metaclass(self):
""" Tests that the model meta class configures polymorphic models properly """
class Base(models.Model):
__keyspace__ = 'test'
partition = columns.Integer(primary_key=True)
type1 = columns.Integer(polymorphic_key=True)
class M1(Base):
__polymorphic_key__ = 1
assert Base._is_polymorphic
assert M1._is_polymorphic
assert Base._is_polymorphic_base
assert not M1._is_polymorphic_base
assert Base._polymorphic_column is Base._columns['type1']
assert M1._polymorphic_column is M1._columns['type1']
assert Base._polymorphic_column_name == 'type1'
assert M1._polymorphic_column_name == 'type1'
def test_table_names_are_inherited_from_poly_base(self):
class Base(models.Model):
__keyspace__ = 'test'
partition = columns.Integer(primary_key=True)
type1 = columns.Integer(polymorphic_key=True)
class M1(Base):
__polymorphic_key__ = 1
assert Base.column_family_name() == M1.column_family_name()
def test_collection_columns_cant_be_polymorphic_keys(self):
with self.assertRaises(models.ModelDefinitionException):
class Base(models.Model):
__keyspace__ = 'test'
partition = columns.Integer(primary_key=True)
type1 = columns.Set(columns.Integer, polymorphic_key=True)
class PolyBase(models.Model):
__keyspace__ = 'test'
partition = columns.UUID(primary_key=True, default=uuid.uuid4)
row_type = columns.Integer(polymorphic_key=True)
class Poly1(PolyBase):
__polymorphic_key__ = 1
data1 = columns.Text()
class Poly2(PolyBase):
__polymorphic_key__ = 2
data2 = columns.Text()
class TestPolymorphicModel(BaseCassEngTestCase):
@classmethod
def setUpClass(cls):
super(TestPolymorphicModel, cls).setUpClass()
management.sync_table(Poly1)
management.sync_table(Poly2)
@classmethod
def tearDownClass(cls):
super(TestPolymorphicModel, cls).tearDownClass()
management.drop_table(Poly1)
management.drop_table(Poly2)
def test_saving_base_model_fails(self):
with self.assertRaises(models.PolyMorphicModelException):
PolyBase.create()
def test_saving_subclass_saves_poly_key(self):
p1 = Poly1.create(data1='pickle')
p2 = Poly2.create(data2='bacon')
assert p1.row_type == Poly1.__polymorphic_key__
assert p2.row_type == Poly2.__polymorphic_key__
def test_query_deserialization(self):
p1 = Poly1.create(data1='pickle')
p2 = Poly2.create(data2='bacon')
p1r = PolyBase.get(partition=p1.partition)
p2r = PolyBase.get(partition=p2.partition)
assert isinstance(p1r, Poly1)
assert isinstance(p2r, Poly2)
def test_delete_on_polymorphic_subclass_does_not_include_polymorphic_key(self):
p1 = Poly1.create()
session = get_session()
with mock.patch.object(session, 'execute') as m:
Poly1.objects(partition=p1.partition).delete()
# make sure our polymorphic key isn't in the CQL
# not sure how we would even get here if it was in there
# since the CQL would fail.
self.assertNotIn("row_type", m.call_args[0][0].query_string)
class UnindexedPolyBase(models.Model):
__keyspace__ = 'test'
partition = columns.UUID(primary_key=True, default=uuid.uuid4)
cluster = columns.UUID(primary_key=True, default=uuid.uuid4)
row_type = columns.Integer(polymorphic_key=True)
class UnindexedPoly1(UnindexedPolyBase):
__polymorphic_key__ = 1
data1 = columns.Text()
class UnindexedPoly2(UnindexedPolyBase):
__polymorphic_key__ = 2
data2 = columns.Text()
class UnindexedPoly3(UnindexedPoly2):
__polymorphic_key__ = 3
data3 = columns.Text()
class TestUnindexedPolymorphicQuery(BaseCassEngTestCase):
@classmethod
def setUpClass(cls):
super(TestUnindexedPolymorphicQuery, cls).setUpClass()
management.sync_table(UnindexedPoly1)
management.sync_table(UnindexedPoly2)
management.sync_table(UnindexedPoly3)
cls.p1 = UnindexedPoly1.create(data1='pickle')
cls.p2 = UnindexedPoly2.create(partition=cls.p1.partition, data2='bacon')
cls.p3 = UnindexedPoly3.create(partition=cls.p1.partition, data3='turkey')
@classmethod
def tearDownClass(cls):
super(TestUnindexedPolymorphicQuery, cls).tearDownClass()
management.drop_table(UnindexedPoly1)
management.drop_table(UnindexedPoly2)
management.drop_table(UnindexedPoly3)
def test_non_conflicting_type_results_work(self):
p1, p2, p3 = self.p1, self.p2, self.p3
assert len(list(UnindexedPoly1.objects(partition=p1.partition, cluster=p1.cluster))) == 1
assert len(list(UnindexedPoly2.objects(partition=p1.partition, cluster=p2.cluster))) == 1
def test_subclassed_model_results_work_properly(self):
p1, p2, p3 = self.p1, self.p2, self.p3
assert len(list(UnindexedPoly2.objects(partition=p1.partition, cluster__in=[p2.cluster, p3.cluster]))) == 2
def test_conflicting_type_results(self):
with self.assertRaises(models.PolyMorphicModelException):
list(UnindexedPoly1.objects(partition=self.p1.partition))
with self.assertRaises(models.PolyMorphicModelException):
list(UnindexedPoly2.objects(partition=self.p1.partition))
class IndexedPolyBase(models.Model):
__keyspace__ = 'test'
partition = columns.UUID(primary_key=True, default=uuid.uuid4)
cluster = columns.UUID(primary_key=True, default=uuid.uuid4)
row_type = columns.Integer(polymorphic_key=True, index=True)
class IndexedPoly1(IndexedPolyBase):
__polymorphic_key__ = 1
data1 = columns.Text()
class IndexedPoly2(IndexedPolyBase):
__polymorphic_key__ = 2
data2 = columns.Text()
class TestIndexedPolymorphicQuery(BaseCassEngTestCase):
@classmethod
def setUpClass(cls):
super(TestIndexedPolymorphicQuery, cls).setUpClass()
management.sync_table(IndexedPoly1)
management.sync_table(IndexedPoly2)
cls.p1 = IndexedPoly1.create(data1='pickle')
cls.p2 = IndexedPoly2.create(partition=cls.p1.partition, data2='bacon')
@classmethod
def tearDownClass(cls):
super(TestIndexedPolymorphicQuery, cls).tearDownClass()
management.drop_table(IndexedPoly1)
management.drop_table(IndexedPoly2)
def test_success_case(self):
assert len(list(IndexedPoly1.objects(partition=self.p1.partition))) == 1
assert len(list(IndexedPoly2.objects(partition=self.p1.partition))) == 1

View File

@@ -0,0 +1,91 @@
from uuid import uuid4
from mock import patch
from cqlengine.exceptions import ValidationError
from cqlengine.tests.base import BaseCassEngTestCase
from cqlengine.models import Model
from cqlengine import columns
from cqlengine.management import sync_table, drop_table
class TestUpdateModel(Model):
__keyspace__ = 'test'
partition = columns.UUID(primary_key=True, default=uuid4)
cluster = columns.UUID(primary_key=True, default=uuid4)
count = columns.Integer(required=False)
text = columns.Text(required=False, index=True)
class ModelUpdateTests(BaseCassEngTestCase):
@classmethod
def setUpClass(cls):
super(ModelUpdateTests, cls).setUpClass()
sync_table(TestUpdateModel)
@classmethod
def tearDownClass(cls):
super(ModelUpdateTests, cls).tearDownClass()
drop_table(TestUpdateModel)
def test_update_model(self):
""" tests calling udpate on models with no values passed in """
m0 = TestUpdateModel.create(count=5, text='monkey')
# independently save over a new count value, unknown to original instance
m1 = TestUpdateModel.get(partition=m0.partition, cluster=m0.cluster)
m1.count = 6
m1.save()
# update the text, and call update
m0.text = 'monkey land'
m0.update()
# database should reflect both updates
m2 = TestUpdateModel.get(partition=m0.partition, cluster=m0.cluster)
self.assertEqual(m2.count, m1.count)
self.assertEqual(m2.text, m0.text)
def test_update_values(self):
""" tests calling update on models with values passed in """
m0 = TestUpdateModel.create(count=5, text='monkey')
# independently save over a new count value, unknown to original instance
m1 = TestUpdateModel.get(partition=m0.partition, cluster=m0.cluster)
m1.count = 6
m1.save()
# update the text, and call update
m0.update(text='monkey land')
self.assertEqual(m0.text, 'monkey land')
# database should reflect both updates
m2 = TestUpdateModel.get(partition=m0.partition, cluster=m0.cluster)
self.assertEqual(m2.count, m1.count)
self.assertEqual(m2.text, m0.text)
def test_noop_model_update(self):
""" tests that calling update on a model with no changes will do nothing. """
m0 = TestUpdateModel.create(count=5, text='monkey')
with patch.object(self.session, 'execute') as execute:
m0.update()
assert execute.call_count == 0
with patch.object(self.session, 'execute') as execute:
m0.update(count=5)
assert execute.call_count == 0
def test_invalid_update_kwarg(self):
""" tests that passing in a kwarg to the update method that isn't a column will fail """
m0 = TestUpdateModel.create(count=5, text='monkey')
with self.assertRaises(ValidationError):
m0.update(numbers=20)
def test_primary_key_update_failure(self):
""" tests that attempting to update the value of a primary key will fail """
m0 = TestUpdateModel.create(count=5, text='monkey')
with self.assertRaises(ValidationError):
m0.update(partition=uuid4())

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,61 @@
import random
from cqlengine.tests.base import BaseCassEngTestCase
from cqlengine.management import sync_table
from cqlengine.management import drop_table
from cqlengine.models import Model
from cqlengine import columns
class TestModel(Model):
__keyspace__ = 'test'
id = columns.Integer(primary_key=True)
clustering_key = columns.Integer(primary_key=True, clustering_order='desc')
class TestClusteringComplexModel(Model):
__keyspace__ = 'test'
id = columns.Integer(primary_key=True)
clustering_key = columns.Integer(primary_key=True, clustering_order='desc')
some_value = columns.Integer()
class TestClusteringOrder(BaseCassEngTestCase):
@classmethod
def setUpClass(cls):
super(TestClusteringOrder, cls).setUpClass()
sync_table(TestModel)
@classmethod
def tearDownClass(cls):
super(TestClusteringOrder, cls).tearDownClass()
drop_table(TestModel)
def test_clustering_order(self):
"""
Tests that models can be saved and retrieved
"""
items = list(range(20))
random.shuffle(items)
for i in items:
TestModel.create(id=1, clustering_key=i)
values = list(TestModel.objects.values_list('clustering_key', flat=True))
# [19L, 18L, 17L, 16L, 15L, 14L, 13L, 12L, 11L, 10L, 9L, 8L, 7L, 6L, 5L, 4L, 3L, 2L, 1L, 0L]
self.assertEquals(values, sorted(items, reverse=True))
def test_clustering_order_more_complex(self):
"""
Tests that models can be saved and retrieved
"""
sync_table(TestClusteringComplexModel)
items = list(range(20))
random.shuffle(items)
for i in items:
TestClusteringComplexModel.create(id=1, clustering_key=i, some_value=2)
values = list(TestClusteringComplexModel.objects.values_list('some_value', flat=True))
self.assertEquals([2] * 20, values)
drop_table(TestClusteringComplexModel)

View File

@@ -0,0 +1 @@
__author__ = 'bdeggleston'

View File

@@ -0,0 +1,9 @@
from unittest import TestCase
from cqlengine.operators import BaseQueryOperator, QueryOperatorException
class BaseOperatorTest(TestCase):
def test_get_operator_cannot_be_called_from_base_class(self):
with self.assertRaises(QueryOperatorException):
BaseQueryOperator.get_operator('*')

View File

@@ -0,0 +1,30 @@
from unittest import TestCase
from cqlengine.operators import *
class TestWhereOperators(TestCase):
def test_symbol_lookup(self):
""" tests where symbols are looked up properly """
def check_lookup(symbol, expected):
op = BaseWhereOperator.get_operator(symbol)
self.assertEqual(op, expected)
check_lookup('EQ', EqualsOperator)
check_lookup('IN', InOperator)
check_lookup('GT', GreaterThanOperator)
check_lookup('GTE', GreaterThanOrEqualOperator)
check_lookup('LT', LessThanOperator)
check_lookup('LTE', LessThanOrEqualOperator)
def test_operator_rendering(self):
""" tests symbols are rendered properly """
self.assertEqual("=", unicode(EqualsOperator()))
self.assertEqual("IN", unicode(InOperator()))
self.assertEqual(">", unicode(GreaterThanOperator()))
self.assertEqual(">=", unicode(GreaterThanOrEqualOperator()))
self.assertEqual("<", unicode(LessThanOperator()))
self.assertEqual("<=", unicode(LessThanOrEqualOperator()))

View File

@@ -3,28 +3,35 @@ from unittest import skip
from uuid import uuid4
import random
from cqlengine import Model, columns
from cqlengine.management import delete_table, create_table
from cqlengine.query import BatchQuery
from cqlengine.management import drop_table, sync_table
from cqlengine.query import BatchQuery, DMLQuery
from cqlengine.tests.base import BaseCassEngTestCase
class TestMultiKeyModel(Model):
__keyspace__ = 'test'
partition = columns.Integer(primary_key=True)
cluster = columns.Integer(primary_key=True)
count = columns.Integer(required=False)
text = columns.Text(required=False)
class BatchQueryLogModel(Model):
__keyspace__ = 'test'
# simple k/v table
k = columns.Integer(primary_key=True)
v = columns.Integer()
class BatchQueryTests(BaseCassEngTestCase):
@classmethod
def setUpClass(cls):
super(BatchQueryTests, cls).setUpClass()
delete_table(TestMultiKeyModel)
create_table(TestMultiKeyModel)
drop_table(TestMultiKeyModel)
sync_table(TestMultiKeyModel)
@classmethod
def tearDownClass(cls):
super(BatchQueryTests, cls).tearDownClass()
delete_table(TestMultiKeyModel)
drop_table(TestMultiKeyModel)
def setUp(self):
super(BatchQueryTests, self).setUp()
@@ -104,3 +111,61 @@ class BatchQueryTests(BaseCassEngTestCase):
for m in TestMultiKeyModel.all():
m.delete()
def test_none_success_case(self):
""" Tests that passing None into the batch call clears any batch object """
b = BatchQuery()
q = TestMultiKeyModel.objects.batch(b)
assert q._batch == b
q = q.batch(None)
assert q._batch is None
def test_dml_none_success_case(self):
""" Tests that passing None into the batch call clears any batch object """
b = BatchQuery()
q = DMLQuery(TestMultiKeyModel, batch=b)
assert q._batch == b
q.batch(None)
assert q._batch is None
def test_batch_execute_on_exception_succeeds(self):
# makes sure if execute_on_exception == True we still apply the batch
drop_table(BatchQueryLogModel)
sync_table(BatchQueryLogModel)
obj = BatchQueryLogModel.objects(k=1)
self.assertEqual(0, len(obj))
try:
with BatchQuery(execute_on_exception=True) as b:
BatchQueryLogModel.batch(b).create(k=1, v=1)
raise Exception("Blah")
except:
pass
obj = BatchQueryLogModel.objects(k=1)
# should be 1 because the batch should execute
self.assertEqual(1, len(obj))
def test_batch_execute_on_exception_skips_if_not_specified(self):
# makes sure if execute_on_exception == True we still apply the batch
drop_table(BatchQueryLogModel)
sync_table(BatchQueryLogModel)
obj = BatchQueryLogModel.objects(k=2)
self.assertEqual(0, len(obj))
try:
with BatchQuery() as b:
BatchQueryLogModel.batch(b).create(k=2, v=2)
raise Exception("Blah")
except:
pass
obj = BatchQueryLogModel.objects(k=2)
# should be 0 because the batch should not execute
self.assertEqual(0, len(obj))

View File

@@ -4,13 +4,14 @@ from uuid import uuid4
from cqlengine.tests.base import BaseCassEngTestCase
from cqlengine.exceptions import ModelException
from cqlengine.management import create_table
from cqlengine.management import delete_table
from cqlengine.management import sync_table
from cqlengine.management import drop_table
from cqlengine.models import Model
from cqlengine import columns
from cqlengine import query
class DateTimeQueryTestModel(Model):
__keyspace__ = 'test'
user = columns.Integer(primary_key=True)
day = columns.DateTime(primary_key=True)
data = columns.Text()
@@ -20,7 +21,7 @@ class TestDateTimeQueries(BaseCassEngTestCase):
@classmethod
def setUpClass(cls):
super(TestDateTimeQueries, cls).setUpClass()
create_table(DateTimeQueryTestModel)
sync_table(DateTimeQueryTestModel)
cls.base_date = datetime.now() - timedelta(days=10)
for x in range(7):
@@ -35,7 +36,7 @@ class TestDateTimeQueries(BaseCassEngTestCase):
@classmethod
def tearDownClass(cls):
super(TestDateTimeQueries, cls).tearDownClass()
delete_table(DateTimeQueryTestModel)
drop_table(DateTimeQueryTestModel)
def test_range_query(self):
""" Tests that loading from a range of dates works properly """
@@ -43,7 +44,7 @@ class TestDateTimeQueries(BaseCassEngTestCase):
end = start + timedelta(days=3)
results = DateTimeQueryTestModel.filter(user=0, day__gte=start, day__lt=end)
assert len(results) == 3
assert len(results) == 3
def test_datetime_precision(self):
""" Tests that millisecond resolution is preserved when saving datetime objects """

View File

@@ -0,0 +1,246 @@
from cqlengine import operators
from cqlengine.named import NamedKeyspace
from cqlengine.operators import EqualsOperator, GreaterThanOrEqualOperator
from cqlengine.query import ResultObject
from cqlengine.tests.query.test_queryset import BaseQuerySetUsage
from cqlengine.tests.base import BaseCassEngTestCase
class TestQuerySetOperation(BaseCassEngTestCase):
@classmethod
def setUpClass(cls):
super(TestQuerySetOperation, cls).setUpClass()
cls.keyspace = NamedKeyspace('cqlengine_test')
cls.table = cls.keyspace.table('test_model')
def test_query_filter_parsing(self):
"""
Tests the queryset filter method parses it's kwargs properly
"""
query1 = self.table.objects(test_id=5)
assert len(query1._where) == 1
op = query1._where[0]
assert isinstance(op.operator, operators.EqualsOperator)
assert op.value == 5
query2 = query1.filter(expected_result__gte=1)
assert len(query2._where) == 2
op = query2._where[1]
assert isinstance(op.operator, operators.GreaterThanOrEqualOperator)
assert op.value == 1
def test_query_expression_parsing(self):
""" Tests that query experessions are evaluated properly """
query1 = self.table.filter(self.table.column('test_id') == 5)
assert len(query1._where) == 1
op = query1._where[0]
assert isinstance(op.operator, operators.EqualsOperator)
assert op.value == 5
query2 = query1.filter(self.table.column('expected_result') >= 1)
assert len(query2._where) == 2
op = query2._where[1]
assert isinstance(op.operator, operators.GreaterThanOrEqualOperator)
assert op.value == 1
def test_filter_method_where_clause_generation(self):
"""
Tests the where clause creation
"""
query1 = self.table.objects(test_id=5)
self.assertEqual(len(query1._where), 1)
where = query1._where[0]
self.assertEqual(where.field, 'test_id')
self.assertEqual(where.value, 5)
query2 = query1.filter(expected_result__gte=1)
self.assertEqual(len(query2._where), 2)
where = query2._where[0]
self.assertEqual(where.field, 'test_id')
self.assertIsInstance(where.operator, EqualsOperator)
self.assertEqual(where.value, 5)
where = query2._where[1]
self.assertEqual(where.field, 'expected_result')
self.assertIsInstance(where.operator, GreaterThanOrEqualOperator)
self.assertEqual(where.value, 1)
def test_query_expression_where_clause_generation(self):
"""
Tests the where clause creation
"""
query1 = self.table.objects(self.table.column('test_id') == 5)
self.assertEqual(len(query1._where), 1)
where = query1._where[0]
self.assertEqual(where.field, 'test_id')
self.assertEqual(where.value, 5)
query2 = query1.filter(self.table.column('expected_result') >= 1)
self.assertEqual(len(query2._where), 2)
where = query2._where[0]
self.assertEqual(where.field, 'test_id')
self.assertIsInstance(where.operator, EqualsOperator)
self.assertEqual(where.value, 5)
where = query2._where[1]
self.assertEqual(where.field, 'expected_result')
self.assertIsInstance(where.operator, GreaterThanOrEqualOperator)
self.assertEqual(where.value, 1)
class TestQuerySetCountSelectionAndIteration(BaseQuerySetUsage):
@classmethod
def setUpClass(cls):
super(TestQuerySetCountSelectionAndIteration, cls).setUpClass()
from cqlengine.tests.query.test_queryset import TestModel
ks,tn = TestModel.column_family_name().split('.')
cls.keyspace = NamedKeyspace(ks)
cls.table = cls.keyspace.table(tn)
def test_count(self):
""" Tests that adding filtering statements affects the count query as expected """
assert self.table.objects.count() == 12
q = self.table.objects(test_id=0)
assert q.count() == 4
def test_query_expression_count(self):
""" Tests that adding query statements affects the count query as expected """
assert self.table.objects.count() == 12
q = self.table.objects(self.table.column('test_id') == 0)
assert q.count() == 4
def test_iteration(self):
""" Tests that iterating over a query set pulls back all of the expected results """
q = self.table.objects(test_id=0)
#tuple of expected attempt_id, expected_result values
compare_set = set([(0,5), (1,10), (2,15), (3,20)])
for t in q:
val = t.attempt_id, t.expected_result
assert val in compare_set
compare_set.remove(val)
assert len(compare_set) == 0
# test with regular filtering
q = self.table.objects(attempt_id=3).allow_filtering()
assert len(q) == 3
#tuple of expected test_id, expected_result values
compare_set = set([(0,20), (1,20), (2,75)])
for t in q:
val = t.test_id, t.expected_result
assert val in compare_set
compare_set.remove(val)
assert len(compare_set) == 0
# test with query method
q = self.table.objects(self.table.column('attempt_id') == 3).allow_filtering()
assert len(q) == 3
#tuple of expected test_id, expected_result values
compare_set = set([(0,20), (1,20), (2,75)])
for t in q:
val = t.test_id, t.expected_result
assert val in compare_set
compare_set.remove(val)
assert len(compare_set) == 0
def test_multiple_iterations_work_properly(self):
""" Tests that iterating over a query set more than once works """
# test with both the filtering method and the query method
for q in (self.table.objects(test_id=0), self.table.objects(self.table.column('test_id') == 0)):
#tuple of expected attempt_id, expected_result values
compare_set = set([(0,5), (1,10), (2,15), (3,20)])
for t in q:
val = t.attempt_id, t.expected_result
assert val in compare_set
compare_set.remove(val)
assert len(compare_set) == 0
#try it again
compare_set = set([(0,5), (1,10), (2,15), (3,20)])
for t in q:
val = t.attempt_id, t.expected_result
assert val in compare_set
compare_set.remove(val)
assert len(compare_set) == 0
def test_multiple_iterators_are_isolated(self):
"""
tests that the use of one iterator does not affect the behavior of another
"""
for q in (self.table.objects(test_id=0), self.table.objects(self.table.column('test_id') == 0)):
q = q.order_by('attempt_id')
expected_order = [0,1,2,3]
iter1 = iter(q)
iter2 = iter(q)
for attempt_id in expected_order:
assert iter1.next().attempt_id == attempt_id
assert iter2.next().attempt_id == attempt_id
def test_get_success_case(self):
"""
Tests that the .get() method works on new and existing querysets
"""
m = self.table.objects.get(test_id=0, attempt_id=0)
assert isinstance(m, ResultObject)
assert m.test_id == 0
assert m.attempt_id == 0
q = self.table.objects(test_id=0, attempt_id=0)
m = q.get()
assert isinstance(m, ResultObject)
assert m.test_id == 0
assert m.attempt_id == 0
q = self.table.objects(test_id=0)
m = q.get(attempt_id=0)
assert isinstance(m, ResultObject)
assert m.test_id == 0
assert m.attempt_id == 0
def test_query_expression_get_success_case(self):
"""
Tests that the .get() method works on new and existing querysets
"""
m = self.table.get(self.table.column('test_id') == 0, self.table.column('attempt_id') == 0)
assert isinstance(m, ResultObject)
assert m.test_id == 0
assert m.attempt_id == 0
q = self.table.objects(self.table.column('test_id') == 0, self.table.column('attempt_id') == 0)
m = q.get()
assert isinstance(m, ResultObject)
assert m.test_id == 0
assert m.attempt_id == 0
q = self.table.objects(self.table.column('test_id') == 0)
m = q.get(self.table.column('attempt_id') == 0)
assert isinstance(m, ResultObject)
assert m.test_id == 0
assert m.attempt_id == 0
def test_get_doesnotexist_exception(self):
"""
Tests that get calls that don't return a result raises a DoesNotExist error
"""
with self.assertRaises(self.table.DoesNotExist):
self.table.objects.get(test_id=100)
def test_get_multipleobjects_exception(self):
"""
Tests that get calls that return multiple results raise a MultipleObjectsReturned error
"""
with self.assertRaises(self.table.MultipleObjectsReturned):
self.table.objects.get(test_id=1)

View File

@@ -1,10 +1,13 @@
from datetime import datetime
import time
from cqlengine.columns import DateTime
from cqlengine.tests.base import BaseCassEngTestCase
from cqlengine import columns, Model
from cqlengine import functions
from cqlengine import query
from cqlengine.statements import WhereClause
from cqlengine.operators import EqualsOperator
from cqlengine.management import sync_table, drop_table
class TestQuerySetOperation(BaseCassEngTestCase):
@@ -13,22 +16,96 @@ class TestQuerySetOperation(BaseCassEngTestCase):
Tests that queries with helper functions are generated properly
"""
now = datetime.now()
col = columns.DateTime()
col.set_column_name('time')
qry = query.EqualsOperator(col, functions.MaxTimeUUID(now))
where = WhereClause('time', EqualsOperator(), functions.MaxTimeUUID(now))
where.set_context_id(5)
assert qry.cql == '"time" = MaxTimeUUID(:{})'.format(qry.identifier)
self.assertEqual(str(where), '"time" = MaxTimeUUID(%(5)s)')
ctx = {}
where.update_context(ctx)
self.assertEqual(ctx, {'5': DateTime().to_database(now)})
def test_mintimeuuid_function(self):
"""
Tests that queries with helper functions are generated properly
"""
now = datetime.now()
col = columns.DateTime()
col.set_column_name('time')
qry = query.EqualsOperator(col, functions.MinTimeUUID(now))
where = WhereClause('time', EqualsOperator(), functions.MinTimeUUID(now))
where.set_context_id(5)
assert qry.cql == '"time" = MinTimeUUID(:{})'.format(qry.identifier)
self.assertEqual(str(where), '"time" = MinTimeUUID(%(5)s)')
ctx = {}
where.update_context(ctx)
self.assertEqual(ctx, {'5': DateTime().to_database(now)})
class TokenTestModel(Model):
__keyspace__ = 'test'
key = columns.Integer(primary_key=True)
val = columns.Integer()
class TestTokenFunction(BaseCassEngTestCase):
def setUp(self):
super(TestTokenFunction, self).setUp()
sync_table(TokenTestModel)
def tearDown(self):
super(TestTokenFunction, self).tearDown()
drop_table(TokenTestModel)
def test_token_function(self):
""" Tests that token functions work properly """
assert TokenTestModel.objects().count() == 0
for i in range(10):
TokenTestModel.create(key=i, val=i)
assert TokenTestModel.objects().count() == 10
seen_keys = set()
last_token = None
for instance in TokenTestModel.objects().limit(5):
last_token = instance.key
seen_keys.add(last_token)
assert len(seen_keys) == 5
for instance in TokenTestModel.objects(pk__token__gt=functions.Token(last_token)):
seen_keys.add(instance.key)
assert len(seen_keys) == 10
assert all([i in seen_keys for i in range(10)])
def test_compound_pk_token_function(self):
class TestModel(Model):
__keyspace__ = 'test'
p1 = columns.Text(partition_key=True)
p2 = columns.Text(partition_key=True)
func = functions.Token('a', 'b')
q = TestModel.objects.filter(pk__token__gt=func)
where = q._where[0]
where.set_context_id(1)
self.assertEquals(str(where), 'token("p1", "p2") > token(%({})s, %({})s)'.format(1, 2))
# Verify that a SELECT query can be successfully generated
str(q._select_query())
# Token(tuple()) is also possible for convenience
# it (allows for Token(obj.pk) syntax)
func = functions.Token(('a', 'b'))
q = TestModel.objects.filter(pk__token__gt=func)
where = q._where[0]
where.set_context_id(1)
self.assertEquals(str(where), 'token("p1", "p2") > token(%({})s, %({})s)'.format(1, 2))
str(q._select_query())
# The 'pk__token' virtual column may only be compared to a Token
self.assertRaises(query.QueryException, TestModel.objects.filter, pk__token__gt=10)
# A Token may only be compared to the `pk__token' virtual column
func = functions.Token('a', 'b')
self.assertRaises(query.QueryException, TestModel.objects.filter, p1__gt=func)
# The # of arguments to Token must match the # of partition keys
func = functions.Token('a')
self.assertRaises(query.QueryException, TestModel.objects.filter, pk__token__gt=func)

View File

@@ -1,39 +1,74 @@
from datetime import datetime
import time
from unittest import TestCase, skipUnless
from uuid import uuid1, uuid4
import uuid
from cqlengine.tests.base import BaseCassEngTestCase
import mock
from cqlengine.exceptions import ModelException
from cqlengine import functions
from cqlengine.management import create_table
from cqlengine.management import delete_table
from cqlengine.management import sync_table, drop_table, sync_table
from cqlengine.management import drop_table
from cqlengine.models import Model
from cqlengine import columns
from cqlengine import query
from datetime import timedelta
from datetime import tzinfo
from cqlengine import statements
from cqlengine import operators
from cqlengine.connection import get_cluster, get_session
cluster = get_cluster()
class TzOffset(tzinfo):
"""Minimal implementation of a timezone offset to help testing with timezone
aware datetimes.
"""
def __init__(self, offset):
self._offset = timedelta(hours=offset)
def utcoffset(self, dt):
return self._offset
def tzname(self, dt):
return 'TzOffset: {}'.format(self._offset.hours)
def dst(self, dt):
return timedelta(0)
class TestModel(Model):
__keyspace__ = 'test'
test_id = columns.Integer(primary_key=True)
attempt_id = columns.Integer(primary_key=True)
description = columns.Text()
expected_result = columns.Integer()
test_result = columns.Integer()
class IndexedTestModel(Model):
__keyspace__ = 'test'
test_id = columns.Integer(primary_key=True)
attempt_id = columns.Integer(index=True)
description = columns.Text()
expected_result = columns.Integer()
test_result = columns.Integer(index=True)
class TestMultiClusteringModel(Model):
__keyspace__ = 'test'
one = columns.Integer(primary_key=True)
two = columns.Integer(primary_key=True)
three = columns.Integer(primary_key=True)
class TestQuerySetOperation(BaseCassEngTestCase):
def test_query_filter_parsing(self):
"""
Tests the queryset filter method parses it's kwargs properly
@@ -42,14 +77,35 @@ class TestQuerySetOperation(BaseCassEngTestCase):
assert len(query1._where) == 1
op = query1._where[0]
assert isinstance(op, query.EqualsOperator)
assert isinstance(op, statements.WhereClause)
assert isinstance(op.operator, operators.EqualsOperator)
assert op.value == 5
query2 = query1.filter(expected_result__gte=1)
assert len(query2._where) == 2
op = query2._where[1]
assert isinstance(op, query.GreaterThanOrEqualOperator)
self.assertIsInstance(op, statements.WhereClause)
self.assertIsInstance(op.operator, operators.GreaterThanOrEqualOperator)
assert op.value == 1
def test_query_expression_parsing(self):
""" Tests that query experessions are evaluated properly """
query1 = TestModel.filter(TestModel.test_id == 5)
assert len(query1._where) == 1
op = query1._where[0]
assert isinstance(op, statements.WhereClause)
assert isinstance(op.operator, operators.EqualsOperator)
assert op.value == 5
query2 = query1.filter(TestModel.expected_result >= 1)
assert len(query2._where) == 2
op = query2._where[1]
self.assertIsInstance(op, statements.WhereClause)
self.assertIsInstance(op.operator, operators.GreaterThanOrEqualOperator)
assert op.value == 1
def test_using_invalid_column_names_in_filter_kwargs_raises_error(self):
@@ -57,27 +113,21 @@ class TestQuerySetOperation(BaseCassEngTestCase):
Tests that using invalid or nonexistant column names for filter args raises an error
"""
with self.assertRaises(query.QueryException):
query0 = TestModel.objects(nonsense=5)
TestModel.objects(nonsense=5)
def test_where_clause_generation(self):
def test_using_nonexistant_column_names_in_query_args_raises_error(self):
"""
Tests the where clause creation
Tests that using invalid or nonexistant columns for query args raises an error
"""
query1 = TestModel.objects(test_id=5)
ids = [o.identifier for o in query1._where]
where = query1._where_clause()
assert where == '"test_id" = :{}'.format(*ids)
with self.assertRaises(AttributeError):
TestModel.objects(TestModel.nonsense == 5)
query2 = query1.filter(expected_result__gte=1)
ids = [o.identifier for o in query2._where]
where = query2._where_clause()
assert where == '"test_id" = :{} AND "expected_result" >= :{}'.format(*ids)
def test_querystring_generation(self):
def test_using_non_query_operators_in_query_args_raises_error(self):
"""
Tests the select querystring creation
Tests that providing query args that are not query operator instances raises an error
"""
with self.assertRaises(query.QueryException):
TestModel.objects(5)
def test_queryset_is_immutable(self):
"""
@@ -88,10 +138,25 @@ class TestQuerySetOperation(BaseCassEngTestCase):
query2 = query1.filter(expected_result__gte=1)
assert len(query2._where) == 2
assert len(query1._where) == 1
def test_the_all_method_clears_where_filter(self):
def test_queryset_limit_immutability(self):
"""
Tests that calling all on a queryset with previously defined filters returns a queryset with no filters
Tests that calling a queryset function that changes it's state returns a new queryset with same limit
"""
query1 = TestModel.objects(test_id=5).limit(1)
assert query1._limit == 1
query2 = query1.filter(expected_result__gte=1)
assert query2._limit == 1
query3 = query1.filter(expected_result__gte=1).limit(2)
assert query1._limit == 1
assert query3._limit == 2
def test_the_all_method_duplicates_queryset(self):
"""
Tests that calling all on a queryset with previously defined filters duplicates queryset
"""
query1 = TestModel.objects(test_id=5)
assert len(query1._where) == 1
@@ -100,7 +165,7 @@ class TestQuerySetOperation(BaseCassEngTestCase):
assert len(query2._where) == 2
query3 = query2.all()
assert len(query3._where) == 0
assert query3 == query2
def test_defining_only_and_defer_fails(self):
"""
@@ -112,16 +177,16 @@ class TestQuerySetOperation(BaseCassEngTestCase):
Tests that setting only or defer fields that don't exist raises an exception
"""
class BaseQuerySetUsage(BaseCassEngTestCase):
class BaseQuerySetUsage(BaseCassEngTestCase):
@classmethod
def setUpClass(cls):
super(BaseQuerySetUsage, cls).setUpClass()
delete_table(TestModel)
delete_table(IndexedTestModel)
create_table(TestModel)
create_table(IndexedTestModel)
create_table(TestMultiClusteringModel)
drop_table(TestModel)
drop_table(IndexedTestModel)
sync_table(TestModel)
sync_table(IndexedTestModel)
sync_table(TestMultiClusteringModel)
TestModel.objects.create(test_id=0, attempt_id=0, description='try1', expected_result=5, test_result=30)
TestModel.objects.create(test_id=0, attempt_id=1, description='try2', expected_result=10, test_result=30)
@@ -149,39 +214,71 @@ class BaseQuerySetUsage(BaseCassEngTestCase):
IndexedTestModel.objects.create(test_id=7, attempt_id=3, description='try8', expected_result=20, test_result=20)
IndexedTestModel.objects.create(test_id=8, attempt_id=0, description='try9', expected_result=50, test_result=40)
IndexedTestModel.objects.create(test_id=9, attempt_id=1, description='try10', expected_result=60, test_result=40)
IndexedTestModel.objects.create(test_id=10, attempt_id=2, description='try11', expected_result=70, test_result=45)
IndexedTestModel.objects.create(test_id=11, attempt_id=3, description='try12', expected_result=75, test_result=45)
IndexedTestModel.objects.create(test_id=9, attempt_id=1, description='try10', expected_result=60,
test_result=40)
IndexedTestModel.objects.create(test_id=10, attempt_id=2, description='try11', expected_result=70,
test_result=45)
IndexedTestModel.objects.create(test_id=11, attempt_id=3, description='try12', expected_result=75,
test_result=45)
@classmethod
def tearDownClass(cls):
super(BaseQuerySetUsage, cls).tearDownClass()
delete_table(TestModel)
delete_table(IndexedTestModel)
delete_table(TestMultiClusteringModel)
drop_table(TestModel)
drop_table(IndexedTestModel)
drop_table(TestMultiClusteringModel)
class TestQuerySetCountSelectionAndIteration(BaseQuerySetUsage):
def test_count(self):
""" Tests that adding filtering statements affects the count query as expected """
assert TestModel.objects.count() == 12
q = TestModel.objects(test_id=0)
assert q.count() == 4
def test_query_expression_count(self):
""" Tests that adding query statements affects the count query as expected """
assert TestModel.objects.count() == 12
q = TestModel.objects(TestModel.test_id == 0)
assert q.count() == 4
def test_query_limit_count(self):
""" Tests that adding query with a limit affects the count as expected """
assert TestModel.objects.count() == 12
q = TestModel.objects(TestModel.test_id == 0).limit(2)
result = q.count()
self.assertEqual(2, result)
def test_iteration(self):
""" Tests that iterating over a query set pulls back all of the expected results """
q = TestModel.objects(test_id=0)
#tuple of expected attempt_id, expected_result values
compare_set = set([(0,5), (1,10), (2,15), (3,20)])
compare_set = set([(0, 5), (1, 10), (2, 15), (3, 20)])
for t in q:
val = t.attempt_id, t.expected_result
assert val in compare_set
compare_set.remove(val)
assert len(compare_set) == 0
# test with regular filtering
q = TestModel.objects(attempt_id=3).allow_filtering()
assert len(q) == 3
#tuple of expected test_id, expected_result values
compare_set = set([(0,20), (1,20), (2,75)])
compare_set = set([(0, 20), (1, 20), (2, 75)])
for t in q:
val = t.test_id, t.expected_result
assert val in compare_set
compare_set.remove(val)
assert len(compare_set) == 0
# test with query method
q = TestModel.objects(TestModel.attempt_id == 3).allow_filtering()
assert len(q) == 3
#tuple of expected test_id, expected_result values
compare_set = set([(0, 20), (1, 20), (2, 75)])
for t in q:
val = t.test_id, t.expected_result
assert val in compare_set
@@ -190,34 +287,36 @@ class TestQuerySetCountSelectionAndIteration(BaseQuerySetUsage):
def test_multiple_iterations_work_properly(self):
""" Tests that iterating over a query set more than once works """
q = TestModel.objects(test_id=0)
#tuple of expected attempt_id, expected_result values
compare_set = set([(0,5), (1,10), (2,15), (3,20)])
for t in q:
val = t.attempt_id, t.expected_result
assert val in compare_set
compare_set.remove(val)
assert len(compare_set) == 0
# test with both the filtering method and the query method
for q in (TestModel.objects(test_id=0), TestModel.objects(TestModel.test_id == 0)):
#tuple of expected attempt_id, expected_result values
compare_set = set([(0, 5), (1, 10), (2, 15), (3, 20)])
for t in q:
val = t.attempt_id, t.expected_result
assert val in compare_set
compare_set.remove(val)
assert len(compare_set) == 0
#try it again
compare_set = set([(0,5), (1,10), (2,15), (3,20)])
for t in q:
val = t.attempt_id, t.expected_result
assert val in compare_set
compare_set.remove(val)
assert len(compare_set) == 0
#try it again
compare_set = set([(0, 5), (1, 10), (2, 15), (3, 20)])
for t in q:
val = t.attempt_id, t.expected_result
assert val in compare_set
compare_set.remove(val)
assert len(compare_set) == 0
def test_multiple_iterators_are_isolated(self):
"""
tests that the use of one iterator does not affect the behavior of another
"""
q = TestModel.objects(test_id=0).order_by('attempt_id')
expected_order = [0,1,2,3]
iter1 = iter(q)
iter2 = iter(q)
for attempt_id in expected_order:
assert iter1.next().attempt_id == attempt_id
assert iter2.next().attempt_id == attempt_id
for q in (TestModel.objects(test_id=0), TestModel.objects(TestModel.test_id == 0)):
q = q.order_by('attempt_id')
expected_order = [0, 1, 2, 3]
iter1 = iter(q)
iter2 = iter(q)
for attempt_id in expected_order:
assert iter1.next().attempt_id == attempt_id
assert iter2.next().attempt_id == attempt_id
def test_get_success_case(self):
"""
@@ -240,6 +339,27 @@ class TestQuerySetCountSelectionAndIteration(BaseQuerySetUsage):
assert m.test_id == 0
assert m.attempt_id == 0
def test_query_expression_get_success_case(self):
"""
Tests that the .get() method works on new and existing querysets
"""
m = TestModel.get(TestModel.test_id == 0, TestModel.attempt_id == 0)
assert isinstance(m, TestModel)
assert m.test_id == 0
assert m.attempt_id == 0
q = TestModel.objects(TestModel.test_id == 0, TestModel.attempt_id == 0)
m = q.get()
assert isinstance(m, TestModel)
assert m.test_id == 0
assert m.attempt_id == 0
q = TestModel.objects(TestModel.test_id == 0)
m = q.get(TestModel.attempt_id == 0)
assert isinstance(m, TestModel)
assert m.test_id == 0
assert m.attempt_id == 0
def test_get_doesnotexist_exception(self):
"""
Tests that get calls that don't return a result raises a DoesNotExist error
@@ -258,25 +378,52 @@ class TestQuerySetCountSelectionAndIteration(BaseQuerySetUsage):
"""
"""
def test_non_quality_filtering():
class NonEqualityFilteringModel(Model):
__keyspace__ = 'test'
example_id = columns.UUID(primary_key=True, default=uuid.uuid4)
sequence_id = columns.Integer(primary_key=True) # sequence_id is a clustering key
example_type = columns.Integer(index=True)
created_at = columns.DateTime()
drop_table(NonEqualityFilteringModel)
sync_table(NonEqualityFilteringModel)
# setup table, etc.
NonEqualityFilteringModel.create(sequence_id=1, example_type=0, created_at=datetime.now())
NonEqualityFilteringModel.create(sequence_id=3, example_type=0, created_at=datetime.now())
NonEqualityFilteringModel.create(sequence_id=5, example_type=1, created_at=datetime.now())
qA = NonEqualityFilteringModel.objects(NonEqualityFilteringModel.sequence_id > 3).allow_filtering()
num = qA.count()
assert num == 1, num
class TestQuerySetOrdering(BaseQuerySetUsage):
def test_order_by_success_case(self):
q = TestModel.objects(test_id=0).order_by('attempt_id')
expected_order = [0,1,2,3]
for model, expect in zip(q,expected_order):
expected_order = [0, 1, 2, 3]
for model, expect in zip(q, expected_order):
assert model.attempt_id == expect
q = q.order_by('-attempt_id')
expected_order.reverse()
for model, expect in zip(q,expected_order):
for model, expect in zip(q, expected_order):
assert model.attempt_id == expect
def test_ordering_by_non_second_primary_keys_fail(self):
# kwarg filtering
with self.assertRaises(query.QueryException):
q = TestModel.objects(test_id=0).order_by('test_id')
# kwarg filtering
with self.assertRaises(query.QueryException):
q = TestModel.objects(TestModel.test_id == 0).order_by('test_id')
def test_ordering_by_non_primary_keys_fails(self):
with self.assertRaises(query.QueryException):
q = TestModel.objects(test_id=0).order_by('description')
@@ -303,7 +450,6 @@ class TestQuerySetOrdering(BaseQuerySetUsage):
class TestQuerySetSlicing(BaseQuerySetUsage):
def test_out_of_range_index_raises_error(self):
q = TestModel.objects(test_id=0).order_by('attempt_id')
with self.assertRaises(IndexError):
@@ -311,39 +457,39 @@ class TestQuerySetSlicing(BaseQuerySetUsage):
def test_array_indexing_works_properly(self):
q = TestModel.objects(test_id=0).order_by('attempt_id')
expected_order = [0,1,2,3]
expected_order = [0, 1, 2, 3]
for i in range(len(q)):
assert q[i].attempt_id == expected_order[i]
def test_negative_indexing_works_properly(self):
q = TestModel.objects(test_id=0).order_by('attempt_id')
expected_order = [0,1,2,3]
expected_order = [0, 1, 2, 3]
assert q[-1].attempt_id == expected_order[-1]
assert q[-2].attempt_id == expected_order[-2]
def test_slicing_works_properly(self):
q = TestModel.objects(test_id=0).order_by('attempt_id')
expected_order = [0,1,2,3]
expected_order = [0, 1, 2, 3]
for model, expect in zip(q[1:3], expected_order[1:3]):
assert model.attempt_id == expect
def test_negative_slicing(self):
q = TestModel.objects(test_id=0).order_by('attempt_id')
expected_order = [0,1,2,3]
expected_order = [0, 1, 2, 3]
for model, expect in zip(q[-3:], expected_order[-3:]):
assert model.attempt_id == expect
for model, expect in zip(q[:-1], expected_order[:-1]):
assert model.attempt_id == expect
class TestQuerySetValidation(BaseQuerySetUsage):
class TestQuerySetValidation(BaseQuerySetUsage):
def test_primary_key_or_index_must_be_specified(self):
"""
Tests that queries that don't have an equals relation to a primary key or indexed field fail
"""
with self.assertRaises(query.QueryException):
q = TestModel.objects(test_result=25)
[i for i in q]
list([i for i in q])
def test_primary_key_or_index_must_have_equal_relation_filter(self):
"""
@@ -351,19 +497,17 @@ class TestQuerySetValidation(BaseQuerySetUsage):
"""
with self.assertRaises(query.QueryException):
q = TestModel.objects(test_id__gt=0)
[i for i in q]
list([i for i in q])
def test_indexed_field_can_be_queried(self):
"""
Tests that queries on an indexed field will work without any primary key relations specified
"""
q = IndexedTestModel.objects(test_result=25)
count = q.count()
assert q.count() == 4
class TestQuerySetDelete(BaseQuerySetUsage):
class TestQuerySetDelete(BaseQuerySetUsage):
def test_delete(self):
TestModel.objects.create(test_id=3, attempt_id=0, description='try9', expected_result=50, test_result=40)
TestModel.objects.create(test_id=3, attempt_id=1, description='try10', expected_result=60, test_result=40)
@@ -388,15 +532,15 @@ class TestQuerySetDelete(BaseQuerySetUsage):
with self.assertRaises(query.QueryException):
TestModel.objects(attempt_id=0).delete()
class TestQuerySetConnectionHandling(BaseQuerySetUsage):
class TestQuerySetConnectionHandling(BaseQuerySetUsage):
def test_conn_is_returned_after_filling_cache(self):
"""
Tests that the queryset returns it's connection after it's fetched all of it's results
"""
q = TestModel.objects(test_id=0)
#tuple of expected attempt_id, expected_result values
compare_set = set([(0,5), (1,10), (2,15), (3,20)])
compare_set = set([(0, 5), (1, 10), (2, 15), (3, 20)])
for t in q:
val = t.attempt_id, t.expected_result
assert val in compare_set
@@ -405,35 +549,67 @@ class TestQuerySetConnectionHandling(BaseQuerySetUsage):
assert q._con is None
assert q._cur is None
def test_conn_is_returned_after_queryset_is_garbage_collected(self):
""" Tests that the connection is returned to the connection pool after the queryset is gc'd """
from cqlengine.connection import ConnectionPool
# The queue size can be 1 if we just run this file's tests
# It will be 2 when we run 'em all
initial_size = ConnectionPool._queue.qsize()
q = TestModel.objects(test_id=0)
v = q[0]
assert ConnectionPool._queue.qsize() == initial_size - 1
del q
assert ConnectionPool._queue.qsize() == initial_size
class TimeUUIDQueryModel(Model):
partition = columns.UUID(primary_key=True)
time = columns.TimeUUID(primary_key=True)
data = columns.Text(required=False)
__keyspace__ = 'test'
partition = columns.UUID(primary_key=True)
time = columns.TimeUUID(primary_key=True)
data = columns.Text(required=False)
class TestMinMaxTimeUUIDFunctions(BaseCassEngTestCase):
@classmethod
def setUpClass(cls):
super(TestMinMaxTimeUUIDFunctions, cls).setUpClass()
create_table(TimeUUIDQueryModel)
sync_table(TimeUUIDQueryModel)
@classmethod
def tearDownClass(cls):
super(TestMinMaxTimeUUIDFunctions, cls).tearDownClass()
delete_table(TimeUUIDQueryModel)
drop_table(TimeUUIDQueryModel)
def test_tzaware_datetime_support(self):
"""Test that using timezone aware datetime instances works with the
MinTimeUUID/MaxTimeUUID functions.
"""
pk = uuid4()
midpoint_utc = datetime.utcnow().replace(tzinfo=TzOffset(0))
midpoint_helsinki = midpoint_utc.astimezone(TzOffset(3))
# Assert pre-condition that we have the same logical point in time
assert midpoint_utc.utctimetuple() == midpoint_helsinki.utctimetuple()
assert midpoint_utc.timetuple() != midpoint_helsinki.timetuple()
TimeUUIDQueryModel.create(
partition=pk,
time=columns.TimeUUID.from_datetime(midpoint_utc - timedelta(minutes=1)),
data='1')
TimeUUIDQueryModel.create(
partition=pk,
time=columns.TimeUUID.from_datetime(midpoint_utc),
data='2')
TimeUUIDQueryModel.create(
partition=pk,
time=columns.TimeUUID.from_datetime(midpoint_utc + timedelta(minutes=1)),
data='3')
assert ['1', '2'] == [o.data for o in TimeUUIDQueryModel.filter(
TimeUUIDQueryModel.partition == pk,
TimeUUIDQueryModel.time <= functions.MaxTimeUUID(midpoint_utc))]
assert ['1', '2'] == [o.data for o in TimeUUIDQueryModel.filter(
TimeUUIDQueryModel.partition == pk,
TimeUUIDQueryModel.time <= functions.MaxTimeUUID(midpoint_helsinki))]
assert ['2', '3'] == [o.data for o in TimeUUIDQueryModel.filter(
TimeUUIDQueryModel.partition == pk,
TimeUUIDQueryModel.time >= functions.MinTimeUUID(midpoint_utc))]
assert ['2', '3'] == [o.data for o in TimeUUIDQueryModel.filter(
TimeUUIDQueryModel.partition == pk,
TimeUUIDQueryModel.time >= functions.MinTimeUUID(midpoint_helsinki))]
def test_success_case(self):
""" Test that the min and max time uuid functions work as expected """
@@ -449,26 +625,85 @@ class TestMinMaxTimeUUIDFunctions(BaseCassEngTestCase):
TimeUUIDQueryModel.create(partition=pk, time=uuid1(), data='4')
time.sleep(0.2)
# test kwarg filtering
q = TimeUUIDQueryModel.filter(partition=pk, time__lte=functions.MaxTimeUUID(midpoint))
q = [d for d in q]
assert len(q) == 2
datas = [d.data for d in q]
assert '1' in datas
assert '2' in datas
assert '1' in datas
assert '2' in datas
q = TimeUUIDQueryModel.filter(partition=pk, time__gte=functions.MinTimeUUID(midpoint))
assert len(q) == 2
datas = [d.data for d in q]
assert '3' in datas
assert '4' in datas
assert '3' in datas
assert '4' in datas
# test query expression filtering
q = TimeUUIDQueryModel.filter(
TimeUUIDQueryModel.partition == pk,
TimeUUIDQueryModel.time <= functions.MaxTimeUUID(midpoint)
)
q = [d for d in q]
assert len(q) == 2
datas = [d.data for d in q]
assert '1' in datas
assert '2' in datas
q = TimeUUIDQueryModel.filter(
TimeUUIDQueryModel.partition == pk,
TimeUUIDQueryModel.time >= functions.MinTimeUUID(midpoint)
)
assert len(q) == 2
datas = [d.data for d in q]
assert '3' in datas
assert '4' in datas
class TestInOperator(BaseQuerySetUsage):
def test_kwarg_success_case(self):
""" Tests the in operator works with the kwarg query method """
q = TestModel.filter(test_id__in=[0, 1])
assert q.count() == 8
def test_success_case(self):
q = TestModel.filter(test_id__in=[0,1])
def test_query_expression_success_case(self):
""" Tests the in operator works with the query expression query method """
q = TestModel.filter(TestModel.test_id.in_([0, 1]))
assert q.count() == 8
class TestValuesList(BaseQuerySetUsage):
def test_values_list(self):
q = TestModel.objects.filter(test_id=0, attempt_id=1)
item = q.values_list('test_id', 'attempt_id', 'description', 'expected_result', 'test_result').first()
assert item == [0, 1, 'try2', 10, 30]
item = q.values_list('expected_result', flat=True).first()
assert item == 10
class TestObjectsProperty(BaseQuerySetUsage):
def test_objects_property_returns_fresh_queryset(self):
assert TestModel.objects._result_cache is None
len(TestModel.objects) # evaluate queryset
assert TestModel.objects._result_cache is None
@skipUnless(cluster.protocol_version >= 2, "only runs against the cql3 protocol v2.0")
def test_paged_result_handling():
# addresses #225
class PagingTest(Model):
id = columns.Integer(primary_key=True)
val = columns.Integer()
sync_table(PagingTest)
PagingTest.create(id=1, val=1)
PagingTest.create(id=2, val=2)
session = get_session()
with mock.patch.object(session, 'default_fetch_size', 1):
results = PagingTest.objects()[:]
assert len(results) == 2

View File

@@ -0,0 +1,219 @@
from uuid import uuid4
from cqlengine.exceptions import ValidationError
from cqlengine.query import QueryException
from cqlengine.tests.base import BaseCassEngTestCase
from cqlengine.models import Model
from cqlengine.management import sync_table, drop_table
from cqlengine import columns
class TestQueryUpdateModel(Model):
__keyspace__ = 'test'
partition = columns.UUID(primary_key=True, default=uuid4)
cluster = columns.Integer(primary_key=True)
count = columns.Integer(required=False)
text = columns.Text(required=False, index=True)
text_set = columns.Set(columns.Text, required=False)
text_list = columns.List(columns.Text, required=False)
text_map = columns.Map(columns.Text, columns.Text, required=False)
class QueryUpdateTests(BaseCassEngTestCase):
@classmethod
def setUpClass(cls):
super(QueryUpdateTests, cls).setUpClass()
sync_table(TestQueryUpdateModel)
@classmethod
def tearDownClass(cls):
super(QueryUpdateTests, cls).tearDownClass()
drop_table(TestQueryUpdateModel)
def test_update_values(self):
""" tests calling udpate on a queryset """
partition = uuid4()
for i in range(5):
TestQueryUpdateModel.create(partition=partition, cluster=i, count=i, text=str(i))
# sanity check
for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
assert row.cluster == i
assert row.count == i
assert row.text == str(i)
# perform update
TestQueryUpdateModel.objects(partition=partition, cluster=3).update(count=6)
for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
assert row.cluster == i
assert row.count == (6 if i == 3 else i)
assert row.text == str(i)
def test_update_values_validation(self):
""" tests calling udpate on models with values passed in """
partition = uuid4()
for i in range(5):
TestQueryUpdateModel.create(partition=partition, cluster=i, count=i, text=str(i))
# sanity check
for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
assert row.cluster == i
assert row.count == i
assert row.text == str(i)
# perform update
with self.assertRaises(ValidationError):
TestQueryUpdateModel.objects(partition=partition, cluster=3).update(count='asdf')
def test_invalid_update_kwarg(self):
""" tests that passing in a kwarg to the update method that isn't a column will fail """
with self.assertRaises(ValidationError):
TestQueryUpdateModel.objects(partition=uuid4(), cluster=3).update(bacon=5000)
def test_primary_key_update_failure(self):
""" tests that attempting to update the value of a primary key will fail """
with self.assertRaises(ValidationError):
TestQueryUpdateModel.objects(partition=uuid4(), cluster=3).update(cluster=5000)
def test_null_update_deletes_column(self):
""" setting a field to null in the update should issue a delete statement """
partition = uuid4()
for i in range(5):
TestQueryUpdateModel.create(partition=partition, cluster=i, count=i, text=str(i))
# sanity check
for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
assert row.cluster == i
assert row.count == i
assert row.text == str(i)
# perform update
TestQueryUpdateModel.objects(partition=partition, cluster=3).update(text=None)
for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
assert row.cluster == i
assert row.count == i
assert row.text == (None if i == 3 else str(i))
def test_mixed_value_and_null_update(self):
""" tests that updating a columns value, and removing another works properly """
partition = uuid4()
for i in range(5):
TestQueryUpdateModel.create(partition=partition, cluster=i, count=i, text=str(i))
# sanity check
for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
assert row.cluster == i
assert row.count == i
assert row.text == str(i)
# perform update
TestQueryUpdateModel.objects(partition=partition, cluster=3).update(count=6, text=None)
for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)):
assert row.cluster == i
assert row.count == (6 if i == 3 else i)
assert row.text == (None if i == 3 else str(i))
def test_counter_updates(self):
pass
def test_set_add_updates(self):
partition = uuid4()
cluster = 1
TestQueryUpdateModel.objects.create(
partition=partition, cluster=cluster, text_set={"foo"})
TestQueryUpdateModel.objects(
partition=partition, cluster=cluster).update(text_set__add={'bar'})
obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster)
self.assertEqual(obj.text_set, {"foo", "bar"})
def test_set_add_updates_new_record(self):
""" If the key doesn't exist yet, an update creates the record
"""
partition = uuid4()
cluster = 1
TestQueryUpdateModel.objects(
partition=partition, cluster=cluster).update(text_set__add={'bar'})
obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster)
self.assertEqual(obj.text_set, {"bar"})
def test_set_remove_updates(self):
partition = uuid4()
cluster = 1
TestQueryUpdateModel.objects.create(
partition=partition, cluster=cluster, text_set={"foo", "baz"})
TestQueryUpdateModel.objects(
partition=partition, cluster=cluster).update(
text_set__remove={'foo'})
obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster)
self.assertEqual(obj.text_set, {"baz"})
def test_set_remove_new_record(self):
""" Removing something not in the set should silently do nothing
"""
partition = uuid4()
cluster = 1
TestQueryUpdateModel.objects.create(
partition=partition, cluster=cluster, text_set={"foo"})
TestQueryUpdateModel.objects(
partition=partition, cluster=cluster).update(
text_set__remove={'afsd'})
obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster)
self.assertEqual(obj.text_set, {"foo"})
def test_list_append_updates(self):
partition = uuid4()
cluster = 1
TestQueryUpdateModel.objects.create(
partition=partition, cluster=cluster, text_list=["foo"])
TestQueryUpdateModel.objects(
partition=partition, cluster=cluster).update(
text_list__append=['bar'])
obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster)
self.assertEqual(obj.text_list, ["foo", "bar"])
def test_list_prepend_updates(self):
""" Prepend two things since order is reversed by default by CQL """
partition = uuid4()
cluster = 1
TestQueryUpdateModel.objects.create(
partition=partition, cluster=cluster, text_list=["foo"])
TestQueryUpdateModel.objects(
partition=partition, cluster=cluster).update(
text_list__prepend=['bar', 'baz'])
obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster)
self.assertEqual(obj.text_list, ["bar", "baz", "foo"])
def test_map_update_updates(self):
""" Merge a dictionary into existing value """
partition = uuid4()
cluster = 1
TestQueryUpdateModel.objects.create(
partition=partition, cluster=cluster,
text_map={"foo": '1', "bar": '2'})
TestQueryUpdateModel.objects(
partition=partition, cluster=cluster).update(
text_map__update={"bar": '3', "baz": '4'})
obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster)
self.assertEqual(obj.text_map, {"foo": '1', "bar": '3', "baz": '4'})
def test_map_update_none_deletes_key(self):
""" The CQL behavior is if you set a key in a map to null it deletes
that key from the map. Test that this works with __update.
This test fails because of a bug in the cql python library not
converting None to null (and the cql library is no longer in active
developement).
"""
# partition = uuid4()
# cluster = 1
# TestQueryUpdateModel.objects.create(
# partition=partition, cluster=cluster,
# text_map={"foo": '1', "bar": '2'})
# TestQueryUpdateModel.objects(
# partition=partition, cluster=cluster).update(
# text_map__update={"bar": None})
# obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster)
# self.assertEqual(obj.text_map, {"foo": '1'})

View File

@@ -0,0 +1 @@
__author__ = 'bdeggleston'

View File

@@ -0,0 +1,327 @@
from unittest import TestCase
from cqlengine.statements import AssignmentClause, SetUpdateClause, ListUpdateClause, MapUpdateClause, MapDeleteClause, FieldDeleteClause, CounterUpdateClause
class AssignmentClauseTests(TestCase):
def test_rendering(self):
pass
def test_insert_tuple(self):
ac = AssignmentClause('a', 'b')
ac.set_context_id(10)
self.assertEqual(ac.insert_tuple(), ('a', 10))
class SetUpdateClauseTests(TestCase):
def test_update_from_none(self):
c = SetUpdateClause('s', {1, 2}, previous=None)
c._analyze()
c.set_context_id(0)
self.assertEqual(c._assignments, {1, 2})
self.assertIsNone(c._additions)
self.assertIsNone(c._removals)
self.assertEqual(c.get_context_size(), 1)
self.assertEqual(str(c), '"s" = %(0)s')
ctx = {}
c.update_context(ctx)
self.assertEqual(ctx, {'0': {1, 2}})
def test_null_update(self):
""" tests setting a set to None creates an empty update statement """
c = SetUpdateClause('s', None, previous={1, 2})
c._analyze()
c.set_context_id(0)
self.assertIsNone(c._assignments)
self.assertIsNone(c._additions)
self.assertIsNone(c._removals)
self.assertEqual(c.get_context_size(), 0)
self.assertEqual(str(c), '')
ctx = {}
c.update_context(ctx)
self.assertEqual(ctx, {})
def test_no_update(self):
""" tests an unchanged value creates an empty update statement """
c = SetUpdateClause('s', {1, 2}, previous={1, 2})
c._analyze()
c.set_context_id(0)
self.assertIsNone(c._assignments)
self.assertIsNone(c._additions)
self.assertIsNone(c._removals)
self.assertEqual(c.get_context_size(), 0)
self.assertEqual(str(c), '')
ctx = {}
c.update_context(ctx)
self.assertEqual(ctx, {})
def test_additions(self):
c = SetUpdateClause('s', {1, 2, 3}, previous={1, 2})
c._analyze()
c.set_context_id(0)
self.assertIsNone(c._assignments)
self.assertEqual(c._additions, {3})
self.assertIsNone(c._removals)
self.assertEqual(c.get_context_size(), 1)
self.assertEqual(str(c), '"s" = "s" + %(0)s')
ctx = {}
c.update_context(ctx)
self.assertEqual(ctx, {'0': {3}})
def test_removals(self):
c = SetUpdateClause('s', {1, 2}, previous={1, 2, 3})
c._analyze()
c.set_context_id(0)
self.assertIsNone(c._assignments)
self.assertIsNone(c._additions)
self.assertEqual(c._removals, {3})
self.assertEqual(c.get_context_size(), 1)
self.assertEqual(str(c), '"s" = "s" - %(0)s')
ctx = {}
c.update_context(ctx)
self.assertEqual(ctx, {'0': {3}})
def test_additions_and_removals(self):
c = SetUpdateClause('s', {2, 3}, previous={1, 2})
c._analyze()
c.set_context_id(0)
self.assertIsNone(c._assignments)
self.assertEqual(c._additions, {3})
self.assertEqual(c._removals, {1})
self.assertEqual(c.get_context_size(), 2)
self.assertEqual(str(c), '"s" = "s" + %(0)s, "s" = "s" - %(1)s')
ctx = {}
c.update_context(ctx)
self.assertEqual(ctx, {'0': {3}, '1': {1}})
class ListUpdateClauseTests(TestCase):
def test_update_from_none(self):
c = ListUpdateClause('s', [1, 2, 3])
c._analyze()
c.set_context_id(0)
self.assertEqual(c._assignments, [1, 2, 3])
self.assertIsNone(c._append)
self.assertIsNone(c._prepend)
self.assertEqual(c.get_context_size(), 1)
self.assertEqual(str(c), '"s" = %(0)s')
ctx = {}
c.update_context(ctx)
self.assertEqual(ctx, {'0': [1, 2, 3]})
def test_update_from_empty(self):
c = ListUpdateClause('s', [1, 2, 3], previous=[])
c._analyze()
c.set_context_id(0)
self.assertEqual(c._assignments, [1, 2, 3])
self.assertIsNone(c._append)
self.assertIsNone(c._prepend)
self.assertEqual(c.get_context_size(), 1)
self.assertEqual(str(c), '"s" = %(0)s')
ctx = {}
c.update_context(ctx)
self.assertEqual(ctx, {'0': [1, 2, 3]})
def test_update_from_different_list(self):
c = ListUpdateClause('s', [1, 2, 3], previous=[3, 2, 1])
c._analyze()
c.set_context_id(0)
self.assertEqual(c._assignments, [1, 2, 3])
self.assertIsNone(c._append)
self.assertIsNone(c._prepend)
self.assertEqual(c.get_context_size(), 1)
self.assertEqual(str(c), '"s" = %(0)s')
ctx = {}
c.update_context(ctx)
self.assertEqual(ctx, {'0': [1, 2, 3]})
def test_append(self):
c = ListUpdateClause('s', [1, 2, 3, 4], previous=[1, 2])
c._analyze()
c.set_context_id(0)
self.assertIsNone(c._assignments)
self.assertEqual(c._append, [3, 4])
self.assertIsNone(c._prepend)
self.assertEqual(c.get_context_size(), 1)
self.assertEqual(str(c), '"s" = "s" + %(0)s')
ctx = {}
c.update_context(ctx)
self.assertEqual(ctx, {'0': [3, 4]})
def test_prepend(self):
c = ListUpdateClause('s', [1, 2, 3, 4], previous=[3, 4])
c._analyze()
c.set_context_id(0)
self.assertIsNone(c._assignments)
self.assertIsNone(c._append)
self.assertEqual(c._prepend, [1, 2])
self.assertEqual(c.get_context_size(), 1)
self.assertEqual(str(c), '"s" = %(0)s + "s"')
ctx = {}
c.update_context(ctx)
# test context list reversal
self.assertEqual(ctx, {'0': [2, 1]})
def test_append_and_prepend(self):
c = ListUpdateClause('s', [1, 2, 3, 4, 5, 6], previous=[3, 4])
c._analyze()
c.set_context_id(0)
self.assertIsNone(c._assignments)
self.assertEqual(c._append, [5, 6])
self.assertEqual(c._prepend, [1, 2])
self.assertEqual(c.get_context_size(), 2)
self.assertEqual(str(c), '"s" = %(0)s + "s", "s" = "s" + %(1)s')
ctx = {}
c.update_context(ctx)
# test context list reversal
self.assertEqual(ctx, {'0': [2, 1], '1': [5, 6]})
def test_shrinking_list_update(self):
""" tests that updating to a smaller list results in an insert statement """
c = ListUpdateClause('s', [1, 2, 3], previous=[1, 2, 3, 4])
c._analyze()
c.set_context_id(0)
self.assertEqual(c._assignments, [1, 2, 3])
self.assertIsNone(c._append)
self.assertIsNone(c._prepend)
self.assertEqual(c.get_context_size(), 1)
self.assertEqual(str(c), '"s" = %(0)s')
ctx = {}
c.update_context(ctx)
self.assertEqual(ctx, {'0': [1, 2, 3]})
class MapUpdateTests(TestCase):
def test_update(self):
c = MapUpdateClause('s', {3: 0, 5: 6}, previous={5: 0, 3: 4})
c._analyze()
c.set_context_id(0)
self.assertEqual(c._updates, [3, 5])
self.assertEqual(c.get_context_size(), 4)
self.assertEqual(str(c), '"s"[%(0)s] = %(1)s, "s"[%(2)s] = %(3)s')
ctx = {}
c.update_context(ctx)
self.assertEqual(ctx, {'0': 3, "1": 0, '2': 5, '3': 6})
def test_update_from_null(self):
c = MapUpdateClause('s', {3: 0, 5: 6})
c._analyze()
c.set_context_id(0)
self.assertEqual(c._updates, [3, 5])
self.assertEqual(c.get_context_size(), 4)
self.assertEqual(str(c), '"s"[%(0)s] = %(1)s, "s"[%(2)s] = %(3)s')
ctx = {}
c.update_context(ctx)
self.assertEqual(ctx, {'0': 3, "1": 0, '2': 5, '3': 6})
def test_nulled_columns_arent_included(self):
c = MapUpdateClause('s', {3: 0}, {1: 2, 3: 4})
c._analyze()
c.set_context_id(0)
self.assertNotIn(1, c._updates)
class CounterUpdateTests(TestCase):
def test_positive_update(self):
c = CounterUpdateClause('a', 5, 3)
c.set_context_id(5)
self.assertEqual(c.get_context_size(), 1)
self.assertEqual(str(c), '"a" = "a" + %(5)s')
ctx = {}
c.update_context(ctx)
self.assertEqual(ctx, {'5': 2})
def test_negative_update(self):
c = CounterUpdateClause('a', 4, 7)
c.set_context_id(3)
self.assertEqual(c.get_context_size(), 1)
self.assertEqual(str(c), '"a" = "a" - %(3)s')
ctx = {}
c.update_context(ctx)
self.assertEqual(ctx, {'3': 3})
def noop_update(self):
c = CounterUpdateClause('a', 5, 5)
c.set_context_id(5)
self.assertEqual(c.get_context_size(), 1)
self.assertEqual(str(c), '"a" = "a" + %(0)s')
ctx = {}
c.update_context(ctx)
self.assertEqual(ctx, {'5': 0})
class MapDeleteTests(TestCase):
def test_update(self):
c = MapDeleteClause('s', {3: 0}, {1: 2, 3: 4, 5: 6})
c._analyze()
c.set_context_id(0)
self.assertEqual(c._removals, [1, 5])
self.assertEqual(c.get_context_size(), 2)
self.assertEqual(str(c), '"s"[%(0)s], "s"[%(1)s]')
ctx = {}
c.update_context(ctx)
self.assertEqual(ctx, {'0': 1, '1': 5})
class FieldDeleteTests(TestCase):
def test_str(self):
f = FieldDeleteClause("blake")
assert str(f) == '"blake"'

View File

@@ -0,0 +1,11 @@
from unittest import TestCase
from cqlengine.statements import AssignmentStatement, StatementException
class AssignmentStatementTest(TestCase):
def test_add_assignment_type_checking(self):
""" tests that only assignment clauses can be added to queries """
stmt = AssignmentStatement('table', [])
with self.assertRaises(StatementException):
stmt.add_assignment_clause('x=5')

View File

@@ -0,0 +1,16 @@
from unittest import TestCase
from cqlengine.statements import BaseClause
class BaseClauseTests(TestCase):
def test_context_updating(self):
ss = BaseClause('a', 'b')
assert ss.get_context_size() == 1
ctx = {}
ss.set_context_id(10)
ss.update_context(ctx)
assert ctx == {'10': 'b'}

View File

@@ -0,0 +1,11 @@
from unittest import TestCase
from cqlengine.statements import BaseCQLStatement, StatementException
class BaseStatementTest(TestCase):
def test_where_clause_type_checking(self):
""" tests that only assignment clauses can be added to queries """
stmt = BaseCQLStatement('table', [])
with self.assertRaises(StatementException):
stmt.add_where_clause('x=5')

View File

@@ -0,0 +1,48 @@
from unittest import TestCase
from cqlengine.statements import DeleteStatement, WhereClause, MapDeleteClause
from cqlengine.operators import *
class DeleteStatementTests(TestCase):
def test_single_field_is_listified(self):
""" tests that passing a string field into the constructor puts it into a list """
ds = DeleteStatement('table', 'field')
self.assertEqual(len(ds.fields), 1)
self.assertEqual(ds.fields[0].field, 'field')
def test_field_rendering(self):
""" tests that fields are properly added to the select statement """
ds = DeleteStatement('table', ['f1', 'f2'])
self.assertTrue(unicode(ds).startswith('DELETE "f1", "f2"'), unicode(ds))
self.assertTrue(str(ds).startswith('DELETE "f1", "f2"'), str(ds))
def test_none_fields_rendering(self):
""" tests that a '*' is added if no fields are passed in """
ds = DeleteStatement('table', None)
self.assertTrue(unicode(ds).startswith('DELETE FROM'), unicode(ds))
self.assertTrue(str(ds).startswith('DELETE FROM'), str(ds))
def test_table_rendering(self):
ds = DeleteStatement('table', None)
self.assertTrue(unicode(ds).startswith('DELETE FROM table'), unicode(ds))
self.assertTrue(str(ds).startswith('DELETE FROM table'), str(ds))
def test_where_clause_rendering(self):
ds = DeleteStatement('table', None)
ds.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
self.assertEqual(unicode(ds), 'DELETE FROM table WHERE "a" = %(0)s', unicode(ds))
def test_context_update(self):
ds = DeleteStatement('table', None)
ds.add_field(MapDeleteClause('d', {1: 2}, {1:2, 3: 4}))
ds.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
ds.update_context_id(7)
self.assertEqual(unicode(ds), 'DELETE "d"[%(8)s] FROM table WHERE "a" = %(7)s')
self.assertEqual(ds.get_context(), {'7': 'b', '8': 3})
def test_context(self):
ds = DeleteStatement('table', None)
ds.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
self.assertEqual(ds.get_context(), {'0': 'b'})

View File

@@ -0,0 +1,40 @@
from unittest import TestCase
from cqlengine.statements import InsertStatement, StatementException, AssignmentClause
class InsertStatementTests(TestCase):
def test_where_clause_failure(self):
""" tests that where clauses cannot be added to Insert statements """
ist = InsertStatement('table', None)
with self.assertRaises(StatementException):
ist.add_where_clause('s')
def test_statement(self):
ist = InsertStatement('table', None)
ist.add_assignment_clause(AssignmentClause('a', 'b'))
ist.add_assignment_clause(AssignmentClause('c', 'd'))
self.assertEqual(
unicode(ist),
'INSERT INTO table ("a", "c") VALUES (%(0)s, %(1)s)'
)
def test_context_update(self):
ist = InsertStatement('table', None)
ist.add_assignment_clause(AssignmentClause('a', 'b'))
ist.add_assignment_clause(AssignmentClause('c', 'd'))
ist.update_context_id(4)
self.assertEqual(
unicode(ist),
'INSERT INTO table ("a", "c") VALUES (%(4)s, %(5)s)'
)
ctx = ist.get_context()
self.assertEqual(ctx, {'4': 'b', '5': 'd'})
def test_additional_rendering(self):
ist = InsertStatement('table', ttl=60)
ist.add_assignment_clause(AssignmentClause('a', 'b'))
ist.add_assignment_clause(AssignmentClause('c', 'd'))
self.assertIn('USING TTL 60', unicode(ist))

View File

@@ -0,0 +1,70 @@
from unittest import TestCase
from cqlengine.statements import SelectStatement, WhereClause
from cqlengine.operators import *
class SelectStatementTests(TestCase):
def test_single_field_is_listified(self):
""" tests that passing a string field into the constructor puts it into a list """
ss = SelectStatement('table', 'field')
self.assertEqual(ss.fields, ['field'])
def test_field_rendering(self):
""" tests that fields are properly added to the select statement """
ss = SelectStatement('table', ['f1', 'f2'])
self.assertTrue(unicode(ss).startswith('SELECT "f1", "f2"'), unicode(ss))
self.assertTrue(str(ss).startswith('SELECT "f1", "f2"'), str(ss))
def test_none_fields_rendering(self):
""" tests that a '*' is added if no fields are passed in """
ss = SelectStatement('table')
self.assertTrue(unicode(ss).startswith('SELECT *'), unicode(ss))
self.assertTrue(str(ss).startswith('SELECT *'), str(ss))
def test_table_rendering(self):
ss = SelectStatement('table')
self.assertTrue(unicode(ss).startswith('SELECT * FROM table'), unicode(ss))
self.assertTrue(str(ss).startswith('SELECT * FROM table'), str(ss))
def test_where_clause_rendering(self):
ss = SelectStatement('table')
ss.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
self.assertEqual(unicode(ss), 'SELECT * FROM table WHERE "a" = %(0)s', unicode(ss))
def test_count(self):
ss = SelectStatement('table', count=True, limit=10, order_by='d')
ss.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
self.assertEqual(unicode(ss), 'SELECT COUNT(*) FROM table WHERE "a" = %(0)s LIMIT 10', unicode(ss))
self.assertIn('LIMIT', unicode(ss))
self.assertNotIn('ORDER', unicode(ss))
def test_context(self):
ss = SelectStatement('table')
ss.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
self.assertEqual(ss.get_context(), {'0': 'b'})
def test_context_id_update(self):
""" tests that the right things happen the the context id """
ss = SelectStatement('table')
ss.add_where_clause(WhereClause('a', EqualsOperator(), 'b'))
self.assertEqual(ss.get_context(), {'0': 'b'})
self.assertEqual(str(ss), 'SELECT * FROM table WHERE "a" = %(0)s')
ss.update_context_id(5)
self.assertEqual(ss.get_context(), {'5': 'b'})
self.assertEqual(str(ss), 'SELECT * FROM table WHERE "a" = %(5)s')
def test_additional_rendering(self):
ss = SelectStatement(
'table',
None,
order_by=['x', 'y'],
limit=15,
allow_filtering=True
)
qstr = unicode(ss)
self.assertIn('LIMIT 15', qstr)
self.assertIn('ORDER BY x, y', qstr)
self.assertIn('ALLOW FILTERING', qstr)

View File

@@ -0,0 +1,42 @@
from unittest import TestCase
from cqlengine.statements import UpdateStatement, WhereClause, AssignmentClause
from cqlengine.operators import *
class UpdateStatementTests(TestCase):
def test_table_rendering(self):
""" tests that fields are properly added to the select statement """
us = UpdateStatement('table')
self.assertTrue(unicode(us).startswith('UPDATE table SET'), unicode(us))
self.assertTrue(str(us).startswith('UPDATE table SET'), str(us))
def test_rendering(self):
us = UpdateStatement('table')
us.add_assignment_clause(AssignmentClause('a', 'b'))
us.add_assignment_clause(AssignmentClause('c', 'd'))
us.add_where_clause(WhereClause('a', EqualsOperator(), 'x'))
self.assertEqual(unicode(us), 'UPDATE table SET "a" = %(0)s, "c" = %(1)s WHERE "a" = %(2)s', unicode(us))
def test_context(self):
us = UpdateStatement('table')
us.add_assignment_clause(AssignmentClause('a', 'b'))
us.add_assignment_clause(AssignmentClause('c', 'd'))
us.add_where_clause(WhereClause('a', EqualsOperator(), 'x'))
self.assertEqual(us.get_context(), {'0': 'b', '1': 'd', '2': 'x'})
def test_context_update(self):
us = UpdateStatement('table')
us.add_assignment_clause(AssignmentClause('a', 'b'))
us.add_assignment_clause(AssignmentClause('c', 'd'))
us.add_where_clause(WhereClause('a', EqualsOperator(), 'x'))
us.update_context_id(3)
self.assertEqual(unicode(us), 'UPDATE table SET "a" = %(4)s, "c" = %(5)s WHERE "a" = %(3)s')
self.assertEqual(us.get_context(), {'4': 'b', '5': 'd', '3': 'x'})
def test_additional_rendering(self):
us = UpdateStatement('table', ttl=60)
us.add_assignment_clause(AssignmentClause('a', 'b'))
us.add_where_clause(WhereClause('a', EqualsOperator(), 'x'))
self.assertIn('USING TTL 60', unicode(us))

View File

@@ -0,0 +1,25 @@
from unittest import TestCase
from cqlengine.operators import EqualsOperator
from cqlengine.statements import StatementException, WhereClause
class TestWhereClause(TestCase):
def test_operator_check(self):
""" tests that creating a where statement with a non BaseWhereOperator object fails """
with self.assertRaises(StatementException):
WhereClause('a', 'b', 'c')
def test_where_clause_rendering(self):
""" tests that where clauses are rendered properly """
wc = WhereClause('a', EqualsOperator(), 'c')
wc.set_context_id(5)
self.assertEqual('"a" = %(5)s', unicode(wc), unicode(wc))
self.assertEqual('"a" = %(5)s', str(wc), type(wc))
def test_equality_method(self):
""" tests that 2 identical where clauses evaluate as == """
wc1 = WhereClause('a', EqualsOperator(), 'c')
wc2 = WhereClause('a', EqualsOperator(), 'c')
assert wc1 == wc2

View File

@@ -1,29 +1,35 @@
from unittest import skip
from uuid import uuid4
import random
import mock
import sure
from cqlengine import Model, columns
from cqlengine.management import delete_table, create_table
from cqlengine.management import drop_table, sync_table
from cqlengine.query import BatchQuery
from cqlengine.tests.base import BaseCassEngTestCase
class TestMultiKeyModel(Model):
__keyspace__ = 'test'
partition = columns.Integer(primary_key=True)
cluster = columns.Integer(primary_key=True)
count = columns.Integer(required=False)
text = columns.Text(required=False)
class BatchQueryTests(BaseCassEngTestCase):
@classmethod
def setUpClass(cls):
super(BatchQueryTests, cls).setUpClass()
delete_table(TestMultiKeyModel)
create_table(TestMultiKeyModel)
drop_table(TestMultiKeyModel)
sync_table(TestMultiKeyModel)
@classmethod
def tearDownClass(cls):
super(BatchQueryTests, cls).tearDownClass()
delete_table(TestMultiKeyModel)
drop_table(TestMultiKeyModel)
def setUp(self):
super(BatchQueryTests, self).setUp()
@@ -109,3 +115,80 @@ class BatchQueryTests(BaseCassEngTestCase):
with BatchQuery() as b:
pass
class BatchQueryCallbacksTests(BaseCassEngTestCase):
def test_API_managing_callbacks(self):
# Callbacks can be added at init and after
def my_callback(*args, **kwargs):
pass
# adding on init:
batch = BatchQuery()
batch.add_callback(my_callback)
batch.add_callback(my_callback, 2, named_arg='value')
batch.add_callback(my_callback, 1, 3)
assert batch._callbacks == [
(my_callback, (), {}),
(my_callback, (2,), {'named_arg':'value'}),
(my_callback, (1, 3), {})
]
def test_callbacks_properly_execute_callables_and_tuples(self):
call_history = []
def my_callback(*args, **kwargs):
call_history.append(args)
# adding on init:
batch = BatchQuery()
batch.add_callback(my_callback)
batch.add_callback(my_callback, 'more', 'args')
batch.execute()
assert len(call_history) == 2
assert [(), ('more', 'args')] == call_history
def test_callbacks_tied_to_execute(self):
"""Batch callbacks should NOT fire if batch is not executed in context manager mode"""
call_history = []
def my_callback(*args, **kwargs):
call_history.append(args)
with BatchQuery() as batch:
batch.add_callback(my_callback)
pass
assert len(call_history) == 1
class SomeError(Exception):
pass
with self.assertRaises(SomeError):
with BatchQuery() as batch:
batch.add_callback(my_callback)
# this error bubbling up through context manager
# should prevent callback runs (along with b.execute())
raise SomeError
# still same call history. Nothing added
assert len(call_history) == 1
# but if execute ran, even with an error bubbling through
# the callbacks also would have fired
with self.assertRaises(SomeError):
with BatchQuery(execute_on_exception=True) as batch:
batch.add_callback(my_callback)
# this error bubbling up through context manager
# should prevent callback runs (along with b.execute())
raise SomeError
# still same call history
assert len(call_history) == 2

View File

@@ -0,0 +1,95 @@
from cqlengine.management import sync_table, drop_table
from cqlengine.tests.base import BaseCassEngTestCase
from cqlengine.models import Model
from uuid import uuid4
from cqlengine import columns
import mock
from cqlengine import ALL, BatchQuery
class TestConsistencyModel(Model):
__keyspace__ = 'test'
id = columns.UUID(primary_key=True, default=lambda:uuid4())
count = columns.Integer()
text = columns.Text(required=False)
class BaseConsistencyTest(BaseCassEngTestCase):
@classmethod
def setUpClass(cls):
super(BaseConsistencyTest, cls).setUpClass()
sync_table(TestConsistencyModel)
@classmethod
def tearDownClass(cls):
super(BaseConsistencyTest, cls).tearDownClass()
drop_table(TestConsistencyModel)
class TestConsistency(BaseConsistencyTest):
def test_create_uses_consistency(self):
qs = TestConsistencyModel.consistency(ALL)
with mock.patch.object(self.session, 'execute') as m:
qs.create(text="i am not fault tolerant this way")
args = m.call_args
self.assertEqual(ALL, args[0][0].consistency_level)
def test_queryset_is_returned_on_create(self):
qs = TestConsistencyModel.consistency(ALL)
self.assertTrue(isinstance(qs, TestConsistencyModel.__queryset__), type(qs))
def test_update_uses_consistency(self):
t = TestConsistencyModel.create(text="bacon and eggs")
t.text = "ham sandwich"
with mock.patch.object(self.session, 'execute') as m:
t.consistency(ALL).save()
args = m.call_args
self.assertEqual(ALL, args[0][0].consistency_level)
def test_batch_consistency(self):
with mock.patch.object(self.session, 'execute') as m:
with BatchQuery(consistency=ALL) as b:
TestConsistencyModel.batch(b).create(text="monkey")
args = m.call_args
self.assertEqual(ALL, args[0][0].consistency_level)
with mock.patch.object(self.session, 'execute') as m:
with BatchQuery() as b:
TestConsistencyModel.batch(b).create(text="monkey")
args = m.call_args
self.assertNotEqual(ALL, args[0][0].consistency_level)
def test_blind_update(self):
t = TestConsistencyModel.create(text="bacon and eggs")
t.text = "ham sandwich"
uid = t.id
with mock.patch.object(self.session, 'execute') as m:
TestConsistencyModel.objects(id=uid).consistency(ALL).update(text="grilled cheese")
args = m.call_args
self.assertEqual(ALL, args[0][0].consistency_level)
def test_delete(self):
# ensures we always carry consistency through on delete statements
t = TestConsistencyModel.create(text="bacon and eggs")
t.text = "ham and cheese sandwich"
uid = t.id
with mock.patch.object(self.session, 'execute') as m:
t.consistency(ALL).delete()
with mock.patch.object(self.session, 'execute') as m:
TestConsistencyModel.objects(id=uid).consistency(ALL).delete()
args = m.call_args
self.assertEqual(ALL, args[0][0].consistency_level)

View File

@@ -0,0 +1,34 @@
import os
from unittest import TestCase, skipUnless
from cqlengine import Model, Integer
from cqlengine.management import sync_table
from cqlengine.tests import base
import resource
import gc
class LoadTest(Model):
__keyspace__ = 'test'
k = Integer(primary_key=True)
v = Integer()
@skipUnless("LOADTEST" in os.environ, "LOADTEST not on")
def test_lots_of_queries():
sync_table(LoadTest)
import objgraph
gc.collect()
objgraph.show_most_common_types()
print "Starting..."
for i in range(1000000):
if i % 25000 == 0:
# print memory statistic
print "Memory usage: %s" % (resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)
LoadTest.create(k=i, v=i)
objgraph.show_most_common_types()
raise Exception("you shouldn't be here")

View File

@@ -0,0 +1,178 @@
"""
Tests surrounding the blah.timestamp( timedelta(seconds=30) ) format.
"""
from datetime import timedelta, datetime
from uuid import uuid4
import mock
import sure
from cqlengine import Model, columns, BatchQuery
from cqlengine.management import sync_table
from cqlengine.tests.base import BaseCassEngTestCase
class TestTimestampModel(Model):
__keyspace__ = 'test'
id = columns.UUID(primary_key=True, default=lambda:uuid4())
count = columns.Integer()
class BaseTimestampTest(BaseCassEngTestCase):
@classmethod
def setUpClass(cls):
super(BaseTimestampTest, cls).setUpClass()
sync_table(TestTimestampModel)
class BatchTest(BaseTimestampTest):
def test_batch_is_included(self):
with mock.patch.object(self.session, "execute") as m, BatchQuery(timestamp=timedelta(seconds=30)) as b:
TestTimestampModel.batch(b).create(count=1)
"USING TIMESTAMP".should.be.within(m.call_args[0][0].query_string)
class CreateWithTimestampTest(BaseTimestampTest):
def test_batch(self):
with mock.patch.object(self.session, "execute") as m, BatchQuery() as b:
TestTimestampModel.timestamp(timedelta(seconds=10)).batch(b).create(count=1)
query = m.call_args[0][0].query_string
query.should.match(r"INSERT.*USING TIMESTAMP")
query.should_not.match(r"TIMESTAMP.*INSERT")
def test_timestamp_not_included_on_normal_create(self):
with mock.patch.object(self.session, "execute") as m:
TestTimestampModel.create(count=2)
"USING TIMESTAMP".shouldnt.be.within(m.call_args[0][0].query_string)
def test_timestamp_is_set_on_model_queryset(self):
delta = timedelta(seconds=30)
tmp = TestTimestampModel.timestamp(delta)
tmp._timestamp.should.equal(delta)
def test_non_batch_syntax_integration(self):
tmp = TestTimestampModel.timestamp(timedelta(seconds=30)).create(count=1)
tmp.should.be.ok
def test_non_batch_syntax_unit(self):
with mock.patch.object(self.session, "execute") as m:
TestTimestampModel.timestamp(timedelta(seconds=30)).create(count=1)
query = m.call_args[0][0].query_string
"USING TIMESTAMP".should.be.within(query)
class UpdateWithTimestampTest(BaseTimestampTest):
def setUp(self):
self.instance = TestTimestampModel.create(count=1)
super(UpdateWithTimestampTest, self).setUp()
def test_instance_update_includes_timestamp_in_query(self):
# not a batch
with mock.patch.object(self.session, "execute") as m:
self.instance.timestamp(timedelta(seconds=30)).update(count=2)
"USING TIMESTAMP".should.be.within(m.call_args[0][0].query_string)
def test_instance_update_in_batch(self):
with mock.patch.object(self.session, "execute") as m, BatchQuery() as b:
self.instance.batch(b).timestamp(timedelta(seconds=30)).update(count=2)
query = m.call_args[0][0].query_string
"USING TIMESTAMP".should.be.within(query)
class DeleteWithTimestampTest(BaseTimestampTest):
def test_non_batch(self):
"""
we don't expect the model to come back at the end because the deletion timestamp should be in the future
"""
uid = uuid4()
tmp = TestTimestampModel.create(id=uid, count=1)
TestTimestampModel.get(id=uid).should.be.ok
tmp.timestamp(timedelta(seconds=5)).delete()
with self.assertRaises(TestTimestampModel.DoesNotExist):
TestTimestampModel.get(id=uid)
tmp = TestTimestampModel.create(id=uid, count=1)
with self.assertRaises(TestTimestampModel.DoesNotExist):
TestTimestampModel.get(id=uid)
# calling .timestamp sets the TS on the model
tmp.timestamp(timedelta(seconds=5))
tmp._timestamp.should.be.ok
# calling save clears the set timestamp
tmp.save()
tmp._timestamp.shouldnt.be.ok
tmp.timestamp(timedelta(seconds=5))
tmp.update()
tmp._timestamp.shouldnt.be.ok
def test_blind_delete(self):
"""
we don't expect the model to come back at the end because the deletion timestamp should be in the future
"""
uid = uuid4()
tmp = TestTimestampModel.create(id=uid, count=1)
TestTimestampModel.get(id=uid).should.be.ok
TestTimestampModel.objects(id=uid).timestamp(timedelta(seconds=5)).delete()
with self.assertRaises(TestTimestampModel.DoesNotExist):
TestTimestampModel.get(id=uid)
tmp = TestTimestampModel.create(id=uid, count=1)
with self.assertRaises(TestTimestampModel.DoesNotExist):
TestTimestampModel.get(id=uid)
def test_blind_delete_with_datetime(self):
"""
we don't expect the model to come back at the end because the deletion timestamp should be in the future
"""
uid = uuid4()
tmp = TestTimestampModel.create(id=uid, count=1)
TestTimestampModel.get(id=uid).should.be.ok
plus_five_seconds = datetime.now() + timedelta(seconds=5)
TestTimestampModel.objects(id=uid).timestamp(plus_five_seconds).delete()
with self.assertRaises(TestTimestampModel.DoesNotExist):
TestTimestampModel.get(id=uid)
tmp = TestTimestampModel.create(id=uid, count=1)
with self.assertRaises(TestTimestampModel.DoesNotExist):
TestTimestampModel.get(id=uid)
def test_delete_in_the_past(self):
uid = uuid4()
tmp = TestTimestampModel.create(id=uid, count=1)
TestTimestampModel.get(id=uid).should.be.ok
# delete the in past, should not affect the object created above
TestTimestampModel.objects(id=uid).timestamp(timedelta(seconds=-60)).delete()
TestTimestampModel.get(id=uid)

121
cqlengine/tests/test_ttl.py Normal file
View File

@@ -0,0 +1,121 @@
from cqlengine.management import sync_table, drop_table
from cqlengine.tests.base import BaseCassEngTestCase
from cqlengine.models import Model
from uuid import uuid4
from cqlengine import columns
import mock
from cqlengine.connection import get_session
class TestTTLModel(Model):
__keyspace__ = 'test'
id = columns.UUID(primary_key=True, default=lambda:uuid4())
count = columns.Integer()
text = columns.Text(required=False)
class BaseTTLTest(BaseCassEngTestCase):
@classmethod
def setUpClass(cls):
super(BaseTTLTest, cls).setUpClass()
sync_table(TestTTLModel)
@classmethod
def tearDownClass(cls):
super(BaseTTLTest, cls).tearDownClass()
drop_table(TestTTLModel)
class TTLQueryTests(BaseTTLTest):
def test_update_queryset_ttl_success_case(self):
""" tests that ttls on querysets work as expected """
def test_select_ttl_failure(self):
""" tests that ttls on select queries raise an exception """
class TTLModelTests(BaseTTLTest):
def test_ttl_included_on_create(self):
""" tests that ttls on models work as expected """
session = get_session()
with mock.patch.object(session, 'execute') as m:
TestTTLModel.ttl(60).create(text="hello blake")
query = m.call_args[0][0].query_string
self.assertIn("USING TTL", query)
def test_queryset_is_returned_on_class(self):
"""
ensures we get a queryset descriptor back
"""
qs = TestTTLModel.ttl(60)
self.assertTrue(isinstance(qs, TestTTLModel.__queryset__), type(qs))
class TTLInstanceUpdateTest(BaseTTLTest):
def test_update_includes_ttl(self):
session = get_session()
model = TestTTLModel.create(text="goodbye blake")
with mock.patch.object(session, 'execute') as m:
model.ttl(60).update(text="goodbye forever")
query = m.call_args[0][0].query_string
self.assertIn("USING TTL", query)
def test_update_syntax_valid(self):
# sanity test that ensures the TTL syntax is accepted by cassandra
model = TestTTLModel.create(text="goodbye blake")
model.ttl(60).update(text="goodbye forever")
class TTLInstanceTest(BaseTTLTest):
def test_instance_is_returned(self):
"""
ensures that we properly handle the instance.ttl(60).save() scenario
:return:
"""
o = TestTTLModel.create(text="whatever")
o.text = "new stuff"
o = o.ttl(60)
self.assertEqual(60, o._ttl)
def test_ttl_is_include_with_query_on_update(self):
session = get_session()
o = TestTTLModel.create(text="whatever")
o.text = "new stuff"
o = o.ttl(60)
with mock.patch.object(session, 'execute') as m:
o.save()
query = m.call_args[0][0].query_string
self.assertIn("USING TTL", query)
class TTLBlindUpdateTest(BaseTTLTest):
def test_ttl_included_with_blind_update(self):
session = get_session()
o = TestTTLModel.create(text="whatever")
tid = o.id
with mock.patch.object(session, 'execute') as m:
TestTTLModel.objects(id=tid).ttl(60).update(text="bacon")
query = m.call_args[0][0].query_string
self.assertIn("USING TTL", query)

View File

@@ -11,7 +11,7 @@
# All configuration values have a default; values that are commented out
# serve to show the default.
import sys, os
import os
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
@@ -41,16 +41,18 @@ master_doc = 'index'
# General information about the project.
project = u'cqlengine'
copyright = u'2012, Blake Eggleston'
copyright = u'2012, Blake Eggleston, Jon Haddad'
# The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the
# built documents.
#
__cqlengine_version_path__ = os.path.realpath(__file__ + '/../../cqlengine/VERSION')
# The short X.Y version.
version = '0.2'
version = open(__cqlengine_version_path__, 'r').readline().strip()
# The full version, including alpha/beta/rc tags.
release = '0.2'
release = version
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
@@ -183,8 +185,7 @@ latex_elements = {
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title, author, documentclass [howto/manual]).
latex_documents = [
('index', 'cqlengine.tex', u'cqlengine Documentation',
u'Blake Eggleston', 'manual'),
('index', 'cqlengine.tex', u'cqlengine Documentation', u'Blake Eggleston, Jon Haddad', 'manual'),
]
# The name of an image file (relative to this directory) to place at the top of
@@ -214,7 +215,7 @@ latex_documents = [
# (source start file, name, description, authors, manual section).
man_pages = [
('index', 'cqlengine', u'cqlengine Documentation',
[u'Blake Eggleston'], 1)
[u'Blake Eggleston, Jon Haddad'], 1)
]
# If true, show URL addresses after external links.
@@ -227,9 +228,9 @@ man_pages = [
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
('index', 'cqlengine', u'cqlengine Documentation',
u'Blake Eggleston', 'cqlengine', 'One line description of project.',
'Miscellaneous'),
('index', 'cqlengine', u'cqlengine Documentation',
u'Blake Eggleston, Jon Haddad', 'cqlengine', 'One line description of project.',
'Miscellaneous'),
]
# Documents to append as an appendix to all manuals.

View File

@@ -5,20 +5,31 @@
cqlengine documentation
=======================
cqlengine is a Cassandra CQL 3 Object Mapper for Python with an interface similar to the Django orm and mongoengine
**Users of versions < 0.16, the default keyspace 'cqlengine' has been removed. Please read this before upgrading:** :ref:`Breaking Changes <keyspace-change>`
cqlengine is a Cassandra CQL 3 Object Mapper for Python
:ref:`getting-started`
Download
========
`Github <https://github.com/cqlengine/cqlengine>`_
`PyPi <https://pypi.python.org/pypi/cqlengine>`_
Contents:
=========
.. toctree::
:maxdepth: 2
topics/models
topics/queryset
topics/columns
topics/connection
topics/manage_schemas
topics/faq
.. _getting-started:
@@ -32,18 +43,24 @@ Getting Started
from cqlengine import Model
class ExampleModel(Model):
example_id = columns.UUID(primary_key=True)
example_id = columns.UUID(primary_key=True, default=uuid.uuid4)
example_type = columns.Integer(index=True)
created_at = columns.DateTime()
description = columns.Text(required=False)
#next, setup the connection to your cassandra server(s)...
>>> from cqlengine import connection
>>> connection.setup(['127.0.0.1:9160'])
# see http://datastax.github.io/python-driver/api/cassandra/cluster.html for options
# the list of hosts will be passed to create a Cluster() instance
>>> connection.setup(['127.0.0.1'])
# if you're connecting to a 1.2 cluster
>>> connection.setup(['127.0.0.1'], protocol_version=1)
#...and create your CQL table
>>> from cqlengine.management import create_table
>>> create_table(ExampleModel)
>>> from cqlengine.management import sync_table
>>> sync_table(ExampleModel)
#now we can create some rows:
>>> em1 = ExampleModel.create(example_type=0, description="example1", created_at=datetime.now())
@@ -54,8 +71,6 @@ Getting Started
>>> em6 = ExampleModel.create(example_type=1, description="example6", created_at=datetime.now())
>>> em7 = ExampleModel.create(example_type=1, description="example7", created_at=datetime.now())
>>> em8 = ExampleModel.create(example_type=1, description="example8", created_at=datetime.now())
# Note: the UUID and DateTime columns will create uuid4 and datetime.now
# values automatically if we don't specify them when creating new rows
#and now we can run some queries against our table
>>> ExampleModel.objects.count()
@@ -64,7 +79,7 @@ Getting Started
>>> q.count()
4
>>> for instance in q:
>>> print q.description
>>> print instance.description
example5
example6
example7
@@ -78,7 +93,7 @@ Getting Started
>>> q2.count()
1
>>> for instance in q2:
>>> print q.description
>>> print instance.description
example5
@@ -86,10 +101,6 @@ Getting Started
`Users Mailing List <https://groups.google.com/forum/?fromgroups#!forum/cqlengine-users>`_
`Dev Mailing List <https://groups.google.com/forum/?fromgroups#!forum/cqlengine-dev>`_
**NOTE: cqlengine is in alpha and under development, some features may change. Make sure to check the changelog and test your app before upgrading**
Indices and tables
==================

View File

@@ -2,6 +2,10 @@
Columns
=======
**Users of versions < 0.4, please read this post before upgrading:** `Breaking Changes`_
.. _Breaking Changes: https://groups.google.com/forum/?fromgroups#!topic/cqlengine-users/erkSNe1JwuU
.. module:: cqlengine.columns
.. class:: Bytes()
@@ -14,9 +18,9 @@ Columns
.. class:: Ascii()
Stores a US-ASCII character string ::
columns.Ascii()
.. class:: Text()
@@ -34,26 +38,50 @@ Columns
.. class:: Integer()
Stores an integer value ::
Stores a 32-bit signed integer value ::
columns.Integer()
.. class:: BigInt()
Stores a 64-bit signed long value ::
columns.BigInt()
.. class:: VarInt()
Stores an arbitrary-precision integer ::
columns.VarInt()
.. class:: DateTime()
Stores a datetime value.
Python's datetime.now callable is set as the default value for this column ::
columns.DateTime()
.. class:: UUID()
Stores a type 1 or type 4 UUID.
Python's uuid.uuid4 callable is set as the default value for this column. ::
columns.UUID()
.. class:: TimeUUID()
Stores a UUID value as the cql type 'timeuuid' ::
columns.TimeUUID()
.. classmethod:: from_datetime(dt)
generates a TimeUUID for the given datetime
:param dt: the datetime to create a time uuid from
:type dt: datetime.datetime
:returns: a time uuid created from the given datetime
:rtype: uuid1
.. class:: Boolean()
Stores a boolean True or False value ::
@@ -77,16 +105,24 @@ Columns
columns.Decimal()
.. class:: Counter()
Counters can be incremented and decremented ::
columns.Counter()
Collection Type Columns
----------------------------
CQLEngine also supports container column types. Each container column requires a column class argument to specify what type of objects it will hold. The Map column requires 2, one for the key, and the other for the value
*Example*
.. code-block:: python
class Person(Model):
id = columns.UUID(primary_key=True, default=uuid.uuid4)
first_name = columns.Text()
last_name = columns.Text()
@@ -94,7 +130,7 @@ Collection Type Columns
enemies = columns.Set(columns.Text)
todo_list = columns.List(columns.Text)
birthdays = columns.Map(columns.Text, columns.DateTime)
.. class:: Set()
@@ -145,12 +181,16 @@ Column Options
If True, this column is created as a primary key field. A model can have multiple primary keys. Defaults to False.
*In CQL, there are 2 types of primary keys: partition keys and clustering keys. As with CQL, the first primary key is the partition key, and all others are clustering keys.*
*In CQL, there are 2 types of primary keys: partition keys and clustering keys. As with CQL, the first primary key is the partition key, and all others are clustering keys, unless partition keys are specified manually using* :attr:`BaseColumn.partition_key`
.. attribute:: BaseColumn.partition_key
If True, this column is created as partition primary key. There may be many partition keys defined, forming a *composite partition key*
.. attribute:: BaseColumn.index
If True, an index will be created for this column. Defaults to False.
*Note: Indexes can only be created on models with one primary key*
.. attribute:: BaseColumn.db_field
@@ -163,5 +203,8 @@ Column Options
.. attribute:: BaseColumn.required
If True, this model cannot be saved without a value defined for this column. Defaults to True. Primary key fields cannot have their required fields set to False.
If True, this model cannot be saved without a value defined for this column. Defaults to False. Primary key fields cannot have their required fields set to False.
.. attribute:: BaseColumn.clustering_order
Defines CLUSTERING ORDER for this column (valid choices are "asc" (default) or "desc"). It may be specified only for clustering primary keys - more: http://www.datastax.com/docs/1.2/cql_cli/cql/CREATE_TABLE#using-clustering-order

View File

@@ -1,17 +1,26 @@
==============
==========
Connection
==============
==========
**Users of versions < 0.4, please read this post before upgrading:** `Breaking Changes`_
.. _Breaking Changes: https://groups.google.com/forum/?fromgroups#!topic/cqlengine-users/erkSNe1JwuU
.. module:: cqlengine.connection
The setup function in `cqlengine.connection` records the Cassandra servers to connect to.
If there is a problem with one of the servers, cqlengine will try to connect to each of the other connections before failing.
.. function:: setup(hosts [, username=None, password=None])
.. function:: setup(hosts)
:param hosts: list of hosts, strings in the <hostname>:<port>, or just <hostname>
:type hosts: list
:param consistency: the consistency level of the connection, defaults to 'ONE'
:type consistency: int
# see http://datastax.github.io/python-driver/api/cassandra.html#cassandra.ConsistencyLevel
Records the hosts and connects to one of them
See the example at :ref:`getting-started`

8
docs/topics/faq.rst Normal file
View File

@@ -0,0 +1,8 @@
==========================
Frequently Asked Questions
==========================
Q: Why don't updates work correctly on models instantiated as Model(field=blah, field2=blah2)?
-------------------------------------------------------------------
A: The recommended way to create new rows is with the models .create method. The values passed into a model's init method are interpreted by the model as the values as they were read from a row. This allows the model to "know" which rows have changed since the row was read out of cassandra, and create suitable update statements.

View File

@@ -1,6 +1,12 @@
===============
Managing Schmas
===============
================
Managing Schemas
================
**Users of versions < 0.4, please read this post before upgrading:** `Breaking Changes`_
.. _Breaking Changes: https://groups.google.com/forum/?fromgroups#!topic/cqlengine-users/erkSNe1JwuU
.. module:: cqlengine.connection
.. module:: cqlengine.management
@@ -20,23 +26,23 @@ Once a connection has been made to Cassandra, you can use the functions in ``cql
deletes the keyspace with the given name
.. function:: create_table(model [, create_missing_keyspace=True])
.. function:: sync_table(model [, create_missing_keyspace=True])
:param model: the :class:`~cqlengine.model.Model` class to make a table with
:type model: :class:`~cqlengine.model.Model`
:param create_missing_keyspace: *Optional* If True, the model's keyspace will be created if it does not already exist. Defaults to ``True``
:type create_missing_keyspace: bool
creates a CQL table for the given model
syncs a python model to cassandra (creates & alters)
.. function:: delete_table(model)
.. function:: drop_table(model)
:param model: the :class:`~cqlengine.model.Model` class to delete a column family for
:type model: :class:`~cqlengine.model.Model`
deletes the CQL table for the given model
See the example at :ref:`getting-started`

View File

@@ -2,12 +2,18 @@
Models
======
**Users of versions < 0.4, please read this post before upgrading:** `Breaking Changes`_
.. _Breaking Changes: https://groups.google.com/forum/?fromgroups#!topic/cqlengine-users/erkSNe1JwuU
.. module:: cqlengine.connection
.. module:: cqlengine.models
A model is a python class representing a CQL table.
Example
=======
Examples
========
This example defines a Person table, with the columns ``first_name`` and ``last_name``
@@ -15,11 +21,11 @@ This example defines a Person table, with the columns ``first_name`` and ``last_
from cqlengine import columns
from cqlengine.models import Model
class Person(Model):
first_name = columns.Text()
last_name = columns.Text()
The Person model would create this CQL table:
@@ -32,10 +38,42 @@ The Person model would create this CQL table:
PRIMARY KEY (id)
)
Here's an example of a comment table created with clustering keys, in descending order:
.. code-block:: python
from cqlengine import columns
from cqlengine.models import Model
class Comment(Model):
photo_id = columns.UUID(primary_key=True)
comment_id = columns.TimeUUID(primary_key=True, clustering_order="DESC")
comment = columns.Text()
The Comment model's ``create table`` would look like the following:
.. code-block:: sql
CREATE TABLE comment (
photo_id uuid,
comment_id timeuuid,
comment text,
PRIMARY KEY (photo_id, comment_id)
) WITH CLUSTERING ORDER BY (comment_id DESC)
To sync the models to the database, you may do the following:
.. code-block:: python
from cqlengine.management import sync_table
sync_table(Person)
sync_table(Comment)
Columns
=======
Columns in your models map to columns in your CQL table. You define CQL columns by defining column attributes on your model classes. For a model to be valid it needs at least one primary key column (defined automatically if you don't define one) and one non-primary key column.
Columns in your models map to columns in your CQL table. You define CQL columns by defining column attributes on your model classes. For a model to be valid it needs at least one primary key column and one non-primary key column.
Just as in CQL, the order you define your columns in is important, and is the same order they are defined in on a model's corresponding table.
@@ -61,26 +99,37 @@ Column Types
Column Options
--------------
Each column can be defined with optional arguments to modify the way they behave. While some column types may define additional column options, these are the options that are available on all columns:
Each column can be defined with optional arguments to modify the way they behave. While some column types may
define additional column options, these are the options that are available on all columns:
:attr:`~cqlengine.columns.BaseColumn.primary_key`
If True, this column is created as a primary key field. A model can have multiple primary keys. Defaults to False.
*In CQL, there are 2 types of primary keys: partition keys and clustering keys. As with CQL, the first primary key is the partition key, and all others are clustering keys.*
*In CQL, there are 2 types of primary keys: partition keys and clustering keys. As with CQL, the first
primary key is the partition key, and all others are clustering keys, unless partition keys are specified
manually using* :attr:`~cqlengine.columns.BaseColumn.partition_key`
:attr:`~cqlengine.columns.BaseColumn.partition_key`
If True, this column is created as partition primary key. There may be many partition keys defined,
forming a *composite partition key*
:attr:`~cqlengine.columns.BaseColumn.clustering_order`
``ASC`` or ``DESC``, determines the clustering order of a clustering key.
:attr:`~cqlengine.columns.BaseColumn.index`
If True, an index will be created for this column. Defaults to False.
*Note: Indexes can only be created on models with one primary key*
:attr:`~cqlengine.columns.BaseColumn.db_field`
Explicitly sets the name of the column in the database table. If this is left blank, the column name will be the same as the name of the column attribute. Defaults to None.
Explicitly sets the name of the column in the database table. If this is left blank, the column name will be
the same as the name of the column attribute. Defaults to None.
:attr:`~cqlengine.columns.BaseColumn.default`
The default value for this column. If a model instance is saved without a value for this column having been defined, the default value will be used. This can be either a value or a callable object (ie: datetime.now is a valid default argument).
The default value for this column. If a model instance is saved without a value for this column having been
defined, the default value will be used. This can be either a value or a callable object (ie: datetime.now is a valid default argument).
Callable defaults will be called each time a default is assigned to a None value
:attr:`~cqlengine.columns.BaseColumn.required`
If True, this model cannot be saved without a value defined for this column. Defaults to True. Primary key fields cannot have their required fields set to False.
If True, this model cannot be saved without a value defined for this column. Defaults to False. Primary key fields always require values.
Model Methods
=============
@@ -93,7 +142,7 @@ Model Methods
*Example*
.. code-block:: python
#using the person model from earlier:
class Person(Model):
first_name = columns.Text()
@@ -102,7 +151,7 @@ Model Methods
person = Person(first_name='Blake', last_name='Eggleston')
person.first_name #returns 'Blake'
person.last_name #returns 'Eggleston'
.. method:: save()
@@ -117,22 +166,265 @@ Model Methods
#saves it to Cassandra
person.save()
.. method:: delete()
Deletes the object from the database.
.. method:: batch(batch_object)
Sets the batch object to run instance updates and inserts queries with.
.. method:: timestamp(timedelta_or_datetime)
Sets the timestamp for the query
.. method:: ttl(ttl_in_sec)
Sets the ttl values to run instance updates and inserts queries with.
.. method:: update(**values)
Performs an update on the model instance. You can pass in values to set on the model
for updating, or you can call without values to execute an update against any modified
fields. If no fields on the model have been modified since loading, no query will be
performed. Model validation is performed normally.
.. method:: get_changed_columns()
Returns a list of column names that have changed since the model was instantiated or saved
Model Attributes
================
.. attribute:: Model.table_name
.. attribute:: Model.__abstract__
*Optional.* Indicates that this model is only intended to be used as a base class for other models. You can't create tables for abstract models, but checks around schema validity are skipped during class construction.
.. attribute:: Model.__table_name__
*Optional.* Sets the name of the CQL table for this model. If left blank, the table name will be the name of the model, with it's module name as it's prefix. Manually defined table names are not inherited.
.. attribute:: Model.keyspace
.. _keyspace-change:
.. attribute:: Model.__keyspace__
*Optional.* Sets the name of the keyspace used by this model. Defaulst to cqlengine
Sets the name of the keyspace used by this model.
**Prior to cqlengine 0.16, this setting defaulted
to 'cqlengine'. As of 0.16, this field needs to be set on all non-abstract models, or their base classes.**
Table Polymorphism
==================
As of cqlengine 0.8, it is possible to save and load different model classes using a single CQL table.
This is useful in situations where you have different object types that you want to store in a single cassandra row.
For instance, suppose you want a table that stores rows of pets owned by an owner:
.. code-block:: python
class Pet(Model):
__table_name__ = 'pet'
owner_id = UUID(primary_key=True)
pet_id = UUID(primary_key=True)
pet_type = Text(polymorphic_key=True)
name = Text()
def eat(self, food):
pass
def sleep(self, time):
pass
class Cat(Pet):
__polymorphic_key__ = 'cat'
cuteness = Float()
def tear_up_couch(self):
pass
class Dog(Pet):
__polymorphic_key__ = 'dog'
fierceness = Float()
def bark_all_night(self):
pass
After calling ``sync_table`` on each of these tables, the columns defined in each model will be added to the
``pet`` table. Additionally, saving ``Cat`` and ``Dog`` models will save the meta data needed to identify each row
as either a cat or dog.
To setup a polymorphic model structure, follow these steps
1. Create a base model with a column set as the polymorphic_key (set ``polymorphic_key=True`` in the column definition)
2. Create subclass models, and define a unique ``__polymorphic_key__`` value on each
3. Run ``sync_table`` on each of the sub tables
**About the polymorphic key**
The polymorphic key is what cqlengine uses under the covers to map logical cql rows to the appropriate model type. The
base model maintains a map of polymorphic keys to subclasses. When a polymorphic model is saved, this value is automatically
saved into the polymorphic key column. You can set the polymorphic key column to any column type that you like, with
the exception of container and counter columns, although ``Integer`` columns make the most sense. Additionally, if you
set ``index=True`` on your polymorphic key column, you can execute queries against polymorphic subclasses, and a
``WHERE`` clause will be automatically added to your query, returning only rows of that type. Note that you must
define a unique ``__polymorphic_key__`` value to each subclass, and that you can only assign a single polymorphic
key column per model
Extending Model Validation
==========================
Each time you save a model instance in cqlengine, the data in the model is validated against the schema you've defined
for your model. Most of the validation is fairly straightforward, it basically checks that you're not trying to do
something like save text into an integer column, and it enforces the ``required`` flag set on column definitions.
It also performs any transformations needed to save the data properly.
However, there are often additional constraints or transformations you want to impose on your data, beyond simply
making sure that Cassandra won't complain when you try to insert it. To define additional validation on a model,
extend the model's validation method:
.. code-block:: python
class Member(Model):
person_id = UUID(primary_key=True)
name = Text(required=True)
def validate(self):
super(Member, self).validate()
if self.name == 'jon':
raise ValidationError('no jon\'s allowed')
*Note*: while not required, the convention is to raise a ``ValidationError`` (``from cqlengine import ValidationError``)
if validation fails
Table Properties
================
Each table can have its own set of configuration options.
These can be specified on a model with the following attributes:
.. attribute:: Model.__bloom_filter_fp_chance
.. attribute:: Model.__caching__
.. attribute:: Model.__comment__
.. attribute:: Model.__dclocal_read_repair_chance__
.. attribute:: Model.__default_time_to_live__
.. attribute:: Model.__gc_grace_seconds__
.. attribute:: Model.__index_interval__
.. attribute:: Model.__memtable_flush_period_in_ms__
.. attribute:: Model.__populate_io_cache_on_flush__
.. attribute:: Model.__read_repair_chance__
.. attribute:: Model.__replicate_on_write__
Example:
.. code-block:: python
from cqlengine import ROWS_ONLY, columns
from cqlengine.models import Model
class User(Model):
__caching__ = ROWS_ONLY # cache only rows instead of keys only by default
__gc_grace_seconds__ = 86400 # 1 day instead of the default 10 days
user_id = columns.UUID(primary_key=True)
name = columns.Text()
Will produce the following CQL statement:
.. code-block:: sql
CREATE TABLE cqlengine.user (
user_id uuid,
name text,
PRIMARY KEY (user_id)
) WITH caching = 'rows_only'
AND gc_grace_seconds = 86400;
See the `list of supported table properties for more information
<http://www.datastax.com/documentation/cql/3.1/cql/cql_reference/tabProp.html>`_.
Compaction Options
==================
As of cqlengine 0.7 we've added support for specifying compaction options. cqlengine will only use your compaction options if you have a strategy set. When a table is synced, it will be altered to match the compaction options set on your table. This means that if you are changing settings manually they will be changed back on resync. Do not use the compaction settings of cqlengine if you want to manage your compaction settings manually.
cqlengine supports all compaction options as of Cassandra 1.2.8.
Available Options:
.. attribute:: Model.__compaction_bucket_high__
.. attribute:: Model.__compaction_bucket_low__
.. attribute:: Model.__compaction_max_compaction_threshold__
.. attribute:: Model.__compaction_min_compaction_threshold__
.. attribute:: Model.__compaction_min_sstable_size__
.. attribute:: Model.__compaction_sstable_size_in_mb__
.. attribute:: Model.__compaction_tombstone_compaction_interval__
.. attribute:: Model.__compaction_tombstone_threshold__
For example:
.. code-block:: python
class User(Model):
__compaction__ = cqlengine.LeveledCompactionStrategy
__compaction_sstable_size_in_mb__ = 64
__compaction_tombstone_threshold__ = .2
user_id = columns.UUID(primary_key=True)
name = columns.Text()
or for SizeTieredCompaction:
.. code-block:: python
class TimeData(Model):
__compaction__ = SizeTieredCompactionStrategy
__compaction_bucket_low__ = .3
__compaction_bucket_high__ = 2
__compaction_min_threshold__ = 2
__compaction_max_threshold__ = 64
__compaction_tombstone_compaction_interval__ = 86400
Tables may use `LeveledCompactionStrategy` or `SizeTieredCompactionStrategy`. Both options are available in the top level cqlengine module. To reiterate, you will need to set your `__compaction__` option explicitly in order for cqlengine to handle any of your settings.
Manipulating model instances as dictionaries
============================================
As of cqlengine 0.12, we've added support for treating model instances like dictionaries. See below for examples.
.. code-block:: python
class Person(Model):
first_name = columns.Text()
last_name = columns.Text()
kevin = Person.create(first_name="Kevin", last_name="Deldycke")
dict(kevin) # returns {'first_name': 'Kevin', 'last_name': 'Deldycke'}
kevin['first_name'] # returns 'Kevin'
kevin.keys() # returns ['first_name', 'last_name']
kevin.values() # returns ['Kevin', 'Deldycke']
kevin.items() # returns [('first_name', 'Kevin'), ('last_name', 'Deldycke')]
kevin['first_name'] = 'KEVIN5000' # changes the models first name
Automatic Primary Keys
======================
CQL requires that all tables define at least one primary key. If a model definition does not include a primary key column, cqlengine will automatically add a uuid primary key column named ``id``.

View File

@@ -2,6 +2,12 @@
Making Queries
==============
**Users of versions < 0.4, please read this post before upgrading:** `Breaking Changes`_
.. _Breaking Changes: https://groups.google.com/forum/?fromgroups#!topic/cqlengine-users/erkSNe1JwuU
.. module:: cqlengine.connection
.. module:: cqlengine.query
Retrieving objects
@@ -23,15 +29,15 @@ Retrieving all objects
.. _retrieving-objects-with-filters:
Retrieving objects with filters
----------------------------------------
-------------------------------
Typically, you'll want to query only a subset of the records in your database.
That can be accomplished with the QuerySet's ``.filter(\*\*)`` method.
For example, given the model definition:
.. code-block:: python
class Automobile(Model):
manufacturer = columns.Text(primary_key=True)
year = columns.Integer(primary_key=True)
@@ -40,11 +46,17 @@ Retrieving objects with filters
...and assuming the Automobile table contains a record of every car model manufactured in the last 20 years or so, we can retrieve only the cars made by a single manufacturer like this:
.. code-block:: python
q = Automobile.objects.filter(manufacturer='Tesla')
You can also use the more convenient syntax:
.. code-block:: python
q = Automobile.objects(Automobile.manufacturer == 'Tesla')
We can then further filter our query with another call to **.filter**
.. code-block:: python
@@ -123,6 +135,7 @@ Filtering Operators
q = Automobile.objects.filter(manufacturer='Tesla')
q = q.filter(year__in=[2011, 2012])
:attr:`> (__gt) <query.QueryOperator.GreaterThanOperator>`
.. code-block:: python
@@ -130,6 +143,10 @@ Filtering Operators
q = Automobile.objects.filter(manufacturer='Tesla')
q = q.filter(year__gt=2010) # year > 2010
# or the nicer syntax
q.filter(Automobile.year > 2010)
:attr:`>= (__gte) <query.QueryOperator.GreaterThanOrEqualOperator>`
.. code-block:: python
@@ -137,6 +154,10 @@ Filtering Operators
q = Automobile.objects.filter(manufacturer='Tesla')
q = q.filter(year__gte=2010) # year >= 2010
# or the nicer syntax
q.filter(Automobile.year >= 2010)
:attr:`< (__lt) <query.QueryOperator.LessThanOperator>`
.. code-block:: python
@@ -144,6 +165,10 @@ Filtering Operators
q = Automobile.objects.filter(manufacturer='Tesla')
q = q.filter(year__lt=2012) # year < 2012
# or...
q.filter(Automobile.year < 2012)
:attr:`<= (__lte) <query.QueryOperator.LessThanOrEqualOperator>`
.. code-block:: python
@@ -151,6 +176,8 @@ Filtering Operators
q = Automobile.objects.filter(manufacturer='Tesla')
q = q.filter(year__lte=2012) # year <= 2012
q.filter(Automobile.year <= 2012)
TimeUUID Functions
==================
@@ -178,7 +205,29 @@ TimeUUID Functions
DataStream.filter(time__gt=cqlengine.MinTimeUUID(min_time), time__lt=cqlengine.MaxTimeUUID(max_time))
QuerySets are imutable
Token Function
==============
Token functon may be used only on special, virtual column pk__token, representing token of partition key (it also works for composite partition keys).
Cassandra orders returned items by value of partition key token, so using cqlengine.Token we can easy paginate through all table rows.
See http://cassandra.apache.org/doc/cql3/CQL.html#tokenFun
*Example*
.. code-block:: python
class Items(Model):
id = cqlengine.Text(primary_key=True)
data = cqlengine.Bytes()
query = Items.objects.all().limit(10)
first_page = list(query);
last = first_page[-1]
next_page = list(query.filter(pk__token__gt=cqlengine.Token(last.pk)))
QuerySets are immutable
======================
When calling any method that changes a queryset, the method does not actually change the queryset object it's called on, but returns a new queryset object with the attributes of the original queryset, plus the attributes added in the method call.
@@ -198,7 +247,7 @@ Ordering QuerySets
Since Cassandra is essentially a distributed hash table on steroids, the order you get records back in will not be particularly predictable.
However, you can set a column to order on with the ``.order_by(column_name)`` method.
However, you can set a column to order on with the ``.order_by(column_name)`` method.
*Example*
@@ -213,15 +262,36 @@ Ordering QuerySets
*For instance, given our Automobile model, year is the only column we can order on.*
Batch Queries
===============
Values Lists
============
There is a special QuerySet's method ``.values_list()`` - when called, QuerySet returns lists of values instead of model instances. It may significantly speedup things with lower memory footprint for large responses.
Each tuple contains the value from the respective field passed into the ``values_list()`` call — so the first item is the first field, etc. For example:
.. code-block:: python
items = list(range(20))
random.shuffle(items)
for i in items:
TestModel.create(id=1, clustering_key=i)
values = list(TestModel.objects.values_list('clustering_key', flat=True))
# [19L, 18L, 17L, 16L, 15L, 14L, 13L, 12L, 11L, 10L, 9L, 8L, 7L, 6L, 5L, 4L, 3L, 2L, 1L, 0L]
Batch Queries
=============
cqlengine now supports batch queries using the BatchQuery class. Batch queries can be started and stopped manually, or within a context manager. To add queries to the batch object, you just need to precede the create/save/delete call with a call to batch, and pass in the batch object.
Batch Query General Use Pattern
-------------------------------
cqlengine now supports batch queries using the BatchQuery class. Batch queries can be started and stopped manually, or within a context manager. To add queries to the batch object, you just need to precede the create/save/delete call with a call to batch, and pass in the batch object.
You can only create, update, and delete rows with a batch query, attempting to read rows out of the database with a batch query will fail.
.. code-block:: python
from cqlengine import BatchQuery
#using a context manager
@@ -241,6 +311,72 @@ Batch Queries
em3 = ExampleModel.batch(b).create(example_type=0, description="3", created_at=now)
b.execute()
# updating in a batch
b = BatchQuery()
em1.description = "new description"
em1.batch(b).save()
em2.description = "another new description"
em2.batch(b).save()
b.execute()
# deleting in a batch
b = BatchQuery()
ExampleModel.objects(id=some_id).batch(b).delete()
ExampleModel.objects(id=some_id2).batch(b).delete()
b.execute()
Typically you will not want the block to execute if an exception occurs inside the `with` block. However, in the case that this is desirable, it's achievable by using the following syntax:
.. code-block:: python
with BatchQuery(execute_on_exception=True) as b:
LogEntry.batch(b).create(k=1, v=1)
mystery_function() # exception thrown in here
LogEntry.batch(b).create(k=1, v=2) # this code is never reached due to the exception, but anything leading up to here will execute in the batch.
If an exception is thrown somewhere in the block, any statements that have been added to the batch will still be executed. This is useful for some logging situations.
Batch Query Execution Callbacks
-------------------------------
In order to allow secondary tasks to be chained to the end of batch, BatchQuery instances allow callbacks to be
registered with the batch, to be executed immediately after the batch executes.
Multiple callbacks can be attached to same BatchQuery instance, they are executed in the same order that they
are added to the batch.
The callbacks attached to a given batch instance are executed only if the batch executes. If the batch is used as a
context manager and an exception is raised, the queued up callbacks will not be run.
.. code-block:: python
def my_callback(*args, **kwargs):
pass
batch = BatchQuery()
batch.add_callback(my_callback)
batch.add_callback(my_callback, 'positional arg', named_arg='named arg value')
# if you need reference to the batch within the callback,
# just trap it in the arguments to be passed to the callback:
batch.add_callback(my_callback, cqlengine_batch=batch)
# once the batch executes...
batch.execute()
# the effect of the above scheduled callbacks will be similar to
my_callback()
my_callback('positional arg', named_arg='named arg value')
my_callback(cqlengine_batch=batch)
Failure in any of the callbacks does not affect the batch's execution, as the callbacks are started after the execution
of the batch is complete.
QuerySet method reference
=========================
@@ -250,6 +386,15 @@ QuerySet method reference
Returns a queryset matching all rows
.. method:: batch(batch_object)
Sets the batch object to run the query on. Note that running a select query with a batch object will raise an exception
.. method:: consistency(consistency_setting)
Sets the consistency level for the operation. Options may be imported from the top level :attr:`cqlengine` package.
.. method:: count()
Returns the number of matching rows in your QuerySet
@@ -269,7 +414,7 @@ QuerySet method reference
.. method:: limit(num)
Limits the number of results returned by Cassandra.
*Note that CQL's default limit is 10,000, so all queries without a limit set explicitly will have an implicit limit of 10,000*
.. method:: order_by(field_name)
@@ -277,8 +422,106 @@ QuerySet method reference
:param field_name: the name of the field to order on. *Note: the field_name must be a clustering key*
:type field_name: string
Sets the field to order on.
Sets the field to order on.
.. method:: allow_filtering()
Enables the (usually) unwise practive of querying on a clustering key without also defining a partition key
.. method:: timestamp(timestamp_or_long_or_datetime)
Allows for custom timestamps to be saved with the record.
.. method:: ttl(ttl_in_seconds)
:param ttl_in_seconds: time in seconds in which the saved values should expire
:type ttl_in_seconds: int
Sets the ttl to run the query query with. Note that running a select query with a ttl value will raise an exception
.. method:: update(**values)
Performs an update on the row selected by the queryset. Include values to update in the
update like so:
.. code-block:: python
Model.objects(key=n).update(value='x')
Passing in updates for columns which are not part of the model will raise a ValidationError.
Per column validation will be performed, but instance level validation will not
(`Model.validate` is not called).
The queryset update method also supports blindly adding and removing elements from container columns, without
loading a model instance from Cassandra.
Using the syntax `.update(column_name={x, y, z})` will overwrite the contents of the container, like updating a
non container column. However, adding `__<operation>` to the end of the keyword arg, makes the update call add
or remove items from the collection, without overwriting then entire column.
Given the model below, here are the operations that can be performed on the different container columns:
.. code-block:: python
class Row(Model):
row_id = columns.Integer(primary_key=True)
set_column = columns.Set(Integer)
list_column = columns.Set(Integer)
map_column = columns.Set(Integer, Integer)
:class:`~cqlengine.columns.Set`
- `add`: adds the elements of the given set to the column
- `remove`: removes the elements of the given set to the column
.. code-block:: python
# add elements to a set
Row.objects(row_id=5).update(set_column__add={6})
# remove elements to a set
Row.objects(row_id=5).update(set_column__remove={4})
:class:`~cqlengine.columns.List`
- `append`: appends the elements of the given list to the end of the column
- `prepend`: prepends the elements of the given list to the beginning of the column
.. code-block:: python
# append items to a list
Row.objects(row_id=5).update(list_column__append=[6, 7])
# prepend items to a list
Row.objects(row_id=5).update(list_column__prepend=[1, 2])
:class:`~cqlengine.columns.Map`
- `update`: adds the given keys/values to the columns, creating new entries if they didn't exist, and overwriting old ones if they did
.. code-block:: python
# add items to a map
Row.objects(row_id=5).update(map_column__update={1: 2, 3: 4})
Named Tables
===================
Named tables are a way of querying a table without creating an class. They're useful for querying system tables or exploring an unfamiliar database.
.. code-block:: python
from cqlengine.connection import setup
setup("127.0.0.1", "cqlengine_test")
from cqlengine.named import NamedTable
user = NamedTable("cqlengine_test", "user")
user.objects()
user.objects()[0]
# {u'pk': 1, u't': datetime.datetime(2014, 6, 26, 17, 10, 31, 774000)}

6
requirements-dev.txt Normal file
View File

@@ -0,0 +1,6 @@
nose
nose-progressive
profilestats
pycallgraph
ipdbplugin==1.2
ipdb==0.7

View File

@@ -1,5 +1,6 @@
cql==1.4.0
ipython==0.13.1
ipdb==0.7
Sphinx==1.1.3
mock==1.0.1
ipython
ipdb
Sphinx
mock
sure==1.2.5
cassandra-driver>=2.0.0

View File

@@ -4,42 +4,37 @@ from setuptools import setup, find_packages
#python setup.py register
#python setup.py sdist upload
version = '0.2'
version = open('cqlengine/VERSION', 'r').readline().strip()
long_desc = """
cqlengine is a Cassandra CQL ORM for Python in the style of the Django orm and mongoengine
Cassandra CQL 3 Object Mapper for Python
[Documentation](https://cqlengine.readthedocs.org/en/latest/)
[Report a Bug](https://github.com/bdeggleston/cqlengine/issues)
[Users Mailing List](https://groups.google.com/forum/?fromgroups#!forum/cqlengine-users)
[Dev Mailing List](https://groups.google.com/forum/?fromgroups#!forum/cqlengine-dev)
"""
setup(
name='cqlengine',
version=version,
description='Cassandra CQL ORM for Python in the style of the Django orm and mongoengine',
dependency_links = ['https://github.com/bdeggleston/cqlengine/archive/{0}.tar.gz#egg=cqlengine-{0}'.format(version)],
description='Cassandra CQL 3 Object Mapper for Python',
long_description=long_desc,
classifiers = [
"Development Status :: 3 - Alpha",
"Environment :: Web Environment",
"Environment :: Plugins",
"License :: OSI Approved :: BSD License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 2.6",
"Programming Language :: Python :: 2.7",
"Topic :: Internet :: WWW/HTTP",
"Topic :: Software Development :: Libraries :: Python Modules",
],
keywords='cassandra,cql,orm',
install_requires = ['cql'],
author='Blake Eggleston',
install_requires = ['cassandra-driver >= 2.0.0'],
author='Blake Eggleston, Jon Haddad',
author_email='bdeggleston@gmail.com',
url='https://github.com/bdeggleston/cqlengine',
url='https://github.com/cqlengine/cqlengine',
license='BSD',
packages=find_packages(),
include_package_data=True,

6
tox.ini Normal file
View File

@@ -0,0 +1,6 @@
[tox]
envlist=py27,py34
[testenv]
deps= -rrequirements.txt
commands=nosetests --no-skip

13
upgrading.txt Normal file
View File

@@ -0,0 +1,13 @@
0.15 Upgrade to use Datastax Native Driver
We no longer raise cqlengine based OperationalError when connection fails. Now using the exception thrown in the native driver.
If you're calling setup() with the old thrift port in place, you will get connection errors. Either remove it (the default native port is assumed) or change to the default native port, 9042.
The cqlengine connection pool has been removed. Connections are now managed by the native driver. This should drastically reduce the socket overhead as the native driver can multiplex queries.
If you were previously manually using "ALL", "QUORUM", etc, to specificy consistency levels, you will need to migrate to the cqlengine.ALL, QUORUM, etc instead. If you had been using the module level constants before, nothing should need to change.
No longer accepting username & password as arguments to setup. Use the native driver's authentication instead. See http://datastax.github.io/python-driver/api/cassandra/auth.html