Skip to content

Commit 21eb0e9

Browse files
authored
Generalize blockwise slicing (#157)
1 parent 4ee09d2 commit 21eb0e9

File tree

7 files changed

+440
-17
lines changed

7 files changed

+440
-17
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "BlockSparseArrays"
22
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.7.20"
4+
version = "0.7.21"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/BlockArraysExtensions/BlockArraysExtensions.jl

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,20 @@ for f in (:axes, :unsafe_indices, :axes1, :first, :last, :size, :length, :unsafe
6868
end
6969
Base.getindex(S::BlockIndices, i::Integer) = getindex(S.indices, i)
7070

71-
function _blockslice(x, y::AbstractUnitRange{<:Integer})
71+
# Generalization of to `BlockArrays._blockslice`:
72+
# https://github.com/JuliaArrays/BlockArrays.jl/blob/v1.6.3/src/views.jl#L13-L14
73+
# Used by `BlockArrays.unblock`, which is used in `to_indices`
74+
# to convert relative blockwise slices to absolute slices, but in a way
75+
# that preserves the original relative blockwise slice information.
76+
# TODO: Ideally this would be handled in BlockArrays.jl
77+
# once slicing like `A[Block(1)[[1, 2]]]` is supported.
78+
function _blockslice(x, y::AbstractUnitRange)
7279
return BlockSlice(x, y)
7380
end
74-
function _blockslice(x, y::AbstractVector{<:Integer})
81+
function _blockslice(x, y::AbstractVector)
7582
return BlockIndices(x, y)
7683
end
84+
7785
function Base.getindex(S::BlockIndices, i::BlockSlice{<:Block{1}})
7886
# TODO: Check that `i.indices` is consistent with `S.indices`.
7987
# It seems like this isn't handling the case where `i` is a
@@ -167,8 +175,14 @@ const BlockIndexRangeSlices = BlockIndices{
167175
const BlockIndexVectorSlices = BlockIndices{
168176
<:BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexVector}}
169177
}
178+
const GenericBlockIndexVectorSlices = BlockIndices{
179+
<:BlockVector{<:GenericBlockIndex{1},<:Vector{<:BlockIndexVector}}
180+
}
170181
const SubBlockSliceCollection = Union{
171-
BlockIndexRangeSlice,BlockIndexRangeSlices,BlockIndexVectorSlices
182+
BlockIndexRangeSlice,
183+
BlockIndexRangeSlices,
184+
BlockIndexVectorSlices,
185+
GenericBlockIndexVectorSlices,
172186
}
173187

174188
# TODO: This is type piracy. This is used in `reindex` when making
@@ -392,6 +406,13 @@ function blockrange(
392406
return map(Block, blocks(r))
393407
end
394408

409+
function blockrange(
410+
axis::AbstractUnitRange,
411+
r::BlockVector{<:GenericBlockIndex{1},<:AbstractVector{<:BlockIndexVector}},
412+
)
413+
return map(Block, blocks(r))
414+
end
415+
395416
function blockrange(axis::AbstractUnitRange, r)
396417
return error("Slicing not implemented for range of type `$(typeof(r))`.")
397418
end

src/BlockArraysExtensions/blockedunitrange.jl

Lines changed: 80 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,15 @@ BlockArrays.blockindex(b::GenericBlockIndex{1}) = b.α[1]
244244
function GenericBlockIndex(indcs::Tuple{Vararg{GenericBlockIndex{1},N}}) where {N}
245245
GenericBlockIndex(block.(indcs), blockindex.(indcs))
246246
end
247+
248+
function Base.checkindex(
249+
::Type{Bool}, axis::AbstractBlockedUnitRange, ind::GenericBlockIndex{1}
250+
)
251+
return checkindex(Bool, axis, block(ind)) &&
252+
checkbounds(Bool, axis[block(ind)], blockindex(ind))
253+
end
254+
Base.to_index(i::GenericBlockIndex) = i
255+
247256
function print_tuple_elements(io::IO, @nospecialize(t))
248257
if !isempty(t)
249258
print(io, t[1])
@@ -261,6 +270,13 @@ function Base.show(io::IO, B::GenericBlockIndex)
261270
return nothing
262271
end
263272

273+
# https://github.com/JuliaArrays/BlockArrays.jl/blob/v1.6.3/src/views.jl#L31-L32
274+
_maybetail(::Tuple{}) = ()
275+
_maybetail(t::Tuple) = Base.tail(t)
276+
@inline function Base.to_indices(A, inds, I::Tuple{GenericBlockIndex{1},Vararg{Any}})
277+
return (inds[1][I[1]], to_indices(A, _maybetail(inds), Base.tail(I))...)
278+
end
279+
264280
using Base: @propagate_inbounds
265281
@propagate_inbounds function Base.getindex(b::AbstractVector, K::GenericBlockIndex{1})
266282
return b[Block(K.I[1])][K.α[1]]
@@ -276,35 +292,65 @@ end
276292
return b[GenericBlockIndex(tuple(K, J...))]
277293
end
278294

279-
function blockindextype(TB::Type{<:Integer}, TI::Vararg{Type{<:Integer},N}) where {N}
280-
return BlockIndex{N,NTuple{N,TB},Tuple{TI...}}
295+
# TODO: Delete this once `BlockArrays.BlockIndex` is generalized.
296+
@inline function Base.to_indices(
297+
A, inds, I::Tuple{AbstractVector{<:GenericBlockIndex{1}},Vararg{Any}}
298+
)
299+
return (unblock(A, inds, I), to_indices(A, _maybetail(inds), Base.tail(I))...)
281300
end
282-
function blockindextype(TB::Type{<:Integer}, TI::Vararg{Type,N}) where {N}
283-
return GenericBlockIndex{N,NTuple{N,TB},Tuple{TI...}}
301+
302+
# This is a specialization of `BlockArrays.unblock`:
303+
# https://github.com/JuliaArrays/BlockArrays.jl/blob/v1.6.3/src/views.jl#L8-L11
304+
# that is used in the `to_indices` logic for blockwise slicing in
305+
# BlockArrays.jl.
306+
# TODO: Ideally this would be defined in BlockArrays.jl once the slicing
307+
# there is made more generic.
308+
function BlockArrays.unblock(A, inds, I::Tuple{GenericBlockIndex{1},Vararg{Any}})
309+
B = first(I)
310+
return _blockslice(B, inds[1][B])
284311
end
285312

286-
struct BlockIndexVector{N,I<:NTuple{N,AbstractVector},TB<:Integer,BT} <: AbstractArray{BT,N}
313+
# Work around the fact that it is type piracy to define
314+
# `Base.getindex(a::Block, b...)`.
315+
_getindex(a::Block{N}, b::Vararg{Any,N}) where {N} = GenericBlockIndex(a, b)
316+
_getindex(a::Block{N}, b::Vararg{Integer,N}) where {N} = a[b...]
317+
# Fix ambiguity.
318+
_getindex(a::Block{0}) = a[]
319+
320+
struct BlockIndexVector{N,BT,I<:NTuple{N,AbstractVector},TB<:Integer} <: AbstractArray{BT,N}
287321
block::Block{N,TB}
288322
indices::I
289-
function BlockIndexVector(
323+
function BlockIndexVector{N,BT}(
290324
block::Block{N,TB}, indices::I
291-
) where {N,I<:NTuple{N,AbstractVector},TB<:Integer}
292-
BT = blockindextype(TB, eltype.(indices)...)
293-
return new{N,I,TB,BT}(block, indices)
325+
) where {N,BT,I<:NTuple{N,AbstractVector},TB<:Integer}
326+
return new{N,BT,I,TB}(block, indices)
294327
end
295328
end
329+
function BlockIndexVector{1,BT}(block::Block{1}, indices::AbstractVector) where {BT}
330+
return BlockIndexVector{1,BT}(block, (indices,))
331+
end
332+
function BlockIndexVector(
333+
block::Block{N,TB}, indices::NTuple{N,AbstractVector}
334+
) where {N,TB<:Integer}
335+
BT = Base.promote_op(_getindex, typeof(block), eltype.(indices)...)
336+
return BlockIndexVector{N,BT}(block, indices)
337+
end
296338
function BlockIndexVector(block::Block{1}, indices::AbstractVector)
297339
return BlockIndexVector(block, (indices,))
298340
end
299341
Base.size(a::BlockIndexVector) = length.(a.indices)
300342
function Base.getindex(a::BlockIndexVector{N}, I::Vararg{Integer,N}) where {N}
301-
return a.block[map((r, i) -> r[i], a.indices, I)...]
343+
return _getindex(Block(a), getindex.(a.indices, I)...)
302344
end
303345
BlockArrays.block(b::BlockIndexVector) = b.block
304346
BlockArrays.Block(b::BlockIndexVector) = b.block
305347

306348
Base.copy(a::BlockIndexVector) = BlockIndexVector(a.block, copy.(a.indices))
307349

350+
function Base.getindex(b::AbstractBlockedUnitRange, Kkr::BlockIndexVector{1})
351+
return b[block(Kkr)][Kkr.indices...]
352+
end
353+
308354
using ArrayLayouts: LayoutArray
309355
@propagate_inbounds Base.getindex(b::AbstractArray{T,N}, K::BlockIndexVector{N}) where {T,N} = b[block(
310356
K
@@ -316,6 +362,30 @@ using ArrayLayouts: LayoutArray
316362
K
317363
)][K.indices...]
318364

365+
function blockedunitrange_getindices(
366+
a::AbstractBlockedUnitRange,
367+
indices::BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexVector{1}}},
368+
)
369+
return mortar(map(b -> a[b], blocks(indices)))
370+
end
371+
function blockedunitrange_getindices(
372+
a::AbstractBlockedUnitRange,
373+
indices::BlockVector{<:GenericBlockIndex{1},<:Vector{<:BlockIndexVector{1}}},
374+
)
375+
return mortar(map(b -> a[b], blocks(indices)))
376+
end
377+
378+
# This is a specialization of `BlockArrays.unblock`:
379+
# https://github.com/JuliaArrays/BlockArrays.jl/blob/v1.6.3/src/views.jl#L8-L11
380+
# that is used in the `to_indices` logic for blockwise slicing in
381+
# BlockArrays.jl.
382+
# TODO: Ideally this would be defined in BlockArrays.jl once the slicing
383+
# there is made more generic.
384+
function BlockArrays.unblock(A, inds, I::Tuple{BlockIndexVector{1},Vararg{Any}})
385+
B = first(I)
386+
return _blockslice(B, inds[1][B])
387+
end
388+
319389
function to_blockindices(a::AbstractBlockedUnitRange{<:Integer}, I::AbstractArray{Bool})
320390
I_blocks = blocks(BlockedVector(I, blocklengths(a)))
321391
I′_blocks = map(eachindex(I_blocks)) do b

src/abstractblocksparsearray/views.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,10 @@ end
9292
# TODO: Move to `GradedUnitRanges` or `BlockArraysExtensions`.
9393
to_block(I::Block{1}) = I
9494
to_block(I::BlockIndexRange{1}) = Block(I)
95-
to_block(I::BlockIndexVector) = Block(I)
95+
to_block(I::BlockIndexVector{1}) = Block(I)
9696
to_block_indices(I::Block{1}) = Colon()
9797
to_block_indices(I::BlockIndexRange{1}) = only(I.indices)
98-
to_block_indices(I::BlockIndexVector) = only(I.indices)
98+
to_block_indices(I::BlockIndexVector{1}) = only(I.indices)
9999

100100
function Base.view(
101101
a::AbstractBlockSparseArray{<:Any,N},
@@ -166,7 +166,7 @@ function BlockArrays.viewblock(
166166
<:AbstractBlockSparseArray{T,N},
167167
<:Tuple{Vararg{Union{BlockSliceCollection,SubBlockSliceCollection},N}},
168168
},
169-
block::Union{Block{N},BlockIndexRange{N}},
169+
block::Union{Block{N},BlockIndexRange{N},BlockIndexVector{N}},
170170
) where {T,N}
171171
return viewblock(a, to_tuple(block)...)
172172
end
@@ -228,6 +228,14 @@ function to_blockindexrange(
228228
# work right now.
229229
return blocks(a.blocks)[Int(I)]
230230
end
231+
function to_blockindexrange(
232+
a::BlockIndices{<:BlockVector{<:GenericBlockIndex{1},<:Vector{<:BlockIndexVector}}},
233+
I::Block{1},
234+
)
235+
# TODO: Ideally we would just use `a.blocks[I]` but that doesn't
236+
# work right now.
237+
return blocks(a.blocks)[Int(I)]
238+
end
231239
function to_blockindexrange(
232240
a::Base.Slice{<:AbstractBlockedUnitRange{<:Integer}}, I::Block{1}
233241
)

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,36 @@ function Base.to_indices(
7878
return @interface interface(a) to_indices(a, inds, I)
7979
end
8080

81+
# a[mortar([Block(1)[[1, 2]], Block(2)[[1, 3]]])]
82+
function Base.to_indices(
83+
a::AnyAbstractBlockSparseArray,
84+
inds,
85+
I::Tuple{BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexVector{1}}},Vararg{Any}},
86+
)
87+
return @interface interface(a) to_indices(a, inds, I)
88+
end
89+
function Base.to_indices(
90+
a::AnyAbstractBlockSparseArray,
91+
inds,
92+
I::Tuple{BlockVector{<:GenericBlockIndex{1},<:Vector{<:BlockIndexVector{1}}},Vararg{Any}},
93+
)
94+
return @interface interface(a) to_indices(a, inds, I)
95+
end
96+
8197
# a[[Block(1)[1:2], Block(2)[1:2]], [Block(1)[1:2], Block(2)[1:2]]]
8298
function Base.to_indices(
8399
a::AnyAbstractBlockSparseArray, inds, I::Tuple{Vector{<:BlockIndexRange{1}},Vararg{Any}}
84100
)
85101
return to_indices(a, inds, (mortar(I[1]), Base.tail(I)...))
86102
end
87103

104+
# a[[Block(1)[[1, 2]], Block(2)[[1, 2]]], [Block(1)[[1, 2]], Block(2)[[1, 2]]]]
105+
function Base.to_indices(
106+
a::AnyAbstractBlockSparseArray, inds, I::Tuple{Vector{<:BlockIndexVector{1}},Vararg{Any}}
107+
)
108+
return to_indices(a, inds, (mortar(I[1]), Base.tail(I)...))
109+
end
110+
88111
# BlockArrays `AbstractBlockArray` interface
89112
function BlockArrays.blocks(a::AnyAbstractBlockSparseArray)
90113
@interface interface(a) blocks(a)

src/blocksparsearrayinterface/blocksparsearrayinterface.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,23 @@ end
229229
return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...)
230230
end
231231

232+
@interface ::AbstractBlockSparseArrayInterface function Base.to_indices(
233+
a,
234+
inds,
235+
I::Tuple{BlockVector{<:BlockIndex{1},<:Vector{<:BlockIndexVector{1}}},Vararg{Any}},
236+
)
237+
I1 = BlockIndices(I[1], blockedunitrange_getindices(inds[1], I[1]))
238+
return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...)
239+
end
240+
@interface ::AbstractBlockSparseArrayInterface function Base.to_indices(
241+
a,
242+
inds,
243+
I::Tuple{BlockVector{<:GenericBlockIndex{1},<:Vector{<:BlockIndexVector{1}}},Vararg{Any}},
244+
)
245+
I1 = BlockIndices(I[1], blockedunitrange_getindices(inds[1], I[1]))
246+
return (I1, to_indices(a, Base.tail(inds), Base.tail(I))...)
247+
end
248+
232249
# a[BlockVector([Block(2), Block(1)], [2]), BlockVector([Block(2), Block(1)], [2])]
233250
# Permute and merge blocks.
234251
# TODO: This isn't merging blocks yet, that needs to be implemented that.

0 commit comments

Comments
 (0)