From 9bb95c7ab3c03f4fc24750fa46fa53bb48535662 Mon Sep 17 00:00:00 2001 From: "Guan-Ming (Wesley) Chiu" <105915352+guan404ming@users.noreply.github.com> Date: Sun, 30 Nov 2025 00:21:24 +0800 Subject: [PATCH] Implement bidirectional GRU --- .../torch/exported_program_translator.py | 486 +++++++++--------- .../test_frontend_from_exported_program.py | 86 ++++ 2 files changed, 327 insertions(+), 245 deletions(-) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index fc0ca1820940..18bda9e1ec74 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -609,292 +609,288 @@ def _lstm(self, node: fx.Node) -> relax.Var: output = self.block_builder.emit(relax.op.permute_dims(output, axes=[1, 0, 2])) return output - def _gru(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - input_tensor = args[0] - hx = args[1] if len(args) > 1 else None - params = args[2] if len(args) > 2 else None - has_biases = args[3] if len(args) > 3 else True - num_layers = args[4] if len(args) > 4 else 1 - _dropout = args[5] if len(args) > 5 else 0.0 # Not used in inference - _train = args[6] if len(args) > 6 else False # Not used in inference - bidirectional = args[7] if len(args) > 7 else False - batch_first = args[8] if len(args) > 8 else False - - if bidirectional: - raise NotImplementedError("Bidirectional GRU is not yet supported") - - input_shape = self.shape_of(input_tensor) - if batch_first: - batch_size, seq_len, input_size = input_shape - else: - seq_len, batch_size, input_size = input_shape - - if isinstance(seq_len, tvm.tir.IntImm): - seq_len = seq_len.value - if isinstance(batch_size, tvm.tir.IntImm): - batch_size = batch_size.value - if isinstance(input_size, tvm.tir.IntImm): - input_size = input_size.value + def _gru_cell_unroll( + self, + input_reshaped, + weight_ih, + weight_hh, + bias_ih, + bias_hh, + h_prev, + seq_len, + hidden_size, + dtype, + reverse=False, + ): + """Unroll GRU cells for a single direction.""" + gate_size = hidden_size - if params and len(params) >= 2: - # For multi-layer, we need to extract the first layer's weights - # to determine hidden size - if num_layers > 1: - # Multi-layer: params[0] is first layer's weight_ih - weight_ih = params[0] - else: - # Single layer: params[0] is weight_ih - weight_ih = params[0] - # Extract hidden size from weight dimensions - # weight_ih has shape (3 * hidden_size, input_size) - weight_ih_shape = self.shape_of(weight_ih) - hidden_size = weight_ih_shape[0] // 3 # 3 gates: reset, update, new - else: - # Fallback to a default hidden size - hidden_size = 16 + # Split weights by gates: PyTorch GRU gate order: reset, update, new (r, z, n) + # Reset gate weights + weight_ih_r = self.block_builder.emit( + relax.op.strided_slice(weight_ih, axes=[0], begin=[0], end=[gate_size]) + ) + weight_hh_r = self.block_builder.emit( + relax.op.strided_slice(weight_hh, axes=[0], begin=[0], end=[gate_size]) + ) - # Implement actual GRU computation using Relax operations - # GRU equations: - # r_t = sigmoid(W_ir * x_t + b_ir + W_hr * h_{t-1} + b_hr) - # z_t = sigmoid(W_iz * x_t + b_iz + W_hz * h_{t-1} + b_hz) - # n_t = tanh(W_in * x_t + b_in + r_t * (W_hn * h_{t-1} + b_hn)) - # h_t = (1 - z_t) * n_t + z_t * h_{t-1} - dtype = input_tensor.struct_info.dtype + # Update gate weights + weight_ih_z = self.block_builder.emit( + relax.op.strided_slice(weight_ih, axes=[0], begin=[gate_size], end=[2 * gate_size]) + ) + weight_hh_z = self.block_builder.emit( + relax.op.strided_slice(weight_hh, axes=[0], begin=[gate_size], end=[2 * gate_size]) + ) - # Reshape input for processing - if batch_first: - # Input: (batch, seq_len, input_size) -> (seq_len, batch, input_size) - input_reshaped = self.block_builder.emit( - relax.op.permute_dims(input_tensor, axes=[1, 0, 2]) - ) - else: - input_reshaped = input_tensor + # New gate weights + weight_ih_n = self.block_builder.emit( + relax.op.strided_slice(weight_ih, axes=[0], begin=[2 * gate_size], end=[3 * gate_size]) + ) + weight_hh_n = self.block_builder.emit( + relax.op.strided_slice(weight_hh, axes=[0], begin=[2 * gate_size], end=[3 * gate_size]) + ) - # Initialize hidden states for all layers - if hx is not None: - # hx shape: (num_layers, batch_size, hidden_size) - h_states = [] - for layer in range(num_layers): - h_layer = self.block_builder.emit( - relax.op.take(hx, relax.const(layer, "int64"), axis=0, mode="clip") - ) - h_states.append(h_layer) - else: - h_states = [] - for layer in range(num_layers): - h_layer = self.block_builder.emit( - relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), dtype) - ) - h_states.append(h_layer) + # Transpose weights for matmul + weight_ih_r_t = self.block_builder.emit(relax.op.permute_dims(weight_ih_r, axes=[1, 0])) + weight_hh_r_t = self.block_builder.emit(relax.op.permute_dims(weight_hh_r, axes=[1, 0])) + weight_ih_z_t = self.block_builder.emit(relax.op.permute_dims(weight_ih_z, axes=[1, 0])) + weight_hh_z_t = self.block_builder.emit(relax.op.permute_dims(weight_hh_z, axes=[1, 0])) + weight_ih_n_t = self.block_builder.emit(relax.op.permute_dims(weight_ih_n, axes=[1, 0])) + weight_hh_n_t = self.block_builder.emit(relax.op.permute_dims(weight_hh_n, axes=[1, 0])) outputs = [] + time_steps = range(seq_len - 1, -1, -1) if reverse else range(seq_len) - for t in range(seq_len): + for t in time_steps: # Get input at time t: (batch_size, input_size) x_t = self.block_builder.emit( relax.op.take(input_reshaped, relax.const(t, "int64"), axis=0, mode="clip") ) - # Process through each layer - current_input = x_t - new_h_states = [] - - for layer in range(num_layers): - # Get layer parameters - if params and len(params) >= 4 * num_layers: - # Multi-layer case: params are organized as - # [layer0_ih, layer0_hh, layer0_bias_ih, layer0_bias_hh, layer1_ih, ...] - param_offset = layer * 4 - weight_ih = params[param_offset] - weight_hh = params[param_offset + 1] - bias_ih = params[param_offset + 2] if has_biases else None - bias_hh = params[param_offset + 3] if has_biases else None - elif params and len(params) >= 4: - # Single layer case - weight_ih = params[0] - weight_hh = params[1] - bias_ih = params[2] if has_biases else None - bias_hh = params[3] if has_biases else None - else: - # Fallback: create zero weights - weight_ih = self.block_builder.emit( - relax.op.zeros( - relax.ShapeExpr( - (3 * hidden_size, input_size if layer == 0 else hidden_size) - ), - dtype, - ) - ) - weight_hh = self.block_builder.emit( - relax.op.zeros(relax.ShapeExpr((3 * hidden_size, hidden_size)), dtype) - ) - bias_ih = None - bias_hh = None - - # Get previous hidden state for this layer - h_prev = h_states[layer] - - # Split weights by gates: PyTorch GRU gate order: reset, update, new (r, z, n) - gate_size = hidden_size - - # Reset gate weights - weight_ih_r = self.block_builder.emit( - relax.op.strided_slice(weight_ih, axes=[0], begin=[0], end=[gate_size]) + # Compute reset gate: r_t = sigmoid(W_ir * x_t + b_ir + W_hr * h_{t-1} + b_hr) + r_ih = self.block_builder.emit(relax.op.linear_algebra.matmul(x_t, weight_ih_r_t)) + r_hh = self.block_builder.emit(relax.op.linear_algebra.matmul(h_prev, weight_hh_r_t)) + if bias_ih is not None and bias_hh is not None: + bias_ih_r = self.block_builder.emit( + relax.op.strided_slice(bias_ih, axes=[0], begin=[0], end=[gate_size]) ) - weight_hh_r = self.block_builder.emit( - relax.op.strided_slice(weight_hh, axes=[0], begin=[0], end=[gate_size]) + bias_hh_r = self.block_builder.emit( + relax.op.strided_slice(bias_hh, axes=[0], begin=[0], end=[gate_size]) ) + r_t = self.block_builder.emit( + relax.op.sigmoid( + relax.op.add(relax.op.add(relax.op.add(r_ih, bias_ih_r), r_hh), bias_hh_r) + ) + ) + else: + r_t = self.block_builder.emit(relax.op.sigmoid(relax.op.add(r_ih, r_hh))) - # Update gate weights - weight_ih_z = self.block_builder.emit( + # Compute update gate: z_t = sigmoid(W_iz * x_t + b_iz + W_hz * h_{t-1} + b_hz) + z_ih = self.block_builder.emit(relax.op.linear_algebra.matmul(x_t, weight_ih_z_t)) + z_hh = self.block_builder.emit(relax.op.linear_algebra.matmul(h_prev, weight_hh_z_t)) + if bias_ih is not None and bias_hh is not None: + bias_ih_z = self.block_builder.emit( relax.op.strided_slice( - weight_ih, axes=[0], begin=[gate_size], end=[2 * gate_size] + bias_ih, axes=[0], begin=[gate_size], end=[2 * gate_size] ) ) - weight_hh_z = self.block_builder.emit( + bias_hh_z = self.block_builder.emit( relax.op.strided_slice( - weight_hh, axes=[0], begin=[gate_size], end=[2 * gate_size] + bias_hh, axes=[0], begin=[gate_size], end=[2 * gate_size] + ) + ) + z_t = self.block_builder.emit( + relax.op.sigmoid( + relax.op.add(relax.op.add(relax.op.add(z_ih, bias_ih_z), z_hh), bias_hh_z) ) ) + else: + z_t = self.block_builder.emit(relax.op.sigmoid(relax.op.add(z_ih, z_hh))) - # New gate weights - weight_ih_n = self.block_builder.emit( + # Compute new gate: n_t = tanh(W_in * x_t + b_in + r_t * (W_hn * h_{t-1} + b_hn)) + n_ih = self.block_builder.emit(relax.op.linear_algebra.matmul(x_t, weight_ih_n_t)) + n_hh = self.block_builder.emit(relax.op.linear_algebra.matmul(h_prev, weight_hh_n_t)) + if bias_ih is not None and bias_hh is not None: + bias_ih_n = self.block_builder.emit( relax.op.strided_slice( - weight_ih, axes=[0], begin=[2 * gate_size], end=[3 * gate_size] + bias_ih, axes=[0], begin=[2 * gate_size], end=[3 * gate_size] ) ) - weight_hh_n = self.block_builder.emit( + bias_hh_n = self.block_builder.emit( relax.op.strided_slice( - weight_hh, axes=[0], begin=[2 * gate_size], end=[3 * gate_size] + bias_hh, axes=[0], begin=[2 * gate_size], end=[3 * gate_size] ) ) - - # Transpose weights for matmul - weight_ih_r_t = self.block_builder.emit( - relax.op.permute_dims(weight_ih_r, axes=[1, 0]) - ) - weight_hh_r_t = self.block_builder.emit( - relax.op.permute_dims(weight_hh_r, axes=[1, 0]) - ) - weight_ih_z_t = self.block_builder.emit( - relax.op.permute_dims(weight_ih_z, axes=[1, 0]) - ) - weight_hh_z_t = self.block_builder.emit( - relax.op.permute_dims(weight_hh_z, axes=[1, 0]) - ) - weight_ih_n_t = self.block_builder.emit( - relax.op.permute_dims(weight_ih_n, axes=[1, 0]) - ) - weight_hh_n_t = self.block_builder.emit( - relax.op.permute_dims(weight_hh_n, axes=[1, 0]) - ) - - # Compute reset gate: r_t = sigmoid(W_ir * x_t + b_ir + W_hr * h_{t-1} + b_hr) - r_ih = self.block_builder.emit( - relax.op.linear_algebra.matmul(current_input, weight_ih_r_t) - ) - r_hh = self.block_builder.emit( - relax.op.linear_algebra.matmul(h_prev, weight_hh_r_t) - ) - if bias_ih is not None and bias_hh is not None: - bias_ih_r = self.block_builder.emit( - relax.op.strided_slice(bias_ih, axes=[0], begin=[0], end=[gate_size]) - ) - bias_hh_r = self.block_builder.emit( - relax.op.strided_slice(bias_hh, axes=[0], begin=[0], end=[gate_size]) - ) - r_t = self.block_builder.emit( - relax.op.sigmoid( - relax.op.add( - relax.op.add(relax.op.add(r_ih, bias_ih_r), r_hh), bias_hh_r - ) + n_t = self.block_builder.emit( + relax.op.tanh( + relax.op.add( + relax.op.add(n_ih, bias_ih_n), + relax.op.multiply(r_t, relax.op.add(n_hh, bias_hh_n)), ) ) - else: - r_t = self.block_builder.emit(relax.op.sigmoid(relax.op.add(r_ih, r_hh))) - - # Compute update gate: z_t = sigmoid(W_iz * x_t + b_iz + W_hz * h_{t-1} + b_hz) - z_ih = self.block_builder.emit( - relax.op.linear_algebra.matmul(current_input, weight_ih_z_t) ) - z_hh = self.block_builder.emit( - relax.op.linear_algebra.matmul(h_prev, weight_hh_z_t) + else: + n_t = self.block_builder.emit( + relax.op.tanh(relax.op.add(n_ih, relax.op.multiply(r_t, n_hh))) ) - if bias_ih is not None and bias_hh is not None: - bias_ih_z = self.block_builder.emit( - relax.op.strided_slice( - bias_ih, axes=[0], begin=[gate_size], end=[2 * gate_size] - ) - ) - bias_hh_z = self.block_builder.emit( - relax.op.strided_slice( - bias_hh, axes=[0], begin=[gate_size], end=[2 * gate_size] - ) - ) - z_t = self.block_builder.emit( - relax.op.sigmoid( - relax.op.add( - relax.op.add(relax.op.add(z_ih, bias_ih_z), z_hh), bias_hh_z - ) - ) - ) - else: - z_t = self.block_builder.emit(relax.op.sigmoid(relax.op.add(z_ih, z_hh))) - # Compute new gate: n_t = tanh(W_in * x_t + b_in + r_t * (W_hn * h_{t-1} + b_hn)) - n_ih = self.block_builder.emit( - relax.op.linear_algebra.matmul(current_input, weight_ih_n_t) + # Update hidden state: h_t = (1 - z_t) * n_t + z_t * h_{t-1} + one_minus_z = self.block_builder.emit(relax.op.subtract(relax.const(1.0, dtype), z_t)) + h_t = self.block_builder.emit( + relax.op.add(relax.op.multiply(one_minus_z, n_t), relax.op.multiply(z_t, h_prev)) + ) + + outputs.append(h_t) + h_prev = h_t + + if reverse: + outputs = outputs[::-1] + + output = self.block_builder.emit(relax.op.stack(outputs, axis=0)) + return output + + def _gru(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + input_tensor = args[0] + hx = args[1] if len(args) > 1 else None + params = args[2] if len(args) > 2 else None + has_biases = args[3] if len(args) > 3 else True + num_layers = args[4] if len(args) > 4 else 1 + _dropout = args[5] if len(args) > 5 else 0.0 # Not used in inference + _train = args[6] if len(args) > 6 else False # Not used in inference + bidirectional = args[7] if len(args) > 7 else False + batch_first = args[8] if len(args) > 8 else False + + if num_layers > 1: + raise NotImplementedError("Multi-layer GRU is not yet supported") + + input_shape = self.shape_of(input_tensor) + if batch_first: + batch_size, seq_len, input_size = input_shape + else: + seq_len, batch_size, input_size = input_shape + + seq_len = int(seq_len) if isinstance(seq_len, tvm.tir.IntImm) else seq_len + batch_size = int(batch_size) if isinstance(batch_size, tvm.tir.IntImm) else batch_size + input_size = int(input_size) if isinstance(input_size, tvm.tir.IntImm) else input_size + + # Extract hidden size from parameters + # For bidirectional: params has weights for both directions + # params_per_direction = 4 if has_biases else 2 (weight_ih, weight_hh, [bias_ih, bias_hh]) + params_per_direction = 4 if has_biases else 2 + + if params and len(params) >= 2: + # Extract hidden size from weight dimensions + # weight_ih has shape (3 * hidden_size, input_size) + weight_ih_shape = self.shape_of(params[0]) + hidden_size = weight_ih_shape[0] // 3 # 3 gates: reset, update, new + else: + # Fallback to a default hidden size + hidden_size = 16 + + dtype = input_tensor.struct_info.dtype + + # Extract forward direction weights + if params and len(params) >= params_per_direction: + weight_ih_fwd = params[0] + weight_hh_fwd = params[1] + bias_ih_fwd = params[2] if has_biases else None + bias_hh_fwd = params[3] if has_biases else None + else: + # Fallback: create zero weights + weight_ih_fwd = self.block_builder.emit( + relax.op.zeros(relax.ShapeExpr((3 * hidden_size, input_size)), dtype) + ) + weight_hh_fwd = self.block_builder.emit( + relax.op.zeros(relax.ShapeExpr((3 * hidden_size, hidden_size)), dtype) + ) + bias_ih_fwd = None + bias_hh_fwd = None + + # Extract or create backward direction weights if bidirectional + if bidirectional: + if params and len(params) >= params_per_direction * 2: + weight_ih_bwd = params[params_per_direction] + weight_hh_bwd = params[params_per_direction + 1] + bias_ih_bwd = params[params_per_direction + 2] if has_biases else None + bias_hh_bwd = params[params_per_direction + 3] if has_biases else None + else: + # Fallback: create zero weights + weight_ih_bwd = self.block_builder.emit( + relax.op.zeros(relax.ShapeExpr((3 * hidden_size, input_size)), dtype) ) - n_hh = self.block_builder.emit( - relax.op.linear_algebra.matmul(h_prev, weight_hh_n_t) + weight_hh_bwd = self.block_builder.emit( + relax.op.zeros(relax.ShapeExpr((3 * hidden_size, hidden_size)), dtype) ) - if bias_ih is not None and bias_hh is not None: - bias_ih_n = self.block_builder.emit( - relax.op.strided_slice( - bias_ih, axes=[0], begin=[2 * gate_size], end=[3 * gate_size] - ) - ) - bias_hh_n = self.block_builder.emit( - relax.op.strided_slice( - bias_hh, axes=[0], begin=[2 * gate_size], end=[3 * gate_size] - ) - ) - n_t = self.block_builder.emit( - relax.op.tanh( - relax.op.add( - relax.op.add(n_ih, bias_ih_n), - relax.op.multiply(r_t, relax.op.add(n_hh, bias_hh_n)), - ) - ) - ) - else: - n_t = self.block_builder.emit( - relax.op.tanh(relax.op.add(n_ih, relax.op.multiply(r_t, n_hh))) - ) + bias_ih_bwd = None + bias_hh_bwd = None + else: + weight_ih_bwd = None + weight_hh_bwd = None + bias_ih_bwd = None + bias_hh_bwd = None - # Update hidden state: h_t = (1 - z_t) * n_t + z_t * h_{t-1} - one_minus_z = self.block_builder.emit( - relax.op.subtract(relax.const(1.0, dtype), z_t) + # Initialize hidden states + if hx is not None: + h_prev_fwd = self.block_builder.emit( + relax.op.take(hx, relax.const(0, "int64"), axis=0, mode="clip") + ) + if bidirectional: + h_prev_bwd = self.block_builder.emit( + relax.op.take(hx, relax.const(1, "int64"), axis=0, mode="clip") ) - h_t = self.block_builder.emit( - relax.op.add( - relax.op.multiply(one_minus_z, n_t), relax.op.multiply(z_t, h_prev) - ) + else: + h_prev_bwd = None + else: + h_prev_fwd = self.block_builder.emit( + relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), dtype) + ) + if bidirectional: + h_prev_bwd = self.block_builder.emit( + relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), dtype) ) + else: + h_prev_bwd = None - new_h_states.append(h_t) - - current_input = h_t - - # Update hidden states for next time step - h_states = new_h_states + # Reshape input for processing + input_reshaped = ( + self.block_builder.emit(relax.op.permute_dims(input_tensor, axes=[1, 0, 2])) + if batch_first + else input_tensor + ) - # Store output (from the last layer) - outputs.append(h_states[-1]) + # Process forward direction + output_fwd = self._gru_cell_unroll( + input_reshaped, + weight_ih_fwd, + weight_hh_fwd, + bias_ih_fwd, + bias_hh_fwd, + h_prev_fwd, + seq_len, + hidden_size, + dtype, + reverse=False, + ) - # Stack outputs: (seq_len, batch_size, hidden_size) - output = self.block_builder.emit(relax.op.stack(outputs, axis=0)) + # Process backward direction if bidirectional + if bidirectional: + output_bwd = self._gru_cell_unroll( + input_reshaped, + weight_ih_bwd, + weight_hh_bwd, + bias_ih_bwd, + bias_hh_bwd, + h_prev_bwd, + seq_len, + hidden_size, + dtype, + reverse=True, + ) + # Concatenate forward and backward outputs along feature dimension + output = self.block_builder.emit(relax.op.concat([output_fwd, output_bwd], axis=2)) + else: + output = output_fwd # Reshape back to batch_first if needed if batch_first: diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 091f0a4a29c5..43eeb50b89fa 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -7992,6 +7992,92 @@ def forward(self, x): assert pytorch_output2.shape == tvm_output2_np.shape tvm.testing.assert_allclose(pytorch_output2.numpy(), tvm_output2_np, rtol=1e-4, atol=1e-5) + # Test bidirectional GRU with batch_first=True + class BidirectionalGRU(nn.Module): + def __init__(self): + super().__init__() + self.gru = nn.GRU( + input_size=4, + hidden_size=5, + num_layers=1, + batch_first=True, + bidirectional=True, + ) + + def forward(self, x): + y, _ = self.gru(x) + return y + + torch.manual_seed(44) + x3 = torch.randn(2, 3, 4, dtype=torch.float32) + model3 = BidirectionalGRU() + with torch.no_grad(): + pytorch_output3 = model3(x3) + + # Verify output shape is correct (hidden_size * 2 due to bidirectional) + assert pytorch_output3.shape == ( + 2, + 3, + 10, + ), f"Expected shape (2, 3, 10), got {pytorch_output3.shape}" + + exported_program3 = export(model3, args=(x3,)) + mod3 = from_exported_program(exported_program3) + ex3 = relax.build(mod3, target) + vm3 = relax.VirtualMachine(ex3, tvm.cpu()) + x3_tvm = tvm.runtime.tensor(x3.numpy()) + tvm_output3 = vm3["main"](x3_tvm) + if hasattr(tvm_output3, "numpy"): + tvm_output3_np = tvm_output3.numpy() + else: + tvm_output3_np = tvm_output3[0].numpy() + assert ( + pytorch_output3.shape == tvm_output3_np.shape + ), f"Shape mismatch: PyTorch {pytorch_output3.shape} vs TVM {tvm_output3_np.shape}" + tvm.testing.assert_allclose(pytorch_output3.numpy(), tvm_output3_np, rtol=1e-4, atol=1e-5) + + # Test bidirectional GRU with batch_first=False + class SeqFirstBidirectionalGRU(nn.Module): + def __init__(self): + super().__init__() + self.gru = nn.GRU( + input_size=3, + hidden_size=4, + num_layers=1, + batch_first=False, + bidirectional=True, + ) + + def forward(self, x): + y, _ = self.gru(x) + return y + + torch.manual_seed(45) + x4 = torch.randn(4, 2, 3, dtype=torch.float32) # (seq_len, batch, input_size) + model4 = SeqFirstBidirectionalGRU() + with torch.no_grad(): + pytorch_output4 = model4(x4) + + # Verify output shape (seq_len, batch, hidden_size * 2) + assert pytorch_output4.shape == ( + 4, + 2, + 8, + ), f"Expected shape (4, 2, 8), got {pytorch_output4.shape}" + + exported_program4 = export(model4, args=(x4,)) + mod4 = from_exported_program(exported_program4) + ex4 = relax.build(mod4, target) + vm4 = relax.VirtualMachine(ex4, tvm.cpu()) + x4_tvm = tvm.runtime.tensor(x4.numpy()) + tvm_output4 = vm4["main"](x4_tvm) + if hasattr(tvm_output4, "numpy"): + tvm_output4_np = tvm_output4.numpy() + else: + tvm_output4_np = tvm_output4[0].numpy() + assert pytorch_output4.shape == tvm_output4_np.shape + tvm.testing.assert_allclose(pytorch_output4.numpy(), tvm_output4_np, rtol=1e-4, atol=1e-5) + def test_dynamic_shape_with_range_constraints(): class DynamicModel(torch.nn.Module):