Skip to content

Commit

Permalink
refactor Orthogonal collocation a bit
Browse files Browse the repository at this point in the history
  • Loading branch information
rveltz committed Feb 11, 2024
1 parent 09f10de commit a85cd8c
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion src/BifurcationKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ module BifurcationKit
using Parameters: @with_kw, @unpack, @with_kw_noshow
using RecursiveArrayTools: VectorOfArray
using DocStringExtensions
using DataStructures: CircularBuffer
using DataStructures: CircularBuffer # used for Polynomial predictor
using ForwardDiff


Expand Down
20 changes: 10 additions & 10 deletions src/periodicorbit/PeriodicOrbitCollocation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,12 @@ struct MeshCollocationCache{T}
full_mesh::Vector{T}
end

function MeshCollocationCache(Ntst::Int, m::Int, Ty = Float64)
τs = LinRange{Ty}( 0, 1, Ntst + 1) |> collect
σs = LinRange{Ty}(-1, 1, m + 1)
function MeshCollocationCache(Ntst::Int, m::Int, 𝒯 = Float64)
τs = LinRange{𝒯}( 0, 1, Ntst + 1) |> collect
σs = LinRange{𝒯}(-1, 1, m + 1)
L, ∂L = getL(σs)
zg, wg = gausslegendre(m)
cache = MeshCollocationCache(Ntst, m, L, ∂L, zg, wg, τs, σs, zeros(Ty, 1 + m * Ntst))
cache = MeshCollocationCache(Ntst, m, L, ∂L, zg, wg, τs, σs, zeros(𝒯, 1 + m * Ntst))
# put the mesh where we removed redundant timing
cache.full_mesh .= get_times(cache)
return cache
Expand All @@ -59,7 +59,6 @@ get_max_time_step(pb::MeshCollocationCache) = maximum(diff(getmesh(pb)))
# get the sigma corresponding to τ in the interval (𝜏s[j], 𝜏s[j+1])
@inline σj(τ, τs, j) = -(2*τ - τs[j] - τs[j + 1])/(-τs[j + 1] + τs[j])

# code from Jacobi.lagrange
function lagrange(i::Int, x, z)
nz = length(z)
l = one(z[1])
Expand Down Expand Up @@ -302,7 +301,7 @@ function Base.show(io::IO, pb::PeriodicOrbitOCollProblem)
println(io, "└─ # unknowns : ", pb.N * (1 + m * Ntst))
end

function matrix_phase_condition(coll::PeriodicOrbitOCollProblem)
function get_matrix_phase_condition(coll::PeriodicOrbitOCollProblem)
n, m, Ntst = size(coll)
L, ∂L = get_Ls(coll.mesh_cache)
ω = coll.mesh_cache.gauss_weight
Expand Down Expand Up @@ -559,7 +558,7 @@ Compute the jacobian of the problem defining the periodic orbits by orthogonal c
ρI = zero(𝒯)) where {𝒯}
n, m, Ntst = size(coll)
L, ∂L = get_Ls(coll.mesh_cache) # L is of size (m+1, m)
Ω = matrix_phase_condition(coll)
Ω = get_matrix_phase_condition(coll)
mesh = getmesh(coll)
period = getperiod(coll, u, nothing)
uc = get_time_slices(coll, u)
Expand Down Expand Up @@ -659,7 +658,7 @@ function jacobian_poocoll_block(coll::PeriodicOrbitOCollProblem,
J = BlockArray(spzeros(length(u), length(u)), blocks, blocks)
# temporaries
L, ∂L = get_Ls(coll.mesh_cache) # L is of size (m+1, m)
Ω = matrix_phase_condition(coll)
Ω = get_matrix_phase_condition(coll)
mesh = getmesh(coll)
period = getperiod(coll, u, nothing)
uc = get_time_slices(coll, u)
Expand Down Expand Up @@ -740,7 +739,7 @@ end
# J = BlockArray(spzeros(length(u), length(u)), blocks, blocks)
# temporaries
L, ∂L = get_Ls(coll.mesh_cache) # L is of size (m+1, m)
Ω = matrix_phase_condition(coll)
Ω = get_matrix_phase_condition(coll)
mesh = getmesh(coll)
period = getperiod(coll, u, nothing)
uc = get_time_slices(coll, u)
Expand Down Expand Up @@ -1034,7 +1033,7 @@ end
(f) = x -> ForwardDiff.derivative(f, x)
(f, n::Int) = n == 0 ? f : ((f), n-1)

@views function (sol::POSolution{ <: PeriodicOrbitOCollProblem})(t0)
function (sol::POSolution{ <: PeriodicOrbitOCollProblem})(t0)
n, m, Ntst = size(sol.pb)
xc = get_time_slices(sol.pb, sol.x)

Expand Down Expand Up @@ -1186,6 +1185,7 @@ function get_blocks(coll::PeriodicOrbitOCollProblem, Jac::SparseMatrixCSC)
end
####################################################################################################
@views function condensation_of_parameters(coll::PeriodicOrbitOCollProblem, J, rhs0)
#https://github.com/DynareJulia/FastLapackInterface.jl
N, m, Ntst = size(coll)
nbcoll = N * m
nⱼ = size(J, 1)
Expand Down

0 comments on commit a85cd8c

Please sign in to comment.