Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.8.9
+++++

* :pr:`378`: implements the computation of discrepancies in ``method_to_onnx``
* :pr:`379`: update the handling of cache after the removal of HybridCache, SlidingWindowCache in ``transformers>=5``,

0.8.8
Expand Down
17 changes: 16 additions & 1 deletion _doc/examples/plot_export_tiny_llm_method_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
`arnir0/Tiny-LLM <https://huggingface.co/arnir0/Tiny-LLM>`_.
"""

import pandas
from transformers import AutoModelForCausalLM, AutoTokenizer
from onnx_diagnostic import doc
from onnx_diagnostic.export.api import method_to_onnx
Expand Down Expand Up @@ -108,10 +109,24 @@ def generate_text(


# %%
# Let's call generate again.
# Let's call generate again. The conversion is triggered after
# ``convert_after_n_calls=3`` calls to the method forward,
# which exactly what the method generate is doing.
generated_text = generate_text(prompt, model, tokenizer)
print(generated_text)

# %%
# We finally need to check the discrepancies.
# The exports produced an onnx file and dumped the input and output
# of the torch model. We now run the onnx model to check
# it produces the same results.
# It is done after because the model may not hold twice in memory
# (torch and onnxruntime).
# verbose=2 shows more information about expected outputs.
data = forward_replacement.check_discrepancies(verbose=1)
df = pandas.DataFrame(data)
print(df)


# %%

Expand Down
5 changes: 5 additions & 0 deletions _unittests/ut_export/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@


class TestValidate(ExtTestCase):
@ignore_warnings(FutureWarning)
@hide_stdout()
def test_to_onnx(self):
class Model(torch.nn.Module):
Expand Down Expand Up @@ -142,6 +143,9 @@ def forward(self, x, y):
feeds = make_feeds(input_names, args, use_numpy=True)
got = sess.run(None, feeds)
self.assertEqualArray(expected, got[0])
df = method_to_call.check_discrepancies()
self.assertIsInstance(df, list)
self.assertEqual(len(df), 2)
self.clean_dump()

@requires_experimental_experiment("0.1")
Expand Down Expand Up @@ -243,6 +247,7 @@ def forward(self, x, y=None):
feeds = make_feeds(input_names, (args, kwargs), use_numpy=True)
got = sess.run(None, feeds)
self.assertEqualArray(expected, got[0])
method_to_call.check_discrepancies(verbose=1)
self.clean_dump()


Expand Down
198 changes: 184 additions & 14 deletions onnx_diagnostic/export/api.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import inspect
import os
import textwrap
import time
from collections.abc import Mapping, Iterable
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
import torch
from .dynamic_shapes import ModelInputs
from .onnx_plug import EagerDirectReplacementWithOnnx
from ..helpers import string_type
from ..helpers import flatten_object, max_diff, string_diff, string_type
from ..helpers.torch_helper import torch_deepcopy
from ..helpers.rt_helper import make_feeds
from ..reference import OnnxruntimeEvaluator


def get_main_dispatcher(
Expand Down Expand Up @@ -314,10 +319,11 @@ def to_onnx(
raise ValueError(f"Unknown exporter={exporter!r}")


class _WrapperToExportMethodToOnnx(torch.nn.Module):
class WrapperToExportMethodToOnnx(torch.nn.Module):
"""
Wraps an existing models in order to spy on inputs.
This is used by :func:`onnx_diagnostic.export.api.method_to_onnx`.
This is used by :func:`onnx_diagnostic.export.api.method_to_onnx`
or :ref:`l-plot-tiny-llm-export-method-generate` for an example.
"""

def __init__(
Expand Down Expand Up @@ -352,6 +358,7 @@ def __init__(
else getattr(mod, method_name)
)
self._inputs: List[Tuple[Tuple[Any, ...], Dict[str, Any]]] = []
self._outputs: List[Any] = []
self._convert_after_n_calls = convert_after_n_calls
self._patch_kwargs = patch_kwargs
self._method_src = None
Expand All @@ -375,6 +382,7 @@ def __init__(
inline=inline,
)
self._export_done = False
self._serialization_classes: Set[type] = set()

def __str__(self) -> str:
return self.__repr__()
Expand All @@ -385,18 +393,41 @@ def __repr__(self) -> str:
f"{self._method_name})"
)

def _collect_classes(self, obj):
if obj is None or isinstance(obj, torch.Tensor):
return
cls = type(obj)
if cls.__module__ not in ("builtins",):
self._serialization_classes.add(cls)
if hasattr(obj, "__dict__"):
for v in vars(obj).values():
self._collect_classes(v)
return
if isinstance(obj, Mapping):
for v in obj.values():
self._collect_classes(v)
return
if isinstance(obj, Iterable) and not isinstance(obj, (str, bytes)):
for v in obj:
self._collect_classes(v)
return

def forward(self, *args, **kwargs):
if not self._export_done:
self._inputs.append(
(
args,
torch_deepcopy(
(
kwargs
if not kwargs or not self.skip_kwargs_names
else {
k: v for k, v in kwargs.items() if k not in self.skip_kwargs_names
}
),
args,
(
kwargs
if not kwargs or not self.skip_kwargs_names
else {
k: v
for k, v in kwargs.items()
if k not in self.skip_kwargs_names
}
),
)
)
)
if self.verbose:
Expand All @@ -405,12 +436,45 @@ def forward(self, *args, **kwargs):
f"{string_type(self._inputs[-1], with_shape=True)}"
)
if len(self._inputs) >= self._convert_after_n_calls:
name = os.path.splitext(self._to_onnx_kwargs["filename"])[0]
input_file = f"{name}.inputs.pt"
self._input_file = input_file
if self.verbose:
print(
f"[method_to_onnx] save {len(self._inputs)} inputs in {input_file!r}"
)
torch.save(self._inputs, input_file)
self._convert_method_to_onnx()
del self._inputs[:]
self._export_done = True
return self._method_call(*args, **kwargs)

begin = time.perf_counter()
res = self._method_call(*args, **kwargs)
duration = time.perf_counter() - begin
self._collect_classes([args, kwargs, res])
if self._inputs:
self._outputs.append((torch_deepcopy(res), duration))
assert len(self._inputs) == len(self._outputs), (
f"Number of inputs {len(self._inputs)} and "
f"outputs {len(self._outputs)} are different."
)
if self._export_done:
name = os.path.splitext(self._to_onnx_kwargs["filename"])[0]
output_file = f"{name}.outputs.pt"
if self.verbose:
print(
f"[method_to_onnx] save {len(self._outputs)} "
f"outputs in {output_file!r}"
)
torch.save(self._outputs, output_file)
self._output_file = output_file
del self._inputs[:]
del self._outputs[:]
return res

def _convert_method_to_onnx(self):
for args, kwargs in self._inputs:
self._serialization_classes |= {type(a) for a in args}
self._serialization_classes |= {type(a) for a in kwargs.values()}

def make_method(self):
inner_sig = inspect.signature(self._method_call)
Expand Down Expand Up @@ -477,6 +541,112 @@ def __init__(self, parent):
**self._to_onnx_kwargs,
)

def check_discrepancies(
self, atol: float = 1e-4, rtol: float = 0.1, hist=(0.1, 0.01), verbose: int = 0
) -> List[Dict[str, Union[str, int, float]]]:
"""
Computes the discrepancies between the saved inputs and outputs
with the saved onnx model.

:param atol: absolute tolerance, recommended values, 1e-4 for float, 1e-2 flot float16
:param rtol: relative tolerance
:param hist: thresholds, the function determines the number of discrepancies
above that threshold.
:param verbose: verbosity
:return: results, a list of dictionaries, ready to be consumed by a dataframe
"""
assert self._export_done, "The onnx export was not done."
assert os.path.exists(self._input_file), f"input file {self._input_file!r} not found"
assert os.path.exists(
self._output_file
), f"output file {self._output_file!r} not found"
filename = self._to_onnx_kwargs["filename"]
assert isinstance(filename, str) and os.path.exists(
filename
), f"onnx file {filename!r} not found"
classes = [
cls
for cls in self._serialization_classes
if cls not in {int, float, bool, str, torch.Tensor, list, set, dict, torch.device}
]
if verbose:
print(f"[method_to_onnx.check_discrepancies] register classes {classes}")
print(f"[method_to_onnx.check_discrepancies] load {self._input_file!r}")
with torch.serialization.safe_globals(classes):
inputs = torch.load(self._input_file)
if verbose:
print(f"[method_to_onnx.check_discrepancies] load {self._output_file!r}")
with torch.serialization.safe_globals(classes):
outputs = torch.load(self._output_file)
assert len(inputs) == len(outputs), (
f"Unexpected number of inputs {len(inputs)} and outputs {len(outputs)}, "
f"inputs={string_type(inputs, with_shape=True)}, "
f"outputs={string_type(outputs, with_shape=True)}"
)
if verbose:
print(f"[method_to_onnx.check_discrepancies] create onnx session {filename!r}")
sess = OnnxruntimeEvaluator(filename, whole=True)
input_names = sess.input_names
if verbose:
print(f"[method_to_onnx.check_discrepancies] input_names={input_names}")
data = []
for i, (input, (output, latency)) in enumerate(zip(inputs, outputs)):
if verbose:
if verbose > 1:
print(
f"[method_to_onnx.check_discrepancies] process input {i}: "
f"{string_type(input, with_shape=True)}"
)
print(
f"[method_to_onnx.check_discrepancies] expects: "
f"{string_type(output, with_shape=True)}"
)
else:
print(f"[method_to_onnx.check_discrepancies] process input {i}")

flat_inputs = flatten_object(input, drop_keys=True)
if len(flat_inputs) < len(input_names):
# not implemented yet, it is caused by a missing cache,
# which requires an empty cache instead
data.append(dict(index=i, duration_torch=latency, n_inputs=len(flat_inputs)))
continue
assert len(flat_inputs) == len(input_names), (
f"Length mismatch, expecting {len(input_names)} onnx inputs and got "
f"{len(flat_inputs)} flat torch inputs"
)
feeds = make_feeds(input_names, flat_inputs)
begin = time.perf_counter()
ort_outputs = sess.run(None, feeds)
duration = time.perf_counter() - begin
diff = max_diff(output, ort_outputs, hist=hist)
if "rep" in diff and isinstance(diff["rep"], dict):
diff.update(diff["rep"])
del diff["rep"]
diff["SUCCESS"] = (
isinstance(diff["abs"], float)
and isinstance(diff["rel"], float)
and diff["abs"] < atol
and diff["rel"] < rtol
)
diff.update(
dict(
index=i,
duration_torch=latency,
ort_duration=duration,
n_inputs=len(flat_inputs),
)
)
if verbose > 1:
print(
f"[method_to_onnx.check_discrepancies] ort output "
f"{string_type(ort_outputs, with_shape=True)}"
)
print(f"[method_to_onnx.check_discrepancies] diff {string_diff(diff)}")
data.append(diff)
if verbose:
print("[method_to_onnx.check_discrepancies] done")
return data


def method_to_onnx(
mod: "torch.nn.Module",
Expand Down Expand Up @@ -533,7 +703,7 @@ def method_to_onnx(

See :ref:`l-plot-tiny-llm-export-method-generate` for an example.
"""
wrapped_model = _WrapperToExportMethodToOnnx(
wrapped_model = WrapperToExportMethodToOnnx(
mod=mod,
method_name=method_name,
input_names=input_names,
Expand Down
Loading