Skip to content

Commit

Permalink
[spirv] Vectorize Linalg ops with reduction dimensions (iree-org#8836)
Browse files Browse the repository at this point in the history
This helps to generate better code for Linalg ops with both
parallel and reduction iterator types but not using minor
identity maps. To make the flow work, SPIRVVectorizePass
is tweaked to pull in more patterns. Also along the way,
enabled tiling along `linalg.generic` reduction dimensions.
  • Loading branch information
antiagainst authored May 23, 2022
1 parent 28bc420 commit 74cc6bc
Show file tree
Hide file tree
Showing 9 changed files with 522 additions and 45 deletions.
36 changes: 24 additions & 12 deletions compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -410,17 +410,14 @@ static LogicalResult setDefaultOpConfig(spirv::ResourceLimitsAttr limits,
SmallVector<int64_t, 4> loopBounds = linalgOp.getStaticLoopRanges();
bool vectorizable =
allowVectorization &&
// The vectorization pipeline assumes tensor semantics when tiling.
!linalgOp.hasBufferSemantics() && !linalgOp.hasIndexSemantics() &&
// Skip vectorization for non-minor identity inputs as it generates
// vector.transfer_read ops with permutation maps that we currently
// cannot lower.
// TODO: Remove this restriction once the lowering of the permutation
// map is supported in core.
llvm::all_of(linalgOp.getIndexingMaps(),
[](AffineMap &map) { return map.isMinorIdentity(); }) &&
// TODO: Lowering of integers other than i32 may require emulation.
// This is currently not supported for vector operation.
// The vectorization pipeline assumes tensor semantics for tiling.
linalgOp.hasTensorSemantics() && !linalgOp.hasIndexSemantics() &&
// Require all affine maps to be projected permutation so that we can
// generate vector transfer ops.
llvm::all_of(
linalgOp.getIndexingMaps(),
[](AffineMap map) { return map.isProjectedPermutation(); }) &&
// TODO: Fix non-32-bit element type vectorization and remove this.
llvm::all_of(linalgOp->getOperands(), has32BitElementType) &&
llvm::none_of(loopBounds, ShapedType::isDynamic);

Expand Down Expand Up @@ -496,7 +493,7 @@ static LogicalResult setDefaultOpConfig(spirv::ResourceLimitsAttr limits,
if (distributeToThreads(subgroupSize) != 1) {
// Otherwise, allow larger and larger loss factor.

// Threads for distribution Use 32 at least.
// Threads for distribution. Use 32 at least.
int64_t numThreads = std::max(subgroupSize, 32);
// We can tolerate (1 / lossFactor) of threads in the workgroup to be idle.
int64_t lossFactor = 32;
Expand All @@ -515,6 +512,21 @@ static LogicalResult setDefaultOpConfig(spirv::ResourceLimitsAttr limits,
tileSizes.push_back(workgroupTileSizes);
tileSizes.push_back(threadTileSizes);

if (vectorizable) {
// Try to tile all reductions by size 4 if possible. This gives us a chance
// to perform vector4 load if an input has its innnermost dimension being
// reduction. It also avoidss generating too many instructions when
// unrolling vector later.
SmallVector<int64_t> reductionTileSizes(linalgOp.getNumLoops(), 0);
for (const auto &it : llvm::enumerate(linalgOp.getIteratorTypes())) {
if (isReductionIterator(it.value()) && loopBounds[it.index()] % 4 == 0)
reductionTileSizes[it.index()] = 4;
}
if (llvm::any_of(reductionTileSizes, [](int64_t s) { return s != 0; })) {
tileSizes.push_back(reductionTileSizes);
}
}

return setOpConfigAndEntryPointFnTranslation(funcOp, op, tileSizes, pipeline,
workgroupSize);
}
Expand Down
5 changes: 3 additions & 2 deletions compiler/src/iree/compiler/Codegen/SPIRV/SPIRVTile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ static void populateTilingReductionPatterns(RewritePatternSet &patterns) {
auto filter = linalg::LinalgTransformationFilter({marker}, llvm::None);

linalg::TilingPatterns<linalg::BatchMatmulOp, linalg::Conv2DNhwcHwcfOp,
linalg::DepthwiseConv2DNhwcHwcOp,
linalg::DepthwiseConv2DNhwcHwcOp, linalg::GenericOp,
linalg::MatmulOp>::insert(patterns, tilingOptions,
filter);
}
Expand Down Expand Up @@ -283,7 +283,8 @@ class SPIRVTilePass final : public SPIRVTileBase<SPIRVTilePass> {
auto marker = builder.getStringAttr(getTileReductionMarker());
funcOp.walk([&](linalg::LinalgOp op) {
if (isa<linalg::ContractionOpInterface>(*op) ||
isa<linalg::ConvolutionOpInterface>(*op)) {
isa<linalg::ConvolutionOpInterface>(*op) ||
isa<linalg::GenericOp>(*op)) {
op->setAttr(linalg::LinalgTransforms::kLinalgTransformMarker, marker);
}
});
Expand Down
72 changes: 62 additions & 10 deletions compiler/src/iree/compiler/Codegen/SPIRV/SPIRVVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,20 @@ Optional<SmallVector<int64_t, 4>> getNativeVectorShape(Operation *op) {
SmallVector<int64_t, 4> nativeSize(contractOp.getIteratorTypes().size(), 1);
nativeSize[lastParalleldim] = 4; // Map to vec4 fma operations.
return nativeSize;
} else if (auto reductionOp = dyn_cast<vector::MultiDimReductionOp>(op)) {
// Unroll all reduction dimensions by size 1 for vector.multi_reduction.
auto srcVectorType = reductionOp.getSourceVectorType();
auto nativeSize = llvm::to_vector<4>(srcVectorType.getShape());
auto dims = reductionOp.getReductionDims().getAsValueRange<IntegerAttr>();
for (const auto &dimAttr : dims) {
nativeSize[dimAttr.getZExtValue()] = 1;
}
return nativeSize;
} else if (auto transposeOp = dyn_cast<vector::TransposeOp>(op)) {
auto vectorType = transposeOp.getResultType();
SmallVector<int64_t, 4> nativeSize(vectorType.getRank(), 1);
nativeSize.back() = getComputeVectorSize(vectorType.getShape().back());
return nativeSize;
}
return llvm::None;
}
Expand All @@ -100,6 +114,10 @@ void populateVectorizationPatterns(RewritePatternSet &patterns) {
patterns.add<linalg::LinalgVectorizationPattern>(
patterns.getContext(), f.addOpFilter<linalg::ContractionOpInterface>(),
opt);
// Additinally pull in patterns to canonicalize transfer ops and to shuffle
// broadcast/transpose ops around in order to cancel them or embed into
// contract ops. Embedding in the flexible contract ops will help to sustain
// the structure through various transformations.
vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
vector::populateVectorReductionToContractPatterns(patterns);
}
Expand Down Expand Up @@ -163,6 +181,34 @@ class SPIRVVectorizePass : public SPIRVVectorizeBase<SPIRVVectorizePass> {
llvm::dbgs() << "\n\n";
});

// Lower vector.multi_dimension early if any operand is a transpose op.
// The lowering itself generates transpose ops. This helps to cancel
// transpose ops. vector.multi_reduction is arguably a higher level op and
// the lowering also unrolls the multi_reduction op, so it makes sense to
// happen before normal unrolling.
{
SmallVector<Operation *> reductionOps;
funcOp.walk([&](vector::MultiDimReductionOp reductionOp) {
if (llvm::any_of(reductionOp->getOperands(), [](Value operand) {
return operand.getDefiningOp<vector::TransposeOp>();
}))
reductionOps.push_back(reductionOp);
return WalkResult::advance();
});
RewritePatternSet patterns(context);
vector::populateVectorMultiReductionLoweringPatterns(
patterns, vector::VectorMultiReductionLowering::InnerParallel);
FrozenRewritePatternSet frozenSet(std::move(patterns));
applyOpPatternsAndFold(reductionOps, frozenSet,
/*strict=*/false);
}

LLVM_DEBUG({
llvm::dbgs() << "--- After lowering multi_reduction ops ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});

// Then unroll vectors to native vector size. We try to use 128-bit
// vectors for memory access and 4/2/1 vector sizes for computation.
{
Expand All @@ -182,7 +228,7 @@ class SPIRVVectorizePass : public SPIRVVectorizeBase<SPIRVVectorizePass> {
// Next run canonicalization to cast away leading size-1 dimensions. They
// can be generated from vector unrolling and generally cause issues to
// cancel corresponding read/write or insert/extract op pairs. This also
// need to happen befor hositing, where we would make certain vectors loop
// need to happen before hositing, where we would make certain vectors loop
// carried. Once that's done, it's hard to handle the leading size-1
// dimensions across regions.
{
Expand Down Expand Up @@ -251,33 +297,39 @@ class SPIRVVectorizePass : public SPIRVVectorizeBase<SPIRVVectorizePass> {
llvm::dbgs() << "\n\n";
});

// Lower vector broadcast and contraction.
// Lower vector broadcast/transpose and contraction.
{
RewritePatternSet patterns(context);
auto options = vector::VectorTransformsOptions()
.setVectorTransformsOptions(
vector::VectorContractLowering::OuterProduct)
.setVectorTransposeLowering(
vector::VectorTransposeLowering::EltWise);
vector::populateVectorBroadcastLoweringPatterns(patterns);
vector::populateVectorContractLoweringPatterns(
patterns,
vector::VectorTransformsOptions().setVectorTransformsOptions(
vector::VectorContractLowering::OuterProduct));
vector::populateVectorContractLoweringPatterns(patterns, options);
vector::populateVectorMultiReductionLoweringPatterns(
patterns, vector::VectorMultiReductionLowering::InnerParallel);
vector::populateVectorTransposeLoweringPatterns(patterns, options);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
}

LLVM_DEBUG({
llvm::dbgs() << "--- After lowering contract ops ---\n";
llvm::dbgs() << "--- After lowering various vector ops ---\n";
funcOp.print(llvm::dbgs(), OpPrintingFlags().useLocalScope());
llvm::dbgs() << "\n\n";
});

// Cast away leading size-1 dimensions again.
// Run all sorts of canonicalization patterns to clean up again.
{
RewritePatternSet patterns(context);
// We need to pull in casting way leading one dims to allow cancelling
// some read/write ops.
vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
vector::InsertOp::getCanonicalizationPatterns(patterns, context);
vector::ExtractOp::getCanonicalizationPatterns(patterns, context);
vector::TransferReadOp::getCanonicalizationPatterns(patterns, context);
vector::TransferWriteOp::getCanonicalizationPatterns(patterns, context);
vector::ReductionOp::getCanonicalizationPatterns(patterns, context);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return signalPassFailure();
}
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,12 @@ iree_lit_test_suite(
"tile_and_vectorize_conv.mlir",
"tile_and_vectorize_matmul.mlir",
"tile_and_vectorize_to_cooperative_ops.mlir",
"tile_linalg_ops.mlir",
"vector_to_cooperative_matrix.mlir",
"vectorize_elementwise_ops.mlir",
"vectorize_matmul.mlir",
"vectorize_load_store.mlir",
"vectorize_reduction.mlir",
"vectorize_tensor_pad.mlir",
],
include = ["*.mlir"],
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,12 @@ iree_lit_test_suite(
"tile_and_vectorize_conv.mlir"
"tile_and_vectorize_matmul.mlir"
"tile_and_vectorize_to_cooperative_ops.mlir"
"tile_linalg_ops.mlir"
"vector_to_cooperative_matrix.mlir"
"vectorize_elementwise_ops.mlir"
"vectorize_load_store.mlir"
"vectorize_matmul.mlir"
"vectorize_reduction.mlir"
"vectorize_tensor_pad.mlir"
TOOLS
FileCheck
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -360,3 +360,115 @@ hal.executable @dwconv_elementwise {
// CHECK-SAME: translation_info = #[[TRANSLATION]]
// CHECK: linalg.generic
// CHECK-SAME: lowering_config = #[[CONFIG]]


// -----

#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
#hal.descriptor_set.binding<0, storage_buffer>,
#hal.descriptor_set.binding<1, storage_buffer>
]>
]>
#map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d1)>

hal.executable @outermost_reduction {
hal.executable.variant @vulkan_spirv_fb, target = <"vulkan-spirv", "vulkan-spirv-fb", {
spv.target_env = #spv.target_env<#spv.vce<v1.4, [Shader], []>, Unknown:IntegratedGPU, {
max_compute_shared_memory_size = 32768 : i32,
max_compute_workgroup_invocations = 512 : i32,
max_compute_workgroup_size = dense<512> : vector<3xi32>,
subgroup_size = 32 : i32}>
}> {
hal.executable.entry_point @outermost_reduction layout(#executable_layout)
builtin.module {
func.func @outermost_reduction() {
%cst = arith.constant 0.000000e+00 : f32
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:4x2048x512xf32>
%1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:2048x512xf32>
%2 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0], sizes = [4, 2048, 512], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:4x2048x512xf32> -> tensor<4x2048x512xf32>
%3 = linalg.init_tensor [2048, 512] : tensor<2048x512xf32>
%4 = linalg.fill ins(%cst : f32) outs(%3 : tensor<2048x512xf32>) -> tensor<2048x512xf32>
%5 = linalg.generic {
indexing_maps = [#map0, #map1],
iterator_types = ["parallel", "parallel", "reduction"]
} ins(%2 : tensor<4x2048x512xf32>) outs(%4 : tensor<2048x512xf32>) {
^bb0(%arg0: f32, %arg1: f32):
%6 = arith.addf %arg0, %arg1 : f32
linalg.yield %6 : f32
} -> tensor<2048x512xf32>
flow.dispatch.tensor.store %5, %1, offsets = [0, 0], sizes = [2048, 512], strides = [1, 1] : tensor<2048x512xf32> -> !flow.dispatch.tensor<writeonly:2048x512xf32>
return
}
}
}
}

// CHECK-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[1, 128], [1, 4], [0, 0, 4]{{\]}}>
// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVVectorize>
// CHECK-LABEL: hal.executable.entry_point public @outermost_reduction
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
// CHECK: linalg.generic
// CHECK-SAME: lowering_config = #[[$CONFIG]]

// -----

#executable_layout = #hal.executable.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
#hal.descriptor_set.binding<0, storage_buffer>,
#hal.descriptor_set.binding<1, storage_buffer>
]>
]>
#map0 = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d0)>

hal.executable private @innermost_reduction {
hal.executable.variant public @vulkan_spirv_fb, target = <"vulkan-spirv", "vulkan-spirv-fb", {
spv.target_env = #spv.target_env<#spv.vce<v1.4, [Shader], []>, Unknown:IntegratedGPU, {
max_compute_shared_memory_size = 32768 : i32,
max_compute_workgroup_invocations = 512 : i32,
max_compute_workgroup_size = dense<512> : vector<3xi32>,
subgroup_size = 32 : i32}>
}> {
hal.executable.entry_point public @innermost_reduction ordinal(0) layout(#executable_layout)
builtin.module {
func.func @innermost_reduction() {
%cst = arith.constant -0.000000e+00 : f32
%0 = hal.interface.constant.load[0] : i32
%1 = hal.interface.constant.load[1] : i32
%2 = hal.interface.constant.load[2] : i32
%3 = arith.index_cast %0 {stream.alignment = 512 : index, stream.values = [0 : index, 394752 : index, 984064 : index]} : i32 to index
%4 = arith.index_cast %1 {stream.alignment = 512 : index, stream.values = [0 : index, 196608 : index, 197120 : index]} : i32 to index
%5 = arith.index_cast %2 {stream.alignment = 512 : index, stream.values = [512 : index, 197120 : index, 197632 : index]} : i32 to index
%6 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%3) alignment(64) : !flow.dispatch.tensor<readonly:128x384xf32>
%7 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%4) alignment(64) : !flow.dispatch.tensor<readonly:128xf32>
%8 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%5) alignment(64) : !flow.dispatch.tensor<writeonly:128xf32>
%9 = flow.dispatch.tensor.load %6, offsets = [0, 0], sizes = [128, 384], strides = [1, 1] : !flow.dispatch.tensor<readonly:128x384xf32> -> tensor<128x384xf32>
%10 = flow.dispatch.tensor.load %7, offsets = [0], sizes = [128], strides = [1] : !flow.dispatch.tensor<readonly:128xf32> -> tensor<128xf32>
%11 = linalg.init_tensor [128] : tensor<128xf32>
%12 = linalg.fill ins(%cst : f32) outs(%11 : tensor<128xf32>) -> tensor<128xf32>
%13 = linalg.generic {
indexing_maps = [#map0, #map1, #map1],
iterator_types = ["parallel", "reduction"]
} ins(%9, %10 : tensor<128x384xf32>, tensor<128xf32>) outs(%12 : tensor<128xf32>) {
^bb0(%arg0: f32, %arg1: f32, %arg2: f32):
%14 = arith.subf %arg0, %arg1 : f32
%15 = arith.mulf %14, %14 : f32
%16 = arith.addf %15, %arg2 : f32
linalg.yield %16 : f32
} -> tensor<128xf32>
flow.dispatch.tensor.store %13, %8, offsets = [0], sizes = [128], strides = [1] : tensor<128xf32> -> !flow.dispatch.tensor<writeonly:128xf32>
return
}
}
}
}

// CHECK-DAG: #[[$CONFIG:.+]] = #iree_codegen.lowering_config<tile_sizes = {{\[}}[128], [4], [0, 4]{{\]}}>
// CHECK-DAG: #[[$TRANSLATION:.+]] = #iree_codegen.translation_info<SPIRVVectorize>
// CHECK-LABEL: hal.executable.entry_point public @innermost_reduction
// CHECK-SAME: translation_info = #[[$TRANSLATION]]
// CHECK: linalg.generic
// CHECK-SAME: lowering_config = #[[$CONFIG]]
Loading

0 comments on commit 74cc6bc

Please sign in to comment.