From 1f48cbc3a7a031e50dc383859c78133616ee3bd9 Mon Sep 17 00:00:00 2001 From: Joshua Cranmer Date: Mon, 12 Sep 2022 17:02:12 -0400 Subject: [PATCH] Add a helper for type-aware mutateCallInst replacement. (#1611) Opaque pointers requires us to keep more careful track of the pointee types of the builtin methods. While these are discoverable via demangling, the existing interface of mutateCallInst* does not provide easy ways to maintain these types during remapping of OpenCL builtins to SPIR-V builtins. This patch adds a BuiltinCallMutator that keeps track not only of the pointer element types discovered by demangling, but also of the LLVM attributes the arguments originally had, even as arguments are inserted or deleted for OpenCL<->SPIR-V builtin conversions. Additionally, the interface is intended to be more forgiving of the varieties of lambdas needed to do the conversions: for example, there is no need for different entry points to indicate a type-aware mapping, or even to indicate a function whose return type changes. It is intended that all of the existing calls to mutateCallInst be replaced with this new interface, so as to be able to eliminate the getPointerElementType call in the existing implementation of that function. However, the actual migration of the calls will be done in future patches to keep the maximum size of patches relatively small. Original commit: https://github.com/KhronosGroup/SPIRV-LLVM-Translator/commit/139b08b --- llvm-spirv/lib/SPIRV/CMakeLists.txt | 1 + llvm-spirv/lib/SPIRV/OCLToSPIRV.cpp | 4 +- llvm-spirv/lib/SPIRV/OCLToSPIRV.h | 7 +- llvm-spirv/lib/SPIRV/OCLUtil.cpp | 15 +- llvm-spirv/lib/SPIRV/OCLUtil.h | 6 + llvm-spirv/lib/SPIRV/SPIRVBuiltinHelper.cpp | 213 +++++++++++++++ llvm-spirv/lib/SPIRV/SPIRVBuiltinHelper.h | 279 ++++++++++++++++++++ llvm-spirv/lib/SPIRV/SPIRVReader.cpp | 7 +- llvm-spirv/lib/SPIRV/SPIRVReader.h | 3 +- llvm-spirv/lib/SPIRV/SPIRVToOCL.h | 8 +- 10 files changed, 524 insertions(+), 19 deletions(-) create mode 100644 llvm-spirv/lib/SPIRV/SPIRVBuiltinHelper.cpp create mode 100644 llvm-spirv/lib/SPIRV/SPIRVBuiltinHelper.h diff --git a/llvm-spirv/lib/SPIRV/CMakeLists.txt b/llvm-spirv/lib/SPIRV/CMakeLists.txt index 86b79befa3358..9e8b9228e8021 100644 --- a/llvm-spirv/lib/SPIRV/CMakeLists.txt +++ b/llvm-spirv/lib/SPIRV/CMakeLists.txt @@ -9,6 +9,7 @@ set(SRC_LIST OCLTypeToSPIRV.cpp OCLUtil.cpp VectorComputeUtil.cpp + SPIRVBuiltinHelper.cpp SPIRVLowerBitCastToNonStandardType.cpp SPIRVLowerBool.cpp SPIRVLowerConstExpr.cpp diff --git a/llvm-spirv/lib/SPIRV/OCLToSPIRV.cpp b/llvm-spirv/lib/SPIRV/OCLToSPIRV.cpp index dc360117d7e98..63438e29d0d32 100644 --- a/llvm-spirv/lib/SPIRV/OCLToSPIRV.cpp +++ b/llvm-spirv/lib/SPIRV/OCLToSPIRV.cpp @@ -160,7 +160,7 @@ void OCLToSPIRVBase::transVecLoadStoreName(std::string &DemangledName, char OCLToSPIRVLegacy::ID = 0; bool OCLToSPIRVBase::runOCLToSPIRV(Module &Module) { - M = &Module; + initialize(Module); Ctx = &M->getContext(); auto Src = getSPIRVSource(&Module); // This is a pre-processing pass, which transform LLVM IR module to a more @@ -1581,7 +1581,7 @@ void OCLToSPIRVBase::visitCallKernelQuery(CallInst *CI, auto *BlockF = cast(getUnderlyingObject(BlockFVal)); AttributeList Attrs = CI->getCalledFunction()->getAttributes(); - mutateCallInst( + ::mutateCallInst( M, CI, [=](CallInst *CI, std::vector &Args) { Value *Param = *Args.rbegin(); diff --git a/llvm-spirv/lib/SPIRV/OCLToSPIRV.h b/llvm-spirv/lib/SPIRV/OCLToSPIRV.h index 8707ad88d41ed..4efead0b659ff 100644 --- a/llvm-spirv/lib/SPIRV/OCLToSPIRV.h +++ b/llvm-spirv/lib/SPIRV/OCLToSPIRV.h @@ -41,6 +41,7 @@ #define SPIRV_OCLTOSPIRV_H #include "OCLUtil.h" +#include "SPIRVBuiltinHelper.h" #include "llvm/IR/InstVisitor.h" #include "llvm/IR/PassManager.h" @@ -50,10 +51,11 @@ namespace SPIRV { class OCLTypeToSPIRVBase; -class OCLToSPIRVBase : public InstVisitor { +class OCLToSPIRVBase : public InstVisitor, BuiltinCallHelper { public: OCLToSPIRVBase() - : M(nullptr), Ctx(nullptr), CLVer(0), OCLTypeToSPIRVPtr(nullptr) {} + : BuiltinCallHelper(ManglingRules::SPIRV), Ctx(nullptr), CLVer(0), + OCLTypeToSPIRVPtr(nullptr) {} virtual ~OCLToSPIRVBase() {} bool runOCLToSPIRV(Module &M); @@ -264,7 +266,6 @@ class OCLToSPIRVBase : public InstVisitor { OCLTypeToSPIRVBase *getOCLTypeToSPIRV() { return OCLTypeToSPIRVPtr; } private: - Module *M; LLVMContext *Ctx; unsigned CLVer; /// OpenCL version as major*10+minor std::set ValuesToDelete; diff --git a/llvm-spirv/lib/SPIRV/OCLUtil.cpp b/llvm-spirv/lib/SPIRV/OCLUtil.cpp index 4c2e247f08ff0..38b70ce5faed0 100644 --- a/llvm-spirv/lib/SPIRV/OCLUtil.cpp +++ b/llvm-spirv/lib/SPIRV/OCLUtil.cpp @@ -979,9 +979,7 @@ static FunctionType *getBlockInvokeTy(Function *F, unsigned BlockIdx) { class OCLBuiltinFuncMangleInfo : public SPIRV::BuiltinFuncMangleInfo { public: OCLBuiltinFuncMangleInfo(Function *F) : F(F) {} - OCLBuiltinFuncMangleInfo(ArrayRef ArgTypes) - : ArgTypes(ArgTypes.vec()) {} - Type *getArgTy(unsigned I) { return F->getFunctionType()->getParamType(I); } + OCLBuiltinFuncMangleInfo() = default; void init(StringRef UniqName) override { // Make a local copy as we will modify the string in init function std::string TempStorage = UniqName.str(); @@ -1305,10 +1303,7 @@ class OCLBuiltinFuncMangleInfo : public SPIRV::BuiltinFuncMangleInfo { } // Auxiliarry information, it is expected that it is relevant at the moment // the init method is called. - Function *F; // SPIRV decorated function - // TODO: ArgTypes argument should get removed once all SPV-IR related issues - // are resolved - std::vector ArgTypes; // Arguments of OCL builtin + Function *F; // SPIRV decorated function }; CallInst *mutateCallInstOCL( @@ -1330,6 +1325,10 @@ Instruction *mutateCallInstOCL( TakeFuncName); } +std::unique_ptr makeMangler(Function &F) { + return std::make_unique(&F); +} + static StringRef getStructName(Type *Ty) { if (auto *STy = dyn_cast(Ty)) return STy->isLiteral() ? "" : Ty->getStructName(); @@ -1603,6 +1602,6 @@ Value *SPIRV::transSPIRVMemorySemanticsIntoOCLMemFenceFlags( void llvm::mangleOpenClBuiltin(const std::string &UniqName, ArrayRef ArgTypes, std::string &MangledName) { - OCLUtil::OCLBuiltinFuncMangleInfo BtnInfo(ArgTypes); + OCLUtil::OCLBuiltinFuncMangleInfo BtnInfo; MangledName = SPIRV::mangleBuiltin(UniqName, ArgTypes, &BtnInfo); } diff --git a/llvm-spirv/lib/SPIRV/OCLUtil.h b/llvm-spirv/lib/SPIRV/OCLUtil.h index 8429df5c5db1d..8cff65e0647fe 100644 --- a/llvm-spirv/lib/SPIRV/OCLUtil.h +++ b/llvm-spirv/lib/SPIRV/OCLUtil.h @@ -54,6 +54,10 @@ using namespace SPIRV; using namespace llvm; using namespace spv; +namespace SPIRV { +class BuiltinCallMutator; +} // namespace SPIRV + namespace OCLUtil { /////////////////////////////////////////////////////////////////////////////// @@ -520,6 +524,8 @@ std::string getIntelSubgroupBlockDataPostfix(unsigned ElementBitSize, void insertImageNameAccessQualifier(SPIRVAccessQualifierKind Acc, std::string &Name); + +std::unique_ptr makeMangler(Function &F); } // namespace OCLUtil using namespace OCLUtil; diff --git a/llvm-spirv/lib/SPIRV/SPIRVBuiltinHelper.cpp b/llvm-spirv/lib/SPIRV/SPIRVBuiltinHelper.cpp new file mode 100644 index 0000000000000..fae86366f40d1 --- /dev/null +++ b/llvm-spirv/lib/SPIRV/SPIRVBuiltinHelper.cpp @@ -0,0 +1,213 @@ +//===- SPIRVBuiltinHelper.cpp - Helpers for managing calls to builtins ----===// +// +// The LLVM/SPIR-V Translator +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +// Copyright (c) 2022 The Khronos Group Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal with the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimers. +// Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimers in the documentation +// and/or other materials provided with the distribution. +// Neither the names of The Khronos Group, nor the names of its +// contributors may be used to endorse or promote products derived from this +// Software without specific prior written permission. +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS WITH +// THE SOFTWARE. +// +//===----------------------------------------------------------------------===// +// +// This file implements helper functions for adding calls to OpenCL or SPIR-V +// builtin functions, or for rewriting calls to one into calls to the other. +// +//===----------------------------------------------------------------------===// + +#include "SPIRVBuiltinHelper.h" + +#include "OCLUtil.h" +#include "SPIRVInternal.h" + +using namespace llvm; +using namespace SPIRV; + +static std::unique_ptr makeMangler(CallBase *CB, + ManglingRules Rules) { + switch (Rules) { + case ManglingRules::None: + return nullptr; + case ManglingRules::SPIRV: + return std::make_unique(); + case ManglingRules::OpenCL: + return OCLUtil::makeMangler(*CB->getCalledFunction()); + } + llvm_unreachable("Unknown mangling rules to make a name mangler"); +} + +BuiltinCallMutator::BuiltinCallMutator( + CallInst *CI, std::string FuncName, ManglingRules Rules, + std::function NameMapFn) + : CI(CI), FuncName(FuncName), + Attrs(CI->getCalledFunction()->getAttributes()), ReturnTy(CI->getType()), + Args(CI->args()), Rules(Rules), Builder(CI) { + getParameterTypes(CI->getCalledFunction(), PointerTypes, + std::move(NameMapFn)); + PointerTypes.resize(Args.size(), nullptr); +} + +BuiltinCallMutator::BuiltinCallMutator(BuiltinCallMutator &&Other) + : CI(Other.CI), FuncName(std::move(Other.FuncName)), + MutateRet(std::move(Other.MutateRet)), Attrs(Other.Attrs), + ReturnTy(Other.ReturnTy), Args(std::move(Other.Args)), + PointerTypes(std::move(Other.PointerTypes)), + Rules(std::move(Other.Rules)), Builder(CI) { + // Clear the other's CI instance so that it knows not to construct the actual + // call. + Other.CI = nullptr; +} + +Value *BuiltinCallMutator::doConversion() { + assert(CI && "Need to have a call instruction to do the conversion"); + auto Mangler = makeMangler(CI, Rules); + for (unsigned I = 0; I < Args.size(); I++) { + Mangler->getTypeMangleInfo(I).PointerTy = PointerTypes[I]; + } + assert(Attrs.getNumAttrSets() <= Args.size() + 2 && "Too many attributes?"); + CallInst *NewCall = + Builder.Insert(addCallInst(CI->getModule(), FuncName, ReturnTy, Args, + &Attrs, nullptr, Mangler.get())); + Value *Result = MutateRet ? MutateRet(Builder, NewCall) : NewCall; + Result->takeName(CI); + if (!CI->getType()->isVoidTy()) + CI->replaceAllUsesWith(Result); + CI->dropAllReferences(); + CI->eraseFromParent(); + CI = nullptr; + return Result; +} + +BuiltinCallMutator &BuiltinCallMutator::setArgs(ArrayRef NewArgs) { + // Retain only the function attributes, not any parameter attributes. + Attrs = AttributeList::get(CI->getContext(), Attrs.getFnAttrs(), + Attrs.getRetAttrs(), {}); + Args.clear(); + PointerTypes.clear(); + for (Value *Arg : NewArgs) { + assert(!Arg->getType()->isPointerTy() && + "Cannot use this signature with pointer types"); + Args.push_back(Arg); + PointerTypes.emplace_back(); + } + return *this; +} + +// This is a helper method to handle splicing of the attribute lists, as +// llvm::AttributeList doesn't have any helper methods for this sort of design. +// (It's designed to be manually built-up, not adjusted to add/remove +// arguments on the fly). +static void moveAttributes(LLVMContext &Ctx, AttributeList &Attrs, + unsigned Start, unsigned Len, unsigned Dest) { + SmallVector, 6> NewAttrs; + for (unsigned Index : Attrs.indexes()) { + AttributeSet AttrSet = Attrs.getAttributes(Index); + if (!AttrSet.hasAttributes()) + continue; + + // If the attribute is a parameter index, check to see how its index should + // be adjusted. + if (Index > AttributeList::FirstArgIndex) { + unsigned ParamIndex = Index - AttributeList::FirstArgIndex; + if (ParamIndex >= Start && ParamIndex < Start + Len) + // A parameter in this range needs to have its index adjusted to its + // destination location. + Index += Dest - Start; + else if (ParamIndex >= Dest && ParamIndex < Dest + Len) + // This parameter will be overwritten by one of the moved parameters, so + // omit it entirely. + continue; + } + + // The array is usually going to be sorted, but because of the above + // adjustment, we might end up out of order. This logic ensures that the + // array always remains in sorted order. + std::pair ToInsert(Index, AttrSet); + NewAttrs.insert(llvm::lower_bound(NewAttrs, ToInsert, llvm::less_first()), + ToInsert); + } + Attrs = AttributeList::get(Ctx, NewAttrs); +} + +// Convert a ValueTypePair to a TypedPointerType for storing in the PointerTypes +// array. +static TypedPointerType *toTPT(BuiltinCallMutator::ValueTypePair Pair) { + if (!Pair.second) + return nullptr; + unsigned AS = 0; + if (auto *TPT = dyn_cast(Pair.first->getType())) + AS = TPT->getAddressSpace(); + else if (isa(Pair.first->getType())) + AS = Pair.first->getType()->getPointerAddressSpace(); + return TypedPointerType::get(Pair.second, AS); +} + +BuiltinCallMutator &BuiltinCallMutator::insertArg(unsigned Index, + ValueTypePair Arg) { + Args.insert(Args.begin() + Index, Arg.first); + PointerTypes.insert(PointerTypes.begin() + Index, toTPT(Arg)); + moveAttributes(CI->getContext(), Attrs, Index, Args.size() - Index, + Index + 1); + return *this; +} + +BuiltinCallMutator &BuiltinCallMutator::replaceArg(unsigned Index, + ValueTypePair Arg) { + Args[Index] = Arg.first; + PointerTypes[Index] = toTPT(Arg); + Attrs = Attrs.removeParamAttributes(CI->getContext(), Index); + return *this; +} + +BuiltinCallMutator &BuiltinCallMutator::removeArg(unsigned Index) { + // If the argument being dropped is the last one, there is nothing to move, so + // just remove the attributes. + if (Index == Args.size() - 1) + Attrs = Attrs.removeParamAttributes(CI->getContext(), Index); + else + moveAttributes(CI->getContext(), Attrs, Index + 1, Args.size() - Index - 1, + Index); + Args.erase(Args.begin() + Index); + PointerTypes.erase(PointerTypes.begin() + Index); + return *this; +} + +BuiltinCallMutator & +BuiltinCallMutator::changeReturnType(Type *NewReturnTy, + MutateRetFuncTy MutateFunc) { + ReturnTy = NewReturnTy; + MutateRet = std::move(MutateFunc); + return *this; +} + +BuiltinCallMutator BuiltinCallHelper::mutateCallInst(CallInst *CI, + spv::Op Opcode) { + return mutateCallInst(CI, getSPIRVFuncName(Opcode)); +} + +BuiltinCallMutator BuiltinCallHelper::mutateCallInst(CallInst *CI, + std::string FuncName) { + return BuiltinCallMutator(CI, std::move(FuncName), Rules, NameMapFn); +} diff --git a/llvm-spirv/lib/SPIRV/SPIRVBuiltinHelper.h b/llvm-spirv/lib/SPIRV/SPIRVBuiltinHelper.h new file mode 100644 index 0000000000000..e5658a0f0db89 --- /dev/null +++ b/llvm-spirv/lib/SPIRV/SPIRVBuiltinHelper.h @@ -0,0 +1,279 @@ +//===- SPIRVBuiltinHelper.h - Helpers for managing calls to builtins ------===// +// +// The LLVM/SPIR-V Translator +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +// Copyright (c) 2022 The Khronos Group Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a +// copy of this software and associated documentation files (the "Software"), +// to deal with the Software without restriction, including without limitation +// the rights to use, copy, modify, merge, publish, distribute, sublicense, +// and/or sell copies of the Software, and to permit persons to whom the +// Software is furnished to do so, subject to the following conditions: +// +// Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimers. +// Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimers in the documentation +// and/or other materials provided with the distribution. +// Neither the names of The Khronos Group, nor the names of its +// contributors may be used to endorse or promote products derived from this +// Software without specific prior written permission. +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS WITH +// THE SOFTWARE. +// +//===----------------------------------------------------------------------===// +// +// This file implements helper functions for adding calls to OpenCL or SPIR-V +// builtin functions, or for rewriting calls to one into calls to the other. +// +//===----------------------------------------------------------------------===// + +#ifndef SPIRVBUILTINHELPER_H +#define SPIRVBUILTINHELPER_H + +#include "LLVMSPIRVLib.h" +#include "libSPIRV/SPIRVOpCode.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/TypedPointerType.h" + +namespace SPIRV { +enum class ManglingRules { None, OpenCL, SPIRV }; + +namespace detail { +/// This is a helper for triggering the static_assert in mapArg. +template constexpr bool LegalFnType = false; +} // namespace detail + +/// A helper class for changing OpenCL builtin function calls to SPIR-V function +/// calls, or vice versa. Most of the functions will return a reference to the +/// current instance, allowing calls to be chained together, for example: +/// mutateCallInst(CI, NewFuncName) +/// .removeArg(3) +/// .appendArg(translateScope()); +/// +/// Only when the destuctor of this object is called will the original CallInst +/// be destroyed and replaced with the new CallInst be created. +class BuiltinCallMutator { + // Original call instruction + llvm::CallInst *CI; + // New unmangled function name + std::string FuncName; + // Return type mutator. This needs to be saved, because we can't call it until + // the new instruction is created. + std::function &, llvm::CallInst *)> MutateRet; + typedef decltype(MutateRet) MutateRetFuncTy; + // The attribute list for the new call instruction. + llvm::AttributeList Attrs; + // The return type for the new call instruction. + llvm::Type *ReturnTy; + // The arguments for the new call instruction. + llvm::SmallVector Args; + // The pointer element types for the new call instruction. + llvm::SmallVector PointerTypes; + // The mangler rules to use for the new call instruction. + ManglingRules Rules; + + friend class BuiltinCallHelper; + BuiltinCallMutator( + llvm::CallInst *CI, std::string FuncName, ManglingRules Rules, + std::function NameMapFn = nullptr); + + // This does the actual work of creating of the new call, and will return the + // new instruction. + llvm::Value *doConversion(); + +public: + ~BuiltinCallMutator() { + if (CI) + doConversion(); + } + BuiltinCallMutator(const BuiltinCallMutator &) = delete; + BuiltinCallMutator &operator=(const BuiltinCallMutator &) = delete; + BuiltinCallMutator &operator=(BuiltinCallMutator &&) = delete; + BuiltinCallMutator(BuiltinCallMutator &&); + + /// The builder used to generate IR for this call. + llvm::IRBuilder<> Builder; + + /// Return the resulting new instruction. It is not possible to use any + /// method on this object after calling this function. + llvm::Value *getMutated() { return doConversion(); } + + /// Return the number of arguments currently specified for the new call. + unsigned arg_size() const { return Args.size(); } + + /// Get the corresponding argument for the new call. + llvm::Value *getArg(unsigned Index) const { return Args[Index]; } + + /// Return the pointer element type of the corresponding index, or nullptr if + /// it is not a pointer. + llvm::Type *getPointerElementType(unsigned Index) const { + llvm::TypedPointerType *ElTy = PointerTypes[Index]; + return ElTy ? ElTy->getElementType() : nullptr; + } + + /// A pair representing both the LLVM value of an argument and its + /// corresponding pointer element type. This type can be constructed from + /// implicit conversion from an LLVM value object (but only if it is not of + /// pointer type), or by the appropriate std::pair type. + struct ValueTypePair : public std::pair { + ValueTypePair(llvm::Value *V) : pair(V, nullptr) { + assert(!V->getType()->isPointerTy() && + "Must specify a pointer element type if value is a pointer."); + } + ValueTypePair(std::pair P) : pair(P) {} + ValueTypePair(llvm::Value *V, llvm::Type *T) : pair(V, T) {} + ValueTypePair() = delete; + using pair::pair; + }; + + /// Use the following arguments as the arguments of the new call, replacing + /// any previous arguments. This version may not be used if any argument is of + /// pointer type. + BuiltinCallMutator &setArgs(llvm::ArrayRef Args); + + /// This will replace the return type of the call with a different return + /// type. The second argument is a function that will be called with an + /// IRBuilder parameter and the newly generated function, and will return the + /// value to replace all uses of the original call instruction with. Example + /// usage: + /// + /// BuiltinCallMutator Mutator = /* ... */; + /// Mutator.changeReturnType(Int16Ty, [](IRBuilder<> &IRB, CallInst *CI) { + /// return IRB.CreateZExt(CI, Int16Ty); + /// }); + BuiltinCallMutator &changeReturnType(llvm::Type *ReturnTy, + MutateRetFuncTy MutateFunc); + + /// Insert an argument before the given index. + BuiltinCallMutator &insertArg(unsigned Index, ValueTypePair Arg); + + /// Add an argument to the end of the argument list. + BuiltinCallMutator &appendArg(ValueTypePair Arg) { + return insertArg(Args.size(), Arg); + } + + /// Replace the argument at the given index with a new value. + BuiltinCallMutator &replaceArg(unsigned Index, ValueTypePair Arg); + + /// Remove the argument at the given index. + BuiltinCallMutator &removeArg(unsigned Index); + + /// Remove all arguments in a range. + BuiltinCallMutator &removeArgs(unsigned Start, unsigned Len) { + for (unsigned I = 0; I < Len; I++) + removeArg(Start); + return *this; + } + + /// Move the argument from the given index to the new index. + BuiltinCallMutator &moveArg(unsigned FromIndex, unsigned ToIndex) { + if (FromIndex == ToIndex) + return *this; + ValueTypePair Pair(Args[FromIndex], getPointerElementType(FromIndex)); + removeArg(FromIndex); + insertArg(ToIndex, Pair); + return *this; + } + + /// Use a callback function or lambda to convert an argument to a new value. + /// The expected return type of the lambda is anything that is convertible + /// to ValueTypePair, which could be a single Value* (but only if it is not + /// pointer-typed), or a std::pair. The possible signatures + /// of the function parameter are as follows: + /// ValueTypePair func(IRBuilder<> &Builder, Value *, Type *); + /// ValueTypePair func(IRBuilder<> &Builder, Value *); + /// ValueTypePair func(Value *, Type *); + /// ValueTypePair func(Value *); + /// + /// When present, the IRBuilder parameter corresponds to a builder that is set + /// to insert immediately before the new call instruction. The Value parameter + /// corresponds to the argument to be mutated. The Type parameter, when + /// present, corresponds to the pointer element type of the argument, or null + /// when it is not present. + template + BuiltinCallMutator &mapArg(unsigned Index, FnType Func) { + using namespace llvm; + using std::is_invocable; + IRBuilder<> Builder(CI); + Value *V = Args[Index]; + [[maybe_unused]] Type *T = getPointerElementType(Index); + + // Dispatch the function call as appropriate, based on the types that the + // function may be called with. + if constexpr (is_invocable &, Value *, Type *>::value) + replaceArg(Index, Func(Builder, V, T)); + else if constexpr (is_invocable &, Value *>::value) + replaceArg(Index, Func(Builder, V)); + else if constexpr (is_invocable::value) + replaceArg(Index, Func(V, T)); + else if constexpr (is_invocable::value) + replaceArg(Index, Func(V)); + else { + // We need a helper value that is always false, but is dependent on the + // template parameter to prevent this static_assert from firing when one + // of the if constexprs above fires. + static_assert(detail::LegalFnType, + "mapArg lambda signature is not satisfied"); + } + return *this; + } + + /// Map all arguments according to the given function, as if mapArg(i, Func) + /// had been called for every argument i. + template BuiltinCallMutator &mapArgs(FnType Func) { + for (unsigned I = 0, E = Args.size(); I < E; I++) + mapArg(I, Func); + return *this; + } +}; + +/// A helper class for generating calls to SPIR-V builtins with appropriate name +/// mangling rules. It is expected that transformation passes inherit from this +/// class. +class BuiltinCallHelper { + ManglingRules Rules; + std::function NameMapFn; + +protected: + llvm::Module *M = nullptr; + +public: + /// Initialize details about how to mangle and demangle builtins correctly. + /// The Rules argument selects which name mangler to use for mangling. + /// The NameMapFn function will map type names during demangling; it defaults + /// to the identity function. + explicit BuiltinCallHelper( + ManglingRules Rules, + std::function NameMapFn = nullptr) + : Rules(Rules), NameMapFn(std::move(NameMapFn)) {} + + /// Initialize the module that will be operated on. This method must be called + /// before future methods. + void initialize(llvm::Module &M) { this->M = &M; } + + /// Return a mutator that will replace the given call instruction with a call + /// to the given function name. The function name will have its name mangled + /// in accordance with the argument types provided to the mutator. + BuiltinCallMutator mutateCallInst(llvm::CallInst *CI, std::string FuncName); + + /// Return a mutator that will replace the given call instruction with a call + /// to the given SPIR-V opcode (whose name is used in the lookup map of + /// getSPIRVFuncName). + BuiltinCallMutator mutateCallInst(llvm::CallInst *CI, spv::Op Opcode); +}; + +} // namespace SPIRV + +#endif // SPIRVBUILTINHELPER_H diff --git a/llvm-spirv/lib/SPIRV/SPIRVReader.cpp b/llvm-spirv/lib/SPIRV/SPIRVReader.cpp index 09b31e36ef9cc..c3e276099b847 100644 --- a/llvm-spirv/lib/SPIRV/SPIRVReader.cpp +++ b/llvm-spirv/lib/SPIRV/SPIRVReader.cpp @@ -3119,7 +3119,8 @@ Instruction *SPIRVToLLVM::transBuiltinFromInst(const std::string &FuncName, } SPIRVToLLVM::SPIRVToLLVM(Module *LLVMModule, SPIRVModule *TheSPIRVModule) - : M(LLVMModule), BM(TheSPIRVModule) { + : BuiltinCallHelper(ManglingRules::OpenCL), M(LLVMModule), + BM(TheSPIRVModule) { assert(M && "Initialization without an LLVM module is not allowed"); Context = &M->getContext(); DbgTran.reset(new SPIRVToLLVMDbgTran(TheSPIRVModule, LLVMModule, this)); @@ -4457,7 +4458,7 @@ Instruction *SPIRVToLLVM::transAllAny(SPIRVInstruction *I, BasicBlock *BB) { BuiltinFuncMangleInfo BtnInfo; AttributeList Attrs = CI->getCalledFunction()->getAttributes(); return cast(mapValue( - I, mutateCallInst( + I, ::mutateCallInst( M, CI, [=](CallInst *, std::vector &Args) { auto *OldArg = CI->getOperand(0); @@ -4478,7 +4479,7 @@ Instruction *SPIRVToLLVM::transRelational(SPIRVInstruction *I, BasicBlock *BB) { BuiltinFuncMangleInfo BtnInfo; AttributeList Attrs = CI->getCalledFunction()->getAttributes(); return cast(mapValue( - I, mutateCallInst( + I, ::mutateCallInst( M, CI, [=](CallInst *, std::vector &Args, llvm::Type *&RetTy) { if (CI->getType()->isVectorTy()) { diff --git a/llvm-spirv/lib/SPIRV/SPIRVReader.h b/llvm-spirv/lib/SPIRV/SPIRVReader.h index 8b400c29ba751..c451b1a3df48c 100644 --- a/llvm-spirv/lib/SPIRV/SPIRVReader.h +++ b/llvm-spirv/lib/SPIRV/SPIRVReader.h @@ -41,6 +41,7 @@ #ifndef SPIRVREADER_H #define SPIRVREADER_H +#include "SPIRVBuiltinHelper.h" #include "SPIRVInternal.h" #include "SPIRVModule.h" @@ -74,7 +75,7 @@ class SPIRVConstantSampler; class SPIRVConstantPipeStorage; class SPIRVLoopMerge; class SPIRVToLLVMDbgTran; -class SPIRVToLLVM { +class SPIRVToLLVM : private BuiltinCallHelper { public: SPIRVToLLVM(Module *LLVMModule, SPIRVModule *TheSPIRVModule); diff --git a/llvm-spirv/lib/SPIRV/SPIRVToOCL.h b/llvm-spirv/lib/SPIRV/SPIRVToOCL.h index 572d6241ca0d2..dd910826ba467 100644 --- a/llvm-spirv/lib/SPIRV/SPIRVToOCL.h +++ b/llvm-spirv/lib/SPIRV/SPIRVToOCL.h @@ -42,6 +42,7 @@ #define SPIRVTOOCL_H #include "OCLUtil.h" +#include "SPIRVBuiltinHelper.h" #include "SPIRVInternal.h" #include "llvm/IR/InstVisitor.h" #include "llvm/IR/PassManager.h" @@ -51,9 +52,12 @@ namespace SPIRV { -class SPIRVToOCLBase : public InstVisitor { +class SPIRVToOCLBase : public InstVisitor, + protected BuiltinCallHelper { public: - SPIRVToOCLBase() : M(nullptr), Ctx(nullptr) {} + SPIRVToOCLBase() + : BuiltinCallHelper(ManglingRules::OpenCL, translateOpaqueType), + M(nullptr), Ctx(nullptr) {} virtual ~SPIRVToOCLBase() {} virtual bool runSPIRVToOCL(Module &M) = 0;