Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add stratified options to splitobs #195

Merged
merged 4 commits into from
Feb 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions src/obstransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,7 @@ accomplish that, which means that the return value is likely of a
different type than `data`.

Optionally, a random number generator `rng` can be passed as the
first argument.

The optional parameter `rng` allows one to specify the
random number generator used for shuffling. This is useful when
reproducible results are desired.
first argument.

For this function to work, the type of `data` must implement
[`numobs`](@ref) and [`getobs`](@ref).
Expand Down
56 changes: 42 additions & 14 deletions src/splitobs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
splitobs(n::Int; at) -> Tuple

Compute the indices for two or more disjoint subsets of
the range `1:n` with splits given by `at`.
the range `1:n` with split sizes determined by `at`.

# Examples

Expand All @@ -18,13 +18,12 @@ splitobs(n::Int; at) = _splitobs(n, at)

_splitobs(n::Int, at::Integer) = _splitobs(n::Int, at / n)
_splitobs(n::Int, at::NTuple{N, <:Integer}) where {N} = _splitobs(n::Int, at ./ n)

_splitobs(n::Int, at::Tuple{}) = (1:n,)

function _splitobs(n::Int, at::AbstractFloat)
0 <= at <= 1 || throw(ArgumentError("the parameter \"at\" must be in interval (0, 1)"))
n1 = clamp(round(Int, at*n), 0, n)
(1:n1, n1+1:n)
n1 = round(Int, n * at)
return (1:n1, n1+1:n)
end

function _splitobs(n::Int, at::NTuple{N,<:AbstractFloat}) where N
Expand All @@ -37,22 +36,24 @@ function _splitobs(n::Int, at::NTuple{N,<:AbstractFloat}) where N
return (a, rest...)
end


"""
splitobs([rng], data; at, shuffle=false) -> Tuple
splitobs([rng,] data; at, shuffle=false, stratified=nothing) -> Tuple

Partition the `data` into two or more subsets.

When `at` is a number between 0 and 1, this specifies the proportion in the first subset.

When `at` is an integer, it specifies the number of observations in the first subset.

When `at` is a tuple, entries specifies the number or proportion in each subset, except
The argument `at` specifies how to split the data:
- When `at` is a number between 0 and 1, this specifies the proportion in the first subset.
- When `at` is an integer, it specifies the number of observations in the first subset.
- When `at` is a tuple, entries specifies the number or proportion in each subset, except
for the last which will contain the remaning observations.
The number of returned subsets is `length(at)+1`.

If `shuffle=true`, randomly permute the observations before splitting.
A random number generator `rng` can be optionally passed as the first argument.

If `stratified` is not `nothing`, it should be an array of labels with the same length as the data.
The observations will be split in such a way that the proportion of each label is preserved in each subset.

Supports any datatype implementing [`numobs`](@ref).

Expand All @@ -74,14 +75,41 @@ julia> train, test = splitobs((reshape(1.0:100.0, 1, :), 101:200), at=0.7, shuff

julia> vec(test[1]) .+ 100 == test[2]
true

julia> splitobs(1:10, at=0.5, stratified=[0,0,0,0,1,1,1,1,1,1]) # 2 zeros and 3 ones in each subset
([1, 2, 5, 6, 7], [3, 4, 8, 9, 10])
```
"""
splitobs(data; kws...) = splitobs(Random.default_rng(), data; kws...)

function splitobs(rng::AbstractRNG, data; at, shuffle::Bool=false)
function splitobs(rng::AbstractRNG, data; at,
shuffle::Bool=false,
stratified::Union{Nothing,AbstractVector}=nothing)
n = numobs(data)
at = _normalize_at(n, at)
if shuffle
data = shuffleobs(rng, data)
perm = randperm(rng, n)
data = obsview(data, perm) # same as shuffleobs(rng, data), but make it explicit to keep perm
end
n = numobs(data)
return map(idx -> obsview(data, idx), splitobs(n; at))
if stratified !== nothing
@assert length(stratified) == n
if shuffle
stratified = stratified[perm]
end
idxs_groups = group_indices(stratified)
idxs_splits = ntuple(i -> Int[], length(at)+1)
for (lbl, idxs) in idxs_groups
new_idxs_splits = splitobs(idxs; at, shuffle=false)
for i in 1:length(idxs_splits)
append!(idxs_splits[i], new_idxs_splits[i])
end
end
else
idxs_splits = splitobs(n; at)
end
return map(idxs -> obsview(data, idxs), idxs_splits)
end

_normalize_at(n, at::Integer) = at / n
_normalize_at(n, at::NTuple{N, <:Integer}) where N = at ./ n
_normalize_at(n, at) = at
4 changes: 2 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -380,15 +380,15 @@ function batch(xs::Vector{<:NamedTuple})
all_keys = [sort(collect(keys(x))) for x in xs]
ks = all_keys[1]
@assert all(==(ks), all_keys) "Cannot batch named tuples with different keys"
NamedTuple(k => batch([x[k] for x in xs]) for k in ks)
return NamedTuple(k => batch([x[k] for x in xs]) for k in ks)
end

function batch(xs::Vector{<:Dict})
@assert length(xs) > 0 "Input should be non-empty"
all_keys = [sort(collect(keys(x))) for x in xs]
ks = all_keys[1]
@assert all(==(ks), all_keys) "cannot batch dicts with different keys"
Dict(k => batch([x[k] for x in xs]) for k in ks)
return Dict(k => batch([x[k] for x in xs]) for k in ks)
end

"""
Expand Down
17 changes: 17 additions & 0 deletions test/splitobs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,20 @@ end
p2, _ = splitobs(rng, data, at=3, shuffle=true)
@test p1 == p2
end

@testset "stratified" begin
data = (a=zeros(Float32, 2, 10), b=[0,0,0,0,1,1,1,1,1,1])
d1, d2 = splitobs(data, at=0.5, stratified=data.b)
@test d1.b == [0,0,1,1,1]
@test d2.b == [0,0,1,1,1]
d1, d2 = splitobs(data, at=0.25, stratified=data.b)
@test d1.b == [0,1,1]
@test d2.b == [0,0,0,1,1,1,1]

d1, d2 = splitobs(data, at=0., stratified=data.b)
@test d1.b == []
@test d2.b == [0,0,0,0,1,1,1,1,1,1]
d1, d2 = splitobs(data, at=1., stratified=data.b)
@test d1.b == [0,0,0,0,1,1,1,1,1,1]
@test d2.b == []
end
Loading