Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ void populatePrintOpToLLVMPattern(LLVMTypeConverter &typeConverter,
const TargetInfoBase &targetInfo,
PatternBenefit benefit);

void populateExternCallOpToLLVMPattern(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
const TargetInfoBase &targetInfo,
PatternBenefit benefit);

void populateInstrumentationToLLVMPatterns(LLVMTypeConverter &typeConverter,
const TargetInfoBase &targetInfo,
RewritePatternSet &patterns,
Expand Down
26 changes: 26 additions & 0 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,32 @@ def TT_MapElementwiseReturnOp: TT_Op<"map_elementwise.return",
let assemblyFormat = "attr-dict ($result^ `:` type($result))?";
}

//
// External Call op
//
def TT_ExternCallOp : TT_Op<"extern_call", [
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
ConditionallySpeculatable,
]> {

let description = [{
call an external function $symbol implemented in $libpath/$libname with $args
return $libpath/$libname:$symbol($args...)
}];

let arguments = (ins Variadic<TT_Type>:$srcs, StrAttr:$libname, StrAttr:$libpath, StrAttr:$symbol, BoolAttr:$pure);

let results = (outs Variadic<TT_Type>:$result);

let assemblyFormat = "operands attr-dict `:` functional-type(operands, $result)";

let extraClassDeclaration = [{
// Interface method for ConditionallySpeculatable.
Speculation::Speculatability getSpeculatability();
}];

}

//
// External Elementwise op
//
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TritonGPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ add_triton_library(TritonGPUToLLVM
TypeConverter.cpp
Utility.cpp
ViewOpToLLVM.cpp
ExternCallOpToLLVM.cpp

DEPENDS
TritonGPUConversionPassIncGen
Expand Down
61 changes: 61 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/ExternCallOpToLLVM.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Support/LLVM.h"
#include "triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h"
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"

namespace {

class ExternCallOpConversion
: public ConvertOpToLLVMPattern<triton::ExternCallOp> {
public:
ExternCallOpConversion(const LLVMTypeConverter &converter,
const PatternBenefit &benefit)
: ConvertOpToLLVMPattern<triton::ExternCallOp>(converter, benefit) {}

LogicalResult
matchAndRewrite(triton::ExternCallOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();

if (op->getNumResults() > 1) {
llvm::errs() << "ExternCallConversion does not support multi outs.";
return failure();
}

LLVM::LLVMVoidType voidTy = void_ty(op->getContext());
auto newOperands = adaptor.getOperands();
Type retType =
op->getNumResults() == 0
? voidTy
: this->getTypeConverter()->convertType(op->getResult(0).getType());
std::string funcName = op.getSymbol().str();
StringRef libname = op.getLibname();
StringRef libpath = op.getLibpath();

Operation *externCallOp;
Type funcType = mlir::triton::gpu::getFunctionType(retType, newOperands);
LLVM::LLVMFuncOp funcOp = mlir::triton::gpu::appendOrGetExternFuncOp(
rewriter, op, funcName, funcType, libname, libpath);
externCallOp = LLVM::createLLVMCallOp(rewriter, loc, funcOp, newOperands);

if (op->getNumResults() == 0) {
rewriter.eraseOp(op);
} else {
rewriter.replaceOp(op, externCallOp->getResult(0));
}

return success();
}
};

} // namespace

void mlir::triton::populateExternCallOpToLLVMPattern(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
const TargetInfoBase &targetInfo, PatternBenefit benefit) {
patterns.add<ExternCallOpConversion>(typeConverter, benefit);
}
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
GenericOpPattern<triton::HistogramOp>,
GenericOpPattern<triton::GatherOp>,
GenericOpPattern<triton::ExternElementwiseOp>,
GenericOpPattern<triton::ExternCallOp>,
GenericOpPattern<triton::PrintOp>,
GenericOpPattern<triton::AssertOp>,
GenericOpPattern<triton::AtomicCASOp>,
Expand Down
18 changes: 18 additions & 0 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1313,6 +1313,24 @@ Speculation::Speculatability ExternElementwiseOp::getSpeculatability() {
return Speculation::NotSpeculatable;
}

// -- ExternCallOp --
void ExternCallOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
if (getPure())
return;
effects.emplace_back(MemoryEffects::Write::get(),
SideEffects::DefaultResource::get());
effects.emplace_back(MemoryEffects::Read::get(),
SideEffects::DefaultResource::get());
}

Speculation::Speculatability ExternCallOp::getSpeculatability() {
if (getPure())
return Speculation::Speculatable;
return Speculation::NotSpeculatable;
}

// -- GatherOp --
LogicalResult GatherOp::verify() {
RankedTensorType indicesTy = getIndices().getType();
Expand Down
8 changes: 8 additions & 0 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1673,6 +1673,14 @@ void init_triton_ir(py::module &&m) {
return self.create<ExternElementwiseOp>(retType, argList, libName,
libPath, symbol, isPure);
})
.def("create_extern_call",
[](TritonOpBuilder &self, const std::string &libName,
const std::string &libPath, const std::string &symbol,
std::vector<Value> &argList, const std::vector<Type> &retTypes,
bool isPure) -> OpState {
return self.create<ExternCallOp>(retTypes, argList, libName,
libPath, symbol, isPure);
})
// Built-in instruction
.def("create_get_program_id",
[](TritonOpBuilder &self, int axis) -> Value {
Expand Down
30 changes: 29 additions & 1 deletion python/triton/experimental/tle/language/raw/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,26 @@
import triton.language as tl
from triton.language.core import builtin, constexpr as tl_constexpr, tensor
from triton.experimental.tle.language.gpu import buffered_tensor
import importlib.util


def _pointer_type_hash(self):
return hash((self.name, self.element_ty, "tt_ptr"))


def patch_hash_method_for_pointer_type():
elem_dtype_list = tl.core.dtype.SINT_TYPES + tl.core.dtype.UINT_TYPES + tl.core.dtype.FP_TYPES + tl.core.dtype.OTHER_TYPES
for elem_dtype in elem_dtype_list:
ptr_ty = type(tl.core.pointer_type(tl.core.dtype(elem_dtype)))
ptr_ty.__hash__ = _pointer_type_hash


def import_from_path(file_path):
module_name = f"_imported_{abs(hash(file_path))}"
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module


def _resolve_alias_indices(func, llvm, handles, output_indices, _semantic):
Expand Down Expand Up @@ -42,14 +62,22 @@ def _normalize_hint(hint):

def _tle_raw_call(func, args, *, output_indices, hint, smem, _semantic):
hint = _normalize_hint(hint)
handles = [arg.handle for arg in args]
if getattr(func, "deferred", False):
handles = [arg.handle for arg in args]
if output_indices is None:
raise RuntimeError("deferred tle_raw.call requires explicit output_indices=")
alias_indices = output_indices
source_id = func.register_pending_source(hint=hint)
dsl_region_op = func.create_region_deferred(_semantic.builder, source_id, handles, alias_indices, hint)
else:
if func.compiler.lower() == "nvcc" or (func.compiler.lower() == "clang" and func.target.lower() == "bc"):
patch_hash_method_for_pointer_type()
module = import_from_path(func.extern_file)
target_fn = getattr(module, func.extern_func_name)
ret = target_fn(*args, _semantic=_semantic)
return ret

handles = [arg.handle for arg in args]
context = _semantic.builder.get_context()
llvm = func.make_llvm(context)
alias_indices = _resolve_alias_indices(func, llvm, handles, output_indices, _semantic)
Expand Down
Loading
Loading