Yota
Basic usage
The most important function is grad()
, which has the form grad(f, args...) -> (output, gradients)
, e.g.:
using Yota
f(x) = 5x + 3
val, g = grad(f, 10)
(53, (ChainRulesCore.ZeroTangent(), 5.0))
Here val
is the result of f(10)
and g
is a tuple of gradients w.r.t. to the inputs including the function itself (which is ZeroTangent()
in this case).
A bit more complex example from the ML domain:
using Yota
mutable struct Linear{T}
W::AbstractArray{T,2}
b::AbstractArray{T}
end
(m::Linear)(X) = m.W * X .+ m.b
# not very useful, but simple example of a loss function
loss(m::Linear, X) = sum(m(X))
m = Linear(rand(3,4), rand(3))
X = rand(4,5)
val, g = grad(loss, m, X)
@show g[2].W
@show g[2].b
3-element Vector{Float64}:
5.0
5.0
5.0
The computed gradients can then be used in the update!()
function to modify tensors and fields of (mutable) structs:
for i=1:100
val, g = grad(loss, m, X)
println("Loss value in $(i)th epoch: $val")
update!(m, g[2], (x, gx) -> x .- 0.01gx)
end