diff --git a/python/tvm/relax/frontend/nn/modules.py b/python/tvm/relax/frontend/nn/modules.py index 5ca5f72787b7..455e42df97fc 100644 --- a/python/tvm/relax/frontend/nn/modules.py +++ b/python/tvm/relax/frontend/nn/modules.py @@ -743,7 +743,7 @@ def forward(self, x: Tensor): op.reshape(x, shape=[-1]), axis=0, ), - shape=[*x.shape, self.dim], # TODO(@junrushao): revisit and remove self.dim + shape=[*x.shape, self.weight.shape[1]], ) diff --git a/tests/python/relax/test_frontend_nn_modules.py b/tests/python/relax/test_frontend_nn_modules.py index e9a4a6f62424..8dc4994465ca 100644 --- a/tests/python/relax/test_frontend_nn_modules.py +++ b/tests/python/relax/test_frontend_nn_modules.py @@ -365,7 +365,26 @@ def forward( assert_structural_equal(tvm_mod["forward"], forward, True) -def test_embedding(): +def test_embedding_1d(): + @R.function + def forward( + x: R.Tensor((4,), dtype="int32"), + _io: R.Object, + weight: R.Tensor((8, 16), dtype="float32"), + ) -> R.Tuple(R.Tensor((4, 16), dtype="float32"), R.Tuple(R.Object)): + R.func_attr({"num_input": 2}) + with R.dataflow(): + take: R.Tensor((4, 16), dtype="float32") = R.take(weight, x, axis=0) + gv1: R.Tuple(R.Tensor((4, 16), dtype="float32"), R.Tuple(R.Object)) = take, (_io,) + R.output(gv1) + return gv1 + + mod = modules.Embedding(8, 16, "float32") + tvm_mod, _ = mod.export_tvm(spec={"forward": {"x": spec.Tensor((4,), "int32")}}, debug=True) + assert_structural_equal(tvm_mod["forward"], forward, True) + + +def test_embedding_2d(): @R.function def forward( x: R.Tensor((1, 4), dtype="int32"),