-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[HLSL][DXIL] Implement refract
intrinsic
#136026
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
6ce233a
83d69dd
83c4f5a
1853e3d
dff5181
3d87d6d
e3c1b0a
d8b079c
d1e1fe9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -71,6 +71,23 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) { | |||||||||||||||||||||||||
#endif | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
template <typename T> constexpr T refract_impl(T I, T N, T Eta) { | ||||||||||||||||||||||||||
T K = 1 - Eta * Eta * (1 - (N * I * N * I)); | ||||||||||||||||||||||||||
T Result = (Eta * I - (Eta * N * I + sqrt(K)) * N); | ||||||||||||||||||||||||||
return select<T>(K < 0, static_cast<T>(0), Result); | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
template <typename T, int L> | ||||||||||||||||||||||||||
constexpr vector<T, L> refract_vec_impl(vector<T, L> I, vector<T, L> N, T Eta) { | ||||||||||||||||||||||||||
#if (__has_builtin(__builtin_spirv_refract)) | ||||||||||||||||||||||||||
return __builtin_spirv_refract(I, N, Eta); | ||||||||||||||||||||||||||
#else | ||||||||||||||||||||||||||
vector<T, L> K = 1 - Eta * Eta * (1 - dot(N, I) * dot(N, I)); | ||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should really be storing |
||||||||||||||||||||||||||
vector<T, L> Result = (Eta * I - (Eta * dot(N, I) + sqrt(K)) * N); | ||||||||||||||||||||||||||
return select<vector<T, L>>(K < 0, vector<T, L>(0), Result); | ||||||||||||||||||||||||||
Comment on lines
+82
to
+87
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Our most recent pattern on these has been to do one function when possible instead of a scalar and vector implementation. I beleive this is a case where that might be possible. Can you experiment with the following? In template <typename T>
struct is_vector {
static const bool value = false;
};
/*NOTE: (don't include this comment) what I am doing here is adding a specialization for vector<T, N>*/
template <typename T, int N>
struct is_vector<vector<T, N>> {
static const bool value = true;
}; Then line 82-87 changes to
Suggested change
|
||||||||||||||||||||||||||
#endif | ||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||
|
||||||||||||||||||||||||||
template <typename T> constexpr T fmod_impl(T X, T Y) { | ||||||||||||||||||||||||||
#if !defined(__DIRECTX__) | ||||||||||||||||||||||||||
return __builtin_elementwise_fmod(X, Y); | ||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -16,88 +16,90 @@ namespace clang { | |||
|
||||
SemaSPIRV::SemaSPIRV(Sema &S) : SemaBase(S) {} | ||||
|
||||
/// Checks if the first `NumArgsToCheck` arguments of a function call are of | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @spall did a bunch of work to get semantic checking to be more consistent. She should review at least this file for this PR. |
||||
/// vector type. If any of the arguments is not a vector type, it emits a | ||||
/// diagnostic error and returns `true`. Otherwise, it returns `false`. | ||||
/// | ||||
/// \param TheCall The function call expression to check. | ||||
/// \param NumArgsToCheck The number of arguments to check for vector type. | ||||
/// \return `true` if any of the arguments is not a vector type, `false` | ||||
/// otherwise. | ||||
|
||||
bool SemaSPIRV::CheckVectorArgs(CallExpr *TheCall, unsigned NumArgsToCheck) { | ||||
raoanag marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
for (unsigned i = 0; i < NumArgsToCheck; ++i) { | ||||
ExprResult Arg = TheCall->getArg(i); | ||||
QualType ArgTy = Arg.get()->getType(); | ||||
auto *VTy = ArgTy->getAs<VectorType>(); | ||||
if (VTy == nullptr) { | ||||
SemaRef.Diag(Arg.get()->getBeginLoc(), | ||||
diag::err_typecheck_convert_incompatible) | ||||
<< ArgTy | ||||
<< SemaRef.Context.getVectorType(ArgTy, 2, VectorKind::Generic) << 1 | ||||
<< 0 << 0; | ||||
return true; | ||||
} | ||||
} | ||||
return false; | ||||
} | ||||
|
||||
bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID, | ||||
CallExpr *TheCall) { | ||||
switch (BuiltinID) { | ||||
case SPIRV::BI__builtin_spirv_distance: { | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does this not need to check that the element type is a float? Is that checked previously? |
||||
if (SemaRef.checkArgCount(TheCall, 2)) | ||||
return true; | ||||
|
||||
ExprResult A = TheCall->getArg(0); | ||||
QualType ArgTyA = A.get()->getType(); | ||||
auto *VTyA = ArgTyA->getAs<VectorType>(); | ||||
if (VTyA == nullptr) { | ||||
SemaRef.Diag(A.get()->getBeginLoc(), | ||||
diag::err_typecheck_convert_incompatible) | ||||
<< ArgTyA | ||||
<< SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1 | ||||
<< 0 << 0; | ||||
return true; | ||||
} | ||||
|
||||
ExprResult B = TheCall->getArg(1); | ||||
QualType ArgTyB = B.get()->getType(); | ||||
auto *VTyB = ArgTyB->getAs<VectorType>(); | ||||
if (VTyB == nullptr) { | ||||
SemaRef.Diag(A.get()->getBeginLoc(), | ||||
diag::err_typecheck_convert_incompatible) | ||||
<< ArgTyB | ||||
<< SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1 | ||||
<< 0 << 0; | ||||
// Use the helper function to check both arguments | ||||
if (CheckVectorArgs(TheCall, 2)) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if you want to match the new error message and style convention in SemaHLSL file, you can do something like this llvm-project/clang/lib/Sema/SemaHLSL.cpp Line 2855 in cd46354
CheckFloatOrHalfRepresentation checks if its a scalar or a vector but you could check if there is a different existing function which checks for vectors of half or float, or write a new function that does this. |
||||
return true; | ||||
} | ||||
|
||||
QualType RetTy = VTyA->getElementType(); | ||||
QualType RetTy = | ||||
TheCall->getArg(0)->getType()->getAs<VectorType>()->getElementType(); | ||||
TheCall->setType(RetTy); | ||||
break; | ||||
} | ||||
case SPIRV::BI__builtin_spirv_length: { | ||||
if (SemaRef.checkArgCount(TheCall, 1)) | ||||
return true; | ||||
ExprResult A = TheCall->getArg(0); | ||||
QualType ArgTyA = A.get()->getType(); | ||||
auto *VTy = ArgTyA->getAs<VectorType>(); | ||||
if (VTy == nullptr) { | ||||
SemaRef.Diag(A.get()->getBeginLoc(), | ||||
diag::err_typecheck_convert_incompatible) | ||||
<< ArgTyA | ||||
<< SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1 | ||||
<< 0 << 0; | ||||
|
||||
// Use the helper function to check the argument | ||||
if (CheckVectorArgs(TheCall, 1)) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same question here about if you should be checking if the element type is float. And Same comment about the style from SemaHLSL |
||||
return true; | ||||
} | ||||
QualType RetTy = VTy->getElementType(); | ||||
|
||||
QualType RetTy = | ||||
TheCall->getArg(0)->getType()->getAs<VectorType>()->getElementType(); | ||||
TheCall->setType(RetTy); | ||||
break; | ||||
} | ||||
case SPIRV::BI__builtin_spirv_reflect: { | ||||
if (SemaRef.checkArgCount(TheCall, 2)) | ||||
case SPIRV::BI__builtin_spirv_refract: { | ||||
if (SemaRef.checkArgCount(TheCall, 3)) | ||||
return true; | ||||
|
||||
ExprResult A = TheCall->getArg(0); | ||||
QualType ArgTyA = A.get()->getType(); | ||||
auto *VTyA = ArgTyA->getAs<VectorType>(); | ||||
if (VTyA == nullptr) { | ||||
SemaRef.Diag(A.get()->getBeginLoc(), | ||||
diag::err_typecheck_convert_incompatible) | ||||
<< ArgTyA | ||||
<< SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1 | ||||
<< 0 << 0; | ||||
// Use the helper function to check the first two arguments | ||||
if (CheckVectorArgs(TheCall, 2)) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment here about following the SemaHLSL style. |
||||
return true; | ||||
} | ||||
|
||||
ExprResult B = TheCall->getArg(1); | ||||
QualType ArgTyB = B.get()->getType(); | ||||
auto *VTyB = ArgTyB->getAs<VectorType>(); | ||||
if (VTyB == nullptr) { | ||||
SemaRef.Diag(A.get()->getBeginLoc(), | ||||
diag::err_typecheck_convert_incompatible) | ||||
<< ArgTyB | ||||
<< SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1 | ||||
<< 0 << 0; | ||||
ExprResult C = TheCall->getArg(2); | ||||
QualType ArgTyC = C.get()->getType(); | ||||
if (!ArgTyC->isFloatingType()) { | ||||
SemaRef.Diag(C.get()->getBeginLoc(), diag::err_builtin_invalid_arg_type) | ||||
<< 3 << /* scalar*/ 5 << /* no int */ 0 << /* fp */ 1 << ArgTyC; | ||||
return true; | ||||
} | ||||
|
||||
QualType RetTy = ArgTyA; | ||||
QualType RetTy = TheCall->getArg(0)->getType(); | ||||
TheCall->setType(RetTy); | ||||
break; | ||||
} | ||||
case SPIRV::BI__builtin_spirv_reflect: { | ||||
if (SemaRef.checkArgCount(TheCall, 2)) | ||||
return true; | ||||
|
||||
// Use the helper function to check both arguments | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same question here about if you need to check for float. and same comment about semahlsl style. |
||||
if (CheckVectorArgs(TheCall, 2)) | ||||
return true; | ||||
|
||||
QualType RetTy = TheCall->getArg(0)->getType(); | ||||
TheCall->setType(RetTy); | ||||
break; | ||||
} | ||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You need to add
CustomTypeChecking
.