diff --git a/src/obstransform.jl b/src/obstransform.jl index e5c6168..e2aefbb 100644 --- a/src/obstransform.jl +++ b/src/obstransform.jl @@ -208,11 +208,7 @@ accomplish that, which means that the return value is likely of a different type than `data`. Optionally, a random number generator `rng` can be passed as the -first argument. - -The optional parameter `rng` allows one to specify the -random number generator used for shuffling. This is useful when -reproducible results are desired. +first argument. For this function to work, the type of `data` must implement [`numobs`](@ref) and [`getobs`](@ref). diff --git a/src/splitobs.jl b/src/splitobs.jl index 7061d43..9ea6e15 100644 --- a/src/splitobs.jl +++ b/src/splitobs.jl @@ -2,7 +2,7 @@ splitobs(n::Int; at) -> Tuple Compute the indices for two or more disjoint subsets of -the range `1:n` with splits given by `at`. +the range `1:n` with split sizes determined by `at`. # Examples @@ -18,13 +18,12 @@ splitobs(n::Int; at) = _splitobs(n, at) _splitobs(n::Int, at::Integer) = _splitobs(n::Int, at / n) _splitobs(n::Int, at::NTuple{N, <:Integer}) where {N} = _splitobs(n::Int, at ./ n) - _splitobs(n::Int, at::Tuple{}) = (1:n,) function _splitobs(n::Int, at::AbstractFloat) 0 <= at <= 1 || throw(ArgumentError("the parameter \"at\" must be in interval (0, 1)")) - n1 = clamp(round(Int, at*n), 0, n) - (1:n1, n1+1:n) + n1 = round(Int, n * at) + return (1:n1, n1+1:n) end function _splitobs(n::Int, at::NTuple{N,<:AbstractFloat}) where N @@ -37,22 +36,24 @@ function _splitobs(n::Int, at::NTuple{N,<:AbstractFloat}) where N return (a, rest...) end + """ - splitobs([rng], data; at, shuffle=false) -> Tuple + splitobs([rng,] data; at, shuffle=false, stratified=nothing) -> Tuple Partition the `data` into two or more subsets. -When `at` is a number between 0 and 1, this specifies the proportion in the first subset. - -When `at` is an integer, it specifies the number of observations in the first subset. - -When `at` is a tuple, entries specifies the number or proportion in each subset, except +The argument `at` specifies how to split the data: +- When `at` is a number between 0 and 1, this specifies the proportion in the first subset. +- When `at` is an integer, it specifies the number of observations in the first subset. +- When `at` is a tuple, entries specifies the number or proportion in each subset, except for the last which will contain the remaning observations. The number of returned subsets is `length(at)+1`. If `shuffle=true`, randomly permute the observations before splitting. A random number generator `rng` can be optionally passed as the first argument. +If `stratified` is not `nothing`, it should be an array of labels with the same length as the data. +The observations will be split in such a way that the proportion of each label is preserved in each subset. Supports any datatype implementing [`numobs`](@ref). @@ -74,14 +75,41 @@ julia> train, test = splitobs((reshape(1.0:100.0, 1, :), 101:200), at=0.7, shuff julia> vec(test[1]) .+ 100 == test[2] true + +julia> splitobs(1:10, at=0.5, stratified=[0,0,0,0,1,1,1,1,1,1]) # 2 zeros and 3 ones in each subset +([1, 2, 5, 6, 7], [3, 4, 8, 9, 10]) ``` """ splitobs(data; kws...) = splitobs(Random.default_rng(), data; kws...) -function splitobs(rng::AbstractRNG, data; at, shuffle::Bool=false) +function splitobs(rng::AbstractRNG, data; at, + shuffle::Bool=false, + stratified::Union{Nothing,AbstractVector}=nothing) + n = numobs(data) + at = _normalize_at(n, at) if shuffle - data = shuffleobs(rng, data) + perm = randperm(rng, n) + data = obsview(data, perm) # same as shuffleobs(rng, data), but make it explicit to keep perm end - n = numobs(data) - return map(idx -> obsview(data, idx), splitobs(n; at)) + if stratified !== nothing + @assert length(stratified) == n + if shuffle + stratified = stratified[perm] + end + idxs_groups = group_indices(stratified) + idxs_splits = ntuple(i -> Int[], length(at)+1) + for (lbl, idxs) in idxs_groups + new_idxs_splits = splitobs(idxs; at, shuffle=false) + for i in 1:length(idxs_splits) + append!(idxs_splits[i], new_idxs_splits[i]) + end + end + else + idxs_splits = splitobs(n; at) + end + return map(idxs -> obsview(data, idxs), idxs_splits) end + +_normalize_at(n, at::Integer) = at / n +_normalize_at(n, at::NTuple{N, <:Integer}) where N = at ./ n +_normalize_at(n, at) = at \ No newline at end of file diff --git a/src/utils.jl b/src/utils.jl index b928135..82ca407 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -380,7 +380,7 @@ function batch(xs::Vector{<:NamedTuple}) all_keys = [sort(collect(keys(x))) for x in xs] ks = all_keys[1] @assert all(==(ks), all_keys) "Cannot batch named tuples with different keys" - NamedTuple(k => batch([x[k] for x in xs]) for k in ks) + return NamedTuple(k => batch([x[k] for x in xs]) for k in ks) end function batch(xs::Vector{<:Dict}) @@ -388,7 +388,7 @@ function batch(xs::Vector{<:Dict}) all_keys = [sort(collect(keys(x))) for x in xs] ks = all_keys[1] @assert all(==(ks), all_keys) "cannot batch dicts with different keys" - Dict(k => batch([x[k] for x in xs]) for k in ks) + return Dict(k => batch([x[k] for x in xs]) for k in ks) end """ diff --git a/test/splitobs.jl b/test/splitobs.jl index f5ce335..04a9e62 100644 --- a/test/splitobs.jl +++ b/test/splitobs.jl @@ -90,3 +90,20 @@ end p2, _ = splitobs(rng, data, at=3, shuffle=true) @test p1 == p2 end + +@testset "stratified" begin + data = (a=zeros(Float32, 2, 10), b=[0,0,0,0,1,1,1,1,1,1]) + d1, d2 = splitobs(data, at=0.5, stratified=data.b) + @test d1.b == [0,0,1,1,1] + @test d2.b == [0,0,1,1,1] + d1, d2 = splitobs(data, at=0.25, stratified=data.b) + @test d1.b == [0,1,1] + @test d2.b == [0,0,0,1,1,1,1] + + d1, d2 = splitobs(data, at=0., stratified=data.b) + @test d1.b == [] + @test d2.b == [0,0,0,0,1,1,1,1,1,1] + d1, d2 = splitobs(data, at=1., stratified=data.b) + @test d1.b == [0,0,0,0,1,1,1,1,1,1] + @test d2.b == [] +end