diff --git a/src/gather.jl b/src/gather.jl index 5690dfc..f6f24e3 100644 --- a/src/gather.jl +++ b/src/gather.jl @@ -63,3 +63,17 @@ function NNlib.gather!(dst::AnyCuArray, src::AnyCuArray, idx::AnyCuArray) kernel(args...; threads=threads, blocks=blocks) return dst end + +function NNlib.gather(src::AnyCuArray{Tsrc, Nsrc}, + idx::AnyCuArray{Tidx, Nidx}) where + {Tsrc, Nsrc, Nidx, Tidx} + M = NNlib.typelength(Tidx) + dstsize = (size(src)[1:Nsrc-M]..., size(idx)...) + dst = similar(src, Tsrc, dstsize) + return NNlib.gather!(dst, src, idx) +end + +function NNlib.gather(src::AnyCuArray, idx::AbstractArray) + err_msg = "src and idx both must be on GPU, but received $(typeof(src)) and $(typeof(idx)), respectively." + throw(ArgumentError(err_msg)) +end diff --git a/test/gather.jl b/test/gather.jl index f200f77..e46f3ea 100644 --- a/test/gather.jl +++ b/test/gather.jl @@ -17,6 +17,7 @@ gputest(src -> NNlib.gather(src, index), src, checkgrad=true) @test NNlib.gather!(CUDA.zeros(T, size(index)...), src, index) == output @test_throws ArgumentError NNlib.gather!(zeros(T, 3, 5), src, index) + @test_throws ArgumentError NNlib.gather(src, collect(index)) ## 1d src, 2d index of tuples -> 2d output src = CT([3, 4, 5, 6, 7])