Skip to content

Commit

Permalink
refactor collocation methods
Browse files Browse the repository at this point in the history
  • Loading branch information
rveltz committed Feb 11, 2024
1 parent a85cd8c commit 15ffef9
Showing 1 changed file with 53 additions and 47 deletions.
100 changes: 53 additions & 47 deletions src/periodicorbit/PeriodicOrbitCollocation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@ using FastGaussQuadrature: gausslegendre
"""
cache = MeshCollocationCache(Ntst::Int, m::Int, Ty = Float64)
Structure to hold the cache for the collocation method.
Structure to hold the cache for the collocation method. More precisely, it starts from a partition of [0,1] based on the mesh points:
0 = τ₁ < τ₂ < ... < τₙₜₛₜ₊₁ = 1
On each mesh interval [τⱼ, τⱼ₊₁] mapped to [-1,1], a Legendre polynomial of degree m is formed.
$(TYPEDFIELDS)
Expand All @@ -17,46 +22,45 @@ $(TYPEDFIELDS)
- `m` degree of the collocation polynomials
- `Ty` type of the time variable
"""
struct MeshCollocationCache{T}
struct MeshCollocationCache{𝒯}
"Coarse mesh size"
Ntst::Int
"Collocation degree, usually called m"
degree::Int
"Lagrange matrix"
lagrange_vals::Matrix{T}
lagrange_vals::Matrix{𝒯}
"Lagrange matrix for derivative"
lagrange_∂::Matrix{T}
lagrange_∂::Matrix{𝒯}
"Gauss nodes"
gauss_nodes::Vector{T}
gauss_nodes::Vector{𝒯}
"Gauss weights"
gauss_weight::Vector{T}
"Values for the coarse mesh, call τj. This can be adapted."
mesh::Vector{T}
"Values for collocation poinnts, call σj. These are fixed."
mesh_coll::LinRange{T}
gauss_weight::Vector{𝒯}
"Values of the coarse mesh, call τj. This can be adapted."
τs::Vector{𝒯}
"Values of collocation points, call σj. These are fixed."
σs::LinRange{𝒯}
"Full mesh containing both the coarse mesh and the collocation points."
full_mesh::Vector{T}
full_mesh::Vector{𝒯}
end

function MeshCollocationCache(Ntst::Int, m::Int, 𝒯 = Float64)
τs = LinRange{𝒯}( 0, 1, Ntst + 1) |> collect
σs = LinRange{𝒯}(-1, 1, m + 1)
L, ∂L = getL(σs)
zg, wg = gausslegendre(m)
cache = MeshCollocationCache(Ntst, m, L, ∂L, zg, wg, τs, σs, zeros(𝒯, 1 + m * Ntst))
# put the mesh where we removed redundant timing
L, ∂L, zg, wg = compute_legendre_matrices(σs)
cache = MeshCollocationCache{𝒯}(Ntst, m, L, ∂L, zg, wg, τs, σs, zeros(𝒯, 1 + m * Ntst))
# save the mesh where we removed redundant timing
cache.full_mesh .= get_times(cache)
return cache
end

@inline Base.eltype(pb::MeshCollocationCache{T}) where T = T
@inline Base.size(pb::MeshCollocationCache) = (pb.degree, pb.Ntst)
@inline get_Ls(pb::MeshCollocationCache) = (pb.lagrange_vals, pb.lagrange_∂)
@inline getmesh(pb::MeshCollocationCache) = pb.mesh
@inline get_mesh_coll(pb::MeshCollocationCache) = pb.mesh_coll
get_max_time_step(pb::MeshCollocationCache) = maximum(diff(getmesh(pb)))
@inline Base.eltype(cache::MeshCollocationCache{𝒯}) where 𝒯 = 𝒯
@inline Base.size(cache::MeshCollocationCache) = (cache.degree, cache.Ntst)
@inline get_Ls(cache::MeshCollocationCache) = (cache.lagrange_vals, cache.lagrange_∂)
@inline getmesh(cache::MeshCollocationCache) = cache.τs
@inline get_mesh_coll(cache::MeshCollocationCache) = cache.σs
get_max_time_step(cache::MeshCollocationCache) = maximum(diff(getmesh(cache)))
@inline τj(σ, τs, j) = τs[j] + (1 + σ)/2 * (τs[j+1] - τs[j])
# get the sigma corresponding to τ in the interval (𝜏s[j], 𝜏s[j+1])
# get the sigma corresponding to τ in the interval (τs[j], τs[j+1])
@inline σj(τ, τs, j) = -(2*τ - τs[j] - τs[j + 1])/(-τs[j + 1] + τs[j])

function lagrange(i::Int, x, z)
Expand All @@ -74,44 +78,44 @@ end
dlagrange(i, x, z) = ForwardDiff.derivative(x -> lagrange(i, x, z), x)

# should accept a range, ie σs = LinRange(-1, 1, m + 1)
function getL(σs::AbstractVector)
function compute_legendre_matrices(σs::AbstractVector{𝒯}) where 𝒯
m = length(σs) - 1
zs, = gausslegendre(m)
L = zeros(m + 1, m)
∂L = zeros(m + 1, m)
zs, ws = gausslegendre(m)
L = zeros(𝒯, m + 1, m)
∂L = zeros(𝒯, m + 1, m)
for j in 1:m+1
for i in 1:m
L[j, i] = lagrange(j, zs[i], σs)
∂L[j, i] = dlagrange(j, zs[i], σs)
end
end
return (;L, ∂L)
return (;L, ∂L, zg = zs, wg = ws)
end

"""
$(SIGNATURES)
Return all the times at which the problem is evaluated.
"""
function get_times(pb::MeshCollocationCache)
m, Ntst = size(pb)
Ty = eltype(pb)
ts = zero(Ty)
tsvec = Ty[0]
τs = pb.mesh
σs = pb.mesh_coll
function get_times(cache::MeshCollocationCache{𝒯}) where 𝒯
m, Ntst = size(cache)
tsvec = zeros(𝒯, m * Ntst + 1)
τs = cache.τs
σs = cache.σs
ind = 2
for j in 1:Ntst
for l in 1:m+1
ts = τj(σs[l], τs, j)
l>1 && push!(tsvec, τj(σs[l], τs, j))
for l in 2:m+1
@inbounds t = τj(σs[l], τs, j)
tsvec[ind] = t
ind +=1
end
end
return vec(tsvec)
return tsvec
end

function update_mesh!(pb::MeshCollocationCache, mesh)
pb.mesh .= mesh
pb.full_mesh .= get_times(pb)
function update_mesh!(cache::MeshCollocationCache, τs)
cache.τs .= τs
cache.full_mesh .= get_times(cache)
end
####################################################################################################
"""
Expand Down Expand Up @@ -274,7 +278,7 @@ get_time_slices(x::AbstractVector, N, degree, Ntst) = reshape(x, N, degree * Nts
get_time_slices(pb::PeriodicOrbitOCollProblem, x) = @views get_time_slices(x[1:end-1], size(pb)...)
get_times(pb::PeriodicOrbitOCollProblem) = get_times(pb.mesh_cache)
"""
Returns the vector of size m+1, 0 = τ1 < τ1 < ... < τm+1 = 1
Returns the vector of size m+1, 0 = τ₁ < τ₂ < ... < τₘ < τₘ₊₁ = 1
"""
getmesh(pb::PeriodicOrbitOCollProblem) = getmesh(pb.mesh_cache)
get_mesh_coll(pb::PeriodicOrbitOCollProblem) = get_mesh_coll(pb.mesh_cache)
Expand Down Expand Up @@ -855,7 +859,9 @@ jacobian(prob::WrapPOColl, x, p) = prob.jacobian(x, p)
# for recording the solution in a branch
function getsolution(wrap::WrapPOColl, x)
if wrap.prob.meshadapt
return (mesh = copy(get_times(wrap.prob)), sol = x, _mesh = copy(wrap.prob.mesh_cache.mesh))
return (mesh = copy(get_times(wrap.prob)),
sol = x,
_mesh = copy(wrap.prob.mesh_cache.τs))
else
return x
end
Expand Down Expand Up @@ -1078,8 +1084,8 @@ function compute_error!(pb::PeriodicOrbitOCollProblem, x::AbstractVector{Ty};
kw...) where Ty
n, m, Ntst = size(pb)
period = getperiod(pb, x, nothing)
# get solution
sol = POSolution(deepcopy(pb), x)
# get solution, we copy x because it is overwritten at the end
sol = POSolution(deepcopy(pb), copy(x))
# derivative of degree m, indeed ∂(sol, m+1) = 0
dmsol = (sol, m)
# we find the values of vm := ∂m(x) at the mid points
Expand All @@ -1091,7 +1097,7 @@ function compute_error!(pb::PeriodicOrbitOCollProblem, x::AbstractVector{Ty};
# this is the function s^{(k)} in the above paper on page 63
# we want to estimate sk = s^{(m+1)} which is 0 by definition, pol of degree m
if isempty(findall(diff(meshT) .<= 0)) == false
@error "[In mesh-adaptation]. The mesh is non monotonic! Please report the error to the website of BifurcationKit.jl"
@error "[Mesh-adaptation]. The mesh is non monotonic! Please report the error to the website of BifurcationKit.jl"
return (success = false, newmeshT = meshT, ϕ = meshT)
end
sk = Ty[]
Expand Down Expand Up @@ -1157,7 +1163,7 @@ function compute_error!(pb::PeriodicOrbitOCollProblem, x::AbstractVector{Ty};

############
# update solution
newsol = generate_solution(pb, t -> sol(t), period)
newsol = generate_solution(pb, sol, period)
x .= newsol

success = true
Expand Down

0 comments on commit 15ffef9

Please sign in to comment.