Linearized traces
Tracing
Usually, programs are executed as a sequence of nested function calls, e.g.:
foo(x) = 2x
bar(x, y) = foo(x) + 3y
baz(x, y) = bar(x, y) - 1
baz(1.0, 2.0)
7.0
Sometimes, however, it's more convenient to work with a linearized representation of the computation. Example use cases include collecting computational graphs for automatic differentiation, exporting to ONNX, serialization of functions to library-independent format, etc. trace()
lets you obtain such a linearized representation:
using Umlaut
val, tape = trace(baz, 1.0, 2.0)
(7.0, Tape{Umlaut.BaseCtx}
inp %1::typeof(Main.baz)
inp %2::Float64
inp %3::Float64
%4 = *(2, %2)::Float64
%5 = *(3, %3)::Float64
%6 = +(%4, %5)::Float64
%7 = -(%6, 1)::Float64
)
trace()
returns two values - the result of the original function call and the generated tape. The structure of the tape is described in Tape anatomy section, here just note that trace()
recursed into baz()
, bar()
and foo()
, but recorded +
, -
and *
onto the tape as is. This is because +
, -
and *
are considered "primitives", i.e. the most basic operations which all other functions consist of. This behavior can be customized using a tracing context.
Context
Context is a way to customize tracing and attach arbitrary data to the generated tape. For example, here's how we can add a new function to the list of primitives:
import Umlaut: isprimitive, BaseCtx
struct MyCtx end
isprimitive(::MyCtx, f, args...) = isprimitive(BaseCtx(), f, args...) || f == foo
val, tape = trace(baz, 1.0, 2.0; ctx=MyCtx())
(7.0, Tape{Main.MyCtx}
inp %1::typeof(Main.baz)
inp %2::Float64
inp %3::Float64
%4 = foo(%2)::Float64
%5 = *(3, %3)::Float64
%6 = +(%4, %5)::Float64
%7 = -(%6, 1)::Float64
)
In this code:
MyCtx
is a new context type; there are no restrictions on the type of contextisprimitive
is a function that decides whether a particular function callf(args...)
should be treated as a primitive in this contextBaseCtx
is the default context that treats all built-in functions from modulesBase
,Core
, etc. as primitives
So we define a new method for isprimitive()
that returns true
for all built-in functions and for function foo
.
isprimitive()
can be artibtrarily complex. For example, if we want to include all functions from a particular module, we can write:
isprimitive(::MyCtx, f, args...) = Base.parentmodule(f) == Main
On the other hand, if we only need to set a few functions as primitives, BaseCtx()
provides a convenient constructor for it:
val, tape = trace(baz, 1.0, 2.0; ctx=BaseCtx([+, -, *, foo]))
(7.0, Tape{Umlaut.BaseCtx}
inp %1::typeof(Main.baz)
inp %2::Float64
inp %3::Float64
%4 = foo(%2)::Float64
%5 = *(3, %3)::Float64
%6 = +(%4, %5)::Float64
%7 = -(%6, 1)::Float64
)
Another useful function is record_primitive!()
, which lets you overload the way a primitive call is recorded to the tape. As a toy example, imagine that we want to replace all invokations of *
with +
and calculate the number of times it has been called. Even though we haven't learned tape anatomy and utils yet, try to parse this code:
import Umlaut: record_primitive!
function loop1(a, n)
a = 2a
for i in 1:n
a = a * n
end
return a
end
mutable struct CountingReplacingCtx
replace::Pair
count::Int
end
# v_fargs is a tuple of Variables or constant values, representing a function call
# that we are about to invoke (but haven't yet)
function record_primitive!(tape::Tape{CountingReplacingCtx}, v_fargs...)
# tape.c refers to the provided context
if v_fargs[1] == tape.c.replace[1]
tape.c.count += 1
return push!(tape, mkcall(tape.c.replace[2], v_fargs[2:end]...))
else
return push!(tape, mkcall(v_fargs...))
end
end
_, tape = trace(loop1, 2.0, 3; ctx=CountingReplacingCtx((*) => (+), 0))
@assert tape.c.count == 4
@assert count(op -> op isa Call && op.fn == (+), tape) == 4
Although we could have done it as a postprocessing using replace!()
, record_primitive!()
has advantage of running before the original function is invoked and thus avoiding double calculation.