Skip to content

Commit 653cd7d

Browse files
authored
Add more methods on thunks to make ChainRules work on 1.0 (#381)
1 parent 66bd0a1 commit 653cd7d

3 files changed

Lines changed: 15 additions & 1 deletion

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRulesCore"
22
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
3-
version = "0.10.8"
3+
version = "0.10.9"
44

55
[deps]
66
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

src/differentials/thunks.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,10 @@ Base.mapreduce(f, op, a::AbstractThunk; kws...) = mapreduce(f, op, unthunk(a); k
4141
function Base.mapreduce(f, op, itr, a::AbstractThunk; kws...)
4242
return mapreduce(f, op, itr, unthunk(a); kws...)
4343
end
44+
Base.sum(a::AbstractThunk; kws...) = sum(unthunk(a); kws...)
4445
Base.sum!(r, A::AbstractThunk; kws...) = sum!(r, unthunk(A); kws...)
4546

47+
Base.fill(a::AbstractThunk, b::Integer) = fill(unthunk(a), b)
4648
Base.vec(a::AbstractThunk) = vec(unthunk(a))
4749
Base.reshape(a::AbstractThunk, args...) = reshape(unthunk(a), args...)
4850
Base.getindex(a::AbstractThunk, args...) = getindex(unthunk(a), args...)
@@ -78,6 +80,7 @@ LinearAlgebra.dot(a::AbstractThunk, b::AbstractThunk) = dot(unthunk(a), unthunk(
7880
LinearAlgebra.ldiv!(a, b::AbstractThunk) = throw(MutateThunkException())
7981
LinearAlgebra.rdiv!(a::AbstractThunk, b) = throw(MutateThunkException())
8082

83+
LinearAlgebra.mul!(A, B::AbstractThunk, C) = mul!(A, unthunk(B), C)
8184
LinearAlgebra.mul!(C::AbstractThunk, A, B, α, β) = throw(MutateThunkException())
8285
function LinearAlgebra.mul!(C::AbstractThunk, A::AbstractThunk, B, α, β)
8386
return throw(MutateThunkException())
@@ -190,6 +193,9 @@ end
190193

191194
Base.show(io::IO, x::Thunk) = print(io, "Thunk($(repr(x.f)))")
192195

196+
Base.convert(::Type{<:Thunk}, a::AbstractZero) = @thunk(a)
197+
198+
193199
"""
194200
InplaceableThunk(val::Thunk, add!::Function)
195201

test/differentials/thunks.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@
2222
@test occursin(r"Thunk\(.*rand.*\)", rep)
2323
end
2424

25+
@testset "convert" begin
26+
@test convert(Thunk, ZeroTangent()) isa Thunk
27+
end
28+
2529
@testset "unthunk" begin
2630
@test unthunk(@thunk(3)) == 3
2731
@test unthunk(@thunk(@thunk(3))) isa Thunk
@@ -107,8 +111,10 @@
107111
@test 3 == mapreduce(_ -> 1, +, t)
108112
@test 3 == mapreduce((_, _) -> 1, +, v, t)
109113
end
114+
@test 10 == sum(@thunk([1 2; 3 4]))
110115
@test [4 6] == sum!([1 1], @thunk([1 2; 3 4]))
111116

117+
@test fill(3.2, 3) == fill(@thunk(3.2), 3)
112118
@test v == vec(t)
113119
@test [1 2 3] == reshape(t, 1, 3)
114120
@test 1 == getindex(t, 1)
@@ -152,6 +158,8 @@
152158
rdiv!(deepcopy(a), 2.0)
153159
end
154160

161+
@test mul!(deepcopy(a), a, a) == mul!(deepcopy(a), t, a)
162+
155163
res = mul!(deepcopy(a), a, a, true, true)
156164
@test_throws MutateThunkException mul!(deepcopy(t), a, a, true, true)
157165
@test_throws MutateThunkException mul!(deepcopy(t), t, a, true, true)

0 commit comments

Comments
 (0)