diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 4f350380..38b99ba2 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,8 @@ Change Logs 0.8.12 ++++++ + +* :pr:`400`, :pr:`401`:, :pr:`402`: improves InputObserver (investigations), add it the documentation * :pr:`399`: update CI 0.8.11 diff --git a/_doc/final/plot_export_gemma3_tiny_input_observer.py b/_doc/final/plot_export_gemma3_tiny_input_observer.py new file mode 100644 index 00000000..2b57d8b5 --- /dev/null +++ b/_doc/final/plot_export_gemma3_tiny_input_observer.py @@ -0,0 +1,96 @@ +""" +.. _l-plot-gemma3-tiny-export-input-observer: + +Export Gemma3 tiny random with InputObserver +============================================ + +This reuses the recipe introduced by example :ref:`l-plot-tiny-llm-export-input-observer` +for model `tiny-random/gemma-3 `_. +""" + +import pandas +from onnx_diagnostic import doc +from onnx_diagnostic.helpers import string_type +from onnx_diagnostic.export.api import to_onnx +from onnx_diagnostic.torch_export_patches import ( + register_additional_serialization_functions, + torch_export_patches, +) +from onnx_diagnostic.investigate.input_observer import InputObserver + + +from transformers import pipeline + +model_id = "tiny-random/gemma-3" +pipe = pipeline( + "image-text-to-text", + model=model_id, + device="cuda", + trust_remote_code=True, + max_new_tokens=3, +) +messages = [ + {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, + { + "role": "user", + "content": [ + { + "type": "image", + "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/p-blog/candy.JPG", + }, + {"type": "text", "text": "What animal is on the candy?"}, + ], + }, +] + + +# %% +# The model to observe. +print("model type:", type(pipe.model)) + +# %% +# Captures inputs and outputs for the model. +observer = InputObserver() +with ( + register_additional_serialization_functions(patch_transformers=True), + observer(pipe.model), +): + pipe(text=messages, max_new_tokens=4) + + +print(f"{observer.num_obs()} observations stored for encoder.") + +# %% +# Exports the model. +kwargs = observer.infer_arguments() +dynamic_shapes = observer.infer_dynamic_shapes(set_batch_dimension_for=True) +print(f"encoder kwargs={string_type(kwargs, with_shape=True)}") +print(f"encoder dynamic_shapes={dynamic_shapes}") +for candidate in observer.info.inputs: + print( + " ", + candidate, + candidate.str_obs(), + string_type(candidate.aligned_flat_list, with_shape=True), + ) + + +filename = "plot_export_gemma3_tiny_input_observer.onnx" +with torch_export_patches(patch_transformers=True): + to_onnx( + pipe.model, + args=(), + filename=filename, + kwargs=kwargs, + dynamic_shapes=dynamic_shapes, + exporter="custom", + ) + +# %% +# Let's measure the discrepancies. +data = observer.check_discrepancies(filename, progress_bar=True) +print(pandas.DataFrame(data)) + + +# %% +doc.save_fig(doc.plot_dot(filename), f"{filename}.png", dpi=400) diff --git a/_doc/final/plot_export_whisper_tiny_input_observer.py b/_doc/final/plot_export_whisper_tiny_input_observer.py index bc753f32..8dec6a08 100644 --- a/_doc/final/plot_export_whisper_tiny_input_observer.py +++ b/_doc/final/plot_export_whisper_tiny_input_observer.py @@ -6,6 +6,9 @@ This reuses the recipe introduced by example :ref:`l-plot-tiny-llm-export-input-observer` for model `openai/whisper-tiny `_. + +The model ++++++++++ """ import pandas diff --git a/_unittests/ut_helpers/test_cache_helper.py b/_unittests/ut_helpers/test_cache_helper.py index 7d018b01..eea74ba7 100644 --- a/_unittests/ut_helpers/test_cache_helper.py +++ b/_unittests/ut_helpers/test_cache_helper.py @@ -13,7 +13,7 @@ make_sliding_window_cache, make_static_cache, ) -from onnx_diagnostic.helpers.torch_helper import torch_deepcopy +from onnx_diagnostic.helpers.torch_helper import torch_deepcopy, to_any from onnx_diagnostic.export import CoupleInputsDynamicShapes from onnx_diagnostic.torch_export_patches.patch_inputs import ( convert_dynamic_axes_into_dynamic_shapes, @@ -353,6 +353,63 @@ def forward(self, x, i, j): ) self.assertNotEmpty(ep) + @requires_transformers("4.57") + def test_make_dynamic_cache_2_types(self): + cache = make_dynamic_cache( + [ + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + ], + cls_layers=[ + transformers.cache_utils.DynamicLayer, + transformers.cache_utils.DynamicSlidingWindowLayer, + ], + ) + text = self.string_type(cache, with_shape=True) + self.assertEqual( + "DynamicCache(DynamicLayer(T1s4x5x6x7, T1s4x5x6x7), " + "DynamicSlidingWindowLayer(T1s4x5x6x7, T1s4x5x6x7))", + text, + ) + self.assertEqual(0, max_diff(cache, cache)["abs"]) + + @requires_transformers("4.57") + def test_unflatten_flatten_mixed_layers(self): + with torch_export_patches(patch_transformers=True): + c2 = make_dynamic_cache( + [ + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + ], + cls_layers=[ + transformers.cache_utils.DynamicLayer, + transformers.cache_utils.DynamicSlidingWindowLayer, + ], + ) + self.assertEqual(0, max_diff(c2, c2)["abs"]) + self.assertIsInstance(c2, transformers.cache_utils.DynamicCache) + flat, spec = torch.utils._pytree.tree_flatten(c2) + self.assertIsInstance(flat, list) + self.assertEqual(len(flat), 4) + unflat = flatten_unflatten_for_dynamic_shapes(c2) + self.assertIsInstance(unflat, list) + self.assertEqual(len(unflat), 4) + restored = torch.utils._pytree.tree_unflatten(flat, spec) + self.assertEqual( + [type(lay) for lay in c2.layers], [type(lay) for lay in restored.layers] + ) + self.assertEqual(0, max_diff(c2, restored)["abs"]) + ct = to_any(c2, torch.float16) + self.assertEqual( + [type(lay) for lay in c2.layers], [type(lay) for lay in ct.layers] + ) + self.assertLess(max_diff(c2, ct)["abs"], 1e-3) + c3 = torch_deepcopy(c2) + self.assertEqual(0, max_diff(c2, c3)["abs"]) + self.assertEqual( + [type(lay) for lay in c2.layers], [type(lay) for lay in c3.layers] + ) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_investigate/test_input_observer.py b/_unittests/ut_investigate/test_input_observer.py index 6c642fd5..b40dac7d 100644 --- a/_unittests/ut_investigate/test_input_observer.py +++ b/_unittests/ut_investigate/test_input_observer.py @@ -129,10 +129,12 @@ def forward(self, x, y, add=True): torch.testing.assert_close(expected[i], observer.info.flat_outputs[i][0]) cst = torch.export.Dim.DYNAMIC - self.assertEqual(dict(x={0: cst, 1: cst}, y={1: cst}), observer.infer_dynamic_shapes()) + self.assertEqual( + dict(x={0: cst, 1: cst}, y={1: cst}, add=None), observer.infer_dynamic_shapes() + ) args = observer.infer_arguments() self.assertIsInstance(args, dict) - self.assertEqual(2, len(args)) + self.assertEqual(3, len(args)) def test_io_captured_args_kwargs(self): class Model(torch.nn.Module): @@ -512,7 +514,6 @@ def forward(self, x, y_list, z_tuple=None): self.assertEqual(expected, observer.infer_dynamic_shapes()) def test_io_captured_custom_class(self): - class TestCustomClass: def __init__(self, keys, values): self.data = list(zip(keys, values)) @@ -570,7 +571,10 @@ def forward(self, x, custom=None): ] cst = torch.export.Dim.DYNAMIC - expected = ({0: cst, 1: cst}, [{0: cst, 1: cst}, {1: cst}, {1: cst}, {0: cst, 1: cst}]) + expected = ( + {0: cst, 1: cst}, + [{0: cst, 1: cst}, {1: cst}, {1: cst}, {0: cst, 1: cst}], + ) flat = torch.utils._pytree.tree_flatten(inputs[-1])[0] self.assertEqual(len(flat), 5) @@ -811,6 +815,35 @@ def forward(self, x=None, y=None): self.assertEqual(2, len(args)) self.assertEqual(len([v for v in args.values() if v is not None]), 2) + def test_io_int_kwargs(self): + class Model(torch.nn.Module): + def forward(self, x=None, y=None, option=1): + if option == 1: + return x + y + return x - y + + inputs = [ + dict(x=torch.randn((5, 7)), y=torch.randn((5, 7)), option=0), + dict(x=torch.randn((5, 9)), y=torch.randn((5, 9)), option=0), + ] + + model = Model() + observer = InputObserver() + with observer(model, store_n_calls=4): + for kwargs in inputs: + model(**kwargs) + kwargs = observer.infer_arguments() + self.assertIn("option", kwargs) + self.assertEqual(kwargs["option"], 0) + shapes = observer.infer_dynamic_shapes() + self.assertIn("option", shapes) + self.assertEqual(shapes["option"], None) + ep = torch.export.export(model, (), kwargs=kwargs, dynamic_shapes=shapes) + self.assertEqualArray(model(**kwargs), ep.module()(**kwargs)) + epo = torch.onnx.export(model, (), kwargs=kwargs, dynamic_shapes=shapes) + proto = epo.model_proto + self.assertEqual(["x", "y"], [i.name for i in proto.graph.input]) + def test_io_mixed_args_kwargs_as_dict_2(self): class Model(torch.nn.Module): def forward(self, x=None, y=None): @@ -845,6 +878,26 @@ def forward(self, x=None, y=None): # self.assertEqual(2, len(args)) # self.assertEqual(len([v for v in args.values() if v is not None]), 2) + def test_infer_dynamic_shapes_exception(self): + """ + dict(input_ids:T7s1x282, + pixel_values:T1s1x3x896x896, + attention_mask:T7s1x282, + position_ids:T7s1x282, + token_type_ids:T7s1x282,cache_position:T7s282 + ) + dict(input_ids:T7s1x1,attention_mask:T7s1x283,position_ids:T7s1x1, + past_key_values:DynamicCache( + DynamicSlidingWindowLayer(T16s1x1x282x32, T16s1x1x282x32), + DynamicLayer(T16s1x1x282x32, T16s1x1x282x32)), + token_type_ids:T7s1x1,cache_position:T7s1) + dict(input_ids:T7s1x1,attention_mask:T7s1x284,position_ids:T7s1x1, + past_key_values:DynamicCache( + DynamicSlidingWindowLayer(T16s1x1x283x32, T16s1x1x283x32), + DynamicLayer(T16s1x1x283x32, T16s1x1x283x32)), + token_type_ids:T7s1x1,cache_position:T7s1) + """ + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_investigate/test_input_observer_transformers.py b/_unittests/ut_investigate/test_input_observer_transformers.py index 0c0c9ba0..52284fed 100644 --- a/_unittests/ut_investigate/test_input_observer_transformers.py +++ b/_unittests/ut_investigate/test_input_observer_transformers.py @@ -20,16 +20,13 @@ def test_input_observer_onnx_generate_tiny_llm(self): data = get_untrained_model_with_inputs(mid) model, inputs, _ds = data["model"], data["inputs"], data["dynamic_shapes"] input_ids = inputs["input_ids"][:1] - attention_mask = inputs["attention_mask"][:1] observer = InputObserver() with ( register_additional_serialization_functions(patch_transformers=True), observer(model), ): - outputs = model.generate( - input_ids=input_ids, attention_mask=attention_mask, do_sample=False - ) + outputs = model.generate(input_ids=input_ids, do_sample=False) filenamec = self.get_dump_file("test_input_observer_onnx_generate_tiny_llm.onnx") with torch_export_patches(patch_transformers=True): @@ -49,7 +46,7 @@ def test_input_observer_onnx_generate_tiny_llm(self): onnx_tokens = onnx_generate( filenamec, input_ids=input_ids, - attention_mask=attention_mask, + attention_mask=torch.ones(input_ids.shape, dtype=torch.int64), eos_token_id=model.config.eos_token_id, max_new_tokens=20, ) diff --git a/onnx_diagnostic/helpers/cache_helper.py b/onnx_diagnostic/helpers/cache_helper.py index 55dd3779..eef4e925 100644 --- a/onnx_diagnostic/helpers/cache_helper.py +++ b/onnx_diagnostic/helpers/cache_helper.py @@ -224,49 +224,68 @@ def make_dynamic_cache( are supported. """ key_value_pairs = _preprocess_key_value_pairs(key_value_pairs) - cls_kwargs = {} if isinstance(cls_layers, str): assert hasattr( transformers.cache_utils, cls_layers ), f"Unable to find class {cls_layers!r} in transformers.cache_utils" + cls_kwargs = {} cls_layer = getattr(transformers.cache_utils, cls_layers) if cls_layers == "DynamicSlidingWindowLayer": cls_kwargs["sliding_window"] = key_value_pairs[0][0].shape[2] assert isinstance( cls_kwargs["sliding_window"], int ), f"sliding_window must be an integer but shape={key_value_pairs[0][0].shape}" - elif cls_layers is not None: - unique = set(cls_layers) - assert len(unique) == 1, f"Not implemented when cls_layers={cls_layers}" - cls_layer = unique.pop() + elif cls_layers is not None and isinstance(cls_layers, list): + assert len(cls_layers) == len(key_value_pairs), ( + f"Length mismatch {len(key_value_pairs)} expected but " + f"{len(cls_layers)} layer types are given." + ) + cls_kwargs = [{} for _kv in key_value_pairs] # type: ignore[assignment] + cls_layer = None if ( hasattr(transformers.cache_utils, "DynamicSlidingWindowLayer") - and cls_layer == transformers.cache_utils.DynamicSlidingWindowLayer + and transformers.cache_utils.DynamicSlidingWindowLayer in cls_layers ): - from .helper import string_type - - assert key_value_pairs and key_value_pairs[0], ( - f"not implemented for key_value_pairs=" - f"{string_type(key_value_pairs, with_shape=True)}" - ) - cls_kwargs["sliding_window"] = key_value_pairs[0][0].shape[2] - assert isinstance( - cls_kwargs["sliding_window"], int - ), f"sliding_window must be an integer but shape={key_value_pairs[0][0].shape}" + assert ( + key_value_pairs and key_value_pairs[0] + ), f"not implemented for type(key_value_pairs[0])={type(key_value_pairs[0])}" + for kv, clsy, kws in zip(key_value_pairs, cls_layers, cls_kwargs): + if clsy == transformers.cache_utils.DynamicSlidingWindowLayer: + kws["sliding_window"] = kv[0].shape[2] # type: ignore[index] + assert isinstance( + kws["sliding_window"], int # type: ignore[index] + ), f"sliding_window must be an integer but shape={kv[0].shape}" else: + assert ( + cls_layers is None + ), f"cls_layers must be list or a string but it is {cls_layers}" + cls_kwargs = {} cls_layer = ( transformers.cache_utils.DynamicLayer if hasattr(transformers.cache_utils, "DynamicLayer") else None ) + if cls_layer is not None: + cls_layers = [cls_layer for _ in key_value_pairs] + cls_kwargs = ( + cls_kwargs # type: ignore[assignment] + if isinstance(cls_kwargs, list) + else [cls_kwargs for _ in key_value_pairs] + ) + elif cls_layers is not None: + assert isinstance(cls_layers, list), f"Unexpected type cls_layers={cls_layers}" + assert isinstance(cls_kwargs, list), f"Unexpected type cls_kwargs={cls_kwargs}" + if ( key_value_pairs and isinstance(key_value_pairs[0][0], torch._subclasses.fake_tensor.FakeTensor) and pv.Version(transformers.__version__) >= pv.Version("4.56") ): cache = transformers.cache_utils.DynamicCache() - cache.layers.extend([cls_layer(**cls_kwargs) for _ in key_value_pairs]) + cache.layers.extend( + [cls_layer(**kws) for cls_layer, kws in zip(cls_layers, cls_kwargs)] # type: ignore[operator, arg-type] + ) for i, layer in enumerate(cache.layers): k, v = key_value_pairs[i][0], key_value_pairs[i][1] layer.dtype = k.dtype @@ -281,8 +300,22 @@ def make_dynamic_cache( return finalize_cache(cache) cache = transformers.cache_utils.DynamicCache() - if hasattr(cache, "layers") and cls_layer != transformers.cache_utils.DynamicLayer: - cache.layers.extend([cls_layer(**cls_kwargs) for _ in key_value_pairs]) + if hasattr(cache, "layers") and ( + cls_layer is None or cls_layer != transformers.cache_utils.DynamicLayer + ): + assert isinstance( + cls_kwargs, list + ), f"Wrong type {type(cls_kwargs)} for cls_kwargs" + assert len(cls_kwargs) == len( + cls_layers + ), f"Length mismatch between cls_kwargs={cls_kwargs} and cls_layers={cls_layers}" + assert len(cls_kwargs) == len(key_value_pairs), ( + f"Length mismatch between cls_kwargs={cls_kwargs} and " + f"len(key_value_pairs)={len(key_value_pairs)}" + ) + cache.layers.extend( + [cls_layer(**kws) for cls_layer, kws in zip(cls_layers, cls_kwargs)] # type: ignore[operator, arg-type] + ) for i, layer in enumerate(cache.layers): layer.keys, layer.values = key_value_pairs[i][0], key_value_pairs[i][1] layer.is_initialized = True diff --git a/onnx_diagnostic/helpers/helper.py b/onnx_diagnostic/helpers/helper.py index 3d0110cb..69229bc2 100644 --- a/onnx_diagnostic/helpers/helper.py +++ b/onnx_diagnostic/helpers/helper.py @@ -574,6 +574,32 @@ def string_type( print(f"[string_type] CACHE1:{type(obj)}") return f"MambaCache(conv_states={c}, ssm_states={d})" + if ( + obj.__class__.__name__ in {"DynamicCache"} + and hasattr(obj, "layers") + and any(lay.__class__.__name__ != "DynamicLayer" for lay in obj.layers) + ): + slay = [] + for lay in obj.layers: + skeys = string_type( + lay.keys, + with_shape=with_shape, + with_min_max=with_min_max, + with_device=with_device, + limit=limit, + verbose=verbose, + ) + svalues = string_type( + lay.keys, + with_shape=with_shape, + with_min_max=with_min_max, + with_device=with_device, + limit=limit, + verbose=verbose, + ) + slay.append(f"{lay.__class__.__name__}({skeys}, {svalues})") + return f"{obj.__class__.__name__}({', '.join(slay)})" + if obj.__class__.__name__ in { "DynamicCache", "SlidingWindowCache", @@ -829,6 +855,19 @@ def _torch_sym_int_to_str(value: "torch.SymInt") -> Union[int, str]: # noqa: F return f"{obj}" if obj.__class__.__name__ == "FakeTensorContext": return "FakeTensorContext(...)" + if obj.__class__.__name__ == "Chat": + import transformers.utils.chat_template_utils as ctu + + assert isinstance(obj, ctu.Chat), f"unexpected type {type(obj)}" + msg = string_type( + obj.messages, + with_shape=with_shape, + with_min_max=with_min_max, + with_device=with_device, + limit=limit, + verbose=verbose, + ) + return f"Chat({msg})" if verbose: print(f"[string_type] END:{type(obj)}") diff --git a/onnx_diagnostic/investigate/input_observer.py b/onnx_diagnostic/investigate/input_observer.py index 9031e162..9233cd04 100644 --- a/onnx_diagnostic/investigate/input_observer.py +++ b/onnx_diagnostic/investigate/input_observer.py @@ -81,14 +81,16 @@ def _infer_dynamic_dimensions( shape_list: list of shapes, they must all have the same length set_batch_dimension: - make the first dimension dynamic if it is not + forces the first dimension to be treated as dynamic, + even if all shapes have the same value for that dimension Returns: list of dynamic dimensions """ unique_ranks = {len(shape) for shape in shape_list} torch._check( - len(unique_ranks) == 1, lambda: "all shapes in shape_list must have the same rank" + len(unique_ranks) == 1, + lambda: "all shapes in shape_list must have the same rank", ) rank = unique_ranks.pop() dynamic = [] @@ -100,43 +102,32 @@ def _infer_dynamic_dimensions( class InputCandidate: - """Steals forward method to collect inputs and outputs. - This information is used to infer dynamic shapes and - export arguments. - - Examples - -------- - >>> input_observer = InputObserver() - >>> with input_observer(model): - >>> model(x1, y1) - >>> model(x2, y2) - >>> ep = torch.export.export( # or torch.onnx.export - >>> model, - >>> input_observer.infer_arguments(), - >>> dynamic_shapes.input_observer.infer_dynamic_shapes(), - >>> ) + """Retains one set of inputs given to the forward method or any + other method the class :class:`InputObserver` is stealing from. - With LLM: - >>> input_observer = InputObserver() - >>> with input_observer(model): - >>> model.generate(input_ids) - >>> ep = torch.export.export( # or torch.onnx.export - >>> model, - >>> () - >>> kwargs=input_observer.infer_arguments(), - >>> dynamic_shapes.input_observer.infer_dynamic_shapes(), - >>> ) - - See example :ref:`l-plot-tiny-llm-export-input-observer`. + Args: + args: Positional arguments. + kwargs: Optional arguments. + clone: Clone the inputs before storing them. Some tensors + may be modified inplace, the original value must be retained. + cst_kwargs: Any optional arguments constant over multiple calls. + int, float, str, bool values must be stored here. """ - def __init__(self, args: tuple[Any, ...], kwargs: dict[str, Any], clone: bool): + def __init__( + self, + args: tuple[Any, ...], + kwargs: dict[str, Any], + clone: bool, + cst_kwargs: dict[str, int | str | float | bool], + ): self.args = args self.kwargs = kwargs self.flat_list, self.spec = torch.utils._pytree.tree_flatten((args, kwargs)) self.n_tensors = sum(t is not None for t in self.flat_list) self._position_to_args_kwargs: list[int | str] | None = None self._n_tensors_for_args_kwargs: dict[int | str, int] | None = None + self.cst_kwargs = cst_kwargs.copy() if clone: self.flat_list = [ @@ -165,7 +156,8 @@ def str_obs(self) -> str: """Prints out some information about the osbervations.""" return ( f"InputCandidate(args={string_type(self.args, with_shape=True)}, " - f"kwargs={string_type(self.kwargs, with_shape=True)})" + f"kwargs={string_type(self.kwargs, with_shape=True)}, " + f"cst_kwargs={self.cst_kwargs})" ) def build_mappings(self) -> list[int | str]: @@ -222,6 +214,12 @@ def align_with( ): """Two candidates are considered as aligned if after being flattened if they have the same number of tensors (None allowed).""" + if self.cst_kwargs != best_candidate.cst_kwargs: + raise RuntimeError( + f"Two calls were made with different constant values, " + f"{self.cst_kwargs} != {best_candidate.cst_kwargs}" + ) + args = self.args if len(self.args) > len(best_candidate.args): # We need to move some args to kwargs as the best_candidate does. @@ -285,9 +283,14 @@ class InputObserverInfo: They are used a second time because :func:`torch.export.export` cares about the order in kwargs and dynamic shapes, it needs to be the same in the ordered dictionaries `add_inputs` receive. + default_values: Default values defined by the signature of the function, + any value equal to that is ignore to simplify the export. """ - def __init__(self, signature_names: list[str]): + def __init__( + self, signature_names: list[str], default_values: dict[str, int | bool | str | float] + ): + self.default_values = default_values self.inputs: list[InputCandidate] = [] self.outputs_specs: list[torch.utils._pytree.PyTreeSpec] = [] self.flat_outputs: list[list[torch.Tensor | None]] = [] @@ -307,6 +310,13 @@ def add_inputs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): args: Positional arguments. kwargs: Named arguments. """ + cst_kwargs = { + k: v + for k, v in kwargs.items() + if k in self.signature_names + and isinstance(v, (int, float, bool, str)) + and v != self.default_values.get(k, None) + } kwargs = { k: v for k, v in kwargs.items() @@ -322,7 +332,7 @@ def add_inputs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): if k not in ordered_kwargs: ordered_kwargs[k] = v - candidate = InputCandidate(args, ordered_kwargs, clone=True) + candidate = InputCandidate(args, ordered_kwargs, clone=True, cst_kwargs=cst_kwargs) self.inputs.append(candidate) if self._best_candidate is None or len(self._best_candidate) < len(candidate): self._best_candidate = candidate @@ -371,16 +381,21 @@ def infer_dynamic_shapes( self, set_batch_dimension_for: set[int | str] | bool | None = None, return_flat: bool = False, - ) -> tuple[dict[int, Any], ...] | dict[str, dict[int, Any]]: - """Infers dynamic shapes. based on the collected tensors. + ) -> tuple[dict[int, Any] | None, ...] | dict[str, dict[int, Any] | None]: + """Infers dynamic shapes based on the collected tensors. Most of the time, models do support a batch dimension but this batch dimension has the same value for every input sample. Instead of running inference on new samples, argument `set_batch_dimension_for` can be used to tell the first dimension is a dynamic dimension for a particular set of inputs referenced by their name (str) or their position (int). - `return_flat` tells the function to return a flat tuple instead of - nested structured. + Args: + set_batch_dimension_for (set[int | str] | None): Set of input identifiers, + by name (``str``) or position (``int``), for which the first dimension + should be treated as a dynamic batch dimension. If ``None`` or empty, + no additional batch dimensions are marked as dynamic. + return_flat: Tells the function to return a flat tuple instead of + nested structured. """ self.align_inputs_none_values() # type checking @@ -445,18 +460,20 @@ def _set_batch_dimension_for_flat_index(index): self._best_candidate.kwargs ): # It means forward method is called with tensors only. - if not self._best_candidate.kwargs: + if not self._best_candidate.kwargs and not self._best_candidate.cst_kwargs: # only positional arguments return tuple(flat_dynamic_shapes) if not self._best_candidate.args: # only named arguments - return dict(zip(list(self._best_candidate.kwargs), flat_dynamic_shapes)) + ds = dict(zip(list(self._best_candidate.kwargs), flat_dynamic_shapes)) + return {**ds, **dict.fromkeys(self._best_candidate.cst_kwargs, None)} # positional arguments needs to be moved to the named arguments n_args = len(self._best_candidate.args) pos_names = self.signature_names[:n_args] return { **dict(zip(pos_names, flat_dynamic_shapes[:n_args])), **dict(zip(list(self._best_candidate.kwargs), flat_dynamic_shapes[n_args:])), + **dict.fromkeys(self._best_candidate.cst_kwargs, None), } # nested types, here comes the fun part because the shapes cannot be unflattened, @@ -492,6 +509,8 @@ def change_function(t): (self._best_candidate.args, self._best_candidate.kwargs), change_function=change_function, ) + if self._best_candidate.cst_kwargs: + ds_kwargs = {**ds_kwargs, **dict.fromkeys(self._best_candidate.cst_kwargs, None)} if not ds_kwargs: return tuple(ds_args) if not ds_args: @@ -595,6 +614,9 @@ def infer_arguments( args, kwargs = torch.utils._pytree.tree_unflatten( aligned_flat_list, candidate.aligned_spec ) + if self._best_candidate.cst_kwargs: + kwargs = {**kwargs, **self._best_candidate.cst_kwargs} + if not kwargs: return args if not args: @@ -631,6 +653,9 @@ class InputObserver: >>> kwargs=input_observer.infer_arguments(), >>> dynamic_shapes.input_observer.infer_dynamic_shapes(), >>> ) + + Examples can be found in :ref:`l-plot-tiny-llm-export-input-observer`, + :ref:`l-plot-whisper-tiny-export-input-observer`. """ def __init__(self): @@ -661,7 +686,10 @@ def num_obs(self) -> int: @contextlib.contextmanager def __call__( - self, model: torch.nn.Module, store_n_calls: int = 3, method_name: str = "forward" + self, + model: torch.nn.Module, + store_n_calls: int = 3, + method_name: str = "forward", ): """Starts collecting inputs and outputs of a specific method. The model method is replaced by a new one collecting tensors @@ -675,25 +703,34 @@ def __call__( method_name: Method name to spy on. """ if not hasattr(model, method_name): - raise ValueError(f"Model type {model} does not have a method {method_name!r}") + raise ValueError(f"Model type {model} does not have a method {method_name!r}.") captured_method = getattr(model, method_name) + sig = inspect.signature(captured_method) if self.info is None: self.info = InputObserverInfo( - signature_names=list(inspect.signature(captured_method).parameters) + signature_names=list(sig.parameters), + default_values={ + p.name: p.default + for p in sig.parameters.values() + if p.default != inspect.Parameter.empty + and isinstance(p.default, (int, bool, str, float)) + }, ) n_already_stored = len(self.info) - setattr( - model, - method_name, - lambda *args, _cm=captured_method, _snc=( - store_n_calls + n_already_stored - ), **kwargs: self._replaced_method( - *args, - _captured_method=_cm, - _store_n_calls=_snc, - **kwargs, - ), + lambda_method = lambda *args, _cm=captured_method, _snc=( # noqa: E731 + store_n_calls + n_already_stored + ), **kwargs: self._replaced_method( + *args, _captured_method=_cm, _store_n_calls=_snc, **kwargs ) + + # It may happen than the signature of the forward is used to trigger a preprocessing. + # This is used in GenerationMixin (transformers): + # position_ids_key = "decoder_position_ids" if ... else "position_ids" + # if position_ids_key in set(inspect.signature(self.forward).parameters.keys()): + lambda_method.__signature__ = sig # type: ignore[attr-defined] + + setattr(model, method_name, lambda_method) + try: yield self finally: @@ -705,13 +742,20 @@ def _check_captured(self): def infer_dynamic_shapes( self, set_batch_dimension_for: set[int | str] | bool | None = None - ) -> tuple[dict[int, Any], ...] | dict[str, dict[int, Any]]: + ) -> tuple[dict[int, Any] | None, ...] | dict[str, dict[int, Any] | None]: """ Infers dynamic shapes. Most of the time, models do support a batch dimension but this batch dimension has the same value for every input sample. Instead of running inference on new samples, argument `set_batch_dimension_for` can be used to tell the first dimension is a dynamic dimension for a particular set of inputs referenced by their name (str) or their position (int). + + Args: + set_batch_dimension_for (set[int | str] | None): A set of input + identifiers (by position as ``int`` or by name as ``str``) for + which the first dimension should be treated as a dynamic batch + dimension. If ``None``, no dimensions are explicitly marked as + dynamic. """ self._check_captured() assert self.info is not None # missed by type checking @@ -746,11 +790,22 @@ def infer_arguments( else: if isinstance(index_or_args_or_kwargs, tuple): index_or_candidate = InputCandidate( - args=index_or_args_or_kwargs, kwargs={}, clone=False + args=index_or_args_or_kwargs, kwargs={}, clone=False, cst_kwargs={} ) elif isinstance(index_or_args_or_kwargs, dict): index_or_candidate = InputCandidate( - args=(), kwargs=index_or_args_or_kwargs, clone=False + args=(), + kwargs={ + k: v + for k, v in index_or_args_or_kwargs.items() + if k not in self.info.default_values + }, + clone=False, + cst_kwargs={ + k: v + for k, v in index_or_args_or_kwargs.items() + if k in self.info.default_values + }, ) else: raise ValueError( @@ -832,6 +887,7 @@ def check_discrepancies( except Exception as e: error = str(e) ort_outputs = None + duration = time.perf_counter() - begin if error: diff: dict[str, Any] = dict(error=error, SUCCESS=False)