Skip to content

Commit 6c6be64

Browse files
Liangliang-MajikunshangYizhouZ
authored
Grouped gemm cutlass (#22)
* add flash attention interface Signed-off-by: Kunshang Ji <[email protected]> * update interface Signed-off-by: Kunshang Ji <[email protected]> * add cutlass deps (#1) * add cutlass Signed-off-by: Kunshang Ji <[email protected]> * fix import Signed-off-by: Kunshang Ji <[email protected]> --------- Signed-off-by: Kunshang Ji <[email protected]> * add chunk_prefill step<1> * fix register * fix cmake * debug msg * functional ready * dev base Signed-off-by: Ma, Liangliang <[email protected]> * base of grouped_gemm_fp8 Signed-off-by: Ma, Liangliang <[email protected]> * update func Signed-off-by: Ma, Liangliang <[email protected]> * add test Signed-off-by: Ma, Liangliang <[email protected]> * update functor Signed-off-by: Ma, Liangliang <[email protected]> * update grouped_gemm Signed-off-by: Ma, Liangliang <[email protected]> * build ready Signed-off-by: Ma, Liangliang <[email protected]> * base integration done Signed-off-by: Ma, Liangliang <[email protected]> * grouped gemm base ready Signed-off-by: Ma, Liangliang <[email protected]> * gemm2 use cutlass grouped_mm Signed-off-by: Ma, Liangliang <[email protected]> * gemm1 use cutlass group_mm Signed-off-by: Ma, Liangliang <[email protected]> * rm flash_attn in this pr Signed-off-by: Ma, Liangliang <[email protected]> * rebase CMakeLists Signed-off-by: Ma, Liangliang <[email protected]> * use main Cmakes Signed-off-by: Ma, Liangliang <[email protected]> * use main setup Signed-off-by: Ma, Liangliang <[email protected]> * mv utils Signed-off-by: Ma, Liangliang <[email protected]> * finish rebase Signed-off-by: Ma, Liangliang <[email protected]> * add profile and change to col-maj Signed-off-by: Ma, Liangliang <[email protected]> * dont not reserve block_C Signed-off-by: Ma, Liangliang <[email protected]> * remove redundant allocation Signed-off-by: Ma, Liangliang <[email protected]> * e2e debug Signed-off-by: Ma, Liangliang <[email protected]> * add release func Signed-off-by: Ma, Liangliang <[email protected]> * gemm args allocate once Signed-off-by: Ma, Liangliang <[email protected]> * hidden_states copy Signed-off-by: Ma, Liangliang <[email protected]> * output bf16 Signed-off-by: Ma, Liangliang <[email protected]> * use static tensor buffer Signed-off-by: Ma, Liangliang <[email protected]> * remove ptr_C Signed-off-by: Ma, Liangliang <[email protected]> * fix device lost Signed-off-by: Ma, Liangliang <[email protected]> * acc and oom fixed Signed-off-by: Ma, Liangliang <[email protected]> * base Signed-off-by: Ma, Liangliang <[email protected]> * update CMakeLists Signed-off-by: Ma, Liangliang <[email protected]> * refactor csrc of cutlass Signed-off-by: Ma, Liangliang <[email protected]> * put src in vllm Signed-off-by: Ma, Liangliang <[email protected]> * add adapter src Signed-off-by: Ma, Liangliang <[email protected]> * clean up Signed-off-by: Ma, Liangliang <[email protected]> * add test Signed-off-by: Ma, Liangliang <[email protected]> * clean up Signed-off-by: Ma, Liangliang <[email protected]> * fix format Signed-off-by: Ma, Liangliang <[email protected]> * fix format f841 Signed-off-by: Ma, Liangliang <[email protected]> --------- Signed-off-by: Kunshang Ji <[email protected]> Signed-off-by: Ma, Liangliang <[email protected]> Co-authored-by: Kunshang Ji <[email protected]> Co-authored-by: Yizhou Wang <[email protected]>
1 parent a25429f commit 6c6be64

20 files changed

+6037
-2
lines changed

CMakeLists.txt

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ if(VLLM_GPU_LANG STREQUAL "SYCL")
177177
FetchContent_Declare(
178178
cutlass-sycl
179179
GIT_REPOSITORY https://github.com/intel/cutlass-sycl
180+
180181
# Please keep this in sync with CUTLASS_REVISION line above.
181182
GIT_TAG ${CUTLASS_REVISION}
182183
GIT_PROGRESS TRUE
@@ -196,7 +197,6 @@ if(VLLM_GPU_LANG STREQUAL "SYCL")
196197
set(CUTLASS_ENABLE_GDC_FOR_SM100_DEFAULT OFF CACHE BOOL "DISABLE CUDA")
197198
# list(APPEND CMAKE_CXX_FLAGS "-ftemplate-backtrace-limit=0 " )
198199
# list(APPEND CMAKE_CXX_FLAGS "-fdiagnostics-color=always " )
199-
200200

201201
FetchContent_MakeAvailable(cutlass-sycl)
202202
set(CUTLASS_INCLUDE_DIR ${cutlass-sycl_SOURCE_DIR}/include CACHE PATH "CUTLASS Header Library")
@@ -269,11 +269,15 @@ endif ()
269269
#
270270
# xpu only ops/kernels, implemented with cutlass/onednn/sycl.
271271
#
272+
file(GLOB CUTLASS_BACKEND_SRCS
273+
csrc/xpu/cutlass_kernels/*.cpp
274+
)
272275
if(VLLM_GPU_LANG STREQUAL "SYCL")
273276
set(VLLM_EXT_XPU_SRC
274277
"csrc/xpu/torch_bindings.cpp"
275278
"csrc/xpu/lora/lora_shrink.cpp"
276279
"csrc/xpu/lora/lora_expand.cpp"
280+
${CUTLASS_BACKEND_SRCS}
277281
)
278282
include_directories("/usr/include")
279283
set(CMPLR_ROOT $ENV{CMPLR_ROOT})
@@ -282,6 +286,12 @@ if(VLLM_GPU_LANG STREQUAL "SYCL")
282286
list(APPEND VLLM_GPU_FLAGS "-DVLLM_BUILD_XPU_OPS" )
283287
list(APPEND VLLM_GPU_LINK_FLAGS "-fsycl" "-fsycl-targets=spir64")
284288
list(APPEND VLLM_LINK_LIBRARIES "sycl" "OpenCL" "pthread" "m" "dl" "torch" )
289+
# CUTLASS FLAGS
290+
list(APPEND VLLM_GPU_FLAGS "-O3" "-DNDEBUG")
291+
list(APPEND VLLM_GPU_FLAGS "-gline-tables-only")
292+
list(APPEND VLLM_GPU_FLAGS "-fsycl" "-fsycl-targets=spir64_gen" "-ftemplate-backtrace-limit=10")
293+
list(APPEND VLLM_GPU_LINK_FLAGS "-fsycl" "-fsycl-targets=spir64_gen")
294+
list(APPEND VLLM_GPU_LINK_FLAGS -Xsycl-target-backend=spir64_gen "-device bmg-g21-a0 -internal_options -cl-intel-256-GRF-per-thread")
285295
endif()
286296

287297
if(ONEDNN_FOUND)
@@ -305,6 +315,8 @@ define_gpu_extension_target(
305315
ARCHITECTURES ${VLLM_GPU_ARCHES}
306316
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
307317
INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
318+
INCLUDE_DIRECTORIES ${CUTLASS_APP_INCLUDE_DIR}
319+
INCLUDE_DIRECTORIES ${VLLM_INCLUDE_DIR}
308320
USE_SABI 3
309321
WITH_SOABI)
310322

csrc/core/registration.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#pragma once
2-
32
#include <Python.h>
43

54
#define _CONCAT(A, B) A##B
Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
/***************************************************************************************************
2+
* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights
3+
*reserved. SPDX-License-Identifier: BSD-3-Clause
4+
*
5+
* Redistribution and use in source and binary forms, with or without
6+
* modification, are permitted provided that the following conditions are met:
7+
*
8+
* 1. Redistributions of source code must retain the above copyright notice,
9+
*this list of conditions and the following disclaimer.
10+
*
11+
* 2. Redistributions in binary form must reproduce the above copyright notice,
12+
* this list of conditions and the following disclaimer in the documentation
13+
* and/or other materials provided with the distribution.
14+
*
15+
* 3. Neither the name of the copyright holder nor the names of its
16+
* contributors may be used to endorse or promote products derived from
17+
* this software without specific prior written permission.
18+
*
19+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
22+
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
23+
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
24+
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
25+
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
26+
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
27+
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
28+
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
29+
*POSSIBILITY OF SUCH DAMAGE.
30+
*
31+
**************************************************************************************************/
32+
33+
/*! \file
34+
\brief
35+
Default kernel-level GEMM definitions combine threadblock-scoped matrix
36+
multiply-add with the appropriate threadblock-scoped epilogue.
37+
38+
Note, CUTLASS epilogues universally target row-major outputs. Column-major
39+
outputs are accommodated by exchanging A and B operands and assuming
40+
transposed layouts. Partial specializations here choose
41+
'device::GemmTransposed' to implement this functionality.
42+
43+
*/
44+
45+
#pragma once
46+
47+
#include "cutlass/cutlass.h"
48+
49+
#include "cutlass/complex.h"
50+
#include "cutlass/layout/matrix.h"
51+
#include "cutlass/numeric_types.h"
52+
53+
#include "gemm_universal_k.h"
54+
#include "cutlass/gemm/kernel/gemm_universal_streamk.h"
55+
#include "cutlass/gemm/kernel/default_gemm.h"
56+
#include "cutlass/gemm/kernel/default_gemm_complex.h"
57+
58+
#include "cutlass/layout/permute.h"
59+
60+
/////////////////////////////////////////////////////////////////////////////////////////////////
61+
62+
namespace cutlass {
63+
namespace gemm {
64+
namespace kernel {
65+
66+
/////////////////////////////////////////////////////////////////////////////////////////////////
67+
68+
template <
69+
/// Element type for A matrix operand
70+
typename ElementA_,
71+
/// Layout type for A matrix operand
72+
typename LayoutA_,
73+
/// Complex elementwise transformation on A operand
74+
ComplexTransform TransformA,
75+
/// Access granularity of A matrix in units of elements
76+
int kAlignmentA,
77+
/// Element type for B matrix operand
78+
typename ElementB_,
79+
/// Layout type for B matrix operand
80+
typename LayoutB_,
81+
/// Complex elementwise transformation on B operand
82+
ComplexTransform TransformB,
83+
/// Access granularity of B matrix in units of elements
84+
int kAlignmentB,
85+
/// Element type for C and D matrix operands
86+
typename ElementC_,
87+
/// Layout type for C and D matrix operands
88+
typename LayoutC_,
89+
/// Element type for internal accumulation
90+
typename ElementAccumulator,
91+
/// Operator class tag
92+
typename OperatorClass,
93+
/// Tag indicating architecture to tune for
94+
typename ArchTag,
95+
/// Threadblock-level tile size (concept: GemmShape)
96+
typename ThreadblockShape,
97+
/// Warp-level tile size (concept: GemmShape)
98+
typename WarpShape,
99+
/// Instruction tile size (concept: GemmShape)
100+
typename InstructionShape,
101+
/// Epilogue output operator
102+
typename EpilogueOutputOp,
103+
/// Threadblock-level swizzling operator
104+
typename ThreadblockSwizzle,
105+
/// Number of stages used in the pipelined mainloop
106+
int Stages,
107+
/// Operation performed by GEMM
108+
typename Operator,
109+
/// Use zfill or predicate for out-of-bound cp.async
110+
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
111+
/// Gather operand A by using an index array
112+
bool GatherA = false,
113+
/// Gather operand B by using an index array
114+
bool GatherB = false,
115+
/// Scatter result D by using an index array
116+
bool ScatterD = false,
117+
/// Permute result D
118+
typename PermuteDLayout = layout::NoPermute,
119+
/// Permute operand A
120+
typename PermuteALayout_ = layout::NoPermute,
121+
/// Permute operand B
122+
typename PermuteBLayout_ = layout::NoPermute,
123+
///
124+
typename Enable = void>
125+
struct DefaultGemmUniversal;
126+
127+
/////////////////////////////////////////////////////////////////////////////////////////////////
128+
//
129+
// Real-valued GEMM kernels
130+
//
131+
132+
template <
133+
/// Element type for A matrix operand
134+
typename ElementA,
135+
/// Layout type for A matrix operand
136+
typename LayoutA,
137+
/// Access granularity of A matrix in units of elements
138+
int kAlignmentA,
139+
/// Element type for B matrix operand
140+
typename ElementB,
141+
/// Layout type for B matrix operand
142+
typename LayoutB,
143+
/// Access granularity of B matrix in units of elements
144+
int kAlignmentB,
145+
/// Element type for C and D matrix operands
146+
typename ElementC,
147+
/// Layout type for C and D matrix operands
148+
typename LayoutC,
149+
/// Element type for internal accumulation
150+
typename ElementAccumulator,
151+
/// Operator class tag
152+
typename OperatorClass,
153+
/// Tag indicating architecture to tune for
154+
typename ArchTag,
155+
/// Threadblock-level tile size (concept: GemmShape)
156+
typename ThreadblockShape,
157+
/// Warp-level tile size (concept: GemmShape)
158+
typename WarpShape,
159+
/// Warp-level tile size (concept: GemmShape)
160+
typename InstructionShape,
161+
/// Epilogue output operator
162+
typename EpilogueOutputOp,
163+
/// Threadblock-level swizzling operator
164+
typename ThreadblockSwizzle,
165+
/// Number of stages used in the pipelined mainloop
166+
int Stages,
167+
/// Operation performed by GEMM
168+
typename Operator,
169+
/// Use zfill or predicate for out-of-bound cp.async
170+
SharedMemoryClearOption SharedMemoryClear,
171+
/// Gather operand A by using an index array
172+
bool GatherA,
173+
/// Gather operand B by using an index array
174+
bool GatherB,
175+
/// Scatter result D by using an index array
176+
bool ScatterD,
177+
/// Permute result D
178+
typename PermuteDLayout,
179+
/// Permute operand A
180+
typename PermuteALayout,
181+
/// Permute operand B
182+
typename PermuteBLayout>
183+
struct DefaultGemmUniversal<
184+
ElementA, LayoutA,
185+
ComplexTransform::kNone, // transform A
186+
kAlignmentA, ElementB, LayoutB,
187+
ComplexTransform::kNone, // transform B
188+
kAlignmentB, ElementC, LayoutC, ElementAccumulator, OperatorClass, ArchTag,
189+
ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp,
190+
ThreadblockSwizzle, Stages, Operator, SharedMemoryClear, GatherA, GatherB,
191+
ScatterD, PermuteDLayout, PermuteALayout, PermuteBLayout,
192+
typename platform::enable_if<
193+
!cutlass::is_complex<ElementAccumulator>::value>::type> {
194+
using DefaultGemmKernel = typename kernel::DefaultGemm<
195+
ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementC,
196+
LayoutC, ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape,
197+
WarpShape, InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages,
198+
true, Operator, SharedMemoryClear, GatherA, GatherB, ScatterD,
199+
PermuteDLayout, PermuteALayout, PermuteBLayout>::GemmKernel;
200+
201+
/// Universal kernel without StreamkFeature member type
202+
template <class SwizzleT, class Enable = void>
203+
class SelectBase
204+
: public kernel::GemmUniversal<typename DefaultGemmKernel::Mma,
205+
typename DefaultGemmKernel::Epilogue,
206+
SwizzleT> {};
207+
208+
/// Universal kernel with StreamkFeature member type
209+
template <class SwizzleT>
210+
class SelectBase<SwizzleT, typename SwizzleT::StreamkFeature>
211+
: public kernel::GemmUniversalStreamk<
212+
typename DefaultGemmKernel::Mma,
213+
typename DefaultGemmKernel::Epilogue, SwizzleT> {};
214+
215+
/// Select kernel by ThreadblockSwizzle's support for StreamkFeature
216+
using GemmKernel = SelectBase<ThreadblockSwizzle>;
217+
};
218+
219+
/////////////////////////////////////////////////////////////////////////////////////////////////
220+
221+
//
222+
// Complex-valued GEMM kernels
223+
//
224+
225+
template <
226+
/// Element type for A matrix operand
227+
typename ElementA,
228+
/// Layout type for A matrix operand
229+
typename LayoutA,
230+
/// Complex elementwise transformation on A operand
231+
ComplexTransform TransformA,
232+
/// Access granularity of A matrix in units of elements
233+
int kAlignmentA,
234+
/// Element type for B matrix operand
235+
typename ElementB,
236+
/// Layout type for B matrix operand
237+
typename LayoutB,
238+
/// Complex elementwise transformation on B operand
239+
ComplexTransform TransformB,
240+
/// Access granularity of B matrix in units of elements
241+
int kAlignmentB,
242+
/// Element type for C and D matrix operands
243+
typename ElementC,
244+
/// Layout type for C and D matrix operands
245+
typename LayoutC,
246+
/// Element type for internal accumulation
247+
typename ElementAccumulator,
248+
/// Operator class tag
249+
typename OperatorClass,
250+
/// Tag indicating architecture to tune for
251+
typename ArchTag,
252+
/// Threadblock-level tile size (concept: GemmShape)
253+
typename ThreadblockShape,
254+
/// Warp-level tile size (concept: GemmShape)
255+
typename WarpShape,
256+
/// Warp-level tile size (concept: GemmShape)
257+
typename InstructionShape,
258+
/// Epilogue output operator
259+
typename EpilogueOutputOp,
260+
/// Threadblock-level swizzling operator
261+
typename ThreadblockSwizzle,
262+
/// Number of stages used in the pipelined mainloop
263+
int Stages,
264+
/// Operation performed by GEMM
265+
typename Operator,
266+
/// Use zfill or predicate for out-of-bound cp.async
267+
SharedMemoryClearOption SharedMemoryClear>
268+
struct DefaultGemmUniversal<
269+
ElementA, LayoutA, TransformA, kAlignmentA, ElementB, LayoutB, TransformB,
270+
kAlignmentB, ElementC, LayoutC, ElementAccumulator, OperatorClass, ArchTag,
271+
ThreadblockShape, WarpShape, InstructionShape, EpilogueOutputOp,
272+
ThreadblockSwizzle, Stages, Operator, SharedMemoryClear, false, false,
273+
false, layout::NoPermute, layout::NoPermute, layout::NoPermute,
274+
typename platform::enable_if<
275+
cutlass::is_complex<ElementAccumulator>::value>::type> {
276+
using DefaultGemmKernel = typename kernel::DefaultGemmComplex<
277+
ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
278+
ElementAccumulator, OperatorClass, ArchTag, ThreadblockShape, WarpShape,
279+
InstructionShape, EpilogueOutputOp, ThreadblockSwizzle, Stages,
280+
TransformA, TransformB, Operator, false>::GemmKernel;
281+
282+
/// Universal kernel without StreamkFeature member type
283+
template <class SwizzleT, class Enable = void>
284+
class SelectBase
285+
: public kernel::GemmUniversal<typename DefaultGemmKernel::Mma,
286+
typename DefaultGemmKernel::Epilogue,
287+
SwizzleT> {};
288+
289+
/// Universal kernel with StreamkFeature member type
290+
template <class SwizzleT>
291+
class SelectBase<SwizzleT, typename SwizzleT::StreamkFeature>
292+
: public kernel::GemmUniversalStreamk<
293+
typename DefaultGemmKernel::Mma,
294+
typename DefaultGemmKernel::Epilogue, SwizzleT> {};
295+
296+
/// Select kernel by ThreadblockSwizzle's support for StreamkFeature
297+
using GemmKernel = SelectBase<ThreadblockSwizzle>;
298+
};
299+
300+
/////////////////////////////////////////////////////////////////////////////////////////////////
301+
302+
} // namespace kernel
303+
} // namespace gemm
304+
} // namespace cutlass
305+
306+
/////////////////////////////////////////////////////////////////////////////////////////////////

0 commit comments

Comments
 (0)