diff --git a/Project.toml b/Project.toml index b06d2adb..b3bdcfa1 100644 --- a/Project.toml +++ b/Project.toml @@ -15,10 +15,11 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Optim = "429524aa-4258-5aef-a3af-852621145aeb" RegularizationTools = "29dad682-9a27-4bc3-9c72-016788665182" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" [extensions] -DataInterpolationsChainRulesCoreExt = "ChainRulesCore" +DataInterpolationsChainRulesCoreExt = ["ChainRulesCore", "SparseArrays"] DataInterpolationsOptimExt = "Optim" DataInterpolationsRegularizationToolsExt = "RegularizationTools" DataInterpolationsSymbolicsExt = "Symbolics" @@ -38,6 +39,7 @@ RecipesBase = "1.3" Reexport = "1" RegularizationTools = "0.6" SafeTestsets = "0.1" +SparseArrays = "1.10" StableRNGs = "1" Symbolics = "5.29" Test = "1" diff --git a/ext/DataInterpolationsChainRulesCoreExt.jl b/ext/DataInterpolationsChainRulesCoreExt.jl index 34e27841..37613347 100644 --- a/ext/DataInterpolationsChainRulesCoreExt.jl +++ b/ext/DataInterpolationsChainRulesCoreExt.jl @@ -1,28 +1,185 @@ module DataInterpolationsChainRulesCoreExt if isdefined(Base, :get_extension) - using DataInterpolations: _interpolate, derivative, AbstractInterpolation, + using DataInterpolations: _interpolate, derivative, AbstractInterpolation, get_idx, + cumulative_integral, LinearParameterCache, + QuadraticSplineParameterCache, LagrangeInterpolation, AkimaInterpolation, - BSplineInterpolation, BSplineApprox + BSplineInterpolation, BSplineApprox, LinearInterpolation, + QuadraticSpline using ChainRulesCore + using LinearAlgebra + using SparseArrays + using ReadOnlyArrays else - using ..DataInterpolations: _interpolate, derivative, AbstractInterpolation, + using ..DataInterpolations: _interpolate, derivative, AbstractInterpolation, get_idx, + cumulative_integral, LinearParameterCache, + QuadraticSplineParameterCache, LagrangeInterpolation, AkimaInterpolation, - BSplineInterpolation, BSplineApprox + BSplineInterpolation, BSplineApprox, LinearInterpolation, + QuadraticSpline using ..ChainRulesCore + using ..LinearAlgebra + using ..SparseArrays + using ..ReadOnlyArrays end -function ChainRulesCore.rrule(::typeof(_interpolate), - A::Union{ - LagrangeInterpolation, - AkimaInterpolation, - BSplineInterpolation, - BSplineApprox - }, - t::Number) - deriv = derivative(A, t) - interpolate_pullback(Δ) = (NoTangent(), NoTangent(), deriv * Δ) - return _interpolate(A, t), interpolate_pullback +## Linear interpolation + +function ChainRulesCore.rrule( + ::Type{LinearParameterCache}, u::AbstractArray, t::AbstractVector) + p = LinearParameterCache(u, t) + du = zeros(eltype(p.slope), length(u)) + + function LinearParameterCache_pullback(Δp) + df = NoTangent() + du[2:end] += Δp.slope + du[1:(end - 1)] -= Δp.slope + dt = NoTangent() + return (df, du, dt) + end + + p, LinearParameterCache_pullback +end + +function ChainRulesCore.rrule( + ::Type{LinearInterpolation}, u, t, I, p, extrapolate, safetycopy) + A = LinearInterpolation(u, t, I, p, extrapolate, safetycopy) + + function LinearInterpolation_pullback(ΔA) + df = NoTangent() + du = ΔA.u + dt = NoTangent() + dI = NoTangent() + dp = ΔA.p + dextrapolate = NoTangent() + dsafetycopy = NoTangent() + return df, du, dt, dI, dp, dextrapolate, dsafetycopy + end + + A, LinearInterpolation_pullback +end + +function allocate_direct_field_tangents(A::LinearInterpolation) + idx = A.idx_prev[] + u = SparseVector(length(A.u), [idx], zeros(1)) + (; u) +end + +function allocate_parameter_tangents(A::LinearInterpolation) + idx = A.idx_prev[] + slope = SparseVector(length(A.p.slope), [idx], zeros(1)) + return (; slope) +end + +function _tangent_direct_fields!( + direct_field_tangents::NamedTuple, A::LinearInterpolation, Δt, Δ) + (; u) = direct_field_tangents + idx = A.idx_prev[] + u[idx] = Δ +end + +function _tangent_p!(parameter_tangents::NamedTuple, A::LinearInterpolation, Δt, Δ) + (; slope) = parameter_tangents + idx = A.idx_prev[] + slope[idx] = Δt * Δ +end + +## Quadratic Spline + +function ChainRulesCore.rrule(::Type{QuadraticSplineParameterCache}, u, t) + p = QuadraticSplineParameterCache(u, t) + n = length(u) + + Δt = diff(t) + diagonal_main = 2 ./ Δt + pushfirst!(diagonal_main, zero(eltype(diagonal_main))) + diagonal_down = -diagonal_main[2:end] + diagonal_up = zero(diagonal_down) + ∂d_∂u = Tridiagonal(diagonal_down, diagonal_main, diagonal_up) + + ∂σ_∂z = spzeros(n, n - 1) + for i in 1:(n - 1) + ∂σ_∂z[i, i] = -0.5 / Δt[i] + ∂σ_∂z[i + 1, i] = 0.5 / Δt[i] + end + + function QuadraticSplineParameterCache_pullback(Δp) + df = NoTangent() + temp1 = Δp.z + ∂σ_∂z * Δp.σ + temp2 = p.tA' \ temp1 + du = ∂d_∂u' * temp2 + dt = NoTangent() + return (df, du, dt) + end + + p, QuadraticSplineParameterCache_pullback +end + +function ChainRulesCore.rrule(::Type{QuadraticSpline}, u, t, I, p, extrapolate, safetycopy) + A = QuadraticSpline(u, t, I, p, extrapolate, safetycopy) + + function QuadraticSpline_pullback(ΔA) + df = NoTangent() + du = ΔA.u + dt = NoTangent() + dI = NoTangent() + dp = ΔA.p + dextrapolate = NoTangent() + dsafetycopy = NoTangent() + return df, du, dt, dI, dp, dextrapolate, dsafetycopy + end + + A, QuadraticSpline_pullback +end + +function allocate_direct_field_tangents(A::QuadraticSpline) + idx = A.idx_prev[] + u = SparseVector(length(A.u), [idx], zeros(1)) + (; u) +end + +function allocate_parameter_tangents(A::QuadraticSpline) + idx = A.idx_prev[] + z = SparseVector(length(A.p.z), [idx], zeros(1)) + σ = SparseVector(length(A.p.σ), [idx], zeros(1)) + return (; z, σ) +end + +function _tangent_direct_fields!( + direct_field_tangents::NamedTuple, A::QuadraticSpline, Δt, Δ) + (; u) = direct_field_tangents + idx = A.idx_prev[] + u[idx] = Δ +end + +function _tangent_p!(parameter_tangents::NamedTuple, A::QuadraticSpline, Δt, Δ) + (; z, σ) = parameter_tangents + idx = A.idx_prev[] + z[idx] = Δ * Δt + σ[idx] = Δ * Δt^2 +end + +## generic + +function ChainRulesCore.rrule(A::AType, t::Number) where {AType <: AbstractInterpolation} + u = A(t) + idx = get_idx(A.t, t, A.idx_prev[]) + direct_field_tangents = allocate_direct_field_tangents(A) + parameter_tangents = allocate_parameter_tangents(A) + + function _interpolate_pullback(Δ) + A.idx_prev[] = idx + Δt = t - A.t[idx] + _tangent_direct_fields!(direct_field_tangents, A, Δt, Δ) + _tangent_p!(parameter_tangents, A, Δt, Δ) + dA = Tangent{AType}(; direct_field_tangents..., + p = Tangent{typeof(A.p)}(; parameter_tangents...)) + dt = @thunk(derivative(A, t)*Δ) + return dA, dt + end + + u, _interpolate_pullback end function ChainRulesCore.frule((_, _, Δt), ::typeof(_interpolate), A::AbstractInterpolation, @@ -30,4 +187,16 @@ function ChainRulesCore.frule((_, _, Δt), ::typeof(_interpolate), A::AbstractIn return _interpolate(A, t), derivative(A, t) * Δt end +function ChainRulesCore.rrule(::Type{ReadOnlyArray}, parent) + read_only_array = ReadOnlyArray(parent) + ReadOnlyArray_pullback(Δ) = NoTangent(), Δ + read_only_array, ReadOnlyArray_pullback +end + +function ChainRulesCore.rrule(::typeof(cumulative_integral), A, u) + I = cumulative_integral(A, u) + cumulative_integral_pullback(Δ) = NoTangent(), NoTangent() + I, cumulative_integral_pullback +end + end # module diff --git a/src/derivatives.jl b/src/derivatives.jl index 30c76fd0..159f37eb 100644 --- a/src/derivatives.jl +++ b/src/derivatives.jl @@ -130,7 +130,8 @@ end function _derivative(A::QuadraticSpline{<:AbstractVector}, t::Number, iguess) idx = get_idx(A.t, t, iguess; lb = 2, ub_shift = 0, side = :first) σ = A.p.σ[idx - 1] - A.z[idx - 1] + 2σ * (t - A.t[idx - 1]), idx + z = A.p.z[idx - 1] + z + 2σ * (t - A.t[idx - 1]), idx end # CubicSpline Interpolation diff --git a/src/integrals.jl b/src/integrals.jl index 3040189f..d85ab27f 100644 --- a/src/integrals.jl +++ b/src/integrals.jl @@ -6,7 +6,8 @@ end function integral(A::AbstractInterpolation, t1::Number, t2::Number) ((t1 < A.t[1] || t1 > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError()) ((t2 < A.t[1] || t2 > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError()) - !hasfield(typeof(A), :I) && throw(IntegralNotFoundError()) + has_I = hasfield(typeof(A), :I) + (!has_I || (has_I && isnothing(A.I))) && throw(IntegralNotFoundError()) # the index less than or equal to t1 idx1 = get_idx(A.t, t1, 0) # the index less than t2 @@ -61,7 +62,7 @@ end function _integral(A::QuadraticSpline{<:AbstractVector{<:Number}}, idx::Number, t::Number) Cᵢ = A.u[idx] Δt = t - A.t[idx] - return A.z[idx] * Δt^2 / 2 + A.p.σ[idx] * Δt^3 / 3 + Cᵢ * Δt + return A.p.z[idx] * Δt^2 / 2 + A.p.σ[idx] * Δt^3 / 3 + Cᵢ * Δt end function _integral(A::CubicSpline{<:AbstractVector{<:Number}}, idx::Number, t::Number) diff --git a/src/interpolation_caches.jl b/src/interpolation_caches.jl index 83e04fe5..5e59b842 100644 --- a/src/interpolation_caches.jl +++ b/src/interpolation_caches.jl @@ -32,7 +32,7 @@ function LinearInterpolation(u, t; extrapolate = false, safetycopy = true) u, t = munge_data(u, t, safetycopy) p = LinearParameterCache(u, t) A = LinearInterpolation(u, t, nothing, p, extrapolate, safetycopy) - I = cumulative_integral(A) + I = cumulative_integral(A, A.u) LinearInterpolation(u, t, I, p, extrapolate, safetycopy) end @@ -74,7 +74,7 @@ function QuadraticInterpolation(u, t, mode; extrapolate = false, safetycopy = tr u, t = munge_data(u, t, safetycopy) p = QuadraticParameterCache(u, t) A = QuadraticInterpolation(u, t, nothing, p, mode, extrapolate, safetycopy) - I = cumulative_integral(A) + I = cumulative_integral(A, A.u) QuadraticInterpolation(u, t, I, p, mode, extrapolate, safetycopy) end @@ -198,7 +198,7 @@ function AkimaInterpolation(u, t; extrapolate = false, safetycopy = true) d = (b[1:(end - 1)] .+ b[2:end] .- 2.0 .* m[3:(end - 2)]) ./ dt .^ 2 A = AkimaInterpolation(u, t, nothing, b, c, d, extrapolate, safetycopy) - I = cumulative_integral(A) + I = cumulative_integral(A, A.u) AkimaInterpolation(u, t, I, b, c, d, extrapolate, safetycopy) end @@ -238,7 +238,7 @@ end function ConstantInterpolation(u, t; dir = :left, extrapolate = false, safetycopy = true) u, t = munge_data(u, t, safetycopy) A = ConstantInterpolation(u, t, nothing, dir, extrapolate, safetycopy) - I = cumulative_integral(A) + I = cumulative_integral(A, A.u) ConstantInterpolation(u, t, I, dir, extrapolate, safetycopy) end @@ -263,22 +263,16 @@ struct QuadraticSpline{uType, tType, IType, pType, tAType, dType, zType, T} <: u::uType t::tType I::IType - p::QuadraticSplineParameterCache{pType} - tA::tAType - d::dType - z::zType + p::QuadraticSplineParameterCache{tAType, dType, zType, pType} extrapolate::Bool idx_prev::Base.RefValue{Int} safetycopy::Bool - function QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, safetycopy) - new{typeof(u), typeof(t), typeof(I), typeof(p.σ), typeof(tA), - typeof(d), typeof(z), eltype(u)}(u, + function QuadraticSpline(u, t, I, p, extrapolate, safetycopy) + new{typeof(u), typeof(t), typeof(I), typeof(p.σ), typeof(p.tA), + typeof(p.d), typeof(p.z), eltype(u)}(u, t, I, p, - tA, - d, - z, extrapolate, Ref(1), safetycopy @@ -287,45 +281,13 @@ struct QuadraticSpline{uType, tType, IType, pType, tAType, dType, zType, T} <: end function QuadraticSpline( - u::uType, t; extrapolate = false, - safetycopy = true) where {uType <: AbstractVector{<:Number}} - u, t = munge_data(u, t, safetycopy) - s = length(t) - dl = ones(eltype(t), s - 1) - d_tmp = ones(eltype(t), s) - du = zeros(eltype(t), s - 1) - tA = Tridiagonal(dl, d_tmp, du) - - # zero for element type of d, which we don't know yet - typed_zero = zero(2 // 1 * (u[begin + 1] - u[begin]) / (t[begin + 1] - t[begin])) - - d = map(i -> i == 1 ? typed_zero : 2 // 1 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]), 1:s) - z = tA \ d - p = QuadraticSplineParameterCache(z, t) - A = QuadraticSpline(u, t, nothing, p, tA, d, z, extrapolate, safetycopy) - I = cumulative_integral(A) - QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, safetycopy) -end - -function QuadraticSpline( - u::uType, t; extrapolate = false, safetycopy = true) where {uType <: AbstractVector} + u, t; extrapolate = false, + safetycopy = true) u, t = munge_data(u, t, safetycopy) - s = length(t) - dl = ones(eltype(t), s - 1) - d_tmp = ones(eltype(t), s) - du = zeros(eltype(t), s - 1) - tA = Tridiagonal(dl, d_tmp, du) - d_ = map( - i -> i == 1 ? zeros(eltype(t), size(u[1])) : - 2 // 1 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]), - 1:s) - d = transpose(reshape(reduce(hcat, d_), :, s)) - z_ = reshape(transpose(tA \ d), size(u[1])..., :) - z = [z_s for z_s in eachslice(z_, dims = ndims(z_))] - p = QuadraticSplineParameterCache(z, t) - A = QuadraticSpline(u, t, nothing, p, tA, d, z, extrapolate, safetycopy) - I = cumulative_integral(A) - QuadraticSpline(u, t, I, p, tA, d, z, extrapolate, safetycopy) + p = QuadraticSplineParameterCache(u, t) + A = QuadraticSpline(u, t, nothing, p, extrapolate, safetycopy) + I = cumulative_integral(A, A.u) + QuadraticSpline(u, t, I, p, extrapolate, safetycopy) end """ @@ -391,7 +353,7 @@ function CubicSpline(u::uType, z = tA \ d p = CubicSplineParameterCache(u, h, z) A = CubicSpline(u, t, nothing, p, h[1:(n + 1)], z, extrapolate, safetycopy) - I = cumulative_integral(A) + I = cumulative_integral(A, A.u) CubicSpline(u, t, I, p, h[1:(n + 1)], z, extrapolate, safetycopy) end @@ -413,7 +375,7 @@ function CubicSpline( z = [z_s for z_s in eachslice(z_, dims = ndims(z_))] p = CubicSplineParameterCache(u, h, z) A = CubicSpline(u, t, nothing, p, h[1:(n + 1)], z, extrapolate, safetycopy) - I = cumulative_integral(A) + I = cumulative_integral(A, A.u) CubicSpline(u, t, I, p, h[1:(n + 1)], z, extrapolate, safetycopy) end @@ -735,7 +697,7 @@ function CubicHermiteSpline(du, u, t; extrapolate = false, safetycopy = true) u, t = munge_data(u, t, safetycopy) p = CubicHermiteParameterCache(du, u, t) A = CubicHermiteSpline(du, u, t, nothing, p, extrapolate, safetycopy) - I = cumulative_integral(A) + I = cumulative_integral(A, A.u) CubicHermiteSpline(du, u, t, I, p, extrapolate, safetycopy) end @@ -779,6 +741,6 @@ function QuinticHermiteSpline(ddu, du, u, t; extrapolate = false, safetycopy = t u, t = munge_data(u, t, safetycopy) p = QuinticHermiteParameterCache(ddu, du, u, t) A = QuinticHermiteSpline(ddu, du, u, t, nothing, p, extrapolate, safetycopy) - I = cumulative_integral(A) + I = cumulative_integral(A, A.u) QuinticHermiteSpline(ddu, du, u, t, I, p, extrapolate, safetycopy) end diff --git a/src/interpolation_methods.jl b/src/interpolation_methods.jl index 5d09ceff..dd2893a7 100644 --- a/src/interpolation_methods.jl +++ b/src/interpolation_methods.jl @@ -149,7 +149,7 @@ function _interpolate(A::QuadraticSpline{<:AbstractVector}, t::Number, iguess) idx = get_idx(A.t, t, iguess) Cᵢ = A.u[idx] Δt = t - A.t[idx] - return A.z[idx] * Δt + A.p.σ[idx] * Δt^2 + Cᵢ, idx + return A.p.z[idx] * Δt + A.p.σ[idx] * Δt^2 + Cᵢ, idx end # CubicSpline Interpolation diff --git a/src/interpolation_utils.jl b/src/interpolation_utils.jl index 466248b1..ee6db3c9 100644 --- a/src/interpolation_utils.jl +++ b/src/interpolation_utils.jl @@ -121,12 +121,13 @@ function get_idx(tvec, t, iguess; lb = 1, ub_shift = -1, idx_shift = 0, side = : end end -function cumulative_integral(A) - if isempty(methods(_integral, (typeof(A), Any, Any))) - return nothing - end - integral_values = [_integral(A, idx, A.t[idx + 1]) - _integral(A, idx, A.t[idx]) - for idx in 1:(length(A.t) - 1)] - pushfirst!(integral_values, zero(first(integral_values))) +function cumulative_integral(A, ::AbstractVector{<:Number}) + integral_prototype = _integral(A, 1, A.t[2]) + + integral_values = [zero(integral_prototype), + (_integral(A, idx, A.t[idx + 1]) - _integral(A, idx, A.t[idx]) + for idx in 1:(length(A.t) - 1))...] return cumsum(integral_values) end + +cumulative_integral(A, ::AbstractArray) = nothing diff --git a/src/online.jl b/src/online.jl index dc500611..a1ca710e 100644 --- a/src/online.jl +++ b/src/online.jl @@ -3,7 +3,7 @@ import Base: append!, push! function push!(A::LinearInterpolation{U, T}, u::eltype(U), t::eltype(T)) where {U, T} push!(A.u.parent, u) push!(A.t.parent, t) - slope = linear_interpolation_parameters(A.u, A.t, length(A.t) - 1) + slope = interpolation_parameters(Val(:LinearInterpolation), A.u, A.t, length(A.t) - 1) push!(A.p.slope, slope) A end @@ -11,7 +11,8 @@ end function push!(A::QuadraticInterpolation{U, T}, u::eltype(U), t::eltype(T)) where {U, T} push!(A.u.parent, u) push!(A.t.parent, t) - l₀, l₁, l₂ = quadratic_interpolation_parameters(A.u, A.t, length(A.t) - 2) + l₀, l₁, l₂ = interpolation_parameters( + Val(:QuadraticInterpolation), A.u, A.t, length(A.t) - 2) push!(A.p.l₀, l₀) push!(A.p.l₁, l₁) push!(A.p.l₂, l₂) @@ -31,7 +32,7 @@ function append!( u, t = munge_data(u, t, true) append!(A.u.parent, u) append!(A.t.parent, t) - slope = linear_interpolation_parameters.( + slope = interpolation_parameters.(Val(:LinearInterpolation), Ref(A.u), Ref(A.t), length_old:(length(A.t) - 1)) append!(A.p.slope, slope) A @@ -53,7 +54,7 @@ function append!( u, t = munge_data(u, t, true) append!(A.u.parent, u) append!(A.t.parent, t) - parameters = quadratic_interpolation_parameters.( + parameters = interpolation_parameters.(Val(:QuadraticInterpolation), Ref(A.u), Ref(A.t), (length_old - 1):(length(A.t) - 2)) l₀, l₁, l₂ = collect.(eachrow(hcat(collect.(parameters)...))) append!(A.p.l₀, l₀) diff --git a/src/parameter_caches.jl b/src/parameter_caches.jl index 2820dc8f..3648120a 100644 --- a/src/parameter_caches.jl +++ b/src/parameter_caches.jl @@ -2,12 +2,14 @@ struct LinearParameterCache{pType} slope::pType end -function LinearParameterCache(u, t) - slope = linear_interpolation_parameters.(Ref(u), Ref(t), 1:(length(t) - 1)) +function LinearParameterCache(u, t)::LinearParameterCache + idxs = 1:(length(t) - 1) + slope = [interpolation_parameters(Val(:LinearInterpolation), u, t, idx) for idx in idxs] return LinearParameterCache(slope) end -function linear_interpolation_parameters(u, t, idx) +function interpolation_parameters(::Val{:LinearInterpolation}, + u::AbstractArray, t::AbstractVector, idx::Integer) Δu = u isa AbstractMatrix ? u[:, idx + 1] - u[:, idx] : u[idx + 1] - u[idx] Δt = t[idx + 1] - t[idx] slope = Δu / Δt @@ -22,13 +24,13 @@ struct QuadraticParameterCache{pType} end function QuadraticParameterCache(u, t) - parameters = quadratic_interpolation_parameters.( + parameters = interpolation_parameters.(Val(:QuadraticInterpolation), Ref(u), Ref(t), 1:(length(t) - 2)) l₀, l₁, l₂ = collect.(eachrow(hcat(collect.(parameters)...))) return QuadraticParameterCache(l₀, l₁, l₂) end -function quadratic_interpolation_parameters(u, t, idx) +function interpolation_parameters(::Val{:QuadraticInterpolation}, u, t, idx) if u isa AbstractMatrix u₀ = u[:, idx] u₁ = u[:, idx + 1] @@ -50,16 +52,45 @@ function quadratic_interpolation_parameters(u, t, idx) return l₀, l₁, l₂ end -struct QuadraticSplineParameterCache{pType} - σ::pType -end - -function QuadraticSplineParameterCache(z, t) - σ = quadratic_spline_parameters.(Ref(z), Ref(t), 1:(length(t) - 1)) - return QuadraticSplineParameterCache(σ) -end - -function quadratic_spline_parameters(z, t, idx) +struct QuadraticSplineParameterCache{tAType, dType, zType, σType} + tA::tAType + d::dType + z::zType + σ::σType +end + +function QuadraticSplineParameterCache(u::AbstractVector{<:Number}, t) + s = length(t) + dl = ones(eltype(t), s - 1) + d_tmp = ones(eltype(t), s) + du = zeros(eltype(t), s - 1) + tA = Tridiagonal(dl, d_tmp, du) + + d = [2 // 1 * (u[i] - u[max(1, i - 1)]) / (t[i] - t[1 + abs(i - 2)]) + for i in eachindex(t)] + z = tA \ d + σ = interpolation_parameters.(Val(:QuadraticSpline), Ref(z), Ref(t), 1:(length(t) - 1)) + return QuadraticSplineParameterCache(tA, d, z, σ) +end + +function QuadraticSplineParameterCache(u::AbstractVector{<:AbstractArray{<:Number}}, t) + s = length(t) + dl = ones(eltype(t), s - 1) + d_tmp = ones(eltype(t), s) + du = zeros(eltype(t), s - 1) + tA = Tridiagonal(dl, d_tmp, du) + d_ = map( + i -> i == 1 ? zeros(eltype(t), size(u[1])) : + 2 // 1 * (u[i] - u[i - 1]) / (t[i] - t[i - 1]), + 1:s) + d = transpose(reshape(reduce(hcat, d_), :, s)) + z_ = reshape(transpose(tA \ d), size(u[1])..., :) + z = [z_s for z_s in eachslice(z_, dims = ndims(z_))] + σ = interpolation_parameters.(Val(:QuadraticSpline), Ref(z), Ref(t), 1:(length(t) - 1)) + return QuadraticSplineParameterCache(tA, d, z, σ) +end + +function interpolation_parameters(::Val{:QuadraticSpline}, z, t, idx) σ = 1 // 2 * (z[idx + 1] - z[idx]) / (t[idx + 1] - t[idx]) return σ end @@ -70,13 +101,13 @@ struct CubicSplineParameterCache{pType} end function CubicSplineParameterCache(u, h, z) - parameters = cubic_spline_parameters.( + parameters = interpolation_parameters.(Val(:CubicSpline), Ref(u), Ref(h), Ref(z), 1:(size(u)[end] - 1)) c₁, c₂ = collect.(eachrow(hcat(collect.(parameters)...))) return CubicSplineParameterCache(c₁, c₂) end -function cubic_spline_parameters(u, h, z, idx) +function interpolation_parameters(::Val{:CubicSpline}, u, h, z, idx) c₁ = (u[idx + 1] / h[idx + 1] - z[idx + 1] * h[idx + 1] / 6) c₂ = (u[idx] / h[idx + 1] - z[idx] * h[idx + 1] / 6) return c₁, c₂ @@ -88,13 +119,13 @@ struct CubicHermiteParameterCache{pType} end function CubicHermiteParameterCache(du, u, t) - parameters = cubic_hermite_spline_parameters.( + parameters = interpolation_parameters.(Val(:CubicHermiteSpline), Ref(du), Ref(u), Ref(t), 1:(length(t) - 1)) c₁, c₂ = collect.(eachrow(hcat(collect.(parameters)...))) return CubicHermiteParameterCache(c₁, c₂) end -function cubic_hermite_spline_parameters(du, u, t, idx) +function interpolation_parameters(::Val{:CubicHermiteSpline}, du, u, t, idx) Δt = t[idx + 1] - t[idx] u₀ = u[idx] u₁ = u[idx + 1] @@ -112,13 +143,13 @@ struct QuinticHermiteParameterCache{pType} end function QuinticHermiteParameterCache(ddu, du, u, t) - parameters = quintic_hermite_spline_parameters.( + parameters = interpolation_parameters.(Val(:QuinticHermiteSpline), Ref(ddu), Ref(du), Ref(u), Ref(t), 1:(length(t) - 1)) c₁, c₂, c₃ = collect.(eachrow(hcat(collect.(parameters)...))) return QuinticHermiteParameterCache(c₁, c₂, c₃) end -function quintic_hermite_spline_parameters(ddu, du, u, t, idx) +function interpolation_parameters(::Val{:QuinticHermiteSpline}, ddu, du, u, t, idx) Δt = t[idx + 1] - t[idx] u₀ = u[idx] u₁ = u[idx + 1]