diff --git a/deepmd/pt/utils/auto_batch_size.py b/deepmd/pt/utils/auto_batch_size.py index 5f8e0930d3..306d722fad 100644 --- a/deepmd/pt/utils/auto_batch_size.py +++ b/deepmd/pt/utils/auto_batch_size.py @@ -49,20 +49,54 @@ def is_oom_error(self, e: Exception) -> bool: e : Exception Exception """ - # several sources think CUSOLVER_STATUS_INTERNAL_ERROR is another out-of-memory error, - # such as https://github.com/JuliaGPU/CUDA.jl/issues/1924 - # (the meaningless error message should be considered as a bug in cusolver) - if ( - isinstance(e, RuntimeError) - and ( - "CUDA out of memory." in e.args[0] - or "CUDA driver error: out of memory" in e.args[0] - or "cusolver error: CUSOLVER_STATUS_INTERNAL_ERROR" in e.args[0] - # https://github.com/deepmodeling/deepmd-kit/issues/4594 - or "CUDA error: out of memory" in e.args[0] - ) - ) or isinstance(e, torch.cuda.OutOfMemoryError): - # Release all unoccupied cached memory + if isinstance(e, torch.cuda.OutOfMemoryError): torch.cuda.empty_cache() return True + + if not isinstance(e, RuntimeError): + return False + + # Gather messages from the exception itself and its chain. AOTInductor + # (.pt2) sometimes strips the underlying OOM message when rewrapping, + # but not always; checking ``__cause__`` / ``__context__`` catches the + # remaining cases when the original error is preserved. + msgs: list[str] = [] + cur: BaseException | None = e + seen: set[int] = set() + while cur is not None and id(cur) not in seen: + seen.add(id(cur)) + if cur.args: + first = cur.args[0] + if isinstance(first, str): + msgs.append(first) + cur = cur.__cause__ or cur.__context__ + + # Several sources treat CUSOLVER_STATUS_INTERNAL_ERROR as an OOM, e.g. + # https://github.com/JuliaGPU/CUDA.jl/issues/1924 + # https://github.com/deepmodeling/deepmd-kit/issues/4594 + plain_oom_markers = ( + "CUDA out of memory.", + "CUDA driver error: out of memory", + "CUDA error: out of memory", + "cusolver error: CUSOLVER_STATUS_INTERNAL_ERROR", + ) + if any(m in msg for msg in msgs for m in plain_oom_markers): + torch.cuda.empty_cache() + return True + + # AOTInductor (.pt2) wraps the underlying CUDA OOM as a generic + # ``run_func_(...) API call failed at .../model_container_runner.cpp``. + # The original "CUDA out of memory" text is printed to stderr only and + # is absent from the Python-level RuntimeError, so we match on the + # wrapper signature. If the root cause turns out to be something + # other than OOM, ``execute()`` will keep shrinking the batch and + # eventually raise ``OutOfMemoryError`` at batch size 1, which is a + # clean failure rather than an uncaught exception. + aoti_wrapped = any( + "run_func_(" in msg and "model_container_runner" in msg for msg in msgs + ) + if aoti_wrapped: + torch.cuda.empty_cache() + return True + return False diff --git a/source/tests/pt/test_auto_batch_size.py b/source/tests/pt/test_auto_batch_size.py index c67a23df52..e7bb69b62e 100644 --- a/source/tests/pt/test_auto_batch_size.py +++ b/source/tests/pt/test_auto_batch_size.py @@ -1,5 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import unittest +from unittest import ( + mock, +) import numpy as np @@ -9,6 +12,36 @@ class TestAutoBatchSize(unittest.TestCase): + @mock.patch("deepmd.pt.utils.auto_batch_size.torch.cuda.empty_cache") + def test_is_oom_error_cuda_message(self, empty_cache) -> None: + auto_batch_size = AutoBatchSize(256, 2.0) + + self.assertTrue( + auto_batch_size.is_oom_error(RuntimeError("CUDA out of memory.")) + ) + empty_cache.assert_called_once() + + @mock.patch("deepmd.pt.utils.auto_batch_size.torch.cuda.empty_cache") + def test_is_oom_error_empty_runtime_error_from_cuda_oom(self, empty_cache) -> None: + auto_batch_size = AutoBatchSize(256, 2.0) + cause = RuntimeError("CUDA driver error: out of memory") + error = RuntimeError() + error.__cause__ = cause + + self.assertTrue(auto_batch_size.is_oom_error(error)) + empty_cache.assert_called_once() + + @mock.patch("deepmd.pt.utils.auto_batch_size.torch.cuda.empty_cache") + def test_is_oom_error_aoti_wrapper(self, empty_cache) -> None: + auto_batch_size = AutoBatchSize(256, 2.0) + error = RuntimeError( + "run_func_(...) API call failed at " + "/tmp/torchinductor/model_container_runner.cpp" + ) + + self.assertTrue(auto_batch_size.is_oom_error(error)) + empty_cache.assert_called_once() + def test_execute_all(self) -> None: dd0 = np.zeros((10000, 2, 1, 3, 4)) dd1 = np.ones((10000, 2, 1, 3, 4))