Skip to content

Commit 3519be9

Browse files
committed
make AbstractLazyArrayStyle
1 parent 1cac65f commit 3519be9

File tree

5 files changed

+46
-43
lines changed

5 files changed

+46
-43
lines changed

ext/LazyArraysBandedMatricesExt.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import ArrayLayouts: colsupport, rowsupport, materialize!, MatMulVecAdd, MatMulM
66
OnesLayout, AbstractFillLayout, mulreduce, inv_layout, _fill_lmul!, copyto!_layout, _copy_oftype,
77
layout_getindex, transtype
88
import LazyArrays: sublayout, symmetriclayout, hermitianlayout, applylayout, cachedlayout, transposelayout,
9-
LazyArrayStyle, ApplyArrayBroadcastStyle, AbstractInvLayout, AbstractLazyLayout, LazyLayouts,
9+
LazyArrayStyle, AbstractLazyArrayStyle, ApplyArrayBroadcastStyle, AbstractInvLayout, AbstractLazyLayout, LazyLayouts,
1010
AbstractPaddedLayout, PaddedLayout, AbstractLazyBandedLayout, LazyBandedLayout, PaddedRows,
1111
PaddedColumns, CachedArray, CachedMatrix, LazyLayout, BroadcastLayout, ApplyLayout,
1212
paddeddata, resizedata!, broadcastlayout, _broadcastarray2broadcasted, _broadcast_sub_arguments,
@@ -24,12 +24,12 @@ hermitianlayout(::Type{<:Real}, ::AbstractLazyBandedLayout) = SymmetricLayout{La
2424
hermitianlayout(::Type{<:Complex}, ::AbstractLazyBandedLayout) = HermitianLayout{LazyBandedLayout}()
2525

2626

27-
bandedbroadcaststyle(::LazyArrayStyle) = LazyArrayStyle{2}()
27+
bandedbroadcaststyle(::AbstractLazyArrayStyle) = LazyArrayStyle{2}()
2828

29-
BroadcastStyle(::LazyArrayStyle{1}, ::BandedStyle) = LazyArrayStyle{2}()
30-
BroadcastStyle(::BandedStyle, ::LazyArrayStyle{1}) = LazyArrayStyle{2}()
31-
BroadcastStyle(::LazyArrayStyle{2}, ::BandedStyle) = LazyArrayStyle{2}()
32-
BroadcastStyle(::BandedStyle, ::LazyArrayStyle{2}) = LazyArrayStyle{2}()
29+
BroadcastStyle(::AbstractLazyArrayStyle{1}, ::BandedStyle) = LazyArrayStyle{2}()
30+
BroadcastStyle(::BandedStyle, ::AbstractLazyArrayStyle{1}) = LazyArrayStyle{2}()
31+
BroadcastStyle(::AbstractLazyArrayStyle{2}, ::BandedStyle) = LazyArrayStyle{2}()
32+
BroadcastStyle(::BandedStyle, ::AbstractLazyArrayStyle{2}) = LazyArrayStyle{2}()
3333

3434
bandedcolumns(::AbstractLazyLayout) = BandedColumns{LazyLayout}()
3535
bandedcolumns(::DualLayout{<:AbstractLazyLayout}) = BandedColumns{LazyLayout}()
@@ -287,10 +287,10 @@ copyto!_layout(_, ::BroadcastBandedLayout, dest::AbstractMatrix, bc::AbstractMat
287287
# _banded_broadcast!(dest::AbstractMatrix, f, (A,B)::Tuple{AbstractMatrix{T},AbstractMatrix{V}}, _, ::Tuple{<:Any,ApplyBandedLayout{typeof(*)}}) where {T,V} =
288288
# broadcast!(f, dest, BandedMatrix(A), BandedMatrix(B))
289289

290-
broadcasted(::LazyArrayStyle, ::typeof(*), c::Number, A::BandedMatrix) = _BandedMatrix(c .* A.data, A.raxis, A.l, A.u)
291-
broadcasted(::LazyArrayStyle, ::typeof(*), A::BandedMatrix, c::Number) = _BandedMatrix(A.data .* c, A.raxis, A.l, A.u)
292-
broadcasted(::LazyArrayStyle, ::typeof(\), c::Number, A::BandedMatrix) = _BandedMatrix(c .\ A.data, A.raxis, A.l, A.u)
293-
broadcasted(::LazyArrayStyle, ::typeof(/), A::BandedMatrix, c::Number) = _BandedMatrix(A.data ./ c, A.raxis, A.l, A.u)
290+
broadcasted(::AbstractLazyArrayStyle, ::typeof(*), c::Number, A::BandedMatrix) = _BandedMatrix(c .* A.data, A.raxis, A.l, A.u)
291+
broadcasted(::AbstractLazyArrayStyle, ::typeof(*), A::BandedMatrix, c::Number) = _BandedMatrix(A.data .* c, A.raxis, A.l, A.u)
292+
broadcasted(::AbstractLazyArrayStyle, ::typeof(\), c::Number, A::BandedMatrix) = _BandedMatrix(c .\ A.data, A.raxis, A.l, A.u)
293+
broadcasted(::AbstractLazyArrayStyle, ::typeof(/), A::BandedMatrix, c::Number) = _BandedMatrix(A.data ./ c, A.raxis, A.l, A.u)
294294

295295

296296
copy(M::Mul{BroadcastBandedLayout{typeof(*)}, <:Union{PaddedColumns,PaddedLayout}}) = _broadcast_banded_padded_mul(arguments(BroadcastBandedLayout{typeof(*)}(), M.A), M.B)

src/cache.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -333,13 +333,14 @@ MemoryLayout(C::Type{CachedArray{T,N,DAT,ARR}}) where {T,N,DAT,ARR} = cachedlayo
333333
# to take advantage of special implementations of the sub-components
334334
######
335335

336-
struct CachedArrayStyle{N} <: AbstractArrayStyle{N} end
336+
struct CachedArrayStyle{N} <: AbstractLazyArrayStyle{N} end
337337

338338
BroadcastStyle(::Type{<:AbstractCachedArray{<:Any,N}}) where N = CachedArrayStyle{N}()
339339
BroadcastStyle(::Type{<:SubArray{<:Any,N,<:AbstractCachedArray{<:Any,M}}}) where {N,M} = CachedArrayStyle{M}()
340+
BroadcastStyle(::CachedArrayStyle{N}, ::LazyArrayStyle{M}) where {N,M} = CachedArrayStyle{max(M, N)}()
340341

341342

342-
broadcasted(::LazyArrayStyle, op, A::CachedArray) = CachedArray(broadcast(op, cacheddata(A)), broadcast(op, A.array))
343+
broadcasted(::AbstractLazyArrayStyle, op, A::CachedArray) = CachedArray(broadcast(op, cacheddata(A)), broadcast(op, A.array))
343344
layout_broadcasted(::CachedLayout, _, op, A::AbstractArray, c::Number) = CachedArray(broadcast(op, cacheddata(A), c), broadcast(op, A.array, c))
344345
layout_broadcasted(_, ::CachedLayout, op, c::Number, A::CachedArray) = CachedArray(broadcast(op, c, cacheddata(A)), broadcast(op, c, A.array))
345346
layout_broadcasted(::CachedLayout, _, op, A::CachedArray, c::Ref) = CachedArray(broadcast(op, cacheddata(A), c), broadcast(op, A.array, c))
@@ -501,19 +502,19 @@ CachedAbstractVector(array::AbstractVector{T}) where T = CachedAbstractVector{T}
501502
CachedAbstractMatrix(array::AbstractMatrix{T}) where T = CachedAbstractMatrix{T}(array)
502503

503504

504-
broadcasted(::LazyArrayStyle, op, A::CachedAbstractArray) = CachedAbstractArray(broadcast(op, cacheddata(A)), broadcast(op, A.array))
505-
function broadcasted(::LazyArrayStyle, op, A::CachedAbstractVector, B::CachedAbstractVector)
505+
broadcasted(::AbstractLazyArrayStyle, op, A::CachedAbstractArray) = CachedAbstractArray(broadcast(op, cacheddata(A)), broadcast(op, A.array))
506+
function broadcasted(::AbstractLazyArrayStyle, op, A::CachedAbstractVector, B::CachedAbstractVector)
506507
n = max(A.datasize[1],B.datasize[1])
507508
resizedata!(A,n)
508509
resizedata!(B,n)
509510
Adat = view(cacheddata(A),1:n)
510511
Bdat = view(cacheddata(B),1:n)
511512
CachedAbstractArray(broadcast(op, Adat, Bdat), broadcast(op, A.array, B.array))
512513
end
513-
broadcasted(::LazyArrayStyle, op, A::CachedAbstractArray, c::Number) = CachedAbstractArray(broadcast(op, cacheddata(A), c), broadcast(op, A.array, c))
514-
broadcasted(::LazyArrayStyle, op, c::Number, A::CachedAbstractArray) = CachedAbstractArray(broadcast(op, c, cacheddata(A)), broadcast(op, c, A.array))
515-
broadcasted(::LazyArrayStyle, op, A::CachedAbstractArray, c::Ref) = CachedAbstractArray(broadcast(op, cacheddata(A), c), broadcast(op, A.array, c))
516-
broadcasted(::LazyArrayStyle, op, c::Ref, A::CachedAbstractArray) = CachedAbstractArray(broadcast(op, c, cacheddata(A)), broadcast(op, c, A.array))
514+
broadcasted(::AbstractLazyArrayStyle, op, A::CachedAbstractArray, c::Number) = CachedAbstractArray(broadcast(op, cacheddata(A), c), broadcast(op, A.array, c))
515+
broadcasted(::AbstractLazyArrayStyle, op, c::Number, A::CachedAbstractArray) = CachedAbstractArray(broadcast(op, c, cacheddata(A)), broadcast(op, c, A.array))
516+
broadcasted(::AbstractLazyArrayStyle, op, A::CachedAbstractArray, c::Ref) = CachedAbstractArray(broadcast(op, cacheddata(A), c), broadcast(op, A.array, c))
517+
broadcasted(::AbstractLazyArrayStyle, op, c::Ref, A::CachedAbstractArray) = CachedAbstractArray(broadcast(op, c, cacheddata(A)), broadcast(op, c, A.array))
517518

518519

519520
###

src/lazybroadcasting.jl

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
struct LazyArrayStyle{N} <: AbstractArrayStyle{N} end
1+
abstract type AbstractLazyArrayStyle{N} <: AbstractArrayStyle{N} end
2+
3+
struct LazyArrayStyle{N} <: AbstractLazyArrayStyle{N} end
24
LazyArrayStyle(::Val{N}) where N = LazyArrayStyle{N}()
35
LazyArrayStyle{M}(::Val{N}) where {N,M} = LazyArrayStyle{N}()
46

@@ -7,8 +9,8 @@ LazyArrayStyle{M}(::Val{N}) where {N,M} = LazyArrayStyle{N}()
79
layout_broadcasted(_, _, op, A, B) = Base.Broadcast.Broadcasted(Base.Broadcast.combine_styles(A,B), op, (A, B))
810
layout_broadcasted(op, A, B) = layout_broadcasted(MemoryLayout(A), MemoryLayout(B), op, A, B)
911

10-
DefaultArrayStyle(::LazyArrayStyle{N}) where N = DefaultArrayStyle{N}()
11-
broadcasted(::LazyArrayStyle, op, A, B) = layout_broadcasted(op, A, B)
12+
DefaultArrayStyle(::AbstractLazyArrayStyle{N}) where N = DefaultArrayStyle{N}()
13+
broadcasted(::AbstractLazyArrayStyle, op, A, B) = layout_broadcasted(op, A, B)
1214

1315
for op in (:*, :/, :+, :-)
1416
@eval layout_broadcasted(::ZerosLayout, _, ::typeof($op), a, b) = broadcasted(DefaultArrayStyle(Base.Broadcast.combine_styles(a,b)), $op, a, b)
@@ -116,7 +118,7 @@ converteltype(::Type{T}, A::AbstractArray) where T = convert(AbstractArray{T}, A
116118
converteltype(::Type{T}, A) where T = convert(T, A)
117119
sub_materialize(::BroadcastLayout, A) = converteltype(eltype(A), sub_materialize(_broadcasted(A)))
118120

119-
copy(bc::Broadcasted{<:LazyArrayStyle}) = BroadcastArray(bc)
121+
copy(bc::Broadcasted{<:AbstractLazyArrayStyle}) = BroadcastArray(bc)
120122

121123
# BroadcastArray are immutable
122124
copy(bc::BroadcastArray) = bc
@@ -159,37 +161,37 @@ BroadcastStyle(::Type{<:UpperOrLowerTriangular{<:Any,<:LazyMatrix}}) = LazyArray
159161
BroadcastStyle(::Type{<:LinearAlgebra.HermOrSym{<:Any,<:LazyMatrix}}) = LazyArrayStyle{2}()
160162

161163

162-
BroadcastStyle(L::LazyArrayStyle{N}, ::StructuredMatrixStyle) where N = L
164+
BroadcastStyle(L::AbstractLazyArrayStyle{N}, ::StructuredMatrixStyle) where N = L
163165

164166

165167

166168
## scalar-range broadcast operations ##
167169
# Ranges already support smart broadcasting
168170
for op in (+, -, big)
169171
@eval begin
170-
broadcasted(::LazyArrayStyle{1}, ::typeof($op), r::AbstractRange) =
172+
broadcasted(::AbstractLazyArrayStyle{1}, ::typeof($op), r::AbstractRange) =
171173
broadcast(DefaultArrayStyle{1}(), $op, r)
172174
end
173175
end
174176

175177
for op in (-, +, *, /)
176-
@eval broadcasted(::LazyArrayStyle{1}, ::typeof($op), r::AbstractRange, x::Real) = broadcast(DefaultArrayStyle{1}(), $op, r, x)
178+
@eval broadcasted(::AbstractLazyArrayStyle{1}, ::typeof($op), r::AbstractRange, x::Real) = broadcast(DefaultArrayStyle{1}(), $op, r, x)
177179
end
178180

179181
for op in (-, +, *, \)
180-
@eval broadcasted(::LazyArrayStyle{1}, ::typeof($op), x::Real, r::AbstractRange) = broadcast(DefaultArrayStyle{1}(), $op, x, r)
182+
@eval broadcasted(::AbstractLazyArrayStyle{1}, ::typeof($op), x::Real, r::AbstractRange) = broadcast(DefaultArrayStyle{1}(), $op, x, r)
181183
end
182184

183-
broadcasted(::LazyArrayStyle{N}, op, r::AbstractFill{T,N}) where {T,N} = broadcast(DefaultArrayStyle{N}(), op, r)
184-
broadcasted(::LazyArrayStyle{N}, op, r::AbstractFill{T,N}, x::Number) where {T,N} = broadcast(DefaultArrayStyle{N}(), op, r, x)
185-
broadcasted(::LazyArrayStyle{N}, op, x::Number, r::AbstractFill{T,N}) where {T,N} = broadcast(DefaultArrayStyle{N}(), op, x, r)
186-
broadcasted(::LazyArrayStyle{N}, op, r::AbstractFill{T,N}, x::Ref) where {T,N} = broadcast(DefaultArrayStyle{N}(), op, r, x)
187-
broadcasted(::LazyArrayStyle{N}, op, x::Ref, r::AbstractFill{T,N}) where {T,N} = broadcast(DefaultArrayStyle{N}(), op, x, r)
188-
broadcasted(::LazyArrayStyle{N}, op, r1::AbstractFill{T,N}, r2::AbstractFill{V,N}) where {T,V,N} = broadcast(DefaultArrayStyle{N}(), op, r1, r2)
189-
broadcasted(::LazyArrayStyle{1}, ::typeof(*), a::AbstractFill, b::AbstractRange) = broadcast(DefaultArrayStyle{1}(), *, a, b)
190-
broadcasted(::LazyArrayStyle{1}, ::typeof(*), a::AbstractRange, b::AbstractFill) = broadcast(DefaultArrayStyle{1}(), *, a, b)
191-
broadcasted(::LazyArrayStyle{1}, ::typeof(*), a::Zeros{<:Any,1}, b::AbstractRange) = broadcast(DefaultArrayStyle{1}(), *, a, b)
192-
broadcasted(::LazyArrayStyle{1}, ::typeof(*), a::AbstractRange, b::Zeros{<:Any,1}) = broadcast(DefaultArrayStyle{1}(), *, a, b)
185+
broadcasted(::AbstractLazyArrayStyle{N}, op, r::AbstractFill{T,N}) where {T,N} = broadcast(DefaultArrayStyle{N}(), op, r)
186+
broadcasted(::AbstractLazyArrayStyle{N}, op, r::AbstractFill{T,N}, x::Number) where {T,N} = broadcast(DefaultArrayStyle{N}(), op, r, x)
187+
broadcasted(::AbstractLazyArrayStyle{N}, op, x::Number, r::AbstractFill{T,N}) where {T,N} = broadcast(DefaultArrayStyle{N}(), op, x, r)
188+
broadcasted(::AbstractLazyArrayStyle{N}, op, r::AbstractFill{T,N}, x::Ref) where {T,N} = broadcast(DefaultArrayStyle{N}(), op, r, x)
189+
broadcasted(::AbstractLazyArrayStyle{N}, op, x::Ref, r::AbstractFill{T,N}) where {T,N} = broadcast(DefaultArrayStyle{N}(), op, x, r)
190+
broadcasted(::AbstractLazyArrayStyle{N}, op, r1::AbstractFill{T,N}, r2::AbstractFill{V,N}) where {T,V,N} = broadcast(DefaultArrayStyle{N}(), op, r1, r2)
191+
broadcasted(::AbstractLazyArrayStyle{1}, ::typeof(*), a::AbstractFill, b::AbstractRange) = broadcast(DefaultArrayStyle{1}(), *, a, b)
192+
broadcasted(::AbstractLazyArrayStyle{1}, ::typeof(*), a::AbstractRange, b::AbstractFill) = broadcast(DefaultArrayStyle{1}(), *, a, b)
193+
broadcasted(::AbstractLazyArrayStyle{1}, ::typeof(*), a::Zeros{<:Any,1}, b::AbstractRange) = broadcast(DefaultArrayStyle{1}(), *, a, b)
194+
broadcasted(::AbstractLazyArrayStyle{1}, ::typeof(*), a::AbstractRange, b::Zeros{<:Any,1}) = broadcast(DefaultArrayStyle{1}(), *, a, b)
193195

194196

195197
###
@@ -308,8 +310,8 @@ arguments(b::BroadcastLayout, A::Transpose) = map(_transpose, arguments(b, paren
308310

309311
# broadcasting a transpose is the same as broadcasting it to the array and transposing
310312
# this allows us to collapse to one broadcast.
311-
broadcasted(::LazyArrayStyle, op, A::Transpose{<:Any,<:BroadcastArray}) = transpose(broadcast(op, parent(A)))
312-
broadcasted(::LazyArrayStyle, op, A::Adjoint{<:Real,<:BroadcastArray}) = adjoint(broadcast(op, parent(A)))
313+
broadcasted(::AbstractLazyArrayStyle, op, A::Transpose{<:Any,<:BroadcastArray}) = transpose(broadcast(op, parent(A)))
314+
broadcasted(::AbstractLazyArrayStyle, op, A::Adjoint{<:Real,<:BroadcastArray}) = adjoint(broadcast(op, parent(A)))
313315

314316
# ensure we benefit from fast linear indexing
315317
getindex(A::Transpose{<:Any,<:BroadcastVector}, k::AbstractVector) = parent(A)[k]

src/lazyconcat.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -450,9 +450,9 @@ _flatten_nums(args::Tuple{}, bc::Tuple{}) = ()
450450
_flatten_nums(args::Tuple, bc::Tuple) = (bc[1], _flatten_nums(tail(args), tail(bc))...)
451451
_flatten_nums(args::Tuple{Number, Vararg{Any}}, bc::Tuple{AbstractArray, Vararg{Any}}) = (Fill(bc[1],1), _flatten_nums(tail(args), tail(bc))...)
452452

453-
broadcasted(::LazyArrayStyle, op, A::Vcat) = Vcat(_flatten_nums(A.args, broadcast(x -> broadcast(op, x), A.args))...)
454-
broadcasted(::LazyArrayStyle, op, A::Transpose{<:Any,<:Vcat}) = transpose(broadcast(op, parent(A)))
455-
broadcasted(::LazyArrayStyle, op, A::Adjoint{<:Real,<:Vcat}) = broadcast(op, parent(A))'
453+
broadcasted(::AbstractLazyArrayStyle, op, A::Vcat) = Vcat(_flatten_nums(A.args, broadcast(x -> broadcast(op, x), A.args))...)
454+
broadcasted(::AbstractLazyArrayStyle, op, A::Transpose{<:Any,<:Vcat}) = transpose(broadcast(op, parent(A)))
455+
broadcasted(::AbstractLazyArrayStyle, op, A::Adjoint{<:Real,<:Vcat}) = broadcast(op, parent(A))'
456456

457457

458458
for Cat in (:vcat, :hcat)

src/linalg/mul.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -284,11 +284,11 @@ permutedims(A::ApplyArray{<:Any,2,typeof(*)}) = ApplyArray(*, reverse(map(permut
284284
##
285285

286286
for op in (:*, :\)
287-
@eval broadcasted(::LazyArrayStyle{N}, ::typeof($op), a::Number, b::ApplyArray{<:Number,N,typeof(*)}) where N =
287+
@eval broadcasted(::AbstractLazyArrayStyle{N}, ::typeof($op), a::Number, b::ApplyArray{<:Number,N,typeof(*)}) where N =
288288
ApplyArray(*, broadcast($op,a,first(b.args)), tail(b.args)...)
289289
end
290290

291-
broadcasted(::LazyArrayStyle{N}, ::typeof(/), b::ApplyArray{<:Number,N,typeof(*)}, a::Number) where N =
291+
broadcasted(::AbstractLazyArrayStyle{N}, ::typeof(/), b::ApplyArray{<:Number,N,typeof(*)}, a::Number) where N =
292292
ApplyArray(*, Base.front(b.args)..., broadcast(/,last(b.args),a))
293293

294294
for Typ in (:Lmul, :Rmul)

0 commit comments

Comments
 (0)