Skip to content

[Bug] InternalError: PermuteDims dimension mismatch when converting scaled_dot_product_attention with 2D inputs #18441

@LiSsHhUuAaIi

Description

@LiSsHhUuAaIi

Description

When converting a PyTorch model containing torch.nn.functional.scaled_dot_product_attention with 2D input tensors to TVM Relax module via torch.export, an InternalError occurs during the conversion process. The TVM frontend incorrectly assumes 4D input dimensions for the attention operation when the actual input is 2D.

Expected behavior

The PyTorch model with scaled_dot_product_attention should be successfully converted to TVM Relax module regardless of input dimensionality, as it runs correctly in native PyTorch.

Actual behavior

An InternalError occurs during from_exported_program conversion with the message PermuteDims expects the number of input axes to equal the ndim of the input tensor. However, the tensor ndim is 2 while the given number of axes is 4, indicating a dimension mismatch in the PermuteDims operation.

Environment

  • OS: Ubuntu 20.04.6 LTS
  • TVM version: 0.23.dev0
  • Python version: 3.11.14

Steps to reproduce

import torch
import torch.nn as nn
import tvm
from tvm import relax

class MinimalAttentionModel(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return torch.nn.functional.scaled_dot_product_attention(x, x, x, is_causal=False)


model = MinimalAttentionModel()
model.eval()

# 2D input triggers the bug
x = torch.randn(8, 32)

# PyTorch execution works
with torch.no_grad():
    output = model(x)

# PyTorch export works  
exported_program = torch.export.export(model, (x,))

# TVM conversion fails
from tvm.relax.frontend.torch import from_exported_program
mod = from_exported_program(exported_program)  # InternalError here

Error Log

Traceback (most recent call last):
  File "test.py", line 29, in <module>
    mod = from_exported_program(exported_program)  # InternalError here
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  ....

tvm.error.InternalError: PermuteDims expects the number of input axes to equal the ndim of the input tensor. However, the tensor ndim is 2 while the given number of axes is 4

Triage

  • needs-triage
  • bug
  • frontend: pytorch

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions