Skip to content

Commit bae3edc

Browse files
refactor: interpolation with higher dim arrays
1 parent caf0b76 commit bae3edc

File tree

3 files changed

+142
-42
lines changed

3 files changed

+142
-42
lines changed

src/interpolation_caches.jl

Lines changed: 8 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -173,29 +173,24 @@ Extrapolation extends the last cubic polynomial on each side.
173173
for a test based on the normalized standard deviation of the difference with respect
174174
to the straight line (see [`looks_linear`](@ref)). Defaults to 1e-2.
175175
"""
176-
struct AkimaInterpolation{uType, tType, IType, bType, cType, dType, T, N} <:
176+
struct AkimaInterpolation{uType, tType, IType, pType, T, N} <:
177177
AbstractInterpolation{T, N}
178178
u::uType
179179
t::tType
180180
I::IType
181-
b::bType
182-
c::cType
183-
d::dType
181+
p::pType
184182
extrapolate::Bool
185183
iguesser::Guesser{tType}
186184
cache_parameters::Bool
187185
linear_lookup::Bool
188186
function AkimaInterpolation(
189-
u, t, I, b, c, d, extrapolate, cache_parameters, assume_linear_t)
187+
u, t, I, p, extrapolate, cache_parameters, assume_linear_t)
190188
linear_lookup = seems_linear(assume_linear_t, t)
191189
N = get_output_dim(u)
192-
new{typeof(u), typeof(t), typeof(I), typeof(b), typeof(c),
193-
typeof(d), eltype(u), N}(u,
190+
new{typeof(u), typeof(t), typeof(I), typeof(p), eltype(u), N}(u,
194191
t,
195192
I,
196-
b,
197-
c,
198-
d,
193+
p,
199194
extrapolate,
200195
Guesser(t),
201196
cache_parameters,
@@ -208,30 +203,11 @@ function AkimaInterpolation(
208203
u, t; extrapolate = false, cache_parameters = false, assume_linear_t = 1e-2)
209204
u, t = munge_data(u, t)
210205
linear_lookup = seems_linear(assume_linear_t, t)
211-
n = length(t)
212-
dt = diff(t)
213-
m = Array{eltype(u)}(undef, n + 3)
214-
m[3:(end - 2)] = diff(u) ./ dt
215-
m[2] = 2m[3] - m[4]
216-
m[1] = 2m[2] - m[3]
217-
m[end - 1] = 2m[end - 2] - m[end - 3]
218-
m[end] = 2m[end - 1] - m[end - 2]
219-
220-
b = 0.5 .* (m[4:end] .+ m[1:(end - 3)])
221-
dm = abs.(diff(m))
222-
f1 = dm[3:(n + 2)]
223-
f2 = dm[1:n]
224-
f12 = f1 + f2
225-
ind = findall(f12 .> 1e-9 * maximum(f12))
226-
b[ind] = (f1[ind] .* m[ind .+ 1] .+
227-
f2[ind] .* m[ind .+ 2]) ./ f12[ind]
228-
c = (3.0 .* m[3:(end - 2)] .- 2.0 .* b[1:(end - 1)] .- b[2:end]) ./ dt
229-
d = (b[1:(end - 1)] .+ b[2:end] .- 2.0 .* m[3:(end - 2)]) ./ dt .^ 2
230-
206+
p = AkimaParameterCache(u, t)
231207
A = AkimaInterpolation(
232-
u, t, nothing, b, c, d, extrapolate, cache_parameters, linear_lookup)
208+
u, t, nothing, p, extrapolate, cache_parameters, linear_lookup)
233209
I = cumulative_integral(A, cache_parameters)
234-
AkimaInterpolation(u, t, I, b, c, d, extrapolate, cache_parameters, linear_lookup)
210+
AkimaInterpolation(u, t, I, p, extrapolate, cache_parameters, linear_lookup)
235211
end
236212

237213
"""

src/interpolation_methods.jl

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,14 @@ function _interpolate(A::LagrangeInterpolation{<:AbstractVector}, t::Number, igu
8888
end
8989

9090
function _interpolate(
91-
A::LagrangeInterpolation{<:AbstractArray{T, N}}, t::Number, iguess) where {T, N}
91+
A::LagrangeInterpolation{<:AbstractArray}, t::Number, iguess)
9292
idx = get_idx(A, t, iguess)
9393
findRequiredIdxs!(A, t, idx)
9494
ax = axes(A.u)[1:(end - 1)]
9595
if A.t[A.idxs[1]] == t
9696
return A.u[ax..., A.idxs[1]]
9797
end
98-
N1 = zero(A.u[ax..., 1])
98+
N = zero(A.u[ax..., 1])
9999
D = zero(A.t[1])
100100
tmp = D
101101
for i in 1:length(A.idxs)
@@ -113,15 +113,22 @@ function _interpolate(
113113
end
114114
tmp = inv((t - A.t[A.idxs[i]]) * mult)
115115
D += tmp
116-
@. N1 += (tmp * A.u[ax..., A.idxs[i]])
116+
@. N += (tmp * A.u[ax..., A.idxs[i]])
117117
end
118-
N1 / D
118+
N / D
119119
end
120120

121121
function _interpolate(A::AkimaInterpolation{<:AbstractVector}, t::Number, iguess)
122122
idx = get_idx(A, t, iguess)
123123
wj = t - A.t[idx]
124-
@evalpoly wj A.u[idx] A.b[idx] A.c[idx] A.d[idx]
124+
@evalpoly wj A.u[idx] A.p.b[idx] A.p.c[idx] A.p.d[idx]
125+
end
126+
127+
function _interpolate(A::AkimaInterpolation{<:AbstractArray}, t::Number, iguess)
128+
idx = get_idx(A, t, iguess)
129+
wj = t - A.t[idx]
130+
ax = axes(A.u)[1:(end - 1)]
131+
@. @evalpoly wj A.u[ax..., idx] A.p.b[ax..., idx] A.p.c[ax..., idx] A.p.d[ax..., idx]
125132
end
126133

127134
# ConstantInterpolation Interpolation
@@ -137,7 +144,7 @@ function _interpolate(A::ConstantInterpolation{<:AbstractVector}, t::Number, igu
137144
end
138145

139146
function _interpolate(
140-
A::ConstantInterpolation{<:AbstractArray{T, N}}, t::Number, iguess) where {T, N}
147+
A::ConstantInterpolation{<:AbstractArray}, t::Number, iguess)
141148
if A.dir === :left
142149
# :left means that value to the left is used for interpolation
143150
idx = get_idx(A, t, iguess; lb = 1, ub_shift = 0)
@@ -158,7 +165,7 @@ function _interpolate(A::QuadraticSpline{<:AbstractVector}, t::Number, iguess)
158165
end
159166

160167
function _interpolate(
161-
A::QuadraticSpline{<:AbstractArray{T, N}}, t::Number, iguess) where {T, N}
168+
A::QuadraticSpline{<:AbstractArray}, t::Number, iguess)
162169
idx = get_idx(A, t, iguess)
163170
ax = axes(A.u)[1:(end - 1)]
164171
Cᵢ = A.u[ax..., idx]
@@ -179,7 +186,7 @@ function _interpolate(A::CubicSpline{<:AbstractVector}, t::Number, iguess)
179186
I + C + D
180187
end
181188

182-
function _interpolate(A::CubicSpline{<:AbstractArray{T, N}}, t::Number, iguess) where {T, N}
189+
function _interpolate(A::CubicSpline{<:AbstractArray}, t::Number, iguess)
183190
idx = get_idx(A, t, iguess)
184191
Δt₁ = t - A.t[idx]
185192
Δt₂ = A.t[idx + 1] - t
@@ -238,6 +245,18 @@ function _interpolate(
238245
out
239246
end
240247

248+
function _interpolate(
249+
A::CubicHermiteSpline{<:AbstractArray}, t::Number, iguess)
250+
idx = get_idx(A, t, iguess)
251+
Δt₀ = t - A.t[idx]
252+
Δt₁ = t - A.t[idx + 1]
253+
ax = axes(A.u)[1:(end - 1)]
254+
out = A.u[ax..., idx] .+ Δt₀ .* A.du[ax..., idx]
255+
c₁, c₂ = get_parameters(A, idx)
256+
out .+= Δt₀^2 .* (c₁ .+ Δt₁ .* c₂)
257+
out
258+
end
259+
241260
# Quintic Hermite Spline
242261
function _interpolate(
243262
A::QuinticHermiteSpline{<:AbstractVector{<:Number}}, t::Number, iguess)
@@ -249,3 +268,15 @@ function _interpolate(
249268
out += Δt₀^3 * (c₁ + Δt₁ * (c₂ + c₃ * Δt₁))
250269
out
251270
end
271+
272+
function _interpolate(
273+
A::QuinticHermiteSpline{<:AbstractArray}, t::Number, iguess)
274+
idx = get_idx(A, t, iguess)
275+
Δt₀ = t - A.t[idx]
276+
Δt₁ = t - A.t[idx + 1]
277+
ax = axes(A.u)[1:(end - 1)]
278+
out = A.u[ax..., idx] + Δt₀ * (A.du[ax..., idx] + A.ddu[ax..., idx] * Δt₀ / 2)
279+
c₁, c₂, c₃ = get_parameters(A, idx)
280+
out .+= Δt₀^3 .* (c₁ .+ Δt₁ .* (c₂ .+ c₃ .* Δt₁))
281+
out
282+
end

src/parameter_caches.jl

Lines changed: 95 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,72 @@ function quadratic_interpolation_parameters(u::AbstractArray{T, N}, t, idx) wher
7575
return l₀, l₁, l₂
7676
end
7777

78+
struct AkimaParameterCache{pType}
79+
b::pType
80+
c::pType
81+
d::pType
82+
end
83+
84+
function AkimaParameterCache(u, t)
85+
b, c, d = akima_interpolation_parameters(u, t)
86+
AkimaParameterCache(b, c, d)
87+
end
88+
89+
function akima_interpolation_parameters(u::AbstractVector, t)
90+
n = length(t)
91+
dt = diff(t)
92+
m = Array{eltype(u)}(undef, n + 3)
93+
m[3:(end - 2)] = diff(u) ./ dt
94+
m[2] = 2m[3] - m[4]
95+
m[1] = 2m[2] - m[3]
96+
m[end - 1] = 2m[end - 2] - m[end - 3]
97+
m[end] = 2m[end - 1] - m[end - 2]
98+
b = 0.5 .* (m[4:end] .+ m[1:(end - 3)])
99+
dm = abs.(diff(m))
100+
f1 = dm[3:(n + 2)]
101+
f2 = dm[1:n]
102+
f12 = f1 + f2
103+
ind = findall(f12 .> 1e-9 * maximum(f12))
104+
b[ind] = (f1[ind] .* m[ind .+ 1] .+
105+
f2[ind] .* m[ind .+ 2]) ./ f12[ind]
106+
c = (3.0 .* m[3:(end - 2)] .- 2.0 .* b[1:(end - 1)] .- b[2:end]) ./ dt
107+
d = (b[1:(end - 1)] .+ b[2:end] .- 2.0 .* m[3:(end - 2)]) ./ dt .^ 2
108+
return b, c, d
109+
end
110+
111+
function akima_interpolation_parameters(u::AbstractArray, t)
112+
n = length(t)
113+
dt = diff(t)
114+
ax = axes(u)[1:(end - 1)]
115+
su = size(u)
116+
m = zeros(eltype(u), su[1:(end - 1)]..., n + 3)
117+
m[ax..., 3:(end - 2)] .= mapslices(
118+
x -> x ./ dt, diff(u, dims = length(su)); dims = length(su))
119+
m[ax..., 2] .= 2m[ax..., 3] .- m[ax..., 4]
120+
m[ax..., 1] .= 2m[ax..., 2] .- m[3]
121+
m[ax..., end - 1] .= 2m[ax..., end - 2] - m[ax..., end - 3]
122+
m[ax..., end] .= 2m[ax..., end - 1] .- m[ax..., end - 2]
123+
b = 0.5 .* (m[ax..., 4:end] .+ m[ax..., 1:(end - 3)])
124+
dm = abs.(diff(m, dims = length(su)))
125+
f1 = dm[ax..., 3:(n + 2)]
126+
f2 = dm[ax..., 1:n]
127+
f12 = f1 .+ f2
128+
ind = findall(f12 .> 1e-9 * maximum(f12))
129+
indi = map(i -> i.I, ind)
130+
b[ind] .= (f1[ind] .*
131+
m[CartesianIndex.(map(i -> (i[1:(end - 1)]..., i[end] + 1), indi))] .+
132+
f2[ind] .*
133+
m[CartesianIndex.(map(i -> (i[1:(end - 1)]..., i[end] + 2), indi))]) ./
134+
f12[ind]
135+
c = mapslices(x -> x ./ dt,
136+
(3.0 .* m[ax..., 3:(end - 2)] .- 2.0 .* b[ax..., 1:(end - 1)] .- b[ax..., 2:end]);
137+
dims = length(su))
138+
d = mapslices(x -> x ./ dt .^ 2,
139+
(b[ax..., 1:(end - 1)] .+ b[ax..., 2:end] .- 2.0 .* m[ax..., 3:(end - 2)]);
140+
dims = length(su))
141+
return b, c, d
142+
end
143+
78144
struct QuadraticSplineParameterCache{pType}
79145
σ::pType
80146
end
@@ -152,7 +218,19 @@ function CubicHermiteParameterCache(du, u, t, cache_parameters)
152218
end
153219
end
154220

155-
function cubic_hermite_spline_parameters(du, u, t, idx)
221+
function cubic_hermite_spline_parameters(du::AbstractArray, u, t, idx)
222+
ax = axes(u)[1:(end - 1)]
223+
Δt = t[idx + 1] - t[idx]
224+
u₀ = u[ax..., idx]
225+
u₁ = u[ax..., idx + 1]
226+
du₀ = du[ax..., idx]
227+
du₁ = du[ax..., idx + 1]
228+
c₁ = (u₁ - u₀ - du₀ * Δt) / Δt^2
229+
c₂ = (du₁ - du₀ - 2c₁ * Δt) / Δt^2
230+
return c₁, c₂
231+
end
232+
233+
function cubic_hermite_spline_parameters(du::AbstractVector, u, t, idx)
156234
Δt = t[idx + 1] - t[idx]
157235
u₀ = u[idx]
158236
u₁ = u[idx + 1]
@@ -183,7 +261,7 @@ function QuinticHermiteParameterCache(ddu, du, u, t, cache_parameters)
183261
end
184262
end
185263

186-
function quintic_hermite_spline_parameters(ddu, du, u, t, idx)
264+
function quintic_hermite_spline_parameters(ddu::AbstractVector, du, u, t, idx)
187265
Δt = t[idx + 1] - t[idx]
188266
u₀ = u[idx]
189267
u₁ = u[idx + 1]
@@ -196,3 +274,18 @@ function quintic_hermite_spline_parameters(ddu, du, u, t, idx)
196274
c₃ = (6u₁ - 6u₀ - 3(du₀ + du₁)Δt + (ddu₁ - ddu₀)Δt^2 / 2) / Δt^5
197275
return c₁, c₂, c₃
198276
end
277+
278+
function quintic_hermite_spline_parameters(ddu::AbstractArray, du, u, t, idx)
279+
ax = axes(ddu)[1:(end - 1)]
280+
Δt = t[idx + 1] - t[idx]
281+
u₀ = u[ax..., idx]
282+
u₁ = u[ax..., idx + 1]
283+
du₀ = du[ax..., idx]
284+
du₁ = du[ax..., idx + 1]
285+
ddu₀ = ddu[ax..., idx]
286+
ddu₁ = ddu[ax..., idx + 1]
287+
c₁ = (u₁ - u₀ - du₀ * Δt - ddu₀ * Δt^2 / 2) / Δt^3
288+
c₂ = (3u₀ - 3u₁ + 2(du₀ + du₁ / 2)Δt + ddu₀ * Δt^2 / 2) / Δt^4
289+
c₃ = (6u₁ - 6u₀ - 3(du₀ + du₁)Δt + (ddu₁ - ddu₀)Δt^2 / 2) / Δt^5
290+
return c₁, c₂, c₃
291+
end

0 commit comments

Comments
 (0)