Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ Change Logs
0.8.7
+++++

* :pr:`363`: patch for DynamicDimConstraintPrinter
* :pr:`360`: preliminary work for phi4

0.8.6
+++++

Expand Down
2 changes: 2 additions & 0 deletions _doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ def linkcode_resolve(domain, info):
("py:class", "onnx_ir.Tuple"),
("py:class", "pandas.core.groupby.generic.DataFrameGroupBy"),
("py:class", "pipeline.Pipeline"),
("py:class", "torch._guards.Source"),
("py:class", "torch._ops.HigherOrderOperator"),
("py:class", "torch.fx.passes.operator_support.OperatorSupport"),
("py:class", "torch.fx.proxy.TracerBase"),
("py:class", "torch.FloatTensor"),
Expand Down
68 changes: 36 additions & 32 deletions _unittests/ut_torch_export_patches/test_patch_loops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest
import torch
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch, has_torch
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch
from onnx_diagnostic.helpers.torch_helper import (
is_torchdynamo_exporting,
fake_torchdynamo_exporting,
Expand All @@ -11,6 +11,7 @@
register_patched_expressions,
patched_float_arange,
)
from onnx_diagnostic.torch_export_patches import torch_export_patches


class TestOnnxExportErrors(ExtTestCase):
Expand All @@ -20,9 +21,23 @@ def test_patched_expressions(self):
names = {_[0] for _ in res}
self.assertIn("float_arange", names)

@requires_torch("2.8")
def test_filter_position_ids(self):
def test_float_arange(self):
register_patched_expressions()
rg = torch.arange(0.0, 0.99, 0.1)
rg2 = torch.ops.patched.float_arange(
torch.tensor(0.0), torch.tensor(0.99), torch.tensor(0.1)
)
rg3 = patched_float_arange(torch.tensor(0.0), torch.tensor(0.99), torch.tensor(0.1))
self.assertEqualArray(rg, rg2, atol=1e-5)
self.assertEqualArray(rg, rg3, atol=1e-5)
with fake_torchdynamo_exporting():
rg4 = patched_float_arange(
torch.tensor(0.0), torch.tensor(0.99), torch.tensor(0.1)
)
self.assertEqualArray(rg, rg4, atol=1e-5)

@requires_torch("2.9.99")
def test_filter_position_ids(self):
def filter_position_ids(
patch_attention_mask: torch.Tensor,
position_ids: torch.Tensor,
Expand All @@ -42,15 +57,6 @@ def filter_position_ids(
position_ids[batch_idx][p_attn_mask.view(-1)] = pos_ids
return position_ids

def float_arange(start, end, step):
length = torch.sym_int((end - start) / step + (step * (1 - 1e-6)))
torch._check(length > 0)
res = torch.arange(0, length)
torch._check(res.is_contiguous())
fres = res.to(torch.float32)
fstart = torch.tensor(start, dtype=torch.float32)
return fres + fstart

def scan_filter_position_ids(
patch_attention_mask: torch.Tensor,
position_ids: torch.Tensor,
Expand All @@ -59,18 +65,21 @@ def scan_filter_position_ids(
):

def body(p_attn_mask, position_ids_row):
h_len = torch.tensor(1) / p_attn_mask[:, 0].sum()
w_len = torch.tensor(1) / p_attn_mask[0].sum()
fractional_coords_h = patched_float_arange(
torch.tensor(0.0), torch.tensor(1 - 1e-6), h_len
h_len = torch.tensor(1, dtype=p_attn_mask.dtype) / p_attn_mask[:, 0].sum()
w_len = torch.tensor(1, dtype=p_attn_mask.dtype) / p_attn_mask[0].sum()
torch._check(h_len.item() > 0)
fractional_coords_h = torch.arange(
torch.tensor(0.0, dtype=p_attn_mask.dtype),
torch.tensor(1 - 1e-6, dtype=p_attn_mask.dtype),
h_len,
)
fractional_coords_w = patched_float_arange(
torch.tensor(0.0), torch.tensor(1 - 1e-6), w_len
torch._check(w_len.item() > 0)
fractional_coords_w = torch.arange(
torch.tensor(0.0, dtype=p_attn_mask.dtype),
torch.tensor(1 - 1e-6, dtype=p_attn_mask.dtype),
w_len,
)

# torch.arange(0, 1 - 1e-6, 1 / p_attn_mask[:, 0].sum().item())
# torch.arange(0, 1 - 1e-6, 1 / p_attn_mask[0].sum().item())

bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)

Expand Down Expand Up @@ -116,17 +125,12 @@ def forward(self, patch_attention_mask, position_ids, boundaries):
self.assertEqualArray(expected, got)

DYN = torch.export.Dim.DYNAMIC
ep = torch.export.export(model, inputs, dynamic_shapes=({0: DYN}, {0: DYN}, {0: DYN}))
try:
got = ep.module()(*inputs)
except Exception:
# At least it exports, we need to remove the assert from the exported program.
# Let's revisit this later.
if has_torch("2.11"):
raise
got = None
if got is not None:
self.assertEqualArray(expected, got)
with torch_export_patches(patch_torch=True):
ep = torch.export.export(
model, inputs, dynamic_shapes=({0: DYN}, {0: DYN}, {0: DYN})
)
got = ep.module()(*inputs)
self.assertEqualArray(expected, got)


if __name__ == "__main__":
Expand Down
85 changes: 85 additions & 0 deletions _unittests/ut_torch_export_patches/test_patch_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,91 @@ def forward(self, x, y):
shape = output[0].args[0][0].meta["val"].shape
self.assertEqual(str(shape), "torch.Size([Max(s17, s77)])")

@requires_torch("2.9.99")
def test_patched_DynamicDimConstraintPrinter(self):
def filter_position_ids(
patch_attention_mask: torch.Tensor,
position_ids: torch.Tensor,
boundaries: torch.Tensor,
num_patches_per_side: int,
):
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / p_attn_mask[:, 0].sum())
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / p_attn_mask[0].sum())

bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)

pos_ids = (
bucket_coords_h[:, None] * num_patches_per_side + bucket_coords_w
).flatten()
position_ids[batch_idx][p_attn_mask.view(-1)] = pos_ids
return position_ids

def scan_filter_position_ids(
patch_attention_mask: torch.Tensor,
position_ids: torch.Tensor,
boundaries: torch.Tensor,
num_patches_per_side: int,
):

def body(p_attn_mask, position_ids_row):
h_len = torch.tensor(1, dtype=p_attn_mask.dtype) / p_attn_mask[:, 0].sum()
w_len = torch.tensor(1, dtype=p_attn_mask.dtype) / p_attn_mask[0].sum()
torch._check(h_len.item() > 0)
fractional_coords_h = torch.arange(
torch.tensor(0.0, dtype=p_attn_mask.dtype),
torch.tensor(1 - 1e-6, dtype=p_attn_mask.dtype),
h_len,
)
torch._check(w_len.item() > 0)
fractional_coords_w = torch.arange(
torch.tensor(0.0, dtype=p_attn_mask.dtype),
torch.tensor(1 - 1e-6, dtype=p_attn_mask.dtype),
w_len,
)

bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)

pos_ids = (
bucket_coords_h[:, None] * num_patches_per_side + bucket_coords_w
).flatten()

row = position_ids_row.clone()
row[p_attn_mask.view(-1)] = pos_ids
return [row]

return torch.ops.higher_order.scan(
body, [], [patch_attention_mask, position_ids], additional_inputs=[]
)

class Model(torch.nn.Module):
def forward(self, patch_attention_mask, position_ids, boundaries):
res = scan_filter_position_ids(
patch_attention_mask, position_ids, boundaries, 32
)
return res[0]

patch_attention_mask = torch.randint(0, 17, (32, 32, 32)) >= 1
patch_attention_mask[:, :, :] = True
position_ids = torch.zeros((32, 1024), dtype=torch.int64)
boundaries = (torch.arange(33).to(torch.float32) / 33)[1:-1]
inputs = (patch_attention_mask, position_ids, boundaries)

model = Model()
true_expected = filter_position_ids(*(*inputs, 32))
expected = model(*inputs)
self.assertEqualArray(true_expected, expected)

DYN = torch.export.Dim.DYNAMIC
with torch_export_patches(patch_torch=True):
ep = torch.export.export(
model, inputs, dynamic_shapes=({0: DYN}, {0: DYN}, {0: DYN})
)
got = ep.module()(*inputs)
self.assertEqualArray(expected, got)


if __name__ == "__main__":
unittest.main(verbosity=2)
25 changes: 25 additions & 0 deletions onnx_diagnostic/torch_export_patches/onnx_export_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def _patch_torch(
catch_constraints: bool,
stop_if_static: int,
) -> Tuple[Optional[Callable], ...]:
import packaging.version as pv
import torch
import torch.jit
import torch._export.non_strict_utils # produce_guards_and_solve_constraints
Expand All @@ -238,6 +239,11 @@ def _patch_torch(
patched_ShapeEnv,
)

if pv.Version(torch.__version__) >= pv.Version("2.9.99"):
from .patches.patch_torch import patched_DynamicDimConstraintPrinter
else:
patched_DynamicDimConstraintPrinter = None

f___constrain_user_specified_dimhint_range = None
f__broadcast_in_dim_meta = None
f__broadcast_shapes = None
Expand All @@ -259,6 +265,17 @@ def _patch_torch(
print(f"[torch_export_patches] stop_if_static={stop_if_static!r}")
print("[torch_export_patches] patch pytorch")

# torch.tx.experimental.symbolic_shapes.DynamicDimConstraintPrinter._print_Symbol
if patched_DynamicDimConstraintPrinter is not None:
f__print_symbol = (
torch.fx.experimental.symbolic_shapes.DynamicDimConstraintPrinter._print_Symbol
)
torch.fx.experimental.symbolic_shapes.DynamicDimConstraintPrinter._print_Symbol = (
patched_DynamicDimConstraintPrinter._print_Symbol
)
else:
f__print_symbol = None

# torch.vmap
f_vmap = torch.vmap
torch.vmap = patched_vmap
Expand Down Expand Up @@ -392,6 +409,7 @@ def _patch_torch(
f_shape_env__log_guard,
f_shape_env__set_replacement,
f_vmap,
f__print_symbol,
)


Expand All @@ -416,13 +434,18 @@ def _unpatch_torch(
f_shape_env__log_guard: Optional[Callable],
f_shape_env__set_replacement: Optional[Callable],
f_vmap: Optional[Callable],
f__print_symbol: Optional[Callable],
):
import torch
import torch.jit
import torch._export.non_strict_utils # produce_guards_and_solve_constraints
from torch.fx.experimental.symbolic_shapes import ShapeEnv

# this should disappear when torch.jit is removed
if f__print_symbol is not None:
torch.fx.experimental.symbolic_shapes.DynamicDimConstraintPrinter._print_Symbol = (
f__print_symbol
)
torch.vmap = f_vmap
torch.jit.isinstance = f_jit_isinstance
torch._dynamo.mark_static_address = f_mark_static_address
Expand Down Expand Up @@ -992,6 +1015,7 @@ def torch_export_patches(
f_shape_env__log_guard,
f_shape_env__set_replacement,
f_vmap,
f__print_Symbol,
) = _patch_torch(
verbose, patch_details, patch_torch, catch_constraints, stop_if_static
)
Expand Down Expand Up @@ -1067,6 +1091,7 @@ def torch_export_patches(
f_shape_env__log_guard,
f_shape_env__set_replacement,
f_vmap,
f__print_Symbol,
)

if patch_transformers:
Expand Down
5 changes: 4 additions & 1 deletion onnx_diagnostic/torch_export_patches/patch_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,10 @@ def patched_selector(fct: Callable, patched_fct: Callable) -> Callable:


def patched_float_arange(start, end, step):
"""Patched arange when start, end, step are floats."""
"""
Patched arange when start, end, step are floats.
This patch should not be needed after 2.10.
"""
if is_torchdynamo_exporting():
return torch.ops.patched.float_arange(start, end, step)
else:
Expand Down
15 changes: 15 additions & 0 deletions onnx_diagnostic/torch_export_patches/patches/patch_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import traceback
from functools import reduce
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Union
import sympy
import torch
from torch._subclasses.fake_tensor import FakeTensorMode

Expand Down Expand Up @@ -1091,3 +1092,17 @@ def _greater_than_reduce(acc, x):
new_strides.append(a.stride()[original_idx] * a.size()[original_idx])

return a.as_strided(shape, new_strides, a.storage_offset())


class patched_DynamicDimConstraintPrinter:
"""
Patches
``torch.tx.experimental.symbolic_shapes.DynamicDimConstraintPrinter._print_Symbol``.
Valid for ``torch>=2.10``.
"""

def _print_Symbol(self, expr: sympy.Symbol) -> str:
assert isinstance(expr, sympy.Symbol), str(type(expr))
if self.symbol_to_source.get(expr):
return self.symbol_to_source[expr][0].name
return str(expr)
Loading