@@ -126,6 +126,11 @@ ProjectTo(::AbstractZero) = ProjectTo{NoTangent}() # Any x::Zero in forward pas
126126(:: ProjectTo{NoTangent} )(dx) = NoTangent () # but this is the projection only for nonzero gradients,
127127(:: ProjectTo{NoTangent} )(:: NoTangent ) = NoTangent () # and this one solves an ambiguity.
128128
129+ # Also, any explicit construction with fields, where all fields project to zero, itself
130+ # projects to zero. This simplifies projectors for wrapper types like Diagonal([true, false]).
131+ const _PZ = ProjectTo{<: AbstractZero }
132+ ProjectTo {P} (:: NamedTuple{T, <:Tuple{_PZ, Vararg{<:_PZ}}} ) where {P,T} = ProjectTo {NoTangent} ()
133+
129134# Tangent
130135# This may be produced from e.g. x=range(1,2,length=3). There need not be any
131136# AbstractArray representation of such a tangent, so we just pass it along,
@@ -265,12 +270,10 @@ end
265270# #### `LinearAlgebra`
266271# ####
267272
273+ using LinearAlgebra: AdjointAbsVec, TransposeAbsVec, AdjOrTransAbsVec
274+
268275# Row vectors
269- function ProjectTo (x:: LinearAlgebra.AdjointAbsVec )
270- sub = ProjectTo (parent (x))
271- sub isa ProjectTo{<: AbstractZero } && return sub
272- return ProjectTo {Adjoint} (; parent= sub)
273- end
276+ ProjectTo (x:: AdjointAbsVec ) = ProjectTo {Adjoint} (; parent= ProjectTo (parent (x)))
274277# Note that while [1 2; 3 4]' isa Adjoint, we use ProjectTo{Adjoint} only to encode AdjointAbsVec.
275278# Transposed matrices are, like PermutedDimsArray, just a storage detail,
276279# but row vectors behave differently, for example [1,2,3]' * [1,2,3] isa Number
@@ -285,11 +288,7 @@ function (project::ProjectTo{Adjoint})(dx::AbstractArray)
285288 return adjoint (project. parent (dy))
286289end
287290
288- function ProjectTo (x:: LinearAlgebra.TransposeAbsVec )
289- sub = ProjectTo (parent (x))
290- sub isa ProjectTo{<: AbstractZero } && return sub
291- return ProjectTo {Transpose} (; parent= sub)
292- end
291+ ProjectTo (x:: LinearAlgebra.TransposeAbsVec ) = ProjectTo {Transpose} (; parent= ProjectTo (parent (x)))
293292function (project:: ProjectTo{Transpose} )(dx:: LinearAlgebra.AdjOrTransAbsVec )
294293 return transpose (project. parent (transpose (dx)))
295294end
@@ -302,11 +301,7 @@ function (project::ProjectTo{Transpose})(dx::AbstractArray)
302301end
303302
304303# Diagonal
305- function ProjectTo (x:: Diagonal )
306- sub = ProjectTo (x. diag)
307- sub isa ProjectTo{<: AbstractZero } && return sub # TODO not necc if Diagonal(NoTangent()) worked
308- return ProjectTo {Diagonal} (; diag= sub)
309- end
304+ ProjectTo (x:: Diagonal ) = ProjectTo {Diagonal} (; diag= ProjectTo (x. diag))
310305(project:: ProjectTo{Diagonal} )(dx:: AbstractMatrix ) = Diagonal (project. diag (diag (dx)))
311306(project:: ProjectTo{Diagonal} )(dx:: Diagonal ) = Diagonal (project. diag (dx. diag))
312307
@@ -318,7 +313,8 @@ for (SymHerm, chk, fun) in (
318313 @eval begin
319314 function ProjectTo (x:: $SymHerm )
320315 sub = ProjectTo (parent (x))
321- sub isa ProjectTo{<: AbstractZero } && return sub # TODO not necc if Hermitian(NoTangent()) etc. worked
316+ # Because the projector stores uplo, ProjectTo(Symmetric(rand(3,3) .> 0)) isn't automatically trivial:
317+ sub isa ProjectTo{<: AbstractZero } && return sub
322318 return ProjectTo {$SymHerm} (; uplo= LinearAlgebra. sym_uplo (x. uplo), parent= sub)
323319 end
324320 function (project:: ProjectTo{$SymHerm} )(dx:: AbstractArray )
343339# Triangular
344340for UL in (:UpperTriangular , :LowerTriangular , :UnitUpperTriangular , :UnitLowerTriangular ) # UpperHessenberg
345341 @eval begin
346- function ProjectTo (x:: $UL )
347- sub = ProjectTo (parent (x))
348- # TODO not nesc if UnitUpperTriangular(NoTangent()) etc. worked
349- sub isa ProjectTo{<: AbstractZero } && return sub
350- return ProjectTo {$UL} (; parent= sub)
351- end
342+ ProjectTo (x:: $UL ) = ProjectTo {$UL} (; parent= ProjectTo (parent (x)))
352343 (project:: ProjectTo{$UL} )(dx:: AbstractArray ) = $ UL (project. parent (dx))
353344 function (project:: ProjectTo{$UL} )(dx:: Diagonal )
354345 sub = project. parent
0 commit comments