From 7fdb1e388bcc8ea6c764ec2d1303ba833ee2f537 Mon Sep 17 00:00:00 2001 From: Guan-Ming Chiu Date: Fri, 16 Jan 2026 12:34:59 +0800 Subject: [PATCH 1/2] Replace topi.take with relax.op.take --- python/tvm/relax/frontend/onnx/onnx_frontend.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 1479d6f23913..409d17847991 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -3161,11 +3161,10 @@ def _impl_v1(cls, bb, inputs, attr, params): if pos_ids is None: pos_ids = relax.const([list(range(seq_len))] * batch_size, dtype="int64") - # TODO(jwfromm) Replace with relax ops once take has better support. - word_vec = bb.emit_te(topi.take, word_emb, input_ids, 0) + word_vec = relax.op.take(word_emb, input_ids, axis=0) if segment_ids: - segment_vec = bb.emit_te(topi.take, segment_emb, segment_ids, 0) - pos_vec = bb.emit_te(topi.take, pos_emb, pos_ids, 0) + segment_vec = relax.op.take(segment_emb, segment_ids, axis=0) + pos_vec = relax.op.take(pos_emb, pos_ids, axis=0) vec_sum = relax.op.add(word_vec, pos_vec) if segment_ids: @@ -3323,15 +3322,11 @@ def _impl_v11(cls, bb, inputs, attr, params): mode = attr.get("mode", b"DCR").decode("utf-8") b, c, h, w = inputs[0].struct_info.shape if mode == "DCR": - x = relax.op.reshape( - inputs[0], (b, block_size, block_size, c // (block_size**2), h, w) - ) + x = relax.op.reshape(inputs[0], (b, block_size, block_size, c // (block_size**2), h, w)) x = relax.op.permute_dims(x, [0, 3, 4, 1, 5, 2]) return relax.op.reshape(x, (b, c // (block_size**2), h * block_size, w * block_size)) elif mode == "CRD": - x = relax.op.reshape( - inputs[0], (b, c // (block_size**2), block_size, block_size, h, w) - ) + x = relax.op.reshape(inputs[0], (b, c // (block_size**2), block_size, block_size, h, w)) x = relax.op.permute_dims(x, [0, 1, 4, 2, 5, 3]) return relax.op.reshape(x, (b, c // (block_size**2), h * block_size, w * block_size)) else: From 8c707a588dbe56deefff2fd16153d9001cd645d4 Mon Sep 17 00:00:00 2001 From: "Guan-Ming (Wesley) Chiu" <105915352+guan404ming@users.noreply.github.com> Date: Fri, 16 Jan 2026 13:53:12 +0800 Subject: [PATCH 2/2] Fix lint --- python/tvm/relax/frontend/onnx/onnx_frontend.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 409d17847991..c5eb0420a3c4 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -3322,11 +3322,15 @@ def _impl_v11(cls, bb, inputs, attr, params): mode = attr.get("mode", b"DCR").decode("utf-8") b, c, h, w = inputs[0].struct_info.shape if mode == "DCR": - x = relax.op.reshape(inputs[0], (b, block_size, block_size, c // (block_size**2), h, w)) + x = relax.op.reshape( + inputs[0], (b, block_size, block_size, c // (block_size**2), h, w) + ) x = relax.op.permute_dims(x, [0, 3, 4, 1, 5, 2]) return relax.op.reshape(x, (b, c // (block_size**2), h * block_size, w * block_size)) elif mode == "CRD": - x = relax.op.reshape(inputs[0], (b, c // (block_size**2), block_size, block_size, h, w)) + x = relax.op.reshape( + inputs[0], (b, c // (block_size**2), block_size, block_size, h, w) + ) x = relax.op.permute_dims(x, [0, 1, 4, 2, 5, 3]) return relax.op.reshape(x, (b, c // (block_size**2), h * block_size, w * block_size)) else: