Guards Overview¶
From a UX perspective, TorchDynamo is very easy to use. The user invokes
torchdynamo.optimize
as an annotation:
@torchdynamo.optimize(my_compiler)
def fn_foo(bar):
Where a complete example looks like this:
from typing import List
import torch
from torch import _dynamo as torchdynamo
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
print("my_compiler() called with FX graph:")
gm.graph.print_tabular()
return gm.forward # return a python callable
@torchdynamo.optimize(my_compiler)
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b
for _ in range(100):
toy_example(torch.randn(10), torch.randn(10))
This allows TorchDynamo to capture the interpreted Python frames, grab any and all relevant information, and speed things up wherever it can. The speedup comes from a few places, and can be rather dependent on the backend (my_compiler in the example above) provided, but the one speedup that is important in this section is caching. Caching itself is not a direct speedup but a critical enablement that prevents recompilation. We dig a hole with dynamo, and caching allows us to get out. It enables us to hold perf neutrality while then enabling backends - the true source of our speedups.
With even a pass-through no-op backend provided:
def my_compiler(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
return gm.forward
We can see TorchDynamo speeding up Python execution even on regular Python, not just PyTorch.
Caching and Guards Overview¶
TorchDynamo operates through caching transformed (by TorchDynamo) user bytecode. When TorchDynamo receives a frame for evaluation, it checks if the objects referenced in the frame have changed in certain ways, and if not, TorchDynamo reads the previously transformed user bytecode to evaluate it. In this section, we will focus on how we can identify whether or not the objects referenced in the frame have changed. This is a critical piece of functionality in TorchDynamo, because it drives the entire invalidation lifecycle. This functionality is called guards.
At a very high level, the flow can be summarized like this:
TorchDynamo receives a Python frame.
It converts the frame (1) passing it through instruction translation.
For the objects captured in (2), TorchDynamo creates tracking objects that are: * tracked on an output graph, which is an internal specialization of a torch.fx.Tracer * guards
TorchDynamo processes the guard objects created in (3), turning them into a generated Python function, check_fn, associated with a piece of code.
The check_fn is evaluated whenever we encounter this code a subsequent time - if a check_fn passes and evaluates to True, TorchDynamo identifies the code in the cache and the code encountered here as same, and can be safely used. If it fails and evaluates to False, TorchDynamo identifies the code in the cache as not valid, and can be thrown out in favor of a new entry, through recompilation or a graph break.
Python Frame Evaluation and PEP 523¶
The functionality of TorchDynamo is based on PEP 523.
TorchDynamo installs a frame evaluation function on Python by using _PyInterpreterState_SetEvalFrameFunc. TorchDynamo has a hook where Python can hand control back to us during evaluation.
The function we have installed is convert_frame
or
convert_frame_assert
in the nopython=True
case, but glossing
over that nuance for now, let’s take a look at convert_frame_assert
,
as convert_frame
proxies to it.
We can find it on line 20 of convert_frame.py, with a signature as follows:
def convert_frame_assert(compiler_fn: Callable, one_graph=True):
This function wraps the entry point of where Python invokes TorchDynamo with a frame:
def _convert_frame_assert(frame: types.FrameType, cache_size: int):
Here is what this function does:
Checks if it has seen this
code
(see: f_code here) before and exits early if it did.Checks if the code is an unsupported case.
Checks if the
cache_size
(second arg above) crosses the limit defined in the config,cache_size_limit
. If it has, the function drops the frame and logs warnings. This helps to avoid constant recompilation of a frame as it generally means that the frame is hot in an unexpected way and caching it produces needless overhead, as it is likely to get evicted the next time it is encountered.Passes the frame, alongside a function that creates an
InstructionTranslator
through bytecode transformation, viatransform_code_object
. A few crucial things happen under the hood here:New code is produced through
transform_code_object
.An FX tracer named
output
is produced throughInstructionTranslator
.This can be a bit confusing, as
InstructionTranslator
is not an fx tracer, but its stored in a variable named tracer, and its output*isan `fx`tracer.The function produces guards and stores them on
output
above.The function produces
output_instructions
and stores them onoutput
above.The function maps the newly produced transformed code to the initial code it read off the frame. This mapping is worth remembering, we will refer to it much later on below where we cover guard failures.
Using the transformed code from 4.1 and the guards from 4.3, the function produces a GuardedCode.
Now that we have learned about frame evaluation, let’s review
InstructionTranslator
, and see how it turns the frame we handed
it over into TorchDynamo internal types.
InstructionTranslator¶
InstructionTranslator does a lot! We won’t cover the details of
everything it does, but most importantly for this document, it produces
a mapping of symbolic_locals
which maintains a mapping from the
frame’s f_locals
to TorchDynamo internal Variable objects (more on these
in a moment. symbolic_locals
is filled via traversing the frame’s
locals:
self.symbolic_locals = collections.OrderedDict(
(k, VariableBuilder(self, LocalSource(k))(f_locals[k]))
for k in vars
if k in f_locals
)
The important component here is the invocation of a call
into VariableBuilder
. VariableBuilder
’s call implementation
proxies into a function called _wrap
, which in turn both constructs
instances of VariableTracker
and calls make_guards
on them. More
on that later.
This mapping, in turn, is critical as each Variable has associated
guards, which are then passed to self.output
, the instance of
OutputGraph
, an fx tracer, mentioned in 4.2 of the section above. If
you recall, this OutputGraph
, stored in a variable called output
is where our guards are stored before being passed on to become
GuardedCode
How does InstructionTranslator
do this? At the heart of it, there is
a loop that is pumped, which drives a function step
.
step
is just that - a single processing step, taking exactly one
instruction and doing something with it.
Note
These are real instructions processed by TorchDynamo’s
transform_code_object
, and it is pretty cool.
Note
This section purposely skips the details of dis.get_instructions.
For the example above, here is a snippet of a what a few
Instruction
's may look like:
Instruction(opcode=124, opname='LOAD_FAST', arg=0, argval='b', offset=32, starts_line=8, is_jump_target=True, target=None)
Instruction(opcode=100, opname='LOAD_CONST', arg=3, argval=-1, offset=34, starts_line=None, is_jump_target=False, target=None)
Instruction(opcode=20, opname='BINARY_MULTIPLY', arg=None, argval=None, offset=36, starts_line=None, is_jump_target=False, target=None)
This is the core functionality of this function. Take a look at the opname
,
and then take a look at this little snippet from inside step
;
if not hasattr(self, inst.opname):
unimplemented(f"missing: {inst.opname}")
getattr(self, inst.opname)(inst)
As we can see, the function checks if the current class, the
InstructionTranslator
has an attribute set matching the operator name
(for example, LOAD_CONST
). If it does, the function invokes it, passing the
whole instruction object in. If it does not, the function drops the frame as
unimplemented.
For the LOAD_CONST
example, we can see that we do indeed support it,
with a relatively straightforward definition:
def LOAD_CONST(self, inst):
self.push(ConstantVariable(value=inst.argval))
We can see that this function creates a new instance of the class
ConstantVariable
, with a value, in our example case, -1, and then
pushes it onto the stack.
There are dozens of such methods - see symbolic_convert.py
for all of
them. Generally, we implement as many matching methods to Python
bytecode instructions as possible.
Across both the logic downstream of step
and the logic from invoking
VariableBuilder
- we now have a lot of VariableTracker
s and of
course, we’ve spoken about creating guards quiet a bit. Let’s dig into
what Variables are, and get a little closer to understanding guards.
Variables¶
A ConstantVariable
is an instance ofVariableTracker
.
VariableTracker
represents a tracked Python local or stack value.
When it comes to representing an object inside TorchDynamo, a
VariableTracker
does exactly what it says - it tracks a given variable.
It is an extremely flexible class, but there are a few points to keep in
mind:
It manages the
guard
relationship around the underlying object through:make_guard
replace_guards
add_guard(s)
propagate
-propagate(*vars: List[List["VariableTracker"]])
- Perhaps the most important of all, in that it combines guards from all the providedVariableTracker
instances passed in. It visits the guards and combines the guards from these onto itself.
It acts as a proxy on behalf of the underlying object, implementing methods for the rest of TorchDynamo to get information about the tracked object:
call_method
call_function
python_type
as_proxy
is/as_python_proxy
It stores the variable
source
of typeSource
, fromtorchdynamo/source.py
. This source type is a relatively self contained class that helps us organize and bookkeep where the original source came from, and helps provide convenience methods for things like getting the name, and importantly for us, producing guards.
And this class (VariableTracker
) is built around subclassing,
somewhere between a full Abstract Base Class and fully fleshed out class
- it leaves many methods raising NotImplementedError
- with reliance on
subclasses. See torchdynamo/variables/
for all subclasses to fulfill
contracts and custom behaviors.
Knowing what we know now, we can see an example of how an instruction
from dis
, BUILD_TUPLE
:
BUILD_TUPLE(count)
Creates a tuple consuming count items from the stack, and pushes the resulting tuple onto the stack.
In our case, our signature will be a little different due to the way
we create Instruction
objects, but the gist of it will be the same.
Instead of passing in count
, we pass in an object with a little
extra bookkeeping, and of course, we deal with turning regular old
python objects into TorchDynamo notions:
def BUILD_TUPLE(self, inst):
items = self.popn(inst.argval)
options = VariableTracker.propagate(items)
self.push(TupleVariable(items, **options))
Here is what this code does:
The function reads
argval
, which in this case, is analogous tocounts
in the pydoc for the equivalent instruction.The function
popn
the items, in this case, the signature isdef popn(self, n: int) -> List[TensorVariable]:
this hints at an underlying contract - we are returningTensorVariables
. If we take a closer look atsymbolic_convert.py
andInstructionTranslatorBase
/InstructionTranslator
we see that the only thing pushed onto and popped from our stack areVariableTracker
s.
The function calls
VariableTracker.propagate
. This takes the guards from every single item popped off the stack in 2, and recursively traverses it and combines all the guards intooptions
:py return { "guards": guards, }
The function then makes a new instance of a
VariableTracker
,TupleVariable
out of theitems
andoptions
. This then allows us to install all the appropriate guards from theitems
that make up the newTupleVariable
Note
Where did the first guards come from? Propagation
is a good technique, but we need something created before it can be
propagated. VariableBuilder
calls
make_guards
as it creates VariableTracker
instances, from
f_locals
. This in turn calls into the source
, to have it create
guards.
After all this, bytecode translation is done and we are one step closer
to producing GuardedCode
. We now understand how locals become
VariableTracker
s, how instructions are handled, and where guards
are called on for creation. Before we can go into seeing how code and
guards are combined into a GuardedCode object, we need to dig a little
bit into those make_guard
and source.make_guard
calls above. We
can then understand, what was going on when we made guards
alongside, and on, VariableTracker
instances.
Making Guards¶
Guards are just Python objects, of the class Guard
. Let’s look at them
in more detail.
Looking at the definition of the dataclass (and therefore, ctor signature), we see that it has a name, a source, and a create function.
@dataclasses.dataclass
class Guard:
name: str
source: GuardSource
create_fn: Callable
The name should be the name of the variable.
The source here is an enum indicating what kind of source the guard belongs to.
Note
Not to be confused with Source
and the other types
in source.py
, as stored on VariableTracker
.
create_fn
provides the main functionality to transition from a simple
dataclass to actually producing valid Python code to be invoked for
knowing whether or not things have changed in between invocations, and
whether we can safely read from the code cache or not.
The most common code paths for getting an instance of a guard are
through make_guards
on VariableTracker
.
make_guards
->``source.make_guard``->``return Guard(self.name(), self.guard_source(), fn)``
Or, in a concrete example:
...
elif istype(value, range):
guards = self.make_guards(GuardBuilder.EQUALS_MATCH)
return RangeVariable(value=value, guards=guards)
Since source
was set at the construction time of this
VariableTracker
, all that was needed here was to provide the fn
,
GuardBuilder.EQUALS_MATCH
to the create_fn
field.
This create_fn
must be a method on GuardBuilder
. The reason for
this becomes apparent in our next step. Once we have all the guards
created for a frame, we move on to CheckFunctionManager
and
compile_check_fn
.
Before the convert_frame
function can produce a GuardedCode
,
it needs to run the CheckFunctionManager
, with all the guards, to
produce a check_fn
which will then, in turn get passed in alongside
the code into GuardedCode
. This is the same check_fn
that we store in our
cache entry, and the same one we run to know whether or not to retrieve
the code stored alongside. For reference, here is that code:
static CacheEntry *create_cache_entry(CacheEntry *next,
PyObject *guarded_code) {
CacheEntry *e = (CacheEntry *)malloc(sizeof(CacheEntry));
DEBUG_NULL_CHECK(e);
e->check_fn = PyObject_GetAttrString(guarded_code, "check_fn");
NULL_CHECK(e->check_fn);
e->code = (PyCodeObject *)PyObject_GetAttrString(guarded_code, "code");
NULL_CHECK(e->code);
e->next = next;
return e;
}
We now know how a check_fn
function is used, and who makes it, and
what it is composed of, but what we do not yet know is how. How does a
list of Guard
objects become a function we can run later on?
First, we iterate these guards:
for guard in sorted(guards or [], key=Guard.sort_key):
if not config.guard_nn_modules and guard.is_nn_module():
continue
guard.create(local_builder, global_builder)
Calling guard.create
runs that create_fn
we set on the Guard
class above (don’t confuse it with the check_fn
we are working on
producing, the names are similar, so it can get a little confusing). In
our example above, our create_fn
is GuardBuilder.EQUALS_MATCH
.
So we are now invoking it, passing in the self
, the guard itself,
in.
The signature is: def EQUALS_MATCH(self, guard: Guard):
And internally to that function, we can use the name
on the guard to
get back our original object, querying it for data and type information,
which in turn gets us to the most important bit: appending code.
At its simplest, EQUALS_MATCH
appends just one line of code:
self.code.append(f"{ref} == {val!r}")
. Where ref
is the name of
the variable, and val
is the value. It might produce code like this:
y == 2
This is a basic example. But if we append a few other kinds of GuardBuilder
functions and then combine them all with
and
in between each statement (as we do), we might get something
like this:
___guarded_code.valid and ___check_type_id(y, 94367738391392) and y == 2 and ___check_tensors(x)
Here is what this code performs:
A check for
.valid
A type ID check
A value check
A tensor check
This becomes the heart of the code our check_fn
, which in turn
is evaluated the next time we encounter this code. It
will then check:
Is this code still valid?
If (1), Does
y
still have a type of94367738391392
?If (2), is
y
still 2?If (3), let’s check on if tensor
x
changed in some specific ways.
If all of these are still true, then we can use the code cached
alongside this check_fn
.
Note
For a deeper dive for how and where this happens
you can read static PyCodeObject *lookup(CacheEntry *e, PyObject *f_locals) {
of
_eval_frame.c
.
If not, then, we can move on to recompiling the code anew, and storing
that in the cache alongside this code, and a whole new check_fn
,
again to be checked on yet another subsequent frame.
There are lots of other such functions on GuardBuilder
which get
coalesced into, at times massive, strings which then get evaluated as
Python code and stored into check_fn
. The example above
illustrates of a simple case. To understand this functionality better, read
the other functions on GuardBuilder
, or better yet, dump the code
variable
in compile_check_fn
to see what is getting produced,
especially on larger, real models.
Summary¶
In this section, we have reviewed:
The role of
.valid
and invalidation around weak references (and potentially soon to be NN Moduleinvalidations).How the C++ side of guard functions (
___check_type_id
,___check_tensors
, etc) operateWhat happens when guards fail.
What happens if we produce invalid guard code.
We covered how user provided code wrapped in a TorchDynamo context
goes on to get traced and tracked internally, organized into VariableTracker
s
Source
s and subsequently Guard
s, and how those Guards
in
turn guide cache entry selection and invalidation when handing Python
code.