diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index de947320..415c853c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -120,6 +120,9 @@ jobs: - name: pip freeze run: python -m pip freeze + - name: cache + run: PYTHONPATH=. UNITTEST_GOING=1 pytest _unittests/ut_helpers/test_cache_helper.py _unittests/ut_helpers/test_torch_helper.py _unittests/ut_helpers/test_fake_tensor_helper.py _unittests/ut_torch_export_patches/test_patch_serialization_transformers.py _unittests/ut_torch_export_patches/test_patch_transformers.py _unittests/ut_torch_export_patches/test_dynamic_class.py _unittests/ut_torch_export_patches/test_onnx_export_errors.py + - name: tiny-llm torch.export.export run: PYTHONPATH=. UNITTEST_GOING=1 python _unittests/ut_torch_models/test_tiny_llms.py diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 825771cd..c0c4b39d 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,8 @@ Change Logs 0.8.9 +++++ +* :pr:`379`: update the handling of cache after the removal of HybridCache, SlidingWindowCache in ``transformers>=5``, + 0.8.8 +++++ diff --git a/_unittests/ut_export/test_shape_helper.py b/_unittests/ut_export/test_shape_helper.py index 0615c819..9be69bc9 100644 --- a/_unittests/ut_export/test_shape_helper.py +++ b/_unittests/ut_export/test_shape_helper.py @@ -69,12 +69,16 @@ def test_all_dynamic_shape_all_transformers_cache(self): ], ), ( - make_sliding_window_cache( - [ - (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), - (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), - (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), - ] + ( + make_sliding_window_cache( + [ + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + ] + ) + if make_sliding_window_cache is not None + else None ), [ {0: "d_0_0", 1: "d_0_1", 2: "d_0_2", 3: "d_0_3"}, @@ -106,11 +110,15 @@ def test_all_dynamic_shape_all_transformers_cache(self): ] with torch_export_patches(patch_transformers=True): for cache, exds in caches: + if cache is None: + continue with self.subTest(cache_name=cache.__class__.__name__, patch=True): ds = all_dynamic_shapes_from_inputs(cache) self.assertEqual(exds, ds) for cache, exds in caches: + if cache is None: + continue with self.subTest(cache_name=cache.__class__.__name__, patch=False): ds = all_dynamic_shapes_from_inputs(cache) self.assertEqual(exds, ds) diff --git a/_unittests/ut_helpers/test_cache_helper.py b/_unittests/ut_helpers/test_cache_helper.py index db6e2781..7d018b01 100644 --- a/_unittests/ut_helpers/test_cache_helper.py +++ b/_unittests/ut_helpers/test_cache_helper.py @@ -167,6 +167,9 @@ def test_make_mamba_cache(self): ) self.assertEqual(0, max_diff(cache, cache)["abs"]) + @unittest.skipIf( + not make_sliding_window_cache, "SlidingWindowCache removed in transformers>=5" + ) def test_make_sliding_window_cache(self): cache = make_sliding_window_cache( [ @@ -223,6 +226,7 @@ def test_unflatten_flatten_static_cache(self): self.string_type(unflat, with_shape=True), ) + @unittest.skipIf(not make_hybrid_cache, "HybridCache removed in transformers>=5") def test_make_hybrid_cache(self): cache = make_hybrid_cache( [ @@ -240,6 +244,7 @@ def test_make_hybrid_cache(self): self.assertEqual(0, max_diff(cache, cache)["abs"]) self.assertEqual(0, max_diff(cache, torch_deepcopy(cache))["abs"]) + @unittest.skipIf(not make_hybrid_cache, "HybridCache removed in transformers>=5") def test_unflatten_flatten_hybrid_cache(self): with torch_export_patches(patch_transformers=True): c2 = make_hybrid_cache( diff --git a/_unittests/ut_helpers/test_helper.py b/_unittests/ut_helpers/test_helper.py index d2981a91..1aef9297 100644 --- a/_unittests/ut_helpers/test_helper.py +++ b/_unittests/ut_helpers/test_helper.py @@ -667,12 +667,14 @@ def test_max_diff_caches(self): [(torch.rand((1, 1, 3, 4)), torch.rand((1, 1, 3, 4)))], max_cache_len=3 ) self.assertEqual(max_diff(cache, cache)["abs"], 0) - cache = make_hybrid_cache([(torch.rand((1, 1, 3, 4)), torch.rand((1, 1, 3, 4)))]) - self.assertEqual(max_diff(cache, cache)["abs"], 0) - cache = make_sliding_window_cache( - [(torch.rand((1, 1, 3, 4)), torch.rand((1, 1, 3, 4)))] - ) - self.assertEqual(max_diff(cache, cache)["abs"], 0) + if make_hybrid_cache is not None: + cache = make_hybrid_cache([(torch.rand((1, 1, 3, 4)), torch.rand((1, 1, 3, 4)))]) + self.assertEqual(max_diff(cache, cache)["abs"], 0) + if make_sliding_window_cache is not None: + cache = make_sliding_window_cache( + [(torch.rand((1, 1, 3, 4)), torch.rand((1, 1, 3, 4)))] + ) + self.assertEqual(max_diff(cache, cache)["abs"], 0) cache = make_encoder_decoder_cache(cache, cache) self.assertEqual(max_diff(cache, cache)["abs"], 0) diff --git a/_unittests/ut_helpers/test_torch_helper.py b/_unittests/ut_helpers/test_torch_helper.py index 2c3c2990..36c72c5c 100644 --- a/_unittests/ut_helpers/test_torch_helper.py +++ b/_unittests/ut_helpers/test_torch_helper.py @@ -344,6 +344,7 @@ def test_torch_deepcopy_base_model_outputs(self): self.assertEqual(hash1, hash2) self.assertGreater(torch_tensor_size(bo), 1) + @unittest.skipIf(make_sliding_window_cache is None, "SlidingWindowCache was removed") def test_torch_deepcopy_sliding_windon_cache(self): cache = make_sliding_window_cache( [ diff --git a/_unittests/ut_tasks/test_tasks_image_text_to_text.py b/_unittests/ut_tasks/test_tasks_image_text_to_text.py index 03ac9e2d..be37aaf0 100644 --- a/_unittests/ut_tasks/test_tasks_image_text_to_text.py +++ b/_unittests/ut_tasks/test_tasks_image_text_to_text.py @@ -10,6 +10,7 @@ from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs from onnx_diagnostic.torch_export_patches import torch_export_patches from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str +from onnx_diagnostic.helpers.cache_helper import get_make_hybrid_cache class TestTasksImageTextToText(ExtTestCase): @@ -58,6 +59,9 @@ def test_image_text_to_text_tiny_gemma3(self): @requires_transformers("4.56.99") @requires_torch("2.8.99") def test_image_text_to_text_gemma3_4b_it(self): + make_hybrid_cache = get_make_hybrid_cache() + if make_hybrid_cache is None: + raise unittest.SkipTest("not implemented yet for transformers>=5") mid = "google/gemma-3-4b-it" data = get_untrained_model_with_inputs( mid, diff --git a/_unittests/ut_torch_export_patches/test_dynamic_class.py b/_unittests/ut_torch_export_patches/test_dynamic_class.py index dfa0f66c..cf908b95 100644 --- a/_unittests/ut_torch_export_patches/test_dynamic_class.py +++ b/_unittests/ut_torch_export_patches/test_dynamic_class.py @@ -13,9 +13,11 @@ ignore_warnings, hide_stdout, requires_torch, + requires_transformers, ) from onnx_diagnostic.helpers import string_type from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache, CacheKeyValue +from onnx_diagnostic.helpers.torch_helper import torch_deepcopy from onnx_diagnostic.torch_export_patches.onnx_export_errors import ( torch_export_patches, ) @@ -288,7 +290,7 @@ def test_phi2_export_module(self): data = get_untrained_model_with_inputs("microsoft/phi-2") model, inputs, dyn_shapes = data["model"], data["inputs"], data["dynamic_shapes"] str_inputs = string_type(inputs, with_shape=True, with_min_max=True) - inputs_copied = copy.deepcopy(inputs) + inputs_copied = torch_deepcopy(inputs) expected = model(**inputs_copied) self.maxDiff = None self.assertEqual(str_inputs, string_type(inputs, with_shape=True, with_min_max=True)) @@ -298,7 +300,7 @@ def test_phi2_export_module(self): string_type(inputs, with_shape=True, with_min_max=True), string_type(inputs_copied, with_shape=True, with_min_max=True), ) - inputs_copied = copy.deepcopy(inputs) + inputs_copied = torch_deepcopy(inputs) self.assertEqual( str_inputs, string_type(inputs_copied, with_shape=True, with_min_max=True) ) @@ -307,13 +309,13 @@ def test_phi2_export_module(self): ep = torch.export.export( model, (), - kwargs=inputs, + kwargs=torch_deepcopy(inputs), dynamic_shapes=use_dyn_not_str(dyn_shapes), strict=False, # True works but then the it fails during the execution ) # ep = ep.run_decompositions() mod = ep.module() - inputs_copied = copy.deepcopy(inputs) + inputs_copied = torch_deepcopy(inputs) self.assertEqual( str_inputs, string_type(inputs_copied, with_shape=True, with_min_max=True) ) @@ -368,6 +370,50 @@ def call_function(self, target, args, kwargs): got = MyInterpreter(ep.module()).run(*args) self.assertEqualAny(expected, got) + @ignore_warnings(UserWarning) + @requires_torch("2.9") + @requires_transformers("4.57") + def test_tiny_llm_export_module(self): + data = get_untrained_model_with_inputs("arnir0/Tiny-LLM") + model, inputs, dyn_shapes = data["model"], data["inputs"], data["dynamic_shapes"] + inputs_copied = torch_deepcopy(inputs) + self.assertEqualArray(inputs["input_ids"], inputs_copied["input_ids"]) + self.assertEqualArray(inputs["position_ids"], inputs_copied["position_ids"]) + self.assertEqualArray(inputs["attention_mask"], inputs_copied["attention_mask"]) + self.assertEqualArray( + inputs["past_key_values"].layers[0].keys, + inputs_copied["past_key_values"].layers[0].keys, + ) + self.assertEqualArray( + inputs["past_key_values"].layers[0].values, + inputs_copied["past_key_values"].layers[0].values, + ) + expected = model(**torch_deepcopy(inputs)) + + with torch_export_patches(patch_torch=False, patch_transformers=True): + inputs_copied = torch_deepcopy(inputs) + self.assertEqualArray(inputs["input_ids"], inputs_copied["input_ids"]) + self.assertEqualArray(inputs["position_ids"], inputs_copied["position_ids"]) + self.assertEqualArray(inputs["attention_mask"], inputs_copied["attention_mask"]) + self.assertEqualArray( + inputs["past_key_values"].layers[0].keys, + inputs_copied["past_key_values"].layers[0].keys, + ) + self.assertEqualArray( + inputs["past_key_values"].layers[0].values, + inputs_copied["past_key_values"].layers[0].values, + ) + ep = torch.export.export( + model, + (), + kwargs=torch_deepcopy(inputs), + dynamic_shapes=use_dyn_not_str(dyn_shapes), + strict=False, + ) + mod = ep.module() + got = mod(**torch_deepcopy(inputs)) + self.assertEqualAny(expected, got) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_torch_export_patches/test_onnx_export_errors.py b/_unittests/ut_torch_export_patches/test_onnx_export_errors.py index 2e8a75c4..48863955 100644 --- a/_unittests/ut_torch_export_patches/test_onnx_export_errors.py +++ b/_unittests/ut_torch_export_patches/test_onnx_export_errors.py @@ -115,7 +115,7 @@ def forward(self, x: torch.Tensor, cache: MambaCache): model(x, cache) DYN = torch.export.Dim.DYNAMIC - with torch_export_patches(): + with torch_export_patches(patch_transformers=True): cache = MambaCache(_config(), max_batch_size=2, device="cpu") torch.export.export( Model(), diff --git a/_unittests/ut_torch_export_patches/test_patch_serialization_transformers.py b/_unittests/ut_torch_export_patches/test_patch_serialization_transformers.py index 352be02a..1285c1e1 100644 --- a/_unittests/ut_torch_export_patches/test_patch_serialization_transformers.py +++ b/_unittests/ut_torch_export_patches/test_patch_serialization_transformers.py @@ -1,7 +1,13 @@ import unittest import torch from transformers.modeling_outputs import BaseModelOutput -from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, requires_torch +from onnx_diagnostic.ext_test_case import ( + ExtTestCase, + ignore_warnings, + requires_torch, + requires_transformers, +) +from onnx_diagnostic.helpers import flatten_object from onnx_diagnostic.helpers.cache_helper import ( make_encoder_decoder_cache, make_dynamic_cache, @@ -24,7 +30,7 @@ def test_encoder_decoder_cache_flatten(self): make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]), make_dynamic_cache([(torch.rand((5, 5, 5)), torch.rand((5, 5, 5)))]), ) - with torch_export_patches(): + with torch_export_patches(patch_transformers=True): flat, _spec = torch.utils._pytree.tree_flatten(cache) self.assertEqual( "#4[T1s4x4x4,T1s4x4x4,T1s5x5x5,T1s5x5x5]", @@ -42,7 +48,7 @@ def test_encoder_decoder_cache_deepcopy(self): make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]), make_dynamic_cache([(torch.rand((5, 5, 5)), torch.rand((5, 5, 5)))]), ) - with torch_export_patches(): + with torch_export_patches(patch_transformers=True): cache2 = torch_deepcopy([cache]) self.assertEqualAny([cache], cache2) @@ -75,7 +81,7 @@ def forward(self, cache): @ignore_warnings(UserWarning) def test_dynamic_cache_flatten(self): cache = make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]) - with torch_export_patches(): + with torch_export_patches(patch_transformers=True): flat, _spec = torch.utils._pytree.tree_flatten(cache) self.assertEqual( "#2[T1s4x4x4,T1s4x4x4]", @@ -114,7 +120,7 @@ def forward(self, cache): @ignore_warnings(UserWarning) def test_dynamic_cache_deepcopy(self): cache = make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]) - with torch_export_patches(): + with torch_export_patches(patch_transformers=True): cache2 = torch_deepcopy([cache]) self.assertEqualAny([cache], cache2) @@ -122,7 +128,7 @@ def test_dynamic_cache_deepcopy(self): def test_base_model_output_deepcopy(self): bo = BaseModelOutput(last_hidden_state=torch.rand((4, 4, 4))) self.assertEqual(bo.__class__.__name__, "BaseModelOutput") - with torch_export_patches(): + with torch_export_patches(patch_transformers=True): bo2 = torch_deepcopy([bo]) self.assertIsInstance(bo2, list) self.assertEqual(bo2[0].__class__.__name__, "BaseModelOutput") @@ -131,7 +137,7 @@ def test_base_model_output_deepcopy(self): @ignore_warnings(UserWarning) def test_base_model_output_string_type(self): bo = BaseModelOutput(last_hidden_state=torch.rand((4, 4, 4))) - with torch_export_patches(): + with torch_export_patches(patch_transformers=True): self.assertEqual( "BaseModelOutput(last_hidden_state:T1s4x4x4)", self.string_type(bo, with_shape=True), @@ -140,7 +146,7 @@ def test_base_model_output_string_type(self): @ignore_warnings(UserWarning) def test_base_model_output_flatten(self): bo = BaseModelOutput(last_hidden_state=torch.rand((4, 4, 4))) - with torch_export_patches(): + with torch_export_patches(patch_transformers=True): flat, _spec = torch.utils._pytree.tree_flatten(bo) self.assertEqual( "#1[T1s4x4x4]", @@ -177,16 +183,18 @@ def test_base_model_output_unflatten_flatten(self): self.assertEqual("#1[T1r3]", self.string_type(unflat)) @ignore_warnings(UserWarning) + @unittest.skipIf(not make_sliding_window_cache, "SlidingWindowCache was removed") def test_base_sliding_window_cache_unflatten_flatten(self): cache = make_sliding_window_cache( [(torch.rand((4, 4, 4, 4)), torch.rand((4, 4, 4, 4)))] ) - with torch_export_patches(): + with torch_export_patches(patch_transformers=True): cache2 = torch_deepcopy([cache]) self.assertEqualAny([cache], cache2) @ignore_warnings(UserWarning) @requires_torch("2.7.99") + @unittest.skipIf(not make_sliding_window_cache, "SlidingWindowCache was removed") def test_sliding_window_cache_export(self): class Model(torch.nn.Module): def forward(self, cache): @@ -208,11 +216,12 @@ def forward(self, cache): torch.export.export(model, (cache,), dynamic_shapes=(ds,)) @ignore_warnings(UserWarning) + @unittest.skipIf(not make_sliding_window_cache, "SlidingWindowCache was removed") def test_sliding_window_cache_flatten(self): cache = make_sliding_window_cache( [(torch.rand((4, 4, 4, 4)), torch.rand((4, 4, 4, 4)))] ) - with torch_export_patches(): + with torch_export_patches(patch_transformers=True): flat, _spec = torch.utils._pytree.tree_flatten(cache) self.assertEqual( "#2[T1s4x4x4x4,T1s4x4x4x4]", @@ -244,7 +253,7 @@ def test_static_cache(self): self.string_type(bo, with_shape=True), ) - with torch_export_patches(): + with torch_export_patches(patch_transformers=True): # internal function bo2 = torch_deepcopy([bo]) self.assertIsInstance(bo2, list) @@ -288,6 +297,103 @@ def forward(self, cache): with torch_export_patches(patch_transformers=True, stop_if_static=1): torch.export.export(model, (bo,), dynamic_shapes=(ds,)) + @ignore_warnings(UserWarning) + @requires_transformers("4.99") + def test_dynamic_cache_flatten_unflatten(self): + values = [ + (torch.rand((2, 4, 4, 4)), torch.rand((2, 4, 4, 4))), + (torch.rand((2, 4, 4, 4)), torch.rand((2, 4, 4, 4))), + ] + cache = make_dynamic_cache(values) + flat_cache = flatten_object(cache) + order_cache = flatten_object(values) + with torch_export_patches(patch_transformers=True): + flat, _spec = torch.utils._pytree.tree_flatten(cache) + cache2 = torch.utils._pytree.tree_unflatten(flat, _spec) + self.assertEqualAny(flat_cache, flatten_object(cache2)) + self.assertEqualAny(order_cache, flatten_object(cache2)) + self.assertEqual( + [type(ly) for ly in cache.layers], [type(ly) for ly in cache2.layers] + ) + + @ignore_warnings(UserWarning) + @requires_transformers("4.99") + def test_dynamic_cache_in_a_model_args(self): + import torch + + class Model(torch.nn.Module): + def forward(self, x, cache): + acc = x.clone() + for i, lay in enumerate(cache.layers): + acc = acc + lay.keys * (i + 1) - lay.values + cache.update(x * (i + 1), x * 2 * (i + 1), i) + return acc, cache + + values = [ + (torch.rand((2, 4, 4, 4)), torch.rand((2, 4, 4, 4))), + (torch.rand((2, 4, 4, 4)), torch.rand((2, 4, 4, 4))), + ] + cache = make_dynamic_cache(values) + inputs = (torch.rand((2, 4, 1, 4)), cache) + inputs_copied = torch_deepcopy(inputs) + self.assertEqualAny(inputs, inputs_copied) + model = Model() + expected = model(*inputs) + DYN = torch.export.Dim.DYNAMIC + with torch_export_patches(patch_transformers=True): + ep = torch.export.export( + model, + torch_deepcopy(inputs_copied), + dynamic_shapes=( + {0: DYN}, + [{0: DYN, 2: DYN}, {0: DYN, 2: DYN}, {0: DYN, 2: DYN}, {0: DYN, 2: DYN}], + ), + ) + got = ep.module()(*inputs_copied) + self.assertEqualAny(expected, got) + + @ignore_warnings(UserWarning) + @requires_transformers("4.99") + def test_dynamic_cache_in_a_model_kwargs(self): + import torch + + class Model(torch.nn.Module): + def forward(self, x, cache): + acc = x.clone() + for i, lay in enumerate(cache.layers): + acc = acc + lay.keys * (i + 1) - lay.values + cache.update(x * (i + 1), x * 2 * (i + 1), i) + return acc, cache + + values = [ + (torch.rand((2, 4, 4, 4)), torch.rand((2, 4, 4, 4))), + (torch.rand((2, 4, 4, 4)), torch.rand((2, 4, 4, 4))), + ] + cache = make_dynamic_cache(values) + inputs = dict(x=torch.rand((2, 4, 1, 4)), cache=cache) + inputs_copied = torch_deepcopy(inputs) + self.assertEqualAny(inputs, inputs_copied) + model = Model() + expected = model(**inputs) + DYN = torch.export.Dim.DYNAMIC + with torch_export_patches(patch_transformers=True): + ep = torch.export.export( + model, + (), + kwargs=torch_deepcopy(inputs_copied), + dynamic_shapes=dict( + x={0: DYN}, + cache=[ + {0: DYN, 2: DYN}, + {0: DYN, 2: DYN}, + {0: DYN, 2: DYN}, + {0: DYN, 2: DYN}, + ], + ), + ) + got = ep.module()(**inputs_copied) + self.assertEqualAny(expected, got) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_torch_models/test_tiny_llms.py b/_unittests/ut_torch_models/test_tiny_llms.py index ac37a7b7..d386fe66 100644 --- a/_unittests/ut_torch_models/test_tiny_llms.py +++ b/_unittests/ut_torch_models/test_tiny_llms.py @@ -37,6 +37,9 @@ def test_tiny_llm_export_dynamic(self): dynamic_shapes=use_dyn_not_str(data["dynamic_shapes"]), ) got = ep.module()(**inputs) + # print("***", self.string_type(expected, with_shape=True, with_min_max=True)) + # print("***", self.string_type(got, with_shape=True, with_min_max=True)) + print(ep) self.assertEqualArrayAny(expected, got) @requires_transformers("4.52") diff --git a/_unittests/ut_torch_models/test_validate_models.py b/_unittests/ut_torch_models/test_validate_models.py index ad6dae7e..cf817b00 100644 --- a/_unittests/ut_torch_models/test_validate_models.py +++ b/_unittests/ut_torch_models/test_validate_models.py @@ -1,6 +1,4 @@ import unittest -import packaging.version as pv -import torch from onnx_diagnostic.ext_test_case import ( ExtTestCase, hide_stdout, @@ -29,7 +27,7 @@ def test_validate_tiny_llms_bfloat16(self): do_same=True, patch=True, rewrite=True, - stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0, + stop_if_static=2, dump_folder="dump_test/validate_tiny_llm", dtype="bfloat16", device="cuda", @@ -55,14 +53,15 @@ def test_validate_microsoft_phi4_reasoning(self): do_same=True, patch=True, rewrite=True, - stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0, + exporter_options=dict(prefer_deferred_runtime_asserts_over_guards=False), + stop_if_static=2, dump_folder="dump_test/validate_microsoft_phi4_reasoning", ) self.assertLess(summary["disc_onnx_ort_run_abs"], 2e-5) self.assertIn("onnx_filename", data) self.clean_dump() - @requires_transformers("4.53") + @requires_transformers("4.57") @requires_torch("2.8.99") @requires_experimental() @hide_stdout() @@ -77,7 +76,8 @@ def test_validate_microsoft_phi3_mini_128k(self): do_same=True, patch=True, rewrite=True, - stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0, + exporter_options=dict(prefer_deferred_runtime_asserts_over_guards=False), + stop_if_static=2, dump_folder="dump_test/validate_microsoft_phi3_mini_128k", ) self.assertLess(summary["disc_onnx_ort_run_abs"], 2e-5) diff --git a/onnx_diagnostic/helpers/cache_helper.py b/onnx_diagnostic/helpers/cache_helper.py index ff0977ab..767c4f1f 100644 --- a/onnx_diagnostic/helpers/cache_helper.py +++ b/onnx_diagnostic/helpers/cache_helper.py @@ -28,18 +28,23 @@ def __init__(self, cache=None): ] self.key_cache = [layer.keys for layer in layers] self.value_cache = [layer.values for layer in layers] + self.cls_layers = [type(lay) for lay in cache.layers] elif cache is not None and hasattr(cache, "key_cache"): self.key_cache = cache.key_cache self.value_cache = cache.value_cache + self.cls_layers = None elif cache is None: self.key_cache = None self.value_cache = None + self.cls_layers = None else: raise NotImplementedError(f"type(cache)={type(cache)}") def make_dynamic_cache(self): """Does the reverse operation.""" - return make_dynamic_cache(list(zip(self.key_cache, self.value_cache))) + return make_dynamic_cache( + list(zip(self.key_cache, self.value_cache)), cls_layers=self.cls_layers + ) @property def n_layers(self) -> int: @@ -156,12 +161,16 @@ def _preprocess_key_value_pairs( def make_dynamic_cache( key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]], + cls_layers: Optional[Union[str, List[type]]] = None, ) -> transformers.cache_utils.DynamicCache: """ Creates an instance of :class:`transformers.cache_utils.DynamicCache`. This version is valid for ``transformers >= 4.50``. :param key_value_pairs: list of pairs of (key, values) + :param cls_layers: to select the appropriate class to use on each layer, + if specified, sliding_window is ignored, it can be a string + if all layers are expected to follow the same class :return: :class:`transformers.cache_utils.DynamicCache` Example: @@ -192,15 +201,49 @@ def make_dynamic_cache( are supported. """ key_value_pairs = _preprocess_key_value_pairs(key_value_pairs) + cls_kwargs = {} + if isinstance(cls_layers, str): + assert hasattr( + transformers.cache_utils, cls_layers + ), f"Unable to find class {cls_layers!r} in transformers.cache_utils" + cls_layer = getattr(transformers.cache_utils, cls_layers) + if cls_layers == "DynamicSlidingWindowLayer": + cls_kwargs["sliding_window"] = key_value_pairs[0][0].shape[2] + assert isinstance( + cls_kwargs["sliding_window"], int + ), f"sliding_window must be an integer but shape={key_value_pairs[0][0].shape}" + elif cls_layers is not None: + unique = set(cls_layers) + assert len(unique) == 1, f"Not implemented when cls_layers={cls_layers}" + cls_layer = unique.pop() + if ( + hasattr(transformers.cache_utils, "DynamicSlidingWindowLayer") + and cls_layer == transformers.cache_utils.DynamicSlidingWindowLayer + ): + from .helper import string_type + + assert key_value_pairs and key_value_pairs[0], ( + f"not implemented for key_value_pairs=" + f"{string_type(key_value_pairs, with_shape=True)}" + ) + cls_kwargs["sliding_window"] = key_value_pairs[0][0].shape[2] + assert isinstance( + cls_kwargs["sliding_window"], int + ), f"sliding_window must be an integer but shape={key_value_pairs[0][0].shape}" + else: + cls_layer = ( + transformers.cache_utils.DynamicLayer + if hasattr(transformers.cache_utils, "DynamicLayer") + else None + ) + if ( key_value_pairs and isinstance(key_value_pairs[0][0], torch._subclasses.fake_tensor.FakeTensor) and pv.Version(transformers.__version__) >= pv.Version("4.56") ): cache = transformers.cache_utils.DynamicCache() - cache.layers.extend( - [transformers.cache_utils.DynamicLayer() for _ in key_value_pairs] - ) + cache.layers.extend([cls_layer(**cls_kwargs) for _ in key_value_pairs]) for i, layer in enumerate(cache.layers): k, v = key_value_pairs[i][0], key_value_pairs[i][1] layer.dtype = k.dtype @@ -214,14 +257,21 @@ def make_dynamic_cache( ) return finalize_cache(cache) - cache = transformers.cache_utils.DynamicCache(key_value_pairs) - if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers): - # The cache constructor contains the two following lines - # (in cache_utils.py) which append empty layers when the cache is - # initialized. We need to remove them. - # self.num_hidden_layers = getattr(config, "num_hidden_layers", 1) - # self.append_new_layers(self.num_hidden_layers - 1) - cache.layers[:] = cache.layers[-len(key_value_pairs) :] + cache = transformers.cache_utils.DynamicCache() + if hasattr(cache, "layers") and cls_layer != transformers.cache_utils.DynamicLayer: + cache.layers.extend([cls_layer(**cls_kwargs) for _ in key_value_pairs]) + for i, layer in enumerate(cache.layers): + layer.keys, layer.values = key_value_pairs[i][0], key_value_pairs[i][1] + layer.is_initialized = True + else: + cache = transformers.cache_utils.DynamicCache(key_value_pairs) + if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers): + # The cache constructor contains the two following lines + # (in cache_utils.py) which append empty layers when the cache is + # initialized. We need to remove them. + # self.num_hidden_layers = getattr(config, "num_hidden_layers", 1) + # self.append_new_layers(self.num_hidden_layers - 1) + cache.layers[:] = cache.layers[-len(key_value_pairs) :] assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), ( f"Unexpected number of layers in the cache ({len(cache.layers)}), " f"{len(key_value_pairs)} expected." @@ -232,6 +282,7 @@ def make_dynamic_cache( def make_dynamic_cache( key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]], + cls_layers: Optional[Union[str, List[type]]] = None, ) -> transformers.cache_utils.DynamicCache: """ Creates an instance of :class:`transformers.cache_utils.DynamicCache`. @@ -263,6 +314,7 @@ def make_dynamic_cache( ) print(string_type(past_key_values, with_shape=True)) """ + assert not cls_layers, "cls_layers cannot be used for transformers<5." key_value_pairs = _preprocess_key_value_pairs(key_value_pairs) cache = transformers.cache_utils.DynamicCache(len(key_value_pairs)) # type: ignore for i, (key, value) in enumerate(key_value_pairs): @@ -508,9 +560,13 @@ def get_text_config(self, *args, **kwargs): ) return finalize_cache(cache) + def get_make_hybrid_cache(): + return make_sliding_window_cache + else: make_sliding_window_cache = None # type: ignore[assignment] + if hasattr(transformers.cache_utils, "HybridCache"): def make_hybrid_cache( @@ -672,9 +728,15 @@ def get_text_config(self, *args, **kwargs): ) return finalize_cache(cache) + def get_make_hybrid_cache(): + return make_hybrid_cache + else: make_hybrid_cache = None # type: ignore[assignment] + def get_make_hybrid_cache(): + return None + def finalize_cache(cache: transformers.cache_utils.Cache) -> transformers.cache_utils.Cache: """ diff --git a/onnx_diagnostic/helpers/torch_helper.py b/onnx_diagnostic/helpers/torch_helper.py index 2f7978db..59dfa397 100644 --- a/onnx_diagnostic/helpers/torch_helper.py +++ b/onnx_diagnostic/helpers/torch_helper.py @@ -15,9 +15,6 @@ from .cache_helper import ( make_dynamic_cache, make_encoder_decoder_cache, - make_hybrid_cache, - make_sliding_window_cache, - make_mamba_cache, make_static_cache, CacheKeyValue, ) @@ -769,10 +766,22 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device, str]) -> Any: return {to_any(t, to_value) for t in value} if type(value) is dict: return {k: to_any(t, to_value) for k, t in value.items()} - if value.__class__.__name__ in {"DynamicCache", "HybridCache"}: - make = dict(DynamicCache=make_dynamic_cache, HybridCache=make_hybrid_cache) + if value.__class__.__name__ == "DynamicCache": cc = CacheKeyValue(value) - return make[value.__class__.__name__]( # type: ignore[operator] + return make_dynamic_cache( + list( + zip( + [t.to(to_value) if t is not None else t for t in cc.key_cache], + [t.to(to_value) if t is not None else t for t in cc.value_cache], + ) + ), + cls_layers=cc.cls_layers, + ) + if value.__class__.__name__ == "HybridCache": + from .cache_helper import make_hybrid_cache + + cc = CacheKeyValue(value) + return make_hybrid_cache( list( zip( [t.to(to_value) if t is not None else t for t in cc.key_cache], @@ -843,7 +852,9 @@ def torch_deepcopy(value: Any) -> Any: from .cache_helper import CacheKeyValue ca = CacheKeyValue(value) - return make_dynamic_cache(torch_deepcopy(list(zip(ca.key_cache, ca.value_cache)))) + return make_dynamic_cache( + torch_deepcopy(list(zip(ca.key_cache, ca.value_cache))), cls_layers=ca.cls_layers + ) if value.__class__.__name__ == "StaticCache": from .cache_helper import CacheKeyValue @@ -858,12 +869,12 @@ def torch_deepcopy(value: Any) -> Any: max_cache_len=max([value.max_cache_len, *[t.shape[2] for t in ca.key_cache]]), ) if value.__class__.__name__ == "HybridCache": - from .cache_helper import CacheKeyValue + from .cache_helper import CacheKeyValue, make_hybrid_cache ca = CacheKeyValue(value) return make_hybrid_cache(torch_deepcopy(list(zip(ca.key_cache, ca.value_cache)))) if value.__class__.__name__ == "SlidingWindowCache": - from .cache_helper import CacheKeyValue + from .cache_helper import CacheKeyValue, make_sliding_window_cache ca = CacheKeyValue(value) return make_sliding_window_cache( @@ -875,6 +886,8 @@ def torch_deepcopy(value: Any) -> Any: torch_deepcopy(value.cross_attention_cache), ) if value.__class__.__name__ == "MambaCache": + from .cache_helper import make_mamba_cache + return make_mamba_cache(list(zip(value.conv_states, value.ssm_states))) if value.__class__ in torch.utils._pytree.SUPPORTED_NODES: diff --git a/onnx_diagnostic/tasks/image_text_to_text.py b/onnx_diagnostic/tasks/image_text_to_text.py index 3b60a8b8..f26b65e5 100644 --- a/onnx_diagnostic/tasks/image_text_to_text.py +++ b/onnx_diagnostic/tasks/image_text_to_text.py @@ -1,7 +1,7 @@ import itertools from typing import Any, Callable, Dict, Optional, Tuple import torch -from ..helpers.cache_helper import make_dynamic_cache, make_hybrid_cache +from ..helpers.cache_helper import make_dynamic_cache, get_make_hybrid_cache from ..helpers.config_helper import ( update_config, check_hasattr, @@ -200,6 +200,9 @@ def _check_(): _check_() + make_hybrid_cache = get_make_hybrid_cache() + assert make_hybrid_cache is not None, "not implemented when make_hybrid_cache is missing" + inputs = dict( input_ids=dummies["input_ids"], token_type_ids=dummies["token_type_ids"], diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py index 7b81277c..1b1bad0d 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py @@ -7,15 +7,9 @@ from transformers.cache_utils import DynamicCache, StaticCache try: - from transformers.cache_utils import ( - EncoderDecoderCache, - HybridCache, - SlidingWindowCache, - ) + from transformers.cache_utils import EncoderDecoderCache except ImportError: EncoderDecoderCache = None - HybridCache = None - SlidingWindowCache = None from ..helpers import string_type from .serialization import _lower_name_with_ @@ -36,6 +30,24 @@ def get_mamba_cache_cls() -> type: return None +def get_hybrid_cache_cls() -> type: + try: + from transformers.cache_utils import HybridCache + + return HybridCache + except ImportError: + return None + + +def get_sliding_window_cache_cls() -> type: + try: + from transformers.cache_utils import SlidingWindowCache + + return SlidingWindowCache + except ImportError: + return None + + def register_class_serialization( cls, f_flatten: Callable, @@ -179,18 +191,9 @@ def serialization_functions( flatten_dynamic_cache, unflatten_dynamic_cache, flatten_with_keys_dynamic_cache, - flatten_hybrid_cache, - unflatten_hybrid_cache, - flatten_with_keys_hybrid_cache, - flatten_mamba_cache, - unflatten_mamba_cache, - flatten_with_keys_mamba_cache, flatten_encoder_decoder_cache, unflatten_encoder_decoder_cache, flatten_with_keys_encoder_decoder_cache, - flatten_sliding_window_cache, - unflatten_sliding_window_cache, - flatten_with_keys_sliding_window_cache, flatten_static_cache, unflatten_static_cache, flatten_with_keys_static_cache, @@ -208,14 +211,6 @@ def serialization_functions( # f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]), verbose=verbose, ), - HybridCache: lambda verbose=verbose: register_class_serialization( - HybridCache, - flatten_hybrid_cache, - unflatten_hybrid_cache, - flatten_with_keys_hybrid_cache, - # f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]), - verbose=verbose, - ), EncoderDecoderCache: lambda verbose=verbose: register_class_serialization( EncoderDecoderCache, flatten_encoder_decoder_cache, @@ -223,13 +218,6 @@ def serialization_functions( flatten_with_keys_encoder_decoder_cache, verbose=verbose, ), - SlidingWindowCache: lambda verbose=verbose: register_class_serialization( - SlidingWindowCache, - flatten_sliding_window_cache, - unflatten_sliding_window_cache, - flatten_with_keys_sliding_window_cache, - verbose=verbose, - ), StaticCache: lambda verbose=verbose: register_class_serialization( StaticCache, flatten_static_cache, @@ -240,6 +228,12 @@ def serialization_functions( } MambaCache = get_mamba_cache_cls() if MambaCache: + from .serialization.transformers_impl import ( + flatten_mamba_cache, + unflatten_mamba_cache, + flatten_with_keys_mamba_cache, + ) + transformers_classes[MambaCache] = ( lambda verbose=verbose: register_class_serialization( MambaCache, @@ -249,6 +243,42 @@ def serialization_functions( verbose=verbose, ) ) + HybridCache = get_hybrid_cache_cls() + if HybridCache: + from .serialization.transformers_impl import ( + flatten_hybrid_cache, + unflatten_hybrid_cache, + flatten_with_keys_hybrid_cache, + ) + + transformers_classes[HybridCache] = ( + lambda verbose=verbose: register_class_serialization( + HybridCache, + flatten_hybrid_cache, + unflatten_hybrid_cache, + flatten_with_keys_hybrid_cache, + verbose=verbose, + ) + ) + + SlidingWindowCache = get_sliding_window_cache_cls() + if SlidingWindowCache: + from .serialization.transformers_impl import ( + flatten_sliding_window_cache, + unflatten_sliding_window_cache, + flatten_with_keys_sliding_window_cache, + ) + + transformers_classes[SlidingWindowCache] = ( + lambda verbose=verbose: register_class_serialization( + SlidingWindowCache, + flatten_sliding_window_cache, + unflatten_sliding_window_cache, + flatten_with_keys_sliding_window_cache, + verbose=verbose, + ) + ) + classes.update(transformers_classes) if patch_diffusers: @@ -303,13 +333,7 @@ def unregister_class_serialization(cls: type, verbose: int = 0): def unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0): - """Undo all registrations.""" - MambaCache = get_mamba_cache_cls() - cls_ensemble = ( - {DynamicCache, EncoderDecoderCache} - | set(undo) - | ({MambaCache} if MambaCache else set()) - ) + cls_ensemble = {DynamicCache, EncoderDecoderCache} | set(undo) for cls in cls_ensemble: if undo.get(cls.__name__, False): unregister_class_serialization(cls, verbose) diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_torch.py b/onnx_diagnostic/torch_export_patches/patches/patch_torch.py index 63894b51..28d7ede1 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_torch.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_torch.py @@ -541,14 +541,17 @@ def compute_concrete_val() -> sympy.Basic: # oblivious_var_to_val will be defined iff we have sizes # with DimDynamic.OBLIVIOUS_SIZE type. # See https://github.com/pytorch/pytorch/issues/137100#issuecomment-2495778113 + var_to_val = getattr( + self, + "unbacked_var_to_val", + getattr(self, "oblivious_var_to_val", False), + ) if ( - self.oblivious_var_to_val - and not ( - correct_hint := orig_expr.xreplace(self.oblivious_var_to_val) - ).free_symbols + var_to_val + and not (correct_hint := orig_expr.xreplace(var_to_val)).free_symbols and not ( counterfactual_hint := orig_expr.xreplace( - {k: max(2, v) for k, v in self.oblivious_var_to_val.items()} + {k: max(2, v) for k, v in var_to_val.items()} ) ).free_symbols and correct_hint == counterfactual_hint @@ -571,11 +574,11 @@ def compute_concrete_val() -> sympy.Basic: # and if they pass we add a runtime assertions and continue. if ( not ok - and self.unbacked_var_to_val + and var_to_val and not ( - unsound_result := orig_expr.xreplace( - self.unbacked_var_to_val - ).xreplace(self.var_to_val) + unsound_result := orig_expr.xreplace(var_to_val).xreplace( + var_to_val + ) ).free_symbols ): # pyrefly: ignore # unbound-name diff --git a/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py b/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py index 548bc1dd..58fbceee 100644 --- a/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +++ b/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py @@ -1,13 +1,7 @@ import itertools from typing import Any, Callable, List, Set, Tuple import torch -from transformers.cache_utils import ( - Cache, - DynamicCache, - EncoderDecoderCache, - HybridCache, - StaticCache, -) +from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache try: from transformers.cache_utils import SlidingWindowCache @@ -15,6 +9,11 @@ SlidingWindowCache = None +try: + from transformers.cache_utils import HybridCache +except ImportError: + HybridCache = None + try: from transformers.models.mamba.modeling_mamba import MambaCache except ImportError: @@ -78,6 +77,14 @@ def flatten_dynamic_cache( dynamic_cache: DynamicCache, ) -> Tuple[List[Any], torch.utils._pytree.Context]: """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects.""" + assert ( + not hasattr(dynamic_cache, "layers") + or not dynamic_cache.layers + or all(lay.__class__.__name__ == "DynamicLayer" for lay in dynamic_cache.layers) + ), ( + f"The serialization does not work yet on other layers " + f"than DynamicLayer, but layers={[lay.__class__ for lay in dynamic_cache.layers]}" + ) return _flatten_key_value_cache(dynamic_cache) @@ -85,6 +92,14 @@ def flatten_with_keys_dynamic_cache( dynamic_cache: DynamicCache, ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]: """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects.""" + assert ( + not hasattr(dynamic_cache, "layers") + or not dynamic_cache.layers + or all(lay.__class__.__name__ == "DynamicLayer" for lay in dynamic_cache.layers) + ), ( + f"The serialization does not work yet on other layers " + f"than DynamicLayer, but layers={[lay.__class__ for lay in dynamic_cache.layers]}" + ) return _flatten_with_keys_cache(dynamic_cache) @@ -99,26 +114,25 @@ def unflatten_dynamic_cache( # HybridCache ############# +if HybridCache: -def flatten_hybrid_cache( - cache: HybridCache, -) -> Tuple[List[Any], torch.utils._pytree.Context]: - """Serializes a :class:`transformers.cache_utils.HybridCache` with python objects.""" - return _flatten_key_value_cache(cache) - - -def flatten_with_keys_hybrid_cache( - cache: HybridCache, -) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]: - """Serializes a :class:`transformers.cache_utils.HybridCache` with python objects.""" - return _flatten_with_keys_cache(cache) + def flatten_hybrid_cache( + cache: HybridCache, + ) -> Tuple[List[Any], torch.utils._pytree.Context]: + """Serializes a :class:`transformers.cache_utils.HybridCache` with python objects.""" + return _flatten_key_value_cache(cache) + def flatten_with_keys_hybrid_cache( + cache: HybridCache, + ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]: + """Serializes a :class:`transformers.cache_utils.HybridCache` with python objects.""" + return _flatten_with_keys_cache(cache) -def unflatten_hybrid_cache( - values: List[Any], context: torch.utils._pytree.Context, output_type=None -) -> HybridCache: - """Restores a :class:`transformers.cache_utils.HybridCache` from python objects.""" - return _unflatten_cache(make_hybrid_cache, values, context, output_type=output_type) + def unflatten_hybrid_cache( + values: List[Any], context: torch.utils._pytree.Context, output_type=None + ) -> HybridCache: + """Restores a :class:`transformers.cache_utils.HybridCache` from python objects.""" + return _unflatten_cache(make_hybrid_cache, values, context, output_type=output_type) #############