diff --git a/CHANGELOG.md b/CHANGELOG.md index c9db0d9b..bfd051c0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,11 @@ changelog does not include internal changes that do not affect the user. ## [Unreleased] +### Fixed + +- Added a fallback for when the inner optimization of `NashMTL` fails (which can happen for example + on the matrix [[0., 0.], [0., 1.]]). + ## [0.9.0] - 2026-02-24 ### Added diff --git a/src/torchjd/aggregation/_nash_mtl.py b/src/torchjd/aggregation/_nash_mtl.py index 06e1293d..a20be617 100644 --- a/src/torchjd/aggregation/_nash_mtl.py +++ b/src/torchjd/aggregation/_nash_mtl.py @@ -158,9 +158,10 @@ def _solve_optimization(self, gtg: np.ndarray) -> np.ndarray: try: self.prob.solve(solver=cp.ECOS, warm_start=True, max_iters=100) - except SolverError: - # On macOS, this can happen with: Solver 'ECOS' failed. + except (SolverError, ValueError): + # On macOS, SolverError can happen with: Solver 'ECOS' failed. # No idea why. The corresponding matrix is of shape [9, 11] with rank 5. + # ValueError happens with for example matrix [[0., 0.], [0., 1.]]. # Maybe other exceptions can happen in other cases. self.alpha_param.value = self.prvs_alpha_param.value diff --git a/tests/unit/aggregation/test_nash_mtl.py b/tests/unit/aggregation/test_nash_mtl.py index a1200d46..d82fca41 100644 --- a/tests/unit/aggregation/test_nash_mtl.py +++ b/tests/unit/aggregation/test_nash_mtl.py @@ -1,7 +1,7 @@ from pytest import mark from torch import Tensor from torch.testing import assert_close -from utils.tensors import ones_, randn_ +from utils.tensors import ones_, randn_, tensor_ try: from torchjd.aggregation import NashMTL @@ -19,6 +19,10 @@ def _make_aggregator(matrix: Tensor) -> NashMTL: standard_pairs = [(_make_aggregator(matrix), matrix) for matrix in nash_mtl_matrices] +edge_case_matrices = [ + tensor_([[0.0, 0.0], [0.0, 1.0]]) # This leads to a (caught) ValueError in _solve_optimization. +] +edge_case_pairs = [(_make_aggregator(matrix), matrix) for matrix in edge_case_matrices] requires_grad_pairs = [(NashMTL(n_tasks=3), ones_(3, 5, requires_grad=True))] @@ -27,8 +31,13 @@ def _make_aggregator(matrix: Tensor) -> NashMTL: @mark.filterwarnings( "ignore:Solution may be inaccurate.", "ignore:You are solving a parameterized problem that is not DPP.", + "ignore:divide by zero encountered in divide", + "ignore:divide by zero encountered in true_divide", + "ignore:overflow encountered in divide", + "ignore:overflow encountered in true_divide", + "ignore:invalid value encountered in matmul", ) -@mark.parametrize(["aggregator", "matrix"], standard_pairs) +@mark.parametrize(["aggregator", "matrix"], standard_pairs + edge_case_pairs) def test_expected_structure(aggregator: NashMTL, matrix: Tensor) -> None: assert_expected_structure(aggregator, matrix)