diff --git a/contentcuration/contentcuration/tests/test_exportchannel.py b/contentcuration/contentcuration/tests/test_exportchannel.py index 7c1f3d6761..e3c4568acb 100644 --- a/contentcuration/contentcuration/tests/test_exportchannel.py +++ b/contentcuration/contentcuration/tests/test_exportchannel.py @@ -4,9 +4,11 @@ import string import tempfile import uuid +from unittest import mock import pytest from celery import states +from django.conf import settings from django.core.management import call_command from django.db import connections from django_celery_results.models import TaskResult @@ -32,6 +34,7 @@ from .testdata import slideshow from .testdata import thumbnail_bytes from .testdata import tree +from .utils.restricted_filesystemstorage import RestrictedFileSystemStorage from contentcuration import models as cc from contentcuration.models import CustomTaskMetadata from contentcuration.utils.assessment.qti.archive import hex_to_qti_id @@ -39,6 +42,7 @@ from contentcuration.utils.publish import ChannelIncompleteError from contentcuration.utils.publish import convert_channel_thumbnail from contentcuration.utils.publish import create_content_database +from contentcuration.utils.publish import create_draft_channel_version from contentcuration.utils.publish import create_slideshow_manifest from contentcuration.utils.publish import fill_published_fields from contentcuration.utils.publish import map_prerequisites @@ -732,7 +736,7 @@ def test_create_slideshow_manifest(self): ) create_slideshow_manifest(ccnode) manifest_collection = cc.File.objects.filter( - contentnode=ccnode, preset_id=u"slideshow_manifest" + contentnode=ccnode, preset_id="slideshow_manifest" ) assert len(manifest_collection) == 1 @@ -1130,3 +1134,120 @@ def test_only_next_file_created(self): call_args = self.mock_save_export.call_args self.assertEqual(call_args[0][1], "next") self.assertEqual(call_args[0][2], True) + + +class PublishChannelDraftCleanupTestCase(StudioTestCase): + """Test that publish cleans up draft artifacts.""" + + @classmethod + def setUpClass(cls): + super(PublishChannelDraftCleanupTestCase, cls).setUpClass() + cls.patch_copy_db = patch("contentcuration.utils.publish.save_export_database") + cls.patch_copy_db.start() + + @classmethod + def tearDownClass(cls): + super(PublishChannelDraftCleanupTestCase, cls).tearDownClass() + cls.patch_copy_db.stop() + + def setUp(self): + super(PublishChannelDraftCleanupTestCase, self).setUp() + + self._temp_directory_ctx = tempfile.TemporaryDirectory() + self.test_db_root_dir = self._temp_directory_ctx.__enter__() + + restricted_storage = RestrictedFileSystemStorage(location=self.test_db_root_dir) + + self._storage_patch_ctx = mock.patch( + "contentcuration.utils.publish.storage", + new=restricted_storage, + ) + self._storage_patch_ctx.__enter__() + + os.makedirs( + os.path.join(self.test_db_root_dir, settings.DB_ROOT), exist_ok=True + ) + + self.content_channel = channel() + self.content_channel.version = 2 + self.content_channel.save() + + self.draft_db_path = os.path.join( + self.test_db_root_dir, + settings.DB_ROOT, + f"{self.content_channel.id}-next.sqlite3", + ) + + def tearDown(self): + self._temp_directory_ctx.__exit__(None, None, None) + self._storage_patch_ctx.__exit__(None, None, None) + + super(PublishChannelDraftCleanupTestCase, self).tearDown() + + def run_publish(self): + publish_channel( + self.admin_user.id, + self.content_channel.id, + force=True, + force_exercises=False, + send_email=False, + progress_tracker=None, + is_draft_version=False, + use_staging_tree=False, + ) + + def test_draft_channel_version_removed(self): + create_draft_channel_version(self.content_channel) + self.assertTrue( + cc.ChannelVersion.objects.filter( + channel=self.content_channel, version=None + ).exists() + ) + + self.run_publish() + + self.assertFalse( + cc.ChannelVersion.objects.filter( + channel=self.content_channel, version=None + ).exists() + ) + + def test_draft_database_removed(self): + with open(self.draft_db_path, "w") as f: + f.write("draft content") + self.assertTrue(os.path.exists(self.draft_db_path)) + + self.run_publish() + + self.assertFalse(os.path.exists(self.draft_db_path)) + + def test_no_draft_artifacts_no_error(self): + self.assertFalse( + cc.ChannelVersion.objects.filter( + channel=self.content_channel, version=None + ).exists() + ) + self.assertFalse(os.path.exists(self.draft_db_path)) + + self.run_publish() + + def test_published_channel_versions_not_affected(self): + create_draft_channel_version(self.content_channel) + + published_count = cc.ChannelVersion.objects.filter( + channel=self.content_channel, version__isnull=False + ).count() + + self.run_publish() + + self.content_channel.refresh_from_db() + new_published_count = cc.ChannelVersion.objects.filter( + channel=self.content_channel, version__isnull=False + ).count() + + self.assertEqual(new_published_count, published_count + 1) + self.assertFalse( + cc.ChannelVersion.objects.filter( + channel=self.content_channel, version=None + ).exists() + ) diff --git a/contentcuration/contentcuration/utils/publish.py b/contentcuration/contentcuration/utils/publish.py index 9a00814dd7..1b8f08131a 100644 --- a/contentcuration/contentcuration/utils/publish.py +++ b/contentcuration/contentcuration/utils/publish.py @@ -1141,6 +1141,13 @@ def publish_channel( # noqa: C901 else: increment_channel_version(channel) if not is_draft_version: + ccmodels.ChannelVersion.objects.filter( + channel=channel, version=None + ).delete() + draft_db_path = get_content_db_path(channel_id, "next") + if storage.exists(draft_db_path): + storage.delete(draft_db_path) + sync_contentnode_and_channel_tsvectors(channel_id=channel.id) mark_all_nodes_as_published(base_tree) fill_published_fields(channel, version_notes)