diff --git a/docs/source/docs/aggregation/index.rst b/docs/source/docs/aggregation/index.rst index 73442a93..57ab1a8f 100644 --- a/docs/source/docs/aggregation/index.rst +++ b/docs/source/docs/aggregation/index.rst @@ -16,6 +16,9 @@ Abstract base classes .. autoclass:: torchjd.aggregation.Stateful :members: reset +.. autoclass:: torchjd.aggregation.Stochastic + :members: reset + .. toctree:: :hidden: diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index 400cfe27..5016458b 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -71,7 +71,7 @@ from ._krum import Krum, KrumWeighting from ._mean import Mean, MeanWeighting from ._mgda import MGDA, MGDAWeighting -from ._mixins import Stateful +from ._mixins import Stateful, Stochastic from ._pcgrad import PCGrad, PCGradWeighting from ._random import Random, RandomWeighting from ._sum import Sum, SumWeighting @@ -109,6 +109,7 @@ "Random", "RandomWeighting", "Stateful", + "Stochastic", "Sum", "SumWeighting", "TrimmedMean", diff --git a/src/torchjd/aggregation/_graddrop.py b/src/torchjd/aggregation/_graddrop.py index 61c9354e..4a98cfb9 100644 --- a/src/torchjd/aggregation/_graddrop.py +++ b/src/torchjd/aggregation/_graddrop.py @@ -6,6 +6,7 @@ from torchjd._linalg import Matrix from ._aggregator_bases import Aggregator +from ._mixins import Stochastic from ._utils.non_differentiable import raise_non_differentiable_error @@ -13,7 +14,7 @@ def _identity(P: Tensor) -> Tensor: return P -class GradDrop(Aggregator): +class GradDrop(Aggregator, Stochastic): """ :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that applies the gradient combination steps from GradDrop, as defined in lines 10 to 15 of Algorithm 1 of `Just Pick a Sign: @@ -24,16 +25,21 @@ class GradDrop(Aggregator): increasing. Defaults to identity. :param leak: The tensor of leak values, determining how much each row is allowed to leak through. Defaults to None, which means no leak. + :param seed: Seed for the internal random number generator. If ``None``, a seed is drawn from + the global PyTorch RNG to fork an independent stream. """ - def __init__(self, f: Callable = _identity, leak: Tensor | None = None) -> None: + def __init__( + self, f: Callable = _identity, leak: Tensor | None = None, seed: int | None = None + ) -> None: if leak is not None and leak.dim() != 1: raise ValueError( "Parameter `leak` should be a 1-dimensional tensor. Found `leak.shape = " f"{leak.shape}`.", ) - super().__init__() + Aggregator.__init__(self) + Stochastic.__init__(self, seed=seed) self.f = f self.leak = leak @@ -50,7 +56,7 @@ def forward(self, matrix: Matrix, /) -> Tensor: P = 0.5 * (torch.ones_like(matrix[0]) + matrix.sum(dim=0) / matrix.abs().sum(dim=0)) fP = self.f(P) - U = torch.rand(P.shape, dtype=matrix.dtype, device=matrix.device) + U = torch.rand(P.shape, dtype=matrix.dtype, device=matrix.device, generator=self.generator) vector = torch.zeros_like(matrix[0]) for i in range(len(matrix)): diff --git a/src/torchjd/aggregation/_gradvac.py b/src/torchjd/aggregation/_gradvac.py index cc518fbb..9c605a4b 100644 --- a/src/torchjd/aggregation/_gradvac.py +++ b/src/torchjd/aggregation/_gradvac.py @@ -6,16 +6,16 @@ from torch import Tensor from torchjd._linalg import PSDMatrix -from torchjd.aggregation._mixins import Stateful +from torchjd.aggregation._mixins import Stochastic from ._aggregator_bases import GramianWeightedAggregator from ._utils.non_differentiable import raise_non_differentiable_error from ._weighting_bases import Weighting -class GradVacWeighting(Weighting[PSDMatrix], Stateful): +class GradVacWeighting(Weighting[PSDMatrix], Stochastic): r""" - :class:`~torchjd.aggregation._mixins.Stateful` + :class:`~torchjd.aggregation._mixins.Stochastic` :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of :class:`~torchjd.aggregation.GradVac`. @@ -37,10 +37,13 @@ class GradVacWeighting(Weighting[PSDMatrix], Stateful): :param beta: EMA decay for :math:`\hat{\phi}`. :param eps: Small non-negative constant added to denominators. + :param seed: Seed for the internal random number generator. If ``None``, a seed is drawn from + the global PyTorch RNG to fork an independent stream. """ - def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None: - super().__init__() + def __init__(self, beta: float = 0.5, eps: float = 1e-8, seed: int | None = None) -> None: + Weighting.__init__(self) + Stochastic.__init__(self, seed=seed) if not (0.0 <= beta <= 1.0): raise ValueError(f"Parameter `beta` must be in [0, 1]. Found beta={beta!r}.") if eps < 0.0: @@ -72,8 +75,9 @@ def eps(self, value: float) -> None: self._eps = value def reset(self) -> None: - """Clears EMA state so the next forward starts from zero targets.""" + """Resets the random number generator and clears the EMA state.""" + Stochastic.reset(self) self._phi_t = None self._state_key = None @@ -101,7 +105,7 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor: cG = C[i] @ G others = [j for j in range(m) if j != i] - perm = torch.randperm(len(others)) + perm = torch.randperm(len(others), generator=self.generator) shuffled_js = [others[idx] for idx in perm.tolist()] for j in shuffled_js: @@ -133,9 +137,9 @@ def _ensure_state(self, m: int, dtype: torch.dtype) -> None: self._state_key = key -class GradVac(GramianWeightedAggregator, Stateful): +class GradVac(GramianWeightedAggregator, Stochastic): r""" - :class:`~torchjd.aggregation._mixins.Stateful` + :class:`~torchjd.aggregation._mixins.Stochastic` :class:`~torchjd.aggregation._aggregator_bases.Aggregator` implementing the aggregation step of Gradient Vaccine (GradVac) from `Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Models (ICLR 2021 Spotlight) @@ -155,22 +159,18 @@ class GradVac(GramianWeightedAggregator, Stateful): :param beta: EMA decay for :math:`\hat{\phi}`. :param eps: Small non-negative constant added to denominators. - - .. note:: - For each task :math:`i`, the order of other tasks :math:`j` is shuffled independently - using the global PyTorch RNG (``torch.randperm``). Seed it with ``torch.manual_seed`` if - you need reproducibility. + :param seed: Seed for the internal random number generator. If ``None``, a seed is drawn from + the global PyTorch RNG to fork an independent stream. .. note:: To apply GradVac with the `whole_model`, `enc_dec`, `all_layer` or `all_matrix` grouping strategy, please refer to the :doc:`Grouping ` examples. """ - gramian_weighting: GradVacWeighting - - def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None: - weighting = GradVacWeighting(beta=beta, eps=eps) - super().__init__(weighting) + def __init__(self, beta: float = 0.5, eps: float = 1e-8, seed: int | None = None) -> None: + weighting = GradVacWeighting(beta=beta, eps=eps, seed=seed) + GramianWeightedAggregator.__init__(self, weighting) + Stochastic.__init__(self, generator=weighting.generator) self._gradvac_weighting = weighting self.register_full_backward_pre_hook(raise_non_differentiable_error) @@ -191,8 +191,9 @@ def eps(self, value: float) -> None: self._gradvac_weighting.eps = value def reset(self) -> None: - """Clears EMA state so the next forward starts from zero targets.""" + """Resets the random number generator and clears the EMA state.""" + Stochastic.reset(self) self._gradvac_weighting.reset() def __repr__(self) -> str: diff --git a/src/torchjd/aggregation/_mixins.py b/src/torchjd/aggregation/_mixins.py index 8481feab..90f0f9a6 100644 --- a/src/torchjd/aggregation/_mixins.py +++ b/src/torchjd/aggregation/_mixins.py @@ -1,9 +1,52 @@ from abc import ABC, abstractmethod +import torch + class Stateful(ABC): - """Mixin adding a reset method.""" + r""" + Mixin for stateful mappings. + + A maping implements `Stateful` **if and only if** its behavior depends on an internal + state. + + Formally, a stateless mapping is a function :math:`f : x \mapsto y` whereas a stateful + maping is a transition map :math:`A : (x, s) \mapsto (y, s')` where :math:`s` is the + internal state, :math:`s'` the updated state, and :math:`y` the output. + There exists an initial state :math:`s_0`, and the method `reset()` restores the state to + :math:`s_0`. A `Stateful` mapping must be constructed with the intial state :math:`s_0`. + """ @abstractmethod def reset(self) -> None: - """Resets the internal state.""" + """Resets the internal state :math:`s_0`.""" + + +class StochasticState(Stateful): + r""" + State respresenting stochasticity. + + Internally, a ``StochasticState`` mapping holds a :class:`torch.Generator` that serves as an + independent random number stream. + + :param seed: Seed for the internal :class:`torch.Generator`. If ``None``, a seed is drawn + from the global PyTorch RNG to fork an independent stream. + :param generator: An existing :class:`torch.Generator` to share, typically from a companion + :class:`StochasticState` instance. Mutually exclusive with ``seed``. + """ + + def __init__(self, seed: int | None = None, generator: torch.Generator | None = None) -> None: + if generator is not None and seed is not None: + raise ValueError("Parameters `seed` and `generator` are mutually exclusive.") + if generator is not None: + self.generator = generator + else: + self.generator = torch.Generator() + if seed is None: + seed = int(torch.randint(0, 2**62, size=(1,), dtype=torch.int64).item()) + self.generator.manual_seed(seed) + self._initial_rng_state = self.generator.get_state() + + def reset(self) -> None: + """Resets the random number generator to its initial state.""" + self.generator.set_state(self._initial_rng_state) diff --git a/src/torchjd/aggregation/_pcgrad.py b/src/torchjd/aggregation/_pcgrad.py index 770ffe09..26ee42f4 100644 --- a/src/torchjd/aggregation/_pcgrad.py +++ b/src/torchjd/aggregation/_pcgrad.py @@ -4,18 +4,27 @@ from torch import Tensor from torchjd._linalg import PSDMatrix +from torchjd.aggregation import Stateful +from torchjd.aggregation._mixins import StochasticState from ._aggregator_bases import GramianWeightedAggregator from ._utils.non_differentiable import raise_non_differentiable_error from ._weighting_bases import Weighting -class PCGradWeighting(Weighting[PSDMatrix]): +class PCGradWeighting(Weighting[PSDMatrix], Stateful): """ :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of :class:`~torchjd.aggregation.PCGrad`. + + :param seed: Seed for the internal random number generator. If ``None``, a seed is drawn from + the global PyTorch RNG to fork an independent stream. """ + def __init__(self, seed: int | None = None) -> None: + super().__init__() + self.state = StochasticState(seed=seed) + def forward(self, gramian: PSDMatrix, /) -> Tensor: # Move all computations on cpu to avoid moving memory between cpu and gpu at each iteration device = gramian.device @@ -27,7 +36,7 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor: weights = torch.zeros(dimension, device=cpu, dtype=dtype) for i in range(dimension): - permutation = torch.randperm(dimension) + permutation = torch.randperm(dimension, generator=self.state.generator) current_weights = torch.zeros(dimension, device=cpu, dtype=dtype) current_weights[i] = 1.0 @@ -46,16 +55,17 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor: return weights.to(device) -class PCGrad(GramianWeightedAggregator): +class PCGrad(GramianWeightedAggregator, Stateful): """ :class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in algorithm 1 of `Gradient Surgery for Multi-Task Learning `_. - """ - gramian_weighting: PCGradWeighting + :param seed: Seed for the internal random number generator. If ``None``, a seed is drawn from + the global PyTorch RNG to fork an independent stream. + """ - def __init__(self) -> None: - super().__init__(PCGradWeighting()) + def __init__(self, seed: int | None = None) -> None: + super().__init__(PCGradWeighting(seed=seed)) # This prevents running into a RuntimeError due to modifying stored tensors in place. self.register_full_backward_pre_hook(raise_non_differentiable_error) diff --git a/src/torchjd/aggregation/_random.py b/src/torchjd/aggregation/_random.py index 8345a15c..4d602b57 100644 --- a/src/torchjd/aggregation/_random.py +++ b/src/torchjd/aggregation/_random.py @@ -5,30 +5,43 @@ from torchjd._linalg import Matrix from ._aggregator_bases import WeightedAggregator +from ._mixins import Stochastic from ._weighting_bases import Weighting -class RandomWeighting(Weighting[Matrix]): +class RandomWeighting(Weighting[Matrix], Stochastic): """ :class:`~torchjd.aggregation._weighting_bases.Weighting` that generates positive random weights at each call. + + :param seed: Seed for the internal random number generator. If ``None``, a seed is drawn from + the global PyTorch RNG to fork an independent stream. """ + def __init__(self, seed: int | None = None) -> None: + Weighting.__init__(self) + Stochastic.__init__(self, seed=seed) + def forward(self, matrix: Tensor, /) -> Tensor: - random_vector = torch.randn(matrix.shape[0], device=matrix.device, dtype=matrix.dtype) + random_vector = torch.randn( + matrix.shape[0], device=matrix.device, dtype=matrix.dtype, generator=self.generator + ) weights = F.softmax(random_vector, dim=-1) return weights -class Random(WeightedAggregator): +class Random(WeightedAggregator, Stochastic): """ :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that computes a random combination of the rows of the provided matrices, as defined in algorithm 2 of `Reasonable Effectiveness of Random Weighting: A Litmus Test for Multi-Task Learning `_. - """ - weighting: RandomWeighting + :param seed: Seed for the internal random number generator. If ``None``, a seed is drawn from + the global PyTorch RNG to fork an independent stream. + """ - def __init__(self) -> None: - super().__init__(RandomWeighting()) + def __init__(self, seed: int | None = None) -> None: + weighting = RandomWeighting(seed=seed) + WeightedAggregator.__init__(self, weighting) + Stochastic.__init__(self, generator=weighting.generator) diff --git a/tests/unit/aggregation/_asserts.py b/tests/unit/aggregation/_asserts.py index 4b85bf09..b4d8c738 100644 --- a/tests/unit/aggregation/_asserts.py +++ b/tests/unit/aggregation/_asserts.py @@ -4,7 +4,7 @@ from torch.testing import assert_close from utils.tensors import rand_, randperm_ -from torchjd.aggregation import Aggregator +from torchjd.aggregation import Aggregator, Stateful from torchjd.aggregation._utils.non_differentiable import NonDifferentiableError @@ -110,3 +110,33 @@ def assert_non_differentiable(aggregator: Aggregator, matrix: Tensor) -> None: vector = aggregator(matrix) with raises(NonDifferentiableError): vector.backward(torch.ones_like(vector)) + + +def assert_stateful(aggregator: Aggregator, matrix: Tensor) -> None: + """ + Test that a given `Aggregator` is stateful. Specifically: + - For a fixed state, the aggregator is determinist on the matrix + - The reset method and the constructor both set the state to the initial state + """ + + assert isinstance(aggregator, Stateful) + + first_pair = (aggregator(matrix), aggregator(matrix)) + aggregator.reset() + second_pair = (aggregator(matrix), aggregator(matrix)) + + assert_close(first_pair[0], second_pair[0], atol=0.0, rtol=0.0) + assert_close(first_pair[1], second_pair[1], atol=0.0, rtol=0.0) + + +def assert_stateless(aggregator: Aggregator, matrix: Tensor) -> None: + """ + Test that a given `Aggregator` is stateless. Specifically, it must be deterministic. + """ + + assert not isinstance(aggregator, Stateful) + + first = aggregator(matrix) + second = aggregator(matrix) + + assert_close(first, second, atol=0.0, rtol=0.0) diff --git a/tests/unit/aggregation/test_aligned_mtl.py b/tests/unit/aggregation/test_aligned_mtl.py index d48e8855..24ce80a5 100644 --- a/tests/unit/aggregation/test_aligned_mtl.py +++ b/tests/unit/aggregation/test_aligned_mtl.py @@ -5,7 +5,7 @@ from torchjd.aggregation import AlignedMTL -from ._asserts import assert_expected_structure, assert_permutation_invariant +from ._asserts import assert_expected_structure, assert_permutation_invariant, assert_stateless from ._inputs import scaled_matrices, typical_matrices aggregators = [ @@ -28,6 +28,11 @@ def test_permutation_invariant(aggregator: AlignedMTL, matrix: Tensor) -> None: assert_permutation_invariant(aggregator, matrix) +@mark.parametrize(["aggregator", "matrix"], typical_pairs) +def test_stateless(aggregator: AlignedMTL, matrix: Tensor) -> None: + assert_stateless(aggregator, matrix) + + def test_representations() -> None: A = AlignedMTL(pref_vector=None) assert repr(A) == "AlignedMTL(pref_vector=None, scale_mode='min')" diff --git a/tests/unit/aggregation/test_cagrad.py b/tests/unit/aggregation/test_cagrad.py index c7d18b1f..70aa3b16 100644 --- a/tests/unit/aggregation/test_cagrad.py +++ b/tests/unit/aggregation/test_cagrad.py @@ -12,7 +12,12 @@ pytest.skip("CAGrad dependencies not installed", allow_module_level=True) -from ._asserts import assert_expected_structure, assert_non_conflicting, assert_non_differentiable +from ._asserts import ( + assert_expected_structure, + assert_non_conflicting, + assert_non_differentiable, + assert_stateless, +) from ._inputs import scaled_matrices, typical_matrices scaled_pairs = [(CAGrad(c=0.5), matrix) for matrix in scaled_matrices] @@ -38,6 +43,11 @@ def test_non_conflicting(aggregator: CAGrad, matrix: Tensor) -> None: assert_non_conflicting(aggregator, matrix) +@mark.parametrize(["aggregator", "matrix"], typical_pairs) +def test_stateless(aggregator: CAGrad, matrix: Tensor) -> None: + assert_stateless(aggregator, matrix) + + @mark.parametrize( ["c", "expectation"], [ diff --git a/tests/unit/aggregation/test_config.py b/tests/unit/aggregation/test_config.py index 2db2ea0f..c3ba275e 100644 --- a/tests/unit/aggregation/test_config.py +++ b/tests/unit/aggregation/test_config.py @@ -10,6 +10,7 @@ assert_linear_under_scaling, assert_non_differentiable, assert_permutation_invariant, + assert_stateless, ) from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices @@ -39,6 +40,11 @@ def test_non_differentiable(aggregator: ConFIG, matrix: Tensor) -> None: assert_non_differentiable(aggregator, matrix) +@mark.parametrize(["aggregator", "matrix"], typical_pairs) +def test_stateless(aggregator: ConFIG, matrix: Tensor) -> None: + assert_stateless(aggregator, matrix) + + def test_representations() -> None: A = ConFIG() assert repr(A) == "ConFIG(pref_vector=None)" diff --git a/tests/unit/aggregation/test_constant.py b/tests/unit/aggregation/test_constant.py index aa1332fc..07bcd110 100644 --- a/tests/unit/aggregation/test_constant.py +++ b/tests/unit/aggregation/test_constant.py @@ -11,6 +11,7 @@ from ._asserts import ( assert_expected_structure, assert_linear_under_scaling, + assert_stateless, assert_strongly_stationary, ) from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices @@ -42,6 +43,11 @@ def test_strongly_stationary(aggregator: Constant, matrix: Tensor) -> None: assert_strongly_stationary(aggregator, matrix) +@mark.parametrize(["aggregator", "matrix"], typical_pairs) +def test_stateless(aggregator: Constant, matrix: Tensor) -> None: + assert_stateless(aggregator, matrix) + + @mark.parametrize( ["weights_shape", "expectation"], [ diff --git a/tests/unit/aggregation/test_dualproj.py b/tests/unit/aggregation/test_dualproj.py index 5bd0e71a..051ec1a0 100644 --- a/tests/unit/aggregation/test_dualproj.py +++ b/tests/unit/aggregation/test_dualproj.py @@ -10,6 +10,7 @@ assert_non_conflicting, assert_non_differentiable, assert_permutation_invariant, + assert_stateless, assert_strongly_stationary, ) from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices @@ -45,6 +46,11 @@ def test_non_differentiable(aggregator: DualProj, matrix: Tensor) -> None: assert_non_differentiable(aggregator, matrix) +@mark.parametrize(["aggregator", "matrix"], typical_pairs) +def test_stateless(aggregator: DualProj, matrix: Tensor) -> None: + assert_stateless(aggregator, matrix) + + def test_representations() -> None: A = DualProj(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver="quadprog") assert ( diff --git a/tests/unit/aggregation/test_graddrop.py b/tests/unit/aggregation/test_graddrop.py index 2868dca0..586d94c8 100644 --- a/tests/unit/aggregation/test_graddrop.py +++ b/tests/unit/aggregation/test_graddrop.py @@ -9,7 +9,7 @@ from torchjd.aggregation import GradDrop -from ._asserts import assert_expected_structure, assert_non_differentiable +from ._asserts import assert_expected_structure, assert_non_differentiable, assert_stateful from ._inputs import scaled_matrices, typical_matrices scaled_pairs = [(GradDrop(), matrix) for matrix in scaled_matrices] @@ -27,6 +27,11 @@ def test_non_differentiable(aggregator: GradDrop, matrix: Tensor) -> None: assert_non_differentiable(aggregator, matrix) +@mark.parametrize(["aggregator", "matrix"], typical_pairs) +def test_stateful(aggregator: GradDrop, matrix: Tensor) -> None: + assert_stateful(aggregator, matrix) + + @mark.parametrize( ["leak_shape", "expectation"], [ diff --git a/tests/unit/aggregation/test_gradvac.py b/tests/unit/aggregation/test_gradvac.py index bde2e8fd..88fd4466 100644 --- a/tests/unit/aggregation/test_gradvac.py +++ b/tests/unit/aggregation/test_gradvac.py @@ -6,7 +6,7 @@ from torchjd.aggregation import GradVac, GradVacWeighting -from ._asserts import assert_expected_structure, assert_non_differentiable +from ._asserts import assert_expected_structure, assert_non_differentiable, assert_stateful from ._inputs import scaled_matrices, typical_matrices, typical_matrices_2_plus_rows scaled_pairs = [(GradVac(), m) for m in scaled_matrices] @@ -104,6 +104,11 @@ def test_non_differentiable(aggregator: GradVac, matrix: Tensor) -> None: assert_non_differentiable(aggregator, matrix) +@mark.parametrize(["aggregator", "matrix"], typical_pairs) +def test_stateful(aggregator: GradVac, matrix: Tensor) -> None: + assert_stateful(aggregator, matrix) + + def test_weighting_beta_out_of_range() -> None: with raises(ValueError, match="beta"): GradVacWeighting(beta=-0.1) diff --git a/tests/unit/aggregation/test_imtl_g.py b/tests/unit/aggregation/test_imtl_g.py index 03c41d5e..3fa40ceb 100644 --- a/tests/unit/aggregation/test_imtl_g.py +++ b/tests/unit/aggregation/test_imtl_g.py @@ -9,6 +9,7 @@ assert_expected_structure, assert_non_differentiable, assert_permutation_invariant, + assert_stateless, ) from ._inputs import scaled_matrices, typical_matrices @@ -32,6 +33,11 @@ def test_non_differentiable(aggregator: IMTLG, matrix: Tensor) -> None: assert_non_differentiable(aggregator, matrix) +@mark.parametrize(["aggregator", "matrix"], typical_pairs) +def test_stateless(aggregator: IMTLG, matrix: Tensor) -> None: + assert_stateless(aggregator, matrix) + + def test_imtlg_zero() -> None: """ Tests that IMTLG correctly returns the 0 vector in the special case where input matrix only diff --git a/tests/unit/aggregation/test_krum.py b/tests/unit/aggregation/test_krum.py index 4097f2eb..bab3011f 100644 --- a/tests/unit/aggregation/test_krum.py +++ b/tests/unit/aggregation/test_krum.py @@ -7,7 +7,7 @@ from torchjd.aggregation import Krum -from ._asserts import assert_expected_structure +from ._asserts import assert_expected_structure, assert_stateless from ._inputs import scaled_matrices_2_plus_rows, typical_matrices_2_plus_rows scaled_pairs = [(Krum(n_byzantine=1), matrix) for matrix in scaled_matrices_2_plus_rows] @@ -19,6 +19,11 @@ def test_expected_structure(aggregator: Krum, matrix: Tensor) -> None: assert_expected_structure(aggregator, matrix) +@mark.parametrize(["aggregator", "matrix"], typical_pairs) +def test_stateless(aggregator: Krum, matrix: Tensor) -> None: + assert_stateless(aggregator, matrix) + + @mark.parametrize( ["n_byzantine", "expectation"], [ diff --git a/tests/unit/aggregation/test_mean.py b/tests/unit/aggregation/test_mean.py index 88c28e93..628f1e2a 100644 --- a/tests/unit/aggregation/test_mean.py +++ b/tests/unit/aggregation/test_mean.py @@ -7,6 +7,7 @@ assert_expected_structure, assert_linear_under_scaling, assert_permutation_invariant, + assert_stateless, assert_strongly_stationary, ) from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices @@ -36,6 +37,11 @@ def test_strongly_stationary(aggregator: Mean, matrix: Tensor) -> None: assert_strongly_stationary(aggregator, matrix) +@mark.parametrize(["aggregator", "matrix"], typical_pairs) +def test_stateless(aggregator: Mean, matrix: Tensor) -> None: + assert_stateless(aggregator, matrix) + + def test_representations() -> None: A = Mean() assert repr(A) == "Mean()" diff --git a/tests/unit/aggregation/test_mgda.py b/tests/unit/aggregation/test_mgda.py index 5c925b8f..69c9b9d8 100644 --- a/tests/unit/aggregation/test_mgda.py +++ b/tests/unit/aggregation/test_mgda.py @@ -11,6 +11,7 @@ assert_expected_structure, assert_non_conflicting, assert_permutation_invariant, + assert_stateless, ) from ._inputs import scaled_matrices, typical_matrices @@ -33,6 +34,11 @@ def test_permutation_invariant(aggregator: MGDA, matrix: Tensor) -> None: assert_permutation_invariant(aggregator, matrix) +@mark.parametrize(["aggregator", "matrix"], typical_pairs) +def test_stateless(aggregator: MGDA, matrix: Tensor) -> None: + assert_stateless(aggregator, matrix) + + @mark.parametrize( "shape", [ diff --git a/tests/unit/aggregation/test_nash_mtl.py b/tests/unit/aggregation/test_nash_mtl.py index d82fca41..9b7cfd52 100644 --- a/tests/unit/aggregation/test_nash_mtl.py +++ b/tests/unit/aggregation/test_nash_mtl.py @@ -10,7 +10,7 @@ pytest.skip("NashMTL dependencies not installed", allow_module_level=True) -from ._asserts import assert_expected_structure, assert_non_differentiable +from ._asserts import assert_expected_structure, assert_non_differentiable, assert_stateful from ._inputs import nash_mtl_matrices @@ -48,6 +48,15 @@ def test_non_differentiable(aggregator: NashMTL, matrix: Tensor) -> None: assert_non_differentiable(aggregator, matrix) +@mark.filterwarnings( + "ignore:Solution may be inaccurate.", + "ignore:You are solving a parameterized problem that is not DPP.", +) +@mark.parametrize(["aggregator", "matrix"], standard_pairs) +def test_stateful(aggregator: NashMTL, matrix: Tensor) -> None: + assert_stateful(aggregator, matrix) + + @mark.filterwarnings("ignore: You are solving a parameterized problem that is not DPP.") def test_nash_mtl_reset() -> None: """ diff --git a/tests/unit/aggregation/test_pcgrad.py b/tests/unit/aggregation/test_pcgrad.py index b776071d..79d67d27 100644 --- a/tests/unit/aggregation/test_pcgrad.py +++ b/tests/unit/aggregation/test_pcgrad.py @@ -8,7 +8,7 @@ from torchjd.aggregation._pcgrad import PCGradWeighting from torchjd.aggregation._upgrad import UPGradWeighting -from ._asserts import assert_expected_structure, assert_non_differentiable +from ._asserts import assert_expected_structure, assert_non_differentiable, assert_stateful from ._inputs import scaled_matrices, typical_matrices scaled_pairs = [(PCGrad(), matrix) for matrix in scaled_matrices] @@ -26,6 +26,11 @@ def test_non_differentiable(aggregator: PCGrad, matrix: Tensor) -> None: assert_non_differentiable(aggregator, matrix) +@mark.parametrize(["aggregator", "matrix"], typical_pairs) +def test_stateful(aggregator: PCGrad, matrix: Tensor) -> None: + assert_stateful(aggregator, matrix) + + @mark.parametrize( "shape", [ diff --git a/tests/unit/aggregation/test_random.py b/tests/unit/aggregation/test_random.py index 77ab7f42..fcc8b08b 100644 --- a/tests/unit/aggregation/test_random.py +++ b/tests/unit/aggregation/test_random.py @@ -3,7 +3,7 @@ from torchjd.aggregation import Random -from ._asserts import assert_expected_structure, assert_strongly_stationary +from ._asserts import assert_expected_structure, assert_stateful, assert_strongly_stationary from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices scaled_pairs = [(Random(), matrix) for matrix in scaled_matrices] @@ -21,6 +21,11 @@ def test_strongly_stationary(aggregator: Random, matrix: Tensor) -> None: assert_strongly_stationary(aggregator, matrix) +@mark.parametrize(["aggregator", "matrix"], typical_pairs) +def test_stateful(aggregator: Random, matrix: Tensor) -> None: + assert_stateful(aggregator, matrix) + + def test_representations() -> None: A = Random() assert repr(A) == "Random()" diff --git a/tests/unit/aggregation/test_sum.py b/tests/unit/aggregation/test_sum.py index 386c507f..757e7e77 100644 --- a/tests/unit/aggregation/test_sum.py +++ b/tests/unit/aggregation/test_sum.py @@ -7,6 +7,7 @@ assert_expected_structure, assert_linear_under_scaling, assert_permutation_invariant, + assert_stateless, assert_strongly_stationary, ) from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices @@ -36,6 +37,11 @@ def test_strongly_stationary(aggregator: Sum, matrix: Tensor) -> None: assert_strongly_stationary(aggregator, matrix) +@mark.parametrize(["aggregator", "matrix"], typical_pairs) +def test_stateless(aggregator: Sum, matrix: Tensor) -> None: + assert_stateless(aggregator, matrix) + + def test_representations() -> None: A = Sum() assert repr(A) == "Sum()" diff --git a/tests/unit/aggregation/test_trimmed_mean.py b/tests/unit/aggregation/test_trimmed_mean.py index 3a6ccb2b..97c027d8 100644 --- a/tests/unit/aggregation/test_trimmed_mean.py +++ b/tests/unit/aggregation/test_trimmed_mean.py @@ -7,7 +7,7 @@ from torchjd.aggregation import TrimmedMean -from ._asserts import assert_expected_structure, assert_permutation_invariant +from ._asserts import assert_expected_structure, assert_permutation_invariant, assert_stateless from ._inputs import scaled_matrices_2_plus_rows, typical_matrices_2_plus_rows scaled_pairs = [(TrimmedMean(trim_number=1), matrix) for matrix in scaled_matrices_2_plus_rows] @@ -24,6 +24,11 @@ def test_permutation_invariant(aggregator: TrimmedMean, matrix: Tensor) -> None: assert_permutation_invariant(aggregator, matrix) +@mark.parametrize(["aggregator", "matrix"], typical_pairs) +def test_stateless(aggregator: TrimmedMean, matrix: Tensor) -> None: + assert_stateless(aggregator, matrix) + + @mark.parametrize( ["trim_number", "expectation"], [ diff --git a/tests/unit/aggregation/test_upgrad.py b/tests/unit/aggregation/test_upgrad.py index 1859b662..758c2d4b 100644 --- a/tests/unit/aggregation/test_upgrad.py +++ b/tests/unit/aggregation/test_upgrad.py @@ -11,6 +11,7 @@ assert_non_conflicting, assert_non_differentiable, assert_permutation_invariant, + assert_stateless, assert_strongly_stationary, ) from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices @@ -51,6 +52,11 @@ def test_non_differentiable(aggregator: UPGrad, matrix: Tensor) -> None: assert_non_differentiable(aggregator, matrix) +@mark.parametrize(["aggregator", "matrix"], typical_pairs) +def test_stateless(aggregator: UPGrad, matrix: Tensor) -> None: + assert_stateless(aggregator, matrix) + + def test_representations() -> None: A = UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver="quadprog") assert repr(A) == "UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver='quadprog')"