diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 78199c2d..eddfc61e 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,9 @@ Change Logs 0.8.7 +++++ +* :pr:`363`: patch for DynamicDimConstraintPrinter +* :pr:`360`: preliminary work for phi4 + 0.8.6 +++++ diff --git a/_doc/conf.py b/_doc/conf.py index ae7d9174..3e7bc6a0 100644 --- a/_doc/conf.py +++ b/_doc/conf.py @@ -134,6 +134,8 @@ def linkcode_resolve(domain, info): ("py:class", "onnx_ir.Tuple"), ("py:class", "pandas.core.groupby.generic.DataFrameGroupBy"), ("py:class", "pipeline.Pipeline"), + ("py:class", "torch._guards.Source"), + ("py:class", "torch._ops.HigherOrderOperator"), ("py:class", "torch.fx.passes.operator_support.OperatorSupport"), ("py:class", "torch.fx.proxy.TracerBase"), ("py:class", "torch.FloatTensor"), diff --git a/_unittests/ut_torch_export_patches/test_patch_loops.py b/_unittests/ut_torch_export_patches/test_patch_loops.py index c7438505..ddb09d0d 100644 --- a/_unittests/ut_torch_export_patches/test_patch_loops.py +++ b/_unittests/ut_torch_export_patches/test_patch_loops.py @@ -1,6 +1,6 @@ import unittest import torch -from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch, has_torch +from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch from onnx_diagnostic.helpers.torch_helper import ( is_torchdynamo_exporting, fake_torchdynamo_exporting, @@ -11,6 +11,7 @@ register_patched_expressions, patched_float_arange, ) +from onnx_diagnostic.torch_export_patches import torch_export_patches class TestOnnxExportErrors(ExtTestCase): @@ -20,9 +21,23 @@ def test_patched_expressions(self): names = {_[0] for _ in res} self.assertIn("float_arange", names) - @requires_torch("2.8") - def test_filter_position_ids(self): + def test_float_arange(self): + register_patched_expressions() + rg = torch.arange(0.0, 0.99, 0.1) + rg2 = torch.ops.patched.float_arange( + torch.tensor(0.0), torch.tensor(0.99), torch.tensor(0.1) + ) + rg3 = patched_float_arange(torch.tensor(0.0), torch.tensor(0.99), torch.tensor(0.1)) + self.assertEqualArray(rg, rg2, atol=1e-5) + self.assertEqualArray(rg, rg3, atol=1e-5) + with fake_torchdynamo_exporting(): + rg4 = patched_float_arange( + torch.tensor(0.0), torch.tensor(0.99), torch.tensor(0.1) + ) + self.assertEqualArray(rg, rg4, atol=1e-5) + @requires_torch("2.9.99") + def test_filter_position_ids(self): def filter_position_ids( patch_attention_mask: torch.Tensor, position_ids: torch.Tensor, @@ -42,15 +57,6 @@ def filter_position_ids( position_ids[batch_idx][p_attn_mask.view(-1)] = pos_ids return position_ids - def float_arange(start, end, step): - length = torch.sym_int((end - start) / step + (step * (1 - 1e-6))) - torch._check(length > 0) - res = torch.arange(0, length) - torch._check(res.is_contiguous()) - fres = res.to(torch.float32) - fstart = torch.tensor(start, dtype=torch.float32) - return fres + fstart - def scan_filter_position_ids( patch_attention_mask: torch.Tensor, position_ids: torch.Tensor, @@ -59,18 +65,21 @@ def scan_filter_position_ids( ): def body(p_attn_mask, position_ids_row): - h_len = torch.tensor(1) / p_attn_mask[:, 0].sum() - w_len = torch.tensor(1) / p_attn_mask[0].sum() - fractional_coords_h = patched_float_arange( - torch.tensor(0.0), torch.tensor(1 - 1e-6), h_len + h_len = torch.tensor(1, dtype=p_attn_mask.dtype) / p_attn_mask[:, 0].sum() + w_len = torch.tensor(1, dtype=p_attn_mask.dtype) / p_attn_mask[0].sum() + torch._check(h_len.item() > 0) + fractional_coords_h = torch.arange( + torch.tensor(0.0, dtype=p_attn_mask.dtype), + torch.tensor(1 - 1e-6, dtype=p_attn_mask.dtype), + h_len, ) - fractional_coords_w = patched_float_arange( - torch.tensor(0.0), torch.tensor(1 - 1e-6), w_len + torch._check(w_len.item() > 0) + fractional_coords_w = torch.arange( + torch.tensor(0.0, dtype=p_attn_mask.dtype), + torch.tensor(1 - 1e-6, dtype=p_attn_mask.dtype), + w_len, ) - # torch.arange(0, 1 - 1e-6, 1 / p_attn_mask[:, 0].sum().item()) - # torch.arange(0, 1 - 1e-6, 1 / p_attn_mask[0].sum().item()) - bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) @@ -116,17 +125,12 @@ def forward(self, patch_attention_mask, position_ids, boundaries): self.assertEqualArray(expected, got) DYN = torch.export.Dim.DYNAMIC - ep = torch.export.export(model, inputs, dynamic_shapes=({0: DYN}, {0: DYN}, {0: DYN})) - try: - got = ep.module()(*inputs) - except Exception: - # At least it exports, we need to remove the assert from the exported program. - # Let's revisit this later. - if has_torch("2.11"): - raise - got = None - if got is not None: - self.assertEqualArray(expected, got) + with torch_export_patches(patch_torch=True): + ep = torch.export.export( + model, inputs, dynamic_shapes=({0: DYN}, {0: DYN}, {0: DYN}) + ) + got = ep.module()(*inputs) + self.assertEqualArray(expected, got) if __name__ == "__main__": diff --git a/_unittests/ut_torch_export_patches/test_patch_torch.py b/_unittests/ut_torch_export_patches/test_patch_torch.py index e2da377a..e1427c52 100644 --- a/_unittests/ut_torch_export_patches/test_patch_torch.py +++ b/_unittests/ut_torch_export_patches/test_patch_torch.py @@ -579,6 +579,91 @@ def forward(self, x, y): shape = output[0].args[0][0].meta["val"].shape self.assertEqual(str(shape), "torch.Size([Max(s17, s77)])") + @requires_torch("2.9.99") + def test_patched_DynamicDimConstraintPrinter(self): + def filter_position_ids( + patch_attention_mask: torch.Tensor, + position_ids: torch.Tensor, + boundaries: torch.Tensor, + num_patches_per_side: int, + ): + for batch_idx, p_attn_mask in enumerate(patch_attention_mask): + fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / p_attn_mask[:, 0].sum()) + fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / p_attn_mask[0].sum()) + + bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) + bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) + + pos_ids = ( + bucket_coords_h[:, None] * num_patches_per_side + bucket_coords_w + ).flatten() + position_ids[batch_idx][p_attn_mask.view(-1)] = pos_ids + return position_ids + + def scan_filter_position_ids( + patch_attention_mask: torch.Tensor, + position_ids: torch.Tensor, + boundaries: torch.Tensor, + num_patches_per_side: int, + ): + + def body(p_attn_mask, position_ids_row): + h_len = torch.tensor(1, dtype=p_attn_mask.dtype) / p_attn_mask[:, 0].sum() + w_len = torch.tensor(1, dtype=p_attn_mask.dtype) / p_attn_mask[0].sum() + torch._check(h_len.item() > 0) + fractional_coords_h = torch.arange( + torch.tensor(0.0, dtype=p_attn_mask.dtype), + torch.tensor(1 - 1e-6, dtype=p_attn_mask.dtype), + h_len, + ) + torch._check(w_len.item() > 0) + fractional_coords_w = torch.arange( + torch.tensor(0.0, dtype=p_attn_mask.dtype), + torch.tensor(1 - 1e-6, dtype=p_attn_mask.dtype), + w_len, + ) + + bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) + bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) + + pos_ids = ( + bucket_coords_h[:, None] * num_patches_per_side + bucket_coords_w + ).flatten() + + row = position_ids_row.clone() + row[p_attn_mask.view(-1)] = pos_ids + return [row] + + return torch.ops.higher_order.scan( + body, [], [patch_attention_mask, position_ids], additional_inputs=[] + ) + + class Model(torch.nn.Module): + def forward(self, patch_attention_mask, position_ids, boundaries): + res = scan_filter_position_ids( + patch_attention_mask, position_ids, boundaries, 32 + ) + return res[0] + + patch_attention_mask = torch.randint(0, 17, (32, 32, 32)) >= 1 + patch_attention_mask[:, :, :] = True + position_ids = torch.zeros((32, 1024), dtype=torch.int64) + boundaries = (torch.arange(33).to(torch.float32) / 33)[1:-1] + inputs = (patch_attention_mask, position_ids, boundaries) + + model = Model() + true_expected = filter_position_ids(*(*inputs, 32)) + expected = model(*inputs) + self.assertEqualArray(true_expected, expected) + + DYN = torch.export.Dim.DYNAMIC + with torch_export_patches(patch_torch=True): + ep = torch.export.export( + model, inputs, dynamic_shapes=({0: DYN}, {0: DYN}, {0: DYN}) + ) + got = ep.module()(*inputs) + self.assertEqualArray(expected, got) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py index fc4470c2..fdae4311 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py @@ -221,6 +221,7 @@ def _patch_torch( catch_constraints: bool, stop_if_static: int, ) -> Tuple[Optional[Callable], ...]: + import packaging.version as pv import torch import torch.jit import torch._export.non_strict_utils # produce_guards_and_solve_constraints @@ -238,6 +239,11 @@ def _patch_torch( patched_ShapeEnv, ) + if pv.Version(torch.__version__) >= pv.Version("2.9.99"): + from .patches.patch_torch import patched_DynamicDimConstraintPrinter + else: + patched_DynamicDimConstraintPrinter = None + f___constrain_user_specified_dimhint_range = None f__broadcast_in_dim_meta = None f__broadcast_shapes = None @@ -259,6 +265,17 @@ def _patch_torch( print(f"[torch_export_patches] stop_if_static={stop_if_static!r}") print("[torch_export_patches] patch pytorch") + # torch.tx.experimental.symbolic_shapes.DynamicDimConstraintPrinter._print_Symbol + if patched_DynamicDimConstraintPrinter is not None: + f__print_symbol = ( + torch.fx.experimental.symbolic_shapes.DynamicDimConstraintPrinter._print_Symbol + ) + torch.fx.experimental.symbolic_shapes.DynamicDimConstraintPrinter._print_Symbol = ( + patched_DynamicDimConstraintPrinter._print_Symbol + ) + else: + f__print_symbol = None + # torch.vmap f_vmap = torch.vmap torch.vmap = patched_vmap @@ -392,6 +409,7 @@ def _patch_torch( f_shape_env__log_guard, f_shape_env__set_replacement, f_vmap, + f__print_symbol, ) @@ -416,6 +434,7 @@ def _unpatch_torch( f_shape_env__log_guard: Optional[Callable], f_shape_env__set_replacement: Optional[Callable], f_vmap: Optional[Callable], + f__print_symbol: Optional[Callable], ): import torch import torch.jit @@ -423,6 +442,10 @@ def _unpatch_torch( from torch.fx.experimental.symbolic_shapes import ShapeEnv # this should disappear when torch.jit is removed + if f__print_symbol is not None: + torch.fx.experimental.symbolic_shapes.DynamicDimConstraintPrinter._print_Symbol = ( + f__print_symbol + ) torch.vmap = f_vmap torch.jit.isinstance = f_jit_isinstance torch._dynamo.mark_static_address = f_mark_static_address @@ -992,6 +1015,7 @@ def torch_export_patches( f_shape_env__log_guard, f_shape_env__set_replacement, f_vmap, + f__print_Symbol, ) = _patch_torch( verbose, patch_details, patch_torch, catch_constraints, stop_if_static ) @@ -1067,6 +1091,7 @@ def torch_export_patches( f_shape_env__log_guard, f_shape_env__set_replacement, f_vmap, + f__print_Symbol, ) if patch_transformers: diff --git a/onnx_diagnostic/torch_export_patches/patch_expressions.py b/onnx_diagnostic/torch_export_patches/patch_expressions.py index 0b6b1990..7fa891a2 100644 --- a/onnx_diagnostic/torch_export_patches/patch_expressions.py +++ b/onnx_diagnostic/torch_export_patches/patch_expressions.py @@ -101,7 +101,10 @@ def patched_selector(fct: Callable, patched_fct: Callable) -> Callable: def patched_float_arange(start, end, step): - """Patched arange when start, end, step are floats.""" + """ + Patched arange when start, end, step are floats. + This patch should not be needed after 2.10. + """ if is_torchdynamo_exporting(): return torch.ops.patched.float_arange(start, end, step) else: diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_torch.py b/onnx_diagnostic/torch_export_patches/patches/patch_torch.py index b947c8c5..63894b51 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_torch.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_torch.py @@ -5,6 +5,7 @@ import traceback from functools import reduce from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Union +import sympy import torch from torch._subclasses.fake_tensor import FakeTensorMode @@ -1091,3 +1092,17 @@ def _greater_than_reduce(acc, x): new_strides.append(a.stride()[original_idx] * a.size()[original_idx]) return a.as_strided(shape, new_strides, a.storage_offset()) + + +class patched_DynamicDimConstraintPrinter: + """ + Patches + ``torch.tx.experimental.symbolic_shapes.DynamicDimConstraintPrinter._print_Symbol``. + Valid for ``torch>=2.10``. + """ + + def _print_Symbol(self, expr: sympy.Symbol) -> str: + assert isinstance(expr, sympy.Symbol), str(type(expr)) + if self.symbol_to_source.get(expr): + return self.symbol_to_source[expr][0].name + return str(expr)