Cookbook

Value and gradient:

f(x, y) = x^2 + sqrt(y)
val, g = grad(f, 2.0, 3.0)
_, dx, dy = g
(ChainRulesCore.ZeroTangent(), 4.0, 0.2886751345948129)

Gradient tape (useful for further processing):

tape = gradtape(f, 2.0, 3.0)
Tape{Yota.GradCtx}
  inp %1::typeof(Main.f)
  inp %2::Float64
  inp %3::Float64
  %4 = rrule(Yota.YotaRuleConfig(), Core.apply_type, Val, 2)::Tuple{DataType, ChainRules.var"#apply_type_pullback#42"{Tuple{Int64}}} 
  %5 = _getfield(%4, 1)::DataType 
  %6 = _getfield(%4, 2)::ChainRules.var"#apply_type_pullback#42"{Tuple{Int64}} 
  %7 = rrule(Yota.YotaRuleConfig(), %5)::Tuple{Val{2}, Yota.var"#62#63"} 
  %8 = _getfield(%7, 1)::Val{2} 
  %9 = _getfield(%7, 2)::Yota.var"#62#63" 
  %10 = rrule(Yota.YotaRuleConfig(), Base.literal_pow, ^, %2, %8)::Tuple{Float64, ChainRules.var"#square_pullback#1238"{Float64}} 
  %11 = _getfield(%10, 1)::Float64 
  %12 = _getfield(%10, 2)::ChainRules.var"#square_pullback#1238"{Float64} 
  %13 = rrule(Yota.YotaRuleConfig(), sqrt, %3)::Tuple{Float64, ChainRules.var"#sqrt_pullback#1319"{Float64, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}} 
  %14 = _getfield(%13, 1)::Float64 
  %15 = _getfield(%13, 2)::ChainRules.var"#sqrt_pullback#1319"{Float64, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}} 
  %16 = rrule(Yota.YotaRuleConfig(), +, %11, %14)::Tuple{Float64, ChainRules.var"#+_pullback#1334"{Bool, Bool, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}} 
  %17 = _getfield(%16, 1)::Float64 
  %18 = _getfield(%16, 2)::ChainRules.var"#+_pullback#1334"{Bool, Bool, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}} 
  const %19 = 1::Int64
  %20 = %18(%19)::Tuple{ChainRulesCore.NoTangent, Float64, Float64} 
  %21 = getfield(%20, 2)::Float64 
  %22 = getfield(%20, 3)::Float64 
  %23 = %15(%22)::Tuple{ChainRulesCore.NoTangent, Float64} 
  %24 = getfield(%23, 2)::Float64 
  %25 = %12(%21)::Tuple{ChainRulesCore.NoTangent, ChainRulesCore.NoTangent, Float64, ChainRulesCore.NoTangent} 
  %26 = getfield(%25, 3)::Float64 
  %27 = getfield(%25, 4)::ChainRulesCore.NoTangent 
  %28 = tuple(ChainRulesCore.ZeroTangent(), %26, %24)::Tuple{ChainRulesCore.ZeroTangent, Float64, Float64} 
  %29 = map(ChainRulesCore.unthunk, %28)::Tuple{ChainRulesCore.ZeroTangent, Float64, Float64} 
  %30 = tuple(%17, %29)::Tuple{Float64, Tuple{ChainRulesCore.ZeroTangent, Float64, Float64}} 

VJP, value and gradient:

h(w, b, x) = w * x .+ b

w, b, x = rand(3, 4), rand(3), rand(4, 5)
val, g = grad(h, w, b, x; seed=ones(3, 5))
([2.4806560082289932 2.037825664891348 … 1.9918014181738262 2.0279417221970393; 1.3864556536714439 1.263987933944183 … 1.4260493714949072 1.1619437555984136; 1.9410562635931086 1.2832254147778754 … 1.1732619334268692 1.5240754456995187], (ChainRulesCore.ZeroTangent(), [2.2403357925655802 2.9044866056242133 0.959654601573775 2.713544543806556; 2.2403357925655802 2.9044866056242133 0.959654601573775 2.713544543806556; 2.2403357925655802 2.9044866056242133 0.959654601573775 2.713544543806556], [5.0, 5.0, 5.0], [1.4982931055547928 1.4982931055547928 … 1.4982931055547928 1.4982931055547928; 1.94903447688742 1.94903447688742 … 1.94903447688742 1.94903447688742; 2.049891766581335 2.049891766581335 … 2.049891766581335 2.049891766581335; 1.6003009377777344 1.6003009377777344 … 1.6003009377777344 1.6003009377777344]))

VJP, value and pullback:

import Yota: YotaRuleConfig, rrule_via_ad

h(w, b, x) = w * x .+ b

w, b, x = rand(3, 4), rand(3), rand(4, 5)
val, pb = rrule_via_ad(YotaRuleConfig(), h, w, b, x)
pb(ones(3, 5))
(ChainRulesCore.ZeroTangent(), InplaceableThunk(ChainRules.var"#..., Thunk(ChainRules.var"#...)), [5.0, 5.0, 5.0], InplaceableThunk(ChainRules.var"#..., Thunk(ChainRules.var"#...)))

Reset gradient cache:

Yota.reset!()
Dict{Any, Any}()