Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions python/tvm/relax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,8 +370,11 @@ def _convert_te_arg_helper(arg):
arg, ShapeExpr
), "For Expr having ShapeStructInfo, emit_te now only supports ShapeExpr"
return [_convert_te_arg_helper(val) for val in arg.values]
if isinstance(arg.struct_info, PrimStructInfo):
return arg.value
if (
isinstance(arg.struct_info, PrimStructInfo)
and arg.struct_info.value is not None
):
return _convert_te_arg_helper(arg.struct_info.value)
elif isinstance(arg, (list, Array)):
return [_convert_te_arg_helper(x) for x in arg]
elif isinstance(arg, tuple):
Expand Down
54 changes: 54 additions & 0 deletions tests/python/relax/test_blockbuilder_emit_te.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
""" This file tests advanced emit_te features with help of TVMScript assertion"""
# The tests here depend on tvmscript
import tvm
from tvm import te, tir
from tvm import relax as rx
from tvm.ir.base import assert_structural_equal
Expand Down Expand Up @@ -69,3 +70,56 @@ def main(
return gv

assert_structural_equal(after, Expected)


def test_symbolic_shape_in_prim_value():
"""Symbolic vars may be provided to TE in R.Prim"""

def te_slice(tensor, i):
return tvm.te.compute([tensor.shape[1]], lambda j: tensor[i, j], name="slice")

def from_builder():
bb = rx.BlockBuilder()
A = rx.Var("A", R.Tensor([16, 16], "float32"))
tir_i = tvm.tir.Var("tir_i", "int64")
relax_i = rx.Var("relax_i", R.Prim(value=tir_i))

with bb.function("main", params=[A, relax_i]):
A_sliced = bb.emit_te(te_slice, A, relax_i)
bb.emit_func_output(A_sliced)

return bb.get()

@I.ir_module
class Expected:
@T.prim_func(private=True)
def te_slice(
A: T.Buffer([T.int64(16), T.int64(16)], "float32"),
Output: T.Buffer(T.int64(16), "float32"),
row_index: T.int64,
):
T.func_attr({"tir.noalias": T.bool(True)})

for i in range(A.shape[1]):
with T.block("slice"):
vi = T.axis.remap("S", [i])
Output[vi] = A[row_index, vi]

@R.function
def main(
A: R.Tensor([16, 16], "float32"),
arg_row_index: R.Prim(value="row_index"),
):
cls = Expected

row_index = T.int64()

gv = R.call_tir(
cls.te_slice,
A,
tir_vars=[row_index],
out_sinfo=R.Tensor([16], "float32"),
)
return gv

tvm.ir.assert_structural_equal(from_builder(), Expected)