Skip to content

[mlir][gpu] Add field to mark asynchronous side effects #72013

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td
Original file line number Diff line number Diff line change
Expand Up @@ -423,9 +423,9 @@ def NVGPU_DeviceAsyncCopyOp : NVGPU_Op<"device_async_copy", [
```
}];
let results = (outs NVGPU_DeviceAsyncToken:$asyncToken);
let arguments = (ins Arg<AnyMemRef, "", [MemWriteAt<0, FullEffect>]>:$dst,
let arguments = (ins Arg<AnyMemRef, "", [MemWriteAt<0, FullEffect, 1>]>:$dst,
Variadic<Index>:$dstIndices,
Arg<AnyMemRef, "", [MemReadAt<0, FullEffect>]>:$src,
Arg<AnyMemRef, "", [MemReadAt<0, FullEffect, 1>]>:$src,
Variadic<Index>:$srcIndices,
IndexAttr:$dstElements,
Optional<Index>:$srcElements,
Expand Down Expand Up @@ -642,7 +642,7 @@ def NVGPU_TmaAsyncLoadOp : NVGPU_Op<"tma.async.load", [AttrSizedOperandSegments]

The Op uses `$barrier` mbarrier based completion mechanism.
}];
let arguments = (ins Arg<AnyMemRef, "", [MemWriteAt<0, FullEffect>]>:$dst,
let arguments = (ins Arg<AnyMemRef, "", [MemWriteAt<0, FullEffect, 1>]>:$dst,
NVGPU_MBarrierGroup:$barriers,
NVGPU_TensorMapDescriptor:$tensorMapDescriptor,
Variadic<Index>:$coordinates,
Expand Down
6 changes: 5 additions & 1 deletion mlir/include/mlir/Interfaces/SideEffectInterfaceBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ def PartialEffect : EffectRange<0>;
// This class is the general base side effect class. This is used by derived
// effect interfaces to define their effects.
class SideEffect<EffectOpInterfaceBase interface, string effectName,
Resource resourceReference, int effectStage, EffectRange range>
Resource resourceReference, int effectStage, EffectRange range,
bits<1> isAsync>
: OpVariableDecorator {
/// The name of the base effects class.
string baseEffectName = interface.baseEffectName;
Expand All @@ -183,6 +184,9 @@ class SideEffect<EffectOpInterfaceBase interface, string effectName,

// Does this side effect act on every single value of resource.
bit effectOnFullRegion = range.Value;

// Does this side effect potentially occur after op exit
bit asynchronous = isAsync;
Copy link
Collaborator

@joker-eph joker-eph Nov 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a tricky semantics to add. I actually am not sure how to model this.
How are we supposed to reason about Effects which have this bit set?

Thanks for giving this a try, this is an important thing to solve. However I feel this is deep enough of a change that we should have a clear description of this in a proposal on Discourse.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fair. I will mark this as a draft until I can work out the proposal with my teammates.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, effect instances have an attribute field for additional qualification of effects. It may be worth using, rather than pushing asynchronous on every potential client, as long as one can come up with a "conservatively safe" way of handling side effects while ignoring this information.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a variation that fills the attribute parameter with a BoolAttr. How would you recommend making it certain who added the parameter and for what purpose?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using a custom attribute (GpuAsyncAttr) that stores a bool (or even make it unit if that makes sense) rather than a BoolAttr. There was some previous patch where we discussed putting a DictAttr and using string keys.

Note that I'm not sure this is the right mechanism, it's conditioned on there being a way to interpret the absence of the attribute in a conservatively correct way. (The only user of the attribute I'm aware of so far is a downstream project that uses affine-ish descriptors of the subset of memref elements accessed, in absence of which analyses can just conservatively assume all elements are.)

}

// This class is the base used for specifying effects applied to an operation.
Expand Down
18 changes: 12 additions & 6 deletions mlir/include/mlir/Interfaces/SideEffectInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,28 +141,28 @@ class EffectInstance {
EffectInstance(EffectT *effect, Resource *resource = DefaultResource::get())
: effect(effect), resource(resource), stage(0),
effectOnFullRegion(false) {}
EffectInstance(EffectT *effect, int stage, bool effectOnFullRegion,
EffectInstance(EffectT *effect, int stage, bool effectOnFullRegion, bool async,
Resource *resource = DefaultResource::get())
: effect(effect), resource(resource), stage(stage),
effectOnFullRegion(effectOnFullRegion) {}
effectOnFullRegion(effectOnFullRegion), asynchronous(async) {}
EffectInstance(EffectT *effect, Value value,
Resource *resource = DefaultResource::get())
: effect(effect), resource(resource), value(value), stage(0),
effectOnFullRegion(false) {}
EffectInstance(EffectT *effect, Value value, int stage,
bool effectOnFullRegion,
bool effectOnFullRegion, bool async,
Resource *resource = DefaultResource::get())
: effect(effect), resource(resource), value(value), stage(stage),
effectOnFullRegion(effectOnFullRegion) {}
effectOnFullRegion(effectOnFullRegion), asynchronous(async) {}
EffectInstance(EffectT *effect, SymbolRefAttr symbol,
Resource *resource = DefaultResource::get())
: effect(effect), resource(resource), value(symbol), stage(0),
effectOnFullRegion(false) {}
EffectInstance(EffectT *effect, SymbolRefAttr symbol, int stage,
bool effectOnFullRegion,
bool effectOnFullRegion, bool async,
Resource *resource = DefaultResource::get())
: effect(effect), resource(resource), value(symbol), stage(stage),
effectOnFullRegion(effectOnFullRegion) {}
effectOnFullRegion(effectOnFullRegion), asynchronous(async) {}
EffectInstance(EffectT *effect, Attribute parameters,
Resource *resource = DefaultResource::get())
: effect(effect), resource(resource), parameters(parameters), stage(0),
Expand Down Expand Up @@ -221,6 +221,9 @@ class EffectInstance {
/// Return if this side effect act on every single value of resource.
bool getEffectOnFullRegion() const { return effectOnFullRegion; }

/// Return if the side effect may occur after the op exits.
bool getAsynchronous() const { return asynchronous; }

private:
/// The specific effect being applied.
EffectT *effect;
Expand All @@ -242,6 +245,9 @@ class EffectInstance {

// Does this side effect act on every single value of resource.
bool effectOnFullRegion;

/// Does this side effect potentially occur after op exit.
bool asynchronous;
};
} // namespace SideEffects

Expand Down
24 changes: 12 additions & 12 deletions mlir/include/mlir/Interfaces/SideEffectInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def MemoryEffectsOpInterface

// The base class for defining specific memory effects.
class MemoryEffect<string effectName, Resource resource, int stage,
EffectRange range>
: SideEffect<MemoryEffectsOpInterface, effectName, resource, stage, range>;
EffectRange range, bits<1> async>
: SideEffect<MemoryEffectsOpInterface, effectName, resource, stage, range, async>;

// This class represents the trait for memory effects that may be placed on
// operations.
Expand All @@ -51,7 +51,7 @@ class MemoryEffects<list<MemoryEffect> effects = []>
// not any visible mutation or dereference.
class MemAlloc<Resource resource, int stage = 0,
EffectRange range = PartialEffect>
: MemoryEffect<"::mlir::MemoryEffects::Allocate", resource, stage, range>;
: MemoryEffect<"::mlir::MemoryEffects::Allocate", resource, stage, range, 0>;
def MemAlloc : MemAlloc<DefaultResource, 0, PartialEffect>;
class MemAllocAt<int stage, EffectRange range = PartialEffect>
: MemAlloc<DefaultResource, stage, range>;
Expand All @@ -61,7 +61,7 @@ class MemAllocAt<int stage, EffectRange range = PartialEffect>
// resource, and not any visible allocation, mutation or dereference.
class MemFree<Resource resource, int stage = 0,
EffectRange range = PartialEffect>
: MemoryEffect<"::mlir::MemoryEffects::Free", resource, stage, range>;
: MemoryEffect<"::mlir::MemoryEffects::Free", resource, stage, range, 0>;
def MemFree : MemFree<DefaultResource, 0, PartialEffect>;
class MemFreeAt<int stage, EffectRange range = PartialEffect>
: MemFree<DefaultResource, stage, range>;
Expand All @@ -70,21 +70,21 @@ class MemFreeAt<int stage, EffectRange range = PartialEffect>
// resource. A 'read' effect implies only dereferencing of the resource, and
// not any visible mutation.
class MemRead<Resource resource, int stage = 0,
EffectRange range = PartialEffect>
: MemoryEffect<"::mlir::MemoryEffects::Read", resource, stage, range>;
EffectRange range = PartialEffect, bits<1> async = 0>
: MemoryEffect<"::mlir::MemoryEffects::Read", resource, stage, range, async>;
def MemRead : MemRead<DefaultResource, 0, PartialEffect>;
class MemReadAt<int stage, EffectRange range = PartialEffect>
: MemRead<DefaultResource, stage, range>;
class MemReadAt<int stage, EffectRange range = PartialEffect, bits<1> async = 0>
: MemRead<DefaultResource, stage, range, async>;

// The following effect indicates that the operation writes to some
// resource. A 'write' effect implies only mutating a resource, and not any
// visible dereference or read.
class MemWrite<Resource resource, int stage = 0,
EffectRange range = PartialEffect>
: MemoryEffect<"::mlir::MemoryEffects::Write", resource, stage, range>;
EffectRange range = PartialEffect, bits<1> async = 0>
: MemoryEffect<"::mlir::MemoryEffects::Write", resource, stage, range, async>;
def MemWrite : MemWrite<DefaultResource, 0, PartialEffect>;
class MemWriteAt<int stage, EffectRange range = PartialEffect>
: MemWrite<DefaultResource, stage, range>;
class MemWriteAt<int stage, EffectRange range = PartialEffect, bits<1> async = 0>
: MemWrite<DefaultResource, stage, range, async>;

//===----------------------------------------------------------------------===//
// Effect Traits
Expand Down
3 changes: 3 additions & 0 deletions mlir/include/mlir/TableGen/SideEffects.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ class SideEffect : public Operator::VariableDecorator {
// Return if this side effect act on every single value of resource.
bool getEffectOnfullRegion() const;

// Return if the side effect occurs after op exit.
bool getAsynchronous() const;

static bool classof(const Operator::VariableDecorator *var);
};

Expand Down
5 changes: 5 additions & 0 deletions mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,11 @@ static bool
haveConflictingEffects(ArrayRef<MemoryEffects::EffectInstance> beforeEffects,
ArrayRef<MemoryEffects::EffectInstance> afterEffects) {
for (const MemoryEffects::EffectInstance &before : beforeEffects) {
// Before may conflict with after, but since it is async, a BarrierOp cannot
// synchronize the effects. If the async field is set, it is presumed that
// some architecture-specific mechanism is needed to synchronize the effect.
if (before.getAsynchronous()) continue;

for (const MemoryEffects::EffectInstance &after : afterEffects) {
// If cannot alias, definitely no conflict.
if (!mayAlias(before, after))
Expand Down
4 changes: 4 additions & 0 deletions mlir/lib/TableGen/SideEffects.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ bool SideEffect::getEffectOnfullRegion() const {
return def->getValueAsBit("effectOnFullRegion");
}

bool SideEffect::getAsynchronous() const {
return def->getValueAsBit("asynchronous");
}

bool SideEffect::classof(const Operator::VariableDecorator *var) {
return var->getDef().isSubClassOf("SideEffect");
}
Expand Down
20 changes: 20 additions & 0 deletions mlir/test/Dialect/GPU/barrier-elimination.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,23 @@ attributes {__parallel_region_boundary_for_test} {
%4 = memref.load %C[] : memref<f32>
return %0, %1, %2, %3, %4 : f32, f32, f32, f32, f32
}

// CHECK-LABEL: @async_copy
func.func @async_copy() -> ()
attributes {__parallel_region_boundary_for_test} {
// CHECK: %[[A:.+]] = memref.alloc
// CHECK: %[[B:.+]] = memref.alloc
%A = memref.alloc() : memref<f32>
%B = memref.alloc() : memref<f32, #gpu.address_space<workgroup>>
gpu.barrier
// CHECK: %[[T:.+]] = nvgpu.device_async_copy %[[A]][], %[[B]][], 1
%token = nvgpu.device_async_copy %A[], %B[], 1 : memref<f32> to memref<f32, #gpu.address_space<workgroup>>
// This needs to be erased because it can't synchronize the effects on %B.
gpu.barrier
// This does synchronize the effects on %B.
// CHECK-NEXT: nvgpu.device_async_wait %[[T]]
nvgpu.device_async_wait %token
// CHECK-NEXT: linalg.abs ins(%[[B]] : memref<f32, #gpu.address_space<workgroup>>) outs(%[[A]] : memref<f32>)
linalg.abs ins(%B: memref<f32, #gpu.address_space<workgroup>>) outs(%A: memref<f32>)
return
}
2 changes: 1 addition & 1 deletion mlir/test/lib/Dialect/Test/TestInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def TestEffectOpInterface

class TestEffect<string effectName>
: SideEffect<TestEffectOpInterface, effectName, DefaultResource, 0,
PartialEffect>;
PartialEffect, 0>;

class TestEffects<list<TestEffect> effects = []>
: SideEffectsTraitBase<TestEffectOpInterface, effects>;
Expand Down
14 changes: 7 additions & 7 deletions mlir/test/mlir-tblgen/op-side-effects.td
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ def SideEffectOpB : TEST_Op<"side_effect_op_b",

// CHECK: void SideEffectOpA::getEffects
// CHECK: for (::mlir::Value value : getODSOperands(0))
// CHECK: effects.emplace_back(::mlir::MemoryEffects::Read::get(), value, 0, false, ::mlir::SideEffects::DefaultResource::get());
// CHECK: effects.emplace_back(::mlir::MemoryEffects::Read::get(), value, 0, false, false, ::mlir::SideEffects::DefaultResource::get());
// CHECK: for (::mlir::Value value : getODSOperands(1))
// CHECK: effects.emplace_back(::mlir::MemoryEffects::Write::get(), value, 1, true, ::mlir::SideEffects::DefaultResource::get());
// CHECK: effects.emplace_back(::mlir::MemoryEffects::Read::get(), getSymbolAttr(), 0, false, ::mlir::SideEffects::DefaultResource::get());
// CHECK: effects.emplace_back(::mlir::MemoryEffects::Write::get(), getFlatSymbolAttr(), 0, false, ::mlir::SideEffects::DefaultResource::get());
// CHECK: effects.emplace_back(::mlir::MemoryEffects::Write::get(), value, 1, true, false, ::mlir::SideEffects::DefaultResource::get());
// CHECK: effects.emplace_back(::mlir::MemoryEffects::Read::get(), getSymbolAttr(), 0, false, false, ::mlir::SideEffects::DefaultResource::get());
// CHECK: effects.emplace_back(::mlir::MemoryEffects::Write::get(), getFlatSymbolAttr(), 0, false, false, ::mlir::SideEffects::DefaultResource::get());
// CHECK: if (auto symbolRef = getOptionalSymbolAttr())
// CHECK: effects.emplace_back(::mlir::MemoryEffects::Read::get(), symbolRef, 0, false, ::mlir::SideEffects::DefaultResource::get());
// CHECK: effects.emplace_back(::mlir::MemoryEffects::Read::get(), symbolRef, 0, false, false, ::mlir::SideEffects::DefaultResource::get());
// CHECK: for (::mlir::Value value : getODSResults(0))
// CHECK: effects.emplace_back(::mlir::MemoryEffects::Allocate::get(), value, 0, false, CustomResource::get());
// CHECK: effects.emplace_back(::mlir::MemoryEffects::Allocate::get(), value, 0, false, false, CustomResource::get());

// CHECK: void SideEffectOpB::getEffects
// CHECK: effects.emplace_back(::mlir::MemoryEffects::Write::get(), 0, false, CustomResource::get());
// CHECK: effects.emplace_back(::mlir::MemoryEffects::Write::get(), 0, false, false, CustomResource::get());
14 changes: 8 additions & 6 deletions mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3304,9 +3304,10 @@ void OpEmitter::genSideEffectInterfaceMethods() {
// {1}: Optional value or symbol reference.
// {2}: The side effect stage.
// {3}: Does this side effect act on every single value of resource.
// {4}: The resource class.
// {4}: Is asynchronous
// {5}: The resource class.
const char *addEffectCode =
" effects.emplace_back({0}::get(), {1}{2}, {3}, {4}::get());\n";
" effects.emplace_back({0}::get(), {1}{2}, {3}, {4}, {5}::get());\n";

for (auto &it : interfaceEffects) {
// Generate the 'getEffects' method.
Expand All @@ -3325,10 +3326,11 @@ void OpEmitter::genSideEffectInterfaceMethods() {
StringRef resource = location.effect.getResource();
int stage = (int)location.effect.getStage();
bool effectOnFullRegion = (int)location.effect.getEffectOnfullRegion();
bool async = (int)location.effect.getAsynchronous();
if (location.kind == EffectKind::Static) {
// A static instance has no attached value.
body << llvm::formatv(addEffectCode, effect, "", stage,
effectOnFullRegion, resource)
effectOnFullRegion, async, resource)
.str();
} else if (location.kind == EffectKind::Symbol) {
// A symbol reference requires adding the proper attribute.
Expand All @@ -3337,11 +3339,11 @@ void OpEmitter::genSideEffectInterfaceMethods() {
if (attr->attr.isOptional()) {
body << " if (auto symbolRef = " << argName << "Attr())\n "
<< llvm::formatv(addEffectCode, effect, "symbolRef, ", stage,
effectOnFullRegion, resource)
effectOnFullRegion, async, resource)
.str();
} else {
body << llvm::formatv(addEffectCode, effect, argName + "Attr(), ",
stage, effectOnFullRegion, resource)
stage, effectOnFullRegion, async, resource)
.str();
}
} else {
Expand All @@ -3350,7 +3352,7 @@ void OpEmitter::genSideEffectInterfaceMethods() {
<< (location.kind == EffectKind::Operand ? "Operands" : "Results")
<< "(" << location.index << "))\n "
<< llvm::formatv(addEffectCode, effect, "value, ", stage,
effectOnFullRegion, resource)
effectOnFullRegion, async, resource)
.str();
}
}
Expand Down