Linearized traces
Tracing
Usually, programs are executed as a sequence of nested function calls, e.g.:
foo(x) = 2x
bar(x, y) = foo(x) + 3y
baz(x, y) = bar(x, y) - 1
baz(1.0, 2.0)7.0Sometimes, however, it's more convenient to work with a linearized representation of the computation. Example use cases include collecting computational graphs for automatic differentiation, exporting to ONNX, serialization of functions to library-independent format, etc. trace() lets you obtain such a linearized representation:
using Umlaut
val, tape = trace(baz, 1.0, 2.0)(7.0, Tape{Umlaut.BaseCtx}
inp %1::typeof(Main.baz)
inp %2::Float64
inp %3::Float64
%4 = *(2, %2)::Float64
%5 = *(3, %3)::Float64
%6 = +(%4, %5)::Float64
%7 = -(%6, 1)::Float64
)trace() returns two values - the result of the original function call and the generated tape. The structure of the tape is described in Tape anatomy section, here just note that trace() recursed into baz(), bar() and foo(), but recorded +, - and * onto the tape as is. This is because +, - and * are considered "primitives", i.e. the most basic operations which all other functions consist of. This behavior can be customized using a tracing context.
Context
Context is a way to customize tracing and attach arbitrary data to the generated tape. For example, here's how we can add a new function to the list of primitives:
import Umlaut: isprimitive, BaseCtx
struct MyCtx end
isprimitive(::MyCtx, f, args...) = isprimitive(BaseCtx(), f, args...) || f == foo
val, tape = trace(baz, 1.0, 2.0; ctx=MyCtx())(7.0, Tape{Main.MyCtx}
inp %1::typeof(Main.baz)
inp %2::Float64
inp %3::Float64
%4 = foo(%2)::Float64
%5 = *(3, %3)::Float64
%6 = +(%4, %5)::Float64
%7 = -(%6, 1)::Float64
)In this code:
MyCtxis a new context type; there are no restrictions on the type of contextisprimitiveis a function that decides whether a particular function callf(args...)should be treated as a primitive in this contextBaseCtxis the default context that treats all built-in functions from modulesBase,Core, etc. as primitives
So we define a new method for isprimitive() that returns true for all built-in functions and for function foo.
isprimitive() can be artibtrarily complex. For example, if we want to include all functions from a particular module, we can write:
isprimitive(::MyCtx, f, args...) = Base.parentmodule(f) == MainOn the other hand, if we only need to set a few functions as primitives, BaseCtx() provides a convenient constructor for it:
val, tape = trace(baz, 1.0, 2.0; ctx=BaseCtx([+, -, *, foo]))(7.0, Tape{Umlaut.BaseCtx}
inp %1::typeof(Main.baz)
inp %2::Float64
inp %3::Float64
%4 = foo(%2)::Float64
%5 = *(3, %3)::Float64
%6 = +(%4, %5)::Float64
%7 = -(%6, 1)::Float64
)Another useful function is record_primitive!(), which lets you overload the way a primitive call is recorded to the tape. As a toy example, imagine that we want to replace all invokations of * with + and calculate the number of times it has been called. Even though we haven't learned tape anatomy and utils yet, try to parse this code:
import Umlaut: record_primitive!
function loop1(a, n)
a = 2a
for i in 1:n
a = a * n
end
return a
end
mutable struct CountingReplacingCtx
replace::Pair
count::Int
end
# v_fargs is a tuple of Variables or constant values, representing a function call
# that we are about to invoke (but haven't yet)
function record_primitive!(tape::Tape{CountingReplacingCtx}, v_fargs...)
# tape.c refers to the provided context
if v_fargs[1] == tape.c.replace[1]
tape.c.count += 1
return push!(tape, mkcall(tape.c.replace[2], v_fargs[2:end]...))
else
return push!(tape, mkcall(v_fargs...))
end
end
_, tape = trace(loop1, 2.0, 3; ctx=CountingReplacingCtx((*) => (+), 0))
@assert tape.c.count == 4
@assert count(op -> op isa Call && op.fn == (+), tape) == 4Although we could have done it as a postprocessing using replace!(), record_primitive!() has advantage of running before the original function is invoked and thus avoiding double calculation.