Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added @materialize convenience DSL #91

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/LazyArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,11 @@ end

export Mul, Applied, MulArray, MulVector, MulMatrix, InvMatrix, PInvMatrix,
Hcat, Vcat, Kron, BroadcastArray, BroadcastMatrix, BroadcastVector, cache, Ldiv, Inv, PInv, Diff, Cumsum,
applied, materialize, materialize!, ApplyArray, ApplyMatrix, ApplyVector, apply, ⋆, @~, LazyArray
applied, materialize, materialize!, @materialize, ApplyArray, ApplyMatrix, ApplyVector, apply, ⋆, @~, LazyArray


include("lazyapplying.jl")
include("materialize_dsl.jl")
include("lazybroadcasting.jl")
include("linalg/linalg.jl")
include("cache.jl")
Expand Down
124 changes: 124 additions & 0 deletions src/materialize_dsl.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# For unparametrized destination types
generate_copyto!_signature(dest, dest_type::Symbol, Msig) =
:(Base.copyto!($(dest)::$(dest_type), applied_obj::$(Msig)))

# For parametrized destination types
function generate_copyto!_signature(dest, dest_type::Expr, Msig)
dest_type.head == :curly ||
throw(ArgumentError("Invalid destination specification $(dest)::$(dest_type)"))
:(Base.copyto!($(dest)::$(dest_type), applied_obj::$(Msig)) where {$(dest_type.args[2:end]...)})
end

function generate_copyto!(body, factor_names, Msig)
body.head == :(->) ||
throw(ArgumentError("Invalid copyto! specification"))
body.args[1].head == :(::) ||
throw(ArgumentError("Invalid destination specification $(body.args[1])"))
(dest,dest_type) = body.args[1].args
copyto!_signature = generate_copyto!_signature(dest, dest_type, Msig)
f_body = quote
axes($dest) == axes(applied_obj) || throw(DimensionMismatch("axes must be same"))
$(factor_names) = applied_obj.args
$(body.args[2].args...)
$(dest)
end
Expr(:function, copyto!_signature, f_body)
end

"""
@materialize function op(args...)

This macro simplifies the setup of a few functions necessary for the
materialization of [`Applied`](@ref) objects:

- `ApplyStyle`, used to ensure dispatch of the applied object to the
routines below

- `copyto!(dest::DestType, applied_obj::Applied{...,op})` performs the
actual materialization of `applied_obj` into the destination object
which has been generated by

- `similar` which usually returns a suitable matrix

- `materialize` which makes use of the above functions

# Example

```julia
@materialize function *(Ac::MyAdjointBasis,
O::MyOperator,
B::MyBasis)
MyApplyStyle # An instance of this type will be returned by ApplyStyle
T -> begin # generates similar
A = parent(Ac)
parent(A) == parent(B) ||
throw(ArgumentError("Incompatible bases"))

# There may be different matrices best representing different
# situations:
if ...
Diagonal(Vector{T}(undef, size(B,1)))
else
Tridiagonal(Vector{T}(undef, size(B,1)-1),
Vector{T}(undef, size(B,1)),
Vector{T}(undef, size(B,1)-1))
end
end
dest::Diagonal{T} -> begin # generate copyto!(dest::Diagonal{T}, ...) where T
dest.diag .= 1
end
dest::Tridiagonal{T} -> begin # generate copyto!(dest::Tridiagonal{T}, ...) where T
dest.dl .= -2
dest.ev .= 1
dest.du .= 3
end
end
```
"""
macro materialize(expr)
expr.head == :function || expr.head == :(=) || error("Must start with a function")
@assert expr.args[1].head == :call
op = expr.args[1].args[1]

bodies = filter(e -> !(e isa LineNumberNode), expr.args[2].args)
length(bodies) < 3 &&
throw(ArgumentError("At least three blocks required (ApplyStyle, similar, and at least one copyto!)"))

factor_types = :(<:Tuple{})
factor_names = :(())
apply_style = first(bodies)
apply_style_fun = :(LazyArrays.ApplyStyle(::typeof($op)) = $(apply_style)())

# Generate Applied signature
for arg in expr.args[1].args[2:end]
arg isa Expr && arg.head == :(::) ||
throw(ArgumentError("Invalid argument specification $(arg)"))
arg_name, arg_typ = arg.args
push!(factor_types.args[1].args, :(<:$(arg_typ)))
push!(factor_names.args, arg_name)
push!(apply_style_fun.args[1].args, :(::Type{<:$(arg_typ)}))
end
Msig = :(LazyArrays.Applied{$(apply_style), typeof($op), $(factor_types)})

sim_body = bodies[2]
sim_body.head == :(->) ||
throw(ArgumentError("Invalid similar specification"))
T = first(sim_body.args)

copytos! = map(body -> generate_copyto!(body, factor_names, Msig), bodies[3:end])

f = quote
$(apply_style_fun)

function Base.similar(applied_obj::$Msig, ::Type{$T}=eltype(applied_obj)) where $T
$(factor_names) = applied_obj.args
$(sim_body.args[2])
end

$(copytos!...)

LazyArrays.materialize(applied_obj::$Msig) =
copyto!(similar(applied_obj, eltype(applied_obj)), applied_obj)
end
esc(f)
end
59 changes: 59 additions & 0 deletions test/materialize_dsl.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
struct MyOperator{T}
n::Int
kind::Symbol
end

Base.axes(O::MyOperator) = (Base.OneTo(O.n),Base.OneTo(O.n))
Base.axes(O::MyOperator,i) = axes(O)[i]
Base.size(O::MyOperator) = (O.n,O.n)
Base.eltype(::MyOperator{T}) where T = T

struct MyApplyStyle <: ApplyStyle end

@materialize function *(Ac::Adjoint{<:Any,<:AbstractMatrix},
O::MyOperator,
B::AbstractMatrix)
MyApplyStyle
T -> begin
A = parent(Ac)

if O.kind == :diagonal
Diagonal(Vector{T}(undef, O.n))
else
Tridiagonal(Vector{T}(undef, O.n-1),
Vector{T}(undef, O.n),
Vector{T}(undef, O.n-1))
end
end
dest::Diagonal -> begin
dest.diag .= 1
end
dest::Tridiagonal{T} -> begin
dest.dl .= -2
dest.d .= 1
dest.du .= 3
end
end

@testset "Materialize DSL" begin
o = ones(10)
M = ones(10,10)
D = MyOperator{Float64}(10, :diagonal)
T = MyOperator{ComplexF64}(10, :tridiagonal)

@test LazyArrays.ApplyStyle(*, typeof(M'), typeof(D), typeof(M)) == MyApplyStyle()
@test LazyArrays.ApplyStyle(*, typeof(M'), typeof(T), typeof(M)) == MyApplyStyle()

d = apply(*, M', D, M)
@test d isa Diagonal{Float64}
@test all(d.diag .== 1)

t = apply(*, M', T, M)
@test t isa Tridiagonal
@test all(t.dl .== -2)
@test all(t.d .== 1)
@test all(t.du .== 3)

M̃ = ones(11,11)
@test_throws DimensionMismatch apply(*, M̃', D, M̃)
end
7 changes: 5 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Test, LinearAlgebra, LazyArrays, StaticArrays, FillArrays, ArrayLayouts
import LazyArrays: CachedArray, colsupport, rowsupport, LazyArrayStyle, broadcasted,
PaddedLayout, ApplyLayout, BroadcastLayout, AddArray, LazyLayout
PaddedLayout, ApplyLayout, BroadcastLayout, AddArray, LazyLayout,
ApplyStyle

@testset "Lazy MemoryLayout" begin
@testset "ApplyArray" begin
Expand All @@ -25,6 +26,7 @@ import LazyArrays: CachedArray, colsupport, rowsupport, LazyArrayStyle, broadcas
end
end
include("applytests.jl")
include("materialize_dsl.jl")
include("multests.jl")
include("ldivtests.jl")
include("addtests.jl")
Expand Down Expand Up @@ -224,4 +226,5 @@ end

bc = BroadcastArray(broadcasted(+,1:10,broadcasted(+,1,2)))
@test bc.args[2] == 3
end
end