Skip to content

Commit 389326b

Browse files
committed
feat: support libtorch
1 parent 2f9d878 commit 389326b

22 files changed

+651
-246
lines changed

csrc/apis/attention.hpp

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -219,22 +219,4 @@ static torch::Tensor fp8_paged_mqa_logits(const torch::Tensor& q,
219219
return logits;
220220
}
221221

222-
static void register_apis(pybind11::module_& m) {
223-
m.def("fp8_gemm_nt_skip_head_mid", &fp8_gemm_nt_skip_head_mid,
224-
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("head_splits"),
225-
py::arg("recipe") = std::nullopt,
226-
py::arg("compiled_dims") = "nk",
227-
py::arg("disable_ue8m0_cast") = false);
228-
m.def("fp8_mqa_logits", &fp8_mqa_logits,
229-
py::arg("q"), py::arg("kv"), py::arg("weights"),
230-
py::arg("cu_seq_len_k_start"), py::arg("cu_seq_len_k_end"),
231-
py::arg("clean_logits") = true);
232-
m.def("get_paged_mqa_logits_metadata", &get_paged_mqa_logits_metadata,
233-
py::arg("context_lens"), py::arg("block_kv"), py::arg("num_sms"));
234-
m.def("fp8_paged_mqa_logits", &fp8_paged_mqa_logits,
235-
py::arg("q"), py::arg("kv_cache"), py::arg("weights"),
236-
py::arg("context_lens"), py::arg("block_table"), py::arg("schedule_meta"),
237-
py::arg("max_context_len"), py::arg("clean_logits") = false);
238-
}
239-
240222
} // namespace deep_gemm::attention

csrc/apis/einsum.hpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
#pragma once
22

3-
#include <pybind11/pybind11.h>
4-
#include <torch/python.h>
5-
63
#include "../utils/exception.hpp"
74
#include "../utils/format.hpp"
85
#include "../utils/layout.hpp"
@@ -106,10 +103,4 @@ static void einsum(const std::string& expr,
106103
}
107104
}
108105

109-
static void register_apis(pybind11::module_& m) {
110-
m.def("einsum", &einsum,
111-
py::arg("expr"), py::arg("a"), py::arg("b"),
112-
py::arg("d"), py::arg("c") = std::nullopt);
113-
}
114-
115106
} // namespace deep_gemm::einsum

csrc/apis/gemm.hpp

Lines changed: 0 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -500,84 +500,4 @@ static void cublaslt_gemm_tt(const torch::Tensor& a, const torch::Tensor& b,
500500
cublaslt_gemm_nt(a.transpose(0, 1), b, d, c);
501501
}
502502

503-
static void register_apis(pybind11::module_& m) {
504-
// FP8 GEMMs
505-
m.def("fp8_gemm_nt", &fp8_gemm_nt,
506-
py::arg("a"), py::arg("b"), py::arg("d"),
507-
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
508-
py::arg("compiled_dims") = "nk",
509-
py::arg("disable_ue8m0_cast") = false);
510-
m.def("fp8_gemm_nn", &fp8_gemm_nn,
511-
py::arg("a"), py::arg("b"), py::arg("d"),
512-
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
513-
py::arg("compiled_dims") = "nk",
514-
py::arg("disable_ue8m0_cast") = false);
515-
m.def("fp8_gemm_tn", &fp8_gemm_tn,
516-
py::arg("a"), py::arg("b"), py::arg("d"),
517-
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
518-
py::arg("compiled_dims") = "mn",
519-
py::arg("disable_ue8m0_cast") = false);
520-
m.def("fp8_gemm_tt", &fp8_gemm_tt,
521-
py::arg("a"), py::arg("b"), py::arg("d"),
522-
py::arg("c") = std::nullopt, py::arg("recipe") = std::nullopt,
523-
py::arg("compiled_dims") = "mn",
524-
py::arg("disable_ue8m0_cast") = false);
525-
m.def("m_grouped_fp8_gemm_nt_contiguous", &m_grouped_fp8_gemm_nt_contiguous,
526-
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"),
527-
py::arg("recipe") = std::nullopt, py::arg("compiled_dims") = "nk",
528-
py::arg("disable_ue8m0_cast") = false);
529-
m.def("m_grouped_fp8_gemm_nn_contiguous", &m_grouped_fp8_gemm_nn_contiguous,
530-
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"),
531-
py::arg("recipe") = std::nullopt, py::arg("compiled_dims") = "nk",
532-
py::arg("disable_ue8m0_cast") = false);
533-
m.def("m_grouped_fp8_gemm_nt_masked", &m_grouped_fp8_gemm_nt_masked,
534-
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"),
535-
py::arg("expected_m"), py::arg("recipe") = std::nullopt,
536-
py::arg("compiled_dims") = "nk", py::arg("disable_ue8m0_cast") = false);
537-
m.def("k_grouped_fp8_gemm_tn_contiguous", &k_grouped_fp8_gemm_tn_contiguous,
538-
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"),
539-
py::arg("ks_tensor"), py::arg("c") = std::nullopt,
540-
py::arg("recipe") = std::make_tuple(1, 1, 128),
541-
py::arg("compiled_dims") = "mn");
542-
m.def("k_grouped_fp8_gemm_nt_contiguous", &k_grouped_fp8_gemm_nt_contiguous,
543-
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("ks"),
544-
py::arg("ks_tensor"), py::arg("c") = std::nullopt,
545-
py::arg("recipe") = std::make_tuple(1, 1, 128),
546-
py::arg("compiled_dims") = "mn");
547-
548-
// BF16 GEMMs
549-
m.def("bf16_gemm_nt", &bf16_gemm_nt,
550-
py::arg("a"), py::arg("b"), py::arg("d"),
551-
py::arg("c") = std::nullopt,
552-
py::arg("compiled_dims") = "nk");
553-
m.def("bf16_gemm_nn", &bf16_gemm_nn,
554-
py::arg("a"), py::arg("b"), py::arg("d"),
555-
py::arg("c") = std::nullopt,
556-
py::arg("compiled_dims") = "nk");
557-
m.def("bf16_gemm_tn", &bf16_gemm_tn,
558-
py::arg("a"), py::arg("b"), py::arg("d"),
559-
py::arg("c") = std::nullopt,
560-
py::arg("compiled_dims") = "mn");
561-
m.def("bf16_gemm_tt", &bf16_gemm_tt,
562-
py::arg("a"), py::arg("b"), py::arg("d"),
563-
py::arg("c") = std::nullopt,
564-
py::arg("compiled_dims") = "mn");
565-
m.def("m_grouped_bf16_gemm_nt_contiguous", &m_grouped_bf16_gemm_nt_contiguous,
566-
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("m_indices"),
567-
py::arg("compiled_dims") = "nk");
568-
m.def("m_grouped_bf16_gemm_nt_masked", &m_grouped_bf16_gemm_nt_masked,
569-
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("masked_m"),
570-
py::arg("expected_m"), py::arg("compiled_dims") = "nk");
571-
572-
// cuBLASLt GEMMs
573-
m.def("cublaslt_gemm_nt", &cublaslt_gemm_nt,
574-
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt);
575-
m.def("cublaslt_gemm_nn", &cublaslt_gemm_nn,
576-
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt);
577-
m.def("cublaslt_gemm_tn", &cublaslt_gemm_tn,
578-
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt);
579-
m.def("cublaslt_gemm_tt", &cublaslt_gemm_tt,
580-
py::arg("a"), py::arg("b"), py::arg("d"), py::arg("c") = std::nullopt);
581-
}
582-
583503
} // namespace deep_gemm::gemm

csrc/apis/layout.hpp

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -69,17 +69,4 @@ static torch::Tensor transform_k_grouped_sf_into_required_layout(const torch::Te
6969
DG_HOST_UNREACHABLE("Unknown cases");
7070
}
7171

72-
static void register_apis(pybind11::module_& m) {
73-
m.def("transform_sf_into_required_layout", &transform_sf_into_required_layout,
74-
py::arg("sf"), py::arg("mn"), py::arg("k"), py::arg("recipe"),
75-
py::arg("num_groups") = std::nullopt, py::arg("is_sfa") = false,
76-
py::arg("disable_ue8m0_cast") = false);
77-
78-
m.def("get_tma_aligned_size", &get_tma_aligned_size);
79-
m.def("get_mk_alignment_for_contiguous_layout", &get_mk_alignment_for_contiguous_layout);
80-
m.def("get_mn_major_tma_aligned_tensor", &get_mn_major_tma_aligned_tensor);
81-
m.def("get_mn_major_tma_aligned_packed_ue8m0_tensor", &get_mn_major_tma_aligned_packed_ue8m0_tensor);
82-
m.def("get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor", &get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor);
83-
}
84-
8572
} // namespace deep_gemm::layout

csrc/apis/runtime.hpp

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,6 @@
55

66
namespace deep_gemm::runtime {
77

8-
static void register_apis(pybind11::module_& m) {
9-
m.def("set_num_sms", [&](const int& new_num_sms) {
10-
device_runtime->set_num_sms(new_num_sms);
11-
});
12-
m.def("get_num_sms", [&]() {
13-
return device_runtime->get_num_sms();
14-
});
15-
m.def("set_tc_util", [&](const int& new_tc_util) {
16-
device_runtime->set_tc_util(new_tc_util);
17-
});
18-
m.def("get_tc_util", [&]() {
19-
return device_runtime->get_tc_util();
20-
});
21-
22-
m.def("init", [&](const std::string& library_root_path, const std::string& cuda_home_path_by_python) {
23-
Compiler::prepare_init(library_root_path, cuda_home_path_by_python);
24-
KernelRuntime::prepare_init(cuda_home_path_by_python);
25-
});
26-
}
8+
// The init and other functions are now exposed via TORCH_LIBRARY in python_api.cpp
279

2810
} // namespace deep_gemm::runtime

csrc/jit/device_runtime.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ class DeviceRuntime {
1717
cublasLtHandle_t cublaslt_handle{};
1818
std::shared_ptr<torch::Tensor> cublaslt_workspace;
1919

20+
// cuBLASLt utils
21+
static constexpr size_t kCublasLtWorkspaceSize = 32 * 1024 * 1024;
22+
cublasLtHandle_t cublaslt_handle{};
23+
std::shared_ptr<torch::Tensor> cublaslt_workspace;
24+
2025
public:
2126
explicit DeviceRuntime() {
2227
cublaslt_workspace = std::make_shared<torch::Tensor>(torch::empty({kCublasLtWorkspaceSize}, dtype(torch::kByte).device(at::kCUDA)));

csrc/jit_kernels/impls/runtime_utils.hpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
#pragma once
22

33
#include <cuda.h>
4-
#include <torch/python.h>
54

65
#include "../../utils/math.hpp"
76
#include "../heuristics/sm90.hpp"
@@ -75,10 +74,6 @@ static CUtensorMapSwizzle mode_into_tensor_map_swizzle(const int& mode, const in
7574
}
7675
#endif
7776

78-
DG_HOST_ASSERT(base == 0);
79-
switch (mode) {
80-
case 0:
81-
case 16: return CU_TENSOR_MAP_SWIZZLE_NONE;
8277
case 32: return CU_TENSOR_MAP_SWIZZLE_32B;
8378
case 64: return CU_TENSOR_MAP_SWIZZLE_64B;
8479
case 128: return CU_TENSOR_MAP_SWIZZLE_128B;

csrc/jit_kernels/impls/sm100_bf16_gemm.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
#pragma once
22

3-
#include <torch/python.h>
4-
53
#include "../../jit/compiler.hpp"
64
#include "../../jit/device_runtime.hpp"
75
#include "../../jit/kernel_runtime.hpp"

csrc/jit_kernels/impls/sm100_bmk_bnk_mn.hpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
#pragma once
22

3-
#include <torch/python.h>
4-
53
#include "../../jit/compiler.hpp"
64
#include "../../jit/device_runtime.hpp"
75
#include "../../jit/kernel_runtime.hpp"
@@ -134,4 +132,4 @@ static void sm100_bmn_bnk_mn_gemm(const torch::Tensor &a,
134132
SM100BmkBnkMnRuntime::launch(runtime, args);
135133
}
136134

137-
} // namespace deep_gemm
135+
} // namespace deep_gemm

csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
#pragma once
22

3-
#include <torch/python.h>
4-
53
#include "../../jit/compiler.hpp"
64
#include "../../jit/device_runtime.hpp"
75
#include "../../jit/kernel_runtime.hpp"

0 commit comments

Comments
 (0)