Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ext/LazyArraysStaticArraysExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module LazyArraysStaticArraysExt

using LazyArrays
using LazyArrays: LazyArrayStyle
using LazyArrays: AbstractLazyArrayStyle
using StaticArrays
using StaticArrays: StaticArrayStyle

Expand All @@ -10,6 +10,6 @@ function LazyArrays._vcat_layout_broadcasted((Ahead,Atail)::Tuple{SVector{M},Any
Vcat(op.(Ahead,Bhead), op.(Atail,Btail))
end

Base.BroadcastStyle(L::LazyArrayStyle{N}, ::StaticArrayStyle{N}) where N = L
Base.BroadcastStyle(L::AbstractLazyArrayStyle{N}, ::StaticArrayStyle{N}) where N = L

end
31 changes: 31 additions & 0 deletions src/cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,8 @@ MemoryLayout(C::Type{CachedArray{T,N,DAT,ARR}}) where {T,N,DAT,ARR} = cachedlayo
######

struct CachedArrayStyle{N} <: AbstractLazyArrayStyle{N} end
CachedArrayStyle(::Val{N}) where N = CachedArrayStyle{N}()
CachedArrayStyle{M}(::Val{N}) where {N,M} = CachedArrayStyle{N}()

BroadcastStyle(::Type{<:AbstractCachedArray{<:Any,N}}) where N = CachedArrayStyle{N}()
BroadcastStyle(::Type{<:SubArray{<:Any,N,<:AbstractCachedArray{<:Any,M}}}) where {N,M} = CachedArrayStyle{M}()
Expand Down Expand Up @@ -381,7 +383,36 @@ for op in (:*, :\, :+, :-)
@eval layout_broadcasted(::ZerosLayout, ::CachedLayout, ::typeof($op), a::AbstractVector, b::AbstractVector) = broadcast(DefaultArrayStyle{1}(), $op, a, b)
end

function resize_bcargs(bc::Broadcasted{<:CachedArrayStyle}, dest)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function resize_bcargs(bc::Broadcasted{<:CachedArrayStyle}, dest)
function resize_bcargs!(bc::Broadcasted{<:CachedArrayStyle}, dest)

rsz_args = let len = length(dest)
map(bc.args) do arg
resizedata!(arg, len)
iscached = arg isa AbstractCachedArray || (arg isa SubArray && parent(arg) isa AbstractCachedArray)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be all based on memory layouts not types.

iscached ? cacheddata(arg) : arg
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems very unlikely to be type stable. I think we want to rewrite this whole map in a functional programming style, something like:

_bc_resizecacheddata!(n) = ()
_bc_resizecacheddata!(n, a, b...) = __bc_resizecacheddata!(n, MemoryLayout(a), a, b...)
__bc_resizecacheddata!(n, _, a, b...) = (a, _bc_resizecacheddata!(b...))
function __bc_resizecacheddata!(n, ::AbstractCachedLayout, a::AbstractVector, b...)
    resizedata!(a, n)
    (view(cacheddata(a), 1:n), _bc_resizecacheddata!(b...))
end

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that since bc.args is a Tuple these should be functionally equivalent? But I'll change it anyway to not rely on that

end
end
return broadcasted(bc.f, rsz_args...)
end

function similar(bc::Broadcasted{<:CachedArrayStyle}, ::Type{T}) where T
return CachedArray(zeros(T, axes(bc)))
end

function copyto!(dest::AbstractArray, bc::Broadcasted{<:CachedArrayStyle})
#=
Without flatten, we were observing some stack overflows in some cases for nested broadcasts, e.g.
using SemiclassicalOrthogonalPolynomials, ClassicalOrthogonalPolynomials
Q = Normalized(Legendre())
P = SemiclassicalOrthogonalPolynomials.RaisedOP(Q, 1)
A, = ClassicalOrthogonalPolynomials.recurrencecoefficients(Q)
d = -inv(A[1] * SemiclassicalOrthogonalPolynomials._p0(Q) * P.ℓ[1])
κ = d * SemiclassicalOrthogonalPolynomials.normalizationconstant(1, P)
κ[1:2]
leads to a stack overflow.
=#
rsz_bc = resize_bcargs(Base.Broadcast.flatten(bc), dest)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not convinced flatten is preferable to a recursive call. I.e. something like

_bc_resizedata!(n, a::Broadcasted, b...) = (broadcasted(a.f, _bc_resizedata!(n, a.args...)...), _bc_resizedata!(n, b...))

But I think it is fine to flatten for now and come back later if there's an issue.

copyto!(dest, rsz_bc)
end

###
# norm
Expand Down
35 changes: 34 additions & 1 deletion test/cachetests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,6 @@ using Infinities
@test a[1:5] == zeros(5)
end


@testset "Issue #327" begin
A = cache(Zeros((1:5, OneToInf())))
B = cache(Zeros((1:5, OneToInf())))
Expand Down Expand Up @@ -487,6 +486,40 @@ using Infinities
B[5, 7] = 3.4
@test A == B
end

@testset "copyto! with CachedArrayStyle" begin
a = Accumulate(*, 1:5);
b = BroadcastVector(*, 2, a);
dest = Vector{Int}(undef, 3)
src = view(b, 1:3)
bc = LazyArrays._broadcastarray2broadcasted(src);
@test similar(bc, Float32) == cache(zeros(Float32, 3)) && similar(bc, Float32) isa CachedArray{Float32}
@test a.datasize == (1,)
@inferred LazyArrays.resize_bcargs(bc, dest);
@test a.datasize == (3,)
dest = Vector{Int}(undef, 1)
src = view(b, 5:5);
bc = LazyArrays._broadcastarray2broadcasted(src);
@inferred LazyArrays.resize_bcargs(bc, dest);
@test a.datasize == (5,)

a = Accumulate(*, 1:5); # reset to test different resizing
b = BroadcastVector(*, 2, a);
dest = Vector{Int}(undef, 4)
src = view(b,2:5)
bc = LazyArrays._broadcastarray2broadcasted(src);
rbc = LazyArrays.resize_bcargs(bc, dest);
@test Base.Broadcast.BroadcastStyle(typeof(rbc)) == Base.Broadcast.DefaultArrayStyle{1}()
@test rbc.f === bc.f
@test rbc.args == (2, a[2:5])

a = Accumulate(*, 1:5); # reset to ensure copyto! is working as intended
b = BroadcastVector(*, 2, a);
dest = Vector{Int}(undef, 3);
src = view(b,2:4);
copyto!(dest, src)
@test dest == [4,12,48]
end
end

end # module
23 changes: 1 addition & 22 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -379,27 +379,6 @@ end
@test a[end] ≈ prod(1 .+ (1:10_000_000).^(-2.0))
@test LazyArrays.AccumulateAbstractVector(*, 1:5) == Accumulate(*, 1:5)
@test LazyArrays.AccumulateAbstractVector(*, 1:5) isa LazyArrays.AccumulateAbstractVector

@testset "Broadcasted Cached" begin
a = Accumulate(*, 1:5)
b = BroadcastVector(*, 2, a);

dest = Vector{Int}(undef, 3)
copyto!(dest, view(b,1:3))

# lets step through the copyto! to reduce to MWE
bc = LazyArrays._broadcastarray2broadcasted(view(b,1:3))
# this is equivalent to
v = view(a,1:3)
bc = broadcasted(*, 2, v)

copyto!(dest, bc)


# what we want:
resizedata!(v, length(dest))
copyto!(dest, broadcasted(*, 2, LazyArrays.cacheddata(v)))
end
end
end

Expand Down Expand Up @@ -465,4 +444,4 @@ end

include("blocktests.jl")
include("bandedtests.jl")
include("blockbandedtests.jl")
include("blockbandedtests.jl")
Loading