Public API
Tracing
Umlaut.trace
— Functiontrace(f, args...; ctx=BaseCtx())
Trace function call, return the result and the corresponding Tape. trace
records to the tape primitive methods and recursively dives into non-primitives.
Tracing can be customized using a context and the following methods:
- isprimitive(ctx, f, args...) - decides whethere
f(args...)
should be treated as a primitive. - recordprimitive!(tape::Tape{C}, vf, vargs...) - records the primitive call defined by variables `fv(v_args...)` to the tape.
The default context is BaseCtx()
, which treats all functions from standard Julia modules as primitives and simply pushes the call to the tape. See the docstrings of these functions for further examples of customization.
Examples:
foo(x) = 2x
bar(x) = foo(x) + 1
val, tape = trace(bar, 2.0)
# (5.0, Tape{Dict{Any, Any}}
# inp %1::typeof(bar)
# inp %2::Float64
# %3 = *(2, %2)::Float64
# %4 = +(%3, 1)::Float64
# )
val, tape = trace(bar, 2.0; ctx=BaseCtx([*, +, foo]))
# (5.0, Tape{Dict{Any, Any}}
# inp %1::typeof(bar)
# inp %2::Float64
# %3 = foo(%2)::Float64
# %4 = +(%3, 1)::Float64
# )
struct MyCtx end
isprimitive(ctx::MyCtx, f, args...) = isprimitive(BaseCtx(), f, args...) || f in [foo]
val, tape = trace(bar, 2.0; ctx=MyCtx())
# (5.0, Tape{Dict{Any, Any}}
# inp %1::typeof(bar)
# inp %2::Float64
# %3 = foo(%2)::Float64
# %4 = +(%3, 1)::Float64
# )
Umlaut.isprimitive
— Functionisprimitive(ctx::BaseCtx, f, args...)
The default implementation of isprimitive
used in trace()
. Returns true
if the method with the provided signature is defined in one of the Julia's built-in modules, e.g. Base
, Core
, Broadcast
, etc.
isprimitive(ctx::Any, f, args...)
Fallback implementation of isprimitive()
, behaves the same way as isprimitive(BaseCtx(), f, args...)
.
Umlaut.record_primitive!
— Functionrecord_primitive!(tape::Tape{C}, v_fargs...) where C
Record a primitive function call to the tape.
By default, this function simply pushes the function call to the tape, but it can also be overwritten to do more complex logic. For example, instead of recording the function call, a user can push one or more other calls, essentially implementing replace!()
right during the tracing and without calling the function twice.
Examples:
The following code shows how to replace f(args...) with ChainRules.rrule(f, args...) duing the tracing:
function record_primitive!(tape::Tape{RRuleContext}, v_fargs)
v_rr = push!(tape, mkcall(rrule, v_fargs...))
v_val = push!(tape, mkcall(getfield, v_rr, 1))
v_pb = push!(tape, mkcall(getfield, v_rr, 1))
tape.c.pullbacks[v_val] = v_pb
return v_val # the function should return Variable with the result
end
See also: isprimitive()
Umlaut.BaseCtx
— TypeDict-like tracing context that treats as primitives all functions from the standard Julia modules (e.g. Base, Core, Statistics, etc.)
Umlaut.__new__
— Function__new__(T, args...)
User-level version of the new()
pseudofunction. Can be used to construct most Julia types, including structs without default constructors, closures, etc.
Variables
Umlaut.Variable
— TypeVariable represents a reference to an operation on a tape. Variables can be used to index tape or keep reference to a specific operation on the tape.
Variables (also aliesed as V
) can be:
- free, created as
V(id)
- used for indexing into tape - bound, created as
V(op)
or V(tape, id) - used to keep a robust reference to an operation on the tape
See also: bound
Umlaut.bound
— Functionbound(tape::Tape, v::Variable)
V(tape::Tape, v::Integer)
%(tape::Tape, i::Integer)
Create version of the var bound to an operation on the tape. The short syntax tape %i
is convenient for working in REPL, but may surprise a reader of your code. Use it wisely.
Examples:
V(3) # unbound var
V(tape, 3) # bound var
bound(tape, 3) # bound var
bound(tape, V(3)) # bound var
tape %3 # bound var
Umlaut.rebind!
— Functionrebind!(tape::Tape, op, st::Dict)
rebind!(tape::Tape, st::Dict; from, to)
Rebind all variables according to substitution table. Example:
tape = Tape()
v1, v2 = inputs!(tape, nothing, 3.0, 5.0)
v3 = push!(tape, mkcall(*, v1, 2))
st = Dict(v1.id => v2.id)
rebind!(tape, st)
@assert tape[v3].args[1].id == v2.id
See also: rebind_context!()
Umlaut.rebind_context!
— Functionrebind_context!(tape::Tape, st::Dict)
Rebind variables in the tape's context according to substitution table. By default does nothing, but can be overwitten for specific Tape{C}
Tape structure
Umlaut.Tape
— TypeLinearized representation of a function execution.
Fields
ops
- vector of operations on the taperesult
- variable pointing to the operation to be used as the resultparent
- parent tape if anymeta
- internal metadatac
- application-specific context
Umlaut.AbstractOp
— TypeBase type for operations on a tape
Umlaut.Input
— TypeOperation representing input data of a tape
Umlaut.Constant
— TypeOperation representing a constant value on a tape
Umlaut.Call
— TypeOperation represening function call on tape. Typically, calls are constructed using mkcall
function.
Important fields of a Call{T}:
fn::T
- function or object to be calledargs::Vector
- vector of variables or values used as argumentsval::Any
- the result of the function call
Umlaut.Loop
— TypeOperation representing a loop in an computational graph. See the online documentation for details.
Umlaut.inputs
— FunctionGet list of a tape input variables
Umlaut.inputs!
— FunctionSet values of a tape inputs
Umlaut.mkcall
— Functionmkcall(fn, args...; val=missing, kwargs=(;))
Convenient constructor for Call operation. If val is UncalculatedValue
(default) and call value can be calculated from (bound) variables and constants, they are calculated. To prevent this behavior, set val to some neutral value.
Tape transformations
Base.push!
— Functionpush!(tape::Tape, op::AbstractOp)
Push a new operation to the end of the tape.
Base.insert!
— Functioninsert!(tape::Tape, idx::Integer, ops::AbstractOp...)
Insert new operations into tape starting from position idx.
Base.replace!
— Functionreplace!(tape, op => new_ops; rebind_to=length(new_ops), old_new=Dict())
Replace specified operation with 1 or more other operations, rebind variables in the reminder of the tape to ops[rebind_to].
Operation can be specified directly, by a variable or by ID.
Base.deleteat!
— Functiondeleteat!(tape::Tape, idx; rebind_to = nothing)
Remove tape[V(idx)]
from the tape
. If rebind_to
is not nothing
, then replace all references to V(idx)
with V(rebind_to)
.
idx
may be an index or Variable
/AbstractOp
directly.
Umlaut.primitivize!
— Functionprimitivize!(tape::Tape; ctx=nothing)
Trace non-primitive function calls on a tape and decompose them into a list of corresponding primitive calls.
Example
f(x) = 2x - 1
g(x) = f(x) + 5
tape = Tape()
_, x = inputs!(tape, g, 3.0)
y = push!(tape, mkcall(f, x))
z = push!(tape, mkcall(+, y, 5))
tape.result = z
primitivize!(tape)
# output
Tape{BaseCtx}
inp %1::typeof(g)
inp %2::Float64
%3 = *(2, %2)::Float64
%4 = -(%3, 1)::Float64
%5 = +(%4, 5)::Float64
Tape execution
Umlaut.play!
— Functionplay!(tape::Tape, args...; debug=false)
Execute operations on the tape one by one. If debug=true
, print each operation before execution.
Umlaut.compile
— Functioncompile(tape::Tape)
Compile tape into a normal Julia function.
Umlaut.to_expr
— Functionto_expr(tape::Tape)
Generate a Julia expression corresponding to the tape.
Internal functions
Umlaut.code_signature
— Functioncode_signature(ctx, v_fargs)
Returns method signature as a tuple (f, (arg1typ, arg2typ, ...)). This signature is suitable for getcode() and which().
Umlaut.call_signature
— Functioncall_signature(fn, args...)
call_signature(tape::Tape, op::Call)
Get a signature of a function call. The obtain signature is suitable for is_primitive(sig)
.
Umlaut.trace!
— Functiontrace!(t::Tracer, v_fargs)
Trace call defined by variables in v_fargs.
Umlaut.trace_call!
— Functiontrace_call!(t::Tracer{C}, v_f, v_args...) where C
Customizable handler that controls what to do with a function call. The default implementation checks if the call is a primitive and either records it to the tape or recurses into it.
Umlaut.unsplat!
— Functionunsplat!(t::Tracer, v_fargs)
In the lowered form, splatting syntax f(xs...) is represented as Core.applyiterate(f, xs). unsplat!() reverses this change and transformes v_fargs to a normal form, possibly destructuring xs into separate variables on the tape.
Umlaut.map_vars
— FunctionHelper function to map a function only to Variable arguments of a Call leaving constant values as is
Umlaut.block_expressions
— Functionblock_expressions(ir::IRCode)
For each block, compute a vector of its expressions along with their SSA IDs. Returns Vector{blockinfo}, where blockinfo is Vector{ssa_id => expr}
Umlaut.promote_const_value
— FunctionUnwrap constant value from its expression container such as GlobalRef, QuoteNode, etc. No-op if there's no known container.
Umlaut.loop_exit_vars_at_point
— FunctionCollect variables which will be used at loop exit if it happens at this point on tape.
Umlaut.handle_gotoifnot_node!
— Functionhandle_gotoifnot_node!(t::Tracer, cf::Core.GotoIfNot, frame::Frame)
Must return the value associated to a Core.GotoIfNot
node. May also have side-effects, such as modifying the tape.
Umlaut.is_ho_tracable
— Functionis_ho_tracable(ctx::Any, f, args...)
Is higher-order tracable. Returns true if f
is a known higher-order function that Umlaut knows how to trace and its functional argument is a non-primitive.
is_ho_tracable()
helps to trace through higher-order functions like Core._apply_iterate()
(used internally when splatting arguments with ...) as if they themselves were non-primitives.
Umlaut.__foreigncall__
— Functionfunction __foreigncall__(
::Val{name}, ::Val{RT}, AT::Tuple, ::Val{nreq}, ::Val{calling_convention}, x...
) where {name, RT, nreq, calling_convention}
:foreigncall nodes get translated into calls to this function. For example,
Expr(:foreigncall, :foo, Tout, (A, B), nreq, :ccall, args...)
becomes
__foreigncall__(Val(:foo), Val(Tout), (Val(A), Val(B)), Val(nreq), Val(:ccall), args...)
Please consult the Julia documentation for more information on how foreigncall nodes work, and consult this package's tests for examples.
Umlaut.UncalculatedValue
— TypeUncalculatedValue()
This struct is used to signal that a value on the tape has not been computed. The downstream Call operations using at least one value of type Umlaut.UncalculatedValue
will propagate and will also have a result of type Umlaut.UncalculatedValue
.
Index
Umlaut.AbstractOp
Umlaut.BaseCtx
Umlaut.Call
Umlaut.Constant
Umlaut.Input
Umlaut.Loop
Umlaut.Tape
Umlaut.UncalculatedValue
Umlaut.Variable
Base.deleteat!
Base.insert!
Base.push!
Base.replace!
Umlaut.__foreigncall__
Umlaut.__new__
Umlaut.block_expressions
Umlaut.bound
Umlaut.call_signature
Umlaut.code_signature
Umlaut.compile
Umlaut.handle_gotoifnot_node!
Umlaut.inputs
Umlaut.inputs!
Umlaut.is_ho_tracable
Umlaut.isprimitive
Umlaut.loop_exit_vars_at_point
Umlaut.map_vars
Umlaut.mkcall
Umlaut.play!
Umlaut.primitivize!
Umlaut.promote_const_value
Umlaut.rebind!
Umlaut.rebind_context!
Umlaut.record_primitive!
Umlaut.to_expr
Umlaut.trace
Umlaut.trace!
Umlaut.trace_call!
Umlaut.unsplat!