diff --git a/dojo/api_v2/serializers.py b/dojo/api_v2/serializers.py index 1eeb021d165..d23cffcf4b4 100644 --- a/dojo/api_v2/serializers.py +++ b/dojo/api_v2/serializers.py @@ -2856,30 +2856,36 @@ def save(self): msg = "Invalid format" raise Exception(msg) + # Filter out ignored keys + language_names = [name for name in deserialized if name not in {"header", "SUM"}] + # Prepopulate existing Language_Type objects + existing_types = { + lt.language: lt + for lt in Language_Type.objects.filter(language__in=language_names) + } + # Determine which Language_Type objects need to be created + new_language_names = [name for name in language_names if name not in existing_types] + new_types = [Language_Type(language=name) for name in new_language_names] + Language_Type.objects.bulk_create(new_types) + # Add newly created Language_Type objects to cache + for lt in Language_Type.objects.filter(language__in=new_language_names): + existing_types[lt.language] = lt + # Delete all Languages for this product Languages.objects.filter(product=product).delete() - - for name in deserialized: - if name not in {"header", "SUM"}: - element = deserialized[name] - - try: - ( - language_type, - _created, - ) = Language_Type.objects.get_or_create(language=name) - except Language_Type.MultipleObjectsReturned: - language_type = Language_Type.objects.filter( - language=name, - ).first() - - language = Languages() - language.product = product - language.language = language_type - language.files = element.get("nFiles", 0) - language.blank = element.get("blank", 0) - language.comment = element.get("comment", 0) - language.code = element.get("code", 0) - language.save() + # Prepare Languages objects for bulk insert + languages_to_create = [ + Languages( + product=product, + language=existing_types[name], + files=deserialized[name].get("nFiles", 0), + blank=deserialized[name].get("blank", 0), + comment=deserialized[name].get("comment", 0), + code=deserialized[name].get("code", 0), + ) + for name in language_names + ] + # Bulk insert all Languages in one query + Languages.objects.bulk_create(languages_to_create) def validate(self, data): if is_scan_file_too_large(data["file"]): diff --git a/unittests/test_rest_framework.py b/unittests/test_rest_framework.py index 5666f4de1d8..5f10e4e1ad8 100644 --- a/unittests/test_rest_framework.py +++ b/unittests/test_rest_framework.py @@ -3788,26 +3788,101 @@ def __init__(self, *args, **kwargs): def __del__(self: object): self.payload["file"].close() + def _build_payload(self, data): + return { + "product": 1, + "file": SimpleUploadedFile( + "defectdojo_cloc.json", + json.dumps(data).encode("utf-8"), + content_type="application/json", + ), + } + def test_create(self): - BaseClass.CreateRequestTest.test_create(self) + self.payload["file"].close() + base_data = json.loads( + Path("unittests/files/defectdojo_cloc.json").read_text( + encoding="utf-8", + ), + ) + updated_data = json.loads(json.dumps(base_data)) + updated_data.pop("JSON", None) + updated_data["Python"]["code"] = 51057 + updated_data["Go"] = { + "nFiles": 1, + "blank": 2, + "comment": 3, + "code": 4, + } - languages = Languages.objects.filter(product=1).order_by("language") + test_cases = [ + ( + "initial", + base_data, + { + "JSON": { + "files": 21, + "blank": 7, + "comment": 0, + "code": 63996, + }, + "Python": { + "files": 432, + "blank": 10813, + "comment": 5054, + "code": 51056, + }, + }, + ), + ( + "updated", + updated_data, + { + "Go": { + "files": 1, + "blank": 2, + "comment": 3, + "code": 4, + }, + "Python": { + "files": 432, + "blank": 10813, + "comment": 5054, + "code": 51057, + }, + }, + ), + ] - self.assertEqual(2, len(languages)) + product = Product.objects.get(id=1) + for case_name, payload_data, expected in test_cases: + with self.subTest(case=case_name): + self.payload = self._build_payload(payload_data) + response = self.client.post(self.url, self.payload) + self.assertEqual(201, response.status_code, response.content[:1000]) + self.check_schema_response("post", "201", response) - self.assertEqual(languages[0].product, Product.objects.get(id=1)) - self.assertEqual(languages[0].language, Language_Type.objects.get(id=1)) - self.assertEqual(languages[0].files, 21) - self.assertEqual(languages[0].blank, 7) - self.assertEqual(languages[0].comment, 0) - self.assertEqual(languages[0].code, 63996) + languages = ( + Languages.objects.filter(product=1) + .select_related("language") + .order_by("language__language") + ) + self.assertEqual(len(expected), languages.count()) - self.assertEqual(languages[1].product, Product.objects.get(id=1)) - self.assertEqual(languages[1].language, Language_Type.objects.get(id=2)) - self.assertEqual(languages[1].files, 432) - self.assertEqual(languages[1].blank, 10813) - self.assertEqual(languages[1].comment, 5054) - self.assertEqual(languages[1].code, 51056) + languages_by_name = { + language.language.language: language + for language in languages + } + self.assertEqual(set(expected.keys()), set(languages_by_name.keys())) + + for name, counts in expected.items(): + language = languages_by_name[name] + self.assertEqual(product, language.product) + self.assertEqual(name, language.language.language) + self.assertEqual(counts["files"], language.files) + self.assertEqual(counts["blank"], language.blank) + self.assertEqual(counts["comment"], language.comment) + self.assertEqual(counts["code"], language.code) @versioned_fixtures