Skip to content

Commit bf6850b

Browse files
committed
upgrade, more matrices
1 parent 82ad403 commit bf6850b

File tree

4 files changed

+55
-6
lines changed

4 files changed

+55
-6
lines changed

src/projection.jl

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,16 @@ function (project::ProjectTo{Adjoint})(dx::AbstractArray)
343343
dy = eltype(dx) <: Real ? vec(dx) : adjoint(dx)
344344
return adjoint(project.parent(dy))
345345
end
346+
# structural => natural standardisation, broadest possible signature
347+
function (project::ProjectTo{Adjoint})(dx::Tangent)
348+
if dx.parent isa Tangent
349+
# Can't wrap a structural representation of an array in an Adjoint:
350+
return dx
351+
else
352+
# This case should handle dx.parent isa AbstractZero, too
353+
return Adjoint(project.parent(dx.parent))
354+
end
355+
end
346356

347357
function ProjectTo(x::LinearAlgebra.TransposeAbsVec)
348358
return ProjectTo{Transpose}(; parent=ProjectTo(parent(x)))
@@ -357,14 +367,22 @@ function (project::ProjectTo{Transpose})(dx::AbstractArray)
357367
dy = eltype(dx) <: Number ? vec(dx) : transpose(dx)
358368
return transpose(project.parent(dy))
359369
end
370+
function (project::ProjectTo{Transpose})(
371+
dx::Tangent{<:Transpose, <:NamedTuple{(:parent,), <:Tuple{AbstractVector}}},
372+
)
373+
return Transpose(project.parent(dx.parent))
374+
end
360375

361376
# Diagonal
362377
ProjectTo(x::Diagonal) = ProjectTo{Diagonal}(; diag=ProjectTo(x.diag))
363378
(project::ProjectTo{Diagonal})(dx::AbstractMatrix) = Diagonal(project.diag(diag(dx)))
364379
(project::ProjectTo{Diagonal})(dx::Diagonal) = Diagonal(project.diag(dx.diag))
365-
366-
(project::ProjectTo{Diagonal})(dx::Tangent{T}) where T = (@show T; Diagonal(project.diag(dx.diag)))
367-
# (project::ProjectTo{Diagonal})(dx::Tangent{<:Diagonal, NamedTuple{(:diag,), <:Tuple{AbstractVector}}}) = Diagonal(project.diag(@show dx.diag))
380+
# structural => natural standardisation, very conservative signature:
381+
function (project::ProjectTo{Diagonal})(
382+
dx::Tangent{<:Diagonal, <:NamedTuple{(:diag,), <:Tuple{AbstractVector}}},
383+
)
384+
return Diagonal(project.diag(dx.diag))
385+
end
368386

369387
# Symmetric
370388
for (SymHerm, chk, fun) in
@@ -383,6 +401,13 @@ for (SymHerm, chk, fun) in
383401
dz = $chk(dy) ? dy : (dy .+ $fun(dy)) ./ 2
384402
return $SymHerm(project.parent(dz), project.uplo)
385403
end
404+
function (project::ProjectTo{$SymHerm})(dx::Tangent{<:$SymHerm})
405+
if dx.data isa Tangent
406+
return dx
407+
else
408+
return $SymHerm(project.parent(dx.data))
409+
end
410+
end
386411
# This is an example of a subspace which is not a subtype,
387412
# not clear how broadly it's worthwhile to try to support this.
388413
function (project::ProjectTo{$SymHerm})(dx::Diagonal)

src/tangent_types/abstract_zero.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,6 @@ Base.iterate(::AbstractZero, ::Any) = nothing
1919
Base.Broadcast.broadcastable(x::AbstractZero) = Ref(x)
2020
Base.Broadcast.broadcasted(::Type{T}) where {T<:AbstractZero} = T()
2121

22-
# Linear operators
23-
Base.adjoint(z::AbstractZero) = z
24-
Base.transpose(z::AbstractZero) = z
2522
Base.:/(z::AbstractZero, ::Any) = z
2623

2724
Base.convert(::Type{T}, x::AbstractZero) where {T<:Number} = zero(T)
@@ -36,6 +33,14 @@ Base.view(z::AbstractZero, ind...) = z
3633
Base.sum(z::AbstractZero; dims=:) = z
3734
Base.reshape(z::AbstractZero, size...) = z
3835

36+
# LinearAlgebra
37+
for f in (:adjoint, :transpose, :Adjoint, :Transpose, :Diagonal)
38+
@eval LinearAlgebra.$f(z::AbstractZero) = z
39+
end
40+
for f in (:Symmetric, :Hermitian)
41+
@eval LinearAlgebra.$f(z::AbstractZero, uplo=:U) = z
42+
end
43+
3944
"""
4045
ZeroTangent() <: AbstractZero
4146

test/projection.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,10 @@ struct NoSuperType end
187187
@test padj_complex(transpose([4, 5, 6 + 7im])) == [4 5 6 + 7im]
188188
@test padj_complex(adjoint([4, 5, 6 + 7im])) == [4 5 6 - 7im]
189189

190+
# structural => natural
191+
@test padj(Tangent{adjT}(; parent=ones(3) .+ im)) isa adjT
192+
@test_skip padj(Tangent{Any}(; parent=ones(3))) isa adjT # only for Adjoint now
193+
190194
# evil test case
191195
if VERSION >= v"1.7-" # up to 1.6 Vector[[1,2,3]]' is an error, not sure why it's called
192196
xs = adj(Any[Any[1, 2, 3], Any[4 + im, 5 - im, 6 + im, 7 - im]])
@@ -221,6 +225,10 @@ struct NoSuperType end
221225
@test psymm(psymm(reshape(1:9, 3, 3))) == psymm(reshape(1:9, 3, 3))
222226
@test psymm(rand(ComplexF32, 3, 3, 1)) isa Symmetric{Float64}
223227
@test ProjectTo(Symmetric(randn(3, 3) .> 0))(randn(3, 3)) == NoTangent() # Bool
228+
# structural => natural
229+
dx = Tangent{typeof(Symmetric(rand(3, 3)))}(; data=[1 2 3; 4 5 6; 7 8 9im])
230+
@test psymm(dx) isa Symmetric{Float64}
231+
@test psymm(Tangent{typeof(Symmetric(rand(3, 3)))}(; )) isa AbstractZero
224232

225233
pherm = ProjectTo(Hermitian(rand(3, 3) .+ im, :L))
226234
# NB, projection onto Hermitian subspace, not application of Hermitian constructor
@@ -247,6 +255,8 @@ struct NoSuperType end
247255
@test pdiag(Diagonal(1.0:3.0)) === Diagonal(1.0:3.0)
248256
@test ProjectTo(Diagonal(randn(3) .> 0))(randn(3, 3)) == NoTangent()
249257
@test ProjectTo(Diagonal(randn(3) .> 0))(Diagonal(rand(3))) == NoTangent()
258+
# structural => natural
259+
@test pdiag(Tangent{typeof(Diagonal(1:3))}(; diag=ones(3) .+ im)) isa Diagonal{Float64}
250260

251261
pbi = ProjectTo(Bidiagonal(rand(3, 3), :L))
252262
@test pbi(reshape(1:9, 3, 3)) == [1.0 0.0 0.0; 2.0 5.0 0.0; 0.0 6.0 9.0]

test/tangent_types/abstract_zero.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,4 +116,13 @@
116116
@test convert(Int64, NoTangent()) == 0
117117
@test convert(Float64, NoTangent()) == 0.0
118118
end
119+
120+
@testset "LinearAlgebra constructors" begin
121+
@test adjoint(ZeroTangent()) === ZeroTangent()
122+
@test transpose(ZeroTangent()) === ZeroTangent()
123+
@test Adjoint(ZeroTangent()) === ZeroTangent()
124+
@test Transpose(ZeroTangent()) === ZeroTangent()
125+
@test Symmetric(ZeroTangent()) === ZeroTangent()
126+
@test Hermitian(ZeroTangent(), :U) === ZeroTangent()
127+
end
119128
end

0 commit comments

Comments
 (0)