Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion clang/docs/MatrixTypes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ Otherwise, the result is a glvalue with type ``cv T`` and with the same value
category as ``E1`` which refers to the element at the given row and column in
the matrix.

Programs containing a single subscript expression into a matrix are ill-formed.
A single subscript expression into a matrix is legal in HLSL and denotes a
vector for the specified row lane, but is ill-formed in C and C++.

**Note**: We considered providing an expression of the form
``postfix-expression [expression]`` to access columns of a matrix. We think
Expand Down
2 changes: 2 additions & 0 deletions clang/include/clang/AST/ComputeDependence.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class ParenExpr;
class UnaryOperator;
class UnaryExprOrTypeTraitExpr;
class ArraySubscriptExpr;
class MatrixSingleSubscriptExpr;
class MatrixSubscriptExpr;
class CompoundLiteralExpr;
class ImplicitCastExpr;
Expand Down Expand Up @@ -117,6 +118,7 @@ ExprDependence computeDependence(ParenExpr *E);
ExprDependence computeDependence(UnaryOperator *E, const ASTContext &Ctx);
ExprDependence computeDependence(UnaryExprOrTypeTraitExpr *E);
ExprDependence computeDependence(ArraySubscriptExpr *E);
ExprDependence computeDependence(MatrixSingleSubscriptExpr *E);
ExprDependence computeDependence(MatrixSubscriptExpr *E);
ExprDependence computeDependence(CompoundLiteralExpr *E);
ExprDependence computeDependence(ImplicitCastExpr *E);
Expand Down
66 changes: 66 additions & 0 deletions clang/include/clang/AST/Expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -2790,6 +2790,72 @@ class ArraySubscriptExpr : public Expr {
}
};

/// MatrixSingleSubscriptExpr - Matrix single subscript expression for the
/// MatrixType extension when you want to get\set a vector from a Matrix.
class MatrixSingleSubscriptExpr : public Expr {
enum { BASE, ROW_IDX, END_EXPR };
Stmt *SubExprs[END_EXPR];

public:
/// matrix[row]
///
/// \param Base The matrix expression.
/// \param RowIdx The row index expression.
/// \param T The type of the row (usually a vector type).
/// \param RBracketLoc Location of the closing ']'.
MatrixSingleSubscriptExpr(Expr *Base, Expr *RowIdx, QualType T,
SourceLocation RBracketLoc)
: Expr(MatrixSingleSubscriptExprClass, T,
Base->getValueKind(), // lvalue/rvalue follows the matrix base
OK_MatrixComponent) {
SubExprs[BASE] = Base;
SubExprs[ROW_IDX] = RowIdx;
ArrayOrMatrixSubscriptExprBits.RBracketLoc = RBracketLoc;
setDependence(computeDependence(this));
}

/// Create an empty matrix single-subscript expression.
explicit MatrixSingleSubscriptExpr(EmptyShell Shell)
: Expr(MatrixSingleSubscriptExprClass, Shell) {}

Expr *getBase() { return cast<Expr>(SubExprs[BASE]); }
const Expr *getBase() const { return cast<Expr>(SubExprs[BASE]); }
void setBase(Expr *E) { SubExprs[BASE] = E; }

Expr *getRowIdx() { return cast<Expr>(SubExprs[ROW_IDX]); }
const Expr *getRowIdx() const { return cast<Expr>(SubExprs[ROW_IDX]); }
void setRowIdx(Expr *E) { SubExprs[ROW_IDX] = E; }

SourceLocation getBeginLoc() const LLVM_READONLY {
return getBase()->getBeginLoc();
}

SourceLocation getEndLoc() const { return getRBracketLoc(); }

SourceLocation getExprLoc() const LLVM_READONLY {
return getBase()->getExprLoc();
}

SourceLocation getRBracketLoc() const {
return ArrayOrMatrixSubscriptExprBits.RBracketLoc;
}
void setRBracketLoc(SourceLocation L) {
ArrayOrMatrixSubscriptExprBits.RBracketLoc = L;
}

static bool classof(const Stmt *T) {
return T->getStmtClass() == MatrixSingleSubscriptExprClass;
}

// Iterators
child_range children() {
return child_range(&SubExprs[0], &SubExprs[0] + END_EXPR);
}
const_child_range children() const {
return const_child_range(&SubExprs[0], &SubExprs[0] + END_EXPR);
}
};

/// MatrixSubscriptExpr - Matrix subscript expression for the MatrixType
/// extension.
/// MatrixSubscriptExpr can be either incomplete (only Base and RowIdx are set
Expand Down
1 change: 1 addition & 0 deletions clang/include/clang/AST/RecursiveASTVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -2894,6 +2894,7 @@ DEF_TRAVERSE_STMT(CXXMemberCallExpr, {})
// over the children.
DEF_TRAVERSE_STMT(AddrLabelExpr, {})
DEF_TRAVERSE_STMT(ArraySubscriptExpr, {})
DEF_TRAVERSE_STMT(MatrixSingleSubscriptExpr, {})
DEF_TRAVERSE_STMT(MatrixSubscriptExpr, {})
DEF_TRAVERSE_STMT(ArraySectionExpr, {})
DEF_TRAVERSE_STMT(OMPArrayShapingExpr, {})
Expand Down
1 change: 1 addition & 0 deletions clang/include/clang/AST/Stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,7 @@ class alignas(void *) Stmt {
class ArrayOrMatrixSubscriptExprBitfields {
friend class ArraySubscriptExpr;
friend class MatrixSubscriptExpr;
friend class MatrixSingleSubscriptExpr;

LLVM_PREFERRED_TYPE(ExprBitfields)
unsigned : NumExprBits;
Expand Down
4 changes: 2 additions & 2 deletions clang/include/clang/Basic/Specifiers.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ namespace clang {
/// A bitfield object is a bitfield on a C or C++ record.
OK_BitField,

/// A vector component is an element or range of elements on a vector.
/// A vector component is an element or range of elements of a vector.
OK_VectorComponent,

/// An Objective-C property is a logical field of an Objective-C
Expand All @@ -165,7 +165,7 @@ namespace clang {
/// Objective-C method calls.
OK_ObjCSubscript,

/// A matrix component is a single element of a matrix.
/// A matrix component is a single element or range of elements of a matrix.
OK_MatrixComponent
};

Expand Down
1 change: 1 addition & 0 deletions clang/include/clang/Basic/StmtNodes.td
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def UnaryOperator : StmtNode<Expr>;
def OffsetOfExpr : StmtNode<Expr>;
def UnaryExprOrTypeTraitExpr : StmtNode<Expr>;
def ArraySubscriptExpr : StmtNode<Expr>;
def MatrixSingleSubscriptExpr : StmtNode<Expr>;
def MatrixSubscriptExpr : StmtNode<Expr>;
def ArraySectionExpr : StmtNode<Expr>;
def OMPIteratorExpr : StmtNode<Expr>;
Expand Down
3 changes: 3 additions & 0 deletions clang/include/clang/Sema/Sema.h
Original file line number Diff line number Diff line change
Expand Up @@ -7405,6 +7405,9 @@ class Sema final : public SemaBase {
ExprResult CreateBuiltinArraySubscriptExpr(Expr *Base, SourceLocation LLoc,
Expr *Idx, SourceLocation RLoc);

ExprResult CreateBuiltinMatrixSingleSubscriptExpr(Expr *Base, Expr *RowIdx,
SourceLocation RBLoc);

ExprResult CreateBuiltinMatrixSubscriptExpr(Expr *Base, Expr *RowIdx,
Expr *ColumnIdx,
SourceLocation RBLoc);
Expand Down
4 changes: 4 additions & 0 deletions clang/lib/AST/ComputeDependence.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ ExprDependence clang::computeDependence(ArraySubscriptExpr *E) {
return E->getLHS()->getDependence() | E->getRHS()->getDependence();
}

ExprDependence clang::computeDependence(MatrixSingleSubscriptExpr *E) {
return E->getBase()->getDependence() | E->getRowIdx()->getDependence();
}

ExprDependence clang::computeDependence(MatrixSubscriptExpr *E) {
return E->getBase()->getDependence() | E->getRowIdx()->getDependence() |
(E->getColumnIdx() ? E->getColumnIdx()->getDependence()
Expand Down
1 change: 1 addition & 0 deletions clang/lib/AST/Expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3792,6 +3792,7 @@ bool Expr::HasSideEffects(const ASTContext &Ctx,

case ParenExprClass:
case ArraySubscriptExprClass:
case MatrixSingleSubscriptExprClass:
case MatrixSubscriptExprClass:
case ArraySectionExprClass:
case OMPArrayShapingExprClass:
Expand Down
3 changes: 3 additions & 0 deletions clang/lib/AST/ExprClassification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,9 @@ static Cl::Kinds ClassifyInternal(ASTContext &Ctx, const Expr *E) {
}
return Cl::CL_LValue;

case Expr::MatrixSingleSubscriptExprClass:
return ClassifyInternal(Ctx, cast<MatrixSingleSubscriptExpr>(E)->getBase());

// Subscripting matrix types behaves like member accesses.
case Expr::MatrixSubscriptExprClass:
return ClassifyInternal(Ctx, cast<MatrixSubscriptExpr>(E)->getBase());
Expand Down
1 change: 1 addition & 0 deletions clang/lib/AST/ExprConstant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20887,6 +20887,7 @@ static ICEDiag CheckICE(const Expr* E, const ASTContext &Ctx) {
case Expr::ImaginaryLiteralClass:
case Expr::StringLiteralClass:
case Expr::ArraySubscriptExprClass:
case Expr::MatrixSingleSubscriptExprClass:
case Expr::MatrixSubscriptExprClass:
case Expr::ArraySectionExprClass:
case Expr::OMPArrayShapingExprClass:
Expand Down
9 changes: 9 additions & 0 deletions clang/lib/AST/ItaniumMangle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5485,6 +5485,15 @@ void CXXNameMangler::mangleExpression(const Expr *E, unsigned Arity,
break;
}

case Expr::MatrixSingleSubscriptExprClass: {
NotPrimaryExpr();
const MatrixSingleSubscriptExpr *ME = cast<MatrixSingleSubscriptExpr>(E);
Out << "ix";
mangleExpression(ME->getBase());
mangleExpression(ME->getRowIdx());
break;
}

case Expr::MatrixSubscriptExprClass: {
NotPrimaryExpr();
const MatrixSubscriptExpr *ME = cast<MatrixSubscriptExpr>(E);
Expand Down
8 changes: 8 additions & 0 deletions clang/lib/AST/StmtPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1690,6 +1690,14 @@ void StmtPrinter::VisitArraySubscriptExpr(ArraySubscriptExpr *Node) {
OS << "]";
}

void StmtPrinter::VisitMatrixSingleSubscriptExpr(
MatrixSingleSubscriptExpr *Node) {
PrintExpr(Node->getBase());
OS << "[";
PrintExpr(Node->getRowIdx());
OS << "]";
}

void StmtPrinter::VisitMatrixSubscriptExpr(MatrixSubscriptExpr *Node) {
PrintExpr(Node->getBase());
OS << "[";
Expand Down
5 changes: 5 additions & 0 deletions clang/lib/AST/StmtProfile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1510,6 +1510,11 @@ void StmtProfiler::VisitArraySubscriptExpr(const ArraySubscriptExpr *S) {
VisitExpr(S);
}

void StmtProfiler::VisitMatrixSingleSubscriptExpr(
const MatrixSingleSubscriptExpr *S) {
VisitExpr(S);
}

void StmtProfiler::VisitMatrixSubscriptExpr(const MatrixSubscriptExpr *S) {
VisitExpr(S);
}
Expand Down
84 changes: 84 additions & 0 deletions clang/lib/CodeGen/CGExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/LLVMContext.h"
Expand Down Expand Up @@ -1818,6 +1819,8 @@ LValue CodeGenFunction::EmitLValueHelper(const Expr *E,
return EmitUnaryOpLValue(cast<UnaryOperator>(E));
case Expr::ArraySubscriptExprClass:
return EmitArraySubscriptExpr(cast<ArraySubscriptExpr>(E));
case Expr::MatrixSingleSubscriptExprClass:
return EmitMatrixSingleSubscriptExpr(cast<MatrixSingleSubscriptExpr>(E));
case Expr::MatrixSubscriptExprClass:
return EmitMatrixSubscriptExpr(cast<MatrixSubscriptExpr>(E));
case Expr::ArraySectionExprClass:
Expand Down Expand Up @@ -2462,6 +2465,31 @@ RValue CodeGenFunction::EmitLoadOfLValue(LValue LV, SourceLocation Loc) {
Builder.CreateLoad(LV.getMatrixAddress(), LV.isVolatileQualified());
return RValue::get(Builder.CreateExtractElement(Load, Idx, "matrixext"));
}
if (LV.isMatrixRow()) {
QualType MatTy = LV.getType();
const ConstantMatrixType *MT = MatTy->castAs<ConstantMatrixType>();

unsigned NumRows = MT->getNumRows();
unsigned NumCols = MT->getNumColumns();

llvm::Value *MatrixVec = EmitLoadOfScalar(LV, Loc);
llvm::Value *Row = LV.getMatrixRowIdx();
llvm::Type *ElemTy = ConvertType(MT->getElementType());
llvm::Type *RowTy = llvm::FixedVectorType::get(ElemTy, MT->getNumColumns());
llvm::Value *Result = llvm::PoisonValue::get(RowTy); // <NumCols x T>

llvm::MatrixBuilder MB(Builder);

for (unsigned Col = 0; Col < NumCols; ++Col) {
llvm::Value *ColIdx = llvm::ConstantInt::get(Row->getType(), Col);
llvm::Value *EltIndex = MB.CreateIndex(Row, ColIdx, NumRows);
llvm::Value *Elt = Builder.CreateExtractElement(MatrixVec, EltIndex);
llvm::Value *Lane = llvm::ConstantInt::get(Builder.getInt32Ty(), Col);
Result = Builder.CreateInsertElement(Result, Elt, Lane);
}

return RValue::get(Result);
}

assert(LV.isBitField() && "Unknown LValue type!");
return EmitLoadOfBitfieldLValue(LV, Loc);
Expand Down Expand Up @@ -2689,6 +2717,31 @@ void CodeGenFunction::EmitStoreThroughLValue(RValue Src, LValue Dst,
addInstToCurrentSourceAtom(I, Vec);
return;
}
if (Dst.isMatrixRow()) {
QualType MatTy = Dst.getType();
const ConstantMatrixType *MT = MatTy->castAs<ConstantMatrixType>();

unsigned NumRows = MT->getNumRows();
unsigned NumCols = MT->getNumColumns();

llvm::Value *MatrixVec =
Builder.CreateLoad(Dst.getAddress(), "matrix.load");

llvm::Value *Row = Dst.getMatrixRowIdx();
llvm::Value *RowVal = Src.getScalarVal(); // <NumCols x T>
llvm::MatrixBuilder MB(Builder);

for (unsigned Col = 0; Col < NumCols; ++Col) {
llvm::Value *ColIdx = llvm::ConstantInt::get(Row->getType(), Col);
llvm::Value *EltIndex = MB.CreateIndex(Row, ColIdx, NumRows);
llvm::Value *Lane = llvm::ConstantInt::get(Builder.getInt32Ty(), Col);
llvm::Value *NewElt = Builder.CreateExtractElement(RowVal, Lane);
MatrixVec = Builder.CreateInsertElement(MatrixVec, NewElt, EltIndex);
}

Builder.CreateStore(MatrixVec, Dst.getAddress());
return;
}

assert(Dst.isBitField() && "Unknown LValue type");
return EmitStoreThroughBitfieldLValue(Src, Dst);
Expand Down Expand Up @@ -4904,6 +4957,34 @@ llvm::Value *CodeGenFunction::EmitMatrixIndexExpr(const Expr *E) {
return Builder.CreateIntCast(Idx, IntPtrTy, IsSigned);
}

LValue CodeGenFunction::EmitMatrixSingleSubscriptExpr(
const MatrixSingleSubscriptExpr *E) {
LValue Base = EmitLValue(E->getBase());
llvm::Value *RowIdx = EmitMatrixIndexExpr(E->getRowIdx());

if (auto *RowConst = llvm::dyn_cast<llvm::ConstantInt>(RowIdx)) {
// Extract matrix shape from the AST type
const auto *MatTy = E->getBase()->getType()->castAs<ConstantMatrixType>();
unsigned NumCols = MatTy->getNumColumns();
llvm::SmallVector<llvm::Constant *, 8> Indices;
Indices.reserve(NumCols);

unsigned Row = RowConst->getZExtValue();
unsigned Start = Row * NumCols;
for (unsigned C = 0; C < NumCols; ++C)
Indices.push_back(llvm::ConstantInt::get(Int32Ty, Start + C));

llvm::Constant *Elts = llvm::ConstantVector::get(Indices);
return LValue::MakeExtVectorElt(
MaybeConvertMatrixAddress(Base.getAddress(), *this), Elts,
E->getBase()->getType(), Base.getBaseInfo(), TBAAAccessInfo());
}

return LValue::MakeMatrixRow(
MaybeConvertMatrixAddress(Base.getAddress(), *this), RowIdx,
E->getBase()->getType(), Base.getBaseInfo(), TBAAAccessInfo());
}

LValue CodeGenFunction::EmitMatrixSubscriptExpr(const MatrixSubscriptExpr *E) {
assert(
!E->isIncomplete() &&
Expand Down Expand Up @@ -5176,6 +5257,9 @@ EmitExtVectorElementExpr(const ExtVectorElementExpr *E) {
return LValue::MakeExtVectorElt(Base.getAddress(), CV, type,
Base.getBaseInfo(), TBAAAccessInfo());
}
if (Base.isMatrixRow())
return EmitUnsupportedLValue(E, "Matrix single index swizzle");

assert(Base.isExtVectorElt() && "Can only subscript lvalue vec elts here!");

llvm::Constant *BaseElts = Base.getExtVectorElts();
Expand Down
Loading