Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enzyme.ReverseSplitWithPrimal is not supported #215

Open
avik-pal opened this issue Nov 1, 2024 · 4 comments
Open

Enzyme.ReverseSplitWithPrimal is not supported #215

avik-pal opened this issue Nov 1, 2024 · 4 comments

Comments

@avik-pal
Copy link
Collaborator

avik-pal commented Nov 1, 2024

using Enzyme, Reactant

f(x) = sum(abs2, x .* x)

function enzyme_split_mode(x)
    dx = Enzyme.make_zero(x)
    forward, reverse = autodiff_thunk(
        ReverseSplitWithPrimal, Const{typeof(f)}, Active, Duplicated{typeof(x)}
    )
    tape, result, shadow_result = forward(Const(f), Duplicated(x, dx))
    reverse(Const(f), Duplicated(x, dx), 1.0, tape)
    return result, dx
end

x = rand(10)

f(x)
enzyme_split_mode(x)

x_ra = Reactant.to_rarray(x)

@code_hlo optimize = true enzyme_split_mode(x_ra)
@avik-pal
Copy link
Collaborator Author

avik-pal commented Nov 1, 2024

error:

ERROR: AssertionError: Base.allocatedinline(actualRetType) returns false: actualRetType = Union{Reactant.TracedRNumber{Float64}, Reactant.TracedRArray{Float64}}, rettype = Active{Union{Reactant.TracedRNumber{Float64}, Reactant.TracedRArray{Float64}}}
Stacktrace:
  [1] create_abi_wrapper(enzymefn::LLVM.Function, TT::Type, rettype::Type, actualRetType::Type, Mode::Enzyme.API.CDerivativeMode, augmented::Ptr{Nothing}, width::Int64, returnPrimal::Bool, shadow_init::Bool, world::UInt64, interp::Enzyme.Compiler.Interpreter.EnzymeInterpreter{Nothing})
    @ Enzyme.Compiler /mnt/.julia/packages/Enzyme/VSRgT/src/compiler.jl:4287
  [2] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::Tuple{…}, returnPrimal::Bool, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
    @ Enzyme.Compiler /mnt/.julia/packages/Enzyme/VSRgT/src/compiler.jl:4023
  [3] codegen(output::Symbol, job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler /mnt/.julia/packages/Enzyme/VSRgT/src/compiler.jl:7156
  [4] codegen
    @ /mnt/.julia/packages/Enzyme/VSRgT/src/compiler.jl:5972 [inlined]
  [5] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler /mnt/.julia/packages/Enzyme/VSRgT/src/compiler.jl:8267
  [6] _thunk
    @ /mnt/.julia/packages/Enzyme/VSRgT/src/compiler.jl:8267 [inlined]
  [7] cached_compilation
    @ /mnt/.julia/packages/Enzyme/VSRgT/src/compiler.jl:8308 [inlined]
  [8] thunkbase(ctx::LLVM.Context, mi::Core.MethodInstance, ::Val{0x0000000000007b05}, ::Type{Const{typeof(f)}}, ::Type{Active}, tt::Type{Tuple{Duplicated{…}}}, ::Val{Enzyme.API.DEM_ReverseModeGradient}, ::Val{1}, ::Val{(false, false)}, ::Val{true}, ::Val{false}, ::Type{FFIABI}, ::Val{false}, ::Val{false})
    @ Enzyme.Compiler /mnt/.julia/packages/Enzyme/VSRgT/src/compiler.jl:8440
  [9] #s2080#19075
    @ /mnt/.julia/packages/Enzyme/VSRgT/src/compiler.jl:8577 [inlined]
 [10] var"#s2080#19075"(FA::Any, A::Any, TT::Any, Mode::Any, ModifiedBetween::Any, width::Any, ReturnPrimal::Any, ShadowInit::Any, World::Any, ABI::Any, ErrIfFuncWritten::Any, RuntimeActivity::Any, ::Any, ::Type, ::Type, ::Type, tt::Any, ::Type, ::Type, ::Type, ::Type, ::Type, ::Type, ::Type, ::Any)
    @ Enzyme.Compiler ./none:0
 [11] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
    @ Core ./boot.jl:602
 [12] autodiff_thunk
    @ /mnt/.julia/packages/Enzyme/VSRgT/src/Enzyme.jl:969 [inlined]
 [13] enzyme_split_mode
    @ ./REPL[14]:3 [inlined]
 [14] (::Tuple{})(none::Reactant.TracedRArray{Float64, 1})
    @ Base.Experimental ./<missing>:0
 [15] (::Reactant.var"#26#35"{Bool, typeof(enzyme_split_mode), Vector{ConcreteRArray{Float64, 1}}, Vector{Union{ReactantCore.MissingTracedValue, Reactant.TracedRArray, Reactant.TracedRNumber}}, Tuple{Reactant.TracedRArray{Float64, 1}}})()
    @ Reactant /mnt/software/lux/Reactant.jl/src/utils.jl:148
 [16] block!(f::Reactant.var"#26#35"{Bool, typeof(enzyme_split_mode), Vector{ConcreteRArray{Float64, 1}}, Vector{Union{ReactantCore.MissingTracedValue, Reactant.TracedRArray, Reactant.TracedRNumber}}, Tuple{Reactant.TracedRArray{Float64, 1}}}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Block.jl:201
 [17] make_mlir_fn(f::Function, args::Vector{ConcreteRArray{Float64, 1}}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, no_args_in_result::Bool, construct_function_without_args::Bool)
    @ Reactant /mnt/software/lux/Reactant.jl/src/utils.jl:112
 [18] make_mlir_fn
    @ /mnt/software/lux/Reactant.jl/src/utils.jl:36 [inlined]
 [19] #6
    @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:271 [inlined]
 [20] block!(f::Reactant.Compiler.var"#6#11"{typeof(enzyme_split_mode), Vector{ConcreteRArray{Float64, 1}}}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Block.jl:201
 [21] #5
    @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:270 [inlined]
 [22] mmodule!(f::Reactant.Compiler.var"#5#10"{Reactant.MLIR.IR.Module, typeof(enzyme_split_mode), Vector{ConcreteRArray{Float64, 1}}}, blk::Reactant.MLIR.IR.Module)
    @ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Module.jl:93
 [23] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Vector{ConcreteRArray{Float64, 1}}; optimize::Bool)
    @ Reactant.Compiler /mnt/software/lux/Reactant.jl/src/Compiler.jl:267
 [24] compile_mlir!
    @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:266 [inlined]
 [25] #2
    @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:261 [inlined]
 [26] context!(f::Reactant.Compiler.var"#2#3"{@Kwargs{optimize::Bool}, typeof(enzyme_split_mode), Vector{ConcreteRArray{Float64, 1}}}, ctx::Reactant.MLIR.IR.Context)
    @ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Context.jl:71
 [27] compile_mlir(f::Function, args::Vector{ConcreteRArray{Float64, 1}}; kwargs::@Kwargs{optimize::Bool})
    @ Reactant.Compiler /mnt/software/lux/Reactant.jl/src/Compiler.jl:259
 [28] macro expansion
    @ /mnt/software/lux/Reactant.jl/src/Compiler.jl:409 [inlined]
 [29] top-level scope
    @ REPL[19]:1
Some type information was truncated. Use `show(err)` to see complete types.

from the error it seems to hit enzyme proper (I am blind of course it hits enzyme proper, we dont have a autodiff_thunk)

@mofeing
Copy link
Collaborator

mofeing commented Nov 1, 2024

mmm I believe that you no longer need to use autodiff_thunk so much in Enzyme v0.13? @wsmoses

@avik-pal
Copy link
Collaborator Author

avik-pal commented Nov 2, 2024

Oh I did not realize autodiff works with the split mode (at least the docs still use thunk there)

@wsmoses
Copy link
Member

wsmoses commented Nov 2, 2024

No, Enzyme 0.13 improves the usage of deferred not split mode.

We haven't implemented split mode in EnzymeMLIR yet (though ironically @Pangoraw and I were discussing this the other day on his callop PRs)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants