diff --git a/src/driver.jl b/src/driver.jl index 0bfdec31..2ed08061 100644 --- a/src/driver.jl +++ b/src/driver.jl @@ -130,6 +130,12 @@ function codegen(output::Symbol, @nospecialize(job::CompilerJob); error("Unknown compilation output $output") end +# GPUCompiler intrinsic that marks deferred compilation +function var"gpuc.deferred" end + +# GPUCompiler intrinsic that marks deferred compilation, across backends +function var"gpuc.deferred.with" end + # primitive mechanism for deferred compilation, for implementing CUDA dynamic parallelism. # this could both be generalized (e.g. supporting actual function calls, instead of # returning a function pointer), and be integrated with the nonrecursive codegen. @@ -157,6 +163,29 @@ end end end +function find_base_object(val) + while true + if val isa ConstantExpr && (opcode(val) == LLVM.API.LLVMIntToPtr || + opcode(val) == LLVM.API.LLVMBitCast || + opcode(val) == LLVM.API.LLVMAddrSpaceCast) + val = first(operands(val)) + elseif val isa LLVM.IntToPtrInst || val isa LLVM.BitCastInst || val isa LLVM.AddrSpaceCastInst + val = first(operands(val)) + elseif val isa LLVM.LoadInst + # In 1.11+ we no longer embed integer constants directly. + gv = first(operands(val)) + if gv isa LLVM.GlobalValue + val = LLVM.initializer(gv) + continue + end + break + else + break + end + end + return val +end + const __llvm_initialized = Ref(false) @locked function emit_llvm(@nospecialize(job::CompilerJob); @@ -186,8 +215,8 @@ const __llvm_initialized = Ref(false) entry = finish_module!(job, ir, entry) # deferred code generation - has_deferred_jobs = !only_entry && toplevel && - haskey(functions(ir), "deferred_codegen") + has_deferred_jobs = !only_entry && toplevel && haskey(functions(ir), "deferred_codegen") + jobs = Dict{CompilerJob, String}(job => entry_fn) if has_deferred_jobs dyn_marker = functions(ir)["deferred_codegen"] @@ -198,7 +227,6 @@ const __llvm_initialized = Ref(false) changed = false # find deferred compiler - # TODO: recover this information earlier, from the Julia IR worklist = Dict{CompilerJob, Vector{LLVM.CallInst}}() for use in uses(dyn_marker) # decode the call @@ -260,6 +288,50 @@ const __llvm_initialized = Ref(false) end # all deferred compilations should have been resolved + if dyn_marker !== nothing + @compiler_assert isempty(uses(dyn_marker)) job + unsafe_delete!(ir, dyn_marker) + end + end + + if haskey(functions(ir), "gpuc.lookup") + dyn_marker = functions(ir)["gpuc.lookup"] + + worklist = Dict{Any, Vector{LLVM.CallInst}}() + for use in uses(dyn_marker) + # decode the call + call = user(use)::LLVM.CallInst + dyn_mi_inst = find_base_object(operands(call)[1]) + @compiler_assert isa(dyn_mi_inst, LLVM.ConstantInt) job + dyn_mi = Base.unsafe_pointer_to_objref( + convert(Ptr{Cvoid}, convert(Int, dyn_mi_inst))) + push!(get!(worklist, dyn_mi, LLVM.CallInst[]), call) + end + + for dyn_mi in keys(worklist) + dyn_fn_name = compiled[dyn_mi].specfunc + dyn_fn = functions(ir)[dyn_fn_name] + + # insert a pointer to the function everywhere the entry is used + T_ptr = convert(LLVMType, Ptr{Cvoid}) + for call in worklist[dyn_mi] + @dispose builder=IRBuilder() begin + position!(builder, call) + fptr = if LLVM.version() >= v"17" + T_ptr = LLVM.PointerType() + bitcast!(builder, dyn_fn, T_ptr) + elseif VERSION >= v"1.12.0-DEV.225" + T_ptr = LLVM.PointerType(LLVM.Int8Type()) + bitcast!(builder, dyn_fn, T_ptr) + else + ptrtoint!(builder, dyn_fn, T_ptr) + end + replace_uses!(call, fptr) + end + unsafe_delete!(LLVM.parent(call), call) + end + end + @compiler_assert isempty(uses(dyn_marker)) job unsafe_delete!(ir, dyn_marker) end diff --git a/src/jlgen.jl b/src/jlgen.jl index a34bd42e..04aa8bdb 100644 --- a/src/jlgen.jl +++ b/src/jlgen.jl @@ -318,7 +318,8 @@ else get_method_table_view(world::UInt, mt::MTType) = OverlayMethodTable(world, mt) end -struct GPUInterpreter <: CC.AbstractInterpreter +abstract type AbstractGPUInterpreter <: CC.AbstractInterpreter end +struct GPUInterpreter <: AbstractGPUInterpreter world::UInt method_table::GPUMethodTableView @@ -435,6 +436,113 @@ function CC.concrete_eval_eligible(interp::GPUInterpreter, return ret end +struct DeferredCallInfo <: CC.CallInfo + rt::DataType + info::CC.CallInfo +end + +function CC.abstract_call_known(interp::AbstractGPUInterpreter, @nospecialize(f), + arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState, + max_methods::Int = CC.get_max_methods(interp, f, sv)) + (; fargs, argtypes) = arginfo + if f === var"gpuc.deferred" || f === var"gpuc.deferred.with" + first_arg = f === var"gpuc.deferred" ? 2 : 3 + argvec = argtypes[first_arg:end] + call = CC.abstract_call(interp, CC.ArgInfo(nothing, argvec), si, sv, max_methods) + callinfo = DeferredCallInfo(call.rt, call.info) + @static if VERSION < v"1.11.0-" + return CC.CallMeta(Ptr{Cvoid}, CC.Effects(), callinfo) + else + return CC.CallMeta(Ptr{Cvoid}, Union{}, CC.Effects(), callinfo) + end + end + return @invoke CC.abstract_call_known(interp::CC.AbstractInterpreter, f, + arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState, + max_methods::Int) +end + +# Use the Inlining infrastructure to perform our refinement +const FlagType = VERSION >= v"1.11.0-" ? UInt32 : UInt8 +function CC.handle_call!(todo::Vector{Pair{Int,Any}}, + ir::CC.IRCode, idx::CC.Int, stmt::Expr, info::DeferredCallInfo, flag::FlagType, sig::CC.Signature, + state::CC.InliningState) + + minfo = info.info + results = minfo.results + if length(results.matches) != 1 + return nothing + end + match = only(results.matches) + + # lookup the target mi with correct edge tracking + case = CC.compileable_specialization(match, CC.Effects(), CC.InliningEdgeTracker(state), info) + @assert case isa CC.InvokeCase + @assert stmt.head === :call + + f = stmt.args[1] + name = f === var"gpuc.deferred" ? "extern gpuc.lookup" : "extern gpuc.lookup.with" + with_arg_T = f === var"gpuc.deferred" ? () : (Any,) + + args = Any[ + name, + Ptr{Cvoid}, + Core.svec(Any, Any, with_arg_T..., match.spec_types.parameters[2:end]...), # Must use Any for MethodInstance or ftype + 0, + QuoteNode(:llvmcall), + case.invoke, + stmt.args[2:end]... + ] + stmt.head = :foreigncall + stmt.args = args + return nothing +end + +struct DeferredEdges + edges::Vector{MethodInstance} +end + +function find_deferred_edges(ir::CC.IRCode) + edges = MethodInstance[] + # @aviatesk: Can we add this instead in handle_call + for stmt in ir.stmts + inst = stmt[:inst] + inst isa Expr || continue + expr = inst::Expr + if expr.head === :foreigncall && + expr.args[1] == "extern gpuc.lookup" + deferred_mi = expr.args[6] + push!(edges, deferred_mi) + elseif expr.head === :foreigncall && + expr.args[1] == "extern gpuc.lookup.with" + deferred_mi = expr.args[6] + with = expr.args[7] + @show (deferred_mi, with) + end + end + unique!(edges) + return edges +end + +if VERSION >= v"1.11.0-" +# stack_analysis_result and ipo_dataflow_analysis is 1.11 only +function CC.ipo_dataflow_analysis!(interp::AbstractGPUInterpreter, ir::CC.IRCode, caller::CC.InferenceResult) + edges = find_deferred_edges(ir) + if !isempty(edges) + CC.stack_analysis_result!(caller, DeferredEdges(edges)) + end + @invoke CC.ipo_dataflow_analysis!(interp::CC.AbstractInterpreter, ir::CC.IRCode, caller::CC.InferenceResult) +end +else +# v1.10.0 +function CC.finish(interp::AbstractGPUInterpreter, opt::CC.OptimizationState, ir::CC.IRCode, caller::CC.InferenceResult) + edges = find_deferred_edges(ir) + if !isempty(edges) + # This is a tad bit risky, but nobody should be running EA on our results. + caller.argescapes = DeferredEdges(edges) + end + @invoke CC.finish(interp::CC.AbstractInterpreter, opt::CC.OptimizationState, ir::CC.IRCode, caller::CC.InferenceResult) +end +end ## world view of the cache using Core.Compiler: WorldView @@ -584,6 +692,24 @@ function compile_method_instance(@nospecialize(job::CompilerJob)) error("Cannot compile $(job.source) for world $(job.world); method is only valid in worlds $(job.source.def.primary_world) to $(job.source.def.deleted_world)") end + compiled = IdDict() + llvm_mod, outstanding = compile_method_instance(job, compiled) + worklist = outstanding + while !isempty(worklist) + source = pop!(worklist) + haskey(compiled, source) && continue + job2 = CompilerJob(source, job.config) + @debug "Processing..." job2 + llvm_mod2, outstanding = compile_method_instance(job2, compiled) + append!(worklist, outstanding) + @assert context(llvm_mod) == context(llvm_mod2) + link!(llvm_mod, llvm_mod2) + end + + return llvm_mod, compiled +end + +function compile_method_instance(@nospecialize(job::CompilerJob), compiled::IdDict{Any, Any}) # populate the cache interp = get_interpreter(job) cache = CC.code_cache(interp) @@ -594,7 +720,7 @@ function compile_method_instance(@nospecialize(job::CompilerJob)) # create a callback to look-up function in our cache, # and keep track of the method instances we needed. - method_instances = [] + method_instances = Any[] if Sys.ARCH == :x86 || Sys.ARCH == :x86_64 function lookup_fun(mi, min_world, max_world) push!(method_instances, mi) @@ -659,7 +785,6 @@ function compile_method_instance(@nospecialize(job::CompilerJob)) end # process all compiled method instances - compiled = Dict() for mi in method_instances ci = ci_cache_lookup(cache, mi, job.world, job.world) ci === nothing && continue @@ -696,10 +821,34 @@ function compile_method_instance(@nospecialize(job::CompilerJob)) compiled[mi] = (; ci, func=llvm_func, specfunc=llvm_specfunc) end + # Collect the deferred edges + outstanding = Any[] + for mi in method_instances + !haskey(compiled, mi) && continue # Equivalent to ci_cache_lookup == nothing + ci = compiled[mi].ci + @static if VERSION >= v"1.11.0-" + edges = CC.traverse_analysis_results(ci) do @nospecialize result + return result isa DeferredEdges ? result : return + end + else + edges = ci.argescapes + if !(edges isa Union{Nothing, DeferredEdges}) + edges = nothing + end + end + if edges !== nothing + for deferred_mi in (edges::DeferredEdges).edges + if !haskey(compiled, deferred_mi) + push!(outstanding, deferred_mi) + end + end + end + end + # ensure that the requested method instance was compiled @assert haskey(compiled, job.source) - return llvm_mod, compiled + return llvm_mod, outstanding end # partially revert JuliaLangjulia#49391 diff --git a/test/native_tests.jl b/test/native_tests.jl index 298c1010..9ea6ebd6 100644 --- a/test/native_tests.jl +++ b/test/native_tests.jl @@ -162,6 +162,20 @@ end ir = fetch(t) @test contains(ir, r"add i64 %\d+, 3") end + + @testset "deferred" begin + @gensym child kernel unrelated + @eval @noinline $child(i) = i + @eval $kernel(i) = GPUCompiler.var"gpuc.deferred"($child, i) + + # smoke test + job, _ = Native.create_job(eval(kernel), (Int64,)) + + ci, rt = only(GPUCompiler.code_typed(job)) + @test rt === Ptr{Cvoid} + + ir = sprint(io->GPUCompiler.code_llvm(io, job)) + end end ############################################################################################