Expected behavior
When using torch.fx to convert a PyTorch model containing in-place operations (e.g., bitwise_or_), the resulting IR in TVM should accurately reflect the updated tensor and return the modified value.
Actual behavior
Currently, the FX-based tracing results in incorrect IR where the original input tensor is returned, instead of the tensor updated by the in-place operation. This leads to a semantic mismatch.
Example:
# Original PyTorch model
class Model(Module):
def forward(self, input: torch.Tensor, other: torch.Tensor):
input.bitwise_or_(other)
return input
Produces this incorrect FX-derived IR:
@R.function
def main(inp_0: R.Tensor((128, 128), dtype="int32"), inp_1: R.Tensor((128, 128), dtype="int32")) -> R.Tensor((128, 128), dtype="int32"):
with R.dataflow():
lv: R.Tensor((128, 128), dtype="int32") = R.bitwise_or(inp_0, inp_1)
gv: R.Tensor((128, 128), dtype="int32") = inp_0 # Incorrect: should return lv
R.output(gv)
return gv
Whereas using exported_program gives the correct representation:
@R.function
def main(input: R.Tensor((128, 128), dtype="int32"), other: R.Tensor((128, 128), dtype="int32")) -> R.Tuple(R.Tensor((128, 128), dtype="int32")):
with R.dataflow():
lv: R.Tensor((128, 128), dtype="int32") = R.bitwise_or(input, other)
gv: R.Tuple(R.Tensor((128, 128), dtype="int32")) = (lv,)
R.output(gv)
return gv
cc @shingjan
Expected behavior
When using
torch.fxto convert a PyTorch model containing in-place operations (e.g.,bitwise_or_), the resulting IR in TVM should accurately reflect the updated tensor and return the modified value.Actual behavior
Currently, the FX-based tracing results in incorrect IR where the original input tensor is returned, instead of the tensor updated by the in-place operation. This leads to a semantic mismatch.
Example:
Produces this incorrect FX-derived IR:
Whereas using
exported_programgives the correct representation:cc @shingjan