feat: add LSTM, GRU, and RNN ONNX node converters#256
Open
nik875 wants to merge 1 commit into
Open
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Closes #136
Summary
This PR adds ONNX-to-PyTorch converters for the LSTM, GRU, and RNN operators, covering opset versions 7 and 14.
All three operators are registered via @add_converter in a new onnx2torch/node_converters/rnn.py module and wired into
init.py. operators.md is updated to mark them as supported.
Implementation notes
LSTM wraps nn.LSTM. ONNX gate order [I, O, F, C] is reordered to PyTorch [I, F, G, O] at conversion time by slicing and
re-concatenating weight chunks. Initial hidden/cell states are supported as optional graph inputs. sequence_lens, peephole
weights, input_forget=1, clip, and layout=1 raise NotImplementedError.
GRU uses a manual per-timestep loop rather than wrapping nn.GRU. This is necessary because the ONNX default
(linear_before_reset=0) computes:
ht = tanh(Wh·x + Wbh + Rh·(r⊙h) + Rbh)
whereas PyTorch nn.GRU implements linear_before_reset=1:
ht = tanh(Wh·x + Wbh + r⊙(Rh·h + Rbh))
These are mathematically different for non-zero hidden states. The manual loop correctly implements both formulas; weights are
stored as buffers in ONNX gate order [Z, R, H] without reordering. sequence_lens, clip, and layout=1 raise
NotImplementedError.
RNN wraps nn.RNN. No gate reordering is required. The activations attribute is respected for Tanh (default) and Relu.
sequence_lens, clip, and layout=1 raise NotImplementedError.
All three converters require W and R to be graph initializers (static weights). Bidirectional and reverse directions are
supported for all three operators.
Output shape
PyTorch recurrent layers return Y with shape [seq, batch, num_directions * hidden_size]. The converters reshape and permute
this to the ONNX-expected [seq, num_directions, batch, hidden_size].
Tests
32 new test cases in tests/node_converters/rnn_test.py using the standard make_model_from_nodes + check_onnx_model pattern:
GRU round-trip tests (onnx → torch → onnx → ORT) use atol_onnx_torch2onnx=1e-4 because the manual loop exports as individual
elementwise ops rather than a single GRU node, causing minor floating-point order differences.
I was able to successfully load this CRNN model into PyTorch after these changes. I can not promise that it will work for every RNN model out-of-the-box.
Known limitations (documented in operators.md)
layout=1,clip,input_forget=1, peephole weights,sequence_lens, dynamic weightslayout=1,clip,linear_before_reset=1,sequence_lenslayout=1,clip,sequence_lens, dynamic weights