@@ -113,38 +113,53 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
113113 py::arg (" comm_overlap" ) = nullptr , py::arg (" comm_type" ) = std::nullopt ,
114114 py::arg (" extra_output" ) = std::nullopt , py::arg (" bulk_overlap" ) = false ,
115115 py::arg (" alpha" ) = 1 .0f , py::arg (" beta" ) = std::nullopt );
116+ /* GELU and variants*/
116117 m.def (" gelu" , transformer_engine::pytorch::gelu, " GeLU activation" , py::arg (" input" ),
117118 py::arg (" quantizer" ));
118- m.def (" relu" , transformer_engine::pytorch::relu, " ReLU activation" , py::arg (" input" ),
119- py::arg (" quantizer" ));
120119 m.def (" geglu" , transformer_engine::pytorch::geglu, " GeGLU activation" , py::arg (" input" ),
121120 py::arg (" quantizer" ));
121+ m.def (" qgelu" , transformer_engine::pytorch::qgelu, " QuickGELU activation" , py::arg (" input" ),
122+ py::arg (" quantizer" ));
122123 m.def (" qgeglu" , transformer_engine::pytorch::qgeglu, " QuickGeGLU activation" , py::arg (" input" ),
123124 py::arg (" quantizer" ));
125+ /* ReLU and variants */
126+ m.def (" relu" , transformer_engine::pytorch::relu, " ReLU activation" , py::arg (" input" ),
127+ py::arg (" quantizer" ));
124128 m.def (" reglu" , transformer_engine::pytorch::reglu, " ReGLU activation" , py::arg (" input" ),
125129 py::arg (" quantizer" ));
126- m.def (" swiglu " , transformer_engine::pytorch::swiglu , " SwiGLU activation" , py::arg (" input" ),
130+ m.def (" srelu " , transformer_engine::pytorch::srelu , " Squared ReLU activation" , py::arg (" input" ),
127131 py::arg (" quantizer" ));
128- m.def (" qgelu " , transformer_engine::pytorch::qgelu , " QuickGELU activation" , py::arg (" input" ),
132+ m.def (" sreglu " , transformer_engine::pytorch::sreglu , " Squared ReGLU activation" , py::arg (" input" ),
129133 py::arg (" quantizer" ));
130- m.def (" srelu" , transformer_engine::pytorch::srelu, " Squared ReLU activation" , py::arg (" input" ),
134+ /* SwiGLU and variants */
135+ m.def (" silu" , transformer_engine::pytorch::silu, " SiLU activation" , py::arg (" input" ),
136+ py::arg (" quantizer" ));
137+ m.def (" swiglu" , transformer_engine::pytorch::swiglu, " SwiGLU activation" , py::arg (" input" ),
131138 py::arg (" quantizer" ));
139+ /* Backward of GELU and variants */
132140 m.def (" dgelu" , transformer_engine::pytorch::dgelu, " Backward of GeLU" , py::arg (" grad" ),
133141 py::arg (" fwd_input" ), py::arg (" quantizer" ));
134- m.def (" drelu" , transformer_engine::pytorch::drelu, " Backward of ReLU" , py::arg (" grad" ),
135- py::arg (" fwd_input" ), py::arg (" quantizer" ));
136142 m.def (" dgeglu" , transformer_engine::pytorch::dgeglu, " Backward of GeGLU" , py::arg (" grad" ),
137143 py::arg (" fwd_input" ), py::arg (" quantizer" ));
144+ m.def (" dqgelu" , transformer_engine::pytorch::dqgelu, " Backward of QuickGELU" , py::arg (" grad" ),
145+ py::arg (" fwd_input" ), py::arg (" quantizer" ));
138146 m.def (" dqgeglu" , transformer_engine::pytorch::dqgeglu, " Backward of QuickGeGLU" , py::arg (" grad" ),
139147 py::arg (" fwd_input" ), py::arg (" quantizer" ));
148+ /* Backward of ReLU and variants */
149+ m.def (" drelu" , transformer_engine::pytorch::drelu, " Backward of ReLU" , py::arg (" grad" ),
150+ py::arg (" fwd_input" ), py::arg (" quantizer" ));
140151 m.def (" dreglu" , transformer_engine::pytorch::dreglu, " Backward of ReGLU" , py::arg (" grad" ),
141152 py::arg (" fwd_input" ), py::arg (" quantizer" ));
142- m.def (" dswiglu " , transformer_engine::pytorch::dswiglu , " Backward of SwiGLU " , py::arg (" grad" ),
153+ m.def (" dsrelu " , transformer_engine::pytorch::dsrelu , " Backward of Squared ReLU " , py::arg (" grad" ),
143154 py::arg (" fwd_input" ), py::arg (" quantizer" ));
144- m.def (" dqgelu" , transformer_engine::pytorch::dqgelu, " Backward of QuickGELU" , py::arg (" grad" ),
155+ m.def (" dsreglu" , transformer_engine::pytorch::dsreglu, " Backward of Squared ReGLU" ,
156+ py::arg (" grad" ), py::arg (" fwd_input" ), py::arg (" quantizer" ));
157+ /* Backward of SiLU and variants */
158+ m.def (" dsilu" , transformer_engine::pytorch::dsilu, " Backward of SiLU" , py::arg (" grad" ),
145159 py::arg (" fwd_input" ), py::arg (" quantizer" ));
146- m.def (" dsrelu " , transformer_engine::pytorch::dsrelu , " Backward of Squared ReLU " , py::arg (" grad" ),
160+ m.def (" dswiglu " , transformer_engine::pytorch::dswiglu , " Backward of SwiGLU " , py::arg (" grad" ),
147161 py::arg (" fwd_input" ), py::arg (" quantizer" ));
162+ /* DBias + DAct fusions*/
148163 m.def (" dbias_dgelu" , transformer_engine::pytorch::dbias_dgelu, " DGeLU + DBias + Quantize" ,
149164 py::arg (" grad" ), py::arg (" fwd_input" ), py::arg (" quantizer" ));
150165 m.def (" dbias_dsilu" , transformer_engine::pytorch::dbias_dsilu, " DSiLU + DBias + Quantize" ,
0 commit comments