Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1002,6 +1002,23 @@ def _cross_entropy(self, node: fx.node.Node) -> relax.Expr:
)
)

def _scaled_dot_product_attention(self, node: fx.node.Node) -> relax.Var:
assert len(node.args) <= 4, "Dropout, and causal masking are not supported."
transpose_S_H = lambda tensor: relax.op.permute_dims(tensor, [0, 2, 1, 3])
query = transpose_S_H(self.env[node.args[0]])
key = transpose_S_H(self.env[node.args[1]])
value = transpose_S_H(self.env[node.args[2]])

if len(node.args) == 4:
mask = self.env[node.args[3]]
msg = "Only a float mask is supported for the attn_mask input."
assert "float" in mask.struct_info.dtype, msg
attn = relax.op.nn.attention(query, key, value, bias=mask)
else:
attn = relax.op.nn.attention(query, key, value)

return self.block_builder.emit(attn)

########## Others ##########

def _size(self, node: fx.node.Node) -> relax.Expr:
Expand Down Expand Up @@ -1185,6 +1202,7 @@ def create_convert_map(self):
"neg": self._neg,
"max": self._max,
"cross_entropy": self._cross_entropy,
"scaled_dot_product_attention": self._scaled_dot_product_attention,
}

def from_fx(
Expand Down
Loading