Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 122 additions & 1 deletion contentcuration/contentcuration/tests/test_exportchannel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
import string
import tempfile
import uuid
from unittest import mock

import pytest
from celery import states
from django.conf import settings
from django.core.management import call_command
from django.db import connections
from django_celery_results.models import TaskResult
Expand All @@ -32,13 +34,15 @@
from .testdata import slideshow
from .testdata import thumbnail_bytes
from .testdata import tree
from .utils.restricted_filesystemstorage import RestrictedFileSystemStorage
from contentcuration import models as cc
from contentcuration.models import CustomTaskMetadata
from contentcuration.utils.assessment.qti.archive import hex_to_qti_id
from contentcuration.utils.celery.tasks import generate_task_signature
from contentcuration.utils.publish import ChannelIncompleteError
from contentcuration.utils.publish import convert_channel_thumbnail
from contentcuration.utils.publish import create_content_database
from contentcuration.utils.publish import create_draft_channel_version
from contentcuration.utils.publish import create_slideshow_manifest
from contentcuration.utils.publish import fill_published_fields
from contentcuration.utils.publish import map_prerequisites
Expand Down Expand Up @@ -732,7 +736,7 @@ def test_create_slideshow_manifest(self):
)
create_slideshow_manifest(ccnode)
manifest_collection = cc.File.objects.filter(
contentnode=ccnode, preset_id=u"slideshow_manifest"
contentnode=ccnode, preset_id="slideshow_manifest"
)
assert len(manifest_collection) == 1

Expand Down Expand Up @@ -1130,3 +1134,120 @@ def test_only_next_file_created(self):
call_args = self.mock_save_export.call_args
self.assertEqual(call_args[0][1], "next")
self.assertEqual(call_args[0][2], True)


class PublishChannelDraftCleanupTestCase(StudioTestCase):
"""Test that publish cleans up draft artifacts."""

@classmethod
def setUpClass(cls):
super(PublishChannelDraftCleanupTestCase, cls).setUpClass()
cls.patch_copy_db = patch("contentcuration.utils.publish.save_export_database")
cls.patch_copy_db.start()

@classmethod
def tearDownClass(cls):
super(PublishChannelDraftCleanupTestCase, cls).tearDownClass()
cls.patch_copy_db.stop()

def setUp(self):
super(PublishChannelDraftCleanupTestCase, self).setUp()

self._temp_directory_ctx = tempfile.TemporaryDirectory()
self.test_db_root_dir = self._temp_directory_ctx.__enter__()

restricted_storage = RestrictedFileSystemStorage(location=self.test_db_root_dir)

self._storage_patch_ctx = mock.patch(
"contentcuration.utils.publish.storage",
new=restricted_storage,
)
self._storage_patch_ctx.__enter__()

os.makedirs(
os.path.join(self.test_db_root_dir, settings.DB_ROOT), exist_ok=True
)

self.content_channel = channel()
self.content_channel.version = 2
self.content_channel.save()

self.draft_db_path = os.path.join(
self.test_db_root_dir,
settings.DB_ROOT,
f"{self.content_channel.id}-next.sqlite3",
)

def tearDown(self):
self._temp_directory_ctx.__exit__(None, None, None)
self._storage_patch_ctx.__exit__(None, None, None)

super(PublishChannelDraftCleanupTestCase, self).tearDown()

def run_publish(self):
publish_channel(
self.admin_user.id,
self.content_channel.id,
force=True,
force_exercises=False,
send_email=False,
progress_tracker=None,
is_draft_version=False,
use_staging_tree=False,
)

def test_draft_channel_version_removed(self):
create_draft_channel_version(self.content_channel)
self.assertTrue(
cc.ChannelVersion.objects.filter(
channel=self.content_channel, version=None
).exists()
)

self.run_publish()

self.assertFalse(
cc.ChannelVersion.objects.filter(
channel=self.content_channel, version=None
).exists()
)

def test_draft_database_removed(self):
with open(self.draft_db_path, "w") as f:
f.write("draft content")
self.assertTrue(os.path.exists(self.draft_db_path))

self.run_publish()

self.assertFalse(os.path.exists(self.draft_db_path))

def test_no_draft_artifacts_no_error(self):
self.assertFalse(
cc.ChannelVersion.objects.filter(
channel=self.content_channel, version=None
).exists()
)
self.assertFalse(os.path.exists(self.draft_db_path))

self.run_publish()

def test_published_channel_versions_not_affected(self):
create_draft_channel_version(self.content_channel)

published_count = cc.ChannelVersion.objects.filter(
channel=self.content_channel, version__isnull=False
).count()

self.run_publish()

self.content_channel.refresh_from_db()
new_published_count = cc.ChannelVersion.objects.filter(
channel=self.content_channel, version__isnull=False
).count()

self.assertEqual(new_published_count, published_count + 1)
self.assertFalse(
cc.ChannelVersion.objects.filter(
channel=self.content_channel, version=None
).exists()
)
7 changes: 7 additions & 0 deletions contentcuration/contentcuration/utils/publish.py
Original file line number Diff line number Diff line change
Expand Up @@ -1141,6 +1141,13 @@ def publish_channel( # noqa: C901
else:
increment_channel_version(channel)
if not is_draft_version:
ccmodels.ChannelVersion.objects.filter(
channel=channel, version=None
).delete()
draft_db_path = get_content_db_path(channel_id, "next")
if storage.exists(draft_db_path):
storage.delete(draft_db_path)

sync_contentnode_and_channel_tsvectors(channel_id=channel.id)
mark_all_nodes_as_published(base_tree)
fill_published_fields(channel, version_notes)
Expand Down