Skip to content

Commit b717822

Browse files
authored
Blocked logical indexing (#475)
This introduces a blocked version of logical indexing. With this PR, if you index into array with an `AbstractBlockVector{Bool}`, it is interpreted as a logical index but also uses the blocking of the index to determine the block structure of the output array, for example: ```julia julia> using BlockArrays julia> a = randn(6, 6) 6×6 Matrix{Float64}: -0.0577235 -0.12942 -0.10982 1.01086 0.196898 0.896616 -0.163481 -1.24784 1.01413 0.244657 -0.49961 -0.435926 1.64239 -0.930051 -0.923835 -0.789608 -0.53113 -0.0502656 0.0999888 -2.41073 -2.03078 0.019679 -0.857197 0.188939 -0.698236 -0.218804 -1.36086 0.77242 0.1388 1.97166 1.77482 -1.58258 -0.042804 1.30733 1.33004 0.930145 julia> mask = [true, true, false, false, true, false] 6-element Vector{Bool}: 1 1 0 0 1 0 julia> I = BlockedVector(mask, [3, 3]) 2-blocked 6-element BlockedVector{Bool}: 1 1 0 ─ 0 1 0 julia> a[I, I] 2×2-blocked 3×3 BlockedMatrix{Float64}: -0.0577235 -0.12942 │ 0.196898 -0.163481 -1.24784 │ -0.49961 ───────────────────────┼─────────── -0.698236 -0.218804 │ 0.1388 ``` Without this PR, the output of `a[I, I]` would not have been blocked.
1 parent 0c0d813 commit b717822

File tree

5 files changed

+100
-5
lines changed

5 files changed

+100
-5
lines changed

src/BlockArrays.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,16 @@ export blockappend!, blockpush!, blockpushfirst!, blockpop!, blockpopfirst!
2020
import Base: @propagate_inbounds, Array, AbstractArray, to_indices, to_index,
2121
unsafe_indices, first, last, size, length, unsafe_length,
2222
unsafe_convert,
23-
getindex, setindex!, ndims, show, view,
23+
getindex, setindex!, ndims, show, print_array, view,
2424
step,
25-
broadcast, eltype, convert, similar,
25+
broadcast, eltype, convert, similar, collect,
2626
tail, reindex,
2727
RangeIndex, Int, Integer, Number, Tuple,
2828
+, -, *, /, \, min, max, isless, in, copy, copyto!, axes, @deprecate,
29-
BroadcastStyle, checkbounds,
29+
BroadcastStyle, checkbounds, checkindex, ensure_indexable,
3030
oneunit, ones, zeros, intersect, Slice, resize!
3131

32-
using Base: ReshapedArray, dataids, oneto
32+
using Base: ReshapedArray, LogicalIndex, dataids, oneto
3333

3434
import Base: (:), IteratorSize, iterate, axes1, strides, isempty
3535
import Base.Broadcast: broadcasted, DefaultArrayStyle, AbstractArrayStyle, Broadcasted, broadcastable

src/blockedarray.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,10 @@ AbstractArray{T,N}(A::BlockedArray) where {T,N} = BlockedArray(AbstractArray{T,N
193193

194194
copy(A::BlockedArray) = BlockedArray(copy(A.blocks), A.axes)
195195

196+
# Blocked version of `collect(::AbstractArray)` that preserves the
197+
# block structure.
198+
blockcollect(a::AbstractArray) = BlockedArray(collect(a), axes(a))
199+
196200
Base.dataids(A::BlockedArray) = Base.dataids(A.blocks)
197201

198202
###########################

src/views.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,50 @@ to_index(::BlockRange) = throw(ArgumentError("BlockRange must be converted by to
6060
@inline to_indices(A, I::Tuple{AbstractVector{<:BlockIndex{1}}, Vararg{Any}}) = to_indices(A, axes(A), I)
6161
@inline to_indices(A, I::Tuple{AbstractVector{<:BlockIndexRange{1}}, Vararg{Any}}) = to_indices(A, axes(A), I)
6262

63+
## BlockedLogicalIndex
64+
# Blocked version of `LogicalIndex`:
65+
# https://github.com/JuliaLang/julia/blob/3e2f90fbb8f6b0651f2601d7599c55d4e3efd496/base/multidimensional.jl#L819-L831
66+
const BlockedLogicalIndex{T,R<:LogicalIndex{T},BS<:Tuple{AbstractUnitRange{<:Integer}}} = BlockedVector{T,R,BS}
67+
function BlockedLogicalIndex(I::AbstractVector{Bool})
68+
blocklengths = map(b -> count(view(I, b)), BlockRange(I))
69+
return BlockedVector(LogicalIndex(I), blocklengths)
70+
end
71+
# https://github.com/JuliaLang/julia/blob/3e2f90fbb8f6b0651f2601d7599c55d4e3efd496/base/multidimensional.jl#L838-L839
72+
show(io::IO, r::BlockedLogicalIndex) = print(io, blockcollect(r))
73+
print_array(io::IO, X::BlockedLogicalIndex) = print_array(io, blockcollect(X))
74+
75+
# Blocked version of `to_index(::AbstractArray{Bool})`:
76+
# https://github.com/JuliaLang/julia/blob/3e2f90fbb8f6b0651f2601d7599c55d4e3efd496/base/indices.jl#L309
77+
function to_index(I::AbstractBlockVector{Bool})
78+
return BlockedLogicalIndex(I)
79+
end
80+
81+
# Blocked version of `collect(::LogicalIndex)`:
82+
# https://github.com/JuliaLang/julia/blob/3e2f90fbb8f6b0651f2601d7599c55d4e3efd496/base/multidimensional.jl#L837
83+
# Without this definition, `collect` will try to call `getindex` on the `LogicalIndex`
84+
# which isn't defined.
85+
collect(I::BlockedLogicalIndex) = collect(I.blocks)
86+
87+
# Iteration of BlockedLogicalIndex is just iteration over the underlying
88+
# LogicalIndex, which is implemented here:
89+
# https://github.com/JuliaLang/julia/blob/3e2f90fbb8f6b0651f2601d7599c55d4e3efd496/base/multidimensional.jl#L840-L890
90+
@inline iterate(I::BlockedLogicalIndex) = iterate(I.blocks)
91+
@inline iterate(I::BlockedLogicalIndex, s) = iterate(I.blocks, s)
92+
93+
## Boundscheck for BlockLogicalindex
94+
# Like for LogicalIndex, map all calls to mask:
95+
# https://github.com/JuliaLang/julia/blob/3e2f90fbb8f6b0651f2601d7599c55d4e3efd496/base/multidimensional.jl#L892-L897
96+
checkbounds(::Type{Bool}, A::AbstractArray, i::BlockedLogicalIndex) = checkbounds(Bool, A, i.blocks.mask)
97+
# `checkbounds_indices` has been handled via `I::AbstractArray` fallback
98+
checkindex(::Type{Bool}, inds::AbstractUnitRange, i::BlockedLogicalIndex) = checkindex(Bool, inds, i.blocks.mask)
99+
checkindex(::Type{Bool}, inds::Tuple, i::BlockedLogicalIndex) = checkindex(Bool, inds, i.blocks.mask)
100+
101+
# Instantiate the BlockedLogicalIndex when constructing a SubArray, similar to
102+
# `ensure_indexable(I::Tuple{LogicalIndex,Vararg{Any}})`:
103+
# https://github.com/JuliaLang/julia/blob/3e2f90fbb8f6b0651f2601d7599c55d4e3efd496/base/multidimensional.jl#L918
104+
@inline ensure_indexable(I::Tuple{BlockedLogicalIndex,Vararg{Any}}) =
105+
(blockcollect(I[1]), ensure_indexable(tail(I))...)
106+
63107
@propagate_inbounds reindex(idxs::Tuple{BlockSlice{<:BlockRange}, Vararg{Any}},
64108
subidxs::Tuple{BlockSlice{<:BlockIndexRange}, Vararg{Any}}) =
65109
(BlockSlice(BlockIndexRange(Block(idxs[1].block.indices[1][Int(subidxs[1].block.block)]),

test/test_blockarrays.jl

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module TestBlockArrays
22

33
using SparseArrays, BlockArrays, FillArrays, LinearAlgebra, Test, OffsetArrays, Images
4-
import BlockArrays: _BlockArray
4+
import BlockArrays: _BlockArray, blockcollect
55

66
const Fill = FillArrays.Fill
77

@@ -255,6 +255,32 @@ end
255255
@test zero(b) isa typeof(b)
256256
end
257257

258+
@testset "blockcollect" begin
259+
a = randn(6, 6)
260+
@test blockcollect(a) == a
261+
@test blockcollect(a) a
262+
@test blockcollect(a).blocks a
263+
# TODO: Maybe special case this to call `collect` and return a `Matrix`?
264+
@test blockcollect(a) isa BlockedMatrix{Float64,Matrix{Float64}}
265+
@test blockisequal(axes(blockcollect(a)), axes(a))
266+
@test blocksize(blockcollect(a)) == (1, 1)
267+
268+
b = BlockedArray(randn(6, 6), [3, 3], [3, 3])
269+
@test blockcollect(b) == b
270+
@test blockcollect(b) b
271+
@test blockcollect(b).blocks b
272+
@test blockcollect(b) isa BlockedMatrix{Float64,Matrix{Float64}}
273+
@test blockisequal(axes(blockcollect(b)), axes(b))
274+
@test blocksize(blockcollect(b)) == (2, 2)
275+
276+
c = BlockArray(randn(6, 6), [3, 3], [3, 3])
277+
@test blockcollect(c) == c
278+
@test blockcollect(c) c
279+
@test blockcollect(c) isa BlockedMatrix{Float64,Matrix{Float64}}
280+
@test blockisequal(axes(blockcollect(c)), axes(c))
281+
@test blocksize(blockcollect(c)) == (2, 2)
282+
end
283+
258284
@test_throws DimensionMismatch BlockArray([1,2,3],[1,1])
259285

260286
@testset "mortar" begin

test/test_blockviews.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ module TestBlockViews
22

33
using BlockArrays, ArrayLayouts, Test
44
using FillArrays
5+
import BlockArrays: BlockedLogicalIndex
6+
import Base: LogicalIndex
57

68
# useds to force SubArray return
79
bview(a, b) = Base.invoke(view, Tuple{AbstractArray,Any}, a, b)
@@ -353,6 +355,25 @@ bview(a, b) = Base.invoke(view, Tuple{AbstractArray,Any}, a, b)
353355
@test MemoryLayout(v) == MemoryLayout(a)
354356
@test v[Block(1)] == a[Block(1)]
355357
end
358+
359+
@testset "BlockedLogicalIndex" begin
360+
a = randn(6, 6)
361+
for mask in ([true, true, false, false, true, false], BitVector([true, true, false, false, true, false]))
362+
I = BlockedVector(mask, [3, 3])
363+
@test to_indices(a, (I, I)) == to_indices(a, (mask, mask))
364+
@test to_indices(a, (I, I)) == (BlockedVector(LogicalIndex(mask), [2, 1]), BlockedVector(LogicalIndex(mask), [2, 1]))
365+
@test to_indices(a, (I, I)) isa Tuple{BlockedLogicalIndex{Int},BlockedLogicalIndex{Int}}
366+
@test blocklengths.(Base.axes1.(to_indices(a, (I, I)))) == ([2, 1], [2, 1])
367+
for b in (view(a, I, I), a[I, I])
368+
@test size(b) == (3, 3)
369+
@test blocklengths.(axes(b)) == ([2, 1], [2, 1])
370+
@test b == a[mask, mask]
371+
end
372+
@test parentindices(view(a, I, I)) == (BlockedVector([1, 2, 5], [2, 1]), BlockedVector([1, 2, 5], [2, 1]))
373+
@test parentindices(view(a, I, I)) isa Tuple{BlockedVector{Int,Vector{Int}},BlockedVector{Int,Vector{Int}}}
374+
@test blocklengths.(Base.axes1.(parentindices(view(a, I, I)))) == ([2, 1], [2, 1])
375+
end
376+
end
356377
end
357378

358379
end # module

0 commit comments

Comments
 (0)