From f613704f9bd12cedfa2d1954683b398406a4f26d Mon Sep 17 00:00:00 2001 From: "Guan-Ming (Wesley) Chiu" <105915352+guan404ming@users.noreply.github.com> Date: Thu, 27 Nov 2025 14:44:17 +0800 Subject: [PATCH 1/2] Implement bidirectional LSTM --- .../torch/exported_program_translator.py | 255 +++++++++++------- .../test_frontend_from_exported_program.py | 105 ++++---- 2 files changed, 202 insertions(+), 158 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index ac79024acfb9..7f9917437350 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -350,6 +350,75 @@ def _upsample_bicubic2d(self, node: fx.node) -> relax.Var: align_corners=align_corners, ) + def _lstm_cell_unroll( + self, + input_reshaped, + weight_ih, + weight_hh, + bias_ih, + bias_hh, + h_prev, + c_prev, + seq_len, + hidden_size, + reverse=False, + ): + """Unroll LSTM cells for a single direction.""" + weight_ih_t = self.block_builder.emit(relax.op.permute_dims(weight_ih, axes=[1, 0])) + weight_hh_t = self.block_builder.emit(relax.op.permute_dims(weight_hh, axes=[1, 0])) + outputs = [] + time_steps = range(seq_len - 1, -1, -1) if reverse else range(seq_len) + + for t in time_steps: + x_t = self.block_builder.emit( + relax.op.take(input_reshaped, relax.const(t, "int64"), axis=0, mode="clip") + ) + ih_gates = self.block_builder.emit(relax.op.linear_algebra.matmul(x_t, weight_ih_t)) + hh_gates = self.block_builder.emit(relax.op.linear_algebra.matmul(h_prev, weight_hh_t)) + + gates = self.block_builder.emit(relax.op.add(ih_gates, hh_gates)) + if bias_ih is not None: + gates = self.block_builder.emit(relax.op.add(gates, bias_ih)) + if bias_hh is not None: + gates = self.block_builder.emit(relax.op.add(gates, bias_hh)) + + i_gate = self.block_builder.emit( + relax.op.strided_slice(gates, axes=[1], begin=[0], end=[hidden_size]) + ) + f_gate = self.block_builder.emit( + relax.op.strided_slice(gates, axes=[1], begin=[hidden_size], end=[2 * hidden_size]) + ) + g_gate = self.block_builder.emit( + relax.op.strided_slice( + gates, axes=[1], begin=[2 * hidden_size], end=[3 * hidden_size] + ) + ) + o_gate = self.block_builder.emit( + relax.op.strided_slice( + gates, axes=[1], begin=[3 * hidden_size], end=[4 * hidden_size] + ) + ) + + i_t = self.block_builder.emit(relax.op.sigmoid(i_gate)) + f_t = self.block_builder.emit(relax.op.sigmoid(f_gate)) + g_t = self.block_builder.emit(relax.op.tanh(g_gate)) + o_t = self.block_builder.emit(relax.op.sigmoid(o_gate)) + + c_t = self.block_builder.emit( + relax.op.add(relax.op.multiply(f_t, c_prev), relax.op.multiply(i_t, g_t)) + ) + h_t = self.block_builder.emit(relax.op.multiply(o_t, relax.op.tanh(c_t))) + + outputs.append(h_t) + h_prev = h_t + c_prev = c_t + + if reverse: + outputs = outputs[::-1] + + output = self.block_builder.emit(relax.op.stack(outputs, axis=0)) + return output + def _lstm(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) input_tensor = args[0] @@ -357,39 +426,30 @@ def _lstm(self, node: fx.Node) -> relax.Var: params = args[2] if len(args) > 2 else None has_biases = args[3] if len(args) > 3 else True num_layers = args[4] if len(args) > 4 else 1 - _dropout = args[5] if len(args) > 5 else 0.0 # Not used in inference - _train = args[6] if len(args) > 6 else False # Not used in inference bidirectional = args[7] if len(args) > 7 else False batch_first = args[8] if len(args) > 8 else False - if bidirectional: - raise NotImplementedError("Bidirectional LSTM is not yet supported") + if num_layers > 1: raise NotImplementedError("Multi-layer LSTM is not yet supported") + input_shape = self.shape_of(input_tensor) if batch_first: - # Input shape: (batch, seq_len, input_size) batch_size, seq_len, input_size = input_shape else: - # Input shape: (seq_len, batch, input_size) seq_len, batch_size, input_size = input_shape - if isinstance(seq_len, tvm.tir.IntImm): - seq_len = seq_len.value - if isinstance(batch_size, tvm.tir.IntImm): - batch_size = batch_size.value - if isinstance(input_size, tvm.tir.IntImm): - input_size = input_size.value + seq_len = int(seq_len) if isinstance(seq_len, tvm.tir.IntImm) else seq_len + batch_size = int(batch_size) if isinstance(batch_size, tvm.tir.IntImm) else batch_size + input_size = int(input_size) if isinstance(input_size, tvm.tir.IntImm) else input_size # Extract hidden size from the LSTM parameters # The parameters are: [weight_ih, weight_hh, bias_ih, bias_hh] # weight_ih shape: (4 * hidden_size, input_size) # weight_hh shape: (4 * hidden_size, hidden_size) if params and len(params) >= 2: - weight_ih = params[0] - weight_hh = params[1] # Extract hidden size from weight dimensions # weight_ih has shape (4 * hidden_size, input_size) - weight_ih_shape = self.shape_of(weight_ih) - hidden_size = weight_ih_shape[0] // 4 # 4 gates: input, forget, cell, output + weight_ih_shape = self.shape_of(params[0]) + hidden_size = weight_ih_shape[0] // 4 else: # Fallback to a default hidden size hidden_size = 16 @@ -402,109 +462,96 @@ def _lstm(self, node: fx.Node) -> relax.Var: # c_t = f_t * c_{t-1} + i_t * g_t # h_t = o_t * tanh(c_t) dtype = input_tensor.struct_info.dtype - if params and len(params) >= 4: - weight_ih = params[0] # (4 * hidden_size, input_size) - weight_hh = params[1] # (4 * hidden_size, hidden_size) - bias_ih = params[2] if has_biases else None # (4 * hidden_size,) - bias_hh = params[3] if has_biases else None # (4 * hidden_size,) + params_per_direction = 4 if has_biases else 2 + + weight_ih_fwd = params[0] if params else None + weight_hh_fwd = params[1] if params and len(params) > 1 else None + bias_ih_fwd = params[2] if params and has_biases and len(params) > 2 else None + bias_hh_fwd = params[3] if params and has_biases and len(params) > 3 else None + + if bidirectional and params and len(params) >= params_per_direction * 2: + weight_ih_bwd = params[params_per_direction] + weight_hh_bwd = params[params_per_direction + 1] + bias_ih_bwd = params[params_per_direction + 2] if has_biases else None + bias_hh_bwd = params[params_per_direction + 3] if has_biases else None else: - # Fallback: create zero weights - weight_ih = self.block_builder.emit( - relax.op.zeros(relax.ShapeExpr((4 * hidden_size, input_size)), dtype) - ) - weight_hh = self.block_builder.emit( - relax.op.zeros(relax.ShapeExpr((4 * hidden_size, hidden_size)), dtype) - ) - bias_ih = None - bias_hh = None - # Initialize hidden and cell states + weight_ih_bwd = None + weight_hh_bwd = None + bias_ih_bwd = None + bias_hh_bwd = None + if hx is not None and len(hx) >= 2: - h_0 = hx[0] # (num_layers, batch_size, hidden_size) - c_0 = hx[1] # (num_layers, batch_size, hidden_size) - # Extract the first layer's hidden state - h_prev = self.block_builder.emit( + h_0, c_0 = hx[0], hx[1] + h_prev_fwd = self.block_builder.emit( relax.op.take(h_0, relax.const(0, "int64"), axis=0, mode="clip") ) - c_prev = self.block_builder.emit( + c_prev_fwd = self.block_builder.emit( relax.op.take(c_0, relax.const(0, "int64"), axis=0, mode="clip") ) + if bidirectional: + h_prev_bwd = self.block_builder.emit( + relax.op.take(h_0, relax.const(1, "int64"), axis=0, mode="clip") + ) + c_prev_bwd = self.block_builder.emit( + relax.op.take(c_0, relax.const(1, "int64"), axis=0, mode="clip") + ) + else: + h_prev_bwd = None + c_prev_bwd = None else: - h_prev = self.block_builder.emit( + h_prev_fwd = self.block_builder.emit( relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), dtype) ) - c_prev = self.block_builder.emit( + c_prev_fwd = self.block_builder.emit( relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), dtype) ) - # Reshape input for processing - if batch_first: - # Input: (batch, seq_len, input_size) -> (seq_len, batch, input_size) - input_reshaped = self.block_builder.emit( - relax.op.permute_dims(input_tensor, axes=[1, 0, 2]) - ) - else: - input_reshaped = input_tensor - weight_ih_t = self.block_builder.emit(relax.op.permute_dims(weight_ih, axes=[1, 0])) - weight_hh_t = self.block_builder.emit(relax.op.permute_dims(weight_hh, axes=[1, 0])) - outputs = [] - for t in range(seq_len): - # Get input at time t: (batch_size, input_size) - x_t = self.block_builder.emit( - relax.op.take(input_reshaped, relax.const(t, "int64"), axis=0, mode="clip") - ) - # Compute gates: W_ih * x_t + W_hh * h_{t-1} + bias - # Input-to-hidden: (batch_size, input_size) @ (4*hidden_size, input_size).T - ih_gates = self.block_builder.emit(relax.op.linear_algebra.matmul(x_t, weight_ih_t)) - - # Hidden-to-hidden: (batch_size, hidden_size) @ (4*hidden_size, hidden_size).T - hh_gates = self.block_builder.emit(relax.op.linear_algebra.matmul(h_prev, weight_hh_t)) - # Add biases if present - if bias_ih is not None and bias_hh is not None: - gates = self.block_builder.emit( - relax.op.add(relax.op.add(relax.op.add(ih_gates, bias_ih), hh_gates), bias_hh) - ) - elif bias_ih is not None: - gates = self.block_builder.emit( - relax.op.add(relax.op.add(ih_gates, bias_ih), hh_gates) + if bidirectional: + h_prev_bwd = self.block_builder.emit( + relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), dtype) ) - elif bias_hh is not None: - gates = self.block_builder.emit( - relax.op.add(relax.op.add(ih_gates, hh_gates), bias_hh) + c_prev_bwd = self.block_builder.emit( + relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), dtype) ) else: - gates = self.block_builder.emit(relax.op.add(ih_gates, hh_gates)) - # Split gates: (batch_size, 4 * hidden_size) -> 4 x (batch_size, hidden_size) - gate_size = hidden_size - i_gate = self.block_builder.emit( - relax.op.strided_slice(gates, axes=[1], begin=[0], end=[gate_size]) - ) - f_gate = self.block_builder.emit( - relax.op.strided_slice(gates, axes=[1], begin=[gate_size], end=[2 * gate_size]) - ) - g_gate = self.block_builder.emit( - relax.op.strided_slice(gates, axes=[1], begin=[2 * gate_size], end=[3 * gate_size]) - ) - o_gate = self.block_builder.emit( - relax.op.strided_slice(gates, axes=[1], begin=[3 * gate_size], end=[4 * gate_size]) - ) - # Apply activations - i_t = self.block_builder.emit(relax.op.sigmoid(i_gate)) - f_t = self.block_builder.emit(relax.op.sigmoid(f_gate)) - g_t = self.block_builder.emit(relax.op.tanh(g_gate)) - o_t = self.block_builder.emit(relax.op.sigmoid(o_gate)) - # Update cell state: c_t = f_t * c_{t-1} + i_t * g_t - c_t = self.block_builder.emit( - relax.op.add(relax.op.multiply(f_t, c_prev), relax.op.multiply(i_t, g_t)) + h_prev_bwd = None + c_prev_bwd = None + + input_reshaped = ( + self.block_builder.emit(relax.op.permute_dims(input_tensor, axes=[1, 0, 2])) + if batch_first + else input_tensor + ) + + output_fwd = self._lstm_cell_unroll( + input_reshaped, + weight_ih_fwd, + weight_hh_fwd, + bias_ih_fwd, + bias_hh_fwd, + h_prev_fwd, + c_prev_fwd, + seq_len, + hidden_size, + reverse=False, + ) + + if bidirectional: + output_bwd = self._lstm_cell_unroll( + input_reshaped, + weight_ih_bwd, + weight_hh_bwd, + bias_ih_bwd, + bias_hh_bwd, + h_prev_bwd, + c_prev_bwd, + seq_len, + hidden_size, + reverse=True, ) - # Update hidden state: h_t = o_t * tanh(c_t) - h_t = self.block_builder.emit(relax.op.multiply(o_t, relax.op.tanh(c_t))) - # Store output - outputs.append(h_t) - # Update for next iteration - h_prev = h_t - c_prev = c_t - # Stack outputs: (seq_len, batch_size, hidden_size) - output = self.block_builder.emit(relax.op.stack(outputs, axis=0)) - # Reshape back to batch_first if needed + output = self.block_builder.emit(relax.op.concat([output_fwd, output_bwd], axis=2)) + else: + output = output_fwd + if batch_first: # (seq_len, batch_size, hidden_size) -> (batch_size, seq_len, hidden_size) output = self.block_builder.emit(relax.op.permute_dims(output, axes=[1, 0, 2])) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 78a8a09a3cf4..cb8de68b3746 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -42,6 +42,37 @@ def verify_model( tvm.ir.assert_structural_equal(mod, expected) +def verify_model_numerically(torch_model, example_args, rtol=1e-4, atol=1e-5): + """Verify model by comparing numerical outputs between PyTorch and TVM.""" + with torch.no_grad(): + pytorch_output = torch_model(*example_args) + + exported_program = export(torch_model, args=example_args) + mod = from_exported_program(exported_program) + target = tvm.target.Target("llvm") + ex = relax.build(mod, target) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + tvm_args = [tvm.runtime.tensor(arg.numpy()) for arg in example_args] + tvm_output = vm["main"](*tvm_args) + + if hasattr(tvm_output, "numpy"): + tvm_output_np = tvm_output.numpy() + else: + tvm_output_np = tvm_output[0].numpy() + + pytorch_output_np = ( + pytorch_output.numpy() + if isinstance(pytorch_output, torch.Tensor) + else pytorch_output[0].numpy() + ) + + assert ( + pytorch_output_np.shape == tvm_output_np.shape + ), f"Shape mismatch: PyTorch {pytorch_output_np.shape} vs TVM {tvm_output_np.shape}" + np.testing.assert_allclose(pytorch_output_np, tvm_output_np, rtol=rtol, atol=atol) + + operator_basic_unary = [ (torch.abs, R.abs), (torch.acos, R.acos), @@ -7616,74 +7647,40 @@ def main( def test_lstm(): - class BasicLSTM(nn.Module): - def __init__(self): + class LSTM(nn.Module): + def __init__(self, input_size, hidden_size, batch_first, bidirectional): super().__init__() self.lstm = nn.LSTM( - input_size=4, - hidden_size=8, + input_size=input_size, + hidden_size=hidden_size, num_layers=1, - batch_first=True, - bidirectional=False, + batch_first=batch_first, + bidirectional=bidirectional, ) def forward(self, x): y, _ = self.lstm(x) return y + # Unidirectional LSTM with batch_first=True torch.manual_seed(42) x = torch.randn(2, 3, 4, dtype=torch.float32) - model = BasicLSTM() - with torch.no_grad(): - pytorch_output = model(x) - exported_program = export(model, args=(x,)) - mod = from_exported_program(exported_program) - target = tvm.target.Target("llvm") - ex = relax.build(mod, target) - vm = relax.VirtualMachine(ex, tvm.cpu()) - x_tvm = tvm.runtime.tensor(x.numpy()) - tvm_output = vm["main"](x_tvm) - if hasattr(tvm_output, "numpy"): - tvm_output_np = tvm_output.numpy() - else: - tvm_output_np = tvm_output[0].numpy() - assert ( - pytorch_output.shape == tvm_output_np.shape - ), f"Shape mismatch: PyTorch {pytorch_output.shape} vs TVM {tvm_output_np.shape}" - np.testing.assert_allclose(pytorch_output.numpy(), tvm_output_np, rtol=1e-4, atol=1e-5) - - class SeqFirstLSTM(nn.Module): - def __init__(self): - super().__init__() - self.lstm = nn.LSTM( - input_size=3, - hidden_size=6, - num_layers=1, - batch_first=False, - bidirectional=False, - ) - - def forward(self, x): - y, _ = self.lstm(x) - return y + verify_model_numerically(LSTM(4, 8, batch_first=True, bidirectional=False), (x,)) + # Unidirectional LSTM with batch_first=False torch.manual_seed(43) x2 = torch.randn(4, 2, 3, dtype=torch.float32) - model2 = SeqFirstLSTM() - with torch.no_grad(): - pytorch_output2 = model2(x2) - exported_program2 = export(model2, args=(x2,)) - mod2 = from_exported_program(exported_program2) - ex2 = relax.build(mod2, target) - vm2 = relax.VirtualMachine(ex2, tvm.cpu()) - x2_tvm = tvm.runtime.tensor(x2.numpy()) - tvm_output2 = vm2["main"](x2_tvm) - if hasattr(tvm_output2, "numpy"): - tvm_output2_np = tvm_output2.numpy() - else: - tvm_output2_np = tvm_output2[0].numpy() - assert pytorch_output2.shape == tvm_output2_np.shape - np.testing.assert_allclose(pytorch_output2.numpy(), tvm_output2_np, rtol=1e-4, atol=1e-5) + verify_model_numerically(LSTM(3, 6, batch_first=False, bidirectional=False), (x2,)) + + # Bidirectional LSTM with batch_first=True + torch.manual_seed(44) + x3 = torch.randn(2, 3, 4, dtype=torch.float32) + verify_model_numerically(LSTM(4, 8, batch_first=True, bidirectional=True), (x3,)) + + # Bidirectional LSTM with batch_first=False + torch.manual_seed(45) + x4 = torch.randn(4, 2, 3, dtype=torch.float32) + verify_model_numerically(LSTM(3, 6, batch_first=False, bidirectional=True), (x4,)) def test_tensor_none_tuple(): From 4b017974abed575c99017f0bf19795059aceae2c Mon Sep 17 00:00:00 2001 From: "Guan-Ming (Wesley) Chiu" <105915352+guan404ming@users.noreply.github.com> Date: Fri, 28 Nov 2025 10:42:02 +0800 Subject: [PATCH 2/2] Update tests and add fallback --- .../torch/exported_program_translator.py | 44 ++++++++++++++----- .../test_frontend_from_exported_program.py | 5 ++- 2 files changed, 37 insertions(+), 12 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 7f9917437350..9d88ab886575 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -464,16 +464,40 @@ def _lstm(self, node: fx.Node) -> relax.Var: dtype = input_tensor.struct_info.dtype params_per_direction = 4 if has_biases else 2 - weight_ih_fwd = params[0] if params else None - weight_hh_fwd = params[1] if params and len(params) > 1 else None - bias_ih_fwd = params[2] if params and has_biases and len(params) > 2 else None - bias_hh_fwd = params[3] if params and has_biases and len(params) > 3 else None - - if bidirectional and params and len(params) >= params_per_direction * 2: - weight_ih_bwd = params[params_per_direction] - weight_hh_bwd = params[params_per_direction + 1] - bias_ih_bwd = params[params_per_direction + 2] if has_biases else None - bias_hh_bwd = params[params_per_direction + 3] if has_biases else None + # Extract or create forward direction weights + if params and len(params) >= 2: + weight_ih_fwd = params[0] + weight_hh_fwd = params[1] + bias_ih_fwd = params[2] if has_biases and len(params) > 2 else None + bias_hh_fwd = params[3] if has_biases and len(params) > 3 else None + else: + # Fallback: create zero weights + weight_ih_fwd = self.block_builder.emit( + relax.op.zeros(relax.ShapeExpr((4 * hidden_size, input_size)), dtype) + ) + weight_hh_fwd = self.block_builder.emit( + relax.op.zeros(relax.ShapeExpr((4 * hidden_size, hidden_size)), dtype) + ) + bias_ih_fwd = None + bias_hh_fwd = None + + # Extract or create backward direction weights if bidirectional + if bidirectional: + if params and len(params) >= params_per_direction * 2: + weight_ih_bwd = params[params_per_direction] + weight_hh_bwd = params[params_per_direction + 1] + bias_ih_bwd = params[params_per_direction + 2] if has_biases else None + bias_hh_bwd = params[params_per_direction + 3] if has_biases else None + else: + # Fallback: create zero weights + weight_ih_bwd = self.block_builder.emit( + relax.op.zeros(relax.ShapeExpr((4 * hidden_size, input_size)), dtype) + ) + weight_hh_bwd = self.block_builder.emit( + relax.op.zeros(relax.ShapeExpr((4 * hidden_size, hidden_size)), dtype) + ) + bias_ih_bwd = None + bias_hh_bwd = None else: weight_ih_bwd = None weight_hh_bwd = None diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index cb8de68b3746..a9322c413913 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -42,7 +42,7 @@ def verify_model( tvm.ir.assert_structural_equal(mod, expected) -def verify_model_numerically(torch_model, example_args, rtol=1e-4, atol=1e-5): +def verify_model_numerically(torch_model, example_args, rtol=1e-7, atol=1e-7): """Verify model by comparing numerical outputs between PyTorch and TVM.""" with torch.no_grad(): pytorch_output = torch_model(*example_args) @@ -70,7 +70,7 @@ def verify_model_numerically(torch_model, example_args, rtol=1e-4, atol=1e-5): assert ( pytorch_output_np.shape == tvm_output_np.shape ), f"Shape mismatch: PyTorch {pytorch_output_np.shape} vs TVM {tvm_output_np.shape}" - np.testing.assert_allclose(pytorch_output_np, tvm_output_np, rtol=rtol, atol=atol) + tvm.testing.assert_allclose(pytorch_output_np, tvm_output_np, rtol=rtol, atol=atol) operator_basic_unary = [ @@ -7646,6 +7646,7 @@ def main( verify_model(MatrixMultiply(), example_args, {}, Expected) +@tvm.testing.requires_llvm def test_lstm(): class LSTM(nn.Module): def __init__(self, input_size, hidden_size, batch_first, bidirectional):