Description
When converting a PyTorch model containing torch.nn.functional.scaled_dot_product_attention with 2D input tensors to TVM Relax module via torch.export, an InternalError occurs during the conversion process. The TVM frontend incorrectly assumes 4D input dimensions for the attention operation when the actual input is 2D.
Expected behavior
The PyTorch model with scaled_dot_product_attention should be successfully converted to TVM Relax module regardless of input dimensionality, as it runs correctly in native PyTorch.
Actual behavior
An InternalError occurs during from_exported_program conversion with the message PermuteDims expects the number of input axes to equal the ndim of the input tensor. However, the tensor ndim is 2 while the given number of axes is 4, indicating a dimension mismatch in the PermuteDims operation.
Environment
- OS: Ubuntu 20.04.6 LTS
- TVM version: 0.23.dev0
- Python version: 3.11.14
Steps to reproduce
import torch
import torch.nn as nn
import tvm
from tvm import relax
class MinimalAttentionModel(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return torch.nn.functional.scaled_dot_product_attention(x, x, x, is_causal=False)
model = MinimalAttentionModel()
model.eval()
# 2D input triggers the bug
x = torch.randn(8, 32)
# PyTorch execution works
with torch.no_grad():
output = model(x)
# PyTorch export works
exported_program = torch.export.export(model, (x,))
# TVM conversion fails
from tvm.relax.frontend.torch import from_exported_program
mod = from_exported_program(exported_program) # InternalError here
Error Log
Traceback (most recent call last):
File "test.py", line 29, in <module>
mod = from_exported_program(exported_program) # InternalError here
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
....
tvm.error.InternalError: PermuteDims expects the number of input axes to equal the ndim of the input tensor. However, the tensor ndim is 2 while the given number of axes is 4
Triage
- needs-triage
- bug
- frontend: pytorch
Description
When converting a PyTorch model containing
torch.nn.functional.scaled_dot_product_attentionwith 2D input tensors to TVM Relax module viatorch.export, an InternalError occurs during the conversion process. The TVM frontend incorrectly assumes 4D input dimensions for the attention operation when the actual input is 2D.Expected behavior
The PyTorch model with
scaled_dot_product_attentionshould be successfully converted to TVM Relax module regardless of input dimensionality, as it runs correctly in native PyTorch.Actual behavior
An InternalError occurs during
from_exported_programconversion with the messagePermuteDims expects the number of input axes to equal the ndim of the input tensor. However, the tensor ndim is 2 while the given number of axes is 4, indicating a dimension mismatch in the PermuteDims operation.Environment
Steps to reproduce
Error Log
Triage