diff --git a/CHANGELOG.md b/CHANGELOG.md index c9db0d9b..f32df49b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,10 @@ changelog does not include internal changes that do not affect the user. ## [Unreleased] +### Added + +- Added `GradVac` and `GradVacWeighting` from [Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Models](https://arxiv.org/pdf/2010.05874). + ## [0.9.0] - 2026-02-24 ### Added diff --git a/README.md b/README.md index ccf443fc..05df2b1e 100644 --- a/README.md +++ b/README.md @@ -281,6 +281,7 @@ TorchJD provides many existing aggregators from the literature, listed in the fo | [Constant](https://torchjd.org/stable/docs/aggregation/constant#torchjd.aggregation.Constant) | [ConstantWeighting](https://torchjd.org/stable/docs/aggregation/constant#torchjd.aggregation.ConstantWeighting) | - | | [DualProj](https://torchjd.org/stable/docs/aggregation/dualproj#torchjd.aggregation.DualProj) | [DualProjWeighting](https://torchjd.org/stable/docs/aggregation/dualproj#torchjd.aggregation.DualProjWeighting) | [Gradient Episodic Memory for Continual Learning](https://arxiv.org/pdf/1706.08840) | | [GradDrop](https://torchjd.org/stable/docs/aggregation/graddrop#torchjd.aggregation.GradDrop) | - | [Just Pick a Sign: Optimizing Deep Multitask Models with Gradient Sign Dropout](https://arxiv.org/pdf/2010.06808) | +| [GradVac](https://torchjd.org/stable/docs/aggregation/gradvac#torchjd.aggregation.GradVac) | [GradVacWeighting](https://torchjd.org/stable/docs/aggregation/gradvac#torchjd.aggregation.GradVacWeighting) | [Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Models](https://arxiv.org/pdf/2010.05874) | | [IMTLG](https://torchjd.org/stable/docs/aggregation/imtl_g#torchjd.aggregation.IMTLG) | [IMTLGWeighting](https://torchjd.org/stable/docs/aggregation/imtl_g#torchjd.aggregation.IMTLGWeighting) | [Towards Impartial Multi-task Learning](https://discovery.ucl.ac.uk/id/eprint/10120667/) | | [Krum](https://torchjd.org/stable/docs/aggregation/krum#torchjd.aggregation.Krum) | [KrumWeighting](https://torchjd.org/stable/docs/aggregation/krum#torchjd.aggregation.KrumWeighting) | [Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent](https://proceedings.neurips.cc/paper/2017/file/f4b9ec30ad9f68f89b29639786cb62ef-Paper.pdf) | | [Mean](https://torchjd.org/stable/docs/aggregation/mean#torchjd.aggregation.Mean) | [MeanWeighting](https://torchjd.org/stable/docs/aggregation/mean#torchjd.aggregation.MeanWeighting) | - | diff --git a/docs/source/docs/aggregation/gradvac.rst b/docs/source/docs/aggregation/gradvac.rst new file mode 100644 index 00000000..18e154bc --- /dev/null +++ b/docs/source/docs/aggregation/gradvac.rst @@ -0,0 +1,14 @@ +:hide-toc: + +GradVac +======= + +.. autoclass:: torchjd.aggregation.GradVac + :members: + :undoc-members: + :exclude-members: forward + +.. autoclass:: torchjd.aggregation.GradVacWeighting + :members: + :undoc-members: + :exclude-members: forward diff --git a/docs/source/docs/aggregation/index.rst b/docs/source/docs/aggregation/index.rst index c15d5980..64ba6f63 100644 --- a/docs/source/docs/aggregation/index.rst +++ b/docs/source/docs/aggregation/index.rst @@ -35,6 +35,7 @@ Abstract base classes dualproj.rst flattening.rst graddrop.rst + gradvac.rst imtl_g.rst krum.rst mean.rst diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index 9eed9bf7..93f824e3 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -66,6 +66,7 @@ from ._dualproj import DualProj, DualProjWeighting from ._flattening import Flattening from ._graddrop import GradDrop +from ._gradvac import GradVac, GradVacWeighting from ._imtl_g import IMTLG, IMTLGWeighting from ._krum import Krum, KrumWeighting from ._mean import Mean, MeanWeighting @@ -92,6 +93,8 @@ "Flattening", "GeneralizedWeighting", "GradDrop", + "GradVac", + "GradVacWeighting", "IMTLG", "IMTLGWeighting", "Krum", diff --git a/src/torchjd/aggregation/_gradvac.py b/src/torchjd/aggregation/_gradvac.py new file mode 100644 index 00000000..d2fbcbf0 --- /dev/null +++ b/src/torchjd/aggregation/_gradvac.py @@ -0,0 +1,206 @@ +from __future__ import annotations + +from typing import cast + +import torch +from torch import Tensor + +from torchjd._linalg import PSDMatrix + +from ._aggregator_bases import GramianWeightedAggregator +from ._utils.non_differentiable import raise_non_differentiable_error +from ._weighting_bases import Weighting + + +class GradVac(GramianWeightedAggregator): + r""" + :class:`~torchjd.aggregation._aggregator_bases.Aggregator` implementing the aggregation step of + Gradient Vaccine (GradVac) from `Gradient Vaccine: Investigating and Improving Multi-task + Optimization in Massively Multilingual Models (ICLR 2021 Spotlight) + `_. + + For each task :math:`i`, the order in which other tasks :math:`j` are visited is drawn at + random. For each pair :math:`(i, j)`, the cosine similarity :math:`\phi_{ij}` between the + (possibly already modified) gradient of task :math:`i` and the original gradient of task + :math:`j` is compared to an EMA target :math:`\hat{\phi}_{ij}`. When + :math:`\phi_{ij} < \hat{\phi}_{ij}`, a closed-form correction adds a scaled copy of + :math:`g_j` to :math:`g_i^{(\mathrm{PC})}`. The EMA is then updated with + :math:`\hat{\phi}_{ij} \leftarrow (1-\beta)\hat{\phi}_{ij} + \beta \phi_{ij}`. The aggregated + vector is the sum of the modified rows. + + This aggregator is stateful: it keeps :math:`\hat{\phi}` across calls. Use :meth:`reset` when + the number of tasks or dtype changes. + + :param beta: EMA decay for :math:`\hat{\phi}` (paper default ``0.5``). You may read or assign + the :attr:`beta` attribute between steps to tune the EMA update. + :param eps: Small non-negative constant added to denominators when computing cosines and the + vaccine weight (default ``1e-8``); set to ``0`` to omit this stabilization. You may read or + assign the :attr:`eps` attribute between steps to tune numerical behavior. + + .. note:: + For each task :math:`i`, the order of other tasks :math:`j` is shuffled independently + using the global PyTorch RNG (``torch.randperm``). Seed it with ``torch.manual_seed`` if + you need reproducibility. + + .. note:: + To apply GradVac with per-layer or per-parameter-group granularity, first aggregate the + Jacobian into groups, apply GradVac per group, and sum the results. See the grouping usage + example for details. + """ + + def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None: + weighting = GradVacWeighting(beta=beta, eps=eps) + super().__init__(weighting) + self._gradvac_weighting = weighting + self.register_full_backward_pre_hook(raise_non_differentiable_error) + + @property + def beta(self) -> float: + """EMA decay coefficient for :math:`\\hat{\\phi}` (paper default ``0.5``).""" + + return self._gradvac_weighting.beta + + @beta.setter + def beta(self, value: float) -> None: + self._gradvac_weighting.beta = value + + @property + def eps(self) -> float: + """Small non-negative constant added to denominators for numerical stability.""" + + return self._gradvac_weighting.eps + + @eps.setter + def eps(self, value: float) -> None: + self._gradvac_weighting.eps = value + + def reset(self) -> None: + """Clears EMA state so the next forward starts from zero targets.""" + + self._gradvac_weighting.reset() + + def __repr__(self) -> str: + return f"GradVac(beta={self.beta!r}, eps={self.eps!r})" + + +class GradVacWeighting(Weighting[PSDMatrix]): + r""" + :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of + :class:`~torchjd.aggregation.GradVac`. + + All required quantities (gradient norms, cosine similarities, and their updates after the + vaccine correction) are derived purely from the Gramian, without needing the full Jacobian. + If :math:`g_i^{(\mathrm{PC})} = \sum_k c_{ik} g_k`, then: + + .. math:: + + \|g_i^{(\mathrm{PC})}\|^2 = \mathbf{c}_i G \mathbf{c}_i^\top,\qquad + g_i^{(\mathrm{PC})} \cdot g_j = \mathbf{c}_i G_{:,j} + + where :math:`G` is the Gramian. The correction :math:`g_i^{(\mathrm{PC})} \mathrel{+}= w + g_j` then becomes :math:`c_{ij} \mathrel{+}= w`, and the updated dot products follow + immediately. + + This weighting is stateful: it keeps :math:`\hat{\phi}` across calls. Use :meth:`reset` when + the number of tasks or dtype changes. + + :param beta: EMA decay for :math:`\hat{\phi}` (paper default ``0.5``). + :param eps: Small non-negative constant added to denominators (default ``1e-8``). + """ + + def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None: + super().__init__() + if not (0.0 <= beta <= 1.0): + raise ValueError(f"Parameter `beta` must be in [0, 1]. Found beta={beta!r}.") + if eps < 0.0: + raise ValueError(f"Parameter `eps` must be non-negative. Found eps={eps!r}.") + + self._beta = beta + self._eps = eps + self._phi_t: Tensor | None = None + self._state_key: tuple[int, torch.dtype] | None = None + + @property + def beta(self) -> float: + """EMA decay coefficient for :math:`\\hat{\\phi}` (paper default ``0.5``).""" + + return self._beta + + @beta.setter + def beta(self, value: float) -> None: + if not (0.0 <= value <= 1.0): + raise ValueError(f"Attribute `beta` must be in [0, 1]. Found beta={value!r}.") + self._beta = value + + @property + def eps(self) -> float: + """Small non-negative constant added to denominators for numerical stability.""" + + return self._eps + + @eps.setter + def eps(self, value: float) -> None: + if value < 0.0: + raise ValueError(f"Attribute `eps` must be non-negative. Found eps={value!r}.") + self._eps = value + + def reset(self) -> None: + """Clears EMA state so the next forward starts from zero targets.""" + + self._phi_t = None + self._state_key = None + + def forward(self, gramian: PSDMatrix, /) -> Tensor: + # Move all computations on cpu to avoid moving memory between cpu and gpu at each iteration + device = gramian.device + dtype = gramian.dtype + cpu = torch.device("cpu") + + G = cast(PSDMatrix, gramian.to(device=cpu)) + m = G.shape[0] + + self._ensure_state(m, dtype) + phi_t = cast(Tensor, self._phi_t) + + beta = self._beta + eps = self._eps + + # C[i, :] holds coefficients such that g_i^PC = sum_k C[i,k] * g_k (original gradients). + # Initially each modified gradient equals the original, so C = I. + C = torch.eye(m, device=cpu, dtype=dtype) + + for i in range(m): + # Dot products of g_i^PC with every original g_j, shape (m,). + cG = C[i] @ G + + others = [j for j in range(m) if j != i] + perm = torch.randperm(len(others)) + shuffled_js = [others[idx] for idx in perm.tolist()] + + for j in shuffled_js: + dot_ij = cG[j] + norm_i_sq = (cG * C[i]).sum() + norm_i = norm_i_sq.clamp(min=0.0).sqrt() + norm_j = G[j, j].clamp(min=0.0).sqrt() + denom = norm_i * norm_j + eps + phi_ijk = dot_ij / denom + + phi_hat = phi_t[i, j] + if phi_ijk < phi_hat: + sqrt_1_phi2 = (1.0 - phi_ijk * phi_ijk).clamp(min=0.0).sqrt() + sqrt_1_hat2 = (1.0 - phi_hat * phi_hat).clamp(min=0.0).sqrt() + denom_w = norm_j * sqrt_1_hat2 + eps + w = norm_i * (phi_hat * sqrt_1_phi2 - phi_ijk * sqrt_1_hat2) / denom_w + C[i, j] = C[i, j] + w + cG = cG + w * G[j] + + phi_t[i, j] = (1.0 - beta) * phi_hat + beta * phi_ijk + + weights = C.sum(dim=0) + return weights.to(device) + + def _ensure_state(self, m: int, dtype: torch.dtype) -> None: + key = (m, dtype) + if self._state_key != key or self._phi_t is None: + self._phi_t = torch.zeros(m, m, dtype=dtype) + self._state_key = key diff --git a/tests/plots/_utils.py b/tests/plots/_utils.py index dc69bfda..40118fea 100644 --- a/tests/plots/_utils.py +++ b/tests/plots/_utils.py @@ -1,3 +1,5 @@ +from collections.abc import Callable + import numpy as np import torch from plotly import graph_objects as go @@ -7,14 +9,22 @@ class Plotter: - def __init__(self, aggregators: list[Aggregator], matrix: torch.Tensor, seed: int = 0) -> None: - self.aggregators = aggregators + def __init__( + self, + aggregator_factories: dict[str, Callable[[], Aggregator]], + selected_keys: list[str], + matrix: torch.Tensor, + seed: int = 0, + ) -> None: + self._aggregator_factories = aggregator_factories + self.selected_keys = selected_keys self.matrix = matrix self.seed = seed def make_fig(self) -> Figure: torch.random.manual_seed(self.seed) - results = [agg(self.matrix) for agg in self.aggregators] + aggregators = [self._aggregator_factories[key]() for key in self.selected_keys] + results = [agg(self.matrix) for agg in aggregators] fig = go.Figure() @@ -23,14 +33,19 @@ def make_fig(self) -> Figure: fig.add_trace(cone) for i in range(len(self.matrix)): - scatter = make_vector_scatter(self.matrix[i], "blue", f"g{i + 1}") + scatter = make_vector_scatter( + self.matrix[i], + "blue", + f"g{i + 1}", + textposition="top right", + ) fig.add_trace(scatter) for i in range(len(results)): scatter = make_vector_scatter( results[i], "black", - str(self.aggregators[i]), + self.selected_keys[i], showlegend=True, dash=True, ) diff --git a/tests/plots/interactive_plotter.py b/tests/plots/interactive_plotter.py index 2a945f93..56031a61 100644 --- a/tests/plots/interactive_plotter.py +++ b/tests/plots/interactive_plotter.py @@ -1,22 +1,26 @@ import logging import os import webbrowser +from collections.abc import Callable from threading import Timer import numpy as np import torch from dash import Dash, Input, Output, callback, dcc, html from plotly.graph_objs import Figure +from typing_extensions import Unpack from plots._utils import Plotter, angle_to_coord, coord_to_angle from torchjd.aggregation import ( IMTLG, MGDA, + Aggregator, AlignedMTL, CAGrad, ConFIG, DualProj, GradDrop, + GradVac, Mean, NashMTL, PCGrad, @@ -30,6 +34,14 @@ MAX_LENGTH = 25.0 +def _format_angle_display(angle: float) -> str: + return f"{angle:.4f} rad ({np.degrees(angle):.1f}°)" + + +def _format_length_display(r: float) -> str: + return f"{r:.4f}" + + def main() -> None: log = logging.getLogger("werkzeug") log.setLevel(logging.CRITICAL) @@ -42,26 +54,30 @@ def main() -> None: ], ) - aggregators = [ - AlignedMTL(), - CAGrad(c=0.5), - ConFIG(), - DualProj(), - GradDrop(), - IMTLG(), - Mean(), - MGDA(), - NashMTL(n_tasks=matrix.shape[0]), - PCGrad(), - Random(), - Sum(), - TrimmedMean(trim_number=1), - UPGrad(), - ] - - aggregators_dict = {str(aggregator): aggregator for aggregator in aggregators} - - plotter = Plotter([], matrix) + n_tasks = matrix.shape[0] + aggregator_factories: dict[str, Callable[[], Aggregator]] = { + "AlignedMTL-min": lambda: AlignedMTL(scale_mode="min"), + "AlignedMTL-median": lambda: AlignedMTL(scale_mode="median"), + "AlignedMTL-RMSE": lambda: AlignedMTL(scale_mode="rmse"), + str(CAGrad(c=0.5)): lambda: CAGrad(c=0.5), + str(ConFIG()): lambda: ConFIG(), + str(DualProj()): lambda: DualProj(), + str(GradDrop()): lambda: GradDrop(), + str(GradVac()): lambda: GradVac(), + str(IMTLG()): lambda: IMTLG(), + str(Mean()): lambda: Mean(), + str(MGDA()): lambda: MGDA(), + str(NashMTL(n_tasks=n_tasks)): lambda: NashMTL(n_tasks=n_tasks), + str(PCGrad()): lambda: PCGrad(), + str(Random()): lambda: Random(), + str(Sum()): lambda: Sum(), + str(TrimmedMean(trim_number=1)): lambda: TrimmedMean(trim_number=1), + str(UPGrad()): lambda: UPGrad(), + } + + aggregator_strings = list(aggregator_factories.keys()) + + plotter = Plotter(aggregator_factories, [], matrix) app = Dash(__name__) @@ -96,7 +112,6 @@ def main() -> None: gradient_slider_inputs.append(Input(angle_input, "value")) gradient_slider_inputs.append(Input(r_input, "value")) - aggregator_strings = [str(aggregator) for aggregator in aggregators] checklist = dcc.Checklist(aggregator_strings, [], id="aggregator-checklist") control_div = html.Div( @@ -115,22 +130,32 @@ def update_seed(value: int) -> Figure: plotter.seed = value return plotter.make_fig() + n_gradients = len(matrix) + gradient_value_outputs: list[Output] = [] + for i in range(n_gradients): + gradient_value_outputs.append(Output(f"g{i + 1}-angle-value", "children")) + gradient_value_outputs.append(Output(f"g{i + 1}-length-value", "children")) + @callback( Output("aggregations-fig", "figure", allow_duplicate=True), + *gradient_value_outputs, *gradient_slider_inputs, prevent_initial_call=True, ) - def update_gradient_coordinate(*values: str) -> Figure: + def update_gradient_coordinate(*values: str) -> tuple[Figure, Unpack[tuple[str, ...]]]: values_ = [float(value) for value in values] + display_parts: list[str] = [] for j in range(len(values_) // 2): angle = values_[2 * j] r = values_[2 * j + 1] x, y = angle_to_coord(angle, r) plotter.matrix[j, 0] = x plotter.matrix[j, 1] = y + display_parts.append(_format_angle_display(angle)) + display_parts.append(_format_length_display(r)) - return plotter.make_fig() + return (plotter.make_fig(), *display_parts) @callback( Output("aggregations-fig", "figure", allow_duplicate=True), @@ -138,9 +163,7 @@ def update_gradient_coordinate(*values: str) -> Figure: prevent_initial_call=True, ) def update_aggregators(value: list[str]) -> Figure: - aggregator_keys = value - new_aggregators = [aggregators_dict[key] for key in aggregator_keys] - plotter.aggregators = new_aggregators + plotter.selected_keys = list(value) return plotter.make_fig() Timer(1, open_browser).start() @@ -173,11 +196,56 @@ def make_gradient_div( style={"width": "250px"}, ) + label_style: dict[str, str | int] = { + "display": "inline-block", + "width": "52px", + "margin-right": "8px", + "vertical-align": "middle", + } + value_style: dict[str, str] = { + "display": "inline-block", + "margin-left": "10px", + "min-width": "140px", + "font-family": "monospace", + "font-size": "13px", + "vertical-align": "middle", + } + row_style: dict[str, str] = {"display": "block", "margin-bottom": "6px"} div = html.Div( [ - html.P(f"g{i + 1}", style={"display": "inline-block", "margin-right": 20}), - angle_input, - r_input, + dcc.Markdown( + f"$g_{{{i + 1}}}$", + mathjax=True, + style={ + "margin": "0 0 6px 0", + "font-weight": "bold", + "display": "block", + }, + ), + html.Div( + [ + html.Span("Angle", style=label_style), + angle_input, + html.Span( + id=f"g{i + 1}-angle-value", + children=_format_angle_display(angle), + style=value_style, + ), + ], + style=row_style, + ), + html.Div( + [ + html.Span("Length", style=label_style), + r_input, + html.Span( + id=f"g{i + 1}-length-value", + children=_format_length_display(r), + style=value_style, + ), + ], + style={**row_style, "margin-bottom": "12px"}, + ), ], ) return div, angle_input, r_input diff --git a/tests/unit/aggregation/test_gradvac.py b/tests/unit/aggregation/test_gradvac.py new file mode 100644 index 00000000..bde2e8fd --- /dev/null +++ b/tests/unit/aggregation/test_gradvac.py @@ -0,0 +1,146 @@ +import torch +from pytest import mark, raises +from torch import Tensor +from torch.testing import assert_close +from utils.tensors import ones_, randn_, tensor_ + +from torchjd.aggregation import GradVac, GradVacWeighting + +from ._asserts import assert_expected_structure, assert_non_differentiable +from ._inputs import scaled_matrices, typical_matrices, typical_matrices_2_plus_rows + +scaled_pairs = [(GradVac(), m) for m in scaled_matrices] +typical_pairs = [(GradVac(), m) for m in typical_matrices] +requires_grad_pairs = [(GradVac(), ones_(3, 5, requires_grad=True))] + + +def test_representations() -> None: + A = GradVac() + assert repr(A) == "GradVac(beta=0.5, eps=1e-08)" + assert str(A) == "GradVac" + + +def test_beta_out_of_range() -> None: + with raises(ValueError, match="beta"): + GradVac(beta=-0.1) + with raises(ValueError, match="beta"): + GradVac(beta=1.1) + + +def test_beta_setter_out_of_range() -> None: + A = GradVac() + with raises(ValueError, match="beta"): + A.beta = -0.1 + with raises(ValueError, match="beta"): + A.beta = 1.1 + + +def test_beta_setter_updates_value() -> None: + A = GradVac() + A.beta = 0.25 + assert A.beta == 0.25 + + +def test_eps_rejects_negative() -> None: + with raises(ValueError, match="eps"): + GradVac(eps=-1e-9) + + +def test_eps_setter_rejects_negative() -> None: + A = GradVac() + with raises(ValueError, match="eps"): + A.eps = -1e-9 + + +def test_eps_can_be_changed_between_steps() -> None: + J = tensor_([[1.0, 0.0], [0.0, 1.0]]) + A = GradVac() + A.eps = 1e-6 + assert A(J).isfinite().all() + A.reset() + A.eps = 1e-10 + assert A(J).isfinite().all() + + +def test_zero_rows_returns_zero_vector() -> None: + out = GradVac()(tensor_([]).reshape(0, 3)) + assert_close(out, tensor_([0.0, 0.0, 0.0])) + + +def test_zero_columns_returns_zero_vector() -> None: + out = GradVac()(tensor_([]).reshape(2, 0)) + assert out.shape == (0,) + + +def test_reproducible_with_manual_seed() -> None: + J = randn_((3, 8)) + torch.manual_seed(12345) + A1 = GradVac(beta=0.3) + out1 = A1(J) + torch.manual_seed(12345) + A2 = GradVac(beta=0.3) + out2 = A2(J) + assert_close(out1, out2) + + +@mark.parametrize("matrix", typical_matrices_2_plus_rows) +def test_reset_restores_first_step_behavior(matrix: Tensor) -> None: + torch.manual_seed(7) + A = GradVac(beta=0.5) + first = A(matrix) + A(matrix) + A.reset() + torch.manual_seed(7) + assert_close(first, A(matrix)) + + +@mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) +def test_expected_structure(aggregator: GradVac, matrix: Tensor) -> None: + assert_expected_structure(aggregator, matrix) + + +@mark.parametrize(["aggregator", "matrix"], requires_grad_pairs) +def test_non_differentiable(aggregator: GradVac, matrix: Tensor) -> None: + assert_non_differentiable(aggregator, matrix) + + +def test_weighting_beta_out_of_range() -> None: + with raises(ValueError, match="beta"): + GradVacWeighting(beta=-0.1) + with raises(ValueError, match="beta"): + GradVacWeighting(beta=1.1) + + +def test_weighting_eps_rejects_negative() -> None: + with raises(ValueError, match="eps"): + GradVacWeighting(eps=-1e-9) + + +def test_weighting_reset_restores_first_step_behavior() -> None: + J = randn_((3, 8)) + G = J @ J.T + torch.manual_seed(7) + w = GradVacWeighting(beta=0.5) + first = w(G) + w(G) + w.reset() + torch.manual_seed(7) + assert_close(first, w(G)) + + +def test_aggregator_and_weighting_agree() -> None: + """GradVac()(J) == GradVacWeighting()(J @ J.T) @ J for any matrix J.""" + + J = randn_((3, 8)) + G = J @ J.T + + torch.manual_seed(42) + A = GradVac(beta=0.3) + expected = A(J) + + torch.manual_seed(42) + W = GradVacWeighting(beta=0.3) + weights = W(G) + result = weights @ J + + assert_close(result, expected, rtol=1e-4, atol=1e-4) diff --git a/tests/unit/aggregation/test_values.py b/tests/unit/aggregation/test_values.py index 42faca91..f468dc44 100644 --- a/tests/unit/aggregation/test_values.py +++ b/tests/unit/aggregation/test_values.py @@ -14,6 +14,8 @@ DualProj, DualProjWeighting, GradDrop, + GradVac, + GradVacWeighting, IMTLGWeighting, Krum, KrumWeighting, @@ -57,6 +59,7 @@ (Constant(tensor([1.0, 2.0])), J_base, tensor([8.0, 3.0, 3.0])), (DualProj(), J_base, tensor([0.5563, 1.1109, 1.1109])), (GradDrop(), J_base, tensor([6.0, 2.0, 2.0])), + (GradVac(), J_base, tensor([0.5848, 3.8012, 3.8012])), (IMTLG(), J_base, tensor([0.0767, 1.0000, 1.0000])), (Krum(n_byzantine=1, n_selected=4), J_Krum, tensor([1.2500, 0.7500, 1.5000])), (Mean(), J_base, tensor([1.0, 1.0, 1.0])), @@ -77,6 +80,7 @@ (DualProjWeighting(), G_base, tensor([0.6109, 0.5000])), (IMTLGWeighting(), G_base, tensor([0.5923, 0.4077])), (KrumWeighting(1, 4), G_Krum, tensor([0.2500, 0.2500, 0.0000, 0.2500, 0.2500])), + (GradVacWeighting(), G_base, tensor([2.2222, 1.5789])), (MeanWeighting(), G_base, tensor([0.5000, 0.5000])), (MGDAWeighting(), G_base, tensor([0.6000, 0.4000])), (PCGradWeighting(), G_base, tensor([2.2222, 1.5789])),