From 91c91d3d0178134382cb7be594460eaecb4cfe57 Mon Sep 17 00:00:00 2001 From: rkhosrowshahi Date: Wed, 15 Apr 2026 16:30:28 -0400 Subject: [PATCH 1/6] docs: restore Grouping example for per-parameter-group aggregation Re-add the GradVac grouping documentation (whole model, encoder-decoder, per-layer, per-tensor) with doctest snippets, link it from the examples index, and cross-link from GradVac. Changelog: note the new example under Unreleased. --- CHANGELOG.md | 1 + docs/source/examples/grouping.rst | 167 ++++++++++++++++++++++++++++ docs/source/examples/index.rst | 4 + src/torchjd/aggregation/_gradvac.py | 8 ++ 4 files changed, 180 insertions(+) create mode 100644 docs/source/examples/grouping.rst diff --git a/CHANGELOG.md b/CHANGELOG.md index 5104aa77e..e2692a86f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ changelog does not include internal changes that do not affect the user. - Added `GradVac` and `GradVacWeighting` from [Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Models](https://arxiv.org/pdf/2010.05874). +- Documented per-parameter-group aggregation (GradVac-style grouping) in a new Grouping example. ### Fixed diff --git a/docs/source/examples/grouping.rst b/docs/source/examples/grouping.rst new file mode 100644 index 000000000..aa9e85aad --- /dev/null +++ b/docs/source/examples/grouping.rst @@ -0,0 +1,167 @@ +Grouping +======== + +When applying a conflict-resolving aggregator such as :class:`~torchjd.aggregation.GradVac` in +multi-task learning, the cosine similarities between task gradients can be computed at different +granularities. The GradVac paper introduces four strategies, each partitioning the shared +parameter vector differently: + +1. **Whole Model** (default) — one group covering all shared parameters. +2. **Encoder-Decoder** — one group per top-level sub-network (e.g. encoder and decoder separately). +3. **All Layers** — one group per leaf module of the encoder. +4. **All Matrices** — one group per individual parameter tensor. + +In TorchJD, grouping is achieved by calling :func:`~torchjd.autojac.jac_to_grad` once per group +after :func:`~torchjd.autojac.mtl_backward`, with a dedicated aggregator instance per group. +For stateful aggregators such as :class:`~torchjd.aggregation.GradVac`, each instance +independently maintains its own EMA state :math:`\hat{\phi}`, matching the per-block targets from +the original paper. + +.. note:: + The grouping is orthogonal to the choice of + :func:`~torchjd.autojac.backward` vs :func:`~torchjd.autojac.mtl_backward`. Those functions + determine *which* parameters receive Jacobians; grouping then determines *how* those Jacobians + are partitioned for aggregation. Calling :func:`~torchjd.autojac.jac_to_grad` once on all shared + parameters corresponds to the Whole Model strategy. Splitting those parameters into + sub-networks and calling :func:`~torchjd.autojac.jac_to_grad` separately on each — with a + dedicated aggregator per sub-network — gives an arbitrary custom grouping, such as the + Encoder-Decoder strategy described in the GradVac paper for encoder-decoder architectures. + +.. note:: + The examples below use :class:`~torchjd.aggregation.GradVac`, but the same pattern applies to + any aggregator. + +1. Whole Model +-------------- + +A single :class:`~torchjd.aggregation.GradVac` instance aggregates all shared parameters +together. Cosine similarities are computed between the full task gradient vectors. + +.. testcode:: + :emphasize-lines: 14, 19 + + import torch + from torch.nn import Linear, MSELoss, ReLU, Sequential + from torch.optim import SGD + + from torchjd.aggregation import GradVac + from torchjd.autojac import jac_to_grad, mtl_backward + + encoder = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU()) + task1_head, task2_head = Linear(3, 1), Linear(3, 1) + optimizer = SGD([*encoder.parameters(), *task1_head.parameters(), *task2_head.parameters()], lr=0.1) + loss_fn = MSELoss() + inputs, t1, t2 = torch.randn(8, 16, 10), torch.randn(8, 16, 1), torch.randn(8, 16, 1) + + gradvac = GradVac() + + for x, y1, y2 in zip(inputs, t1, t2): + features = encoder(x) + mtl_backward([loss_fn(task1_head(features), y1), loss_fn(task2_head(features), y2)], features=features) + jac_to_grad(encoder.parameters(), gradvac) + optimizer.step() + optimizer.zero_grad() + +2. Encoder-Decoder +------------------ + +One :class:`~torchjd.aggregation.GradVac` instance per top-level sub-network. Here the model +is split into an encoder and a decoder; cosine similarities are computed separately within each. +Passing ``features=dec_out`` to :func:`~torchjd.autojac.mtl_backward` causes both sub-networks +to receive Jacobians, which are then aggregated independently. + +.. testcode:: + :emphasize-lines: 8-9, 15-16, 22-23 + + import torch + from torch.nn import Linear, MSELoss, ReLU, Sequential + from torch.optim import SGD + + from torchjd.aggregation import GradVac + from torchjd.autojac import jac_to_grad, mtl_backward + + encoder = Sequential(Linear(10, 5), ReLU()) + decoder = Sequential(Linear(5, 3), ReLU()) + task1_head, task2_head = Linear(3, 1), Linear(3, 1) + optimizer = SGD([*encoder.parameters(), *decoder.parameters(), *task1_head.parameters(), *task2_head.parameters()], lr=0.1) + loss_fn = MSELoss() + inputs, t1, t2 = torch.randn(8, 16, 10), torch.randn(8, 16, 1), torch.randn(8, 16, 1) + + encoder_gradvac = GradVac() + decoder_gradvac = GradVac() + + for x, y1, y2 in zip(inputs, t1, t2): + enc_out = encoder(x) + dec_out = decoder(enc_out) + mtl_backward([loss_fn(task1_head(dec_out), y1), loss_fn(task2_head(dec_out), y2)], features=dec_out) + jac_to_grad(encoder.parameters(), encoder_gradvac) + jac_to_grad(decoder.parameters(), decoder_gradvac) + optimizer.step() + optimizer.zero_grad() + +3. All Layers +------------- + +One :class:`~torchjd.aggregation.GradVac` instance per leaf module. Cosine similarities are +computed between the per-layer blocks of the task gradients. + +.. testcode:: + :emphasize-lines: 14-15, 20-21 + + import torch + from torch.nn import Linear, MSELoss, ReLU, Sequential + from torch.optim import SGD + + from torchjd.aggregation import GradVac + from torchjd.autojac import jac_to_grad, mtl_backward + + encoder = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU()) + task1_head, task2_head = Linear(3, 1), Linear(3, 1) + optimizer = SGD([*encoder.parameters(), *task1_head.parameters(), *task2_head.parameters()], lr=0.1) + loss_fn = MSELoss() + inputs, t1, t2 = torch.randn(8, 16, 10), torch.randn(8, 16, 1), torch.randn(8, 16, 1) + + leaf_layers = [m for m in encoder.modules() if not list(m.children()) and list(m.parameters())] + gradvacs = [GradVac() for _ in leaf_layers] + + for x, y1, y2 in zip(inputs, t1, t2): + features = encoder(x) + mtl_backward([loss_fn(task1_head(features), y1), loss_fn(task2_head(features), y2)], features=features) + for layer, gradvac in zip(leaf_layers, gradvacs): + jac_to_grad(layer.parameters(), gradvac) + optimizer.step() + optimizer.zero_grad() + +4. All Matrices +--------------- + +One :class:`~torchjd.aggregation.GradVac` instance per individual parameter tensor. Cosine +similarities are computed between the per-tensor blocks of the task gradients (e.g. weights and +biases of each layer are treated as separate groups). + +.. testcode:: + :emphasize-lines: 14-15, 20-21 + + import torch + from torch.nn import Linear, MSELoss, ReLU, Sequential + from torch.optim import SGD + + from torchjd.aggregation import GradVac + from torchjd.autojac import jac_to_grad, mtl_backward + + encoder = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU()) + task1_head, task2_head = Linear(3, 1), Linear(3, 1) + optimizer = SGD([*encoder.parameters(), *task1_head.parameters(), *task2_head.parameters()], lr=0.1) + loss_fn = MSELoss() + inputs, t1, t2 = torch.randn(8, 16, 10), torch.randn(8, 16, 1), torch.randn(8, 16, 1) + + shared_params = list(encoder.parameters()) + gradvacs = [GradVac() for _ in shared_params] + + for x, y1, y2 in zip(inputs, t1, t2): + features = encoder(x) + mtl_backward([loss_fn(task1_head(features), y1), loss_fn(task2_head(features), y2)], features=features) + for param, gradvac in zip(shared_params, gradvacs): + jac_to_grad([param], gradvac) + optimizer.step() + optimizer.zero_grad() diff --git a/docs/source/examples/index.rst b/docs/source/examples/index.rst index 49c5c1f46..c1f1e836b 100644 --- a/docs/source/examples/index.rst +++ b/docs/source/examples/index.rst @@ -29,6 +29,9 @@ This section contains some usage examples for TorchJD. - :doc:`PyTorch Lightning Integration ` showcases how to combine TorchJD with PyTorch Lightning, by providing an example implementation of a multi-task ``LightningModule`` optimized by Jacobian descent. +- :doc:`Grouping ` shows how to apply an aggregator independently per parameter group + (e.g. per layer), so that conflict resolution happens at a finer granularity than the full + shared parameter vector. - :doc:`Automatic Mixed Precision ` shows how to combine mixed precision training with TorchJD. .. toctree:: @@ -43,3 +46,4 @@ This section contains some usage examples for TorchJD. monitoring.rst lightning_integration.rst amp.rst + grouping.rst diff --git a/src/torchjd/aggregation/_gradvac.py b/src/torchjd/aggregation/_gradvac.py index efb55f444..a98ee0c28 100644 --- a/src/torchjd/aggregation/_gradvac.py +++ b/src/torchjd/aggregation/_gradvac.py @@ -40,6 +40,14 @@ class GradVac(GramianWeightedAggregator, Stateful): For each task :math:`i`, the order of other tasks :math:`j` is shuffled independently using the global PyTorch RNG (``torch.randperm``). Seed it with ``torch.manual_seed`` if you need reproducibility. + + .. note:: + To apply GradVac with per-layer or per-parameter-group granularity, create a separate + :class:`GradVac` instance for each group and call + :func:`~torchjd.autojac.jac_to_grad` once per group after + :func:`~torchjd.autojac.mtl_backward`. Each instance maintains its own EMA state, + matching the per-block targets :math:`\hat{\phi}_{ijk}` from the original paper. See + the :doc:`Grouping ` example for details. """ def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None: From fefc53a82e9dcb829a77012f4eca564ce4e10070 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 16 Apr 2026 00:43:27 +0200 Subject: [PATCH 2/6] Improve grouping example * Add link to the paper * Simplify some formulations * Rename strategies whole model => together; encoder-decoder => per network; all layers => per layer; all matrices => per tensor * Place a bit less emphasis on GradVac: rename gradvac to aggregator when possible * Create losses in separate lines in the code examples * Remove a few redundant sentences from a note --- docs/source/examples/grouping.rst | 113 ++++++++++++++++-------------- 1 file changed, 62 insertions(+), 51 deletions(-) diff --git a/docs/source/examples/grouping.rst b/docs/source/examples/grouping.rst index aa9e85aad..808e8d2b6 100644 --- a/docs/source/examples/grouping.rst +++ b/docs/source/examples/grouping.rst @@ -3,42 +3,45 @@ Grouping When applying a conflict-resolving aggregator such as :class:`~torchjd.aggregation.GradVac` in multi-task learning, the cosine similarities between task gradients can be computed at different -granularities. The GradVac paper introduces four strategies, each partitioning the shared -parameter vector differently: +granularities. The [Gradient Vaccine paper](https://arxiv.org/pdf/2010.05874) introduces four +strategies, each partitioning the shared parameter vector differently: -1. **Whole Model** (default) — one group covering all shared parameters. -2. **Encoder-Decoder** — one group per top-level sub-network (e.g. encoder and decoder separately). -3. **All Layers** — one group per leaf module of the encoder. -4. **All Matrices** — one group per individual parameter tensor. +1. **Together** (baseline): one group covering all shared parameters. Corresponds to the + `whole_model` stategy in the paper. + +2. **Per network**: one group per top-level sub-network (e.g. encoder and decoder separately). + Corresponds to the `enc_dec` stategy in the paper. + +3. **Per layer**: one group per leaf module of the encoder. Corresponds to the `all_layer` stategy + in the paper. + +4. **Per tensor**: one group per individual parameter tensor. Corresponds to the `all_matrix` + stategy in the paper. In TorchJD, grouping is achieved by calling :func:`~torchjd.autojac.jac_to_grad` once per group after :func:`~torchjd.autojac.mtl_backward`, with a dedicated aggregator instance per group. -For stateful aggregators such as :class:`~torchjd.aggregation.GradVac`, each instance -independently maintains its own EMA state :math:`\hat{\phi}`, matching the per-block targets from -the original paper. +For :class:`~torchjd.aggregation.Stateful` aggregators, each instance independently maintains its +own state (e.g. the EMA :math:`\hat{\phi}` state in :class:`~torchjd.aggregation.GradVac`), matching +the per-block targets from the original paper. .. note:: The grouping is orthogonal to the choice of :func:`~torchjd.autojac.backward` vs :func:`~torchjd.autojac.mtl_backward`. Those functions determine *which* parameters receive Jacobians; grouping then determines *how* those Jacobians - are partitioned for aggregation. Calling :func:`~torchjd.autojac.jac_to_grad` once on all shared - parameters corresponds to the Whole Model strategy. Splitting those parameters into - sub-networks and calling :func:`~torchjd.autojac.jac_to_grad` separately on each — with a - dedicated aggregator per sub-network — gives an arbitrary custom grouping, such as the - Encoder-Decoder strategy described in the GradVac paper for encoder-decoder architectures. + are partitioned for aggregation. .. note:: The examples below use :class:`~torchjd.aggregation.GradVac`, but the same pattern applies to - any aggregator. + any :class:`~torchjd.aggregation.Aggregator`. -1. Whole Model --------------- +1. Together +----------- -A single :class:`~torchjd.aggregation.GradVac` instance aggregates all shared parameters +A single :class:`~torchjd.aggregation.Aggregator` instance aggregates all shared parameters together. Cosine similarities are computed between the full task gradient vectors. .. testcode:: - :emphasize-lines: 14, 19 + :emphasize-lines: 14, 21 import torch from torch.nn import Linear, MSELoss, ReLU, Sequential @@ -53,25 +56,27 @@ together. Cosine similarities are computed between the full task gradient vector loss_fn = MSELoss() inputs, t1, t2 = torch.randn(8, 16, 10), torch.randn(8, 16, 1), torch.randn(8, 16, 1) - gradvac = GradVac() + aggregator = GradVac() for x, y1, y2 in zip(inputs, t1, t2): features = encoder(x) - mtl_backward([loss_fn(task1_head(features), y1), loss_fn(task2_head(features), y2)], features=features) - jac_to_grad(encoder.parameters(), gradvac) + loss1 = loss_fn(task1_head(features), y1) + loss2 = loss_fn(task2_head(features), y2) + mtl_backward([loss1, loss2], features=features) + jac_to_grad(encoder.parameters(), aggregator) optimizer.step() optimizer.zero_grad() -2. Encoder-Decoder ------------------- +2. Per network +-------------- -One :class:`~torchjd.aggregation.GradVac` instance per top-level sub-network. Here the model +One :class:`~torchjd.aggregation.Aggregator` instance per top-level sub-network. Here the model is split into an encoder and a decoder; cosine similarities are computed separately within each. Passing ``features=dec_out`` to :func:`~torchjd.autojac.mtl_backward` causes both sub-networks to receive Jacobians, which are then aggregated independently. .. testcode:: - :emphasize-lines: 8-9, 15-16, 22-23 + :emphasize-lines: 8-9, 15-16, 24-25 import torch from torch.nn import Linear, MSELoss, ReLU, Sequential @@ -87,26 +92,28 @@ to receive Jacobians, which are then aggregated independently. loss_fn = MSELoss() inputs, t1, t2 = torch.randn(8, 16, 10), torch.randn(8, 16, 1), torch.randn(8, 16, 1) - encoder_gradvac = GradVac() - decoder_gradvac = GradVac() + encoder_aggregator = GradVac() + decoder_aggregator = GradVac() for x, y1, y2 in zip(inputs, t1, t2): enc_out = encoder(x) dec_out = decoder(enc_out) - mtl_backward([loss_fn(task1_head(dec_out), y1), loss_fn(task2_head(dec_out), y2)], features=dec_out) - jac_to_grad(encoder.parameters(), encoder_gradvac) - jac_to_grad(decoder.parameters(), decoder_gradvac) + loss1 = loss_fn(task1_head(dec_out), y1) + loss2 = loss_fn(task2_head(dec_out), y2) + mtl_backward([loss1, loss2], features=dec_out) + jac_to_grad(encoder.parameters(), encoder_aggregator) + jac_to_grad(decoder.parameters(), decoder_aggregator) optimizer.step() optimizer.zero_grad() -3. All Layers -------------- +3. Per layer +------------ -One :class:`~torchjd.aggregation.GradVac` instance per leaf module. Cosine similarities are -computed between the per-layer blocks of the task gradients. +One :class:`~torchjd.aggregation.Aggregator` instance per leaf module. Cosine similarities are +computed per-layer between the task gradients. .. testcode:: - :emphasize-lines: 14-15, 20-21 + :emphasize-lines: 14-15, 22-23 import torch from torch.nn import Linear, MSELoss, ReLU, Sequential @@ -121,26 +128,28 @@ computed between the per-layer blocks of the task gradients. loss_fn = MSELoss() inputs, t1, t2 = torch.randn(8, 16, 10), torch.randn(8, 16, 1), torch.randn(8, 16, 1) - leaf_layers = [m for m in encoder.modules() if not list(m.children()) and list(m.parameters())] - gradvacs = [GradVac() for _ in leaf_layers] + leaf_layers = [m for m in encoder.modules() if list(m.parameters()) and not list(m.children())] + aggregators = [GradVac() for _ in leaf_layers] for x, y1, y2 in zip(inputs, t1, t2): features = encoder(x) - mtl_backward([loss_fn(task1_head(features), y1), loss_fn(task2_head(features), y2)], features=features) - for layer, gradvac in zip(leaf_layers, gradvacs): - jac_to_grad(layer.parameters(), gradvac) + loss1 = loss_fn(task1_head(features), y1) + loss2 = loss_fn(task2_head(features), y2) + mtl_backward([loss1, loss2], features=features) + for layer, aggregator in zip(leaf_layers, aggregators): + jac_to_grad(layer.parameters(), aggregator) optimizer.step() optimizer.zero_grad() -4. All Matrices ---------------- +4. Per parameter +---------------- -One :class:`~torchjd.aggregation.GradVac` instance per individual parameter tensor. Cosine -similarities are computed between the per-tensor blocks of the task gradients (e.g. weights and -biases of each layer are treated as separate groups). +One :class:`~torchjd.aggregation.Aggregator` instance per individual parameter tensor. Cosine +similarities are computed per-tensor between the task gradients (e.g. weights and biases of each +layer are treated as separate groups). .. testcode:: - :emphasize-lines: 14-15, 20-21 + :emphasize-lines: 14-15, 22-23 import torch from torch.nn import Linear, MSELoss, ReLU, Sequential @@ -156,12 +165,14 @@ biases of each layer are treated as separate groups). inputs, t1, t2 = torch.randn(8, 16, 10), torch.randn(8, 16, 1), torch.randn(8, 16, 1) shared_params = list(encoder.parameters()) - gradvacs = [GradVac() for _ in shared_params] + aggregators = [GradVac() for _ in shared_params] for x, y1, y2 in zip(inputs, t1, t2): features = encoder(x) - mtl_backward([loss_fn(task1_head(features), y1), loss_fn(task2_head(features), y2)], features=features) - for param, gradvac in zip(shared_params, gradvacs): - jac_to_grad([param], gradvac) + loss1 = loss_fn(task1_head(features), y1) + loss2 = loss_fn(task2_head(features), y2) + mtl_backward([loss1, loss2], features=features) + for param, aggregator in zip(shared_params, aggregators): + jac_to_grad([param], aggregator) optimizer.step() optimizer.zero_grad() From 275eeb1290af1af36d7f6058fee6dc57e4c0fdb7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 16 Apr 2026 00:47:40 +0200 Subject: [PATCH 3/6] Rename "shared parameters vector" to "parameter vector" so that it also applies to non-MTL --- docs/source/examples/index.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/examples/index.rst b/docs/source/examples/index.rst index c1f1e836b..9d0c48849 100644 --- a/docs/source/examples/index.rst +++ b/docs/source/examples/index.rst @@ -31,7 +31,7 @@ This section contains some usage examples for TorchJD. ``LightningModule`` optimized by Jacobian descent. - :doc:`Grouping ` shows how to apply an aggregator independently per parameter group (e.g. per layer), so that conflict resolution happens at a finer granularity than the full - shared parameter vector. + parameter vector. - :doc:`Automatic Mixed Precision ` shows how to combine mixed precision training with TorchJD. .. toctree:: From c9ff9db4a14430b72a7dd5f15ccf58d14dc5445b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 16 Apr 2026 00:48:22 +0200 Subject: [PATCH 4/6] Simplify the note about GradVac grouping and make it use explicitly the names used in GradVac's paper and in LibMTL --- src/torchjd/aggregation/_gradvac.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/torchjd/aggregation/_gradvac.py b/src/torchjd/aggregation/_gradvac.py index a98ee0c28..e593a8eb5 100644 --- a/src/torchjd/aggregation/_gradvac.py +++ b/src/torchjd/aggregation/_gradvac.py @@ -42,12 +42,8 @@ class GradVac(GramianWeightedAggregator, Stateful): you need reproducibility. .. note:: - To apply GradVac with per-layer or per-parameter-group granularity, create a separate - :class:`GradVac` instance for each group and call - :func:`~torchjd.autojac.jac_to_grad` once per group after - :func:`~torchjd.autojac.mtl_backward`. Each instance maintains its own EMA state, - matching the per-block targets :math:`\hat{\phi}_{ijk}` from the original paper. See - the :doc:`Grouping ` example for details. + To apply GradVac with the `whole_model`, `enc_dec`, `all_layer` or `all_matrix` grouping + strategy, please refer to the :doc:`Grouping ` examples. """ def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None: From d4797fcfd38c53d5e89c9c1cb91df7d04482e821 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 16 Apr 2026 00:52:51 +0200 Subject: [PATCH 5/6] Simplify the introduction of the grouping example --- docs/source/examples/grouping.rst | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/docs/source/examples/grouping.rst b/docs/source/examples/grouping.rst index 808e8d2b6..8bfcee69a 100644 --- a/docs/source/examples/grouping.rst +++ b/docs/source/examples/grouping.rst @@ -1,10 +1,9 @@ Grouping ======== -When applying a conflict-resolving aggregator such as :class:`~torchjd.aggregation.GradVac` in -multi-task learning, the cosine similarities between task gradients can be computed at different -granularities. The [Gradient Vaccine paper](https://arxiv.org/pdf/2010.05874) introduces four -strategies, each partitioning the shared parameter vector differently: +The aggregation can be made independently on groups of parameters, at different granularities. The +[Gradient Vaccine paper](https://arxiv.org/pdf/2010.05874) introduces four strategies to partition +the parameters: 1. **Together** (baseline): one group covering all shared parameters. Corresponds to the `whole_model` stategy in the paper. From df75f0002dbfed6707acb71bdf514dd1d261619f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Thu, 16 Apr 2026 00:59:23 +0200 Subject: [PATCH 6/6] Fix link, a few more improvements / simplifications --- docs/source/examples/grouping.rst | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/source/examples/grouping.rst b/docs/source/examples/grouping.rst index 8bfcee69a..a04e50a04 100644 --- a/docs/source/examples/grouping.rst +++ b/docs/source/examples/grouping.rst @@ -2,29 +2,29 @@ Grouping ======== The aggregation can be made independently on groups of parameters, at different granularities. The -[Gradient Vaccine paper](https://arxiv.org/pdf/2010.05874) introduces four strategies to partition +`Gradient Vaccine paper `_ introduces four strategies to partition the parameters: -1. **Together** (baseline): one group covering all shared parameters. Corresponds to the - `whole_model` stategy in the paper. +1. **Together** (baseline): one group covering all parameters. Corresponds to the `whole_model` + stategy in the paper. 2. **Per network**: one group per top-level sub-network (e.g. encoder and decoder separately). Corresponds to the `enc_dec` stategy in the paper. -3. **Per layer**: one group per leaf module of the encoder. Corresponds to the `all_layer` stategy +3. **Per layer**: one group per leaf module of the network. Corresponds to the `all_layer` stategy in the paper. 4. **Per tensor**: one group per individual parameter tensor. Corresponds to the `all_matrix` stategy in the paper. In TorchJD, grouping is achieved by calling :func:`~torchjd.autojac.jac_to_grad` once per group -after :func:`~torchjd.autojac.mtl_backward`, with a dedicated aggregator instance per group. -For :class:`~torchjd.aggregation.Stateful` aggregators, each instance independently maintains its -own state (e.g. the EMA :math:`\hat{\phi}` state in :class:`~torchjd.aggregation.GradVac`), matching -the per-block targets from the original paper. +after :func:`~torchjd.autojac.backward` or :func:`~torchjd.autojac.mtl_backward`, with a dedicated +aggregator instance per group. For :class:`~torchjd.aggregation.Stateful` aggregators, each instance +should independently maintains its own state (e.g. the EMA :math:`\hat{\phi}` state in +:class:`~torchjd.aggregation.GradVac`, matching the per-block targets from the original paper). .. note:: - The grouping is orthogonal to the choice of + The grouping is orthogonal to the choice between :func:`~torchjd.autojac.backward` vs :func:`~torchjd.autojac.mtl_backward`. Those functions determine *which* parameters receive Jacobians; grouping then determines *how* those Jacobians are partitioned for aggregation.