@@ -37,7 +37,7 @@ struct NeverInlineMeta <: InlineStateMeta end
3737import GPUCompiler: abstract_call_known, GPUInterpreter
3838import Core. Compiler: CallMeta, Effects, NoCallInfo, ArgInfo,
3939 StmtInfo, AbsIntState, EFFECTS_TOTAL,
40- MethodResultPure
40+ MethodResultPure, CallInfo, IRCode
4141
4242function abstract_call_known (meta:: InlineStateMeta , interp:: GPUInterpreter , @nospecialize (f),
4343 arginfo:: ArgInfo , si:: StmtInfo , sv:: AbsIntState , max_methods:: Int )
@@ -70,5 +70,178 @@ function inlining_handler(meta::InlineStateMeta, interp::GPUInterpreter, @nospec
7070 return nothing
7171end
7272
73+ struct MockEnzymeMeta end
7374
74- end
75+ # Having to define this function is annoying
76+ # introduce `abstract type InferenceMeta`
77+ function inlining_handler (meta:: MockEnzymeMeta , interp:: GPUInterpreter , @nospecialize (atype), callinfo)
78+ return nothing
79+ end
80+
81+ function autodiff end
82+
83+ import GPUCompiler: DeferredCallInfo
84+ struct AutodiffCallInfo <: CallInfo
85+ rt
86+ info:: DeferredCallInfo
87+ end
88+
89+ function abstract_call_known (meta:: Nothing , interp:: GPUInterpreter , f:: typeof (autodiff),
90+ arginfo:: ArgInfo , si:: StmtInfo , sv:: AbsIntState , max_methods:: Int )
91+ (; fargs, argtypes) = arginfo
92+
93+ @assert f === autodiff
94+ if length (argtypes) <= 1
95+ @static if VERSION < v " 1.11.0-"
96+ return CallMeta (Union{}, Effects (), NoCallInfo ())
97+ else
98+ return CallMeta (Union{}, Union{}, Effects (), NoCallInfo ())
99+ end
100+ end
101+
102+ other_fargs = fargs === nothing ? nothing : fargs[2 : end ]
103+ other_arginfo = ArgInfo (other_fargs, argtypes[2 : end ])
104+ # TODO : Ought we not change absint to use MockEnzymeMeta(), otherwise we fill the cache for nothing.
105+ call = Core. Compiler. abstract_call (interp, other_arginfo, si, sv, max_methods)
106+ callinfo = DeferredCallInfo (MockEnzymeMeta (), call. rt, call. info)
107+
108+ # Real Enzyme must compute `rt` and `exct` according to enzyme semantics
109+ # and likely perform a unwrapping of fargs...
110+ rt = call. rt
111+
112+ # TODO : Edges? Effects?
113+ @static if VERSION < v " 1.11.0-"
114+ # Can't use call.effects since otherwise this call might be just replaced with rt
115+ return CallMeta (rt, Effects (), AutodiffCallInfo (rt, callinfo))
116+ else
117+ return CallMeta (rt, call. exct, Effects (), AutodiffCallInfo (rt, callinfo))
118+ end
119+ end
120+
121+ function abstract_call_known (meta:: MockEnzymeMeta , interp:: GPUInterpreter , @nospecialize (f),
122+ arginfo:: ArgInfo , si:: StmtInfo , sv:: AbsIntState , max_methods:: Int )
123+ return nothing
124+ end
125+
126+ import Core. Compiler: insert_node!, NewInstruction, ReturnNode, Instruction, InliningState, Signature
127+
128+ # We really need a Compiler stdlib
129+ Base. getindex (ir:: IRCode , i) = Core. Compiler. getindex (ir, i)
130+ Base. setindex! (inst:: Instruction , val, i) = Core. Compiler. setindex! (inst, val, i)
131+
132+ const FlagType = VERSION >= v " 1.11.0-" ? UInt32 : UInt8
133+ function Core. Compiler. handle_call! (todo:: Vector{Pair{Int,Any}} , ir:: IRCode , stmt_idx:: Int ,
134+ stmt:: Expr , info:: AutodiffCallInfo , flag:: FlagType ,
135+ sig:: Signature , state:: InliningState )
136+
137+ # Goal:
138+ # The IR we want to inline here is:
139+ # unpack the args ..
140+ # ptr = gpuc.deferred(MockEnzymeMeta(), f, primal_args...)
141+ # ret = ccall("extern __autodiff", llvmcall, RT, Tuple{Ptr{Cvoid, args...}}, ptr, adjoint_args...)
142+
143+ # 0. Obtain primal mi from DeferredCallInfo
144+ # TODO : remove this code duplication
145+ deferred_info = info. info
146+ minfo = deferred_info. info
147+ results = minfo. results
148+ if length (results. matches) != 1
149+ return nothing
150+ end
151+ match = only (results. matches)
152+
153+ # lookup the target mi with correct edge tracking
154+ # TODO : Effects?
155+ case = Core. Compiler. compileable_specialization (
156+ match, Core. Compiler. Effects (), Core. Compiler. InliningEdgeTracker (state), info)
157+ @assert case isa Core. Compiler. InvokeCase
158+ @assert stmt. head === :call
159+
160+ # Now create the IR we want to inline
161+ ir = Core. Compiler. IRCode () # contains a placeholder
162+ args = [Core. Compiler. Argument (i) for i in 2 : length (stmt. args)] # f, args...
163+ idx = 0
164+
165+ # 0. Enzyme proper: Desugar args
166+ primal_args = args
167+ primal_argtypes = match. spec_types. parameters[2 : end ]
168+
169+ adjoint_rt = info. rt
170+ adjoint_args = args # TODO
171+ adjoint_argtypes = primal_argtypes
172+
173+ # 1: Since Julia's inliner goes bottom up we need to pretend that we inlined the deferred call
174+ expr = Expr (:foreigncall ,
175+ " extern gpuc.lookup" ,
176+ Ptr{Cvoid},
177+ Core. svec (#= meta=# Any, #= mi=# Any, #= f=# Any, primal_argtypes... ), # Must use Any for MethodInstance or ftype
178+ 0 ,
179+ QuoteNode (:llvmcall ),
180+ deferred_info. meta,
181+ case. invoke,
182+ primal_args...
183+ )
184+ ptr = insert_node! (ir, (idx += 1 ), NewInstruction (expr, Ptr{Cvoid}))
185+
186+ # 2. Call to magic `__autodiff`
187+ expr = Expr (:foreigncall ,
188+ " extern __autodiff" ,
189+ adjoint_rt,
190+ Core. svec (Ptr{Cvoid}, Any, adjoint_argtypes... ),
191+ 0 ,
192+ QuoteNode (:llvmcall ),
193+ ptr,
194+ adjoint_args...
195+ )
196+ ret = insert_node! (ir, idx, NewInstruction (expr, adjoint_rt))
197+
198+ # Finally replace placeholder return
199+ ir[Core. SSAValue (1 )][:inst ] = Core. ReturnNode (ret)
200+ ir[Core. SSAValue (1 )][:type ] = Ptr{Cvoid}
201+
202+ ir = Core. Compiler. compact! (ir)
203+
204+ # which mi to use here?
205+ # push inlining todos
206+ # TODO : Effects
207+ # aviatesk mentioned using inlining_policy instead...
208+ itodo = Core. Compiler. InliningTodo (case. invoke, ir, Core. Compiler. Effects ())
209+ @assert itodo. linear_inline_eligible
210+ push! (todo, (stmt_idx=> itodo))
211+
212+ return nothing
213+ end
214+
215+ function mock_enzyme! (@nospecialize (job), intrinsic, mod:: LLVM.Module )
216+ changed = false
217+
218+ for use in LLVM. uses (intrinsic)
219+ call = LLVM. user (use)
220+ LLVM. @dispose builder= LLVM. IRBuilder () begin
221+ LLVM. position! (builder, call)
222+ ops = LLVM. operands (call)
223+ target = ops[1 ]
224+ if target isa LLVM. ConstantExpr && (LLVM. opcode (target) == LLVM. API. LLVMPtrToInt ||
225+ LLVM. opcode (target) == LLVM. API. LLVMBitCast)
226+ target = first (LLVM. operands (target))
227+ end
228+ funcT = LLVM. called_type (call)
229+ funcT = LLVM. FunctionType (LLVM. return_type (funcT), LLVM. parameters (funcT)[3 : end ])
230+ direct_call = LLVM. call! (builder, funcT, target, ops[3 : end - 1 ]) # why is the -1 necessary
231+
232+ LLVM. replace_uses! (call, direct_call)
233+ end
234+ if isempty (LLVM. uses (call))
235+ LLVM. erase! (call)
236+ changed = true
237+ else
238+ # the validator will detect this
239+ end
240+ end
241+
242+ return changed
243+ end
244+
245+ GPUCompiler. register_plugin! (" __autodiff" , mock_enzyme!)
246+
247+ end # module
0 commit comments