@@ -500,84 +500,4 @@ static void cublaslt_gemm_tt(const torch::Tensor& a, const torch::Tensor& b,
500500 cublaslt_gemm_nt (a.transpose (0 , 1 ), b, d, c);
501501}
502502
503- static void register_apis (pybind11::module_& m) {
504- // FP8 GEMMs
505- m.def (" fp8_gemm_nt" , &fp8_gemm_nt,
506- py::arg (" a" ), py::arg (" b" ), py::arg (" d" ),
507- py::arg (" c" ) = std::nullopt , py::arg (" recipe" ) = std::nullopt ,
508- py::arg (" compiled_dims" ) = " nk" ,
509- py::arg (" disable_ue8m0_cast" ) = false );
510- m.def (" fp8_gemm_nn" , &fp8_gemm_nn,
511- py::arg (" a" ), py::arg (" b" ), py::arg (" d" ),
512- py::arg (" c" ) = std::nullopt , py::arg (" recipe" ) = std::nullopt ,
513- py::arg (" compiled_dims" ) = " nk" ,
514- py::arg (" disable_ue8m0_cast" ) = false );
515- m.def (" fp8_gemm_tn" , &fp8_gemm_tn,
516- py::arg (" a" ), py::arg (" b" ), py::arg (" d" ),
517- py::arg (" c" ) = std::nullopt , py::arg (" recipe" ) = std::nullopt ,
518- py::arg (" compiled_dims" ) = " mn" ,
519- py::arg (" disable_ue8m0_cast" ) = false );
520- m.def (" fp8_gemm_tt" , &fp8_gemm_tt,
521- py::arg (" a" ), py::arg (" b" ), py::arg (" d" ),
522- py::arg (" c" ) = std::nullopt , py::arg (" recipe" ) = std::nullopt ,
523- py::arg (" compiled_dims" ) = " mn" ,
524- py::arg (" disable_ue8m0_cast" ) = false );
525- m.def (" m_grouped_fp8_gemm_nt_contiguous" , &m_grouped_fp8_gemm_nt_contiguous,
526- py::arg (" a" ), py::arg (" b" ), py::arg (" d" ), py::arg (" m_indices" ),
527- py::arg (" recipe" ) = std::nullopt , py::arg (" compiled_dims" ) = " nk" ,
528- py::arg (" disable_ue8m0_cast" ) = false );
529- m.def (" m_grouped_fp8_gemm_nn_contiguous" , &m_grouped_fp8_gemm_nn_contiguous,
530- py::arg (" a" ), py::arg (" b" ), py::arg (" d" ), py::arg (" m_indices" ),
531- py::arg (" recipe" ) = std::nullopt , py::arg (" compiled_dims" ) = " nk" ,
532- py::arg (" disable_ue8m0_cast" ) = false );
533- m.def (" m_grouped_fp8_gemm_nt_masked" , &m_grouped_fp8_gemm_nt_masked,
534- py::arg (" a" ), py::arg (" b" ), py::arg (" d" ), py::arg (" masked_m" ),
535- py::arg (" expected_m" ), py::arg (" recipe" ) = std::nullopt ,
536- py::arg (" compiled_dims" ) = " nk" , py::arg (" disable_ue8m0_cast" ) = false );
537- m.def (" k_grouped_fp8_gemm_tn_contiguous" , &k_grouped_fp8_gemm_tn_contiguous,
538- py::arg (" a" ), py::arg (" b" ), py::arg (" d" ), py::arg (" ks" ),
539- py::arg (" ks_tensor" ), py::arg (" c" ) = std::nullopt ,
540- py::arg (" recipe" ) = std::make_tuple (1 , 1 , 128 ),
541- py::arg (" compiled_dims" ) = " mn" );
542- m.def (" k_grouped_fp8_gemm_nt_contiguous" , &k_grouped_fp8_gemm_nt_contiguous,
543- py::arg (" a" ), py::arg (" b" ), py::arg (" d" ), py::arg (" ks" ),
544- py::arg (" ks_tensor" ), py::arg (" c" ) = std::nullopt ,
545- py::arg (" recipe" ) = std::make_tuple (1 , 1 , 128 ),
546- py::arg (" compiled_dims" ) = " mn" );
547-
548- // BF16 GEMMs
549- m.def (" bf16_gemm_nt" , &bf16_gemm_nt,
550- py::arg (" a" ), py::arg (" b" ), py::arg (" d" ),
551- py::arg (" c" ) = std::nullopt ,
552- py::arg (" compiled_dims" ) = " nk" );
553- m.def (" bf16_gemm_nn" , &bf16_gemm_nn,
554- py::arg (" a" ), py::arg (" b" ), py::arg (" d" ),
555- py::arg (" c" ) = std::nullopt ,
556- py::arg (" compiled_dims" ) = " nk" );
557- m.def (" bf16_gemm_tn" , &bf16_gemm_tn,
558- py::arg (" a" ), py::arg (" b" ), py::arg (" d" ),
559- py::arg (" c" ) = std::nullopt ,
560- py::arg (" compiled_dims" ) = " mn" );
561- m.def (" bf16_gemm_tt" , &bf16_gemm_tt,
562- py::arg (" a" ), py::arg (" b" ), py::arg (" d" ),
563- py::arg (" c" ) = std::nullopt ,
564- py::arg (" compiled_dims" ) = " mn" );
565- m.def (" m_grouped_bf16_gemm_nt_contiguous" , &m_grouped_bf16_gemm_nt_contiguous,
566- py::arg (" a" ), py::arg (" b" ), py::arg (" d" ), py::arg (" m_indices" ),
567- py::arg (" compiled_dims" ) = " nk" );
568- m.def (" m_grouped_bf16_gemm_nt_masked" , &m_grouped_bf16_gemm_nt_masked,
569- py::arg (" a" ), py::arg (" b" ), py::arg (" d" ), py::arg (" masked_m" ),
570- py::arg (" expected_m" ), py::arg (" compiled_dims" ) = " nk" );
571-
572- // cuBLASLt GEMMs
573- m.def (" cublaslt_gemm_nt" , &cublaslt_gemm_nt,
574- py::arg (" a" ), py::arg (" b" ), py::arg (" d" ), py::arg (" c" ) = std::nullopt );
575- m.def (" cublaslt_gemm_nn" , &cublaslt_gemm_nn,
576- py::arg (" a" ), py::arg (" b" ), py::arg (" d" ), py::arg (" c" ) = std::nullopt );
577- m.def (" cublaslt_gemm_tn" , &cublaslt_gemm_tn,
578- py::arg (" a" ), py::arg (" b" ), py::arg (" d" ), py::arg (" c" ) = std::nullopt );
579- m.def (" cublaslt_gemm_tt" , &cublaslt_gemm_tt,
580- py::arg (" a" ), py::arg (" b" ), py::arg (" d" ), py::arg (" c" ) = std::nullopt );
581- }
582-
583503} // namespace deep_gemm::gemm
0 commit comments