Skip to content
Merged
3 changes: 3 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ jobs:
- name: pip freeze
run: python -m pip freeze

- name: cache
run: PYTHONPATH=. UNITTEST_GOING=1 pytest _unittests/ut_helpers/test_cache_helper.py _unittests/ut_helpers/test_torch_helper.py _unittests/ut_helpers/test_fake_tensor_helper.py _unittests/ut_torch_export_patches/test_patch_serialization_transformers.py _unittests/ut_torch_export_patches/test_patch_transformers.py _unittests/ut_torch_export_patches/test_dynamic_class.py _unittests/ut_torch_export_patches/test_onnx_export_errors.py

- name: tiny-llm torch.export.export
run: PYTHONPATH=. UNITTEST_GOING=1 python _unittests/ut_torch_models/test_tiny_llms.py

Expand Down
2 changes: 2 additions & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ Change Logs
0.8.9
+++++

* :pr:`379`: update the handling of cache after the removal of HybridCache, SlidingWindowCache in ``transformers>=5``,

0.8.8
+++++

Expand Down
20 changes: 14 additions & 6 deletions _unittests/ut_export/test_shape_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,16 @@ def test_all_dynamic_shape_all_transformers_cache(self):
],
),
(
make_sliding_window_cache(
[
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
]
(
make_sliding_window_cache(
[
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
]
)
if make_sliding_window_cache is not None
else None
),
[
{0: "d_0_0", 1: "d_0_1", 2: "d_0_2", 3: "d_0_3"},
Expand Down Expand Up @@ -106,11 +110,15 @@ def test_all_dynamic_shape_all_transformers_cache(self):
]
with torch_export_patches(patch_transformers=True):
for cache, exds in caches:
if cache is None:
continue
with self.subTest(cache_name=cache.__class__.__name__, patch=True):
ds = all_dynamic_shapes_from_inputs(cache)
self.assertEqual(exds, ds)

for cache, exds in caches:
if cache is None:
continue
with self.subTest(cache_name=cache.__class__.__name__, patch=False):
ds = all_dynamic_shapes_from_inputs(cache)
self.assertEqual(exds, ds)
Expand Down
5 changes: 5 additions & 0 deletions _unittests/ut_helpers/test_cache_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ def test_make_mamba_cache(self):
)
self.assertEqual(0, max_diff(cache, cache)["abs"])

@unittest.skipIf(
not make_sliding_window_cache, "SlidingWindowCache removed in transformers>=5"
)
def test_make_sliding_window_cache(self):
cache = make_sliding_window_cache(
[
Expand Down Expand Up @@ -223,6 +226,7 @@ def test_unflatten_flatten_static_cache(self):
self.string_type(unflat, with_shape=True),
)

@unittest.skipIf(not make_hybrid_cache, "HybridCache removed in transformers>=5")
def test_make_hybrid_cache(self):
cache = make_hybrid_cache(
[
Expand All @@ -240,6 +244,7 @@ def test_make_hybrid_cache(self):
self.assertEqual(0, max_diff(cache, cache)["abs"])
self.assertEqual(0, max_diff(cache, torch_deepcopy(cache))["abs"])

@unittest.skipIf(not make_hybrid_cache, "HybridCache removed in transformers>=5")
def test_unflatten_flatten_hybrid_cache(self):
with torch_export_patches(patch_transformers=True):
c2 = make_hybrid_cache(
Expand Down
14 changes: 8 additions & 6 deletions _unittests/ut_helpers/test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,12 +667,14 @@ def test_max_diff_caches(self):
[(torch.rand((1, 1, 3, 4)), torch.rand((1, 1, 3, 4)))], max_cache_len=3
)
self.assertEqual(max_diff(cache, cache)["abs"], 0)
cache = make_hybrid_cache([(torch.rand((1, 1, 3, 4)), torch.rand((1, 1, 3, 4)))])
self.assertEqual(max_diff(cache, cache)["abs"], 0)
cache = make_sliding_window_cache(
[(torch.rand((1, 1, 3, 4)), torch.rand((1, 1, 3, 4)))]
)
self.assertEqual(max_diff(cache, cache)["abs"], 0)
if make_hybrid_cache is not None:
cache = make_hybrid_cache([(torch.rand((1, 1, 3, 4)), torch.rand((1, 1, 3, 4)))])
self.assertEqual(max_diff(cache, cache)["abs"], 0)
if make_sliding_window_cache is not None:
cache = make_sliding_window_cache(
[(torch.rand((1, 1, 3, 4)), torch.rand((1, 1, 3, 4)))]
)
self.assertEqual(max_diff(cache, cache)["abs"], 0)
cache = make_encoder_decoder_cache(cache, cache)
self.assertEqual(max_diff(cache, cache)["abs"], 0)

Expand Down
1 change: 1 addition & 0 deletions _unittests/ut_helpers/test_torch_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ def test_torch_deepcopy_base_model_outputs(self):
self.assertEqual(hash1, hash2)
self.assertGreater(torch_tensor_size(bo), 1)

@unittest.skipIf(make_sliding_window_cache is None, "SlidingWindowCache was removed")
def test_torch_deepcopy_sliding_windon_cache(self):
cache = make_sliding_window_cache(
[
Expand Down
4 changes: 4 additions & 0 deletions _unittests/ut_tasks/test_tasks_image_text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
from onnx_diagnostic.helpers.cache_helper import get_make_hybrid_cache


class TestTasksImageTextToText(ExtTestCase):
Expand Down Expand Up @@ -58,6 +59,9 @@ def test_image_text_to_text_tiny_gemma3(self):
@requires_transformers("4.56.99")
@requires_torch("2.8.99")
def test_image_text_to_text_gemma3_4b_it(self):
make_hybrid_cache = get_make_hybrid_cache()
if make_hybrid_cache is None:
raise unittest.SkipTest("not implemented yet for transformers>=5")
mid = "google/gemma-3-4b-it"
data = get_untrained_model_with_inputs(
mid,
Expand Down
54 changes: 50 additions & 4 deletions _unittests/ut_torch_export_patches/test_dynamic_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
ignore_warnings,
hide_stdout,
requires_torch,
requires_transformers,
)
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache, CacheKeyValue
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy
from onnx_diagnostic.torch_export_patches.onnx_export_errors import (
torch_export_patches,
)
Expand Down Expand Up @@ -288,7 +290,7 @@ def test_phi2_export_module(self):
data = get_untrained_model_with_inputs("microsoft/phi-2")
model, inputs, dyn_shapes = data["model"], data["inputs"], data["dynamic_shapes"]
str_inputs = string_type(inputs, with_shape=True, with_min_max=True)
inputs_copied = copy.deepcopy(inputs)
inputs_copied = torch_deepcopy(inputs)
expected = model(**inputs_copied)
self.maxDiff = None
self.assertEqual(str_inputs, string_type(inputs, with_shape=True, with_min_max=True))
Expand All @@ -298,7 +300,7 @@ def test_phi2_export_module(self):
string_type(inputs, with_shape=True, with_min_max=True),
string_type(inputs_copied, with_shape=True, with_min_max=True),
)
inputs_copied = copy.deepcopy(inputs)
inputs_copied = torch_deepcopy(inputs)
self.assertEqual(
str_inputs, string_type(inputs_copied, with_shape=True, with_min_max=True)
)
Expand All @@ -307,13 +309,13 @@ def test_phi2_export_module(self):
ep = torch.export.export(
model,
(),
kwargs=inputs,
kwargs=torch_deepcopy(inputs),
dynamic_shapes=use_dyn_not_str(dyn_shapes),
strict=False, # True works but then the it fails during the execution
)
# ep = ep.run_decompositions()
mod = ep.module()
inputs_copied = copy.deepcopy(inputs)
inputs_copied = torch_deepcopy(inputs)
self.assertEqual(
str_inputs, string_type(inputs_copied, with_shape=True, with_min_max=True)
)
Expand Down Expand Up @@ -368,6 +370,50 @@ def call_function(self, target, args, kwargs):
got = MyInterpreter(ep.module()).run(*args)
self.assertEqualAny(expected, got)

@ignore_warnings(UserWarning)
@requires_torch("2.9")
@requires_transformers("4.57")
def test_tiny_llm_export_module(self):
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM")
model, inputs, dyn_shapes = data["model"], data["inputs"], data["dynamic_shapes"]
inputs_copied = torch_deepcopy(inputs)
self.assertEqualArray(inputs["input_ids"], inputs_copied["input_ids"])
self.assertEqualArray(inputs["position_ids"], inputs_copied["position_ids"])
self.assertEqualArray(inputs["attention_mask"], inputs_copied["attention_mask"])
self.assertEqualArray(
inputs["past_key_values"].layers[0].keys,
inputs_copied["past_key_values"].layers[0].keys,
)
self.assertEqualArray(
inputs["past_key_values"].layers[0].values,
inputs_copied["past_key_values"].layers[0].values,
)
expected = model(**torch_deepcopy(inputs))

with torch_export_patches(patch_torch=False, patch_transformers=True):
inputs_copied = torch_deepcopy(inputs)
self.assertEqualArray(inputs["input_ids"], inputs_copied["input_ids"])
self.assertEqualArray(inputs["position_ids"], inputs_copied["position_ids"])
self.assertEqualArray(inputs["attention_mask"], inputs_copied["attention_mask"])
self.assertEqualArray(
inputs["past_key_values"].layers[0].keys,
inputs_copied["past_key_values"].layers[0].keys,
)
self.assertEqualArray(
inputs["past_key_values"].layers[0].values,
inputs_copied["past_key_values"].layers[0].values,
)
ep = torch.export.export(
model,
(),
kwargs=torch_deepcopy(inputs),
dynamic_shapes=use_dyn_not_str(dyn_shapes),
strict=False,
)
mod = ep.module()
got = mod(**torch_deepcopy(inputs))
self.assertEqualAny(expected, got)


if __name__ == "__main__":
unittest.main(verbosity=2)
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def forward(self, x: torch.Tensor, cache: MambaCache):
model(x, cache)
DYN = torch.export.Dim.DYNAMIC

with torch_export_patches():
with torch_export_patches(patch_transformers=True):
cache = MambaCache(_config(), max_batch_size=2, device="cpu")
torch.export.export(
Model(),
Expand Down
Loading
Loading