diff --git a/src/apps/api/tests/test_datasets.py b/src/apps/api/tests/test_datasets.py index 8f676d90a..daa0d6d52 100644 --- a/src/apps/api/tests/test_datasets.py +++ b/src/apps/api/tests/test_datasets.py @@ -175,7 +175,6 @@ def setUp(self): created_by=self.owner, downloads=5 ) - self.private_dataset = DataFactory( is_public=False, created_by=self.owner, @@ -187,13 +186,10 @@ def test_download_public_dataset(self, mock_make_url_sassy): # Mock the URL that would normally be generated for the file # This avoids depending on actual file storage or signature logic mock_make_url_sassy.return_value = "http://codebench-storage/public_dataset.zip" - response = self.client.get(reverse("datasets:download_by_pk", args=[self.public_dataset.pk])) - # Should redirect to the URL self.assertEqual(response.status_code, 302) self.assertEqual(response["Location"], "http://codebench-storage/public_dataset.zip") - # Should increment download count self.public_dataset.refresh_from_db() self.assertEqual(self.public_dataset.downloads, 6) @@ -203,27 +199,21 @@ def test_download_private_dataset_as_owner(self, mock_make_url_sassy): # Mock the URL that would normally be generated for the file # This avoids depending on actual file storage or signature logic mock_make_url_sassy.return_value = "http://codebench-storage/private_dataset.zip" - response = self.client.get(reverse("datasets:download_by_pk", args=[self.private_dataset.pk])) - self.assertEqual(response.status_code, 302) self.assertEqual(response["Location"], "http://codebench-storage/private_dataset.zip") - self.private_dataset.refresh_from_db() self.assertEqual(self.private_dataset.downloads, 3) def test_download_private_dataset_as_other_user(self): # Authenticate as a different user who is not the owner self.client.force_login(self.other_user) - response = self.client.get(reverse("datasets:download_by_pk", args=[self.private_dataset.pk])) - # Should return 404 (access denied) - self.assertEqual(response.status_code, 404) + self.assertEqual(response.status_code, 403) def test_download_nonexistent_dataset(self): response = self.client.get(reverse("datasets:download_by_pk", args=[99999])) - # Should return 404 (access denied) self.assertEqual(response.status_code, 404) diff --git a/src/apps/datasets/views.py b/src/apps/datasets/views.py index b6c312970..7330a9735 100644 --- a/src/apps/datasets/views.py +++ b/src/apps/datasets/views.py @@ -1,11 +1,74 @@ from django.contrib.auth.mixins import LoginRequiredMixin +from django.core.exceptions import PermissionDenied +from django.db.models import Q from django.http import HttpResponseRedirect, Http404 from django.shortcuts import get_object_or_404 from django.views.generic import TemplateView, DetailView - from datasets.models import Data from utils.data import make_url_sassy from api.serializers.datasets import DatasetSerializer +from competitions.models import Competition, CompetitionParticipant + + +def user_can_download(user, data): + if data.is_public: + return True + if not user.is_authenticated: + return False + if data.created_by == user: + return True + + # Organizers (creator + collaborators) can download any dataset in their competition + organizer_qs = Competition.objects.filter( + Q(created_by=user) | Q(collaborators=user) + ).filter( + Q(phases__public_data=data) | + Q(phases__starting_kit=data) | + Q(phases__task_instances__task__input_data=data) | + Q(phases__task_instances__task__reference_data=data) | + Q(phases__task_instances__task__scoring_program=data) | + Q(phases__task_instances__task__ingestion_program=data) | + Q(phases__task_instances__task__solutions__data=data) + ) + if data.type == Data.SUBMISSION and data.competition: + organizer_qs = organizer_qs | Competition.objects.filter( + Q(created_by=user) | Q(collaborators=user), + pk=data.competition_id, + ) + if organizer_qs.exists(): + return True + + # Reference data, submissions, and bundles are never accessible to participants + if data.type in (Data.REFERENCE_DATA, Data.SUBMISSION, Data.COMPETITION_BUNDLE): + return False + + approved_participant = Q( + participants__user=user, + participants__status=CompetitionParticipant.APPROVED, + ) + + if data.type in (Data.PUBLIC_DATA, Data.STARTING_KIT): + return Competition.objects.filter(approved_participant).filter( + Q(phases__public_data=data) | Q(phases__starting_kit=data) + ).exists() + + if data.type == Data.INPUT_DATA: + return Competition.objects.filter(approved_participant, make_input_data_available=True).filter( + phases__task_instances__task__input_data=data + ).exists() + + if data.type in (Data.SCORING_PROGRAM, Data.INGESTION_PROGRAM): + return Competition.objects.filter(approved_participant, make_programs_available=True).filter( + Q(phases__task_instances__task__scoring_program=data) | + Q(phases__task_instances__task__ingestion_program=data) + ).exists() + + if data.type == Data.SOLUTION: + return Competition.objects.filter(approved_participant).filter( + phases__task_instances__task__solutions__data=data + ).exists() + + return False class DataManagement(LoginRequiredMixin, TemplateView): @@ -26,20 +89,17 @@ class DatasetDetail(DetailView): def get_object(self, *args, **kwargs): dataset = super().get_object(*args, **kwargs) - # If dataset is public or (user is authenticated and is owner), return dataset if dataset.is_public or ( self.request.user.is_authenticated and dataset.created_by == self.request.user ): return dataset - # Otherwise return 404 raise Http404() def get_context_data(self, **kwargs): context = super().get_context_data(**kwargs) dataset = context["object"] - serializer = DatasetSerializer(dataset) context["object"] = serializer.data return context @@ -47,18 +107,19 @@ def get_context_data(self, **kwargs): def download(request, key): data = get_object_or_404(Data, key=key) + if not user_can_download(request.user, data): + if request.user.is_authenticated: + raise PermissionDenied() + raise Http404() return HttpResponseRedirect(make_url_sassy(data.data_file.name)) def download_by_pk(request, pk): dataset = get_object_or_404(Data, pk=pk) - - if dataset.is_public or dataset.created_by == request.user: - # Increment download count - dataset.downloads = (dataset.downloads or 0) + 1 - dataset.save(update_fields=["downloads"]) - - # Redirect to the actual file URL - return HttpResponseRedirect(make_url_sassy(dataset.data_file.name)) - - raise Http404() + if not user_can_download(request.user, dataset): + if request.user.is_authenticated: + raise PermissionDenied() + raise Http404() + dataset.downloads = (dataset.downloads or 0) + 1 + dataset.save(update_fields=["downloads"]) + return HttpResponseRedirect(make_url_sassy(dataset.data_file.name))