diff --git a/django_admin_reversefields/mixins.py b/django_admin_reversefields/mixins.py index 2a832f0..8c3e975 100644 --- a/django_admin_reversefields/mixins.py +++ b/django_admin_reversefields/mixins.py @@ -51,7 +51,7 @@ from django import forms from django.contrib.admin.widgets import FilteredSelectMultiple from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured -from django.db import models, transaction +from django.db import IntegrityError, models, transaction from django.http import HttpRequest PermissionCallable = Callable[ @@ -898,11 +898,7 @@ def _apply_bulk_unbind(self, config: ReverseRelationConfig, instance, exclude_pk Raises: forms.ValidationError: If database constraints prevent the unbind operation. - Exception: Any other database error during the bulk update. """ - from django import forms - from django.db import IntegrityError - try: # Build queryset for objects currently bound to this instance queryset = config.model._default_manager.filter(**{config.fk_field: instance}) @@ -911,18 +907,13 @@ def _apply_bulk_unbind(self, config: ReverseRelationConfig, instance, exclude_pk if exclude_pks: queryset = queryset.exclude(pk__in=exclude_pks) - # Perform bulk unbind using .update() - if queryset.exists(): - queryset.update(**{config.fk_field: None}) + # .update() is a no-op on empty querysets and returns 0. + queryset.update(**{config.fk_field: None}) except IntegrityError as e: raise forms.ValidationError( f"Bulk unbind operation failed for {config.model._meta.verbose_name}: {e}" ) from e - except Exception as e: - raise forms.ValidationError( - f"Unexpected error during bulk unbind operation: {e}" - ) from e def _apply_bulk_bind(self, config: ReverseRelationConfig, instance, target_objects): """Bind multiple objects using .update() for performance. @@ -938,11 +929,7 @@ def _apply_bulk_bind(self, config: ReverseRelationConfig, instance, target_objec Raises: forms.ValidationError: If database constraints prevent the bind operation. - Exception: Any other database error during the bulk update. """ - from django import forms - from django.db import IntegrityError - if not target_objects: return @@ -965,8 +952,6 @@ def _apply_bulk_bind(self, config: ReverseRelationConfig, instance, target_objec raise forms.ValidationError( f"Bulk bind operation failed for {config.model._meta.verbose_name}: {e}" ) from e - except Exception as e: - raise forms.ValidationError(f"Unexpected error during bulk bind operation: {e}") from e def _apply_bulk_operations(self, config: ReverseRelationConfig, instance, selection): """Coordinate bulk unbind and bind operations for a reverse relation field. @@ -982,44 +967,31 @@ def _apply_bulk_operations(self, config: ReverseRelationConfig, instance, select iterable of objects for multi-select). Raises: - forms.ValidationError: If database constraints prevent the operations - or other errors occur during bulk updates. + forms.ValidationError: If database constraints prevent the operations. """ - from django import forms + if config.multiple: + # Multi-select scenario + selected = list(selection) if selection else [] + selected_ids = {obj.pk for obj in selected} - try: - if config.multiple: - # Multi-select scenario - selected = list(selection) if selection else [] - selected_ids = {obj.pk for obj in selected} + # Step 1: Bulk unbind objects that are no longer selected + # (exclude the ones that should remain bound) + self._apply_bulk_unbind(config, instance, selected_ids) - # Step 1: Bulk unbind objects that are no longer selected - # (exclude the ones that should remain bound) - self._apply_bulk_unbind(config, instance, selected_ids) + # Step 2: Bulk bind newly selected objects + self._apply_bulk_bind(config, instance, selected) - # Step 2: Bulk bind newly selected objects - self._apply_bulk_bind(config, instance, selected) + else: + # Single-select scenario + target = selection - else: - # Single-select scenario - target = selection - - # Step 1: Bulk unbind all current relations - # For single-select, we unbind everything first - self._apply_bulk_unbind(config, instance, set()) - - # Step 2: Bulk bind the target (if provided) - if target: - self._apply_bulk_bind(config, instance, [target]) - - except forms.ValidationError: - # Re-raise validation errors as-is - raise - except Exception as e: - # Wrap unexpected errors in ValidationError with meaningful message - raise forms.ValidationError( - f"Bulk operation failed for {config.model._meta.verbose_name}: {e}" - ) from e + # Step 1: Bulk unbind all current relations + # For single-select, we unbind everything first + self._apply_bulk_unbind(config, instance, set()) + + # Step 2: Bulk bind the target (if provided) + if target: + self._apply_bulk_bind(config, instance, [target]) def _apply_individual_operations(self, config: ReverseRelationConfig, instance, selection): """Apply bind/unbind operations using individual model saves. diff --git a/tests/parameterized/test_binding.py b/tests/parameterized/test_binding.py index 2bdf145..0e1daf0 100644 --- a/tests/parameterized/test_binding.py +++ b/tests/parameterized/test_binding.py @@ -464,3 +464,29 @@ def test_commit_false_defers_reverse_updates_until_save_model_both_modes(self): self.assertEqual(project.company, obj) self.assertEqual(department.company, obj) self.assertIsNone(form._reverse_relation_data) + + def test_single_select_resubmit_same_value_both_modes(self): + """Re-submitting the same single-select value should remain bound after save.""" + for bulk_enabled in [False, True]: + with self.subTest(bulk_enabled=bulk_enabled): + department = Department.objects.create(name=f"Same Selection {bulk_enabled}") + admin_instance = create_parameterized_admin(bulk_enabled=bulk_enabled) + request = self.factory.post("/") + form_cls = admin_instance.get_form(request, self.company) + + first_submit = form_cls( + {"name": self.company.name, "department_binding": department.pk}, + instance=self.company, + ) + self.assertTrue(first_submit.is_valid()) + saved_company = first_submit.save() + + second_submit = form_cls( + {"name": self.company.name, "department_binding": department.pk}, + instance=saved_company, + ) + self.assertTrue(second_submit.is_valid()) + second_submit.save() + + department.refresh_from_db() + self.assertEqual(department.company, saved_company) diff --git a/tests/test_edge_cases.py b/tests/test_edge_cases.py index caa4e3a..fcd81a4 100644 --- a/tests/test_edge_cases.py +++ b/tests/test_edge_cases.py @@ -1,6 +1,7 @@ """Tests for edge cases and non-parameterizable scenarios.""" # Test imports +from django import forms from django.contrib import admin from django.core.exceptions import ImproperlyConfigured from django.db import transaction @@ -384,6 +385,61 @@ class TestAdmin(ReverseRelationAdminMixin, admin.ModelAdmin): settings_a.refresh_from_db() self.assertIsNone(settings_a.company) + def test_bulk_atomic_rollback_across_multiple_fields(self): + """Bulk updates across fields should roll back as one unit on failure.""" + + class TestAdmin(ReverseRelationAdminMixin, admin.ModelAdmin): + reverse_relations = { + "department_binding": ReverseRelationConfig( + model=Department, + fk_field="company", + multiple=False, + bulk=True, + ), + "project_binding": ReverseRelationConfig( + model=Project, + fk_field="company", + multiple=False, + bulk=True, + ), + } + + department = Department.objects.create(name="Rollback Dept") + project = Project.objects.create(name="Rollback Project") + + request = self.factory.post("/") + admin_inst = TestAdmin(Company, self.site) + + original_bulk_bind = admin_inst._apply_bulk_bind + + def fail_on_project_bind(config, instance, target_objects): + if config.model is Project: + raise forms.ValidationError("Forced bulk failure on project binding") + return original_bulk_bind(config, instance, target_objects) + + admin_inst._apply_bulk_bind = fail_on_project_bind + + form_cls = admin_inst.get_form(request, self.company) + form = form_cls( + { + "name": self.company.name, + "department_binding": department.pk, + "project_binding": project.pk, + }, + instance=self.company, + ) + + self.assertTrue(form.is_valid()) + + with self.assertRaises(forms.ValidationError): + form.save() + + # Department was processed first and would have been bound without rollback. + department.refresh_from_db() + project.refresh_from_db() + self.assertIsNone(department.company) + self.assertIsNone(project.company) + def test_multiple_companies_complex_bindings(self): """Test base operations with multiple companies and complex binding patterns.""" # Create multiple companies and objects