diff --git a/swauth/authtypes.py b/swauth/authtypes.py index 7698f4a..3764792 100644 --- a/swauth/authtypes.py +++ b/swauth/authtypes.py @@ -31,12 +31,34 @@ conditions: import hashlib import os +import sys #: Maximum length any valid token should ever be. MAX_TOKEN_LENGTH = 5000 +def validate_creds(creds): + """Parse and validate user credentials whether format is right + + :param creds: User credentials + :returns: Auth_type class instance and parsed user credentials in dict + :raises ValueError: If credential format is wrong (eg: bad auth_type) + """ + try: + auth_type, auth_rest = creds.split(':', 1) + except ValueError: + raise ValueError("Missing ':' in %s" % creds) + authtypes = sys.modules[__name__] + auth_encoder = getattr(authtypes, auth_type.title(), None) + if auth_encoder is None: + raise ValueError('Invalid auth_type: %s' % auth_type) + auth_encoder = auth_encoder() + parsed_creds = dict(type=auth_type, salt=None, hash=None) + parsed_creds.update(auth_encoder.validate(auth_rest)) + return auth_encoder, parsed_creds + + class Plaintext(object): """Provides a particular auth type for encoding format for encoding and matching user keys. @@ -54,15 +76,28 @@ class Plaintext(object): """ return "plaintext:%s" % key - def match(self, key, creds): + def match(self, key, creds, **kwargs): """Checks whether the user-provided key matches the user's credentials :param key: User-supplied key :param creds: User's stored credentials + :param kwargs: Extra keyword args for compatibility reason with + other auth_type classes :returns: True if the supplied key is valid, False otherwise """ return self.encode(key) == creds + def validate(self, auth_rest): + """Validate user credentials whether format is right for Plaintext + + :param auth_rest: User credentials' part without auth_type + :return: Dict with a hash part of user credentials + :raises ValueError: If credentials' part has zero length + """ + if len(auth_rest) == 0: + raise ValueError("Key must have non-zero length!") + return dict(hash=auth_rest) + class Sha1(object): """Provides a particular auth type for encoding format for encoding and @@ -98,19 +133,33 @@ class Sha1(object): salt = self.salt or os.urandom(32).encode('base64').rstrip() return self.encode_w_salt(salt, key) - def match(self, key, creds): + def match(self, key, creds, salt, **kwargs): """Checks whether the user-provided key matches the user's credentials :param key: User-supplied key :param creds: User's stored credentials + :param salt: Salt for hashing + :param kwargs: Extra keyword args for compatibility reason with + other auth_type classes :returns: True if the supplied key is valid, False otherwise """ - - type, rest = creds.split(':') - salt, enc = rest.split('$') - return self.encode_w_salt(salt, key) == creds + def validate(self, auth_rest): + """Validate user credentials whether format is right for Sha1 + + :param auth_rest: User credentials' part without auth_type + :return: Dict with a hash and a salt part of user credentials + :raises ValueError: If credentials' part doesn't contain delimiter + between a salt and a hash. + """ + try: + auth_salt, auth_hash = auth_rest.split('$') + except ValueError: + raise ValueError("Missing '$' in %s" % auth_rest) + + return dict(salt=auth_salt, hash=auth_hash) + class Sha512(object): """Provides a particular auth type for encoding format for encoding and @@ -146,15 +195,28 @@ class Sha512(object): salt = self.salt or os.urandom(32).encode('base64').rstrip() return self.encode_w_salt(salt, key) - def match(self, key, creds): + def match(self, key, creds, salt, **kwargs): """Checks whether the user-provided key matches the user's credentials :param key: User-supplied key :param creds: User's stored credentials + :param salt: Salt for hashing + :param kwargs: Extra keyword args for compatibility reason with + other auth_type classes :returns: True if the supplied key is valid, False otherwise """ - - type, rest = creds.split(':') - salt, enc = rest.split('$') - return self.encode_w_salt(salt, key) == creds + + def validate(self, auth_rest): + """Validate user credentials whether format is right for Sha512 + + :param auth_rest: User credentials' part without auth_type + :return: Dict with a hash and a salt part of user credentials + :raises ValueError: If credentials' part doesn't contain delimiter + between a salt and a hash. + """ + try: + auth_salt, auth_hash = auth_rest.split('$') + except ValueError: + raise ValueError("Missing '$' in %s" % auth_rest) + return dict(salt=auth_salt, hash=auth_hash) diff --git a/swauth/middleware.py b/swauth/middleware.py index e04100b..0bf6c6d 100644 --- a/swauth/middleware.py +++ b/swauth/middleware.py @@ -1074,10 +1074,9 @@ class Swauth(object): user[0] == '.' or (not key and not key_hash): return HTTPBadRequest(request=req) if key_hash: - if ':' not in key_hash: - return HTTPBadRequest(request=req) - auth_type, hash = key_hash.split(':') - if getattr(swauth.authtypes, auth_type.title(), None) is None: + try: + swauth.authtypes.validate_creds(key_hash) + except ValueError: return HTTPBadRequest(request=req) user_arg = account + ':' + user @@ -1540,12 +1539,13 @@ class Swauth(object): """ if user_detail: creds = user_detail.get('auth') - auth_type = creds.split(':')[0] - auth_encoder = getattr(swauth.authtypes, auth_type.title(), None) - if auth_encoder is None: - self.logger.error('Invalid auth_type %s' % auth_type) + try: + auth_encoder, creds_dict = \ + swauth.authtypes.validate_creds(creds) + except ValueError as e: + self.logger.error('%s' % e.args[0]) return False - return user_detail and auth_encoder().match(key, creds) + return user_detail and auth_encoder.match(key, creds, **creds_dict) def is_user_changing_own_key(self, req, user): """Check if the user is changing his own key. diff --git a/test/unit/test_authtypes.py b/test/unit/test_authtypes.py index cc88eec..3bd93d0 100644 --- a/test/unit/test_authtypes.py +++ b/test/unit/test_authtypes.py @@ -18,6 +18,56 @@ from swauth import authtypes import unittest +class TestValidation(unittest.TestCase): + def test_validate_creds(self): + creds = 'plaintext:keystring' + creds_dict = dict(type='plaintext', salt=None, hash='keystring') + auth_encoder, parsed_creds = authtypes.validate_creds(creds) + self.assertEqual(parsed_creds, creds_dict) + self.assertTrue(isinstance(auth_encoder, authtypes.Plaintext)) + + creds = 'sha1:salt$d50dc700c296e23ce5b41f7431a0e01f69010f06' + creds_dict = dict(type='sha1', salt='salt', + hash='d50dc700c296e23ce5b41f7431a0e01f69010f06') + auth_encoder, parsed_creds = authtypes.validate_creds(creds) + self.assertEqual(parsed_creds, creds_dict) + self.assertTrue(isinstance(auth_encoder, authtypes.Sha1)) + + creds = ('sha512:salt$482e73705fac6909e2d78e8bbaf65ac3ca1473' + '8f445cc2367b7daa3f0e8f3dcfe798e426b9e332776c8da59c' + '0c11d4832931d1bf48830f670ecc6ceb04fbad0f') + creds_dict = dict(type='sha512', salt='salt', + hash='482e73705fac6909e2d78e8bbaf65ac3ca1473' + '8f445cc2367b7daa3f0e8f3dcfe798e426b9e3' + '32776c8da59c0c11d4832931d1bf48830f670e' + 'cc6ceb04fbad0f') + auth_encoder, parsed_creds = authtypes.validate_creds(creds) + self.assertEqual(parsed_creds, creds_dict) + self.assertTrue(isinstance(auth_encoder, authtypes.Sha512)) + + def test_validate_creds_fail(self): + # wrong format, missing `:` + creds = 'unknown;keystring' + self.assertRaisesRegexp(ValueError, "Missing ':' in .*", + authtypes.validate_creds, creds) + # unknown auth_type + creds = 'unknown:keystring' + self.assertRaisesRegexp(ValueError, "Invalid auth_type: .*", + authtypes.validate_creds, creds) + # wrong plaintext keystring + creds = 'plaintext:' + self.assertRaisesRegexp(ValueError, "Key must have non-zero length!", + authtypes.validate_creds, creds) + # wrong sha1 format, missing `$` + creds = 'sha1:saltkeystring' + self.assertRaisesRegexp(ValueError, "Missing '\$' in .*", + authtypes.validate_creds, creds) + # wrong sha512 format, missing `$` + creds = 'sha512:saltkeystring' + self.assertRaisesRegexp(ValueError, "Missing '\$' in .*", + authtypes.validate_creds, creds) + + class TestPlaintext(unittest.TestCase): def setUp(self): @@ -54,16 +104,22 @@ class TestSha1(unittest.TestCase): def test_sha1_valid_match(self): creds = 'sha1:salt$d50dc700c296e23ce5b41f7431a0e01f69010f06' - match = self.auth_encoder.match('keystring', creds) + creds_dict = dict(type='sha1', salt='salt', + hash='d50dc700c296e23ce5b41f7431a0e01f69010f06') + match = self.auth_encoder.match('keystring', creds, **creds_dict) self.assertEqual(match, True) def test_sha1_invalid_match(self): creds = 'sha1:salt$deadbabedeadbabedeadbabec0ffeebadc0ffeee' - match = self.auth_encoder.match('keystring', creds) + creds_dict = dict(type='sha1', salt='salt', + hash='deadbabedeadbabedeadbabec0ffeebadc0ffeee') + match = self.auth_encoder.match('keystring', creds, **creds_dict) self.assertEqual(match, False) creds = 'sha1:salt$d50dc700c296e23ce5b41f7431a0e01f69010f06' - match = self.auth_encoder.match('keystring2', creds) + creds_dict = dict(type='sha1', salt='salt', + hash='d50dc700c296e23ce5b41f7431a0e01f69010f06') + match = self.auth_encoder.match('keystring2', creds, **creds_dict) self.assertEqual(match, False) @@ -86,20 +142,32 @@ class TestSha512(unittest.TestCase): creds = ('sha512:salt$482e73705fac6909e2d78e8bbaf65ac3ca14738f445cc2' '367b7daa3f0e8f3dcfe798e426b9e332776c8da59c0c11d4832931d1bf' '48830f670ecc6ceb04fbad0f') - match = self.auth_encoder.match('keystring', creds) + creds_dict = dict(type='sha512', salt='salt', + hash='482e73705fac6909e2d78e8bbaf65ac3ca14738f445cc2' + '367b7daa3f0e8f3dcfe798e426b9e332776c8da59c0c11' + 'd4832931d1bf48830f670ecc6ceb04fbad0f') + match = self.auth_encoder.match('keystring', creds, **creds_dict) self.assertEqual(match, True) def test_sha512_invalid_match(self): creds = ('sha512:salt$deadbabedeadbabedeadbabedeadbabedeadbabedeadba' 'bedeadbabedeadbabedeadbabedeadbabedeadbabedeadbabedeadbabe' 'c0ffeebadc0ffeeec0ffeeba') - match = self.auth_encoder.match('keystring', creds) + creds_dict = dict(type='sha512', salt='salt', + hash='deadbabedeadbabedeadbabedeadbabedeadbabedeadba' + 'bedeadbabedeadbabedeadbabedeadbabedeadbabedead' + 'babedeadbabec0ffeebadc0ffeeec0ffeeba') + match = self.auth_encoder.match('keystring', creds, **creds_dict) self.assertEqual(match, False) creds = ('sha512:salt$482e73705fac6909e2d78e8bbaf65ac3ca14738f445cc2' '367b7daa3f0e8f3dcfe798e426b9e332776c8da59c0c11d4832931d1bf' '48830f670ecc6ceb04fbad0f') - match = self.auth_encoder.match('keystring2', creds) + creds_dict = dict(type='sha512', salt='salt', + hash='482e73705fac6909e2d78e8bbaf65ac3ca14738f445cc2' + '367b7daa3f0e8f3dcfe798e426b9e332776c8da59c0c11' + 'd4832931d1bf48830f670ecc6ceb04fbad0f') + match = self.auth_encoder.match('keystring2', creds, **creds_dict) self.assertEqual(match, False) if __name__ == '__main__':