Skip to content

[mlir][gpu] Add gpu.rotate operation #142796

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 1, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 50 additions & 2 deletions mlir/include/mlir/Dialect/GPU/IR/GPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1304,8 +1304,8 @@ def GPU_ShuffleOp : GPU_Op<
Results<(outs AnyIntegerOrFloatOr1DVector:$shuffleResult, I1:$valid)> {
let summary = "Shuffles values within a subgroup.";
let description = [{
The "shuffle" op moves values to a across lanes (a.k.a., invocations,
work items) within the same subgroup. The `width` argument specifies the
The "shuffle" op moves values across lanes in a subgroup (a.k.a., local
invocation) within the same subgroup. The `width` argument specifies the
number of lanes that participate in the shuffle, and must be uniform
across all lanes. Further, the first `width` lanes of the subgroup must
be active.
Expand Down Expand Up @@ -1364,6 +1364,54 @@ def GPU_ShuffleOp : GPU_Op<
];
}

def GPU_RotateOp : GPU_Op<
"rotate", [Pure, AllTypesMatch<["value", "rotateResult"]>]>,
Arguments<(ins AnyIntegerOrFloatOr1DVector:$value, I32:$offset, I32:$width)>,
Results<(outs AnyIntegerOrFloatOr1DVector:$rotateResult, I1:$valid)> {
let summary = "Rotate values within a subgroup.";
let description = [{
The "rotate" op moves values across lanes in a subgroup (a.k.a., local
invocations) within the same subgroup. The `width` argument specifies the
number of lanes that participate in the rotation, and must be uniform across
all participating lanes. Further, the first `width` lanes of the subgroup
must be active.

`width` must be a power of two, and `offset` must be in the range
`[0, width)`.

Return the `rotateResult` of the invocation whose id within the group is
calculated as follows:

```mlir
Invocation ID = ((LaneId + offset) & (width - 1)) + (LaneId & ~(width - 1))
```

Returns the `rotateResult` and `true` if the current lane id is smaller than
`width`, and poison value and `false` otherwise.

example:

```mlir
%offset = arith.constant 1 : i32
%width = arith.constant 16 : i32
%1, %2 = gpu.rotate %0, %offset, %width : f32
```

For lane `k`, returns the value from lane `(k + cst1) % width`.
}];

let assemblyFormat = [{
$value `,` $offset `,` $width attr-dict `:` type($value)
}];

let builders = [
// Helper function that creates a rotate with constant offset/width.
OpBuilder<(ins "Value":$value, "int32_t":$offset, "int32_t":$width)>
];

let hasVerifier = 1;
}

def GPU_BarrierOp : GPU_Op<"barrier"> {
let summary = "Synchronizes all work items of a workgroup.";
let description = [{
Expand Down
47 changes: 46 additions & 1 deletion mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,16 @@ class GPUShuffleConversion final : public OpConversionPattern<gpu::ShuffleOp> {
ConversionPatternRewriter &rewriter) const override;
};

/// Pattern to convert a gpu.rotate op into a spirv.GroupNonUniformRotateKHROp.
class GPURotateConversion final : public OpConversionPattern<gpu::RotateOp> {
public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(gpu::RotateOp rotateOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
};

class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> {
public:
using OpConversionPattern::OpConversionPattern;
Expand Down Expand Up @@ -458,6 +468,41 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
return success();
}

//===----------------------------------------------------------------------===//
// Rotate
//===----------------------------------------------------------------------===//

LogicalResult GPURotateConversion::matchAndRewrite(
gpu::RotateOp rotateOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
const spirv::TargetEnv &targetEnv =
getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
unsigned subgroupSize =
targetEnv.getAttr().getResourceLimits().getSubgroupSize();
IntegerAttr widthAttr;
if (!matchPattern(rotateOp.getWidth(), m_Constant(&widthAttr)) ||
widthAttr.getValue().getZExtValue() > subgroupSize)
return rewriter.notifyMatchFailure(
rotateOp,
"rotate width is not a constant or larger than target subgroup size");

Location loc = rotateOp.getLoc();
auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
Value rotateResult = rewriter.create<spirv::GroupNonUniformRotateKHROp>(
loc, scope, adaptor.getValue(), adaptor.getOffset(), adaptor.getWidth());
Value validVal;
if (widthAttr.getValue().getZExtValue() == subgroupSize) {
validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), loc, rewriter);
} else {
Value laneId = rewriter.create<gpu::LaneIdOp>(loc, widthAttr);
validVal = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult,
laneId, adaptor.getWidth());
}

rewriter.replaceOp(rotateOp, {rotateResult, validVal});
return success();
}

//===----------------------------------------------------------------------===//
// Group ops
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -733,7 +778,7 @@ void mlir::populateGPUToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
patterns.add<
GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion,
GPUReturnOpConversion, GPUShuffleConversion,
GPUReturnOpConversion, GPUShuffleConversion, GPURotateConversion,
LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
Expand Down
44 changes: 44 additions & 0 deletions mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1331,6 +1331,50 @@ void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value,
mode);
}

//===----------------------------------------------------------------------===//
// RotateOp
//===----------------------------------------------------------------------===//

void RotateOp::build(OpBuilder &builder, OperationState &result, Value value,
int32_t offset, int32_t width) {
build(builder, result, value,
builder.create<arith::ConstantOp>(result.location,
builder.getI32IntegerAttr(offset)),
builder.create<arith::ConstantOp>(result.location,
builder.getI32IntegerAttr(width)));
}

LogicalResult RotateOp::verify() {
auto offsetConstOp = getOffset().getDefiningOp<arith::ConstantOp>();
if (!offsetConstOp)
return emitOpError() << "offset is not a constant value";

auto offsetIntAttr =
llvm::dyn_cast<mlir::IntegerAttr>(offsetConstOp.getValue());

auto widthConstOp = getWidth().getDefiningOp<arith::ConstantOp>();
if (!widthConstOp)
return emitOpError() << "width is not a constant value";

auto widthIntAttr =
llvm::dyn_cast<mlir::IntegerAttr>(widthConstOp.getValue());

llvm::APInt offsetValue = offsetIntAttr.getValue();
llvm::APInt widthValue = widthIntAttr.getValue();

if (!widthValue.isPowerOf2())
return emitOpError() << "width must be a power of two";

if (offsetValue.sge(widthValue) || offsetValue.slt(0)) {
SmallString<8> widthStr;
widthValue.toStringUnsigned(widthStr);
return emitOpError() << "offset must be in the range [0, "
<< std::string(std::move(widthStr)) << ")";
}

return success();
}

//===----------------------------------------------------------------------===//
// BarrierOp
//===----------------------------------------------------------------------===//
Expand Down
102 changes: 102 additions & 0 deletions mlir/test/Conversion/GPUToSPIRV/rotate.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// RUN: mlir-opt -split-input-file -convert-gpu-to-spirv -verify-diagnostics %s -o - | FileCheck %s

module attributes {
gpu.container_module,
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformRotateKHR], []>,
#spirv.resource_limits<subgroup_size = 16>>
} {

gpu.module @kernels {
// CHECK-LABEL: spirv.func @rotate()
gpu.func @rotate() kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
%offset = arith.constant 4 : i32
%width = arith.constant 16 : i32
%val = arith.constant 42.0 : f32

// CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32
// CHECK: %[[WIDTH:.+]] = spirv.Constant 16 : i32
// CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
// CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Subgroup> %[[VAL]], %[[OFFSET]], cluster_size(%[[WIDTH]]) : f32, i32, i32 -> f32
// CHECK: %{{.+}} = spirv.Constant true
%result, %valid = gpu.rotate %val, %offset, %width : f32
gpu.return
}
}

}

// -----

module attributes {
gpu.container_module,
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformRotateKHR], []>,
#spirv.resource_limits<subgroup_size = 16>>
} {

gpu.module @kernels {
// CHECK-LABEL: spirv.func @rotate_width_less_than_subgroup_size()
gpu.func @rotate_width_less_than_subgroup_size() kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
%offset = arith.constant 4 : i32
%width = arith.constant 8 : i32
%val = arith.constant 42.0 : f32

// CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32
// CHECK: %[[WIDTH:.+]] = spirv.Constant 8 : i32
// CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
// CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Subgroup> %[[VAL]], %[[OFFSET]], cluster_size(%[[WIDTH]]) : f32, i32, i32 -> f32
// CHECK: %[[INVOCATION_ID_ADDR:.+]] = spirv.mlir.addressof @__builtin__SubgroupLocalInvocationId__
// CHECK: %[[INVOCATION_ID:.+]] = spirv.Load "Input" %[[INVOCATION_ID_ADDR]]
// CHECK: %{{.+}} = spirv.ULessThan %[[INVOCATION_ID]], %[[WIDTH]]
%result, %valid = gpu.rotate %val, %offset, %width : f32
gpu.return
}
}

}

// -----

module attributes {
gpu.container_module,
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformRotateKHR], []>,
#spirv.resource_limits<subgroup_size = 16>>
} {

gpu.module @kernels {
gpu.func @rotate_with_bigger_than_subgroup_size() kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
%offset = arith.constant 4 : i32
%width = arith.constant 32 : i32
%val = arith.constant 42.0 : f32

// expected-error @+1 {{failed to legalize operation 'gpu.rotate'}}
%result, %valid = gpu.rotate %val, %offset, %width : f32
gpu.return
}
}

}

// -----

module attributes {
gpu.container_module,
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformRotateKHR], []>,
#spirv.resource_limits<subgroup_size = 16>>
} {

gpu.module @kernels {
gpu.func @rotate_non_const_width(%width: i32) kernel
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
%offset = arith.constant 4 : i32
%val = arith.constant 42.0 : f32

// expected-error @+1 {{'gpu.rotate' op width is not a constant value}}
%result, %valid = gpu.rotate %val, %offset, %width : f32
gpu.return
}
}

}
78 changes: 78 additions & 0 deletions mlir/test/Dialect/GPU/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,84 @@ func.func @shuffle_unsupported_type_vec(%arg0 : vector<[4]xf32>, %arg1 : i32, %a

// -----

func.func @rotate_mismatching_type(%arg0 : f32) {
%offset = arith.constant 4 : i32
%width = arith.constant 16 : i32
// expected-error@+1 {{op failed to verify that all of {value, rotateResult} have same type}}
%rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (i32, i1)
return
}

// -----

func.func @rotate_unsupported_type(%arg0 : index) {
%offset = arith.constant 4 : i32
%width = arith.constant 16 : i32
// expected-error@+1 {{op operand #0 must be Integer or Float or fixed-length vector of Integer or Float values of ranks 1, but got 'index'}}
%rotate, %valid = gpu.rotate %arg0, %offset, %width : index
return
}

// -----

func.func @rotate_unsupported_type_vec(%arg0 : vector<[4]xf32>) {
%offset = arith.constant 4 : i32
%width = arith.constant 16 : i32
// expected-error@+1 {{op operand #0 must be Integer or Float or fixed-length vector of Integer or Float values of ranks 1, but got 'vector<[4]xf32>'}}
%rotate, %valid = gpu.rotate %arg0, %offset, %width : vector<[4]xf32>
return
}

// -----

func.func @rotate_unsupported_width(%arg0 : f32) {
%offset = arith.constant 4 : i32
%width = arith.constant 15 : i32
// expected-error@+1 {{op width must be a power of two}}
%rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1)
return
}

// -----

func.func @rotate_unsupported_offset(%arg0 : f32) {
%offset = arith.constant 16 : i32
%width = arith.constant 16 : i32
// expected-error@+1 {{op offset must be in the range [0, 16)}}
%rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1)
return
}

// -----

func.func @rotate_unsupported_offset_minus(%arg0 : f32) {
%offset = arith.constant -1 : i32
%width = arith.constant 16 : i32
// expected-error@+1 {{op offset must be in the range [0, 16)}}
%rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1)
return
}

// -----

func.func @rotate_offset_non_constant(%arg0 : f32, %offset : i32) {
%width = arith.constant 16 : i32
// expected-error@+1 {{op offset is not a constant value}}
%rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1)
return
}

// -----

func.func @rotate_width_non_constant(%arg0 : f32, %width : i32) {
%offset = arith.constant 0 : i32
// expected-error@+1 {{op width is not a constant value}}
%rotate, %valid = "gpu.rotate"(%arg0, %offset, %width) : (f32, i32, i32) -> (f32, i1)
return
}

// -----

module {
gpu.module @gpu_funcs {
// expected-error @+1 {{custom op 'gpu.func' gpu.func requires named arguments}}
Expand Down
4 changes: 4 additions & 0 deletions mlir/test/Dialect/GPU/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ module attributes {gpu.container_module} {
// CHECK: gpu.shuffle idx %{{.*}}, %{{.*}}, %{{.*}} : f32
%shfl3, %pred3 = gpu.shuffle idx %arg0, %offset, %width : f32

// CHECK: gpu.rotate %{{.*}}, %{{.*}}, %{{.*}} : f32
%rotate_width = arith.constant 16 : i32
%rotate, %pred4 = gpu.rotate %arg0, %offset, %rotate_width : f32

"gpu.barrier"() : () -> ()

"some_op"(%bIdX, %tIdX) : (index, index) -> ()
Expand Down