feat(aggregation): Add GradVac aggregator#638
feat(aggregation): Add GradVac aggregator#638rkhosrowshahi wants to merge 1 commit intoSimplexLab:mainfrom
Conversation
Implement Gradient Vaccine (ICLR 2021) as a stateful Jacobian aggregator. Support group_type 0 (whole model), 1 (all_layer via encoder), and 2 (all_matrix via shared_params), with DEFAULT_GRADVAC_EPS and configurable eps. Add Sphinx page and unit tests. Autogram is not supported; use torch.manual_seed for reproducible task shuffle order. Made-with: Cursor
|
Hey @ValerianRey : |
|
I think the link check action failed because of bad luck, re-running the job fixed it. |
|
Thanks a lot for the PR, I'll look at it tonight or tomorrow! |
|
/opencode:Plan can you please check that the actual implementation (from line 219 to line 251 of _gradvac.py) is equivalent to the implementation in https://github.com/median-research-group/LibMTL/blob/main/LibMTL/weighting/GradVac.py? Please report anything that differs. |
|
Let me analyze the differences between the two implementations: Comparison SummaryAfter comparing lines 219-251 of ✅ Equivalent Behaviors
🔍 Differences Found
📝 Notes
Conclusion: The implementations are algorithmically equivalent. TorchJD includes minor improvements (configurable eps, numerical clamp) and follows the codebase's design patterns. |
|
/opencode:Plan Please review this PR in detail. |
This comment was marked as low quality.
This comment was marked as low quality.
There was a problem hiding this comment.
Really good work, tyvm! This seems almost ready to merge IMO. Just a bunch of nitpicks that should be really easy to fix, and 1 somewhat important implementation question (iterating over groups in outer loop vs inner loop).
I would also argue that it seems that this could be implemented as a Weighting internally (because we actually act on norms and cosine similarities between gradients, which is what the gramian contains). Also, it's possible to keep track of norms and cosine similarities between projected gradients even if we don't have those gradients, just by making some operations on the gramian. This is what we did to implement PCGrad as a Weighting.
For example, imagine you have g1 and g2 be two gradients. From the gramian, you know ||g1||, ||g2|| (the sqrt of the diag elements), and g1 . g2 (an off-diag element), so you can deduce cos(g1, g2) from that.
If you compute g1' = g1 + w * g2, you can also directly deduce the norm of g1':
||g1'||² = ||g1||² + w² ||g2||² + 2w g1 . g2 (all elements of the right handside are known).
Similarly, you can compute g1' . g2 = (g1 + w * g2) . g2 = g1 . g2 + w g1 . g2.
So even after projection, you still know the dot products between all of your gradients, meaning that you still know the "new" gramian.
I didn't think through it entirely but at a first glance it seems possible to adapt this as a weighting, because of that. The implementation may even be faster actually (because we have fewer norms to recompute). But it may be hard to implement, so IMO we should merge this without even trying to implement it as a Weighting, and we can always improve later. @PierreQuinton what do you think about that?
| #: Default small constant added to denominators for numerical stability. | ||
| DEFAULT_GRADVAC_EPS = 1e-8 |
There was a problem hiding this comment.
I don't think we need that to be stored in a constant (we never do that for the default value of the params of the other aggregators).
| The constructor argument ``group_type`` (default ``0``) sets **parameter granularity** for the | ||
| per-block cosine statistics in GradVac: | ||
|
|
||
| * ``0`` — **whole model** (``whole_model``): one block per task gradient row. Omit ``encoder`` and | ||
| ``shared_params``. | ||
| * ``1`` — **all layer** (``all_layer``): one block per leaf submodule with parameters under | ||
| ``encoder`` (same traversal as ``encoder.modules()`` in the reference formulation). | ||
| * ``2`` — **all matrix** (``all_matrix``): one block per tensor in ``shared_params``, in order. Use | ||
| the same tensors as for the shared-parameter Jacobian columns (e.g. the parameters you would pass | ||
| to a shared-gradient helper). |
There was a problem hiding this comment.
This part is already included in the built documentation by .. autoclass:: torchjd.aggregation.GradVac, so it ends up being duplicated. Btw to look at the built documentation, you can run:
uv run make clean -C docs
uv run make html -C docs
and then open docs/build/html/index.html with a web browser.
| device = grads.device | ||
| dtype = grads.dtype | ||
| self._ensure_state(m, n, sizes, device, dtype) | ||
| assert self._rho_t is not None |
There was a problem hiding this comment.
This assert also cannot fail I think, so we can remove it.
There was a problem hiding this comment.
We can return the self._rho_t from self._ensure_state().
def _ensure_state(
self,
m: int,
n: int,
sizes: tuple[int, ...],
device: torch.device,
dtype: torch.dtype,
) -> Tensor:
key = (m, n, sizes, device, dtype)
num_groups = len(sizes)
if self._state_key != key or self._phi_t is None:
phi = torch.zeros(m, m, num_groups, device=device, dtype=dtype)
self._phi_t = phi
self._state_key = key
return phi
return self._phi_t| assert self._rho_t is not None | |
| phi_t = self._ensure_state(m, n, sizes, device, dtype) |
There was a problem hiding this comment.
I'm not a fan of changing internal state + returning it. I think the intention of the function becomes a bit harder to understand.
If you just get rid of the assert, does it work or do you get a problem reported by ty? If so we could maybe just cast, or even keep the assert.
There was a problem hiding this comment.
Ok then we initialize the state but don't return it:
def _ensure_state(
self,
m: int,
n: int,
sizes: tuple[int, ...],
device: torch.device,
dtype: torch.dtype,
) -> None:
key = (m, n, sizes, device, dtype)
num_groups = len(sizes)
if self._state_key != key or self._phi_t is None:
self._phi_t = torch.zeros(m, m, num_groups, device=device, dtype=dtype)
self._state_key = keyand in forward:
self._ensure_state(m, n, sizes, device, dtype)
phi_t = cast(Tensor, self._phi_t)cast here is to make sure the data type is Tensor.
| Massively Multilingual Models (ICLR 2021 Spotlight) | ||
| <https://openreview.net/forum?id=F1vEjWK-lH_>`_. | ||
|
|
||
| The input matrix is a Jacobian :math:`G \in \mathbb{R}^{M \times D}` whose rows are per-task |
There was a problem hiding this comment.
In torchjd we usually denote the Jacobian as J, the number of objectives as (lowercase) m, and the number of parameters as (lowercase) n.
| The input matrix is a Jacobian :math:`G \in \mathbb{R}^{M \times D}` whose rows are per-task | |
| The input matrix is a Jacobian :math:`J \in \mathbb{R}^{m \times n}` whose rows are per-task |
There was a problem hiding this comment.
Thanks. I will also change the other two lines that used D.
| each task gradient row is partitioned into blocks :math:`k` so that cosines and EMA targets | ||
| :math:`\bar{\rho}_{ijk}` are computed **per block** rather than only globally: | ||
|
|
||
| * ``0`` — **whole model** (``whole_model``): the full row of length :math:`D` is a single block. |
There was a problem hiding this comment.
| * ``0`` — **whole model** (``whole_model``): the full row of length :math:`D` is a single block. | |
| * ``0`` — **whole model** (``whole_model``): the full row of length :math:`n` is a single block. |
|
|
||
|
|
||
| def test_eps_can_be_changed_between_steps() -> None: | ||
| j = tensor([[1.0, 0.0], [0.0, 1.0]]) |
There was a problem hiding this comment.
The tensor here is not gonna be affected by DEVICE and DTYPE. Please use tensor_ (from tests/utils/tensors.py) to ensure that the tensor's device and dtype are gonna change when we change the PYTEST_TORCH_DEVICE and PYTEST_TORCH_DTYPE variables.
| j = tensor([[1.0, 0.0], [0.0, 1.0]]) | |
| J = tensor_([[1.0, 0.0], [0.0, 1.0]]) |
There was a problem hiding this comment.
Thanks. Was the capital J intended to indicate a constant variable or a typo?
There was a problem hiding this comment.
No, we always use uppercase J for the Jacobian, because it matches the mathematical notation. Same as G for the Gramian. Also lowercase j is generally just an index in a for loop.
|
|
||
|
|
||
| def test_group_type_0_rejects_shared_params() -> None: | ||
| p = nn.Parameter(tensor([1.0])) |
There was a problem hiding this comment.
| p = nn.Parameter(tensor([1.0])) | |
| p = nn.Parameter(tensor_([1.0])) |
| out = GradVac()(tensor([]).reshape(0, 3)) | ||
| assert_close(out, tensor([0.0, 0.0, 0.0])) |
There was a problem hiding this comment.
| out = GradVac()(tensor([]).reshape(0, 3)) | |
| assert_close(out, tensor([0.0, 0.0, 0.0])) | |
| out = GradVac()(tensor_([]).reshape(0, 3)) | |
| assert_close(out, tensor_([0.0, 0.0, 0.0])) |
| def test_zero_columns_returns_zero_vector() -> None: | ||
| """Handled inside forward before grouping validation.""" | ||
|
|
||
| out = GradVac()(tensor([]).reshape(2, 0)) |
There was a problem hiding this comment.
| out = GradVac()(tensor([]).reshape(2, 0)) | |
| out = GradVac()(tensor_([]).reshape(2, 0)) |
| d = sum(p.numel() for p in net.parameters()) | ||
| agg = GradVac(group_type=1, encoder=net) | ||
| with raises(ValueError, match="Jacobian width"): | ||
| agg(tensor([[1.0] * (d - 1), [2.0] * (d - 1)])) |
There was a problem hiding this comment.
| agg(tensor([[1.0] * (d - 1), [2.0] * (d - 1)])) | |
| agg(tensor_([[1.0] * (d - 1), [2.0] * (d - 1)])) |
|
Opencode's review was quite low quality, but it mentioned something that I missed: we need a test for GradVac in tests/unit/aggregation/test_values.py. Similarly, i'd like to have GradVac added to tests/plots/interactive_plotter.py. |
|
@rkhosrowshahi Thx for all the updates! feel free to commit and push all the code suggestions and other changes you made! |
|
@rkhosrowshahi Very nice, I like all this. I think for me, we need to remove the groupings, we should add a Lastly, as @ValerianRey mentioned, this is most likely a gramian based aggregator, I think the formula he gave were right and the Gramian operation can be deduced (and will probably be more efficient). So I think this should be made into a Gramian weighting and gramian based aggregator pair (see |


Summary
Adds Gradient Vaccine (GradVac) from ICLR 2021 as a stateful
Aggregatoron the full task Jacobian.Behavior
\bar{\rho}, with the closed-form vaccine update when\rho < \bar{\rho}.group_type:0whole model (single block);1all_layer viaencoder(leaf modules with parameters);2all_matrix viashared_params(one block per tensor, iteration order = Jacobian column order).DEFAULT_GRADVAC_EPSand configurableeps(constructor + mutable attribute).torch.randperm; usetorch.manual_seedfor reproducibility.Files
src/torchjd/aggregation/_gradvac.py, export in__init__.pydocs/source/docs/aggregation/gradvac.rst+ index toctreetests/unit/aggregation/test_gradvac.pyVerification
ruff format/ruff checkon touched pathsty checkon_gradvac.pypytest tests/unit/aggregation/test_gradvac.py tests/unit/aggregation/test_values.py -W error