diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index c0c4b39d..d3671fb5 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.8.9 +++++ +* :pr:`378`: implements the computation of discrepancies in ``method_to_onnx`` * :pr:`379`: update the handling of cache after the removal of HybridCache, SlidingWindowCache in ``transformers>=5``, 0.8.8 diff --git a/_doc/examples/plot_export_tiny_llm_method_generate.py b/_doc/examples/plot_export_tiny_llm_method_generate.py index 8109a251..6568ffe4 100644 --- a/_doc/examples/plot_export_tiny_llm_method_generate.py +++ b/_doc/examples/plot_export_tiny_llm_method_generate.py @@ -17,6 +17,7 @@ `arnir0/Tiny-LLM `_. """ +import pandas from transformers import AutoModelForCausalLM, AutoTokenizer from onnx_diagnostic import doc from onnx_diagnostic.export.api import method_to_onnx @@ -108,10 +109,24 @@ def generate_text( # %% -# Let's call generate again. +# Let's call generate again. The conversion is triggered after +# ``convert_after_n_calls=3`` calls to the method forward, +# which exactly what the method generate is doing. generated_text = generate_text(prompt, model, tokenizer) print(generated_text) +# %% +# We finally need to check the discrepancies. +# The exports produced an onnx file and dumped the input and output +# of the torch model. We now run the onnx model to check +# it produces the same results. +# It is done after because the model may not hold twice in memory +# (torch and onnxruntime). +# verbose=2 shows more information about expected outputs. +data = forward_replacement.check_discrepancies(verbose=1) +df = pandas.DataFrame(data) +print(df) + # %% diff --git a/_unittests/ut_export/test_api.py b/_unittests/ut_export/test_api.py index 186a7472..4a47f362 100644 --- a/_unittests/ut_export/test_api.py +++ b/_unittests/ut_export/test_api.py @@ -19,6 +19,7 @@ class TestValidate(ExtTestCase): + @ignore_warnings(FutureWarning) @hide_stdout() def test_to_onnx(self): class Model(torch.nn.Module): @@ -142,6 +143,9 @@ def forward(self, x, y): feeds = make_feeds(input_names, args, use_numpy=True) got = sess.run(None, feeds) self.assertEqualArray(expected, got[0]) + df = method_to_call.check_discrepancies() + self.assertIsInstance(df, list) + self.assertEqual(len(df), 2) self.clean_dump() @requires_experimental_experiment("0.1") @@ -243,6 +247,7 @@ def forward(self, x, y=None): feeds = make_feeds(input_names, (args, kwargs), use_numpy=True) got = sess.run(None, feeds) self.assertEqualArray(expected, got[0]) + method_to_call.check_discrepancies(verbose=1) self.clean_dump() diff --git a/onnx_diagnostic/export/api.py b/onnx_diagnostic/export/api.py index 244d6dfb..5d506bb8 100644 --- a/onnx_diagnostic/export/api.py +++ b/onnx_diagnostic/export/api.py @@ -1,11 +1,16 @@ import inspect import os import textwrap +import time +from collections.abc import Mapping, Iterable from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union import torch from .dynamic_shapes import ModelInputs from .onnx_plug import EagerDirectReplacementWithOnnx -from ..helpers import string_type +from ..helpers import flatten_object, max_diff, string_diff, string_type +from ..helpers.torch_helper import torch_deepcopy +from ..helpers.rt_helper import make_feeds +from ..reference import OnnxruntimeEvaluator def get_main_dispatcher( @@ -314,10 +319,11 @@ def to_onnx( raise ValueError(f"Unknown exporter={exporter!r}") -class _WrapperToExportMethodToOnnx(torch.nn.Module): +class WrapperToExportMethodToOnnx(torch.nn.Module): """ Wraps an existing models in order to spy on inputs. - This is used by :func:`onnx_diagnostic.export.api.method_to_onnx`. + This is used by :func:`onnx_diagnostic.export.api.method_to_onnx` + or :ref:`l-plot-tiny-llm-export-method-generate` for an example. """ def __init__( @@ -352,6 +358,7 @@ def __init__( else getattr(mod, method_name) ) self._inputs: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = [] + self._outputs: List[Any] = [] self._convert_after_n_calls = convert_after_n_calls self._patch_kwargs = patch_kwargs self._method_src = None @@ -375,6 +382,7 @@ def __init__( inline=inline, ) self._export_done = False + self._serialization_classes: Set[type] = set() def __str__(self) -> str: return self.__repr__() @@ -385,18 +393,41 @@ def __repr__(self) -> str: f"{self._method_name})" ) + def _collect_classes(self, obj): + if obj is None or isinstance(obj, torch.Tensor): + return + cls = type(obj) + if cls.__module__ not in ("builtins",): + self._serialization_classes.add(cls) + if hasattr(obj, "__dict__"): + for v in vars(obj).values(): + self._collect_classes(v) + return + if isinstance(obj, Mapping): + for v in obj.values(): + self._collect_classes(v) + return + if isinstance(obj, Iterable) and not isinstance(obj, (str, bytes)): + for v in obj: + self._collect_classes(v) + return + def forward(self, *args, **kwargs): if not self._export_done: self._inputs.append( - ( - args, + torch_deepcopy( ( - kwargs - if not kwargs or not self.skip_kwargs_names - else { - k: v for k, v in kwargs.items() if k not in self.skip_kwargs_names - } - ), + args, + ( + kwargs + if not kwargs or not self.skip_kwargs_names + else { + k: v + for k, v in kwargs.items() + if k not in self.skip_kwargs_names + } + ), + ) ) ) if self.verbose: @@ -405,12 +436,45 @@ def forward(self, *args, **kwargs): f"{string_type(self._inputs[-1], with_shape=True)}" ) if len(self._inputs) >= self._convert_after_n_calls: + name = os.path.splitext(self._to_onnx_kwargs["filename"])[0] + input_file = f"{name}.inputs.pt" + self._input_file = input_file + if self.verbose: + print( + f"[method_to_onnx] save {len(self._inputs)} inputs in {input_file!r}" + ) + torch.save(self._inputs, input_file) self._convert_method_to_onnx() - del self._inputs[:] self._export_done = True - return self._method_call(*args, **kwargs) + + begin = time.perf_counter() + res = self._method_call(*args, **kwargs) + duration = time.perf_counter() - begin + self._collect_classes([args, kwargs, res]) + if self._inputs: + self._outputs.append((torch_deepcopy(res), duration)) + assert len(self._inputs) == len(self._outputs), ( + f"Number of inputs {len(self._inputs)} and " + f"outputs {len(self._outputs)} are different." + ) + if self._export_done: + name = os.path.splitext(self._to_onnx_kwargs["filename"])[0] + output_file = f"{name}.outputs.pt" + if self.verbose: + print( + f"[method_to_onnx] save {len(self._outputs)} " + f"outputs in {output_file!r}" + ) + torch.save(self._outputs, output_file) + self._output_file = output_file + del self._inputs[:] + del self._outputs[:] + return res def _convert_method_to_onnx(self): + for args, kwargs in self._inputs: + self._serialization_classes |= {type(a) for a in args} + self._serialization_classes |= {type(a) for a in kwargs.values()} def make_method(self): inner_sig = inspect.signature(self._method_call) @@ -477,6 +541,112 @@ def __init__(self, parent): **self._to_onnx_kwargs, ) + def check_discrepancies( + self, atol: float = 1e-4, rtol: float = 0.1, hist=(0.1, 0.01), verbose: int = 0 + ) -> List[Dict[str, Union[str, int, float]]]: + """ + Computes the discrepancies between the saved inputs and outputs + with the saved onnx model. + + :param atol: absolute tolerance, recommended values, 1e-4 for float, 1e-2 flot float16 + :param rtol: relative tolerance + :param hist: thresholds, the function determines the number of discrepancies + above that threshold. + :param verbose: verbosity + :return: results, a list of dictionaries, ready to be consumed by a dataframe + """ + assert self._export_done, "The onnx export was not done." + assert os.path.exists(self._input_file), f"input file {self._input_file!r} not found" + assert os.path.exists( + self._output_file + ), f"output file {self._output_file!r} not found" + filename = self._to_onnx_kwargs["filename"] + assert isinstance(filename, str) and os.path.exists( + filename + ), f"onnx file {filename!r} not found" + classes = [ + cls + for cls in self._serialization_classes + if cls not in {int, float, bool, str, torch.Tensor, list, set, dict, torch.device} + ] + if verbose: + print(f"[method_to_onnx.check_discrepancies] register classes {classes}") + print(f"[method_to_onnx.check_discrepancies] load {self._input_file!r}") + with torch.serialization.safe_globals(classes): + inputs = torch.load(self._input_file) + if verbose: + print(f"[method_to_onnx.check_discrepancies] load {self._output_file!r}") + with torch.serialization.safe_globals(classes): + outputs = torch.load(self._output_file) + assert len(inputs) == len(outputs), ( + f"Unexpected number of inputs {len(inputs)} and outputs {len(outputs)}, " + f"inputs={string_type(inputs, with_shape=True)}, " + f"outputs={string_type(outputs, with_shape=True)}" + ) + if verbose: + print(f"[method_to_onnx.check_discrepancies] create onnx session {filename!r}") + sess = OnnxruntimeEvaluator(filename, whole=True) + input_names = sess.input_names + if verbose: + print(f"[method_to_onnx.check_discrepancies] input_names={input_names}") + data = [] + for i, (input, (output, latency)) in enumerate(zip(inputs, outputs)): + if verbose: + if verbose > 1: + print( + f"[method_to_onnx.check_discrepancies] process input {i}: " + f"{string_type(input, with_shape=True)}" + ) + print( + f"[method_to_onnx.check_discrepancies] expects: " + f"{string_type(output, with_shape=True)}" + ) + else: + print(f"[method_to_onnx.check_discrepancies] process input {i}") + + flat_inputs = flatten_object(input, drop_keys=True) + if len(flat_inputs) < len(input_names): + # not implemented yet, it is caused by a missing cache, + # which requires an empty cache instead + data.append(dict(index=i, duration_torch=latency, n_inputs=len(flat_inputs))) + continue + assert len(flat_inputs) == len(input_names), ( + f"Length mismatch, expecting {len(input_names)} onnx inputs and got " + f"{len(flat_inputs)} flat torch inputs" + ) + feeds = make_feeds(input_names, flat_inputs) + begin = time.perf_counter() + ort_outputs = sess.run(None, feeds) + duration = time.perf_counter() - begin + diff = max_diff(output, ort_outputs, hist=hist) + if "rep" in diff and isinstance(diff["rep"], dict): + diff.update(diff["rep"]) + del diff["rep"] + diff["SUCCESS"] = ( + isinstance(diff["abs"], float) + and isinstance(diff["rel"], float) + and diff["abs"] < atol + and diff["rel"] < rtol + ) + diff.update( + dict( + index=i, + duration_torch=latency, + ort_duration=duration, + n_inputs=len(flat_inputs), + ) + ) + if verbose > 1: + print( + f"[method_to_onnx.check_discrepancies] ort output " + f"{string_type(ort_outputs, with_shape=True)}" + ) + print(f"[method_to_onnx.check_discrepancies] diff {string_diff(diff)}") + data.append(diff) + if verbose: + print("[method_to_onnx.check_discrepancies] done") + return data + def method_to_onnx( mod: "torch.nn.Module", @@ -533,7 +703,7 @@ def method_to_onnx( See :ref:`l-plot-tiny-llm-export-method-generate` for an example. """ - wrapped_model = _WrapperToExportMethodToOnnx( + wrapped_model = WrapperToExportMethodToOnnx( mod=mod, method_name=method_name, input_names=input_names,