diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 1479d6f23913..adc5998d1836 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -911,8 +911,7 @@ class Size(OnnxOpConverter): @classmethod def _impl_v1(cls, bb, inputs, attr, params): - # TODO(tvm-team): add native support for size op - return relax.op.prod(relax.op.shape_to_tensor(relax.op.shape_of(inputs[0]))) + return relax.op.size(inputs[0]) class EyeLike(OnnxOpConverter): diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index c6504d79c9a5..2ebca3811f92 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -40,6 +40,7 @@ register_gradient, shape_of, shape_to_tensor, + size, tensor_to_shape, to_vdevice, ) diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py index ffa19fbaa060..d46aa883f0fb 100644 --- a/python/tvm/relax/op/base.py +++ b/python/tvm/relax/op/base.py @@ -634,6 +634,22 @@ def shape_of(expr: Expr) -> Expr: return _ffi_api.shape_of(expr) # type: ignore # pylint: disable=no-member +def size(expr: Expr) -> Expr: + """Get the total number of elements in a tensor. + + Parameters + ---------- + expr : Expr + The input tensor. + + Returns + ------- + result : Expr + A scalar tensor of dtype int64 containing the total number of elements. + """ + return _ffi_api.size(expr) # type: ignore # pylint: disable=no-member + + def tensor_to_shape(expr: Expr) -> Expr: """Convert tensor to shape expr. Parameters @@ -777,11 +793,13 @@ def call_pure_packed( sinfo_args = [sinfo_args] sinfo_args = [ - sinfo() - if callable(sinfo) - else sinfo.asobject() - if isinstance(sinfo, ObjectConvertible) - else sinfo + ( + sinfo() + if callable(sinfo) + else sinfo.asobject() + if isinstance(sinfo, ObjectConvertible) + else sinfo + ) for sinfo in sinfo_args ] diff --git a/python/tvm/relax/transform/legalize_ops/inspect_op.py b/python/tvm/relax/transform/legalize_ops/inspect_op.py index e031386e6e41..a41c74cae0bd 100644 --- a/python/tvm/relax/transform/legalize_ops/inspect_op.py +++ b/python/tvm/relax/transform/legalize_ops/inspect_op.py @@ -23,6 +23,7 @@ from ...block_builder import BlockBuilder from ...expr import Call, Expr +from ... import op from .common import register_legalize @@ -126,3 +127,8 @@ def _get_tensor_elem_offset(dlpack_handle: T.handle) -> T.int64: gvar = bb.add_func(_get_tensor_elem_offset, "_get_tensor_elem_offset") return Call(gvar, call.args) + + +@register_legalize("relax.size") +def _size(_bb: BlockBuilder, call: Call) -> Expr: + return op.prod(op.shape_to_tensor(op.shape_of(call.args[0]))) diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index e0a009a94eb8..5410c3c03a43 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -163,6 +163,7 @@ sign, sin, sinh, + size, slice_scatter, sort, split, @@ -938,6 +939,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "shape", "shape_of", "ShapeExpr", + "size", "std", "str", "sum", diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index 3acfb53b2784..d7d68766dd1a 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -1125,6 +1125,32 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef().def("relax.op.shape_of", MakeShapeOf); } +// size + +StructInfo InferStructInfoSize(const Call& call, const BlockBuilder& ctx) { + auto arg_sinfo = GetStructInfo(call->args[0]); + auto* tensor_sinfo = GetStructInfo(call->args[0]).as(); + CHECK(tensor_sinfo) << "size expects a tensor input, but received " << arg_sinfo + << "; use MatchCast if necessary"; + return TensorStructInfo(ShapeExpr(ffi::Array{}), DataType::Int(64)); +} + +TVM_REGISTER_OP("relax.size") + .set_num_inputs(1) + .add_argument("input", "Expr", "The input tensor") + .set_attr("FInferStructInfo", InferStructInfoSize) + .set_attr("FPurity", Bool(true)); + +Expr MakeSize(Expr expr) { + static const Op& op = Op::Get("relax.size"); + return Call(op, {expr}, {}, {}); +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("relax.op.size", MakeSize); +} + // tensor_to_shape StructInfo ReturnTensorToShapeStructInfo(const Call& call, const BlockBuilder& ctx) { diff --git a/tests/python/relax/test_op_size.py b/tests/python/relax/test_op_size.py new file mode 100644 index 000000000000..77c5ebef5af1 --- /dev/null +++ b/tests/python/relax/test_op_size.py @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np + +import tvm +import tvm.testing +from tvm import relax +from tvm.script import relax as R + + +def test_op_size(): + @tvm.script.ir_module + class Module: + @R.function + def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((), "int64"): + return R.size(x) + + x_np = np.random.rand(2, 3).astype("float32") + x = tvm.runtime.tensor(x_np) + + target = tvm.target.Target("llvm") + ex = relax.build(Module, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + res = vm["main"](x) + assert res.numpy() == 6 + + +def test_op_size_dynamic(): + @tvm.script.ir_module + class Module: + @R.function + def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor((), "int64"): + return R.size(x) + + x_np = np.random.rand(4, 5).astype("float32") + x = tvm.runtime.tensor(x_np) + + target = tvm.target.Target("llvm") + ex = relax.build(Module, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + res = vm["main"](x) + assert res.numpy() == 20 + + +if __name__ == "__main__": + tvm.testing.main()