Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Fix lint
  • Loading branch information
yongwww committed Feb 17, 2025
commit 94fc6619e8456ba4bc61784cb1149a96201cb7fe
2 changes: 1 addition & 1 deletion python/tvm/relax/frontend/nn/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1567,7 +1567,7 @@ def get_timestep_embedding(

# Zero pad
if embedding_dim % 2 == 1:
emb = _op.nn.pad(emb, 0, (0, 1, 0, 0))
emb = _op.nn.pad(emb, (0, 1, 0, 0))

# Cast to proper output type
emb = _op.astype(emb, dtype)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relax/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from tvm import DataType
from tvm.tir import FloatImm

from ...expr import Expr, const
from ...expr import Expr
from . import _ffi_api


Expand Down
10 changes: 6 additions & 4 deletions tests/python/relax/test_frontend_nn_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,9 @@ def test(self, x: Tensor):
return chunk

@R.function
def test(x: R.Tensor((8,), dtype="float32"), _io: R.Object) -> R.Tuple(
def test(
x: R.Tensor((8,), dtype="float32"), _io: R.Object
) -> R.Tuple(
R.Tuple(
R.Tensor((2,), dtype="float32"),
R.Tensor((2,), dtype="float32"),
Expand Down Expand Up @@ -488,9 +490,9 @@ def test(
) -> R.Tuple(R.Tensor((1, 32, 32, 32), dtype="float32"), R.Tuple(R.Object)):
R.func_attr({"num_input": 4})
with R.dataflow():
scaled_dot_product_attention: R.Tensor((1, 32, 32, 32), dtype="float32") = (
R.nn.attention(query, key, value, scale=None, causal_mask=None)
)
scaled_dot_product_attention: R.Tensor(
(1, 32, 32, 32), dtype="float32"
) = R.nn.attention(query, key, value, scale=None, causal_mask=None)
gv1: R.Tuple(R.Tensor((1, 32, 32, 32), dtype="float32"), R.Tuple(R.Object)) = (
scaled_dot_product_attention,
(_io,),
Expand Down