@@ -67,7 +67,7 @@ function (project::ProjectTo{T})(dx::Tangent) where {T}
6767end
6868
6969# Used for encoding fields, leaves alone non-diff types:
70- _maybe_projector (x:: Union{AbstractArray, Number, Ref} ) = ProjectTo (x)
70+ _maybe_projector (x:: Union{AbstractArray,Number,Ref} ) = ProjectTo (x)
7171_maybe_projector (x) = x
7272# Used for re-constructing fields, restores non-diff types:
7373_maybe_call (f:: ProjectTo , x) = f (x)
161161function ProjectTo (xs:: AbstractArray )
162162 elements = map (ProjectTo, xs)
163163 if elements isa AbstractArray{<: ProjectTo{<:AbstractZero} }
164- return ProjectTo {NoTangent} () # short-circuit if all elements project to zero
164+ return ProjectTo {NoTangent} () # short-circuit if all elements project to zero
165165 else
166166 # Arrays of arrays come here, and will apply projectors individually:
167167 return ProjectTo {AbstractArray} (; elements= elements, axes= axes (xs))
@@ -175,7 +175,9 @@ function (project::ProjectTo{AbstractArray})(dx::AbstractArray{S,M}) where {S,M}
175175 dx
176176 else
177177 for d in 1 : max (M, length (project. axes))
178- size (dx, d) == length (get (project. axes, d, 1 )) || throw (_projection_mismatch (project. axes, size (dx)))
178+ if size (dx, d) != length (get (project. axes, d, 1 ))
179+ throw (_projection_mismatch (project. axes, size (dx)))
180+ end
179181 end
180182 reshape (dx, project. axes)
181183 end
@@ -185,29 +187,37 @@ function (project::ProjectTo{AbstractArray})(dx::AbstractArray{S,M}) where {S,M}
185187 T = project_type (project. element)
186188 S <: T ? dy : map (project. element, dy)
187189 else
188- map ((f,y) -> f (y), project. elements, dy)
190+ map ((f, y) -> f (y), project. elements, dy)
189191 end
190192 return dz
191193end
192194
193195# Row vectors aren't acceptable as gradients for 1-row matrices:
194- (project:: ProjectTo{AbstractArray} )(dx:: LinearAlgebra.AdjOrTransAbsVec ) = project (reshape (vec (dx),1 ,:))
196+ function (project:: ProjectTo{AbstractArray} )(dx:: LinearAlgebra.AdjOrTransAbsVec )
197+ return project (reshape (vec (dx), 1 , :))
198+ end
195199
196200# Zero-dimensional arrays -- these have a habit of going missing,
197201# although really Ref() is probably a better structure.
198202function (project:: ProjectTo{AbstractArray} )(dx:: Number ) # ... so we restore from numbers
199- project. axes isa Tuple{} || throw (DimensionMismatch (" array with ndims(x) == $(length (project. axes)) > 0 cannot have as gradient dx::Number" ))
203+ if ! (project. axes isa Tuple{})
204+ throw (DimensionMismatch (
205+ " array with ndims(x) == $(length (project. axes)) > 0 cannot have dx::Number" ,
206+ ))
207+ end
200208 return fill (project. element (dx))
201209end
202210
203211# Ref -- works like a zero-array, also allows restoration from a number:
204- ProjectTo (x:: Ref ) = ProjectTo {Ref} (; x = ProjectTo (x[]))
212+ ProjectTo (x:: Ref ) = ProjectTo {Ref} (; x= ProjectTo (x[]))
205213(project:: ProjectTo{Ref} )(dx:: Ref ) = Ref (project. x (dx[]))
206214(project:: ProjectTo{Ref} )(dx:: Number ) = Ref (project. x (dx))
207215
208216function _projection_mismatch (axes_x:: Tuple , size_dx:: Tuple )
209217 size_x = map (length, axes_x)
210- return DimensionMismatch (" variable with size(x) == $size_x cannot have a gradient with size(dx) == $size_dx " )
218+ return DimensionMismatch (
219+ " variable with size(x) == $size_x cannot have a gradient with size(dx) == $size_dx "
220+ )
211221end
212222
213223# ####
@@ -217,25 +227,33 @@ end
217227# Row vectors
218228function ProjectTo (x:: LinearAlgebra.AdjointAbsVec )
219229 sub = ProjectTo (parent (x))
220- ProjectTo {Adjoint} (; parent= sub)
230+ return ProjectTo {Adjoint} (; parent= sub)
221231end
222232# Note that while [1 2; 3 4]' isa Adjoint, we use ProjectTo{Adjoint} only to encode AdjointAbsVec.
223233# Transposed matrices are, like PermutedDimsArray, just a storage detail,
224234# but row vectors behave differently, for example [1,2,3]' * [1,2,3] isa Number
225- (project:: ProjectTo{Adjoint} )(dx:: LinearAlgebra.AdjOrTransAbsVec ) = adjoint (project. parent (adjoint (dx)))
235+ function (project:: ProjectTo{Adjoint} )(dx:: LinearAlgebra.AdjOrTransAbsVec )
236+ return adjoint (project. parent (adjoint (dx)))
237+ end
226238function (project:: ProjectTo{Adjoint} )(dx:: AbstractArray )
227- size (dx,1 ) == 1 && size (dx,2 ) == length (project. parent. axes[1 ]) || throw (_projection_mismatch ((1 : 1 , project. parent. axes... ), size (dx)))
239+ if size (dx, 1 ) != 1 || size (dx, 2 ) != length (project. parent. axes[1 ])
240+ throw (_projection_mismatch ((1 : 1 , project. parent. axes... ), size (dx)))
241+ end
228242 dy = eltype (dx) <: Real ? vec (dx) : adjoint (dx)
229243 return adjoint (project. parent (dy))
230244end
231245
232246function ProjectTo (x:: LinearAlgebra.TransposeAbsVec )
233247 sub = ProjectTo (parent (x))
234- ProjectTo {Transpose} (; parent= sub)
248+ return ProjectTo {Transpose} (; parent= sub)
249+ end
250+ function (project:: ProjectTo{Transpose} )(dx:: LinearAlgebra.AdjOrTransAbsVec )
251+ return transpose (project. parent (transpose (dx)))
235252end
236- (project:: ProjectTo{Transpose} )(dx:: LinearAlgebra.AdjOrTransAbsVec ) = transpose (project. parent (transpose (dx)))
237253function (project:: ProjectTo{Transpose} )(dx:: AbstractArray )
238- size (dx,1 ) == 1 && size (dx,2 ) == length (project. parent. axes[1 ]) || throw (_projection_mismatch ((1 : 1 , project. parent. axes... ), size (dx)))
254+ if size (dx, 1 ) != 1 || size (dx, 2 ) != length (project. parent. axes[1 ])
255+ throw (_projection_mismatch ((1 : 1 , project. parent. axes... ), size (dx)))
256+ end
239257 dy = eltype (dx) <: Number ? vec (dx) : transpose (dx)
240258 return transpose (project. parent (dy))
241259end
250268(project:: ProjectTo{Diagonal} )(dx:: Diagonal ) = Diagonal (project. diag (dx. diag))
251269
252270# Symmetric
253- for (SymHerm, chk, fun) in ((:Symmetric , :issymmetric , :transpose ), (:Hermitian , :ishermitian , :adjoint ))
271+ for (SymHerm, chk, fun) in (
272+ (:Symmetric , :issymmetric , :transpose ),
273+ (:Hermitian , :ishermitian , :adjoint ),
274+ )
254275 @eval begin
255276 function ProjectTo (x:: $SymHerm )
256277 sub = ProjectTo (parent (x))
@@ -268,7 +289,9 @@ for (SymHerm, chk, fun) in ((:Symmetric, :issymmetric, :transpose), (:Hermitian,
268289 # not clear how broadly it's worthwhile to try to support this.
269290 function (project:: ProjectTo{$SymHerm} )(dx:: Diagonal )
270291 sub = project. parent # this is going to be unhappy about the size
271- sub_one = ProjectTo {project_type(sub)} (; element = sub. element, axes = (sub. axes[1 ],))
292+ sub_one = ProjectTo {project_type(sub)} (;
293+ element= sub. element, axes= (sub. axes[1 ],)
294+ )
272295 return Diagonal (sub_one (dx. diag))
273296 end
274297 end
@@ -279,13 +302,16 @@ for UL in (:UpperTriangular, :LowerTriangular, :UnitUpperTriangular, :UnitLowerT
279302 @eval begin
280303 function ProjectTo (x:: $UL )
281304 sub = ProjectTo (parent (x))
282- sub isa ProjectTo{<: AbstractZero } && return sub # TODO not necc if UnitUpperTriangular(NoTangent()) etc. worked
305+ # TODO not nesc if UnitUpperTriangular(NoTangent()) etc. worked
306+ sub isa ProjectTo{<: AbstractZero } && return sub
283307 return ProjectTo {$UL} (; parent= sub)
284308 end
285309 (project:: ProjectTo{$UL} )(dx:: AbstractArray ) = $ UL (project. parent (dx))
286310 function (project:: ProjectTo{$UL} )(dx:: Diagonal )
287311 sub = project. parent
288- sub_one = ProjectTo {project_type(sub)} (; element = sub. element, axes = (sub. axes[1 ],))
312+ sub_one = ProjectTo {project_type(sub)} (;
313+ element= sub. element, axes= (sub. axes[1 ],)
314+ )
289315 return Diagonal (sub_one (dx. diag))
290316 end
291317 end
@@ -306,7 +332,7 @@ function (project::ProjectTo{Bidiagonal})(dx::Bidiagonal)
306332 else
307333 uplo = LinearAlgebra. sym_uplo (project. uplo)
308334 dv = project. dv (diag (dx))
309- ev = fill! (similar (dv, length (dv)- 1 ), 0 )
335+ ev = fill! (similar (dv, length (dv) - 1 ), 0 )
310336 return Bidiagonal (dv, ev, uplo)
311337 end
312338end
321347
322348# another strategy is just to use the AbstractArray method
323349function ProjectTo (x:: Tridiagonal{T} ) where {T<: Number }
324- notparent = invoke (ProjectTo, Tuple{AbstractArray{T}} where T<: Number , x)
325- return ProjectTo {Tridiagonal} (; notparent = notparent)
350+ notparent = invoke (ProjectTo, Tuple{AbstractArray{T}} where { T<: Number } , x)
351+ return ProjectTo {Tridiagonal} (; notparent= notparent)
326352end
327353function (project:: ProjectTo{Tridiagonal} )(dx:: AbstractArray )
328354 dy = project. notparent (dx)
@@ -340,20 +366,26 @@ using SparseArrays
340366# This implementation very naiive, can probably be made more efficient.
341367
342368function ProjectTo (x:: SparseVector{T} ) where {T<: Number }
343- return ProjectTo {SparseVector} (; element = ProjectTo (zero (T)), nzind = x. nzind, axes = axes (x))
369+ return ProjectTo {SparseVector} (;
370+ element= ProjectTo (zero (T)), nzind= x. nzind, axes= axes (x)
371+ )
344372end
345373function (project:: ProjectTo{SparseVector} )(dx:: AbstractArray )
346374 dy = if axes (dx) == project. axes
347375 dx
348376 else
349- size (dx, 1 ) == length (project. axes[1 ]) || throw (_projection_mismatch (project. axes, size (dx)))
377+ if size (dx, 1 ) != length (project. axes[1 ])
378+ throw (_projection_mismatch (project. axes, size (dx)))
379+ end
350380 reshape (dx, project. axes)
351381 end
352382 nzval = map (i -> project. element (dy[i]), project. nzind)
353383 return SparseVector (length (dx), project. nzind, nzval)
354384end
355385function (project:: ProjectTo{SparseVector} )(dx:: SparseVector )
356- size (dx) == map (length, project. axes) || throw (_projection_mismatch (project. axes, size (dx)))
386+ if size (dx) != map (length, project. axes)
387+ throw (_projection_mismatch (project. axes, size (dx)))
388+ end
357389 # When sparsity pattern is unchanged, all the time is in checking this,
358390 # perhaps some simple hash/checksum might be good enough?
359391 samepattern = project. nzind == dx. nzind
@@ -373,17 +405,23 @@ function (project::ProjectTo{SparseVector})(dx::SparseVector)
373405end
374406
375407function ProjectTo (x:: SparseMatrixCSC{T} ) where {T<: Number }
376- ProjectTo {SparseMatrixCSC} (; element = ProjectTo (zero (T)), axes = axes (x),
377- rowval = rowvals (x), nzranges = nzrange .(Ref (x), axes (x,2 )), colptr = x. colptr)
408+ return ProjectTo {SparseMatrixCSC} (;
409+ element= ProjectTo (zero (T)),
410+ axes= axes (x),
411+ rowval= rowvals (x),
412+ nzranges= nzrange .(Ref (x), axes (x, 2 )),
413+ colptr= x. colptr,
414+ )
378415end
379416# You need not really store nzranges, you can get them from colptr -- TODO
380417# nzrange(S::AbstractSparseMatrixCSC, col::Integer) = getcolptr(S)[col]:(getcolptr(S)[col+1]-1)
381418function (project:: ProjectTo{SparseMatrixCSC} )(dx:: AbstractArray )
382419 dy = if axes (dx) == project. axes
383420 dx
384421 else
385- size (dx, 1 ) == length (project. axes[1 ]) || throw (_projection_mismatch (project. axes, size (dx)))
386- size (dx, 2 ) == length (project. axes[2 ]) || throw (_projection_mismatch (project. axes, size (dx)))
422+ if size (dx) != (length (project. axes[1 ]), length (project. axes[2 ]))
423+ throw (_projection_mismatch (project. axes, size (dx)))
424+ end
387425 reshape (dx, project. axes)
388426 end
389427 nzval = Vector {project_type(project.element)} (undef, length (project. rowval))
@@ -392,15 +430,17 @@ function (project::ProjectTo{SparseMatrixCSC})(dx::AbstractArray)
392430 for i in project. nzranges[col]
393431 row = project. rowval[i]
394432 val = dy[row, col]
395- nzval[k+= 1 ] = project. element (val)
433+ nzval[k += 1 ] = project. element (val)
396434 end
397435 end
398436 m, n = map (length, project. axes)
399437 return SparseMatrixCSC (m, n, project. colptr, project. rowval, nzval)
400438end
401439
402440function (project:: ProjectTo{SparseMatrixCSC} )(dx:: SparseMatrixCSC )
403- size (dx) == map (length, project. axes) || throw (_projection_mismatch (project. axes, size (dx)))
441+ if size (dx) != map (length, project. axes)
442+ throw (_projection_mismatch (project. axes, size (dx)))
443+ end
404444 samepattern = dx. colptr == project. colptr && dx. rowval == project. rowval
405445 # samepattern = length(dx.colptr) == length(project.colptr) && dx.colptr[end] == project.colptr[end]
406446 if eltype (dx) <: project_type (project. element) && samepattern
0 commit comments