Public API

Tracing

Umlaut.traceFunction
trace(f, args...; ctx=BaseCtx())

Trace function call, return the result and the corresponding Tape. trace records to the tape primitive methods and recursively dives into non-primitives.

Tracing can be customized using a context and the following methods:

  • isprimitive(ctx, f, args...) - decides whethere f(args...) should be treated as a primitive.
  • recordprimitive!(tape::Tape{C}, vf, vargs...) - records the primitive call defined by variables `fv(v_args...)` to the tape.

The default context is BaseCtx(), which treats all functions from standard Julia modules as primitives and simply pushes the call to the tape. See the docstrings of these functions for further examples of customization.

Examples:

foo(x) = 2x
bar(x) = foo(x) + 1

val, tape = trace(bar, 2.0)
# (5.0, Tape{Dict{Any, Any}}
#   inp %1::typeof(bar)
#   inp %2::Float64
#   %3 = *(2, %2)::Float64
#   %4 = +(%3, 1)::Float64
# )

val, tape = trace(bar, 2.0; ctx=BaseCtx([*, +, foo]))
# (5.0, Tape{Dict{Any, Any}}
#   inp %1::typeof(bar)
#   inp %2::Float64
#   %3 = foo(%2)::Float64
#   %4 = +(%3, 1)::Float64
# )

struct MyCtx end

isprimitive(ctx::MyCtx, f, args...) = isprimitive(BaseCtx(), f, args...) || f in [foo]
val, tape = trace(bar, 2.0; ctx=MyCtx())
# (5.0, Tape{Dict{Any, Any}}
#   inp %1::typeof(bar)
#   inp %2::Float64
#   %3 = foo(%2)::Float64
#   %4 = +(%3, 1)::Float64
# )
source
Umlaut.isprimitiveFunction
isprimitive(ctx::BaseCtx, f, args...)

The default implementation of isprimitive used in trace(). Returns true if the method with the provided signature is defined in one of the Julia's built-in modules, e.g. Base, Core, Broadcast, etc.

source
isprimitive(ctx::Any, f, args...)

Fallback implementation of isprimitive(), behaves the same way as isprimitive(BaseCtx(), f, args...).

source
Umlaut.record_primitive!Function
record_primitive!(tape::Tape{C}, v_fargs...) where C

Record a primitive function call to the tape.

By default, this function simply pushes the function call to the tape, but it can also be overwritten to do more complex logic. For example, instead of recording the function call, a user can push one or more other calls, essentially implementing replace!() right during the tracing and without calling the function twice.

Examples:

The following code shows how to replace f(args...) with ChainRules.rrule(f, args...) duing the tracing:

function record_primitive!(tape::Tape{RRuleContext}, v_fargs)
    v_rr = push!(tape, mkcall(rrule, v_fargs...))
    v_val = push!(tape, mkcall(getfield, v_rr, 1))
    v_pb = push!(tape, mkcall(getfield, v_rr, 1))
    tape.c.pullbacks[v_val] = v_pb
    return v_val   # the function should return Variable with the result
end

See also: isprimitive()

source
Umlaut.BaseCtxType

Dict-like tracing context that treats as primitives all functions from the standard Julia modules (e.g. Base, Core, Statistics, etc.)

source
Umlaut.__new__Function
__new__(T, args...)

User-level version of the new() pseudofunction. Can be used to construct most Julia types, including structs without default constructors, closures, etc.

source

Variables

Umlaut.VariableType

Variable represents a reference to an operation on a tape. Variables can be used to index tape or keep reference to a specific operation on the tape.

Variables (also aliesed as V) can be:

  • free, created as V(id) - used for indexing into tape
  • bound, created as V(op) or V(tape, id) - used to keep a robust reference to an operation on the tape

See also: bound

source
Umlaut.boundFunction
bound(tape::Tape, v::Variable)
V(tape::Tape, v::Integer)
%(tape::Tape, i::Integer)

Create version of the var bound to an operation on the tape. The short syntax tape %i is convenient for working in REPL, but may surprise a reader of your code. Use it wisely.

Examples:

V(3)                # unbound var
V(tape, 3)          # bound var
bound(tape, 3)      # bound var
bound(tape, V(3))   # bound var
tape %3             # bound var
source
Umlaut.rebind!Function
rebind!(tape::Tape, op, st::Dict)
rebind!(tape::Tape, st::Dict; from, to)

Rebind all variables according to substitution table. Example:

tape = Tape()
v1, v2 = inputs!(tape, nothing, 3.0, 5.0)
v3 = push!(tape, mkcall(*, v1, 2))
st = Dict(v1.id => v2.id)
rebind!(tape, st)
@assert tape[v3].args[1].id == v2.id

See also: rebind_context!()

source
Umlaut.rebind_context!Function
rebind_context!(tape::Tape, st::Dict)

Rebind variables in the tape's context according to substitution table. By default does nothing, but can be overwitten for specific Tape{C}

source

Tape structure

Umlaut.TapeType

Linearized representation of a function execution.

Fields

  • ops - vector of operations on the tape
  • result - variable pointing to the operation to be used as the result
  • parent - parent tape if any
  • meta - internal metadata
  • c - application-specific context
source
Umlaut.CallType

Operation represening function call on tape. Typically, calls are constructed using mkcall function.

Important fields of a Call{T}:

  • fn::T - function or object to be called
  • args::Vector - vector of variables or values used as arguments
  • val::Any - the result of the function call
source
Umlaut.LoopType

Operation representing a loop in an computational graph. See the online documentation for details.

source
Umlaut.mkcallFunction
mkcall(fn, args...; val=missing, kwargs=(;))

Convenient constructor for Call operation. If val is UncalculatedValue (default) and call value can be calculated from (bound) variables and constants, they are calculated. To prevent this behavior, set val to some neutral value.

source

Tape transformations

Base.push!Function
push!(tape::Tape, op::AbstractOp)

Push a new operation to the end of the tape.

source
Base.insert!Function
insert!(tape::Tape, idx::Integer, ops::AbstractOp...)

Insert new operations into tape starting from position idx.

source
Base.replace!Function
replace!(tape, op  => new_ops; rebind_to=length(new_ops), old_new=Dict())

Replace specified operation with 1 or more other operations, rebind variables in the reminder of the tape to ops[rebind_to].

Operation can be specified directly, by a variable or by ID.

source
Base.deleteat!Function
deleteat!(tape::Tape, idx; rebind_to = nothing)

Remove tape[V(idx)] from the tape. If rebind_to is not nothing, then replace all references to V(idx) with V(rebind_to).

idx may be an index or Variable/AbstractOp directly.

source
Umlaut.primitivize!Function
primitivize!(tape::Tape; ctx=nothing)

Trace non-primitive function calls on a tape and decompose them into a list of corresponding primitive calls.

Example

f(x) = 2x - 1
g(x) = f(x) + 5

tape = Tape()
_, x = inputs!(tape, g, 3.0)
y = push!(tape, mkcall(f, x))
z = push!(tape, mkcall(+, y, 5))
tape.result = z

primitivize!(tape)

# output

Tape{BaseCtx}
  inp %1::typeof(g)
  inp %2::Float64
  %3 = *(2, %2)::Float64
  %4 = -(%3, 1)::Float64
  %5 = +(%4, 5)::Float64
source

Tape execution

Umlaut.play!Function
play!(tape::Tape, args...; debug=false)

Execute operations on the tape one by one. If debug=true, print each operation before execution.

source
Umlaut.to_exprFunction
to_expr(tape::Tape)

Generate a Julia expression corresponding to the tape.

source

Internal functions

Umlaut.code_signatureFunction
code_signature(ctx, v_fargs)

Returns method signature as a tuple (f, (arg1typ, arg2typ, ...)). This signature is suitable for getcode() and which().

source
Umlaut.call_signatureFunction
call_signature(fn, args...)
call_signature(tape::Tape, op::Call)

Get a signature of a function call. The obtain signature is suitable for is_primitive(sig).

source
Umlaut.trace!Function
trace!(t::Tracer, v_fargs)

Trace call defined by variables in v_fargs.

source
Umlaut.trace_call!Function
trace_call!(t::Tracer{C}, v_f, v_args...) where C

Customizable handler that controls what to do with a function call. The default implementation checks if the call is a primitive and either records it to the tape or recurses into it.

source
Umlaut.unsplat!Function
unsplat!(t::Tracer, v_fargs)

In the lowered form, splatting syntax f(xs...) is represented as Core.applyiterate(f, xs). unsplat!() reverses this change and transformes v_fargs to a normal form, possibly destructuring xs into separate variables on the tape.

source
Umlaut.map_varsFunction

Helper function to map a function only to Variable arguments of a Call leaving constant values as is

source
Umlaut.block_expressionsFunction
block_expressions(ir::IRCode)

For each block, compute a vector of its expressions along with their SSA IDs. Returns Vector{blockinfo}, where blockinfo is Vector{ssa_id => expr}

source
Umlaut.promote_const_valueFunction

Unwrap constant value from its expression container such as GlobalRef, QuoteNode, etc. No-op if there's no known container.

source
Umlaut.handle_gotoifnot_node!Function
handle_gotoifnot_node!(t::Tracer, cf::Core.GotoIfNot, frame::Frame)

Must return the value associated to a Core.GotoIfNot node. May also have side-effects, such as modifying the tape.

source
Umlaut.is_ho_tracableFunction
is_ho_tracable(ctx::Any, f, args...)

Is higher-order tracable. Returns true if f is a known higher-order function that Umlaut knows how to trace and its functional argument is a non-primitive.

is_ho_tracable() helps to trace through higher-order functions like Core._apply_iterate() (used internally when splatting arguments with ...) as if they themselves were non-primitives.

source
Umlaut.__foreigncall__Function
function __foreigncall__(
    ::Val{name}, ::Val{RT}, AT::Tuple, ::Val{nreq}, ::Val{calling_convention}, x...
) where {name, RT, nreq, calling_convention}

:foreigncall nodes get translated into calls to this function. For example,

Expr(:foreigncall, :foo, Tout, (A, B), nreq, :ccall, args...)

becomes

__foreigncall__(Val(:foo), Val(Tout), (Val(A), Val(B)), Val(nreq), Val(:ccall), args...)

Please consult the Julia documentation for more information on how foreigncall nodes work, and consult this package's tests for examples.

source
Umlaut.UncalculatedValueType

UncalculatedValue()

This struct is used to signal that a value on the tape has not been computed. The downstream Call operations using at least one value of type Umlaut.UncalculatedValue will propagate and will also have a result of type Umlaut.UncalculatedValue.

source

Index