@@ -59,10 +59,8 @@ function generic_projector(x::T; kw...) where {T}
5959 fields_nt:: NamedTuple = backing (x)
6060 fields_proj = map (_maybe_projector, fields_nt)
6161 # We can't use `T` because if we have `Foo{Matrix{E}}` it should be allowed to make a
62- # `Foo{Diagaonal{E}}` etc. We assume it has a default constructor that has all fields
63- # but if it doesn't `construct` will give a good error message.
62+ # `Foo{Diagaonal{E}}` etc. Official API for this? https://github.com/JuliaLang/julia/issues/35543
6463 wrapT = T. name. wrapper
65- # Official API for this? https://github.com/JuliaLang/julia/issues/35543
6664 return ProjectTo {wrapT} (; fields_proj... , kw... )
6765end
6866
@@ -252,11 +250,18 @@ end
252250
253251# Ref
254252# This can't be its own tangent, so it standardises on a Tangent{<:Ref}
255- ProjectTo (x:: Ref ) = ProjectTo {Ref} (; reftype= typeof (x), x= ProjectTo (x[]))
256- (project:: ProjectTo{Ref} )(dx:: Ref ) = Tangent {project.reftype} (; x= project. x (dx[]))
257- (project:: ProjectTo{Ref} )(dx:: Tangent ) = Tangent {project.reftype} (; x= project. x (dx. x))
253+ function ProjectTo (x:: Ref )
254+ sub = ProjectTo (x[]) # should we worry about isdefined(Ref{Vector{Int}}(), :x)?
255+ if sub isa ProjectTo{<: AbstractZero }
256+ return ProjectTo {NoTangent} ()
257+ else
258+ return ProjectTo {Ref} (; type= typeof (x), x= sub)
259+ end
260+ end
261+ (project:: ProjectTo{Ref} )(dx:: Ref ) = Tangent {project.type} (; x= project. x (dx[]))
262+ (project:: ProjectTo{Ref} )(dx:: Tangent ) = Tangent {project.type} (; x= project. x (dx. x))
258263# Since this works like a zero-array in broadcasting, it should also accept a number:
259- (project:: ProjectTo{Ref} )(dx:: Number ) = Tangent {project.reftype } (; x= project. x (dx))
264+ (project:: ProjectTo{Ref} )(dx:: Number ) = Tangent {project.type } (; x= project. x (dx))
260265
261266# ####
262267# #### `LinearAlgebra`
0 commit comments