Public API

Yota.gradFunction
grad(f, args...; seed=1)

Find gradient of a callable f w.r.t. its arguments.

grad() returns two things: value of f(args...) and a tuple of grafients w.r.t. to its inputs (including the callable itself).

using Yota   # hide

val, g = grad(x -> sum(x .+ 1), [1.0, 2.0, 3.0])

# output
(9.0, (ChainRulesCore.ZeroTangent(), [1.0, 1.0, 1.0]))

By default, grad() expects the callable to return a scalar. Vector-valued functions can be differentiated if a seed (starting value) is provided. Seed is equivalent to the vector in VJP notation.

using Yota   # hide

val, g = grad(x -> 2x, [1.0, 2.0, 3.0]; seed=ones(3))

# output
([2.0, 4.0, 6.0], (ChainRulesCore.ZeroTangent(), [2.0, 2.0, 2.0]))

All gradients can be applied to original variables using update!() function.

See also: gradtape

source
Yota.gradtapeFunction
gradtape(f, args...; ctx=GradCtx(), seed=1)
gradtape!(tape::Tape; seed=1)

Calculate and record to the tape gradients of tape[tape.resultid] w.r.t. Input nodes. See grad() for more high-level API.

source
Yota.YotaRuleConfigType
YotaRuleConfig()

ChainRules.RuleConfig passed to all rrules in Yota. Extends RuleConfig{Union{NoForwardsMode,HasReverseMode}}.

source

Internals

Missing docstring.

Missing docstring for record_primitive!. Check Documenter's build log for details.

Yota.back!Function

Backpropagate through the tape, record derivatives as new operations.

source
Yota.todo_listFunction

Collect variables that we need to step through during the reverse pass. The returned vector is already deduplicated and reverse-sorted

source
Missing docstring.

Missing docstring for make_rrule. Check Documenter's build log for details.

Yota._getfieldFunction
_getfield(value, fld)

This function can be used instead of getfield() to bypass Yota rules during backpropagation.

source
Yota.isstructFunction

Check if an object is of a struct type, i.e. not a number or array

source

Index