From c4fc172feab27fce50ce3e6a39011e0e7926730a Mon Sep 17 00:00:00 2001 From: Richard Tibbles Date: Sat, 23 Apr 2022 08:59:47 -0700 Subject: [PATCH 1/2] Use recursion to properly handle nested extra_fields updates. --- .../tests/viewsets/test_contentnode.py | 20 ++++++ .../contentcuration/viewsets/common.py | 64 ++++++++++--------- 2 files changed, 53 insertions(+), 31 deletions(-) diff --git a/contentcuration/contentcuration/tests/viewsets/test_contentnode.py b/contentcuration/contentcuration/tests/viewsets/test_contentnode.py index 5bb17aabd0..7e4a4bd4c9 100644 --- a/contentcuration/contentcuration/tests/viewsets/test_contentnode.py +++ b/contentcuration/contentcuration/tests/viewsets/test_contentnode.py @@ -634,6 +634,26 @@ def test_update_contentnode_remove_from_extra_fields(self): with self.assertRaises(KeyError): models.ContentNode.objects.get(id=contentnode.id).extra_fields["m"] + def test_update_contentnode_remove_from_extra_fields_nested(self): + user = testdata.user() + metadata = self.contentnode_db_metadata + metadata["extra_fields"] = { + "options": { + "modality": "QUIZ", + }, + } + contentnode = models.ContentNode.objects.create(**metadata) + self.client.force_authenticate(user=user) + # Remove extra_fields.options.modality + response = self.client.post( + self.sync_url, + [generate_update_event(contentnode.id, CONTENTNODE, {"extra_fields.options.modality": None})], + format="json", + ) + self.assertEqual(response.status_code, 200, response.content) + with self.assertRaises(KeyError): + models.ContentNode.objects.get(id=contentnode.id).extra_fields["options"]["modality"] + def test_update_contentnode_add_multiple_metadata_labels(self): user = testdata.user() diff --git a/contentcuration/contentcuration/viewsets/common.py b/contentcuration/contentcuration/viewsets/common.py index fb03a3cc9f..dbe03056dc 100644 --- a/contentcuration/contentcuration/viewsets/common.py +++ b/contentcuration/contentcuration/viewsets/common.py @@ -125,8 +125,6 @@ def get_value(self, dictionary): # get just field name value = dictionary.get(self.field_name, {}) - self.initial_value = value - if value is None: return empty @@ -136,11 +134,26 @@ def get_value(self, dictionary): # then merge in fields with keys like `content_defaults.author` multi_value = MultiValueDict() multi_value.update(dictionary) - html_value = unnest_dict( - html.parse_html_dict(multi_value, prefix=self.field_name).dict() - ) - value.update(html_value) - + html_value = html.parse_html_dict(multi_value, prefix=self.field_name).dict() + + fields = getattr(self, "fields", {}) + + for key in html_value: + # Split on the first occurrence of a "." in case we are dealing with a dot path + # referencing a child field of this field. + keys = key.split(".", 1) + # Only attempt to use this if there is a dot path, and the parent of the dot path is + # a valid child field. Otherwise, we just use the value as-is. + if key not in fields and len(keys) == 2 and keys[0] in fields: + # If it is a valid child field, we invoke the nested field's get_value method + # with the value of the child field. + # N.B. the get_value method expects a dictionary that references the field's name + # not just the value. + value[keys[0]] = fields[keys[0]].get_value({keys[0]: {keys[1]: html_value[key]}}) + if key in value: + del value[key] + else: + value[key] = html_value[key] return value if value.keys() else empty @@ -154,30 +167,19 @@ def create(self, validated_data): def update(self, instance, validated_data): instance = instance or self.default_value() - instance.update(validated_data) - # This should have been set when get_value was invoked - # But could be `None`, so we check if it is truthy here. - if getattr(self, "initial_value", None): - # Iterate through each field - for key in self.initial_value: - # If the field is explicitly being set as None, then - # we need to delete it from the instance. - if self.initial_value[key] is None: - # Follow the dot path to find the nested object - obj = instance - # Iterate through each part of the dot path - # up until, but not including the final key - for part in key.split(".")[:-1]: - if isinstance(obj, dict): - # If it's a dict use get to get the next level object - obj = obj.get(part) - elif isinstance(obj, list): - # If it's a list, use the index to get the next level object - obj = obj[int(part)] - else: - raise ValidationError("Tried to access a dot path in an invalid type") - # Finally, delete the final key - obj.pop(key.split(".")[-1]) + for key, value in validated_data.items(): + if value is None: + # If the value is None, we delete the key from the instance. + # Silently ignore deletion of values that don't exist + if key in instance: + del instance[key] + elif hasattr(self.fields[key], "update"): + # If the nested field has an update method (e.g. a nested serializer), + # call the update value so that we can do any recursive updates + self.fields[key].update(instance[key], validated_data[key]) + else: + # Otherwise, just update the value + instance[key] = validated_data[key] return instance From e6daf76765a1f7b9b64446ce8799740a6faf7f45 Mon Sep 17 00:00:00 2001 From: Richard Tibbles Date: Sat, 23 Apr 2022 09:00:20 -0700 Subject: [PATCH 2/2] Add a dot path exception for metadata label fields. --- .../tests/viewsets/test_contentnode.py | 46 ++++++++++++++++++- .../contentcuration/viewsets/contentnode.py | 14 +++++- 2 files changed, 58 insertions(+), 2 deletions(-) diff --git a/contentcuration/contentcuration/tests/viewsets/test_contentnode.py b/contentcuration/contentcuration/tests/viewsets/test_contentnode.py index 7e4a4bd4c9..d8fb7d5e06 100644 --- a/contentcuration/contentcuration/tests/viewsets/test_contentnode.py +++ b/contentcuration/contentcuration/tests/viewsets/test_contentnode.py @@ -16,6 +16,7 @@ from le_utils.constants import content_kinds from le_utils.constants import roles from le_utils.constants.labels.accessibility_categories import ACCESSIBILITYCATEGORIESLIST +from le_utils.constants.labels.subjects import SUBJECTSLIST from contentcuration import models from contentcuration.tests import testdata @@ -31,6 +32,9 @@ from contentcuration.viewsets.sync.utils import generate_update_event +nested_subjects = [subject for subject in SUBJECTSLIST if "." in subject] + + def create_and_get_contentnode(parent_id): contentnode = models.ContentNode.objects.create( title="Aron's cool contentnode", @@ -677,12 +681,35 @@ def test_update_contentnode_add_multiple_metadata_labels(self): self.assertTrue(models.ContentNode.objects.get(id=contentnode.id).accessibility_labels[ACCESSIBILITYCATEGORIESLIST[0]]) self.assertTrue(models.ContentNode.objects.get(id=contentnode.id).accessibility_labels[ACCESSIBILITYCATEGORIESLIST[1]]) + def test_update_contentnode_add_multiple_nested_metadata_labels(self): + user = testdata.user() + + contentnode = models.ContentNode.objects.create(**self.contentnode_db_metadata) + self.client.force_authenticate(user=user) + # Add metadata label to categories + response = self.client.post( + self.sync_url, + [generate_update_event(contentnode.id, CONTENTNODE, {"categories.{}".format(nested_subjects[0]): True})], + format="json", + ) + self.assertEqual(response.status_code, 200, response.content) + self.assertTrue(models.ContentNode.objects.get(id=contentnode.id).categories[nested_subjects[0]]) + + response = self.client.post( + self.sync_url, + [generate_update_event(contentnode.id, CONTENTNODE, {"categories.{}".format(nested_subjects[1]): True})], + format="json", + ) + self.assertEqual(response.status_code, 200, response.content) + self.assertTrue(models.ContentNode.objects.get(id=contentnode.id).categories[nested_subjects[0]]) + self.assertTrue(models.ContentNode.objects.get(id=contentnode.id).categories[nested_subjects[1]]) + def test_update_contentnode_remove_metadata_label(self): user = testdata.user() metadata = self.contentnode_db_metadata metadata["accessibility_labels"] = {ACCESSIBILITYCATEGORIESLIST[0]: True} - contentnode = models.ContentNode.objects.create(**self.contentnode_db_metadata) + contentnode = models.ContentNode.objects.create(**metadata) self.client.force_authenticate(user=user) # Add metadata label to accessibility_labels response = self.client.post( @@ -694,6 +721,23 @@ def test_update_contentnode_remove_metadata_label(self): with self.assertRaises(KeyError): models.ContentNode.objects.get(id=contentnode.id).accessibility_labels[ACCESSIBILITYCATEGORIESLIST[0]] + def test_update_contentnode_remove_nested_metadata_label(self): + user = testdata.user() + metadata = self.contentnode_db_metadata + metadata["categories"] = {nested_subjects[0]: True} + + contentnode = models.ContentNode.objects.create(**self.contentnode_db_metadata) + self.client.force_authenticate(user=user) + # Add metadata label to categories + response = self.client.post( + self.sync_url, + [generate_update_event(contentnode.id, CONTENTNODE, {"categories.{}".format(nested_subjects[0]): None})], + format="json", + ) + self.assertEqual(response.status_code, 200, response.content) + with self.assertRaises(KeyError): + models.ContentNode.objects.get(id=contentnode.id).categories[nested_subjects[0]] + def test_update_contentnode_tags(self): user = testdata.user() contentnode = models.ContentNode.objects.create(**self.contentnode_db_metadata) diff --git a/contentcuration/contentcuration/viewsets/contentnode.py b/contentcuration/contentcuration/viewsets/contentnode.py index ca5609365f..7c8c5bdc0c 100644 --- a/contentcuration/contentcuration/viewsets/contentnode.py +++ b/contentcuration/contentcuration/viewsets/contentnode.py @@ -271,6 +271,17 @@ class TagField(DotPathValueMixin, DictField): pass +class MetadataLabelBooleanField(BooleanField): + def bind(self, field_name, parent): + # By default the bind method of the Field class sets the source_attrs to field_name.split("."). + # As we have literal field names that include "." we need to override this behavior. + # Otherwise it will attempt to set the source_attrs to a nested path, assuming that it is a source path, + # not a materialized path. This probably means that it was a bad idea to use "." in the materialized path, + # but alea iacta est. + super(MetadataLabelBooleanField, self).bind(field_name, parent) + self.source_attrs = [self.source] + + class MetadataLabelsField(JSONFieldDictSerializer): def __init__(self, choices, *args, **kwargs): self.choices = choices @@ -280,8 +291,9 @@ def __init__(self, choices, *args, **kwargs): def get_fields(self): fields = {} for label_id, label_name in self.choices: - field = BooleanField(required=False, label=label_name, allow_null=True) + field = MetadataLabelBooleanField(required=False, label=label_name, allow_null=True) fields[label_id] = field + return fields