From 5eebcf7690e387c4487fea6fde869f200de660cc Mon Sep 17 00:00:00 2001 From: rkhosrowshahi Date: Thu, 9 Apr 2026 10:39:31 -0400 Subject: [PATCH 01/16] feat(aggregation): add GradVac aggregator Implement Gradient Vaccine (ICLR 2021) as a stateful Jacobian aggregator. Support group_type 0 (whole model), 1 (all_layer via encoder), and 2 (all_matrix via shared_params), with DEFAULT_GRADVAC_EPS and configurable eps. Add Sphinx page and unit tests. Autogram is not supported; use torch.manual_seed for reproducible task shuffle order. --- docs/source/docs/aggregation/gradvac.rst | 22 ++ docs/source/docs/aggregation/index.rst | 1 + src/torchjd/aggregation/__init__.py | 3 + src/torchjd/aggregation/_gradvac.py | 251 +++++++++++++++++++++++ tests/unit/aggregation/test_gradvac.py | 185 +++++++++++++++++ 5 files changed, 462 insertions(+) create mode 100644 docs/source/docs/aggregation/gradvac.rst create mode 100644 src/torchjd/aggregation/_gradvac.py create mode 100644 tests/unit/aggregation/test_gradvac.py diff --git a/docs/source/docs/aggregation/gradvac.rst b/docs/source/docs/aggregation/gradvac.rst new file mode 100644 index 00000000..5fd2f7bf --- /dev/null +++ b/docs/source/docs/aggregation/gradvac.rst @@ -0,0 +1,22 @@ +:hide-toc: + +GradVac +======= + +.. autodata:: torchjd.aggregation.DEFAULT_GRADVAC_EPS + +The constructor argument ``group_type`` (default ``0``) sets **parameter granularity** for the +per-block cosine statistics in GradVac: + +* ``0`` — **whole model** (``whole_model``): one block per task gradient row. Omit ``encoder`` and + ``shared_params``. +* ``1`` — **all layer** (``all_layer``): one block per leaf submodule with parameters under + ``encoder`` (same traversal as ``encoder.modules()`` in the reference formulation). +* ``2`` — **all matrix** (``all_matrix``): one block per tensor in ``shared_params``, in order. Use + the same tensors as for the shared-parameter Jacobian columns (e.g. the parameters you would pass + to a shared-gradient helper). + +.. autoclass:: torchjd.aggregation.GradVac + :members: + :undoc-members: + :exclude-members: forward diff --git a/docs/source/docs/aggregation/index.rst b/docs/source/docs/aggregation/index.rst index c15d5980..64ba6f63 100644 --- a/docs/source/docs/aggregation/index.rst +++ b/docs/source/docs/aggregation/index.rst @@ -35,6 +35,7 @@ Abstract base classes dualproj.rst flattening.rst graddrop.rst + gradvac.rst imtl_g.rst krum.rst mean.rst diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index 9eed9bf7..d6c39602 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -66,6 +66,7 @@ from ._dualproj import DualProj, DualProjWeighting from ._flattening import Flattening from ._graddrop import GradDrop +from ._gradvac import DEFAULT_GRADVAC_EPS, GradVac from ._imtl_g import IMTLG, IMTLGWeighting from ._krum import Krum, KrumWeighting from ._mean import Mean, MeanWeighting @@ -87,11 +88,13 @@ "ConFIG", "Constant", "ConstantWeighting", + "DEFAULT_GRADVAC_EPS", "DualProj", "DualProjWeighting", "Flattening", "GeneralizedWeighting", "GradDrop", + "GradVac", "IMTLG", "IMTLGWeighting", "Krum", diff --git a/src/torchjd/aggregation/_gradvac.py b/src/torchjd/aggregation/_gradvac.py new file mode 100644 index 00000000..3d8f0c0f --- /dev/null +++ b/src/torchjd/aggregation/_gradvac.py @@ -0,0 +1,251 @@ +from __future__ import annotations + +from collections.abc import Iterable + +import torch +import torch.nn as nn +from torch import Tensor + +from torchjd._linalg import Matrix + +from ._aggregator_bases import Aggregator +from ._utils.non_differentiable import raise_non_differentiable_error + +#: Default small constant added to denominators for numerical stability. +DEFAULT_GRADVAC_EPS = 1e-8 + + +def _gradvac_all_layer_group_sizes(encoder: nn.Module) -> tuple[int, ...]: + """ + Block sizes per leaf submodule with parameters, matching the ``all_layer`` grouping: iterate + ``encoder.modules()`` and append the total number of elements in each module that has no child + submodules and registers at least one parameter. + """ + + return tuple( + sum(w.numel() for w in module.parameters()) + for module in encoder.modules() + if len(module._modules) == 0 and len(module._parameters) > 0 + ) + + +def _gradvac_all_matrix_group_sizes(shared_params: Iterable[Tensor]) -> tuple[int, ...]: + """One block per tensor in ``shared_params`` order (``all_matrix`` / shared-parameter layout).""" + + return tuple(p.numel() for p in shared_params) + + +class GradVac(Aggregator): + r""" + :class:`~torchjd.aggregation._aggregator_bases.Aggregator` implementing Gradient Vaccine + (GradVac) from `Gradient Vaccine: Investigating and Improving Multi-task Optimization in + Massively Multilingual Models (ICLR 2021 Spotlight) + `_. + + The input matrix is a Jacobian :math:`G \in \mathbb{R}^{M \times D}` whose rows are per-task + gradients. For each task :math:`i`, rows are visited in a random order; for each other task + :math:`j` and each parameter block :math:`k`, the cosine correlation :math:`\rho_{ijk}` between + the (possibly already modified) gradient of task :math:`i` and the original gradient of task + :math:`j` on that block is compared to an EMA target :math:`\bar{\rho}_{ijk}`. When + :math:`\rho_{ijk} < \bar{\rho}_{ijk}`, a closed-form correction adds a scaled copy of + :math:`g_j` to the block of :math:`g_i^{(\mathrm{PC})}`. The EMA is then updated with + :math:`\bar{\rho}_{ijk} \leftarrow (1-\beta)\bar{\rho}_{ijk} + \beta \rho_{ijk}`. The aggregated + vector is the sum of the modified rows. + + This aggregator is stateful: it keeps :math:`\bar{\rho}` across calls. Use :meth:`reset` when + the number of tasks, parameter dimension, grouping, device, or dtype changes. + + **Parameter granularity** is selected by ``group_type`` (integer, default ``0``). It defines how + each task gradient row is partitioned into blocks :math:`k` so that cosines and EMA targets + :math:`\bar{\rho}_{ijk}` are computed **per block** rather than only globally: + + * ``0`` — **whole model** (``whole_model``): the full row of length :math:`D` is a single block. + Cosine similarity is taken between entire task gradients. Do not pass ``encoder`` or + ``shared_params``. + * ``1`` — **all layer** (``all_layer``): one block per leaf ``nn.Module`` under ``encoder`` that + holds parameters (same rule as iterating ``encoder.modules()`` and selecting leaves with + parameters). Pass ``encoder``; ``shared_params`` must be omitted. + * ``2`` — **all matrix** (``all_matrix``): one block per tensor in ``shared_params``, in iteration + order. That order must match how Jacobian columns are laid out for those shared parameters. + Pass ``shared_params``; ``encoder`` must be omitted. + + :param beta: EMA decay for :math:`\bar{\rho}` (paper default ``0.5``). + :param group_type: Granularity of parameter grouping; see **Parameter granularity** above. + :param encoder: Module whose subtree defines ``all_layer`` blocks when ``group_type == 1``. + :param shared_params: Iterable of parameter tensors defining ``all_matrix`` block sizes and + order when ``group_type == 2``. It is materialized once at construction. + :param eps: Small positive constant added to denominators when computing cosines and the + vaccine weight (default :data:`~torchjd.aggregation.DEFAULT_GRADVAC_EPS`). You may read or + assign the :attr:`eps` attribute between steps to tune numerical behavior. + + .. note:: + GradVac is not compatible with autogram: it needs full Jacobian rows and per-block inner + products, not only a Gram matrix. Only the autojac path is supported. + + .. note:: + Task-order shuffling uses the global PyTorch RNG (``torch.randperm``). Seed it with + ``torch.manual_seed`` if you need reproducibility. + """ + + def __init__( + self, + beta: float = 0.5, + group_type: int = 0, + encoder: nn.Module | None = None, + shared_params: Iterable[Tensor] | None = None, + eps: float = DEFAULT_GRADVAC_EPS, + ) -> None: + super().__init__() + if not (0.0 <= beta <= 1.0): + raise ValueError(f"Parameter `beta` must be in [0, 1]. Found beta={beta!r}.") + if group_type not in (0, 1, 2): + raise ValueError( + "Parameter `group_type` must be 0 (whole_model), 1 (all_layer), or 2 (all_matrix). " + f"Found group_type={group_type!r}.", + ) + params_tuple: tuple[Tensor, ...] = () + fixed_block_sizes: tuple[int, ...] | None + if group_type == 0: + if encoder is not None: + raise ValueError("Parameter `encoder` must be None when `group_type == 0`.") + if shared_params is not None: + raise ValueError("Parameter `shared_params` must be None when `group_type == 0`.") + fixed_block_sizes = None + elif group_type == 1: + if encoder is None: + raise ValueError("Parameter `encoder` is required when `group_type == 1`.") + if shared_params is not None: + raise ValueError("Parameter `shared_params` must be None when `group_type == 1`.") + fixed_block_sizes = _gradvac_all_layer_group_sizes(encoder) + if sum(fixed_block_sizes) == 0: + raise ValueError("Parameter `encoder` has no parameters in any leaf module.") + else: + if shared_params is None: + raise ValueError("Parameter `shared_params` is required when `group_type == 2`.") + if encoder is not None: + raise ValueError("Parameter `encoder` must be None when `group_type == 2`.") + params_tuple = tuple(shared_params) + if len(params_tuple) == 0: + raise ValueError( + "Parameter `shared_params` must be non-empty when `group_type == 2`." + ) + fixed_block_sizes = _gradvac_all_matrix_group_sizes(params_tuple) + + if eps <= 0.0: + raise ValueError(f"Parameter `eps` must be positive. Found eps={eps!r}.") + + self._beta = beta + self._group_type = group_type + self._encoder = encoder + self._shared_params_len = len(params_tuple) + self._fixed_block_sizes = fixed_block_sizes + self._eps = float(eps) + + self._rho_t: Tensor | None = None + self._state_key: tuple[int, int, tuple[int, ...], torch.device, torch.dtype] | None = None + + self.register_full_backward_pre_hook(raise_non_differentiable_error) + + @property + def eps(self) -> float: + """Small positive constant added to denominators for numerical stability.""" + + return self._eps + + @eps.setter + def eps(self, value: float) -> None: + v = float(value) + if v <= 0.0: + raise ValueError(f"Attribute `eps` must be positive. Found eps={value!r}.") + self._eps = v + + def reset(self) -> None: + """Clears EMA state so the next forward starts from zero targets.""" + + self._rho_t = None + self._state_key = None + + def __repr__(self) -> str: + enc = "None" if self._encoder is None else f"{self._encoder.__class__.__name__}(...)" + sp = "None" if self._group_type != 2 else f"n_params={self._shared_params_len}" + return ( + f"{self.__class__.__name__}(beta={self._beta!r}, group_type={self._group_type!r}, " + f"encoder={enc}, shared_params={sp}, eps={self._eps!r})" + ) + + def _resolve_segment_sizes(self, n: int) -> tuple[int, ...]: + if self._group_type == 0: + return (n,) + assert self._fixed_block_sizes is not None + sizes = self._fixed_block_sizes + if sum(sizes) != n: + raise ValueError( + "The Jacobian width `D` must equal the sum of block sizes implied by " + f"`encoder` or `shared_params` for this `group_type`. Found D={n}, " + f"sum(block_sizes)={sum(sizes)}.", + ) + return sizes + + def _ensure_state( + self, + m: int, + n: int, + sizes: tuple[int, ...], + device: torch.device, + dtype: torch.dtype, + ) -> None: + key = (m, n, sizes, device, dtype) + num_groups = len(sizes) + if self._state_key != key or self._rho_t is None: + self._rho_t = torch.zeros(m, m, num_groups, device=device, dtype=dtype) + self._state_key = key + + def forward(self, matrix: Matrix, /) -> Tensor: + grads = matrix + m, n = grads.shape + if m == 0 or n == 0: + return torch.zeros(n, dtype=grads.dtype, device=grads.device) + + sizes = self._resolve_segment_sizes(n) + device = grads.device + dtype = grads.dtype + self._ensure_state(m, n, sizes, device, dtype) + assert self._rho_t is not None + + rho_t = self._rho_t + beta = self._beta + eps = self.eps + + pc_grads = grads.clone() + offsets = [0] + for s in sizes: + offsets.append(offsets[-1] + s) + + for i in range(m): + others = [j for j in range(m) if j != i] + perm = torch.randperm(len(others)) + order = perm.tolist() + shuffled_js = [others[idx] for idx in order] + + for j in shuffled_js: + for k in range(len(sizes)): + beg, end = offsets[k], offsets[k + 1] + slice_i = pc_grads[i, beg:end] + slice_j = grads[j, beg:end] + + norm_i = slice_i.norm() + norm_j = slice_j.norm() + denom = norm_i * norm_j + eps + rho_ijk = slice_i.dot(slice_j) / denom + + bar = rho_t[i, j, k] + if rho_ijk < bar: + sqrt_1_rho2 = (1.0 - rho_ijk * rho_ijk).clamp(min=0.0).sqrt() + sqrt_1_bar2 = (1.0 - bar * bar).clamp(min=0.0).sqrt() + denom_w = norm_j * sqrt_1_bar2 + eps + w = norm_i * (bar * sqrt_1_rho2 - rho_ijk * sqrt_1_bar2) / denom_w + pc_grads[i, beg:end] = slice_i + slice_j * w + + rho_t[i, j, k] = (1.0 - beta) * bar + beta * rho_ijk + + return pc_grads.sum(dim=0) diff --git a/tests/unit/aggregation/test_gradvac.py b/tests/unit/aggregation/test_gradvac.py new file mode 100644 index 00000000..49a8770a --- /dev/null +++ b/tests/unit/aggregation/test_gradvac.py @@ -0,0 +1,185 @@ +import torch +import torch.nn as nn +from pytest import mark, raises +from torch import Tensor, tensor +from torch.testing import assert_close +from utils.tensors import ones_, randn_ + +from torchjd.aggregation import DEFAULT_GRADVAC_EPS, GradVac + +from ._asserts import assert_expected_structure, assert_non_differentiable +from ._inputs import scaled_matrices, typical_matrices + +scaled_pairs = [(GradVac(), m) for m in scaled_matrices] +typical_pairs = [(GradVac(), m) for m in typical_matrices] +requires_grad_pairs = [(GradVac(), ones_(3, 5, requires_grad=True))] + + +def test_repr() -> None: + g = GradVac() + assert repr(g).startswith("GradVac(") + assert "beta=" in repr(g) + assert "group_type=" in repr(g) + assert "encoder=" in repr(g) + assert "eps=" in repr(g) + + +def test_beta_out_of_range() -> None: + with raises(ValueError, match="beta"): + GradVac(beta=-0.1) + with raises(ValueError, match="beta"): + GradVac(beta=1.1) + + +def test_eps_non_positive() -> None: + with raises(ValueError, match="eps"): + GradVac(eps=0.0) + with raises(ValueError, match="eps"): + GradVac(eps=-1e-9) + + +def test_eps_setter_rejects_non_positive() -> None: + g = GradVac() + with raises(ValueError, match="eps"): + g.eps = 0.0 + + +def test_default_eps_constant() -> None: + assert DEFAULT_GRADVAC_EPS == 1e-8 + assert GradVac().eps == DEFAULT_GRADVAC_EPS + + +def test_eps_can_be_changed_between_steps() -> None: + j = tensor([[1.0, 0.0], [0.0, 1.0]]) + agg = GradVac() + agg.eps = 1e-6 + assert agg(j).isfinite().all() + agg.reset() + agg.eps = 1e-10 + assert agg(j).isfinite().all() + + +def test_group_type_invalid() -> None: + with raises(ValueError, match="group_type"): + GradVac(group_type=3) + + +def test_group_type_0_rejects_encoder() -> None: + net = nn.Linear(1, 1) + with raises(ValueError, match="encoder"): + GradVac(encoder=net) + + +def test_group_type_0_rejects_shared_params() -> None: + p = nn.Parameter(tensor([1.0])) + with raises(ValueError, match="shared_params"): + GradVac(shared_params=[p]) + + +def test_group_type_1_requires_encoder() -> None: + with raises(ValueError, match="encoder"): + GradVac(group_type=1) + + +def test_group_type_1_rejects_shared_params() -> None: + net = nn.Linear(1, 1) + p = nn.Parameter(tensor([1.0])) + with raises(ValueError, match="shared_params"): + GradVac(group_type=1, encoder=net, shared_params=[p]) + + +def test_group_type_2_requires_shared_params() -> None: + with raises(ValueError, match="shared_params"): + GradVac(group_type=2) + + +def test_group_type_2_rejects_encoder() -> None: + net = nn.Linear(1, 1) + with raises(ValueError, match="encoder"): + GradVac(group_type=2, encoder=net, shared_params=list(net.parameters())) + + +def test_encoder_without_leaf_parameters() -> None: + class Empty(nn.Module): + pass + + with raises(ValueError, match="encoder"): + GradVac(group_type=1, encoder=Empty()) + + +def test_shared_params_empty() -> None: + with raises(ValueError, match="shared_params"): + GradVac(group_type=2, shared_params=()) + + +def test_group_type_1_forward() -> None: + net = nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 1)) + d = sum(p.numel() for p in net.parameters()) + j = randn_((2, d)) + torch.manual_seed(0) + out = GradVac(group_type=1, encoder=net)(j) + assert out.shape == (d,) + assert out.isfinite().all() + + +def test_group_type_2_forward() -> None: + net = nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 1)) + params = list(net.parameters()) + d = sum(p.numel() for p in params) + j = randn_((2, d)) + torch.manual_seed(0) + out = GradVac(group_type=2, shared_params=params)(j) + assert out.shape == (d,) + assert out.isfinite().all() + + +def test_jacobian_width_mismatch() -> None: + net = nn.Linear(2, 2) + d = sum(p.numel() for p in net.parameters()) + agg = GradVac(group_type=1, encoder=net) + with raises(ValueError, match="Jacobian width"): + agg(tensor([[1.0] * (d - 1), [2.0] * (d - 1)])) + + +def test_zero_rows_returns_zero_vector() -> None: + out = GradVac()(tensor([]).reshape(0, 3)) + assert_close(out, tensor([0.0, 0.0, 0.0])) + + +def test_zero_columns_returns_zero_vector() -> None: + """Handled inside forward before grouping validation.""" + + out = GradVac()(tensor([]).reshape(2, 0)) + assert out.shape == (0,) + + +def test_reproducible_with_manual_seed() -> None: + j = randn_((3, 8)) + torch.manual_seed(12345) + a1 = GradVac(beta=0.3) + out1 = a1(j) + torch.manual_seed(12345) + a2 = GradVac(beta=0.3) + out2 = a2(j) + assert_close(out1, out2) + + +def test_reset_restores_first_step_behavior() -> None: + j = tensor([[1.0, 0.0], [0.0, 1.0]]) + torch.manual_seed(7) + agg = GradVac(beta=0.5) + first = agg(j) + agg(j) + agg.reset() + torch.manual_seed(7) + assert_close(first, agg(j)) + + +@mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) +def test_expected_structure(aggregator: GradVac, matrix: Tensor) -> None: + assert_expected_structure(aggregator, matrix) + + +@mark.parametrize(["aggregator", "matrix"], requires_grad_pairs) +def test_non_differentiable(aggregator: GradVac, matrix: Tensor) -> None: + assert_non_differentiable(aggregator, matrix) From a588c93ad97990e84bb31ff97c124aa32fe71785 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= <31951177+ValerianRey@users.noreply.github.com> Date: Sat, 11 Apr 2026 16:56:31 +0200 Subject: [PATCH 02/16] chore: Remove outdated doctesting stuff (#639) --- CONTRIBUTING.md | 8 +++----- pyproject.toml | 1 - 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 5b4c6253..2e422782 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -104,10 +104,9 @@ uv run pre-commit install uv run make doctest -C docs ``` - - To compute the code coverage locally, you should run the unit tests and the doc tests together, - with the `--cov` flag: + - To compute the code coverage locally, you should run the unit tests with the `--cov` flag: ```bash - uv run pytest tests/unit tests/doc --cov=src + uv run pytest tests/unit --cov=src ``` > [!TIP] @@ -148,8 +147,7 @@ should create it. ### Testing We ask contributors to implement the unit tests necessary to check the correctness of their -implementations. Besides, whenever usage examples are provided, we require the example's code to be -tested in `tests/doc`. We aim for 100% code coverage, but we greatly appreciate any PR, even with +implementations. We aim for 100% code coverage, but we greatly appreciate any PR, even with insufficient code coverage. To ensure that the tensors generated during the tests are on the right device and dtype, you have to use the partial functions defined in `tests/utils/tensors.py` to instantiate tensors. For instance, instead of diff --git a/pyproject.toml b/pyproject.toml index 84700e94..55e3263d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -163,7 +163,6 @@ ignore = [ [tool.ruff.lint.per-file-ignores] "**/conftest.py" = ["ARG"] # Can't change argument names in the functions pytest expects -"tests/doc/test_rst.py" = ["ARG"] # For the lightning example [tool.ruff.lint.flake8-annotations] suppress-dummy-args = true From 9d65f63aad3e42e18d100f711aac80c358406657 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Sat, 11 Apr 2026 19:12:04 +0200 Subject: [PATCH 03/16] chore: Add governance documentation (#637) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add GOVERNANCE.md documenting technical governance structure - Add CODEOWNERS file defining project maintainers - Add CODE_OF_CONDUCT.md referencing Linux Foundation CoC These files are required for PyTorch Ecosystem membership. --------- Co-authored-by: Valérian Rey <31951177+ValerianRey@users.noreply.github.com> --- CODEOWNERS | 24 ++++++++++++++ CODE_OF_CONDUCT.md | 7 ++++ CONTRIBUTING.md | 18 +++++++++-- GOVERNANCE.md | 80 ++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 127 insertions(+), 2 deletions(-) create mode 100644 CODEOWNERS create mode 100644 CODE_OF_CONDUCT.md create mode 100644 GOVERNANCE.md diff --git a/CODEOWNERS b/CODEOWNERS new file mode 100644 index 00000000..7806988e --- /dev/null +++ b/CODEOWNERS @@ -0,0 +1,24 @@ +# Code Owners for TorchJD +# +# This file defines the code owners for the repository. When a pull request is opened, +# GitHub automatically requests reviews from the appropriate code owners based on the +# files changed. +# +# Each line contains a pattern followed by one or more owners (GitHub usernames or team names). +# Patterns use gitignore-style glob patterns. +# +# For more information, see https://docs.github.com/en/repositories/managing-your-repositorys-settings-and-features/customifying-your-repository/about-code-owners +# +# Note: Ownership for new packages should be decided by the maintainers and SimplexLab. + +# Default owners for the entire repository +* @SimplexLab/maintainers + +# CI workflows +/.github/workflows/ @PierreQuinton @ValerianRey + +# Python packages in src/torchjd +/src/torchjd/_linalg/ @PierreQuinton @ValerianRey +/src/torchjd/aggregation/ @PierreQuinton @ValerianRey +/src/torchjd/autogram/ @PierreQuinton @ValerianRey +/src/torchjd/autojac/ @PierreQuinton @ValerianRey diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 00000000..24d3d186 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,7 @@ +# Code of Conduct + +TorchJD follows the [Linux Foundation Code of Conduct](https://lfprojects.org/policies/code-of-conduct/). + +## Changes to this Code of Conduct + +Changes to this Code of Conduct can only be made upon request from SimplexLab, which defines when and how such changes are possible. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 2e422782..cde0df1e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,7 +1,21 @@ # Contributing to TorchJD -This document explains how to contribute to TorchJD. Please use issues or discussions to communicate -with maintainers before implementing major changes. +This document explains how to contribute to TorchJD. + +## Getting Started + +- **Minor changes** (bug fixes, documentation, small improvements): Open a pull request directly following the guidelines in this document. +- **Significant or major changes** (new features, API changes, architectural decisions): Join the [SimplexLab Discord server](https://discord.gg/76KkRnb3nk), introduce yourself and your idea, and discuss it with the community to determine if and how it fits within the project's goals before implementing. + +## Code Ownership + +This project uses a [CODEOWNERS](CODEOWNERS) file to automatically assign reviewers to pull requests +based on which files are changed. The code owners are the people or groups who created or maintain +specific parts of the codebase. + +When you open a pull request, GitHub will automatically request reviews from the relevant code owners +for the files you've modified. This ensures that changes are reviewed by the people most familiar +with the affected code. ## Installation diff --git a/GOVERNANCE.md b/GOVERNANCE.md new file mode 100644 index 00000000..3501bb47 --- /dev/null +++ b/GOVERNANCE.md @@ -0,0 +1,80 @@ +# TorchJD Governance + +This document defines the governance structure and decision-making process for the TorchJD project. + +## Project Ownership + +The TorchJD project is the property of SimplexLab. SimplexLab has full authority over the project, including its direction, governance structure, and major decisions. Maintainers are typically members of SimplexLab and are responsible for day-to-day operations, code reviews, and technical decisions. + +## Maintainers + +TorchJD is maintained by: + +- **Valérian Rey** ([@ValerianRey](https://github.com/ValerianRey)) +- **Pierre Quinton** ([@PierreQuinton](https://github.com/PierreQuinton)) + +Maintainers are responsible for: +- Reviewing and merging pull requests +- Managing releases +- Setting project direction and priorities +- Ensuring code quality and consistency + +## Decision Making + +### Technical Decisions + +Most technical decisions are made through the pull request process: + +1. **Minor changes** (bug fixes, documentation, small improvements): Require approval from at least one maintainer +2. **Significant changes** (new features, API changes, refactoring): Should be discussed in an issue first, then require approval from at least one maintainer +3. **Major changes** (breaking changes, architectural decisions): Should be discussed in an issue or discussion thread and require consensus from all maintainers + +For significant or major changes, contributors should join the [SimplexLab Discord server](https://discord.gg/76KkRnb3nk), introduce themselves and their idea, and discuss it with the community to determine if and how it fits within the project's goals. + +### Pull Request Process + +1. Contributors submit pull requests following the guidelines in [CONTRIBUTING.md](CONTRIBUTING.md) +2. Maintainers review the code for correctness, style, and alignment with project goals +3. Once approved, any maintainer can merge the pull request +4. All pull requests must pass CI checks before being merged + +### Consensus + +For major decisions, maintainers aim for consensus. SimplexLab operates as a democratic decision-making body. If consensus among maintainers cannot be reached: +- The decision may be postponed for further discussion +- If a decision must be made, SimplexLab resolves the consensus based on the expertise of all maintainers relevant to the discussion as well as all people involved in the discussion + +## Release Process + +Releases are managed by maintainers following the process described in [CONTRIBUTING.md](CONTRIBUTING.md): + +1. Ensure all tests pass +2. Update the changelog +3. Update the version number +4. Create a release on GitHub +5. Verify deployment to PyPI + +## Adding Maintainers + +New maintainers may be added when: +- They have made significant, sustained contributions to the project +- They demonstrate understanding of the project's goals and coding standards +- They are committed to the long-term maintenance of the project + +New maintainers must be approved by SimplexLab, based on the report and recommendation of all existing maintainers. + +## Conflict Resolution + +Conflicts are resolved through discussion: +1. Issues should first be discussed in the relevant issue or pull request +2. If unresolved, maintainers discuss privately to reach consensus +3. If maintainers cannot reach consensus, SimplexLab has the final authority to resolve the conflict +4. The goal is always to find the best solution for the project and its users + +## Code of Conduct + +This project follows the [Linux Foundation Code of Conduct](https://lfprojects.org/policies/code-of-conduct/). + +## Changes to Governance + +Changes to this governance document can only be made upon request from SimplexLab, which defines when and how such changes are possible. From 3ab336c10442874a6eb741ef3ad7b960748980e4 Mon Sep 17 00:00:00 2001 From: rkhosrowshahi Date: Sun, 12 Apr 2026 12:29:22 -0400 Subject: [PATCH 04/16] refactor(gradvac): literal group types, eps/beta rules, and plotter UX - Use group_type "whole_model" | "all_layer" | "all_matrix" instead of 0/1/2 - Remove DEFAULT_GRADVAC_EPS from the public API; keep default 1e-8; allow eps=0 - Validate beta via setter; tighten GradVac repr/str expectations - Fix all_layer leaf sizing via children() and parameters() instead of private fields - Trim redundant GradVac.rst prose; align docs with the new API - Tests: GradVac cases, value regression with torch.manual_seed for GradVac - Plotter: factory dict + fresh aggregator instances per update; legend from selected keys; MathJax labels and live angle/length readouts in the sidebar This commit includes GradVac implementation with Aggregator class. --- docs/source/docs/aggregation/gradvac.rst | 13 -- src/torchjd/aggregation/__init__.py | 3 +- src/torchjd/aggregation/_gradvac.py | 194 ++++++++++++----------- tests/plots/_utils.py | 25 ++- tests/plots/interactive_plotter.py | 125 +++++++++++---- tests/unit/aggregation/test_gradvac.py | 91 +++++------ tests/unit/aggregation/test_values.py | 5 + 7 files changed, 269 insertions(+), 187 deletions(-) diff --git a/docs/source/docs/aggregation/gradvac.rst b/docs/source/docs/aggregation/gradvac.rst index 5fd2f7bf..2116fca3 100644 --- a/docs/source/docs/aggregation/gradvac.rst +++ b/docs/source/docs/aggregation/gradvac.rst @@ -3,19 +3,6 @@ GradVac ======= -.. autodata:: torchjd.aggregation.DEFAULT_GRADVAC_EPS - -The constructor argument ``group_type`` (default ``0``) sets **parameter granularity** for the -per-block cosine statistics in GradVac: - -* ``0`` — **whole model** (``whole_model``): one block per task gradient row. Omit ``encoder`` and - ``shared_params``. -* ``1`` — **all layer** (``all_layer``): one block per leaf submodule with parameters under - ``encoder`` (same traversal as ``encoder.modules()`` in the reference formulation). -* ``2`` — **all matrix** (``all_matrix``): one block per tensor in ``shared_params``, in order. Use - the same tensors as for the shared-parameter Jacobian columns (e.g. the parameters you would pass - to a shared-gradient helper). - .. autoclass:: torchjd.aggregation.GradVac :members: :undoc-members: diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index d6c39602..23f9d48c 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -66,7 +66,7 @@ from ._dualproj import DualProj, DualProjWeighting from ._flattening import Flattening from ._graddrop import GradDrop -from ._gradvac import DEFAULT_GRADVAC_EPS, GradVac +from ._gradvac import GradVac from ._imtl_g import IMTLG, IMTLGWeighting from ._krum import Krum, KrumWeighting from ._mean import Mean, MeanWeighting @@ -88,7 +88,6 @@ "ConFIG", "Constant", "ConstantWeighting", - "DEFAULT_GRADVAC_EPS", "DualProj", "DualProjWeighting", "Flattening", diff --git a/src/torchjd/aggregation/_gradvac.py b/src/torchjd/aggregation/_gradvac.py index 3d8f0c0f..d9fad734 100644 --- a/src/torchjd/aggregation/_gradvac.py +++ b/src/torchjd/aggregation/_gradvac.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections.abc import Iterable +from typing import Literal, cast import torch import torch.nn as nn @@ -11,11 +12,8 @@ from ._aggregator_bases import Aggregator from ._utils.non_differentiable import raise_non_differentiable_error -#: Default small constant added to denominators for numerical stability. -DEFAULT_GRADVAC_EPS = 1e-8 - -def _gradvac_all_layer_group_sizes(encoder: nn.Module) -> tuple[int, ...]: +def _all_layer_group_sizes(encoder: nn.Module) -> tuple[int, ...]: """ Block sizes per leaf submodule with parameters, matching the ``all_layer`` grouping: iterate ``encoder.modules()`` and append the total number of elements in each module that has no child @@ -25,11 +23,11 @@ def _gradvac_all_layer_group_sizes(encoder: nn.Module) -> tuple[int, ...]: return tuple( sum(w.numel() for w in module.parameters()) for module in encoder.modules() - if len(module._modules) == 0 and len(module._parameters) > 0 + if len(list(module.children())) == 0 and next(module.parameters(), None) is not None ) -def _gradvac_all_matrix_group_sizes(shared_params: Iterable[Tensor]) -> tuple[int, ...]: +def _all_matrix_group_sizes(shared_params: Iterable[Tensor]) -> tuple[int, ...]: """One block per tensor in ``shared_params`` order (``all_matrix`` / shared-parameter layout).""" return tuple(p.numel() for p in shared_params) @@ -42,40 +40,42 @@ class GradVac(Aggregator): Massively Multilingual Models (ICLR 2021 Spotlight) `_. - The input matrix is a Jacobian :math:`G \in \mathbb{R}^{M \times D}` whose rows are per-task - gradients. For each task :math:`i`, rows are visited in a random order; for each other task - :math:`j` and each parameter block :math:`k`, the cosine correlation :math:`\rho_{ijk}` between - the (possibly already modified) gradient of task :math:`i` and the original gradient of task - :math:`j` on that block is compared to an EMA target :math:`\bar{\rho}_{ijk}`. When - :math:`\rho_{ijk} < \bar{\rho}_{ijk}`, a closed-form correction adds a scaled copy of + The input matrix is a Jacobian :math:`J \in \mathbb{R}^{m \times n}` whose rows are per-task + gradients. For each task :math:`i` and each parameter block :math:`k`, the order in which other + tasks :math:`j` are visited is drawn at random (independently for each :math:`k`); for each pair + :math:`(i, j)` on block :math:`k`, the cosine correlation :math:`\phi_{ijk}` between the + (possibly already modified) gradient of task :math:`i` and the original gradient of task + :math:`j` on that block is compared to an EMA target :math:`\hat{\phi}_{ijk}`. When + :math:`\phi_{ijk} < \hat{\phi}_{ijk}`, a closed-form correction adds a scaled copy of :math:`g_j` to the block of :math:`g_i^{(\mathrm{PC})}`. The EMA is then updated with - :math:`\bar{\rho}_{ijk} \leftarrow (1-\beta)\bar{\rho}_{ijk} + \beta \rho_{ijk}`. The aggregated + :math:`\hat{\phi}_{ijk} \leftarrow (1-\beta)\hat{\phi}_{ijk} + \beta \phi_{ijk}`. The aggregated vector is the sum of the modified rows. - This aggregator is stateful: it keeps :math:`\bar{\rho}` across calls. Use :meth:`reset` when + This aggregator is stateful: it keeps :math:`\hat{\phi}` across calls. Use :meth:`reset` when the number of tasks, parameter dimension, grouping, device, or dtype changes. - **Parameter granularity** is selected by ``group_type`` (integer, default ``0``). It defines how - each task gradient row is partitioned into blocks :math:`k` so that cosines and EMA targets - :math:`\bar{\rho}_{ijk}` are computed **per block** rather than only globally: - - * ``0`` — **whole model** (``whole_model``): the full row of length :math:`D` is a single block. - Cosine similarity is taken between entire task gradients. Do not pass ``encoder`` or - ``shared_params``. - * ``1`` — **all layer** (``all_layer``): one block per leaf ``nn.Module`` under ``encoder`` that - holds parameters (same rule as iterating ``encoder.modules()`` and selecting leaves with - parameters). Pass ``encoder``; ``shared_params`` must be omitted. - * ``2`` — **all matrix** (``all_matrix``): one block per tensor in ``shared_params``, in iteration - order. That order must match how Jacobian columns are laid out for those shared parameters. - Pass ``shared_params``; ``encoder`` must be omitted. - - :param beta: EMA decay for :math:`\bar{\rho}` (paper default ``0.5``). + **Parameter granularity** is selected by ``group_type`` (default ``"whole_model"``). It defines + how each task gradient row is partitioned into blocks :math:`k` so that cosines and EMA targets + :math:`\hat{\phi}_{ijk}` are computed **per block** rather than only globally: + + * ``"whole_model"``: the full row of length :math:`n` is a single block. Cosine similarity is + taken between entire task gradients. Do not pass ``encoder`` or ``shared_params``. + * ``"all_layer"``: one block per leaf ``nn.Module`` under ``encoder`` that holds parameters + (same rule as iterating ``encoder.modules()`` and selecting leaves with parameters). Pass + ``encoder``; ``shared_params`` must be omitted. + * ``"all_matrix"``: one block per tensor in ``shared_params``, in iteration order. That order + must match how Jacobian columns are laid out for those shared parameters. Pass + ``shared_params``; ``encoder`` must be omitted. + + :param beta: EMA decay for :math:`\hat{\phi}` (paper default ``0.5``). You may read or assign the + :attr:`beta` attribute between steps to tune the EMA update. :param group_type: Granularity of parameter grouping; see **Parameter granularity** above. - :param encoder: Module whose subtree defines ``all_layer`` blocks when ``group_type == 1``. + :param encoder: Module whose subtree defines ``all_layer`` blocks when + ``group_type == "all_layer"``. :param shared_params: Iterable of parameter tensors defining ``all_matrix`` block sizes and - order when ``group_type == 2``. It is materialized once at construction. - :param eps: Small positive constant added to denominators when computing cosines and the - vaccine weight (default :data:`~torchjd.aggregation.DEFAULT_GRADVAC_EPS`). You may read or + order when ``group_type == "all_matrix"``. It is materialized once at construction. + :param eps: Small non-negative constant added to denominators when computing cosines and the + vaccine weight (default ``1e-8``); set to ``0`` to omit this stabilization. You may read or assign the :attr:`eps` attribute between steps to tune numerical behavior. .. note:: @@ -83,105 +83,123 @@ class GradVac(Aggregator): products, not only a Gram matrix. Only the autojac path is supported. .. note:: - Task-order shuffling uses the global PyTorch RNG (``torch.randperm``). Seed it with + For each task :math:`i` and block :math:`k`, the order of other tasks :math:`j` is shuffled + independently using the global PyTorch RNG (``torch.randperm``). Seed it with ``torch.manual_seed`` if you need reproducibility. """ def __init__( self, beta: float = 0.5, - group_type: int = 0, + group_type: Literal["whole_model", "all_layer", "all_matrix"] = "whole_model", encoder: nn.Module | None = None, shared_params: Iterable[Tensor] | None = None, - eps: float = DEFAULT_GRADVAC_EPS, + eps: float = 1e-8, ) -> None: super().__init__() if not (0.0 <= beta <= 1.0): raise ValueError(f"Parameter `beta` must be in [0, 1]. Found beta={beta!r}.") - if group_type not in (0, 1, 2): - raise ValueError( - "Parameter `group_type` must be 0 (whole_model), 1 (all_layer), or 2 (all_matrix). " - f"Found group_type={group_type!r}.", - ) params_tuple: tuple[Tensor, ...] = () fixed_block_sizes: tuple[int, ...] | None - if group_type == 0: + if group_type == "whole_model": if encoder is not None: - raise ValueError("Parameter `encoder` must be None when `group_type == 0`.") + raise ValueError( + 'Parameter `encoder` must be None when `group_type == "whole_model"`.' + ) if shared_params is not None: - raise ValueError("Parameter `shared_params` must be None when `group_type == 0`.") + raise ValueError( + 'Parameter `shared_params` must be None when `group_type == "whole_model"`.' + ) fixed_block_sizes = None - elif group_type == 1: + elif group_type == "all_layer": if encoder is None: - raise ValueError("Parameter `encoder` is required when `group_type == 1`.") + raise ValueError( + 'Parameter `encoder` is required when `group_type == "all_layer"`.' + ) if shared_params is not None: - raise ValueError("Parameter `shared_params` must be None when `group_type == 1`.") - fixed_block_sizes = _gradvac_all_layer_group_sizes(encoder) + raise ValueError( + 'Parameter `shared_params` must be None when `group_type == "all_layer"`.' + ) + fixed_block_sizes = _all_layer_group_sizes(encoder) if sum(fixed_block_sizes) == 0: raise ValueError("Parameter `encoder` has no parameters in any leaf module.") else: if shared_params is None: - raise ValueError("Parameter `shared_params` is required when `group_type == 2`.") + raise ValueError( + 'Parameter `shared_params` is required when `group_type == "all_matrix"`.' + ) if encoder is not None: - raise ValueError("Parameter `encoder` must be None when `group_type == 2`.") + raise ValueError( + 'Parameter `encoder` must be None when `group_type == "all_matrix"`.' + ) params_tuple = tuple(shared_params) if len(params_tuple) == 0: raise ValueError( - "Parameter `shared_params` must be non-empty when `group_type == 2`." + 'Parameter `shared_params` must be non-empty when `group_type == "all_matrix"`.' ) - fixed_block_sizes = _gradvac_all_matrix_group_sizes(params_tuple) + fixed_block_sizes = _all_matrix_group_sizes(params_tuple) - if eps <= 0.0: - raise ValueError(f"Parameter `eps` must be positive. Found eps={eps!r}.") + if eps < 0.0: + raise ValueError(f"Parameter `eps` must be non-negative. Found eps={eps!r}.") self._beta = beta self._group_type = group_type self._encoder = encoder self._shared_params_len = len(params_tuple) self._fixed_block_sizes = fixed_block_sizes - self._eps = float(eps) + self._eps = eps - self._rho_t: Tensor | None = None + self._phi_t: Tensor | None = None self._state_key: tuple[int, int, tuple[int, ...], torch.device, torch.dtype] | None = None self.register_full_backward_pre_hook(raise_non_differentiable_error) + @property + def beta(self) -> float: + """EMA decay coefficient for :math:`\\hat{\\phi}` (paper default ``0.5``).""" + + return self._beta + + @beta.setter + def beta(self, value: float) -> None: + if not (0.0 <= value <= 1.0): + raise ValueError(f"Attribute `beta` must be in [0, 1]. Found beta={value!r}.") + self._beta = value + @property def eps(self) -> float: - """Small positive constant added to denominators for numerical stability.""" + """Small non-negative constant added to denominators for numerical stability.""" return self._eps @eps.setter def eps(self, value: float) -> None: - v = float(value) - if v <= 0.0: - raise ValueError(f"Attribute `eps` must be positive. Found eps={value!r}.") - self._eps = v + if value < 0.0: + raise ValueError(f"Attribute `eps` must be non-negative. Found eps={value!r}.") + self._eps = value def reset(self) -> None: """Clears EMA state so the next forward starts from zero targets.""" - self._rho_t = None + self._phi_t = None self._state_key = None def __repr__(self) -> str: enc = "None" if self._encoder is None else f"{self._encoder.__class__.__name__}(...)" - sp = "None" if self._group_type != 2 else f"n_params={self._shared_params_len}" + sp = "None" if self._group_type != "all_matrix" else f"n_params={self._shared_params_len}" return ( f"{self.__class__.__name__}(beta={self._beta!r}, group_type={self._group_type!r}, " f"encoder={enc}, shared_params={sp}, eps={self._eps!r})" ) def _resolve_segment_sizes(self, n: int) -> tuple[int, ...]: - if self._group_type == 0: + if self._group_type == "whole_model": return (n,) - assert self._fixed_block_sizes is not None - sizes = self._fixed_block_sizes + sizes = cast(tuple[int, ...], self._fixed_block_sizes) if sum(sizes) != n: raise ValueError( - "The Jacobian width `D` must equal the sum of block sizes implied by " - f"`encoder` or `shared_params` for this `group_type`. Found D={n}, " + "The Jacobian width `n` must equal the sum of block sizes implied by " + f"`encoder` or `shared_params` for this `group_type`. Found n={n}, " f"sum(block_sizes)={sum(sizes)}.", ) return sizes @@ -196,8 +214,8 @@ def _ensure_state( ) -> None: key = (m, n, sizes, device, dtype) num_groups = len(sizes) - if self._state_key != key or self._rho_t is None: - self._rho_t = torch.zeros(m, m, num_groups, device=device, dtype=dtype) + if self._state_key != key or self._phi_t is None: + self._phi_t = torch.zeros(m, m, num_groups, device=device, dtype=dtype) self._state_key = key def forward(self, matrix: Matrix, /) -> Tensor: @@ -210,10 +228,8 @@ def forward(self, matrix: Matrix, /) -> Tensor: device = grads.device dtype = grads.dtype self._ensure_state(m, n, sizes, device, dtype) - assert self._rho_t is not None - - rho_t = self._rho_t - beta = self._beta + phi_t = cast(Tensor, self._phi_t) + beta = self.beta eps = self.eps pc_grads = grads.clone() @@ -223,29 +239,27 @@ def forward(self, matrix: Matrix, /) -> Tensor: for i in range(m): others = [j for j in range(m) if j != i] - perm = torch.randperm(len(others)) - order = perm.tolist() - shuffled_js = [others[idx] for idx in order] - - for j in shuffled_js: - for k in range(len(sizes)): - beg, end = offsets[k], offsets[k + 1] + for k in range(len(sizes)): + perm = torch.randperm(len(others)) + shuffled_js = [others[idx] for idx in perm.tolist()] + beg, end = offsets[k], offsets[k + 1] + for j in shuffled_js: slice_i = pc_grads[i, beg:end] slice_j = grads[j, beg:end] norm_i = slice_i.norm() norm_j = slice_j.norm() denom = norm_i * norm_j + eps - rho_ijk = slice_i.dot(slice_j) / denom - - bar = rho_t[i, j, k] - if rho_ijk < bar: - sqrt_1_rho2 = (1.0 - rho_ijk * rho_ijk).clamp(min=0.0).sqrt() - sqrt_1_bar2 = (1.0 - bar * bar).clamp(min=0.0).sqrt() - denom_w = norm_j * sqrt_1_bar2 + eps - w = norm_i * (bar * sqrt_1_rho2 - rho_ijk * sqrt_1_bar2) / denom_w + phi_ijk = slice_i.dot(slice_j) / denom + + phi_hat = phi_t[i, j, k] + if phi_ijk < phi_hat: + sqrt_1_phi2 = (1.0 - phi_ijk * phi_ijk).clamp(min=0.0).sqrt() + sqrt_1_hat2 = (1.0 - phi_hat * phi_hat).clamp(min=0.0).sqrt() + denom_w = norm_j * sqrt_1_hat2 + eps + w = norm_i * (phi_hat * sqrt_1_phi2 - phi_ijk * sqrt_1_hat2) / denom_w pc_grads[i, beg:end] = slice_i + slice_j * w - rho_t[i, j, k] = (1.0 - beta) * bar + beta * rho_ijk + phi_t[i, j, k] = (1.0 - beta) * phi_hat + beta * phi_ijk return pc_grads.sum(dim=0) diff --git a/tests/plots/_utils.py b/tests/plots/_utils.py index dc69bfda..40118fea 100644 --- a/tests/plots/_utils.py +++ b/tests/plots/_utils.py @@ -1,3 +1,5 @@ +from collections.abc import Callable + import numpy as np import torch from plotly import graph_objects as go @@ -7,14 +9,22 @@ class Plotter: - def __init__(self, aggregators: list[Aggregator], matrix: torch.Tensor, seed: int = 0) -> None: - self.aggregators = aggregators + def __init__( + self, + aggregator_factories: dict[str, Callable[[], Aggregator]], + selected_keys: list[str], + matrix: torch.Tensor, + seed: int = 0, + ) -> None: + self._aggregator_factories = aggregator_factories + self.selected_keys = selected_keys self.matrix = matrix self.seed = seed def make_fig(self) -> Figure: torch.random.manual_seed(self.seed) - results = [agg(self.matrix) for agg in self.aggregators] + aggregators = [self._aggregator_factories[key]() for key in self.selected_keys] + results = [agg(self.matrix) for agg in aggregators] fig = go.Figure() @@ -23,14 +33,19 @@ def make_fig(self) -> Figure: fig.add_trace(cone) for i in range(len(self.matrix)): - scatter = make_vector_scatter(self.matrix[i], "blue", f"g{i + 1}") + scatter = make_vector_scatter( + self.matrix[i], + "blue", + f"g{i + 1}", + textposition="top right", + ) fig.add_trace(scatter) for i in range(len(results)): scatter = make_vector_scatter( results[i], "black", - str(self.aggregators[i]), + self.selected_keys[i], showlegend=True, dash=True, ) diff --git a/tests/plots/interactive_plotter.py b/tests/plots/interactive_plotter.py index 2a945f93..4d43f397 100644 --- a/tests/plots/interactive_plotter.py +++ b/tests/plots/interactive_plotter.py @@ -1,6 +1,7 @@ import logging import os import webbrowser +from collections.abc import Callable from threading import Timer import numpy as np @@ -12,11 +13,13 @@ from torchjd.aggregation import ( IMTLG, MGDA, + Aggregator, AlignedMTL, CAGrad, ConFIG, DualProj, GradDrop, + GradVac, Mean, NashMTL, PCGrad, @@ -30,6 +33,14 @@ MAX_LENGTH = 25.0 +def _format_angle_display(angle: float) -> str: + return f"{angle:.4f} rad ({np.degrees(angle):.1f}°)" + + +def _format_length_display(r: float) -> str: + return f"{r:.4f}" + + def main() -> None: log = logging.getLogger("werkzeug") log.setLevel(logging.CRITICAL) @@ -42,26 +53,30 @@ def main() -> None: ], ) - aggregators = [ - AlignedMTL(), - CAGrad(c=0.5), - ConFIG(), - DualProj(), - GradDrop(), - IMTLG(), - Mean(), - MGDA(), - NashMTL(n_tasks=matrix.shape[0]), - PCGrad(), - Random(), - Sum(), - TrimmedMean(trim_number=1), - UPGrad(), - ] - - aggregators_dict = {str(aggregator): aggregator for aggregator in aggregators} - - plotter = Plotter([], matrix) + n_tasks = matrix.shape[0] + aggregator_factories: dict[str, Callable[[], Aggregator]] = { + "AlignedMTL-min": lambda: AlignedMTL(scale_mode="min"), + "AlignedMTL-median": lambda: AlignedMTL(scale_mode="median"), + "AlignedMTL-RMSE": lambda: AlignedMTL(scale_mode="rmse"), + str(CAGrad(c=0.5)): lambda: CAGrad(c=0.5), + str(ConFIG()): lambda: ConFIG(), + str(DualProj()): lambda: DualProj(), + str(GradDrop()): lambda: GradDrop(), + str(GradVac()): lambda: GradVac(), + str(IMTLG()): lambda: IMTLG(), + str(Mean()): lambda: Mean(), + str(MGDA()): lambda: MGDA(), + str(NashMTL(n_tasks=n_tasks)): lambda: NashMTL(n_tasks=n_tasks), + str(PCGrad()): lambda: PCGrad(), + str(Random()): lambda: Random(), + str(Sum()): lambda: Sum(), + str(TrimmedMean(trim_number=1)): lambda: TrimmedMean(trim_number=1), + str(UPGrad()): lambda: UPGrad(), + } + + aggregator_strings = list(aggregator_factories.keys()) + + plotter = Plotter(aggregator_factories, [], matrix) app = Dash(__name__) @@ -96,7 +111,6 @@ def main() -> None: gradient_slider_inputs.append(Input(angle_input, "value")) gradient_slider_inputs.append(Input(r_input, "value")) - aggregator_strings = [str(aggregator) for aggregator in aggregators] checklist = dcc.Checklist(aggregator_strings, [], id="aggregator-checklist") control_div = html.Div( @@ -115,22 +129,32 @@ def update_seed(value: int) -> Figure: plotter.seed = value return plotter.make_fig() + n_gradients = len(matrix) + gradient_value_outputs: list[Output] = [] + for i in range(n_gradients): + gradient_value_outputs.append(Output(f"g{i + 1}-angle-value", "children")) + gradient_value_outputs.append(Output(f"g{i + 1}-length-value", "children")) + @callback( Output("aggregations-fig", "figure", allow_duplicate=True), + *gradient_value_outputs, *gradient_slider_inputs, prevent_initial_call=True, ) - def update_gradient_coordinate(*values: str) -> Figure: + def update_gradient_coordinate(*values: str) -> tuple[Figure, ...]: values_ = [float(value) for value in values] + display_parts: list[str] = [] for j in range(len(values_) // 2): angle = values_[2 * j] r = values_[2 * j + 1] x, y = angle_to_coord(angle, r) plotter.matrix[j, 0] = x plotter.matrix[j, 1] = y + display_parts.append(_format_angle_display(angle)) + display_parts.append(_format_length_display(r)) - return plotter.make_fig() + return (plotter.make_fig(), *display_parts) @callback( Output("aggregations-fig", "figure", allow_duplicate=True), @@ -138,9 +162,7 @@ def update_gradient_coordinate(*values: str) -> Figure: prevent_initial_call=True, ) def update_aggregators(value: list[str]) -> Figure: - aggregator_keys = value - new_aggregators = [aggregators_dict[key] for key in aggregator_keys] - plotter.aggregators = new_aggregators + plotter.selected_keys = list(value) return plotter.make_fig() Timer(1, open_browser).start() @@ -173,11 +195,56 @@ def make_gradient_div( style={"width": "250px"}, ) + label_style: dict[str, str | int] = { + "display": "inline-block", + "width": "52px", + "margin-right": "8px", + "vertical-align": "middle", + } + value_style: dict[str, str] = { + "display": "inline-block", + "margin-left": "10px", + "min-width": "140px", + "font-family": "monospace", + "font-size": "13px", + "vertical-align": "middle", + } + row_style: dict[str, str] = {"display": "block", "margin-bottom": "6px"} div = html.Div( [ - html.P(f"g{i + 1}", style={"display": "inline-block", "margin-right": 20}), - angle_input, - r_input, + dcc.Markdown( + f"$g_{{{i + 1}}}$", + mathjax=True, + style={ + "margin": "0 0 6px 0", + "font-weight": "bold", + "display": "block", + }, + ), + html.Div( + [ + html.Span("Angle", style=label_style), + angle_input, + html.Span( + id=f"g{i + 1}-angle-value", + children=_format_angle_display(angle), + style=value_style, + ), + ], + style=row_style, + ), + html.Div( + [ + html.Span("Length", style=label_style), + r_input, + html.Span( + id=f"g{i + 1}-length-value", + children=_format_length_display(r), + style=value_style, + ), + ], + style={**row_style, "margin-bottom": "12px"}, + ), ], ) return div, angle_input, r_input diff --git a/tests/unit/aggregation/test_gradvac.py b/tests/unit/aggregation/test_gradvac.py index 49a8770a..7432aeac 100644 --- a/tests/unit/aggregation/test_gradvac.py +++ b/tests/unit/aggregation/test_gradvac.py @@ -1,27 +1,26 @@ import torch import torch.nn as nn from pytest import mark, raises -from torch import Tensor, tensor +from torch import Tensor from torch.testing import assert_close -from utils.tensors import ones_, randn_ +from utils.tensors import ones_, randn_, tensor_ -from torchjd.aggregation import DEFAULT_GRADVAC_EPS, GradVac +from torchjd.aggregation import GradVac from ._asserts import assert_expected_structure, assert_non_differentiable -from ._inputs import scaled_matrices, typical_matrices +from ._inputs import scaled_matrices, typical_matrices, typical_matrices_2_plus_rows scaled_pairs = [(GradVac(), m) for m in scaled_matrices] typical_pairs = [(GradVac(), m) for m in typical_matrices] requires_grad_pairs = [(GradVac(), ones_(3, 5, requires_grad=True))] -def test_repr() -> None: +def test_representations() -> None: g = GradVac() - assert repr(g).startswith("GradVac(") - assert "beta=" in repr(g) - assert "group_type=" in repr(g) - assert "encoder=" in repr(g) - assert "eps=" in repr(g) + assert repr(g) == ( + "GradVac(beta=0.5, group_type='whole_model', encoder=None, shared_params=None, eps=1e-08)" + ) + assert str(g) == "GradVac" def test_beta_out_of_range() -> None: @@ -31,26 +30,27 @@ def test_beta_out_of_range() -> None: GradVac(beta=1.1) -def test_eps_non_positive() -> None: - with raises(ValueError, match="eps"): - GradVac(eps=0.0) +def test_beta_setter_out_of_range() -> None: + g = GradVac() + with raises(ValueError, match="beta"): + g.beta = -0.1 + with raises(ValueError, match="beta"): + g.beta = 1.1 + + +def test_eps_rejects_negative() -> None: with raises(ValueError, match="eps"): GradVac(eps=-1e-9) -def test_eps_setter_rejects_non_positive() -> None: +def test_eps_setter_rejects_negative() -> None: g = GradVac() with raises(ValueError, match="eps"): - g.eps = 0.0 - - -def test_default_eps_constant() -> None: - assert DEFAULT_GRADVAC_EPS == 1e-8 - assert GradVac().eps == DEFAULT_GRADVAC_EPS + g.eps = -1e-9 def test_eps_can_be_changed_between_steps() -> None: - j = tensor([[1.0, 0.0], [0.0, 1.0]]) + j = tensor_([[1.0, 0.0], [0.0, 1.0]]) agg = GradVac() agg.eps = 1e-6 assert agg(j).isfinite().all() @@ -59,44 +59,39 @@ def test_eps_can_be_changed_between_steps() -> None: assert agg(j).isfinite().all() -def test_group_type_invalid() -> None: - with raises(ValueError, match="group_type"): - GradVac(group_type=3) - - def test_group_type_0_rejects_encoder() -> None: net = nn.Linear(1, 1) with raises(ValueError, match="encoder"): - GradVac(encoder=net) + GradVac(group_type="whole_model", encoder=net) def test_group_type_0_rejects_shared_params() -> None: - p = nn.Parameter(tensor([1.0])) + p = nn.Parameter(tensor_([1.0])) with raises(ValueError, match="shared_params"): - GradVac(shared_params=[p]) + GradVac(group_type="whole_model", shared_params=[p]) def test_group_type_1_requires_encoder() -> None: with raises(ValueError, match="encoder"): - GradVac(group_type=1) + GradVac(group_type="all_layer") def test_group_type_1_rejects_shared_params() -> None: net = nn.Linear(1, 1) - p = nn.Parameter(tensor([1.0])) + p = nn.Parameter(tensor_([1.0])) with raises(ValueError, match="shared_params"): - GradVac(group_type=1, encoder=net, shared_params=[p]) + GradVac(group_type="all_layer", encoder=net, shared_params=[p]) def test_group_type_2_requires_shared_params() -> None: with raises(ValueError, match="shared_params"): - GradVac(group_type=2) + GradVac(group_type="all_matrix") def test_group_type_2_rejects_encoder() -> None: net = nn.Linear(1, 1) with raises(ValueError, match="encoder"): - GradVac(group_type=2, encoder=net, shared_params=list(net.parameters())) + GradVac(group_type="all_matrix", encoder=net, shared_params=list(net.parameters())) def test_encoder_without_leaf_parameters() -> None: @@ -104,12 +99,12 @@ class Empty(nn.Module): pass with raises(ValueError, match="encoder"): - GradVac(group_type=1, encoder=Empty()) + GradVac(group_type="all_layer", encoder=Empty()) def test_shared_params_empty() -> None: with raises(ValueError, match="shared_params"): - GradVac(group_type=2, shared_params=()) + GradVac(group_type="all_matrix", shared_params=()) def test_group_type_1_forward() -> None: @@ -117,7 +112,7 @@ def test_group_type_1_forward() -> None: d = sum(p.numel() for p in net.parameters()) j = randn_((2, d)) torch.manual_seed(0) - out = GradVac(group_type=1, encoder=net)(j) + out = GradVac(group_type="all_layer", encoder=net)(j) assert out.shape == (d,) assert out.isfinite().all() @@ -128,7 +123,7 @@ def test_group_type_2_forward() -> None: d = sum(p.numel() for p in params) j = randn_((2, d)) torch.manual_seed(0) - out = GradVac(group_type=2, shared_params=params)(j) + out = GradVac(group_type="all_matrix", shared_params=params)(j) assert out.shape == (d,) assert out.isfinite().all() @@ -136,20 +131,20 @@ def test_group_type_2_forward() -> None: def test_jacobian_width_mismatch() -> None: net = nn.Linear(2, 2) d = sum(p.numel() for p in net.parameters()) - agg = GradVac(group_type=1, encoder=net) + agg = GradVac(group_type="all_layer", encoder=net) with raises(ValueError, match="Jacobian width"): - agg(tensor([[1.0] * (d - 1), [2.0] * (d - 1)])) + agg(tensor_([[1.0] * (d - 1), [2.0] * (d - 1)])) def test_zero_rows_returns_zero_vector() -> None: - out = GradVac()(tensor([]).reshape(0, 3)) - assert_close(out, tensor([0.0, 0.0, 0.0])) + out = GradVac()(tensor_([]).reshape(0, 3)) + assert_close(out, tensor_([0.0, 0.0, 0.0])) def test_zero_columns_returns_zero_vector() -> None: """Handled inside forward before grouping validation.""" - out = GradVac()(tensor([]).reshape(2, 0)) + out = GradVac()(tensor_([]).reshape(2, 0)) assert out.shape == (0,) @@ -164,15 +159,15 @@ def test_reproducible_with_manual_seed() -> None: assert_close(out1, out2) -def test_reset_restores_first_step_behavior() -> None: - j = tensor([[1.0, 0.0], [0.0, 1.0]]) +@mark.parametrize("matrix", typical_matrices_2_plus_rows) +def test_reset_restores_first_step_behavior(matrix: Tensor) -> None: torch.manual_seed(7) agg = GradVac(beta=0.5) - first = agg(j) - agg(j) + first = agg(matrix) + agg(matrix) agg.reset() torch.manual_seed(7) - assert_close(first, agg(j)) + assert_close(first, agg(matrix)) @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) diff --git a/tests/unit/aggregation/test_values.py b/tests/unit/aggregation/test_values.py index 42faca91..d4e8ba77 100644 --- a/tests/unit/aggregation/test_values.py +++ b/tests/unit/aggregation/test_values.py @@ -1,3 +1,4 @@ +import torch from pytest import mark, param from torch import Tensor, tensor from torch.testing import assert_close @@ -14,6 +15,7 @@ DualProj, DualProjWeighting, GradDrop, + GradVac, IMTLGWeighting, Krum, KrumWeighting, @@ -57,6 +59,7 @@ (Constant(tensor([1.0, 2.0])), J_base, tensor([8.0, 3.0, 3.0])), (DualProj(), J_base, tensor([0.5563, 1.1109, 1.1109])), (GradDrop(), J_base, tensor([6.0, 2.0, 2.0])), + (GradVac(), J_base, tensor([0.5848, 3.8012, 3.8012])), (IMTLG(), J_base, tensor([0.0767, 1.0000, 1.0000])), (Krum(n_byzantine=1, n_selected=4), J_Krum, tensor([1.2500, 0.7500, 1.5000])), (Mean(), J_base, tensor([1.0, 1.0, 1.0])), @@ -113,6 +116,8 @@ def test_aggregator_output(A: Aggregator, J: Tensor, expected_output: Tensor) -> None: """Test that the output values of an aggregator are fixed (on cpu).""" + if isinstance(A, GradVac): + torch.manual_seed(0) assert_close(A(J), expected_output, rtol=0, atol=1e-4) From e53849e769643ba9585a1ba8f69878f69ff9ae8c Mon Sep 17 00:00:00 2001 From: rkhosrowshahi Date: Sun, 12 Apr 2026 12:44:34 -0400 Subject: [PATCH 05/16] refactor(gradvac): base on GramianWeightedAggregator with GradVacWeighting GradVac only needs gradient norms and dot products, which are fully determined by the Gramian. This makes GradVac compatible with the autogram path. - Remove grouping parameters (group_type, encoder, shared_params) from GradVac - Export GradVacWeighting publicly --- docs/source/docs/aggregation/gradvac.rst | 5 + src/torchjd/aggregation/__init__.py | 3 +- src/torchjd/aggregation/_gradvac.py | 324 +++++++++-------------- tests/unit/aggregation/test_gradvac.py | 128 +++------ tests/unit/aggregation/test_values.py | 2 + 5 files changed, 185 insertions(+), 277 deletions(-) diff --git a/docs/source/docs/aggregation/gradvac.rst b/docs/source/docs/aggregation/gradvac.rst index 2116fca3..18e154bc 100644 --- a/docs/source/docs/aggregation/gradvac.rst +++ b/docs/source/docs/aggregation/gradvac.rst @@ -7,3 +7,8 @@ GradVac :members: :undoc-members: :exclude-members: forward + +.. autoclass:: torchjd.aggregation.GradVacWeighting + :members: + :undoc-members: + :exclude-members: forward diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index 23f9d48c..93f824e3 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -66,7 +66,7 @@ from ._dualproj import DualProj, DualProjWeighting from ._flattening import Flattening from ._graddrop import GradDrop -from ._gradvac import GradVac +from ._gradvac import GradVac, GradVacWeighting from ._imtl_g import IMTLG, IMTLGWeighting from ._krum import Krum, KrumWeighting from ._mean import Mean, MeanWeighting @@ -94,6 +94,7 @@ "GeneralizedWeighting", "GradDrop", "GradVac", + "GradVacWeighting", "IMTLG", "IMTLGWeighting", "Krum", diff --git a/src/torchjd/aggregation/_gradvac.py b/src/torchjd/aggregation/_gradvac.py index d9fad734..dc43a34b 100644 --- a/src/torchjd/aggregation/_gradvac.py +++ b/src/torchjd/aggregation/_gradvac.py @@ -1,158 +1,124 @@ from __future__ import annotations -from collections.abc import Iterable -from typing import Literal, cast +from typing import cast import torch -import torch.nn as nn from torch import Tensor -from torchjd._linalg import Matrix +from torchjd._linalg import PSDMatrix -from ._aggregator_bases import Aggregator +from ._aggregator_bases import GramianWeightedAggregator from ._utils.non_differentiable import raise_non_differentiable_error +from ._weighting_bases import Weighting -def _all_layer_group_sizes(encoder: nn.Module) -> tuple[int, ...]: - """ - Block sizes per leaf submodule with parameters, matching the ``all_layer`` grouping: iterate - ``encoder.modules()`` and append the total number of elements in each module that has no child - submodules and registers at least one parameter. - """ - - return tuple( - sum(w.numel() for w in module.parameters()) - for module in encoder.modules() - if len(list(module.children())) == 0 and next(module.parameters(), None) is not None - ) - - -def _all_matrix_group_sizes(shared_params: Iterable[Tensor]) -> tuple[int, ...]: - """One block per tensor in ``shared_params`` order (``all_matrix`` / shared-parameter layout).""" - - return tuple(p.numel() for p in shared_params) - - -class GradVac(Aggregator): +class GradVac(GramianWeightedAggregator): r""" - :class:`~torchjd.aggregation._aggregator_bases.Aggregator` implementing Gradient Vaccine - (GradVac) from `Gradient Vaccine: Investigating and Improving Multi-task Optimization in - Massively Multilingual Models (ICLR 2021 Spotlight) + :class:`~torchjd.aggregation._aggregator_bases.Aggregator` implementing the aggregation step of + Gradient Vaccine (GradVac) from `Gradient Vaccine: Investigating and Improving Multi-task + Optimization in Massively Multilingual Models (ICLR 2021 Spotlight) `_. - The input matrix is a Jacobian :math:`J \in \mathbb{R}^{m \times n}` whose rows are per-task - gradients. For each task :math:`i` and each parameter block :math:`k`, the order in which other - tasks :math:`j` are visited is drawn at random (independently for each :math:`k`); for each pair - :math:`(i, j)` on block :math:`k`, the cosine correlation :math:`\phi_{ijk}` between the + For each task :math:`i`, the order in which other tasks :math:`j` are visited is drawn at + random. For each pair :math:`(i, j)`, the cosine similarity :math:`\phi_{ij}` between the (possibly already modified) gradient of task :math:`i` and the original gradient of task - :math:`j` on that block is compared to an EMA target :math:`\hat{\phi}_{ijk}`. When - :math:`\phi_{ijk} < \hat{\phi}_{ijk}`, a closed-form correction adds a scaled copy of - :math:`g_j` to the block of :math:`g_i^{(\mathrm{PC})}`. The EMA is then updated with - :math:`\hat{\phi}_{ijk} \leftarrow (1-\beta)\hat{\phi}_{ijk} + \beta \phi_{ijk}`. The aggregated + :math:`j` is compared to an EMA target :math:`\hat{\phi}_{ij}`. When + :math:`\phi_{ij} < \hat{\phi}_{ij}`, a closed-form correction adds a scaled copy of + :math:`g_j` to :math:`g_i^{(\mathrm{PC})}`. The EMA is then updated with + :math:`\hat{\phi}_{ij} \leftarrow (1-\beta)\hat{\phi}_{ij} + \beta \phi_{ij}`. The aggregated vector is the sum of the modified rows. This aggregator is stateful: it keeps :math:`\hat{\phi}` across calls. Use :meth:`reset` when - the number of tasks, parameter dimension, grouping, device, or dtype changes. - - **Parameter granularity** is selected by ``group_type`` (default ``"whole_model"``). It defines - how each task gradient row is partitioned into blocks :math:`k` so that cosines and EMA targets - :math:`\hat{\phi}_{ijk}` are computed **per block** rather than only globally: - - * ``"whole_model"``: the full row of length :math:`n` is a single block. Cosine similarity is - taken between entire task gradients. Do not pass ``encoder`` or ``shared_params``. - * ``"all_layer"``: one block per leaf ``nn.Module`` under ``encoder`` that holds parameters - (same rule as iterating ``encoder.modules()`` and selecting leaves with parameters). Pass - ``encoder``; ``shared_params`` must be omitted. - * ``"all_matrix"``: one block per tensor in ``shared_params``, in iteration order. That order - must match how Jacobian columns are laid out for those shared parameters. Pass - ``shared_params``; ``encoder`` must be omitted. - - :param beta: EMA decay for :math:`\hat{\phi}` (paper default ``0.5``). You may read or assign the - :attr:`beta` attribute between steps to tune the EMA update. - :param group_type: Granularity of parameter grouping; see **Parameter granularity** above. - :param encoder: Module whose subtree defines ``all_layer`` blocks when - ``group_type == "all_layer"``. - :param shared_params: Iterable of parameter tensors defining ``all_matrix`` block sizes and - order when ``group_type == "all_matrix"``. It is materialized once at construction. + the number of tasks or dtype changes. + + :param beta: EMA decay for :math:`\hat{\phi}` (paper default ``0.5``). You may read or assign + the :attr:`beta` attribute between steps to tune the EMA update. :param eps: Small non-negative constant added to denominators when computing cosines and the vaccine weight (default ``1e-8``); set to ``0`` to omit this stabilization. You may read or assign the :attr:`eps` attribute between steps to tune numerical behavior. .. note:: - GradVac is not compatible with autogram: it needs full Jacobian rows and per-block inner - products, not only a Gram matrix. Only the autojac path is supported. + For each task :math:`i`, the order of other tasks :math:`j` is shuffled independently + using the global PyTorch RNG (``torch.randperm``). Seed it with ``torch.manual_seed`` if + you need reproducibility. .. note:: - For each task :math:`i` and block :math:`k`, the order of other tasks :math:`j` is shuffled - independently using the global PyTorch RNG (``torch.randperm``). Seed it with - ``torch.manual_seed`` if you need reproducibility. + To apply GradVac with per-layer or per-parameter-group granularity, first aggregate the + Jacobian into groups, apply GradVac per group, and sum the results. See the grouping usage + example for details. + """ + + def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None: + weighting = GradVacWeighting(beta=beta, eps=eps) + super().__init__(weighting) + self._gradvac_weighting = weighting + self.register_full_backward_pre_hook(raise_non_differentiable_error) + + @property + def beta(self) -> float: + """EMA decay coefficient for :math:`\\hat{\\phi}` (paper default ``0.5``).""" + + return self._gradvac_weighting.beta + + @beta.setter + def beta(self, value: float) -> None: + self._gradvac_weighting.beta = value + + @property + def eps(self) -> float: + """Small non-negative constant added to denominators for numerical stability.""" + + return self._gradvac_weighting.eps + + @eps.setter + def eps(self, value: float) -> None: + self._gradvac_weighting.eps = value + + def reset(self) -> None: + """Clears EMA state so the next forward starts from zero targets.""" + + self._gradvac_weighting.reset() + + def __repr__(self) -> str: + return f"GradVac(beta={self.beta!r}, eps={self.eps!r})" + + +class GradVacWeighting(Weighting[PSDMatrix]): + r""" + :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of + :class:`~torchjd.aggregation.GradVac`. + + All required quantities (gradient norms, cosine similarities, and their updates after the + vaccine correction) are derived purely from the Gramian, without needing the full Jacobian. + If :math:`g_i^{(\mathrm{PC})} = \sum_k c_{ik} g_k`, then: + + .. math:: + + \|g_i^{(\mathrm{PC})}\|^2 = \mathbf{c}_i G \mathbf{c}_i^\top,\qquad + g_i^{(\mathrm{PC})} \cdot g_j = \mathbf{c}_i G_{:,j} + + where :math:`G` is the Gramian. The correction :math:`g_i^{(\mathrm{PC})} \mathrel{+}= w + g_j` then becomes :math:`c_{ij} \mathrel{+}= w`, and the updated dot products follow + immediately. + + This weighting is stateful: it keeps :math:`\hat{\phi}` across calls. Use :meth:`reset` when + the number of tasks or dtype changes. + + :param beta: EMA decay for :math:`\hat{\phi}` (paper default ``0.5``). + :param eps: Small non-negative constant added to denominators (default ``1e-8``). """ - def __init__( - self, - beta: float = 0.5, - group_type: Literal["whole_model", "all_layer", "all_matrix"] = "whole_model", - encoder: nn.Module | None = None, - shared_params: Iterable[Tensor] | None = None, - eps: float = 1e-8, - ) -> None: + def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None: super().__init__() if not (0.0 <= beta <= 1.0): raise ValueError(f"Parameter `beta` must be in [0, 1]. Found beta={beta!r}.") - params_tuple: tuple[Tensor, ...] = () - fixed_block_sizes: tuple[int, ...] | None - if group_type == "whole_model": - if encoder is not None: - raise ValueError( - 'Parameter `encoder` must be None when `group_type == "whole_model"`.' - ) - if shared_params is not None: - raise ValueError( - 'Parameter `shared_params` must be None when `group_type == "whole_model"`.' - ) - fixed_block_sizes = None - elif group_type == "all_layer": - if encoder is None: - raise ValueError( - 'Parameter `encoder` is required when `group_type == "all_layer"`.' - ) - if shared_params is not None: - raise ValueError( - 'Parameter `shared_params` must be None when `group_type == "all_layer"`.' - ) - fixed_block_sizes = _all_layer_group_sizes(encoder) - if sum(fixed_block_sizes) == 0: - raise ValueError("Parameter `encoder` has no parameters in any leaf module.") - else: - if shared_params is None: - raise ValueError( - 'Parameter `shared_params` is required when `group_type == "all_matrix"`.' - ) - if encoder is not None: - raise ValueError( - 'Parameter `encoder` must be None when `group_type == "all_matrix"`.' - ) - params_tuple = tuple(shared_params) - if len(params_tuple) == 0: - raise ValueError( - 'Parameter `shared_params` must be non-empty when `group_type == "all_matrix"`.' - ) - fixed_block_sizes = _all_matrix_group_sizes(params_tuple) - if eps < 0.0: raise ValueError(f"Parameter `eps` must be non-negative. Found eps={eps!r}.") self._beta = beta - self._group_type = group_type - self._encoder = encoder - self._shared_params_len = len(params_tuple) - self._fixed_block_sizes = fixed_block_sizes self._eps = eps - self._phi_t: Tensor | None = None - self._state_key: tuple[int, int, tuple[int, ...], torch.device, torch.dtype] | None = None - - self.register_full_backward_pre_hook(raise_non_differentiable_error) + self._state_key: tuple[int, torch.dtype] | None = None @property def beta(self) -> float: @@ -184,82 +150,56 @@ def reset(self) -> None: self._phi_t = None self._state_key = None - def __repr__(self) -> str: - enc = "None" if self._encoder is None else f"{self._encoder.__class__.__name__}(...)" - sp = "None" if self._group_type != "all_matrix" else f"n_params={self._shared_params_len}" - return ( - f"{self.__class__.__name__}(beta={self._beta!r}, group_type={self._group_type!r}, " - f"encoder={enc}, shared_params={sp}, eps={self._eps!r})" - ) - - def _resolve_segment_sizes(self, n: int) -> tuple[int, ...]: - if self._group_type == "whole_model": - return (n,) - sizes = cast(tuple[int, ...], self._fixed_block_sizes) - if sum(sizes) != n: - raise ValueError( - "The Jacobian width `n` must equal the sum of block sizes implied by " - f"`encoder` or `shared_params` for this `group_type`. Found n={n}, " - f"sum(block_sizes)={sum(sizes)}.", - ) - return sizes - - def _ensure_state( - self, - m: int, - n: int, - sizes: tuple[int, ...], - device: torch.device, - dtype: torch.dtype, - ) -> None: - key = (m, n, sizes, device, dtype) - num_groups = len(sizes) - if self._state_key != key or self._phi_t is None: - self._phi_t = torch.zeros(m, m, num_groups, device=device, dtype=dtype) - self._state_key = key + def forward(self, gramian: PSDMatrix, /) -> Tensor: + device = gramian.device + dtype = gramian.dtype + cpu = torch.device("cpu") - def forward(self, matrix: Matrix, /) -> Tensor: - grads = matrix - m, n = grads.shape - if m == 0 or n == 0: - return torch.zeros(n, dtype=grads.dtype, device=grads.device) + G = cast(PSDMatrix, gramian.to(device=cpu)) + m = G.shape[0] - sizes = self._resolve_segment_sizes(n) - device = grads.device - dtype = grads.dtype - self._ensure_state(m, n, sizes, device, dtype) + self._ensure_state(m, dtype) phi_t = cast(Tensor, self._phi_t) - beta = self.beta - eps = self.eps - pc_grads = grads.clone() - offsets = [0] - for s in sizes: - offsets.append(offsets[-1] + s) + beta = self._beta + eps = self._eps + + # C[i, :] holds coefficients such that g_i^PC = sum_k C[i,k] * g_k (original gradients). + # Initially each modified gradient equals the original, so C = I. + C = torch.eye(m, device=cpu, dtype=dtype) for i in range(m): + # Dot products of g_i^PC with every original g_j, shape (m,). + cG = C[i] @ G + others = [j for j in range(m) if j != i] - for k in range(len(sizes)): - perm = torch.randperm(len(others)) - shuffled_js = [others[idx] for idx in perm.tolist()] - beg, end = offsets[k], offsets[k + 1] - for j in shuffled_js: - slice_i = pc_grads[i, beg:end] - slice_j = grads[j, beg:end] - - norm_i = slice_i.norm() - norm_j = slice_j.norm() - denom = norm_i * norm_j + eps - phi_ijk = slice_i.dot(slice_j) / denom - - phi_hat = phi_t[i, j, k] - if phi_ijk < phi_hat: - sqrt_1_phi2 = (1.0 - phi_ijk * phi_ijk).clamp(min=0.0).sqrt() - sqrt_1_hat2 = (1.0 - phi_hat * phi_hat).clamp(min=0.0).sqrt() - denom_w = norm_j * sqrt_1_hat2 + eps - w = norm_i * (phi_hat * sqrt_1_phi2 - phi_ijk * sqrt_1_hat2) / denom_w - pc_grads[i, beg:end] = slice_i + slice_j * w - - phi_t[i, j, k] = (1.0 - beta) * phi_hat + beta * phi_ijk - - return pc_grads.sum(dim=0) + perm = torch.randperm(len(others)) + shuffled_js = [others[idx] for idx in perm.tolist()] + + for j in shuffled_js: + dot_ij = cG[j] + norm_i_sq = (cG * C[i]).sum() + norm_i = norm_i_sq.clamp(min=0.0).sqrt() + norm_j = G[j, j].clamp(min=0.0).sqrt() + denom = norm_i * norm_j + eps + phi_ijk = dot_ij / denom + + phi_hat = phi_t[i, j] + if phi_ijk < phi_hat: + sqrt_1_phi2 = (1.0 - phi_ijk * phi_ijk).clamp(min=0.0).sqrt() + sqrt_1_hat2 = (1.0 - phi_hat * phi_hat).clamp(min=0.0).sqrt() + denom_w = norm_j * sqrt_1_hat2 + eps + w = norm_i * (phi_hat * sqrt_1_phi2 - phi_ijk * sqrt_1_hat2) / denom_w + C[i, j] = C[i, j] + w + cG = cG + w * G[j] + + phi_t[i, j] = (1.0 - beta) * phi_hat + beta * phi_ijk + + weights = C.sum(dim=0) + return weights.to(device) + + def _ensure_state(self, m: int, dtype: torch.dtype) -> None: + key = (m, dtype) + if self._state_key != key or self._phi_t is None: + self._phi_t = torch.zeros(m, m, dtype=dtype) + self._state_key = key diff --git a/tests/unit/aggregation/test_gradvac.py b/tests/unit/aggregation/test_gradvac.py index 7432aeac..f1105800 100644 --- a/tests/unit/aggregation/test_gradvac.py +++ b/tests/unit/aggregation/test_gradvac.py @@ -1,11 +1,10 @@ import torch -import torch.nn as nn from pytest import mark, raises from torch import Tensor from torch.testing import assert_close from utils.tensors import ones_, randn_, tensor_ -from torchjd.aggregation import GradVac +from torchjd.aggregation import GradVac, GradVacWeighting from ._asserts import assert_expected_structure, assert_non_differentiable from ._inputs import scaled_matrices, typical_matrices, typical_matrices_2_plus_rows @@ -17,9 +16,7 @@ def test_representations() -> None: g = GradVac() - assert repr(g) == ( - "GradVac(beta=0.5, group_type='whole_model', encoder=None, shared_params=None, eps=1e-08)" - ) + assert repr(g) == "GradVac(beta=0.5, eps=1e-08)" assert str(g) == "GradVac" @@ -59,91 +56,12 @@ def test_eps_can_be_changed_between_steps() -> None: assert agg(j).isfinite().all() -def test_group_type_0_rejects_encoder() -> None: - net = nn.Linear(1, 1) - with raises(ValueError, match="encoder"): - GradVac(group_type="whole_model", encoder=net) - - -def test_group_type_0_rejects_shared_params() -> None: - p = nn.Parameter(tensor_([1.0])) - with raises(ValueError, match="shared_params"): - GradVac(group_type="whole_model", shared_params=[p]) - - -def test_group_type_1_requires_encoder() -> None: - with raises(ValueError, match="encoder"): - GradVac(group_type="all_layer") - - -def test_group_type_1_rejects_shared_params() -> None: - net = nn.Linear(1, 1) - p = nn.Parameter(tensor_([1.0])) - with raises(ValueError, match="shared_params"): - GradVac(group_type="all_layer", encoder=net, shared_params=[p]) - - -def test_group_type_2_requires_shared_params() -> None: - with raises(ValueError, match="shared_params"): - GradVac(group_type="all_matrix") - - -def test_group_type_2_rejects_encoder() -> None: - net = nn.Linear(1, 1) - with raises(ValueError, match="encoder"): - GradVac(group_type="all_matrix", encoder=net, shared_params=list(net.parameters())) - - -def test_encoder_without_leaf_parameters() -> None: - class Empty(nn.Module): - pass - - with raises(ValueError, match="encoder"): - GradVac(group_type="all_layer", encoder=Empty()) - - -def test_shared_params_empty() -> None: - with raises(ValueError, match="shared_params"): - GradVac(group_type="all_matrix", shared_params=()) - - -def test_group_type_1_forward() -> None: - net = nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 1)) - d = sum(p.numel() for p in net.parameters()) - j = randn_((2, d)) - torch.manual_seed(0) - out = GradVac(group_type="all_layer", encoder=net)(j) - assert out.shape == (d,) - assert out.isfinite().all() - - -def test_group_type_2_forward() -> None: - net = nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 1)) - params = list(net.parameters()) - d = sum(p.numel() for p in params) - j = randn_((2, d)) - torch.manual_seed(0) - out = GradVac(group_type="all_matrix", shared_params=params)(j) - assert out.shape == (d,) - assert out.isfinite().all() - - -def test_jacobian_width_mismatch() -> None: - net = nn.Linear(2, 2) - d = sum(p.numel() for p in net.parameters()) - agg = GradVac(group_type="all_layer", encoder=net) - with raises(ValueError, match="Jacobian width"): - agg(tensor_([[1.0] * (d - 1), [2.0] * (d - 1)])) - - def test_zero_rows_returns_zero_vector() -> None: out = GradVac()(tensor_([]).reshape(0, 3)) assert_close(out, tensor_([0.0, 0.0, 0.0])) def test_zero_columns_returns_zero_vector() -> None: - """Handled inside forward before grouping validation.""" - out = GradVac()(tensor_([]).reshape(2, 0)) assert out.shape == (0,) @@ -178,3 +96,45 @@ def test_expected_structure(aggregator: GradVac, matrix: Tensor) -> None: @mark.parametrize(["aggregator", "matrix"], requires_grad_pairs) def test_non_differentiable(aggregator: GradVac, matrix: Tensor) -> None: assert_non_differentiable(aggregator, matrix) + + +def test_weighting_beta_out_of_range() -> None: + with raises(ValueError, match="beta"): + GradVacWeighting(beta=-0.1) + with raises(ValueError, match="beta"): + GradVacWeighting(beta=1.1) + + +def test_weighting_eps_rejects_negative() -> None: + with raises(ValueError, match="eps"): + GradVacWeighting(eps=-1e-9) + + +def test_weighting_reset_restores_first_step_behavior() -> None: + j = randn_((3, 8)) + G = j @ j.T + torch.manual_seed(7) + w = GradVacWeighting(beta=0.5) + first = w(G) + w(G) + w.reset() + torch.manual_seed(7) + assert_close(first, w(G)) + + +def test_aggregator_and_weighting_agree() -> None: + """GradVac()(J) == GradVacWeighting()(J @ J.T) @ J for any matrix J.""" + + j = randn_((3, 8)) + G = j @ j.T + + torch.manual_seed(42) + agg = GradVac(beta=0.3) + expected = agg(j) + + torch.manual_seed(42) + weighting = GradVacWeighting(beta=0.3) + weights = weighting(G) + result = weights @ j + + assert_close(result, expected, rtol=1e-4, atol=1e-4) diff --git a/tests/unit/aggregation/test_values.py b/tests/unit/aggregation/test_values.py index d4e8ba77..4af2b810 100644 --- a/tests/unit/aggregation/test_values.py +++ b/tests/unit/aggregation/test_values.py @@ -16,6 +16,7 @@ DualProjWeighting, GradDrop, GradVac, + GradVacWeighting, IMTLGWeighting, Krum, KrumWeighting, @@ -80,6 +81,7 @@ (DualProjWeighting(), G_base, tensor([0.6109, 0.5000])), (IMTLGWeighting(), G_base, tensor([0.5923, 0.4077])), (KrumWeighting(1, 4), G_Krum, tensor([0.2500, 0.2500, 0.0000, 0.2500, 0.2500])), + (GradVacWeighting(), G_base, tensor([2.2222, 1.5789])), (MeanWeighting(), G_base, tensor([0.5000, 0.5000])), (MGDAWeighting(), G_base, tensor([0.6000, 0.4000])), (PCGradWeighting(), G_base, tensor([2.2222, 1.5789])), From 49099640c1621b1002e623333c61402196208c8b Mon Sep 17 00:00:00 2001 From: rkhosrowshahi Date: Sun, 12 Apr 2026 12:59:33 -0400 Subject: [PATCH 06/16] fix: update type hint for update_gradient_coordinate function --- tests/plots/interactive_plotter.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/plots/interactive_plotter.py b/tests/plots/interactive_plotter.py index 4d43f397..56031a61 100644 --- a/tests/plots/interactive_plotter.py +++ b/tests/plots/interactive_plotter.py @@ -8,6 +8,7 @@ import torch from dash import Dash, Input, Output, callback, dcc, html from plotly.graph_objs import Figure +from typing_extensions import Unpack from plots._utils import Plotter, angle_to_coord, coord_to_angle from torchjd.aggregation import ( @@ -141,7 +142,7 @@ def update_seed(value: int) -> Figure: *gradient_slider_inputs, prevent_initial_call=True, ) - def update_gradient_coordinate(*values: str) -> tuple[Figure, ...]: + def update_gradient_coordinate(*values: str) -> tuple[Figure, Unpack[tuple[str, ...]]]: values_ = [float(value) for value in values] display_parts: list[str] = [] From a39f3430d4c6d32f8d322fa34bbbcfb836821d1b Mon Sep 17 00:00:00 2001 From: rkhosrowshahi Date: Sun, 12 Apr 2026 13:02:10 -0400 Subject: [PATCH 07/16] test(gradvac): cover beta setter success path for codecov --- tests/unit/aggregation/test_gradvac.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/unit/aggregation/test_gradvac.py b/tests/unit/aggregation/test_gradvac.py index f1105800..807fc243 100644 --- a/tests/unit/aggregation/test_gradvac.py +++ b/tests/unit/aggregation/test_gradvac.py @@ -35,6 +35,12 @@ def test_beta_setter_out_of_range() -> None: g.beta = 1.1 +def test_beta_setter_updates_value() -> None: + g = GradVac() + g.beta = 0.25 + assert g.beta == 0.25 + + def test_eps_rejects_negative() -> None: with raises(ValueError, match="eps"): GradVac(eps=-1e-9) From 0359e60705a3437072ea14c0a541b90bc3dbb2b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 12 Apr 2026 21:24:23 +0200 Subject: [PATCH 08/16] Rename some variables in test_gradvac.py --- tests/unit/aggregation/test_gradvac.py | 74 +++++++++++++------------- 1 file changed, 37 insertions(+), 37 deletions(-) diff --git a/tests/unit/aggregation/test_gradvac.py b/tests/unit/aggregation/test_gradvac.py index 807fc243..bde2e8fd 100644 --- a/tests/unit/aggregation/test_gradvac.py +++ b/tests/unit/aggregation/test_gradvac.py @@ -15,9 +15,9 @@ def test_representations() -> None: - g = GradVac() - assert repr(g) == "GradVac(beta=0.5, eps=1e-08)" - assert str(g) == "GradVac" + A = GradVac() + assert repr(A) == "GradVac(beta=0.5, eps=1e-08)" + assert str(A) == "GradVac" def test_beta_out_of_range() -> None: @@ -28,17 +28,17 @@ def test_beta_out_of_range() -> None: def test_beta_setter_out_of_range() -> None: - g = GradVac() + A = GradVac() with raises(ValueError, match="beta"): - g.beta = -0.1 + A.beta = -0.1 with raises(ValueError, match="beta"): - g.beta = 1.1 + A.beta = 1.1 def test_beta_setter_updates_value() -> None: - g = GradVac() - g.beta = 0.25 - assert g.beta == 0.25 + A = GradVac() + A.beta = 0.25 + assert A.beta == 0.25 def test_eps_rejects_negative() -> None: @@ -47,19 +47,19 @@ def test_eps_rejects_negative() -> None: def test_eps_setter_rejects_negative() -> None: - g = GradVac() + A = GradVac() with raises(ValueError, match="eps"): - g.eps = -1e-9 + A.eps = -1e-9 def test_eps_can_be_changed_between_steps() -> None: - j = tensor_([[1.0, 0.0], [0.0, 1.0]]) - agg = GradVac() - agg.eps = 1e-6 - assert agg(j).isfinite().all() - agg.reset() - agg.eps = 1e-10 - assert agg(j).isfinite().all() + J = tensor_([[1.0, 0.0], [0.0, 1.0]]) + A = GradVac() + A.eps = 1e-6 + assert A(J).isfinite().all() + A.reset() + A.eps = 1e-10 + assert A(J).isfinite().all() def test_zero_rows_returns_zero_vector() -> None: @@ -73,25 +73,25 @@ def test_zero_columns_returns_zero_vector() -> None: def test_reproducible_with_manual_seed() -> None: - j = randn_((3, 8)) + J = randn_((3, 8)) torch.manual_seed(12345) - a1 = GradVac(beta=0.3) - out1 = a1(j) + A1 = GradVac(beta=0.3) + out1 = A1(J) torch.manual_seed(12345) - a2 = GradVac(beta=0.3) - out2 = a2(j) + A2 = GradVac(beta=0.3) + out2 = A2(J) assert_close(out1, out2) @mark.parametrize("matrix", typical_matrices_2_plus_rows) def test_reset_restores_first_step_behavior(matrix: Tensor) -> None: torch.manual_seed(7) - agg = GradVac(beta=0.5) - first = agg(matrix) - agg(matrix) - agg.reset() + A = GradVac(beta=0.5) + first = A(matrix) + A(matrix) + A.reset() torch.manual_seed(7) - assert_close(first, agg(matrix)) + assert_close(first, A(matrix)) @mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) @@ -117,8 +117,8 @@ def test_weighting_eps_rejects_negative() -> None: def test_weighting_reset_restores_first_step_behavior() -> None: - j = randn_((3, 8)) - G = j @ j.T + J = randn_((3, 8)) + G = J @ J.T torch.manual_seed(7) w = GradVacWeighting(beta=0.5) first = w(G) @@ -131,16 +131,16 @@ def test_weighting_reset_restores_first_step_behavior() -> None: def test_aggregator_and_weighting_agree() -> None: """GradVac()(J) == GradVacWeighting()(J @ J.T) @ J for any matrix J.""" - j = randn_((3, 8)) - G = j @ j.T + J = randn_((3, 8)) + G = J @ J.T torch.manual_seed(42) - agg = GradVac(beta=0.3) - expected = agg(j) + A = GradVac(beta=0.3) + expected = A(J) torch.manual_seed(42) - weighting = GradVacWeighting(beta=0.3) - weights = weighting(G) - result = weights @ j + W = GradVacWeighting(beta=0.3) + weights = W(G) + result = weights @ J assert_close(result, expected, rtol=1e-4, atol=1e-4) From 1da5f6ececc31d6dba2370213eea2a06962dc6d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 12 Apr 2026 21:27:02 +0200 Subject: [PATCH 09/16] Add comment about why we move to cpu --- src/torchjd/aggregation/_gradvac.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/torchjd/aggregation/_gradvac.py b/src/torchjd/aggregation/_gradvac.py index dc43a34b..d2fbcbf0 100644 --- a/src/torchjd/aggregation/_gradvac.py +++ b/src/torchjd/aggregation/_gradvac.py @@ -151,6 +151,7 @@ def reset(self) -> None: self._state_key = None def forward(self, gramian: PSDMatrix, /) -> Tensor: + # Move all computations on cpu to avoid moving memory between cpu and gpu at each iteration device = gramian.device dtype = gramian.dtype cpu = torch.device("cpu") From 21d55f981ff608c7f6f1b1a9b83de33a8e523f6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 12 Apr 2026 21:31:27 +0200 Subject: [PATCH 10/16] Add GradVac to the aggregator table in README --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index ccf443fc..05df2b1e 100644 --- a/README.md +++ b/README.md @@ -281,6 +281,7 @@ TorchJD provides many existing aggregators from the literature, listed in the fo | [Constant](https://torchjd.org/stable/docs/aggregation/constant#torchjd.aggregation.Constant) | [ConstantWeighting](https://torchjd.org/stable/docs/aggregation/constant#torchjd.aggregation.ConstantWeighting) | - | | [DualProj](https://torchjd.org/stable/docs/aggregation/dualproj#torchjd.aggregation.DualProj) | [DualProjWeighting](https://torchjd.org/stable/docs/aggregation/dualproj#torchjd.aggregation.DualProjWeighting) | [Gradient Episodic Memory for Continual Learning](https://arxiv.org/pdf/1706.08840) | | [GradDrop](https://torchjd.org/stable/docs/aggregation/graddrop#torchjd.aggregation.GradDrop) | - | [Just Pick a Sign: Optimizing Deep Multitask Models with Gradient Sign Dropout](https://arxiv.org/pdf/2010.06808) | +| [GradVac](https://torchjd.org/stable/docs/aggregation/gradvac#torchjd.aggregation.GradVac) | [GradVacWeighting](https://torchjd.org/stable/docs/aggregation/gradvac#torchjd.aggregation.GradVacWeighting) | [Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Models](https://arxiv.org/pdf/2010.05874) | | [IMTLG](https://torchjd.org/stable/docs/aggregation/imtl_g#torchjd.aggregation.IMTLG) | [IMTLGWeighting](https://torchjd.org/stable/docs/aggregation/imtl_g#torchjd.aggregation.IMTLGWeighting) | [Towards Impartial Multi-task Learning](https://discovery.ucl.ac.uk/id/eprint/10120667/) | | [Krum](https://torchjd.org/stable/docs/aggregation/krum#torchjd.aggregation.Krum) | [KrumWeighting](https://torchjd.org/stable/docs/aggregation/krum#torchjd.aggregation.KrumWeighting) | [Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent](https://proceedings.neurips.cc/paper/2017/file/f4b9ec30ad9f68f89b29639786cb62ef-Paper.pdf) | | [Mean](https://torchjd.org/stable/docs/aggregation/mean#torchjd.aggregation.Mean) | [MeanWeighting](https://torchjd.org/stable/docs/aggregation/mean#torchjd.aggregation.MeanWeighting) | - | From 17b1dd5130f242d6a065c1329d42f7745fa02b1a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 12 Apr 2026 21:33:15 +0200 Subject: [PATCH 11/16] Add changelog entry --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c9db0d9b..f32df49b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,10 @@ changelog does not include internal changes that do not affect the user. ## [Unreleased] +### Added + +- Added `GradVac` and `GradVacWeighting` from [Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Models](https://arxiv.org/pdf/2010.05874). + ## [0.9.0] - 2026-02-24 ### Added From f4e8e60423ab417ce749912d67d40fb4583f11b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sun, 12 Apr 2026 21:40:08 +0200 Subject: [PATCH 12/16] Remove seed setting in test_aggregator_output Seed is already set to 0 because of the autoused fix_randomness fixture declared in conftest.py --- tests/unit/aggregation/test_values.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/unit/aggregation/test_values.py b/tests/unit/aggregation/test_values.py index 4af2b810..f468dc44 100644 --- a/tests/unit/aggregation/test_values.py +++ b/tests/unit/aggregation/test_values.py @@ -1,4 +1,3 @@ -import torch from pytest import mark, param from torch import Tensor, tensor from torch.testing import assert_close @@ -118,8 +117,6 @@ def test_aggregator_output(A: Aggregator, J: Tensor, expected_output: Tensor) -> None: """Test that the output values of an aggregator are fixed (on cpu).""" - if isinstance(A, GradVac): - torch.manual_seed(0) assert_close(A(J), expected_output, rtol=0, atol=1e-4) From 75c89c10b0dd58914bffd9c3466bb216c62d0a78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= <31951177+ValerianRey@users.noreply.github.com> Date: Mon, 13 Apr 2026 13:51:18 +0200 Subject: [PATCH 13/16] fix(aggregation): Add fallback in NashMTL (#640) --- CHANGELOG.md | 5 +++++ src/torchjd/aggregation/_nash_mtl.py | 5 +++-- tests/unit/aggregation/test_nash_mtl.py | 13 +++++++++++-- 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c9db0d9b..bfd051c0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,11 @@ changelog does not include internal changes that do not affect the user. ## [Unreleased] +### Fixed + +- Added a fallback for when the inner optimization of `NashMTL` fails (which can happen for example + on the matrix [[0., 0.], [0., 1.]]). + ## [0.9.0] - 2026-02-24 ### Added diff --git a/src/torchjd/aggregation/_nash_mtl.py b/src/torchjd/aggregation/_nash_mtl.py index 06e1293d..a20be617 100644 --- a/src/torchjd/aggregation/_nash_mtl.py +++ b/src/torchjd/aggregation/_nash_mtl.py @@ -158,9 +158,10 @@ def _solve_optimization(self, gtg: np.ndarray) -> np.ndarray: try: self.prob.solve(solver=cp.ECOS, warm_start=True, max_iters=100) - except SolverError: - # On macOS, this can happen with: Solver 'ECOS' failed. + except (SolverError, ValueError): + # On macOS, SolverError can happen with: Solver 'ECOS' failed. # No idea why. The corresponding matrix is of shape [9, 11] with rank 5. + # ValueError happens with for example matrix [[0., 0.], [0., 1.]]. # Maybe other exceptions can happen in other cases. self.alpha_param.value = self.prvs_alpha_param.value diff --git a/tests/unit/aggregation/test_nash_mtl.py b/tests/unit/aggregation/test_nash_mtl.py index a1200d46..d82fca41 100644 --- a/tests/unit/aggregation/test_nash_mtl.py +++ b/tests/unit/aggregation/test_nash_mtl.py @@ -1,7 +1,7 @@ from pytest import mark from torch import Tensor from torch.testing import assert_close -from utils.tensors import ones_, randn_ +from utils.tensors import ones_, randn_, tensor_ try: from torchjd.aggregation import NashMTL @@ -19,6 +19,10 @@ def _make_aggregator(matrix: Tensor) -> NashMTL: standard_pairs = [(_make_aggregator(matrix), matrix) for matrix in nash_mtl_matrices] +edge_case_matrices = [ + tensor_([[0.0, 0.0], [0.0, 1.0]]) # This leads to a (caught) ValueError in _solve_optimization. +] +edge_case_pairs = [(_make_aggregator(matrix), matrix) for matrix in edge_case_matrices] requires_grad_pairs = [(NashMTL(n_tasks=3), ones_(3, 5, requires_grad=True))] @@ -27,8 +31,13 @@ def _make_aggregator(matrix: Tensor) -> NashMTL: @mark.filterwarnings( "ignore:Solution may be inaccurate.", "ignore:You are solving a parameterized problem that is not DPP.", + "ignore:divide by zero encountered in divide", + "ignore:divide by zero encountered in true_divide", + "ignore:overflow encountered in divide", + "ignore:overflow encountered in true_divide", + "ignore:invalid value encountered in matmul", ) -@mark.parametrize(["aggregator", "matrix"], standard_pairs) +@mark.parametrize(["aggregator", "matrix"], standard_pairs + edge_case_pairs) def test_expected_structure(aggregator: NashMTL, matrix: Tensor) -> None: assert_expected_structure(aggregator, matrix) From 9ffdd135ca13797d488f74adfab9f38ab1319ee0 Mon Sep 17 00:00:00 2001 From: rkhosrowshahi Date: Mon, 13 Apr 2026 11:57:53 -0400 Subject: [PATCH 14/16] Revert plot test refactors; keep GradVac in interactive plotter --- tests/plots/_utils.py | 25 ++---- tests/plots/interactive_plotter.py | 126 +++++++---------------------- 2 files changed, 35 insertions(+), 116 deletions(-) diff --git a/tests/plots/_utils.py b/tests/plots/_utils.py index 40118fea..dc69bfda 100644 --- a/tests/plots/_utils.py +++ b/tests/plots/_utils.py @@ -1,5 +1,3 @@ -from collections.abc import Callable - import numpy as np import torch from plotly import graph_objects as go @@ -9,22 +7,14 @@ class Plotter: - def __init__( - self, - aggregator_factories: dict[str, Callable[[], Aggregator]], - selected_keys: list[str], - matrix: torch.Tensor, - seed: int = 0, - ) -> None: - self._aggregator_factories = aggregator_factories - self.selected_keys = selected_keys + def __init__(self, aggregators: list[Aggregator], matrix: torch.Tensor, seed: int = 0) -> None: + self.aggregators = aggregators self.matrix = matrix self.seed = seed def make_fig(self) -> Figure: torch.random.manual_seed(self.seed) - aggregators = [self._aggregator_factories[key]() for key in self.selected_keys] - results = [agg(self.matrix) for agg in aggregators] + results = [agg(self.matrix) for agg in self.aggregators] fig = go.Figure() @@ -33,19 +23,14 @@ def make_fig(self) -> Figure: fig.add_trace(cone) for i in range(len(self.matrix)): - scatter = make_vector_scatter( - self.matrix[i], - "blue", - f"g{i + 1}", - textposition="top right", - ) + scatter = make_vector_scatter(self.matrix[i], "blue", f"g{i + 1}") fig.add_trace(scatter) for i in range(len(results)): scatter = make_vector_scatter( results[i], "black", - self.selected_keys[i], + str(self.aggregators[i]), showlegend=True, dash=True, ) diff --git a/tests/plots/interactive_plotter.py b/tests/plots/interactive_plotter.py index 56031a61..2411e4c3 100644 --- a/tests/plots/interactive_plotter.py +++ b/tests/plots/interactive_plotter.py @@ -1,20 +1,17 @@ import logging import os import webbrowser -from collections.abc import Callable from threading import Timer import numpy as np import torch from dash import Dash, Input, Output, callback, dcc, html from plotly.graph_objs import Figure -from typing_extensions import Unpack from plots._utils import Plotter, angle_to_coord, coord_to_angle from torchjd.aggregation import ( IMTLG, MGDA, - Aggregator, AlignedMTL, CAGrad, ConFIG, @@ -34,14 +31,6 @@ MAX_LENGTH = 25.0 -def _format_angle_display(angle: float) -> str: - return f"{angle:.4f} rad ({np.degrees(angle):.1f}°)" - - -def _format_length_display(r: float) -> str: - return f"{r:.4f}" - - def main() -> None: log = logging.getLogger("werkzeug") log.setLevel(logging.CRITICAL) @@ -54,30 +43,27 @@ def main() -> None: ], ) - n_tasks = matrix.shape[0] - aggregator_factories: dict[str, Callable[[], Aggregator]] = { - "AlignedMTL-min": lambda: AlignedMTL(scale_mode="min"), - "AlignedMTL-median": lambda: AlignedMTL(scale_mode="median"), - "AlignedMTL-RMSE": lambda: AlignedMTL(scale_mode="rmse"), - str(CAGrad(c=0.5)): lambda: CAGrad(c=0.5), - str(ConFIG()): lambda: ConFIG(), - str(DualProj()): lambda: DualProj(), - str(GradDrop()): lambda: GradDrop(), - str(GradVac()): lambda: GradVac(), - str(IMTLG()): lambda: IMTLG(), - str(Mean()): lambda: Mean(), - str(MGDA()): lambda: MGDA(), - str(NashMTL(n_tasks=n_tasks)): lambda: NashMTL(n_tasks=n_tasks), - str(PCGrad()): lambda: PCGrad(), - str(Random()): lambda: Random(), - str(Sum()): lambda: Sum(), - str(TrimmedMean(trim_number=1)): lambda: TrimmedMean(trim_number=1), - str(UPGrad()): lambda: UPGrad(), - } - - aggregator_strings = list(aggregator_factories.keys()) - - plotter = Plotter(aggregator_factories, [], matrix) + aggregators = [ + AlignedMTL(), + CAGrad(c=0.5), + ConFIG(), + DualProj(), + GradDrop(), + GradVac(), + IMTLG(), + Mean(), + MGDA(), + NashMTL(n_tasks=matrix.shape[0]), + PCGrad(), + Random(), + Sum(), + TrimmedMean(trim_number=1), + UPGrad(), + ] + + aggregators_dict = {str(aggregator): aggregator for aggregator in aggregators} + + plotter = Plotter([], matrix) app = Dash(__name__) @@ -112,6 +98,7 @@ def main() -> None: gradient_slider_inputs.append(Input(angle_input, "value")) gradient_slider_inputs.append(Input(r_input, "value")) + aggregator_strings = [str(aggregator) for aggregator in aggregators] checklist = dcc.Checklist(aggregator_strings, [], id="aggregator-checklist") control_div = html.Div( @@ -130,32 +117,22 @@ def update_seed(value: int) -> Figure: plotter.seed = value return plotter.make_fig() - n_gradients = len(matrix) - gradient_value_outputs: list[Output] = [] - for i in range(n_gradients): - gradient_value_outputs.append(Output(f"g{i + 1}-angle-value", "children")) - gradient_value_outputs.append(Output(f"g{i + 1}-length-value", "children")) - @callback( Output("aggregations-fig", "figure", allow_duplicate=True), - *gradient_value_outputs, *gradient_slider_inputs, prevent_initial_call=True, ) - def update_gradient_coordinate(*values: str) -> tuple[Figure, Unpack[tuple[str, ...]]]: + def update_gradient_coordinate(*values: str) -> Figure: values_ = [float(value) for value in values] - display_parts: list[str] = [] for j in range(len(values_) // 2): angle = values_[2 * j] r = values_[2 * j + 1] x, y = angle_to_coord(angle, r) plotter.matrix[j, 0] = x plotter.matrix[j, 1] = y - display_parts.append(_format_angle_display(angle)) - display_parts.append(_format_length_display(r)) - return (plotter.make_fig(), *display_parts) + return plotter.make_fig() @callback( Output("aggregations-fig", "figure", allow_duplicate=True), @@ -163,7 +140,9 @@ def update_gradient_coordinate(*values: str) -> tuple[Figure, Unpack[tuple[str, prevent_initial_call=True, ) def update_aggregators(value: list[str]) -> Figure: - plotter.selected_keys = list(value) + aggregator_keys = value + new_aggregators = [aggregators_dict[key] for key in aggregator_keys] + plotter.aggregators = new_aggregators return plotter.make_fig() Timer(1, open_browser).start() @@ -196,56 +175,11 @@ def make_gradient_div( style={"width": "250px"}, ) - label_style: dict[str, str | int] = { - "display": "inline-block", - "width": "52px", - "margin-right": "8px", - "vertical-align": "middle", - } - value_style: dict[str, str] = { - "display": "inline-block", - "margin-left": "10px", - "min-width": "140px", - "font-family": "monospace", - "font-size": "13px", - "vertical-align": "middle", - } - row_style: dict[str, str] = {"display": "block", "margin-bottom": "6px"} div = html.Div( [ - dcc.Markdown( - f"$g_{{{i + 1}}}$", - mathjax=True, - style={ - "margin": "0 0 6px 0", - "font-weight": "bold", - "display": "block", - }, - ), - html.Div( - [ - html.Span("Angle", style=label_style), - angle_input, - html.Span( - id=f"g{i + 1}-angle-value", - children=_format_angle_display(angle), - style=value_style, - ), - ], - style=row_style, - ), - html.Div( - [ - html.Span("Length", style=label_style), - r_input, - html.Span( - id=f"g{i + 1}-length-value", - children=_format_length_display(r), - style=value_style, - ), - ], - style={**row_style, "margin-bottom": "12px"}, - ), + html.P(f"g{i + 1}", style={"display": "inline-block", "margin-right": 20}), + angle_input, + r_input, ], ) return div, angle_input, r_input From e626475d3a8f70659cc373ca1827b12442ad46b5 Mon Sep 17 00:00:00 2001 From: rkhosrowshahi Date: Mon, 13 Apr 2026 12:39:29 -0400 Subject: [PATCH 15/16] docs(aggregation): add grouping usage example and fix GradVac note Add a Grouping example page covering all four strategies from the GradVac paper (whole_model, enc_dec, all_layer, all_matrix), with a runnable code block for each. Update the GradVac docstring note to link to the new page instead of the previous placeholder text. Fix trailing whitespace in CHANGELOG.md. --- CHANGELOG.md | 2 +- docs/source/examples/grouping.rst | 167 ++++++++++++++++++++++++++++ docs/source/examples/index.rst | 4 + src/torchjd/aggregation/_gradvac.py | 9 +- 4 files changed, 178 insertions(+), 4 deletions(-) create mode 100644 docs/source/examples/grouping.rst diff --git a/CHANGELOG.md b/CHANGELOG.md index ad21dcc3..5a15bcae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,7 @@ changelog does not include internal changes that do not affect the user. ### Added -- Added `GradVac` and `GradVacWeighting` from +- Added `GradVac` and `GradVacWeighting` from [Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Models](https://arxiv.org/pdf/2010.05874). - Added a fallback for when the inner optimization of `NashMTL` fails (which can happen for example on the matrix [[0., 0.], [0., 1.]]). diff --git a/docs/source/examples/grouping.rst b/docs/source/examples/grouping.rst new file mode 100644 index 00000000..aa9e85aa --- /dev/null +++ b/docs/source/examples/grouping.rst @@ -0,0 +1,167 @@ +Grouping +======== + +When applying a conflict-resolving aggregator such as :class:`~torchjd.aggregation.GradVac` in +multi-task learning, the cosine similarities between task gradients can be computed at different +granularities. The GradVac paper introduces four strategies, each partitioning the shared +parameter vector differently: + +1. **Whole Model** (default) — one group covering all shared parameters. +2. **Encoder-Decoder** — one group per top-level sub-network (e.g. encoder and decoder separately). +3. **All Layers** — one group per leaf module of the encoder. +4. **All Matrices** — one group per individual parameter tensor. + +In TorchJD, grouping is achieved by calling :func:`~torchjd.autojac.jac_to_grad` once per group +after :func:`~torchjd.autojac.mtl_backward`, with a dedicated aggregator instance per group. +For stateful aggregators such as :class:`~torchjd.aggregation.GradVac`, each instance +independently maintains its own EMA state :math:`\hat{\phi}`, matching the per-block targets from +the original paper. + +.. note:: + The grouping is orthogonal to the choice of + :func:`~torchjd.autojac.backward` vs :func:`~torchjd.autojac.mtl_backward`. Those functions + determine *which* parameters receive Jacobians; grouping then determines *how* those Jacobians + are partitioned for aggregation. Calling :func:`~torchjd.autojac.jac_to_grad` once on all shared + parameters corresponds to the Whole Model strategy. Splitting those parameters into + sub-networks and calling :func:`~torchjd.autojac.jac_to_grad` separately on each — with a + dedicated aggregator per sub-network — gives an arbitrary custom grouping, such as the + Encoder-Decoder strategy described in the GradVac paper for encoder-decoder architectures. + +.. note:: + The examples below use :class:`~torchjd.aggregation.GradVac`, but the same pattern applies to + any aggregator. + +1. Whole Model +-------------- + +A single :class:`~torchjd.aggregation.GradVac` instance aggregates all shared parameters +together. Cosine similarities are computed between the full task gradient vectors. + +.. testcode:: + :emphasize-lines: 14, 19 + + import torch + from torch.nn import Linear, MSELoss, ReLU, Sequential + from torch.optim import SGD + + from torchjd.aggregation import GradVac + from torchjd.autojac import jac_to_grad, mtl_backward + + encoder = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU()) + task1_head, task2_head = Linear(3, 1), Linear(3, 1) + optimizer = SGD([*encoder.parameters(), *task1_head.parameters(), *task2_head.parameters()], lr=0.1) + loss_fn = MSELoss() + inputs, t1, t2 = torch.randn(8, 16, 10), torch.randn(8, 16, 1), torch.randn(8, 16, 1) + + gradvac = GradVac() + + for x, y1, y2 in zip(inputs, t1, t2): + features = encoder(x) + mtl_backward([loss_fn(task1_head(features), y1), loss_fn(task2_head(features), y2)], features=features) + jac_to_grad(encoder.parameters(), gradvac) + optimizer.step() + optimizer.zero_grad() + +2. Encoder-Decoder +------------------ + +One :class:`~torchjd.aggregation.GradVac` instance per top-level sub-network. Here the model +is split into an encoder and a decoder; cosine similarities are computed separately within each. +Passing ``features=dec_out`` to :func:`~torchjd.autojac.mtl_backward` causes both sub-networks +to receive Jacobians, which are then aggregated independently. + +.. testcode:: + :emphasize-lines: 8-9, 15-16, 22-23 + + import torch + from torch.nn import Linear, MSELoss, ReLU, Sequential + from torch.optim import SGD + + from torchjd.aggregation import GradVac + from torchjd.autojac import jac_to_grad, mtl_backward + + encoder = Sequential(Linear(10, 5), ReLU()) + decoder = Sequential(Linear(5, 3), ReLU()) + task1_head, task2_head = Linear(3, 1), Linear(3, 1) + optimizer = SGD([*encoder.parameters(), *decoder.parameters(), *task1_head.parameters(), *task2_head.parameters()], lr=0.1) + loss_fn = MSELoss() + inputs, t1, t2 = torch.randn(8, 16, 10), torch.randn(8, 16, 1), torch.randn(8, 16, 1) + + encoder_gradvac = GradVac() + decoder_gradvac = GradVac() + + for x, y1, y2 in zip(inputs, t1, t2): + enc_out = encoder(x) + dec_out = decoder(enc_out) + mtl_backward([loss_fn(task1_head(dec_out), y1), loss_fn(task2_head(dec_out), y2)], features=dec_out) + jac_to_grad(encoder.parameters(), encoder_gradvac) + jac_to_grad(decoder.parameters(), decoder_gradvac) + optimizer.step() + optimizer.zero_grad() + +3. All Layers +------------- + +One :class:`~torchjd.aggregation.GradVac` instance per leaf module. Cosine similarities are +computed between the per-layer blocks of the task gradients. + +.. testcode:: + :emphasize-lines: 14-15, 20-21 + + import torch + from torch.nn import Linear, MSELoss, ReLU, Sequential + from torch.optim import SGD + + from torchjd.aggregation import GradVac + from torchjd.autojac import jac_to_grad, mtl_backward + + encoder = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU()) + task1_head, task2_head = Linear(3, 1), Linear(3, 1) + optimizer = SGD([*encoder.parameters(), *task1_head.parameters(), *task2_head.parameters()], lr=0.1) + loss_fn = MSELoss() + inputs, t1, t2 = torch.randn(8, 16, 10), torch.randn(8, 16, 1), torch.randn(8, 16, 1) + + leaf_layers = [m for m in encoder.modules() if not list(m.children()) and list(m.parameters())] + gradvacs = [GradVac() for _ in leaf_layers] + + for x, y1, y2 in zip(inputs, t1, t2): + features = encoder(x) + mtl_backward([loss_fn(task1_head(features), y1), loss_fn(task2_head(features), y2)], features=features) + for layer, gradvac in zip(leaf_layers, gradvacs): + jac_to_grad(layer.parameters(), gradvac) + optimizer.step() + optimizer.zero_grad() + +4. All Matrices +--------------- + +One :class:`~torchjd.aggregation.GradVac` instance per individual parameter tensor. Cosine +similarities are computed between the per-tensor blocks of the task gradients (e.g. weights and +biases of each layer are treated as separate groups). + +.. testcode:: + :emphasize-lines: 14-15, 20-21 + + import torch + from torch.nn import Linear, MSELoss, ReLU, Sequential + from torch.optim import SGD + + from torchjd.aggregation import GradVac + from torchjd.autojac import jac_to_grad, mtl_backward + + encoder = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU()) + task1_head, task2_head = Linear(3, 1), Linear(3, 1) + optimizer = SGD([*encoder.parameters(), *task1_head.parameters(), *task2_head.parameters()], lr=0.1) + loss_fn = MSELoss() + inputs, t1, t2 = torch.randn(8, 16, 10), torch.randn(8, 16, 1), torch.randn(8, 16, 1) + + shared_params = list(encoder.parameters()) + gradvacs = [GradVac() for _ in shared_params] + + for x, y1, y2 in zip(inputs, t1, t2): + features = encoder(x) + mtl_backward([loss_fn(task1_head(features), y1), loss_fn(task2_head(features), y2)], features=features) + for param, gradvac in zip(shared_params, gradvacs): + jac_to_grad([param], gradvac) + optimizer.step() + optimizer.zero_grad() diff --git a/docs/source/examples/index.rst b/docs/source/examples/index.rst index 49c5c1f4..c1f1e836 100644 --- a/docs/source/examples/index.rst +++ b/docs/source/examples/index.rst @@ -29,6 +29,9 @@ This section contains some usage examples for TorchJD. - :doc:`PyTorch Lightning Integration ` showcases how to combine TorchJD with PyTorch Lightning, by providing an example implementation of a multi-task ``LightningModule`` optimized by Jacobian descent. +- :doc:`Grouping ` shows how to apply an aggregator independently per parameter group + (e.g. per layer), so that conflict resolution happens at a finer granularity than the full + shared parameter vector. - :doc:`Automatic Mixed Precision ` shows how to combine mixed precision training with TorchJD. .. toctree:: @@ -43,3 +46,4 @@ This section contains some usage examples for TorchJD. monitoring.rst lightning_integration.rst amp.rst + grouping.rst diff --git a/src/torchjd/aggregation/_gradvac.py b/src/torchjd/aggregation/_gradvac.py index d2fbcbf0..b62728dd 100644 --- a/src/torchjd/aggregation/_gradvac.py +++ b/src/torchjd/aggregation/_gradvac.py @@ -43,9 +43,12 @@ class GradVac(GramianWeightedAggregator): you need reproducibility. .. note:: - To apply GradVac with per-layer or per-parameter-group granularity, first aggregate the - Jacobian into groups, apply GradVac per group, and sum the results. See the grouping usage - example for details. + To apply GradVac with per-layer or per-parameter-group granularity, create a separate + :class:`GradVac` instance for each group and call + :func:`~torchjd.autojac.jac_to_grad` once per group after + :func:`~torchjd.autojac.mtl_backward`. Each instance maintains its own EMA state, + matching the per-block targets :math:`\hat{\phi}_{ijk}` from the original paper. See + the :doc:`Grouping ` example for details. """ def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None: From a244d2b0de13f4085a91b07e948c9e7c14eeadf8 Mon Sep 17 00:00:00 2001 From: rkhosrowshahi Date: Mon, 13 Apr 2026 12:54:29 -0400 Subject: [PATCH 16/16] docs(changelog): split Unreleased into Added and Fixed for GradVac and NashMTL --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5a15bcae..5104aa77 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,9 @@ changelog does not include internal changes that do not affect the user. - Added `GradVac` and `GradVacWeighting` from [Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Models](https://arxiv.org/pdf/2010.05874). + +### Fixed + - Added a fallback for when the inner optimization of `NashMTL` fails (which can happen for example on the matrix [[0., 0.], [0., 1.]]).