Skip to content

Commit 92585b7

Browse files
rainj-meFlamingoPg
authored andcommitted
feat: support misc kernel launch
1 parent 389326b commit 92585b7

File tree

9 files changed

+52
-24
lines changed

9 files changed

+52
-24
lines changed

csrc/jit/device_runtime.hpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,7 @@ namespace deep_gemm {
1111
class DeviceRuntime {
1212
int num_sms = 0, tc_util = 0;
1313
std::shared_ptr<cudaDeviceProp> cached_prop;
14-
15-
// cuBLASLt utils
16-
static constexpr size_t kCublasLtWorkspaceSize = 32 * 1024 * 1024;
17-
cublasLtHandle_t cublaslt_handle{};
18-
std::shared_ptr<torch::Tensor> cublaslt_workspace;
14+
int compile_mode = 0;
1915

2016
// cuBLASLt utils
2117
static constexpr size_t kCublasLtWorkspaceSize = 32 * 1024 * 1024;
@@ -82,6 +78,15 @@ class DeviceRuntime {
8278
return num_sms;
8379
}
8480

81+
void set_compile_mode(const int& new_compile_mode) {
82+
DG_HOST_ASSERT(0 <= new_compile_mode and new_compile_mode <= 1);
83+
compile_mode = new_compile_mode;
84+
}
85+
86+
int get_compile_mode() {
87+
return compile_mode;
88+
}
89+
8590
void set_tc_util(const int& new_tc_util) {
8691
DG_HOST_ASSERT(0 <= new_tc_util and new_tc_util <= 100);
8792
tc_util = new_tc_util;

csrc/jit_kernels/impls/runtime_utils.hpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,17 @@ static CUtensorMapDataType aten_dtype_to_tensor_map_dtype(const at::ScalarType&
6767
}
6868

6969
static CUtensorMapSwizzle mode_into_tensor_map_swizzle(const int& mode, const int& base) {
70-
#if CUDA_VERSION >= 12080
7170
if (base != 0) {
71+
#if CUDA_VERSION >= 12080
7272
DG_HOST_ASSERT(base == 32 and mode == 128);
7373
return CU_TENSOR_MAP_SWIZZLE_128B_ATOM_32B;
74-
}
7574
#endif
7675

76+
}
77+
78+
switch (mode) {
79+
case 0:
80+
case 16: return CU_TENSOR_MAP_SWIZZLE_NONE;
7781
case 32: return CU_TENSOR_MAP_SWIZZLE_32B;
7882
case 64: return CU_TENSOR_MAP_SWIZZLE_64B;
7983
case 128: return CU_TENSOR_MAP_SWIZZLE_128B;
@@ -215,4 +219,10 @@ static CUtensorMap make_tma_sf_desc(const cute::UMMA::Major& major,
215219
allow_tf32);
216220
}
217221

222+
#define MAYBE_LAUNCH(EXPR) do { \
223+
if (device_runtime->get_compile_mode() == 0) { \
224+
(EXPR); \
225+
} \
226+
} while (0)
227+
218228
} // namespace deep_gemm

csrc/jit_kernels/impls/sm100_bf16_gemm.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ static void sm100_bf16_gemm(const torch::Tensor& a,
134134
};
135135
const auto& code = SM100BF16GemmRuntime::generate(args);
136136
const auto& runtime = compiler->build("sm100_bf16_gemm", code);
137-
SM100BF16GemmRuntime::launch(runtime, args);
137+
MAYBE_LAUNCH(SM100BF16GemmRuntime::launch(runtime, args));
138138
}
139139

140140
} // namespace deep_gemm

csrc/jit_kernels/impls/sm100_fp8_gemm_1d1d.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ static void sm100_fp8_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa
148148
};
149149
const auto& code = SM100FP8Gemm1D1DRuntime::generate(args);
150150
const auto& runtime = compiler->build("sm100_fp8_gemm_1d1d", code);
151-
SM100FP8Gemm1D1DRuntime::launch(runtime, args);
151+
MAYBE_LAUNCH(SM100FP8Gemm1D1DRuntime::launch(runtime, args));
152152
}
153153

154154
static void sm100_m_grouped_fp8_gemm_contiguous_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
@@ -206,7 +206,7 @@ static void sm100_m_grouped_fp8_gemm_contiguous_1d1d(const torch::Tensor& a, con
206206
};
207207
const auto& code = SM100FP8Gemm1D1DRuntime::generate(args);
208208
const auto& runtime = compiler->build("sm100_m_grouped_fp8_gemm_contiguous_1d1d", code);
209-
SM100FP8Gemm1D1DRuntime::launch(runtime, args);
209+
MAYBE_LAUNCH(SM100FP8Gemm1D1DRuntime::launch(runtime, args));
210210
}
211211

212212
static void sm100_m_grouped_fp8_gemm_masked_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
@@ -265,7 +265,7 @@ static void sm100_m_grouped_fp8_gemm_masked_1d1d(const torch::Tensor& a, const t
265265
};
266266
const auto& code = SM100FP8Gemm1D1DRuntime::generate(args);
267267
const auto& runtime = compiler->build("sm100_fp8_m_grouped_gemm_masked_1d1d", code);
268-
SM100FP8Gemm1D1DRuntime::launch(runtime, args);
268+
MAYBE_LAUNCH(SM100FP8Gemm1D1DRuntime::launch(runtime, args));
269269
}
270270

271271
static void fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor& sfa,
@@ -346,7 +346,7 @@ static void fp8_k_grouped_gemm_1d1d(const torch::Tensor& a, const torch::Tensor&
346346
};
347347
const auto& code = SM100FP8Gemm1D1DRuntime::generate(args);
348348
const auto& runtime = compiler->build("sm100_fp8_k_grouped_gemm_1d1d", code);
349-
SM100FP8Gemm1D1DRuntime::launch(runtime, args);
349+
MAYBE_LAUNCH(SM100FP8Gemm1D1DRuntime::launch(runtime, args));
350350
}
351351

352352
} // namespace deep_gemm

csrc/jit_kernels/impls/sm100_fp8_gemm_1d2d.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ static void sm100_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa
127127
};
128128
const auto& code = SM100FP8Gemm1D2DRuntime::generate(args);
129129
const auto& runtime = compiler->build("sm100_fp8_gemm_1d2d", code);
130-
SM100FP8Gemm1D2DRuntime::launch(runtime, args);
130+
MAYBE_LAUNCH(SM100FP8Gemm1D2DRuntime::launch(runtime, args));
131131
}
132132

133133
static void sm100_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
@@ -181,7 +181,7 @@ static void sm100_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, con
181181
};
182182
const auto& code = SM100FP8Gemm1D2DRuntime::generate(args);
183183
const auto& runtime = compiler->build("sm100_m_grouped_fp8_gemm_contiguous_1d2d", code);
184-
SM100FP8Gemm1D2DRuntime::launch(runtime, args);
184+
MAYBE_LAUNCH(SM100FP8Gemm1D2DRuntime::launch(runtime, args));
185185
}
186186

187187
static void sm100_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
@@ -236,7 +236,7 @@ static void sm100_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const t
236236
};
237237
const auto& code = SM100FP8Gemm1D2DRuntime::generate(args);
238238
const auto& runtime = compiler->build("sm100_fp8_m_grouped_gemm_masked_1d2d", code);
239-
SM100FP8Gemm1D2DRuntime::launch(runtime, args);
239+
MAYBE_LAUNCH(SM100FP8Gemm1D2DRuntime::launch(runtime, args));
240240
}
241241

242242
} // namespace deep_gemm

csrc/jit_kernels/impls/sm90_bf16_gemm.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ static void sm90_bf16_gemm(const torch::Tensor& a,
115115
};
116116
const auto& code = SM90BF16GemmRuntime::generate(args);
117117
const auto& runtime = compiler->build("sm90_bf16_gemm", code);
118-
SM90BF16GemmRuntime::launch(runtime, args);
118+
MAYBE_LAUNCH(SM90BF16GemmRuntime::launch(runtime, args));
119119
}
120120

121121
static void sm90_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a,
@@ -168,7 +168,7 @@ static void sm90_m_grouped_bf16_gemm_contiguous(const torch::Tensor& a,
168168
};
169169
const auto& code = SM90BF16GemmRuntime::generate(args);
170170
const auto& runtime = compiler->build("sm90_m_grouped_bf16_gemm_contiguous", code);
171-
SM90BF16GemmRuntime::launch(runtime, args);
171+
MAYBE_LAUNCH(SM90BF16GemmRuntime::launch(runtime, args));
172172
}
173173

174174
static void sm90_bf16_m_grouped_gemm_masked(const torch::Tensor& a,
@@ -222,7 +222,7 @@ static void sm90_bf16_m_grouped_gemm_masked(const torch::Tensor& a,
222222
};
223223
const auto& code = SM90BF16GemmRuntime::generate(args);
224224
const auto& runtime = compiler->build("sm90_bf16_m_grouped_gemm_masked", code);
225-
SM90BF16GemmRuntime::launch(runtime, args);
225+
MAYBE_LAUNCH(SM90BF16GemmRuntime::launch(runtime, args));
226226
}
227227

228228
} // namespace deep_gemm

csrc/jit_kernels/impls/sm90_fp8_gemm_1d2d.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ static void sm90_fp8_gemm_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
128128
};
129129
const auto& code = SM90FP8Gemm1D2DRuntime::generate(args);
130130
const auto& runtime = compiler->build("sm90_fp8_gemm_1d2d", code);
131-
SM90FP8Gemm1D2DRuntime::launch(runtime, args);
131+
MAYBE_LAUNCH(SM90FP8Gemm1D2DRuntime::launch(runtime, args));
132132
}
133133

134134
static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
@@ -188,7 +188,7 @@ static void sm90_m_grouped_fp8_gemm_contiguous_1d2d(const torch::Tensor& a, cons
188188
};
189189
const auto& code = SM90FP8Gemm1D2DRuntime::generate(args);
190190
const auto& runtime = compiler->build("sm90_m_grouped_fp8_gemm_contiguous_1d2d", code);
191-
SM90FP8Gemm1D2DRuntime::launch(runtime, args);
191+
MAYBE_LAUNCH(SM90FP8Gemm1D2DRuntime::launch(runtime, args));
192192
}
193193

194194
static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const torch::Tensor& sfa,
@@ -249,7 +249,7 @@ static void sm90_m_grouped_fp8_gemm_masked_1d2d(const torch::Tensor& a, const to
249249
};
250250
const auto& code = SM90FP8Gemm1D2DRuntime::generate(args);
251251
const auto& runtime = compiler->build("sm90_fp8_m_grouped_gemm_masked_1d2d", code);
252-
SM90FP8Gemm1D2DRuntime::launch(runtime, args);
252+
MAYBE_LAUNCH(SM90FP8Gemm1D2DRuntime::launch(runtime, args));
253253
}
254254

255255
} // namespace deep_gemm

csrc/python_api.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,16 @@ TORCH_LIBRARY(deep_gemm, m) {
214214
return static_cast<int64_t>(deep_gemm::device_runtime->get_num_sms());
215215
});
216216

217+
m.def("set_compile_mode(int new_compile_mode) -> ()");
218+
m.impl("set_compile_mode", [](int64_t new_compile_mode) {
219+
deep_gemm::device_runtime->set_compile_mode(static_cast<int>(new_compile_mode));
220+
});
221+
222+
m.def("get_compile_mode() -> int");
223+
m.impl("get_compile_mode", []() -> int64_t {
224+
return static_cast<int64_t>(deep_gemm::device_runtime->get_compile_mode());
225+
});
226+
217227
m.def("set_tc_util(int new_tc_util) -> ()");
218228
m.impl("set_tc_util", [](int64_t new_tc_util) {
219229
deep_gemm::device_runtime->set_tc_util(static_cast<int>(new_tc_util));

deep_gemm/__init__.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,16 @@ def _ensure_initialized() -> None:
4444

4545

4646
def _wrap_op(name: str):
47+
func = getattr(torch.ops.deep_gemm, name)
4748
def _fn(*args, **kwargs):
4849
_ensure_initialized()
49-
return getattr(torch.ops.deep_gemm, name)(*args, **kwargs)
50+
return func(*args, **kwargs)
5051
return _fn
5152

5253
set_num_sms = _wrap_op('set_num_sms')
5354
get_num_sms = _wrap_op('get_num_sms')
55+
set_compile_mode = _wrap_op('set_compile_mode')
56+
get_compile_mode = _wrap_op('get_compile_mode')
5457
set_tc_util = _wrap_op('set_tc_util')
5558
get_tc_util = _wrap_op('get_tc_util')
5659

@@ -121,10 +124,10 @@ def _verify_ops_loaded():
121124
'cublaslt_gemm_nt', 'cublaslt_gemm_nn',
122125
'cublaslt_gemm_tn', 'cublaslt_gemm_tt',
123126
]
124-
127+
125128
available_ops = list(torch.ops.deep_gemm.__dict__.keys())
126129
missing_ops = [op for op in expected_ops if op not in available_ops]
127-
130+
128131
if missing_ops:
129132
print(f"Warning: Missing operations: {missing_ops}")
130133

0 commit comments

Comments
 (0)