diff --git a/src/SciMLOperators.jl b/src/SciMLOperators.jl index 4409e920..5c7aa847 100644 --- a/src/SciMLOperators.jl +++ b/src/SciMLOperators.jl @@ -87,7 +87,8 @@ export AffineOperator, AddVector, FunctionOperator, - TensorProductOperator + TensorProductOperator, + ConcretizedOperator export update_coefficients!, update_coefficients, diff --git a/src/basic.jl b/src/basic.jl index 3e5dd74f..8bd5067a 100644 --- a/src/basic.jl +++ b/src/basic.jl @@ -38,6 +38,14 @@ has_mul!(::IdentityOperator) = true has_ldiv(::IdentityOperator) = true has_ldiv!(::IdentityOperator) = true +function concretize!(A, L::IdentityOperator, α, β) + @assert size(A) == (L.len, L.len) + A .*= β + for i in 1:L.len + A[i, i] += α + end +end + # opeator application for op in ( :*, :\, @@ -259,6 +267,10 @@ has_mul!(L::ScaledOperator) = has_mul!(L.L) has_ldiv(L::ScaledOperator) = has_ldiv(L.L) & !iszero(L.λ) has_ldiv!(L::ScaledOperator) = has_ldiv!(L.L) & !iszero(L.λ) +function concretize!(A, L::ScaledOperator{T}, α, β) where {T} + concretize!(A, L.L, convert(Number, L.λ) * α, β) +end + function cache_internals(L::ScaledOperator, u::AbstractVecOrMat) @set! L.L = cache_operator(L.L, u) @set! L.λ = cache_operator(L.λ, u) @@ -420,6 +432,13 @@ function update_coefficients(L::AddedOperator, u, p, t) @set! L.ops = ops end +function concretize!(A, L::AddedOperator{T}, α, β) where {T} + A .*= β + for op in L.ops + concretize!(A, op, α, one(T)) + end +end + getops(L::AddedOperator) = L.ops islinear(L::AddedOperator) = all(islinear, getops(L)) Base.iszero(L::AddedOperator) = all(iszero, getops(L)) @@ -819,4 +838,100 @@ function LinearAlgebra.ldiv!(L::InvertedOperator, u::AbstractVecOrMat) copy!(L.cache, u) mul!(u, L.L, L.cache) end + +struct ConcretizedOperator{T, LType, AType} <: AbstractSciMLOperator{T} + L::LType + A::AType + function ConcretizedOperator(L::AbstractSciMLOperator{T}, A::AbstractMatrix) where {T} + new{T,typeof(L),typeof(A)}(L, A) + end +end + + +""" + ConcretizedOperator(L::AbstractSciMLOperator) + +Concretization of a SciMLOperator `L`, with a concrete backing `A = concretize(L)` that is used +for all linear algebra operations. Unlike `A` itself, a concretized operator correctly supports +updates to the operator state, by first updating `L` and then updating `A` accordingly. +""" +function ConcretizedOperator(L::AbstractSciMLOperator{T}) where {T} + ConcretizedOperator(L, convert(AbstractMatrix, L)) +end + +Base.convert(::Type{AbstractMatrix}, L::ConcretizedOperator) = L.A + +function Base.show(io::IO, L::ConcretizedOperator) + print(io, "ConcretizedOperator(") + show(io, L.L) + print(io, ")") +end +Base.size(L::ConcretizedOperator) = size(L.A) +function Base.resize!(L::ConcretizedOperator, n::Integer) + resize!(L.L, n) + resize!(L.cache, n) + # TODO: these next two lines seem dangerous... in which cases do they make sense? + resize!(L.A, n) + concretize!(L.A, L.L) +end + +function update_coefficients(L::ConcretizedOperator, u, p, t; kwargs...) + @set! L.L = update_coefficients(L.L, u, p, t; kwargs...) + @set! L.A = concretize(L.L) +end + +function update_coefficients!(L::ConcretizedOperator, u, p, t; kwargs...) + for op in getops(L) + update_coefficients!(op, u, p, t; kwargs...) + end + concretize!(L.A, L.L) # TODO: this needs to be supported. Also, problematic if L.A is scalar, should we only support matrix? +end + +getops(L::ConcretizedOperator) = (L.L,) +islinear(L::ConcretizedOperator) = islinear(L.L) +isconvertible(::ConcretizedOperator) = true + +for op in ( + :adjoint, + :transpose, + :conj, + ) + @eval Base.$op(L::ConcretizedOperator) = ConcretizedOperator($op(L.L), $op(L.A)) +end + +@forward ConcretizedOperator.L ( + # LinearAlgebra + LinearAlgebra.issymmetric, + LinearAlgebra.ishermitian, + LinearAlgebra.isposdef, + LinearAlgebra.opnorm, + + # SciML + isconstant, + has_adjoint, + has_mul, + has_mul!, + has_ldiv, + has_ldiv! + ) + +Base.:*(L::ConcretizedOperator, u::AbstractVecOrMat) = L.A * u +Base.:\(L::ConcretizedOperator, u::AbstractVecOrMat) = L.A \ u + +function LinearAlgebra.mul!(v::AbstractVecOrMat, L::ConcretizedOperator, u::AbstractVecOrMat) + mul!(v, L.A, u) +end + +function LinearAlgebra.mul!(v::AbstractVecOrMat, L::ConcretizedOperator, u::AbstractVecOrMat, α, β) + mul!(v, L.A, u, α, β) +end + +function LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::ConcretizedOperator, u::AbstractVecOrMat) + ldiv!(v, L.A, u) +end + +function LinearAlgebra.ldiv!(L::ConcretizedOperator, u::AbstractVecOrMat) + ldiv!(L.A, u) +end + # diff --git a/src/interface.jl b/src/interface.jl index 8437f0f3..778e25f1 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -311,6 +311,21 @@ concretize(L::Union{ } ) = convert(Number, L) +function concretize!(A, L) + T = eltype(L) + concretize!(A, L, one(T), zero(T)) +end + +function concretize!(A, L::AbstractMatrix, α, β) + A .= α .* L .+ β .* A +end + +function concretize!(A, L::Union{Factorization, AbstractSciMLOperator}, α, β) + @warn """using concretize-based fallback for concretize! for $(typeof(L))""" + # TODO: could also use a mul! based fallback on the unit vectors + concretize!(A, concretize(L), α, β) +end + """ $SIGNATURES diff --git a/src/matrix.jl b/src/matrix.jl index 30687f2c..2a1daf3b 100644 --- a/src/matrix.jl +++ b/src/matrix.jl @@ -159,6 +159,10 @@ function update_coefficients!(L::MatrixOperator, u, p, t; kwargs...) L.update_func!(L.A, u, p, t; kwargs...) end +function concretize!(A, L::MatrixOperator, α, β) + return concretize!(A, L.A, α, β) +end + SparseArrays.sparse(L::MatrixOperator) = sparse(L.A) SparseArrays.issparse(L::MatrixOperator) = issparse(L.A) diff --git a/test/basic.jl b/test/basic.jl index 89e5d919..38a952c6 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -326,4 +326,16 @@ end v=rand(N); @test ldiv!(v, Di, u) ≈ u .* s v=copy(u); @test ldiv!(Di, u) ≈ v .* s end + +@testset "ConcretizedOperator" begin + A = rand(2, 2); B = rand(2, 2); + L = MatrixOperator(A; update_func=(u,p,t)->t * A) + 2 * MatrixOperator(B; update_func=(u,p,t)->t * B) + I + C = ConcretizedOperator(L) + v = rand(2) + C * v ≈ L * v + @test C.A ≈ convert(AbstractMatrix, L) + update_coefficients!(C, nothing, nothing, 2.0) + C * v ≈ L * v + @test C.A ≈ convert(AbstractMatrix, L) +end #