Skip to content

feat: add LSTM, GRU, and RNN ONNX node converters#256

Open
nik875 wants to merge 1 commit into
ENOT-AutoDL:mainfrom
nik875:feat/lstm-gru-rnn-converters
Open

feat: add LSTM, GRU, and RNN ONNX node converters#256
nik875 wants to merge 1 commit into
ENOT-AutoDL:mainfrom
nik875:feat/lstm-gru-rnn-converters

Conversation

@nik875

@nik875 nik875 commented May 21, 2026

Copy link
Copy Markdown

Closes #136

Summary

This PR adds ONNX-to-PyTorch converters for the LSTM, GRU, and RNN operators, covering opset versions 7 and 14.

All three operators are registered via @add_converter in a new onnx2torch/node_converters/rnn.py module and wired into
init.py. operators.md is updated to mark them as supported.

Implementation notes

LSTM wraps nn.LSTM. ONNX gate order [I, O, F, C] is reordered to PyTorch [I, F, G, O] at conversion time by slicing and
re-concatenating weight chunks. Initial hidden/cell states are supported as optional graph inputs. sequence_lens, peephole
weights, input_forget=1, clip, and layout=1 raise NotImplementedError.

GRU uses a manual per-timestep loop rather than wrapping nn.GRU. This is necessary because the ONNX default
(linear_before_reset=0) computes:

ht = tanh(Wh·x + Wbh + Rh·(r⊙h) + Rbh)

whereas PyTorch nn.GRU implements linear_before_reset=1:

ht = tanh(Wh·x + Wbh + r⊙(Rh·h + Rbh))

These are mathematically different for non-zero hidden states. The manual loop correctly implements both formulas; weights are
stored as buffers in ONNX gate order [Z, R, H] without reordering. sequence_lens, clip, and layout=1 raise
NotImplementedError.

RNN wraps nn.RNN. No gate reordering is required. The activations attribute is respected for Tanh (default) and Relu.
sequence_lens, clip, and layout=1 raise NotImplementedError.

All three converters require W and R to be graph initializers (static weights). Bidirectional and reverse directions are
supported for all three operators.

Output shape

PyTorch recurrent layers return Y with shape [seq, batch, num_directions * hidden_size]. The converters reshape and permute
this to the ONNX-expected [seq, num_directions, batch, hidden_size].

Tests

32 new test cases in tests/node_converters/rnn_test.py using the standard make_model_from_nodes + check_onnx_model pattern:

  • LSTM: forward (with/without bias, multiple shapes), output subsets (Y only), bidirectional, reverse, initial h/c states
  • GRU: forward (with/without bias, multiple shapes), output subsets, bidirectional, reverse
  • RNN: forward (with/without bias, multiple shapes), output subsets, bidirectional, reverse

GRU round-trip tests (onnx → torch → onnx → ORT) use atol_onnx_torch2onnx=1e-4 because the manual loop exports as individual
elementwise ops rather than a single GRU node, causing minor floating-point order differences.

I was able to successfully load this CRNN model into PyTorch after these changes. I can not promise that it will work for every RNN model out-of-the-box.

Known limitations (documented in operators.md)

Operator Unsupported
LSTM layout=1, clip, input_forget=1, peephole weights, sequence_lens, dynamic weights
GRU layout=1, clip, linear_before_reset=1, sequence_lens
RNN layout=1, clip, sequence_lens, dynamic weights

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Can we increase the function of rnn structures such as lstm

1 participant