Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions backends/arm/_passes/rewrite_upsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,16 @@ def get_resize_parameters_1d(
"We do not support align_corners=True for symbolic shapes."
)

# SymInt seems to not actually work for symbolic expressions, so use the underlying sympy objects instead
# Use the exported SymPy expressions for symbolic shapes.
input_size = (
input_size.node._expr
sympy.sympify(input_size.node.expr)
if isinstance(input_size, torch.SymInt)
else input_size
else sympy.sympify(input_size)
)
output_size = (
output_size.node._expr
sympy.sympify(output_size.node.expr)
if isinstance(output_size, torch.SymInt)
else output_size
else sympy.sympify(output_size)
)
if align_corners and input_size > 1 and output_size > 1:
scale_n = output_size - 1
Expand All @@ -91,17 +91,15 @@ def get_resize_parameters_1d(
scale_d = input_size - 1
else:
scale_d = input_size
ratio = scale_n / scale_d
if not sympy.sympify(ratio).is_constant():
ratio = sympy.nsimplify(sympy.simplify(scale_n / scale_d))
if ratio.free_symbols:
raise RuntimeError(
"Resize requires a constant ratio: " + str(ratio) + " is not constant!"
)
gcd = sympy.gcd(scale_n, scale_d)
scale_n = 2 * scale_n // gcd
scale_d = 2 * scale_d // gcd
# These should always be whole integers, based on the above calculations
scale_n = int(scale_n.evalf())
scale_d = int(scale_d.evalf())
ratio_num, ratio_den = ratio.as_numer_denom()
# TOSA encodes resize scales as doubled rationals.
scale_n = int((2 * ratio_num).evalf())
scale_d = int((2 * ratio_den).evalf())

if align_corners:
offset = 0
Expand All @@ -111,9 +109,11 @@ def get_resize_parameters_1d(

# Calculate border to maintain the correct the output size.
# Note that this should always result in a constant value, as the ratio is constant.
border = scale_d * (output_size - 1) - scale_n * (input_size - 1) + offset
border = sympy.simplify(
scale_d * (output_size - 1) - scale_n * (input_size - 1) + offset
)

if not sympy.sympify(border).is_constant():
if border.free_symbols:
raise RuntimeError(
"Resize requires a constant border: "
+ str(border)
Expand Down
44 changes: 43 additions & 1 deletion backends/arm/test/misc/test_tosa_dialect_resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
# LICENSE file in the root directory of this source tree.

import executorch.backends.arm.tosa.dialect # noqa: F401

import pytest
import sympy # type: ignore
import torch

from executorch.backends.arm.tosa.dialect.lib import TosaValueError
Expand All @@ -15,6 +15,21 @@
)
from executorch.exir.dialects._ops import ops as exir_ops
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.fx.experimental.symbolic_shapes import ShapeEnv


def _make_symint(
shape_env: ShapeEnv, symbol: str, hint: int, min: int = 1, max: int = 64
) -> torch.SymInt:
symint = shape_env.create_symintnode(sympy.Symbol(symbol), hint=hint)
shape_env.constrain_symbol_range(
symint.node.expr, compiler_min=min, compiler_max=max
)
return symint


def _expr(sym: torch.SymInt) -> sympy.Expr:
return sympy.sympify(getattr(sym.node, "expr", sym.node._expr))


def test_bilinear_resize_rejects_exact_one_sixteenth_downscale():
Expand All @@ -34,3 +49,30 @@ def test_bilinear_resize_rejects_exact_one_sixteenth_downscale():
[-15, -15],
resize_mode="bilinear",
)


def test_resize_accepts_symbolic_scale_and_border_values():
shape_env = ShapeEnv()
scale_y_n = _make_symint(shape_env, "scale_y_n", hint=2, min=1, max=8)
border_y = _make_symint(shape_env, "border_y", hint=1, min=0, max=8)

with TosaLoweringContext(
TosaSpecification.create_from_string("TOSA-1.1+FP+shape"), shape_env
), FakeTensorMode(shape_env=shape_env) as mode:
x = mode.from_tensor(torch.empty(size=(1, 3, 4, 2), dtype=torch.float32))
output = exir_ops.backend.tosa.RESIZE.default(
x,
[scale_y_n, 1, 4, 2],
[0, 0],
[border_y, 0],
resize_mode="nearest",
)

assert output.dtype == torch.float32
assert (output.shape[0], output.shape[-1]) == (1, 2)
assert isinstance(output.shape[1], torch.SymInt)
assert output.shape[2] == 7
# The output height is computed as: (input_height - 1) * scale_y_n + border_y + 1.
# As the hegiht is a symbolic expression, we check that the expression is correct by
# comparing it to the expected expression.
assert str(_expr(output.shape[1])) == "(((border_y + 2*scale_y_n)//1)) + 1"
43 changes: 43 additions & 0 deletions backends/arm/test/passes/test_rewrite_upsample_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import sympy # type: ignore
import torch
from executorch.backends.arm._passes.rewrite_upsample import RewriteUpsamplePass
from torch.fx.experimental.symbolic_shapes import ShapeEnv


def _make_symint(
shape_env: ShapeEnv, symbol: str, hint: int, min: int = 1, max: int = 64
) -> torch.SymInt:
symint = shape_env.create_symintnode(sympy.Symbol(symbol), hint=hint)
shape_env.constrain_symbol_range(
symint.node.expr, compiler_min=min, compiler_max=max
)
return symint


def test_get_resize_parameters_1d_supports_symbolic_shapes_with_constant_ratio():
shape_env = ShapeEnv()
input_size = _make_symint(shape_env, "input_size", hint=4)
output_size = input_size * 2

scale_n, scale_d, offset, border = RewriteUpsamplePass.get_resize_parameters_1d(
input_size, output_size, align_corners=False
)

assert (scale_n, scale_d, offset, border) == (4, 2, -1, 1)


def test_get_resize_parameters_1d_rejects_non_constant_symbolic_ratio():
shape_env = ShapeEnv()
input_size = _make_symint(shape_env, "input_size", hint=4)
output_size = input_size + 1

with pytest.raises(RuntimeError, match="constant ratio"):
RewriteUpsamplePass.get_resize_parameters_1d(
input_size, output_size, align_corners=False
)
4 changes: 3 additions & 1 deletion backends/arm/tosa/dialect/ops/resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def _get_output_dtype(

def _validate_resize_parameters(scale, border, resize_mode):
def in_int16_range(values):
return all((x >= -(2**15)) and (x <= 2**15 - 1) for x in values)
return all(
(x >= -(2**15)) and (x <= 2**15 - 1) for x in values if isinstance(x, int)
)

if not in_int16_range(scale):
raise TosaValueError("scale is out of the int16 range", op="RESIZE")
Expand Down
Loading