From 400dfa6ebc4264fed2388617c10687771d608b09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 29 Oct 2024 22:56:47 +0100 Subject: [PATCH 1/6] Add lightning to test dependencies --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index a40e846cb..e48689fdb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,7 @@ doc = [ test = [ "pytest>=7.3", # Before version 7.3, not all tests are run + "lightning", ] plot = [ From e248eda1d3062598f54c6286799dbb5fd0df94b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 29 Oct 2024 22:57:28 +0100 Subject: [PATCH 2/6] Add pytorch_lightning usage example --- docs/source/examples/index.rst | 4 ++ .../source/examples/lightning_integration.rst | 70 +++++++++++++++++++ tests/doc/test_rst.py | 59 ++++++++++++++++ 3 files changed, 133 insertions(+) create mode 100644 docs/source/examples/lightning_integration.rst diff --git a/docs/source/examples/index.rst b/docs/source/examples/index.rst index ff3c068dd..5278cd760 100644 --- a/docs/source/examples/index.rst +++ b/docs/source/examples/index.rst @@ -13,6 +13,9 @@ This section contains some usage examples for TorchJD. - :doc:`Multi-Task Learning (MTL) ` provides an example of multi-task learning where Jacobian descent is used to optimize the vector of per-task losses of a multi-task model, using the dedicated backpropagation function :doc:`mtl_backward <../docs/autojac/mtl_backward>`. +- :doc:`PyTorch Lightning Integration ` showcases how to combine + TorchJD with PyTorch Lightning, by providing an example implementation of a multi-task + ``LightningModule`` optimized by Jacobian descent. .. toctree:: :hidden: @@ -20,3 +23,4 @@ This section contains some usage examples for TorchJD. basic_usage.rst iwrm.rst mtl.rst + lightning_integration.rst diff --git a/docs/source/examples/lightning_integration.rst b/docs/source/examples/lightning_integration.rst new file mode 100644 index 000000000..0095e4b6e --- /dev/null +++ b/docs/source/examples/lightning_integration.rst @@ -0,0 +1,70 @@ +PyTorch Lightning Integration +============================= + +To make a step of Jacobian descent with TorchJD in a :class:`~lightning.LightningModule`, you simply +have to disable ``automatic_optimization`` and to override the ``training_step`` method. + +The following code provides an example implementation for multi-task learning using a +:class:`~lightning.LightningModule`. + +.. code-block:: python + :emphasize-lines: 9-10, 18, 32-38 + + import torch + from lightning import LightningModule, Trainer + from lightning.pytorch.utilities.types import OptimizerLRScheduler + from torch.nn import Linear, ReLU, Sequential + from torch.nn.functional import mse_loss + from torch.optim import Adam + from torch.utils.data import DataLoader, TensorDataset + + from torchjd import mtl_backward + from torchjd.aggregation import UPGrad + + class Model(LightningModule): + def __init__(self): + super().__init__() + self.feature_extractor = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU()) + self.task1_head = Linear(3, 1) + self.task2_head = Linear(3, 1) + self.automatic_optimization = False + + def training_step(self, batch, batch_idx) -> None: + input, target1, target2 = batch + + features = self.feature_extractor(input) + output1 = self.task1_head(features) + output2 = self.task2_head(features) + + loss1 = mse_loss(output1, target1) + loss2 = mse_loss(output2, target2) + + opt = self.optimizers() + opt.zero_grad() + mtl_backward( + losses=[loss1, loss2], + features=features, + tasks_params=[self.task1_head.parameters(), self.task2_head.parameters()], + shared_params=self.feature_extractor.parameters(), + A=UPGrad(), + ) + opt.step() + + def configure_optimizers(self) -> OptimizerLRScheduler: + optimizer = Adam(self.parameters(), lr=1e-3) + return optimizer + + model = Model() + inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10 + task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task + task2_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the second task + dataset = TensorDataset(inputs, task1_targets, task2_targets) + train_loader = DataLoader(dataset) + trainer = Trainer(accelerator="cpu", max_epochs=1, enable_checkpointing=False, logger=False) + trainer.fit(model=model, train_dataloaders=train_loader) + +.. warning:: + This will not handle scaling in low-precision settings. There is currently no easy fix. + +.. warning:: + Make sure that your model is not compiled. TorchJD is not compatible with compiled models. diff --git a/tests/doc/test_rst.py b/tests/doc/test_rst.py index afa7ef02a..cbb1c99ec 100644 --- a/tests/doc/test_rst.py +++ b/tests/doc/test_rst.py @@ -116,3 +116,62 @@ def test_mtl(): A=A, ) optimizer.step() + + +def test_lightning_integration(): + import warnings + + warnings.filterwarnings("ignore") + + import torch + from lightning import LightningModule, Trainer + from lightning.pytorch.utilities.types import OptimizerLRScheduler + from torch.nn import Linear, ReLU, Sequential + from torch.nn.functional import mse_loss + from torch.optim import Adam + from torch.utils.data import DataLoader, TensorDataset + + from torchjd import mtl_backward + from torchjd.aggregation import UPGrad + + class Model(LightningModule): + def __init__(self): + super().__init__() + self.feature_extractor = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU()) + self.task1_head = Linear(3, 1) + self.task2_head = Linear(3, 1) + self.automatic_optimization = False + + def training_step(self, batch, batch_idx) -> None: + input, target1, target2 = batch + + features = self.feature_extractor(input) + output1 = self.task1_head(features) + output2 = self.task2_head(features) + + loss1 = mse_loss(output1, target1) + loss2 = mse_loss(output2, target2) + + opt = self.optimizers() + opt.zero_grad() + mtl_backward( + losses=[loss1, loss2], + features=features, + tasks_params=[self.task1_head.parameters(), self.task2_head.parameters()], + shared_params=self.feature_extractor.parameters(), + A=UPGrad(), + ) + opt.step() + + def configure_optimizers(self) -> OptimizerLRScheduler: + optimizer = Adam(self.parameters(), lr=1e-3) + return optimizer + + model = Model() + inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10 + task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task + task2_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the second task + dataset = TensorDataset(inputs, task1_targets, task2_targets) + train_loader = DataLoader(dataset) + trainer = Trainer(accelerator="cpu", max_epochs=1, enable_checkpointing=False, logger=False) + trainer.fit(model=model, train_dataloaders=train_loader) From 48b6015d74a2b6856b77c314438d2dca07f477a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 30 Oct 2024 12:35:26 +0100 Subject: [PATCH 3/6] Add more blank lines in lightning_integration --- docs/source/examples/lightning_integration.rst | 3 +++ tests/doc/test_rst.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/docs/source/examples/lightning_integration.rst b/docs/source/examples/lightning_integration.rst index 0095e4b6e..abc266d4d 100644 --- a/docs/source/examples/lightning_integration.rst +++ b/docs/source/examples/lightning_integration.rst @@ -55,12 +55,15 @@ The following code provides an example implementation for multi-task learning us return optimizer model = Model() + inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10 task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task task2_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the second task + dataset = TensorDataset(inputs, task1_targets, task2_targets) train_loader = DataLoader(dataset) trainer = Trainer(accelerator="cpu", max_epochs=1, enable_checkpointing=False, logger=False) + trainer.fit(model=model, train_dataloaders=train_loader) .. warning:: diff --git a/tests/doc/test_rst.py b/tests/doc/test_rst.py index cbb1c99ec..ae63097a6 100644 --- a/tests/doc/test_rst.py +++ b/tests/doc/test_rst.py @@ -168,10 +168,13 @@ def configure_optimizers(self) -> OptimizerLRScheduler: return optimizer model = Model() + inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10 task1_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the first task task2_targets = torch.randn(8, 16, 1) # 8 batches of 16 targets for the second task + dataset = TensorDataset(inputs, task1_targets, task2_targets) train_loader = DataLoader(dataset) trainer = Trainer(accelerator="cpu", max_epochs=1, enable_checkpointing=False, logger=False) + trainer.fit(model=model, train_dataloaders=train_loader) From b016fdf94d851eb63bd230d0eb341e7456657122 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 30 Oct 2024 12:44:44 +0100 Subject: [PATCH 4/6] Improve text around the lightning example --- docs/source/examples/lightning_integration.rst | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/docs/source/examples/lightning_integration.rst b/docs/source/examples/lightning_integration.rst index abc266d4d..77ce0e274 100644 --- a/docs/source/examples/lightning_integration.rst +++ b/docs/source/examples/lightning_integration.rst @@ -1,11 +1,14 @@ PyTorch Lightning Integration ============================= -To make a step of Jacobian descent with TorchJD in a :class:`~lightning.LightningModule`, you simply -have to disable ``automatic_optimization`` and to override the ``training_step`` method. +To use Jacobian descent with TorchJD in a :class:`~lightning.LightningModule`, you need to turn off +automatic optimization by setting ``automatic_optimization`` to ``False`` and to customize the +``training_step`` method to make it call the appropriate TorchJD method (:doc:`backward +<../docs/autojac/backward>` or :doc:`mtl_backward <../docs/autojac/mtl_backward>`). -The following code provides an example implementation for multi-task learning using a -:class:`~lightning.LightningModule`. +The following code example demonstrates a basic multi-task learning setup using a +:class:`~lightning.LightningModule` that will call :doc:`mtl_backward +<../docs/autojac/mtl_backward>` at each training iteration. .. code-block:: python :emphasize-lines: 9-10, 18, 32-38 @@ -67,7 +70,9 @@ The following code provides an example implementation for multi-task learning us trainer.fit(model=model, train_dataloaders=train_loader) .. warning:: - This will not handle scaling in low-precision settings. There is currently no easy fix. + This will not handle automatic scaling in low-precision settings. There is currently no easy + fix. .. warning:: - Make sure that your model is not compiled. TorchJD is not compatible with compiled models. + TorchJD is incompatible with compiled models, so you must ensure that your model is not + compiled. From f5fa032e5b3d3561eb8afc83303437c5184000da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 30 Oct 2024 12:45:47 +0100 Subject: [PATCH 5/6] Add changelog entry --- CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ad1be8e51..4dfa05750 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). This changelog does not include internal changes that do not affect the user. +## [Unreleased] + +### Added + +- PyTorch Lightning integration example. + ## [0.2.1] - 2024-09-17 ### Changed From 3d0020f1a4f52b1396059aa5fcf1aeb2dcb55a59 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 30 Oct 2024 12:58:09 +0100 Subject: [PATCH 6/6] Add lower cap on lightning version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e48689fdb..2c05281d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,7 @@ doc = [ test = [ "pytest>=7.3", # Before version 7.3, not all tests are run - "lightning", + "lightning>=2.0.9", # No OptimizerLRScheduler public type before 2.0.9 ] plot = [