Linearized traces

Usually, programs are executed as a sequence of nested function calls, e.g.:

foo(x) = 2x
bar(x, y) = foo(x) + 3y
baz(x, y) = bar(x, y) - 1

baz(1.0, 2.0)
7.0

Sometimes, however, it's more convenient to work with a linearized representation of the computation. Example use cases include collecting computational graphs for automatic differentiation, exporting to ONNX, serialization of functions to library-independent format, etc. trace() lets you obtain such a linearized representation:

using Ghost

val, tape = trace(baz, 1.0, 2.0)
(7.0, Tape{Dict{Any, Any}}
  inp %1::typeof(Main.baz)
  inp %2::Float64
  inp %3::Float64
  %4 = *(2, %2)::Float64
  %5 = *(3, %3)::Float64
  %6 = +(%4, %5)::Float64
  %7 = -(%6, 1)::Float64
)

trace() returns two values - the result of the original function call and the generated tape. The structure of the tape is described in Tape anatomy section, here just note that trace() recursed into baz(), bar() and foo(), but recorded +, - and * onto the tape as is. This is because +, - and * are considered "primitives", i.e. the most basic operations which all other functions consist of. This behavior can be customized using one of the two keyword arguments:

  • primitives - an iterable of functions to be considered primitive
  • is_primitive(sig) - a function which takes a method call signature and returns true if this method must be considered primitive and false otherwise

Here's an example:

val, tape = trace(baz, 1.0, 2.0; primitives=Set([+, -, *, foo]))
(7.0, Tape{Dict{Any, Any}}
  inp %1::typeof(Main.baz)
  inp %2::Float64
  inp %3::Float64
  %4 = foo(%2)::Float64
  %5 = *(3, %3)::Float64
  %6 = +(%4, %5)::Float64
  %7 = -(%6, 1)::Float64
)

The default behavior is defined by Ghost.is_primitive function and can be extended e.g. like this:

function custom_is_primitive(sig)
    return Ghost.is_primitive(sig) || sig == Tuple{typeof(foo), Float64}
end

val, tape = trace(baz, 1.0, 2.0; is_primitive=custom_is_primitive)
(7.0, Tape{Dict{Any, Any}}
  inp %1::typeof(Main.baz)
  inp %2::Float64
  inp %3::Float64
  %4 = foo(%2)::Float64
  %5 = *(3, %3)::Float64
  %6 = +(%4, %5)::Float64
  %7 = -(%6, 1)::Float64
)

An easy way to get a valid call signature is to use Ghost.call_signature.

See also Ghost.FunctionResolver for better understanding of the implementation of is_primitive.

In complex scenarios it may be useful to bring additional application-specific data together with a tape. For this purpose Tape is parametrized by a context type which is Dict{Any, Any} by default, but can be anything. A context object can be attached during tracing using the ctx keyword:

mutable struct MyCtx
    a
    b
end

val, tape = trace(baz, 1.0, 2.0; ctx=MyCtx(0, 0))
(7.0, Tape{Main.MyCtx}
  inp %1::typeof(Main.baz)
  inp %2::Float64
  inp %3::Float64
  %4 = *(2, %2)::Float64
  %5 = *(3, %3)::Float64
  %6 = +(%4, %5)::Float64
  %7 = -(%6, 1)::Float64
)

The presense of the context doesn't affect tracing, but can be used during further tape processing. See Tape context for more details.