From f2eb9cf763552427962e223ddcc01ee6362d1566 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 22 Aug 2021 11:23:08 -0400 Subject: [PATCH 01/21] convert Tangent to Diagonal --- src/projection.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/projection.jl b/src/projection.jl index 8eba26353..43fea0692 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -421,6 +421,9 @@ ProjectTo(x::Diagonal) = ProjectTo{Diagonal}(; diag=ProjectTo(x.diag)) (project::ProjectTo{Diagonal})(dx::AbstractMatrix) = Diagonal(project.diag(diag(dx))) (project::ProjectTo{Diagonal})(dx::Diagonal) = Diagonal(project.diag(dx.diag)) +(project::ProjectTo{Diagonal})(dx::Tangent{T}) where T = (@show T; Diagonal(project.diag(dx.diag))) +# (project::ProjectTo{Diagonal})(dx::Tangent{<:Diagonal, NamedTuple{(:diag,), <:Tuple{AbstractVector}}}) = Diagonal(project.diag(@show dx.diag)) + # Symmetric for (SymHerm, chk, fun) in ((:Symmetric, :issymmetric, :transpose), (:Hermitian, :ishermitian, :adjoint)) @@ -455,6 +458,7 @@ for UL in (:UpperTriangular, :LowerTriangular, :UnitUpperTriangular, :UnitLowerT @eval begin ProjectTo(x::$UL) = ProjectTo{$UL}(; parent=ProjectTo(parent(x))) (project::ProjectTo{$UL})(dx::AbstractArray) = $UL(project.parent(dx)) + # Another subspace which is not a subtype, like Diagonal inside Symmetric above, equally unsure function (project::ProjectTo{$UL})(dx::Diagonal) sub = project.parent sub_one = ProjectTo{project_type(sub)}(; @@ -462,6 +466,8 @@ for UL in (:UpperTriangular, :LowerTriangular, :UnitUpperTriangular, :UnitLowerT ) return Diagonal(sub_one(dx.diag)) end + # Convert "structural" `Tangent`s to array-like "natural" tangents + (project::ProjectTo{$UL})(dx::Tangent{<:$UL, NamedTuple{(:data,), <:Tuple{AbstractMatrix}}}) = $UL(dx.data) end end From 57a658bf0b4a75568d8178208b3b1f9ee8d94093 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 15 Oct 2021 16:20:40 -0400 Subject: [PATCH 02/21] upgrade, more matrices --- src/projection.jl | 31 ++++++++++++++++++++++++++--- src/tangent_types/abstract_zero.jl | 14 +++++++++---- test/projection.jl | 10 ++++++++++ test/tangent_types/abstract_zero.jl | 9 +++++++++ 4 files changed, 57 insertions(+), 7 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index 43fea0692..bd507b984 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -401,6 +401,16 @@ function (project::ProjectTo{Adjoint})(dx::AbstractArray) dy = eltype(dx) <: Real ? vec(dx) : adjoint(dx) return adjoint(project.parent(dy)) end +# structural => natural standardisation, broadest possible signature +function (project::ProjectTo{Adjoint})(dx::Tangent) + if dx.parent isa Tangent + # Can't wrap a structural representation of an array in an Adjoint: + return dx + else + # This case should handle dx.parent isa AbstractZero, too + return Adjoint(project.parent(dx.parent)) + end +end function ProjectTo(x::LinearAlgebra.TransposeAbsVec) return ProjectTo{Transpose}(; parent=ProjectTo(parent(x))) @@ -415,14 +425,22 @@ function (project::ProjectTo{Transpose})(dx::AbstractArray) dy = eltype(dx) <: Number ? vec(dx) : transpose(dx) return transpose(project.parent(dy)) end +function (project::ProjectTo{Transpose})( + dx::Tangent{<:Transpose, <:NamedTuple{(:parent,), <:Tuple{AbstractVector}}}, + ) + return Transpose(project.parent(dx.parent)) +end # Diagonal ProjectTo(x::Diagonal) = ProjectTo{Diagonal}(; diag=ProjectTo(x.diag)) (project::ProjectTo{Diagonal})(dx::AbstractMatrix) = Diagonal(project.diag(diag(dx))) (project::ProjectTo{Diagonal})(dx::Diagonal) = Diagonal(project.diag(dx.diag)) - -(project::ProjectTo{Diagonal})(dx::Tangent{T}) where T = (@show T; Diagonal(project.diag(dx.diag))) -# (project::ProjectTo{Diagonal})(dx::Tangent{<:Diagonal, NamedTuple{(:diag,), <:Tuple{AbstractVector}}}) = Diagonal(project.diag(@show dx.diag)) +# structural => natural standardisation, very conservative signature: +function (project::ProjectTo{Diagonal})( + dx::Tangent{<:Diagonal, <:NamedTuple{(:diag,), <:Tuple{AbstractVector}}}, + ) + return Diagonal(project.diag(dx.diag)) +end # Symmetric for (SymHerm, chk, fun) in @@ -441,6 +459,13 @@ for (SymHerm, chk, fun) in dz = $chk(dy) ? dy : (dy .+ $fun(dy)) ./ 2 return $SymHerm(project.parent(dz), project.uplo) end + function (project::ProjectTo{$SymHerm})(dx::Tangent{<:$SymHerm}) + if dx.data isa Tangent + return dx + else + return $SymHerm(project.parent(dx.data)) + end + end # This is an example of a subspace which is not a subtype, # not clear how broadly it's worthwhile to try to support this. function (project::ProjectTo{$SymHerm})(dx::Diagonal) diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index e52d84819..9aa19b36c 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -19,9 +19,6 @@ Base.iterate(::AbstractZero, ::Any) = nothing Base.Broadcast.broadcastable(x::AbstractZero) = Ref(x) Base.Broadcast.broadcasted(::Type{T}) where {T<:AbstractZero} = T() -# Linear operators -Base.adjoint(z::AbstractZero) = z -Base.transpose(z::AbstractZero) = z Base.:/(z::AbstractZero, ::Any) = z Base.convert(::Type{T}, x::AbstractZero) where {T<:Number} = zero(T) @@ -37,7 +34,16 @@ Base.sum(z::AbstractZero; dims=:) = z Base.reshape(z::AbstractZero, size...) = z Base.reverse(z::AbstractZero, args...; kwargs...) = z -(::Type{<:UniformScaling})(z::AbstractZero) = z +# LinearAlgebra +LinearAlgebra.adjoint(z::AbstractZero, ind...) = z +LinearAlgebra.transpose(z::AbstractZero, ind...) = z + +for T in (:UniformScaling, :Adjoint, :Transpose, :Diagonal) + @eval (::Type{<:LinearAlgebra.$T})(z::AbstractZero) = z +end +for T in (:Symmetric, :Hermitian) + @eval (::Type{<:LinearAlgebra.$T})(z::AbstractZero, uplo=:U) = z +end """ ZeroTangent() <: AbstractZero diff --git a/test/projection.jl b/test/projection.jl index 3e70772ac..c4978aca9 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -232,6 +232,10 @@ struct NoSuperType end @test padj_complex(transpose([4, 5, 6 + 7im])) == [4 5 6 + 7im] @test padj_complex(adjoint([4, 5, 6 + 7im])) == [4 5 6 - 7im] + # structural => natural + @test padj(Tangent{adjT}(; parent=ones(3) .+ im)) isa adjT + @test_skip padj(Tangent{Any}(; parent=ones(3))) isa adjT # only for Adjoint now + # evil test case if VERSION >= v"1.7-" # up to 1.6 Vector[[1,2,3]]' is an error, not sure why it's called xs = adj(Any[Any[1, 2, 3], Any[4 + im, 5 - im, 6 + im, 7 - im]]) @@ -266,6 +270,10 @@ struct NoSuperType end @test psymm(psymm(reshape(1:9, 3, 3))) == psymm(reshape(1:9, 3, 3)) @test psymm(rand(ComplexF32, 3, 3, 1)) isa Symmetric{Float64} @test ProjectTo(Symmetric(randn(3, 3) .> 0))(randn(3, 3)) == NoTangent() # Bool + # structural => natural + dx = Tangent{typeof(Symmetric(rand(3, 3)))}(; data=[1 2 3; 4 5 6; 7 8 9im]) + @test psymm(dx) isa Symmetric{Float64} + @test psymm(Tangent{typeof(Symmetric(rand(3, 3)))}(; )) isa AbstractZero pherm = ProjectTo(Hermitian(rand(3, 3) .+ im, :L)) # NB, projection onto Hermitian subspace, not application of Hermitian constructor @@ -292,6 +300,8 @@ struct NoSuperType end @test pdiag(Diagonal(1.0:3.0)) === Diagonal(1.0:3.0) @test ProjectTo(Diagonal(randn(3) .> 0))(randn(3, 3)) == NoTangent() @test ProjectTo(Diagonal(randn(3) .> 0))(Diagonal(rand(3))) == NoTangent() + # structural => natural + @test pdiag(Tangent{typeof(Diagonal(1:3))}(; diag=ones(3) .+ im)) isa Diagonal{Float64} pbi = ProjectTo(Bidiagonal(rand(3, 3), :L)) @test pbi(reshape(1:9, 3, 3)) == [1.0 0.0 0.0; 2.0 5.0 0.0; 0.0 6.0 9.0] diff --git a/test/tangent_types/abstract_zero.jl b/test/tangent_types/abstract_zero.jl index 43e433b6c..004ae0f02 100644 --- a/test/tangent_types/abstract_zero.jl +++ b/test/tangent_types/abstract_zero.jl @@ -149,4 +149,13 @@ end @test isempty(detect_ambiguities(M)) end + + @testset "LinearAlgebra constructors" begin + @test adjoint(ZeroTangent()) === ZeroTangent() + @test transpose(ZeroTangent()) === ZeroTangent() + @test Adjoint(ZeroTangent()) === ZeroTangent() + @test Transpose(ZeroTangent()) === ZeroTangent() + @test Symmetric(ZeroTangent()) === ZeroTangent() + @test Hermitian(ZeroTangent(), :U) === ZeroTangent() + end end From 920bea7ac5bd7d8969d136058a68a58f5c61523a Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 15 Oct 2021 16:21:01 -0400 Subject: [PATCH 03/21] allow getindex to have several indices --- src/tangent_types/abstract_zero.jl | 3 +-- test/tangent_types/abstract_zero.jl | 3 +++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index 9aa19b36c..a062a0131 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -27,8 +27,7 @@ Base.convert(::Type{T}, x::AbstractZero) where {T<:Number} = zero(T) (::Type{Complex})(x::AbstractZero, y::Real) = Complex(false, y) (::Type{Complex})(x::Real, y::AbstractZero) = Complex(x, false) -Base.getindex(z::AbstractZero, args...) = z - +Base.getindex(z::AbstractZero, ind...) = z Base.view(z::AbstractZero, ind...) = z Base.sum(z::AbstractZero; dims=:) = z Base.reshape(z::AbstractZero, size...) = z diff --git a/test/tangent_types/abstract_zero.jl b/test/tangent_types/abstract_zero.jl index 004ae0f02..7ee896b16 100644 --- a/test/tangent_types/abstract_zero.jl +++ b/test/tangent_types/abstract_zero.jl @@ -5,6 +5,9 @@ end @testset "Linear operators" begin + @test getindex(ZeroTangent(), 1) === ZeroTangent() + @test getindex(NoTangent(), 1, 2) === NoTangent() + @test view(ZeroTangent(), 1) == ZeroTangent() @test view(NoTangent(), 1, 2) == NoTangent() From 47abb7bc90063249506282dfe6fcdc0f6040b97b Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 15 Oct 2021 20:39:12 -0400 Subject: [PATCH 04/21] more forgiving signatures --- src/projection.jl | 26 +++++++++----------------- src/tangent_types/abstract_zero.jl | 5 ++++- 2 files changed, 13 insertions(+), 18 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index bd507b984..3a3c32175 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -425,21 +425,16 @@ function (project::ProjectTo{Transpose})(dx::AbstractArray) dy = eltype(dx) <: Number ? vec(dx) : transpose(dx) return transpose(project.parent(dy)) end -function (project::ProjectTo{Transpose})( - dx::Tangent{<:Transpose, <:NamedTuple{(:parent,), <:Tuple{AbstractVector}}}, - ) - return Transpose(project.parent(dx.parent)) +function (project::ProjectTo{Transpose})(dx::Tangent) # structural => natural + return dx.parent isa Tangent ? dx : Transpose(project.parent(dx.parent)) end # Diagonal ProjectTo(x::Diagonal) = ProjectTo{Diagonal}(; diag=ProjectTo(x.diag)) (project::ProjectTo{Diagonal})(dx::AbstractMatrix) = Diagonal(project.diag(diag(dx))) (project::ProjectTo{Diagonal})(dx::Diagonal) = Diagonal(project.diag(dx.diag)) -# structural => natural standardisation, very conservative signature: -function (project::ProjectTo{Diagonal})( - dx::Tangent{<:Diagonal, <:NamedTuple{(:diag,), <:Tuple{AbstractVector}}}, - ) - return Diagonal(project.diag(dx.diag)) +function (project::ProjectTo{Diagonal})(dx::Tangent) # structural => natural + return dx.diag isa Tangent ? dx.diag : Diagonal(project.diag(dx.diag)) end # Symmetric @@ -459,12 +454,8 @@ for (SymHerm, chk, fun) in dz = $chk(dy) ? dy : (dy .+ $fun(dy)) ./ 2 return $SymHerm(project.parent(dz), project.uplo) end - function (project::ProjectTo{$SymHerm})(dx::Tangent{<:$SymHerm}) - if dx.data isa Tangent - return dx - else - return $SymHerm(project.parent(dx.data)) - end + function (project::ProjectTo{$SymHerm})(dx::Tangent) # structural => natural + return dx.data isa Tangent ? dx : $SymHerm(project.parent(dx.data), project.uplo) end # This is an example of a subspace which is not a subtype, # not clear how broadly it's worthwhile to try to support this. @@ -491,8 +482,9 @@ for UL in (:UpperTriangular, :LowerTriangular, :UnitUpperTriangular, :UnitLowerT ) return Diagonal(sub_one(dx.diag)) end - # Convert "structural" `Tangent`s to array-like "natural" tangents - (project::ProjectTo{$UL})(dx::Tangent{<:$UL, NamedTuple{(:data,), <:Tuple{AbstractMatrix}}}) = $UL(dx.data) + function (project::ProjectTo{$UL})(dx::Tangent) # structural => natural + return dx.data isa Tangent ? dx : $UL(project.parent(dx.data), project.uplo) + end end end diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index a062a0131..a3a1710ff 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -37,7 +37,10 @@ Base.reverse(z::AbstractZero, args...; kwargs...) = z LinearAlgebra.adjoint(z::AbstractZero, ind...) = z LinearAlgebra.transpose(z::AbstractZero, ind...) = z -for T in (:UniformScaling, :Adjoint, :Transpose, :Diagonal) +for T in ( + :UniformScaling, :Adjoint, :Transpose, :Diagonal + :UpperTriangular, :LowerTriangular, :UnitUpperTriangular, :UnitLowerTriangular, + ) @eval (::Type{<:LinearAlgebra.$T})(z::AbstractZero) = z end for T in (:Symmetric, :Hermitian) From af192e434a73862fe73c90cfe0076d98aaf3ff10 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 15 Oct 2021 21:22:48 -0400 Subject: [PATCH 05/21] parent -> data, to match field name --- src/projection.jl | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index 3a3c32175..6bcc4c789 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -445,22 +445,22 @@ for (SymHerm, chk, fun) in sub = ProjectTo(parent(x)) # Because the projector stores uplo, ProjectTo(Symmetric(rand(3,3) .> 0)) isn't automatically trivial: sub isa ProjectTo{<:AbstractZero} && return sub - return ProjectTo{$SymHerm}(; uplo=LinearAlgebra.sym_uplo(x.uplo), parent=sub) + return ProjectTo{$SymHerm}(; uplo=LinearAlgebra.sym_uplo(x.uplo), data=sub) end function (project::ProjectTo{$SymHerm})(dx::AbstractArray) - dy = project.parent(dx) + dy = project.data(dx) # Here $chk means this is efficient on same-type. # If we could mutate dx, then that could speed up action on dx::Matrix. dz = $chk(dy) ? dy : (dy .+ $fun(dy)) ./ 2 - return $SymHerm(project.parent(dz), project.uplo) + return $SymHerm(project.data(dz), project.uplo) end function (project::ProjectTo{$SymHerm})(dx::Tangent) # structural => natural - return dx.data isa Tangent ? dx : $SymHerm(project.parent(dx.data), project.uplo) + return dx.data isa Tangent ? dx : $SymHerm(project.data(dx.data), project.uplo) end # This is an example of a subspace which is not a subtype, # not clear how broadly it's worthwhile to try to support this. function (project::ProjectTo{$SymHerm})(dx::Diagonal) - sub = project.parent # this is going to be unhappy about the size + sub = project.data # this is going to be unhappy about the size sub_one = ProjectTo{project_type(sub)}(; element=sub.element, axes=(sub.axes[1],) ) @@ -472,19 +472,19 @@ end # Triangular for UL in (:UpperTriangular, :LowerTriangular, :UnitUpperTriangular, :UnitLowerTriangular) # UpperHessenberg @eval begin - ProjectTo(x::$UL) = ProjectTo{$UL}(; parent=ProjectTo(parent(x))) - (project::ProjectTo{$UL})(dx::AbstractArray) = $UL(project.parent(dx)) + ProjectTo(x::$UL) = ProjectTo{$UL}(; data=ProjectTo(parent(x))) + (project::ProjectTo{$UL})(dx::AbstractArray) = $UL(project.data(dx)) + function (project::ProjectTo{$UL})(dx::Tangent) # structural => natural + return dx.data isa Tangent ? dx : $UL(project.data(dx.data), project.uplo) + end # Another subspace which is not a subtype, like Diagonal inside Symmetric above, equally unsure function (project::ProjectTo{$UL})(dx::Diagonal) - sub = project.parent + sub = project.data sub_one = ProjectTo{project_type(sub)}(; element=sub.element, axes=(sub.axes[1],) ) return Diagonal(sub_one(dx.diag)) end - function (project::ProjectTo{$UL})(dx::Tangent) # structural => natural - return dx.data isa Tangent ? dx : $UL(project.parent(dx.data), project.uplo) - end end end From 8579febdfa0d86fe48d8d2b41829fe61e055bf66 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 15 Oct 2021 21:23:18 -0400 Subject: [PATCH 06/21] allow any Tangent, always --- src/projection.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/projection.jl b/src/projection.jl index 6bcc4c789..4a15ccd53 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -140,7 +140,7 @@ end # Tangent # We haven't entirely figured out when to convert Tangents to "natural" representations such as # dx::AbstractArray (when both are possible), or the reverse. So for now we just pass them through: -(::ProjectTo{T})(dx::Tangent{<:T}) where {T} = dx +(::ProjectTo)(dx::Tangent) = dx ##### ##### `Base` From 7a49da92de42b921cd0d78e4b1010bca9e8a8718 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 18 Oct 2021 09:29:38 -0400 Subject: [PATCH 07/21] restore Tangent{<:T} constraints --- src/projection.jl | 37 +++++++++++++++++++----------- src/tangent_types/abstract_zero.jl | 3 +++ 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index 4a15ccd53..7512c1765 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -140,7 +140,7 @@ end # Tangent # We haven't entirely figured out when to convert Tangents to "natural" representations such as # dx::AbstractArray (when both are possible), or the reverse. So for now we just pass them through: -(::ProjectTo)(dx::Tangent) = dx +(::ProjectTo{T})(dx::Tangent{<:T}) where {T} = dx ##### ##### `Base` @@ -380,6 +380,8 @@ end using LinearAlgebra: AdjointAbsVec, TransposeAbsVec, AdjOrTransAbsVec +const ArrayOrZero = Union{AbstractArray, AbstractZero} + # UniformScaling can represent its own cotangent ProjectTo(x::UniformScaling) = ProjectTo{UniformScaling}(; λ=ProjectTo(x.λ)) ProjectTo(x::UniformScaling{Bool}) = ProjectTo(false) @@ -402,13 +404,12 @@ function (project::ProjectTo{Adjoint})(dx::AbstractArray) return adjoint(project.parent(dy)) end # structural => natural standardisation, broadest possible signature -function (project::ProjectTo{Adjoint})(dx::Tangent) - if dx.parent isa Tangent +function (project::ProjectTo{Adjoint})(dx::Tangent{<:Adjoint}) + if dx.parent isa ArrayOrZero + return Adjoint(project.parent(dx.parent)) + else # Can't wrap a structural representation of an array in an Adjoint: return dx - else - # This case should handle dx.parent isa AbstractZero, too - return Adjoint(project.parent(dx.parent)) end end @@ -425,16 +426,16 @@ function (project::ProjectTo{Transpose})(dx::AbstractArray) dy = eltype(dx) <: Number ? vec(dx) : transpose(dx) return transpose(project.parent(dy)) end -function (project::ProjectTo{Transpose})(dx::Tangent) # structural => natural - return dx.parent isa Tangent ? dx : Transpose(project.parent(dx.parent)) +function (project::ProjectTo{Transpose})(dx::Tangent{<:Transpose}) # structural => natural + return dx.parent isa ArrayOrZero ? Transpose(project.parent(dx.parent)) : dx end # Diagonal ProjectTo(x::Diagonal) = ProjectTo{Diagonal}(; diag=ProjectTo(x.diag)) (project::ProjectTo{Diagonal})(dx::AbstractMatrix) = Diagonal(project.diag(diag(dx))) (project::ProjectTo{Diagonal})(dx::Diagonal) = Diagonal(project.diag(dx.diag)) -function (project::ProjectTo{Diagonal})(dx::Tangent) # structural => natural - return dx.diag isa Tangent ? dx.diag : Diagonal(project.diag(dx.diag)) +function (project::ProjectTo{Diagonal})(dx::Tangent{<:Diagonal}) # structural => natural + dx.diag isa ArrayOrZero ? Diagonal(project.diag(dx.diag)) : dx end # Symmetric @@ -454,8 +455,8 @@ for (SymHerm, chk, fun) in dz = $chk(dy) ? dy : (dy .+ $fun(dy)) ./ 2 return $SymHerm(project.data(dz), project.uplo) end - function (project::ProjectTo{$SymHerm})(dx::Tangent) # structural => natural - return dx.data isa Tangent ? dx : $SymHerm(project.data(dx.data), project.uplo) + function (project::ProjectTo{$SymHerm})(dx::Tangent{<:$SymHerm}) # structural => natural + dx.data isa ArrayOrZero ? $SymHerm(project.data(dx.data), project.uplo) : dx end # This is an example of a subspace which is not a subtype, # not clear how broadly it's worthwhile to try to support this. @@ -474,8 +475,8 @@ for UL in (:UpperTriangular, :LowerTriangular, :UnitUpperTriangular, :UnitLowerT @eval begin ProjectTo(x::$UL) = ProjectTo{$UL}(; data=ProjectTo(parent(x))) (project::ProjectTo{$UL})(dx::AbstractArray) = $UL(project.data(dx)) - function (project::ProjectTo{$UL})(dx::Tangent) # structural => natural - return dx.data isa Tangent ? dx : $UL(project.data(dx.data), project.uplo) + function (project::ProjectTo{$UL})(dx::Tangent{<:$UL}) # structural => natural + dx.data isa ArrayOrZero ? $UL(project.data(dx.data), project.uplo) : dx end # Another subspace which is not a subtype, like Diagonal inside Symmetric above, equally unsure function (project::ProjectTo{$UL})(dx::Diagonal) @@ -507,6 +508,14 @@ function (project::ProjectTo{Bidiagonal})(dx::Bidiagonal) return Bidiagonal(dv, ev, uplo) end end +function (project::ProjectTo{Bidiagonal})(dx::Tangent{<:Bidiagonal}) # structural => natural + if dx.dv isa ArrayOrZero && dx.ev isa ArrayOrZero + # possibly the various cases should live here, not as methods of constructor? + Bidiagonal(project.dv(dx.dv), project.ev(dx.ev), project.uplo) + else + dx + end +end ProjectTo(x::SymTridiagonal{T}) where {T<:Number} = generic_projector(x) function (project::ProjectTo{SymTridiagonal})(dx::AbstractMatrix) diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index a3a1710ff..ff737d088 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -46,6 +46,9 @@ end for T in (:Symmetric, :Hermitian) @eval (::Type{<:LinearAlgebra.$T})(z::AbstractZero, uplo=:U) = z end +LinearAlgebra.Bidiagonal(dv::AbstractZero, ev::AbstractZero, uplo::Symbol) = NoTangent() +LinearAlgebra.Bidiagonal(dv::AbstractArray, ev::AbstractZero, uplo::Symbol) = Bidiagonal(dv, zero(dv)[1:end-1], uplo) +LinearAlgebra.Bidiagonal(dv::AbstractZero, ev::AbstractArray, uplo::Symbol) = Bidiagonal(vcat(zero(ev), false), ev, uplo) """ ZeroTangent() <: AbstractZero From fcdfcbc6907a7cac21ec6d38057e5d8a080f9cdd Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 18 Oct 2021 09:39:40 -0400 Subject: [PATCH 08/21] return, etc --- src/projection.jl | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index 7512c1765..5861e6d96 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -406,9 +406,11 @@ end # structural => natural standardisation, broadest possible signature function (project::ProjectTo{Adjoint})(dx::Tangent{<:Adjoint}) if dx.parent isa ArrayOrZero + # Adjoint handles ZeroTangent, which could also be produced by project.parent return Adjoint(project.parent(dx.parent)) else - # Can't wrap a structural representation of an array in an Adjoint: + # Can't wrap a structural representation, or a thunk, in an Adjoint. + # But do these happen? return dx end end @@ -435,7 +437,7 @@ ProjectTo(x::Diagonal) = ProjectTo{Diagonal}(; diag=ProjectTo(x.diag)) (project::ProjectTo{Diagonal})(dx::AbstractMatrix) = Diagonal(project.diag(diag(dx))) (project::ProjectTo{Diagonal})(dx::Diagonal) = Diagonal(project.diag(dx.diag)) function (project::ProjectTo{Diagonal})(dx::Tangent{<:Diagonal}) # structural => natural - dx.diag isa ArrayOrZero ? Diagonal(project.diag(dx.diag)) : dx + return dx.diag isa ArrayOrZero ? Diagonal(project.diag(dx.diag)) : dx end # Symmetric @@ -456,7 +458,7 @@ for (SymHerm, chk, fun) in return $SymHerm(project.data(dz), project.uplo) end function (project::ProjectTo{$SymHerm})(dx::Tangent{<:$SymHerm}) # structural => natural - dx.data isa ArrayOrZero ? $SymHerm(project.data(dx.data), project.uplo) : dx + return dx.data isa ArrayOrZero ? $SymHerm(project.data(dx.data), project.uplo) : dx end # This is an example of a subspace which is not a subtype, # not clear how broadly it's worthwhile to try to support this. @@ -476,7 +478,7 @@ for UL in (:UpperTriangular, :LowerTriangular, :UnitUpperTriangular, :UnitLowerT ProjectTo(x::$UL) = ProjectTo{$UL}(; data=ProjectTo(parent(x))) (project::ProjectTo{$UL})(dx::AbstractArray) = $UL(project.data(dx)) function (project::ProjectTo{$UL})(dx::Tangent{<:$UL}) # structural => natural - dx.data isa ArrayOrZero ? $UL(project.data(dx.data), project.uplo) : dx + return dx.data isa ArrayOrZero ? $UL(project.data(dx.data), project.uplo) : dx end # Another subspace which is not a subtype, like Diagonal inside Symmetric above, equally unsure function (project::ProjectTo{$UL})(dx::Diagonal) @@ -511,9 +513,9 @@ end function (project::ProjectTo{Bidiagonal})(dx::Tangent{<:Bidiagonal}) # structural => natural if dx.dv isa ArrayOrZero && dx.ev isa ArrayOrZero # possibly the various cases should live here, not as methods of constructor? - Bidiagonal(project.dv(dx.dv), project.ev(dx.ev), project.uplo) + return Bidiagonal(project.dv(dx.dv), project.ev(dx.ev), project.uplo) else - dx + return dx end end From 4badfe1ab3d0d0782bba4d08c462811560eaadcf Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Mon, 1 Nov 2021 23:03:52 -0400 Subject: [PATCH 09/21] simplify Bidiagonal, I think --- src/projection.jl | 29 ++++++++++++++++++----------- src/tangent_types/abstract_zero.jl | 3 --- test/projection.jl | 5 +++++ 3 files changed, 23 insertions(+), 14 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index 5861e6d96..7aa7886a2 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -491,29 +491,36 @@ for UL in (:UpperTriangular, :LowerTriangular, :UnitUpperTriangular, :UnitLowerT end end -# Weird -- not exhaustive! +# Weird cases -- not exhaustive! + # one strategy is to recurse into the struct: ProjectTo(x::Bidiagonal{T}) where {T<:Number} = generic_projector(x) function (project::ProjectTo{Bidiagonal})(dx::AbstractMatrix) - uplo = LinearAlgebra.sym_uplo(project.uplo) - dv = project.dv(diag(dx)) - ev = project.ev(uplo === :U ? diag(dx, 1) : diag(dx, -1)) - return Bidiagonal(dv, ev, uplo) + dy = Bidiagonal(dx, LinearAlgebra.sym_uplo(project.uplo)) + return generic_projection(project, dy) end function (project::ProjectTo{Bidiagonal})(dx::Bidiagonal) if project.uplo == dx.uplo return generic_projection(project, dx) # fast path else - uplo = LinearAlgebra.sym_uplo(project.uplo) - dv = project.dv(diag(dx)) - ev = fill!(similar(dv, length(dv) - 1), 0) - return Bidiagonal(dv, ev, uplo) + # Allow Diagonal, subspace which is not a subtype + return Diagonal(project.dv(dx.dv)) end end +function (project::ProjectTo{Bidiagonal})(dx::Diagonal) # subspace which is not a subtype + return Diagonal(project.dv(dx.diag)) +end function (project::ProjectTo{Bidiagonal})(dx::Tangent{<:Bidiagonal}) # structural => natural if dx.dv isa ArrayOrZero && dx.ev isa ArrayOrZero - # possibly the various cases should live here, not as methods of constructor? - return Bidiagonal(project.dv(dx.dv), project.ev(dx.ev), project.uplo) + dv = project.dv(dx.dv) + ev = project.ev(dx.ev) + if ev isa AbstractZero # then collapse to Diagonal, or possibly Zero: + return Diagonal(dv) + elseif dv isa AbstractZero # a bit ugly, must construct explicit zeros: + dv = fill!(similar(ev, length(ev) + 1), 0) + ev = convert(typeof(dv), ev) # required if ev isa Fill, or a OneElement + end + return Bidiagonal(dv, ev, project.uplo) # neither argument can be a Zero else return dx end diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index ff737d088..a3a1710ff 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -46,9 +46,6 @@ end for T in (:Symmetric, :Hermitian) @eval (::Type{<:LinearAlgebra.$T})(z::AbstractZero, uplo=:U) = z end -LinearAlgebra.Bidiagonal(dv::AbstractZero, ev::AbstractZero, uplo::Symbol) = NoTangent() -LinearAlgebra.Bidiagonal(dv::AbstractArray, ev::AbstractZero, uplo::Symbol) = Bidiagonal(dv, zero(dv)[1:end-1], uplo) -LinearAlgebra.Bidiagonal(dv::AbstractZero, ev::AbstractArray, uplo::Symbol) = Bidiagonal(vcat(zero(ev), false), ev, uplo) """ ZeroTangent() <: AbstractZero diff --git a/test/projection.jl b/test/projection.jl index c4978aca9..44eb74845 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -312,6 +312,11 @@ struct NoSuperType end bu = Bidiagonal(rand(3, 3) .+ im, :U) # differs but uplo, not type @test pbi(bu) == diagm(0 => diag(real(bu))) @test_throws DimensionMismatch pbi(rand(ComplexF32, 3, 2)) + # structural => natural + @test pbi(Tangent{Bidiagonal}(; ev=(1:2.0))) isa Bidiagonal # constructs the diagonal + # subspace but not a subtype: + @test pbi(Tangent{Bidiagonal}(; dv=[1,2,3+im])) isa Diagonal{Float64} + @test pbi(Diagonal(1:3)) isa Diagonal{Float64} pstri = ProjectTo(SymTridiagonal(Symmetric(rand(3, 3)))) @test pstri(reshape(1:9, 3, 3)) == [1.0 3.0 0.0; 3.0 5.0 7.0; 0.0 7.0 9.0] From 4220f36035eda97d7cfa8bcc2ae6bb9228ae69db Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 3 Nov 2021 14:25:34 -0400 Subject: [PATCH 10/21] rewrite multidiagonal cases --- src/projection.jl | 112 +++++++++++++++++------------ src/tangent_types/abstract_zero.jl | 30 ++++++++ test/projection.jl | 13 +++- 3 files changed, 107 insertions(+), 48 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index 7aa7886a2..61f4fc2c6 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -490,60 +490,78 @@ for UL in (:UpperTriangular, :LowerTriangular, :UnitUpperTriangular, :UnitLowerT end end end - -# Weird cases -- not exhaustive! - -# one strategy is to recurse into the struct: -ProjectTo(x::Bidiagonal{T}) where {T<:Number} = generic_projector(x) -function (project::ProjectTo{Bidiagonal})(dx::AbstractMatrix) - dy = Bidiagonal(dx, LinearAlgebra.sym_uplo(project.uplo)) - return generic_projection(project, dy) -end -function (project::ProjectTo{Bidiagonal})(dx::Bidiagonal) - if project.uplo == dx.uplo - return generic_projection(project, dx) # fast path - else - # Allow Diagonal, subspace which is not a subtype - return Diagonal(project.dv(dx.dv)) - end -end -function (project::ProjectTo{Bidiagonal})(dx::Diagonal) # subspace which is not a subtype - return Diagonal(project.dv(dx.diag)) -end -function (project::ProjectTo{Bidiagonal})(dx::Tangent{<:Bidiagonal}) # structural => natural - if dx.dv isa ArrayOrZero && dx.ev isa ArrayOrZero - dv = project.dv(dx.dv) - ev = project.ev(dx.ev) - if ev isa AbstractZero # then collapse to Diagonal, or possibly Zero: - return Diagonal(dv) - elseif dv isa AbstractZero # a bit ugly, must construct explicit zeros: - dv = fill!(similar(ev, length(ev) + 1), 0) - ev = convert(typeof(dv), ev) # required if ev isa Fill, or a OneElement +for UUL in (:UnitUpperTriangular, :UnitLowerTriangular) + # UL = Symbol(string(UUL)[5:end]) + @eval begin + ProjectTo(x::$UUL) = ProjectTo{$UUL}(; data=ProjectTo(parent(x))) + function (project::ProjectTo{$UUL})(dx::AbstractArray) + dy = project.data(dx) + # Since x's diagonal is fixed to 1, dx must be zero there: + $UUL(dy) - I # makes an UpperTriangular, etc. end - return Bidiagonal(dv, ev, project.uplo) # neither argument can be a Zero - else - return dx + # (project::ProjectTo{$UUL})(dx::$UL) = project.data(dx) end end - -ProjectTo(x::SymTridiagonal{T}) where {T<:Number} = generic_projector(x) -function (project::ProjectTo{SymTridiagonal})(dx::AbstractMatrix) - dv = project.dv(diag(dx)) - ev = project.ev((diag(dx, 1) .+ diag(dx, -1)) ./ 2) +# Subspaces which aren't subtypes, like Diagonal inside Symmetric above: +(project::ProjectTo{UpperTriangular})(dx::Diagonal) = project.data(dx) +(project::ProjectTo{LowerTriangular})(dx::Diagonal) = project.data(dx) + +(project::ProjectTo{UpperHessenberg})(dx::Diagonal) = project.data(dx) +(project::ProjectTo{UpperHessenberg})(dx::UpperTriangular) = project.data(dx) + +(project::ProjectTo{UnitUpperTriangular})(dx::Diagonal) = NoTangent() +(project::ProjectTo{UnitUpperTriangular})(dx::UpperTriangular) = project.data(dx) # produced by projector +(project::ProjectTo{UnitLowerTriangular})(dx::Diagonal) = NoTangent() +(project::ProjectTo{UnitLowerTriangular})(dx::LowerTriangular) = project.data(dx) + +# Multidiagonal +# For all of these, the eltypes must all match, so store one full-size projector for simplicity. +function ProjectTo(x::Bidiagonal) + full = invoke(ProjectTo, Tuple{AbstractArray{T2}} where {T2<:Number}, x) + full isa ProjectTo{<:AbstractZero} && return full + ProjectTo{Bidiagonal}(full = full, uplo = LinearAlgebra.sym_uplo(x.uplo)) +end +function ProjectTo(x::Tridiagonal) + full = invoke(ProjectTo, Tuple{AbstractArray{T2}} where {T2<:Number}, x) + # full isa ProjectTo{<:AbstractZero} && return full + ProjectTo{Tridiagonal}(full = full) +end +function ProjectTo(x::SymTridiagonal) + full = invoke(ProjectTo, Tuple{AbstractArray{T2}} where {T2<:Number}, x) + # full isa ProjectTo{<:AbstractZero} && return full + ProjectTo{SymTridiagonal}(full = full) +end +(project::ProjectTo{Bidiagonal})(dx::AbstractArray) = Bidiagonal(project.full(dx), project.uplo) +(project::ProjectTo{Bidiagonal})(dx::AbstractMatrix) = project.full(Bidiagonal(dx, project.uplo)) +(project::ProjectTo{Tridiagonal})(dx::AbstractArray) = Tridiagonal(project.full(dx)) +(project::ProjectTo{Tridiagonal})(dx::AbstractMatrix) = project.full(Tridiagonal(dx)) +(project::ProjectTo{SymTridiagonal})(dx::SymTridiagonal) = SymTridiagonal(project.full(dx)) +(project::ProjectTo{SymTridiagonal})(dx::Symmetric) = project.full(SymTridiagonal(dx)) +function (project::ProjectTo{SymTridiagonal})(dx::AbstractArray) + dz = project.full(dx) + dv = diag(dz) + ev = (diag(dz, 1) .+ diag(dz, -1)) ./ 2 return SymTridiagonal(dv, ev) end -(project::ProjectTo{SymTridiagonal})(dx::SymTridiagonal) = generic_projection(project, dx) - -# another strategy is just to use the AbstractArray method -function ProjectTo(x::Tridiagonal{T}) where {T<:Number} - notparent = invoke(ProjectTo, Tuple{AbstractArray{T2}} where {T2<:Number}, x) - return ProjectTo{Tridiagonal}(; notparent=notparent) +# Subspaces which aren't subtypes: +(project::ProjectTo{Bidiagonal})(dx::Diagonal) = project.full(dx) +(project::ProjectTo{Tridiagonal})(dx::Diagonal) = project.full(dx) +(project::ProjectTo{Tridiagonal})(dx::Bidiagonal) = project.full(dx) +(project::ProjectTo{SymTridiagonal})(dx::Diagonal) = project.full(dx) +# structural => natural +function (project::ProjectTo{Bidiagonal})(dx::Tangent{<:Bidiagonal}) + dx.dv isa ArrayOrZero && dx.ev isa ArrayOrZero || return dx + return project.full(Bidiagonal(dx.dv, dx.ev, project.uplo)) # will return a Diagonal when ev::AbstractZero +end +function (project::ProjectTo{Tridiagonal})(dx::Tangent{<:Tridiagonal}) + dx.dl isa ArrayOrZero && dx.d isa ArrayOrZero && dx.du isa ArrayOrZero || return dx + return project.full(Tridiagonal(dx.dl, dx.d, dx.du)) end -function (project::ProjectTo{Tridiagonal})(dx::AbstractArray) - dy = project.notparent(dx) - return Tridiagonal(dy) +function (project::ProjectTo{SymTridiagonal})(dx::Tangent{<:SymTridiagonal}) + dx.dv isa ArrayOrZero && dx.ev isa ArrayOrZero || return dx + return project.full(SymTridiagonal(dx.dv, dx.ev)) end -# Note that backing(::Tridiagonal) doesn't work, https://github.com/JuliaDiff/ChainRulesCore.jl/issues/392 + ##### ##### `SparseArrays` diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index a3a1710ff..80d92e128 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -40,6 +40,7 @@ LinearAlgebra.transpose(z::AbstractZero, ind...) = z for T in ( :UniformScaling, :Adjoint, :Transpose, :Diagonal :UpperTriangular, :LowerTriangular, :UnitUpperTriangular, :UnitLowerTriangular, + :UpperHessenberg ) @eval (::Type{<:LinearAlgebra.$T})(z::AbstractZero) = z end @@ -47,6 +48,35 @@ for T in (:Symmetric, :Hermitian) @eval (::Type{<:LinearAlgebra.$T})(z::AbstractZero, uplo=:U) = z end +LinearAlgebra.Bidiagonal(dv::AbstractVector, ev::AbstractZero, uplo::Symbol) = Diagonal(dv) +function LinearAlgebra.Bidiagonal(dv::AbstractZero, ev::AbstractVector, uplo::Symbol) + dv = fill!(similar(ev, length(ev) + 1), 0) # can't avoid making a dummy array + Bidiagonal(dv, convert(typeof(dv), ev), uplo) +end +LinearAlgebra.Bidiagonal(dv::AbstractZero, ev::AbstractZero, uplo::Symbol) = NoTangent() + +# one Zero: +LinearAlgebra.Tridiagonal(dl::AbstractZero, d::AbstractVector, du::AbstractZero) = Diagonal(d) +LinearAlgebra.Tridiagonal(dl::AbstractZero, d::AbstractVector, du::AbstractVector) = Bidiagonal(d, du, :U) +LinearAlgebra.Tridiagonal(dl::AbstractVector, d::AbstractVector, du::AbstractZero) = Bidiagonal(d, dl, :L) +function LinearAlgebra.Tridiagonal(dl::AbstractVector, d::AbstractZero, du::AbstractVector) + d = fill!(similar(dl, length(dl) + 1), 0) + Tridiagonal(convert(typeof(d), dl), d, convert(typeof(d), du)) +end +# two Zeros: +LinearAlgebra.Tridiagonal(dl::AbstractZero, d::AbstractVector, du::AbstractZero) = Diagonal(d) +LinearAlgebra.Tridiagonal(dl::AbstractZero, d::AbstractZero, du::AbstractVector) = Bidiagonal(d, du, :U) +LinearAlgebra.Tridiagonal(dl::AbstractVector, d::AbstractZero, du::AbstractZero) = Bidiagonal(d, dl, :L) +# three Zeros: +LinearAlgebra.Tridiagonal(dl::AbstractZero, d::AbstractZero, du::AbstractZero) = NoTangent() + +LinearAlgebra.SymTridiagonal(dv::AbstractVector, ev::AbstractZero) = Diagonal(dv) +function LinearAlgebra.SymTridiagonal(dv::AbstractZero, ev::AbstractVector) + dv = fill!(similar(ev, length(ev) + 1), 0) + SymTridiagonal(dv, convert(typeof(dv), ev)) +end +LinearAlgebra.SymTridiagonal(dv::AbstractZero, ev::AbstractZero) = NoTangent() + """ ZeroTangent() <: AbstractZero diff --git a/test/projection.jl b/test/projection.jl index 44eb74845..ca4f72990 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -325,12 +325,23 @@ struct NoSuperType end stri = SymTridiagonal(Symmetric(rand(3, 3) .+ im)) @test pstri(stri) == real(stri) @test_throws DimensionMismatch pstri(rand(ComplexF32, 3, 2)) + # structural => natural + pstri(Tangent{SymTridiagonal}(ev = [1,2])) # matrix + # subspace but not a subtype: + @test pstri(Tangent{SymTridiagonal}(dv = [1, 2, 3])) isa Diagonal{Float64} + @test pstri(Diagonal(rand(3) .+ im)) isa Diagonal{Float64} ptri = ProjectTo(Tridiagonal(rand(3, 3))) @test ptri(reshape(1:9, 3, 3)) == [1.0 4.0 0.0; 2.0 5.0 8.0; 0.0 6.0 9.0] @test ptri(ptri(reshape(1:9, 3, 3))) == ptri(reshape(1:9, 3, 3)) @test ptri(rand(ComplexF32, 3, 3)) isa Tridiagonal{Float64} - @test_throws DimensionMismatch ptri(rand(ComplexF32, 3, 2)) + @test_throws ArgumentError ptri(rand(ComplexF32, 3, 2)) + # structural => natural + ptri(Tangent{Tridiagonal}(du = [1, 2], dl = [3im, 4im])) isa Tridiagonal{Float64} + # subspace but not a subtype: + ptri(Tangent{Tridiagonal}(du = [1, 2])) isa Bidiagonal{Float64} + ptri(Tangent{Tridiagonal}(du = [1, 2], d = [3im, 4im, 5im])) isa Bidiagonal{Float64} + ptri(Tangent{Tridiagonal}(d = [1, 2, 3])) isa Diagonal{Float64} end ##### From 171b896cdc87ea6df2e1f3d07211c8a33a54818e Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 3 Nov 2021 20:41:34 -0400 Subject: [PATCH 11/21] cleanup + bugfix triangular etc cases --- src/projection.jl | 20 +++----------------- test/projection.jl | 10 ++++++++++ 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index 61f4fc2c6..c5c6b6af9 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -462,31 +462,17 @@ for (SymHerm, chk, fun) in end # This is an example of a subspace which is not a subtype, # not clear how broadly it's worthwhile to try to support this. - function (project::ProjectTo{$SymHerm})(dx::Diagonal) - sub = project.data # this is going to be unhappy about the size - sub_one = ProjectTo{project_type(sub)}(; - element=sub.element, axes=(sub.axes[1],) - ) - return Diagonal(sub_one(dx.diag)) - end + (project::ProjectTo{$SymHerm})(dx::Diagonal) = project.data(dx) end end # Triangular -for UL in (:UpperTriangular, :LowerTriangular, :UnitUpperTriangular, :UnitLowerTriangular) # UpperHessenberg +for UL in (:UpperTriangular, :LowerTriangular, :UpperHessenberg) @eval begin ProjectTo(x::$UL) = ProjectTo{$UL}(; data=ProjectTo(parent(x))) (project::ProjectTo{$UL})(dx::AbstractArray) = $UL(project.data(dx)) function (project::ProjectTo{$UL})(dx::Tangent{<:$UL}) # structural => natural - return dx.data isa ArrayOrZero ? $UL(project.data(dx.data), project.uplo) : dx - end - # Another subspace which is not a subtype, like Diagonal inside Symmetric above, equally unsure - function (project::ProjectTo{$UL})(dx::Diagonal) - sub = project.data - sub_one = ProjectTo{project_type(sub)}(; - element=sub.element, axes=(sub.axes[1],) - ) - return Diagonal(sub_one(dx.diag)) + return dx.data isa ArrayOrZero ? $UL(project.data(dx.data)) : dx end end end diff --git a/test/projection.jl b/test/projection.jl index ca4f72990..6c7e3ceb8 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -287,6 +287,16 @@ struct NoSuperType end @test pupp(rand(ComplexF32, 3, 3, 1)) isa UpperTriangular{Float64} @test ProjectTo(UpperTriangular(randn(3, 3) .> 0))(randn(3, 3)) == NoTangent() + phess = ProjectTo(UpperHessenberg(rand(3, 3))) + @test phess(reshape(1:9,3,3)) == [1 4 7; 2 5 8; 0 6 9] + @test phess(reshape(1:9,3,3) .+ im) isa UpperHessenberg{Float64} + + pdu = ProjectTo(UnitLowerTriangular(rand(3, 3))) + # NB, since the diagonal is constant 1, its gradient is zero: + @test pdu(reshape(1:9, 3, 3)) == [0 0 0; 2 0 0; 3 6 0] + @test pdu(rand(ComplexF32, 3, 3, 1)) isa LowerTriangular{Float64} + @test pdu(Diagonal(1:3)) == NoTangent() + # an experiment with allowing subspaces which aren't subtypes @test psymm(Diagonal([1, 2, 3])) isa Diagonal{Float64} @test pupp(Diagonal([1, 2, 3 + 4im])) isa Diagonal{Float64} From 630d207b7bf69695f93e7efbe8eec6be4923bb82 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 7 Nov 2021 18:22:42 -0500 Subject: [PATCH 12/21] cleanup --- src/projection.jl | 46 ++++++++++++++++++----------- src/tangent_types/abstract_zero.jl | 27 ++++++++++++----- test/projection.jl | 3 +- test/tangent_types/abstract_zero.jl | 22 ++++++++++++++ 4 files changed, 72 insertions(+), 26 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index c5c6b6af9..3f84d007f 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -462,7 +462,7 @@ for (SymHerm, chk, fun) in end # This is an example of a subspace which is not a subtype, # not clear how broadly it's worthwhile to try to support this. - (project::ProjectTo{$SymHerm})(dx::Diagonal) = project.data(dx) + (project::ProjectTo{$SymHerm})(dx::Diagonal) = project.data(dx) end end @@ -477,15 +477,20 @@ for UL in (:UpperTriangular, :LowerTriangular, :UpperHessenberg) end end for UUL in (:UnitUpperTriangular, :UnitLowerTriangular) - # UL = Symbol(string(UUL)[5:end]) + UL = Symbol(string(UUL)[5:end]) @eval begin ProjectTo(x::$UUL) = ProjectTo{$UUL}(; data=ProjectTo(parent(x))) function (project::ProjectTo{$UUL})(dx::AbstractArray) dy = project.data(dx) # Since x's diagonal is fixed to 1, dx must be zero there: - $UUL(dy) - I # makes an UpperTriangular, etc. + return $UUL(dy) - I # makes an UpperTriangular, etc. + end + # No type perfectly encodes the gradient of UnitUpperTriangular. + # To avoid unnecessary copies of what projection produces, + # allow any UpperTriangular through: + function (project::ProjectTo{$UUL})(dx::$UL) + dy = project.data(dx) end - # (project::ProjectTo{$UUL})(dx::$UL) = project.data(dx) end end # Subspaces which aren't subtypes, like Diagonal inside Symmetric above: @@ -496,33 +501,40 @@ end (project::ProjectTo{UpperHessenberg})(dx::UpperTriangular) = project.data(dx) (project::ProjectTo{UnitUpperTriangular})(dx::Diagonal) = NoTangent() -(project::ProjectTo{UnitUpperTriangular})(dx::UpperTriangular) = project.data(dx) # produced by projector (project::ProjectTo{UnitLowerTriangular})(dx::Diagonal) = NoTangent() -(project::ProjectTo{UnitLowerTriangular})(dx::LowerTriangular) = project.data(dx) # Multidiagonal # For all of these, the eltypes must all match, so store one full-size projector for simplicity. function ProjectTo(x::Bidiagonal) full = invoke(ProjectTo, Tuple{AbstractArray{T2}} where {T2<:Number}, x) - full isa ProjectTo{<:AbstractZero} && return full - ProjectTo{Bidiagonal}(full = full, uplo = LinearAlgebra.sym_uplo(x.uplo)) + # full isa ProjectTo{<:AbstractZero} && return full # never happens, invoke misses the Bool method + ProjectTo(zero(eltype(x))) isa ProjectTo{<:AbstractZero} && return ProjectTo(false) # better short-circuit + return ProjectTo{Bidiagonal}(full = full, uplo = LinearAlgebra.sym_uplo(x.uplo)) end function ProjectTo(x::Tridiagonal) full = invoke(ProjectTo, Tuple{AbstractArray{T2}} where {T2<:Number}, x) - # full isa ProjectTo{<:AbstractZero} && return full - ProjectTo{Tridiagonal}(full = full) + ProjectTo(zero(eltype(x))) isa ProjectTo{<:AbstractZero} && return ProjectTo(false) + return ProjectTo{Tridiagonal}(full = full) end function ProjectTo(x::SymTridiagonal) full = invoke(ProjectTo, Tuple{AbstractArray{T2}} where {T2<:Number}, x) - # full isa ProjectTo{<:AbstractZero} && return full - ProjectTo{SymTridiagonal}(full = full) + ProjectTo(zero(eltype(x))) isa ProjectTo{<:AbstractZero} && return ProjectTo(false) + return ProjectTo{SymTridiagonal}(full = full) +end +# Own type: `project.full` can convert eltype mantaining strucure +function (project::ProjectTo{Bidiagonal})(dx::Bidiagonal) + if LinearAlgebra.sym_uplo(dx.uplo) == project.uplo + return project.full(dx) + else # make a dummy array, better type-stability than returning a Diagonal + return project.full(Bidiagonal(dx.dv, zero(dx.ev), project.uplo)) + end end +(project::ProjectTo{Tridiagonal})(dx::Tridiagonal) = project.full(dx) +(project::ProjectTo{SymTridiagonal})(dx::SymTridiagonal) = project.full(dx) +# AbstractArray (project::ProjectTo{Bidiagonal})(dx::AbstractArray) = Bidiagonal(project.full(dx), project.uplo) -(project::ProjectTo{Bidiagonal})(dx::AbstractMatrix) = project.full(Bidiagonal(dx, project.uplo)) (project::ProjectTo{Tridiagonal})(dx::AbstractArray) = Tridiagonal(project.full(dx)) -(project::ProjectTo{Tridiagonal})(dx::AbstractMatrix) = project.full(Tridiagonal(dx)) -(project::ProjectTo{SymTridiagonal})(dx::SymTridiagonal) = SymTridiagonal(project.full(dx)) -(project::ProjectTo{SymTridiagonal})(dx::Symmetric) = project.full(SymTridiagonal(dx)) +(project::ProjectTo{SymTridiagonal})(dx::Symmetric) = SymTridiagonal(project.full(dx)) function (project::ProjectTo{SymTridiagonal})(dx::AbstractArray) dz = project.full(dx) dv = diag(dz) @@ -643,6 +655,6 @@ function (project::ProjectTo{SparseMatrixCSC})(dx::SparseMatrixCSC) m, n = size(dx) return SparseMatrixCSC(m, n, dx.colptr, dx.rowval, nzval) else - invoke(project, Tuple{AbstractArray}, dx) + return invoke(project, Tuple{AbstractArray}, dx) end end diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index 80d92e128..204866e3f 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -39,14 +39,14 @@ LinearAlgebra.transpose(z::AbstractZero, ind...) = z for T in ( :UniformScaling, :Adjoint, :Transpose, :Diagonal - :UpperTriangular, :LowerTriangular, :UnitUpperTriangular, :UnitLowerTriangular, - :UpperHessenberg + :UpperTriangular, :LowerTriangular, :UpperHessenberg, + :UnitUpperTriangular, :UnitLowerTriangular, ) @eval (::Type{<:LinearAlgebra.$T})(z::AbstractZero) = z end -for T in (:Symmetric, :Hermitian) - @eval (::Type{<:LinearAlgebra.$T})(z::AbstractZero, uplo=:U) = z -end + +LinearAlgebra.Symmetric(z::AbstractZero, uplo=:U) = z +LinearAlgebra.Hermitian(z::AbstractZero, uplo=:U) = z LinearAlgebra.Bidiagonal(dv::AbstractVector, ev::AbstractZero, uplo::Symbol) = Diagonal(dv) function LinearAlgebra.Bidiagonal(dv::AbstractZero, ev::AbstractVector, uplo::Symbol) @@ -56,9 +56,8 @@ end LinearAlgebra.Bidiagonal(dv::AbstractZero, ev::AbstractZero, uplo::Symbol) = NoTangent() # one Zero: -LinearAlgebra.Tridiagonal(dl::AbstractZero, d::AbstractVector, du::AbstractZero) = Diagonal(d) -LinearAlgebra.Tridiagonal(dl::AbstractZero, d::AbstractVector, du::AbstractVector) = Bidiagonal(d, du, :U) -LinearAlgebra.Tridiagonal(dl::AbstractVector, d::AbstractVector, du::AbstractZero) = Bidiagonal(d, dl, :L) +LinearAlgebra.Tridiagonal(dl::AbstractZero, d::AbstractVector, du::AbstractVector) = Bidiagonal(_promote_vectors(d, du)..., :U) +LinearAlgebra.Tridiagonal(dl::AbstractVector, d::AbstractVector, du::AbstractZero) = Bidiagonal(_promote_vectors(d, dl)..., :L) function LinearAlgebra.Tridiagonal(dl::AbstractVector, d::AbstractZero, du::AbstractVector) d = fill!(similar(dl, length(dl) + 1), 0) Tridiagonal(convert(typeof(d), dl), d, convert(typeof(d), du)) @@ -77,6 +76,18 @@ function LinearAlgebra.SymTridiagonal(dv::AbstractZero, ev::AbstractVector) end LinearAlgebra.SymTridiagonal(dv::AbstractZero, ev::AbstractZero) = NoTangent() +# These types all demand exactly same-type vectors, but may get e.g. Fill, Vector. +_promote_vectors(x::T, y::T) where {T<:AbstractVector} = (x, y) +function _promote_vectors(x::AbstractVector, y::AbstractVector) + T = Base._return_type(+, Tuple{typeof(x), typeof(y)}) + if isconcretetype(T) + return convert(T, x), convert(T, y) + else + short = map(first∘promote, x, y) + return convert(typeof(short), x), convert(typeof(short), y) + end +end + """ ZeroTangent() <: AbstractZero diff --git a/test/projection.jl b/test/projection.jl index 6c7e3ceb8..cced5bf80 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -322,6 +322,7 @@ struct NoSuperType end bu = Bidiagonal(rand(3, 3) .+ im, :U) # differs but uplo, not type @test pbi(bu) == diagm(0 => diag(real(bu))) @test_throws DimensionMismatch pbi(rand(ComplexF32, 3, 2)) + @test ProjectTo(Bidiagonal(randn(3, 3) .> 0, :L))(rand(3, 3)) == NoTangent() # structural => natural @test pbi(Tangent{Bidiagonal}(; ev=(1:2.0))) isa Bidiagonal # constructs the diagonal # subspace but not a subtype: @@ -345,7 +346,7 @@ struct NoSuperType end @test ptri(reshape(1:9, 3, 3)) == [1.0 4.0 0.0; 2.0 5.0 8.0; 0.0 6.0 9.0] @test ptri(ptri(reshape(1:9, 3, 3))) == ptri(reshape(1:9, 3, 3)) @test ptri(rand(ComplexF32, 3, 3)) isa Tridiagonal{Float64} - @test_throws ArgumentError ptri(rand(ComplexF32, 3, 2)) + @test_throws DimensionMismatch ptri(rand(ComplexF32, 3, 2)) # structural => natural ptri(Tangent{Tridiagonal}(du = [1, 2], dl = [3im, 4im])) isa Tridiagonal{Float64} # subspace but not a subtype: diff --git a/test/tangent_types/abstract_zero.jl b/test/tangent_types/abstract_zero.jl index 7ee896b16..8e7fe18ed 100644 --- a/test/tangent_types/abstract_zero.jl +++ b/test/tangent_types/abstract_zero.jl @@ -160,5 +160,27 @@ @test Transpose(ZeroTangent()) === ZeroTangent() @test Symmetric(ZeroTangent()) === ZeroTangent() @test Hermitian(ZeroTangent(), :U) === ZeroTangent() + + # Multidiagonal + @test Bidiagonal([1, 2, 3], ZeroTangent(), :U) == Diagonal(1:3) + @test Bidiagonal(ZeroTangent(), 1:3, :U) == [0 1 0 0; 0 0 2 0; 0 0 0 3; 0 0 0 0] + @test Tridiagonal(ZeroTangent(), ZeroTangent(), [4, 5]) == [0 4 0; 0 0 5; 0 0 0] + @test Tridiagonal(ZeroTangent(), 1:3, ZeroTangent()) == Diagonal(1:3) + end + + @testset "promote vectors" begin + v64s = (Vector{Float64}, Vector{Float64}) + @test v64s == typeof.(ChainRulesCore._promote_vectors(fill(pi,3), [1,2,3])) + @test v64s == typeof.(ChainRulesCore._promote_vectors(Any[1,2,pi], rand(2))) + @test v64s == typeof.(ChainRulesCore._promote_vectors(randn(3) .> 1, rand(10))) + + @test length.(ChainRulesCore._promote_vectors(Any[1,2,pi], rand(2))) == (3, 2) + @test length.(ChainRulesCore._promote_vectors(randn(3) .> 1, rand(10))) == (3, 10) + + _samet((x, y)) = typeof(x) == typeof(y) + @test _samet(ChainRulesCore._promote_vectors(sparse([1,0,3]), [4,5,6])) + @test _samet(ChainRulesCore._promote_vectors(SA[1/2, 3/4], [5, 6])) + # The last isn't actually realistic, since all the constructors which take several + # vectors demand both same type and different length, so cannot accept StaticArrays. end end From b76e0872ddd2fa85f627dee9cbd172012da83ea6 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 7 Nov 2021 18:40:38 -0500 Subject: [PATCH 13/21] fixup --- src/projection.jl | 13 ++++++------- src/tangent_types/abstract_zero.jl | 6 +++--- test/projection.jl | 8 +++++--- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index 3f84d007f..84a70d2e3 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -468,6 +468,7 @@ end # Triangular for UL in (:UpperTriangular, :LowerTriangular, :UpperHessenberg) + VERSION < v"1.4" && UL == :UpperHessenberg && continue # not defined in 1.0 @eval begin ProjectTo(x::$UL) = ProjectTo{$UL}(; data=ProjectTo(parent(x))) (project::ProjectTo{$UL})(dx::AbstractArray) = $UL(project.data(dx)) @@ -488,9 +489,7 @@ for UUL in (:UnitUpperTriangular, :UnitLowerTriangular) # No type perfectly encodes the gradient of UnitUpperTriangular. # To avoid unnecessary copies of what projection produces, # allow any UpperTriangular through: - function (project::ProjectTo{$UUL})(dx::$UL) - dy = project.data(dx) - end + (project::ProjectTo{$UUL})(dx::$UL) = project.data(dx) end end # Subspaces which aren't subtypes, like Diagonal inside Symmetric above: @@ -509,17 +508,17 @@ function ProjectTo(x::Bidiagonal) full = invoke(ProjectTo, Tuple{AbstractArray{T2}} where {T2<:Number}, x) # full isa ProjectTo{<:AbstractZero} && return full # never happens, invoke misses the Bool method ProjectTo(zero(eltype(x))) isa ProjectTo{<:AbstractZero} && return ProjectTo(false) # better short-circuit - return ProjectTo{Bidiagonal}(full = full, uplo = LinearAlgebra.sym_uplo(x.uplo)) + return ProjectTo{Bidiagonal}(; full = full, uplo = LinearAlgebra.sym_uplo(x.uplo)) end function ProjectTo(x::Tridiagonal) full = invoke(ProjectTo, Tuple{AbstractArray{T2}} where {T2<:Number}, x) ProjectTo(zero(eltype(x))) isa ProjectTo{<:AbstractZero} && return ProjectTo(false) - return ProjectTo{Tridiagonal}(full = full) + return ProjectTo{Tridiagonal}(; full = full) end function ProjectTo(x::SymTridiagonal) full = invoke(ProjectTo, Tuple{AbstractArray{T2}} where {T2<:Number}, x) ProjectTo(zero(eltype(x))) isa ProjectTo{<:AbstractZero} && return ProjectTo(false) - return ProjectTo{SymTridiagonal}(full = full) + return ProjectTo{SymTridiagonal}(; full = full) end # Own type: `project.full` can convert eltype mantaining strucure function (project::ProjectTo{Bidiagonal})(dx::Bidiagonal) @@ -532,7 +531,7 @@ end (project::ProjectTo{Tridiagonal})(dx::Tridiagonal) = project.full(dx) (project::ProjectTo{SymTridiagonal})(dx::SymTridiagonal) = project.full(dx) # AbstractArray -(project::ProjectTo{Bidiagonal})(dx::AbstractArray) = Bidiagonal(project.full(dx), project.uplo) +(proj::ProjectTo{Bidiagonal})(dx::AbstractArray) = Bidiagonal(project.full(dx), proj.uplo) (project::ProjectTo{Tridiagonal})(dx::AbstractArray) = Tridiagonal(project.full(dx)) (project::ProjectTo{SymTridiagonal})(dx::Symmetric) = SymTridiagonal(project.full(dx)) function (project::ProjectTo{SymTridiagonal})(dx::AbstractArray) diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index 204866e3f..753e54734 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -51,7 +51,7 @@ LinearAlgebra.Hermitian(z::AbstractZero, uplo=:U) = z LinearAlgebra.Bidiagonal(dv::AbstractVector, ev::AbstractZero, uplo::Symbol) = Diagonal(dv) function LinearAlgebra.Bidiagonal(dv::AbstractZero, ev::AbstractVector, uplo::Symbol) dv = fill!(similar(ev, length(ev) + 1), 0) # can't avoid making a dummy array - Bidiagonal(dv, convert(typeof(dv), ev), uplo) + return Bidiagonal(dv, convert(typeof(dv), ev), uplo) end LinearAlgebra.Bidiagonal(dv::AbstractZero, ev::AbstractZero, uplo::Symbol) = NoTangent() @@ -60,7 +60,7 @@ LinearAlgebra.Tridiagonal(dl::AbstractZero, d::AbstractVector, du::AbstractVecto LinearAlgebra.Tridiagonal(dl::AbstractVector, d::AbstractVector, du::AbstractZero) = Bidiagonal(_promote_vectors(d, dl)..., :L) function LinearAlgebra.Tridiagonal(dl::AbstractVector, d::AbstractZero, du::AbstractVector) d = fill!(similar(dl, length(dl) + 1), 0) - Tridiagonal(convert(typeof(d), dl), d, convert(typeof(d), du)) + return Tridiagonal(convert(typeof(d), dl), d, convert(typeof(d), du)) end # two Zeros: LinearAlgebra.Tridiagonal(dl::AbstractZero, d::AbstractVector, du::AbstractZero) = Diagonal(d) @@ -72,7 +72,7 @@ LinearAlgebra.Tridiagonal(dl::AbstractZero, d::AbstractZero, du::AbstractZero) = LinearAlgebra.SymTridiagonal(dv::AbstractVector, ev::AbstractZero) = Diagonal(dv) function LinearAlgebra.SymTridiagonal(dv::AbstractZero, ev::AbstractVector) dv = fill!(similar(ev, length(ev) + 1), 0) - SymTridiagonal(dv, convert(typeof(dv), ev)) + return SymTridiagonal(dv, convert(typeof(dv), ev)) end LinearAlgebra.SymTridiagonal(dv::AbstractZero, ev::AbstractZero) = NoTangent() diff --git a/test/projection.jl b/test/projection.jl index cced5bf80..79b774647 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -287,9 +287,11 @@ struct NoSuperType end @test pupp(rand(ComplexF32, 3, 3, 1)) isa UpperTriangular{Float64} @test ProjectTo(UpperTriangular(randn(3, 3) .> 0))(randn(3, 3)) == NoTangent() - phess = ProjectTo(UpperHessenberg(rand(3, 3))) - @test phess(reshape(1:9,3,3)) == [1 4 7; 2 5 8; 0 6 9] - @test phess(reshape(1:9,3,3) .+ im) isa UpperHessenberg{Float64} + if VERSION >= v"1.4" # not sure 1.4 exactly! + phess = ProjectTo(UpperHessenberg(rand(3, 3))) + @test phess(reshape(1:9,3,3)) == [1 4 7; 2 5 8; 0 6 9] + @test phess(reshape(1:9,3,3) .+ im) isa UpperHessenberg{Float64} + end pdu = ProjectTo(UnitLowerTriangular(rand(3, 3))) # NB, since the diagonal is constant 1, its gradient is zero: From ab8765ccee0f8ebe4e6448f779ee52fe811bd977 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 7 Nov 2021 18:44:51 -0500 Subject: [PATCH 14/21] a bug --- src/projection.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/projection.jl b/src/projection.jl index 84a70d2e3..0f827f4ff 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -531,7 +531,7 @@ end (project::ProjectTo{Tridiagonal})(dx::Tridiagonal) = project.full(dx) (project::ProjectTo{SymTridiagonal})(dx::SymTridiagonal) = project.full(dx) # AbstractArray -(proj::ProjectTo{Bidiagonal})(dx::AbstractArray) = Bidiagonal(project.full(dx), proj.uplo) +(project::ProjectTo{Bidiagonal})(dx::AbstractArray) = Bidiagonal(project.full(dx), project.uplo) (project::ProjectTo{Tridiagonal})(dx::AbstractArray) = Tridiagonal(project.full(dx)) (project::ProjectTo{SymTridiagonal})(dx::Symmetric) = SymTridiagonal(project.full(dx)) function (project::ProjectTo{SymTridiagonal})(dx::AbstractArray) From 72ac3ba4bb9fbbb838b3cbaf535885fc88ef03e5 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 7 Nov 2021 18:49:42 -0500 Subject: [PATCH 15/21] skip --- src/tangent_types/abstract_zero.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index 753e54734..e8849d53e 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -42,6 +42,7 @@ for T in ( :UpperTriangular, :LowerTriangular, :UpperHessenberg, :UnitUpperTriangular, :UnitLowerTriangular, ) + VERSION < v"1.4" && f == :UpperHessenberg && continue # not defined in 1.0 @eval (::Type{<:LinearAlgebra.$T})(z::AbstractZero) = z end From fce54dfdaebc815f747856c52edabd607276a4d1 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 7 Nov 2021 18:55:02 -0500 Subject: [PATCH 16/21] skip' --- src/projection.jl | 8 ++++---- src/tangent_types/abstract_zero.jl | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index 0f827f4ff..ef246fa9d 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -495,10 +495,10 @@ end # Subspaces which aren't subtypes, like Diagonal inside Symmetric above: (project::ProjectTo{UpperTriangular})(dx::Diagonal) = project.data(dx) (project::ProjectTo{LowerTriangular})(dx::Diagonal) = project.data(dx) - -(project::ProjectTo{UpperHessenberg})(dx::Diagonal) = project.data(dx) -(project::ProjectTo{UpperHessenberg})(dx::UpperTriangular) = project.data(dx) - +if VERSION >= v"1.4" + (project::ProjectTo{UpperHessenberg})(dx::Diagonal) = project.data(dx) + (project::ProjectTo{UpperHessenberg})(dx::UpperTriangular) = project.data(dx) +end (project::ProjectTo{UnitUpperTriangular})(dx::Diagonal) = NoTangent() (project::ProjectTo{UnitLowerTriangular})(dx::Diagonal) = NoTangent() diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index e8849d53e..2616d50b4 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -84,7 +84,7 @@ function _promote_vectors(x::AbstractVector, y::AbstractVector) if isconcretetype(T) return convert(T, x), convert(T, y) else - short = map(first∘promote, x, y) + short = map(first ∘ promote, x, y) return convert(typeof(short), x), convert(typeof(short), y) end end From b3e69f0294d9ffc4a5b1a08e4ea211be1d822667 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 7 Nov 2021 19:00:15 -0500 Subject: [PATCH 17/21] zip on 1.0 --- src/tangent_types/abstract_zero.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index 2616d50b4..c9203ac75 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -84,7 +84,7 @@ function _promote_vectors(x::AbstractVector, y::AbstractVector) if isconcretetype(T) return convert(T, x), convert(T, y) else - short = map(first ∘ promote, x, y) + short = map(Base.splat(first ∘ promote), zip(x, y)) return convert(typeof(short), x), convert(typeof(short), y) end end From eed6da2aface10972be22be5942dd0e4602f1435 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 7 Nov 2021 19:12:19 -0500 Subject: [PATCH 18/21] zip' --- src/tangent_types/abstract_zero.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index c9203ac75..c661d38cb 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -84,7 +84,11 @@ function _promote_vectors(x::AbstractVector, y::AbstractVector) if isconcretetype(T) return convert(T, x), convert(T, y) else - short = map(Base.splat(first ∘ promote), zip(x, y)) + if VERSION > v"1.4" + short = map(first ∘ promote, x, y) + else # on 1.0 and friends, neither map nor zip stop early. So we improvise + short = [promote(x[i], y[i])[1] for i in intersect(axes(x, 1), axes(y, 1))] + end return convert(typeof(short), x), convert(typeof(short), y) end end From 43da543c584a49ca1496a5a62cb5ae7c8cdcac1a Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 11 Feb 2022 23:04:07 -0500 Subject: [PATCH 19/21] allow 3d arrays for Diagonal --- src/projection.jl | 4 ++++ test/projection.jl | 1 + 2 files changed, 5 insertions(+) diff --git a/src/projection.jl b/src/projection.jl index ef246fa9d..ae0c9de73 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -436,6 +436,10 @@ end ProjectTo(x::Diagonal) = ProjectTo{Diagonal}(; diag=ProjectTo(x.diag)) (project::ProjectTo{Diagonal})(dx::AbstractMatrix) = Diagonal(project.diag(diag(dx))) (project::ProjectTo{Diagonal})(dx::Diagonal) = Diagonal(project.diag(dx.diag)) +function (project::ProjectTo{Diagonal})(dx::AbstractArray) + ind = diagind(size(dx,1), size(dx,2), 0) + return Diagonal(project.diag(dx[ind])) +end function (project::ProjectTo{Diagonal})(dx::Tangent{<:Diagonal}) # structural => natural return dx.diag isa ArrayOrZero ? Diagonal(project.diag(dx.diag)) : dx end diff --git a/test/projection.jl b/test/projection.jl index 79b774647..62376dce5 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -307,6 +307,7 @@ struct NoSuperType end @testset "LinearAlgebra: sparse structured matrices" begin pdiag = ProjectTo(Diagonal(1:3)) @test pdiag(reshape(1:9, 3, 3)) == Diagonal([1, 5, 9]) + @test pdiag(reshape(1:9, 3, 3, 1)) == Diagonal([1, 5, 9]) @test pdiag(pdiag(reshape(1:9, 3, 3))) == pdiag(reshape(1:9, 3, 3)) @test pdiag(rand(ComplexF32, 3, 3)) isa Diagonal{Float64} @test pdiag(Diagonal(1.0:3.0)) === Diagonal(1.0:3.0) From 524dc76a1d7b4b19b08df966d0588fceead0f9fa Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 11 Feb 2022 23:19:22 -0500 Subject: [PATCH 20/21] fixup --- src/tangent_types/abstract_zero.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/tangent_types/abstract_zero.jl b/src/tangent_types/abstract_zero.jl index c661d38cb..d83680f17 100644 --- a/src/tangent_types/abstract_zero.jl +++ b/src/tangent_types/abstract_zero.jl @@ -38,12 +38,12 @@ LinearAlgebra.adjoint(z::AbstractZero, ind...) = z LinearAlgebra.transpose(z::AbstractZero, ind...) = z for T in ( - :UniformScaling, :Adjoint, :Transpose, :Diagonal + :UniformScaling, :Adjoint, :Transpose, :Diagonal, :UpperTriangular, :LowerTriangular, :UpperHessenberg, :UnitUpperTriangular, :UnitLowerTriangular, ) - VERSION < v"1.4" && f == :UpperHessenberg && continue # not defined in 1.0 - @eval (::Type{<:LinearAlgebra.$T})(z::AbstractZero) = z + VERSION < v"1.4" && T == :UpperHessenberg && continue # not defined in 1.0 + @eval LinearAlgebra.$T(z::AbstractZero) = z end LinearAlgebra.Symmetric(z::AbstractZero, uplo=:U) = z From bc4293efa7e32fb64c2c82aa93a49232308ee0c3 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 2 Mar 2022 15:16:41 -0500 Subject: [PATCH 21/21] remove version < 1.6 checks --- src/projection.jl | 9 ++++----- test/projection.jl | 10 ++++------ 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/src/projection.jl b/src/projection.jl index ae0c9de73..7a55e956d 100644 --- a/src/projection.jl +++ b/src/projection.jl @@ -472,7 +472,6 @@ end # Triangular for UL in (:UpperTriangular, :LowerTriangular, :UpperHessenberg) - VERSION < v"1.4" && UL == :UpperHessenberg && continue # not defined in 1.0 @eval begin ProjectTo(x::$UL) = ProjectTo{$UL}(; data=ProjectTo(parent(x))) (project::ProjectTo{$UL})(dx::AbstractArray) = $UL(project.data(dx)) @@ -499,10 +498,10 @@ end # Subspaces which aren't subtypes, like Diagonal inside Symmetric above: (project::ProjectTo{UpperTriangular})(dx::Diagonal) = project.data(dx) (project::ProjectTo{LowerTriangular})(dx::Diagonal) = project.data(dx) -if VERSION >= v"1.4" - (project::ProjectTo{UpperHessenberg})(dx::Diagonal) = project.data(dx) - (project::ProjectTo{UpperHessenberg})(dx::UpperTriangular) = project.data(dx) -end + +(project::ProjectTo{UpperHessenberg})(dx::Diagonal) = project.data(dx) +(project::ProjectTo{UpperHessenberg})(dx::UpperTriangular) = project.data(dx) + (project::ProjectTo{UnitUpperTriangular})(dx::Diagonal) = NoTangent() (project::ProjectTo{UnitLowerTriangular})(dx::Diagonal) = NoTangent() diff --git a/test/projection.jl b/test/projection.jl index 62376dce5..8e7e88be3 100644 --- a/test/projection.jl +++ b/test/projection.jl @@ -287,11 +287,9 @@ struct NoSuperType end @test pupp(rand(ComplexF32, 3, 3, 1)) isa UpperTriangular{Float64} @test ProjectTo(UpperTriangular(randn(3, 3) .> 0))(randn(3, 3)) == NoTangent() - if VERSION >= v"1.4" # not sure 1.4 exactly! - phess = ProjectTo(UpperHessenberg(rand(3, 3))) - @test phess(reshape(1:9,3,3)) == [1 4 7; 2 5 8; 0 6 9] - @test phess(reshape(1:9,3,3) .+ im) isa UpperHessenberg{Float64} - end + phess = ProjectTo(UpperHessenberg(rand(3, 3))) + @test phess(reshape(1:9,3,3)) == [1 4 7; 2 5 8; 0 6 9] + @test phess(reshape(1:9,3,3) .+ im) isa UpperHessenberg{Float64} pdu = ProjectTo(UnitLowerTriangular(rand(3, 3))) # NB, since the diagonal is constant 1, its gradient is zero: @@ -494,7 +492,7 @@ struct NoSuperType end @test eval(Meta.parse(str))(ones(1, 3)) isa Adjoint{Float64,Vector{Float64}} end - VERSION > v"1.1" && @testset "allocation tests" begin + @testset "allocation tests" begin # For sure these fail on Julia 1.0, not sure about 1.3 etc. # We only really care about current stable anyway # Each "@test 33 > ..." is zero on nightly, 32 on 1.5.