@@ -620,12 +620,9 @@ void MultiBatchMatmulOp::getEffects(
620620
621621LogicalResult ScaledDotProductAttentionOp::verify () { return success (); }
622622
623- // / Given an N-dimensional tensor x, this method converts
624- // / softmax(x) to the following sequence of operations:
625- // /
626- // / 1. transpose ins[1]
627- // / 2. matmul ins[0] @ 1
628- // /
623+ // / This method converts ScaledDotProductAttention into the following
624+ // / sequence of operations:
625+ // / output = softmax(ins[0] @ transpose(ins[1]) * scale + ins[3]) @ ins[2]
629626FailureOr<SmallVector<Value>>
630627ScaledDotProductAttentionOp::decomposeOperation (OpBuilder &b) {
631628 OpBuilder::InsertionGuard guard (b);
@@ -635,6 +632,7 @@ ScaledDotProductAttentionOp::decomposeOperation(OpBuilder &b) {
635632 mask = getInputs ()[3 ];
636633 auto dtype = cast<RankedTensorType>(query.getType ()).getElementType ();
637634 auto shape = cast<RankedTensorType>(query.getType ()).getShape ();
635+ float rsqrt_head = 1 / sqrt (shape[3 ]);
638636
639637 SmallVector<int64_t > permutation{0 , 1 , 3 , 2 };
640638 SmallVector<int64_t > transposeShape{shape[0 ], shape[1 ], shape[3 ], shape[2 ]};
@@ -652,16 +650,40 @@ ScaledDotProductAttentionOp::decomposeOperation(OpBuilder &b) {
652650 /* inputs=*/ ValueRange{query, transpose->getResult (0 )},
653651 /* outputs=*/ ValueRange{matmulQKOut.getResult ()});
654652
653+ auto mulOut = b.create <tensor::EmptyOp>(loc, matmulQKShape, dtype);
654+ // Broadcast the initial value to the output tensor before convolving.
655+ SmallVector<AffineMap, 4 > indexingMaps;
656+ indexingMaps.push_back (b.getMultiDimIdentityMap (4 ));
657+ indexingMaps.push_back (b.getMultiDimIdentityMap (4 ));
658+ auto mul = b.create <linalg::GenericOp>(
659+ /* location=*/ loc, matmulQKOut.getResult ().getType (),
660+ /* inputs=*/ ValueRange{matmulQK->getResult (0 )},
661+ /* outputs=*/ ValueRange{mulOut.getResult ()}, indexingMaps,
662+ SmallVector<utils::IteratorType>(4 , utils::IteratorType::parallel),
663+ [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
664+ Value constant = b.create <arith::ConstantOp>(
665+ loc, nestedBuilder.getFloatAttr (dtype, rsqrt_head));
666+ Value added =
667+ nestedBuilder.create <arith::MulFOp>(loc, args[0 ], constant);
668+ nestedBuilder.create <linalg::YieldOp>(nestedLoc, added);
669+ });
670+
655671 auto addOut = b.create <tensor::EmptyOp>(loc, matmulQKShape, dtype);
656672 auto add = b.create <linalg::AddOp>(
657673 /* location=*/ loc, addOut.getResult ().getType (),
658- /* inputs=*/ ValueRange{matmulQK ->getResult (0 ), mask},
674+ /* inputs=*/ ValueRange{mul ->getResult (0 ), mask},
659675 /* outputs=*/ ValueRange{addOut.getResult ()});
660676
677+ auto softmaxOut = b.create <tensor::EmptyOp>(loc, matmulQKShape, dtype);
678+ auto softmax = b.create <linalg::SoftmaxOp>(
679+ /* location=*/ loc, softmaxOut.getResult ().getType (),
680+ /* inputs=*/ add->getResult (0 ),
681+ /* outputs=*/ softmaxOut.getResult (), 3 );
682+
661683 auto matmulVOut = b.create <tensor::EmptyOp>(loc, shape, dtype);
662684 auto matmulV = b.create <linalgx::MultiBatchMatmulOp>(
663685 /* location=*/ loc, matmulVOut.getResult ().getType (),
664- /* inputs=*/ ValueRange{add ->getResult (0 ), value},
686+ /* inputs=*/ ValueRange{softmax ->getResult (0 ), value},
665687 /* outputs=*/ ValueRange{matmulVOut.getResult ()});
666688 return SmallVector<Value>{matmulV.getResults ()[0 ]};
667689}
0 commit comments