diff --git a/sqlalchemy_utils/primitives/__init__.py b/sqlalchemy_utils/primitives/__init__.py index f00f3ac..9a7dd90 100644 --- a/sqlalchemy_utils/primitives/__init__.py +++ b/sqlalchemy_utils/primitives/__init__.py @@ -1,7 +1,11 @@ from .number_range import NumberRange, NumberRangeException +from .weekday import WeekDay +from .weekdays import WeekDays __all__ = ( NumberRange, - NumberRangeException + NumberRangeException, + WeekDay, + WeekDays ) diff --git a/sqlalchemy_utils/primitives/utils.py b/sqlalchemy_utils/primitives/utils.py new file mode 100644 index 0000000..4efed7b --- /dev/null +++ b/sqlalchemy_utils/primitives/utils.py @@ -0,0 +1,13 @@ +import sys + + +def str_coercible(cls): + if sys.version_info[0] >= 3: # Python 3 + def __str__(self): + return self.__unicode__() + else: # Python 2 + def __str__(self): + return self.__unicode__().encode('utf8') + + cls.__str__ = __str__ + return cls diff --git a/sqlalchemy_utils/primitives/weekday.py b/sqlalchemy_utils/primitives/weekday.py new file mode 100644 index 0000000..8c9e7a2 --- /dev/null +++ b/sqlalchemy_utils/primitives/weekday.py @@ -0,0 +1,63 @@ +# -*- coding: utf-8 -*- +import sys +from sqlalchemy import types +from sqlalchemy.dialects.postgresql import BIT +import six + +try: + from functools import total_ordering +except ImportError: + # Python 2.6 port + from total_ordering import total_ordering +from sqlalchemy_utils import i18n +from .utils import str_coercible + + +@str_coercible +@total_ordering +class WeekDay(object): + NUM_WEEK_DAYS = 7 + + def __init__(self, index): + if not (0 <= index < self.NUM_WEEK_DAYS): + raise ValueError( + "index must be between 0 and %d" % self.NUM_WEEK_DAYS + ) + self.index = index + + def __eq__(self, other): + if isinstance(other, WeekDay): + return self.index == other.index + else: + return NotImplemented + + def __hash__(self): + return hash(self.index) + + def __lt__(self, other): + return self.position < other.position + + def __repr__(self): + return '%s(%r)' % (self.__class__.__name__, self.index) + + def __unicode__(self): + return self.name + + def get_name(self, width='wide', context='format'): + names = i18n.get_day_names( + width, + context, + i18n.get_locale() + ) + return names[self.index] + + @property + def name(self): + return self.get_name() + + @property + def position(self): + return ( + self.index - + i18n.get_locale().first_week_day + ) % self.NUM_WEEK_DAYS diff --git a/sqlalchemy_utils/primitives/weekdays.py b/sqlalchemy_utils/primitives/weekdays.py new file mode 100644 index 0000000..b8dffb7 --- /dev/null +++ b/sqlalchemy_utils/primitives/weekdays.py @@ -0,0 +1,60 @@ +import six + +from .utils import str_coercible +from .weekday import WeekDay + + +@str_coercible +class WeekDays(object): + def __init__(self, bit_string_or_week_days): + if isinstance(bit_string_or_week_days, six.string_types): + self._days = set() + + if len(bit_string_or_week_days) != WeekDay.NUM_WEEK_DAYS: + raise ValueError( + 'Bit string must be {0} characters long.'.format( + WeekDay.NUM_WEEK_DAYS + ) + ) + + for index, bit in enumerate(bit_string_or_week_days): + if bit not in '01': + raise ValueError( + 'Bit string may only contain zeroes and ones.' + ) + if bit == '1': + self._days.add(WeekDay(index)) + elif isinstance(bit_string_or_week_days, WeekDays): + self._days = bit_string_or_week_days._days + else: + self._days = set(bit_string_or_week_days) + + def __eq__(self, other): + if isinstance(other, WeekDays): + return self._days == other._days + elif isinstance(other, six.string_types): + return self.as_bit_string() == other + else: + return NotImplemented + + def __iter__(self): + for day in sorted(self._days): + yield day + + def __contains__(self, value): + return value in self._days + + def __repr__(self): + return '%s(%r)' % ( + self.__class__.__name__, + self.as_bit_string() + ) + + def __unicode__(self): + return u', '.join(six.text_type(day) for day in self) + + def as_bit_string(self): + return ''.join( + '1' if WeekDay(index) in self._days else '0' + for index in six.moves.xrange(WeekDay.NUM_WEEK_DAYS) + ) diff --git a/sqlalchemy_utils/types/weekdays.py b/sqlalchemy_utils/types/weekdays.py index 753e9d2..505988c 100644 --- a/sqlalchemy_utils/types/weekdays.py +++ b/sqlalchemy_utils/types/weekdays.py @@ -1,133 +1,7 @@ -# -*- coding: utf-8 -*- -import sys +import six from sqlalchemy import types from sqlalchemy.dialects.postgresql import BIT -import six - -try: - from functools import total_ordering -except ImportError: - # Python 2.6 port - from total_ordering import total_ordering -from sqlalchemy_utils import i18n - - -@total_ordering -class WeekDay(object): - NUM_WEEK_DAYS = 7 - - def __init__(self, index): - if not (0 <= index < self.NUM_WEEK_DAYS): - raise ValueError( - "index must be between 0 and %d" % self.NUM_WEEK_DAYS - ) - self.index = index - - def __eq__(self, other): - if isinstance(other, WeekDay): - return self.index == other.index - else: - return NotImplemented - - def __hash__(self): - return hash(self.index) - - def __lt__(self, other): - return self.position < other.position - - def __repr__(self): - return '%s(%r)' % (self.__class__.__name__, self.index) - - if sys.version_info[0] >= 3: # Python 3 - def __str__(self): - return self.__unicode__() - else: # Python 2 - def __str__(self): - return self.__unicode__().encode('utf8') - - def __unicode__(self): - return self.name - - def get_name(self, width='wide', context='format'): - names = i18n.get_day_names( - width, - context, - i18n.get_locale() - ) - return names[self.index] - - @property - def name(self): - return self.get_name() - - @property - def position(self): - return ( - self.index - - i18n.get_locale().first_week_day - ) % self.NUM_WEEK_DAYS - - -class WeekDays(object): - def __init__(self, bit_string_or_week_days): - if isinstance(bit_string_or_week_days, six.string_types): - self._days = set() - - if len(bit_string_or_week_days) != WeekDay.NUM_WEEK_DAYS: - raise ValueError( - 'Bit string must be {0} characters long.'.format( - WeekDay.NUM_WEEK_DAYS - ) - ) - - for index, bit in enumerate(bit_string_or_week_days): - if bit not in '01': - raise ValueError( - 'Bit string may only contain zeroes and ones.' - ) - if bit == '1': - self._days.add(WeekDay(index)) - elif isinstance(bit_string_or_week_days, WeekDays): - self._days = bit_string_or_week_days._days - else: - self._days = set(bit_string_or_week_days) - - def __eq__(self, other): - if isinstance(other, WeekDays): - return self._days == other._days - elif isinstance(other, six.string_types): - return self.as_bit_string() == other - else: - return NotImplemented - - def __iter__(self): - for day in sorted(self._days): - yield day - - def __contains__(self, value): - return value in self._days - - def __repr__(self): - return '%s(%r)' % ( - self.__class__.__name__, - self.as_bit_string() - ) - - if sys.version_info[0] >= 3: # Python 3 - def __str__(self): - return self.__unicode__() - else: # Python 2 - def __str__(self): - return self.__unicode__().encode('utf8') - - def __unicode__(self): - return u', '.join(six.text_type(day) for day in self) - - def as_bit_string(self): - return ''.join( - '1' if WeekDay(index) in self._days else '0' - for index in six.moves.xrange(WeekDay.NUM_WEEK_DAYS) - ) +from sqlalchemy_utils.primitives import WeekDay, WeekDays class WeekDaysType(types.TypeDecorator): diff --git a/tests/primitives/test_weekdays.py b/tests/primitives/test_weekdays.py index 21168ee..70f5681 100644 --- a/tests/primitives/test_weekdays.py +++ b/tests/primitives/test_weekdays.py @@ -81,6 +81,11 @@ class TestWeekDay(object): flexmock(day).should_receive('name').and_return(u'maanantaina') assert six.text_type(day) == u'maanantaina' + def test_str(self): + day = WeekDay(0) + flexmock(day).should_receive('name').and_return(u'maanantaina') + assert str(day) == 'maanantaina' + @pytest.mark.skipif('Locale is None') class TestWeekDays(object): @@ -158,3 +163,8 @@ class TestWeekDays(object): i18n.get_locale = lambda: Locale('fi') days = WeekDays('1000100') assert six.text_type(days) == u'maanantaina, perjantaina' + + def test_str(self): + i18n.get_locale = lambda: Locale('fi') + days = WeekDays('1000100') + assert str(days) == 'maanantaina, perjantaina'