torch.optim¶
torch.optim
is a package implementing various optimization algorithms.
Most commonly used methods are already supported, and the interface is general
enough, so that more sophisticated ones can also be easily integrated in the
future.
How to use an optimizer¶
To use torch.optim
you have to construct an optimizer object that will hold
the current state and will update the parameters based on the computed gradients.
Constructing it¶
To construct an Optimizer
you have to give it an iterable containing the
parameters (all should be Variable
s) to optimize. Then,
you can specify optimizer-specific options such as the learning rate, weight decay, etc.
Example:
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
optimizer = optim.Adam([var1, var2], lr=0.0001)
Per-parameter options¶
Optimizer
s also support specifying per-parameter options. To do this, instead
of passing an iterable of Variable
s, pass in an iterable of
dict
s. Each of them will define a separate parameter group, and should contain
a params
key, containing a list of parameters belonging to it. Other keys
should match the keyword arguments accepted by the optimizers, and will be used
as optimization options for this group.
Note
You can still pass options as keyword arguments. They will be used as defaults, in the groups that didn’t override them. This is useful when you only want to vary a single option, while keeping all others consistent between parameter groups.
For example, this is very useful when one wants to specify per-layer learning rates:
optim.SGD([
{'params': model.base.parameters()},
{'params': model.classifier.parameters(), 'lr': 1e-3}
], lr=1e-2, momentum=0.9)
This means that model.base
’s parameters will use the default learning rate of 1e-2
,
model.classifier
’s parameters will use a learning rate of 1e-3
, and a momentum of
0.9
will be used for all parameters.
Taking an optimization step¶
All optimizers implement a step()
method, that updates the
parameters. It can be used in two ways:
optimizer.step()
¶
This is a simplified version supported by most optimizers. The function can be
called once the gradients are computed using e.g.
backward()
.
Example:
for input, target in dataset:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
optimizer.step(closure)
¶
Some optimization algorithms such as Conjugate Gradient and LBFGS need to reevaluate the function multiple times, so you have to pass in a closure that allows them to recompute your model. The closure should clear the gradients, compute the loss, and return it.
Example:
for input, target in dataset:
def closure():
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
return loss
optimizer.step(closure)
Base class¶
- class torch.optim.Optimizer(params, defaults)[source]¶
Base class for all optimizers.
Warning
Parameters need to be specified as collections that have a deterministic ordering that is consistent between runs. Examples of objects that don’t satisfy those properties are sets and iterators over values of dictionaries.
- Parameters:
params (iterable) – an iterable of
torch.Tensor
s ordict
s. Specifies what Tensors should be optimized.defaults – (dict): a dict containing default values of optimization options (used when a parameter group doesn’t specify them).
Add a param group to the |
|
Loads the optimizer state. |
|
Returns the state of the optimizer as a |
|
Performs a single optimization step (parameter update). |
|
Resets the gradients of all optimized |
Algorithms¶
Adadelta |
Implements Adadelta algorithm. |
Adagrad |
Implements Adagrad algorithm. |
Adam |
Implements Adam algorithm. |
AdamW |
Implements AdamW algorithm. |
SparseAdam |
Implements lazy version of Adam algorithm suitable for sparse tensors. |
Adamax |
Implements Adamax algorithm (a variant of Adam based on infinity norm). |
ASGD |
Implements Averaged Stochastic Gradient Descent. |
LBFGS |
Implements L-BFGS algorithm, heavily inspired by minFunc. |
NAdam |
Implements NAdam algorithm. |
RAdam |
Implements RAdam algorithm. |
RMSprop |
Implements RMSprop algorithm. |
Rprop |
Implements the resilient backpropagation algorithm. |
SGD |
Many of our algorithms have various implementations optimized for performance, readability and/or generality, so we attempt to default to the generally fastest implementation for the current device if no particular implementation has been specified by the user.
We have 3 major categories of implementations: for-loop, foreach (multi-tensor), and fused. The most straightforward implementations are for-loops over the parameters with big chunks of computation. For-looping is usually slower than our foreach implementations, which combine parameters into a multi-tensor and run the big chunks of computation all at once, thereby saving many sequential kernel calls. A few of our optimizers have even faster fused implementations, which fuse the big chunks of computation into one kernel. We can think of foreach implementations as fusing horizontally and fused implementations as fusing vertically on top of that.
In general, the performance ordering of the 3 implementations is fused > foreach > for-loop. So when applicable, we default to foreach over for-loop. Applicable means the foreach implementation is available, the user has not specified any implementation-specific kwargs (e.g., fused, foreach, differentiable), and all tensors are native and on CUDA. Note that while fused should be even faster than foreach, the implementations are newer and we would like to give them more bake-in time before flipping the switch everywhere. You are welcome to try them out though!
Below is a table showing the available and default implementations of each algorithm:
Algorithm |
Default |
Has foreach? |
Has fused? |
---|---|---|---|
Adadelta |
foreach |
yes |
no |
Adagrad |
foreach |
yes |
no |
Adam |
foreach |
yes |
yes |
AdamW |
foreach |
yes |
yes |
SparseAdam |
for-loop |
no |
no |
Adamax |
foreach |
yes |
no |
ASGD |
foreach |
yes |
no |
LBFGS |
for-loop |
no |
no |
NAdam |
foreach |
yes |
no |
RAdam |
foreach |
yes |
no |
RMSprop |
foreach |
yes |
no |
Rprop |
foreach |
yes |
no |
SGD |
foreach |
yes |
no |
How to adjust learning rate¶
torch.optim.lr_scheduler
provides several methods to adjust the learning
rate based on the number of epochs. torch.optim.lr_scheduler.ReduceLROnPlateau
allows dynamic learning rate reducing based on some validation measurements.
Learning rate scheduling should be applied after optimizer’s update; e.g., you should write your code this way:
Example:
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler = ExponentialLR(optimizer, gamma=0.9)
for epoch in range(20):
for input, target in dataset:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
scheduler.step()
Most learning rate schedulers can be called back-to-back (also referred to as chaining schedulers). The result is that each scheduler is applied one after the other on the learning rate obtained by the one preceding it.
Example:
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler1 = ExponentialLR(optimizer, gamma=0.9)
scheduler2 = MultiStepLR(optimizer, milestones=[30,80], gamma=0.1)
for epoch in range(20):
for input, target in dataset:
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
scheduler1.step()
scheduler2.step()
In many places in the documentation, we will use the following template to refer to schedulers algorithms.
>>> scheduler = ...
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
Warning
Prior to PyTorch 1.1.0, the learning rate scheduler was expected to be called before
the optimizer’s update; 1.1.0 changed this behavior in a BC-breaking way. If you use
the learning rate scheduler (calling scheduler.step()
) before the optimizer’s update
(calling optimizer.step()
), this will skip the first value of the learning rate schedule.
If you are unable to reproduce results after upgrading to PyTorch 1.1.0, please check
if you are calling scheduler.step()
at the wrong time.
Sets the learning rate of each parameter group to the initial lr times a given function. |
|
Multiply the learning rate of each parameter group by the factor given in the specified function. |
|
Decays the learning rate of each parameter group by gamma every step_size epochs. |
|
Decays the learning rate of each parameter group by gamma once the number of epoch reaches one of the milestones. |
|
Decays the learning rate of each parameter group by a small constant factor until the number of epoch reaches a pre-defined milestone: total_iters. |
|
Decays the learning rate of each parameter group by linearly changing small multiplicative factor until the number of epoch reaches a pre-defined milestone: total_iters. |
|
Decays the learning rate of each parameter group by gamma every epoch. |
|
Decays the learning rate of each parameter group using a polynomial function in the given total_iters. |
|
Set the learning rate of each parameter group using a cosine annealing schedule, where is set to the initial lr and is the number of epochs since the last restart in SGDR: |
|
Chains list of learning rate schedulers. |
|
Receives the list of schedulers that is expected to be called sequentially during optimization process and milestone points that provides exact intervals to reflect which scheduler is supposed to be called at a given epoch. |
|
Reduce learning rate when a metric has stopped improving. |
|
Sets the learning rate of each parameter group according to cyclical learning rate policy (CLR). |
|
Sets the learning rate of each parameter group according to the 1cycle learning rate policy. |
|
Set the learning rate of each parameter group using a cosine annealing schedule, where is set to the initial lr, is the number of epochs since the last restart and is the number of epochs between two warm restarts in SGDR: |
Stochastic Weight Averaging¶
torch.optim.swa_utils
implements Stochastic Weight Averaging (SWA). In particular,
torch.optim.swa_utils.AveragedModel
class implements SWA models,
torch.optim.swa_utils.SWALR
implements the SWA learning rate scheduler and
torch.optim.swa_utils.update_bn()
is a utility function used to update SWA batch
normalization statistics at the end of training.
SWA has been proposed in Averaging Weights Leads to Wider Optima and Better Generalization.
Constructing averaged models¶
AveragedModel class serves to compute the weights of the SWA model. You can create an averaged model by running:
>>> swa_model = AveragedModel(model)
Here the model model
can be an arbitrary torch.nn.Module
object. swa_model
will keep track of the running averages of the parameters of the model
. To update these
averages, you can use the update_parameters()
function:
>>> swa_model.update_parameters(model)
SWA learning rate schedules¶
Typically, in SWA the learning rate is set to a high constant value. SWALR
is a
learning rate scheduler that anneals the learning rate to a fixed value, and then keeps it
constant. For example, the following code creates a scheduler that linearly anneals the
learning rate from its initial value to 0.05 in 5 epochs within each parameter group:
>>> swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, \
>>> anneal_strategy="linear", anneal_epochs=5, swa_lr=0.05)
You can also use cosine annealing to a fixed value instead of linear annealing by setting
anneal_strategy="cos"
.
Taking care of batch normalization¶
update_bn()
is a utility function that allows to compute the batchnorm statistics for the SWA model
on a given dataloader loader
at the end of training:
>>> torch.optim.swa_utils.update_bn(loader, swa_model)
update_bn()
applies the swa_model
to every element in the dataloader and computes the activation
statistics for each batch normalization layer in the model.
Warning
update_bn()
assumes that each batch in the dataloader loader
is either a tensors or a list of
tensors where the first element is the tensor that the network swa_model
should be applied to.
If your dataloader has a different structure, you can update the batch normalization statistics of the
swa_model
by doing a forward pass with the swa_model
on each element of the dataset.
Custom averaging strategies¶
By default, torch.optim.swa_utils.AveragedModel
computes a running equal average of
the parameters that you provide, but you can also use custom averaging functions with the
avg_fn
parameter. In the following example ema_model
computes an exponential moving average.
Example:
>>> ema_avg = lambda averaged_model_parameter, model_parameter, num_averaged:\
>>> 0.1 * averaged_model_parameter + 0.9 * model_parameter
>>> ema_model = torch.optim.swa_utils.AveragedModel(model, avg_fn=ema_avg)
Putting it all together¶
In the example below, swa_model
is the SWA model that accumulates the averages of the weights.
We train the model for a total of 300 epochs and we switch to the SWA learning rate schedule
and start to collect SWA averages of the parameters at epoch 160:
>>> loader, optimizer, model, loss_fn = ...
>>> swa_model = torch.optim.swa_utils.AveragedModel(model)
>>> scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=300)
>>> swa_start = 160
>>> swa_scheduler = SWALR(optimizer, swa_lr=0.05)
>>>
>>> for epoch in range(300):
>>> for input, target in loader:
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
>>> if epoch > swa_start:
>>> swa_model.update_parameters(model)
>>> swa_scheduler.step()
>>> else:
>>> scheduler.step()
>>>
>>> # Update bn statistics for the swa_model at the end
>>> torch.optim.swa_utils.update_bn(loader, swa_model)
>>> # Use swa_model to make predictions on test data
>>> preds = swa_model(test_input)