diff --git a/Project.toml b/Project.toml index efc4a6c..6b542ee 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "BlockSparseArrays" uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4" authors = ["ITensor developers and contributors"] -version = "0.7.19" +version = "0.7.20" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/BlockArraysExtensions/blockedunitrange.jl b/src/BlockArraysExtensions/blockedunitrange.jl index 3577905..0856238 100644 --- a/src/BlockArraysExtensions/blockedunitrange.jl +++ b/src/BlockArraysExtensions/blockedunitrange.jl @@ -216,22 +216,111 @@ function to_blockindices(a::AbstractBlockedUnitRange{<:Integer}, I::UnitRange{<: ) end -struct BlockIndexVector{T<:Integer,I<:AbstractVector{T},TB<:Integer} <: - AbstractVector{BlockIndex{1,Tuple{TB},Tuple{T}}} - block::Block{1,TB} +struct GenericBlockIndex{N,TI<:Tuple{Vararg{Integer,N}},Tα<:Tuple{Vararg{Any,N}}} + I::TI + α::Tα +end +@inline function GenericBlockIndex(a::NTuple{N,Block{1}}, b::Tuple) where {N} + return GenericBlockIndex(Int.(a), b) +end +@inline function GenericBlockIndex(::Tuple{}, b::Tuple{}) + return GenericBlockIndex{0,Tuple{},Tuple{}}((), ()) +end +@inline GenericBlockIndex(a::Integer, b) = GenericBlockIndex((a,), (b,)) +@inline GenericBlockIndex(a::Tuple, b) = GenericBlockIndex(a, (b,)) +@inline GenericBlockIndex(a::Integer, b::Tuple) = GenericBlockIndex((a,), b) +@inline GenericBlockIndex() = GenericBlockIndex((), ()) +@inline GenericBlockIndex(a::Block, b::Tuple) = GenericBlockIndex(a.n, b) +@inline GenericBlockIndex(a::Block, b) = GenericBlockIndex(a, (b,)) +@inline function GenericBlockIndex( + I::Tuple{Vararg{Integer,N}}, α::Tuple{Vararg{Any,M}} +) where {M,N} + M <= N || throw(ArgumentError("number of indices must not exceed the number of blocks")) + α2 = ntuple(k -> k <= M ? α[k] : 1, N) + GenericBlockIndex(I, α2) +end +BlockArrays.block(b::GenericBlockIndex) = Block(b.I...) +BlockArrays.blockindex(b::GenericBlockIndex{1}) = b.α[1] +function GenericBlockIndex(indcs::Tuple{Vararg{GenericBlockIndex{1},N}}) where {N} + GenericBlockIndex(block.(indcs), blockindex.(indcs)) +end +function print_tuple_elements(io::IO, @nospecialize(t)) + if !isempty(t) + print(io, t[1]) + for n in t[2:end] + print(io, ", ", n) + end + end + return nothing +end +function Base.show(io::IO, B::GenericBlockIndex) + show(io, Block(B.I...)) + print(io, "[") + print_tuple_elements(io, B.α) + print(io, "]") + return nothing +end + +using Base: @propagate_inbounds +@propagate_inbounds function Base.getindex(b::AbstractVector, K::GenericBlockIndex{1}) + return b[Block(K.I[1])][K.α[1]] +end +@propagate_inbounds function Base.getindex( + b::AbstractArray{T,N}, K::GenericBlockIndex{N} +) where {T,N} + return b[block(K)][K.α...] +end +@propagate_inbounds function Base.getindex( + b::AbstractArray, K::GenericBlockIndex{1}, J::GenericBlockIndex{1}... +) + return b[GenericBlockIndex(tuple(K, J...))] +end + +function blockindextype(TB::Type{<:Integer}, TI::Vararg{Type{<:Integer},N}) where {N} + return BlockIndex{N,NTuple{N,TB},Tuple{TI...}} +end +function blockindextype(TB::Type{<:Integer}, TI::Vararg{Type,N}) where {N} + return GenericBlockIndex{N,NTuple{N,TB},Tuple{TI...}} +end + +struct BlockIndexVector{N,I<:NTuple{N,AbstractVector},TB<:Integer,BT} <: AbstractArray{BT,N} + block::Block{N,TB} indices::I + function BlockIndexVector( + block::Block{N,TB}, indices::I + ) where {N,I<:NTuple{N,AbstractVector},TB<:Integer} + BT = blockindextype(TB, eltype.(indices)...) + return new{N,I,TB,BT}(block, indices) + end +end +function BlockIndexVector(block::Block{1}, indices::AbstractVector) + return BlockIndexVector(block, (indices,)) +end +Base.size(a::BlockIndexVector) = length.(a.indices) +function Base.getindex(a::BlockIndexVector{N}, I::Vararg{Integer,N}) where {N} + return a.block[map((r, i) -> r[i], a.indices, I)...] end -Base.length(a::BlockIndexVector) = length(a.indices) -Base.size(a::BlockIndexVector) = (length(a),) -BlockArrays.Block(a::BlockIndexVector) = a.block -Base.getindex(a::BlockIndexVector, I::Integer) = Block(a)[a.indices[I]] -Base.copy(a::BlockIndexVector) = BlockIndexVector(a.block, copy(a.indices)) +BlockArrays.block(b::BlockIndexVector) = b.block +BlockArrays.Block(b::BlockIndexVector) = b.block + +Base.copy(a::BlockIndexVector) = BlockIndexVector(a.block, copy.(a.indices)) + +using ArrayLayouts: LayoutArray +@propagate_inbounds Base.getindex(b::AbstractArray{T,N}, K::BlockIndexVector{N}) where {T,N} = b[block( + K +)][K.indices...] +@propagate_inbounds Base.getindex(b::LayoutArray{T,N}, K::BlockIndexVector{N}) where {T,N} = b[block( + K +)][K.indices...] +@propagate_inbounds Base.getindex(b::LayoutArray{T,1}, K::BlockIndexVector{1}) where {T} = b[block( + K +)][K.indices...] function to_blockindices(a::AbstractBlockedUnitRange{<:Integer}, I::AbstractArray{Bool}) I_blocks = blocks(BlockedVector(I, blocklengths(a))) I′_blocks = map(eachindex(I_blocks)) do b I_b = findall(I_blocks[b]) - BlockIndexVector(Block(b), I_b) + return BlockIndexVector(Block(b), I_b) end return mortar(filter(!isempty, I′_blocks)) end diff --git a/src/abstractblocksparsearray/views.jl b/src/abstractblocksparsearray/views.jl index 4a40947..1bfe48f 100644 --- a/src/abstractblocksparsearray/views.jl +++ b/src/abstractblocksparsearray/views.jl @@ -95,7 +95,7 @@ to_block(I::BlockIndexRange{1}) = Block(I) to_block(I::BlockIndexVector) = Block(I) to_block_indices(I::Block{1}) = Colon() to_block_indices(I::BlockIndexRange{1}) = only(I.indices) -to_block_indices(I::BlockIndexVector) = I.indices +to_block_indices(I::BlockIndexVector) = only(I.indices) function Base.view( a::AbstractBlockSparseArray{<:Any,N},