Merge branch 'tzaware-functions' of github.com:dokai/cqlengine into pull-118

This commit is contained in:
Blake Eggleston
2013-10-24 17:25:21 -07:00
3 changed files with 74 additions and 15 deletions

View File

@@ -332,11 +332,9 @@ class DateTime(Column):
else: else:
raise ValidationError("'{}' is not a datetime object".format(value)) raise ValidationError("'{}' is not a datetime object".format(value))
epoch = datetime(1970, 1, 1, tzinfo=value.tzinfo) epoch = datetime(1970, 1, 1, tzinfo=value.tzinfo)
offset = 0 offset = epoch.tzinfo.utcoffset(epoch).total_seconds() if epoch.tzinfo else 0
if epoch.tzinfo:
offset_delta = epoch.tzinfo.utcoffset(epoch) return long(((value - epoch).total_seconds() - offset) * 1000)
offset = offset_delta.days*24*3600 + offset_delta.seconds
return long(((value - epoch).total_seconds() - offset) * 1000)
class Date(Column): class Date(Column):
@@ -406,12 +404,7 @@ class TimeUUID(UUID):
global _last_timestamp global _last_timestamp
epoch = datetime(1970, 1, 1, tzinfo=dt.tzinfo) epoch = datetime(1970, 1, 1, tzinfo=dt.tzinfo)
offset = epoch.tzinfo.utcoffset(epoch).total_seconds() if epoch.tzinfo else 0
offset = 0
if epoch.tzinfo:
offset_delta = epoch.tzinfo.utcoffset(epoch)
offset = offset_delta.days*24*3600 + offset_delta.seconds
timestamp = (dt - epoch).total_seconds() - offset timestamp = (dt - epoch).total_seconds() - offset
node = None node = None

View File

@@ -54,8 +54,10 @@ class MinTimeUUID(BaseQueryFunction):
super(MinTimeUUID, self).__init__(value) super(MinTimeUUID, self).__init__(value)
def get_value(self): def get_value(self):
epoch = datetime(1970, 1, 1) epoch = datetime(1970, 1, 1, tzinfo=self.value.tzinfo)
return long((self.value - epoch).total_seconds() * 1000) offset = epoch.tzinfo.utcoffset(epoch).total_seconds() if epoch.tzinfo else 0
return long(((self.value - epoch).total_seconds() - offset) * 1000)
def get_dict(self, column): def get_dict(self, column):
return {self.identifier: self.get_value()} return {self.identifier: self.get_value()}
@@ -79,8 +81,10 @@ class MaxTimeUUID(BaseQueryFunction):
super(MaxTimeUUID, self).__init__(value) super(MaxTimeUUID, self).__init__(value)
def get_value(self): def get_value(self):
epoch = datetime(1970, 1, 1) epoch = datetime(1970, 1, 1, tzinfo=self.value.tzinfo)
return long((self.value - epoch).total_seconds() * 1000) offset = epoch.tzinfo.utcoffset(epoch).total_seconds() if epoch.tzinfo else 0
return long(((self.value - epoch).total_seconds() - offset) * 1000)
def get_dict(self, column): def get_dict(self, column):
return {self.identifier: self.get_value()} return {self.identifier: self.get_value()}

View File

@@ -11,6 +11,25 @@ from cqlengine.management import delete_table
from cqlengine.models import Model from cqlengine.models import Model
from cqlengine import columns from cqlengine import columns
from cqlengine import query from cqlengine import query
from datetime import timedelta
from datetime import tzinfo
class TzOffset(tzinfo):
"""Minimal implementation of a timezone offset to help testing with timezone
aware datetimes.
"""
def __init__(self, offset):
self._offset = timedelta(hours=offset)
def utcoffset(self, dt):
return self._offset
def tzname(self, dt):
return 'TzOffset: {}'.format(self._offset.hours)
def dst(self, dt):
return timedelta(0)
class TestModel(Model): class TestModel(Model):
test_id = columns.Integer(primary_key=True) test_id = columns.Integer(primary_key=True)
@@ -515,6 +534,49 @@ class TestMinMaxTimeUUIDFunctions(BaseCassEngTestCase):
super(TestMinMaxTimeUUIDFunctions, cls).tearDownClass() super(TestMinMaxTimeUUIDFunctions, cls).tearDownClass()
delete_table(TimeUUIDQueryModel) delete_table(TimeUUIDQueryModel)
def test_tzaware_datetime_support(self):
"""Test that using timezone aware datetime instances works with the
MinTimeUUID/MaxTimeUUID functions.
"""
pk = uuid4()
midpoint_utc = datetime.utcnow().replace(tzinfo=TzOffset(0))
midpoint_helsinki = midpoint_utc.astimezone(TzOffset(3))
# Assert pre-condition that we have the same logical point in time
assert midpoint_utc.utctimetuple() == midpoint_helsinki.utctimetuple()
assert midpoint_utc.timetuple() != midpoint_helsinki.timetuple()
TimeUUIDQueryModel.create(
partition=pk,
time=columns.TimeUUID.from_datetime(midpoint_utc - timedelta(minutes=1)),
data='1')
TimeUUIDQueryModel.create(
partition=pk,
time=columns.TimeUUID.from_datetime(midpoint_utc),
data='2')
TimeUUIDQueryModel.create(
partition=pk,
time=columns.TimeUUID.from_datetime(midpoint_utc + timedelta(minutes=1)),
data='3')
assert ['1', '2'] == [o.data for o in TimeUUIDQueryModel.filter(
TimeUUIDQueryModel.partition == pk,
TimeUUIDQueryModel.time <= functions.MaxTimeUUID(midpoint_utc))]
assert ['1', '2'] == [o.data for o in TimeUUIDQueryModel.filter(
TimeUUIDQueryModel.partition == pk,
TimeUUIDQueryModel.time <= functions.MaxTimeUUID(midpoint_helsinki))]
assert ['2', '3'] == [o.data for o in TimeUUIDQueryModel.filter(
TimeUUIDQueryModel.partition == pk,
TimeUUIDQueryModel.time >= functions.MinTimeUUID(midpoint_utc))]
assert ['2', '3'] == [o.data for o in TimeUUIDQueryModel.filter(
TimeUUIDQueryModel.partition == pk,
TimeUUIDQueryModel.time >= functions.MinTimeUUID(midpoint_helsinki))]
def test_success_case(self): def test_success_case(self):
""" Test that the min and max time uuid functions work as expected """ """ Test that the min and max time uuid functions work as expected """
pk = uuid4() pk = uuid4()