Skip to content

Blocked logical indexing #475

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/BlockArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,16 @@ export blockappend!, blockpush!, blockpushfirst!, blockpop!, blockpopfirst!
import Base: @propagate_inbounds, Array, AbstractArray, to_indices, to_index,
unsafe_indices, first, last, size, length, unsafe_length,
unsafe_convert,
getindex, setindex!, ndims, show, view,
getindex, setindex!, ndims, show, print_array, view,
step,
broadcast, eltype, convert, similar,
broadcast, eltype, convert, similar, collect,
tail, reindex,
RangeIndex, Int, Integer, Number, Tuple,
+, -, *, /, \, min, max, isless, in, copy, copyto!, axes, @deprecate,
BroadcastStyle, checkbounds,
BroadcastStyle, checkbounds, checkindex, ensure_indexable,
oneunit, ones, zeros, intersect, Slice, resize!

using Base: ReshapedArray, dataids, oneto
using Base: ReshapedArray, LogicalIndex, dataids, oneto

import Base: (:), IteratorSize, iterate, axes1, strides, isempty
import Base.Broadcast: broadcasted, DefaultArrayStyle, AbstractArrayStyle, Broadcasted, broadcastable
Expand Down
4 changes: 4 additions & 0 deletions src/blockedarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,10 @@ AbstractArray{T,N}(A::BlockedArray) where {T,N} = BlockedArray(AbstractArray{T,N

copy(A::BlockedArray) = BlockedArray(copy(A.blocks), A.axes)

# Blocked version of `collect(::AbstractArray)` that preserves the
# block structure.
blockcollect(a::AbstractArray) = BlockedArray(collect(a), axes(a))

Base.dataids(A::BlockedArray) = Base.dataids(A.blocks)

###########################
Expand Down
44 changes: 44 additions & 0 deletions src/views.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,50 @@
@inline to_indices(A, I::Tuple{AbstractVector{<:BlockIndex{1}}, Vararg{Any}}) = to_indices(A, axes(A), I)
@inline to_indices(A, I::Tuple{AbstractVector{<:BlockIndexRange{1}}, Vararg{Any}}) = to_indices(A, axes(A), I)

## BlockedLogicalIndex
# Blocked version of `LogicalIndex`:
# https://github.com/JuliaLang/julia/blob/3e2f90fbb8f6b0651f2601d7599c55d4e3efd496/base/multidimensional.jl#L819-L831
const BlockedLogicalIndex{T,R<:LogicalIndex{T},BS<:Tuple{AbstractUnitRange{<:Integer}}} = BlockedVector{T,R,BS}
function BlockedLogicalIndex(I::AbstractVector{Bool})
blocklengths = map(b -> count(view(I, b)), BlockRange(I))
return BlockedVector(LogicalIndex(I), blocklengths)
end
# https://github.com/JuliaLang/julia/blob/3e2f90fbb8f6b0651f2601d7599c55d4e3efd496/base/multidimensional.jl#L838-L839
show(io::IO, r::BlockedLogicalIndex) = print(io, blockcollect(r))
print_array(io::IO, X::BlockedLogicalIndex) = print_array(io, blockcollect(X))

Check warning on line 72 in src/views.jl

View check run for this annotation

Codecov / codecov/patch

src/views.jl#L71-L72

Added lines #L71 - L72 were not covered by tests

# Blocked version of `to_index(::AbstractArray{Bool})`:
# https://github.com/JuliaLang/julia/blob/3e2f90fbb8f6b0651f2601d7599c55d4e3efd496/base/indices.jl#L309
function to_index(I::AbstractBlockVector{Bool})
return BlockedLogicalIndex(I)
end

# Blocked version of `collect(::LogicalIndex)`:
# https://github.com/JuliaLang/julia/blob/3e2f90fbb8f6b0651f2601d7599c55d4e3efd496/base/multidimensional.jl#L837
# Without this definition, `collect` will try to call `getindex` on the `LogicalIndex`
# which isn't defined.
collect(I::BlockedLogicalIndex) = collect(I.blocks)

# Iteration of BlockedLogicalIndex is just iteration over the underlying
# LogicalIndex, which is implemented here:
# https://github.com/JuliaLang/julia/blob/3e2f90fbb8f6b0651f2601d7599c55d4e3efd496/base/multidimensional.jl#L840-L890
@inline iterate(I::BlockedLogicalIndex) = iterate(I.blocks)
@inline iterate(I::BlockedLogicalIndex, s) = iterate(I.blocks, s)

## Boundscheck for BlockLogicalindex
# Like for LogicalIndex, map all calls to mask:
# https://github.com/JuliaLang/julia/blob/3e2f90fbb8f6b0651f2601d7599c55d4e3efd496/base/multidimensional.jl#L892-L897
checkbounds(::Type{Bool}, A::AbstractArray, i::BlockedLogicalIndex) = checkbounds(Bool, A, i.blocks.mask)

Check warning on line 95 in src/views.jl

View check run for this annotation

Codecov / codecov/patch

src/views.jl#L95

Added line #L95 was not covered by tests
# `checkbounds_indices` has been handled via `I::AbstractArray` fallback
checkindex(::Type{Bool}, inds::AbstractUnitRange, i::BlockedLogicalIndex) = checkindex(Bool, inds, i.blocks.mask)
checkindex(::Type{Bool}, inds::Tuple, i::BlockedLogicalIndex) = checkindex(Bool, inds, i.blocks.mask)

Check warning on line 98 in src/views.jl

View check run for this annotation

Codecov / codecov/patch

src/views.jl#L98

Added line #L98 was not covered by tests

# Instantiate the BlockedLogicalIndex when constructing a SubArray, similar to
# `ensure_indexable(I::Tuple{LogicalIndex,Vararg{Any}})`:
# https://github.com/JuliaLang/julia/blob/3e2f90fbb8f6b0651f2601d7599c55d4e3efd496/base/multidimensional.jl#L918
@inline ensure_indexable(I::Tuple{BlockedLogicalIndex,Vararg{Any}}) =
(blockcollect(I[1]), ensure_indexable(tail(I))...)

@propagate_inbounds reindex(idxs::Tuple{BlockSlice{<:BlockRange}, Vararg{Any}},
subidxs::Tuple{BlockSlice{<:BlockIndexRange}, Vararg{Any}}) =
(BlockSlice(BlockIndexRange(Block(idxs[1].block.indices[1][Int(subidxs[1].block.block)]),
Expand Down
28 changes: 27 additions & 1 deletion test/test_blockarrays.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module TestBlockArrays

using SparseArrays, BlockArrays, FillArrays, LinearAlgebra, Test, OffsetArrays, Images
import BlockArrays: _BlockArray
import BlockArrays: _BlockArray, blockcollect

const Fill = FillArrays.Fill

Expand Down Expand Up @@ -255,6 +255,32 @@ end
@test zero(b) isa typeof(b)
end

@testset "blockcollect" begin
a = randn(6, 6)
@test blockcollect(a) == a
@test blockcollect(a) ≢ a
@test blockcollect(a).blocks ≢ a
# TODO: Maybe special case this to call `collect` and return a `Matrix`?
@test blockcollect(a) isa BlockedMatrix{Float64,Matrix{Float64}}
@test blockisequal(axes(blockcollect(a)), axes(a))
@test blocksize(blockcollect(a)) == (1, 1)

b = BlockedArray(randn(6, 6), [3, 3], [3, 3])
@test blockcollect(b) == b
@test blockcollect(b) ≢ b
@test blockcollect(b).blocks ≢ b
@test blockcollect(b) isa BlockedMatrix{Float64,Matrix{Float64}}
@test blockisequal(axes(blockcollect(b)), axes(b))
@test blocksize(blockcollect(b)) == (2, 2)

c = BlockArray(randn(6, 6), [3, 3], [3, 3])
@test blockcollect(c) == c
@test blockcollect(c) ≢ c
@test blockcollect(c) isa BlockedMatrix{Float64,Matrix{Float64}}
@test blockisequal(axes(blockcollect(c)), axes(c))
@test blocksize(blockcollect(c)) == (2, 2)
end

@test_throws DimensionMismatch BlockArray([1,2,3],[1,1])

@testset "mortar" begin
Expand Down
21 changes: 21 additions & 0 deletions test/test_blockviews.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ module TestBlockViews

using BlockArrays, ArrayLayouts, Test
using FillArrays
import BlockArrays: BlockedLogicalIndex
import Base: LogicalIndex

# useds to force SubArray return
bview(a, b) = Base.invoke(view, Tuple{AbstractArray,Any}, a, b)
Expand Down Expand Up @@ -353,6 +355,25 @@ bview(a, b) = Base.invoke(view, Tuple{AbstractArray,Any}, a, b)
@test MemoryLayout(v) == MemoryLayout(a)
@test v[Block(1)] == a[Block(1)]
end

@testset "BlockedLogicalIndex" begin
a = randn(6, 6)
for mask in ([true, true, false, false, true, false], BitVector([true, true, false, false, true, false]))
I = BlockedVector(mask, [3, 3])
@test to_indices(a, (I, I)) == to_indices(a, (mask, mask))
@test to_indices(a, (I, I)) == (BlockedVector(LogicalIndex(mask), [2, 1]), BlockedVector(LogicalIndex(mask), [2, 1]))
@test to_indices(a, (I, I)) isa Tuple{BlockedLogicalIndex{Int},BlockedLogicalIndex{Int}}
@test blocklengths.(Base.axes1.(to_indices(a, (I, I)))) == ([2, 1], [2, 1])
for b in (view(a, I, I), a[I, I])
@test size(b) == (3, 3)
@test blocklengths.(axes(b)) == ([2, 1], [2, 1])
@test b == a[mask, mask]
end
@test parentindices(view(a, I, I)) == (BlockedVector([1, 2, 5], [2, 1]), BlockedVector([1, 2, 5], [2, 1]))
@test parentindices(view(a, I, I)) isa Tuple{BlockedVector{Int,Vector{Int}},BlockedVector{Int,Vector{Int}}}
@test blocklengths.(Base.axes1.(parentindices(view(a, I, I)))) == ([2, 1], [2, 1])
end
end
end

end # module
Loading