Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ changelog does not include internal changes that do not affect the user.

## [Unreleased]

### Added

- Added `GradVac` and `GradVacWeighting` from [Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Models](https://arxiv.org/pdf/2010.05874).

## [0.9.0] - 2026-02-24

### Added
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ TorchJD provides many existing aggregators from the literature, listed in the fo
| [Constant](https://torchjd.org/stable/docs/aggregation/constant#torchjd.aggregation.Constant) | [ConstantWeighting](https://torchjd.org/stable/docs/aggregation/constant#torchjd.aggregation.ConstantWeighting) | - |
| [DualProj](https://torchjd.org/stable/docs/aggregation/dualproj#torchjd.aggregation.DualProj) | [DualProjWeighting](https://torchjd.org/stable/docs/aggregation/dualproj#torchjd.aggregation.DualProjWeighting) | [Gradient Episodic Memory for Continual Learning](https://arxiv.org/pdf/1706.08840) |
| [GradDrop](https://torchjd.org/stable/docs/aggregation/graddrop#torchjd.aggregation.GradDrop) | - | [Just Pick a Sign: Optimizing Deep Multitask Models with Gradient Sign Dropout](https://arxiv.org/pdf/2010.06808) |
| [GradVac](https://torchjd.org/stable/docs/aggregation/gradvac#torchjd.aggregation.GradVac) | [GradVacWeighting](https://torchjd.org/stable/docs/aggregation/gradvac#torchjd.aggregation.GradVacWeighting) | [Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Models](https://arxiv.org/pdf/2010.05874) |
| [IMTLG](https://torchjd.org/stable/docs/aggregation/imtl_g#torchjd.aggregation.IMTLG) | [IMTLGWeighting](https://torchjd.org/stable/docs/aggregation/imtl_g#torchjd.aggregation.IMTLGWeighting) | [Towards Impartial Multi-task Learning](https://discovery.ucl.ac.uk/id/eprint/10120667/) |
| [Krum](https://torchjd.org/stable/docs/aggregation/krum#torchjd.aggregation.Krum) | [KrumWeighting](https://torchjd.org/stable/docs/aggregation/krum#torchjd.aggregation.KrumWeighting) | [Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent](https://proceedings.neurips.cc/paper/2017/file/f4b9ec30ad9f68f89b29639786cb62ef-Paper.pdf) |
| [Mean](https://torchjd.org/stable/docs/aggregation/mean#torchjd.aggregation.Mean) | [MeanWeighting](https://torchjd.org/stable/docs/aggregation/mean#torchjd.aggregation.MeanWeighting) | - |
Expand Down
14 changes: 14 additions & 0 deletions docs/source/docs/aggregation/gradvac.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
:hide-toc:

GradVac
=======

.. autoclass:: torchjd.aggregation.GradVac
:members:
:undoc-members:
:exclude-members: forward

.. autoclass:: torchjd.aggregation.GradVacWeighting
:members:
:undoc-members:
:exclude-members: forward
1 change: 1 addition & 0 deletions docs/source/docs/aggregation/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Abstract base classes
dualproj.rst
flattening.rst
graddrop.rst
gradvac.rst
imtl_g.rst
krum.rst
mean.rst
Expand Down
3 changes: 3 additions & 0 deletions src/torchjd/aggregation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
from ._dualproj import DualProj, DualProjWeighting
from ._flattening import Flattening
from ._graddrop import GradDrop
from ._gradvac import GradVac, GradVacWeighting
from ._imtl_g import IMTLG, IMTLGWeighting
from ._krum import Krum, KrumWeighting
from ._mean import Mean, MeanWeighting
Expand All @@ -92,6 +93,8 @@
"Flattening",
"GeneralizedWeighting",
"GradDrop",
"GradVac",
"GradVacWeighting",
"IMTLG",
"IMTLGWeighting",
"Krum",
Expand Down
206 changes: 206 additions & 0 deletions src/torchjd/aggregation/_gradvac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
from __future__ import annotations

from typing import cast

import torch
from torch import Tensor

from torchjd._linalg import PSDMatrix

from ._aggregator_bases import GramianWeightedAggregator
from ._utils.non_differentiable import raise_non_differentiable_error
from ._weighting_bases import Weighting


class GradVac(GramianWeightedAggregator):
r"""
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` implementing the aggregation step of
Gradient Vaccine (GradVac) from `Gradient Vaccine: Investigating and Improving Multi-task
Optimization in Massively Multilingual Models (ICLR 2021 Spotlight)
<https://openreview.net/forum?id=F1vEjWK-lH_>`_.

For each task :math:`i`, the order in which other tasks :math:`j` are visited is drawn at
random. For each pair :math:`(i, j)`, the cosine similarity :math:`\phi_{ij}` between the
(possibly already modified) gradient of task :math:`i` and the original gradient of task
:math:`j` is compared to an EMA target :math:`\hat{\phi}_{ij}`. When
:math:`\phi_{ij} < \hat{\phi}_{ij}`, a closed-form correction adds a scaled copy of
:math:`g_j` to :math:`g_i^{(\mathrm{PC})}`. The EMA is then updated with
:math:`\hat{\phi}_{ij} \leftarrow (1-\beta)\hat{\phi}_{ij} + \beta \phi_{ij}`. The aggregated
vector is the sum of the modified rows.

This aggregator is stateful: it keeps :math:`\hat{\phi}` across calls. Use :meth:`reset` when
the number of tasks or dtype changes.

:param beta: EMA decay for :math:`\hat{\phi}` (paper default ``0.5``). You may read or assign
the :attr:`beta` attribute between steps to tune the EMA update.
:param eps: Small non-negative constant added to denominators when computing cosines and the
vaccine weight (default ``1e-8``); set to ``0`` to omit this stabilization. You may read or
assign the :attr:`eps` attribute between steps to tune numerical behavior.

.. note::
For each task :math:`i`, the order of other tasks :math:`j` is shuffled independently
using the global PyTorch RNG (``torch.randperm``). Seed it with ``torch.manual_seed`` if
you need reproducibility.

.. note::
To apply GradVac with per-layer or per-parameter-group granularity, first aggregate the
Jacobian into groups, apply GradVac per group, and sum the results. See the grouping usage
example for details.
"""

def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None:
weighting = GradVacWeighting(beta=beta, eps=eps)
super().__init__(weighting)
self._gradvac_weighting = weighting
self.register_full_backward_pre_hook(raise_non_differentiable_error)

@property
def beta(self) -> float:
"""EMA decay coefficient for :math:`\\hat{\\phi}` (paper default ``0.5``)."""

return self._gradvac_weighting.beta

@beta.setter
def beta(self, value: float) -> None:
self._gradvac_weighting.beta = value

@property
def eps(self) -> float:
"""Small non-negative constant added to denominators for numerical stability."""

return self._gradvac_weighting.eps

@eps.setter
def eps(self, value: float) -> None:
self._gradvac_weighting.eps = value

def reset(self) -> None:
"""Clears EMA state so the next forward starts from zero targets."""

self._gradvac_weighting.reset()

def __repr__(self) -> str:
return f"GradVac(beta={self.beta!r}, eps={self.eps!r})"


class GradVacWeighting(Weighting[PSDMatrix]):
r"""
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
:class:`~torchjd.aggregation.GradVac`.

All required quantities (gradient norms, cosine similarities, and their updates after the
vaccine correction) are derived purely from the Gramian, without needing the full Jacobian.
If :math:`g_i^{(\mathrm{PC})} = \sum_k c_{ik} g_k`, then:

.. math::

\|g_i^{(\mathrm{PC})}\|^2 = \mathbf{c}_i G \mathbf{c}_i^\top,\qquad
g_i^{(\mathrm{PC})} \cdot g_j = \mathbf{c}_i G_{:,j}

where :math:`G` is the Gramian. The correction :math:`g_i^{(\mathrm{PC})} \mathrel{+}= w
g_j` then becomes :math:`c_{ij} \mathrel{+}= w`, and the updated dot products follow
immediately.

This weighting is stateful: it keeps :math:`\hat{\phi}` across calls. Use :meth:`reset` when
the number of tasks or dtype changes.

:param beta: EMA decay for :math:`\hat{\phi}` (paper default ``0.5``).
:param eps: Small non-negative constant added to denominators (default ``1e-8``).
"""

def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None:
super().__init__()
if not (0.0 <= beta <= 1.0):
raise ValueError(f"Parameter `beta` must be in [0, 1]. Found beta={beta!r}.")
if eps < 0.0:
raise ValueError(f"Parameter `eps` must be non-negative. Found eps={eps!r}.")

self._beta = beta
self._eps = eps
self._phi_t: Tensor | None = None
self._state_key: tuple[int, torch.dtype] | None = None

@property
def beta(self) -> float:
"""EMA decay coefficient for :math:`\\hat{\\phi}` (paper default ``0.5``)."""

return self._beta

@beta.setter
def beta(self, value: float) -> None:
if not (0.0 <= value <= 1.0):
raise ValueError(f"Attribute `beta` must be in [0, 1]. Found beta={value!r}.")
self._beta = value

@property
def eps(self) -> float:
"""Small non-negative constant added to denominators for numerical stability."""

return self._eps

@eps.setter
def eps(self, value: float) -> None:
if value < 0.0:
raise ValueError(f"Attribute `eps` must be non-negative. Found eps={value!r}.")
self._eps = value

def reset(self) -> None:
"""Clears EMA state so the next forward starts from zero targets."""

self._phi_t = None
self._state_key = None

def forward(self, gramian: PSDMatrix, /) -> Tensor:
# Move all computations on cpu to avoid moving memory between cpu and gpu at each iteration
device = gramian.device
dtype = gramian.dtype
cpu = torch.device("cpu")

G = cast(PSDMatrix, gramian.to(device=cpu))
m = G.shape[0]

self._ensure_state(m, dtype)
phi_t = cast(Tensor, self._phi_t)

beta = self._beta
eps = self._eps

# C[i, :] holds coefficients such that g_i^PC = sum_k C[i,k] * g_k (original gradients).
# Initially each modified gradient equals the original, so C = I.
C = torch.eye(m, device=cpu, dtype=dtype)

for i in range(m):
# Dot products of g_i^PC with every original g_j, shape (m,).
cG = C[i] @ G

others = [j for j in range(m) if j != i]
perm = torch.randperm(len(others))
shuffled_js = [others[idx] for idx in perm.tolist()]

for j in shuffled_js:
dot_ij = cG[j]
norm_i_sq = (cG * C[i]).sum()
norm_i = norm_i_sq.clamp(min=0.0).sqrt()
norm_j = G[j, j].clamp(min=0.0).sqrt()
denom = norm_i * norm_j + eps
phi_ijk = dot_ij / denom

phi_hat = phi_t[i, j]
if phi_ijk < phi_hat:
sqrt_1_phi2 = (1.0 - phi_ijk * phi_ijk).clamp(min=0.0).sqrt()
sqrt_1_hat2 = (1.0 - phi_hat * phi_hat).clamp(min=0.0).sqrt()
denom_w = norm_j * sqrt_1_hat2 + eps
w = norm_i * (phi_hat * sqrt_1_phi2 - phi_ijk * sqrt_1_hat2) / denom_w
C[i, j] = C[i, j] + w
cG = cG + w * G[j]

phi_t[i, j] = (1.0 - beta) * phi_hat + beta * phi_ijk

weights = C.sum(dim=0)
return weights.to(device)

def _ensure_state(self, m: int, dtype: torch.dtype) -> None:
key = (m, dtype)
if self._state_key != key or self._phi_t is None:
self._phi_t = torch.zeros(m, m, dtype=dtype)
self._state_key = key
25 changes: 20 additions & 5 deletions tests/plots/_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections.abc import Callable

import numpy as np
import torch
from plotly import graph_objects as go
Expand All @@ -7,14 +9,22 @@


class Plotter:
def __init__(self, aggregators: list[Aggregator], matrix: torch.Tensor, seed: int = 0) -> None:
self.aggregators = aggregators
def __init__(
self,
aggregator_factories: dict[str, Callable[[], Aggregator]],
selected_keys: list[str],
matrix: torch.Tensor,
seed: int = 0,
) -> None:
self._aggregator_factories = aggregator_factories
self.selected_keys = selected_keys
self.matrix = matrix
self.seed = seed

def make_fig(self) -> Figure:
torch.random.manual_seed(self.seed)
results = [agg(self.matrix) for agg in self.aggregators]
aggregators = [self._aggregator_factories[key]() for key in self.selected_keys]
results = [agg(self.matrix) for agg in aggregators]

fig = go.Figure()

Expand All @@ -23,14 +33,19 @@ def make_fig(self) -> Figure:
fig.add_trace(cone)

for i in range(len(self.matrix)):
scatter = make_vector_scatter(self.matrix[i], "blue", f"g{i + 1}")
scatter = make_vector_scatter(
self.matrix[i],
"blue",
f"g{i + 1}",
textposition="top right",
)
fig.add_trace(scatter)

for i in range(len(results)):
scatter = make_vector_scatter(
results[i],
"black",
str(self.aggregators[i]),
self.selected_keys[i],
showlegend=True,
dash=True,
)
Expand Down
Loading
Loading