diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 0eedfc74feec..5530ae0ed569 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -996,6 +996,46 @@ StringRef getAMDArch(Operation *module) { return ref.drop_front(4); // drop the "hip:" } +// Rough utility for obtaining a SharedEnc for a LinearEncoding, +// as we've replaced DotOpEnc with Linear in some cases +// (specifically, fp4ToFp and similar unpack-upcast thru join) +std::optional +getSharedForLinear(ttg::LinearEncodingAttr enc, + ArrayRef globalOrder, ArrayRef shape, + unsigned elemBitWidth, ttg::CTALayoutAttr ctaLayout) { + auto ctx = enc.getContext(); + auto ll = enc.getLinearLayout(); + auto rank = shape.size(); + + if (rank != 2) + return std::nullopt; + + auto order = enc.getOrder(); + assert(globalOrder.size() == rank); + // TODO add memdesc_trans support for dot(trans(cvt(src) #linear) #dot_op) + if (order != globalOrder) + return std::nullopt; + + auto innerDim = order[0]; + auto outerDim = order[1]; + auto contigPerWarp = enc.getContigPerWarp(); + + constexpr unsigned BANK_SIZE{128}; + auto elemBytes = elemBitWidth / 8; + + auto vec = contigPerWarp[innerDim]; + auto rowSize = elemBytes * (unsigned)shape[innerDim]; + auto perPhase = std::max(BANK_SIZE / rowSize, 1u); + auto maxPhase = std::max(contigPerWarp[outerDim] / perPhase, 1u); + + // cp.async does not support transfer size < 4B + if (vec * elemBytes < 4 && perPhase < maxPhase) + return std::nullopt; + + return ttg::SwizzledSharedEncodingAttr::get(ctx, vec, perPhase, maxPhase, + order, ctaLayout); +} + // If all the transitive uses of the given value have are used by a convert to // the same dot operand encoding, return the shared encoding that needs to be // used to be compatible with users' layouts. If there are incompatible shared @@ -1023,7 +1063,8 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) { if (!isa(user)) return std::nullopt; auto enc = - cast(user->getResult(0).getType()).getEncoding(); + cast(user->getResult(0).getType()) + .getEncoding(); if (isa(enc)) { auto srcTy = cast(val.getType()); auto CTALayout = ttg::getCTALayout(srcTy.getEncoding()); @@ -1039,6 +1080,16 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) { val.getContext(), /*vec=*/1, /*perPhase=*/1, /*maxPhase=*/1, ttg::getOrder(srcTy.getEncoding()), ttg::getCTALayout(srcTy.getEncoding())); + } else if (auto linearEnc = dyn_cast(enc)) { + auto srcTy = cast(val.getType()); + auto ctaLayout = ttg::getCTALayout(srcTy.getEncoding()); + auto order = ttg::getOrder(srcTy.getEncoding()); + unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth(); + auto attrOpt = getSharedForLinear(linearEnc, order, srcTy.getShape(), + bitWidth, ctaLayout); + if (!attrOpt) + return std::nullopt; + tempAttr = *attrOpt; } else { return std::nullopt; }