Skip to content

Commit 3c7c907

Browse files
committed
[DAP] Support both f32 and f64 type for 'dap.fir' operation.
1 parent 1f9b40a commit 3c7c907

File tree

6 files changed

+183
-177
lines changed

6 files changed

+183
-177
lines changed

benchmarks/AudioProcessing/Operations/FIROp/CMakeLists.txt

+33-90
Original file line numberDiff line numberDiff line change
@@ -1,94 +1,38 @@
1-
#-------------------------------------------------------------------------------
2-
# Generate MLIRFIRScalar
3-
#-------------------------------------------------------------------------------
4-
5-
add_custom_command(
6-
OUTPUT mlir-fir.o
7-
COMMAND
8-
${BUDDY_MLIR_BUILD_DIR}/bin/buddy-opt
9-
${BUDDY_SOURCE_DIR}/benchmarks/AudioProcessing/Operations/FIROp/MLIRFIR.mlir
10-
-convert-scf-to-cf
11-
-llvm-request-c-wrappers
12-
-convert-arith-to-llvm
13-
-finalize-memref-to-llvm
14-
-convert-func-to-llvm
15-
-reconcile-unrealized-casts |
16-
${LLVM_MLIR_BINARY_DIR}/mlir-translate -mlir-to-llvmir |
17-
${LLVM_MLIR_BINARY_DIR}/llc
18-
-mtriple=${BUDDY_OPT_TRIPLE}
19-
-mattr=${BUDDY_OPT_ATTR}
20-
-filetype=obj
21-
-o ${BUDDY_BINARY_DIR}/../benchmarks/AudioProcessing/Operations/FIROp/mlir-fir.o
22-
DEPENDS
23-
${BUDDY_MLIR_BUILD_DIR}/bin/buddy-opt
24-
${LLVM_MLIR_BINARY_DIR}/mlir-translate
25-
${LLVM_MLIR_BINARY_DIR}/llc
26-
)
27-
28-
add_library(MLIRFIRScalar STATIC mlir-fir.o)
29-
set_target_properties(MLIRFIRScalar PROPERTIES LINKER_LANGUAGE CXX)
30-
31-
#-------------------------------------------------------------------------------
32-
# Generate MLIRFIRTiledVectorization
33-
#-------------------------------------------------------------------------------
34-
35-
add_custom_command(
36-
OUTPUT fir-tile-vectorization.o
37-
COMMAND
38-
${BUDDY_MLIR_BUILD_DIR}/bin/buddy-opt
39-
${BUDDY_SOURCE_DIR}/benchmarks/AudioProcessing/Operations/FIROp/MLIRFIRTiledVectorization.mlir
40-
-convert-scf-to-cf
41-
-convert-vector-to-llvm
42-
-llvm-request-c-wrappers
43-
-convert-arith-to-llvm
44-
-finalize-memref-to-llvm
45-
-convert-func-to-llvm
46-
-reconcile-unrealized-casts |
47-
${LLVM_MLIR_BINARY_DIR}/mlir-translate -mlir-to-llvmir |
48-
${LLVM_MLIR_BINARY_DIR}/llc
49-
-mtriple=${BUDDY_OPT_TRIPLE}
50-
-mattr=${BUDDY_OPT_ATTR}
51-
-filetype=obj
52-
-o ${BUDDY_BINARY_DIR}/../benchmarks/AudioProcessing/Operations/FIROp/fir-tile-vectorization.o
53-
DEPENDS
54-
${BUDDY_MLIR_BUILD_DIR}/bin/buddy-opt
55-
${LLVM_MLIR_BINARY_DIR}/mlir-translate
56-
${LLVM_MLIR_BINARY_DIR}/llc
57-
)
58-
59-
add_library(MLIRFIRTiledVectorization STATIC fir-tile-vectorization.o)
60-
set_target_properties(MLIRFIRTiledVectorization PROPERTIES LINKER_LANGUAGE CXX)
61-
621
#-------------------------------------------------------------------------------
632
# Generate MLIRFIRVectorization
643
#-------------------------------------------------------------------------------
654

66-
add_custom_command(
67-
OUTPUT fir-vectorization.o
68-
COMMAND
69-
${BUDDY_MLIR_BUILD_DIR}/bin/buddy-opt
70-
${BUDDY_SOURCE_DIR}/benchmarks/AudioProcessing/Operations/FIROp/MLIRFIRVectorization.mlir
71-
-convert-scf-to-cf
72-
-convert-vector-to-llvm
73-
-llvm-request-c-wrappers
74-
-convert-arith-to-llvm
75-
-finalize-memref-to-llvm
76-
-convert-func-to-llvm
77-
-reconcile-unrealized-casts |
78-
${LLVM_MLIR_BINARY_DIR}/mlir-translate -mlir-to-llvmir |
79-
${LLVM_MLIR_BINARY_DIR}/llc
80-
-mtriple=${BUDDY_OPT_TRIPLE}
81-
-mattr=${BUDDY_OPT_ATTR}
82-
-filetype=obj
83-
-o ${BUDDY_BINARY_DIR}/../benchmarks/AudioProcessing/Operations/FIROp/fir-vectorization.o
84-
DEPENDS
85-
${BUDDY_MLIR_BUILD_DIR}/bin/buddy-opt
86-
${LLVM_MLIR_BINARY_DIR}/mlir-translate
87-
${LLVM_MLIR_BINARY_DIR}/llc
88-
)
89-
90-
add_library(MLIRFIRVectorization STATIC fir-vectorization.o)
91-
set_target_properties(MLIRFIRVectorization PROPERTIES LINKER_LANGUAGE CXX)
5+
function(build_fir_vectorization type)
6+
add_custom_command(
7+
OUTPUT fir-vectorization-${type}.o
8+
COMMAND
9+
cat ${BUDDY_SOURCE_DIR}/benchmarks/AudioProcessing/Operations/FIROp/MLIRFIRVectorization.mlir |
10+
sed 's/TYPE_PLACEHOLDER/${type}/g' |
11+
${BUDDY_MLIR_BUILD_DIR}/bin/buddy-opt
12+
-convert-scf-to-cf
13+
-convert-vector-to-llvm
14+
-llvm-request-c-wrappers
15+
-convert-arith-to-llvm
16+
-finalize-memref-to-llvm
17+
-convert-func-to-llvm
18+
-reconcile-unrealized-casts |
19+
${LLVM_MLIR_BINARY_DIR}/mlir-translate -mlir-to-llvmir |
20+
${LLVM_MLIR_BINARY_DIR}/llc
21+
-mtriple=${BUDDY_OPT_TRIPLE}
22+
-mattr=${BUDDY_OPT_ATTR}
23+
-filetype=obj
24+
-o ${BUDDY_BINARY_DIR}/../benchmarks/AudioProcessing/Operations/FIROp/fir-vectorization-${type}.o
25+
DEPENDS
26+
${BUDDY_MLIR_BUILD_DIR}/bin/buddy-opt
27+
${LLVM_MLIR_BINARY_DIR}/mlir-translate
28+
${LLVM_MLIR_BINARY_DIR}/llc
29+
)
30+
add_library(MLIRFIRVectorization${type} STATIC fir-vectorization-${type}.o)
31+
set_target_properties(MLIRFIRVectorization${type} PROPERTIES LINKER_LANGUAGE CXX)
32+
endfunction()
33+
34+
build_fir_vectorization(f32)
35+
build_fir_vectorization(f64)
9236

9337
#-------------------------------------------------------------------------------
9438
# Generate dap-op-fir-benchmark
@@ -105,9 +49,8 @@ target_link_libraries(dap-op-fir-benchmark PRIVATE
10549
# Third-party library
10650
kfr_io
10751
# MLIR hand-written benchmark
108-
MLIRFIRScalar
109-
MLIRFIRTiledVectorization
110-
MLIRFIRVectorization
52+
MLIRFIRVectorizationf32
53+
MLIRFIRVectorizationf64
11154
# Buddy DAP library
11255
BuddyLibDAP
11356
# LLVM/MLIR library

benchmarks/AudioProcessing/Operations/FIROp/MLIRFIR.mlir

+2-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
//
1515
//===----------------------------------------------------------------------===//
1616
//
17-
// This file provides the MLIR Fir function.
17+
// This file implements the scalar version of the Fir function, following the
18+
// same algorithm as Buddy's scalar version DAP pass: `--lower-dap`.
1819
//
1920
//===----------------------------------------------------------------------===//
2021

benchmarks/AudioProcessing/Operations/FIROp/MLIRFIRTiledVectorization.mlir

+4-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
//
1515
//===----------------------------------------------------------------------===//
1616
//
17-
// This file provides the vectorized MLIR FIR function with tiling.
17+
// This file implements the vectorized FIR function using a tiling technique.
18+
// following the same algorithm as Buddy's vectorize DAP pass:
19+
// `--vectorize-dap="fir-vec-size=16 fir-tile-size=2048"`
1820
//
1921
//===----------------------------------------------------------------------===//
2022

@@ -110,7 +112,7 @@ func.func @fir_tiled_vectorization(%input : memref<?xf32>, %kernel : memref<?xf3
110112
scf.for %i = %address to %upbound step %vl_step {
111113
%in_vec = vector.load %input[%i] : memref<?xf32>, vector<16xf32>
112114
%out_index = arith.addi %i, %n : index
113-
%out_vec = vector.load %output[%out_index] : memref<?xf32>, vector<16xf32> // 需要计算output的偏移量
115+
%out_vec = vector.load %output[%out_index] : memref<?xf32>, vector<16xf32>
114116
%fma_vec = vector.fma %k_vec, %in_vec, %out_vec : vector<16xf32>
115117
vector.store %fma_vec, %output[%out_index] : memref<?xf32>, vector<16xf32>
116118
}

benchmarks/AudioProcessing/Operations/FIROp/MLIRFIRVectorization.mlir

+17-16
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,18 @@
1414
//
1515
//===----------------------------------------------------------------------===//
1616
//
17-
// This file provides the vectorized MLIR FIR function (without tiling).
17+
// This file implements the vectorized MLIR FIR function (without tiling),
18+
// with a fixed vector size of 16.
1819
//
1920
//===----------------------------------------------------------------------===//
2021

21-
func.func @fir_vectorization(%input : memref<?xf32>, %kernel : memref<?xf32>,
22-
%output : memref<?xf32>) -> () {
22+
func.func @fir_vector_TYPE_PLACEHOLDER(%input : memref<?xTYPE_PLACEHOLDER>,
23+
%kernel : memref<?xTYPE_PLACEHOLDER>, %output : memref<?xTYPE_PLACEHOLDER>) -> () {
2324
// 1. Get the total length of the workload.
2425
%c0 = arith.constant 0 : index
2526
%c1 = arith.constant 1 : index
26-
%input_size = memref.dim %input, %c0 : memref<?xf32>
27-
%kernel_size = memref.dim %kernel, %c0 : memref<?xf32>
27+
%input_size = memref.dim %input, %c0 : memref<?xTYPE_PLACEHOLDER>
28+
%kernel_size = memref.dim %kernel, %c0 : memref<?xTYPE_PLACEHOLDER>
2829

2930
// 2. Set the iteration step (vector size).
3031
%vl_step = arith.constant 16 : index
@@ -40,30 +41,30 @@ func.func @fir_vectorization(%input : memref<?xf32>, %kernel : memref<?xf32>,
4041
// 4. Loop through each kernel element
4142
scf.for %n = %c0 to %kernel_size step %c1
4243
iter_args(%upbound = %upbound_init) -> (index) {
43-
%k_elem = memref.load %kernel[%n] : memref<?xf32>
44-
%k_vec = vector.splat %k_elem : vector<16xf32>
44+
%k_elem = memref.load %kernel[%n] : memref<?xTYPE_PLACEHOLDER>
45+
%k_vec = vector.splat %k_elem : vector<16xTYPE_PLACEHOLDER>
4546

4647
// 5. Perform the vectorization body.
4748
%iter_idx = scf.for %i = %c0 to %upbound step %vl_step
4849
iter_args(%iter_init = %c0) -> (index) {
49-
%in_vec = vector.load %input[%i] : memref<?xf32>, vector<16xf32>
50+
%in_vec = vector.load %input[%i] : memref<?xTYPE_PLACEHOLDER>, vector<16xTYPE_PLACEHOLDER>
5051
%out_index = arith.addi %i, %n : index
51-
%out_vec = vector.load %output[%out_index] : memref<?xf32>, vector<16xf32>
52-
%fma_vec = vector.fma %k_vec, %in_vec, %out_vec : vector<16xf32>
53-
vector.store %fma_vec, %output[%out_index] : memref<?xf32>, vector<16xf32>
52+
%out_vec = vector.load %output[%out_index] : memref<?xTYPE_PLACEHOLDER>, vector<16xTYPE_PLACEHOLDER>
53+
%fma_vec = vector.fma %k_vec, %in_vec, %out_vec : vector<16xTYPE_PLACEHOLDER>
54+
vector.store %fma_vec, %output[%out_index] : memref<?xTYPE_PLACEHOLDER>, vector<16xTYPE_PLACEHOLDER>
5455
%i_next = arith.addi %i, %vl_step : index
5556
scf.yield %i_next : index
5657
}
5758

5859
// 6. Process the remainder of the elements with scalar operations.
5960
%upbound_scalar = arith.addi %upbound, %vl_step_minus_1 : index
6061
scf.for %i = %iter_idx to %upbound_scalar step %c1 {
61-
%in_elem = memref.load %input[%i] : memref<?xf32>
62+
%in_elem = memref.load %input[%i] : memref<?xTYPE_PLACEHOLDER>
6263
%out_index = arith.addi %i, %n : index
63-
%out_elem = memref.load %output[%out_index] : memref<?xf32>
64-
%mul_elem = arith.mulf %in_elem, %k_elem : f32
65-
%add_elem = arith.addf %mul_elem, %out_elem : f32
66-
memref.store %add_elem, %output[%out_index] : memref<?xf32>
64+
%out_elem = memref.load %output[%out_index] : memref<?xTYPE_PLACEHOLDER>
65+
%mul_elem = arith.mulf %in_elem, %k_elem : TYPE_PLACEHOLDER
66+
%add_elem = arith.addf %mul_elem, %out_elem : TYPE_PLACEHOLDER
67+
memref.store %add_elem, %output[%out_index] : memref<?xTYPE_PLACEHOLDER>
6768
}
6869

6970
%upbound_next = arith.subi %upbound, %c1 : index

0 commit comments

Comments
 (0)