diff --git a/stacktask/api/migrations/0001_initial.py b/stacktask/api/migrations/0001_initial.py index 5982d16..5365a7b 100644 --- a/stacktask/api/migrations/0001_initial.py +++ b/stacktask/api/migrations/0001_initial.py @@ -27,6 +27,7 @@ class Migration(migrations.Migration): name='Task', fields=[ ('uuid', models.CharField(default=stacktask.api.models.hex_uuid, max_length=200, serialize=False, primary_key=True)), + ('hash_key', models.CharField(max_length=64, db_index=True)), ('ip_address', models.GenericIPAddressField()), ('keystone_user', jsonfield.fields.JSONField(default={})), ('project_id', models.CharField(max_length=200, null=True, db_index=True)), diff --git a/stacktask/api/models.py b/stacktask/api/models.py index 53c9129..6e07216 100644 --- a/stacktask/api/models.py +++ b/stacktask/api/models.py @@ -30,6 +30,8 @@ class Task(models.Model): """ uuid = models.CharField(max_length=200, default=hex_uuid, primary_key=True) + hash_key = models.CharField(max_length=64, db_index=True) + # who is this: ip_address = models.GenericIPAddressField() keystone_user = JSONField(default={}) diff --git a/stacktask/api/v1/tasks.py b/stacktask/api/v1/tasks.py index 47c61d8..e942b92 100644 --- a/stacktask/api/v1/tasks.py +++ b/stacktask/api/v1/tasks.py @@ -19,7 +19,7 @@ from django.utils import timezone from stacktask.api import utils from stacktask.api.v1.views import APIViewWithLogger from stacktask.api.v1.utils import ( - send_email, create_notification, create_token) + send_email, create_notification, create_token, create_task_hash) from django.conf import settings @@ -101,6 +101,16 @@ class TaskView(APIViewWithLogger): errors.update(action['serializer'].errors) return {'errors': errors} + hash_key = create_task_hash(self.task_type, action_list) + + tasks = Task.objects.filter( + hash_key=hash_key, + completed=0, + cancelled=0) + + if len(tasks) > 0: + return {'errors': ['Task is a duplicate of an existing one']} + ip_address = request.META['REMOTE_ADDR'] keystone_user = request.keystone_user @@ -108,11 +118,13 @@ class TaskView(APIViewWithLogger): task = Task.objects.create( ip_address=ip_address, keystone_user=keystone_user, project_id=keystone_user['project_id'], - task_type=self.task_type) + task_type=self.task_type, + hash_key=hash_key) except KeyError: task = Task.objects.create( ip_address=ip_address, keystone_user=keystone_user, - task_type=self.task_type) + task_type=self.task_type, + hash_key=hash_key) task.save() for i, action in enumerate(action_list): diff --git a/stacktask/api/v1/tests.py b/stacktask/api/v1/tests.py index 0334a72..b736815 100644 --- a/stacktask/api/v1/tests.py +++ b/stacktask/api/v1/tests.py @@ -1039,3 +1039,68 @@ class APITests(APITestCase): response = self.client.put(url, data, format='json', headers=headers) self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + @mock.patch( + 'stacktask.actions.models.user_store.IdentityManager', FakeManager) + @mock.patch( + 'stacktask.actions.tenant_setup.models.IdentityManager', FakeManager) + def test_duplicate_tasks_new_project(self): + """ + Ensure we can't submit duplicate tasks + """ + + project = mock.Mock() + project.id = 'test_project_id' + project.name = 'test_project' + project.roles = {} + + setup_temp_cache({}, {}) + + url = "/v1/actions/CreateProject" + data = {'project_name': "test_project", 'email': "test@example.com"} + response = self.client.post(url, data, format='json') + self.assertEqual(response.status_code, status.HTTP_200_OK) + response = self.client.post(url, data, format='json') + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + data = {'project_name': "test_project_2", 'email': "test@example.com"} + response = self.client.post(url, data, format='json') + self.assertEqual(response.status_code, status.HTTP_200_OK) + + @mock.patch( + 'stacktask.actions.models.user_store.IdentityManager', FakeManager) + def test_duplicate_tasks_new_user(self): + """ + Ensure we can't submit duplicate tasks + """ + project = mock.Mock() + project.id = 'test_project_id' + project.name = 'test_project' + project.roles = {} + + setup_temp_cache({'test_project': project}, {}) + + url = "/v1/actions/InviteUser" + headers = { + 'project_name': "test_project", + 'project_id': "test_project_id", + 'roles': "project_owner,Member,project_mod", + 'username': "test@example.com", + 'user_id': "test_user_id", + 'authenticated': True + } + data = {'email': "test@example.com", 'roles': ["Member"], + 'project_id': 'test_project_id'} + response = self.client.post(url, data, format='json', headers=headers) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, {'notes': ['created token']}) + response = self.client.post(url, data, format='json', headers=headers) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + + data = {'email': "test2@example.com", 'roles': ["Member"], + 'project_id': 'test_project_id'} + response = self.client.post(url, data, format='json', headers=headers) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(response.data, {'notes': ['created token']}) + response = self.client.post(url, data, format='json', headers=headers) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) diff --git a/stacktask/api/v1/utils.py b/stacktask/api/v1/utils.py index d2aa5f8..275bb8c 100644 --- a/stacktask/api/v1/utils.py +++ b/stacktask/api/v1/utils.py @@ -6,6 +6,7 @@ from django.core.mail import send_mail from smtplib import SMTPException from django.conf import settings from django.template import loader +import hashlib def create_token(task): @@ -91,3 +92,22 @@ def create_notification(task, notes, error=False): for note_engine, conf in class_conf.get('notifications', {}).iteritems(): engine = settings.NOTIFICATION_ENGINES[note_engine](conf) engine.notify(task, notes, error) + + +def create_task_hash(task_type, action_list): + hashable_list = [task_type, ] + + for action in action_list: + hashable_list.append(action['name']) + # iterate like this to maintain consistent order for hash + for field in action['action'].required: + try: + hashable_list.append( + action['serializer'].validated_data[field]) + except KeyError: + if field is "username" and settings.USERNAME_IS_EMAIL: + continue + else: + raise + + return hashlib.sha256(str(hashable_list)).hexdigest()