From f5ee3e303bfeecb624c2165bf95dd88c5434ee2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 9 Jan 2026 11:59:42 +0100 Subject: [PATCH 01/18] remove hybrid_cache --- .../onnx_export_serialization.py | 53 ++++++++++++------- .../serialization/transformers_impl.py | 46 ++++++++-------- 2 files changed, 55 insertions(+), 44 deletions(-) diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py index 7b81277c..4e116a07 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py @@ -7,14 +7,9 @@ from transformers.cache_utils import DynamicCache, StaticCache try: - from transformers.cache_utils import ( - EncoderDecoderCache, - HybridCache, - SlidingWindowCache, - ) + from transformers.cache_utils import EncoderDecoderCache, SlidingWindowCache except ImportError: EncoderDecoderCache = None - HybridCache = None SlidingWindowCache = None from ..helpers import string_type from .serialization import _lower_name_with_ @@ -36,6 +31,15 @@ def get_mamba_cache_cls() -> type: return None +def get_hybrid_cache_cls() -> type: + try: + from transformers.cache_utils import HybridCache + + return HybridCache + except ImportError: + return None + + def register_class_serialization( cls, f_flatten: Callable, @@ -179,12 +183,6 @@ def serialization_functions( flatten_dynamic_cache, unflatten_dynamic_cache, flatten_with_keys_dynamic_cache, - flatten_hybrid_cache, - unflatten_hybrid_cache, - flatten_with_keys_hybrid_cache, - flatten_mamba_cache, - unflatten_mamba_cache, - flatten_with_keys_mamba_cache, flatten_encoder_decoder_cache, unflatten_encoder_decoder_cache, flatten_with_keys_encoder_decoder_cache, @@ -208,14 +206,6 @@ def serialization_functions( # f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]), verbose=verbose, ), - HybridCache: lambda verbose=verbose: register_class_serialization( - HybridCache, - flatten_hybrid_cache, - unflatten_hybrid_cache, - flatten_with_keys_hybrid_cache, - # f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]), - verbose=verbose, - ), EncoderDecoderCache: lambda verbose=verbose: register_class_serialization( EncoderDecoderCache, flatten_encoder_decoder_cache, @@ -240,6 +230,12 @@ def serialization_functions( } MambaCache = get_mamba_cache_cls() if MambaCache: + from .serialization.transformers_impl import ( + flatten_mamba_cache, + unflatten_mamba_cache, + flatten_with_keys_mamba_cache, + ) + transformers_classes[MambaCache] = ( lambda verbose=verbose: register_class_serialization( MambaCache, @@ -249,6 +245,23 @@ def serialization_functions( verbose=verbose, ) ) + HybridCache = get_hybrid_cache_cls() + if HybridCache: + from .serialization.transformers_impl import ( + flatten_hybrid_cache, + unflatten_hybrid_cache, + flatten_with_keys_hybrid_cache, + ) + + transformers_classes[HybridCache] = ( + lambda verbose=verbose: register_class_serialization( + HybridCache, + flatten_hybrid_cache, + unflatten_hybrid_cache, + flatten_with_keys_hybrid_cache, + verbose=verbose, + ) + ) classes.update(transformers_classes) if patch_diffusers: diff --git a/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py b/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py index 548bc1dd..31611c61 100644 --- a/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +++ b/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py @@ -1,13 +1,7 @@ import itertools from typing import Any, Callable, List, Set, Tuple import torch -from transformers.cache_utils import ( - Cache, - DynamicCache, - EncoderDecoderCache, - HybridCache, - StaticCache, -) +from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache try: from transformers.cache_utils import SlidingWindowCache @@ -15,6 +9,11 @@ SlidingWindowCache = None +try: + from transformers.cache_utils import HybridCache +except ImportError: + HybridCache = None + try: from transformers.models.mamba.modeling_mamba import MambaCache except ImportError: @@ -99,26 +98,25 @@ def unflatten_dynamic_cache( # HybridCache ############# +if HybridCache: -def flatten_hybrid_cache( - cache: HybridCache, -) -> Tuple[List[Any], torch.utils._pytree.Context]: - """Serializes a :class:`transformers.cache_utils.HybridCache` with python objects.""" - return _flatten_key_value_cache(cache) - - -def flatten_with_keys_hybrid_cache( - cache: HybridCache, -) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]: - """Serializes a :class:`transformers.cache_utils.HybridCache` with python objects.""" - return _flatten_with_keys_cache(cache) + def flatten_hybrid_cache( + cache: HybridCache, + ) -> Tuple[List[Any], torch.utils._pytree.Context]: + """Serializes a :class:`transformers.cache_utils.HybridCache` with python objects.""" + return _flatten_key_value_cache(cache) + def flatten_with_keys_hybrid_cache( + cache: HybridCache, + ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]: + """Serializes a :class:`transformers.cache_utils.HybridCache` with python objects.""" + return _flatten_with_keys_cache(cache) -def unflatten_hybrid_cache( - values: List[Any], context: torch.utils._pytree.Context, output_type=None -) -> HybridCache: - """Restores a :class:`transformers.cache_utils.HybridCache` from python objects.""" - return _unflatten_cache(make_hybrid_cache, values, context, output_type=output_type) + def unflatten_hybrid_cache( + values: List[Any], context: torch.utils._pytree.Context, output_type=None + ) -> HybridCache: + """Restores a :class:`transformers.cache_utils.HybridCache` from python objects.""" + return _unflatten_cache(make_hybrid_cache, values, context, output_type=output_type) ############# From 2164b8dd7166a0d682a429034c28a7d0b88c7a50 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 9 Jan 2026 12:00:49 +0100 Subject: [PATCH 02/18] ch --- CHANGELOGS.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 825771cd..a82c5dd9 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,8 @@ Change Logs 0.8.9 +++++ +* :pr:`379`: remove hybrid_cache for ``transformers>=5`` + 0.8.8 +++++ From 9663b7f1a87ffcb8934776fb35f3a8c8f6d6fbc3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 9 Jan 2026 12:10:27 +0100 Subject: [PATCH 03/18] fix unit test --- _unittests/ut_helpers/test_cache_helper.py | 5 +++ .../onnx_export_serialization.py | 41 +++++++++++++------ 2 files changed, 34 insertions(+), 12 deletions(-) diff --git a/_unittests/ut_helpers/test_cache_helper.py b/_unittests/ut_helpers/test_cache_helper.py index db6e2781..7d018b01 100644 --- a/_unittests/ut_helpers/test_cache_helper.py +++ b/_unittests/ut_helpers/test_cache_helper.py @@ -167,6 +167,9 @@ def test_make_mamba_cache(self): ) self.assertEqual(0, max_diff(cache, cache)["abs"]) + @unittest.skipIf( + not make_sliding_window_cache, "SlidingWindowCache removed in transformers>=5" + ) def test_make_sliding_window_cache(self): cache = make_sliding_window_cache( [ @@ -223,6 +226,7 @@ def test_unflatten_flatten_static_cache(self): self.string_type(unflat, with_shape=True), ) + @unittest.skipIf(not make_hybrid_cache, "HybridCache removed in transformers>=5") def test_make_hybrid_cache(self): cache = make_hybrid_cache( [ @@ -240,6 +244,7 @@ def test_make_hybrid_cache(self): self.assertEqual(0, max_diff(cache, cache)["abs"]) self.assertEqual(0, max_diff(cache, torch_deepcopy(cache))["abs"]) + @unittest.skipIf(not make_hybrid_cache, "HybridCache removed in transformers>=5") def test_unflatten_flatten_hybrid_cache(self): with torch_export_patches(patch_transformers=True): c2 = make_hybrid_cache( diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py index 4e116a07..72f07430 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py @@ -7,10 +7,9 @@ from transformers.cache_utils import DynamicCache, StaticCache try: - from transformers.cache_utils import EncoderDecoderCache, SlidingWindowCache + from transformers.cache_utils import EncoderDecoderCache except ImportError: EncoderDecoderCache = None - SlidingWindowCache = None from ..helpers import string_type from .serialization import _lower_name_with_ @@ -40,6 +39,15 @@ def get_hybrid_cache_cls() -> type: return None +def get_sliding_window_cache_cls() -> type: + try: + from transformers.cache_utils import SlidingWindowCache + + return SlidingWindowCache + except ImportError: + return None + + def register_class_serialization( cls, f_flatten: Callable, @@ -186,9 +194,6 @@ def serialization_functions( flatten_encoder_decoder_cache, unflatten_encoder_decoder_cache, flatten_with_keys_encoder_decoder_cache, - flatten_sliding_window_cache, - unflatten_sliding_window_cache, - flatten_with_keys_sliding_window_cache, flatten_static_cache, unflatten_static_cache, flatten_with_keys_static_cache, @@ -213,13 +218,6 @@ def serialization_functions( flatten_with_keys_encoder_decoder_cache, verbose=verbose, ), - SlidingWindowCache: lambda verbose=verbose: register_class_serialization( - SlidingWindowCache, - flatten_sliding_window_cache, - unflatten_sliding_window_cache, - flatten_with_keys_sliding_window_cache, - verbose=verbose, - ), StaticCache: lambda verbose=verbose: register_class_serialization( StaticCache, flatten_static_cache, @@ -262,6 +260,25 @@ def serialization_functions( verbose=verbose, ) ) + + SlidingWindowCache = get_sliding_window_cache_cls() + if SlidingWindowCache: + from .serialization.transformers_impl import ( + flatten_sliding_window_cache, + unflatten_sliding_window_cache, + flatten_with_keys_sliding_window_cache, + ) + + transformers_classes[SlidingWindowCache] = ( + lambda verbose=verbose: register_class_serialization( + SlidingWindowCache, + flatten_sliding_window_cache, + unflatten_sliding_window_cache, + flatten_with_keys_sliding_window_cache, + verbose=verbose, + ) + ) + classes.update(transformers_classes) if patch_diffusers: From bc995be5b30e2e00334e62fdd5164a9986015290 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 9 Jan 2026 12:13:22 +0100 Subject: [PATCH 04/18] remove unnecessary code --- .../torch_export_patches/onnx_export_serialization.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py index 72f07430..1b1bad0d 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py @@ -333,13 +333,7 @@ def unregister_class_serialization(cls: type, verbose: int = 0): def unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0): - """Undo all registrations.""" - MambaCache = get_mamba_cache_cls() - cls_ensemble = ( - {DynamicCache, EncoderDecoderCache} - | set(undo) - | ({MambaCache} if MambaCache else set()) - ) + cls_ensemble = {DynamicCache, EncoderDecoderCache} | set(undo) for cls in cls_ensemble: if undo.get(cls.__name__, False): unregister_class_serialization(cls, verbose) From d79d35558b3631fd82140edf6ab142edbae6e271 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 9 Jan 2026 12:58:45 +0100 Subject: [PATCH 05/18] fix --- _unittests/ut_export/test_shape_helper.py | 20 ++++++++++++------ .../test_patch_serialization_transformers.py | 3 +++ .../patches/patch_torch.py | 21 +++++++++++-------- 3 files changed, 29 insertions(+), 15 deletions(-) diff --git a/_unittests/ut_export/test_shape_helper.py b/_unittests/ut_export/test_shape_helper.py index 0615c819..9be69bc9 100644 --- a/_unittests/ut_export/test_shape_helper.py +++ b/_unittests/ut_export/test_shape_helper.py @@ -69,12 +69,16 @@ def test_all_dynamic_shape_all_transformers_cache(self): ], ), ( - make_sliding_window_cache( - [ - (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), - (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), - (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), - ] + ( + make_sliding_window_cache( + [ + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + ] + ) + if make_sliding_window_cache is not None + else None ), [ {0: "d_0_0", 1: "d_0_1", 2: "d_0_2", 3: "d_0_3"}, @@ -106,11 +110,15 @@ def test_all_dynamic_shape_all_transformers_cache(self): ] with torch_export_patches(patch_transformers=True): for cache, exds in caches: + if cache is None: + continue with self.subTest(cache_name=cache.__class__.__name__, patch=True): ds = all_dynamic_shapes_from_inputs(cache) self.assertEqual(exds, ds) for cache, exds in caches: + if cache is None: + continue with self.subTest(cache_name=cache.__class__.__name__, patch=False): ds = all_dynamic_shapes_from_inputs(cache) self.assertEqual(exds, ds) diff --git a/_unittests/ut_torch_export_patches/test_patch_serialization_transformers.py b/_unittests/ut_torch_export_patches/test_patch_serialization_transformers.py index 352be02a..4f2a92c0 100644 --- a/_unittests/ut_torch_export_patches/test_patch_serialization_transformers.py +++ b/_unittests/ut_torch_export_patches/test_patch_serialization_transformers.py @@ -177,6 +177,7 @@ def test_base_model_output_unflatten_flatten(self): self.assertEqual("#1[T1r3]", self.string_type(unflat)) @ignore_warnings(UserWarning) + @unittest.skipIf(not make_sliding_window_cache, "SlidingWindowCache was removed") def test_base_sliding_window_cache_unflatten_flatten(self): cache = make_sliding_window_cache( [(torch.rand((4, 4, 4, 4)), torch.rand((4, 4, 4, 4)))] @@ -187,6 +188,7 @@ def test_base_sliding_window_cache_unflatten_flatten(self): @ignore_warnings(UserWarning) @requires_torch("2.7.99") + @unittest.skipIf(not make_sliding_window_cache, "SlidingWindowCache was removed") def test_sliding_window_cache_export(self): class Model(torch.nn.Module): def forward(self, cache): @@ -208,6 +210,7 @@ def forward(self, cache): torch.export.export(model, (cache,), dynamic_shapes=(ds,)) @ignore_warnings(UserWarning) + @unittest.skipIf(not make_sliding_window_cache, "SlidingWindowCache was removed") def test_sliding_window_cache_flatten(self): cache = make_sliding_window_cache( [(torch.rand((4, 4, 4, 4)), torch.rand((4, 4, 4, 4)))] diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_torch.py b/onnx_diagnostic/torch_export_patches/patches/patch_torch.py index 63894b51..28d7ede1 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_torch.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_torch.py @@ -541,14 +541,17 @@ def compute_concrete_val() -> sympy.Basic: # oblivious_var_to_val will be defined iff we have sizes # with DimDynamic.OBLIVIOUS_SIZE type. # See https://github.com/pytorch/pytorch/issues/137100#issuecomment-2495778113 + var_to_val = getattr( + self, + "unbacked_var_to_val", + getattr(self, "oblivious_var_to_val", False), + ) if ( - self.oblivious_var_to_val - and not ( - correct_hint := orig_expr.xreplace(self.oblivious_var_to_val) - ).free_symbols + var_to_val + and not (correct_hint := orig_expr.xreplace(var_to_val)).free_symbols and not ( counterfactual_hint := orig_expr.xreplace( - {k: max(2, v) for k, v in self.oblivious_var_to_val.items()} + {k: max(2, v) for k, v in var_to_val.items()} ) ).free_symbols and correct_hint == counterfactual_hint @@ -571,11 +574,11 @@ def compute_concrete_val() -> sympy.Basic: # and if they pass we add a runtime assertions and continue. if ( not ok - and self.unbacked_var_to_val + and var_to_val and not ( - unsound_result := orig_expr.xreplace( - self.unbacked_var_to_val - ).xreplace(self.var_to_val) + unsound_result := orig_expr.xreplace(var_to_val).xreplace( + var_to_val + ) ).free_symbols ): # pyrefly: ignore # unbound-name From 178e780aa4d71ac9f16c6c47784ffa15fd5f89c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 9 Jan 2026 13:12:47 +0100 Subject: [PATCH 06/18] fix --- _unittests/ut_helpers/test_helper.py | 14 ++++++++------ _unittests/ut_helpers/test_torch_helper.py | 1 + _unittests/ut_torch_models/test_validate_models.py | 2 +- onnx_diagnostic/tasks/image_text_to_text.py | 2 +- 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/_unittests/ut_helpers/test_helper.py b/_unittests/ut_helpers/test_helper.py index d2981a91..1aef9297 100644 --- a/_unittests/ut_helpers/test_helper.py +++ b/_unittests/ut_helpers/test_helper.py @@ -667,12 +667,14 @@ def test_max_diff_caches(self): [(torch.rand((1, 1, 3, 4)), torch.rand((1, 1, 3, 4)))], max_cache_len=3 ) self.assertEqual(max_diff(cache, cache)["abs"], 0) - cache = make_hybrid_cache([(torch.rand((1, 1, 3, 4)), torch.rand((1, 1, 3, 4)))]) - self.assertEqual(max_diff(cache, cache)["abs"], 0) - cache = make_sliding_window_cache( - [(torch.rand((1, 1, 3, 4)), torch.rand((1, 1, 3, 4)))] - ) - self.assertEqual(max_diff(cache, cache)["abs"], 0) + if make_hybrid_cache is not None: + cache = make_hybrid_cache([(torch.rand((1, 1, 3, 4)), torch.rand((1, 1, 3, 4)))]) + self.assertEqual(max_diff(cache, cache)["abs"], 0) + if make_sliding_window_cache is not None: + cache = make_sliding_window_cache( + [(torch.rand((1, 1, 3, 4)), torch.rand((1, 1, 3, 4)))] + ) + self.assertEqual(max_diff(cache, cache)["abs"], 0) cache = make_encoder_decoder_cache(cache, cache) self.assertEqual(max_diff(cache, cache)["abs"], 0) diff --git a/_unittests/ut_helpers/test_torch_helper.py b/_unittests/ut_helpers/test_torch_helper.py index 2c3c2990..36c72c5c 100644 --- a/_unittests/ut_helpers/test_torch_helper.py +++ b/_unittests/ut_helpers/test_torch_helper.py @@ -344,6 +344,7 @@ def test_torch_deepcopy_base_model_outputs(self): self.assertEqual(hash1, hash2) self.assertGreater(torch_tensor_size(bo), 1) + @unittest.skipIf(make_sliding_window_cache is None, "SlidingWindowCache was removed") def test_torch_deepcopy_sliding_windon_cache(self): cache = make_sliding_window_cache( [ diff --git a/_unittests/ut_torch_models/test_validate_models.py b/_unittests/ut_torch_models/test_validate_models.py index ad6dae7e..23a5ce82 100644 --- a/_unittests/ut_torch_models/test_validate_models.py +++ b/_unittests/ut_torch_models/test_validate_models.py @@ -62,7 +62,7 @@ def test_validate_microsoft_phi4_reasoning(self): self.assertIn("onnx_filename", data) self.clean_dump() - @requires_transformers("4.53") + @requires_transformers("4.57") @requires_torch("2.8.99") @requires_experimental() @hide_stdout() diff --git a/onnx_diagnostic/tasks/image_text_to_text.py b/onnx_diagnostic/tasks/image_text_to_text.py index 3b60a8b8..d9f02ff7 100644 --- a/onnx_diagnostic/tasks/image_text_to_text.py +++ b/onnx_diagnostic/tasks/image_text_to_text.py @@ -211,7 +211,7 @@ def _check_(): ), position_ids=torch.arange(0, sequence_length).to(torch.int64).expand((batch_size, -1)), cache_position=torch.arange(0, sequence_length).to(torch.int64), - past_key_values=make_hybrid_cache( + past_key_values=(make_hybrid_cache or make_dynamic_cache)( [ ( torch.randn( From 43fbd1c050f441da970a4df6f3343adcf0301591 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 9 Jan 2026 14:40:49 +0100 Subject: [PATCH 07/18] fix cache --- CHANGELOGS.rst | 2 +- onnx_diagnostic/helpers/cache_helper.py | 50 +++++++++++++++++-- onnx_diagnostic/helpers/torch_helper.py | 31 ++++++++---- onnx_diagnostic/tasks/image_text_to_text.py | 10 +++- .../serialization/transformers_impl.py | 12 +++++ 5 files changed, 89 insertions(+), 16 deletions(-) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index a82c5dd9..c0c4b39d 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,7 +4,7 @@ Change Logs 0.8.9 +++++ -* :pr:`379`: remove hybrid_cache for ``transformers>=5`` +* :pr:`379`: update the handling of cache after the removal of HybridCache, SlidingWindowCache in ``transformers>=5``, 0.8.8 +++++ diff --git a/onnx_diagnostic/helpers/cache_helper.py b/onnx_diagnostic/helpers/cache_helper.py index ff0977ab..36fe7679 100644 --- a/onnx_diagnostic/helpers/cache_helper.py +++ b/onnx_diagnostic/helpers/cache_helper.py @@ -28,12 +28,15 @@ def __init__(self, cache=None): ] self.key_cache = [layer.keys for layer in layers] self.value_cache = [layer.values for layer in layers] + self.cls_layers = [type(lay) for lay in cache.layers] elif cache is not None and hasattr(cache, "key_cache"): self.key_cache = cache.key_cache self.value_cache = cache.value_cache + self.cls_layers = None elif cache is None: self.key_cache = None self.value_cache = None + self.cls_layers = None else: raise NotImplementedError(f"type(cache)={type(cache)}") @@ -156,12 +159,16 @@ def _preprocess_key_value_pairs( def make_dynamic_cache( key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]], + cls_layers: Optional[Union[str, List[type]]] = None, ) -> transformers.cache_utils.DynamicCache: """ Creates an instance of :class:`transformers.cache_utils.DynamicCache`. This version is valid for ``transformers >= 4.50``. :param key_value_pairs: list of pairs of (key, values) + :param cls_layers: to select the appropriate class to use on each layer, + if specified, sliding_window is ignored, it can be a string + if all layers are expected to follow the same class :return: :class:`transformers.cache_utils.DynamicCache` Example: @@ -192,15 +199,42 @@ def make_dynamic_cache( are supported. """ key_value_pairs = _preprocess_key_value_pairs(key_value_pairs) + cls_kwargs = {} + if isinstance(cls_layers, str): + assert hasattr( + transformers.cache_utils, cls_layers + ), f"Unable to find class {cls_layers!r} in transformers.cache_utils" + cls_layer = getattr(transformers.cache_utils, cls_layers) + if cls_layers == "DynamicSlidingWindowLayer": + cls_kwargs["sliding_window"] = key_value_pairs[0][0].shape[2] + assert isinstance( + cls_kwargs["sliding_window"], int + ), f"sliding_window must be an integer but shape={key_value_pairs[0][0].shape}" + elif cls_layers is not None: + unique = set(cls_layers) + assert len(unique) == 1, f"Not implemented when cls_layers={cls_layers}" + cls_layer = unique.pop() + if cls_layer == transformers.cache_utils.DynamicSlidingWindowLayer: + from .helper import string_type + + assert key_value_pairs and key_value_pairs[0], ( + f"not implemented for key_value_pairs=" + f"{string_type(key_value_pairs, with_shape=True)}" + ) + cls_kwargs["sliding_window"] = key_value_pairs[0][0].shape[2] + assert isinstance( + cls_kwargs["sliding_window"], int + ), f"sliding_window must be an integer but shape={key_value_pairs[0][0].shape}" + else: + cls_layer = transformers.cache_utils.DynamicLayer + if ( key_value_pairs and isinstance(key_value_pairs[0][0], torch._subclasses.fake_tensor.FakeTensor) and pv.Version(transformers.__version__) >= pv.Version("4.56") ): cache = transformers.cache_utils.DynamicCache() - cache.layers.extend( - [transformers.cache_utils.DynamicLayer() for _ in key_value_pairs] - ) + cache.layers.extend([cls_layer(**cls_kwargs) for _ in key_value_pairs]) for i, layer in enumerate(cache.layers): k, v = key_value_pairs[i][0], key_value_pairs[i][1] layer.dtype = k.dtype @@ -214,7 +248,8 @@ def make_dynamic_cache( ) return finalize_cache(cache) - cache = transformers.cache_utils.DynamicCache(key_value_pairs) + cache = transformers.cache_utils.DynamicCache() + cache.layers.extend([cls_layer(**cls_kwargs) for _ in key_value_pairs]) if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers): # The cache constructor contains the two following lines # (in cache_utils.py) which append empty layers when the cache is @@ -508,9 +543,16 @@ def get_text_config(self, *args, **kwargs): ) return finalize_cache(cache) + def get_make_hybrid_cache(): + return make_sliding_window_cache + else: make_sliding_window_cache = None # type: ignore[assignment] + def get_make_hybrid_cache(): + return None + + if hasattr(transformers.cache_utils, "HybridCache"): def make_hybrid_cache( diff --git a/onnx_diagnostic/helpers/torch_helper.py b/onnx_diagnostic/helpers/torch_helper.py index 2f7978db..756e5589 100644 --- a/onnx_diagnostic/helpers/torch_helper.py +++ b/onnx_diagnostic/helpers/torch_helper.py @@ -15,9 +15,6 @@ from .cache_helper import ( make_dynamic_cache, make_encoder_decoder_cache, - make_hybrid_cache, - make_sliding_window_cache, - make_mamba_cache, make_static_cache, CacheKeyValue, ) @@ -769,10 +766,22 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device, str]) -> Any: return {to_any(t, to_value) for t in value} if type(value) is dict: return {k: to_any(t, to_value) for k, t in value.items()} - if value.__class__.__name__ in {"DynamicCache", "HybridCache"}: - make = dict(DynamicCache=make_dynamic_cache, HybridCache=make_hybrid_cache) + if value.__class__.__name__ == "DynamicCache": + cc = CacheKeyValue(value) + return make_dynamic_cache( + list( + zip( + [t.to(to_value) if t is not None else t for t in cc.key_cache], + [t.to(to_value) if t is not None else t for t in cc.value_cache], + ) + ), + cls_layers=cc.cls_layers, + ) + if value.__class__.__name__ in "HybridCache": + from .cache_helper import make_hybrid_cache + cc = CacheKeyValue(value) - return make[value.__class__.__name__]( # type: ignore[operator] + return make_hybrid_cache( list( zip( [t.to(to_value) if t is not None else t for t in cc.key_cache], @@ -843,7 +852,9 @@ def torch_deepcopy(value: Any) -> Any: from .cache_helper import CacheKeyValue ca = CacheKeyValue(value) - return make_dynamic_cache(torch_deepcopy(list(zip(ca.key_cache, ca.value_cache)))) + return make_dynamic_cache( + torch_deepcopy(list(zip(ca.key_cache, ca.value_cache))), cls_layers=ca.cls_layers + ) if value.__class__.__name__ == "StaticCache": from .cache_helper import CacheKeyValue @@ -858,12 +869,12 @@ def torch_deepcopy(value: Any) -> Any: max_cache_len=max([value.max_cache_len, *[t.shape[2] for t in ca.key_cache]]), ) if value.__class__.__name__ == "HybridCache": - from .cache_helper import CacheKeyValue + from .cache_helper import CacheKeyValue, make_hybrid_cache ca = CacheKeyValue(value) return make_hybrid_cache(torch_deepcopy(list(zip(ca.key_cache, ca.value_cache)))) if value.__class__.__name__ == "SlidingWindowCache": - from .cache_helper import CacheKeyValue + from .cache_helper import CacheKeyValue, make_sliding_window_cache ca = CacheKeyValue(value) return make_sliding_window_cache( @@ -875,6 +886,8 @@ def torch_deepcopy(value: Any) -> Any: torch_deepcopy(value.cross_attention_cache), ) if value.__class__.__name__ == "MambaCache": + from .cache_helper import make_mamba_cache + return make_mamba_cache(list(zip(value.conv_states, value.ssm_states))) if value.__class__ in torch.utils._pytree.SUPPORTED_NODES: diff --git a/onnx_diagnostic/tasks/image_text_to_text.py b/onnx_diagnostic/tasks/image_text_to_text.py index d9f02ff7..e998dead 100644 --- a/onnx_diagnostic/tasks/image_text_to_text.py +++ b/onnx_diagnostic/tasks/image_text_to_text.py @@ -1,7 +1,7 @@ import itertools from typing import Any, Callable, Dict, Optional, Tuple import torch -from ..helpers.cache_helper import make_dynamic_cache, make_hybrid_cache +from ..helpers.cache_helper import make_dynamic_cache, get_make_hybrid_cache from ..helpers.config_helper import ( update_config, check_hasattr, @@ -200,6 +200,12 @@ def _check_(): _check_() + make_hybrid_cache = get_make_hybrid_cache() + if make_hybrid_cache is None: + make_hybrid_cache = lambda *args: make_dynamic_cache( # noqa: E731 + *args, cls_layers="DynamicSlidingWindowLayer" + ) + inputs = dict( input_ids=dummies["input_ids"], token_type_ids=dummies["token_type_ids"], @@ -211,7 +217,7 @@ def _check_(): ), position_ids=torch.arange(0, sequence_length).to(torch.int64).expand((batch_size, -1)), cache_position=torch.arange(0, sequence_length).to(torch.int64), - past_key_values=(make_hybrid_cache or make_dynamic_cache)( + past_key_values=make_hybrid_cache( [ ( torch.randn( diff --git a/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py b/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py index 31611c61..05e4fa8c 100644 --- a/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +++ b/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py @@ -77,6 +77,12 @@ def flatten_dynamic_cache( dynamic_cache: DynamicCache, ) -> Tuple[List[Any], torch.utils._pytree.Context]: """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects.""" + assert not dynamic_cache.layers or all( + lay.__class__.__name__ == "DynamicLayer" for lay in dynamic_cache.layers + ), ( + f"The serialization does not work yet on other layers " + f"than DynamicLayer, but layers={[lay.__class__ for lay in dynamic_cache.layers]}" + ) return _flatten_key_value_cache(dynamic_cache) @@ -84,6 +90,12 @@ def flatten_with_keys_dynamic_cache( dynamic_cache: DynamicCache, ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]: """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects.""" + assert not dynamic_cache.layers or all( + lay.__class__.__name__ == "DynamicLayer" for lay in dynamic_cache.layers + ), ( + f"The serialization does not work yet on other layers " + f"than DynamicLayer, but layers={[lay.__class__ for lay in dynamic_cache.layers]}" + ) return _flatten_with_keys_cache(dynamic_cache) From ce3097a4f11be886ed0b724e6f9c76ca9bba3109 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 9 Jan 2026 14:53:38 +0100 Subject: [PATCH 08/18] fix cache --- onnx_diagnostic/helpers/cache_helper.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/onnx_diagnostic/helpers/cache_helper.py b/onnx_diagnostic/helpers/cache_helper.py index 36fe7679..d782da8a 100644 --- a/onnx_diagnostic/helpers/cache_helper.py +++ b/onnx_diagnostic/helpers/cache_helper.py @@ -42,7 +42,9 @@ def __init__(self, cache=None): def make_dynamic_cache(self): """Does the reverse operation.""" - return make_dynamic_cache(list(zip(self.key_cache, self.value_cache))) + return make_dynamic_cache( + list(zip(self.key_cache, self.value_cache)), cls_layers=self.cls_layers + ) @property def n_layers(self) -> int: @@ -250,6 +252,8 @@ def make_dynamic_cache( cache = transformers.cache_utils.DynamicCache() cache.layers.extend([cls_layer(**cls_kwargs) for _ in key_value_pairs]) + for i, layer in enumerate(cache.layers): + layer.keys, layer.values = key_value_pairs[i][0], key_value_pairs[i][1] if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers): # The cache constructor contains the two following lines # (in cache_utils.py) which append empty layers when the cache is @@ -267,6 +271,7 @@ def make_dynamic_cache( def make_dynamic_cache( key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]], + cls_layers=None, ) -> transformers.cache_utils.DynamicCache: """ Creates an instance of :class:`transformers.cache_utils.DynamicCache`. @@ -298,6 +303,7 @@ def make_dynamic_cache( ) print(string_type(past_key_values, with_shape=True)) """ + assert not cls_layers, "cls_layers cannot be used for transformers<5." key_value_pairs = _preprocess_key_value_pairs(key_value_pairs) cache = transformers.cache_utils.DynamicCache(len(key_value_pairs)) # type: ignore for i, (key, value) in enumerate(key_value_pairs): From 738a253b8b5ee44445c9bddaaabbdc86f1347f65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 9 Jan 2026 15:07:29 +0100 Subject: [PATCH 09/18] update ci --- .github/workflows/ci.yml | 3 +++ _unittests/ut_torch_models/test_tiny_llms.py | 3 +++ onnx_diagnostic/helpers/cache_helper.py | 25 +++++++++++--------- 3 files changed, 20 insertions(+), 11 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index de947320..415c853c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -120,6 +120,9 @@ jobs: - name: pip freeze run: python -m pip freeze + - name: cache + run: PYTHONPATH=. UNITTEST_GOING=1 pytest _unittests/ut_helpers/test_cache_helper.py _unittests/ut_helpers/test_torch_helper.py _unittests/ut_helpers/test_fake_tensor_helper.py _unittests/ut_torch_export_patches/test_patch_serialization_transformers.py _unittests/ut_torch_export_patches/test_patch_transformers.py _unittests/ut_torch_export_patches/test_dynamic_class.py _unittests/ut_torch_export_patches/test_onnx_export_errors.py + - name: tiny-llm torch.export.export run: PYTHONPATH=. UNITTEST_GOING=1 python _unittests/ut_torch_models/test_tiny_llms.py diff --git a/_unittests/ut_torch_models/test_tiny_llms.py b/_unittests/ut_torch_models/test_tiny_llms.py index ac37a7b7..d386fe66 100644 --- a/_unittests/ut_torch_models/test_tiny_llms.py +++ b/_unittests/ut_torch_models/test_tiny_llms.py @@ -37,6 +37,9 @@ def test_tiny_llm_export_dynamic(self): dynamic_shapes=use_dyn_not_str(data["dynamic_shapes"]), ) got = ep.module()(**inputs) + # print("***", self.string_type(expected, with_shape=True, with_min_max=True)) + # print("***", self.string_type(got, with_shape=True, with_min_max=True)) + print(ep) self.assertEqualArrayAny(expected, got) @requires_transformers("4.52") diff --git a/onnx_diagnostic/helpers/cache_helper.py b/onnx_diagnostic/helpers/cache_helper.py index d782da8a..51f19f3a 100644 --- a/onnx_diagnostic/helpers/cache_helper.py +++ b/onnx_diagnostic/helpers/cache_helper.py @@ -251,16 +251,19 @@ def make_dynamic_cache( return finalize_cache(cache) cache = transformers.cache_utils.DynamicCache() - cache.layers.extend([cls_layer(**cls_kwargs) for _ in key_value_pairs]) - for i, layer in enumerate(cache.layers): - layer.keys, layer.values = key_value_pairs[i][0], key_value_pairs[i][1] - if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers): - # The cache constructor contains the two following lines - # (in cache_utils.py) which append empty layers when the cache is - # initialized. We need to remove them. - # self.num_hidden_layers = getattr(config, "num_hidden_layers", 1) - # self.append_new_layers(self.num_hidden_layers - 1) - cache.layers[:] = cache.layers[-len(key_value_pairs) :] + if hasattr(cache, "layers"): + cache.layers.extend([cls_layer(**cls_kwargs) for _ in key_value_pairs]) + for i, layer in enumerate(cache.layers): + layer.keys, layer.values = key_value_pairs[i][0], key_value_pairs[i][1] + else: + cache = transformers.cache_utils.DynamicCache(key_value_pairs) + if len(key_value_pairs) < len(cache.layers): + # The cache constructor contains the two following lines + # (in cache_utils.py) which append empty layers when the cache is + # initialized. We need to remove them. + # self.num_hidden_layers = getattr(config, "num_hidden_layers", 1) + # self.append_new_layers(self.num_hidden_layers - 1) + cache.layers[:] = cache.layers[-len(key_value_pairs) :] assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), ( f"Unexpected number of layers in the cache ({len(cache.layers)}), " f"{len(key_value_pairs)} expected." @@ -271,7 +274,7 @@ def make_dynamic_cache( def make_dynamic_cache( key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]], - cls_layers=None, + cls_layers: Optional[Union[str, List[type]]] = None, ) -> transformers.cache_utils.DynamicCache: """ Creates an instance of :class:`transformers.cache_utils.DynamicCache`. From e0a8b11dee11941440225dcaf5a4d449fa59ed62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 9 Jan 2026 15:12:03 +0100 Subject: [PATCH 10/18] fix mypy --- onnx_diagnostic/helpers/cache_helper.py | 9 ++++++--- onnx_diagnostic/tasks/image_text_to_text.py | 5 +---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/onnx_diagnostic/helpers/cache_helper.py b/onnx_diagnostic/helpers/cache_helper.py index 51f19f3a..c6103ea6 100644 --- a/onnx_diagnostic/helpers/cache_helper.py +++ b/onnx_diagnostic/helpers/cache_helper.py @@ -558,9 +558,6 @@ def get_make_hybrid_cache(): else: make_sliding_window_cache = None # type: ignore[assignment] - def get_make_hybrid_cache(): - return None - if hasattr(transformers.cache_utils, "HybridCache"): @@ -723,9 +720,15 @@ def get_text_config(self, *args, **kwargs): ) return finalize_cache(cache) + def get_make_hybrid_cache(): + return make_hybrid_cache + else: make_hybrid_cache = None # type: ignore[assignment] + def get_make_hybrid_cache(): + return None + def finalize_cache(cache: transformers.cache_utils.Cache) -> transformers.cache_utils.Cache: """ diff --git a/onnx_diagnostic/tasks/image_text_to_text.py b/onnx_diagnostic/tasks/image_text_to_text.py index e998dead..f26b65e5 100644 --- a/onnx_diagnostic/tasks/image_text_to_text.py +++ b/onnx_diagnostic/tasks/image_text_to_text.py @@ -201,10 +201,7 @@ def _check_(): _check_() make_hybrid_cache = get_make_hybrid_cache() - if make_hybrid_cache is None: - make_hybrid_cache = lambda *args: make_dynamic_cache( # noqa: E731 - *args, cls_layers="DynamicSlidingWindowLayer" - ) + assert make_hybrid_cache is not None, "not implemented when make_hybrid_cache is missing" inputs = dict( input_ids=dummies["input_ids"], From 18b08bbbdf7b5e1897c9e516560cdf337071484f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 9 Jan 2026 18:41:54 +0100 Subject: [PATCH 11/18] more unit test --- .../test_dynamic_class.py | 28 +++- .../test_onnx_export_errors.py | 2 +- .../test_patch_serialization_transformers.py | 125 ++++++++++++++++-- 3 files changed, 139 insertions(+), 16 deletions(-) diff --git a/_unittests/ut_torch_export_patches/test_dynamic_class.py b/_unittests/ut_torch_export_patches/test_dynamic_class.py index dfa0f66c..5eae175e 100644 --- a/_unittests/ut_torch_export_patches/test_dynamic_class.py +++ b/_unittests/ut_torch_export_patches/test_dynamic_class.py @@ -16,6 +16,7 @@ ) from onnx_diagnostic.helpers import string_type from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache, CacheKeyValue +from onnx_diagnostic.helpers.torch_helper import torch_deepcopy from onnx_diagnostic.torch_export_patches.onnx_export_errors import ( torch_export_patches, ) @@ -288,7 +289,7 @@ def test_phi2_export_module(self): data = get_untrained_model_with_inputs("microsoft/phi-2") model, inputs, dyn_shapes = data["model"], data["inputs"], data["dynamic_shapes"] str_inputs = string_type(inputs, with_shape=True, with_min_max=True) - inputs_copied = copy.deepcopy(inputs) + inputs_copied = torch_deepcopy(inputs) expected = model(**inputs_copied) self.maxDiff = None self.assertEqual(str_inputs, string_type(inputs, with_shape=True, with_min_max=True)) @@ -298,7 +299,7 @@ def test_phi2_export_module(self): string_type(inputs, with_shape=True, with_min_max=True), string_type(inputs_copied, with_shape=True, with_min_max=True), ) - inputs_copied = copy.deepcopy(inputs) + inputs_copied = torch_deepcopy(inputs) self.assertEqual( str_inputs, string_type(inputs_copied, with_shape=True, with_min_max=True) ) @@ -307,13 +308,13 @@ def test_phi2_export_module(self): ep = torch.export.export( model, (), - kwargs=inputs, + kwargs=torch_deepcopy(inputs), dynamic_shapes=use_dyn_not_str(dyn_shapes), strict=False, # True works but then the it fails during the execution ) # ep = ep.run_decompositions() mod = ep.module() - inputs_copied = copy.deepcopy(inputs) + inputs_copied = torch_deepcopy(inputs) self.assertEqual( str_inputs, string_type(inputs_copied, with_shape=True, with_min_max=True) ) @@ -368,6 +369,25 @@ def call_function(self, target, args, kwargs): got = MyInterpreter(ep.module()).run(*args) self.assertEqualAny(expected, got) + @ignore_warnings(UserWarning) + @requires_torch("2.9") + def test_tiny_llm_export_module(self): + data = get_untrained_model_with_inputs("arnir0/Tiny-LLM") + model, inputs, dyn_shapes = data["model"], data["inputs"], data["dynamic_shapes"] + expected = model(**torch_deepcopy(inputs)) + + with torch_export_patches(patch_transformers=True): + ep = torch.export.export( + model, + (), + kwargs=torch_deepcopy(inputs), + dynamic_shapes=use_dyn_not_str(dyn_shapes), + strict=False, + ) + mod = ep.module() + got = mod(**torch_deepcopy(inputs)) + self.assertEqualAny(expected, got) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_torch_export_patches/test_onnx_export_errors.py b/_unittests/ut_torch_export_patches/test_onnx_export_errors.py index 2e8a75c4..48863955 100644 --- a/_unittests/ut_torch_export_patches/test_onnx_export_errors.py +++ b/_unittests/ut_torch_export_patches/test_onnx_export_errors.py @@ -115,7 +115,7 @@ def forward(self, x: torch.Tensor, cache: MambaCache): model(x, cache) DYN = torch.export.Dim.DYNAMIC - with torch_export_patches(): + with torch_export_patches(patch_transformers=True): cache = MambaCache(_config(), max_batch_size=2, device="cpu") torch.export.export( Model(), diff --git a/_unittests/ut_torch_export_patches/test_patch_serialization_transformers.py b/_unittests/ut_torch_export_patches/test_patch_serialization_transformers.py index 4f2a92c0..214834c3 100644 --- a/_unittests/ut_torch_export_patches/test_patch_serialization_transformers.py +++ b/_unittests/ut_torch_export_patches/test_patch_serialization_transformers.py @@ -1,7 +1,13 @@ import unittest import torch from transformers.modeling_outputs import BaseModelOutput -from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, requires_torch +from onnx_diagnostic.ext_test_case import ( + ExtTestCase, + ignore_warnings, + requires_torch, + requires_transformers, +) +from onnx_diagnostic.helpers import flatten_object from onnx_diagnostic.helpers.cache_helper import ( make_encoder_decoder_cache, make_dynamic_cache, @@ -24,7 +30,7 @@ def test_encoder_decoder_cache_flatten(self): make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]), make_dynamic_cache([(torch.rand((5, 5, 5)), torch.rand((5, 5, 5)))]), ) - with torch_export_patches(): + with torch_export_patches(patch_transformers=True): flat, _spec = torch.utils._pytree.tree_flatten(cache) self.assertEqual( "#4[T1s4x4x4,T1s4x4x4,T1s5x5x5,T1s5x5x5]", @@ -42,7 +48,7 @@ def test_encoder_decoder_cache_deepcopy(self): make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]), make_dynamic_cache([(torch.rand((5, 5, 5)), torch.rand((5, 5, 5)))]), ) - with torch_export_patches(): + with torch_export_patches(patch_transformers=True): cache2 = torch_deepcopy([cache]) self.assertEqualAny([cache], cache2) @@ -75,7 +81,7 @@ def forward(self, cache): @ignore_warnings(UserWarning) def test_dynamic_cache_flatten(self): cache = make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]) - with torch_export_patches(): + with torch_export_patches(patch_transformers=True): flat, _spec = torch.utils._pytree.tree_flatten(cache) self.assertEqual( "#2[T1s4x4x4,T1s4x4x4]", @@ -114,7 +120,7 @@ def forward(self, cache): @ignore_warnings(UserWarning) def test_dynamic_cache_deepcopy(self): cache = make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]) - with torch_export_patches(): + with torch_export_patches(patch_transformers=True): cache2 = torch_deepcopy([cache]) self.assertEqualAny([cache], cache2) @@ -122,7 +128,7 @@ def test_dynamic_cache_deepcopy(self): def test_base_model_output_deepcopy(self): bo = BaseModelOutput(last_hidden_state=torch.rand((4, 4, 4))) self.assertEqual(bo.__class__.__name__, "BaseModelOutput") - with torch_export_patches(): + with torch_export_patches(patch_transformers=True): bo2 = torch_deepcopy([bo]) self.assertIsInstance(bo2, list) self.assertEqual(bo2[0].__class__.__name__, "BaseModelOutput") @@ -131,7 +137,7 @@ def test_base_model_output_deepcopy(self): @ignore_warnings(UserWarning) def test_base_model_output_string_type(self): bo = BaseModelOutput(last_hidden_state=torch.rand((4, 4, 4))) - with torch_export_patches(): + with torch_export_patches(patch_transformers=True): self.assertEqual( "BaseModelOutput(last_hidden_state:T1s4x4x4)", self.string_type(bo, with_shape=True), @@ -140,7 +146,7 @@ def test_base_model_output_string_type(self): @ignore_warnings(UserWarning) def test_base_model_output_flatten(self): bo = BaseModelOutput(last_hidden_state=torch.rand((4, 4, 4))) - with torch_export_patches(): + with torch_export_patches(patch_transformers=True): flat, _spec = torch.utils._pytree.tree_flatten(bo) self.assertEqual( "#1[T1s4x4x4]", @@ -182,7 +188,7 @@ def test_base_sliding_window_cache_unflatten_flatten(self): cache = make_sliding_window_cache( [(torch.rand((4, 4, 4, 4)), torch.rand((4, 4, 4, 4)))] ) - with torch_export_patches(): + with torch_export_patches(patch_transformers=True): cache2 = torch_deepcopy([cache]) self.assertEqualAny([cache], cache2) @@ -215,7 +221,7 @@ def test_sliding_window_cache_flatten(self): cache = make_sliding_window_cache( [(torch.rand((4, 4, 4, 4)), torch.rand((4, 4, 4, 4)))] ) - with torch_export_patches(): + with torch_export_patches(patch_transformers=True): flat, _spec = torch.utils._pytree.tree_flatten(cache) self.assertEqual( "#2[T1s4x4x4x4,T1s4x4x4x4]", @@ -247,7 +253,7 @@ def test_static_cache(self): self.string_type(bo, with_shape=True), ) - with torch_export_patches(): + with torch_export_patches(patch_transformers=True): # internal function bo2 = torch_deepcopy([bo]) self.assertIsInstance(bo2, list) @@ -291,6 +297,103 @@ def forward(self, cache): with torch_export_patches(patch_transformers=True, stop_if_static=1): torch.export.export(model, (bo,), dynamic_shapes=(ds,)) + @ignore_warnings(UserWarning) + @requires_transformers("4.99") + def test_dynamic_cache_flatten_unflatten(self): + values = [ + (torch.rand((2, 3, 4, 4)), torch.rand((2, 4, 3, 4))), + (torch.rand((2, 4, 4, 3)), torch.rand((2, 4, 4, 3))), + ] + cache = make_dynamic_cache(values) + flat_cache = flatten_object(cache) + order_cache = flatten_object(values) + with torch_export_patches(patch_transformers=True): + flat, _spec = torch.utils._pytree.tree_flatten(cache) + cache2 = torch.utils._pytree.tree_unflatten(flat, _spec) + self.assertEqualAny(flat_cache, flatten_object(cache2)) + self.assertEqualAny(order_cache, flatten_object(cache2)) + self.assertEqual( + [type(ly) for ly in cache.layers], [type(ly) for ly in cache2.layers] + ) + + @ignore_warnings(UserWarning) + @requires_transformers("4.99") + def test_dynamic_cache_in_a_model_args(self): + import torch + + class Model(torch.nn.Module): + def forward(self, x, cache): + acc = x.clone() + for i, lay in enumerate(cache.layers): + acc = acc + lay.keys * (i + 1) - lay.values + cache.update(x * (i + 1), x * 2 * (i + 1), i) + return acc, cache + + values = [ + (torch.rand((2, 4, 4, 4)), torch.rand((2, 4, 4, 4))), + (torch.rand((2, 4, 4, 4)), torch.rand((2, 4, 4, 4))), + ] + cache = make_dynamic_cache(values) + inputs = (torch.rand((2, 4, 1, 4)), cache) + inputs_copied = torch_deepcopy(inputs) + self.assertEqualAny(inputs, inputs_copied) + model = Model() + expected = model(*inputs) + DYN = torch.export.Dim.DYNAMIC + with torch_export_patches(patch_transformers=True): + ep = torch.export.export( + model, + torch_deepcopy(inputs_copied), + dynamic_shapes=( + {0: DYN}, + [{0: DYN, 2: DYN}, {0: DYN, 2: DYN}, {0: DYN, 2: DYN}, {0: DYN, 2: DYN}], + ), + ) + got = ep.module()(*inputs_copied) + self.assertEqualAny(expected, got) + + @ignore_warnings(UserWarning) + @requires_transformers("4.99") + def test_dynamic_cache_in_a_model_kwargs(self): + import torch + + class Model(torch.nn.Module): + def forward(self, x, cache): + acc = x.clone() + for i, lay in enumerate(cache.layers): + acc = acc + lay.keys * (i + 1) - lay.values + cache.update(x * (i + 1), x * 2 * (i + 1), i) + return acc, cache + + values = [ + (torch.rand((2, 4, 4, 4)), torch.rand((2, 4, 4, 4))), + (torch.rand((2, 4, 4, 4)), torch.rand((2, 4, 4, 4))), + ] + cache = make_dynamic_cache(values) + inputs = dict(x=torch.rand((2, 4, 1, 4)), cache=cache) + inputs_copied = torch_deepcopy(inputs) + self.assertEqualAny(inputs, inputs_copied) + model = Model() + expected = model(**inputs) + DYN = torch.export.Dim.DYNAMIC + with torch_export_patches(patch_transformers=True): + ep = torch.export.export( + model, + (), + kwargs=torch_deepcopy(inputs_copied), + dynamic_shapes=dict( + x={0: DYN}, + cache=[ + {0: DYN, 2: DYN}, + {0: DYN, 2: DYN}, + {0: DYN, 2: DYN}, + {0: DYN, 2: DYN}, + ], + ), + ) + got = ep.module()(**inputs_copied) + self.assertEqualAny(expected, got) + if __name__ == "__main__": unittest.main(verbosity=2) From 9fd596f395e35d223a76515b1e3abe6640d9b182 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 9 Jan 2026 19:00:13 +0100 Subject: [PATCH 12/18] add more unittest --- .../test_dynamic_class.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/_unittests/ut_torch_export_patches/test_dynamic_class.py b/_unittests/ut_torch_export_patches/test_dynamic_class.py index 5eae175e..472f172a 100644 --- a/_unittests/ut_torch_export_patches/test_dynamic_class.py +++ b/_unittests/ut_torch_export_patches/test_dynamic_class.py @@ -374,9 +374,33 @@ def call_function(self, target, args, kwargs): def test_tiny_llm_export_module(self): data = get_untrained_model_with_inputs("arnir0/Tiny-LLM") model, inputs, dyn_shapes = data["model"], data["inputs"], data["dynamic_shapes"] + inputs_copied = torch_deepcopy(inputs) + self.assertEqualArray(inputs["input_ids"], inputs_copied["input_ids"]) + self.assertEqualArray(inputs["position_ids"], inputs_copied["position_ids"]) + self.assertEqualArray(inputs["attention_mask"], inputs_copied["attention_mask"]) + self.assertEqualArray( + inputs["past_key_values"].layers[0].keys, + inputs_copied["past_key_values"].layers[0].keys, + ) + self.assertEqualArray( + inputs["past_key_values"].layers[0].values, + inputs_copied["past_key_values"].layers[0].values, + ) expected = model(**torch_deepcopy(inputs)) with torch_export_patches(patch_transformers=True): + inputs_copied = torch_deepcopy(inputs) + self.assertEqualArray(inputs["input_ids"], inputs_copied["input_ids"]) + self.assertEqualArray(inputs["position_ids"], inputs_copied["position_ids"]) + self.assertEqualArray(inputs["attention_mask"], inputs_copied["attention_mask"]) + self.assertEqualArray( + inputs["past_key_values"].layers[0].keys, + inputs_copied["past_key_values"].layers[0].keys, + ) + self.assertEqualArray( + inputs["past_key_values"].layers[0].values, + inputs_copied["past_key_values"].layers[0].values, + ) ep = torch.export.export( model, (), From 8d639b1efde97d382a6a2fed06b6c53153192baf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 9 Jan 2026 19:09:34 +0100 Subject: [PATCH 13/18] fix cache --- _unittests/ut_torch_export_patches/test_dynamic_class.py | 2 +- onnx_diagnostic/helpers/cache_helper.py | 3 ++- onnx_diagnostic/helpers/torch_helper.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/_unittests/ut_torch_export_patches/test_dynamic_class.py b/_unittests/ut_torch_export_patches/test_dynamic_class.py index 472f172a..183a681a 100644 --- a/_unittests/ut_torch_export_patches/test_dynamic_class.py +++ b/_unittests/ut_torch_export_patches/test_dynamic_class.py @@ -388,7 +388,7 @@ def test_tiny_llm_export_module(self): ) expected = model(**torch_deepcopy(inputs)) - with torch_export_patches(patch_transformers=True): + with torch_export_patches(patch_torch=False, patch_transformers=True): inputs_copied = torch_deepcopy(inputs) self.assertEqualArray(inputs["input_ids"], inputs_copied["input_ids"]) self.assertEqualArray(inputs["position_ids"], inputs_copied["position_ids"]) diff --git a/onnx_diagnostic/helpers/cache_helper.py b/onnx_diagnostic/helpers/cache_helper.py index c6103ea6..ac647950 100644 --- a/onnx_diagnostic/helpers/cache_helper.py +++ b/onnx_diagnostic/helpers/cache_helper.py @@ -251,10 +251,11 @@ def make_dynamic_cache( return finalize_cache(cache) cache = transformers.cache_utils.DynamicCache() - if hasattr(cache, "layers"): + if hasattr(cache, "layers") and cls_layer != transformers.cache_utils.DynamicLayer: cache.layers.extend([cls_layer(**cls_kwargs) for _ in key_value_pairs]) for i, layer in enumerate(cache.layers): layer.keys, layer.values = key_value_pairs[i][0], key_value_pairs[i][1] + layer.is_initialized = True else: cache = transformers.cache_utils.DynamicCache(key_value_pairs) if len(key_value_pairs) < len(cache.layers): diff --git a/onnx_diagnostic/helpers/torch_helper.py b/onnx_diagnostic/helpers/torch_helper.py index 756e5589..59dfa397 100644 --- a/onnx_diagnostic/helpers/torch_helper.py +++ b/onnx_diagnostic/helpers/torch_helper.py @@ -777,7 +777,7 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device, str]) -> Any: ), cls_layers=cc.cls_layers, ) - if value.__class__.__name__ in "HybridCache": + if value.__class__.__name__ == "HybridCache": from .cache_helper import make_hybrid_cache cc = CacheKeyValue(value) From 550aea25642686be2b17226f9f461c76e435f3e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 9 Jan 2026 19:18:51 +0100 Subject: [PATCH 14/18] fix --- .../test_patch_serialization_transformers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/_unittests/ut_torch_export_patches/test_patch_serialization_transformers.py b/_unittests/ut_torch_export_patches/test_patch_serialization_transformers.py index 214834c3..1285c1e1 100644 --- a/_unittests/ut_torch_export_patches/test_patch_serialization_transformers.py +++ b/_unittests/ut_torch_export_patches/test_patch_serialization_transformers.py @@ -301,8 +301,8 @@ def forward(self, cache): @requires_transformers("4.99") def test_dynamic_cache_flatten_unflatten(self): values = [ - (torch.rand((2, 3, 4, 4)), torch.rand((2, 4, 3, 4))), - (torch.rand((2, 4, 4, 3)), torch.rand((2, 4, 4, 3))), + (torch.rand((2, 4, 4, 4)), torch.rand((2, 4, 4, 4))), + (torch.rand((2, 4, 4, 4)), torch.rand((2, 4, 4, 4))), ] cache = make_dynamic_cache(values) flat_cache = flatten_object(cache) From 5a85d1f6d1cefc12aefb73cf48abce2faff42ae4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 9 Jan 2026 19:40:13 +0100 Subject: [PATCH 15/18] fix --- .../ut_tasks/test_tasks_image_text_to_text.py | 4 ++++ onnx_diagnostic/helpers/cache_helper.py | 13 ++++++++++--- .../serialization/transformers_impl.py | 12 ++++++++---- 3 files changed, 22 insertions(+), 7 deletions(-) diff --git a/_unittests/ut_tasks/test_tasks_image_text_to_text.py b/_unittests/ut_tasks/test_tasks_image_text_to_text.py index 03ac9e2d..be37aaf0 100644 --- a/_unittests/ut_tasks/test_tasks_image_text_to_text.py +++ b/_unittests/ut_tasks/test_tasks_image_text_to_text.py @@ -10,6 +10,7 @@ from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs from onnx_diagnostic.torch_export_patches import torch_export_patches from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str +from onnx_diagnostic.helpers.cache_helper import get_make_hybrid_cache class TestTasksImageTextToText(ExtTestCase): @@ -58,6 +59,9 @@ def test_image_text_to_text_tiny_gemma3(self): @requires_transformers("4.56.99") @requires_torch("2.8.99") def test_image_text_to_text_gemma3_4b_it(self): + make_hybrid_cache = get_make_hybrid_cache() + if make_hybrid_cache is None: + raise unittest.SkipTest("not implemented yet for transformers>=5") mid = "google/gemma-3-4b-it" data = get_untrained_model_with_inputs( mid, diff --git a/onnx_diagnostic/helpers/cache_helper.py b/onnx_diagnostic/helpers/cache_helper.py index ac647950..767c4f1f 100644 --- a/onnx_diagnostic/helpers/cache_helper.py +++ b/onnx_diagnostic/helpers/cache_helper.py @@ -216,7 +216,10 @@ def make_dynamic_cache( unique = set(cls_layers) assert len(unique) == 1, f"Not implemented when cls_layers={cls_layers}" cls_layer = unique.pop() - if cls_layer == transformers.cache_utils.DynamicSlidingWindowLayer: + if ( + hasattr(transformers.cache_utils, "DynamicSlidingWindowLayer") + and cls_layer == transformers.cache_utils.DynamicSlidingWindowLayer + ): from .helper import string_type assert key_value_pairs and key_value_pairs[0], ( @@ -228,7 +231,11 @@ def make_dynamic_cache( cls_kwargs["sliding_window"], int ), f"sliding_window must be an integer but shape={key_value_pairs[0][0].shape}" else: - cls_layer = transformers.cache_utils.DynamicLayer + cls_layer = ( + transformers.cache_utils.DynamicLayer + if hasattr(transformers.cache_utils, "DynamicLayer") + else None + ) if ( key_value_pairs @@ -258,7 +265,7 @@ def make_dynamic_cache( layer.is_initialized = True else: cache = transformers.cache_utils.DynamicCache(key_value_pairs) - if len(key_value_pairs) < len(cache.layers): + if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers): # The cache constructor contains the two following lines # (in cache_utils.py) which append empty layers when the cache is # initialized. We need to remove them. diff --git a/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py b/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py index 05e4fa8c..58fbceee 100644 --- a/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +++ b/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py @@ -77,8 +77,10 @@ def flatten_dynamic_cache( dynamic_cache: DynamicCache, ) -> Tuple[List[Any], torch.utils._pytree.Context]: """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects.""" - assert not dynamic_cache.layers or all( - lay.__class__.__name__ == "DynamicLayer" for lay in dynamic_cache.layers + assert ( + not hasattr(dynamic_cache, "layers") + or not dynamic_cache.layers + or all(lay.__class__.__name__ == "DynamicLayer" for lay in dynamic_cache.layers) ), ( f"The serialization does not work yet on other layers " f"than DynamicLayer, but layers={[lay.__class__ for lay in dynamic_cache.layers]}" @@ -90,8 +92,10 @@ def flatten_with_keys_dynamic_cache( dynamic_cache: DynamicCache, ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]: """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects.""" - assert not dynamic_cache.layers or all( - lay.__class__.__name__ == "DynamicLayer" for lay in dynamic_cache.layers + assert ( + not hasattr(dynamic_cache, "layers") + or not dynamic_cache.layers + or all(lay.__class__.__name__ == "DynamicLayer" for lay in dynamic_cache.layers) ), ( f"The serialization does not work yet on other layers " f"than DynamicLayer, but layers={[lay.__class__ for lay in dynamic_cache.layers]}" From 15b4a5ca7539aa4261e6a237eb6d1ef2021db87e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 9 Jan 2026 19:46:36 +0100 Subject: [PATCH 16/18] disable one test --- _unittests/ut_torch_export_patches/test_dynamic_class.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/_unittests/ut_torch_export_patches/test_dynamic_class.py b/_unittests/ut_torch_export_patches/test_dynamic_class.py index 183a681a..cf908b95 100644 --- a/_unittests/ut_torch_export_patches/test_dynamic_class.py +++ b/_unittests/ut_torch_export_patches/test_dynamic_class.py @@ -13,6 +13,7 @@ ignore_warnings, hide_stdout, requires_torch, + requires_transformers, ) from onnx_diagnostic.helpers import string_type from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache, CacheKeyValue @@ -371,6 +372,7 @@ def call_function(self, target, args, kwargs): @ignore_warnings(UserWarning) @requires_torch("2.9") + @requires_transformers("4.57") def test_tiny_llm_export_module(self): data = get_untrained_model_with_inputs("arnir0/Tiny-LLM") model, inputs, dyn_shapes = data["model"], data["inputs"], data["dynamic_shapes"] From 555372d5c54bb35bc67019e264c7f62ba59442fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Sat, 10 Jan 2026 11:36:17 +0100 Subject: [PATCH 17/18] fix one unit test --- _unittests/ut_torch_models/test_validate_models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/_unittests/ut_torch_models/test_validate_models.py b/_unittests/ut_torch_models/test_validate_models.py index 23a5ce82..13a886e6 100644 --- a/_unittests/ut_torch_models/test_validate_models.py +++ b/_unittests/ut_torch_models/test_validate_models.py @@ -77,7 +77,8 @@ def test_validate_microsoft_phi3_mini_128k(self): do_same=True, patch=True, rewrite=True, - stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0, + exporter_options=dict(prefer_deferred_runtime_asserts_over_guards=False), + stop_if_static=2, dump_folder="dump_test/validate_microsoft_phi3_mini_128k", ) self.assertLess(summary["disc_onnx_ort_run_abs"], 2e-5) From c896c9b3749d7c7a994be6230678322b9f994dd5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Sat, 10 Jan 2026 11:54:29 +0100 Subject: [PATCH 18/18] fix final --- _unittests/ut_torch_models/test_validate_models.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/_unittests/ut_torch_models/test_validate_models.py b/_unittests/ut_torch_models/test_validate_models.py index 13a886e6..cf817b00 100644 --- a/_unittests/ut_torch_models/test_validate_models.py +++ b/_unittests/ut_torch_models/test_validate_models.py @@ -1,6 +1,4 @@ import unittest -import packaging.version as pv -import torch from onnx_diagnostic.ext_test_case import ( ExtTestCase, hide_stdout, @@ -29,7 +27,7 @@ def test_validate_tiny_llms_bfloat16(self): do_same=True, patch=True, rewrite=True, - stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0, + stop_if_static=2, dump_folder="dump_test/validate_tiny_llm", dtype="bfloat16", device="cuda", @@ -55,7 +53,8 @@ def test_validate_microsoft_phi4_reasoning(self): do_same=True, patch=True, rewrite=True, - stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0, + exporter_options=dict(prefer_deferred_runtime_asserts_over_guards=False), + stop_if_static=2, dump_folder="dump_test/validate_microsoft_phi4_reasoning", ) self.assertLess(summary["disc_onnx_ort_run_abs"], 2e-5)