Public API
Tracing
Umlaut.trace — Functiontrace(f, args...; ctx=BaseCtx())Trace function call, return the result and the corresponding Tape. trace records to the tape primitive methods and recursively dives into non-primitives.
Tracing can be customized using a context and the following methods:
- isprimitive(ctx, f, args...) - decides whethere
f(args...)should be treated as a primitive. - recordprimitive!(tape::Tape{C}, vf, vargs...) - records the primitive call defined by variables `fv(v_args...)` to the tape.
The default context is BaseCtx(), which treats all functions from standard Julia modules as primitives and simply pushes the call to the tape. See the docstrings of these functions for further examples of customization.
Examples:
foo(x) = 2x
bar(x) = foo(x) + 1
val, tape = trace(bar, 2.0)
# (5.0, Tape{Dict{Any, Any}}
# inp %1::typeof(bar)
# inp %2::Float64
# %3 = *(2, %2)::Float64
# %4 = +(%3, 1)::Float64
# )
val, tape = trace(bar, 2.0; ctx=BaseCtx([*, +, foo]))
# (5.0, Tape{Dict{Any, Any}}
# inp %1::typeof(bar)
# inp %2::Float64
# %3 = foo(%2)::Float64
# %4 = +(%3, 1)::Float64
# )
struct MyCtx end
isprimitive(ctx::MyCtx, f, args...) = isprimitive(BaseCtx(), f, args...) || f in [foo]
val, tape = trace(bar, 2.0; ctx=MyCtx())
# (5.0, Tape{Dict{Any, Any}}
# inp %1::typeof(bar)
# inp %2::Float64
# %3 = foo(%2)::Float64
# %4 = +(%3, 1)::Float64
# )Umlaut.isprimitive — Functionisprimitive(ctx::BaseCtx, f, args...)The default implementation of isprimitive used in trace(). Returns true if the method with the provided signature is defined in one of the Julia's built-in modules, e.g. Base, Core, Broadcast, etc.
isprimitive(ctx::Any, f, args...)Fallback implementation of isprimitive(), behaves the same way as isprimitive(BaseCtx(), f, args...).
Umlaut.record_primitive! — Functionrecord_primitive!(tape::Tape{C}, v_fargs...) where CRecord a primitive function call to the tape.
By default, this function simply pushes the function call to the tape, but it can also be overwritten to do more complex logic. For example, instead of recording the function call, a user can push one or more other calls, essentially implementing replace!() right during the tracing and without calling the function twice.
Examples:
The following code shows how to replace f(args...) with ChainRules.rrule(f, args...) duing the tracing:
function record_primitive!(tape::Tape{RRuleContext}, v_fargs)
v_rr = push!(tape, mkcall(rrule, v_fargs...))
v_val = push!(tape, mkcall(getfield, v_rr, 1))
v_pb = push!(tape, mkcall(getfield, v_rr, 1))
tape.c.pullbacks[v_val] = v_pb
return v_val # the function should return Variable with the result
endSee also: isprimitive()
Umlaut.BaseCtx — TypeDict-like tracing context that treats as primitives all functions from the standard Julia modules (e.g. Base, Core, Statistics, etc.)
Umlaut.__new__ — Function__new__(T, args...)User-level version of the new() pseudofunction. Can be used to construct most Julia types, including structs without default constructors, closures, etc.
Variables
Umlaut.Variable — TypeVariable represents a reference to an operation on a tape. Variables can be used to index tape or keep reference to a specific operation on the tape.
Variables (also aliesed as V) can be:
- free, created as
V(id)- used for indexing into tape - bound, created as
V(op)or V(tape, id) - used to keep a robust reference to an operation on the tape
See also: bound
Umlaut.bound — Functionbound(tape::Tape, v::Variable)
V(tape::Tape, v::Integer)
%(tape::Tape, i::Integer)Create version of the var bound to an operation on the tape. The short syntax tape %i is convenient for working in REPL, but may surprise a reader of your code. Use it wisely.
Examples:
V(3) # unbound var
V(tape, 3) # bound var
bound(tape, 3) # bound var
bound(tape, V(3)) # bound var
tape %3 # bound varUmlaut.rebind! — Functionrebind!(tape::Tape, op, st::Dict)
rebind!(tape::Tape, st::Dict; from, to)Rebind all variables according to substitution table. Example:
tape = Tape()
v1, v2 = inputs!(tape, nothing, 3.0, 5.0)
v3 = push!(tape, mkcall(*, v1, 2))
st = Dict(v1.id => v2.id)
rebind!(tape, st)
@assert tape[v3].args[1].id == v2.idSee also: rebind_context!()
Umlaut.rebind_context! — Functionrebind_context!(tape::Tape, st::Dict)Rebind variables in the tape's context according to substitution table. By default does nothing, but can be overwitten for specific Tape{C}
Tape structure
Umlaut.Tape — TypeLinearized representation of a function execution.
Fields
ops- vector of operations on the taperesult- variable pointing to the operation to be used as the resultparent- parent tape if anymeta- internal metadatac- application-specific context
Umlaut.AbstractOp — TypeBase type for operations on a tape
Umlaut.Input — TypeOperation representing input data of a tape
Umlaut.Constant — TypeOperation representing a constant value on a tape
Umlaut.Call — TypeOperation represening function call on tape. Typically, calls are constructed using mkcall function.
Important fields of a Call{T}:
fn::T- function or object to be calledargs::Vector- vector of variables or values used as argumentsval::Any- the result of the function call
Umlaut.Loop — TypeOperation representing a loop in an computational graph. See the online documentation for details.
Umlaut.inputs — FunctionGet list of a tape input variables
Umlaut.inputs! — FunctionSet values of a tape inputs
Umlaut.mkcall — Functionmkcall(fn, args...; val=missing, kwargs=(;))Convenient constructor for Call operation. If val is UncalculatedValue (default) and call value can be calculated from (bound) variables and constants, they are calculated. To prevent this behavior, set val to some neutral value.
Tape transformations
Base.push! — Functionpush!(tape::Tape, op::AbstractOp)Push a new operation to the end of the tape.
Base.insert! — Functioninsert!(tape::Tape, idx::Integer, ops::AbstractOp...)Insert new operations into tape starting from position idx.
Base.replace! — Functionreplace!(tape, op => new_ops; rebind_to=length(new_ops), old_new=Dict())Replace specified operation with 1 or more other operations, rebind variables in the reminder of the tape to ops[rebind_to].
Operation can be specified directly, by a variable or by ID.
Base.deleteat! — Functiondeleteat!(tape::Tape, idx; rebind_to = nothing)Remove tape[V(idx)] from the tape. If rebind_to is not nothing, then replace all references to V(idx) with V(rebind_to).
idx may be an index or Variable/AbstractOp directly.
Umlaut.primitivize! — Functionprimitivize!(tape::Tape; ctx=nothing)Trace non-primitive function calls on a tape and decompose them into a list of corresponding primitive calls.
Example
f(x) = 2x - 1
g(x) = f(x) + 5
tape = Tape()
_, x = inputs!(tape, g, 3.0)
y = push!(tape, mkcall(f, x))
z = push!(tape, mkcall(+, y, 5))
tape.result = z
primitivize!(tape)
# output
Tape{BaseCtx}
inp %1::typeof(g)
inp %2::Float64
%3 = *(2, %2)::Float64
%4 = -(%3, 1)::Float64
%5 = +(%4, 5)::Float64Tape execution
Umlaut.play! — Functionplay!(tape::Tape, args...; debug=false)Execute operations on the tape one by one. If debug=true, print each operation before execution.
Umlaut.compile — Functioncompile(tape::Tape)Compile tape into a normal Julia function.
Umlaut.to_expr — Functionto_expr(tape::Tape)Generate a Julia expression corresponding to the tape.
Internal functions
Umlaut.code_signature — Functioncode_signature(ctx, v_fargs)Returns method signature as a tuple (f, (arg1typ, arg2typ, ...)). This signature is suitable for getcode() and which().
Umlaut.call_signature — Functioncall_signature(fn, args...)
call_signature(tape::Tape, op::Call)Get a signature of a function call. The obtain signature is suitable for is_primitive(sig).
Umlaut.trace! — Functiontrace!(t::Tracer, v_fargs)Trace call defined by variables in v_fargs.
Umlaut.trace_call! — Functiontrace_call!(t::Tracer{C}, v_f, v_args...) where CCustomizable handler that controls what to do with a function call. The default implementation checks if the call is a primitive and either records it to the tape or recurses into it.
Umlaut.unsplat! — Functionunsplat!(t::Tracer, v_fargs)In the lowered form, splatting syntax f(xs...) is represented as Core.applyiterate(f, xs). unsplat!() reverses this change and transformes v_fargs to a normal form, possibly destructuring xs into separate variables on the tape.
Umlaut.map_vars — FunctionHelper function to map a function only to Variable arguments of a Call leaving constant values as is
Umlaut.block_expressions — Functionblock_expressions(ir::IRCode)For each block, compute a vector of its expressions along with their SSA IDs. Returns Vector{blockinfo}, where blockinfo is Vector{ssa_id => expr}
Umlaut.promote_const_value — FunctionUnwrap constant value from its expression container such as GlobalRef, QuoteNode, etc. No-op if there's no known container.
Umlaut.loop_exit_vars_at_point — FunctionCollect variables which will be used at loop exit if it happens at this point on tape.
Umlaut.handle_gotoifnot_node! — Functionhandle_gotoifnot_node!(t::Tracer, cf::Core.GotoIfNot, frame::Frame)Must return the value associated to a Core.GotoIfNot node. May also have side-effects, such as modifying the tape.
Umlaut.is_ho_tracable — Functionis_ho_tracable(ctx::Any, f, args...)Is higher-order tracable. Returns true if f is a known higher-order function that Umlaut knows how to trace and its functional argument is a non-primitive.
is_ho_tracable() helps to trace through higher-order functions like Core._apply_iterate() (used internally when splatting arguments with ...) as if they themselves were non-primitives.
Umlaut.__foreigncall__ — Functionfunction __foreigncall__(
::Val{name}, ::Val{RT}, AT::Tuple, ::Val{nreq}, ::Val{calling_convention}, x...
) where {name, RT, nreq, calling_convention}:foreigncall nodes get translated into calls to this function. For example,
Expr(:foreigncall, :foo, Tout, (A, B), nreq, :ccall, args...)becomes
__foreigncall__(Val(:foo), Val(Tout), (Val(A), Val(B)), Val(nreq), Val(:ccall), args...)Please consult the Julia documentation for more information on how foreigncall nodes work, and consult this package's tests for examples.
Umlaut.UncalculatedValue — TypeUncalculatedValue()
This struct is used to signal that a value on the tape has not been computed. The downstream Call operations using at least one value of type Umlaut.UncalculatedValue will propagate and will also have a result of type Umlaut.UncalculatedValue.
Index
Umlaut.AbstractOpUmlaut.BaseCtxUmlaut.CallUmlaut.ConstantUmlaut.InputUmlaut.LoopUmlaut.TapeUmlaut.UncalculatedValueUmlaut.VariableBase.deleteat!Base.insert!Base.push!Base.replace!Umlaut.__foreigncall__Umlaut.__new__Umlaut.block_expressionsUmlaut.boundUmlaut.call_signatureUmlaut.code_signatureUmlaut.compileUmlaut.handle_gotoifnot_node!Umlaut.inputsUmlaut.inputs!Umlaut.is_ho_tracableUmlaut.isprimitiveUmlaut.loop_exit_vars_at_pointUmlaut.map_varsUmlaut.mkcallUmlaut.play!Umlaut.primitivize!Umlaut.promote_const_valueUmlaut.rebind!Umlaut.rebind_context!Umlaut.record_primitive!Umlaut.to_exprUmlaut.traceUmlaut.trace!Umlaut.trace_call!Umlaut.unsplat!