diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 6895e946b8a45..e55060dc04204 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -3009,10 +3009,29 @@ def NVVM_GriddepcontrolLaunchDependentsOp // NVVM Mapa Op //===----------------------------------------------------------------------===// +// Helper predicates for address space checking +def IsGenericAddressSpace : CPred<"llvm::cast($_self).getAddressSpace() == 0">; +def IsSharedAddressSpace : CPred<"llvm::cast($_self).getAddressSpace() == 3">; +def IsSharedClusterAddressSpace : CPred<"llvm::cast($_self).getAddressSpace() == 7">; + +class NVVM_AddressSpaceMapping : + PredOpTrait<"valid address space mapping for NVVM mapa operation", + Or<[ + // Generic -> Generic + And<[ + SubstLeaves<"$_self", "$" # inputArg # ".getType()", IsGenericAddressSpace>, + SubstLeaves<"$_self", "$" # resultArg # ".getType()", IsGenericAddressSpace> + ]>, + // Shared -> SharedCluster + And<[ + SubstLeaves<"$_self", "$" # inputArg # ".getType()", IsSharedAddressSpace>, + SubstLeaves<"$_self", "$" # resultArg # ".getType()", IsSharedClusterAddressSpace> + ]> + ]>>; + def NVVM_MapaOp: NVVM_Op<"mapa", - [TypesMatchWith<"`res` and `a` should have the same type", - "a", "res", "$_self">, NVVMRequiresSM<90>]> { - let results = (outs AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$res); + [NVVM_AddressSpaceMapping<"a", "res">, NVVMRequiresSM<90>]> { + let results = (outs AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerSharedCluster]>:$res); let arguments = (ins AnyTypeOf<[LLVM_PointerGeneric, LLVM_PointerShared]>:$a, I32:$b); string llvmBuilder = [{ diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index 251ca716c7a7a..7a85eea58c558 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1201,8 +1201,8 @@ func.func @cp_async(%arg0: !llvm.ptr<3>, %arg1: !llvm.ptr<1>) { // ----- func.func @mapa(%a: !llvm.ptr, %b : i32) { - // expected-error @below {{`res` and `a` should have the same type}} - %0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr<3> + // expected-error @below {{'nvvm.mapa' op failed to verify that valid address space mapping for NVVM mapa operation}} + %0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr<7> return } diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir index c7fa41c98ac92..4349193aa1a45 100644 --- a/mlir/test/Dialect/LLVMIR/nvvm.mlir +++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir @@ -552,7 +552,7 @@ func.func @mapa(%a: !llvm.ptr, %a_shared: !llvm.ptr<3>, %b : i32) { // CHECK: nvvm.mapa %{{.*}} %0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr // CHECK: nvvm.mapa %{{.*}} - %1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<3> + %1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<7> return } diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir index f86a04186f512..c119c1a0fd21f 100644 --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -760,8 +760,8 @@ llvm.func @nvvm_griddepcontrol_launch_dependents() { llvm.func @nvvm_mapa(%a: !llvm.ptr, %a_shared: !llvm.ptr<3>, %b : i32) { // CHECK-LLVM: call ptr @llvm.nvvm.mapa(ptr %{{.*}}, i32 %{{.*}}) %0 = nvvm.mapa %a, %b: !llvm.ptr -> !llvm.ptr - // CHECK-LLVM: call ptr addrspace(3) @llvm.nvvm.mapa.shared.cluster(ptr addrspace(3) %{{.*}}, i32 %{{.*}}) - %1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<3> + // CHECK-LLVM: call ptr addrspace(7) @llvm.nvvm.mapa.shared.cluster(ptr addrspace(3) %{{.*}}, i32 %{{.*}}) + %1 = nvvm.mapa %a_shared, %b: !llvm.ptr<3> -> !llvm.ptr<7> llvm.return }