torch.utils.checkpoint¶
Note
Checkpointing is implemented by rerunning a forward-pass segment for
each checkpointed segment during backward. This can cause persistent
states like the RNG state to be advanced than they would without
checkpointing. By default, checkpointing includes logic to juggle
the RNG state such that checkpointed passes making use of RNG
(through dropout for example) have deterministic output as
compared to non-checkpointed passes. The logic to stash and restore
RNG states can incur a moderate performance hit depending on the runtime
of checkpointed operations. If deterministic output compared to
non-checkpointed passes is not required, supply preserve_rng_state=False
to checkpoint
or checkpoint_sequential
to omit stashing and
restoring the RNG state during each checkpoint.
The stashing logic saves and restores the RNG state for the current device
and the device of all cuda Tensor arguments to the run_fn
.
However, the logic has no way to anticipate if the user will move
Tensors to a new device within the run_fn
itself. Therefore, if you move
Tensors to a new device (“new” meaning not belonging to the set of
[current device + devices of Tensor arguments]) within run_fn
, deterministic
output compared to non-checkpointed passes is never guaranteed.
- torch.utils.checkpoint.checkpoint(function, *args, use_reentrant=True, context_fn=<function noop_context_fn>, **kwargs)[source]¶
Checkpoint a model or part of the model
Checkpointing is a technique that trades compute for memory. Instead of storing all intermediate activations of the entire computation graph for the backward pass, the checkpointed part omits saving intermediate activations and recomputes them during the backward pass. This can be applied to any part of a model.
There are currently two checkpointing implementations available, determined by the
use_reentrant
parameter. It is recommended that you useuse_reentrant=False
. Please refer the note below for a discussion of their differences.Warning
If the
function
invocation during the backward pass differs from the forward pass, e.g., due to a global variable, the checkpointed checkpointed version may not be equivalent, potentially causing an error being raised or leading to silently incorrect gradients.Warning
If you are using the
use_reentrant=True
variant (this is currently the default), please refer to the note below for important considerations and potential limitations.Note
The reentrant variant of checkpoint (
use_reentrant=True
) and the non-reentrant variant of checkpoint (use_reentrant=False
) differ in the following ways:Non-reentrant checkpoint stops recomputation as soon as all needed intermediate activations have been recomputed. This feature is enabled by default, but can be disabled with
set_checkpoint_early_stop()
. Reentrant checkpoint always recomputesfunction
in its entirety during the backward pass.
- The reentrant variant does not record the autograd graph during the
forward pass, as it runs with the forward pass under
torch.no_grad()
. The non-reentrant version does record the autograd graph, allowing one to perform backward on the graph within checkpointed regions.
The reentrant checkpoint only supports the
torch.autograd.backward()
API for the backward pass without its inputs argument, while the non-reentrant version supports all ways of performing the backward pass.At least one input and output must have
requires_grad=True
for the reentrant variant. If this condition is unmet, the checkpointed part of the model will not have gradients. The non-reentrant version does not have this requirement.The reentrant version does not consider tensors in nested structures (e.g., custom objects, lists, dicts, etc) as participating in autograd, while the non-reentrant version does.
The reentrant checkpoint does not support checkpointed regions with detached tensors from the computational graph, whereas the non-reentrant version does. For the reentrant variant, if the checkpointed segment contains tensors detached using
detach()
or withtorch.no_grad()
, the backward pass will raise an error. This is becausecheckpoint
makes all the outputs require gradients and this causes issues when a tensor is defined to have no gradient in the model. To avoid this, detach the tensors outside of thecheckpoint
function.
- Parameters:
function – describes what to run in the forward pass of the model or part of the model. It should also know how to handle the inputs passed as the tuple. For example, in LSTM, if user passes
(activation, hidden)
,function
should correctly use the first input asactivation
and the second input ashidden
preserve_rng_state (bool, optional) – Omit stashing and restoring the RNG state during each checkpoint. Default:
True
use_reentrant (bool, optional) – Use checkpointing implementation that requires re-entrant autograd. If
use_reentrant=False
is specified,checkpoint
will use an implementation that does not require re-entrant autograd. This allowscheckpoint
to support additional functionality, such as working as expected withtorch.autograd.grad
and support for keyword arguments input into the checkpointed function. Note that future versions of PyTorch will default touse_reentrant=False
. Default:True
context_fn (Callable, optional) – A callable returning a tuple of two context managers. The function and its recomputation will be run under the first and second context managers respectively. This argument is only supported if
use_reentrant=False
.args – tuple containing inputs to the
function
- Returns:
Output of running
function
on*args
- torch.utils.checkpoint.checkpoint_sequential(functions, segments, input, use_reentrant=True, **kwargs)[source]¶
A helper function for checkpointing sequential models.
Sequential models execute a list of modules/functions in order (sequentially). Therefore, we can divide such a model in various segments and checkpoint each segment. All segments except the last will not store the intermediate activations. The inputs of each checkpointed segment will be saved for re-running the segment in the backward pass.
Warning
If you are using the
use_reentrant=True` variant (this is the default), please see :func:`~torch.utils.checkpoint.checkpoint` for the important considerations and limitations of this variant. It is recommended that you use ``use_reentrant=False
.- Parameters:
functions – A
torch.nn.Sequential
or the list of modules or functions (comprising the model) to run sequentially.segments – Number of chunks to create in the model
input – A Tensor that is input to
functions
preserve_rng_state (bool, optional) – Omit stashing and restoring the RNG state during each checkpoint. Default:
True
use_reentrant (bool, optional) – Use checkpointing implementation that requires re-entrant autograd. If
use_reentrant=False
is specified,checkpoint
will use an implementation that does not require re-entrant autograd. This allowscheckpoint
to support additional functionality, such as working as expected withtorch.autograd.grad
and support for keyword arguments input into the checkpointed function. Default:True
- Returns:
Output of running
functions
sequentially on*inputs
Example
>>> model = nn.Sequential(...) >>> input_var = checkpoint_sequential(model, chunks, input_var)