Skip to content

Commit 02fc15a

Browse files
Merge branch 'patch'
2 parents bb38f0e + 3bccf29 commit 02fc15a

File tree

2 files changed

+21
-6
lines changed

2 files changed

+21
-6
lines changed

src/common_interface/solve.jl

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ function DiffEqBase.init{uType, tType, isinplace, Method, LinearSolver}(
6666
sizeu = size(prob.u0)
6767

6868
### Fix the more general function to Sundials allowed style
69-
if !isinplace && (typeof(prob.u0)<:Vector{Float64} || typeof(prob.u0)<:Number)
69+
if !isinplace && typeof(prob.u0)<:Number
70+
f! = (du, u, p, t) -> (du .= prob.f(first(u), p, t); Cint(0))
71+
elseif !isinplace && typeof(prob.u0)<:Vector{Float64}
7072
f! = (du, u, p, t) -> (du .= prob.f(u, p, t); Cint(0))
7173
elseif !isinplace && typeof(prob.u0)<:AbstractArray
7274
f! = (du, u, p, t) -> (du .= vec(prob.f(reshape(u, sizeu), p, t)); Cint(0))
@@ -284,7 +286,9 @@ function DiffEqBase.init{uType, tType, isinplace, Method, LinearSolver}(
284286
u0nv = NVector(u0)
285287

286288
### Fix the more general function to Sundials allowed style
287-
if !isinplace && (typeof(prob.u0)<:Vector{Float64} || typeof(prob.u0)<:Number)
289+
if !isinplace && typeof(prob.u0)<:Number
290+
f! = (du, u, p, t) -> (du .= prob.f(first(u), p, t); Cint(0))
291+
elseif !isinplace && typeof(prob.u0)<:Vector{Float64}
288292
f! = (du, u, p, t) -> (du .= prob.f(u, p, t); Cint(0))
289293
elseif !isinplace && typeof(prob.u0)<:AbstractArray
290294
f! = (du, u, p, t) -> (du .= vec(prob.f(reshape(u, sizeu), p, t)); Cint(0))
@@ -298,7 +302,10 @@ function DiffEqBase.init{uType, tType, isinplace, Method, LinearSolver}(
298302
if typeof(prob.problem_type) <: SplitODEProblem
299303

300304
### Fix the more general function to Sundials allowed style
301-
if !isinplace && (typeof(prob.u0)<:Vector{Float64} || typeof(prob.u0)<:Number)
305+
if !isinplace && typeof(prob.u0)<:Number
306+
f1! = (du, u, p, t) -> (du .= prob.f.f1(first(u), p, t); Cint(0))
307+
f2! = (du, u, p, t) -> (du .= prob.f.f2(first(u), p, t); Cint(0))
308+
elseif !isinplace && typeof(prob.u0)<:Vector{Float64}
302309
f1! = (du, u, p, t) -> (du .= prob.f.f1(u, p, t); Cint(0))
303310
f2! = (du, u, p, t) -> (du .= prob.f.f2(u, p, t); Cint(0))
304311
elseif !isinplace && typeof(prob.u0)<:AbstractArray
@@ -560,10 +567,12 @@ function DiffEqBase.init{uType, duType, tType, isinplace, LinearSolver}(
560567
sizedu = size(prob.du0)
561568

562569
### Fix the more general function to Sundials allowed style
563-
if !isinplace && (typeof(prob.u0)<:Vector{Float64} || typeof(prob.u0)<:Number)
564-
f! = (out, du, u, p, t) -> (out[:] = prob.f(du, u, p, t); Cint(0))
570+
if !isinplace && typeof(prob.u0)<:Number
571+
f! = (out, du, u, p, t) -> (out .= prob.f(first(du),first(u), p, t); Cint(0))
572+
elseif !isinplace && typeof(prob.u0)<:Vector{Float64}
573+
f! = (out, du, u, p, t) -> (out .= prob.f(du, u, p, t); Cint(0))
565574
elseif !isinplace && typeof(prob.u0)<:AbstractArray
566-
f! = (out, du, u, p, t) -> (out[:] = vec(
575+
f! = (out, du, u, p, t) -> (out .= vec(
567576
prob.f(reshape(du, sizedu), reshape(u, sizeu), p, t)
568577
);Cint(0))
569578
elseif typeof(prob.u0)<:Vector{Float64}

test/common_interface/cvode.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,9 @@ prob = deepcopy(prob_ode_2Dlinear)
7575
prob2 = ODEProblem(prob.f,prob.u0,(1.0,0.0))
7676
sol = solve(prob2,CVODE_BDF())
7777
@test maximum(diff(sol.t)) < 0 # Make sure all go negative
78+
79+
number_test(u,p,t) = -u^2 + (p[1] + t + p[2])*u + p[2]
80+
u0 = 0.0;
81+
tspan = (0.0, 10)
82+
prob = ODEProblem(number_test,u0,tspan,(2.0,0.01))
83+
sol = solve(prob,CVODE_BDF())

0 commit comments

Comments
 (0)