Skip to content

Commit 9207b0b

Browse files
authored
Remove setdiff in kfolds (#209)
Rather than using `setdiff` to compute training indices (which can allocate large vectors on big datasets), compute them analytically from fold sizes and offsets. This reduces memory usage and improves performance.
1 parent 9632e82 commit 9207b0b

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

src/folds.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,23 +37,23 @@ julia> val_idx
3737
9:10
3838
```
3939
"""
40-
function kfolds(n::Integer, k::Integer = 5)
40+
function kfolds(n::Integer, k::Integer=5)
4141
2 <= k <= n || throw(ArgumentError("n must be positive and k must to be within 2:$(max(2,n))"))
4242
# Compute the size of each fold. This is important because
4343
# in general the number of total observations might not be
4444
# divisible by k. In such cases it is custom that the remaining
4545
# observations are divided among the folds. Thus some folds
4646
# have one more observation than others.
47-
sizes = fill(floor(Int, n/k), k)
48-
for i = 1:(n % k)
47+
sizes = fill(floor(Int, n / k), k)
48+
for i = 1:(n%k)
4949
sizes[i] = sizes[i] + 1
5050
end
5151
# Compute start offset for each fold
5252
offsets = cumsum(sizes) .- sizes .+ 1
5353
# Compute the validation indices using the offsets and sizes
54-
val_indices = map((o,s) -> (o:o+s-1), offsets, sizes)
54+
val_indices = map((o, s) -> (o:o+s-1), offsets, sizes)
5555
# The train indices are then the indicies not in validation
56-
train_indices = map(idx -> setdiff(1:n, idx), val_indices)
56+
train_indices = map((o, s) -> vcat(1:o-1, o+s:n), offsets, sizes)
5757
# We return a tuple of arrays
5858
return train_indices, val_indices
5959
end
@@ -106,8 +106,8 @@ function kfolds(data, k::Integer)
106106
n = numobs(data)
107107
train_indices, val_indices = kfolds(n, k)
108108

109-
((obsview(data, itrain), obsview(data, ival))
110-
for (itrain, ival) in zip(train_indices, val_indices))
109+
((obsview(data, itrain), obsview(data, ival))
110+
for (itrain, ival) in zip(train_indices, val_indices))
111111
end
112112

113113
kfolds(data; k) = kfolds(data, k)
@@ -151,8 +151,8 @@ julia> val_idx
151151
9:10
152152
```
153153
"""
154-
function leavepout(n::Integer, p::Integer = 1)
155-
1 <= p <= floor(n/2) || throw(ArgumentError("p must to be within 1:$(floor(Int,n/2))"))
154+
function leavepout(n::Integer, p::Integer=1)
155+
1 <= p <= floor(n / 2) || throw(ArgumentError("p must to be within 1:$(floor(Int,n/2))"))
156156
k = floor(Int, n / p)
157157
kfolds(n, k)
158158
end
@@ -181,9 +181,9 @@ See[`kfolds`](@ref) for a related function.
181181
"""
182182
function leavepout(data, p::Integer)
183183
n = numobs(data)
184-
1 <= p <= floor(n/2) || throw(ArgumentError("p must to be within 1:$(floor(Int,n/2))"))
184+
1 <= p <= floor(n / 2) || throw(ArgumentError("p must to be within 1:$(floor(Int,n/2))"))
185185
k = floor(Int, n / p)
186186
kfolds(data, k)
187187
end
188188

189-
leavepout(data; p::Integer=1) = leavepout(data, p)
189+
leavepout(data; p::Integer=1) = leavepout(data, p)

0 commit comments

Comments
 (0)