Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ Change Logs
0.8.12
++++++


* :pr:`400`, :pr:`401`:, :pr:`402`: improves InputObserver (investigations), add it the documentation
* :pr:`399`: update CI

0.8.11
Expand Down
96 changes: 96 additions & 0 deletions _doc/final/plot_export_gemma3_tiny_input_observer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""
.. _l-plot-gemma3-tiny-export-input-observer:

Export Gemma3 tiny random with InputObserver
============================================

This reuses the recipe introduced by example :ref:`l-plot-tiny-llm-export-input-observer`
for model `tiny-random/gemma-3 <https://huggingface.co/tiny-random/gemma-3>`_.
"""

import pandas
from onnx_diagnostic import doc
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.export.api import to_onnx
from onnx_diagnostic.torch_export_patches import (
register_additional_serialization_functions,
torch_export_patches,
)
from onnx_diagnostic.investigate.input_observer import InputObserver


from transformers import pipeline

model_id = "tiny-random/gemma-3"
pipe = pipeline(
"image-text-to-text",
model=model_id,
device="cuda",
trust_remote_code=True,
max_new_tokens=3,
)
messages = [
{"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
{
"role": "user",
"content": [
{
"type": "image",
"url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/p-blog/candy.JPG",
},
{"type": "text", "text": "What animal is on the candy?"},
],
},
]


# %%
# The model to observe.
print("model type:", type(pipe.model))

# %%
# Captures inputs and outputs for the model.
observer = InputObserver()
with (
register_additional_serialization_functions(patch_transformers=True),
observer(pipe.model),
):
pipe(text=messages, max_new_tokens=4)


print(f"{observer.num_obs()} observations stored for encoder.")

# %%
# Exports the model.
kwargs = observer.infer_arguments()
dynamic_shapes = observer.infer_dynamic_shapes(set_batch_dimension_for=True)
print(f"encoder kwargs={string_type(kwargs, with_shape=True)}")
print(f"encoder dynamic_shapes={dynamic_shapes}")
for candidate in observer.info.inputs:
print(
" ",
candidate,
candidate.str_obs(),
string_type(candidate.aligned_flat_list, with_shape=True),
)


filename = "plot_export_gemma3_tiny_input_observer.onnx"
with torch_export_patches(patch_transformers=True):
to_onnx(
pipe.model,
args=(),
filename=filename,
kwargs=kwargs,
dynamic_shapes=dynamic_shapes,
exporter="custom",
)

# %%
# Let's measure the discrepancies.
data = observer.check_discrepancies(filename, progress_bar=True)
print(pandas.DataFrame(data))


# %%
doc.save_fig(doc.plot_dot(filename), f"{filename}.png", dpi=400)
3 changes: 3 additions & 0 deletions _doc/final/plot_export_whisper_tiny_input_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

This reuses the recipe introduced by example :ref:`l-plot-tiny-llm-export-input-observer`
for model `openai/whisper-tiny <https://huggingface.co/openai/whisper-tiny>`_.

The model
+++++++++
"""

import pandas
Expand Down
59 changes: 58 additions & 1 deletion _unittests/ut_helpers/test_cache_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
make_sliding_window_cache,
make_static_cache,
)
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy, to_any
from onnx_diagnostic.export import CoupleInputsDynamicShapes
from onnx_diagnostic.torch_export_patches.patch_inputs import (
convert_dynamic_axes_into_dynamic_shapes,
Expand Down Expand Up @@ -353,6 +353,63 @@ def forward(self, x, i, j):
)
self.assertNotEmpty(ep)

@requires_transformers("4.57")
def test_make_dynamic_cache_2_types(self):
cache = make_dynamic_cache(
[
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
],
cls_layers=[
transformers.cache_utils.DynamicLayer,
transformers.cache_utils.DynamicSlidingWindowLayer,
],
)
text = self.string_type(cache, with_shape=True)
self.assertEqual(
"DynamicCache(DynamicLayer(T1s4x5x6x7, T1s4x5x6x7), "
"DynamicSlidingWindowLayer(T1s4x5x6x7, T1s4x5x6x7))",
text,
)
self.assertEqual(0, max_diff(cache, cache)["abs"])

@requires_transformers("4.57")
def test_unflatten_flatten_mixed_layers(self):
with torch_export_patches(patch_transformers=True):
c2 = make_dynamic_cache(
[
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
],
cls_layers=[
transformers.cache_utils.DynamicLayer,
transformers.cache_utils.DynamicSlidingWindowLayer,
],
)
self.assertEqual(0, max_diff(c2, c2)["abs"])
self.assertIsInstance(c2, transformers.cache_utils.DynamicCache)
flat, spec = torch.utils._pytree.tree_flatten(c2)
self.assertIsInstance(flat, list)
self.assertEqual(len(flat), 4)
unflat = flatten_unflatten_for_dynamic_shapes(c2)
self.assertIsInstance(unflat, list)
self.assertEqual(len(unflat), 4)
restored = torch.utils._pytree.tree_unflatten(flat, spec)
self.assertEqual(
[type(lay) for lay in c2.layers], [type(lay) for lay in restored.layers]
)
self.assertEqual(0, max_diff(c2, restored)["abs"])
ct = to_any(c2, torch.float16)
self.assertEqual(
[type(lay) for lay in c2.layers], [type(lay) for lay in ct.layers]
)
self.assertLess(max_diff(c2, ct)["abs"], 1e-3)
c3 = torch_deepcopy(c2)
self.assertEqual(0, max_diff(c2, c3)["abs"])
self.assertEqual(
[type(lay) for lay in c2.layers], [type(lay) for lay in c3.layers]
)


if __name__ == "__main__":
unittest.main(verbosity=2)
61 changes: 57 additions & 4 deletions _unittests/ut_investigate/test_input_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,12 @@ def forward(self, x, y, add=True):
torch.testing.assert_close(expected[i], observer.info.flat_outputs[i][0])

cst = torch.export.Dim.DYNAMIC
self.assertEqual(dict(x={0: cst, 1: cst}, y={1: cst}), observer.infer_dynamic_shapes())
self.assertEqual(
dict(x={0: cst, 1: cst}, y={1: cst}, add=None), observer.infer_dynamic_shapes()
)
args = observer.infer_arguments()
self.assertIsInstance(args, dict)
self.assertEqual(2, len(args))
self.assertEqual(3, len(args))

def test_io_captured_args_kwargs(self):
class Model(torch.nn.Module):
Expand Down Expand Up @@ -512,7 +514,6 @@ def forward(self, x, y_list, z_tuple=None):
self.assertEqual(expected, observer.infer_dynamic_shapes())

def test_io_captured_custom_class(self):

class TestCustomClass:
def __init__(self, keys, values):
self.data = list(zip(keys, values))
Expand Down Expand Up @@ -570,7 +571,10 @@ def forward(self, x, custom=None):
]

cst = torch.export.Dim.DYNAMIC
expected = ({0: cst, 1: cst}, [{0: cst, 1: cst}, {1: cst}, {1: cst}, {0: cst, 1: cst}])
expected = (
{0: cst, 1: cst},
[{0: cst, 1: cst}, {1: cst}, {1: cst}, {0: cst, 1: cst}],
)
flat = torch.utils._pytree.tree_flatten(inputs[-1])[0]
self.assertEqual(len(flat), 5)

Expand Down Expand Up @@ -811,6 +815,35 @@ def forward(self, x=None, y=None):
self.assertEqual(2, len(args))
self.assertEqual(len([v for v in args.values() if v is not None]), 2)

def test_io_int_kwargs(self):
class Model(torch.nn.Module):
def forward(self, x=None, y=None, option=1):
if option == 1:
return x + y
return x - y

inputs = [
dict(x=torch.randn((5, 7)), y=torch.randn((5, 7)), option=0),
dict(x=torch.randn((5, 9)), y=torch.randn((5, 9)), option=0),
]

model = Model()
observer = InputObserver()
with observer(model, store_n_calls=4):
for kwargs in inputs:
model(**kwargs)
kwargs = observer.infer_arguments()
self.assertIn("option", kwargs)
self.assertEqual(kwargs["option"], 0)
shapes = observer.infer_dynamic_shapes()
self.assertIn("option", shapes)
self.assertEqual(shapes["option"], None)
ep = torch.export.export(model, (), kwargs=kwargs, dynamic_shapes=shapes)
self.assertEqualArray(model(**kwargs), ep.module()(**kwargs))
epo = torch.onnx.export(model, (), kwargs=kwargs, dynamic_shapes=shapes)
proto = epo.model_proto
self.assertEqual(["x", "y"], [i.name for i in proto.graph.input])

def test_io_mixed_args_kwargs_as_dict_2(self):
class Model(torch.nn.Module):
def forward(self, x=None, y=None):
Expand Down Expand Up @@ -845,6 +878,26 @@ def forward(self, x=None, y=None):
# self.assertEqual(2, len(args))
# self.assertEqual(len([v for v in args.values() if v is not None]), 2)

def test_infer_dynamic_shapes_exception(self):
"""
dict(input_ids:T7s1x282,
pixel_values:T1s1x3x896x896,
attention_mask:T7s1x282,
position_ids:T7s1x282,
token_type_ids:T7s1x282,cache_position:T7s282
)
dict(input_ids:T7s1x1,attention_mask:T7s1x283,position_ids:T7s1x1,
past_key_values:DynamicCache(
DynamicSlidingWindowLayer(T16s1x1x282x32, T16s1x1x282x32),
DynamicLayer(T16s1x1x282x32, T16s1x1x282x32)),
token_type_ids:T7s1x1,cache_position:T7s1)
dict(input_ids:T7s1x1,attention_mask:T7s1x284,position_ids:T7s1x1,
past_key_values:DynamicCache(
DynamicSlidingWindowLayer(T16s1x1x283x32, T16s1x1x283x32),
DynamicLayer(T16s1x1x283x32, T16s1x1x283x32)),
token_type_ids:T7s1x1,cache_position:T7s1)
"""


if __name__ == "__main__":
unittest.main(verbosity=2)
7 changes: 2 additions & 5 deletions _unittests/ut_investigate/test_input_observer_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,13 @@ def test_input_observer_onnx_generate_tiny_llm(self):
data = get_untrained_model_with_inputs(mid)
model, inputs, _ds = data["model"], data["inputs"], data["dynamic_shapes"]
input_ids = inputs["input_ids"][:1]
attention_mask = inputs["attention_mask"][:1]

observer = InputObserver()
with (
register_additional_serialization_functions(patch_transformers=True),
observer(model),
):
outputs = model.generate(
input_ids=input_ids, attention_mask=attention_mask, do_sample=False
)
outputs = model.generate(input_ids=input_ids, do_sample=False)

filenamec = self.get_dump_file("test_input_observer_onnx_generate_tiny_llm.onnx")
with torch_export_patches(patch_transformers=True):
Expand All @@ -49,7 +46,7 @@ def test_input_observer_onnx_generate_tiny_llm(self):
onnx_tokens = onnx_generate(
filenamec,
input_ids=input_ids,
attention_mask=attention_mask,
attention_mask=torch.ones(input_ids.shape, dtype=torch.int64),
eos_token_id=model.config.eos_token_id,
max_new_tokens=20,
)
Expand Down
Loading
Loading