.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "advanced/python_custom_ops.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_advanced_python_custom_ops.py: .. _python-custom-ops-tutorial: Python Custom Operators ======================= .. grid:: 2 .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn :class-card: card-prerequisites * How to integrate custom operators written in Python with PyTorch * How to test custom operators using ``torch.library.opcheck`` .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites :class-card: card-prerequisites * PyTorch 2.4 or later PyTorch offers a large library of operators that work on Tensors (e.g. ``torch.add``, ``torch.sum``, etc). However, you might wish to use a new customized operator with PyTorch, perhaps written by a third-party library. This tutorial shows how to wrap Python functions so that they behave like PyTorch native operators. Reasons why you may wish to create a custom operator in PyTorch include: - Treating an arbitrary Python function as an opaque callable with respect to ``torch.compile`` (that is, prevent ``torch.compile`` from tracing into the function). - Adding training support to an arbitrary Python function Please note that if your operation can be expressed as a composition of existing PyTorch operators, then there is usually no need to use the custom operator API -- everything (for example ``torch.compile``, training support) should just work. .. GENERATED FROM PYTHON SOURCE LINES 39-42 Example: Wrapping PIL's crop into a custom operator ------------------------------------ Let's say that we are using PIL's ``crop`` operation. .. GENERATED FROM PYTHON SOURCE LINES 42-61 .. code-block:: default import torch from torchvision.transforms.functional import to_pil_image, pil_to_tensor import PIL import IPython import matplotlib.pyplot as plt def crop(pic, box): img = to_pil_image(pic.cpu()) cropped_img = img.crop(box) return pil_to_tensor(cropped_img).to(pic.device) / 255. def display(img): plt.imshow(img.numpy().transpose((1, 2, 0))) img = torch.ones(3, 64, 64) img *= torch.linspace(0, 1, steps=64) * torch.linspace(0, 1, steps=64).unsqueeze(-1) display(img) .. GENERATED FROM PYTHON SOURCE LINES 62-66 .. code-block:: default cropped_img = crop(img, (10, 10, 50, 50)) display(cropped_img) .. GENERATED FROM PYTHON SOURCE LINES 67-74 ``crop`` is not handled effectively out-of-the-box by ``torch.compile``: ``torch.compile`` induces a `"graph break" `_ on functions it is unable to handle and graph breaks are bad for performance. The following code demonstrates this by raising an error (``torch.compile`` with ``fullgraph=True`` raises an error if a graph break occurs). .. GENERATED FROM PYTHON SOURCE LINES 74-82 .. code-block:: default @torch.compile(fullgraph=True) def f(img): return crop(img, (10, 10, 50, 50)) # The following raises an error. Uncomment the line to see it. # cropped_img = f(img) .. GENERATED FROM PYTHON SOURCE LINES 83-91 In order to black-box ``crop`` for use with ``torch.compile``, we need to do two things: 1. wrap the function into a PyTorch custom operator. 2. add a "``FakeTensor`` kernel" (aka "meta kernel") to the operator. Given the metadata (e.g. shapes) of the input Tensors, this function says how to compute the metadata of the output Tensor(s). .. GENERATED FROM PYTHON SOURCE LINES 91-111 .. code-block:: default from typing import Sequence # Use torch.library.custom_op to define a new custom operator. # If your operator mutates any input Tensors, their names must be specified # in the ``mutates_args`` argument. @torch.library.custom_op("mylib::crop", mutates_args=()) def crop(pic: torch.Tensor, box: Sequence[int]) -> torch.Tensor: img = to_pil_image(pic.cpu()) cropped_img = img.crop(box) return (pil_to_tensor(cropped_img) / 255.).to(pic.device, pic.dtype) # Use register_fake to add a ``FakeTensor`` kernel for the operator @crop.register_fake def _(pic, box): channels = pic.shape[0] x0, y0, x1, y1 = box return pic.new_empty(channels, y1 - y0, x1 - x0) .. GENERATED FROM PYTHON SOURCE LINES 112-113 After this, ``crop`` now works without graph breaks: .. GENERATED FROM PYTHON SOURCE LINES 113-121 .. code-block:: default @torch.compile(fullgraph=True) def f(img): return crop(img, (10, 10, 50, 50)) cropped_img = f(img) display(img) .. GENERATED FROM PYTHON SOURCE LINES 122-125 .. code-block:: default display(cropped_img) .. GENERATED FROM PYTHON SOURCE LINES 126-136 Adding training support for crop -------------------------------- Use ``torch.library.register_autograd`` to add training support for an operator. Prefer this over directly using ``torch.autograd.Function``; some compositions of ``autograd.Function`` with PyTorch operator registration APIs can lead to (and has led to) silent incorrectness when composed with ``torch.compile``. The gradient formula for ``crop`` is essentially ``PIL.paste`` (we'll leave the derivation as an exercise to the reader). Let's first wrap ``paste`` into a custom operator: .. GENERATED FROM PYTHON SOURCE LINES 136-152 .. code-block:: default @torch.library.custom_op("mylib::paste", mutates_args=()) def paste(im1: torch.Tensor, im2: torch.Tensor, coord: Sequence[int]) -> torch.Tensor: assert im1.device == im2.device assert im1.dtype == im2.dtype im1_pil = to_pil_image(im1.cpu()) im2_pil = to_pil_image(im2.cpu()) PIL.Image.Image.paste(im1_pil, im2_pil, coord) return (pil_to_tensor(im1_pil) / 255.).to(im1.device, im1.dtype) @paste.register_fake def _(im1, im2, coord): assert im1.device == im2.device assert im1.dtype == im2.dtype return torch.empty_like(im1) .. GENERATED FROM PYTHON SOURCE LINES 153-154 And now let's use ``register_autograd`` to specify the gradient formula for ``crop``: .. GENERATED FROM PYTHON SOURCE LINES 154-167 .. code-block:: default def backward(ctx, grad_output): grad_input = grad_output.new_zeros(ctx.pic_shape) grad_input = paste(grad_input, grad_output, ctx.coords) return grad_input, None def setup_context(ctx, inputs, output): pic, box = inputs ctx.coords = box[:2] ctx.pic_shape = pic.shape crop.register_autograd(backward, setup_context=setup_context) .. GENERATED FROM PYTHON SOURCE LINES 168-171 Note that the backward must be a composition of PyTorch-understood operators, which is why we wrapped paste into a custom operator instead of directly using PIL's paste. .. GENERATED FROM PYTHON SOURCE LINES 171-177 .. code-block:: default img = img.requires_grad_() result = crop(img, (10, 10, 50, 50)) result.sum().backward() display(img.grad) .. GENERATED FROM PYTHON SOURCE LINES 178-180 This is the correct gradient, with 1s (white) in the cropped region and 0s (black) in the unused region. .. GENERATED FROM PYTHON SOURCE LINES 182-192 Testing Python Custom operators ------------------------------- Use ``torch.library.opcheck`` to test that the custom operator was registered correctly. This does not test that the gradients are mathematically correct; please write separate tests for that (either manual ones or ``torch.autograd.gradcheck``). To use ``opcheck``, pass it a set of example inputs to test against. If your operator supports training, then the examples should include Tensors that require grad. If your operator supports multiple devices, then the examples should include Tensors from each device. .. GENERATED FROM PYTHON SOURCE LINES 192-203 .. code-block:: default examples = [ [torch.randn(3, 64, 64), [0, 0, 10, 10]], [torch.randn(3, 91, 91, requires_grad=True), [10, 0, 20, 10]], [torch.randn(3, 60, 60, dtype=torch.double), [3, 4, 32, 20]], [torch.randn(3, 512, 512, requires_grad=True, dtype=torch.double), [3, 4, 32, 45]], ] for example in examples: torch.library.opcheck(crop, example) .. GENERATED FROM PYTHON SOURCE LINES 204-214 Mutable Python Custom operators ------------------------------- You can also wrap a Python function that mutates its inputs into a custom operator. Functions that mutate inputs are common because that is how many low-level kernels are written; for example, a kernel that computes ``sin`` may take in the input and an output tensor and write ``input.sin()`` to the output tensor. We'll use ``numpy.sin`` to demonstrate an example of a mutable Python custom operator. .. GENERATED FROM PYTHON SOURCE LINES 214-225 .. code-block:: default import numpy as np @torch.library.custom_op("mylib::numpy_sin", mutates_args={"output"}, device_types="cpu") def numpy_sin(input: torch.Tensor, output: torch.Tensor) -> None: assert input.device == output.device assert input.device.type == "cpu" input_np = input.numpy() output_np = output.numpy() np.sin(input_np, out=output_np) .. GENERATED FROM PYTHON SOURCE LINES 226-228 Because the operator doesn't return anything, there is no need to register a ``FakeTensor`` kernel (meta kernel) to get it to work with ``torch.compile``. .. GENERATED FROM PYTHON SOURCE LINES 228-239 .. code-block:: default @torch.compile(fullgraph=True) def f(x): out = torch.empty(3) numpy_sin(x, out) return out x = torch.randn(3) y = f(x) assert torch.allclose(y, x.sin()) .. GENERATED FROM PYTHON SOURCE LINES 240-242 And here's an ``opcheck`` run telling us that we did indeed register the operator correctly. ``opcheck`` would error out if we forgot to add the output to ``mutates_args``, for example. .. GENERATED FROM PYTHON SOURCE LINES 242-252 .. code-block:: default example_inputs = [ [torch.randn(3), torch.empty(3)], [torch.randn(0, 3), torch.empty(0, 3)], [torch.randn(1, 2, 3, 4, dtype=torch.double), torch.empty(1, 2, 3, 4, dtype=torch.double)], ] for example in example_inputs: torch.library.opcheck(numpy_sin, example) .. GENERATED FROM PYTHON SOURCE LINES 253-265 Conclusion ---------- In this tutorial, we learned how to use ``torch.library.custom_op`` to create a custom operator in Python that works with PyTorch subsystems such as ``torch.compile`` and autograd. This tutorial provides a basic introduction to custom operators. For more detailed information, see: - `the torch.library documentation `_ - `the Custom Operators Manual `_ .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.000 seconds) .. _sphx_glr_download_advanced_python_custom_ops.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: python_custom_ops.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: python_custom_ops.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_