Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Undefined fill(::ComplexF64, ::Reactant.MLIR.IR.Type) method required for differentiating complex function #238

Open
Todorbsc opened this issue Nov 7, 2024 · 16 comments
Labels
good first issue Good for newcomers

Comments

@Todorbsc
Copy link
Collaborator

Todorbsc commented Nov 7, 2024

CC @mofeing
I get an error when trying to compile a gradient of a function that accepts complex parameters.

julia> using Enzyme

julia> using Reactant
AssertionError("Could not find registered platform with name: \"cuda\". Available platform names are: ")

julia> using Adapt

julia> N = 10
10

julia> params = rand(ComplexF64, N);

julia> expected = rand(ComplexF64, N);

julia> params′ = adapt(ConcreteRArray, params);

julia> expected′ = adapt(ConcreteRArray, expected);

julia> function f1(params, expected)
           return sum(abs.(expected - params))
       end
f1 (generic function with 1 method)

julia> function ∇f(params, expected)
           foo = Enzyme.gradient(ReverseWithPrimal, f1, params, Enzyme.Const(expected))
           return foo.val, foo.derivs[1]
       end
∇f (generic function with 1 method)

julia> ∇fR = Reactant.@compile ∇f(params′, expected′)
ERROR: MethodError: no method matching fill(::ComplexF64, ::Reactant.MLIR.IR.Type)

Closest candidates are:
  fill(::Type{Reactant.MLIR.IR.Attribute}, ::Any, ::Any)
   @ Reactant ~/.julia/packages/Reactant/e7PeE/src/mlir/IR/Attribute.jl:494
  fill(::Any, ::Union{Integer, AbstractUnitRange}...)
   @ Base array.jl:582
  fill(::Int8, ::Reactant.MLIR.IR.Type)
   @ Reactant ~/.julia/packages/Reactant/e7PeE/src/mlir/IR/Attribute.jl:459
  ...

Stacktrace:
  [1] broadcast_to_size(arg::ComplexF64, rsize::Tuple{Int64})
    @ Reactant ~/.julia/packages/Reactant/e7PeE/src/TracedRArray.jl:614
  [2] make_zero(::Type{Reactant.TracedRArray{ComplexF64, 1}}, seen::IdDict{Any, Any}, prev::Reactant.TracedRArray{ComplexF64, 1}, ::Val{false})
    @ Reactant ~/.julia/packages/Reactant/e7PeE/src/Reactant.jl:72
  [3] make_zero (repeats 2 times)
    @ ~/.julia/packages/EnzymeCore/Gdg5y/src/EnzymeCore.jl:524 [inlined]
  [4] macro expansion
    @ ~/.julia/packages/Enzyme/VSRgT/src/Enzyme.jl:1711 [inlined]
  [5] gradient
    @ ~/.julia/packages/Enzyme/VSRgT/src/Enzyme.jl:1661 [inlined]
  [6] ∇f
    @ ./REPL[10]:2 [inlined]
  [7] (::Tuple{})(none::Reactant.TracedRArray{ComplexF64, 1}, none::Reactant.TracedRArray{ComplexF64, 1})
    @ Base.Experimental ./<missing>:0
  [8] (::Reactant.var"#26#35"{Bool, typeof(∇f), Tuple{}, Vector{}, Tuple{}})()
    @ Reactant ~/.julia/packages/Reactant/e7PeE/src/utils.jl:148
  [9] block!(f::Reactant.var"#26#35"{Bool, typeof(∇f), Tuple{}, Vector{}, Tuple{}}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/e7PeE/src/mlir/IR/Block.jl:201
 [10] make_mlir_fn(f::Function, args::Tuple{…}, kwargs::Tuple{}, name::String, concretein::Bool; toscalar::Bool, return_dialect::Symbol, no_args_in_result::Bool, construct_function_without_args::Bool)
    @ Reactant ~/.julia/packages/Reactant/e7PeE/src/utils.jl:112
 [11] make_mlir_fn
    @ ~/.julia/packages/Reactant/e7PeE/src/utils.jl:36 [inlined]
 [12] #6
    @ ~/.julia/packages/Reactant/e7PeE/src/Compiler.jl:270 [inlined]
 [13] block!(f::Reactant.Compiler.var"#6#11"{typeof(∇f), Tuple{ConcreteRArray{}, ConcreteRArray{}}}, blk::Reactant.MLIR.IR.Block)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/e7PeE/src/mlir/IR/Block.jl:201
 [14] #5
    @ ~/.julia/packages/Reactant/e7PeE/src/Compiler.jl:269 [inlined]
 [15] mmodule!(f::Reactant.Compiler.var"#5#10"{Reactant.MLIR.IR.Module, typeof(∇f), Tuple{}}, blk::Reactant.MLIR.IR.Module)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/e7PeE/src/mlir/IR/Module.jl:93
 [16] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{ConcreteRArray{…}, ConcreteRArray{…}}; optimize::Bool)
    @ Reactant.Compiler ~/.julia/packages/Reactant/e7PeE/src/Compiler.jl:266
 [17] compile_mlir!
    @ ~/.julia/packages/Reactant/e7PeE/src/Compiler.jl:265 [inlined]
 [18] (::Reactant.Compiler.var"#30#32"{Bool, typeof(∇f), Tuple{ConcreteRArray{ComplexF64, 1}, ConcreteRArray{ComplexF64, 1}}})()
    @ Reactant.Compiler ~/.julia/packages/Reactant/e7PeE/src/Compiler.jl:720
 [19] context!(f::Reactant.Compiler.var"#30#32"{Bool, typeof(∇f), Tuple{ConcreteRArray{}, ConcreteRArray{}}}, ctx::Reactant.MLIR.IR.Context)
    @ Reactant.MLIR.IR ~/.julia/packages/Reactant/e7PeE/src/mlir/IR/Context.jl:76
 [20] compile_xla(f::Function, args::Tuple{ConcreteRArray{ComplexF64, 1}, ConcreteRArray{ComplexF64, 1}}; client::Nothing, optimize::Bool)
    @ Reactant.Compiler ~/.julia/packages/Reactant/e7PeE/src/Compiler.jl:717
 [21] compile_xla
    @ ~/.julia/packages/Reactant/e7PeE/src/Compiler.jl:712 [inlined]
 [22] compile(f::Function, args::Tuple{ConcreteRArray{ComplexF64, 1}, ConcreteRArray{ComplexF64, 1}}; client::Nothing, optimize::Bool, sync::Bool)
    @ Reactant.Compiler ~/.julia/packages/Reactant/e7PeE/src/Compiler.jl:744
 [23] top-level scope
    @ ~/.julia/packages/Reactant/e7PeE/src/Compiler.jl:490
Some type information was truncated. Use `show(err)` to see complete types.

This does NOT happen if the parameters are real:

julia> params_real = rand(N);

julia> expected_real = rand(N);

julia> params_real′ = adapt(ConcreteRArray, params_real);

julia> expected_real′ = adapt(ConcreteRArray, expected_real);

julia> function f1(params, expected)
           return sum(abs.(expected - params))
       end
f1 (generic function with 1 method)

julia> function ∇f(params, expected)
           foo = Enzyme.gradient(ReverseWithPrimal, f1, params, Enzyme.Const(expected))
           return foo.val, foo.derivs[1]
       end
∇f (generic function with 1 method)

julia> ∇fR = Reactant.@compile ∇f(params_real′, expected_real′)
Reactant.Compiler.Thunk{Symbol("##∇f_reactant#225")}()
@mofeing
Copy link
Collaborator

mofeing commented Nov 7, 2024

I think this is related to previous bugs on deviating complex numbers. At least, it's crashing in the same line.

@mofeing mofeing changed the title Error in Reactant.@compile for complex types parameters Undefined fill(::ComplexF64, ::Reactant.MLIR.IR.Type) method required for differentiating complex function Nov 7, 2024
@mofeing mofeing added the good first issue Good for newcomers label Nov 7, 2024
@avik-pal
Copy link
Collaborator

avik-pal commented Nov 8, 2024

On main:

julia> @code_hlo optimize=falsef(params′, expected′)
Module:
module {
  func.func private @"-_broadcast_scalar"(%arg0: tensor<complex<f64>>, %arg1: tensor<complex<f64>>) -> (tensor<complex<f64>>, tensor<complex<f64>>, tensor<complex<f64>>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<complex<f64>>) -> tensor<complex<f64>>
    %1 = stablehlo.transpose %arg1, dims = [] : (tensor<complex<f64>>) -> tensor<complex<f64>>
    %2 = stablehlo.subtract %0, %1 : tensor<complex<f64>>
    %3 = stablehlo.transpose %2, dims = [] : (tensor<complex<f64>>) -> tensor<complex<f64>>
    %4 = stablehlo.transpose %0, dims = [] : (tensor<complex<f64>>) -> tensor<complex<f64>>
    %5 = stablehlo.transpose %1, dims = [] : (tensor<complex<f64>>) -> tensor<complex<f64>>
    return %3, %4, %5 : tensor<complex<f64>>, tensor<complex<f64>>, tensor<complex<f64>>
  }
  func.func private @abs_broadcast_scalar(%arg0: tensor<complex<f64>>) -> (tensor<f64>, tensor<complex<f64>>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<complex<f64>>) -> tensor<complex<f64>>
    %1 = stablehlo.abs %0 : (tensor<complex<f64>>) -> tensor<f64>
    %2 = stablehlo.transpose %1, dims = [] : (tensor<f64>) -> tensor<f64>
    %3 = stablehlo.transpose %0, dims = [] : (tensor<complex<f64>>) -> tensor<complex<f64>>
    return %2, %3 : tensor<f64>, tensor<complex<f64>>
  }
  func.func private @identity_broadcast_scalar(%arg0: tensor<f64>) -> tensor<f64> {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f64>) -> tensor<f64>
    %1 = stablehlo.transpose %0, dims = [] : (tensor<f64>) -> tensor<f64>
    return %1 : tensor<f64>
  }
  func.func private @"Const{typeof(f1)}(Main.f1)_autodiff"(%arg0: tensor<10xcomplex<f64>>, %arg1: tensor<10xcomplex<f64>>) -> (tensor<f64>, tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>) {
    %0 = stablehlo.transpose %arg0, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
    %1 = stablehlo.transpose %arg1, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
    %2 = stablehlo.broadcast_in_dim %1, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
    %3 = stablehlo.broadcast_in_dim %0, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
    %4:3 = enzyme.batch @"-_broadcast_scalar"(%2, %3) {batch_shape = array<i64: 10>} : (tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>) -> (tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>)
    %5 = stablehlo.broadcast_in_dim %4#0, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
    %6:2 = enzyme.batch @abs_broadcast_scalar(%5) {batch_shape = array<i64: 10>} : (tensor<10xcomplex<f64>>) -> (tensor<10xf64>, tensor<10xcomplex<f64>>)
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f64>
    %7 = stablehlo.broadcast_in_dim %6#0, dims = [0] : (tensor<10xf64>) -> tensor<10xf64>
    %8 = enzyme.batch @identity_broadcast_scalar(%7) {batch_shape = array<i64: 10>} : (tensor<10xf64>) -> tensor<10xf64>
    %9 = stablehlo.reduce(%8 init: %cst) applies stablehlo.add across dimensions = [0] : (tensor<10xf64>, tensor<f64>) -> tensor<f64>
    %10 = stablehlo.transpose %9, dims = [] : (tensor<f64>) -> tensor<f64>
    %11 = stablehlo.transpose %0, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
    %12 = stablehlo.transpose %1, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
    return %10, %11, %12 : tensor<f64>, tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>
  }
  func.func @main(%arg0: tensor<10xcomplex<f64>>, %arg1: tensor<10xcomplex<f64>>) -> (tensor<f64>, tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>) {
    %0 = stablehlo.transpose %arg0, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
    %1 = stablehlo.transpose %arg1, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
    %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<10xcomplex<f64>>
    %cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<f64>
    %2 = stablehlo.transpose %0, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
    %3 = stablehlo.transpose %1, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
    %4 = stablehlo.transpose %cst_0, dims = [] : (tensor<f64>) -> tensor<f64>
    %5 = stablehlo.transpose %cst, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
    %6:4 = enzyme.autodiff @"Const{typeof(f1)}(Main.f1)_autodiff"(%2, %3, %4, %5) {activity = [#enzyme<activity enzyme_active>, #enzyme<activity enzyme_const>], ret_activity = [#enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_const>]} : (tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>, tensor<f64>, tensor<10xcomplex<f64>>) -> (tensor<f64>, tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>)
    %7 = stablehlo.transpose %6#0, dims = [] : (tensor<f64>) -> tensor<f64>
    %8 = stablehlo.transpose %6#1, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
    %9 = stablehlo.transpose %6#2, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
    %10 = stablehlo.transpose %6#3, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
    %11 = stablehlo.transpose %7, dims = [] : (tensor<f64>) -> tensor<f64>
    %12 = stablehlo.transpose %10, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
    %13 = stablehlo.transpose %8, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
    %14 = stablehlo.transpose %9, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
    return %11, %12, %13, %14 : tensor<f64>, tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>
  }
}

We are generating something incorrect here:

error: 'complex.add' op operand #0 must be complex type with floating-point elements, but got 'tensor<10xcomplex<f64>>'

@mofeing
Copy link
Collaborator

mofeing commented Nov 8, 2024

I believe the problem is that complex.add only works on complex types... we should change it for stablehlo.add which supports complex tensors

but where is that complex.add happening?

@avik-pal
Copy link
Collaborator

avik-pal commented Nov 8, 2024

I was under the impression, stablehlo.add is being lowered to complex.add? (at least from the IR which only has stablehlo.add)

@mofeing
Copy link
Collaborator

mofeing commented Nov 8, 2024

mmm can you show it with optimize=true?

@avik-pal
Copy link
Collaborator

avik-pal commented Nov 8, 2024

mmm can you show it with optimize=true?

That errors

@avik-pal
Copy link
Collaborator

avik-pal commented Nov 8, 2024

@avik-pal
Copy link
Collaborator

avik-pal commented Nov 11, 2024

@wsmoses any idea where #238 (comment) is originating from? Lol you replied too fast on the other issue 😅

@wsmoses
Copy link
Member

wsmoses commented Nov 11, 2024

lol yeah if you can spit out it and the pass pipeline we run, we can run it with enzymexlamlir-opt manually and see what’s up

@avik-pal
Copy link
Collaborator

Module:
module {
  func.func private @"-_broadcast_scalar"(%arg0: tensor<complex<f64>>, %arg1: tensor<complex<f64>>) -> (tensor<complex<f64>>, tensor<complex<f64>>, tensor<complex<f64>>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<complex<f64>>) -> tensor<complex<f64>>
    %1 = stablehlo.transpose %arg1, dims = [] : (tensor<complex<f64>>) -> tensor<complex<f64>>
    %2 = stablehlo.subtract %0, %1 : tensor<complex<f64>>
    %3 = stablehlo.transpose %2, dims = [] : (tensor<complex<f64>>) -> tensor<complex<f64>>
    %4 = stablehlo.transpose %0, dims = [] : (tensor<complex<f64>>) -> tensor<complex<f64>>
    %5 = stablehlo.transpose %1, dims = [] : (tensor<complex<f64>>) -> tensor<complex<f64>>
    return %3, %4, %5 : tensor<complex<f64>>, tensor<complex<f64>>, tensor<complex<f64>>
  }
  func.func private @abs_broadcast_scalar(%arg0: tensor<complex<f64>>) -> (tensor<f64>, tensor<complex<f64>>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<complex<f64>>) -> tensor<complex<f64>>
    %1 = stablehlo.abs %0 : (tensor<complex<f64>>) -> tensor<f64>
    %2 = stablehlo.transpose %1, dims = [] : (tensor<f64>) -> tensor<f64>
    %3 = stablehlo.transpose %0, dims = [] : (tensor<complex<f64>>) -> tensor<complex<f64>>
    return %2, %3 : tensor<f64>, tensor<complex<f64>>
  }
  func.func private @identity_broadcast_scalar(%arg0: tensor<f64>) -> tensor<f64> {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f64>) -> tensor<f64>
    %1 = stablehlo.transpose %0, dims = [] : (tensor<f64>) -> tensor<f64>
    return %1 : tensor<f64>
  }
  func.func private @"Const{typeof(f1)}(Main.f1)_autodiff"(%arg0: tensor<10xcomplex<f64>>, %arg1: tensor<10xcomplex<f64>>) -> (tensor<f64>, tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>) {
    %0 = stablehlo.transpose %arg0, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
    %1 = stablehlo.transpose %arg1, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
    %2 = stablehlo.broadcast_in_dim %1, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
    %3 = stablehlo.broadcast_in_dim %0, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
    %4:3 = enzyme.batch @"-_broadcast_scalar"(%2, %3) {batch_shape = array<i64: 10>} : (tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>) -> (tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>)
    %5 = stablehlo.broadcast_in_dim %4#0, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
    %6:2 = enzyme.batch @abs_broadcast_scalar(%5) {batch_shape = array<i64: 10>} : (tensor<10xcomplex<f64>>) -> (tensor<10xf64>, tensor<10xcomplex<f64>>)
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f64>
    %7 = stablehlo.broadcast_in_dim %6#0, dims = [0] : (tensor<10xf64>) -> tensor<10xf64>
    %8 = enzyme.batch @identity_broadcast_scalar(%7) {batch_shape = array<i64: 10>} : (tensor<10xf64>) -> tensor<10xf64>
    %9 = stablehlo.reduce(%8 init: %cst) applies stablehlo.add across dimensions = [0] : (tensor<10xf64>, tensor<f64>) -> tensor<f64>
    %10 = stablehlo.transpose %9, dims = [] : (tensor<f64>) -> tensor<f64>
    %11 = stablehlo.transpose %0, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
    %12 = stablehlo.transpose %1, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
    return %10, %11, %12 : tensor<f64>, tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>
  }
  func.func @main(%arg0: tensor<10xcomplex<f64>>, %arg1: tensor<10xcomplex<f64>>) -> (tensor<f64>, tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>) {
    %0 = stablehlo.transpose %arg0, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
    %1 = stablehlo.transpose %arg1, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
    %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<10xcomplex<f64>>
    %cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<f64>
    %2 = stablehlo.transpose %0, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
    %3 = stablehlo.transpose %1, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
    %4 = stablehlo.transpose %cst_0, dims = [] : (tensor<f64>) -> tensor<f64>
    %5 = stablehlo.transpose %cst, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
    %6:4 = enzyme.autodiff @"Const{typeof(f1)}(Main.f1)_autodiff"(%2, %3, %4, %5) {activity = [#enzyme<activity enzyme_active>, #enzyme<activity enzyme_const>], ret_activity = [#enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_const>]} : (tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>, tensor<f64>, tensor<10xcomplex<f64>>) -> (tensor<f64>, tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>)
    %7 = stablehlo.transpose %6#0, dims = [] : (tensor<f64>) -> tensor<f64>
    %8 = stablehlo.transpose %6#1, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
    %9 = stablehlo.transpose %6#2, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
    %10 = stablehlo.transpose %6#3, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
    %11 = stablehlo.transpose %7, dims = [] : (tensor<f64>) -> tensor<f64>
    %12 = stablehlo.transpose %10, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
    %13 = stablehlo.transpose %8, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
    %14 = stablehlo.transpose %9, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
    return %11, %12, %13, %14 : tensor<f64>, tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>
  }
}

^this works? Or do you need additional info.

@wsmoses
Copy link
Member

wsmoses commented Nov 11, 2024

I think the critical thing here is more so the optimization pass pipeline itself.

@avik-pal
Copy link
Collaborator

I do see an arith raise pass

"arith-raise{stablehlo=true}",

@avik-pal
Copy link
Collaborator

This is the full pipeline

inline{default-pipeline=canonicalize max-iterations=4},canonicalize,cse,canonicalize,enzyme-hlo-generate-td{patterns=compare_op_canon<16>;transpose_transpose<16>;broadcast_in_dim_op_canon<16>;convert_op_canon<16>;dynamic_broadcast_in_dim_op_not_actually_dynamic<16>;chained_dynamic_broadcast_in_dim_canonicalization<16>;dynamic_broadcast_in_dim_all_dims_non_expanding<16>;noop_reduce_op_canon<16>;empty_reduce_op_canon<16>;dynamic_reshape_op_canon<16>;get_tuple_element_op_canon<16>;real_op_canon<16>;imag_op_canon<16>;get_dimension_size_op_canon<16>;gather_op_canon<16>;reshape_op_canon<16>;merge_consecutive_reshapes<16>;transpose_is_reshape<16>;zero_extent_tensor_canon<16>;reorder_elementwise_and_shape_op<16>;cse_broadcast_in_dim<16>;cse_slice<16>;cse_transpose<16>;cse_convert<16>;cse_pad<16>;cse_dot_general<16>;cse_reshape<16>;cse_mul<16>;cse_div<16>;cse_add<16>;cse_subtract<16>;cse_min<16>;cse_max<16>;cse_neg<16>;cse_concatenate<16>;concatenate_op_canon<16>(1024);select_op_canon<16>(1024);add_simplify<16>;sub_simplify<16>;and_simplify<16>;max_simplify<16>;min_simplify<16>;or_simplify<16>;negate_simplify<16>;mul_simplify<16>;div_simplify<16>;rem_simplify<16>;pow_simplify<16>;sqrt_simplify<16>;cos_simplify<16>;sin_simplify<16>;noop_slice<16>;const_prop_through_barrier<16>;slice_slice<16>;shift_right_logical_simplify<16>;pad_simplify<16>;negative_pad_to_slice<16>;tanh_simplify<16>;exp_simplify<16>;slice_simplify<16>;convert_simplify<16>;dynamic_slice_to_static<16>;dynamic_update_slice_elim<16>;concat_to_broadcast<16>;reduce_to_reshape<16>;broadcast_to_reshape<16>;gather_simplify<16>;iota_simplify<16>(1024);broadcast_in_dim_simplify<16>(1024);convert_concat<1>;dynamic_update_to_concat<1>;slice_of_dynamic_update<1>;slice_elementwise<1>;slice_pad<1>;dot_reshape_dot<1>;concat_const_prop<1>;concat_fuse<1>;pad_reshape_pad<1>;pad_pad<1>;concat_push_binop_add<1>;concat_push_binop_mul<1>;scatter_to_dynamic_update_slice<1>;reduce_concat<1>;slice_concat<1>;bin_broadcast_splat_add<1>;bin_broadcast_splat_subtract<1>;bin_broadcast_splat_div<1>;bin_broadcast_splat_mul<1>;reshape_iota<16>;slice_reshape_slice<1>;dot_general_simplify<16>;transpose_simplify<16>;reshape_empty_broadcast<1>;add_pad_pad_to_concat<1>;broadcast_reshape<1>;slice_reshape_concat<1>;slice_reshape_elementwise<1>;slice_reshape_transpose<1>;slice_reshape_dot_general<1>;concat_pad<1>;reduce_pad<1>;broadcast_pad<1>;zero_product_reshape_pad<1>;mul_zero_pad<1>;div_zero_pad<1>;binop_const_reshape_pad<1>;binop_const_pad_add<1>;binop_const_pad_subtract<1>;binop_const_pad_mul<1>;binop_const_pad_div<1>;slice_reshape_pad<1>;binop_binop_pad_pad_add<1>;binop_binop_pad_pad_mul<1>;binop_pad_pad_add<1>;binop_pad_pad_subtract<1>;binop_pad_pad_mul<1>;binop_pad_pad_div<1>;binop_pad_pad_min<1>;binop_pad_pad_max<1>;unary_pad_push_convert<1>;unary_pad_push_tanh<1>;unary_pad_push_exp<1>;transpose_pad<1>;transpose_dot_reorder<1>;dot_transpose<1>;convert_convert_float<1>;concat_to_pad<1>;concat_appending_reshape<1>;reshape_iota<1>;broadcast_reduce<1>;slice_dot_general<1>;dot_reshape_pad<1>;pad_dot_general<1>(0);dot_reshape_pad<1>;pad_dot_general<1>(1)},transform-interpreter,enzyme-hlo-remove-transform,enzyme-batch,inline{default-pipeline=canonicalize max-iterations=4},canonicalize,cse,canonicalize,enzyme-hlo-generate-td{patterns=compare_op_canon<16>;transpose_transpose<16>;broadcast_in_dim_op_canon<16>;convert_op_canon<16>;dynamic_broadcast_in_dim_op_not_actually_dynamic<16>;chained_dynamic_broadcast_in_dim_canonicalization<16>;dynamic_broadcast_in_dim_all_dims_non_expanding<16>;noop_reduce_op_canon<16>;empty_reduce_op_canon<16>;dynamic_reshape_op_canon<16>;get_tuple_element_op_canon<16>;real_op_canon<16>;imag_op_canon<16>;get_dimension_size_op_canon<16>;gather_op_canon<16>;reshape_op_canon<16>;merge_consecutive_reshapes<16>;transpose_is_reshape<16>;zero_extent_tensor_canon<16>;reorder_elementwise_and_shape_op<16>;cse_broadcast_in_dim<16>;cse_slice<16>;cse_transpose<16>;cse_convert<16>;cse_pad<16>;cse_dot_general<16>;cse_reshape<16>;cse_mul<16>;cse_div<16>;cse_add<16>;cse_subtract<16>;cse_min<16>;cse_max<16>;cse_neg<16>;cse_concatenate<16>;concatenate_op_canon<16>(1024);select_op_canon<16>(1024);add_simplify<16>;sub_simplify<16>;and_simplify<16>;max_simplify<16>;min_simplify<16>;or_simplify<16>;negate_simplify<16>;mul_simplify<16>;div_simplify<16>;rem_simplify<16>;pow_simplify<16>;sqrt_simplify<16>;cos_simplify<16>;sin_simplify<16>;noop_slice<16>;const_prop_through_barrier<16>;slice_slice<16>;shift_right_logical_simplify<16>;pad_simplify<16>;negative_pad_to_slice<16>;tanh_simplify<16>;exp_simplify<16>;slice_simplify<16>;convert_simplify<16>;dynamic_slice_to_static<16>;dynamic_update_slice_elim<16>;concat_to_broadcast<16>;reduce_to_reshape<16>;broadcast_to_reshape<16>;gather_simplify<16>;iota_simplify<16>(1024);broadcast_in_dim_simplify<16>(1024);convert_concat<1>;dynamic_update_to_concat<1>;slice_of_dynamic_update<1>;slice_elementwise<1>;slice_pad<1>;dot_reshape_dot<1>;concat_const_prop<1>;concat_fuse<1>;pad_reshape_pad<1>;pad_pad<1>;concat_push_binop_add<1>;concat_push_binop_mul<1>;scatter_to_dynamic_update_slice<1>;reduce_concat<1>;slice_concat<1>;bin_broadcast_splat_add<1>;bin_broadcast_splat_subtract<1>;bin_broadcast_splat_div<1>;bin_broadcast_splat_mul<1>;reshape_iota<16>;slice_reshape_slice<1>;dot_general_simplify<16>;transpose_simplify<16>;reshape_empty_broadcast<1>;add_pad_pad_to_concat<1>;broadcast_reshape<1>;slice_reshape_concat<1>;slice_reshape_elementwise<1>;slice_reshape_transpose<1>;slice_reshape_dot_general<1>;concat_pad<1>;reduce_pad<1>;broadcast_pad<1>;zero_product_reshape_pad<1>;mul_zero_pad<1>;div_zero_pad<1>;binop_const_reshape_pad<1>;binop_const_pad_add<1>;binop_const_pad_subtract<1>;binop_const_pad_mul<1>;binop_const_pad_div<1>;slice_reshape_pad<1>;binop_binop_pad_pad_add<1>;binop_binop_pad_pad_mul<1>;binop_pad_pad_add<1>;binop_pad_pad_subtract<1>;binop_pad_pad_mul<1>;binop_pad_pad_div<1>;binop_pad_pad_min<1>;binop_pad_pad_max<1>;unary_pad_push_convert<1>;unary_pad_push_tanh<1>;unary_pad_push_exp<1>;transpose_pad<1>;transpose_dot_reorder<1>;dot_transpose<1>;convert_convert_float<1>;concat_to_pad<1>;concat_appending_reshape<1>;reshape_iota<1>;broadcast_reduce<1>;slice_dot_general<1>;dot_reshape_pad<1>;pad_dot_general<1>(0);dot_reshape_pad<1>;pad_dot_general<1>(1)},transform-interpreter,enzyme-hlo-remove-transform,enzyme,arith-raise{stablehlo=true},canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,inline{default-pipeline=canonicalize max-iterations=4},canonicalize,cse,canonicalize,enzyme-hlo-generate-td{patterns=compare_op_canon<16>;transpose_transpose<16>;broadcast_in_dim_op_canon<16>;convert_op_canon<16>;dynamic_broadcast_in_dim_op_not_actually_dynamic<16>;chained_dynamic_broadcast_in_dim_canonicalization<16>;dynamic_broadcast_in_dim_all_dims_non_expanding<16>;noop_reduce_op_canon<16>;empty_reduce_op_canon<16>;dynamic_reshape_op_canon<16>;get_tuple_element_op_canon<16>;real_op_canon<16>;imag_op_canon<16>;get_dimension_size_op_canon<16>;gather_op_canon<16>;reshape_op_canon<16>;merge_consecutive_reshapes<16>;transpose_is_reshape<16>;zero_extent_tensor_canon<16>;reorder_elementwise_and_shape_op<16>;cse_broadcast_in_dim<16>;cse_slice<16>;cse_transpose<16>;cse_convert<16>;cse_pad<16>;cse_dot_general<16>;cse_reshape<16>;cse_mul<16>;cse_div<16>;cse_add<16>;cse_subtract<16>;cse_min<16>;cse_max<16>;cse_neg<16>;cse_concatenate<16>;concatenate_op_canon<16>(1024);select_op_canon<16>(1024);add_simplify<16>;sub_simplify<16>;and_simplify<16>;max_simplify<16>;min_simplify<16>;or_simplify<16>;negate_simplify<16>;mul_simplify<16>;div_simplify<16>;rem_simplify<16>;pow_simplify<16>;sqrt_simplify<16>;cos_simplify<16>;sin_simplify<16>;noop_slice<16>;const_prop_through_barrier<16>;slice_slice<16>;shift_right_logical_simplify<16>;pad_simplify<16>;negative_pad_to_slice<16>;tanh_simplify<16>;exp_simplify<16>;slice_simplify<16>;convert_simplify<16>;dynamic_slice_to_static<16>;dynamic_update_slice_elim<16>;concat_to_broadcast<16>;reduce_to_reshape<16>;broadcast_to_reshape<16>;gather_simplify<16>;iota_simplify<16>(1024);broadcast_in_dim_simplify<16>(1024);convert_concat<1>;dynamic_update_to_concat<1>;slice_of_dynamic_update<1>;slice_elementwise<1>;slice_pad<1>;dot_reshape_dot<1>;concat_const_prop<1>;concat_fuse<1>;pad_reshape_pad<1>;pad_pad<1>;concat_push_binop_add<1>;concat_push_binop_mul<1>;scatter_to_dynamic_update_slice<1>;reduce_concat<1>;slice_concat<1>;bin_broadcast_splat_add<1>;bin_broadcast_splat_subtract<1>;bin_broadcast_splat_div<1>;bin_broadcast_splat_mul<1>;reshape_iota<16>;slice_reshape_slice<1>;dot_general_simplify<16>;transpose_simplify<16>;reshape_empty_broadcast<1>;add_pad_pad_to_concat<1>;broadcast_reshape<1>;slice_reshape_concat<1>;slice_reshape_elementwise<1>;slice_reshape_transpose<1>;slice_reshape_dot_general<1>;concat_pad<1>;reduce_pad<1>;broadcast_pad<1>;zero_product_reshape_pad<1>;mul_zero_pad<1>;div_zero_pad<1>;binop_const_reshape_pad<1>;binop_const_pad_add<1>;binop_const_pad_subtract<1>;binop_const_pad_mul<1>;binop_const_pad_div<1>;slice_reshape_pad<1>;binop_binop_pad_pad_add<1>;binop_binop_pad_pad_mul<1>;binop_pad_pad_add<1>;binop_pad_pad_subtract<1>;binop_pad_pad_mul<1>;binop_pad_pad_div<1>;binop_pad_pad_min<1>;binop_pad_pad_max<1>;unary_pad_push_convert<1>;unary_pad_push_tanh<1>;unary_pad_push_exp<1>;transpose_pad<1>;transpose_dot_reorder<1>;dot_transpose<1>;convert_convert_float<1>;concat_to_pad<1>;concat_appending_reshape<1>;reshape_iota<1>;broadcast_reduce<1>;slice_dot_general<1>;dot_reshape_pad<1>;pad_dot_general<1>(0);dot_reshape_pad<1>;pad_dot_general<1>(1)},transform-interpreter,enzyme-hlo-remove-transform

@mofeing
Copy link
Collaborator

mofeing commented Nov 12, 2024

we have the same problem but with complex.conj not being raised to chlo.conj

@Pangoraw
Copy link
Collaborator

We probably want to disable the verifier for the pass-manager like in the complex lit tests.

#269.

@Todorbsc
Copy link
Collaborator Author

Todorbsc commented Nov 12, 2024

We probably want to disable the verifier for the pass-manager like in the complex lit tests.

#269.

I tested the PR #269 and it solves the complex.add and complex.conj issues.
However, the execution of the MWE crashes with this error now:

julia> @code_hlo ∇f(params′, expected′)
julia: external/llvm-project/llvm/include/llvm/Support/Casting.h:566: decltype(auto) llvm::cast(const From&) [with To = mlir::FloatType; From = mlir::Type]: La declaración `isa<To>(Val) && "cast<Ty>() argument of incompatible type!"' no se cumple.

[12417] signal (6.-6): Abortado
in expression starting at REPL[18]:1
pthread_kill at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
raise at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
abort at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
unknown function (ip: 0x719b8862871a)
__assert_fail at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
cast<mlir::FloatType, mlir::Type> at /proc/self/cwd/external/llvm-project/llvm/include/llvm/Support/Casting.h:566
cast<mlir::FloatType> at /proc/self/cwd/external/llvm-project/mlir/include/mlir/IR/Types.h:352
getConstantAttr at /proc/self/cwd/external/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp:30
createReverseModeAdjoint at /proc/self/cwd/bazel-out/k8-dbg/bin/external/enzyme_ad/src/enzyme_ad/jax/Implementations/StableHLODerivatives.inc:3182
createReverseModeAdjoint at /proc/self/cwd/bazel-out/k8-dbg/bin/external/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.h.inc:381
createReverseModeAdjoint at /proc/self/cwd/bazel-out/k8-dbg/bin/external/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.cpp.inc:50
visitChild at /proc/self/cwd/external/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp:51
visitChildren at /proc/self/cwd/external/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp:66
differentiate at /proc/self/cwd/external/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp:174
CreateReverseDiff at /proc/self/cwd/external/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp:244
HandleAutoDiffReverse<mlir::enzyme::AutoDiffOp> at /proc/self/cwd/external/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp:271
lowerEnzymeCalls at /proc/self/cwd/external/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp:325
operator() at /proc/self/cwd/external/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp:352
operator() at /proc/self/cwd/external/llvm-project/mlir/include/mlir/IR/Visitors.h:338
callback_fn<mlir::detail::walk<>(mlir::Operation*, (anonymous namespace)::DifferentiatePass::runOnOperation()::<lambda(mlir::FunctionOpInterface)>&&)::<lambda(mlir::Operation*)> > at /proc/self/cwd/external/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:45
operator() at /proc/self/cwd/external/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:68
walk<mlir::ForwardIterator> at /proc/self/cwd/external/llvm-project/mlir/include/mlir/IR/Visitors.h:186
walk<mlir::ForwardIterator> at /proc/self/cwd/external/llvm-project/mlir/include/mlir/IR/Visitors.h:181
walk<> at /proc/self/cwd/external/llvm-project/mlir/include/mlir/IR/Visitors.h:340
walk<> at /proc/self/cwd/external/llvm-project/mlir/include/mlir/IR/Operation.h:794
runOnOperation at /proc/self/cwd/external/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp:351
operator() at /proc/self/cwd/external/llvm-project/mlir/lib/Pass/Pass.cpp:526
callback_fn<mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int)::<lambda()> > at /proc/self/cwd/external/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:45
operator() at /proc/self/cwd/external/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:68
executeAction<mlir::PassExecutionAction, mlir::Pass&> at /proc/self/cwd/external/llvm-project/mlir/include/mlir/IR/MLIRContext.h:280
run at /proc/self/cwd/external/llvm-project/mlir/lib/Pass/Pass.cpp:520
runPipeline at /proc/self/cwd/external/llvm-project/mlir/lib/Pass/Pass.cpp:592
runPasses at /proc/self/cwd/external/llvm-project/mlir/lib/Pass/Pass.cpp:905
run at /proc/self/cwd/external/llvm-project/mlir/lib/Pass/Pass.cpp:885
mlirPassManagerRunOnOp at /proc/self/cwd/external/llvm-project/mlir/lib/CAPI/IR/Pass.cpp:44
mlirPassManagerRunOnOp at /home/tkrasimi/.julia/dev/Reactant/src/mlir/libMLIR_h.jl:5853 [inlined]
run! at /home/tkrasimi/.julia/dev/Reactant/src/mlir/IR/Pass.jl:74 [inlined]
#run_pass_pipeline!#1 at /home/tkrasimi/.julia/dev/Reactant/src/Compiler.jl:252
run_pass_pipeline! at /home/tkrasimi/.julia/dev/Reactant/src/Compiler.jl:247 [inlined]
#compile_mlir!#8 at /home/tkrasimi/.julia/dev/Reactant/src/Compiler.jl:298
compile_mlir! at /home/tkrasimi/.julia/dev/Reactant/src/Compiler.jl:280 [inlined]
#6 at /home/tkrasimi/.julia/dev/Reactant/src/Compiler.jl:275 [inlined]
context! at /home/tkrasimi/.julia/dev/Reactant/src/mlir/IR/Context.jl:76
unknown function (ip: 0x719b87fdd109)
_jl_invoke at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/gf.c:2894 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/gf.c:3076
#compile_mlir#5 at /home/tkrasimi/.julia/dev/Reactant/src/Compiler.jl:273
compile_mlir at /home/tkrasimi/.julia/dev/Reactant/src/Compiler.jl:270
unknown function (ip: 0x719b87fdb540)
_jl_invoke at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/gf.c:2894 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/gf.c:3076
jl_apply at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/julia.h:1982 [inlined]
do_call at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/interpreter.c:126
eval_value at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/interpreter.c:223
eval_stmt_value at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/interpreter.c:174 [inlined]
eval_body at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/interpreter.c:617
jl_interpret_toplevel_thunk at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/interpreter.c:775
jl_toplevel_eval_flex at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/toplevel.c:934
jl_toplevel_eval_flex at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/toplevel.c:877
ijl_toplevel_eval_in at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/toplevel.c:985
eval at ./boot.jl:385 [inlined]
eval_user_input at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:150
repl_backend_loop at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:246
#start_repl_backend#46 at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:231
start_repl_backend at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:228
_jl_invoke at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/gf.c:2894 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/gf.c:3076
#run_repl#59 at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:389
run_repl at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:375
jfptr_run_repl_91689.1 at /opt/julia-1.10.0/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/gf.c:2894 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/gf.c:3076
#1013 at ./client.jl:432
jfptr_YY.1013_82677.1 at /opt/julia-1.10.0/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/gf.c:2894 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/gf.c:3076
jl_apply at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/julia.h:1982 [inlined]
jl_f__call_latest at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/builtins.c:812
#invokelatest#2 at ./essentials.jl:887 [inlined]
invokelatest at ./essentials.jl:884 [inlined]
run_main_repl at ./client.jl:416
exec_options at ./client.jl:333
_start at ./client.jl:552
jfptr__start_82703.1 at /opt/julia-1.10.0/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/gf.c:2894 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/gf.c:3076
jl_apply at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/julia.h:1982 [inlined]
true_main at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/jlapi.c:582
jl_repl_entrypoint at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/jlapi.c:731
main at julia (unknown line)
unknown function (ip: 0x719b88629d8f)
__libc_start_main at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
unknown function (ip: 0x4010b8)
Allocations: 38032866 (Pool: 37988934; Big: 43932); GC: 51
Abortado (`core' generado)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

5 participants