-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Undefined fill(::ComplexF64, ::Reactant.MLIR.IR.Type)
method required for differentiating complex function
#238
Comments
I think this is related to previous bugs on deviating complex numbers. At least, it's crashing in the same line. |
Reactant.@compile
for complex types parametersfill(::ComplexF64, ::Reactant.MLIR.IR.Type)
method required for differentiating complex function
On main: julia> @code_hlo optimize=false ∇f(params′, expected′)
Module:
module {
func.func private @"-_broadcast_scalar"(%arg0: tensor<complex<f64>>, %arg1: tensor<complex<f64>>) -> (tensor<complex<f64>>, tensor<complex<f64>>, tensor<complex<f64>>) {
%0 = stablehlo.transpose %arg0, dims = [] : (tensor<complex<f64>>) -> tensor<complex<f64>>
%1 = stablehlo.transpose %arg1, dims = [] : (tensor<complex<f64>>) -> tensor<complex<f64>>
%2 = stablehlo.subtract %0, %1 : tensor<complex<f64>>
%3 = stablehlo.transpose %2, dims = [] : (tensor<complex<f64>>) -> tensor<complex<f64>>
%4 = stablehlo.transpose %0, dims = [] : (tensor<complex<f64>>) -> tensor<complex<f64>>
%5 = stablehlo.transpose %1, dims = [] : (tensor<complex<f64>>) -> tensor<complex<f64>>
return %3, %4, %5 : tensor<complex<f64>>, tensor<complex<f64>>, tensor<complex<f64>>
}
func.func private @abs_broadcast_scalar(%arg0: tensor<complex<f64>>) -> (tensor<f64>, tensor<complex<f64>>) {
%0 = stablehlo.transpose %arg0, dims = [] : (tensor<complex<f64>>) -> tensor<complex<f64>>
%1 = stablehlo.abs %0 : (tensor<complex<f64>>) -> tensor<f64>
%2 = stablehlo.transpose %1, dims = [] : (tensor<f64>) -> tensor<f64>
%3 = stablehlo.transpose %0, dims = [] : (tensor<complex<f64>>) -> tensor<complex<f64>>
return %2, %3 : tensor<f64>, tensor<complex<f64>>
}
func.func private @identity_broadcast_scalar(%arg0: tensor<f64>) -> tensor<f64> {
%0 = stablehlo.transpose %arg0, dims = [] : (tensor<f64>) -> tensor<f64>
%1 = stablehlo.transpose %0, dims = [] : (tensor<f64>) -> tensor<f64>
return %1 : tensor<f64>
}
func.func private @"Const{typeof(f1)}(Main.f1)_autodiff"(%arg0: tensor<10xcomplex<f64>>, %arg1: tensor<10xcomplex<f64>>) -> (tensor<f64>, tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>) {
%0 = stablehlo.transpose %arg0, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
%1 = stablehlo.transpose %arg1, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
%2 = stablehlo.broadcast_in_dim %1, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
%3 = stablehlo.broadcast_in_dim %0, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
%4:3 = enzyme.batch @"-_broadcast_scalar"(%2, %3) {batch_shape = array<i64: 10>} : (tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>) -> (tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>)
%5 = stablehlo.broadcast_in_dim %4#0, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
%6:2 = enzyme.batch @abs_broadcast_scalar(%5) {batch_shape = array<i64: 10>} : (tensor<10xcomplex<f64>>) -> (tensor<10xf64>, tensor<10xcomplex<f64>>)
%cst = stablehlo.constant dense<0.000000e+00> : tensor<f64>
%7 = stablehlo.broadcast_in_dim %6#0, dims = [0] : (tensor<10xf64>) -> tensor<10xf64>
%8 = enzyme.batch @identity_broadcast_scalar(%7) {batch_shape = array<i64: 10>} : (tensor<10xf64>) -> tensor<10xf64>
%9 = stablehlo.reduce(%8 init: %cst) applies stablehlo.add across dimensions = [0] : (tensor<10xf64>, tensor<f64>) -> tensor<f64>
%10 = stablehlo.transpose %9, dims = [] : (tensor<f64>) -> tensor<f64>
%11 = stablehlo.transpose %0, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
%12 = stablehlo.transpose %1, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
return %10, %11, %12 : tensor<f64>, tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>
}
func.func @main(%arg0: tensor<10xcomplex<f64>>, %arg1: tensor<10xcomplex<f64>>) -> (tensor<f64>, tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>) {
%0 = stablehlo.transpose %arg0, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
%1 = stablehlo.transpose %arg1, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
%cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<10xcomplex<f64>>
%cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<f64>
%2 = stablehlo.transpose %0, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
%3 = stablehlo.transpose %1, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
%4 = stablehlo.transpose %cst_0, dims = [] : (tensor<f64>) -> tensor<f64>
%5 = stablehlo.transpose %cst, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
%6:4 = enzyme.autodiff @"Const{typeof(f1)}(Main.f1)_autodiff"(%2, %3, %4, %5) {activity = [#enzyme<activity enzyme_active>, #enzyme<activity enzyme_const>], ret_activity = [#enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_const>]} : (tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>, tensor<f64>, tensor<10xcomplex<f64>>) -> (tensor<f64>, tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>)
%7 = stablehlo.transpose %6#0, dims = [] : (tensor<f64>) -> tensor<f64>
%8 = stablehlo.transpose %6#1, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
%9 = stablehlo.transpose %6#2, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
%10 = stablehlo.transpose %6#3, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
%11 = stablehlo.transpose %7, dims = [] : (tensor<f64>) -> tensor<f64>
%12 = stablehlo.transpose %10, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
%13 = stablehlo.transpose %8, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
%14 = stablehlo.transpose %9, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
return %11, %12, %13, %14 : tensor<f64>, tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>
}
} We are generating something incorrect here: error: 'complex.add' op operand #0 must be complex type with floating-point elements, but got 'tensor<10xcomplex<f64>>' |
I believe the problem is that but where is that |
I was under the impression, stablehlo.add is being lowered to complex.add? (at least from the IR which only has stablehlo.add) |
mmm can you show it with |
That errors |
actually nvm, complex.add should be raised to stablehlo https://github.com/EnzymeAD/Enzyme-JAX/blob/a1b5a6048dc358f54c5aedc2b737a3af64a2c80f/src/enzyme_ad/jax/Passes/ArithRaising.cpp#L52 |
@wsmoses any idea where #238 (comment) is originating from? Lol you replied too fast on the other issue 😅 |
lol yeah if you can spit out it and the pass pipeline we run, we can run it with enzymexlamlir-opt manually and see what’s up |
Module:
module {
func.func private @"-_broadcast_scalar"(%arg0: tensor<complex<f64>>, %arg1: tensor<complex<f64>>) -> (tensor<complex<f64>>, tensor<complex<f64>>, tensor<complex<f64>>) {
%0 = stablehlo.transpose %arg0, dims = [] : (tensor<complex<f64>>) -> tensor<complex<f64>>
%1 = stablehlo.transpose %arg1, dims = [] : (tensor<complex<f64>>) -> tensor<complex<f64>>
%2 = stablehlo.subtract %0, %1 : tensor<complex<f64>>
%3 = stablehlo.transpose %2, dims = [] : (tensor<complex<f64>>) -> tensor<complex<f64>>
%4 = stablehlo.transpose %0, dims = [] : (tensor<complex<f64>>) -> tensor<complex<f64>>
%5 = stablehlo.transpose %1, dims = [] : (tensor<complex<f64>>) -> tensor<complex<f64>>
return %3, %4, %5 : tensor<complex<f64>>, tensor<complex<f64>>, tensor<complex<f64>>
}
func.func private @abs_broadcast_scalar(%arg0: tensor<complex<f64>>) -> (tensor<f64>, tensor<complex<f64>>) {
%0 = stablehlo.transpose %arg0, dims = [] : (tensor<complex<f64>>) -> tensor<complex<f64>>
%1 = stablehlo.abs %0 : (tensor<complex<f64>>) -> tensor<f64>
%2 = stablehlo.transpose %1, dims = [] : (tensor<f64>) -> tensor<f64>
%3 = stablehlo.transpose %0, dims = [] : (tensor<complex<f64>>) -> tensor<complex<f64>>
return %2, %3 : tensor<f64>, tensor<complex<f64>>
}
func.func private @identity_broadcast_scalar(%arg0: tensor<f64>) -> tensor<f64> {
%0 = stablehlo.transpose %arg0, dims = [] : (tensor<f64>) -> tensor<f64>
%1 = stablehlo.transpose %0, dims = [] : (tensor<f64>) -> tensor<f64>
return %1 : tensor<f64>
}
func.func private @"Const{typeof(f1)}(Main.f1)_autodiff"(%arg0: tensor<10xcomplex<f64>>, %arg1: tensor<10xcomplex<f64>>) -> (tensor<f64>, tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>) {
%0 = stablehlo.transpose %arg0, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
%1 = stablehlo.transpose %arg1, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
%2 = stablehlo.broadcast_in_dim %1, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
%3 = stablehlo.broadcast_in_dim %0, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
%4:3 = enzyme.batch @"-_broadcast_scalar"(%2, %3) {batch_shape = array<i64: 10>} : (tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>) -> (tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>)
%5 = stablehlo.broadcast_in_dim %4#0, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
%6:2 = enzyme.batch @abs_broadcast_scalar(%5) {batch_shape = array<i64: 10>} : (tensor<10xcomplex<f64>>) -> (tensor<10xf64>, tensor<10xcomplex<f64>>)
%cst = stablehlo.constant dense<0.000000e+00> : tensor<f64>
%7 = stablehlo.broadcast_in_dim %6#0, dims = [0] : (tensor<10xf64>) -> tensor<10xf64>
%8 = enzyme.batch @identity_broadcast_scalar(%7) {batch_shape = array<i64: 10>} : (tensor<10xf64>) -> tensor<10xf64>
%9 = stablehlo.reduce(%8 init: %cst) applies stablehlo.add across dimensions = [0] : (tensor<10xf64>, tensor<f64>) -> tensor<f64>
%10 = stablehlo.transpose %9, dims = [] : (tensor<f64>) -> tensor<f64>
%11 = stablehlo.transpose %0, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
%12 = stablehlo.transpose %1, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
return %10, %11, %12 : tensor<f64>, tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>
}
func.func @main(%arg0: tensor<10xcomplex<f64>>, %arg1: tensor<10xcomplex<f64>>) -> (tensor<f64>, tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>) {
%0 = stablehlo.transpose %arg0, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
%1 = stablehlo.transpose %arg1, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
%cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<10xcomplex<f64>>
%cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<f64>
%2 = stablehlo.transpose %0, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
%3 = stablehlo.transpose %1, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
%4 = stablehlo.transpose %cst_0, dims = [] : (tensor<f64>) -> tensor<f64>
%5 = stablehlo.transpose %cst, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
%6:4 = enzyme.autodiff @"Const{typeof(f1)}(Main.f1)_autodiff"(%2, %3, %4, %5) {activity = [#enzyme<activity enzyme_active>, #enzyme<activity enzyme_const>], ret_activity = [#enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_const>]} : (tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>, tensor<f64>, tensor<10xcomplex<f64>>) -> (tensor<f64>, tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>)
%7 = stablehlo.transpose %6#0, dims = [] : (tensor<f64>) -> tensor<f64>
%8 = stablehlo.transpose %6#1, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
%9 = stablehlo.transpose %6#2, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
%10 = stablehlo.transpose %6#3, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
%11 = stablehlo.transpose %7, dims = [] : (tensor<f64>) -> tensor<f64>
%12 = stablehlo.transpose %10, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
%13 = stablehlo.transpose %8, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
%14 = stablehlo.transpose %9, dims = [0] : (tensor<10xcomplex<f64>>) -> tensor<10xcomplex<f64>>
return %11, %12, %13, %14 : tensor<f64>, tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>, tensor<10xcomplex<f64>>
}
} ^this works? Or do you need additional info. |
I think the critical thing here is more so the optimization pass pipeline itself. |
I do see an arith raise pass Line 291 in f2a91bf
|
This is the full pipeline
|
we have the same problem but with |
We probably want to disable the verifier for the pass-manager like in the complex lit tests. #269. |
I tested the PR #269 and it solves the julia> @code_hlo ∇f(params′, expected′)
julia: external/llvm-project/llvm/include/llvm/Support/Casting.h:566: decltype(auto) llvm::cast(const From&) [with To = mlir::FloatType; From = mlir::Type]: La declaración `isa<To>(Val) && "cast<Ty>() argument of incompatible type!"' no se cumple.
[12417] signal (6.-6): Abortado
in expression starting at REPL[18]:1
pthread_kill at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
raise at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
abort at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
unknown function (ip: 0x719b8862871a)
__assert_fail at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
cast<mlir::FloatType, mlir::Type> at /proc/self/cwd/external/llvm-project/llvm/include/llvm/Support/Casting.h:566
cast<mlir::FloatType> at /proc/self/cwd/external/llvm-project/mlir/include/mlir/IR/Types.h:352
getConstantAttr at /proc/self/cwd/external/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp:30
createReverseModeAdjoint at /proc/self/cwd/bazel-out/k8-dbg/bin/external/enzyme_ad/src/enzyme_ad/jax/Implementations/StableHLODerivatives.inc:3182
createReverseModeAdjoint at /proc/self/cwd/bazel-out/k8-dbg/bin/external/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.h.inc:381
createReverseModeAdjoint at /proc/self/cwd/bazel-out/k8-dbg/bin/external/enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.cpp.inc:50
visitChild at /proc/self/cwd/external/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp:51
visitChildren at /proc/self/cwd/external/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp:66
differentiate at /proc/self/cwd/external/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp:174
CreateReverseDiff at /proc/self/cwd/external/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp:244
HandleAutoDiffReverse<mlir::enzyme::AutoDiffOp> at /proc/self/cwd/external/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp:271
lowerEnzymeCalls at /proc/self/cwd/external/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp:325
operator() at /proc/self/cwd/external/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp:352
operator() at /proc/self/cwd/external/llvm-project/mlir/include/mlir/IR/Visitors.h:338
callback_fn<mlir::detail::walk<>(mlir::Operation*, (anonymous namespace)::DifferentiatePass::runOnOperation()::<lambda(mlir::FunctionOpInterface)>&&)::<lambda(mlir::Operation*)> > at /proc/self/cwd/external/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:45
operator() at /proc/self/cwd/external/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:68
walk<mlir::ForwardIterator> at /proc/self/cwd/external/llvm-project/mlir/include/mlir/IR/Visitors.h:186
walk<mlir::ForwardIterator> at /proc/self/cwd/external/llvm-project/mlir/include/mlir/IR/Visitors.h:181
walk<> at /proc/self/cwd/external/llvm-project/mlir/include/mlir/IR/Visitors.h:340
walk<> at /proc/self/cwd/external/llvm-project/mlir/include/mlir/IR/Operation.h:794
runOnOperation at /proc/self/cwd/external/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp:351
operator() at /proc/self/cwd/external/llvm-project/mlir/lib/Pass/Pass.cpp:526
callback_fn<mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int)::<lambda()> > at /proc/self/cwd/external/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:45
operator() at /proc/self/cwd/external/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:68
executeAction<mlir::PassExecutionAction, mlir::Pass&> at /proc/self/cwd/external/llvm-project/mlir/include/mlir/IR/MLIRContext.h:280
run at /proc/self/cwd/external/llvm-project/mlir/lib/Pass/Pass.cpp:520
runPipeline at /proc/self/cwd/external/llvm-project/mlir/lib/Pass/Pass.cpp:592
runPasses at /proc/self/cwd/external/llvm-project/mlir/lib/Pass/Pass.cpp:905
run at /proc/self/cwd/external/llvm-project/mlir/lib/Pass/Pass.cpp:885
mlirPassManagerRunOnOp at /proc/self/cwd/external/llvm-project/mlir/lib/CAPI/IR/Pass.cpp:44
mlirPassManagerRunOnOp at /home/tkrasimi/.julia/dev/Reactant/src/mlir/libMLIR_h.jl:5853 [inlined]
run! at /home/tkrasimi/.julia/dev/Reactant/src/mlir/IR/Pass.jl:74 [inlined]
#run_pass_pipeline!#1 at /home/tkrasimi/.julia/dev/Reactant/src/Compiler.jl:252
run_pass_pipeline! at /home/tkrasimi/.julia/dev/Reactant/src/Compiler.jl:247 [inlined]
#compile_mlir!#8 at /home/tkrasimi/.julia/dev/Reactant/src/Compiler.jl:298
compile_mlir! at /home/tkrasimi/.julia/dev/Reactant/src/Compiler.jl:280 [inlined]
#6 at /home/tkrasimi/.julia/dev/Reactant/src/Compiler.jl:275 [inlined]
context! at /home/tkrasimi/.julia/dev/Reactant/src/mlir/IR/Context.jl:76
unknown function (ip: 0x719b87fdd109)
_jl_invoke at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/gf.c:2894 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/gf.c:3076
#compile_mlir#5 at /home/tkrasimi/.julia/dev/Reactant/src/Compiler.jl:273
compile_mlir at /home/tkrasimi/.julia/dev/Reactant/src/Compiler.jl:270
unknown function (ip: 0x719b87fdb540)
_jl_invoke at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/gf.c:2894 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/gf.c:3076
jl_apply at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/julia.h:1982 [inlined]
do_call at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/interpreter.c:126
eval_value at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/interpreter.c:223
eval_stmt_value at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/interpreter.c:174 [inlined]
eval_body at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/interpreter.c:617
jl_interpret_toplevel_thunk at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/interpreter.c:775
jl_toplevel_eval_flex at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/toplevel.c:934
jl_toplevel_eval_flex at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/toplevel.c:877
ijl_toplevel_eval_in at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/toplevel.c:985
eval at ./boot.jl:385 [inlined]
eval_user_input at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:150
repl_backend_loop at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:246
#start_repl_backend#46 at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:231
start_repl_backend at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:228
_jl_invoke at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/gf.c:2894 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/gf.c:3076
#run_repl#59 at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:389
run_repl at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/usr/share/julia/stdlib/v1.10/REPL/src/REPL.jl:375
jfptr_run_repl_91689.1 at /opt/julia-1.10.0/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/gf.c:2894 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/gf.c:3076
#1013 at ./client.jl:432
jfptr_YY.1013_82677.1 at /opt/julia-1.10.0/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/gf.c:2894 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/gf.c:3076
jl_apply at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/julia.h:1982 [inlined]
jl_f__call_latest at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/builtins.c:812
#invokelatest#2 at ./essentials.jl:887 [inlined]
invokelatest at ./essentials.jl:884 [inlined]
run_main_repl at ./client.jl:416
exec_options at ./client.jl:333
_start at ./client.jl:552
jfptr__start_82703.1 at /opt/julia-1.10.0/lib/julia/sys.so (unknown line)
_jl_invoke at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/gf.c:2894 [inlined]
ijl_apply_generic at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/gf.c:3076
jl_apply at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/julia.h:1982 [inlined]
true_main at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/jlapi.c:582
jl_repl_entrypoint at /cache/build/builder-amdci4-6/julialang/julia-release-1-dot-10/src/jlapi.c:731
main at julia (unknown line)
unknown function (ip: 0x719b88629d8f)
__libc_start_main at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
unknown function (ip: 0x4010b8)
Allocations: 38032866 (Pool: 37988934; Big: 43932); GC: 51
Abortado (`core' generado) |
CC @mofeing
I get an error when trying to compile a gradient of a function that accepts complex parameters.
This does NOT happen if the parameters are real:
The text was updated successfully, but these errors were encountered: