Skip to content

Commit c64e1ce

Browse files
committed
[HLSL][Matrix] Add support for single subscript accessor
fixes #166206 - Add swizzle support if row index is constant - Add test cases - Add new AST type - Add new LValue for Matrix Row Type - TODO: Make the new LValue a dynamic index version of ExtVectorElt
1 parent 1bd0ec4 commit c64e1ce

33 files changed

+992
-2
lines changed

clang/include/clang/AST/ComputeDependence.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class ParenExpr;
2828
class UnaryOperator;
2929
class UnaryExprOrTypeTraitExpr;
3030
class ArraySubscriptExpr;
31+
class MatrixSingleSubscriptExpr;
3132
class MatrixSubscriptExpr;
3233
class CompoundLiteralExpr;
3334
class ImplicitCastExpr;
@@ -117,6 +118,7 @@ ExprDependence computeDependence(ParenExpr *E);
117118
ExprDependence computeDependence(UnaryOperator *E, const ASTContext &Ctx);
118119
ExprDependence computeDependence(UnaryExprOrTypeTraitExpr *E);
119120
ExprDependence computeDependence(ArraySubscriptExpr *E);
121+
ExprDependence computeDependence(MatrixSingleSubscriptExpr *E);
120122
ExprDependence computeDependence(MatrixSubscriptExpr *E);
121123
ExprDependence computeDependence(CompoundLiteralExpr *E);
122124
ExprDependence computeDependence(ImplicitCastExpr *E);

clang/include/clang/AST/Expr.h

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2790,6 +2790,73 @@ class ArraySubscriptExpr : public Expr {
27902790
}
27912791
};
27922792

2793+
/// MatrixSingleSubscriptExpr - Matrix single subscript expression for the
2794+
/// MatrixType extension when you want to get\set a vector from a Matrix.
2795+
class MatrixSingleSubscriptExpr : public Expr {
2796+
enum { BASE, ROW_IDX, END_EXPR };
2797+
Stmt *SubExprs[END_EXPR];
2798+
2799+
public:
2800+
/// matrix[row]
2801+
///
2802+
/// \param Base The matrix expression.
2803+
/// \param RowIdx The row index expression.
2804+
/// \param T The type of the row (usually a vector type).
2805+
/// \param RBracketLoc Location of the closing ']'.
2806+
MatrixSingleSubscriptExpr(Expr *Base, Expr *RowIdx, QualType T,
2807+
SourceLocation RBracketLoc)
2808+
: Expr(MatrixSingleSubscriptExprClass, T,
2809+
Base->getValueKind(), // lvalue/rvalue follows the matrix base
2810+
OK_MatrixComponent) { // or OK_Ordinary/OK_VectorComponent if you
2811+
// prefer
2812+
SubExprs[BASE] = Base;
2813+
SubExprs[ROW_IDX] = RowIdx;
2814+
ArrayOrMatrixSubscriptExprBits.RBracketLoc = RBracketLoc;
2815+
setDependence(computeDependence(this));
2816+
}
2817+
2818+
/// Create an empty matrix single-subscript expression.
2819+
explicit MatrixSingleSubscriptExpr(EmptyShell Shell)
2820+
: Expr(MatrixSingleSubscriptExprClass, Shell) {}
2821+
2822+
Expr *getBase() { return cast<Expr>(SubExprs[BASE]); }
2823+
const Expr *getBase() const { return cast<Expr>(SubExprs[BASE]); }
2824+
void setBase(Expr *E) { SubExprs[BASE] = E; }
2825+
2826+
Expr *getRowIdx() { return cast<Expr>(SubExprs[ROW_IDX]); }
2827+
const Expr *getRowIdx() const { return cast<Expr>(SubExprs[ROW_IDX]); }
2828+
void setRowIdx(Expr *E) { SubExprs[ROW_IDX] = E; }
2829+
2830+
SourceLocation getBeginLoc() const LLVM_READONLY {
2831+
return getBase()->getBeginLoc();
2832+
}
2833+
2834+
SourceLocation getEndLoc() const { return getRBracketLoc(); }
2835+
2836+
SourceLocation getExprLoc() const LLVM_READONLY {
2837+
return getBase()->getExprLoc();
2838+
}
2839+
2840+
SourceLocation getRBracketLoc() const {
2841+
return ArrayOrMatrixSubscriptExprBits.RBracketLoc;
2842+
}
2843+
void setRBracketLoc(SourceLocation L) {
2844+
ArrayOrMatrixSubscriptExprBits.RBracketLoc = L;
2845+
}
2846+
2847+
static bool classof(const Stmt *T) {
2848+
return T->getStmtClass() == MatrixSingleSubscriptExprClass;
2849+
}
2850+
2851+
// Iterators
2852+
child_range children() {
2853+
return child_range(&SubExprs[0], &SubExprs[0] + END_EXPR);
2854+
}
2855+
const_child_range children() const {
2856+
return const_child_range(&SubExprs[0], &SubExprs[0] + END_EXPR);
2857+
}
2858+
};
2859+
27932860
/// MatrixSubscriptExpr - Matrix subscript expression for the MatrixType
27942861
/// extension.
27952862
/// MatrixSubscriptExpr can be either incomplete (only Base and RowIdx are set

clang/include/clang/AST/RecursiveASTVisitor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2894,6 +2894,7 @@ DEF_TRAVERSE_STMT(CXXMemberCallExpr, {})
28942894
// over the children.
28952895
DEF_TRAVERSE_STMT(AddrLabelExpr, {})
28962896
DEF_TRAVERSE_STMT(ArraySubscriptExpr, {})
2897+
DEF_TRAVERSE_STMT(MatrixSingleSubscriptExpr, {})
28972898
DEF_TRAVERSE_STMT(MatrixSubscriptExpr, {})
28982899
DEF_TRAVERSE_STMT(ArraySectionExpr, {})
28992900
DEF_TRAVERSE_STMT(OMPArrayShapingExpr, {})

clang/include/clang/AST/Stmt.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,7 @@ class alignas(void *) Stmt {
540540
class ArrayOrMatrixSubscriptExprBitfields {
541541
friend class ArraySubscriptExpr;
542542
friend class MatrixSubscriptExpr;
543+
friend class MatrixSingleSubscriptExpr;
543544

544545
LLVM_PREFERRED_TYPE(ExprBitfields)
545546
unsigned : NumExprBits;

clang/include/clang/Basic/StmtNodes.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def UnaryOperator : StmtNode<Expr>;
7575
def OffsetOfExpr : StmtNode<Expr>;
7676
def UnaryExprOrTypeTraitExpr : StmtNode<Expr>;
7777
def ArraySubscriptExpr : StmtNode<Expr>;
78+
def MatrixSingleSubscriptExpr : StmtNode<Expr>;
7879
def MatrixSubscriptExpr : StmtNode<Expr>;
7980
def ArraySectionExpr : StmtNode<Expr>;
8081
def OMPIteratorExpr : StmtNode<Expr>;

clang/include/clang/Sema/Sema.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7405,6 +7405,9 @@ class Sema final : public SemaBase {
74057405
ExprResult CreateBuiltinArraySubscriptExpr(Expr *Base, SourceLocation LLoc,
74067406
Expr *Idx, SourceLocation RLoc);
74077407

7408+
ExprResult CreateBuiltinMatrixSingleSubscriptExpr(Expr *Base, Expr *RowIdx,
7409+
SourceLocation RBLoc);
7410+
74087411
ExprResult CreateBuiltinMatrixSubscriptExpr(Expr *Base, Expr *RowIdx,
74097412
Expr *ColumnIdx,
74107413
SourceLocation RBLoc);

clang/lib/AST/ComputeDependence.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,10 @@ ExprDependence clang::computeDependence(ArraySubscriptExpr *E) {
115115
return E->getLHS()->getDependence() | E->getRHS()->getDependence();
116116
}
117117

118+
ExprDependence clang::computeDependence(MatrixSingleSubscriptExpr *E) {
119+
return E->getBase()->getDependence() | E->getRowIdx()->getDependence();
120+
}
121+
118122
ExprDependence clang::computeDependence(MatrixSubscriptExpr *E) {
119123
return E->getBase()->getDependence() | E->getRowIdx()->getDependence() |
120124
(E->getColumnIdx() ? E->getColumnIdx()->getDependence()

clang/lib/AST/Expr.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3792,6 +3792,7 @@ bool Expr::HasSideEffects(const ASTContext &Ctx,
37923792

37933793
case ParenExprClass:
37943794
case ArraySubscriptExprClass:
3795+
case MatrixSingleSubscriptExprClass:
37953796
case MatrixSubscriptExprClass:
37963797
case ArraySectionExprClass:
37973798
case OMPArrayShapingExprClass:

clang/lib/AST/ExprClassification.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,9 @@ static Cl::Kinds ClassifyInternal(ASTContext &Ctx, const Expr *E) {
259259
}
260260
return Cl::CL_LValue;
261261

262+
case Expr::MatrixSingleSubscriptExprClass:
263+
return ClassifyInternal(Ctx, cast<MatrixSingleSubscriptExpr>(E)->getBase());
264+
262265
// Subscripting matrix types behaves like member accesses.
263266
case Expr::MatrixSubscriptExprClass:
264267
return ClassifyInternal(Ctx, cast<MatrixSubscriptExpr>(E)->getBase());

clang/lib/AST/ExprConstant.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20887,6 +20887,7 @@ static ICEDiag CheckICE(const Expr* E, const ASTContext &Ctx) {
2088720887
case Expr::ImaginaryLiteralClass:
2088820888
case Expr::StringLiteralClass:
2088920889
case Expr::ArraySubscriptExprClass:
20890+
case Expr::MatrixSingleSubscriptExprClass:
2089020891
case Expr::MatrixSubscriptExprClass:
2089120892
case Expr::ArraySectionExprClass:
2089220893
case Expr::OMPArrayShapingExprClass:

0 commit comments

Comments
 (0)