Merge pull request #98 from kvesteri/updates

Updates to support EncryptedType in more scenarios
This commit is contained in:
Konsta Vesterinen
2014-12-13 10:29:23 +02:00
6 changed files with 303 additions and 145 deletions

View File

@@ -3,6 +3,12 @@ Changelog
Here you can see the full list of changes between each SQLAlchemy-Utils release. Here you can see the full list of changes between each SQLAlchemy-Utils release.
0.28.1 (2014-12-xx)
^^^^^^^^^^^^^^^^^^^^
- Improved EncryptedType to support more underlying_type's; now supports: Integer, Boolean, Date, Time, DateTime, ColorType, PhoneNumberType, Unicode(Text), String(Text), Enum
- Allow a callable to be used to lookup the key for EncryptedType
0.28.0 (2014-12-12) 0.28.0 (2014-12-12)
^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^

View File

@@ -27,7 +27,7 @@ PY3 = sys.version_info[0] == 3
extras_require = { extras_require = {
'test': [ 'test': [
'pytest==2.2.3', 'pytest==2.3.5',
'Pygments>=1.2', 'Pygments>=1.2',
'Jinja2>=2.3', 'Jinja2>=2.3',
'docutils>=0.10', 'docutils>=0.10',

View File

@@ -49,6 +49,7 @@ class ColorType(types.TypeDecorator, ScalarCoercible):
""" """
STORE_FORMAT = u'hex' STORE_FORMAT = u'hex'
impl = types.Unicode(20) impl = types.Unicode(20)
python_type = colour.Color
def __init__(self, max_length=20, *args, **kwargs): def __init__(self, max_length=20, *args, **kwargs):
# Fail if colour is not found. # Fail if colour is not found.

View File

@@ -1,8 +1,10 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import base64 import base64
import six import six
from sqlalchemy.types import TypeDecorator, String import datetime
from sqlalchemy.types import TypeDecorator, String, Binary
from sqlalchemy_utils.exceptions import ImproperlyConfigured from sqlalchemy_utils.exceptions import ImproperlyConfigured
from .scalar_coercible import ScalarCoercible
cryptography = None cryptography = None
try: try:
@@ -24,13 +26,14 @@ class EncryptionDecryptionBaseEngine(object):
new engines. new engines.
""" """
def __init__(self, key): def _update_key(self, key):
"""Initialize a base engine."""
if isinstance(key, six.string_types): if isinstance(key, six.string_types):
key = six.b(key) key = key.encode()
self._digest = hashes.Hash(hashes.SHA256(), backend=default_backend()) digest = hashes.Hash(hashes.SHA256(), backend=default_backend())
self._digest.update(key) digest.update(key)
self._engine_key = self._digest.finalize() engine_key = digest.finalize()
self._initialize_engine(engine_key)
def encrypt(self, value): def encrypt(self, value):
raise NotImplementedError('Subclasses must implement this!') raise NotImplementedError('Subclasses must implement this!')
@@ -45,14 +48,6 @@ class AesEngine(EncryptionDecryptionBaseEngine):
BLOCK_SIZE = 16 BLOCK_SIZE = 16
PADDING = six.b('*') PADDING = six.b('*')
def __init__(self, key):
super(AesEngine, self).__init__(key)
self._initialize_engine(self._engine_key)
def _update_key(self, new_key):
parent = EncryptionDecryptionBaseEngine(new_key)
self._initialize_engine(parent._engine_key)
def _initialize_engine(self, parent_class_key): def _initialize_engine(self, parent_class_key):
self.secret_key = parent_class_key self.secret_key = parent_class_key
self.iv = self.secret_key[:16] self.iv = self.secret_key[:16]
@@ -74,7 +69,7 @@ class AesEngine(EncryptionDecryptionBaseEngine):
value = repr(value) value = repr(value)
if isinstance(value, six.text_type): if isinstance(value, six.text_type):
value = str(value) value = str(value)
value = six.b(value) value = value.encode()
value = self._pad(value) value = self._pad(value)
encryptor = self.cipher.encryptor() encryptor = self.cipher.encryptor()
encrypted = encryptor.update(value) + encryptor.finalize() encrypted = encryptor.update(value) + encryptor.finalize()
@@ -96,14 +91,6 @@ class AesEngine(EncryptionDecryptionBaseEngine):
class FernetEngine(EncryptionDecryptionBaseEngine): class FernetEngine(EncryptionDecryptionBaseEngine):
"""Provide Fernet encryption and decryption methods.""" """Provide Fernet encryption and decryption methods."""
def __init__(self, key):
super(FernetEngine, self).__init__(key)
self._initialize_engine(self._engine_key)
def _update_key(self, new_key):
parent = EncryptionDecryptionBaseEngine(new_key)
self._initialize_engine(parent._engine_key)
def _initialize_engine(self, parent_class_key): def _initialize_engine(self, parent_class_key):
self.secret_key = base64.urlsafe_b64encode(parent_class_key) self.secret_key = base64.urlsafe_b64encode(parent_class_key)
self.fernet = Fernet(self.secret_key) self.fernet = Fernet(self.secret_key)
@@ -113,7 +100,7 @@ class FernetEngine(EncryptionDecryptionBaseEngine):
value = repr(value) value = repr(value)
if isinstance(value, six.text_type): if isinstance(value, six.text_type):
value = str(value) value = str(value)
value = six.b(value) value = value.encode()
encrypted = self.fernet.encrypt(value) encrypted = self.fernet.encrypt(value)
return encrypted return encrypted
@@ -126,7 +113,7 @@ class FernetEngine(EncryptionDecryptionBaseEngine):
return decrypted return decrypted
class EncryptedType(TypeDecorator): class EncryptedType(TypeDecorator, ScalarCoercible):
""" """
EncryptedType provides a way to encrypt and decrypt values, EncryptedType provides a way to encrypt and decrypt values,
to and from databases, that their type is a basic SQLAlchemy type. to and from databases, that their type is a basic SQLAlchemy type.
@@ -194,9 +181,23 @@ class EncryptedType(TypeDecorator):
Base.metadata.drop_all(connection) Base.metadata.drop_all(connection)
connection.close() connection.close()
engine.dispose() engine.dispose()
The key parameter accepts a callable to allow for the key to change
per-row instead of be fixed for the whole table.
::
def get_key():
return 'dynamic-key'
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
username = sa.Column(EncryptedType(
sa.Unicode, get_key))
""" """
impl = String impl = Binary
def __init__(self, type_in=None, key=None, engine=None, **kwargs): def __init__(self, type_in=None, key=None, engine=None, **kwargs):
"""Initialization.""" """Initialization."""
@@ -208,11 +209,13 @@ class EncryptedType(TypeDecorator):
# set the underlying type # set the underlying type
if type_in is None: if type_in is None:
type_in = String() type_in = String()
self.underlying_type = type_in() elif isinstance(type_in, type):
type_in = type_in()
self.underlying_type = type_in
self._key = key self._key = key
if not engine: if not engine:
engine = AesEngine engine = AesEngine
self.engine = engine(self._key) self.engine = engine()
@property @property
def key(self): def key(self):
@@ -221,13 +224,73 @@ class EncryptedType(TypeDecorator):
@key.setter @key.setter
def key(self, value): def key(self, value):
self._key = value self._key = value
self.engine._update_key(self._key)
def _update_key(self):
key = self._key() if callable(self._key) else self._key
self.engine._update_key(key)
def process_bind_param(self, value, dialect): def process_bind_param(self, value, dialect):
"""Encrypt a value on the way in.""" """Encrypt a value on the way in."""
if value is not None:
self._update_key()
try:
value = self.underlying_type.process_bind_param(
value, dialect
)
except AttributeError:
# Doesn't have 'process_bind_param'
# Handle 'boolean' and 'dates'
type_ = self.underlying_type.python_type
if issubclass(type_, bool):
value = 'true' if value else 'false'
elif issubclass(type_, (datetime.date, datetime.time)):
value = value.isoformat()
return self.engine.encrypt(value) return self.engine.encrypt(value)
def process_result_value(self, value, dialect): def process_result_value(self, value, dialect):
"""Decrypt value on the way out.""" """Decrypt value on the way out."""
if value is not None:
self._update_key()
decrypted_value = self.engine.decrypt(value) decrypted_value = self.engine.decrypt(value)
try:
return self.underlying_type.process_result_value(
decrypted_value, dialect
)
except AttributeError:
# Doesn't have 'process_result_value'
# Handle 'boolean' and 'dates'
type_ = self.underlying_type.python_type
if issubclass(type_, bool):
return decrypted_value == 'true'
elif issubclass(type_, datetime.datetime):
return datetime.datetime.strptime(
decrypted_value, '%Y-%m-%dT%H:%M:%S'
)
elif issubclass(type_, datetime.time):
return datetime.datetime.strptime(
decrypted_value, '%H:%M:%S'
).time()
elif issubclass(type_, datetime.date):
return datetime.datetime.strptime(
decrypted_value, '%Y-%m-%d'
).date()
# Handle all others
return self.underlying_type.python_type(decrypted_value) return self.underlying_type.python_type(decrypted_value)
def _coerce(self, value):
if isinstance(self.underlying_type, ScalarCoercible):
return self.underlying_type._coerce(value)
return value

View File

@@ -94,6 +94,9 @@ class PhoneNumberType(types.TypeDecorator, ScalarCoercible):
STORE_FORMAT = 'e164' STORE_FORMAT = 'e164'
impl = types.Unicode(20) impl = types.Unicode(20)
def python_type(self, text):
return self._coerce(text)
def __init__(self, country_code='US', max_length=20, *args, **kwargs): def __init__(self, country_code='US', max_length=20, *args, **kwargs):
# Bail if phonenumbers is not found. # Bail if phonenumbers is not found.
if phonenumbers is None: if phonenumbers is None:

View File

@@ -1,4 +1,6 @@
import sqlalchemy as sa import sqlalchemy as sa
from datetime import datetime, date, time
import pytest
from pytest import mark from pytest import mark
cryptography = None cryptography = None
try: try:
@@ -7,151 +9,234 @@ except ImportError:
pass pass
from tests import TestCase from tests import TestCase
from sqlalchemy_utils import EncryptedType from sqlalchemy_utils import EncryptedType, PhoneNumberType, ColorType
from sqlalchemy_utils.types.encrypted import AesEngine, FernetEngine from sqlalchemy_utils.types.encrypted import AesEngine, FernetEngine
@mark.skipif('cryptography is None') @mark.skipif('cryptography is None')
class EncryptedTypeTestCase(TestCase): class EncryptedTypeTestCase(TestCase):
def setup_method(self, method):
# set some test values @pytest.fixture(scope='function')
self.test_key = 'secretkey1234' def user(self, request):
self.user_name = u'someone'
self.test_token = self.generate_test_token()
self.active = True
self.accounts_num = 2
self.searched_user = None
super(EncryptedTypeTestCase, self).setup_method(method)
# set the values to the user object # set the values to the user object
self.user = self.User() self.user = self.User()
self.user.username = self.user_name self.user.username = self.user_name
self.user.phone = self.user_phone
self.user.color = self.user_color
self.user.date = self.user_date
self.user.time = self.user_time
self.user.enum = self.user_enum
self.user.datetime = self.user_datetime
self.user.access_token = self.test_token self.user.access_token = self.test_token
self.user.is_active = self.active self.user.is_active = self.active
self.user.accounts_num = self.accounts_num self.user.accounts_num = self.accounts_num
self.session.add(self.user) self.session.add(self.user)
self.session.commit() self.session.commit()
def teardown_method(self, method): # register a finalizer to cleanup
self.session.delete(self.user) def finalize():
self.session.commit()
del self.user_name del self.user_name
del self.test_token del self.test_token
del self.active del self.active
del self.accounts_num del self.accounts_num
del self.test_key del self.test_key
del self.searched_user del self.searched_user
super(EncryptedTypeTestCase, self).teardown_method(method)
def create_models(self): request.addfinalizer(finalize)
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
username = sa.Column(EncryptedType(
sa.Unicode,
self.test_key,
self.__class__.encryption_engine))
access_token = sa.Column(EncryptedType(
sa.String,
self.test_key,
self.__class__.encryption_engine))
is_active = sa.Column(EncryptedType(
sa.Boolean,
self.test_key,
self.__class__.encryption_engine))
accounts_num = sa.Column(EncryptedType(
sa.Integer,
self.test_key,
self.__class__.encryption_engine))
def __repr__(self): return self.session.query(self.User).get(self.user.id)
return (
"User(id={}, username={}, access_token={},"
"active={}, accounts={})".format(
self.id,
self.username,
self.access_token,
self.is_active,
self.accounts_num
)
)
self.User = User
def assert_username(self, _user):
assert _user.username == self.user_name
def assert_access_token(self, _user):
assert _user.access_token == self.test_token
def assert_is_active(self, _user):
assert _user.is_active == self.active
def assert_accounts_num(self, _user):
assert _user.accounts_num == self.accounts_num
def generate_test_token(self): def generate_test_token(self):
import string import string
import random import random
token = "" token = ''
characters = string.ascii_letters + string.digits characters = string.ascii_letters + string.digits
for i in range(60): for i in range(60):
token += ''.join(random.choice(characters)) token += ''.join(random.choice(characters))
return token return token
def create_models(self):
# set some test values
self.test_key = 'secretkey1234'
self.user_name = u'someone'
self.user_phone = u'(555) 555-5555'
self.user_color = u'#fff'
self.user_enum = 'One'
self.user_date = date(2010, 10, 2)
self.user_time = time(10, 12)
self.user_datetime = datetime(2010, 10, 2, 10, 12)
self.test_token = self.generate_test_token()
self.active = True
self.accounts_num = 2
self.searched_user = None
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
username = sa.Column(EncryptedType(
sa.Unicode,
self.test_key,
self.__class__.encryption_engine)
)
access_token = sa.Column(EncryptedType(
sa.String,
self.test_key,
self.__class__.encryption_engine)
)
is_active = sa.Column(EncryptedType(
sa.Boolean,
self.test_key,
self.__class__.encryption_engine)
)
accounts_num = sa.Column(EncryptedType(
sa.Integer,
self.test_key,
self.__class__.encryption_engine)
)
phone = sa.Column(EncryptedType(
PhoneNumberType,
self.test_key,
self.__class__.encryption_engine)
)
color = sa.Column(EncryptedType(
ColorType,
self.test_key,
self.__class__.encryption_engine)
)
date = sa.Column(EncryptedType(
sa.Date,
self.test_key,
self.__class__.encryption_engine)
)
time = sa.Column(EncryptedType(
sa.Time,
self.test_key,
self.__class__.encryption_engine)
)
datetime = sa.Column(EncryptedType(
sa.DateTime,
self.test_key,
self.__class__.encryption_engine)
)
enum = sa.Column(EncryptedType(
sa.Enum('One', name='user_enum_t'),
self.test_key,
self.__class__.encryption_engine)
)
self.User = User
class Team(self.Base):
__tablename__ = 'team'
id = sa.Column(sa.Integer, primary_key=True)
key = sa.Column(sa.String(50))
name = sa.Column(EncryptedType(
sa.Unicode,
lambda: self._team_key,
self.__class__.encryption_engine)
)
self.Team = Team
def test_unicode(self, user):
assert user.username == self.user_name
def test_string(self, user):
assert user.access_token == self.test_token
def test_boolean(self, user):
assert user.is_active == self.active
def test_integer(self, user):
assert user.accounts_num == self.accounts_num
def test_phone_number(self, user):
assert str(user.phone) == self.user_phone
def test_color(self, user):
assert user.color.hex == self.user_color
def test_date(self, user):
assert user.date == self.user_date
def test_datetime(self, user):
assert user.datetime == self.user_datetime
def test_time(self, user):
assert user.time == self.user_time
def test_enum(self, user):
assert user.enum == self.user_enum
def test_lookup_key(self):
# Add teams
self._team_key = 'one'
team = self.Team(key=self._team_key, name=u'One')
self.session.add(team)
self.session.commit()
team_1_id = team.id
self._team_key = 'two'
team = self.Team(key=self._team_key)
team.name = u'Two'
self.session.add(team)
self.session.commit()
team_2_id = team.id
# Lookup teams
self._team_key = self.session.query(self.Team.key).filter_by(
id=team_1_id
).one()[0]
team = self.session.query(self.Team).get(team_1_id)
assert team.name == u'One'
with pytest.raises(Exception):
self.session.query(self.Team).get(team_2_id)
self.session.expunge_all()
self._team_key = self.session.query(self.Team.key).filter_by(
id=team_2_id
).one()[0]
team = self.session.query(self.Team).get(team_2_id)
assert team.name == u'Two'
with pytest.raises(Exception):
self.session.query(self.Team).get(team_1_id)
self.session.expunge_all()
# Remove teams
self.session.query(self.Team).delete()
self.session.commit()
class TestAesEncryptedTypeTestcase(EncryptedTypeTestCase): class TestAesEncryptedTypeTestcase(EncryptedTypeTestCase):
encryption_engine = AesEngine encryption_engine = AesEngine
def test_unicode(self): def test_lookup_by_encrypted_string(self, user):
self.searched_user = self.session.query(self.User).filter( test = self.session.query(self.User).filter(
self.User.access_token == self.test_token
).first()
self.assert_username(self.searched_user)
def test_string(self):
self.searched_user = self.session.query(self.User).filter(
self.User.username == self.user_name self.User.username == self.user_name
).first() ).first()
self.assert_access_token(self.searched_user)
def test_boolean(self): assert test.username == user.username
self.searched_user = self.session.query(self.User).filter(
self.User.access_token == self.test_token
).first()
self.assert_is_active(self.searched_user)
def test_integer(self):
self.searched_user = self.session.query(self.User).filter(
self.User.access_token == self.test_token
).first()
self.assert_accounts_num(self.searched_user)
class TestFernetEnryptedTypeTestCase(EncryptedTypeTestCase): class TestFernetEncryptedTypeTestCase(EncryptedTypeTestCase):
encryption_engine = FernetEngine encryption_engine = FernetEngine
def test_unicode(self):
self.searched_user = self.session.query(self.User).filter(
self.User.id == self.user.id
).first()
self.assert_username(self.searched_user)
def test_string(self):
self.searched_user = self.session.query(self.User).filter(
self.User.id == self.user.id
).first()
self.assert_access_token(self.searched_user)
def test_boolean(self):
self.searched_user = self.session.query(self.User).filter(
self.User.id == self.user.id
).first()
self.assert_is_active(self.searched_user)
def test_integer(self):
self.searched_user = self.session.query(self.User).filter(
self.User.id == self.user.id
).first()
self.assert_accounts_num(self.searched_user)