Skip to content

Conversation

@farzonl
Copy link
Member

@farzonl farzonl commented Dec 5, 2025

fixes #166206

  • Add swizzle support if row index is constant
  • Add test cases
  • Add new AST type
  • Add new LValue for Matrix Row Type
  • TODO: Make the new LValue a dynamic index version of ExtVectorElt

@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:modules C++20 modules and Clang Header Modules clang:codegen IR generation bugs: mangling, exceptions, etc. clang:as-a-library libclang and C++ API clang:static analyzer HLSL HLSL Language Support labels Dec 5, 2025
@llvmbot
Copy link
Member

llvmbot commented Dec 5, 2025

@llvm/pr-subscribers-clang-codegen
@llvm/pr-subscribers-clang-modules
@llvm/pr-subscribers-clang-static-analyzer-1

@llvm/pr-subscribers-clang

Author: Farzon Lotfi (farzonl)

Changes

fixes #166206

  • Add swizzle support if row index is constant
  • Add test cases
  • Add new AST type
  • Add new LValue for Matrix Row Type
  • TODO: Make the new LValue a dynamic index version of ExtVectorElt

Patch is 72.24 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/170779.diff

33 Files Affected:

  • (modified) clang/include/clang/AST/ComputeDependence.h (+2)
  • (modified) clang/include/clang/AST/Expr.h (+67)
  • (modified) clang/include/clang/AST/RecursiveASTVisitor.h (+1)
  • (modified) clang/include/clang/AST/Stmt.h (+1)
  • (modified) clang/include/clang/Basic/StmtNodes.td (+1)
  • (modified) clang/include/clang/Sema/Sema.h (+3)
  • (modified) clang/lib/AST/ComputeDependence.cpp (+4)
  • (modified) clang/lib/AST/Expr.cpp (+1)
  • (modified) clang/lib/AST/ExprClassification.cpp (+3)
  • (modified) clang/lib/AST/ExprConstant.cpp (+1)
  • (modified) clang/lib/AST/ItaniumMangle.cpp (+9)
  • (modified) clang/lib/AST/StmtPrinter.cpp (+8)
  • (modified) clang/lib/AST/StmtProfile.cpp (+5)
  • (modified) clang/lib/CodeGen/CGExpr.cpp (+93)
  • (modified) clang/lib/CodeGen/CGExprScalar.cpp (+35)
  • (modified) clang/lib/CodeGen/CGValue.h (+18-1)
  • (modified) clang/lib/CodeGen/CodeGenFunction.h (+1)
  • (modified) clang/lib/Sema/SemaExceptionSpec.cpp (+1)
  • (modified) clang/lib/Sema/SemaExpr.cpp (+60-1)
  • (modified) clang/lib/Sema/TreeTransform.h (+29)
  • (modified) clang/lib/Serialization/ASTReaderStmt.cpp (+8)
  • (modified) clang/lib/Serialization/ASTWriterStmt.cpp (+9)
  • (modified) clang/lib/StaticAnalyzer/Core/ExprEngine.cpp (+5)
  • (added) clang/test/AST/HLSL/matrix-single-subscript-getter.hlsl (+77)
  • (added) clang/test/AST/HLSL/matrix-single-subscript-setter.hlsl (+59)
  • (added) clang/test/AST/HLSL/matrix-single-subscript-swizzle.hlsl (+56)
  • (added) clang/test/AST/HLSL/pch_with_matrix_single_subscript.hlsl (+16)
  • (added) clang/test/CodeGenHLSL/BasicFeatures/MatrixSingleSubscriptConstSwizzle.hlsl (+60)
  • (added) clang/test/CodeGenHLSL/BasicFeatures/MatrixSingleSubscriptDynamicSwizzle.hlsl (+10)
  • (added) clang/test/CodeGenHLSL/BasicFeatures/MatrixSingleSubscriptGetter.hlsl (+205)
  • (added) clang/test/CodeGenHLSL/BasicFeatures/MatrixSingleSubscriptSetter.hlsl (+126)
  • (added) clang/test/SemaHLSL/matrix_single_subscript_errors.hlsl (+12)
  • (modified) clang/tools/libclang/CXCursor.cpp (+5)
diff --git a/clang/include/clang/AST/ComputeDependence.h b/clang/include/clang/AST/ComputeDependence.h
index c298f2620f211..895105640b931 100644
--- a/clang/include/clang/AST/ComputeDependence.h
+++ b/clang/include/clang/AST/ComputeDependence.h
@@ -28,6 +28,7 @@ class ParenExpr;
 class UnaryOperator;
 class UnaryExprOrTypeTraitExpr;
 class ArraySubscriptExpr;
+class MatrixSingleSubscriptExpr;
 class MatrixSubscriptExpr;
 class CompoundLiteralExpr;
 class ImplicitCastExpr;
@@ -117,6 +118,7 @@ ExprDependence computeDependence(ParenExpr *E);
 ExprDependence computeDependence(UnaryOperator *E, const ASTContext &Ctx);
 ExprDependence computeDependence(UnaryExprOrTypeTraitExpr *E);
 ExprDependence computeDependence(ArraySubscriptExpr *E);
+ExprDependence computeDependence(MatrixSingleSubscriptExpr *E);
 ExprDependence computeDependence(MatrixSubscriptExpr *E);
 ExprDependence computeDependence(CompoundLiteralExpr *E);
 ExprDependence computeDependence(ImplicitCastExpr *E);
diff --git a/clang/include/clang/AST/Expr.h b/clang/include/clang/AST/Expr.h
index 573cc72db35c6..16d9bbe8ff7c1 100644
--- a/clang/include/clang/AST/Expr.h
+++ b/clang/include/clang/AST/Expr.h
@@ -2790,6 +2790,73 @@ class ArraySubscriptExpr : public Expr {
   }
 };
 
+/// MatrixSingleSubscriptExpr - Matrix single subscript expression for the
+/// MatrixType extension when you want to get\set a vector from a Matrix.
+class MatrixSingleSubscriptExpr : public Expr {
+  enum { BASE, ROW_IDX, END_EXPR };
+  Stmt *SubExprs[END_EXPR];
+
+public:
+  /// matrix[row]
+  ///
+  /// \param Base        The matrix expression.
+  /// \param RowIdx      The row index expression.
+  /// \param T           The type of the row (usually a vector type).
+  /// \param RBracketLoc Location of the closing ']'.
+  MatrixSingleSubscriptExpr(Expr *Base, Expr *RowIdx, QualType T,
+                            SourceLocation RBracketLoc)
+      : Expr(MatrixSingleSubscriptExprClass, T,
+             Base->getValueKind(), // lvalue/rvalue follows the matrix base
+             OK_MatrixComponent) { // or OK_Ordinary/OK_VectorComponent if you
+                                   // prefer
+    SubExprs[BASE] = Base;
+    SubExprs[ROW_IDX] = RowIdx;
+    ArrayOrMatrixSubscriptExprBits.RBracketLoc = RBracketLoc;
+    setDependence(computeDependence(this));
+  }
+
+  /// Create an empty matrix single-subscript expression.
+  explicit MatrixSingleSubscriptExpr(EmptyShell Shell)
+      : Expr(MatrixSingleSubscriptExprClass, Shell) {}
+
+  Expr *getBase() { return cast<Expr>(SubExprs[BASE]); }
+  const Expr *getBase() const { return cast<Expr>(SubExprs[BASE]); }
+  void setBase(Expr *E) { SubExprs[BASE] = E; }
+
+  Expr *getRowIdx() { return cast<Expr>(SubExprs[ROW_IDX]); }
+  const Expr *getRowIdx() const { return cast<Expr>(SubExprs[ROW_IDX]); }
+  void setRowIdx(Expr *E) { SubExprs[ROW_IDX] = E; }
+
+  SourceLocation getBeginLoc() const LLVM_READONLY {
+    return getBase()->getBeginLoc();
+  }
+
+  SourceLocation getEndLoc() const { return getRBracketLoc(); }
+
+  SourceLocation getExprLoc() const LLVM_READONLY {
+    return getBase()->getExprLoc();
+  }
+
+  SourceLocation getRBracketLoc() const {
+    return ArrayOrMatrixSubscriptExprBits.RBracketLoc;
+  }
+  void setRBracketLoc(SourceLocation L) {
+    ArrayOrMatrixSubscriptExprBits.RBracketLoc = L;
+  }
+
+  static bool classof(const Stmt *T) {
+    return T->getStmtClass() == MatrixSingleSubscriptExprClass;
+  }
+
+  // Iterators
+  child_range children() {
+    return child_range(&SubExprs[0], &SubExprs[0] + END_EXPR);
+  }
+  const_child_range children() const {
+    return const_child_range(&SubExprs[0], &SubExprs[0] + END_EXPR);
+  }
+};
+
 /// MatrixSubscriptExpr - Matrix subscript expression for the MatrixType
 /// extension.
 /// MatrixSubscriptExpr can be either incomplete (only Base and RowIdx are set
diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h
index 8f427427d71ed..92409b72e4f0c 100644
--- a/clang/include/clang/AST/RecursiveASTVisitor.h
+++ b/clang/include/clang/AST/RecursiveASTVisitor.h
@@ -2893,6 +2893,7 @@ DEF_TRAVERSE_STMT(CXXMemberCallExpr, {})
 // over the children.
 DEF_TRAVERSE_STMT(AddrLabelExpr, {})
 DEF_TRAVERSE_STMT(ArraySubscriptExpr, {})
+DEF_TRAVERSE_STMT(MatrixSingleSubscriptExpr, {})
 DEF_TRAVERSE_STMT(MatrixSubscriptExpr, {})
 DEF_TRAVERSE_STMT(ArraySectionExpr, {})
 DEF_TRAVERSE_STMT(OMPArrayShapingExpr, {})
diff --git a/clang/include/clang/AST/Stmt.h b/clang/include/clang/AST/Stmt.h
index e1cca34d2212c..21d0a7dfe577c 100644
--- a/clang/include/clang/AST/Stmt.h
+++ b/clang/include/clang/AST/Stmt.h
@@ -530,6 +530,7 @@ class alignas(void *) Stmt {
   class ArrayOrMatrixSubscriptExprBitfields {
     friend class ArraySubscriptExpr;
     friend class MatrixSubscriptExpr;
+    friend class MatrixSingleSubscriptExpr;
 
     LLVM_PREFERRED_TYPE(ExprBitfields)
     unsigned : NumExprBits;
diff --git a/clang/include/clang/Basic/StmtNodes.td b/clang/include/clang/Basic/StmtNodes.td
index bf3686bb372d5..ada74807e56e2 100644
--- a/clang/include/clang/Basic/StmtNodes.td
+++ b/clang/include/clang/Basic/StmtNodes.td
@@ -74,6 +74,7 @@ def UnaryOperator : StmtNode<Expr>;
 def OffsetOfExpr : StmtNode<Expr>;
 def UnaryExprOrTypeTraitExpr : StmtNode<Expr>;
 def ArraySubscriptExpr : StmtNode<Expr>;
+def MatrixSingleSubscriptExpr : StmtNode<Expr>;
 def MatrixSubscriptExpr : StmtNode<Expr>;
 def ArraySectionExpr : StmtNode<Expr>;
 def OMPIteratorExpr : StmtNode<Expr>;
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 4a601a0eaf1b9..d4d5c3d8bed17 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -7406,6 +7406,9 @@ class Sema final : public SemaBase {
   ExprResult CreateBuiltinArraySubscriptExpr(Expr *Base, SourceLocation LLoc,
                                              Expr *Idx, SourceLocation RLoc);
 
+  ExprResult CreateBuiltinMatrixSingleSubscriptExpr(Expr *Base, Expr *RowIdx,
+                                                    SourceLocation RBLoc);
+
   ExprResult CreateBuiltinMatrixSubscriptExpr(Expr *Base, Expr *RowIdx,
                                               Expr *ColumnIdx,
                                               SourceLocation RBLoc);
diff --git a/clang/lib/AST/ComputeDependence.cpp b/clang/lib/AST/ComputeDependence.cpp
index 638080ea781a9..8429f17d26be5 100644
--- a/clang/lib/AST/ComputeDependence.cpp
+++ b/clang/lib/AST/ComputeDependence.cpp
@@ -115,6 +115,10 @@ ExprDependence clang::computeDependence(ArraySubscriptExpr *E) {
   return E->getLHS()->getDependence() | E->getRHS()->getDependence();
 }
 
+ExprDependence clang::computeDependence(MatrixSingleSubscriptExpr *E) {
+  return E->getBase()->getDependence() | E->getRowIdx()->getDependence();
+}
+
 ExprDependence clang::computeDependence(MatrixSubscriptExpr *E) {
   return E->getBase()->getDependence() | E->getRowIdx()->getDependence() |
          (E->getColumnIdx() ? E->getColumnIdx()->getDependence()
diff --git a/clang/lib/AST/Expr.cpp b/clang/lib/AST/Expr.cpp
index ca7f3e16a9276..b400b2a083d9b 100644
--- a/clang/lib/AST/Expr.cpp
+++ b/clang/lib/AST/Expr.cpp
@@ -3789,6 +3789,7 @@ bool Expr::HasSideEffects(const ASTContext &Ctx,
 
   case ParenExprClass:
   case ArraySubscriptExprClass:
+  case MatrixSingleSubscriptExprClass:
   case MatrixSubscriptExprClass:
   case ArraySectionExprClass:
   case OMPArrayShapingExprClass:
diff --git a/clang/lib/AST/ExprClassification.cpp b/clang/lib/AST/ExprClassification.cpp
index aeacd0dc765ef..9995d1b411c5b 100644
--- a/clang/lib/AST/ExprClassification.cpp
+++ b/clang/lib/AST/ExprClassification.cpp
@@ -259,6 +259,9 @@ static Cl::Kinds ClassifyInternal(ASTContext &Ctx, const Expr *E) {
     }
     return Cl::CL_LValue;
 
+  case Expr::MatrixSingleSubscriptExprClass:
+    return ClassifyInternal(Ctx, cast<MatrixSingleSubscriptExpr>(E)->getBase());
+
   // Subscripting matrix types behaves like member accesses.
   case Expr::MatrixSubscriptExprClass:
     return ClassifyInternal(Ctx, cast<MatrixSubscriptExpr>(E)->getBase());
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 11c5e1c6e90f4..52481dc71b75d 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -20667,6 +20667,7 @@ static ICEDiag CheckICE(const Expr* E, const ASTContext &Ctx) {
   case Expr::ImaginaryLiteralClass:
   case Expr::StringLiteralClass:
   case Expr::ArraySubscriptExprClass:
+  case Expr::MatrixSingleSubscriptExprClass:
   case Expr::MatrixSubscriptExprClass:
   case Expr::ArraySectionExprClass:
   case Expr::OMPArrayShapingExprClass:
diff --git a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp
index 5572e0a7ae59c..cb71987fba766 100644
--- a/clang/lib/AST/ItaniumMangle.cpp
+++ b/clang/lib/AST/ItaniumMangle.cpp
@@ -5482,6 +5482,15 @@ void CXXNameMangler::mangleExpression(const Expr *E, unsigned Arity,
     break;
   }
 
+  case Expr::MatrixSingleSubscriptExprClass: {
+    NotPrimaryExpr();
+    const MatrixSingleSubscriptExpr *ME = cast<MatrixSingleSubscriptExpr>(E);
+    Out << "ix";
+    mangleExpression(ME->getBase());
+    mangleExpression(ME->getRowIdx());
+    break;
+  }
+
   case Expr::MatrixSubscriptExprClass: {
     NotPrimaryExpr();
     const MatrixSubscriptExpr *ME = cast<MatrixSubscriptExpr>(E);
diff --git a/clang/lib/AST/StmtPrinter.cpp b/clang/lib/AST/StmtPrinter.cpp
index ff8ca01ec5477..51b9c47f22ff0 100644
--- a/clang/lib/AST/StmtPrinter.cpp
+++ b/clang/lib/AST/StmtPrinter.cpp
@@ -1685,6 +1685,14 @@ void StmtPrinter::VisitArraySubscriptExpr(ArraySubscriptExpr *Node) {
   OS << "]";
 }
 
+void StmtPrinter::VisitMatrixSingleSubscriptExpr(
+    MatrixSingleSubscriptExpr *Node) {
+  PrintExpr(Node->getBase());
+  OS << "[";
+  PrintExpr(Node->getRowIdx());
+  OS << "]";
+}
+
 void StmtPrinter::VisitMatrixSubscriptExpr(MatrixSubscriptExpr *Node) {
   PrintExpr(Node->getBase());
   OS << "[";
diff --git a/clang/lib/AST/StmtProfile.cpp b/clang/lib/AST/StmtProfile.cpp
index 4a8c638c85331..c7b7c65715dfc 100644
--- a/clang/lib/AST/StmtProfile.cpp
+++ b/clang/lib/AST/StmtProfile.cpp
@@ -1508,6 +1508,11 @@ void StmtProfiler::VisitArraySubscriptExpr(const ArraySubscriptExpr *S) {
   VisitExpr(S);
 }
 
+void StmtProfiler::VisitMatrixSingleSubscriptExpr(
+    const MatrixSingleSubscriptExpr *S) {
+  VisitExpr(S);
+}
+
 void StmtProfiler::VisitMatrixSubscriptExpr(const MatrixSubscriptExpr *S) {
   VisitExpr(S);
 }
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index e842158236cd4..ca06b5df94cb3 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -1796,6 +1796,8 @@ LValue CodeGenFunction::EmitLValueHelper(const Expr *E,
     return EmitUnaryOpLValue(cast<UnaryOperator>(E));
   case Expr::ArraySubscriptExprClass:
     return EmitArraySubscriptExpr(cast<ArraySubscriptExpr>(E));
+  case Expr::MatrixSingleSubscriptExprClass:
+    return EmitMatrixSingleSubscriptExpr(cast<MatrixSingleSubscriptExpr>(E));
   case Expr::MatrixSubscriptExprClass:
     return EmitMatrixSubscriptExpr(cast<MatrixSubscriptExpr>(E));
   case Expr::ArraySectionExprClass:
@@ -2440,6 +2442,35 @@ RValue CodeGenFunction::EmitLoadOfLValue(LValue LV, SourceLocation Loc) {
         Builder.CreateLoad(LV.getMatrixAddress(), LV.isVolatileQualified());
     return RValue::get(Builder.CreateExtractElement(Load, Idx, "matrixext"));
   }
+  if (LV.isMatrixRow()) {
+    QualType MatTy = LV.getType();
+    const ConstantMatrixType *MT = MatTy->castAs<ConstantMatrixType>();
+
+    unsigned NumRows = MT->getNumRows();
+    unsigned NumCols = MT->getNumColumns();
+
+    llvm::Value *MatrixVec = EmitLoadOfScalar(LV, Loc);
+
+    llvm::Value *Row = LV.getMatrixRowIdx();
+    llvm::Value *Result =
+        llvm::UndefValue::get(ConvertType(LV.getType())); // <NumCols x T>
+
+    llvm::MatrixBuilder MB(Builder);
+
+    for (unsigned Col = 0; Col < NumCols; ++Col) {
+      llvm::Value *ColIdx = llvm::ConstantInt::get(Row->getType(), Col);
+
+      llvm::Value *EltIndex = MB.CreateIndex(Row, ColIdx, NumRows);
+
+      llvm::Value *Elt = Builder.CreateExtractElement(MatrixVec, EltIndex);
+
+      llvm::Value *Lane = llvm::ConstantInt::get(Builder.getInt32Ty(), Col);
+
+      Result = Builder.CreateInsertElement(Result, Elt, Lane);
+    }
+
+    return RValue::get(Result);
+  }
 
   assert(LV.isBitField() && "Unknown LValue type!");
   return EmitLoadOfBitfieldLValue(LV, Loc);
@@ -2662,6 +2693,36 @@ void CodeGenFunction::EmitStoreThroughLValue(RValue Src, LValue Dst,
       addInstToCurrentSourceAtom(I, Vec);
       return;
     }
+    if (Dst.isMatrixRow()) {
+      QualType MatTy = Dst.getType();
+      const ConstantMatrixType *MT = MatTy->castAs<ConstantMatrixType>();
+
+      unsigned NumRows = MT->getNumRows();
+      unsigned NumCols = MT->getNumColumns();
+
+      llvm::Value *MatrixVec =
+          Builder.CreateLoad(Dst.getAddress(), "matrix.load");
+
+      llvm::Value *Row = Dst.getMatrixRowIdx();
+      llvm::Value *RowVal = Src.getScalarVal(); // <NumCols x T>
+
+      llvm::MatrixBuilder MB(Builder);
+
+      for (unsigned Col = 0; Col < NumCols; ++Col) {
+        llvm::Value *ColIdx = llvm::ConstantInt::get(Row->getType(), Col);
+
+        llvm::Value *EltIndex = MB.CreateIndex(Row, ColIdx, NumRows);
+
+        llvm::Value *Lane = llvm::ConstantInt::get(Builder.getInt32Ty(), Col);
+
+        llvm::Value *NewElt = Builder.CreateExtractElement(RowVal, Lane);
+
+        MatrixVec = Builder.CreateInsertElement(MatrixVec, NewElt, EltIndex);
+      }
+
+      Builder.CreateStore(MatrixVec, Dst.getAddress());
+      return;
+    }
 
     assert(Dst.isBitField() && "Unknown LValue type");
     return EmitStoreThroughBitfieldLValue(Src, Dst);
@@ -4874,6 +4935,35 @@ llvm::Value *CodeGenFunction::EmitMatrixIndexExpr(const Expr *E) {
   return Builder.CreateIntCast(Idx, IntPtrTy, IsSigned);
 }
 
+LValue CodeGenFunction::EmitMatrixSingleSubscriptExpr(
+    const MatrixSingleSubscriptExpr *E) {
+  LValue Base = EmitLValue(E->getBase());
+  llvm::Value *RowIdx = EmitMatrixIndexExpr(E->getRowIdx());
+
+  if (auto *RowConst = llvm::dyn_cast<llvm::ConstantInt>(RowIdx)) {
+
+    // Extract matrix shape from the AST type
+    const auto *MatTy = E->getBase()->getType()->castAs<ConstantMatrixType>();
+    unsigned NumCols = MatTy->getNumColumns();
+    llvm::SmallVector<llvm::Constant *, 8> Indices;
+    Indices.reserve(NumCols);
+
+    unsigned Row = RowConst->getZExtValue();
+    unsigned Start = Row * NumCols;
+    for (unsigned C = 0; C < NumCols; ++C) {
+      Indices.push_back(llvm::ConstantInt::get(Int32Ty, Start + C));
+    }
+    llvm::Constant *Elts = llvm::ConstantVector::get(Indices);
+    return LValue::MakeExtVectorElt(
+        MaybeConvertMatrixAddress(Base.getAddress(), *this), Elts,
+        E->getBase()->getType(), Base.getBaseInfo(), TBAAAccessInfo());
+  }
+
+  return LValue::MakeMatrixRow(
+      MaybeConvertMatrixAddress(Base.getAddress(), *this), RowIdx,
+      E->getBase()->getType(), Base.getBaseInfo(), TBAAAccessInfo());
+}
+
 LValue CodeGenFunction::EmitMatrixSubscriptExpr(const MatrixSubscriptExpr *E) {
   assert(
       !E->isIncomplete() &&
@@ -5146,6 +5236,9 @@ EmitExtVectorElementExpr(const ExtVectorElementExpr *E) {
     return LValue::MakeExtVectorElt(Base.getAddress(), CV, type,
                                     Base.getBaseInfo(), TBAAAccessInfo());
   }
+  if (Base.isMatrixRow())
+    return EmitUnsupportedLValue(E, "Matrix single index swizzle");
+
   assert(Base.isExtVectorElt() && "Can only subscript lvalue vec elts here!");
 
   llvm::Constant *BaseElts = Base.getExtVectorElts();
diff --git a/clang/lib/CodeGen/CGExprScalar.cpp b/clang/lib/CodeGen/CGExprScalar.cpp
index 769bc37b0e131..70397e8cb99c2 100644
--- a/clang/lib/CodeGen/CGExprScalar.cpp
+++ b/clang/lib/CodeGen/CGExprScalar.cpp
@@ -599,6 +599,7 @@ class ScalarExprEmitter
   }
 
   Value *VisitArraySubscriptExpr(ArraySubscriptExpr *E);
+  Value *VisitMatrixSingleSubscriptExpr(MatrixSingleSubscriptExpr *E);
   Value *VisitMatrixSubscriptExpr(MatrixSubscriptExpr *E);
   Value *VisitShuffleVectorExpr(ShuffleVectorExpr *E);
   Value *VisitConvertVectorExpr(ConvertVectorExpr *E);
@@ -2109,6 +2110,40 @@ Value *ScalarExprEmitter::VisitArraySubscriptExpr(ArraySubscriptExpr *E) {
   return Builder.CreateExtractElement(Base, Idx, "vecext");
 }
 
+Value *ScalarExprEmitter::VisitMatrixSingleSubscriptExpr(
+    MatrixSingleSubscriptExpr *E) {
+  TestAndClearIgnoreResultAssign();
+
+  auto *MatrixTy = E->getBase()->getType()->castAs<ConstantMatrixType>();
+  unsigned NumRows = MatrixTy->getNumRows();
+  unsigned NumColumns = MatrixTy->getNumColumns();
+
+  // Row index
+  Value *RowIdx = CGF.EmitMatrixIndexExpr(E->getRowIdx());
+
+  llvm::MatrixBuilder MB(Builder);
+
+  // The row index must be in [0, NumRows)
+  if (CGF.CGM.getCodeGenOpts().OptimizationLevel > 0)
+    MB.CreateIndexAssumption(RowIdx, NumRows);
+
+  Value *FlatMatrix = Visit(E->getBase());
+  llvm::Type *ElemTy = CGF.ConvertType(MatrixTy->getElementType());
+  auto *ResultTy = llvm::FixedVectorType::get(ElemTy, NumColumns);
+  Value *RowVec = llvm::UndefValue::get(ResultTy);
+
+  for (unsigned Col = 0; Col != NumColumns; ++Col) {
+    Value *ColVal = llvm::ConstantInt::get(RowIdx->getType(), Col);
+    Value *EltIdx = MB.CreateIndex(RowIdx, ColVal, NumRows, "matrix_row_idx");
+    Value *Elt =
+        Builder.CreateExtractElement(FlatMatrix, EltIdx, "matrix_elem");
+    Value *Lane = llvm::ConstantInt::get(Builder.getInt32Ty(), Col);
+    RowVec = Builder.CreateInsertElement(RowVec, Elt, Lane, "matrix_row_ins");
+  }
+
+  return RowVec;
+}
+
 Value *ScalarExprEmitter::VisitMatrixSubscriptExpr(MatrixSubscriptExpr *E) {
   TestAndClearIgnoreResultAssign();
 
diff --git a/clang/lib/CodeGen/CGValue.h b/clang/lib/CodeGen/CGValue.h
index 6b381b59e71cd..c08ca70de10e1 100644
--- a/clang/lib/CodeGen/CGValue.h
+++ b/clang/lib/CodeGen/CGValue.h
@@ -187,7 +187,8 @@ class LValue {
     BitField,     // This is a bitfield l-value, use getBitfield*.
     ExtVectorElt, // This is an extended vector subset, use getExtVectorComp
     GlobalReg,    // This is a register l-value, use getGlobalReg()
-    MatrixElt     // This is a matrix element, use getVector*
+    MatrixElt,    // This is a matrix element, use getVector*
+    MatrixRow     // This is a matrix vector subset, use getVector*
   } LVType;
 
   union {
@@ -282,6 +283,7 @@ class LValue {
   bool isExtVectorElt() const { return LVType == ExtVectorElt; }
   bool isGlobalReg() const { return LVType == GlobalReg; }
   bool isMatrixElt() const { return LVType == MatrixElt; }
+  bool isMatrixRow() const { return LVType == MatrixRow; }
 
   bool isVolatileQualified() const { return Quals.hasVolatile(); }
   bool isRestrictQualified() const { return Quals.hasRestrict(); }
@@ -398,6 +400,11 @@ class LValue {
     return VectorIdx;
   }
 
+  llvm::Value *getMatrixRowIdx() const {
+    assert(isMatrixRow());
+    return VectorIdx;
+  }
+
   // extended vector elements.
   Address getExtVectorAddress() const {
     assert(isExtVectorElt());
@@ -486,6 +493,16 @@ class LValue {
     return R;
   }
 
+  static LValue MakeMatrixRow(Address Addr, llvm::Value *RowIdx,
+                              QualType MatrixTy, LValueBaseInfo BaseInfo,
+                              TBAAAccessInfo TBAAInfo) {
+    LValue LV;
+    LV.LVType = MatrixRow;
+    LV.VectorIdx = RowIdx; // store the row index here
+    LV.Initialize(MatrixTy, MatrixTy.getQualifiers(), Addr, BaseInfo, TBAAInfo);
+    return LV;
+  }
+
   static LValue MakeMatrixElt(Address matAddress, llvm::Value *Idx,
                               QualType type, LValueBaseInfo BaseInfo,
                               TBAAAccessInfo TBAAInfo) {
diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h
index 8c4c1c8c2dc95..3abe516debcb0 100644
--- a/clang/lib/CodeGen/CodeGenFunction.h
+++ b/clang/lib/CodeGen/CodeGenFunction.h
@@ -4412,6 +4412,7 @@ class CodeGenFunction : public CodeGenTypeCache {
   LValue EmitArraySubscriptExpr(const ArraySubscriptExpr *E,
                                 bool Accessed = false);
   llvm::Value *EmitMatrixIndexExpr(const Expr *E);
+  LValue EmitMatrixSingleSubscriptExpr(const MatrixSingleSubscriptExpr *E);
   LValue EmitMatrixSubscriptExpr(const MatrixSubscriptExpr *E);
   LValue EmitArraySectionExpr(const ArraySect...
[truncated]

@github-actions
Copy link

github-actions bot commented Dec 5, 2025

✅ With the latest revision this PR passed the undef deprecator.

GlobalReg, // This is a register l-value, use getGlobalReg()
MatrixElt // This is a matrix element, use getVector*
MatrixElt, // This is a matrix element, use getVector*
MatrixRow // This is a matrix vector subset, use getVector*
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MatrixRow to me implies a vector but open to a better name.

@farzonl farzonl force-pushed the feature/matrix_single_subscript-166206 branch from 37cc571 to 4c112b9 Compare December 5, 2025 00:36
@github-actions
Copy link

github-actions bot commented Dec 5, 2025

🐧 Linux x64 Test Results

  • 112121 tests passed
  • 4524 tests skipped

✅ The build succeeded and all tests passed.

@farzonl farzonl force-pushed the feature/matrix_single_subscript-166206 branch from 4c112b9 to ff98397 Compare December 5, 2025 01:09
Copy link
Member

@hekota hekota left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add tests for the template transformation?

fixes llvm#166206

- Add swizzle support if row index is constant
- Add test cases
- Add new AST type
- Add new LValue for Matrix Row Type
- TODO: Make the new LValue a dynamic index version of ExtVectorElt
@farzonl farzonl force-pushed the feature/matrix_single_subscript-166206 branch from ff98397 to e2a3247 Compare December 16, 2025 17:59
@github-actions
Copy link

github-actions bot commented Dec 16, 2025

🪟 Windows x64 Test Results

  • 53091 tests passed
  • 2084 tests skipped

✅ The build succeeded and all tests passed.

@farzonl farzonl force-pushed the feature/matrix_single_subscript-166206 branch from e2a3247 to 82e2bc0 Compare December 16, 2025 19:45
@farzonl farzonl force-pushed the feature/matrix_single_subscript-166206 branch from a48169b to 0f4584c Compare December 17, 2025 18:31
@farzonl farzonl force-pushed the feature/matrix_single_subscript-166206 branch from 0f4584c to b032922 Compare December 17, 2025 19:43
OK_ObjCSubscript,

/// A matrix component is a single element of a matrix.
/// A matrix component is a single element or range of elements on a matrix.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// A matrix component is a single element or range of elements on a matrix.
/// A matrix component is a single element or range of elements of a matrix.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just copied the docs for vector /// A vector component is an element or range of elements of a vector. But sure I'll change both

Comment on lines 109 to 110
A single subscript expression into a matrix is legal in HLSL and yields the
column-sized vector for the selected row, but is ill-formed in C and C++.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I find "column-sized vector" to be a bit ambiguous in meaning.

I believe you intended for "column-sized vector" to mean a vector whose number of elements is equal to the number of columns in the matrix, but "column-sized vector" could also mean the vector has the same number of elements as one column of the matrix due to sounding very close to the term "column vector" in Linear Algebra.

I would change the description to say "A single subscript expression into a matrix... yields a row vector containing the elements of the selected row"

Copy link
Member Author

@farzonl farzonl Dec 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dunno more wordy than I wanted and doesn't specify vector size. Maybe something like

  • denotes a vector whose length equals the matrix’s column dimension for the specified row lane

OR if we think vector size is implied just:

  • denotes a vector for the specified row lane

@farzonl farzonl merged commit 60b6c53 into llvm:main Dec 17, 2025
13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

clang:as-a-library libclang and C++ API clang:codegen IR generation bugs: mangling, exceptions, etc. clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:modules C++20 modules and Clang Header Modules clang:static analyzer clang Clang issues not falling into any other category HLSL HLSL Language Support

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[HLSL] Clang needs to support single subscript operators on matrix type

4 participants