Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from le_utils.constants import content_kinds
from le_utils.constants import roles
from le_utils.constants.labels.accessibility_categories import ACCESSIBILITYCATEGORIESLIST
from le_utils.constants.labels.subjects import SUBJECTSLIST

from contentcuration import models
from contentcuration.tests import testdata
Expand All @@ -31,6 +32,9 @@
from contentcuration.viewsets.sync.utils import generate_update_event


nested_subjects = [subject for subject in SUBJECTSLIST if "." in subject]


def create_and_get_contentnode(parent_id):
contentnode = models.ContentNode.objects.create(
title="Aron's cool contentnode",
Expand Down Expand Up @@ -634,6 +638,26 @@ def test_update_contentnode_remove_from_extra_fields(self):
with self.assertRaises(KeyError):
models.ContentNode.objects.get(id=contentnode.id).extra_fields["m"]

def test_update_contentnode_remove_from_extra_fields_nested(self):
user = testdata.user()
metadata = self.contentnode_db_metadata
metadata["extra_fields"] = {
"options": {
"modality": "QUIZ",
},
}
contentnode = models.ContentNode.objects.create(**metadata)
self.client.force_authenticate(user=user)
# Remove extra_fields.options.modality
response = self.client.post(
self.sync_url,
[generate_update_event(contentnode.id, CONTENTNODE, {"extra_fields.options.modality": None})],
format="json",
)
self.assertEqual(response.status_code, 200, response.content)
with self.assertRaises(KeyError):
models.ContentNode.objects.get(id=contentnode.id).extra_fields["options"]["modality"]

def test_update_contentnode_add_multiple_metadata_labels(self):
user = testdata.user()

Expand All @@ -657,12 +681,35 @@ def test_update_contentnode_add_multiple_metadata_labels(self):
self.assertTrue(models.ContentNode.objects.get(id=contentnode.id).accessibility_labels[ACCESSIBILITYCATEGORIESLIST[0]])
self.assertTrue(models.ContentNode.objects.get(id=contentnode.id).accessibility_labels[ACCESSIBILITYCATEGORIESLIST[1]])

def test_update_contentnode_add_multiple_nested_metadata_labels(self):
user = testdata.user()

contentnode = models.ContentNode.objects.create(**self.contentnode_db_metadata)
self.client.force_authenticate(user=user)
# Add metadata label to categories
response = self.client.post(
self.sync_url,
[generate_update_event(contentnode.id, CONTENTNODE, {"categories.{}".format(nested_subjects[0]): True})],
format="json",
)
self.assertEqual(response.status_code, 200, response.content)
self.assertTrue(models.ContentNode.objects.get(id=contentnode.id).categories[nested_subjects[0]])

response = self.client.post(
self.sync_url,
[generate_update_event(contentnode.id, CONTENTNODE, {"categories.{}".format(nested_subjects[1]): True})],
format="json",
)
self.assertEqual(response.status_code, 200, response.content)
self.assertTrue(models.ContentNode.objects.get(id=contentnode.id).categories[nested_subjects[0]])
self.assertTrue(models.ContentNode.objects.get(id=contentnode.id).categories[nested_subjects[1]])

def test_update_contentnode_remove_metadata_label(self):
user = testdata.user()
metadata = self.contentnode_db_metadata
metadata["accessibility_labels"] = {ACCESSIBILITYCATEGORIESLIST[0]: True}

contentnode = models.ContentNode.objects.create(**self.contentnode_db_metadata)
contentnode = models.ContentNode.objects.create(**metadata)
self.client.force_authenticate(user=user)
# Add metadata label to accessibility_labels
response = self.client.post(
Expand All @@ -674,6 +721,23 @@ def test_update_contentnode_remove_metadata_label(self):
with self.assertRaises(KeyError):
models.ContentNode.objects.get(id=contentnode.id).accessibility_labels[ACCESSIBILITYCATEGORIESLIST[0]]

def test_update_contentnode_remove_nested_metadata_label(self):
user = testdata.user()
metadata = self.contentnode_db_metadata
metadata["categories"] = {nested_subjects[0]: True}

contentnode = models.ContentNode.objects.create(**self.contentnode_db_metadata)
self.client.force_authenticate(user=user)
# Add metadata label to categories
response = self.client.post(
self.sync_url,
[generate_update_event(contentnode.id, CONTENTNODE, {"categories.{}".format(nested_subjects[0]): None})],
format="json",
)
self.assertEqual(response.status_code, 200, response.content)
with self.assertRaises(KeyError):
models.ContentNode.objects.get(id=contentnode.id).categories[nested_subjects[0]]

def test_update_contentnode_tags(self):
user = testdata.user()
contentnode = models.ContentNode.objects.create(**self.contentnode_db_metadata)
Expand Down
64 changes: 33 additions & 31 deletions contentcuration/contentcuration/viewsets/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,6 @@ def get_value(self, dictionary):
# get just field name
value = dictionary.get(self.field_name, {})

self.initial_value = value

if value is None:
return empty

Expand All @@ -136,11 +134,26 @@ def get_value(self, dictionary):
# then merge in fields with keys like `content_defaults.author`
multi_value = MultiValueDict()
multi_value.update(dictionary)
html_value = unnest_dict(
html.parse_html_dict(multi_value, prefix=self.field_name).dict()
)
value.update(html_value)

html_value = html.parse_html_dict(multi_value, prefix=self.field_name).dict()

fields = getattr(self, "fields", {})

for key in html_value:
# Split on the first occurrence of a "." in case we are dealing with a dot path
# referencing a child field of this field.
keys = key.split(".", 1)
# Only attempt to use this if there is a dot path, and the parent of the dot path is
# a valid child field. Otherwise, we just use the value as-is.
if key not in fields and len(keys) == 2 and keys[0] in fields:
# If it is a valid child field, we invoke the nested field's get_value method
# with the value of the child field.
# N.B. the get_value method expects a dictionary that references the field's name
# not just the value.
value[keys[0]] = fields[keys[0]].get_value({keys[0]: {keys[1]: html_value[key]}})
if key in value:
del value[key]
else:
value[key] = html_value[key]
return value if value.keys() else empty


Expand All @@ -154,30 +167,19 @@ def create(self, validated_data):

def update(self, instance, validated_data):
instance = instance or self.default_value()
instance.update(validated_data)
# This should have been set when get_value was invoked
# But could be `None`, so we check if it is truthy here.
if getattr(self, "initial_value", None):
# Iterate through each field
for key in self.initial_value:
# If the field is explicitly being set as None, then
# we need to delete it from the instance.
if self.initial_value[key] is None:
# Follow the dot path to find the nested object
obj = instance
# Iterate through each part of the dot path
# up until, but not including the final key
for part in key.split(".")[:-1]:
if isinstance(obj, dict):
# If it's a dict use get to get the next level object
obj = obj.get(part)
elif isinstance(obj, list):
# If it's a list, use the index to get the next level object
obj = obj[int(part)]
else:
raise ValidationError("Tried to access a dot path in an invalid type")
# Finally, delete the final key
obj.pop(key.split(".")[-1])
for key, value in validated_data.items():
if value is None:
# If the value is None, we delete the key from the instance.
# Silently ignore deletion of values that don't exist
if key in instance:
del instance[key]
elif hasattr(self.fields[key], "update"):
# If the nested field has an update method (e.g. a nested serializer),
# call the update value so that we can do any recursive updates
self.fields[key].update(instance[key], validated_data[key])
else:
# Otherwise, just update the value
instance[key] = validated_data[key]
return instance


Expand Down
14 changes: 13 additions & 1 deletion contentcuration/contentcuration/viewsets/contentnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,17 @@ class TagField(DotPathValueMixin, DictField):
pass


class MetadataLabelBooleanField(BooleanField):
def bind(self, field_name, parent):
# By default the bind method of the Field class sets the source_attrs to field_name.split(".").
# As we have literal field names that include "." we need to override this behavior.
# Otherwise it will attempt to set the source_attrs to a nested path, assuming that it is a source path,
# not a materialized path. This probably means that it was a bad idea to use "." in the materialized path,
# but alea iacta est.
super(MetadataLabelBooleanField, self).bind(field_name, parent)
self.source_attrs = [self.source]


class MetadataLabelsField(JSONFieldDictSerializer):
def __init__(self, choices, *args, **kwargs):
self.choices = choices
Expand All @@ -280,8 +291,9 @@ def __init__(self, choices, *args, **kwargs):
def get_fields(self):
fields = {}
for label_id, label_name in self.choices:
field = BooleanField(required=False, label=label_name, allow_null=True)
field = MetadataLabelBooleanField(required=False, label=label_name, allow_null=True)
fields[label_id] = field

return fields


Expand Down