Skip to content

Commit f581ef5

Browse files
authored
[mlir][gpu] Add gpu.rotate operation (#142796)
Add gpu.rotate operation and a pattern to convert gpu.rotate to SPIR-V OpGroupNonUniformRotateKHR.
1 parent a97826a commit f581ef5

File tree

6 files changed

+323
-3
lines changed

6 files changed

+323
-3
lines changed

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1304,8 +1304,8 @@ def GPU_ShuffleOp : GPU_Op<
13041304
Results<(outs AnyIntegerOrFloatOr1DVector:$shuffleResult, I1:$valid)> {
13051305
let summary = "Shuffles values within a subgroup.";
13061306
let description = [{
1307-
The "shuffle" op moves values to a across lanes (a.k.a., invocations,
1308-
work items) within the same subgroup. The `width` argument specifies the
1307+
The "shuffle" op moves values across lanes in a subgroup (a.k.a., local
1308+
invocation) within the same subgroup. The `width` argument specifies the
13091309
number of lanes that participate in the shuffle, and must be uniform
13101310
across all lanes. Further, the first `width` lanes of the subgroup must
13111311
be active.
@@ -1366,6 +1366,54 @@ def GPU_ShuffleOp : GPU_Op<
13661366
];
13671367
}
13681368

1369+
def GPU_RotateOp : GPU_Op<
1370+
"rotate", [Pure, AllTypesMatch<["value", "rotateResult"]>]>,
1371+
Arguments<(ins AnyIntegerOrFloatOr1DVector:$value, I32:$offset, I32:$width)>,
1372+
Results<(outs AnyIntegerOrFloatOr1DVector:$rotateResult, I1:$valid)> {
1373+
let summary = "Rotate values within a subgroup.";
1374+
let description = [{
1375+
The "rotate" op moves values across lanes in a subgroup (a.k.a., local
1376+
invocations) within the same subgroup. The `width` argument specifies the
1377+
number of lanes that participate in the rotation, and must be uniform across
1378+
all participating lanes. Further, the first `width` lanes of the subgroup
1379+
must be active.
1380+
1381+
`width` must be a power of two, and `offset` must be in the range
1382+
`[0, width)`.
1383+
1384+
Return the `rotateResult` of the invocation whose id within the group is
1385+
calculated as follows:
1386+
1387+
```mlir
1388+
Invocation ID = ((LaneId + offset) & (width - 1)) + (LaneId & ~(width - 1))
1389+
```
1390+
1391+
Returns the `rotateResult` and `true` if the current lane id is smaller than
1392+
`width`, and poison value and `false` otherwise.
1393+
1394+
example:
1395+
1396+
```mlir
1397+
%offset = arith.constant 1 : i32
1398+
%width = arith.constant 16 : i32
1399+
%1, %2 = gpu.rotate %0, %offset, %width : f32
1400+
```
1401+
1402+
For lane `k`, returns the value from lane `(k + cst1) % width`.
1403+
}];
1404+
1405+
let assemblyFormat = [{
1406+
$value `,` $offset `,` $width attr-dict `:` type($value)
1407+
}];
1408+
1409+
let builders = [
1410+
// Helper function that creates a rotate with constant offset/width.
1411+
OpBuilder<(ins "Value":$value, "int32_t":$offset, "int32_t":$width)>
1412+
];
1413+
1414+
let hasVerifier = 1;
1415+
}
1416+
13691417
def GPU_BarrierOp : GPU_Op<"barrier"> {
13701418
let summary = "Synchronizes all work items of a workgroup.";
13711419
let description = [{

mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,16 @@ class GPUShuffleConversion final : public OpConversionPattern<gpu::ShuffleOp> {
122122
ConversionPatternRewriter &rewriter) const override;
123123
};
124124

125+
/// Pattern to convert a gpu.rotate op into a spirv.GroupNonUniformRotateKHROp.
126+
class GPURotateConversion final : public OpConversionPattern<gpu::RotateOp> {
127+
public:
128+
using OpConversionPattern::OpConversionPattern;
129+
130+
LogicalResult
131+
matchAndRewrite(gpu::RotateOp rotateOp, OpAdaptor adaptor,
132+
ConversionPatternRewriter &rewriter) const override;
133+
};
134+
125135
class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> {
126136
public:
127137
using OpConversionPattern::OpConversionPattern;
@@ -488,6 +498,41 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
488498
return success();
489499
}
490500

501+
//===----------------------------------------------------------------------===//
502+
// Rotate
503+
//===----------------------------------------------------------------------===//
504+
505+
LogicalResult GPURotateConversion::matchAndRewrite(
506+
gpu::RotateOp rotateOp, OpAdaptor adaptor,
507+
ConversionPatternRewriter &rewriter) const {
508+
const spirv::TargetEnv &targetEnv =
509+
getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
510+
unsigned subgroupSize =
511+
targetEnv.getAttr().getResourceLimits().getSubgroupSize();
512+
IntegerAttr widthAttr;
513+
if (!matchPattern(rotateOp.getWidth(), m_Constant(&widthAttr)) ||
514+
widthAttr.getValue().getZExtValue() > subgroupSize)
515+
return rewriter.notifyMatchFailure(
516+
rotateOp,
517+
"rotate width is not a constant or larger than target subgroup size");
518+
519+
Location loc = rotateOp.getLoc();
520+
auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
521+
Value rotateResult = rewriter.create<spirv::GroupNonUniformRotateKHROp>(
522+
loc, scope, adaptor.getValue(), adaptor.getOffset(), adaptor.getWidth());
523+
Value validVal;
524+
if (widthAttr.getValue().getZExtValue() == subgroupSize) {
525+
validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), loc, rewriter);
526+
} else {
527+
Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr);
528+
validVal = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
529+
laneId, adaptor.getWidth());
530+
}
531+
532+
rewriter.replaceOp(rotateOp, {rotateResult, validVal});
533+
return success();
534+
}
535+
491536
//===----------------------------------------------------------------------===//
492537
// Group ops
493538
//===----------------------------------------------------------------------===//
@@ -776,7 +821,7 @@ void mlir::populateGPUToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
776821
RewritePatternSet &patterns) {
777822
patterns.add<
778823
GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion,
779-
GPUReturnOpConversion, GPUShuffleConversion,
824+
GPUReturnOpConversion, GPUShuffleConversion, GPURotateConversion,
780825
LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
781826
LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
782827
LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,

mlir/lib/Dialect/GPU/IR/GPUDialect.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1331,6 +1331,49 @@ void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value,
13311331
mode);
13321332
}
13331333

1334+
//===----------------------------------------------------------------------===//
1335+
// RotateOp
1336+
//===----------------------------------------------------------------------===//
1337+
1338+
void RotateOp::build(OpBuilder &builder, OperationState &result, Value value,
1339+
int32_t offset, int32_t width) {
1340+
build(builder, result, value,
1341+
builder.create<arith::ConstantOp>(result.location,
1342+
builder.getI32IntegerAttr(offset)),
1343+
builder.create<arith::ConstantOp>(result.location,
1344+
builder.getI32IntegerAttr(width)));
1345+
}
1346+
1347+
LogicalResult RotateOp::verify() {
1348+
auto offsetConstOp = getOffset().getDefiningOp<arith::ConstantOp>();
1349+
if (!offsetConstOp)
1350+
return emitOpError() << "offset is not a constant value";
1351+
1352+
auto offsetIntAttr =
1353+
llvm::dyn_cast<mlir::IntegerAttr>(offsetConstOp.getValue());
1354+
1355+
auto widthConstOp = getWidth().getDefiningOp<arith::ConstantOp>();
1356+
if (!widthConstOp)
1357+
return emitOpError() << "width is not a constant value";
1358+
1359+
auto widthIntAttr =
1360+
llvm::dyn_cast<mlir::IntegerAttr>(widthConstOp.getValue());
1361+
1362+
llvm::APInt offsetValue = offsetIntAttr.getValue();
1363+
llvm::APInt widthValue = widthIntAttr.getValue();
1364+
1365+
if (!widthValue.isPowerOf2())
1366+
return emitOpError() << "width must be a power of two";
1367+
1368+
if (offsetValue.sge(widthValue) || offsetValue.slt(0)) {
1369+
int64_t widthValueInt = widthValue.getSExtValue();
1370+
return emitOpError() << "offset must be in the range [0, " << widthValueInt
1371+
<< ")";
1372+
}
1373+
1374+
return success();
1375+
}
1376+
13341377
//===----------------------------------------------------------------------===//
13351378
// BarrierOp
13361379
//===----------------------------------------------------------------------===//
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
// RUN: mlir-opt -split-input-file -convert-gpu-to-spirv -verify-diagnostics %s -o - | FileCheck %s
2+
3+
module attributes {
4+
gpu.container_module,
5+
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformRotateKHR], []>,
6+
#spirv.resource_limits<subgroup_size = 16>>
7+
} {
8+
9+
gpu.module @kernels {
10+
// CHECK-LABEL: spirv.func @rotate()
11+
gpu.func @rotate() kernel
12+
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
13+
%offset = arith.constant 4 : i32
14+
%width = arith.constant 16 : i32
15+
%val = arith.constant 42.0 : f32
16+
17+
// CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32
18+
// CHECK: %[[WIDTH:.+]] = spirv.Constant 16 : i32
19+
// CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
20+
// CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Subgroup> %[[VAL]], %[[OFFSET]], cluster_size(%[[WIDTH]]) : f32, i32, i32 -> f32
21+
// CHECK: %{{.+}} = spirv.Constant true
22+
%result, %valid = gpu.rotate %val, %offset, %width : f32
23+
gpu.return
24+
}
25+
}
26+
27+
}
28+
29+
// -----
30+
31+
module attributes {
32+
gpu.container_module,
33+
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformRotateKHR], []>,
34+
#spirv.resource_limits<subgroup_size = 16>>
35+
} {
36+
37+
gpu.module @kernels {
38+
// CHECK-LABEL: spirv.func @rotate_width_less_than_subgroup_size()
39+
gpu.func @rotate_width_less_than_subgroup_size() kernel
40+
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
41+
%offset = arith.constant 4 : i32
42+
%width = arith.constant 8 : i32
43+
%val = arith.constant 42.0 : f32
44+
45+
// CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32
46+
// CHECK: %[[WIDTH:.+]] = spirv.Constant 8 : i32
47+
// CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
48+
// CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Subgroup> %[[VAL]], %[[OFFSET]], cluster_size(%[[WIDTH]]) : f32, i32, i32 -> f32
49+
// CHECK: %[[INVOCATION_ID_ADDR:.+]] = spirv.mlir.addressof @__builtin__SubgroupLocalInvocationId__
50+
// CHECK: %[[INVOCATION_ID:.+]] = spirv.Load "Input" %[[INVOCATION_ID_ADDR]]
51+
// CHECK: %{{.+}} = spirv.ULessThan %[[INVOCATION_ID]], %[[WIDTH]]
52+
%result, %valid = gpu.rotate %val, %offset, %width : f32
53+
gpu.return
54+
}
55+
}
56+
57+
}
58+
59+
// -----
60+
61+
module attributes {
62+
gpu.container_module,
63+
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformRotateKHR], []>,
64+
#spirv.resource_limits<subgroup_size = 16>>
65+
} {
66+
67+
gpu.module @kernels {
68+
gpu.func @rotate_with_bigger_than_subgroup_size() kernel
69+
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
70+
%offset = arith.constant 4 : i32
71+
%width = arith.constant 32 : i32
72+
%val = arith.constant 42.0 : f32
73+
74+
// expected-error @+1 {{failed to legalize operation 'gpu.rotate'}}
75+
%result, %valid = gpu.rotate %val, %offset, %width : f32
76+
gpu.return
77+
}
78+
}
79+
80+
}
81+
82+
// -----
83+
84+
module attributes {
85+
gpu.container_module,
86+
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformRotateKHR], []>,
87+
#spirv.resource_limits<subgroup_size = 16>>
88+
} {
89+
90+
gpu.module @kernels {
91+
gpu.func @rotate_non_const_width(%width: i32) kernel
92+
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
93+
%offset = arith.constant 4 : i32
94+
%val = arith.constant 42.0 : f32
95+
96+
// expected-error @+1 {{'gpu.rotate' op width is not a constant value}}
97+
%result, %valid = gpu.rotate %val, %offset, %width : f32
98+
gpu.return
99+
}
100+
}
101+
102+
}

mlir/test/Dialect/GPU/invalid.mlir

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,84 @@ func.func @shuffle_unsupported_type_vec(%arg0 : vector<[4]xf32>, %arg1 : i32, %a
478478

479479
// -----
480480

481+
func.func @rotate_mismatching_type(%arg0 : f32) {
482+
%offset = arith.constant 4 : i32
483+
%width = arith.constant 16 : i32
484+
// expected-error@+1 {{op failed to verify that all of {value, rotateResult} have same type}}
485+
%rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (i32, i1)
486+
return
487+
}
488+
489+
// -----
490+
491+
func.func @rotate_unsupported_type(%arg0 : index) {
492+
%offset = arith.constant 4 : i32
493+
%width = arith.constant 16 : i32
494+
// expected-error@+1 {{op operand #0 must be Integer or Float or fixed-length vector of Integer or Float values of ranks 1, but got 'index'}}
495+
%rotate, %valid = gpu.rotate %arg0, %offset, %width : index
496+
return
497+
}
498+
499+
// -----
500+
501+
func.func @rotate_unsupported_type_vec(%arg0 : vector<[4]xf32>) {
502+
%offset = arith.constant 4 : i32
503+
%width = arith.constant 16 : i32
504+
// expected-error@+1 {{op operand #0 must be Integer or Float or fixed-length vector of Integer or Float values of ranks 1, but got 'vector<[4]xf32>'}}
505+
%rotate, %valid = gpu.rotate %arg0, %offset, %width : vector<[4]xf32>
506+
return
507+
}
508+
509+
// -----
510+
511+
func.func @rotate_unsupported_width(%arg0 : f32) {
512+
%offset = arith.constant 4 : i32
513+
%width = arith.constant 15 : i32
514+
// expected-error@+1 {{op width must be a power of two}}
515+
%rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1)
516+
return
517+
}
518+
519+
// -----
520+
521+
func.func @rotate_unsupported_offset(%arg0 : f32) {
522+
%offset = arith.constant 16 : i32
523+
%width = arith.constant 16 : i32
524+
// expected-error@+1 {{op offset must be in the range [0, 16)}}
525+
%rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1)
526+
return
527+
}
528+
529+
// -----
530+
531+
func.func @rotate_unsupported_offset_minus(%arg0 : f32) {
532+
%offset = arith.constant -1 : i32
533+
%width = arith.constant 16 : i32
534+
// expected-error@+1 {{op offset must be in the range [0, 16)}}
535+
%rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1)
536+
return
537+
}
538+
539+
// -----
540+
541+
func.func @rotate_offset_non_constant(%arg0 : f32, %offset : i32) {
542+
%width = arith.constant 16 : i32
543+
// expected-error@+1 {{op offset is not a constant value}}
544+
%rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1)
545+
return
546+
}
547+
548+
// -----
549+
550+
func.func @rotate_width_non_constant(%arg0 : f32, %width : i32) {
551+
%offset = arith.constant 0 : i32
552+
// expected-error@+1 {{op width is not a constant value}}
553+
%rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1)
554+
return
555+
}
556+
557+
// -----
558+
481559
module {
482560
gpu.module @gpu_funcs {
483561
// expected-error @+1 {{custom op 'gpu.func' gpu.func requires named arguments}}

mlir/test/Dialect/GPU/ops.mlir

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,10 @@ module attributes {gpu.container_module} {
140140
// CHECK: gpu.shuffle idx %{{.*}}, %{{.*}}, %{{.*}} : f32
141141
%shfl3, %pred3 = gpu.shuffle idx %arg0, %offset, %width : f32
142142

143+
// CHECK: gpu.rotate %{{.*}}, %{{.*}}, %{{.*}} : f32
144+
%rotate_width = arith.constant 16 : i32
145+
%rotate, %pred4 = gpu.rotate %arg0, %offset, %rotate_width : f32
146+
143147
"gpu.barrier"() : () -> ()
144148

145149
"some_op"(%bIdX, %tIdX) : (index, index) -> ()

0 commit comments

Comments
 (0)