.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "recipes/compiling_optimizer_lr_scheduler.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_recipes_compiling_optimizer_lr_scheduler.py: (beta) Running the compiled optimizer with an LR Scheduler ============================================================ **Author:** `Michael Lazos `_ .. GENERATED FROM PYTHON SOURCE LINES 9-16 The optimizer is a key algorithm for training any deep learning model. In this example, we will show how to pair the optimizer, which has been compiled using ``torch.compile``, with the LR schedulers to accelerate training convergence. .. note:: This tutorial requires PyTorch 2.3.0 or later. .. GENERATED FROM PYTHON SOURCE LINES 18-22 Model Setup ~~~~~~~~~~~~~~~~~~~~~ For this example, we'll use a simple sequence of linear layers. .. GENERATED FROM PYTHON SOURCE LINES 22-38 .. code-block:: default import torch # Create simple model model = torch.nn.Sequential( *[torch.nn.Linear(1024, 1024, False, device="cuda") for _ in range(10)] ) input = torch.rand(1024, device="cuda") # run forward pass output = model(input) # run backward to populate the grads for our optimizer below output.sum().backward() .. GENERATED FROM PYTHON SOURCE LINES 39-49 Setting up and running the compiled optimizer with LR Scheduler ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ In this section, we'll use the Adam optimizer with LinearLR Scheduler and create a helper function to wrap the ``step()`` call for each of them in ``torch.compile()``. .. note:: ``torch.compile`` is only supported on CUDA devices that have a compute capability of 7.0 or higher. .. GENERATED FROM PYTHON SOURCE LINES 49-76 .. code-block:: default # exit cleanly if we are on a device that doesn't support ``torch.compile`` if torch.cuda.get_device_capability() < (7, 0): print("Exiting because torch.compile is not supported on this device.") import sys sys.exit(0) # !!! IMPORTANT !!! Wrap the lr in a Tensor if we are pairing the # the optimizer with an LR Scheduler. # Without this, torch.compile will recompile as the value of the LR # changes. opt = torch.optim.Adam(model.parameters(), lr=torch.tensor(0.01)) sched = torch.optim.lr_scheduler.LinearLR(opt, total_iters=5) @torch.compile(fullgraph=False) def fn(): opt.step() sched.step() # Warmup runs to compile the function for _ in range(5): fn() print(opt.param_groups[0]["lr"]) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor(0.0047) tensor(0.0060) tensor(0.0073) tensor(0.0087) tensor(0.0100) .. GENERATED FROM PYTHON SOURCE LINES 77-81 Extension: What happens with a non-tensor LR? ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ For the curious, we will show how to peek into what happens with ``torch.compile`` when we don't wrap the LR in a tensor. .. GENERATED FROM PYTHON SOURCE LINES 81-101 .. code-block:: default # No longer wrap the LR in a tensor here opt = torch.optim.Adam(model.parameters(), lr=0.01) sched = torch.optim.lr_scheduler.LinearLR(opt, total_iters=5) @torch.compile(fullgraph=False) def fn(): opt.step() sched.step() # Setup logging to view recompiles torch._logging.set_logs(recompiles=True) # Warmup runs to compile the function # We will now recompile on each iteration # as the value of the lr is mutated. for _ in range(5): fn() .. rst-class:: sphx-glr-script-out .. code-block:: none V0715 17:19:09.565000 140206346437248 torch/_dynamo/guards.py:1425] [__recompiles] Recompiling function step in /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/optim/adam.py:135 V0715 17:19:09.565000 140206346437248 torch/_dynamo/guards.py:1425] [__recompiles] triggered by the following guard failure(s): V0715 17:19:09.565000 140206346437248 torch/_dynamo/guards.py:1425] [__recompiles] - ___key_to_id(L['self'].state) == [140197331170656,140197331160736,140197330074128,140197330070048,140197330064288,140197330060688,140197330073888,140197330063808,140197330074448,140197330067408] V0715 17:19:12.850000 140206346437248 torch/_dynamo/guards.py:1425] [__recompiles] Recompiling function step in /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/optim/adam.py:135 V0715 17:19:12.850000 140206346437248 torch/_dynamo/guards.py:1425] [__recompiles] triggered by the following guard failure(s): V0715 17:19:12.850000 140206346437248 torch/_dynamo/guards.py:1425] [__recompiles] - L['self'].param_groups[0]['lr'] == 0.003333333333333333 V0715 17:19:12.850000 140206346437248 torch/_dynamo/guards.py:1425] [__recompiles] - G['__optimizer_140197482626688_140197351566336_c59']() is not None V0715 17:19:15.378000 140206346437248 torch/_dynamo/guards.py:1425] [__recompiles] Recompiling function step in /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/optim/adam.py:135 V0715 17:19:15.378000 140206346437248 torch/_dynamo/guards.py:1425] [__recompiles] triggered by the following guard failure(s): V0715 17:19:15.378000 140206346437248 torch/_dynamo/guards.py:1425] [__recompiles] - L['self'].param_groups[0]['lr'] == 0.004666666666666667 V0715 17:19:15.378000 140206346437248 torch/_dynamo/guards.py:1425] [__recompiles] - L['self'].param_groups[0]['lr'] == 0.003333333333333333 V0715 17:19:15.378000 140206346437248 torch/_dynamo/guards.py:1425] [__recompiles] - G['__optimizer_140197482626688_140197351566336_c59']() is not None V0715 17:19:17.898000 140206346437248 torch/_dynamo/guards.py:1425] [__recompiles] Recompiling function step in /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/optim/adam.py:135 V0715 17:19:17.898000 140206346437248 torch/_dynamo/guards.py:1425] [__recompiles] triggered by the following guard failure(s): V0715 17:19:17.898000 140206346437248 torch/_dynamo/guards.py:1425] [__recompiles] - L['self'].param_groups[0]['lr'] == 0.006000000000000001 V0715 17:19:17.898000 140206346437248 torch/_dynamo/guards.py:1425] [__recompiles] - L['self'].param_groups[0]['lr'] == 0.004666666666666667 V0715 17:19:17.898000 140206346437248 torch/_dynamo/guards.py:1425] [__recompiles] - L['self'].param_groups[0]['lr'] == 0.003333333333333333 V0715 17:19:17.898000 140206346437248 torch/_dynamo/guards.py:1425] [__recompiles] - G['__optimizer_140197482626688_140197351566336_c59']() is not None V0715 17:19:20.418000 140206346437248 torch/_dynamo/guards.py:1425] [__recompiles] Recompiling function step in /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/optim/adam.py:135 V0715 17:19:20.418000 140206346437248 torch/_dynamo/guards.py:1425] [__recompiles] triggered by the following guard failure(s): V0715 17:19:20.418000 140206346437248 torch/_dynamo/guards.py:1425] [__recompiles] - L['self'].param_groups[0]['lr'] == 0.007333333333333335 V0715 17:19:20.418000 140206346437248 torch/_dynamo/guards.py:1425] [__recompiles] - L['self'].param_groups[0]['lr'] == 0.006000000000000001 V0715 17:19:20.418000 140206346437248 torch/_dynamo/guards.py:1425] [__recompiles] - L['self'].param_groups[0]['lr'] == 0.004666666666666667 V0715 17:19:20.418000 140206346437248 torch/_dynamo/guards.py:1425] [__recompiles] - L['self'].param_groups[0]['lr'] == 0.003333333333333333 V0715 17:19:20.418000 140206346437248 torch/_dynamo/guards.py:1425] [__recompiles] - G['__optimizer_140197482626688_140197351566336_c59']() is not None .. GENERATED FROM PYTHON SOURCE LINES 102-104 With this example, we can see that we recompile the optimizer a few times due to the guard failure on the ``lr`` in ``param_groups[0]``. .. GENERATED FROM PYTHON SOURCE LINES 106-118 Conclusion ~~~~~~~~~~ In this tutorial we showed how to pair the optimizer compiled with ``torch.compile`` with an LR Scheduler to accelerate training convergence. We used a model consisting of a simple sequence of linear layers with the Adam optimizer paired with a LinearLR scheduler to demonstrate the LR changing across iterations. See also: * `Compiled optimizer tutorial `__ - an intro into the compiled optimizer. * `Compiling the optimizer with PT2 `__ - deeper technical details on the compiled optimizer. .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 16.786 seconds) .. _sphx_glr_download_recipes_compiling_optimizer_lr_scheduler.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: compiling_optimizer_lr_scheduler.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: compiling_optimizer_lr_scheduler.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_