1
1
# Copyright (c) Meta Platforms, Inc. and affiliates.
2
2
# All rights reserved.
3
+ # Copyright 2025 Arm Limited and/or its affiliates.
3
4
#
4
5
# This source code is licensed under the BSD-style license found in the
5
6
# LICENSE file in the root directory of this source tree.
15
16
operator_test ,
16
17
OperatorTest ,
17
18
)
19
+ from torch .nn .quantizable .modules .rnn import LSTM as QuantizableLSTM
20
+
21
+
22
+ def _get_lstm_cls (use_quantizable_lstm : bool ):
23
+ return QuantizableLSTM if use_quantizable_lstm else torch .nn .LSTM
18
24
19
25
20
26
class Model (torch .nn .Module ):
@@ -27,9 +33,11 @@ def __init__(
27
33
batch_first = True ,
28
34
dropout = 0.0 ,
29
35
bidirectional = False ,
36
+ use_quantizable_lstm : bool = False ,
30
37
):
31
38
super ().__init__ ()
32
- self .lstm = torch .nn .LSTM (
39
+ lstm_cls = _get_lstm_cls (use_quantizable_lstm )
40
+ self .lstm = lstm_cls (
33
41
input_size = input_size ,
34
42
hidden_size = hidden_size ,
35
43
num_layers = num_layers ,
@@ -47,116 +55,144 @@ def forward(self, x):
47
55
class LSTM (OperatorTest ):
48
56
@dtype_test
49
57
def test_lstm_dtype (self , flow : TestFlow , dtype ) -> None :
58
+ use_quantizable_lstm = flow .quantize
50
59
self ._test_op (
51
- Model (num_layers = 2 ).to (dtype ),
60
+ Model (num_layers = 2 , use_quantizable_lstm = use_quantizable_lstm ).to (dtype ),
52
61
((torch .rand (1 , 10 , 64 ) * 10 ).to (dtype ),), # (batch=1, seq_len, input_size)
53
62
flow ,
54
63
)
55
64
56
65
@dtype_test
57
66
def test_lstm_no_bias_dtype (self , flow : TestFlow , dtype ) -> None :
67
+ use_quantizable_lstm = flow .quantize
58
68
self ._test_op (
59
- Model (num_layers = 2 , bias = False ).to (dtype ),
69
+ Model (
70
+ num_layers = 2 , bias = False , use_quantizable_lstm = use_quantizable_lstm
71
+ ).to (dtype ),
60
72
((torch .rand (1 , 10 , 64 ) * 10 ).to (dtype ),),
61
73
flow ,
62
74
)
63
75
64
76
def test_lstm_feature_sizes (self , flow : TestFlow ) -> None :
77
+ use_quantizable_lstm = flow .quantize
65
78
self ._test_op (
66
- Model (input_size = 32 , hidden_size = 16 ),
79
+ Model (
80
+ input_size = 32 ,
81
+ hidden_size = 16 ,
82
+ use_quantizable_lstm = use_quantizable_lstm ,
83
+ ),
67
84
(torch .randn (1 , 8 , 32 ),), # (batch=1, seq_len, input_size)
68
85
flow ,
69
86
)
70
87
self ._test_op (
71
- Model (input_size = 128 , hidden_size = 64 ),
88
+ Model (
89
+ input_size = 128 ,
90
+ hidden_size = 64 ,
91
+ use_quantizable_lstm = use_quantizable_lstm ,
92
+ ),
72
93
(torch .randn (1 , 12 , 128 ),),
73
94
flow ,
74
95
)
75
96
self ._test_op (
76
- Model (input_size = 256 , hidden_size = 128 ),
97
+ Model (
98
+ input_size = 256 ,
99
+ hidden_size = 128 ,
100
+ use_quantizable_lstm = use_quantizable_lstm ,
101
+ ),
77
102
(torch .randn (1 , 6 , 256 ),),
78
103
flow ,
79
104
)
80
105
self ._test_op (
81
- Model (input_size = 16 , hidden_size = 32 ),
106
+ Model (
107
+ input_size = 16 ,
108
+ hidden_size = 32 ,
109
+ use_quantizable_lstm = use_quantizable_lstm ,
110
+ ),
82
111
(torch .randn (1 , 5 , 16 ),),
83
112
flow ,
84
113
)
85
114
86
115
def test_lstm_batch_sizes (self , flow : TestFlow ) -> None :
116
+ use_quantizable_lstm = flow .quantize
87
117
self ._test_op (
88
- Model (),
118
+ Model (use_quantizable_lstm = use_quantizable_lstm ),
89
119
(torch .randn (8 , 10 , 64 ),),
90
120
flow ,
91
121
)
92
122
self ._test_op (
93
- Model (),
123
+ Model (use_quantizable_lstm = use_quantizable_lstm ),
94
124
(torch .randn (32 , 10 , 64 ),),
95
125
flow ,
96
126
)
97
127
self ._test_op (
98
- Model (),
128
+ Model (use_quantizable_lstm = use_quantizable_lstm ),
99
129
(torch .randn (100 , 10 , 64 ),),
100
130
flow ,
101
131
)
102
132
103
133
def test_lstm_seq_lengths (self , flow : TestFlow ) -> None :
134
+ use_quantizable_lstm = flow .quantize
104
135
self ._test_op (
105
- Model (),
136
+ Model (use_quantizable_lstm = use_quantizable_lstm ),
106
137
(torch .randn (1 , 5 , 64 ),),
107
138
flow ,
108
139
)
109
140
self ._test_op (
110
- Model (),
141
+ Model (use_quantizable_lstm = use_quantizable_lstm ),
111
142
(torch .randn (1 , 20 , 64 ),),
112
143
flow ,
113
144
)
114
145
self ._test_op (
115
- Model (),
146
+ Model (use_quantizable_lstm = use_quantizable_lstm ),
116
147
(torch .randn (1 , 50 , 64 ),),
117
148
flow ,
118
149
)
119
150
120
151
def test_lstm_batch_first_false (self , flow : TestFlow ) -> None :
152
+ use_quantizable_lstm = flow .quantize
121
153
self ._test_op (
122
- Model (batch_first = False ),
154
+ Model (batch_first = False , use_quantizable_lstm = use_quantizable_lstm ),
123
155
(torch .randn (10 , 1 , 64 ),), # (seq_len, batch=1, input_size)
124
156
flow ,
125
157
)
126
158
127
159
def test_lstm_num_layers (self , flow : TestFlow ) -> None :
160
+ use_quantizable_lstm = flow .quantize
128
161
self ._test_op (
129
- Model (num_layers = 2 ),
162
+ Model (num_layers = 2 , use_quantizable_lstm = use_quantizable_lstm ),
130
163
(torch .randn (1 , 10 , 64 ),),
131
164
flow ,
132
165
)
133
166
self ._test_op (
134
- Model (num_layers = 3 ),
167
+ Model (num_layers = 3 , use_quantizable_lstm = use_quantizable_lstm ),
135
168
(torch .randn (1 , 10 , 64 ),),
136
169
flow ,
137
170
)
138
171
139
172
def test_lstm_bidirectional (self , flow : TestFlow ) -> None :
173
+ use_quantizable_lstm = flow .quantize
140
174
self ._test_op (
141
- Model (bidirectional = True ),
175
+ Model (bidirectional = True , use_quantizable_lstm = use_quantizable_lstm ),
142
176
(torch .randn (1 , 10 , 64 ),),
143
177
flow ,
144
178
)
145
179
146
180
def test_lstm_with_dropout (self , flow : TestFlow ) -> None :
147
181
# Note: Dropout is only effective with num_layers > 1
182
+ use_quantizable_lstm = flow .quantize
148
183
self ._test_op (
149
- Model (num_layers = 2 , dropout = 0.2 ),
184
+ Model (num_layers = 2 , dropout = 0.2 , use_quantizable_lstm = use_quantizable_lstm ),
150
185
(torch .randn (1 , 10 , 64 ),),
151
186
flow ,
152
187
)
153
188
154
189
def test_lstm_with_initial_states (self , flow : TestFlow ) -> None :
155
190
# Create a model that accepts initial states
156
191
class ModelWithStates (torch .nn .Module ):
157
- def __init__ (self ):
192
+ def __init__ (self , use_quantizable_lstm : bool = False ):
158
193
super ().__init__ ()
159
- self .lstm = torch .nn .LSTM (
194
+ lstm_cls = _get_lstm_cls (use_quantizable_lstm )
195
+ self .lstm = lstm_cls (
160
196
input_size = 64 ,
161
197
hidden_size = 32 ,
162
198
num_layers = 2 ,
@@ -169,9 +205,10 @@ def forward(self, x, h0, c0):
169
205
batch_size = 1
170
206
num_layers = 2
171
207
hidden_size = 32
208
+ use_quantizable_lstm = flow .quantize
172
209
173
210
self ._test_op (
174
- ModelWithStates (),
211
+ ModelWithStates (use_quantizable_lstm = use_quantizable_lstm ),
175
212
(
176
213
torch .randn (batch_size , 10 , 64 ), # input
177
214
torch .randn (num_layers , batch_size , hidden_size ), # h0
@@ -183,9 +220,10 @@ def forward(self, x, h0, c0):
183
220
def test_lstm_return_hidden_states (self , flow : TestFlow ) -> None :
184
221
# Create a model that returns both output and hidden states
185
222
class ModelWithHiddenStates (torch .nn .Module ):
186
- def __init__ (self ):
223
+ def __init__ (self , use_quantizable_lstm : bool = False ):
187
224
super ().__init__ ()
188
- self .lstm = torch .nn .LSTM (
225
+ lstm_cls = _get_lstm_cls (use_quantizable_lstm )
226
+ self .lstm = lstm_cls (
189
227
input_size = 64 ,
190
228
hidden_size = 32 ,
191
229
num_layers = 2 ,
@@ -200,9 +238,10 @@ def forward(self, x):
200
238
batch_size = 1
201
239
seq_len = 10
202
240
input_size = 64
241
+ use_quantizable_lstm = flow .quantize
203
242
204
243
self ._test_op (
205
- ModelWithHiddenStates (),
244
+ ModelWithHiddenStates (use_quantizable_lstm = use_quantizable_lstm ),
206
245
(torch .randn (batch_size , seq_len , input_size ),),
207
246
flow ,
208
247
)
0 commit comments