diff --git a/ara/api/tests/tests_records.py b/ara/api/tests/tests_records.py index aad243d..bf7e1b4 100644 --- a/ara/api/tests/tests_records.py +++ b/ara/api/tests/tests_records.py @@ -81,3 +81,24 @@ class RecordTestCase(APITestCase): self.assertEqual(1, len(request.data["results"])) self.assertEqual(record.key, request.data["results"][0]["key"]) self.assertEqual(record.playbook.id, request.data["results"][0]["playbook"]) + + def test_get_records_by_key(self): + playbook = factories.PlaybookFactory() + record = factories.RecordFactory(playbook=playbook, key="by_key") + factories.RecordFactory(key="another_record") + request = self.client.get("/api/v1/records?key=%s" % record.key) + self.assertEqual(2, models.Record.objects.all().count()) + self.assertEqual(1, len(request.data["results"])) + self.assertEqual(record.key, request.data["results"][0]["key"]) + self.assertEqual(record.playbook.id, request.data["results"][0]["playbook"]) + + def test_get_records_by_playbook_and_key(self): + playbook = factories.PlaybookFactory() + record = factories.RecordFactory(playbook=playbook, key="by_playbook_and_key") + factories.RecordFactory(playbook=playbook, key="another_record_in_playbook") + factories.RecordFactory(key="another_record_in_another_playbook") + request = self.client.get("/api/v1/records?playbook=%s&key=%s" % (playbook.id, record.key)) + self.assertEqual(3, models.Record.objects.all().count()) + self.assertEqual(1, len(request.data["results"])) + self.assertEqual(record.key, request.data["results"][0]["key"]) + self.assertEqual(record.playbook.id, request.data["results"][0]["playbook"]) diff --git a/ara/api/views.py b/ara/api/views.py index 84271e7..40d9b0e 100644 --- a/ara/api/views.py +++ b/ara/api/views.py @@ -75,7 +75,7 @@ class FileViewSet(viewsets.ModelViewSet): class RecordViewSet(viewsets.ModelViewSet): queryset = models.Record.objects.all() serializer_class = serializers.RecordSerializer - filter_fields = ("playbook",) + filter_fields = ("playbook", "key") class StatsViewSet(viewsets.ModelViewSet):