Skip to content
Closed
Prev Previous commit
Next Next commit
Check for TupleStructInfo in all call_* variants
  • Loading branch information
Lunderberg committed Oct 31, 2023
commit fc16d5da7de6e48a1d1f281c22ad3187fa2775e5
28 changes: 16 additions & 12 deletions python/tvm/relax/op/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,18 @@ def null_value() -> Call:
return _ffi_api.null_value() # type: ignore


def _normalize_arg_tuple(args: Expr) -> Expr:
if isinstance(args, RxTuple) or isinstance(args.struct_info_, TupleStructInfo):
# A tuple, or a Var bound to a tuple, are kept as-is
return args
elif isinstance(args, Expr):
# A single argument is wrapped into a tuple
return RxTuple((args,))
else:
# Anything else is left for the FFI to handle
return args


@args_converter.auto
def call_tir(
gvar: GlobalVar,
Expand Down Expand Up @@ -97,12 +109,7 @@ def call_tir(
ret: Call
A call node for the call_tir operator.
"""
if (
isinstance(args, Expr)
and not isinstance(args, RxTuple)
and not isinstance(args.struct_info_, TupleStructInfo)
):
args = RxTuple((args,))
args = _normalize_arg_tuple(args)

if not isinstance(out_sinfo, list):
out_sinfo = [out_sinfo]
Expand Down Expand Up @@ -156,8 +163,7 @@ def call_tir_with_grad(
ret: Call
A call node for the call_tir_with_grad operator.
"""
if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore
args = RxTuple((args,))
args = _normalize_arg_tuple(args)

if not isinstance(out_sinfo, list):
out_sinfo = [out_sinfo]
Expand Down Expand Up @@ -224,8 +230,7 @@ def call_tir_inplace(
ret: Call
A call node for the call_tir operator.
"""
if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore
args = RxTuple((args,))
args = _normalize_arg_tuple(args)

if not isinstance(inplace_indices, list):
inplace_indices = [inplace_indices]
Expand Down Expand Up @@ -279,8 +284,7 @@ def call_dps_packed(
if isinstance(func, str):
func = ExternFunc(func)

if isinstance(args, Expr) and not isinstance(args, RxTuple): # type: ignore
args = RxTuple((args,))
args = _normalize_arg_tuple(args)

if not isinstance(out_sinfo, list):
out_sinfo = [out_sinfo]
Expand Down