diff --git a/.gitignore b/.gitignore index 16029f08..1b5006e6 100644 --- a/.gitignore +++ b/.gitignore @@ -38,3 +38,8 @@ html/ #Mr Developer .mr.developer.cfg +.noseids +/commitlog +/data + +docs/_build diff --git a/.travis.yml b/.travis.yml index 86b92c0a..172874a6 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,7 +8,9 @@ before_install: - sudo service cassandra start install: - "pip install -r requirements.txt --use-mirrors" - - "pip install pytest --use-mirrors" + #- "pip install pytest --use-mirrors" + script: - while [ ! -f /var/run/cassandra.pid ] ; do sleep 1 ; done # wait until cassandra is ready - - "py.test cqlengine/tests/" + - "nosetests --no-skip" + # - "py.test cqlengine/tests/" diff --git a/AUTHORS b/AUTHORS index b361bf0d..2da11977 100644 --- a/AUTHORS +++ b/AUTHORS @@ -1,9 +1,16 @@ PRIMARY AUTHORS Blake Eggleston +Jon Haddad -CONTRIBUTORS +CONTRIBUTORS (let us know if we missed you) Eric Scrivner - test environment, connection pooling -Jon Haddad - helped hash out some of the architecture - +Kevin Deldycke +Roey Berman +Danny Cosson +Michael Hall +Netanel Cohen-Tzemach +Mariusz Kryński +Greg Doermann +@pandu-rao \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000..62f6d786 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,12 @@ +### Contributing code to cqlengine + +Before submitting a pull request, please make sure that it follows these guidelines: + +* Limit yourself to one feature or bug fix per pull request. +* Include unittests that thoroughly test the feature/bug fix +* Write clear, descriptive commit messages. +* Many granular commits are preferred over large monolithic commits +* If you're adding or modifying features, please update the documentation + +If you're working on a big ticket item, please check in on [cqlengine-users](https://groups.google.com/forum/?fromgroups#!forum/cqlengine-users). +We'd hate to have to steer you in a different direction after you've already put in a lot of hard work. diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..061fc274 --- /dev/null +++ b/Makefile @@ -0,0 +1,14 @@ +clean: + find . -name *.pyc -delete + rm -rf cqlengine/__pycache__ + + +build: clean + python setup.py build + +release: clean + python setup.py sdist upload + + +.PHONY: build + diff --git a/README.md b/README.md index 7b244ffe..361fff38 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,9 @@ cqlengine =============== -cqlengine is a Cassandra CQL 3 Object Mapper for Python with an interface similar to the Django orm and mongoengine +cqlengine is a Cassandra CQL 3 Object Mapper for Python + +**Users of versions < 0.16, the default keyspace 'cqlengine' has been removed. Please read this before upgrading:** [Breaking Changes](https://cqlengine.readthedocs.org/en/latest/topics/models.html#keyspace-change) [Documentation](https://cqlengine.readthedocs.org/en/latest/) @@ -9,8 +11,6 @@ cqlengine is a Cassandra CQL 3 Object Mapper for Python with an interface simila [Users Mailing List](https://groups.google.com/forum/?fromgroups#!forum/cqlengine-users) -[Dev Mailing List](https://groups.google.com/forum/?fromgroups#!forum/cqlengine-dev) - ## Installation ``` pip install cqlengine @@ -20,23 +20,27 @@ pip install cqlengine ```python #first, define a model +import uuid from cqlengine import columns from cqlengine.models import Model class ExampleModel(Model): read_repair_chance = 0.05 # optional - defaults to 0.1 - example_id = columns.UUID(primary_key=True) + example_id = columns.UUID(primary_key=True, default=uuid.uuid4) example_type = columns.Integer(index=True) created_at = columns.DateTime() description = columns.Text(required=False) #next, setup the connection to your cassandra server(s)... >>> from cqlengine import connection ->>> connection.setup(['127.0.0.1:9160']) +>>> connection.setup(['127.0.0.1']) + +# or if you're still on cassandra 1.2 +>>> connection.setup(['127.0.0.1'], protocol_version=1) #...and create your CQL table ->>> from cqlengine.management import create_table ->>> create_table(ExampleModel) +>>> from cqlengine.management import sync_table +>>> sync_table(ExampleModel) #now we can create some rows: >>> em1 = ExampleModel.create(example_type=0, description="example1", created_at=datetime.now()) @@ -47,8 +51,6 @@ class ExampleModel(Model): >>> em6 = ExampleModel.create(example_type=1, description="example6", created_at=datetime.now()) >>> em7 = ExampleModel.create(example_type=1, description="example7", created_at=datetime.now()) >>> em8 = ExampleModel.create(example_type=1, description="example8", created_at=datetime.now()) -# Note: the UUID and DateTime columns will create uuid4 and datetime.now -# values automatically if we don't specify them when creating new rows #and now we can run some queries against our table >>> ExampleModel.objects.count() @@ -57,7 +59,7 @@ class ExampleModel(Model): >>> q.count() 4 >>> for instance in q: ->>> print q.description +>>> print instance.description example5 example6 example7 @@ -71,6 +73,10 @@ example8 >>> q2.count() 1 >>> for instance in q2: ->>> print q.description +>>> print instance.description example5 ``` + +## Contributing + +If you'd like to contribute to cqlengine, please read the [contributor guidelines](https://github.com/bdeggleston/cqlengine/blob/master/CONTRIBUTING.md) diff --git a/RELEASE.txt b/RELEASE.txt new file mode 100644 index 00000000..c4def8dd --- /dev/null +++ b/RELEASE.txt @@ -0,0 +1,7 @@ +Check changelog +Ensure docs are updated +Tests pass +Update VERSION +Push tag to github +Push release to pypi + diff --git a/changelog b/changelog index 3edb19f0..2aa66d41 100644 --- a/changelog +++ b/changelog @@ -1,6 +1,176 @@ CHANGELOG -0.2.1 (in progress) +0.16.0 + + 225: No handling of PagedResult from execute + 222: figure out how to make travis not totally fail when a test is skipped + 220: delayed connect. use setup(delayed_connect=True) + 218: throw exception on create_table and delete_table + 212: Unchanged primary key trigger error on update + 206: FAQ - why we dont' do #189 + 191: Add support for simple table properties. + 172: raise exception when None is passed in as query param + 170: trying to perform queries before connection is established should raise useful exception + 162: Not Possible to Make Non-Equality Filtering Queries + 161: Filtering on DateTime column + 154: Blob(bytes) column type issue + 128: remove default read_repair_chance & ensure changes are saved when using sync_table + 106: specify caching on model + 99: specify caching options table management + 94: type checking on sync_table (currently allows anything then fails miserably) + 73: remove default 'cqlengine' keyspace table management + 71: add named table and query expression usage to docs + + + +0.15.0 +* native driver integration + +0.14.0 + +* fix for setting map to empty (Lifto) +* report when creating models with attributes that conflict with cqlengine (maedhroz) +* use stable version of sure package (maedhroz) +* performance improvements + + +0.13.0 +* adding support for post batch callbacks (thanks Daniel Dotsenko github.com/dvdotsenko) +* fixing sync table for tables with multiple keys (thanks Daniel Dotsenko github.com/dvdotsenko) +* fixing bug in Bytes column (thanks Netanel Cohen-Tzemach github.com/natict) +* fixing bug with timestamps and DST (thanks Netanel Cohen-Tzemach github.com/natict) + +0.12.0 + +* Normalize and unquote boolean values. (Thanks Kevin Deldycke github.com/kdeldycke) +* Fix race condition in connection manager (Thanks Roey Berman github.com/bergundy) +* allow access to instance columns as if it is a dict (Thanks Kevin Deldycke github.com/kdeldycke) +* Added support for blind partial updates to queryset (Thanks Danny Cosson github.com/dcosson) +* Model instance equality check bugfix (Thanks to Michael Hall, github.com/mahall) +* Fixed bug syncing tables with camel cased names (Thanks to Netanel Cohen-Tzemach, github.com/natict) +* Fixed bug dealing with eggs (Thanks Kevin Deldycke github.com/kdeldycke) + +0.11.0 + +* support for USING TIMESTAMP via a .timestamp(timedelta(seconds=30)) syntax + - allows for long, timedelta, and datetime +* fixed use of USING TIMESTAMP in batches +* clear TTL and timestamp off models after persisting to DB +* allows UUID without dashes - (Thanks to Michael Hall, github.com/mahall) +* fixes regarding syncing schema settings (thanks Kai Lautaportti github.com/dokai) + +0.10.0 + +* support for execute_on_exception within batches + +0.9.2 +* fixing create keyspace with network topology strategy +* fixing regression with query expressions + +0.9 +* adding update method +* adding support for ttls +* adding support for per-query consistency +* adding BigInt column (thanks @Lifto) +* adding support for timezone aware time uuid functions (thanks @dokai) +* only saving collection fields on insert if they've been modified +* adding model method that returns a list of modified columns + +0.8.5 +* adding support for timeouts + +0.8.4 +* changing value manager previous value copying to deepcopy + +0.8.3 +* better logging for operational errors + +0.8.2 +* fix for connection failover + +0.8.1 +* fix for models not exactly matching schema + +0.8.0 +* support for table polymorphism +* var int type + +0.7.1 +* refactoring query class to make defining custom model instantiation logic easier + +0.7.0 +* added counter columns +* added support for compaction settings at the model level +* deprecated delete_table in favor of drop_table +* deprecated create_table in favor of sync_table +* added support for custom QuerySets + +0.6.0 +* added table sync + +0.5.2 +* adding hex conversion to Bytes column + +0.5.1 +* improving connection pooling +* fixing bug with clustering order columns not being quoted + +0.5 +* eagerly loading results into the query result cache, the cql driver does this anyway, + and pulling them from the cursor was causing some problems with gevented queries, + this will cause some breaking changes for users calling execute directly + +0.4.10 +* changing query parameter placeholders from uuid1 to uuid4 + +0.4.7 +* adding support for passing None into query batch methods to clear any batch objects + +0.4.6 +* fixing the way models are compared + +0.4.5 +* fixed bug where container columns would not call their child to_python method, this only really affected columns with special to_python logic + +0.4.4 +* adding query logging back +* fixed bug updating an empty list column + +0.4.3 +* fixed bug with Text column validation + +0.4.2 +* added support for instantiating container columns with column instances + +0.4.1 +* fixed bug in TimeUUID from datetime method + +0.4.0 +* removed default values from all column types +* explicit primary key is required (automatic id removed) +* added validation on keyname types on .create() +* changed table_name to __table_name__, read_repair_chance to __read_repair_chance__, keyspace to __keyspace__ +* modified table name auto generator to ignore module name +* changed internal implementation of model value get/set +* added TimeUUID.from_datetime(), used for generating UUID1's for a specific +time + +0.3.3 +* added abstract base class models + +0.3.2 +* comprehesive rewrite of connection management (thanks @rustyrazorblade) + +0.3 +* added support for Token function (thanks @mrk-its) +* added support for compound partition key (thanks @mrk-its)s +* added support for defining clustering key ordering (thanks @mrk-its) +* added values_list to Query class, bypassing object creation if desired (thanks @mrk-its) +* fixed bug with Model.objects caching values (thanks @mrk-its) +* fixed Cassandra 1.2.5 compatibility bug +* updated model exception inheritance + +0.2.1 * adding support for datetimes with tzinfo (thanks @gdoermann) * fixing bug in saving map updates (thanks @pandu-rao) diff --git a/cqlengine/VERSION b/cqlengine/VERSION new file mode 100644 index 00000000..2a0970ca --- /dev/null +++ b/cqlengine/VERSION @@ -0,0 +1 @@ +0.16.1 diff --git a/cqlengine/__init__.py b/cqlengine/__init__.py index cd4e28ab..ecec3cd0 100644 --- a/cqlengine/__init__.py +++ b/cqlengine/__init__.py @@ -1,7 +1,32 @@ +import pkg_resources + +from cassandra import ConsistencyLevel + from cqlengine.columns import * from cqlengine.functions import * from cqlengine.models import Model, CounterModel from cqlengine.query import BatchQuery -__version__ = '0.2' +__cqlengine_version_path__ = pkg_resources.resource_filename('cqlengine', + 'VERSION') +__version__ = open(__cqlengine_version_path__, 'r').readline().strip() + +# compaction +SizeTieredCompactionStrategy = "SizeTieredCompactionStrategy" +LeveledCompactionStrategy = "LeveledCompactionStrategy" + +# Caching constants. +CACHING_ALL = "ALL" +CACHING_KEYS_ONLY = "KEYS_ONLY" +CACHING_ROWS_ONLY = "ROWS_ONLY" +CACHING_NONE = "NONE" + +ANY = ConsistencyLevel.ANY +ONE = ConsistencyLevel.ONE +TWO = ConsistencyLevel.TWO +THREE = ConsistencyLevel.THREE +QUORUM = ConsistencyLevel.QUORUM +LOCAL_QUORUM = ConsistencyLevel.LOCAL_QUORUM +EACH_QUORUM = ConsistencyLevel.EACH_QUORUM +ALL = ConsistencyLevel.ALL diff --git a/cqlengine/columns.py b/cqlengine/columns.py index d5a83db5..9e92eefe 100644 --- a/cqlengine/columns.py +++ b/cqlengine/columns.py @@ -1,19 +1,20 @@ #column field types -from copy import copy +from copy import deepcopy, copy from datetime import datetime from datetime import date import re -from uuid import uuid1, uuid4 -from cql.query import cql_quote +from cassandra.cqltypes import DateType from cqlengine.exceptions import ValidationError +from cassandra.encoder import cql_quote + class BaseValueManager(object): def __init__(self, instance, column, value): self.instance = instance self.column = column - self.previous_value = copy(value) + self.previous_value = deepcopy(value) self.value = value @property @@ -52,6 +53,26 @@ class BaseValueManager(object): else: return property(_get, _set) + +class ValueQuoter(object): + """ + contains a single value, which will quote itself for CQL insertion statements + """ + def __init__(self, value): + self.value = value + + def __str__(self): + raise NotImplementedError + + def __repr__(self): + return self.__str__() + + def __eq__(self, other): + if isinstance(other, self.__class__): + return self.value == other.value + return False + + class Column(object): #the cassandra type this column maps to @@ -60,24 +81,38 @@ class Column(object): instance_counter = 0 - def __init__(self, primary_key=False, index=False, db_field=None, default=None, required=True): + def __init__(self, + primary_key=False, + partition_key=False, + index=False, + db_field=None, + default=None, + required=False, + clustering_order=None, + polymorphic_key=False): """ :param primary_key: bool flag, indicates this column is a primary key. The first primary key defined - on a model is the partition key, all others are cluster keys + on a model is the partition key (unless partition keys are set), all others are cluster keys + :param partition_key: indicates that this column should be the partition key, defining + more than one partition key column creates a compound partition key :param index: bool flag, indicates an index should be created for this column :param db_field: the fieldname this field will map to in the database :param default: the default value, can be a value or a callable (no args) - :param required: boolean, is the field required? + :param required: boolean, is the field required? Model validation will raise and + exception if required is set to True and there is a None value assigned + :param clustering_order: only applicable on clustering keys (primary keys that are not partition keys) + determines the order that the clustering keys are sorted on disk + :param polymorphic_key: boolean, if set to True, this column will be used for saving and loading instances + of polymorphic tables """ - self.primary_key = primary_key + self.partition_key = partition_key + self.primary_key = partition_key or primary_key self.index = index self.db_field = db_field self.default = default self.required = required - - #only the model meta class should touch this - self._partition_key = False - + self.clustering_order = clustering_order + self.polymorphic_key = polymorphic_key #the column name in the model definition self.column_name = None @@ -137,7 +172,7 @@ class Column(object): """ Returns a column definition for CQL table definition """ - return '"{}" {}'.format(self.db_field_name, self.db_type) + return '{} {}'.format(self.cql, self.db_type) def set_column_name(self, name): """ @@ -156,17 +191,45 @@ class Column(object): """ Returns the name of the cql index """ return 'index_{}'.format(self.db_field_name) + @property + def cql(self): + return self.get_cql() + + def get_cql(self): + return '"{}"'.format(self.db_field_name) + + def _val_is_null(self, val): + """ determines if the given value equates to a null value for the given column type """ + return val is None + + class Bytes(Column): db_type = 'blob' + class Quoter(ValueQuoter): + def __str__(self): + return '0x' + self.value.encode('hex') + + def to_database(self, value): + val = super(Bytes, self).to_database(value) + if val is None: return + + return self.Quoter(val) + + def to_python(self, value): + #return value[2:].decode('hex') + return value + + class Ascii(Column): db_type = 'ascii' + class Text(Column): db_type = 'text' def __init__(self, *args, **kwargs): - self.min_length = kwargs.pop('min_length', 1 if kwargs.get('required', True) else None) + self.min_length = kwargs.pop('min_length', 1 if kwargs.get('required', False) else None) self.max_length = kwargs.pop('max_length', None) super(Text, self).__init__(*args, **kwargs) @@ -183,6 +246,7 @@ class Text(Column): raise ValidationError('{} is shorter than {} characters'.format(self.column_name, self.min_length)) return value + class Integer(Column): db_type = 'int' @@ -200,44 +264,102 @@ class Integer(Column): def to_database(self, value): return self.validate(value) -class DateTime(Column): - db_type = 'timestamp' - def __init__(self, **kwargs): - super(DateTime, self).__init__(**kwargs) + +class BigInt(Integer): + db_type = 'bigint' + + +class VarInt(Column): + db_type = 'varint' + + def validate(self, value): + val = super(VarInt, self).validate(value) + if val is None: + return + try: + return long(val) + except (TypeError, ValueError): + raise ValidationError( + "{} can't be converted to integral value".format(value)) def to_python(self, value): + return self.validate(value) + + def to_database(self, value): + return self.validate(value) + + +class CounterValueManager(BaseValueManager): + def __init__(self, instance, column, value): + super(CounterValueManager, self).__init__(instance, column, value) + self.value = self.value or 0 + self.previous_value = self.previous_value or 0 + + +class Counter(Integer): + db_type = 'counter' + + value_manager = CounterValueManager + + def __init__(self, + index=False, + db_field=None, + required=False): + super(Counter, self).__init__( + primary_key=False, + partition_key=False, + index=index, + db_field=db_field, + default=0, + required=required, + ) + + +class DateTime(Column): + db_type = 'timestamp' + + def to_python(self, value): + if value is None: return if isinstance(value, datetime): return value - return datetime.utcfromtimestamp(value) + elif isinstance(value, date): + return datetime(*(value.timetuple()[:6])) + try: + return datetime.utcfromtimestamp(value) + except TypeError: + return datetime.utcfromtimestamp(DateType.deserialize(value)) def to_database(self, value): value = super(DateTime, self).to_database(value) + if value is None: return if not isinstance(value, datetime): - raise ValidationError("'{}' is not a datetime object".format(value)) + if isinstance(value, date): + value = datetime(value.year, value.month, value.day) + else: + raise ValidationError("'{}' is not a datetime object".format(value)) epoch = datetime(1970, 1, 1, tzinfo=value.tzinfo) - offset = 0 - if epoch.tzinfo: - offset_delta = epoch.tzinfo.utcoffset(epoch) - offset = offset_delta.days*24*3600 + offset_delta.seconds - return long(((value - epoch).total_seconds() - offset) * 1000) + offset = epoch.tzinfo.utcoffset(epoch).total_seconds() if epoch.tzinfo else 0 + + return long(((value - epoch).total_seconds() - offset) * 1000) class Date(Column): db_type = 'timestamp' - def __init__(self, **kwargs): - super(Date, self).__init__(**kwargs) - def to_python(self, value): + if value is None: return if isinstance(value, datetime): return value.date() elif isinstance(value, date): return value - - return datetime.utcfromtimestamp(value).date() + try: + return datetime.utcfromtimestamp(value).date() + except TypeError: + return datetime.utcfromtimestamp(DateType.deserialize(value)).date() def to_database(self, value): value = super(Date, self).to_database(value) + if value is None: return if isinstance(value, datetime): value = value.date() if not isinstance(value, date): @@ -252,19 +374,25 @@ class UUID(Column): """ db_type = 'uuid' - re_uuid = re.compile(r'[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}') - - def __init__(self, default=lambda:uuid4(), **kwargs): - super(UUID, self).__init__(default=default, **kwargs) + re_uuid = re.compile(r'[0-9a-f]{8}-?[0-9a-f]{4}-?[0-9a-f]{4}-?[0-9a-f]{4}-?[0-9a-f]{12}') def validate(self, value): val = super(UUID, self).validate(value) if val is None: return from uuid import UUID as _UUID if isinstance(val, _UUID): return val - if not self.re_uuid.match(val): - raise ValidationError("{} is not a valid uuid".format(value)) - return _UUID(val) + if isinstance(val, basestring) and self.re_uuid.match(val): + return _UUID(val) + raise ValidationError("{} is not a valid uuid".format(value)) + + def to_python(self, value): + return self.validate(value) + + def to_database(self, value): + return self.validate(value) + +from uuid import UUID as pyUUID, getnode + class TimeUUID(UUID): """ @@ -273,18 +401,52 @@ class TimeUUID(UUID): db_type = 'timeuuid' - def __init__(self, **kwargs): - kwargs.setdefault('default', lambda: uuid1()) - super(TimeUUID, self).__init__(**kwargs) + @classmethod + def from_datetime(self, dt): + """ + generates a UUID for a given datetime + + :param dt: datetime + :type dt: datetime + :return: + """ + global _last_timestamp + + epoch = datetime(1970, 1, 1, tzinfo=dt.tzinfo) + offset = epoch.tzinfo.utcoffset(epoch).total_seconds() if epoch.tzinfo else 0 + timestamp = (dt - epoch).total_seconds() - offset + + node = None + clock_seq = None + + nanoseconds = int(timestamp * 1e9) + timestamp = int(nanoseconds // 100) + 0x01b21dd213814000L + + if clock_seq is None: + import random + clock_seq = random.randrange(1 << 14L) # instead of stable storage + time_low = timestamp & 0xffffffffL + time_mid = (timestamp >> 32L) & 0xffffL + time_hi_version = (timestamp >> 48L) & 0x0fffL + clock_seq_low = clock_seq & 0xffL + clock_seq_hi_variant = (clock_seq >> 8L) & 0x3fL + if node is None: + node = getnode() + return pyUUID(fields=(time_low, time_mid, time_hi_version, + clock_seq_hi_variant, clock_seq_low, node), version=1) + class Boolean(Column): db_type = 'boolean' - def to_python(self, value): + def validate(self, value): + """ Always returns a Python boolean. """ + value = super(Boolean, self).validate(value) return bool(value) - def to_database(self, value): - return bool(value) + def to_python(self, value): + return self.validate(value) + class Float(Column): db_type = 'double' @@ -307,59 +469,76 @@ class Float(Column): def to_database(self, value): return self.validate(value) + class Decimal(Column): db_type = 'decimal' -class Counter(Integer): - """ Validates like an integer, goes into the database as a counter - """ - db_type = 'counter' - +class Counter(Column): + #TODO: counter field def __init__(self, **kwargs): super(Counter, self).__init__(**kwargs) - -class ContainerQuoter(object): - """ - contains a single value, which will quote itself for CQL insertion statements - """ - def __init__(self, value): - self.value = value - - def __str__(self): raise NotImplementedError + def to_python(self, value): + return self.validate(value) + + def to_database(self, value): + return self.validate(value) + + class BaseContainerColumn(Column): """ - Base Container type + Base Container type for collection-like columns. + + https://cassandra.apache.org/doc/cql3/CQL.html#collections """ def __init__(self, value_type, **kwargs): """ :param value_type: a column class indicating the types of the value """ - if not issubclass(value_type, Column): + inheritance_comparator = issubclass if isinstance(value_type, type) else isinstance + if not inheritance_comparator(value_type, Column): raise ValidationError('value_type must be a column class') - if issubclass(value_type, BaseContainerColumn): + if inheritance_comparator(value_type, BaseContainerColumn): raise ValidationError('container types cannot be nested') if value_type.db_type is None: raise ValidationError('value_type cannot be an abstract column type') - self.value_type = value_type - self.value_col = self.value_type() + if isinstance(value_type, type): + self.value_type = value_type + self.value_col = self.value_type() + else: + self.value_col = value_type + self.value_type = self.value_col.__class__ + super(BaseContainerColumn, self).__init__(**kwargs) + def validate(self, value): + value = super(BaseContainerColumn, self).validate(value) + # It is dangerous to let collections have more than 65535. + # See: https://issues.apache.org/jira/browse/CASSANDRA-5428 + if value is not None and len(value) > 65535: + raise ValidationError("Collection can't have more than 65535 elements.") + return value + def get_column_def(self): """ Returns a column definition for CQL table definition """ db_type = self.db_type.format(self.value_type.db_type) - return '{} {}'.format(self.db_field_name, db_type) + return '{} {}'.format(self.cql, db_type) + + + def _val_is_null(self, val): + return not val + + +class BaseContainerQuoter(ValueQuoter): + + def __nonzero__(self): + return bool(self.value) - def get_update_statement(self, val, prev, ctx): - """ - Used to add partial update statements - """ - raise NotImplementedError class Set(BaseContainerColumn): """ @@ -369,20 +548,21 @@ class Set(BaseContainerColumn): """ db_type = 'set<{}>' - class Quoter(ContainerQuoter): + class Quoter(BaseContainerQuoter): def __str__(self): cq = cql_quote return '{' + ', '.join([cq(v) for v in self.value]) + '}' - def __init__(self, value_type, strict=True, **kwargs): + def __init__(self, value_type, strict=True, default=set, **kwargs): """ :param value_type: a column class indicating the types of the value :param strict: sets whether non set values will be coerced to set type on validation, or raise a validation error, defaults to True """ self.strict = strict - super(Set, self).__init__(value_type, **kwargs) + + super(Set, self).__init__(value_type, default=default, **kwargs) def validate(self, value): val = super(Set, self).validate(value) @@ -394,55 +574,23 @@ class Set(BaseContainerColumn): else: raise ValidationError('{} cannot be coerced to a set object'.format(val)) + if None in val: + raise ValidationError("None not allowed in a set") + return {self.value_col.validate(v) for v in val} + def to_python(self, value): + if value is None: return set() + return {self.value_col.to_python(v) for v in value} + def to_database(self, value): if value is None: return None + if isinstance(value, self.Quoter): return value return self.Quoter({self.value_col.to_database(v) for v in value}) - def get_update_statement(self, val, prev, ctx): - """ - Returns statements that will be added to an object's update statement - also updates the query context - :param val: the current column value - :param prev: the previous column value - :param ctx: the values that will be passed to the query - :rtype: list - """ - # remove from Quoter containers, if applicable - val = self.to_database(val) - prev = self.to_database(prev) - if isinstance(val, self.Quoter): val = val.value - if isinstance(prev, self.Quoter): prev = prev.value - - if val is None or val == prev: - # don't return anything if the new value is the same as - # the old one, or if the new value is none - return [] - elif prev is None or not any({v in prev for v in val}): - field = uuid1().hex - ctx[field] = self.Quoter(val) - return ['"{}" = :{}'.format(self.db_field_name, field)] - else: - # partial update time - to_create = val - prev - to_delete = prev - val - statements = [] - - if to_create: - field_id = uuid1().hex - ctx[field_id] = self.Quoter(to_create) - statements += ['"{0}" = "{0}" + :{1}'.format(self.db_field_name, field_id)] - - if to_delete: - field_id = uuid1().hex - ctx[field_id] = self.Quoter(to_delete) - statements += ['"{0}" = "{0}" - :{1}'.format(self.db_field_name, field_id)] - - return statements class List(BaseContainerColumn): """ @@ -452,93 +600,36 @@ class List(BaseContainerColumn): """ db_type = 'list<{}>' - class Quoter(ContainerQuoter): + class Quoter(BaseContainerQuoter): def __str__(self): cq = cql_quote return '[' + ', '.join([cq(v) for v in self.value]) + ']' + def __nonzero__(self): + return bool(self.value) + + def __init__(self, value_type, default=list, **kwargs): + return super(List, self).__init__(value_type=value_type, default=default, **kwargs) + def validate(self, value): val = super(List, self).validate(value) if val is None: return if not isinstance(val, (set, list, tuple)): raise ValidationError('{} is not a list object'.format(val)) + if None in val: + raise ValidationError("None is not allowed in a list") return [self.value_col.validate(v) for v in val] def to_python(self, value): - if value is None: return None - return list(value) + if value is None: return [] + return [self.value_col.to_python(v) for v in value] def to_database(self, value): if value is None: return None if isinstance(value, self.Quoter): return value return self.Quoter([self.value_col.to_database(v) for v in value]) - def get_update_statement(self, val, prev, values): - """ - Returns statements that will be added to an object's update statement - also updates the query context - """ - # remove from Quoter containers, if applicable - val = self.to_database(val) - prev = self.to_database(prev) - if isinstance(val, self.Quoter): val = val.value - if isinstance(prev, self.Quoter): prev = prev.value - - def _insert(): - field_id = uuid1().hex - values[field_id] = self.Quoter(val) - return ['"{}" = :{}'.format(self.db_field_name, field_id)] - - if val is None or val == prev: - return [] - elif prev is None: - return _insert() - elif len(val) < len(prev): - return _insert() - else: - # the prepend and append lists, - # if both of these are still None after looking - # at both lists, an insert statement will be returned - prepend = None - append = None - - # the max start idx we want to compare - search_space = len(val) - max(0, len(prev)-1) - - # the size of the sub lists we want to look at - search_size = len(prev) - - for i in range(search_space): - #slice boundary - j = i + search_size - sub = val[i:j] - idx_cmp = lambda idx: prev[idx] == sub[idx] - if idx_cmp(0) and idx_cmp(-1) and prev == sub: - prepend = val[:i] - append = val[j:] - break - - # create update statements - if prepend is append is None: - return _insert() - - statements = [] - if prepend: - field_id = uuid1().hex - # CQL seems to prepend element at a time, starting - # with the element at idx 0, we can either reverse - # it here, or have it inserted in reverse - prepend.reverse() - values[field_id] = self.Quoter(prepend) - statements += ['"{0}" = :{1} + "{0}"'.format(self.db_field_name, field_id)] - - if append: - field_id = uuid1().hex - values[field_id] = self.Quoter(append) - statements += ['"{0}" = "{0}" + :{1}'.format(self.db_field_name, field_id)] - - return statements class Map(BaseContainerColumn): """ @@ -549,27 +640,41 @@ class Map(BaseContainerColumn): db_type = 'map<{}, {}>' - class Quoter(ContainerQuoter): + class Quoter(BaseContainerQuoter): def __str__(self): cq = cql_quote return '{' + ', '.join([cq(k) + ':' + cq(v) for k,v in self.value.items()]) + '}' - def __init__(self, key_type, value_type, **kwargs): + def get(self, key): + return self.value.get(key) + + def keys(self): + return self.value.keys() + + def items(self): + return self.value.items() + + def __init__(self, key_type, value_type, default=dict, **kwargs): """ :param key_type: a column class indicating the types of the key :param value_type: a column class indicating the types of the value """ - if not issubclass(value_type, Column): + inheritance_comparator = issubclass if isinstance(key_type, type) else isinstance + if not inheritance_comparator(key_type, Column): raise ValidationError('key_type must be a column class') - if issubclass(value_type, BaseContainerColumn): + if inheritance_comparator(key_type, BaseContainerColumn): raise ValidationError('container types cannot be nested') if key_type.db_type is None: raise ValidationError('key_type cannot be an abstract column type') - self.key_type = key_type - self.key_col = self.key_type() - super(Map, self).__init__(value_type, **kwargs) + if isinstance(key_type, type): + self.key_type = key_type + self.key_col = self.key_type() + else: + self.key_col = key_type + self.key_type = self.key_col.__class__ + super(Map, self).__init__(value_type, default=default, **kwargs) def get_column_def(self): """ @@ -579,7 +684,7 @@ class Map(BaseContainerColumn): self.key_type.db_type, self.value_type.db_type ) - return '{} {}'.format(self.db_field_name, db_type) + return '{} {}'.format(self.cql, db_type) def validate(self, value): val = super(Map, self).validate(value) @@ -589,62 +694,38 @@ class Map(BaseContainerColumn): return {self.key_col.validate(k):self.value_col.validate(v) for k,v in val.items()} def to_python(self, value): + if value is None: + return {} if value is not None: - return {self.key_col.to_python(k):self.value_col.to_python(v) for k,v in value.items()} + return {self.key_col.to_python(k): self.value_col.to_python(v) for k,v in value.items()} def to_database(self, value): if value is None: return None if isinstance(value, self.Quoter): return value return self.Quoter({self.key_col.to_database(k):self.value_col.to_database(v) for k,v in value.items()}) - def get_update_statement(self, val, prev, ctx): - """ - http://www.datastax.com/docs/1.2/cql_cli/using/collections_map#deletion - """ - # remove from Quoter containers, if applicable - val = self.to_database(val) - prev = self.to_database(prev) - if isinstance(val, self.Quoter): val = val.value - if isinstance(prev, self.Quoter): prev = prev.value - val = val or {} - prev = prev or {} - - #get the updated map - update = {k:v for k,v in val.items() if v != prev.get(k)} - - statements = [] - for k,v in update.items(): - key_id = uuid1().hex - val_id = uuid1().hex - ctx[key_id] = k - ctx[val_id] = v - statements += ['"{}"[:{}] = :{}'.format(self.db_field_name, key_id, val_id)] - - return statements - - def get_delete_statement(self, val, prev, ctx): - """ - Returns statements that will be added to an object's delete statement - also updates the query context, used for removing keys from a map - """ - if val is prev is None: - return [] - - val = self.to_database(val) - prev = self.to_database(prev) - if isinstance(val, self.Quoter): val = val.value - if isinstance(prev, self.Quoter): prev = prev.value - - old_keys = set(prev.keys()) if prev else set() - new_keys = set(val.keys()) if val else set() - del_keys = old_keys - new_keys - - del_statements = [] - for key in del_keys: - field_id = uuid1().hex - ctx[field_id] = key - del_statements += ['"{}"[:{}]'.format(self.db_field_name, field_id)] - - return del_statements +class _PartitionKeysToken(Column): + """ + virtual column representing token of partition columns. + Used by filter(pk__token=Token(...)) filters + """ + + def __init__(self, model): + self.partition_columns = model._partition_keys.values() + super(_PartitionKeysToken, self).__init__(partition_key=True) + + @property + def db_field_name(self): + return 'token({})'.format(', '.join(['"{}"'.format(c.db_field_name) for c in self.partition_columns])) + + def to_database(self, value): + from cqlengine.functions import Token + assert isinstance(value, Token) + value.set_columns(self.partition_columns) + return value + + def get_cql(self): + return "token({})".format(", ".join(c.cql for c in self.partition_columns)) + diff --git a/cqlengine/connection.py b/cqlengine/connection.py index c56d08a9..7cf8dea2 100644 --- a/cqlengine/connection.py +++ b/cqlengine/connection.py @@ -3,184 +3,113 @@ #http://cassandra.apache.org/doc/cql/CQL.html from collections import namedtuple -import Queue -import random +from cassandra.cluster import Cluster +from cassandra.query import SimpleStatement, Statement -import cql +try: + import Queue as queue +except ImportError: + # python 3 + import queue -from cqlengine.exceptions import CQLEngineException +import logging -from thrift.transport.TTransport import TTransportException +from cqlengine.exceptions import CQLEngineException, UndefinedKeyspaceException +from cassandra import ConsistencyLevel +from cqlengine.statements import BaseCQLStatement +from cassandra.query import dict_factory +LOG = logging.getLogger('cqlengine.cql') class CQLConnectionError(CQLEngineException): pass Host = namedtuple('Host', ['name', 'port']) -_hosts = [] -_host_idx = 0 -_conn= None -_username = None -_password = None -_max_connections = 10 -def setup(hosts, username=None, password=None, max_connections=10, default_keyspace=None, lazy=False): +cluster = None +session = None +lazy_connect_args = None +default_consistency_level = None + +def setup( + hosts, + default_keyspace, + consistency=ConsistencyLevel.ONE, + lazy_connect=False, + **kwargs): """ Records the hosts and connects to one of them - :param hosts: list of hosts, strings in the :, or just + :param hosts: list of hosts, see http://datastax.github.io/python-driver/api/cassandra/cluster.html + :type hosts: list + :param default_keyspace: The default keyspace to use + :type default_keyspace: str + :param consistency: The global consistency level + :type consistency: int + :param lazy_connect: True if should not connect until first use + :type lazy_connect: bool """ - global _hosts - global _username - global _password - global _max_connections - _username = username - _password = password - _max_connections = max_connections + global cluster, session, default_consistency_level, lazy_connect_args - if default_keyspace: - from cqlengine import models - models.DEFAULT_KEYSPACE = default_keyspace + if 'username' in kwargs or 'password' in kwargs: + raise CQLEngineException("Username & Password are now handled by using the native driver's auth_provider") - for host in hosts: - host = host.strip() - host = host.split(':') - if len(host) == 1: - _hosts.append(Host(host[0], 9160)) - elif len(host) == 2: - _hosts.append(Host(*host)) - else: - raise CQLConnectionError("Can't parse {}".format(''.join(host))) + if not default_keyspace: + raise UndefinedKeyspaceException() - if not _hosts: - raise CQLConnectionError("At least one host required") + from cqlengine import models + models.DEFAULT_KEYSPACE = default_keyspace - random.shuffle(_hosts) + default_consistency_level = consistency + if lazy_connect: + lazy_connect_args = (hosts, default_keyspace, consistency, kwargs) + return - if not lazy: - con = ConnectionPool.get() - ConnectionPool.put(con) + cluster = Cluster(hosts, **kwargs) + session = cluster.connect() + session.row_factory = dict_factory + +def execute(query, params=None, consistency_level=None): + + handle_lazy_connect() + + if not session: + raise CQLEngineException("It is required to setup() cqlengine before executing queries") -class ConnectionPool(object): - """Handles pooling of database connections.""" + if consistency_level is None: + consistency_level = default_consistency_level - # Connection pool queue - _queue = None + if isinstance(query, Statement): + pass - @classmethod - def clear(cls): - """ - Force the connection pool to be cleared. Will close all internal - connections. - """ - try: - while not cls._queue.empty(): - cls._queue.get().close() - except: - pass + elif isinstance(query, BaseCQLStatement): + params = query.get_context() + query = str(query) + query = SimpleStatement(query, consistency_level=consistency_level) - @classmethod - def get(cls): - """ - Returns a usable database connection. Uses the internal queue to - determine whether to return an existing connection or to create - a new one. - """ - try: - if cls._queue.empty(): - return cls._create_connection() - return cls._queue.get() - except CQLConnectionError as cqle: - raise cqle - except: - if not cls._queue: - cls._queue = Queue.Queue(maxsize=_max_connections) - return cls._create_connection() - - @classmethod - def put(cls, conn): - """ - Returns a connection to the queue freeing it up for other queries to - use. - - :param conn: The connection to be released - :type conn: connection - """ - try: - if cls._queue.full(): - conn.close() - else: - cls._queue.put(conn) - except: - if not cls._queue: - cls._queue = Queue.Queue(maxsize=_max_connections) - cls._queue.put(conn) - - @classmethod - def _create_connection(cls): - """ - Creates a new connection for the connection pool. - """ - global _hosts - global _username - global _password - - if not _hosts: - raise CQLConnectionError("At least one host required") - - host = _hosts[_host_idx] - - new_conn = cql.connect(host.name, host.port, user=_username, password=_password) - new_conn.set_cql_version('3.0.0') - return new_conn + elif isinstance(query, basestring): + query = SimpleStatement(query, consistency_level=consistency_level) -class connection_manager(object): - """ - Connection failure tolerant connection manager. Written to be used in a 'with' block for connection pooling - """ - def __init__(self): - if not _hosts: - raise CQLConnectionError("No connections have been configured, call cqlengine.connection.setup") - self.keyspace = None - self.con = ConnectionPool.get() - self.cur = None + LOG.info(query.query_string) - def close(self): - if self.cur: self.cur.close() - ConnectionPool.put(self.con) + params = params or {} + result = session.execute(query, params) - def __enter__(self): - return self + return result - def __exit__(self, type, value, traceback): - self.close() - def execute(self, query, params=None): - """ - Gets a connection from the pool and executes the given query, returns the cursor +def get_session(): + handle_lazy_connect() + return session - if there's a connection problem, this will silently create a new connection pool - from the available hosts, and remove the problematic host from the host list - """ - if params is None: - params = {} - global _host_idx - - for i in range(len(_hosts)): - try: - self.cur = self.con.cursor() - self.cur.execute(query, params) - return self.cur - except cql.ProgrammingError as ex: - raise CQLEngineException(unicode(ex)) - except TTransportException: - #TODO: check for other errors raised in the event of a connection / server problem - #move to the next connection and set the connection pool - _host_idx += 1 - _host_idx %= len(_hosts) - self.con.close() - self.con = ConnectionPool._create_connection() - - raise CQLConnectionError("couldn't reach a Cassandra server") +def get_cluster(): + handle_lazy_connect() + return cluster +def handle_lazy_connect(): + global lazy_connect_args + if lazy_connect_args: + hosts, default_keyspace, consistency, kwargs = lazy_connect_args + lazy_connect_args = None + setup(hosts, default_keyspace, consistency, **kwargs) diff --git a/cqlengine/exceptions.py b/cqlengine/exceptions.py index 87b5cdc4..94a51b92 100644 --- a/cqlengine/exceptions.py +++ b/cqlengine/exceptions.py @@ -3,3 +3,4 @@ class CQLEngineException(Exception): pass class ModelException(CQLEngineException): pass class ValidationError(CQLEngineException): pass +class UndefinedKeyspaceException(CQLEngineException): pass diff --git a/cqlengine/functions.py b/cqlengine/functions.py index b2886c9a..164bfd41 100644 --- a/cqlengine/functions.py +++ b/cqlengine/functions.py @@ -1,31 +1,48 @@ from datetime import datetime +from uuid import uuid1 from cqlengine.exceptions import ValidationError -class BaseQueryFunction(object): +class QueryValue(object): + """ + Base class for query filter values. Subclasses of these classes can + be passed into .filter() keyword args + """ + + format_string = '%({})s' + + def __init__(self, value): + self.value = value + self.context_id = None + + def __unicode__(self): + return self.format_string.format(self.context_id) + + def set_context_id(self, ctx_id): + self.context_id = ctx_id + + def get_context_size(self): + return 1 + + def update_context(self, ctx): + ctx[str(self.context_id)] = self.value + + +class BaseQueryFunction(QueryValue): """ Base class for filtering functions. Subclasses of these classes can be passed into .filter() and will be translated into CQL functions in the resulting query """ - _cql_string = None - - def __init__(self, value): - self.value = value - - def to_cql(self, value_id): - """ - Returns a function for cql with the value id as it's argument - """ - return self._cql_string.format(value_id) - - def get_value(self): - raise NotImplementedError - class MinTimeUUID(BaseQueryFunction): + """ + return a fake timeuuid corresponding to the smallest possible timeuuid for the given timestamp - _cql_string = 'MinTimeUUID(:{})' + http://cassandra.apache.org/doc/cql3/CQL.html#timeuuidFun + """ + + format_string = 'MinTimeUUID(%({})s)' def __init__(self, value): """ @@ -36,13 +53,23 @@ class MinTimeUUID(BaseQueryFunction): raise ValidationError('datetime instance is required') super(MinTimeUUID, self).__init__(value) - def get_value(self): - epoch = datetime(1970, 1, 1) - return long((self.value - epoch).total_seconds() * 1000) + def to_database(self, val): + epoch = datetime(1970, 1, 1, tzinfo=val.tzinfo) + offset = epoch.tzinfo.utcoffset(epoch).total_seconds() if epoch.tzinfo else 0 + return long(((val - epoch).total_seconds() - offset) * 1000) + + def update_context(self, ctx): + ctx[str(self.context_id)] = self.to_database(self.value) + class MaxTimeUUID(BaseQueryFunction): + """ + return a fake timeuuid corresponding to the largest possible timeuuid for the given timestamp - _cql_string = 'MaxTimeUUID(:{})' + http://cassandra.apache.org/doc/cql3/CQL.html#timeuuidFun + """ + + format_string = 'MaxTimeUUID(%({})s)' def __init__(self, value): """ @@ -53,9 +80,41 @@ class MaxTimeUUID(BaseQueryFunction): raise ValidationError('datetime instance is required') super(MaxTimeUUID, self).__init__(value) - def get_value(self): - epoch = datetime(1970, 1, 1) - return long((self.value - epoch).total_seconds() * 1000) + def to_database(self, val): + epoch = datetime(1970, 1, 1, tzinfo=val.tzinfo) + offset = epoch.tzinfo.utcoffset(epoch).total_seconds() if epoch.tzinfo else 0 + return long(((val - epoch).total_seconds() - offset) * 1000) + + def update_context(self, ctx): + ctx[str(self.context_id)] = self.to_database(self.value) + + +class Token(BaseQueryFunction): + """ + compute the token for a given partition key + + http://cassandra.apache.org/doc/cql3/CQL.html#tokenFun + """ + + def __init__(self, *values): + if len(values) == 1 and isinstance(values[0], (list, tuple)): + values = values[0] + super(Token, self).__init__(values) + self._columns = None + + def set_columns(self, columns): + self._columns = columns + + def get_context_size(self): + return len(self.value) + + def __unicode__(self): + token_args = ', '.join('%({})s'.format(self.context_id + i) for i in range(self.get_context_size())) + return "token({})".format(token_args) + + def update_context(self, ctx): + for i, (col, val) in enumerate(zip(self._columns, self.value)): + ctx[str(self.context_id + i)] = col.to_database(val) class NotSet(object): diff --git a/cqlengine/management.py b/cqlengine/management.py index 67c34f68..095e41b4 100644 --- a/cqlengine/management.py +++ b/cqlengine/management.py @@ -1,8 +1,22 @@ import json +import warnings +from cqlengine import SizeTieredCompactionStrategy, LeveledCompactionStrategy +from cqlengine import ONE +from cqlengine.named import NamedTable -from cqlengine.connection import connection_manager +from cqlengine.connection import execute, get_cluster from cqlengine.exceptions import CQLEngineException +import logging +from collections import namedtuple +Field = namedtuple('Field', ['name', 'type']) + +logger = logging.getLogger(__name__) +from cqlengine.models import Model + +# system keyspaces +schema_columnfamilies = NamedTable('system', 'schema_columnfamilies') + def create_keyspace(name, strategy_class='SimpleStrategy', replication_factor=3, durable_writes=True, **replication_values): """ creates a keyspace @@ -13,109 +27,302 @@ def create_keyspace(name, strategy_class='SimpleStrategy', replication_factor=3, :param durable_writes: 1.2 only, write log is bypassed if set to False :param **replication_values: 1.2 only, additional values to ad to the replication data map """ - with connection_manager() as con: - if not any([name == k.name for k in con.con.client.describe_keyspaces()]): -# if name not in [k.name for k in con.con.client.describe_keyspaces()]: - try: - #Try the 1.1 method - con.execute("""CREATE KEYSPACE {} - WITH strategy_class = '{}' - AND strategy_options:replication_factor={};""".format(name, strategy_class, replication_factor)) - except CQLEngineException: - #try the 1.2 method - replication_map = { - 'class': strategy_class, - 'replication_factor':replication_factor - } - replication_map.update(replication_values) + cluster = get_cluster() - query = """ - CREATE KEYSPACE {} - WITH REPLICATION = {} - """.format(name, json.dumps(replication_map).replace('"', "'")) + if name not in cluster.metadata.keyspaces: + #try the 1.2 method + replication_map = { + 'class': strategy_class, + 'replication_factor':replication_factor + } + replication_map.update(replication_values) + if strategy_class.lower() != 'simplestrategy': + # Although the Cassandra documentation states for `replication_factor` + # that it is "Required if class is SimpleStrategy; otherwise, + # not used." we get an error if it is present. + replication_map.pop('replication_factor', None) - if strategy_class != 'SimpleStrategy': - query += " AND DURABLE_WRITES = {}".format('true' if durable_writes else 'false') + query = """ + CREATE KEYSPACE {} + WITH REPLICATION = {} + """.format(name, json.dumps(replication_map).replace('"', "'")) + + if strategy_class != 'SimpleStrategy': + query += " AND DURABLE_WRITES = {}".format('true' if durable_writes else 'false') + + execute(query) - con.execute(query) def delete_keyspace(name): - with connection_manager() as con: - if name in [k.name for k in con.con.client.describe_keyspaces()]: - con.execute("DROP KEYSPACE {}".format(name)) + cluster = get_cluster() + if name in cluster.metadata.keyspaces: + execute("DROP KEYSPACE {}".format(name)) def create_table(model, create_missing_keyspace=True): + raise CQLEngineException("create_table is deprecated, please use sync_table") + +def sync_table(model, create_missing_keyspace=True): + """ + Inspects the model and creates / updates the corresponding table and columns. + + Note that the attributes removed from the model are not deleted on the database. + They become effectively ignored by (will not show up on) the model. + + :param create_missing_keyspace: (Defaults to True) Flags to us that we need to create missing keyspace + mentioned in the model automatically. + :type create_missing_keyspace: bool + """ + + if not issubclass(model, Model): + raise CQLEngineException("Models must be derived from base Model.") + + if model.__abstract__: + raise CQLEngineException("cannot create table from abstract model") + + #construct query string cf_name = model.column_family_name() raw_cf_name = model.column_family_name(include_keyspace=False) + ks_name = model._get_keyspace() #create missing keyspace if create_missing_keyspace: - create_keyspace(model._get_keyspace()) + create_keyspace(ks_name) - with connection_manager() as con: - #check for an existing column family - ks_info = con.con.client.describe_keyspace(model._get_keyspace()) - if not any([raw_cf_name == cf.name for cf in ks_info.cf_defs]): - qs = ['CREATE TABLE {}'.format(cf_name)] + cluster = get_cluster() - #add column types - pkeys = [] - qtypes = [] - def add_column(col): - s = col.get_column_def() - if col.primary_key: pkeys.append('"{}"'.format(col.db_field_name)) - qtypes.append(s) - for name, col in model._columns.items(): - add_column(col) + keyspace = cluster.metadata.keyspaces[ks_name] + tables = keyspace.tables - qtypes.append('PRIMARY KEY ({})'.format(', '.join(pkeys))) - - qs += ['({})'.format(', '.join(qtypes))] - - # add read_repair_chance - qs += ['WITH read_repair_chance = {}'.format(model.read_repair_chance)] - qs = ' '.join(qs) + #check for an existing column family + if raw_cf_name not in tables: + qs = get_create_table(model) - try: - con.execute(qs) - except CQLEngineException as ex: - # 1.2 doesn't return cf names, so we have to examine the exception - # and ignore if it says the column family already exists - if "Cannot add already existing column family" not in unicode(ex): - raise + try: + execute(qs) + except CQLEngineException as ex: + # 1.2 doesn't return cf names, so we have to examine the exception + # and ignore if it says the column family already exists + if "Cannot add already existing column family" not in unicode(ex): + raise + else: + # see if we're missing any columns + fields = get_fields(model) + field_names = [x.name for x in fields] + for name, col in model._columns.items(): + if col.primary_key or col.partition_key: continue # we can't mess with the PK + if col.db_field_name in field_names: continue # skip columns already defined - #get existing index names, skip ones that already exist - ks_info = con.con.client.describe_keyspace(model._get_keyspace()) - cf_defs = [cf for cf in ks_info.cf_defs if cf.name == raw_cf_name] - idx_names = [i.index_name for i in cf_defs[0].column_metadata] if cf_defs else [] - idx_names = filter(None, idx_names) + # add missing column using the column def + query = "ALTER TABLE {} add {}".format(cf_name, col.get_column_def()) + logger.debug(query) + execute(query) - indexes = [c for n,c in model._columns.items() if c.index] - if indexes: - for column in indexes: - if column.db_index_name in idx_names: continue - qs = ['CREATE INDEX index_{}_{}'.format(raw_cf_name, column.db_field_name)] - qs += ['ON {}'.format(cf_name)] - qs += ['("{}")'.format(column.db_field_name)] - qs = ' '.join(qs) + update_compaction(model) - try: - con.execute(qs) - except CQLEngineException as ex: - # 1.2 doesn't return cf names, so we have to examine the exception - # and ignore if it says the index already exists - if "Index already exists" not in unicode(ex): - raise + + table = cluster.metadata.keyspaces[ks_name].tables[raw_cf_name] + + indexes = [c for n,c in model._columns.items() if c.index] + + for column in indexes: + if table.columns[column.db_field_name].index: + continue + + qs = ['CREATE INDEX index_{}_{}'.format(raw_cf_name, column.db_field_name)] + qs += ['ON {}'.format(cf_name)] + qs += ['("{}")'.format(column.db_field_name)] + qs = ' '.join(qs) + execute(qs) + +def get_create_table(model): + cf_name = model.column_family_name() + qs = ['CREATE TABLE {}'.format(cf_name)] + + #add column types + pkeys = [] # primary keys + ckeys = [] # clustering keys + qtypes = [] # field types + def add_column(col): + s = col.get_column_def() + if col.primary_key: + keys = (pkeys if col.partition_key else ckeys) + keys.append('"{}"'.format(col.db_field_name)) + qtypes.append(s) + for name, col in model._columns.items(): + add_column(col) + + qtypes.append('PRIMARY KEY (({}){})'.format(', '.join(pkeys), ckeys and ', ' + ', '.join(ckeys) or '')) + + qs += ['({})'.format(', '.join(qtypes))] + + with_qs = [] + + table_properties = ['bloom_filter_fp_chance', 'caching', 'comment', + 'dclocal_read_repair_chance', 'default_time_to_live', 'gc_grace_seconds', + 'index_interval', 'memtable_flush_period_in_ms', 'populate_io_cache_on_flush', + 'read_repair_chance', 'replicate_on_write'] + for prop_name in table_properties: + prop_value = getattr(model, '__{}__'.format(prop_name), None) + if prop_value is not None: + # Strings needs to be single quoted + if isinstance(prop_value, basestring): + prop_value = "'{}'".format(prop_value) + with_qs.append("{} = {}".format(prop_name, prop_value)) + + _order = ['"{}" {}'.format(c.db_field_name, c.clustering_order or 'ASC') for c in model._clustering_keys.values()] + if _order: + with_qs.append('clustering order by ({})'.format(', '.join(_order))) + + compaction_options = get_compaction_options(model) + if compaction_options: + compaction_options = json.dumps(compaction_options).replace('"', "'") + with_qs.append("compaction = {}".format(compaction_options)) + + # Add table properties. + if with_qs: + qs += ['WITH {}'.format(' AND '.join(with_qs))] + + qs = ' '.join(qs) + return qs + + +def get_compaction_options(model): + """ + Generates dictionary (later converted to a string) for creating and altering + tables with compaction strategy + + :param model: + :return: + """ + if not model.__compaction__: + return {} + + result = {'class':model.__compaction__} + + def setter(key, limited_to_strategy = None): + """ + sets key in result, checking if the key is limited to either SizeTiered or Leveled + :param key: one of the compaction options, like "bucket_high" + :param limited_to_strategy: SizeTieredCompactionStrategy, LeveledCompactionStrategy + :return: + """ + mkey = "__compaction_{}__".format(key) + tmp = getattr(model, mkey) + if tmp and limited_to_strategy and limited_to_strategy != model.__compaction__: + raise CQLEngineException("{} is limited to {}".format(key, limited_to_strategy)) + + if tmp: + # Explicitly cast the values to strings to be able to compare the + # values against introspected values from Cassandra. + result[key] = str(tmp) + + setter('tombstone_compaction_interval') + setter('tombstone_threshold') + + setter('bucket_high', SizeTieredCompactionStrategy) + setter('bucket_low', SizeTieredCompactionStrategy) + setter('max_threshold', SizeTieredCompactionStrategy) + setter('min_threshold', SizeTieredCompactionStrategy) + setter('min_sstable_size', SizeTieredCompactionStrategy) + + setter('sstable_size_in_mb', LeveledCompactionStrategy) + + return result + + +def get_fields(model): + # returns all fields that aren't part of the PK + ks_name = model._get_keyspace() + col_family = model.column_family_name(include_keyspace=False) + + query = "select * from system.schema_columns where keyspace_name = %s and columnfamily_name = %s" + tmp = execute(query, [ks_name, col_family]) + + # Tables containing only primary keys do not appear to create + # any entries in system.schema_columns, as only non-primary-key attributes + # appear to be inserted into the schema_columns table + + try: + return [Field(x['column_name'], x['validator']) for x in tmp if x['type'] == 'regular'] + except KeyError: + return [Field(x['column_name'], x['validator']) for x in tmp] + # convert to Field named tuples + + +def get_table_settings(model): + # returns the table as provided by the native driver for a given model + cluster = get_cluster() + ks = model._get_keyspace() + table = model.column_family_name(include_keyspace=False) + table = cluster.metadata.keyspaces[ks].tables[table] + return table + +def update_compaction(model): + """Updates the compaction options for the given model if necessary. + + :param model: The model to update. + + :return: `True`, if the compaction options were modified in Cassandra, + `False` otherwise. + :rtype: bool + """ + logger.debug("Checking %s for compaction differences", model) + table = get_table_settings(model) + + existing_options = table.options.copy() + + existing_compaction_strategy = existing_options['compaction_strategy_class'] + + existing_options = json.loads(existing_options['compaction_strategy_options']) + + desired_options = get_compaction_options(model) + + desired_compact_strategy = desired_options.get('class', SizeTieredCompactionStrategy) + + desired_options.pop('class', None) + + do_update = False + + if desired_compact_strategy not in existing_compaction_strategy: + do_update = True + + for k, v in desired_options.items(): + val = existing_options.pop(k, None) + if val != v: + do_update = True + + # check compaction_strategy_options + if do_update: + options = get_compaction_options(model) + # jsonify + options = json.dumps(options).replace('"', "'") + cf_name = model.column_family_name() + query = "ALTER TABLE {} with compaction = {}".format(cf_name, options) + logger.debug(query) + execute(query) + return True + + return False def delete_table(model): - cf_name = model.column_family_name() - with connection_manager() as con: - try: - con.execute('drop table {};'.format(cf_name)) - except CQLEngineException as ex: - #don't freak out if the table doesn't exist - if 'Cannot drop non existing column family' not in unicode(ex): - raise + raise CQLEngineException("delete_table has been deprecated in favor of drop_table()") + + +def drop_table(model): + + # don't try to delete non existant tables + meta = get_cluster().metadata + + ks_name = model._get_keyspace() + raw_cf_name = model.column_family_name(include_keyspace=False) + + try: + table = meta.keyspaces[ks_name].tables[raw_cf_name] + execute('drop table {};'.format(model.column_family_name(include_keyspace=True))) + except KeyError: + pass + + diff --git a/cqlengine/models.py b/cqlengine/models.py index 33ccb56b..ae0c7e4d 100644 --- a/cqlengine/models.py +++ b/cqlengine/models.py @@ -1,15 +1,24 @@ from collections import OrderedDict import re +import warnings from cqlengine import columns -from cqlengine.exceptions import ModelException -from cqlengine.functions import BaseQueryFunction, NotSet -from cqlengine.query import QuerySet, QueryException, DMLQuery +from cqlengine.exceptions import ModelException, CQLEngineException, ValidationError +from cqlengine.query import ModelQuerySet, DMLQuery, AbstractQueryableColumn +from cqlengine.query import DoesNotExist as _DoesNotExist +from cqlengine.query import MultipleObjectsReturned as _MultipleObjectsReturned class ModelDefinitionException(ModelException): pass -DEFAULT_KEYSPACE = 'cqlengine' + +class PolyMorphicModelException(ModelException): pass + +DEFAULT_KEYSPACE = None + +class UndefinedKeyspaceWarning(Warning): + pass + class hybrid_classmethod(object): """ @@ -27,35 +36,253 @@ class hybrid_classmethod(object): else: return self.instmethod.__get__(instance, owner) + def __call__(self, *args, **kwargs): + """ + Just a hint to IDEs that it's ok to call this + """ + raise NotImplementedError + + +class QuerySetDescriptor(object): + """ + returns a fresh queryset for the given model + it's declared on everytime it's accessed + """ + + def __get__(self, obj, model): + """ :rtype: ModelQuerySet """ + if model.__abstract__: + raise CQLEngineException('cannot execute queries against abstract models') + queryset = model.__queryset__(model) + + # if this is a concrete polymorphic model, and the polymorphic + # key is an indexed column, add a filter clause to only return + # logical rows of the proper type + if model._is_polymorphic and not model._is_polymorphic_base: + name, column = model._polymorphic_column_name, model._polymorphic_column + if column.partition_key or column.index: + # look for existing poly types + return queryset.filter(**{name: model.__polymorphic_key__}) + + return queryset + + def __call__(self, *args, **kwargs): + """ + Just a hint to IDEs that it's ok to call this + + :rtype: ModelQuerySet + """ + raise NotImplementedError + + +class TTLDescriptor(object): + """ + returns a query set descriptor + """ + def __get__(self, instance, model): + if instance: + #instance = copy.deepcopy(instance) + # instance method + def ttl_setter(ts): + instance._ttl = ts + return instance + return ttl_setter + + qs = model.__queryset__(model) + + def ttl_setter(ts): + qs._ttl = ts + return qs + + return ttl_setter + + def __call__(self, *args, **kwargs): + raise NotImplementedError + +class TimestampDescriptor(object): + """ + returns a query set descriptor with a timestamp specified + """ + def __get__(self, instance, model): + if instance: + # instance method + def timestamp_setter(ts): + instance._timestamp = ts + return instance + return timestamp_setter + + return model.objects.timestamp + + + def __call__(self, *args, **kwargs): + raise NotImplementedError + +class ConsistencyDescriptor(object): + """ + returns a query set descriptor if called on Class, instance if it was an instance call + """ + def __get__(self, instance, model): + if instance: + #instance = copy.deepcopy(instance) + def consistency_setter(consistency): + instance.__consistency__ = consistency + return instance + return consistency_setter + + qs = model.__queryset__(model) + + def consistency_setter(consistency): + qs._consistency = consistency + return qs + + return consistency_setter + + def __call__(self, *args, **kwargs): + raise NotImplementedError + + +class ColumnQueryEvaluator(AbstractQueryableColumn): + """ + Wraps a column and allows it to be used in comparator + expressions, returning query operators + + ie: + Model.column == 5 + """ + + def __init__(self, column): + self.column = column + + def __unicode__(self): + return self.column.db_field_name + + def _get_column(self): + """ :rtype: ColumnQueryEvaluator """ + return self.column + + +class ColumnDescriptor(object): + """ + Handles the reading and writing of column values to and from + a model instance's value manager, as well as creating + comparator queries + """ + + def __init__(self, column): + """ + :param column: + :type column: columns.Column + :return: + """ + self.column = column + self.query_evaluator = ColumnQueryEvaluator(self.column) + + def __get__(self, instance, owner): + """ + Returns either the value or column, depending + on if an instance is provided or not + + :param instance: the model instance + :type instance: Model + """ + try: + return instance._values[self.column.column_name].getval() + except AttributeError as e: + return self.query_evaluator + + def __set__(self, instance, value): + """ + Sets the value on an instance, raises an exception with classes + TODO: use None instance to create update statements + """ + if instance: + return instance._values[self.column.column_name].setval(value) + else: + raise AttributeError('cannot reassign column values') + + def __delete__(self, instance): + """ + Sets the column value to None, if possible + """ + if instance: + if self.column.can_delete: + instance._values[self.column.column_name].delval() + else: + raise AttributeError('cannot delete {} columns'.format(self.column.column_name)) + + class BaseModel(object): """ The base model class, don't inherit from this, inherit from Model, defined below """ - - class DoesNotExist(QueryException): pass - class MultipleObjectsReturned(QueryException): pass - #table names will be generated automatically from it's model and package name - #however, you can also define them manually here - table_name = None + class DoesNotExist(_DoesNotExist): pass - #DEFAULT_TTL must be an integer seconds for the default time to live on any insert on the table - #this can be overridden on any given query, but you can set a default on the model - DEFAULT_TTL = None + class MultipleObjectsReturned(_MultipleObjectsReturned): pass - #the keyspace for this model - keyspace = None - read_repair_chance = 0.1 + objects = QuerySetDescriptor() + ttl = TTLDescriptor() + consistency = ConsistencyDescriptor() - def __init__(self, ttl=NotSet, **values): + # custom timestamps, see USING TIMESTAMP X + timestamp = TimestampDescriptor() + + # _len is lazily created by __len__ + + # table names will be generated automatically from it's model + # however, you can also define them manually here + __table_name__ = None + + # the keyspace for this model + __keyspace__ = None + + # polymorphism options + __polymorphic_key__ = None + + # compaction options + __compaction__ = None + __compaction_tombstone_compaction_interval__ = None + __compaction_tombstone_threshold__ = None + + # compaction - size tiered options + __compaction_bucket_high__ = None + __compaction_bucket_low__ = None + __compaction_max_threshold__ = None + __compaction_min_threshold__ = None + __compaction_min_sstable_size__ = None + + # compaction - leveled options + __compaction_sstable_size_in_mb__ = None + + # end compaction + # the queryset class used for this class + __queryset__ = ModelQuerySet + __dmlquery__ = DMLQuery + + #__ttl__ = None # this doesn't seem to be used + __consistency__ = None # can be set per query + + # Additional table properties + __bloom_filter_fp_chance__ = None + __caching__ = None + __comment__ = None + __dclocal_read_repair_chance__ = None + __default_time_to_live__ = None + __gc_grace_seconds__ = None + __index_interval__ = None + __memtable_flush_period_in_ms__ = None + __populate_io_cache_on_flush__ = None + __read_repair_chance__ = None + __replicate_on_write__ = None + + _timestamp = None # optional timestamp to include with the operation (USING TIMESTAMP) + + def __init__(self, **values): self._values = {} - if ttl == NotSet: - self.ttl = self.DEFAULT_TTL - else: - self.ttl = ttl for name, column in self._columns.items(): value = values.get(name, None) - if value is not None: value = column.to_python(value) + if value is not None or isinstance(column, columns.BaseContainerColumn): + value = column.to_python(value) value_mngr = column.value_manager(self, column, value) self._values[name] = value_mngr @@ -64,6 +291,77 @@ class BaseModel(object): self._is_persisted = False self._batch = None + + def __repr__(self): + """ + Pretty printing of models by their primary key + """ + return '{} <{}>'.format(self.__class__.__name__, + ', '.join(('{}={}'.format(k, getattr(self, k)) for k,v in self._primary_keys.iteritems())) + ) + + + + @classmethod + def _discover_polymorphic_submodels(cls): + if not cls._is_polymorphic_base: + raise ModelException('_discover_polymorphic_submodels can only be called on polymorphic base classes') + def _discover(klass): + if not klass._is_polymorphic_base and klass.__polymorphic_key__ is not None: + cls._polymorphic_map[klass.__polymorphic_key__] = klass + for subklass in klass.__subclasses__(): + _discover(subklass) + _discover(cls) + + @classmethod + def _get_model_by_polymorphic_key(cls, key): + if not cls._is_polymorphic_base: + raise ModelException('_get_model_by_polymorphic_key can only be called on polymorphic base classes') + return cls._polymorphic_map.get(key) + + @classmethod + def _construct_instance(cls, values): + """ + method used to construct instances from query results + this is where polymorphic deserialization occurs + """ + # we're going to take the values, which is from the DB as a dict + # and translate that into our local fields + # the db_map is a db_field -> model field map + items = values.items() + field_dict = dict([(cls._db_map.get(k, k),v) for k,v in items]) + + if cls._is_polymorphic: + poly_key = field_dict.get(cls._polymorphic_column_name) + + if poly_key is None: + raise PolyMorphicModelException('polymorphic key was not found in values') + + poly_base = cls if cls._is_polymorphic_base else cls._polymorphic_base + + klass = poly_base._get_model_by_polymorphic_key(poly_key) + if klass is None: + poly_base._discover_polymorphic_submodels() + klass = poly_base._get_model_by_polymorphic_key(poly_key) + if klass is None: + raise PolyMorphicModelException( + 'unrecognized polymorphic key {} for class {}'.format(poly_key, poly_base.__name__) + ) + + if not issubclass(klass, cls): + raise PolyMorphicModelException( + '{} is not a subclass of {}'.format(klass.__name__, cls.__name__) + ) + + field_dict = {k: v for k, v in field_dict.items() if k in klass._columns.keys()} + + else: + klass = cls + + instance = klass(**field_dict) + instance._is_persisted = True + return instance + def _can_update(self): """ Called by the save function to check if this should be @@ -78,10 +376,35 @@ class BaseModel(object): @classmethod def _get_keyspace(cls): """ Returns the manual keyspace, if set, otherwise the default keyspace """ - return cls.keyspace or DEFAULT_KEYSPACE + return cls.__keyspace__ or DEFAULT_KEYSPACE + + @classmethod + def _get_column(cls, name): + """ + Returns the column matching the given name, raising a key error if + it doesn't exist + + :param name: the name of the column to return + :rtype: Column + """ + return cls._columns[name] def __eq__(self, other): - return self.as_dict() == other.as_dict() + if self.__class__ != other.__class__: + return False + + # check attribute keys + keys = set(self._columns.keys()) + other_keys = set(other._columns.keys()) + if keys != other_keys: + return False + + # check that all of the attributes match + for key in other_keys: + if getattr(self, key, None) != getattr(other, key, None): + return False + + return True def __ne__(self, other): return not self.__eq__(other) @@ -93,16 +416,16 @@ class BaseModel(object): otherwise, it creates it from the module and class name """ cf_name = '' - if cls.table_name: - cf_name = cls.table_name.lower() + if cls.__table_name__: + cf_name = cls.__table_name__.lower() else: + # get polymorphic base table names if model is polymorphic + if cls._is_polymorphic and not cls._is_polymorphic_base: + return cls._polymorphic_base.column_family_name(include_keyspace=include_keyspace) + camelcase = re.compile(r'([a-z])([A-Z])') ccase = lambda s: camelcase.sub(lambda v: '{}_{}'.format(v.group(1), v.group(2).lower()), s) - - module = cls.__module__.split('.') - if module: - cf_name = ccase(module[-1]) + '_' - + cf_name += ccase(cls.__name__) #trim to less than 48 characters or cassandra will complain cf_name = cf_name[-48:] @@ -111,18 +434,56 @@ class BaseModel(object): if not include_keyspace: return cf_name return '{}.{}'.format(cls._get_keyspace(), cf_name) - @property - def pk(self): - """ Returns the object's primary key """ - return getattr(self, self._pk_name) - def validate(self): """ Cleans and validates the field values """ for name, col in self._columns.items(): val = col.validate(getattr(self, name)) setattr(self, name, val) - def as_dict(self): + ### Let an instance be used like a dict of its columns keys/values + + def __iter__(self): + """ Iterate over column ids. """ + for column_id in self._columns.keys(): + yield column_id + + def __getitem__(self, key): + """ Returns column's value. """ + if not isinstance(key, basestring): + raise TypeError + if key not in self._columns.keys(): + raise KeyError + return getattr(self, key) + + def __setitem__(self, key, val): + """ Sets a column's value. """ + if not isinstance(key, basestring): + raise TypeError + if key not in self._columns.keys(): + raise KeyError + return setattr(self, key, val) + + def __len__(self): + """ Returns the number of columns defined on that model. """ + try: + return self._len + except: + self._len = len(self._columns.keys()) + return self._len + + def keys(self): + """ Returns list of column's IDs. """ + return [k for k in self] + + def values(self): + """ Returns list of column's values. """ + return [self[k] for k in self] + + def items(self): + """ Returns a list of columns's IDs/values. """ + return [(k, self[k]) for k in self] + + def _as_dict(self): """ Returns a map of column names to cleaned values """ values = self._dynamic_columns or {} for name, col in self._columns.items(): @@ -131,37 +492,97 @@ class BaseModel(object): @classmethod def create(cls, **kwargs): + extra_columns = set(kwargs.keys()) - set(cls._columns.keys()) + if extra_columns: + raise ValidationError("Incorrect columns passed: {}".format(extra_columns)) return cls.objects.create(**kwargs) - + @classmethod def all(cls): return cls.objects.all() - - @classmethod - def filter(cls, **kwargs): - return cls.objects.filter(**kwargs) - - @classmethod - def get(cls, **kwargs): - return cls.objects.get(**kwargs) - def save(self, ttl=NotSet, timestamp=None): - if ttl == NotSet: - ttl = self.ttl + @classmethod + def filter(cls, *args, **kwargs): + # if kwargs.values().count(None): + # raise CQLEngineException("Cannot pass None as a filter") + + return cls.objects.filter(*args, **kwargs) + + @classmethod + def get(cls, *args, **kwargs): + return cls.objects.get(*args, **kwargs) + + def save(self): + # handle polymorphic models + if self._is_polymorphic: + if self._is_polymorphic_base: + raise PolyMorphicModelException('cannot save polymorphic base model') + else: + setattr(self, self._polymorphic_column_name, self.__polymorphic_key__) + is_new = self.pk is None self.validate() - DMLQuery(self.__class__, self, batch=self._batch).save(ttl, timestamp) + self.__dmlquery__(self.__class__, self, + batch=self._batch, + ttl=self._ttl, + timestamp=self._timestamp, + consistency=self.__consistency__).save() #reset the value managers for v in self._values.values(): v.reset_previous_value() self._is_persisted = True + self._ttl = None + self._timestamp = None + + return self + + def update(self, **values): + for k, v in values.items(): + col = self._columns.get(k) + + # check for nonexistant columns + if col is None: + raise ValidationError("{}.{} has no column named: {}".format(self.__module__, self.__class__.__name__, k)) + + # check for primary key update attempts + if col.is_primary_key: + raise ValidationError("Cannot apply update to primary key '{}' for {}.{}".format(k, self.__module__, self.__class__.__name__)) + + setattr(self, k, v) + + # handle polymorphic models + if self._is_polymorphic: + if self._is_polymorphic_base: + raise PolyMorphicModelException('cannot update polymorphic base model') + else: + setattr(self, self._polymorphic_column_name, self.__polymorphic_key__) + + self.validate() + self.__dmlquery__(self.__class__, self, + batch=self._batch, + ttl=self._ttl, + timestamp=self._timestamp, + consistency=self.__consistency__).update() + + #reset the value managers + for v in self._values.values(): + v.reset_previous_value() + self._is_persisted = True + + self._ttl = None + self._timestamp = None + return self def delete(self): """ Deletes this instance """ - DMLQuery(self.__class__, self, batch=self._batch).delete() + self.__dmlquery__(self.__class__, self, batch=self._batch, timestamp=self._timestamp, consistency=self.__consistency__).delete() + + def get_changed_columns(self): + """ returns a list of the columns that have been updated since instantiation or save """ + return [k for k,v in self._values.items() if v.changed] @classmethod def _class_batch(cls, batch): @@ -171,6 +592,7 @@ class BaseModel(object): self._batch = batch return self + batch = hybrid_classmethod(_class_batch, _inst_batch) @@ -185,7 +607,6 @@ class ModelMetaClass(type): column_dict = OrderedDict() primary_keys = OrderedDict() pk_name = None - primary_key = None #get inherited properties inherited_columns = OrderedDict() @@ -193,52 +614,105 @@ class ModelMetaClass(type): for k,v in getattr(base, '_defined_columns', {}).items(): inherited_columns.setdefault(k,v) + #short circuit __abstract__ inheritance + is_abstract = attrs['__abstract__'] = attrs.get('__abstract__', False) + + #short circuit __polymorphic_key__ inheritance + attrs['__polymorphic_key__'] = attrs.get('__polymorphic_key__', None) + def _transform_column(col_name, col_obj): column_dict[col_name] = col_obj if col_obj.primary_key: primary_keys[col_name] = col_obj col_obj.set_column_name(col_name) #set properties - _get = lambda self: self._values[col_name].getval() - _set = lambda self, val: self._values[col_name].setval(val) - _del = lambda self: self._values[col_name].delval() - if col_obj.can_delete: - attrs[col_name] = property(_get, _set) - else: - attrs[col_name] = property(_get, _set, _del) + attrs[col_name] = ColumnDescriptor(col_obj) column_definitions = [(k,v) for k,v in attrs.items() if isinstance(v, columns.Column)] column_definitions = sorted(column_definitions, lambda x,y: cmp(x[1].position, y[1].position)) + is_polymorphic_base = any([c[1].polymorphic_key for c in column_definitions]) + column_definitions = inherited_columns.items() + column_definitions - #columns defined on model, excludes automatically - #defined columns + polymorphic_columns = [c for c in column_definitions if c[1].polymorphic_key] + is_polymorphic = len(polymorphic_columns) > 0 + if len(polymorphic_columns) > 1: + raise ModelDefinitionException('only one polymorphic_key can be defined in a model, {} found'.format(len(polymorphic_columns))) + + polymorphic_column_name, polymorphic_column = polymorphic_columns[0] if polymorphic_columns else (None, None) + + if isinstance(polymorphic_column, (columns.BaseContainerColumn, columns.Counter)): + raise ModelDefinitionException('counter and container columns cannot be used for polymorphic keys') + + # find polymorphic base class + polymorphic_base = None + if is_polymorphic and not is_polymorphic_base: + def _get_polymorphic_base(bases): + for base in bases: + if getattr(base, '_is_polymorphic_base', False): + return base + klass = _get_polymorphic_base(base.__bases__) + if klass: + return klass + polymorphic_base = _get_polymorphic_base(bases) + defined_columns = OrderedDict(column_definitions) - #prepend primary key if one hasn't been defined - if not any([v.primary_key for k,v in column_definitions]): - k,v = 'id', columns.UUID(primary_key=True) - column_definitions = [(k,v)] + column_definitions + # check for primary key + if not is_abstract and not any([v.primary_key for k,v in column_definitions]): + raise ModelDefinitionException("At least 1 primary key is required.") + + counter_columns = [c for c in defined_columns.values() if isinstance(c, columns.Counter)] + data_columns = [c for c in defined_columns.values() if not c.primary_key and not isinstance(c, columns.Counter)] + if counter_columns and data_columns: + raise ModelDefinitionException('counter models may not have data columns') + + has_partition_keys = any(v.partition_key for (k, v) in column_definitions) - #TODO: check that the defined columns don't conflict with any of the Model API's existing attributes/methods #transform column definitions - for k,v in column_definitions: - if pk_name is None and v.primary_key: - pk_name = k - primary_key = v - v._partition_key = True - _transform_column(k,v) - - #setup primary key shortcut - if pk_name != 'pk': - attrs['pk'] = attrs[pk_name] + for k, v in column_definitions: + # don't allow a column with the same name as a built-in attribute or method + if k in BaseModel.__dict__: + raise ModelDefinitionException("column '{}' conflicts with built-in attribute/method".format(k)) - #check for duplicate column names + # counter column primary keys are not allowed + if (v.primary_key or v.partition_key) and isinstance(v, (columns.Counter, columns.BaseContainerColumn)): + raise ModelDefinitionException('counter columns and container columns cannot be used as primary keys') + + # this will mark the first primary key column as a partition + # key, if one hasn't been set already + if not has_partition_keys and v.primary_key: + v.partition_key = True + has_partition_keys = True + _transform_column(k, v) + + partition_keys = OrderedDict(k for k in primary_keys.items() if k[1].partition_key) + clustering_keys = OrderedDict(k for k in primary_keys.items() if not k[1].partition_key) + + #setup partition key shortcut + if len(partition_keys) == 0: + if not is_abstract: + raise ModelException("at least one partition key must be defined") + if len(partition_keys) == 1: + pk_name = partition_keys.keys()[0] + attrs['pk'] = attrs[pk_name] + else: + # composite partition key case, get/set a tuple of values + _get = lambda self: tuple(self._values[c].getval() for c in partition_keys.keys()) + _set = lambda self, val: tuple(self._values[c].setval(v) for (c, v) in zip(partition_keys.keys(), val)) + attrs['pk'] = property(_get, _set) + + # some validation col_names = set() for v in column_dict.values(): + # check for duplicate column names if v.db_field_name in col_names: raise ModelException("{} defines the column {} more than once".format(name, v.db_field_name)) + if v.clustering_order and not (v.primary_key and not v.partition_key): + raise ModelException("clustering_order may be specified only for clustering primary keys") + if v.clustering_order and v.clustering_order.lower() not in ('asc', 'desc'): + raise ModelException("invalid clustering order {} for column {}".format(repr(v.clustering_order), v.db_field_name)) col_names.add(v.db_field_name) #create db_name -> model name map for loading @@ -246,21 +720,46 @@ class ModelMetaClass(type): for field_name, col in column_dict.items(): db_map[col.db_field_name] = field_name - #short circuit table_name inheritance - attrs['table_name'] = attrs.get('table_name') - #add management members to the class attrs['_columns'] = column_dict attrs['_primary_keys'] = primary_keys attrs['_defined_columns'] = defined_columns + + # maps the database field to the models key attrs['_db_map'] = db_map attrs['_pk_name'] = pk_name - attrs['_primary_key'] = primary_key attrs['_dynamic_columns'] = {} + attrs['_partition_keys'] = partition_keys + attrs['_clustering_keys'] = clustering_keys + attrs['_has_counter'] = len(counter_columns) > 0 + + # add polymorphic management attributes + attrs['_is_polymorphic_base'] = is_polymorphic_base + attrs['_is_polymorphic'] = is_polymorphic + attrs['_polymorphic_base'] = polymorphic_base + attrs['_polymorphic_column'] = polymorphic_column + attrs['_polymorphic_column_name'] = polymorphic_column_name + attrs['_polymorphic_map'] = {} if is_polymorphic_base else None + + #setup class exceptions + DoesNotExistBase = None + for base in bases: + DoesNotExistBase = getattr(base, 'DoesNotExist', None) + if DoesNotExistBase is not None: break + DoesNotExistBase = DoesNotExistBase or attrs.pop('DoesNotExist', BaseModel.DoesNotExist) + attrs['DoesNotExist'] = type('DoesNotExist', (DoesNotExistBase,), {}) + + MultipleObjectsReturnedBase = None + for base in bases: + MultipleObjectsReturnedBase = getattr(base, 'MultipleObjectsReturned', None) + if MultipleObjectsReturnedBase is not None: break + MultipleObjectsReturnedBase = DoesNotExistBase or attrs.pop('MultipleObjectsReturned', BaseModel.MultipleObjectsReturned) + attrs['MultipleObjectsReturned'] = type('MultipleObjectsReturned', (MultipleObjectsReturnedBase,), {}) + #create the class and add a QuerySet to it klass = super(ModelMetaClass, cls).__new__(cls, name, bases, attrs) - klass.objects = QuerySet(klass) + return klass @@ -269,6 +768,7 @@ class Model(BaseModel): the db name for the column family can be set as the attribute db_name, or it will be genertaed from the class name """ + __abstract__ = True __metaclass__ = ModelMetaClass diff --git a/cqlengine/named.py b/cqlengine/named.py new file mode 100644 index 00000000..c6ba3ac9 --- /dev/null +++ b/cqlengine/named.py @@ -0,0 +1,122 @@ +from cqlengine.exceptions import CQLEngineException +from cqlengine.query import AbstractQueryableColumn, SimpleQuerySet + +from cqlengine.query import DoesNotExist as _DoesNotExist +from cqlengine.query import MultipleObjectsReturned as _MultipleObjectsReturned + +class QuerySetDescriptor(object): + """ + returns a fresh queryset for the given model + it's declared on everytime it's accessed + """ + + def __get__(self, obj, model): + """ :rtype: ModelQuerySet """ + if model.__abstract__: + raise CQLEngineException('cannot execute queries against abstract models') + return SimpleQuerySet(obj) + + def __call__(self, *args, **kwargs): + """ + Just a hint to IDEs that it's ok to call this + + :rtype: ModelQuerySet + """ + raise NotImplementedError + + +class NamedColumn(AbstractQueryableColumn): + """ + A column that is not coupled to a model class, or type + """ + + def __init__(self, name): + self.name = name + + def __unicode__(self): + return self.name + + def _get_column(self): + """ :rtype: NamedColumn """ + return self + + @property + def db_field_name(self): + return self.name + + @property + def cql(self): + return self.get_cql() + + def get_cql(self): + return '"{}"'.format(self.name) + + def to_database(self, val): + return val + + +class NamedTable(object): + """ + A Table that is not coupled to a model class + """ + + __abstract__ = False + + objects = QuerySetDescriptor() + + class DoesNotExist(_DoesNotExist): pass + class MultipleObjectsReturned(_MultipleObjectsReturned): pass + + def __init__(self, keyspace, name): + self.keyspace = keyspace + self.name = name + + def column(self, name): + return NamedColumn(name) + + def column_family_name(self, include_keyspace=True): + """ + Returns the column family name if it's been defined + otherwise, it creates it from the module and class name + """ + if include_keyspace: + return '{}.{}'.format(self.keyspace, self.name) + else: + return self.name + + def _get_column(self, name): + """ + Returns the column matching the given name + + :rtype: Column + """ + return self.column(name) + + # def create(self, **kwargs): + # return self.objects.create(**kwargs) + + def all(self): + return self.objects.all() + + def filter(self, *args, **kwargs): + return self.objects.filter(*args, **kwargs) + + def get(self, *args, **kwargs): + return self.objects.get(*args, **kwargs) + + +class NamedKeyspace(object): + """ + A keyspace + """ + + def __init__(self, name): + self.name = name + + def table(self, name): + """ + returns a table descriptor with the given + name that belongs to this keyspace + """ + return NamedTable(self.name, name) + diff --git a/cqlengine/operators.py b/cqlengine/operators.py new file mode 100644 index 00000000..4fef0a6b --- /dev/null +++ b/cqlengine/operators.py @@ -0,0 +1,82 @@ +class QueryOperatorException(Exception): pass + + +class BaseQueryOperator(object): + # The symbol that identifies this operator in kwargs + # ie: colname__ + symbol = None + + # The comparator symbol this operator uses in cql + cql_symbol = None + + def __unicode__(self): + if self.cql_symbol is None: + raise QueryOperatorException("cql symbol is None") + return self.cql_symbol + + def __str__(self): + return unicode(self).encode('utf-8') + + @classmethod + def get_operator(cls, symbol): + if cls == BaseQueryOperator: + raise QueryOperatorException("get_operator can only be called from a BaseQueryOperator subclass") + if not hasattr(cls, 'opmap'): + cls.opmap = {} + def _recurse(klass): + if klass.symbol: + cls.opmap[klass.symbol.upper()] = klass + for subklass in klass.__subclasses__(): + _recurse(subklass) + pass + _recurse(cls) + try: + return cls.opmap[symbol.upper()] + except KeyError: + raise QueryOperatorException("{} doesn't map to a QueryOperator".format(symbol)) + + +class BaseWhereOperator(BaseQueryOperator): + """ base operator used for where clauses """ + + +class EqualsOperator(BaseWhereOperator): + symbol = 'EQ' + cql_symbol = '=' + + +class InOperator(EqualsOperator): + symbol = 'IN' + cql_symbol = 'IN' + + +class GreaterThanOperator(BaseWhereOperator): + symbol = "GT" + cql_symbol = '>' + + +class GreaterThanOrEqualOperator(BaseWhereOperator): + symbol = "GTE" + cql_symbol = '>=' + + +class LessThanOperator(BaseWhereOperator): + symbol = "LT" + cql_symbol = '<' + + +class LessThanOrEqualOperator(BaseWhereOperator): + symbol = "LTE" + cql_symbol = '<=' + + +class BaseAssignmentOperator(BaseQueryOperator): + """ base operator used for insert and delete statements """ + + +class AssignmentOperator(BaseAssignmentOperator): + cql_symbol = "=" + + +class AddSymbol(BaseAssignmentOperator): + cql_symbol = "+" \ No newline at end of file diff --git a/cqlengine/query.py b/cqlengine/query.py index eca930a0..738822f3 100644 --- a/cqlengine/query.py +++ b/cqlengine/query.py @@ -1,208 +1,196 @@ -from collections import namedtuple import copy -from datetime import datetime -from hashlib import md5 -from time import time -from uuid import uuid1 -from cqlengine import BaseContainerColumn, BaseValueManager, Map, Counter +import time +from datetime import datetime, timedelta +from cqlengine import BaseContainerColumn, Map, columns +from cqlengine.columns import Counter, List, Set -from cqlengine.connection import connection_manager -from cqlengine.exceptions import CQLEngineException -from cqlengine.functions import BaseQueryFunction, format_timestamp +from cqlengine.connection import execute + +from cqlengine.exceptions import CQLEngineException, ValidationError +from cqlengine.functions import Token, BaseQueryFunction, QueryValue #CQL 3 reference: #http://www.datastax.com/docs/1.1/references/cql/index +from cqlengine.operators import InOperator, EqualsOperator, GreaterThanOperator, GreaterThanOrEqualOperator +from cqlengine.operators import LessThanOperator, LessThanOrEqualOperator, BaseWhereOperator +from cqlengine.statements import WhereClause, SelectStatement, DeleteStatement, UpdateStatement, AssignmentClause, InsertStatement, BaseCQLStatement, MapUpdateClause, MapDeleteClause, ListUpdateClause, SetUpdateClause, CounterUpdateClause + class QueryException(CQLEngineException): pass -class QueryOperatorException(QueryException): pass +class DoesNotExist(QueryException): pass +class MultipleObjectsReturned(QueryException): pass -class QueryOperator(object): - # The symbol that identifies this operator in filter kwargs - # ie: colname__ - symbol = None - # The comparator symbol this operator uses in cql - cql_symbol = None +class AbstractQueryableColumn(object): + """ + exposes cql query operators through pythons + builtin comparator symbols + """ - def __init__(self, column, value): - self.column = column - self.value = value + def _get_column(self): + raise NotImplementedError - #the identifier is a unique key that will be used in string - #replacement on query strings, it's created from a hash - #of this object's id and the time - self.identifier = uuid1().hex + def __unicode__(self): + raise NotImplementedError - #perform validation on this operator - self.validate_operator() - self.validate_value() + def __str__(self): + return str(unicode(self)) - @property - def cql(self): - """ - Returns this operator's portion of the WHERE clause - :param valname: the dict key that this operator's compare value will be found in - """ - if isinstance(self.value, BaseQueryFunction): - return '"{}" {} {}'.format(self.column.db_field_name, self.cql_symbol, self.value.to_cql(self.identifier)) + def _to_database(self, val): + if isinstance(val, QueryValue): + return val else: - return '"{}" {} :{}'.format(self.column.db_field_name, self.cql_symbol, self.identifier) + return self._get_column().to_database(val) - def validate_operator(self): + def in_(self, item): """ - Checks that this operator can be used on the column provided + Returns an in operator + + used where you'd typically want to use python's `in` operator """ - if self.symbol is None: - raise QueryOperatorException( - "{} is not a valid operator, use one with 'symbol' defined".format( - self.__class__.__name__ - ) - ) - if self.cql_symbol is None: - raise QueryOperatorException( - "{} is not a valid operator, use one with 'cql_symbol' defined".format( - self.__class__.__name__ - ) - ) + return WhereClause(unicode(self), InOperator(), item) - def validate_value(self): - """ - Checks that the compare value works with this operator + def __eq__(self, other): + return WhereClause(unicode(self), EqualsOperator(), self._to_database(other)) - Doesn't do anything by default - """ - pass + def __gt__(self, other): + return WhereClause(unicode(self), GreaterThanOperator(), self._to_database(other)) - def get_dict(self): - """ - Returns this operators contribution to the cql.query arg dictionanry + def __ge__(self, other): + return WhereClause(unicode(self), GreaterThanOrEqualOperator(), self._to_database(other)) - ie: if this column's name is colname, and the identifier is colval, - this should return the dict: {'colval':} - SELECT * FROM column_family WHERE colname=:colval - """ - if isinstance(self.value, BaseQueryFunction): - return {self.identifier: self.column.to_database(self.value.get_value())} - else: - return {self.identifier: self.column.to_database(self.value)} + def __lt__(self, other): + return WhereClause(unicode(self), LessThanOperator(), self._to_database(other)) - @classmethod - def get_operator(cls, symbol): - if not hasattr(cls, 'opmap'): - QueryOperator.opmap = {} - def _recurse(klass): - if klass.symbol: - QueryOperator.opmap[klass.symbol.upper()] = klass - for subklass in klass.__subclasses__(): - _recurse(subklass) - pass - _recurse(QueryOperator) - try: - return QueryOperator.opmap[symbol.upper()] - except KeyError: - raise QueryOperatorException("{} doesn't map to a QueryOperator".format(symbol)) + def __le__(self, other): + return WhereClause(unicode(self), LessThanOrEqualOperator(), self._to_database(other)) -class EqualsOperator(QueryOperator): - symbol = 'EQ' - cql_symbol = '=' - -class InOperator(EqualsOperator): - symbol = 'IN' - cql_symbol = 'IN' - - class Quoter(object): - """ - contains a single value, which will quote itself for CQL insertion statements - """ - def __init__(self, value): - self.value = value - - def __str__(self): - from cql.query import cql_quote as cq - return '(' + ', '.join([cq(v) for v in self.value]) + ')' - - def get_dict(self): - if isinstance(self.value, BaseQueryFunction): - return {self.identifier: self.column.to_database(self.value.get_value())} - else: - try: - values = [v for v in self.value] - except TypeError: - raise QueryException("in operator arguments must be iterable, {} found".format(self.value)) - return {self.identifier: self.Quoter([self.column.to_database(v) for v in self.value])} - -class GreaterThanOperator(QueryOperator): - symbol = "GT" - cql_symbol = '>' - -class GreaterThanOrEqualOperator(QueryOperator): - symbol = "GTE" - cql_symbol = '>=' - -class LessThanOperator(QueryOperator): - symbol = "LT" - cql_symbol = '<' - -class LessThanOrEqualOperator(QueryOperator): - symbol = "LTE" - cql_symbol = '<=' class BatchType(object): Unlogged = 'UNLOGGED' Counter = 'COUNTER' + class BatchQuery(object): """ Handles the batching of queries http://www.datastax.com/docs/1.2/cql_cli/cql/BATCH """ + _consistency = None - def __init__(self, batch_type=None, timestamp=None): + def __init__(self, batch_type=None, timestamp=None, consistency=None, execute_on_exception=False): + """ + :param batch_type: (optional) One of batch type values available through BatchType enum + :type batch_type: str or None + :param timestamp: (optional) A datetime or timedelta object with desired timestamp to be applied + to the batch transaction. + :type timestamp: datetime or timedelta or None + :param consistency: (optional) One of consistency values ("ANY", "ONE", "QUORUM" etc) + :type consistency: str or None + :param execute_on_exception: (Defaults to False) Indicates that when the BatchQuery instance is used + as a context manager the queries accumulated within the context must be executed despite + encountering an error within the context. By default, any exception raised from within + the context scope will cause the batched queries not to be executed. + :type execute_on_exception: bool + :param callbacks: A list of functions to be executed after the batch executes. Note, that if the batch + does not execute, the callbacks are not executed. This, thus, effectively is a list of "on success" + callback handlers. If defined, must be a collection of callables. + :type callbacks: list or set or tuple + """ self.queries = [] self.batch_type = batch_type - if timestamp is not None and not isinstance(timestamp, datetime): + if timestamp is not None and not isinstance(timestamp, (datetime, timedelta)): raise CQLEngineException('timestamp object must be an instance of datetime') self.timestamp = timestamp + self._consistency = consistency + self._execute_on_exception = execute_on_exception + self._callbacks = [] - def add_query(self, query, params): - self.queries.append((query, params)) + def add_query(self, query): + if not isinstance(query, BaseCQLStatement): + raise CQLEngineException('only BaseCQLStatements can be added to a batch query') + self.queries.append(query) + + def consistency(self, consistency): + self._consistency = consistency + + def _execute_callbacks(self): + for callback, args, kwargs in self._callbacks: + callback(*args, **kwargs) + + # trying to clear up the ref counts for objects mentioned in the set + del self._callbacks + + def add_callback(self, fn, *args, **kwargs): + """Add a function and arguments to be passed to it to be executed after the batch executes. + + A batch can support multiple callbacks. + + Note, that if the batch does not execute, the callbacks are not executed. + A callback, thus, is an "on batch success" handler. + + :param fn: Callable object + :type fn: callable + :param *args: Positional arguments to be passed to the callback at the time of execution + :param **kwargs: Named arguments to be passed to the callback at the time of execution + """ + if not callable(fn): + raise ValueError("Value for argument 'fn' is {} and is not a callable object.".format(type(fn))) + self._callbacks.append((fn, args, kwargs)) def execute(self): if len(self.queries) == 0: # Empty batch is a no-op + # except for callbacks + self._execute_callbacks() return opener = 'BEGIN ' + (self.batch_type + ' ' if self.batch_type else '') + ' BATCH' if self.timestamp: - ts = format_timestamp(self.timestamp) + + if isinstance(self.timestamp, (int, long)): + ts = self.timestamp + elif isinstance(self.timestamp, (datetime, timedelta)): + ts = self.timestamp + if isinstance(self.timestamp, timedelta): + ts += datetime.now() # Apply timedelta + ts = long(time.mktime(ts.timetuple()) * 1e+6 + ts.microsecond) + else: + raise ValueError("Batch expects a long, a timedelta, or a datetime") + opener += ' USING TIMESTAMP {}'.format(ts) query_list = [opener] parameters = {} - for query, params in self.queries: - query_list.append(' ' + query) - parameters.update(params) + ctx_counter = 0 + for query in self.queries: + query.update_context_id(ctx_counter) + ctx = query.get_context() + ctx_counter += len(ctx) + query_list.append(' ' + str(query)) + parameters.update(ctx) query_list.append('APPLY BATCH;') - with connection_manager() as con: - con.execute('\n'.join(query_list), parameters) + execute('\n'.join(query_list), parameters, self._consistency) self.queries = [] + self._execute_callbacks() def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): - #don't execute if there was an exception - if exc_type is not None: return + #don't execute if there was an exception by default + if exc_type is not None and not self._execute_on_exception: return self.execute() -class QuerySet(object): + +class AbstractQuerySet(object): def __init__(self, model): - super(QuerySet, self).__init__() + super(AbstractQuerySet, self).__init__() self.model = model #Where clause filters @@ -221,6 +209,9 @@ class QuerySet(object): self._defer_fields = [] self._only_fields = [] + self._values_list = False + self._flat_values_list = False + #results cache self._con = None self._cur = None @@ -228,24 +219,34 @@ class QuerySet(object): self._result_idx = None self._batch = None + self._ttl = None + self._consistency = None + self._timestamp = None @property def column_family_name(self): return self.model.column_family_name() + def _execute(self, q): + if self._batch: + return self._batch.add_query(q) + else: + result = execute(q, consistency_level=self._consistency) + return result + def __unicode__(self): - return self._select_query() + return unicode(self._select_query()) def __str__(self): return str(self.__unicode__()) - def __call__(self, **kwargs): - return self.filter(**kwargs) + def __call__(self, *args, **kwargs): + return self.filter(*args, **kwargs) def __deepcopy__(self, memo): clone = self.__class__(self.model) - for k,v in self.__dict__.items(): - if k in ['_con', '_cur', '_result_cache', '_result_idx']: + for k, v in self.__dict__.items(): + if k in ['_con', '_cur', '_result_cache', '_result_idx']: # don't clone these clone.__dict__[k] = None elif k == '_batch': # we need to keep the same batch instance across @@ -259,72 +260,32 @@ class QuerySet(object): return clone def __len__(self): - return self.count() - - def __del__(self): - if self._con: - self._con.close() - self._con = None - self._cur = None + self._execute_query() + return len(self._result_cache) #----query generation / execution---- - def _validate_where_syntax(self): - """ Checks that a filterset will not create invalid cql """ + def _select_fields(self): + """ returns the fields to select """ + return [] - #check that there's either a = or IN relationship with a primary key or indexed field - equal_ops = [w for w in self._where if isinstance(w, EqualsOperator)] - if not any([w.column.primary_key or w.column.index for w in equal_ops]): - raise QueryException('Where clauses require either a "=" or "IN" comparison with either a primary key or indexed field') - - if not self._allow_filtering: - #if the query is not on an indexed field - if not any([w.column.index for w in equal_ops]): - if not any([w.column._partition_key for w in equal_ops]): - raise QueryException('Filtering on a clustering key without a partition key is not allowed unless allow_filtering() is called on the querset') - - - #TODO: abuse this to see if we can get cql to raise an exception - - def _where_clause(self): - """ Returns a where clause based on the given filter args """ - self._validate_where_syntax() - return ' AND '.join([f.cql for f in self._where]) - - def _where_values(self): - """ Returns the value dict to be passed to the cql query """ - values = {} - for where in self._where: - values.update(where.get_dict()) - return values + def _validate_select_where(self): + """ put select query validation here """ def _select_query(self): """ Returns a select clause based on the given filter args """ - fields = self.model._columns.keys() - if self._defer_fields: - fields = [f for f in fields if f not in self._defer_fields] - elif self._only_fields: - fields = [f for f in fields if f in self._only_fields] - db_fields = [self.model._columns[f].db_field_name for f in fields] - - qs = ['SELECT {}'.format(', '.join(['"{}"'.format(f) for f in db_fields]))] - qs += ['FROM {}'.format(self.column_family_name)] - if self._where: - qs += ['WHERE {}'.format(self._where_clause())] - - if self._order: - qs += ['ORDER BY {}'.format(', '.join(self._order))] - - if self._limit: - qs += ['LIMIT {}'.format(self._limit)] - - if self._allow_filtering: - qs += ['ALLOW FILTERING'] - - return ' '.join(qs) + self._validate_select_where() + return SelectStatement( + self.column_family_name, + fields=self._select_fields(), + where=self._where, + order_by=self._order, + limit=self._limit, + allow_filtering=self._allow_filtering + ) #----Reads------ @@ -332,9 +293,8 @@ class QuerySet(object): if self._batch: raise CQLEngineException("Only inserts, updates, and deletes are available in batch mode") if self._result_cache is None: - self._con = connection_manager() - self._cur = self._con.execute(self._select_query(), self._where_values()) - self._result_cache = [None]*self._cur.rowcount + self._result_cache = list(self._execute(self._select_query())) + self._construct_result = self._get_result_constructor() def _fill_result_cache_to_idx(self, idx): self._execute_query() @@ -345,15 +305,12 @@ class QuerySet(object): if qty < 1: return else: - names = [i[0] for i in self._cur.description] - for values in self._cur.fetchmany(qty): - value_dict = dict(zip(names, values)) + for idx in range(qty): self._result_idx += 1 - self._result_cache[self._result_idx] = self._construct_instance(value_dict) + self._result_cache[self._result_idx] = self._construct_result(self._result_cache[self._result_idx]) #return the connection to the connection pool if we have all objects - if self._result_cache and self._result_cache[-1] is not None: - self._con.close() + if self._result_cache and self._result_idx == (len(self._result_cache) - 1): self._con = None self._cur = None @@ -362,7 +319,7 @@ class QuerySet(object): for idx in range(len(self._result_cache)): instance = self._result_cache[idx] - if instance is None: + if isinstance(instance, dict): self._fill_result_cache_to_idx(idx) yield self._result_cache[idx] @@ -393,19 +350,11 @@ class QuerySet(object): self._fill_result_cache_to_idx(s) return self._result_cache[s] - - def _construct_instance(self, values): - #translate column names to model names - field_dict = {} - db_map = self.model._db_map - for key, val in values.items(): - if key in db_map: - field_dict[db_map[key]] = val - else: - field_dict[key] = val - instance = self.model(**field_dict) - instance._is_persisted = True - return instance + def _get_result_constructor(self): + """ + Returns a function that will be used to instantiate query results + """ + raise NotImplementedError def batch(self, batch_obj): """ @@ -413,8 +362,8 @@ class QuerySet(object): :param batch_obj: :return: """ - if not isinstance(batch_obj, BatchQuery): - raise CQLEngineException('batch_obj must be a BatchQuery instance') + if batch_obj is not None and not isinstance(batch_obj, BatchQuery): + raise CQLEngineException('batch_obj must be a BatchQuery instance or None') clone = copy.deepcopy(self) clone._batch = batch_obj return clone @@ -426,8 +375,11 @@ class QuerySet(object): return None def all(self): + return copy.deepcopy(self) + + def consistency(self, consistency): clone = copy.deepcopy(self) - clone._where = [] + clone._consistency = consistency return clone def _parse_filter_arg(self, arg): @@ -436,7 +388,7 @@ class QuerySet(object): __ :returns: colname, op tuple """ - statement = arg.split('__') + statement = arg.rsplit('__', 1) if len(statement) == 1: return arg, None elif len(statement) == 2: @@ -444,33 +396,76 @@ class QuerySet(object): else: raise QueryException("Can't parse '{}'".format(arg)) - def filter(self, **kwargs): + def filter(self, *args, **kwargs): + """ + Adds WHERE arguments to the queryset, returning a new queryset + + #TODO: show examples + + :rtype: AbstractQuerySet + """ #add arguments to the where clause filters + if kwargs.values().count(None): + raise CQLEngineException("None values on filter are not allowed") + clone = copy.deepcopy(self) + for operator in args: + if not isinstance(operator, WhereClause): + raise QueryException('{} is not a valid query operator'.format(operator)) + clone._where.append(operator) + for arg, val in kwargs.items(): col_name, col_op = self._parse_filter_arg(arg) + quote_field = True #resolve column and operator try: - column = self.model._columns[col_name] + column = self.model._get_column(col_name) except KeyError: - raise QueryException("Can't resolve column name: '{}'".format(col_name)) + if col_name == 'pk__token': + if not isinstance(val, Token): + raise QueryException("Virtual column 'pk__token' may only be compared to Token() values") + column = columns._PartitionKeysToken(self.model) + quote_field = False + else: + raise QueryException("Can't resolve column name: '{}'".format(col_name)) + + if isinstance(val, Token): + if col_name != 'pk__token': + raise QueryException("Token() values may only be compared to the 'pk__token' virtual column") + partition_columns = column.partition_columns + if len(partition_columns) != len(val.value): + raise QueryException( + 'Token() received {} arguments but model has {} partition keys'.format( + len(val.value), len(partition_columns))) + val.set_columns(partition_columns) #get query operator, or use equals if not supplied - operator_class = QueryOperator.get_operator(col_op or 'EQ') - operator = operator_class(column, val) + operator_class = BaseWhereOperator.get_operator(col_op or 'EQ') + operator = operator_class() - clone._where.append(operator) + if isinstance(operator, InOperator): + if not isinstance(val, (list, tuple)): + raise QueryException('IN queries must use a list/tuple value') + query_val = [column.to_database(v) for v in val] + elif isinstance(val, BaseQueryFunction): + query_val = val + else: + query_val = column.to_database(val) + + clone._where.append(WhereClause(column.db_field_name, operator, query_val, quote_field=quote_field)) return clone - def get(self, **kwargs): + def get(self, *args, **kwargs): """ Returns a single instance matching this query, optionally with additional filter kwargs. A DoesNotExistError will be raised if there are no rows matching the query A MultipleObjectsFoundError will be raised if there is more than one row matching the queyr """ - if kwargs: return self.filter(**kwargs).get() + if args or kwargs: + return self.filter(*args, **kwargs).get() + self._execute_query() if len(self._result_cache) == 0: raise self.model.DoesNotExist @@ -480,6 +475,12 @@ class QuerySet(object): else: return self[0] + def _get_ordering_condition(self, colname): + order_type = 'DESC' if colname.startswith('-') else 'ASC' + colname = colname.replace('-', '') + + return colname, order_type + def order_by(self, *colnames): """ orders the result set. @@ -494,24 +495,7 @@ class QuerySet(object): conditions = [] for colname in colnames: - order_type = 'DESC' if colname.startswith('-') else 'ASC' - colname = colname.replace('-', '') - - column = self.model._columns.get(colname) - if column is None: - raise QueryException("Can't resolve the column name: '{}'".format(colname)) - - #validate the column selection - if not column.primary_key: - raise QueryException( - "Can't order on '{}', can only order on (clustered) primary keys".format(colname)) - - pks = [v for k, v in self.model._columns.items() if v.primary_key] - if column == pks[0]: - raise QueryException( - "Can't order by the first primary key (partition key), clustering (secondary) keys only") - - conditions.append('"{}" {}'.format(column.db_field_name, order_type)) + conditions.append('"{}" {}'.format(*self._get_ordering_condition(colname))) clone = copy.deepcopy(self) clone._order.extend(conditions) @@ -521,20 +505,12 @@ class QuerySet(object): """ Returns the number of rows matched by this query """ if self._batch: raise CQLEngineException("Only inserts, updates, and deletes are available in batch mode") - #TODO: check for previous query execution and return row count if it exists + if self._result_cache is None: - qs = ['SELECT COUNT(*)'] - qs += ['FROM {}'.format(self.column_family_name)] - if self._where: - qs += ['WHERE {}'.format(self._where_clause())] - if self._allow_filtering: - qs += ['ALLOW FILTERING'] - - qs = ' '.join(qs) - - with connection_manager() as con: - cur = con.execute(qs, self._where_values()) - return cur.fetchone()[0] + query = self._select_query() + query.count = True + result = self._execute(query) + return result[0]['count'] else: return len(self._result_cache) @@ -594,26 +570,193 @@ class QuerySet(object): return self._only_or_defer('defer', fields) def create(self, **kwargs): - return self.model(**kwargs).batch(self._batch).save() + return self.model(**kwargs).batch(self._batch).ttl(self._ttl).\ + consistency(self._consistency).\ + timestamp(self._timestamp).save() - #----delete--- - def delete(self, columns=[]): + def delete(self): """ Deletes the contents of a query """ #validate where clause partition_key = self.model._primary_keys.values()[0] - if not any([c.column.db_field_name == partition_key.db_field_name for c in self._where]): + if not any([c.field == partition_key.column_name for c in self._where]): raise QueryException("The partition key must be defined on delete queries") - qs = ['DELETE FROM {}'.format(self.column_family_name)] - qs += ['WHERE {}'.format(self._where_clause())] - qs = ' '.join(qs) - if self._batch: - self._batch.add_query(qs, self._where_values()) + dq = DeleteStatement( + self.column_family_name, + where=self._where, + timestamp=self._timestamp + ) + self._execute(dq) + + def __eq__(self, q): + if len(self._where) == len(q._where): + return all([w in q._where for w in self._where]) + return False + + def __ne__(self, q): + return not (self != q) + + +class ResultObject(dict): + """ + adds attribute access to a dictionary + """ + + def __getattr__(self, item): + try: + return self[item] + except KeyError: + raise AttributeError + + +class SimpleQuerySet(AbstractQuerySet): + """ + + """ + + def _get_result_constructor(self): + """ + Returns a function that will be used to instantiate query results + """ + def _construct_instance(values): + return ResultObject(values) + return _construct_instance + + +class ModelQuerySet(AbstractQuerySet): + """ + + """ + def _validate_select_where(self): + """ Checks that a filterset will not create invalid select statement """ + #check that there's either a = or IN relationship with a primary key or indexed field + equal_ops = [self.model._columns.get(w.field) for w in self._where if isinstance(w.operator, EqualsOperator)] + token_comparison = any([w for w in self._where if isinstance(w.value, Token)]) + if not any([w.primary_key or w.index for w in equal_ops]) and not token_comparison and not self._allow_filtering: + raise QueryException('Where clauses require either a "=" or "IN" comparison with either a primary key or indexed field') + + if not self._allow_filtering: + #if the query is not on an indexed field + if not any([w.index for w in equal_ops]): + if not any([w.partition_key for w in equal_ops]) and not token_comparison: + raise QueryException('Filtering on a clustering key without a partition key is not allowed unless allow_filtering() is called on the querset') + + def _select_fields(self): + if self._defer_fields or self._only_fields: + fields = self.model._columns.keys() + if self._defer_fields: + fields = [f for f in fields if f not in self._defer_fields] + elif self._only_fields: + fields = self._only_fields + return [self.model._columns[f].db_field_name for f in fields] + return super(ModelQuerySet, self)._select_fields() + + def _get_result_constructor(self): + """ Returns a function that will be used to instantiate query results """ + if not self._values_list: # we want models + return lambda rows: self.model._construct_instance(rows) + elif self._flat_values_list: # the user has requested flattened list (1 value per row) + return lambda row: row.popitem()[1] else: - with connection_manager() as con: - con.execute(qs, self._where_values()) + return lambda row: self._get_row_value_list(self._only_fields, row) + + def _get_row_value_list(self, fields, row): + result = [] + for x in fields: + result.append(row[x]) + return result + + def _get_ordering_condition(self, colname): + colname, order_type = super(ModelQuerySet, self)._get_ordering_condition(colname) + + column = self.model._columns.get(colname) + if column is None: + raise QueryException("Can't resolve the column name: '{}'".format(colname)) + + #validate the column selection + if not column.primary_key: + raise QueryException( + "Can't order on '{}', can only order on (clustered) primary keys".format(colname)) + + pks = [v for k, v in self.model._columns.items() if v.primary_key] + if column == pks[0]: + raise QueryException( + "Can't order by the first primary key (partition key), clustering (secondary) keys only") + + return column.db_field_name, order_type + + def values_list(self, *fields, **kwargs): + """ Instructs the query set to return tuples, not model instance """ + flat = kwargs.pop('flat', False) + if kwargs: + raise TypeError('Unexpected keyword arguments to values_list: %s' + % (kwargs.keys(),)) + if flat and len(fields) > 1: + raise TypeError("'flat' is not valid when values_list is called with more than one field.") + clone = self.only(fields) + clone._values_list = True + clone._flat_values_list = flat + return clone + + def ttl(self, ttl): + clone = copy.deepcopy(self) + clone._ttl = ttl + return clone + + def timestamp(self, timestamp): + clone = copy.deepcopy(self) + clone._timestamp = timestamp + return clone + + def update(self, **values): + """ Updates the rows in this queryset """ + if not values: + return + + nulled_columns = set() + us = UpdateStatement(self.column_family_name, where=self._where, ttl=self._ttl, timestamp=self._timestamp) + for name, val in values.items(): + col_name, col_op = self._parse_filter_arg(name) + col = self.model._columns.get(col_name) + # check for nonexistant columns + if col is None: + raise ValidationError("{}.{} has no column named: {}".format(self.__module__, self.model.__name__, col_name)) + # check for primary key update attempts + if col.is_primary_key: + raise ValidationError("Cannot apply update to primary key '{}' for {}.{}".format(col_name, self.__module__, self.model.__name__)) + + val = col.validate(val) + if val is None: + nulled_columns.add(col_name) + continue + + # add the update statements + if isinstance(col, Counter): + # TODO: implement counter updates + raise NotImplementedError + elif isinstance(col, (List, Set, Map)): + if isinstance(col, List): + klass = ListUpdateClause + elif isinstance(col, Set): + klass = SetUpdateClause + elif isinstance(col, Map): + klass = MapUpdateClause + else: + raise RuntimeError + us.add_assignment_clause(klass(col_name, col.to_database(val), operation=col_op)) + else: + us.add_assignment_clause(AssignmentClause( + col_name, col.to_database(val))) + + if us.assignments: + self._execute(us) + + if nulled_columns: + ds = DeleteStatement(self.column_family_name, fields=nulled_columns, where=self._where) + self._execute(ds) + class DMLQuery(object): """ @@ -623,21 +766,115 @@ class DMLQuery(object): unlike the read query object, this is mutable """ + _ttl = None + _consistency = None + _timestamp = None - def __init__(self, model, instance=None, batch=None): + def __init__(self, model, instance=None, batch=None, ttl=None, consistency=None, timestamp=None): self.model = model self.column_family_name = self.model.column_family_name() self.instance = instance - self.batch = batch - pass + self._batch = batch + self._ttl = ttl + self._consistency = consistency + self._timestamp = timestamp + + def _execute(self, q): + if self._batch: + return self._batch.add_query(q) + else: + tmp = execute(q, consistency_level=self._consistency) + return tmp def batch(self, batch_obj): - if not isinstance(batch_obj, BatchQuery): - raise CQLEngineException('batch_obj must be a BatchQuery instance') - self.batch = batch_obj + if batch_obj is not None and not isinstance(batch_obj, BatchQuery): + raise CQLEngineException('batch_obj must be a BatchQuery instance or None') + self._batch = batch_obj return self - def save(self, ttl=None, timestamp=None): + def _delete_null_columns(self): + """ + executes a delete query to remove columns that have changed to null + """ + ds = DeleteStatement(self.column_family_name) + deleted_fields = False + for _, v in self.instance._values.items(): + col = v.column + if v.deleted: + ds.add_field(col.db_field_name) + deleted_fields = True + elif isinstance(col, Map): + uc = MapDeleteClause(col.db_field_name, v.value, v.previous_value) + if uc.get_context_size() > 0: + ds.add_field(uc) + deleted_fields = True + + if deleted_fields: + for name, col in self.model._primary_keys.items(): + ds.add_where_clause(WhereClause( + col.db_field_name, + EqualsOperator(), + col.to_database(getattr(self.instance, name)) + )) + self._execute(ds) + + def update(self): + """ + updates a row. + This is a blind update call. + All validation and cleaning needs to happen + prior to calling this. + """ + if self.instance is None: + raise CQLEngineException("DML Query intance attribute is None") + assert type(self.instance) == self.model + + statement = UpdateStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp) + #get defined fields and their column names + for name, col in self.model._columns.items(): + if not col.is_primary_key: + val = getattr(self.instance, name, None) + val_mgr = self.instance._values[name] + + # don't update something that is null + if val is None: + continue + + # don't update something if it hasn't changed + if not val_mgr.changed and not isinstance(col, Counter): + continue + + if isinstance(col, (BaseContainerColumn, Counter)): + # get appropriate clause + if isinstance(col, List): klass = ListUpdateClause + elif isinstance(col, Map): klass = MapUpdateClause + elif isinstance(col, Set): klass = SetUpdateClause + elif isinstance(col, Counter): klass = CounterUpdateClause + else: raise RuntimeError + + # do the stuff + clause = klass(col.db_field_name, val, + previous=val_mgr.previous_value, column=col) + if clause.get_context_size() > 0: + statement.add_assignment_clause(clause) + else: + statement.add_assignment_clause(AssignmentClause( + col.db_field_name, + col.to_database(val) + )) + + if statement.get_context_size() > 0 or self.instance._has_counter: + for name, col in self.model._primary_keys.items(): + statement.add_where_clause(WhereClause( + col.db_field_name, + EqualsOperator(), + col.to_database(getattr(self.instance, name)) + )) + self._execute(statement) + + self._delete_null_columns() + + def save(self): """ Creates / updates a row. This is a blind insert call. @@ -648,156 +885,42 @@ class DMLQuery(object): raise CQLEngineException("DML Query intance attribute is None") assert type(self.instance) == self.model - #organize data - value_pairs = [] - values = self.instance.as_dict() - - #get defined fields and their column names - for name, col in self.model._columns.items(): - val = values.get(name) - if val is None: continue - value_pairs += [(col.db_field_name, val)] - - #construct query string - field_names = zip(*value_pairs)[0] - field_ids = {n:uuid1().hex for n in field_names} - field_values = dict(value_pairs) - query_values = {field_ids[n]:field_values[n] for n in field_names} - - qs = [] - - using = [] - if ttl is not None: - ttl = int(ttl) - using.append('TTL {} '.format(ttl)) - if timestamp: - ts = format_timestamp(timestamp) - using.append('TIMESTAMP {} '.format(ts)) - - usings = '' - if using: - using = 'AND '.join(using).strip() - usings = ' USING {}'.format(using) - - - if self.instance._can_update(): - qs += ["UPDATE {}".format(self.column_family_name)] - if usings: - qs += [usings] - qs += ["SET"] - - set_statements = [] - #get defined fields and their column names - for name, col in self.model._columns.items(): - if not col.is_primary_key: - val = values.get(name) - if val is None: continue - if isinstance(col, Counter): - field_ids.pop(name) - value = field_values.pop(name) - if value == 0: - # Don't increment that column - continue - elif value < 0: - sign = '-' - else: - sign = '+' - set_statements += ['{0} = {0} {1} {2}'.format(col.db_field_name, sign, abs(value))] - elif isinstance(col, BaseContainerColumn): - #remove value from query values, the column will handle it - query_values.pop(field_ids.get(name), None) - - val_mgr = self.instance._values[name] - set_statements += col.get_update_statement(val, val_mgr.previous_value, query_values) - pass - else: - set_statements += ['"{}" = :{}'.format(col.db_field_name, field_ids[col.db_field_name])] - qs += [', '.join(set_statements)] - - qs += ['WHERE'] - - where_statements = [] - for name, col in self.model._primary_keys.items(): - where_statements += ['"{}" = :{}'.format(col.db_field_name, field_ids[col.db_field_name])] - - qs += [' AND '.join(where_statements)] - - # clear the qs if there are not set statements - if not set_statements: qs = [] - + nulled_fields = set() + if self.instance._has_counter or self.instance._can_update(): + return self.update() else: - qs += ["INSERT INTO {}".format(self.column_family_name)] - qs += ["({})".format(', '.join(['"{}"'.format(f) for f in field_names]))] - qs += ['VALUES'] - qs += ["({})".format(', '.join([':'+field_ids[f] for f in field_names]))] - if usings: - qs += [usings] - - qs = ' '.join(qs) + insert = InsertStatement(self.column_family_name, ttl=self._ttl, timestamp=self._timestamp) + for name, col in self.instance._columns.items(): + val = getattr(self.instance, name, None) + if col._val_is_null(val): + if self.instance._values[name].changed: + nulled_fields.add(col.db_field_name) + continue + insert.add_assignment_clause(AssignmentClause( + col.db_field_name, + col.to_database(getattr(self.instance, name, None)) + )) # skip query execution if it's empty # caused by pointless update queries - if qs: - if self.batch: - self.batch.add_query(qs, query_values) - else: - with connection_manager() as con: - con.execute(qs, query_values) + if not insert.is_empty: + self._execute(insert) - - # delete nulled columns and removed map keys - qs = ['DELETE'] - query_values = {} - - del_statements = [] - for k,v in self.instance._values.items(): - col = v.column - if v.deleted: - del_statements += ['"{}"'.format(col.db_field_name)] - elif isinstance(col, Map): - del_statements += col.get_delete_statement(v.value, v.previous_value, query_values) - - if del_statements: - qs += [', '.join(del_statements)] - - qs += ['FROM {}'.format(self.column_family_name)] - - qs += ['WHERE'] - where_statements = [] - for name, col in self.model._primary_keys.items(): - field_id = uuid1().hex - query_values[field_id] = field_values[name] - where_statements += ['"{}" = :{}'.format(col.db_field_name, field_id)] - qs += [' AND '.join(where_statements)] - - qs = ' '.join(qs) - - if self.batch: - self.batch.add_query(qs, query_values) - else: - with connection_manager() as con: - con.execute(qs, query_values) + # delete any nulled columns + self._delete_null_columns() def delete(self): """ Deletes one instance """ if self.instance is None: - raise CQLEngineException("DML Query intance attribute is None") - field_values = {} - qs = ['DELETE FROM {}'.format(self.column_family_name)] - qs += ['WHERE'] - where_statements = [] + raise CQLEngineException("DML Query instance attribute is None") + + ds = DeleteStatement(self.column_family_name, timestamp=self._timestamp) for name, col in self.model._primary_keys.items(): - field_id = uuid1().hex - field_values[field_id] = col.to_database(getattr(self.instance, name)) - where_statements += ['"{}" = :{}'.format(col.db_field_name, field_id)] - - qs += [' AND '.join(where_statements)] - qs = ' '.join(qs) - - if self.batch: - self.batch.add_query(qs, field_values) - else: - with connection_manager() as con: - con.execute(qs, field_values) + ds.add_where_clause(WhereClause( + col.db_field_name, + EqualsOperator(), + col.to_database(getattr(self.instance, name)) + )) + self._execute(ds) diff --git a/cqlengine/statements.py b/cqlengine/statements.py new file mode 100644 index 00000000..549b529d --- /dev/null +++ b/cqlengine/statements.py @@ -0,0 +1,724 @@ +import time +from datetime import datetime, timedelta +from cqlengine.functions import QueryValue +from cqlengine.operators import BaseWhereOperator, InOperator + + +class StatementException(Exception): pass + + +class ValueQuoter(object): + + def __init__(self, value): + self.value = value + + def __unicode__(self): + from cassandra.encoder import cql_quote + if isinstance(self.value, bool): + return 'true' if self.value else 'false' + elif isinstance(self.value, (list, tuple)): + return '[' + ', '.join([cql_quote(v) for v in self.value]) + ']' + elif isinstance(self.value, dict): + return '{' + ', '.join([cql_quote(k) + ':' + cql_quote(v) for k,v in self.value.items()]) + '}' + elif isinstance(self.value, set): + return '{' + ', '.join([cql_quote(v) for v in self.value]) + '}' + return cql_quote(self.value) + + def __eq__(self, other): + if isinstance(other, self.__class__): + return self.value == other.value + return False + + def __str__(self): + return unicode(self).encode('utf-8') + + +class InQuoter(ValueQuoter): + + def __unicode__(self): + from cassandra.encoder import cql_quote + return '(' + ', '.join([cql_quote(v) for v in self.value]) + ')' + + +class BaseClause(object): + + def __init__(self, field, value): + self.field = field + self.value = value + self.context_id = None + + def __unicode__(self): + raise NotImplementedError + + def __str__(self): + return unicode(self).encode('utf-8') + + def __hash__(self): + return hash(self.field) ^ hash(self.value) + + def __eq__(self, other): + if isinstance(other, self.__class__): + return self.field == other.field and self.value == other.value + return False + + def __ne__(self, other): + return not self.__eq__(other) + + def get_context_size(self): + """ returns the number of entries this clause will add to the query context """ + return 1 + + def set_context_id(self, i): + """ sets the value placeholder that will be used in the query """ + self.context_id = i + + def update_context(self, ctx): + """ updates the query context with this clauses values """ + assert isinstance(ctx, dict) + ctx[str(self.context_id)] = self.value + + +class WhereClause(BaseClause): + """ a single where statement used in queries """ + + def __init__(self, field, operator, value, quote_field=True): + """ + + :param field: + :param operator: + :param value: + :param quote_field: hack to get the token function rendering properly + :return: + """ + if not isinstance(operator, BaseWhereOperator): + raise StatementException( + "operator must be of type {}, got {}".format(BaseWhereOperator, type(operator)) + ) + super(WhereClause, self).__init__(field, value) + self.operator = operator + self.query_value = self.value if isinstance(self.value, QueryValue) else QueryValue(self.value) + self.quote_field = quote_field + + def __unicode__(self): + field = ('"{}"' if self.quote_field else '{}').format(self.field) + return u'{} {} {}'.format(field, self.operator, unicode(self.query_value)) + + def __hash__(self): + return super(WhereClause, self).__hash__() ^ hash(self.operator) + + def __eq__(self, other): + if super(WhereClause, self).__eq__(other): + return self.operator.__class__ == other.operator.__class__ + return False + + def get_context_size(self): + return self.query_value.get_context_size() + + def set_context_id(self, i): + super(WhereClause, self).set_context_id(i) + self.query_value.set_context_id(i) + + def update_context(self, ctx): + if isinstance(self.operator, InOperator): + ctx[str(self.context_id)] = InQuoter(self.value) + else: + self.query_value.update_context(ctx) + + +class AssignmentClause(BaseClause): + """ a single variable st statement """ + + def __unicode__(self): + return u'"{}" = %({})s'.format(self.field, self.context_id) + + def insert_tuple(self): + return self.field, self.context_id + + +class ContainerUpdateClause(AssignmentClause): + + def __init__(self, field, value, operation=None, previous=None, column=None): + super(ContainerUpdateClause, self).__init__(field, value) + self.previous = previous + self._assignments = None + self._operation = operation + self._analyzed = False + self._column = column + + def _to_database(self, val): + return self._column.to_database(val) if self._column else val + + def _analyze(self): + raise NotImplementedError + + def get_context_size(self): + raise NotImplementedError + + def update_context(self, ctx): + raise NotImplementedError + + +class SetUpdateClause(ContainerUpdateClause): + """ updates a set collection """ + + def __init__(self, field, value, operation=None, previous=None, column=None): + super(SetUpdateClause, self).__init__(field, value, operation, previous, column=column) + self._additions = None + self._removals = None + + def __unicode__(self): + qs = [] + ctx_id = self.context_id + if self.previous is None and not (self._assignments or self._additions or self._removals): + qs += ['"{}" = %({})s'.format(self.field, ctx_id)] + if self._assignments: + qs += ['"{}" = %({})s'.format(self.field, ctx_id)] + ctx_id += 1 + if self._additions: + qs += ['"{0}" = "{0}" + %({1})s'.format(self.field, ctx_id)] + ctx_id += 1 + if self._removals: + qs += ['"{0}" = "{0}" - %({1})s'.format(self.field, ctx_id)] + + return ', '.join(qs) + + def _analyze(self): + """ works out the updates to be performed """ + if self.value is None or self.value == self.previous: + pass + elif self._operation == "add": + self._additions = self.value + elif self._operation == "remove": + self._removals = self.value + elif self.previous is None: + self._assignments = self.value + else: + # partial update time + self._additions = (self.value - self.previous) or None + self._removals = (self.previous - self.value) or None + self._analyzed = True + + def get_context_size(self): + if not self._analyzed: self._analyze() + if self.previous is None and not (self._assignments or self._additions or self._removals): + return 1 + return int(bool(self._assignments)) + int(bool(self._additions)) + int(bool(self._removals)) + + def update_context(self, ctx): + if not self._analyzed: self._analyze() + ctx_id = self.context_id + if self.previous is None and not (self._assignments or self._additions or self._removals): + ctx[str(ctx_id)] = self._to_database({}) + if self._assignments: + ctx[str(ctx_id)] = self._to_database(self._assignments) + ctx_id += 1 + if self._additions: + ctx[str(ctx_id)] = self._to_database(self._additions) + ctx_id += 1 + if self._removals: + ctx[str(ctx_id)] = self._to_database(self._removals) + + +class ListUpdateClause(ContainerUpdateClause): + """ updates a list collection """ + + def __init__(self, field, value, operation=None, previous=None, column=None): + super(ListUpdateClause, self).__init__(field, value, operation, previous, column=column) + self._append = None + self._prepend = None + + def __unicode__(self): + if not self._analyzed: self._analyze() + qs = [] + ctx_id = self.context_id + if self._assignments is not None: + qs += ['"{}" = %({})s'.format(self.field, ctx_id)] + ctx_id += 1 + + if self._prepend: + qs += ['"{0}" = %({1})s + "{0}"'.format(self.field, ctx_id)] + ctx_id += 1 + + if self._append: + qs += ['"{0}" = "{0}" + %({1})s'.format(self.field, ctx_id)] + + return ', '.join(qs) + + def get_context_size(self): + if not self._analyzed: self._analyze() + return int(self._assignments is not None) + int(bool(self._append)) + int(bool(self._prepend)) + + def update_context(self, ctx): + if not self._analyzed: self._analyze() + ctx_id = self.context_id + if self._assignments is not None: + ctx[str(ctx_id)] = self._to_database(self._assignments) + ctx_id += 1 + if self._prepend: + # CQL seems to prepend element at a time, starting + # with the element at idx 0, we can either reverse + # it here, or have it inserted in reverse + ctx[str(ctx_id)] = self._to_database(list(reversed(self._prepend))) + ctx_id += 1 + if self._append: + ctx[str(ctx_id)] = self._to_database(self._append) + + def _analyze(self): + """ works out the updates to be performed """ + if self.value is None or self.value == self.previous: + pass + + elif self._operation == "append": + self._append = self.value + + elif self._operation == "prepend": + # self.value is a Quoter but we reverse self._prepend later as if + # it's a list, so we have to set it to the underlying list + self._prepend = self.value.value + + elif self.previous is None: + self._assignments = self.value + + elif len(self.value) < len(self.previous): + # if elements have been removed, + # rewrite the whole list + self._assignments = self.value + + elif len(self.previous) == 0: + # if we're updating from an empty + # list, do a complete insert + self._assignments = self.value + else: + + # the max start idx we want to compare + search_space = len(self.value) - max(0, len(self.previous)-1) + + # the size of the sub lists we want to look at + search_size = len(self.previous) + + for i in range(search_space): + #slice boundary + j = i + search_size + sub = self.value[i:j] + idx_cmp = lambda idx: self.previous[idx] == sub[idx] + if idx_cmp(0) and idx_cmp(-1) and self.previous == sub: + self._prepend = self.value[:i] or None + self._append = self.value[j:] or None + break + + # if both append and prepend are still None after looking + # at both lists, an insert statement will be created + if self._prepend is self._append is None: + self._assignments = self.value + + self._analyzed = True + + +class MapUpdateClause(ContainerUpdateClause): + """ updates a map collection """ + + def __init__(self, field, value, operation=None, previous=None, column=None): + super(MapUpdateClause, self).__init__(field, value, operation, previous, column=column) + self._updates = None + + def _analyze(self): + if self._operation == "update": + self._updates = self.value.keys() + else: + if self.previous is None: + self._updates = sorted([k for k, v in self.value.items()]) + else: + self._updates = sorted([k for k, v in self.value.items() if v != self.previous.get(k)]) or None + self._analyzed = True + + def get_context_size(self): + if not self._analyzed: self._analyze() + if self.previous is None and not self._updates: + return 1 + return len(self._updates or []) * 2 + + def update_context(self, ctx): + if not self._analyzed: self._analyze() + ctx_id = self.context_id + if self.previous is None and not self._updates: + ctx[str(ctx_id)] = {} + else: + for key in self._updates or []: + val = self.value.get(key) + ctx[str(ctx_id)] = self._column.key_col.to_database(key) if self._column else key + ctx[str(ctx_id + 1)] = self._column.value_col.to_database(val) if self._column else val + ctx_id += 2 + + def __unicode__(self): + if not self._analyzed: self._analyze() + qs = [] + + ctx_id = self.context_id + if self.previous is None and not self._updates: + qs += ['"int_map" = %({})s'.format(ctx_id)] + else: + for _ in self._updates or []: + qs += ['"{}"[%({})s] = %({})s'.format(self.field, ctx_id, ctx_id + 1)] + ctx_id += 2 + + return ', '.join(qs) + + +class CounterUpdateClause(ContainerUpdateClause): + + def __init__(self, field, value, previous=None, column=None): + super(CounterUpdateClause, self).__init__(field, value, previous=previous, column=column) + self.previous = self.previous or 0 + + def get_context_size(self): + return 1 + + def update_context(self, ctx): + ctx[str(self.context_id)] = self._to_database(abs(self.value - self.previous)) + + def __unicode__(self): + delta = self.value - self.previous + sign = '-' if delta < 0 else '+' + return '"{0}" = "{0}" {1} %({2})s'.format(self.field, sign, self.context_id) + + +class BaseDeleteClause(BaseClause): + pass + + +class FieldDeleteClause(BaseDeleteClause): + """ deletes a field from a row """ + + def __init__(self, field): + super(FieldDeleteClause, self).__init__(field, None) + + def __unicode__(self): + return '"{}"'.format(self.field) + + def update_context(self, ctx): + pass + + def get_context_size(self): + return 0 + + +class MapDeleteClause(BaseDeleteClause): + """ removes keys from a map """ + + def __init__(self, field, value, previous=None): + super(MapDeleteClause, self).__init__(field, value) + self.value = self.value or {} + self.previous = previous or {} + self._analyzed = False + self._removals = None + + def _analyze(self): + self._removals = sorted([k for k in self.previous if k not in self.value]) + self._analyzed = True + + def update_context(self, ctx): + if not self._analyzed: self._analyze() + for idx, key in enumerate(self._removals): + ctx[str(self.context_id + idx)] = key + + def get_context_size(self): + if not self._analyzed: self._analyze() + return len(self._removals) + + def __unicode__(self): + if not self._analyzed: self._analyze() + return ', '.join(['"{}"[%({})s]'.format(self.field, self.context_id + i) for i in range(len(self._removals))]) + + +class BaseCQLStatement(object): + """ The base cql statement class """ + + def __init__(self, table, consistency=None, timestamp=None, where=None): + super(BaseCQLStatement, self).__init__() + self.table = table + self.consistency = consistency + self.context_id = 0 + self.context_counter = self.context_id + self.timestamp = timestamp + + self.where_clauses = [] + for clause in where or []: + self.add_where_clause(clause) + + def add_where_clause(self, clause): + """ + adds a where clause to this statement + :param clause: the clause to add + :type clause: WhereClause + """ + if not isinstance(clause, WhereClause): + raise StatementException("only instances of WhereClause can be added to statements") + clause.set_context_id(self.context_counter) + self.context_counter += clause.get_context_size() + self.where_clauses.append(clause) + + def get_context(self): + """ + returns the context dict for this statement + :rtype: dict + """ + ctx = {} + for clause in self.where_clauses or []: + clause.update_context(ctx) + return ctx + + def get_context_size(self): + return len(self.get_context()) + + def update_context_id(self, i): + self.context_id = i + self.context_counter = self.context_id + for clause in self.where_clauses: + clause.set_context_id(self.context_counter) + self.context_counter += clause.get_context_size() + + @property + def timestamp_normalized(self): + """ + we're expecting self.timestamp to be either a long, int, a datetime, or a timedelta + :return: + """ + if not self.timestamp: + return None + + if isinstance(self.timestamp, (int, long)): + return self.timestamp + + if isinstance(self.timestamp, timedelta): + tmp = datetime.now() + self.timestamp + else: + tmp = self.timestamp + + return long(time.mktime(tmp.timetuple()) * 1e+6 + tmp.microsecond) + + def __unicode__(self): + raise NotImplementedError + + def __str__(self): + return unicode(self).encode('utf-8') + + def __repr__(self): + return self.__unicode__() + + @property + def _where(self): + return 'WHERE {}'.format(' AND '.join([unicode(c) for c in self.where_clauses])) + + +class SelectStatement(BaseCQLStatement): + """ a cql select statement """ + + def __init__(self, + table, + fields=None, + count=False, + consistency=None, + where=None, + order_by=None, + limit=None, + allow_filtering=False): + + """ + :param where + :type where list of cqlengine.statements.WhereClause + """ + super(SelectStatement, self).__init__( + table, + consistency=consistency, + where=where + ) + + self.fields = [fields] if isinstance(fields, basestring) else (fields or []) + self.count = count + self.order_by = [order_by] if isinstance(order_by, basestring) else order_by + self.limit = limit + self.allow_filtering = allow_filtering + + def __unicode__(self): + qs = ['SELECT'] + if self.count: + qs += ['COUNT(*)'] + else: + qs += [', '.join(['"{}"'.format(f) for f in self.fields]) if self.fields else '*'] + qs += ['FROM', self.table] + + if self.where_clauses: + qs += [self._where] + + if self.order_by and not self.count: + qs += ['ORDER BY {}'.format(', '.join(unicode(o) for o in self.order_by))] + + if self.limit: + qs += ['LIMIT {}'.format(self.limit)] + + if self.allow_filtering: + qs += ['ALLOW FILTERING'] + + return ' '.join(qs) + + +class AssignmentStatement(BaseCQLStatement): + """ value assignment statements """ + + def __init__(self, + table, + assignments=None, + consistency=None, + where=None, + ttl=None, + timestamp=None): + super(AssignmentStatement, self).__init__( + table, + consistency=consistency, + where=where, + ) + self.ttl = ttl + self.timestamp = timestamp + + # add assignments + self.assignments = [] + for assignment in assignments or []: + self.add_assignment_clause(assignment) + + def update_context_id(self, i): + super(AssignmentStatement, self).update_context_id(i) + for assignment in self.assignments: + assignment.set_context_id(self.context_counter) + self.context_counter += assignment.get_context_size() + + def add_assignment_clause(self, clause): + """ + adds an assignment clause to this statement + :param clause: the clause to add + :type clause: AssignmentClause + """ + if not isinstance(clause, AssignmentClause): + raise StatementException("only instances of AssignmentClause can be added to statements") + clause.set_context_id(self.context_counter) + self.context_counter += clause.get_context_size() + self.assignments.append(clause) + + @property + def is_empty(self): + return len(self.assignments) == 0 + + def get_context(self): + ctx = super(AssignmentStatement, self).get_context() + for clause in self.assignments: + clause.update_context(ctx) + return ctx + + +class InsertStatement(AssignmentStatement): + """ an cql insert select statement """ + + def add_where_clause(self, clause): + raise StatementException("Cannot add where clauses to insert statements") + + def __unicode__(self): + qs = ['INSERT INTO {}'.format(self.table)] + + # get column names and context placeholders + fields = [a.insert_tuple() for a in self.assignments] + columns, values = zip(*fields) + + qs += ["({})".format(', '.join(['"{}"'.format(c) for c in columns]))] + qs += ['VALUES'] + qs += ["({})".format(', '.join(['%({})s'.format(v) for v in values]))] + + if self.ttl: + qs += ["USING TTL {}".format(self.ttl)] + + if self.timestamp: + qs += ["USING TIMESTAMP {}".format(self.timestamp_normalized)] + + return ' '.join(qs) + + +class UpdateStatement(AssignmentStatement): + """ an cql update select statement """ + + def __unicode__(self): + qs = ['UPDATE', self.table] + + using_options = [] + + if self.ttl: + using_options += ["TTL {}".format(self.ttl)] + + if self.timestamp: + using_options += ["TIMESTAMP {}".format(self.timestamp_normalized)] + + if using_options: + qs += ["USING {}".format(" AND ".join(using_options))] + + qs += ['SET'] + qs += [', '.join([unicode(c) for c in self.assignments])] + + if self.where_clauses: + qs += [self._where] + + return ' '.join(qs) + + +class DeleteStatement(BaseCQLStatement): + """ a cql delete statement """ + + def __init__(self, table, fields=None, consistency=None, where=None, timestamp=None): + super(DeleteStatement, self).__init__( + table, + consistency=consistency, + where=where, + timestamp=timestamp + ) + self.fields = [] + if isinstance(fields, basestring): + fields = [fields] + for field in fields or []: + self.add_field(field) + + def update_context_id(self, i): + super(DeleteStatement, self).update_context_id(i) + for field in self.fields: + field.set_context_id(self.context_counter) + self.context_counter += field.get_context_size() + + def get_context(self): + ctx = super(DeleteStatement, self).get_context() + for field in self.fields: + field.update_context(ctx) + return ctx + + def add_field(self, field): + if isinstance(field, basestring): + field = FieldDeleteClause(field) + if not isinstance(field, BaseClause): + raise StatementException("only instances of AssignmentClause can be added to statements") + field.set_context_id(self.context_counter) + self.context_counter += field.get_context_size() + self.fields.append(field) + + def __unicode__(self): + qs = ['DELETE'] + if self.fields: + qs += [', '.join(['{}'.format(f) for f in self.fields])] + qs += ['FROM', self.table] + + delete_option = [] + + if self.timestamp: + delete_option += ["TIMESTAMP {}".format(self.timestamp_normalized)] + + if delete_option: + qs += [" USING {} ".format(" AND ".join(delete_option))] + + if self.where_clauses: + qs += [self._where] + + return ' '.join(qs) + diff --git a/cqlengine/tests/base.py b/cqlengine/tests/base.py index 64e6a3c2..3250de14 100644 --- a/cqlengine/tests/base.py +++ b/cqlengine/tests/base.py @@ -1,18 +1,33 @@ from unittest import TestCase from cqlengine import connection +import os +from cqlengine.connection import get_session + + +if os.environ.get('CASSANDRA_TEST_HOST'): + CASSANDRA_TEST_HOST = os.environ['CASSANDRA_TEST_HOST'] +else: + CASSANDRA_TEST_HOST = 'localhost' + +protocol_version = int(os.environ.get("CASSANDRA_PROTOCOL_VERSION", 2)) + +connection.setup([CASSANDRA_TEST_HOST], protocol_version=protocol_version, default_keyspace='cqlengine_test') class BaseCassEngTestCase(TestCase): - @classmethod - def setUpClass(cls): - super(BaseCassEngTestCase, cls).setUpClass() - if not connection._hosts: - connection.setup(['localhost:9160'], default_keyspace='cqlengine_test') + # @classmethod + # def setUpClass(cls): + # super(BaseCassEngTestCase, cls).setUpClass() + session = None + + def setUp(self): + self.session = get_session() + super(BaseCassEngTestCase, self).setUp() def assertHasAttr(self, obj, attr): - self.assertTrue(hasattr(obj, attr), + self.assertTrue(hasattr(obj, attr), "{} doesn't have attribute: {}".format(obj, attr)) def assertNotHasAttr(self, obj, attr): - self.assertFalse(hasattr(obj, attr), + self.assertFalse(hasattr(obj, attr), "{} shouldn't have the attribute: {}".format(obj, attr)) diff --git a/cqlengine/tests/columns/test_container_columns.py b/cqlengine/tests/columns/test_container_columns.py index 3533aad3..0b756f33 100644 --- a/cqlengine/tests/columns/test_container_columns.py +++ b/cqlengine/tests/columns/test_container_columns.py @@ -1,32 +1,91 @@ from datetime import datetime, timedelta +import json from uuid import uuid4 from cqlengine import Model, ValidationError from cqlengine import columns -from cqlengine.management import create_table, delete_table +from cqlengine.management import sync_table, drop_table from cqlengine.tests.base import BaseCassEngTestCase + class TestSetModel(Model): - partition = columns.UUID(primary_key=True, default=uuid4) - int_set = columns.Set(columns.Integer, required=False) - text_set = columns.Set(columns.Text, required=False) + __keyspace__ = 'test' + + partition = columns.UUID(primary_key=True, default=uuid4) + int_set = columns.Set(columns.Integer, required=False) + text_set = columns.Set(columns.Text, required=False) + + +class JsonTestColumn(columns.Column): + + db_type = 'text' + + def to_python(self, value): + if value is None: return + if isinstance(value, basestring): + return json.loads(value) + else: + return value + + def to_database(self, value): + if value is None: return + return json.dumps(value) + class TestSetColumn(BaseCassEngTestCase): @classmethod def setUpClass(cls): super(TestSetColumn, cls).setUpClass() - delete_table(TestSetModel) - create_table(TestSetModel) + drop_table(TestSetModel) + sync_table(TestSetModel) @classmethod def tearDownClass(cls): super(TestSetColumn, cls).tearDownClass() - delete_table(TestSetModel) + drop_table(TestSetModel) + + def test_add_none_fails(self): + with self.assertRaises(ValidationError): + m = TestSetModel.create(int_set=set([None])) + + def test_empty_set_initial(self): + """ + tests that sets are set() by default, should never be none + :return: + """ + m = TestSetModel.create() + m.int_set.add(5) + m.save() + + def test_deleting_last_item_should_succeed(self): + m = TestSetModel.create() + m.int_set.add(5) + m.save() + m.int_set.remove(5) + m.save() + + m = TestSetModel.get(partition=m.partition) + self.assertNotIn(5, m.int_set) + + def test_blind_deleting_last_item_should_succeed(self): + m = TestSetModel.create() + m.int_set.add(5) + m.save() + + TestSetModel.objects(partition=m.partition).update(int_set=set()) + + m = TestSetModel.get(partition=m.partition) + self.assertNotIn(5, m.int_set) + + def test_empty_set_retrieval(self): + m = TestSetModel.create() + m2 = TestSetModel.get(partition=m.partition) + m2.int_set.add(3) def test_io_success(self): """ Tests that a basic usage works as expected """ - m1 = TestSetModel.create(int_set={1,2}, text_set={'kai', 'andreas'}) + m1 = TestSetModel.create(int_set={1, 2}, text_set={'kai', 'andreas'}) m2 = TestSetModel.get(partition=m1.partition) assert isinstance(m2.int_set, set) @@ -45,54 +104,96 @@ class TestSetColumn(BaseCassEngTestCase): with self.assertRaises(ValidationError): TestSetModel.create(int_set={'string', True}, text_set={1, 3.0}) + def test_element_count_validation(self): + """ + Tests that big collections are detected and raise an exception. + """ + TestSetModel.create(text_set={str(uuid4()) for i in range(65535)}) + with self.assertRaises(ValidationError): + TestSetModel.create(text_set={str(uuid4()) for i in range(65536)}) + def test_partial_updates(self): """ Tests that partial udpates work as expected """ - m1 = TestSetModel.create(int_set={1,2,3,4}) + m1 = TestSetModel.create(int_set={1, 2, 3, 4}) m1.int_set.add(5) m1.int_set.remove(1) - assert m1.int_set == {2,3,4,5} + assert m1.int_set == {2, 3, 4, 5} m1.save() - m2 = TestSetModel.get(partition=m1.partition) - assert m2.int_set == {2,3,4,5} + m2 = TestSetModel.get(partition=m1.partition) + assert m2.int_set == {2, 3, 4, 5} - def test_partial_update_creation(self): + def test_instantiation_with_column_class(self): """ - Tests that proper update statements are created for a partial set update - :return: + Tests that columns instantiated with a column class work properly + and that the class is instantiated in the constructor """ - ctx = {} - col = columns.Set(columns.Integer, db_field="TEST") - statements = col.get_update_statement({1,2,3,4}, {2,3,4,5}, ctx) + column = columns.Set(columns.Text) + assert isinstance(column.value_col, columns.Text) + + def test_instantiation_with_column_instance(self): + """ + Tests that columns instantiated with a column instance work properly + """ + column = columns.Set(columns.Text(min_length=100)) + assert isinstance(column.value_col, columns.Text) + + def test_to_python(self): + """ Tests that to_python of value column is called """ + column = columns.Set(JsonTestColumn) + val = {1, 2, 3} + db_val = column.to_database(val) + assert db_val.value == {json.dumps(v) for v in val} + py_val = column.to_python(db_val.value) + assert py_val == val + + def test_default_empty_container_saving(self): + """ tests that the default empty container is not saved if it hasn't been updated """ + pkey = uuid4() + # create a row with set data + TestSetModel.create(partition=pkey, int_set={3, 4}) + # create another with no set data + TestSetModel.create(partition=pkey) + + m = TestSetModel.get(partition=pkey) + self.assertEqual(m.int_set, {3, 4}) - assert len([v for v in ctx.values() if {1} == v.value]) == 1 - assert len([v for v in ctx.values() if {5} == v.value]) == 1 - assert len([s for s in statements if '"TEST" = "TEST" -' in s]) == 1 - assert len([s for s in statements if '"TEST" = "TEST" +' in s]) == 1 class TestListModel(Model): - partition = columns.UUID(primary_key=True, default=uuid4) - int_list = columns.List(columns.Integer, required=False) - text_list = columns.List(columns.Text, required=False) + __keyspace__ = 'test' + + partition = columns.UUID(primary_key=True, default=uuid4) + int_list = columns.List(columns.Integer, required=False) + text_list = columns.List(columns.Text, required=False) + class TestListColumn(BaseCassEngTestCase): @classmethod def setUpClass(cls): super(TestListColumn, cls).setUpClass() - delete_table(TestListModel) - create_table(TestListModel) + drop_table(TestListModel) + sync_table(TestListModel) @classmethod def tearDownClass(cls): super(TestListColumn, cls).tearDownClass() - delete_table(TestListModel) + drop_table(TestListModel) + + def test_initial(self): + tmp = TestListModel.create() + tmp.int_list.append(1) + + def test_initial(self): + tmp = TestListModel.create() + tmp2 = TestListModel.get(partition=tmp.partition) + tmp2.int_list.append(1) def test_io_success(self): """ Tests that a basic usage works as expected """ - m1 = TestListModel.create(int_list=[1,2], text_list=['kai', 'andreas']) + m1 = TestListModel.create(int_list=[1, 2], text_list=['kai', 'andreas']) m2 = TestListModel.get(partition=m1.partition) assert isinstance(m2.int_list, list) @@ -114,6 +215,14 @@ class TestListColumn(BaseCassEngTestCase): with self.assertRaises(ValidationError): TestListModel.create(int_list=['string', True], text_list=[1, 3.0]) + def test_element_count_validation(self): + """ + Tests that big collections are detected and raise an exception. + """ + TestListModel.create(text_list=[str(uuid4()) for i in range(65535)]) + with self.assertRaises(ValidationError): + TestListModel.create(text_list=[str(uuid4()) for i in range(65536)]) + def test_partial_updates(self): """ Tests that partial udpates work as expected """ final = range(10) @@ -123,43 +232,128 @@ class TestListColumn(BaseCassEngTestCase): m1.int_list = final m1.save() - m2 = TestListModel.get(partition=m1.partition) + m2 = TestListModel.get(partition=m1.partition) assert list(m2.int_list) == final - def test_partial_update_creation(self): + def test_instantiation_with_column_class(self): """ - Tests that proper update statements are created for a partial list update - :return: + Tests that columns instantiated with a column class work properly + and that the class is instantiated in the constructor """ - final = range(10) - initial = final[3:7] + column = columns.List(columns.Text) + assert isinstance(column.value_col, columns.Text) - ctx = {} - col = columns.List(columns.Integer, db_field="TEST") - statements = col.get_update_statement(final, initial, ctx) + def test_instantiation_with_column_instance(self): + """ + Tests that columns instantiated with a column instance work properly + """ + column = columns.List(columns.Text(min_length=100)) + assert isinstance(column.value_col, columns.Text) - assert len([v for v in ctx.values() if [2,1,0] == v.value]) == 1 - assert len([v for v in ctx.values() if [7,8,9] == v.value]) == 1 - assert len([s for s in statements if '"TEST" = "TEST" +' in s]) == 1 - assert len([s for s in statements if '+ "TEST"' in s]) == 1 + def test_to_python(self): + """ Tests that to_python of value column is called """ + column = columns.List(JsonTestColumn) + val = [1, 2, 3] + db_val = column.to_database(val) + assert db_val.value == [json.dumps(v) for v in val] + py_val = column.to_python(db_val.value) + assert py_val == val + + def test_default_empty_container_saving(self): + """ tests that the default empty container is not saved if it hasn't been updated """ + pkey = uuid4() + # create a row with list data + TestListModel.create(partition=pkey, int_list=[1,2,3,4]) + # create another with no list data + TestListModel.create(partition=pkey) + + m = TestListModel.get(partition=pkey) + self.assertEqual(m.int_list, [1,2,3,4]) + + def test_remove_entry_works(self): + pkey = uuid4() + tmp = TestListModel.create(partition=pkey, int_list=[1,2]) + tmp.int_list.pop() + tmp.update() + tmp = TestListModel.get(partition=pkey) + self.assertEqual(tmp.int_list, [1]) + + def test_update_from_non_empty_to_empty(self): + pkey = uuid4() + tmp = TestListModel.create(partition=pkey, int_list=[1,2]) + tmp.int_list = [] + tmp.update() + + tmp = TestListModel.get(partition=pkey) + self.assertEqual(tmp.int_list, []) + + def test_insert_none(self): + pkey = uuid4() + with self.assertRaises(ValidationError): + TestListModel.create(partition=pkey, int_list=[None]) + + def test_blind_list_updates_from_none(self): + """ Tests that updates from None work as expected """ + m = TestListModel.create(int_list=None) + expected = [1, 2] + m.int_list = expected + m.save() + + m2 = TestListModel.get(partition=m.partition) + assert m2.int_list == expected + + TestListModel.objects(partition=m.partition).update(int_list=[]) + + m3 = TestListModel.get(partition=m.partition) + assert m3.int_list == [] class TestMapModel(Model): - partition = columns.UUID(primary_key=True, default=uuid4) - int_map = columns.Map(columns.Integer, columns.UUID, required=False) - text_map = columns.Map(columns.Text, columns.DateTime, required=False) + __keyspace__ = 'test' + + partition = columns.UUID(primary_key=True, default=uuid4) + int_map = columns.Map(columns.Integer, columns.UUID, required=False) + text_map = columns.Map(columns.Text, columns.DateTime, required=False) + class TestMapColumn(BaseCassEngTestCase): @classmethod def setUpClass(cls): super(TestMapColumn, cls).setUpClass() - delete_table(TestMapModel) - create_table(TestMapModel) + drop_table(TestMapModel) + sync_table(TestMapModel) @classmethod def tearDownClass(cls): super(TestMapColumn, cls).tearDownClass() - delete_table(TestMapModel) + drop_table(TestMapModel) + + def test_empty_default(self): + tmp = TestMapModel.create() + tmp.int_map['blah'] = 1 + + def test_add_none_as_map_key(self): + with self.assertRaises(ValidationError): + TestMapModel.create(int_map={None:1}) + + def test_add_none_as_map_value(self): + with self.assertRaises(ValidationError): + TestMapModel.create(int_map={None:1}) + + def test_empty_retrieve(self): + tmp = TestMapModel.create() + tmp2 = TestMapModel.get(partition=tmp.partition) + tmp2.int_map['blah'] = 1 + + def test_remove_last_entry_works(self): + tmp = TestMapModel.create() + tmp.text_map["blah"] = datetime.now() + tmp.save() + del tmp.text_map["blah"] + tmp.save() + + tmp = TestMapModel.get(partition=tmp.partition) + self.assertNotIn("blah", tmp.int_map) def test_io_success(self): """ Tests that a basic usage works as expected """ @@ -167,7 +361,7 @@ class TestMapColumn(BaseCassEngTestCase): k2 = uuid4() now = datetime.now() then = now + timedelta(days=1) - m1 = TestMapModel.create(int_map={1:k1,2:k2}, text_map={'now':now, 'then':then}) + m1 = TestMapModel.create(int_map={1: k1, 2: k2}, text_map={'now': now, 'then': then}) m2 = TestMapModel.get(partition=m1.partition) assert isinstance(m2.int_map, dict) @@ -188,7 +382,15 @@ class TestMapColumn(BaseCassEngTestCase): Tests that attempting to use the wrong types will raise an exception """ with self.assertRaises(ValidationError): - TestMapModel.create(int_map={'key':2,uuid4():'val'}, text_map={2:5}) + TestMapModel.create(int_map={'key': 2, uuid4(): 'val'}, text_map={2: 5}) + + def test_element_count_validation(self): + """ + Tests that big collections are detected and raise an exception. + """ + TestMapModel.create(text_map={str(uuid4()): i for i in range(65535)}) + with self.assertRaises(ValidationError): + TestMapModel.create(text_map={str(uuid4()): i for i in range(65536)}) def test_partial_updates(self): """ Tests that partial udpates work as expected """ @@ -199,36 +401,93 @@ class TestMapColumn(BaseCassEngTestCase): earlier = early - timedelta(minutes=30) later = now + timedelta(minutes=30) - initial = {'now':now, 'early':earlier} - final = {'later':later, 'early':early} + initial = {'now': now, 'early': earlier} + final = {'later': later, 'early': early} m1 = TestMapModel.create(text_map=initial) m1.text_map = final m1.save() - m2 = TestMapModel.get(partition=m1.partition) + m2 = TestMapModel.get(partition=m1.partition) assert m2.text_map == final def test_updates_from_none(self): """ Tests that updates from None work as expected """ m = TestMapModel.create(int_map=None) - expected = {1:uuid4()} + expected = {1: uuid4()} m.int_map = expected m.save() m2 = TestMapModel.get(partition=m.partition) assert m2.int_map == expected + m2.int_map = None + m2.save() + m3 = TestMapModel.get(partition=m.partition) + assert m3.int_map != expected + + def test_blind_updates_from_none(self): + """ Tests that updates from None work as expected """ + m = TestMapModel.create(int_map=None) + expected = {1: uuid4()} + m.int_map = expected + m.save() + + m2 = TestMapModel.get(partition=m.partition) + assert m2.int_map == expected + + TestMapModel.objects(partition=m.partition).update(int_map={}) + + m3 = TestMapModel.get(partition=m.partition) + assert m3.int_map != expected def test_updates_to_none(self): """ Tests that setting the field to None works as expected """ - m = TestMapModel.create(int_map={1:uuid4()}) + m = TestMapModel.create(int_map={1: uuid4()}) m.int_map = None m.save() m2 = TestMapModel.get(partition=m.partition) - assert m2.int_map is None + assert m2.int_map == {} + + def test_instantiation_with_column_class(self): + """ + Tests that columns instantiated with a column class work properly + and that the class is instantiated in the constructor + """ + column = columns.Map(columns.Text, columns.Integer) + assert isinstance(column.key_col, columns.Text) + assert isinstance(column.value_col, columns.Integer) + + def test_instantiation_with_column_instance(self): + """ + Tests that columns instantiated with a column instance work properly + """ + column = columns.Map(columns.Text(min_length=100), columns.Integer()) + assert isinstance(column.key_col, columns.Text) + assert isinstance(column.value_col, columns.Integer) + + def test_to_python(self): + """ Tests that to_python of value column is called """ + column = columns.Map(JsonTestColumn, JsonTestColumn) + val = {1: 2, 3: 4, 5: 6} + db_val = column.to_database(val) + assert db_val.value == {json.dumps(k):json.dumps(v) for k,v in val.items()} + py_val = column.to_python(db_val.value) + assert py_val == val + + def test_default_empty_container_saving(self): + """ tests that the default empty container is not saved if it hasn't been updated """ + pkey = uuid4() + tmap = {1: uuid4(), 2: uuid4()} + # create a row with set data + TestMapModel.create(partition=pkey, int_map=tmap) + # create another with no set data + TestMapModel.create(partition=pkey) + + m = TestMapModel.get(partition=pkey) + self.assertEqual(m.int_map, tmap) # def test_partial_update_creation(self): # """ @@ -246,3 +505,27 @@ class TestMapColumn(BaseCassEngTestCase): # assert len([v for v in ctx.values() if [7,8,9] == v.value]) == 1 # assert len([s for s in statements if '"TEST" = "TEST" +' in s]) == 1 # assert len([s for s in statements if '+ "TEST"' in s]) == 1 + + +class TestCamelMapModel(Model): + __keyspace__ = 'test' + + partition = columns.UUID(primary_key=True, default=uuid4) + camelMap = columns.Map(columns.Text, columns.Integer, required=False) + + +class TestCamelMapColumn(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(TestCamelMapColumn, cls).setUpClass() + drop_table(TestCamelMapModel) + sync_table(TestCamelMapModel) + + @classmethod + def tearDownClass(cls): + super(TestCamelMapColumn, cls).tearDownClass() + drop_table(TestCamelMapModel) + + def test_camelcase_column(self): + TestCamelMapModel.create(partition=None, camelMap={'blah': 1}) diff --git a/cqlengine/tests/columns/test_counter_column.py b/cqlengine/tests/columns/test_counter_column.py new file mode 100644 index 00000000..8e8b2361 --- /dev/null +++ b/cqlengine/tests/columns/test_counter_column.py @@ -0,0 +1,94 @@ +from uuid import uuid4 + +from cqlengine import Model +from cqlengine import columns +from cqlengine.management import sync_table, drop_table +from cqlengine.models import ModelDefinitionException +from cqlengine.tests.base import BaseCassEngTestCase + + +class TestCounterModel(Model): + __keyspace__ = 'test' + partition = columns.UUID(primary_key=True, default=uuid4) + cluster = columns.UUID(primary_key=True, default=uuid4) + counter = columns.Counter() + + +class TestClassConstruction(BaseCassEngTestCase): + + def test_defining_a_non_counter_column_fails(self): + """ Tests that defining a non counter column field in a model with a counter column fails """ + with self.assertRaises(ModelDefinitionException): + class model(Model): + partition = columns.UUID(primary_key=True, default=uuid4) + counter = columns.Counter() + text = columns.Text() + + + def test_defining_a_primary_key_counter_column_fails(self): + """ Tests that defining primary keys on counter columns fails """ + with self.assertRaises(TypeError): + class model(Model): + partition = columns.UUID(primary_key=True, default=uuid4) + cluster = columns.Counter(primary_ley=True) + counter = columns.Counter() + + # force it + with self.assertRaises(ModelDefinitionException): + class model(Model): + partition = columns.UUID(primary_key=True, default=uuid4) + cluster = columns.Counter() + cluster.primary_key = True + counter = columns.Counter() + + +class TestCounterColumn(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(TestCounterColumn, cls).setUpClass() + drop_table(TestCounterModel) + sync_table(TestCounterModel) + + @classmethod + def tearDownClass(cls): + super(TestCounterColumn, cls).tearDownClass() + drop_table(TestCounterModel) + + def test_updates(self): + """ Tests that counter updates work as intended """ + instance = TestCounterModel.create() + instance.counter += 5 + instance.save() + + actual = TestCounterModel.get(partition=instance.partition) + assert actual.counter == 5 + + def test_concurrent_updates(self): + """ Tests updates from multiple queries reaches the correct value """ + instance = TestCounterModel.create() + new1 = TestCounterModel.get(partition=instance.partition) + new2 = TestCounterModel.get(partition=instance.partition) + + new1.counter += 5 + new1.save() + new2.counter += 5 + new2.save() + + actual = TestCounterModel.get(partition=instance.partition) + assert actual.counter == 10 + + def test_update_from_none(self): + """ Tests that updating from None uses a create statement """ + instance = TestCounterModel() + instance.counter += 1 + instance.save() + + new = TestCounterModel.get(partition=instance.partition) + assert new.counter == 1 + + def test_new_instance_defaults_to_zero(self): + """ Tests that instantiating a new model instance will set the counter column to zero """ + instance = TestCounterModel() + assert instance.counter == 0 + diff --git a/cqlengine/tests/columns/test_validation.py b/cqlengine/tests/columns/test_validation.py index e323c465..0df006e8 100644 --- a/cqlengine/tests/columns/test_validation.py +++ b/cqlengine/tests/columns/test_validation.py @@ -3,7 +3,10 @@ from datetime import datetime, timedelta from datetime import date from datetime import tzinfo from decimal import Decimal as D +from unittest import TestCase +from uuid import uuid4, uuid1 from cqlengine import ValidationError +from cqlengine.connection import execute from cqlengine.tests.base import BaseCassEngTestCase @@ -12,6 +15,8 @@ from cqlengine.columns import Bytes from cqlengine.columns import Ascii from cqlengine.columns import Text from cqlengine.columns import Integer +from cqlengine.columns import BigInt +from cqlengine.columns import VarInt from cqlengine.columns import DateTime from cqlengine.columns import Date from cqlengine.columns import UUID @@ -19,23 +24,27 @@ from cqlengine.columns import Boolean from cqlengine.columns import Float from cqlengine.columns import Decimal -from cqlengine.management import create_table, delete_table +from cqlengine.management import sync_table, drop_table from cqlengine.models import Model +import sys + + class TestDatetime(BaseCassEngTestCase): class DatetimeTest(Model): + __keyspace__ = 'test' test_id = Integer(primary_key=True) created_at = DateTime() @classmethod def setUpClass(cls): super(TestDatetime, cls).setUpClass() - create_table(cls.DatetimeTest) + sync_table(cls.DatetimeTest) @classmethod def tearDownClass(cls): super(TestDatetime, cls).tearDownClass() - delete_table(cls.DatetimeTest) + drop_table(cls.DatetimeTest) def test_datetime_io(self): now = datetime.now() @@ -55,21 +64,78 @@ class TestDatetime(BaseCassEngTestCase): dt2 = self.DatetimeTest.objects(test_id=0).first() assert dt2.created_at.timetuple()[:6] == (now + timedelta(hours=1)).timetuple()[:6] + def test_datetime_date_support(self): + today = date.today() + self.DatetimeTest.objects.create(test_id=0, created_at=today) + dt2 = self.DatetimeTest.objects(test_id=0).first() + assert dt2.created_at.isoformat() == datetime(today.year, today.month, today.day).isoformat() + + def test_datetime_none(self): + dt = self.DatetimeTest.objects.create(test_id=1, created_at=None) + dt2 = self.DatetimeTest.objects(test_id=1).first() + assert dt2.created_at is None + + dts = self.DatetimeTest.objects.filter(test_id=1).values_list('created_at') + assert dts[0][0] is None + + +class TestBoolDefault(BaseCassEngTestCase): + class BoolDefaultValueTest(Model): + __keyspace__ = 'test' + test_id = Integer(primary_key=True) + stuff = Boolean(default=True) + + @classmethod + def setUpClass(cls): + super(TestBoolDefault, cls).setUpClass() + sync_table(cls.BoolDefaultValueTest) + + def test_default_is_set(self): + tmp = self.BoolDefaultValueTest.create(test_id=1) + self.assertEqual(True, tmp.stuff) + tmp2 = self.BoolDefaultValueTest.get(test_id=1) + self.assertEqual(True, tmp2.stuff) + + + +class TestVarInt(BaseCassEngTestCase): + class VarIntTest(Model): + __keyspace__ = 'test' + test_id = Integer(primary_key=True) + bignum = VarInt(primary_key=True) + + @classmethod + def setUpClass(cls): + super(TestVarInt, cls).setUpClass() + sync_table(cls.VarIntTest) + + @classmethod + def tearDownClass(cls): + super(TestVarInt, cls).tearDownClass() + sync_table(cls.VarIntTest) + + def test_varint_io(self): + long_int = sys.maxint + 1 + int1 = self.VarIntTest.objects.create(test_id=0, bignum=long_int) + int2 = self.VarIntTest.objects(test_id=0).first() + assert int1.bignum == int2.bignum + class TestDate(BaseCassEngTestCase): class DateTest(Model): + __keyspace__ = 'test' test_id = Integer(primary_key=True) created_at = Date() @classmethod def setUpClass(cls): super(TestDate, cls).setUpClass() - create_table(cls.DateTest) + sync_table(cls.DateTest) @classmethod def tearDownClass(cls): super(TestDate, cls).tearDownClass() - delete_table(cls.DateTest) + drop_table(cls.DateTest) def test_date_io(self): today = date.today() @@ -85,23 +151,32 @@ class TestDate(BaseCassEngTestCase): assert isinstance(dt2.created_at, date) assert dt2.created_at.isoformat() == now.date().isoformat() + def test_date_none(self): + self.DateTest.objects.create(test_id=1, created_at=None) + dt2 = self.DateTest.objects(test_id=1).first() + assert dt2.created_at is None + + dts = self.DateTest.objects(test_id=1).values_list('created_at') + assert dts[0][0] is None + class TestDecimal(BaseCassEngTestCase): class DecimalTest(Model): + __keyspace__ = 'test' test_id = Integer(primary_key=True) dec_val = Decimal() @classmethod def setUpClass(cls): super(TestDecimal, cls).setUpClass() - create_table(cls.DecimalTest) + sync_table(cls.DecimalTest) @classmethod def tearDownClass(cls): super(TestDecimal, cls).tearDownClass() - delete_table(cls.DecimalTest) + drop_table(cls.DecimalTest) - def test_datetime_io(self): + def test_decimal_io(self): dt = self.DecimalTest.objects.create(test_id=0, dec_val=D('0.00')) dt2 = self.DecimalTest.objects(test_id=0).first() assert dt2.dec_val == dt.dec_val @@ -110,22 +185,55 @@ class TestDecimal(BaseCassEngTestCase): dt2 = self.DecimalTest.objects(test_id=0).first() assert dt2.dec_val == D('5') +class TestUUID(BaseCassEngTestCase): + class UUIDTest(Model): + __keyspace__ = 'test' + test_id = Integer(primary_key=True) + a_uuid = UUID(default=uuid4()) + + @classmethod + def setUpClass(cls): + super(TestUUID, cls).setUpClass() + sync_table(cls.UUIDTest) + + @classmethod + def tearDownClass(cls): + super(TestUUID, cls).tearDownClass() + drop_table(cls.UUIDTest) + + def test_uuid_str_with_dashes(self): + a_uuid = uuid4() + t0 = self.UUIDTest.create(test_id=0, a_uuid=str(a_uuid)) + t1 = self.UUIDTest.get(test_id=0) + assert a_uuid == t1.a_uuid + + def test_uuid_str_no_dashes(self): + a_uuid = uuid4() + t0 = self.UUIDTest.create(test_id=1, a_uuid=a_uuid.hex) + t1 = self.UUIDTest.get(test_id=1) + assert a_uuid == t1.a_uuid + class TestTimeUUID(BaseCassEngTestCase): class TimeUUIDTest(Model): + __keyspace__ = 'test' test_id = Integer(primary_key=True) - timeuuid = TimeUUID() + timeuuid = TimeUUID(default=uuid1()) @classmethod def setUpClass(cls): super(TestTimeUUID, cls).setUpClass() - create_table(cls.TimeUUIDTest) + sync_table(cls.TimeUUIDTest) @classmethod def tearDownClass(cls): super(TestTimeUUID, cls).tearDownClass() - delete_table(cls.TimeUUIDTest) + drop_table(cls.TimeUUIDTest) def test_timeuuid_io(self): + """ + ensures that + :return: + """ t0 = self.TimeUUIDTest.create(test_id=0) t1 = self.TimeUUIDTest.get(test_id=0) @@ -133,22 +241,32 @@ class TestTimeUUID(BaseCassEngTestCase): class TestInteger(BaseCassEngTestCase): class IntegerTest(Model): - test_id = UUID(primary_key=True) - value = Integer(default=0) + __keyspace__ = 'test' + test_id = UUID(primary_key=True, default=lambda:uuid4()) + value = Integer(default=0, required=True) def test_default_zero_fields_validate(self): """ Tests that integer columns with a default value of 0 validate """ it = self.IntegerTest() it.validate() +class TestBigInt(BaseCassEngTestCase): + class BigIntTest(Model): + __keyspace__ = 'test' + test_id = UUID(primary_key=True, default=lambda:uuid4()) + value = BigInt(default=0, required=True) + + def test_default_zero_fields_validate(self): + """ Tests that bigint columns with a default value of 0 validate """ + it = self.BigIntTest() + it.validate() + class TestText(BaseCassEngTestCase): def test_min_length(self): #min len defaults to 1 col = Text() - - with self.assertRaises(ValidationError): - col.validate('') + col.validate('') col.validate('b') @@ -174,7 +292,7 @@ class TestText(BaseCassEngTestCase): Text().validate(bytearray('bytearray')) with self.assertRaises(ValidationError): - Text().validate(None) + Text(required=True).validate(None) with self.assertRaises(ValidationError): Text().validate(5) @@ -182,15 +300,48 @@ class TestText(BaseCassEngTestCase): with self.assertRaises(ValidationError): Text().validate(True) + def test_non_required_validation(self): + """ Tests that validation is ok on none and blank values if required is False """ + Text().validate('') + Text().validate(None) +class TestExtraFieldsRaiseException(BaseCassEngTestCase): + class TestModel(Model): + __keyspace__ = 'test' + id = UUID(primary_key=True, default=uuid4) + def test_extra_field(self): + with self.assertRaises(ValidationError): + self.TestModel.create(bacon=5000) +class TestPythonDoesntDieWhenExtraFieldIsInCassandra(BaseCassEngTestCase): + class TestModel(Model): + __keyspace__ = 'test' + __table_name__ = 'alter_doesnt_break_running_app' + id = UUID(primary_key=True, default=uuid4) + def test_extra_field(self): + drop_table(self.TestModel) + sync_table(self.TestModel) + self.TestModel.create() + execute("ALTER TABLE {} add blah int".format(self.TestModel.column_family_name(include_keyspace=True))) + self.TestModel.objects().all() +class TestTimeUUIDFromDatetime(TestCase): + def test_conversion_specific_date(self): + dt = datetime(1981, 7, 11, microsecond=555000) + uuid = TimeUUID.from_datetime(dt) + from uuid import UUID + assert isinstance(uuid, UUID) + ts = (uuid.time - 0x01b21dd213814000) / 1e7 # back to a timestamp + new_dt = datetime.utcfromtimestamp(ts) + + # checks that we created a UUID1 with the proper timestamp + assert new_dt == dt diff --git a/cqlengine/tests/columns/test_value_io.py b/cqlengine/tests/columns/test_value_io.py new file mode 100644 index 00000000..1515c25a --- /dev/null +++ b/cqlengine/tests/columns/test_value_io.py @@ -0,0 +1,171 @@ +from datetime import datetime, timedelta +from decimal import Decimal +from uuid import uuid1, uuid4, UUID + +from cqlengine.tests.base import BaseCassEngTestCase + +from cqlengine.management import sync_table +from cqlengine.management import drop_table +from cqlengine.models import Model +from cqlengine.columns import ValueQuoter +from cqlengine import columns +import unittest + + +class BaseColumnIOTest(BaseCassEngTestCase): + """ + Tests that values are come out of cassandra in the format we expect + + To test a column type, subclass this test, define the column, and the primary key + and data values you want to test + """ + + # The generated test model is assigned here + _generated_model = None + + # the column we want to test + column = None + + # the values we want to test against, you can + # use a single value, or multiple comma separated values + pkey_val = None + data_val = None + + @classmethod + def setUpClass(cls): + super(BaseColumnIOTest, cls).setUpClass() + + #if the test column hasn't been defined, bail out + if not cls.column: return + + # create a table with the given column + class IOTestModel(Model): + __keyspace__ = 'test' + table_name = cls.column.db_type + "_io_test_model_{}".format(uuid4().hex[:8]) + pkey = cls.column(primary_key=True) + data = cls.column() + cls._generated_model = IOTestModel + sync_table(cls._generated_model) + + #tupleify the tested values + if not isinstance(cls.pkey_val, tuple): + cls.pkey_val = cls.pkey_val, + if not isinstance(cls.data_val, tuple): + cls.data_val = cls.data_val, + + @classmethod + def tearDownClass(cls): + super(BaseColumnIOTest, cls).tearDownClass() + if not cls.column: return + drop_table(cls._generated_model) + + def comparator_converter(self, val): + """ If you want to convert the original value used to compare the model vales """ + return val + + def test_column_io(self): + """ Tests the given models class creates and retrieves values as expected """ + if not self.column: return + for pkey, data in zip(self.pkey_val, self.data_val): + #create + m1 = self._generated_model.create(pkey=pkey, data=data) + + #get + m2 = self._generated_model.get(pkey=pkey) + assert m1.pkey == m2.pkey == self.comparator_converter(pkey), self.column + assert m1.data == m2.data == self.comparator_converter(data), self.column + + #delete + self._generated_model.filter(pkey=pkey).delete() + +class TestBlobIO(BaseColumnIOTest): + + column = columns.Bytes + pkey_val = 'blake', uuid4().bytes + data_val = 'eggleston', uuid4().bytes + +class TestTextIO(BaseColumnIOTest): + + column = columns.Text + pkey_val = 'bacon' + data_val = 'monkey' + + +class TestNonBinaryTextIO(BaseColumnIOTest): + + column = columns.Text + pkey_val = 'bacon' + data_val = '0xmonkey' + +class TestInteger(BaseColumnIOTest): + + column = columns.Integer + pkey_val = 5 + data_val = 6 + +class TestBigInt(BaseColumnIOTest): + + column = columns.BigInt + pkey_val = 6 + data_val = pow(2, 63) - 1 + +class TestDateTime(BaseColumnIOTest): + + column = columns.DateTime + + now = datetime(*datetime.now().timetuple()[:6]) + pkey_val = now + data_val = now + timedelta(days=1) + +class TestDate(BaseColumnIOTest): + + column = columns.Date + + now = datetime.now().date() + pkey_val = now + data_val = now + timedelta(days=1) + +class TestUUID(BaseColumnIOTest): + + column = columns.UUID + + pkey_val = str(uuid4()), uuid4() + data_val = str(uuid4()), uuid4() + + def comparator_converter(self, val): + return val if isinstance(val, UUID) else UUID(val) + +class TestTimeUUID(BaseColumnIOTest): + + column = columns.TimeUUID + + pkey_val = str(uuid1()), uuid1() + data_val = str(uuid1()), uuid1() + + def comparator_converter(self, val): + return val if isinstance(val, UUID) else UUID(val) + +class TestFloatIO(BaseColumnIOTest): + + column = columns.Float + + pkey_val = 3.14 + data_val = -1982.11 + +class TestDecimalIO(BaseColumnIOTest): + + column = columns.Decimal + + pkey_val = Decimal('1.35'), 5, '2.4' + data_val = Decimal('0.005'), 3.5, '8' + + def comparator_converter(self, val): + return Decimal(val) + +class TestQuoter(unittest.TestCase): + + def test_equals(self): + assert ValueQuoter(False) == ValueQuoter(False) + assert ValueQuoter(1) == ValueQuoter(1) + assert ValueQuoter("foo") == ValueQuoter("foo") + assert ValueQuoter(1.55) == ValueQuoter(1.55) diff --git a/cqlengine/tests/connections/__init__.py b/cqlengine/tests/connections/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/cqlengine/tests/management/test_compaction_settings.py b/cqlengine/tests/management/test_compaction_settings.py new file mode 100644 index 00000000..2cad5d55 --- /dev/null +++ b/cqlengine/tests/management/test_compaction_settings.py @@ -0,0 +1,237 @@ +import copy +import json +from time import sleep +from mock import patch, MagicMock +from cqlengine import Model, columns, SizeTieredCompactionStrategy, LeveledCompactionStrategy +from cqlengine.exceptions import CQLEngineException +from cqlengine.management import get_compaction_options, drop_table, sync_table, get_table_settings +from cqlengine.tests.base import BaseCassEngTestCase + + +class CompactionModel(Model): + __keyspace__ = 'test' + __compaction__ = None + cid = columns.UUID(primary_key=True) + name = columns.Text() + + +class BaseCompactionTest(BaseCassEngTestCase): + def assert_option_fails(self, key): + # key is a normal_key, converted to + # __compaction_key__ + + key = "__compaction_{}__".format(key) + + with patch.object(self.model, key, 10), \ + self.assertRaises(CQLEngineException): + get_compaction_options(self.model) + + +class SizeTieredCompactionTest(BaseCompactionTest): + + def setUp(self): + self.model = copy.deepcopy(CompactionModel) + self.model.__compaction__ = SizeTieredCompactionStrategy + + def test_size_tiered(self): + result = get_compaction_options(self.model) + assert result['class'] == SizeTieredCompactionStrategy + + def test_min_threshold(self): + self.model.__compaction_min_threshold__ = 2 + result = get_compaction_options(self.model) + assert result['min_threshold'] == '2' + + +class LeveledCompactionTest(BaseCompactionTest): + def setUp(self): + self.model = copy.deepcopy(CompactionLeveledStrategyModel) + + def test_simple_leveled(self): + result = get_compaction_options(self.model) + assert result['class'] == LeveledCompactionStrategy + + def test_bucket_high_fails(self): + self.assert_option_fails('bucket_high') + + def test_bucket_low_fails(self): + self.assert_option_fails('bucket_low') + + def test_max_threshold_fails(self): + self.assert_option_fails('max_threshold') + + def test_min_threshold_fails(self): + self.assert_option_fails('min_threshold') + + def test_min_sstable_size_fails(self): + self.assert_option_fails('min_sstable_size') + + def test_sstable_size_in_mb(self): + with patch.object(self.model, '__compaction_sstable_size_in_mb__', 32): + result = get_compaction_options(self.model) + + assert result['sstable_size_in_mb'] == '32' + + +class LeveledcompactionTestTable(Model): + __keyspace__ = 'test' + __compaction__ = LeveledCompactionStrategy + __compaction_sstable_size_in_mb__ = 64 + + user_id = columns.UUID(primary_key=True) + name = columns.Text() + +from cqlengine.management import schema_columnfamilies + +class AlterTableTest(BaseCassEngTestCase): + + def test_alter_is_called_table(self): + drop_table(LeveledcompactionTestTable) + sync_table(LeveledcompactionTestTable) + with patch('cqlengine.management.update_compaction') as mock: + sync_table(LeveledcompactionTestTable) + assert mock.called == 1 + + def test_compaction_not_altered_without_changes_leveled(self): + from cqlengine.management import update_compaction + + class LeveledCompactionChangesDetectionTest(Model): + __keyspace__ = 'test' + __compaction__ = LeveledCompactionStrategy + __compaction_sstable_size_in_mb__ = 160 + __compaction_tombstone_threshold__ = 0.125 + __compaction_tombstone_compaction_interval__ = 3600 + + pk = columns.Integer(primary_key=True) + + drop_table(LeveledCompactionChangesDetectionTest) + sync_table(LeveledCompactionChangesDetectionTest) + + assert not update_compaction(LeveledCompactionChangesDetectionTest) + + def test_compaction_not_altered_without_changes_sizetiered(self): + from cqlengine.management import update_compaction + + class SizeTieredCompactionChangesDetectionTest(Model): + __keyspace__ = 'test' + __compaction__ = SizeTieredCompactionStrategy + __compaction_bucket_high__ = 20 + __compaction_bucket_low__ = 10 + __compaction_max_threshold__ = 200 + __compaction_min_threshold__ = 100 + __compaction_min_sstable_size__ = 1000 + __compaction_tombstone_threshold__ = 0.125 + __compaction_tombstone_compaction_interval__ = 3600 + + pk = columns.Integer(primary_key=True) + + drop_table(SizeTieredCompactionChangesDetectionTest) + sync_table(SizeTieredCompactionChangesDetectionTest) + + assert not update_compaction(SizeTieredCompactionChangesDetectionTest) + + def test_alter_actually_alters(self): + tmp = copy.deepcopy(LeveledcompactionTestTable) + drop_table(tmp) + sync_table(tmp) + tmp.__compaction__ = SizeTieredCompactionStrategy + tmp.__compaction_sstable_size_in_mb__ = None + sync_table(tmp) + + table_settings = get_table_settings(tmp) + + self.assertRegexpMatches(table_settings.options['compaction_strategy_class'], '.*SizeTieredCompactionStrategy$') + + + def test_alter_options(self): + + class AlterTable(Model): + __keyspace__ = 'test' + __compaction__ = LeveledCompactionStrategy + __compaction_sstable_size_in_mb__ = 64 + + user_id = columns.UUID(primary_key=True) + name = columns.Text() + + drop_table(AlterTable) + sync_table(AlterTable) + AlterTable.__compaction_sstable_size_in_mb__ = 128 + sync_table(AlterTable) + + + +class EmptyCompactionTest(BaseCassEngTestCase): + def test_empty_compaction(self): + class EmptyCompactionModel(Model): + __keyspace__ = 'test' + __compaction__ = None + cid = columns.UUID(primary_key=True) + name = columns.Text() + + result = get_compaction_options(EmptyCompactionModel) + self.assertEqual({}, result) + + +class CompactionLeveledStrategyModel(Model): + __keyspace__ = 'test' + __compaction__ = LeveledCompactionStrategy + cid = columns.UUID(primary_key=True) + name = columns.Text() + + +class CompactionSizeTieredModel(Model): + __keyspace__ = 'test' + __compaction__ = SizeTieredCompactionStrategy + cid = columns.UUID(primary_key=True) + name = columns.Text() + + + +class OptionsTest(BaseCassEngTestCase): + + def test_all_size_tiered_options(self): + class AllSizeTieredOptionsModel(Model): + __keyspace__ = 'test' + __compaction__ = SizeTieredCompactionStrategy + __compaction_bucket_low__ = .3 + __compaction_bucket_high__ = 2 + __compaction_min_threshold__ = 2 + __compaction_max_threshold__ = 64 + __compaction_tombstone_compaction_interval__ = 86400 + + cid = columns.UUID(primary_key=True) + name = columns.Text() + + drop_table(AllSizeTieredOptionsModel) + sync_table(AllSizeTieredOptionsModel) + + options = get_table_settings(AllSizeTieredOptionsModel).options['compaction_strategy_options'] + options = json.loads(options) + + expected = {u'min_threshold': u'2', + u'bucket_low': u'0.3', + u'tombstone_compaction_interval': u'86400', + u'bucket_high': u'2', + u'max_threshold': u'64'} + + self.assertDictEqual(options, expected) + + + def test_all_leveled_options(self): + + class AllLeveledOptionsModel(Model): + __keyspace__ = 'test' + __compaction__ = LeveledCompactionStrategy + __compaction_sstable_size_in_mb__ = 64 + + cid = columns.UUID(primary_key=True) + name = columns.Text() + + drop_table(AllLeveledOptionsModel) + sync_table(AllLeveledOptionsModel) + + settings = get_table_settings(AllLeveledOptionsModel).options + + options = json.loads(settings['compaction_strategy_options']) + self.assertDictEqual(options, {u'sstable_size_in_mb': u'64'}) + diff --git a/cqlengine/tests/management/test_management.py b/cqlengine/tests/management/test_management.py index ac542bb8..1ed3137f 100644 --- a/cqlengine/tests/management/test_management.py +++ b/cqlengine/tests/management/test_management.py @@ -1,44 +1,14 @@ -from cqlengine.management import create_table, delete_table + +from cqlengine import ALL, CACHING_ALL, CACHING_NONE +from cqlengine.exceptions import CQLEngineException +from cqlengine.management import get_fields, sync_table, drop_table from cqlengine.tests.base import BaseCassEngTestCase - -from cqlengine.connection import ConnectionPool - -from mock import Mock from cqlengine import management from cqlengine.tests.query.test_queryset import TestModel +from cqlengine.models import Model +from cqlengine import columns, SizeTieredCompactionStrategy, LeveledCompactionStrategy -class ConnectionPoolTestCase(BaseCassEngTestCase): - """Test cassandra connection pooling.""" - - def setUp(self): - ConnectionPool.clear() - - def test_should_create_single_connection_on_request(self): - """Should create a single connection on first request""" - result = ConnectionPool.get() - self.assertIsNotNone(result) - self.assertEquals(0, ConnectionPool._queue.qsize()) - ConnectionPool._queue.put(result) - self.assertEquals(1, ConnectionPool._queue.qsize()) - - def test_should_close_connection_if_queue_is_full(self): - """Should close additional connections if queue is full""" - connections = [ConnectionPool.get() for x in range(10)] - for conn in connections: - ConnectionPool.put(conn) - fake_conn = Mock() - ConnectionPool.put(fake_conn) - fake_conn.close.assert_called_once_with() - - def test_should_pop_connections_from_queue(self): - """Should pull existing connections off of the queue""" - conn = ConnectionPool.get() - ConnectionPool.put(conn) - self.assertEquals(1, ConnectionPool._queue.qsize()) - self.assertEquals(conn, ConnectionPool.get()) - self.assertEquals(0, ConnectionPool._queue.qsize()) - class CreateKeyspaceTest(BaseCassEngTestCase): def test_create_succeeeds(self): @@ -51,7 +21,219 @@ class DeleteTableTest(BaseCassEngTestCase): """ """ - create_table(TestModel) + sync_table(TestModel) - delete_table(TestModel) - delete_table(TestModel) + drop_table(TestModel) + drop_table(TestModel) + +class LowercaseKeyModel(Model): + __keyspace__ = 'test' + first_key = columns.Integer(primary_key=True) + second_key = columns.Integer(primary_key=True) + some_data = columns.Text() + +class CapitalizedKeyModel(Model): + __keyspace__ = 'test' + firstKey = columns.Integer(primary_key=True) + secondKey = columns.Integer(primary_key=True) + someData = columns.Text() + +class PrimaryKeysOnlyModel(Model): + __keyspace__ = 'test' + __compaction__ = LeveledCompactionStrategy + + first_ey = columns.Integer(primary_key=True) + second_key = columns.Integer(primary_key=True) + + +class CapitalizedKeyTest(BaseCassEngTestCase): + + def test_table_definition(self): + """ Tests that creating a table with capitalized column names succeedso """ + sync_table(LowercaseKeyModel) + sync_table(CapitalizedKeyModel) + + drop_table(LowercaseKeyModel) + drop_table(CapitalizedKeyModel) + + +class FirstModel(Model): + __keyspace__ = 'test' + __table_name__ = 'first_model' + first_key = columns.UUID(primary_key=True) + second_key = columns.UUID() + third_key = columns.Text() + +class SecondModel(Model): + __keyspace__ = 'test' + __table_name__ = 'first_model' + first_key = columns.UUID(primary_key=True) + second_key = columns.UUID() + third_key = columns.Text() + fourth_key = columns.Text() + +class ThirdModel(Model): + __keyspace__ = 'test' + __table_name__ = 'first_model' + first_key = columns.UUID(primary_key=True) + second_key = columns.UUID() + third_key = columns.Text() + # removed fourth key, but it should stay in the DB + blah = columns.Map(columns.Text, columns.Text) + +class FourthModel(Model): + __keyspace__ = 'test' + __table_name__ = 'first_model' + first_key = columns.UUID(primary_key=True) + second_key = columns.UUID() + third_key = columns.Text() + # removed fourth key, but it should stay in the DB + renamed = columns.Map(columns.Text, columns.Text, db_field='blah') + +class AddColumnTest(BaseCassEngTestCase): + def setUp(self): + drop_table(FirstModel) + + def test_add_column(self): + sync_table(FirstModel) + fields = get_fields(FirstModel) + + # this should contain the second key + self.assertEqual(len(fields), 2) + # get schema + sync_table(SecondModel) + + fields = get_fields(FirstModel) + self.assertEqual(len(fields), 3) + + sync_table(ThirdModel) + fields = get_fields(FirstModel) + self.assertEqual(len(fields), 4) + + sync_table(FourthModel) + fields = get_fields(FirstModel) + self.assertEqual(len(fields), 4) + + +class ModelWithTableProperties(Model): + __keyspace__ = 'test' + # Set random table properties + __bloom_filter_fp_chance__ = 0.76328 + __caching__ = CACHING_ALL + __comment__ = 'TxfguvBdzwROQALmQBOziRMbkqVGFjqcJfVhwGR' + __default_time_to_live__ = 4756 + __gc_grace_seconds__ = 2063 + __index_interval__ = 98706 + __memtable_flush_period_in_ms__ = 43681 + __populate_io_cache_on_flush__ = True + __read_repair_chance__ = 0.17985 + __replicate_on_write__ = False + __dclocal_read_repair_chance__ = 0.50811 + + key = columns.UUID(primary_key=True) + + +class TablePropertiesTests(BaseCassEngTestCase): + + def setUp(self): + drop_table(ModelWithTableProperties) + + def test_set_table_properties(self): + sync_table(ModelWithTableProperties) + self.assertDictContainsSubset({ + 'bloom_filter_fp_chance': 0.76328, + 'caching': CACHING_ALL, + 'comment': 'TxfguvBdzwROQALmQBOziRMbkqVGFjqcJfVhwGR', + 'default_time_to_live': 4756, + 'gc_grace_seconds': 2063, + 'index_interval': 98706, + 'memtable_flush_period_in_ms': 43681, + 'populate_io_cache_on_flush': True, + 'read_repair_chance': 0.17985, + 'replicate_on_write': False, + # For some reason 'dclocal_read_repair_chance' in CQL is called + # just 'local_read_repair_chance' in the schema table. + # Source: https://issues.apache.org/jira/browse/CASSANDRA-6717 + + # TODO: due to a bug in the native driver i'm not seeing the local read repair chance show up + #'local_read_repair_chance': 0.50811, + + }, management.get_table_settings(ModelWithTableProperties).options) + + def test_table_property_update(self): + ModelWithTableProperties.__bloom_filter_fp_chance__ = 0.66778 + ModelWithTableProperties.__caching__ = CACHING_NONE + ModelWithTableProperties.__comment__ = 'xirAkRWZVVvsmzRvXamiEcQkshkUIDINVJZgLYSdnGHweiBrAiJdLJkVohdRy' + ModelWithTableProperties.__default_time_to_live__ = 65178 + ModelWithTableProperties.__gc_grace_seconds__ = 96362 + ModelWithTableProperties.__index_interval__ = 94207 + ModelWithTableProperties.__memtable_flush_period_in_ms__ = 60210 + ModelWithTableProperties.__populate_io_cache_on_flush__ = False + ModelWithTableProperties.__read_repair_chance__ = 0.2989 + ModelWithTableProperties.__replicate_on_write__ = True + ModelWithTableProperties.__dclocal_read_repair_chance__ = 0.12732 + + sync_table(ModelWithTableProperties) + + table_settings = management.get_table_settings(ModelWithTableProperties).options + + self.assertDictContainsSubset({ + 'bloom_filter_fp_chance': 0.66778, + 'caching': CACHING_NONE, + 'comment': 'xirAkRWZVVvsmzRvXamiEcQkshkUIDINVJZgLYSdnGHweiBrAiJdLJkVohdRy', + 'default_time_to_live': 65178, + 'gc_grace_seconds': 96362, + 'index_interval': 94207, + 'memtable_flush_period_in_ms': 60210, + 'populate_io_cache_on_flush': False, + 'read_repair_chance': 0.2989, + 'replicate_on_write': True, + + # TODO see above comment re: native driver missing local read repair chance + # 'local_read_repair_chance': 0.12732, + }, table_settings) + + +class SyncTableTests(BaseCassEngTestCase): + + def setUp(self): + drop_table(PrimaryKeysOnlyModel) + + def test_sync_table_works_with_primary_keys_only_tables(self): + + # This is "create table": + + sync_table(PrimaryKeysOnlyModel) + + # let's make sure settings persisted correctly: + + assert PrimaryKeysOnlyModel.__compaction__ == LeveledCompactionStrategy + # blows up with DoesNotExist if table does not exist + table_settings = management.get_table_settings(PrimaryKeysOnlyModel) + # let make sure the flag we care about + + assert LeveledCompactionStrategy in table_settings.options['compaction_strategy_class'] + + + # Now we are "updating" the table: + + # setting up something to change + PrimaryKeysOnlyModel.__compaction__ = SizeTieredCompactionStrategy + + # primary-keys-only tables do not create entries in system.schema_columns + # table. Only non-primary keys are added to that table. + # Our code must deal with that eventuality properly (not crash) + # on subsequent runs of sync_table (which runs get_fields internally) + get_fields(PrimaryKeysOnlyModel) + sync_table(PrimaryKeysOnlyModel) + + table_settings = management.get_table_settings(PrimaryKeysOnlyModel) + assert SizeTieredCompactionStrategy in table_settings.options['compaction_strategy_class'] + +class NonModelFailureTest(BaseCassEngTestCase): + class FakeModel(object): + pass + + def test_failure(self): + with self.assertRaises(CQLEngineException): + sync_table(self.FakeModel) diff --git a/cqlengine/tests/model/test_class_construction.py b/cqlengine/tests/model/test_class_construction.py index 77b1d737..37f0f819 100644 --- a/cqlengine/tests/model/test_class_construction.py +++ b/cqlengine/tests/model/test_class_construction.py @@ -1,7 +1,10 @@ +from uuid import uuid4 +import warnings +from cqlengine.query import QueryException, ModelQuerySet, DMLQuery from cqlengine.tests.base import BaseCassEngTestCase -from cqlengine.exceptions import ModelException -from cqlengine.models import Model +from cqlengine.exceptions import ModelException, CQLEngineException +from cqlengine.models import Model, ModelDefinitionException, ColumnQueryEvaluator, UndefinedKeyspaceWarning from cqlengine import columns import cqlengine @@ -17,6 +20,8 @@ class TestModelClassFunction(BaseCassEngTestCase): """ class TestModel(Model): + __keyspace__ = 'test' + id = columns.UUID(primary_key=True, default=lambda:uuid4()) text = columns.Text() #check class attibutes @@ -37,6 +42,8 @@ class TestModelClassFunction(BaseCassEngTestCase): -the db_map allows columns """ class WildDBNames(Model): + __keyspace__ = 'test' + id = columns.UUID(primary_key=True, default=lambda:uuid4()) content = columns.Text(db_field='words_and_whatnot') numbers = columns.Integer(db_field='integers_etc') @@ -60,17 +67,29 @@ class TestModelClassFunction(BaseCassEngTestCase): """ class Stuff(Model): + __keyspace__ = 'test' + id = columns.UUID(primary_key=True, default=lambda:uuid4()) words = columns.Text() content = columns.Text() numbers = columns.Integer() self.assertEquals(Stuff._columns.keys(), ['id', 'words', 'content', 'numbers']) + def test_exception_raised_when_creating_class_without_pk(self): + with self.assertRaises(ModelDefinitionException): + class TestModel(Model): + __keyspace__ = 'test' + count = columns.Integer() + text = columns.Text(required=False) + + def test_value_managers_are_keeping_model_instances_isolated(self): """ Tests that instance value managers are isolated from other instances """ class Stuff(Model): + __keyspace__ = 'test' + id = columns.UUID(primary_key=True, default=lambda:uuid4()) num = columns.Integer() inst1 = Stuff(num=5) @@ -85,6 +104,8 @@ class TestModelClassFunction(BaseCassEngTestCase): Tests that fields defined on the super class are inherited properly """ class TestModel(Model): + __keyspace__ = 'test' + id = columns.UUID(primary_key=True, default=lambda:uuid4()) text = columns.Text() class InheritedModel(TestModel): @@ -93,6 +114,15 @@ class TestModelClassFunction(BaseCassEngTestCase): assert 'text' in InheritedModel._columns assert 'numbers' in InheritedModel._columns + def test_column_family_name_generation(self): + """ Tests that auto column family name generation works as expected """ + class TestModel(Model): + __keyspace__ = 'test' + id = columns.UUID(primary_key=True, default=lambda:uuid4()) + text = columns.Text() + + assert TestModel.column_family_name(include_keyspace=False) == 'test_model' + def test_normal_fields_can_be_defined_between_primary_keys(self): """ Tests tha non primary key fields can be defined between primary key fields @@ -117,24 +147,224 @@ class TestModelClassFunction(BaseCassEngTestCase): """ Test that metadata defined in one class, is not inherited by subclasses """ - + + def test_partition_keys(self): + """ + Test compound partition key definition + """ + class ModelWithPartitionKeys(cqlengine.Model): + __keyspace__ = 'test' + id = columns.UUID(primary_key=True, default=lambda:uuid4()) + c1 = cqlengine.Text(primary_key=True) + p1 = cqlengine.Text(partition_key=True) + p2 = cqlengine.Text(partition_key=True) + + cols = ModelWithPartitionKeys._columns + + self.assertTrue(cols['c1'].primary_key) + self.assertFalse(cols['c1'].partition_key) + + self.assertTrue(cols['p1'].primary_key) + self.assertTrue(cols['p1'].partition_key) + self.assertTrue(cols['p2'].primary_key) + self.assertTrue(cols['p2'].partition_key) + + obj = ModelWithPartitionKeys(p1='a', p2='b') + self.assertEquals(obj.pk, ('a', 'b')) + + def test_del_attribute_is_assigned_properly(self): + """ Tests that columns that can be deleted have the del attribute """ + class DelModel(Model): + __keyspace__ = 'test' + id = columns.UUID(primary_key=True, default=lambda:uuid4()) + key = columns.Integer(primary_key=True) + data = columns.Integer(required=False) + + model = DelModel(key=4, data=5) + del model.data + with self.assertRaises(AttributeError): + del model.key + + def test_does_not_exist_exceptions_are_not_shared_between_model(self): + """ Tests that DoesNotExist exceptions are not the same exception between models """ + + class Model1(Model): + __keyspace__ = 'test' + id = columns.UUID(primary_key=True, default=lambda:uuid4()) + + class Model2(Model): + __keyspace__ = 'test' + id = columns.UUID(primary_key=True, default=lambda:uuid4()) + + try: + raise Model1.DoesNotExist + except Model2.DoesNotExist: + assert False, "Model1 exception should not be caught by Model2" + except Model1.DoesNotExist: + #expected + pass + + def test_does_not_exist_inherits_from_superclass(self): + """ Tests that a DoesNotExist exception can be caught by it's parent class DoesNotExist """ + class Model1(Model): + __keyspace__ = 'test' + id = columns.UUID(primary_key=True, default=lambda:uuid4()) + + class Model2(Model1): + pass + + try: + raise Model2.DoesNotExist + except Model1.DoesNotExist: + #expected + pass + except Exception: + assert False, "Model2 exception should not be caught by Model1" + + def test_abstract_model_keyspace_warning_is_skipped(self): + with warnings.catch_warnings(record=True) as warn: + class NoKeyspace(Model): + __abstract__ = True + key = columns.UUID(primary_key=True) + + self.assertEqual(len(warn), 0) + class TestManualTableNaming(BaseCassEngTestCase): - + class RenamedTest(cqlengine.Model): - keyspace = 'whatever' - table_name = 'manual_name' - + __keyspace__ = 'whatever' + __table_name__ = 'manual_name' + id = cqlengine.UUID(primary_key=True) data = cqlengine.Text() - + def test_proper_table_naming(self): assert self.RenamedTest.column_family_name(include_keyspace=False) == 'manual_name' assert self.RenamedTest.column_family_name(include_keyspace=True) == 'whatever.manual_name' - def test_manual_table_name_is_not_inherited(self): - class InheritedTest(self.RenamedTest): pass - assert InheritedTest.table_name is None - +class AbstractModel(Model): + __abstract__ = True + __keyspace__ = 'test' + +class ConcreteModel(AbstractModel): + pkey = columns.Integer(primary_key=True) + data = columns.Integer() + +class AbstractModelWithCol(Model): + __keyspace__ = 'test' + __abstract__ = True + pkey = columns.Integer(primary_key=True) + +class ConcreteModelWithCol(AbstractModelWithCol): + data = columns.Integer() + +class AbstractModelWithFullCols(Model): + __abstract__ = True + __keyspace__ = 'test' + pkey = columns.Integer(primary_key=True) + data = columns.Integer() + +class TestAbstractModelClasses(BaseCassEngTestCase): + + def test_id_field_is_not_created(self): + """ Tests that an id field is not automatically generated on abstract classes """ + assert not hasattr(AbstractModel, 'id') + assert not hasattr(AbstractModelWithCol, 'id') + + def test_id_field_is_not_created_on_subclass(self): + assert not hasattr(ConcreteModel, 'id') + + def test_abstract_attribute_is_not_inherited(self): + """ Tests that __abstract__ attribute is not inherited """ + assert not ConcreteModel.__abstract__ + assert not ConcreteModelWithCol.__abstract__ + + def test_attempting_to_save_abstract_model_fails(self): + """ Attempting to save a model from an abstract model should fail """ + with self.assertRaises(CQLEngineException): + AbstractModelWithFullCols.create(pkey=1, data=2) + + def test_attempting_to_create_abstract_table_fails(self): + """ Attempting to create a table from an abstract model should fail """ + from cqlengine.management import sync_table + with self.assertRaises(CQLEngineException): + sync_table(AbstractModelWithFullCols) + + def test_attempting_query_on_abstract_model_fails(self): + """ Tests attempting to execute query with an abstract model fails """ + with self.assertRaises(CQLEngineException): + iter(AbstractModelWithFullCols.objects(pkey=5)).next() + + def test_abstract_columns_are_inherited(self): + """ Tests that columns defined in the abstract class are inherited into the concrete class """ + assert hasattr(ConcreteModelWithCol, 'pkey') + assert isinstance(ConcreteModelWithCol.pkey, ColumnQueryEvaluator) + assert isinstance(ConcreteModelWithCol._columns['pkey'], columns.Column) + + def test_concrete_class_table_creation_cycle(self): + """ Tests that models with inherited abstract classes can be created, and have io performed """ + from cqlengine.management import sync_table, drop_table + sync_table(ConcreteModelWithCol) + + w1 = ConcreteModelWithCol.create(pkey=5, data=6) + w2 = ConcreteModelWithCol.create(pkey=6, data=7) + + r1 = ConcreteModelWithCol.get(pkey=5) + r2 = ConcreteModelWithCol.get(pkey=6) + + assert w1.pkey == r1.pkey + assert w1.data == r1.data + assert w2.pkey == r2.pkey + assert w2.data == r2.data + + drop_table(ConcreteModelWithCol) + + +class TestCustomQuerySet(BaseCassEngTestCase): + """ Tests overriding the default queryset class """ + + class TestException(Exception): pass + + def test_overriding_queryset(self): + + class QSet(ModelQuerySet): + def create(iself, **kwargs): + raise self.TestException + + class CQModel(Model): + __queryset__ = QSet + __keyspace__ = 'test' + part = columns.UUID(primary_key=True) + data = columns.Text() + + with self.assertRaises(self.TestException): + CQModel.create(part=uuid4(), data='s') + + def test_overriding_dmlqueryset(self): + + class DMLQ(DMLQuery): + def save(iself): + raise self.TestException + + class CDQModel(Model): + __keyspace__ = 'test' + __dmlquery__ = DMLQ + part = columns.UUID(primary_key=True) + data = columns.Text() + + with self.assertRaises(self.TestException): + CDQModel().save() + + +class TestCachedLengthIsNotCarriedToSubclasses(BaseCassEngTestCase): + def test_subclassing(self): + + length = len(ConcreteModelWithCol()) + + class AlreadyLoadedTest(ConcreteModelWithCol): + new_field = columns.Integer() + + self.assertGreater(len(AlreadyLoadedTest()), length) diff --git a/cqlengine/tests/model/test_equality_operations.py b/cqlengine/tests/model/test_equality_operations.py index 4d5b9521..eeb8992b 100644 --- a/cqlengine/tests/model/test_equality_operations.py +++ b/cqlengine/tests/model/test_equality_operations.py @@ -1,12 +1,15 @@ from unittest import skip +from uuid import uuid4 from cqlengine.tests.base import BaseCassEngTestCase -from cqlengine.management import create_table -from cqlengine.management import delete_table +from cqlengine.management import sync_table +from cqlengine.management import drop_table from cqlengine.models import Model from cqlengine import columns class TestModel(Model): + __keyspace__ = 'test' + id = columns.UUID(primary_key=True, default=lambda:uuid4()) count = columns.Integer() text = columns.Text(required=False) @@ -15,7 +18,7 @@ class TestEqualityOperators(BaseCassEngTestCase): @classmethod def setUpClass(cls): super(TestEqualityOperators, cls).setUpClass() - create_table(TestModel) + sync_table(TestModel) def setUp(self): super(TestEqualityOperators, self).setUp() @@ -25,7 +28,7 @@ class TestEqualityOperators(BaseCassEngTestCase): @classmethod def tearDownClass(cls): super(TestEqualityOperators, cls).tearDownClass() - delete_table(TestModel) + drop_table(TestModel) def test_an_instance_evaluates_as_equal_to_itself(self): """ diff --git a/cqlengine/tests/model/test_model.py b/cqlengine/tests/model/test_model.py new file mode 100644 index 00000000..1bd143d0 --- /dev/null +++ b/cqlengine/tests/model/test_model.py @@ -0,0 +1,56 @@ +from unittest import TestCase + +from cqlengine.models import Model, ModelDefinitionException +from cqlengine import columns + + +class TestModel(TestCase): + """ Tests the non-io functionality of models """ + + def test_instance_equality(self): + """ tests the model equality functionality """ + class EqualityModel(Model): + __keyspace__ = 'test' + pk = columns.Integer(primary_key=True) + + m0 = EqualityModel(pk=0) + m1 = EqualityModel(pk=1) + + self.assertEqual(m0, m0) + self.assertNotEqual(m0, m1) + + def test_model_equality(self): + """ tests the model equality functionality """ + class EqualityModel0(Model): + __keyspace__ = 'test' + pk = columns.Integer(primary_key=True) + + class EqualityModel1(Model): + __keyspace__ = 'test' + kk = columns.Integer(primary_key=True) + + m0 = EqualityModel0(pk=0) + m1 = EqualityModel1(kk=1) + + self.assertEqual(m0, m0) + self.assertNotEqual(m0, m1) + + +class BuiltInAttributeConflictTest(TestCase): + """tests Model definitions that conflict with built-in attributes/methods""" + + def test_model_with_attribute_name_conflict(self): + """should raise exception when model defines column that conflicts with built-in attribute""" + with self.assertRaises(ModelDefinitionException): + class IllegalTimestampColumnModel(Model): + __keyspace__ = 'test' + my_primary_key = columns.Integer(primary_key=True) + timestamp = columns.BigInt() + + def test_model_with_method_name_conflict(self): + """should raise exception when model defines column that conflicts with built-in method""" + with self.assertRaises(ModelDefinitionException): + class IllegalFilterColumnModel(Model): + __keyspace__ = 'test' + my_primary_key = columns.Integer(primary_key=True) + filter = columns.Text() \ No newline at end of file diff --git a/cqlengine/tests/model/test_model_io.py b/cqlengine/tests/model/test_model_io.py index 08e77ded..1faaaf02 100644 --- a/cqlengine/tests/model/test_model_io.py +++ b/cqlengine/tests/model/test_model_io.py @@ -1,39 +1,78 @@ from uuid import uuid4 import random +from datetime import date +from operator import itemgetter +from cqlengine.exceptions import CQLEngineException from cqlengine.tests.base import BaseCassEngTestCase -from cqlengine.management import create_table -from cqlengine.management import delete_table +from cqlengine.management import sync_table +from cqlengine.management import drop_table from cqlengine.models import Model from cqlengine import columns class TestModel(Model): + __keyspace__ = 'test' + id = columns.UUID(primary_key=True, default=lambda:uuid4()) count = columns.Integer() text = columns.Text(required=False) a_bool = columns.Boolean(default=False) +class TestModel(Model): + __keyspace__ = 'test' + id = columns.UUID(primary_key=True, default=lambda:uuid4()) + count = columns.Integer() + text = columns.Text(required=False) + a_bool = columns.Boolean(default=False) + + class TestModelIO(BaseCassEngTestCase): @classmethod def setUpClass(cls): super(TestModelIO, cls).setUpClass() - create_table(TestModel) + sync_table(TestModel) @classmethod def tearDownClass(cls): super(TestModelIO, cls).tearDownClass() - delete_table(TestModel) + drop_table(TestModel) def test_model_save_and_load(self): """ Tests that models can be saved and retrieved """ tm = TestModel.create(count=8, text='123456789') + self.assertIsInstance(tm, TestModel) + tm2 = TestModel.objects(id=tm.pk).first() + self.assertIsInstance(tm2, TestModel) for cname in tm._columns.keys(): self.assertEquals(getattr(tm, cname), getattr(tm2, cname)) + def test_model_read_as_dict(self): + """ + Tests that columns of an instance can be read as a dict. + """ + tm = TestModel.create(count=8, text='123456789', a_bool=True) + column_dict = { + 'id': tm.id, + 'count': tm.count, + 'text': tm.text, + 'a_bool': tm.a_bool, + } + self.assertEquals(sorted(tm.keys()), sorted(column_dict.keys())) + self.assertEquals(sorted(tm.values()), sorted(column_dict.values())) + self.assertEquals( + sorted(tm.items(), key=itemgetter(0)), + sorted(column_dict.items(), key=itemgetter(0))) + self.assertEquals(len(tm), len(column_dict)) + for column_id in column_dict.keys(): + self.assertEqual(tm[column_id], column_dict[column_id]) + + tm['count'] = 6 + self.assertEqual(tm.count, 6) + def test_model_updating_works_properly(self): """ Tests that subsequent saves after initial model creation work @@ -65,35 +104,38 @@ class TestModelIO(BaseCassEngTestCase): tm.save() tm2 = TestModel.objects(id=tm.pk).first() + self.assertIsInstance(tm2, TestModel) + assert tm2.text is None assert tm2._values['text'].previous_value is None - def test_a_sensical_error_is_raised_if_you_try_to_create_a_table_twice(self): """ """ - create_table(TestModel) - create_table(TestModel) + sync_table(TestModel) + sync_table(TestModel) class TestMultiKeyModel(Model): + __keyspace__ = 'test' partition = columns.Integer(primary_key=True) cluster = columns.Integer(primary_key=True) count = columns.Integer(required=False) text = columns.Text(required=False) + class TestDeleting(BaseCassEngTestCase): @classmethod def setUpClass(cls): super(TestDeleting, cls).setUpClass() - delete_table(TestMultiKeyModel) - create_table(TestMultiKeyModel) + drop_table(TestMultiKeyModel) + sync_table(TestMultiKeyModel) @classmethod def tearDownClass(cls): super(TestDeleting, cls).tearDownClass() - delete_table(TestMultiKeyModel) + drop_table(TestMultiKeyModel) def test_deleting_only_deletes_one_object(self): partition = random.randint(0,1000) @@ -114,13 +156,13 @@ class TestUpdating(BaseCassEngTestCase): @classmethod def setUpClass(cls): super(TestUpdating, cls).setUpClass() - delete_table(TestMultiKeyModel) - create_table(TestMultiKeyModel) + drop_table(TestMultiKeyModel) + sync_table(TestMultiKeyModel) @classmethod def tearDownClass(cls): super(TestUpdating, cls).tearDownClass() - delete_table(TestMultiKeyModel) + drop_table(TestMultiKeyModel) def setUp(self): super(TestUpdating, self).setUp() @@ -148,6 +190,14 @@ class TestUpdating(BaseCassEngTestCase): assert check.count is None assert check.text is None + def test_get_changed_columns(self): + assert self.instance.get_changed_columns() == [] + self.instance.count = 1 + changes = self.instance.get_changed_columns() + assert len(changes) == 1 + assert changes == ['count'] + self.instance.save() + assert self.instance.get_changed_columns() == [] class TestCanUpdate(BaseCassEngTestCase): @@ -155,13 +205,13 @@ class TestCanUpdate(BaseCassEngTestCase): @classmethod def setUpClass(cls): super(TestCanUpdate, cls).setUpClass() - delete_table(TestModel) - create_table(TestModel) + drop_table(TestModel) + sync_table(TestModel) @classmethod def tearDownClass(cls): super(TestCanUpdate, cls).tearDownClass() - delete_table(TestModel) + drop_table(TestModel) def test_success_case(self): tm = TestModel(count=8, text='123456789') @@ -193,16 +243,18 @@ class TestCanUpdate(BaseCassEngTestCase): class IndexDefinitionModel(Model): + __keyspace__ = 'test' key = columns.UUID(primary_key=True) val = columns.Text(index=True) class TestIndexedColumnDefinition(BaseCassEngTestCase): def test_exception_isnt_raised_if_an_index_is_defined_more_than_once(self): - create_table(IndexDefinitionModel) - create_table(IndexDefinitionModel) + sync_table(IndexDefinitionModel) + sync_table(IndexDefinitionModel) class ReservedWordModel(Model): + __keyspace__ = 'test' token = columns.Text(primary_key=True) insert = columns.Integer(index=True) @@ -211,7 +263,7 @@ class TestQueryQuoting(BaseCassEngTestCase): def test_reserved_cql_words_can_be_used_as_column_names(self): """ """ - create_table(ReservedWordModel) + sync_table(ReservedWordModel) model1 = ReservedWordModel.create(token='1', insert=5) @@ -222,3 +274,52 @@ class TestQueryQuoting(BaseCassEngTestCase): assert model1.insert == model2[0].insert +class TestQueryModel(Model): + __keyspace__ = 'test' + test_id = columns.UUID(primary_key=True, default=uuid4) + date = columns.Date(primary_key=True) + description = columns.Text() + + +class TestQuerying(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(TestQuerying, cls).setUpClass() + drop_table(TestQueryModel) + sync_table(TestQueryModel) + + @classmethod + def tearDownClass(cls): + super(TestQuerying, cls).tearDownClass() + drop_table(TestQueryModel) + + def test_query_with_date(self): + uid = uuid4() + day = date(2013, 11, 26) + obj = TestQueryModel.create(test_id=uid, date=day, description=u'foo') + + self.assertEqual(obj.description, u'foo') + + inst = TestQueryModel.filter( + TestQueryModel.test_id == uid, + TestQueryModel.date == day).limit(1).first() + + assert inst.test_id == uid + assert inst.date == day + +def test_none_filter_fails(): + class NoneFilterModel(Model): + __keyspace__ = 'test' + pk = columns.Integer(primary_key=True) + v = columns.Integer() + sync_table(NoneFilterModel) + + try: + NoneFilterModel.objects(pk=None) + raise Exception("fail") + except CQLEngineException as e: + pass + + + diff --git a/cqlengine/tests/model/test_polymorphism.py b/cqlengine/tests/model/test_polymorphism.py new file mode 100644 index 00000000..450a6ed4 --- /dev/null +++ b/cqlengine/tests/model/test_polymorphism.py @@ -0,0 +1,241 @@ +import uuid +import mock + +from cqlengine import columns +from cqlengine import models +from cqlengine.connection import get_session +from cqlengine.tests.base import BaseCassEngTestCase +from cqlengine import management + + +class TestPolymorphicClassConstruction(BaseCassEngTestCase): + + def test_multiple_polymorphic_key_failure(self): + """ Tests that defining a model with more than one polymorphic key fails """ + with self.assertRaises(models.ModelDefinitionException): + class M(models.Model): + __keyspace__ = 'test' + partition = columns.Integer(primary_key=True) + type1 = columns.Integer(polymorphic_key=True) + type2 = columns.Integer(polymorphic_key=True) + + def test_polymorphic_key_inheritance(self): + """ Tests that polymorphic_key attribute is not inherited """ + class Base(models.Model): + __keyspace__ = 'test' + partition = columns.Integer(primary_key=True) + type1 = columns.Integer(polymorphic_key=True) + + class M1(Base): + __polymorphic_key__ = 1 + + class M2(M1): + pass + + assert M2.__polymorphic_key__ is None + + def test_polymorphic_metaclass(self): + """ Tests that the model meta class configures polymorphic models properly """ + class Base(models.Model): + __keyspace__ = 'test' + partition = columns.Integer(primary_key=True) + type1 = columns.Integer(polymorphic_key=True) + + class M1(Base): + __polymorphic_key__ = 1 + + assert Base._is_polymorphic + assert M1._is_polymorphic + + assert Base._is_polymorphic_base + assert not M1._is_polymorphic_base + + assert Base._polymorphic_column is Base._columns['type1'] + assert M1._polymorphic_column is M1._columns['type1'] + + assert Base._polymorphic_column_name == 'type1' + assert M1._polymorphic_column_name == 'type1' + + def test_table_names_are_inherited_from_poly_base(self): + class Base(models.Model): + __keyspace__ = 'test' + partition = columns.Integer(primary_key=True) + type1 = columns.Integer(polymorphic_key=True) + + class M1(Base): + __polymorphic_key__ = 1 + + assert Base.column_family_name() == M1.column_family_name() + + def test_collection_columns_cant_be_polymorphic_keys(self): + with self.assertRaises(models.ModelDefinitionException): + class Base(models.Model): + __keyspace__ = 'test' + partition = columns.Integer(primary_key=True) + type1 = columns.Set(columns.Integer, polymorphic_key=True) + + +class PolyBase(models.Model): + __keyspace__ = 'test' + partition = columns.UUID(primary_key=True, default=uuid.uuid4) + row_type = columns.Integer(polymorphic_key=True) + + +class Poly1(PolyBase): + __polymorphic_key__ = 1 + data1 = columns.Text() + + +class Poly2(PolyBase): + __polymorphic_key__ = 2 + data2 = columns.Text() + + +class TestPolymorphicModel(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(TestPolymorphicModel, cls).setUpClass() + management.sync_table(Poly1) + management.sync_table(Poly2) + + @classmethod + def tearDownClass(cls): + super(TestPolymorphicModel, cls).tearDownClass() + management.drop_table(Poly1) + management.drop_table(Poly2) + + def test_saving_base_model_fails(self): + with self.assertRaises(models.PolyMorphicModelException): + PolyBase.create() + + def test_saving_subclass_saves_poly_key(self): + p1 = Poly1.create(data1='pickle') + p2 = Poly2.create(data2='bacon') + + assert p1.row_type == Poly1.__polymorphic_key__ + assert p2.row_type == Poly2.__polymorphic_key__ + + def test_query_deserialization(self): + p1 = Poly1.create(data1='pickle') + p2 = Poly2.create(data2='bacon') + + p1r = PolyBase.get(partition=p1.partition) + p2r = PolyBase.get(partition=p2.partition) + + assert isinstance(p1r, Poly1) + assert isinstance(p2r, Poly2) + + def test_delete_on_polymorphic_subclass_does_not_include_polymorphic_key(self): + p1 = Poly1.create() + session = get_session() + with mock.patch.object(session, 'execute') as m: + Poly1.objects(partition=p1.partition).delete() + + # make sure our polymorphic key isn't in the CQL + # not sure how we would even get here if it was in there + # since the CQL would fail. + + self.assertNotIn("row_type", m.call_args[0][0].query_string) + + + + + +class UnindexedPolyBase(models.Model): + __keyspace__ = 'test' + partition = columns.UUID(primary_key=True, default=uuid.uuid4) + cluster = columns.UUID(primary_key=True, default=uuid.uuid4) + row_type = columns.Integer(polymorphic_key=True) + + +class UnindexedPoly1(UnindexedPolyBase): + __polymorphic_key__ = 1 + data1 = columns.Text() + + +class UnindexedPoly2(UnindexedPolyBase): + __polymorphic_key__ = 2 + data2 = columns.Text() + + +class UnindexedPoly3(UnindexedPoly2): + __polymorphic_key__ = 3 + data3 = columns.Text() + + +class TestUnindexedPolymorphicQuery(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(TestUnindexedPolymorphicQuery, cls).setUpClass() + management.sync_table(UnindexedPoly1) + management.sync_table(UnindexedPoly2) + management.sync_table(UnindexedPoly3) + + cls.p1 = UnindexedPoly1.create(data1='pickle') + cls.p2 = UnindexedPoly2.create(partition=cls.p1.partition, data2='bacon') + cls.p3 = UnindexedPoly3.create(partition=cls.p1.partition, data3='turkey') + + @classmethod + def tearDownClass(cls): + super(TestUnindexedPolymorphicQuery, cls).tearDownClass() + management.drop_table(UnindexedPoly1) + management.drop_table(UnindexedPoly2) + management.drop_table(UnindexedPoly3) + + def test_non_conflicting_type_results_work(self): + p1, p2, p3 = self.p1, self.p2, self.p3 + assert len(list(UnindexedPoly1.objects(partition=p1.partition, cluster=p1.cluster))) == 1 + assert len(list(UnindexedPoly2.objects(partition=p1.partition, cluster=p2.cluster))) == 1 + + def test_subclassed_model_results_work_properly(self): + p1, p2, p3 = self.p1, self.p2, self.p3 + assert len(list(UnindexedPoly2.objects(partition=p1.partition, cluster__in=[p2.cluster, p3.cluster]))) == 2 + + def test_conflicting_type_results(self): + with self.assertRaises(models.PolyMorphicModelException): + list(UnindexedPoly1.objects(partition=self.p1.partition)) + with self.assertRaises(models.PolyMorphicModelException): + list(UnindexedPoly2.objects(partition=self.p1.partition)) + + +class IndexedPolyBase(models.Model): + __keyspace__ = 'test' + partition = columns.UUID(primary_key=True, default=uuid.uuid4) + cluster = columns.UUID(primary_key=True, default=uuid.uuid4) + row_type = columns.Integer(polymorphic_key=True, index=True) + + +class IndexedPoly1(IndexedPolyBase): + __polymorphic_key__ = 1 + data1 = columns.Text() + + +class IndexedPoly2(IndexedPolyBase): + __polymorphic_key__ = 2 + data2 = columns.Text() + + +class TestIndexedPolymorphicQuery(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(TestIndexedPolymorphicQuery, cls).setUpClass() + management.sync_table(IndexedPoly1) + management.sync_table(IndexedPoly2) + + cls.p1 = IndexedPoly1.create(data1='pickle') + cls.p2 = IndexedPoly2.create(partition=cls.p1.partition, data2='bacon') + + @classmethod + def tearDownClass(cls): + super(TestIndexedPolymorphicQuery, cls).tearDownClass() + management.drop_table(IndexedPoly1) + management.drop_table(IndexedPoly2) + + def test_success_case(self): + assert len(list(IndexedPoly1.objects(partition=self.p1.partition))) == 1 + assert len(list(IndexedPoly2.objects(partition=self.p1.partition))) == 1 + + diff --git a/cqlengine/tests/model/test_updates.py b/cqlengine/tests/model/test_updates.py new file mode 100644 index 00000000..4ffc817e --- /dev/null +++ b/cqlengine/tests/model/test_updates.py @@ -0,0 +1,91 @@ +from uuid import uuid4 + +from mock import patch +from cqlengine.exceptions import ValidationError + +from cqlengine.tests.base import BaseCassEngTestCase +from cqlengine.models import Model +from cqlengine import columns +from cqlengine.management import sync_table, drop_table + + +class TestUpdateModel(Model): + __keyspace__ = 'test' + partition = columns.UUID(primary_key=True, default=uuid4) + cluster = columns.UUID(primary_key=True, default=uuid4) + count = columns.Integer(required=False) + text = columns.Text(required=False, index=True) + + +class ModelUpdateTests(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(ModelUpdateTests, cls).setUpClass() + sync_table(TestUpdateModel) + + @classmethod + def tearDownClass(cls): + super(ModelUpdateTests, cls).tearDownClass() + drop_table(TestUpdateModel) + + def test_update_model(self): + """ tests calling udpate on models with no values passed in """ + m0 = TestUpdateModel.create(count=5, text='monkey') + + # independently save over a new count value, unknown to original instance + m1 = TestUpdateModel.get(partition=m0.partition, cluster=m0.cluster) + m1.count = 6 + m1.save() + + # update the text, and call update + m0.text = 'monkey land' + m0.update() + + # database should reflect both updates + m2 = TestUpdateModel.get(partition=m0.partition, cluster=m0.cluster) + self.assertEqual(m2.count, m1.count) + self.assertEqual(m2.text, m0.text) + + def test_update_values(self): + """ tests calling update on models with values passed in """ + m0 = TestUpdateModel.create(count=5, text='monkey') + + # independently save over a new count value, unknown to original instance + m1 = TestUpdateModel.get(partition=m0.partition, cluster=m0.cluster) + m1.count = 6 + m1.save() + + # update the text, and call update + m0.update(text='monkey land') + self.assertEqual(m0.text, 'monkey land') + + # database should reflect both updates + m2 = TestUpdateModel.get(partition=m0.partition, cluster=m0.cluster) + self.assertEqual(m2.count, m1.count) + self.assertEqual(m2.text, m0.text) + + def test_noop_model_update(self): + """ tests that calling update on a model with no changes will do nothing. """ + m0 = TestUpdateModel.create(count=5, text='monkey') + + with patch.object(self.session, 'execute') as execute: + m0.update() + assert execute.call_count == 0 + + with patch.object(self.session, 'execute') as execute: + m0.update(count=5) + assert execute.call_count == 0 + + def test_invalid_update_kwarg(self): + """ tests that passing in a kwarg to the update method that isn't a column will fail """ + m0 = TestUpdateModel.create(count=5, text='monkey') + with self.assertRaises(ValidationError): + m0.update(numbers=20) + + def test_primary_key_update_failure(self): + """ tests that attempting to update the value of a primary key will fail """ + m0 = TestUpdateModel.create(count=5, text='monkey') + with self.assertRaises(ValidationError): + m0.update(partition=uuid4()) + diff --git a/cqlengine/tests/model/test_validation.py b/cqlengine/tests/model/test_validation.py index e69de29b..8b137891 100644 --- a/cqlengine/tests/model/test_validation.py +++ b/cqlengine/tests/model/test_validation.py @@ -0,0 +1 @@ + diff --git a/cqlengine/tests/model/test_value_lists.py b/cqlengine/tests/model/test_value_lists.py new file mode 100644 index 00000000..306fa28d --- /dev/null +++ b/cqlengine/tests/model/test_value_lists.py @@ -0,0 +1,61 @@ +import random +from cqlengine.tests.base import BaseCassEngTestCase + +from cqlengine.management import sync_table +from cqlengine.management import drop_table +from cqlengine.models import Model +from cqlengine import columns + + +class TestModel(Model): + __keyspace__ = 'test' + id = columns.Integer(primary_key=True) + clustering_key = columns.Integer(primary_key=True, clustering_order='desc') + +class TestClusteringComplexModel(Model): + __keyspace__ = 'test' + id = columns.Integer(primary_key=True) + clustering_key = columns.Integer(primary_key=True, clustering_order='desc') + some_value = columns.Integer() + +class TestClusteringOrder(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(TestClusteringOrder, cls).setUpClass() + sync_table(TestModel) + + @classmethod + def tearDownClass(cls): + super(TestClusteringOrder, cls).tearDownClass() + drop_table(TestModel) + + def test_clustering_order(self): + """ + Tests that models can be saved and retrieved + """ + items = list(range(20)) + random.shuffle(items) + for i in items: + TestModel.create(id=1, clustering_key=i) + + values = list(TestModel.objects.values_list('clustering_key', flat=True)) + # [19L, 18L, 17L, 16L, 15L, 14L, 13L, 12L, 11L, 10L, 9L, 8L, 7L, 6L, 5L, 4L, 3L, 2L, 1L, 0L] + self.assertEquals(values, sorted(items, reverse=True)) + + def test_clustering_order_more_complex(self): + """ + Tests that models can be saved and retrieved + """ + sync_table(TestClusteringComplexModel) + + items = list(range(20)) + random.shuffle(items) + for i in items: + TestClusteringComplexModel.create(id=1, clustering_key=i, some_value=2) + + values = list(TestClusteringComplexModel.objects.values_list('some_value', flat=True)) + + self.assertEquals([2] * 20, values) + drop_table(TestClusteringComplexModel) + diff --git a/cqlengine/tests/operators/__init__.py b/cqlengine/tests/operators/__init__.py new file mode 100644 index 00000000..f6150a6d --- /dev/null +++ b/cqlengine/tests/operators/__init__.py @@ -0,0 +1 @@ +__author__ = 'bdeggleston' diff --git a/cqlengine/tests/operators/test_assignment_operators.py b/cqlengine/tests/operators/test_assignment_operators.py new file mode 100644 index 00000000..e69de29b diff --git a/cqlengine/tests/operators/test_base_operator.py b/cqlengine/tests/operators/test_base_operator.py new file mode 100644 index 00000000..af13fdb1 --- /dev/null +++ b/cqlengine/tests/operators/test_base_operator.py @@ -0,0 +1,9 @@ +from unittest import TestCase +from cqlengine.operators import BaseQueryOperator, QueryOperatorException + + +class BaseOperatorTest(TestCase): + + def test_get_operator_cannot_be_called_from_base_class(self): + with self.assertRaises(QueryOperatorException): + BaseQueryOperator.get_operator('*') \ No newline at end of file diff --git a/cqlengine/tests/operators/test_where_operators.py b/cqlengine/tests/operators/test_where_operators.py new file mode 100644 index 00000000..f8f0e8fa --- /dev/null +++ b/cqlengine/tests/operators/test_where_operators.py @@ -0,0 +1,30 @@ +from unittest import TestCase +from cqlengine.operators import * + + +class TestWhereOperators(TestCase): + + def test_symbol_lookup(self): + """ tests where symbols are looked up properly """ + + def check_lookup(symbol, expected): + op = BaseWhereOperator.get_operator(symbol) + self.assertEqual(op, expected) + + check_lookup('EQ', EqualsOperator) + check_lookup('IN', InOperator) + check_lookup('GT', GreaterThanOperator) + check_lookup('GTE', GreaterThanOrEqualOperator) + check_lookup('LT', LessThanOperator) + check_lookup('LTE', LessThanOrEqualOperator) + + def test_operator_rendering(self): + """ tests symbols are rendered properly """ + self.assertEqual("=", unicode(EqualsOperator())) + self.assertEqual("IN", unicode(InOperator())) + self.assertEqual(">", unicode(GreaterThanOperator())) + self.assertEqual(">=", unicode(GreaterThanOrEqualOperator())) + self.assertEqual("<", unicode(LessThanOperator())) + self.assertEqual("<=", unicode(LessThanOrEqualOperator())) + + diff --git a/cqlengine/tests/query/test_batch_query.py b/cqlengine/tests/query/test_batch_query.py index 5b31826a..9f3ecb5d 100644 --- a/cqlengine/tests/query/test_batch_query.py +++ b/cqlengine/tests/query/test_batch_query.py @@ -3,28 +3,35 @@ from unittest import skip from uuid import uuid4 import random from cqlengine import Model, columns -from cqlengine.management import delete_table, create_table -from cqlengine.query import BatchQuery +from cqlengine.management import drop_table, sync_table +from cqlengine.query import BatchQuery, DMLQuery from cqlengine.tests.base import BaseCassEngTestCase class TestMultiKeyModel(Model): + __keyspace__ = 'test' partition = columns.Integer(primary_key=True) cluster = columns.Integer(primary_key=True) count = columns.Integer(required=False) text = columns.Text(required=False) +class BatchQueryLogModel(Model): + __keyspace__ = 'test' + # simple k/v table + k = columns.Integer(primary_key=True) + v = columns.Integer() + class BatchQueryTests(BaseCassEngTestCase): @classmethod def setUpClass(cls): super(BatchQueryTests, cls).setUpClass() - delete_table(TestMultiKeyModel) - create_table(TestMultiKeyModel) + drop_table(TestMultiKeyModel) + sync_table(TestMultiKeyModel) @classmethod def tearDownClass(cls): super(BatchQueryTests, cls).tearDownClass() - delete_table(TestMultiKeyModel) + drop_table(TestMultiKeyModel) def setUp(self): super(BatchQueryTests, self).setUp() @@ -104,3 +111,61 @@ class BatchQueryTests(BaseCassEngTestCase): for m in TestMultiKeyModel.all(): m.delete() + def test_none_success_case(self): + """ Tests that passing None into the batch call clears any batch object """ + b = BatchQuery() + + q = TestMultiKeyModel.objects.batch(b) + assert q._batch == b + + q = q.batch(None) + assert q._batch is None + + def test_dml_none_success_case(self): + """ Tests that passing None into the batch call clears any batch object """ + b = BatchQuery() + + q = DMLQuery(TestMultiKeyModel, batch=b) + assert q._batch == b + + q.batch(None) + assert q._batch is None + + def test_batch_execute_on_exception_succeeds(self): + # makes sure if execute_on_exception == True we still apply the batch + drop_table(BatchQueryLogModel) + sync_table(BatchQueryLogModel) + + obj = BatchQueryLogModel.objects(k=1) + self.assertEqual(0, len(obj)) + + try: + with BatchQuery(execute_on_exception=True) as b: + BatchQueryLogModel.batch(b).create(k=1, v=1) + raise Exception("Blah") + except: + pass + + obj = BatchQueryLogModel.objects(k=1) + # should be 1 because the batch should execute + self.assertEqual(1, len(obj)) + + def test_batch_execute_on_exception_skips_if_not_specified(self): + # makes sure if execute_on_exception == True we still apply the batch + drop_table(BatchQueryLogModel) + sync_table(BatchQueryLogModel) + + obj = BatchQueryLogModel.objects(k=2) + self.assertEqual(0, len(obj)) + + try: + with BatchQuery() as b: + BatchQueryLogModel.batch(b).create(k=2, v=2) + raise Exception("Blah") + except: + pass + + obj = BatchQueryLogModel.objects(k=2) + + # should be 0 because the batch should not execute + self.assertEqual(0, len(obj)) diff --git a/cqlengine/tests/query/test_datetime_queries.py b/cqlengine/tests/query/test_datetime_queries.py index e374bd7d..8179b0f2 100644 --- a/cqlengine/tests/query/test_datetime_queries.py +++ b/cqlengine/tests/query/test_datetime_queries.py @@ -4,13 +4,14 @@ from uuid import uuid4 from cqlengine.tests.base import BaseCassEngTestCase from cqlengine.exceptions import ModelException -from cqlengine.management import create_table -from cqlengine.management import delete_table +from cqlengine.management import sync_table +from cqlengine.management import drop_table from cqlengine.models import Model from cqlengine import columns from cqlengine import query class DateTimeQueryTestModel(Model): + __keyspace__ = 'test' user = columns.Integer(primary_key=True) day = columns.DateTime(primary_key=True) data = columns.Text() @@ -20,7 +21,7 @@ class TestDateTimeQueries(BaseCassEngTestCase): @classmethod def setUpClass(cls): super(TestDateTimeQueries, cls).setUpClass() - create_table(DateTimeQueryTestModel) + sync_table(DateTimeQueryTestModel) cls.base_date = datetime.now() - timedelta(days=10) for x in range(7): @@ -35,7 +36,7 @@ class TestDateTimeQueries(BaseCassEngTestCase): @classmethod def tearDownClass(cls): super(TestDateTimeQueries, cls).tearDownClass() - delete_table(DateTimeQueryTestModel) + drop_table(DateTimeQueryTestModel) def test_range_query(self): """ Tests that loading from a range of dates works properly """ @@ -43,7 +44,7 @@ class TestDateTimeQueries(BaseCassEngTestCase): end = start + timedelta(days=3) results = DateTimeQueryTestModel.filter(user=0, day__gte=start, day__lt=end) - assert len(results) == 3 + assert len(results) == 3 def test_datetime_precision(self): """ Tests that millisecond resolution is preserved when saving datetime objects """ diff --git a/cqlengine/tests/query/test_named.py b/cqlengine/tests/query/test_named.py new file mode 100644 index 00000000..38df1542 --- /dev/null +++ b/cqlengine/tests/query/test_named.py @@ -0,0 +1,246 @@ +from cqlengine import operators +from cqlengine.named import NamedKeyspace +from cqlengine.operators import EqualsOperator, GreaterThanOrEqualOperator +from cqlengine.query import ResultObject +from cqlengine.tests.query.test_queryset import BaseQuerySetUsage +from cqlengine.tests.base import BaseCassEngTestCase + + +class TestQuerySetOperation(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(TestQuerySetOperation, cls).setUpClass() + cls.keyspace = NamedKeyspace('cqlengine_test') + cls.table = cls.keyspace.table('test_model') + + def test_query_filter_parsing(self): + """ + Tests the queryset filter method parses it's kwargs properly + """ + query1 = self.table.objects(test_id=5) + assert len(query1._where) == 1 + + op = query1._where[0] + assert isinstance(op.operator, operators.EqualsOperator) + assert op.value == 5 + + query2 = query1.filter(expected_result__gte=1) + assert len(query2._where) == 2 + + op = query2._where[1] + assert isinstance(op.operator, operators.GreaterThanOrEqualOperator) + assert op.value == 1 + + def test_query_expression_parsing(self): + """ Tests that query experessions are evaluated properly """ + query1 = self.table.filter(self.table.column('test_id') == 5) + assert len(query1._where) == 1 + + op = query1._where[0] + assert isinstance(op.operator, operators.EqualsOperator) + assert op.value == 5 + + query2 = query1.filter(self.table.column('expected_result') >= 1) + assert len(query2._where) == 2 + + op = query2._where[1] + assert isinstance(op.operator, operators.GreaterThanOrEqualOperator) + assert op.value == 1 + + def test_filter_method_where_clause_generation(self): + """ + Tests the where clause creation + """ + query1 = self.table.objects(test_id=5) + self.assertEqual(len(query1._where), 1) + where = query1._where[0] + self.assertEqual(where.field, 'test_id') + self.assertEqual(where.value, 5) + + query2 = query1.filter(expected_result__gte=1) + self.assertEqual(len(query2._where), 2) + + where = query2._where[0] + self.assertEqual(where.field, 'test_id') + self.assertIsInstance(where.operator, EqualsOperator) + self.assertEqual(where.value, 5) + + where = query2._where[1] + self.assertEqual(where.field, 'expected_result') + self.assertIsInstance(where.operator, GreaterThanOrEqualOperator) + self.assertEqual(where.value, 1) + + def test_query_expression_where_clause_generation(self): + """ + Tests the where clause creation + """ + query1 = self.table.objects(self.table.column('test_id') == 5) + self.assertEqual(len(query1._where), 1) + where = query1._where[0] + self.assertEqual(where.field, 'test_id') + self.assertEqual(where.value, 5) + + query2 = query1.filter(self.table.column('expected_result') >= 1) + self.assertEqual(len(query2._where), 2) + + where = query2._where[0] + self.assertEqual(where.field, 'test_id') + self.assertIsInstance(where.operator, EqualsOperator) + self.assertEqual(where.value, 5) + + where = query2._where[1] + self.assertEqual(where.field, 'expected_result') + self.assertIsInstance(where.operator, GreaterThanOrEqualOperator) + self.assertEqual(where.value, 1) + + +class TestQuerySetCountSelectionAndIteration(BaseQuerySetUsage): + + @classmethod + def setUpClass(cls): + super(TestQuerySetCountSelectionAndIteration, cls).setUpClass() + + from cqlengine.tests.query.test_queryset import TestModel + + ks,tn = TestModel.column_family_name().split('.') + cls.keyspace = NamedKeyspace(ks) + cls.table = cls.keyspace.table(tn) + + def test_count(self): + """ Tests that adding filtering statements affects the count query as expected """ + assert self.table.objects.count() == 12 + + q = self.table.objects(test_id=0) + assert q.count() == 4 + + def test_query_expression_count(self): + """ Tests that adding query statements affects the count query as expected """ + assert self.table.objects.count() == 12 + + q = self.table.objects(self.table.column('test_id') == 0) + assert q.count() == 4 + + def test_iteration(self): + """ Tests that iterating over a query set pulls back all of the expected results """ + q = self.table.objects(test_id=0) + #tuple of expected attempt_id, expected_result values + compare_set = set([(0,5), (1,10), (2,15), (3,20)]) + for t in q: + val = t.attempt_id, t.expected_result + assert val in compare_set + compare_set.remove(val) + assert len(compare_set) == 0 + + # test with regular filtering + q = self.table.objects(attempt_id=3).allow_filtering() + assert len(q) == 3 + #tuple of expected test_id, expected_result values + compare_set = set([(0,20), (1,20), (2,75)]) + for t in q: + val = t.test_id, t.expected_result + assert val in compare_set + compare_set.remove(val) + assert len(compare_set) == 0 + + # test with query method + q = self.table.objects(self.table.column('attempt_id') == 3).allow_filtering() + assert len(q) == 3 + #tuple of expected test_id, expected_result values + compare_set = set([(0,20), (1,20), (2,75)]) + for t in q: + val = t.test_id, t.expected_result + assert val in compare_set + compare_set.remove(val) + assert len(compare_set) == 0 + + def test_multiple_iterations_work_properly(self): + """ Tests that iterating over a query set more than once works """ + # test with both the filtering method and the query method + for q in (self.table.objects(test_id=0), self.table.objects(self.table.column('test_id') == 0)): + #tuple of expected attempt_id, expected_result values + compare_set = set([(0,5), (1,10), (2,15), (3,20)]) + for t in q: + val = t.attempt_id, t.expected_result + assert val in compare_set + compare_set.remove(val) + assert len(compare_set) == 0 + + #try it again + compare_set = set([(0,5), (1,10), (2,15), (3,20)]) + for t in q: + val = t.attempt_id, t.expected_result + assert val in compare_set + compare_set.remove(val) + assert len(compare_set) == 0 + + def test_multiple_iterators_are_isolated(self): + """ + tests that the use of one iterator does not affect the behavior of another + """ + for q in (self.table.objects(test_id=0), self.table.objects(self.table.column('test_id') == 0)): + q = q.order_by('attempt_id') + expected_order = [0,1,2,3] + iter1 = iter(q) + iter2 = iter(q) + for attempt_id in expected_order: + assert iter1.next().attempt_id == attempt_id + assert iter2.next().attempt_id == attempt_id + + def test_get_success_case(self): + """ + Tests that the .get() method works on new and existing querysets + """ + m = self.table.objects.get(test_id=0, attempt_id=0) + assert isinstance(m, ResultObject) + assert m.test_id == 0 + assert m.attempt_id == 0 + + q = self.table.objects(test_id=0, attempt_id=0) + m = q.get() + assert isinstance(m, ResultObject) + assert m.test_id == 0 + assert m.attempt_id == 0 + + q = self.table.objects(test_id=0) + m = q.get(attempt_id=0) + assert isinstance(m, ResultObject) + assert m.test_id == 0 + assert m.attempt_id == 0 + + def test_query_expression_get_success_case(self): + """ + Tests that the .get() method works on new and existing querysets + """ + m = self.table.get(self.table.column('test_id') == 0, self.table.column('attempt_id') == 0) + assert isinstance(m, ResultObject) + assert m.test_id == 0 + assert m.attempt_id == 0 + + q = self.table.objects(self.table.column('test_id') == 0, self.table.column('attempt_id') == 0) + m = q.get() + assert isinstance(m, ResultObject) + assert m.test_id == 0 + assert m.attempt_id == 0 + + q = self.table.objects(self.table.column('test_id') == 0) + m = q.get(self.table.column('attempt_id') == 0) + assert isinstance(m, ResultObject) + assert m.test_id == 0 + assert m.attempt_id == 0 + + def test_get_doesnotexist_exception(self): + """ + Tests that get calls that don't return a result raises a DoesNotExist error + """ + with self.assertRaises(self.table.DoesNotExist): + self.table.objects.get(test_id=100) + + def test_get_multipleobjects_exception(self): + """ + Tests that get calls that return multiple results raise a MultipleObjectsReturned error + """ + with self.assertRaises(self.table.MultipleObjectsReturned): + self.table.objects.get(test_id=1) + + diff --git a/cqlengine/tests/query/test_queryoperators.py b/cqlengine/tests/query/test_queryoperators.py index 5db4a299..5572d00c 100644 --- a/cqlengine/tests/query/test_queryoperators.py +++ b/cqlengine/tests/query/test_queryoperators.py @@ -1,10 +1,13 @@ from datetime import datetime -import time +from cqlengine.columns import DateTime from cqlengine.tests.base import BaseCassEngTestCase from cqlengine import columns, Model from cqlengine import functions from cqlengine import query +from cqlengine.statements import WhereClause +from cqlengine.operators import EqualsOperator +from cqlengine.management import sync_table, drop_table class TestQuerySetOperation(BaseCassEngTestCase): @@ -13,22 +16,96 @@ class TestQuerySetOperation(BaseCassEngTestCase): Tests that queries with helper functions are generated properly """ now = datetime.now() - col = columns.DateTime() - col.set_column_name('time') - qry = query.EqualsOperator(col, functions.MaxTimeUUID(now)) + where = WhereClause('time', EqualsOperator(), functions.MaxTimeUUID(now)) + where.set_context_id(5) - assert qry.cql == '"time" = MaxTimeUUID(:{})'.format(qry.identifier) + self.assertEqual(str(where), '"time" = MaxTimeUUID(%(5)s)') + ctx = {} + where.update_context(ctx) + self.assertEqual(ctx, {'5': DateTime().to_database(now)}) def test_mintimeuuid_function(self): """ Tests that queries with helper functions are generated properly """ now = datetime.now() - col = columns.DateTime() - col.set_column_name('time') - qry = query.EqualsOperator(col, functions.MinTimeUUID(now)) + where = WhereClause('time', EqualsOperator(), functions.MinTimeUUID(now)) + where.set_context_id(5) - assert qry.cql == '"time" = MinTimeUUID(:{})'.format(qry.identifier) + self.assertEqual(str(where), '"time" = MinTimeUUID(%(5)s)') + ctx = {} + where.update_context(ctx) + self.assertEqual(ctx, {'5': DateTime().to_database(now)}) +class TokenTestModel(Model): + __keyspace__ = 'test' + key = columns.Integer(primary_key=True) + val = columns.Integer() + +class TestTokenFunction(BaseCassEngTestCase): + + def setUp(self): + super(TestTokenFunction, self).setUp() + sync_table(TokenTestModel) + + def tearDown(self): + super(TestTokenFunction, self).tearDown() + drop_table(TokenTestModel) + + def test_token_function(self): + """ Tests that token functions work properly """ + assert TokenTestModel.objects().count() == 0 + for i in range(10): + TokenTestModel.create(key=i, val=i) + assert TokenTestModel.objects().count() == 10 + seen_keys = set() + last_token = None + for instance in TokenTestModel.objects().limit(5): + last_token = instance.key + seen_keys.add(last_token) + assert len(seen_keys) == 5 + for instance in TokenTestModel.objects(pk__token__gt=functions.Token(last_token)): + seen_keys.add(instance.key) + + assert len(seen_keys) == 10 + assert all([i in seen_keys for i in range(10)]) + + def test_compound_pk_token_function(self): + + class TestModel(Model): + __keyspace__ = 'test' + p1 = columns.Text(partition_key=True) + p2 = columns.Text(partition_key=True) + + func = functions.Token('a', 'b') + + q = TestModel.objects.filter(pk__token__gt=func) + where = q._where[0] + where.set_context_id(1) + self.assertEquals(str(where), 'token("p1", "p2") > token(%({})s, %({})s)'.format(1, 2)) + + # Verify that a SELECT query can be successfully generated + str(q._select_query()) + + # Token(tuple()) is also possible for convenience + # it (allows for Token(obj.pk) syntax) + func = functions.Token(('a', 'b')) + + q = TestModel.objects.filter(pk__token__gt=func) + where = q._where[0] + where.set_context_id(1) + self.assertEquals(str(where), 'token("p1", "p2") > token(%({})s, %({})s)'.format(1, 2)) + str(q._select_query()) + + # The 'pk__token' virtual column may only be compared to a Token + self.assertRaises(query.QueryException, TestModel.objects.filter, pk__token__gt=10) + + # A Token may only be compared to the `pk__token' virtual column + func = functions.Token('a', 'b') + self.assertRaises(query.QueryException, TestModel.objects.filter, p1__gt=func) + + # The # of arguments to Token must match the # of partition keys + func = functions.Token('a') + self.assertRaises(query.QueryException, TestModel.objects.filter, pk__token__gt=func) diff --git a/cqlengine/tests/query/test_queryset.py b/cqlengine/tests/query/test_queryset.py index a77d7502..1c9a3ed5 100644 --- a/cqlengine/tests/query/test_queryset.py +++ b/cqlengine/tests/query/test_queryset.py @@ -1,39 +1,74 @@ from datetime import datetime import time +from unittest import TestCase, skipUnless from uuid import uuid1, uuid4 +import uuid from cqlengine.tests.base import BaseCassEngTestCase - +import mock from cqlengine.exceptions import ModelException from cqlengine import functions -from cqlengine.management import create_table -from cqlengine.management import delete_table +from cqlengine.management import sync_table, drop_table, sync_table +from cqlengine.management import drop_table from cqlengine.models import Model from cqlengine import columns from cqlengine import query +from datetime import timedelta +from datetime import tzinfo + +from cqlengine import statements +from cqlengine import operators + + +from cqlengine.connection import get_cluster, get_session + +cluster = get_cluster() + + +class TzOffset(tzinfo): + """Minimal implementation of a timezone offset to help testing with timezone + aware datetimes. + """ + + def __init__(self, offset): + self._offset = timedelta(hours=offset) + + def utcoffset(self, dt): + return self._offset + + def tzname(self, dt): + return 'TzOffset: {}'.format(self._offset.hours) + + def dst(self, dt): + return timedelta(0) + class TestModel(Model): + __keyspace__ = 'test' test_id = columns.Integer(primary_key=True) attempt_id = columns.Integer(primary_key=True) description = columns.Text() expected_result = columns.Integer() test_result = columns.Integer() + class IndexedTestModel(Model): + __keyspace__ = 'test' test_id = columns.Integer(primary_key=True) attempt_id = columns.Integer(index=True) description = columns.Text() expected_result = columns.Integer() test_result = columns.Integer(index=True) + class TestMultiClusteringModel(Model): + __keyspace__ = 'test' one = columns.Integer(primary_key=True) two = columns.Integer(primary_key=True) three = columns.Integer(primary_key=True) class TestQuerySetOperation(BaseCassEngTestCase): - def test_query_filter_parsing(self): """ Tests the queryset filter method parses it's kwargs properly @@ -42,14 +77,35 @@ class TestQuerySetOperation(BaseCassEngTestCase): assert len(query1._where) == 1 op = query1._where[0] - assert isinstance(op, query.EqualsOperator) + + assert isinstance(op, statements.WhereClause) + assert isinstance(op.operator, operators.EqualsOperator) assert op.value == 5 query2 = query1.filter(expected_result__gte=1) assert len(query2._where) == 2 op = query2._where[1] - assert isinstance(op, query.GreaterThanOrEqualOperator) + self.assertIsInstance(op, statements.WhereClause) + self.assertIsInstance(op.operator, operators.GreaterThanOrEqualOperator) + assert op.value == 1 + + def test_query_expression_parsing(self): + """ Tests that query experessions are evaluated properly """ + query1 = TestModel.filter(TestModel.test_id == 5) + assert len(query1._where) == 1 + + op = query1._where[0] + assert isinstance(op, statements.WhereClause) + assert isinstance(op.operator, operators.EqualsOperator) + assert op.value == 5 + + query2 = query1.filter(TestModel.expected_result >= 1) + assert len(query2._where) == 2 + + op = query2._where[1] + self.assertIsInstance(op, statements.WhereClause) + self.assertIsInstance(op.operator, operators.GreaterThanOrEqualOperator) assert op.value == 1 def test_using_invalid_column_names_in_filter_kwargs_raises_error(self): @@ -57,27 +113,21 @@ class TestQuerySetOperation(BaseCassEngTestCase): Tests that using invalid or nonexistant column names for filter args raises an error """ with self.assertRaises(query.QueryException): - query0 = TestModel.objects(nonsense=5) + TestModel.objects(nonsense=5) - def test_where_clause_generation(self): + def test_using_nonexistant_column_names_in_query_args_raises_error(self): """ - Tests the where clause creation + Tests that using invalid or nonexistant columns for query args raises an error """ - query1 = TestModel.objects(test_id=5) - ids = [o.identifier for o in query1._where] - where = query1._where_clause() - assert where == '"test_id" = :{}'.format(*ids) + with self.assertRaises(AttributeError): + TestModel.objects(TestModel.nonsense == 5) - query2 = query1.filter(expected_result__gte=1) - ids = [o.identifier for o in query2._where] - where = query2._where_clause() - assert where == '"test_id" = :{} AND "expected_result" >= :{}'.format(*ids) - - - def test_querystring_generation(self): + def test_using_non_query_operators_in_query_args_raises_error(self): """ - Tests the select querystring creation + Tests that providing query args that are not query operator instances raises an error """ + with self.assertRaises(query.QueryException): + TestModel.objects(5) def test_queryset_is_immutable(self): """ @@ -88,10 +138,25 @@ class TestQuerySetOperation(BaseCassEngTestCase): query2 = query1.filter(expected_result__gte=1) assert len(query2._where) == 2 + assert len(query1._where) == 1 - def test_the_all_method_clears_where_filter(self): + def test_queryset_limit_immutability(self): """ - Tests that calling all on a queryset with previously defined filters returns a queryset with no filters + Tests that calling a queryset function that changes it's state returns a new queryset with same limit + """ + query1 = TestModel.objects(test_id=5).limit(1) + assert query1._limit == 1 + + query2 = query1.filter(expected_result__gte=1) + assert query2._limit == 1 + + query3 = query1.filter(expected_result__gte=1).limit(2) + assert query1._limit == 1 + assert query3._limit == 2 + + def test_the_all_method_duplicates_queryset(self): + """ + Tests that calling all on a queryset with previously defined filters duplicates queryset """ query1 = TestModel.objects(test_id=5) assert len(query1._where) == 1 @@ -100,7 +165,7 @@ class TestQuerySetOperation(BaseCassEngTestCase): assert len(query2._where) == 2 query3 = query2.all() - assert len(query3._where) == 0 + assert query3 == query2 def test_defining_only_and_defer_fails(self): """ @@ -112,16 +177,16 @@ class TestQuerySetOperation(BaseCassEngTestCase): Tests that setting only or defer fields that don't exist raises an exception """ -class BaseQuerySetUsage(BaseCassEngTestCase): +class BaseQuerySetUsage(BaseCassEngTestCase): @classmethod def setUpClass(cls): super(BaseQuerySetUsage, cls).setUpClass() - delete_table(TestModel) - delete_table(IndexedTestModel) - create_table(TestModel) - create_table(IndexedTestModel) - create_table(TestMultiClusteringModel) + drop_table(TestModel) + drop_table(IndexedTestModel) + sync_table(TestModel) + sync_table(IndexedTestModel) + sync_table(TestMultiClusteringModel) TestModel.objects.create(test_id=0, attempt_id=0, description='try1', expected_result=5, test_result=30) TestModel.objects.create(test_id=0, attempt_id=1, description='try2', expected_result=10, test_result=30) @@ -149,39 +214,71 @@ class BaseQuerySetUsage(BaseCassEngTestCase): IndexedTestModel.objects.create(test_id=7, attempt_id=3, description='try8', expected_result=20, test_result=20) IndexedTestModel.objects.create(test_id=8, attempt_id=0, description='try9', expected_result=50, test_result=40) - IndexedTestModel.objects.create(test_id=9, attempt_id=1, description='try10', expected_result=60, test_result=40) - IndexedTestModel.objects.create(test_id=10, attempt_id=2, description='try11', expected_result=70, test_result=45) - IndexedTestModel.objects.create(test_id=11, attempt_id=3, description='try12', expected_result=75, test_result=45) + IndexedTestModel.objects.create(test_id=9, attempt_id=1, description='try10', expected_result=60, + test_result=40) + IndexedTestModel.objects.create(test_id=10, attempt_id=2, description='try11', expected_result=70, + test_result=45) + IndexedTestModel.objects.create(test_id=11, attempt_id=3, description='try12', expected_result=75, + test_result=45) @classmethod def tearDownClass(cls): super(BaseQuerySetUsage, cls).tearDownClass() - delete_table(TestModel) - delete_table(IndexedTestModel) - delete_table(TestMultiClusteringModel) + drop_table(TestModel) + drop_table(IndexedTestModel) + drop_table(TestMultiClusteringModel) + class TestQuerySetCountSelectionAndIteration(BaseQuerySetUsage): - def test_count(self): + """ Tests that adding filtering statements affects the count query as expected """ assert TestModel.objects.count() == 12 q = TestModel.objects(test_id=0) assert q.count() == 4 + def test_query_expression_count(self): + """ Tests that adding query statements affects the count query as expected """ + assert TestModel.objects.count() == 12 + + q = TestModel.objects(TestModel.test_id == 0) + assert q.count() == 4 + + def test_query_limit_count(self): + """ Tests that adding query with a limit affects the count as expected """ + assert TestModel.objects.count() == 12 + + q = TestModel.objects(TestModel.test_id == 0).limit(2) + result = q.count() + self.assertEqual(2, result) + def test_iteration(self): + """ Tests that iterating over a query set pulls back all of the expected results """ q = TestModel.objects(test_id=0) #tuple of expected attempt_id, expected_result values - compare_set = set([(0,5), (1,10), (2,15), (3,20)]) + compare_set = set([(0, 5), (1, 10), (2, 15), (3, 20)]) for t in q: val = t.attempt_id, t.expected_result assert val in compare_set compare_set.remove(val) assert len(compare_set) == 0 + # test with regular filtering q = TestModel.objects(attempt_id=3).allow_filtering() assert len(q) == 3 #tuple of expected test_id, expected_result values - compare_set = set([(0,20), (1,20), (2,75)]) + compare_set = set([(0, 20), (1, 20), (2, 75)]) + for t in q: + val = t.test_id, t.expected_result + assert val in compare_set + compare_set.remove(val) + assert len(compare_set) == 0 + + # test with query method + q = TestModel.objects(TestModel.attempt_id == 3).allow_filtering() + assert len(q) == 3 + #tuple of expected test_id, expected_result values + compare_set = set([(0, 20), (1, 20), (2, 75)]) for t in q: val = t.test_id, t.expected_result assert val in compare_set @@ -190,34 +287,36 @@ class TestQuerySetCountSelectionAndIteration(BaseQuerySetUsage): def test_multiple_iterations_work_properly(self): """ Tests that iterating over a query set more than once works """ - q = TestModel.objects(test_id=0) - #tuple of expected attempt_id, expected_result values - compare_set = set([(0,5), (1,10), (2,15), (3,20)]) - for t in q: - val = t.attempt_id, t.expected_result - assert val in compare_set - compare_set.remove(val) - assert len(compare_set) == 0 + # test with both the filtering method and the query method + for q in (TestModel.objects(test_id=0), TestModel.objects(TestModel.test_id == 0)): + #tuple of expected attempt_id, expected_result values + compare_set = set([(0, 5), (1, 10), (2, 15), (3, 20)]) + for t in q: + val = t.attempt_id, t.expected_result + assert val in compare_set + compare_set.remove(val) + assert len(compare_set) == 0 - #try it again - compare_set = set([(0,5), (1,10), (2,15), (3,20)]) - for t in q: - val = t.attempt_id, t.expected_result - assert val in compare_set - compare_set.remove(val) - assert len(compare_set) == 0 + #try it again + compare_set = set([(0, 5), (1, 10), (2, 15), (3, 20)]) + for t in q: + val = t.attempt_id, t.expected_result + assert val in compare_set + compare_set.remove(val) + assert len(compare_set) == 0 def test_multiple_iterators_are_isolated(self): """ tests that the use of one iterator does not affect the behavior of another """ - q = TestModel.objects(test_id=0).order_by('attempt_id') - expected_order = [0,1,2,3] - iter1 = iter(q) - iter2 = iter(q) - for attempt_id in expected_order: - assert iter1.next().attempt_id == attempt_id - assert iter2.next().attempt_id == attempt_id + for q in (TestModel.objects(test_id=0), TestModel.objects(TestModel.test_id == 0)): + q = q.order_by('attempt_id') + expected_order = [0, 1, 2, 3] + iter1 = iter(q) + iter2 = iter(q) + for attempt_id in expected_order: + assert iter1.next().attempt_id == attempt_id + assert iter2.next().attempt_id == attempt_id def test_get_success_case(self): """ @@ -240,6 +339,27 @@ class TestQuerySetCountSelectionAndIteration(BaseQuerySetUsage): assert m.test_id == 0 assert m.attempt_id == 0 + def test_query_expression_get_success_case(self): + """ + Tests that the .get() method works on new and existing querysets + """ + m = TestModel.get(TestModel.test_id == 0, TestModel.attempt_id == 0) + assert isinstance(m, TestModel) + assert m.test_id == 0 + assert m.attempt_id == 0 + + q = TestModel.objects(TestModel.test_id == 0, TestModel.attempt_id == 0) + m = q.get() + assert isinstance(m, TestModel) + assert m.test_id == 0 + assert m.attempt_id == 0 + + q = TestModel.objects(TestModel.test_id == 0) + m = q.get(TestModel.attempt_id == 0) + assert isinstance(m, TestModel) + assert m.test_id == 0 + assert m.attempt_id == 0 + def test_get_doesnotexist_exception(self): """ Tests that get calls that don't return a result raises a DoesNotExist error @@ -258,25 +378,52 @@ class TestQuerySetCountSelectionAndIteration(BaseQuerySetUsage): """ """ + +def test_non_quality_filtering(): + class NonEqualityFilteringModel(Model): + __keyspace__ = 'test' + example_id = columns.UUID(primary_key=True, default=uuid.uuid4) + sequence_id = columns.Integer(primary_key=True) # sequence_id is a clustering key + example_type = columns.Integer(index=True) + created_at = columns.DateTime() + + drop_table(NonEqualityFilteringModel) + sync_table(NonEqualityFilteringModel) + + # setup table, etc. + + NonEqualityFilteringModel.create(sequence_id=1, example_type=0, created_at=datetime.now()) + NonEqualityFilteringModel.create(sequence_id=3, example_type=0, created_at=datetime.now()) + NonEqualityFilteringModel.create(sequence_id=5, example_type=1, created_at=datetime.now()) + + qA = NonEqualityFilteringModel.objects(NonEqualityFilteringModel.sequence_id > 3).allow_filtering() + num = qA.count() + assert num == 1, num + + + class TestQuerySetOrdering(BaseQuerySetUsage): def test_order_by_success_case(self): - q = TestModel.objects(test_id=0).order_by('attempt_id') - expected_order = [0,1,2,3] - for model, expect in zip(q,expected_order): + expected_order = [0, 1, 2, 3] + for model, expect in zip(q, expected_order): assert model.attempt_id == expect q = q.order_by('-attempt_id') expected_order.reverse() - for model, expect in zip(q,expected_order): + for model, expect in zip(q, expected_order): assert model.attempt_id == expect def test_ordering_by_non_second_primary_keys_fail(self): - + # kwarg filtering with self.assertRaises(query.QueryException): q = TestModel.objects(test_id=0).order_by('test_id') + # kwarg filtering + with self.assertRaises(query.QueryException): + q = TestModel.objects(TestModel.test_id == 0).order_by('test_id') + def test_ordering_by_non_primary_keys_fails(self): with self.assertRaises(query.QueryException): q = TestModel.objects(test_id=0).order_by('description') @@ -303,7 +450,6 @@ class TestQuerySetOrdering(BaseQuerySetUsage): class TestQuerySetSlicing(BaseQuerySetUsage): - def test_out_of_range_index_raises_error(self): q = TestModel.objects(test_id=0).order_by('attempt_id') with self.assertRaises(IndexError): @@ -311,39 +457,39 @@ class TestQuerySetSlicing(BaseQuerySetUsage): def test_array_indexing_works_properly(self): q = TestModel.objects(test_id=0).order_by('attempt_id') - expected_order = [0,1,2,3] + expected_order = [0, 1, 2, 3] for i in range(len(q)): assert q[i].attempt_id == expected_order[i] def test_negative_indexing_works_properly(self): q = TestModel.objects(test_id=0).order_by('attempt_id') - expected_order = [0,1,2,3] + expected_order = [0, 1, 2, 3] assert q[-1].attempt_id == expected_order[-1] assert q[-2].attempt_id == expected_order[-2] def test_slicing_works_properly(self): q = TestModel.objects(test_id=0).order_by('attempt_id') - expected_order = [0,1,2,3] + expected_order = [0, 1, 2, 3] for model, expect in zip(q[1:3], expected_order[1:3]): assert model.attempt_id == expect def test_negative_slicing(self): q = TestModel.objects(test_id=0).order_by('attempt_id') - expected_order = [0,1,2,3] + expected_order = [0, 1, 2, 3] for model, expect in zip(q[-3:], expected_order[-3:]): assert model.attempt_id == expect for model, expect in zip(q[:-1], expected_order[:-1]): assert model.attempt_id == expect -class TestQuerySetValidation(BaseQuerySetUsage): +class TestQuerySetValidation(BaseQuerySetUsage): def test_primary_key_or_index_must_be_specified(self): """ Tests that queries that don't have an equals relation to a primary key or indexed field fail """ with self.assertRaises(query.QueryException): q = TestModel.objects(test_result=25) - [i for i in q] + list([i for i in q]) def test_primary_key_or_index_must_have_equal_relation_filter(self): """ @@ -351,19 +497,17 @@ class TestQuerySetValidation(BaseQuerySetUsage): """ with self.assertRaises(query.QueryException): q = TestModel.objects(test_id__gt=0) - [i for i in q] - + list([i for i in q]) def test_indexed_field_can_be_queried(self): """ Tests that queries on an indexed field will work without any primary key relations specified """ q = IndexedTestModel.objects(test_result=25) - count = q.count() assert q.count() == 4 -class TestQuerySetDelete(BaseQuerySetUsage): +class TestQuerySetDelete(BaseQuerySetUsage): def test_delete(self): TestModel.objects.create(test_id=3, attempt_id=0, description='try9', expected_result=50, test_result=40) TestModel.objects.create(test_id=3, attempt_id=1, description='try10', expected_result=60, test_result=40) @@ -388,15 +532,15 @@ class TestQuerySetDelete(BaseQuerySetUsage): with self.assertRaises(query.QueryException): TestModel.objects(attempt_id=0).delete() -class TestQuerySetConnectionHandling(BaseQuerySetUsage): +class TestQuerySetConnectionHandling(BaseQuerySetUsage): def test_conn_is_returned_after_filling_cache(self): """ Tests that the queryset returns it's connection after it's fetched all of it's results """ q = TestModel.objects(test_id=0) #tuple of expected attempt_id, expected_result values - compare_set = set([(0,5), (1,10), (2,15), (3,20)]) + compare_set = set([(0, 5), (1, 10), (2, 15), (3, 20)]) for t in q: val = t.attempt_id, t.expected_result assert val in compare_set @@ -405,35 +549,67 @@ class TestQuerySetConnectionHandling(BaseQuerySetUsage): assert q._con is None assert q._cur is None - def test_conn_is_returned_after_queryset_is_garbage_collected(self): - """ Tests that the connection is returned to the connection pool after the queryset is gc'd """ - from cqlengine.connection import ConnectionPool - # The queue size can be 1 if we just run this file's tests - # It will be 2 when we run 'em all - initial_size = ConnectionPool._queue.qsize() - q = TestModel.objects(test_id=0) - v = q[0] - assert ConnectionPool._queue.qsize() == initial_size - 1 - - del q - assert ConnectionPool._queue.qsize() == initial_size class TimeUUIDQueryModel(Model): - partition = columns.UUID(primary_key=True) - time = columns.TimeUUID(primary_key=True) - data = columns.Text(required=False) + __keyspace__ = 'test' + partition = columns.UUID(primary_key=True) + time = columns.TimeUUID(primary_key=True) + data = columns.Text(required=False) + class TestMinMaxTimeUUIDFunctions(BaseCassEngTestCase): - @classmethod def setUpClass(cls): super(TestMinMaxTimeUUIDFunctions, cls).setUpClass() - create_table(TimeUUIDQueryModel) + sync_table(TimeUUIDQueryModel) @classmethod def tearDownClass(cls): super(TestMinMaxTimeUUIDFunctions, cls).tearDownClass() - delete_table(TimeUUIDQueryModel) + drop_table(TimeUUIDQueryModel) + + def test_tzaware_datetime_support(self): + """Test that using timezone aware datetime instances works with the + MinTimeUUID/MaxTimeUUID functions. + """ + pk = uuid4() + midpoint_utc = datetime.utcnow().replace(tzinfo=TzOffset(0)) + midpoint_helsinki = midpoint_utc.astimezone(TzOffset(3)) + + # Assert pre-condition that we have the same logical point in time + assert midpoint_utc.utctimetuple() == midpoint_helsinki.utctimetuple() + assert midpoint_utc.timetuple() != midpoint_helsinki.timetuple() + + TimeUUIDQueryModel.create( + partition=pk, + time=columns.TimeUUID.from_datetime(midpoint_utc - timedelta(minutes=1)), + data='1') + + TimeUUIDQueryModel.create( + partition=pk, + time=columns.TimeUUID.from_datetime(midpoint_utc), + data='2') + + TimeUUIDQueryModel.create( + partition=pk, + time=columns.TimeUUID.from_datetime(midpoint_utc + timedelta(minutes=1)), + data='3') + + assert ['1', '2'] == [o.data for o in TimeUUIDQueryModel.filter( + TimeUUIDQueryModel.partition == pk, + TimeUUIDQueryModel.time <= functions.MaxTimeUUID(midpoint_utc))] + + assert ['1', '2'] == [o.data for o in TimeUUIDQueryModel.filter( + TimeUUIDQueryModel.partition == pk, + TimeUUIDQueryModel.time <= functions.MaxTimeUUID(midpoint_helsinki))] + + assert ['2', '3'] == [o.data for o in TimeUUIDQueryModel.filter( + TimeUUIDQueryModel.partition == pk, + TimeUUIDQueryModel.time >= functions.MinTimeUUID(midpoint_utc))] + + assert ['2', '3'] == [o.data for o in TimeUUIDQueryModel.filter( + TimeUUIDQueryModel.partition == pk, + TimeUUIDQueryModel.time >= functions.MinTimeUUID(midpoint_helsinki))] def test_success_case(self): """ Test that the min and max time uuid functions work as expected """ @@ -449,26 +625,85 @@ class TestMinMaxTimeUUIDFunctions(BaseCassEngTestCase): TimeUUIDQueryModel.create(partition=pk, time=uuid1(), data='4') time.sleep(0.2) + # test kwarg filtering q = TimeUUIDQueryModel.filter(partition=pk, time__lte=functions.MaxTimeUUID(midpoint)) q = [d for d in q] assert len(q) == 2 datas = [d.data for d in q] - assert '1' in datas - assert '2' in datas + assert '1' in datas + assert '2' in datas q = TimeUUIDQueryModel.filter(partition=pk, time__gte=functions.MinTimeUUID(midpoint)) assert len(q) == 2 datas = [d.data for d in q] - assert '3' in datas - assert '4' in datas + assert '3' in datas + assert '4' in datas + + # test query expression filtering + q = TimeUUIDQueryModel.filter( + TimeUUIDQueryModel.partition == pk, + TimeUUIDQueryModel.time <= functions.MaxTimeUUID(midpoint) + ) + q = [d for d in q] + assert len(q) == 2 + datas = [d.data for d in q] + assert '1' in datas + assert '2' in datas + + q = TimeUUIDQueryModel.filter( + TimeUUIDQueryModel.partition == pk, + TimeUUIDQueryModel.time >= functions.MinTimeUUID(midpoint) + ) + assert len(q) == 2 + datas = [d.data for d in q] + assert '3' in datas + assert '4' in datas class TestInOperator(BaseQuerySetUsage): + def test_kwarg_success_case(self): + """ Tests the in operator works with the kwarg query method """ + q = TestModel.filter(test_id__in=[0, 1]) + assert q.count() == 8 - def test_success_case(self): - q = TestModel.filter(test_id__in=[0,1]) + def test_query_expression_success_case(self): + """ Tests the in operator works with the query expression query method """ + q = TestModel.filter(TestModel.test_id.in_([0, 1])) assert q.count() == 8 +class TestValuesList(BaseQuerySetUsage): + def test_values_list(self): + q = TestModel.objects.filter(test_id=0, attempt_id=1) + item = q.values_list('test_id', 'attempt_id', 'description', 'expected_result', 'test_result').first() + assert item == [0, 1, 'try2', 10, 30] + + item = q.values_list('expected_result', flat=True).first() + assert item == 10 + + +class TestObjectsProperty(BaseQuerySetUsage): + def test_objects_property_returns_fresh_queryset(self): + assert TestModel.objects._result_cache is None + len(TestModel.objects) # evaluate queryset + assert TestModel.objects._result_cache is None + + +@skipUnless(cluster.protocol_version >= 2, "only runs against the cql3 protocol v2.0") +def test_paged_result_handling(): + # addresses #225 + class PagingTest(Model): + id = columns.Integer(primary_key=True) + val = columns.Integer() + sync_table(PagingTest) + + PagingTest.create(id=1, val=1) + PagingTest.create(id=2, val=2) + + session = get_session() + with mock.patch.object(session, 'default_fetch_size', 1): + results = PagingTest.objects()[:] + + assert len(results) == 2 diff --git a/cqlengine/tests/query/test_updates.py b/cqlengine/tests/query/test_updates.py new file mode 100644 index 00000000..1037cf5a --- /dev/null +++ b/cqlengine/tests/query/test_updates.py @@ -0,0 +1,219 @@ +from uuid import uuid4 +from cqlengine.exceptions import ValidationError +from cqlengine.query import QueryException + +from cqlengine.tests.base import BaseCassEngTestCase +from cqlengine.models import Model +from cqlengine.management import sync_table, drop_table +from cqlengine import columns + + +class TestQueryUpdateModel(Model): + __keyspace__ = 'test' + partition = columns.UUID(primary_key=True, default=uuid4) + cluster = columns.Integer(primary_key=True) + count = columns.Integer(required=False) + text = columns.Text(required=False, index=True) + text_set = columns.Set(columns.Text, required=False) + text_list = columns.List(columns.Text, required=False) + text_map = columns.Map(columns.Text, columns.Text, required=False) + +class QueryUpdateTests(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(QueryUpdateTests, cls).setUpClass() + sync_table(TestQueryUpdateModel) + + @classmethod + def tearDownClass(cls): + super(QueryUpdateTests, cls).tearDownClass() + drop_table(TestQueryUpdateModel) + + def test_update_values(self): + """ tests calling udpate on a queryset """ + partition = uuid4() + for i in range(5): + TestQueryUpdateModel.create(partition=partition, cluster=i, count=i, text=str(i)) + + # sanity check + for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)): + assert row.cluster == i + assert row.count == i + assert row.text == str(i) + + # perform update + TestQueryUpdateModel.objects(partition=partition, cluster=3).update(count=6) + + for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)): + assert row.cluster == i + assert row.count == (6 if i == 3 else i) + assert row.text == str(i) + + def test_update_values_validation(self): + """ tests calling udpate on models with values passed in """ + partition = uuid4() + for i in range(5): + TestQueryUpdateModel.create(partition=partition, cluster=i, count=i, text=str(i)) + + # sanity check + for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)): + assert row.cluster == i + assert row.count == i + assert row.text == str(i) + + # perform update + with self.assertRaises(ValidationError): + TestQueryUpdateModel.objects(partition=partition, cluster=3).update(count='asdf') + + def test_invalid_update_kwarg(self): + """ tests that passing in a kwarg to the update method that isn't a column will fail """ + with self.assertRaises(ValidationError): + TestQueryUpdateModel.objects(partition=uuid4(), cluster=3).update(bacon=5000) + + def test_primary_key_update_failure(self): + """ tests that attempting to update the value of a primary key will fail """ + with self.assertRaises(ValidationError): + TestQueryUpdateModel.objects(partition=uuid4(), cluster=3).update(cluster=5000) + + def test_null_update_deletes_column(self): + """ setting a field to null in the update should issue a delete statement """ + partition = uuid4() + for i in range(5): + TestQueryUpdateModel.create(partition=partition, cluster=i, count=i, text=str(i)) + + # sanity check + for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)): + assert row.cluster == i + assert row.count == i + assert row.text == str(i) + + # perform update + TestQueryUpdateModel.objects(partition=partition, cluster=3).update(text=None) + + for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)): + assert row.cluster == i + assert row.count == i + assert row.text == (None if i == 3 else str(i)) + + def test_mixed_value_and_null_update(self): + """ tests that updating a columns value, and removing another works properly """ + partition = uuid4() + for i in range(5): + TestQueryUpdateModel.create(partition=partition, cluster=i, count=i, text=str(i)) + + # sanity check + for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)): + assert row.cluster == i + assert row.count == i + assert row.text == str(i) + + # perform update + TestQueryUpdateModel.objects(partition=partition, cluster=3).update(count=6, text=None) + + for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)): + assert row.cluster == i + assert row.count == (6 if i == 3 else i) + assert row.text == (None if i == 3 else str(i)) + + def test_counter_updates(self): + pass + + def test_set_add_updates(self): + partition = uuid4() + cluster = 1 + TestQueryUpdateModel.objects.create( + partition=partition, cluster=cluster, text_set={"foo"}) + TestQueryUpdateModel.objects( + partition=partition, cluster=cluster).update(text_set__add={'bar'}) + obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster) + self.assertEqual(obj.text_set, {"foo", "bar"}) + + def test_set_add_updates_new_record(self): + """ If the key doesn't exist yet, an update creates the record + """ + partition = uuid4() + cluster = 1 + TestQueryUpdateModel.objects( + partition=partition, cluster=cluster).update(text_set__add={'bar'}) + obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster) + self.assertEqual(obj.text_set, {"bar"}) + + def test_set_remove_updates(self): + partition = uuid4() + cluster = 1 + TestQueryUpdateModel.objects.create( + partition=partition, cluster=cluster, text_set={"foo", "baz"}) + TestQueryUpdateModel.objects( + partition=partition, cluster=cluster).update( + text_set__remove={'foo'}) + obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster) + self.assertEqual(obj.text_set, {"baz"}) + + def test_set_remove_new_record(self): + """ Removing something not in the set should silently do nothing + """ + partition = uuid4() + cluster = 1 + TestQueryUpdateModel.objects.create( + partition=partition, cluster=cluster, text_set={"foo"}) + TestQueryUpdateModel.objects( + partition=partition, cluster=cluster).update( + text_set__remove={'afsd'}) + obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster) + self.assertEqual(obj.text_set, {"foo"}) + + def test_list_append_updates(self): + partition = uuid4() + cluster = 1 + TestQueryUpdateModel.objects.create( + partition=partition, cluster=cluster, text_list=["foo"]) + TestQueryUpdateModel.objects( + partition=partition, cluster=cluster).update( + text_list__append=['bar']) + obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster) + self.assertEqual(obj.text_list, ["foo", "bar"]) + + def test_list_prepend_updates(self): + """ Prepend two things since order is reversed by default by CQL """ + partition = uuid4() + cluster = 1 + TestQueryUpdateModel.objects.create( + partition=partition, cluster=cluster, text_list=["foo"]) + TestQueryUpdateModel.objects( + partition=partition, cluster=cluster).update( + text_list__prepend=['bar', 'baz']) + obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster) + self.assertEqual(obj.text_list, ["bar", "baz", "foo"]) + + def test_map_update_updates(self): + """ Merge a dictionary into existing value """ + partition = uuid4() + cluster = 1 + TestQueryUpdateModel.objects.create( + partition=partition, cluster=cluster, + text_map={"foo": '1', "bar": '2'}) + TestQueryUpdateModel.objects( + partition=partition, cluster=cluster).update( + text_map__update={"bar": '3', "baz": '4'}) + obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster) + self.assertEqual(obj.text_map, {"foo": '1', "bar": '3', "baz": '4'}) + + def test_map_update_none_deletes_key(self): + """ The CQL behavior is if you set a key in a map to null it deletes + that key from the map. Test that this works with __update. + + This test fails because of a bug in the cql python library not + converting None to null (and the cql library is no longer in active + developement). + """ + # partition = uuid4() + # cluster = 1 + # TestQueryUpdateModel.objects.create( + # partition=partition, cluster=cluster, + # text_map={"foo": '1', "bar": '2'}) + # TestQueryUpdateModel.objects( + # partition=partition, cluster=cluster).update( + # text_map__update={"bar": None}) + # obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster) + # self.assertEqual(obj.text_map, {"foo": '1'}) diff --git a/cqlengine/tests/statements/__init__.py b/cqlengine/tests/statements/__init__.py new file mode 100644 index 00000000..f6150a6d --- /dev/null +++ b/cqlengine/tests/statements/__init__.py @@ -0,0 +1 @@ +__author__ = 'bdeggleston' diff --git a/cqlengine/tests/statements/test_assignment_clauses.py b/cqlengine/tests/statements/test_assignment_clauses.py new file mode 100644 index 00000000..32da18b2 --- /dev/null +++ b/cqlengine/tests/statements/test_assignment_clauses.py @@ -0,0 +1,327 @@ +from unittest import TestCase +from cqlengine.statements import AssignmentClause, SetUpdateClause, ListUpdateClause, MapUpdateClause, MapDeleteClause, FieldDeleteClause, CounterUpdateClause + + +class AssignmentClauseTests(TestCase): + + def test_rendering(self): + pass + + def test_insert_tuple(self): + ac = AssignmentClause('a', 'b') + ac.set_context_id(10) + self.assertEqual(ac.insert_tuple(), ('a', 10)) + + +class SetUpdateClauseTests(TestCase): + + def test_update_from_none(self): + c = SetUpdateClause('s', {1, 2}, previous=None) + c._analyze() + c.set_context_id(0) + + self.assertEqual(c._assignments, {1, 2}) + self.assertIsNone(c._additions) + self.assertIsNone(c._removals) + + self.assertEqual(c.get_context_size(), 1) + self.assertEqual(str(c), '"s" = %(0)s') + + ctx = {} + c.update_context(ctx) + self.assertEqual(ctx, {'0': {1, 2}}) + + def test_null_update(self): + """ tests setting a set to None creates an empty update statement """ + c = SetUpdateClause('s', None, previous={1, 2}) + c._analyze() + c.set_context_id(0) + + self.assertIsNone(c._assignments) + self.assertIsNone(c._additions) + self.assertIsNone(c._removals) + + self.assertEqual(c.get_context_size(), 0) + self.assertEqual(str(c), '') + + ctx = {} + c.update_context(ctx) + self.assertEqual(ctx, {}) + + def test_no_update(self): + """ tests an unchanged value creates an empty update statement """ + c = SetUpdateClause('s', {1, 2}, previous={1, 2}) + c._analyze() + c.set_context_id(0) + + self.assertIsNone(c._assignments) + self.assertIsNone(c._additions) + self.assertIsNone(c._removals) + + self.assertEqual(c.get_context_size(), 0) + self.assertEqual(str(c), '') + + ctx = {} + c.update_context(ctx) + self.assertEqual(ctx, {}) + + def test_additions(self): + c = SetUpdateClause('s', {1, 2, 3}, previous={1, 2}) + c._analyze() + c.set_context_id(0) + + self.assertIsNone(c._assignments) + self.assertEqual(c._additions, {3}) + self.assertIsNone(c._removals) + + self.assertEqual(c.get_context_size(), 1) + self.assertEqual(str(c), '"s" = "s" + %(0)s') + + ctx = {} + c.update_context(ctx) + self.assertEqual(ctx, {'0': {3}}) + + def test_removals(self): + c = SetUpdateClause('s', {1, 2}, previous={1, 2, 3}) + c._analyze() + c.set_context_id(0) + + self.assertIsNone(c._assignments) + self.assertIsNone(c._additions) + self.assertEqual(c._removals, {3}) + + self.assertEqual(c.get_context_size(), 1) + self.assertEqual(str(c), '"s" = "s" - %(0)s') + + ctx = {} + c.update_context(ctx) + self.assertEqual(ctx, {'0': {3}}) + + def test_additions_and_removals(self): + c = SetUpdateClause('s', {2, 3}, previous={1, 2}) + c._analyze() + c.set_context_id(0) + + self.assertIsNone(c._assignments) + self.assertEqual(c._additions, {3}) + self.assertEqual(c._removals, {1}) + + self.assertEqual(c.get_context_size(), 2) + self.assertEqual(str(c), '"s" = "s" + %(0)s, "s" = "s" - %(1)s') + + ctx = {} + c.update_context(ctx) + self.assertEqual(ctx, {'0': {3}, '1': {1}}) + + +class ListUpdateClauseTests(TestCase): + + def test_update_from_none(self): + c = ListUpdateClause('s', [1, 2, 3]) + c._analyze() + c.set_context_id(0) + + self.assertEqual(c._assignments, [1, 2, 3]) + self.assertIsNone(c._append) + self.assertIsNone(c._prepend) + + self.assertEqual(c.get_context_size(), 1) + self.assertEqual(str(c), '"s" = %(0)s') + + ctx = {} + c.update_context(ctx) + self.assertEqual(ctx, {'0': [1, 2, 3]}) + + def test_update_from_empty(self): + c = ListUpdateClause('s', [1, 2, 3], previous=[]) + c._analyze() + c.set_context_id(0) + + self.assertEqual(c._assignments, [1, 2, 3]) + self.assertIsNone(c._append) + self.assertIsNone(c._prepend) + + self.assertEqual(c.get_context_size(), 1) + self.assertEqual(str(c), '"s" = %(0)s') + + ctx = {} + c.update_context(ctx) + self.assertEqual(ctx, {'0': [1, 2, 3]}) + + def test_update_from_different_list(self): + c = ListUpdateClause('s', [1, 2, 3], previous=[3, 2, 1]) + c._analyze() + c.set_context_id(0) + + self.assertEqual(c._assignments, [1, 2, 3]) + self.assertIsNone(c._append) + self.assertIsNone(c._prepend) + + self.assertEqual(c.get_context_size(), 1) + self.assertEqual(str(c), '"s" = %(0)s') + + ctx = {} + c.update_context(ctx) + self.assertEqual(ctx, {'0': [1, 2, 3]}) + + def test_append(self): + c = ListUpdateClause('s', [1, 2, 3, 4], previous=[1, 2]) + c._analyze() + c.set_context_id(0) + + self.assertIsNone(c._assignments) + self.assertEqual(c._append, [3, 4]) + self.assertIsNone(c._prepend) + + self.assertEqual(c.get_context_size(), 1) + self.assertEqual(str(c), '"s" = "s" + %(0)s') + + ctx = {} + c.update_context(ctx) + self.assertEqual(ctx, {'0': [3, 4]}) + + def test_prepend(self): + c = ListUpdateClause('s', [1, 2, 3, 4], previous=[3, 4]) + c._analyze() + c.set_context_id(0) + + self.assertIsNone(c._assignments) + self.assertIsNone(c._append) + self.assertEqual(c._prepend, [1, 2]) + + self.assertEqual(c.get_context_size(), 1) + self.assertEqual(str(c), '"s" = %(0)s + "s"') + + ctx = {} + c.update_context(ctx) + # test context list reversal + self.assertEqual(ctx, {'0': [2, 1]}) + + def test_append_and_prepend(self): + c = ListUpdateClause('s', [1, 2, 3, 4, 5, 6], previous=[3, 4]) + c._analyze() + c.set_context_id(0) + + self.assertIsNone(c._assignments) + self.assertEqual(c._append, [5, 6]) + self.assertEqual(c._prepend, [1, 2]) + + self.assertEqual(c.get_context_size(), 2) + self.assertEqual(str(c), '"s" = %(0)s + "s", "s" = "s" + %(1)s') + + ctx = {} + c.update_context(ctx) + # test context list reversal + self.assertEqual(ctx, {'0': [2, 1], '1': [5, 6]}) + + def test_shrinking_list_update(self): + """ tests that updating to a smaller list results in an insert statement """ + c = ListUpdateClause('s', [1, 2, 3], previous=[1, 2, 3, 4]) + c._analyze() + c.set_context_id(0) + + self.assertEqual(c._assignments, [1, 2, 3]) + self.assertIsNone(c._append) + self.assertIsNone(c._prepend) + + self.assertEqual(c.get_context_size(), 1) + self.assertEqual(str(c), '"s" = %(0)s') + + ctx = {} + c.update_context(ctx) + self.assertEqual(ctx, {'0': [1, 2, 3]}) + + +class MapUpdateTests(TestCase): + + def test_update(self): + c = MapUpdateClause('s', {3: 0, 5: 6}, previous={5: 0, 3: 4}) + c._analyze() + c.set_context_id(0) + + self.assertEqual(c._updates, [3, 5]) + self.assertEqual(c.get_context_size(), 4) + self.assertEqual(str(c), '"s"[%(0)s] = %(1)s, "s"[%(2)s] = %(3)s') + + ctx = {} + c.update_context(ctx) + self.assertEqual(ctx, {'0': 3, "1": 0, '2': 5, '3': 6}) + + def test_update_from_null(self): + c = MapUpdateClause('s', {3: 0, 5: 6}) + c._analyze() + c.set_context_id(0) + + self.assertEqual(c._updates, [3, 5]) + self.assertEqual(c.get_context_size(), 4) + self.assertEqual(str(c), '"s"[%(0)s] = %(1)s, "s"[%(2)s] = %(3)s') + + ctx = {} + c.update_context(ctx) + self.assertEqual(ctx, {'0': 3, "1": 0, '2': 5, '3': 6}) + + def test_nulled_columns_arent_included(self): + c = MapUpdateClause('s', {3: 0}, {1: 2, 3: 4}) + c._analyze() + c.set_context_id(0) + + self.assertNotIn(1, c._updates) + + +class CounterUpdateTests(TestCase): + + def test_positive_update(self): + c = CounterUpdateClause('a', 5, 3) + c.set_context_id(5) + + self.assertEqual(c.get_context_size(), 1) + self.assertEqual(str(c), '"a" = "a" + %(5)s') + + ctx = {} + c.update_context(ctx) + self.assertEqual(ctx, {'5': 2}) + + def test_negative_update(self): + c = CounterUpdateClause('a', 4, 7) + c.set_context_id(3) + + self.assertEqual(c.get_context_size(), 1) + self.assertEqual(str(c), '"a" = "a" - %(3)s') + + ctx = {} + c.update_context(ctx) + self.assertEqual(ctx, {'3': 3}) + + def noop_update(self): + c = CounterUpdateClause('a', 5, 5) + c.set_context_id(5) + + self.assertEqual(c.get_context_size(), 1) + self.assertEqual(str(c), '"a" = "a" + %(0)s') + + ctx = {} + c.update_context(ctx) + self.assertEqual(ctx, {'5': 0}) + + +class MapDeleteTests(TestCase): + + def test_update(self): + c = MapDeleteClause('s', {3: 0}, {1: 2, 3: 4, 5: 6}) + c._analyze() + c.set_context_id(0) + + self.assertEqual(c._removals, [1, 5]) + self.assertEqual(c.get_context_size(), 2) + self.assertEqual(str(c), '"s"[%(0)s], "s"[%(1)s]') + + ctx = {} + c.update_context(ctx) + self.assertEqual(ctx, {'0': 1, '1': 5}) + + +class FieldDeleteTests(TestCase): + + def test_str(self): + f = FieldDeleteClause("blake") + assert str(f) == '"blake"' diff --git a/cqlengine/tests/statements/test_assignment_statement.py b/cqlengine/tests/statements/test_assignment_statement.py new file mode 100644 index 00000000..695c7ceb --- /dev/null +++ b/cqlengine/tests/statements/test_assignment_statement.py @@ -0,0 +1,11 @@ +from unittest import TestCase +from cqlengine.statements import AssignmentStatement, StatementException + + +class AssignmentStatementTest(TestCase): + + def test_add_assignment_type_checking(self): + """ tests that only assignment clauses can be added to queries """ + stmt = AssignmentStatement('table', []) + with self.assertRaises(StatementException): + stmt.add_assignment_clause('x=5') \ No newline at end of file diff --git a/cqlengine/tests/statements/test_base_clause.py b/cqlengine/tests/statements/test_base_clause.py new file mode 100644 index 00000000..c5bbeb40 --- /dev/null +++ b/cqlengine/tests/statements/test_base_clause.py @@ -0,0 +1,16 @@ +from unittest import TestCase +from cqlengine.statements import BaseClause + + +class BaseClauseTests(TestCase): + + def test_context_updating(self): + ss = BaseClause('a', 'b') + assert ss.get_context_size() == 1 + + ctx = {} + ss.set_context_id(10) + ss.update_context(ctx) + assert ctx == {'10': 'b'} + + diff --git a/cqlengine/tests/statements/test_base_statement.py b/cqlengine/tests/statements/test_base_statement.py new file mode 100644 index 00000000..4acc1b92 --- /dev/null +++ b/cqlengine/tests/statements/test_base_statement.py @@ -0,0 +1,11 @@ +from unittest import TestCase +from cqlengine.statements import BaseCQLStatement, StatementException + + +class BaseStatementTest(TestCase): + + def test_where_clause_type_checking(self): + """ tests that only assignment clauses can be added to queries """ + stmt = BaseCQLStatement('table', []) + with self.assertRaises(StatementException): + stmt.add_where_clause('x=5') diff --git a/cqlengine/tests/statements/test_delete_statement.py b/cqlengine/tests/statements/test_delete_statement.py new file mode 100644 index 00000000..f3b0df9f --- /dev/null +++ b/cqlengine/tests/statements/test_delete_statement.py @@ -0,0 +1,48 @@ +from unittest import TestCase +from cqlengine.statements import DeleteStatement, WhereClause, MapDeleteClause +from cqlengine.operators import * + + +class DeleteStatementTests(TestCase): + + def test_single_field_is_listified(self): + """ tests that passing a string field into the constructor puts it into a list """ + ds = DeleteStatement('table', 'field') + self.assertEqual(len(ds.fields), 1) + self.assertEqual(ds.fields[0].field, 'field') + + def test_field_rendering(self): + """ tests that fields are properly added to the select statement """ + ds = DeleteStatement('table', ['f1', 'f2']) + self.assertTrue(unicode(ds).startswith('DELETE "f1", "f2"'), unicode(ds)) + self.assertTrue(str(ds).startswith('DELETE "f1", "f2"'), str(ds)) + + def test_none_fields_rendering(self): + """ tests that a '*' is added if no fields are passed in """ + ds = DeleteStatement('table', None) + self.assertTrue(unicode(ds).startswith('DELETE FROM'), unicode(ds)) + self.assertTrue(str(ds).startswith('DELETE FROM'), str(ds)) + + def test_table_rendering(self): + ds = DeleteStatement('table', None) + self.assertTrue(unicode(ds).startswith('DELETE FROM table'), unicode(ds)) + self.assertTrue(str(ds).startswith('DELETE FROM table'), str(ds)) + + def test_where_clause_rendering(self): + ds = DeleteStatement('table', None) + ds.add_where_clause(WhereClause('a', EqualsOperator(), 'b')) + self.assertEqual(unicode(ds), 'DELETE FROM table WHERE "a" = %(0)s', unicode(ds)) + + def test_context_update(self): + ds = DeleteStatement('table', None) + ds.add_field(MapDeleteClause('d', {1: 2}, {1:2, 3: 4})) + ds.add_where_clause(WhereClause('a', EqualsOperator(), 'b')) + + ds.update_context_id(7) + self.assertEqual(unicode(ds), 'DELETE "d"[%(8)s] FROM table WHERE "a" = %(7)s') + self.assertEqual(ds.get_context(), {'7': 'b', '8': 3}) + + def test_context(self): + ds = DeleteStatement('table', None) + ds.add_where_clause(WhereClause('a', EqualsOperator(), 'b')) + self.assertEqual(ds.get_context(), {'0': 'b'}) diff --git a/cqlengine/tests/statements/test_insert_statement.py b/cqlengine/tests/statements/test_insert_statement.py new file mode 100644 index 00000000..32bb7435 --- /dev/null +++ b/cqlengine/tests/statements/test_insert_statement.py @@ -0,0 +1,40 @@ +from unittest import TestCase +from cqlengine.statements import InsertStatement, StatementException, AssignmentClause + + +class InsertStatementTests(TestCase): + + def test_where_clause_failure(self): + """ tests that where clauses cannot be added to Insert statements """ + ist = InsertStatement('table', None) + with self.assertRaises(StatementException): + ist.add_where_clause('s') + + def test_statement(self): + ist = InsertStatement('table', None) + ist.add_assignment_clause(AssignmentClause('a', 'b')) + ist.add_assignment_clause(AssignmentClause('c', 'd')) + + self.assertEqual( + unicode(ist), + 'INSERT INTO table ("a", "c") VALUES (%(0)s, %(1)s)' + ) + + def test_context_update(self): + ist = InsertStatement('table', None) + ist.add_assignment_clause(AssignmentClause('a', 'b')) + ist.add_assignment_clause(AssignmentClause('c', 'd')) + + ist.update_context_id(4) + self.assertEqual( + unicode(ist), + 'INSERT INTO table ("a", "c") VALUES (%(4)s, %(5)s)' + ) + ctx = ist.get_context() + self.assertEqual(ctx, {'4': 'b', '5': 'd'}) + + def test_additional_rendering(self): + ist = InsertStatement('table', ttl=60) + ist.add_assignment_clause(AssignmentClause('a', 'b')) + ist.add_assignment_clause(AssignmentClause('c', 'd')) + self.assertIn('USING TTL 60', unicode(ist)) diff --git a/cqlengine/tests/statements/test_quoter.py b/cqlengine/tests/statements/test_quoter.py new file mode 100644 index 00000000..e69de29b diff --git a/cqlengine/tests/statements/test_select_statement.py b/cqlengine/tests/statements/test_select_statement.py new file mode 100644 index 00000000..9c466485 --- /dev/null +++ b/cqlengine/tests/statements/test_select_statement.py @@ -0,0 +1,70 @@ +from unittest import TestCase +from cqlengine.statements import SelectStatement, WhereClause +from cqlengine.operators import * + + +class SelectStatementTests(TestCase): + + def test_single_field_is_listified(self): + """ tests that passing a string field into the constructor puts it into a list """ + ss = SelectStatement('table', 'field') + self.assertEqual(ss.fields, ['field']) + + def test_field_rendering(self): + """ tests that fields are properly added to the select statement """ + ss = SelectStatement('table', ['f1', 'f2']) + self.assertTrue(unicode(ss).startswith('SELECT "f1", "f2"'), unicode(ss)) + self.assertTrue(str(ss).startswith('SELECT "f1", "f2"'), str(ss)) + + def test_none_fields_rendering(self): + """ tests that a '*' is added if no fields are passed in """ + ss = SelectStatement('table') + self.assertTrue(unicode(ss).startswith('SELECT *'), unicode(ss)) + self.assertTrue(str(ss).startswith('SELECT *'), str(ss)) + + def test_table_rendering(self): + ss = SelectStatement('table') + self.assertTrue(unicode(ss).startswith('SELECT * FROM table'), unicode(ss)) + self.assertTrue(str(ss).startswith('SELECT * FROM table'), str(ss)) + + def test_where_clause_rendering(self): + ss = SelectStatement('table') + ss.add_where_clause(WhereClause('a', EqualsOperator(), 'b')) + self.assertEqual(unicode(ss), 'SELECT * FROM table WHERE "a" = %(0)s', unicode(ss)) + + def test_count(self): + ss = SelectStatement('table', count=True, limit=10, order_by='d') + ss.add_where_clause(WhereClause('a', EqualsOperator(), 'b')) + self.assertEqual(unicode(ss), 'SELECT COUNT(*) FROM table WHERE "a" = %(0)s LIMIT 10', unicode(ss)) + self.assertIn('LIMIT', unicode(ss)) + self.assertNotIn('ORDER', unicode(ss)) + + def test_context(self): + ss = SelectStatement('table') + ss.add_where_clause(WhereClause('a', EqualsOperator(), 'b')) + self.assertEqual(ss.get_context(), {'0': 'b'}) + + def test_context_id_update(self): + """ tests that the right things happen the the context id """ + ss = SelectStatement('table') + ss.add_where_clause(WhereClause('a', EqualsOperator(), 'b')) + self.assertEqual(ss.get_context(), {'0': 'b'}) + self.assertEqual(str(ss), 'SELECT * FROM table WHERE "a" = %(0)s') + + ss.update_context_id(5) + self.assertEqual(ss.get_context(), {'5': 'b'}) + self.assertEqual(str(ss), 'SELECT * FROM table WHERE "a" = %(5)s') + + def test_additional_rendering(self): + ss = SelectStatement( + 'table', + None, + order_by=['x', 'y'], + limit=15, + allow_filtering=True + ) + qstr = unicode(ss) + self.assertIn('LIMIT 15', qstr) + self.assertIn('ORDER BY x, y', qstr) + self.assertIn('ALLOW FILTERING', qstr) + diff --git a/cqlengine/tests/statements/test_update_statement.py b/cqlengine/tests/statements/test_update_statement.py new file mode 100644 index 00000000..fae69bbd --- /dev/null +++ b/cqlengine/tests/statements/test_update_statement.py @@ -0,0 +1,42 @@ +from unittest import TestCase +from cqlengine.statements import UpdateStatement, WhereClause, AssignmentClause +from cqlengine.operators import * + + +class UpdateStatementTests(TestCase): + + def test_table_rendering(self): + """ tests that fields are properly added to the select statement """ + us = UpdateStatement('table') + self.assertTrue(unicode(us).startswith('UPDATE table SET'), unicode(us)) + self.assertTrue(str(us).startswith('UPDATE table SET'), str(us)) + + def test_rendering(self): + us = UpdateStatement('table') + us.add_assignment_clause(AssignmentClause('a', 'b')) + us.add_assignment_clause(AssignmentClause('c', 'd')) + us.add_where_clause(WhereClause('a', EqualsOperator(), 'x')) + self.assertEqual(unicode(us), 'UPDATE table SET "a" = %(0)s, "c" = %(1)s WHERE "a" = %(2)s', unicode(us)) + + def test_context(self): + us = UpdateStatement('table') + us.add_assignment_clause(AssignmentClause('a', 'b')) + us.add_assignment_clause(AssignmentClause('c', 'd')) + us.add_where_clause(WhereClause('a', EqualsOperator(), 'x')) + self.assertEqual(us.get_context(), {'0': 'b', '1': 'd', '2': 'x'}) + + def test_context_update(self): + us = UpdateStatement('table') + us.add_assignment_clause(AssignmentClause('a', 'b')) + us.add_assignment_clause(AssignmentClause('c', 'd')) + us.add_where_clause(WhereClause('a', EqualsOperator(), 'x')) + us.update_context_id(3) + self.assertEqual(unicode(us), 'UPDATE table SET "a" = %(4)s, "c" = %(5)s WHERE "a" = %(3)s') + self.assertEqual(us.get_context(), {'4': 'b', '5': 'd', '3': 'x'}) + + def test_additional_rendering(self): + us = UpdateStatement('table', ttl=60) + us.add_assignment_clause(AssignmentClause('a', 'b')) + us.add_where_clause(WhereClause('a', EqualsOperator(), 'x')) + self.assertIn('USING TTL 60', unicode(us)) + diff --git a/cqlengine/tests/statements/test_where_clause.py b/cqlengine/tests/statements/test_where_clause.py new file mode 100644 index 00000000..be81b63e --- /dev/null +++ b/cqlengine/tests/statements/test_where_clause.py @@ -0,0 +1,25 @@ +from unittest import TestCase +from cqlengine.operators import EqualsOperator +from cqlengine.statements import StatementException, WhereClause + + +class TestWhereClause(TestCase): + + def test_operator_check(self): + """ tests that creating a where statement with a non BaseWhereOperator object fails """ + with self.assertRaises(StatementException): + WhereClause('a', 'b', 'c') + + def test_where_clause_rendering(self): + """ tests that where clauses are rendered properly """ + wc = WhereClause('a', EqualsOperator(), 'c') + wc.set_context_id(5) + + self.assertEqual('"a" = %(5)s', unicode(wc), unicode(wc)) + self.assertEqual('"a" = %(5)s', str(wc), type(wc)) + + def test_equality_method(self): + """ tests that 2 identical where clauses evaluate as == """ + wc1 = WhereClause('a', EqualsOperator(), 'c') + wc2 = WhereClause('a', EqualsOperator(), 'c') + assert wc1 == wc2 diff --git a/cqlengine/tests/test_batch_query.py b/cqlengine/tests/test_batch_query.py index 7ee98b70..39204a59 100644 --- a/cqlengine/tests/test_batch_query.py +++ b/cqlengine/tests/test_batch_query.py @@ -1,29 +1,35 @@ from unittest import skip from uuid import uuid4 import random + +import mock +import sure + from cqlengine import Model, columns -from cqlengine.management import delete_table, create_table +from cqlengine.management import drop_table, sync_table from cqlengine.query import BatchQuery from cqlengine.tests.base import BaseCassEngTestCase class TestMultiKeyModel(Model): + __keyspace__ = 'test' partition = columns.Integer(primary_key=True) cluster = columns.Integer(primary_key=True) count = columns.Integer(required=False) text = columns.Text(required=False) + class BatchQueryTests(BaseCassEngTestCase): @classmethod def setUpClass(cls): super(BatchQueryTests, cls).setUpClass() - delete_table(TestMultiKeyModel) - create_table(TestMultiKeyModel) + drop_table(TestMultiKeyModel) + sync_table(TestMultiKeyModel) @classmethod def tearDownClass(cls): super(BatchQueryTests, cls).tearDownClass() - delete_table(TestMultiKeyModel) + drop_table(TestMultiKeyModel) def setUp(self): super(BatchQueryTests, self).setUp() @@ -109,3 +115,80 @@ class BatchQueryTests(BaseCassEngTestCase): with BatchQuery() as b: pass + +class BatchQueryCallbacksTests(BaseCassEngTestCase): + + def test_API_managing_callbacks(self): + + # Callbacks can be added at init and after + + def my_callback(*args, **kwargs): + pass + + # adding on init: + batch = BatchQuery() + + batch.add_callback(my_callback) + batch.add_callback(my_callback, 2, named_arg='value') + batch.add_callback(my_callback, 1, 3) + + assert batch._callbacks == [ + (my_callback, (), {}), + (my_callback, (2,), {'named_arg':'value'}), + (my_callback, (1, 3), {}) + ] + + def test_callbacks_properly_execute_callables_and_tuples(self): + + call_history = [] + def my_callback(*args, **kwargs): + call_history.append(args) + + # adding on init: + batch = BatchQuery() + + batch.add_callback(my_callback) + batch.add_callback(my_callback, 'more', 'args') + + batch.execute() + + assert len(call_history) == 2 + assert [(), ('more', 'args')] == call_history + + def test_callbacks_tied_to_execute(self): + """Batch callbacks should NOT fire if batch is not executed in context manager mode""" + + call_history = [] + def my_callback(*args, **kwargs): + call_history.append(args) + + with BatchQuery() as batch: + batch.add_callback(my_callback) + pass + + assert len(call_history) == 1 + + class SomeError(Exception): + pass + + with self.assertRaises(SomeError): + with BatchQuery() as batch: + batch.add_callback(my_callback) + # this error bubbling up through context manager + # should prevent callback runs (along with b.execute()) + raise SomeError + + # still same call history. Nothing added + assert len(call_history) == 1 + + # but if execute ran, even with an error bubbling through + # the callbacks also would have fired + with self.assertRaises(SomeError): + with BatchQuery(execute_on_exception=True) as batch: + batch.add_callback(my_callback) + # this error bubbling up through context manager + # should prevent callback runs (along with b.execute()) + raise SomeError + + # still same call history + assert len(call_history) == 2 diff --git a/cqlengine/tests/test_consistency.py b/cqlengine/tests/test_consistency.py new file mode 100644 index 00000000..a45117b0 --- /dev/null +++ b/cqlengine/tests/test_consistency.py @@ -0,0 +1,95 @@ +from cqlengine.management import sync_table, drop_table +from cqlengine.tests.base import BaseCassEngTestCase +from cqlengine.models import Model +from uuid import uuid4 +from cqlengine import columns +import mock +from cqlengine import ALL, BatchQuery + +class TestConsistencyModel(Model): + __keyspace__ = 'test' + id = columns.UUID(primary_key=True, default=lambda:uuid4()) + count = columns.Integer() + text = columns.Text(required=False) + +class BaseConsistencyTest(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(BaseConsistencyTest, cls).setUpClass() + sync_table(TestConsistencyModel) + + @classmethod + def tearDownClass(cls): + super(BaseConsistencyTest, cls).tearDownClass() + drop_table(TestConsistencyModel) + + +class TestConsistency(BaseConsistencyTest): + def test_create_uses_consistency(self): + + qs = TestConsistencyModel.consistency(ALL) + with mock.patch.object(self.session, 'execute') as m: + qs.create(text="i am not fault tolerant this way") + + args = m.call_args + self.assertEqual(ALL, args[0][0].consistency_level) + + def test_queryset_is_returned_on_create(self): + qs = TestConsistencyModel.consistency(ALL) + self.assertTrue(isinstance(qs, TestConsistencyModel.__queryset__), type(qs)) + + def test_update_uses_consistency(self): + t = TestConsistencyModel.create(text="bacon and eggs") + t.text = "ham sandwich" + + with mock.patch.object(self.session, 'execute') as m: + t.consistency(ALL).save() + + args = m.call_args + self.assertEqual(ALL, args[0][0].consistency_level) + + + def test_batch_consistency(self): + + with mock.patch.object(self.session, 'execute') as m: + with BatchQuery(consistency=ALL) as b: + TestConsistencyModel.batch(b).create(text="monkey") + + args = m.call_args + + self.assertEqual(ALL, args[0][0].consistency_level) + + with mock.patch.object(self.session, 'execute') as m: + with BatchQuery() as b: + TestConsistencyModel.batch(b).create(text="monkey") + + args = m.call_args + self.assertNotEqual(ALL, args[0][0].consistency_level) + + def test_blind_update(self): + t = TestConsistencyModel.create(text="bacon and eggs") + t.text = "ham sandwich" + uid = t.id + + with mock.patch.object(self.session, 'execute') as m: + TestConsistencyModel.objects(id=uid).consistency(ALL).update(text="grilled cheese") + + args = m.call_args + self.assertEqual(ALL, args[0][0].consistency_level) + + + def test_delete(self): + # ensures we always carry consistency through on delete statements + t = TestConsistencyModel.create(text="bacon and eggs") + t.text = "ham and cheese sandwich" + uid = t.id + + with mock.patch.object(self.session, 'execute') as m: + t.consistency(ALL).delete() + + with mock.patch.object(self.session, 'execute') as m: + TestConsistencyModel.objects(id=uid).consistency(ALL).delete() + + args = m.call_args + self.assertEqual(ALL, args[0][0].consistency_level) diff --git a/cqlengine/tests/test_load.py b/cqlengine/tests/test_load.py new file mode 100644 index 00000000..6e64d96d --- /dev/null +++ b/cqlengine/tests/test_load.py @@ -0,0 +1,34 @@ +import os +from unittest import TestCase, skipUnless + +from cqlengine import Model, Integer +from cqlengine.management import sync_table +from cqlengine.tests import base +import resource +import gc + +class LoadTest(Model): + __keyspace__ = 'test' + k = Integer(primary_key=True) + v = Integer() + + +@skipUnless("LOADTEST" in os.environ, "LOADTEST not on") +def test_lots_of_queries(): + sync_table(LoadTest) + import objgraph + gc.collect() + objgraph.show_most_common_types() + + print "Starting..." + + for i in range(1000000): + if i % 25000 == 0: + # print memory statistic + print "Memory usage: %s" % (resource.getrusage(resource.RUSAGE_SELF).ru_maxrss) + + LoadTest.create(k=i, v=i) + + objgraph.show_most_common_types() + + raise Exception("you shouldn't be here") diff --git a/cqlengine/tests/test_timestamp.py b/cqlengine/tests/test_timestamp.py new file mode 100644 index 00000000..cd298cf4 --- /dev/null +++ b/cqlengine/tests/test_timestamp.py @@ -0,0 +1,178 @@ +""" +Tests surrounding the blah.timestamp( timedelta(seconds=30) ) format. +""" +from datetime import timedelta, datetime + +from uuid import uuid4 +import mock +import sure +from cqlengine import Model, columns, BatchQuery +from cqlengine.management import sync_table +from cqlengine.tests.base import BaseCassEngTestCase + + +class TestTimestampModel(Model): + __keyspace__ = 'test' + id = columns.UUID(primary_key=True, default=lambda:uuid4()) + count = columns.Integer() + + +class BaseTimestampTest(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(BaseTimestampTest, cls).setUpClass() + sync_table(TestTimestampModel) + + +class BatchTest(BaseTimestampTest): + + def test_batch_is_included(self): + with mock.patch.object(self.session, "execute") as m, BatchQuery(timestamp=timedelta(seconds=30)) as b: + TestTimestampModel.batch(b).create(count=1) + + "USING TIMESTAMP".should.be.within(m.call_args[0][0].query_string) + + +class CreateWithTimestampTest(BaseTimestampTest): + + def test_batch(self): + with mock.patch.object(self.session, "execute") as m, BatchQuery() as b: + TestTimestampModel.timestamp(timedelta(seconds=10)).batch(b).create(count=1) + + query = m.call_args[0][0].query_string + + query.should.match(r"INSERT.*USING TIMESTAMP") + query.should_not.match(r"TIMESTAMP.*INSERT") + + def test_timestamp_not_included_on_normal_create(self): + with mock.patch.object(self.session, "execute") as m: + TestTimestampModel.create(count=2) + + "USING TIMESTAMP".shouldnt.be.within(m.call_args[0][0].query_string) + + def test_timestamp_is_set_on_model_queryset(self): + delta = timedelta(seconds=30) + tmp = TestTimestampModel.timestamp(delta) + tmp._timestamp.should.equal(delta) + + def test_non_batch_syntax_integration(self): + tmp = TestTimestampModel.timestamp(timedelta(seconds=30)).create(count=1) + tmp.should.be.ok + + def test_non_batch_syntax_unit(self): + + with mock.patch.object(self.session, "execute") as m: + TestTimestampModel.timestamp(timedelta(seconds=30)).create(count=1) + + query = m.call_args[0][0].query_string + + "USING TIMESTAMP".should.be.within(query) + + +class UpdateWithTimestampTest(BaseTimestampTest): + def setUp(self): + self.instance = TestTimestampModel.create(count=1) + super(UpdateWithTimestampTest, self).setUp() + + def test_instance_update_includes_timestamp_in_query(self): + # not a batch + + with mock.patch.object(self.session, "execute") as m: + self.instance.timestamp(timedelta(seconds=30)).update(count=2) + + "USING TIMESTAMP".should.be.within(m.call_args[0][0].query_string) + + def test_instance_update_in_batch(self): + with mock.patch.object(self.session, "execute") as m, BatchQuery() as b: + self.instance.batch(b).timestamp(timedelta(seconds=30)).update(count=2) + + query = m.call_args[0][0].query_string + "USING TIMESTAMP".should.be.within(query) + + +class DeleteWithTimestampTest(BaseTimestampTest): + + def test_non_batch(self): + """ + we don't expect the model to come back at the end because the deletion timestamp should be in the future + """ + uid = uuid4() + tmp = TestTimestampModel.create(id=uid, count=1) + + TestTimestampModel.get(id=uid).should.be.ok + + tmp.timestamp(timedelta(seconds=5)).delete() + + with self.assertRaises(TestTimestampModel.DoesNotExist): + TestTimestampModel.get(id=uid) + + tmp = TestTimestampModel.create(id=uid, count=1) + + with self.assertRaises(TestTimestampModel.DoesNotExist): + TestTimestampModel.get(id=uid) + + # calling .timestamp sets the TS on the model + tmp.timestamp(timedelta(seconds=5)) + tmp._timestamp.should.be.ok + + # calling save clears the set timestamp + tmp.save() + tmp._timestamp.shouldnt.be.ok + + tmp.timestamp(timedelta(seconds=5)) + tmp.update() + tmp._timestamp.shouldnt.be.ok + + def test_blind_delete(self): + """ + we don't expect the model to come back at the end because the deletion timestamp should be in the future + """ + uid = uuid4() + tmp = TestTimestampModel.create(id=uid, count=1) + + TestTimestampModel.get(id=uid).should.be.ok + + TestTimestampModel.objects(id=uid).timestamp(timedelta(seconds=5)).delete() + + with self.assertRaises(TestTimestampModel.DoesNotExist): + TestTimestampModel.get(id=uid) + + tmp = TestTimestampModel.create(id=uid, count=1) + + with self.assertRaises(TestTimestampModel.DoesNotExist): + TestTimestampModel.get(id=uid) + + def test_blind_delete_with_datetime(self): + """ + we don't expect the model to come back at the end because the deletion timestamp should be in the future + """ + uid = uuid4() + tmp = TestTimestampModel.create(id=uid, count=1) + + TestTimestampModel.get(id=uid).should.be.ok + + plus_five_seconds = datetime.now() + timedelta(seconds=5) + + TestTimestampModel.objects(id=uid).timestamp(plus_five_seconds).delete() + + with self.assertRaises(TestTimestampModel.DoesNotExist): + TestTimestampModel.get(id=uid) + + tmp = TestTimestampModel.create(id=uid, count=1) + + with self.assertRaises(TestTimestampModel.DoesNotExist): + TestTimestampModel.get(id=uid) + + def test_delete_in_the_past(self): + uid = uuid4() + tmp = TestTimestampModel.create(id=uid, count=1) + + TestTimestampModel.get(id=uid).should.be.ok + + # delete the in past, should not affect the object created above + TestTimestampModel.objects(id=uid).timestamp(timedelta(seconds=-60)).delete() + + TestTimestampModel.get(id=uid) + + diff --git a/cqlengine/tests/test_ttl.py b/cqlengine/tests/test_ttl.py new file mode 100644 index 00000000..bb5e1bb4 --- /dev/null +++ b/cqlengine/tests/test_ttl.py @@ -0,0 +1,121 @@ +from cqlengine.management import sync_table, drop_table +from cqlengine.tests.base import BaseCassEngTestCase +from cqlengine.models import Model +from uuid import uuid4 +from cqlengine import columns +import mock +from cqlengine.connection import get_session + + +class TestTTLModel(Model): + __keyspace__ = 'test' + id = columns.UUID(primary_key=True, default=lambda:uuid4()) + count = columns.Integer() + text = columns.Text(required=False) + + +class BaseTTLTest(BaseCassEngTestCase): + + @classmethod + def setUpClass(cls): + super(BaseTTLTest, cls).setUpClass() + sync_table(TestTTLModel) + + @classmethod + def tearDownClass(cls): + super(BaseTTLTest, cls).tearDownClass() + drop_table(TestTTLModel) + + + +class TTLQueryTests(BaseTTLTest): + + def test_update_queryset_ttl_success_case(self): + """ tests that ttls on querysets work as expected """ + + def test_select_ttl_failure(self): + """ tests that ttls on select queries raise an exception """ + + +class TTLModelTests(BaseTTLTest): + + def test_ttl_included_on_create(self): + """ tests that ttls on models work as expected """ + session = get_session() + + with mock.patch.object(session, 'execute') as m: + TestTTLModel.ttl(60).create(text="hello blake") + + query = m.call_args[0][0].query_string + self.assertIn("USING TTL", query) + + def test_queryset_is_returned_on_class(self): + """ + ensures we get a queryset descriptor back + """ + qs = TestTTLModel.ttl(60) + self.assertTrue(isinstance(qs, TestTTLModel.__queryset__), type(qs)) + + + +class TTLInstanceUpdateTest(BaseTTLTest): + def test_update_includes_ttl(self): + session = get_session() + + model = TestTTLModel.create(text="goodbye blake") + with mock.patch.object(session, 'execute') as m: + model.ttl(60).update(text="goodbye forever") + + query = m.call_args[0][0].query_string + self.assertIn("USING TTL", query) + + def test_update_syntax_valid(self): + # sanity test that ensures the TTL syntax is accepted by cassandra + model = TestTTLModel.create(text="goodbye blake") + model.ttl(60).update(text="goodbye forever") + + + + + +class TTLInstanceTest(BaseTTLTest): + def test_instance_is_returned(self): + """ + ensures that we properly handle the instance.ttl(60).save() scenario + :return: + """ + o = TestTTLModel.create(text="whatever") + o.text = "new stuff" + o = o.ttl(60) + self.assertEqual(60, o._ttl) + + def test_ttl_is_include_with_query_on_update(self): + session = get_session() + + o = TestTTLModel.create(text="whatever") + o.text = "new stuff" + o = o.ttl(60) + + with mock.patch.object(session, 'execute') as m: + o.save() + + query = m.call_args[0][0].query_string + self.assertIn("USING TTL", query) + + +class TTLBlindUpdateTest(BaseTTLTest): + def test_ttl_included_with_blind_update(self): + session = get_session() + + o = TestTTLModel.create(text="whatever") + tid = o.id + + with mock.patch.object(session, 'execute') as m: + TestTTLModel.objects(id=tid).ttl(60).update(text="bacon") + + query = m.call_args[0][0].query_string + self.assertIn("USING TTL", query) + + + + diff --git a/docs/conf.py b/docs/conf.py index 96c7db6e..82a8833b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -11,7 +11,7 @@ # All configuration values have a default; values that are commented out # serve to show the default. -import sys, os +import os # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the @@ -41,16 +41,18 @@ master_doc = 'index' # General information about the project. project = u'cqlengine' -copyright = u'2012, Blake Eggleston' +copyright = u'2012, Blake Eggleston, Jon Haddad' # The version info for the project you're documenting, acts as replacement for # |version| and |release|, also used in various other places throughout the # built documents. # +__cqlengine_version_path__ = os.path.realpath(__file__ + '/../../cqlengine/VERSION') # The short X.Y version. -version = '0.2' +version = open(__cqlengine_version_path__, 'r').readline().strip() # The full version, including alpha/beta/rc tags. -release = '0.2' +release = version + # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -183,8 +185,7 @@ latex_elements = { # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, author, documentclass [howto/manual]). latex_documents = [ - ('index', 'cqlengine.tex', u'cqlengine Documentation', - u'Blake Eggleston', 'manual'), + ('index', 'cqlengine.tex', u'cqlengine Documentation', u'Blake Eggleston, Jon Haddad', 'manual'), ] # The name of an image file (relative to this directory) to place at the top of @@ -214,7 +215,7 @@ latex_documents = [ # (source start file, name, description, authors, manual section). man_pages = [ ('index', 'cqlengine', u'cqlengine Documentation', - [u'Blake Eggleston'], 1) + [u'Blake Eggleston, Jon Haddad'], 1) ] # If true, show URL addresses after external links. @@ -227,9 +228,9 @@ man_pages = [ # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - ('index', 'cqlengine', u'cqlengine Documentation', - u'Blake Eggleston', 'cqlengine', 'One line description of project.', - 'Miscellaneous'), + ('index', 'cqlengine', u'cqlengine Documentation', + u'Blake Eggleston, Jon Haddad', 'cqlengine', 'One line description of project.', + 'Miscellaneous'), ] # Documents to append as an appendix to all manuals. diff --git a/docs/index.rst b/docs/index.rst index 6cea2f2c..7bb0e5e9 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -5,20 +5,31 @@ cqlengine documentation ======================= -cqlengine is a Cassandra CQL 3 Object Mapper for Python with an interface similar to the Django orm and mongoengine +**Users of versions < 0.16, the default keyspace 'cqlengine' has been removed. Please read this before upgrading:** :ref:`Breaking Changes ` + +cqlengine is a Cassandra CQL 3 Object Mapper for Python :ref:`getting-started` +Download +======== + +`Github `_ + +`PyPi `_ + Contents: +========= .. toctree:: :maxdepth: 2 - + topics/models topics/queryset topics/columns topics/connection topics/manage_schemas + topics/faq .. _getting-started: @@ -32,18 +43,24 @@ Getting Started from cqlengine import Model class ExampleModel(Model): - example_id = columns.UUID(primary_key=True) + example_id = columns.UUID(primary_key=True, default=uuid.uuid4) example_type = columns.Integer(index=True) created_at = columns.DateTime() description = columns.Text(required=False) #next, setup the connection to your cassandra server(s)... >>> from cqlengine import connection - >>> connection.setup(['127.0.0.1:9160']) + + # see http://datastax.github.io/python-driver/api/cassandra/cluster.html for options + # the list of hosts will be passed to create a Cluster() instance + >>> connection.setup(['127.0.0.1']) + + # if you're connecting to a 1.2 cluster + >>> connection.setup(['127.0.0.1'], protocol_version=1) #...and create your CQL table - >>> from cqlengine.management import create_table - >>> create_table(ExampleModel) + >>> from cqlengine.management import sync_table + >>> sync_table(ExampleModel) #now we can create some rows: >>> em1 = ExampleModel.create(example_type=0, description="example1", created_at=datetime.now()) @@ -54,8 +71,6 @@ Getting Started >>> em6 = ExampleModel.create(example_type=1, description="example6", created_at=datetime.now()) >>> em7 = ExampleModel.create(example_type=1, description="example7", created_at=datetime.now()) >>> em8 = ExampleModel.create(example_type=1, description="example8", created_at=datetime.now()) - # Note: the UUID and DateTime columns will create uuid4 and datetime.now - # values automatically if we don't specify them when creating new rows #and now we can run some queries against our table >>> ExampleModel.objects.count() @@ -64,7 +79,7 @@ Getting Started >>> q.count() 4 >>> for instance in q: - >>> print q.description + >>> print instance.description example5 example6 example7 @@ -78,7 +93,7 @@ Getting Started >>> q2.count() 1 >>> for instance in q2: - >>> print q.description + >>> print instance.description example5 @@ -86,10 +101,6 @@ Getting Started `Users Mailing List `_ -`Dev Mailing List `_ - -**NOTE: cqlengine is in alpha and under development, some features may change. Make sure to check the changelog and test your app before upgrading** - Indices and tables ================== diff --git a/docs/topics/columns.rst b/docs/topics/columns.rst index 0537e3bc..0f71e3c7 100644 --- a/docs/topics/columns.rst +++ b/docs/topics/columns.rst @@ -2,6 +2,10 @@ Columns ======= +**Users of versions < 0.4, please read this post before upgrading:** `Breaking Changes`_ + +.. _Breaking Changes: https://groups.google.com/forum/?fromgroups#!topic/cqlengine-users/erkSNe1JwuU + .. module:: cqlengine.columns .. class:: Bytes() @@ -14,9 +18,9 @@ Columns .. class:: Ascii() Stores a US-ASCII character string :: - + columns.Ascii() - + .. class:: Text() @@ -34,26 +38,50 @@ Columns .. class:: Integer() - Stores an integer value :: + Stores a 32-bit signed integer value :: columns.Integer() +.. class:: BigInt() + + Stores a 64-bit signed long value :: + + columns.BigInt() + +.. class:: VarInt() + + Stores an arbitrary-precision integer :: + + columns.VarInt() + .. class:: DateTime() Stores a datetime value. - Python's datetime.now callable is set as the default value for this column :: - columns.DateTime() .. class:: UUID() Stores a type 1 or type 4 UUID. - Python's uuid.uuid4 callable is set as the default value for this column. :: - columns.UUID() +.. class:: TimeUUID() + + Stores a UUID value as the cql type 'timeuuid' :: + + columns.TimeUUID() + + .. classmethod:: from_datetime(dt) + + generates a TimeUUID for the given datetime + + :param dt: the datetime to create a time uuid from + :type dt: datetime.datetime + + :returns: a time uuid created from the given datetime + :rtype: uuid1 + .. class:: Boolean() Stores a boolean True or False value :: @@ -77,16 +105,24 @@ Columns columns.Decimal() +.. class:: Counter() + + Counters can be incremented and decremented :: + + columns.Counter() + + Collection Type Columns ---------------------------- CQLEngine also supports container column types. Each container column requires a column class argument to specify what type of objects it will hold. The Map column requires 2, one for the key, and the other for the value - + *Example* .. code-block:: python - + class Person(Model): + id = columns.UUID(primary_key=True, default=uuid.uuid4) first_name = columns.Text() last_name = columns.Text() @@ -94,7 +130,7 @@ Collection Type Columns enemies = columns.Set(columns.Text) todo_list = columns.List(columns.Text) birthdays = columns.Map(columns.Text, columns.DateTime) - + .. class:: Set() @@ -145,12 +181,16 @@ Column Options If True, this column is created as a primary key field. A model can have multiple primary keys. Defaults to False. - *In CQL, there are 2 types of primary keys: partition keys and clustering keys. As with CQL, the first primary key is the partition key, and all others are clustering keys.* + *In CQL, there are 2 types of primary keys: partition keys and clustering keys. As with CQL, the first primary key is the partition key, and all others are clustering keys, unless partition keys are specified manually using* :attr:`BaseColumn.partition_key` + + .. attribute:: BaseColumn.partition_key + + If True, this column is created as partition primary key. There may be many partition keys defined, forming a *composite partition key* .. attribute:: BaseColumn.index If True, an index will be created for this column. Defaults to False. - + *Note: Indexes can only be created on models with one primary key* .. attribute:: BaseColumn.db_field @@ -163,5 +203,8 @@ Column Options .. attribute:: BaseColumn.required - If True, this model cannot be saved without a value defined for this column. Defaults to True. Primary key fields cannot have their required fields set to False. + If True, this model cannot be saved without a value defined for this column. Defaults to False. Primary key fields cannot have their required fields set to False. + .. attribute:: BaseColumn.clustering_order + + Defines CLUSTERING ORDER for this column (valid choices are "asc" (default) or "desc"). It may be specified only for clustering primary keys - more: http://www.datastax.com/docs/1.2/cql_cli/cql/CREATE_TABLE#using-clustering-order diff --git a/docs/topics/connection.rst b/docs/topics/connection.rst index 14bc2341..11eeda39 100644 --- a/docs/topics/connection.rst +++ b/docs/topics/connection.rst @@ -1,17 +1,26 @@ -============== +========== Connection -============== +========== + +**Users of versions < 0.4, please read this post before upgrading:** `Breaking Changes`_ + +.. _Breaking Changes: https://groups.google.com/forum/?fromgroups#!topic/cqlengine-users/erkSNe1JwuU .. module:: cqlengine.connection The setup function in `cqlengine.connection` records the Cassandra servers to connect to. If there is a problem with one of the servers, cqlengine will try to connect to each of the other connections before failing. -.. function:: setup(hosts [, username=None, password=None]) +.. function:: setup(hosts) :param hosts: list of hosts, strings in the :, or just :type hosts: list + :param consistency: the consistency level of the connection, defaults to 'ONE' + :type consistency: int + + # see http://datastax.github.io/python-driver/api/cassandra.html#cassandra.ConsistencyLevel + Records the hosts and connects to one of them See the example at :ref:`getting-started` diff --git a/docs/topics/faq.rst b/docs/topics/faq.rst new file mode 100644 index 00000000..b250180b --- /dev/null +++ b/docs/topics/faq.rst @@ -0,0 +1,8 @@ +========================== +Frequently Asked Questions +========================== + +Q: Why don't updates work correctly on models instantiated as Model(field=blah, field2=blah2)? +------------------------------------------------------------------- + +A: The recommended way to create new rows is with the models .create method. The values passed into a model's init method are interpreted by the model as the values as they were read from a row. This allows the model to "know" which rows have changed since the row was read out of cassandra, and create suitable update statements. \ No newline at end of file diff --git a/docs/topics/manage_schemas.rst b/docs/topics/manage_schemas.rst index 2174e56f..1da5b1c9 100644 --- a/docs/topics/manage_schemas.rst +++ b/docs/topics/manage_schemas.rst @@ -1,6 +1,12 @@ -=============== -Managing Schmas -=============== +================ +Managing Schemas +================ + +**Users of versions < 0.4, please read this post before upgrading:** `Breaking Changes`_ + +.. _Breaking Changes: https://groups.google.com/forum/?fromgroups#!topic/cqlengine-users/erkSNe1JwuU + +.. module:: cqlengine.connection .. module:: cqlengine.management @@ -20,23 +26,23 @@ Once a connection has been made to Cassandra, you can use the functions in ``cql deletes the keyspace with the given name -.. function:: create_table(model [, create_missing_keyspace=True]) - +.. function:: sync_table(model [, create_missing_keyspace=True]) + :param model: the :class:`~cqlengine.model.Model` class to make a table with :type model: :class:`~cqlengine.model.Model` :param create_missing_keyspace: *Optional* If True, the model's keyspace will be created if it does not already exist. Defaults to ``True`` :type create_missing_keyspace: bool - creates a CQL table for the given model + syncs a python model to cassandra (creates & alters) -.. function:: delete_table(model) +.. function:: drop_table(model) :param model: the :class:`~cqlengine.model.Model` class to delete a column family for :type model: :class:`~cqlengine.model.Model` deletes the CQL table for the given model - + See the example at :ref:`getting-started` diff --git a/docs/topics/models.rst b/docs/topics/models.rst index 0dd2c19c..16025907 100644 --- a/docs/topics/models.rst +++ b/docs/topics/models.rst @@ -2,12 +2,18 @@ Models ====== +**Users of versions < 0.4, please read this post before upgrading:** `Breaking Changes`_ + +.. _Breaking Changes: https://groups.google.com/forum/?fromgroups#!topic/cqlengine-users/erkSNe1JwuU + +.. module:: cqlengine.connection + .. module:: cqlengine.models A model is a python class representing a CQL table. -Example -======= +Examples +======== This example defines a Person table, with the columns ``first_name`` and ``last_name`` @@ -15,11 +21,11 @@ This example defines a Person table, with the columns ``first_name`` and ``last_ from cqlengine import columns from cqlengine.models import Model - + class Person(Model): first_name = columns.Text() last_name = columns.Text() - + The Person model would create this CQL table: @@ -32,10 +38,42 @@ The Person model would create this CQL table: PRIMARY KEY (id) ) +Here's an example of a comment table created with clustering keys, in descending order: + +.. code-block:: python + + from cqlengine import columns + from cqlengine.models import Model + + class Comment(Model): + photo_id = columns.UUID(primary_key=True) + comment_id = columns.TimeUUID(primary_key=True, clustering_order="DESC") + comment = columns.Text() + +The Comment model's ``create table`` would look like the following: + +.. code-block:: sql + + CREATE TABLE comment ( + photo_id uuid, + comment_id timeuuid, + comment text, + PRIMARY KEY (photo_id, comment_id) + ) WITH CLUSTERING ORDER BY (comment_id DESC) + +To sync the models to the database, you may do the following: + +.. code-block:: python + + from cqlengine.management import sync_table + sync_table(Person) + sync_table(Comment) + + Columns ======= - Columns in your models map to columns in your CQL table. You define CQL columns by defining column attributes on your model classes. For a model to be valid it needs at least one primary key column (defined automatically if you don't define one) and one non-primary key column. + Columns in your models map to columns in your CQL table. You define CQL columns by defining column attributes on your model classes. For a model to be valid it needs at least one primary key column and one non-primary key column. Just as in CQL, the order you define your columns in is important, and is the same order they are defined in on a model's corresponding table. @@ -61,26 +99,37 @@ Column Types Column Options -------------- - Each column can be defined with optional arguments to modify the way they behave. While some column types may define additional column options, these are the options that are available on all columns: + Each column can be defined with optional arguments to modify the way they behave. While some column types may + define additional column options, these are the options that are available on all columns: :attr:`~cqlengine.columns.BaseColumn.primary_key` If True, this column is created as a primary key field. A model can have multiple primary keys. Defaults to False. - *In CQL, there are 2 types of primary keys: partition keys and clustering keys. As with CQL, the first primary key is the partition key, and all others are clustering keys.* + *In CQL, there are 2 types of primary keys: partition keys and clustering keys. As with CQL, the first + primary key is the partition key, and all others are clustering keys, unless partition keys are specified + manually using* :attr:`~cqlengine.columns.BaseColumn.partition_key` + + :attr:`~cqlengine.columns.BaseColumn.partition_key` + If True, this column is created as partition primary key. There may be many partition keys defined, + forming a *composite partition key* + + :attr:`~cqlengine.columns.BaseColumn.clustering_order` + ``ASC`` or ``DESC``, determines the clustering order of a clustering key. :attr:`~cqlengine.columns.BaseColumn.index` If True, an index will be created for this column. Defaults to False. - - *Note: Indexes can only be created on models with one primary key* :attr:`~cqlengine.columns.BaseColumn.db_field` - Explicitly sets the name of the column in the database table. If this is left blank, the column name will be the same as the name of the column attribute. Defaults to None. + Explicitly sets the name of the column in the database table. If this is left blank, the column name will be + the same as the name of the column attribute. Defaults to None. :attr:`~cqlengine.columns.BaseColumn.default` - The default value for this column. If a model instance is saved without a value for this column having been defined, the default value will be used. This can be either a value or a callable object (ie: datetime.now is a valid default argument). + The default value for this column. If a model instance is saved without a value for this column having been + defined, the default value will be used. This can be either a value or a callable object (ie: datetime.now is a valid default argument). + Callable defaults will be called each time a default is assigned to a None value :attr:`~cqlengine.columns.BaseColumn.required` - If True, this model cannot be saved without a value defined for this column. Defaults to True. Primary key fields cannot have their required fields set to False. + If True, this model cannot be saved without a value defined for this column. Defaults to False. Primary key fields always require values. Model Methods ============= @@ -93,7 +142,7 @@ Model Methods *Example* .. code-block:: python - + #using the person model from earlier: class Person(Model): first_name = columns.Text() @@ -102,7 +151,7 @@ Model Methods person = Person(first_name='Blake', last_name='Eggleston') person.first_name #returns 'Blake' person.last_name #returns 'Eggleston' - + .. method:: save() @@ -117,22 +166,265 @@ Model Methods #saves it to Cassandra person.save() - .. method:: delete() - + Deletes the object from the database. + .. method:: batch(batch_object) + + Sets the batch object to run instance updates and inserts queries with. + + .. method:: timestamp(timedelta_or_datetime) + + Sets the timestamp for the query + + .. method:: ttl(ttl_in_sec) + + Sets the ttl values to run instance updates and inserts queries with. + + .. method:: update(**values) + + Performs an update on the model instance. You can pass in values to set on the model + for updating, or you can call without values to execute an update against any modified + fields. If no fields on the model have been modified since loading, no query will be + performed. Model validation is performed normally. + + .. method:: get_changed_columns() + + Returns a list of column names that have changed since the model was instantiated or saved + Model Attributes ================ - .. attribute:: Model.table_name + .. attribute:: Model.__abstract__ + + *Optional.* Indicates that this model is only intended to be used as a base class for other models. You can't create tables for abstract models, but checks around schema validity are skipped during class construction. + + .. attribute:: Model.__table_name__ *Optional.* Sets the name of the CQL table for this model. If left blank, the table name will be the name of the model, with it's module name as it's prefix. Manually defined table names are not inherited. - .. attribute:: Model.keyspace + .. _keyspace-change: + .. attribute:: Model.__keyspace__ - *Optional.* Sets the name of the keyspace used by this model. Defaulst to cqlengine + Sets the name of the keyspace used by this model. + + **Prior to cqlengine 0.16, this setting defaulted + to 'cqlengine'. As of 0.16, this field needs to be set on all non-abstract models, or their base classes.** + + +Table Polymorphism +================== + + As of cqlengine 0.8, it is possible to save and load different model classes using a single CQL table. + This is useful in situations where you have different object types that you want to store in a single cassandra row. + + For instance, suppose you want a table that stores rows of pets owned by an owner: + + .. code-block:: python + + class Pet(Model): + __table_name__ = 'pet' + owner_id = UUID(primary_key=True) + pet_id = UUID(primary_key=True) + pet_type = Text(polymorphic_key=True) + name = Text() + + def eat(self, food): + pass + + def sleep(self, time): + pass + + class Cat(Pet): + __polymorphic_key__ = 'cat' + cuteness = Float() + + def tear_up_couch(self): + pass + + class Dog(Pet): + __polymorphic_key__ = 'dog' + fierceness = Float() + + def bark_all_night(self): + pass + + After calling ``sync_table`` on each of these tables, the columns defined in each model will be added to the + ``pet`` table. Additionally, saving ``Cat`` and ``Dog`` models will save the meta data needed to identify each row + as either a cat or dog. + + To setup a polymorphic model structure, follow these steps + + 1. Create a base model with a column set as the polymorphic_key (set ``polymorphic_key=True`` in the column definition) + 2. Create subclass models, and define a unique ``__polymorphic_key__`` value on each + 3. Run ``sync_table`` on each of the sub tables + + **About the polymorphic key** + + The polymorphic key is what cqlengine uses under the covers to map logical cql rows to the appropriate model type. The + base model maintains a map of polymorphic keys to subclasses. When a polymorphic model is saved, this value is automatically + saved into the polymorphic key column. You can set the polymorphic key column to any column type that you like, with + the exception of container and counter columns, although ``Integer`` columns make the most sense. Additionally, if you + set ``index=True`` on your polymorphic key column, you can execute queries against polymorphic subclasses, and a + ``WHERE`` clause will be automatically added to your query, returning only rows of that type. Note that you must + define a unique ``__polymorphic_key__`` value to each subclass, and that you can only assign a single polymorphic + key column per model + + +Extending Model Validation +========================== + + Each time you save a model instance in cqlengine, the data in the model is validated against the schema you've defined + for your model. Most of the validation is fairly straightforward, it basically checks that you're not trying to do + something like save text into an integer column, and it enforces the ``required`` flag set on column definitions. + It also performs any transformations needed to save the data properly. + + However, there are often additional constraints or transformations you want to impose on your data, beyond simply + making sure that Cassandra won't complain when you try to insert it. To define additional validation on a model, + extend the model's validation method: + + .. code-block:: python + + class Member(Model): + person_id = UUID(primary_key=True) + name = Text(required=True) + + def validate(self): + super(Member, self).validate() + if self.name == 'jon': + raise ValidationError('no jon\'s allowed') + + *Note*: while not required, the convention is to raise a ``ValidationError`` (``from cqlengine import ValidationError``) + if validation fails + + +Table Properties +================ + + Each table can have its own set of configuration options. + These can be specified on a model with the following attributes: + + .. attribute:: Model.__bloom_filter_fp_chance + + .. attribute:: Model.__caching__ + + .. attribute:: Model.__comment__ + + .. attribute:: Model.__dclocal_read_repair_chance__ + + .. attribute:: Model.__default_time_to_live__ + + .. attribute:: Model.__gc_grace_seconds__ + + .. attribute:: Model.__index_interval__ + + .. attribute:: Model.__memtable_flush_period_in_ms__ + + .. attribute:: Model.__populate_io_cache_on_flush__ + + .. attribute:: Model.__read_repair_chance__ + + .. attribute:: Model.__replicate_on_write__ + + Example: + + .. code-block:: python + + from cqlengine import ROWS_ONLY, columns + from cqlengine.models import Model + + class User(Model): + __caching__ = ROWS_ONLY # cache only rows instead of keys only by default + __gc_grace_seconds__ = 86400 # 1 day instead of the default 10 days + + user_id = columns.UUID(primary_key=True) + name = columns.Text() + + Will produce the following CQL statement: + + .. code-block:: sql + + CREATE TABLE cqlengine.user ( + user_id uuid, + name text, + PRIMARY KEY (user_id) + ) WITH caching = 'rows_only' + AND gc_grace_seconds = 86400; + + See the `list of supported table properties for more information + `_. + + +Compaction Options +================== + + As of cqlengine 0.7 we've added support for specifying compaction options. cqlengine will only use your compaction options if you have a strategy set. When a table is synced, it will be altered to match the compaction options set on your table. This means that if you are changing settings manually they will be changed back on resync. Do not use the compaction settings of cqlengine if you want to manage your compaction settings manually. + + cqlengine supports all compaction options as of Cassandra 1.2.8. + + Available Options: + + .. attribute:: Model.__compaction_bucket_high__ + + .. attribute:: Model.__compaction_bucket_low__ + + .. attribute:: Model.__compaction_max_compaction_threshold__ + + .. attribute:: Model.__compaction_min_compaction_threshold__ + + .. attribute:: Model.__compaction_min_sstable_size__ + + .. attribute:: Model.__compaction_sstable_size_in_mb__ + + .. attribute:: Model.__compaction_tombstone_compaction_interval__ + + .. attribute:: Model.__compaction_tombstone_threshold__ + + For example: + + .. code-block:: python + + class User(Model): + __compaction__ = cqlengine.LeveledCompactionStrategy + __compaction_sstable_size_in_mb__ = 64 + __compaction_tombstone_threshold__ = .2 + + user_id = columns.UUID(primary_key=True) + name = columns.Text() + + or for SizeTieredCompaction: + + .. code-block:: python + + class TimeData(Model): + __compaction__ = SizeTieredCompactionStrategy + __compaction_bucket_low__ = .3 + __compaction_bucket_high__ = 2 + __compaction_min_threshold__ = 2 + __compaction_max_threshold__ = 64 + __compaction_tombstone_compaction_interval__ = 86400 + + Tables may use `LeveledCompactionStrategy` or `SizeTieredCompactionStrategy`. Both options are available in the top level cqlengine module. To reiterate, you will need to set your `__compaction__` option explicitly in order for cqlengine to handle any of your settings. + + +Manipulating model instances as dictionaries +============================================ + + As of cqlengine 0.12, we've added support for treating model instances like dictionaries. See below for examples. + + .. code-block:: python + + class Person(Model): + first_name = columns.Text() + last_name = columns.Text() + + kevin = Person.create(first_name="Kevin", last_name="Deldycke") + dict(kevin) # returns {'first_name': 'Kevin', 'last_name': 'Deldycke'} + kevin['first_name'] # returns 'Kevin' + kevin.keys() # returns ['first_name', 'last_name'] + kevin.values() # returns ['Kevin', 'Deldycke'] + kevin.items() # returns [('first_name', 'Kevin'), ('last_name', 'Deldycke')] + + kevin['first_name'] = 'KEVIN5000' # changes the models first name -Automatic Primary Keys -====================== - CQL requires that all tables define at least one primary key. If a model definition does not include a primary key column, cqlengine will automatically add a uuid primary key column named ``id``. diff --git a/docs/topics/queryset.rst b/docs/topics/queryset.rst index 6733080c..e1e17264 100644 --- a/docs/topics/queryset.rst +++ b/docs/topics/queryset.rst @@ -2,6 +2,12 @@ Making Queries ============== +**Users of versions < 0.4, please read this post before upgrading:** `Breaking Changes`_ + +.. _Breaking Changes: https://groups.google.com/forum/?fromgroups#!topic/cqlengine-users/erkSNe1JwuU + +.. module:: cqlengine.connection + .. module:: cqlengine.query Retrieving objects @@ -23,15 +29,15 @@ Retrieving all objects .. _retrieving-objects-with-filters: Retrieving objects with filters ----------------------------------------- +------------------------------- Typically, you'll want to query only a subset of the records in your database. That can be accomplished with the QuerySet's ``.filter(\*\*)`` method. For example, given the model definition: - + .. code-block:: python - + class Automobile(Model): manufacturer = columns.Text(primary_key=True) year = columns.Integer(primary_key=True) @@ -40,11 +46,17 @@ Retrieving objects with filters ...and assuming the Automobile table contains a record of every car model manufactured in the last 20 years or so, we can retrieve only the cars made by a single manufacturer like this: - + .. code-block:: python q = Automobile.objects.filter(manufacturer='Tesla') + You can also use the more convenient syntax: + + .. code-block:: python + + q = Automobile.objects(Automobile.manufacturer == 'Tesla') + We can then further filter our query with another call to **.filter** .. code-block:: python @@ -123,6 +135,7 @@ Filtering Operators q = Automobile.objects.filter(manufacturer='Tesla') q = q.filter(year__in=[2011, 2012]) + :attr:`> (__gt) ` .. code-block:: python @@ -130,6 +143,10 @@ Filtering Operators q = Automobile.objects.filter(manufacturer='Tesla') q = q.filter(year__gt=2010) # year > 2010 + # or the nicer syntax + + q.filter(Automobile.year > 2010) + :attr:`>= (__gte) ` .. code-block:: python @@ -137,6 +154,10 @@ Filtering Operators q = Automobile.objects.filter(manufacturer='Tesla') q = q.filter(year__gte=2010) # year >= 2010 + # or the nicer syntax + + q.filter(Automobile.year >= 2010) + :attr:`< (__lt) ` .. code-block:: python @@ -144,6 +165,10 @@ Filtering Operators q = Automobile.objects.filter(manufacturer='Tesla') q = q.filter(year__lt=2012) # year < 2012 + # or... + + q.filter(Automobile.year < 2012) + :attr:`<= (__lte) ` .. code-block:: python @@ -151,6 +176,8 @@ Filtering Operators q = Automobile.objects.filter(manufacturer='Tesla') q = q.filter(year__lte=2012) # year <= 2012 + q.filter(Automobile.year <= 2012) + TimeUUID Functions ================== @@ -178,7 +205,29 @@ TimeUUID Functions DataStream.filter(time__gt=cqlengine.MinTimeUUID(min_time), time__lt=cqlengine.MaxTimeUUID(max_time)) -QuerySets are imutable +Token Function +============== + + Token functon may be used only on special, virtual column pk__token, representing token of partition key (it also works for composite partition keys). + Cassandra orders returned items by value of partition key token, so using cqlengine.Token we can easy paginate through all table rows. + + See http://cassandra.apache.org/doc/cql3/CQL.html#tokenFun + + *Example* + + .. code-block:: python + + class Items(Model): + id = cqlengine.Text(primary_key=True) + data = cqlengine.Bytes() + + query = Items.objects.all().limit(10) + + first_page = list(query); + last = first_page[-1] + next_page = list(query.filter(pk__token__gt=cqlengine.Token(last.pk))) + +QuerySets are immutable ====================== When calling any method that changes a queryset, the method does not actually change the queryset object it's called on, but returns a new queryset object with the attributes of the original queryset, plus the attributes added in the method call. @@ -198,7 +247,7 @@ Ordering QuerySets Since Cassandra is essentially a distributed hash table on steroids, the order you get records back in will not be particularly predictable. - However, you can set a column to order on with the ``.order_by(column_name)`` method. + However, you can set a column to order on with the ``.order_by(column_name)`` method. *Example* @@ -213,15 +262,36 @@ Ordering QuerySets *For instance, given our Automobile model, year is the only column we can order on.* -Batch Queries -=============== +Values Lists +============ + + There is a special QuerySet's method ``.values_list()`` - when called, QuerySet returns lists of values instead of model instances. It may significantly speedup things with lower memory footprint for large responses. + Each tuple contains the value from the respective field passed into the ``values_list()`` call — so the first item is the first field, etc. For example: + + .. code-block:: python + + items = list(range(20)) + random.shuffle(items) + for i in items: + TestModel.create(id=1, clustering_key=i) + + values = list(TestModel.objects.values_list('clustering_key', flat=True)) + # [19L, 18L, 17L, 16L, 15L, 14L, 13L, 12L, 11L, 10L, 9L, 8L, 7L, 6L, 5L, 4L, 3L, 2L, 1L, 0L] + + + +Batch Queries +============= + + cqlengine now supports batch queries using the BatchQuery class. Batch queries can be started and stopped manually, or within a context manager. To add queries to the batch object, you just need to precede the create/save/delete call with a call to batch, and pass in the batch object. + +Batch Query General Use Pattern +------------------------------- - cqlengine now supports batch queries using the BatchQuery class. Batch queries can be started and stopped manually, or within a context manager. To add queries to the batch object, you just need to precede the create/save/delete call with a call to batch, and pass in the batch object. - You can only create, update, and delete rows with a batch query, attempting to read rows out of the database with a batch query will fail. .. code-block:: python - + from cqlengine import BatchQuery #using a context manager @@ -241,6 +311,72 @@ Batch Queries em3 = ExampleModel.batch(b).create(example_type=0, description="3", created_at=now) b.execute() + # updating in a batch + + b = BatchQuery() + em1.description = "new description" + em1.batch(b).save() + em2.description = "another new description" + em2.batch(b).save() + b.execute() + + # deleting in a batch + b = BatchQuery() + ExampleModel.objects(id=some_id).batch(b).delete() + ExampleModel.objects(id=some_id2).batch(b).delete() + b.execute() + + + Typically you will not want the block to execute if an exception occurs inside the `with` block. However, in the case that this is desirable, it's achievable by using the following syntax: + + .. code-block:: python + + with BatchQuery(execute_on_exception=True) as b: + LogEntry.batch(b).create(k=1, v=1) + mystery_function() # exception thrown in here + LogEntry.batch(b).create(k=1, v=2) # this code is never reached due to the exception, but anything leading up to here will execute in the batch. + + If an exception is thrown somewhere in the block, any statements that have been added to the batch will still be executed. This is useful for some logging situations. + +Batch Query Execution Callbacks +------------------------------- + + In order to allow secondary tasks to be chained to the end of batch, BatchQuery instances allow callbacks to be + registered with the batch, to be executed immediately after the batch executes. + + Multiple callbacks can be attached to same BatchQuery instance, they are executed in the same order that they + are added to the batch. + + The callbacks attached to a given batch instance are executed only if the batch executes. If the batch is used as a + context manager and an exception is raised, the queued up callbacks will not be run. + + .. code-block:: python + + def my_callback(*args, **kwargs): + pass + + batch = BatchQuery() + + batch.add_callback(my_callback) + batch.add_callback(my_callback, 'positional arg', named_arg='named arg value') + + # if you need reference to the batch within the callback, + # just trap it in the arguments to be passed to the callback: + batch.add_callback(my_callback, cqlengine_batch=batch) + + # once the batch executes... + batch.execute() + + # the effect of the above scheduled callbacks will be similar to + my_callback() + my_callback('positional arg', named_arg='named arg value') + my_callback(cqlengine_batch=batch) + + Failure in any of the callbacks does not affect the batch's execution, as the callbacks are started after the execution + of the batch is complete. + + + QuerySet method reference ========================= @@ -250,6 +386,15 @@ QuerySet method reference Returns a queryset matching all rows + .. method:: batch(batch_object) + + Sets the batch object to run the query on. Note that running a select query with a batch object will raise an exception + + .. method:: consistency(consistency_setting) + + Sets the consistency level for the operation. Options may be imported from the top level :attr:`cqlengine` package. + + .. method:: count() Returns the number of matching rows in your QuerySet @@ -269,7 +414,7 @@ QuerySet method reference .. method:: limit(num) Limits the number of results returned by Cassandra. - + *Note that CQL's default limit is 10,000, so all queries without a limit set explicitly will have an implicit limit of 10,000* .. method:: order_by(field_name) @@ -277,8 +422,106 @@ QuerySet method reference :param field_name: the name of the field to order on. *Note: the field_name must be a clustering key* :type field_name: string - Sets the field to order on. + Sets the field to order on. .. method:: allow_filtering() Enables the (usually) unwise practive of querying on a clustering key without also defining a partition key + + .. method:: timestamp(timestamp_or_long_or_datetime) + + Allows for custom timestamps to be saved with the record. + + .. method:: ttl(ttl_in_seconds) + + :param ttl_in_seconds: time in seconds in which the saved values should expire + :type ttl_in_seconds: int + + Sets the ttl to run the query query with. Note that running a select query with a ttl value will raise an exception + + .. method:: update(**values) + + Performs an update on the row selected by the queryset. Include values to update in the + update like so: + + .. code-block:: python + Model.objects(key=n).update(value='x') + + Passing in updates for columns which are not part of the model will raise a ValidationError. + Per column validation will be performed, but instance level validation will not + (`Model.validate` is not called). + + The queryset update method also supports blindly adding and removing elements from container columns, without + loading a model instance from Cassandra. + + Using the syntax `.update(column_name={x, y, z})` will overwrite the contents of the container, like updating a + non container column. However, adding `__` to the end of the keyword arg, makes the update call add + or remove items from the collection, without overwriting then entire column. + + + Given the model below, here are the operations that can be performed on the different container columns: + + .. code-block:: python + + class Row(Model): + row_id = columns.Integer(primary_key=True) + set_column = columns.Set(Integer) + list_column = columns.Set(Integer) + map_column = columns.Set(Integer, Integer) + + :class:`~cqlengine.columns.Set` + + - `add`: adds the elements of the given set to the column + - `remove`: removes the elements of the given set to the column + + + .. code-block:: python + + # add elements to a set + Row.objects(row_id=5).update(set_column__add={6}) + + # remove elements to a set + Row.objects(row_id=5).update(set_column__remove={4}) + + :class:`~cqlengine.columns.List` + + - `append`: appends the elements of the given list to the end of the column + - `prepend`: prepends the elements of the given list to the beginning of the column + + .. code-block:: python + + # append items to a list + Row.objects(row_id=5).update(list_column__append=[6, 7]) + + # prepend items to a list + Row.objects(row_id=5).update(list_column__prepend=[1, 2]) + + + :class:`~cqlengine.columns.Map` + + - `update`: adds the given keys/values to the columns, creating new entries if they didn't exist, and overwriting old ones if they did + + .. code-block:: python + + # add items to a map + Row.objects(row_id=5).update(map_column__update={1: 2, 3: 4}) + + +Named Tables +=================== + +Named tables are a way of querying a table without creating an class. They're useful for querying system tables or exploring an unfamiliar database. + + + .. code-block:: python + + from cqlengine.connection import setup + setup("127.0.0.1", "cqlengine_test") + + from cqlengine.named import NamedTable + user = NamedTable("cqlengine_test", "user") + user.objects() + user.objects()[0] + + # {u'pk': 1, u't': datetime.datetime(2014, 6, 26, 17, 10, 31, 774000)} + diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 00000000..2773a99f --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,6 @@ +nose +nose-progressive +profilestats +pycallgraph +ipdbplugin==1.2 +ipdb==0.7 diff --git a/requirements.txt b/requirements.txt index 04a6bc44..e648425a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ -cql==1.4.0 -ipython==0.13.1 -ipdb==0.7 -Sphinx==1.1.3 -mock==1.0.1 +ipython +ipdb +Sphinx +mock +sure==1.2.5 +cassandra-driver>=2.0.0 diff --git a/setup.py b/setup.py index af822166..a427f36b 100644 --- a/setup.py +++ b/setup.py @@ -4,42 +4,37 @@ from setuptools import setup, find_packages #python setup.py register #python setup.py sdist upload -version = '0.2' +version = open('cqlengine/VERSION', 'r').readline().strip() long_desc = """ -cqlengine is a Cassandra CQL ORM for Python in the style of the Django orm and mongoengine +Cassandra CQL 3 Object Mapper for Python [Documentation](https://cqlengine.readthedocs.org/en/latest/) [Report a Bug](https://github.com/bdeggleston/cqlengine/issues) [Users Mailing List](https://groups.google.com/forum/?fromgroups#!forum/cqlengine-users) - -[Dev Mailing List](https://groups.google.com/forum/?fromgroups#!forum/cqlengine-dev) """ setup( name='cqlengine', version=version, - description='Cassandra CQL ORM for Python in the style of the Django orm and mongoengine', - dependency_links = ['https://github.com/bdeggleston/cqlengine/archive/{0}.tar.gz#egg=cqlengine-{0}'.format(version)], + description='Cassandra CQL 3 Object Mapper for Python', long_description=long_desc, classifiers = [ - "Development Status :: 3 - Alpha", "Environment :: Web Environment", "Environment :: Plugins", "License :: OSI Approved :: BSD License", "Operating System :: OS Independent", - "Programming Language :: Python :: 2.6", "Programming Language :: Python :: 2.7", "Topic :: Internet :: WWW/HTTP", "Topic :: Software Development :: Libraries :: Python Modules", ], keywords='cassandra,cql,orm', - install_requires = ['cql'], - author='Blake Eggleston', + install_requires = ['cassandra-driver >= 2.0.0'], + author='Blake Eggleston, Jon Haddad', author_email='bdeggleston@gmail.com', - url='https://github.com/bdeggleston/cqlengine', + url='https://github.com/cqlengine/cqlengine', license='BSD', packages=find_packages(), include_package_data=True, diff --git a/tox.ini b/tox.ini new file mode 100644 index 00000000..eee19848 --- /dev/null +++ b/tox.ini @@ -0,0 +1,6 @@ +[tox] +envlist=py27,py34 + +[testenv] +deps= -rrequirements.txt +commands=nosetests --no-skip diff --git a/upgrading.txt b/upgrading.txt new file mode 100644 index 00000000..95727aa5 --- /dev/null +++ b/upgrading.txt @@ -0,0 +1,13 @@ +0.15 Upgrade to use Datastax Native Driver + +We no longer raise cqlengine based OperationalError when connection fails. Now using the exception thrown in the native driver. + +If you're calling setup() with the old thrift port in place, you will get connection errors. Either remove it (the default native port is assumed) or change to the default native port, 9042. + +The cqlengine connection pool has been removed. Connections are now managed by the native driver. This should drastically reduce the socket overhead as the native driver can multiplex queries. + +If you were previously manually using "ALL", "QUORUM", etc, to specificy consistency levels, you will need to migrate to the cqlengine.ALL, QUORUM, etc instead. If you had been using the module level constants before, nothing should need to change. + +No longer accepting username & password as arguments to setup. Use the native driver's authentication instead. See http://datastax.github.io/python-driver/api/cassandra/auth.html + +