Cookbook
Value and gradient:
f(x, y) = x^2 + sqrt(y)
val, g = grad(f, 2.0, 3.0)
_, dx, dy = g
(ChainRulesCore.ZeroTangent(), 4.0, 0.2886751345948129)
Gradient tape (useful for further processing):
tape = gradtape(f, 2.0, 3.0)
Tape{Yota.GradCtx}
inp %1::typeof(Main.f)
inp %2::Float64
inp %3::Float64
%4 = rrule(Yota.YotaRuleConfig(), Core.apply_type, Val, 2)::Tuple{DataType, ChainRules.var"#apply_type_pullback#42"{Tuple{Int64}}}
%5 = _getfield(%4, 1)::DataType
%6 = _getfield(%4, 2)::ChainRules.var"#apply_type_pullback#42"{Tuple{Int64}}
%7 = rrule(Yota.YotaRuleConfig(), %5)::Tuple{Val{2}, Yota.var"#62#63"}
%8 = _getfield(%7, 1)::Val{2}
%9 = _getfield(%7, 2)::Yota.var"#62#63"
%10 = rrule(Yota.YotaRuleConfig(), Base.literal_pow, ^, %2, %8)::Tuple{Float64, ChainRules.var"#square_pullback#1238"{Float64}}
%11 = _getfield(%10, 1)::Float64
%12 = _getfield(%10, 2)::ChainRules.var"#square_pullback#1238"{Float64}
%13 = rrule(Yota.YotaRuleConfig(), sqrt, %3)::Tuple{Float64, ChainRules.var"#sqrt_pullback#1319"{Float64, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}
%14 = _getfield(%13, 1)::Float64
%15 = _getfield(%13, 2)::ChainRules.var"#sqrt_pullback#1319"{Float64, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}
%16 = rrule(Yota.YotaRuleConfig(), +, %11, %14)::Tuple{Float64, ChainRules.var"#+_pullback#1334"{Bool, Bool, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}
%17 = _getfield(%16, 1)::Float64
%18 = _getfield(%16, 2)::ChainRules.var"#+_pullback#1334"{Bool, Bool, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}
const %19 = 1::Int64
%20 = %18(%19)::Tuple{ChainRulesCore.NoTangent, Float64, Float64}
%21 = getfield(%20, 2)::Float64
%22 = getfield(%20, 3)::Float64
%23 = %15(%22)::Tuple{ChainRulesCore.NoTangent, Float64}
%24 = getfield(%23, 2)::Float64
%25 = %12(%21)::Tuple{ChainRulesCore.NoTangent, ChainRulesCore.NoTangent, Float64, ChainRulesCore.NoTangent}
%26 = getfield(%25, 3)::Float64
%27 = getfield(%25, 4)::ChainRulesCore.NoTangent
%28 = tuple(ChainRulesCore.ZeroTangent(), %26, %24)::Tuple{ChainRulesCore.ZeroTangent, Float64, Float64}
%29 = map(ChainRulesCore.unthunk, %28)::Tuple{ChainRulesCore.ZeroTangent, Float64, Float64}
%30 = tuple(%17, %29)::Tuple{Float64, Tuple{ChainRulesCore.ZeroTangent, Float64, Float64}}
VJP, value and gradient:
h(w, b, x) = w * x .+ b
w, b, x = rand(3, 4), rand(3), rand(4, 5)
val, g = grad(h, w, b, x; seed=ones(3, 5))
([2.4806560082289932 2.037825664891348 … 1.9918014181738262 2.0279417221970393; 1.3864556536714439 1.263987933944183 … 1.4260493714949072 1.1619437555984136; 1.9410562635931086 1.2832254147778754 … 1.1732619334268692 1.5240754456995187], (ChainRulesCore.ZeroTangent(), [2.2403357925655802 2.9044866056242133 0.959654601573775 2.713544543806556; 2.2403357925655802 2.9044866056242133 0.959654601573775 2.713544543806556; 2.2403357925655802 2.9044866056242133 0.959654601573775 2.713544543806556], [5.0, 5.0, 5.0], [1.4982931055547928 1.4982931055547928 … 1.4982931055547928 1.4982931055547928; 1.94903447688742 1.94903447688742 … 1.94903447688742 1.94903447688742; 2.049891766581335 2.049891766581335 … 2.049891766581335 2.049891766581335; 1.6003009377777344 1.6003009377777344 … 1.6003009377777344 1.6003009377777344]))
VJP, value and pullback:
import Yota: YotaRuleConfig, rrule_via_ad
h(w, b, x) = w * x .+ b
w, b, x = rand(3, 4), rand(3), rand(4, 5)
val, pb = rrule_via_ad(YotaRuleConfig(), h, w, b, x)
pb(ones(3, 5))
(ChainRulesCore.ZeroTangent(), InplaceableThunk(ChainRules.var"#..., Thunk(ChainRules.var"#...)), [5.0, 5.0, 5.0], InplaceableThunk(ChainRules.var"#..., Thunk(ChainRules.var"#...)))
Reset gradient cache:
Yota.reset!()
Dict{Any, Any}()