From a85cd8cb56bbf64510a3182370cdebd38b1e7df4 Mon Sep 17 00:00:00 2001 From: romain veltz Date: Sun, 11 Feb 2024 16:33:41 +0100 Subject: [PATCH] refactor Orthogonal collocation a bit --- src/BifurcationKit.jl | 2 +- src/periodicorbit/PeriodicOrbitCollocation.jl | 20 +++++++++---------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/BifurcationKit.jl b/src/BifurcationKit.jl index 42c18e1d..0ac763cc 100644 --- a/src/BifurcationKit.jl +++ b/src/BifurcationKit.jl @@ -5,7 +5,7 @@ module BifurcationKit using Parameters: @with_kw, @unpack, @with_kw_noshow using RecursiveArrayTools: VectorOfArray using DocStringExtensions - using DataStructures: CircularBuffer + using DataStructures: CircularBuffer # used for Polynomial predictor using ForwardDiff diff --git a/src/periodicorbit/PeriodicOrbitCollocation.jl b/src/periodicorbit/PeriodicOrbitCollocation.jl index 27f0fd49..83e45c7e 100644 --- a/src/periodicorbit/PeriodicOrbitCollocation.jl +++ b/src/periodicorbit/PeriodicOrbitCollocation.jl @@ -38,12 +38,12 @@ struct MeshCollocationCache{T} full_mesh::Vector{T} end -function MeshCollocationCache(Ntst::Int, m::Int, Ty = Float64) - τs = LinRange{Ty}( 0, 1, Ntst + 1) |> collect - σs = LinRange{Ty}(-1, 1, m + 1) +function MeshCollocationCache(Ntst::Int, m::Int, 𝒯 = Float64) + τs = LinRange{𝒯}( 0, 1, Ntst + 1) |> collect + σs = LinRange{𝒯}(-1, 1, m + 1) L, ∂L = getL(σs) zg, wg = gausslegendre(m) - cache = MeshCollocationCache(Ntst, m, L, ∂L, zg, wg, τs, σs, zeros(Ty, 1 + m * Ntst)) + cache = MeshCollocationCache(Ntst, m, L, ∂L, zg, wg, τs, σs, zeros(𝒯, 1 + m * Ntst)) # put the mesh where we removed redundant timing cache.full_mesh .= get_times(cache) return cache @@ -59,7 +59,6 @@ get_max_time_step(pb::MeshCollocationCache) = maximum(diff(getmesh(pb))) # get the sigma corresponding to τ in the interval (𝜏s[j], 𝜏s[j+1]) @inline σj(τ, τs, j) = -(2*τ - τs[j] - τs[j + 1])/(-τs[j + 1] + τs[j]) -# code from Jacobi.lagrange function lagrange(i::Int, x, z) nz = length(z) l = one(z[1]) @@ -302,7 +301,7 @@ function Base.show(io::IO, pb::PeriodicOrbitOCollProblem) println(io, "└─ # unknowns : ", pb.N * (1 + m * Ntst)) end -function matrix_phase_condition(coll::PeriodicOrbitOCollProblem) +function get_matrix_phase_condition(coll::PeriodicOrbitOCollProblem) n, m, Ntst = size(coll) L, ∂L = get_Ls(coll.mesh_cache) ω = coll.mesh_cache.gauss_weight @@ -559,7 +558,7 @@ Compute the jacobian of the problem defining the periodic orbits by orthogonal c ρI = zero(𝒯)) where {𝒯} n, m, Ntst = size(coll) L, ∂L = get_Ls(coll.mesh_cache) # L is of size (m+1, m) - Ω = matrix_phase_condition(coll) + Ω = get_matrix_phase_condition(coll) mesh = getmesh(coll) period = getperiod(coll, u, nothing) uc = get_time_slices(coll, u) @@ -659,7 +658,7 @@ function jacobian_poocoll_block(coll::PeriodicOrbitOCollProblem, J = BlockArray(spzeros(length(u), length(u)), blocks, blocks) # temporaries L, ∂L = get_Ls(coll.mesh_cache) # L is of size (m+1, m) - Ω = matrix_phase_condition(coll) + Ω = get_matrix_phase_condition(coll) mesh = getmesh(coll) period = getperiod(coll, u, nothing) uc = get_time_slices(coll, u) @@ -740,7 +739,7 @@ end # J = BlockArray(spzeros(length(u), length(u)), blocks, blocks) # temporaries L, ∂L = get_Ls(coll.mesh_cache) # L is of size (m+1, m) - Ω = matrix_phase_condition(coll) + Ω = get_matrix_phase_condition(coll) mesh = getmesh(coll) period = getperiod(coll, u, nothing) uc = get_time_slices(coll, u) @@ -1034,7 +1033,7 @@ end ∂(f) = x -> ForwardDiff.derivative(f, x) ∂(f, n::Int) = n == 0 ? f : ∂(∂(f), n-1) -@views function (sol::POSolution{ <: PeriodicOrbitOCollProblem})(t0) +function (sol::POSolution{ <: PeriodicOrbitOCollProblem})(t0) n, m, Ntst = size(sol.pb) xc = get_time_slices(sol.pb, sol.x) @@ -1186,6 +1185,7 @@ function get_blocks(coll::PeriodicOrbitOCollProblem, Jac::SparseMatrixCSC) end #################################################################################################### @views function condensation_of_parameters(coll::PeriodicOrbitOCollProblem, J, rhs0) + #https://github.com/DynareJulia/FastLapackInterface.jl N, m, Ntst = size(coll) nbcoll = N * m nⱼ = size(J, 1)