@@ -69,5 +69,89 @@ function inlining_handler(meta::InlineStateMeta, interp::GPUInterpreter, @nospec
6969 return nothing
7070end
7171
72+ struct MockEnzymeMeta end
7273
74+ function autodiff end
75+
76+ import GPUCompiler: DeferredCallInfo
77+ struct AutodiffCallInfo <: CallInfo
78+ info:: DeferredCallInfo
79+ end
80+
81+ function abstract_call_known (meta:: MockEnzymeMeta , interp:: GPUInterpreter , @nospecialize (f),
82+ arginfo:: ArgInfo , si:: StmtInfo , sv:: AbsIntState , max_methods:: Int )
83+ (; fargs, argtypes) = arginfo
84+
85+ if f === autodiff
86+ if length (argtypes) >= 1
87+ @static if VERSION < v " 1.11.0-"
88+ return CallMeta (Union{}, Effects (), NoCallInfo ())
89+ else
90+ return CallMeta (Union{}, Union{}, Effects (), NoCallInfo ())
91+ end
92+ end
93+
94+ other_fargs = fargs === nothing ? nothing : fargs[2 : end ]
95+ other_arginfo = ArgInfo (other_fargs, argtypes[2 : end ])
96+ call = Core. Compiler. abstract_call (interp, other_arginfo, si, sv, max_methods)
97+ callinfo = DeferredCallInfo (MockEnzymeMeta (), call. rt, call. info)
98+
99+ # Real Enzyme must compute `rt` and `exct` according to enzyme semantics
100+ # and likely perform a unwrapping of fargs...
101+ rt = Nothing
102+
103+ # TODO : Edges? Effects?
104+ @static if VERSION < v " 1.11.0-"
105+ return CallMeta (rt, call. effects, AutodiffCallInfo (callinfo))
106+ else
107+ return CallMeta (rt, call. exct, call. effects, AutodiffCallInfo (callinfo))
108+ end
109+ end
110+
111+ return nothing
112+ end
113+
114+ const FlagType = VERSION >= v " 1.11.0-" ? UInt32 : UInt8
115+ function CC. handle_call! (todo:: Vector{Pair{Int,Any}} , ir:: CC.IRCode , idx:: CC.Int ,
116+ stmt:: Expr , info:: AutodiffCallInfo , flag:: FlagType ,
117+ sig:: CC.Signature , state:: CC.InliningState )
118+ # Goal:
119+ # The IR we want to inline here is:
120+ # unpack the args ...
121+ # ptr = gpuc.deferred(MockEnzymeMeta(), f, primal_args...)
122+ # ret = ccall("extern __autodiff", llvmcall, RT, Tuple{Ptr{Cvoid, args...}}, ptr, adjoint_args...)
123+
124+ push! (todo, idx=> (info:: AutoDiffTodo ))
125+
126+ # # 1. Since Julia's inliner goes bottom up we need to pretend that we inlined the deferred call
127+ # deferred_info = info.info
128+ # # TODO : This is code duplication is unfortunate...
129+ # minfo = deferred_info.info
130+ # results = minfo.results
131+ # if length(results.matches) != 1
132+ # return nothing
133+ # end
134+ # match = only(results.matches)
135+
136+ # # lookup the target mi with correct edge tracking
137+ # case = CC.compileable_specialization(match, CC.Effects(), CC.InliningEdgeTracker(state),
138+ # info)
139+ # @assert case isa CC.InvokeCase
140+ # @assert stmt.head === :call
141+
142+ # stmt = Expr(:foreigncall,
143+ # "extern gpuc.lookup",
144+ # Ptr{Cvoid},
145+ # Core.svec(Any, Any, Any, match.spec_types.parameters[2:end]...), # Must use Any for MethodInstance or ftype
146+ # 0,
147+ # QuoteNode(:llvmcall),
148+ # info.meta,
149+ # case.invoke,
150+ # stmt.args[3:end]...
151+ # )
152+
153+ # # 2. Form call to `__autodiff`
154+ # # TODO !
155+
156+ return nothing
73157end
0 commit comments