Skip to content
3 changes: 3 additions & 0 deletions docs/source/docs/aggregation/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ Abstract base classes
.. autoclass:: torchjd.aggregation.Stateful
:members: reset

.. autoclass:: torchjd.aggregation.Stochastic
:members: reset


.. toctree::
:hidden:
Expand Down
3 changes: 2 additions & 1 deletion src/torchjd/aggregation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
from ._krum import Krum, KrumWeighting
from ._mean import Mean, MeanWeighting
from ._mgda import MGDA, MGDAWeighting
from ._mixins import Stateful
from ._mixins import Stateful, Stochastic

Check failure on line 74 in src/torchjd/aggregation/__init__.py

View workflow job for this annotation

GitHub Actions / Code quality (ty and pre-commit hooks)

ty (unresolved-import)

src/torchjd/aggregation/__init__.py:74:32: unresolved-import: Module `torchjd.aggregation._mixins` has no member `Stochastic`
from ._pcgrad import PCGrad, PCGradWeighting
from ._random import Random, RandomWeighting
from ._sum import Sum, SumWeighting
Expand Down Expand Up @@ -109,6 +109,7 @@
"Random",
"RandomWeighting",
"Stateful",
"Stochastic",
"Sum",
"SumWeighting",
"TrimmedMean",
Expand Down
14 changes: 10 additions & 4 deletions src/torchjd/aggregation/_graddrop.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
from torchjd._linalg import Matrix

from ._aggregator_bases import Aggregator
from ._mixins import Stochastic

Check failure on line 9 in src/torchjd/aggregation/_graddrop.py

View workflow job for this annotation

GitHub Actions / Code quality (ty and pre-commit hooks)

ty (unresolved-import)

src/torchjd/aggregation/_graddrop.py:9:22: unresolved-import: Module `torchjd.aggregation._mixins` has no member `Stochastic`
from ._utils.non_differentiable import raise_non_differentiable_error


def _identity(P: Tensor) -> Tensor:
return P


class GradDrop(Aggregator):
class GradDrop(Aggregator, Stochastic):
"""
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` that applies the gradient combination
steps from GradDrop, as defined in lines 10 to 15 of Algorithm 1 of `Just Pick a Sign:
Expand All @@ -24,16 +25,21 @@
increasing. Defaults to identity.
:param leak: The tensor of leak values, determining how much each row is allowed to leak
through. Defaults to None, which means no leak.
:param seed: Seed for the internal random number generator. If ``None``, a seed is drawn from
the global PyTorch RNG to fork an independent stream.
"""

def __init__(self, f: Callable = _identity, leak: Tensor | None = None) -> None:
def __init__(
self, f: Callable = _identity, leak: Tensor | None = None, seed: int | None = None
) -> None:
if leak is not None and leak.dim() != 1:
raise ValueError(
"Parameter `leak` should be a 1-dimensional tensor. Found `leak.shape = "
f"{leak.shape}`.",
)

super().__init__()
Aggregator.__init__(self)
Stochastic.__init__(self, seed=seed)
self.f = f
self.leak = leak

Expand All @@ -50,7 +56,7 @@

P = 0.5 * (torch.ones_like(matrix[0]) + matrix.sum(dim=0) / matrix.abs().sum(dim=0))
fP = self.f(P)
U = torch.rand(P.shape, dtype=matrix.dtype, device=matrix.device)
U = torch.rand(P.shape, dtype=matrix.dtype, device=matrix.device, generator=self.generator)

vector = torch.zeros_like(matrix[0])
for i in range(len(matrix)):
Expand Down
41 changes: 21 additions & 20 deletions src/torchjd/aggregation/_gradvac.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@
from torch import Tensor

from torchjd._linalg import PSDMatrix
from torchjd.aggregation._mixins import Stateful
from torchjd.aggregation._mixins import Stochastic

Check failure on line 9 in src/torchjd/aggregation/_gradvac.py

View workflow job for this annotation

GitHub Actions / Code quality (ty and pre-commit hooks)

ty (unresolved-import)

src/torchjd/aggregation/_gradvac.py:9:41: unresolved-import: Module `torchjd.aggregation._mixins` has no member `Stochastic`

from ._aggregator_bases import GramianWeightedAggregator
from ._utils.non_differentiable import raise_non_differentiable_error
from ._weighting_bases import Weighting


class GradVacWeighting(Weighting[PSDMatrix], Stateful):
class GradVacWeighting(Weighting[PSDMatrix], Stochastic):
r"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd.aggregation._mixins.Stochastic`
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
:class:`~torchjd.aggregation.GradVac`.

Expand All @@ -37,10 +37,13 @@

:param beta: EMA decay for :math:`\hat{\phi}`.
:param eps: Small non-negative constant added to denominators.
:param seed: Seed for the internal random number generator. If ``None``, a seed is drawn from
the global PyTorch RNG to fork an independent stream.
"""

def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None:
super().__init__()
def __init__(self, beta: float = 0.5, eps: float = 1e-8, seed: int | None = None) -> None:
Weighting.__init__(self)
Stochastic.__init__(self, seed=seed)
if not (0.0 <= beta <= 1.0):
raise ValueError(f"Parameter `beta` must be in [0, 1]. Found beta={beta!r}.")
if eps < 0.0:
Expand Down Expand Up @@ -72,8 +75,9 @@
self._eps = value

def reset(self) -> None:
"""Clears EMA state so the next forward starts from zero targets."""
"""Resets the random number generator and clears the EMA state."""

Stochastic.reset(self)
self._phi_t = None
self._state_key = None

Expand Down Expand Up @@ -101,7 +105,7 @@
cG = C[i] @ G

others = [j for j in range(m) if j != i]
perm = torch.randperm(len(others))
perm = torch.randperm(len(others), generator=self.generator)
shuffled_js = [others[idx] for idx in perm.tolist()]

for j in shuffled_js:
Expand Down Expand Up @@ -133,9 +137,9 @@
self._state_key = key


class GradVac(GramianWeightedAggregator, Stateful):
class GradVac(GramianWeightedAggregator, Stochastic):
r"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd.aggregation._mixins.Stochastic`
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` implementing the aggregation step of
Gradient Vaccine (GradVac) from `Gradient Vaccine: Investigating and Improving Multi-task
Optimization in Massively Multilingual Models (ICLR 2021 Spotlight)
Expand All @@ -155,22 +159,18 @@

:param beta: EMA decay for :math:`\hat{\phi}`.
:param eps: Small non-negative constant added to denominators.

.. note::
For each task :math:`i`, the order of other tasks :math:`j` is shuffled independently
using the global PyTorch RNG (``torch.randperm``). Seed it with ``torch.manual_seed`` if
you need reproducibility.
:param seed: Seed for the internal random number generator. If ``None``, a seed is drawn from
the global PyTorch RNG to fork an independent stream.

.. note::
To apply GradVac with the `whole_model`, `enc_dec`, `all_layer` or `all_matrix` grouping
strategy, please refer to the :doc:`Grouping </examples/grouping>` examples.
"""

gramian_weighting: GradVacWeighting

def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None:
weighting = GradVacWeighting(beta=beta, eps=eps)
super().__init__(weighting)
def __init__(self, beta: float = 0.5, eps: float = 1e-8, seed: int | None = None) -> None:
weighting = GradVacWeighting(beta=beta, eps=eps, seed=seed)
GramianWeightedAggregator.__init__(self, weighting)
Stochastic.__init__(self, generator=weighting.generator)
self._gradvac_weighting = weighting
self.register_full_backward_pre_hook(raise_non_differentiable_error)

Expand All @@ -191,8 +191,9 @@
self._gradvac_weighting.eps = value

def reset(self) -> None:
"""Clears EMA state so the next forward starts from zero targets."""
"""Resets the random number generator and clears the EMA state."""

Stochastic.reset(self)
self._gradvac_weighting.reset()

def __repr__(self) -> str:
Expand Down
47 changes: 45 additions & 2 deletions src/torchjd/aggregation/_mixins.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,52 @@
from abc import ABC, abstractmethod

import torch


class Stateful(ABC):
"""Mixin adding a reset method."""
r"""
Mixin for stateful mappings.

A maping implements `Stateful` **if and only if** its behavior depends on an internal
state.

Formally, a stateless mapping is a function :math:`f : x \mapsto y` whereas a stateful
maping is a transition map :math:`A : (x, s) \mapsto (y, s')` where :math:`s` is the
internal state, :math:`s'` the updated state, and :math:`y` the output.
There exists an initial state :math:`s_0`, and the method `reset()` restores the state to
:math:`s_0`. A `Stateful` mapping must be constructed with the intial state :math:`s_0`.
"""

@abstractmethod
def reset(self) -> None:
"""Resets the internal state."""
"""Resets the internal state :math:`s_0`."""


class StochasticState(Stateful):
r"""
State respresenting stochasticity.

Internally, a ``StochasticState`` mapping holds a :class:`torch.Generator` that serves as an
independent random number stream.

:param seed: Seed for the internal :class:`torch.Generator`. If ``None``, a seed is drawn
from the global PyTorch RNG to fork an independent stream.
:param generator: An existing :class:`torch.Generator` to share, typically from a companion
:class:`StochasticState` instance. Mutually exclusive with ``seed``.
"""

def __init__(self, seed: int | None = None, generator: torch.Generator | None = None) -> None:
if generator is not None and seed is not None:
raise ValueError("Parameters `seed` and `generator` are mutually exclusive.")
if generator is not None:
self.generator = generator
else:
self.generator = torch.Generator()
if seed is None:
seed = int(torch.randint(0, 2**62, size=(1,), dtype=torch.int64).item())
self.generator.manual_seed(seed)
self._initial_rng_state = self.generator.get_state()

def reset(self) -> None:
"""Resets the random number generator to its initial state."""
self.generator.set_state(self._initial_rng_state)
24 changes: 17 additions & 7 deletions src/torchjd/aggregation/_pcgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,27 @@
from torch import Tensor

from torchjd._linalg import PSDMatrix
from torchjd.aggregation import Stateful
from torchjd.aggregation._mixins import StochasticState

from ._aggregator_bases import GramianWeightedAggregator
from ._utils.non_differentiable import raise_non_differentiable_error
from ._weighting_bases import Weighting


class PCGradWeighting(Weighting[PSDMatrix]):
class PCGradWeighting(Weighting[PSDMatrix], Stateful):
"""
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
:class:`~torchjd.aggregation.PCGrad`.

:param seed: Seed for the internal random number generator. If ``None``, a seed is drawn from
the global PyTorch RNG to fork an independent stream.
"""

def __init__(self, seed: int | None = None) -> None:
super().__init__()
self.state = StochasticState(seed=seed)

def forward(self, gramian: PSDMatrix, /) -> Tensor:
# Move all computations on cpu to avoid moving memory between cpu and gpu at each iteration
device = gramian.device
Expand All @@ -27,7 +36,7 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:
weights = torch.zeros(dimension, device=cpu, dtype=dtype)

for i in range(dimension):
permutation = torch.randperm(dimension)
permutation = torch.randperm(dimension, generator=self.state.generator)
current_weights = torch.zeros(dimension, device=cpu, dtype=dtype)
current_weights[i] = 1.0

Expand All @@ -46,16 +55,17 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:
return weights.to(device)


class PCGrad(GramianWeightedAggregator):
class PCGrad(GramianWeightedAggregator, Stateful):
"""
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in algorithm 1 of
`Gradient Surgery for Multi-Task Learning <https://arxiv.org/pdf/2001.06782.pdf>`_.
"""

gramian_weighting: PCGradWeighting
:param seed: Seed for the internal random number generator. If ``None``, a seed is drawn from
the global PyTorch RNG to fork an independent stream.
"""

def __init__(self) -> None:
super().__init__(PCGradWeighting())
def __init__(self, seed: int | None = None) -> None:
super().__init__(PCGradWeighting(seed=seed))

# This prevents running into a RuntimeError due to modifying stored tensors in place.
self.register_full_backward_pre_hook(raise_non_differentiable_error)
27 changes: 20 additions & 7 deletions src/torchjd/aggregation/_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,43 @@
from torchjd._linalg import Matrix

from ._aggregator_bases import WeightedAggregator
from ._mixins import Stochastic

Check failure on line 8 in src/torchjd/aggregation/_random.py

View workflow job for this annotation

GitHub Actions / Code quality (ty and pre-commit hooks)

ty (unresolved-import)

src/torchjd/aggregation/_random.py:8:22: unresolved-import: Module `torchjd.aggregation._mixins` has no member `Stochastic`
from ._weighting_bases import Weighting


class RandomWeighting(Weighting[Matrix]):
class RandomWeighting(Weighting[Matrix], Stochastic):
"""
:class:`~torchjd.aggregation._weighting_bases.Weighting` that generates positive random weights
at each call.

:param seed: Seed for the internal random number generator. If ``None``, a seed is drawn from
the global PyTorch RNG to fork an independent stream.
"""

def __init__(self, seed: int | None = None) -> None:
Weighting.__init__(self)
Stochastic.__init__(self, seed=seed)

def forward(self, matrix: Tensor, /) -> Tensor:
random_vector = torch.randn(matrix.shape[0], device=matrix.device, dtype=matrix.dtype)
random_vector = torch.randn(
matrix.shape[0], device=matrix.device, dtype=matrix.dtype, generator=self.generator
)
weights = F.softmax(random_vector, dim=-1)
return weights


class Random(WeightedAggregator):
class Random(WeightedAggregator, Stochastic):
"""
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` that computes a random combination of
the rows of the provided matrices, as defined in algorithm 2 of `Reasonable Effectiveness of
Random Weighting: A Litmus Test for Multi-Task Learning
<https://arxiv.org/pdf/2111.10603.pdf>`_.
"""

weighting: RandomWeighting
:param seed: Seed for the internal random number generator. If ``None``, a seed is drawn from
the global PyTorch RNG to fork an independent stream.
"""

def __init__(self) -> None:
super().__init__(RandomWeighting())
def __init__(self, seed: int | None = None) -> None:
weighting = RandomWeighting(seed=seed)
WeightedAggregator.__init__(self, weighting)
Stochastic.__init__(self, generator=weighting.generator)
32 changes: 31 additions & 1 deletion tests/unit/aggregation/_asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch.testing import assert_close
from utils.tensors import rand_, randperm_

from torchjd.aggregation import Aggregator
from torchjd.aggregation import Aggregator, Stateful
from torchjd.aggregation._utils.non_differentiable import NonDifferentiableError


Expand Down Expand Up @@ -110,3 +110,33 @@ def assert_non_differentiable(aggregator: Aggregator, matrix: Tensor) -> None:
vector = aggregator(matrix)
with raises(NonDifferentiableError):
vector.backward(torch.ones_like(vector))


def assert_stateful(aggregator: Aggregator, matrix: Tensor) -> None:
"""
Test that a given `Aggregator` is stateful. Specifically:
- For a fixed state, the aggregator is determinist on the matrix
- The reset method and the constructor both set the state to the initial state
"""

assert isinstance(aggregator, Stateful)

first_pair = (aggregator(matrix), aggregator(matrix))
aggregator.reset()
second_pair = (aggregator(matrix), aggregator(matrix))

assert_close(first_pair[0], second_pair[0], atol=0.0, rtol=0.0)
assert_close(first_pair[1], second_pair[1], atol=0.0, rtol=0.0)


def assert_stateless(aggregator: Aggregator, matrix: Tensor) -> None:
"""
Test that a given `Aggregator` is stateless. Specifically, it must be deterministic.
"""

assert not isinstance(aggregator, Stateful)

first = aggregator(matrix)
second = aggregator(matrix)

assert_close(first, second, atol=0.0, rtol=0.0)
Loading
Loading