Skip to content

Commit

Permalink
statified
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Feb 2, 2025
1 parent 3aa4372 commit 7f4c1a4
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/splitobs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ If `shuffle=true`, randomly permute the observations before splitting.
A random number generator `rng` can be optionally passed as the first argument.
If `stratified` is not `nothing`, it should be an array of labels with the same length as the data.
The observations will be split in a way that the proportion of each label is preserved in each subset.
The observations will be split in such a way that the proportion of each label is preserved in each subset.
Supports any datatype implementing [`numobs`](@ref).
Expand All @@ -75,6 +75,9 @@ julia> train, test = splitobs((reshape(1.0:100.0, 1, :), 101:200), at=0.7, shuff
julia> vec(test[1]) .+ 100 == test[2]
true
julia> splitobs(1:10, at=0.5, stratified=[0,0,0,0,1,1,1,1,1,1]) # 2 zeros and 3 ones in each subset
([1, 2, 5, 6, 7], [3, 4, 8, 9, 10])
```
"""
splitobs(data; kws...) = splitobs(Random.default_rng(), data; kws...)
Expand Down
7 changes: 7 additions & 0 deletions test/splitobs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,4 +99,11 @@ end
d1, d2 = splitobs(data, at=0.25, stratified=data.b)
@test d1.b == [0,1,1]
@test d2.b == [0,0,0,1,1,1,1]

d1, d2 = splitobs(data, at=0., stratified=data.b)
@test d1.b == []
@test d2.b == [0,0,0,0,1,1,1,1,1,1]
d1, d2 = splitobs(data, at=1., stratified=data.b)
@test d1.b == [0,0,0,0,1,1,1,1,1,1]
@test d2.b == []
end

0 comments on commit 7f4c1a4

Please sign in to comment.