Skip to content

Commit

Permalink
Fix type inference and performance problems of munge_data
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion committed Feb 7, 2025
1 parent e409e9c commit 3abd7d7
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 34 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ EnumX = "1.0.4"
FindFirstFunctions = "1.3"
FiniteDifferences = "0.12.31"
ForwardDiff = "0.10.36"
JET = "0.9.17"
LinearAlgebra = "1.10"
Optim = "1.6"
PrettyTables = "2"
Expand All @@ -53,6 +54,7 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
RegularizationTools = "29dad682-9a27-4bc3-9c72-016788665182"
Expand All @@ -64,4 +66,4 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "BenchmarkTools", "SafeTestsets", "ChainRulesCore", "Optim", "RegularizationTools", "Test", "StableRNGs", "FiniteDifferences", "QuadGK", "ForwardDiff", "Symbolics", "Unitful", "Zygote"]
test = ["Aqua", "BenchmarkTools", "JET", "SafeTestsets", "ChainRulesCore", "Optim", "RegularizationTools", "Test", "StableRNGs", "FiniteDifferences", "QuadGK", "ForwardDiff", "Symbolics", "Unitful", "Zygote"]
62 changes: 29 additions & 33 deletions src/interpolation_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,51 +104,47 @@ function quadratic_spline_params(t::AbstractVector, sc::AbstractVector)
end

# helper function for data manipulation
function munge_data(u::AbstractVector{<:Real}, t::AbstractVector{<:Real})
return u, t
end

function munge_data(u::AbstractVector, t::AbstractVector)
Tu = Base.nonmissingtype(eltype(u))
Tt = Base.nonmissingtype(eltype(t))
@assert length(t) == length(u)
non_missing_indices = collect(
i for i in 1:length(t)
if !ismissing(u[i]) && !ismissing(t[i])
)
Tu = nonmissingtype(eltype(u))
Tt = nonmissingtype(eltype(t))
if Tu === eltype(u) && Tt === eltype(t)
return u, t
end

u = Tu.([u[i] for i in non_missing_indices])
t = Tt.([t[i] for i in non_missing_indices])
@assert length(t) == length(u)
non_missing_mask = map((ui, ti) -> !ismissing(ui) && !ismissing(ti), u, t)
u = convert(AbstractVector{Tu}, u[non_missing_mask])
t = convert(AbstractVector{Tt}, t[non_missing_mask])

return u, t
end

function munge_data(U::StridedMatrix, t::AbstractVector)
TU = Base.nonmissingtype(eltype(U))
Tt = Base.nonmissingtype(eltype(t))
@assert length(t) == size(U, 2)
non_missing_indices = collect(
i for i in 1:length(t)
if !any(ismissing, U[:, i]) && !ismissing(t[i])
)
function munge_data(U::AbstractMatrix, t::AbstractVector)
TU = nonmissingtype(eltype(U))
Tt = nonmissingtype(eltype(t))
if TU === eltype(U) && Tt === eltype(t)
return U, t
end

U = hcat([TU.(U[:, i]) for i in non_missing_indices]...)
t = Tt.([t[i] for i in non_missing_indices])
@assert length(t) == size(U, 2)
non_missing_mask = map((uis, ti) -> !any(ismissing, uis) && !ismissing(ti), eachcol(U), t)
U = convert(AbstractMatrix{TU}, U[:, non_missing_mask])
t = convert(AbstractVector{Tt}, t[non_missing_mask])

return U, t
end

function munge_data(U::AbstractArray{T, N}, t) where {T, N}
TU = Base.nonmissingtype(eltype(U))
Tt = Base.nonmissingtype(eltype(t))
@assert length(t) == size(U, ndims(U))
ax = axes(U)[1:(end - 1)]
non_missing_indices = collect(
i for i in 1:length(t)
if !any(ismissing, U[ax..., i]) && !ismissing(t[i])
)
U = cat([TU.(U[ax..., i]) for i in non_missing_indices]...; dims = ndims(U))
t = Tt.([t[i] for i in non_missing_indices])
TU = nonmissingtype(eltype(U))
Tt = nonmissingtype(eltype(t))
if TU === eltype(U) && Tt === eltype(t)
return U, t
end

@assert length(t) == size(U, N)
non_missing_mask = map((uis, ti) -> !any(ismissing, uis) && !ismissing(ti), eachslice(U; dims=N), t)
U = convert(AbstractArray{TU,N}, copy(selectdim(U, N, non_missing_mask)))
t = convert(AbstractVector{Tt}, t[non_missing_mask])

return U, t
end
Expand Down
20 changes: 20 additions & 0 deletions test/interpolation_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using StableRNGs
using Optim, ForwardDiff
using BenchmarkTools
using Unitful
using JET

function test_interpolation_type(T)
@test T <: DataInterpolations.AbstractInterpolation
Expand Down Expand Up @@ -920,3 +921,22 @@ f_cubic_spline = c -> square(CubicSpline, c)
@test ForwardDiff.derivative(f_quadratic_spline, 4.0) 8.0
@test ForwardDiff.derivative(f_cubic_spline, 2.0) 4.0
@test ForwardDiff.derivative(f_cubic_spline, 4.0) 8.0

@testset "munge_data" begin
t0 = [0.1, 0.2, 0.3]
u0 = ["A", "B", "C"]

for T in (String, Union{String,Missing}), dims in 1:3
_u0 = convert(Array{T}, reshape(u0, ntuple(i -> i == dims ? 3 : 1, dims)))

u, t = @inferred(DataInterpolations.munge_data(_u0, t0))
@test u isa Array{String,dims}
@test t isa Vector{Float64}
if T === String
@test u === _u0
@test t === t
end

@test_call DataInterpolations.munge_data(_u0, t0)
end
end

0 comments on commit 3abd7d7

Please sign in to comment.