Skip to content

RuntimeError during WXFormer training on Casper with PyTorch 2.10.0 #288

@kevinyang-cky

Description

@kevinyang-cky

This documents a regression when upgrading to PyTorch 2.10.0.

Training runs successfully with PyTorch 2.8.0 and 2.9.0, but fails on 2.10.0 with the following error during the forward pass:

[rank0]: Traceback (most recent call last):
[rank0]:   File "/glade/u/home/kevinyang/miles-credit_multichannel/applications/train_regional_goes.py", line 525, in <module>
[rank0]:     main_cli()
[rank0]:   File "/glade/u/home/kevinyang/miles-credit_multichannel/applications/train_regional_goes.py", line 521, in main_cli
[rank0]:     main(world_rank, world_size, conf, backend)
[rank0]:   File "/glade/u/home/kevinyang/miles-credit_multichannel/applications/train_regional_goes.py", line 358, in main
[rank0]:     result = trainer.fit(
[rank0]:              ^^^^^^^^^^^^
[rank0]:   File "/glade/u/home/kevinyang/miles-credit_multichannel_exp005/credit/trainers/base_trainer.py", line 313, in fit
[rank0]:     train_results = self.train_one_epoch(
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/glade/u/home/kevinyang/miles-credit_multichannel_exp005/credit/trainers/trainerLES.py", line 244, in train_one_epoch
[rank0]:     y_pred = self.model(x)
[rank0]:              ^^^^^^^^^^^^^
[rank0]:   File "/glade/work/kevinyang/conda-envs/miles-credit-casper/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/glade/work/kevinyang/conda-envs/miles-credit-casper/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/glade/work/kevinyang/conda-envs/miles-credit-casper/lib/python3.12/site-packages/torch/nn/parallel/distributed.py", line 1666, in forward
[rank0]:     else self._run_ddp_forward(*inputs, **kwargs)
[rank0]:          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/glade/work/kevinyang/conda-envs/miles-credit-casper/lib/python3.12/site-packages/torch/nn/parallel/distributed.py", line 1492, in _run_ddp_forward
[rank0]:     return self.module(*inputs, **kwargs)  # type: ignore[index]
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/glade/work/kevinyang/conda-envs/miles-credit-casper/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/glade/work/kevinyang/conda-envs/miles-credit-casper/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/glade/u/home/kevinyang/miles-credit_multichannel_exp005/credit/models/wxformer/crossformer.py", line 665, in forward
[rank0]:     x = cel(x)
[rank0]:         ^^^^^^
[rank0]:   File "/glade/work/kevinyang/conda-envs/miles-credit-casper/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/glade/work/kevinyang/conda-envs/miles-credit-casper/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/glade/work/kevinyang/conda-envs/miles-credit-casper/lib/python3.12/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 168, in forward
[rank0]:     return self.checkpoint_fn(  # type: ignore[misc]
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/glade/work/kevinyang/conda-envs/miles-credit-casper/lib/python3.12/site-packages/torch/_compile.py", line 54, in inner
[rank0]:     return disable_fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/glade/work/kevinyang/conda-envs/miles-credit-casper/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 1181, in _fn
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/glade/work/kevinyang/conda-envs/miles-credit-casper/lib/python3.12/site-packages/torch/utils/checkpoint.py", line 512, in checkpoint
[rank0]:     ret = function(*args, **kwargs)
[rank0]:           ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/glade/work/kevinyang/conda-envs/miles-credit-casper/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/glade/work/kevinyang/conda-envs/miles-credit-casper/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1787, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/glade/u/home/kevinyang/miles-credit_multichannel_exp005/credit/models/wxformer/crossformer.py", line 187, in forward
[rank0]:     fmaps = tuple(map(lambda conv: conv(x), self.convs))
[rank0]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/glade/u/home/kevinyang/miles-credit_multichannel_exp005/credit/models/wxformer/crossformer.py", line 187, in <lambda>
[rank0]:     fmaps = tuple(map(lambda conv: conv(x), self.convs))
[rank0]:                                    ^^^^^^^
[rank0]:   File "/glade/work/kevinyang/conda-envs/miles-credit-casper/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1776, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/glade/work/kevinyang/conda-envs/miles-credit-casper/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1882, in _call_impl
[rank0]:     return inner()
[rank0]:            ^^^^^^^
[rank0]:   File "/glade/work/kevinyang/conda-envs/miles-credit-casper/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1819, in inner
[rank0]:     args_result = hook(self, args)
[rank0]:                   ^^^^^^^^^^^^^^^^
[rank0]:   File "/glade/work/kevinyang/conda-envs/miles-credit-casper/lib/python3.12/site-packages/torch/nn/utils/spectral_norm.py", line 129, in __call__
[rank0]:     self.compute_weight(module, do_power_iteration=module.training),
[rank0]:     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/glade/work/kevinyang/conda-envs/miles-credit-casper/lib/python3.12/site-packages/torch/nn/utils/spectral_norm.py", line 104, in compute_weight
[rank0]:     torch.mv(weight_mat.t(), u), dim=0, eps=self.eps, out=v
[rank0]:     ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: CUDA error: CUBLAS_STATUS_INVALID_VALUE when calling `cublasSgemv(handle, op, m, n, &alpha, a, lda, x, incx, &beta, y, incy)`
[rank0]:[W306 10:32:52.126109338 ProcessGroupNCCL.cpp:1553] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions