Skip to content

Commit 2de9bae

Browse files
committed
CachedArrayStyle for broadcasting with cached arrays
1 parent 04d24fb commit 2de9bae

File tree

4 files changed

+59
-33
lines changed

4 files changed

+59
-33
lines changed

src/cache.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,11 @@ MemoryLayout(C::Type{CachedArray{T,N,DAT,ARR}}) where {T,N,DAT,ARR} = cachedlayo
333333
# to take advantage of special implementations of the sub-components
334334
######
335335

336-
BroadcastStyle(::Type{<:CachedArray{<:Any,N}}) where N = LazyArrayStyle{N}()
336+
struct CachedArrayStyle{N} <: AbstractArrayStyle{N} end
337+
338+
BroadcastStyle(::Type{<:AbstractCachedArray{<:Any,N}}) where N = CachedArrayStyle{N}()
339+
BroadcastStyle(::Type{<:SubArray{<:Any,N,<:AbstractCachedArray{<:Any,M}}}) where {N,M} = CachedArrayStyle{M}()
340+
337341

338342
broadcasted(::LazyArrayStyle, op, A::CachedArray) = CachedArray(broadcast(op, cacheddata(A)), broadcast(op, A.array))
339343
layout_broadcasted(::CachedLayout, _, op, A::AbstractArray, c::Number) = CachedArray(broadcast(op, cacheddata(A), c), broadcast(op, A.array, c))

src/lazybroadcasting.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,10 @@ _BroadcastArray(bc::Broadcasted) = BroadcastArray{combine_eltypes(bc.f, bc.args)
7474
BroadcastArray(bc::Broadcasted{S}) where S =
7575
_BroadcastArray(instantiate(Broadcasted{S}(bc.f, _broadcast2broadcastarray(bc.args...))))
7676

77-
BroadcastArray(f, A, As...) = BroadcastArray(broadcasted(f, A, As...))
77+
BroadcastArray(f, A, As...) = BroadcastArray{combine_eltypes(f, (A,As...))}(f, A, As...)
7878
BroadcastArray{T}(f, A, As...) where T = BroadcastArray{T}(instantiate(broadcasted(f, A, As...)))
79-
BroadcastMatrix(f, A...) = BroadcastMatrix(broadcasted(f, A...))
80-
BroadcastVector(f, A...) = BroadcastVector(broadcasted(f, A...))
79+
BroadcastMatrix(f, A...) = BroadcastMatrix{combine_eltypes(f, A)}(f, A...)
80+
BroadcastVector(f, A...) = BroadcastVector{combine_eltypes(f, A)}(f, A...)
8181

8282
BroadcastArray{T,N}(f, A...) where {T,N} = BroadcastArray{T,N,typeof(f),typeof(A)}(f, A)
8383

test/broadcasttests.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,35 @@ using Infinities
449449
@test A*B Matrix(A)*Matrix(B)
450450
@test simplifiable(*,A,B) == Val(false) # TODO: Why False?
451451
end
452+
453+
@testset "misc tests" begin
454+
bc = broadcasted(exp,[1,2,3])
455+
v = BroadcastArray(exp, [1,2,3])
456+
@test BroadcastArray(bc) == BroadcastVector(bc) == BroadcastVector{Float64,typeof(exp),typeof(bc.args)}(bc) ==
457+
v == BroadcastVector(exp, [1,2,3]) == exp.([1,2,3])
458+
459+
@test Base.IndexStyle(typeof(BroadcastVector(exp, [1,2,3]))) == IndexCartesian()
460+
461+
bc = broadcasted(exp,[1 2; 3 4])
462+
M = BroadcastArray(exp, [1 2; 3 4])
463+
@test BroadcastArray(bc) == BroadcastMatrix(bc) == BroadcastMatrix{Float64,typeof(exp),typeof(bc.args)}(bc) ==
464+
M == BroadcastMatrix(BroadcastMatrix(bc)) == BroadcastMatrix(exp,[1 2; 3 4]) == exp.([1 2; 3 4])
465+
466+
@test exp.(v') isa Adjoint{<:Any,<:BroadcastVector}
467+
@test exp.(transpose(v)) isa Transpose{<:Any,<:BroadcastVector}
468+
@test exp.(M') isa Adjoint{<:Any,<:BroadcastMatrix}
469+
@test exp.(transpose(M)) isa Transpose{<:Any,<:BroadcastMatrix}
470+
471+
bc = BroadcastArray(broadcasted(+, 1:10, broadcasted(sin, 1:10)))
472+
@test bc[1:10] == (1:10) .+ sin.(1:10)
473+
474+
bc = BroadcastArray(broadcasted(+,1:10,broadcasted(+,1,2)))
475+
@test bc.args[2] == 3
476+
477+
@testset "_vec_mul_arguments method" begin
478+
@test_throws "MethodError: no method matching _vec_mul_arguments" LazyArrays._vec_mul_arguments(2, [])
479+
end
480+
end
452481
end
453482

454483
end #module

test/runtests.jl

Lines changed: 22 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using Test, LinearAlgebra, LazyArrays, FillArrays, ArrayLayouts, SparseArrays
22
using StaticArrays
3-
import LazyArrays: CachedArray, colsupport, rowsupport, LazyArrayStyle, broadcasted,
3+
import LazyArrays: CachedArray, colsupport, rowsupport, LazyArrayStyle, broadcasted, resizedata!,
44
ApplyLayout, BroadcastLayout, AddArray, LazyLayout, PaddedLayout, PaddedRows, PaddedColumns
55
import ArrayLayouts: OnesLayout
66

@@ -379,6 +379,27 @@ end
379379
@test a[end] prod(1 .+ (1:10_000_000).^(-2.0))
380380
@test LazyArrays.AccumulateAbstractVector(*, 1:5) == Accumulate(*, 1:5)
381381
@test LazyArrays.AccumulateAbstractVector(*, 1:5) isa LazyArrays.AccumulateAbstractVector
382+
383+
@testset "Broadcasted Cached" begin
384+
a = Accumulate(*, 1:5)
385+
b = BroadcastVector(*, 2, a);
386+
387+
dest = Vector{Int}(undef, 3)
388+
copyto!(dest, view(b,1:3))
389+
390+
# lets step through the copyto! to reduce to MWE
391+
bc = LazyArrays._broadcastarray2broadcasted(view(b,1:3))
392+
# this is equivalent to
393+
v = view(a,1:3)
394+
bc = broadcasted(*, 2, v)
395+
396+
copyto!(dest, bc)
397+
398+
399+
# what we want:
400+
resizedata!(v, length(dest))
401+
copyto!(dest, broadcasted(*, 2, LazyArrays.cacheddata(v)))
402+
end
382403
end
383404
end
384405

@@ -441,34 +462,6 @@ end
441462
@test tril(A,1) isa ApplyMatrix{Float64,typeof(tril)}
442463
end
443464

444-
@testset "BroadcastArray" begin
445-
bc = broadcasted(exp,[1,2,3])
446-
v = BroadcastArray(exp, [1,2,3])
447-
@test BroadcastArray(bc) == BroadcastVector(bc) == BroadcastVector{Float64,typeof(exp),typeof(bc.args)}(bc) ==
448-
v == BroadcastVector(exp, [1,2,3]) == exp.([1,2,3])
449-
450-
Base.IndexStyle(typeof(BroadcastVector(exp, [1,2,3]))) == IndexLinear()
451-
452-
bc = broadcasted(exp,[1 2; 3 4])
453-
M = BroadcastArray(exp, [1 2; 3 4])
454-
@test BroadcastArray(bc) == BroadcastMatrix(bc) == BroadcastMatrix{Float64,typeof(exp),typeof(bc.args)}(bc) ==
455-
M == BroadcastMatrix(BroadcastMatrix(bc)) == BroadcastMatrix(exp,[1 2; 3 4]) == exp.([1 2; 3 4])
456-
457-
@test exp.(v') isa Adjoint{<:Any,<:BroadcastVector}
458-
@test exp.(transpose(v)) isa Transpose{<:Any,<:BroadcastVector}
459-
@test exp.(M') isa Adjoint{<:Any,<:BroadcastMatrix}
460-
@test exp.(transpose(M)) isa Transpose{<:Any,<:BroadcastMatrix}
461-
462-
bc = BroadcastArray(broadcasted(+, 1:10, broadcasted(sin, 1:10)))
463-
@test bc[1:10] == (1:10) .+ sin.(1:10)
464-
465-
bc = BroadcastArray(broadcasted(+,1:10,broadcasted(+,1,2)))
466-
@test bc.args[2] == 3
467-
468-
@testset "_vec_mul_arguments method" begin
469-
@test_throws "MethodError: no method matching _vec_mul_arguments" LazyArrays._vec_mul_arguments(2, [])
470-
end
471-
end
472465

473466
include("blocktests.jl")
474467
include("bandedtests.jl")

0 commit comments

Comments
 (0)