@@ -72,12 +72,6 @@ function generic_projection(project::ProjectTo{T}, dx::T) where {T}
7272 return construct (T, map (_maybe_call, sub_projects, sub_dxs))
7373end
7474
75- function (project:: ProjectTo{T} )(dx:: Tangent ) where {T}
76- sub_projects = backing (project)
77- sub_dxs = backing (canonicalize (dx))
78- return construct (T, map (_maybe_call, sub_projects, sub_dxs))
79- end
80-
8175# Used for encoding fields, leaves alone non-diff types:
8276_maybe_projector (x:: Union{AbstractArray,Number,Ref} ) = ProjectTo (x)
8377_maybe_projector (x) = x
@@ -135,6 +129,14 @@ ProjectTo(::AbstractZero) = ProjectTo{NoTangent}() # Any x::Zero in forward pas
135129(:: ProjectTo{NoTangent} )(dx) = NoTangent () # but this is the projection only for nonzero gradients,
136130(:: ProjectTo{NoTangent} )(:: NoTangent ) = NoTangent () # and this one solves an ambiguity.
137131
132+ # Tangent
133+ # This may be produced from e.g. x=range(1,2,length=3). There need not be any
134+ # AbstractArray representation of such a tangent, so we just pass it along,
135+ # and trust that projection on fields before the constructor will act if necessary.
136+ (:: ProjectTo{T} )(dx:: Tangent{<:T} ) where {T} = dx
137+
138+ # (project::ProjectTo{<:AbstractArray})(dx::Tangent{<:AbstractArray}) = dx
139+
138140# ####
139141# #### `Base`
140142# ####
@@ -241,18 +243,21 @@ function (project::ProjectTo{AbstractArray})(dx::Number) # ... so we restore fro
241243 return fill (project. element (dx))
242244end
243245
244- # Ref -- works like a zero-array, also allows restoration from a number:
245- ProjectTo (x:: Ref ) = ProjectTo {Ref} (; x= ProjectTo (x[]))
246- (project:: ProjectTo{Ref} )(dx:: Ref ) = Ref (project. x (dx[]))
247- (project:: ProjectTo{Ref} )(dx:: Number ) = Ref (project. x (dx))
248-
249246function _projection_mismatch (axes_x:: Tuple , size_dx:: Tuple )
250247 size_x = map (length, axes_x)
251248 return DimensionMismatch (
252249 " variable with size(x) == $size_x cannot have a gradient with size(dx) == $size_dx "
253250 )
254251end
255252
253+ # Ref
254+ # This can't be its own tangent, so it standardises on a Tangent{<:Ref}
255+ ProjectTo (x:: Ref ) = ProjectTo {Ref} (; reftype= typeof (x), x= ProjectTo (x[]))
256+ (project:: ProjectTo{Ref} )(dx:: Ref ) = Tangent {project.reftype} (; x= project. x (dx[]))
257+ (project:: ProjectTo{Ref} )(dx:: Tangent ) = Tangent {project.reftype} (; x= project. x (dx. x))
258+ # Since this works like a zero-array in broadcasting, it should also accept a number:
259+ (project:: ProjectTo{Ref} )(dx:: Number ) = Tangent {project.reftype} (; x= project. x (dx))
260+
256261# ####
257262# #### `LinearAlgebra`
258263# ####
0 commit comments