Skip to content

Conversation

@farzonl
Copy link
Member

@farzonl farzonl commented Dec 8, 2025

fixes #159438

This patch adds MatrixElementExpr, a new AST node for HLSL matrix element and swizzle access (e.g. M._m00, M._11_22_33).

It introduces a shared ElementAccessExprBase used by both matrix and vector swizzle expressions, updates Sema to parse and validate zero-based and one-based accessors, detects duplicates for l-value checks, and emits improved diagnostics. CodeGen is updated to lower scalar and multi-element accesses consistently, and full AST serialization, dumping, and tooling support is included. This implementation reflects the updated RFC for HLSL matrix accessor semantics.

@github-actions
Copy link

github-actions bot commented Dec 8, 2025

🐧 Linux x64 Test Results

  • 112682 tests passed
  • 4095 tests skipped

✅ The build succeeded and all tests passed.

@farzonl farzonl force-pushed the feature/matrix-swizzle-issue-159438 branch from d37d179 to 6e259ed Compare December 9, 2025 00:25
++NumComponents;
}
if (NumComponents == 0 || NumComponents > 4) {
S.Diag(OpLoc, diag::err_hlsl_matrix_swizzle_invalid_length)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO add a test for this to clang/test/SemaHLSL/matrix-member-access-errors.hlsl with 5 components

@farzonl farzonl force-pushed the feature/matrix-swizzle-issue-159438 branch from 6e259ed to 33847da Compare December 9, 2025 17:13
@github-actions
Copy link

github-actions bot commented Dec 9, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@farzonl farzonl force-pushed the feature/matrix-swizzle-issue-159438 branch from 33847da to fcdbceb Compare December 9, 2025 17:16
@farzonl farzonl force-pushed the feature/matrix-swizzle-issue-159438 branch from fcdbceb to 06ce71f Compare December 9, 2025 21:44
@farzonl farzonl marked this pull request as ready for review December 11, 2025 21:29
@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:modules C++20 modules and Clang Header Modules clang:codegen IR generation bugs: mangling, exceptions, etc. clang:as-a-library libclang and C++ API clang:static analyzer HLSL HLSL Language Support ClangIR Anything related to the ClangIR project labels Dec 11, 2025
@llvmbot
Copy link
Member

llvmbot commented Dec 11, 2025

@llvm/pr-subscribers-clang-static-analyzer-1
@llvm/pr-subscribers-clang-codegen

@llvm/pr-subscribers-clang-modules

Author: Farzon Lotfi (farzonl)

Changes

fixes #159438

This patch adds MatrixElementExpr, a new AST node for HLSL matrix element and swizzle access (e.g. M._m00, M._11_22_33).

It introduces a shared ElementAccessExprBase used by both matrix and vector swizzle expressions, updates Sema to parse and validate zero-based and one-based accessors, detects duplicates for l-value checks, and emits improved diagnostics. CodeGen is updated to lower scalar and multi-element accesses consistently, and full AST serialization, dumping, and tooling support is included. This implementation reflects the updated RFC for HLSL matrix accessor semantics.


Patch is 165.23 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/171225.diff

43 Files Affected:

  • (modified) clang/include/clang/AST/ComputeDependence.h (+2)
  • (modified) clang/include/clang/AST/Expr.h (+68-39)
  • (modified) clang/include/clang/AST/RecursiveASTVisitor.h (+1)
  • (modified) clang/include/clang/AST/TextNodeDumper.h (+1)
  • (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+12)
  • (modified) clang/include/clang/Basic/StmtNodes.td (+1)
  • (modified) clang/include/clang/Sema/SemaHLSL.h (+4)
  • (modified) clang/include/clang/Serialization/ASTBitCodes.h (+3)
  • (modified) clang/lib/AST/ComputeDependence.cpp (+4)
  • (modified) clang/lib/AST/Expr.cpp (+125-1)
  • (modified) clang/lib/AST/ExprClassification.cpp (+15)
  • (modified) clang/lib/AST/ExprConstant.cpp (+1)
  • (modified) clang/lib/AST/ItaniumMangle.cpp (+1)
  • (modified) clang/lib/AST/StmtPrinter.cpp (+6)
  • (modified) clang/lib/AST/StmtProfile.cpp (+5)
  • (modified) clang/lib/AST/TextNodeDumper.cpp (+4)
  • (modified) clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp (+2)
  • (modified) clang/lib/CodeGen/CGExpr.cpp (+47)
  • (modified) clang/lib/CodeGen/CGExprScalar.cpp (+1)
  • (modified) clang/lib/CodeGen/CodeGenFunction.h (+1)
  • (modified) clang/lib/Sema/SemaExceptionSpec.cpp (+1)
  • (modified) clang/lib/Sema/SemaExpr.cpp (+3)
  • (modified) clang/lib/Sema/SemaExprMember.cpp (+17)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+199)
  • (modified) clang/lib/Sema/TreeTransform.h (+25-6)
  • (modified) clang/lib/Serialization/ASTReaderStmt.cpp (+11)
  • (modified) clang/lib/Serialization/ASTWriterStmt.cpp (+8)
  • (modified) clang/lib/StaticAnalyzer/Core/ExprEngine.cpp (+1)
  • (added) clang/test/AST/HLSL/matrix-member-access-scalar.hlsl (+38)
  • (added) clang/test/AST/HLSL/matrix-member-access-swizzle-ast-dump-json.hlsl (+25)
  • (added) clang/test/AST/HLSL/matrix-member-access-swizzle-ast-print.hlsl (+21)
  • (added) clang/test/AST/HLSL/matrix-member-access-swizzle.hlsl (+49)
  • (added) clang/test/AST/HLSL/pch_with_matrix_element_accessor.hlsl (+26)
  • (added) clang/test/CodeGenHLSL/matrix-member-one-based-accessor-scalar-load.hlsl (+230)
  • (added) clang/test/CodeGenHLSL/matrix-member-one-based-accessor-scalar-store.hlsl (+345)
  • (added) clang/test/CodeGenHLSL/matrix-member-one-based-swizzle-load.hlsl (+108)
  • (added) clang/test/CodeGenHLSL/matrix-member-one-based-swizzle-store.hlsl (+230)
  • (added) clang/test/CodeGenHLSL/matrix-member-zero-based-accessor-scalar-load.hlsl (+230)
  • (added) clang/test/CodeGenHLSL/matrix-member-zero-based-accessor-scalar-store.hlsl (+345)
  • (added) clang/test/CodeGenHLSL/matrix-member-zero-based-swizzle-load.hlsl (+108)
  • (added) clang/test/CodeGenHLSL/matrix-member-zero-based-swizzle-store.hlsl (+230)
  • (added) clang/test/SemaHLSL/matrix-member-access-errors.hlsl (+28)
  • (modified) clang/tools/libclang/CXCursor.cpp (+1)
diff --git a/clang/include/clang/AST/ComputeDependence.h b/clang/include/clang/AST/ComputeDependence.h
index c298f2620f211..16fdbfcac0864 100644
--- a/clang/include/clang/AST/ComputeDependence.h
+++ b/clang/include/clang/AST/ComputeDependence.h
@@ -44,6 +44,7 @@ class ArrayInitLoopExpr;
 class ImplicitValueInitExpr;
 class InitListExpr;
 class ExtVectorElementExpr;
+class MatrixElementExpr;
 class BlockExpr;
 class AsTypeExpr;
 class DeclRefExpr;
@@ -133,6 +134,7 @@ ExprDependence computeDependence(ArrayInitLoopExpr *E);
 ExprDependence computeDependence(ImplicitValueInitExpr *E);
 ExprDependence computeDependence(InitListExpr *E);
 ExprDependence computeDependence(ExtVectorElementExpr *E);
+ExprDependence computeDependence(MatrixElementExpr *E);
 ExprDependence computeDependence(BlockExpr *E,
                                  bool ContainsUnexpandedParameterPack);
 ExprDependence computeDependence(AsTypeExpr *E);
diff --git a/clang/include/clang/AST/Expr.h b/clang/include/clang/AST/Expr.h
index 573cc72db35c6..f0d9004a70cb7 100644
--- a/clang/include/clang/AST/Expr.h
+++ b/clang/include/clang/AST/Expr.h
@@ -291,6 +291,7 @@ class Expr : public ValueStmt {
     LV_NotObjectType,
     LV_IncompleteVoidType,
     LV_DuplicateVectorComponents,
+    LV_DuplicateMatrixComponents,
     LV_InvalidExpression,
     LV_InvalidMessageExpression,
     LV_MemberFunction,
@@ -306,8 +307,9 @@ class Expr : public ValueStmt {
     MLV_NotObjectType,
     MLV_IncompleteVoidType,
     MLV_DuplicateVectorComponents,
+    MLV_DuplicateMatrixComponents,
     MLV_InvalidExpression,
-    MLV_LValueCast,           // Specialized form of MLV_InvalidExpression.
+    MLV_LValueCast, // Specialized form of MLV_InvalidExpression.
     MLV_IncompleteType,
     MLV_ConstQualified,
     MLV_ConstQualifiedField,
@@ -340,16 +342,17 @@ class Expr : public ValueStmt {
     enum Kinds {
       CL_LValue,
       CL_XValue,
-      CL_Function, // Functions cannot be lvalues in C.
-      CL_Void, // Void cannot be an lvalue in C.
+      CL_Function,        // Functions cannot be lvalues in C.
+      CL_Void,            // Void cannot be an lvalue in C.
       CL_AddressableVoid, // Void expression whose address can be taken in C.
       CL_DuplicateVectorComponents, // A vector shuffle with dupes.
+      CL_DuplicateMatrixComponents, // A matrix shuffle with dupes.
       CL_MemberFunction, // An expression referring to a member function
       CL_SubObjCPropertySetting,
-      CL_ClassTemporary, // A temporary of class type, or subobject thereof.
-      CL_ArrayTemporary, // A temporary of array type.
+      CL_ClassTemporary,    // A temporary of class type, or subobject thereof.
+      CL_ArrayTemporary,    // A temporary of array type.
       CL_ObjCMessageRValue, // ObjC message is an rvalue
-      CL_PRValue // A prvalue for any other reason, of any other type
+      CL_PRValue            // A prvalue for any other reason, of any other type
     };
     /// The results of modification testing.
     enum ModifiableType {
@@ -6488,30 +6491,24 @@ class GenericSelectionExpr final
 // Clang Extensions
 //===----------------------------------------------------------------------===//
 
-/// ExtVectorElementExpr - This represents access to specific elements of a
-/// vector, and may occur on the left hand side or right hand side.  For example
-/// the following is legal:  "V.xy = V.zw" if V is a 4 element extended vector.
-///
-/// Note that the base may have either vector or pointer to vector type, just
-/// like a struct field reference.
-///
-class ExtVectorElementExpr : public Expr {
+template <class Derived> class ElementAccessExprBase : public Expr {
+protected:
   Stmt *Base;
   IdentifierInfo *Accessor;
   SourceLocation AccessorLoc;
-public:
-  ExtVectorElementExpr(QualType ty, ExprValueKind VK, Expr *base,
-                       IdentifierInfo &accessor, SourceLocation loc)
-      : Expr(ExtVectorElementExprClass, ty, VK,
-             (VK == VK_PRValue ? OK_Ordinary : OK_VectorComponent)),
-        Base(base), Accessor(&accessor), AccessorLoc(loc) {
-    setDependence(computeDependence(this));
+
+  ElementAccessExprBase(StmtClass SC, QualType Ty, ExprValueKind VK, Expr *Base,
+                        IdentifierInfo &Accessor, SourceLocation Loc,
+                        ExprObjectKind OK)
+      : Expr(SC, Ty, VK, OK), Base(Base), Accessor(&Accessor),
+        AccessorLoc(Loc) {
+    setDependence(computeDependence(static_cast<Derived *>(this)));
   }
 
-  /// Build an empty vector element expression.
-  explicit ExtVectorElementExpr(EmptyShell Empty)
-    : Expr(ExtVectorElementExprClass, Empty) { }
+  explicit ElementAccessExprBase(StmtClass SC, EmptyShell Empty)
+      : Expr(SC, Empty) {}
 
+public:
   const Expr *getBase() const { return cast<Expr>(Base); }
   Expr *getBase() { return cast<Expr>(Base); }
   void setBase(Expr *E) { Base = E; }
@@ -6522,22 +6519,40 @@ class ExtVectorElementExpr : public Expr {
   SourceLocation getAccessorLoc() const { return AccessorLoc; }
   void setAccessorLoc(SourceLocation L) { AccessorLoc = L; }
 
-  /// getNumElements - Get the number of components being selected.
-  unsigned getNumElements() const;
-
-  /// containsDuplicateElements - Return true if any element access is
-  /// repeated.
-  bool containsDuplicateElements() const;
-
-  /// getEncodedElementAccess - Encode the elements accessed into an llvm
-  /// aggregate Constant of ConstantInt(s).
-  void getEncodedElementAccess(SmallVectorImpl<uint32_t> &Elts) const;
-
   SourceLocation getBeginLoc() const LLVM_READONLY {
     return getBase()->getBeginLoc();
   }
   SourceLocation getEndLoc() const LLVM_READONLY { return AccessorLoc; }
 
+  child_range children() { return child_range(&Base, &Base + 1); }
+  const_child_range children() const {
+    return const_child_range(&Base, &Base + 1);
+  }
+};
+
+/// ExtVectorElementExpr - This represents access to specific elements of a
+/// vector, and may occur on the left hand side or right hand side.  For example
+/// the following is legal:  "V.xy = V.zw" if V is a 4 element extended vector.
+///
+/// Note that the base may have either vector or pointer to vector type, just
+/// like a struct field reference.
+///
+class ExtVectorElementExpr
+    : public ElementAccessExprBase<ExtVectorElementExpr> {
+public:
+  ExtVectorElementExpr(QualType Ty, ExprValueKind VK, Expr *Base,
+                       IdentifierInfo &Accessor, SourceLocation Loc)
+      : ElementAccessExprBase(
+            ExtVectorElementExprClass, Ty, VK, Base, Accessor, Loc,
+            (VK == VK_PRValue ? OK_Ordinary : OK_VectorComponent)) {}
+
+  explicit ExtVectorElementExpr(EmptyShell Empty)
+      : ElementAccessExprBase(ExtVectorElementExprClass, Empty) {}
+
+  unsigned getNumElements() const;
+  bool containsDuplicateElements() const;
+  void getEncodedElementAccess(SmallVectorImpl<uint32_t> &Elts) const;
+
   /// isArrow - Return true if the base expression is a pointer to vector,
   /// return false if the base expression is a vector.
   bool isArrow() const;
@@ -6545,11 +6560,25 @@ class ExtVectorElementExpr : public Expr {
   static bool classof(const Stmt *T) {
     return T->getStmtClass() == ExtVectorElementExprClass;
   }
+};
 
-  // Iterators
-  child_range children() { return child_range(&Base, &Base+1); }
-  const_child_range children() const {
-    return const_child_range(&Base, &Base + 1);
+class MatrixElementExpr : public ElementAccessExprBase<MatrixElementExpr> {
+public:
+  MatrixElementExpr(QualType Ty, ExprValueKind VK, Expr *Base,
+                    IdentifierInfo &Accessor, SourceLocation Loc)
+      : ElementAccessExprBase(
+            MatrixElementExprClass, Ty, VK, Base, Accessor, Loc,
+            OK_Ordinary /*TODO: Should we add a new OK_MatrixComponent?*/) {}
+
+  explicit MatrixElementExpr(EmptyShell Empty)
+      : ElementAccessExprBase(MatrixElementExprClass, Empty) {}
+
+  unsigned getNumElements() const;
+  bool containsDuplicateElements() const;
+  void getEncodedElementAccess(SmallVectorImpl<uint32_t> &Elts) const;
+
+  static bool classof(const Stmt *T) {
+    return T->getStmtClass() == MatrixElementExprClass;
   }
 };
 
diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h
index 8f427427d71ed..0cc10403657ca 100644
--- a/clang/include/clang/AST/RecursiveASTVisitor.h
+++ b/clang/include/clang/AST/RecursiveASTVisitor.h
@@ -2940,6 +2940,7 @@ DEF_TRAVERSE_STMT(UserDefinedLiteral, {})
 DEF_TRAVERSE_STMT(DesignatedInitExpr, {})
 DEF_TRAVERSE_STMT(DesignatedInitUpdateExpr, {})
 DEF_TRAVERSE_STMT(ExtVectorElementExpr, {})
+DEF_TRAVERSE_STMT(MatrixElementExpr, {})
 DEF_TRAVERSE_STMT(GNUNullExpr, {})
 DEF_TRAVERSE_STMT(ImplicitValueInitExpr, {})
 DEF_TRAVERSE_STMT(NoInitExpr, {})
diff --git a/clang/include/clang/AST/TextNodeDumper.h b/clang/include/clang/AST/TextNodeDumper.h
index 88ecd526e3d7e..ab828be124b0b 100644
--- a/clang/include/clang/AST/TextNodeDumper.h
+++ b/clang/include/clang/AST/TextNodeDumper.h
@@ -286,6 +286,7 @@ class TextNodeDumper
   void VisitUnaryExprOrTypeTraitExpr(const UnaryExprOrTypeTraitExpr *Node);
   void VisitMemberExpr(const MemberExpr *Node);
   void VisitExtVectorElementExpr(const ExtVectorElementExpr *Node);
+  void VisitMatrixElementExpr(const MatrixElementExpr *Node);
   void VisitBinaryOperator(const BinaryOperator *Node);
   void VisitCompoundAssignOperator(const CompoundAssignOperator *Node);
   void VisitAddrLabelExpr(const AddrLabelExpr *Node);
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index fec278c21a89e..bba27597be6bd 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -9336,6 +9336,8 @@ def err_typecheck_lvalue_casts_not_supported : Error<
 
 def err_typecheck_duplicate_vector_components_not_mlvalue : Error<
   "vector is not assignable (contains duplicate components)">;
+def err_typecheck_duplicate_matrix_components_not_mlvalue : Error<
+  "matrix is not assignable (contains duplicate components)">;
 def err_block_decl_ref_not_modifiable_lvalue : Error<
   "variable is not assignable (missing __block type specifier)">;
 def err_lambda_decl_ref_not_modifiable_lvalue : Error<
@@ -13061,6 +13063,7 @@ def err_builtin_matrix_stride_too_small: Error<
   "stride must be greater or equal to the number of rows">;
 def err_builtin_matrix_invalid_dimension: Error<
   "%0 dimension is outside the allowed range [1, %1]">;
+def err_builtin_matrix_invalid_member: Error<"invalid matrix member '%0' expected %1">;
 
 def warn_mismatched_import : Warning<
   "import %select{module|name}0 (%1) does not match the import %select{module|name}0 (%2) of the "
@@ -13318,6 +13321,15 @@ def err_hlsl_builtin_scalar_vector_mismatch
           "%select{all|second and third}0 arguments to %1 must be of scalar or "
           "vector type with matching scalar element type%diff{: $ vs $|}2,3">;
 
+def err_hlsl_matrix_element_not_in_bounds : Error<
+  "matrix %select{row|column}0 element accessor is out of bounds of %select{zero|one}1 based indexing">;
+
+def err_hlsl_matrix_index_out_of_bounds : Error<
+  "matrix %select{row|column}0 index %1 is out of bounds of %select{rows|columns}0 size %2">;
+
+def err_hlsl_matrix_swizzle_invalid_length : Error<
+  "matrix swizzle length must be between 1 and 4 but is %0">;
+
 def warn_hlsl_impcast_vector_truncation : Warning<
   "implicit conversion truncates vector: %0 to %1">, InGroup<VectorConversion>;
 
diff --git a/clang/include/clang/Basic/StmtNodes.td b/clang/include/clang/Basic/StmtNodes.td
index bf3686bb372d5..950ebba66b720 100644
--- a/clang/include/clang/Basic/StmtNodes.td
+++ b/clang/include/clang/Basic/StmtNodes.td
@@ -91,6 +91,7 @@ def CStyleCastExpr : StmtNode<ExplicitCastExpr>;
 def OMPArrayShapingExpr : StmtNode<Expr>;
 def CompoundLiteralExpr : StmtNode<Expr>;
 def ExtVectorElementExpr : StmtNode<Expr>;
+def MatrixElementExpr : StmtNode<Expr>;
 def InitListExpr : StmtNode<Expr>;
 def DesignatedInitExpr : StmtNode<Expr>;
 def DesignatedInitUpdateExpr : StmtNode<Expr>;
diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index a2faa91d1e54d..953b3529c40b6 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -214,6 +214,10 @@ class SemaHLSL : public SemaBase {
   bool transformInitList(const InitializedEntity &Entity, InitListExpr *Init);
   bool handleInitialization(VarDecl *VDecl, Expr *&Init);
   void deduceAddressSpace(VarDecl *Decl);
+  QualType CheckMatrixComponent(Sema &S, QualType baseType, ExprValueKind &VK,
+                                SourceLocation OpLoc,
+                                const IdentifierInfo *CompName,
+                                SourceLocation CompLoc);
 
 private:
   // HLSL resource type attributes need to be processed all at once.
diff --git a/clang/include/clang/Serialization/ASTBitCodes.h b/clang/include/clang/Serialization/ASTBitCodes.h
index d7d429eacd67a..2fc71b7f916ac 100644
--- a/clang/include/clang/Serialization/ASTBitCodes.h
+++ b/clang/include/clang/Serialization/ASTBitCodes.h
@@ -1693,6 +1693,9 @@ enum StmtCode {
   /// An ExtVectorElementExpr record.
   EXPR_EXT_VECTOR_ELEMENT,
 
+  /// A MatrixElementExpr record.
+  EXPR_MATRIX_ELEMENT,
+
   /// An InitListExpr record.
   EXPR_INIT_LIST,
 
diff --git a/clang/lib/AST/ComputeDependence.cpp b/clang/lib/AST/ComputeDependence.cpp
index 638080ea781a9..610aa16e9ae13 100644
--- a/clang/lib/AST/ComputeDependence.cpp
+++ b/clang/lib/AST/ComputeDependence.cpp
@@ -252,6 +252,10 @@ ExprDependence clang::computeDependence(ExtVectorElementExpr *E) {
   return E->getBase()->getDependence();
 }
 
+ExprDependence clang::computeDependence(MatrixElementExpr *E) {
+  return E->getBase()->getDependence();
+}
+
 ExprDependence clang::computeDependence(BlockExpr *E,
                                         bool ContainsUnexpandedParameterPack) {
   auto D = toExprDependenceForImpliedType(E->getType()->getDependence());
diff --git a/clang/lib/AST/Expr.cpp b/clang/lib/AST/Expr.cpp
index ca7f3e16a9276..2fd071fa3fb45 100644
--- a/clang/lib/AST/Expr.cpp
+++ b/clang/lib/AST/Expr.cpp
@@ -25,6 +25,7 @@
 #include "clang/AST/IgnoreExpr.h"
 #include "clang/AST/Mangle.h"
 #include "clang/AST/RecordLayout.h"
+#include "clang/AST/TypeBase.h"
 #include "clang/Basic/Builtins.h"
 #include "clang/Basic/CharInfo.h"
 #include "clang/Basic/SourceManager.h"
@@ -3798,6 +3799,7 @@ bool Expr::HasSideEffects(const ASTContext &Ctx,
   case BinaryConditionalOperatorClass:
   case CompoundLiteralExprClass:
   case ExtVectorElementExprClass:
+  case MatrixElementExprClass:
   case DesignatedInitExprClass:
   case DesignatedInitUpdateExprClass:
   case ArrayInitLoopExprClass:
@@ -4418,7 +4420,14 @@ unsigned ExtVectorElementExpr::getNumElements() const {
   return 1;
 }
 
-/// containsDuplicateElements - Return true if any element access is repeated.
+unsigned MatrixElementExpr::getNumElements() const {
+  if (const ConstantMatrixType *MT = getType()->getAs<ConstantMatrixType>())
+    return MT->getNumElementsFlattened();
+  return 1;
+}
+
+/// containsDuplicateElements - Return true if any Vector element access is
+/// repeated.
 bool ExtVectorElementExpr::containsDuplicateElements() const {
   // FIXME: Refactor this code to an accessor on the AST node which returns the
   // "type" of component access, and share with code below and in Sema.
@@ -4439,6 +4448,68 @@ bool ExtVectorElementExpr::containsDuplicateElements() const {
   return false;
 }
 
+/// containsDuplicateElements - Return true if any Matrix element access is
+/// repeated.
+bool MatrixElementExpr::containsDuplicateElements() const {
+  StringRef Comp = Accessor->getName();
+  assert(!Comp.empty() && Comp[0] == '_' && "invalid matrix accessor");
+
+  // Get the matrix type so we know bounds.
+  const ConstantMatrixType *MT =
+      getBase()->getType()->getAs<ConstantMatrixType>();
+  assert(MT && "MatrixElementExpr base must be a matrix type");
+
+  unsigned Rows = MT->getNumRows();
+  unsigned Cols = MT->getNumColumns();
+  unsigned Max = Rows * Cols;
+
+  // Zero-indexed: _mRC  (4 chars per component)
+  // One-indexed: _RC    (3 chars per component)
+  bool IsZeroIndexed = false;
+  unsigned ChunkLen = 0;
+
+  if (Comp.size() >= 2 && Comp[0] == '_' && Comp[1] == 'm') {
+    IsZeroIndexed = true;
+    ChunkLen = 4;
+  } else {
+    IsZeroIndexed = false;
+    ChunkLen = 3;
+  }
+
+  assert(ChunkLen && "unrecognized matrix swizzle format");
+  assert(Comp.size() % ChunkLen == 0 &&
+         "matrix swizzle accessor has invalid length");
+
+  // Track visited elements using real matrix size.
+  SmallVector<bool, 16> Seen(Max, false);
+
+  for (unsigned I = 0, e = Comp.size(); I < e; I += ChunkLen) {
+    unsigned Row = 0, Col = 0;
+
+    if (IsZeroIndexed) {
+      // Pattern: _mRC
+      assert(Comp[I] == '_' && Comp[I + 1] == 'm');
+      Row = Comp[I + 2] - '0'; // 0..(Rows-1)
+      Col = Comp[I + 3] - '0';
+    } else {
+      // Pattern: _RC
+      assert(Comp[I] == '_');
+      Row = (Comp[I + 1] - '1'); // 1..Rows (ie same as 0..Rows-1)
+      Col = (Comp[I + 2] - '1');
+    }
+
+    // Bounds check (Sema should enforce correctness, but we assert anyway)
+    assert(Row < Rows && Col < Cols && "matrix swizzle index out of bounds");
+
+    unsigned Index = Row * Cols + Col;
+    if (Seen[Index])
+      return true;
+
+    Seen[Index] = true;
+  }
+  return false;
+}
+
 /// getEncodedElementAccess - We encode the fields as a llvm ConstantArray.
 void ExtVectorElementExpr::getEncodedElementAccess(
     SmallVectorImpl<uint32_t> &Elts) const {
@@ -4472,6 +4543,59 @@ void ExtVectorElementExpr::getEncodedElementAccess(
   }
 }
 
+void MatrixElementExpr::getEncodedElementAccess(
+    SmallVectorImpl<uint32_t> &Elts) const {
+  StringRef Comp = Accessor->getName();
+  assert(!Comp.empty() && Comp[0] == '_' && "invalid matrix accessor");
+
+  const ConstantMatrixType *MT =
+      getBase()->getType()->getAs<ConstantMatrixType>();
+  assert(MT && "MatrixElementExpr base must be a matrix type");
+
+  unsigned Rows = MT->getNumRows();
+  unsigned Cols = MT->getNumColumns();
+
+  // Zero-indexed: _mRC (4 chars per component: '_', 'm', row, col)
+  // One-indexed:  _RC  (3 chars per component: '_', row, col)
+  bool IsZeroIndexed = false;
+  unsigned ChunkLen = 0;
+
+  if (Comp.size() >= 2 && Comp[0] == '_' && Comp[1] == 'm') {
+    IsZeroIndexed = true;
+    ChunkLen = 4;
+  } else {
+    IsZeroIndexed = false;
+    ChunkLen = 3;
+  }
+
+  assert(ChunkLen != 0 && "unrecognized matrix swizzle format");
+  assert(Comp.size() % ChunkLen == 0 &&
+         "matrix swizzle accessor has invalid length");
+
+  for (unsigned i = 0, e = Comp.size(); i < e; i += ChunkLen) {
+    unsigned Row = 0, Col = 0;
+
+    if (IsZeroIndexed) {
+      // Pattern: _mRC
+      assert(Comp[i] == '_' && Comp[i + 1] == 'm' &&
+             "invalid zero-indexed matrix swizzle component");
+      Row = static_cast<unsigned>(Comp[i + 2] - '0'); // 0..Rows-1
+      Col = static_cast<unsigned>(Comp[i + 3] - '0'); // 0..Cols-1
+    } else {
+      // Pattern: _RC
+      assert(Comp[i] == '_' && "invalid one-indexed matrix swizzle component");
+      Row = static_cast<unsigned>(Comp[i + 1] - '1'); // 1..Rows -> 0..Rows-1
+      Col = static_cast<unsigned>(Comp[i + 2] - '1'); // 1..Cols -> 0..Cols-1
+    }
+
+    // Sema should have validated these, but assert here for sanity.
+    assert(Row < Rows && Col < Cols && "matrix swizzle index out of range");
+
+    unsigned Index = Row * Cols + Col;
+    Elts.push_back(Index);
+  }
+}
+
 ShuffleVectorExpr::ShuffleVectorExpr(const ASTContext &C, ArrayRef<Expr *> args,
                                      QualType Type, SourceLocation BLoc,
                                      SourceLocation RP)
diff --git a/clang/lib/AST/ExprClassification.cpp b/clang/lib/AST/ExprClassification.cpp
index aeacd0dc765ef..7bb5a2202c3bf 100644
--- a/clang/lib/AST/ExprClassification.cpp
+++ b/clang/lib/AST/ExprClassification.cpp
@@ -63,6 +63,7 @@ Cl Expr::ClassifyImpl(ASTContext &Ctx, SourceLocation *Loc) const {
   case Cl::CL_Void:
   case Cl::CL_AddressableVoid:
   case Cl::CL_DuplicateVectorComponents:
+  case Cl::CL_DuplicateMatrixComponents:
   case Cl::CL_MemberFunction:
   case Cl::CL_SubObjCPropertySetting:
   case Cl::CL_ClassTempor...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Dec 11, 2025

@llvm/pr-subscribers-clangir

Author: Farzon Lotfi (farzonl)

Changes

fixes #159438

This patch adds MatrixElementExpr, a new AST node for HLSL matrix element and swizzle access (e.g. M._m00, M._11_22_33).

It introduces a shared ElementAccessExprBase used by both matrix and vector swizzle expressions, updates Sema to parse and validate zero-based and one-based accessors, detects duplicates for l-value checks, and emits improved diagnostics. CodeGen is updated to lower scalar and multi-element accesses consistently, and full AST serialization, dumping, and tooling support is included. This implementation reflects the updated RFC for HLSL matrix accessor semantics.


Patch is 165.23 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/171225.diff

43 Files Affected:

  • (modified) clang/include/clang/AST/ComputeDependence.h (+2)
  • (modified) clang/include/clang/AST/Expr.h (+68-39)
  • (modified) clang/include/clang/AST/RecursiveASTVisitor.h (+1)
  • (modified) clang/include/clang/AST/TextNodeDumper.h (+1)
  • (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+12)
  • (modified) clang/include/clang/Basic/StmtNodes.td (+1)
  • (modified) clang/include/clang/Sema/SemaHLSL.h (+4)
  • (modified) clang/include/clang/Serialization/ASTBitCodes.h (+3)
  • (modified) clang/lib/AST/ComputeDependence.cpp (+4)
  • (modified) clang/lib/AST/Expr.cpp (+125-1)
  • (modified) clang/lib/AST/ExprClassification.cpp (+15)
  • (modified) clang/lib/AST/ExprConstant.cpp (+1)
  • (modified) clang/lib/AST/ItaniumMangle.cpp (+1)
  • (modified) clang/lib/AST/StmtPrinter.cpp (+6)
  • (modified) clang/lib/AST/StmtProfile.cpp (+5)
  • (modified) clang/lib/AST/TextNodeDumper.cpp (+4)
  • (modified) clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp (+2)
  • (modified) clang/lib/CodeGen/CGExpr.cpp (+47)
  • (modified) clang/lib/CodeGen/CGExprScalar.cpp (+1)
  • (modified) clang/lib/CodeGen/CodeGenFunction.h (+1)
  • (modified) clang/lib/Sema/SemaExceptionSpec.cpp (+1)
  • (modified) clang/lib/Sema/SemaExpr.cpp (+3)
  • (modified) clang/lib/Sema/SemaExprMember.cpp (+17)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+199)
  • (modified) clang/lib/Sema/TreeTransform.h (+25-6)
  • (modified) clang/lib/Serialization/ASTReaderStmt.cpp (+11)
  • (modified) clang/lib/Serialization/ASTWriterStmt.cpp (+8)
  • (modified) clang/lib/StaticAnalyzer/Core/ExprEngine.cpp (+1)
  • (added) clang/test/AST/HLSL/matrix-member-access-scalar.hlsl (+38)
  • (added) clang/test/AST/HLSL/matrix-member-access-swizzle-ast-dump-json.hlsl (+25)
  • (added) clang/test/AST/HLSL/matrix-member-access-swizzle-ast-print.hlsl (+21)
  • (added) clang/test/AST/HLSL/matrix-member-access-swizzle.hlsl (+49)
  • (added) clang/test/AST/HLSL/pch_with_matrix_element_accessor.hlsl (+26)
  • (added) clang/test/CodeGenHLSL/matrix-member-one-based-accessor-scalar-load.hlsl (+230)
  • (added) clang/test/CodeGenHLSL/matrix-member-one-based-accessor-scalar-store.hlsl (+345)
  • (added) clang/test/CodeGenHLSL/matrix-member-one-based-swizzle-load.hlsl (+108)
  • (added) clang/test/CodeGenHLSL/matrix-member-one-based-swizzle-store.hlsl (+230)
  • (added) clang/test/CodeGenHLSL/matrix-member-zero-based-accessor-scalar-load.hlsl (+230)
  • (added) clang/test/CodeGenHLSL/matrix-member-zero-based-accessor-scalar-store.hlsl (+345)
  • (added) clang/test/CodeGenHLSL/matrix-member-zero-based-swizzle-load.hlsl (+108)
  • (added) clang/test/CodeGenHLSL/matrix-member-zero-based-swizzle-store.hlsl (+230)
  • (added) clang/test/SemaHLSL/matrix-member-access-errors.hlsl (+28)
  • (modified) clang/tools/libclang/CXCursor.cpp (+1)
diff --git a/clang/include/clang/AST/ComputeDependence.h b/clang/include/clang/AST/ComputeDependence.h
index c298f2620f211..16fdbfcac0864 100644
--- a/clang/include/clang/AST/ComputeDependence.h
+++ b/clang/include/clang/AST/ComputeDependence.h
@@ -44,6 +44,7 @@ class ArrayInitLoopExpr;
 class ImplicitValueInitExpr;
 class InitListExpr;
 class ExtVectorElementExpr;
+class MatrixElementExpr;
 class BlockExpr;
 class AsTypeExpr;
 class DeclRefExpr;
@@ -133,6 +134,7 @@ ExprDependence computeDependence(ArrayInitLoopExpr *E);
 ExprDependence computeDependence(ImplicitValueInitExpr *E);
 ExprDependence computeDependence(InitListExpr *E);
 ExprDependence computeDependence(ExtVectorElementExpr *E);
+ExprDependence computeDependence(MatrixElementExpr *E);
 ExprDependence computeDependence(BlockExpr *E,
                                  bool ContainsUnexpandedParameterPack);
 ExprDependence computeDependence(AsTypeExpr *E);
diff --git a/clang/include/clang/AST/Expr.h b/clang/include/clang/AST/Expr.h
index 573cc72db35c6..f0d9004a70cb7 100644
--- a/clang/include/clang/AST/Expr.h
+++ b/clang/include/clang/AST/Expr.h
@@ -291,6 +291,7 @@ class Expr : public ValueStmt {
     LV_NotObjectType,
     LV_IncompleteVoidType,
     LV_DuplicateVectorComponents,
+    LV_DuplicateMatrixComponents,
     LV_InvalidExpression,
     LV_InvalidMessageExpression,
     LV_MemberFunction,
@@ -306,8 +307,9 @@ class Expr : public ValueStmt {
     MLV_NotObjectType,
     MLV_IncompleteVoidType,
     MLV_DuplicateVectorComponents,
+    MLV_DuplicateMatrixComponents,
     MLV_InvalidExpression,
-    MLV_LValueCast,           // Specialized form of MLV_InvalidExpression.
+    MLV_LValueCast, // Specialized form of MLV_InvalidExpression.
     MLV_IncompleteType,
     MLV_ConstQualified,
     MLV_ConstQualifiedField,
@@ -340,16 +342,17 @@ class Expr : public ValueStmt {
     enum Kinds {
       CL_LValue,
       CL_XValue,
-      CL_Function, // Functions cannot be lvalues in C.
-      CL_Void, // Void cannot be an lvalue in C.
+      CL_Function,        // Functions cannot be lvalues in C.
+      CL_Void,            // Void cannot be an lvalue in C.
       CL_AddressableVoid, // Void expression whose address can be taken in C.
       CL_DuplicateVectorComponents, // A vector shuffle with dupes.
+      CL_DuplicateMatrixComponents, // A matrix shuffle with dupes.
       CL_MemberFunction, // An expression referring to a member function
       CL_SubObjCPropertySetting,
-      CL_ClassTemporary, // A temporary of class type, or subobject thereof.
-      CL_ArrayTemporary, // A temporary of array type.
+      CL_ClassTemporary,    // A temporary of class type, or subobject thereof.
+      CL_ArrayTemporary,    // A temporary of array type.
       CL_ObjCMessageRValue, // ObjC message is an rvalue
-      CL_PRValue // A prvalue for any other reason, of any other type
+      CL_PRValue            // A prvalue for any other reason, of any other type
     };
     /// The results of modification testing.
     enum ModifiableType {
@@ -6488,30 +6491,24 @@ class GenericSelectionExpr final
 // Clang Extensions
 //===----------------------------------------------------------------------===//
 
-/// ExtVectorElementExpr - This represents access to specific elements of a
-/// vector, and may occur on the left hand side or right hand side.  For example
-/// the following is legal:  "V.xy = V.zw" if V is a 4 element extended vector.
-///
-/// Note that the base may have either vector or pointer to vector type, just
-/// like a struct field reference.
-///
-class ExtVectorElementExpr : public Expr {
+template <class Derived> class ElementAccessExprBase : public Expr {
+protected:
   Stmt *Base;
   IdentifierInfo *Accessor;
   SourceLocation AccessorLoc;
-public:
-  ExtVectorElementExpr(QualType ty, ExprValueKind VK, Expr *base,
-                       IdentifierInfo &accessor, SourceLocation loc)
-      : Expr(ExtVectorElementExprClass, ty, VK,
-             (VK == VK_PRValue ? OK_Ordinary : OK_VectorComponent)),
-        Base(base), Accessor(&accessor), AccessorLoc(loc) {
-    setDependence(computeDependence(this));
+
+  ElementAccessExprBase(StmtClass SC, QualType Ty, ExprValueKind VK, Expr *Base,
+                        IdentifierInfo &Accessor, SourceLocation Loc,
+                        ExprObjectKind OK)
+      : Expr(SC, Ty, VK, OK), Base(Base), Accessor(&Accessor),
+        AccessorLoc(Loc) {
+    setDependence(computeDependence(static_cast<Derived *>(this)));
   }
 
-  /// Build an empty vector element expression.
-  explicit ExtVectorElementExpr(EmptyShell Empty)
-    : Expr(ExtVectorElementExprClass, Empty) { }
+  explicit ElementAccessExprBase(StmtClass SC, EmptyShell Empty)
+      : Expr(SC, Empty) {}
 
+public:
   const Expr *getBase() const { return cast<Expr>(Base); }
   Expr *getBase() { return cast<Expr>(Base); }
   void setBase(Expr *E) { Base = E; }
@@ -6522,22 +6519,40 @@ class ExtVectorElementExpr : public Expr {
   SourceLocation getAccessorLoc() const { return AccessorLoc; }
   void setAccessorLoc(SourceLocation L) { AccessorLoc = L; }
 
-  /// getNumElements - Get the number of components being selected.
-  unsigned getNumElements() const;
-
-  /// containsDuplicateElements - Return true if any element access is
-  /// repeated.
-  bool containsDuplicateElements() const;
-
-  /// getEncodedElementAccess - Encode the elements accessed into an llvm
-  /// aggregate Constant of ConstantInt(s).
-  void getEncodedElementAccess(SmallVectorImpl<uint32_t> &Elts) const;
-
   SourceLocation getBeginLoc() const LLVM_READONLY {
     return getBase()->getBeginLoc();
   }
   SourceLocation getEndLoc() const LLVM_READONLY { return AccessorLoc; }
 
+  child_range children() { return child_range(&Base, &Base + 1); }
+  const_child_range children() const {
+    return const_child_range(&Base, &Base + 1);
+  }
+};
+
+/// ExtVectorElementExpr - This represents access to specific elements of a
+/// vector, and may occur on the left hand side or right hand side.  For example
+/// the following is legal:  "V.xy = V.zw" if V is a 4 element extended vector.
+///
+/// Note that the base may have either vector or pointer to vector type, just
+/// like a struct field reference.
+///
+class ExtVectorElementExpr
+    : public ElementAccessExprBase<ExtVectorElementExpr> {
+public:
+  ExtVectorElementExpr(QualType Ty, ExprValueKind VK, Expr *Base,
+                       IdentifierInfo &Accessor, SourceLocation Loc)
+      : ElementAccessExprBase(
+            ExtVectorElementExprClass, Ty, VK, Base, Accessor, Loc,
+            (VK == VK_PRValue ? OK_Ordinary : OK_VectorComponent)) {}
+
+  explicit ExtVectorElementExpr(EmptyShell Empty)
+      : ElementAccessExprBase(ExtVectorElementExprClass, Empty) {}
+
+  unsigned getNumElements() const;
+  bool containsDuplicateElements() const;
+  void getEncodedElementAccess(SmallVectorImpl<uint32_t> &Elts) const;
+
   /// isArrow - Return true if the base expression is a pointer to vector,
   /// return false if the base expression is a vector.
   bool isArrow() const;
@@ -6545,11 +6560,25 @@ class ExtVectorElementExpr : public Expr {
   static bool classof(const Stmt *T) {
     return T->getStmtClass() == ExtVectorElementExprClass;
   }
+};
 
-  // Iterators
-  child_range children() { return child_range(&Base, &Base+1); }
-  const_child_range children() const {
-    return const_child_range(&Base, &Base + 1);
+class MatrixElementExpr : public ElementAccessExprBase<MatrixElementExpr> {
+public:
+  MatrixElementExpr(QualType Ty, ExprValueKind VK, Expr *Base,
+                    IdentifierInfo &Accessor, SourceLocation Loc)
+      : ElementAccessExprBase(
+            MatrixElementExprClass, Ty, VK, Base, Accessor, Loc,
+            OK_Ordinary /*TODO: Should we add a new OK_MatrixComponent?*/) {}
+
+  explicit MatrixElementExpr(EmptyShell Empty)
+      : ElementAccessExprBase(MatrixElementExprClass, Empty) {}
+
+  unsigned getNumElements() const;
+  bool containsDuplicateElements() const;
+  void getEncodedElementAccess(SmallVectorImpl<uint32_t> &Elts) const;
+
+  static bool classof(const Stmt *T) {
+    return T->getStmtClass() == MatrixElementExprClass;
   }
 };
 
diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h
index 8f427427d71ed..0cc10403657ca 100644
--- a/clang/include/clang/AST/RecursiveASTVisitor.h
+++ b/clang/include/clang/AST/RecursiveASTVisitor.h
@@ -2940,6 +2940,7 @@ DEF_TRAVERSE_STMT(UserDefinedLiteral, {})
 DEF_TRAVERSE_STMT(DesignatedInitExpr, {})
 DEF_TRAVERSE_STMT(DesignatedInitUpdateExpr, {})
 DEF_TRAVERSE_STMT(ExtVectorElementExpr, {})
+DEF_TRAVERSE_STMT(MatrixElementExpr, {})
 DEF_TRAVERSE_STMT(GNUNullExpr, {})
 DEF_TRAVERSE_STMT(ImplicitValueInitExpr, {})
 DEF_TRAVERSE_STMT(NoInitExpr, {})
diff --git a/clang/include/clang/AST/TextNodeDumper.h b/clang/include/clang/AST/TextNodeDumper.h
index 88ecd526e3d7e..ab828be124b0b 100644
--- a/clang/include/clang/AST/TextNodeDumper.h
+++ b/clang/include/clang/AST/TextNodeDumper.h
@@ -286,6 +286,7 @@ class TextNodeDumper
   void VisitUnaryExprOrTypeTraitExpr(const UnaryExprOrTypeTraitExpr *Node);
   void VisitMemberExpr(const MemberExpr *Node);
   void VisitExtVectorElementExpr(const ExtVectorElementExpr *Node);
+  void VisitMatrixElementExpr(const MatrixElementExpr *Node);
   void VisitBinaryOperator(const BinaryOperator *Node);
   void VisitCompoundAssignOperator(const CompoundAssignOperator *Node);
   void VisitAddrLabelExpr(const AddrLabelExpr *Node);
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index fec278c21a89e..bba27597be6bd 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -9336,6 +9336,8 @@ def err_typecheck_lvalue_casts_not_supported : Error<
 
 def err_typecheck_duplicate_vector_components_not_mlvalue : Error<
   "vector is not assignable (contains duplicate components)">;
+def err_typecheck_duplicate_matrix_components_not_mlvalue : Error<
+  "matrix is not assignable (contains duplicate components)">;
 def err_block_decl_ref_not_modifiable_lvalue : Error<
   "variable is not assignable (missing __block type specifier)">;
 def err_lambda_decl_ref_not_modifiable_lvalue : Error<
@@ -13061,6 +13063,7 @@ def err_builtin_matrix_stride_too_small: Error<
   "stride must be greater or equal to the number of rows">;
 def err_builtin_matrix_invalid_dimension: Error<
   "%0 dimension is outside the allowed range [1, %1]">;
+def err_builtin_matrix_invalid_member: Error<"invalid matrix member '%0' expected %1">;
 
 def warn_mismatched_import : Warning<
   "import %select{module|name}0 (%1) does not match the import %select{module|name}0 (%2) of the "
@@ -13318,6 +13321,15 @@ def err_hlsl_builtin_scalar_vector_mismatch
           "%select{all|second and third}0 arguments to %1 must be of scalar or "
           "vector type with matching scalar element type%diff{: $ vs $|}2,3">;
 
+def err_hlsl_matrix_element_not_in_bounds : Error<
+  "matrix %select{row|column}0 element accessor is out of bounds of %select{zero|one}1 based indexing">;
+
+def err_hlsl_matrix_index_out_of_bounds : Error<
+  "matrix %select{row|column}0 index %1 is out of bounds of %select{rows|columns}0 size %2">;
+
+def err_hlsl_matrix_swizzle_invalid_length : Error<
+  "matrix swizzle length must be between 1 and 4 but is %0">;
+
 def warn_hlsl_impcast_vector_truncation : Warning<
   "implicit conversion truncates vector: %0 to %1">, InGroup<VectorConversion>;
 
diff --git a/clang/include/clang/Basic/StmtNodes.td b/clang/include/clang/Basic/StmtNodes.td
index bf3686bb372d5..950ebba66b720 100644
--- a/clang/include/clang/Basic/StmtNodes.td
+++ b/clang/include/clang/Basic/StmtNodes.td
@@ -91,6 +91,7 @@ def CStyleCastExpr : StmtNode<ExplicitCastExpr>;
 def OMPArrayShapingExpr : StmtNode<Expr>;
 def CompoundLiteralExpr : StmtNode<Expr>;
 def ExtVectorElementExpr : StmtNode<Expr>;
+def MatrixElementExpr : StmtNode<Expr>;
 def InitListExpr : StmtNode<Expr>;
 def DesignatedInitExpr : StmtNode<Expr>;
 def DesignatedInitUpdateExpr : StmtNode<Expr>;
diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h
index a2faa91d1e54d..953b3529c40b6 100644
--- a/clang/include/clang/Sema/SemaHLSL.h
+++ b/clang/include/clang/Sema/SemaHLSL.h
@@ -214,6 +214,10 @@ class SemaHLSL : public SemaBase {
   bool transformInitList(const InitializedEntity &Entity, InitListExpr *Init);
   bool handleInitialization(VarDecl *VDecl, Expr *&Init);
   void deduceAddressSpace(VarDecl *Decl);
+  QualType CheckMatrixComponent(Sema &S, QualType baseType, ExprValueKind &VK,
+                                SourceLocation OpLoc,
+                                const IdentifierInfo *CompName,
+                                SourceLocation CompLoc);
 
 private:
   // HLSL resource type attributes need to be processed all at once.
diff --git a/clang/include/clang/Serialization/ASTBitCodes.h b/clang/include/clang/Serialization/ASTBitCodes.h
index d7d429eacd67a..2fc71b7f916ac 100644
--- a/clang/include/clang/Serialization/ASTBitCodes.h
+++ b/clang/include/clang/Serialization/ASTBitCodes.h
@@ -1693,6 +1693,9 @@ enum StmtCode {
   /// An ExtVectorElementExpr record.
   EXPR_EXT_VECTOR_ELEMENT,
 
+  /// A MatrixElementExpr record.
+  EXPR_MATRIX_ELEMENT,
+
   /// An InitListExpr record.
   EXPR_INIT_LIST,
 
diff --git a/clang/lib/AST/ComputeDependence.cpp b/clang/lib/AST/ComputeDependence.cpp
index 638080ea781a9..610aa16e9ae13 100644
--- a/clang/lib/AST/ComputeDependence.cpp
+++ b/clang/lib/AST/ComputeDependence.cpp
@@ -252,6 +252,10 @@ ExprDependence clang::computeDependence(ExtVectorElementExpr *E) {
   return E->getBase()->getDependence();
 }
 
+ExprDependence clang::computeDependence(MatrixElementExpr *E) {
+  return E->getBase()->getDependence();
+}
+
 ExprDependence clang::computeDependence(BlockExpr *E,
                                         bool ContainsUnexpandedParameterPack) {
   auto D = toExprDependenceForImpliedType(E->getType()->getDependence());
diff --git a/clang/lib/AST/Expr.cpp b/clang/lib/AST/Expr.cpp
index ca7f3e16a9276..2fd071fa3fb45 100644
--- a/clang/lib/AST/Expr.cpp
+++ b/clang/lib/AST/Expr.cpp
@@ -25,6 +25,7 @@
 #include "clang/AST/IgnoreExpr.h"
 #include "clang/AST/Mangle.h"
 #include "clang/AST/RecordLayout.h"
+#include "clang/AST/TypeBase.h"
 #include "clang/Basic/Builtins.h"
 #include "clang/Basic/CharInfo.h"
 #include "clang/Basic/SourceManager.h"
@@ -3798,6 +3799,7 @@ bool Expr::HasSideEffects(const ASTContext &Ctx,
   case BinaryConditionalOperatorClass:
   case CompoundLiteralExprClass:
   case ExtVectorElementExprClass:
+  case MatrixElementExprClass:
   case DesignatedInitExprClass:
   case DesignatedInitUpdateExprClass:
   case ArrayInitLoopExprClass:
@@ -4418,7 +4420,14 @@ unsigned ExtVectorElementExpr::getNumElements() const {
   return 1;
 }
 
-/// containsDuplicateElements - Return true if any element access is repeated.
+unsigned MatrixElementExpr::getNumElements() const {
+  if (const ConstantMatrixType *MT = getType()->getAs<ConstantMatrixType>())
+    return MT->getNumElementsFlattened();
+  return 1;
+}
+
+/// containsDuplicateElements - Return true if any Vector element access is
+/// repeated.
 bool ExtVectorElementExpr::containsDuplicateElements() const {
   // FIXME: Refactor this code to an accessor on the AST node which returns the
   // "type" of component access, and share with code below and in Sema.
@@ -4439,6 +4448,68 @@ bool ExtVectorElementExpr::containsDuplicateElements() const {
   return false;
 }
 
+/// containsDuplicateElements - Return true if any Matrix element access is
+/// repeated.
+bool MatrixElementExpr::containsDuplicateElements() const {
+  StringRef Comp = Accessor->getName();
+  assert(!Comp.empty() && Comp[0] == '_' && "invalid matrix accessor");
+
+  // Get the matrix type so we know bounds.
+  const ConstantMatrixType *MT =
+      getBase()->getType()->getAs<ConstantMatrixType>();
+  assert(MT && "MatrixElementExpr base must be a matrix type");
+
+  unsigned Rows = MT->getNumRows();
+  unsigned Cols = MT->getNumColumns();
+  unsigned Max = Rows * Cols;
+
+  // Zero-indexed: _mRC  (4 chars per component)
+  // One-indexed: _RC    (3 chars per component)
+  bool IsZeroIndexed = false;
+  unsigned ChunkLen = 0;
+
+  if (Comp.size() >= 2 && Comp[0] == '_' && Comp[1] == 'm') {
+    IsZeroIndexed = true;
+    ChunkLen = 4;
+  } else {
+    IsZeroIndexed = false;
+    ChunkLen = 3;
+  }
+
+  assert(ChunkLen && "unrecognized matrix swizzle format");
+  assert(Comp.size() % ChunkLen == 0 &&
+         "matrix swizzle accessor has invalid length");
+
+  // Track visited elements using real matrix size.
+  SmallVector<bool, 16> Seen(Max, false);
+
+  for (unsigned I = 0, e = Comp.size(); I < e; I += ChunkLen) {
+    unsigned Row = 0, Col = 0;
+
+    if (IsZeroIndexed) {
+      // Pattern: _mRC
+      assert(Comp[I] == '_' && Comp[I + 1] == 'm');
+      Row = Comp[I + 2] - '0'; // 0..(Rows-1)
+      Col = Comp[I + 3] - '0';
+    } else {
+      // Pattern: _RC
+      assert(Comp[I] == '_');
+      Row = (Comp[I + 1] - '1'); // 1..Rows (ie same as 0..Rows-1)
+      Col = (Comp[I + 2] - '1');
+    }
+
+    // Bounds check (Sema should enforce correctness, but we assert anyway)
+    assert(Row < Rows && Col < Cols && "matrix swizzle index out of bounds");
+
+    unsigned Index = Row * Cols + Col;
+    if (Seen[Index])
+      return true;
+
+    Seen[Index] = true;
+  }
+  return false;
+}
+
 /// getEncodedElementAccess - We encode the fields as a llvm ConstantArray.
 void ExtVectorElementExpr::getEncodedElementAccess(
     SmallVectorImpl<uint32_t> &Elts) const {
@@ -4472,6 +4543,59 @@ void ExtVectorElementExpr::getEncodedElementAccess(
   }
 }
 
+void MatrixElementExpr::getEncodedElementAccess(
+    SmallVectorImpl<uint32_t> &Elts) const {
+  StringRef Comp = Accessor->getName();
+  assert(!Comp.empty() && Comp[0] == '_' && "invalid matrix accessor");
+
+  const ConstantMatrixType *MT =
+      getBase()->getType()->getAs<ConstantMatrixType>();
+  assert(MT && "MatrixElementExpr base must be a matrix type");
+
+  unsigned Rows = MT->getNumRows();
+  unsigned Cols = MT->getNumColumns();
+
+  // Zero-indexed: _mRC (4 chars per component: '_', 'm', row, col)
+  // One-indexed:  _RC  (3 chars per component: '_', row, col)
+  bool IsZeroIndexed = false;
+  unsigned ChunkLen = 0;
+
+  if (Comp.size() >= 2 && Comp[0] == '_' && Comp[1] == 'm') {
+    IsZeroIndexed = true;
+    ChunkLen = 4;
+  } else {
+    IsZeroIndexed = false;
+    ChunkLen = 3;
+  }
+
+  assert(ChunkLen != 0 && "unrecognized matrix swizzle format");
+  assert(Comp.size() % ChunkLen == 0 &&
+         "matrix swizzle accessor has invalid length");
+
+  for (unsigned i = 0, e = Comp.size(); i < e; i += ChunkLen) {
+    unsigned Row = 0, Col = 0;
+
+    if (IsZeroIndexed) {
+      // Pattern: _mRC
+      assert(Comp[i] == '_' && Comp[i + 1] == 'm' &&
+             "invalid zero-indexed matrix swizzle component");
+      Row = static_cast<unsigned>(Comp[i + 2] - '0'); // 0..Rows-1
+      Col = static_cast<unsigned>(Comp[i + 3] - '0'); // 0..Cols-1
+    } else {
+      // Pattern: _RC
+      assert(Comp[i] == '_' && "invalid one-indexed matrix swizzle component");
+      Row = static_cast<unsigned>(Comp[i + 1] - '1'); // 1..Rows -> 0..Rows-1
+      Col = static_cast<unsigned>(Comp[i + 2] - '1'); // 1..Cols -> 0..Cols-1
+    }
+
+    // Sema should have validated these, but assert here for sanity.
+    assert(Row < Rows && Col < Cols && "matrix swizzle index out of range");
+
+    unsigned Index = Row * Cols + Col;
+    Elts.push_back(Index);
+  }
+}
+
 ShuffleVectorExpr::ShuffleVectorExpr(const ASTContext &C, ArrayRef<Expr *> args,
                                      QualType Type, SourceLocation BLoc,
                                      SourceLocation RP)
diff --git a/clang/lib/AST/ExprClassification.cpp b/clang/lib/AST/ExprClassification.cpp
index aeacd0dc765ef..7bb5a2202c3bf 100644
--- a/clang/lib/AST/ExprClassification.cpp
+++ b/clang/lib/AST/ExprClassification.cpp
@@ -63,6 +63,7 @@ Cl Expr::ClassifyImpl(ASTContext &Ctx, SourceLocation *Loc) const {
   case Cl::CL_Void:
   case Cl::CL_AddressableVoid:
   case Cl::CL_DuplicateVectorComponents:
+  case Cl::CL_DuplicateMatrixComponents:
   case Cl::CL_MemberFunction:
   case Cl::CL_SubObjCPropertySetting:
   case Cl::CL_ClassTempor...
[truncated]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

clang:as-a-library libclang and C++ API clang:codegen IR generation bugs: mangling, exceptions, etc. clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:modules C++20 modules and Clang Header Modules clang:static analyzer clang Clang issues not falling into any other category ClangIR Anything related to the ClangIR project HLSL HLSL Language Support

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[HLSL] Add _m and _<numeric> based (swizzle) accessors to hlsl::matrix.

2 participants