Note
Click here to download the full example code
Reasoning about Shapes in PyTorch¶
When writing models with PyTorch, it is commonly the case that the parameters
to a given layer depend on the shape of the output of the previous layer. For
example, the in_features
of an nn.Linear
layer must match the
size(-1)
of the input. For some layers, the shape computation involves
complex equations, for example convolution operations.
One way around this is to run the forward pass with random inputs, but this is wasteful in terms of memory and compute.
Instead, we can make use of the meta
device to determine the output shapes
of a layer without materializing any data.
import torch
import timeit
t = torch.rand(2, 3, 10, 10, device="meta")
conv = torch.nn.Conv2d(3, 5, 2, device="meta")
start = timeit.default_timer()
out = conv(t)
end = timeit.default_timer()
print(out)
print(f"Time taken: {end-start}")
tensor(..., device='meta', size=(2, 5, 9, 9), grad_fn=<ConvolutionBackward0>)
Time taken: 0.00013915199997427408
Observe that since data is not materialized, passing arbitrarily large inputs will not significantly alter the time taken for shape computation.
tensor(..., device='meta', size=(1024, 5, 65535, 65535),
grad_fn=<ConvolutionBackward0>)
Time taken: 7.251199986058054e-05
Consider an arbitrary network such as the following:
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1) # flatten all dimensions except batch
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
We can view the intermediate shapes within an entire network by registering a forward hook to each layer that prints the shape of the output.
def fw_hook(module, input, output):
print(f"Shape of output to {module} is {output.shape}.")
# Any tensor created within this torch.device context manager will be
# on the meta device.
with torch.device("meta"):
net = Net()
inp = torch.randn((1024, 3, 32, 32))
for name, layer in net.named_modules():
layer.register_forward_hook(fw_hook)
out = net(inp)
Shape of output to Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1)) is torch.Size([1024, 6, 28, 28]).
Shape of output to MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) is torch.Size([1024, 6, 14, 14]).
Shape of output to Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1)) is torch.Size([1024, 16, 10, 10]).
Shape of output to MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) is torch.Size([1024, 16, 5, 5]).
Shape of output to Linear(in_features=400, out_features=120, bias=True) is torch.Size([1024, 120]).
Shape of output to Linear(in_features=120, out_features=84, bias=True) is torch.Size([1024, 84]).
Shape of output to Linear(in_features=84, out_features=10, bias=True) is torch.Size([1024, 10]).
Shape of output to Net(
(conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(fc1): Linear(in_features=400, out_features=120, bias=True)
(fc2): Linear(in_features=120, out_features=84, bias=True)
(fc3): Linear(in_features=84, out_features=10, bias=True)
) is torch.Size([1024, 10]).
Total running time of the script: ( 0 minutes 0.011 seconds)