-
Notifications
You must be signed in to change notification settings - Fork 15.5k
[HLSL][Matrix] Add support for single subscript accessor #170779
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[HLSL][Matrix] Add support for single subscript accessor #170779
Conversation
|
@llvm/pr-subscribers-clang-codegen @llvm/pr-subscribers-clang Author: Farzon Lotfi (farzonl) Changesfixes #166206
Patch is 72.24 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/170779.diff 33 Files Affected:
diff --git a/clang/include/clang/AST/ComputeDependence.h b/clang/include/clang/AST/ComputeDependence.h
index c298f2620f211..895105640b931 100644
--- a/clang/include/clang/AST/ComputeDependence.h
+++ b/clang/include/clang/AST/ComputeDependence.h
@@ -28,6 +28,7 @@ class ParenExpr;
class UnaryOperator;
class UnaryExprOrTypeTraitExpr;
class ArraySubscriptExpr;
+class MatrixSingleSubscriptExpr;
class MatrixSubscriptExpr;
class CompoundLiteralExpr;
class ImplicitCastExpr;
@@ -117,6 +118,7 @@ ExprDependence computeDependence(ParenExpr *E);
ExprDependence computeDependence(UnaryOperator *E, const ASTContext &Ctx);
ExprDependence computeDependence(UnaryExprOrTypeTraitExpr *E);
ExprDependence computeDependence(ArraySubscriptExpr *E);
+ExprDependence computeDependence(MatrixSingleSubscriptExpr *E);
ExprDependence computeDependence(MatrixSubscriptExpr *E);
ExprDependence computeDependence(CompoundLiteralExpr *E);
ExprDependence computeDependence(ImplicitCastExpr *E);
diff --git a/clang/include/clang/AST/Expr.h b/clang/include/clang/AST/Expr.h
index 573cc72db35c6..16d9bbe8ff7c1 100644
--- a/clang/include/clang/AST/Expr.h
+++ b/clang/include/clang/AST/Expr.h
@@ -2790,6 +2790,73 @@ class ArraySubscriptExpr : public Expr {
}
};
+/// MatrixSingleSubscriptExpr - Matrix single subscript expression for the
+/// MatrixType extension when you want to get\set a vector from a Matrix.
+class MatrixSingleSubscriptExpr : public Expr {
+ enum { BASE, ROW_IDX, END_EXPR };
+ Stmt *SubExprs[END_EXPR];
+
+public:
+ /// matrix[row]
+ ///
+ /// \param Base The matrix expression.
+ /// \param RowIdx The row index expression.
+ /// \param T The type of the row (usually a vector type).
+ /// \param RBracketLoc Location of the closing ']'.
+ MatrixSingleSubscriptExpr(Expr *Base, Expr *RowIdx, QualType T,
+ SourceLocation RBracketLoc)
+ : Expr(MatrixSingleSubscriptExprClass, T,
+ Base->getValueKind(), // lvalue/rvalue follows the matrix base
+ OK_MatrixComponent) { // or OK_Ordinary/OK_VectorComponent if you
+ // prefer
+ SubExprs[BASE] = Base;
+ SubExprs[ROW_IDX] = RowIdx;
+ ArrayOrMatrixSubscriptExprBits.RBracketLoc = RBracketLoc;
+ setDependence(computeDependence(this));
+ }
+
+ /// Create an empty matrix single-subscript expression.
+ explicit MatrixSingleSubscriptExpr(EmptyShell Shell)
+ : Expr(MatrixSingleSubscriptExprClass, Shell) {}
+
+ Expr *getBase() { return cast<Expr>(SubExprs[BASE]); }
+ const Expr *getBase() const { return cast<Expr>(SubExprs[BASE]); }
+ void setBase(Expr *E) { SubExprs[BASE] = E; }
+
+ Expr *getRowIdx() { return cast<Expr>(SubExprs[ROW_IDX]); }
+ const Expr *getRowIdx() const { return cast<Expr>(SubExprs[ROW_IDX]); }
+ void setRowIdx(Expr *E) { SubExprs[ROW_IDX] = E; }
+
+ SourceLocation getBeginLoc() const LLVM_READONLY {
+ return getBase()->getBeginLoc();
+ }
+
+ SourceLocation getEndLoc() const { return getRBracketLoc(); }
+
+ SourceLocation getExprLoc() const LLVM_READONLY {
+ return getBase()->getExprLoc();
+ }
+
+ SourceLocation getRBracketLoc() const {
+ return ArrayOrMatrixSubscriptExprBits.RBracketLoc;
+ }
+ void setRBracketLoc(SourceLocation L) {
+ ArrayOrMatrixSubscriptExprBits.RBracketLoc = L;
+ }
+
+ static bool classof(const Stmt *T) {
+ return T->getStmtClass() == MatrixSingleSubscriptExprClass;
+ }
+
+ // Iterators
+ child_range children() {
+ return child_range(&SubExprs[0], &SubExprs[0] + END_EXPR);
+ }
+ const_child_range children() const {
+ return const_child_range(&SubExprs[0], &SubExprs[0] + END_EXPR);
+ }
+};
+
/// MatrixSubscriptExpr - Matrix subscript expression for the MatrixType
/// extension.
/// MatrixSubscriptExpr can be either incomplete (only Base and RowIdx are set
diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h
index 8f427427d71ed..92409b72e4f0c 100644
--- a/clang/include/clang/AST/RecursiveASTVisitor.h
+++ b/clang/include/clang/AST/RecursiveASTVisitor.h
@@ -2893,6 +2893,7 @@ DEF_TRAVERSE_STMT(CXXMemberCallExpr, {})
// over the children.
DEF_TRAVERSE_STMT(AddrLabelExpr, {})
DEF_TRAVERSE_STMT(ArraySubscriptExpr, {})
+DEF_TRAVERSE_STMT(MatrixSingleSubscriptExpr, {})
DEF_TRAVERSE_STMT(MatrixSubscriptExpr, {})
DEF_TRAVERSE_STMT(ArraySectionExpr, {})
DEF_TRAVERSE_STMT(OMPArrayShapingExpr, {})
diff --git a/clang/include/clang/AST/Stmt.h b/clang/include/clang/AST/Stmt.h
index e1cca34d2212c..21d0a7dfe577c 100644
--- a/clang/include/clang/AST/Stmt.h
+++ b/clang/include/clang/AST/Stmt.h
@@ -530,6 +530,7 @@ class alignas(void *) Stmt {
class ArrayOrMatrixSubscriptExprBitfields {
friend class ArraySubscriptExpr;
friend class MatrixSubscriptExpr;
+ friend class MatrixSingleSubscriptExpr;
LLVM_PREFERRED_TYPE(ExprBitfields)
unsigned : NumExprBits;
diff --git a/clang/include/clang/Basic/StmtNodes.td b/clang/include/clang/Basic/StmtNodes.td
index bf3686bb372d5..ada74807e56e2 100644
--- a/clang/include/clang/Basic/StmtNodes.td
+++ b/clang/include/clang/Basic/StmtNodes.td
@@ -74,6 +74,7 @@ def UnaryOperator : StmtNode<Expr>;
def OffsetOfExpr : StmtNode<Expr>;
def UnaryExprOrTypeTraitExpr : StmtNode<Expr>;
def ArraySubscriptExpr : StmtNode<Expr>;
+def MatrixSingleSubscriptExpr : StmtNode<Expr>;
def MatrixSubscriptExpr : StmtNode<Expr>;
def ArraySectionExpr : StmtNode<Expr>;
def OMPIteratorExpr : StmtNode<Expr>;
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 4a601a0eaf1b9..d4d5c3d8bed17 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -7406,6 +7406,9 @@ class Sema final : public SemaBase {
ExprResult CreateBuiltinArraySubscriptExpr(Expr *Base, SourceLocation LLoc,
Expr *Idx, SourceLocation RLoc);
+ ExprResult CreateBuiltinMatrixSingleSubscriptExpr(Expr *Base, Expr *RowIdx,
+ SourceLocation RBLoc);
+
ExprResult CreateBuiltinMatrixSubscriptExpr(Expr *Base, Expr *RowIdx,
Expr *ColumnIdx,
SourceLocation RBLoc);
diff --git a/clang/lib/AST/ComputeDependence.cpp b/clang/lib/AST/ComputeDependence.cpp
index 638080ea781a9..8429f17d26be5 100644
--- a/clang/lib/AST/ComputeDependence.cpp
+++ b/clang/lib/AST/ComputeDependence.cpp
@@ -115,6 +115,10 @@ ExprDependence clang::computeDependence(ArraySubscriptExpr *E) {
return E->getLHS()->getDependence() | E->getRHS()->getDependence();
}
+ExprDependence clang::computeDependence(MatrixSingleSubscriptExpr *E) {
+ return E->getBase()->getDependence() | E->getRowIdx()->getDependence();
+}
+
ExprDependence clang::computeDependence(MatrixSubscriptExpr *E) {
return E->getBase()->getDependence() | E->getRowIdx()->getDependence() |
(E->getColumnIdx() ? E->getColumnIdx()->getDependence()
diff --git a/clang/lib/AST/Expr.cpp b/clang/lib/AST/Expr.cpp
index ca7f3e16a9276..b400b2a083d9b 100644
--- a/clang/lib/AST/Expr.cpp
+++ b/clang/lib/AST/Expr.cpp
@@ -3789,6 +3789,7 @@ bool Expr::HasSideEffects(const ASTContext &Ctx,
case ParenExprClass:
case ArraySubscriptExprClass:
+ case MatrixSingleSubscriptExprClass:
case MatrixSubscriptExprClass:
case ArraySectionExprClass:
case OMPArrayShapingExprClass:
diff --git a/clang/lib/AST/ExprClassification.cpp b/clang/lib/AST/ExprClassification.cpp
index aeacd0dc765ef..9995d1b411c5b 100644
--- a/clang/lib/AST/ExprClassification.cpp
+++ b/clang/lib/AST/ExprClassification.cpp
@@ -259,6 +259,9 @@ static Cl::Kinds ClassifyInternal(ASTContext &Ctx, const Expr *E) {
}
return Cl::CL_LValue;
+ case Expr::MatrixSingleSubscriptExprClass:
+ return ClassifyInternal(Ctx, cast<MatrixSingleSubscriptExpr>(E)->getBase());
+
// Subscripting matrix types behaves like member accesses.
case Expr::MatrixSubscriptExprClass:
return ClassifyInternal(Ctx, cast<MatrixSubscriptExpr>(E)->getBase());
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 11c5e1c6e90f4..52481dc71b75d 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -20667,6 +20667,7 @@ static ICEDiag CheckICE(const Expr* E, const ASTContext &Ctx) {
case Expr::ImaginaryLiteralClass:
case Expr::StringLiteralClass:
case Expr::ArraySubscriptExprClass:
+ case Expr::MatrixSingleSubscriptExprClass:
case Expr::MatrixSubscriptExprClass:
case Expr::ArraySectionExprClass:
case Expr::OMPArrayShapingExprClass:
diff --git a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp
index 5572e0a7ae59c..cb71987fba766 100644
--- a/clang/lib/AST/ItaniumMangle.cpp
+++ b/clang/lib/AST/ItaniumMangle.cpp
@@ -5482,6 +5482,15 @@ void CXXNameMangler::mangleExpression(const Expr *E, unsigned Arity,
break;
}
+ case Expr::MatrixSingleSubscriptExprClass: {
+ NotPrimaryExpr();
+ const MatrixSingleSubscriptExpr *ME = cast<MatrixSingleSubscriptExpr>(E);
+ Out << "ix";
+ mangleExpression(ME->getBase());
+ mangleExpression(ME->getRowIdx());
+ break;
+ }
+
case Expr::MatrixSubscriptExprClass: {
NotPrimaryExpr();
const MatrixSubscriptExpr *ME = cast<MatrixSubscriptExpr>(E);
diff --git a/clang/lib/AST/StmtPrinter.cpp b/clang/lib/AST/StmtPrinter.cpp
index ff8ca01ec5477..51b9c47f22ff0 100644
--- a/clang/lib/AST/StmtPrinter.cpp
+++ b/clang/lib/AST/StmtPrinter.cpp
@@ -1685,6 +1685,14 @@ void StmtPrinter::VisitArraySubscriptExpr(ArraySubscriptExpr *Node) {
OS << "]";
}
+void StmtPrinter::VisitMatrixSingleSubscriptExpr(
+ MatrixSingleSubscriptExpr *Node) {
+ PrintExpr(Node->getBase());
+ OS << "[";
+ PrintExpr(Node->getRowIdx());
+ OS << "]";
+}
+
void StmtPrinter::VisitMatrixSubscriptExpr(MatrixSubscriptExpr *Node) {
PrintExpr(Node->getBase());
OS << "[";
diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index 4a8c638c85331..c7b7c65715dfc 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -1508,6 +1508,11 @@ void StmtProfiler::VisitArraySubscriptExpr(const ArraySubscriptExpr *S) {
VisitExpr(S);
}
+void StmtProfiler::VisitMatrixSingleSubscriptExpr(
+ const MatrixSingleSubscriptExpr *S) {
+ VisitExpr(S);
+}
+
void StmtProfiler::VisitMatrixSubscriptExpr(const MatrixSubscriptExpr *S) {
VisitExpr(S);
}
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index e842158236cd4..ca06b5df94cb3 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -1796,6 +1796,8 @@ LValue CodeGenFunction::EmitLValueHelper(const Expr *E,
return EmitUnaryOpLValue(cast<UnaryOperator>(E));
case Expr::ArraySubscriptExprClass:
return EmitArraySubscriptExpr(cast<ArraySubscriptExpr>(E));
+ case Expr::MatrixSingleSubscriptExprClass:
+ return EmitMatrixSingleSubscriptExpr(cast<MatrixSingleSubscriptExpr>(E));
case Expr::MatrixSubscriptExprClass:
return EmitMatrixSubscriptExpr(cast<MatrixSubscriptExpr>(E));
case Expr::ArraySectionExprClass:
@@ -2440,6 +2442,35 @@ RValue CodeGenFunction::EmitLoadOfLValue(LValue LV, SourceLocation Loc) {
Builder.CreateLoad(LV.getMatrixAddress(), LV.isVolatileQualified());
return RValue::get(Builder.CreateExtractElement(Load, Idx, "matrixext"));
}
+ if (LV.isMatrixRow()) {
+ QualType MatTy = LV.getType();
+ const ConstantMatrixType *MT = MatTy->castAs<ConstantMatrixType>();
+
+ unsigned NumRows = MT->getNumRows();
+ unsigned NumCols = MT->getNumColumns();
+
+ llvm::Value *MatrixVec = EmitLoadOfScalar(LV, Loc);
+
+ llvm::Value *Row = LV.getMatrixRowIdx();
+ llvm::Value *Result =
+ llvm::UndefValue::get(ConvertType(LV.getType())); // <NumCols x T>
+
+ llvm::MatrixBuilder MB(Builder);
+
+ for (unsigned Col = 0; Col < NumCols; ++Col) {
+ llvm::Value *ColIdx = llvm::ConstantInt::get(Row->getType(), Col);
+
+ llvm::Value *EltIndex = MB.CreateIndex(Row, ColIdx, NumRows);
+
+ llvm::Value *Elt = Builder.CreateExtractElement(MatrixVec, EltIndex);
+
+ llvm::Value *Lane = llvm::ConstantInt::get(Builder.getInt32Ty(), Col);
+
+ Result = Builder.CreateInsertElement(Result, Elt, Lane);
+ }
+
+ return RValue::get(Result);
+ }
assert(LV.isBitField() && "Unknown LValue type!");
return EmitLoadOfBitfieldLValue(LV, Loc);
@@ -2662,6 +2693,36 @@ void CodeGenFunction::EmitStoreThroughLValue(RValue Src, LValue Dst,
addInstToCurrentSourceAtom(I, Vec);
return;
}
+ if (Dst.isMatrixRow()) {
+ QualType MatTy = Dst.getType();
+ const ConstantMatrixType *MT = MatTy->castAs<ConstantMatrixType>();
+
+ unsigned NumRows = MT->getNumRows();
+ unsigned NumCols = MT->getNumColumns();
+
+ llvm::Value *MatrixVec =
+ Builder.CreateLoad(Dst.getAddress(), "matrix.load");
+
+ llvm::Value *Row = Dst.getMatrixRowIdx();
+ llvm::Value *RowVal = Src.getScalarVal(); // <NumCols x T>
+
+ llvm::MatrixBuilder MB(Builder);
+
+ for (unsigned Col = 0; Col < NumCols; ++Col) {
+ llvm::Value *ColIdx = llvm::ConstantInt::get(Row->getType(), Col);
+
+ llvm::Value *EltIndex = MB.CreateIndex(Row, ColIdx, NumRows);
+
+ llvm::Value *Lane = llvm::ConstantInt::get(Builder.getInt32Ty(), Col);
+
+ llvm::Value *NewElt = Builder.CreateExtractElement(RowVal, Lane);
+
+ MatrixVec = Builder.CreateInsertElement(MatrixVec, NewElt, EltIndex);
+ }
+
+ Builder.CreateStore(MatrixVec, Dst.getAddress());
+ return;
+ }
assert(Dst.isBitField() && "Unknown LValue type");
return EmitStoreThroughBitfieldLValue(Src, Dst);
@@ -4874,6 +4935,35 @@ llvm::Value *CodeGenFunction::EmitMatrixIndexExpr(const Expr *E) {
return Builder.CreateIntCast(Idx, IntPtrTy, IsSigned);
}
+LValue CodeGenFunction::EmitMatrixSingleSubscriptExpr(
+ const MatrixSingleSubscriptExpr *E) {
+ LValue Base = EmitLValue(E->getBase());
+ llvm::Value *RowIdx = EmitMatrixIndexExpr(E->getRowIdx());
+
+ if (auto *RowConst = llvm::dyn_cast<llvm::ConstantInt>(RowIdx)) {
+
+ // Extract matrix shape from the AST type
+ const auto *MatTy = E->getBase()->getType()->castAs<ConstantMatrixType>();
+ unsigned NumCols = MatTy->getNumColumns();
+ llvm::SmallVector<llvm::Constant *, 8> Indices;
+ Indices.reserve(NumCols);
+
+ unsigned Row = RowConst->getZExtValue();
+ unsigned Start = Row * NumCols;
+ for (unsigned C = 0; C < NumCols; ++C) {
+ Indices.push_back(llvm::ConstantInt::get(Int32Ty, Start + C));
+ }
+ llvm::Constant *Elts = llvm::ConstantVector::get(Indices);
+ return LValue::MakeExtVectorElt(
+ MaybeConvertMatrixAddress(Base.getAddress(), *this), Elts,
+ E->getBase()->getType(), Base.getBaseInfo(), TBAAAccessInfo());
+ }
+
+ return LValue::MakeMatrixRow(
+ MaybeConvertMatrixAddress(Base.getAddress(), *this), RowIdx,
+ E->getBase()->getType(), Base.getBaseInfo(), TBAAAccessInfo());
+}
+
LValue CodeGenFunction::EmitMatrixSubscriptExpr(const MatrixSubscriptExpr *E) {
assert(
!E->isIncomplete() &&
@@ -5146,6 +5236,9 @@ EmitExtVectorElementExpr(const ExtVectorElementExpr *E) {
return LValue::MakeExtVectorElt(Base.getAddress(), CV, type,
Base.getBaseInfo(), TBAAAccessInfo());
}
+ if (Base.isMatrixRow())
+ return EmitUnsupportedLValue(E, "Matrix single index swizzle");
+
assert(Base.isExtVectorElt() && "Can only subscript lvalue vec elts here!");
llvm::Constant *BaseElts = Base.getExtVectorElts();
diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp
index 769bc37b0e131..70397e8cb99c2 100644
--- a/clang/lib/CodeGen/CGExprScalar.cpp
+++ b/clang/lib/CodeGen/CGExprScalar.cpp
@@ -599,6 +599,7 @@ class ScalarExprEmitter
}
Value *VisitArraySubscriptExpr(ArraySubscriptExpr *E);
+ Value *VisitMatrixSingleSubscriptExpr(MatrixSingleSubscriptExpr *E);
Value *VisitMatrixSubscriptExpr(MatrixSubscriptExpr *E);
Value *VisitShuffleVectorExpr(ShuffleVectorExpr *E);
Value *VisitConvertVectorExpr(ConvertVectorExpr *E);
@@ -2109,6 +2110,40 @@ Value *ScalarExprEmitter::VisitArraySubscriptExpr(ArraySubscriptExpr *E) {
return Builder.CreateExtractElement(Base, Idx, "vecext");
}
+Value *ScalarExprEmitter::VisitMatrixSingleSubscriptExpr(
+ MatrixSingleSubscriptExpr *E) {
+ TestAndClearIgnoreResultAssign();
+
+ auto *MatrixTy = E->getBase()->getType()->castAs<ConstantMatrixType>();
+ unsigned NumRows = MatrixTy->getNumRows();
+ unsigned NumColumns = MatrixTy->getNumColumns();
+
+ // Row index
+ Value *RowIdx = CGF.EmitMatrixIndexExpr(E->getRowIdx());
+
+ llvm::MatrixBuilder MB(Builder);
+
+ // The row index must be in [0, NumRows)
+ if (CGF.CGM.getCodeGenOpts().OptimizationLevel > 0)
+ MB.CreateIndexAssumption(RowIdx, NumRows);
+
+ Value *FlatMatrix = Visit(E->getBase());
+ llvm::Type *ElemTy = CGF.ConvertType(MatrixTy->getElementType());
+ auto *ResultTy = llvm::FixedVectorType::get(ElemTy, NumColumns);
+ Value *RowVec = llvm::UndefValue::get(ResultTy);
+
+ for (unsigned Col = 0; Col != NumColumns; ++Col) {
+ Value *ColVal = llvm::ConstantInt::get(RowIdx->getType(), Col);
+ Value *EltIdx = MB.CreateIndex(RowIdx, ColVal, NumRows, "matrix_row_idx");
+ Value *Elt =
+ Builder.CreateExtractElement(FlatMatrix, EltIdx, "matrix_elem");
+ Value *Lane = llvm::ConstantInt::get(Builder.getInt32Ty(), Col);
+ RowVec = Builder.CreateInsertElement(RowVec, Elt, Lane, "matrix_row_ins");
+ }
+
+ return RowVec;
+}
+
Value *ScalarExprEmitter::VisitMatrixSubscriptExpr(MatrixSubscriptExpr *E) {
TestAndClearIgnoreResultAssign();
diff --git a/clang/lib/CodeGen/CGValue.h b/clang/lib/CodeGen/CGValue.h
index 6b381b59e71cd..c08ca70de10e1 100644
--- a/clang/lib/CodeGen/CGValue.h
+++ b/clang/lib/CodeGen/CGValue.h
@@ -187,7 +187,8 @@ class LValue {
BitField, // This is a bitfield l-value, use getBitfield*.
ExtVectorElt, // This is an extended vector subset, use getExtVectorComp
GlobalReg, // This is a register l-value, use getGlobalReg()
- MatrixElt // This is a matrix element, use getVector*
+ MatrixElt, // This is a matrix element, use getVector*
+ MatrixRow // This is a matrix vector subset, use getVector*
} LVType;
union {
@@ -282,6 +283,7 @@ class LValue {
bool isExtVectorElt() const { return LVType == ExtVectorElt; }
bool isGlobalReg() const { return LVType == GlobalReg; }
bool isMatrixElt() const { return LVType == MatrixElt; }
+ bool isMatrixRow() const { return LVType == MatrixRow; }
bool isVolatileQualified() const { return Quals.hasVolatile(); }
bool isRestrictQualified() const { return Quals.hasRestrict(); }
@@ -398,6 +400,11 @@ class LValue {
return VectorIdx;
}
+ llvm::Value *getMatrixRowIdx() const {
+ assert(isMatrixRow());
+ return VectorIdx;
+ }
+
// extended vector elements.
Address getExtVectorAddress() const {
assert(isExtVectorElt());
@@ -486,6 +493,16 @@ class LValue {
return R;
}
+ static LValue MakeMatrixRow(Address Addr, llvm::Value *RowIdx,
+ QualType MatrixTy, LValueBaseInfo BaseInfo,
+ TBAAAccessInfo TBAAInfo) {
+ LValue LV;
+ LV.LVType = MatrixRow;
+ LV.VectorIdx = RowIdx; // store the row index here
+ LV.Initialize(MatrixTy, MatrixTy.getQualifiers(), Addr, BaseInfo, TBAAInfo);
+ return LV;
+ }
+
static LValue MakeMatrixElt(Address matAddress, llvm::Value *Idx,
QualType type, LValueBaseInfo BaseInfo,
TBAAAccessInfo TBAAInfo) {
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index 8c4c1c8c2dc95..3abe516debcb0 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -4412,6 +4412,7 @@ class CodeGenFunction : public CodeGenTypeCache {
LValue EmitArraySubscriptExpr(const ArraySubscriptExpr *E,
bool Accessed = false);
llvm::Value *EmitMatrixIndexExpr(const Expr *E);
+ LValue EmitMatrixSingleSubscriptExpr(const MatrixSingleSubscriptExpr *E);
LValue EmitMatrixSubscriptExpr(const MatrixSubscriptExpr *E);
LValue EmitArraySectionExpr(const ArraySect...
[truncated]
|
|
✅ With the latest revision this PR passed the undef deprecator. |
| GlobalReg, // This is a register l-value, use getGlobalReg() | ||
| MatrixElt // This is a matrix element, use getVector* | ||
| MatrixElt, // This is a matrix element, use getVector* | ||
| MatrixRow // This is a matrix vector subset, use getVector* |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MatrixRow to me implies a vector but open to a better name.
37cc571 to
4c112b9
Compare
🐧 Linux x64 Test Results
✅ The build succeeded and all tests passed. |
4c112b9 to
ff98397
Compare
hekota
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please add tests for the template transformation?
clang/test/CodeGenHLSL/BasicFeatures/MatrixSingleSubscriptDynamicSwizzle.hlsl
Show resolved
Hide resolved
clang/test/CodeGenHLSL/BasicFeatures/MatrixSingleSubscriptGetter.hlsl
Outdated
Show resolved
Hide resolved
fixes llvm#166206 - Add swizzle support if row index is constant - Add test cases - Add new AST type - Add new LValue for Matrix Row Type - TODO: Make the new LValue a dynamic index version of ExtVectorElt
ff98397 to
e2a3247
Compare
🪟 Windows x64 Test Results
✅ The build succeeded and all tests passed. |
e2a3247 to
82e2bc0
Compare
a48169b to
0f4584c
Compare
0f4584c to
b032922
Compare
| OK_ObjCSubscript, | ||
|
|
||
| /// A matrix component is a single element of a matrix. | ||
| /// A matrix component is a single element or range of elements on a matrix. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| /// A matrix component is a single element or range of elements on a matrix. | |
| /// A matrix component is a single element or range of elements of a matrix. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just copied the docs for vector /// A vector component is an element or range of elements of a vector. But sure I'll change both
clang/docs/MatrixTypes.rst
Outdated
| A single subscript expression into a matrix is legal in HLSL and yields the | ||
| column-sized vector for the selected row, but is ill-formed in C and C++. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I find "column-sized vector" to be a bit ambiguous in meaning.
I believe you intended for "column-sized vector" to mean a vector whose number of elements is equal to the number of columns in the matrix, but "column-sized vector" could also mean the vector has the same number of elements as one column of the matrix due to sounding very close to the term "column vector" in Linear Algebra.
I would change the description to say "A single subscript expression into a matrix... yields a row vector containing the elements of the selected row"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I dunno more wordy than I wanted and doesn't specify vector size. Maybe something like
denotes a vector whose length equals the matrix’s column dimension for the specified row lane
OR if we think vector size is implied just:
denotes a vector for the specified row lane
fixes #166206