Skip to content

Commit bdc4386

Browse files
fix buffered dataloader (#206)
* fix buffered dataloader * printing test
1 parent f778831 commit bdc4386

File tree

2 files changed

+43
-8
lines changed

2 files changed

+43
-8
lines changed

src/dataloader.jl

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ The original data is preserved in the `data` field of the DataLoader.
6565
(depending on the `collate` and `batchsize` options, could be `getobs!(buffer, data, idxs)` or `getobs!(buffer[i], data, idx)`).
6666
Default `false`.
6767
- **`collate`**: Defines the batching behavior. Default `nothing`.
68-
- If `nothing` , a batch is `getobs(data, indices)`.
68+
- If `nothing`, a batch is `getobs(data, indices)`.
6969
- If `false`, each batch is `[getobs(data, i) for i in indices]`.
7070
- If `true`, applies `MLUtils.batch` to the vector of observations in a batch,
7171
recursively collating arrays in the last dimensions. See [`MLUtils.batch`](@ref) for more information
@@ -235,7 +235,7 @@ _create_buffer(x) = getobs(x, 1)
235235

236236
function _create_buffer(x::BatchView)
237237
obsindices = _batchrange(x, 1)
238-
return [getobs(A.data, idx) for idx in enumerate(obsindices)]
238+
return [getobs(x.data, i) for i in obsindices]
239239
end
240240

241241
function _create_buffer(x::BatchView{TElem,TData,Val{nothing}}) where {TElem,TData}
@@ -322,18 +322,24 @@ end
322322

323323
# Base uses this function for composable array printing, e.g. adjoint(view(::Matrix)))
324324
function Base.showarg(io::IO, d::DataLoader, toplevel)
325-
print(io, "DataLoader(")
325+
print(io, "DataLoader(data")
326326
Base.showarg(io, d.data, false)
327-
d.buffer == false || print(io, ", buffer=", d.buffer)
327+
if d.buffer != false
328+
print(io, ", buffer")
329+
Base.showarg(io, d.buffer, false)
330+
end
328331
d.parallel == false || print(io, ", parallel=", d.parallel)
329332
d.shuffle == false || print(io, ", shuffle=", d.shuffle)
330333
d.batchsize == 1 || print(io, ", batchsize=", d.batchsize)
331334
d.partial == true || print(io, ", partial=", d.partial)
332-
d.collate === Val(nothing) || print(io, ", collate=", d.collate)
335+
d.collate === Val(nothing) || print(io, ", collate=", _valstr(d.collate))
333336
d.rng == Random.default_rng() || print(io, ", rng=", d.rng)
334337
print(io, ")")
335338
end
336339

340+
_valstr(::Val{T}) where T = string(T)
341+
_valstr(x) = string(x)
342+
337343
Base.show(io::IO, e::DataLoader) = Base.showarg(io, e, false)
338344

339345
function Base.show(io::IO, m::MIME"text/plain", d::DataLoader)

test/dataloader.jl

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -271,21 +271,50 @@
271271

272272
d = DataLoader((X2, Y2), batchsize=3)
273273

274-
@test contains(repr(d), "DataLoader(::Tuple{Matrix")
274+
@test contains(repr(d), "DataLoader(data::Tuple{Matrix")
275275
@test contains(repr(d), "batchsize=3")
276276

277277
@test contains(repr(MIME"text/plain"(), d), "2-element DataLoader")
278278
@test contains(repr(MIME"text/plain"(), d), "2×3 Matrix{Float32}, 3-element Vector")
279279

280280
d2 = DataLoader((x = X2, y = Y2), batchsize=2, partial=false)
281281

282-
@test contains(repr(d2), "DataLoader(::@NamedTuple")
282+
@test contains(repr(d2), "DataLoader(data::@NamedTuple")
283283
@test contains(repr(d2), "partial=false")
284284

285-
@test contains(repr(MIME"text/plain"(), d2), "2-element DataLoader(::@NamedTuple")
285+
@test contains(repr(MIME"text/plain"(), d2), "2-element DataLoader(data::@NamedTuple")
286286
@test contains(repr(MIME"text/plain"(), d2), "x = 2×2 Matrix{Float32}, y = 2-element Vector")
287287
end
288288
end
289+
290+
@testset "buffer issue 205" begin
291+
292+
function shift_pair(X)
293+
inputs = map(X) do x
294+
T = size(x, 4)
295+
return selectdim(x, 4, 1:(T-1))
296+
end
297+
targets = map(X) do x
298+
T = size(x, 4)
299+
return selectdim(x, 4, 2:T)
300+
end
301+
return (stack(inputs), stack(targets))
302+
end
303+
304+
trajectory = randn(Float32, 32, 32, 4, 3, 5);
305+
306+
loader = DataLoader(
307+
trajectory;
308+
batchsize=2,
309+
partial=false,
310+
buffer=true,
311+
collate = shift_pair,
312+
shuffle = false,
313+
)
314+
315+
@test first(loader)[1] == trajectory[:, :, :, 1:2, 1:2]
316+
@test first(loader)[2] == trajectory[:, :, :, 2:3, 1:2]
317+
end
289318
end
290319

291320
@testset "eachobs" begin

0 commit comments

Comments
 (0)