From e710f8495504246bdf1e5a71ab4bb24a33ba5971 Mon Sep 17 00:00:00 2001 From: Oscar Andersson Date: Wed, 1 Apr 2026 13:43:25 +0200 Subject: [PATCH] Arm backend: Sympify tosa.RESIZE lowering Modify rewrite_upsample to support dynamic shapes. Signed-off-by: Oscar Andersson Change-Id: I992a2e12da1c5cb6ed38c705829a67f911d29c20 --- backends/arm/_passes/rewrite_upsample.py | 30 ++++++------- .../arm/test/misc/test_tosa_dialect_resize.py | 44 ++++++++++++++++++- .../test/passes/test_rewrite_upsample_pass.py | 43 ++++++++++++++++++ backends/arm/tosa/dialect/ops/resize.py | 4 +- 4 files changed, 104 insertions(+), 17 deletions(-) create mode 100644 backends/arm/test/passes/test_rewrite_upsample_pass.py diff --git a/backends/arm/_passes/rewrite_upsample.py b/backends/arm/_passes/rewrite_upsample.py index da336d0dde3..9f81f5cbbe5 100644 --- a/backends/arm/_passes/rewrite_upsample.py +++ b/backends/arm/_passes/rewrite_upsample.py @@ -72,16 +72,16 @@ def get_resize_parameters_1d( "We do not support align_corners=True for symbolic shapes." ) - # SymInt seems to not actually work for symbolic expressions, so use the underlying sympy objects instead + # Use the exported SymPy expressions for symbolic shapes. input_size = ( - input_size.node._expr + sympy.sympify(input_size.node.expr) if isinstance(input_size, torch.SymInt) - else input_size + else sympy.sympify(input_size) ) output_size = ( - output_size.node._expr + sympy.sympify(output_size.node.expr) if isinstance(output_size, torch.SymInt) - else output_size + else sympy.sympify(output_size) ) if align_corners and input_size > 1 and output_size > 1: scale_n = output_size - 1 @@ -91,17 +91,15 @@ def get_resize_parameters_1d( scale_d = input_size - 1 else: scale_d = input_size - ratio = scale_n / scale_d - if not sympy.sympify(ratio).is_constant(): + ratio = sympy.nsimplify(sympy.simplify(scale_n / scale_d)) + if ratio.free_symbols: raise RuntimeError( "Resize requires a constant ratio: " + str(ratio) + " is not constant!" ) - gcd = sympy.gcd(scale_n, scale_d) - scale_n = 2 * scale_n // gcd - scale_d = 2 * scale_d // gcd - # These should always be whole integers, based on the above calculations - scale_n = int(scale_n.evalf()) - scale_d = int(scale_d.evalf()) + ratio_num, ratio_den = ratio.as_numer_denom() + # TOSA encodes resize scales as doubled rationals. + scale_n = int((2 * ratio_num).evalf()) + scale_d = int((2 * ratio_den).evalf()) if align_corners: offset = 0 @@ -111,9 +109,11 @@ def get_resize_parameters_1d( # Calculate border to maintain the correct the output size. # Note that this should always result in a constant value, as the ratio is constant. - border = scale_d * (output_size - 1) - scale_n * (input_size - 1) + offset + border = sympy.simplify( + scale_d * (output_size - 1) - scale_n * (input_size - 1) + offset + ) - if not sympy.sympify(border).is_constant(): + if border.free_symbols: raise RuntimeError( "Resize requires a constant border: " + str(border) diff --git a/backends/arm/test/misc/test_tosa_dialect_resize.py b/backends/arm/test/misc/test_tosa_dialect_resize.py index 91e7aad8ad9..bfa35961195 100644 --- a/backends/arm/test/misc/test_tosa_dialect_resize.py +++ b/backends/arm/test/misc/test_tosa_dialect_resize.py @@ -4,8 +4,8 @@ # LICENSE file in the root directory of this source tree. import executorch.backends.arm.tosa.dialect # noqa: F401 - import pytest +import sympy # type: ignore import torch from executorch.backends.arm.tosa.dialect.lib import TosaValueError @@ -15,6 +15,21 @@ ) from executorch.exir.dialects._ops import ops as exir_ops from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.experimental.symbolic_shapes import ShapeEnv + + +def _make_symint( + shape_env: ShapeEnv, symbol: str, hint: int, min: int = 1, max: int = 64 +) -> torch.SymInt: + symint = shape_env.create_symintnode(sympy.Symbol(symbol), hint=hint) + shape_env.constrain_symbol_range( + symint.node.expr, compiler_min=min, compiler_max=max + ) + return symint + + +def _expr(sym: torch.SymInt) -> sympy.Expr: + return sympy.sympify(getattr(sym.node, "expr", sym.node._expr)) def test_bilinear_resize_rejects_exact_one_sixteenth_downscale(): @@ -34,3 +49,30 @@ def test_bilinear_resize_rejects_exact_one_sixteenth_downscale(): [-15, -15], resize_mode="bilinear", ) + + +def test_resize_accepts_symbolic_scale_and_border_values(): + shape_env = ShapeEnv() + scale_y_n = _make_symint(shape_env, "scale_y_n", hint=2, min=1, max=8) + border_y = _make_symint(shape_env, "border_y", hint=1, min=0, max=8) + + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.1+FP+shape"), shape_env + ), FakeTensorMode(shape_env=shape_env) as mode: + x = mode.from_tensor(torch.empty(size=(1, 3, 4, 2), dtype=torch.float32)) + output = exir_ops.backend.tosa.RESIZE.default( + x, + [scale_y_n, 1, 4, 2], + [0, 0], + [border_y, 0], + resize_mode="nearest", + ) + + assert output.dtype == torch.float32 + assert (output.shape[0], output.shape[-1]) == (1, 2) + assert isinstance(output.shape[1], torch.SymInt) + assert output.shape[2] == 7 + # The output height is computed as: (input_height - 1) * scale_y_n + border_y + 1. + # As the hegiht is a symbolic expression, we check that the expression is correct by + # comparing it to the expected expression. + assert str(_expr(output.shape[1])) == "(((border_y + 2*scale_y_n)//1)) + 1" diff --git a/backends/arm/test/passes/test_rewrite_upsample_pass.py b/backends/arm/test/passes/test_rewrite_upsample_pass.py new file mode 100644 index 00000000000..a5a5b1bbff6 --- /dev/null +++ b/backends/arm/test/passes/test_rewrite_upsample_pass.py @@ -0,0 +1,43 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import sympy # type: ignore +import torch +from executorch.backends.arm._passes.rewrite_upsample import RewriteUpsamplePass +from torch.fx.experimental.symbolic_shapes import ShapeEnv + + +def _make_symint( + shape_env: ShapeEnv, symbol: str, hint: int, min: int = 1, max: int = 64 +) -> torch.SymInt: + symint = shape_env.create_symintnode(sympy.Symbol(symbol), hint=hint) + shape_env.constrain_symbol_range( + symint.node.expr, compiler_min=min, compiler_max=max + ) + return symint + + +def test_get_resize_parameters_1d_supports_symbolic_shapes_with_constant_ratio(): + shape_env = ShapeEnv() + input_size = _make_symint(shape_env, "input_size", hint=4) + output_size = input_size * 2 + + scale_n, scale_d, offset, border = RewriteUpsamplePass.get_resize_parameters_1d( + input_size, output_size, align_corners=False + ) + + assert (scale_n, scale_d, offset, border) == (4, 2, -1, 1) + + +def test_get_resize_parameters_1d_rejects_non_constant_symbolic_ratio(): + shape_env = ShapeEnv() + input_size = _make_symint(shape_env, "input_size", hint=4) + output_size = input_size + 1 + + with pytest.raises(RuntimeError, match="constant ratio"): + RewriteUpsamplePass.get_resize_parameters_1d( + input_size, output_size, align_corners=False + ) diff --git a/backends/arm/tosa/dialect/ops/resize.py b/backends/arm/tosa/dialect/ops/resize.py index 47add0ffb7f..c48ff508afc 100644 --- a/backends/arm/tosa/dialect/ops/resize.py +++ b/backends/arm/tosa/dialect/ops/resize.py @@ -52,7 +52,9 @@ def _get_output_dtype( def _validate_resize_parameters(scale, border, resize_mode): def in_int16_range(values): - return all((x >= -(2**15)) and (x <= 2**15 - 1) for x in values) + return all( + (x >= -(2**15)) and (x <= 2**15 - 1) for x in values if isinstance(x, int) + ) if not in_int16_range(scale): raise TosaValueError("scale is out of the int16 range", op="RESIZE")