diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 1938355169f0..e554648c41ad 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1477,10 +1477,49 @@ def _pixel_shuffle(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.nn.pixel_shuffle(data, upscale_factor)) def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var: - transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1, 3]) - query = transpose_S_H(self.env[node.args[0]]) - key = transpose_S_H(self.env[node.args[1]]) - value = transpose_S_H(self.env[node.args[2]]) + query_tensor = self.env[node.args[0]] + key_tensor = self.env[node.args[1]] + value_tensor = self.env[node.args[2]] + + # Check the dimensionality of the input tensors + query_ndim = len(query_tensor.struct_info.shape) + + # TVM's nn.attention requires 4D inputs in format (batch, num_heads, seq_len, head_dim) + # For 2D inputs (seq_len, head_dim), we need to reshape to 4D first + if query_ndim == 2: + # 2D input: (seq_len, head_dim) -> expand to (1, 1, seq_len, head_dim) + # Add batch dimension at axis 0 + query_3d = self.block_builder.emit(relax.op.expand_dims(query_tensor, axis=0)) + key_3d = self.block_builder.emit(relax.op.expand_dims(key_tensor, axis=0)) + value_3d = self.block_builder.emit(relax.op.expand_dims(value_tensor, axis=0)) + # Add num_heads dimension at axis 1 + query = self.block_builder.emit(relax.op.expand_dims(query_3d, axis=1)) + key = self.block_builder.emit(relax.op.expand_dims(key_3d, axis=1)) + value = self.block_builder.emit(relax.op.expand_dims(value_3d, axis=1)) + + # No permutation needed for 2D inputs after expanding to 4D + # After attention, squeeze back to 2D: (1, 1, seq_len, head_dim) -> (seq_len, head_dim) + def transpose_and_reshape_back(tensor): + # Squeeze batch and num_heads dimensions + return self.block_builder.emit(relax.op.squeeze(tensor, axis=[0, 1])) + + elif query_ndim == 4: + # 4D input: (batch, seq_len, num_heads, head_dim) + # -> (batch, num_heads, seq_len, head_dim) + transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1, 3]) + query = self.block_builder.emit(transpose_S_H(query_tensor)) + key = self.block_builder.emit(transpose_S_H(key_tensor)) + value = self.block_builder.emit(transpose_S_H(value_tensor)) + + # For 4D, transpose back after attention + def transpose_and_reshape_back(tensor): + return self.block_builder.emit(transpose_S_H(tensor)) + + else: + raise ValueError( + f"scaled_dot_product_attention expects 2D or 4D inputs, but got {query_ndim}D input" + ) + attn_mask = node.args[3] if len(node.args) > 3 else node.kwargs.get("attn_mask", None) dropout_p = node.args[4] if len(node.args) > 4 else node.kwargs.get("dropout_p", 0.0) assert dropout_p == 0.0, "Dropout is not supported" @@ -1492,12 +1531,12 @@ def _scaled_dot_product_attention(self, node: fx.Node) -> relax.Var: msg = "Only a float mask is supported for the attn_mask input." assert "float" in attn_mask.struct_info.dtype, msg - return self.block_builder.emit( - transpose_S_H( - relax.op.nn.attention(query, key, value, bias=attn_mask, causal_mask=causal_mask) - ) + attention_output = self.block_builder.emit( + relax.op.nn.attention(query, key, value, bias=attn_mask, causal_mask=causal_mask) ) + return transpose_and_reshape_back(attention_output) + def _unbind(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", 0) @@ -1594,6 +1633,7 @@ def _any(self, node: fx.Node) -> relax.Var: x = args[0] dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) keepdim = args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False) + # For boolean tensors, any is equivalent to max (checking if any element is True) return self.block_builder.emit(relax.op.max(x, dim, keepdims=keepdim)) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 8ff46bf611b2..f59784c3a2f0 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -4255,6 +4255,45 @@ def main( run_ep_decomposition=True, ) + # Test 2D input (seq_len, head_dim) - bug fix for #18441 + class Attention2D(Module): + def forward(self, x): + return torch.nn.functional.scaled_dot_product_attention(x, x, x, is_causal=False) + + @I.ir_module + class Expected2D: + @R.function + def main( + x: R.Tensor((8, 32), dtype="float32"), + ) -> R.Tuple(R.Tensor((8, 32), dtype="float32")): + with R.dataflow(): + # Expand to add batch dimension for query, key, value separately + # (8, 32) -> (1, 8, 32) + lv: R.Tensor((1, 8, 32), dtype="float32") = R.expand_dims(x, axis=[0]) + lv1: R.Tensor((1, 8, 32), dtype="float32") = R.expand_dims(x, axis=[0]) + lv2: R.Tensor((1, 8, 32), dtype="float32") = R.expand_dims(x, axis=[0]) + # Expand to add num_heads dimension: (1, 8, 32) -> (1, 1, 8, 32) + lv3: R.Tensor((1, 1, 8, 32), dtype="float32") = R.expand_dims(lv, axis=[1]) + lv4: R.Tensor((1, 1, 8, 32), dtype="float32") = R.expand_dims(lv1, axis=[1]) + lv5: R.Tensor((1, 1, 8, 32), dtype="float32") = R.expand_dims(lv2, axis=[1]) + # Attention operation: (1, 1, 8, 32) -> (1, 1, 8, 32) + lv6: R.Tensor((1, 1, 8, 32), dtype="float32") = R.nn.attention( + lv3, lv4, lv5, scale=None, causal_mask=None, window_size=None + ) + # Squeeze batch and num_heads dimensions: (1, 1, 8, 32) -> (8, 32) + lv7: R.Tensor((8, 32), dtype="float32") = R.squeeze(lv6, axis=[0, 1]) + gv: R.Tuple(R.Tensor((8, 32), dtype="float32")) = (lv7,) + R.output(gv) + return gv + + verify_model( + Attention2D(), + (torch.randn(8, 32, dtype=torch.float32),), + {}, + Expected2D, + run_ep_decomposition=False, + ) + def test_unbind(): class Unbind1(Module):