Skip to content

Commit a33d98f

Browse files
[CK TILE ENGINE] GEMM Multi D Restructure (#3121)
* Renaming old code * Adding GEMM code with new Architecture * Partial Progress : Errors * Partial Progress : Working code * Changes to element wise function * Removing Debugging statements * Working GEMM Multi D code * Removing Stale Code * Address Copilot review comments * Address Copilot review comments * Changes to validation file * Changes to common code snippets * Creating common folder * Removing duplicate files * Pointing to right common file * Pointing to right common file * Pointing to right common file * Changing to VERBOSE * Changing CMAKE messages to verbose * Updating Cmake with right layout datatype configs * Working code for GEMM Multi D
1 parent 04efd28 commit a33d98f

22 files changed

+2415
-1974
lines changed

Jenkinsfile

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1642,14 +1642,9 @@ pipeline {
16421642
ninja -j64 benchmark_gemm_preshuffle_all && \
16431643
python3 ../tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" \
16441644
--warmup 5 --repeat 5 --verbose --json results.json && \
1645-
ninja -j64 benchmark_gemm_multi_d_fp16_rrrr && \
1646-
./bin/benchmark_gemm_multi_d_fp16_rrrr && \
1647-
ninja -j64 benchmark_gemm_multi_d_fp16_ccrr && \
1648-
./bin/benchmark_gemm_multi_d_fp16_ccrr && \
1649-
ninja -j64 benchmark_gemm_multi_d_fp16_crrr && \
1650-
./bin/benchmark_gemm_multi_d_fp16_crrr && \
1651-
ninja -j64 benchmark_gemm_multi_d_fp16_rcrr && \
1652-
./bin/benchmark_gemm_multi_d_fp16_rcrr """
1645+
ninja -j64 benchmark_gemm_multi_d_all && \
1646+
python3 ../tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.py . --problem-sizes "1024,1024,1024" \
1647+
--warmup 5 --repeat 5 --verbose --json results.json """
16531648
}
16541649
steps{
16551650
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
@@ -1682,14 +1677,9 @@ pipeline {
16821677
ninja -j64 benchmark_gemm_preshuffle_all && \
16831678
python3 ../tile_engine/ops/gemm_preshuffle/gemm_preshuffle_benchmark.py . --problem-sizes "1024,1024,1024" \
16841679
--warmup 5 --repeat 5 --verbose --json results.json && \
1685-
ninja -j64 benchmark_gemm_multi_d_fp16_rrrr && \
1686-
./bin/benchmark_gemm_multi_d_fp16_rrrr && \
1687-
ninja -j64 benchmark_gemm_multi_d_fp16_ccrr && \
1688-
./bin/benchmark_gemm_multi_d_fp16_ccrr && \
1689-
ninja -j64 benchmark_gemm_multi_d_fp16_crrr && \
1690-
./bin/benchmark_gemm_multi_d_fp16_crrr && \
1691-
ninja -j64 benchmark_gemm_multi_d_fp16_rcrr && \
1692-
./bin/benchmark_gemm_multi_d_fp16_rcrr """
1680+
ninja -j64 benchmark_gemm_multi_d_all && \
1681+
python3 ../tile_engine/ops/gemm_multi_d/gemm_multi_d_benchmark.py . --problem-sizes "1024,1024,1024" \
1682+
--warmup 5 --repeat 5 --verbose --json results.json """
16931683
}
16941684
steps{
16951685
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)

tile_engine/ops/gemm/commons/validation_utils.py renamed to tile_engine/ops/commons/validation_utils.py

Lines changed: 23 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -125,38 +125,13 @@
125125
[32, 32, 64],
126126
],
127127
},
128-
"gfx1201": {
128+
"gfx1201": { # Check how to handle for GEMM and Multi D
129129
"fp16_fp16_fp16": [
130130
[16, 16, 16],
131131
],
132132
},
133133
}
134134

135-
# Supported warp tile combinations for different GPU architectures and data types
136-
WARP_SUPPORTED_COMBINATIONS = {
137-
"gfx90a": [
138-
[1, 4, 1],
139-
[2, 2, 1],
140-
[4, 1, 1],
141-
],
142-
"gfx942": [
143-
[1, 4, 1],
144-
[2, 2, 1],
145-
[4, 1, 1],
146-
],
147-
"gfx950": [
148-
[1, 4, 1],
149-
[2, 2, 1],
150-
[4, 1, 1],
151-
],
152-
"gfx1201": [
153-
[2, 4, 1],
154-
[1, 8, 1],
155-
[8, 1, 1],
156-
[4, 2, 1],
157-
],
158-
}
159-
160135
# Unsupported trait combinations
161136
TRAIT_UNSUPPORTED_COMBINATIONS = {
162137
("compv3", "cshuffle", "interwave"),
@@ -441,6 +416,20 @@ def get_abc_layouts(layout_code: str) -> Tuple[str, str, str]:
441416
return a_layout, b_layout, c_layout
442417

443418

419+
def get_abcd_layouts(layout_code: str) -> Tuple[str, str, str, List[str]]:
420+
"""
421+
Return (ALayout, BLayout, CLayout) from a 3-letter code like 'rcrr', 'ccrr', 'crrr', 'rrrr'.
422+
"""
423+
code = str(layout_code).strip().lower()
424+
425+
a_layout = LAYOUT_MAP[code[0]]
426+
b_layout = LAYOUT_MAP[code[1]]
427+
c_layout = LAYOUT_MAP[code[2]]
428+
d0_layout = LAYOUT_MAP[code[3]]
429+
d1_layout = LAYOUT_MAP[code[3]]
430+
return a_layout, b_layout, c_layout, [d0_layout, d1_layout]
431+
432+
444433
def validate_whole_wg_cover_configuration(
445434
tile_m,
446435
tile_n,
@@ -464,13 +453,13 @@ def validate_whole_wg_cover_configuration(
464453

465454
# A matrix validation
466455
if layout[0] == "r":
467-
XPerTile = tile_k
468-
YPerTile = tile_m
469-
470456
vector_load_size = get_global_vector_load_size(
471457
BlockSize, tile_k, a_datatype, tile_m, tile_k
472458
)
473459

460+
XPerTile = tile_k
461+
YPerTile = tile_m
462+
474463
elif layout[0] == "c":
475464
vector_load_size = get_global_vector_load_size(
476465
BlockSize, tile_k, a_datatype, tile_m, tile_m
@@ -485,7 +474,6 @@ def validate_whole_wg_cover_configuration(
485474
)
486475

487476
if not wg_cover_core_valid:
488-
print("I am here 1")
489477
logging.debug(
490478
f"whole workgroup cover failed for Matrix A distribution: {wg_cover_core_error}"
491479
)
@@ -521,7 +509,7 @@ def validate_whole_wg_cover_configuration(
521509
if not wg_cover_core_valid:
522510
print("I am here 3")
523511
logging.debug(
524-
f"whole workgroup cover failed for Matrix A distribution: {wg_cover_core_error}"
512+
f"whole workgroup cover failed for Matrix B distribution: {wg_cover_core_error}"
525513
)
526514
return False, wg_cover_core_error
527515

@@ -540,7 +528,6 @@ def validate_whole_wg_cover_configuration(
540528
XPerTile, YPerTile, BlockSize, vector_load_size, warp_size
541529
)
542530
if not wg_cover_core_valid:
543-
print("I am here 4")
544531
logging.debug(
545532
f"whole workgroup cover failed for Matrix B: {wg_cover_core_error}"
546533
)
@@ -557,7 +544,7 @@ def wg_cover_core_validation(
557544
warp_size: int,
558545
) -> Tuple[bool, str]:
559546
if XPerTile % vector_load_size != 0:
560-
return False
547+
return False, "XPerTile is not divisible by vector_load_size"
561548

562549
num_warps = BlockSize / warp_size
563550
LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size)
@@ -567,7 +554,7 @@ def wg_cover_core_validation(
567554
Y1 = warp_size // X0
568555

569556
if X0 * Y1 != warp_size:
570-
return False, ""
557+
return False, "X0 * Y1 != warp_size"
571558

572559
return True, ""
573560

@@ -583,9 +570,9 @@ def get_global_vector_load_size(
583570
PackedSize = 1
584571

585572
if (
586-
XPerTile % (PackedSize * 32 / element_size(DataType)) == 0
573+
PackedSize == 2
574+
and XPerTile % (PackedSize * 32 / element_size(DataType)) == 0
587575
and elements_per_thread % (PackedSize * 32 / element_size(DataType)) == 0
588-
and PackedSize == 2
589576
):
590577
return PackedSize * 32 / element_size(DataType)
591578
elif (

tile_engine/ops/gemm/CMakeLists.txt

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -122,15 +122,15 @@ function(build_individual_gemm_targets datatype layout)
122122
if(DEFINED ENV{GEMM_CONFIG_FILE} AND NOT "$ENV{GEMM_CONFIG_FILE}" STREQUAL "")
123123
set(config_filename "$ENV{GEMM_CONFIG_FILE}")
124124
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${config_filename}")
125-
message(STATUS " Using config from environment variable: ${config_filename}")
125+
message(VERBOSE " Using config from environment variable: ${config_filename}")
126126
elseif(NOT "${GEMM_CONFIG_FILE}" STREQUAL "")
127127
# Use CMake variable if set
128128
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/${GEMM_CONFIG_FILE}")
129-
message(STATUS " Using custom config: ${GEMM_CONFIG_FILE}")
129+
message(VERBOSE " Using custom config: ${GEMM_CONFIG_FILE}")
130130
else()
131131
# Use default config for all layouts
132132
set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json")
133-
message(STATUS " Using default config for layout ${layout}")
133+
message(VERBOSE " Using default config for layout ${layout}")
134134
endif()
135135

136136
# Check if config file exists
@@ -151,16 +151,16 @@ function(build_individual_gemm_targets datatype layout)
151151
endif()
152152

153153
# Generate individual kernel files using parallel version
154-
message(STATUS "Generating individual kernels for ${datatype} ${layout} using ${num_workers} workers...")
155-
message(STATUS " Working path: ${working_path}")
156-
message(STATUS " Config file: ${json_blob}")
157-
message(STATUS " Python executable: ${Python3_EXECUTABLE}")
158-
message(STATUS " Script path: ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py")
154+
message(VERBOSE "Generating individual kernels for ${datatype} ${layout} using ${num_workers} workers...")
155+
message(VERBOSE " Working path: ${working_path}")
156+
message(VERBOSE " Config file: ${json_blob}")
157+
message(VERBOSE " Python executable: ${Python3_EXECUTABLE}")
158+
message(VERBOSE " Script path: ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py")
159159

160160
# Create working directory first
161161
file(MAKE_DIRECTORY ${working_path})
162162

163-
message(STATUS "COMMAND: ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
163+
message(VERBOSE "COMMAND: ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
164164
--working_path ${working_path}
165165
--datatype ${datatype}
166166
--layout ${layout}
@@ -169,7 +169,7 @@ function(build_individual_gemm_targets datatype layout)
169169
--list_kernels ")
170170

171171
# First, just list the kernels (fast operation)
172-
message(STATUS " Listing kernel configurations...")
172+
message(VERBOSE " Listing kernel configurations...")
173173
execute_process(
174174
COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py
175175
--working_path ${working_path}
@@ -192,7 +192,7 @@ function(build_individual_gemm_targets datatype layout)
192192
if(EXISTS ${working_path}/gemm_kernel_count.txt)
193193
file(READ ${working_path}/gemm_kernel_count.txt kernel_count)
194194
string(STRIP "${kernel_count}" kernel_count)
195-
message(STATUS " Found ${kernel_count} kernel configurations")
195+
message(VERBOSE " Found ${kernel_count} kernel configurations")
196196
else()
197197
message(FATAL_ERROR "Kernel count file not found")
198198
endif()
@@ -216,10 +216,10 @@ function(build_individual_gemm_targets datatype layout)
216216
endfunction()
217217

218218
# Main build logic - Only individual builds supported
219-
message(STATUS "=== Starting Tile Engine GEMM Configuration ===")
220-
message(STATUS "GEMM_DATATYPE: ${GEMM_DATATYPE}")
221-
message(STATUS "GEMM_LAYOUT: ${GEMM_LAYOUT}")
222-
message(STATUS "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
219+
message(VERBOSE "=== Starting Tile Engine GEMM Configuration ===")
220+
message(VERBOSE "GEMM_DATATYPE: ${GEMM_DATATYPE}")
221+
message(VERBOSE "GEMM_LAYOUT: ${GEMM_LAYOUT}")
222+
message(VERBOSE "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
223223

224224
# Filter GPU targets to only gfx90a, gfx942, gfx950, gfx1201
225225
set(GEMM_GPU_TARGETS_INDIVIDUAL "")
@@ -228,15 +228,15 @@ set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx1201")
228228
foreach(target IN LISTS SUPPORTED_GPU_TARGETS)
229229
if(target IN_LIST DESIRED_TARGETS)
230230
list(APPEND GEMM_GPU_TARGETS_INDIVIDUAL ${target})
231-
message(STATUS " Adding GPU target: ${target}")
231+
message(VERBOSE " Adding GPU target: ${target}")
232232
endif()
233233
endforeach()
234234

235235
# Skip build if no matching targets found
236236
if(NOT GEMM_GPU_TARGETS_INDIVIDUAL)
237237
message(WARNING "Skipping Tile Engine GEMM build: No supported GPU targets (gfx90a, gfx942, gfx950, gfx1201) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
238238
else()
239-
message(STATUS "Building individual GEMM targets for GPU targets: ${GEMM_GPU_TARGETS_INDIVIDUAL}")
239+
message(VERBOSE "Building individual GEMM targets for GPU targets: ${GEMM_GPU_TARGETS_INDIVIDUAL}")
240240

241241
# Enable parallel compilation optimizations
242242
# Set up job pools for better parallel compilation control
@@ -251,12 +251,12 @@ else()
251251
find_program(CCACHE_PROGRAM ccache)
252252
if(CCACHE_PROGRAM)
253253
set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM})
254-
message(STATUS "Using ccache for faster compilation")
254+
message(VERBOSE "Using ccache for faster compilation")
255255
else()
256256
message(WARNING "ccache requested but not found")
257257
endif()
258258
else()
259-
message(STATUS "ccache disabled for GEMM ops (use -DENABLE_CCACHE_GEMM=ON to enable)")
259+
message(VERBOSE "ccache disabled for GEMM ops (use -DENABLE_CCACHE_GEMM=ON to enable)")
260260
endif()
261261

262262
# Create master collection targets

tile_engine/ops/gemm/gemm_instance_builder.py

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,30 @@
88
import concurrent.futures
99
from pathlib import Path
1010
import logging
11-
from commons.validation_utils import (
12-
is_tile_config_valid,
13-
is_trait_combination_valid,
14-
get_dtype_string,
15-
get_abc_layouts,
16-
)
11+
import importlib.util
12+
13+
14+
def _import_validation_utils():
15+
"""Import validation utilities from commons directory."""
16+
current_dir = os.path.dirname(os.path.abspath(__file__))
17+
parent_dir = os.path.dirname(current_dir)
18+
19+
# Load the module dynamically
20+
spec = importlib.util.spec_from_file_location(
21+
"validation_utils", os.path.join(parent_dir, "commons", "validation_utils.py")
22+
)
23+
validation_utils = importlib.util.module_from_spec(spec)
24+
spec.loader.exec_module(validation_utils)
25+
26+
return validation_utils
27+
28+
29+
# Import validation functions
30+
_validation_utils = _import_validation_utils()
31+
is_tile_config_valid = _validation_utils.is_tile_config_valid
32+
is_trait_combination_valid = _validation_utils.is_trait_combination_valid
33+
get_dtype_string = _validation_utils.get_dtype_string
34+
get_abc_layouts = _validation_utils.get_abc_layouts
1735

1836
logging.basicConfig(level=logging.INFO)
1937

@@ -563,6 +581,8 @@ def generate_individual(self, num_workers=None):
563581
tile_configs = self._get_tile_configs()
564582
trait_combos = self._generate_trait_combinations()
565583
k_block_per_cu = self.config.get("k_block_per_cu")
584+
if k_block_per_cu is None:
585+
k_block_per_cu = 1
566586

567587
# Prepare work items for parallel processing
568588
work_items = []
@@ -574,11 +594,12 @@ def generate_individual(self, num_workers=None):
574594
trait_combo,
575595
k_block_per_cu,
576596
self.working_path,
597+
self.gpu_target,
577598
self.datatype,
578599
self.layout,
600+
self.config_json,
579601
)
580602
)
581-
582603
print(
583604
f"Generating {len(work_items)} individual kernel files using {num_workers} workers..."
584605
)
@@ -615,7 +636,6 @@ def generate_individual(self, num_workers=None):
615636
print(
616637
f" Progress: {completed}/{len(work_items)} kernels generated"
617638
)
618-
619639
try:
620640
result = future.result()
621641
if result:
@@ -662,10 +682,19 @@ def _generate_cmake_individual_targets(self, kernel_list):
662682

663683
def _generate_single_kernel_individual(work_item):
664684
"""Worker function to generate a single individual kernel file"""
665-
tile_config, trait_combo, k_block_per_cu, working_path, datatype, layout = work_item
685+
(
686+
tile_config,
687+
trait_combo,
688+
k_block_per_cu,
689+
working_path,
690+
gpu_target,
691+
datatype,
692+
layout,
693+
config_json,
694+
) = work_item
666695

667696
# Create a temporary builder instance for this worker
668-
builder = GemmKernelBuilder(working_path, datatype, layout)
697+
builder = GemmKernelBuilder(working_path, gpu_target, datatype, layout, config_json)
669698

670699
try:
671700
kernel_name, instance_code = builder._generate_kernel_instance(
@@ -798,6 +827,8 @@ def main():
798827
)
799828

800829
k_block_per_cu = builder.config.get("k_block_per_cu")
830+
if k_block_per_cu is None:
831+
k_block_per_cu = 1
801832

802833
# Generate the kernel
803834
kernel_name, instance_code = builder._generate_kernel_instance(

0 commit comments

Comments
 (0)