Public API
Yota.grad — Functiongrad(f, args...; seed=1)Find gradient of a callable f w.r.t. its arguments.
grad() returns two things: value of f(args...) and a tuple of grafients w.r.t. to its inputs (including the callable itself).
using Yota # hide
val, g = grad(x -> sum(x .+ 1), [1.0, 2.0, 3.0])
# output
(9.0, (ChainRulesCore.ZeroTangent(), [1.0, 1.0, 1.0]))By default, grad() expects the callable to return a scalar. Vector-valued functions can be differentiated if a seed (starting value) is provided. Seed is equivalent to the vector in VJP notation.
using Yota # hide
val, g = grad(x -> 2x, [1.0, 2.0, 3.0]; seed=ones(3))
# output
([2.0, 4.0, 6.0], (ChainRulesCore.ZeroTangent(), [2.0, 2.0, 2.0]))All gradients can be applied to original variables using update!() function.
See also: gradtape
Yota.gradtape — Functiongradtape(f, args...; ctx=GradCtx(), seed=1)
gradtape!(tape::Tape; seed=1)Calculate and record to the tape gradients of tape[tape.resultid] w.r.t. Input nodes. See grad() for more high-level API.
Yota.YotaRuleConfig — TypeYotaRuleConfig()ChainRules.RuleConfig passed to all rrules in Yota. Extends RuleConfig{Union{NoForwardsMode,HasReverseMode}}.
ChainRulesCore.rrule_via_ad — Functionrrule_via_ad(::YotaRuleConfig, f, args...)Generate rrule using Yota.
Internals
Missing docstring for record_primitive!. Check Documenter's build log for details.
Yota.back! — FunctionBackpropagate through the tape, record derivatives as new operations.
Yota.step_back! — FunctionMake a single step of backpropagation.
Yota.todo_list — FunctionCollect variables that we need to step through during the reverse pass. The returned vector is already deduplicated and reverse-sorted
Yota.grad_compile — FunctionLike Umlaut.compile, but adds Yota specific ops
Missing docstring for make_rrule. Check Documenter's build log for details.
Yota._getfield — Function_getfield(value, fld)This function can be used instead of getfield() to bypass Yota rules during backpropagation.
Yota.isstruct — FunctionCheck if an object is of a struct type, i.e. not a number or array