Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
269 changes: 170 additions & 99 deletions python/tvm/relax/frontend/torch/exported_program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,46 +378,106 @@ def _upsample_bicubic2d(self, node: fx.node) -> relax.Var:
align_corners=align_corners,
)

def _lstm_cell_unroll(
self,
input_reshaped,
weight_ih,
weight_hh,
bias_ih,
bias_hh,
h_prev,
c_prev,
seq_len,
hidden_size,
reverse=False,
):
"""Unroll LSTM cells for a single direction."""
weight_ih_t = self.block_builder.emit(relax.op.permute_dims(weight_ih, axes=[1, 0]))
weight_hh_t = self.block_builder.emit(relax.op.permute_dims(weight_hh, axes=[1, 0]))
outputs = []
time_steps = range(seq_len - 1, -1, -1) if reverse else range(seq_len)

for t in time_steps:
x_t = self.block_builder.emit(
relax.op.take(input_reshaped, relax.const(t, "int64"), axis=0, mode="clip")
)
ih_gates = self.block_builder.emit(relax.op.linear_algebra.matmul(x_t, weight_ih_t))
hh_gates = self.block_builder.emit(relax.op.linear_algebra.matmul(h_prev, weight_hh_t))

gates = self.block_builder.emit(relax.op.add(ih_gates, hh_gates))
if bias_ih is not None:
gates = self.block_builder.emit(relax.op.add(gates, bias_ih))
if bias_hh is not None:
gates = self.block_builder.emit(relax.op.add(gates, bias_hh))

i_gate = self.block_builder.emit(
relax.op.strided_slice(gates, axes=[1], begin=[0], end=[hidden_size])
)
f_gate = self.block_builder.emit(
relax.op.strided_slice(gates, axes=[1], begin=[hidden_size], end=[2 * hidden_size])
)
g_gate = self.block_builder.emit(
relax.op.strided_slice(
gates, axes=[1], begin=[2 * hidden_size], end=[3 * hidden_size]
)
)
o_gate = self.block_builder.emit(
relax.op.strided_slice(
gates, axes=[1], begin=[3 * hidden_size], end=[4 * hidden_size]
)
)

i_t = self.block_builder.emit(relax.op.sigmoid(i_gate))
f_t = self.block_builder.emit(relax.op.sigmoid(f_gate))
g_t = self.block_builder.emit(relax.op.tanh(g_gate))
o_t = self.block_builder.emit(relax.op.sigmoid(o_gate))

c_t = self.block_builder.emit(
relax.op.add(relax.op.multiply(f_t, c_prev), relax.op.multiply(i_t, g_t))
)
h_t = self.block_builder.emit(relax.op.multiply(o_t, relax.op.tanh(c_t)))

outputs.append(h_t)
h_prev = h_t
c_prev = c_t

if reverse:
outputs = outputs[::-1]

output = self.block_builder.emit(relax.op.stack(outputs, axis=0))
return output

def _lstm(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
input_tensor = args[0]
hx = args[1] if len(args) > 1 else None
params = args[2] if len(args) > 2 else None
has_biases = args[3] if len(args) > 3 else True
num_layers = args[4] if len(args) > 4 else 1
_dropout = args[5] if len(args) > 5 else 0.0 # Not used in inference
_train = args[6] if len(args) > 6 else False # Not used in inference
bidirectional = args[7] if len(args) > 7 else False
batch_first = args[8] if len(args) > 8 else False
if bidirectional:
raise NotImplementedError("Bidirectional LSTM is not yet supported")

if num_layers > 1:
raise NotImplementedError("Multi-layer LSTM is not yet supported")

input_shape = self.shape_of(input_tensor)
if batch_first:
# Input shape: (batch, seq_len, input_size)
batch_size, seq_len, input_size = input_shape
else:
# Input shape: (seq_len, batch, input_size)
seq_len, batch_size, input_size = input_shape

if isinstance(seq_len, tvm.tir.IntImm):
seq_len = seq_len.value
if isinstance(batch_size, tvm.tir.IntImm):
batch_size = batch_size.value
if isinstance(input_size, tvm.tir.IntImm):
input_size = input_size.value
seq_len = int(seq_len) if isinstance(seq_len, tvm.tir.IntImm) else seq_len
batch_size = int(batch_size) if isinstance(batch_size, tvm.tir.IntImm) else batch_size
input_size = int(input_size) if isinstance(input_size, tvm.tir.IntImm) else input_size
# Extract hidden size from the LSTM parameters
# The parameters are: [weight_ih, weight_hh, bias_ih, bias_hh]
# weight_ih shape: (4 * hidden_size, input_size)
# weight_hh shape: (4 * hidden_size, hidden_size)
if params and len(params) >= 2:
weight_ih = params[0]
weight_hh = params[1]
# Extract hidden size from weight dimensions
# weight_ih has shape (4 * hidden_size, input_size)
weight_ih_shape = self.shape_of(weight_ih)
hidden_size = weight_ih_shape[0] // 4 # 4 gates: input, forget, cell, output
weight_ih_shape = self.shape_of(params[0])
hidden_size = weight_ih_shape[0] // 4
else:
# Fallback to a default hidden size
hidden_size = 16
Comment on lines 481 to 483
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The code falls back to a default hidden_size of 16 when it cannot be inferred from the model parameters. This could lead to unexpected behavior or errors if the actual model has a different hidden size. It would be beneficial to add a warning to notify the user about this fallback, so they are aware of the potential discrepancy.

Expand All @@ -430,109 +490,120 @@ def _lstm(self, node: fx.Node) -> relax.Var:
# c_t = f_t * c_{t-1} + i_t * g_t
# h_t = o_t * tanh(c_t)
dtype = input_tensor.struct_info.dtype
if params and len(params) >= 4:
weight_ih = params[0] # (4 * hidden_size, input_size)
weight_hh = params[1] # (4 * hidden_size, hidden_size)
bias_ih = params[2] if has_biases else None # (4 * hidden_size,)
bias_hh = params[3] if has_biases else None # (4 * hidden_size,)
params_per_direction = 4 if has_biases else 2

# Extract or create forward direction weights
if params and len(params) >= 2:
weight_ih_fwd = params[0]
weight_hh_fwd = params[1]
bias_ih_fwd = params[2] if has_biases and len(params) > 2 else None
bias_hh_fwd = params[3] if has_biases and len(params) > 3 else None
else:
# Fallback: create zero weights
weight_ih = self.block_builder.emit(
weight_ih_fwd = self.block_builder.emit(
relax.op.zeros(relax.ShapeExpr((4 * hidden_size, input_size)), dtype)
)
weight_hh = self.block_builder.emit(
weight_hh_fwd = self.block_builder.emit(
relax.op.zeros(relax.ShapeExpr((4 * hidden_size, hidden_size)), dtype)
)
bias_ih = None
bias_hh = None
# Initialize hidden and cell states
bias_ih_fwd = None
bias_hh_fwd = None

# Extract or create backward direction weights if bidirectional
if bidirectional:
if params and len(params) >= params_per_direction * 2:
weight_ih_bwd = params[params_per_direction]
weight_hh_bwd = params[params_per_direction + 1]
bias_ih_bwd = params[params_per_direction + 2] if has_biases else None
bias_hh_bwd = params[params_per_direction + 3] if has_biases else None
else:
# Fallback: create zero weights
weight_ih_bwd = self.block_builder.emit(
relax.op.zeros(relax.ShapeExpr((4 * hidden_size, input_size)), dtype)
)
weight_hh_bwd = self.block_builder.emit(
relax.op.zeros(relax.ShapeExpr((4 * hidden_size, hidden_size)), dtype)
)
bias_ih_bwd = None
bias_hh_bwd = None
else:
weight_ih_bwd = None
weight_hh_bwd = None
bias_ih_bwd = None
bias_hh_bwd = None

if hx is not None and len(hx) >= 2:
h_0 = hx[0] # (num_layers, batch_size, hidden_size)
c_0 = hx[1] # (num_layers, batch_size, hidden_size)
# Extract the first layer's hidden state
h_prev = self.block_builder.emit(
h_0, c_0 = hx[0], hx[1]
h_prev_fwd = self.block_builder.emit(
relax.op.take(h_0, relax.const(0, "int64"), axis=0, mode="clip")
)
c_prev = self.block_builder.emit(
c_prev_fwd = self.block_builder.emit(
relax.op.take(c_0, relax.const(0, "int64"), axis=0, mode="clip")
)
if bidirectional:
h_prev_bwd = self.block_builder.emit(
relax.op.take(h_0, relax.const(1, "int64"), axis=0, mode="clip")
)
c_prev_bwd = self.block_builder.emit(
relax.op.take(c_0, relax.const(1, "int64"), axis=0, mode="clip")
)
else:
h_prev_bwd = None
c_prev_bwd = None
else:
h_prev = self.block_builder.emit(
h_prev_fwd = self.block_builder.emit(
relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), dtype)
)
c_prev = self.block_builder.emit(
c_prev_fwd = self.block_builder.emit(
relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), dtype)
)
# Reshape input for processing
if batch_first:
# Input: (batch, seq_len, input_size) -> (seq_len, batch, input_size)
input_reshaped = self.block_builder.emit(
relax.op.permute_dims(input_tensor, axes=[1, 0, 2])
)
else:
input_reshaped = input_tensor
weight_ih_t = self.block_builder.emit(relax.op.permute_dims(weight_ih, axes=[1, 0]))
weight_hh_t = self.block_builder.emit(relax.op.permute_dims(weight_hh, axes=[1, 0]))
outputs = []
for t in range(seq_len):
# Get input at time t: (batch_size, input_size)
x_t = self.block_builder.emit(
relax.op.take(input_reshaped, relax.const(t, "int64"), axis=0, mode="clip")
)
# Compute gates: W_ih * x_t + W_hh * h_{t-1} + bias
# Input-to-hidden: (batch_size, input_size) @ (4*hidden_size, input_size).T
ih_gates = self.block_builder.emit(relax.op.linear_algebra.matmul(x_t, weight_ih_t))

# Hidden-to-hidden: (batch_size, hidden_size) @ (4*hidden_size, hidden_size).T
hh_gates = self.block_builder.emit(relax.op.linear_algebra.matmul(h_prev, weight_hh_t))
# Add biases if present
if bias_ih is not None and bias_hh is not None:
gates = self.block_builder.emit(
relax.op.add(relax.op.add(relax.op.add(ih_gates, bias_ih), hh_gates), bias_hh)
)
elif bias_ih is not None:
gates = self.block_builder.emit(
relax.op.add(relax.op.add(ih_gates, bias_ih), hh_gates)
if bidirectional:
h_prev_bwd = self.block_builder.emit(
relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), dtype)
)
elif bias_hh is not None:
gates = self.block_builder.emit(
relax.op.add(relax.op.add(ih_gates, hh_gates), bias_hh)
c_prev_bwd = self.block_builder.emit(
relax.op.zeros(relax.ShapeExpr((batch_size, hidden_size)), dtype)
)
else:
gates = self.block_builder.emit(relax.op.add(ih_gates, hh_gates))
# Split gates: (batch_size, 4 * hidden_size) -> 4 x (batch_size, hidden_size)
gate_size = hidden_size
i_gate = self.block_builder.emit(
relax.op.strided_slice(gates, axes=[1], begin=[0], end=[gate_size])
)
f_gate = self.block_builder.emit(
relax.op.strided_slice(gates, axes=[1], begin=[gate_size], end=[2 * gate_size])
)
g_gate = self.block_builder.emit(
relax.op.strided_slice(gates, axes=[1], begin=[2 * gate_size], end=[3 * gate_size])
)
o_gate = self.block_builder.emit(
relax.op.strided_slice(gates, axes=[1], begin=[3 * gate_size], end=[4 * gate_size])
)
# Apply activations
i_t = self.block_builder.emit(relax.op.sigmoid(i_gate))
f_t = self.block_builder.emit(relax.op.sigmoid(f_gate))
g_t = self.block_builder.emit(relax.op.tanh(g_gate))
o_t = self.block_builder.emit(relax.op.sigmoid(o_gate))
# Update cell state: c_t = f_t * c_{t-1} + i_t * g_t
c_t = self.block_builder.emit(
relax.op.add(relax.op.multiply(f_t, c_prev), relax.op.multiply(i_t, g_t))
h_prev_bwd = None
c_prev_bwd = None

input_reshaped = (
self.block_builder.emit(relax.op.permute_dims(input_tensor, axes=[1, 0, 2]))
if batch_first
else input_tensor
)

output_fwd = self._lstm_cell_unroll(
input_reshaped,
weight_ih_fwd,
weight_hh_fwd,
bias_ih_fwd,
bias_hh_fwd,
h_prev_fwd,
c_prev_fwd,
seq_len,
hidden_size,
reverse=False,
)

if bidirectional:
output_bwd = self._lstm_cell_unroll(
input_reshaped,
weight_ih_bwd,
weight_hh_bwd,
bias_ih_bwd,
bias_hh_bwd,
h_prev_bwd,
c_prev_bwd,
seq_len,
hidden_size,
reverse=True,
)
# Update hidden state: h_t = o_t * tanh(c_t)
h_t = self.block_builder.emit(relax.op.multiply(o_t, relax.op.tanh(c_t)))
# Store output
outputs.append(h_t)
# Update for next iteration
h_prev = h_t
c_prev = c_t
# Stack outputs: (seq_len, batch_size, hidden_size)
output = self.block_builder.emit(relax.op.stack(outputs, axis=0))
# Reshape back to batch_first if needed
output = self.block_builder.emit(relax.op.concat([output_fwd, output_bwd], axis=2))
else:
output = output_fwd

if batch_first:
# (seq_len, batch_size, hidden_size) -> (batch_size, seq_len, hidden_size)
output = self.block_builder.emit(relax.op.permute_dims(output, axes=[1, 0, 2]))
Expand Down
Loading
Loading