Skip to content

Commit ba33df7

Browse files
Merge pull request #261 from jipolanco/jip/recursivefill
Fix `recursivefill!` for `VectorOfArray{<:StaticArray}`
2 parents cba251a + f9712e9 commit ba33df7

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

src/utils.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,14 @@ function recursivefill!(b::AbstractArray{T, N},
8686
end
8787
end
8888

89+
function recursivefill!(bs::AbstractVectorOfArray{T, N},
90+
a::T2) where {T <: StaticArraysCore.StaticArray,
91+
T2 <: StaticArraysCore.StaticArray, N}
92+
@inbounds for b in bs, i in eachindex(b)
93+
b[i] = copy(a)
94+
end
95+
end
96+
8997
function recursivefill!(b::AbstractArray{T, N},
9098
a::T2) where {T <: StaticArraysCore.SArray,
9199
T2 <: Union{Number, Bool}, N}
@@ -94,6 +102,14 @@ function recursivefill!(b::AbstractArray{T, N},
94102
end
95103
end
96104

105+
function recursivefill!(bs::AbstractVectorOfArray{T, N},
106+
a::T2) where {T <: StaticArraysCore.SArray,
107+
T2 <: Union{Number, Bool}, N}
108+
@inbounds for b in bs, i in eachindex(b)
109+
b[i] = fill(a, typeof(b[i]))
110+
end
111+
end
112+
97113
function recursivefill!(b::AbstractArray{T, N}, a::T2) where {T <: Enum, T2 <: Enum, N}
98114
fill!(b, a)
99115
end

test/utils_test.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,21 @@ x = similar(x)
9999
recursivefill!(x, true)
100100
@test x[1] == MVector{10}(ones(10))
101101
@test x[2] == MVector{10}(ones(10))
102+
103+
# Test VectorOfArray + recursivefill! + static arrays
104+
@testset "VectorOfArray + recursivefill! + static arrays" begin
105+
Vec3 = SVector{3, Float64}
106+
x = [randn(Vec3, n) for n in 1:4] # vector of vectors of static arrays
107+
108+
x_voa = VectorOfArray(x)
109+
@test eltype(x_voa) === Vec3
110+
@test first(x_voa) isa AbstractVector{Vec3}
111+
112+
y_voa = recursivecopy(x_voa)
113+
recursivefill!(y_voa, true)
114+
@test all(y_voa[n] == fill(ones(Vec3), n) for n in 1:4)
115+
116+
y_voa = recursivecopy(x_voa)
117+
recursivefill!(y_voa, ones(Vec3))
118+
@test all(y_voa[n] == fill(ones(Vec3), n) for n in 1:4)
119+
end

0 commit comments

Comments
 (0)