diff --git a/CHANGES.rst b/CHANGES.rst index 051d57c..5dbc83c 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -3,6 +3,12 @@ Changelog Here you can see the full list of changes between each SQLAlchemy-Utils release. +0.28.1 (2014-12-xx) +^^^^^^^^^^^^^^^^^^^^ + +- Improved EncryptedType to support more underlying_type's; now supports: Integer, Boolean, Date, Time, DateTime, ColorType, PhoneNumberType, Unicode(Text), String(Text), Enum +- Allow a callable to be used to lookup the key for EncryptedType + 0.28.0 (2014-12-12) ^^^^^^^^^^^^^^^^^^^^ diff --git a/setup.py b/setup.py index b34cf1d..a321aa1 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,7 @@ PY3 = sys.version_info[0] == 3 extras_require = { 'test': [ - 'pytest==2.2.3', + 'pytest==2.3.5', 'Pygments>=1.2', 'Jinja2>=2.3', 'docutils>=0.10', diff --git a/sqlalchemy_utils/types/color.py b/sqlalchemy_utils/types/color.py index 1d4624b..0064243 100644 --- a/sqlalchemy_utils/types/color.py +++ b/sqlalchemy_utils/types/color.py @@ -49,6 +49,7 @@ class ColorType(types.TypeDecorator, ScalarCoercible): """ STORE_FORMAT = u'hex' impl = types.Unicode(20) + python_type = colour.Color def __init__(self, max_length=20, *args, **kwargs): # Fail if colour is not found. diff --git a/sqlalchemy_utils/types/encrypted.py b/sqlalchemy_utils/types/encrypted.py index ae5dd3b..0d28b41 100644 --- a/sqlalchemy_utils/types/encrypted.py +++ b/sqlalchemy_utils/types/encrypted.py @@ -1,8 +1,10 @@ # -*- coding: utf-8 -*- import base64 import six -from sqlalchemy.types import TypeDecorator, String +import datetime +from sqlalchemy.types import TypeDecorator, String, Binary from sqlalchemy_utils.exceptions import ImproperlyConfigured +from .scalar_coercible import ScalarCoercible cryptography = None try: @@ -24,13 +26,14 @@ class EncryptionDecryptionBaseEngine(object): new engines. """ - def __init__(self, key): - """Initialize a base engine.""" + def _update_key(self, key): if isinstance(key, six.string_types): - key = six.b(key) - self._digest = hashes.Hash(hashes.SHA256(), backend=default_backend()) - self._digest.update(key) - self._engine_key = self._digest.finalize() + key = key.encode() + digest = hashes.Hash(hashes.SHA256(), backend=default_backend()) + digest.update(key) + engine_key = digest.finalize() + + self._initialize_engine(engine_key) def encrypt(self, value): raise NotImplementedError('Subclasses must implement this!') @@ -45,14 +48,6 @@ class AesEngine(EncryptionDecryptionBaseEngine): BLOCK_SIZE = 16 PADDING = six.b('*') - def __init__(self, key): - super(AesEngine, self).__init__(key) - self._initialize_engine(self._engine_key) - - def _update_key(self, new_key): - parent = EncryptionDecryptionBaseEngine(new_key) - self._initialize_engine(parent._engine_key) - def _initialize_engine(self, parent_class_key): self.secret_key = parent_class_key self.iv = self.secret_key[:16] @@ -74,7 +69,7 @@ class AesEngine(EncryptionDecryptionBaseEngine): value = repr(value) if isinstance(value, six.text_type): value = str(value) - value = six.b(value) + value = value.encode() value = self._pad(value) encryptor = self.cipher.encryptor() encrypted = encryptor.update(value) + encryptor.finalize() @@ -96,14 +91,6 @@ class AesEngine(EncryptionDecryptionBaseEngine): class FernetEngine(EncryptionDecryptionBaseEngine): """Provide Fernet encryption and decryption methods.""" - def __init__(self, key): - super(FernetEngine, self).__init__(key) - self._initialize_engine(self._engine_key) - - def _update_key(self, new_key): - parent = EncryptionDecryptionBaseEngine(new_key) - self._initialize_engine(parent._engine_key) - def _initialize_engine(self, parent_class_key): self.secret_key = base64.urlsafe_b64encode(parent_class_key) self.fernet = Fernet(self.secret_key) @@ -113,7 +100,7 @@ class FernetEngine(EncryptionDecryptionBaseEngine): value = repr(value) if isinstance(value, six.text_type): value = str(value) - value = six.b(value) + value = value.encode() encrypted = self.fernet.encrypt(value) return encrypted @@ -126,7 +113,7 @@ class FernetEngine(EncryptionDecryptionBaseEngine): return decrypted -class EncryptedType(TypeDecorator): +class EncryptedType(TypeDecorator, ScalarCoercible): """ EncryptedType provides a way to encrypt and decrypt values, to and from databases, that their type is a basic SQLAlchemy type. @@ -194,9 +181,23 @@ class EncryptedType(TypeDecorator): Base.metadata.drop_all(connection) connection.close() engine.dispose() + + The key parameter accepts a callable to allow for the key to change + per-row instead of be fixed for the whole table. + + :: + def get_key(): + return 'dynamic-key' + + class User(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + username = sa.Column(EncryptedType( + sa.Unicode, get_key)) + """ - impl = String + impl = Binary def __init__(self, type_in=None, key=None, engine=None, **kwargs): """Initialization.""" @@ -208,11 +209,13 @@ class EncryptedType(TypeDecorator): # set the underlying type if type_in is None: type_in = String() - self.underlying_type = type_in() + elif isinstance(type_in, type): + type_in = type_in() + self.underlying_type = type_in self._key = key if not engine: engine = AesEngine - self.engine = engine(self._key) + self.engine = engine() @property def key(self): @@ -221,13 +224,73 @@ class EncryptedType(TypeDecorator): @key.setter def key(self, value): self._key = value - self.engine._update_key(self._key) + + def _update_key(self): + key = self._key() if callable(self._key) else self._key + self.engine._update_key(key) def process_bind_param(self, value, dialect): """Encrypt a value on the way in.""" - return self.engine.encrypt(value) + if value is not None: + self._update_key() + + try: + value = self.underlying_type.process_bind_param( + value, dialect + ) + + except AttributeError: + # Doesn't have 'process_bind_param' + + # Handle 'boolean' and 'dates' + type_ = self.underlying_type.python_type + if issubclass(type_, bool): + value = 'true' if value else 'false' + + elif issubclass(type_, (datetime.date, datetime.time)): + value = value.isoformat() + + return self.engine.encrypt(value) def process_result_value(self, value, dialect): """Decrypt value on the way out.""" - decrypted_value = self.engine.decrypt(value) - return self.underlying_type.python_type(decrypted_value) + if value is not None: + self._update_key() + decrypted_value = self.engine.decrypt(value) + + try: + return self.underlying_type.process_result_value( + decrypted_value, dialect + ) + + except AttributeError: + # Doesn't have 'process_result_value' + + # Handle 'boolean' and 'dates' + type_ = self.underlying_type.python_type + if issubclass(type_, bool): + return decrypted_value == 'true' + + elif issubclass(type_, datetime.datetime): + return datetime.datetime.strptime( + decrypted_value, '%Y-%m-%dT%H:%M:%S' + ) + + elif issubclass(type_, datetime.time): + return datetime.datetime.strptime( + decrypted_value, '%H:%M:%S' + ).time() + + elif issubclass(type_, datetime.date): + return datetime.datetime.strptime( + decrypted_value, '%Y-%m-%d' + ).date() + + # Handle all others + return self.underlying_type.python_type(decrypted_value) + + def _coerce(self, value): + if isinstance(self.underlying_type, ScalarCoercible): + return self.underlying_type._coerce(value) + + return value diff --git a/sqlalchemy_utils/types/phone_number.py b/sqlalchemy_utils/types/phone_number.py index 268ae70..762c459 100644 --- a/sqlalchemy_utils/types/phone_number.py +++ b/sqlalchemy_utils/types/phone_number.py @@ -94,6 +94,9 @@ class PhoneNumberType(types.TypeDecorator, ScalarCoercible): STORE_FORMAT = 'e164' impl = types.Unicode(20) + def python_type(self, text): + return self._coerce(text) + def __init__(self, country_code='US', max_length=20, *args, **kwargs): # Bail if phonenumbers is not found. if phonenumbers is None: diff --git a/tests/types/test_encrypted.py b/tests/types/test_encrypted.py index 002bc9b..d5ca19c 100644 --- a/tests/types/test_encrypted.py +++ b/tests/types/test_encrypted.py @@ -1,4 +1,6 @@ import sqlalchemy as sa +from datetime import datetime, date, time +import pytest from pytest import mark cryptography = None try: @@ -7,151 +9,234 @@ except ImportError: pass from tests import TestCase -from sqlalchemy_utils import EncryptedType +from sqlalchemy_utils import EncryptedType, PhoneNumberType, ColorType from sqlalchemy_utils.types.encrypted import AesEngine, FernetEngine @mark.skipif('cryptography is None') class EncryptedTypeTestCase(TestCase): - def setup_method(self, method): - # set some test values - self.test_key = 'secretkey1234' - self.user_name = u'someone' - self.test_token = self.generate_test_token() - self.active = True - self.accounts_num = 2 - self.searched_user = None - super(EncryptedTypeTestCase, self).setup_method(method) + + @pytest.fixture(scope='function') + def user(self, request): # set the values to the user object self.user = self.User() self.user.username = self.user_name + self.user.phone = self.user_phone + self.user.color = self.user_color + self.user.date = self.user_date + self.user.time = self.user_time + self.user.enum = self.user_enum + self.user.datetime = self.user_datetime self.user.access_token = self.test_token self.user.is_active = self.active self.user.accounts_num = self.accounts_num self.session.add(self.user) self.session.commit() - def teardown_method(self, method): - self.session.delete(self.user) - self.session.commit() - del self.user_name - del self.test_token - del self.active - del self.accounts_num - del self.test_key - del self.searched_user - super(EncryptedTypeTestCase, self).teardown_method(method) + # register a finalizer to cleanup + def finalize(): + del self.user_name + del self.test_token + del self.active + del self.accounts_num + del self.test_key + del self.searched_user - def create_models(self): - class User(self.Base): - __tablename__ = 'user' - id = sa.Column(sa.Integer, primary_key=True) - username = sa.Column(EncryptedType( - sa.Unicode, - self.test_key, - self.__class__.encryption_engine)) - access_token = sa.Column(EncryptedType( - sa.String, - self.test_key, - self.__class__.encryption_engine)) - is_active = sa.Column(EncryptedType( - sa.Boolean, - self.test_key, - self.__class__.encryption_engine)) - accounts_num = sa.Column(EncryptedType( - sa.Integer, - self.test_key, - self.__class__.encryption_engine)) + request.addfinalizer(finalize) - def __repr__(self): - return ( - "User(id={}, username={}, access_token={}," - "active={}, accounts={})".format( - self.id, - self.username, - self.access_token, - self.is_active, - self.accounts_num - ) - ) - - self.User = User - - def assert_username(self, _user): - assert _user.username == self.user_name - - def assert_access_token(self, _user): - assert _user.access_token == self.test_token - - def assert_is_active(self, _user): - assert _user.is_active == self.active - - def assert_accounts_num(self, _user): - assert _user.accounts_num == self.accounts_num + return self.session.query(self.User).get(self.user.id) def generate_test_token(self): import string import random - token = "" + token = '' characters = string.ascii_letters + string.digits for i in range(60): token += ''.join(random.choice(characters)) return token + def create_models(self): + # set some test values + self.test_key = 'secretkey1234' + self.user_name = u'someone' + self.user_phone = u'(555) 555-5555' + self.user_color = u'#fff' + self.user_enum = 'One' + self.user_date = date(2010, 10, 2) + self.user_time = time(10, 12) + self.user_datetime = datetime(2010, 10, 2, 10, 12) + self.test_token = self.generate_test_token() + self.active = True + self.accounts_num = 2 + self.searched_user = None + + class User(self.Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + + username = sa.Column(EncryptedType( + sa.Unicode, + self.test_key, + self.__class__.encryption_engine) + ) + + access_token = sa.Column(EncryptedType( + sa.String, + self.test_key, + self.__class__.encryption_engine) + ) + + is_active = sa.Column(EncryptedType( + sa.Boolean, + self.test_key, + self.__class__.encryption_engine) + ) + + accounts_num = sa.Column(EncryptedType( + sa.Integer, + self.test_key, + self.__class__.encryption_engine) + ) + + phone = sa.Column(EncryptedType( + PhoneNumberType, + self.test_key, + self.__class__.encryption_engine) + ) + + color = sa.Column(EncryptedType( + ColorType, + self.test_key, + self.__class__.encryption_engine) + ) + + date = sa.Column(EncryptedType( + sa.Date, + self.test_key, + self.__class__.encryption_engine) + ) + + time = sa.Column(EncryptedType( + sa.Time, + self.test_key, + self.__class__.encryption_engine) + ) + + datetime = sa.Column(EncryptedType( + sa.DateTime, + self.test_key, + self.__class__.encryption_engine) + ) + + enum = sa.Column(EncryptedType( + sa.Enum('One', name='user_enum_t'), + self.test_key, + self.__class__.encryption_engine) + ) + + self.User = User + + class Team(self.Base): + __tablename__ = 'team' + id = sa.Column(sa.Integer, primary_key=True) + key = sa.Column(sa.String(50)) + name = sa.Column(EncryptedType( + sa.Unicode, + lambda: self._team_key, + self.__class__.encryption_engine) + ) + + self.Team = Team + + def test_unicode(self, user): + assert user.username == self.user_name + + def test_string(self, user): + assert user.access_token == self.test_token + + def test_boolean(self, user): + assert user.is_active == self.active + + def test_integer(self, user): + assert user.accounts_num == self.accounts_num + + def test_phone_number(self, user): + assert str(user.phone) == self.user_phone + + def test_color(self, user): + assert user.color.hex == self.user_color + + def test_date(self, user): + assert user.date == self.user_date + + def test_datetime(self, user): + assert user.datetime == self.user_datetime + + def test_time(self, user): + assert user.time == self.user_time + + def test_enum(self, user): + assert user.enum == self.user_enum + + def test_lookup_key(self): + # Add teams + self._team_key = 'one' + team = self.Team(key=self._team_key, name=u'One') + self.session.add(team) + self.session.commit() + team_1_id = team.id + + self._team_key = 'two' + team = self.Team(key=self._team_key) + team.name = u'Two' + self.session.add(team) + self.session.commit() + team_2_id = team.id + + # Lookup teams + self._team_key = self.session.query(self.Team.key).filter_by( + id=team_1_id + ).one()[0] + + team = self.session.query(self.Team).get(team_1_id) + + assert team.name == u'One' + + with pytest.raises(Exception): + self.session.query(self.Team).get(team_2_id) + + self.session.expunge_all() + + self._team_key = self.session.query(self.Team.key).filter_by( + id=team_2_id + ).one()[0] + + team = self.session.query(self.Team).get(team_2_id) + + assert team.name == u'Two' + + with pytest.raises(Exception): + self.session.query(self.Team).get(team_1_id) + + self.session.expunge_all() + + # Remove teams + self.session.query(self.Team).delete() + self.session.commit() + class TestAesEncryptedTypeTestcase(EncryptedTypeTestCase): encryption_engine = AesEngine - def test_unicode(self): - self.searched_user = self.session.query(self.User).filter( - self.User.access_token == self.test_token - ).first() - self.assert_username(self.searched_user) - - def test_string(self): - self.searched_user = self.session.query(self.User).filter( + def test_lookup_by_encrypted_string(self, user): + test = self.session.query(self.User).filter( self.User.username == self.user_name ).first() - self.assert_access_token(self.searched_user) - def test_boolean(self): - self.searched_user = self.session.query(self.User).filter( - self.User.access_token == self.test_token - ).first() - self.assert_is_active(self.searched_user) - - def test_integer(self): - self.searched_user = self.session.query(self.User).filter( - self.User.access_token == self.test_token - ).first() - self.assert_accounts_num(self.searched_user) + assert test.username == user.username -class TestFernetEnryptedTypeTestCase(EncryptedTypeTestCase): +class TestFernetEncryptedTypeTestCase(EncryptedTypeTestCase): encryption_engine = FernetEngine - - def test_unicode(self): - self.searched_user = self.session.query(self.User).filter( - self.User.id == self.user.id - ).first() - self.assert_username(self.searched_user) - - def test_string(self): - self.searched_user = self.session.query(self.User).filter( - self.User.id == self.user.id - ).first() - self.assert_access_token(self.searched_user) - - def test_boolean(self): - self.searched_user = self.session.query(self.User).filter( - self.User.id == self.user.id - ).first() - self.assert_is_active(self.searched_user) - - def test_integer(self): - self.searched_user = self.session.query(self.User).filter( - self.User.id == self.user.id - ).first() - self.assert_accounts_num(self.searched_user)