diff --git a/sqlalchemy_utils/types/password.py b/sqlalchemy_utils/types/password.py index 5d95e80..c943535 100644 --- a/sqlalchemy_utils/types/password.py +++ b/sqlalchemy_utils/types/password.py @@ -4,6 +4,7 @@ from sqlalchemy_utils import ImproperlyConfigured from sqlalchemy import types from sqlalchemy.dialects import postgresql from .scalar_coercible import ScalarCoercible +from sqlalchemy.ext.mutable import Mutable passlib = None try: @@ -13,27 +14,61 @@ except ImportError: pass -class Password(object): - def __init__(self, value, context=None): - # Store the hash. - self.hash = value +class Password(Mutable, object): + + @classmethod + def coerce(cls, key, value): + if isinstance(value, Password): + return value + + if isinstance(value, (six.string_types, six.binary_type)): + return cls(value, secret=True) + + super(Password, cls).coerce(key, value) + + def __init__(self, value, context=None, secret=False): + # Store the hash (if it is one). + self.hash = value if not secret else None + + # Store the secret if we have one. + self.secret = value if secret else None + + # The hash should be bytes. + if isinstance(self.hash, six.text_type): + self.hash = self.hash.encode('utf8') # Save weakref of the password context (if we have one) - if context is not None: - self.context = weakref.proxy(context) + self.context = weakref.proxy(context) if context is not None else None def __eq__(self, value): + if self.hash is None or value is None: + # Ensure that we don't continue comparison if one of us is None. + return self.hash is value + if isinstance(value, Password): # Comparing 2 hashes isn't very useful; but this equality # method breaks otherwise. return value.hash == self.hash - valid, new = self.context.verify_and_update(value, self.hash) - if valid and new: - # New hash was calculated due to various reasons; stored one - # wasn't optimal, etc. - self.hash = new - return valid + if self.context is None: + # Compare 2 hashes again as we don't know how to validate. + return value == self + + if isinstance(value, (six.string_types, six.binary_type)): + valid, new = self.context.verify_and_update(value, self.hash) + if valid and new: + # New hash was calculated due to various reasons; stored one + # wasn't optimal, etc. + self.hash = new + + # The hash should be bytes. + if isinstance(self.hash, six.text_type): + self.hash = self.hash.encode('utf8') + self.changed() + + return valid + + return False def __ne__(self, value): return not (self == value) @@ -143,6 +178,7 @@ class PasswordType(types.TypeDecorator, ScalarCoercible): return Password(value, self.context) def _coerce(self, value): + if value is None: return @@ -155,4 +191,12 @@ class PasswordType(types.TypeDecorator, ScalarCoercible): # If were given a password object; ensure the context is right. value.context = weakref.proxy(self.context) + # If were given a password secret; encrypt it. + if value.secret is not None: + value.hash = self.context.encrypt(value.secret).encode('utf8') + value.secret = None + return value + + +Password.associate_with(PasswordType)