Skip to content

Commit 07db17b

Browse files
authored
[PyTorch] Expose more activation functions (#2106)
expose more activation functions Signed-off-by: Xin Yao <[email protected]>
1 parent ccc1abf commit 07db17b

File tree

10 files changed

+314
-88
lines changed

10 files changed

+314
-88
lines changed

tests/pytorch/test_fusible_ops.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1532,7 +1532,10 @@ def test_make_extra_output(
15321532
torch.testing.assert_close(y2_test, y2_ref, rtol=0, atol=0)
15331533
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
15341534

1535-
@pytest.mark.parametrize("activation", ("relu", "gelu", "geglu", "reglu", "swiglu"))
1535+
@pytest.mark.parametrize(
1536+
"activation",
1537+
("gelu", "geglu", "qgelu", "qgeglu", "relu", "reglu", "srelu", "sreglu", "silu", "swiglu"),
1538+
)
15361539
@pytest.mark.parametrize("out_shape", ((37,), (2, 13), (32, 1, 32)))
15371540
@pytest.mark.parametrize("dtype", _dtypes)
15381541
@pytest.mark.parametrize("quantization", _quantization_list)
@@ -1551,7 +1554,7 @@ def test_activation(
15511554

15521555
# Tensor dimensions
15531556
in_shape = list(out_shape)
1554-
if activation in ("geglu", "reglu", "swiglu"):
1557+
if activation in ("geglu", "qgeglu", "reglu", "sreglu", "swiglu"):
15551558
in_shape[-1] *= 2
15561559

15571560
# Skip invalid configurations
@@ -1578,14 +1581,26 @@ def test_activation(
15781581
y_ref: torch.Tensor
15791582
if activation == "gelu":
15801583
y_ref = torch.nn.functional.gelu(x_ref, approximate="tanh")
1581-
elif activation == "relu":
1582-
y_ref = torch.nn.functional.relu(x_ref)
15831584
elif activation == "geglu":
15841585
x1, x2 = x_ref.chunk(2, dim=-1)
15851586
y_ref = torch.nn.functional.gelu(x1, approximate="tanh") * x2
1587+
elif activation == "qgelu":
1588+
y_ref = x_ref * torch.sigmoid(1.702 * x_ref)
1589+
elif activation == "qgeglu":
1590+
x1, x2 = x_ref.chunk(2, dim=-1)
1591+
y_ref = x1 * torch.sigmoid(1.702 * x1) * x2
1592+
elif activation == "relu":
1593+
y_ref = torch.nn.functional.relu(x_ref)
15861594
elif activation == "reglu":
15871595
x1, x2 = x_ref.chunk(2, dim=-1)
15881596
y_ref = torch.nn.functional.relu(x1) * x2
1597+
elif activation == "srelu":
1598+
y_ref = torch.nn.functional.relu(x_ref) ** 2
1599+
elif activation == "sreglu":
1600+
x1, x2 = x_ref.chunk(2, dim=-1)
1601+
y_ref = torch.nn.functional.relu(x1) ** 2 * x2
1602+
elif activation == "silu":
1603+
y_ref = torch.nn.functional.silu(x_ref)
15891604
elif activation == "swiglu":
15901605
x1, x2 = x_ref.chunk(2, dim=-1)
15911606
y_ref = torch.nn.functional.silu(x1) * x2
@@ -1597,9 +1612,14 @@ def test_activation(
15971612
recipe = make_recipe(quantization)
15981613
make_op = dict(
15991614
gelu=te_ops.GELU,
1600-
relu=te_ops.ReLU,
16011615
geglu=te_ops.GEGLU,
1616+
qgelu=te_ops.QGELU,
1617+
qgeglu=te_ops.QGEGLU,
1618+
relu=te_ops.ReLU,
16021619
reglu=te_ops.ReGLU,
1620+
srelu=te_ops.SReLU,
1621+
sreglu=te_ops.SReGLU,
1622+
silu=te_ops.SiLU,
16031623
swiglu=te_ops.SwiGLU,
16041624
)[activation]
16051625
forward = te_ops.Sequential(

tests/pytorch/test_numerics.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,18 @@
7979

8080
all_boolean = [True, False]
8181

82-
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "qgelu", "srelu"]
82+
all_activations = [
83+
"gelu",
84+
"geglu",
85+
"qgelu",
86+
"qgeglu",
87+
"relu",
88+
"reglu",
89+
"srelu",
90+
"sreglu",
91+
"silu",
92+
"swiglu",
93+
]
8394

8495
all_normalizations = ["LayerNorm", "RMSNorm"]
8596

@@ -427,13 +438,16 @@ def forward(self, inp: torch.Tensor, m_splits: List[int]) -> torch.Tensor:
427438

428439

429440
_supported_act = {
430-
"geglu": nn.GELU(approximate="tanh"),
431441
"gelu": nn.GELU(approximate="tanh"),
432-
"reglu": nn.ReLU(),
433-
"relu": nn.ReLU(),
434-
"swiglu": nn.SiLU(),
442+
"geglu": nn.GELU(approximate="tanh"),
435443
"qgelu": TorchQuickGELU(),
444+
"qgeglu": TorchQuickGELU(),
445+
"relu": nn.ReLU(),
446+
"reglu": nn.ReLU(),
436447
"srelu": TorchSquaredRELU(),
448+
"sreglu": TorchSquaredRELU(),
449+
"silu": nn.SiLU(),
450+
"swiglu": nn.SiLU(),
437451
}
438452

439453

tests/pytorch/test_sanity.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,18 @@ def is_fp8_supported(config: ModelConfig):
104104
all_boolean = [True, False]
105105
batch_sizes_with_zero = [0, 1, 2]
106106

107-
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "srelu", "qgelu", "qgeglu"]
107+
all_activations = [
108+
"gelu",
109+
"geglu",
110+
"qgelu",
111+
"qgeglu",
112+
"relu",
113+
"reglu",
114+
"srelu",
115+
"sreglu",
116+
"silu",
117+
"swiglu",
118+
]
108119
all_normalizations = ["LayerNorm", "RMSNorm"]
109120

110121

transformer_engine/pytorch/csrc/extensions.h

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -154,38 +154,49 @@ at::Tensor swap_first_dims(at::Tensor tensor, std::optional<at::Tensor> out = st
154154
* Activations
155155
**************************************************************************************************/
156156

157+
/* GELU and variants*/
157158
py::object gelu(const at::Tensor &input, py::handle quantizer);
158159

159-
py::object relu(const at::Tensor &input, py::handle quantizer);
160+
py::object dgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
160161

161162
py::object geglu(const at::Tensor &input, py::handle quantizer);
162163

163-
py::object qgeglu(const at::Tensor &input, py::handle quantizer);
164+
py::object dgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
164165

165-
py::object reglu(const at::Tensor &input, py::handle quantizer);
166+
py::object qgelu(const at::Tensor &input, py::handle quantizer);
166167

167-
py::object swiglu(const at::Tensor &input, py::handle quantizer);
168+
py::object dqgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
168169

169-
py::object qgelu(const at::Tensor &input, py::handle quantizer);
170+
py::object qgeglu(const at::Tensor &input, py::handle quantizer);
170171

171-
py::object srelu(const at::Tensor &input, py::handle quantizer);
172+
py::object dqgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
172173

173-
py::object dgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
174+
/* ReLU and variants*/
175+
py::object relu(const at::Tensor &input, py::handle quantizer);
174176

175177
py::object drelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
176178

177-
py::object dgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
178-
179-
py::object dqgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
179+
py::object reglu(const at::Tensor &input, py::handle quantizer);
180180

181181
py::object dreglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
182182

183-
py::object dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
184-
185-
py::object dqgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
183+
py::object srelu(const at::Tensor &input, py::handle quantizer);
186184

187185
py::object dsrelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
188186

187+
py::object sreglu(const at::Tensor &input, py::handle quantizer);
188+
189+
py::object dsreglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
190+
191+
/* Silu and variants*/
192+
py::object silu(const at::Tensor &input, py::handle quantizer);
193+
194+
py::object dsilu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
195+
196+
py::object swiglu(const at::Tensor &input, py::handle quantizer);
197+
198+
py::object dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
199+
189200
/***************************************************************************************************
190201
* LayerNorm
191202
**************************************************************************************************/

transformer_engine/pytorch/csrc/extensions/activation.cpp

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i
101101
return grad_input_py;
102102
}
103103

104+
/* GELU and variants*/
104105
py::object gelu(const at::Tensor& input, py::handle quantizer) {
105106
return activation_helper<nvte_gelu>(input, quantizer);
106107
}
@@ -109,30 +110,39 @@ py::object dgelu(const at::Tensor& grad, const at::Tensor& input, py::handle qua
109110
return dactivation_helper<nvte_dgelu>(grad, input, quantizer);
110111
}
111112

112-
py::object relu(const at::Tensor& input, py::handle quantizer) {
113-
return activation_helper<nvte_relu>(input, quantizer);
113+
py::object geglu(const at::Tensor& input, py::handle quantizer) {
114+
return activation_helper<nvte_geglu>(input, quantizer, 2);
114115
}
115116

116-
py::object drelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
117-
return dactivation_helper<nvte_drelu>(grad, input, quantizer);
117+
py::object dgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
118+
return dactivation_helper<nvte_dgeglu>(grad, input, quantizer);
118119
}
119120

120-
py::object geglu(const at::Tensor& input, py::handle quantizer) {
121-
return activation_helper<nvte_geglu>(input, quantizer, 2);
121+
py::object qgelu(const at::Tensor& input, py::handle quantizer) {
122+
return activation_helper<nvte_qgelu>(input, quantizer);
122123
}
123124

124-
py::object qgeglu(const at::Tensor& input, py::handle quantizer) {
125-
return activation_helper<nvte_qgeglu>(input, quantizer, 2);
125+
py::object dqgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
126+
return dactivation_helper<nvte_dqgelu>(grad, input, quantizer);
126127
}
127128

128-
py::object dgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
129-
return dactivation_helper<nvte_dgeglu>(grad, input, quantizer);
129+
py::object qgeglu(const at::Tensor& input, py::handle quantizer) {
130+
return activation_helper<nvte_qgeglu>(input, quantizer, 2);
130131
}
131132

132133
py::object dqgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
133134
return dactivation_helper<nvte_dqgeglu>(grad, input, quantizer);
134135
}
135136

137+
/* ReLU and variants*/
138+
py::object relu(const at::Tensor& input, py::handle quantizer) {
139+
return activation_helper<nvte_relu>(input, quantizer);
140+
}
141+
142+
py::object drelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
143+
return dactivation_helper<nvte_drelu>(grad, input, quantizer);
144+
}
145+
136146
py::object reglu(const at::Tensor& input, py::handle quantizer) {
137147
return activation_helper<nvte_reglu>(input, quantizer, 2);
138148
}
@@ -141,28 +151,36 @@ py::object dreglu(const at::Tensor& grad, const at::Tensor& input, py::handle qu
141151
return dactivation_helper<nvte_dreglu>(grad, input, quantizer);
142152
}
143153

144-
py::object swiglu(const at::Tensor& input, py::handle quantizer) {
145-
return activation_helper<nvte_swiglu>(input, quantizer, 2);
154+
py::object srelu(const at::Tensor& input, py::handle quantizer) {
155+
return activation_helper<nvte_srelu>(input, quantizer);
146156
}
147157

148-
py::object dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
149-
return dactivation_helper<nvte_dswiglu>(grad, input, quantizer);
158+
py::object dsrelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
159+
return dactivation_helper<nvte_dsrelu>(grad, input, quantizer);
150160
}
151161

152-
py::object qgelu(const at::Tensor& input, py::handle quantizer) {
153-
return activation_helper<nvte_qgelu>(input, quantizer);
162+
py::object sreglu(const at::Tensor& input, py::handle quantizer) {
163+
return activation_helper<nvte_sreglu>(input, quantizer, 2);
154164
}
155165

156-
py::object dqgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
157-
return dactivation_helper<nvte_dqgelu>(grad, input, quantizer);
166+
py::object dsreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
167+
return dactivation_helper<nvte_dsreglu>(grad, input, quantizer);
158168
}
159169

160-
py::object srelu(const at::Tensor& input, py::handle quantizer) {
161-
return activation_helper<nvte_srelu>(input, quantizer);
170+
/* Silu and variants*/
171+
py::object silu(const at::Tensor& input, py::handle quantizer) {
172+
return activation_helper<nvte_silu>(input, quantizer);
162173
}
163174

164-
py::object dsrelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
165-
return dactivation_helper<nvte_dsrelu>(grad, input, quantizer);
175+
py::object dsilu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
176+
return dactivation_helper<nvte_dsilu>(grad, input, quantizer);
177+
}
178+
179+
py::object swiglu(const at::Tensor& input, py::handle quantizer) {
180+
return activation_helper<nvte_swiglu>(input, quantizer, 2);
166181
}
167182

183+
py::object dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
184+
return dactivation_helper<nvte_dswiglu>(grad, input, quantizer);
185+
}
168186
} // namespace transformer_engine::pytorch

transformer_engine/pytorch/csrc/extensions/pybind.cpp

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -113,38 +113,53 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
113113
py::arg("comm_overlap") = nullptr, py::arg("comm_type") = std::nullopt,
114114
py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false,
115115
py::arg("alpha") = 1.0f, py::arg("beta") = std::nullopt);
116+
/* GELU and variants*/
116117
m.def("gelu", transformer_engine::pytorch::gelu, "GeLU activation", py::arg("input"),
117118
py::arg("quantizer"));
118-
m.def("relu", transformer_engine::pytorch::relu, "ReLU activation", py::arg("input"),
119-
py::arg("quantizer"));
120119
m.def("geglu", transformer_engine::pytorch::geglu, "GeGLU activation", py::arg("input"),
121120
py::arg("quantizer"));
121+
m.def("qgelu", transformer_engine::pytorch::qgelu, "QuickGELU activation", py::arg("input"),
122+
py::arg("quantizer"));
122123
m.def("qgeglu", transformer_engine::pytorch::qgeglu, "QuickGeGLU activation", py::arg("input"),
123124
py::arg("quantizer"));
125+
/* ReLU and variants */
126+
m.def("relu", transformer_engine::pytorch::relu, "ReLU activation", py::arg("input"),
127+
py::arg("quantizer"));
124128
m.def("reglu", transformer_engine::pytorch::reglu, "ReGLU activation", py::arg("input"),
125129
py::arg("quantizer"));
126-
m.def("swiglu", transformer_engine::pytorch::swiglu, "SwiGLU activation", py::arg("input"),
130+
m.def("srelu", transformer_engine::pytorch::srelu, "Squared ReLU activation", py::arg("input"),
127131
py::arg("quantizer"));
128-
m.def("qgelu", transformer_engine::pytorch::qgelu, "QuickGELU activation", py::arg("input"),
132+
m.def("sreglu", transformer_engine::pytorch::sreglu, "Squared ReGLU activation", py::arg("input"),
129133
py::arg("quantizer"));
130-
m.def("srelu", transformer_engine::pytorch::srelu, "Squared ReLU activation", py::arg("input"),
134+
/* SwiGLU and variants */
135+
m.def("silu", transformer_engine::pytorch::silu, "SiLU activation", py::arg("input"),
136+
py::arg("quantizer"));
137+
m.def("swiglu", transformer_engine::pytorch::swiglu, "SwiGLU activation", py::arg("input"),
131138
py::arg("quantizer"));
139+
/* Backward of GELU and variants */
132140
m.def("dgelu", transformer_engine::pytorch::dgelu, "Backward of GeLU", py::arg("grad"),
133141
py::arg("fwd_input"), py::arg("quantizer"));
134-
m.def("drelu", transformer_engine::pytorch::drelu, "Backward of ReLU", py::arg("grad"),
135-
py::arg("fwd_input"), py::arg("quantizer"));
136142
m.def("dgeglu", transformer_engine::pytorch::dgeglu, "Backward of GeGLU", py::arg("grad"),
137143
py::arg("fwd_input"), py::arg("quantizer"));
144+
m.def("dqgelu", transformer_engine::pytorch::dqgelu, "Backward of QuickGELU", py::arg("grad"),
145+
py::arg("fwd_input"), py::arg("quantizer"));
138146
m.def("dqgeglu", transformer_engine::pytorch::dqgeglu, "Backward of QuickGeGLU", py::arg("grad"),
139147
py::arg("fwd_input"), py::arg("quantizer"));
148+
/* Backward of ReLU and variants */
149+
m.def("drelu", transformer_engine::pytorch::drelu, "Backward of ReLU", py::arg("grad"),
150+
py::arg("fwd_input"), py::arg("quantizer"));
140151
m.def("dreglu", transformer_engine::pytorch::dreglu, "Backward of ReGLU", py::arg("grad"),
141152
py::arg("fwd_input"), py::arg("quantizer"));
142-
m.def("dswiglu", transformer_engine::pytorch::dswiglu, "Backward of SwiGLU", py::arg("grad"),
153+
m.def("dsrelu", transformer_engine::pytorch::dsrelu, "Backward of Squared ReLU", py::arg("grad"),
143154
py::arg("fwd_input"), py::arg("quantizer"));
144-
m.def("dqgelu", transformer_engine::pytorch::dqgelu, "Backward of QuickGELU", py::arg("grad"),
155+
m.def("dsreglu", transformer_engine::pytorch::dsreglu, "Backward of Squared ReGLU",
156+
py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer"));
157+
/* Backward of SiLU and variants */
158+
m.def("dsilu", transformer_engine::pytorch::dsilu, "Backward of SiLU", py::arg("grad"),
145159
py::arg("fwd_input"), py::arg("quantizer"));
146-
m.def("dsrelu", transformer_engine::pytorch::dsrelu, "Backward of Squared ReLU", py::arg("grad"),
160+
m.def("dswiglu", transformer_engine::pytorch::dswiglu, "Backward of SwiGLU", py::arg("grad"),
147161
py::arg("fwd_input"), py::arg("quantizer"));
162+
/* DBias + DAct fusions*/
148163
m.def("dbias_dgelu", transformer_engine::pytorch::dbias_dgelu, "DGeLU + DBias + Quantize",
149164
py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer"));
150165
m.def("dbias_dsilu", transformer_engine::pytorch::dbias_dsilu, "DSiLU + DBias + Quantize",

0 commit comments

Comments
 (0)