Public API
Yota.grad
— Functiongrad(f, args...; seed=1)
Find gradient of a callable f
w.r.t. its arguments.
grad()
returns two things: value of f(args...)
and a tuple of grafients w.r.t. to its inputs (including the callable itself).
using Yota # hide
val, g = grad(x -> sum(x .+ 1), [1.0, 2.0, 3.0])
# output
(9.0, (ChainRulesCore.ZeroTangent(), [1.0, 1.0, 1.0]))
By default, grad()
expects the callable to return a scalar. Vector-valued functions can be differentiated if a seed (starting value) is provided. Seed is equivalent to the vector in VJP notation.
using Yota # hide
val, g = grad(x -> 2x, [1.0, 2.0, 3.0]; seed=ones(3))
# output
([2.0, 4.0, 6.0], (ChainRulesCore.ZeroTangent(), [2.0, 2.0, 2.0]))
All gradients can be applied to original variables using update!()
function.
See also: gradtape
Yota.gradtape
— Functiongradtape(f, args...; ctx=GradCtx(), seed=1)
gradtape!(tape::Tape; seed=1)
Calculate and record to the tape gradients of tape[tape.resultid]
w.r.t. Input
nodes. See grad() for more high-level API.
Yota.YotaRuleConfig
— TypeYotaRuleConfig()
ChainRules.RuleConfig passed to all rrule
s in Yota. Extends RuleConfig{Union{NoForwardsMode,HasReverseMode}}.
ChainRulesCore.rrule_via_ad
— Functionrrule_via_ad(::YotaRuleConfig, f, args...)
Generate rrule
using Yota.
Internals
Missing docstring for record_primitive!
. Check Documenter's build log for details.
Yota.back!
— FunctionBackpropagate through the tape, record derivatives as new operations.
Yota.step_back!
— FunctionMake a single step of backpropagation.
Yota.todo_list
— FunctionCollect variables that we need to step through during the reverse pass. The returned vector is already deduplicated and reverse-sorted
Yota.grad_compile
— FunctionLike Umlaut.compile, but adds Yota specific ops
Missing docstring for make_rrule
. Check Documenter's build log for details.
Yota._getfield
— Function_getfield(value, fld)
This function can be used instead of getfield() to bypass Yota rules during backpropagation.
Yota.isstruct
— FunctionCheck if an object is of a struct type, i.e. not a number or array