Skip to content

Commit 418c584

Browse files
authored
Use quantizable LSTM in test when flow has quantize=True (#14893)
It makes more sense to use the quantizable version of the LSTM. For example, right now the xnnpack int8 tests pass, even though all tensors are float, since the quantizer is not triggered. Signed-off-by: Erik Lundell <[email protected]>
1 parent 8fbc42c commit 418c584

File tree

1 file changed

+63
-24
lines changed

1 file changed

+63
-24
lines changed

backends/test/suite/operators/test_lstm.py

Lines changed: 63 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
22
# All rights reserved.
3+
# Copyright 2025 Arm Limited and/or its affiliates.
34
#
45
# This source code is licensed under the BSD-style license found in the
56
# LICENSE file in the root directory of this source tree.
@@ -15,6 +16,11 @@
1516
operator_test,
1617
OperatorTest,
1718
)
19+
from torch.nn.quantizable.modules.rnn import LSTM as QuantizableLSTM
20+
21+
22+
def _get_lstm_cls(use_quantizable_lstm: bool):
23+
return QuantizableLSTM if use_quantizable_lstm else torch.nn.LSTM
1824

1925

2026
class Model(torch.nn.Module):
@@ -27,9 +33,11 @@ def __init__(
2733
batch_first=True,
2834
dropout=0.0,
2935
bidirectional=False,
36+
use_quantizable_lstm: bool = False,
3037
):
3138
super().__init__()
32-
self.lstm = torch.nn.LSTM(
39+
lstm_cls = _get_lstm_cls(use_quantizable_lstm)
40+
self.lstm = lstm_cls(
3341
input_size=input_size,
3442
hidden_size=hidden_size,
3543
num_layers=num_layers,
@@ -47,116 +55,144 @@ def forward(self, x):
4755
class LSTM(OperatorTest):
4856
@dtype_test
4957
def test_lstm_dtype(self, flow: TestFlow, dtype) -> None:
58+
use_quantizable_lstm = flow.quantize
5059
self._test_op(
51-
Model(num_layers=2).to(dtype),
60+
Model(num_layers=2, use_quantizable_lstm=use_quantizable_lstm).to(dtype),
5261
((torch.rand(1, 10, 64) * 10).to(dtype),), # (batch=1, seq_len, input_size)
5362
flow,
5463
)
5564

5665
@dtype_test
5766
def test_lstm_no_bias_dtype(self, flow: TestFlow, dtype) -> None:
67+
use_quantizable_lstm = flow.quantize
5868
self._test_op(
59-
Model(num_layers=2, bias=False).to(dtype),
69+
Model(
70+
num_layers=2, bias=False, use_quantizable_lstm=use_quantizable_lstm
71+
).to(dtype),
6072
((torch.rand(1, 10, 64) * 10).to(dtype),),
6173
flow,
6274
)
6375

6476
def test_lstm_feature_sizes(self, flow: TestFlow) -> None:
77+
use_quantizable_lstm = flow.quantize
6578
self._test_op(
66-
Model(input_size=32, hidden_size=16),
79+
Model(
80+
input_size=32,
81+
hidden_size=16,
82+
use_quantizable_lstm=use_quantizable_lstm,
83+
),
6784
(torch.randn(1, 8, 32),), # (batch=1, seq_len, input_size)
6885
flow,
6986
)
7087
self._test_op(
71-
Model(input_size=128, hidden_size=64),
88+
Model(
89+
input_size=128,
90+
hidden_size=64,
91+
use_quantizable_lstm=use_quantizable_lstm,
92+
),
7293
(torch.randn(1, 12, 128),),
7394
flow,
7495
)
7596
self._test_op(
76-
Model(input_size=256, hidden_size=128),
97+
Model(
98+
input_size=256,
99+
hidden_size=128,
100+
use_quantizable_lstm=use_quantizable_lstm,
101+
),
77102
(torch.randn(1, 6, 256),),
78103
flow,
79104
)
80105
self._test_op(
81-
Model(input_size=16, hidden_size=32),
106+
Model(
107+
input_size=16,
108+
hidden_size=32,
109+
use_quantizable_lstm=use_quantizable_lstm,
110+
),
82111
(torch.randn(1, 5, 16),),
83112
flow,
84113
)
85114

86115
def test_lstm_batch_sizes(self, flow: TestFlow) -> None:
116+
use_quantizable_lstm = flow.quantize
87117
self._test_op(
88-
Model(),
118+
Model(use_quantizable_lstm=use_quantizable_lstm),
89119
(torch.randn(8, 10, 64),),
90120
flow,
91121
)
92122
self._test_op(
93-
Model(),
123+
Model(use_quantizable_lstm=use_quantizable_lstm),
94124
(torch.randn(32, 10, 64),),
95125
flow,
96126
)
97127
self._test_op(
98-
Model(),
128+
Model(use_quantizable_lstm=use_quantizable_lstm),
99129
(torch.randn(100, 10, 64),),
100130
flow,
101131
)
102132

103133
def test_lstm_seq_lengths(self, flow: TestFlow) -> None:
134+
use_quantizable_lstm = flow.quantize
104135
self._test_op(
105-
Model(),
136+
Model(use_quantizable_lstm=use_quantizable_lstm),
106137
(torch.randn(1, 5, 64),),
107138
flow,
108139
)
109140
self._test_op(
110-
Model(),
141+
Model(use_quantizable_lstm=use_quantizable_lstm),
111142
(torch.randn(1, 20, 64),),
112143
flow,
113144
)
114145
self._test_op(
115-
Model(),
146+
Model(use_quantizable_lstm=use_quantizable_lstm),
116147
(torch.randn(1, 50, 64),),
117148
flow,
118149
)
119150

120151
def test_lstm_batch_first_false(self, flow: TestFlow) -> None:
152+
use_quantizable_lstm = flow.quantize
121153
self._test_op(
122-
Model(batch_first=False),
154+
Model(batch_first=False, use_quantizable_lstm=use_quantizable_lstm),
123155
(torch.randn(10, 1, 64),), # (seq_len, batch=1, input_size)
124156
flow,
125157
)
126158

127159
def test_lstm_num_layers(self, flow: TestFlow) -> None:
160+
use_quantizable_lstm = flow.quantize
128161
self._test_op(
129-
Model(num_layers=2),
162+
Model(num_layers=2, use_quantizable_lstm=use_quantizable_lstm),
130163
(torch.randn(1, 10, 64),),
131164
flow,
132165
)
133166
self._test_op(
134-
Model(num_layers=3),
167+
Model(num_layers=3, use_quantizable_lstm=use_quantizable_lstm),
135168
(torch.randn(1, 10, 64),),
136169
flow,
137170
)
138171

139172
def test_lstm_bidirectional(self, flow: TestFlow) -> None:
173+
use_quantizable_lstm = flow.quantize
140174
self._test_op(
141-
Model(bidirectional=True),
175+
Model(bidirectional=True, use_quantizable_lstm=use_quantizable_lstm),
142176
(torch.randn(1, 10, 64),),
143177
flow,
144178
)
145179

146180
def test_lstm_with_dropout(self, flow: TestFlow) -> None:
147181
# Note: Dropout is only effective with num_layers > 1
182+
use_quantizable_lstm = flow.quantize
148183
self._test_op(
149-
Model(num_layers=2, dropout=0.2),
184+
Model(num_layers=2, dropout=0.2, use_quantizable_lstm=use_quantizable_lstm),
150185
(torch.randn(1, 10, 64),),
151186
flow,
152187
)
153188

154189
def test_lstm_with_initial_states(self, flow: TestFlow) -> None:
155190
# Create a model that accepts initial states
156191
class ModelWithStates(torch.nn.Module):
157-
def __init__(self):
192+
def __init__(self, use_quantizable_lstm: bool = False):
158193
super().__init__()
159-
self.lstm = torch.nn.LSTM(
194+
lstm_cls = _get_lstm_cls(use_quantizable_lstm)
195+
self.lstm = lstm_cls(
160196
input_size=64,
161197
hidden_size=32,
162198
num_layers=2,
@@ -169,9 +205,10 @@ def forward(self, x, h0, c0):
169205
batch_size = 1
170206
num_layers = 2
171207
hidden_size = 32
208+
use_quantizable_lstm = flow.quantize
172209

173210
self._test_op(
174-
ModelWithStates(),
211+
ModelWithStates(use_quantizable_lstm=use_quantizable_lstm),
175212
(
176213
torch.randn(batch_size, 10, 64), # input
177214
torch.randn(num_layers, batch_size, hidden_size), # h0
@@ -183,9 +220,10 @@ def forward(self, x, h0, c0):
183220
def test_lstm_return_hidden_states(self, flow: TestFlow) -> None:
184221
# Create a model that returns both output and hidden states
185222
class ModelWithHiddenStates(torch.nn.Module):
186-
def __init__(self):
223+
def __init__(self, use_quantizable_lstm: bool = False):
187224
super().__init__()
188-
self.lstm = torch.nn.LSTM(
225+
lstm_cls = _get_lstm_cls(use_quantizable_lstm)
226+
self.lstm = lstm_cls(
189227
input_size=64,
190228
hidden_size=32,
191229
num_layers=2,
@@ -200,9 +238,10 @@ def forward(self, x):
200238
batch_size = 1
201239
seq_len = 10
202240
input_size = 64
241+
use_quantizable_lstm = flow.quantize
203242

204243
self._test_op(
205-
ModelWithHiddenStates(),
244+
ModelWithHiddenStates(use_quantizable_lstm=use_quantizable_lstm),
206245
(torch.randn(batch_size, seq_len, input_size),),
207246
flow,
208247
)

0 commit comments

Comments
 (0)