@@ -65,7 +65,7 @@ The original data is preserved in the `data` field of the DataLoader.
6565 (depending on the `collate` and `batchsize` options, could be `getobs!(buffer, data, idxs)` or `getobs!(buffer[i], data, idx)`).
6666 Default `false`.
6767- **`collate`**: Defines the batching behavior. Default `nothing`.
68- - If `nothing` , a batch is `getobs(data, indices)`.
68+ - If `nothing`, a batch is `getobs(data, indices)`.
6969 - If `false`, each batch is `[getobs(data, i) for i in indices]`.
7070 - If `true`, applies `MLUtils.batch` to the vector of observations in a batch,
7171 recursively collating arrays in the last dimensions. See [`MLUtils.batch`](@ref) for more information
@@ -235,7 +235,7 @@ _create_buffer(x) = getobs(x, 1)
235235
236236function _create_buffer (x:: BatchView )
237237 obsindices = _batchrange (x, 1 )
238- return [getobs (A . data, idx ) for idx in enumerate ( obsindices) ]
238+ return [getobs (x . data, i ) for i in obsindices]
239239end
240240
241241function _create_buffer (x:: BatchView{TElem,TData,Val{nothing}} ) where {TElem,TData}
@@ -322,18 +322,24 @@ end
322322
323323# Base uses this function for composable array printing, e.g. adjoint(view(::Matrix)))
324324function Base. showarg (io:: IO , d:: DataLoader , toplevel)
325- print (io, " DataLoader(" )
325+ print (io, " DataLoader(data " )
326326 Base. showarg (io, d. data, false )
327- d. buffer == false || print (io, " , buffer=" , d. buffer)
327+ if d. buffer != false
328+ print (io, " , buffer" )
329+ Base. showarg (io, d. buffer, false )
330+ end
328331 d. parallel == false || print (io, " , parallel=" , d. parallel)
329332 d. shuffle == false || print (io, " , shuffle=" , d. shuffle)
330333 d. batchsize == 1 || print (io, " , batchsize=" , d. batchsize)
331334 d. partial == true || print (io, " , partial=" , d. partial)
332- d. collate === Val (nothing ) || print (io, " , collate=" , d. collate)
335+ d. collate === Val (nothing ) || print (io, " , collate=" , _valstr ( d. collate) )
333336 d. rng == Random. default_rng () || print (io, " , rng=" , d. rng)
334337 print (io, " )" )
335338end
336339
340+ _valstr (:: Val{T} ) where T = string (T)
341+ _valstr (x) = string (x)
342+
337343Base. show (io:: IO , e:: DataLoader ) = Base. showarg (io, e, false )
338344
339345function Base. show (io:: IO , m:: MIME"text/plain" , d:: DataLoader )
0 commit comments