@@ -376,10 +376,19 @@ class OpenACCClauseCIREmitter final
376
376
// on all operation types.
377
377
mlir::ArrayAttr getAsyncOnlyAttr () {
378
378
if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp, mlir::acc::SerialOp,
379
- mlir::acc::KernelsOp, mlir::acc::DataOp>)
379
+ mlir::acc::KernelsOp, mlir::acc::DataOp>) {
380
380
return operation.getAsyncOnlyAttr ();
381
- else if constexpr (isCombinedType<OpTy>)
381
+ } else if constexpr (isOneOfTypes<OpTy, mlir::acc::EnterDataOp>) {
382
+ if (!operation.getAsyncAttr ())
383
+ return mlir::ArrayAttr{};
384
+
385
+ llvm::SmallVector<mlir::Attribute> devTysTemp;
386
+ devTysTemp.push_back (mlir::acc::DeviceTypeAttr::get (
387
+ builder.getContext (), mlir::acc::DeviceType::None));
388
+ return mlir::ArrayAttr::get (builder.getContext (), devTysTemp);
389
+ } else if constexpr (isCombinedType<OpTy>) {
382
390
return operation.computeOp .getAsyncOnlyAttr ();
391
+ }
383
392
384
393
// Note: 'wait' has async as well, but it cannot have data clauses, so we
385
394
// don't have to handle them here.
@@ -391,10 +400,19 @@ class OpenACCClauseCIREmitter final
391
400
// on all operation types.
392
401
mlir::ArrayAttr getAsyncOperandsDeviceTypeAttr () {
393
402
if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp, mlir::acc::SerialOp,
394
- mlir::acc::KernelsOp, mlir::acc::DataOp>)
403
+ mlir::acc::KernelsOp, mlir::acc::DataOp>) {
395
404
return operation.getAsyncOperandsDeviceTypeAttr ();
396
- else if constexpr (isCombinedType<OpTy>)
405
+ } else if constexpr (isOneOfTypes<OpTy, mlir::acc::EnterDataOp>) {
406
+ if (!operation.getAsyncOperand ())
407
+ return mlir::ArrayAttr{};
408
+
409
+ llvm::SmallVector<mlir::Attribute> devTysTemp;
410
+ devTysTemp.push_back (mlir::acc::DeviceTypeAttr::get (
411
+ builder.getContext (), mlir::acc::DeviceType::None));
412
+ return mlir::ArrayAttr::get (builder.getContext (), devTysTemp);
413
+ } else if constexpr (isCombinedType<OpTy>) {
397
414
return operation.computeOp .getAsyncOperandsDeviceTypeAttr ();
415
+ }
398
416
399
417
// Note: 'wait' has async as well, but it cannot have data clauses, so we
400
418
// don't have to handle them here.
@@ -409,6 +427,8 @@ class OpenACCClauseCIREmitter final
409
427
if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp, mlir::acc::SerialOp,
410
428
mlir::acc::KernelsOp, mlir::acc::DataOp>)
411
429
return operation.getAsyncOperands ();
430
+ else if constexpr (isOneOfTypes<OpTy, mlir::acc::EnterDataOp>)
431
+ return operation.getAsyncOperandMutable ();
412
432
else if constexpr (isCombinedType<OpTy>)
413
433
return operation.computeOp .getAsyncOperands ();
414
434
@@ -542,10 +562,11 @@ class OpenACCClauseCIREmitter final
542
562
void VisitAsyncClause (const OpenACCAsyncClause &clause) {
543
563
hasAsyncClause = true ;
544
564
if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp, mlir::acc::SerialOp,
545
- mlir::acc::KernelsOp, mlir::acc::DataOp>) {
546
- if (!clause.hasIntExpr ())
565
+ mlir::acc::KernelsOp, mlir::acc::DataOp,
566
+ mlir::acc::EnterDataOp>) {
567
+ if (!clause.hasIntExpr ()) {
547
568
operation.addAsyncOnly (builder.getContext (), lastDeviceTypeValues);
548
- else {
569
+ } else {
549
570
550
571
mlir::Value intExpr;
551
572
{
@@ -572,8 +593,8 @@ class OpenACCClauseCIREmitter final
572
593
applyToComputeOp (clause);
573
594
} else {
574
595
// TODO: When we've implemented this for everything, switch this to an
575
- // unreachable. Combined constructs remain. Data, enter data, exit data,
576
- // update constructs remain.
596
+ // unreachable. Combined constructs remain. Exit data, update constructs
597
+ // remain.
577
598
return clauseNotImplemented (clause);
578
599
}
579
600
}
@@ -604,7 +625,7 @@ class OpenACCClauseCIREmitter final
604
625
mlir::acc::KernelsOp, mlir::acc::InitOp,
605
626
mlir::acc::ShutdownOp, mlir::acc::SetOp,
606
627
mlir::acc::DataOp, mlir::acc::WaitOp,
607
- mlir::acc::HostDataOp>) {
628
+ mlir::acc::HostDataOp, mlir::acc::EnterDataOp >) {
608
629
operation.getIfCondMutable ().append (
609
630
createCondition (clause.getConditionExpr ()));
610
631
} else if constexpr (isCombinedType<OpTy>) {
@@ -659,7 +680,8 @@ class OpenACCClauseCIREmitter final
659
680
660
681
void VisitWaitClause (const OpenACCWaitClause &clause) {
661
682
if constexpr (isOneOfTypes<OpTy, mlir::acc::ParallelOp, mlir::acc::SerialOp,
662
- mlir::acc::KernelsOp, mlir::acc::DataOp>) {
683
+ mlir::acc::KernelsOp, mlir::acc::DataOp,
684
+ mlir::acc::EnterDataOp>) {
663
685
if (!clause.hasExprs ()) {
664
686
operation.addWaitOnly (builder.getContext (), lastDeviceTypeValues);
665
687
} else {
@@ -866,11 +888,16 @@ class OpenACCClauseCIREmitter final
866
888
var, mlir::acc::DataClause::acc_copyin, clause.getModifierList (),
867
889
/* structured=*/ true ,
868
890
/* implicit=*/ false );
891
+ } else if constexpr (isOneOfTypes<OpTy, mlir::acc::EnterDataOp>) {
892
+ for (const Expr *var : clause.getVarList ())
893
+ addDataOperand<mlir::acc::CopyinOp>(
894
+ var, mlir::acc::DataClause::acc_copyin, clause.getModifierList (),
895
+ /* structured=*/ false , /* implicit=*/ false );
869
896
} else if constexpr (isCombinedType<OpTy>) {
870
897
applyToComputeOp (clause);
871
898
} else {
872
899
// TODO: When we've implemented this for everything, switch this to an
873
- // unreachable. enter-data, declare constructs remain .
900
+ // unreachable. declare construct remains .
874
901
return clauseNotImplemented (clause);
875
902
}
876
903
}
@@ -900,11 +927,16 @@ class OpenACCClauseCIREmitter final
900
927
var, mlir::acc::DataClause::acc_create, clause.getModifierList (),
901
928
/* structured=*/ true ,
902
929
/* implicit=*/ false );
930
+ } else if constexpr (isOneOfTypes<OpTy, mlir::acc::EnterDataOp>) {
931
+ for (const Expr *var : clause.getVarList ())
932
+ addDataOperand<mlir::acc::CreateOp>(
933
+ var, mlir::acc::DataClause::acc_create, clause.getModifierList (),
934
+ /* structured=*/ false , /* implicit=*/ false );
903
935
} else if constexpr (isCombinedType<OpTy>) {
904
936
applyToComputeOp (clause);
905
937
} else {
906
938
// TODO: When we've implemented this for everything, switch this to an
907
- // unreachable. enter-data, declare constructs remain .
939
+ // unreachable. declare construct remains .
908
940
return clauseNotImplemented (clause);
909
941
}
910
942
}
@@ -974,12 +1006,15 @@ class OpenACCClauseCIREmitter final
974
1006
addDataOperand<mlir::acc::AttachOp, mlir::acc::DetachOp>(
975
1007
var, mlir::acc::DataClause::acc_attach, {}, /* structured=*/ true ,
976
1008
/* implicit=*/ false );
1009
+ } else if constexpr (isOneOfTypes<OpTy, mlir::acc::EnterDataOp>) {
1010
+ for (const Expr *var : clause.getVarList ())
1011
+ addDataOperand<mlir::acc::AttachOp>(
1012
+ var, mlir::acc::DataClause::acc_attach, {},
1013
+ /* structured=*/ false , /* implicit=*/ false );
977
1014
} else if constexpr (isCombinedType<OpTy>) {
978
1015
applyToComputeOp (clause);
979
1016
} else {
980
- // TODO: When we've implemented this for everything, switch this to an
981
- // unreachable. enter data remains.
982
- return clauseNotImplemented (clause);
1017
+ llvm_unreachable (" Unknown construct kind in VisitAttachClause" );
983
1018
}
984
1019
}
985
1020
};
@@ -1018,6 +1053,7 @@ EXPL_SPEC(mlir::acc::ShutdownOp)
1018
1053
EXPL_SPEC(mlir::acc::SetOp)
1019
1054
EXPL_SPEC(mlir::acc::WaitOp)
1020
1055
EXPL_SPEC(mlir::acc::HostDataOp)
1056
+ EXPL_SPEC(mlir::acc::EnterDataOp)
1021
1057
#undef EXPL_SPEC
1022
1058
1023
1059
template <typename ComputeOp, typename LoopOp>
0 commit comments